warp-lang 1.0.0b2__py3-none-manylinux2014_x86_64.whl → 1.0.0b6__py3-none-manylinux2014_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docs/conf.py +17 -5
- examples/env/env_ant.py +1 -1
- examples/env/env_cartpole.py +1 -1
- examples/env/env_humanoid.py +1 -1
- examples/env/env_usd.py +4 -1
- examples/env/environment.py +8 -9
- examples/example_dem.py +34 -33
- examples/example_diffray.py +364 -337
- examples/example_fluid.py +32 -23
- examples/example_jacobian_ik.py +97 -93
- examples/example_marching_cubes.py +6 -16
- examples/example_mesh.py +6 -16
- examples/example_mesh_intersect.py +16 -14
- examples/example_nvdb.py +14 -16
- examples/example_raycast.py +14 -13
- examples/example_raymarch.py +16 -23
- examples/example_render_opengl.py +19 -10
- examples/example_sim_cartpole.py +82 -78
- examples/example_sim_cloth.py +45 -48
- examples/example_sim_fk_grad.py +51 -44
- examples/example_sim_fk_grad_torch.py +47 -40
- examples/example_sim_grad_bounce.py +108 -133
- examples/example_sim_grad_cloth.py +99 -113
- examples/example_sim_granular.py +5 -6
- examples/{example_sim_sdf_shape.py → example_sim_granular_collision_sdf.py} +37 -26
- examples/example_sim_neo_hookean.py +51 -55
- examples/example_sim_particle_chain.py +4 -4
- examples/example_sim_quadruped.py +126 -81
- examples/example_sim_rigid_chain.py +54 -61
- examples/example_sim_rigid_contact.py +66 -70
- examples/example_sim_rigid_fem.py +3 -3
- examples/example_sim_rigid_force.py +1 -1
- examples/example_sim_rigid_gyroscopic.py +3 -4
- examples/example_sim_rigid_kinematics.py +28 -39
- examples/example_sim_trajopt.py +112 -110
- examples/example_sph.py +9 -8
- examples/example_wave.py +7 -7
- examples/fem/bsr_utils.py +30 -17
- examples/fem/example_apic_fluid.py +85 -69
- examples/fem/example_convection_diffusion.py +97 -93
- examples/fem/example_convection_diffusion_dg.py +142 -149
- examples/fem/example_convection_diffusion_dg0.py +141 -136
- examples/fem/example_deformed_geometry.py +146 -0
- examples/fem/example_diffusion.py +115 -84
- examples/fem/example_diffusion_3d.py +116 -86
- examples/fem/example_diffusion_mgpu.py +102 -79
- examples/fem/example_mixed_elasticity.py +139 -100
- examples/fem/example_navier_stokes.py +175 -162
- examples/fem/example_stokes.py +143 -111
- examples/fem/example_stokes_transfer.py +186 -157
- examples/fem/mesh_utils.py +59 -97
- examples/fem/plot_utils.py +138 -17
- tools/ci/publishing/build_nodes_info.py +54 -0
- warp/__init__.py +4 -3
- warp/__init__.pyi +1 -0
- warp/bin/warp-clang.so +0 -0
- warp/bin/warp.so +0 -0
- warp/build.py +5 -3
- warp/build_dll.py +29 -9
- warp/builtins.py +836 -492
- warp/codegen.py +864 -553
- warp/config.py +3 -1
- warp/context.py +389 -172
- warp/fem/__init__.py +24 -6
- warp/fem/cache.py +318 -25
- warp/fem/dirichlet.py +7 -3
- warp/fem/domain.py +14 -0
- warp/fem/field/__init__.py +30 -38
- warp/fem/field/field.py +149 -0
- warp/fem/field/nodal_field.py +244 -138
- warp/fem/field/restriction.py +8 -6
- warp/fem/field/test.py +127 -59
- warp/fem/field/trial.py +117 -60
- warp/fem/geometry/__init__.py +5 -1
- warp/fem/geometry/deformed_geometry.py +271 -0
- warp/fem/geometry/element.py +24 -1
- warp/fem/geometry/geometry.py +86 -14
- warp/fem/geometry/grid_2d.py +112 -54
- warp/fem/geometry/grid_3d.py +134 -65
- warp/fem/geometry/hexmesh.py +953 -0
- warp/fem/geometry/partition.py +85 -33
- warp/fem/geometry/quadmesh_2d.py +532 -0
- warp/fem/geometry/tetmesh.py +451 -115
- warp/fem/geometry/trimesh_2d.py +197 -92
- warp/fem/integrate.py +534 -268
- warp/fem/operator.py +58 -31
- warp/fem/polynomial.py +11 -0
- warp/fem/quadrature/__init__.py +1 -1
- warp/fem/quadrature/pic_quadrature.py +150 -58
- warp/fem/quadrature/quadrature.py +209 -57
- warp/fem/space/__init__.py +230 -53
- warp/fem/space/basis_space.py +489 -0
- warp/fem/space/collocated_function_space.py +105 -0
- warp/fem/space/dof_mapper.py +49 -2
- warp/fem/space/function_space.py +90 -39
- warp/fem/space/grid_2d_function_space.py +149 -496
- warp/fem/space/grid_3d_function_space.py +173 -538
- warp/fem/space/hexmesh_function_space.py +352 -0
- warp/fem/space/partition.py +129 -76
- warp/fem/space/quadmesh_2d_function_space.py +369 -0
- warp/fem/space/restriction.py +46 -34
- warp/fem/space/shape/__init__.py +15 -0
- warp/fem/space/shape/cube_shape_function.py +738 -0
- warp/fem/space/shape/shape_function.py +103 -0
- warp/fem/space/shape/square_shape_function.py +611 -0
- warp/fem/space/shape/tet_shape_function.py +567 -0
- warp/fem/space/shape/triangle_shape_function.py +429 -0
- warp/fem/space/tetmesh_function_space.py +132 -1039
- warp/fem/space/topology.py +295 -0
- warp/fem/space/trimesh_2d_function_space.py +104 -742
- warp/fem/types.py +13 -11
- warp/fem/utils.py +335 -60
- warp/native/array.h +120 -34
- warp/native/builtin.h +101 -72
- warp/native/bvh.cpp +73 -325
- warp/native/bvh.cu +406 -23
- warp/native/bvh.h +22 -40
- warp/native/clang/clang.cpp +1 -0
- warp/native/crt.h +2 -0
- warp/native/cuda_util.cpp +8 -3
- warp/native/cuda_util.h +1 -0
- warp/native/exports.h +1522 -1243
- warp/native/intersect.h +19 -4
- warp/native/intersect_adj.h +8 -8
- warp/native/mat.h +76 -17
- warp/native/mesh.cpp +33 -108
- warp/native/mesh.cu +114 -18
- warp/native/mesh.h +395 -40
- warp/native/noise.h +272 -329
- warp/native/quat.h +51 -8
- warp/native/rand.h +44 -34
- warp/native/reduce.cpp +1 -1
- warp/native/sparse.cpp +4 -4
- warp/native/sparse.cu +163 -155
- warp/native/spatial.h +2 -2
- warp/native/temp_buffer.h +18 -14
- warp/native/vec.h +103 -21
- warp/native/warp.cpp +2 -1
- warp/native/warp.cu +28 -3
- warp/native/warp.h +4 -3
- warp/render/render_opengl.py +261 -109
- warp/sim/__init__.py +1 -2
- warp/sim/articulation.py +385 -185
- warp/sim/import_mjcf.py +59 -48
- warp/sim/import_urdf.py +15 -15
- warp/sim/import_usd.py +174 -102
- warp/sim/inertia.py +17 -18
- warp/sim/integrator_xpbd.py +4 -3
- warp/sim/model.py +330 -250
- warp/sim/render.py +1 -1
- warp/sparse.py +625 -152
- warp/stubs.py +341 -309
- warp/tape.py +9 -6
- warp/tests/__main__.py +3 -6
- warp/tests/assets/curlnoise_golden.npy +0 -0
- warp/tests/assets/pnoise_golden.npy +0 -0
- warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
- warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
- warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
- warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
- warp/tests/aux_test_unresolved_func.py +14 -0
- warp/tests/aux_test_unresolved_symbol.py +14 -0
- warp/tests/disabled_kinematics.py +239 -0
- warp/tests/run_coverage_serial.py +31 -0
- warp/tests/test_adam.py +103 -106
- warp/tests/test_arithmetic.py +94 -74
- warp/tests/test_array.py +82 -101
- warp/tests/test_array_reduce.py +57 -23
- warp/tests/test_atomic.py +64 -28
- warp/tests/test_bool.py +22 -12
- warp/tests/test_builtins_resolution.py +1292 -0
- warp/tests/test_bvh.py +18 -18
- warp/tests/test_closest_point_edge_edge.py +54 -57
- warp/tests/test_codegen.py +165 -134
- warp/tests/test_compile_consts.py +28 -20
- warp/tests/test_conditional.py +108 -24
- warp/tests/test_copy.py +10 -12
- warp/tests/test_ctypes.py +112 -88
- warp/tests/test_dense.py +21 -14
- warp/tests/test_devices.py +98 -0
- warp/tests/test_dlpack.py +75 -75
- warp/tests/test_examples.py +237 -0
- warp/tests/test_fabricarray.py +22 -24
- warp/tests/test_fast_math.py +15 -11
- warp/tests/test_fem.py +1034 -124
- warp/tests/test_fp16.py +23 -16
- warp/tests/test_func.py +187 -86
- warp/tests/test_generics.py +194 -49
- warp/tests/test_grad.py +123 -181
- warp/tests/test_grad_customs.py +176 -0
- warp/tests/test_hash_grid.py +35 -34
- warp/tests/test_import.py +10 -23
- warp/tests/test_indexedarray.py +24 -25
- warp/tests/test_intersect.py +18 -9
- warp/tests/test_large.py +141 -0
- warp/tests/test_launch.py +14 -41
- warp/tests/test_lerp.py +64 -65
- warp/tests/test_lvalue.py +493 -0
- warp/tests/test_marching_cubes.py +12 -13
- warp/tests/test_mat.py +517 -2898
- warp/tests/test_mat_lite.py +115 -0
- warp/tests/test_mat_scalar_ops.py +2889 -0
- warp/tests/test_math.py +103 -9
- warp/tests/test_matmul.py +304 -69
- warp/tests/test_matmul_lite.py +410 -0
- warp/tests/test_mesh.py +60 -22
- warp/tests/test_mesh_query_aabb.py +21 -25
- warp/tests/test_mesh_query_point.py +111 -22
- warp/tests/test_mesh_query_ray.py +12 -24
- warp/tests/test_mlp.py +30 -22
- warp/tests/test_model.py +92 -89
- warp/tests/test_modules_lite.py +39 -0
- warp/tests/test_multigpu.py +88 -114
- warp/tests/test_noise.py +12 -11
- warp/tests/test_operators.py +16 -20
- warp/tests/test_options.py +11 -11
- warp/tests/test_pinned.py +17 -18
- warp/tests/test_print.py +32 -11
- warp/tests/test_quat.py +275 -129
- warp/tests/test_rand.py +18 -16
- warp/tests/test_reload.py +38 -34
- warp/tests/test_rounding.py +50 -43
- warp/tests/test_runlength_encode.py +168 -20
- warp/tests/test_smoothstep.py +9 -11
- warp/tests/test_snippet.py +143 -0
- warp/tests/test_sparse.py +261 -63
- warp/tests/test_spatial.py +276 -243
- warp/tests/test_streams.py +110 -85
- warp/tests/test_struct.py +268 -63
- warp/tests/test_tape.py +39 -21
- warp/tests/test_torch.py +90 -86
- warp/tests/test_transient_module.py +10 -12
- warp/tests/test_types.py +363 -0
- warp/tests/test_utils.py +451 -0
- warp/tests/test_vec.py +354 -2050
- warp/tests/test_vec_lite.py +73 -0
- warp/tests/test_vec_scalar_ops.py +2099 -0
- warp/tests/test_volume.py +418 -376
- warp/tests/test_volume_write.py +124 -134
- warp/tests/unittest_serial.py +35 -0
- warp/tests/unittest_suites.py +291 -0
- warp/tests/unittest_utils.py +342 -0
- warp/tests/{test_misc.py → unused_test_misc.py} +13 -5
- warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
- warp/thirdparty/appdirs.py +36 -45
- warp/thirdparty/unittest_parallel.py +589 -0
- warp/types.py +622 -211
- warp/utils.py +54 -393
- warp_lang-1.0.0b6.dist-info/METADATA +238 -0
- warp_lang-1.0.0b6.dist-info/RECORD +409 -0
- {warp_lang-1.0.0b2.dist-info → warp_lang-1.0.0b6.dist-info}/WHEEL +1 -1
- examples/example_cache_management.py +0 -40
- examples/example_multigpu.py +0 -54
- examples/example_struct.py +0 -65
- examples/fem/example_stokes_transfer_3d.py +0 -210
- warp/fem/field/discrete_field.py +0 -80
- warp/fem/space/nodal_function_space.py +0 -233
- warp/tests/test_all.py +0 -223
- warp/tests/test_array_scan.py +0 -60
- warp/tests/test_base.py +0 -208
- warp/tests/test_unresolved_func.py +0 -7
- warp/tests/test_unresolved_symbol.py +0 -7
- warp_lang-1.0.0b2.dist-info/METADATA +0 -26
- warp_lang-1.0.0b2.dist-info/RECORD +0 -378
- /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
- /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
- /warp/tests/{test_square.py → aux_test_square.py} +0 -0
- {warp_lang-1.0.0b2.dist-info → warp_lang-1.0.0b6.dist-info}/LICENSE.md +0 -0
- {warp_lang-1.0.0b2.dist-info → warp_lang-1.0.0b6.dist-info}/top_level.txt +0 -0
warp/codegen.py
CHANGED
|
@@ -20,6 +20,27 @@ from typing import Any, Callable, Mapping
|
|
|
20
20
|
import warp.config
|
|
21
21
|
from warp.types import *
|
|
22
22
|
|
|
23
|
+
|
|
24
|
+
class WarpCodegenError(RuntimeError):
|
|
25
|
+
def __init__(self, message):
|
|
26
|
+
super().__init__(message)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class WarpCodegenTypeError(TypeError):
|
|
30
|
+
def __init__(self, message):
|
|
31
|
+
super().__init__(message)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class WarpCodegenAttributeError(AttributeError):
|
|
35
|
+
def __init__(self, message):
|
|
36
|
+
super().__init__(message)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class WarpCodegenKeyError(KeyError):
|
|
40
|
+
def __init__(self, message):
|
|
41
|
+
super().__init__(message)
|
|
42
|
+
|
|
43
|
+
|
|
23
44
|
# map operator to function name
|
|
24
45
|
builtin_operators = {}
|
|
25
46
|
|
|
@@ -52,6 +73,19 @@ builtin_operators[ast.Invert] = "invert"
|
|
|
52
73
|
builtin_operators[ast.LShift] = "lshift"
|
|
53
74
|
builtin_operators[ast.RShift] = "rshift"
|
|
54
75
|
|
|
76
|
+
comparison_chain_strings = [
|
|
77
|
+
builtin_operators[ast.Gt],
|
|
78
|
+
builtin_operators[ast.Lt],
|
|
79
|
+
builtin_operators[ast.LtE],
|
|
80
|
+
builtin_operators[ast.GtE],
|
|
81
|
+
builtin_operators[ast.Eq],
|
|
82
|
+
builtin_operators[ast.NotEq],
|
|
83
|
+
]
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def op_str_is_chainable(op: str) -> builtins.bool:
|
|
87
|
+
return op in comparison_chain_strings
|
|
88
|
+
|
|
55
89
|
|
|
56
90
|
def get_annotations(obj: Any) -> Mapping[str, Any]:
|
|
57
91
|
"""Alternative to `inspect.get_annotations()` for Python 3.9 and older."""
|
|
@@ -65,16 +99,14 @@ def get_annotations(obj: Any) -> Mapping[str, Any]:
|
|
|
65
99
|
def struct_instance_repr_recursive(inst: StructInstance, depth: int) -> str:
|
|
66
100
|
indent = "\t"
|
|
67
101
|
|
|
68
|
-
|
|
102
|
+
# handle empty structs
|
|
103
|
+
if len(inst._cls.vars) == 0:
|
|
69
104
|
return f"{inst._cls.key}()"
|
|
70
105
|
|
|
71
106
|
lines = []
|
|
72
107
|
lines.append(f"{inst._cls.key}(")
|
|
73
108
|
|
|
74
109
|
for field_name, _ in inst._cls.ctype._fields_:
|
|
75
|
-
if field_name == "_dummy_":
|
|
76
|
-
continue
|
|
77
|
-
|
|
78
110
|
field_value = getattr(inst, field_name, None)
|
|
79
111
|
|
|
80
112
|
if isinstance(field_value, StructInstance):
|
|
@@ -121,9 +153,7 @@ class StructInstance:
|
|
|
121
153
|
assert isinstance(value, array)
|
|
122
154
|
assert types_equal(
|
|
123
155
|
value.dtype, var.type.dtype
|
|
124
|
-
), "assign to struct member variable {} failed, expected type {}, got type {}"
|
|
125
|
-
name, type_repr(var.type.dtype), type_repr(value.dtype)
|
|
126
|
-
)
|
|
156
|
+
), f"assign to struct member variable {name} failed, expected type {type_repr(var.type.dtype)}, got type {type_repr(value.dtype)}"
|
|
127
157
|
setattr(self._ctype, name, value.__ctype__())
|
|
128
158
|
|
|
129
159
|
elif isinstance(var.type, Struct):
|
|
@@ -242,7 +272,7 @@ class Struct:
|
|
|
242
272
|
|
|
243
273
|
class StructType(ctypes.Structure):
|
|
244
274
|
# if struct is empty, add a dummy field to avoid launch errors on CPU device ("ffi_prep_cif failed")
|
|
245
|
-
_fields_ = fields or [("_dummy_", ctypes.
|
|
275
|
+
_fields_ = fields or [("_dummy_", ctypes.c_byte)]
|
|
246
276
|
|
|
247
277
|
self.ctype = StructType
|
|
248
278
|
|
|
@@ -363,21 +393,38 @@ class Struct:
|
|
|
363
393
|
return instance
|
|
364
394
|
|
|
365
395
|
|
|
396
|
+
class Reference:
|
|
397
|
+
def __init__(self, value_type):
|
|
398
|
+
self.value_type = value_type
|
|
399
|
+
|
|
400
|
+
|
|
401
|
+
def is_reference(type):
|
|
402
|
+
return isinstance(type, Reference)
|
|
403
|
+
|
|
404
|
+
|
|
405
|
+
def strip_reference(arg):
|
|
406
|
+
if is_reference(arg):
|
|
407
|
+
return arg.value_type
|
|
408
|
+
else:
|
|
409
|
+
return arg
|
|
410
|
+
|
|
411
|
+
|
|
366
412
|
def compute_type_str(base_name, template_params):
|
|
367
|
-
if
|
|
413
|
+
if not template_params:
|
|
368
414
|
return base_name
|
|
369
|
-
else:
|
|
370
415
|
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
416
|
+
def param2str(p):
|
|
417
|
+
if isinstance(p, int):
|
|
418
|
+
return str(p)
|
|
419
|
+
elif hasattr(p, "_type_"):
|
|
420
|
+
return f"wp::{p.__name__}"
|
|
421
|
+
return p.__name__
|
|
375
422
|
|
|
376
|
-
|
|
423
|
+
return f"{base_name}<{','.join(map(param2str, template_params))}>"
|
|
377
424
|
|
|
378
425
|
|
|
379
426
|
class Var:
|
|
380
|
-
def __init__(self, label, type, requires_grad=False, constant=None, prefix=True
|
|
427
|
+
def __init__(self, label, type, requires_grad=False, constant=None, prefix=True):
|
|
381
428
|
# convert built-in types to wp types
|
|
382
429
|
if type == float:
|
|
383
430
|
type = float32
|
|
@@ -389,27 +436,39 @@ class Var:
|
|
|
389
436
|
self.requires_grad = requires_grad
|
|
390
437
|
self.constant = constant
|
|
391
438
|
self.prefix = prefix
|
|
392
|
-
self.is_adjoint = is_adjoint
|
|
393
439
|
|
|
394
440
|
def __str__(self):
|
|
395
441
|
return self.label
|
|
396
442
|
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
443
|
+
@staticmethod
|
|
444
|
+
def type_to_ctype(t, value_type=False):
|
|
445
|
+
if is_array(t):
|
|
446
|
+
if hasattr(t.dtype, "_wp_generic_type_str_"):
|
|
447
|
+
dtypestr = compute_type_str(f"wp::{t.dtype._wp_generic_type_str_}", t.dtype._wp_type_params_)
|
|
448
|
+
elif isinstance(t.dtype, Struct):
|
|
449
|
+
dtypestr = make_full_qualified_name(t.dtype.cls)
|
|
450
|
+
elif t.dtype.__name__ in ("bool", "int", "float"):
|
|
451
|
+
dtypestr = t.dtype.__name__
|
|
403
452
|
else:
|
|
404
|
-
dtypestr =
|
|
405
|
-
classstr = type(
|
|
453
|
+
dtypestr = f"wp::{t.dtype.__name__}"
|
|
454
|
+
classstr = f"wp::{type(t).__name__}"
|
|
406
455
|
return f"{classstr}_t<{dtypestr}>"
|
|
407
|
-
elif isinstance(
|
|
408
|
-
return make_full_qualified_name(
|
|
409
|
-
elif
|
|
410
|
-
|
|
456
|
+
elif isinstance(t, Struct):
|
|
457
|
+
return make_full_qualified_name(t.cls)
|
|
458
|
+
elif is_reference(t):
|
|
459
|
+
if not value_type:
|
|
460
|
+
return Var.type_to_ctype(t.value_type) + "*"
|
|
461
|
+
else:
|
|
462
|
+
return Var.type_to_ctype(t.value_type)
|
|
463
|
+
elif hasattr(t, "_wp_generic_type_str_"):
|
|
464
|
+
return compute_type_str(f"wp::{t._wp_generic_type_str_}", t._wp_type_params_)
|
|
465
|
+
elif t.__name__ in ("bool", "int", "float"):
|
|
466
|
+
return t.__name__
|
|
411
467
|
else:
|
|
412
|
-
return
|
|
468
|
+
return f"wp::{t.__name__}"
|
|
469
|
+
|
|
470
|
+
def ctype(self, value_type=False):
|
|
471
|
+
return Var.type_to_ctype(self.type, value_type)
|
|
413
472
|
|
|
414
473
|
def emit(self, prefix: str = "var"):
|
|
415
474
|
if self.prefix:
|
|
@@ -417,6 +476,9 @@ class Var:
|
|
|
417
476
|
else:
|
|
418
477
|
return self.label
|
|
419
478
|
|
|
479
|
+
def emit_adj(self):
|
|
480
|
+
return self.emit("adj")
|
|
481
|
+
|
|
420
482
|
|
|
421
483
|
class Block:
|
|
422
484
|
# Represents a basic block of instructions, e.g.: list
|
|
@@ -456,20 +518,17 @@ class Adjoint:
|
|
|
456
518
|
# whether the generation of the adjoint code is skipped for this function
|
|
457
519
|
adj.skip_reverse_codegen = skip_reverse_codegen
|
|
458
520
|
|
|
459
|
-
#
|
|
460
|
-
adj.
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
adj.raw_source, adj.fun_lineno = inspect.getsourcelines(func)
|
|
464
|
-
|
|
465
|
-
# keep track of line number in function code
|
|
466
|
-
adj.lineno = None
|
|
521
|
+
# extract name of source file
|
|
522
|
+
adj.filename = inspect.getsourcefile(func) or "unknown source file"
|
|
523
|
+
# get source file line number where function starts
|
|
524
|
+
_, adj.fun_lineno = inspect.getsourcelines(func)
|
|
467
525
|
|
|
526
|
+
# get function source code
|
|
527
|
+
adj.source = inspect.getsource(func)
|
|
468
528
|
# ensures that indented class methods can be parsed as kernels
|
|
469
529
|
adj.source = textwrap.dedent(adj.source)
|
|
470
530
|
|
|
471
|
-
|
|
472
|
-
adj.filename = inspect.getsourcefile(func) or "unknown source file"
|
|
531
|
+
adj.source_lines = adj.source.splitlines()
|
|
473
532
|
|
|
474
533
|
# build AST and apply node transformers
|
|
475
534
|
adj.tree = ast.parse(adj.source)
|
|
@@ -479,6 +538,9 @@ class Adjoint:
|
|
|
479
538
|
|
|
480
539
|
adj.fun_name = adj.tree.body[0].name
|
|
481
540
|
|
|
541
|
+
# for keeping track of line number in function code
|
|
542
|
+
adj.lineno = None
|
|
543
|
+
|
|
482
544
|
# whether the forward code shall be used for the reverse pass and a custom
|
|
483
545
|
# function signature is applied to the reverse version of the function
|
|
484
546
|
adj.custom_reverse_mode = custom_reverse_mode
|
|
@@ -493,16 +555,17 @@ class Adjoint:
|
|
|
493
555
|
if overload_annotations is None:
|
|
494
556
|
# use source-level argument annotations
|
|
495
557
|
if len(argspec.annotations) < len(argspec.args):
|
|
496
|
-
raise
|
|
558
|
+
raise WarpCodegenError(f"Incomplete argument annotations on function {adj.fun_name}")
|
|
497
559
|
adj.arg_types = argspec.annotations
|
|
498
560
|
else:
|
|
499
561
|
# use overload argument annotations
|
|
500
562
|
for arg_name in argspec.args:
|
|
501
563
|
if arg_name not in overload_annotations:
|
|
502
|
-
raise
|
|
564
|
+
raise WarpCodegenError(f"Incomplete overload annotations for function {adj.fun_name}")
|
|
503
565
|
adj.arg_types = overload_annotations.copy()
|
|
504
566
|
|
|
505
567
|
adj.args = []
|
|
568
|
+
adj.symbols = {}
|
|
506
569
|
|
|
507
570
|
for name, type in adj.arg_types.items():
|
|
508
571
|
# skip return hint
|
|
@@ -513,8 +576,23 @@ class Adjoint:
|
|
|
513
576
|
arg = Var(name, type, False)
|
|
514
577
|
adj.args.append(arg)
|
|
515
578
|
|
|
579
|
+
# pre-populate symbol dictionary with function argument names
|
|
580
|
+
# this is to avoid registering false references to overshadowed modules
|
|
581
|
+
adj.symbols[name] = arg
|
|
582
|
+
|
|
583
|
+
# There are cases where a same module might be rebuilt multiple times,
|
|
584
|
+
# for example when kernels are nested inside of functions, or when
|
|
585
|
+
# a kernel's launch raises an exception. Ideally we'd always want to
|
|
586
|
+
# avoid rebuilding kernels but some corner cases seem to depend on it,
|
|
587
|
+
# so we only avoid rebuilding kernels that errored out to give a chance
|
|
588
|
+
# for unit testing errors being spit out from kernels.
|
|
589
|
+
adj.skip_build = False
|
|
590
|
+
|
|
516
591
|
# generate function ssa form and adjoint
|
|
517
592
|
def build(adj, builder):
|
|
593
|
+
if adj.skip_build:
|
|
594
|
+
return
|
|
595
|
+
|
|
518
596
|
adj.builder = builder
|
|
519
597
|
|
|
520
598
|
adj.symbols = {} # map from symbols to adjoint variables
|
|
@@ -528,7 +606,7 @@ class Adjoint:
|
|
|
528
606
|
adj.loop_blocks = []
|
|
529
607
|
|
|
530
608
|
# holds current indent level
|
|
531
|
-
adj.
|
|
609
|
+
adj.indentation = ""
|
|
532
610
|
|
|
533
611
|
# used to generate new label indices
|
|
534
612
|
adj.label_count = 0
|
|
@@ -542,12 +620,17 @@ class Adjoint:
|
|
|
542
620
|
adj.eval(adj.tree.body[0])
|
|
543
621
|
except Exception as e:
|
|
544
622
|
try:
|
|
623
|
+
if isinstance(e, KeyError) and getattr(e.args[0], "__module__", None) == "ast":
|
|
624
|
+
msg = f'Syntax error: unsupported construct "ast.{e.args[0].__name__}"'
|
|
625
|
+
else:
|
|
626
|
+
msg = "Error"
|
|
545
627
|
lineno = adj.lineno + adj.fun_lineno
|
|
546
|
-
line = adj.
|
|
547
|
-
msg
|
|
628
|
+
line = adj.source_lines[adj.lineno]
|
|
629
|
+
msg += f' while parsing function "{adj.fun_name}" at {adj.filename}:{lineno}:\n{line}\n'
|
|
548
630
|
ex, data, traceback = sys.exc_info()
|
|
549
|
-
e = ex("".join([msg] +
|
|
631
|
+
e = ex(";".join([msg] + [str(a) for a in data.args])).with_traceback(traceback)
|
|
550
632
|
finally:
|
|
633
|
+
adj.skip_build = True
|
|
551
634
|
raise e
|
|
552
635
|
|
|
553
636
|
if builder is not None:
|
|
@@ -570,16 +653,18 @@ class Adjoint:
|
|
|
570
653
|
arg_strs = []
|
|
571
654
|
|
|
572
655
|
for a in args:
|
|
573
|
-
if
|
|
656
|
+
if isinstance(a, warp.context.Function):
|
|
574
657
|
# functions don't have a var_ prefix so strip it off here
|
|
575
658
|
if prefix == "var":
|
|
576
659
|
arg_strs.append(a.key)
|
|
577
660
|
else:
|
|
578
661
|
arg_strs.append(f"{prefix}_{a.key}")
|
|
662
|
+
elif is_reference(a.type):
|
|
663
|
+
arg_strs.append(f"{prefix}_{a}")
|
|
579
664
|
elif isinstance(a, Var):
|
|
580
665
|
arg_strs.append(a.emit(prefix))
|
|
581
666
|
else:
|
|
582
|
-
|
|
667
|
+
raise WarpCodegenTypeError(f"Arguments must be variables or functions, got {type(a)}")
|
|
583
668
|
|
|
584
669
|
return arg_strs
|
|
585
670
|
|
|
@@ -587,30 +672,37 @@ class Adjoint:
|
|
|
587
672
|
def format_forward_call_args(adj, args, use_initializer_list):
|
|
588
673
|
arg_str = ", ".join(adj.format_args("var", args))
|
|
589
674
|
if use_initializer_list:
|
|
590
|
-
return "{{{}}}"
|
|
675
|
+
return f"{{{arg_str}}}"
|
|
591
676
|
return arg_str
|
|
592
677
|
|
|
593
678
|
# generates argument string for a reverse function call
|
|
594
679
|
def format_reverse_call_args(
|
|
595
|
-
adj,
|
|
680
|
+
adj,
|
|
681
|
+
args_var,
|
|
682
|
+
args,
|
|
683
|
+
args_out,
|
|
684
|
+
use_initializer_list,
|
|
685
|
+
has_output_args=True,
|
|
686
|
+
require_original_output_arg=False,
|
|
596
687
|
):
|
|
597
|
-
formatted_var = adj.format_args("var",
|
|
688
|
+
formatted_var = adj.format_args("var", args_var)
|
|
598
689
|
formatted_out = []
|
|
599
|
-
if has_output_args and len(args_out) > 1:
|
|
690
|
+
if has_output_args and (require_original_output_arg or len(args_out) > 1):
|
|
600
691
|
formatted_out = adj.format_args("var", args_out)
|
|
601
692
|
formatted_var_adj = adj.format_args(
|
|
602
|
-
"&adj" if use_initializer_list else "adj",
|
|
693
|
+
"&adj" if use_initializer_list else "adj",
|
|
694
|
+
args,
|
|
603
695
|
)
|
|
604
|
-
formatted_out_adj = adj.format_args("adj",
|
|
696
|
+
formatted_out_adj = adj.format_args("adj", args_out)
|
|
605
697
|
|
|
606
698
|
if len(formatted_var_adj) == 0 and len(formatted_out_adj) == 0:
|
|
607
699
|
# there are no adjoint arguments, so we don't need to call the reverse function
|
|
608
700
|
return None
|
|
609
701
|
|
|
610
702
|
if use_initializer_list:
|
|
611
|
-
var_str = "{{{
|
|
612
|
-
out_str = "{{{
|
|
613
|
-
adj_str = "{{{
|
|
703
|
+
var_str = f"{{{', '.join(formatted_var)}}}"
|
|
704
|
+
out_str = f"{{{', '.join(formatted_out)}}}"
|
|
705
|
+
adj_str = f"{{{', '.join(formatted_var_adj)}}}"
|
|
614
706
|
out_adj_str = ", ".join(formatted_out_adj)
|
|
615
707
|
if len(args_out) > 1:
|
|
616
708
|
arg_str = ", ".join([var_str, out_str, adj_str, out_adj_str])
|
|
@@ -621,10 +713,10 @@ class Adjoint:
|
|
|
621
713
|
return arg_str
|
|
622
714
|
|
|
623
715
|
def indent(adj):
|
|
624
|
-
adj.
|
|
716
|
+
adj.indentation = adj.indentation + " "
|
|
625
717
|
|
|
626
718
|
def dedent(adj):
|
|
627
|
-
adj.
|
|
719
|
+
adj.indentation = adj.indentation[:-4]
|
|
628
720
|
|
|
629
721
|
def begin_block(adj):
|
|
630
722
|
b = Block()
|
|
@@ -639,10 +731,9 @@ class Adjoint:
|
|
|
639
731
|
def end_block(adj):
|
|
640
732
|
return adj.blocks.pop()
|
|
641
733
|
|
|
642
|
-
def add_var(adj, type=None, constant=None
|
|
643
|
-
|
|
644
|
-
|
|
645
|
-
name = str(index)
|
|
734
|
+
def add_var(adj, type=None, constant=None):
|
|
735
|
+
index = len(adj.variables)
|
|
736
|
+
name = str(index)
|
|
646
737
|
|
|
647
738
|
# allocate new variable
|
|
648
739
|
v = Var(name, type=type, constant=constant)
|
|
@@ -655,30 +746,54 @@ class Adjoint:
|
|
|
655
746
|
|
|
656
747
|
# append a statement to the forward pass
|
|
657
748
|
def add_forward(adj, statement, replay=None, skip_replay=False):
|
|
658
|
-
adj.blocks[-1].body_forward.append(adj.
|
|
749
|
+
adj.blocks[-1].body_forward.append(adj.indentation + statement)
|
|
659
750
|
|
|
660
751
|
if not skip_replay:
|
|
661
752
|
if replay:
|
|
662
753
|
# if custom replay specified then output it
|
|
663
|
-
adj.blocks[-1].body_replay.append(adj.
|
|
754
|
+
adj.blocks[-1].body_replay.append(adj.indentation + replay)
|
|
664
755
|
else:
|
|
665
756
|
# by default just replay the original statement
|
|
666
|
-
adj.blocks[-1].body_replay.append(adj.
|
|
757
|
+
adj.blocks[-1].body_replay.append(adj.indentation + statement)
|
|
667
758
|
|
|
668
759
|
# append a statement to the reverse pass
|
|
669
760
|
def add_reverse(adj, statement):
|
|
670
|
-
adj.blocks[-1].body_reverse.append(adj.
|
|
761
|
+
adj.blocks[-1].body_reverse.append(adj.indentation + statement)
|
|
671
762
|
|
|
672
763
|
def add_constant(adj, n):
|
|
673
764
|
output = adj.add_var(type=type(n), constant=n)
|
|
674
765
|
return output
|
|
675
766
|
|
|
767
|
+
def load(adj, var):
|
|
768
|
+
if is_reference(var.type):
|
|
769
|
+
var = adj.add_builtin_call("load", [var])
|
|
770
|
+
return var
|
|
771
|
+
|
|
676
772
|
def add_comp(adj, op_strings, left, comps):
|
|
677
773
|
output = adj.add_var(builtins.bool)
|
|
678
774
|
|
|
679
|
-
|
|
775
|
+
left = adj.load(left)
|
|
776
|
+
s = output.emit() + " = " + ("(" * len(comps)) + left.emit() + " "
|
|
777
|
+
|
|
778
|
+
prev_comp = None
|
|
779
|
+
|
|
680
780
|
for op, comp in zip(op_strings, comps):
|
|
681
|
-
|
|
781
|
+
comp_chainable = op_str_is_chainable(op)
|
|
782
|
+
if comp_chainable and prev_comp:
|
|
783
|
+
# We restrict chaining to operands of the same type
|
|
784
|
+
if prev_comp.type is comp.type:
|
|
785
|
+
prev_comp = adj.load(prev_comp)
|
|
786
|
+
comp = adj.load(comp)
|
|
787
|
+
s += "&& (" + prev_comp.emit() + " " + op + " " + comp.emit() + ")) "
|
|
788
|
+
else:
|
|
789
|
+
raise WarpCodegenTypeError(
|
|
790
|
+
f"Cannot chain comparisons of unequal types: {prev_comp.type} {op} {comp.type}."
|
|
791
|
+
)
|
|
792
|
+
else:
|
|
793
|
+
comp = adj.load(comp)
|
|
794
|
+
s += op + " " + comp.emit() + ") "
|
|
795
|
+
|
|
796
|
+
prev_comp = comp
|
|
682
797
|
|
|
683
798
|
s = s.rstrip() + ";"
|
|
684
799
|
|
|
@@ -687,110 +802,106 @@ class Adjoint:
|
|
|
687
802
|
return output
|
|
688
803
|
|
|
689
804
|
def add_bool_op(adj, op_string, exprs):
|
|
805
|
+
exprs = [adj.load(expr) for expr in exprs]
|
|
690
806
|
output = adj.add_var(builtins.bool)
|
|
691
|
-
command = (
|
|
692
|
-
"var_" + str(output) + " = " + (" " + op_string + " ").join(["var_" + str(expr) for expr in exprs]) + ";"
|
|
693
|
-
)
|
|
807
|
+
command = output.emit() + " = " + (" " + op_string + " ").join([expr.emit() for expr in exprs]) + ";"
|
|
694
808
|
adj.add_forward(command)
|
|
695
809
|
|
|
696
810
|
return output
|
|
697
811
|
|
|
698
|
-
def
|
|
699
|
-
|
|
700
|
-
# we validate argument types before they go to generated native code
|
|
701
|
-
resolved_func = None
|
|
812
|
+
def resolve_func(adj, func, args, min_outputs, templates, kwds):
|
|
813
|
+
arg_types = [strip_reference(a.type) for a in args if not isinstance(a, warp.context.Function)]
|
|
702
814
|
|
|
703
|
-
if func.is_builtin():
|
|
815
|
+
if not func.is_builtin():
|
|
816
|
+
# user-defined function
|
|
817
|
+
overload = func.get_overload(arg_types)
|
|
818
|
+
if overload is not None:
|
|
819
|
+
return overload
|
|
820
|
+
else:
|
|
821
|
+
# if func is overloaded then perform overload resolution here
|
|
822
|
+
# we validate argument types before they go to generated native code
|
|
704
823
|
for f in func.overloads:
|
|
705
|
-
match = True
|
|
706
|
-
|
|
707
824
|
# skip type checking for variadic functions
|
|
708
825
|
if not f.variadic:
|
|
709
826
|
# check argument counts match are compatible (may be some default args)
|
|
710
827
|
if len(f.input_types) < len(args):
|
|
711
|
-
match = False
|
|
712
828
|
continue
|
|
713
829
|
|
|
714
|
-
|
|
715
|
-
|
|
716
|
-
|
|
717
|
-
|
|
718
|
-
|
|
719
|
-
|
|
720
|
-
|
|
721
|
-
|
|
722
|
-
|
|
723
|
-
|
|
724
|
-
|
|
725
|
-
|
|
726
|
-
|
|
727
|
-
|
|
728
|
-
|
|
729
|
-
|
|
730
|
-
|
|
731
|
-
|
|
732
|
-
|
|
733
|
-
|
|
734
|
-
|
|
830
|
+
def match_args(args, f):
|
|
831
|
+
# check argument types equal
|
|
832
|
+
for i, (arg_name, arg_type) in enumerate(f.input_types.items()):
|
|
833
|
+
# if arg type registered as Any, treat as
|
|
834
|
+
# template allowing any type to match
|
|
835
|
+
if arg_type == Any:
|
|
836
|
+
continue
|
|
837
|
+
|
|
838
|
+
# handle function refs as a special case
|
|
839
|
+
if arg_type == Callable and type(args[i]) is warp.context.Function:
|
|
840
|
+
continue
|
|
841
|
+
|
|
842
|
+
if arg_type == Reference and is_reference(args[i].type):
|
|
843
|
+
continue
|
|
844
|
+
|
|
845
|
+
# look for default values for missing args
|
|
846
|
+
if i >= len(args):
|
|
847
|
+
if arg_name not in f.defaults:
|
|
848
|
+
return False
|
|
849
|
+
else:
|
|
850
|
+
# otherwise check arg type matches input variable type
|
|
851
|
+
if not types_equal(arg_type, strip_reference(args[i].type), match_generic=True):
|
|
852
|
+
return False
|
|
853
|
+
|
|
854
|
+
return True
|
|
855
|
+
|
|
856
|
+
if not match_args(args, f):
|
|
857
|
+
continue
|
|
735
858
|
|
|
736
859
|
# check output dimensions match expectations
|
|
737
860
|
if min_outputs:
|
|
738
861
|
try:
|
|
739
862
|
value_type = f.value_func(args, kwds, templates)
|
|
740
|
-
if len(value_type) != min_outputs:
|
|
741
|
-
match = False
|
|
863
|
+
if not hasattr(value_type, "__len__") or len(value_type) != min_outputs:
|
|
742
864
|
continue
|
|
743
865
|
except Exception:
|
|
744
866
|
# value func may fail if the user has given
|
|
745
867
|
# incorrect args, so we need to catch this
|
|
746
|
-
match = False
|
|
747
868
|
continue
|
|
748
869
|
|
|
749
870
|
# found a match, use it
|
|
750
|
-
|
|
751
|
-
|
|
752
|
-
|
|
753
|
-
|
|
754
|
-
|
|
755
|
-
|
|
756
|
-
|
|
757
|
-
|
|
758
|
-
|
|
759
|
-
|
|
760
|
-
|
|
761
|
-
|
|
762
|
-
|
|
763
|
-
|
|
764
|
-
# shorten Warp primitive type names
|
|
765
|
-
if isinstance(x.type, list):
|
|
766
|
-
if len(x.type) != 1:
|
|
767
|
-
raise Exception("Argument must not be the result from a multi-valued function")
|
|
768
|
-
arg_type = x.type[0]
|
|
769
|
-
else:
|
|
770
|
-
arg_type = x.type
|
|
771
|
-
if arg_type.__module__ == "warp.types":
|
|
772
|
-
arg_types.append(arg_type.__name__)
|
|
773
|
-
else:
|
|
774
|
-
arg_types.append(arg_type.__module__ + "." + arg_type.__name__)
|
|
775
|
-
|
|
776
|
-
if isinstance(x, warp.context.Function):
|
|
777
|
-
arg_types.append("function")
|
|
778
|
-
|
|
779
|
-
raise Exception(
|
|
780
|
-
f"Couldn't find function overload for '{func.key}' that matched inputs with types: [{', '.join(arg_types)}]"
|
|
781
|
-
)
|
|
871
|
+
return f
|
|
872
|
+
|
|
873
|
+
# unresolved function, report error
|
|
874
|
+
arg_types = []
|
|
875
|
+
|
|
876
|
+
for x in args:
|
|
877
|
+
if isinstance(x, Var):
|
|
878
|
+
# shorten Warp primitive type names
|
|
879
|
+
if isinstance(x.type, list):
|
|
880
|
+
if len(x.type) != 1:
|
|
881
|
+
raise WarpCodegenError("Argument must not be the result from a multi-valued function")
|
|
882
|
+
arg_type = x.type[0]
|
|
883
|
+
else:
|
|
884
|
+
arg_type = x.type
|
|
782
885
|
|
|
783
|
-
|
|
784
|
-
|
|
886
|
+
arg_types.append(type_repr(arg_type))
|
|
887
|
+
|
|
888
|
+
if isinstance(x, warp.context.Function):
|
|
889
|
+
arg_types.append("function")
|
|
890
|
+
|
|
891
|
+
raise WarpCodegenError(
|
|
892
|
+
f"Couldn't find function overload for '{func.key}' that matched inputs with types: [{', '.join(arg_types)}]"
|
|
893
|
+
)
|
|
894
|
+
|
|
895
|
+
def add_call(adj, func, args, min_outputs=None, templates=[], kwds=None):
|
|
896
|
+
func = adj.resolve_func(func, args, min_outputs, templates, kwds)
|
|
785
897
|
|
|
786
898
|
# push any default values onto args
|
|
787
899
|
for i, (arg_name, arg_type) in enumerate(func.input_types.items()):
|
|
788
900
|
if i >= len(args):
|
|
789
|
-
if arg_name in
|
|
901
|
+
if arg_name in func.defaults:
|
|
790
902
|
const = adj.add_constant(func.defaults[arg_name])
|
|
791
903
|
args.append(const)
|
|
792
904
|
else:
|
|
793
|
-
match = False
|
|
794
905
|
break
|
|
795
906
|
|
|
796
907
|
# if it is a user-function then build it recursively
|
|
@@ -798,105 +909,105 @@ class Adjoint:
|
|
|
798
909
|
adj.builder.build_function(func)
|
|
799
910
|
|
|
800
911
|
# evaluate the function type based on inputs
|
|
801
|
-
|
|
912
|
+
arg_types = [strip_reference(a.type) for a in args if not isinstance(a, warp.context.Function)]
|
|
913
|
+
return_type = func.value_func(arg_types, kwds, templates)
|
|
802
914
|
|
|
803
915
|
func_name = compute_type_str(func.native_func, templates)
|
|
916
|
+
param_types = list(func.input_types.values())
|
|
804
917
|
|
|
805
918
|
use_initializer_list = func.initializer_list_func(args, templates)
|
|
806
919
|
|
|
807
|
-
|
|
920
|
+
args_var = [
|
|
921
|
+
adj.load(a)
|
|
922
|
+
if not ((param_types[i] == Reference or param_types[i] == Callable) if i < len(param_types) else False)
|
|
923
|
+
else a
|
|
924
|
+
for i, a in enumerate(args)
|
|
925
|
+
]
|
|
926
|
+
|
|
927
|
+
if return_type is None:
|
|
808
928
|
# handles expression (zero output) functions, e.g.: void do_something();
|
|
809
929
|
|
|
810
|
-
|
|
811
|
-
|
|
930
|
+
output = None
|
|
931
|
+
output_list = []
|
|
932
|
+
|
|
933
|
+
forward_call = (
|
|
934
|
+
f"{func.namespace}{func_name}({adj.format_forward_call_args(args_var, use_initializer_list)});"
|
|
812
935
|
)
|
|
813
936
|
replay_call = forward_call
|
|
814
937
|
if func.custom_replay_func is not None:
|
|
815
|
-
replay_call = "{}replay_{}({});"
|
|
816
|
-
func.namespace, func_name, adj.format_forward_call_args(args, use_initializer_list)
|
|
817
|
-
)
|
|
818
|
-
if func.skip_replay:
|
|
819
|
-
adj.add_forward(forward_call, replay="// " + replay_call)
|
|
820
|
-
else:
|
|
821
|
-
adj.add_forward(forward_call, replay=replay_call)
|
|
822
|
-
|
|
823
|
-
if not func.missing_grad and len(args):
|
|
824
|
-
arg_str = adj.format_reverse_call_args(args, [], {}, {}, use_initializer_list)
|
|
825
|
-
if arg_str is not None:
|
|
826
|
-
reverse_call = "{}adj_{}({});".format(func.namespace, func.native_func, arg_str)
|
|
827
|
-
adj.add_reverse(reverse_call)
|
|
828
|
-
|
|
829
|
-
return None
|
|
938
|
+
replay_call = f"{func.namespace}replay_{func_name}({adj.format_forward_call_args(args_var, use_initializer_list)});"
|
|
830
939
|
|
|
831
|
-
elif not isinstance(
|
|
940
|
+
elif not isinstance(return_type, list) or len(return_type) == 1:
|
|
832
941
|
# handle simple function (one output)
|
|
833
942
|
|
|
834
|
-
if isinstance(
|
|
835
|
-
|
|
836
|
-
output = adj.add_var(
|
|
837
|
-
|
|
838
|
-
|
|
839
|
-
)
|
|
943
|
+
if isinstance(return_type, list):
|
|
944
|
+
return_type = return_type[0]
|
|
945
|
+
output = adj.add_var(return_type)
|
|
946
|
+
output_list = [output]
|
|
947
|
+
|
|
948
|
+
forward_call = f"var_{output} = {func.namespace}{func_name}({adj.format_forward_call_args(args_var, use_initializer_list)});"
|
|
840
949
|
replay_call = forward_call
|
|
841
950
|
if func.custom_replay_func is not None:
|
|
842
|
-
replay_call = "var_{} = {}replay_{}({});"
|
|
843
|
-
output, func.namespace, func_name, adj.format_forward_call_args(args, use_initializer_list)
|
|
844
|
-
)
|
|
845
|
-
|
|
846
|
-
if func.skip_replay:
|
|
847
|
-
adj.add_forward(forward_call, replay="// " + replay_call)
|
|
848
|
-
else:
|
|
849
|
-
adj.add_forward(forward_call, replay=replay_call)
|
|
850
|
-
|
|
851
|
-
if not func.missing_grad and len(args):
|
|
852
|
-
arg_str = adj.format_reverse_call_args(args, [output], {}, {}, use_initializer_list)
|
|
853
|
-
if arg_str is not None:
|
|
854
|
-
reverse_call = "{}adj_{}({});".format(func.namespace, func.native_func, arg_str)
|
|
855
|
-
adj.add_reverse(reverse_call)
|
|
856
|
-
|
|
857
|
-
return output
|
|
951
|
+
replay_call = f"var_{output} = {func.namespace}replay_{func_name}({adj.format_forward_call_args(args_var, use_initializer_list)});"
|
|
858
952
|
|
|
859
953
|
else:
|
|
860
954
|
# handle multiple value functions
|
|
861
955
|
|
|
862
|
-
output = [adj.add_var(v) for v in
|
|
863
|
-
|
|
864
|
-
|
|
956
|
+
output = [adj.add_var(v) for v in return_type]
|
|
957
|
+
output_list = output
|
|
958
|
+
|
|
959
|
+
forward_call = (
|
|
960
|
+
f"{func.namespace}{func_name}({adj.format_forward_call_args(args_var + output, use_initializer_list)});"
|
|
865
961
|
)
|
|
866
|
-
|
|
962
|
+
replay_call = forward_call
|
|
867
963
|
|
|
868
|
-
|
|
869
|
-
|
|
870
|
-
|
|
871
|
-
|
|
872
|
-
|
|
873
|
-
|
|
874
|
-
|
|
964
|
+
if func.skip_replay:
|
|
965
|
+
adj.add_forward(forward_call, replay="// " + replay_call)
|
|
966
|
+
else:
|
|
967
|
+
adj.add_forward(forward_call, replay=replay_call)
|
|
968
|
+
|
|
969
|
+
if not func.missing_grad and len(args):
|
|
970
|
+
reverse_has_output_args = (
|
|
971
|
+
func.require_original_output_arg or len(output_list) > 1
|
|
972
|
+
) and func.custom_grad_func is None
|
|
973
|
+
arg_str = adj.format_reverse_call_args(
|
|
974
|
+
args_var,
|
|
975
|
+
args,
|
|
976
|
+
output_list,
|
|
977
|
+
use_initializer_list,
|
|
978
|
+
has_output_args=reverse_has_output_args,
|
|
979
|
+
require_original_output_arg=func.require_original_output_arg,
|
|
980
|
+
)
|
|
981
|
+
if arg_str is not None:
|
|
982
|
+
reverse_call = f"{func.namespace}adj_{func.native_func}({arg_str});"
|
|
983
|
+
adj.add_reverse(reverse_call)
|
|
875
984
|
|
|
876
|
-
|
|
877
|
-
return output[0]
|
|
985
|
+
return output
|
|
878
986
|
|
|
879
|
-
|
|
987
|
+
def add_builtin_call(adj, func_name, args, min_outputs=None, templates=[], kwds=None):
|
|
988
|
+
func = warp.context.builtin_functions[func_name]
|
|
989
|
+
return adj.add_call(func, args, min_outputs, templates, kwds)
|
|
880
990
|
|
|
881
991
|
def add_return(adj, var):
|
|
882
992
|
if var is None or len(var) == 0:
|
|
883
|
-
adj.add_forward("return;", "goto label{};"
|
|
993
|
+
adj.add_forward("return;", f"goto label{adj.label_count};")
|
|
884
994
|
elif len(var) == 1:
|
|
885
|
-
adj.add_forward("return
|
|
995
|
+
adj.add_forward(f"return {var[0].emit()};", f"goto label{adj.label_count};")
|
|
886
996
|
adj.add_reverse("adj_" + str(var[0]) + " += adj_ret;")
|
|
887
997
|
else:
|
|
888
998
|
for i, v in enumerate(var):
|
|
889
|
-
adj.add_forward("ret_{} =
|
|
890
|
-
adj.add_reverse("adj_{} += adj_ret_{};"
|
|
891
|
-
adj.add_forward("return;", "goto label{};"
|
|
999
|
+
adj.add_forward(f"ret_{i} = {v.emit()};")
|
|
1000
|
+
adj.add_reverse(f"adj_{v} += adj_ret_{i};")
|
|
1001
|
+
adj.add_forward("return;", f"goto label{adj.label_count};")
|
|
892
1002
|
|
|
893
|
-
adj.add_reverse("label{}:;"
|
|
1003
|
+
adj.add_reverse(f"label{adj.label_count}:;")
|
|
894
1004
|
|
|
895
1005
|
adj.label_count += 1
|
|
896
1006
|
|
|
897
1007
|
# define an if statement
|
|
898
1008
|
def begin_if(adj, cond):
|
|
899
|
-
|
|
1009
|
+
cond = adj.load(cond)
|
|
1010
|
+
adj.add_forward(f"if ({cond.emit()}) {{")
|
|
900
1011
|
adj.add_reverse("}")
|
|
901
1012
|
|
|
902
1013
|
adj.indent()
|
|
@@ -905,10 +1016,12 @@ class Adjoint:
|
|
|
905
1016
|
adj.dedent()
|
|
906
1017
|
|
|
907
1018
|
adj.add_forward("}")
|
|
908
|
-
adj.
|
|
1019
|
+
cond = adj.load(cond)
|
|
1020
|
+
adj.add_reverse(f"if ({cond.emit()}) {{")
|
|
909
1021
|
|
|
910
1022
|
def begin_else(adj, cond):
|
|
911
|
-
adj.
|
|
1023
|
+
cond = adj.load(cond)
|
|
1024
|
+
adj.add_forward(f"if (!{cond.emit()}) {{")
|
|
912
1025
|
adj.add_reverse("}")
|
|
913
1026
|
|
|
914
1027
|
adj.indent()
|
|
@@ -917,7 +1030,8 @@ class Adjoint:
|
|
|
917
1030
|
adj.dedent()
|
|
918
1031
|
|
|
919
1032
|
adj.add_forward("}")
|
|
920
|
-
adj.
|
|
1033
|
+
cond = adj.load(cond)
|
|
1034
|
+
adj.add_reverse(f"if (!{cond.emit()}) {{")
|
|
921
1035
|
|
|
922
1036
|
# define a for-loop
|
|
923
1037
|
def begin_for(adj, iter):
|
|
@@ -927,10 +1041,10 @@ class Adjoint:
|
|
|
927
1041
|
adj.indent()
|
|
928
1042
|
|
|
929
1043
|
# evaluate cond
|
|
930
|
-
adj.add_forward(f"if (iter_cmp(
|
|
1044
|
+
adj.add_forward(f"if (iter_cmp({iter.emit()}) == 0) goto for_end_{cond_block.label};")
|
|
931
1045
|
|
|
932
1046
|
# evaluate iter
|
|
933
|
-
val = adj.
|
|
1047
|
+
val = adj.add_builtin_call("iter_next", [iter])
|
|
934
1048
|
|
|
935
1049
|
adj.begin_block()
|
|
936
1050
|
|
|
@@ -961,17 +1075,14 @@ class Adjoint:
|
|
|
961
1075
|
reverse = []
|
|
962
1076
|
|
|
963
1077
|
# reverse iterator
|
|
964
|
-
reverse.append(adj.
|
|
1078
|
+
reverse.append(adj.indentation + f"{iter.emit()} = wp::iter_reverse({iter.emit()});")
|
|
965
1079
|
|
|
966
1080
|
for i in cond_block.body_forward:
|
|
967
1081
|
reverse.append(i)
|
|
968
1082
|
|
|
969
1083
|
# zero adjoints
|
|
970
1084
|
for i in body_block.vars:
|
|
971
|
-
|
|
972
|
-
reverse.append(adj.prefix + f"\tadj_{i} = {i.ctype()}{{}};")
|
|
973
|
-
else:
|
|
974
|
-
reverse.append(adj.prefix + f"\tadj_{i} = {i.ctype()}(0);")
|
|
1085
|
+
reverse.append(adj.indentation + f"\t{i.emit_adj()} = {{}};")
|
|
975
1086
|
|
|
976
1087
|
# replay
|
|
977
1088
|
for i in body_block.body_replay:
|
|
@@ -981,14 +1092,14 @@ class Adjoint:
|
|
|
981
1092
|
for i in reversed(body_block.body_reverse):
|
|
982
1093
|
reverse.append(i)
|
|
983
1094
|
|
|
984
|
-
reverse.append(adj.
|
|
985
|
-
reverse.append(adj.
|
|
1095
|
+
reverse.append(adj.indentation + f"\tgoto for_start_{cond_block.label};")
|
|
1096
|
+
reverse.append(adj.indentation + f"for_end_{cond_block.label}:;")
|
|
986
1097
|
|
|
987
1098
|
adj.blocks[-1].body_reverse.extend(reversed(reverse))
|
|
988
1099
|
|
|
989
1100
|
# define a while loop
|
|
990
1101
|
def begin_while(adj, cond):
|
|
991
|
-
#
|
|
1102
|
+
# evaluate condition in its own block
|
|
992
1103
|
# so we can control replay
|
|
993
1104
|
cond_block = adj.begin_block()
|
|
994
1105
|
adj.loop_blocks.append(cond_block)
|
|
@@ -996,7 +1107,7 @@ class Adjoint:
|
|
|
996
1107
|
|
|
997
1108
|
c = adj.eval(cond)
|
|
998
1109
|
|
|
999
|
-
cond_block.body_forward.append(f"if ((
|
|
1110
|
+
cond_block.body_forward.append(f"if (({c.emit()}) == false) goto while_end_{cond_block.label};")
|
|
1000
1111
|
|
|
1001
1112
|
# being block around loop
|
|
1002
1113
|
adj.begin_block()
|
|
@@ -1030,10 +1141,7 @@ class Adjoint:
|
|
|
1030
1141
|
|
|
1031
1142
|
# zero adjoints of local vars
|
|
1032
1143
|
for i in body_block.vars:
|
|
1033
|
-
|
|
1034
|
-
reverse.append(f"adj_{i} = {i.ctype()}{{}};")
|
|
1035
|
-
else:
|
|
1036
|
-
reverse.append(f"adj_{i} = {i.ctype()}(0);")
|
|
1144
|
+
reverse.append(f"{i.emit_adj()} = {{}};")
|
|
1037
1145
|
|
|
1038
1146
|
# replay
|
|
1039
1147
|
for i in body_block.body_replay:
|
|
@@ -1053,6 +1161,10 @@ class Adjoint:
|
|
|
1053
1161
|
for f in node.body:
|
|
1054
1162
|
adj.eval(f)
|
|
1055
1163
|
|
|
1164
|
+
if adj.return_var is not None and len(adj.return_var) == 1:
|
|
1165
|
+
if not isinstance(node.body[-1], ast.Return):
|
|
1166
|
+
adj.add_forward("return {};", skip_replay=True)
|
|
1167
|
+
|
|
1056
1168
|
def emit_If(adj, node):
|
|
1057
1169
|
if len(node.body) == 0:
|
|
1058
1170
|
return None
|
|
@@ -1080,7 +1192,7 @@ class Adjoint:
|
|
|
1080
1192
|
|
|
1081
1193
|
if var1 != var2:
|
|
1082
1194
|
# insert a phi function that selects var1, var2 based on cond
|
|
1083
|
-
out = adj.
|
|
1195
|
+
out = adj.add_builtin_call("select", [cond, var1, var2])
|
|
1084
1196
|
adj.symbols[sym] = out
|
|
1085
1197
|
|
|
1086
1198
|
symbols_prev = adj.symbols.copy()
|
|
@@ -1104,7 +1216,7 @@ class Adjoint:
|
|
|
1104
1216
|
if var1 != var2:
|
|
1105
1217
|
# insert a phi function that selects var1, var2 based on cond
|
|
1106
1218
|
# note the reversed order of vars since we want to use !cond as our select
|
|
1107
|
-
out = adj.
|
|
1219
|
+
out = adj.add_builtin_call("select", [cond, var2, var1])
|
|
1108
1220
|
adj.symbols[sym] = out
|
|
1109
1221
|
|
|
1110
1222
|
def emit_Compare(adj, node):
|
|
@@ -1126,7 +1238,7 @@ class Adjoint:
|
|
|
1126
1238
|
elif isinstance(op, ast.Or):
|
|
1127
1239
|
func = "||"
|
|
1128
1240
|
else:
|
|
1129
|
-
raise
|
|
1241
|
+
raise WarpCodegenKeyError(f"Op {op} is not supported")
|
|
1130
1242
|
|
|
1131
1243
|
return adj.add_bool_op(func, [adj.eval(expr) for expr in node.values])
|
|
1132
1244
|
|
|
@@ -1146,7 +1258,7 @@ class Adjoint:
|
|
|
1146
1258
|
obj = capturedvars.get(str(node.id), None)
|
|
1147
1259
|
|
|
1148
1260
|
if obj is None:
|
|
1149
|
-
raise
|
|
1261
|
+
raise WarpCodegenKeyError("Referencing undefined symbol: " + str(node.id))
|
|
1150
1262
|
|
|
1151
1263
|
if warp.types.is_value(obj):
|
|
1152
1264
|
# evaluate constant
|
|
@@ -1158,26 +1270,96 @@ class Adjoint:
|
|
|
1158
1270
|
# pass it back to the caller for processing
|
|
1159
1271
|
return obj
|
|
1160
1272
|
|
|
1273
|
+
@staticmethod
|
|
1274
|
+
def resolve_type_attribute(var_type: type, attr: str):
|
|
1275
|
+
if isinstance(var_type, type) and type_is_value(var_type):
|
|
1276
|
+
if attr == "dtype":
|
|
1277
|
+
return type_scalar_type(var_type)
|
|
1278
|
+
elif attr == "length":
|
|
1279
|
+
return type_length(var_type)
|
|
1280
|
+
|
|
1281
|
+
return getattr(var_type, attr, None)
|
|
1282
|
+
|
|
1283
|
+
def vector_component_index(adj, component, vector_type):
|
|
1284
|
+
if len(component) != 1:
|
|
1285
|
+
raise WarpCodegenAttributeError(f"Vector swizzle must be single character, got .{component}")
|
|
1286
|
+
|
|
1287
|
+
dim = vector_type._shape_[0]
|
|
1288
|
+
swizzles = "xyzw"[0:dim]
|
|
1289
|
+
if component not in swizzles:
|
|
1290
|
+
raise WarpCodegenAttributeError(
|
|
1291
|
+
f"Vector swizzle for {vector_type} must be one of {swizzles}, got {component}"
|
|
1292
|
+
)
|
|
1293
|
+
|
|
1294
|
+
index = swizzles.index(component)
|
|
1295
|
+
index = adj.add_constant(index)
|
|
1296
|
+
return index
|
|
1297
|
+
|
|
1298
|
+
@staticmethod
|
|
1299
|
+
def is_differentiable_value_type(var_type):
|
|
1300
|
+
# checks that the argument type is a value type (i.e, not an array)
|
|
1301
|
+
# possibly holding differentiable values (for which gradients must be accumulated)
|
|
1302
|
+
return type_scalar_type(var_type) in float_types or isinstance(var_type, Struct)
|
|
1303
|
+
|
|
1161
1304
|
def emit_Attribute(adj, node):
|
|
1162
|
-
|
|
1163
|
-
|
|
1305
|
+
if hasattr(node, "is_adjoint"):
|
|
1306
|
+
node.value.is_adjoint = True
|
|
1307
|
+
|
|
1308
|
+
aggregate = adj.eval(node.value)
|
|
1164
1309
|
|
|
1165
|
-
|
|
1166
|
-
|
|
1310
|
+
try:
|
|
1311
|
+
if isinstance(aggregate, types.ModuleType) or isinstance(aggregate, type):
|
|
1312
|
+
out = getattr(aggregate, node.attr)
|
|
1167
1313
|
|
|
1168
1314
|
if warp.types.is_value(out):
|
|
1169
1315
|
return adj.add_constant(out)
|
|
1170
1316
|
|
|
1171
1317
|
return out
|
|
1172
1318
|
|
|
1173
|
-
|
|
1174
|
-
|
|
1175
|
-
|
|
1319
|
+
if hasattr(node, "is_adjoint"):
|
|
1320
|
+
# create a Var that points to the struct attribute, i.e.: directly generates `struct.attr` when used
|
|
1321
|
+
attr_name = aggregate.label + "." + node.attr
|
|
1322
|
+
attr_type = aggregate.type.vars[node.attr].type
|
|
1323
|
+
|
|
1324
|
+
return Var(attr_name, attr_type)
|
|
1325
|
+
|
|
1326
|
+
aggregate_type = strip_reference(aggregate.type)
|
|
1176
1327
|
|
|
1177
|
-
|
|
1328
|
+
# reading a vector component
|
|
1329
|
+
if type_is_vector(aggregate_type):
|
|
1330
|
+
index = adj.vector_component_index(node.attr, aggregate_type)
|
|
1178
1331
|
|
|
1179
|
-
|
|
1180
|
-
|
|
1332
|
+
return adj.add_builtin_call("extract", [aggregate, index])
|
|
1333
|
+
|
|
1334
|
+
else:
|
|
1335
|
+
attr_type = Reference(aggregate_type.vars[node.attr].type)
|
|
1336
|
+
attr = adj.add_var(attr_type)
|
|
1337
|
+
|
|
1338
|
+
if is_reference(aggregate.type):
|
|
1339
|
+
adj.add_forward(f"{attr.emit()} = &({aggregate.emit()}->{node.attr});")
|
|
1340
|
+
else:
|
|
1341
|
+
adj.add_forward(f"{attr.emit()} = &({aggregate.emit()}.{node.attr});")
|
|
1342
|
+
|
|
1343
|
+
if adj.is_differentiable_value_type(strip_reference(attr_type)):
|
|
1344
|
+
adj.add_reverse(f"{aggregate.emit_adj()}.{node.attr} += {attr.emit_adj()};")
|
|
1345
|
+
else:
|
|
1346
|
+
adj.add_reverse(f"{aggregate.emit_adj()}.{node.attr} = {attr.emit_adj()};")
|
|
1347
|
+
|
|
1348
|
+
return attr
|
|
1349
|
+
|
|
1350
|
+
except (KeyError, AttributeError):
|
|
1351
|
+
# Try resolving as type attribute
|
|
1352
|
+
aggregate_type = strip_reference(aggregate.type) if isinstance(aggregate, Var) else aggregate
|
|
1353
|
+
|
|
1354
|
+
type_attribute = adj.resolve_type_attribute(aggregate_type, node.attr)
|
|
1355
|
+
if type_attribute is not None:
|
|
1356
|
+
return type_attribute
|
|
1357
|
+
|
|
1358
|
+
if isinstance(aggregate, Var):
|
|
1359
|
+
raise WarpCodegenAttributeError(
|
|
1360
|
+
f"Error, `{node.attr}` is not an attribute of '{node.value.id}' ({type_repr(aggregate.type)})"
|
|
1361
|
+
)
|
|
1362
|
+
raise WarpCodegenAttributeError(f"Error, `{node.attr}` is not an attribute of '{aggregate}'")
|
|
1181
1363
|
|
|
1182
1364
|
def emit_String(adj, node):
|
|
1183
1365
|
# string constant
|
|
@@ -1194,19 +1376,25 @@ class Adjoint:
|
|
|
1194
1376
|
adj.symbols[key] = out
|
|
1195
1377
|
return out
|
|
1196
1378
|
|
|
1379
|
+
def emit_Ellipsis(adj, node):
|
|
1380
|
+
# stubbed @wp.native_func
|
|
1381
|
+
return
|
|
1382
|
+
|
|
1197
1383
|
def emit_NameConstant(adj, node):
|
|
1198
|
-
if node.value
|
|
1384
|
+
if node.value:
|
|
1199
1385
|
return adj.add_constant(True)
|
|
1200
|
-
elif node.value is False:
|
|
1201
|
-
return adj.add_constant(False)
|
|
1202
1386
|
elif node.value is None:
|
|
1203
|
-
raise
|
|
1387
|
+
raise WarpCodegenTypeError("None type unsupported")
|
|
1388
|
+
else:
|
|
1389
|
+
return adj.add_constant(False)
|
|
1204
1390
|
|
|
1205
1391
|
def emit_Constant(adj, node):
|
|
1206
1392
|
if isinstance(node, ast.Str):
|
|
1207
1393
|
return adj.emit_String(node)
|
|
1208
1394
|
elif isinstance(node, ast.Num):
|
|
1209
1395
|
return adj.emit_Num(node)
|
|
1396
|
+
elif isinstance(node, ast.Ellipsis):
|
|
1397
|
+
return adj.emit_Ellipsis(node)
|
|
1210
1398
|
else:
|
|
1211
1399
|
assert isinstance(node, ast.NameConstant)
|
|
1212
1400
|
return adj.emit_NameConstant(node)
|
|
@@ -1217,18 +1405,16 @@ class Adjoint:
|
|
|
1217
1405
|
right = adj.eval(node.right)
|
|
1218
1406
|
|
|
1219
1407
|
name = builtin_operators[type(node.op)]
|
|
1220
|
-
func = warp.context.builtin_functions[name]
|
|
1221
1408
|
|
|
1222
|
-
return adj.
|
|
1409
|
+
return adj.add_builtin_call(name, [left, right])
|
|
1223
1410
|
|
|
1224
1411
|
def emit_UnaryOp(adj, node):
|
|
1225
1412
|
# evaluate unary op arguments
|
|
1226
1413
|
arg = adj.eval(node.operand)
|
|
1227
1414
|
|
|
1228
1415
|
name = builtin_operators[type(node.op)]
|
|
1229
|
-
func = warp.context.builtin_functions[name]
|
|
1230
1416
|
|
|
1231
|
-
return adj.
|
|
1417
|
+
return adj.add_builtin_call(name, [arg])
|
|
1232
1418
|
|
|
1233
1419
|
def materialize_redefinitions(adj, symbols):
|
|
1234
1420
|
# detect symbols with conflicting definitions (assigned inside the for loop)
|
|
@@ -1240,19 +1426,17 @@ class Adjoint:
|
|
|
1240
1426
|
if var1 != var2:
|
|
1241
1427
|
if warp.config.verbose and not adj.custom_reverse_mode:
|
|
1242
1428
|
lineno = adj.lineno + adj.fun_lineno
|
|
1243
|
-
line = adj.
|
|
1244
|
-
msg = f'Warning: detected mutated variable {sym} during a dynamic for-loop in function "{adj.fun_name}" at {adj.filename}:{lineno}: this
|
|
1429
|
+
line = adj.source_lines[adj.lineno]
|
|
1430
|
+
msg = f'Warning: detected mutated variable {sym} during a dynamic for-loop in function "{adj.fun_name}" at {adj.filename}:{lineno}: this may not be a differentiable operation.\n{line}\n'
|
|
1245
1431
|
print(msg)
|
|
1246
1432
|
|
|
1247
1433
|
if var1.constant is not None:
|
|
1248
|
-
raise
|
|
1249
|
-
"Error mutating a constant {} inside a dynamic loop, use the following syntax: pi = float(3.141) to declare a dynamic variable"
|
|
1250
|
-
sym
|
|
1251
|
-
)
|
|
1434
|
+
raise WarpCodegenError(
|
|
1435
|
+
f"Error mutating a constant {sym} inside a dynamic loop, use the following syntax: pi = float(3.141) to declare a dynamic variable"
|
|
1252
1436
|
)
|
|
1253
1437
|
|
|
1254
1438
|
# overwrite the old variable value (violates SSA)
|
|
1255
|
-
adj.
|
|
1439
|
+
adj.add_builtin_call("assign", [var1, var2])
|
|
1256
1440
|
|
|
1257
1441
|
# reset the symbol to point to the original variable
|
|
1258
1442
|
adj.symbols[sym] = var1
|
|
@@ -1271,35 +1455,20 @@ class Adjoint:
|
|
|
1271
1455
|
|
|
1272
1456
|
adj.end_while()
|
|
1273
1457
|
|
|
1274
|
-
def is_num(adj, a):
|
|
1275
|
-
# simple constant
|
|
1276
|
-
if isinstance(a, ast.Num):
|
|
1277
|
-
return True
|
|
1278
|
-
# expression of form -constant
|
|
1279
|
-
elif isinstance(a, ast.UnaryOp) and isinstance(a.op, ast.USub) and isinstance(a.operand, ast.Num):
|
|
1280
|
-
return True
|
|
1281
|
-
else:
|
|
1282
|
-
# try and resolve the expression to an object
|
|
1283
|
-
# e.g.: wp.constant in the globals scope
|
|
1284
|
-
obj, path = adj.resolve_path(a)
|
|
1285
|
-
if warp.types.is_int(obj):
|
|
1286
|
-
return True
|
|
1287
|
-
else:
|
|
1288
|
-
return False
|
|
1289
|
-
|
|
1290
1458
|
def eval_num(adj, a):
|
|
1291
1459
|
if isinstance(a, ast.Num):
|
|
1292
|
-
return a.n
|
|
1293
|
-
|
|
1294
|
-
return -a.operand.n
|
|
1295
|
-
|
|
1296
|
-
|
|
1297
|
-
|
|
1298
|
-
|
|
1299
|
-
|
|
1300
|
-
|
|
1301
|
-
|
|
1302
|
-
|
|
1460
|
+
return True, a.n
|
|
1461
|
+
if isinstance(a, ast.UnaryOp) and isinstance(a.op, ast.USub) and isinstance(a.operand, ast.Num):
|
|
1462
|
+
return True, -a.operand.n
|
|
1463
|
+
|
|
1464
|
+
# try and resolve the expression to an object
|
|
1465
|
+
# e.g.: wp.constant in the globals scope
|
|
1466
|
+
obj, _ = adj.resolve_static_expression(a)
|
|
1467
|
+
|
|
1468
|
+
if isinstance(obj, Var) and obj.constant is not None:
|
|
1469
|
+
obj = obj.constant
|
|
1470
|
+
|
|
1471
|
+
return warp.types.is_int(obj), obj
|
|
1303
1472
|
|
|
1304
1473
|
# detects whether a loop contains a break (or continue) statement
|
|
1305
1474
|
def contains_break(adj, body):
|
|
@@ -1322,61 +1491,82 @@ class Adjoint:
|
|
|
1322
1491
|
|
|
1323
1492
|
# returns a constant range() if unrollable, otherwise None
|
|
1324
1493
|
def get_unroll_range(adj, loop):
|
|
1325
|
-
if
|
|
1494
|
+
if (
|
|
1495
|
+
not isinstance(loop.iter, ast.Call)
|
|
1496
|
+
or not isinstance(loop.iter.func, ast.Name)
|
|
1497
|
+
or loop.iter.func.id != "range"
|
|
1498
|
+
or len(loop.iter.args) == 0
|
|
1499
|
+
or len(loop.iter.args) > 3
|
|
1500
|
+
):
|
|
1326
1501
|
return None
|
|
1327
1502
|
|
|
1328
|
-
|
|
1329
|
-
|
|
1330
|
-
|
|
1331
|
-
|
|
1332
|
-
|
|
1333
|
-
|
|
1334
|
-
|
|
1335
|
-
|
|
1336
|
-
|
|
1337
|
-
|
|
1338
|
-
|
|
1339
|
-
|
|
1340
|
-
|
|
1341
|
-
|
|
1342
|
-
|
|
1343
|
-
|
|
1344
|
-
|
|
1345
|
-
|
|
1346
|
-
|
|
1347
|
-
|
|
1348
|
-
|
|
1349
|
-
|
|
1350
|
-
|
|
1351
|
-
|
|
1352
|
-
|
|
1353
|
-
|
|
1354
|
-
|
|
1355
|
-
|
|
1356
|
-
|
|
1357
|
-
|
|
1358
|
-
|
|
1359
|
-
|
|
1360
|
-
|
|
1361
|
-
|
|
1362
|
-
|
|
1503
|
+
# if all range() arguments are numeric constants we will unroll
|
|
1504
|
+
# note that this only handles trivial constants, it will not unroll
|
|
1505
|
+
# constant compile-time expressions e.g.: range(0, 3*2)
|
|
1506
|
+
|
|
1507
|
+
# Evaluate the arguments and check that they are numeric constants
|
|
1508
|
+
# It is important to do that in one pass, so that if evaluating these arguments have side effects
|
|
1509
|
+
# the code does not get generated more than once
|
|
1510
|
+
range_args = [adj.eval_num(arg) for arg in loop.iter.args]
|
|
1511
|
+
arg_is_numeric, arg_values = zip(*range_args)
|
|
1512
|
+
|
|
1513
|
+
if all(arg_is_numeric):
|
|
1514
|
+
# All argument are numeric constants
|
|
1515
|
+
|
|
1516
|
+
# range(end)
|
|
1517
|
+
if len(loop.iter.args) == 1:
|
|
1518
|
+
start = 0
|
|
1519
|
+
end = arg_values[0]
|
|
1520
|
+
step = 1
|
|
1521
|
+
|
|
1522
|
+
# range(start, end)
|
|
1523
|
+
elif len(loop.iter.args) == 2:
|
|
1524
|
+
start = arg_values[0]
|
|
1525
|
+
end = arg_values[1]
|
|
1526
|
+
step = 1
|
|
1527
|
+
|
|
1528
|
+
# range(start, end, step)
|
|
1529
|
+
elif len(loop.iter.args) == 3:
|
|
1530
|
+
start = arg_values[0]
|
|
1531
|
+
end = arg_values[1]
|
|
1532
|
+
step = arg_values[2]
|
|
1533
|
+
|
|
1534
|
+
# test if we're above max unroll count
|
|
1535
|
+
max_iters = abs(end - start) // abs(step)
|
|
1536
|
+
max_unroll = adj.builder.options["max_unroll"]
|
|
1537
|
+
|
|
1538
|
+
ok_to_unroll = True
|
|
1539
|
+
|
|
1540
|
+
if max_iters > max_unroll:
|
|
1541
|
+
if warp.config.verbose:
|
|
1542
|
+
print(
|
|
1543
|
+
f"Warning: fixed-size loop count of {max_iters} is larger than the module 'max_unroll' limit of {max_unroll}, will generate dynamic loop."
|
|
1544
|
+
)
|
|
1545
|
+
ok_to_unroll = False
|
|
1363
1546
|
|
|
1364
|
-
|
|
1365
|
-
|
|
1366
|
-
|
|
1367
|
-
|
|
1547
|
+
elif adj.contains_break(loop.body):
|
|
1548
|
+
if warp.config.verbose:
|
|
1549
|
+
print("Warning: 'break' or 'continue' found in loop body, will generate dynamic loop.")
|
|
1550
|
+
ok_to_unroll = False
|
|
1368
1551
|
|
|
1369
|
-
|
|
1370
|
-
|
|
1552
|
+
if ok_to_unroll:
|
|
1553
|
+
return range(start, end, step)
|
|
1554
|
+
|
|
1555
|
+
# Unroll is not possible, range needs to be valuated dynamically
|
|
1556
|
+
range_call = adj.add_builtin_call(
|
|
1557
|
+
"range",
|
|
1558
|
+
[adj.add_constant(val) if is_numeric else val for is_numeric, val in range_args],
|
|
1559
|
+
)
|
|
1560
|
+
return range_call
|
|
1371
1561
|
|
|
1372
1562
|
def emit_For(adj, node):
|
|
1373
1563
|
# try and unroll simple range() statements that use constant args
|
|
1374
1564
|
unroll_range = adj.get_unroll_range(node)
|
|
1375
1565
|
|
|
1376
|
-
if unroll_range:
|
|
1566
|
+
if isinstance(unroll_range, range):
|
|
1377
1567
|
for i in unroll_range:
|
|
1378
1568
|
const_iter = adj.add_constant(i)
|
|
1379
|
-
var_iter = adj.
|
|
1569
|
+
var_iter = adj.add_builtin_call("int", [const_iter])
|
|
1380
1570
|
adj.symbols[node.target.id] = var_iter
|
|
1381
1571
|
|
|
1382
1572
|
# eval body
|
|
@@ -1385,8 +1575,12 @@ class Adjoint:
|
|
|
1385
1575
|
|
|
1386
1576
|
# otherwise generate a dynamic loop
|
|
1387
1577
|
else:
|
|
1388
|
-
# evaluate the Iterable
|
|
1389
|
-
|
|
1578
|
+
# evaluate the Iterable -- only if not previously evaluated when trying to unroll
|
|
1579
|
+
if unroll_range is not None:
|
|
1580
|
+
# Range has already been evaluated when trying to unroll, do not re-evaluate
|
|
1581
|
+
iter = unroll_range
|
|
1582
|
+
else:
|
|
1583
|
+
iter = adj.eval(node.iter)
|
|
1390
1584
|
|
|
1391
1585
|
adj.symbols[node.target.id] = adj.begin_for(iter)
|
|
1392
1586
|
|
|
@@ -1415,15 +1609,28 @@ class Adjoint:
|
|
|
1415
1609
|
def emit_Expr(adj, node):
|
|
1416
1610
|
return adj.eval(node.value)
|
|
1417
1611
|
|
|
1612
|
+
def check_tid_in_func_error(adj, node):
|
|
1613
|
+
if adj.is_user_function:
|
|
1614
|
+
if hasattr(node.func, "attr") and node.func.attr == "tid":
|
|
1615
|
+
lineno = adj.lineno + adj.fun_lineno
|
|
1616
|
+
line = adj.source_lines[adj.lineno]
|
|
1617
|
+
raise WarpCodegenError(
|
|
1618
|
+
"tid() may only be called from a Warp kernel, not a Warp function. "
|
|
1619
|
+
"Instead, obtain the indices from a @wp.kernel and pass them as "
|
|
1620
|
+
f"arguments to the function {adj.fun_name}, {adj.filename}:{lineno}:\n{line}\n"
|
|
1621
|
+
)
|
|
1622
|
+
|
|
1418
1623
|
def emit_Call(adj, node):
|
|
1624
|
+
adj.check_tid_in_func_error(node)
|
|
1625
|
+
|
|
1419
1626
|
# try and lookup function in globals by
|
|
1420
1627
|
# resolving path (e.g.: module.submodule.attr)
|
|
1421
|
-
func, path = adj.
|
|
1628
|
+
func, path = adj.resolve_static_expression(node.func)
|
|
1422
1629
|
templates = []
|
|
1423
1630
|
|
|
1424
|
-
if isinstance(func, warp.context.Function)
|
|
1631
|
+
if not isinstance(func, warp.context.Function):
|
|
1425
1632
|
if len(path) == 0:
|
|
1426
|
-
raise
|
|
1633
|
+
raise WarpCodegenError(f"Unknown function or operator: '{node.func.func.id}'")
|
|
1427
1634
|
|
|
1428
1635
|
attr = path[-1]
|
|
1429
1636
|
caller = func
|
|
@@ -1448,7 +1655,7 @@ class Adjoint:
|
|
|
1448
1655
|
func = caller.initializer()
|
|
1449
1656
|
|
|
1450
1657
|
if func is None:
|
|
1451
|
-
raise
|
|
1658
|
+
raise WarpCodegenError(
|
|
1452
1659
|
f"Could not find function {'.'.join(path)} as a built-in or user-defined function. Note that user functions must be annotated with a @wp.func decorator to be called from a kernel."
|
|
1453
1660
|
)
|
|
1454
1661
|
|
|
@@ -1464,9 +1671,14 @@ class Adjoint:
|
|
|
1464
1671
|
if isinstance(kw.value, ast.Num):
|
|
1465
1672
|
return kw.value.n
|
|
1466
1673
|
elif isinstance(kw.value, ast.Tuple):
|
|
1467
|
-
|
|
1674
|
+
arg_is_numeric, arg_values = zip(*(adj.eval_num(e) for e in kw.value.elts))
|
|
1675
|
+
if not all(arg_is_numeric):
|
|
1676
|
+
raise WarpCodegenError(
|
|
1677
|
+
f"All elements of the tuple keyword argument '{kw.name}' must be numeric constants, got '{arg_values}'"
|
|
1678
|
+
)
|
|
1679
|
+
return arg_values
|
|
1468
1680
|
else:
|
|
1469
|
-
return adj.
|
|
1681
|
+
return adj.resolve_static_expression(kw.value)[0]
|
|
1470
1682
|
|
|
1471
1683
|
kwds = {kw.arg: kwval(kw) for kw in node.keywords}
|
|
1472
1684
|
|
|
@@ -1483,15 +1695,19 @@ class Adjoint:
|
|
|
1483
1695
|
# the ast.Index node appears in 3.7 versions
|
|
1484
1696
|
# when performing array slices, e.g.: x = arr[i]
|
|
1485
1697
|
# but in version 3.8 and higher it does not appear
|
|
1698
|
+
|
|
1699
|
+
if hasattr(node, "is_adjoint"):
|
|
1700
|
+
node.value.is_adjoint = True
|
|
1701
|
+
|
|
1486
1702
|
return adj.eval(node.value)
|
|
1487
1703
|
|
|
1488
1704
|
def emit_Subscript(adj, node):
|
|
1489
1705
|
if hasattr(node.value, "attr") and node.value.attr == "adjoint":
|
|
1490
1706
|
# handle adjoint of a variable, i.e. wp.adjoint[var]
|
|
1707
|
+
node.slice.is_adjoint = True
|
|
1491
1708
|
var = adj.eval(node.slice)
|
|
1492
1709
|
var_name = var.label
|
|
1493
|
-
var = Var(f"adj_{var_name}", type=var.type, constant=None, prefix=False
|
|
1494
|
-
adj.symbols[var.label] = var
|
|
1710
|
+
var = Var(f"adj_{var_name}", type=var.type, constant=None, prefix=False)
|
|
1495
1711
|
return var
|
|
1496
1712
|
|
|
1497
1713
|
target = adj.eval(node.value)
|
|
@@ -1514,29 +1730,34 @@ class Adjoint:
|
|
|
1514
1730
|
var = adj.eval(node.slice)
|
|
1515
1731
|
indices.append(var)
|
|
1516
1732
|
|
|
1517
|
-
|
|
1518
|
-
|
|
1733
|
+
target_type = strip_reference(target.type)
|
|
1734
|
+
if is_array(target_type):
|
|
1735
|
+
if len(indices) == target_type.ndim:
|
|
1519
1736
|
# handles array loads (where each dimension has an index specified)
|
|
1520
|
-
out = adj.
|
|
1737
|
+
out = adj.add_builtin_call("address", [target, *indices])
|
|
1521
1738
|
else:
|
|
1522
1739
|
# handles array views (fewer indices than dimensions)
|
|
1523
|
-
out = adj.
|
|
1740
|
+
out = adj.add_builtin_call("view", [target, *indices])
|
|
1524
1741
|
|
|
1525
1742
|
else:
|
|
1526
1743
|
# handles non-array type indexing, e.g: vec3, mat33, etc
|
|
1527
|
-
out = adj.
|
|
1744
|
+
out = adj.add_builtin_call("extract", [target, *indices])
|
|
1528
1745
|
|
|
1529
|
-
out.is_adjoint = target.is_adjoint
|
|
1530
1746
|
return out
|
|
1531
1747
|
|
|
1532
1748
|
def emit_Assign(adj, node):
|
|
1749
|
+
if len(node.targets) != 1:
|
|
1750
|
+
raise WarpCodegenError("Assigning the same value to multiple variables is not supported")
|
|
1751
|
+
|
|
1752
|
+
lhs = node.targets[0]
|
|
1753
|
+
|
|
1533
1754
|
# handle the case where we are assigning multiple output variables
|
|
1534
|
-
if isinstance(
|
|
1755
|
+
if isinstance(lhs, ast.Tuple):
|
|
1535
1756
|
# record the expected number of outputs on the node
|
|
1536
1757
|
# we do this so we can decide which function to
|
|
1537
1758
|
# call based on the number of expected outputs
|
|
1538
1759
|
if isinstance(node.value, ast.Call):
|
|
1539
|
-
node.value.expects = len(
|
|
1760
|
+
node.value.expects = len(lhs.elts)
|
|
1540
1761
|
|
|
1541
1762
|
# evaluate values
|
|
1542
1763
|
if isinstance(node.value, ast.Tuple):
|
|
@@ -1545,49 +1766,43 @@ class Adjoint:
|
|
|
1545
1766
|
out = adj.eval(node.value)
|
|
1546
1767
|
|
|
1547
1768
|
names = []
|
|
1548
|
-
for v in
|
|
1769
|
+
for v in lhs.elts:
|
|
1549
1770
|
if isinstance(v, ast.Name):
|
|
1550
1771
|
names.append(v.id)
|
|
1551
1772
|
else:
|
|
1552
|
-
raise
|
|
1773
|
+
raise WarpCodegenError(
|
|
1553
1774
|
"Multiple return functions can only assign to simple variables, e.g.: x, y = func()"
|
|
1554
1775
|
)
|
|
1555
1776
|
|
|
1556
1777
|
if len(names) != len(out):
|
|
1557
|
-
raise
|
|
1558
|
-
"Multiple return functions need to receive all their output values, incorrect number of values to unpack (expected {}, got {})"
|
|
1559
|
-
len(out), len(names)
|
|
1560
|
-
)
|
|
1778
|
+
raise WarpCodegenError(
|
|
1779
|
+
f"Multiple return functions need to receive all their output values, incorrect number of values to unpack (expected {len(out)}, got {len(names)})"
|
|
1561
1780
|
)
|
|
1562
1781
|
|
|
1563
1782
|
for name, rhs in zip(names, out):
|
|
1564
1783
|
if name in adj.symbols:
|
|
1565
1784
|
if not types_equal(rhs.type, adj.symbols[name].type):
|
|
1566
|
-
raise
|
|
1567
|
-
"Error, assigning to existing symbol {} ({}) with different type ({})"
|
|
1568
|
-
name, adj.symbols[name].type, rhs.type
|
|
1569
|
-
)
|
|
1785
|
+
raise WarpCodegenTypeError(
|
|
1786
|
+
f"Error, assigning to existing symbol {name} ({adj.symbols[name].type}) with different type ({rhs.type})"
|
|
1570
1787
|
)
|
|
1571
1788
|
|
|
1572
1789
|
adj.symbols[name] = rhs
|
|
1573
1790
|
|
|
1574
|
-
return out
|
|
1575
|
-
|
|
1576
1791
|
# handles the case where we are assigning to an array index (e.g.: arr[i] = 2.0)
|
|
1577
|
-
elif isinstance(
|
|
1578
|
-
if hasattr(
|
|
1792
|
+
elif isinstance(lhs, ast.Subscript):
|
|
1793
|
+
if hasattr(lhs.value, "attr") and lhs.value.attr == "adjoint":
|
|
1579
1794
|
# handle adjoint of a variable, i.e. wp.adjoint[var]
|
|
1580
|
-
|
|
1795
|
+
lhs.slice.is_adjoint = True
|
|
1796
|
+
src_var = adj.eval(lhs.slice)
|
|
1581
1797
|
var = Var(f"adj_{src_var.label}", type=src_var.type, constant=None, prefix=False)
|
|
1582
|
-
adj.symbols[var.label] = var
|
|
1583
1798
|
value = adj.eval(node.value)
|
|
1584
1799
|
adj.add_forward(f"{var.emit()} = {value.emit()};")
|
|
1585
|
-
return
|
|
1800
|
+
return
|
|
1586
1801
|
|
|
1587
|
-
target = adj.eval(
|
|
1802
|
+
target = adj.eval(lhs.value)
|
|
1588
1803
|
value = adj.eval(node.value)
|
|
1589
1804
|
|
|
1590
|
-
slice =
|
|
1805
|
+
slice = lhs.slice
|
|
1591
1806
|
indices = []
|
|
1592
1807
|
|
|
1593
1808
|
if isinstance(slice, ast.Tuple):
|
|
@@ -1595,7 +1810,6 @@ class Adjoint:
|
|
|
1595
1810
|
for arg in slice.elts:
|
|
1596
1811
|
var = adj.eval(arg)
|
|
1597
1812
|
indices.append(var)
|
|
1598
|
-
|
|
1599
1813
|
elif isinstance(slice, ast.Index) and isinstance(slice.value, ast.Tuple):
|
|
1600
1814
|
# handles the x[i, j] case (Python 3.7.x)
|
|
1601
1815
|
for arg in slice.value.elts:
|
|
@@ -1606,65 +1820,84 @@ class Adjoint:
|
|
|
1606
1820
|
var = adj.eval(slice)
|
|
1607
1821
|
indices.append(var)
|
|
1608
1822
|
|
|
1609
|
-
|
|
1610
|
-
adj.add_call(warp.context.builtin_functions["store"], [target, *indices, value])
|
|
1823
|
+
target_type = strip_reference(target.type)
|
|
1611
1824
|
|
|
1612
|
-
|
|
1613
|
-
adj.
|
|
1825
|
+
if is_array(target_type):
|
|
1826
|
+
adj.add_builtin_call("array_store", [target, *indices, value])
|
|
1827
|
+
|
|
1828
|
+
elif type_is_vector(target_type) or type_is_matrix(target_type):
|
|
1829
|
+
if is_reference(target.type):
|
|
1830
|
+
attr = adj.add_builtin_call("indexref", [target, *indices])
|
|
1831
|
+
else:
|
|
1832
|
+
attr = adj.add_builtin_call("index", [target, *indices])
|
|
1833
|
+
|
|
1834
|
+
adj.add_builtin_call("store", [attr, value])
|
|
1614
1835
|
|
|
1615
1836
|
if warp.config.verbose and not adj.custom_reverse_mode:
|
|
1616
1837
|
lineno = adj.lineno + adj.fun_lineno
|
|
1617
|
-
line = adj.
|
|
1618
|
-
node_source = adj.get_node_source(
|
|
1838
|
+
line = adj.source_lines[adj.lineno]
|
|
1839
|
+
node_source = adj.get_node_source(lhs.value)
|
|
1619
1840
|
print(
|
|
1620
1841
|
f"Warning: mutating {node_source} in function {adj.fun_name} at {adj.filename}:{lineno}: this is a non-differentiable operation.\n{line}\n"
|
|
1621
1842
|
)
|
|
1622
1843
|
|
|
1623
1844
|
else:
|
|
1624
|
-
raise
|
|
1845
|
+
raise WarpCodegenError("Can only subscript assign array, vector, and matrix types")
|
|
1625
1846
|
|
|
1626
|
-
|
|
1627
|
-
|
|
1628
|
-
elif isinstance(node.targets[0], ast.Name):
|
|
1847
|
+
elif isinstance(lhs, ast.Name):
|
|
1629
1848
|
# symbol name
|
|
1630
|
-
name =
|
|
1849
|
+
name = lhs.id
|
|
1631
1850
|
|
|
1632
1851
|
# evaluate rhs
|
|
1633
1852
|
rhs = adj.eval(node.value)
|
|
1634
1853
|
|
|
1635
1854
|
# check type matches if symbol already defined
|
|
1636
1855
|
if name in adj.symbols:
|
|
1637
|
-
if not types_equal(rhs.type, adj.symbols[name].type):
|
|
1638
|
-
raise
|
|
1639
|
-
"Error, assigning to existing symbol {} ({}) with different type ({})"
|
|
1640
|
-
name, adj.symbols[name].type, rhs.type
|
|
1641
|
-
)
|
|
1856
|
+
if not types_equal(strip_reference(rhs.type), adj.symbols[name].type):
|
|
1857
|
+
raise WarpCodegenTypeError(
|
|
1858
|
+
f"Error, assigning to existing symbol {name} ({adj.symbols[name].type}) with different type ({rhs.type})"
|
|
1642
1859
|
)
|
|
1643
1860
|
|
|
1644
1861
|
# handle simple assignment case (a = b), where we generate a value copy rather than reference
|
|
1645
|
-
if isinstance(node.value, ast.Name):
|
|
1646
|
-
out = adj.
|
|
1647
|
-
adj.add_call(warp.context.builtin_functions["copy"], [out, rhs])
|
|
1862
|
+
if isinstance(node.value, ast.Name) or is_reference(rhs.type):
|
|
1863
|
+
out = adj.add_builtin_call("copy", [rhs])
|
|
1648
1864
|
else:
|
|
1649
1865
|
out = rhs
|
|
1650
1866
|
|
|
1651
1867
|
# update symbol map (assumes lhs is a Name node)
|
|
1652
1868
|
adj.symbols[name] = out
|
|
1653
|
-
return out
|
|
1654
1869
|
|
|
1655
|
-
elif isinstance(
|
|
1870
|
+
elif isinstance(lhs, ast.Attribute):
|
|
1656
1871
|
rhs = adj.eval(node.value)
|
|
1657
|
-
|
|
1658
|
-
|
|
1872
|
+
aggregate = adj.eval(lhs.value)
|
|
1873
|
+
aggregate_type = strip_reference(aggregate.type)
|
|
1659
1874
|
|
|
1660
|
-
|
|
1661
|
-
|
|
1662
|
-
|
|
1663
|
-
|
|
1664
|
-
|
|
1875
|
+
# assigning to a vector component
|
|
1876
|
+
if type_is_vector(aggregate_type):
|
|
1877
|
+
index = adj.vector_component_index(lhs.attr, aggregate_type)
|
|
1878
|
+
|
|
1879
|
+
if is_reference(aggregate.type):
|
|
1880
|
+
attr = adj.add_builtin_call("indexref", [aggregate, index])
|
|
1881
|
+
else:
|
|
1882
|
+
attr = adj.add_builtin_call("index", [aggregate, index])
|
|
1883
|
+
|
|
1884
|
+
adj.add_builtin_call("store", [attr, rhs])
|
|
1885
|
+
|
|
1886
|
+
else:
|
|
1887
|
+
attr = adj.emit_Attribute(lhs)
|
|
1888
|
+
if is_reference(attr.type):
|
|
1889
|
+
adj.add_builtin_call("store", [attr, rhs])
|
|
1890
|
+
else:
|
|
1891
|
+
adj.add_builtin_call("assign", [attr, rhs])
|
|
1892
|
+
|
|
1893
|
+
if warp.config.verbose and not adj.custom_reverse_mode:
|
|
1894
|
+
lineno = adj.lineno + adj.fun_lineno
|
|
1895
|
+
line = adj.source_lines[adj.lineno]
|
|
1896
|
+
msg = f'Warning: detected mutated struct {attr.label} during function "{adj.fun_name}" at {adj.filename}:{lineno}: this is a non-differentiable operation.\n{line}\n'
|
|
1897
|
+
print(msg)
|
|
1665
1898
|
|
|
1666
1899
|
else:
|
|
1667
|
-
raise
|
|
1900
|
+
raise WarpCodegenError("Error, unsupported assignment statement.")
|
|
1668
1901
|
|
|
1669
1902
|
def emit_Return(adj, node):
|
|
1670
1903
|
if node.value is None:
|
|
@@ -1675,37 +1908,26 @@ class Adjoint:
|
|
|
1675
1908
|
var = (adj.eval(node.value),)
|
|
1676
1909
|
|
|
1677
1910
|
if adj.return_var is not None:
|
|
1678
|
-
old_ctypes = tuple(v.ctype() for v in adj.return_var)
|
|
1679
|
-
new_ctypes = tuple(v.ctype() for v in var)
|
|
1911
|
+
old_ctypes = tuple(v.ctype(value_type=True) for v in adj.return_var)
|
|
1912
|
+
new_ctypes = tuple(v.ctype(value_type=True) for v in var)
|
|
1680
1913
|
if old_ctypes != new_ctypes:
|
|
1681
|
-
raise
|
|
1914
|
+
raise WarpCodegenTypeError(
|
|
1682
1915
|
f"Error, function returned different types, previous: [{', '.join(old_ctypes)}], new [{', '.join(new_ctypes)}]"
|
|
1683
1916
|
)
|
|
1684
|
-
else:
|
|
1685
|
-
adj.return_var = var
|
|
1686
1917
|
|
|
1687
|
-
|
|
1918
|
+
if var is not None:
|
|
1919
|
+
adj.return_var = tuple()
|
|
1920
|
+
for ret in var:
|
|
1921
|
+
if is_reference(ret.type):
|
|
1922
|
+
ret = adj.add_builtin_call("copy", [ret])
|
|
1923
|
+
adj.return_var += (ret,)
|
|
1688
1924
|
|
|
1689
|
-
|
|
1690
|
-
# convert inplace operations (+=, -=, etc) to ssa form, e.g.: c = a + b
|
|
1691
|
-
left = adj.eval(node.target)
|
|
1925
|
+
adj.add_return(adj.return_var)
|
|
1692
1926
|
|
|
1693
|
-
|
|
1694
|
-
|
|
1695
|
-
|
|
1696
|
-
|
|
1697
|
-
return
|
|
1698
|
-
|
|
1699
|
-
right = adj.eval(node.value)
|
|
1700
|
-
|
|
1701
|
-
# lookup
|
|
1702
|
-
name = builtin_operators[type(node.op)]
|
|
1703
|
-
func = warp.context.builtin_functions[name]
|
|
1704
|
-
|
|
1705
|
-
out = adj.add_call(func, [left, right])
|
|
1706
|
-
|
|
1707
|
-
# update symbol map
|
|
1708
|
-
adj.symbols[node.target.id] = out
|
|
1927
|
+
def emit_AugAssign(adj, node):
|
|
1928
|
+
# replace augmented assignment with assignment statement + binary op
|
|
1929
|
+
new_node = ast.Assign(targets=[node.target], value=ast.BinOp(node.target, node.op, node.value))
|
|
1930
|
+
adj.eval(new_node)
|
|
1709
1931
|
|
|
1710
1932
|
def emit_Tuple(adj, node):
|
|
1711
1933
|
# LHS for expressions, such as i, j, k = 1, 2, 3
|
|
@@ -1715,131 +1937,159 @@ class Adjoint:
|
|
|
1715
1937
|
def emit_Pass(adj, node):
|
|
1716
1938
|
pass
|
|
1717
1939
|
|
|
1940
|
+
node_visitors = {
|
|
1941
|
+
ast.FunctionDef: emit_FunctionDef,
|
|
1942
|
+
ast.If: emit_If,
|
|
1943
|
+
ast.Compare: emit_Compare,
|
|
1944
|
+
ast.BoolOp: emit_BoolOp,
|
|
1945
|
+
ast.Name: emit_Name,
|
|
1946
|
+
ast.Attribute: emit_Attribute,
|
|
1947
|
+
ast.Str: emit_String, # Deprecated in 3.8; use Constant
|
|
1948
|
+
ast.Num: emit_Num, # Deprecated in 3.8; use Constant
|
|
1949
|
+
ast.NameConstant: emit_NameConstant, # Deprecated in 3.8; use Constant
|
|
1950
|
+
ast.Constant: emit_Constant,
|
|
1951
|
+
ast.BinOp: emit_BinOp,
|
|
1952
|
+
ast.UnaryOp: emit_UnaryOp,
|
|
1953
|
+
ast.While: emit_While,
|
|
1954
|
+
ast.For: emit_For,
|
|
1955
|
+
ast.Break: emit_Break,
|
|
1956
|
+
ast.Continue: emit_Continue,
|
|
1957
|
+
ast.Expr: emit_Expr,
|
|
1958
|
+
ast.Call: emit_Call,
|
|
1959
|
+
ast.Index: emit_Index, # Deprecated in 3.8; Use the index value directly instead.
|
|
1960
|
+
ast.Subscript: emit_Subscript,
|
|
1961
|
+
ast.Assign: emit_Assign,
|
|
1962
|
+
ast.Return: emit_Return,
|
|
1963
|
+
ast.AugAssign: emit_AugAssign,
|
|
1964
|
+
ast.Tuple: emit_Tuple,
|
|
1965
|
+
ast.Pass: emit_Pass,
|
|
1966
|
+
ast.Ellipsis: emit_Ellipsis,
|
|
1967
|
+
}
|
|
1968
|
+
|
|
1718
1969
|
def eval(adj, node):
|
|
1719
1970
|
if hasattr(node, "lineno"):
|
|
1720
1971
|
adj.set_lineno(node.lineno - 1)
|
|
1721
1972
|
|
|
1722
|
-
|
|
1723
|
-
|
|
1724
|
-
|
|
1725
|
-
ast.Compare: Adjoint.emit_Compare,
|
|
1726
|
-
ast.BoolOp: Adjoint.emit_BoolOp,
|
|
1727
|
-
ast.Name: Adjoint.emit_Name,
|
|
1728
|
-
ast.Attribute: Adjoint.emit_Attribute,
|
|
1729
|
-
ast.Str: Adjoint.emit_String, # Deprecated in 3.8; use Constant
|
|
1730
|
-
ast.Num: Adjoint.emit_Num, # Deprecated in 3.8; use Constant
|
|
1731
|
-
ast.NameConstant: Adjoint.emit_NameConstant, # Deprecated in 3.8; use Constant
|
|
1732
|
-
ast.Constant: Adjoint.emit_Constant,
|
|
1733
|
-
ast.BinOp: Adjoint.emit_BinOp,
|
|
1734
|
-
ast.UnaryOp: Adjoint.emit_UnaryOp,
|
|
1735
|
-
ast.While: Adjoint.emit_While,
|
|
1736
|
-
ast.For: Adjoint.emit_For,
|
|
1737
|
-
ast.Break: Adjoint.emit_Break,
|
|
1738
|
-
ast.Continue: Adjoint.emit_Continue,
|
|
1739
|
-
ast.Expr: Adjoint.emit_Expr,
|
|
1740
|
-
ast.Call: Adjoint.emit_Call,
|
|
1741
|
-
ast.Index: Adjoint.emit_Index, # Deprecated in 3.8; Use the index value directly instead.
|
|
1742
|
-
ast.Subscript: Adjoint.emit_Subscript,
|
|
1743
|
-
ast.Assign: Adjoint.emit_Assign,
|
|
1744
|
-
ast.Return: Adjoint.emit_Return,
|
|
1745
|
-
ast.AugAssign: Adjoint.emit_AugAssign,
|
|
1746
|
-
ast.Tuple: Adjoint.emit_Tuple,
|
|
1747
|
-
ast.Pass: Adjoint.emit_Pass,
|
|
1748
|
-
}
|
|
1749
|
-
|
|
1750
|
-
emit_node = node_visitors.get(type(node))
|
|
1751
|
-
|
|
1752
|
-
if emit_node is not None:
|
|
1753
|
-
if adj.is_user_function:
|
|
1754
|
-
if hasattr(node, "value") and hasattr(node.value, "func") and hasattr(node.value.func, "attr"):
|
|
1755
|
-
if node.value.func.attr == "tid":
|
|
1756
|
-
lineno = adj.lineno + adj.fun_lineno
|
|
1757
|
-
line = adj.source.splitlines()[adj.lineno]
|
|
1758
|
-
|
|
1759
|
-
warp.utils.warn(
|
|
1760
|
-
"Calling wp.tid() from a @wp.func is deprecated and will be removed in a future Warp "
|
|
1761
|
-
"version. Instead, obtain the indices from a @wp.kernel and pass them as "
|
|
1762
|
-
f"arguments to this function {adj.fun_name}, {adj.filename}:{lineno}:\n{line}\n",
|
|
1763
|
-
PendingDeprecationWarning,
|
|
1764
|
-
stacklevel=2,
|
|
1765
|
-
)
|
|
1766
|
-
return emit_node(adj, node)
|
|
1767
|
-
else:
|
|
1768
|
-
raise Exception("Error, ast node of type {} not supported".format(type(node)))
|
|
1973
|
+
emit_node = adj.node_visitors[type(node)]
|
|
1974
|
+
|
|
1975
|
+
return emit_node(adj, node)
|
|
1769
1976
|
|
|
1770
1977
|
# helper to evaluate expressions of the form
|
|
1771
1978
|
# obj1.obj2.obj3.attr in the function's global scope
|
|
1772
|
-
def resolve_path(adj,
|
|
1773
|
-
|
|
1979
|
+
def resolve_path(adj, path):
|
|
1980
|
+
if len(path) == 0:
|
|
1981
|
+
return None
|
|
1774
1982
|
|
|
1775
|
-
|
|
1776
|
-
|
|
1777
|
-
|
|
1983
|
+
# if root is overshadowed by local symbols, bail out
|
|
1984
|
+
if path[0] in adj.symbols:
|
|
1985
|
+
return None
|
|
1778
1986
|
|
|
1779
|
-
if
|
|
1780
|
-
|
|
1987
|
+
if path[0] in __builtins__:
|
|
1988
|
+
return __builtins__[path[0]]
|
|
1781
1989
|
|
|
1782
|
-
#
|
|
1783
|
-
|
|
1990
|
+
# Look up the closure info and append it to adj.func.__globals__
|
|
1991
|
+
# in case you want to define a kernel inside a function and refer
|
|
1992
|
+
# to variables you've declared inside that function:
|
|
1993
|
+
extract_contents = (
|
|
1994
|
+
lambda contents: contents
|
|
1995
|
+
if isinstance(contents, warp.context.Function) or not callable(contents)
|
|
1996
|
+
else contents
|
|
1997
|
+
)
|
|
1998
|
+
capturedvars = dict(
|
|
1999
|
+
zip(
|
|
2000
|
+
adj.func.__code__.co_freevars,
|
|
2001
|
+
[extract_contents(c.cell_contents) for c in (adj.func.__closure__ or [])],
|
|
2002
|
+
)
|
|
2003
|
+
)
|
|
2004
|
+
vars_dict = {**adj.func.__globals__, **capturedvars}
|
|
1784
2005
|
|
|
1785
|
-
if
|
|
1786
|
-
|
|
2006
|
+
if path[0] in vars_dict:
|
|
2007
|
+
func = vars_dict[path[0]]
|
|
1787
2008
|
|
|
1788
|
-
#
|
|
1789
|
-
|
|
1790
|
-
|
|
1791
|
-
# in case you want to define a kernel inside a function and refer
|
|
1792
|
-
# to variables you've declared inside that function:
|
|
1793
|
-
extract_contents = (
|
|
1794
|
-
lambda contents: contents
|
|
1795
|
-
if isinstance(contents, warp.context.Function) or not callable(contents)
|
|
1796
|
-
else contents
|
|
1797
|
-
)
|
|
1798
|
-
capturedvars = dict(
|
|
1799
|
-
zip(
|
|
1800
|
-
adj.func.__code__.co_freevars,
|
|
1801
|
-
[extract_contents(c.cell_contents) for c in (adj.func.__closure__ or [])],
|
|
1802
|
-
)
|
|
1803
|
-
)
|
|
2009
|
+
# Support Warp types in kernels without the module suffix (e.g. v = vec3(0.0,0.2,0.4)):
|
|
2010
|
+
else:
|
|
2011
|
+
func = getattr(warp, path[0], None)
|
|
1804
2012
|
|
|
1805
|
-
|
|
1806
|
-
|
|
1807
|
-
|
|
1808
|
-
|
|
1809
|
-
pass
|
|
2013
|
+
if func:
|
|
2014
|
+
for i in range(1, len(path)):
|
|
2015
|
+
if hasattr(func, path[i]):
|
|
2016
|
+
func = getattr(func, path[i])
|
|
1810
2017
|
|
|
1811
|
-
|
|
1812
|
-
# in a kernel:
|
|
2018
|
+
return func
|
|
1813
2019
|
|
|
1814
|
-
|
|
2020
|
+
# Evaluates a static expression that does not depend on runtime values
|
|
2021
|
+
# if eval_types is True, try resolving the path using evaluated type information as well
|
|
2022
|
+
def resolve_static_expression(adj, root_node, eval_types=True):
|
|
2023
|
+
attributes = []
|
|
1815
2024
|
|
|
1816
|
-
|
|
1817
|
-
|
|
1818
|
-
|
|
1819
|
-
|
|
2025
|
+
node = root_node
|
|
2026
|
+
while isinstance(node, ast.Attribute):
|
|
2027
|
+
attributes.append(node.attr)
|
|
2028
|
+
node = node.value
|
|
1820
2029
|
|
|
1821
|
-
|
|
1822
|
-
|
|
1823
|
-
|
|
1824
|
-
|
|
1825
|
-
|
|
1826
|
-
|
|
1827
|
-
|
|
1828
|
-
|
|
1829
|
-
|
|
2030
|
+
if eval_types and isinstance(node, ast.Call) and isinstance(node.func, ast.Name):
|
|
2031
|
+
# support for operators returning modules
|
|
2032
|
+
# i.e. operator_name(*operator_args).x.y.z
|
|
2033
|
+
operator_args = node.args
|
|
2034
|
+
operator_name = node.func.id
|
|
2035
|
+
|
|
2036
|
+
if operator_name == "type":
|
|
2037
|
+
if len(operator_args) != 1:
|
|
2038
|
+
raise WarpCodegenError(f"type() operator expects exactly one argument, got {len(operator_args)}")
|
|
2039
|
+
|
|
2040
|
+
# type() operator
|
|
2041
|
+
var = adj.eval(operator_args[0])
|
|
2042
|
+
|
|
2043
|
+
if isinstance(var, Var):
|
|
2044
|
+
var_type = strip_reference(var.type)
|
|
2045
|
+
# Allow accessing type attributes, for instance array.dtype
|
|
2046
|
+
while attributes:
|
|
2047
|
+
attr_name = attributes.pop()
|
|
2048
|
+
var_type, prev_type = adj.resolve_type_attribute(var_type, attr_name), var_type
|
|
2049
|
+
|
|
2050
|
+
if var_type is None:
|
|
2051
|
+
raise WarpCodegenAttributeError(
|
|
2052
|
+
f"{attr_name} is not an attribute of {type_repr(prev_type)}"
|
|
2053
|
+
)
|
|
2054
|
+
|
|
2055
|
+
return var_type, [type_repr(var_type)]
|
|
2056
|
+
else:
|
|
2057
|
+
raise WarpCodegenError(f"Cannot deduce the type of {var}")
|
|
2058
|
+
|
|
2059
|
+
# reverse list since ast presents it backward order
|
|
2060
|
+
path = [*reversed(attributes)]
|
|
2061
|
+
if isinstance(node, ast.Name):
|
|
2062
|
+
path.insert(0, node.id)
|
|
2063
|
+
|
|
2064
|
+
# Try resolving path from captured context
|
|
2065
|
+
captured_obj = adj.resolve_path(path)
|
|
2066
|
+
if captured_obj is not None:
|
|
2067
|
+
return captured_obj, path
|
|
2068
|
+
|
|
2069
|
+
# Still nothing found, maybe this is a predefined type attribute like `dtype`
|
|
2070
|
+
if eval_types:
|
|
2071
|
+
try:
|
|
2072
|
+
val = adj.eval(root_node)
|
|
2073
|
+
if val:
|
|
2074
|
+
return [val, type_repr(val)]
|
|
2075
|
+
|
|
2076
|
+
except Exception:
|
|
2077
|
+
pass
|
|
2078
|
+
|
|
2079
|
+
return None, path
|
|
1830
2080
|
|
|
1831
2081
|
# annotate generated code with the original source code line
|
|
1832
2082
|
def set_lineno(adj, lineno):
|
|
1833
2083
|
if adj.lineno is None or adj.lineno != lineno:
|
|
1834
2084
|
line = lineno + adj.fun_lineno
|
|
1835
|
-
source = adj.
|
|
2085
|
+
source = adj.source_lines[lineno].strip().ljust(80 - len(adj.indentation), " ")
|
|
1836
2086
|
adj.add_forward(f"// {source} <L {line}>")
|
|
1837
2087
|
adj.add_reverse(f"// adj: {source} <L {line}>")
|
|
1838
2088
|
adj.lineno = lineno
|
|
1839
2089
|
|
|
1840
2090
|
def get_node_source(adj, node):
|
|
1841
2091
|
# return the Python code corresponding to the given AST node
|
|
1842
|
-
return ast.get_source_segment(
|
|
2092
|
+
return ast.get_source_segment(adj.source, node)
|
|
1843
2093
|
|
|
1844
2094
|
|
|
1845
2095
|
# ----------------
|
|
@@ -1856,7 +2106,10 @@ cpu_module_header = """
|
|
|
1856
2106
|
#define int(x) cast_int(x)
|
|
1857
2107
|
#define adj_int(x, adj_x, adj_ret) adj_cast_int(x, adj_x, adj_ret)
|
|
1858
2108
|
|
|
1859
|
-
|
|
2109
|
+
#define builtin_tid1d() wp::tid(wp::s_threadIdx)
|
|
2110
|
+
#define builtin_tid2d(x, y) wp::tid(x, y, wp::s_threadIdx, dim)
|
|
2111
|
+
#define builtin_tid3d(x, y, z) wp::tid(x, y, z, wp::s_threadIdx, dim)
|
|
2112
|
+
#define builtin_tid4d(x, y, z, w) wp::tid(x, y, z, w, wp::s_threadIdx, dim)
|
|
1860
2113
|
|
|
1861
2114
|
"""
|
|
1862
2115
|
|
|
@@ -1871,8 +2124,10 @@ cuda_module_header = """
|
|
|
1871
2124
|
#define int(x) cast_int(x)
|
|
1872
2125
|
#define adj_int(x, adj_x, adj_ret) adj_cast_int(x, adj_x, adj_ret)
|
|
1873
2126
|
|
|
1874
|
-
|
|
1875
|
-
|
|
2127
|
+
#define builtin_tid1d() wp::tid(_idx)
|
|
2128
|
+
#define builtin_tid2d(x, y) wp::tid(x, y, _idx, dim)
|
|
2129
|
+
#define builtin_tid3d(x, y, z) wp::tid(x, y, z, _idx, dim)
|
|
2130
|
+
#define builtin_tid4d(x, y, z, w) wp::tid(x, y, z, w, _idx, dim)
|
|
1876
2131
|
|
|
1877
2132
|
"""
|
|
1878
2133
|
|
|
@@ -1886,7 +2141,9 @@ struct {name}
|
|
|
1886
2141
|
{{
|
|
1887
2142
|
}}
|
|
1888
2143
|
|
|
1889
|
-
CUDA_CALLABLE {name}& operator += (const {name}&)
|
|
2144
|
+
CUDA_CALLABLE {name}& operator += (const {name}& rhs)
|
|
2145
|
+
{{{prefix_add_body}
|
|
2146
|
+
return *this;}}
|
|
1890
2147
|
|
|
1891
2148
|
}};
|
|
1892
2149
|
|
|
@@ -1942,24 +2199,18 @@ cuda_kernel_template = """
|
|
|
1942
2199
|
extern "C" __global__ void {name}_cuda_kernel_forward(
|
|
1943
2200
|
{forward_args})
|
|
1944
2201
|
{{
|
|
1945
|
-
size_t _idx =
|
|
1946
|
-
|
|
1947
|
-
|
|
1948
|
-
|
|
1949
|
-
set_launch_bounds(dim);
|
|
1950
|
-
|
|
1951
|
-
{forward_body}}}
|
|
2202
|
+
for (size_t _idx = static_cast<size_t>(blockDim.x) * static_cast<size_t>(blockIdx.x) + static_cast<size_t>(threadIdx.x);
|
|
2203
|
+
_idx < dim.size;
|
|
2204
|
+
_idx += static_cast<size_t>(blockDim.x) * static_cast<size_t>(gridDim.x)) {{
|
|
2205
|
+
{forward_body}}}}}
|
|
1952
2206
|
|
|
1953
2207
|
extern "C" __global__ void {name}_cuda_kernel_backward(
|
|
1954
2208
|
{reverse_args})
|
|
1955
2209
|
{{
|
|
1956
|
-
size_t _idx =
|
|
1957
|
-
|
|
1958
|
-
|
|
1959
|
-
|
|
1960
|
-
set_launch_bounds(dim);
|
|
1961
|
-
|
|
1962
|
-
{reverse_body}}}
|
|
2210
|
+
for (size_t _idx = static_cast<size_t>(blockDim.x) * static_cast<size_t>(blockIdx.x) + static_cast<size_t>(threadIdx.x);
|
|
2211
|
+
_idx < dim.size;
|
|
2212
|
+
_idx += static_cast<size_t>(blockDim.x) * static_cast<size_t>(gridDim.x)) {{
|
|
2213
|
+
{reverse_body}}}}}
|
|
1963
2214
|
|
|
1964
2215
|
"""
|
|
1965
2216
|
|
|
@@ -1985,11 +2236,9 @@ extern "C" {{
|
|
|
1985
2236
|
WP_API void {name}_cpu_forward(
|
|
1986
2237
|
{forward_args})
|
|
1987
2238
|
{{
|
|
1988
|
-
set_launch_bounds(dim);
|
|
1989
|
-
|
|
1990
2239
|
for (size_t i=0; i < dim.size; ++i)
|
|
1991
2240
|
{{
|
|
1992
|
-
s_threadIdx = i;
|
|
2241
|
+
wp::s_threadIdx = i;
|
|
1993
2242
|
|
|
1994
2243
|
{name}_cpu_kernel_forward(
|
|
1995
2244
|
{forward_params});
|
|
@@ -1999,11 +2248,9 @@ WP_API void {name}_cpu_forward(
|
|
|
1999
2248
|
WP_API void {name}_cpu_backward(
|
|
2000
2249
|
{reverse_args})
|
|
2001
2250
|
{{
|
|
2002
|
-
set_launch_bounds(dim);
|
|
2003
|
-
|
|
2004
2251
|
for (size_t i=0; i < dim.size; ++i)
|
|
2005
2252
|
{{
|
|
2006
|
-
s_threadIdx = i;
|
|
2253
|
+
wp::s_threadIdx = i;
|
|
2007
2254
|
|
|
2008
2255
|
{name}_cpu_kernel_backward(
|
|
2009
2256
|
{reverse_params});
|
|
@@ -2109,8 +2356,13 @@ def codegen_struct(struct, device="cpu", indent_size=4):
|
|
|
2109
2356
|
|
|
2110
2357
|
body = []
|
|
2111
2358
|
indent_block = " " * indent_size
|
|
2112
|
-
|
|
2113
|
-
|
|
2359
|
+
|
|
2360
|
+
if len(struct.vars) > 0:
|
|
2361
|
+
for label, var in struct.vars.items():
|
|
2362
|
+
body.append(var.ctype() + " " + label + ";\n")
|
|
2363
|
+
else:
|
|
2364
|
+
# for empty structs, emit the dummy attribute to avoid any compiler-specific alignment issues
|
|
2365
|
+
body.append("char _dummy_;\n")
|
|
2114
2366
|
|
|
2115
2367
|
forward_args = []
|
|
2116
2368
|
reverse_args = []
|
|
@@ -2118,17 +2370,25 @@ def codegen_struct(struct, device="cpu", indent_size=4):
|
|
|
2118
2370
|
forward_initializers = []
|
|
2119
2371
|
reverse_body = []
|
|
2120
2372
|
atomic_add_body = []
|
|
2373
|
+
prefix_add_body = []
|
|
2121
2374
|
|
|
2122
2375
|
# forward args
|
|
2123
2376
|
for label, var in struct.vars.items():
|
|
2124
|
-
|
|
2125
|
-
|
|
2377
|
+
var_ctype = var.ctype()
|
|
2378
|
+
forward_args.append(f"{var_ctype} const& {label} = {{}}")
|
|
2379
|
+
reverse_args.append(f"{var_ctype} const&")
|
|
2126
2380
|
|
|
2127
|
-
|
|
2381
|
+
namespace = "wp::" if var_ctype.startswith("wp::") or var_ctype == "bool" else ""
|
|
2382
|
+
atomic_add_body.append(f"{indent_block}{namespace}adj_atomic_add(&p->{label}, t.{label});\n")
|
|
2128
2383
|
|
|
2129
2384
|
prefix = f"{indent_block}," if forward_initializers else ":"
|
|
2130
2385
|
forward_initializers.append(f"{indent_block}{prefix} {label}{{{label}}}\n")
|
|
2131
2386
|
|
|
2387
|
+
# prefix-add operator
|
|
2388
|
+
for label, var in struct.vars.items():
|
|
2389
|
+
if not is_array(var.type):
|
|
2390
|
+
prefix_add_body.append(f"{indent_block}{label} += rhs.{label};\n")
|
|
2391
|
+
|
|
2132
2392
|
# reverse args
|
|
2133
2393
|
for label, var in struct.vars.items():
|
|
2134
2394
|
reverse_args.append(var.ctype() + " & adj_" + label)
|
|
@@ -2146,6 +2406,7 @@ def codegen_struct(struct, device="cpu", indent_size=4):
|
|
|
2146
2406
|
forward_initializers="".join(forward_initializers),
|
|
2147
2407
|
reverse_args=indent(reverse_args),
|
|
2148
2408
|
reverse_body="".join(reverse_body),
|
|
2409
|
+
prefix_add_body="".join(prefix_add_body),
|
|
2149
2410
|
atomic_add_body="".join(atomic_add_body),
|
|
2150
2411
|
)
|
|
2151
2412
|
|
|
@@ -2189,7 +2450,7 @@ def codegen_func_forward(adj, func_type="kernel", device="cpu"):
|
|
|
2189
2450
|
return s
|
|
2190
2451
|
|
|
2191
2452
|
|
|
2192
|
-
def codegen_func_reverse_body(adj, device="cpu", indent=4):
|
|
2453
|
+
def codegen_func_reverse_body(adj, device="cpu", indent=4, func_type="kernel"):
|
|
2193
2454
|
body = []
|
|
2194
2455
|
indent_block = " " * indent
|
|
2195
2456
|
|
|
@@ -2207,7 +2468,11 @@ def codegen_func_reverse_body(adj, device="cpu", indent=4):
|
|
|
2207
2468
|
for l in reversed(adj.blocks[0].body_reverse):
|
|
2208
2469
|
body += [l + "\n"]
|
|
2209
2470
|
|
|
2210
|
-
body
|
|
2471
|
+
# In grid-stride kernels the reverse body is in a for loop
|
|
2472
|
+
if device == "cuda" and func_type == "kernel":
|
|
2473
|
+
body += ["continue;\n"]
|
|
2474
|
+
else:
|
|
2475
|
+
body += ["return;\n"]
|
|
2211
2476
|
|
|
2212
2477
|
return "".join([indent_block + l for l in body])
|
|
2213
2478
|
|
|
@@ -2230,20 +2495,17 @@ def codegen_func_reverse(adj, func_type="kernel", device="cpu"):
|
|
|
2230
2495
|
s += " // dual vars\n"
|
|
2231
2496
|
|
|
2232
2497
|
for var in adj.variables:
|
|
2233
|
-
|
|
2234
|
-
s += f" {var.ctype()} {var.emit('adj')};\n"
|
|
2235
|
-
else:
|
|
2236
|
-
s += f" {var.ctype()} {var.emit('adj')}(0);\n"
|
|
2498
|
+
s += f" {var.ctype(value_type=True)} {var.emit_adj()} = {{}};\n"
|
|
2237
2499
|
|
|
2238
2500
|
if device == "cpu":
|
|
2239
2501
|
s += codegen_func_reverse_body(adj, device=device, indent=4)
|
|
2240
2502
|
elif device == "cuda":
|
|
2241
2503
|
if func_type == "kernel":
|
|
2242
|
-
s += codegen_func_reverse_body(adj, device=device, indent=8)
|
|
2504
|
+
s += codegen_func_reverse_body(adj, device=device, indent=8, func_type=func_type)
|
|
2243
2505
|
else:
|
|
2244
|
-
s += codegen_func_reverse_body(adj, device=device, indent=4)
|
|
2506
|
+
s += codegen_func_reverse_body(adj, device=device, indent=4, func_type=func_type)
|
|
2245
2507
|
else:
|
|
2246
|
-
raise ValueError("Device {} not supported for codegen"
|
|
2508
|
+
raise ValueError(f"Device {device} not supported for codegen")
|
|
2247
2509
|
|
|
2248
2510
|
return s
|
|
2249
2511
|
|
|
@@ -2298,7 +2560,7 @@ def codegen_func(adj, c_func_name: str, device="cpu", options={}):
|
|
|
2298
2560
|
forward_template = cuda_forward_function_template
|
|
2299
2561
|
reverse_template = cuda_reverse_function_template
|
|
2300
2562
|
else:
|
|
2301
|
-
raise ValueError("Device {} is not supported"
|
|
2563
|
+
raise ValueError(f"Device {device} is not supported")
|
|
2302
2564
|
|
|
2303
2565
|
# codegen body
|
|
2304
2566
|
forward_body = codegen_func_forward(adj, func_type="function", device=device)
|
|
@@ -2335,6 +2597,55 @@ def codegen_func(adj, c_func_name: str, device="cpu", options={}):
|
|
|
2335
2597
|
return s
|
|
2336
2598
|
|
|
2337
2599
|
|
|
2600
|
+
def codegen_snippet(adj, name, snippet, adj_snippet):
|
|
2601
|
+
forward_args = []
|
|
2602
|
+
reverse_args = []
|
|
2603
|
+
|
|
2604
|
+
# forward args
|
|
2605
|
+
for i, arg in enumerate(adj.args):
|
|
2606
|
+
s = f"{arg.ctype()} {arg.emit().replace('var_', '')}"
|
|
2607
|
+
forward_args.append(s)
|
|
2608
|
+
reverse_args.append(s)
|
|
2609
|
+
|
|
2610
|
+
# reverse args
|
|
2611
|
+
for i, arg in enumerate(adj.args):
|
|
2612
|
+
if isinstance(arg.type, indexedarray):
|
|
2613
|
+
_arg = Var(arg.label, array(dtype=arg.type.dtype, ndim=arg.type.ndim))
|
|
2614
|
+
reverse_args.append(_arg.ctype() + " & adj_" + arg.label)
|
|
2615
|
+
else:
|
|
2616
|
+
reverse_args.append(arg.ctype() + " & adj_" + arg.label)
|
|
2617
|
+
|
|
2618
|
+
forward_template = cuda_forward_function_template
|
|
2619
|
+
reverse_template = cuda_reverse_function_template
|
|
2620
|
+
|
|
2621
|
+
s = ""
|
|
2622
|
+
s += forward_template.format(
|
|
2623
|
+
name=name,
|
|
2624
|
+
return_type="void",
|
|
2625
|
+
forward_args=indent(forward_args),
|
|
2626
|
+
forward_body=snippet,
|
|
2627
|
+
filename=adj.filename,
|
|
2628
|
+
lineno=adj.fun_lineno,
|
|
2629
|
+
)
|
|
2630
|
+
|
|
2631
|
+
if adj_snippet:
|
|
2632
|
+
reverse_body = adj_snippet
|
|
2633
|
+
else:
|
|
2634
|
+
reverse_body = ""
|
|
2635
|
+
|
|
2636
|
+
s += reverse_template.format(
|
|
2637
|
+
name=name,
|
|
2638
|
+
return_type="void",
|
|
2639
|
+
reverse_args=indent(reverse_args),
|
|
2640
|
+
forward_body=snippet,
|
|
2641
|
+
reverse_body=reverse_body,
|
|
2642
|
+
filename=adj.filename,
|
|
2643
|
+
lineno=adj.fun_lineno,
|
|
2644
|
+
)
|
|
2645
|
+
|
|
2646
|
+
return s
|
|
2647
|
+
|
|
2648
|
+
|
|
2338
2649
|
def codegen_kernel(kernel, device, options):
|
|
2339
2650
|
# Update the module's options with the ones defined on the kernel, if any.
|
|
2340
2651
|
options = dict(options)
|
|
@@ -2342,8 +2653,8 @@ def codegen_kernel(kernel, device, options):
|
|
|
2342
2653
|
|
|
2343
2654
|
adj = kernel.adj
|
|
2344
2655
|
|
|
2345
|
-
forward_args = ["launch_bounds_t dim"]
|
|
2346
|
-
reverse_args = ["launch_bounds_t dim"]
|
|
2656
|
+
forward_args = ["wp::launch_bounds_t dim"]
|
|
2657
|
+
reverse_args = ["wp::launch_bounds_t dim"]
|
|
2347
2658
|
|
|
2348
2659
|
# forward args
|
|
2349
2660
|
for arg in adj.args:
|
|
@@ -2372,7 +2683,7 @@ def codegen_kernel(kernel, device, options):
|
|
|
2372
2683
|
elif device == "cuda":
|
|
2373
2684
|
template = cuda_kernel_template
|
|
2374
2685
|
else:
|
|
2375
|
-
raise ValueError("Device {} is not supported"
|
|
2686
|
+
raise ValueError(f"Device {device} is not supported")
|
|
2376
2687
|
|
|
2377
2688
|
s = template.format(
|
|
2378
2689
|
name=kernel.get_mangled_name(),
|
|
@@ -2392,7 +2703,7 @@ def codegen_module(kernel, device="cpu"):
|
|
|
2392
2703
|
adj = kernel.adj
|
|
2393
2704
|
|
|
2394
2705
|
# build forward signature
|
|
2395
|
-
forward_args = ["launch_bounds_t dim"]
|
|
2706
|
+
forward_args = ["wp::launch_bounds_t dim"]
|
|
2396
2707
|
forward_params = ["dim"]
|
|
2397
2708
|
|
|
2398
2709
|
for arg in adj.args:
|