warp-lang 1.5.0__py3-none-manylinux2014_x86_64.whl → 1.6.0__py3-none-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (132) hide show
  1. warp/__init__.py +5 -0
  2. warp/autograd.py +414 -191
  3. warp/bin/warp-clang.so +0 -0
  4. warp/bin/warp.so +0 -0
  5. warp/build.py +40 -12
  6. warp/build_dll.py +13 -6
  7. warp/builtins.py +1124 -497
  8. warp/codegen.py +261 -136
  9. warp/config.py +1 -1
  10. warp/context.py +357 -119
  11. warp/examples/assets/square_cloth.usd +0 -0
  12. warp/examples/benchmarks/benchmark_gemm.py +27 -18
  13. warp/examples/benchmarks/benchmark_interop_paddle.py +3 -3
  14. warp/examples/benchmarks/benchmark_interop_torch.py +3 -3
  15. warp/examples/core/example_torch.py +18 -34
  16. warp/examples/fem/example_apic_fluid.py +1 -0
  17. warp/examples/fem/example_mixed_elasticity.py +1 -1
  18. warp/examples/optim/example_bounce.py +1 -1
  19. warp/examples/optim/example_cloth_throw.py +1 -1
  20. warp/examples/optim/example_diffray.py +4 -15
  21. warp/examples/optim/example_drone.py +1 -1
  22. warp/examples/optim/example_softbody_properties.py +392 -0
  23. warp/examples/optim/example_trajectory.py +1 -3
  24. warp/examples/optim/example_walker.py +5 -0
  25. warp/examples/sim/example_cartpole.py +0 -2
  26. warp/examples/sim/example_cloth.py +3 -1
  27. warp/examples/sim/example_cloth_self_contact.py +260 -0
  28. warp/examples/sim/example_granular_collision_sdf.py +4 -5
  29. warp/examples/sim/example_jacobian_ik.py +0 -2
  30. warp/examples/sim/example_quadruped.py +5 -2
  31. warp/examples/tile/example_tile_cholesky.py +79 -0
  32. warp/examples/tile/example_tile_convolution.py +2 -2
  33. warp/examples/tile/example_tile_fft.py +2 -2
  34. warp/examples/tile/example_tile_filtering.py +3 -3
  35. warp/examples/tile/example_tile_matmul.py +4 -4
  36. warp/examples/tile/example_tile_mlp.py +12 -12
  37. warp/examples/tile/example_tile_nbody.py +180 -0
  38. warp/examples/tile/example_tile_walker.py +319 -0
  39. warp/fem/geometry/geometry.py +0 -2
  40. warp/math.py +147 -0
  41. warp/native/array.h +12 -0
  42. warp/native/builtin.h +0 -1
  43. warp/native/bvh.cpp +149 -70
  44. warp/native/bvh.cu +287 -68
  45. warp/native/bvh.h +195 -85
  46. warp/native/clang/clang.cpp +5 -1
  47. warp/native/coloring.cpp +5 -1
  48. warp/native/cuda_util.cpp +91 -53
  49. warp/native/cuda_util.h +5 -0
  50. warp/native/exports.h +40 -40
  51. warp/native/intersect.h +17 -0
  52. warp/native/mat.h +41 -0
  53. warp/native/mathdx.cpp +19 -0
  54. warp/native/mesh.cpp +25 -8
  55. warp/native/mesh.cu +153 -101
  56. warp/native/mesh.h +482 -403
  57. warp/native/quat.h +40 -0
  58. warp/native/solid_angle.h +7 -0
  59. warp/native/sort.cpp +85 -0
  60. warp/native/sort.cu +34 -0
  61. warp/native/sort.h +3 -1
  62. warp/native/spatial.h +11 -0
  63. warp/native/tile.h +1187 -669
  64. warp/native/tile_reduce.h +8 -6
  65. warp/native/vec.h +41 -0
  66. warp/native/warp.cpp +8 -1
  67. warp/native/warp.cu +263 -40
  68. warp/native/warp.h +19 -5
  69. warp/optim/linear.py +22 -4
  70. warp/render/render_opengl.py +130 -64
  71. warp/sim/__init__.py +6 -1
  72. warp/sim/collide.py +270 -26
  73. warp/sim/import_urdf.py +8 -8
  74. warp/sim/integrator_euler.py +25 -7
  75. warp/sim/integrator_featherstone.py +154 -35
  76. warp/sim/integrator_vbd.py +842 -40
  77. warp/sim/model.py +134 -72
  78. warp/sparse.py +1 -1
  79. warp/stubs.py +265 -132
  80. warp/tape.py +28 -30
  81. warp/tests/aux_test_module_unload.py +15 -0
  82. warp/tests/{test_sim_grad.py → flaky_test_sim_grad.py} +104 -63
  83. warp/tests/test_array.py +74 -0
  84. warp/tests/test_assert.py +242 -0
  85. warp/tests/test_codegen.py +14 -61
  86. warp/tests/test_collision.py +2 -2
  87. warp/tests/test_coloring.py +12 -2
  88. warp/tests/test_examples.py +12 -1
  89. warp/tests/test_func.py +21 -4
  90. warp/tests/test_grad_debug.py +87 -2
  91. warp/tests/test_hash_grid.py +1 -1
  92. warp/tests/test_ipc.py +116 -0
  93. warp/tests/test_lerp.py +13 -87
  94. warp/tests/test_mat.py +138 -167
  95. warp/tests/test_math.py +47 -1
  96. warp/tests/test_matmul.py +17 -16
  97. warp/tests/test_matmul_lite.py +10 -15
  98. warp/tests/test_mesh.py +84 -60
  99. warp/tests/test_mesh_query_aabb.py +165 -0
  100. warp/tests/test_mesh_query_point.py +328 -286
  101. warp/tests/test_mesh_query_ray.py +134 -121
  102. warp/tests/test_mlp.py +2 -2
  103. warp/tests/test_operators.py +43 -0
  104. warp/tests/test_overwrite.py +47 -2
  105. warp/tests/test_quat.py +77 -0
  106. warp/tests/test_reload.py +29 -0
  107. warp/tests/test_sim_grad_bounce_linear.py +204 -0
  108. warp/tests/test_smoothstep.py +17 -83
  109. warp/tests/test_static.py +19 -3
  110. warp/tests/test_tape.py +25 -0
  111. warp/tests/test_tile.py +178 -191
  112. warp/tests/test_tile_load.py +356 -0
  113. warp/tests/test_tile_mathdx.py +61 -8
  114. warp/tests/test_tile_mlp.py +17 -17
  115. warp/tests/test_tile_reduce.py +24 -18
  116. warp/tests/test_tile_shared_memory.py +66 -17
  117. warp/tests/test_tile_view.py +165 -0
  118. warp/tests/test_torch.py +35 -0
  119. warp/tests/test_utils.py +36 -24
  120. warp/tests/test_vec.py +110 -0
  121. warp/tests/unittest_suites.py +29 -4
  122. warp/tests/unittest_utils.py +30 -13
  123. warp/thirdparty/unittest_parallel.py +2 -2
  124. warp/types.py +411 -101
  125. warp/utils.py +10 -7
  126. {warp_lang-1.5.0.dist-info → warp_lang-1.6.0.dist-info}/METADATA +92 -69
  127. {warp_lang-1.5.0.dist-info → warp_lang-1.6.0.dist-info}/RECORD +130 -119
  128. {warp_lang-1.5.0.dist-info → warp_lang-1.6.0.dist-info}/WHEEL +1 -1
  129. warp/examples/benchmarks/benchmark_tile.py +0 -179
  130. warp/native/tile_gemm.h +0 -341
  131. {warp_lang-1.5.0.dist-info → warp_lang-1.6.0.dist-info}/LICENSE.md +0 -0
  132. {warp_lang-1.5.0.dist-info → warp_lang-1.6.0.dist-info}/top_level.txt +0 -0
warp/native/bvh.h CHANGED
@@ -11,6 +11,14 @@
11
11
  #include "builtin.h"
12
12
  #include "intersect.h"
13
13
 
14
+ #define BVH_LEAF_SIZE (4)
15
+ #define SAH_NUM_BUCKETS (16)
16
+ #define USE_LOAD4
17
+
18
+ #define BVH_CONSTRUCTOR_SAH (0)
19
+ #define BVH_CONSTRUCTOR_MEDIAN (1)
20
+ #define BVH_CONSTRUCTOR_LBVH (2)
21
+
14
22
  namespace wp
15
23
  {
16
24
 
@@ -72,12 +80,38 @@ struct bounds3
72
80
  }
73
81
  }
74
82
 
83
+ CUDA_CALLABLE inline bool overlaps(const vec3& b_lower, const vec3& b_upper) const
84
+ {
85
+ if (lower[0] > b_upper[0] ||
86
+ lower[1] > b_upper[1] ||
87
+ lower[2] > b_upper[2] ||
88
+ upper[0] < b_lower[0] ||
89
+ upper[1] < b_lower[1] ||
90
+ upper[2] < b_lower[2])
91
+ {
92
+ return false;
93
+ }
94
+ else
95
+ {
96
+ return true;
97
+ }
98
+ }
99
+
75
100
  CUDA_CALLABLE inline void add_point(const vec3& p)
76
101
  {
77
102
  lower = min(lower, p);
78
103
  upper = max(upper, p);
79
104
  }
80
105
 
106
+ CUDA_CALLABLE inline void add_bounds(const vec3& lower_other, const vec3& upper_other)
107
+ {
108
+ // lower_other will only impact the lower of the new bounds
109
+ // upper_other will only impact the upper of the new bounds
110
+ // this costs only half of the computation of adding lower_other and upper_other separately
111
+ lower = min(lower, lower_other);
112
+ upper = max(upper, upper_other);
113
+ }
114
+
81
115
  CUDA_CALLABLE inline float area() const
82
116
  {
83
117
  vec3 e = upper-lower;
@@ -108,6 +142,13 @@ struct BVHPackedNodeHalf
108
142
  float x;
109
143
  float y;
110
144
  float z;
145
+ // For non-leaf nodes:
146
+ // - 'lower.i' represents the index of the left child node.
147
+ // - 'upper.i' represents the index of the right child node.
148
+ //
149
+ // For leaf nodes:
150
+ // - 'lower.i' indicates the start index of the primitives in 'primitive_indices'.
151
+ // - 'upper.i' indicates the index just after the last primitive in 'primitive_indices'
111
152
  unsigned int i : 31;
112
153
  unsigned int b : 1;
113
154
  };
@@ -120,11 +161,15 @@ struct BVH
120
161
  // used for fast refits
121
162
  int* node_parents;
122
163
  int* node_counts;
164
+ // reordered primitive indices corresponds to the ordering of leaf nodes
165
+ int* primitive_indices;
123
166
 
124
167
  int max_depth;
125
168
  int max_nodes;
126
169
  int num_nodes;
127
-
170
+ // since we use packed leaf nodes, the number of them is no longer the number of items, but variable
171
+ int num_leaf_nodes;
172
+
128
173
  // pointer (CPU or GPU) to a single integer index in node_lowers, node_uppers
129
174
  // representing the root of the tree, this is not always the first node
130
175
  // for bottom-up builders
@@ -161,6 +206,24 @@ CUDA_CALLABLE inline void make_node(volatile BVHPackedNodeHalf* n, const vec3& b
161
206
  n->b = (unsigned int)(leaf?1:0);
162
207
  }
163
208
 
209
+ #ifdef __CUDA_ARCH__
210
+ __device__ inline wp::BVHPackedNodeHalf bvh_load_node(const wp::BVHPackedNodeHalf* nodes, int index)
211
+ {
212
+ #ifdef USE_LOAD4
213
+ //return (const wp::BVHPackedNodeHalf&)(__ldg((const float4*)(nodes)+index));
214
+ return (const wp::BVHPackedNodeHalf&)(*((const float4*)(nodes)+index));
215
+ #else
216
+ return nodes[index];
217
+ #endif // USE_LOAD4
218
+
219
+ }
220
+ #else
221
+ inline wp::BVHPackedNodeHalf bvh_load_node(const wp::BVHPackedNodeHalf* nodes, int index)
222
+ {
223
+ return nodes[index];
224
+ }
225
+ #endif // __CUDACC__
226
+
164
227
  CUDA_CALLABLE inline int clz(int x)
165
228
  {
166
229
  int n;
@@ -215,7 +278,8 @@ struct bvh_query_t
215
278
  is_ray(false),
216
279
  input_lower(),
217
280
  input_upper(),
218
- bounds_nr(0)
281
+ bounds_nr(0),
282
+ primitive_counter(-1)
219
283
  {}
220
284
 
221
285
  // Required for adjoint computations.
@@ -230,22 +294,37 @@ struct bvh_query_t
230
294
  int stack[32];
231
295
  int count;
232
296
 
297
+ // >= 0 if currently in a packed leaf node
298
+ int primitive_counter;
299
+
233
300
  // inputs
234
- bool is_ray;
235
301
  wp::vec3 input_lower; // start for ray
236
302
  wp::vec3 input_upper; // dir for ray
237
303
 
238
304
  int bounds_nr;
305
+ bool is_ray;
239
306
  };
240
307
 
308
+ CUDA_CALLABLE inline bool bvh_query_intersection_test(const bvh_query_t& query, const vec3& node_lower, const vec3& node_upper)
309
+ {
310
+ if (query.is_ray)
311
+ {
312
+ float t = 0.0f;
313
+ return intersect_ray_aabb(query.input_lower, query.input_upper, node_lower, node_upper, t);
314
+ }
315
+ else
316
+ {
317
+ return intersect_aabb_aabb(query.input_lower, query.input_upper, node_lower, node_upper);
318
+ }
319
+ }
241
320
 
242
321
  CUDA_CALLABLE inline bvh_query_t bvh_query(
243
- uint64_t id, bool is_ray, const vec3& lower, const vec3& upper)
322
+ uint64_t id, bool is_ray, const vec3& lower, const vec3& upper)
244
323
  {
245
- // This routine traverses the BVH tree until it finds
324
+ // This routine traverses the BVH tree until it finds
246
325
  // the first overlapping bound.
247
326
 
248
- // initialize empty
327
+ // initialize empty
249
328
  bvh_query_t query;
250
329
 
251
330
  query.bounds_nr = -1;
@@ -255,57 +334,41 @@ CUDA_CALLABLE inline bvh_query_t bvh_query(
255
334
  query.bvh = bvh;
256
335
  query.is_ray = is_ray;
257
336
 
258
- // optimization: make the latest
337
+ // optimization: make the latest
259
338
  query.stack[0] = *bvh.root;
260
339
  query.count = 1;
261
- query.input_lower = lower;
262
- query.input_upper = upper;
340
+ query.input_lower = lower;
341
+ query.input_upper = upper;
263
342
 
264
- wp::bounds3 input_bounds(query.input_lower, query.input_upper);
265
-
266
- // Navigate through the bvh, find the first overlapping leaf node.
267
- while (query.count)
268
- {
343
+ // Navigate through the bvh, find the first overlapping leaf node.
344
+ while (query.count)
345
+ {
269
346
  const int node_index = query.stack[--query.count];
347
+ BVHPackedNodeHalf node_lower = bvh_load_node(bvh.node_lowers, node_index);
348
+ BVHPackedNodeHalf node_upper = bvh_load_node(bvh.node_uppers, node_index);
270
349
 
271
- BVHPackedNodeHalf node_lower = bvh.node_lowers[node_index];
272
- BVHPackedNodeHalf node_upper = bvh.node_uppers[node_index];
273
-
274
- wp::vec3 lower_pos(node_lower.x, node_lower.y, node_lower.z);
275
- wp::vec3 upper_pos(node_upper.x, node_upper.y, node_upper.z);
276
- wp::bounds3 current_bounds(lower_pos, upper_pos);
277
-
278
- if (query.is_ray)
279
- {
280
- float t = 0.0f;
281
- if (!intersect_ray_aabb(query.input_lower, query.input_upper, current_bounds.lower, current_bounds.upper, t))
282
- // Skip this box, it doesn't overlap with our ray.
283
- continue;
284
- }
285
- else
350
+ if (!bvh_query_intersection_test(query, (vec3&)node_lower, (vec3&)node_upper))
286
351
  {
287
- if (!input_bounds.overlaps(current_bounds))
288
- // Skip this box, it doesn't overlap with our target box.
289
- continue;
352
+ continue;
290
353
  }
291
354
 
292
355
  const int left_index = node_lower.i;
293
356
  const int right_index = node_upper.i;
294
-
295
- // Make bounds from this AABB
296
- if (node_lower.b)
297
- {
298
- // found very first leaf index.
357
+ // Make bounds from this AABB
358
+ if (node_lower.b)
359
+ {
360
+ // Reached a leaf node, point to its first primitive
299
361
  // Back up one level and return
362
+ query.primitive_counter = left_index;
300
363
  query.stack[query.count++] = node_index;
301
364
  return query;
302
- }
303
- else
304
- {
305
- query.stack[query.count++] = left_index;
306
- query.stack[query.count++] = right_index;
307
365
  }
308
- }
366
+ else
367
+ {
368
+ query.stack[query.count++] = left_index;
369
+ query.stack[query.count++] = right_index;
370
+ }
371
+ }
309
372
 
310
373
  return query;
311
374
  }
@@ -338,52 +401,100 @@ CUDA_CALLABLE inline void adj_bvh_query_ray(uint64_t id, const vec3& start, cons
338
401
 
339
402
  CUDA_CALLABLE inline bool bvh_query_next(bvh_query_t& query, int& index)
340
403
  {
341
- BVH bvh = query.bvh;
342
-
343
- wp::bounds3 input_bounds(query.input_lower, query.input_upper);
344
-
345
- // Navigate through the bvh, find the first overlapping leaf node.
346
- while (query.count)
347
- {
348
- const int node_index = query.stack[--query.count];
349
- BVHPackedNodeHalf node_lower = bvh.node_lowers[node_index];
350
- BVHPackedNodeHalf node_upper = bvh.node_uppers[node_index];
404
+ BVH bvh = query.bvh;
351
405
 
352
- wp::vec3 lower_pos(node_lower.x, node_lower.y, node_lower.z);
353
- wp::vec3 upper_pos(node_upper.x, node_upper.y, node_upper.z);
354
- wp::bounds3 current_bounds(lower_pos, upper_pos);
406
+ if (query.primitive_counter != -1)
407
+ // currently in a leaf node which is the last node in the stack
408
+ {
409
+ const int node_index = query.stack[query.count - 1];
410
+ BVHPackedNodeHalf node_lower = bvh_load_node(bvh.node_lowers, node_index);
411
+ BVHPackedNodeHalf node_upper = bvh_load_node(bvh.node_uppers, node_index);
355
412
 
356
- if (query.is_ray)
413
+ const int end = node_upper.i;
414
+ for (int primitive_counter = query.primitive_counter; primitive_counter < end; primitive_counter++)
357
415
  {
358
- float t = 0.0f;
359
- if (!intersect_ray_aabb(query.input_lower, query.input_upper, current_bounds.lower, current_bounds.upper, t))
360
- // Skip this box, it doesn't overlap with our ray.
361
- continue;
416
+ int primitive_index = bvh.primitive_indices[primitive_counter];
417
+ if (bvh_query_intersection_test(query, bvh.item_lowers[primitive_index], bvh.item_uppers[primitive_index]))
418
+ {
419
+ if (primitive_counter < end - 1)
420
+ // still need to come back to this leaf node for the leftover primitives
421
+ {
422
+ query.primitive_counter = primitive_counter + 1;
423
+ }
424
+ else
425
+ // no need to come back to this leaf node
426
+ {
427
+ query.count--;
428
+ query.primitive_counter = -1;
429
+ }
430
+ index = primitive_index;
431
+ query.bounds_nr = primitive_index;
432
+
433
+ return true;
434
+ }
362
435
  }
363
- else {
364
- if (!input_bounds.overlaps(current_bounds))
365
- // Skip this box, it doesn't overlap with our target box.
366
- continue;
436
+ // if we reach here that means we have finished the current leaf node without finding intersections
437
+ query.primitive_counter = -1;
438
+ // remove the leaf node from the back of the stack because it is finished
439
+ // and continue the bvh traversal
440
+ query.count--;
441
+ }
442
+
443
+ // Navigate through the bvh, find the first overlapping leaf node.
444
+ while (query.count)
445
+ {
446
+ const int node_index = query.stack[--query.count];
447
+ BVHPackedNodeHalf node_lower = bvh_load_node(bvh.node_lowers, node_index);
448
+ BVHPackedNodeHalf node_upper = bvh_load_node(bvh.node_uppers, node_index);
449
+
450
+ const int left_index = node_lower.i;
451
+ const int right_index = node_upper.i;
452
+
453
+ wp::vec3 lower_pos(node_lower.x, node_lower.y, node_lower.z);
454
+ wp::vec3 upper_pos(node_upper.x, node_upper.y, node_upper.z);
455
+ wp::bounds3 current_bounds(lower_pos, upper_pos);
456
+
457
+ if (!bvh_query_intersection_test(query, (vec3&)node_lower, (vec3&)node_upper))
458
+ {
459
+ continue;
367
460
  }
368
461
 
369
- const int left_index = node_lower.i;
370
- const int right_index = node_upper.i;
371
-
372
- if (node_lower.b)
373
- {
374
- // found leaf
375
- query.bounds_nr = left_index;
376
- index = left_index;
377
- return true;
378
- }
379
- else
380
- {
381
-
382
- query.stack[query.count++] = left_index;
383
- query.stack[query.count++] = right_index;
384
- }
385
- }
386
- return false;
462
+ if (node_lower.b)
463
+ {
464
+ // found leaf, loop through its content primitives
465
+ const int start = left_index;
466
+ const int end = right_index;
467
+
468
+ for (int primitive_counter = start; primitive_counter < end; primitive_counter++)
469
+ {
470
+ int primitive_index = bvh.primitive_indices[primitive_counter];
471
+ if (bvh_query_intersection_test(query, bvh.item_lowers[primitive_index], bvh.item_uppers[primitive_index]))
472
+ {
473
+ if (primitive_counter < end - 1)
474
+ // still need to come back to this leaf node for the leftover primitives
475
+ {
476
+ query.primitive_counter = primitive_counter + 1;
477
+ query.stack[query.count++] = node_index;
478
+ }
479
+ else
480
+ // no need to come back to this leaf node
481
+ {
482
+ query.primitive_counter = -1;
483
+ }
484
+ index = primitive_index;
485
+ query.bounds_nr = primitive_index;
486
+
487
+ return true;
488
+ }
489
+ }
490
+ }
491
+ else
492
+ {
493
+ query.stack[query.count++] = left_index;
494
+ query.stack[query.count++] = right_index;
495
+ }
496
+ }
497
+ return false;
387
498
  }
388
499
 
389
500
 
@@ -421,7 +532,7 @@ CUDA_CALLABLE void bvh_rem_descriptor(uint64_t id);
421
532
 
422
533
  #if !__CUDA_ARCH__
423
534
 
424
- void bvh_create_host(vec3* lowers, vec3* uppers, int num_items, BVH& bvh);
535
+ void bvh_create_host(vec3* lowers, vec3* uppers, int num_items, int constructor_type, BVH& bvh);
425
536
  void bvh_destroy_host(wp::BVH& bvh);
426
537
  void bvh_refit_host(wp::BVH& bvh);
427
538
 
@@ -431,4 +542,3 @@ void bvh_refit_device(uint64_t id);
431
542
  #endif
432
543
 
433
544
  } // namespace wp
434
-
@@ -218,7 +218,7 @@ static std::unique_ptr<llvm::Module> cuda_to_llvm(const std::string& input_file,
218
218
 
219
219
  extern "C" {
220
220
 
221
- WP_API int compile_cpp(const char* cpp_src, const char *input_file, const char* include_dir, const char* output_file, bool debug, bool verify_fp)
221
+ WP_API int compile_cpp(const char* cpp_src, const char *input_file, const char* include_dir, const char* output_file, bool debug, bool verify_fp, bool fuse_fp)
222
222
  {
223
223
  initialize_llvm();
224
224
 
@@ -236,6 +236,10 @@ WP_API int compile_cpp(const char* cpp_src, const char *input_file, const char*
236
236
  const char* CPU = "generic";
237
237
  const char* features = "";
238
238
  llvm::TargetOptions target_options;
239
+ if (fuse_fp)
240
+ target_options.AllowFPOpFusion = llvm::FPOpFusion::Standard;
241
+ else
242
+ target_options.AllowFPOpFusion = llvm::FPOpFusion::Strict;
239
243
  llvm::Reloc::Model relocation_model = llvm::Reloc::PIC_; // Position Independent Code
240
244
  llvm::CodeModel::Model code_model = llvm::CodeModel::Large; // Don't make assumptions about displacement sizes
241
245
  llvm::TargetMachine* target_machine = target->createTargetMachine(target_triple, CPU, features, target_options, relocation_model, code_model);
warp/native/coloring.cpp CHANGED
@@ -590,7 +590,11 @@ extern "C"
590
590
  if (num_colors > 1) {
591
591
  std::vector<std::vector<int>> color_groups;
592
592
  convert_to_color_groups(num_colors, graph.node_colors, color_groups);
593
- return balance_color_groups(target_max_min_ratio, graph, color_groups);
593
+
594
+ float max_min_ratio = balance_color_groups(target_max_min_ratio, graph, color_groups);
595
+ memcpy(node_colors.data, graph.node_colors.data(), num_nodes * sizeof(int));
596
+
597
+ return max_min_ratio;
594
598
  }
595
599
  else
596
600
  {
warp/native/cuda_util.cpp CHANGED
@@ -102,6 +102,11 @@ static PFN_cuGraphicsGLRegisterBuffer_v3000 pfn_cuGraphicsGLRegisterBuffer;
102
102
  static PFN_cuGraphicsUnregisterResource_v3000 pfn_cuGraphicsUnregisterResource;
103
103
  static PFN_cuModuleGetGlobal_v3020 pfn_cuModuleGetGlobal;
104
104
  static PFN_cuFuncSetAttribute_v9000 pfn_cuFuncSetAttribute;
105
+ static PFN_cuIpcGetEventHandle_v4010 pfn_cuIpcGetEventHandle;
106
+ static PFN_cuIpcOpenEventHandle_v4010 pfn_cuIpcOpenEventHandle;
107
+ static PFN_cuIpcGetMemHandle_v4010 pfn_cuIpcGetMemHandle;
108
+ static PFN_cuIpcOpenMemHandle_v11000 pfn_cuIpcOpenMemHandle;
109
+ static PFN_cuIpcCloseMemHandle_v4010 pfn_cuIpcCloseMemHandle;
105
110
 
106
111
  static bool cuda_driver_initialized = false;
107
112
 
@@ -120,15 +125,17 @@ static inline int get_minor(int version)
120
125
  return (version % 1000) / 10;
121
126
  }
122
127
 
123
- static bool get_driver_entry_point(const char* name, void** pfn)
128
+ // Get versioned driver entry point. The version argument should match the function pointer type.
129
+ // For example, to initialize PFN_cuCtxCreate_v3020 use version 3020.
130
+ static bool get_driver_entry_point(const char* name, int version, void** pfn)
124
131
  {
125
132
  if (!pfn_cuGetProcAddress || !name || !pfn)
126
133
  return false;
127
134
 
128
135
  #if CUDA_VERSION < 12000
129
- CUresult r = pfn_cuGetProcAddress(name, pfn, WP_CUDA_DRIVER_VERSION, CU_GET_PROC_ADDRESS_DEFAULT);
136
+ CUresult r = pfn_cuGetProcAddress(name, pfn, version, CU_GET_PROC_ADDRESS_DEFAULT);
130
137
  #else
131
- CUresult r = pfn_cuGetProcAddress(name, pfn, WP_CUDA_DRIVER_VERSION, CU_GET_PROC_ADDRESS_DEFAULT, NULL);
138
+ CUresult r = pfn_cuGetProcAddress(name, pfn, version, CU_GET_PROC_ADDRESS_DEFAULT, NULL);
132
139
  #endif
133
140
 
134
141
  if (r != CUDA_SUCCESS)
@@ -170,7 +177,8 @@ bool init_cuda_driver()
170
177
 
171
178
  // check the CUDA driver version and report an error if it's too low
172
179
  int driver_version = 0;
173
- if (get_driver_entry_point("cuDriverGetVersion", &(void*&)pfn_cuDriverGetVersion) && check_cu(pfn_cuDriverGetVersion(&driver_version)))
180
+ if (get_driver_entry_point("cuDriverGetVersion", 2020, &(void*&)pfn_cuDriverGetVersion) &&
181
+ check_cu(pfn_cuDriverGetVersion(&driver_version)))
174
182
  {
175
183
  if (driver_version < WP_CUDA_DRIVER_VERSION)
176
184
  {
@@ -186,55 +194,60 @@ bool init_cuda_driver()
186
194
  }
187
195
 
188
196
  // initialize driver entry points
189
- get_driver_entry_point("cuGetErrorString", &(void*&)pfn_cuGetErrorString);
190
- get_driver_entry_point("cuGetErrorName", &(void*&)pfn_cuGetErrorName);
191
- get_driver_entry_point("cuInit", &(void*&)pfn_cuInit);
192
- get_driver_entry_point("cuDeviceGet", &(void*&)pfn_cuDeviceGet);
193
- get_driver_entry_point("cuDeviceGetCount", &(void*&)pfn_cuDeviceGetCount);
194
- get_driver_entry_point("cuDeviceGetName", &(void*&)pfn_cuDeviceGetName);
195
- get_driver_entry_point("cuDeviceGetAttribute", &(void*&)pfn_cuDeviceGetAttribute);
196
- get_driver_entry_point("cuDeviceGetUuid", &(void*&)pfn_cuDeviceGetUuid);
197
- get_driver_entry_point("cuDevicePrimaryCtxRetain", &(void*&)pfn_cuDevicePrimaryCtxRetain);
198
- get_driver_entry_point("cuDevicePrimaryCtxRelease", &(void*&)pfn_cuDevicePrimaryCtxRelease);
199
- get_driver_entry_point("cuDeviceCanAccessPeer", &(void*&)pfn_cuDeviceCanAccessPeer);
200
- get_driver_entry_point("cuMemGetInfo", &(void*&)pfn_cuMemGetInfo);
201
- get_driver_entry_point("cuCtxSetCurrent", &(void*&)pfn_cuCtxSetCurrent);
202
- get_driver_entry_point("cuCtxGetCurrent", &(void*&)pfn_cuCtxGetCurrent);
203
- get_driver_entry_point("cuCtxPushCurrent", &(void*&)pfn_cuCtxPushCurrent);
204
- get_driver_entry_point("cuCtxPopCurrent", &(void*&)pfn_cuCtxPopCurrent);
205
- get_driver_entry_point("cuCtxSynchronize", &(void*&)pfn_cuCtxSynchronize);
206
- get_driver_entry_point("cuCtxGetDevice", &(void*&)pfn_cuCtxGetDevice);
207
- get_driver_entry_point("cuCtxCreate", &(void*&)pfn_cuCtxCreate);
208
- get_driver_entry_point("cuCtxDestroy", &(void*&)pfn_cuCtxDestroy);
209
- get_driver_entry_point("cuCtxEnablePeerAccess", &(void*&)pfn_cuCtxEnablePeerAccess);
210
- get_driver_entry_point("cuCtxDisablePeerAccess", &(void*&)pfn_cuCtxDisablePeerAccess);
211
- get_driver_entry_point("cuStreamCreate", &(void*&)pfn_cuStreamCreate);
212
- get_driver_entry_point("cuStreamDestroy", &(void*&)pfn_cuStreamDestroy);
213
- get_driver_entry_point("cuStreamSynchronize", &(void*&)pfn_cuStreamSynchronize);
214
- get_driver_entry_point("cuStreamWaitEvent", &(void*&)pfn_cuStreamWaitEvent);
215
- get_driver_entry_point("cuStreamGetCtx", &(void*&)pfn_cuStreamGetCtx);
216
- get_driver_entry_point("cuStreamGetCaptureInfo", &(void*&)pfn_cuStreamGetCaptureInfo);
217
- get_driver_entry_point("cuStreamUpdateCaptureDependencies", &(void*&)pfn_cuStreamUpdateCaptureDependencies);
218
- get_driver_entry_point("cuStreamCreateWithPriority", &(void*&)pfn_cuStreamCreateWithPriority);
219
- get_driver_entry_point("cuStreamGetPriority", &(void*&)pfn_cuStreamGetPriority);
220
- get_driver_entry_point("cuEventCreate", &(void*&)pfn_cuEventCreate);
221
- get_driver_entry_point("cuEventDestroy", &(void*&)pfn_cuEventDestroy);
222
- get_driver_entry_point("cuEventRecord", &(void*&)pfn_cuEventRecord);
223
- get_driver_entry_point("cuEventRecordWithFlags", &(void*&)pfn_cuEventRecordWithFlags);
224
- get_driver_entry_point("cuEventSynchronize", &(void*&)pfn_cuEventSynchronize);
225
- get_driver_entry_point("cuModuleLoadDataEx", &(void*&)pfn_cuModuleLoadDataEx);
226
- get_driver_entry_point("cuModuleUnload", &(void*&)pfn_cuModuleUnload);
227
- get_driver_entry_point("cuModuleGetFunction", &(void*&)pfn_cuModuleGetFunction);
228
- get_driver_entry_point("cuLaunchKernel", &(void*&)pfn_cuLaunchKernel);
229
- get_driver_entry_point("cuMemcpyPeerAsync", &(void*&)pfn_cuMemcpyPeerAsync);
230
- get_driver_entry_point("cuPointerGetAttribute", &(void*&)pfn_cuPointerGetAttribute);
231
- get_driver_entry_point("cuGraphicsMapResources", &(void*&)pfn_cuGraphicsMapResources);
232
- get_driver_entry_point("cuGraphicsUnmapResources", &(void*&)pfn_cuGraphicsUnmapResources);
233
- get_driver_entry_point("cuGraphicsResourceGetMappedPointer", &(void*&)pfn_cuGraphicsResourceGetMappedPointer);
234
- get_driver_entry_point("cuGraphicsGLRegisterBuffer", &(void*&)pfn_cuGraphicsGLRegisterBuffer);
235
- get_driver_entry_point("cuGraphicsUnregisterResource", &(void*&)pfn_cuGraphicsUnregisterResource);
236
- get_driver_entry_point("cuModuleGetGlobal", &(void*&)pfn_cuModuleGetGlobal);
237
- get_driver_entry_point("cuFuncSetAttribute", &(void*&)pfn_cuFuncSetAttribute);
197
+ get_driver_entry_point("cuGetErrorString", 6000, &(void*&)pfn_cuGetErrorString);
198
+ get_driver_entry_point("cuGetErrorName", 6000, &(void*&)pfn_cuGetErrorName);
199
+ get_driver_entry_point("cuInit", 2000, &(void*&)pfn_cuInit);
200
+ get_driver_entry_point("cuDeviceGet", 2000, &(void*&)pfn_cuDeviceGet);
201
+ get_driver_entry_point("cuDeviceGetCount", 2000, &(void*&)pfn_cuDeviceGetCount);
202
+ get_driver_entry_point("cuDeviceGetName", 2000, &(void*&)pfn_cuDeviceGetName);
203
+ get_driver_entry_point("cuDeviceGetAttribute", 2000, &(void*&)pfn_cuDeviceGetAttribute);
204
+ get_driver_entry_point("cuDeviceGetUuid", 110400, &(void*&)pfn_cuDeviceGetUuid);
205
+ get_driver_entry_point("cuDevicePrimaryCtxRetain", 7000, &(void*&)pfn_cuDevicePrimaryCtxRetain);
206
+ get_driver_entry_point("cuDevicePrimaryCtxRelease", 11000, &(void*&)pfn_cuDevicePrimaryCtxRelease);
207
+ get_driver_entry_point("cuDeviceCanAccessPeer", 4000, &(void*&)pfn_cuDeviceCanAccessPeer);
208
+ get_driver_entry_point("cuMemGetInfo", 3020, &(void*&)pfn_cuMemGetInfo);
209
+ get_driver_entry_point("cuCtxSetCurrent", 4000, &(void*&)pfn_cuCtxSetCurrent);
210
+ get_driver_entry_point("cuCtxGetCurrent", 4000, &(void*&)pfn_cuCtxGetCurrent);
211
+ get_driver_entry_point("cuCtxPushCurrent", 4000, &(void*&)pfn_cuCtxPushCurrent);
212
+ get_driver_entry_point("cuCtxPopCurrent", 4000, &(void*&)pfn_cuCtxPopCurrent);
213
+ get_driver_entry_point("cuCtxSynchronize", 2000, &(void*&)pfn_cuCtxSynchronize);
214
+ get_driver_entry_point("cuCtxGetDevice", 2000, &(void*&)pfn_cuCtxGetDevice);
215
+ get_driver_entry_point("cuCtxCreate", 3020, &(void*&)pfn_cuCtxCreate);
216
+ get_driver_entry_point("cuCtxDestroy", 4000, &(void*&)pfn_cuCtxDestroy);
217
+ get_driver_entry_point("cuCtxEnablePeerAccess", 4000, &(void*&)pfn_cuCtxEnablePeerAccess);
218
+ get_driver_entry_point("cuCtxDisablePeerAccess", 4000, &(void*&)pfn_cuCtxDisablePeerAccess);
219
+ get_driver_entry_point("cuStreamCreate", 2000, &(void*&)pfn_cuStreamCreate);
220
+ get_driver_entry_point("cuStreamDestroy", 4000, &(void*&)pfn_cuStreamDestroy);
221
+ get_driver_entry_point("cuStreamSynchronize", 2000, &(void*&)pfn_cuStreamSynchronize);
222
+ get_driver_entry_point("cuStreamWaitEvent", 3020, &(void*&)pfn_cuStreamWaitEvent);
223
+ get_driver_entry_point("cuStreamGetCtx", 9020, &(void*&)pfn_cuStreamGetCtx);
224
+ get_driver_entry_point("cuStreamGetCaptureInfo", 11030, &(void*&)pfn_cuStreamGetCaptureInfo);
225
+ get_driver_entry_point("cuStreamUpdateCaptureDependencies", 11030, &(void*&)pfn_cuStreamUpdateCaptureDependencies);
226
+ get_driver_entry_point("cuStreamCreateWithPriority", 5050, &(void*&)pfn_cuStreamCreateWithPriority);
227
+ get_driver_entry_point("cuStreamGetPriority", 5050, &(void*&)pfn_cuStreamGetPriority);
228
+ get_driver_entry_point("cuEventCreate", 2000, &(void*&)pfn_cuEventCreate);
229
+ get_driver_entry_point("cuEventDestroy", 4000, &(void*&)pfn_cuEventDestroy);
230
+ get_driver_entry_point("cuEventRecord", 2000, &(void*&)pfn_cuEventRecord);
231
+ get_driver_entry_point("cuEventRecordWithFlags", 11010, &(void*&)pfn_cuEventRecordWithFlags);
232
+ get_driver_entry_point("cuEventSynchronize", 2000, &(void*&)pfn_cuEventSynchronize);
233
+ get_driver_entry_point("cuModuleLoadDataEx", 2010, &(void*&)pfn_cuModuleLoadDataEx);
234
+ get_driver_entry_point("cuModuleUnload", 2000, &(void*&)pfn_cuModuleUnload);
235
+ get_driver_entry_point("cuModuleGetFunction", 2000, &(void*&)pfn_cuModuleGetFunction);
236
+ get_driver_entry_point("cuLaunchKernel", 4000, &(void*&)pfn_cuLaunchKernel);
237
+ get_driver_entry_point("cuMemcpyPeerAsync", 4000, &(void*&)pfn_cuMemcpyPeerAsync);
238
+ get_driver_entry_point("cuPointerGetAttribute", 4000, &(void*&)pfn_cuPointerGetAttribute);
239
+ get_driver_entry_point("cuGraphicsMapResources", 3000, &(void*&)pfn_cuGraphicsMapResources);
240
+ get_driver_entry_point("cuGraphicsUnmapResources", 3000, &(void*&)pfn_cuGraphicsUnmapResources);
241
+ get_driver_entry_point("cuGraphicsResourceGetMappedPointer", 3020, &(void*&)pfn_cuGraphicsResourceGetMappedPointer);
242
+ get_driver_entry_point("cuGraphicsGLRegisterBuffer", 3000, &(void*&)pfn_cuGraphicsGLRegisterBuffer);
243
+ get_driver_entry_point("cuGraphicsUnregisterResource", 3000, &(void*&)pfn_cuGraphicsUnregisterResource);
244
+ get_driver_entry_point("cuModuleGetGlobal", 3020, &(void*&)pfn_cuModuleGetGlobal);
245
+ get_driver_entry_point("cuFuncSetAttribute", 9000, &(void*&)pfn_cuFuncSetAttribute);
246
+ get_driver_entry_point("cuIpcGetEventHandle", 4010, &(void*&)pfn_cuIpcGetEventHandle);
247
+ get_driver_entry_point("cuIpcOpenEventHandle", 4010, &(void*&)pfn_cuIpcOpenEventHandle);
248
+ get_driver_entry_point("cuIpcGetMemHandle", 4010, &(void*&)pfn_cuIpcGetMemHandle);
249
+ get_driver_entry_point("cuIpcOpenMemHandle", 11000, &(void*&)pfn_cuIpcOpenMemHandle);
250
+ get_driver_entry_point("cuIpcCloseMemHandle", 4010, &(void*&)pfn_cuIpcCloseMemHandle);
238
251
 
239
252
  if (pfn_cuInit)
240
253
  cuda_driver_initialized = check_cu(pfn_cuInit(0));
@@ -582,4 +595,29 @@ CUresult cuFuncSetAttribute_f(CUfunction hfunc, CUfunction_attribute attrib, int
582
595
  return pfn_cuFuncSetAttribute ? pfn_cuFuncSetAttribute(hfunc, attrib, value) : DRIVER_ENTRY_POINT_ERROR;
583
596
  }
584
597
 
598
+ CUresult cuIpcGetEventHandle_f(CUipcEventHandle *pHandle, CUevent event)
599
+ {
600
+ return pfn_cuIpcGetEventHandle ? pfn_cuIpcGetEventHandle(pHandle, event) : DRIVER_ENTRY_POINT_ERROR;
601
+ }
602
+
603
+ CUresult cuIpcOpenEventHandle_f(CUevent *phEvent, CUipcEventHandle handle)
604
+ {
605
+ return pfn_cuIpcOpenEventHandle ? pfn_cuIpcOpenEventHandle(phEvent, handle) : DRIVER_ENTRY_POINT_ERROR;
606
+ }
607
+
608
+ CUresult cuIpcGetMemHandle_f(CUipcMemHandle *pHandle, CUdeviceptr dptr)
609
+ {
610
+ return pfn_cuIpcGetMemHandle ? pfn_cuIpcGetMemHandle(pHandle, dptr) : DRIVER_ENTRY_POINT_ERROR;
611
+ }
612
+
613
+ CUresult cuIpcOpenMemHandle_f(CUdeviceptr *pdptr, CUipcMemHandle handle, unsigned int flags)
614
+ {
615
+ return pfn_cuIpcOpenMemHandle ? pfn_cuIpcOpenMemHandle(pdptr, handle, flags) : DRIVER_ENTRY_POINT_ERROR;
616
+ }
617
+
618
+ CUresult cuIpcCloseMemHandle_f(CUdeviceptr dptr)
619
+ {
620
+ return pfn_cuIpcCloseMemHandle ? pfn_cuIpcCloseMemHandle(dptr) : DRIVER_ENTRY_POINT_ERROR;
621
+ }
622
+
585
623
  #endif // WP_ENABLE_CUDA
warp/native/cuda_util.h CHANGED
@@ -101,6 +101,11 @@ CUresult cuGraphicsGLRegisterBuffer_f(CUgraphicsResource *pCudaResource, unsigne
101
101
  CUresult cuGraphicsUnregisterResource_f(CUgraphicsResource resource);
102
102
  CUresult cuModuleGetGlobal_f(CUdeviceptr* dptr, size_t* bytes, CUmodule hmod, const char* name );
103
103
  CUresult cuFuncSetAttribute_f(CUfunction hfunc, CUfunction_attribute attrib, int value);
104
+ CUresult cuIpcGetEventHandle_f(CUipcEventHandle *pHandle, CUevent event);
105
+ CUresult cuIpcOpenEventHandle_f(CUevent *phEvent, CUipcEventHandle handle);
106
+ CUresult cuIpcGetMemHandle_f(CUipcMemHandle *pHandle, CUdeviceptr dptr);
107
+ CUresult cuIpcOpenMemHandle_f(CUdeviceptr *pdptr, CUipcMemHandle handle, unsigned int flags);
108
+ CUresult cuIpcCloseMemHandle_f(CUdeviceptr dptr);
104
109
 
105
110
  bool init_cuda_driver();
106
111
  bool is_cuda_driver_initialized();