warp-lang 0.9.0__py3-none-win_amd64.whl → 0.11.0__py3-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +15 -7
- warp/__init__.pyi +1 -0
- warp/bin/warp-clang.dll +0 -0
- warp/bin/warp.dll +0 -0
- warp/build.py +22 -443
- warp/build_dll.py +384 -0
- warp/builtins.py +998 -488
- warp/codegen.py +1307 -739
- warp/config.py +5 -3
- warp/constants.py +6 -0
- warp/context.py +1291 -548
- warp/dlpack.py +31 -31
- warp/fabric.py +326 -0
- warp/fem/__init__.py +27 -0
- warp/fem/cache.py +389 -0
- warp/fem/dirichlet.py +181 -0
- warp/fem/domain.py +263 -0
- warp/fem/field/__init__.py +101 -0
- warp/fem/field/field.py +149 -0
- warp/fem/field/nodal_field.py +299 -0
- warp/fem/field/restriction.py +21 -0
- warp/fem/field/test.py +181 -0
- warp/fem/field/trial.py +183 -0
- warp/fem/geometry/__init__.py +19 -0
- warp/fem/geometry/closest_point.py +70 -0
- warp/fem/geometry/deformed_geometry.py +271 -0
- warp/fem/geometry/element.py +744 -0
- warp/fem/geometry/geometry.py +186 -0
- warp/fem/geometry/grid_2d.py +373 -0
- warp/fem/geometry/grid_3d.py +435 -0
- warp/fem/geometry/hexmesh.py +953 -0
- warp/fem/geometry/partition.py +376 -0
- warp/fem/geometry/quadmesh_2d.py +532 -0
- warp/fem/geometry/tetmesh.py +840 -0
- warp/fem/geometry/trimesh_2d.py +577 -0
- warp/fem/integrate.py +1616 -0
- warp/fem/operator.py +191 -0
- warp/fem/polynomial.py +213 -0
- warp/fem/quadrature/__init__.py +2 -0
- warp/fem/quadrature/pic_quadrature.py +245 -0
- warp/fem/quadrature/quadrature.py +294 -0
- warp/fem/space/__init__.py +292 -0
- warp/fem/space/basis_space.py +489 -0
- warp/fem/space/collocated_function_space.py +105 -0
- warp/fem/space/dof_mapper.py +236 -0
- warp/fem/space/function_space.py +145 -0
- warp/fem/space/grid_2d_function_space.py +267 -0
- warp/fem/space/grid_3d_function_space.py +306 -0
- warp/fem/space/hexmesh_function_space.py +352 -0
- warp/fem/space/partition.py +350 -0
- warp/fem/space/quadmesh_2d_function_space.py +369 -0
- warp/fem/space/restriction.py +160 -0
- warp/fem/space/shape/__init__.py +15 -0
- warp/fem/space/shape/cube_shape_function.py +738 -0
- warp/fem/space/shape/shape_function.py +103 -0
- warp/fem/space/shape/square_shape_function.py +611 -0
- warp/fem/space/shape/tet_shape_function.py +567 -0
- warp/fem/space/shape/triangle_shape_function.py +429 -0
- warp/fem/space/tetmesh_function_space.py +292 -0
- warp/fem/space/topology.py +295 -0
- warp/fem/space/trimesh_2d_function_space.py +221 -0
- warp/fem/types.py +77 -0
- warp/fem/utils.py +495 -0
- warp/native/array.h +164 -55
- warp/native/builtin.h +150 -174
- warp/native/bvh.cpp +75 -328
- warp/native/bvh.cu +406 -23
- warp/native/bvh.h +37 -45
- warp/native/clang/clang.cpp +136 -24
- warp/native/crt.cpp +1 -76
- warp/native/crt.h +111 -104
- warp/native/cuda_crt.h +1049 -0
- warp/native/cuda_util.cpp +15 -3
- warp/native/cuda_util.h +3 -1
- warp/native/cutlass/tools/library/scripts/conv2d_operation.py +463 -0
- warp/native/cutlass/tools/library/scripts/conv3d_operation.py +321 -0
- warp/native/cutlass/tools/library/scripts/gemm_operation.py +988 -0
- warp/native/cutlass/tools/library/scripts/generator.py +4625 -0
- warp/native/cutlass/tools/library/scripts/library.py +799 -0
- warp/native/cutlass/tools/library/scripts/manifest.py +402 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/docs/source/conf.py +96 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/profile/conv/conv2d_f16_sm80.py +106 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/profile/gemm/gemm_f32_sm80.py +91 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/setup.py +80 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/__init__.py +48 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/arguments.py +118 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/c_types.py +241 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/compiler.py +432 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/conv2d_operation.py +631 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/epilogue.py +1026 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/frontend.py +104 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/gemm_operation.py +1276 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/library.py +744 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/memory_manager.py +74 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/operation.py +110 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/parser.py +619 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/reduction_operation.py +398 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/tensor_ref.py +70 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/__init__.py +4 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/conv2d_testbed.py +646 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_grouped_testbed.py +235 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_testbed.py +557 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/profiler.py +70 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/type_hint.py +39 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/__init__.py +1 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/device.py +76 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/reference_model.py +255 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/__init__.py +0 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +201 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +177 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +98 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +95 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +163 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +187 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +309 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +54 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +253 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +97 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +242 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/run_all_tests.py +10 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/frontend/test_frontend.py +146 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/__init__.py +0 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_bf16_sm80.py +96 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f16_sm80.py +447 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f32_sm80.py +146 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f64_sm80.py +102 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_grouped_sm80.py +203 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_s8_sm80.py +229 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/run_all_tests.py +9 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/unit/test_sm80.py +453 -0
- warp/native/cutlass/tools/library/scripts/rank_2k_operation.py +398 -0
- warp/native/cutlass/tools/library/scripts/rank_k_operation.py +387 -0
- warp/native/cutlass/tools/library/scripts/rt.py +796 -0
- warp/native/cutlass/tools/library/scripts/symm_operation.py +400 -0
- warp/native/cutlass/tools/library/scripts/trmm_operation.py +407 -0
- warp/native/cutlass_gemm.cu +5 -3
- warp/native/exports.h +1240 -949
- warp/native/fabric.h +228 -0
- warp/native/hashgrid.cpp +4 -4
- warp/native/hashgrid.h +22 -2
- warp/native/initializer_array.h +2 -2
- warp/native/intersect.h +22 -7
- warp/native/intersect_adj.h +8 -8
- warp/native/intersect_tri.h +13 -16
- warp/native/marching.cu +157 -161
- warp/native/mat.h +119 -19
- warp/native/matnn.h +2 -2
- warp/native/mesh.cpp +108 -83
- warp/native/mesh.cu +243 -6
- warp/native/mesh.h +1547 -458
- warp/native/nanovdb/NanoVDB.h +1 -1
- warp/native/noise.h +272 -329
- warp/native/quat.h +51 -8
- warp/native/rand.h +45 -35
- warp/native/range.h +6 -2
- warp/native/reduce.cpp +157 -0
- warp/native/reduce.cu +348 -0
- warp/native/runlength_encode.cpp +62 -0
- warp/native/runlength_encode.cu +46 -0
- warp/native/scan.cu +11 -13
- warp/native/scan.h +1 -0
- warp/native/solid_angle.h +442 -0
- warp/native/sort.cpp +13 -0
- warp/native/sort.cu +9 -1
- warp/native/sparse.cpp +338 -0
- warp/native/sparse.cu +545 -0
- warp/native/spatial.h +2 -2
- warp/native/temp_buffer.h +30 -0
- warp/native/vec.h +126 -24
- warp/native/volume.h +120 -0
- warp/native/warp.cpp +658 -53
- warp/native/warp.cu +660 -68
- warp/native/warp.h +112 -12
- warp/optim/__init__.py +1 -0
- warp/optim/linear.py +922 -0
- warp/optim/sgd.py +92 -0
- warp/render/render_opengl.py +392 -152
- warp/render/render_usd.py +11 -11
- warp/sim/__init__.py +2 -2
- warp/sim/articulation.py +385 -185
- warp/sim/collide.py +21 -8
- warp/sim/import_mjcf.py +297 -106
- warp/sim/import_urdf.py +389 -210
- warp/sim/import_usd.py +198 -97
- warp/sim/inertia.py +17 -18
- warp/sim/integrator_euler.py +14 -8
- warp/sim/integrator_xpbd.py +161 -19
- warp/sim/model.py +795 -291
- warp/sim/optimizer.py +2 -6
- warp/sim/render.py +65 -3
- warp/sim/utils.py +3 -0
- warp/sparse.py +1227 -0
- warp/stubs.py +665 -223
- warp/tape.py +66 -15
- warp/tests/__main__.py +3 -6
- warp/tests/assets/curlnoise_golden.npy +0 -0
- warp/tests/assets/pnoise_golden.npy +0 -0
- warp/tests/assets/torus.usda +105 -105
- warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
- warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
- warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
- warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
- warp/tests/aux_test_unresolved_func.py +14 -0
- warp/tests/aux_test_unresolved_symbol.py +14 -0
- warp/tests/disabled_kinematics.py +239 -0
- warp/tests/run_coverage_serial.py +31 -0
- warp/tests/test_adam.py +103 -106
- warp/tests/test_arithmetic.py +128 -74
- warp/tests/test_array.py +1497 -211
- warp/tests/test_array_reduce.py +150 -0
- warp/tests/test_atomic.py +64 -28
- warp/tests/test_bool.py +99 -0
- warp/tests/test_builtins_resolution.py +1292 -0
- warp/tests/test_bvh.py +75 -43
- warp/tests/test_closest_point_edge_edge.py +54 -57
- warp/tests/test_codegen.py +233 -128
- warp/tests/test_compile_consts.py +28 -20
- warp/tests/test_conditional.py +108 -24
- warp/tests/test_copy.py +10 -12
- warp/tests/test_ctypes.py +112 -88
- warp/tests/test_dense.py +21 -14
- warp/tests/test_devices.py +98 -0
- warp/tests/test_dlpack.py +136 -108
- warp/tests/test_examples.py +277 -0
- warp/tests/test_fabricarray.py +955 -0
- warp/tests/test_fast_math.py +15 -11
- warp/tests/test_fem.py +1271 -0
- warp/tests/test_fp16.py +53 -19
- warp/tests/test_func.py +187 -74
- warp/tests/test_generics.py +194 -49
- warp/tests/test_grad.py +180 -116
- warp/tests/test_grad_customs.py +176 -0
- warp/tests/test_hash_grid.py +52 -37
- warp/tests/test_import.py +10 -23
- warp/tests/test_indexedarray.py +577 -24
- warp/tests/test_intersect.py +18 -9
- warp/tests/test_large.py +141 -0
- warp/tests/test_launch.py +251 -15
- warp/tests/test_lerp.py +64 -65
- warp/tests/test_linear_solvers.py +154 -0
- warp/tests/test_lvalue.py +493 -0
- warp/tests/test_marching_cubes.py +12 -13
- warp/tests/test_mat.py +508 -2778
- warp/tests/test_mat_lite.py +115 -0
- warp/tests/test_mat_scalar_ops.py +2889 -0
- warp/tests/test_math.py +103 -9
- warp/tests/test_matmul.py +305 -69
- warp/tests/test_matmul_lite.py +410 -0
- warp/tests/test_mesh.py +71 -14
- warp/tests/test_mesh_query_aabb.py +41 -25
- warp/tests/test_mesh_query_point.py +325 -34
- warp/tests/test_mesh_query_ray.py +39 -22
- warp/tests/test_mlp.py +30 -22
- warp/tests/test_model.py +92 -89
- warp/tests/test_modules_lite.py +39 -0
- warp/tests/test_multigpu.py +88 -114
- warp/tests/test_noise.py +12 -11
- warp/tests/test_operators.py +16 -20
- warp/tests/test_options.py +11 -11
- warp/tests/test_pinned.py +17 -18
- warp/tests/test_print.py +32 -11
- warp/tests/test_quat.py +275 -129
- warp/tests/test_rand.py +18 -16
- warp/tests/test_reload.py +38 -34
- warp/tests/test_rounding.py +50 -43
- warp/tests/test_runlength_encode.py +190 -0
- warp/tests/test_smoothstep.py +9 -11
- warp/tests/test_snippet.py +143 -0
- warp/tests/test_sparse.py +460 -0
- warp/tests/test_spatial.py +276 -243
- warp/tests/test_streams.py +110 -85
- warp/tests/test_struct.py +331 -85
- warp/tests/test_tape.py +39 -21
- warp/tests/test_torch.py +118 -89
- warp/tests/test_transient_module.py +12 -13
- warp/tests/test_types.py +614 -0
- warp/tests/test_utils.py +494 -0
- warp/tests/test_vec.py +354 -1987
- warp/tests/test_vec_lite.py +73 -0
- warp/tests/test_vec_scalar_ops.py +2099 -0
- warp/tests/test_volume.py +457 -293
- warp/tests/test_volume_write.py +124 -134
- warp/tests/unittest_serial.py +35 -0
- warp/tests/unittest_suites.py +341 -0
- warp/tests/unittest_utils.py +568 -0
- warp/tests/unused_test_misc.py +71 -0
- warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
- warp/thirdparty/appdirs.py +36 -45
- warp/thirdparty/unittest_parallel.py +549 -0
- warp/torch.py +72 -30
- warp/types.py +1744 -713
- warp/utils.py +360 -350
- warp_lang-0.11.0.dist-info/LICENSE.md +36 -0
- warp_lang-0.11.0.dist-info/METADATA +238 -0
- warp_lang-0.11.0.dist-info/RECORD +332 -0
- {warp_lang-0.9.0.dist-info → warp_lang-0.11.0.dist-info}/WHEEL +1 -1
- warp/bin/warp-clang.exp +0 -0
- warp/bin/warp-clang.lib +0 -0
- warp/bin/warp.exp +0 -0
- warp/bin/warp.lib +0 -0
- warp/tests/test_all.py +0 -215
- warp/tests/test_array_scan.py +0 -60
- warp/tests/test_base.py +0 -208
- warp/tests/test_unresolved_func.py +0 -7
- warp/tests/test_unresolved_symbol.py +0 -7
- warp_lang-0.9.0.dist-info/METADATA +0 -20
- warp_lang-0.9.0.dist-info/RECORD +0 -177
- /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
- /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
- /warp/tests/{test_square.py → aux_test_square.py} +0 -0
- {warp_lang-0.9.0.dist-info → warp_lang-0.11.0.dist-info}/top_level.txt +0 -0
warp/native/marching.cu
CHANGED
|
@@ -1,8 +1,6 @@
|
|
|
1
1
|
#include "warp.h"
|
|
2
2
|
#include "cuda_util.h"
|
|
3
|
-
|
|
4
|
-
#include "thrust/device_ptr.h"
|
|
5
|
-
#include "thrust/sort.h"
|
|
3
|
+
#include "scan.h"
|
|
6
4
|
|
|
7
5
|
namespace wp {
|
|
8
6
|
|
|
@@ -162,13 +160,17 @@ namespace wp {
|
|
|
162
160
|
};
|
|
163
161
|
|
|
164
162
|
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
163
|
+
// ---------------------------------------------------------------------------------------
|
|
164
|
+
struct MarchingCubes
|
|
165
|
+
{
|
|
166
|
+
MarchingCubes()
|
|
169
167
|
{
|
|
170
|
-
|
|
171
|
-
|
|
168
|
+
memset(this, 0, sizeof(MarchingCubes));
|
|
169
|
+
first_cell_vert = nullptr;
|
|
170
|
+
first_cell_tri = nullptr;
|
|
171
|
+
cell_verts = nullptr;
|
|
172
|
+
context = nullptr;
|
|
173
|
+
}
|
|
172
174
|
|
|
173
175
|
__device__ __host__ int cell_index(int xi, int yi, int zi) const
|
|
174
176
|
{
|
|
@@ -181,169 +183,169 @@ namespace wp {
|
|
|
181
183
|
xi = cell_index / ny;
|
|
182
184
|
}
|
|
183
185
|
|
|
184
|
-
|
|
186
|
+
// grid
|
|
185
187
|
int nx;
|
|
186
188
|
int ny;
|
|
187
189
|
int nz;
|
|
188
190
|
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
191
|
+
int* first_cell_vert;
|
|
192
|
+
int* first_cell_tri;
|
|
193
|
+
int* cell_verts;
|
|
192
194
|
|
|
193
195
|
int num_cells;
|
|
194
196
|
int max_cells;
|
|
195
197
|
|
|
196
198
|
void* context;
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
199
|
+
};
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
// -----------------------------------------------------------------------------------
|
|
203
|
+
__global__ void count_cell_verts(MarchingCubes mc, const float* density, float threshold)
|
|
204
|
+
{
|
|
205
|
+
int cell_index = blockIdx.x * blockDim.x + threadIdx.x;
|
|
206
|
+
if (cell_index >= mc.num_cells)
|
|
207
|
+
return;
|
|
208
|
+
|
|
209
|
+
int xi, yi, zi;
|
|
210
|
+
mc.cell_coord(cell_index, xi, yi, zi);
|
|
211
|
+
|
|
212
|
+
mc.first_cell_vert[cell_index] = 0;
|
|
213
|
+
if (xi >= mc.nx - 1 || yi >= mc.ny - 1 || zi >= mc.nz - 1)
|
|
214
|
+
return;
|
|
215
|
+
|
|
216
|
+
float d0 = density[cell_index];
|
|
217
|
+
float dx = density[mc.cell_index(xi + 1, yi, zi)];
|
|
218
|
+
float dy = density[mc.cell_index(xi, yi + 1, zi)];
|
|
219
|
+
float dz = density[mc.cell_index(xi, yi, zi + 1)];
|
|
220
|
+
|
|
221
|
+
int num = 0;
|
|
222
|
+
if ((d0 <= threshold && dx >= threshold) || (dx <= threshold && d0 >= threshold))
|
|
223
|
+
num++;
|
|
224
|
+
if ((d0 <= threshold && dy >= threshold) || (dy <= threshold && d0 >= threshold))
|
|
225
|
+
num++;
|
|
226
|
+
if ((d0 <= threshold && dz >= threshold) || (dz <= threshold && d0 >= threshold))
|
|
227
|
+
num++;
|
|
228
|
+
|
|
229
|
+
mc.first_cell_vert[cell_index] = num;
|
|
230
|
+
}
|
|
231
|
+
|
|
232
|
+
// -----------------------------------------------------------------------------------
|
|
233
|
+
__global__ void create_cell_verts(MarchingCubes mc, vec3* __restrict__ vertices, vec3* normals, const float* __restrict__ density, float threshold)
|
|
234
|
+
{
|
|
235
|
+
int cell_index = blockIdx.x * blockDim.x + threadIdx.x;
|
|
236
|
+
if (cell_index >= mc.num_cells)
|
|
237
|
+
return;
|
|
238
|
+
|
|
239
|
+
int xi, yi, zi;
|
|
240
|
+
mc.cell_coord(cell_index, xi, yi, zi);
|
|
241
|
+
if (xi >= mc.nx - 1 || yi >= mc.ny - 1 || zi >= mc.nz - 1)
|
|
242
|
+
return;
|
|
243
|
+
|
|
244
|
+
vec3 p = vec3(xi + 0.5f, yi + 0.5f, zi + 0.5f);
|
|
245
|
+
|
|
246
|
+
float d0 = density[cell_index];
|
|
247
|
+
float ds[3];
|
|
248
|
+
ds[0] = density[mc.cell_index(xi + 1, yi, zi)];
|
|
249
|
+
ds[1] = density[mc.cell_index(xi, yi + 1, zi)];
|
|
250
|
+
ds[2] = density[mc.cell_index(xi, yi, zi + 1)];
|
|
251
|
+
|
|
252
|
+
// vec3 n0 = densityNormal[cell_index];
|
|
253
|
+
// vec3 ns[3];
|
|
254
|
+
// ns[0] = densityNormal[mc.cell_index(xi + 1, yi, zi)];
|
|
255
|
+
// ns[1] = densityNormal[mc.cell_index(xi, yi + 1, zi)];
|
|
256
|
+
// ns[2] = densityNormal[mc.cell_index(xi, yi, zi + 1)];
|
|
257
|
+
|
|
258
|
+
int first = mc.first_cell_vert[cell_index];
|
|
259
|
+
|
|
260
|
+
for (int dim = 0; dim < 3; dim++)
|
|
259
261
|
{
|
|
260
|
-
|
|
261
|
-
|
|
262
|
+
float d = ds[dim];
|
|
263
|
+
mc.cell_verts[3 * cell_index + dim] = 0;
|
|
262
264
|
|
|
263
|
-
|
|
265
|
+
if ((d0 <= threshold && d >= threshold) || (d <= threshold && d0 >= threshold))
|
|
264
266
|
{
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
267
|
+
float t = (d != d0) ? clamp((threshold - d0) / (d - d0), 0.0f, 1.0f) : 0.5f;
|
|
268
|
+
int id = first++;
|
|
269
|
+
|
|
270
|
+
vec3 off;
|
|
271
|
+
off[dim] = t;
|
|
272
|
+
vertices[id] = p + off;
|
|
273
|
+
|
|
274
|
+
// vec3 n = normalize(n0 + t * (ns[dim] - n0));
|
|
275
|
+
// normals[id] = -n;
|
|
276
|
+
|
|
277
|
+
mc.cell_verts[3 * cell_index + dim] = id;
|
|
278
|
+
}
|
|
279
|
+
}
|
|
280
|
+
}
|
|
281
|
+
|
|
282
|
+
// -----------------------------------------------------------------------------------
|
|
283
|
+
__global__ void count_cell_tris(MarchingCubes mc, const float* __restrict__ density, float threshold)
|
|
284
|
+
{
|
|
285
|
+
int cell_index = blockIdx.x * blockDim.x + threadIdx.x;
|
|
286
|
+
if (cell_index >= mc.num_cells)
|
|
287
|
+
return;
|
|
288
|
+
|
|
289
|
+
int xi, yi, zi;
|
|
290
|
+
mc.cell_coord(cell_index, xi, yi, zi);
|
|
291
|
+
|
|
292
|
+
mc.first_cell_tri[cell_index] = 0;
|
|
293
|
+
if (xi >= mc.nx - 2 || yi >= mc.ny - 2 || zi >= mc.nz - 2)
|
|
294
|
+
return;
|
|
295
|
+
|
|
296
|
+
int code = 0;
|
|
297
|
+
for (int i = 0; i < 8; i++) {
|
|
298
|
+
int cxi = xi + marchingCubeCorners[i][0];
|
|
299
|
+
int cyi = yi + marchingCubeCorners[i][1];
|
|
300
|
+
int czi = zi + marchingCubeCorners[i][2];
|
|
301
|
+
|
|
302
|
+
if (density[mc.cell_index(cxi, cyi, czi)] >= threshold)
|
|
303
|
+
code |= (1 << i);
|
|
304
|
+
}
|
|
305
|
+
|
|
306
|
+
mc.first_cell_tri[cell_index] = firstMarchingCubesId[code + 1] - firstMarchingCubesId[code];
|
|
307
|
+
}
|
|
308
|
+
|
|
309
|
+
// -----------------------------------------------------------------------------------
|
|
310
|
+
__global__ void create_cell_tris(MarchingCubes mc, const float* __restrict__ density, int* __restrict__ triangles, float threshold)
|
|
311
|
+
{
|
|
312
|
+
int cell_index = blockIdx.x * blockDim.x + threadIdx.x;
|
|
313
|
+
if (cell_index >= mc.num_cells)
|
|
314
|
+
return;
|
|
315
|
+
|
|
316
|
+
int xi, yi, zi;
|
|
317
|
+
mc.cell_coord(cell_index, xi, yi, zi);
|
|
318
|
+
if (xi >= mc.nx - 2 || yi >= mc.ny - 2 || zi >= mc.nz - 2)
|
|
319
|
+
return;
|
|
320
|
+
|
|
321
|
+
int code = 0;
|
|
322
|
+
for (int i = 0; i < 8; i++)
|
|
321
323
|
{
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
324
|
+
int cxi = xi + marchingCubeCorners[i][0];
|
|
325
|
+
int cyi = yi + marchingCubeCorners[i][1];
|
|
326
|
+
int czi = zi + marchingCubeCorners[i][2];
|
|
325
327
|
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
328
|
+
if (density[mc.cell_index(cxi, cyi, czi)] >= threshold)
|
|
329
|
+
code |= (1 << i);
|
|
330
|
+
}
|
|
329
331
|
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
332
|
+
int firstIn = firstMarchingCubesId[code];
|
|
333
|
+
int num = firstMarchingCubesId[code + 1] - firstIn;
|
|
334
|
+
int firstOut = mc.first_cell_tri[cell_index];
|
|
333
335
|
|
|
334
|
-
|
|
336
|
+
for (int i = 0; i < num; i++)
|
|
335
337
|
{
|
|
336
|
-
|
|
338
|
+
int eid = marchingCubesIds[firstIn + i];
|
|
337
339
|
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
340
|
+
int exi = xi + marchingCubesEdgeLocations[eid][0];
|
|
341
|
+
int eyi = yi + marchingCubesEdgeLocations[eid][1];
|
|
342
|
+
int ezi = zi + marchingCubesEdgeLocations[eid][2];
|
|
343
|
+
int edgeNr = marchingCubesEdgeLocations[eid][3];
|
|
342
344
|
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
345
|
+
int id = mc.cell_verts[3 * mc.cell_index(exi, eyi, ezi) + edgeNr];
|
|
346
|
+
triangles[firstOut + i] = id;
|
|
347
|
+
}
|
|
348
|
+
}
|
|
347
349
|
|
|
348
350
|
// -------------------------
|
|
349
351
|
void marching_cubes_resize(MarchingCubes& mc, int nx, int ny, int nz)
|
|
@@ -444,10 +446,7 @@ WP_API int marching_cubes_surface_device(
|
|
|
444
446
|
int num_last;
|
|
445
447
|
memcpy_d2h(WP_CURRENT_CONTEXT, &num_last, &mc.first_cell_vert[mc.num_cells - 1], sizeof(int));
|
|
446
448
|
|
|
447
|
-
|
|
448
|
-
thrust::device_ptr<int>(mc.first_cell_vert),
|
|
449
|
-
thrust::device_ptr<int>(mc.first_cell_vert + mc.num_cells),
|
|
450
|
-
thrust::device_ptr<int>(mc.first_cell_vert));
|
|
449
|
+
scan_device(mc.first_cell_vert, mc.first_cell_vert, mc.num_cells, false);
|
|
451
450
|
|
|
452
451
|
int num_verts;
|
|
453
452
|
memcpy_d2h(WP_CURRENT_CONTEXT, &num_verts, &mc.first_cell_vert[mc.num_cells - 1], sizeof(int));
|
|
@@ -472,10 +471,7 @@ WP_API int marching_cubes_surface_device(
|
|
|
472
471
|
|
|
473
472
|
memcpy_d2h(WP_CURRENT_CONTEXT, &num_last, &mc.first_cell_tri[mc.num_cells - 1], sizeof(int));
|
|
474
473
|
|
|
475
|
-
|
|
476
|
-
thrust::device_ptr<int>(mc.first_cell_tri),
|
|
477
|
-
thrust::device_ptr<int>(mc.first_cell_tri + mc.num_cells),
|
|
478
|
-
thrust::device_ptr<int>(mc.first_cell_tri));
|
|
474
|
+
scan_device(mc.first_cell_tri, mc.first_cell_tri, mc.num_cells, false);
|
|
479
475
|
|
|
480
476
|
|
|
481
477
|
int num_indices;
|
warp/native/mat.h
CHANGED
|
@@ -21,7 +21,9 @@ struct quat_t;
|
|
|
21
21
|
template<unsigned Rows, unsigned Cols, typename Type>
|
|
22
22
|
struct mat_t
|
|
23
23
|
{
|
|
24
|
-
inline mat_t()
|
|
24
|
+
inline CUDA_CALLABLE mat_t()
|
|
25
|
+
: data()
|
|
26
|
+
{}
|
|
25
27
|
|
|
26
28
|
inline CUDA_CALLABLE mat_t(Type s)
|
|
27
29
|
{
|
|
@@ -30,6 +32,14 @@ struct mat_t
|
|
|
30
32
|
data[i][j] = s;
|
|
31
33
|
}
|
|
32
34
|
|
|
35
|
+
template <typename OtherType>
|
|
36
|
+
inline explicit CUDA_CALLABLE mat_t(const mat_t<Rows, Cols, OtherType>& other)
|
|
37
|
+
{
|
|
38
|
+
for (unsigned i=0; i < Rows; ++i)
|
|
39
|
+
for (unsigned j=0; j < Cols; ++j)
|
|
40
|
+
data[i][j] = other.data[i][j];
|
|
41
|
+
}
|
|
42
|
+
|
|
33
43
|
inline CUDA_CALLABLE mat_t(vec_t<2,Type> c0, vec_t<2,Type> c1)
|
|
34
44
|
{
|
|
35
45
|
data[0][0] = c0[0];
|
|
@@ -185,7 +195,7 @@ struct mat_t
|
|
|
185
195
|
}
|
|
186
196
|
|
|
187
197
|
// row major storage assumed to be compatible with PyTorch
|
|
188
|
-
Type data[Rows][Cols]
|
|
198
|
+
Type data[Rows][Cols];
|
|
189
199
|
};
|
|
190
200
|
|
|
191
201
|
|
|
@@ -290,7 +300,19 @@ inline CUDA_CALLABLE mat_t<Rows,Cols,Type> atomic_max(mat_t<Rows,Cols,Type> * ad
|
|
|
290
300
|
}
|
|
291
301
|
|
|
292
302
|
template<unsigned Rows, unsigned Cols, typename Type>
|
|
293
|
-
inline CUDA_CALLABLE
|
|
303
|
+
inline CUDA_CALLABLE void adj_atomic_minmax(
|
|
304
|
+
mat_t<Rows,Cols,Type> *addr,
|
|
305
|
+
mat_t<Rows,Cols,Type> *adj_addr,
|
|
306
|
+
const mat_t<Rows,Cols,Type> &value,
|
|
307
|
+
mat_t<Rows,Cols,Type> &adj_value)
|
|
308
|
+
{
|
|
309
|
+
for (unsigned i=0; i < Rows; ++i)
|
|
310
|
+
for (unsigned j=0; j < Cols; ++j)
|
|
311
|
+
adj_atomic_minmax(&addr->data[i][j], &adj_addr->data[i][j], value.data[i][j], adj_value.data[i][j]);
|
|
312
|
+
}
|
|
313
|
+
|
|
314
|
+
template<unsigned Rows, unsigned Cols, typename Type>
|
|
315
|
+
inline CUDA_CALLABLE vec_t<Cols,Type> extract(const mat_t<Rows,Cols,Type>& m, int row)
|
|
294
316
|
{
|
|
295
317
|
vec_t<Cols,Type> ret;
|
|
296
318
|
for(unsigned i=0; i < Cols; ++i)
|
|
@@ -301,7 +323,7 @@ inline CUDA_CALLABLE vec_t<Cols,Type> index(const mat_t<Rows,Cols,Type>& m, int
|
|
|
301
323
|
}
|
|
302
324
|
|
|
303
325
|
template<unsigned Rows, unsigned Cols, typename Type>
|
|
304
|
-
inline CUDA_CALLABLE Type
|
|
326
|
+
inline CUDA_CALLABLE Type extract(const mat_t<Rows,Cols,Type>& m, int row, int col)
|
|
305
327
|
{
|
|
306
328
|
#ifndef NDEBUG
|
|
307
329
|
if (row < 0 || row >= Rows)
|
|
@@ -319,7 +341,7 @@ inline CUDA_CALLABLE Type index(const mat_t<Rows,Cols,Type>& m, int row, int col
|
|
|
319
341
|
}
|
|
320
342
|
|
|
321
343
|
template<unsigned Rows, unsigned Cols, typename Type>
|
|
322
|
-
inline CUDA_CALLABLE
|
|
344
|
+
inline CUDA_CALLABLE vec_t<Cols, Type>* index(mat_t<Rows,Cols,Type>& m, int row)
|
|
323
345
|
{
|
|
324
346
|
#ifndef NDEBUG
|
|
325
347
|
if (row < 0 || row >= Rows)
|
|
@@ -329,12 +351,11 @@ inline CUDA_CALLABLE void indexset(mat_t<Rows,Cols,Type>& m, int row, vec_t<Cols
|
|
|
329
351
|
}
|
|
330
352
|
#endif
|
|
331
353
|
|
|
332
|
-
|
|
333
|
-
m.data[row][i] = value[i];
|
|
354
|
+
return reinterpret_cast<vec_t<Cols, Type>*>(&m.data[row]);
|
|
334
355
|
}
|
|
335
356
|
|
|
336
357
|
template<unsigned Rows, unsigned Cols, typename Type>
|
|
337
|
-
inline CUDA_CALLABLE
|
|
358
|
+
inline CUDA_CALLABLE Type* index(mat_t<Rows,Cols,Type>& m, int row, int col)
|
|
338
359
|
{
|
|
339
360
|
#ifndef NDEBUG
|
|
340
361
|
if (row < 0 || row >= Rows)
|
|
@@ -348,18 +369,19 @@ inline CUDA_CALLABLE void indexset(mat_t<Rows,Cols,Type>& m, int row, int col, T
|
|
|
348
369
|
assert(0);
|
|
349
370
|
}
|
|
350
371
|
#endif
|
|
351
|
-
|
|
372
|
+
|
|
373
|
+
return &m.data[row][col];
|
|
352
374
|
}
|
|
353
375
|
|
|
354
376
|
template<unsigned Rows, unsigned Cols, typename Type>
|
|
355
|
-
inline CUDA_CALLABLE void
|
|
377
|
+
inline CUDA_CALLABLE void adj_index(const mat_t<Rows,Cols,Type>& m, int row,
|
|
356
378
|
const mat_t<Rows,Cols,Type>& adj_m, int adj_row, const vec_t<Cols, Type>& adj_value)
|
|
357
379
|
{
|
|
358
380
|
// nop
|
|
359
381
|
}
|
|
360
382
|
|
|
361
383
|
template<unsigned Rows, unsigned Cols, typename Type>
|
|
362
|
-
inline CUDA_CALLABLE void
|
|
384
|
+
inline CUDA_CALLABLE void adj_index(const mat_t<Rows,Cols,Type>& m, int row, int col,
|
|
363
385
|
const mat_t<Rows,Cols,Type>& adj_m, int adj_row, int adj_col, Type adj_value)
|
|
364
386
|
{
|
|
365
387
|
// nop
|
|
@@ -417,7 +439,22 @@ inline CUDA_CALLABLE mat_t<Rows,Cols,Type> div(const mat_t<Rows,Cols,Type>& a, T
|
|
|
417
439
|
}
|
|
418
440
|
}
|
|
419
441
|
|
|
420
|
-
return t;
|
|
442
|
+
return t;
|
|
443
|
+
}
|
|
444
|
+
|
|
445
|
+
template<unsigned Rows, unsigned Cols, typename Type>
|
|
446
|
+
inline CUDA_CALLABLE mat_t<Rows,Cols,Type> div(Type b, const mat_t<Rows,Cols,Type>& a)
|
|
447
|
+
{
|
|
448
|
+
mat_t<Rows,Cols,Type> t;
|
|
449
|
+
for (unsigned i=0; i < Rows; ++i)
|
|
450
|
+
{
|
|
451
|
+
for (unsigned j=0; j < Cols; ++j)
|
|
452
|
+
{
|
|
453
|
+
t.data[i][j] = b / a.data[i][j];
|
|
454
|
+
}
|
|
455
|
+
}
|
|
456
|
+
|
|
457
|
+
return t;
|
|
421
458
|
}
|
|
422
459
|
|
|
423
460
|
template<unsigned Rows, unsigned Cols, typename Type>
|
|
@@ -432,7 +469,7 @@ inline CUDA_CALLABLE mat_t<Rows,Cols,Type> mul(const mat_t<Rows,Cols,Type>& a, T
|
|
|
432
469
|
}
|
|
433
470
|
}
|
|
434
471
|
|
|
435
|
-
return t;
|
|
472
|
+
return t;
|
|
436
473
|
}
|
|
437
474
|
|
|
438
475
|
template<unsigned Rows, unsigned Cols, typename Type>
|
|
@@ -465,6 +502,17 @@ inline CUDA_CALLABLE vec_t<Rows,Type> mul(const mat_t<Rows,Cols,Type>& a, const
|
|
|
465
502
|
return r;
|
|
466
503
|
}
|
|
467
504
|
|
|
505
|
+
template<unsigned Rows, unsigned Cols, typename Type>
|
|
506
|
+
inline CUDA_CALLABLE vec_t<Cols,Type> mul(const vec_t<Rows,Type>& b, const mat_t<Rows,Cols,Type>& a)
|
|
507
|
+
{
|
|
508
|
+
vec_t<Cols,Type> r = a.get_row(0)*b[0];
|
|
509
|
+
for( unsigned i=1; i < Rows; ++i )
|
|
510
|
+
{
|
|
511
|
+
r += a.get_row(i)*b[i];
|
|
512
|
+
}
|
|
513
|
+
return r;
|
|
514
|
+
}
|
|
515
|
+
|
|
468
516
|
template<unsigned Rows, unsigned Cols, unsigned ColsOut, typename Type>
|
|
469
517
|
inline CUDA_CALLABLE mat_t<Rows,ColsOut,Type> mul(const mat_t<Rows,Cols,Type>& a, const mat_t<Cols,ColsOut,Type>& b)
|
|
470
518
|
{
|
|
@@ -608,6 +656,17 @@ inline CUDA_CALLABLE Type trace(const mat_t<Rows,Rows,Type>& m)
|
|
|
608
656
|
return ret;
|
|
609
657
|
}
|
|
610
658
|
|
|
659
|
+
template<unsigned Rows, typename Type>
|
|
660
|
+
inline CUDA_CALLABLE vec_t<Rows, Type> get_diag(const mat_t<Rows,Rows,Type>& m)
|
|
661
|
+
{
|
|
662
|
+
vec_t<Rows, Type> ret;
|
|
663
|
+
for( unsigned i=0; i < Rows; ++i )
|
|
664
|
+
{
|
|
665
|
+
ret[i] = m.data[i][i];
|
|
666
|
+
}
|
|
667
|
+
return ret;
|
|
668
|
+
}
|
|
669
|
+
|
|
611
670
|
// Only implementing inverses for 2x2, 3x3 and 4x4 matrices for now...
|
|
612
671
|
template<typename Type>
|
|
613
672
|
inline CUDA_CALLABLE mat_t<2,2,Type> inverse(const mat_t<2,2,Type>& m)
|
|
@@ -842,14 +901,14 @@ inline CUDA_CALLABLE vec_t<3,Type> transform_vector(const mat_t<4,4,Type>& m, co
|
|
|
842
901
|
}
|
|
843
902
|
|
|
844
903
|
template<unsigned Rows, unsigned Cols, typename Type>
|
|
845
|
-
inline CUDA_CALLABLE void
|
|
904
|
+
inline CUDA_CALLABLE void adj_extract(const mat_t<Rows,Cols,Type>& m, int row, mat_t<Rows,Cols,Type>& adj_m, int& adj_row, const vec_t<Cols,Type>& adj_ret)
|
|
846
905
|
{
|
|
847
906
|
for( unsigned col=0; col < Cols; ++col )
|
|
848
907
|
adj_m.data[row][col] += adj_ret[col];
|
|
849
908
|
}
|
|
850
909
|
|
|
851
910
|
template<unsigned Rows, unsigned Cols, typename Type>
|
|
852
|
-
inline void CUDA_CALLABLE
|
|
911
|
+
inline void CUDA_CALLABLE adj_extract(const mat_t<Rows,Cols,Type>& m, int row, int col, mat_t<Rows,Cols,Type>& adj_m, int& adj_row, int& adj_col, Type adj_ret)
|
|
853
912
|
{
|
|
854
913
|
#ifndef NDEBUG
|
|
855
914
|
if (row < 0 || row > Rows)
|
|
@@ -913,6 +972,20 @@ inline CUDA_CALLABLE void adj_div(const mat_t<Rows,Cols,Type>& a, Type s, mat_t<
|
|
|
913
972
|
}
|
|
914
973
|
}
|
|
915
974
|
|
|
975
|
+
template<unsigned Rows, unsigned Cols, typename Type>
|
|
976
|
+
inline CUDA_CALLABLE void adj_div(Type s, const mat_t<Rows,Cols,Type>& a, Type& adj_s, mat_t<Rows,Cols,Type>& adj_a, const mat_t<Rows,Cols,Type>& adj_ret)
|
|
977
|
+
{
|
|
978
|
+
adj_s -= tensordot(a , adj_ret)/ (s * s); // - a / s^2
|
|
979
|
+
|
|
980
|
+
for (unsigned i=0; i < Rows; ++i)
|
|
981
|
+
{
|
|
982
|
+
for (unsigned j=0; j < Cols; ++j)
|
|
983
|
+
{
|
|
984
|
+
adj_a.data[i][j] += s / adj_ret.data[i][j];
|
|
985
|
+
}
|
|
986
|
+
}
|
|
987
|
+
}
|
|
988
|
+
|
|
916
989
|
template<unsigned Rows, unsigned Cols, typename Type>
|
|
917
990
|
inline CUDA_CALLABLE void adj_mul(const mat_t<Rows,Cols,Type>& a, Type b, mat_t<Rows,Cols,Type>& adj_a, Type& adj_b, const mat_t<Rows,Cols,Type>& adj_ret)
|
|
918
991
|
{
|
|
@@ -946,6 +1019,13 @@ inline CUDA_CALLABLE void adj_mul(const mat_t<Rows,Cols,Type>& a, const vec_t<Co
|
|
|
946
1019
|
adj_b += mul(transpose(a), adj_ret);
|
|
947
1020
|
}
|
|
948
1021
|
|
|
1022
|
+
template<unsigned Rows, unsigned Cols, typename Type>
|
|
1023
|
+
inline CUDA_CALLABLE void adj_mul(const vec_t<Rows,Type>& b, const mat_t<Rows,Cols,Type>& a, vec_t<Rows,Type>& adj_b, mat_t<Rows,Cols,Type>& adj_a, const vec_t<Cols,Type>& adj_ret)
|
|
1024
|
+
{
|
|
1025
|
+
adj_a += outer(b, adj_ret);
|
|
1026
|
+
adj_b += mul(adj_ret, transpose(a));
|
|
1027
|
+
}
|
|
1028
|
+
|
|
949
1029
|
template<unsigned Rows, unsigned Cols, unsigned ColsOut, typename Type>
|
|
950
1030
|
inline CUDA_CALLABLE void adj_mul(const mat_t<Rows,Cols,Type>& a, const mat_t<Cols,ColsOut,Type>& b, mat_t<Rows,Cols,Type>& adj_a, mat_t<Cols,ColsOut,Type>& adj_b, const mat_t<Rows,ColsOut,Type>& adj_ret)
|
|
951
1031
|
{
|
|
@@ -973,6 +1053,13 @@ inline CUDA_CALLABLE void adj_diag(const vec_t<Rows,Type>& d, vec_t<Rows,Type>&
|
|
|
973
1053
|
adj_d[i] += adj_ret.data[i][i];
|
|
974
1054
|
}
|
|
975
1055
|
|
|
1056
|
+
template<unsigned Rows, typename Type>
|
|
1057
|
+
inline CUDA_CALLABLE void adj_get_diag(const mat_t<Rows,Rows,Type>& m, mat_t<Rows,Rows,Type>& adj_m, const vec_t<Rows,Type>& adj_ret)
|
|
1058
|
+
{
|
|
1059
|
+
for (unsigned i=0; i < Rows; ++i)
|
|
1060
|
+
adj_m.data[i][i] += adj_ret[i];
|
|
1061
|
+
}
|
|
1062
|
+
|
|
976
1063
|
template<typename Type>
|
|
977
1064
|
inline CUDA_CALLABLE void adj_determinant(const mat_t<2,2,Type>& m, mat_t<2,2,Type>& adj_m, Type adj_ret)
|
|
978
1065
|
{
|
|
@@ -1079,10 +1166,10 @@ inline CUDA_CALLABLE void adj_determinant(const mat_t<4,4,Type>& m, mat_t<4,4,Ty
|
|
|
1079
1166
|
}
|
|
1080
1167
|
|
|
1081
1168
|
template<unsigned Rows, typename Type>
|
|
1082
|
-
inline CUDA_CALLABLE void adj_inverse(const mat_t<Rows,Rows,Type>& m, mat_t<Rows,Rows,Type>& adj_m, const mat_t<Rows,Rows,Type>& adj_ret)
|
|
1169
|
+
inline CUDA_CALLABLE void adj_inverse(const mat_t<Rows,Rows,Type>& m, mat_t<Rows,Rows,Type>& ret, mat_t<Rows,Rows,Type>& adj_m, const mat_t<Rows,Rows,Type>& adj_ret)
|
|
1083
1170
|
{
|
|
1084
1171
|
// todo: how to cache this from the forward pass?
|
|
1085
|
-
mat_t<Rows,Rows,Type> invt = transpose(
|
|
1172
|
+
mat_t<Rows,Rows,Type> invt = transpose(ret);
|
|
1086
1173
|
|
|
1087
1174
|
// see https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf 2.2.3
|
|
1088
1175
|
adj_m -= mul(mul(invt, adj_ret), invt);
|
|
@@ -1124,10 +1211,10 @@ inline CUDA_CALLABLE void adj_cw_mul(const mat_t<Rows,Cols,Type>& a, const mat_t
|
|
|
1124
1211
|
}
|
|
1125
1212
|
|
|
1126
1213
|
template<unsigned Rows, unsigned Cols, typename Type>
|
|
1127
|
-
inline CUDA_CALLABLE void adj_cw_div(const mat_t<Rows,Cols,Type>& a, const mat_t<Rows,Cols,Type>& b, mat_t<Rows,Cols,Type>& adj_a, mat_t<Rows,Cols,Type>& adj_b, const mat_t<Rows,Cols,Type>& adj_ret)
|
|
1214
|
+
inline CUDA_CALLABLE void adj_cw_div(const mat_t<Rows,Cols,Type>& a, const mat_t<Rows,Cols,Type>& b, mat_t<Rows,Cols,Type>& ret, mat_t<Rows,Cols,Type>& adj_a, mat_t<Rows,Cols,Type>& adj_b, const mat_t<Rows,Cols,Type>& adj_ret)
|
|
1128
1215
|
{
|
|
1129
1216
|
adj_a += cw_div(adj_ret, b);
|
|
1130
|
-
adj_b -= cw_mul(adj_ret, cw_div(
|
|
1217
|
+
adj_b -= cw_mul(adj_ret, cw_div(ret, b));
|
|
1131
1218
|
}
|
|
1132
1219
|
|
|
1133
1220
|
// adjoint for the constant constructor:
|
|
@@ -1143,6 +1230,19 @@ inline CUDA_CALLABLE void adj_mat_t(Type s, Type& adj_s, const mat_t<Rows, Cols,
|
|
|
1143
1230
|
}
|
|
1144
1231
|
}
|
|
1145
1232
|
|
|
1233
|
+
// adjoint for the casting constructor:
|
|
1234
|
+
template<unsigned Rows, unsigned Cols, typename Type, typename OtherType>
|
|
1235
|
+
inline CUDA_CALLABLE void adj_mat_t(const mat_t<Rows, Cols, OtherType>& other, mat_t<Rows, Cols, OtherType>& adj_other, const mat_t<Rows, Cols, Type>& adj_ret)
|
|
1236
|
+
{
|
|
1237
|
+
for (unsigned i=0; i < Rows; ++i)
|
|
1238
|
+
{
|
|
1239
|
+
for (unsigned j=0; j < Cols; ++j)
|
|
1240
|
+
{
|
|
1241
|
+
adj_other.data[i][j] += adj_ret.data[i][j];
|
|
1242
|
+
}
|
|
1243
|
+
}
|
|
1244
|
+
}
|
|
1245
|
+
|
|
1146
1246
|
// adjoint for the initializer_array scalar constructor:
|
|
1147
1247
|
template<unsigned Rows, unsigned Cols, typename Type>
|
|
1148
1248
|
inline CUDA_CALLABLE void adj_mat_t(const initializer_array<Rows * Cols, Type> &cmps, const initializer_array<Rows * Cols, Type*> &adj_cmps, const mat_t<Rows, Cols, Type>& adj_ret)
|
warp/native/matnn.h
CHANGED
|
@@ -248,7 +248,7 @@ CUDA_CALLABLE inline void adj_mlp(const array_t<float>& weights, const array_t<f
|
|
|
248
248
|
tmp += weights.data[i*n + j]*x.data[index + b*j];
|
|
249
249
|
}
|
|
250
250
|
|
|
251
|
-
// adjoint w.r.t to
|
|
251
|
+
// adjoint w.r.t to activation
|
|
252
252
|
float adj_f = 0.0f;
|
|
253
253
|
|
|
254
254
|
if (adj_out.data)
|
|
@@ -313,7 +313,7 @@ CUDA_CALLABLE inline void adj_mlp(const array_t<float>& weights, const array_t<f
|
|
|
313
313
|
// tmp += weights[i*n + j]*x[index + b*j];
|
|
314
314
|
// }
|
|
315
315
|
|
|
316
|
-
// // adjoint w.r.t to
|
|
316
|
+
// // adjoint w.r.t to activation
|
|
317
317
|
// float adj_f = 0.0f;
|
|
318
318
|
// adj_activation(tmp, adj_f, adj_out[index + b*i]);
|
|
319
319
|
|