warp-lang 1.5.1__py3-none-macosx_10_13_universal2.whl → 1.6.1__py3-none-macosx_10_13_universal2.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +5 -0
- warp/autograd.py +414 -191
- warp/bin/libwarp-clang.dylib +0 -0
- warp/bin/libwarp.dylib +0 -0
- warp/build.py +40 -12
- warp/build_dll.py +13 -6
- warp/builtins.py +1077 -481
- warp/codegen.py +250 -122
- warp/config.py +65 -21
- warp/context.py +500 -149
- warp/examples/assets/square_cloth.usd +0 -0
- warp/examples/benchmarks/benchmark_gemm.py +27 -18
- warp/examples/benchmarks/benchmark_interop_paddle.py +3 -3
- warp/examples/benchmarks/benchmark_interop_torch.py +3 -3
- warp/examples/core/example_marching_cubes.py +1 -1
- warp/examples/core/example_mesh.py +1 -1
- warp/examples/core/example_torch.py +18 -34
- warp/examples/core/example_wave.py +1 -1
- warp/examples/fem/example_apic_fluid.py +1 -0
- warp/examples/fem/example_mixed_elasticity.py +1 -1
- warp/examples/optim/example_bounce.py +1 -1
- warp/examples/optim/example_cloth_throw.py +1 -1
- warp/examples/optim/example_diffray.py +4 -15
- warp/examples/optim/example_drone.py +1 -1
- warp/examples/optim/example_softbody_properties.py +392 -0
- warp/examples/optim/example_trajectory.py +1 -3
- warp/examples/optim/example_walker.py +5 -0
- warp/examples/sim/example_cartpole.py +0 -2
- warp/examples/sim/example_cloth_self_contact.py +314 -0
- warp/examples/sim/example_granular_collision_sdf.py +4 -5
- warp/examples/sim/example_jacobian_ik.py +0 -2
- warp/examples/sim/example_quadruped.py +5 -2
- warp/examples/tile/example_tile_cholesky.py +79 -0
- warp/examples/tile/example_tile_convolution.py +2 -2
- warp/examples/tile/example_tile_fft.py +2 -2
- warp/examples/tile/example_tile_filtering.py +3 -3
- warp/examples/tile/example_tile_matmul.py +4 -4
- warp/examples/tile/example_tile_mlp.py +12 -12
- warp/examples/tile/example_tile_nbody.py +191 -0
- warp/examples/tile/example_tile_walker.py +319 -0
- warp/math.py +147 -0
- warp/native/array.h +12 -0
- warp/native/builtin.h +0 -1
- warp/native/bvh.cpp +149 -70
- warp/native/bvh.cu +287 -68
- warp/native/bvh.h +195 -85
- warp/native/clang/clang.cpp +6 -2
- warp/native/crt.h +1 -0
- warp/native/cuda_util.cpp +35 -0
- warp/native/cuda_util.h +5 -0
- warp/native/exports.h +40 -40
- warp/native/intersect.h +17 -0
- warp/native/mat.h +57 -3
- warp/native/mathdx.cpp +19 -0
- warp/native/mesh.cpp +25 -8
- warp/native/mesh.cu +153 -101
- warp/native/mesh.h +482 -403
- warp/native/quat.h +40 -0
- warp/native/solid_angle.h +7 -0
- warp/native/sort.cpp +85 -0
- warp/native/sort.cu +34 -0
- warp/native/sort.h +3 -1
- warp/native/spatial.h +11 -0
- warp/native/tile.h +1189 -664
- warp/native/tile_reduce.h +8 -6
- warp/native/vec.h +41 -0
- warp/native/warp.cpp +8 -1
- warp/native/warp.cu +263 -40
- warp/native/warp.h +19 -5
- warp/optim/linear.py +22 -4
- warp/render/render_opengl.py +132 -59
- warp/render/render_usd.py +10 -2
- warp/sim/__init__.py +6 -1
- warp/sim/collide.py +289 -32
- warp/sim/import_urdf.py +20 -5
- warp/sim/integrator_euler.py +25 -7
- warp/sim/integrator_featherstone.py +147 -35
- warp/sim/integrator_vbd.py +842 -40
- warp/sim/model.py +173 -112
- warp/sim/render.py +2 -2
- warp/stubs.py +249 -116
- warp/tape.py +28 -30
- warp/tests/aux_test_module_unload.py +15 -0
- warp/tests/{test_sim_grad.py → flaky_test_sim_grad.py} +104 -63
- warp/tests/test_array.py +100 -0
- warp/tests/test_assert.py +242 -0
- warp/tests/test_codegen.py +14 -61
- warp/tests/test_collision.py +8 -8
- warp/tests/test_examples.py +16 -1
- warp/tests/test_grad_debug.py +87 -2
- warp/tests/test_hash_grid.py +1 -1
- warp/tests/test_ipc.py +116 -0
- warp/tests/test_launch.py +77 -26
- warp/tests/test_mat.py +213 -168
- warp/tests/test_math.py +47 -1
- warp/tests/test_matmul.py +11 -7
- warp/tests/test_matmul_lite.py +4 -4
- warp/tests/test_mesh.py +84 -60
- warp/tests/test_mesh_query_aabb.py +165 -0
- warp/tests/test_mesh_query_point.py +328 -286
- warp/tests/test_mesh_query_ray.py +134 -121
- warp/tests/test_mlp.py +2 -2
- warp/tests/test_operators.py +43 -0
- warp/tests/test_overwrite.py +6 -5
- warp/tests/test_quat.py +77 -0
- warp/tests/test_reload.py +29 -0
- warp/tests/test_sim_grad_bounce_linear.py +204 -0
- warp/tests/test_static.py +16 -0
- warp/tests/test_tape.py +25 -0
- warp/tests/test_tile.py +134 -191
- warp/tests/test_tile_load.py +399 -0
- warp/tests/test_tile_mathdx.py +61 -8
- warp/tests/test_tile_mlp.py +17 -17
- warp/tests/test_tile_reduce.py +24 -18
- warp/tests/test_tile_shared_memory.py +66 -17
- warp/tests/test_tile_view.py +165 -0
- warp/tests/test_torch.py +35 -0
- warp/tests/test_utils.py +36 -24
- warp/tests/test_vec.py +110 -0
- warp/tests/unittest_suites.py +29 -4
- warp/tests/unittest_utils.py +30 -11
- warp/thirdparty/unittest_parallel.py +5 -2
- warp/types.py +419 -111
- warp/utils.py +9 -5
- {warp_lang-1.5.1.dist-info → warp_lang-1.6.1.dist-info}/METADATA +86 -45
- {warp_lang-1.5.1.dist-info → warp_lang-1.6.1.dist-info}/RECORD +129 -118
- {warp_lang-1.5.1.dist-info → warp_lang-1.6.1.dist-info}/WHEEL +1 -1
- warp/examples/benchmarks/benchmark_tile.py +0 -179
- warp/native/tile_gemm.h +0 -341
- {warp_lang-1.5.1.dist-info → warp_lang-1.6.1.dist-info}/LICENSE.md +0 -0
- {warp_lang-1.5.1.dist-info → warp_lang-1.6.1.dist-info}/top_level.txt +0 -0
warp/codegen.py
CHANGED
|
@@ -282,9 +282,9 @@ class StructInstance:
|
|
|
282
282
|
else:
|
|
283
283
|
# wp.array
|
|
284
284
|
assert isinstance(value, array)
|
|
285
|
-
assert types_equal(
|
|
286
|
-
|
|
287
|
-
)
|
|
285
|
+
assert types_equal(value.dtype, var.type.dtype), (
|
|
286
|
+
f"assign to struct member variable {name} failed, expected type {type_repr(var.type.dtype)}, got type {type_repr(value.dtype)}"
|
|
287
|
+
)
|
|
288
288
|
setattr(self._ctype, name, value.__ctype__())
|
|
289
289
|
|
|
290
290
|
elif isinstance(var.type, Struct):
|
|
@@ -606,6 +606,9 @@ def compute_type_str(base_name, template_params):
|
|
|
606
606
|
return "bool"
|
|
607
607
|
else:
|
|
608
608
|
return f"wp::{p.__name__}"
|
|
609
|
+
elif is_tile(p):
|
|
610
|
+
return p.ctype()
|
|
611
|
+
|
|
609
612
|
return p.__name__
|
|
610
613
|
|
|
611
614
|
return f"{base_name}<{','.join(map(param2str, template_params))}>"
|
|
@@ -947,7 +950,7 @@ class Adjoint:
|
|
|
947
950
|
total_shared = 0
|
|
948
951
|
|
|
949
952
|
for var in adj.variables:
|
|
950
|
-
if is_tile(var.type) and var.type.storage == "shared":
|
|
953
|
+
if is_tile(var.type) and var.type.storage == "shared" and var.type.owner:
|
|
951
954
|
total_shared += var.type.size_in_bytes()
|
|
952
955
|
|
|
953
956
|
return total_shared + adj.max_required_extra_shared_memory
|
|
@@ -1139,6 +1142,9 @@ class Adjoint:
|
|
|
1139
1142
|
if isinstance(var, (Reference, warp.context.Function)):
|
|
1140
1143
|
return var
|
|
1141
1144
|
|
|
1145
|
+
if isinstance(var, int):
|
|
1146
|
+
return adj.add_constant(var)
|
|
1147
|
+
|
|
1142
1148
|
if var.label is None:
|
|
1143
1149
|
return adj.add_var(var.type, var.constant)
|
|
1144
1150
|
|
|
@@ -1349,8 +1355,9 @@ class Adjoint:
|
|
|
1349
1355
|
# which allows for some more advanced resolution to be performed,
|
|
1350
1356
|
# for example by checking whether an argument corresponds to
|
|
1351
1357
|
# a literal value or references a variable.
|
|
1358
|
+
extra_shared_memory = 0
|
|
1352
1359
|
if func.lto_dispatch_func is not None:
|
|
1353
|
-
func_args, template_args, ltoirs = func.lto_dispatch_func(
|
|
1360
|
+
func_args, template_args, ltoirs, extra_shared_memory = func.lto_dispatch_func(
|
|
1354
1361
|
func.input_types, return_type, output_list, bound_args, options=adj.builder_options, builder=adj.builder
|
|
1355
1362
|
)
|
|
1356
1363
|
elif func.dispatch_func is not None:
|
|
@@ -1424,7 +1431,9 @@ class Adjoint:
|
|
|
1424
1431
|
# update our smem roofline requirements based on any
|
|
1425
1432
|
# shared memory required by the dependent function call
|
|
1426
1433
|
if not func.is_builtin():
|
|
1427
|
-
adj.alloc_shared_extra(func.adj.get_total_required_shared())
|
|
1434
|
+
adj.alloc_shared_extra(func.adj.get_total_required_shared() + extra_shared_memory)
|
|
1435
|
+
else:
|
|
1436
|
+
adj.alloc_shared_extra(extra_shared_memory)
|
|
1428
1437
|
|
|
1429
1438
|
return output
|
|
1430
1439
|
|
|
@@ -1527,7 +1536,8 @@ class Adjoint:
|
|
|
1527
1536
|
# zero adjoints
|
|
1528
1537
|
for i in body_block.vars:
|
|
1529
1538
|
if is_tile(i.type):
|
|
1530
|
-
|
|
1539
|
+
if i.type.owner:
|
|
1540
|
+
reverse.append(adj.indentation + f"\t{i.emit_adj()}.grad_zero();")
|
|
1531
1541
|
else:
|
|
1532
1542
|
reverse.append(adj.indentation + f"\t{i.emit_adj()} = {{}};")
|
|
1533
1543
|
|
|
@@ -1857,6 +1867,17 @@ class Adjoint:
|
|
|
1857
1867
|
# stubbed @wp.native_func
|
|
1858
1868
|
return
|
|
1859
1869
|
|
|
1870
|
+
def emit_Assert(adj, node):
|
|
1871
|
+
# eval condition
|
|
1872
|
+
cond = adj.eval(node.test)
|
|
1873
|
+
cond = adj.load(cond)
|
|
1874
|
+
|
|
1875
|
+
source_segment = ast.get_source_segment(adj.source, node)
|
|
1876
|
+
# If a message was provided with the assert, " marks can interfere with the generated code
|
|
1877
|
+
escaped_segment = source_segment.replace('"', '\\"')
|
|
1878
|
+
|
|
1879
|
+
adj.add_forward(f'assert(("{escaped_segment}",{cond.emit()}));')
|
|
1880
|
+
|
|
1860
1881
|
def emit_NameConstant(adj, node):
|
|
1861
1882
|
if node.value:
|
|
1862
1883
|
return adj.add_constant(node.value)
|
|
@@ -1900,12 +1921,25 @@ class Adjoint:
|
|
|
1900
1921
|
|
|
1901
1922
|
name = builtin_operators[type(node.op)]
|
|
1902
1923
|
|
|
1924
|
+
try:
|
|
1925
|
+
# Check if there is any user-defined overload for this operator
|
|
1926
|
+
user_func = adj.resolve_external_reference(name)
|
|
1927
|
+
if isinstance(user_func, warp.context.Function):
|
|
1928
|
+
return adj.add_call(user_func, (left, right), {}, {})
|
|
1929
|
+
except WarpCodegenError:
|
|
1930
|
+
pass
|
|
1931
|
+
|
|
1903
1932
|
return adj.add_builtin_call(name, [left, right])
|
|
1904
1933
|
|
|
1905
1934
|
def emit_UnaryOp(adj, node):
|
|
1906
1935
|
# evaluate unary op arguments
|
|
1907
1936
|
arg = adj.eval(node.operand)
|
|
1908
1937
|
|
|
1938
|
+
# evaluate expression to a compile-time constant if arg is a constant
|
|
1939
|
+
if arg.constant is not None and math.isfinite(arg.constant):
|
|
1940
|
+
if isinstance(node.op, ast.USub):
|
|
1941
|
+
return adj.add_constant(-arg.constant)
|
|
1942
|
+
|
|
1909
1943
|
name = builtin_operators[type(node.op)]
|
|
1910
1944
|
|
|
1911
1945
|
return adj.add_builtin_call(name, [arg])
|
|
@@ -2244,15 +2278,22 @@ class Adjoint:
|
|
|
2244
2278
|
out = adj.add_call(func, args, kwargs, type_args, min_outputs=min_outputs)
|
|
2245
2279
|
|
|
2246
2280
|
if warp.config.verify_autograd_array_access:
|
|
2281
|
+
# Extract the types and values passed as arguments to the function call.
|
|
2282
|
+
arg_types = tuple(strip_reference(get_arg_type(x)) for x in args)
|
|
2283
|
+
kwarg_types = {k: strip_reference(get_arg_type(v)) for k, v in kwargs.items()}
|
|
2284
|
+
|
|
2285
|
+
# Resolve the exact function signature among any existing overload.
|
|
2286
|
+
resolved_func = adj.resolve_func(func, arg_types, kwarg_types, min_outputs)
|
|
2287
|
+
|
|
2247
2288
|
# update arg read/write states according to what happens to that arg in the called function
|
|
2248
|
-
if hasattr(
|
|
2289
|
+
if hasattr(resolved_func, "adj"):
|
|
2249
2290
|
for i, arg in enumerate(args):
|
|
2250
|
-
if
|
|
2291
|
+
if resolved_func.adj.args[i].is_write:
|
|
2251
2292
|
kernel_name = adj.fun_name
|
|
2252
2293
|
filename = adj.filename
|
|
2253
2294
|
lineno = adj.lineno + adj.fun_lineno
|
|
2254
2295
|
arg.mark_write(kernel_name=kernel_name, filename=filename, lineno=lineno)
|
|
2255
|
-
if
|
|
2296
|
+
if resolved_func.adj.args[i].is_read:
|
|
2256
2297
|
arg.mark_read()
|
|
2257
2298
|
|
|
2258
2299
|
return out
|
|
@@ -2350,12 +2391,16 @@ class Adjoint:
|
|
|
2350
2391
|
out.is_write = target.is_write
|
|
2351
2392
|
|
|
2352
2393
|
elif is_tile(target_type):
|
|
2353
|
-
if len(indices) ==
|
|
2394
|
+
if len(indices) == len(target_type.shape):
|
|
2354
2395
|
# handles extracting a single element from a tile
|
|
2355
2396
|
out = adj.add_builtin_call("tile_extract", [target, *indices])
|
|
2356
|
-
|
|
2397
|
+
elif len(indices) < len(target_type.shape):
|
|
2357
2398
|
# handles tile views
|
|
2358
|
-
out = adj.add_builtin_call("tile_view", [target,
|
|
2399
|
+
out = adj.add_builtin_call("tile_view", [target, indices])
|
|
2400
|
+
else:
|
|
2401
|
+
raise RuntimeError(
|
|
2402
|
+
f"Incorrect number of indices specified for a tile view/extract, got {len(indices)} indices for a {len(target_type.shape)} dimensional tile."
|
|
2403
|
+
)
|
|
2359
2404
|
|
|
2360
2405
|
else:
|
|
2361
2406
|
# handles non-array type indexing, e.g: vec3, mat33, etc
|
|
@@ -2447,6 +2492,9 @@ class Adjoint:
|
|
|
2447
2492
|
|
|
2448
2493
|
target.mark_write(kernel_name=kernel_name, filename=filename, lineno=lineno)
|
|
2449
2494
|
|
|
2495
|
+
elif is_tile(target_type):
|
|
2496
|
+
adj.add_builtin_call("assign", [target, *indices, rhs])
|
|
2497
|
+
|
|
2450
2498
|
elif type_is_vector(target_type) or type_is_quaternion(target_type) or type_is_matrix(target_type):
|
|
2451
2499
|
# recursively unwind AST, stopping at penultimate node
|
|
2452
2500
|
node = lhs
|
|
@@ -2473,15 +2521,18 @@ class Adjoint:
|
|
|
2473
2521
|
print(
|
|
2474
2522
|
f"Warning: mutating {node_source} in function {adj.fun_name} at {adj.filename}:{lineno}: this is a non-differentiable operation.\n{line}\n"
|
|
2475
2523
|
)
|
|
2476
|
-
|
|
2477
2524
|
else:
|
|
2478
|
-
|
|
2479
|
-
|
|
2480
|
-
|
|
2481
|
-
|
|
2482
|
-
|
|
2483
|
-
adj.symbols[id]
|
|
2484
|
-
|
|
2525
|
+
if adj.builder_options.get("enable_backward", True):
|
|
2526
|
+
out = adj.add_builtin_call("assign", [target, *indices, rhs])
|
|
2527
|
+
|
|
2528
|
+
# re-point target symbol to out var
|
|
2529
|
+
for id in adj.symbols:
|
|
2530
|
+
if adj.symbols[id] == target:
|
|
2531
|
+
adj.symbols[id] = out
|
|
2532
|
+
break
|
|
2533
|
+
else:
|
|
2534
|
+
attr = adj.add_builtin_call("index", [target, *indices])
|
|
2535
|
+
adj.add_builtin_call("store", [attr, rhs])
|
|
2485
2536
|
|
|
2486
2537
|
else:
|
|
2487
2538
|
raise WarpCodegenError(
|
|
@@ -2518,22 +2569,23 @@ class Adjoint:
|
|
|
2518
2569
|
|
|
2519
2570
|
# assigning to a vector or quaternion component
|
|
2520
2571
|
if type_is_vector(aggregate_type) or type_is_quaternion(aggregate_type):
|
|
2521
|
-
# TODO: handle wp.adjoint case
|
|
2522
|
-
|
|
2523
2572
|
index = adj.vector_component_index(lhs.attr, aggregate_type)
|
|
2524
2573
|
|
|
2525
|
-
# TODO: array vec component case
|
|
2526
2574
|
if is_reference(aggregate.type):
|
|
2527
2575
|
attr = adj.add_builtin_call("indexref", [aggregate, index])
|
|
2528
2576
|
adj.add_builtin_call("store", [attr, rhs])
|
|
2529
2577
|
else:
|
|
2530
|
-
|
|
2531
|
-
|
|
2532
|
-
|
|
2533
|
-
|
|
2534
|
-
|
|
2535
|
-
adj.symbols[id]
|
|
2536
|
-
|
|
2578
|
+
if adj.builder_options.get("enable_backward", True):
|
|
2579
|
+
out = adj.add_builtin_call("assign", [aggregate, index, rhs])
|
|
2580
|
+
|
|
2581
|
+
# re-point target symbol to out var
|
|
2582
|
+
for id in adj.symbols:
|
|
2583
|
+
if adj.symbols[id] == aggregate:
|
|
2584
|
+
adj.symbols[id] = out
|
|
2585
|
+
break
|
|
2586
|
+
else:
|
|
2587
|
+
attr = adj.add_builtin_call("index", [aggregate, index])
|
|
2588
|
+
adj.add_builtin_call("store", [attr, rhs])
|
|
2537
2589
|
|
|
2538
2590
|
else:
|
|
2539
2591
|
attr = adj.emit_Attribute(lhs)
|
|
@@ -2637,10 +2689,14 @@ class Adjoint:
|
|
|
2637
2689
|
make_new_assign_statement()
|
|
2638
2690
|
return
|
|
2639
2691
|
|
|
2640
|
-
# TODO
|
|
2641
2692
|
elif type_is_vector(target_type) or type_is_quaternion(target_type) or type_is_matrix(target_type):
|
|
2642
|
-
|
|
2643
|
-
|
|
2693
|
+
if isinstance(node.op, ast.Add):
|
|
2694
|
+
adj.add_builtin_call("augassign_add", [target, *indices, rhs])
|
|
2695
|
+
elif isinstance(node.op, ast.Sub):
|
|
2696
|
+
adj.add_builtin_call("augassign_sub", [target, *indices, rhs])
|
|
2697
|
+
else:
|
|
2698
|
+
make_new_assign_statement()
|
|
2699
|
+
return
|
|
2644
2700
|
|
|
2645
2701
|
else:
|
|
2646
2702
|
raise WarpCodegenError("Can only subscript in-place assign array, vector, quaternion, and matrix types")
|
|
@@ -2688,6 +2744,7 @@ class Adjoint:
|
|
|
2688
2744
|
ast.Tuple: emit_Tuple,
|
|
2689
2745
|
ast.Pass: emit_Pass,
|
|
2690
2746
|
ast.Ellipsis: emit_Ellipsis,
|
|
2747
|
+
ast.Assert: emit_Assert,
|
|
2691
2748
|
}
|
|
2692
2749
|
|
|
2693
2750
|
def eval(adj, node):
|
|
@@ -2850,11 +2907,62 @@ class Adjoint:
|
|
|
2850
2907
|
if static_code is None:
|
|
2851
2908
|
raise WarpCodegenError("Error extracting source code from wp.static() expression")
|
|
2852
2909
|
|
|
2910
|
+
# Since this is an expression, we can enforce it to be defined on a single line.
|
|
2911
|
+
static_code = static_code.replace("\n", "")
|
|
2912
|
+
|
|
2853
2913
|
vars_dict = adj.get_static_evaluation_context()
|
|
2854
2914
|
# add constant variables to the static call context
|
|
2855
2915
|
constant_vars = {k: v.constant for k, v in adj.symbols.items() if isinstance(v, Var) and v.constant is not None}
|
|
2856
2916
|
vars_dict.update(constant_vars)
|
|
2857
2917
|
|
|
2918
|
+
# Replace all constant `len()` expressions with their value.
|
|
2919
|
+
if "len" in static_code:
|
|
2920
|
+
|
|
2921
|
+
def eval_len(obj):
|
|
2922
|
+
if type_is_vector(obj):
|
|
2923
|
+
return obj._length_
|
|
2924
|
+
elif type_is_quaternion(obj):
|
|
2925
|
+
return obj._length_
|
|
2926
|
+
elif type_is_matrix(obj):
|
|
2927
|
+
return obj._shape_[0]
|
|
2928
|
+
elif type_is_transformation(obj):
|
|
2929
|
+
return obj._length_
|
|
2930
|
+
elif is_tile(obj):
|
|
2931
|
+
return obj.shape[0]
|
|
2932
|
+
|
|
2933
|
+
return len(obj)
|
|
2934
|
+
|
|
2935
|
+
len_expr_ctx = vars_dict.copy()
|
|
2936
|
+
constant_types = {k: v.type for k, v in adj.symbols.items() if isinstance(v, Var) and v.type is not None}
|
|
2937
|
+
len_expr_ctx.update(constant_types)
|
|
2938
|
+
len_expr_ctx.update({"len": eval_len})
|
|
2939
|
+
|
|
2940
|
+
# We want to replace the expression code in-place,
|
|
2941
|
+
# so reparse it to get the correct column info.
|
|
2942
|
+
len_value_locs = []
|
|
2943
|
+
expr_tree = ast.parse(static_code)
|
|
2944
|
+
assert len(expr_tree.body) == 1 and isinstance(expr_tree.body[0], ast.Expr)
|
|
2945
|
+
expr_root = expr_tree.body[0].value
|
|
2946
|
+
for expr_node in ast.walk(expr_root):
|
|
2947
|
+
if isinstance(expr_node, ast.Call) and expr_node.func.id == "len" and len(expr_node.args) == 1:
|
|
2948
|
+
len_expr = static_code[expr_node.col_offset : expr_node.end_col_offset]
|
|
2949
|
+
try:
|
|
2950
|
+
len_value = eval(len_expr, len_expr_ctx)
|
|
2951
|
+
except Exception:
|
|
2952
|
+
pass
|
|
2953
|
+
else:
|
|
2954
|
+
len_value_locs.append((len_value, expr_node.col_offset, expr_node.end_col_offset))
|
|
2955
|
+
|
|
2956
|
+
if len_value_locs:
|
|
2957
|
+
new_static_code = ""
|
|
2958
|
+
loc = 0
|
|
2959
|
+
for value, start, end in len_value_locs:
|
|
2960
|
+
new_static_code += f"{static_code[loc:start]}{value}"
|
|
2961
|
+
loc = end
|
|
2962
|
+
|
|
2963
|
+
new_static_code += static_code[len_value_locs[-1][2] :]
|
|
2964
|
+
static_code = new_static_code
|
|
2965
|
+
|
|
2858
2966
|
try:
|
|
2859
2967
|
value = eval(static_code, vars_dict)
|
|
2860
2968
|
if warp.config.verbose:
|
|
@@ -3139,7 +3247,7 @@ static CUDA_CALLABLE void adj_{name}(
|
|
|
3139
3247
|
|
|
3140
3248
|
"""
|
|
3141
3249
|
|
|
3142
|
-
|
|
3250
|
+
cuda_kernel_template_forward = """
|
|
3143
3251
|
|
|
3144
3252
|
extern "C" __global__ void {name}_cuda_kernel_forward(
|
|
3145
3253
|
{forward_args})
|
|
@@ -3154,6 +3262,10 @@ extern "C" __global__ void {name}_cuda_kernel_forward(
|
|
|
3154
3262
|
{forward_body} }}
|
|
3155
3263
|
}}
|
|
3156
3264
|
|
|
3265
|
+
"""
|
|
3266
|
+
|
|
3267
|
+
cuda_kernel_template_backward = """
|
|
3268
|
+
|
|
3157
3269
|
extern "C" __global__ void {name}_cuda_kernel_backward(
|
|
3158
3270
|
{reverse_args})
|
|
3159
3271
|
{{
|
|
@@ -3169,13 +3281,17 @@ extern "C" __global__ void {name}_cuda_kernel_backward(
|
|
|
3169
3281
|
|
|
3170
3282
|
"""
|
|
3171
3283
|
|
|
3172
|
-
|
|
3284
|
+
cpu_kernel_template_forward = """
|
|
3173
3285
|
|
|
3174
3286
|
void {name}_cpu_kernel_forward(
|
|
3175
3287
|
{forward_args})
|
|
3176
3288
|
{{
|
|
3177
3289
|
{forward_body}}}
|
|
3178
3290
|
|
|
3291
|
+
"""
|
|
3292
|
+
|
|
3293
|
+
cpu_kernel_template_backward = """
|
|
3294
|
+
|
|
3179
3295
|
void {name}_cpu_kernel_backward(
|
|
3180
3296
|
{reverse_args})
|
|
3181
3297
|
{{
|
|
@@ -3183,7 +3299,7 @@ void {name}_cpu_kernel_backward(
|
|
|
3183
3299
|
|
|
3184
3300
|
"""
|
|
3185
3301
|
|
|
3186
|
-
|
|
3302
|
+
cpu_module_template_forward = """
|
|
3187
3303
|
|
|
3188
3304
|
extern "C" {{
|
|
3189
3305
|
|
|
@@ -3198,6 +3314,14 @@ WP_API void {name}_cpu_forward(
|
|
|
3198
3314
|
}}
|
|
3199
3315
|
}}
|
|
3200
3316
|
|
|
3317
|
+
}} // extern C
|
|
3318
|
+
|
|
3319
|
+
"""
|
|
3320
|
+
|
|
3321
|
+
cpu_module_template_backward = """
|
|
3322
|
+
|
|
3323
|
+
extern "C" {{
|
|
3324
|
+
|
|
3201
3325
|
WP_API void {name}_cpu_backward(
|
|
3202
3326
|
{reverse_args})
|
|
3203
3327
|
{{
|
|
@@ -3212,36 +3336,6 @@ WP_API void {name}_cpu_backward(
|
|
|
3212
3336
|
|
|
3213
3337
|
"""
|
|
3214
3338
|
|
|
3215
|
-
cuda_module_header_template = """
|
|
3216
|
-
|
|
3217
|
-
extern "C" {{
|
|
3218
|
-
|
|
3219
|
-
// Python CUDA entry points
|
|
3220
|
-
WP_API void {name}_cuda_forward(
|
|
3221
|
-
void* stream,
|
|
3222
|
-
{forward_args});
|
|
3223
|
-
|
|
3224
|
-
WP_API void {name}_cuda_backward(
|
|
3225
|
-
void* stream,
|
|
3226
|
-
{reverse_args});
|
|
3227
|
-
|
|
3228
|
-
}} // extern C
|
|
3229
|
-
"""
|
|
3230
|
-
|
|
3231
|
-
cpu_module_header_template = """
|
|
3232
|
-
|
|
3233
|
-
extern "C" {{
|
|
3234
|
-
|
|
3235
|
-
// Python CPU entry points
|
|
3236
|
-
WP_API void {name}_cpu_forward(
|
|
3237
|
-
{forward_args});
|
|
3238
|
-
|
|
3239
|
-
WP_API void {name}_cpu_backward(
|
|
3240
|
-
{reverse_args});
|
|
3241
|
-
|
|
3242
|
-
}} // extern C
|
|
3243
|
-
"""
|
|
3244
|
-
|
|
3245
3339
|
|
|
3246
3340
|
# converts a constant Python value to equivalent C-repr
|
|
3247
3341
|
def constant_str(value):
|
|
@@ -3679,59 +3773,82 @@ def codegen_kernel(kernel, device, options):
|
|
|
3679
3773
|
|
|
3680
3774
|
adj = kernel.adj
|
|
3681
3775
|
|
|
3682
|
-
|
|
3683
|
-
|
|
3776
|
+
if device == "cpu":
|
|
3777
|
+
template_forward = cpu_kernel_template_forward
|
|
3778
|
+
template_backward = cpu_kernel_template_backward
|
|
3779
|
+
elif device == "cuda":
|
|
3780
|
+
template_forward = cuda_kernel_template_forward
|
|
3781
|
+
template_backward = cuda_kernel_template_backward
|
|
3782
|
+
else:
|
|
3783
|
+
raise ValueError(f"Device {device} is not supported")
|
|
3784
|
+
|
|
3785
|
+
template = ""
|
|
3786
|
+
template_fmt_args = {
|
|
3787
|
+
"name": kernel.get_mangled_name(),
|
|
3788
|
+
}
|
|
3684
3789
|
|
|
3790
|
+
# build forward signature
|
|
3791
|
+
forward_args = ["wp::launch_bounds_t dim"]
|
|
3685
3792
|
if device == "cpu":
|
|
3686
3793
|
forward_args.append("size_t task_index")
|
|
3687
|
-
reverse_args.append("size_t task_index")
|
|
3688
3794
|
|
|
3689
|
-
# forward args
|
|
3690
3795
|
for arg in adj.args:
|
|
3691
3796
|
forward_args.append(arg.ctype() + " var_" + arg.label)
|
|
3692
|
-
reverse_args.append(arg.ctype() + " var_" + arg.label)
|
|
3693
3797
|
|
|
3694
|
-
# reverse args
|
|
3695
|
-
for arg in adj.args:
|
|
3696
|
-
# indexed array gradients are regular arrays
|
|
3697
|
-
if isinstance(arg.type, indexedarray):
|
|
3698
|
-
_arg = Var(arg.label, array(dtype=arg.type.dtype, ndim=arg.type.ndim))
|
|
3699
|
-
reverse_args.append(_arg.ctype() + " adj_" + arg.label)
|
|
3700
|
-
else:
|
|
3701
|
-
reverse_args.append(arg.ctype() + " adj_" + arg.label)
|
|
3702
|
-
|
|
3703
|
-
# codegen body
|
|
3704
3798
|
forward_body = codegen_func_forward(adj, func_type="kernel", device=device)
|
|
3799
|
+
template_fmt_args.update(
|
|
3800
|
+
{
|
|
3801
|
+
"forward_args": indent(forward_args),
|
|
3802
|
+
"forward_body": forward_body,
|
|
3803
|
+
}
|
|
3804
|
+
)
|
|
3805
|
+
template += template_forward
|
|
3705
3806
|
|
|
3706
3807
|
if options["enable_backward"]:
|
|
3707
|
-
|
|
3708
|
-
|
|
3709
|
-
|
|
3808
|
+
# build reverse signature
|
|
3809
|
+
reverse_args = ["wp::launch_bounds_t dim"]
|
|
3810
|
+
if device == "cpu":
|
|
3811
|
+
reverse_args.append("size_t task_index")
|
|
3710
3812
|
|
|
3711
|
-
|
|
3712
|
-
|
|
3713
|
-
elif device == "cuda":
|
|
3714
|
-
template = cuda_kernel_template
|
|
3715
|
-
else:
|
|
3716
|
-
raise ValueError(f"Device {device} is not supported")
|
|
3813
|
+
for arg in adj.args:
|
|
3814
|
+
reverse_args.append(arg.ctype() + " var_" + arg.label)
|
|
3717
3815
|
|
|
3718
|
-
|
|
3719
|
-
|
|
3720
|
-
|
|
3721
|
-
|
|
3722
|
-
|
|
3723
|
-
|
|
3724
|
-
|
|
3816
|
+
for arg in adj.args:
|
|
3817
|
+
# indexed array gradients are regular arrays
|
|
3818
|
+
if isinstance(arg.type, indexedarray):
|
|
3819
|
+
_arg = Var(arg.label, array(dtype=arg.type.dtype, ndim=arg.type.ndim))
|
|
3820
|
+
reverse_args.append(_arg.ctype() + " adj_" + arg.label)
|
|
3821
|
+
else:
|
|
3822
|
+
reverse_args.append(arg.ctype() + " adj_" + arg.label)
|
|
3725
3823
|
|
|
3824
|
+
reverse_body = codegen_func_reverse(adj, func_type="kernel", device=device)
|
|
3825
|
+
template_fmt_args.update(
|
|
3826
|
+
{
|
|
3827
|
+
"reverse_args": indent(reverse_args),
|
|
3828
|
+
"reverse_body": reverse_body,
|
|
3829
|
+
}
|
|
3830
|
+
)
|
|
3831
|
+
template += template_backward
|
|
3832
|
+
|
|
3833
|
+
s = template.format(**template_fmt_args)
|
|
3726
3834
|
return s
|
|
3727
3835
|
|
|
3728
3836
|
|
|
3729
|
-
def codegen_module(kernel, device
|
|
3837
|
+
def codegen_module(kernel, device, options):
|
|
3730
3838
|
if device != "cpu":
|
|
3731
3839
|
return ""
|
|
3732
3840
|
|
|
3841
|
+
# Update the module's options with the ones defined on the kernel, if any.
|
|
3842
|
+
options = dict(options)
|
|
3843
|
+
options.update(kernel.options)
|
|
3844
|
+
|
|
3733
3845
|
adj = kernel.adj
|
|
3734
3846
|
|
|
3847
|
+
template = ""
|
|
3848
|
+
template_fmt_args = {
|
|
3849
|
+
"name": kernel.get_mangled_name(),
|
|
3850
|
+
}
|
|
3851
|
+
|
|
3735
3852
|
# build forward signature
|
|
3736
3853
|
forward_args = ["wp::launch_bounds_t dim"]
|
|
3737
3854
|
forward_params = ["dim", "task_index"]
|
|
@@ -3745,29 +3862,40 @@ def codegen_module(kernel, device="cpu"):
|
|
|
3745
3862
|
forward_args.append(f"{arg.ctype()} var_{arg.label}")
|
|
3746
3863
|
forward_params.append("var_" + arg.label)
|
|
3747
3864
|
|
|
3748
|
-
|
|
3749
|
-
|
|
3750
|
-
|
|
3865
|
+
template_fmt_args.update(
|
|
3866
|
+
{
|
|
3867
|
+
"forward_args": indent(forward_args),
|
|
3868
|
+
"forward_params": indent(forward_params, 3),
|
|
3869
|
+
}
|
|
3870
|
+
)
|
|
3871
|
+
template += cpu_module_template_forward
|
|
3751
3872
|
|
|
3752
|
-
|
|
3753
|
-
|
|
3754
|
-
|
|
3755
|
-
|
|
3756
|
-
reverse_args.append(f"const {_arg.ctype()} adj_{arg.label}")
|
|
3757
|
-
reverse_params.append(f"adj_{_arg.label}")
|
|
3758
|
-
elif hasattr(arg.type, "_wp_generic_type_str_"):
|
|
3759
|
-
# vectors and matrices are passed from Python by pointer
|
|
3760
|
-
reverse_args.append(f"const {arg.ctype()}* adj_{arg.label}")
|
|
3761
|
-
reverse_params.append(f"*adj_{arg.label}")
|
|
3762
|
-
else:
|
|
3763
|
-
reverse_args.append(f"{arg.ctype()} adj_{arg.label}")
|
|
3764
|
-
reverse_params.append(f"adj_{arg.label}")
|
|
3873
|
+
if options["enable_backward"]:
|
|
3874
|
+
# build reverse signature
|
|
3875
|
+
reverse_args = [*forward_args]
|
|
3876
|
+
reverse_params = [*forward_params]
|
|
3765
3877
|
|
|
3766
|
-
|
|
3767
|
-
|
|
3768
|
-
|
|
3769
|
-
|
|
3770
|
-
|
|
3771
|
-
|
|
3772
|
-
|
|
3878
|
+
for arg in adj.args:
|
|
3879
|
+
if isinstance(arg.type, indexedarray):
|
|
3880
|
+
# indexed array gradients are regular arrays
|
|
3881
|
+
_arg = Var(arg.label, array(dtype=arg.type.dtype, ndim=arg.type.ndim))
|
|
3882
|
+
reverse_args.append(f"const {_arg.ctype()} adj_{arg.label}")
|
|
3883
|
+
reverse_params.append(f"adj_{_arg.label}")
|
|
3884
|
+
elif hasattr(arg.type, "_wp_generic_type_str_"):
|
|
3885
|
+
# vectors and matrices are passed from Python by pointer
|
|
3886
|
+
reverse_args.append(f"const {arg.ctype()}* adj_{arg.label}")
|
|
3887
|
+
reverse_params.append(f"*adj_{arg.label}")
|
|
3888
|
+
else:
|
|
3889
|
+
reverse_args.append(f"{arg.ctype()} adj_{arg.label}")
|
|
3890
|
+
reverse_params.append(f"adj_{arg.label}")
|
|
3891
|
+
|
|
3892
|
+
template_fmt_args.update(
|
|
3893
|
+
{
|
|
3894
|
+
"reverse_args": indent(reverse_args),
|
|
3895
|
+
"reverse_params": indent(reverse_params, 3),
|
|
3896
|
+
}
|
|
3897
|
+
)
|
|
3898
|
+
template += cpu_module_template_backward
|
|
3899
|
+
|
|
3900
|
+
s = template.format(**template_fmt_args)
|
|
3773
3901
|
return s
|