warp-lang 1.5.1__py3-none-win_amd64.whl → 1.6.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (123) hide show
  1. warp/__init__.py +5 -0
  2. warp/autograd.py +414 -191
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +40 -12
  6. warp/build_dll.py +13 -6
  7. warp/builtins.py +1076 -480
  8. warp/codegen.py +240 -119
  9. warp/config.py +1 -1
  10. warp/context.py +298 -84
  11. warp/examples/assets/square_cloth.usd +0 -0
  12. warp/examples/benchmarks/benchmark_gemm.py +27 -18
  13. warp/examples/benchmarks/benchmark_interop_paddle.py +3 -3
  14. warp/examples/benchmarks/benchmark_interop_torch.py +3 -3
  15. warp/examples/core/example_torch.py +18 -34
  16. warp/examples/fem/example_apic_fluid.py +1 -0
  17. warp/examples/fem/example_mixed_elasticity.py +1 -1
  18. warp/examples/optim/example_bounce.py +1 -1
  19. warp/examples/optim/example_cloth_throw.py +1 -1
  20. warp/examples/optim/example_diffray.py +4 -15
  21. warp/examples/optim/example_drone.py +1 -1
  22. warp/examples/optim/example_softbody_properties.py +392 -0
  23. warp/examples/optim/example_trajectory.py +1 -3
  24. warp/examples/optim/example_walker.py +5 -0
  25. warp/examples/sim/example_cartpole.py +0 -2
  26. warp/examples/sim/example_cloth_self_contact.py +260 -0
  27. warp/examples/sim/example_granular_collision_sdf.py +4 -5
  28. warp/examples/sim/example_jacobian_ik.py +0 -2
  29. warp/examples/sim/example_quadruped.py +5 -2
  30. warp/examples/tile/example_tile_cholesky.py +79 -0
  31. warp/examples/tile/example_tile_convolution.py +2 -2
  32. warp/examples/tile/example_tile_fft.py +2 -2
  33. warp/examples/tile/example_tile_filtering.py +3 -3
  34. warp/examples/tile/example_tile_matmul.py +4 -4
  35. warp/examples/tile/example_tile_mlp.py +12 -12
  36. warp/examples/tile/example_tile_nbody.py +180 -0
  37. warp/examples/tile/example_tile_walker.py +319 -0
  38. warp/math.py +147 -0
  39. warp/native/array.h +12 -0
  40. warp/native/builtin.h +0 -1
  41. warp/native/bvh.cpp +149 -70
  42. warp/native/bvh.cu +287 -68
  43. warp/native/bvh.h +195 -85
  44. warp/native/clang/clang.cpp +5 -1
  45. warp/native/cuda_util.cpp +35 -0
  46. warp/native/cuda_util.h +5 -0
  47. warp/native/exports.h +40 -40
  48. warp/native/intersect.h +17 -0
  49. warp/native/mat.h +41 -0
  50. warp/native/mathdx.cpp +19 -0
  51. warp/native/mesh.cpp +25 -8
  52. warp/native/mesh.cu +153 -101
  53. warp/native/mesh.h +482 -403
  54. warp/native/quat.h +40 -0
  55. warp/native/solid_angle.h +7 -0
  56. warp/native/sort.cpp +85 -0
  57. warp/native/sort.cu +34 -0
  58. warp/native/sort.h +3 -1
  59. warp/native/spatial.h +11 -0
  60. warp/native/tile.h +1185 -664
  61. warp/native/tile_reduce.h +8 -6
  62. warp/native/vec.h +41 -0
  63. warp/native/warp.cpp +8 -1
  64. warp/native/warp.cu +263 -40
  65. warp/native/warp.h +19 -5
  66. warp/optim/linear.py +22 -4
  67. warp/render/render_opengl.py +124 -59
  68. warp/sim/__init__.py +6 -1
  69. warp/sim/collide.py +270 -26
  70. warp/sim/integrator_euler.py +25 -7
  71. warp/sim/integrator_featherstone.py +154 -35
  72. warp/sim/integrator_vbd.py +842 -40
  73. warp/sim/model.py +111 -53
  74. warp/stubs.py +248 -115
  75. warp/tape.py +28 -30
  76. warp/tests/aux_test_module_unload.py +15 -0
  77. warp/tests/{test_sim_grad.py → flaky_test_sim_grad.py} +104 -63
  78. warp/tests/test_array.py +74 -0
  79. warp/tests/test_assert.py +242 -0
  80. warp/tests/test_codegen.py +14 -61
  81. warp/tests/test_collision.py +2 -2
  82. warp/tests/test_examples.py +9 -0
  83. warp/tests/test_grad_debug.py +87 -2
  84. warp/tests/test_hash_grid.py +1 -1
  85. warp/tests/test_ipc.py +116 -0
  86. warp/tests/test_mat.py +138 -167
  87. warp/tests/test_math.py +47 -1
  88. warp/tests/test_matmul.py +11 -7
  89. warp/tests/test_matmul_lite.py +4 -4
  90. warp/tests/test_mesh.py +84 -60
  91. warp/tests/test_mesh_query_aabb.py +165 -0
  92. warp/tests/test_mesh_query_point.py +328 -286
  93. warp/tests/test_mesh_query_ray.py +134 -121
  94. warp/tests/test_mlp.py +2 -2
  95. warp/tests/test_operators.py +43 -0
  96. warp/tests/test_overwrite.py +2 -2
  97. warp/tests/test_quat.py +77 -0
  98. warp/tests/test_reload.py +29 -0
  99. warp/tests/test_sim_grad_bounce_linear.py +204 -0
  100. warp/tests/test_static.py +16 -0
  101. warp/tests/test_tape.py +25 -0
  102. warp/tests/test_tile.py +134 -191
  103. warp/tests/test_tile_load.py +356 -0
  104. warp/tests/test_tile_mathdx.py +61 -8
  105. warp/tests/test_tile_mlp.py +17 -17
  106. warp/tests/test_tile_reduce.py +24 -18
  107. warp/tests/test_tile_shared_memory.py +66 -17
  108. warp/tests/test_tile_view.py +165 -0
  109. warp/tests/test_torch.py +35 -0
  110. warp/tests/test_utils.py +36 -24
  111. warp/tests/test_vec.py +110 -0
  112. warp/tests/unittest_suites.py +29 -4
  113. warp/tests/unittest_utils.py +30 -11
  114. warp/thirdparty/unittest_parallel.py +2 -2
  115. warp/types.py +409 -99
  116. warp/utils.py +9 -5
  117. {warp_lang-1.5.1.dist-info → warp_lang-1.6.0.dist-info}/METADATA +68 -44
  118. {warp_lang-1.5.1.dist-info → warp_lang-1.6.0.dist-info}/RECORD +121 -110
  119. {warp_lang-1.5.1.dist-info → warp_lang-1.6.0.dist-info}/WHEEL +1 -1
  120. warp/examples/benchmarks/benchmark_tile.py +0 -179
  121. warp/native/tile_gemm.h +0 -341
  122. {warp_lang-1.5.1.dist-info → warp_lang-1.6.0.dist-info}/LICENSE.md +0 -0
  123. {warp_lang-1.5.1.dist-info → warp_lang-1.6.0.dist-info}/top_level.txt +0 -0
warp/native/bvh.cpp CHANGED
@@ -23,11 +23,11 @@ namespace wp
23
23
 
24
24
  /////////////////////////////////////////////////////////////////////////////////////////////
25
25
 
26
- class MedianBVHBuilder
26
+ class TopDownBVHBuilder
27
27
  {
28
28
  public:
29
29
 
30
- void build(BVH& bvh, const vec3* lowers, const vec3* uppers, int n);
30
+ void build(BVH& bvh, const vec3* lowers, const vec3* uppers, int n, int in_constructor_type);
31
31
 
32
32
  private:
33
33
 
@@ -35,15 +35,26 @@ private:
35
35
 
36
36
  int partition_median(const vec3* lowers, const vec3* uppers, int* indices, int start, int end, bounds3 range_bounds);
37
37
  int partition_midpoint(const vec3* lowers, const vec3* uppers, int* indices, int start, int end, bounds3 range_bounds);
38
- int partition_sah(const vec3* lowers, const vec3* uppers, int* indices, int start, int end, bounds3 range_bounds);
38
+ float partition_sah(BVH& bvh, const vec3* lowers, const vec3* uppers,
39
+ int start, int end, bounds3 range_bounds, int& split_axis);
39
40
 
40
- int build_recursive(BVH& bvh, const vec3* lowers, const vec3* uppers, int* indices, int start, int end, int depth, int parent);
41
+ int build_recursive(BVH& bvh, const vec3* lowers, const vec3* uppers, int start, int end, int depth, int parent);
42
+
43
+ int constructor_type = -1;
41
44
  };
42
45
 
43
46
  //////////////////////////////////////////////////////////////////////
44
47
 
45
- void MedianBVHBuilder::build(BVH& bvh, const vec3* lowers, const vec3* uppers, int n)
48
+ void TopDownBVHBuilder::build(BVH& bvh, const vec3* lowers, const vec3* uppers, int n, int in_constructor_type)
46
49
  {
50
+ constructor_type = in_constructor_type;
51
+ if (constructor_type != BVH_CONSTRUCTOR_SAH && constructor_type != BVH_CONSTRUCTOR_MEDIAN)
52
+ {
53
+ printf("Unrecognized Constructor type: %d! For CPU constructor it should be either SAH (%d) or Median (%d)!\n",
54
+ constructor_type, BVH_CONSTRUCTOR_SAH, BVH_CONSTRUCTOR_MEDIAN);
55
+ return;
56
+ }
57
+
47
58
  bvh.max_depth = 0;
48
59
  bvh.max_nodes = 2*n-1;
49
60
 
@@ -51,7 +62,7 @@ void MedianBVHBuilder::build(BVH& bvh, const vec3* lowers, const vec3* uppers, i
51
62
  bvh.node_uppers = new BVHPackedNodeHalf[bvh.max_nodes];
52
63
  bvh.node_parents = new int[bvh.max_nodes];
53
64
  bvh.node_counts = NULL;
54
-
65
+
55
66
  // root is always in first slot for top down builders
56
67
  bvh.root = new int[1];
57
68
  bvh.root[0] = 0;
@@ -59,22 +70,21 @@ void MedianBVHBuilder::build(BVH& bvh, const vec3* lowers, const vec3* uppers, i
59
70
  if (n == 0)
60
71
  return;
61
72
 
62
- std::vector<int> indices(n);
63
- for (int i=0; i < n; ++i)
64
- indices[i] = i;
73
+ bvh.primitive_indices = new int[n];
74
+ for (int i = 0; i < n; ++i)
75
+ bvh.primitive_indices[i] = i;
65
76
 
66
- build_recursive(bvh, lowers, uppers, &indices[0], 0, n, 0, -1);
77
+ build_recursive(bvh, lowers, uppers, 0, n, 0, -1);
67
78
  }
68
79
 
69
80
 
70
- bounds3 MedianBVHBuilder::calc_bounds(const vec3* lowers, const vec3* uppers, const int* indices, int start, int end)
81
+ bounds3 TopDownBVHBuilder::calc_bounds(const vec3* lowers, const vec3* uppers, const int* indices, int start, int end)
71
82
  {
72
83
  bounds3 u;
73
84
 
74
85
  for (int i=start; i < end; ++i)
75
86
  {
76
- u.add_point(lowers[indices[i]]);
77
- u.add_point(uppers[indices[i]]);
87
+ u.add_bounds(lowers[indices[i]], uppers[indices[i]]);
78
88
  }
79
89
 
80
90
  return u;
@@ -98,7 +108,7 @@ struct PartitionPredicateMedian
98
108
  };
99
109
 
100
110
 
101
- int MedianBVHBuilder::partition_median(const vec3* lowers, const vec3* uppers, int* indices, int start, int end, bounds3 range_bounds)
111
+ int TopDownBVHBuilder::partition_median(const vec3* lowers, const vec3* uppers, int* indices, int start, int end, bounds3 range_bounds)
102
112
  {
103
113
  assert(end-start >= 2);
104
114
 
@@ -113,9 +123,9 @@ int MedianBVHBuilder::partition_median(const vec3* lowers, const vec3* uppers, i
113
123
  return k;
114
124
  }
115
125
 
116
- struct PartitionPredictateMidPoint
126
+ struct PartitionPredicateMidPoint
117
127
  {
118
- PartitionPredictateMidPoint(const vec3* lowers, const vec3* uppers, int a, float m) : lowers(lowers), uppers(uppers), axis(a), mid(m) {}
128
+ PartitionPredicateMidPoint(const vec3* lowers, const vec3* uppers, int a, float m) : lowers(lowers), uppers(uppers), axis(a), mid(m) {}
119
129
 
120
130
  bool operator()(int index) const
121
131
  {
@@ -132,7 +142,7 @@ struct PartitionPredictateMidPoint
132
142
  };
133
143
 
134
144
 
135
- int MedianBVHBuilder::partition_midpoint(const vec3* lowers, const vec3* uppers, int* indices, int start, int end, bounds3 range_bounds)
145
+ int TopDownBVHBuilder::partition_midpoint(const vec3* lowers, const vec3* uppers, int* indices, int start, int end, bounds3 range_bounds)
136
146
  {
137
147
  assert(end-start >= 2);
138
148
 
@@ -142,7 +152,7 @@ int MedianBVHBuilder::partition_midpoint(const vec3* lowers, const vec3* uppers,
142
152
  int axis = longest_axis(edges);
143
153
  float mid = center[axis];
144
154
 
145
- int* upper = std::partition(indices+start, indices+end, PartitionPredictateMidPoint(lowers, uppers, axis, mid));
155
+ int* upper = std::partition(indices+start, indices+end, PartitionPredicateMidPoint(lowers, uppers, axis, mid));
146
156
 
147
157
  int k = upper-indices;
148
158
 
@@ -153,50 +163,88 @@ int MedianBVHBuilder::partition_midpoint(const vec3* lowers, const vec3* uppers,
153
163
  return k;
154
164
  }
155
165
 
156
- // disable std::sort workaround for macOS error
157
- #if 0
158
- int MedianBVHBuilder::partition_sah(const bounds3* bounds, int* indices, int start, int end, bounds3 range_bounds)
166
+ float TopDownBVHBuilder::partition_sah(BVH& bvh, const vec3* lowers, const vec3* uppers, int start, int end, bounds3 range_bounds, int& split_axis)
159
167
  {
160
- assert(end-start >= 2);
168
+ int buckets_counts[SAH_NUM_BUCKETS];
169
+ bounds3 buckets[SAH_NUM_BUCKETS];
170
+ float left_areas[SAH_NUM_BUCKETS - 1];
171
+ float right_areas[SAH_NUM_BUCKETS - 1];
172
+
173
+ assert(end - start >= 2);
161
174
 
162
- int n = end-start;
175
+ int n = end - start;
163
176
  vec3 edges = range_bounds.edges();
164
177
 
165
- int longestAxis = longest_axis(edges);
178
+ bounds3 b = calc_bounds(lowers, uppers, bvh.primitive_indices, start, end);
166
179
 
167
- // sort along longest axis
168
- std::sort(&indices[0]+start, &indices[0]+end, PartitionPredicateMedian(&bounds[0], longestAxis));
180
+ split_axis = longest_axis(edges);
169
181
 
170
- // total area for range from [0, split]
171
- std::vector<float> left_areas(n);
172
- // total area for range from (split, end]
173
- std::vector<float> right_areas(n);
182
+ // compute each bucket
183
+ float range_start = b.lower[split_axis];
184
+ float range_end = b.upper[split_axis];
185
+
186
+ std::fill(buckets_counts, buckets_counts + SAH_NUM_BUCKETS, 0);
187
+ for (int item_idx = start; item_idx < end; item_idx++)
188
+ {
189
+ vec3 item_center = 0.5f * (lowers[bvh.primitive_indices[item_idx]] + uppers[bvh.primitive_indices[item_idx]]);
190
+ int bucket_idx = SAH_NUM_BUCKETS * (item_center[split_axis] - range_start) / (range_end - range_start);
191
+ assert(bucket_idx >= 0 && bucket_idx <= SAH_NUM_BUCKETS);
192
+ // one of them will have the range_end, we put it into the last bucket
193
+ bucket_idx = bucket_idx < SAH_NUM_BUCKETS ? bucket_idx : SAH_NUM_BUCKETS - 1;
194
+
195
+ bounds3 item_bound(lowers[bvh.primitive_indices[item_idx]], uppers[bvh.primitive_indices[item_idx]]);
196
+
197
+ if (buckets_counts[bucket_idx])
198
+ {
199
+ buckets[bucket_idx] = bounds_union(item_bound, buckets[bucket_idx]);
200
+ }
201
+ else
202
+ {
203
+ buckets[bucket_idx] = item_bound;
204
+ }
205
+
206
+ buckets_counts[bucket_idx]++;
207
+ }
174
208
 
175
209
  bounds3 left;
176
210
  bounds3 right;
177
211
 
212
+ // n - 1 division points for n buckets
213
+ int counts_l[SAH_NUM_BUCKETS - 1];
214
+ int counts_r[SAH_NUM_BUCKETS - 1];
215
+
216
+ int count_l = 0;
217
+ int count_r = 0;
178
218
  // build cumulative bounds and area from left and right
179
- for (int i=0; i < n; ++i)
219
+ for (int i = 0; i < SAH_NUM_BUCKETS - 1; ++i)
180
220
  {
181
- left = bounds_union(left, bounds[indices[start+i]]);
182
- right = bounds_union(right, bounds[indices[end-i-1]]);
221
+ bounds3 bound_start = buckets[i];
222
+ bounds3 bound_end = buckets[SAH_NUM_BUCKETS - i - 1];
223
+
224
+ left = bounds_union(left, bound_start);
225
+ right = bounds_union(right, bound_end);
183
226
 
184
227
  left_areas[i] = left.area();
185
- right_areas[n-i-1] = right.area();
228
+ right_areas[SAH_NUM_BUCKETS - i - 2] = right.area();
229
+
230
+ count_l += buckets_counts[i];
231
+ count_r += buckets_counts[SAH_NUM_BUCKETS - i - 1];
232
+
233
+ counts_l[i] = count_l;
234
+ counts_r[SAH_NUM_BUCKETS - i - 2] = count_r;
186
235
  }
187
236
 
188
- float invTotalArea = 1.0f/range_bounds.area();
237
+ float invTotalArea = 1.0f / range_bounds.area();
189
238
 
190
- // find split point i that minimizes area(left[i]) + area(right[i])
239
+ // find split point i that minimizes area(left[i]) * count[left[i]] + area(right[i]) * count[right[i]]
191
240
  int minSplit = 0;
192
241
  float minCost = FLT_MAX;
193
-
194
- for (int i=0; i < n; ++i)
242
+ for (int i = 0; i < SAH_NUM_BUCKETS - 1; ++i)
195
243
  {
196
- float pBelow = left_areas[i]*invTotalArea;
197
- float pAbove = right_areas[i]*invTotalArea;
244
+ float pBelow = left_areas[i] * invTotalArea;
245
+ float pAbove = right_areas[i] * invTotalArea;
198
246
 
199
- float cost = pBelow*i + pAbove*(n-i);
247
+ float cost = pBelow * counts_l[i] + pAbove * counts_r[i];
200
248
 
201
249
  if (cost < minCost)
202
250
  {
@@ -205,15 +253,20 @@ int MedianBVHBuilder::partition_sah(const bounds3* bounds, int* indices, int sta
205
253
  }
206
254
  }
207
255
 
208
- return start + minSplit + 1;
256
+ // return the dividing
257
+ assert(minSplit >= 0 && minSplit < SAH_NUM_BUCKETS - 1);
258
+ float split_point = range_start + (minSplit + 1) * (range_end - range_start) / SAH_NUM_BUCKETS;
259
+
260
+ return split_point;
209
261
  }
210
- #endif
211
262
 
212
- int MedianBVHBuilder::build_recursive(BVH& bvh, const vec3* lowers, const vec3* uppers, int* indices, int start, int end, int depth, int parent)
263
+ int TopDownBVHBuilder::build_recursive(BVH& bvh, const vec3* lowers, const vec3* uppers, int start, int end, int depth, int parent)
213
264
  {
214
265
  assert(start < end);
215
266
 
216
- const int n = end-start;
267
+ // printf("start %d end %d\n", start, end);
268
+
269
+ const int n = end - start;
217
270
  const int node_index = bvh.num_nodes++;
218
271
 
219
272
  assert(node_index < bvh.max_nodes);
@@ -221,31 +274,50 @@ int MedianBVHBuilder::build_recursive(BVH& bvh, const vec3* lowers, const vec3*
221
274
  if (depth > bvh.max_depth)
222
275
  bvh.max_depth = depth;
223
276
 
224
- bounds3 b = calc_bounds(lowers, uppers, indices, start, end);
225
-
226
- const int kMaxItemsPerLeaf = 1;
277
+ bounds3 b = calc_bounds(lowers, uppers, bvh.primitive_indices, start, end);
227
278
 
228
- if (n <= kMaxItemsPerLeaf)
279
+ if (n <= BVH_LEAF_SIZE)
229
280
  {
230
- bvh.node_lowers[node_index] = make_node(b.lower, indices[start], true);
231
- bvh.node_uppers[node_index] = make_node(b.upper, indices[start], false);
281
+ bvh.node_lowers[node_index] = make_node(b.lower, start, true);
282
+ bvh.node_uppers[node_index] = make_node(b.upper, end, false);
232
283
  bvh.node_parents[node_index] = parent;
284
+ bvh.num_leaf_nodes++;
233
285
  }
234
- else
286
+ else
235
287
  {
236
- //int split = partition_midpoint(bounds, indices, start, end, b);
237
- int split = partition_median(lowers, uppers, indices, start, end, b);
238
- //int split = partition_sah(bounds, indices, start, end, b);
288
+ int split = -1;
289
+ if (constructor_type == BVH_CONSTRUCTOR_SAH)
290
+ // SAH constructor
291
+ {
292
+ int split_axis = -1;
293
+ float split_point = partition_sah(bvh, lowers, uppers, start, end, b, split_axis);
294
+ auto boundary = std::partition(bvh.primitive_indices + start, bvh.primitive_indices + end,
295
+ [&](int i) {
296
+ return 0.5f * (lowers[i] + uppers[i])[split_axis] < split_point;
297
+ });
298
+
299
+ split = std::distance(bvh.primitive_indices + start, boundary) + start;
300
+ }
301
+ else if (constructor_type == BVH_CONSTRUCTOR_MEDIAN)
302
+ // Median constructor
303
+ {
304
+ split = partition_median(lowers, uppers, bvh.primitive_indices, start, end, b);
305
+ }
306
+ else
307
+ {
308
+ printf("Unknown type of BVH constructor: %d!\n", constructor_type);
309
+ return -1;
310
+ }
239
311
 
240
312
  if (split == start || split == end)
241
313
  {
242
314
  // partitioning failed, split down the middle
243
- split = (start+end)/2;
315
+ split = (start + end) / 2;
244
316
  }
245
-
246
- int left_child = build_recursive(bvh, lowers, uppers, indices, start, split, depth+1, node_index);
247
- int right_child = build_recursive(bvh, lowers, uppers, indices, split, end, depth+1, node_index);
248
-
317
+
318
+ int left_child = build_recursive(bvh, lowers, uppers, start, split, depth + 1, node_index);
319
+ int right_child = build_recursive(bvh, lowers, uppers, split, end, depth + 1, node_index);
320
+
249
321
  bvh.node_lowers[node_index] = make_node(b.lower, left_child, false);
250
322
  bvh.node_uppers[node_index] = make_node(b.upper, right_child, false);
251
323
  bvh.node_parents[node_index] = parent;
@@ -262,11 +334,16 @@ void bvh_refit_recursive(BVH& bvh, int index)
262
334
 
263
335
  if (lower.b)
264
336
  {
265
- const int leaf_index = lower.i;
266
-
267
337
  // update leaf from items
268
- (vec3&)lower = bvh.item_lowers[leaf_index];
269
- (vec3&)upper = bvh.item_uppers[leaf_index];
338
+ bounds3 bound;
339
+ for (int item_counter = lower.i; item_counter < upper.i; item_counter++)
340
+ {
341
+ const int item = bvh.primitive_indices[item_counter];
342
+ bound.add_bounds(bvh.item_lowers[item], bvh.item_uppers[item]);
343
+ }
344
+
345
+ (vec3&)lower = bound.lower;
346
+ (vec3&)upper = bound.upper;
270
347
  }
271
348
  else
272
349
  {
@@ -340,7 +417,7 @@ void bvh_rem_descriptor(uint64_t id)
340
417
 
341
418
 
342
419
  // create in-place given existing descriptor
343
- void bvh_create_host(vec3* lowers, vec3* uppers, int num_items, BVH& bvh)
420
+ void bvh_create_host(vec3* lowers, vec3* uppers, int num_items, int constructor_type, BVH& bvh)
344
421
  {
345
422
  memset(&bvh, 0, sizeof(BVH));
346
423
 
@@ -348,8 +425,8 @@ void bvh_create_host(vec3* lowers, vec3* uppers, int num_items, BVH& bvh)
348
425
  bvh.item_uppers = uppers;
349
426
  bvh.num_items = num_items;
350
427
 
351
- MedianBVHBuilder builder;
352
- builder.build(bvh, lowers, uppers, num_items);
428
+ TopDownBVHBuilder builder;
429
+ builder.build(bvh, lowers, uppers, num_items, constructor_type);
353
430
  }
354
431
 
355
432
  void bvh_destroy_host(BVH& bvh)
@@ -357,11 +434,13 @@ void bvh_destroy_host(BVH& bvh)
357
434
  delete[] bvh.node_lowers;
358
435
  delete[] bvh.node_uppers;
359
436
  delete[] bvh.node_parents;
437
+ delete[] bvh.primitive_indices;
360
438
  delete[] bvh.root;
361
439
 
362
440
  bvh.node_lowers = NULL;
363
441
  bvh.node_uppers = NULL;
364
442
  bvh.node_parents = NULL;
443
+ bvh.primitive_indices = NULL;
365
444
  bvh.root = NULL;
366
445
 
367
446
  bvh.max_nodes = 0;
@@ -370,10 +449,10 @@ void bvh_destroy_host(BVH& bvh)
370
449
 
371
450
  } // namespace wp
372
451
 
373
- uint64_t bvh_create_host(vec3* lowers, vec3* uppers, int num_items)
452
+ uint64_t bvh_create_host(vec3* lowers, vec3* uppers, int num_items, int constructor_type)
374
453
  {
375
454
  BVH* bvh = new BVH();
376
- wp::bvh_create_host(lowers, uppers, num_items, *bvh);
455
+ wp::bvh_create_host(lowers, uppers, num_items, constructor_type, *bvh);
377
456
 
378
457
  return (uint64_t)bvh;
379
458
  }
@@ -395,7 +474,7 @@ void bvh_destroy_host(uint64_t id)
395
474
  // stubs for non-CUDA platforms
396
475
  #if !WP_ENABLE_CUDA
397
476
 
398
- uint64_t bvh_create_device(void* context, wp::vec3* lowers, wp::vec3* uppers, int num_items) { return 0; }
477
+ uint64_t bvh_create_device(void* context, wp::vec3* lowers, wp::vec3* uppers, int num_items, int constructor_type) { return 0; }
399
478
  void bvh_refit_device(uint64_t id) {}
400
479
  void bvh_destroy_device(uint64_t id) {}
401
480