warp-lang 1.6.2__py3-none-macosx_10_13_universal2.whl → 1.7.1__py3-none-macosx_10_13_universal2.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +7 -1
- warp/autograd.py +12 -2
- warp/bin/libwarp-clang.dylib +0 -0
- warp/bin/libwarp.dylib +0 -0
- warp/build.py +410 -0
- warp/build_dll.py +6 -14
- warp/builtins.py +463 -372
- warp/codegen.py +196 -124
- warp/config.py +42 -6
- warp/context.py +496 -271
- warp/dlpack.py +8 -6
- warp/examples/assets/nonuniform.usd +0 -0
- warp/examples/assets/nvidia_logo.png +0 -0
- warp/examples/benchmarks/benchmark_cloth.py +1 -1
- warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
- warp/examples/core/example_sample_mesh.py +300 -0
- warp/examples/distributed/example_jacobi_mpi.py +507 -0
- warp/examples/fem/example_apic_fluid.py +1 -1
- warp/examples/fem/example_burgers.py +2 -2
- warp/examples/fem/example_deformed_geometry.py +1 -1
- warp/examples/fem/example_distortion_energy.py +1 -1
- warp/examples/fem/example_magnetostatics.py +6 -6
- warp/examples/fem/utils.py +9 -3
- warp/examples/interop/example_jax_callable.py +116 -0
- warp/examples/interop/example_jax_ffi_callback.py +132 -0
- warp/examples/interop/example_jax_kernel.py +205 -0
- warp/examples/optim/example_fluid_checkpoint.py +497 -0
- warp/examples/tile/example_tile_matmul.py +2 -4
- warp/fem/__init__.py +11 -1
- warp/fem/adaptivity.py +4 -4
- warp/fem/field/field.py +11 -1
- warp/fem/field/nodal_field.py +56 -88
- warp/fem/field/virtual.py +62 -23
- warp/fem/geometry/adaptive_nanogrid.py +16 -13
- warp/fem/geometry/closest_point.py +1 -1
- warp/fem/geometry/deformed_geometry.py +5 -2
- warp/fem/geometry/geometry.py +5 -0
- warp/fem/geometry/grid_2d.py +12 -12
- warp/fem/geometry/grid_3d.py +12 -15
- warp/fem/geometry/hexmesh.py +5 -7
- warp/fem/geometry/nanogrid.py +9 -11
- warp/fem/geometry/quadmesh.py +13 -13
- warp/fem/geometry/tetmesh.py +3 -4
- warp/fem/geometry/trimesh.py +7 -20
- warp/fem/integrate.py +262 -93
- warp/fem/linalg.py +5 -5
- warp/fem/quadrature/pic_quadrature.py +37 -22
- warp/fem/quadrature/quadrature.py +194 -25
- warp/fem/space/__init__.py +1 -1
- warp/fem/space/basis_function_space.py +4 -2
- warp/fem/space/basis_space.py +25 -18
- warp/fem/space/hexmesh_function_space.py +2 -2
- warp/fem/space/partition.py +6 -2
- warp/fem/space/quadmesh_function_space.py +8 -8
- warp/fem/space/shape/cube_shape_function.py +23 -23
- warp/fem/space/shape/square_shape_function.py +12 -12
- warp/fem/space/shape/triangle_shape_function.py +1 -1
- warp/fem/space/tetmesh_function_space.py +3 -3
- warp/fem/space/trimesh_function_space.py +2 -2
- warp/fem/utils.py +12 -6
- warp/jax.py +14 -1
- warp/jax_experimental/__init__.py +16 -0
- warp/{jax_experimental.py → jax_experimental/custom_call.py} +28 -29
- warp/jax_experimental/ffi.py +702 -0
- warp/jax_experimental/xla_ffi.py +602 -0
- warp/math.py +89 -0
- warp/native/array.h +13 -0
- warp/native/builtin.h +29 -3
- warp/native/bvh.cpp +3 -1
- warp/native/bvh.cu +42 -14
- warp/native/bvh.h +2 -1
- warp/native/clang/clang.cpp +30 -3
- warp/native/cuda_util.cpp +14 -0
- warp/native/cuda_util.h +2 -0
- warp/native/exports.h +68 -63
- warp/native/intersect.h +26 -26
- warp/native/intersect_adj.h +33 -33
- warp/native/marching.cu +1 -1
- warp/native/mat.h +513 -9
- warp/native/mesh.h +10 -10
- warp/native/quat.h +99 -11
- warp/native/rand.h +6 -0
- warp/native/sort.cpp +122 -59
- warp/native/sort.cu +152 -15
- warp/native/sort.h +8 -1
- warp/native/sparse.cpp +43 -22
- warp/native/sparse.cu +52 -17
- warp/native/svd.h +116 -0
- warp/native/tile.h +312 -116
- warp/native/tile_reduce.h +46 -3
- warp/native/vec.h +68 -7
- warp/native/volume.cpp +85 -113
- warp/native/volume_builder.cu +25 -10
- warp/native/volume_builder.h +6 -0
- warp/native/warp.cpp +5 -6
- warp/native/warp.cu +100 -11
- warp/native/warp.h +19 -10
- warp/optim/linear.py +10 -10
- warp/render/render_opengl.py +19 -17
- warp/render/render_usd.py +93 -3
- warp/sim/articulation.py +4 -4
- warp/sim/collide.py +32 -19
- warp/sim/import_mjcf.py +449 -155
- warp/sim/import_urdf.py +32 -12
- warp/sim/inertia.py +189 -156
- warp/sim/integrator_euler.py +8 -5
- warp/sim/integrator_featherstone.py +3 -10
- warp/sim/integrator_vbd.py +207 -2
- warp/sim/integrator_xpbd.py +8 -5
- warp/sim/model.py +71 -25
- warp/sim/render.py +4 -0
- warp/sim/utils.py +2 -2
- warp/sparse.py +642 -555
- warp/stubs.py +217 -20
- warp/tests/__main__.py +0 -15
- warp/tests/assets/torus.usda +1 -1
- warp/tests/cuda/__init__.py +0 -0
- warp/tests/{test_mempool.py → cuda/test_mempool.py} +39 -0
- warp/tests/{test_streams.py → cuda/test_streams.py} +71 -0
- warp/tests/geometry/__init__.py +0 -0
- warp/tests/{test_mesh_query_point.py → geometry/test_mesh_query_point.py} +66 -63
- warp/tests/{test_mesh_query_ray.py → geometry/test_mesh_query_ray.py} +1 -1
- warp/tests/{test_volume.py → geometry/test_volume.py} +41 -6
- warp/tests/interop/__init__.py +0 -0
- warp/tests/{test_dlpack.py → interop/test_dlpack.py} +28 -5
- warp/tests/sim/__init__.py +0 -0
- warp/tests/{disabled_kinematics.py → sim/disabled_kinematics.py} +9 -10
- warp/tests/{test_collision.py → sim/test_collision.py} +236 -205
- warp/tests/sim/test_inertia.py +161 -0
- warp/tests/{test_model.py → sim/test_model.py} +40 -0
- warp/tests/{flaky_test_sim_grad.py → sim/test_sim_grad.py} +4 -0
- warp/tests/{test_sim_kinematics.py → sim/test_sim_kinematics.py} +2 -1
- warp/tests/sim/test_vbd.py +597 -0
- warp/tests/sim/test_xpbd.py +399 -0
- warp/tests/test_bool.py +1 -1
- warp/tests/test_codegen.py +24 -3
- warp/tests/test_examples.py +40 -38
- warp/tests/test_fem.py +98 -14
- warp/tests/test_linear_solvers.py +0 -11
- warp/tests/test_mat.py +577 -156
- warp/tests/test_mat_scalar_ops.py +4 -4
- warp/tests/test_overwrite.py +0 -60
- warp/tests/test_quat.py +356 -151
- warp/tests/test_rand.py +44 -37
- warp/tests/test_sparse.py +47 -6
- warp/tests/test_spatial.py +75 -0
- warp/tests/test_static.py +1 -1
- warp/tests/test_utils.py +84 -4
- warp/tests/test_vec.py +336 -178
- warp/tests/tile/__init__.py +0 -0
- warp/tests/{test_tile.py → tile/test_tile.py} +136 -51
- warp/tests/{test_tile_load.py → tile/test_tile_load.py} +98 -1
- warp/tests/{test_tile_mathdx.py → tile/test_tile_mathdx.py} +9 -6
- warp/tests/{test_tile_mlp.py → tile/test_tile_mlp.py} +25 -14
- warp/tests/{test_tile_reduce.py → tile/test_tile_reduce.py} +60 -1
- warp/tests/{test_tile_view.py → tile/test_tile_view.py} +1 -1
- warp/tests/unittest_serial.py +1 -0
- warp/tests/unittest_suites.py +45 -62
- warp/tests/unittest_utils.py +2 -1
- warp/thirdparty/unittest_parallel.py +3 -1
- warp/types.py +175 -666
- warp/utils.py +137 -72
- {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info}/METADATA +46 -12
- {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info}/RECORD +184 -171
- {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info}/WHEEL +1 -1
- {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info/licenses}/LICENSE.md +0 -26
- warp/examples/optim/example_walker.py +0 -317
- warp/native/cutlass_gemm.cpp +0 -43
- warp/native/cutlass_gemm.cu +0 -382
- warp/tests/test_matmul.py +0 -511
- warp/tests/test_matmul_lite.py +0 -411
- warp/tests/test_vbd.py +0 -386
- warp/tests/unused_test_misc.py +0 -77
- /warp/tests/{test_async.py → cuda/test_async.py} +0 -0
- /warp/tests/{test_ipc.py → cuda/test_ipc.py} +0 -0
- /warp/tests/{test_multigpu.py → cuda/test_multigpu.py} +0 -0
- /warp/tests/{test_peer.py → cuda/test_peer.py} +0 -0
- /warp/tests/{test_pinned.py → cuda/test_pinned.py} +0 -0
- /warp/tests/{test_bvh.py → geometry/test_bvh.py} +0 -0
- /warp/tests/{test_hash_grid.py → geometry/test_hash_grid.py} +0 -0
- /warp/tests/{test_marching_cubes.py → geometry/test_marching_cubes.py} +0 -0
- /warp/tests/{test_mesh.py → geometry/test_mesh.py} +0 -0
- /warp/tests/{test_mesh_query_aabb.py → geometry/test_mesh_query_aabb.py} +0 -0
- /warp/tests/{test_volume_write.py → geometry/test_volume_write.py} +0 -0
- /warp/tests/{test_jax.py → interop/test_jax.py} +0 -0
- /warp/tests/{test_paddle.py → interop/test_paddle.py} +0 -0
- /warp/tests/{test_torch.py → interop/test_torch.py} +0 -0
- /warp/tests/{test_coloring.py → sim/test_coloring.py} +0 -0
- /warp/tests/{test_sim_grad_bounce_linear.py → sim/test_sim_grad_bounce_linear.py} +0 -0
- /warp/tests/{test_tile_shared_memory.py → tile/test_tile_shared_memory.py} +0 -0
- {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info}/top_level.txt +0 -0
warp/codegen.py
CHANGED
|
@@ -26,7 +26,7 @@ import re
|
|
|
26
26
|
import sys
|
|
27
27
|
import textwrap
|
|
28
28
|
import types
|
|
29
|
-
from typing import Any, Callable, Dict, Mapping, Optional, Sequence
|
|
29
|
+
from typing import Any, Callable, Dict, Mapping, Optional, Sequence, get_args, get_origin
|
|
30
30
|
|
|
31
31
|
import warp.config
|
|
32
32
|
from warp.types import *
|
|
@@ -57,7 +57,7 @@ class WarpCodegenKeyError(KeyError):
|
|
|
57
57
|
|
|
58
58
|
|
|
59
59
|
# map operator to function name
|
|
60
|
-
builtin_operators = {}
|
|
60
|
+
builtin_operators: Dict[type[ast.AST], str] = {}
|
|
61
61
|
|
|
62
62
|
# see https://www.ics.uci.edu/~pattis/ICS-31/lectures/opexp.pdf for a
|
|
63
63
|
# nice overview of python operators
|
|
@@ -122,16 +122,6 @@ def get_closure_cell_contents(obj):
|
|
|
122
122
|
return None
|
|
123
123
|
|
|
124
124
|
|
|
125
|
-
def get_type_origin(tp):
|
|
126
|
-
# Compatible version of `typing.get_origin()` for Python 3.7 and older.
|
|
127
|
-
return getattr(tp, "__origin__", None)
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
def get_type_args(tp):
|
|
131
|
-
# Compatible version of `typing.get_args()` for Python 3.7 and older.
|
|
132
|
-
return getattr(tp, "__args__", ())
|
|
133
|
-
|
|
134
|
-
|
|
135
125
|
def eval_annotations(annotations: Mapping[str, Any], obj: Any) -> Mapping[str, Any]:
|
|
136
126
|
"""Un-stringize annotations caused by `from __future__ import annotations` of PEP 563."""
|
|
137
127
|
# Implementation backported from `inspect.get_annotations()` for Python 3.9 and older.
|
|
@@ -212,7 +202,7 @@ def get_full_arg_spec(func: Callable) -> inspect.FullArgSpec:
|
|
|
212
202
|
return spec._replace(annotations=eval_annotations(spec.annotations, func))
|
|
213
203
|
|
|
214
204
|
|
|
215
|
-
def struct_instance_repr_recursive(inst: StructInstance, depth: int) -> str:
|
|
205
|
+
def struct_instance_repr_recursive(inst: StructInstance, depth: int, use_repr: bool) -> str:
|
|
216
206
|
indent = "\t"
|
|
217
207
|
|
|
218
208
|
# handle empty structs
|
|
@@ -226,9 +216,12 @@ def struct_instance_repr_recursive(inst: StructInstance, depth: int) -> str:
|
|
|
226
216
|
field_value = getattr(inst, field_name, None)
|
|
227
217
|
|
|
228
218
|
if isinstance(field_value, StructInstance):
|
|
229
|
-
field_value = struct_instance_repr_recursive(field_value, depth + 1)
|
|
219
|
+
field_value = struct_instance_repr_recursive(field_value, depth + 1, use_repr)
|
|
230
220
|
|
|
231
|
-
|
|
221
|
+
if use_repr:
|
|
222
|
+
lines.append(f"{indent * (depth + 1)}{field_name}={field_value!r},")
|
|
223
|
+
else:
|
|
224
|
+
lines.append(f"{indent * (depth + 1)}{field_name}={field_value!s},")
|
|
232
225
|
|
|
233
226
|
lines.append(f"{indent * depth})")
|
|
234
227
|
return "\n".join(lines)
|
|
@@ -351,7 +344,10 @@ class StructInstance:
|
|
|
351
344
|
return self._ctype
|
|
352
345
|
|
|
353
346
|
def __repr__(self):
|
|
354
|
-
return struct_instance_repr_recursive(self, 0)
|
|
347
|
+
return struct_instance_repr_recursive(self, 0, use_repr=True)
|
|
348
|
+
|
|
349
|
+
def __str__(self):
|
|
350
|
+
return struct_instance_repr_recursive(self, 0, use_repr=False)
|
|
355
351
|
|
|
356
352
|
def to(self, device):
|
|
357
353
|
"""Copies this struct with all array members moved onto the given device.
|
|
@@ -415,12 +411,14 @@ class StructInstance:
|
|
|
415
411
|
|
|
416
412
|
|
|
417
413
|
class Struct:
|
|
418
|
-
|
|
414
|
+
hash: bytes
|
|
415
|
+
|
|
416
|
+
def __init__(self, cls: type, key: str, module: warp.context.Module):
|
|
419
417
|
self.cls = cls
|
|
420
418
|
self.module = module
|
|
421
419
|
self.key = key
|
|
420
|
+
self.vars: Dict[str, Var] = {}
|
|
422
421
|
|
|
423
|
-
self.vars = {}
|
|
424
422
|
annotations = get_annotations(self.cls)
|
|
425
423
|
for label, type in annotations.items():
|
|
426
424
|
self.vars[label] = Var(label, type)
|
|
@@ -591,11 +589,11 @@ class Reference:
|
|
|
591
589
|
self.value_type = value_type
|
|
592
590
|
|
|
593
591
|
|
|
594
|
-
def is_reference(type):
|
|
592
|
+
def is_reference(type: Any) -> builtins.bool:
|
|
595
593
|
return isinstance(type, Reference)
|
|
596
594
|
|
|
597
595
|
|
|
598
|
-
def strip_reference(arg):
|
|
596
|
+
def strip_reference(arg: Any) -> Any:
|
|
599
597
|
if is_reference(arg):
|
|
600
598
|
return arg.value_type
|
|
601
599
|
else:
|
|
@@ -623,7 +621,15 @@ def compute_type_str(base_name, template_params):
|
|
|
623
621
|
|
|
624
622
|
|
|
625
623
|
class Var:
|
|
626
|
-
def __init__(
|
|
624
|
+
def __init__(
|
|
625
|
+
self,
|
|
626
|
+
label: str,
|
|
627
|
+
type: type,
|
|
628
|
+
requires_grad: builtins.bool = False,
|
|
629
|
+
constant: Optional[builtins.bool] = None,
|
|
630
|
+
prefix: builtins.bool = True,
|
|
631
|
+
relative_lineno: Optional[int] = None,
|
|
632
|
+
):
|
|
627
633
|
# convert built-in types to wp types
|
|
628
634
|
if type == float:
|
|
629
635
|
type = float32
|
|
@@ -646,11 +652,14 @@ class Var:
|
|
|
646
652
|
# used to associate a view array Var with its parent array Var
|
|
647
653
|
self.parent = None
|
|
648
654
|
|
|
655
|
+
# Used to associate the variable with the Python statement that resulted in it being created.
|
|
656
|
+
self.relative_lineno = relative_lineno
|
|
657
|
+
|
|
649
658
|
def __str__(self):
|
|
650
659
|
return self.label
|
|
651
660
|
|
|
652
661
|
@staticmethod
|
|
653
|
-
def type_to_ctype(t, value_type=False):
|
|
662
|
+
def type_to_ctype(t: type, value_type: builtins.bool = False) -> str:
|
|
654
663
|
if is_array(t):
|
|
655
664
|
if hasattr(t.dtype, "_wp_generic_type_str_"):
|
|
656
665
|
dtypestr = compute_type_str(f"wp::{t.dtype._wp_generic_type_str_}", t.dtype._wp_type_params_)
|
|
@@ -681,7 +690,7 @@ class Var:
|
|
|
681
690
|
else:
|
|
682
691
|
return f"wp::{t.__name__}"
|
|
683
692
|
|
|
684
|
-
def ctype(self, value_type=False):
|
|
693
|
+
def ctype(self, value_type: builtins.bool = False) -> str:
|
|
685
694
|
return Var.type_to_ctype(self.type, value_type)
|
|
686
695
|
|
|
687
696
|
def emit(self, prefix: str = "var"):
|
|
@@ -803,7 +812,7 @@ def func_match_args(func, arg_types, kwarg_types):
|
|
|
803
812
|
return True
|
|
804
813
|
|
|
805
814
|
|
|
806
|
-
def get_arg_type(arg: Union[Var, Any]):
|
|
815
|
+
def get_arg_type(arg: Union[Var, Any]) -> type:
|
|
807
816
|
if isinstance(arg, str):
|
|
808
817
|
return str
|
|
809
818
|
|
|
@@ -819,7 +828,7 @@ def get_arg_type(arg: Union[Var, Any]):
|
|
|
819
828
|
return type(arg)
|
|
820
829
|
|
|
821
830
|
|
|
822
|
-
def get_arg_value(arg:
|
|
831
|
+
def get_arg_value(arg: Any) -> Any:
|
|
823
832
|
if isinstance(arg, Sequence):
|
|
824
833
|
return tuple(get_arg_value(x) for x in arg)
|
|
825
834
|
|
|
@@ -867,6 +876,9 @@ class Adjoint:
|
|
|
867
876
|
"please save it on a file and use `importlib` if needed."
|
|
868
877
|
) from e
|
|
869
878
|
|
|
879
|
+
# Indicates where the function definition starts (excludes decorators)
|
|
880
|
+
adj.fun_def_lineno = None
|
|
881
|
+
|
|
870
882
|
# get function source code
|
|
871
883
|
adj.source = inspect.getsource(func)
|
|
872
884
|
# ensures that indented class methods can be parsed as kernels
|
|
@@ -941,9 +953,6 @@ class Adjoint:
|
|
|
941
953
|
# for unit testing errors being spit out from kernels.
|
|
942
954
|
adj.skip_build = False
|
|
943
955
|
|
|
944
|
-
# Collect the LTOIR required at link-time
|
|
945
|
-
adj.ltoirs = []
|
|
946
|
-
|
|
947
956
|
# allocate extra space for a function call that requires its
|
|
948
957
|
# own shared memory space, we treat shared memory as a stack
|
|
949
958
|
# where each function pushes and pops space off, the extra
|
|
@@ -1133,7 +1142,7 @@ class Adjoint:
|
|
|
1133
1142
|
name = str(index)
|
|
1134
1143
|
|
|
1135
1144
|
# allocate new variable
|
|
1136
|
-
v = Var(name, type=type, constant=constant)
|
|
1145
|
+
v = Var(name, type=type, constant=constant, relative_lineno=adj.lineno)
|
|
1137
1146
|
|
|
1138
1147
|
adj.variables.append(v)
|
|
1139
1148
|
|
|
@@ -1158,11 +1167,44 @@ class Adjoint:
|
|
|
1158
1167
|
|
|
1159
1168
|
return var
|
|
1160
1169
|
|
|
1161
|
-
|
|
1162
|
-
|
|
1170
|
+
def get_line_directive(adj, statement: str, relative_lineno: Optional[int] = None) -> Optional[str]:
|
|
1171
|
+
"""Get a line directive for the given statement.
|
|
1172
|
+
|
|
1173
|
+
Args:
|
|
1174
|
+
statement: The statement to get the line directive for.
|
|
1175
|
+
relative_lineno: The line number of the statement relative to the function.
|
|
1176
|
+
|
|
1177
|
+
Returns:
|
|
1178
|
+
A line directive for the given statement, or None if no line directive is needed.
|
|
1179
|
+
"""
|
|
1180
|
+
|
|
1181
|
+
# lineinfo is enabled by default in debug mode regardless of the builder option, don't want to unnecessarily
|
|
1182
|
+
# emit line directives in generated code if it's not being compiled with line information
|
|
1183
|
+
lineinfo_enabled = (
|
|
1184
|
+
adj.builder_options.get("lineinfo", False) or adj.builder_options.get("mode", "release") == "debug"
|
|
1185
|
+
)
|
|
1186
|
+
|
|
1187
|
+
if relative_lineno is not None and lineinfo_enabled and warp.config.line_directives:
|
|
1188
|
+
is_comment = statement.strip().startswith("//")
|
|
1189
|
+
if not is_comment:
|
|
1190
|
+
line = relative_lineno + adj.fun_lineno
|
|
1191
|
+
# Convert backslashes to forward slashes for CUDA compatibility
|
|
1192
|
+
normalized_path = adj.filename.replace("\\", "/")
|
|
1193
|
+
return f'#line {line} "{normalized_path}"'
|
|
1194
|
+
return None
|
|
1195
|
+
|
|
1196
|
+
def add_forward(adj, statement: str, replay: Optional[str] = None, skip_replay: builtins.bool = False) -> None:
|
|
1197
|
+
"""Append a statement to the forward pass."""
|
|
1198
|
+
|
|
1199
|
+
if line_directive := adj.get_line_directive(statement, adj.lineno):
|
|
1200
|
+
adj.blocks[-1].body_forward.append(line_directive)
|
|
1201
|
+
|
|
1163
1202
|
adj.blocks[-1].body_forward.append(adj.indentation + statement)
|
|
1164
1203
|
|
|
1165
1204
|
if not skip_replay:
|
|
1205
|
+
if line_directive:
|
|
1206
|
+
adj.blocks[-1].body_replay.append(line_directive)
|
|
1207
|
+
|
|
1166
1208
|
if replay:
|
|
1167
1209
|
# if custom replay specified then output it
|
|
1168
1210
|
adj.blocks[-1].body_replay.append(adj.indentation + replay)
|
|
@@ -1171,9 +1213,14 @@ class Adjoint:
|
|
|
1171
1213
|
adj.blocks[-1].body_replay.append(adj.indentation + statement)
|
|
1172
1214
|
|
|
1173
1215
|
# append a statement to the reverse pass
|
|
1174
|
-
def add_reverse(adj, statement):
|
|
1216
|
+
def add_reverse(adj, statement: str) -> None:
|
|
1217
|
+
"""Append a statement to the reverse pass."""
|
|
1218
|
+
|
|
1175
1219
|
adj.blocks[-1].body_reverse.append(adj.indentation + statement)
|
|
1176
1220
|
|
|
1221
|
+
if line_directive := adj.get_line_directive(statement, adj.lineno):
|
|
1222
|
+
adj.blocks[-1].body_reverse.append(line_directive)
|
|
1223
|
+
|
|
1177
1224
|
def add_constant(adj, n):
|
|
1178
1225
|
output = adj.add_var(type=type(n), constant=n)
|
|
1179
1226
|
return output
|
|
@@ -1281,7 +1328,7 @@ class Adjoint:
|
|
|
1281
1328
|
|
|
1282
1329
|
# Bind the positional and keyword arguments to the function's signature
|
|
1283
1330
|
# in order to process them as Python does it.
|
|
1284
|
-
bound_args = func.signature.bind(*args, **kwargs)
|
|
1331
|
+
bound_args: inspect.BoundArguments = func.signature.bind(*args, **kwargs)
|
|
1285
1332
|
|
|
1286
1333
|
# Type args are the “compile time” argument values we get from codegen.
|
|
1287
1334
|
# For example, when calling `wp.vec3f(...)` from within a kernel,
|
|
@@ -1451,6 +1498,8 @@ class Adjoint:
|
|
|
1451
1498
|
|
|
1452
1499
|
def add_return(adj, var):
|
|
1453
1500
|
if var is None or len(var) == 0:
|
|
1501
|
+
# NOTE: If this kernel gets compiled for a CUDA device, then we need
|
|
1502
|
+
# to convert the return; into a continue; in codegen_func_forward()
|
|
1454
1503
|
adj.add_forward("return;", f"goto label{adj.label_count};")
|
|
1455
1504
|
elif len(var) == 1:
|
|
1456
1505
|
adj.add_forward(f"return {var[0].emit()};", f"goto label{adj.label_count};")
|
|
@@ -1624,6 +1673,8 @@ class Adjoint:
|
|
|
1624
1673
|
adj.blocks[-1].body_reverse.extend(reversed(reverse))
|
|
1625
1674
|
|
|
1626
1675
|
def emit_FunctionDef(adj, node):
|
|
1676
|
+
adj.fun_def_lineno = node.lineno
|
|
1677
|
+
|
|
1627
1678
|
for f in node.body:
|
|
1628
1679
|
# Skip variable creation for standalone constants, including docstrings
|
|
1629
1680
|
if isinstance(f, ast.Expr) and isinstance(f.value, ast.Constant):
|
|
@@ -1688,7 +1739,7 @@ class Adjoint:
|
|
|
1688
1739
|
|
|
1689
1740
|
if var1 != var2:
|
|
1690
1741
|
# insert a phi function that selects var1, var2 based on cond
|
|
1691
|
-
out = adj.add_builtin_call("
|
|
1742
|
+
out = adj.add_builtin_call("where", [cond, var2, var1])
|
|
1692
1743
|
adj.symbols[sym] = out
|
|
1693
1744
|
|
|
1694
1745
|
symbols_prev = adj.symbols.copy()
|
|
@@ -1712,7 +1763,7 @@ class Adjoint:
|
|
|
1712
1763
|
if var1 != var2:
|
|
1713
1764
|
# insert a phi function that selects var1, var2 based on cond
|
|
1714
1765
|
# note the reversed order of vars since we want to use !cond as our select
|
|
1715
|
-
out = adj.add_builtin_call("
|
|
1766
|
+
out = adj.add_builtin_call("where", [cond, var1, var2])
|
|
1716
1767
|
adj.symbols[sym] = out
|
|
1717
1768
|
|
|
1718
1769
|
def emit_Compare(adj, node):
|
|
@@ -1856,25 +1907,6 @@ class Adjoint:
|
|
|
1856
1907
|
) from e
|
|
1857
1908
|
raise WarpCodegenAttributeError(f"Error, `{node.attr}` is not an attribute of '{aggregate}'") from e
|
|
1858
1909
|
|
|
1859
|
-
def emit_String(adj, node):
|
|
1860
|
-
# string constant
|
|
1861
|
-
return adj.add_constant(node.s)
|
|
1862
|
-
|
|
1863
|
-
def emit_Num(adj, node):
|
|
1864
|
-
# lookup constant, if it has already been assigned then return existing var
|
|
1865
|
-
key = (node.n, type(node.n))
|
|
1866
|
-
|
|
1867
|
-
if key in adj.symbols:
|
|
1868
|
-
return adj.symbols[key]
|
|
1869
|
-
else:
|
|
1870
|
-
out = adj.add_constant(node.n)
|
|
1871
|
-
adj.symbols[key] = out
|
|
1872
|
-
return out
|
|
1873
|
-
|
|
1874
|
-
def emit_Ellipsis(adj, node):
|
|
1875
|
-
# stubbed @wp.native_func
|
|
1876
|
-
return
|
|
1877
|
-
|
|
1878
1910
|
def emit_Assert(adj, node):
|
|
1879
1911
|
# eval condition
|
|
1880
1912
|
cond = adj.eval(node.test)
|
|
@@ -1886,24 +1918,11 @@ class Adjoint:
|
|
|
1886
1918
|
|
|
1887
1919
|
adj.add_forward(f'assert(("{escaped_segment}",{cond.emit()}));')
|
|
1888
1920
|
|
|
1889
|
-
def emit_NameConstant(adj, node):
|
|
1890
|
-
if node.value:
|
|
1891
|
-
return adj.add_constant(node.value)
|
|
1892
|
-
elif node.value is None:
|
|
1893
|
-
raise WarpCodegenTypeError("None type unsupported")
|
|
1894
|
-
else:
|
|
1895
|
-
return adj.add_constant(False)
|
|
1896
|
-
|
|
1897
1921
|
def emit_Constant(adj, node):
|
|
1898
|
-
if
|
|
1899
|
-
|
|
1900
|
-
elif isinstance(node, ast.Num):
|
|
1901
|
-
return adj.emit_Num(node)
|
|
1902
|
-
elif isinstance(node, ast.Ellipsis):
|
|
1903
|
-
return adj.emit_Ellipsis(node)
|
|
1922
|
+
if node.value is None:
|
|
1923
|
+
raise WarpCodegenTypeError("None type unsupported")
|
|
1904
1924
|
else:
|
|
1905
|
-
|
|
1906
|
-
return adj.emit_NameConstant(node)
|
|
1925
|
+
return adj.add_constant(node.value)
|
|
1907
1926
|
|
|
1908
1927
|
def emit_BinOp(adj, node):
|
|
1909
1928
|
# evaluate binary operator arguments
|
|
@@ -1997,10 +2016,11 @@ class Adjoint:
|
|
|
1997
2016
|
adj.end_while()
|
|
1998
2017
|
|
|
1999
2018
|
def eval_num(adj, a):
|
|
2000
|
-
if isinstance(a, ast.
|
|
2001
|
-
return True, a.
|
|
2002
|
-
if isinstance(a, ast.UnaryOp) and isinstance(a.op, ast.USub) and isinstance(a.operand, ast.
|
|
2003
|
-
|
|
2019
|
+
if isinstance(a, ast.Constant):
|
|
2020
|
+
return True, a.value
|
|
2021
|
+
if isinstance(a, ast.UnaryOp) and isinstance(a.op, ast.USub) and isinstance(a.operand, ast.Constant):
|
|
2022
|
+
# Negative constant
|
|
2023
|
+
return True, -a.operand.value
|
|
2004
2024
|
|
|
2005
2025
|
# try and resolve the expression to an object
|
|
2006
2026
|
# e.g.: wp.constant in the globals scope
|
|
@@ -2530,8 +2550,8 @@ class Adjoint:
|
|
|
2530
2550
|
f"Warning: mutating {node_source} in function {adj.fun_name} at {adj.filename}:{lineno}: this is a non-differentiable operation.\n{line}\n"
|
|
2531
2551
|
)
|
|
2532
2552
|
else:
|
|
2533
|
-
if
|
|
2534
|
-
out = adj.add_builtin_call("
|
|
2553
|
+
if warp.config.enable_vector_component_overwrites:
|
|
2554
|
+
out = adj.add_builtin_call("assign_copy", [target, *indices, rhs])
|
|
2535
2555
|
|
|
2536
2556
|
# re-point target symbol to out var
|
|
2537
2557
|
for id in adj.symbols:
|
|
@@ -2539,8 +2559,7 @@ class Adjoint:
|
|
|
2539
2559
|
adj.symbols[id] = out
|
|
2540
2560
|
break
|
|
2541
2561
|
else:
|
|
2542
|
-
|
|
2543
|
-
adj.add_builtin_call("store", [attr, rhs])
|
|
2562
|
+
adj.add_builtin_call("assign_inplace", [target, *indices, rhs])
|
|
2544
2563
|
|
|
2545
2564
|
else:
|
|
2546
2565
|
raise WarpCodegenError(
|
|
@@ -2583,8 +2602,8 @@ class Adjoint:
|
|
|
2583
2602
|
attr = adj.add_builtin_call("indexref", [aggregate, index])
|
|
2584
2603
|
adj.add_builtin_call("store", [attr, rhs])
|
|
2585
2604
|
else:
|
|
2586
|
-
if
|
|
2587
|
-
out = adj.add_builtin_call("
|
|
2605
|
+
if warp.config.enable_vector_component_overwrites:
|
|
2606
|
+
out = adj.add_builtin_call("assign_copy", [aggregate, index, rhs])
|
|
2588
2607
|
|
|
2589
2608
|
# re-point target symbol to out var
|
|
2590
2609
|
for id in adj.symbols:
|
|
@@ -2592,8 +2611,7 @@ class Adjoint:
|
|
|
2592
2611
|
adj.symbols[id] = out
|
|
2593
2612
|
break
|
|
2594
2613
|
else:
|
|
2595
|
-
|
|
2596
|
-
adj.add_builtin_call("store", [attr, rhs])
|
|
2614
|
+
adj.add_builtin_call("assign_inplace", [aggregate, index, rhs])
|
|
2597
2615
|
|
|
2598
2616
|
else:
|
|
2599
2617
|
attr = adj.emit_Attribute(lhs)
|
|
@@ -2699,10 +2717,12 @@ class Adjoint:
|
|
|
2699
2717
|
|
|
2700
2718
|
elif type_is_vector(target_type) or type_is_quaternion(target_type) or type_is_matrix(target_type):
|
|
2701
2719
|
if isinstance(node.op, ast.Add):
|
|
2702
|
-
adj.add_builtin_call("
|
|
2720
|
+
adj.add_builtin_call("add_inplace", [target, *indices, rhs])
|
|
2703
2721
|
elif isinstance(node.op, ast.Sub):
|
|
2704
|
-
adj.add_builtin_call("
|
|
2722
|
+
adj.add_builtin_call("sub_inplace", [target, *indices, rhs])
|
|
2705
2723
|
else:
|
|
2724
|
+
if warp.config.verbose:
|
|
2725
|
+
print(f"Warning: in-place op {node.op} is not differentiable")
|
|
2706
2726
|
make_new_assign_statement()
|
|
2707
2727
|
return
|
|
2708
2728
|
|
|
@@ -2732,9 +2752,6 @@ class Adjoint:
|
|
|
2732
2752
|
ast.BoolOp: emit_BoolOp,
|
|
2733
2753
|
ast.Name: emit_Name,
|
|
2734
2754
|
ast.Attribute: emit_Attribute,
|
|
2735
|
-
ast.Str: emit_String, # Deprecated in 3.8; use Constant
|
|
2736
|
-
ast.Num: emit_Num, # Deprecated in 3.8; use Constant
|
|
2737
|
-
ast.NameConstant: emit_NameConstant, # Deprecated in 3.8; use Constant
|
|
2738
2755
|
ast.Constant: emit_Constant,
|
|
2739
2756
|
ast.BinOp: emit_BinOp,
|
|
2740
2757
|
ast.UnaryOp: emit_UnaryOp,
|
|
@@ -2744,14 +2761,13 @@ class Adjoint:
|
|
|
2744
2761
|
ast.Continue: emit_Continue,
|
|
2745
2762
|
ast.Expr: emit_Expr,
|
|
2746
2763
|
ast.Call: emit_Call,
|
|
2747
|
-
ast.Index: emit_Index, # Deprecated in 3.
|
|
2764
|
+
ast.Index: emit_Index, # Deprecated in 3.9
|
|
2748
2765
|
ast.Subscript: emit_Subscript,
|
|
2749
2766
|
ast.Assign: emit_Assign,
|
|
2750
2767
|
ast.Return: emit_Return,
|
|
2751
2768
|
ast.AugAssign: emit_AugAssign,
|
|
2752
2769
|
ast.Tuple: emit_Tuple,
|
|
2753
2770
|
ast.Pass: emit_Pass,
|
|
2754
|
-
ast.Ellipsis: emit_Ellipsis,
|
|
2755
2771
|
ast.Assert: emit_Assert,
|
|
2756
2772
|
}
|
|
2757
2773
|
|
|
@@ -2947,12 +2963,16 @@ class Adjoint:
|
|
|
2947
2963
|
|
|
2948
2964
|
# We want to replace the expression code in-place,
|
|
2949
2965
|
# so reparse it to get the correct column info.
|
|
2950
|
-
len_value_locs = []
|
|
2966
|
+
len_value_locs: List[Tuple[int, int, int]] = []
|
|
2951
2967
|
expr_tree = ast.parse(static_code)
|
|
2952
2968
|
assert len(expr_tree.body) == 1 and isinstance(expr_tree.body[0], ast.Expr)
|
|
2953
2969
|
expr_root = expr_tree.body[0].value
|
|
2954
2970
|
for expr_node in ast.walk(expr_root):
|
|
2955
|
-
if
|
|
2971
|
+
if (
|
|
2972
|
+
isinstance(expr_node, ast.Call)
|
|
2973
|
+
and getattr(expr_node.func, "id", None) == "len"
|
|
2974
|
+
and len(expr_node.args) == 1
|
|
2975
|
+
):
|
|
2956
2976
|
len_expr = static_code[expr_node.col_offset : expr_node.end_col_offset]
|
|
2957
2977
|
try:
|
|
2958
2978
|
len_value = eval(len_expr, len_expr_ctx)
|
|
@@ -3110,9 +3130,9 @@ class Adjoint:
|
|
|
3110
3130
|
|
|
3111
3131
|
local_variables = set() # Track local variables appearing on the LHS so we know when variables are shadowed
|
|
3112
3132
|
|
|
3113
|
-
constants = {}
|
|
3114
|
-
types = {}
|
|
3115
|
-
functions = {}
|
|
3133
|
+
constants: Dict[str, Any] = {}
|
|
3134
|
+
types: Dict[Union[Struct, type], Any] = {}
|
|
3135
|
+
functions: Dict[warp.context.Function, Any] = {}
|
|
3116
3136
|
|
|
3117
3137
|
for node in ast.walk(adj.tree):
|
|
3118
3138
|
if isinstance(node, ast.Name) and node.id not in local_variables:
|
|
@@ -3155,7 +3175,7 @@ class Adjoint:
|
|
|
3155
3175
|
# code generation
|
|
3156
3176
|
|
|
3157
3177
|
cpu_module_header = """
|
|
3158
|
-
#define WP_TILE_BLOCK_DIM {
|
|
3178
|
+
#define WP_TILE_BLOCK_DIM {block_dim}
|
|
3159
3179
|
#define WP_NO_CRT
|
|
3160
3180
|
#include "builtin.h"
|
|
3161
3181
|
|
|
@@ -3174,7 +3194,7 @@ cpu_module_header = """
|
|
|
3174
3194
|
"""
|
|
3175
3195
|
|
|
3176
3196
|
cuda_module_header = """
|
|
3177
|
-
#define WP_TILE_BLOCK_DIM {
|
|
3197
|
+
#define WP_TILE_BLOCK_DIM {block_dim}
|
|
3178
3198
|
#define WP_NO_CRT
|
|
3179
3199
|
#include "builtin.h"
|
|
3180
3200
|
|
|
@@ -3197,6 +3217,7 @@ struct {name}
|
|
|
3197
3217
|
{{
|
|
3198
3218
|
{struct_body}
|
|
3199
3219
|
|
|
3220
|
+
{defaulted_constructor_def}
|
|
3200
3221
|
CUDA_CALLABLE {name}({forward_args})
|
|
3201
3222
|
{forward_initializers}
|
|
3202
3223
|
{{
|
|
@@ -3239,53 +3260,53 @@ static void adj_{name}(
|
|
|
3239
3260
|
|
|
3240
3261
|
cuda_forward_function_template = """
|
|
3241
3262
|
// {filename}:{lineno}
|
|
3242
|
-
static CUDA_CALLABLE {return_type} {name}(
|
|
3263
|
+
{line_directive}static CUDA_CALLABLE {return_type} {name}(
|
|
3243
3264
|
{forward_args})
|
|
3244
3265
|
{{
|
|
3245
|
-
{forward_body}}}
|
|
3266
|
+
{forward_body}{line_directive}}}
|
|
3246
3267
|
|
|
3247
3268
|
"""
|
|
3248
3269
|
|
|
3249
3270
|
cuda_reverse_function_template = """
|
|
3250
3271
|
// {filename}:{lineno}
|
|
3251
|
-
static CUDA_CALLABLE void adj_{name}(
|
|
3272
|
+
{line_directive}static CUDA_CALLABLE void adj_{name}(
|
|
3252
3273
|
{reverse_args})
|
|
3253
3274
|
{{
|
|
3254
|
-
{reverse_body}}}
|
|
3275
|
+
{reverse_body}{line_directive}}}
|
|
3255
3276
|
|
|
3256
3277
|
"""
|
|
3257
3278
|
|
|
3258
3279
|
cuda_kernel_template_forward = """
|
|
3259
3280
|
|
|
3260
|
-
extern "C" __global__ void {name}_cuda_kernel_forward(
|
|
3281
|
+
{line_directive}extern "C" __global__ void {name}_cuda_kernel_forward(
|
|
3261
3282
|
{forward_args})
|
|
3262
3283
|
{{
|
|
3263
|
-
for (size_t _idx = static_cast<size_t>(blockDim.x) * static_cast<size_t>(blockIdx.x) + static_cast<size_t>(threadIdx.x);
|
|
3264
|
-
_idx < dim.size;
|
|
3265
|
-
_idx += static_cast<size_t>(blockDim.x) * static_cast<size_t>(gridDim.x))
|
|
3284
|
+
{line_directive} for (size_t _idx = static_cast<size_t>(blockDim.x) * static_cast<size_t>(blockIdx.x) + static_cast<size_t>(threadIdx.x);
|
|
3285
|
+
{line_directive} _idx < dim.size;
|
|
3286
|
+
{line_directive} _idx += static_cast<size_t>(blockDim.x) * static_cast<size_t>(gridDim.x))
|
|
3266
3287
|
{{
|
|
3267
3288
|
// reset shared memory allocator
|
|
3268
|
-
wp::tile_alloc_shared(0, true);
|
|
3289
|
+
{line_directive} wp::tile_alloc_shared(0, true);
|
|
3269
3290
|
|
|
3270
|
-
{forward_body} }}
|
|
3271
|
-
}}
|
|
3291
|
+
{forward_body}{line_directive} }}
|
|
3292
|
+
{line_directive}}}
|
|
3272
3293
|
|
|
3273
3294
|
"""
|
|
3274
3295
|
|
|
3275
3296
|
cuda_kernel_template_backward = """
|
|
3276
3297
|
|
|
3277
|
-
extern "C" __global__ void {name}_cuda_kernel_backward(
|
|
3298
|
+
{line_directive}extern "C" __global__ void {name}_cuda_kernel_backward(
|
|
3278
3299
|
{reverse_args})
|
|
3279
3300
|
{{
|
|
3280
|
-
for (size_t _idx = static_cast<size_t>(blockDim.x) * static_cast<size_t>(blockIdx.x) + static_cast<size_t>(threadIdx.x);
|
|
3281
|
-
_idx < dim.size;
|
|
3282
|
-
_idx += static_cast<size_t>(blockDim.x) * static_cast<size_t>(gridDim.x))
|
|
3301
|
+
{line_directive} for (size_t _idx = static_cast<size_t>(blockDim.x) * static_cast<size_t>(blockIdx.x) + static_cast<size_t>(threadIdx.x);
|
|
3302
|
+
{line_directive} _idx < dim.size;
|
|
3303
|
+
{line_directive} _idx += static_cast<size_t>(blockDim.x) * static_cast<size_t>(gridDim.x))
|
|
3283
3304
|
{{
|
|
3284
3305
|
// reset shared memory allocator
|
|
3285
|
-
wp::tile_alloc_shared(0, true);
|
|
3306
|
+
{line_directive} wp::tile_alloc_shared(0, true);
|
|
3286
3307
|
|
|
3287
|
-
{reverse_body} }}
|
|
3288
|
-
}}
|
|
3308
|
+
{reverse_body}{line_directive} }}
|
|
3309
|
+
{line_directive}}}
|
|
3289
3310
|
|
|
3290
3311
|
"""
|
|
3291
3312
|
|
|
@@ -3315,10 +3336,17 @@ extern "C" {{
|
|
|
3315
3336
|
WP_API void {name}_cpu_forward(
|
|
3316
3337
|
{forward_args})
|
|
3317
3338
|
{{
|
|
3318
|
-
|
|
3339
|
+
for (size_t task_index = 0; task_index < dim.size; ++task_index)
|
|
3319
3340
|
{{
|
|
3341
|
+
// init shared memory allocator
|
|
3342
|
+
wp::tile_alloc_shared(0, true);
|
|
3343
|
+
|
|
3320
3344
|
{name}_cpu_kernel_forward(
|
|
3321
3345
|
{forward_params});
|
|
3346
|
+
|
|
3347
|
+
// check shared memory allocator
|
|
3348
|
+
wp::tile_alloc_shared(0, false, true);
|
|
3349
|
+
|
|
3322
3350
|
}}
|
|
3323
3351
|
}}
|
|
3324
3352
|
|
|
@@ -3335,8 +3363,14 @@ WP_API void {name}_cpu_backward(
|
|
|
3335
3363
|
{{
|
|
3336
3364
|
for (size_t task_index = 0; task_index < dim.size; ++task_index)
|
|
3337
3365
|
{{
|
|
3366
|
+
// initialize shared memory allocator
|
|
3367
|
+
wp::tile_alloc_shared(0, true);
|
|
3368
|
+
|
|
3338
3369
|
{name}_cpu_kernel_backward(
|
|
3339
3370
|
{reverse_params});
|
|
3371
|
+
|
|
3372
|
+
// check shared memory allocator
|
|
3373
|
+
wp::tile_alloc_shared(0, false, true);
|
|
3340
3374
|
}}
|
|
3341
3375
|
}}
|
|
3342
3376
|
|
|
@@ -3418,7 +3452,7 @@ def indent(args, stops=1):
|
|
|
3418
3452
|
|
|
3419
3453
|
|
|
3420
3454
|
# generates a C function name based on the python function name
|
|
3421
|
-
def make_full_qualified_name(func):
|
|
3455
|
+
def make_full_qualified_name(func: Union[str, Callable]) -> str:
|
|
3422
3456
|
if not isinstance(func, str):
|
|
3423
3457
|
func = func.__qualname__
|
|
3424
3458
|
return re.sub("[^0-9a-zA-Z_]+", "", func.replace(".", "__"))
|
|
@@ -3448,7 +3482,8 @@ def codegen_struct(struct, device="cpu", indent_size=4):
|
|
|
3448
3482
|
# forward args
|
|
3449
3483
|
for label, var in struct.vars.items():
|
|
3450
3484
|
var_ctype = var.ctype()
|
|
3451
|
-
|
|
3485
|
+
default_arg_def = " = {}" if forward_args else ""
|
|
3486
|
+
forward_args.append(f"{var_ctype} const& {label}{default_arg_def}")
|
|
3452
3487
|
reverse_args.append(f"{var_ctype} const&")
|
|
3453
3488
|
|
|
3454
3489
|
namespace = "wp::" if var_ctype.startswith("wp::") or var_ctype == "bool" else ""
|
|
@@ -3472,6 +3507,9 @@ def codegen_struct(struct, device="cpu", indent_size=4):
|
|
|
3472
3507
|
|
|
3473
3508
|
reverse_args.append(name + " & adj_ret")
|
|
3474
3509
|
|
|
3510
|
+
# explicitly defaulted default constructor if no default constructor has been defined
|
|
3511
|
+
defaulted_constructor_def = f"{name}() = default;" if forward_args else ""
|
|
3512
|
+
|
|
3475
3513
|
return struct_template.format(
|
|
3476
3514
|
name=name,
|
|
3477
3515
|
struct_body="".join([indent_block + l for l in body]),
|
|
@@ -3481,6 +3519,7 @@ def codegen_struct(struct, device="cpu", indent_size=4):
|
|
|
3481
3519
|
reverse_body="".join(reverse_body),
|
|
3482
3520
|
prefix_add_body="".join(prefix_add_body),
|
|
3483
3521
|
atomic_add_body="".join(atomic_add_body),
|
|
3522
|
+
defaulted_constructor_def=defaulted_constructor_def,
|
|
3484
3523
|
)
|
|
3485
3524
|
|
|
3486
3525
|
|
|
@@ -3510,14 +3549,21 @@ def codegen_func_forward(adj, func_type="kernel", device="cpu"):
|
|
|
3510
3549
|
else:
|
|
3511
3550
|
lines += [f"const {var.ctype()} {var.emit()} = {constant_str(var.constant)};\n"]
|
|
3512
3551
|
|
|
3552
|
+
if line_directive := adj.get_line_directive(lines[-1], var.relative_lineno):
|
|
3553
|
+
lines.insert(-1, f"{line_directive}\n")
|
|
3554
|
+
|
|
3513
3555
|
# forward pass
|
|
3514
3556
|
lines += ["//---------\n"]
|
|
3515
3557
|
lines += ["// forward\n"]
|
|
3516
3558
|
|
|
3517
3559
|
for f in adj.blocks[0].body_forward:
|
|
3518
|
-
|
|
3560
|
+
if func_type == "kernel" and device == "cuda" and f.lstrip().startswith("return;"):
|
|
3561
|
+
# Use of grid-stride loops in CUDA kernels requires that we convert return; to continue;
|
|
3562
|
+
lines += [f.replace("return;", "continue;") + "\n"]
|
|
3563
|
+
else:
|
|
3564
|
+
lines += [f + "\n"]
|
|
3519
3565
|
|
|
3520
|
-
return "".join(
|
|
3566
|
+
return "".join(l.lstrip() if l.lstrip().startswith("#line") else indent_block + l for l in lines)
|
|
3521
3567
|
|
|
3522
3568
|
|
|
3523
3569
|
def codegen_func_reverse(adj, func_type="kernel", device="cpu"):
|
|
@@ -3547,6 +3593,9 @@ def codegen_func_reverse(adj, func_type="kernel", device="cpu"):
|
|
|
3547
3593
|
else:
|
|
3548
3594
|
lines += [f"const {var.ctype()} {var.emit()} = {constant_str(var.constant)};\n"]
|
|
3549
3595
|
|
|
3596
|
+
if line_directive := adj.get_line_directive(lines[-1], var.relative_lineno):
|
|
3597
|
+
lines.insert(-1, f"{line_directive}\n")
|
|
3598
|
+
|
|
3550
3599
|
# dual vars
|
|
3551
3600
|
lines += ["//---------\n"]
|
|
3552
3601
|
lines += ["// dual vars\n"]
|
|
@@ -3567,6 +3616,9 @@ def codegen_func_reverse(adj, func_type="kernel", device="cpu"):
|
|
|
3567
3616
|
else:
|
|
3568
3617
|
lines += [f"{ctype} {name} = {{}};\n"]
|
|
3569
3618
|
|
|
3619
|
+
if line_directive := adj.get_line_directive(lines[-1], var.relative_lineno):
|
|
3620
|
+
lines.insert(-1, f"{line_directive}\n")
|
|
3621
|
+
|
|
3570
3622
|
# forward pass
|
|
3571
3623
|
lines += ["//---------\n"]
|
|
3572
3624
|
lines += ["// forward\n"]
|
|
@@ -3587,7 +3639,7 @@ def codegen_func_reverse(adj, func_type="kernel", device="cpu"):
|
|
|
3587
3639
|
else:
|
|
3588
3640
|
lines += ["return;\n"]
|
|
3589
3641
|
|
|
3590
|
-
return "".join(
|
|
3642
|
+
return "".join(l.lstrip() if l.lstrip().startswith("#line") else indent_block + l for l in lines)
|
|
3591
3643
|
|
|
3592
3644
|
|
|
3593
3645
|
def codegen_func(adj, c_func_name: str, device="cpu", options=None):
|
|
@@ -3595,11 +3647,11 @@ def codegen_func(adj, c_func_name: str, device="cpu", options=None):
|
|
|
3595
3647
|
options = {}
|
|
3596
3648
|
|
|
3597
3649
|
if adj.return_var is not None and "return" in adj.arg_types:
|
|
3598
|
-
if
|
|
3599
|
-
if len(
|
|
3650
|
+
if get_origin(adj.arg_types["return"]) is tuple:
|
|
3651
|
+
if len(get_args(adj.arg_types["return"])) != len(adj.return_var):
|
|
3600
3652
|
raise WarpCodegenError(
|
|
3601
3653
|
f"The function `{adj.fun_name}` has its return type "
|
|
3602
|
-
f"annotated as a tuple of {len(
|
|
3654
|
+
f"annotated as a tuple of {len(get_args(adj.arg_types['return']))} elements "
|
|
3603
3655
|
f"but the code returns {len(adj.return_var)} values."
|
|
3604
3656
|
)
|
|
3605
3657
|
elif not types_equal(adj.arg_types["return"], tuple(x.type for x in adj.return_var)):
|
|
@@ -3608,7 +3660,7 @@ def codegen_func(adj, c_func_name: str, device="cpu", options=None):
|
|
|
3608
3660
|
f"annotated as `{warp.context.type_str(adj.arg_types['return'])}` "
|
|
3609
3661
|
f"but the code returns a tuple with types `({', '.join(warp.context.type_str(x.type) for x in adj.return_var)})`."
|
|
3610
3662
|
)
|
|
3611
|
-
elif len(adj.return_var) > 1 and
|
|
3663
|
+
elif len(adj.return_var) > 1 and get_origin(adj.arg_types["return"]) is not tuple:
|
|
3612
3664
|
raise WarpCodegenError(
|
|
3613
3665
|
f"The function `{adj.fun_name}` has its return type "
|
|
3614
3666
|
f"annotated as `{warp.context.type_str(adj.arg_types['return'])}` "
|
|
@@ -3621,6 +3673,13 @@ def codegen_func(adj, c_func_name: str, device="cpu", options=None):
|
|
|
3621
3673
|
f"but the code returns a value of type `{warp.context.type_str(adj.return_var[0].type)}`."
|
|
3622
3674
|
)
|
|
3623
3675
|
|
|
3676
|
+
# Build line directive for function definition (subtract 1 to account for 1-indexing of AST line numbers)
|
|
3677
|
+
# This is used as a catch-all C-to-Python source line mapping for any code that does not have
|
|
3678
|
+
# a direct mapping to a Python source line.
|
|
3679
|
+
func_line_directive = ""
|
|
3680
|
+
if line_directive := adj.get_line_directive("", adj.fun_def_lineno - 1):
|
|
3681
|
+
func_line_directive = f"{line_directive}\n"
|
|
3682
|
+
|
|
3624
3683
|
# forward header
|
|
3625
3684
|
if adj.return_var is not None and len(adj.return_var) == 1:
|
|
3626
3685
|
return_type = adj.return_var[0].ctype()
|
|
@@ -3684,6 +3743,7 @@ def codegen_func(adj, c_func_name: str, device="cpu", options=None):
|
|
|
3684
3743
|
forward_body=forward_body,
|
|
3685
3744
|
filename=adj.filename,
|
|
3686
3745
|
lineno=adj.fun_lineno,
|
|
3746
|
+
line_directive=func_line_directive,
|
|
3687
3747
|
)
|
|
3688
3748
|
|
|
3689
3749
|
if not adj.skip_reverse_codegen:
|
|
@@ -3702,6 +3762,7 @@ def codegen_func(adj, c_func_name: str, device="cpu", options=None):
|
|
|
3702
3762
|
reverse_body=reverse_body,
|
|
3703
3763
|
filename=adj.filename,
|
|
3704
3764
|
lineno=adj.fun_lineno,
|
|
3765
|
+
line_directive=func_line_directive,
|
|
3705
3766
|
)
|
|
3706
3767
|
|
|
3707
3768
|
return s
|
|
@@ -3744,6 +3805,7 @@ def codegen_snippet(adj, name, snippet, adj_snippet, replay_snippet):
|
|
|
3744
3805
|
forward_body=snippet,
|
|
3745
3806
|
filename=adj.filename,
|
|
3746
3807
|
lineno=adj.fun_lineno,
|
|
3808
|
+
line_directive="",
|
|
3747
3809
|
)
|
|
3748
3810
|
|
|
3749
3811
|
if replay_snippet is not None:
|
|
@@ -3754,6 +3816,7 @@ def codegen_snippet(adj, name, snippet, adj_snippet, replay_snippet):
|
|
|
3754
3816
|
forward_body=replay_snippet,
|
|
3755
3817
|
filename=adj.filename,
|
|
3756
3818
|
lineno=adj.fun_lineno,
|
|
3819
|
+
line_directive="",
|
|
3757
3820
|
)
|
|
3758
3821
|
|
|
3759
3822
|
if adj_snippet:
|
|
@@ -3769,6 +3832,7 @@ def codegen_snippet(adj, name, snippet, adj_snippet, replay_snippet):
|
|
|
3769
3832
|
reverse_body=reverse_body,
|
|
3770
3833
|
filename=adj.filename,
|
|
3771
3834
|
lineno=adj.fun_lineno,
|
|
3835
|
+
line_directive="",
|
|
3772
3836
|
)
|
|
3773
3837
|
|
|
3774
3838
|
return s
|
|
@@ -3781,6 +3845,13 @@ def codegen_kernel(kernel, device, options):
|
|
|
3781
3845
|
|
|
3782
3846
|
adj = kernel.adj
|
|
3783
3847
|
|
|
3848
|
+
# Build line directive for function definition (subtract 1 to account for 1-indexing of AST line numbers)
|
|
3849
|
+
# This is used as a catch-all C-to-Python source line mapping for any code that does not have
|
|
3850
|
+
# a direct mapping to a Python source line.
|
|
3851
|
+
func_line_directive = ""
|
|
3852
|
+
if line_directive := adj.get_line_directive("", adj.fun_def_lineno - 1):
|
|
3853
|
+
func_line_directive = f"{line_directive}\n"
|
|
3854
|
+
|
|
3784
3855
|
if device == "cpu":
|
|
3785
3856
|
template_forward = cpu_kernel_template_forward
|
|
3786
3857
|
template_backward = cpu_kernel_template_backward
|
|
@@ -3808,6 +3879,7 @@ def codegen_kernel(kernel, device, options):
|
|
|
3808
3879
|
{
|
|
3809
3880
|
"forward_args": indent(forward_args),
|
|
3810
3881
|
"forward_body": forward_body,
|
|
3882
|
+
"line_directive": func_line_directive,
|
|
3811
3883
|
}
|
|
3812
3884
|
)
|
|
3813
3885
|
template += template_forward
|