warp-lang 1.6.2__py3-none-macosx_10_13_universal2.whl → 1.7.0__py3-none-macosx_10_13_universal2.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +7 -1
- warp/bin/libwarp-clang.dylib +0 -0
- warp/bin/libwarp.dylib +0 -0
- warp/build.py +410 -0
- warp/build_dll.py +6 -14
- warp/builtins.py +452 -362
- warp/codegen.py +179 -119
- warp/config.py +42 -6
- warp/context.py +490 -271
- warp/dlpack.py +8 -6
- warp/examples/assets/nonuniform.usd +0 -0
- warp/examples/assets/nvidia_logo.png +0 -0
- warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
- warp/examples/core/example_sample_mesh.py +300 -0
- warp/examples/fem/example_apic_fluid.py +1 -1
- warp/examples/fem/example_burgers.py +2 -2
- warp/examples/fem/example_deformed_geometry.py +1 -1
- warp/examples/fem/example_distortion_energy.py +1 -1
- warp/examples/fem/example_magnetostatics.py +6 -6
- warp/examples/fem/utils.py +9 -3
- warp/examples/interop/example_jax_callable.py +116 -0
- warp/examples/interop/example_jax_ffi_callback.py +132 -0
- warp/examples/interop/example_jax_kernel.py +205 -0
- warp/examples/optim/example_fluid_checkpoint.py +497 -0
- warp/examples/tile/example_tile_matmul.py +2 -4
- warp/fem/__init__.py +11 -1
- warp/fem/adaptivity.py +4 -4
- warp/fem/field/nodal_field.py +22 -68
- warp/fem/field/virtual.py +62 -23
- warp/fem/geometry/adaptive_nanogrid.py +9 -10
- warp/fem/geometry/closest_point.py +1 -1
- warp/fem/geometry/deformed_geometry.py +5 -2
- warp/fem/geometry/geometry.py +5 -0
- warp/fem/geometry/grid_2d.py +12 -12
- warp/fem/geometry/grid_3d.py +12 -15
- warp/fem/geometry/hexmesh.py +5 -7
- warp/fem/geometry/nanogrid.py +9 -11
- warp/fem/geometry/quadmesh.py +13 -13
- warp/fem/geometry/tetmesh.py +3 -4
- warp/fem/geometry/trimesh.py +3 -8
- warp/fem/integrate.py +262 -93
- warp/fem/linalg.py +5 -5
- warp/fem/quadrature/pic_quadrature.py +37 -22
- warp/fem/quadrature/quadrature.py +194 -25
- warp/fem/space/__init__.py +1 -1
- warp/fem/space/basis_function_space.py +4 -2
- warp/fem/space/basis_space.py +25 -18
- warp/fem/space/hexmesh_function_space.py +2 -2
- warp/fem/space/partition.py +6 -2
- warp/fem/space/quadmesh_function_space.py +8 -8
- warp/fem/space/shape/cube_shape_function.py +23 -23
- warp/fem/space/shape/square_shape_function.py +12 -12
- warp/fem/space/shape/triangle_shape_function.py +1 -1
- warp/fem/space/tetmesh_function_space.py +3 -3
- warp/fem/space/trimesh_function_space.py +2 -2
- warp/fem/utils.py +12 -6
- warp/jax.py +14 -1
- warp/jax_experimental/__init__.py +16 -0
- warp/{jax_experimental.py → jax_experimental/custom_call.py} +14 -27
- warp/jax_experimental/ffi.py +698 -0
- warp/jax_experimental/xla_ffi.py +602 -0
- warp/math.py +89 -0
- warp/native/array.h +13 -0
- warp/native/builtin.h +29 -3
- warp/native/bvh.cpp +3 -1
- warp/native/bvh.cu +42 -14
- warp/native/bvh.h +2 -1
- warp/native/clang/clang.cpp +30 -3
- warp/native/cuda_util.cpp +14 -0
- warp/native/cuda_util.h +2 -0
- warp/native/exports.h +68 -63
- warp/native/intersect.h +26 -26
- warp/native/intersect_adj.h +33 -33
- warp/native/marching.cu +1 -1
- warp/native/mat.h +513 -9
- warp/native/mesh.h +10 -10
- warp/native/quat.h +99 -11
- warp/native/rand.h +6 -0
- warp/native/sort.cpp +122 -59
- warp/native/sort.cu +152 -15
- warp/native/sort.h +8 -1
- warp/native/sparse.cpp +43 -22
- warp/native/sparse.cu +52 -17
- warp/native/svd.h +116 -0
- warp/native/tile.h +301 -105
- warp/native/tile_reduce.h +46 -3
- warp/native/vec.h +68 -7
- warp/native/volume.cpp +85 -113
- warp/native/volume_builder.cu +25 -10
- warp/native/volume_builder.h +6 -0
- warp/native/warp.cpp +5 -6
- warp/native/warp.cu +99 -10
- warp/native/warp.h +19 -10
- warp/optim/linear.py +10 -10
- warp/sim/articulation.py +4 -4
- warp/sim/collide.py +21 -10
- warp/sim/import_mjcf.py +449 -155
- warp/sim/import_urdf.py +32 -12
- warp/sim/integrator_euler.py +5 -5
- warp/sim/integrator_featherstone.py +3 -10
- warp/sim/integrator_vbd.py +207 -2
- warp/sim/integrator_xpbd.py +5 -5
- warp/sim/model.py +42 -13
- warp/sim/utils.py +2 -2
- warp/sparse.py +642 -555
- warp/stubs.py +216 -19
- warp/tests/__main__.py +0 -15
- warp/tests/cuda/__init__.py +0 -0
- warp/tests/{test_mempool.py → cuda/test_mempool.py} +39 -0
- warp/tests/{test_streams.py → cuda/test_streams.py} +71 -0
- warp/tests/geometry/__init__.py +0 -0
- warp/tests/{test_mesh_query_point.py → geometry/test_mesh_query_point.py} +66 -63
- warp/tests/{test_mesh_query_ray.py → geometry/test_mesh_query_ray.py} +1 -1
- warp/tests/{test_volume.py → geometry/test_volume.py} +41 -6
- warp/tests/interop/__init__.py +0 -0
- warp/tests/{test_dlpack.py → interop/test_dlpack.py} +28 -5
- warp/tests/sim/__init__.py +0 -0
- warp/tests/{disabled_kinematics.py → sim/disabled_kinematics.py} +9 -10
- warp/tests/{test_collision.py → sim/test_collision.py} +2 -2
- warp/tests/{test_model.py → sim/test_model.py} +40 -0
- warp/tests/{test_sim_kinematics.py → sim/test_sim_kinematics.py} +2 -1
- warp/tests/sim/test_vbd.py +597 -0
- warp/tests/test_bool.py +1 -1
- warp/tests/test_examples.py +28 -36
- warp/tests/test_fem.py +23 -4
- warp/tests/test_linear_solvers.py +0 -11
- warp/tests/test_mat.py +233 -79
- warp/tests/test_mat_scalar_ops.py +4 -4
- warp/tests/test_overwrite.py +0 -60
- warp/tests/test_quat.py +67 -46
- warp/tests/test_rand.py +44 -37
- warp/tests/test_sparse.py +47 -6
- warp/tests/test_spatial.py +75 -0
- warp/tests/test_static.py +1 -1
- warp/tests/test_utils.py +84 -4
- warp/tests/test_vec.py +46 -34
- warp/tests/tile/__init__.py +0 -0
- warp/tests/{test_tile.py → tile/test_tile.py} +136 -51
- warp/tests/{test_tile_load.py → tile/test_tile_load.py} +1 -1
- warp/tests/{test_tile_mathdx.py → tile/test_tile_mathdx.py} +9 -6
- warp/tests/{test_tile_mlp.py → tile/test_tile_mlp.py} +25 -14
- warp/tests/{test_tile_reduce.py → tile/test_tile_reduce.py} +60 -1
- warp/tests/{test_tile_view.py → tile/test_tile_view.py} +1 -1
- warp/tests/unittest_serial.py +1 -0
- warp/tests/unittest_suites.py +45 -59
- warp/tests/unittest_utils.py +2 -1
- warp/thirdparty/unittest_parallel.py +3 -1
- warp/types.py +110 -658
- warp/utils.py +137 -72
- {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info}/METADATA +29 -7
- {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info}/RECORD +172 -162
- {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info}/WHEEL +1 -1
- warp/examples/optim/example_walker.py +0 -317
- warp/native/cutlass_gemm.cpp +0 -43
- warp/native/cutlass_gemm.cu +0 -382
- warp/tests/test_matmul.py +0 -511
- warp/tests/test_matmul_lite.py +0 -411
- warp/tests/test_vbd.py +0 -386
- warp/tests/unused_test_misc.py +0 -77
- /warp/tests/{test_async.py → cuda/test_async.py} +0 -0
- /warp/tests/{test_ipc.py → cuda/test_ipc.py} +0 -0
- /warp/tests/{test_multigpu.py → cuda/test_multigpu.py} +0 -0
- /warp/tests/{test_peer.py → cuda/test_peer.py} +0 -0
- /warp/tests/{test_pinned.py → cuda/test_pinned.py} +0 -0
- /warp/tests/{test_bvh.py → geometry/test_bvh.py} +0 -0
- /warp/tests/{test_hash_grid.py → geometry/test_hash_grid.py} +0 -0
- /warp/tests/{test_marching_cubes.py → geometry/test_marching_cubes.py} +0 -0
- /warp/tests/{test_mesh.py → geometry/test_mesh.py} +0 -0
- /warp/tests/{test_mesh_query_aabb.py → geometry/test_mesh_query_aabb.py} +0 -0
- /warp/tests/{test_volume_write.py → geometry/test_volume_write.py} +0 -0
- /warp/tests/{test_jax.py → interop/test_jax.py} +0 -0
- /warp/tests/{test_paddle.py → interop/test_paddle.py} +0 -0
- /warp/tests/{test_torch.py → interop/test_torch.py} +0 -0
- /warp/tests/{flaky_test_sim_grad.py → sim/flaky_test_sim_grad.py} +0 -0
- /warp/tests/{test_coloring.py → sim/test_coloring.py} +0 -0
- /warp/tests/{test_sim_grad_bounce_linear.py → sim/test_sim_grad_bounce_linear.py} +0 -0
- /warp/tests/{test_tile_shared_memory.py → tile/test_tile_shared_memory.py} +0 -0
- {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info/licenses}/LICENSE.md +0 -0
- {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info}/top_level.txt +0 -0
warp/codegen.py
CHANGED
|
@@ -26,7 +26,7 @@ import re
|
|
|
26
26
|
import sys
|
|
27
27
|
import textwrap
|
|
28
28
|
import types
|
|
29
|
-
from typing import Any, Callable, Dict, Mapping, Optional, Sequence
|
|
29
|
+
from typing import Any, Callable, Dict, Mapping, Optional, Sequence, get_args, get_origin
|
|
30
30
|
|
|
31
31
|
import warp.config
|
|
32
32
|
from warp.types import *
|
|
@@ -57,7 +57,7 @@ class WarpCodegenKeyError(KeyError):
|
|
|
57
57
|
|
|
58
58
|
|
|
59
59
|
# map operator to function name
|
|
60
|
-
builtin_operators = {}
|
|
60
|
+
builtin_operators: Dict[type[ast.AST], str] = {}
|
|
61
61
|
|
|
62
62
|
# see https://www.ics.uci.edu/~pattis/ICS-31/lectures/opexp.pdf for a
|
|
63
63
|
# nice overview of python operators
|
|
@@ -122,16 +122,6 @@ def get_closure_cell_contents(obj):
|
|
|
122
122
|
return None
|
|
123
123
|
|
|
124
124
|
|
|
125
|
-
def get_type_origin(tp):
|
|
126
|
-
# Compatible version of `typing.get_origin()` for Python 3.7 and older.
|
|
127
|
-
return getattr(tp, "__origin__", None)
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
def get_type_args(tp):
|
|
131
|
-
# Compatible version of `typing.get_args()` for Python 3.7 and older.
|
|
132
|
-
return getattr(tp, "__args__", ())
|
|
133
|
-
|
|
134
|
-
|
|
135
125
|
def eval_annotations(annotations: Mapping[str, Any], obj: Any) -> Mapping[str, Any]:
|
|
136
126
|
"""Un-stringize annotations caused by `from __future__ import annotations` of PEP 563."""
|
|
137
127
|
# Implementation backported from `inspect.get_annotations()` for Python 3.9 and older.
|
|
@@ -415,12 +405,14 @@ class StructInstance:
|
|
|
415
405
|
|
|
416
406
|
|
|
417
407
|
class Struct:
|
|
418
|
-
|
|
408
|
+
hash: bytes
|
|
409
|
+
|
|
410
|
+
def __init__(self, cls: type, key: str, module: warp.context.Module):
|
|
419
411
|
self.cls = cls
|
|
420
412
|
self.module = module
|
|
421
413
|
self.key = key
|
|
414
|
+
self.vars: Dict[str, Var] = {}
|
|
422
415
|
|
|
423
|
-
self.vars = {}
|
|
424
416
|
annotations = get_annotations(self.cls)
|
|
425
417
|
for label, type in annotations.items():
|
|
426
418
|
self.vars[label] = Var(label, type)
|
|
@@ -591,11 +583,11 @@ class Reference:
|
|
|
591
583
|
self.value_type = value_type
|
|
592
584
|
|
|
593
585
|
|
|
594
|
-
def is_reference(type):
|
|
586
|
+
def is_reference(type: Any) -> builtins.bool:
|
|
595
587
|
return isinstance(type, Reference)
|
|
596
588
|
|
|
597
589
|
|
|
598
|
-
def strip_reference(arg):
|
|
590
|
+
def strip_reference(arg: Any) -> Any:
|
|
599
591
|
if is_reference(arg):
|
|
600
592
|
return arg.value_type
|
|
601
593
|
else:
|
|
@@ -623,7 +615,15 @@ def compute_type_str(base_name, template_params):
|
|
|
623
615
|
|
|
624
616
|
|
|
625
617
|
class Var:
|
|
626
|
-
def __init__(
|
|
618
|
+
def __init__(
|
|
619
|
+
self,
|
|
620
|
+
label: str,
|
|
621
|
+
type: type,
|
|
622
|
+
requires_grad: builtins.bool = False,
|
|
623
|
+
constant: Optional[builtins.bool] = None,
|
|
624
|
+
prefix: builtins.bool = True,
|
|
625
|
+
relative_lineno: Optional[int] = None,
|
|
626
|
+
):
|
|
627
627
|
# convert built-in types to wp types
|
|
628
628
|
if type == float:
|
|
629
629
|
type = float32
|
|
@@ -646,11 +646,14 @@ class Var:
|
|
|
646
646
|
# used to associate a view array Var with its parent array Var
|
|
647
647
|
self.parent = None
|
|
648
648
|
|
|
649
|
+
# Used to associate the variable with the Python statement that resulted in it being created.
|
|
650
|
+
self.relative_lineno = relative_lineno
|
|
651
|
+
|
|
649
652
|
def __str__(self):
|
|
650
653
|
return self.label
|
|
651
654
|
|
|
652
655
|
@staticmethod
|
|
653
|
-
def type_to_ctype(t, value_type=False):
|
|
656
|
+
def type_to_ctype(t: type, value_type: builtins.bool = False) -> str:
|
|
654
657
|
if is_array(t):
|
|
655
658
|
if hasattr(t.dtype, "_wp_generic_type_str_"):
|
|
656
659
|
dtypestr = compute_type_str(f"wp::{t.dtype._wp_generic_type_str_}", t.dtype._wp_type_params_)
|
|
@@ -681,7 +684,7 @@ class Var:
|
|
|
681
684
|
else:
|
|
682
685
|
return f"wp::{t.__name__}"
|
|
683
686
|
|
|
684
|
-
def ctype(self, value_type=False):
|
|
687
|
+
def ctype(self, value_type: builtins.bool = False) -> str:
|
|
685
688
|
return Var.type_to_ctype(self.type, value_type)
|
|
686
689
|
|
|
687
690
|
def emit(self, prefix: str = "var"):
|
|
@@ -803,7 +806,7 @@ def func_match_args(func, arg_types, kwarg_types):
|
|
|
803
806
|
return True
|
|
804
807
|
|
|
805
808
|
|
|
806
|
-
def get_arg_type(arg: Union[Var, Any]):
|
|
809
|
+
def get_arg_type(arg: Union[Var, Any]) -> type:
|
|
807
810
|
if isinstance(arg, str):
|
|
808
811
|
return str
|
|
809
812
|
|
|
@@ -819,7 +822,7 @@ def get_arg_type(arg: Union[Var, Any]):
|
|
|
819
822
|
return type(arg)
|
|
820
823
|
|
|
821
824
|
|
|
822
|
-
def get_arg_value(arg:
|
|
825
|
+
def get_arg_value(arg: Any) -> Any:
|
|
823
826
|
if isinstance(arg, Sequence):
|
|
824
827
|
return tuple(get_arg_value(x) for x in arg)
|
|
825
828
|
|
|
@@ -867,6 +870,9 @@ class Adjoint:
|
|
|
867
870
|
"please save it on a file and use `importlib` if needed."
|
|
868
871
|
) from e
|
|
869
872
|
|
|
873
|
+
# Indicates where the function definition starts (excludes decorators)
|
|
874
|
+
adj.fun_def_lineno = None
|
|
875
|
+
|
|
870
876
|
# get function source code
|
|
871
877
|
adj.source = inspect.getsource(func)
|
|
872
878
|
# ensures that indented class methods can be parsed as kernels
|
|
@@ -941,9 +947,6 @@ class Adjoint:
|
|
|
941
947
|
# for unit testing errors being spit out from kernels.
|
|
942
948
|
adj.skip_build = False
|
|
943
949
|
|
|
944
|
-
# Collect the LTOIR required at link-time
|
|
945
|
-
adj.ltoirs = []
|
|
946
|
-
|
|
947
950
|
# allocate extra space for a function call that requires its
|
|
948
951
|
# own shared memory space, we treat shared memory as a stack
|
|
949
952
|
# where each function pushes and pops space off, the extra
|
|
@@ -1133,7 +1136,7 @@ class Adjoint:
|
|
|
1133
1136
|
name = str(index)
|
|
1134
1137
|
|
|
1135
1138
|
# allocate new variable
|
|
1136
|
-
v = Var(name, type=type, constant=constant)
|
|
1139
|
+
v = Var(name, type=type, constant=constant, relative_lineno=adj.lineno)
|
|
1137
1140
|
|
|
1138
1141
|
adj.variables.append(v)
|
|
1139
1142
|
|
|
@@ -1158,11 +1161,44 @@ class Adjoint:
|
|
|
1158
1161
|
|
|
1159
1162
|
return var
|
|
1160
1163
|
|
|
1161
|
-
|
|
1162
|
-
|
|
1164
|
+
def get_line_directive(adj, statement: str, relative_lineno: Optional[int] = None) -> Optional[str]:
|
|
1165
|
+
"""Get a line directive for the given statement.
|
|
1166
|
+
|
|
1167
|
+
Args:
|
|
1168
|
+
statement: The statement to get the line directive for.
|
|
1169
|
+
relative_lineno: The line number of the statement relative to the function.
|
|
1170
|
+
|
|
1171
|
+
Returns:
|
|
1172
|
+
A line directive for the given statement, or None if no line directive is needed.
|
|
1173
|
+
"""
|
|
1174
|
+
|
|
1175
|
+
# lineinfo is enabled by default in debug mode regardless of the builder option, don't want to unnecessarily
|
|
1176
|
+
# emit line directives in generated code if it's not being compiled with line information
|
|
1177
|
+
lineinfo_enabled = (
|
|
1178
|
+
adj.builder_options.get("lineinfo", False) or adj.builder_options.get("mode", "release") == "debug"
|
|
1179
|
+
)
|
|
1180
|
+
|
|
1181
|
+
if relative_lineno is not None and lineinfo_enabled and warp.config.line_directives:
|
|
1182
|
+
is_comment = statement.strip().startswith("//")
|
|
1183
|
+
if not is_comment:
|
|
1184
|
+
line = relative_lineno + adj.fun_lineno
|
|
1185
|
+
# Convert backslashes to forward slashes for CUDA compatibility
|
|
1186
|
+
normalized_path = adj.filename.replace("\\", "/")
|
|
1187
|
+
return f'#line {line} "{normalized_path}"'
|
|
1188
|
+
return None
|
|
1189
|
+
|
|
1190
|
+
def add_forward(adj, statement: str, replay: Optional[str] = None, skip_replay: builtins.bool = False) -> None:
|
|
1191
|
+
"""Append a statement to the forward pass."""
|
|
1192
|
+
|
|
1193
|
+
if line_directive := adj.get_line_directive(statement, adj.lineno):
|
|
1194
|
+
adj.blocks[-1].body_forward.append(line_directive)
|
|
1195
|
+
|
|
1163
1196
|
adj.blocks[-1].body_forward.append(adj.indentation + statement)
|
|
1164
1197
|
|
|
1165
1198
|
if not skip_replay:
|
|
1199
|
+
if line_directive:
|
|
1200
|
+
adj.blocks[-1].body_replay.append(line_directive)
|
|
1201
|
+
|
|
1166
1202
|
if replay:
|
|
1167
1203
|
# if custom replay specified then output it
|
|
1168
1204
|
adj.blocks[-1].body_replay.append(adj.indentation + replay)
|
|
@@ -1171,9 +1207,14 @@ class Adjoint:
|
|
|
1171
1207
|
adj.blocks[-1].body_replay.append(adj.indentation + statement)
|
|
1172
1208
|
|
|
1173
1209
|
# append a statement to the reverse pass
|
|
1174
|
-
def add_reverse(adj, statement):
|
|
1210
|
+
def add_reverse(adj, statement: str) -> None:
|
|
1211
|
+
"""Append a statement to the reverse pass."""
|
|
1212
|
+
|
|
1175
1213
|
adj.blocks[-1].body_reverse.append(adj.indentation + statement)
|
|
1176
1214
|
|
|
1215
|
+
if line_directive := adj.get_line_directive(statement, adj.lineno):
|
|
1216
|
+
adj.blocks[-1].body_reverse.append(line_directive)
|
|
1217
|
+
|
|
1177
1218
|
def add_constant(adj, n):
|
|
1178
1219
|
output = adj.add_var(type=type(n), constant=n)
|
|
1179
1220
|
return output
|
|
@@ -1281,7 +1322,7 @@ class Adjoint:
|
|
|
1281
1322
|
|
|
1282
1323
|
# Bind the positional and keyword arguments to the function's signature
|
|
1283
1324
|
# in order to process them as Python does it.
|
|
1284
|
-
bound_args = func.signature.bind(*args, **kwargs)
|
|
1325
|
+
bound_args: inspect.BoundArguments = func.signature.bind(*args, **kwargs)
|
|
1285
1326
|
|
|
1286
1327
|
# Type args are the “compile time” argument values we get from codegen.
|
|
1287
1328
|
# For example, when calling `wp.vec3f(...)` from within a kernel,
|
|
@@ -1624,6 +1665,8 @@ class Adjoint:
|
|
|
1624
1665
|
adj.blocks[-1].body_reverse.extend(reversed(reverse))
|
|
1625
1666
|
|
|
1626
1667
|
def emit_FunctionDef(adj, node):
|
|
1668
|
+
adj.fun_def_lineno = node.lineno
|
|
1669
|
+
|
|
1627
1670
|
for f in node.body:
|
|
1628
1671
|
# Skip variable creation for standalone constants, including docstrings
|
|
1629
1672
|
if isinstance(f, ast.Expr) and isinstance(f.value, ast.Constant):
|
|
@@ -1688,7 +1731,7 @@ class Adjoint:
|
|
|
1688
1731
|
|
|
1689
1732
|
if var1 != var2:
|
|
1690
1733
|
# insert a phi function that selects var1, var2 based on cond
|
|
1691
|
-
out = adj.add_builtin_call("
|
|
1734
|
+
out = adj.add_builtin_call("where", [cond, var2, var1])
|
|
1692
1735
|
adj.symbols[sym] = out
|
|
1693
1736
|
|
|
1694
1737
|
symbols_prev = adj.symbols.copy()
|
|
@@ -1712,7 +1755,7 @@ class Adjoint:
|
|
|
1712
1755
|
if var1 != var2:
|
|
1713
1756
|
# insert a phi function that selects var1, var2 based on cond
|
|
1714
1757
|
# note the reversed order of vars since we want to use !cond as our select
|
|
1715
|
-
out = adj.add_builtin_call("
|
|
1758
|
+
out = adj.add_builtin_call("where", [cond, var1, var2])
|
|
1716
1759
|
adj.symbols[sym] = out
|
|
1717
1760
|
|
|
1718
1761
|
def emit_Compare(adj, node):
|
|
@@ -1856,25 +1899,6 @@ class Adjoint:
|
|
|
1856
1899
|
) from e
|
|
1857
1900
|
raise WarpCodegenAttributeError(f"Error, `{node.attr}` is not an attribute of '{aggregate}'") from e
|
|
1858
1901
|
|
|
1859
|
-
def emit_String(adj, node):
|
|
1860
|
-
# string constant
|
|
1861
|
-
return adj.add_constant(node.s)
|
|
1862
|
-
|
|
1863
|
-
def emit_Num(adj, node):
|
|
1864
|
-
# lookup constant, if it has already been assigned then return existing var
|
|
1865
|
-
key = (node.n, type(node.n))
|
|
1866
|
-
|
|
1867
|
-
if key in adj.symbols:
|
|
1868
|
-
return adj.symbols[key]
|
|
1869
|
-
else:
|
|
1870
|
-
out = adj.add_constant(node.n)
|
|
1871
|
-
adj.symbols[key] = out
|
|
1872
|
-
return out
|
|
1873
|
-
|
|
1874
|
-
def emit_Ellipsis(adj, node):
|
|
1875
|
-
# stubbed @wp.native_func
|
|
1876
|
-
return
|
|
1877
|
-
|
|
1878
1902
|
def emit_Assert(adj, node):
|
|
1879
1903
|
# eval condition
|
|
1880
1904
|
cond = adj.eval(node.test)
|
|
@@ -1886,24 +1910,11 @@ class Adjoint:
|
|
|
1886
1910
|
|
|
1887
1911
|
adj.add_forward(f'assert(("{escaped_segment}",{cond.emit()}));')
|
|
1888
1912
|
|
|
1889
|
-
def emit_NameConstant(adj, node):
|
|
1890
|
-
if node.value:
|
|
1891
|
-
return adj.add_constant(node.value)
|
|
1892
|
-
elif node.value is None:
|
|
1893
|
-
raise WarpCodegenTypeError("None type unsupported")
|
|
1894
|
-
else:
|
|
1895
|
-
return adj.add_constant(False)
|
|
1896
|
-
|
|
1897
1913
|
def emit_Constant(adj, node):
|
|
1898
|
-
if
|
|
1899
|
-
|
|
1900
|
-
elif isinstance(node, ast.Num):
|
|
1901
|
-
return adj.emit_Num(node)
|
|
1902
|
-
elif isinstance(node, ast.Ellipsis):
|
|
1903
|
-
return adj.emit_Ellipsis(node)
|
|
1914
|
+
if node.value is None:
|
|
1915
|
+
raise WarpCodegenTypeError("None type unsupported")
|
|
1904
1916
|
else:
|
|
1905
|
-
|
|
1906
|
-
return adj.emit_NameConstant(node)
|
|
1917
|
+
return adj.add_constant(node.value)
|
|
1907
1918
|
|
|
1908
1919
|
def emit_BinOp(adj, node):
|
|
1909
1920
|
# evaluate binary operator arguments
|
|
@@ -1997,10 +2008,11 @@ class Adjoint:
|
|
|
1997
2008
|
adj.end_while()
|
|
1998
2009
|
|
|
1999
2010
|
def eval_num(adj, a):
|
|
2000
|
-
if isinstance(a, ast.
|
|
2001
|
-
return True, a.
|
|
2002
|
-
if isinstance(a, ast.UnaryOp) and isinstance(a.op, ast.USub) and isinstance(a.operand, ast.
|
|
2003
|
-
|
|
2011
|
+
if isinstance(a, ast.Constant):
|
|
2012
|
+
return True, a.value
|
|
2013
|
+
if isinstance(a, ast.UnaryOp) and isinstance(a.op, ast.USub) and isinstance(a.operand, ast.Constant):
|
|
2014
|
+
# Negative constant
|
|
2015
|
+
return True, -a.operand.value
|
|
2004
2016
|
|
|
2005
2017
|
# try and resolve the expression to an object
|
|
2006
2018
|
# e.g.: wp.constant in the globals scope
|
|
@@ -2530,8 +2542,8 @@ class Adjoint:
|
|
|
2530
2542
|
f"Warning: mutating {node_source} in function {adj.fun_name} at {adj.filename}:{lineno}: this is a non-differentiable operation.\n{line}\n"
|
|
2531
2543
|
)
|
|
2532
2544
|
else:
|
|
2533
|
-
if
|
|
2534
|
-
out = adj.add_builtin_call("
|
|
2545
|
+
if warp.config.enable_vector_component_overwrites:
|
|
2546
|
+
out = adj.add_builtin_call("assign_copy", [target, *indices, rhs])
|
|
2535
2547
|
|
|
2536
2548
|
# re-point target symbol to out var
|
|
2537
2549
|
for id in adj.symbols:
|
|
@@ -2539,8 +2551,7 @@ class Adjoint:
|
|
|
2539
2551
|
adj.symbols[id] = out
|
|
2540
2552
|
break
|
|
2541
2553
|
else:
|
|
2542
|
-
|
|
2543
|
-
adj.add_builtin_call("store", [attr, rhs])
|
|
2554
|
+
adj.add_builtin_call("assign_inplace", [target, *indices, rhs])
|
|
2544
2555
|
|
|
2545
2556
|
else:
|
|
2546
2557
|
raise WarpCodegenError(
|
|
@@ -2583,8 +2594,8 @@ class Adjoint:
|
|
|
2583
2594
|
attr = adj.add_builtin_call("indexref", [aggregate, index])
|
|
2584
2595
|
adj.add_builtin_call("store", [attr, rhs])
|
|
2585
2596
|
else:
|
|
2586
|
-
if
|
|
2587
|
-
out = adj.add_builtin_call("
|
|
2597
|
+
if warp.config.enable_vector_component_overwrites:
|
|
2598
|
+
out = adj.add_builtin_call("assign_copy", [aggregate, index, rhs])
|
|
2588
2599
|
|
|
2589
2600
|
# re-point target symbol to out var
|
|
2590
2601
|
for id in adj.symbols:
|
|
@@ -2592,8 +2603,7 @@ class Adjoint:
|
|
|
2592
2603
|
adj.symbols[id] = out
|
|
2593
2604
|
break
|
|
2594
2605
|
else:
|
|
2595
|
-
|
|
2596
|
-
adj.add_builtin_call("store", [attr, rhs])
|
|
2606
|
+
adj.add_builtin_call("assign_inplace", [aggregate, index, rhs])
|
|
2597
2607
|
|
|
2598
2608
|
else:
|
|
2599
2609
|
attr = adj.emit_Attribute(lhs)
|
|
@@ -2699,10 +2709,12 @@ class Adjoint:
|
|
|
2699
2709
|
|
|
2700
2710
|
elif type_is_vector(target_type) or type_is_quaternion(target_type) or type_is_matrix(target_type):
|
|
2701
2711
|
if isinstance(node.op, ast.Add):
|
|
2702
|
-
adj.add_builtin_call("
|
|
2712
|
+
adj.add_builtin_call("add_inplace", [target, *indices, rhs])
|
|
2703
2713
|
elif isinstance(node.op, ast.Sub):
|
|
2704
|
-
adj.add_builtin_call("
|
|
2714
|
+
adj.add_builtin_call("sub_inplace", [target, *indices, rhs])
|
|
2705
2715
|
else:
|
|
2716
|
+
if warp.config.verbose:
|
|
2717
|
+
print(f"Warning: in-place op {node.op} is not differentiable")
|
|
2706
2718
|
make_new_assign_statement()
|
|
2707
2719
|
return
|
|
2708
2720
|
|
|
@@ -2732,9 +2744,6 @@ class Adjoint:
|
|
|
2732
2744
|
ast.BoolOp: emit_BoolOp,
|
|
2733
2745
|
ast.Name: emit_Name,
|
|
2734
2746
|
ast.Attribute: emit_Attribute,
|
|
2735
|
-
ast.Str: emit_String, # Deprecated in 3.8; use Constant
|
|
2736
|
-
ast.Num: emit_Num, # Deprecated in 3.8; use Constant
|
|
2737
|
-
ast.NameConstant: emit_NameConstant, # Deprecated in 3.8; use Constant
|
|
2738
2747
|
ast.Constant: emit_Constant,
|
|
2739
2748
|
ast.BinOp: emit_BinOp,
|
|
2740
2749
|
ast.UnaryOp: emit_UnaryOp,
|
|
@@ -2744,14 +2753,13 @@ class Adjoint:
|
|
|
2744
2753
|
ast.Continue: emit_Continue,
|
|
2745
2754
|
ast.Expr: emit_Expr,
|
|
2746
2755
|
ast.Call: emit_Call,
|
|
2747
|
-
ast.Index: emit_Index, # Deprecated in 3.
|
|
2756
|
+
ast.Index: emit_Index, # Deprecated in 3.9
|
|
2748
2757
|
ast.Subscript: emit_Subscript,
|
|
2749
2758
|
ast.Assign: emit_Assign,
|
|
2750
2759
|
ast.Return: emit_Return,
|
|
2751
2760
|
ast.AugAssign: emit_AugAssign,
|
|
2752
2761
|
ast.Tuple: emit_Tuple,
|
|
2753
2762
|
ast.Pass: emit_Pass,
|
|
2754
|
-
ast.Ellipsis: emit_Ellipsis,
|
|
2755
2763
|
ast.Assert: emit_Assert,
|
|
2756
2764
|
}
|
|
2757
2765
|
|
|
@@ -2947,12 +2955,16 @@ class Adjoint:
|
|
|
2947
2955
|
|
|
2948
2956
|
# We want to replace the expression code in-place,
|
|
2949
2957
|
# so reparse it to get the correct column info.
|
|
2950
|
-
len_value_locs = []
|
|
2958
|
+
len_value_locs: List[Tuple[int, int, int]] = []
|
|
2951
2959
|
expr_tree = ast.parse(static_code)
|
|
2952
2960
|
assert len(expr_tree.body) == 1 and isinstance(expr_tree.body[0], ast.Expr)
|
|
2953
2961
|
expr_root = expr_tree.body[0].value
|
|
2954
2962
|
for expr_node in ast.walk(expr_root):
|
|
2955
|
-
if
|
|
2963
|
+
if (
|
|
2964
|
+
isinstance(expr_node, ast.Call)
|
|
2965
|
+
and getattr(expr_node.func, "id", None) == "len"
|
|
2966
|
+
and len(expr_node.args) == 1
|
|
2967
|
+
):
|
|
2956
2968
|
len_expr = static_code[expr_node.col_offset : expr_node.end_col_offset]
|
|
2957
2969
|
try:
|
|
2958
2970
|
len_value = eval(len_expr, len_expr_ctx)
|
|
@@ -3110,9 +3122,9 @@ class Adjoint:
|
|
|
3110
3122
|
|
|
3111
3123
|
local_variables = set() # Track local variables appearing on the LHS so we know when variables are shadowed
|
|
3112
3124
|
|
|
3113
|
-
constants = {}
|
|
3114
|
-
types = {}
|
|
3115
|
-
functions = {}
|
|
3125
|
+
constants: Dict[str, Any] = {}
|
|
3126
|
+
types: Dict[Union[Struct, type], Any] = {}
|
|
3127
|
+
functions: Dict[warp.context.Function, Any] = {}
|
|
3116
3128
|
|
|
3117
3129
|
for node in ast.walk(adj.tree):
|
|
3118
3130
|
if isinstance(node, ast.Name) and node.id not in local_variables:
|
|
@@ -3155,7 +3167,7 @@ class Adjoint:
|
|
|
3155
3167
|
# code generation
|
|
3156
3168
|
|
|
3157
3169
|
cpu_module_header = """
|
|
3158
|
-
#define WP_TILE_BLOCK_DIM {
|
|
3170
|
+
#define WP_TILE_BLOCK_DIM {block_dim}
|
|
3159
3171
|
#define WP_NO_CRT
|
|
3160
3172
|
#include "builtin.h"
|
|
3161
3173
|
|
|
@@ -3174,7 +3186,7 @@ cpu_module_header = """
|
|
|
3174
3186
|
"""
|
|
3175
3187
|
|
|
3176
3188
|
cuda_module_header = """
|
|
3177
|
-
#define WP_TILE_BLOCK_DIM {
|
|
3189
|
+
#define WP_TILE_BLOCK_DIM {block_dim}
|
|
3178
3190
|
#define WP_NO_CRT
|
|
3179
3191
|
#include "builtin.h"
|
|
3180
3192
|
|
|
@@ -3197,6 +3209,7 @@ struct {name}
|
|
|
3197
3209
|
{{
|
|
3198
3210
|
{struct_body}
|
|
3199
3211
|
|
|
3212
|
+
{defaulted_constructor_def}
|
|
3200
3213
|
CUDA_CALLABLE {name}({forward_args})
|
|
3201
3214
|
{forward_initializers}
|
|
3202
3215
|
{{
|
|
@@ -3239,53 +3252,53 @@ static void adj_{name}(
|
|
|
3239
3252
|
|
|
3240
3253
|
cuda_forward_function_template = """
|
|
3241
3254
|
// {filename}:{lineno}
|
|
3242
|
-
static CUDA_CALLABLE {return_type} {name}(
|
|
3255
|
+
{line_directive}static CUDA_CALLABLE {return_type} {name}(
|
|
3243
3256
|
{forward_args})
|
|
3244
3257
|
{{
|
|
3245
|
-
{forward_body}}}
|
|
3258
|
+
{forward_body}{line_directive}}}
|
|
3246
3259
|
|
|
3247
3260
|
"""
|
|
3248
3261
|
|
|
3249
3262
|
cuda_reverse_function_template = """
|
|
3250
3263
|
// {filename}:{lineno}
|
|
3251
|
-
static CUDA_CALLABLE void adj_{name}(
|
|
3264
|
+
{line_directive}static CUDA_CALLABLE void adj_{name}(
|
|
3252
3265
|
{reverse_args})
|
|
3253
3266
|
{{
|
|
3254
|
-
{reverse_body}}}
|
|
3267
|
+
{reverse_body}{line_directive}}}
|
|
3255
3268
|
|
|
3256
3269
|
"""
|
|
3257
3270
|
|
|
3258
3271
|
cuda_kernel_template_forward = """
|
|
3259
3272
|
|
|
3260
|
-
extern "C" __global__ void {name}_cuda_kernel_forward(
|
|
3273
|
+
{line_directive}extern "C" __global__ void {name}_cuda_kernel_forward(
|
|
3261
3274
|
{forward_args})
|
|
3262
3275
|
{{
|
|
3263
|
-
for (size_t _idx = static_cast<size_t>(blockDim.x) * static_cast<size_t>(blockIdx.x) + static_cast<size_t>(threadIdx.x);
|
|
3264
|
-
_idx < dim.size;
|
|
3265
|
-
_idx += static_cast<size_t>(blockDim.x) * static_cast<size_t>(gridDim.x))
|
|
3276
|
+
{line_directive} for (size_t _idx = static_cast<size_t>(blockDim.x) * static_cast<size_t>(blockIdx.x) + static_cast<size_t>(threadIdx.x);
|
|
3277
|
+
{line_directive} _idx < dim.size;
|
|
3278
|
+
{line_directive} _idx += static_cast<size_t>(blockDim.x) * static_cast<size_t>(gridDim.x))
|
|
3266
3279
|
{{
|
|
3267
3280
|
// reset shared memory allocator
|
|
3268
|
-
wp::tile_alloc_shared(0, true);
|
|
3281
|
+
{line_directive} wp::tile_alloc_shared(0, true);
|
|
3269
3282
|
|
|
3270
|
-
{forward_body} }}
|
|
3271
|
-
}}
|
|
3283
|
+
{forward_body}{line_directive} }}
|
|
3284
|
+
{line_directive}}}
|
|
3272
3285
|
|
|
3273
3286
|
"""
|
|
3274
3287
|
|
|
3275
3288
|
cuda_kernel_template_backward = """
|
|
3276
3289
|
|
|
3277
|
-
extern "C" __global__ void {name}_cuda_kernel_backward(
|
|
3290
|
+
{line_directive}extern "C" __global__ void {name}_cuda_kernel_backward(
|
|
3278
3291
|
{reverse_args})
|
|
3279
3292
|
{{
|
|
3280
|
-
for (size_t _idx = static_cast<size_t>(blockDim.x) * static_cast<size_t>(blockIdx.x) + static_cast<size_t>(threadIdx.x);
|
|
3281
|
-
_idx < dim.size;
|
|
3282
|
-
_idx += static_cast<size_t>(blockDim.x) * static_cast<size_t>(gridDim.x))
|
|
3293
|
+
{line_directive} for (size_t _idx = static_cast<size_t>(blockDim.x) * static_cast<size_t>(blockIdx.x) + static_cast<size_t>(threadIdx.x);
|
|
3294
|
+
{line_directive} _idx < dim.size;
|
|
3295
|
+
{line_directive} _idx += static_cast<size_t>(blockDim.x) * static_cast<size_t>(gridDim.x))
|
|
3283
3296
|
{{
|
|
3284
3297
|
// reset shared memory allocator
|
|
3285
|
-
wp::tile_alloc_shared(0, true);
|
|
3298
|
+
{line_directive} wp::tile_alloc_shared(0, true);
|
|
3286
3299
|
|
|
3287
|
-
{reverse_body} }}
|
|
3288
|
-
}}
|
|
3300
|
+
{reverse_body}{line_directive} }}
|
|
3301
|
+
{line_directive}}}
|
|
3289
3302
|
|
|
3290
3303
|
"""
|
|
3291
3304
|
|
|
@@ -3315,10 +3328,17 @@ extern "C" {{
|
|
|
3315
3328
|
WP_API void {name}_cpu_forward(
|
|
3316
3329
|
{forward_args})
|
|
3317
3330
|
{{
|
|
3318
|
-
|
|
3331
|
+
for (size_t task_index = 0; task_index < dim.size; ++task_index)
|
|
3319
3332
|
{{
|
|
3333
|
+
// init shared memory allocator
|
|
3334
|
+
wp::tile_alloc_shared(0, true);
|
|
3335
|
+
|
|
3320
3336
|
{name}_cpu_kernel_forward(
|
|
3321
3337
|
{forward_params});
|
|
3338
|
+
|
|
3339
|
+
// check shared memory allocator
|
|
3340
|
+
wp::tile_alloc_shared(0, false, true);
|
|
3341
|
+
|
|
3322
3342
|
}}
|
|
3323
3343
|
}}
|
|
3324
3344
|
|
|
@@ -3335,8 +3355,14 @@ WP_API void {name}_cpu_backward(
|
|
|
3335
3355
|
{{
|
|
3336
3356
|
for (size_t task_index = 0; task_index < dim.size; ++task_index)
|
|
3337
3357
|
{{
|
|
3358
|
+
// initialize shared memory allocator
|
|
3359
|
+
wp::tile_alloc_shared(0, true);
|
|
3360
|
+
|
|
3338
3361
|
{name}_cpu_kernel_backward(
|
|
3339
3362
|
{reverse_params});
|
|
3363
|
+
|
|
3364
|
+
// check shared memory allocator
|
|
3365
|
+
wp::tile_alloc_shared(0, false, true);
|
|
3340
3366
|
}}
|
|
3341
3367
|
}}
|
|
3342
3368
|
|
|
@@ -3418,7 +3444,7 @@ def indent(args, stops=1):
|
|
|
3418
3444
|
|
|
3419
3445
|
|
|
3420
3446
|
# generates a C function name based on the python function name
|
|
3421
|
-
def make_full_qualified_name(func):
|
|
3447
|
+
def make_full_qualified_name(func: Union[str, Callable]) -> str:
|
|
3422
3448
|
if not isinstance(func, str):
|
|
3423
3449
|
func = func.__qualname__
|
|
3424
3450
|
return re.sub("[^0-9a-zA-Z_]+", "", func.replace(".", "__"))
|
|
@@ -3448,7 +3474,8 @@ def codegen_struct(struct, device="cpu", indent_size=4):
|
|
|
3448
3474
|
# forward args
|
|
3449
3475
|
for label, var in struct.vars.items():
|
|
3450
3476
|
var_ctype = var.ctype()
|
|
3451
|
-
|
|
3477
|
+
default_arg_def = " = {}" if forward_args else ""
|
|
3478
|
+
forward_args.append(f"{var_ctype} const& {label}{default_arg_def}")
|
|
3452
3479
|
reverse_args.append(f"{var_ctype} const&")
|
|
3453
3480
|
|
|
3454
3481
|
namespace = "wp::" if var_ctype.startswith("wp::") or var_ctype == "bool" else ""
|
|
@@ -3472,6 +3499,9 @@ def codegen_struct(struct, device="cpu", indent_size=4):
|
|
|
3472
3499
|
|
|
3473
3500
|
reverse_args.append(name + " & adj_ret")
|
|
3474
3501
|
|
|
3502
|
+
# explicitly defaulted default constructor if no default constructor has been defined
|
|
3503
|
+
defaulted_constructor_def = f"{name}() = default;" if forward_args else ""
|
|
3504
|
+
|
|
3475
3505
|
return struct_template.format(
|
|
3476
3506
|
name=name,
|
|
3477
3507
|
struct_body="".join([indent_block + l for l in body]),
|
|
@@ -3481,6 +3511,7 @@ def codegen_struct(struct, device="cpu", indent_size=4):
|
|
|
3481
3511
|
reverse_body="".join(reverse_body),
|
|
3482
3512
|
prefix_add_body="".join(prefix_add_body),
|
|
3483
3513
|
atomic_add_body="".join(atomic_add_body),
|
|
3514
|
+
defaulted_constructor_def=defaulted_constructor_def,
|
|
3484
3515
|
)
|
|
3485
3516
|
|
|
3486
3517
|
|
|
@@ -3510,6 +3541,9 @@ def codegen_func_forward(adj, func_type="kernel", device="cpu"):
|
|
|
3510
3541
|
else:
|
|
3511
3542
|
lines += [f"const {var.ctype()} {var.emit()} = {constant_str(var.constant)};\n"]
|
|
3512
3543
|
|
|
3544
|
+
if line_directive := adj.get_line_directive(lines[-1], var.relative_lineno):
|
|
3545
|
+
lines.insert(-1, f"{line_directive}\n")
|
|
3546
|
+
|
|
3513
3547
|
# forward pass
|
|
3514
3548
|
lines += ["//---------\n"]
|
|
3515
3549
|
lines += ["// forward\n"]
|
|
@@ -3517,7 +3551,7 @@ def codegen_func_forward(adj, func_type="kernel", device="cpu"):
|
|
|
3517
3551
|
for f in adj.blocks[0].body_forward:
|
|
3518
3552
|
lines += [f + "\n"]
|
|
3519
3553
|
|
|
3520
|
-
return "".join(
|
|
3554
|
+
return "".join(l.lstrip() if l.lstrip().startswith("#line") else indent_block + l for l in lines)
|
|
3521
3555
|
|
|
3522
3556
|
|
|
3523
3557
|
def codegen_func_reverse(adj, func_type="kernel", device="cpu"):
|
|
@@ -3547,6 +3581,9 @@ def codegen_func_reverse(adj, func_type="kernel", device="cpu"):
|
|
|
3547
3581
|
else:
|
|
3548
3582
|
lines += [f"const {var.ctype()} {var.emit()} = {constant_str(var.constant)};\n"]
|
|
3549
3583
|
|
|
3584
|
+
if line_directive := adj.get_line_directive(lines[-1], var.relative_lineno):
|
|
3585
|
+
lines.insert(-1, f"{line_directive}\n")
|
|
3586
|
+
|
|
3550
3587
|
# dual vars
|
|
3551
3588
|
lines += ["//---------\n"]
|
|
3552
3589
|
lines += ["// dual vars\n"]
|
|
@@ -3567,6 +3604,9 @@ def codegen_func_reverse(adj, func_type="kernel", device="cpu"):
|
|
|
3567
3604
|
else:
|
|
3568
3605
|
lines += [f"{ctype} {name} = {{}};\n"]
|
|
3569
3606
|
|
|
3607
|
+
if line_directive := adj.get_line_directive(lines[-1], var.relative_lineno):
|
|
3608
|
+
lines.insert(-1, f"{line_directive}\n")
|
|
3609
|
+
|
|
3570
3610
|
# forward pass
|
|
3571
3611
|
lines += ["//---------\n"]
|
|
3572
3612
|
lines += ["// forward\n"]
|
|
@@ -3587,7 +3627,7 @@ def codegen_func_reverse(adj, func_type="kernel", device="cpu"):
|
|
|
3587
3627
|
else:
|
|
3588
3628
|
lines += ["return;\n"]
|
|
3589
3629
|
|
|
3590
|
-
return "".join(
|
|
3630
|
+
return "".join(l.lstrip() if l.lstrip().startswith("#line") else indent_block + l for l in lines)
|
|
3591
3631
|
|
|
3592
3632
|
|
|
3593
3633
|
def codegen_func(adj, c_func_name: str, device="cpu", options=None):
|
|
@@ -3595,11 +3635,11 @@ def codegen_func(adj, c_func_name: str, device="cpu", options=None):
|
|
|
3595
3635
|
options = {}
|
|
3596
3636
|
|
|
3597
3637
|
if adj.return_var is not None and "return" in adj.arg_types:
|
|
3598
|
-
if
|
|
3599
|
-
if len(
|
|
3638
|
+
if get_origin(adj.arg_types["return"]) is tuple:
|
|
3639
|
+
if len(get_args(adj.arg_types["return"])) != len(adj.return_var):
|
|
3600
3640
|
raise WarpCodegenError(
|
|
3601
3641
|
f"The function `{adj.fun_name}` has its return type "
|
|
3602
|
-
f"annotated as a tuple of {len(
|
|
3642
|
+
f"annotated as a tuple of {len(get_args(adj.arg_types['return']))} elements "
|
|
3603
3643
|
f"but the code returns {len(adj.return_var)} values."
|
|
3604
3644
|
)
|
|
3605
3645
|
elif not types_equal(adj.arg_types["return"], tuple(x.type for x in adj.return_var)):
|
|
@@ -3608,7 +3648,7 @@ def codegen_func(adj, c_func_name: str, device="cpu", options=None):
|
|
|
3608
3648
|
f"annotated as `{warp.context.type_str(adj.arg_types['return'])}` "
|
|
3609
3649
|
f"but the code returns a tuple with types `({', '.join(warp.context.type_str(x.type) for x in adj.return_var)})`."
|
|
3610
3650
|
)
|
|
3611
|
-
elif len(adj.return_var) > 1 and
|
|
3651
|
+
elif len(adj.return_var) > 1 and get_origin(adj.arg_types["return"]) is not tuple:
|
|
3612
3652
|
raise WarpCodegenError(
|
|
3613
3653
|
f"The function `{adj.fun_name}` has its return type "
|
|
3614
3654
|
f"annotated as `{warp.context.type_str(adj.arg_types['return'])}` "
|
|
@@ -3621,6 +3661,13 @@ def codegen_func(adj, c_func_name: str, device="cpu", options=None):
|
|
|
3621
3661
|
f"but the code returns a value of type `{warp.context.type_str(adj.return_var[0].type)}`."
|
|
3622
3662
|
)
|
|
3623
3663
|
|
|
3664
|
+
# Build line directive for function definition (subtract 1 to account for 1-indexing of AST line numbers)
|
|
3665
|
+
# This is used as a catch-all C-to-Python source line mapping for any code that does not have
|
|
3666
|
+
# a direct mapping to a Python source line.
|
|
3667
|
+
func_line_directive = ""
|
|
3668
|
+
if line_directive := adj.get_line_directive("", adj.fun_def_lineno - 1):
|
|
3669
|
+
func_line_directive = f"{line_directive}\n"
|
|
3670
|
+
|
|
3624
3671
|
# forward header
|
|
3625
3672
|
if adj.return_var is not None and len(adj.return_var) == 1:
|
|
3626
3673
|
return_type = adj.return_var[0].ctype()
|
|
@@ -3684,6 +3731,7 @@ def codegen_func(adj, c_func_name: str, device="cpu", options=None):
|
|
|
3684
3731
|
forward_body=forward_body,
|
|
3685
3732
|
filename=adj.filename,
|
|
3686
3733
|
lineno=adj.fun_lineno,
|
|
3734
|
+
line_directive=func_line_directive,
|
|
3687
3735
|
)
|
|
3688
3736
|
|
|
3689
3737
|
if not adj.skip_reverse_codegen:
|
|
@@ -3702,6 +3750,7 @@ def codegen_func(adj, c_func_name: str, device="cpu", options=None):
|
|
|
3702
3750
|
reverse_body=reverse_body,
|
|
3703
3751
|
filename=adj.filename,
|
|
3704
3752
|
lineno=adj.fun_lineno,
|
|
3753
|
+
line_directive=func_line_directive,
|
|
3705
3754
|
)
|
|
3706
3755
|
|
|
3707
3756
|
return s
|
|
@@ -3744,6 +3793,7 @@ def codegen_snippet(adj, name, snippet, adj_snippet, replay_snippet):
|
|
|
3744
3793
|
forward_body=snippet,
|
|
3745
3794
|
filename=adj.filename,
|
|
3746
3795
|
lineno=adj.fun_lineno,
|
|
3796
|
+
line_directive="",
|
|
3747
3797
|
)
|
|
3748
3798
|
|
|
3749
3799
|
if replay_snippet is not None:
|
|
@@ -3754,6 +3804,7 @@ def codegen_snippet(adj, name, snippet, adj_snippet, replay_snippet):
|
|
|
3754
3804
|
forward_body=replay_snippet,
|
|
3755
3805
|
filename=adj.filename,
|
|
3756
3806
|
lineno=adj.fun_lineno,
|
|
3807
|
+
line_directive="",
|
|
3757
3808
|
)
|
|
3758
3809
|
|
|
3759
3810
|
if adj_snippet:
|
|
@@ -3769,6 +3820,7 @@ def codegen_snippet(adj, name, snippet, adj_snippet, replay_snippet):
|
|
|
3769
3820
|
reverse_body=reverse_body,
|
|
3770
3821
|
filename=adj.filename,
|
|
3771
3822
|
lineno=adj.fun_lineno,
|
|
3823
|
+
line_directive="",
|
|
3772
3824
|
)
|
|
3773
3825
|
|
|
3774
3826
|
return s
|
|
@@ -3781,6 +3833,13 @@ def codegen_kernel(kernel, device, options):
|
|
|
3781
3833
|
|
|
3782
3834
|
adj = kernel.adj
|
|
3783
3835
|
|
|
3836
|
+
# Build line directive for function definition (subtract 1 to account for 1-indexing of AST line numbers)
|
|
3837
|
+
# This is used as a catch-all C-to-Python source line mapping for any code that does not have
|
|
3838
|
+
# a direct mapping to a Python source line.
|
|
3839
|
+
func_line_directive = ""
|
|
3840
|
+
if line_directive := adj.get_line_directive("", adj.fun_def_lineno - 1):
|
|
3841
|
+
func_line_directive = f"{line_directive}\n"
|
|
3842
|
+
|
|
3784
3843
|
if device == "cpu":
|
|
3785
3844
|
template_forward = cpu_kernel_template_forward
|
|
3786
3845
|
template_backward = cpu_kernel_template_backward
|
|
@@ -3808,6 +3867,7 @@ def codegen_kernel(kernel, device, options):
|
|
|
3808
3867
|
{
|
|
3809
3868
|
"forward_args": indent(forward_args),
|
|
3810
3869
|
"forward_body": forward_body,
|
|
3870
|
+
"line_directive": func_line_directive,
|
|
3811
3871
|
}
|
|
3812
3872
|
)
|
|
3813
3873
|
template += template_forward
|