warp-lang 1.8.0__py3-none-macosx_10_13_universal2.whl → 1.9.0__py3-none-macosx_10_13_universal2.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +282 -103
- warp/__init__.pyi +482 -110
- warp/bin/libwarp-clang.dylib +0 -0
- warp/bin/libwarp.dylib +0 -0
- warp/build.py +93 -30
- warp/build_dll.py +48 -63
- warp/builtins.py +955 -137
- warp/codegen.py +327 -209
- warp/config.py +1 -1
- warp/context.py +1363 -800
- warp/examples/core/example_marching_cubes.py +1 -0
- warp/examples/core/example_render_opengl.py +100 -3
- warp/examples/fem/example_apic_fluid.py +98 -52
- warp/examples/fem/example_convection_diffusion_dg.py +25 -4
- warp/examples/fem/example_diffusion_mgpu.py +8 -3
- warp/examples/fem/utils.py +68 -22
- warp/examples/interop/example_jax_callable.py +34 -4
- warp/examples/interop/example_jax_kernel.py +27 -1
- warp/fabric.py +1 -1
- warp/fem/cache.py +27 -19
- warp/fem/domain.py +2 -2
- warp/fem/field/nodal_field.py +2 -2
- warp/fem/field/virtual.py +266 -166
- warp/fem/geometry/geometry.py +5 -5
- warp/fem/integrate.py +200 -91
- warp/fem/space/restriction.py +4 -0
- warp/fem/space/shape/tet_shape_function.py +3 -10
- warp/jax_experimental/custom_call.py +1 -1
- warp/jax_experimental/ffi.py +203 -54
- warp/marching_cubes.py +708 -0
- warp/native/array.h +103 -8
- warp/native/builtin.h +90 -9
- warp/native/bvh.cpp +64 -28
- warp/native/bvh.cu +58 -58
- warp/native/bvh.h +2 -2
- warp/native/clang/clang.cpp +7 -7
- warp/native/coloring.cpp +13 -3
- warp/native/crt.cpp +2 -2
- warp/native/crt.h +3 -5
- warp/native/cuda_util.cpp +42 -11
- warp/native/cuda_util.h +10 -4
- warp/native/exports.h +1842 -1908
- warp/native/fabric.h +2 -1
- warp/native/hashgrid.cpp +37 -37
- warp/native/hashgrid.cu +2 -2
- warp/native/initializer_array.h +1 -1
- warp/native/intersect.h +4 -4
- warp/native/mat.h +1913 -119
- warp/native/mathdx.cpp +43 -43
- warp/native/mesh.cpp +24 -24
- warp/native/mesh.cu +26 -26
- warp/native/mesh.h +5 -3
- warp/native/nanovdb/GridHandle.h +179 -12
- warp/native/nanovdb/HostBuffer.h +8 -7
- warp/native/nanovdb/NanoVDB.h +517 -895
- warp/native/nanovdb/NodeManager.h +323 -0
- warp/native/nanovdb/PNanoVDB.h +2 -2
- warp/native/quat.h +337 -16
- warp/native/rand.h +7 -7
- warp/native/range.h +7 -1
- warp/native/reduce.cpp +10 -10
- warp/native/reduce.cu +13 -14
- warp/native/runlength_encode.cpp +2 -2
- warp/native/runlength_encode.cu +5 -5
- warp/native/scan.cpp +3 -3
- warp/native/scan.cu +4 -4
- warp/native/sort.cpp +10 -10
- warp/native/sort.cu +22 -22
- warp/native/sparse.cpp +8 -8
- warp/native/sparse.cu +14 -14
- warp/native/spatial.h +366 -17
- warp/native/svd.h +23 -8
- warp/native/temp_buffer.h +2 -2
- warp/native/tile.h +303 -70
- warp/native/tile_radix_sort.h +5 -1
- warp/native/tile_reduce.h +16 -25
- warp/native/tuple.h +2 -2
- warp/native/vec.h +385 -18
- warp/native/volume.cpp +54 -54
- warp/native/volume.cu +1 -1
- warp/native/volume.h +2 -1
- warp/native/volume_builder.cu +30 -37
- warp/native/warp.cpp +150 -149
- warp/native/warp.cu +337 -193
- warp/native/warp.h +227 -226
- warp/optim/linear.py +736 -271
- warp/render/imgui_manager.py +289 -0
- warp/render/render_opengl.py +137 -57
- warp/render/render_usd.py +0 -1
- warp/sim/collide.py +1 -2
- warp/sim/graph_coloring.py +2 -2
- warp/sim/integrator_vbd.py +10 -2
- warp/sparse.py +559 -176
- warp/tape.py +2 -0
- warp/tests/aux_test_module_aot.py +7 -0
- warp/tests/cuda/test_async.py +3 -3
- warp/tests/cuda/test_conditional_captures.py +101 -0
- warp/tests/geometry/test_marching_cubes.py +233 -12
- warp/tests/sim/test_cloth.py +89 -6
- warp/tests/sim/test_coloring.py +82 -7
- warp/tests/test_array.py +56 -5
- warp/tests/test_assert.py +53 -0
- warp/tests/test_atomic_cas.py +127 -114
- warp/tests/test_codegen.py +3 -2
- warp/tests/test_context.py +8 -15
- warp/tests/test_enum.py +136 -0
- warp/tests/test_examples.py +2 -2
- warp/tests/test_fem.py +45 -2
- warp/tests/test_fixedarray.py +229 -0
- warp/tests/test_func.py +18 -15
- warp/tests/test_future_annotations.py +7 -5
- warp/tests/test_linear_solvers.py +30 -0
- warp/tests/test_map.py +1 -1
- warp/tests/test_mat.py +1540 -378
- warp/tests/test_mat_assign_copy.py +178 -0
- warp/tests/test_mat_constructors.py +574 -0
- warp/tests/test_module_aot.py +287 -0
- warp/tests/test_print.py +69 -0
- warp/tests/test_quat.py +162 -34
- warp/tests/test_quat_assign_copy.py +145 -0
- warp/tests/test_reload.py +2 -1
- warp/tests/test_sparse.py +103 -0
- warp/tests/test_spatial.py +140 -34
- warp/tests/test_spatial_assign_copy.py +160 -0
- warp/tests/test_static.py +48 -0
- warp/tests/test_struct.py +43 -3
- warp/tests/test_tape.py +38 -0
- warp/tests/test_types.py +0 -20
- warp/tests/test_vec.py +216 -441
- warp/tests/test_vec_assign_copy.py +143 -0
- warp/tests/test_vec_constructors.py +325 -0
- warp/tests/tile/test_tile.py +206 -152
- warp/tests/tile/test_tile_cholesky.py +605 -0
- warp/tests/tile/test_tile_load.py +169 -0
- warp/tests/tile/test_tile_mathdx.py +2 -558
- warp/tests/tile/test_tile_matmul.py +179 -0
- warp/tests/tile/test_tile_mlp.py +1 -1
- warp/tests/tile/test_tile_reduce.py +100 -11
- warp/tests/tile/test_tile_shared_memory.py +16 -16
- warp/tests/tile/test_tile_sort.py +59 -55
- warp/tests/unittest_suites.py +16 -0
- warp/tests/walkthrough_debug.py +1 -1
- warp/thirdparty/unittest_parallel.py +108 -9
- warp/types.py +554 -264
- warp/utils.py +68 -86
- {warp_lang-1.8.0.dist-info → warp_lang-1.9.0.dist-info}/METADATA +28 -65
- {warp_lang-1.8.0.dist-info → warp_lang-1.9.0.dist-info}/RECORD +150 -138
- warp/native/marching.cpp +0 -19
- warp/native/marching.cu +0 -514
- warp/native/marching.h +0 -19
- {warp_lang-1.8.0.dist-info → warp_lang-1.9.0.dist-info}/WHEEL +0 -0
- {warp_lang-1.8.0.dist-info → warp_lang-1.9.0.dist-info}/licenses/LICENSE.md +0 -0
- {warp_lang-1.8.0.dist-info → warp_lang-1.9.0.dist-info}/top_level.txt +0 -0
warp/codegen.py
CHANGED
|
@@ -18,9 +18,11 @@ from __future__ import annotations
|
|
|
18
18
|
import ast
|
|
19
19
|
import builtins
|
|
20
20
|
import ctypes
|
|
21
|
+
import enum
|
|
21
22
|
import functools
|
|
22
23
|
import hashlib
|
|
23
24
|
import inspect
|
|
25
|
+
import itertools
|
|
24
26
|
import math
|
|
25
27
|
import re
|
|
26
28
|
import sys
|
|
@@ -614,8 +616,12 @@ def compute_type_str(base_name, template_params):
|
|
|
614
616
|
return base_name
|
|
615
617
|
|
|
616
618
|
def param2str(p):
|
|
619
|
+
if isinstance(p, builtins.bool):
|
|
620
|
+
return "true" if p else "false"
|
|
617
621
|
if isinstance(p, int):
|
|
618
622
|
return str(p)
|
|
623
|
+
elif hasattr(p, "_wp_generic_type_str_"):
|
|
624
|
+
return compute_type_str(f"wp::{p._wp_generic_type_str_}", p._wp_type_params_)
|
|
619
625
|
elif hasattr(p, "_type_"):
|
|
620
626
|
if p.__name__ == "bool":
|
|
621
627
|
return "bool"
|
|
@@ -623,6 +629,8 @@ def compute_type_str(base_name, template_params):
|
|
|
623
629
|
return f"wp::{p.__name__}"
|
|
624
630
|
elif is_tile(p):
|
|
625
631
|
return p.ctype()
|
|
632
|
+
elif isinstance(p, Struct):
|
|
633
|
+
return p.native_name
|
|
626
634
|
|
|
627
635
|
return p.__name__
|
|
628
636
|
|
|
@@ -682,7 +690,12 @@ class Var:
|
|
|
682
690
|
|
|
683
691
|
@staticmethod
|
|
684
692
|
def type_to_ctype(t: type, value_type: builtins.bool = False) -> str:
|
|
685
|
-
if
|
|
693
|
+
if isinstance(t, fixedarray):
|
|
694
|
+
template_args = (str(t.size), Var.dtype_to_ctype(t.dtype))
|
|
695
|
+
dtypestr = ", ".join(template_args)
|
|
696
|
+
classstr = f"wp::{type(t).__name__}"
|
|
697
|
+
return f"{classstr}_t<{dtypestr}>"
|
|
698
|
+
elif is_array(t):
|
|
686
699
|
dtypestr = Var.dtype_to_ctype(t.dtype)
|
|
687
700
|
classstr = f"wp::{type(t).__name__}"
|
|
688
701
|
return f"{classstr}_t<{dtypestr}>"
|
|
@@ -778,11 +791,10 @@ def apply_defaults(
|
|
|
778
791
|
arguments = bound_args.arguments
|
|
779
792
|
new_arguments = []
|
|
780
793
|
for name in bound_args._signature.parameters.keys():
|
|
781
|
-
|
|
794
|
+
if name in arguments:
|
|
782
795
|
new_arguments.append((name, arguments[name]))
|
|
783
|
-
|
|
784
|
-
|
|
785
|
-
new_arguments.append((name, values[name]))
|
|
796
|
+
elif name in values:
|
|
797
|
+
new_arguments.append((name, values[name]))
|
|
786
798
|
|
|
787
799
|
bound_args.arguments = dict(new_arguments)
|
|
788
800
|
|
|
@@ -835,6 +847,9 @@ def get_arg_type(arg: Var | Any) -> type:
|
|
|
835
847
|
if isinstance(arg, Sequence):
|
|
836
848
|
return tuple(get_arg_type(x) for x in arg)
|
|
837
849
|
|
|
850
|
+
if is_array(arg):
|
|
851
|
+
return arg
|
|
852
|
+
|
|
838
853
|
if get_origin(arg) is tuple:
|
|
839
854
|
return tuple(get_arg_type(x) for x in get_args(arg))
|
|
840
855
|
|
|
@@ -894,6 +909,8 @@ class Adjoint:
|
|
|
894
909
|
adj.skip_forward_codegen = skip_forward_codegen
|
|
895
910
|
# whether the generation of the adjoint code is skipped for this function
|
|
896
911
|
adj.skip_reverse_codegen = skip_reverse_codegen
|
|
912
|
+
# Whether this function is used by a kernel that has has the backward pass enabled.
|
|
913
|
+
adj.used_by_backward_kernel = False
|
|
897
914
|
|
|
898
915
|
# extract name of source file
|
|
899
916
|
adj.filename = inspect.getsourcefile(func) or "unknown source file"
|
|
@@ -960,13 +977,18 @@ class Adjoint:
|
|
|
960
977
|
continue
|
|
961
978
|
|
|
962
979
|
# add variable for argument
|
|
963
|
-
arg = Var(name, type, False)
|
|
980
|
+
arg = Var(name, type, requires_grad=False)
|
|
964
981
|
adj.args.append(arg)
|
|
965
982
|
|
|
966
983
|
# pre-populate symbol dictionary with function argument names
|
|
967
984
|
# this is to avoid registering false references to overshadowed modules
|
|
968
985
|
adj.symbols[name] = arg
|
|
969
986
|
|
|
987
|
+
# Indicates whether there are unresolved static expressions in the function.
|
|
988
|
+
# These stem from wp.static() expressions that could not be evaluated at declaration time.
|
|
989
|
+
# This will signal to the module builder that this module needs to be rebuilt even if the module hash is unchanged.
|
|
990
|
+
adj.has_unresolved_static_expressions = False
|
|
991
|
+
|
|
970
992
|
# try to replace static expressions by their constant result if the
|
|
971
993
|
# expression can be evaluated at declaration time
|
|
972
994
|
adj.static_expressions: dict[str, Any] = {}
|
|
@@ -1064,17 +1086,21 @@ class Adjoint:
|
|
|
1064
1086
|
# recursively evaluate function body
|
|
1065
1087
|
try:
|
|
1066
1088
|
adj.eval(adj.tree.body[0])
|
|
1067
|
-
except Exception:
|
|
1089
|
+
except Exception as original_exc:
|
|
1068
1090
|
try:
|
|
1069
1091
|
lineno = adj.lineno + adj.fun_lineno
|
|
1070
1092
|
line = adj.source_lines[adj.lineno]
|
|
1071
1093
|
msg = f'Error while parsing function "{adj.fun_name}" at {adj.filename}:{lineno}:\n{line}\n'
|
|
1072
|
-
|
|
1073
|
-
|
|
1094
|
+
|
|
1095
|
+
# Combine the new message with the original exception's arguments
|
|
1096
|
+
new_args = (";".join([msg] + [str(a) for a in original_exc.args]),)
|
|
1097
|
+
|
|
1098
|
+
# Enhance the original exception with parser context before re-raising.
|
|
1099
|
+
# 'from None' is used to suppress Python's chained exceptions for a cleaner error output.
|
|
1100
|
+
raise type(original_exc)(*new_args).with_traceback(original_exc.__traceback__) from None
|
|
1074
1101
|
finally:
|
|
1075
1102
|
adj.skip_build = True
|
|
1076
1103
|
adj.builder = None
|
|
1077
|
-
raise e
|
|
1078
1104
|
|
|
1079
1105
|
if builder is not None:
|
|
1080
1106
|
for a in adj.args:
|
|
@@ -1220,9 +1246,9 @@ class Adjoint:
|
|
|
1220
1246
|
|
|
1221
1247
|
# lineinfo is enabled by default in debug mode regardless of the builder option, don't want to unnecessarily
|
|
1222
1248
|
# emit line directives in generated code if it's not being compiled with line information
|
|
1223
|
-
|
|
1224
|
-
|
|
1225
|
-
)
|
|
1249
|
+
build_mode = val if (val := adj.builder_options.get("mode")) is not None else warp.config.mode
|
|
1250
|
+
|
|
1251
|
+
lineinfo_enabled = adj.builder_options.get("lineinfo", False) or build_mode == "debug"
|
|
1226
1252
|
|
|
1227
1253
|
if relative_lineno is not None and lineinfo_enabled and warp.config.line_directives:
|
|
1228
1254
|
is_comment = statement.strip().startswith("//")
|
|
@@ -1341,7 +1367,7 @@ class Adjoint:
|
|
|
1341
1367
|
# unresolved function, report error
|
|
1342
1368
|
arg_type_reprs = []
|
|
1343
1369
|
|
|
1344
|
-
for x in arg_types:
|
|
1370
|
+
for x in itertools.chain(arg_types, kwarg_types.values()):
|
|
1345
1371
|
if isinstance(x, warp.context.Function):
|
|
1346
1372
|
arg_type_reprs.append("function")
|
|
1347
1373
|
else:
|
|
@@ -1371,7 +1397,7 @@ class Adjoint:
|
|
|
1371
1397
|
# in order to process them as Python does it.
|
|
1372
1398
|
bound_args: inspect.BoundArguments = func.signature.bind(*args, **kwargs)
|
|
1373
1399
|
|
|
1374
|
-
# Type args are the
|
|
1400
|
+
# Type args are the "compile time" argument values we get from codegen.
|
|
1375
1401
|
# For example, when calling `wp.vec3f(...)` from within a kernel,
|
|
1376
1402
|
# this translates in fact to calling the `vector()` built-in augmented
|
|
1377
1403
|
# with the type args `length=3, dtype=float`.
|
|
@@ -1409,20 +1435,30 @@ class Adjoint:
|
|
|
1409
1435
|
bound_args = bound_args.arguments
|
|
1410
1436
|
|
|
1411
1437
|
# if it is a user-function then build it recursively
|
|
1412
|
-
if not func.is_builtin()
|
|
1413
|
-
|
|
1414
|
-
#
|
|
1415
|
-
|
|
1416
|
-
|
|
1417
|
-
|
|
1418
|
-
if
|
|
1419
|
-
|
|
1420
|
-
|
|
1421
|
-
|
|
1438
|
+
if not func.is_builtin():
|
|
1439
|
+
# If the function called is a user function,
|
|
1440
|
+
# we need to ensure its adjoint is also being generated.
|
|
1441
|
+
if adj.used_by_backward_kernel:
|
|
1442
|
+
func.adj.used_by_backward_kernel = True
|
|
1443
|
+
|
|
1444
|
+
if adj.builder is None:
|
|
1445
|
+
func.build(None)
|
|
1446
|
+
|
|
1447
|
+
elif func not in adj.builder.functions:
|
|
1448
|
+
adj.builder.build_function(func)
|
|
1449
|
+
# add custom grad, replay functions to the list of functions
|
|
1450
|
+
# to be built later (invalid code could be generated if we built them now)
|
|
1451
|
+
# so that they are not missed when only the forward function is imported
|
|
1452
|
+
# from another module
|
|
1453
|
+
if func.custom_grad_func:
|
|
1454
|
+
adj.builder.deferred_functions.append(func.custom_grad_func)
|
|
1455
|
+
if func.custom_replay_func:
|
|
1456
|
+
adj.builder.deferred_functions.append(func.custom_replay_func)
|
|
1422
1457
|
|
|
1423
1458
|
# Resolve the return value based on the types and values of the given arguments.
|
|
1424
1459
|
bound_arg_types = {k: get_arg_type(v) for k, v in bound_args.items()}
|
|
1425
1460
|
bound_arg_values = {k: get_arg_value(v) for k, v in bound_args.items()}
|
|
1461
|
+
|
|
1426
1462
|
return_type = func.value_func(
|
|
1427
1463
|
{k: strip_reference(v) for k, v in bound_arg_types.items()},
|
|
1428
1464
|
bound_arg_values,
|
|
@@ -1486,6 +1522,9 @@ class Adjoint:
|
|
|
1486
1522
|
|
|
1487
1523
|
# if the argument is a function (and not a builtin), then build it recursively
|
|
1488
1524
|
if isinstance(func_arg_var, warp.context.Function) and not func_arg_var.is_builtin():
|
|
1525
|
+
if adj.used_by_backward_kernel:
|
|
1526
|
+
func_arg_var.adj.used_by_backward_kernel = True
|
|
1527
|
+
|
|
1489
1528
|
adj.builder.build_function(func_arg_var)
|
|
1490
1529
|
|
|
1491
1530
|
fwd_args.append(strip_reference(func_arg_var))
|
|
@@ -1879,6 +1918,9 @@ class Adjoint:
|
|
|
1879
1918
|
return obj
|
|
1880
1919
|
if isinstance(obj, type):
|
|
1881
1920
|
return obj
|
|
1921
|
+
if isinstance(obj, Struct):
|
|
1922
|
+
adj.builder.build_struct_recursive(obj)
|
|
1923
|
+
return obj
|
|
1882
1924
|
if isinstance(obj, types.ModuleType):
|
|
1883
1925
|
return obj
|
|
1884
1926
|
|
|
@@ -1931,11 +1973,17 @@ class Adjoint:
|
|
|
1931
1973
|
aggregate = adj.eval(node.value)
|
|
1932
1974
|
|
|
1933
1975
|
try:
|
|
1976
|
+
if isinstance(aggregate, Var) and aggregate.constant is not None:
|
|
1977
|
+
# this case may occur when the attribute is a constant, e.g.: `IntEnum.A.value`
|
|
1978
|
+
return aggregate
|
|
1979
|
+
|
|
1934
1980
|
if isinstance(aggregate, types.ModuleType) or isinstance(aggregate, type):
|
|
1935
1981
|
out = getattr(aggregate, node.attr)
|
|
1936
1982
|
|
|
1937
1983
|
if warp.types.is_value(out):
|
|
1938
1984
|
return adj.add_constant(out)
|
|
1985
|
+
if isinstance(out, (enum.IntEnum, enum.IntFlag)):
|
|
1986
|
+
return adj.add_constant(int(out))
|
|
1939
1987
|
|
|
1940
1988
|
return out
|
|
1941
1989
|
|
|
@@ -1963,18 +2011,29 @@ class Adjoint:
|
|
|
1963
2011
|
return adj.add_builtin_call("transform_get_rotation", [aggregate])
|
|
1964
2012
|
|
|
1965
2013
|
else:
|
|
1966
|
-
|
|
2014
|
+
attr_var = aggregate_type.vars[node.attr]
|
|
2015
|
+
|
|
2016
|
+
# represent pointer types as uint64
|
|
2017
|
+
if isinstance(attr_var.type, pointer_t):
|
|
2018
|
+
cast = f"({Var.dtype_to_ctype(uint64)}*)"
|
|
2019
|
+
adj_cast = f"({Var.dtype_to_ctype(attr_var.type.dtype)}*)"
|
|
2020
|
+
attr_type = Reference(uint64)
|
|
2021
|
+
else:
|
|
2022
|
+
cast = ""
|
|
2023
|
+
adj_cast = ""
|
|
2024
|
+
attr_type = Reference(attr_var.type)
|
|
2025
|
+
|
|
1967
2026
|
attr = adj.add_var(attr_type)
|
|
1968
2027
|
|
|
1969
2028
|
if is_reference(aggregate.type):
|
|
1970
|
-
adj.add_forward(f"{attr.emit()} = &({aggregate.emit()}->{
|
|
2029
|
+
adj.add_forward(f"{attr.emit()} = {cast}&({aggregate.emit()}->{attr_var.label});")
|
|
1971
2030
|
else:
|
|
1972
|
-
adj.add_forward(f"{attr.emit()} = &({aggregate.emit()}.{
|
|
2031
|
+
adj.add_forward(f"{attr.emit()} = {cast}&({aggregate.emit()}.{attr_var.label});")
|
|
1973
2032
|
|
|
1974
2033
|
if adj.is_differentiable_value_type(strip_reference(attr_type)):
|
|
1975
|
-
adj.add_reverse(f"{aggregate.emit_adj()}.{
|
|
2034
|
+
adj.add_reverse(f"{aggregate.emit_adj()}.{attr_var.label} += {adj_cast}{attr.emit_adj()};")
|
|
1976
2035
|
else:
|
|
1977
|
-
adj.add_reverse(f"{aggregate.emit_adj()}.{
|
|
2036
|
+
adj.add_reverse(f"{aggregate.emit_adj()}.{attr_var.label} = {adj_cast}{attr.emit_adj()};")
|
|
1978
2037
|
|
|
1979
2038
|
return attr
|
|
1980
2039
|
|
|
@@ -2302,9 +2361,12 @@ class Adjoint:
|
|
|
2302
2361
|
|
|
2303
2362
|
return var
|
|
2304
2363
|
|
|
2305
|
-
if isinstance(expr, (type, Var, warp.context.Function)):
|
|
2364
|
+
if isinstance(expr, (type, Struct, Var, warp.context.Function)):
|
|
2306
2365
|
return expr
|
|
2307
2366
|
|
|
2367
|
+
if isinstance(expr, (enum.IntEnum, enum.IntFlag)):
|
|
2368
|
+
return adj.add_constant(int(expr))
|
|
2369
|
+
|
|
2308
2370
|
return adj.add_constant(expr)
|
|
2309
2371
|
|
|
2310
2372
|
def emit_Call(adj, node):
|
|
@@ -2322,8 +2384,9 @@ class Adjoint:
|
|
|
2322
2384
|
|
|
2323
2385
|
if adj.is_static_expression(func):
|
|
2324
2386
|
# try to evaluate wp.static() expressions
|
|
2325
|
-
obj,
|
|
2387
|
+
obj, code = adj.evaluate_static_expression(node)
|
|
2326
2388
|
if obj is not None:
|
|
2389
|
+
adj.static_expressions[code] = obj
|
|
2327
2390
|
if isinstance(obj, warp.context.Function):
|
|
2328
2391
|
# special handling for wp.static() evaluating to a function
|
|
2329
2392
|
return obj
|
|
@@ -2352,7 +2415,8 @@ class Adjoint:
|
|
|
2352
2415
|
|
|
2353
2416
|
# struct constructor
|
|
2354
2417
|
if func is None and isinstance(caller, Struct):
|
|
2355
|
-
adj.builder
|
|
2418
|
+
if adj.builder is not None:
|
|
2419
|
+
adj.builder.build_struct_recursive(caller)
|
|
2356
2420
|
if node.args or node.keywords:
|
|
2357
2421
|
func = caller.value_constructor
|
|
2358
2422
|
else:
|
|
@@ -2412,68 +2476,45 @@ class Adjoint:
|
|
|
2412
2476
|
|
|
2413
2477
|
return adj.eval(node.value)
|
|
2414
2478
|
|
|
2415
|
-
|
|
2416
|
-
|
|
2417
|
-
|
|
2418
|
-
|
|
2419
|
-
|
|
2420
|
-
|
|
2421
|
-
|
|
2422
|
-
|
|
2423
|
-
|
|
2424
|
-
|
|
2425
|
-
|
|
2426
|
-
|
|
2427
|
-
|
|
2428
|
-
|
|
2429
|
-
|
|
2430
|
-
|
|
2431
|
-
|
|
2432
|
-
|
|
2433
|
-
|
|
2434
|
-
if isinstance(root.value, ast.Name):
|
|
2435
|
-
symbol = adj.emit_Name(root.value)
|
|
2436
|
-
symbol_type = strip_reference(symbol.type)
|
|
2437
|
-
if is_array(symbol_type):
|
|
2438
|
-
array = symbol
|
|
2439
|
-
break
|
|
2440
|
-
|
|
2441
|
-
root = root.value
|
|
2442
|
-
|
|
2443
|
-
# If not all indices index into the array, just evaluate the right-most indexing operation.
|
|
2444
|
-
if not array or (count > array.type.ndim):
|
|
2445
|
-
count = 1
|
|
2446
|
-
|
|
2447
|
-
indices = []
|
|
2448
|
-
root = node
|
|
2449
|
-
while len(indices) < count:
|
|
2450
|
-
if isinstance(root.slice, ast.Tuple):
|
|
2451
|
-
ij = [adj.eval(arg) for arg in root.slice.elts]
|
|
2452
|
-
elif isinstance(root.slice, ast.Index) and isinstance(root.slice.value, ast.Tuple):
|
|
2453
|
-
ij = [adj.eval(arg) for arg in root.slice.value.elts]
|
|
2454
|
-
else:
|
|
2455
|
-
ij = [adj.eval(root.slice)]
|
|
2456
|
-
|
|
2457
|
-
indices = ij + indices # prepend
|
|
2458
|
-
|
|
2459
|
-
root = root.value
|
|
2460
|
-
|
|
2461
|
-
target = adj.eval(root)
|
|
2479
|
+
def eval_indices(adj, target_type, indices):
|
|
2480
|
+
nodes = indices
|
|
2481
|
+
if hasattr(target_type, "_wp_generic_type_hint_"):
|
|
2482
|
+
indices = []
|
|
2483
|
+
for dim, node in enumerate(nodes):
|
|
2484
|
+
if isinstance(node, ast.Slice):
|
|
2485
|
+
# In the context of slicing a vec/mat type, indices are expected
|
|
2486
|
+
# to be compile-time constants, hence we can infer the actual slice
|
|
2487
|
+
# bounds also at compile-time.
|
|
2488
|
+
length = target_type._shape_[dim]
|
|
2489
|
+
step = 1 if node.step is None else adj.eval(node.step).constant
|
|
2490
|
+
|
|
2491
|
+
if node.lower is None:
|
|
2492
|
+
start = length - 1 if step < 0 else 0
|
|
2493
|
+
else:
|
|
2494
|
+
start = adj.eval(node.lower).constant
|
|
2495
|
+
start = min(max(start, -length), length)
|
|
2496
|
+
start = start + length if start < 0 else start
|
|
2462
2497
|
|
|
2463
|
-
|
|
2498
|
+
if node.upper is None:
|
|
2499
|
+
stop = -1 if step < 0 else length
|
|
2500
|
+
else:
|
|
2501
|
+
stop = adj.eval(node.upper).constant
|
|
2502
|
+
stop = min(max(stop, -length), length)
|
|
2503
|
+
stop = stop + length if stop < 0 else stop
|
|
2464
2504
|
|
|
2465
|
-
|
|
2466
|
-
|
|
2467
|
-
|
|
2468
|
-
|
|
2469
|
-
var = adj.eval(node.slice)
|
|
2470
|
-
var_name = var.label
|
|
2471
|
-
var = Var(f"adj_{var_name}", type=var.type, constant=None, prefix=False)
|
|
2472
|
-
return var
|
|
2505
|
+
slice = adj.add_builtin_call("slice", (start, stop, step))
|
|
2506
|
+
indices.append(slice)
|
|
2507
|
+
else:
|
|
2508
|
+
indices.append(adj.eval(node))
|
|
2473
2509
|
|
|
2474
|
-
|
|
2510
|
+
return tuple(indices)
|
|
2511
|
+
else:
|
|
2512
|
+
return tuple(adj.eval(x) for x in nodes)
|
|
2475
2513
|
|
|
2514
|
+
def emit_indexing(adj, target, indices):
|
|
2476
2515
|
target_type = strip_reference(target.type)
|
|
2516
|
+
indices = adj.eval_indices(target_type, indices)
|
|
2517
|
+
|
|
2477
2518
|
if is_array(target_type):
|
|
2478
2519
|
if len(indices) == target_type.ndim:
|
|
2479
2520
|
# handles array loads (where each dimension has an index specified)
|
|
@@ -2512,47 +2553,116 @@ class Adjoint:
|
|
|
2512
2553
|
|
|
2513
2554
|
return out
|
|
2514
2555
|
|
|
2556
|
+
# from a list of lists of indices, strip the first `count` indices
|
|
2557
|
+
@staticmethod
|
|
2558
|
+
def strip_indices(indices, count):
|
|
2559
|
+
dim = count
|
|
2560
|
+
while count > 0:
|
|
2561
|
+
ij = indices[0]
|
|
2562
|
+
indices = indices[1:]
|
|
2563
|
+
count -= len(ij)
|
|
2564
|
+
|
|
2565
|
+
# report straddling like in `arr2d[0][1,2]` as a syntax error
|
|
2566
|
+
if count < 0:
|
|
2567
|
+
raise WarpCodegenError(
|
|
2568
|
+
f"Incorrect number of indices specified for array indexing, got {dim - count} indices for a {dim} dimensional array."
|
|
2569
|
+
)
|
|
2570
|
+
|
|
2571
|
+
return indices
|
|
2572
|
+
|
|
2573
|
+
def recurse_subscript(adj, node, indices):
|
|
2574
|
+
if isinstance(node, ast.Name):
|
|
2575
|
+
target = adj.eval(node)
|
|
2576
|
+
return target, indices
|
|
2577
|
+
|
|
2578
|
+
if isinstance(node, ast.Subscript):
|
|
2579
|
+
if hasattr(node.value, "attr") and node.value.attr == "adjoint":
|
|
2580
|
+
return adj.eval(node), indices
|
|
2581
|
+
|
|
2582
|
+
if isinstance(node.slice, ast.Tuple):
|
|
2583
|
+
ij = node.slice.elts
|
|
2584
|
+
elif isinstance(node.slice, ast.Index) and isinstance(node.slice.value, ast.Tuple):
|
|
2585
|
+
# The node `ast.Index` is deprecated in Python 3.9.
|
|
2586
|
+
ij = node.slice.value.elts
|
|
2587
|
+
elif isinstance(node.slice, ast.ExtSlice):
|
|
2588
|
+
# The node `ast.ExtSlice` is deprecated in Python 3.9.
|
|
2589
|
+
ij = node.slice.dims
|
|
2590
|
+
else:
|
|
2591
|
+
ij = [node.slice]
|
|
2592
|
+
|
|
2593
|
+
indices = [ij, *indices] # prepend
|
|
2594
|
+
|
|
2595
|
+
target, indices = adj.recurse_subscript(node.value, indices)
|
|
2596
|
+
|
|
2597
|
+
target_type = strip_reference(target.type)
|
|
2598
|
+
if is_array(target_type):
|
|
2599
|
+
flat_indices = [i for ij in indices for i in ij]
|
|
2600
|
+
if len(flat_indices) > target_type.ndim:
|
|
2601
|
+
target = adj.emit_indexing(target, flat_indices[: target_type.ndim])
|
|
2602
|
+
indices = adj.strip_indices(indices, target_type.ndim)
|
|
2603
|
+
|
|
2604
|
+
return target, indices
|
|
2605
|
+
|
|
2606
|
+
target = adj.eval(node)
|
|
2607
|
+
return target, indices
|
|
2608
|
+
|
|
2609
|
+
# returns the object being indexed, and the list of indices
|
|
2610
|
+
def eval_subscript(adj, node):
|
|
2611
|
+
target, indices = adj.recurse_subscript(node, [])
|
|
2612
|
+
flat_indices = [i for ij in indices for i in ij]
|
|
2613
|
+
return target, flat_indices
|
|
2614
|
+
|
|
2615
|
+
def emit_Subscript(adj, node):
|
|
2616
|
+
if hasattr(node.value, "attr") and node.value.attr == "adjoint":
|
|
2617
|
+
# handle adjoint of a variable, i.e. wp.adjoint[var]
|
|
2618
|
+
node.slice.is_adjoint = True
|
|
2619
|
+
var = adj.eval(node.slice)
|
|
2620
|
+
var_name = var.label
|
|
2621
|
+
var = Var(f"adj_{var_name}", type=var.type, constant=None, prefix=False)
|
|
2622
|
+
return var
|
|
2623
|
+
|
|
2624
|
+
target, indices = adj.eval_subscript(node)
|
|
2625
|
+
|
|
2626
|
+
return adj.emit_indexing(target, indices)
|
|
2627
|
+
|
|
2515
2628
|
def emit_Assign(adj, node):
|
|
2516
2629
|
if len(node.targets) != 1:
|
|
2517
2630
|
raise WarpCodegenError("Assigning the same value to multiple variables is not supported")
|
|
2518
2631
|
|
|
2519
|
-
|
|
2632
|
+
# Check if the rhs corresponds to an unsupported construct.
|
|
2633
|
+
# Tuples are supported in the context of assigning multiple variables
|
|
2634
|
+
# at once, but not for simple assignments like `x = (1, 2, 3)`.
|
|
2635
|
+
# Therefore, we need to catch this specific case here instead of
|
|
2636
|
+
# more generally in `adj.eval()`.
|
|
2637
|
+
if isinstance(node.value, ast.List):
|
|
2638
|
+
raise WarpCodegenError(
|
|
2639
|
+
"List constructs are not supported in kernels. Use vectors like `wp.vec3()` for small collections instead."
|
|
2640
|
+
)
|
|
2520
2641
|
|
|
2521
|
-
|
|
2522
|
-
# Check if the rhs corresponds to an unsupported construct.
|
|
2523
|
-
# Tuples are supported in the context of assigning multiple variables
|
|
2524
|
-
# at once, but not for simple assignments like `x = (1, 2, 3)`.
|
|
2525
|
-
# Therefore, we need to catch this specific case here instead of
|
|
2526
|
-
# more generally in `adj.eval()`.
|
|
2527
|
-
if isinstance(node.value, ast.List):
|
|
2528
|
-
raise WarpCodegenError(
|
|
2529
|
-
"List constructs are not supported in kernels. Use vectors like `wp.vec3()` for small collections instead."
|
|
2530
|
-
)
|
|
2642
|
+
lhs = node.targets[0]
|
|
2531
2643
|
|
|
2532
|
-
|
|
2533
|
-
if isinstance(lhs, ast.Tuple):
|
|
2644
|
+
if isinstance(lhs, ast.Tuple) and isinstance(node.value, ast.Call):
|
|
2534
2645
|
# record the expected number of outputs on the node
|
|
2535
2646
|
# we do this so we can decide which function to
|
|
2536
2647
|
# call based on the number of expected outputs
|
|
2537
|
-
|
|
2538
|
-
node.value.expects = len(lhs.elts)
|
|
2648
|
+
node.value.expects = len(lhs.elts)
|
|
2539
2649
|
|
|
2540
|
-
|
|
2541
|
-
|
|
2542
|
-
|
|
2543
|
-
|
|
2544
|
-
|
|
2650
|
+
# evaluate rhs
|
|
2651
|
+
if isinstance(lhs, ast.Tuple) and isinstance(node.value, ast.Tuple):
|
|
2652
|
+
rhs = [adj.eval(v) for v in node.value.elts]
|
|
2653
|
+
else:
|
|
2654
|
+
rhs = adj.eval(node.value)
|
|
2655
|
+
|
|
2656
|
+
# handle the case where we are assigning multiple output variables
|
|
2657
|
+
if isinstance(lhs, ast.Tuple):
|
|
2658
|
+
subtype = getattr(rhs, "type", None)
|
|
2545
2659
|
|
|
2546
|
-
subtype = getattr(out, "type", None)
|
|
2547
2660
|
if isinstance(subtype, warp.types.tuple_t):
|
|
2548
|
-
if len(
|
|
2661
|
+
if len(rhs.type.types) != len(lhs.elts):
|
|
2549
2662
|
raise WarpCodegenError(
|
|
2550
|
-
f"Invalid number of values to unpack (expected {len(lhs.elts)}, got {len(
|
|
2663
|
+
f"Invalid number of values to unpack (expected {len(lhs.elts)}, got {len(rhs.type.types)})."
|
|
2551
2664
|
)
|
|
2552
|
-
|
|
2553
|
-
out = tuple(
|
|
2554
|
-
adj.add_builtin_call("extract", (target, adj.add_constant(i))) for i in range(len(lhs.elts))
|
|
2555
|
-
)
|
|
2665
|
+
rhs = tuple(adj.add_builtin_call("extract", (rhs, adj.add_constant(i))) for i in range(len(lhs.elts)))
|
|
2556
2666
|
|
|
2557
2667
|
names = []
|
|
2558
2668
|
for v in lhs.elts:
|
|
@@ -2563,11 +2673,12 @@ class Adjoint:
|
|
|
2563
2673
|
"Multiple return functions can only assign to simple variables, e.g.: x, y = func()"
|
|
2564
2674
|
)
|
|
2565
2675
|
|
|
2566
|
-
if len(names) != len(
|
|
2676
|
+
if len(names) != len(rhs):
|
|
2567
2677
|
raise WarpCodegenError(
|
|
2568
|
-
f"Multiple return functions need to receive all their output values, incorrect number of values to unpack (expected {len(
|
|
2678
|
+
f"Multiple return functions need to receive all their output values, incorrect number of values to unpack (expected {len(rhs)}, got {len(names)})"
|
|
2569
2679
|
)
|
|
2570
2680
|
|
|
2681
|
+
out = rhs
|
|
2571
2682
|
for name, rhs in zip(names, out):
|
|
2572
2683
|
if name in adj.symbols:
|
|
2573
2684
|
if not types_equal(rhs.type, adj.symbols[name].type):
|
|
@@ -2579,8 +2690,6 @@ class Adjoint:
|
|
|
2579
2690
|
|
|
2580
2691
|
# handles the case where we are assigning to an array index (e.g.: arr[i] = 2.0)
|
|
2581
2692
|
elif isinstance(lhs, ast.Subscript):
|
|
2582
|
-
rhs = adj.eval(node.value)
|
|
2583
|
-
|
|
2584
2693
|
if hasattr(lhs.value, "attr") and lhs.value.attr == "adjoint":
|
|
2585
2694
|
# handle adjoint of a variable, i.e. wp.adjoint[var]
|
|
2586
2695
|
lhs.slice.is_adjoint = True
|
|
@@ -2592,6 +2701,7 @@ class Adjoint:
|
|
|
2592
2701
|
target, indices = adj.eval_subscript(lhs)
|
|
2593
2702
|
|
|
2594
2703
|
target_type = strip_reference(target.type)
|
|
2704
|
+
indices = adj.eval_indices(target_type, indices)
|
|
2595
2705
|
|
|
2596
2706
|
if is_array(target_type):
|
|
2597
2707
|
adj.add_builtin_call("array_store", [target, *indices, rhs])
|
|
@@ -2613,14 +2723,11 @@ class Adjoint:
|
|
|
2613
2723
|
or type_is_transformation(target_type)
|
|
2614
2724
|
):
|
|
2615
2725
|
# recursively unwind AST, stopping at penultimate node
|
|
2616
|
-
|
|
2617
|
-
while hasattr(
|
|
2618
|
-
|
|
2619
|
-
node = node.value
|
|
2620
|
-
else:
|
|
2621
|
-
break
|
|
2726
|
+
root = lhs
|
|
2727
|
+
while hasattr(root.value, "value"):
|
|
2728
|
+
root = root.value
|
|
2622
2729
|
# lhs is updating a variable adjoint (i.e. wp.adjoint[var])
|
|
2623
|
-
if hasattr(
|
|
2730
|
+
if hasattr(root, "attr") and root.attr == "adjoint":
|
|
2624
2731
|
attr = adj.add_builtin_call("index", [target, *indices])
|
|
2625
2732
|
adj.add_builtin_call("store", [attr, rhs])
|
|
2626
2733
|
return
|
|
@@ -2658,9 +2765,6 @@ class Adjoint:
|
|
|
2658
2765
|
# symbol name
|
|
2659
2766
|
name = lhs.id
|
|
2660
2767
|
|
|
2661
|
-
# evaluate rhs
|
|
2662
|
-
rhs = adj.eval(node.value)
|
|
2663
|
-
|
|
2664
2768
|
# check type matches if symbol already defined
|
|
2665
2769
|
if name in adj.symbols:
|
|
2666
2770
|
if not types_equal(strip_reference(rhs.type), adj.symbols[name].type):
|
|
@@ -2681,7 +2785,6 @@ class Adjoint:
|
|
|
2681
2785
|
adj.symbols[name] = out
|
|
2682
2786
|
|
|
2683
2787
|
elif isinstance(lhs, ast.Attribute):
|
|
2684
|
-
rhs = adj.eval(node.value)
|
|
2685
2788
|
aggregate = adj.eval(lhs.value)
|
|
2686
2789
|
aggregate_type = strip_reference(aggregate.type)
|
|
2687
2790
|
|
|
@@ -2769,9 +2872,9 @@ class Adjoint:
|
|
|
2769
2872
|
new_node = ast.Assign(targets=[lhs], value=ast.BinOp(lhs, node.op, node.value))
|
|
2770
2873
|
adj.eval(new_node)
|
|
2771
2874
|
|
|
2772
|
-
|
|
2773
|
-
rhs = adj.eval(node.value)
|
|
2875
|
+
rhs = adj.eval(node.value)
|
|
2774
2876
|
|
|
2877
|
+
if isinstance(lhs, ast.Subscript):
|
|
2775
2878
|
# wp.adjoint[var] appears in custom grad functions, and does not require
|
|
2776
2879
|
# special consideration in the AugAssign case
|
|
2777
2880
|
if hasattr(lhs.value, "attr") and lhs.value.attr == "adjoint":
|
|
@@ -2781,6 +2884,7 @@ class Adjoint:
|
|
|
2781
2884
|
target, indices = adj.eval_subscript(lhs)
|
|
2782
2885
|
|
|
2783
2886
|
target_type = strip_reference(target.type)
|
|
2887
|
+
indices = adj.eval_indices(target_type, indices)
|
|
2784
2888
|
|
|
2785
2889
|
if is_array(target_type):
|
|
2786
2890
|
# target_types int8, uint8, int16, uint16 are not suitable for atomic array accumulation
|
|
@@ -2853,7 +2957,6 @@ class Adjoint:
|
|
|
2853
2957
|
|
|
2854
2958
|
elif isinstance(lhs, ast.Name):
|
|
2855
2959
|
target = adj.eval(node.target)
|
|
2856
|
-
rhs = adj.eval(node.value)
|
|
2857
2960
|
|
|
2858
2961
|
if is_tile(target.type) and is_tile(rhs.type):
|
|
2859
2962
|
if isinstance(node.op, ast.Add):
|
|
@@ -3109,6 +3212,7 @@ class Adjoint:
|
|
|
3109
3212
|
|
|
3110
3213
|
# Since this is an expression, we can enforce it to be defined on a single line.
|
|
3111
3214
|
static_code = static_code.replace("\n", "")
|
|
3215
|
+
code_to_eval = static_code # code to be evaluated
|
|
3112
3216
|
|
|
3113
3217
|
vars_dict = adj.get_static_evaluation_context()
|
|
3114
3218
|
# add constant variables to the static call context
|
|
@@ -3150,10 +3254,12 @@ class Adjoint:
|
|
|
3150
3254
|
loc = end
|
|
3151
3255
|
|
|
3152
3256
|
new_static_code += static_code[len_value_locs[-1][2] :]
|
|
3153
|
-
|
|
3257
|
+
code_to_eval = new_static_code
|
|
3154
3258
|
|
|
3155
3259
|
try:
|
|
3156
|
-
value = eval(
|
|
3260
|
+
value = eval(code_to_eval, vars_dict)
|
|
3261
|
+
if isinstance(value, (enum.IntEnum, enum.IntFlag)):
|
|
3262
|
+
value = int(value)
|
|
3157
3263
|
if warp.config.verbose:
|
|
3158
3264
|
print(f"Evaluated static command: {static_code} = {value}")
|
|
3159
3265
|
except NameError as e:
|
|
@@ -3206,6 +3312,9 @@ class Adjoint:
|
|
|
3206
3312
|
# (and is therefore not executable and raises this exception), in which
|
|
3207
3313
|
# case changing the constant, or the code affecting this constant, would lead to
|
|
3208
3314
|
# a different module hash anyway.
|
|
3315
|
+
# In any case, we mark this Adjoint to have unresolvable static expressions.
|
|
3316
|
+
# This will trigger a code generation step even if the module hash is unchanged.
|
|
3317
|
+
adj.has_unresolved_static_expressions = True
|
|
3209
3318
|
pass
|
|
3210
3319
|
|
|
3211
3320
|
return self.generic_visit(node)
|
|
@@ -3361,6 +3470,11 @@ cuda_module_header = """
|
|
|
3361
3470
|
#define WP_NO_CRT
|
|
3362
3471
|
#include "builtin.h"
|
|
3363
3472
|
|
|
3473
|
+
// Map wp.breakpoint() to a device brkpt at the call site so cuda-gdb attributes the stop to the generated .cu line
|
|
3474
|
+
#if defined(__CUDACC__) && !defined(_MSC_VER)
|
|
3475
|
+
#define __debugbreak() __brkpt()
|
|
3476
|
+
#endif
|
|
3477
|
+
|
|
3364
3478
|
// avoid namespacing of float type for casting to float type, this is to avoid wp::float(x), which is not valid in C++
|
|
3365
3479
|
#define float(x) cast_float(x)
|
|
3366
3480
|
#define adj_float(x, adj_x, adj_ret) adj_cast_float(x, adj_x, adj_ret)
|
|
@@ -3398,6 +3512,12 @@ static CUDA_CALLABLE void adj_{name}({reverse_args})
|
|
|
3398
3512
|
{{
|
|
3399
3513
|
{reverse_body}}}
|
|
3400
3514
|
|
|
3515
|
+
// Required when compiling adjoints.
|
|
3516
|
+
CUDA_CALLABLE {name} add(const {name}& a, const {name}& b)
|
|
3517
|
+
{{
|
|
3518
|
+
return {name}();
|
|
3519
|
+
}}
|
|
3520
|
+
|
|
3401
3521
|
CUDA_CALLABLE void adj_atomic_add({name}* p, {name} t)
|
|
3402
3522
|
{{
|
|
3403
3523
|
{atomic_add_body}}}
|
|
@@ -3478,7 +3598,8 @@ cuda_kernel_template_backward = """
|
|
|
3478
3598
|
cpu_kernel_template_forward = """
|
|
3479
3599
|
|
|
3480
3600
|
void {name}_cpu_kernel_forward(
|
|
3481
|
-
{forward_args}
|
|
3601
|
+
{forward_args},
|
|
3602
|
+
wp_args_{name} *_wp_args)
|
|
3482
3603
|
{{
|
|
3483
3604
|
{forward_body}}}
|
|
3484
3605
|
|
|
@@ -3487,7 +3608,9 @@ void {name}_cpu_kernel_forward(
|
|
|
3487
3608
|
cpu_kernel_template_backward = """
|
|
3488
3609
|
|
|
3489
3610
|
void {name}_cpu_kernel_backward(
|
|
3490
|
-
{reverse_args}
|
|
3611
|
+
{reverse_args},
|
|
3612
|
+
wp_args_{name} *_wp_args,
|
|
3613
|
+
wp_args_{name} *_wp_adj_args)
|
|
3491
3614
|
{{
|
|
3492
3615
|
{reverse_body}}}
|
|
3493
3616
|
|
|
@@ -3499,15 +3622,15 @@ extern "C" {{
|
|
|
3499
3622
|
|
|
3500
3623
|
// Python CPU entry points
|
|
3501
3624
|
WP_API void {name}_cpu_forward(
|
|
3502
|
-
|
|
3625
|
+
wp::launch_bounds_t dim,
|
|
3626
|
+
wp_args_{name} *_wp_args)
|
|
3503
3627
|
{{
|
|
3504
3628
|
for (size_t task_index = 0; task_index < dim.size; ++task_index)
|
|
3505
3629
|
{{
|
|
3506
3630
|
// init shared memory allocator
|
|
3507
3631
|
wp::tile_alloc_shared(0, true);
|
|
3508
3632
|
|
|
3509
|
-
{name}_cpu_kernel_forward(
|
|
3510
|
-
{forward_params});
|
|
3633
|
+
{name}_cpu_kernel_forward(dim, task_index, _wp_args);
|
|
3511
3634
|
|
|
3512
3635
|
// check shared memory allocator
|
|
3513
3636
|
wp::tile_alloc_shared(0, false, true);
|
|
@@ -3524,15 +3647,16 @@ cpu_module_template_backward = """
|
|
|
3524
3647
|
extern "C" {{
|
|
3525
3648
|
|
|
3526
3649
|
WP_API void {name}_cpu_backward(
|
|
3527
|
-
|
|
3650
|
+
wp::launch_bounds_t dim,
|
|
3651
|
+
wp_args_{name} *_wp_args,
|
|
3652
|
+
wp_args_{name} *_wp_adj_args)
|
|
3528
3653
|
{{
|
|
3529
3654
|
for (size_t task_index = 0; task_index < dim.size; ++task_index)
|
|
3530
3655
|
{{
|
|
3531
3656
|
// initialize shared memory allocator
|
|
3532
3657
|
wp::tile_alloc_shared(0, true);
|
|
3533
3658
|
|
|
3534
|
-
{name}_cpu_kernel_backward(
|
|
3535
|
-
{reverse_params});
|
|
3659
|
+
{name}_cpu_kernel_backward(dim, task_index, _wp_args, _wp_adj_args);
|
|
3536
3660
|
|
|
3537
3661
|
// check shared memory allocator
|
|
3538
3662
|
wp::tile_alloc_shared(0, false, true);
|
|
@@ -3563,7 +3687,7 @@ def constant_str(value):
|
|
|
3563
3687
|
# special case for float16, which is stored as uint16 in the ctypes.Array
|
|
3564
3688
|
from warp.context import runtime
|
|
3565
3689
|
|
|
3566
|
-
scalar_value = runtime.core.
|
|
3690
|
+
scalar_value = runtime.core.wp_half_bits_to_float
|
|
3567
3691
|
else:
|
|
3568
3692
|
|
|
3569
3693
|
def scalar_value(x):
|
|
@@ -3701,8 +3825,17 @@ def codegen_func_forward(adj, func_type="kernel", device="cpu"):
|
|
|
3701
3825
|
|
|
3702
3826
|
indent_block = " " * indent
|
|
3703
3827
|
|
|
3704
|
-
# primal vars
|
|
3705
3828
|
lines = []
|
|
3829
|
+
|
|
3830
|
+
# argument vars
|
|
3831
|
+
if device == "cpu" and func_type == "kernel":
|
|
3832
|
+
lines += ["//---------\n"]
|
|
3833
|
+
lines += ["// argument vars\n"]
|
|
3834
|
+
|
|
3835
|
+
for var in adj.args:
|
|
3836
|
+
lines += [f"{var.ctype()} {var.emit()} = _wp_args->{var.label};\n"]
|
|
3837
|
+
|
|
3838
|
+
# primal vars
|
|
3706
3839
|
lines += ["//---------\n"]
|
|
3707
3840
|
lines += ["// primal vars\n"]
|
|
3708
3841
|
|
|
@@ -3746,6 +3879,17 @@ def codegen_func_reverse(adj, func_type="kernel", device="cpu"):
|
|
|
3746
3879
|
|
|
3747
3880
|
lines = []
|
|
3748
3881
|
|
|
3882
|
+
# argument vars
|
|
3883
|
+
if device == "cpu" and func_type == "kernel":
|
|
3884
|
+
lines += ["//---------\n"]
|
|
3885
|
+
lines += ["// argument vars\n"]
|
|
3886
|
+
|
|
3887
|
+
for var in adj.args:
|
|
3888
|
+
lines += [f"{var.ctype()} {var.emit()} = _wp_args->{var.label};\n"]
|
|
3889
|
+
|
|
3890
|
+
for var in adj.args:
|
|
3891
|
+
lines += [f"{var.ctype()} {var.emit_adj()} = _wp_adj_args->{var.label};\n"]
|
|
3892
|
+
|
|
3749
3893
|
# primal vars
|
|
3750
3894
|
lines += ["//---------\n"]
|
|
3751
3895
|
lines += ["// primal vars\n"]
|
|
@@ -3837,6 +3981,19 @@ def codegen_func(adj, c_func_name: str, device="cpu", options=None):
|
|
|
3837
3981
|
f"annotated as `{warp.context.type_str(adj.arg_types['return'])}` "
|
|
3838
3982
|
f"but the code returns a value of type `{warp.context.type_str(adj.return_var[0].type)}`."
|
|
3839
3983
|
)
|
|
3984
|
+
elif (
|
|
3985
|
+
isinstance(adj.return_var[0].type, warp.types.fixedarray)
|
|
3986
|
+
and type(adj.arg_types["return"]) is warp.types.array
|
|
3987
|
+
):
|
|
3988
|
+
# If the return statement yields a `fixedarray` while the function is annotated
|
|
3989
|
+
# to return a standard `array`, then raise an error since the `fixedarray` storage
|
|
3990
|
+
# allocated on the stack will be freed once the function exits, meaning that the
|
|
3991
|
+
# resulting `array` instance will point to an invalid data.
|
|
3992
|
+
raise WarpCodegenError(
|
|
3993
|
+
f"The function `{adj.fun_name}` returns a fixed-size array "
|
|
3994
|
+
f"whereas it has its return type annotated as "
|
|
3995
|
+
f"`{warp.context.type_str(adj.arg_types['return'])}`."
|
|
3996
|
+
)
|
|
3840
3997
|
|
|
3841
3998
|
# Build line directive for function definition (subtract 1 to account for 1-indexing of AST line numbers)
|
|
3842
3999
|
# This is used as a catch-all C-to-Python source line mapping for any code that does not have
|
|
@@ -3915,10 +4072,10 @@ def codegen_func(adj, c_func_name: str, device="cpu", options=None):
|
|
|
3915
4072
|
if adj.custom_reverse_mode:
|
|
3916
4073
|
reverse_body = "\t// user-defined adjoint code\n" + forward_body
|
|
3917
4074
|
else:
|
|
3918
|
-
if options.get("enable_backward", True):
|
|
4075
|
+
if options.get("enable_backward", True) and adj.used_by_backward_kernel:
|
|
3919
4076
|
reverse_body = codegen_func_reverse(adj, func_type="function", device=device)
|
|
3920
4077
|
else:
|
|
3921
|
-
reverse_body = '\t// reverse mode disabled (module option "enable_backward" is False)\n'
|
|
4078
|
+
reverse_body = '\t// reverse mode disabled (module option "enable_backward" is False or no dependent kernel found with "enable_backward")\n'
|
|
3922
4079
|
s += reverse_template.format(
|
|
3923
4080
|
name=c_func_name,
|
|
3924
4081
|
return_type=return_type,
|
|
@@ -4010,6 +4167,13 @@ def codegen_kernel(kernel, device, options):
|
|
|
4010
4167
|
|
|
4011
4168
|
adj = kernel.adj
|
|
4012
4169
|
|
|
4170
|
+
args_struct = ""
|
|
4171
|
+
if device == "cpu":
|
|
4172
|
+
args_struct = f"struct wp_args_{kernel.get_mangled_name()} {{\n"
|
|
4173
|
+
for i in adj.args:
|
|
4174
|
+
args_struct += f" {i.ctype()} {i.label};\n"
|
|
4175
|
+
args_struct += "};\n"
|
|
4176
|
+
|
|
4013
4177
|
# Build line directive for function definition (subtract 1 to account for 1-indexing of AST line numbers)
|
|
4014
4178
|
# This is used as a catch-all C-to-Python source line mapping for any code that does not have
|
|
4015
4179
|
# a direct mapping to a Python source line.
|
|
@@ -4035,9 +4199,9 @@ def codegen_kernel(kernel, device, options):
|
|
|
4035
4199
|
forward_args = ["wp::launch_bounds_t dim"]
|
|
4036
4200
|
if device == "cpu":
|
|
4037
4201
|
forward_args.append("size_t task_index")
|
|
4038
|
-
|
|
4039
|
-
|
|
4040
|
-
|
|
4202
|
+
else:
|
|
4203
|
+
for arg in adj.args:
|
|
4204
|
+
forward_args.append(arg.ctype() + " var_" + arg.label)
|
|
4041
4205
|
|
|
4042
4206
|
forward_body = codegen_func_forward(adj, func_type="kernel", device=device)
|
|
4043
4207
|
template_fmt_args.update(
|
|
@@ -4054,17 +4218,16 @@ def codegen_kernel(kernel, device, options):
|
|
|
4054
4218
|
reverse_args = ["wp::launch_bounds_t dim"]
|
|
4055
4219
|
if device == "cpu":
|
|
4056
4220
|
reverse_args.append("size_t task_index")
|
|
4057
|
-
|
|
4058
|
-
|
|
4059
|
-
|
|
4060
|
-
|
|
4061
|
-
|
|
4062
|
-
|
|
4063
|
-
|
|
4064
|
-
|
|
4065
|
-
|
|
4066
|
-
|
|
4067
|
-
reverse_args.append(arg.ctype() + " adj_" + arg.label)
|
|
4221
|
+
else:
|
|
4222
|
+
for arg in adj.args:
|
|
4223
|
+
reverse_args.append(arg.ctype() + " var_" + arg.label)
|
|
4224
|
+
for arg in adj.args:
|
|
4225
|
+
# indexed array gradients are regular arrays
|
|
4226
|
+
if isinstance(arg.type, indexedarray):
|
|
4227
|
+
_arg = Var(arg.label, array(dtype=arg.type.dtype, ndim=arg.type.ndim))
|
|
4228
|
+
reverse_args.append(_arg.ctype() + " adj_" + arg.label)
|
|
4229
|
+
else:
|
|
4230
|
+
reverse_args.append(arg.ctype() + " adj_" + arg.label)
|
|
4068
4231
|
|
|
4069
4232
|
reverse_body = codegen_func_reverse(adj, func_type="kernel", device=device)
|
|
4070
4233
|
template_fmt_args.update(
|
|
@@ -4076,7 +4239,7 @@ def codegen_kernel(kernel, device, options):
|
|
|
4076
4239
|
template += template_backward
|
|
4077
4240
|
|
|
4078
4241
|
s = template.format(**template_fmt_args)
|
|
4079
|
-
return s
|
|
4242
|
+
return args_struct + s
|
|
4080
4243
|
|
|
4081
4244
|
|
|
4082
4245
|
def codegen_module(kernel, device, options):
|
|
@@ -4087,59 +4250,14 @@ def codegen_module(kernel, device, options):
|
|
|
4087
4250
|
options = dict(options)
|
|
4088
4251
|
options.update(kernel.options)
|
|
4089
4252
|
|
|
4090
|
-
adj = kernel.adj
|
|
4091
|
-
|
|
4092
4253
|
template = ""
|
|
4093
4254
|
template_fmt_args = {
|
|
4094
4255
|
"name": kernel.get_mangled_name(),
|
|
4095
4256
|
}
|
|
4096
4257
|
|
|
4097
|
-
# build forward signature
|
|
4098
|
-
forward_args = ["wp::launch_bounds_t dim"]
|
|
4099
|
-
forward_params = ["dim", "task_index"]
|
|
4100
|
-
|
|
4101
|
-
for arg in adj.args:
|
|
4102
|
-
if hasattr(arg.type, "_wp_generic_type_str_"):
|
|
4103
|
-
# vectors and matrices are passed from Python by pointer
|
|
4104
|
-
forward_args.append(f"const {arg.ctype()}* var_" + arg.label)
|
|
4105
|
-
forward_params.append(f"*var_{arg.label}")
|
|
4106
|
-
else:
|
|
4107
|
-
forward_args.append(f"{arg.ctype()} var_{arg.label}")
|
|
4108
|
-
forward_params.append("var_" + arg.label)
|
|
4109
|
-
|
|
4110
|
-
template_fmt_args.update(
|
|
4111
|
-
{
|
|
4112
|
-
"forward_args": indent(forward_args),
|
|
4113
|
-
"forward_params": indent(forward_params, 3),
|
|
4114
|
-
}
|
|
4115
|
-
)
|
|
4116
4258
|
template += cpu_module_template_forward
|
|
4117
4259
|
|
|
4118
4260
|
if options["enable_backward"]:
|
|
4119
|
-
# build reverse signature
|
|
4120
|
-
reverse_args = [*forward_args]
|
|
4121
|
-
reverse_params = [*forward_params]
|
|
4122
|
-
|
|
4123
|
-
for arg in adj.args:
|
|
4124
|
-
if isinstance(arg.type, indexedarray):
|
|
4125
|
-
# indexed array gradients are regular arrays
|
|
4126
|
-
_arg = Var(arg.label, array(dtype=arg.type.dtype, ndim=arg.type.ndim))
|
|
4127
|
-
reverse_args.append(f"const {_arg.ctype()} adj_{arg.label}")
|
|
4128
|
-
reverse_params.append(f"adj_{_arg.label}")
|
|
4129
|
-
elif hasattr(arg.type, "_wp_generic_type_str_"):
|
|
4130
|
-
# vectors and matrices are passed from Python by pointer
|
|
4131
|
-
reverse_args.append(f"const {arg.ctype()}* adj_{arg.label}")
|
|
4132
|
-
reverse_params.append(f"*adj_{arg.label}")
|
|
4133
|
-
else:
|
|
4134
|
-
reverse_args.append(f"{arg.ctype()} adj_{arg.label}")
|
|
4135
|
-
reverse_params.append(f"adj_{arg.label}")
|
|
4136
|
-
|
|
4137
|
-
template_fmt_args.update(
|
|
4138
|
-
{
|
|
4139
|
-
"reverse_args": indent(reverse_args),
|
|
4140
|
-
"reverse_params": indent(reverse_params, 3),
|
|
4141
|
-
}
|
|
4142
|
-
)
|
|
4143
4261
|
template += cpu_module_template_backward
|
|
4144
4262
|
|
|
4145
4263
|
s = template.format(**template_fmt_args)
|