warp-lang 1.4.2__py3-none-macosx_10_13_universal2.whl → 1.5.1__py3-none-macosx_10_13_universal2.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +4 -0
- warp/autograd.py +43 -8
- warp/bin/libwarp.dylib +0 -0
- warp/build.py +21 -2
- warp/build_dll.py +23 -6
- warp/builtins.py +1819 -7
- warp/codegen.py +197 -61
- warp/config.py +2 -2
- warp/context.py +379 -107
- warp/examples/assets/pixel.jpg +0 -0
- warp/examples/benchmarks/benchmark_cloth_paddle.py +86 -0
- warp/examples/benchmarks/benchmark_gemm.py +121 -0
- warp/examples/benchmarks/benchmark_interop_paddle.py +158 -0
- warp/examples/benchmarks/benchmark_tile.py +179 -0
- warp/examples/fem/example_adaptive_grid.py +37 -10
- warp/examples/fem/example_apic_fluid.py +3 -2
- warp/examples/fem/example_convection_diffusion_dg.py +4 -5
- warp/examples/fem/example_deformed_geometry.py +1 -1
- warp/examples/fem/example_diffusion_3d.py +47 -4
- warp/examples/fem/example_distortion_energy.py +220 -0
- warp/examples/fem/example_magnetostatics.py +127 -85
- warp/examples/fem/example_nonconforming_contact.py +5 -5
- warp/examples/fem/example_stokes.py +3 -1
- warp/examples/fem/example_streamlines.py +12 -19
- warp/examples/fem/utils.py +38 -15
- warp/examples/sim/example_cloth.py +4 -25
- warp/examples/sim/example_quadruped.py +2 -1
- warp/examples/tile/example_tile_convolution.py +58 -0
- warp/examples/tile/example_tile_fft.py +47 -0
- warp/examples/tile/example_tile_filtering.py +105 -0
- warp/examples/tile/example_tile_matmul.py +79 -0
- warp/examples/tile/example_tile_mlp.py +375 -0
- warp/fem/__init__.py +8 -0
- warp/fem/cache.py +16 -12
- warp/fem/dirichlet.py +1 -1
- warp/fem/domain.py +44 -1
- warp/fem/field/__init__.py +1 -2
- warp/fem/field/field.py +31 -19
- warp/fem/field/nodal_field.py +101 -49
- warp/fem/field/virtual.py +794 -0
- warp/fem/geometry/__init__.py +2 -2
- warp/fem/geometry/deformed_geometry.py +3 -105
- warp/fem/geometry/element.py +13 -0
- warp/fem/geometry/geometry.py +165 -7
- warp/fem/geometry/grid_2d.py +3 -6
- warp/fem/geometry/grid_3d.py +31 -28
- warp/fem/geometry/hexmesh.py +3 -46
- warp/fem/geometry/nanogrid.py +3 -2
- warp/fem/geometry/{quadmesh_2d.py → quadmesh.py} +280 -159
- warp/fem/geometry/tetmesh.py +2 -43
- warp/fem/geometry/{trimesh_2d.py → trimesh.py} +354 -186
- warp/fem/integrate.py +683 -261
- warp/fem/linalg.py +404 -0
- warp/fem/operator.py +101 -18
- warp/fem/polynomial.py +5 -5
- warp/fem/quadrature/quadrature.py +45 -21
- warp/fem/space/__init__.py +45 -11
- warp/fem/space/basis_function_space.py +451 -0
- warp/fem/space/basis_space.py +58 -11
- warp/fem/space/function_space.py +146 -5
- warp/fem/space/grid_2d_function_space.py +80 -66
- warp/fem/space/grid_3d_function_space.py +113 -68
- warp/fem/space/hexmesh_function_space.py +96 -108
- warp/fem/space/nanogrid_function_space.py +62 -110
- warp/fem/space/quadmesh_function_space.py +208 -0
- warp/fem/space/shape/__init__.py +45 -7
- warp/fem/space/shape/cube_shape_function.py +328 -54
- warp/fem/space/shape/shape_function.py +10 -1
- warp/fem/space/shape/square_shape_function.py +328 -60
- warp/fem/space/shape/tet_shape_function.py +269 -19
- warp/fem/space/shape/triangle_shape_function.py +238 -19
- warp/fem/space/tetmesh_function_space.py +69 -37
- warp/fem/space/topology.py +38 -0
- warp/fem/space/trimesh_function_space.py +179 -0
- warp/fem/utils.py +6 -331
- warp/jax_experimental.py +3 -1
- warp/native/array.h +15 -0
- warp/native/builtin.h +66 -26
- warp/native/bvh.h +4 -0
- warp/native/coloring.cpp +604 -0
- warp/native/cuda_util.cpp +68 -51
- warp/native/cuda_util.h +2 -1
- warp/native/fabric.h +8 -0
- warp/native/hashgrid.h +4 -0
- warp/native/marching.cu +8 -0
- warp/native/mat.h +14 -3
- warp/native/mathdx.cpp +59 -0
- warp/native/mesh.h +4 -0
- warp/native/range.h +13 -1
- warp/native/reduce.cpp +9 -1
- warp/native/reduce.cu +7 -0
- warp/native/runlength_encode.cpp +9 -1
- warp/native/runlength_encode.cu +7 -1
- warp/native/scan.cpp +8 -0
- warp/native/scan.cu +8 -0
- warp/native/scan.h +8 -1
- warp/native/sparse.cpp +8 -0
- warp/native/sparse.cu +8 -0
- warp/native/temp_buffer.h +7 -0
- warp/native/tile.h +1854 -0
- warp/native/tile_gemm.h +341 -0
- warp/native/tile_reduce.h +210 -0
- warp/native/volume_builder.cu +8 -0
- warp/native/volume_builder.h +8 -0
- warp/native/warp.cpp +10 -2
- warp/native/warp.cu +369 -15
- warp/native/warp.h +12 -2
- warp/optim/adam.py +39 -4
- warp/paddle.py +29 -12
- warp/render/render_opengl.py +140 -67
- warp/sim/graph_coloring.py +292 -0
- warp/sim/import_urdf.py +8 -8
- warp/sim/integrator_euler.py +4 -2
- warp/sim/integrator_featherstone.py +115 -44
- warp/sim/integrator_vbd.py +6 -0
- warp/sim/model.py +109 -32
- warp/sparse.py +1 -1
- warp/stubs.py +569 -4
- warp/tape.py +12 -7
- warp/tests/assets/pixel.npy +0 -0
- warp/tests/aux_test_instancing_gc.py +18 -0
- warp/tests/test_array.py +39 -0
- warp/tests/test_codegen.py +81 -1
- warp/tests/test_codegen_instancing.py +30 -0
- warp/tests/test_collision.py +110 -0
- warp/tests/test_coloring.py +251 -0
- warp/tests/test_context.py +34 -0
- warp/tests/test_examples.py +21 -5
- warp/tests/test_fem.py +453 -113
- warp/tests/test_func.py +34 -4
- warp/tests/test_generics.py +52 -0
- warp/tests/test_iter.py +68 -0
- warp/tests/test_lerp.py +13 -87
- warp/tests/test_mat_scalar_ops.py +1 -1
- warp/tests/test_matmul.py +6 -9
- warp/tests/test_matmul_lite.py +6 -11
- warp/tests/test_mesh_query_point.py +1 -1
- warp/tests/test_module_hashing.py +23 -0
- warp/tests/test_overwrite.py +45 -0
- warp/tests/test_paddle.py +27 -87
- warp/tests/test_print.py +56 -1
- warp/tests/test_smoothstep.py +17 -83
- warp/tests/test_spatial.py +1 -1
- warp/tests/test_static.py +3 -3
- warp/tests/test_tile.py +744 -0
- warp/tests/test_tile_mathdx.py +144 -0
- warp/tests/test_tile_mlp.py +383 -0
- warp/tests/test_tile_reduce.py +374 -0
- warp/tests/test_tile_shared_memory.py +190 -0
- warp/tests/test_vbd.py +12 -20
- warp/tests/test_volume.py +43 -0
- warp/tests/unittest_suites.py +19 -2
- warp/tests/unittest_utils.py +4 -2
- warp/types.py +340 -74
- warp/utils.py +23 -3
- {warp_lang-1.4.2.dist-info → warp_lang-1.5.1.dist-info}/METADATA +32 -7
- {warp_lang-1.4.2.dist-info → warp_lang-1.5.1.dist-info}/RECORD +160 -133
- {warp_lang-1.4.2.dist-info → warp_lang-1.5.1.dist-info}/WHEEL +1 -1
- warp/fem/field/test.py +0 -180
- warp/fem/field/trial.py +0 -183
- warp/fem/space/collocated_function_space.py +0 -102
- warp/fem/space/quadmesh_2d_function_space.py +0 -261
- warp/fem/space/trimesh_2d_function_space.py +0 -153
- {warp_lang-1.4.2.dist-info → warp_lang-1.5.1.dist-info}/LICENSE.md +0 -0
- {warp_lang-1.4.2.dist-info → warp_lang-1.5.1.dist-info}/top_level.txt +0 -0
warp/codegen.py
CHANGED
|
@@ -23,6 +23,10 @@ from typing import Any, Callable, Dict, Mapping, Optional, Sequence
|
|
|
23
23
|
import warp.config
|
|
24
24
|
from warp.types import *
|
|
25
25
|
|
|
26
|
+
# used as a globally accessible copy
|
|
27
|
+
# of current compile options (block_dim) etc
|
|
28
|
+
options = {}
|
|
29
|
+
|
|
26
30
|
|
|
27
31
|
class WarpCodegenError(RuntimeError):
|
|
28
32
|
def __init__(self, message):
|
|
@@ -110,6 +114,16 @@ def get_closure_cell_contents(obj):
|
|
|
110
114
|
return None
|
|
111
115
|
|
|
112
116
|
|
|
117
|
+
def get_type_origin(tp):
|
|
118
|
+
# Compatible version of `typing.get_origin()` for Python 3.7 and older.
|
|
119
|
+
return getattr(tp, "__origin__", None)
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def get_type_args(tp):
|
|
123
|
+
# Compatible version of `typing.get_args()` for Python 3.7 and older.
|
|
124
|
+
return getattr(tp, "__args__", ())
|
|
125
|
+
|
|
126
|
+
|
|
113
127
|
def eval_annotations(annotations: Mapping[str, Any], obj: Any) -> Mapping[str, Any]:
|
|
114
128
|
"""Un-stringize annotations caused by `from __future__ import annotations` of PEP 563."""
|
|
115
129
|
# Implementation backported from `inspect.get_annotations()` for Python 3.9 and older.
|
|
@@ -637,6 +651,8 @@ class Var:
|
|
|
637
651
|
dtypestr = f"wp::{t.dtype.__name__}"
|
|
638
652
|
classstr = f"wp::{type(t).__name__}"
|
|
639
653
|
return f"{classstr}_t<{dtypestr}>"
|
|
654
|
+
elif is_tile(t):
|
|
655
|
+
return t.ctype()
|
|
640
656
|
elif isinstance(t, Struct):
|
|
641
657
|
return t.native_name
|
|
642
658
|
elif isinstance(t, type) and issubclass(t, StructInstance):
|
|
@@ -876,7 +892,7 @@ class Adjoint:
|
|
|
876
892
|
# use source-level argument annotations
|
|
877
893
|
if len(argspec.annotations) < len(argspec.args):
|
|
878
894
|
raise WarpCodegenError(f"Incomplete argument annotations on function {adj.fun_name}")
|
|
879
|
-
adj.arg_types = argspec.annotations
|
|
895
|
+
adj.arg_types = {k: v for k, v in argspec.annotations.items() if not (k == "return" and v is None)}
|
|
880
896
|
else:
|
|
881
897
|
# use overload argument annotations
|
|
882
898
|
for arg_name in argspec.args:
|
|
@@ -914,6 +930,28 @@ class Adjoint:
|
|
|
914
930
|
# for unit testing errors being spit out from kernels.
|
|
915
931
|
adj.skip_build = False
|
|
916
932
|
|
|
933
|
+
# Collect the LTOIR required at link-time
|
|
934
|
+
adj.ltoirs = []
|
|
935
|
+
|
|
936
|
+
# allocate extra space for a function call that requires its
|
|
937
|
+
# own shared memory space, we treat shared memory as a stack
|
|
938
|
+
# where each function pushes and pops space off, the extra
|
|
939
|
+
# quantity is the 'roofline' amount required for the entire kernel
|
|
940
|
+
def alloc_shared_extra(adj, num_bytes):
|
|
941
|
+
adj.max_required_extra_shared_memory = max(adj.max_required_extra_shared_memory, num_bytes)
|
|
942
|
+
|
|
943
|
+
# returns the total number of bytes for a function
|
|
944
|
+
# based on it's own requirements + worst case
|
|
945
|
+
# requirements of any dependent functions
|
|
946
|
+
def get_total_required_shared(adj):
|
|
947
|
+
total_shared = 0
|
|
948
|
+
|
|
949
|
+
for var in adj.variables:
|
|
950
|
+
if is_tile(var.type) and var.type.storage == "shared":
|
|
951
|
+
total_shared += var.type.size_in_bytes()
|
|
952
|
+
|
|
953
|
+
return total_shared + adj.max_required_extra_shared_memory
|
|
954
|
+
|
|
917
955
|
# generate function ssa form and adjoint
|
|
918
956
|
def build(adj, builder, default_builder_options=None):
|
|
919
957
|
# arg Var read/write flags are held during module rebuilds, so we reset here even when skipping a build
|
|
@@ -934,6 +972,9 @@ class Adjoint:
|
|
|
934
972
|
else:
|
|
935
973
|
adj.builder_options = default_builder_options
|
|
936
974
|
|
|
975
|
+
global options
|
|
976
|
+
options = adj.builder_options
|
|
977
|
+
|
|
937
978
|
adj.symbols = {} # map from symbols to adjoint variables
|
|
938
979
|
adj.variables = [] # list of local variables (in order)
|
|
939
980
|
|
|
@@ -953,6 +994,9 @@ class Adjoint:
|
|
|
953
994
|
# used to generate new label indices
|
|
954
995
|
adj.label_count = 0
|
|
955
996
|
|
|
997
|
+
# tracks how much additional shared memory is required by any dependent function calls
|
|
998
|
+
adj.max_required_extra_shared_memory = 0
|
|
999
|
+
|
|
956
1000
|
# update symbol map for each argument
|
|
957
1001
|
for a in adj.args:
|
|
958
1002
|
adj.symbols[a.label] = a
|
|
@@ -969,6 +1013,7 @@ class Adjoint:
|
|
|
969
1013
|
e = ex(";".join([msg] + [str(a) for a in data.args])).with_traceback(traceback)
|
|
970
1014
|
finally:
|
|
971
1015
|
adj.skip_build = True
|
|
1016
|
+
adj.builder = None
|
|
972
1017
|
raise e
|
|
973
1018
|
|
|
974
1019
|
if builder is not None:
|
|
@@ -978,6 +1023,9 @@ class Adjoint:
|
|
|
978
1023
|
elif isinstance(a.type, warp.types.array) and isinstance(a.type.dtype, Struct):
|
|
979
1024
|
builder.build_struct_recursive(a.type.dtype)
|
|
980
1025
|
|
|
1026
|
+
# release builder reference for GC
|
|
1027
|
+
adj.builder = None
|
|
1028
|
+
|
|
981
1029
|
# code generation methods
|
|
982
1030
|
def format_template(adj, template, input_vars, output_var):
|
|
983
1031
|
# output var is always the 0th index
|
|
@@ -994,9 +1042,9 @@ class Adjoint:
|
|
|
994
1042
|
if isinstance(a, warp.context.Function):
|
|
995
1043
|
# functions don't have a var_ prefix so strip it off here
|
|
996
1044
|
if prefix == "var":
|
|
997
|
-
arg_strs.append(a.native_func)
|
|
1045
|
+
arg_strs.append(f"{a.namespace}{a.native_func}")
|
|
998
1046
|
else:
|
|
999
|
-
arg_strs.append(f"{prefix}_{a.native_func}")
|
|
1047
|
+
arg_strs.append(f"{a.namespace}{prefix}_{a.native_func}")
|
|
1000
1048
|
elif is_reference(a.type):
|
|
1001
1049
|
arg_strs.append(f"{prefix}_{a}")
|
|
1002
1050
|
elif isinstance(a, Var):
|
|
@@ -1127,25 +1175,25 @@ class Adjoint:
|
|
|
1127
1175
|
left = adj.load(left)
|
|
1128
1176
|
s = output.emit() + " = " + ("(" * len(comps)) + left.emit() + " "
|
|
1129
1177
|
|
|
1130
|
-
|
|
1178
|
+
prev_comp_var = None
|
|
1131
1179
|
|
|
1132
1180
|
for op, comp in zip(op_strings, comps):
|
|
1133
1181
|
comp_chainable = op_str_is_chainable(op)
|
|
1134
|
-
if comp_chainable and
|
|
1135
|
-
# We
|
|
1136
|
-
if
|
|
1137
|
-
|
|
1138
|
-
|
|
1139
|
-
s += "&& (" +
|
|
1182
|
+
if comp_chainable and prev_comp_var:
|
|
1183
|
+
# We restrict chaining to operands of the same type
|
|
1184
|
+
if prev_comp_var.type is comp.type:
|
|
1185
|
+
prev_comp_var = adj.load(prev_comp_var)
|
|
1186
|
+
comp_var = adj.load(comp)
|
|
1187
|
+
s += "&& (" + prev_comp_var.emit() + " " + op + " " + comp_var.emit() + ")) "
|
|
1140
1188
|
else:
|
|
1141
1189
|
raise WarpCodegenTypeError(
|
|
1142
|
-
f"Cannot chain comparisons of unequal types: {
|
|
1190
|
+
f"Cannot chain comparisons of unequal types: {prev_comp_var.type} {op} {comp.type}."
|
|
1143
1191
|
)
|
|
1144
1192
|
else:
|
|
1145
|
-
|
|
1146
|
-
s += op + " " +
|
|
1193
|
+
comp_var = adj.load(comp)
|
|
1194
|
+
s += op + " " + comp_var.emit() + ") "
|
|
1147
1195
|
|
|
1148
|
-
|
|
1196
|
+
prev_comp_var = comp_var
|
|
1149
1197
|
|
|
1150
1198
|
s = s.rstrip() + ";"
|
|
1151
1199
|
|
|
@@ -1278,15 +1326,34 @@ class Adjoint:
|
|
|
1278
1326
|
bound_arg_values,
|
|
1279
1327
|
)
|
|
1280
1328
|
|
|
1281
|
-
|
|
1282
|
-
|
|
1283
|
-
#
|
|
1284
|
-
|
|
1285
|
-
|
|
1286
|
-
|
|
1287
|
-
#
|
|
1288
|
-
|
|
1329
|
+
# immediately allocate output variables so we can pass them into the dispatch method
|
|
1330
|
+
if return_type is None:
|
|
1331
|
+
# void function
|
|
1332
|
+
output = None
|
|
1333
|
+
output_list = []
|
|
1334
|
+
elif not isinstance(return_type, Sequence) or len(return_type) == 1:
|
|
1335
|
+
# single return value function
|
|
1336
|
+
if isinstance(return_type, Sequence):
|
|
1337
|
+
return_type = return_type[0]
|
|
1338
|
+
output = adj.add_var(return_type)
|
|
1339
|
+
output_list = [output]
|
|
1340
|
+
else:
|
|
1341
|
+
# multiple return value function
|
|
1342
|
+
output = [adj.add_var(v) for v in return_type]
|
|
1343
|
+
output_list = output
|
|
1289
1344
|
|
|
1345
|
+
# If we have a built-in that requires special handling to dispatch
|
|
1346
|
+
# the arguments to the underlying C++ function, then we can resolve
|
|
1347
|
+
# these using the `dispatch_func`. Since this is only called from
|
|
1348
|
+
# within codegen, we pass it directly `codegen.Var` objects,
|
|
1349
|
+
# which allows for some more advanced resolution to be performed,
|
|
1350
|
+
# for example by checking whether an argument corresponds to
|
|
1351
|
+
# a literal value or references a variable.
|
|
1352
|
+
if func.lto_dispatch_func is not None:
|
|
1353
|
+
func_args, template_args, ltoirs = func.lto_dispatch_func(
|
|
1354
|
+
func.input_types, return_type, output_list, bound_args, options=adj.builder_options, builder=adj.builder
|
|
1355
|
+
)
|
|
1356
|
+
elif func.dispatch_func is not None:
|
|
1290
1357
|
func_args, template_args = func.dispatch_func(func.input_types, return_type, bound_args)
|
|
1291
1358
|
else:
|
|
1292
1359
|
func_args = tuple(bound_args.values())
|
|
@@ -1299,20 +1366,18 @@ class Adjoint:
|
|
|
1299
1366
|
fwd_args = []
|
|
1300
1367
|
for func_arg in func_args:
|
|
1301
1368
|
if not isinstance(func_arg, (Reference, warp.context.Function)):
|
|
1302
|
-
|
|
1369
|
+
func_arg_var = adj.load(func_arg)
|
|
1370
|
+
else:
|
|
1371
|
+
func_arg_var = func_arg
|
|
1303
1372
|
|
|
1304
|
-
# if the argument is a function, build it recursively
|
|
1305
|
-
if isinstance(
|
|
1306
|
-
adj.builder.build_function(
|
|
1373
|
+
# if the argument is a function (and not a builtin), then build it recursively
|
|
1374
|
+
if isinstance(func_arg_var, warp.context.Function) and not func_arg_var.is_builtin():
|
|
1375
|
+
adj.builder.build_function(func_arg_var)
|
|
1307
1376
|
|
|
1308
|
-
fwd_args.append(strip_reference(
|
|
1377
|
+
fwd_args.append(strip_reference(func_arg_var))
|
|
1309
1378
|
|
|
1310
1379
|
if return_type is None:
|
|
1311
1380
|
# handles expression (zero output) functions, e.g.: void do_something();
|
|
1312
|
-
|
|
1313
|
-
output = None
|
|
1314
|
-
output_list = []
|
|
1315
|
-
|
|
1316
1381
|
forward_call = (
|
|
1317
1382
|
f"{func.namespace}{func_name}({adj.format_forward_call_args(fwd_args, use_initializer_list)});"
|
|
1318
1383
|
)
|
|
@@ -1322,12 +1387,6 @@ class Adjoint:
|
|
|
1322
1387
|
|
|
1323
1388
|
elif not isinstance(return_type, Sequence) or len(return_type) == 1:
|
|
1324
1389
|
# handle simple function (one output)
|
|
1325
|
-
|
|
1326
|
-
if isinstance(return_type, Sequence):
|
|
1327
|
-
return_type = return_type[0]
|
|
1328
|
-
output = adj.add_var(return_type)
|
|
1329
|
-
output_list = [output]
|
|
1330
|
-
|
|
1331
1390
|
forward_call = f"var_{output} = {func.namespace}{func_name}({adj.format_forward_call_args(fwd_args, use_initializer_list)});"
|
|
1332
1391
|
replay_call = forward_call
|
|
1333
1392
|
if func.custom_replay_func is not None:
|
|
@@ -1335,10 +1394,6 @@ class Adjoint:
|
|
|
1335
1394
|
|
|
1336
1395
|
else:
|
|
1337
1396
|
# handle multiple value functions
|
|
1338
|
-
|
|
1339
|
-
output = [adj.add_var(v) for v in return_type]
|
|
1340
|
-
output_list = output
|
|
1341
|
-
|
|
1342
1397
|
forward_call = (
|
|
1343
1398
|
f"{func.namespace}{func_name}({adj.format_forward_call_args(fwd_args + output, use_initializer_list)});"
|
|
1344
1399
|
)
|
|
@@ -1366,6 +1421,11 @@ class Adjoint:
|
|
|
1366
1421
|
reverse_call = f"{func.namespace}adj_{func.native_func}({arg_str});"
|
|
1367
1422
|
adj.add_reverse(reverse_call)
|
|
1368
1423
|
|
|
1424
|
+
# update our smem roofline requirements based on any
|
|
1425
|
+
# shared memory required by the dependent function call
|
|
1426
|
+
if not func.is_builtin():
|
|
1427
|
+
adj.alloc_shared_extra(func.adj.get_total_required_shared())
|
|
1428
|
+
|
|
1369
1429
|
return output
|
|
1370
1430
|
|
|
1371
1431
|
def add_builtin_call(adj, func_name, args, min_outputs=None):
|
|
@@ -1466,7 +1526,10 @@ class Adjoint:
|
|
|
1466
1526
|
|
|
1467
1527
|
# zero adjoints
|
|
1468
1528
|
for i in body_block.vars:
|
|
1469
|
-
|
|
1529
|
+
if is_tile(i.type):
|
|
1530
|
+
reverse.append(adj.indentation + f"\t{i.emit_adj()}.grad_zero();")
|
|
1531
|
+
else:
|
|
1532
|
+
reverse.append(adj.indentation + f"\t{i.emit_adj()} = {{}};")
|
|
1470
1533
|
|
|
1471
1534
|
# replay
|
|
1472
1535
|
for i in body_block.body_replay:
|
|
@@ -2206,7 +2269,7 @@ class Adjoint:
|
|
|
2206
2269
|
|
|
2207
2270
|
# returns the object being indexed, and the list of indices
|
|
2208
2271
|
def eval_subscript(adj, node):
|
|
2209
|
-
# We want to coalesce multi-
|
|
2272
|
+
# We want to coalesce multi-dimensional array indexing into a single operation. This needs to deal with expressions like `a[i][j][x][y]` where `a` is a 2D array of matrices,
|
|
2210
2273
|
# and essentially rewrite it into `a[i, j][x][y]`. Since the AST observes the indexing right-to-left, and we don't want to evaluate the index expressions prematurely,
|
|
2211
2274
|
# this requires a first loop to check if this `node` only performs indexing on the array, and a second loop to evaluate and collect index variables.
|
|
2212
2275
|
root = node
|
|
@@ -2286,6 +2349,14 @@ class Adjoint:
|
|
|
2286
2349
|
out.is_read = target.is_read
|
|
2287
2350
|
out.is_write = target.is_write
|
|
2288
2351
|
|
|
2352
|
+
elif is_tile(target_type):
|
|
2353
|
+
if len(indices) == 2:
|
|
2354
|
+
# handles extracting a single element from a tile
|
|
2355
|
+
out = adj.add_builtin_call("tile_extract", [target, *indices])
|
|
2356
|
+
else:
|
|
2357
|
+
# handles tile views
|
|
2358
|
+
out = adj.add_builtin_call("tile_view", [target, *indices])
|
|
2359
|
+
|
|
2289
2360
|
else:
|
|
2290
2361
|
# handles non-array type indexing, e.g: vec3, mat33, etc
|
|
2291
2362
|
out = adj.add_builtin_call("extract", [target, *indices])
|
|
@@ -2500,8 +2571,10 @@ class Adjoint:
|
|
|
2500
2571
|
adj.return_var = ()
|
|
2501
2572
|
for ret in var:
|
|
2502
2573
|
if is_reference(ret.type):
|
|
2503
|
-
|
|
2504
|
-
|
|
2574
|
+
ret_var = adj.add_builtin_call("copy", [ret])
|
|
2575
|
+
else:
|
|
2576
|
+
ret_var = ret
|
|
2577
|
+
adj.return_var += (ret_var,)
|
|
2505
2578
|
|
|
2506
2579
|
adj.add_return(adj.return_var)
|
|
2507
2580
|
|
|
@@ -2527,11 +2600,22 @@ class Adjoint:
|
|
|
2527
2600
|
target_type = strip_reference(target.type)
|
|
2528
2601
|
|
|
2529
2602
|
if is_array(target_type):
|
|
2530
|
-
#
|
|
2531
|
-
if target_type.dtype
|
|
2603
|
+
# target_types int8, uint8, int16, uint16 are not suitable for atomic array accumulation
|
|
2604
|
+
if target_type.dtype in warp.types.non_atomic_types:
|
|
2532
2605
|
make_new_assign_statement()
|
|
2533
2606
|
return
|
|
2534
2607
|
|
|
2608
|
+
# the same holds true for vecs/mats/quats that are composed of these types
|
|
2609
|
+
if (
|
|
2610
|
+
type_is_vector(target_type.dtype)
|
|
2611
|
+
or type_is_quaternion(target_type.dtype)
|
|
2612
|
+
or type_is_matrix(target_type.dtype)
|
|
2613
|
+
):
|
|
2614
|
+
dtype = getattr(target_type.dtype, "_wp_scalar_type_", None)
|
|
2615
|
+
if dtype in warp.types.non_atomic_types:
|
|
2616
|
+
make_new_assign_statement()
|
|
2617
|
+
return
|
|
2618
|
+
|
|
2535
2619
|
kernel_name = adj.fun_name
|
|
2536
2620
|
filename = adj.filename
|
|
2537
2621
|
lineno = adj.lineno + adj.fun_lineno
|
|
@@ -2955,6 +3039,7 @@ class Adjoint:
|
|
|
2955
3039
|
# code generation
|
|
2956
3040
|
|
|
2957
3041
|
cpu_module_header = """
|
|
3042
|
+
#define WP_TILE_BLOCK_DIM {tile_size}
|
|
2958
3043
|
#define WP_NO_CRT
|
|
2959
3044
|
#include "builtin.h"
|
|
2960
3045
|
|
|
@@ -2965,7 +3050,7 @@ cpu_module_header = """
|
|
|
2965
3050
|
#define int(x) cast_int(x)
|
|
2966
3051
|
#define adj_int(x, adj_x, adj_ret) adj_cast_int(x, adj_x, adj_ret)
|
|
2967
3052
|
|
|
2968
|
-
#define builtin_tid1d() wp::tid(task_index)
|
|
3053
|
+
#define builtin_tid1d() wp::tid(task_index, dim)
|
|
2969
3054
|
#define builtin_tid2d(x, y) wp::tid(x, y, task_index, dim)
|
|
2970
3055
|
#define builtin_tid3d(x, y, z) wp::tid(x, y, z, task_index, dim)
|
|
2971
3056
|
#define builtin_tid4d(x, y, z, w) wp::tid(x, y, z, w, task_index, dim)
|
|
@@ -2973,6 +3058,7 @@ cpu_module_header = """
|
|
|
2973
3058
|
"""
|
|
2974
3059
|
|
|
2975
3060
|
cuda_module_header = """
|
|
3061
|
+
#define WP_TILE_BLOCK_DIM {tile_size}
|
|
2976
3062
|
#define WP_NO_CRT
|
|
2977
3063
|
#include "builtin.h"
|
|
2978
3064
|
|
|
@@ -2983,10 +3069,10 @@ cuda_module_header = """
|
|
|
2983
3069
|
#define int(x) cast_int(x)
|
|
2984
3070
|
#define adj_int(x, adj_x, adj_ret) adj_cast_int(x, adj_x, adj_ret)
|
|
2985
3071
|
|
|
2986
|
-
#define builtin_tid1d() wp::tid(
|
|
2987
|
-
#define builtin_tid2d(x, y) wp::tid(x, y,
|
|
2988
|
-
#define builtin_tid3d(x, y, z) wp::tid(x, y, z,
|
|
2989
|
-
#define builtin_tid4d(x, y, z, w) wp::tid(x, y, z, w,
|
|
3072
|
+
#define builtin_tid1d() wp::tid(_idx, dim)
|
|
3073
|
+
#define builtin_tid2d(x, y) wp::tid(x, y, _idx, dim)
|
|
3074
|
+
#define builtin_tid3d(x, y, z) wp::tid(x, y, z, _idx, dim)
|
|
3075
|
+
#define builtin_tid4d(x, y, z, w) wp::tid(x, y, z, w, _idx, dim)
|
|
2990
3076
|
|
|
2991
3077
|
"""
|
|
2992
3078
|
|
|
@@ -3058,20 +3144,26 @@ cuda_kernel_template = """
|
|
|
3058
3144
|
extern "C" __global__ void {name}_cuda_kernel_forward(
|
|
3059
3145
|
{forward_args})
|
|
3060
3146
|
{{
|
|
3061
|
-
for (size_t
|
|
3062
|
-
|
|
3063
|
-
|
|
3147
|
+
for (size_t _idx = static_cast<size_t>(blockDim.x) * static_cast<size_t>(blockIdx.x) + static_cast<size_t>(threadIdx.x);
|
|
3148
|
+
_idx < dim.size;
|
|
3149
|
+
_idx += static_cast<size_t>(blockDim.x) * static_cast<size_t>(gridDim.x))
|
|
3064
3150
|
{{
|
|
3151
|
+
// reset shared memory allocator
|
|
3152
|
+
wp::tile_alloc_shared(0, true);
|
|
3153
|
+
|
|
3065
3154
|
{forward_body} }}
|
|
3066
3155
|
}}
|
|
3067
3156
|
|
|
3068
3157
|
extern "C" __global__ void {name}_cuda_kernel_backward(
|
|
3069
3158
|
{reverse_args})
|
|
3070
3159
|
{{
|
|
3071
|
-
for (size_t
|
|
3072
|
-
|
|
3073
|
-
|
|
3160
|
+
for (size_t _idx = static_cast<size_t>(blockDim.x) * static_cast<size_t>(blockIdx.x) + static_cast<size_t>(threadIdx.x);
|
|
3161
|
+
_idx < dim.size;
|
|
3162
|
+
_idx += static_cast<size_t>(blockDim.x) * static_cast<size_t>(gridDim.x))
|
|
3074
3163
|
{{
|
|
3164
|
+
// reset shared memory allocator
|
|
3165
|
+
wp::tile_alloc_shared(0, true);
|
|
3166
|
+
|
|
3075
3167
|
{reverse_body} }}
|
|
3076
3168
|
}}
|
|
3077
3169
|
|
|
@@ -3309,7 +3401,9 @@ def codegen_func_forward(adj, func_type="kernel", device="cpu"):
|
|
|
3309
3401
|
lines += ["// primal vars\n"]
|
|
3310
3402
|
|
|
3311
3403
|
for var in adj.variables:
|
|
3312
|
-
if var.
|
|
3404
|
+
if is_tile(var.type):
|
|
3405
|
+
lines += [f"{var.ctype()} {var.emit()} = {var.type.cinit(requires_grad=False)};\n"]
|
|
3406
|
+
elif var.constant is None:
|
|
3313
3407
|
lines += [f"{var.ctype()} {var.emit()};\n"]
|
|
3314
3408
|
else:
|
|
3315
3409
|
lines += [f"const {var.ctype()} {var.emit()} = {constant_str(var.constant)};\n"]
|
|
@@ -3344,7 +3438,9 @@ def codegen_func_reverse(adj, func_type="kernel", device="cpu"):
|
|
|
3344
3438
|
lines += ["// primal vars\n"]
|
|
3345
3439
|
|
|
3346
3440
|
for var in adj.variables:
|
|
3347
|
-
if var.
|
|
3441
|
+
if is_tile(var.type):
|
|
3442
|
+
lines += [f"{var.ctype()} {var.emit()} = {var.type.cinit(requires_grad=True)};\n"]
|
|
3443
|
+
elif var.constant is None:
|
|
3348
3444
|
lines += [f"{var.ctype()} {var.emit()};\n"]
|
|
3349
3445
|
else:
|
|
3350
3446
|
lines += [f"const {var.ctype()} {var.emit()} = {constant_str(var.constant)};\n"]
|
|
@@ -3354,7 +3450,20 @@ def codegen_func_reverse(adj, func_type="kernel", device="cpu"):
|
|
|
3354
3450
|
lines += ["// dual vars\n"]
|
|
3355
3451
|
|
|
3356
3452
|
for var in adj.variables:
|
|
3357
|
-
|
|
3453
|
+
name = var.emit_adj()
|
|
3454
|
+
ctype = var.ctype(value_type=True)
|
|
3455
|
+
|
|
3456
|
+
if is_tile(var.type):
|
|
3457
|
+
if var.type.storage == "register":
|
|
3458
|
+
lines += [
|
|
3459
|
+
f"{var.type.ctype()} {name}(0.0);\n"
|
|
3460
|
+
] # reverse mode tiles alias the forward vars since shared tiles store both primal/dual vars together
|
|
3461
|
+
elif var.type.storage == "shared":
|
|
3462
|
+
lines += [
|
|
3463
|
+
f"{var.type.ctype()}& {name} = {var.emit()};\n"
|
|
3464
|
+
] # reverse mode tiles alias the forward vars since shared tiles store both primal/dual vars together
|
|
3465
|
+
else:
|
|
3466
|
+
lines += [f"{ctype} {name} = {{}};\n"]
|
|
3358
3467
|
|
|
3359
3468
|
# forward pass
|
|
3360
3469
|
lines += ["//---------\n"]
|
|
@@ -3383,6 +3492,33 @@ def codegen_func(adj, c_func_name: str, device="cpu", options=None):
|
|
|
3383
3492
|
if options is None:
|
|
3384
3493
|
options = {}
|
|
3385
3494
|
|
|
3495
|
+
if adj.return_var is not None and "return" in adj.arg_types:
|
|
3496
|
+
if get_type_origin(adj.arg_types["return"]) is tuple:
|
|
3497
|
+
if len(get_type_args(adj.arg_types["return"])) != len(adj.return_var):
|
|
3498
|
+
raise WarpCodegenError(
|
|
3499
|
+
f"The function `{adj.fun_name}` has its return type "
|
|
3500
|
+
f"annotated as a tuple of {len(get_type_args(adj.arg_types['return']))} elements "
|
|
3501
|
+
f"but the code returns {len(adj.return_var)} values."
|
|
3502
|
+
)
|
|
3503
|
+
elif not types_equal(adj.arg_types["return"], tuple(x.type for x in adj.return_var)):
|
|
3504
|
+
raise WarpCodegenError(
|
|
3505
|
+
f"The function `{adj.fun_name}` has its return type "
|
|
3506
|
+
f"annotated as `{warp.context.type_str(adj.arg_types['return'])}` "
|
|
3507
|
+
f"but the code returns a tuple with types `({', '.join(warp.context.type_str(x.type) for x in adj.return_var)})`."
|
|
3508
|
+
)
|
|
3509
|
+
elif len(adj.return_var) > 1 and get_type_origin(adj.arg_types["return"]) is not tuple:
|
|
3510
|
+
raise WarpCodegenError(
|
|
3511
|
+
f"The function `{adj.fun_name}` has its return type "
|
|
3512
|
+
f"annotated as `{warp.context.type_str(adj.arg_types['return'])}` "
|
|
3513
|
+
f"but the code returns {len(adj.return_var)} values."
|
|
3514
|
+
)
|
|
3515
|
+
elif not types_equal(adj.arg_types["return"], adj.return_var[0].type):
|
|
3516
|
+
raise WarpCodegenError(
|
|
3517
|
+
f"The function `{adj.fun_name}` has its return type "
|
|
3518
|
+
f"annotated as `{warp.context.type_str(adj.arg_types['return'])}` "
|
|
3519
|
+
f"but the code returns a value of type `{warp.context.type_str(adj.return_var[0].type)}`."
|
|
3520
|
+
)
|
|
3521
|
+
|
|
3386
3522
|
# forward header
|
|
3387
3523
|
if adj.return_var is not None and len(adj.return_var) == 1:
|
|
3388
3524
|
return_type = adj.return_var[0].ctype()
|
warp/config.py
CHANGED
|
@@ -7,7 +7,7 @@
|
|
|
7
7
|
|
|
8
8
|
from typing import Optional
|
|
9
9
|
|
|
10
|
-
version: str = "1.
|
|
10
|
+
version: str = "1.5.1"
|
|
11
11
|
"""Warp version string"""
|
|
12
12
|
|
|
13
13
|
verify_fp: bool = False
|
|
@@ -16,7 +16,7 @@ Has performance implications.
|
|
|
16
16
|
"""
|
|
17
17
|
|
|
18
18
|
verify_cuda: bool = False
|
|
19
|
-
"""If `True`, Warp will check for CUDA errors after every launch
|
|
19
|
+
"""If `True`, Warp will check for CUDA errors after every launch operation.
|
|
20
20
|
CUDA error verification cannot be used during graph capture. Has performance implications.
|
|
21
21
|
"""
|
|
22
22
|
|