warp-lang 1.5.1__py3-none-manylinux2014_aarch64.whl → 1.6.0__py3-none-manylinux2014_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +5 -0
- warp/autograd.py +414 -191
- warp/bin/warp-clang.so +0 -0
- warp/bin/warp.so +0 -0
- warp/build.py +40 -12
- warp/build_dll.py +13 -6
- warp/builtins.py +1076 -480
- warp/codegen.py +240 -119
- warp/config.py +1 -1
- warp/context.py +298 -84
- warp/examples/assets/square_cloth.usd +0 -0
- warp/examples/benchmarks/benchmark_gemm.py +27 -18
- warp/examples/benchmarks/benchmark_interop_paddle.py +3 -3
- warp/examples/benchmarks/benchmark_interop_torch.py +3 -3
- warp/examples/core/example_torch.py +18 -34
- warp/examples/fem/example_apic_fluid.py +1 -0
- warp/examples/fem/example_mixed_elasticity.py +1 -1
- warp/examples/optim/example_bounce.py +1 -1
- warp/examples/optim/example_cloth_throw.py +1 -1
- warp/examples/optim/example_diffray.py +4 -15
- warp/examples/optim/example_drone.py +1 -1
- warp/examples/optim/example_softbody_properties.py +392 -0
- warp/examples/optim/example_trajectory.py +1 -3
- warp/examples/optim/example_walker.py +5 -0
- warp/examples/sim/example_cartpole.py +0 -2
- warp/examples/sim/example_cloth_self_contact.py +260 -0
- warp/examples/sim/example_granular_collision_sdf.py +4 -5
- warp/examples/sim/example_jacobian_ik.py +0 -2
- warp/examples/sim/example_quadruped.py +5 -2
- warp/examples/tile/example_tile_cholesky.py +79 -0
- warp/examples/tile/example_tile_convolution.py +2 -2
- warp/examples/tile/example_tile_fft.py +2 -2
- warp/examples/tile/example_tile_filtering.py +3 -3
- warp/examples/tile/example_tile_matmul.py +4 -4
- warp/examples/tile/example_tile_mlp.py +12 -12
- warp/examples/tile/example_tile_nbody.py +180 -0
- warp/examples/tile/example_tile_walker.py +319 -0
- warp/math.py +147 -0
- warp/native/array.h +12 -0
- warp/native/builtin.h +0 -1
- warp/native/bvh.cpp +149 -70
- warp/native/bvh.cu +287 -68
- warp/native/bvh.h +195 -85
- warp/native/clang/clang.cpp +5 -1
- warp/native/cuda_util.cpp +35 -0
- warp/native/cuda_util.h +5 -0
- warp/native/exports.h +40 -40
- warp/native/intersect.h +17 -0
- warp/native/mat.h +41 -0
- warp/native/mathdx.cpp +19 -0
- warp/native/mesh.cpp +25 -8
- warp/native/mesh.cu +153 -101
- warp/native/mesh.h +482 -403
- warp/native/quat.h +40 -0
- warp/native/solid_angle.h +7 -0
- warp/native/sort.cpp +85 -0
- warp/native/sort.cu +34 -0
- warp/native/sort.h +3 -1
- warp/native/spatial.h +11 -0
- warp/native/tile.h +1185 -664
- warp/native/tile_reduce.h +8 -6
- warp/native/vec.h +41 -0
- warp/native/warp.cpp +8 -1
- warp/native/warp.cu +263 -40
- warp/native/warp.h +19 -5
- warp/optim/linear.py +22 -4
- warp/render/render_opengl.py +124 -59
- warp/sim/__init__.py +6 -1
- warp/sim/collide.py +270 -26
- warp/sim/integrator_euler.py +25 -7
- warp/sim/integrator_featherstone.py +154 -35
- warp/sim/integrator_vbd.py +842 -40
- warp/sim/model.py +111 -53
- warp/stubs.py +248 -115
- warp/tape.py +28 -30
- warp/tests/aux_test_module_unload.py +15 -0
- warp/tests/{test_sim_grad.py → flaky_test_sim_grad.py} +104 -63
- warp/tests/test_array.py +74 -0
- warp/tests/test_assert.py +242 -0
- warp/tests/test_codegen.py +14 -61
- warp/tests/test_collision.py +2 -2
- warp/tests/test_examples.py +9 -0
- warp/tests/test_grad_debug.py +87 -2
- warp/tests/test_hash_grid.py +1 -1
- warp/tests/test_ipc.py +116 -0
- warp/tests/test_mat.py +138 -167
- warp/tests/test_math.py +47 -1
- warp/tests/test_matmul.py +11 -7
- warp/tests/test_matmul_lite.py +4 -4
- warp/tests/test_mesh.py +84 -60
- warp/tests/test_mesh_query_aabb.py +165 -0
- warp/tests/test_mesh_query_point.py +328 -286
- warp/tests/test_mesh_query_ray.py +134 -121
- warp/tests/test_mlp.py +2 -2
- warp/tests/test_operators.py +43 -0
- warp/tests/test_overwrite.py +2 -2
- warp/tests/test_quat.py +77 -0
- warp/tests/test_reload.py +29 -0
- warp/tests/test_sim_grad_bounce_linear.py +204 -0
- warp/tests/test_static.py +16 -0
- warp/tests/test_tape.py +25 -0
- warp/tests/test_tile.py +134 -191
- warp/tests/test_tile_load.py +356 -0
- warp/tests/test_tile_mathdx.py +61 -8
- warp/tests/test_tile_mlp.py +17 -17
- warp/tests/test_tile_reduce.py +24 -18
- warp/tests/test_tile_shared_memory.py +66 -17
- warp/tests/test_tile_view.py +165 -0
- warp/tests/test_torch.py +35 -0
- warp/tests/test_utils.py +36 -24
- warp/tests/test_vec.py +110 -0
- warp/tests/unittest_suites.py +29 -4
- warp/tests/unittest_utils.py +30 -11
- warp/thirdparty/unittest_parallel.py +2 -2
- warp/types.py +409 -99
- warp/utils.py +9 -5
- {warp_lang-1.5.1.dist-info → warp_lang-1.6.0.dist-info}/METADATA +68 -44
- {warp_lang-1.5.1.dist-info → warp_lang-1.6.0.dist-info}/RECORD +121 -110
- {warp_lang-1.5.1.dist-info → warp_lang-1.6.0.dist-info}/WHEEL +1 -1
- warp/examples/benchmarks/benchmark_tile.py +0 -179
- warp/native/tile_gemm.h +0 -341
- {warp_lang-1.5.1.dist-info → warp_lang-1.6.0.dist-info}/LICENSE.md +0 -0
- {warp_lang-1.5.1.dist-info → warp_lang-1.6.0.dist-info}/top_level.txt +0 -0
warp/codegen.py
CHANGED
|
@@ -282,9 +282,9 @@ class StructInstance:
|
|
|
282
282
|
else:
|
|
283
283
|
# wp.array
|
|
284
284
|
assert isinstance(value, array)
|
|
285
|
-
assert types_equal(
|
|
286
|
-
|
|
287
|
-
)
|
|
285
|
+
assert types_equal(value.dtype, var.type.dtype), (
|
|
286
|
+
f"assign to struct member variable {name} failed, expected type {type_repr(var.type.dtype)}, got type {type_repr(value.dtype)}"
|
|
287
|
+
)
|
|
288
288
|
setattr(self._ctype, name, value.__ctype__())
|
|
289
289
|
|
|
290
290
|
elif isinstance(var.type, Struct):
|
|
@@ -606,6 +606,9 @@ def compute_type_str(base_name, template_params):
|
|
|
606
606
|
return "bool"
|
|
607
607
|
else:
|
|
608
608
|
return f"wp::{p.__name__}"
|
|
609
|
+
elif is_tile(p):
|
|
610
|
+
return p.ctype()
|
|
611
|
+
|
|
609
612
|
return p.__name__
|
|
610
613
|
|
|
611
614
|
return f"{base_name}<{','.join(map(param2str, template_params))}>"
|
|
@@ -947,7 +950,7 @@ class Adjoint:
|
|
|
947
950
|
total_shared = 0
|
|
948
951
|
|
|
949
952
|
for var in adj.variables:
|
|
950
|
-
if is_tile(var.type) and var.type.storage == "shared":
|
|
953
|
+
if is_tile(var.type) and var.type.storage == "shared" and var.type.owner:
|
|
951
954
|
total_shared += var.type.size_in_bytes()
|
|
952
955
|
|
|
953
956
|
return total_shared + adj.max_required_extra_shared_memory
|
|
@@ -1139,6 +1142,9 @@ class Adjoint:
|
|
|
1139
1142
|
if isinstance(var, (Reference, warp.context.Function)):
|
|
1140
1143
|
return var
|
|
1141
1144
|
|
|
1145
|
+
if isinstance(var, int):
|
|
1146
|
+
return adj.add_constant(var)
|
|
1147
|
+
|
|
1142
1148
|
if var.label is None:
|
|
1143
1149
|
return adj.add_var(var.type, var.constant)
|
|
1144
1150
|
|
|
@@ -1349,8 +1355,9 @@ class Adjoint:
|
|
|
1349
1355
|
# which allows for some more advanced resolution to be performed,
|
|
1350
1356
|
# for example by checking whether an argument corresponds to
|
|
1351
1357
|
# a literal value or references a variable.
|
|
1358
|
+
extra_shared_memory = 0
|
|
1352
1359
|
if func.lto_dispatch_func is not None:
|
|
1353
|
-
func_args, template_args, ltoirs = func.lto_dispatch_func(
|
|
1360
|
+
func_args, template_args, ltoirs, extra_shared_memory = func.lto_dispatch_func(
|
|
1354
1361
|
func.input_types, return_type, output_list, bound_args, options=adj.builder_options, builder=adj.builder
|
|
1355
1362
|
)
|
|
1356
1363
|
elif func.dispatch_func is not None:
|
|
@@ -1424,7 +1431,9 @@ class Adjoint:
|
|
|
1424
1431
|
# update our smem roofline requirements based on any
|
|
1425
1432
|
# shared memory required by the dependent function call
|
|
1426
1433
|
if not func.is_builtin():
|
|
1427
|
-
adj.alloc_shared_extra(func.adj.get_total_required_shared())
|
|
1434
|
+
adj.alloc_shared_extra(func.adj.get_total_required_shared() + extra_shared_memory)
|
|
1435
|
+
else:
|
|
1436
|
+
adj.alloc_shared_extra(extra_shared_memory)
|
|
1428
1437
|
|
|
1429
1438
|
return output
|
|
1430
1439
|
|
|
@@ -1527,7 +1536,8 @@ class Adjoint:
|
|
|
1527
1536
|
# zero adjoints
|
|
1528
1537
|
for i in body_block.vars:
|
|
1529
1538
|
if is_tile(i.type):
|
|
1530
|
-
|
|
1539
|
+
if i.type.owner:
|
|
1540
|
+
reverse.append(adj.indentation + f"\t{i.emit_adj()}.grad_zero();")
|
|
1531
1541
|
else:
|
|
1532
1542
|
reverse.append(adj.indentation + f"\t{i.emit_adj()} = {{}};")
|
|
1533
1543
|
|
|
@@ -1857,6 +1867,17 @@ class Adjoint:
|
|
|
1857
1867
|
# stubbed @wp.native_func
|
|
1858
1868
|
return
|
|
1859
1869
|
|
|
1870
|
+
def emit_Assert(adj, node):
|
|
1871
|
+
# eval condition
|
|
1872
|
+
cond = adj.eval(node.test)
|
|
1873
|
+
cond = adj.load(cond)
|
|
1874
|
+
|
|
1875
|
+
source_segment = ast.get_source_segment(adj.source, node)
|
|
1876
|
+
# If a message was provided with the assert, " marks can interfere with the generated code
|
|
1877
|
+
escaped_segment = source_segment.replace('"', '\\"')
|
|
1878
|
+
|
|
1879
|
+
adj.add_forward(f'assert(("{escaped_segment}",{cond.emit()}));')
|
|
1880
|
+
|
|
1860
1881
|
def emit_NameConstant(adj, node):
|
|
1861
1882
|
if node.value:
|
|
1862
1883
|
return adj.add_constant(node.value)
|
|
@@ -1900,12 +1921,25 @@ class Adjoint:
|
|
|
1900
1921
|
|
|
1901
1922
|
name = builtin_operators[type(node.op)]
|
|
1902
1923
|
|
|
1924
|
+
try:
|
|
1925
|
+
# Check if there is any user-defined overload for this operator
|
|
1926
|
+
user_func = adj.resolve_external_reference(name)
|
|
1927
|
+
if isinstance(user_func, warp.context.Function):
|
|
1928
|
+
return adj.add_call(user_func, (left, right), {}, {})
|
|
1929
|
+
except WarpCodegenError:
|
|
1930
|
+
pass
|
|
1931
|
+
|
|
1903
1932
|
return adj.add_builtin_call(name, [left, right])
|
|
1904
1933
|
|
|
1905
1934
|
def emit_UnaryOp(adj, node):
|
|
1906
1935
|
# evaluate unary op arguments
|
|
1907
1936
|
arg = adj.eval(node.operand)
|
|
1908
1937
|
|
|
1938
|
+
# evaluate expression to a compile-time constant if arg is a constant
|
|
1939
|
+
if arg.constant is not None and math.isfinite(arg.constant):
|
|
1940
|
+
if isinstance(node.op, ast.USub):
|
|
1941
|
+
return adj.add_constant(-arg.constant)
|
|
1942
|
+
|
|
1909
1943
|
name = builtin_operators[type(node.op)]
|
|
1910
1944
|
|
|
1911
1945
|
return adj.add_builtin_call(name, [arg])
|
|
@@ -2350,12 +2384,16 @@ class Adjoint:
|
|
|
2350
2384
|
out.is_write = target.is_write
|
|
2351
2385
|
|
|
2352
2386
|
elif is_tile(target_type):
|
|
2353
|
-
if len(indices) ==
|
|
2387
|
+
if len(indices) == len(target_type.shape):
|
|
2354
2388
|
# handles extracting a single element from a tile
|
|
2355
2389
|
out = adj.add_builtin_call("tile_extract", [target, *indices])
|
|
2356
|
-
|
|
2390
|
+
elif len(indices) < len(target_type.shape):
|
|
2357
2391
|
# handles tile views
|
|
2358
|
-
out = adj.add_builtin_call("tile_view", [target,
|
|
2392
|
+
out = adj.add_builtin_call("tile_view", [target, indices])
|
|
2393
|
+
else:
|
|
2394
|
+
raise RuntimeError(
|
|
2395
|
+
f"Incorrect number of indices specified for a tile view/extract, got {len(indices)} indices for a {len(target_type.shape)} dimensional tile."
|
|
2396
|
+
)
|
|
2359
2397
|
|
|
2360
2398
|
else:
|
|
2361
2399
|
# handles non-array type indexing, e.g: vec3, mat33, etc
|
|
@@ -2447,6 +2485,9 @@ class Adjoint:
|
|
|
2447
2485
|
|
|
2448
2486
|
target.mark_write(kernel_name=kernel_name, filename=filename, lineno=lineno)
|
|
2449
2487
|
|
|
2488
|
+
elif is_tile(target_type):
|
|
2489
|
+
adj.add_builtin_call("assign", [target, *indices, rhs])
|
|
2490
|
+
|
|
2450
2491
|
elif type_is_vector(target_type) or type_is_quaternion(target_type) or type_is_matrix(target_type):
|
|
2451
2492
|
# recursively unwind AST, stopping at penultimate node
|
|
2452
2493
|
node = lhs
|
|
@@ -2473,15 +2514,18 @@ class Adjoint:
|
|
|
2473
2514
|
print(
|
|
2474
2515
|
f"Warning: mutating {node_source} in function {adj.fun_name} at {adj.filename}:{lineno}: this is a non-differentiable operation.\n{line}\n"
|
|
2475
2516
|
)
|
|
2476
|
-
|
|
2477
2517
|
else:
|
|
2478
|
-
|
|
2479
|
-
|
|
2480
|
-
|
|
2481
|
-
|
|
2482
|
-
|
|
2483
|
-
adj.symbols[id]
|
|
2484
|
-
|
|
2518
|
+
if adj.builder_options.get("enable_backward", True):
|
|
2519
|
+
out = adj.add_builtin_call("assign", [target, *indices, rhs])
|
|
2520
|
+
|
|
2521
|
+
# re-point target symbol to out var
|
|
2522
|
+
for id in adj.symbols:
|
|
2523
|
+
if adj.symbols[id] == target:
|
|
2524
|
+
adj.symbols[id] = out
|
|
2525
|
+
break
|
|
2526
|
+
else:
|
|
2527
|
+
attr = adj.add_builtin_call("index", [target, *indices])
|
|
2528
|
+
adj.add_builtin_call("store", [attr, rhs])
|
|
2485
2529
|
|
|
2486
2530
|
else:
|
|
2487
2531
|
raise WarpCodegenError(
|
|
@@ -2518,22 +2562,23 @@ class Adjoint:
|
|
|
2518
2562
|
|
|
2519
2563
|
# assigning to a vector or quaternion component
|
|
2520
2564
|
if type_is_vector(aggregate_type) or type_is_quaternion(aggregate_type):
|
|
2521
|
-
# TODO: handle wp.adjoint case
|
|
2522
|
-
|
|
2523
2565
|
index = adj.vector_component_index(lhs.attr, aggregate_type)
|
|
2524
2566
|
|
|
2525
|
-
# TODO: array vec component case
|
|
2526
2567
|
if is_reference(aggregate.type):
|
|
2527
2568
|
attr = adj.add_builtin_call("indexref", [aggregate, index])
|
|
2528
2569
|
adj.add_builtin_call("store", [attr, rhs])
|
|
2529
2570
|
else:
|
|
2530
|
-
|
|
2531
|
-
|
|
2532
|
-
|
|
2533
|
-
|
|
2534
|
-
|
|
2535
|
-
adj.symbols[id]
|
|
2536
|
-
|
|
2571
|
+
if adj.builder_options.get("enable_backward", True):
|
|
2572
|
+
out = adj.add_builtin_call("assign", [aggregate, index, rhs])
|
|
2573
|
+
|
|
2574
|
+
# re-point target symbol to out var
|
|
2575
|
+
for id in adj.symbols:
|
|
2576
|
+
if adj.symbols[id] == aggregate:
|
|
2577
|
+
adj.symbols[id] = out
|
|
2578
|
+
break
|
|
2579
|
+
else:
|
|
2580
|
+
attr = adj.add_builtin_call("index", [aggregate, index])
|
|
2581
|
+
adj.add_builtin_call("store", [attr, rhs])
|
|
2537
2582
|
|
|
2538
2583
|
else:
|
|
2539
2584
|
attr = adj.emit_Attribute(lhs)
|
|
@@ -2637,10 +2682,14 @@ class Adjoint:
|
|
|
2637
2682
|
make_new_assign_statement()
|
|
2638
2683
|
return
|
|
2639
2684
|
|
|
2640
|
-
# TODO
|
|
2641
2685
|
elif type_is_vector(target_type) or type_is_quaternion(target_type) or type_is_matrix(target_type):
|
|
2642
|
-
|
|
2643
|
-
|
|
2686
|
+
if isinstance(node.op, ast.Add):
|
|
2687
|
+
adj.add_builtin_call("augassign_add", [target, *indices, rhs])
|
|
2688
|
+
elif isinstance(node.op, ast.Sub):
|
|
2689
|
+
adj.add_builtin_call("augassign_sub", [target, *indices, rhs])
|
|
2690
|
+
else:
|
|
2691
|
+
make_new_assign_statement()
|
|
2692
|
+
return
|
|
2644
2693
|
|
|
2645
2694
|
else:
|
|
2646
2695
|
raise WarpCodegenError("Can only subscript in-place assign array, vector, quaternion, and matrix types")
|
|
@@ -2688,6 +2737,7 @@ class Adjoint:
|
|
|
2688
2737
|
ast.Tuple: emit_Tuple,
|
|
2689
2738
|
ast.Pass: emit_Pass,
|
|
2690
2739
|
ast.Ellipsis: emit_Ellipsis,
|
|
2740
|
+
ast.Assert: emit_Assert,
|
|
2691
2741
|
}
|
|
2692
2742
|
|
|
2693
2743
|
def eval(adj, node):
|
|
@@ -2850,11 +2900,62 @@ class Adjoint:
|
|
|
2850
2900
|
if static_code is None:
|
|
2851
2901
|
raise WarpCodegenError("Error extracting source code from wp.static() expression")
|
|
2852
2902
|
|
|
2903
|
+
# Since this is an expression, we can enforce it to be defined on a single line.
|
|
2904
|
+
static_code = static_code.replace("\n", "")
|
|
2905
|
+
|
|
2853
2906
|
vars_dict = adj.get_static_evaluation_context()
|
|
2854
2907
|
# add constant variables to the static call context
|
|
2855
2908
|
constant_vars = {k: v.constant for k, v in adj.symbols.items() if isinstance(v, Var) and v.constant is not None}
|
|
2856
2909
|
vars_dict.update(constant_vars)
|
|
2857
2910
|
|
|
2911
|
+
# Replace all constant `len()` expressions with their value.
|
|
2912
|
+
if "len" in static_code:
|
|
2913
|
+
|
|
2914
|
+
def eval_len(obj):
|
|
2915
|
+
if type_is_vector(obj):
|
|
2916
|
+
return obj._length_
|
|
2917
|
+
elif type_is_quaternion(obj):
|
|
2918
|
+
return obj._length_
|
|
2919
|
+
elif type_is_matrix(obj):
|
|
2920
|
+
return obj._shape_[0]
|
|
2921
|
+
elif type_is_transformation(obj):
|
|
2922
|
+
return obj._length_
|
|
2923
|
+
elif is_tile(obj):
|
|
2924
|
+
return obj.shape[0]
|
|
2925
|
+
|
|
2926
|
+
return len(obj)
|
|
2927
|
+
|
|
2928
|
+
len_expr_ctx = vars_dict.copy()
|
|
2929
|
+
constant_types = {k: v.type for k, v in adj.symbols.items() if isinstance(v, Var) and v.type is not None}
|
|
2930
|
+
len_expr_ctx.update(constant_types)
|
|
2931
|
+
len_expr_ctx.update({"len": eval_len})
|
|
2932
|
+
|
|
2933
|
+
# We want to replace the expression code in-place,
|
|
2934
|
+
# so reparse it to get the correct column info.
|
|
2935
|
+
len_value_locs = []
|
|
2936
|
+
expr_tree = ast.parse(static_code)
|
|
2937
|
+
assert len(expr_tree.body) == 1 and isinstance(expr_tree.body[0], ast.Expr)
|
|
2938
|
+
expr_root = expr_tree.body[0].value
|
|
2939
|
+
for expr_node in ast.walk(expr_root):
|
|
2940
|
+
if isinstance(expr_node, ast.Call) and expr_node.func.id == "len" and len(expr_node.args) == 1:
|
|
2941
|
+
len_expr = static_code[expr_node.col_offset : expr_node.end_col_offset]
|
|
2942
|
+
try:
|
|
2943
|
+
len_value = eval(len_expr, len_expr_ctx)
|
|
2944
|
+
except Exception:
|
|
2945
|
+
pass
|
|
2946
|
+
else:
|
|
2947
|
+
len_value_locs.append((len_value, expr_node.col_offset, expr_node.end_col_offset))
|
|
2948
|
+
|
|
2949
|
+
if len_value_locs:
|
|
2950
|
+
new_static_code = ""
|
|
2951
|
+
loc = 0
|
|
2952
|
+
for value, start, end in len_value_locs:
|
|
2953
|
+
new_static_code += f"{static_code[loc:start]}{value}"
|
|
2954
|
+
loc = end
|
|
2955
|
+
|
|
2956
|
+
new_static_code += static_code[len_value_locs[-1][2] :]
|
|
2957
|
+
static_code = new_static_code
|
|
2958
|
+
|
|
2858
2959
|
try:
|
|
2859
2960
|
value = eval(static_code, vars_dict)
|
|
2860
2961
|
if warp.config.verbose:
|
|
@@ -3139,7 +3240,7 @@ static CUDA_CALLABLE void adj_{name}(
|
|
|
3139
3240
|
|
|
3140
3241
|
"""
|
|
3141
3242
|
|
|
3142
|
-
|
|
3243
|
+
cuda_kernel_template_forward = """
|
|
3143
3244
|
|
|
3144
3245
|
extern "C" __global__ void {name}_cuda_kernel_forward(
|
|
3145
3246
|
{forward_args})
|
|
@@ -3154,6 +3255,10 @@ extern "C" __global__ void {name}_cuda_kernel_forward(
|
|
|
3154
3255
|
{forward_body} }}
|
|
3155
3256
|
}}
|
|
3156
3257
|
|
|
3258
|
+
"""
|
|
3259
|
+
|
|
3260
|
+
cuda_kernel_template_backward = """
|
|
3261
|
+
|
|
3157
3262
|
extern "C" __global__ void {name}_cuda_kernel_backward(
|
|
3158
3263
|
{reverse_args})
|
|
3159
3264
|
{{
|
|
@@ -3169,13 +3274,17 @@ extern "C" __global__ void {name}_cuda_kernel_backward(
|
|
|
3169
3274
|
|
|
3170
3275
|
"""
|
|
3171
3276
|
|
|
3172
|
-
|
|
3277
|
+
cpu_kernel_template_forward = """
|
|
3173
3278
|
|
|
3174
3279
|
void {name}_cpu_kernel_forward(
|
|
3175
3280
|
{forward_args})
|
|
3176
3281
|
{{
|
|
3177
3282
|
{forward_body}}}
|
|
3178
3283
|
|
|
3284
|
+
"""
|
|
3285
|
+
|
|
3286
|
+
cpu_kernel_template_backward = """
|
|
3287
|
+
|
|
3179
3288
|
void {name}_cpu_kernel_backward(
|
|
3180
3289
|
{reverse_args})
|
|
3181
3290
|
{{
|
|
@@ -3183,7 +3292,7 @@ void {name}_cpu_kernel_backward(
|
|
|
3183
3292
|
|
|
3184
3293
|
"""
|
|
3185
3294
|
|
|
3186
|
-
|
|
3295
|
+
cpu_module_template_forward = """
|
|
3187
3296
|
|
|
3188
3297
|
extern "C" {{
|
|
3189
3298
|
|
|
@@ -3198,6 +3307,14 @@ WP_API void {name}_cpu_forward(
|
|
|
3198
3307
|
}}
|
|
3199
3308
|
}}
|
|
3200
3309
|
|
|
3310
|
+
}} // extern C
|
|
3311
|
+
|
|
3312
|
+
"""
|
|
3313
|
+
|
|
3314
|
+
cpu_module_template_backward = """
|
|
3315
|
+
|
|
3316
|
+
extern "C" {{
|
|
3317
|
+
|
|
3201
3318
|
WP_API void {name}_cpu_backward(
|
|
3202
3319
|
{reverse_args})
|
|
3203
3320
|
{{
|
|
@@ -3212,36 +3329,6 @@ WP_API void {name}_cpu_backward(
|
|
|
3212
3329
|
|
|
3213
3330
|
"""
|
|
3214
3331
|
|
|
3215
|
-
cuda_module_header_template = """
|
|
3216
|
-
|
|
3217
|
-
extern "C" {{
|
|
3218
|
-
|
|
3219
|
-
// Python CUDA entry points
|
|
3220
|
-
WP_API void {name}_cuda_forward(
|
|
3221
|
-
void* stream,
|
|
3222
|
-
{forward_args});
|
|
3223
|
-
|
|
3224
|
-
WP_API void {name}_cuda_backward(
|
|
3225
|
-
void* stream,
|
|
3226
|
-
{reverse_args});
|
|
3227
|
-
|
|
3228
|
-
}} // extern C
|
|
3229
|
-
"""
|
|
3230
|
-
|
|
3231
|
-
cpu_module_header_template = """
|
|
3232
|
-
|
|
3233
|
-
extern "C" {{
|
|
3234
|
-
|
|
3235
|
-
// Python CPU entry points
|
|
3236
|
-
WP_API void {name}_cpu_forward(
|
|
3237
|
-
{forward_args});
|
|
3238
|
-
|
|
3239
|
-
WP_API void {name}_cpu_backward(
|
|
3240
|
-
{reverse_args});
|
|
3241
|
-
|
|
3242
|
-
}} // extern C
|
|
3243
|
-
"""
|
|
3244
|
-
|
|
3245
3332
|
|
|
3246
3333
|
# converts a constant Python value to equivalent C-repr
|
|
3247
3334
|
def constant_str(value):
|
|
@@ -3679,59 +3766,82 @@ def codegen_kernel(kernel, device, options):
|
|
|
3679
3766
|
|
|
3680
3767
|
adj = kernel.adj
|
|
3681
3768
|
|
|
3682
|
-
|
|
3683
|
-
|
|
3769
|
+
if device == "cpu":
|
|
3770
|
+
template_forward = cpu_kernel_template_forward
|
|
3771
|
+
template_backward = cpu_kernel_template_backward
|
|
3772
|
+
elif device == "cuda":
|
|
3773
|
+
template_forward = cuda_kernel_template_forward
|
|
3774
|
+
template_backward = cuda_kernel_template_backward
|
|
3775
|
+
else:
|
|
3776
|
+
raise ValueError(f"Device {device} is not supported")
|
|
3777
|
+
|
|
3778
|
+
template = ""
|
|
3779
|
+
template_fmt_args = {
|
|
3780
|
+
"name": kernel.get_mangled_name(),
|
|
3781
|
+
}
|
|
3684
3782
|
|
|
3783
|
+
# build forward signature
|
|
3784
|
+
forward_args = ["wp::launch_bounds_t dim"]
|
|
3685
3785
|
if device == "cpu":
|
|
3686
3786
|
forward_args.append("size_t task_index")
|
|
3687
|
-
reverse_args.append("size_t task_index")
|
|
3688
3787
|
|
|
3689
|
-
# forward args
|
|
3690
3788
|
for arg in adj.args:
|
|
3691
3789
|
forward_args.append(arg.ctype() + " var_" + arg.label)
|
|
3692
|
-
reverse_args.append(arg.ctype() + " var_" + arg.label)
|
|
3693
|
-
|
|
3694
|
-
# reverse args
|
|
3695
|
-
for arg in adj.args:
|
|
3696
|
-
# indexed array gradients are regular arrays
|
|
3697
|
-
if isinstance(arg.type, indexedarray):
|
|
3698
|
-
_arg = Var(arg.label, array(dtype=arg.type.dtype, ndim=arg.type.ndim))
|
|
3699
|
-
reverse_args.append(_arg.ctype() + " adj_" + arg.label)
|
|
3700
|
-
else:
|
|
3701
|
-
reverse_args.append(arg.ctype() + " adj_" + arg.label)
|
|
3702
3790
|
|
|
3703
|
-
# codegen body
|
|
3704
3791
|
forward_body = codegen_func_forward(adj, func_type="kernel", device=device)
|
|
3792
|
+
template_fmt_args.update(
|
|
3793
|
+
{
|
|
3794
|
+
"forward_args": indent(forward_args),
|
|
3795
|
+
"forward_body": forward_body,
|
|
3796
|
+
}
|
|
3797
|
+
)
|
|
3798
|
+
template += template_forward
|
|
3705
3799
|
|
|
3706
3800
|
if options["enable_backward"]:
|
|
3707
|
-
|
|
3708
|
-
|
|
3709
|
-
|
|
3801
|
+
# build reverse signature
|
|
3802
|
+
reverse_args = ["wp::launch_bounds_t dim"]
|
|
3803
|
+
if device == "cpu":
|
|
3804
|
+
reverse_args.append("size_t task_index")
|
|
3710
3805
|
|
|
3711
|
-
|
|
3712
|
-
|
|
3713
|
-
elif device == "cuda":
|
|
3714
|
-
template = cuda_kernel_template
|
|
3715
|
-
else:
|
|
3716
|
-
raise ValueError(f"Device {device} is not supported")
|
|
3806
|
+
for arg in adj.args:
|
|
3807
|
+
reverse_args.append(arg.ctype() + " var_" + arg.label)
|
|
3717
3808
|
|
|
3718
|
-
|
|
3719
|
-
|
|
3720
|
-
|
|
3721
|
-
|
|
3722
|
-
|
|
3723
|
-
|
|
3724
|
-
|
|
3809
|
+
for arg in adj.args:
|
|
3810
|
+
# indexed array gradients are regular arrays
|
|
3811
|
+
if isinstance(arg.type, indexedarray):
|
|
3812
|
+
_arg = Var(arg.label, array(dtype=arg.type.dtype, ndim=arg.type.ndim))
|
|
3813
|
+
reverse_args.append(_arg.ctype() + " adj_" + arg.label)
|
|
3814
|
+
else:
|
|
3815
|
+
reverse_args.append(arg.ctype() + " adj_" + arg.label)
|
|
3725
3816
|
|
|
3817
|
+
reverse_body = codegen_func_reverse(adj, func_type="kernel", device=device)
|
|
3818
|
+
template_fmt_args.update(
|
|
3819
|
+
{
|
|
3820
|
+
"reverse_args": indent(reverse_args),
|
|
3821
|
+
"reverse_body": reverse_body,
|
|
3822
|
+
}
|
|
3823
|
+
)
|
|
3824
|
+
template += template_backward
|
|
3825
|
+
|
|
3826
|
+
s = template.format(**template_fmt_args)
|
|
3726
3827
|
return s
|
|
3727
3828
|
|
|
3728
3829
|
|
|
3729
|
-
def codegen_module(kernel, device
|
|
3830
|
+
def codegen_module(kernel, device, options):
|
|
3730
3831
|
if device != "cpu":
|
|
3731
3832
|
return ""
|
|
3732
3833
|
|
|
3834
|
+
# Update the module's options with the ones defined on the kernel, if any.
|
|
3835
|
+
options = dict(options)
|
|
3836
|
+
options.update(kernel.options)
|
|
3837
|
+
|
|
3733
3838
|
adj = kernel.adj
|
|
3734
3839
|
|
|
3840
|
+
template = ""
|
|
3841
|
+
template_fmt_args = {
|
|
3842
|
+
"name": kernel.get_mangled_name(),
|
|
3843
|
+
}
|
|
3844
|
+
|
|
3735
3845
|
# build forward signature
|
|
3736
3846
|
forward_args = ["wp::launch_bounds_t dim"]
|
|
3737
3847
|
forward_params = ["dim", "task_index"]
|
|
@@ -3745,29 +3855,40 @@ def codegen_module(kernel, device="cpu"):
|
|
|
3745
3855
|
forward_args.append(f"{arg.ctype()} var_{arg.label}")
|
|
3746
3856
|
forward_params.append("var_" + arg.label)
|
|
3747
3857
|
|
|
3748
|
-
|
|
3749
|
-
|
|
3750
|
-
|
|
3858
|
+
template_fmt_args.update(
|
|
3859
|
+
{
|
|
3860
|
+
"forward_args": indent(forward_args),
|
|
3861
|
+
"forward_params": indent(forward_params, 3),
|
|
3862
|
+
}
|
|
3863
|
+
)
|
|
3864
|
+
template += cpu_module_template_forward
|
|
3751
3865
|
|
|
3752
|
-
|
|
3753
|
-
|
|
3754
|
-
|
|
3755
|
-
|
|
3756
|
-
reverse_args.append(f"const {_arg.ctype()} adj_{arg.label}")
|
|
3757
|
-
reverse_params.append(f"adj_{_arg.label}")
|
|
3758
|
-
elif hasattr(arg.type, "_wp_generic_type_str_"):
|
|
3759
|
-
# vectors and matrices are passed from Python by pointer
|
|
3760
|
-
reverse_args.append(f"const {arg.ctype()}* adj_{arg.label}")
|
|
3761
|
-
reverse_params.append(f"*adj_{arg.label}")
|
|
3762
|
-
else:
|
|
3763
|
-
reverse_args.append(f"{arg.ctype()} adj_{arg.label}")
|
|
3764
|
-
reverse_params.append(f"adj_{arg.label}")
|
|
3866
|
+
if options["enable_backward"]:
|
|
3867
|
+
# build reverse signature
|
|
3868
|
+
reverse_args = [*forward_args]
|
|
3869
|
+
reverse_params = [*forward_params]
|
|
3765
3870
|
|
|
3766
|
-
|
|
3767
|
-
|
|
3768
|
-
|
|
3769
|
-
|
|
3770
|
-
|
|
3771
|
-
|
|
3772
|
-
|
|
3871
|
+
for arg in adj.args:
|
|
3872
|
+
if isinstance(arg.type, indexedarray):
|
|
3873
|
+
# indexed array gradients are regular arrays
|
|
3874
|
+
_arg = Var(arg.label, array(dtype=arg.type.dtype, ndim=arg.type.ndim))
|
|
3875
|
+
reverse_args.append(f"const {_arg.ctype()} adj_{arg.label}")
|
|
3876
|
+
reverse_params.append(f"adj_{_arg.label}")
|
|
3877
|
+
elif hasattr(arg.type, "_wp_generic_type_str_"):
|
|
3878
|
+
# vectors and matrices are passed from Python by pointer
|
|
3879
|
+
reverse_args.append(f"const {arg.ctype()}* adj_{arg.label}")
|
|
3880
|
+
reverse_params.append(f"*adj_{arg.label}")
|
|
3881
|
+
else:
|
|
3882
|
+
reverse_args.append(f"{arg.ctype()} adj_{arg.label}")
|
|
3883
|
+
reverse_params.append(f"adj_{arg.label}")
|
|
3884
|
+
|
|
3885
|
+
template_fmt_args.update(
|
|
3886
|
+
{
|
|
3887
|
+
"reverse_args": indent(reverse_args),
|
|
3888
|
+
"reverse_params": indent(reverse_params, 3),
|
|
3889
|
+
}
|
|
3890
|
+
)
|
|
3891
|
+
template += cpu_module_template_backward
|
|
3892
|
+
|
|
3893
|
+
s = template.format(**template_fmt_args)
|
|
3773
3894
|
return s
|