warp-lang 0.10.1__py3-none-win_amd64.whl → 0.11.0__py3-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +10 -4
- warp/__init__.pyi +1 -0
- warp/bin/warp-clang.dll +0 -0
- warp/bin/warp.dll +0 -0
- warp/build.py +5 -3
- warp/build_dll.py +29 -9
- warp/builtins.py +868 -507
- warp/codegen.py +1074 -638
- warp/config.py +3 -3
- warp/constants.py +6 -0
- warp/context.py +715 -222
- warp/fabric.py +326 -0
- warp/fem/__init__.py +27 -0
- warp/fem/cache.py +389 -0
- warp/fem/dirichlet.py +181 -0
- warp/fem/domain.py +263 -0
- warp/fem/field/__init__.py +101 -0
- warp/fem/field/field.py +149 -0
- warp/fem/field/nodal_field.py +299 -0
- warp/fem/field/restriction.py +21 -0
- warp/fem/field/test.py +181 -0
- warp/fem/field/trial.py +183 -0
- warp/fem/geometry/__init__.py +19 -0
- warp/fem/geometry/closest_point.py +70 -0
- warp/fem/geometry/deformed_geometry.py +271 -0
- warp/fem/geometry/element.py +744 -0
- warp/fem/geometry/geometry.py +186 -0
- warp/fem/geometry/grid_2d.py +373 -0
- warp/fem/geometry/grid_3d.py +435 -0
- warp/fem/geometry/hexmesh.py +953 -0
- warp/fem/geometry/partition.py +376 -0
- warp/fem/geometry/quadmesh_2d.py +532 -0
- warp/fem/geometry/tetmesh.py +840 -0
- warp/fem/geometry/trimesh_2d.py +577 -0
- warp/fem/integrate.py +1616 -0
- warp/fem/operator.py +191 -0
- warp/fem/polynomial.py +213 -0
- warp/fem/quadrature/__init__.py +2 -0
- warp/fem/quadrature/pic_quadrature.py +245 -0
- warp/fem/quadrature/quadrature.py +294 -0
- warp/fem/space/__init__.py +292 -0
- warp/fem/space/basis_space.py +489 -0
- warp/fem/space/collocated_function_space.py +105 -0
- warp/fem/space/dof_mapper.py +236 -0
- warp/fem/space/function_space.py +145 -0
- warp/fem/space/grid_2d_function_space.py +267 -0
- warp/fem/space/grid_3d_function_space.py +306 -0
- warp/fem/space/hexmesh_function_space.py +352 -0
- warp/fem/space/partition.py +350 -0
- warp/fem/space/quadmesh_2d_function_space.py +369 -0
- warp/fem/space/restriction.py +160 -0
- warp/fem/space/shape/__init__.py +15 -0
- warp/fem/space/shape/cube_shape_function.py +738 -0
- warp/fem/space/shape/shape_function.py +103 -0
- warp/fem/space/shape/square_shape_function.py +611 -0
- warp/fem/space/shape/tet_shape_function.py +567 -0
- warp/fem/space/shape/triangle_shape_function.py +429 -0
- warp/fem/space/tetmesh_function_space.py +292 -0
- warp/fem/space/topology.py +295 -0
- warp/fem/space/trimesh_2d_function_space.py +221 -0
- warp/fem/types.py +77 -0
- warp/fem/utils.py +495 -0
- warp/native/array.h +147 -44
- warp/native/builtin.h +122 -149
- warp/native/bvh.cpp +73 -325
- warp/native/bvh.cu +406 -23
- warp/native/bvh.h +34 -43
- warp/native/clang/clang.cpp +13 -8
- warp/native/crt.h +2 -0
- warp/native/cuda_crt.h +5 -0
- warp/native/cuda_util.cpp +15 -3
- warp/native/cuda_util.h +3 -1
- warp/native/cutlass/tools/library/scripts/conv2d_operation.py +463 -0
- warp/native/cutlass/tools/library/scripts/conv3d_operation.py +321 -0
- warp/native/cutlass/tools/library/scripts/gemm_operation.py +988 -0
- warp/native/cutlass/tools/library/scripts/generator.py +4625 -0
- warp/native/cutlass/tools/library/scripts/library.py +799 -0
- warp/native/cutlass/tools/library/scripts/manifest.py +402 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/docs/source/conf.py +96 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/profile/conv/conv2d_f16_sm80.py +106 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/profile/gemm/gemm_f32_sm80.py +91 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/setup.py +80 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/__init__.py +48 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/arguments.py +118 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/c_types.py +241 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/compiler.py +432 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/conv2d_operation.py +631 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/epilogue.py +1026 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/frontend.py +104 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/gemm_operation.py +1276 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/library.py +744 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/memory_manager.py +74 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/operation.py +110 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/parser.py +619 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/reduction_operation.py +398 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/tensor_ref.py +70 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/__init__.py +4 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/conv2d_testbed.py +646 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_grouped_testbed.py +235 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_testbed.py +557 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/profiler.py +70 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/type_hint.py +39 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/__init__.py +1 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/device.py +76 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/reference_model.py +255 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/__init__.py +0 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +201 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +177 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +98 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +95 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +163 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +187 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +309 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +54 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +253 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +97 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +242 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/run_all_tests.py +10 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/frontend/test_frontend.py +146 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/__init__.py +0 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_bf16_sm80.py +96 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f16_sm80.py +447 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f32_sm80.py +146 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f64_sm80.py +102 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_grouped_sm80.py +203 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_s8_sm80.py +229 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/run_all_tests.py +9 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/unit/test_sm80.py +453 -0
- warp/native/cutlass/tools/library/scripts/rank_2k_operation.py +398 -0
- warp/native/cutlass/tools/library/scripts/rank_k_operation.py +387 -0
- warp/native/cutlass/tools/library/scripts/rt.py +796 -0
- warp/native/cutlass/tools/library/scripts/symm_operation.py +400 -0
- warp/native/cutlass/tools/library/scripts/trmm_operation.py +407 -0
- warp/native/cutlass_gemm.cu +5 -3
- warp/native/exports.h +1240 -952
- warp/native/fabric.h +228 -0
- warp/native/hashgrid.cpp +4 -4
- warp/native/hashgrid.h +22 -2
- warp/native/intersect.h +22 -7
- warp/native/intersect_adj.h +8 -8
- warp/native/intersect_tri.h +1 -1
- warp/native/marching.cu +157 -161
- warp/native/mat.h +80 -19
- warp/native/matnn.h +2 -2
- warp/native/mesh.cpp +33 -108
- warp/native/mesh.cu +114 -23
- warp/native/mesh.h +446 -46
- warp/native/noise.h +272 -329
- warp/native/quat.h +51 -8
- warp/native/rand.h +45 -35
- warp/native/range.h +6 -2
- warp/native/reduce.cpp +1 -1
- warp/native/reduce.cu +10 -12
- warp/native/runlength_encode.cu +6 -10
- warp/native/scan.cu +8 -11
- warp/native/sparse.cpp +4 -4
- warp/native/sparse.cu +164 -154
- warp/native/spatial.h +2 -2
- warp/native/temp_buffer.h +14 -30
- warp/native/vec.h +107 -23
- warp/native/volume.h +120 -0
- warp/native/warp.cpp +560 -30
- warp/native/warp.cu +431 -44
- warp/native/warp.h +13 -4
- warp/optim/__init__.py +1 -0
- warp/optim/linear.py +922 -0
- warp/optim/sgd.py +92 -0
- warp/render/render_opengl.py +335 -119
- warp/render/render_usd.py +11 -11
- warp/sim/__init__.py +2 -2
- warp/sim/articulation.py +385 -185
- warp/sim/collide.py +8 -0
- warp/sim/import_mjcf.py +297 -106
- warp/sim/import_urdf.py +389 -210
- warp/sim/import_usd.py +198 -97
- warp/sim/inertia.py +17 -18
- warp/sim/integrator_euler.py +14 -8
- warp/sim/integrator_xpbd.py +158 -16
- warp/sim/model.py +795 -291
- warp/sim/render.py +3 -3
- warp/sim/utils.py +3 -0
- warp/sparse.py +640 -150
- warp/stubs.py +606 -267
- warp/tape.py +61 -10
- warp/tests/__main__.py +3 -6
- warp/tests/assets/curlnoise_golden.npy +0 -0
- warp/tests/assets/pnoise_golden.npy +0 -0
- warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
- warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
- warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
- warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
- warp/tests/aux_test_unresolved_func.py +14 -0
- warp/tests/aux_test_unresolved_symbol.py +14 -0
- warp/tests/disabled_kinematics.py +239 -0
- warp/tests/run_coverage_serial.py +31 -0
- warp/tests/test_adam.py +103 -106
- warp/tests/test_arithmetic.py +128 -74
- warp/tests/test_array.py +212 -97
- warp/tests/test_array_reduce.py +57 -23
- warp/tests/test_atomic.py +64 -28
- warp/tests/test_bool.py +99 -0
- warp/tests/test_builtins_resolution.py +1292 -0
- warp/tests/test_bvh.py +42 -18
- warp/tests/test_closest_point_edge_edge.py +54 -57
- warp/tests/test_codegen.py +208 -130
- warp/tests/test_compile_consts.py +28 -20
- warp/tests/test_conditional.py +108 -24
- warp/tests/test_copy.py +10 -12
- warp/tests/test_ctypes.py +112 -88
- warp/tests/test_dense.py +21 -14
- warp/tests/test_devices.py +98 -0
- warp/tests/test_dlpack.py +75 -75
- warp/tests/test_examples.py +277 -0
- warp/tests/test_fabricarray.py +955 -0
- warp/tests/test_fast_math.py +15 -11
- warp/tests/test_fem.py +1271 -0
- warp/tests/test_fp16.py +53 -19
- warp/tests/test_func.py +187 -86
- warp/tests/test_generics.py +194 -49
- warp/tests/test_grad.py +178 -109
- warp/tests/test_grad_customs.py +176 -0
- warp/tests/test_hash_grid.py +52 -37
- warp/tests/test_import.py +10 -23
- warp/tests/test_indexedarray.py +32 -31
- warp/tests/test_intersect.py +18 -9
- warp/tests/test_large.py +141 -0
- warp/tests/test_launch.py +14 -41
- warp/tests/test_lerp.py +64 -65
- warp/tests/test_linear_solvers.py +154 -0
- warp/tests/test_lvalue.py +493 -0
- warp/tests/test_marching_cubes.py +12 -13
- warp/tests/test_mat.py +517 -2898
- warp/tests/test_mat_lite.py +115 -0
- warp/tests/test_mat_scalar_ops.py +2889 -0
- warp/tests/test_math.py +103 -9
- warp/tests/test_matmul.py +305 -69
- warp/tests/test_matmul_lite.py +410 -0
- warp/tests/test_mesh.py +71 -14
- warp/tests/test_mesh_query_aabb.py +41 -25
- warp/tests/test_mesh_query_point.py +140 -22
- warp/tests/test_mesh_query_ray.py +39 -22
- warp/tests/test_mlp.py +30 -22
- warp/tests/test_model.py +92 -89
- warp/tests/test_modules_lite.py +39 -0
- warp/tests/test_multigpu.py +88 -114
- warp/tests/test_noise.py +12 -11
- warp/tests/test_operators.py +16 -20
- warp/tests/test_options.py +11 -11
- warp/tests/test_pinned.py +17 -18
- warp/tests/test_print.py +32 -11
- warp/tests/test_quat.py +275 -129
- warp/tests/test_rand.py +18 -16
- warp/tests/test_reload.py +38 -34
- warp/tests/test_rounding.py +50 -43
- warp/tests/test_runlength_encode.py +168 -20
- warp/tests/test_smoothstep.py +9 -11
- warp/tests/test_snippet.py +143 -0
- warp/tests/test_sparse.py +261 -63
- warp/tests/test_spatial.py +276 -243
- warp/tests/test_streams.py +110 -85
- warp/tests/test_struct.py +268 -63
- warp/tests/test_tape.py +39 -21
- warp/tests/test_torch.py +118 -89
- warp/tests/test_transient_module.py +12 -13
- warp/tests/test_types.py +614 -0
- warp/tests/test_utils.py +494 -0
- warp/tests/test_vec.py +354 -2050
- warp/tests/test_vec_lite.py +73 -0
- warp/tests/test_vec_scalar_ops.py +2099 -0
- warp/tests/test_volume.py +457 -293
- warp/tests/test_volume_write.py +124 -134
- warp/tests/unittest_serial.py +35 -0
- warp/tests/unittest_suites.py +341 -0
- warp/tests/unittest_utils.py +568 -0
- warp/tests/unused_test_misc.py +71 -0
- warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
- warp/thirdparty/appdirs.py +36 -45
- warp/thirdparty/unittest_parallel.py +549 -0
- warp/torch.py +9 -6
- warp/types.py +1089 -366
- warp/utils.py +93 -387
- warp_lang-0.11.0.dist-info/METADATA +238 -0
- warp_lang-0.11.0.dist-info/RECORD +332 -0
- {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/WHEEL +1 -1
- warp/tests/test_all.py +0 -219
- warp/tests/test_array_scan.py +0 -60
- warp/tests/test_base.py +0 -208
- warp/tests/test_unresolved_func.py +0 -7
- warp/tests/test_unresolved_symbol.py +0 -7
- warp_lang-0.10.1.dist-info/METADATA +0 -21
- warp_lang-0.10.1.dist-info/RECORD +0 -188
- /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
- /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
- /warp/tests/{test_square.py → aux_test_square.py} +0 -0
- {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/LICENSE.md +0 -0
- {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/top_level.txt +0 -0
warp/native/marching.cu
CHANGED
|
@@ -1,8 +1,6 @@
|
|
|
1
1
|
#include "warp.h"
|
|
2
2
|
#include "cuda_util.h"
|
|
3
|
-
|
|
4
|
-
#include "thrust/device_ptr.h"
|
|
5
|
-
#include "thrust/sort.h"
|
|
3
|
+
#include "scan.h"
|
|
6
4
|
|
|
7
5
|
namespace wp {
|
|
8
6
|
|
|
@@ -162,13 +160,17 @@ namespace wp {
|
|
|
162
160
|
};
|
|
163
161
|
|
|
164
162
|
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
163
|
+
// ---------------------------------------------------------------------------------------
|
|
164
|
+
struct MarchingCubes
|
|
165
|
+
{
|
|
166
|
+
MarchingCubes()
|
|
169
167
|
{
|
|
170
|
-
|
|
171
|
-
|
|
168
|
+
memset(this, 0, sizeof(MarchingCubes));
|
|
169
|
+
first_cell_vert = nullptr;
|
|
170
|
+
first_cell_tri = nullptr;
|
|
171
|
+
cell_verts = nullptr;
|
|
172
|
+
context = nullptr;
|
|
173
|
+
}
|
|
172
174
|
|
|
173
175
|
__device__ __host__ int cell_index(int xi, int yi, int zi) const
|
|
174
176
|
{
|
|
@@ -181,169 +183,169 @@ namespace wp {
|
|
|
181
183
|
xi = cell_index / ny;
|
|
182
184
|
}
|
|
183
185
|
|
|
184
|
-
|
|
186
|
+
// grid
|
|
185
187
|
int nx;
|
|
186
188
|
int ny;
|
|
187
189
|
int nz;
|
|
188
190
|
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
191
|
+
int* first_cell_vert;
|
|
192
|
+
int* first_cell_tri;
|
|
193
|
+
int* cell_verts;
|
|
192
194
|
|
|
193
195
|
int num_cells;
|
|
194
196
|
int max_cells;
|
|
195
197
|
|
|
196
198
|
void* context;
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
199
|
+
};
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
// -----------------------------------------------------------------------------------
|
|
203
|
+
__global__ void count_cell_verts(MarchingCubes mc, const float* density, float threshold)
|
|
204
|
+
{
|
|
205
|
+
int cell_index = blockIdx.x * blockDim.x + threadIdx.x;
|
|
206
|
+
if (cell_index >= mc.num_cells)
|
|
207
|
+
return;
|
|
208
|
+
|
|
209
|
+
int xi, yi, zi;
|
|
210
|
+
mc.cell_coord(cell_index, xi, yi, zi);
|
|
211
|
+
|
|
212
|
+
mc.first_cell_vert[cell_index] = 0;
|
|
213
|
+
if (xi >= mc.nx - 1 || yi >= mc.ny - 1 || zi >= mc.nz - 1)
|
|
214
|
+
return;
|
|
215
|
+
|
|
216
|
+
float d0 = density[cell_index];
|
|
217
|
+
float dx = density[mc.cell_index(xi + 1, yi, zi)];
|
|
218
|
+
float dy = density[mc.cell_index(xi, yi + 1, zi)];
|
|
219
|
+
float dz = density[mc.cell_index(xi, yi, zi + 1)];
|
|
220
|
+
|
|
221
|
+
int num = 0;
|
|
222
|
+
if ((d0 <= threshold && dx >= threshold) || (dx <= threshold && d0 >= threshold))
|
|
223
|
+
num++;
|
|
224
|
+
if ((d0 <= threshold && dy >= threshold) || (dy <= threshold && d0 >= threshold))
|
|
225
|
+
num++;
|
|
226
|
+
if ((d0 <= threshold && dz >= threshold) || (dz <= threshold && d0 >= threshold))
|
|
227
|
+
num++;
|
|
228
|
+
|
|
229
|
+
mc.first_cell_vert[cell_index] = num;
|
|
230
|
+
}
|
|
231
|
+
|
|
232
|
+
// -----------------------------------------------------------------------------------
|
|
233
|
+
__global__ void create_cell_verts(MarchingCubes mc, vec3* __restrict__ vertices, vec3* normals, const float* __restrict__ density, float threshold)
|
|
234
|
+
{
|
|
235
|
+
int cell_index = blockIdx.x * blockDim.x + threadIdx.x;
|
|
236
|
+
if (cell_index >= mc.num_cells)
|
|
237
|
+
return;
|
|
238
|
+
|
|
239
|
+
int xi, yi, zi;
|
|
240
|
+
mc.cell_coord(cell_index, xi, yi, zi);
|
|
241
|
+
if (xi >= mc.nx - 1 || yi >= mc.ny - 1 || zi >= mc.nz - 1)
|
|
242
|
+
return;
|
|
243
|
+
|
|
244
|
+
vec3 p = vec3(xi + 0.5f, yi + 0.5f, zi + 0.5f);
|
|
245
|
+
|
|
246
|
+
float d0 = density[cell_index];
|
|
247
|
+
float ds[3];
|
|
248
|
+
ds[0] = density[mc.cell_index(xi + 1, yi, zi)];
|
|
249
|
+
ds[1] = density[mc.cell_index(xi, yi + 1, zi)];
|
|
250
|
+
ds[2] = density[mc.cell_index(xi, yi, zi + 1)];
|
|
251
|
+
|
|
252
|
+
// vec3 n0 = densityNormal[cell_index];
|
|
253
|
+
// vec3 ns[3];
|
|
254
|
+
// ns[0] = densityNormal[mc.cell_index(xi + 1, yi, zi)];
|
|
255
|
+
// ns[1] = densityNormal[mc.cell_index(xi, yi + 1, zi)];
|
|
256
|
+
// ns[2] = densityNormal[mc.cell_index(xi, yi, zi + 1)];
|
|
257
|
+
|
|
258
|
+
int first = mc.first_cell_vert[cell_index];
|
|
259
|
+
|
|
260
|
+
for (int dim = 0; dim < 3; dim++)
|
|
259
261
|
{
|
|
260
|
-
|
|
261
|
-
|
|
262
|
+
float d = ds[dim];
|
|
263
|
+
mc.cell_verts[3 * cell_index + dim] = 0;
|
|
262
264
|
|
|
263
|
-
|
|
265
|
+
if ((d0 <= threshold && d >= threshold) || (d <= threshold && d0 >= threshold))
|
|
264
266
|
{
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
267
|
+
float t = (d != d0) ? clamp((threshold - d0) / (d - d0), 0.0f, 1.0f) : 0.5f;
|
|
268
|
+
int id = first++;
|
|
269
|
+
|
|
270
|
+
vec3 off;
|
|
271
|
+
off[dim] = t;
|
|
272
|
+
vertices[id] = p + off;
|
|
273
|
+
|
|
274
|
+
// vec3 n = normalize(n0 + t * (ns[dim] - n0));
|
|
275
|
+
// normals[id] = -n;
|
|
276
|
+
|
|
277
|
+
mc.cell_verts[3 * cell_index + dim] = id;
|
|
278
|
+
}
|
|
279
|
+
}
|
|
280
|
+
}
|
|
281
|
+
|
|
282
|
+
// -----------------------------------------------------------------------------------
|
|
283
|
+
__global__ void count_cell_tris(MarchingCubes mc, const float* __restrict__ density, float threshold)
|
|
284
|
+
{
|
|
285
|
+
int cell_index = blockIdx.x * blockDim.x + threadIdx.x;
|
|
286
|
+
if (cell_index >= mc.num_cells)
|
|
287
|
+
return;
|
|
288
|
+
|
|
289
|
+
int xi, yi, zi;
|
|
290
|
+
mc.cell_coord(cell_index, xi, yi, zi);
|
|
291
|
+
|
|
292
|
+
mc.first_cell_tri[cell_index] = 0;
|
|
293
|
+
if (xi >= mc.nx - 2 || yi >= mc.ny - 2 || zi >= mc.nz - 2)
|
|
294
|
+
return;
|
|
295
|
+
|
|
296
|
+
int code = 0;
|
|
297
|
+
for (int i = 0; i < 8; i++) {
|
|
298
|
+
int cxi = xi + marchingCubeCorners[i][0];
|
|
299
|
+
int cyi = yi + marchingCubeCorners[i][1];
|
|
300
|
+
int czi = zi + marchingCubeCorners[i][2];
|
|
301
|
+
|
|
302
|
+
if (density[mc.cell_index(cxi, cyi, czi)] >= threshold)
|
|
303
|
+
code |= (1 << i);
|
|
304
|
+
}
|
|
305
|
+
|
|
306
|
+
mc.first_cell_tri[cell_index] = firstMarchingCubesId[code + 1] - firstMarchingCubesId[code];
|
|
307
|
+
}
|
|
308
|
+
|
|
309
|
+
// -----------------------------------------------------------------------------------
|
|
310
|
+
__global__ void create_cell_tris(MarchingCubes mc, const float* __restrict__ density, int* __restrict__ triangles, float threshold)
|
|
311
|
+
{
|
|
312
|
+
int cell_index = blockIdx.x * blockDim.x + threadIdx.x;
|
|
313
|
+
if (cell_index >= mc.num_cells)
|
|
314
|
+
return;
|
|
315
|
+
|
|
316
|
+
int xi, yi, zi;
|
|
317
|
+
mc.cell_coord(cell_index, xi, yi, zi);
|
|
318
|
+
if (xi >= mc.nx - 2 || yi >= mc.ny - 2 || zi >= mc.nz - 2)
|
|
319
|
+
return;
|
|
320
|
+
|
|
321
|
+
int code = 0;
|
|
322
|
+
for (int i = 0; i < 8; i++)
|
|
321
323
|
{
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
324
|
+
int cxi = xi + marchingCubeCorners[i][0];
|
|
325
|
+
int cyi = yi + marchingCubeCorners[i][1];
|
|
326
|
+
int czi = zi + marchingCubeCorners[i][2];
|
|
325
327
|
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
328
|
+
if (density[mc.cell_index(cxi, cyi, czi)] >= threshold)
|
|
329
|
+
code |= (1 << i);
|
|
330
|
+
}
|
|
329
331
|
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
332
|
+
int firstIn = firstMarchingCubesId[code];
|
|
333
|
+
int num = firstMarchingCubesId[code + 1] - firstIn;
|
|
334
|
+
int firstOut = mc.first_cell_tri[cell_index];
|
|
333
335
|
|
|
334
|
-
|
|
336
|
+
for (int i = 0; i < num; i++)
|
|
335
337
|
{
|
|
336
|
-
|
|
338
|
+
int eid = marchingCubesIds[firstIn + i];
|
|
337
339
|
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
340
|
+
int exi = xi + marchingCubesEdgeLocations[eid][0];
|
|
341
|
+
int eyi = yi + marchingCubesEdgeLocations[eid][1];
|
|
342
|
+
int ezi = zi + marchingCubesEdgeLocations[eid][2];
|
|
343
|
+
int edgeNr = marchingCubesEdgeLocations[eid][3];
|
|
342
344
|
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
345
|
+
int id = mc.cell_verts[3 * mc.cell_index(exi, eyi, ezi) + edgeNr];
|
|
346
|
+
triangles[firstOut + i] = id;
|
|
347
|
+
}
|
|
348
|
+
}
|
|
347
349
|
|
|
348
350
|
// -------------------------
|
|
349
351
|
void marching_cubes_resize(MarchingCubes& mc, int nx, int ny, int nz)
|
|
@@ -444,10 +446,7 @@ WP_API int marching_cubes_surface_device(
|
|
|
444
446
|
int num_last;
|
|
445
447
|
memcpy_d2h(WP_CURRENT_CONTEXT, &num_last, &mc.first_cell_vert[mc.num_cells - 1], sizeof(int));
|
|
446
448
|
|
|
447
|
-
|
|
448
|
-
thrust::device_ptr<int>(mc.first_cell_vert),
|
|
449
|
-
thrust::device_ptr<int>(mc.first_cell_vert + mc.num_cells),
|
|
450
|
-
thrust::device_ptr<int>(mc.first_cell_vert));
|
|
449
|
+
scan_device(mc.first_cell_vert, mc.first_cell_vert, mc.num_cells, false);
|
|
451
450
|
|
|
452
451
|
int num_verts;
|
|
453
452
|
memcpy_d2h(WP_CURRENT_CONTEXT, &num_verts, &mc.first_cell_vert[mc.num_cells - 1], sizeof(int));
|
|
@@ -472,10 +471,7 @@ WP_API int marching_cubes_surface_device(
|
|
|
472
471
|
|
|
473
472
|
memcpy_d2h(WP_CURRENT_CONTEXT, &num_last, &mc.first_cell_tri[mc.num_cells - 1], sizeof(int));
|
|
474
473
|
|
|
475
|
-
|
|
476
|
-
thrust::device_ptr<int>(mc.first_cell_tri),
|
|
477
|
-
thrust::device_ptr<int>(mc.first_cell_tri + mc.num_cells),
|
|
478
|
-
thrust::device_ptr<int>(mc.first_cell_tri));
|
|
474
|
+
scan_device(mc.first_cell_tri, mc.first_cell_tri, mc.num_cells, false);
|
|
479
475
|
|
|
480
476
|
|
|
481
477
|
int num_indices;
|
warp/native/mat.h
CHANGED
|
@@ -21,7 +21,9 @@ struct quat_t;
|
|
|
21
21
|
template<unsigned Rows, unsigned Cols, typename Type>
|
|
22
22
|
struct mat_t
|
|
23
23
|
{
|
|
24
|
-
inline mat_t()
|
|
24
|
+
inline CUDA_CALLABLE mat_t()
|
|
25
|
+
: data()
|
|
26
|
+
{}
|
|
25
27
|
|
|
26
28
|
inline CUDA_CALLABLE mat_t(Type s)
|
|
27
29
|
{
|
|
@@ -193,7 +195,7 @@ struct mat_t
|
|
|
193
195
|
}
|
|
194
196
|
|
|
195
197
|
// row major storage assumed to be compatible with PyTorch
|
|
196
|
-
Type data[Rows][Cols]
|
|
198
|
+
Type data[Rows][Cols];
|
|
197
199
|
};
|
|
198
200
|
|
|
199
201
|
|
|
@@ -298,7 +300,19 @@ inline CUDA_CALLABLE mat_t<Rows,Cols,Type> atomic_max(mat_t<Rows,Cols,Type> * ad
|
|
|
298
300
|
}
|
|
299
301
|
|
|
300
302
|
template<unsigned Rows, unsigned Cols, typename Type>
|
|
301
|
-
inline CUDA_CALLABLE
|
|
303
|
+
inline CUDA_CALLABLE void adj_atomic_minmax(
|
|
304
|
+
mat_t<Rows,Cols,Type> *addr,
|
|
305
|
+
mat_t<Rows,Cols,Type> *adj_addr,
|
|
306
|
+
const mat_t<Rows,Cols,Type> &value,
|
|
307
|
+
mat_t<Rows,Cols,Type> &adj_value)
|
|
308
|
+
{
|
|
309
|
+
for (unsigned i=0; i < Rows; ++i)
|
|
310
|
+
for (unsigned j=0; j < Cols; ++j)
|
|
311
|
+
adj_atomic_minmax(&addr->data[i][j], &adj_addr->data[i][j], value.data[i][j], adj_value.data[i][j]);
|
|
312
|
+
}
|
|
313
|
+
|
|
314
|
+
template<unsigned Rows, unsigned Cols, typename Type>
|
|
315
|
+
inline CUDA_CALLABLE vec_t<Cols,Type> extract(const mat_t<Rows,Cols,Type>& m, int row)
|
|
302
316
|
{
|
|
303
317
|
vec_t<Cols,Type> ret;
|
|
304
318
|
for(unsigned i=0; i < Cols; ++i)
|
|
@@ -309,7 +323,7 @@ inline CUDA_CALLABLE vec_t<Cols,Type> index(const mat_t<Rows,Cols,Type>& m, int
|
|
|
309
323
|
}
|
|
310
324
|
|
|
311
325
|
template<unsigned Rows, unsigned Cols, typename Type>
|
|
312
|
-
inline CUDA_CALLABLE Type
|
|
326
|
+
inline CUDA_CALLABLE Type extract(const mat_t<Rows,Cols,Type>& m, int row, int col)
|
|
313
327
|
{
|
|
314
328
|
#ifndef NDEBUG
|
|
315
329
|
if (row < 0 || row >= Rows)
|
|
@@ -327,7 +341,7 @@ inline CUDA_CALLABLE Type index(const mat_t<Rows,Cols,Type>& m, int row, int col
|
|
|
327
341
|
}
|
|
328
342
|
|
|
329
343
|
template<unsigned Rows, unsigned Cols, typename Type>
|
|
330
|
-
inline CUDA_CALLABLE
|
|
344
|
+
inline CUDA_CALLABLE vec_t<Cols, Type>* index(mat_t<Rows,Cols,Type>& m, int row)
|
|
331
345
|
{
|
|
332
346
|
#ifndef NDEBUG
|
|
333
347
|
if (row < 0 || row >= Rows)
|
|
@@ -337,12 +351,11 @@ inline CUDA_CALLABLE void indexset(mat_t<Rows,Cols,Type>& m, int row, vec_t<Cols
|
|
|
337
351
|
}
|
|
338
352
|
#endif
|
|
339
353
|
|
|
340
|
-
|
|
341
|
-
m.data[row][i] = value[i];
|
|
354
|
+
return reinterpret_cast<vec_t<Cols, Type>*>(&m.data[row]);
|
|
342
355
|
}
|
|
343
356
|
|
|
344
357
|
template<unsigned Rows, unsigned Cols, typename Type>
|
|
345
|
-
inline CUDA_CALLABLE
|
|
358
|
+
inline CUDA_CALLABLE Type* index(mat_t<Rows,Cols,Type>& m, int row, int col)
|
|
346
359
|
{
|
|
347
360
|
#ifndef NDEBUG
|
|
348
361
|
if (row < 0 || row >= Rows)
|
|
@@ -356,18 +369,19 @@ inline CUDA_CALLABLE void indexset(mat_t<Rows,Cols,Type>& m, int row, int col, T
|
|
|
356
369
|
assert(0);
|
|
357
370
|
}
|
|
358
371
|
#endif
|
|
359
|
-
|
|
372
|
+
|
|
373
|
+
return &m.data[row][col];
|
|
360
374
|
}
|
|
361
375
|
|
|
362
376
|
template<unsigned Rows, unsigned Cols, typename Type>
|
|
363
|
-
inline CUDA_CALLABLE void
|
|
377
|
+
inline CUDA_CALLABLE void adj_index(const mat_t<Rows,Cols,Type>& m, int row,
|
|
364
378
|
const mat_t<Rows,Cols,Type>& adj_m, int adj_row, const vec_t<Cols, Type>& adj_value)
|
|
365
379
|
{
|
|
366
380
|
// nop
|
|
367
381
|
}
|
|
368
382
|
|
|
369
383
|
template<unsigned Rows, unsigned Cols, typename Type>
|
|
370
|
-
inline CUDA_CALLABLE void
|
|
384
|
+
inline CUDA_CALLABLE void adj_index(const mat_t<Rows,Cols,Type>& m, int row, int col,
|
|
371
385
|
const mat_t<Rows,Cols,Type>& adj_m, int adj_row, int adj_col, Type adj_value)
|
|
372
386
|
{
|
|
373
387
|
// nop
|
|
@@ -425,7 +439,22 @@ inline CUDA_CALLABLE mat_t<Rows,Cols,Type> div(const mat_t<Rows,Cols,Type>& a, T
|
|
|
425
439
|
}
|
|
426
440
|
}
|
|
427
441
|
|
|
428
|
-
return t;
|
|
442
|
+
return t;
|
|
443
|
+
}
|
|
444
|
+
|
|
445
|
+
template<unsigned Rows, unsigned Cols, typename Type>
|
|
446
|
+
inline CUDA_CALLABLE mat_t<Rows,Cols,Type> div(Type b, const mat_t<Rows,Cols,Type>& a)
|
|
447
|
+
{
|
|
448
|
+
mat_t<Rows,Cols,Type> t;
|
|
449
|
+
for (unsigned i=0; i < Rows; ++i)
|
|
450
|
+
{
|
|
451
|
+
for (unsigned j=0; j < Cols; ++j)
|
|
452
|
+
{
|
|
453
|
+
t.data[i][j] = b / a.data[i][j];
|
|
454
|
+
}
|
|
455
|
+
}
|
|
456
|
+
|
|
457
|
+
return t;
|
|
429
458
|
}
|
|
430
459
|
|
|
431
460
|
template<unsigned Rows, unsigned Cols, typename Type>
|
|
@@ -440,7 +469,7 @@ inline CUDA_CALLABLE mat_t<Rows,Cols,Type> mul(const mat_t<Rows,Cols,Type>& a, T
|
|
|
440
469
|
}
|
|
441
470
|
}
|
|
442
471
|
|
|
443
|
-
return t;
|
|
472
|
+
return t;
|
|
444
473
|
}
|
|
445
474
|
|
|
446
475
|
template<unsigned Rows, unsigned Cols, typename Type>
|
|
@@ -473,6 +502,17 @@ inline CUDA_CALLABLE vec_t<Rows,Type> mul(const mat_t<Rows,Cols,Type>& a, const
|
|
|
473
502
|
return r;
|
|
474
503
|
}
|
|
475
504
|
|
|
505
|
+
template<unsigned Rows, unsigned Cols, typename Type>
|
|
506
|
+
inline CUDA_CALLABLE vec_t<Cols,Type> mul(const vec_t<Rows,Type>& b, const mat_t<Rows,Cols,Type>& a)
|
|
507
|
+
{
|
|
508
|
+
vec_t<Cols,Type> r = a.get_row(0)*b[0];
|
|
509
|
+
for( unsigned i=1; i < Rows; ++i )
|
|
510
|
+
{
|
|
511
|
+
r += a.get_row(i)*b[i];
|
|
512
|
+
}
|
|
513
|
+
return r;
|
|
514
|
+
}
|
|
515
|
+
|
|
476
516
|
template<unsigned Rows, unsigned Cols, unsigned ColsOut, typename Type>
|
|
477
517
|
inline CUDA_CALLABLE mat_t<Rows,ColsOut,Type> mul(const mat_t<Rows,Cols,Type>& a, const mat_t<Cols,ColsOut,Type>& b)
|
|
478
518
|
{
|
|
@@ -861,14 +901,14 @@ inline CUDA_CALLABLE vec_t<3,Type> transform_vector(const mat_t<4,4,Type>& m, co
|
|
|
861
901
|
}
|
|
862
902
|
|
|
863
903
|
template<unsigned Rows, unsigned Cols, typename Type>
|
|
864
|
-
inline CUDA_CALLABLE void
|
|
904
|
+
inline CUDA_CALLABLE void adj_extract(const mat_t<Rows,Cols,Type>& m, int row, mat_t<Rows,Cols,Type>& adj_m, int& adj_row, const vec_t<Cols,Type>& adj_ret)
|
|
865
905
|
{
|
|
866
906
|
for( unsigned col=0; col < Cols; ++col )
|
|
867
907
|
adj_m.data[row][col] += adj_ret[col];
|
|
868
908
|
}
|
|
869
909
|
|
|
870
910
|
template<unsigned Rows, unsigned Cols, typename Type>
|
|
871
|
-
inline void CUDA_CALLABLE
|
|
911
|
+
inline void CUDA_CALLABLE adj_extract(const mat_t<Rows,Cols,Type>& m, int row, int col, mat_t<Rows,Cols,Type>& adj_m, int& adj_row, int& adj_col, Type adj_ret)
|
|
872
912
|
{
|
|
873
913
|
#ifndef NDEBUG
|
|
874
914
|
if (row < 0 || row > Rows)
|
|
@@ -932,6 +972,20 @@ inline CUDA_CALLABLE void adj_div(const mat_t<Rows,Cols,Type>& a, Type s, mat_t<
|
|
|
932
972
|
}
|
|
933
973
|
}
|
|
934
974
|
|
|
975
|
+
template<unsigned Rows, unsigned Cols, typename Type>
|
|
976
|
+
inline CUDA_CALLABLE void adj_div(Type s, const mat_t<Rows,Cols,Type>& a, Type& adj_s, mat_t<Rows,Cols,Type>& adj_a, const mat_t<Rows,Cols,Type>& adj_ret)
|
|
977
|
+
{
|
|
978
|
+
adj_s -= tensordot(a , adj_ret)/ (s * s); // - a / s^2
|
|
979
|
+
|
|
980
|
+
for (unsigned i=0; i < Rows; ++i)
|
|
981
|
+
{
|
|
982
|
+
for (unsigned j=0; j < Cols; ++j)
|
|
983
|
+
{
|
|
984
|
+
adj_a.data[i][j] += s / adj_ret.data[i][j];
|
|
985
|
+
}
|
|
986
|
+
}
|
|
987
|
+
}
|
|
988
|
+
|
|
935
989
|
template<unsigned Rows, unsigned Cols, typename Type>
|
|
936
990
|
inline CUDA_CALLABLE void adj_mul(const mat_t<Rows,Cols,Type>& a, Type b, mat_t<Rows,Cols,Type>& adj_a, Type& adj_b, const mat_t<Rows,Cols,Type>& adj_ret)
|
|
937
991
|
{
|
|
@@ -965,6 +1019,13 @@ inline CUDA_CALLABLE void adj_mul(const mat_t<Rows,Cols,Type>& a, const vec_t<Co
|
|
|
965
1019
|
adj_b += mul(transpose(a), adj_ret);
|
|
966
1020
|
}
|
|
967
1021
|
|
|
1022
|
+
template<unsigned Rows, unsigned Cols, typename Type>
|
|
1023
|
+
inline CUDA_CALLABLE void adj_mul(const vec_t<Rows,Type>& b, const mat_t<Rows,Cols,Type>& a, vec_t<Rows,Type>& adj_b, mat_t<Rows,Cols,Type>& adj_a, const vec_t<Cols,Type>& adj_ret)
|
|
1024
|
+
{
|
|
1025
|
+
adj_a += outer(b, adj_ret);
|
|
1026
|
+
adj_b += mul(adj_ret, transpose(a));
|
|
1027
|
+
}
|
|
1028
|
+
|
|
968
1029
|
template<unsigned Rows, unsigned Cols, unsigned ColsOut, typename Type>
|
|
969
1030
|
inline CUDA_CALLABLE void adj_mul(const mat_t<Rows,Cols,Type>& a, const mat_t<Cols,ColsOut,Type>& b, mat_t<Rows,Cols,Type>& adj_a, mat_t<Cols,ColsOut,Type>& adj_b, const mat_t<Rows,ColsOut,Type>& adj_ret)
|
|
970
1031
|
{
|
|
@@ -1105,10 +1166,10 @@ inline CUDA_CALLABLE void adj_determinant(const mat_t<4,4,Type>& m, mat_t<4,4,Ty
|
|
|
1105
1166
|
}
|
|
1106
1167
|
|
|
1107
1168
|
template<unsigned Rows, typename Type>
|
|
1108
|
-
inline CUDA_CALLABLE void adj_inverse(const mat_t<Rows,Rows,Type>& m, mat_t<Rows,Rows,Type>& adj_m, const mat_t<Rows,Rows,Type>& adj_ret)
|
|
1169
|
+
inline CUDA_CALLABLE void adj_inverse(const mat_t<Rows,Rows,Type>& m, mat_t<Rows,Rows,Type>& ret, mat_t<Rows,Rows,Type>& adj_m, const mat_t<Rows,Rows,Type>& adj_ret)
|
|
1109
1170
|
{
|
|
1110
1171
|
// todo: how to cache this from the forward pass?
|
|
1111
|
-
mat_t<Rows,Rows,Type> invt = transpose(
|
|
1172
|
+
mat_t<Rows,Rows,Type> invt = transpose(ret);
|
|
1112
1173
|
|
|
1113
1174
|
// see https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf 2.2.3
|
|
1114
1175
|
adj_m -= mul(mul(invt, adj_ret), invt);
|
|
@@ -1150,10 +1211,10 @@ inline CUDA_CALLABLE void adj_cw_mul(const mat_t<Rows,Cols,Type>& a, const mat_t
|
|
|
1150
1211
|
}
|
|
1151
1212
|
|
|
1152
1213
|
template<unsigned Rows, unsigned Cols, typename Type>
|
|
1153
|
-
inline CUDA_CALLABLE void adj_cw_div(const mat_t<Rows,Cols,Type>& a, const mat_t<Rows,Cols,Type>& b, mat_t<Rows,Cols,Type>& adj_a, mat_t<Rows,Cols,Type>& adj_b, const mat_t<Rows,Cols,Type>& adj_ret)
|
|
1214
|
+
inline CUDA_CALLABLE void adj_cw_div(const mat_t<Rows,Cols,Type>& a, const mat_t<Rows,Cols,Type>& b, mat_t<Rows,Cols,Type>& ret, mat_t<Rows,Cols,Type>& adj_a, mat_t<Rows,Cols,Type>& adj_b, const mat_t<Rows,Cols,Type>& adj_ret)
|
|
1154
1215
|
{
|
|
1155
1216
|
adj_a += cw_div(adj_ret, b);
|
|
1156
|
-
adj_b -= cw_mul(adj_ret, cw_div(
|
|
1217
|
+
adj_b -= cw_mul(adj_ret, cw_div(ret, b));
|
|
1157
1218
|
}
|
|
1158
1219
|
|
|
1159
1220
|
// adjoint for the constant constructor:
|
warp/native/matnn.h
CHANGED
|
@@ -248,7 +248,7 @@ CUDA_CALLABLE inline void adj_mlp(const array_t<float>& weights, const array_t<f
|
|
|
248
248
|
tmp += weights.data[i*n + j]*x.data[index + b*j];
|
|
249
249
|
}
|
|
250
250
|
|
|
251
|
-
// adjoint w.r.t to
|
|
251
|
+
// adjoint w.r.t to activation
|
|
252
252
|
float adj_f = 0.0f;
|
|
253
253
|
|
|
254
254
|
if (adj_out.data)
|
|
@@ -313,7 +313,7 @@ CUDA_CALLABLE inline void adj_mlp(const array_t<float>& weights, const array_t<f
|
|
|
313
313
|
// tmp += weights[i*n + j]*x[index + b*j];
|
|
314
314
|
// }
|
|
315
315
|
|
|
316
|
-
// // adjoint w.r.t to
|
|
316
|
+
// // adjoint w.r.t to activation
|
|
317
317
|
// float adj_f = 0.0f;
|
|
318
318
|
// adj_activation(tmp, adj_f, adj_out[index + b*i]);
|
|
319
319
|
|