warp-lang 1.8.1__py3-none-win_amd64.whl → 1.9.1__py3-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +282 -103
- warp/__init__.pyi +1904 -114
- warp/bin/warp-clang.dll +0 -0
- warp/bin/warp.dll +0 -0
- warp/build.py +93 -30
- warp/build_dll.py +331 -101
- warp/builtins.py +1244 -160
- warp/codegen.py +317 -206
- warp/config.py +1 -1
- warp/context.py +1465 -789
- warp/examples/core/example_marching_cubes.py +1 -0
- warp/examples/core/example_render_opengl.py +100 -3
- warp/examples/fem/example_apic_fluid.py +98 -52
- warp/examples/fem/example_convection_diffusion_dg.py +25 -4
- warp/examples/fem/example_diffusion_mgpu.py +8 -3
- warp/examples/fem/utils.py +68 -22
- warp/examples/interop/example_jax_kernel.py +2 -1
- warp/fabric.py +1 -1
- warp/fem/cache.py +27 -19
- warp/fem/domain.py +2 -2
- warp/fem/field/nodal_field.py +2 -2
- warp/fem/field/virtual.py +264 -166
- warp/fem/geometry/geometry.py +5 -5
- warp/fem/integrate.py +129 -51
- warp/fem/space/restriction.py +4 -0
- warp/fem/space/shape/tet_shape_function.py +3 -10
- warp/jax_experimental/custom_call.py +25 -2
- warp/jax_experimental/ffi.py +22 -1
- warp/jax_experimental/xla_ffi.py +16 -7
- warp/marching_cubes.py +708 -0
- warp/native/array.h +99 -4
- warp/native/builtin.h +86 -9
- warp/native/bvh.cpp +64 -28
- warp/native/bvh.cu +58 -58
- warp/native/bvh.h +2 -2
- warp/native/clang/clang.cpp +7 -7
- warp/native/coloring.cpp +8 -2
- warp/native/crt.cpp +2 -2
- warp/native/crt.h +3 -5
- warp/native/cuda_util.cpp +41 -10
- warp/native/cuda_util.h +10 -4
- warp/native/exports.h +1842 -1908
- warp/native/fabric.h +2 -1
- warp/native/hashgrid.cpp +37 -37
- warp/native/hashgrid.cu +2 -2
- warp/native/initializer_array.h +1 -1
- warp/native/intersect.h +2 -2
- warp/native/mat.h +1910 -116
- warp/native/mathdx.cpp +43 -43
- warp/native/mesh.cpp +24 -24
- warp/native/mesh.cu +26 -26
- warp/native/mesh.h +4 -2
- warp/native/nanovdb/GridHandle.h +179 -12
- warp/native/nanovdb/HostBuffer.h +8 -7
- warp/native/nanovdb/NanoVDB.h +517 -895
- warp/native/nanovdb/NodeManager.h +323 -0
- warp/native/nanovdb/PNanoVDB.h +2 -2
- warp/native/quat.h +331 -14
- warp/native/range.h +7 -1
- warp/native/reduce.cpp +10 -10
- warp/native/reduce.cu +13 -14
- warp/native/runlength_encode.cpp +2 -2
- warp/native/runlength_encode.cu +5 -5
- warp/native/scan.cpp +3 -3
- warp/native/scan.cu +4 -4
- warp/native/sort.cpp +10 -10
- warp/native/sort.cu +40 -31
- warp/native/sort.h +2 -0
- warp/native/sparse.cpp +8 -8
- warp/native/sparse.cu +13 -13
- warp/native/spatial.h +366 -17
- warp/native/temp_buffer.h +2 -2
- warp/native/tile.h +471 -82
- warp/native/vec.h +328 -14
- warp/native/volume.cpp +54 -54
- warp/native/volume.cu +1 -1
- warp/native/volume.h +2 -1
- warp/native/volume_builder.cu +30 -37
- warp/native/warp.cpp +150 -149
- warp/native/warp.cu +377 -216
- warp/native/warp.h +227 -226
- warp/optim/linear.py +736 -271
- warp/render/imgui_manager.py +289 -0
- warp/render/render_opengl.py +99 -18
- warp/render/render_usd.py +1 -0
- warp/sim/graph_coloring.py +2 -2
- warp/sparse.py +558 -175
- warp/tests/aux_test_module_aot.py +7 -0
- warp/tests/cuda/test_async.py +3 -3
- warp/tests/cuda/test_conditional_captures.py +101 -0
- warp/tests/geometry/test_hash_grid.py +38 -0
- warp/tests/geometry/test_marching_cubes.py +233 -12
- warp/tests/interop/test_jax.py +608 -28
- warp/tests/sim/test_coloring.py +6 -6
- warp/tests/test_array.py +58 -5
- warp/tests/test_codegen.py +4 -3
- warp/tests/test_context.py +8 -15
- warp/tests/test_enum.py +136 -0
- warp/tests/test_examples.py +2 -2
- warp/tests/test_fem.py +49 -6
- warp/tests/test_fixedarray.py +229 -0
- warp/tests/test_func.py +18 -15
- warp/tests/test_future_annotations.py +7 -5
- warp/tests/test_linear_solvers.py +30 -0
- warp/tests/test_map.py +15 -1
- warp/tests/test_mat.py +1518 -378
- warp/tests/test_mat_assign_copy.py +178 -0
- warp/tests/test_mat_constructors.py +574 -0
- warp/tests/test_module_aot.py +287 -0
- warp/tests/test_print.py +69 -0
- warp/tests/test_quat.py +140 -34
- warp/tests/test_quat_assign_copy.py +145 -0
- warp/tests/test_reload.py +2 -1
- warp/tests/test_sparse.py +71 -0
- warp/tests/test_spatial.py +140 -34
- warp/tests/test_spatial_assign_copy.py +160 -0
- warp/tests/test_struct.py +43 -3
- warp/tests/test_tuple.py +96 -0
- warp/tests/test_types.py +61 -20
- warp/tests/test_vec.py +179 -34
- warp/tests/test_vec_assign_copy.py +143 -0
- warp/tests/tile/test_tile.py +245 -18
- warp/tests/tile/test_tile_cholesky.py +605 -0
- warp/tests/tile/test_tile_load.py +169 -0
- warp/tests/tile/test_tile_mathdx.py +2 -558
- warp/tests/tile/test_tile_matmul.py +1 -1
- warp/tests/tile/test_tile_mlp.py +1 -1
- warp/tests/tile/test_tile_shared_memory.py +5 -5
- warp/tests/unittest_suites.py +6 -0
- warp/tests/walkthrough_debug.py +1 -1
- warp/thirdparty/unittest_parallel.py +108 -9
- warp/types.py +571 -267
- warp/utils.py +68 -86
- {warp_lang-1.8.1.dist-info → warp_lang-1.9.1.dist-info}/METADATA +29 -69
- {warp_lang-1.8.1.dist-info → warp_lang-1.9.1.dist-info}/RECORD +138 -128
- warp/native/marching.cpp +0 -19
- warp/native/marching.cu +0 -514
- warp/native/marching.h +0 -19
- {warp_lang-1.8.1.dist-info → warp_lang-1.9.1.dist-info}/WHEEL +0 -0
- {warp_lang-1.8.1.dist-info → warp_lang-1.9.1.dist-info}/licenses/LICENSE.md +0 -0
- {warp_lang-1.8.1.dist-info → warp_lang-1.9.1.dist-info}/top_level.txt +0 -0
warp/codegen.py
CHANGED
|
@@ -18,9 +18,11 @@ from __future__ import annotations
|
|
|
18
18
|
import ast
|
|
19
19
|
import builtins
|
|
20
20
|
import ctypes
|
|
21
|
+
import enum
|
|
21
22
|
import functools
|
|
22
23
|
import hashlib
|
|
23
24
|
import inspect
|
|
25
|
+
import itertools
|
|
24
26
|
import math
|
|
25
27
|
import re
|
|
26
28
|
import sys
|
|
@@ -614,6 +616,8 @@ def compute_type_str(base_name, template_params):
|
|
|
614
616
|
return base_name
|
|
615
617
|
|
|
616
618
|
def param2str(p):
|
|
619
|
+
if isinstance(p, builtins.bool):
|
|
620
|
+
return "true" if p else "false"
|
|
617
621
|
if isinstance(p, int):
|
|
618
622
|
return str(p)
|
|
619
623
|
elif hasattr(p, "_wp_generic_type_str_"):
|
|
@@ -625,6 +629,8 @@ def compute_type_str(base_name, template_params):
|
|
|
625
629
|
return f"wp::{p.__name__}"
|
|
626
630
|
elif is_tile(p):
|
|
627
631
|
return p.ctype()
|
|
632
|
+
elif isinstance(p, Struct):
|
|
633
|
+
return p.native_name
|
|
628
634
|
|
|
629
635
|
return p.__name__
|
|
630
636
|
|
|
@@ -684,7 +690,12 @@ class Var:
|
|
|
684
690
|
|
|
685
691
|
@staticmethod
|
|
686
692
|
def type_to_ctype(t: type, value_type: builtins.bool = False) -> str:
|
|
687
|
-
if
|
|
693
|
+
if isinstance(t, fixedarray):
|
|
694
|
+
template_args = (str(t.size), Var.dtype_to_ctype(t.dtype))
|
|
695
|
+
dtypestr = ", ".join(template_args)
|
|
696
|
+
classstr = f"wp::{type(t).__name__}"
|
|
697
|
+
return f"{classstr}_t<{dtypestr}>"
|
|
698
|
+
elif is_array(t):
|
|
688
699
|
dtypestr = Var.dtype_to_ctype(t.dtype)
|
|
689
700
|
classstr = f"wp::{type(t).__name__}"
|
|
690
701
|
return f"{classstr}_t<{dtypestr}>"
|
|
@@ -780,11 +791,10 @@ def apply_defaults(
|
|
|
780
791
|
arguments = bound_args.arguments
|
|
781
792
|
new_arguments = []
|
|
782
793
|
for name in bound_args._signature.parameters.keys():
|
|
783
|
-
|
|
794
|
+
if name in arguments:
|
|
784
795
|
new_arguments.append((name, arguments[name]))
|
|
785
|
-
|
|
786
|
-
|
|
787
|
-
new_arguments.append((name, values[name]))
|
|
796
|
+
elif name in values:
|
|
797
|
+
new_arguments.append((name, values[name]))
|
|
788
798
|
|
|
789
799
|
bound_args.arguments = dict(new_arguments)
|
|
790
800
|
|
|
@@ -837,6 +847,9 @@ def get_arg_type(arg: Var | Any) -> type:
|
|
|
837
847
|
if isinstance(arg, Sequence):
|
|
838
848
|
return tuple(get_arg_type(x) for x in arg)
|
|
839
849
|
|
|
850
|
+
if is_array(arg):
|
|
851
|
+
return arg
|
|
852
|
+
|
|
840
853
|
if get_origin(arg) is tuple:
|
|
841
854
|
return tuple(get_arg_type(x) for x in get_args(arg))
|
|
842
855
|
|
|
@@ -896,6 +909,8 @@ class Adjoint:
|
|
|
896
909
|
adj.skip_forward_codegen = skip_forward_codegen
|
|
897
910
|
# whether the generation of the adjoint code is skipped for this function
|
|
898
911
|
adj.skip_reverse_codegen = skip_reverse_codegen
|
|
912
|
+
# Whether this function is used by a kernel that has has the backward pass enabled.
|
|
913
|
+
adj.used_by_backward_kernel = False
|
|
899
914
|
|
|
900
915
|
# extract name of source file
|
|
901
916
|
adj.filename = inspect.getsourcefile(func) or "unknown source file"
|
|
@@ -962,7 +977,7 @@ class Adjoint:
|
|
|
962
977
|
continue
|
|
963
978
|
|
|
964
979
|
# add variable for argument
|
|
965
|
-
arg = Var(name, type, False)
|
|
980
|
+
arg = Var(name, type, requires_grad=False)
|
|
966
981
|
adj.args.append(arg)
|
|
967
982
|
|
|
968
983
|
# pre-populate symbol dictionary with function argument names
|
|
@@ -1071,17 +1086,21 @@ class Adjoint:
|
|
|
1071
1086
|
# recursively evaluate function body
|
|
1072
1087
|
try:
|
|
1073
1088
|
adj.eval(adj.tree.body[0])
|
|
1074
|
-
except Exception:
|
|
1089
|
+
except Exception as original_exc:
|
|
1075
1090
|
try:
|
|
1076
1091
|
lineno = adj.lineno + adj.fun_lineno
|
|
1077
1092
|
line = adj.source_lines[adj.lineno]
|
|
1078
1093
|
msg = f'Error while parsing function "{adj.fun_name}" at {adj.filename}:{lineno}:\n{line}\n'
|
|
1079
|
-
|
|
1080
|
-
|
|
1094
|
+
|
|
1095
|
+
# Combine the new message with the original exception's arguments
|
|
1096
|
+
new_args = (";".join([msg] + [str(a) for a in original_exc.args]),)
|
|
1097
|
+
|
|
1098
|
+
# Enhance the original exception with parser context before re-raising.
|
|
1099
|
+
# 'from None' is used to suppress Python's chained exceptions for a cleaner error output.
|
|
1100
|
+
raise type(original_exc)(*new_args).with_traceback(original_exc.__traceback__) from None
|
|
1081
1101
|
finally:
|
|
1082
1102
|
adj.skip_build = True
|
|
1083
1103
|
adj.builder = None
|
|
1084
|
-
raise e
|
|
1085
1104
|
|
|
1086
1105
|
if builder is not None:
|
|
1087
1106
|
for a in adj.args:
|
|
@@ -1225,11 +1244,16 @@ class Adjoint:
|
|
|
1225
1244
|
A line directive for the given statement, or None if no line directive is needed.
|
|
1226
1245
|
"""
|
|
1227
1246
|
|
|
1247
|
+
if adj.filename == "unknown source file" or adj.fun_lineno == 0:
|
|
1248
|
+
# Early return if function is not associated with a source file or is otherwise invalid
|
|
1249
|
+
# TODO: Get line directives working with wp.map() functions
|
|
1250
|
+
return None
|
|
1251
|
+
|
|
1228
1252
|
# lineinfo is enabled by default in debug mode regardless of the builder option, don't want to unnecessarily
|
|
1229
1253
|
# emit line directives in generated code if it's not being compiled with line information
|
|
1230
|
-
|
|
1231
|
-
|
|
1232
|
-
)
|
|
1254
|
+
build_mode = val if (val := adj.builder_options.get("mode")) is not None else warp.config.mode
|
|
1255
|
+
|
|
1256
|
+
lineinfo_enabled = adj.builder_options.get("lineinfo", False) or build_mode == "debug"
|
|
1233
1257
|
|
|
1234
1258
|
if relative_lineno is not None and lineinfo_enabled and warp.config.line_directives:
|
|
1235
1259
|
is_comment = statement.strip().startswith("//")
|
|
@@ -1348,7 +1372,7 @@ class Adjoint:
|
|
|
1348
1372
|
# unresolved function, report error
|
|
1349
1373
|
arg_type_reprs = []
|
|
1350
1374
|
|
|
1351
|
-
for x in arg_types:
|
|
1375
|
+
for x in itertools.chain(arg_types, kwarg_types.values()):
|
|
1352
1376
|
if isinstance(x, warp.context.Function):
|
|
1353
1377
|
arg_type_reprs.append("function")
|
|
1354
1378
|
else:
|
|
@@ -1378,7 +1402,7 @@ class Adjoint:
|
|
|
1378
1402
|
# in order to process them as Python does it.
|
|
1379
1403
|
bound_args: inspect.BoundArguments = func.signature.bind(*args, **kwargs)
|
|
1380
1404
|
|
|
1381
|
-
# Type args are the
|
|
1405
|
+
# Type args are the "compile time" argument values we get from codegen.
|
|
1382
1406
|
# For example, when calling `wp.vec3f(...)` from within a kernel,
|
|
1383
1407
|
# this translates in fact to calling the `vector()` built-in augmented
|
|
1384
1408
|
# with the type args `length=3, dtype=float`.
|
|
@@ -1416,20 +1440,30 @@ class Adjoint:
|
|
|
1416
1440
|
bound_args = bound_args.arguments
|
|
1417
1441
|
|
|
1418
1442
|
# if it is a user-function then build it recursively
|
|
1419
|
-
if not func.is_builtin()
|
|
1420
|
-
|
|
1421
|
-
#
|
|
1422
|
-
|
|
1423
|
-
|
|
1424
|
-
|
|
1425
|
-
if
|
|
1426
|
-
|
|
1427
|
-
|
|
1428
|
-
|
|
1443
|
+
if not func.is_builtin():
|
|
1444
|
+
# If the function called is a user function,
|
|
1445
|
+
# we need to ensure its adjoint is also being generated.
|
|
1446
|
+
if adj.used_by_backward_kernel:
|
|
1447
|
+
func.adj.used_by_backward_kernel = True
|
|
1448
|
+
|
|
1449
|
+
if adj.builder is None:
|
|
1450
|
+
func.build(None)
|
|
1451
|
+
|
|
1452
|
+
elif func not in adj.builder.functions:
|
|
1453
|
+
adj.builder.build_function(func)
|
|
1454
|
+
# add custom grad, replay functions to the list of functions
|
|
1455
|
+
# to be built later (invalid code could be generated if we built them now)
|
|
1456
|
+
# so that they are not missed when only the forward function is imported
|
|
1457
|
+
# from another module
|
|
1458
|
+
if func.custom_grad_func:
|
|
1459
|
+
adj.builder.deferred_functions.append(func.custom_grad_func)
|
|
1460
|
+
if func.custom_replay_func:
|
|
1461
|
+
adj.builder.deferred_functions.append(func.custom_replay_func)
|
|
1429
1462
|
|
|
1430
1463
|
# Resolve the return value based on the types and values of the given arguments.
|
|
1431
1464
|
bound_arg_types = {k: get_arg_type(v) for k, v in bound_args.items()}
|
|
1432
1465
|
bound_arg_values = {k: get_arg_value(v) for k, v in bound_args.items()}
|
|
1466
|
+
|
|
1433
1467
|
return_type = func.value_func(
|
|
1434
1468
|
{k: strip_reference(v) for k, v in bound_arg_types.items()},
|
|
1435
1469
|
bound_arg_values,
|
|
@@ -1493,6 +1527,9 @@ class Adjoint:
|
|
|
1493
1527
|
|
|
1494
1528
|
# if the argument is a function (and not a builtin), then build it recursively
|
|
1495
1529
|
if isinstance(func_arg_var, warp.context.Function) and not func_arg_var.is_builtin():
|
|
1530
|
+
if adj.used_by_backward_kernel:
|
|
1531
|
+
func_arg_var.adj.used_by_backward_kernel = True
|
|
1532
|
+
|
|
1496
1533
|
adj.builder.build_function(func_arg_var)
|
|
1497
1534
|
|
|
1498
1535
|
fwd_args.append(strip_reference(func_arg_var))
|
|
@@ -1886,6 +1923,9 @@ class Adjoint:
|
|
|
1886
1923
|
return obj
|
|
1887
1924
|
if isinstance(obj, type):
|
|
1888
1925
|
return obj
|
|
1926
|
+
if isinstance(obj, Struct):
|
|
1927
|
+
adj.builder.build_struct_recursive(obj)
|
|
1928
|
+
return obj
|
|
1889
1929
|
if isinstance(obj, types.ModuleType):
|
|
1890
1930
|
return obj
|
|
1891
1931
|
|
|
@@ -1938,11 +1978,17 @@ class Adjoint:
|
|
|
1938
1978
|
aggregate = adj.eval(node.value)
|
|
1939
1979
|
|
|
1940
1980
|
try:
|
|
1981
|
+
if isinstance(aggregate, Var) and aggregate.constant is not None:
|
|
1982
|
+
# this case may occur when the attribute is a constant, e.g.: `IntEnum.A.value`
|
|
1983
|
+
return aggregate
|
|
1984
|
+
|
|
1941
1985
|
if isinstance(aggregate, types.ModuleType) or isinstance(aggregate, type):
|
|
1942
1986
|
out = getattr(aggregate, node.attr)
|
|
1943
1987
|
|
|
1944
1988
|
if warp.types.is_value(out):
|
|
1945
1989
|
return adj.add_constant(out)
|
|
1990
|
+
if isinstance(out, (enum.IntEnum, enum.IntFlag)):
|
|
1991
|
+
return adj.add_constant(int(out))
|
|
1946
1992
|
|
|
1947
1993
|
return out
|
|
1948
1994
|
|
|
@@ -1970,18 +2016,29 @@ class Adjoint:
|
|
|
1970
2016
|
return adj.add_builtin_call("transform_get_rotation", [aggregate])
|
|
1971
2017
|
|
|
1972
2018
|
else:
|
|
1973
|
-
|
|
2019
|
+
attr_var = aggregate_type.vars[node.attr]
|
|
2020
|
+
|
|
2021
|
+
# represent pointer types as uint64
|
|
2022
|
+
if isinstance(attr_var.type, pointer_t):
|
|
2023
|
+
cast = f"({Var.dtype_to_ctype(uint64)}*)"
|
|
2024
|
+
adj_cast = f"({Var.dtype_to_ctype(attr_var.type.dtype)}*)"
|
|
2025
|
+
attr_type = Reference(uint64)
|
|
2026
|
+
else:
|
|
2027
|
+
cast = ""
|
|
2028
|
+
adj_cast = ""
|
|
2029
|
+
attr_type = Reference(attr_var.type)
|
|
2030
|
+
|
|
1974
2031
|
attr = adj.add_var(attr_type)
|
|
1975
2032
|
|
|
1976
2033
|
if is_reference(aggregate.type):
|
|
1977
|
-
adj.add_forward(f"{attr.emit()} = &({aggregate.emit()}->{
|
|
2034
|
+
adj.add_forward(f"{attr.emit()} = {cast}&({aggregate.emit()}->{attr_var.label});")
|
|
1978
2035
|
else:
|
|
1979
|
-
adj.add_forward(f"{attr.emit()} = &({aggregate.emit()}.{
|
|
2036
|
+
adj.add_forward(f"{attr.emit()} = {cast}&({aggregate.emit()}.{attr_var.label});")
|
|
1980
2037
|
|
|
1981
2038
|
if adj.is_differentiable_value_type(strip_reference(attr_type)):
|
|
1982
|
-
adj.add_reverse(f"{aggregate.emit_adj()}.{
|
|
2039
|
+
adj.add_reverse(f"{aggregate.emit_adj()}.{attr_var.label} += {adj_cast}{attr.emit_adj()};")
|
|
1983
2040
|
else:
|
|
1984
|
-
adj.add_reverse(f"{aggregate.emit_adj()}.{
|
|
2041
|
+
adj.add_reverse(f"{aggregate.emit_adj()}.{attr_var.label} = {adj_cast}{attr.emit_adj()};")
|
|
1985
2042
|
|
|
1986
2043
|
return attr
|
|
1987
2044
|
|
|
@@ -2309,9 +2366,12 @@ class Adjoint:
|
|
|
2309
2366
|
|
|
2310
2367
|
return var
|
|
2311
2368
|
|
|
2312
|
-
if isinstance(expr, (type, Var, warp.context.Function)):
|
|
2369
|
+
if isinstance(expr, (type, Struct, Var, warp.context.Function)):
|
|
2313
2370
|
return expr
|
|
2314
2371
|
|
|
2372
|
+
if isinstance(expr, (enum.IntEnum, enum.IntFlag)):
|
|
2373
|
+
return adj.add_constant(int(expr))
|
|
2374
|
+
|
|
2315
2375
|
return adj.add_constant(expr)
|
|
2316
2376
|
|
|
2317
2377
|
def emit_Call(adj, node):
|
|
@@ -2360,7 +2420,8 @@ class Adjoint:
|
|
|
2360
2420
|
|
|
2361
2421
|
# struct constructor
|
|
2362
2422
|
if func is None and isinstance(caller, Struct):
|
|
2363
|
-
adj.builder
|
|
2423
|
+
if adj.builder is not None:
|
|
2424
|
+
adj.builder.build_struct_recursive(caller)
|
|
2364
2425
|
if node.args or node.keywords:
|
|
2365
2426
|
func = caller.value_constructor
|
|
2366
2427
|
else:
|
|
@@ -2420,68 +2481,45 @@ class Adjoint:
|
|
|
2420
2481
|
|
|
2421
2482
|
return adj.eval(node.value)
|
|
2422
2483
|
|
|
2423
|
-
|
|
2424
|
-
|
|
2425
|
-
|
|
2426
|
-
|
|
2427
|
-
|
|
2428
|
-
|
|
2429
|
-
|
|
2430
|
-
|
|
2431
|
-
|
|
2432
|
-
|
|
2433
|
-
|
|
2434
|
-
|
|
2435
|
-
|
|
2436
|
-
|
|
2437
|
-
|
|
2438
|
-
|
|
2439
|
-
|
|
2440
|
-
|
|
2441
|
-
|
|
2442
|
-
if isinstance(root.value, ast.Name):
|
|
2443
|
-
symbol = adj.emit_Name(root.value)
|
|
2444
|
-
symbol_type = strip_reference(symbol.type)
|
|
2445
|
-
if is_array(symbol_type):
|
|
2446
|
-
array = symbol
|
|
2447
|
-
break
|
|
2448
|
-
|
|
2449
|
-
root = root.value
|
|
2450
|
-
|
|
2451
|
-
# If not all indices index into the array, just evaluate the right-most indexing operation.
|
|
2452
|
-
if not array or (count > array.type.ndim):
|
|
2453
|
-
count = 1
|
|
2454
|
-
|
|
2455
|
-
indices = []
|
|
2456
|
-
root = node
|
|
2457
|
-
while len(indices) < count:
|
|
2458
|
-
if isinstance(root.slice, ast.Tuple):
|
|
2459
|
-
ij = [adj.eval(arg) for arg in root.slice.elts]
|
|
2460
|
-
elif isinstance(root.slice, ast.Index) and isinstance(root.slice.value, ast.Tuple):
|
|
2461
|
-
ij = [adj.eval(arg) for arg in root.slice.value.elts]
|
|
2462
|
-
else:
|
|
2463
|
-
ij = [adj.eval(root.slice)]
|
|
2464
|
-
|
|
2465
|
-
indices = ij + indices # prepend
|
|
2466
|
-
|
|
2467
|
-
root = root.value
|
|
2468
|
-
|
|
2469
|
-
target = adj.eval(root)
|
|
2484
|
+
def eval_indices(adj, target_type, indices):
|
|
2485
|
+
nodes = indices
|
|
2486
|
+
if hasattr(target_type, "_wp_generic_type_hint_"):
|
|
2487
|
+
indices = []
|
|
2488
|
+
for dim, node in enumerate(nodes):
|
|
2489
|
+
if isinstance(node, ast.Slice):
|
|
2490
|
+
# In the context of slicing a vec/mat type, indices are expected
|
|
2491
|
+
# to be compile-time constants, hence we can infer the actual slice
|
|
2492
|
+
# bounds also at compile-time.
|
|
2493
|
+
length = target_type._shape_[dim]
|
|
2494
|
+
step = 1 if node.step is None else adj.eval(node.step).constant
|
|
2495
|
+
|
|
2496
|
+
if node.lower is None:
|
|
2497
|
+
start = length - 1 if step < 0 else 0
|
|
2498
|
+
else:
|
|
2499
|
+
start = adj.eval(node.lower).constant
|
|
2500
|
+
start = min(max(start, -length), length)
|
|
2501
|
+
start = start + length if start < 0 else start
|
|
2470
2502
|
|
|
2471
|
-
|
|
2503
|
+
if node.upper is None:
|
|
2504
|
+
stop = -1 if step < 0 else length
|
|
2505
|
+
else:
|
|
2506
|
+
stop = adj.eval(node.upper).constant
|
|
2507
|
+
stop = min(max(stop, -length), length)
|
|
2508
|
+
stop = stop + length if stop < 0 else stop
|
|
2472
2509
|
|
|
2473
|
-
|
|
2474
|
-
|
|
2475
|
-
|
|
2476
|
-
|
|
2477
|
-
var = adj.eval(node.slice)
|
|
2478
|
-
var_name = var.label
|
|
2479
|
-
var = Var(f"adj_{var_name}", type=var.type, constant=None, prefix=False)
|
|
2480
|
-
return var
|
|
2510
|
+
slice = adj.add_builtin_call("slice", (start, stop, step))
|
|
2511
|
+
indices.append(slice)
|
|
2512
|
+
else:
|
|
2513
|
+
indices.append(adj.eval(node))
|
|
2481
2514
|
|
|
2482
|
-
|
|
2515
|
+
return tuple(indices)
|
|
2516
|
+
else:
|
|
2517
|
+
return tuple(adj.eval(x) for x in nodes)
|
|
2483
2518
|
|
|
2519
|
+
def emit_indexing(adj, target, indices):
|
|
2484
2520
|
target_type = strip_reference(target.type)
|
|
2521
|
+
indices = adj.eval_indices(target_type, indices)
|
|
2522
|
+
|
|
2485
2523
|
if is_array(target_type):
|
|
2486
2524
|
if len(indices) == target_type.ndim:
|
|
2487
2525
|
# handles array loads (where each dimension has an index specified)
|
|
@@ -2520,47 +2558,116 @@ class Adjoint:
|
|
|
2520
2558
|
|
|
2521
2559
|
return out
|
|
2522
2560
|
|
|
2561
|
+
# from a list of lists of indices, strip the first `count` indices
|
|
2562
|
+
@staticmethod
|
|
2563
|
+
def strip_indices(indices, count):
|
|
2564
|
+
dim = count
|
|
2565
|
+
while count > 0:
|
|
2566
|
+
ij = indices[0]
|
|
2567
|
+
indices = indices[1:]
|
|
2568
|
+
count -= len(ij)
|
|
2569
|
+
|
|
2570
|
+
# report straddling like in `arr2d[0][1,2]` as a syntax error
|
|
2571
|
+
if count < 0:
|
|
2572
|
+
raise WarpCodegenError(
|
|
2573
|
+
f"Incorrect number of indices specified for array indexing, got {dim - count} indices for a {dim} dimensional array."
|
|
2574
|
+
)
|
|
2575
|
+
|
|
2576
|
+
return indices
|
|
2577
|
+
|
|
2578
|
+
def recurse_subscript(adj, node, indices):
|
|
2579
|
+
if isinstance(node, ast.Name):
|
|
2580
|
+
target = adj.eval(node)
|
|
2581
|
+
return target, indices
|
|
2582
|
+
|
|
2583
|
+
if isinstance(node, ast.Subscript):
|
|
2584
|
+
if hasattr(node.value, "attr") and node.value.attr == "adjoint":
|
|
2585
|
+
return adj.eval(node), indices
|
|
2586
|
+
|
|
2587
|
+
if isinstance(node.slice, ast.Tuple):
|
|
2588
|
+
ij = node.slice.elts
|
|
2589
|
+
elif isinstance(node.slice, ast.Index) and isinstance(node.slice.value, ast.Tuple):
|
|
2590
|
+
# The node `ast.Index` is deprecated in Python 3.9.
|
|
2591
|
+
ij = node.slice.value.elts
|
|
2592
|
+
elif isinstance(node.slice, ast.ExtSlice):
|
|
2593
|
+
# The node `ast.ExtSlice` is deprecated in Python 3.9.
|
|
2594
|
+
ij = node.slice.dims
|
|
2595
|
+
else:
|
|
2596
|
+
ij = [node.slice]
|
|
2597
|
+
|
|
2598
|
+
indices = [ij, *indices] # prepend
|
|
2599
|
+
|
|
2600
|
+
target, indices = adj.recurse_subscript(node.value, indices)
|
|
2601
|
+
|
|
2602
|
+
target_type = strip_reference(target.type)
|
|
2603
|
+
if is_array(target_type):
|
|
2604
|
+
flat_indices = [i for ij in indices for i in ij]
|
|
2605
|
+
if len(flat_indices) > target_type.ndim:
|
|
2606
|
+
target = adj.emit_indexing(target, flat_indices[: target_type.ndim])
|
|
2607
|
+
indices = adj.strip_indices(indices, target_type.ndim)
|
|
2608
|
+
|
|
2609
|
+
return target, indices
|
|
2610
|
+
|
|
2611
|
+
target = adj.eval(node)
|
|
2612
|
+
return target, indices
|
|
2613
|
+
|
|
2614
|
+
# returns the object being indexed, and the list of indices
|
|
2615
|
+
def eval_subscript(adj, node):
|
|
2616
|
+
target, indices = adj.recurse_subscript(node, [])
|
|
2617
|
+
flat_indices = [i for ij in indices for i in ij]
|
|
2618
|
+
return target, flat_indices
|
|
2619
|
+
|
|
2620
|
+
def emit_Subscript(adj, node):
|
|
2621
|
+
if hasattr(node.value, "attr") and node.value.attr == "adjoint":
|
|
2622
|
+
# handle adjoint of a variable, i.e. wp.adjoint[var]
|
|
2623
|
+
node.slice.is_adjoint = True
|
|
2624
|
+
var = adj.eval(node.slice)
|
|
2625
|
+
var_name = var.label
|
|
2626
|
+
var = Var(f"adj_{var_name}", type=var.type, constant=None, prefix=False)
|
|
2627
|
+
return var
|
|
2628
|
+
|
|
2629
|
+
target, indices = adj.eval_subscript(node)
|
|
2630
|
+
|
|
2631
|
+
return adj.emit_indexing(target, indices)
|
|
2632
|
+
|
|
2523
2633
|
def emit_Assign(adj, node):
|
|
2524
2634
|
if len(node.targets) != 1:
|
|
2525
2635
|
raise WarpCodegenError("Assigning the same value to multiple variables is not supported")
|
|
2526
2636
|
|
|
2527
|
-
|
|
2637
|
+
# Check if the rhs corresponds to an unsupported construct.
|
|
2638
|
+
# Tuples are supported in the context of assigning multiple variables
|
|
2639
|
+
# at once, but not for simple assignments like `x = (1, 2, 3)`.
|
|
2640
|
+
# Therefore, we need to catch this specific case here instead of
|
|
2641
|
+
# more generally in `adj.eval()`.
|
|
2642
|
+
if isinstance(node.value, ast.List):
|
|
2643
|
+
raise WarpCodegenError(
|
|
2644
|
+
"List constructs are not supported in kernels. Use vectors like `wp.vec3()` for small collections instead."
|
|
2645
|
+
)
|
|
2528
2646
|
|
|
2529
|
-
|
|
2530
|
-
# Check if the rhs corresponds to an unsupported construct.
|
|
2531
|
-
# Tuples are supported in the context of assigning multiple variables
|
|
2532
|
-
# at once, but not for simple assignments like `x = (1, 2, 3)`.
|
|
2533
|
-
# Therefore, we need to catch this specific case here instead of
|
|
2534
|
-
# more generally in `adj.eval()`.
|
|
2535
|
-
if isinstance(node.value, ast.List):
|
|
2536
|
-
raise WarpCodegenError(
|
|
2537
|
-
"List constructs are not supported in kernels. Use vectors like `wp.vec3()` for small collections instead."
|
|
2538
|
-
)
|
|
2647
|
+
lhs = node.targets[0]
|
|
2539
2648
|
|
|
2540
|
-
|
|
2541
|
-
if isinstance(lhs, ast.Tuple):
|
|
2649
|
+
if isinstance(lhs, ast.Tuple) and isinstance(node.value, ast.Call):
|
|
2542
2650
|
# record the expected number of outputs on the node
|
|
2543
2651
|
# we do this so we can decide which function to
|
|
2544
2652
|
# call based on the number of expected outputs
|
|
2545
|
-
|
|
2546
|
-
node.value.expects = len(lhs.elts)
|
|
2653
|
+
node.value.expects = len(lhs.elts)
|
|
2547
2654
|
|
|
2548
|
-
|
|
2549
|
-
|
|
2550
|
-
|
|
2551
|
-
|
|
2552
|
-
|
|
2655
|
+
# evaluate rhs
|
|
2656
|
+
if isinstance(lhs, ast.Tuple) and isinstance(node.value, ast.Tuple):
|
|
2657
|
+
rhs = [adj.eval(v) for v in node.value.elts]
|
|
2658
|
+
else:
|
|
2659
|
+
rhs = adj.eval(node.value)
|
|
2660
|
+
|
|
2661
|
+
# handle the case where we are assigning multiple output variables
|
|
2662
|
+
if isinstance(lhs, ast.Tuple):
|
|
2663
|
+
subtype = getattr(rhs, "type", None)
|
|
2553
2664
|
|
|
2554
|
-
subtype = getattr(out, "type", None)
|
|
2555
2665
|
if isinstance(subtype, warp.types.tuple_t):
|
|
2556
|
-
if len(
|
|
2666
|
+
if len(rhs.type.types) != len(lhs.elts):
|
|
2557
2667
|
raise WarpCodegenError(
|
|
2558
|
-
f"Invalid number of values to unpack (expected {len(lhs.elts)}, got {len(
|
|
2668
|
+
f"Invalid number of values to unpack (expected {len(lhs.elts)}, got {len(rhs.type.types)})."
|
|
2559
2669
|
)
|
|
2560
|
-
|
|
2561
|
-
out = tuple(
|
|
2562
|
-
adj.add_builtin_call("extract", (target, adj.add_constant(i))) for i in range(len(lhs.elts))
|
|
2563
|
-
)
|
|
2670
|
+
rhs = tuple(adj.add_builtin_call("extract", (rhs, adj.add_constant(i))) for i in range(len(lhs.elts)))
|
|
2564
2671
|
|
|
2565
2672
|
names = []
|
|
2566
2673
|
for v in lhs.elts:
|
|
@@ -2571,11 +2678,12 @@ class Adjoint:
|
|
|
2571
2678
|
"Multiple return functions can only assign to simple variables, e.g.: x, y = func()"
|
|
2572
2679
|
)
|
|
2573
2680
|
|
|
2574
|
-
if len(names) != len(
|
|
2681
|
+
if len(names) != len(rhs):
|
|
2575
2682
|
raise WarpCodegenError(
|
|
2576
|
-
f"Multiple return functions need to receive all their output values, incorrect number of values to unpack (expected {len(
|
|
2683
|
+
f"Multiple return functions need to receive all their output values, incorrect number of values to unpack (expected {len(rhs)}, got {len(names)})"
|
|
2577
2684
|
)
|
|
2578
2685
|
|
|
2686
|
+
out = rhs
|
|
2579
2687
|
for name, rhs in zip(names, out):
|
|
2580
2688
|
if name in adj.symbols:
|
|
2581
2689
|
if not types_equal(rhs.type, adj.symbols[name].type):
|
|
@@ -2587,8 +2695,6 @@ class Adjoint:
|
|
|
2587
2695
|
|
|
2588
2696
|
# handles the case where we are assigning to an array index (e.g.: arr[i] = 2.0)
|
|
2589
2697
|
elif isinstance(lhs, ast.Subscript):
|
|
2590
|
-
rhs = adj.eval(node.value)
|
|
2591
|
-
|
|
2592
2698
|
if hasattr(lhs.value, "attr") and lhs.value.attr == "adjoint":
|
|
2593
2699
|
# handle adjoint of a variable, i.e. wp.adjoint[var]
|
|
2594
2700
|
lhs.slice.is_adjoint = True
|
|
@@ -2600,6 +2706,7 @@ class Adjoint:
|
|
|
2600
2706
|
target, indices = adj.eval_subscript(lhs)
|
|
2601
2707
|
|
|
2602
2708
|
target_type = strip_reference(target.type)
|
|
2709
|
+
indices = adj.eval_indices(target_type, indices)
|
|
2603
2710
|
|
|
2604
2711
|
if is_array(target_type):
|
|
2605
2712
|
adj.add_builtin_call("array_store", [target, *indices, rhs])
|
|
@@ -2621,14 +2728,11 @@ class Adjoint:
|
|
|
2621
2728
|
or type_is_transformation(target_type)
|
|
2622
2729
|
):
|
|
2623
2730
|
# recursively unwind AST, stopping at penultimate node
|
|
2624
|
-
|
|
2625
|
-
while hasattr(
|
|
2626
|
-
|
|
2627
|
-
node = node.value
|
|
2628
|
-
else:
|
|
2629
|
-
break
|
|
2731
|
+
root = lhs
|
|
2732
|
+
while hasattr(root.value, "value"):
|
|
2733
|
+
root = root.value
|
|
2630
2734
|
# lhs is updating a variable adjoint (i.e. wp.adjoint[var])
|
|
2631
|
-
if hasattr(
|
|
2735
|
+
if hasattr(root, "attr") and root.attr == "adjoint":
|
|
2632
2736
|
attr = adj.add_builtin_call("index", [target, *indices])
|
|
2633
2737
|
adj.add_builtin_call("store", [attr, rhs])
|
|
2634
2738
|
return
|
|
@@ -2666,9 +2770,6 @@ class Adjoint:
|
|
|
2666
2770
|
# symbol name
|
|
2667
2771
|
name = lhs.id
|
|
2668
2772
|
|
|
2669
|
-
# evaluate rhs
|
|
2670
|
-
rhs = adj.eval(node.value)
|
|
2671
|
-
|
|
2672
2773
|
# check type matches if symbol already defined
|
|
2673
2774
|
if name in adj.symbols:
|
|
2674
2775
|
if not types_equal(strip_reference(rhs.type), adj.symbols[name].type):
|
|
@@ -2689,7 +2790,6 @@ class Adjoint:
|
|
|
2689
2790
|
adj.symbols[name] = out
|
|
2690
2791
|
|
|
2691
2792
|
elif isinstance(lhs, ast.Attribute):
|
|
2692
|
-
rhs = adj.eval(node.value)
|
|
2693
2793
|
aggregate = adj.eval(lhs.value)
|
|
2694
2794
|
aggregate_type = strip_reference(aggregate.type)
|
|
2695
2795
|
|
|
@@ -2777,9 +2877,9 @@ class Adjoint:
|
|
|
2777
2877
|
new_node = ast.Assign(targets=[lhs], value=ast.BinOp(lhs, node.op, node.value))
|
|
2778
2878
|
adj.eval(new_node)
|
|
2779
2879
|
|
|
2780
|
-
|
|
2781
|
-
rhs = adj.eval(node.value)
|
|
2880
|
+
rhs = adj.eval(node.value)
|
|
2782
2881
|
|
|
2882
|
+
if isinstance(lhs, ast.Subscript):
|
|
2783
2883
|
# wp.adjoint[var] appears in custom grad functions, and does not require
|
|
2784
2884
|
# special consideration in the AugAssign case
|
|
2785
2885
|
if hasattr(lhs.value, "attr") and lhs.value.attr == "adjoint":
|
|
@@ -2789,6 +2889,7 @@ class Adjoint:
|
|
|
2789
2889
|
target, indices = adj.eval_subscript(lhs)
|
|
2790
2890
|
|
|
2791
2891
|
target_type = strip_reference(target.type)
|
|
2892
|
+
indices = adj.eval_indices(target_type, indices)
|
|
2792
2893
|
|
|
2793
2894
|
if is_array(target_type):
|
|
2794
2895
|
# target_types int8, uint8, int16, uint16 are not suitable for atomic array accumulation
|
|
@@ -2861,7 +2962,6 @@ class Adjoint:
|
|
|
2861
2962
|
|
|
2862
2963
|
elif isinstance(lhs, ast.Name):
|
|
2863
2964
|
target = adj.eval(node.target)
|
|
2864
|
-
rhs = adj.eval(node.value)
|
|
2865
2965
|
|
|
2866
2966
|
if is_tile(target.type) and is_tile(rhs.type):
|
|
2867
2967
|
if isinstance(node.op, ast.Add):
|
|
@@ -3163,6 +3263,8 @@ class Adjoint:
|
|
|
3163
3263
|
|
|
3164
3264
|
try:
|
|
3165
3265
|
value = eval(code_to_eval, vars_dict)
|
|
3266
|
+
if isinstance(value, (enum.IntEnum, enum.IntFlag)):
|
|
3267
|
+
value = int(value)
|
|
3166
3268
|
if warp.config.verbose:
|
|
3167
3269
|
print(f"Evaluated static command: {static_code} = {value}")
|
|
3168
3270
|
except NameError as e:
|
|
@@ -3373,6 +3475,11 @@ cuda_module_header = """
|
|
|
3373
3475
|
#define WP_NO_CRT
|
|
3374
3476
|
#include "builtin.h"
|
|
3375
3477
|
|
|
3478
|
+
// Map wp.breakpoint() to a device brkpt at the call site so cuda-gdb attributes the stop to the generated .cu line
|
|
3479
|
+
#if defined(__CUDACC__) && !defined(_MSC_VER)
|
|
3480
|
+
#define __debugbreak() __brkpt()
|
|
3481
|
+
#endif
|
|
3482
|
+
|
|
3376
3483
|
// avoid namespacing of float type for casting to float type, this is to avoid wp::float(x), which is not valid in C++
|
|
3377
3484
|
#define float(x) cast_float(x)
|
|
3378
3485
|
#define adj_float(x, adj_x, adj_ret) adj_cast_float(x, adj_x, adj_ret)
|
|
@@ -3410,6 +3517,12 @@ static CUDA_CALLABLE void adj_{name}({reverse_args})
|
|
|
3410
3517
|
{{
|
|
3411
3518
|
{reverse_body}}}
|
|
3412
3519
|
|
|
3520
|
+
// Required when compiling adjoints.
|
|
3521
|
+
CUDA_CALLABLE {name} add(const {name}& a, const {name}& b)
|
|
3522
|
+
{{
|
|
3523
|
+
return {name}();
|
|
3524
|
+
}}
|
|
3525
|
+
|
|
3413
3526
|
CUDA_CALLABLE void adj_atomic_add({name}* p, {name} t)
|
|
3414
3527
|
{{
|
|
3415
3528
|
{atomic_add_body}}}
|
|
@@ -3490,7 +3603,8 @@ cuda_kernel_template_backward = """
|
|
|
3490
3603
|
cpu_kernel_template_forward = """
|
|
3491
3604
|
|
|
3492
3605
|
void {name}_cpu_kernel_forward(
|
|
3493
|
-
{forward_args}
|
|
3606
|
+
{forward_args},
|
|
3607
|
+
wp_args_{name} *_wp_args)
|
|
3494
3608
|
{{
|
|
3495
3609
|
{forward_body}}}
|
|
3496
3610
|
|
|
@@ -3499,7 +3613,9 @@ void {name}_cpu_kernel_forward(
|
|
|
3499
3613
|
cpu_kernel_template_backward = """
|
|
3500
3614
|
|
|
3501
3615
|
void {name}_cpu_kernel_backward(
|
|
3502
|
-
{reverse_args}
|
|
3616
|
+
{reverse_args},
|
|
3617
|
+
wp_args_{name} *_wp_args,
|
|
3618
|
+
wp_args_{name} *_wp_adj_args)
|
|
3503
3619
|
{{
|
|
3504
3620
|
{reverse_body}}}
|
|
3505
3621
|
|
|
@@ -3511,15 +3627,15 @@ extern "C" {{
|
|
|
3511
3627
|
|
|
3512
3628
|
// Python CPU entry points
|
|
3513
3629
|
WP_API void {name}_cpu_forward(
|
|
3514
|
-
|
|
3630
|
+
wp::launch_bounds_t dim,
|
|
3631
|
+
wp_args_{name} *_wp_args)
|
|
3515
3632
|
{{
|
|
3516
3633
|
for (size_t task_index = 0; task_index < dim.size; ++task_index)
|
|
3517
3634
|
{{
|
|
3518
3635
|
// init shared memory allocator
|
|
3519
3636
|
wp::tile_alloc_shared(0, true);
|
|
3520
3637
|
|
|
3521
|
-
{name}_cpu_kernel_forward(
|
|
3522
|
-
{forward_params});
|
|
3638
|
+
{name}_cpu_kernel_forward(dim, task_index, _wp_args);
|
|
3523
3639
|
|
|
3524
3640
|
// check shared memory allocator
|
|
3525
3641
|
wp::tile_alloc_shared(0, false, true);
|
|
@@ -3536,15 +3652,16 @@ cpu_module_template_backward = """
|
|
|
3536
3652
|
extern "C" {{
|
|
3537
3653
|
|
|
3538
3654
|
WP_API void {name}_cpu_backward(
|
|
3539
|
-
|
|
3655
|
+
wp::launch_bounds_t dim,
|
|
3656
|
+
wp_args_{name} *_wp_args,
|
|
3657
|
+
wp_args_{name} *_wp_adj_args)
|
|
3540
3658
|
{{
|
|
3541
3659
|
for (size_t task_index = 0; task_index < dim.size; ++task_index)
|
|
3542
3660
|
{{
|
|
3543
3661
|
// initialize shared memory allocator
|
|
3544
3662
|
wp::tile_alloc_shared(0, true);
|
|
3545
3663
|
|
|
3546
|
-
{name}_cpu_kernel_backward(
|
|
3547
|
-
{reverse_params});
|
|
3664
|
+
{name}_cpu_kernel_backward(dim, task_index, _wp_args, _wp_adj_args);
|
|
3548
3665
|
|
|
3549
3666
|
// check shared memory allocator
|
|
3550
3667
|
wp::tile_alloc_shared(0, false, true);
|
|
@@ -3575,7 +3692,7 @@ def constant_str(value):
|
|
|
3575
3692
|
# special case for float16, which is stored as uint16 in the ctypes.Array
|
|
3576
3693
|
from warp.context import runtime
|
|
3577
3694
|
|
|
3578
|
-
scalar_value = runtime.core.
|
|
3695
|
+
scalar_value = runtime.core.wp_half_bits_to_float
|
|
3579
3696
|
else:
|
|
3580
3697
|
|
|
3581
3698
|
def scalar_value(x):
|
|
@@ -3713,8 +3830,17 @@ def codegen_func_forward(adj, func_type="kernel", device="cpu"):
|
|
|
3713
3830
|
|
|
3714
3831
|
indent_block = " " * indent
|
|
3715
3832
|
|
|
3716
|
-
# primal vars
|
|
3717
3833
|
lines = []
|
|
3834
|
+
|
|
3835
|
+
# argument vars
|
|
3836
|
+
if device == "cpu" and func_type == "kernel":
|
|
3837
|
+
lines += ["//---------\n"]
|
|
3838
|
+
lines += ["// argument vars\n"]
|
|
3839
|
+
|
|
3840
|
+
for var in adj.args:
|
|
3841
|
+
lines += [f"{var.ctype()} {var.emit()} = _wp_args->{var.label};\n"]
|
|
3842
|
+
|
|
3843
|
+
# primal vars
|
|
3718
3844
|
lines += ["//---------\n"]
|
|
3719
3845
|
lines += ["// primal vars\n"]
|
|
3720
3846
|
|
|
@@ -3758,6 +3884,17 @@ def codegen_func_reverse(adj, func_type="kernel", device="cpu"):
|
|
|
3758
3884
|
|
|
3759
3885
|
lines = []
|
|
3760
3886
|
|
|
3887
|
+
# argument vars
|
|
3888
|
+
if device == "cpu" and func_type == "kernel":
|
|
3889
|
+
lines += ["//---------\n"]
|
|
3890
|
+
lines += ["// argument vars\n"]
|
|
3891
|
+
|
|
3892
|
+
for var in adj.args:
|
|
3893
|
+
lines += [f"{var.ctype()} {var.emit()} = _wp_args->{var.label};\n"]
|
|
3894
|
+
|
|
3895
|
+
for var in adj.args:
|
|
3896
|
+
lines += [f"{var.ctype()} {var.emit_adj()} = _wp_adj_args->{var.label};\n"]
|
|
3897
|
+
|
|
3761
3898
|
# primal vars
|
|
3762
3899
|
lines += ["//---------\n"]
|
|
3763
3900
|
lines += ["// primal vars\n"]
|
|
@@ -3849,6 +3986,19 @@ def codegen_func(adj, c_func_name: str, device="cpu", options=None):
|
|
|
3849
3986
|
f"annotated as `{warp.context.type_str(adj.arg_types['return'])}` "
|
|
3850
3987
|
f"but the code returns a value of type `{warp.context.type_str(adj.return_var[0].type)}`."
|
|
3851
3988
|
)
|
|
3989
|
+
elif (
|
|
3990
|
+
isinstance(adj.return_var[0].type, warp.types.fixedarray)
|
|
3991
|
+
and type(adj.arg_types["return"]) is warp.types.array
|
|
3992
|
+
):
|
|
3993
|
+
# If the return statement yields a `fixedarray` while the function is annotated
|
|
3994
|
+
# to return a standard `array`, then raise an error since the `fixedarray` storage
|
|
3995
|
+
# allocated on the stack will be freed once the function exits, meaning that the
|
|
3996
|
+
# resulting `array` instance will point to an invalid data.
|
|
3997
|
+
raise WarpCodegenError(
|
|
3998
|
+
f"The function `{adj.fun_name}` returns a fixed-size array "
|
|
3999
|
+
f"whereas it has its return type annotated as "
|
|
4000
|
+
f"`{warp.context.type_str(adj.arg_types['return'])}`."
|
|
4001
|
+
)
|
|
3852
4002
|
|
|
3853
4003
|
# Build line directive for function definition (subtract 1 to account for 1-indexing of AST line numbers)
|
|
3854
4004
|
# This is used as a catch-all C-to-Python source line mapping for any code that does not have
|
|
@@ -3927,10 +4077,10 @@ def codegen_func(adj, c_func_name: str, device="cpu", options=None):
|
|
|
3927
4077
|
if adj.custom_reverse_mode:
|
|
3928
4078
|
reverse_body = "\t// user-defined adjoint code\n" + forward_body
|
|
3929
4079
|
else:
|
|
3930
|
-
if options.get("enable_backward", True):
|
|
4080
|
+
if options.get("enable_backward", True) and adj.used_by_backward_kernel:
|
|
3931
4081
|
reverse_body = codegen_func_reverse(adj, func_type="function", device=device)
|
|
3932
4082
|
else:
|
|
3933
|
-
reverse_body = '\t// reverse mode disabled (module option "enable_backward" is False)\n'
|
|
4083
|
+
reverse_body = '\t// reverse mode disabled (module option "enable_backward" is False or no dependent kernel found with "enable_backward")\n'
|
|
3934
4084
|
s += reverse_template.format(
|
|
3935
4085
|
name=c_func_name,
|
|
3936
4086
|
return_type=return_type,
|
|
@@ -4022,6 +4172,13 @@ def codegen_kernel(kernel, device, options):
|
|
|
4022
4172
|
|
|
4023
4173
|
adj = kernel.adj
|
|
4024
4174
|
|
|
4175
|
+
args_struct = ""
|
|
4176
|
+
if device == "cpu":
|
|
4177
|
+
args_struct = f"struct wp_args_{kernel.get_mangled_name()} {{\n"
|
|
4178
|
+
for i in adj.args:
|
|
4179
|
+
args_struct += f" {i.ctype()} {i.label};\n"
|
|
4180
|
+
args_struct += "};\n"
|
|
4181
|
+
|
|
4025
4182
|
# Build line directive for function definition (subtract 1 to account for 1-indexing of AST line numbers)
|
|
4026
4183
|
# This is used as a catch-all C-to-Python source line mapping for any code that does not have
|
|
4027
4184
|
# a direct mapping to a Python source line.
|
|
@@ -4047,9 +4204,9 @@ def codegen_kernel(kernel, device, options):
|
|
|
4047
4204
|
forward_args = ["wp::launch_bounds_t dim"]
|
|
4048
4205
|
if device == "cpu":
|
|
4049
4206
|
forward_args.append("size_t task_index")
|
|
4050
|
-
|
|
4051
|
-
|
|
4052
|
-
|
|
4207
|
+
else:
|
|
4208
|
+
for arg in adj.args:
|
|
4209
|
+
forward_args.append(arg.ctype() + " var_" + arg.label)
|
|
4053
4210
|
|
|
4054
4211
|
forward_body = codegen_func_forward(adj, func_type="kernel", device=device)
|
|
4055
4212
|
template_fmt_args.update(
|
|
@@ -4066,17 +4223,16 @@ def codegen_kernel(kernel, device, options):
|
|
|
4066
4223
|
reverse_args = ["wp::launch_bounds_t dim"]
|
|
4067
4224
|
if device == "cpu":
|
|
4068
4225
|
reverse_args.append("size_t task_index")
|
|
4069
|
-
|
|
4070
|
-
|
|
4071
|
-
|
|
4072
|
-
|
|
4073
|
-
|
|
4074
|
-
|
|
4075
|
-
|
|
4076
|
-
|
|
4077
|
-
|
|
4078
|
-
|
|
4079
|
-
reverse_args.append(arg.ctype() + " adj_" + arg.label)
|
|
4226
|
+
else:
|
|
4227
|
+
for arg in adj.args:
|
|
4228
|
+
reverse_args.append(arg.ctype() + " var_" + arg.label)
|
|
4229
|
+
for arg in adj.args:
|
|
4230
|
+
# indexed array gradients are regular arrays
|
|
4231
|
+
if isinstance(arg.type, indexedarray):
|
|
4232
|
+
_arg = Var(arg.label, array(dtype=arg.type.dtype, ndim=arg.type.ndim))
|
|
4233
|
+
reverse_args.append(_arg.ctype() + " adj_" + arg.label)
|
|
4234
|
+
else:
|
|
4235
|
+
reverse_args.append(arg.ctype() + " adj_" + arg.label)
|
|
4080
4236
|
|
|
4081
4237
|
reverse_body = codegen_func_reverse(adj, func_type="kernel", device=device)
|
|
4082
4238
|
template_fmt_args.update(
|
|
@@ -4088,7 +4244,7 @@ def codegen_kernel(kernel, device, options):
|
|
|
4088
4244
|
template += template_backward
|
|
4089
4245
|
|
|
4090
4246
|
s = template.format(**template_fmt_args)
|
|
4091
|
-
return s
|
|
4247
|
+
return args_struct + s
|
|
4092
4248
|
|
|
4093
4249
|
|
|
4094
4250
|
def codegen_module(kernel, device, options):
|
|
@@ -4099,59 +4255,14 @@ def codegen_module(kernel, device, options):
|
|
|
4099
4255
|
options = dict(options)
|
|
4100
4256
|
options.update(kernel.options)
|
|
4101
4257
|
|
|
4102
|
-
adj = kernel.adj
|
|
4103
|
-
|
|
4104
4258
|
template = ""
|
|
4105
4259
|
template_fmt_args = {
|
|
4106
4260
|
"name": kernel.get_mangled_name(),
|
|
4107
4261
|
}
|
|
4108
4262
|
|
|
4109
|
-
# build forward signature
|
|
4110
|
-
forward_args = ["wp::launch_bounds_t dim"]
|
|
4111
|
-
forward_params = ["dim", "task_index"]
|
|
4112
|
-
|
|
4113
|
-
for arg in adj.args:
|
|
4114
|
-
if hasattr(arg.type, "_wp_generic_type_str_"):
|
|
4115
|
-
# vectors and matrices are passed from Python by pointer
|
|
4116
|
-
forward_args.append(f"const {arg.ctype()}* var_" + arg.label)
|
|
4117
|
-
forward_params.append(f"*var_{arg.label}")
|
|
4118
|
-
else:
|
|
4119
|
-
forward_args.append(f"{arg.ctype()} var_{arg.label}")
|
|
4120
|
-
forward_params.append("var_" + arg.label)
|
|
4121
|
-
|
|
4122
|
-
template_fmt_args.update(
|
|
4123
|
-
{
|
|
4124
|
-
"forward_args": indent(forward_args),
|
|
4125
|
-
"forward_params": indent(forward_params, 3),
|
|
4126
|
-
}
|
|
4127
|
-
)
|
|
4128
4263
|
template += cpu_module_template_forward
|
|
4129
4264
|
|
|
4130
4265
|
if options["enable_backward"]:
|
|
4131
|
-
# build reverse signature
|
|
4132
|
-
reverse_args = [*forward_args]
|
|
4133
|
-
reverse_params = [*forward_params]
|
|
4134
|
-
|
|
4135
|
-
for arg in adj.args:
|
|
4136
|
-
if isinstance(arg.type, indexedarray):
|
|
4137
|
-
# indexed array gradients are regular arrays
|
|
4138
|
-
_arg = Var(arg.label, array(dtype=arg.type.dtype, ndim=arg.type.ndim))
|
|
4139
|
-
reverse_args.append(f"const {_arg.ctype()} adj_{arg.label}")
|
|
4140
|
-
reverse_params.append(f"adj_{_arg.label}")
|
|
4141
|
-
elif hasattr(arg.type, "_wp_generic_type_str_"):
|
|
4142
|
-
# vectors and matrices are passed from Python by pointer
|
|
4143
|
-
reverse_args.append(f"const {arg.ctype()}* adj_{arg.label}")
|
|
4144
|
-
reverse_params.append(f"*adj_{arg.label}")
|
|
4145
|
-
else:
|
|
4146
|
-
reverse_args.append(f"{arg.ctype()} adj_{arg.label}")
|
|
4147
|
-
reverse_params.append(f"adj_{arg.label}")
|
|
4148
|
-
|
|
4149
|
-
template_fmt_args.update(
|
|
4150
|
-
{
|
|
4151
|
-
"reverse_args": indent(reverse_args),
|
|
4152
|
-
"reverse_params": indent(reverse_params, 3),
|
|
4153
|
-
}
|
|
4154
|
-
)
|
|
4155
4266
|
template += cpu_module_template_backward
|
|
4156
4267
|
|
|
4157
4268
|
s = template.format(**template_fmt_args)
|