warp-lang 1.4.1__py3-none-macosx_10_13_universal2.whl → 1.5.0__py3-none-macosx_10_13_universal2.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +4 -0
- warp/autograd.py +43 -8
- warp/bin/libwarp-clang.dylib +0 -0
- warp/bin/libwarp.dylib +0 -0
- warp/build.py +21 -2
- warp/build_dll.py +23 -6
- warp/builtins.py +1920 -111
- warp/codegen.py +186 -62
- warp/config.py +2 -2
- warp/context.py +322 -73
- warp/examples/assets/pixel.jpg +0 -0
- warp/examples/benchmarks/benchmark_cloth_paddle.py +86 -0
- warp/examples/benchmarks/benchmark_gemm.py +121 -0
- warp/examples/benchmarks/benchmark_interop_paddle.py +158 -0
- warp/examples/benchmarks/benchmark_tile.py +179 -0
- warp/examples/core/example_dem.py +2 -1
- warp/examples/core/example_mesh_intersect.py +3 -3
- warp/examples/fem/example_adaptive_grid.py +37 -10
- warp/examples/fem/example_apic_fluid.py +3 -2
- warp/examples/fem/example_convection_diffusion_dg.py +4 -5
- warp/examples/fem/example_deformed_geometry.py +1 -1
- warp/examples/fem/example_diffusion_3d.py +47 -4
- warp/examples/fem/example_distortion_energy.py +220 -0
- warp/examples/fem/example_magnetostatics.py +127 -85
- warp/examples/fem/example_nonconforming_contact.py +5 -5
- warp/examples/fem/example_stokes.py +3 -1
- warp/examples/fem/example_streamlines.py +12 -19
- warp/examples/fem/utils.py +38 -15
- warp/examples/optim/example_walker.py +2 -2
- warp/examples/sim/example_cloth.py +2 -25
- warp/examples/sim/example_jacobian_ik.py +6 -2
- warp/examples/sim/example_quadruped.py +2 -1
- warp/examples/tile/example_tile_convolution.py +58 -0
- warp/examples/tile/example_tile_fft.py +47 -0
- warp/examples/tile/example_tile_filtering.py +105 -0
- warp/examples/tile/example_tile_matmul.py +79 -0
- warp/examples/tile/example_tile_mlp.py +375 -0
- warp/fem/__init__.py +8 -0
- warp/fem/cache.py +16 -12
- warp/fem/dirichlet.py +1 -1
- warp/fem/domain.py +44 -1
- warp/fem/field/__init__.py +1 -2
- warp/fem/field/field.py +31 -19
- warp/fem/field/nodal_field.py +101 -49
- warp/fem/field/virtual.py +794 -0
- warp/fem/geometry/__init__.py +2 -2
- warp/fem/geometry/deformed_geometry.py +3 -105
- warp/fem/geometry/element.py +13 -0
- warp/fem/geometry/geometry.py +165 -5
- warp/fem/geometry/grid_2d.py +3 -6
- warp/fem/geometry/grid_3d.py +31 -28
- warp/fem/geometry/hexmesh.py +3 -46
- warp/fem/geometry/nanogrid.py +3 -2
- warp/fem/geometry/{quadmesh_2d.py → quadmesh.py} +280 -159
- warp/fem/geometry/tetmesh.py +2 -43
- warp/fem/geometry/{trimesh_2d.py → trimesh.py} +354 -186
- warp/fem/integrate.py +683 -261
- warp/fem/linalg.py +404 -0
- warp/fem/operator.py +101 -18
- warp/fem/polynomial.py +5 -5
- warp/fem/quadrature/quadrature.py +45 -21
- warp/fem/space/__init__.py +45 -11
- warp/fem/space/basis_function_space.py +451 -0
- warp/fem/space/basis_space.py +58 -11
- warp/fem/space/function_space.py +146 -5
- warp/fem/space/grid_2d_function_space.py +80 -66
- warp/fem/space/grid_3d_function_space.py +113 -68
- warp/fem/space/hexmesh_function_space.py +96 -108
- warp/fem/space/nanogrid_function_space.py +62 -110
- warp/fem/space/quadmesh_function_space.py +208 -0
- warp/fem/space/shape/__init__.py +45 -7
- warp/fem/space/shape/cube_shape_function.py +328 -54
- warp/fem/space/shape/shape_function.py +10 -1
- warp/fem/space/shape/square_shape_function.py +328 -60
- warp/fem/space/shape/tet_shape_function.py +269 -19
- warp/fem/space/shape/triangle_shape_function.py +238 -19
- warp/fem/space/tetmesh_function_space.py +69 -37
- warp/fem/space/topology.py +38 -0
- warp/fem/space/trimesh_function_space.py +179 -0
- warp/fem/utils.py +6 -331
- warp/jax_experimental.py +3 -1
- warp/native/array.h +55 -40
- warp/native/builtin.h +124 -43
- warp/native/bvh.h +4 -0
- warp/native/coloring.cpp +600 -0
- warp/native/cuda_util.cpp +14 -0
- warp/native/cuda_util.h +2 -1
- warp/native/fabric.h +8 -0
- warp/native/hashgrid.h +4 -0
- warp/native/marching.cu +8 -0
- warp/native/mat.h +14 -3
- warp/native/mathdx.cpp +59 -0
- warp/native/mesh.h +4 -0
- warp/native/range.h +13 -1
- warp/native/reduce.cpp +9 -1
- warp/native/reduce.cu +7 -0
- warp/native/runlength_encode.cpp +9 -1
- warp/native/runlength_encode.cu +7 -1
- warp/native/scan.cpp +8 -0
- warp/native/scan.cu +8 -0
- warp/native/scan.h +8 -1
- warp/native/sparse.cpp +8 -0
- warp/native/sparse.cu +8 -0
- warp/native/temp_buffer.h +7 -0
- warp/native/tile.h +1857 -0
- warp/native/tile_gemm.h +341 -0
- warp/native/tile_reduce.h +210 -0
- warp/native/volume_builder.cu +8 -0
- warp/native/volume_builder.h +8 -0
- warp/native/warp.cpp +10 -2
- warp/native/warp.cu +369 -15
- warp/native/warp.h +12 -2
- warp/optim/adam.py +39 -4
- warp/paddle.py +29 -12
- warp/render/render_opengl.py +137 -65
- warp/sim/graph_coloring.py +292 -0
- warp/sim/integrator_euler.py +4 -2
- warp/sim/integrator_featherstone.py +115 -44
- warp/sim/integrator_vbd.py +6 -0
- warp/sim/model.py +90 -17
- warp/stubs.py +651 -85
- warp/tape.py +12 -7
- warp/tests/assets/pixel.npy +0 -0
- warp/tests/aux_test_instancing_gc.py +18 -0
- warp/tests/test_array.py +207 -48
- warp/tests/test_closest_point_edge_edge.py +8 -8
- warp/tests/test_codegen.py +120 -1
- warp/tests/test_codegen_instancing.py +30 -0
- warp/tests/test_collision.py +110 -0
- warp/tests/test_coloring.py +241 -0
- warp/tests/test_context.py +34 -0
- warp/tests/test_examples.py +18 -4
- warp/tests/test_fabricarray.py +33 -0
- warp/tests/test_fem.py +453 -113
- warp/tests/test_func.py +48 -1
- warp/tests/test_generics.py +52 -0
- warp/tests/test_iter.py +68 -0
- warp/tests/test_mat_scalar_ops.py +1 -1
- warp/tests/test_mesh_query_point.py +5 -4
- warp/tests/test_module_hashing.py +23 -0
- warp/tests/test_paddle.py +27 -87
- warp/tests/test_print.py +191 -1
- warp/tests/test_spatial.py +1 -1
- warp/tests/test_tile.py +700 -0
- warp/tests/test_tile_mathdx.py +144 -0
- warp/tests/test_tile_mlp.py +383 -0
- warp/tests/test_tile_reduce.py +374 -0
- warp/tests/test_tile_shared_memory.py +190 -0
- warp/tests/test_vbd.py +12 -20
- warp/tests/test_volume.py +43 -0
- warp/tests/unittest_suites.py +23 -2
- warp/tests/unittest_utils.py +4 -0
- warp/types.py +339 -73
- warp/utils.py +22 -1
- {warp_lang-1.4.1.dist-info → warp_lang-1.5.0.dist-info}/METADATA +33 -7
- {warp_lang-1.4.1.dist-info → warp_lang-1.5.0.dist-info}/RECORD +159 -132
- {warp_lang-1.4.1.dist-info → warp_lang-1.5.0.dist-info}/WHEEL +1 -1
- warp/fem/field/test.py +0 -180
- warp/fem/field/trial.py +0 -183
- warp/fem/space/collocated_function_space.py +0 -102
- warp/fem/space/quadmesh_2d_function_space.py +0 -261
- warp/fem/space/trimesh_2d_function_space.py +0 -153
- {warp_lang-1.4.1.dist-info → warp_lang-1.5.0.dist-info}/LICENSE.md +0 -0
- {warp_lang-1.4.1.dist-info → warp_lang-1.5.0.dist-info}/top_level.txt +0 -0
warp/codegen.py
CHANGED
|
@@ -23,6 +23,10 @@ from typing import Any, Callable, Dict, Mapping, Optional, Sequence
|
|
|
23
23
|
import warp.config
|
|
24
24
|
from warp.types import *
|
|
25
25
|
|
|
26
|
+
# used as a globally accessible copy
|
|
27
|
+
# of current compile options (block_dim) etc
|
|
28
|
+
options = {}
|
|
29
|
+
|
|
26
30
|
|
|
27
31
|
class WarpCodegenError(RuntimeError):
|
|
28
32
|
def __init__(self, message):
|
|
@@ -110,6 +114,16 @@ def get_closure_cell_contents(obj):
|
|
|
110
114
|
return None
|
|
111
115
|
|
|
112
116
|
|
|
117
|
+
def get_type_origin(tp):
|
|
118
|
+
# Compatible version of `typing.get_origin()` for Python 3.7 and older.
|
|
119
|
+
return getattr(tp, "__origin__", None)
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def get_type_args(tp):
|
|
123
|
+
# Compatible version of `typing.get_args()` for Python 3.7 and older.
|
|
124
|
+
return getattr(tp, "__args__", ())
|
|
125
|
+
|
|
126
|
+
|
|
113
127
|
def eval_annotations(annotations: Mapping[str, Any], obj: Any) -> Mapping[str, Any]:
|
|
114
128
|
"""Un-stringize annotations caused by `from __future__ import annotations` of PEP 563."""
|
|
115
129
|
# Implementation backported from `inspect.get_annotations()` for Python 3.9 and older.
|
|
@@ -637,6 +651,8 @@ class Var:
|
|
|
637
651
|
dtypestr = f"wp::{t.dtype.__name__}"
|
|
638
652
|
classstr = f"wp::{type(t).__name__}"
|
|
639
653
|
return f"{classstr}_t<{dtypestr}>"
|
|
654
|
+
elif is_tile(t):
|
|
655
|
+
return t.ctype()
|
|
640
656
|
elif isinstance(t, Struct):
|
|
641
657
|
return t.native_name
|
|
642
658
|
elif isinstance(t, type) and issubclass(t, StructInstance):
|
|
@@ -876,7 +892,7 @@ class Adjoint:
|
|
|
876
892
|
# use source-level argument annotations
|
|
877
893
|
if len(argspec.annotations) < len(argspec.args):
|
|
878
894
|
raise WarpCodegenError(f"Incomplete argument annotations on function {adj.fun_name}")
|
|
879
|
-
adj.arg_types = argspec.annotations
|
|
895
|
+
adj.arg_types = {k: v for k, v in argspec.annotations.items() if not (k == "return" and v is None)}
|
|
880
896
|
else:
|
|
881
897
|
# use overload argument annotations
|
|
882
898
|
for arg_name in argspec.args:
|
|
@@ -914,6 +930,28 @@ class Adjoint:
|
|
|
914
930
|
# for unit testing errors being spit out from kernels.
|
|
915
931
|
adj.skip_build = False
|
|
916
932
|
|
|
933
|
+
# Collect the LTOIR required at link-time
|
|
934
|
+
adj.ltoirs = []
|
|
935
|
+
|
|
936
|
+
# allocate extra space for a function call that requires its
|
|
937
|
+
# own shared memory space, we treat shared memory as a stack
|
|
938
|
+
# where each function pushes and pops space off, the extra
|
|
939
|
+
# quantity is the 'roofline' amount required for the entire kernel
|
|
940
|
+
def alloc_shared_extra(adj, num_bytes):
|
|
941
|
+
adj.max_required_extra_shared_memory = max(adj.max_required_extra_shared_memory, num_bytes)
|
|
942
|
+
|
|
943
|
+
# returns the total number of bytes for a function
|
|
944
|
+
# based on it's own requirements + worst case
|
|
945
|
+
# requirements of any dependent functions
|
|
946
|
+
def get_total_required_shared(adj):
|
|
947
|
+
total_shared = 0
|
|
948
|
+
|
|
949
|
+
for var in adj.variables:
|
|
950
|
+
if is_tile(var.type) and var.type.storage == "shared":
|
|
951
|
+
total_shared += var.type.size_in_bytes()
|
|
952
|
+
|
|
953
|
+
return total_shared + adj.max_required_extra_shared_memory
|
|
954
|
+
|
|
917
955
|
# generate function ssa form and adjoint
|
|
918
956
|
def build(adj, builder, default_builder_options=None):
|
|
919
957
|
# arg Var read/write flags are held during module rebuilds, so we reset here even when skipping a build
|
|
@@ -934,12 +972,17 @@ class Adjoint:
|
|
|
934
972
|
else:
|
|
935
973
|
adj.builder_options = default_builder_options
|
|
936
974
|
|
|
975
|
+
global options
|
|
976
|
+
options = adj.builder_options
|
|
977
|
+
|
|
937
978
|
adj.symbols = {} # map from symbols to adjoint variables
|
|
938
979
|
adj.variables = [] # list of local variables (in order)
|
|
939
980
|
|
|
940
981
|
adj.return_var = None # return type for function or kernel
|
|
941
982
|
adj.loop_symbols = [] # symbols at the start of each loop
|
|
942
|
-
adj.loop_const_iter_symbols =
|
|
983
|
+
adj.loop_const_iter_symbols = (
|
|
984
|
+
set()
|
|
985
|
+
) # constant iteration variables for static loops (mutating them does not raise an error)
|
|
943
986
|
|
|
944
987
|
# blocks
|
|
945
988
|
adj.blocks = [Block()]
|
|
@@ -951,6 +994,9 @@ class Adjoint:
|
|
|
951
994
|
# used to generate new label indices
|
|
952
995
|
adj.label_count = 0
|
|
953
996
|
|
|
997
|
+
# tracks how much additional shared memory is required by any dependent function calls
|
|
998
|
+
adj.max_required_extra_shared_memory = 0
|
|
999
|
+
|
|
954
1000
|
# update symbol map for each argument
|
|
955
1001
|
for a in adj.args:
|
|
956
1002
|
adj.symbols[a.label] = a
|
|
@@ -967,6 +1013,7 @@ class Adjoint:
|
|
|
967
1013
|
e = ex(";".join([msg] + [str(a) for a in data.args])).with_traceback(traceback)
|
|
968
1014
|
finally:
|
|
969
1015
|
adj.skip_build = True
|
|
1016
|
+
adj.builder = None
|
|
970
1017
|
raise e
|
|
971
1018
|
|
|
972
1019
|
if builder is not None:
|
|
@@ -976,6 +1023,9 @@ class Adjoint:
|
|
|
976
1023
|
elif isinstance(a.type, warp.types.array) and isinstance(a.type.dtype, Struct):
|
|
977
1024
|
builder.build_struct_recursive(a.type.dtype)
|
|
978
1025
|
|
|
1026
|
+
# release builder reference for GC
|
|
1027
|
+
adj.builder = None
|
|
1028
|
+
|
|
979
1029
|
# code generation methods
|
|
980
1030
|
def format_template(adj, template, input_vars, output_var):
|
|
981
1031
|
# output var is always the 0th index
|
|
@@ -992,9 +1042,9 @@ class Adjoint:
|
|
|
992
1042
|
if isinstance(a, warp.context.Function):
|
|
993
1043
|
# functions don't have a var_ prefix so strip it off here
|
|
994
1044
|
if prefix == "var":
|
|
995
|
-
arg_strs.append(a.native_func)
|
|
1045
|
+
arg_strs.append(f"{a.namespace}{a.native_func}")
|
|
996
1046
|
else:
|
|
997
|
-
arg_strs.append(f"{prefix}_{a.native_func}")
|
|
1047
|
+
arg_strs.append(f"{a.namespace}{prefix}_{a.native_func}")
|
|
998
1048
|
elif is_reference(a.type):
|
|
999
1049
|
arg_strs.append(f"{prefix}_{a}")
|
|
1000
1050
|
elif isinstance(a, Var):
|
|
@@ -1276,15 +1326,34 @@ class Adjoint:
|
|
|
1276
1326
|
bound_arg_values,
|
|
1277
1327
|
)
|
|
1278
1328
|
|
|
1279
|
-
|
|
1280
|
-
|
|
1281
|
-
#
|
|
1282
|
-
|
|
1283
|
-
|
|
1284
|
-
|
|
1285
|
-
#
|
|
1286
|
-
|
|
1329
|
+
# immediately allocate output variables so we can pass them into the dispatch method
|
|
1330
|
+
if return_type is None:
|
|
1331
|
+
# void function
|
|
1332
|
+
output = None
|
|
1333
|
+
output_list = []
|
|
1334
|
+
elif not isinstance(return_type, Sequence) or len(return_type) == 1:
|
|
1335
|
+
# single return value function
|
|
1336
|
+
if isinstance(return_type, Sequence):
|
|
1337
|
+
return_type = return_type[0]
|
|
1338
|
+
output = adj.add_var(return_type)
|
|
1339
|
+
output_list = [output]
|
|
1340
|
+
else:
|
|
1341
|
+
# multiple return value function
|
|
1342
|
+
output = [adj.add_var(v) for v in return_type]
|
|
1343
|
+
output_list = output
|
|
1287
1344
|
|
|
1345
|
+
# If we have a built-in that requires special handling to dispatch
|
|
1346
|
+
# the arguments to the underlying C++ function, then we can resolve
|
|
1347
|
+
# these using the `dispatch_func`. Since this is only called from
|
|
1348
|
+
# within codegen, we pass it directly `codegen.Var` objects,
|
|
1349
|
+
# which allows for some more advanced resolution to be performed,
|
|
1350
|
+
# for example by checking whether an argument corresponds to
|
|
1351
|
+
# a literal value or references a variable.
|
|
1352
|
+
if func.lto_dispatch_func is not None:
|
|
1353
|
+
func_args, template_args, ltoirs = func.lto_dispatch_func(
|
|
1354
|
+
func.input_types, return_type, output_list, bound_args, options=adj.builder_options, builder=adj.builder
|
|
1355
|
+
)
|
|
1356
|
+
elif func.dispatch_func is not None:
|
|
1288
1357
|
func_args, template_args = func.dispatch_func(func.input_types, return_type, bound_args)
|
|
1289
1358
|
else:
|
|
1290
1359
|
func_args = tuple(bound_args.values())
|
|
@@ -1299,18 +1368,14 @@ class Adjoint:
|
|
|
1299
1368
|
if not isinstance(func_arg, (Reference, warp.context.Function)):
|
|
1300
1369
|
func_arg = adj.load(func_arg)
|
|
1301
1370
|
|
|
1302
|
-
# if the argument is a function, build it recursively
|
|
1303
|
-
if isinstance(func_arg, warp.context.Function):
|
|
1371
|
+
# if the argument is a function (and not a builtin), then build it recursively
|
|
1372
|
+
if isinstance(func_arg, warp.context.Function) and not func_arg.is_builtin():
|
|
1304
1373
|
adj.builder.build_function(func_arg)
|
|
1305
1374
|
|
|
1306
1375
|
fwd_args.append(strip_reference(func_arg))
|
|
1307
1376
|
|
|
1308
1377
|
if return_type is None:
|
|
1309
1378
|
# handles expression (zero output) functions, e.g.: void do_something();
|
|
1310
|
-
|
|
1311
|
-
output = None
|
|
1312
|
-
output_list = []
|
|
1313
|
-
|
|
1314
1379
|
forward_call = (
|
|
1315
1380
|
f"{func.namespace}{func_name}({adj.format_forward_call_args(fwd_args, use_initializer_list)});"
|
|
1316
1381
|
)
|
|
@@ -1320,12 +1385,6 @@ class Adjoint:
|
|
|
1320
1385
|
|
|
1321
1386
|
elif not isinstance(return_type, Sequence) or len(return_type) == 1:
|
|
1322
1387
|
# handle simple function (one output)
|
|
1323
|
-
|
|
1324
|
-
if isinstance(return_type, Sequence):
|
|
1325
|
-
return_type = return_type[0]
|
|
1326
|
-
output = adj.add_var(return_type)
|
|
1327
|
-
output_list = [output]
|
|
1328
|
-
|
|
1329
1388
|
forward_call = f"var_{output} = {func.namespace}{func_name}({adj.format_forward_call_args(fwd_args, use_initializer_list)});"
|
|
1330
1389
|
replay_call = forward_call
|
|
1331
1390
|
if func.custom_replay_func is not None:
|
|
@@ -1333,10 +1392,6 @@ class Adjoint:
|
|
|
1333
1392
|
|
|
1334
1393
|
else:
|
|
1335
1394
|
# handle multiple value functions
|
|
1336
|
-
|
|
1337
|
-
output = [adj.add_var(v) for v in return_type]
|
|
1338
|
-
output_list = output
|
|
1339
|
-
|
|
1340
1395
|
forward_call = (
|
|
1341
1396
|
f"{func.namespace}{func_name}({adj.format_forward_call_args(fwd_args + output, use_initializer_list)});"
|
|
1342
1397
|
)
|
|
@@ -1364,6 +1419,11 @@ class Adjoint:
|
|
|
1364
1419
|
reverse_call = f"{func.namespace}adj_{func.native_func}({arg_str});"
|
|
1365
1420
|
adj.add_reverse(reverse_call)
|
|
1366
1421
|
|
|
1422
|
+
# update our smem roofline requirements based on any
|
|
1423
|
+
# shared memory required by the dependent function call
|
|
1424
|
+
if not func.is_builtin():
|
|
1425
|
+
adj.alloc_shared_extra(func.adj.get_total_required_shared())
|
|
1426
|
+
|
|
1367
1427
|
return output
|
|
1368
1428
|
|
|
1369
1429
|
def add_builtin_call(adj, func_name, args, min_outputs=None):
|
|
@@ -1464,7 +1524,10 @@ class Adjoint:
|
|
|
1464
1524
|
|
|
1465
1525
|
# zero adjoints
|
|
1466
1526
|
for i in body_block.vars:
|
|
1467
|
-
|
|
1527
|
+
if is_tile(i.type):
|
|
1528
|
+
reverse.append(adj.indentation + f"\t{i.emit_adj()}.grad_zero();")
|
|
1529
|
+
else:
|
|
1530
|
+
reverse.append(adj.indentation + f"\t{i.emit_adj()} = {{}};")
|
|
1468
1531
|
|
|
1469
1532
|
# replay
|
|
1470
1533
|
for i in body_block.body_replay:
|
|
@@ -2000,22 +2063,11 @@ class Adjoint:
|
|
|
2000
2063
|
)
|
|
2001
2064
|
return range_call
|
|
2002
2065
|
|
|
2003
|
-
def begin_record_constant_iter_symbols(adj):
|
|
2004
|
-
if len(adj.loop_const_iter_symbols) > 0:
|
|
2005
|
-
adj.loop_const_iter_symbols.append(adj.loop_const_iter_symbols[-1])
|
|
2006
|
-
else:
|
|
2007
|
-
adj.loop_const_iter_symbols.append(set())
|
|
2008
|
-
|
|
2009
|
-
def end_record_constant_iter_symbols(adj):
|
|
2010
|
-
if len(adj.loop_const_iter_symbols) > 0:
|
|
2011
|
-
adj.loop_const_iter_symbols.pop()
|
|
2012
|
-
|
|
2013
2066
|
def record_constant_iter_symbol(adj, sym):
|
|
2014
|
-
|
|
2015
|
-
adj.loop_const_iter_symbols[-1].add(sym)
|
|
2067
|
+
adj.loop_const_iter_symbols.add(sym)
|
|
2016
2068
|
|
|
2017
2069
|
def is_constant_iter_symbol(adj, sym):
|
|
2018
|
-
return
|
|
2070
|
+
return sym in adj.loop_const_iter_symbols
|
|
2019
2071
|
|
|
2020
2072
|
def emit_For(adj, node):
|
|
2021
2073
|
# try and unroll simple range() statements that use constant args
|
|
@@ -2045,7 +2097,6 @@ class Adjoint:
|
|
|
2045
2097
|
iter = adj.eval(node.iter)
|
|
2046
2098
|
|
|
2047
2099
|
adj.symbols[node.target.id] = adj.begin_for(iter)
|
|
2048
|
-
adj.begin_record_constant_iter_symbols()
|
|
2049
2100
|
|
|
2050
2101
|
# for loops should be side-effect free, here we store a copy
|
|
2051
2102
|
adj.loop_symbols.append(adj.symbols.copy())
|
|
@@ -2056,7 +2107,6 @@ class Adjoint:
|
|
|
2056
2107
|
|
|
2057
2108
|
adj.materialize_redefinitions(adj.loop_symbols[-1])
|
|
2058
2109
|
adj.loop_symbols.pop()
|
|
2059
|
-
adj.end_record_constant_iter_symbols()
|
|
2060
2110
|
|
|
2061
2111
|
adj.end_for(iter)
|
|
2062
2112
|
|
|
@@ -2217,7 +2267,7 @@ class Adjoint:
|
|
|
2217
2267
|
|
|
2218
2268
|
# returns the object being indexed, and the list of indices
|
|
2219
2269
|
def eval_subscript(adj, node):
|
|
2220
|
-
# We want to coalesce multi-
|
|
2270
|
+
# We want to coalesce multi-dimensional array indexing into a single operation. This needs to deal with expressions like `a[i][j][x][y]` where `a` is a 2D array of matrices,
|
|
2221
2271
|
# and essentially rewrite it into `a[i, j][x][y]`. Since the AST observes the indexing right-to-left, and we don't want to evaluate the index expressions prematurely,
|
|
2222
2272
|
# this requires a first loop to check if this `node` only performs indexing on the array, and a second loop to evaluate and collect index variables.
|
|
2223
2273
|
root = node
|
|
@@ -2297,6 +2347,14 @@ class Adjoint:
|
|
|
2297
2347
|
out.is_read = target.is_read
|
|
2298
2348
|
out.is_write = target.is_write
|
|
2299
2349
|
|
|
2350
|
+
elif is_tile(target_type):
|
|
2351
|
+
if len(indices) == 2:
|
|
2352
|
+
# handles extracting a single element from a tile
|
|
2353
|
+
out = adj.add_builtin_call("tile_extract", [target, *indices])
|
|
2354
|
+
else:
|
|
2355
|
+
# handles tile views
|
|
2356
|
+
out = adj.add_builtin_call("tile_view", [target, *indices])
|
|
2357
|
+
|
|
2300
2358
|
else:
|
|
2301
2359
|
# handles non-array type indexing, e.g: vec3, mat33, etc
|
|
2302
2360
|
out = adj.add_builtin_call("extract", [target, *indices])
|
|
@@ -2538,11 +2596,22 @@ class Adjoint:
|
|
|
2538
2596
|
target_type = strip_reference(target.type)
|
|
2539
2597
|
|
|
2540
2598
|
if is_array(target_type):
|
|
2541
|
-
#
|
|
2542
|
-
if target_type.dtype
|
|
2599
|
+
# target_types int8, uint8, int16, uint16 are not suitable for atomic array accumulation
|
|
2600
|
+
if target_type.dtype in warp.types.non_atomic_types:
|
|
2543
2601
|
make_new_assign_statement()
|
|
2544
2602
|
return
|
|
2545
2603
|
|
|
2604
|
+
# the same holds true for vecs/mats/quats that are composed of these types
|
|
2605
|
+
if (
|
|
2606
|
+
type_is_vector(target_type.dtype)
|
|
2607
|
+
or type_is_quaternion(target_type.dtype)
|
|
2608
|
+
or type_is_matrix(target_type.dtype)
|
|
2609
|
+
):
|
|
2610
|
+
dtype = getattr(target_type.dtype, "_wp_scalar_type_", None)
|
|
2611
|
+
if dtype in warp.types.non_atomic_types:
|
|
2612
|
+
make_new_assign_statement()
|
|
2613
|
+
return
|
|
2614
|
+
|
|
2546
2615
|
kernel_name = adj.fun_name
|
|
2547
2616
|
filename = adj.filename
|
|
2548
2617
|
lineno = adj.lineno + adj.fun_lineno
|
|
@@ -2559,7 +2628,10 @@ class Adjoint:
|
|
|
2559
2628
|
if warp.config.verify_autograd_array_access:
|
|
2560
2629
|
target.mark_write(kernel_name=kernel_name, filename=filename, lineno=lineno)
|
|
2561
2630
|
else:
|
|
2562
|
-
|
|
2631
|
+
if warp.config.verbose:
|
|
2632
|
+
print(f"Warning: in-place op {node.op} is not differentiable")
|
|
2633
|
+
make_new_assign_statement()
|
|
2634
|
+
return
|
|
2563
2635
|
|
|
2564
2636
|
# TODO
|
|
2565
2637
|
elif type_is_vector(target_type) or type_is_quaternion(target_type) or type_is_matrix(target_type):
|
|
@@ -2963,6 +3035,7 @@ class Adjoint:
|
|
|
2963
3035
|
# code generation
|
|
2964
3036
|
|
|
2965
3037
|
cpu_module_header = """
|
|
3038
|
+
#define WP_TILE_BLOCK_DIM {tile_size}
|
|
2966
3039
|
#define WP_NO_CRT
|
|
2967
3040
|
#include "builtin.h"
|
|
2968
3041
|
|
|
@@ -2973,7 +3046,7 @@ cpu_module_header = """
|
|
|
2973
3046
|
#define int(x) cast_int(x)
|
|
2974
3047
|
#define adj_int(x, adj_x, adj_ret) adj_cast_int(x, adj_x, adj_ret)
|
|
2975
3048
|
|
|
2976
|
-
#define builtin_tid1d() wp::tid(task_index)
|
|
3049
|
+
#define builtin_tid1d() wp::tid(task_index, dim)
|
|
2977
3050
|
#define builtin_tid2d(x, y) wp::tid(x, y, task_index, dim)
|
|
2978
3051
|
#define builtin_tid3d(x, y, z) wp::tid(x, y, z, task_index, dim)
|
|
2979
3052
|
#define builtin_tid4d(x, y, z, w) wp::tid(x, y, z, w, task_index, dim)
|
|
@@ -2981,6 +3054,7 @@ cpu_module_header = """
|
|
|
2981
3054
|
"""
|
|
2982
3055
|
|
|
2983
3056
|
cuda_module_header = """
|
|
3057
|
+
#define WP_TILE_BLOCK_DIM {tile_size}
|
|
2984
3058
|
#define WP_NO_CRT
|
|
2985
3059
|
#include "builtin.h"
|
|
2986
3060
|
|
|
@@ -2991,10 +3065,10 @@ cuda_module_header = """
|
|
|
2991
3065
|
#define int(x) cast_int(x)
|
|
2992
3066
|
#define adj_int(x, adj_x, adj_ret) adj_cast_int(x, adj_x, adj_ret)
|
|
2993
3067
|
|
|
2994
|
-
#define builtin_tid1d() wp::tid(
|
|
2995
|
-
#define builtin_tid2d(x, y) wp::tid(x, y,
|
|
2996
|
-
#define builtin_tid3d(x, y, z) wp::tid(x, y, z,
|
|
2997
|
-
#define builtin_tid4d(x, y, z, w) wp::tid(x, y, z, w,
|
|
3068
|
+
#define builtin_tid1d() wp::tid(_idx, dim)
|
|
3069
|
+
#define builtin_tid2d(x, y) wp::tid(x, y, _idx, dim)
|
|
3070
|
+
#define builtin_tid3d(x, y, z) wp::tid(x, y, z, _idx, dim)
|
|
3071
|
+
#define builtin_tid4d(x, y, z, w) wp::tid(x, y, z, w, _idx, dim)
|
|
2998
3072
|
|
|
2999
3073
|
"""
|
|
3000
3074
|
|
|
@@ -3066,20 +3140,26 @@ cuda_kernel_template = """
|
|
|
3066
3140
|
extern "C" __global__ void {name}_cuda_kernel_forward(
|
|
3067
3141
|
{forward_args})
|
|
3068
3142
|
{{
|
|
3069
|
-
for (size_t
|
|
3070
|
-
|
|
3071
|
-
|
|
3143
|
+
for (size_t _idx = static_cast<size_t>(blockDim.x) * static_cast<size_t>(blockIdx.x) + static_cast<size_t>(threadIdx.x);
|
|
3144
|
+
_idx < dim.size;
|
|
3145
|
+
_idx += static_cast<size_t>(blockDim.x) * static_cast<size_t>(gridDim.x))
|
|
3072
3146
|
{{
|
|
3147
|
+
// reset shared memory allocator
|
|
3148
|
+
wp::tile_alloc_shared(0, true);
|
|
3149
|
+
|
|
3073
3150
|
{forward_body} }}
|
|
3074
3151
|
}}
|
|
3075
3152
|
|
|
3076
3153
|
extern "C" __global__ void {name}_cuda_kernel_backward(
|
|
3077
3154
|
{reverse_args})
|
|
3078
3155
|
{{
|
|
3079
|
-
for (size_t
|
|
3080
|
-
|
|
3081
|
-
|
|
3156
|
+
for (size_t _idx = static_cast<size_t>(blockDim.x) * static_cast<size_t>(blockIdx.x) + static_cast<size_t>(threadIdx.x);
|
|
3157
|
+
_idx < dim.size;
|
|
3158
|
+
_idx += static_cast<size_t>(blockDim.x) * static_cast<size_t>(gridDim.x))
|
|
3082
3159
|
{{
|
|
3160
|
+
// reset shared memory allocator
|
|
3161
|
+
wp::tile_alloc_shared(0, true);
|
|
3162
|
+
|
|
3083
3163
|
{reverse_body} }}
|
|
3084
3164
|
}}
|
|
3085
3165
|
|
|
@@ -3317,7 +3397,9 @@ def codegen_func_forward(adj, func_type="kernel", device="cpu"):
|
|
|
3317
3397
|
lines += ["// primal vars\n"]
|
|
3318
3398
|
|
|
3319
3399
|
for var in adj.variables:
|
|
3320
|
-
if var.
|
|
3400
|
+
if is_tile(var.type):
|
|
3401
|
+
lines += [f"{var.ctype()} {var.emit()} = {var.type.cinit(requires_grad=False)};\n"]
|
|
3402
|
+
elif var.constant is None:
|
|
3321
3403
|
lines += [f"{var.ctype()} {var.emit()};\n"]
|
|
3322
3404
|
else:
|
|
3323
3405
|
lines += [f"const {var.ctype()} {var.emit()} = {constant_str(var.constant)};\n"]
|
|
@@ -3352,7 +3434,9 @@ def codegen_func_reverse(adj, func_type="kernel", device="cpu"):
|
|
|
3352
3434
|
lines += ["// primal vars\n"]
|
|
3353
3435
|
|
|
3354
3436
|
for var in adj.variables:
|
|
3355
|
-
if var.
|
|
3437
|
+
if is_tile(var.type):
|
|
3438
|
+
lines += [f"{var.ctype()} {var.emit()} = {var.type.cinit(requires_grad=True)};\n"]
|
|
3439
|
+
elif var.constant is None:
|
|
3356
3440
|
lines += [f"{var.ctype()} {var.emit()};\n"]
|
|
3357
3441
|
else:
|
|
3358
3442
|
lines += [f"const {var.ctype()} {var.emit()} = {constant_str(var.constant)};\n"]
|
|
@@ -3362,7 +3446,20 @@ def codegen_func_reverse(adj, func_type="kernel", device="cpu"):
|
|
|
3362
3446
|
lines += ["// dual vars\n"]
|
|
3363
3447
|
|
|
3364
3448
|
for var in adj.variables:
|
|
3365
|
-
|
|
3449
|
+
name = var.emit_adj()
|
|
3450
|
+
ctype = var.ctype(value_type=True)
|
|
3451
|
+
|
|
3452
|
+
if is_tile(var.type):
|
|
3453
|
+
if var.type.storage == "register":
|
|
3454
|
+
lines += [
|
|
3455
|
+
f"{var.type.ctype()} {name}(0.0);\n"
|
|
3456
|
+
] # reverse mode tiles alias the forward vars since shared tiles store both primal/dual vars together
|
|
3457
|
+
elif var.type.storage == "shared":
|
|
3458
|
+
lines += [
|
|
3459
|
+
f"{var.type.ctype()}& {name} = {var.emit()};\n"
|
|
3460
|
+
] # reverse mode tiles alias the forward vars since shared tiles store both primal/dual vars together
|
|
3461
|
+
else:
|
|
3462
|
+
lines += [f"{ctype} {name} = {{}};\n"]
|
|
3366
3463
|
|
|
3367
3464
|
# forward pass
|
|
3368
3465
|
lines += ["//---------\n"]
|
|
@@ -3391,6 +3488,33 @@ def codegen_func(adj, c_func_name: str, device="cpu", options=None):
|
|
|
3391
3488
|
if options is None:
|
|
3392
3489
|
options = {}
|
|
3393
3490
|
|
|
3491
|
+
if adj.return_var is not None and "return" in adj.arg_types:
|
|
3492
|
+
if get_type_origin(adj.arg_types["return"]) is tuple:
|
|
3493
|
+
if len(get_type_args(adj.arg_types["return"])) != len(adj.return_var):
|
|
3494
|
+
raise WarpCodegenError(
|
|
3495
|
+
f"The function `{adj.fun_name}` has its return type "
|
|
3496
|
+
f"annotated as a tuple of {len(get_type_args(adj.arg_types['return']))} elements "
|
|
3497
|
+
f"but the code returns {len(adj.return_var)} values."
|
|
3498
|
+
)
|
|
3499
|
+
elif not types_equal(adj.arg_types["return"], tuple(x.type for x in adj.return_var)):
|
|
3500
|
+
raise WarpCodegenError(
|
|
3501
|
+
f"The function `{adj.fun_name}` has its return type "
|
|
3502
|
+
f"annotated as `{warp.context.type_str(adj.arg_types['return'])}` "
|
|
3503
|
+
f"but the code returns a tuple with types `({', '.join(warp.context.type_str(x.type) for x in adj.return_var)})`."
|
|
3504
|
+
)
|
|
3505
|
+
elif len(adj.return_var) > 1 and get_type_origin(adj.arg_types["return"]) is not tuple:
|
|
3506
|
+
raise WarpCodegenError(
|
|
3507
|
+
f"The function `{adj.fun_name}` has its return type "
|
|
3508
|
+
f"annotated as `{warp.context.type_str(adj.arg_types['return'])}` "
|
|
3509
|
+
f"but the code returns {len(adj.return_var)} values."
|
|
3510
|
+
)
|
|
3511
|
+
elif not types_equal(adj.arg_types["return"], adj.return_var[0].type):
|
|
3512
|
+
raise WarpCodegenError(
|
|
3513
|
+
f"The function `{adj.fun_name}` has its return type "
|
|
3514
|
+
f"annotated as `{warp.context.type_str(adj.arg_types['return'])}` "
|
|
3515
|
+
f"but the code returns a value of type `{warp.context.type_str(adj.return_var[0].type)}`."
|
|
3516
|
+
)
|
|
3517
|
+
|
|
3394
3518
|
# forward header
|
|
3395
3519
|
if adj.return_var is not None and len(adj.return_var) == 1:
|
|
3396
3520
|
return_type = adj.return_var[0].ctype()
|
warp/config.py
CHANGED
|
@@ -7,7 +7,7 @@
|
|
|
7
7
|
|
|
8
8
|
from typing import Optional
|
|
9
9
|
|
|
10
|
-
version: str = "1.
|
|
10
|
+
version: str = "1.5.0"
|
|
11
11
|
"""Warp version string"""
|
|
12
12
|
|
|
13
13
|
verify_fp: bool = False
|
|
@@ -16,7 +16,7 @@ Has performance implications.
|
|
|
16
16
|
"""
|
|
17
17
|
|
|
18
18
|
verify_cuda: bool = False
|
|
19
|
-
"""If `True`, Warp will check for CUDA errors after every launch
|
|
19
|
+
"""If `True`, Warp will check for CUDA errors after every launch operation.
|
|
20
20
|
CUDA error verification cannot be used during graph capture. Has performance implications.
|
|
21
21
|
"""
|
|
22
22
|
|