warp-lang 1.5.0__py3-none-win_amd64.whl → 1.6.0__py3-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +5 -0
- warp/autograd.py +414 -191
- warp/bin/warp-clang.dll +0 -0
- warp/bin/warp.dll +0 -0
- warp/build.py +40 -12
- warp/build_dll.py +13 -6
- warp/builtins.py +1124 -497
- warp/codegen.py +261 -136
- warp/config.py +1 -1
- warp/context.py +357 -119
- warp/examples/assets/square_cloth.usd +0 -0
- warp/examples/benchmarks/benchmark_gemm.py +27 -18
- warp/examples/benchmarks/benchmark_interop_paddle.py +3 -3
- warp/examples/benchmarks/benchmark_interop_torch.py +3 -3
- warp/examples/core/example_torch.py +18 -34
- warp/examples/fem/example_apic_fluid.py +1 -0
- warp/examples/fem/example_mixed_elasticity.py +1 -1
- warp/examples/optim/example_bounce.py +1 -1
- warp/examples/optim/example_cloth_throw.py +1 -1
- warp/examples/optim/example_diffray.py +4 -15
- warp/examples/optim/example_drone.py +1 -1
- warp/examples/optim/example_softbody_properties.py +392 -0
- warp/examples/optim/example_trajectory.py +1 -3
- warp/examples/optim/example_walker.py +5 -0
- warp/examples/sim/example_cartpole.py +0 -2
- warp/examples/sim/example_cloth.py +3 -1
- warp/examples/sim/example_cloth_self_contact.py +260 -0
- warp/examples/sim/example_granular_collision_sdf.py +4 -5
- warp/examples/sim/example_jacobian_ik.py +0 -2
- warp/examples/sim/example_quadruped.py +5 -2
- warp/examples/tile/example_tile_cholesky.py +79 -0
- warp/examples/tile/example_tile_convolution.py +2 -2
- warp/examples/tile/example_tile_fft.py +2 -2
- warp/examples/tile/example_tile_filtering.py +3 -3
- warp/examples/tile/example_tile_matmul.py +4 -4
- warp/examples/tile/example_tile_mlp.py +12 -12
- warp/examples/tile/example_tile_nbody.py +180 -0
- warp/examples/tile/example_tile_walker.py +319 -0
- warp/fem/geometry/geometry.py +0 -2
- warp/math.py +147 -0
- warp/native/array.h +12 -0
- warp/native/builtin.h +0 -1
- warp/native/bvh.cpp +149 -70
- warp/native/bvh.cu +287 -68
- warp/native/bvh.h +195 -85
- warp/native/clang/clang.cpp +5 -1
- warp/native/coloring.cpp +5 -1
- warp/native/cuda_util.cpp +91 -53
- warp/native/cuda_util.h +5 -0
- warp/native/exports.h +40 -40
- warp/native/intersect.h +17 -0
- warp/native/mat.h +41 -0
- warp/native/mathdx.cpp +19 -0
- warp/native/mesh.cpp +25 -8
- warp/native/mesh.cu +153 -101
- warp/native/mesh.h +482 -403
- warp/native/quat.h +40 -0
- warp/native/solid_angle.h +7 -0
- warp/native/sort.cpp +85 -0
- warp/native/sort.cu +34 -0
- warp/native/sort.h +3 -1
- warp/native/spatial.h +11 -0
- warp/native/tile.h +1187 -669
- warp/native/tile_reduce.h +8 -6
- warp/native/vec.h +41 -0
- warp/native/warp.cpp +8 -1
- warp/native/warp.cu +263 -40
- warp/native/warp.h +19 -5
- warp/optim/linear.py +22 -4
- warp/render/render_opengl.py +130 -64
- warp/sim/__init__.py +6 -1
- warp/sim/collide.py +270 -26
- warp/sim/import_urdf.py +8 -8
- warp/sim/integrator_euler.py +25 -7
- warp/sim/integrator_featherstone.py +154 -35
- warp/sim/integrator_vbd.py +842 -40
- warp/sim/model.py +134 -72
- warp/sparse.py +1 -1
- warp/stubs.py +265 -132
- warp/tape.py +28 -30
- warp/tests/aux_test_module_unload.py +15 -0
- warp/tests/{test_sim_grad.py → flaky_test_sim_grad.py} +104 -63
- warp/tests/test_array.py +74 -0
- warp/tests/test_assert.py +242 -0
- warp/tests/test_codegen.py +14 -61
- warp/tests/test_collision.py +2 -2
- warp/tests/test_coloring.py +12 -2
- warp/tests/test_examples.py +12 -1
- warp/tests/test_func.py +21 -4
- warp/tests/test_grad_debug.py +87 -2
- warp/tests/test_hash_grid.py +1 -1
- warp/tests/test_ipc.py +116 -0
- warp/tests/test_lerp.py +13 -87
- warp/tests/test_mat.py +138 -167
- warp/tests/test_math.py +47 -1
- warp/tests/test_matmul.py +17 -16
- warp/tests/test_matmul_lite.py +10 -15
- warp/tests/test_mesh.py +84 -60
- warp/tests/test_mesh_query_aabb.py +165 -0
- warp/tests/test_mesh_query_point.py +328 -286
- warp/tests/test_mesh_query_ray.py +134 -121
- warp/tests/test_mlp.py +2 -2
- warp/tests/test_operators.py +43 -0
- warp/tests/test_overwrite.py +47 -2
- warp/tests/test_quat.py +77 -0
- warp/tests/test_reload.py +29 -0
- warp/tests/test_sim_grad_bounce_linear.py +204 -0
- warp/tests/test_smoothstep.py +17 -83
- warp/tests/test_static.py +19 -3
- warp/tests/test_tape.py +25 -0
- warp/tests/test_tile.py +178 -191
- warp/tests/test_tile_load.py +356 -0
- warp/tests/test_tile_mathdx.py +61 -8
- warp/tests/test_tile_mlp.py +17 -17
- warp/tests/test_tile_reduce.py +24 -18
- warp/tests/test_tile_shared_memory.py +66 -17
- warp/tests/test_tile_view.py +165 -0
- warp/tests/test_torch.py +35 -0
- warp/tests/test_utils.py +36 -24
- warp/tests/test_vec.py +110 -0
- warp/tests/unittest_suites.py +29 -4
- warp/tests/unittest_utils.py +30 -13
- warp/thirdparty/unittest_parallel.py +2 -2
- warp/types.py +411 -101
- warp/utils.py +10 -7
- {warp_lang-1.5.0.dist-info → warp_lang-1.6.0.dist-info}/METADATA +92 -69
- {warp_lang-1.5.0.dist-info → warp_lang-1.6.0.dist-info}/RECORD +130 -119
- {warp_lang-1.5.0.dist-info → warp_lang-1.6.0.dist-info}/WHEEL +1 -1
- warp/examples/benchmarks/benchmark_tile.py +0 -179
- warp/native/tile_gemm.h +0 -341
- {warp_lang-1.5.0.dist-info → warp_lang-1.6.0.dist-info}/LICENSE.md +0 -0
- {warp_lang-1.5.0.dist-info → warp_lang-1.6.0.dist-info}/top_level.txt +0 -0
warp/codegen.py
CHANGED
|
@@ -282,9 +282,9 @@ class StructInstance:
|
|
|
282
282
|
else:
|
|
283
283
|
# wp.array
|
|
284
284
|
assert isinstance(value, array)
|
|
285
|
-
assert types_equal(
|
|
286
|
-
|
|
287
|
-
)
|
|
285
|
+
assert types_equal(value.dtype, var.type.dtype), (
|
|
286
|
+
f"assign to struct member variable {name} failed, expected type {type_repr(var.type.dtype)}, got type {type_repr(value.dtype)}"
|
|
287
|
+
)
|
|
288
288
|
setattr(self._ctype, name, value.__ctype__())
|
|
289
289
|
|
|
290
290
|
elif isinstance(var.type, Struct):
|
|
@@ -606,6 +606,9 @@ def compute_type_str(base_name, template_params):
|
|
|
606
606
|
return "bool"
|
|
607
607
|
else:
|
|
608
608
|
return f"wp::{p.__name__}"
|
|
609
|
+
elif is_tile(p):
|
|
610
|
+
return p.ctype()
|
|
611
|
+
|
|
609
612
|
return p.__name__
|
|
610
613
|
|
|
611
614
|
return f"{base_name}<{','.join(map(param2str, template_params))}>"
|
|
@@ -947,7 +950,7 @@ class Adjoint:
|
|
|
947
950
|
total_shared = 0
|
|
948
951
|
|
|
949
952
|
for var in adj.variables:
|
|
950
|
-
if is_tile(var.type) and var.type.storage == "shared":
|
|
953
|
+
if is_tile(var.type) and var.type.storage == "shared" and var.type.owner:
|
|
951
954
|
total_shared += var.type.size_in_bytes()
|
|
952
955
|
|
|
953
956
|
return total_shared + adj.max_required_extra_shared_memory
|
|
@@ -1139,6 +1142,9 @@ class Adjoint:
|
|
|
1139
1142
|
if isinstance(var, (Reference, warp.context.Function)):
|
|
1140
1143
|
return var
|
|
1141
1144
|
|
|
1145
|
+
if isinstance(var, int):
|
|
1146
|
+
return adj.add_constant(var)
|
|
1147
|
+
|
|
1142
1148
|
if var.label is None:
|
|
1143
1149
|
return adj.add_var(var.type, var.constant)
|
|
1144
1150
|
|
|
@@ -1175,25 +1181,25 @@ class Adjoint:
|
|
|
1175
1181
|
left = adj.load(left)
|
|
1176
1182
|
s = output.emit() + " = " + ("(" * len(comps)) + left.emit() + " "
|
|
1177
1183
|
|
|
1178
|
-
|
|
1184
|
+
prev_comp_var = None
|
|
1179
1185
|
|
|
1180
1186
|
for op, comp in zip(op_strings, comps):
|
|
1181
1187
|
comp_chainable = op_str_is_chainable(op)
|
|
1182
|
-
if comp_chainable and
|
|
1183
|
-
# We
|
|
1184
|
-
if
|
|
1185
|
-
|
|
1186
|
-
|
|
1187
|
-
s += "&& (" +
|
|
1188
|
+
if comp_chainable and prev_comp_var:
|
|
1189
|
+
# We restrict chaining to operands of the same type
|
|
1190
|
+
if prev_comp_var.type is comp.type:
|
|
1191
|
+
prev_comp_var = adj.load(prev_comp_var)
|
|
1192
|
+
comp_var = adj.load(comp)
|
|
1193
|
+
s += "&& (" + prev_comp_var.emit() + " " + op + " " + comp_var.emit() + ")) "
|
|
1188
1194
|
else:
|
|
1189
1195
|
raise WarpCodegenTypeError(
|
|
1190
|
-
f"Cannot chain comparisons of unequal types: {
|
|
1196
|
+
f"Cannot chain comparisons of unequal types: {prev_comp_var.type} {op} {comp.type}."
|
|
1191
1197
|
)
|
|
1192
1198
|
else:
|
|
1193
|
-
|
|
1194
|
-
s += op + " " +
|
|
1199
|
+
comp_var = adj.load(comp)
|
|
1200
|
+
s += op + " " + comp_var.emit() + ") "
|
|
1195
1201
|
|
|
1196
|
-
|
|
1202
|
+
prev_comp_var = comp_var
|
|
1197
1203
|
|
|
1198
1204
|
s = s.rstrip() + ";"
|
|
1199
1205
|
|
|
@@ -1349,8 +1355,9 @@ class Adjoint:
|
|
|
1349
1355
|
# which allows for some more advanced resolution to be performed,
|
|
1350
1356
|
# for example by checking whether an argument corresponds to
|
|
1351
1357
|
# a literal value or references a variable.
|
|
1358
|
+
extra_shared_memory = 0
|
|
1352
1359
|
if func.lto_dispatch_func is not None:
|
|
1353
|
-
func_args, template_args, ltoirs = func.lto_dispatch_func(
|
|
1360
|
+
func_args, template_args, ltoirs, extra_shared_memory = func.lto_dispatch_func(
|
|
1354
1361
|
func.input_types, return_type, output_list, bound_args, options=adj.builder_options, builder=adj.builder
|
|
1355
1362
|
)
|
|
1356
1363
|
elif func.dispatch_func is not None:
|
|
@@ -1366,13 +1373,15 @@ class Adjoint:
|
|
|
1366
1373
|
fwd_args = []
|
|
1367
1374
|
for func_arg in func_args:
|
|
1368
1375
|
if not isinstance(func_arg, (Reference, warp.context.Function)):
|
|
1369
|
-
|
|
1376
|
+
func_arg_var = adj.load(func_arg)
|
|
1377
|
+
else:
|
|
1378
|
+
func_arg_var = func_arg
|
|
1370
1379
|
|
|
1371
1380
|
# if the argument is a function (and not a builtin), then build it recursively
|
|
1372
|
-
if isinstance(
|
|
1373
|
-
adj.builder.build_function(
|
|
1381
|
+
if isinstance(func_arg_var, warp.context.Function) and not func_arg_var.is_builtin():
|
|
1382
|
+
adj.builder.build_function(func_arg_var)
|
|
1374
1383
|
|
|
1375
|
-
fwd_args.append(strip_reference(
|
|
1384
|
+
fwd_args.append(strip_reference(func_arg_var))
|
|
1376
1385
|
|
|
1377
1386
|
if return_type is None:
|
|
1378
1387
|
# handles expression (zero output) functions, e.g.: void do_something();
|
|
@@ -1422,7 +1431,9 @@ class Adjoint:
|
|
|
1422
1431
|
# update our smem roofline requirements based on any
|
|
1423
1432
|
# shared memory required by the dependent function call
|
|
1424
1433
|
if not func.is_builtin():
|
|
1425
|
-
adj.alloc_shared_extra(func.adj.get_total_required_shared())
|
|
1434
|
+
adj.alloc_shared_extra(func.adj.get_total_required_shared() + extra_shared_memory)
|
|
1435
|
+
else:
|
|
1436
|
+
adj.alloc_shared_extra(extra_shared_memory)
|
|
1426
1437
|
|
|
1427
1438
|
return output
|
|
1428
1439
|
|
|
@@ -1525,7 +1536,8 @@ class Adjoint:
|
|
|
1525
1536
|
# zero adjoints
|
|
1526
1537
|
for i in body_block.vars:
|
|
1527
1538
|
if is_tile(i.type):
|
|
1528
|
-
|
|
1539
|
+
if i.type.owner:
|
|
1540
|
+
reverse.append(adj.indentation + f"\t{i.emit_adj()}.grad_zero();")
|
|
1529
1541
|
else:
|
|
1530
1542
|
reverse.append(adj.indentation + f"\t{i.emit_adj()} = {{}};")
|
|
1531
1543
|
|
|
@@ -1855,6 +1867,17 @@ class Adjoint:
|
|
|
1855
1867
|
# stubbed @wp.native_func
|
|
1856
1868
|
return
|
|
1857
1869
|
|
|
1870
|
+
def emit_Assert(adj, node):
|
|
1871
|
+
# eval condition
|
|
1872
|
+
cond = adj.eval(node.test)
|
|
1873
|
+
cond = adj.load(cond)
|
|
1874
|
+
|
|
1875
|
+
source_segment = ast.get_source_segment(adj.source, node)
|
|
1876
|
+
# If a message was provided with the assert, " marks can interfere with the generated code
|
|
1877
|
+
escaped_segment = source_segment.replace('"', '\\"')
|
|
1878
|
+
|
|
1879
|
+
adj.add_forward(f'assert(("{escaped_segment}",{cond.emit()}));')
|
|
1880
|
+
|
|
1858
1881
|
def emit_NameConstant(adj, node):
|
|
1859
1882
|
if node.value:
|
|
1860
1883
|
return adj.add_constant(node.value)
|
|
@@ -1898,12 +1921,25 @@ class Adjoint:
|
|
|
1898
1921
|
|
|
1899
1922
|
name = builtin_operators[type(node.op)]
|
|
1900
1923
|
|
|
1924
|
+
try:
|
|
1925
|
+
# Check if there is any user-defined overload for this operator
|
|
1926
|
+
user_func = adj.resolve_external_reference(name)
|
|
1927
|
+
if isinstance(user_func, warp.context.Function):
|
|
1928
|
+
return adj.add_call(user_func, (left, right), {}, {})
|
|
1929
|
+
except WarpCodegenError:
|
|
1930
|
+
pass
|
|
1931
|
+
|
|
1901
1932
|
return adj.add_builtin_call(name, [left, right])
|
|
1902
1933
|
|
|
1903
1934
|
def emit_UnaryOp(adj, node):
|
|
1904
1935
|
# evaluate unary op arguments
|
|
1905
1936
|
arg = adj.eval(node.operand)
|
|
1906
1937
|
|
|
1938
|
+
# evaluate expression to a compile-time constant if arg is a constant
|
|
1939
|
+
if arg.constant is not None and math.isfinite(arg.constant):
|
|
1940
|
+
if isinstance(node.op, ast.USub):
|
|
1941
|
+
return adj.add_constant(-arg.constant)
|
|
1942
|
+
|
|
1907
1943
|
name = builtin_operators[type(node.op)]
|
|
1908
1944
|
|
|
1909
1945
|
return adj.add_builtin_call(name, [arg])
|
|
@@ -2348,12 +2384,16 @@ class Adjoint:
|
|
|
2348
2384
|
out.is_write = target.is_write
|
|
2349
2385
|
|
|
2350
2386
|
elif is_tile(target_type):
|
|
2351
|
-
if len(indices) ==
|
|
2387
|
+
if len(indices) == len(target_type.shape):
|
|
2352
2388
|
# handles extracting a single element from a tile
|
|
2353
2389
|
out = adj.add_builtin_call("tile_extract", [target, *indices])
|
|
2354
|
-
|
|
2390
|
+
elif len(indices) < len(target_type.shape):
|
|
2355
2391
|
# handles tile views
|
|
2356
|
-
out = adj.add_builtin_call("tile_view", [target,
|
|
2392
|
+
out = adj.add_builtin_call("tile_view", [target, indices])
|
|
2393
|
+
else:
|
|
2394
|
+
raise RuntimeError(
|
|
2395
|
+
f"Incorrect number of indices specified for a tile view/extract, got {len(indices)} indices for a {len(target_type.shape)} dimensional tile."
|
|
2396
|
+
)
|
|
2357
2397
|
|
|
2358
2398
|
else:
|
|
2359
2399
|
# handles non-array type indexing, e.g: vec3, mat33, etc
|
|
@@ -2445,6 +2485,9 @@ class Adjoint:
|
|
|
2445
2485
|
|
|
2446
2486
|
target.mark_write(kernel_name=kernel_name, filename=filename, lineno=lineno)
|
|
2447
2487
|
|
|
2488
|
+
elif is_tile(target_type):
|
|
2489
|
+
adj.add_builtin_call("assign", [target, *indices, rhs])
|
|
2490
|
+
|
|
2448
2491
|
elif type_is_vector(target_type) or type_is_quaternion(target_type) or type_is_matrix(target_type):
|
|
2449
2492
|
# recursively unwind AST, stopping at penultimate node
|
|
2450
2493
|
node = lhs
|
|
@@ -2471,15 +2514,18 @@ class Adjoint:
|
|
|
2471
2514
|
print(
|
|
2472
2515
|
f"Warning: mutating {node_source} in function {adj.fun_name} at {adj.filename}:{lineno}: this is a non-differentiable operation.\n{line}\n"
|
|
2473
2516
|
)
|
|
2474
|
-
|
|
2475
2517
|
else:
|
|
2476
|
-
|
|
2477
|
-
|
|
2478
|
-
|
|
2479
|
-
|
|
2480
|
-
|
|
2481
|
-
adj.symbols[id]
|
|
2482
|
-
|
|
2518
|
+
if adj.builder_options.get("enable_backward", True):
|
|
2519
|
+
out = adj.add_builtin_call("assign", [target, *indices, rhs])
|
|
2520
|
+
|
|
2521
|
+
# re-point target symbol to out var
|
|
2522
|
+
for id in adj.symbols:
|
|
2523
|
+
if adj.symbols[id] == target:
|
|
2524
|
+
adj.symbols[id] = out
|
|
2525
|
+
break
|
|
2526
|
+
else:
|
|
2527
|
+
attr = adj.add_builtin_call("index", [target, *indices])
|
|
2528
|
+
adj.add_builtin_call("store", [attr, rhs])
|
|
2483
2529
|
|
|
2484
2530
|
else:
|
|
2485
2531
|
raise WarpCodegenError(
|
|
@@ -2516,22 +2562,23 @@ class Adjoint:
|
|
|
2516
2562
|
|
|
2517
2563
|
# assigning to a vector or quaternion component
|
|
2518
2564
|
if type_is_vector(aggregate_type) or type_is_quaternion(aggregate_type):
|
|
2519
|
-
# TODO: handle wp.adjoint case
|
|
2520
|
-
|
|
2521
2565
|
index = adj.vector_component_index(lhs.attr, aggregate_type)
|
|
2522
2566
|
|
|
2523
|
-
# TODO: array vec component case
|
|
2524
2567
|
if is_reference(aggregate.type):
|
|
2525
2568
|
attr = adj.add_builtin_call("indexref", [aggregate, index])
|
|
2526
2569
|
adj.add_builtin_call("store", [attr, rhs])
|
|
2527
2570
|
else:
|
|
2528
|
-
|
|
2529
|
-
|
|
2530
|
-
|
|
2531
|
-
|
|
2532
|
-
|
|
2533
|
-
adj.symbols[id]
|
|
2534
|
-
|
|
2571
|
+
if adj.builder_options.get("enable_backward", True):
|
|
2572
|
+
out = adj.add_builtin_call("assign", [aggregate, index, rhs])
|
|
2573
|
+
|
|
2574
|
+
# re-point target symbol to out var
|
|
2575
|
+
for id in adj.symbols:
|
|
2576
|
+
if adj.symbols[id] == aggregate:
|
|
2577
|
+
adj.symbols[id] = out
|
|
2578
|
+
break
|
|
2579
|
+
else:
|
|
2580
|
+
attr = adj.add_builtin_call("index", [aggregate, index])
|
|
2581
|
+
adj.add_builtin_call("store", [attr, rhs])
|
|
2535
2582
|
|
|
2536
2583
|
else:
|
|
2537
2584
|
attr = adj.emit_Attribute(lhs)
|
|
@@ -2569,8 +2616,10 @@ class Adjoint:
|
|
|
2569
2616
|
adj.return_var = ()
|
|
2570
2617
|
for ret in var:
|
|
2571
2618
|
if is_reference(ret.type):
|
|
2572
|
-
|
|
2573
|
-
|
|
2619
|
+
ret_var = adj.add_builtin_call("copy", [ret])
|
|
2620
|
+
else:
|
|
2621
|
+
ret_var = ret
|
|
2622
|
+
adj.return_var += (ret_var,)
|
|
2574
2623
|
|
|
2575
2624
|
adj.add_return(adj.return_var)
|
|
2576
2625
|
|
|
@@ -2633,10 +2682,14 @@ class Adjoint:
|
|
|
2633
2682
|
make_new_assign_statement()
|
|
2634
2683
|
return
|
|
2635
2684
|
|
|
2636
|
-
# TODO
|
|
2637
2685
|
elif type_is_vector(target_type) or type_is_quaternion(target_type) or type_is_matrix(target_type):
|
|
2638
|
-
|
|
2639
|
-
|
|
2686
|
+
if isinstance(node.op, ast.Add):
|
|
2687
|
+
adj.add_builtin_call("augassign_add", [target, *indices, rhs])
|
|
2688
|
+
elif isinstance(node.op, ast.Sub):
|
|
2689
|
+
adj.add_builtin_call("augassign_sub", [target, *indices, rhs])
|
|
2690
|
+
else:
|
|
2691
|
+
make_new_assign_statement()
|
|
2692
|
+
return
|
|
2640
2693
|
|
|
2641
2694
|
else:
|
|
2642
2695
|
raise WarpCodegenError("Can only subscript in-place assign array, vector, quaternion, and matrix types")
|
|
@@ -2684,6 +2737,7 @@ class Adjoint:
|
|
|
2684
2737
|
ast.Tuple: emit_Tuple,
|
|
2685
2738
|
ast.Pass: emit_Pass,
|
|
2686
2739
|
ast.Ellipsis: emit_Ellipsis,
|
|
2740
|
+
ast.Assert: emit_Assert,
|
|
2687
2741
|
}
|
|
2688
2742
|
|
|
2689
2743
|
def eval(adj, node):
|
|
@@ -2846,11 +2900,62 @@ class Adjoint:
|
|
|
2846
2900
|
if static_code is None:
|
|
2847
2901
|
raise WarpCodegenError("Error extracting source code from wp.static() expression")
|
|
2848
2902
|
|
|
2903
|
+
# Since this is an expression, we can enforce it to be defined on a single line.
|
|
2904
|
+
static_code = static_code.replace("\n", "")
|
|
2905
|
+
|
|
2849
2906
|
vars_dict = adj.get_static_evaluation_context()
|
|
2850
2907
|
# add constant variables to the static call context
|
|
2851
2908
|
constant_vars = {k: v.constant for k, v in adj.symbols.items() if isinstance(v, Var) and v.constant is not None}
|
|
2852
2909
|
vars_dict.update(constant_vars)
|
|
2853
2910
|
|
|
2911
|
+
# Replace all constant `len()` expressions with their value.
|
|
2912
|
+
if "len" in static_code:
|
|
2913
|
+
|
|
2914
|
+
def eval_len(obj):
|
|
2915
|
+
if type_is_vector(obj):
|
|
2916
|
+
return obj._length_
|
|
2917
|
+
elif type_is_quaternion(obj):
|
|
2918
|
+
return obj._length_
|
|
2919
|
+
elif type_is_matrix(obj):
|
|
2920
|
+
return obj._shape_[0]
|
|
2921
|
+
elif type_is_transformation(obj):
|
|
2922
|
+
return obj._length_
|
|
2923
|
+
elif is_tile(obj):
|
|
2924
|
+
return obj.shape[0]
|
|
2925
|
+
|
|
2926
|
+
return len(obj)
|
|
2927
|
+
|
|
2928
|
+
len_expr_ctx = vars_dict.copy()
|
|
2929
|
+
constant_types = {k: v.type for k, v in adj.symbols.items() if isinstance(v, Var) and v.type is not None}
|
|
2930
|
+
len_expr_ctx.update(constant_types)
|
|
2931
|
+
len_expr_ctx.update({"len": eval_len})
|
|
2932
|
+
|
|
2933
|
+
# We want to replace the expression code in-place,
|
|
2934
|
+
# so reparse it to get the correct column info.
|
|
2935
|
+
len_value_locs = []
|
|
2936
|
+
expr_tree = ast.parse(static_code)
|
|
2937
|
+
assert len(expr_tree.body) == 1 and isinstance(expr_tree.body[0], ast.Expr)
|
|
2938
|
+
expr_root = expr_tree.body[0].value
|
|
2939
|
+
for expr_node in ast.walk(expr_root):
|
|
2940
|
+
if isinstance(expr_node, ast.Call) and expr_node.func.id == "len" and len(expr_node.args) == 1:
|
|
2941
|
+
len_expr = static_code[expr_node.col_offset : expr_node.end_col_offset]
|
|
2942
|
+
try:
|
|
2943
|
+
len_value = eval(len_expr, len_expr_ctx)
|
|
2944
|
+
except Exception:
|
|
2945
|
+
pass
|
|
2946
|
+
else:
|
|
2947
|
+
len_value_locs.append((len_value, expr_node.col_offset, expr_node.end_col_offset))
|
|
2948
|
+
|
|
2949
|
+
if len_value_locs:
|
|
2950
|
+
new_static_code = ""
|
|
2951
|
+
loc = 0
|
|
2952
|
+
for value, start, end in len_value_locs:
|
|
2953
|
+
new_static_code += f"{static_code[loc:start]}{value}"
|
|
2954
|
+
loc = end
|
|
2955
|
+
|
|
2956
|
+
new_static_code += static_code[len_value_locs[-1][2] :]
|
|
2957
|
+
static_code = new_static_code
|
|
2958
|
+
|
|
2854
2959
|
try:
|
|
2855
2960
|
value = eval(static_code, vars_dict)
|
|
2856
2961
|
if warp.config.verbose:
|
|
@@ -3135,7 +3240,7 @@ static CUDA_CALLABLE void adj_{name}(
|
|
|
3135
3240
|
|
|
3136
3241
|
"""
|
|
3137
3242
|
|
|
3138
|
-
|
|
3243
|
+
cuda_kernel_template_forward = """
|
|
3139
3244
|
|
|
3140
3245
|
extern "C" __global__ void {name}_cuda_kernel_forward(
|
|
3141
3246
|
{forward_args})
|
|
@@ -3150,6 +3255,10 @@ extern "C" __global__ void {name}_cuda_kernel_forward(
|
|
|
3150
3255
|
{forward_body} }}
|
|
3151
3256
|
}}
|
|
3152
3257
|
|
|
3258
|
+
"""
|
|
3259
|
+
|
|
3260
|
+
cuda_kernel_template_backward = """
|
|
3261
|
+
|
|
3153
3262
|
extern "C" __global__ void {name}_cuda_kernel_backward(
|
|
3154
3263
|
{reverse_args})
|
|
3155
3264
|
{{
|
|
@@ -3165,13 +3274,17 @@ extern "C" __global__ void {name}_cuda_kernel_backward(
|
|
|
3165
3274
|
|
|
3166
3275
|
"""
|
|
3167
3276
|
|
|
3168
|
-
|
|
3277
|
+
cpu_kernel_template_forward = """
|
|
3169
3278
|
|
|
3170
3279
|
void {name}_cpu_kernel_forward(
|
|
3171
3280
|
{forward_args})
|
|
3172
3281
|
{{
|
|
3173
3282
|
{forward_body}}}
|
|
3174
3283
|
|
|
3284
|
+
"""
|
|
3285
|
+
|
|
3286
|
+
cpu_kernel_template_backward = """
|
|
3287
|
+
|
|
3175
3288
|
void {name}_cpu_kernel_backward(
|
|
3176
3289
|
{reverse_args})
|
|
3177
3290
|
{{
|
|
@@ -3179,7 +3292,7 @@ void {name}_cpu_kernel_backward(
|
|
|
3179
3292
|
|
|
3180
3293
|
"""
|
|
3181
3294
|
|
|
3182
|
-
|
|
3295
|
+
cpu_module_template_forward = """
|
|
3183
3296
|
|
|
3184
3297
|
extern "C" {{
|
|
3185
3298
|
|
|
@@ -3194,6 +3307,14 @@ WP_API void {name}_cpu_forward(
|
|
|
3194
3307
|
}}
|
|
3195
3308
|
}}
|
|
3196
3309
|
|
|
3310
|
+
}} // extern C
|
|
3311
|
+
|
|
3312
|
+
"""
|
|
3313
|
+
|
|
3314
|
+
cpu_module_template_backward = """
|
|
3315
|
+
|
|
3316
|
+
extern "C" {{
|
|
3317
|
+
|
|
3197
3318
|
WP_API void {name}_cpu_backward(
|
|
3198
3319
|
{reverse_args})
|
|
3199
3320
|
{{
|
|
@@ -3208,36 +3329,6 @@ WP_API void {name}_cpu_backward(
|
|
|
3208
3329
|
|
|
3209
3330
|
"""
|
|
3210
3331
|
|
|
3211
|
-
cuda_module_header_template = """
|
|
3212
|
-
|
|
3213
|
-
extern "C" {{
|
|
3214
|
-
|
|
3215
|
-
// Python CUDA entry points
|
|
3216
|
-
WP_API void {name}_cuda_forward(
|
|
3217
|
-
void* stream,
|
|
3218
|
-
{forward_args});
|
|
3219
|
-
|
|
3220
|
-
WP_API void {name}_cuda_backward(
|
|
3221
|
-
void* stream,
|
|
3222
|
-
{reverse_args});
|
|
3223
|
-
|
|
3224
|
-
}} // extern C
|
|
3225
|
-
"""
|
|
3226
|
-
|
|
3227
|
-
cpu_module_header_template = """
|
|
3228
|
-
|
|
3229
|
-
extern "C" {{
|
|
3230
|
-
|
|
3231
|
-
// Python CPU entry points
|
|
3232
|
-
WP_API void {name}_cpu_forward(
|
|
3233
|
-
{forward_args});
|
|
3234
|
-
|
|
3235
|
-
WP_API void {name}_cpu_backward(
|
|
3236
|
-
{reverse_args});
|
|
3237
|
-
|
|
3238
|
-
}} // extern C
|
|
3239
|
-
"""
|
|
3240
|
-
|
|
3241
3332
|
|
|
3242
3333
|
# converts a constant Python value to equivalent C-repr
|
|
3243
3334
|
def constant_str(value):
|
|
@@ -3675,59 +3766,82 @@ def codegen_kernel(kernel, device, options):
|
|
|
3675
3766
|
|
|
3676
3767
|
adj = kernel.adj
|
|
3677
3768
|
|
|
3678
|
-
|
|
3679
|
-
|
|
3769
|
+
if device == "cpu":
|
|
3770
|
+
template_forward = cpu_kernel_template_forward
|
|
3771
|
+
template_backward = cpu_kernel_template_backward
|
|
3772
|
+
elif device == "cuda":
|
|
3773
|
+
template_forward = cuda_kernel_template_forward
|
|
3774
|
+
template_backward = cuda_kernel_template_backward
|
|
3775
|
+
else:
|
|
3776
|
+
raise ValueError(f"Device {device} is not supported")
|
|
3777
|
+
|
|
3778
|
+
template = ""
|
|
3779
|
+
template_fmt_args = {
|
|
3780
|
+
"name": kernel.get_mangled_name(),
|
|
3781
|
+
}
|
|
3680
3782
|
|
|
3783
|
+
# build forward signature
|
|
3784
|
+
forward_args = ["wp::launch_bounds_t dim"]
|
|
3681
3785
|
if device == "cpu":
|
|
3682
3786
|
forward_args.append("size_t task_index")
|
|
3683
|
-
reverse_args.append("size_t task_index")
|
|
3684
3787
|
|
|
3685
|
-
# forward args
|
|
3686
3788
|
for arg in adj.args:
|
|
3687
3789
|
forward_args.append(arg.ctype() + " var_" + arg.label)
|
|
3688
|
-
reverse_args.append(arg.ctype() + " var_" + arg.label)
|
|
3689
3790
|
|
|
3690
|
-
# reverse args
|
|
3691
|
-
for arg in adj.args:
|
|
3692
|
-
# indexed array gradients are regular arrays
|
|
3693
|
-
if isinstance(arg.type, indexedarray):
|
|
3694
|
-
_arg = Var(arg.label, array(dtype=arg.type.dtype, ndim=arg.type.ndim))
|
|
3695
|
-
reverse_args.append(_arg.ctype() + " adj_" + arg.label)
|
|
3696
|
-
else:
|
|
3697
|
-
reverse_args.append(arg.ctype() + " adj_" + arg.label)
|
|
3698
|
-
|
|
3699
|
-
# codegen body
|
|
3700
3791
|
forward_body = codegen_func_forward(adj, func_type="kernel", device=device)
|
|
3792
|
+
template_fmt_args.update(
|
|
3793
|
+
{
|
|
3794
|
+
"forward_args": indent(forward_args),
|
|
3795
|
+
"forward_body": forward_body,
|
|
3796
|
+
}
|
|
3797
|
+
)
|
|
3798
|
+
template += template_forward
|
|
3701
3799
|
|
|
3702
3800
|
if options["enable_backward"]:
|
|
3703
|
-
|
|
3704
|
-
|
|
3705
|
-
|
|
3801
|
+
# build reverse signature
|
|
3802
|
+
reverse_args = ["wp::launch_bounds_t dim"]
|
|
3803
|
+
if device == "cpu":
|
|
3804
|
+
reverse_args.append("size_t task_index")
|
|
3706
3805
|
|
|
3707
|
-
|
|
3708
|
-
|
|
3709
|
-
elif device == "cuda":
|
|
3710
|
-
template = cuda_kernel_template
|
|
3711
|
-
else:
|
|
3712
|
-
raise ValueError(f"Device {device} is not supported")
|
|
3806
|
+
for arg in adj.args:
|
|
3807
|
+
reverse_args.append(arg.ctype() + " var_" + arg.label)
|
|
3713
3808
|
|
|
3714
|
-
|
|
3715
|
-
|
|
3716
|
-
|
|
3717
|
-
|
|
3718
|
-
|
|
3719
|
-
|
|
3720
|
-
|
|
3809
|
+
for arg in adj.args:
|
|
3810
|
+
# indexed array gradients are regular arrays
|
|
3811
|
+
if isinstance(arg.type, indexedarray):
|
|
3812
|
+
_arg = Var(arg.label, array(dtype=arg.type.dtype, ndim=arg.type.ndim))
|
|
3813
|
+
reverse_args.append(_arg.ctype() + " adj_" + arg.label)
|
|
3814
|
+
else:
|
|
3815
|
+
reverse_args.append(arg.ctype() + " adj_" + arg.label)
|
|
3721
3816
|
|
|
3817
|
+
reverse_body = codegen_func_reverse(adj, func_type="kernel", device=device)
|
|
3818
|
+
template_fmt_args.update(
|
|
3819
|
+
{
|
|
3820
|
+
"reverse_args": indent(reverse_args),
|
|
3821
|
+
"reverse_body": reverse_body,
|
|
3822
|
+
}
|
|
3823
|
+
)
|
|
3824
|
+
template += template_backward
|
|
3825
|
+
|
|
3826
|
+
s = template.format(**template_fmt_args)
|
|
3722
3827
|
return s
|
|
3723
3828
|
|
|
3724
3829
|
|
|
3725
|
-
def codegen_module(kernel, device
|
|
3830
|
+
def codegen_module(kernel, device, options):
|
|
3726
3831
|
if device != "cpu":
|
|
3727
3832
|
return ""
|
|
3728
3833
|
|
|
3834
|
+
# Update the module's options with the ones defined on the kernel, if any.
|
|
3835
|
+
options = dict(options)
|
|
3836
|
+
options.update(kernel.options)
|
|
3837
|
+
|
|
3729
3838
|
adj = kernel.adj
|
|
3730
3839
|
|
|
3840
|
+
template = ""
|
|
3841
|
+
template_fmt_args = {
|
|
3842
|
+
"name": kernel.get_mangled_name(),
|
|
3843
|
+
}
|
|
3844
|
+
|
|
3731
3845
|
# build forward signature
|
|
3732
3846
|
forward_args = ["wp::launch_bounds_t dim"]
|
|
3733
3847
|
forward_params = ["dim", "task_index"]
|
|
@@ -3741,29 +3855,40 @@ def codegen_module(kernel, device="cpu"):
|
|
|
3741
3855
|
forward_args.append(f"{arg.ctype()} var_{arg.label}")
|
|
3742
3856
|
forward_params.append("var_" + arg.label)
|
|
3743
3857
|
|
|
3744
|
-
|
|
3745
|
-
|
|
3746
|
-
|
|
3858
|
+
template_fmt_args.update(
|
|
3859
|
+
{
|
|
3860
|
+
"forward_args": indent(forward_args),
|
|
3861
|
+
"forward_params": indent(forward_params, 3),
|
|
3862
|
+
}
|
|
3863
|
+
)
|
|
3864
|
+
template += cpu_module_template_forward
|
|
3747
3865
|
|
|
3748
|
-
|
|
3749
|
-
|
|
3750
|
-
|
|
3751
|
-
|
|
3752
|
-
reverse_args.append(f"const {_arg.ctype()} adj_{arg.label}")
|
|
3753
|
-
reverse_params.append(f"adj_{_arg.label}")
|
|
3754
|
-
elif hasattr(arg.type, "_wp_generic_type_str_"):
|
|
3755
|
-
# vectors and matrices are passed from Python by pointer
|
|
3756
|
-
reverse_args.append(f"const {arg.ctype()}* adj_{arg.label}")
|
|
3757
|
-
reverse_params.append(f"*adj_{arg.label}")
|
|
3758
|
-
else:
|
|
3759
|
-
reverse_args.append(f"{arg.ctype()} adj_{arg.label}")
|
|
3760
|
-
reverse_params.append(f"adj_{arg.label}")
|
|
3866
|
+
if options["enable_backward"]:
|
|
3867
|
+
# build reverse signature
|
|
3868
|
+
reverse_args = [*forward_args]
|
|
3869
|
+
reverse_params = [*forward_params]
|
|
3761
3870
|
|
|
3762
|
-
|
|
3763
|
-
|
|
3764
|
-
|
|
3765
|
-
|
|
3766
|
-
|
|
3767
|
-
|
|
3768
|
-
|
|
3871
|
+
for arg in adj.args:
|
|
3872
|
+
if isinstance(arg.type, indexedarray):
|
|
3873
|
+
# indexed array gradients are regular arrays
|
|
3874
|
+
_arg = Var(arg.label, array(dtype=arg.type.dtype, ndim=arg.type.ndim))
|
|
3875
|
+
reverse_args.append(f"const {_arg.ctype()} adj_{arg.label}")
|
|
3876
|
+
reverse_params.append(f"adj_{_arg.label}")
|
|
3877
|
+
elif hasattr(arg.type, "_wp_generic_type_str_"):
|
|
3878
|
+
# vectors and matrices are passed from Python by pointer
|
|
3879
|
+
reverse_args.append(f"const {arg.ctype()}* adj_{arg.label}")
|
|
3880
|
+
reverse_params.append(f"*adj_{arg.label}")
|
|
3881
|
+
else:
|
|
3882
|
+
reverse_args.append(f"{arg.ctype()} adj_{arg.label}")
|
|
3883
|
+
reverse_params.append(f"adj_{arg.label}")
|
|
3884
|
+
|
|
3885
|
+
template_fmt_args.update(
|
|
3886
|
+
{
|
|
3887
|
+
"reverse_args": indent(reverse_args),
|
|
3888
|
+
"reverse_params": indent(reverse_params, 3),
|
|
3889
|
+
}
|
|
3890
|
+
)
|
|
3891
|
+
template += cpu_module_template_backward
|
|
3892
|
+
|
|
3893
|
+
s = template.format(**template_fmt_args)
|
|
3769
3894
|
return s
|