warp-lang 1.4.2__py3-none-manylinux2014_aarch64.whl → 1.5.0__py3-none-manylinux2014_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +4 -0
- warp/autograd.py +43 -8
- warp/bin/warp-clang.so +0 -0
- warp/bin/warp.so +0 -0
- warp/build.py +21 -2
- warp/build_dll.py +23 -6
- warp/builtins.py +1783 -2
- warp/codegen.py +177 -45
- warp/config.py +2 -2
- warp/context.py +321 -73
- warp/examples/assets/pixel.jpg +0 -0
- warp/examples/benchmarks/benchmark_cloth_paddle.py +86 -0
- warp/examples/benchmarks/benchmark_gemm.py +121 -0
- warp/examples/benchmarks/benchmark_interop_paddle.py +158 -0
- warp/examples/benchmarks/benchmark_tile.py +179 -0
- warp/examples/fem/example_adaptive_grid.py +37 -10
- warp/examples/fem/example_apic_fluid.py +3 -2
- warp/examples/fem/example_convection_diffusion_dg.py +4 -5
- warp/examples/fem/example_deformed_geometry.py +1 -1
- warp/examples/fem/example_diffusion_3d.py +47 -4
- warp/examples/fem/example_distortion_energy.py +220 -0
- warp/examples/fem/example_magnetostatics.py +127 -85
- warp/examples/fem/example_nonconforming_contact.py +5 -5
- warp/examples/fem/example_stokes.py +3 -1
- warp/examples/fem/example_streamlines.py +12 -19
- warp/examples/fem/utils.py +38 -15
- warp/examples/sim/example_cloth.py +2 -25
- warp/examples/sim/example_quadruped.py +2 -1
- warp/examples/tile/example_tile_convolution.py +58 -0
- warp/examples/tile/example_tile_fft.py +47 -0
- warp/examples/tile/example_tile_filtering.py +105 -0
- warp/examples/tile/example_tile_matmul.py +79 -0
- warp/examples/tile/example_tile_mlp.py +375 -0
- warp/fem/__init__.py +8 -0
- warp/fem/cache.py +16 -12
- warp/fem/dirichlet.py +1 -1
- warp/fem/domain.py +44 -1
- warp/fem/field/__init__.py +1 -2
- warp/fem/field/field.py +31 -19
- warp/fem/field/nodal_field.py +101 -49
- warp/fem/field/virtual.py +794 -0
- warp/fem/geometry/__init__.py +2 -2
- warp/fem/geometry/deformed_geometry.py +3 -105
- warp/fem/geometry/element.py +13 -0
- warp/fem/geometry/geometry.py +165 -5
- warp/fem/geometry/grid_2d.py +3 -6
- warp/fem/geometry/grid_3d.py +31 -28
- warp/fem/geometry/hexmesh.py +3 -46
- warp/fem/geometry/nanogrid.py +3 -2
- warp/fem/geometry/{quadmesh_2d.py → quadmesh.py} +280 -159
- warp/fem/geometry/tetmesh.py +2 -43
- warp/fem/geometry/{trimesh_2d.py → trimesh.py} +354 -186
- warp/fem/integrate.py +683 -261
- warp/fem/linalg.py +404 -0
- warp/fem/operator.py +101 -18
- warp/fem/polynomial.py +5 -5
- warp/fem/quadrature/quadrature.py +45 -21
- warp/fem/space/__init__.py +45 -11
- warp/fem/space/basis_function_space.py +451 -0
- warp/fem/space/basis_space.py +58 -11
- warp/fem/space/function_space.py +146 -5
- warp/fem/space/grid_2d_function_space.py +80 -66
- warp/fem/space/grid_3d_function_space.py +113 -68
- warp/fem/space/hexmesh_function_space.py +96 -108
- warp/fem/space/nanogrid_function_space.py +62 -110
- warp/fem/space/quadmesh_function_space.py +208 -0
- warp/fem/space/shape/__init__.py +45 -7
- warp/fem/space/shape/cube_shape_function.py +328 -54
- warp/fem/space/shape/shape_function.py +10 -1
- warp/fem/space/shape/square_shape_function.py +328 -60
- warp/fem/space/shape/tet_shape_function.py +269 -19
- warp/fem/space/shape/triangle_shape_function.py +238 -19
- warp/fem/space/tetmesh_function_space.py +69 -37
- warp/fem/space/topology.py +38 -0
- warp/fem/space/trimesh_function_space.py +179 -0
- warp/fem/utils.py +6 -331
- warp/jax_experimental.py +3 -1
- warp/native/array.h +15 -0
- warp/native/builtin.h +66 -26
- warp/native/bvh.h +4 -0
- warp/native/coloring.cpp +600 -0
- warp/native/cuda_util.cpp +14 -0
- warp/native/cuda_util.h +2 -1
- warp/native/fabric.h +8 -0
- warp/native/hashgrid.h +4 -0
- warp/native/marching.cu +8 -0
- warp/native/mat.h +14 -3
- warp/native/mathdx.cpp +59 -0
- warp/native/mesh.h +4 -0
- warp/native/range.h +13 -1
- warp/native/reduce.cpp +9 -1
- warp/native/reduce.cu +7 -0
- warp/native/runlength_encode.cpp +9 -1
- warp/native/runlength_encode.cu +7 -1
- warp/native/scan.cpp +8 -0
- warp/native/scan.cu +8 -0
- warp/native/scan.h +8 -1
- warp/native/sparse.cpp +8 -0
- warp/native/sparse.cu +8 -0
- warp/native/temp_buffer.h +7 -0
- warp/native/tile.h +1857 -0
- warp/native/tile_gemm.h +341 -0
- warp/native/tile_reduce.h +210 -0
- warp/native/volume_builder.cu +8 -0
- warp/native/volume_builder.h +8 -0
- warp/native/warp.cpp +10 -2
- warp/native/warp.cu +369 -15
- warp/native/warp.h +12 -2
- warp/optim/adam.py +39 -4
- warp/paddle.py +29 -12
- warp/render/render_opengl.py +137 -65
- warp/sim/graph_coloring.py +292 -0
- warp/sim/integrator_euler.py +4 -2
- warp/sim/integrator_featherstone.py +115 -44
- warp/sim/integrator_vbd.py +6 -0
- warp/sim/model.py +88 -15
- warp/stubs.py +569 -4
- warp/tape.py +12 -7
- warp/tests/assets/pixel.npy +0 -0
- warp/tests/aux_test_instancing_gc.py +18 -0
- warp/tests/test_array.py +39 -0
- warp/tests/test_codegen.py +81 -1
- warp/tests/test_codegen_instancing.py +30 -0
- warp/tests/test_collision.py +110 -0
- warp/tests/test_coloring.py +241 -0
- warp/tests/test_context.py +34 -0
- warp/tests/test_examples.py +18 -4
- warp/tests/test_fem.py +453 -113
- warp/tests/test_func.py +13 -0
- warp/tests/test_generics.py +52 -0
- warp/tests/test_iter.py +68 -0
- warp/tests/test_mat_scalar_ops.py +1 -1
- warp/tests/test_mesh_query_point.py +1 -1
- warp/tests/test_module_hashing.py +23 -0
- warp/tests/test_paddle.py +27 -87
- warp/tests/test_print.py +56 -1
- warp/tests/test_spatial.py +1 -1
- warp/tests/test_tile.py +700 -0
- warp/tests/test_tile_mathdx.py +144 -0
- warp/tests/test_tile_mlp.py +383 -0
- warp/tests/test_tile_reduce.py +374 -0
- warp/tests/test_tile_shared_memory.py +190 -0
- warp/tests/test_vbd.py +12 -20
- warp/tests/test_volume.py +43 -0
- warp/tests/unittest_suites.py +19 -2
- warp/tests/unittest_utils.py +4 -0
- warp/types.py +338 -72
- warp/utils.py +22 -1
- {warp_lang-1.4.2.dist-info → warp_lang-1.5.0.dist-info}/METADATA +33 -7
- {warp_lang-1.4.2.dist-info → warp_lang-1.5.0.dist-info}/RECORD +153 -126
- {warp_lang-1.4.2.dist-info → warp_lang-1.5.0.dist-info}/WHEEL +1 -1
- warp/fem/field/test.py +0 -180
- warp/fem/field/trial.py +0 -183
- warp/fem/space/collocated_function_space.py +0 -102
- warp/fem/space/quadmesh_2d_function_space.py +0 -261
- warp/fem/space/trimesh_2d_function_space.py +0 -153
- {warp_lang-1.4.2.dist-info → warp_lang-1.5.0.dist-info}/LICENSE.md +0 -0
- {warp_lang-1.4.2.dist-info → warp_lang-1.5.0.dist-info}/top_level.txt +0 -0
warp/codegen.py
CHANGED
|
@@ -23,6 +23,10 @@ from typing import Any, Callable, Dict, Mapping, Optional, Sequence
|
|
|
23
23
|
import warp.config
|
|
24
24
|
from warp.types import *
|
|
25
25
|
|
|
26
|
+
# used as a globally accessible copy
|
|
27
|
+
# of current compile options (block_dim) etc
|
|
28
|
+
options = {}
|
|
29
|
+
|
|
26
30
|
|
|
27
31
|
class WarpCodegenError(RuntimeError):
|
|
28
32
|
def __init__(self, message):
|
|
@@ -110,6 +114,16 @@ def get_closure_cell_contents(obj):
|
|
|
110
114
|
return None
|
|
111
115
|
|
|
112
116
|
|
|
117
|
+
def get_type_origin(tp):
|
|
118
|
+
# Compatible version of `typing.get_origin()` for Python 3.7 and older.
|
|
119
|
+
return getattr(tp, "__origin__", None)
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def get_type_args(tp):
|
|
123
|
+
# Compatible version of `typing.get_args()` for Python 3.7 and older.
|
|
124
|
+
return getattr(tp, "__args__", ())
|
|
125
|
+
|
|
126
|
+
|
|
113
127
|
def eval_annotations(annotations: Mapping[str, Any], obj: Any) -> Mapping[str, Any]:
|
|
114
128
|
"""Un-stringize annotations caused by `from __future__ import annotations` of PEP 563."""
|
|
115
129
|
# Implementation backported from `inspect.get_annotations()` for Python 3.9 and older.
|
|
@@ -637,6 +651,8 @@ class Var:
|
|
|
637
651
|
dtypestr = f"wp::{t.dtype.__name__}"
|
|
638
652
|
classstr = f"wp::{type(t).__name__}"
|
|
639
653
|
return f"{classstr}_t<{dtypestr}>"
|
|
654
|
+
elif is_tile(t):
|
|
655
|
+
return t.ctype()
|
|
640
656
|
elif isinstance(t, Struct):
|
|
641
657
|
return t.native_name
|
|
642
658
|
elif isinstance(t, type) and issubclass(t, StructInstance):
|
|
@@ -876,7 +892,7 @@ class Adjoint:
|
|
|
876
892
|
# use source-level argument annotations
|
|
877
893
|
if len(argspec.annotations) < len(argspec.args):
|
|
878
894
|
raise WarpCodegenError(f"Incomplete argument annotations on function {adj.fun_name}")
|
|
879
|
-
adj.arg_types = argspec.annotations
|
|
895
|
+
adj.arg_types = {k: v for k, v in argspec.annotations.items() if not (k == "return" and v is None)}
|
|
880
896
|
else:
|
|
881
897
|
# use overload argument annotations
|
|
882
898
|
for arg_name in argspec.args:
|
|
@@ -914,6 +930,28 @@ class Adjoint:
|
|
|
914
930
|
# for unit testing errors being spit out from kernels.
|
|
915
931
|
adj.skip_build = False
|
|
916
932
|
|
|
933
|
+
# Collect the LTOIR required at link-time
|
|
934
|
+
adj.ltoirs = []
|
|
935
|
+
|
|
936
|
+
# allocate extra space for a function call that requires its
|
|
937
|
+
# own shared memory space, we treat shared memory as a stack
|
|
938
|
+
# where each function pushes and pops space off, the extra
|
|
939
|
+
# quantity is the 'roofline' amount required for the entire kernel
|
|
940
|
+
def alloc_shared_extra(adj, num_bytes):
|
|
941
|
+
adj.max_required_extra_shared_memory = max(adj.max_required_extra_shared_memory, num_bytes)
|
|
942
|
+
|
|
943
|
+
# returns the total number of bytes for a function
|
|
944
|
+
# based on it's own requirements + worst case
|
|
945
|
+
# requirements of any dependent functions
|
|
946
|
+
def get_total_required_shared(adj):
|
|
947
|
+
total_shared = 0
|
|
948
|
+
|
|
949
|
+
for var in adj.variables:
|
|
950
|
+
if is_tile(var.type) and var.type.storage == "shared":
|
|
951
|
+
total_shared += var.type.size_in_bytes()
|
|
952
|
+
|
|
953
|
+
return total_shared + adj.max_required_extra_shared_memory
|
|
954
|
+
|
|
917
955
|
# generate function ssa form and adjoint
|
|
918
956
|
def build(adj, builder, default_builder_options=None):
|
|
919
957
|
# arg Var read/write flags are held during module rebuilds, so we reset here even when skipping a build
|
|
@@ -934,6 +972,9 @@ class Adjoint:
|
|
|
934
972
|
else:
|
|
935
973
|
adj.builder_options = default_builder_options
|
|
936
974
|
|
|
975
|
+
global options
|
|
976
|
+
options = adj.builder_options
|
|
977
|
+
|
|
937
978
|
adj.symbols = {} # map from symbols to adjoint variables
|
|
938
979
|
adj.variables = [] # list of local variables (in order)
|
|
939
980
|
|
|
@@ -953,6 +994,9 @@ class Adjoint:
|
|
|
953
994
|
# used to generate new label indices
|
|
954
995
|
adj.label_count = 0
|
|
955
996
|
|
|
997
|
+
# tracks how much additional shared memory is required by any dependent function calls
|
|
998
|
+
adj.max_required_extra_shared_memory = 0
|
|
999
|
+
|
|
956
1000
|
# update symbol map for each argument
|
|
957
1001
|
for a in adj.args:
|
|
958
1002
|
adj.symbols[a.label] = a
|
|
@@ -969,6 +1013,7 @@ class Adjoint:
|
|
|
969
1013
|
e = ex(";".join([msg] + [str(a) for a in data.args])).with_traceback(traceback)
|
|
970
1014
|
finally:
|
|
971
1015
|
adj.skip_build = True
|
|
1016
|
+
adj.builder = None
|
|
972
1017
|
raise e
|
|
973
1018
|
|
|
974
1019
|
if builder is not None:
|
|
@@ -978,6 +1023,9 @@ class Adjoint:
|
|
|
978
1023
|
elif isinstance(a.type, warp.types.array) and isinstance(a.type.dtype, Struct):
|
|
979
1024
|
builder.build_struct_recursive(a.type.dtype)
|
|
980
1025
|
|
|
1026
|
+
# release builder reference for GC
|
|
1027
|
+
adj.builder = None
|
|
1028
|
+
|
|
981
1029
|
# code generation methods
|
|
982
1030
|
def format_template(adj, template, input_vars, output_var):
|
|
983
1031
|
# output var is always the 0th index
|
|
@@ -994,9 +1042,9 @@ class Adjoint:
|
|
|
994
1042
|
if isinstance(a, warp.context.Function):
|
|
995
1043
|
# functions don't have a var_ prefix so strip it off here
|
|
996
1044
|
if prefix == "var":
|
|
997
|
-
arg_strs.append(a.native_func)
|
|
1045
|
+
arg_strs.append(f"{a.namespace}{a.native_func}")
|
|
998
1046
|
else:
|
|
999
|
-
arg_strs.append(f"{prefix}_{a.native_func}")
|
|
1047
|
+
arg_strs.append(f"{a.namespace}{prefix}_{a.native_func}")
|
|
1000
1048
|
elif is_reference(a.type):
|
|
1001
1049
|
arg_strs.append(f"{prefix}_{a}")
|
|
1002
1050
|
elif isinstance(a, Var):
|
|
@@ -1278,15 +1326,34 @@ class Adjoint:
|
|
|
1278
1326
|
bound_arg_values,
|
|
1279
1327
|
)
|
|
1280
1328
|
|
|
1281
|
-
|
|
1282
|
-
|
|
1283
|
-
#
|
|
1284
|
-
|
|
1285
|
-
|
|
1286
|
-
|
|
1287
|
-
#
|
|
1288
|
-
|
|
1329
|
+
# immediately allocate output variables so we can pass them into the dispatch method
|
|
1330
|
+
if return_type is None:
|
|
1331
|
+
# void function
|
|
1332
|
+
output = None
|
|
1333
|
+
output_list = []
|
|
1334
|
+
elif not isinstance(return_type, Sequence) or len(return_type) == 1:
|
|
1335
|
+
# single return value function
|
|
1336
|
+
if isinstance(return_type, Sequence):
|
|
1337
|
+
return_type = return_type[0]
|
|
1338
|
+
output = adj.add_var(return_type)
|
|
1339
|
+
output_list = [output]
|
|
1340
|
+
else:
|
|
1341
|
+
# multiple return value function
|
|
1342
|
+
output = [adj.add_var(v) for v in return_type]
|
|
1343
|
+
output_list = output
|
|
1289
1344
|
|
|
1345
|
+
# If we have a built-in that requires special handling to dispatch
|
|
1346
|
+
# the arguments to the underlying C++ function, then we can resolve
|
|
1347
|
+
# these using the `dispatch_func`. Since this is only called from
|
|
1348
|
+
# within codegen, we pass it directly `codegen.Var` objects,
|
|
1349
|
+
# which allows for some more advanced resolution to be performed,
|
|
1350
|
+
# for example by checking whether an argument corresponds to
|
|
1351
|
+
# a literal value or references a variable.
|
|
1352
|
+
if func.lto_dispatch_func is not None:
|
|
1353
|
+
func_args, template_args, ltoirs = func.lto_dispatch_func(
|
|
1354
|
+
func.input_types, return_type, output_list, bound_args, options=adj.builder_options, builder=adj.builder
|
|
1355
|
+
)
|
|
1356
|
+
elif func.dispatch_func is not None:
|
|
1290
1357
|
func_args, template_args = func.dispatch_func(func.input_types, return_type, bound_args)
|
|
1291
1358
|
else:
|
|
1292
1359
|
func_args = tuple(bound_args.values())
|
|
@@ -1301,18 +1368,14 @@ class Adjoint:
|
|
|
1301
1368
|
if not isinstance(func_arg, (Reference, warp.context.Function)):
|
|
1302
1369
|
func_arg = adj.load(func_arg)
|
|
1303
1370
|
|
|
1304
|
-
# if the argument is a function, build it recursively
|
|
1305
|
-
if isinstance(func_arg, warp.context.Function):
|
|
1371
|
+
# if the argument is a function (and not a builtin), then build it recursively
|
|
1372
|
+
if isinstance(func_arg, warp.context.Function) and not func_arg.is_builtin():
|
|
1306
1373
|
adj.builder.build_function(func_arg)
|
|
1307
1374
|
|
|
1308
1375
|
fwd_args.append(strip_reference(func_arg))
|
|
1309
1376
|
|
|
1310
1377
|
if return_type is None:
|
|
1311
1378
|
# handles expression (zero output) functions, e.g.: void do_something();
|
|
1312
|
-
|
|
1313
|
-
output = None
|
|
1314
|
-
output_list = []
|
|
1315
|
-
|
|
1316
1379
|
forward_call = (
|
|
1317
1380
|
f"{func.namespace}{func_name}({adj.format_forward_call_args(fwd_args, use_initializer_list)});"
|
|
1318
1381
|
)
|
|
@@ -1322,12 +1385,6 @@ class Adjoint:
|
|
|
1322
1385
|
|
|
1323
1386
|
elif not isinstance(return_type, Sequence) or len(return_type) == 1:
|
|
1324
1387
|
# handle simple function (one output)
|
|
1325
|
-
|
|
1326
|
-
if isinstance(return_type, Sequence):
|
|
1327
|
-
return_type = return_type[0]
|
|
1328
|
-
output = adj.add_var(return_type)
|
|
1329
|
-
output_list = [output]
|
|
1330
|
-
|
|
1331
1388
|
forward_call = f"var_{output} = {func.namespace}{func_name}({adj.format_forward_call_args(fwd_args, use_initializer_list)});"
|
|
1332
1389
|
replay_call = forward_call
|
|
1333
1390
|
if func.custom_replay_func is not None:
|
|
@@ -1335,10 +1392,6 @@ class Adjoint:
|
|
|
1335
1392
|
|
|
1336
1393
|
else:
|
|
1337
1394
|
# handle multiple value functions
|
|
1338
|
-
|
|
1339
|
-
output = [adj.add_var(v) for v in return_type]
|
|
1340
|
-
output_list = output
|
|
1341
|
-
|
|
1342
1395
|
forward_call = (
|
|
1343
1396
|
f"{func.namespace}{func_name}({adj.format_forward_call_args(fwd_args + output, use_initializer_list)});"
|
|
1344
1397
|
)
|
|
@@ -1366,6 +1419,11 @@ class Adjoint:
|
|
|
1366
1419
|
reverse_call = f"{func.namespace}adj_{func.native_func}({arg_str});"
|
|
1367
1420
|
adj.add_reverse(reverse_call)
|
|
1368
1421
|
|
|
1422
|
+
# update our smem roofline requirements based on any
|
|
1423
|
+
# shared memory required by the dependent function call
|
|
1424
|
+
if not func.is_builtin():
|
|
1425
|
+
adj.alloc_shared_extra(func.adj.get_total_required_shared())
|
|
1426
|
+
|
|
1369
1427
|
return output
|
|
1370
1428
|
|
|
1371
1429
|
def add_builtin_call(adj, func_name, args, min_outputs=None):
|
|
@@ -1466,7 +1524,10 @@ class Adjoint:
|
|
|
1466
1524
|
|
|
1467
1525
|
# zero adjoints
|
|
1468
1526
|
for i in body_block.vars:
|
|
1469
|
-
|
|
1527
|
+
if is_tile(i.type):
|
|
1528
|
+
reverse.append(adj.indentation + f"\t{i.emit_adj()}.grad_zero();")
|
|
1529
|
+
else:
|
|
1530
|
+
reverse.append(adj.indentation + f"\t{i.emit_adj()} = {{}};")
|
|
1470
1531
|
|
|
1471
1532
|
# replay
|
|
1472
1533
|
for i in body_block.body_replay:
|
|
@@ -2206,7 +2267,7 @@ class Adjoint:
|
|
|
2206
2267
|
|
|
2207
2268
|
# returns the object being indexed, and the list of indices
|
|
2208
2269
|
def eval_subscript(adj, node):
|
|
2209
|
-
# We want to coalesce multi-
|
|
2270
|
+
# We want to coalesce multi-dimensional array indexing into a single operation. This needs to deal with expressions like `a[i][j][x][y]` where `a` is a 2D array of matrices,
|
|
2210
2271
|
# and essentially rewrite it into `a[i, j][x][y]`. Since the AST observes the indexing right-to-left, and we don't want to evaluate the index expressions prematurely,
|
|
2211
2272
|
# this requires a first loop to check if this `node` only performs indexing on the array, and a second loop to evaluate and collect index variables.
|
|
2212
2273
|
root = node
|
|
@@ -2286,6 +2347,14 @@ class Adjoint:
|
|
|
2286
2347
|
out.is_read = target.is_read
|
|
2287
2348
|
out.is_write = target.is_write
|
|
2288
2349
|
|
|
2350
|
+
elif is_tile(target_type):
|
|
2351
|
+
if len(indices) == 2:
|
|
2352
|
+
# handles extracting a single element from a tile
|
|
2353
|
+
out = adj.add_builtin_call("tile_extract", [target, *indices])
|
|
2354
|
+
else:
|
|
2355
|
+
# handles tile views
|
|
2356
|
+
out = adj.add_builtin_call("tile_view", [target, *indices])
|
|
2357
|
+
|
|
2289
2358
|
else:
|
|
2290
2359
|
# handles non-array type indexing, e.g: vec3, mat33, etc
|
|
2291
2360
|
out = adj.add_builtin_call("extract", [target, *indices])
|
|
@@ -2527,11 +2596,22 @@ class Adjoint:
|
|
|
2527
2596
|
target_type = strip_reference(target.type)
|
|
2528
2597
|
|
|
2529
2598
|
if is_array(target_type):
|
|
2530
|
-
#
|
|
2531
|
-
if target_type.dtype
|
|
2599
|
+
# target_types int8, uint8, int16, uint16 are not suitable for atomic array accumulation
|
|
2600
|
+
if target_type.dtype in warp.types.non_atomic_types:
|
|
2532
2601
|
make_new_assign_statement()
|
|
2533
2602
|
return
|
|
2534
2603
|
|
|
2604
|
+
# the same holds true for vecs/mats/quats that are composed of these types
|
|
2605
|
+
if (
|
|
2606
|
+
type_is_vector(target_type.dtype)
|
|
2607
|
+
or type_is_quaternion(target_type.dtype)
|
|
2608
|
+
or type_is_matrix(target_type.dtype)
|
|
2609
|
+
):
|
|
2610
|
+
dtype = getattr(target_type.dtype, "_wp_scalar_type_", None)
|
|
2611
|
+
if dtype in warp.types.non_atomic_types:
|
|
2612
|
+
make_new_assign_statement()
|
|
2613
|
+
return
|
|
2614
|
+
|
|
2535
2615
|
kernel_name = adj.fun_name
|
|
2536
2616
|
filename = adj.filename
|
|
2537
2617
|
lineno = adj.lineno + adj.fun_lineno
|
|
@@ -2955,6 +3035,7 @@ class Adjoint:
|
|
|
2955
3035
|
# code generation
|
|
2956
3036
|
|
|
2957
3037
|
cpu_module_header = """
|
|
3038
|
+
#define WP_TILE_BLOCK_DIM {tile_size}
|
|
2958
3039
|
#define WP_NO_CRT
|
|
2959
3040
|
#include "builtin.h"
|
|
2960
3041
|
|
|
@@ -2965,7 +3046,7 @@ cpu_module_header = """
|
|
|
2965
3046
|
#define int(x) cast_int(x)
|
|
2966
3047
|
#define adj_int(x, adj_x, adj_ret) adj_cast_int(x, adj_x, adj_ret)
|
|
2967
3048
|
|
|
2968
|
-
#define builtin_tid1d() wp::tid(task_index)
|
|
3049
|
+
#define builtin_tid1d() wp::tid(task_index, dim)
|
|
2969
3050
|
#define builtin_tid2d(x, y) wp::tid(x, y, task_index, dim)
|
|
2970
3051
|
#define builtin_tid3d(x, y, z) wp::tid(x, y, z, task_index, dim)
|
|
2971
3052
|
#define builtin_tid4d(x, y, z, w) wp::tid(x, y, z, w, task_index, dim)
|
|
@@ -2973,6 +3054,7 @@ cpu_module_header = """
|
|
|
2973
3054
|
"""
|
|
2974
3055
|
|
|
2975
3056
|
cuda_module_header = """
|
|
3057
|
+
#define WP_TILE_BLOCK_DIM {tile_size}
|
|
2976
3058
|
#define WP_NO_CRT
|
|
2977
3059
|
#include "builtin.h"
|
|
2978
3060
|
|
|
@@ -2983,10 +3065,10 @@ cuda_module_header = """
|
|
|
2983
3065
|
#define int(x) cast_int(x)
|
|
2984
3066
|
#define adj_int(x, adj_x, adj_ret) adj_cast_int(x, adj_x, adj_ret)
|
|
2985
3067
|
|
|
2986
|
-
#define builtin_tid1d() wp::tid(
|
|
2987
|
-
#define builtin_tid2d(x, y) wp::tid(x, y,
|
|
2988
|
-
#define builtin_tid3d(x, y, z) wp::tid(x, y, z,
|
|
2989
|
-
#define builtin_tid4d(x, y, z, w) wp::tid(x, y, z, w,
|
|
3068
|
+
#define builtin_tid1d() wp::tid(_idx, dim)
|
|
3069
|
+
#define builtin_tid2d(x, y) wp::tid(x, y, _idx, dim)
|
|
3070
|
+
#define builtin_tid3d(x, y, z) wp::tid(x, y, z, _idx, dim)
|
|
3071
|
+
#define builtin_tid4d(x, y, z, w) wp::tid(x, y, z, w, _idx, dim)
|
|
2990
3072
|
|
|
2991
3073
|
"""
|
|
2992
3074
|
|
|
@@ -3058,20 +3140,26 @@ cuda_kernel_template = """
|
|
|
3058
3140
|
extern "C" __global__ void {name}_cuda_kernel_forward(
|
|
3059
3141
|
{forward_args})
|
|
3060
3142
|
{{
|
|
3061
|
-
for (size_t
|
|
3062
|
-
|
|
3063
|
-
|
|
3143
|
+
for (size_t _idx = static_cast<size_t>(blockDim.x) * static_cast<size_t>(blockIdx.x) + static_cast<size_t>(threadIdx.x);
|
|
3144
|
+
_idx < dim.size;
|
|
3145
|
+
_idx += static_cast<size_t>(blockDim.x) * static_cast<size_t>(gridDim.x))
|
|
3064
3146
|
{{
|
|
3147
|
+
// reset shared memory allocator
|
|
3148
|
+
wp::tile_alloc_shared(0, true);
|
|
3149
|
+
|
|
3065
3150
|
{forward_body} }}
|
|
3066
3151
|
}}
|
|
3067
3152
|
|
|
3068
3153
|
extern "C" __global__ void {name}_cuda_kernel_backward(
|
|
3069
3154
|
{reverse_args})
|
|
3070
3155
|
{{
|
|
3071
|
-
for (size_t
|
|
3072
|
-
|
|
3073
|
-
|
|
3156
|
+
for (size_t _idx = static_cast<size_t>(blockDim.x) * static_cast<size_t>(blockIdx.x) + static_cast<size_t>(threadIdx.x);
|
|
3157
|
+
_idx < dim.size;
|
|
3158
|
+
_idx += static_cast<size_t>(blockDim.x) * static_cast<size_t>(gridDim.x))
|
|
3074
3159
|
{{
|
|
3160
|
+
// reset shared memory allocator
|
|
3161
|
+
wp::tile_alloc_shared(0, true);
|
|
3162
|
+
|
|
3075
3163
|
{reverse_body} }}
|
|
3076
3164
|
}}
|
|
3077
3165
|
|
|
@@ -3309,7 +3397,9 @@ def codegen_func_forward(adj, func_type="kernel", device="cpu"):
|
|
|
3309
3397
|
lines += ["// primal vars\n"]
|
|
3310
3398
|
|
|
3311
3399
|
for var in adj.variables:
|
|
3312
|
-
if var.
|
|
3400
|
+
if is_tile(var.type):
|
|
3401
|
+
lines += [f"{var.ctype()} {var.emit()} = {var.type.cinit(requires_grad=False)};\n"]
|
|
3402
|
+
elif var.constant is None:
|
|
3313
3403
|
lines += [f"{var.ctype()} {var.emit()};\n"]
|
|
3314
3404
|
else:
|
|
3315
3405
|
lines += [f"const {var.ctype()} {var.emit()} = {constant_str(var.constant)};\n"]
|
|
@@ -3344,7 +3434,9 @@ def codegen_func_reverse(adj, func_type="kernel", device="cpu"):
|
|
|
3344
3434
|
lines += ["// primal vars\n"]
|
|
3345
3435
|
|
|
3346
3436
|
for var in adj.variables:
|
|
3347
|
-
if var.
|
|
3437
|
+
if is_tile(var.type):
|
|
3438
|
+
lines += [f"{var.ctype()} {var.emit()} = {var.type.cinit(requires_grad=True)};\n"]
|
|
3439
|
+
elif var.constant is None:
|
|
3348
3440
|
lines += [f"{var.ctype()} {var.emit()};\n"]
|
|
3349
3441
|
else:
|
|
3350
3442
|
lines += [f"const {var.ctype()} {var.emit()} = {constant_str(var.constant)};\n"]
|
|
@@ -3354,7 +3446,20 @@ def codegen_func_reverse(adj, func_type="kernel", device="cpu"):
|
|
|
3354
3446
|
lines += ["// dual vars\n"]
|
|
3355
3447
|
|
|
3356
3448
|
for var in adj.variables:
|
|
3357
|
-
|
|
3449
|
+
name = var.emit_adj()
|
|
3450
|
+
ctype = var.ctype(value_type=True)
|
|
3451
|
+
|
|
3452
|
+
if is_tile(var.type):
|
|
3453
|
+
if var.type.storage == "register":
|
|
3454
|
+
lines += [
|
|
3455
|
+
f"{var.type.ctype()} {name}(0.0);\n"
|
|
3456
|
+
] # reverse mode tiles alias the forward vars since shared tiles store both primal/dual vars together
|
|
3457
|
+
elif var.type.storage == "shared":
|
|
3458
|
+
lines += [
|
|
3459
|
+
f"{var.type.ctype()}& {name} = {var.emit()};\n"
|
|
3460
|
+
] # reverse mode tiles alias the forward vars since shared tiles store both primal/dual vars together
|
|
3461
|
+
else:
|
|
3462
|
+
lines += [f"{ctype} {name} = {{}};\n"]
|
|
3358
3463
|
|
|
3359
3464
|
# forward pass
|
|
3360
3465
|
lines += ["//---------\n"]
|
|
@@ -3383,6 +3488,33 @@ def codegen_func(adj, c_func_name: str, device="cpu", options=None):
|
|
|
3383
3488
|
if options is None:
|
|
3384
3489
|
options = {}
|
|
3385
3490
|
|
|
3491
|
+
if adj.return_var is not None and "return" in adj.arg_types:
|
|
3492
|
+
if get_type_origin(adj.arg_types["return"]) is tuple:
|
|
3493
|
+
if len(get_type_args(adj.arg_types["return"])) != len(adj.return_var):
|
|
3494
|
+
raise WarpCodegenError(
|
|
3495
|
+
f"The function `{adj.fun_name}` has its return type "
|
|
3496
|
+
f"annotated as a tuple of {len(get_type_args(adj.arg_types['return']))} elements "
|
|
3497
|
+
f"but the code returns {len(adj.return_var)} values."
|
|
3498
|
+
)
|
|
3499
|
+
elif not types_equal(adj.arg_types["return"], tuple(x.type for x in adj.return_var)):
|
|
3500
|
+
raise WarpCodegenError(
|
|
3501
|
+
f"The function `{adj.fun_name}` has its return type "
|
|
3502
|
+
f"annotated as `{warp.context.type_str(adj.arg_types['return'])}` "
|
|
3503
|
+
f"but the code returns a tuple with types `({', '.join(warp.context.type_str(x.type) for x in adj.return_var)})`."
|
|
3504
|
+
)
|
|
3505
|
+
elif len(adj.return_var) > 1 and get_type_origin(adj.arg_types["return"]) is not tuple:
|
|
3506
|
+
raise WarpCodegenError(
|
|
3507
|
+
f"The function `{adj.fun_name}` has its return type "
|
|
3508
|
+
f"annotated as `{warp.context.type_str(adj.arg_types['return'])}` "
|
|
3509
|
+
f"but the code returns {len(adj.return_var)} values."
|
|
3510
|
+
)
|
|
3511
|
+
elif not types_equal(adj.arg_types["return"], adj.return_var[0].type):
|
|
3512
|
+
raise WarpCodegenError(
|
|
3513
|
+
f"The function `{adj.fun_name}` has its return type "
|
|
3514
|
+
f"annotated as `{warp.context.type_str(adj.arg_types['return'])}` "
|
|
3515
|
+
f"but the code returns a value of type `{warp.context.type_str(adj.return_var[0].type)}`."
|
|
3516
|
+
)
|
|
3517
|
+
|
|
3386
3518
|
# forward header
|
|
3387
3519
|
if adj.return_var is not None and len(adj.return_var) == 1:
|
|
3388
3520
|
return_type = adj.return_var[0].ctype()
|
warp/config.py
CHANGED
|
@@ -7,7 +7,7 @@
|
|
|
7
7
|
|
|
8
8
|
from typing import Optional
|
|
9
9
|
|
|
10
|
-
version: str = "1.
|
|
10
|
+
version: str = "1.5.0"
|
|
11
11
|
"""Warp version string"""
|
|
12
12
|
|
|
13
13
|
verify_fp: bool = False
|
|
@@ -16,7 +16,7 @@ Has performance implications.
|
|
|
16
16
|
"""
|
|
17
17
|
|
|
18
18
|
verify_cuda: bool = False
|
|
19
|
-
"""If `True`, Warp will check for CUDA errors after every launch
|
|
19
|
+
"""If `True`, Warp will check for CUDA errors after every launch operation.
|
|
20
20
|
CUDA error verification cannot be used during graph capture. Has performance implications.
|
|
21
21
|
"""
|
|
22
22
|
|