warp-lang 1.8.1__py3-none-manylinux_2_34_aarch64.whl → 1.9.0__py3-none-manylinux_2_34_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +282 -103
- warp/__init__.pyi +482 -110
- warp/bin/warp-clang.so +0 -0
- warp/bin/warp.so +0 -0
- warp/build.py +93 -30
- warp/build_dll.py +47 -67
- warp/builtins.py +955 -137
- warp/codegen.py +312 -206
- warp/config.py +1 -1
- warp/context.py +1249 -784
- warp/examples/core/example_marching_cubes.py +1 -0
- warp/examples/core/example_render_opengl.py +100 -3
- warp/examples/fem/example_apic_fluid.py +98 -52
- warp/examples/fem/example_convection_diffusion_dg.py +25 -4
- warp/examples/fem/example_diffusion_mgpu.py +8 -3
- warp/examples/fem/utils.py +68 -22
- warp/fabric.py +1 -1
- warp/fem/cache.py +27 -19
- warp/fem/domain.py +2 -2
- warp/fem/field/nodal_field.py +2 -2
- warp/fem/field/virtual.py +264 -166
- warp/fem/geometry/geometry.py +5 -5
- warp/fem/integrate.py +129 -51
- warp/fem/space/restriction.py +4 -0
- warp/fem/space/shape/tet_shape_function.py +3 -10
- warp/jax_experimental/custom_call.py +1 -1
- warp/jax_experimental/ffi.py +2 -1
- warp/marching_cubes.py +708 -0
- warp/native/array.h +99 -4
- warp/native/builtin.h +82 -5
- warp/native/bvh.cpp +64 -28
- warp/native/bvh.cu +58 -58
- warp/native/bvh.h +2 -2
- warp/native/clang/clang.cpp +7 -7
- warp/native/coloring.cpp +8 -2
- warp/native/crt.cpp +2 -2
- warp/native/crt.h +3 -5
- warp/native/cuda_util.cpp +41 -10
- warp/native/cuda_util.h +10 -4
- warp/native/exports.h +1842 -1908
- warp/native/fabric.h +2 -1
- warp/native/hashgrid.cpp +37 -37
- warp/native/hashgrid.cu +2 -2
- warp/native/initializer_array.h +1 -1
- warp/native/intersect.h +2 -2
- warp/native/mat.h +1910 -116
- warp/native/mathdx.cpp +43 -43
- warp/native/mesh.cpp +24 -24
- warp/native/mesh.cu +26 -26
- warp/native/mesh.h +4 -2
- warp/native/nanovdb/GridHandle.h +179 -12
- warp/native/nanovdb/HostBuffer.h +8 -7
- warp/native/nanovdb/NanoVDB.h +517 -895
- warp/native/nanovdb/NodeManager.h +323 -0
- warp/native/nanovdb/PNanoVDB.h +2 -2
- warp/native/quat.h +331 -14
- warp/native/range.h +7 -1
- warp/native/reduce.cpp +10 -10
- warp/native/reduce.cu +13 -14
- warp/native/runlength_encode.cpp +2 -2
- warp/native/runlength_encode.cu +5 -5
- warp/native/scan.cpp +3 -3
- warp/native/scan.cu +4 -4
- warp/native/sort.cpp +10 -10
- warp/native/sort.cu +22 -22
- warp/native/sparse.cpp +8 -8
- warp/native/sparse.cu +13 -13
- warp/native/spatial.h +366 -17
- warp/native/temp_buffer.h +2 -2
- warp/native/tile.h +283 -69
- warp/native/vec.h +381 -14
- warp/native/volume.cpp +54 -54
- warp/native/volume.cu +1 -1
- warp/native/volume.h +2 -1
- warp/native/volume_builder.cu +30 -37
- warp/native/warp.cpp +150 -149
- warp/native/warp.cu +323 -192
- warp/native/warp.h +227 -226
- warp/optim/linear.py +736 -271
- warp/render/imgui_manager.py +289 -0
- warp/render/render_opengl.py +85 -6
- warp/sim/graph_coloring.py +2 -2
- warp/sparse.py +558 -175
- warp/tests/aux_test_module_aot.py +7 -0
- warp/tests/cuda/test_async.py +3 -3
- warp/tests/cuda/test_conditional_captures.py +101 -0
- warp/tests/geometry/test_marching_cubes.py +233 -12
- warp/tests/sim/test_coloring.py +6 -6
- warp/tests/test_array.py +56 -5
- warp/tests/test_codegen.py +3 -2
- warp/tests/test_context.py +8 -15
- warp/tests/test_enum.py +136 -0
- warp/tests/test_examples.py +2 -2
- warp/tests/test_fem.py +45 -2
- warp/tests/test_fixedarray.py +229 -0
- warp/tests/test_func.py +18 -15
- warp/tests/test_future_annotations.py +7 -5
- warp/tests/test_linear_solvers.py +30 -0
- warp/tests/test_map.py +1 -1
- warp/tests/test_mat.py +1518 -378
- warp/tests/test_mat_assign_copy.py +178 -0
- warp/tests/test_mat_constructors.py +574 -0
- warp/tests/test_module_aot.py +287 -0
- warp/tests/test_print.py +69 -0
- warp/tests/test_quat.py +140 -34
- warp/tests/test_quat_assign_copy.py +145 -0
- warp/tests/test_reload.py +2 -1
- warp/tests/test_sparse.py +71 -0
- warp/tests/test_spatial.py +140 -34
- warp/tests/test_spatial_assign_copy.py +160 -0
- warp/tests/test_struct.py +43 -3
- warp/tests/test_types.py +0 -20
- warp/tests/test_vec.py +179 -34
- warp/tests/test_vec_assign_copy.py +143 -0
- warp/tests/tile/test_tile.py +184 -18
- warp/tests/tile/test_tile_cholesky.py +605 -0
- warp/tests/tile/test_tile_load.py +169 -0
- warp/tests/tile/test_tile_mathdx.py +2 -558
- warp/tests/tile/test_tile_matmul.py +1 -1
- warp/tests/tile/test_tile_mlp.py +1 -1
- warp/tests/tile/test_tile_shared_memory.py +5 -5
- warp/tests/unittest_suites.py +6 -0
- warp/tests/walkthrough_debug.py +1 -1
- warp/thirdparty/unittest_parallel.py +108 -9
- warp/types.py +554 -264
- warp/utils.py +68 -86
- {warp_lang-1.8.1.dist-info → warp_lang-1.9.0.dist-info}/METADATA +28 -65
- {warp_lang-1.8.1.dist-info → warp_lang-1.9.0.dist-info}/RECORD +131 -121
- warp/native/marching.cpp +0 -19
- warp/native/marching.cu +0 -514
- warp/native/marching.h +0 -19
- {warp_lang-1.8.1.dist-info → warp_lang-1.9.0.dist-info}/WHEEL +0 -0
- {warp_lang-1.8.1.dist-info → warp_lang-1.9.0.dist-info}/licenses/LICENSE.md +0 -0
- {warp_lang-1.8.1.dist-info → warp_lang-1.9.0.dist-info}/top_level.txt +0 -0
warp/codegen.py
CHANGED
|
@@ -18,9 +18,11 @@ from __future__ import annotations
|
|
|
18
18
|
import ast
|
|
19
19
|
import builtins
|
|
20
20
|
import ctypes
|
|
21
|
+
import enum
|
|
21
22
|
import functools
|
|
22
23
|
import hashlib
|
|
23
24
|
import inspect
|
|
25
|
+
import itertools
|
|
24
26
|
import math
|
|
25
27
|
import re
|
|
26
28
|
import sys
|
|
@@ -614,6 +616,8 @@ def compute_type_str(base_name, template_params):
|
|
|
614
616
|
return base_name
|
|
615
617
|
|
|
616
618
|
def param2str(p):
|
|
619
|
+
if isinstance(p, builtins.bool):
|
|
620
|
+
return "true" if p else "false"
|
|
617
621
|
if isinstance(p, int):
|
|
618
622
|
return str(p)
|
|
619
623
|
elif hasattr(p, "_wp_generic_type_str_"):
|
|
@@ -625,6 +629,8 @@ def compute_type_str(base_name, template_params):
|
|
|
625
629
|
return f"wp::{p.__name__}"
|
|
626
630
|
elif is_tile(p):
|
|
627
631
|
return p.ctype()
|
|
632
|
+
elif isinstance(p, Struct):
|
|
633
|
+
return p.native_name
|
|
628
634
|
|
|
629
635
|
return p.__name__
|
|
630
636
|
|
|
@@ -684,7 +690,12 @@ class Var:
|
|
|
684
690
|
|
|
685
691
|
@staticmethod
|
|
686
692
|
def type_to_ctype(t: type, value_type: builtins.bool = False) -> str:
|
|
687
|
-
if
|
|
693
|
+
if isinstance(t, fixedarray):
|
|
694
|
+
template_args = (str(t.size), Var.dtype_to_ctype(t.dtype))
|
|
695
|
+
dtypestr = ", ".join(template_args)
|
|
696
|
+
classstr = f"wp::{type(t).__name__}"
|
|
697
|
+
return f"{classstr}_t<{dtypestr}>"
|
|
698
|
+
elif is_array(t):
|
|
688
699
|
dtypestr = Var.dtype_to_ctype(t.dtype)
|
|
689
700
|
classstr = f"wp::{type(t).__name__}"
|
|
690
701
|
return f"{classstr}_t<{dtypestr}>"
|
|
@@ -780,11 +791,10 @@ def apply_defaults(
|
|
|
780
791
|
arguments = bound_args.arguments
|
|
781
792
|
new_arguments = []
|
|
782
793
|
for name in bound_args._signature.parameters.keys():
|
|
783
|
-
|
|
794
|
+
if name in arguments:
|
|
784
795
|
new_arguments.append((name, arguments[name]))
|
|
785
|
-
|
|
786
|
-
|
|
787
|
-
new_arguments.append((name, values[name]))
|
|
796
|
+
elif name in values:
|
|
797
|
+
new_arguments.append((name, values[name]))
|
|
788
798
|
|
|
789
799
|
bound_args.arguments = dict(new_arguments)
|
|
790
800
|
|
|
@@ -837,6 +847,9 @@ def get_arg_type(arg: Var | Any) -> type:
|
|
|
837
847
|
if isinstance(arg, Sequence):
|
|
838
848
|
return tuple(get_arg_type(x) for x in arg)
|
|
839
849
|
|
|
850
|
+
if is_array(arg):
|
|
851
|
+
return arg
|
|
852
|
+
|
|
840
853
|
if get_origin(arg) is tuple:
|
|
841
854
|
return tuple(get_arg_type(x) for x in get_args(arg))
|
|
842
855
|
|
|
@@ -896,6 +909,8 @@ class Adjoint:
|
|
|
896
909
|
adj.skip_forward_codegen = skip_forward_codegen
|
|
897
910
|
# whether the generation of the adjoint code is skipped for this function
|
|
898
911
|
adj.skip_reverse_codegen = skip_reverse_codegen
|
|
912
|
+
# Whether this function is used by a kernel that has has the backward pass enabled.
|
|
913
|
+
adj.used_by_backward_kernel = False
|
|
899
914
|
|
|
900
915
|
# extract name of source file
|
|
901
916
|
adj.filename = inspect.getsourcefile(func) or "unknown source file"
|
|
@@ -962,7 +977,7 @@ class Adjoint:
|
|
|
962
977
|
continue
|
|
963
978
|
|
|
964
979
|
# add variable for argument
|
|
965
|
-
arg = Var(name, type, False)
|
|
980
|
+
arg = Var(name, type, requires_grad=False)
|
|
966
981
|
adj.args.append(arg)
|
|
967
982
|
|
|
968
983
|
# pre-populate symbol dictionary with function argument names
|
|
@@ -1071,17 +1086,21 @@ class Adjoint:
|
|
|
1071
1086
|
# recursively evaluate function body
|
|
1072
1087
|
try:
|
|
1073
1088
|
adj.eval(adj.tree.body[0])
|
|
1074
|
-
except Exception:
|
|
1089
|
+
except Exception as original_exc:
|
|
1075
1090
|
try:
|
|
1076
1091
|
lineno = adj.lineno + adj.fun_lineno
|
|
1077
1092
|
line = adj.source_lines[adj.lineno]
|
|
1078
1093
|
msg = f'Error while parsing function "{adj.fun_name}" at {adj.filename}:{lineno}:\n{line}\n'
|
|
1079
|
-
|
|
1080
|
-
|
|
1094
|
+
|
|
1095
|
+
# Combine the new message with the original exception's arguments
|
|
1096
|
+
new_args = (";".join([msg] + [str(a) for a in original_exc.args]),)
|
|
1097
|
+
|
|
1098
|
+
# Enhance the original exception with parser context before re-raising.
|
|
1099
|
+
# 'from None' is used to suppress Python's chained exceptions for a cleaner error output.
|
|
1100
|
+
raise type(original_exc)(*new_args).with_traceback(original_exc.__traceback__) from None
|
|
1081
1101
|
finally:
|
|
1082
1102
|
adj.skip_build = True
|
|
1083
1103
|
adj.builder = None
|
|
1084
|
-
raise e
|
|
1085
1104
|
|
|
1086
1105
|
if builder is not None:
|
|
1087
1106
|
for a in adj.args:
|
|
@@ -1227,9 +1246,9 @@ class Adjoint:
|
|
|
1227
1246
|
|
|
1228
1247
|
# lineinfo is enabled by default in debug mode regardless of the builder option, don't want to unnecessarily
|
|
1229
1248
|
# emit line directives in generated code if it's not being compiled with line information
|
|
1230
|
-
|
|
1231
|
-
|
|
1232
|
-
)
|
|
1249
|
+
build_mode = val if (val := adj.builder_options.get("mode")) is not None else warp.config.mode
|
|
1250
|
+
|
|
1251
|
+
lineinfo_enabled = adj.builder_options.get("lineinfo", False) or build_mode == "debug"
|
|
1233
1252
|
|
|
1234
1253
|
if relative_lineno is not None and lineinfo_enabled and warp.config.line_directives:
|
|
1235
1254
|
is_comment = statement.strip().startswith("//")
|
|
@@ -1348,7 +1367,7 @@ class Adjoint:
|
|
|
1348
1367
|
# unresolved function, report error
|
|
1349
1368
|
arg_type_reprs = []
|
|
1350
1369
|
|
|
1351
|
-
for x in arg_types:
|
|
1370
|
+
for x in itertools.chain(arg_types, kwarg_types.values()):
|
|
1352
1371
|
if isinstance(x, warp.context.Function):
|
|
1353
1372
|
arg_type_reprs.append("function")
|
|
1354
1373
|
else:
|
|
@@ -1378,7 +1397,7 @@ class Adjoint:
|
|
|
1378
1397
|
# in order to process them as Python does it.
|
|
1379
1398
|
bound_args: inspect.BoundArguments = func.signature.bind(*args, **kwargs)
|
|
1380
1399
|
|
|
1381
|
-
# Type args are the
|
|
1400
|
+
# Type args are the "compile time" argument values we get from codegen.
|
|
1382
1401
|
# For example, when calling `wp.vec3f(...)` from within a kernel,
|
|
1383
1402
|
# this translates in fact to calling the `vector()` built-in augmented
|
|
1384
1403
|
# with the type args `length=3, dtype=float`.
|
|
@@ -1416,20 +1435,30 @@ class Adjoint:
|
|
|
1416
1435
|
bound_args = bound_args.arguments
|
|
1417
1436
|
|
|
1418
1437
|
# if it is a user-function then build it recursively
|
|
1419
|
-
if not func.is_builtin()
|
|
1420
|
-
|
|
1421
|
-
#
|
|
1422
|
-
|
|
1423
|
-
|
|
1424
|
-
|
|
1425
|
-
if
|
|
1426
|
-
|
|
1427
|
-
|
|
1428
|
-
|
|
1438
|
+
if not func.is_builtin():
|
|
1439
|
+
# If the function called is a user function,
|
|
1440
|
+
# we need to ensure its adjoint is also being generated.
|
|
1441
|
+
if adj.used_by_backward_kernel:
|
|
1442
|
+
func.adj.used_by_backward_kernel = True
|
|
1443
|
+
|
|
1444
|
+
if adj.builder is None:
|
|
1445
|
+
func.build(None)
|
|
1446
|
+
|
|
1447
|
+
elif func not in adj.builder.functions:
|
|
1448
|
+
adj.builder.build_function(func)
|
|
1449
|
+
# add custom grad, replay functions to the list of functions
|
|
1450
|
+
# to be built later (invalid code could be generated if we built them now)
|
|
1451
|
+
# so that they are not missed when only the forward function is imported
|
|
1452
|
+
# from another module
|
|
1453
|
+
if func.custom_grad_func:
|
|
1454
|
+
adj.builder.deferred_functions.append(func.custom_grad_func)
|
|
1455
|
+
if func.custom_replay_func:
|
|
1456
|
+
adj.builder.deferred_functions.append(func.custom_replay_func)
|
|
1429
1457
|
|
|
1430
1458
|
# Resolve the return value based on the types and values of the given arguments.
|
|
1431
1459
|
bound_arg_types = {k: get_arg_type(v) for k, v in bound_args.items()}
|
|
1432
1460
|
bound_arg_values = {k: get_arg_value(v) for k, v in bound_args.items()}
|
|
1461
|
+
|
|
1433
1462
|
return_type = func.value_func(
|
|
1434
1463
|
{k: strip_reference(v) for k, v in bound_arg_types.items()},
|
|
1435
1464
|
bound_arg_values,
|
|
@@ -1493,6 +1522,9 @@ class Adjoint:
|
|
|
1493
1522
|
|
|
1494
1523
|
# if the argument is a function (and not a builtin), then build it recursively
|
|
1495
1524
|
if isinstance(func_arg_var, warp.context.Function) and not func_arg_var.is_builtin():
|
|
1525
|
+
if adj.used_by_backward_kernel:
|
|
1526
|
+
func_arg_var.adj.used_by_backward_kernel = True
|
|
1527
|
+
|
|
1496
1528
|
adj.builder.build_function(func_arg_var)
|
|
1497
1529
|
|
|
1498
1530
|
fwd_args.append(strip_reference(func_arg_var))
|
|
@@ -1886,6 +1918,9 @@ class Adjoint:
|
|
|
1886
1918
|
return obj
|
|
1887
1919
|
if isinstance(obj, type):
|
|
1888
1920
|
return obj
|
|
1921
|
+
if isinstance(obj, Struct):
|
|
1922
|
+
adj.builder.build_struct_recursive(obj)
|
|
1923
|
+
return obj
|
|
1889
1924
|
if isinstance(obj, types.ModuleType):
|
|
1890
1925
|
return obj
|
|
1891
1926
|
|
|
@@ -1938,11 +1973,17 @@ class Adjoint:
|
|
|
1938
1973
|
aggregate = adj.eval(node.value)
|
|
1939
1974
|
|
|
1940
1975
|
try:
|
|
1976
|
+
if isinstance(aggregate, Var) and aggregate.constant is not None:
|
|
1977
|
+
# this case may occur when the attribute is a constant, e.g.: `IntEnum.A.value`
|
|
1978
|
+
return aggregate
|
|
1979
|
+
|
|
1941
1980
|
if isinstance(aggregate, types.ModuleType) or isinstance(aggregate, type):
|
|
1942
1981
|
out = getattr(aggregate, node.attr)
|
|
1943
1982
|
|
|
1944
1983
|
if warp.types.is_value(out):
|
|
1945
1984
|
return adj.add_constant(out)
|
|
1985
|
+
if isinstance(out, (enum.IntEnum, enum.IntFlag)):
|
|
1986
|
+
return adj.add_constant(int(out))
|
|
1946
1987
|
|
|
1947
1988
|
return out
|
|
1948
1989
|
|
|
@@ -1970,18 +2011,29 @@ class Adjoint:
|
|
|
1970
2011
|
return adj.add_builtin_call("transform_get_rotation", [aggregate])
|
|
1971
2012
|
|
|
1972
2013
|
else:
|
|
1973
|
-
|
|
2014
|
+
attr_var = aggregate_type.vars[node.attr]
|
|
2015
|
+
|
|
2016
|
+
# represent pointer types as uint64
|
|
2017
|
+
if isinstance(attr_var.type, pointer_t):
|
|
2018
|
+
cast = f"({Var.dtype_to_ctype(uint64)}*)"
|
|
2019
|
+
adj_cast = f"({Var.dtype_to_ctype(attr_var.type.dtype)}*)"
|
|
2020
|
+
attr_type = Reference(uint64)
|
|
2021
|
+
else:
|
|
2022
|
+
cast = ""
|
|
2023
|
+
adj_cast = ""
|
|
2024
|
+
attr_type = Reference(attr_var.type)
|
|
2025
|
+
|
|
1974
2026
|
attr = adj.add_var(attr_type)
|
|
1975
2027
|
|
|
1976
2028
|
if is_reference(aggregate.type):
|
|
1977
|
-
adj.add_forward(f"{attr.emit()} = &({aggregate.emit()}->{
|
|
2029
|
+
adj.add_forward(f"{attr.emit()} = {cast}&({aggregate.emit()}->{attr_var.label});")
|
|
1978
2030
|
else:
|
|
1979
|
-
adj.add_forward(f"{attr.emit()} = &({aggregate.emit()}.{
|
|
2031
|
+
adj.add_forward(f"{attr.emit()} = {cast}&({aggregate.emit()}.{attr_var.label});")
|
|
1980
2032
|
|
|
1981
2033
|
if adj.is_differentiable_value_type(strip_reference(attr_type)):
|
|
1982
|
-
adj.add_reverse(f"{aggregate.emit_adj()}.{
|
|
2034
|
+
adj.add_reverse(f"{aggregate.emit_adj()}.{attr_var.label} += {adj_cast}{attr.emit_adj()};")
|
|
1983
2035
|
else:
|
|
1984
|
-
adj.add_reverse(f"{aggregate.emit_adj()}.{
|
|
2036
|
+
adj.add_reverse(f"{aggregate.emit_adj()}.{attr_var.label} = {adj_cast}{attr.emit_adj()};")
|
|
1985
2037
|
|
|
1986
2038
|
return attr
|
|
1987
2039
|
|
|
@@ -2309,9 +2361,12 @@ class Adjoint:
|
|
|
2309
2361
|
|
|
2310
2362
|
return var
|
|
2311
2363
|
|
|
2312
|
-
if isinstance(expr, (type, Var, warp.context.Function)):
|
|
2364
|
+
if isinstance(expr, (type, Struct, Var, warp.context.Function)):
|
|
2313
2365
|
return expr
|
|
2314
2366
|
|
|
2367
|
+
if isinstance(expr, (enum.IntEnum, enum.IntFlag)):
|
|
2368
|
+
return adj.add_constant(int(expr))
|
|
2369
|
+
|
|
2315
2370
|
return adj.add_constant(expr)
|
|
2316
2371
|
|
|
2317
2372
|
def emit_Call(adj, node):
|
|
@@ -2360,7 +2415,8 @@ class Adjoint:
|
|
|
2360
2415
|
|
|
2361
2416
|
# struct constructor
|
|
2362
2417
|
if func is None and isinstance(caller, Struct):
|
|
2363
|
-
adj.builder
|
|
2418
|
+
if adj.builder is not None:
|
|
2419
|
+
adj.builder.build_struct_recursive(caller)
|
|
2364
2420
|
if node.args or node.keywords:
|
|
2365
2421
|
func = caller.value_constructor
|
|
2366
2422
|
else:
|
|
@@ -2420,68 +2476,45 @@ class Adjoint:
|
|
|
2420
2476
|
|
|
2421
2477
|
return adj.eval(node.value)
|
|
2422
2478
|
|
|
2423
|
-
|
|
2424
|
-
|
|
2425
|
-
|
|
2426
|
-
|
|
2427
|
-
|
|
2428
|
-
|
|
2429
|
-
|
|
2430
|
-
|
|
2431
|
-
|
|
2432
|
-
|
|
2433
|
-
|
|
2434
|
-
|
|
2435
|
-
|
|
2436
|
-
|
|
2437
|
-
|
|
2438
|
-
|
|
2439
|
-
|
|
2440
|
-
|
|
2441
|
-
|
|
2442
|
-
if isinstance(root.value, ast.Name):
|
|
2443
|
-
symbol = adj.emit_Name(root.value)
|
|
2444
|
-
symbol_type = strip_reference(symbol.type)
|
|
2445
|
-
if is_array(symbol_type):
|
|
2446
|
-
array = symbol
|
|
2447
|
-
break
|
|
2448
|
-
|
|
2449
|
-
root = root.value
|
|
2450
|
-
|
|
2451
|
-
# If not all indices index into the array, just evaluate the right-most indexing operation.
|
|
2452
|
-
if not array or (count > array.type.ndim):
|
|
2453
|
-
count = 1
|
|
2454
|
-
|
|
2455
|
-
indices = []
|
|
2456
|
-
root = node
|
|
2457
|
-
while len(indices) < count:
|
|
2458
|
-
if isinstance(root.slice, ast.Tuple):
|
|
2459
|
-
ij = [adj.eval(arg) for arg in root.slice.elts]
|
|
2460
|
-
elif isinstance(root.slice, ast.Index) and isinstance(root.slice.value, ast.Tuple):
|
|
2461
|
-
ij = [adj.eval(arg) for arg in root.slice.value.elts]
|
|
2462
|
-
else:
|
|
2463
|
-
ij = [adj.eval(root.slice)]
|
|
2464
|
-
|
|
2465
|
-
indices = ij + indices # prepend
|
|
2466
|
-
|
|
2467
|
-
root = root.value
|
|
2468
|
-
|
|
2469
|
-
target = adj.eval(root)
|
|
2479
|
+
def eval_indices(adj, target_type, indices):
|
|
2480
|
+
nodes = indices
|
|
2481
|
+
if hasattr(target_type, "_wp_generic_type_hint_"):
|
|
2482
|
+
indices = []
|
|
2483
|
+
for dim, node in enumerate(nodes):
|
|
2484
|
+
if isinstance(node, ast.Slice):
|
|
2485
|
+
# In the context of slicing a vec/mat type, indices are expected
|
|
2486
|
+
# to be compile-time constants, hence we can infer the actual slice
|
|
2487
|
+
# bounds also at compile-time.
|
|
2488
|
+
length = target_type._shape_[dim]
|
|
2489
|
+
step = 1 if node.step is None else adj.eval(node.step).constant
|
|
2490
|
+
|
|
2491
|
+
if node.lower is None:
|
|
2492
|
+
start = length - 1 if step < 0 else 0
|
|
2493
|
+
else:
|
|
2494
|
+
start = adj.eval(node.lower).constant
|
|
2495
|
+
start = min(max(start, -length), length)
|
|
2496
|
+
start = start + length if start < 0 else start
|
|
2470
2497
|
|
|
2471
|
-
|
|
2498
|
+
if node.upper is None:
|
|
2499
|
+
stop = -1 if step < 0 else length
|
|
2500
|
+
else:
|
|
2501
|
+
stop = adj.eval(node.upper).constant
|
|
2502
|
+
stop = min(max(stop, -length), length)
|
|
2503
|
+
stop = stop + length if stop < 0 else stop
|
|
2472
2504
|
|
|
2473
|
-
|
|
2474
|
-
|
|
2475
|
-
|
|
2476
|
-
|
|
2477
|
-
var = adj.eval(node.slice)
|
|
2478
|
-
var_name = var.label
|
|
2479
|
-
var = Var(f"adj_{var_name}", type=var.type, constant=None, prefix=False)
|
|
2480
|
-
return var
|
|
2505
|
+
slice = adj.add_builtin_call("slice", (start, stop, step))
|
|
2506
|
+
indices.append(slice)
|
|
2507
|
+
else:
|
|
2508
|
+
indices.append(adj.eval(node))
|
|
2481
2509
|
|
|
2482
|
-
|
|
2510
|
+
return tuple(indices)
|
|
2511
|
+
else:
|
|
2512
|
+
return tuple(adj.eval(x) for x in nodes)
|
|
2483
2513
|
|
|
2514
|
+
def emit_indexing(adj, target, indices):
|
|
2484
2515
|
target_type = strip_reference(target.type)
|
|
2516
|
+
indices = adj.eval_indices(target_type, indices)
|
|
2517
|
+
|
|
2485
2518
|
if is_array(target_type):
|
|
2486
2519
|
if len(indices) == target_type.ndim:
|
|
2487
2520
|
# handles array loads (where each dimension has an index specified)
|
|
@@ -2520,47 +2553,116 @@ class Adjoint:
|
|
|
2520
2553
|
|
|
2521
2554
|
return out
|
|
2522
2555
|
|
|
2556
|
+
# from a list of lists of indices, strip the first `count` indices
|
|
2557
|
+
@staticmethod
|
|
2558
|
+
def strip_indices(indices, count):
|
|
2559
|
+
dim = count
|
|
2560
|
+
while count > 0:
|
|
2561
|
+
ij = indices[0]
|
|
2562
|
+
indices = indices[1:]
|
|
2563
|
+
count -= len(ij)
|
|
2564
|
+
|
|
2565
|
+
# report straddling like in `arr2d[0][1,2]` as a syntax error
|
|
2566
|
+
if count < 0:
|
|
2567
|
+
raise WarpCodegenError(
|
|
2568
|
+
f"Incorrect number of indices specified for array indexing, got {dim - count} indices for a {dim} dimensional array."
|
|
2569
|
+
)
|
|
2570
|
+
|
|
2571
|
+
return indices
|
|
2572
|
+
|
|
2573
|
+
def recurse_subscript(adj, node, indices):
|
|
2574
|
+
if isinstance(node, ast.Name):
|
|
2575
|
+
target = adj.eval(node)
|
|
2576
|
+
return target, indices
|
|
2577
|
+
|
|
2578
|
+
if isinstance(node, ast.Subscript):
|
|
2579
|
+
if hasattr(node.value, "attr") and node.value.attr == "adjoint":
|
|
2580
|
+
return adj.eval(node), indices
|
|
2581
|
+
|
|
2582
|
+
if isinstance(node.slice, ast.Tuple):
|
|
2583
|
+
ij = node.slice.elts
|
|
2584
|
+
elif isinstance(node.slice, ast.Index) and isinstance(node.slice.value, ast.Tuple):
|
|
2585
|
+
# The node `ast.Index` is deprecated in Python 3.9.
|
|
2586
|
+
ij = node.slice.value.elts
|
|
2587
|
+
elif isinstance(node.slice, ast.ExtSlice):
|
|
2588
|
+
# The node `ast.ExtSlice` is deprecated in Python 3.9.
|
|
2589
|
+
ij = node.slice.dims
|
|
2590
|
+
else:
|
|
2591
|
+
ij = [node.slice]
|
|
2592
|
+
|
|
2593
|
+
indices = [ij, *indices] # prepend
|
|
2594
|
+
|
|
2595
|
+
target, indices = adj.recurse_subscript(node.value, indices)
|
|
2596
|
+
|
|
2597
|
+
target_type = strip_reference(target.type)
|
|
2598
|
+
if is_array(target_type):
|
|
2599
|
+
flat_indices = [i for ij in indices for i in ij]
|
|
2600
|
+
if len(flat_indices) > target_type.ndim:
|
|
2601
|
+
target = adj.emit_indexing(target, flat_indices[: target_type.ndim])
|
|
2602
|
+
indices = adj.strip_indices(indices, target_type.ndim)
|
|
2603
|
+
|
|
2604
|
+
return target, indices
|
|
2605
|
+
|
|
2606
|
+
target = adj.eval(node)
|
|
2607
|
+
return target, indices
|
|
2608
|
+
|
|
2609
|
+
# returns the object being indexed, and the list of indices
|
|
2610
|
+
def eval_subscript(adj, node):
|
|
2611
|
+
target, indices = adj.recurse_subscript(node, [])
|
|
2612
|
+
flat_indices = [i for ij in indices for i in ij]
|
|
2613
|
+
return target, flat_indices
|
|
2614
|
+
|
|
2615
|
+
def emit_Subscript(adj, node):
|
|
2616
|
+
if hasattr(node.value, "attr") and node.value.attr == "adjoint":
|
|
2617
|
+
# handle adjoint of a variable, i.e. wp.adjoint[var]
|
|
2618
|
+
node.slice.is_adjoint = True
|
|
2619
|
+
var = adj.eval(node.slice)
|
|
2620
|
+
var_name = var.label
|
|
2621
|
+
var = Var(f"adj_{var_name}", type=var.type, constant=None, prefix=False)
|
|
2622
|
+
return var
|
|
2623
|
+
|
|
2624
|
+
target, indices = adj.eval_subscript(node)
|
|
2625
|
+
|
|
2626
|
+
return adj.emit_indexing(target, indices)
|
|
2627
|
+
|
|
2523
2628
|
def emit_Assign(adj, node):
|
|
2524
2629
|
if len(node.targets) != 1:
|
|
2525
2630
|
raise WarpCodegenError("Assigning the same value to multiple variables is not supported")
|
|
2526
2631
|
|
|
2527
|
-
|
|
2632
|
+
# Check if the rhs corresponds to an unsupported construct.
|
|
2633
|
+
# Tuples are supported in the context of assigning multiple variables
|
|
2634
|
+
# at once, but not for simple assignments like `x = (1, 2, 3)`.
|
|
2635
|
+
# Therefore, we need to catch this specific case here instead of
|
|
2636
|
+
# more generally in `adj.eval()`.
|
|
2637
|
+
if isinstance(node.value, ast.List):
|
|
2638
|
+
raise WarpCodegenError(
|
|
2639
|
+
"List constructs are not supported in kernels. Use vectors like `wp.vec3()` for small collections instead."
|
|
2640
|
+
)
|
|
2528
2641
|
|
|
2529
|
-
|
|
2530
|
-
# Check if the rhs corresponds to an unsupported construct.
|
|
2531
|
-
# Tuples are supported in the context of assigning multiple variables
|
|
2532
|
-
# at once, but not for simple assignments like `x = (1, 2, 3)`.
|
|
2533
|
-
# Therefore, we need to catch this specific case here instead of
|
|
2534
|
-
# more generally in `adj.eval()`.
|
|
2535
|
-
if isinstance(node.value, ast.List):
|
|
2536
|
-
raise WarpCodegenError(
|
|
2537
|
-
"List constructs are not supported in kernels. Use vectors like `wp.vec3()` for small collections instead."
|
|
2538
|
-
)
|
|
2642
|
+
lhs = node.targets[0]
|
|
2539
2643
|
|
|
2540
|
-
|
|
2541
|
-
if isinstance(lhs, ast.Tuple):
|
|
2644
|
+
if isinstance(lhs, ast.Tuple) and isinstance(node.value, ast.Call):
|
|
2542
2645
|
# record the expected number of outputs on the node
|
|
2543
2646
|
# we do this so we can decide which function to
|
|
2544
2647
|
# call based on the number of expected outputs
|
|
2545
|
-
|
|
2546
|
-
node.value.expects = len(lhs.elts)
|
|
2648
|
+
node.value.expects = len(lhs.elts)
|
|
2547
2649
|
|
|
2548
|
-
|
|
2549
|
-
|
|
2550
|
-
|
|
2551
|
-
|
|
2552
|
-
|
|
2650
|
+
# evaluate rhs
|
|
2651
|
+
if isinstance(lhs, ast.Tuple) and isinstance(node.value, ast.Tuple):
|
|
2652
|
+
rhs = [adj.eval(v) for v in node.value.elts]
|
|
2653
|
+
else:
|
|
2654
|
+
rhs = adj.eval(node.value)
|
|
2655
|
+
|
|
2656
|
+
# handle the case where we are assigning multiple output variables
|
|
2657
|
+
if isinstance(lhs, ast.Tuple):
|
|
2658
|
+
subtype = getattr(rhs, "type", None)
|
|
2553
2659
|
|
|
2554
|
-
subtype = getattr(out, "type", None)
|
|
2555
2660
|
if isinstance(subtype, warp.types.tuple_t):
|
|
2556
|
-
if len(
|
|
2661
|
+
if len(rhs.type.types) != len(lhs.elts):
|
|
2557
2662
|
raise WarpCodegenError(
|
|
2558
|
-
f"Invalid number of values to unpack (expected {len(lhs.elts)}, got {len(
|
|
2663
|
+
f"Invalid number of values to unpack (expected {len(lhs.elts)}, got {len(rhs.type.types)})."
|
|
2559
2664
|
)
|
|
2560
|
-
|
|
2561
|
-
out = tuple(
|
|
2562
|
-
adj.add_builtin_call("extract", (target, adj.add_constant(i))) for i in range(len(lhs.elts))
|
|
2563
|
-
)
|
|
2665
|
+
rhs = tuple(adj.add_builtin_call("extract", (rhs, adj.add_constant(i))) for i in range(len(lhs.elts)))
|
|
2564
2666
|
|
|
2565
2667
|
names = []
|
|
2566
2668
|
for v in lhs.elts:
|
|
@@ -2571,11 +2673,12 @@ class Adjoint:
|
|
|
2571
2673
|
"Multiple return functions can only assign to simple variables, e.g.: x, y = func()"
|
|
2572
2674
|
)
|
|
2573
2675
|
|
|
2574
|
-
if len(names) != len(
|
|
2676
|
+
if len(names) != len(rhs):
|
|
2575
2677
|
raise WarpCodegenError(
|
|
2576
|
-
f"Multiple return functions need to receive all their output values, incorrect number of values to unpack (expected {len(
|
|
2678
|
+
f"Multiple return functions need to receive all their output values, incorrect number of values to unpack (expected {len(rhs)}, got {len(names)})"
|
|
2577
2679
|
)
|
|
2578
2680
|
|
|
2681
|
+
out = rhs
|
|
2579
2682
|
for name, rhs in zip(names, out):
|
|
2580
2683
|
if name in adj.symbols:
|
|
2581
2684
|
if not types_equal(rhs.type, adj.symbols[name].type):
|
|
@@ -2587,8 +2690,6 @@ class Adjoint:
|
|
|
2587
2690
|
|
|
2588
2691
|
# handles the case where we are assigning to an array index (e.g.: arr[i] = 2.0)
|
|
2589
2692
|
elif isinstance(lhs, ast.Subscript):
|
|
2590
|
-
rhs = adj.eval(node.value)
|
|
2591
|
-
|
|
2592
2693
|
if hasattr(lhs.value, "attr") and lhs.value.attr == "adjoint":
|
|
2593
2694
|
# handle adjoint of a variable, i.e. wp.adjoint[var]
|
|
2594
2695
|
lhs.slice.is_adjoint = True
|
|
@@ -2600,6 +2701,7 @@ class Adjoint:
|
|
|
2600
2701
|
target, indices = adj.eval_subscript(lhs)
|
|
2601
2702
|
|
|
2602
2703
|
target_type = strip_reference(target.type)
|
|
2704
|
+
indices = adj.eval_indices(target_type, indices)
|
|
2603
2705
|
|
|
2604
2706
|
if is_array(target_type):
|
|
2605
2707
|
adj.add_builtin_call("array_store", [target, *indices, rhs])
|
|
@@ -2621,14 +2723,11 @@ class Adjoint:
|
|
|
2621
2723
|
or type_is_transformation(target_type)
|
|
2622
2724
|
):
|
|
2623
2725
|
# recursively unwind AST, stopping at penultimate node
|
|
2624
|
-
|
|
2625
|
-
while hasattr(
|
|
2626
|
-
|
|
2627
|
-
node = node.value
|
|
2628
|
-
else:
|
|
2629
|
-
break
|
|
2726
|
+
root = lhs
|
|
2727
|
+
while hasattr(root.value, "value"):
|
|
2728
|
+
root = root.value
|
|
2630
2729
|
# lhs is updating a variable adjoint (i.e. wp.adjoint[var])
|
|
2631
|
-
if hasattr(
|
|
2730
|
+
if hasattr(root, "attr") and root.attr == "adjoint":
|
|
2632
2731
|
attr = adj.add_builtin_call("index", [target, *indices])
|
|
2633
2732
|
adj.add_builtin_call("store", [attr, rhs])
|
|
2634
2733
|
return
|
|
@@ -2666,9 +2765,6 @@ class Adjoint:
|
|
|
2666
2765
|
# symbol name
|
|
2667
2766
|
name = lhs.id
|
|
2668
2767
|
|
|
2669
|
-
# evaluate rhs
|
|
2670
|
-
rhs = adj.eval(node.value)
|
|
2671
|
-
|
|
2672
2768
|
# check type matches if symbol already defined
|
|
2673
2769
|
if name in adj.symbols:
|
|
2674
2770
|
if not types_equal(strip_reference(rhs.type), adj.symbols[name].type):
|
|
@@ -2689,7 +2785,6 @@ class Adjoint:
|
|
|
2689
2785
|
adj.symbols[name] = out
|
|
2690
2786
|
|
|
2691
2787
|
elif isinstance(lhs, ast.Attribute):
|
|
2692
|
-
rhs = adj.eval(node.value)
|
|
2693
2788
|
aggregate = adj.eval(lhs.value)
|
|
2694
2789
|
aggregate_type = strip_reference(aggregate.type)
|
|
2695
2790
|
|
|
@@ -2777,9 +2872,9 @@ class Adjoint:
|
|
|
2777
2872
|
new_node = ast.Assign(targets=[lhs], value=ast.BinOp(lhs, node.op, node.value))
|
|
2778
2873
|
adj.eval(new_node)
|
|
2779
2874
|
|
|
2780
|
-
|
|
2781
|
-
rhs = adj.eval(node.value)
|
|
2875
|
+
rhs = adj.eval(node.value)
|
|
2782
2876
|
|
|
2877
|
+
if isinstance(lhs, ast.Subscript):
|
|
2783
2878
|
# wp.adjoint[var] appears in custom grad functions, and does not require
|
|
2784
2879
|
# special consideration in the AugAssign case
|
|
2785
2880
|
if hasattr(lhs.value, "attr") and lhs.value.attr == "adjoint":
|
|
@@ -2789,6 +2884,7 @@ class Adjoint:
|
|
|
2789
2884
|
target, indices = adj.eval_subscript(lhs)
|
|
2790
2885
|
|
|
2791
2886
|
target_type = strip_reference(target.type)
|
|
2887
|
+
indices = adj.eval_indices(target_type, indices)
|
|
2792
2888
|
|
|
2793
2889
|
if is_array(target_type):
|
|
2794
2890
|
# target_types int8, uint8, int16, uint16 are not suitable for atomic array accumulation
|
|
@@ -2861,7 +2957,6 @@ class Adjoint:
|
|
|
2861
2957
|
|
|
2862
2958
|
elif isinstance(lhs, ast.Name):
|
|
2863
2959
|
target = adj.eval(node.target)
|
|
2864
|
-
rhs = adj.eval(node.value)
|
|
2865
2960
|
|
|
2866
2961
|
if is_tile(target.type) and is_tile(rhs.type):
|
|
2867
2962
|
if isinstance(node.op, ast.Add):
|
|
@@ -3163,6 +3258,8 @@ class Adjoint:
|
|
|
3163
3258
|
|
|
3164
3259
|
try:
|
|
3165
3260
|
value = eval(code_to_eval, vars_dict)
|
|
3261
|
+
if isinstance(value, (enum.IntEnum, enum.IntFlag)):
|
|
3262
|
+
value = int(value)
|
|
3166
3263
|
if warp.config.verbose:
|
|
3167
3264
|
print(f"Evaluated static command: {static_code} = {value}")
|
|
3168
3265
|
except NameError as e:
|
|
@@ -3373,6 +3470,11 @@ cuda_module_header = """
|
|
|
3373
3470
|
#define WP_NO_CRT
|
|
3374
3471
|
#include "builtin.h"
|
|
3375
3472
|
|
|
3473
|
+
// Map wp.breakpoint() to a device brkpt at the call site so cuda-gdb attributes the stop to the generated .cu line
|
|
3474
|
+
#if defined(__CUDACC__) && !defined(_MSC_VER)
|
|
3475
|
+
#define __debugbreak() __brkpt()
|
|
3476
|
+
#endif
|
|
3477
|
+
|
|
3376
3478
|
// avoid namespacing of float type for casting to float type, this is to avoid wp::float(x), which is not valid in C++
|
|
3377
3479
|
#define float(x) cast_float(x)
|
|
3378
3480
|
#define adj_float(x, adj_x, adj_ret) adj_cast_float(x, adj_x, adj_ret)
|
|
@@ -3410,6 +3512,12 @@ static CUDA_CALLABLE void adj_{name}({reverse_args})
|
|
|
3410
3512
|
{{
|
|
3411
3513
|
{reverse_body}}}
|
|
3412
3514
|
|
|
3515
|
+
// Required when compiling adjoints.
|
|
3516
|
+
CUDA_CALLABLE {name} add(const {name}& a, const {name}& b)
|
|
3517
|
+
{{
|
|
3518
|
+
return {name}();
|
|
3519
|
+
}}
|
|
3520
|
+
|
|
3413
3521
|
CUDA_CALLABLE void adj_atomic_add({name}* p, {name} t)
|
|
3414
3522
|
{{
|
|
3415
3523
|
{atomic_add_body}}}
|
|
@@ -3490,7 +3598,8 @@ cuda_kernel_template_backward = """
|
|
|
3490
3598
|
cpu_kernel_template_forward = """
|
|
3491
3599
|
|
|
3492
3600
|
void {name}_cpu_kernel_forward(
|
|
3493
|
-
{forward_args}
|
|
3601
|
+
{forward_args},
|
|
3602
|
+
wp_args_{name} *_wp_args)
|
|
3494
3603
|
{{
|
|
3495
3604
|
{forward_body}}}
|
|
3496
3605
|
|
|
@@ -3499,7 +3608,9 @@ void {name}_cpu_kernel_forward(
|
|
|
3499
3608
|
cpu_kernel_template_backward = """
|
|
3500
3609
|
|
|
3501
3610
|
void {name}_cpu_kernel_backward(
|
|
3502
|
-
{reverse_args}
|
|
3611
|
+
{reverse_args},
|
|
3612
|
+
wp_args_{name} *_wp_args,
|
|
3613
|
+
wp_args_{name} *_wp_adj_args)
|
|
3503
3614
|
{{
|
|
3504
3615
|
{reverse_body}}}
|
|
3505
3616
|
|
|
@@ -3511,15 +3622,15 @@ extern "C" {{
|
|
|
3511
3622
|
|
|
3512
3623
|
// Python CPU entry points
|
|
3513
3624
|
WP_API void {name}_cpu_forward(
|
|
3514
|
-
|
|
3625
|
+
wp::launch_bounds_t dim,
|
|
3626
|
+
wp_args_{name} *_wp_args)
|
|
3515
3627
|
{{
|
|
3516
3628
|
for (size_t task_index = 0; task_index < dim.size; ++task_index)
|
|
3517
3629
|
{{
|
|
3518
3630
|
// init shared memory allocator
|
|
3519
3631
|
wp::tile_alloc_shared(0, true);
|
|
3520
3632
|
|
|
3521
|
-
{name}_cpu_kernel_forward(
|
|
3522
|
-
{forward_params});
|
|
3633
|
+
{name}_cpu_kernel_forward(dim, task_index, _wp_args);
|
|
3523
3634
|
|
|
3524
3635
|
// check shared memory allocator
|
|
3525
3636
|
wp::tile_alloc_shared(0, false, true);
|
|
@@ -3536,15 +3647,16 @@ cpu_module_template_backward = """
|
|
|
3536
3647
|
extern "C" {{
|
|
3537
3648
|
|
|
3538
3649
|
WP_API void {name}_cpu_backward(
|
|
3539
|
-
|
|
3650
|
+
wp::launch_bounds_t dim,
|
|
3651
|
+
wp_args_{name} *_wp_args,
|
|
3652
|
+
wp_args_{name} *_wp_adj_args)
|
|
3540
3653
|
{{
|
|
3541
3654
|
for (size_t task_index = 0; task_index < dim.size; ++task_index)
|
|
3542
3655
|
{{
|
|
3543
3656
|
// initialize shared memory allocator
|
|
3544
3657
|
wp::tile_alloc_shared(0, true);
|
|
3545
3658
|
|
|
3546
|
-
{name}_cpu_kernel_backward(
|
|
3547
|
-
{reverse_params});
|
|
3659
|
+
{name}_cpu_kernel_backward(dim, task_index, _wp_args, _wp_adj_args);
|
|
3548
3660
|
|
|
3549
3661
|
// check shared memory allocator
|
|
3550
3662
|
wp::tile_alloc_shared(0, false, true);
|
|
@@ -3575,7 +3687,7 @@ def constant_str(value):
|
|
|
3575
3687
|
# special case for float16, which is stored as uint16 in the ctypes.Array
|
|
3576
3688
|
from warp.context import runtime
|
|
3577
3689
|
|
|
3578
|
-
scalar_value = runtime.core.
|
|
3690
|
+
scalar_value = runtime.core.wp_half_bits_to_float
|
|
3579
3691
|
else:
|
|
3580
3692
|
|
|
3581
3693
|
def scalar_value(x):
|
|
@@ -3713,8 +3825,17 @@ def codegen_func_forward(adj, func_type="kernel", device="cpu"):
|
|
|
3713
3825
|
|
|
3714
3826
|
indent_block = " " * indent
|
|
3715
3827
|
|
|
3716
|
-
# primal vars
|
|
3717
3828
|
lines = []
|
|
3829
|
+
|
|
3830
|
+
# argument vars
|
|
3831
|
+
if device == "cpu" and func_type == "kernel":
|
|
3832
|
+
lines += ["//---------\n"]
|
|
3833
|
+
lines += ["// argument vars\n"]
|
|
3834
|
+
|
|
3835
|
+
for var in adj.args:
|
|
3836
|
+
lines += [f"{var.ctype()} {var.emit()} = _wp_args->{var.label};\n"]
|
|
3837
|
+
|
|
3838
|
+
# primal vars
|
|
3718
3839
|
lines += ["//---------\n"]
|
|
3719
3840
|
lines += ["// primal vars\n"]
|
|
3720
3841
|
|
|
@@ -3758,6 +3879,17 @@ def codegen_func_reverse(adj, func_type="kernel", device="cpu"):
|
|
|
3758
3879
|
|
|
3759
3880
|
lines = []
|
|
3760
3881
|
|
|
3882
|
+
# argument vars
|
|
3883
|
+
if device == "cpu" and func_type == "kernel":
|
|
3884
|
+
lines += ["//---------\n"]
|
|
3885
|
+
lines += ["// argument vars\n"]
|
|
3886
|
+
|
|
3887
|
+
for var in adj.args:
|
|
3888
|
+
lines += [f"{var.ctype()} {var.emit()} = _wp_args->{var.label};\n"]
|
|
3889
|
+
|
|
3890
|
+
for var in adj.args:
|
|
3891
|
+
lines += [f"{var.ctype()} {var.emit_adj()} = _wp_adj_args->{var.label};\n"]
|
|
3892
|
+
|
|
3761
3893
|
# primal vars
|
|
3762
3894
|
lines += ["//---------\n"]
|
|
3763
3895
|
lines += ["// primal vars\n"]
|
|
@@ -3849,6 +3981,19 @@ def codegen_func(adj, c_func_name: str, device="cpu", options=None):
|
|
|
3849
3981
|
f"annotated as `{warp.context.type_str(adj.arg_types['return'])}` "
|
|
3850
3982
|
f"but the code returns a value of type `{warp.context.type_str(adj.return_var[0].type)}`."
|
|
3851
3983
|
)
|
|
3984
|
+
elif (
|
|
3985
|
+
isinstance(adj.return_var[0].type, warp.types.fixedarray)
|
|
3986
|
+
and type(adj.arg_types["return"]) is warp.types.array
|
|
3987
|
+
):
|
|
3988
|
+
# If the return statement yields a `fixedarray` while the function is annotated
|
|
3989
|
+
# to return a standard `array`, then raise an error since the `fixedarray` storage
|
|
3990
|
+
# allocated on the stack will be freed once the function exits, meaning that the
|
|
3991
|
+
# resulting `array` instance will point to an invalid data.
|
|
3992
|
+
raise WarpCodegenError(
|
|
3993
|
+
f"The function `{adj.fun_name}` returns a fixed-size array "
|
|
3994
|
+
f"whereas it has its return type annotated as "
|
|
3995
|
+
f"`{warp.context.type_str(adj.arg_types['return'])}`."
|
|
3996
|
+
)
|
|
3852
3997
|
|
|
3853
3998
|
# Build line directive for function definition (subtract 1 to account for 1-indexing of AST line numbers)
|
|
3854
3999
|
# This is used as a catch-all C-to-Python source line mapping for any code that does not have
|
|
@@ -3927,10 +4072,10 @@ def codegen_func(adj, c_func_name: str, device="cpu", options=None):
|
|
|
3927
4072
|
if adj.custom_reverse_mode:
|
|
3928
4073
|
reverse_body = "\t// user-defined adjoint code\n" + forward_body
|
|
3929
4074
|
else:
|
|
3930
|
-
if options.get("enable_backward", True):
|
|
4075
|
+
if options.get("enable_backward", True) and adj.used_by_backward_kernel:
|
|
3931
4076
|
reverse_body = codegen_func_reverse(adj, func_type="function", device=device)
|
|
3932
4077
|
else:
|
|
3933
|
-
reverse_body = '\t// reverse mode disabled (module option "enable_backward" is False)\n'
|
|
4078
|
+
reverse_body = '\t// reverse mode disabled (module option "enable_backward" is False or no dependent kernel found with "enable_backward")\n'
|
|
3934
4079
|
s += reverse_template.format(
|
|
3935
4080
|
name=c_func_name,
|
|
3936
4081
|
return_type=return_type,
|
|
@@ -4022,6 +4167,13 @@ def codegen_kernel(kernel, device, options):
|
|
|
4022
4167
|
|
|
4023
4168
|
adj = kernel.adj
|
|
4024
4169
|
|
|
4170
|
+
args_struct = ""
|
|
4171
|
+
if device == "cpu":
|
|
4172
|
+
args_struct = f"struct wp_args_{kernel.get_mangled_name()} {{\n"
|
|
4173
|
+
for i in adj.args:
|
|
4174
|
+
args_struct += f" {i.ctype()} {i.label};\n"
|
|
4175
|
+
args_struct += "};\n"
|
|
4176
|
+
|
|
4025
4177
|
# Build line directive for function definition (subtract 1 to account for 1-indexing of AST line numbers)
|
|
4026
4178
|
# This is used as a catch-all C-to-Python source line mapping for any code that does not have
|
|
4027
4179
|
# a direct mapping to a Python source line.
|
|
@@ -4047,9 +4199,9 @@ def codegen_kernel(kernel, device, options):
|
|
|
4047
4199
|
forward_args = ["wp::launch_bounds_t dim"]
|
|
4048
4200
|
if device == "cpu":
|
|
4049
4201
|
forward_args.append("size_t task_index")
|
|
4050
|
-
|
|
4051
|
-
|
|
4052
|
-
|
|
4202
|
+
else:
|
|
4203
|
+
for arg in adj.args:
|
|
4204
|
+
forward_args.append(arg.ctype() + " var_" + arg.label)
|
|
4053
4205
|
|
|
4054
4206
|
forward_body = codegen_func_forward(adj, func_type="kernel", device=device)
|
|
4055
4207
|
template_fmt_args.update(
|
|
@@ -4066,17 +4218,16 @@ def codegen_kernel(kernel, device, options):
|
|
|
4066
4218
|
reverse_args = ["wp::launch_bounds_t dim"]
|
|
4067
4219
|
if device == "cpu":
|
|
4068
4220
|
reverse_args.append("size_t task_index")
|
|
4069
|
-
|
|
4070
|
-
|
|
4071
|
-
|
|
4072
|
-
|
|
4073
|
-
|
|
4074
|
-
|
|
4075
|
-
|
|
4076
|
-
|
|
4077
|
-
|
|
4078
|
-
|
|
4079
|
-
reverse_args.append(arg.ctype() + " adj_" + arg.label)
|
|
4221
|
+
else:
|
|
4222
|
+
for arg in adj.args:
|
|
4223
|
+
reverse_args.append(arg.ctype() + " var_" + arg.label)
|
|
4224
|
+
for arg in adj.args:
|
|
4225
|
+
# indexed array gradients are regular arrays
|
|
4226
|
+
if isinstance(arg.type, indexedarray):
|
|
4227
|
+
_arg = Var(arg.label, array(dtype=arg.type.dtype, ndim=arg.type.ndim))
|
|
4228
|
+
reverse_args.append(_arg.ctype() + " adj_" + arg.label)
|
|
4229
|
+
else:
|
|
4230
|
+
reverse_args.append(arg.ctype() + " adj_" + arg.label)
|
|
4080
4231
|
|
|
4081
4232
|
reverse_body = codegen_func_reverse(adj, func_type="kernel", device=device)
|
|
4082
4233
|
template_fmt_args.update(
|
|
@@ -4088,7 +4239,7 @@ def codegen_kernel(kernel, device, options):
|
|
|
4088
4239
|
template += template_backward
|
|
4089
4240
|
|
|
4090
4241
|
s = template.format(**template_fmt_args)
|
|
4091
|
-
return s
|
|
4242
|
+
return args_struct + s
|
|
4092
4243
|
|
|
4093
4244
|
|
|
4094
4245
|
def codegen_module(kernel, device, options):
|
|
@@ -4099,59 +4250,14 @@ def codegen_module(kernel, device, options):
|
|
|
4099
4250
|
options = dict(options)
|
|
4100
4251
|
options.update(kernel.options)
|
|
4101
4252
|
|
|
4102
|
-
adj = kernel.adj
|
|
4103
|
-
|
|
4104
4253
|
template = ""
|
|
4105
4254
|
template_fmt_args = {
|
|
4106
4255
|
"name": kernel.get_mangled_name(),
|
|
4107
4256
|
}
|
|
4108
4257
|
|
|
4109
|
-
# build forward signature
|
|
4110
|
-
forward_args = ["wp::launch_bounds_t dim"]
|
|
4111
|
-
forward_params = ["dim", "task_index"]
|
|
4112
|
-
|
|
4113
|
-
for arg in adj.args:
|
|
4114
|
-
if hasattr(arg.type, "_wp_generic_type_str_"):
|
|
4115
|
-
# vectors and matrices are passed from Python by pointer
|
|
4116
|
-
forward_args.append(f"const {arg.ctype()}* var_" + arg.label)
|
|
4117
|
-
forward_params.append(f"*var_{arg.label}")
|
|
4118
|
-
else:
|
|
4119
|
-
forward_args.append(f"{arg.ctype()} var_{arg.label}")
|
|
4120
|
-
forward_params.append("var_" + arg.label)
|
|
4121
|
-
|
|
4122
|
-
template_fmt_args.update(
|
|
4123
|
-
{
|
|
4124
|
-
"forward_args": indent(forward_args),
|
|
4125
|
-
"forward_params": indent(forward_params, 3),
|
|
4126
|
-
}
|
|
4127
|
-
)
|
|
4128
4258
|
template += cpu_module_template_forward
|
|
4129
4259
|
|
|
4130
4260
|
if options["enable_backward"]:
|
|
4131
|
-
# build reverse signature
|
|
4132
|
-
reverse_args = [*forward_args]
|
|
4133
|
-
reverse_params = [*forward_params]
|
|
4134
|
-
|
|
4135
|
-
for arg in adj.args:
|
|
4136
|
-
if isinstance(arg.type, indexedarray):
|
|
4137
|
-
# indexed array gradients are regular arrays
|
|
4138
|
-
_arg = Var(arg.label, array(dtype=arg.type.dtype, ndim=arg.type.ndim))
|
|
4139
|
-
reverse_args.append(f"const {_arg.ctype()} adj_{arg.label}")
|
|
4140
|
-
reverse_params.append(f"adj_{_arg.label}")
|
|
4141
|
-
elif hasattr(arg.type, "_wp_generic_type_str_"):
|
|
4142
|
-
# vectors and matrices are passed from Python by pointer
|
|
4143
|
-
reverse_args.append(f"const {arg.ctype()}* adj_{arg.label}")
|
|
4144
|
-
reverse_params.append(f"*adj_{arg.label}")
|
|
4145
|
-
else:
|
|
4146
|
-
reverse_args.append(f"{arg.ctype()} adj_{arg.label}")
|
|
4147
|
-
reverse_params.append(f"adj_{arg.label}")
|
|
4148
|
-
|
|
4149
|
-
template_fmt_args.update(
|
|
4150
|
-
{
|
|
4151
|
-
"reverse_args": indent(reverse_args),
|
|
4152
|
-
"reverse_params": indent(reverse_params, 3),
|
|
4153
|
-
}
|
|
4154
|
-
)
|
|
4155
4261
|
template += cpu_module_template_backward
|
|
4156
4262
|
|
|
4157
4263
|
s = template.format(**template_fmt_args)
|