warp-lang 1.3.3__py3-none-manylinux2014_aarch64.whl → 1.4.1__py3-none-manylinux2014_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +6 -0
- warp/autograd.py +59 -6
- warp/bin/warp.so +0 -0
- warp/build_dll.py +8 -10
- warp/builtins.py +103 -3
- warp/codegen.py +447 -53
- warp/config.py +1 -1
- warp/context.py +682 -405
- warp/dlpack.py +2 -0
- warp/examples/benchmarks/benchmark_cloth.py +10 -0
- warp/examples/core/example_render_opengl.py +12 -10
- warp/examples/fem/example_adaptive_grid.py +251 -0
- warp/examples/fem/example_apic_fluid.py +1 -1
- warp/examples/fem/example_diffusion_3d.py +2 -2
- warp/examples/fem/example_magnetostatics.py +1 -1
- warp/examples/fem/example_streamlines.py +1 -0
- warp/examples/fem/utils.py +25 -5
- warp/examples/sim/example_cloth.py +50 -6
- warp/fem/__init__.py +2 -0
- warp/fem/adaptivity.py +493 -0
- warp/fem/field/field.py +2 -1
- warp/fem/field/nodal_field.py +18 -26
- warp/fem/field/test.py +4 -4
- warp/fem/field/trial.py +4 -4
- warp/fem/geometry/__init__.py +1 -0
- warp/fem/geometry/adaptive_nanogrid.py +843 -0
- warp/fem/geometry/nanogrid.py +55 -28
- warp/fem/space/__init__.py +1 -1
- warp/fem/space/nanogrid_function_space.py +69 -35
- warp/fem/utils.py +118 -107
- warp/jax_experimental.py +28 -15
- warp/native/array.h +0 -1
- warp/native/builtin.h +103 -6
- warp/native/bvh.cu +4 -2
- warp/native/cuda_util.cpp +14 -0
- warp/native/cuda_util.h +2 -0
- warp/native/error.cpp +4 -2
- warp/native/exports.h +99 -0
- warp/native/mat.h +97 -0
- warp/native/mesh.cpp +36 -0
- warp/native/mesh.cu +52 -1
- warp/native/mesh.h +1 -0
- warp/native/quat.h +43 -0
- warp/native/range.h +11 -2
- warp/native/spatial.h +6 -0
- warp/native/vec.h +74 -0
- warp/native/warp.cpp +2 -1
- warp/native/warp.cu +10 -3
- warp/native/warp.h +8 -1
- warp/paddle.py +382 -0
- warp/sim/__init__.py +1 -0
- warp/sim/collide.py +519 -0
- warp/sim/integrator_euler.py +18 -5
- warp/sim/integrator_featherstone.py +5 -5
- warp/sim/integrator_vbd.py +1026 -0
- warp/sim/integrator_xpbd.py +2 -6
- warp/sim/model.py +50 -25
- warp/sparse.py +9 -7
- warp/stubs.py +459 -0
- warp/tape.py +2 -0
- warp/tests/aux_test_dependent.py +1 -0
- warp/tests/aux_test_name_clash1.py +32 -0
- warp/tests/aux_test_name_clash2.py +32 -0
- warp/tests/aux_test_square.py +1 -0
- warp/tests/test_array.py +188 -0
- warp/tests/test_async.py +3 -3
- warp/tests/test_atomic.py +6 -0
- warp/tests/test_closest_point_edge_edge.py +93 -1
- warp/tests/test_codegen.py +93 -15
- warp/tests/test_codegen_instancing.py +1457 -0
- warp/tests/test_collision.py +486 -0
- warp/tests/test_compile_consts.py +3 -28
- warp/tests/test_dlpack.py +170 -0
- warp/tests/test_examples.py +22 -8
- warp/tests/test_fast_math.py +10 -4
- warp/tests/test_fem.py +81 -1
- warp/tests/test_func.py +46 -0
- warp/tests/test_implicit_init.py +49 -0
- warp/tests/test_jax.py +58 -0
- warp/tests/test_mat.py +84 -0
- warp/tests/test_mesh_query_point.py +188 -0
- warp/tests/test_model.py +13 -0
- warp/tests/test_module_hashing.py +40 -0
- warp/tests/test_multigpu.py +3 -3
- warp/tests/test_overwrite.py +8 -0
- warp/tests/test_paddle.py +852 -0
- warp/tests/test_print.py +89 -0
- warp/tests/test_quat.py +111 -0
- warp/tests/test_reload.py +31 -1
- warp/tests/test_scalar_ops.py +2 -0
- warp/tests/test_static.py +568 -0
- warp/tests/test_streams.py +64 -3
- warp/tests/test_struct.py +4 -4
- warp/tests/test_torch.py +24 -0
- warp/tests/test_triangle_closest_point.py +137 -0
- warp/tests/test_types.py +1 -1
- warp/tests/test_vbd.py +386 -0
- warp/tests/test_vec.py +143 -0
- warp/tests/test_vec_scalar_ops.py +139 -0
- warp/tests/unittest_suites.py +12 -0
- warp/tests/unittest_utils.py +9 -5
- warp/thirdparty/dlpack.py +3 -1
- warp/types.py +167 -36
- warp/utils.py +37 -14
- {warp_lang-1.3.3.dist-info → warp_lang-1.4.1.dist-info}/METADATA +10 -8
- {warp_lang-1.3.3.dist-info → warp_lang-1.4.1.dist-info}/RECORD +109 -97
- warp/tests/test_point_triangle_closest_point.py +0 -143
- {warp_lang-1.3.3.dist-info → warp_lang-1.4.1.dist-info}/LICENSE.md +0 -0
- {warp_lang-1.3.3.dist-info → warp_lang-1.4.1.dist-info}/WHEEL +0 -0
- {warp_lang-1.3.3.dist-info → warp_lang-1.4.1.dist-info}/top_level.txt +0 -0
warp/codegen.py
CHANGED
|
@@ -11,6 +11,7 @@ import ast
|
|
|
11
11
|
import builtins
|
|
12
12
|
import ctypes
|
|
13
13
|
import functools
|
|
14
|
+
import hashlib
|
|
14
15
|
import inspect
|
|
15
16
|
import math
|
|
16
17
|
import re
|
|
@@ -232,8 +233,11 @@ class StructInstance:
|
|
|
232
233
|
|
|
233
234
|
def __getattribute__(self, name):
|
|
234
235
|
cls = super().__getattribute__("_cls")
|
|
235
|
-
if name
|
|
236
|
-
|
|
236
|
+
if name == "native_name":
|
|
237
|
+
return cls.native_name
|
|
238
|
+
|
|
239
|
+
var = cls.vars.get(name)
|
|
240
|
+
if var is not None:
|
|
237
241
|
if isinstance(var.type, type) and issubclass(var.type, ctypes.Array):
|
|
238
242
|
# Each field stored in a `StructInstance` is exposed as
|
|
239
243
|
# a standard Python attribute but also has a `ctypes`
|
|
@@ -408,6 +412,9 @@ class Struct:
|
|
|
408
412
|
elif issubclass(var.type, ctypes.Array):
|
|
409
413
|
fields.append((label, var.type))
|
|
410
414
|
else:
|
|
415
|
+
# HACK: fp16 requires conversion functions from warp.so
|
|
416
|
+
if var.type is warp.float16:
|
|
417
|
+
warp.init()
|
|
411
418
|
fields.append((label, var.type._type_))
|
|
412
419
|
|
|
413
420
|
class StructType(ctypes.Structure):
|
|
@@ -416,15 +423,35 @@ class Struct:
|
|
|
416
423
|
|
|
417
424
|
self.ctype = StructType
|
|
418
425
|
|
|
426
|
+
# Compute the hash. We can cache the hash because it's static, even with nested structs.
|
|
427
|
+
# All field types are specified in the annotations, so they're resolved at declaration time.
|
|
428
|
+
ch = hashlib.sha256()
|
|
429
|
+
|
|
430
|
+
ch.update(bytes(self.key, "utf-8"))
|
|
431
|
+
|
|
432
|
+
for name, type_hint in annotations.items():
|
|
433
|
+
s = f"{name}:{warp.types.get_type_code(type_hint)}"
|
|
434
|
+
ch.update(bytes(s, "utf-8"))
|
|
435
|
+
|
|
436
|
+
# recurse on nested structs
|
|
437
|
+
if isinstance(type_hint, Struct):
|
|
438
|
+
ch.update(type_hint.hash)
|
|
439
|
+
|
|
440
|
+
self.hash = ch.digest()
|
|
441
|
+
|
|
442
|
+
# generate unique identifier for structs in native code
|
|
443
|
+
hash_suffix = f"{self.hash.hex()[:8]}"
|
|
444
|
+
self.native_name = f"{self.key}_{hash_suffix}"
|
|
445
|
+
|
|
419
446
|
# create default constructor (zero-initialize)
|
|
420
447
|
self.default_constructor = warp.context.Function(
|
|
421
448
|
func=None,
|
|
422
|
-
key=self.
|
|
449
|
+
key=self.native_name,
|
|
423
450
|
namespace="",
|
|
424
451
|
value_func=lambda *_: self,
|
|
425
452
|
input_types={},
|
|
426
453
|
initializer_list_func=lambda *_: False,
|
|
427
|
-
native_func=
|
|
454
|
+
native_func=self.native_name,
|
|
428
455
|
)
|
|
429
456
|
|
|
430
457
|
# build a constructor that takes each param as a value
|
|
@@ -432,12 +459,12 @@ class Struct:
|
|
|
432
459
|
|
|
433
460
|
self.value_constructor = warp.context.Function(
|
|
434
461
|
func=None,
|
|
435
|
-
key=self.
|
|
462
|
+
key=self.native_name,
|
|
436
463
|
namespace="",
|
|
437
464
|
value_func=lambda *_: self,
|
|
438
465
|
input_types=input_types,
|
|
439
466
|
initializer_list_func=lambda *_: False,
|
|
440
|
-
native_func=
|
|
467
|
+
native_func=self.native_name,
|
|
441
468
|
)
|
|
442
469
|
|
|
443
470
|
self.default_constructor.add_overload(self.value_constructor)
|
|
@@ -465,6 +492,10 @@ class Struct:
|
|
|
465
492
|
def __init__(inst):
|
|
466
493
|
StructInstance.__init__(inst, self, None)
|
|
467
494
|
|
|
495
|
+
# make sure warp.types.get_type_code works with this StructInstance
|
|
496
|
+
NewStructInstance.cls = self.cls
|
|
497
|
+
NewStructInstance.native_name = self.native_name
|
|
498
|
+
|
|
468
499
|
return NewStructInstance()
|
|
469
500
|
|
|
470
501
|
def initializer(self):
|
|
@@ -599,7 +630,7 @@ class Var:
|
|
|
599
630
|
if hasattr(t.dtype, "_wp_generic_type_str_"):
|
|
600
631
|
dtypestr = compute_type_str(f"wp::{t.dtype._wp_generic_type_str_}", t.dtype._wp_type_params_)
|
|
601
632
|
elif isinstance(t.dtype, Struct):
|
|
602
|
-
dtypestr =
|
|
633
|
+
dtypestr = t.dtype.native_name
|
|
603
634
|
elif t.dtype.__name__ in ("bool", "int", "float"):
|
|
604
635
|
dtypestr = t.dtype.__name__
|
|
605
636
|
else:
|
|
@@ -607,7 +638,10 @@ class Var:
|
|
|
607
638
|
classstr = f"wp::{type(t).__name__}"
|
|
608
639
|
return f"{classstr}_t<{dtypestr}>"
|
|
609
640
|
elif isinstance(t, Struct):
|
|
610
|
-
return
|
|
641
|
+
return t.native_name
|
|
642
|
+
elif isinstance(t, type) and issubclass(t, StructInstance):
|
|
643
|
+
# ensure the actual Struct name is used instead of "NewStructInstance"
|
|
644
|
+
return t.native_name
|
|
611
645
|
elif is_reference(t):
|
|
612
646
|
if not value_type:
|
|
613
647
|
return Var.type_to_ctype(t.value_type) + "*"
|
|
@@ -743,6 +777,9 @@ def func_match_args(func, arg_types, kwarg_types):
|
|
|
743
777
|
|
|
744
778
|
|
|
745
779
|
def get_arg_type(arg: Union[Var, Any]):
|
|
780
|
+
if isinstance(arg, str):
|
|
781
|
+
return str
|
|
782
|
+
|
|
746
783
|
if isinstance(arg, Sequence):
|
|
747
784
|
return tuple(get_arg_type(x) for x in arg)
|
|
748
785
|
|
|
@@ -863,6 +900,12 @@ class Adjoint:
|
|
|
863
900
|
# this is to avoid registering false references to overshadowed modules
|
|
864
901
|
adj.symbols[name] = arg
|
|
865
902
|
|
|
903
|
+
# try to replace static expressions by their constant result if the
|
|
904
|
+
# expression can be evaluated at declaration time
|
|
905
|
+
adj.static_expressions: Dict[str, Any] = {}
|
|
906
|
+
if "static" in adj.source:
|
|
907
|
+
adj.replace_static_expressions()
|
|
908
|
+
|
|
866
909
|
# There are cases where a same module might be rebuilt multiple times,
|
|
867
910
|
# for example when kernels are nested inside of functions, or when
|
|
868
911
|
# a kernel's launch raises an exception. Ideally we'd always want to
|
|
@@ -896,6 +939,7 @@ class Adjoint:
|
|
|
896
939
|
|
|
897
940
|
adj.return_var = None # return type for function or kernel
|
|
898
941
|
adj.loop_symbols = [] # symbols at the start of each loop
|
|
942
|
+
adj.loop_const_iter_symbols = [] # iteration variables (constant) for static loops
|
|
899
943
|
|
|
900
944
|
# blocks
|
|
901
945
|
adj.blocks = [Block()]
|
|
@@ -948,9 +992,9 @@ class Adjoint:
|
|
|
948
992
|
if isinstance(a, warp.context.Function):
|
|
949
993
|
# functions don't have a var_ prefix so strip it off here
|
|
950
994
|
if prefix == "var":
|
|
951
|
-
arg_strs.append(a.
|
|
995
|
+
arg_strs.append(a.native_func)
|
|
952
996
|
else:
|
|
953
|
-
arg_strs.append(f"{prefix}_{a.
|
|
997
|
+
arg_strs.append(f"{prefix}_{a.native_func}")
|
|
954
998
|
elif is_reference(a.type):
|
|
955
999
|
arg_strs.append(f"{prefix}_{a}")
|
|
956
1000
|
elif isinstance(a, Var):
|
|
@@ -1255,6 +1299,10 @@ class Adjoint:
|
|
|
1255
1299
|
if not isinstance(func_arg, (Reference, warp.context.Function)):
|
|
1256
1300
|
func_arg = adj.load(func_arg)
|
|
1257
1301
|
|
|
1302
|
+
# if the argument is a function, build it recursively
|
|
1303
|
+
if isinstance(func_arg, warp.context.Function):
|
|
1304
|
+
adj.builder.build_function(func_arg)
|
|
1305
|
+
|
|
1258
1306
|
fwd_args.append(strip_reference(func_arg))
|
|
1259
1307
|
|
|
1260
1308
|
if return_type is None:
|
|
@@ -1440,6 +1488,7 @@ class Adjoint:
|
|
|
1440
1488
|
cond_block.body_forward.append(f"start_{cond_block.label}:;")
|
|
1441
1489
|
|
|
1442
1490
|
c = adj.eval(cond)
|
|
1491
|
+
c = adj.load(c)
|
|
1443
1492
|
|
|
1444
1493
|
cond_block.body_forward.append(f"if (({c.emit()}) == false) goto end_{cond_block.label};")
|
|
1445
1494
|
|
|
@@ -1493,6 +1542,9 @@ class Adjoint:
|
|
|
1493
1542
|
|
|
1494
1543
|
def emit_FunctionDef(adj, node):
|
|
1495
1544
|
for f in node.body:
|
|
1545
|
+
# Skip variable creation for standalone constants, including docstrings
|
|
1546
|
+
if isinstance(f, ast.Expr) and isinstance(f.value, ast.Constant):
|
|
1547
|
+
continue
|
|
1496
1548
|
adj.eval(f)
|
|
1497
1549
|
|
|
1498
1550
|
if adj.return_var is not None and len(adj.return_var) == 1:
|
|
@@ -1523,6 +1575,16 @@ class Adjoint:
|
|
|
1523
1575
|
# eval condition
|
|
1524
1576
|
cond = adj.eval(node.test)
|
|
1525
1577
|
|
|
1578
|
+
if cond.constant is not None:
|
|
1579
|
+
# resolve constant condition
|
|
1580
|
+
if cond.constant:
|
|
1581
|
+
for stmt in node.body:
|
|
1582
|
+
adj.eval(stmt)
|
|
1583
|
+
else:
|
|
1584
|
+
for stmt in node.orelse:
|
|
1585
|
+
adj.eval(stmt)
|
|
1586
|
+
return None
|
|
1587
|
+
|
|
1526
1588
|
# save symbol map
|
|
1527
1589
|
symbols_prev = adj.symbols.copy()
|
|
1528
1590
|
|
|
@@ -1618,7 +1680,7 @@ class Adjoint:
|
|
|
1618
1680
|
if isinstance(obj, types.ModuleType):
|
|
1619
1681
|
return obj
|
|
1620
1682
|
|
|
1621
|
-
raise
|
|
1683
|
+
raise TypeError(f"Invalid external reference type: {type(obj)}")
|
|
1622
1684
|
|
|
1623
1685
|
@staticmethod
|
|
1624
1686
|
def resolve_type_attribute(var_type: type, attr: str):
|
|
@@ -1732,7 +1794,7 @@ class Adjoint:
|
|
|
1732
1794
|
|
|
1733
1795
|
def emit_NameConstant(adj, node):
|
|
1734
1796
|
if node.value:
|
|
1735
|
-
return adj.add_constant(
|
|
1797
|
+
return adj.add_constant(node.value)
|
|
1736
1798
|
elif node.value is None:
|
|
1737
1799
|
raise WarpCodegenTypeError("None type unsupported")
|
|
1738
1800
|
else:
|
|
@@ -1746,7 +1808,7 @@ class Adjoint:
|
|
|
1746
1808
|
elif isinstance(node, ast.Ellipsis):
|
|
1747
1809
|
return adj.emit_Ellipsis(node)
|
|
1748
1810
|
else:
|
|
1749
|
-
assert isinstance(node, ast.NameConstant)
|
|
1811
|
+
assert isinstance(node, ast.NameConstant) or isinstance(node, ast.Constant)
|
|
1750
1812
|
return adj.emit_NameConstant(node)
|
|
1751
1813
|
|
|
1752
1814
|
def emit_BinOp(adj, node):
|
|
@@ -1787,6 +1849,11 @@ class Adjoint:
|
|
|
1787
1849
|
# detect symbols with conflicting definitions (assigned inside the for loop)
|
|
1788
1850
|
for items in symbols.items():
|
|
1789
1851
|
sym = items[0]
|
|
1852
|
+
if adj.is_constant_iter_symbol(sym):
|
|
1853
|
+
# ignore constant overwriting in for-loops if it is a loop iterator
|
|
1854
|
+
# (it is no problem to unroll static loops multiple times in sequence)
|
|
1855
|
+
continue
|
|
1856
|
+
|
|
1790
1857
|
var1 = items[1]
|
|
1791
1858
|
var2 = adj.symbols[sym]
|
|
1792
1859
|
|
|
@@ -1933,15 +2000,36 @@ class Adjoint:
|
|
|
1933
2000
|
)
|
|
1934
2001
|
return range_call
|
|
1935
2002
|
|
|
2003
|
+
def begin_record_constant_iter_symbols(adj):
|
|
2004
|
+
if len(adj.loop_const_iter_symbols) > 0:
|
|
2005
|
+
adj.loop_const_iter_symbols.append(adj.loop_const_iter_symbols[-1])
|
|
2006
|
+
else:
|
|
2007
|
+
adj.loop_const_iter_symbols.append(set())
|
|
2008
|
+
|
|
2009
|
+
def end_record_constant_iter_symbols(adj):
|
|
2010
|
+
if len(adj.loop_const_iter_symbols) > 0:
|
|
2011
|
+
adj.loop_const_iter_symbols.pop()
|
|
2012
|
+
|
|
2013
|
+
def record_constant_iter_symbol(adj, sym):
|
|
2014
|
+
if len(adj.loop_const_iter_symbols) > 0:
|
|
2015
|
+
adj.loop_const_iter_symbols[-1].add(sym)
|
|
2016
|
+
|
|
2017
|
+
def is_constant_iter_symbol(adj, sym):
|
|
2018
|
+
return len(adj.loop_const_iter_symbols) > 0 and sym in adj.loop_const_iter_symbols[-1]
|
|
2019
|
+
|
|
1936
2020
|
def emit_For(adj, node):
|
|
1937
2021
|
# try and unroll simple range() statements that use constant args
|
|
1938
2022
|
unroll_range = adj.get_unroll_range(node)
|
|
1939
2023
|
|
|
1940
2024
|
if isinstance(unroll_range, range):
|
|
2025
|
+
const_iter_sym = node.target.id
|
|
2026
|
+
# prevent constant conflicts in `materialize_redefinitions()`
|
|
2027
|
+
adj.record_constant_iter_symbol(const_iter_sym)
|
|
2028
|
+
|
|
2029
|
+
# unroll static for-loop
|
|
1941
2030
|
for i in unroll_range:
|
|
1942
2031
|
const_iter = adj.add_constant(i)
|
|
1943
|
-
|
|
1944
|
-
adj.symbols[node.target.id] = var_iter
|
|
2032
|
+
adj.symbols[const_iter_sym] = const_iter
|
|
1945
2033
|
|
|
1946
2034
|
# eval body
|
|
1947
2035
|
for s in node.body:
|
|
@@ -1957,6 +2045,7 @@ class Adjoint:
|
|
|
1957
2045
|
iter = adj.eval(node.iter)
|
|
1958
2046
|
|
|
1959
2047
|
adj.symbols[node.target.id] = adj.begin_for(iter)
|
|
2048
|
+
adj.begin_record_constant_iter_symbols()
|
|
1960
2049
|
|
|
1961
2050
|
# for loops should be side-effect free, here we store a copy
|
|
1962
2051
|
adj.loop_symbols.append(adj.symbols.copy())
|
|
@@ -1967,6 +2056,7 @@ class Adjoint:
|
|
|
1967
2056
|
|
|
1968
2057
|
adj.materialize_redefinitions(adj.loop_symbols[-1])
|
|
1969
2058
|
adj.loop_symbols.pop()
|
|
2059
|
+
adj.end_record_constant_iter_symbols()
|
|
1970
2060
|
|
|
1971
2061
|
adj.end_for(iter)
|
|
1972
2062
|
|
|
@@ -2023,13 +2113,28 @@ class Adjoint:
|
|
|
2023
2113
|
|
|
2024
2114
|
# try and lookup function in globals by
|
|
2025
2115
|
# resolving path (e.g.: module.submodule.attr)
|
|
2026
|
-
|
|
2116
|
+
if hasattr(node.func, "warp_func"):
|
|
2117
|
+
func = node.func.warp_func
|
|
2118
|
+
path = []
|
|
2119
|
+
else:
|
|
2120
|
+
func, path = adj.resolve_static_expression(node.func)
|
|
2027
2121
|
if func is None:
|
|
2028
2122
|
func = adj.eval(node.func)
|
|
2029
2123
|
|
|
2124
|
+
if adj.is_static_expression(func):
|
|
2125
|
+
# try to evaluate wp.static() expressions
|
|
2126
|
+
obj, _ = adj.evaluate_static_expression(node)
|
|
2127
|
+
if obj is not None:
|
|
2128
|
+
if isinstance(obj, warp.context.Function):
|
|
2129
|
+
# special handling for wp.static() evaluating to a function
|
|
2130
|
+
return obj
|
|
2131
|
+
else:
|
|
2132
|
+
out = adj.add_constant(obj)
|
|
2133
|
+
return out
|
|
2134
|
+
|
|
2030
2135
|
type_args = {}
|
|
2031
2136
|
|
|
2032
|
-
if not isinstance(func, warp.context.Function):
|
|
2137
|
+
if len(path) > 0 and not isinstance(func, warp.context.Function):
|
|
2033
2138
|
attr = path[-1]
|
|
2034
2139
|
caller = func
|
|
2035
2140
|
func = None
|
|
@@ -2083,6 +2188,9 @@ class Adjoint:
|
|
|
2083
2188
|
args = tuple(adj.resolve_arg(x) for x in node.args)
|
|
2084
2189
|
kwargs = {x.arg: adj.resolve_arg(x.value) for x in node.keywords}
|
|
2085
2190
|
|
|
2191
|
+
# add the call and build the callee adjoint if needed (func.adj)
|
|
2192
|
+
out = adj.add_call(func, args, kwargs, type_args, min_outputs=min_outputs)
|
|
2193
|
+
|
|
2086
2194
|
if warp.config.verify_autograd_array_access:
|
|
2087
2195
|
# update arg read/write states according to what happens to that arg in the called function
|
|
2088
2196
|
if hasattr(func, "adj"):
|
|
@@ -2095,7 +2203,6 @@ class Adjoint:
|
|
|
2095
2203
|
if func.adj.args[i].is_read:
|
|
2096
2204
|
arg.mark_read()
|
|
2097
2205
|
|
|
2098
|
-
out = adj.add_call(func, args, kwargs, type_args, min_outputs=min_outputs)
|
|
2099
2206
|
return out
|
|
2100
2207
|
|
|
2101
2208
|
def emit_Index(adj, node):
|
|
@@ -2281,20 +2388,40 @@ class Adjoint:
|
|
|
2281
2388
|
target.mark_write(kernel_name=kernel_name, filename=filename, lineno=lineno)
|
|
2282
2389
|
|
|
2283
2390
|
elif type_is_vector(target_type) or type_is_quaternion(target_type) or type_is_matrix(target_type):
|
|
2391
|
+
# recursively unwind AST, stopping at penultimate node
|
|
2392
|
+
node = lhs
|
|
2393
|
+
while hasattr(node, "value"):
|
|
2394
|
+
if hasattr(node.value, "value"):
|
|
2395
|
+
node = node.value
|
|
2396
|
+
else:
|
|
2397
|
+
break
|
|
2398
|
+
# lhs is updating a variable adjoint (i.e. wp.adjoint[var])
|
|
2399
|
+
if hasattr(node, "attr") and node.attr == "adjoint":
|
|
2400
|
+
attr = adj.add_builtin_call("index", [target, *indices])
|
|
2401
|
+
adj.add_builtin_call("store", [attr, rhs])
|
|
2402
|
+
return
|
|
2403
|
+
|
|
2404
|
+
# TODO: array vec component case
|
|
2284
2405
|
if is_reference(target.type):
|
|
2285
2406
|
attr = adj.add_builtin_call("indexref", [target, *indices])
|
|
2286
|
-
|
|
2287
|
-
attr = adj.add_builtin_call("index", [target, *indices])
|
|
2407
|
+
adj.add_builtin_call("store", [attr, rhs])
|
|
2288
2408
|
|
|
2289
|
-
|
|
2409
|
+
if warp.config.verbose and not adj.custom_reverse_mode:
|
|
2410
|
+
lineno = adj.lineno + adj.fun_lineno
|
|
2411
|
+
line = adj.source_lines[adj.lineno]
|
|
2412
|
+
node_source = adj.get_node_source(lhs.value)
|
|
2413
|
+
print(
|
|
2414
|
+
f"Warning: mutating {node_source} in function {adj.fun_name} at {adj.filename}:{lineno}: this is a non-differentiable operation.\n{line}\n"
|
|
2415
|
+
)
|
|
2290
2416
|
|
|
2291
|
-
|
|
2292
|
-
|
|
2293
|
-
|
|
2294
|
-
|
|
2295
|
-
|
|
2296
|
-
|
|
2297
|
-
|
|
2417
|
+
else:
|
|
2418
|
+
out = adj.add_builtin_call("assign", [target, *indices, rhs])
|
|
2419
|
+
|
|
2420
|
+
# re-point target symbol to out var
|
|
2421
|
+
for id in adj.symbols:
|
|
2422
|
+
if adj.symbols[id] == target:
|
|
2423
|
+
adj.symbols[id] = out
|
|
2424
|
+
break
|
|
2298
2425
|
|
|
2299
2426
|
else:
|
|
2300
2427
|
raise WarpCodegenError(
|
|
@@ -2329,16 +2456,24 @@ class Adjoint:
|
|
|
2329
2456
|
aggregate = adj.eval(lhs.value)
|
|
2330
2457
|
aggregate_type = strip_reference(aggregate.type)
|
|
2331
2458
|
|
|
2332
|
-
# assigning to a vector component
|
|
2333
|
-
if type_is_vector(aggregate_type):
|
|
2459
|
+
# assigning to a vector or quaternion component
|
|
2460
|
+
if type_is_vector(aggregate_type) or type_is_quaternion(aggregate_type):
|
|
2461
|
+
# TODO: handle wp.adjoint case
|
|
2462
|
+
|
|
2334
2463
|
index = adj.vector_component_index(lhs.attr, aggregate_type)
|
|
2335
2464
|
|
|
2465
|
+
# TODO: array vec component case
|
|
2336
2466
|
if is_reference(aggregate.type):
|
|
2337
2467
|
attr = adj.add_builtin_call("indexref", [aggregate, index])
|
|
2468
|
+
adj.add_builtin_call("store", [attr, rhs])
|
|
2338
2469
|
else:
|
|
2339
|
-
|
|
2470
|
+
out = adj.add_builtin_call("assign", [aggregate, index, rhs])
|
|
2340
2471
|
|
|
2341
|
-
|
|
2472
|
+
# re-point target symbol to out var
|
|
2473
|
+
for id in adj.symbols:
|
|
2474
|
+
if adj.symbols[id] == aggregate:
|
|
2475
|
+
adj.symbols[id] = out
|
|
2476
|
+
break
|
|
2342
2477
|
|
|
2343
2478
|
else:
|
|
2344
2479
|
attr = adj.emit_Attribute(lhs)
|
|
@@ -2382,9 +2517,66 @@ class Adjoint:
|
|
|
2382
2517
|
adj.add_return(adj.return_var)
|
|
2383
2518
|
|
|
2384
2519
|
def emit_AugAssign(adj, node):
|
|
2385
|
-
|
|
2386
|
-
|
|
2387
|
-
|
|
2520
|
+
lhs = node.target
|
|
2521
|
+
|
|
2522
|
+
# replace augmented assignment with assignment statement + binary op (default behaviour)
|
|
2523
|
+
def make_new_assign_statement():
|
|
2524
|
+
new_node = ast.Assign(targets=[lhs], value=ast.BinOp(lhs, node.op, node.value))
|
|
2525
|
+
adj.eval(new_node)
|
|
2526
|
+
|
|
2527
|
+
if isinstance(lhs, ast.Subscript):
|
|
2528
|
+
rhs = adj.eval(node.value)
|
|
2529
|
+
|
|
2530
|
+
# wp.adjoint[var] appears in custom grad functions, and does not require
|
|
2531
|
+
# special consideration in the AugAssign case
|
|
2532
|
+
if hasattr(lhs.value, "attr") and lhs.value.attr == "adjoint":
|
|
2533
|
+
make_new_assign_statement()
|
|
2534
|
+
return
|
|
2535
|
+
|
|
2536
|
+
target, indices = adj.eval_subscript(lhs)
|
|
2537
|
+
|
|
2538
|
+
target_type = strip_reference(target.type)
|
|
2539
|
+
|
|
2540
|
+
if is_array(target_type):
|
|
2541
|
+
# target_type is not suitable for atomic array accumulation
|
|
2542
|
+
if target_type.dtype not in warp.types.atomic_types:
|
|
2543
|
+
make_new_assign_statement()
|
|
2544
|
+
return
|
|
2545
|
+
|
|
2546
|
+
kernel_name = adj.fun_name
|
|
2547
|
+
filename = adj.filename
|
|
2548
|
+
lineno = adj.lineno + adj.fun_lineno
|
|
2549
|
+
|
|
2550
|
+
if isinstance(node.op, ast.Add):
|
|
2551
|
+
adj.add_builtin_call("atomic_add", [target, *indices, rhs])
|
|
2552
|
+
|
|
2553
|
+
if warp.config.verify_autograd_array_access:
|
|
2554
|
+
target.mark_write(kernel_name=kernel_name, filename=filename, lineno=lineno)
|
|
2555
|
+
|
|
2556
|
+
elif isinstance(node.op, ast.Sub):
|
|
2557
|
+
adj.add_builtin_call("atomic_sub", [target, *indices, rhs])
|
|
2558
|
+
|
|
2559
|
+
if warp.config.verify_autograd_array_access:
|
|
2560
|
+
target.mark_write(kernel_name=kernel_name, filename=filename, lineno=lineno)
|
|
2561
|
+
else:
|
|
2562
|
+
print(f"Warning: in-place op {node.op} is not differentiable")
|
|
2563
|
+
|
|
2564
|
+
# TODO
|
|
2565
|
+
elif type_is_vector(target_type) or type_is_quaternion(target_type) or type_is_matrix(target_type):
|
|
2566
|
+
make_new_assign_statement()
|
|
2567
|
+
return
|
|
2568
|
+
|
|
2569
|
+
else:
|
|
2570
|
+
raise WarpCodegenError("Can only subscript in-place assign array, vector, quaternion, and matrix types")
|
|
2571
|
+
|
|
2572
|
+
# TODO
|
|
2573
|
+
elif isinstance(lhs, ast.Attribute):
|
|
2574
|
+
make_new_assign_statement()
|
|
2575
|
+
return
|
|
2576
|
+
|
|
2577
|
+
else:
|
|
2578
|
+
make_new_assign_statement()
|
|
2579
|
+
return
|
|
2388
2580
|
|
|
2389
2581
|
def emit_Tuple(adj, node):
|
|
2390
2582
|
# LHS for expressions, such as i, j, k = 1, 2, 3
|
|
@@ -2445,9 +2637,6 @@ class Adjoint:
|
|
|
2445
2637
|
if path[0] in adj.symbols:
|
|
2446
2638
|
return None
|
|
2447
2639
|
|
|
2448
|
-
if path[0] in __builtins__:
|
|
2449
|
-
return __builtins__[path[0]]
|
|
2450
|
-
|
|
2451
2640
|
# look up in closure/global variables
|
|
2452
2641
|
expr = adj.resolve_external_reference(path[0])
|
|
2453
2642
|
|
|
@@ -2455,13 +2644,201 @@ class Adjoint:
|
|
|
2455
2644
|
if expr is None:
|
|
2456
2645
|
expr = getattr(warp, path[0], None)
|
|
2457
2646
|
|
|
2458
|
-
|
|
2647
|
+
# look up in builtins
|
|
2648
|
+
if expr is None:
|
|
2649
|
+
expr = __builtins__.get(path[0])
|
|
2650
|
+
|
|
2651
|
+
if expr is not None:
|
|
2459
2652
|
for i in range(1, len(path)):
|
|
2460
2653
|
if hasattr(expr, path[i]):
|
|
2461
2654
|
expr = getattr(expr, path[i])
|
|
2462
2655
|
|
|
2463
2656
|
return expr
|
|
2464
2657
|
|
|
2658
|
+
# retrieves a dictionary of all closure and global variables and their values
|
|
2659
|
+
# to be used in the evaluation context of wp.static() expressions
|
|
2660
|
+
def get_static_evaluation_context(adj):
|
|
2661
|
+
closure_vars = dict(
|
|
2662
|
+
zip(
|
|
2663
|
+
adj.func.__code__.co_freevars,
|
|
2664
|
+
[c.cell_contents for c in (adj.func.__closure__ or [])],
|
|
2665
|
+
)
|
|
2666
|
+
)
|
|
2667
|
+
|
|
2668
|
+
vars_dict = {}
|
|
2669
|
+
vars_dict.update(adj.func.__globals__)
|
|
2670
|
+
# variables captured in closure have precedence over global vars
|
|
2671
|
+
vars_dict.update(closure_vars)
|
|
2672
|
+
|
|
2673
|
+
return vars_dict
|
|
2674
|
+
|
|
2675
|
+
def is_static_expression(adj, func):
|
|
2676
|
+
return (
|
|
2677
|
+
isinstance(func, types.FunctionType)
|
|
2678
|
+
and func.__module__ == "warp.builtins"
|
|
2679
|
+
and func.__qualname__ == "static"
|
|
2680
|
+
)
|
|
2681
|
+
|
|
2682
|
+
# verify the return type of a wp.static() expression is supported inside a Warp kernel
|
|
2683
|
+
def verify_static_return_value(adj, value):
|
|
2684
|
+
if value is None:
|
|
2685
|
+
raise ValueError("None is returned")
|
|
2686
|
+
if warp.types.is_value(value):
|
|
2687
|
+
return True
|
|
2688
|
+
if warp.types.is_array(value):
|
|
2689
|
+
# more useful explanation for the common case of creating a Warp array
|
|
2690
|
+
raise ValueError("a Warp array cannot be created inside Warp kernels")
|
|
2691
|
+
if isinstance(value, str):
|
|
2692
|
+
# we want to support cases such as `print(wp.static("test"))`
|
|
2693
|
+
return True
|
|
2694
|
+
if isinstance(value, warp.context.Function):
|
|
2695
|
+
return True
|
|
2696
|
+
|
|
2697
|
+
def verify_struct(s: StructInstance, attr_path: List[str]):
|
|
2698
|
+
for key in s._cls.vars.keys():
|
|
2699
|
+
v = getattr(s, key)
|
|
2700
|
+
if issubclass(type(v), StructInstance):
|
|
2701
|
+
verify_struct(v, attr_path + [key])
|
|
2702
|
+
else:
|
|
2703
|
+
try:
|
|
2704
|
+
adj.verify_static_return_value(v)
|
|
2705
|
+
except ValueError as e:
|
|
2706
|
+
raise ValueError(
|
|
2707
|
+
f"the returned Warp struct contains a data type that cannot be constructed inside Warp kernels: {e} at {value._cls.key}.{'.'.join(attr_path)}"
|
|
2708
|
+
) from e
|
|
2709
|
+
|
|
2710
|
+
if issubclass(type(value), StructInstance):
|
|
2711
|
+
return verify_struct(value, [])
|
|
2712
|
+
|
|
2713
|
+
raise ValueError(f"value of type {type(value)} cannot be constructed inside Warp kernels")
|
|
2714
|
+
|
|
2715
|
+
# find the source code string of an AST node
|
|
2716
|
+
def extract_node_source(adj, node) -> Optional[str]:
|
|
2717
|
+
if not hasattr(node, "lineno") or not hasattr(node, "col_offset"):
|
|
2718
|
+
return None
|
|
2719
|
+
|
|
2720
|
+
start_line = node.lineno - 1 # line numbers start at 1
|
|
2721
|
+
start_col = node.col_offset
|
|
2722
|
+
|
|
2723
|
+
if hasattr(node, "end_lineno") and hasattr(node, "end_col_offset"):
|
|
2724
|
+
end_line = node.end_lineno - 1
|
|
2725
|
+
end_col = node.end_col_offset
|
|
2726
|
+
else:
|
|
2727
|
+
# fallback for Python versions before 3.8
|
|
2728
|
+
# we have to find the end line and column manually
|
|
2729
|
+
end_line = start_line
|
|
2730
|
+
end_col = start_col
|
|
2731
|
+
parenthesis_count = 1
|
|
2732
|
+
for lineno in range(start_line, len(adj.source_lines)):
|
|
2733
|
+
if lineno == start_line:
|
|
2734
|
+
c_start = start_col
|
|
2735
|
+
else:
|
|
2736
|
+
c_start = 0
|
|
2737
|
+
line = adj.source_lines[lineno]
|
|
2738
|
+
for i in range(c_start, len(line)):
|
|
2739
|
+
c = line[i]
|
|
2740
|
+
if c == "(":
|
|
2741
|
+
parenthesis_count += 1
|
|
2742
|
+
elif c == ")":
|
|
2743
|
+
parenthesis_count -= 1
|
|
2744
|
+
if parenthesis_count == 0:
|
|
2745
|
+
end_col = i
|
|
2746
|
+
end_line = lineno
|
|
2747
|
+
break
|
|
2748
|
+
if parenthesis_count == 0:
|
|
2749
|
+
break
|
|
2750
|
+
|
|
2751
|
+
if start_line == end_line:
|
|
2752
|
+
# single-line expression
|
|
2753
|
+
return adj.source_lines[start_line][start_col:end_col]
|
|
2754
|
+
else:
|
|
2755
|
+
# multi-line expression
|
|
2756
|
+
lines = []
|
|
2757
|
+
# first line (from start_col to the end)
|
|
2758
|
+
lines.append(adj.source_lines[start_line][start_col:])
|
|
2759
|
+
# middle lines (entire lines)
|
|
2760
|
+
lines.extend(adj.source_lines[start_line + 1 : end_line])
|
|
2761
|
+
# last line (from the start to end_col)
|
|
2762
|
+
lines.append(adj.source_lines[end_line][:end_col])
|
|
2763
|
+
return "\n".join(lines).strip()
|
|
2764
|
+
|
|
2765
|
+
# handles a wp.static() expression and returns the resulting object and a string representing the code
|
|
2766
|
+
# of the static expression
|
|
2767
|
+
def evaluate_static_expression(adj, node) -> Tuple[Any, str]:
|
|
2768
|
+
if len(node.args) == 1:
|
|
2769
|
+
static_code = adj.extract_node_source(node.args[0])
|
|
2770
|
+
elif len(node.keywords) == 1:
|
|
2771
|
+
static_code = adj.extract_node_source(node.keywords[0])
|
|
2772
|
+
else:
|
|
2773
|
+
raise WarpCodegenError("warp.static() requires a single argument or keyword")
|
|
2774
|
+
if static_code is None:
|
|
2775
|
+
raise WarpCodegenError("Error extracting source code from wp.static() expression")
|
|
2776
|
+
|
|
2777
|
+
vars_dict = adj.get_static_evaluation_context()
|
|
2778
|
+
# add constant variables to the static call context
|
|
2779
|
+
constant_vars = {k: v.constant for k, v in adj.symbols.items() if isinstance(v, Var) and v.constant is not None}
|
|
2780
|
+
vars_dict.update(constant_vars)
|
|
2781
|
+
|
|
2782
|
+
try:
|
|
2783
|
+
value = eval(static_code, vars_dict)
|
|
2784
|
+
if warp.config.verbose:
|
|
2785
|
+
print(f"Evaluated static command: {static_code} = {value}")
|
|
2786
|
+
except NameError as e:
|
|
2787
|
+
raise WarpCodegenError(
|
|
2788
|
+
f"Error evaluating static expression: {e}. Make sure all variables used in the static expression are constant."
|
|
2789
|
+
) from e
|
|
2790
|
+
except Exception as e:
|
|
2791
|
+
raise WarpCodegenError(
|
|
2792
|
+
f"Error evaluating static expression: {e} while evaluating the following code generated from the static expression:\n{static_code}"
|
|
2793
|
+
) from e
|
|
2794
|
+
|
|
2795
|
+
try:
|
|
2796
|
+
adj.verify_static_return_value(value)
|
|
2797
|
+
except ValueError as e:
|
|
2798
|
+
raise WarpCodegenError(
|
|
2799
|
+
f"Static expression returns an unsupported value: {e} while evaluating the following code generated from the static expression:\n{static_code}"
|
|
2800
|
+
) from e
|
|
2801
|
+
|
|
2802
|
+
return value, static_code
|
|
2803
|
+
|
|
2804
|
+
# try to replace wp.static() expressions by their evaluated value if the
|
|
2805
|
+
# expression can be evaluated
|
|
2806
|
+
def replace_static_expressions(adj):
|
|
2807
|
+
class StaticExpressionReplacer(ast.NodeTransformer):
|
|
2808
|
+
def visit_Call(self, node):
|
|
2809
|
+
func, _ = adj.resolve_static_expression(node.func, eval_types=False)
|
|
2810
|
+
if adj.is_static_expression(func):
|
|
2811
|
+
try:
|
|
2812
|
+
# the static expression will execute as long as the static expression is valid and
|
|
2813
|
+
# only depends on global or captured variables
|
|
2814
|
+
obj, code = adj.evaluate_static_expression(node)
|
|
2815
|
+
if code is not None:
|
|
2816
|
+
adj.static_expressions[code] = obj
|
|
2817
|
+
if isinstance(obj, warp.context.Function):
|
|
2818
|
+
name_node = ast.Name("__warp_func__")
|
|
2819
|
+
# we add a pointer to the Warp function here so that we can refer to it later at
|
|
2820
|
+
# codegen time (note that the function key itself is not sufficient to uniquely
|
|
2821
|
+
# identify the function, as the function may be redefined between the current time
|
|
2822
|
+
# of wp.static() declaration and the time of codegen during module building)
|
|
2823
|
+
name_node.warp_func = obj
|
|
2824
|
+
return ast.copy_location(name_node, node)
|
|
2825
|
+
else:
|
|
2826
|
+
return ast.copy_location(ast.Constant(value=obj), node)
|
|
2827
|
+
except Exception:
|
|
2828
|
+
# Ignoring failing static expressions should generally not be an issue because only
|
|
2829
|
+
# one of these cases should be possible:
|
|
2830
|
+
# 1) the static expression itself is invalid code, in which case the module cannot be
|
|
2831
|
+
# built all,
|
|
2832
|
+
# 2) the static expression contains a reference to a local (even if constant) variable
|
|
2833
|
+
# (and is therefore not executable and raises this exception), in which
|
|
2834
|
+
# case changing the constant, or the code affecting this constant, would lead to
|
|
2835
|
+
# a different module hash anyway.
|
|
2836
|
+
pass
|
|
2837
|
+
|
|
2838
|
+
return self.generic_visit(node)
|
|
2839
|
+
|
|
2840
|
+
adj.tree = StaticExpressionReplacer().visit(adj.tree)
|
|
2841
|
+
|
|
2465
2842
|
# Evaluates a static expression that does not depend on runtime values
|
|
2466
2843
|
# if eval_types is True, try resolving the path using evaluated type information as well
|
|
2467
2844
|
def resolve_static_expression(adj, root_node, eval_types=True):
|
|
@@ -2536,34 +2913,42 @@ class Adjoint:
|
|
|
2536
2913
|
# return the Python code corresponding to the given AST node
|
|
2537
2914
|
return ast.get_source_segment(adj.source, node)
|
|
2538
2915
|
|
|
2539
|
-
def
|
|
2540
|
-
"""Traverses ``adj.tree`` and returns
|
|
2541
|
-
|
|
2542
|
-
This function is meant to be used to populate a module's constants dictionary, which then feeds
|
|
2543
|
-
into the computation of the module's ``content_hash``.
|
|
2544
|
-
"""
|
|
2916
|
+
def get_references(adj) -> Tuple[Dict[str, Any], Dict[Any, Any], Dict[warp.context.Function, Any]]:
|
|
2917
|
+
"""Traverses ``adj.tree`` and returns referenced constants, types, and user-defined functions."""
|
|
2545
2918
|
|
|
2546
2919
|
local_variables = set() # Track local variables appearing on the LHS so we know when variables are shadowed
|
|
2547
|
-
|
|
2920
|
+
|
|
2921
|
+
constants = {}
|
|
2922
|
+
types = {}
|
|
2923
|
+
functions = {}
|
|
2548
2924
|
|
|
2549
2925
|
for node in ast.walk(adj.tree):
|
|
2550
2926
|
if isinstance(node, ast.Name) and node.id not in local_variables:
|
|
2551
2927
|
# look up in closure/global variables
|
|
2552
2928
|
obj = adj.resolve_external_reference(node.id)
|
|
2553
|
-
|
|
2554
2929
|
if warp.types.is_value(obj):
|
|
2555
|
-
|
|
2930
|
+
constants[node.id] = obj
|
|
2556
2931
|
|
|
2557
2932
|
elif isinstance(node, ast.Attribute):
|
|
2558
2933
|
obj, path = adj.resolve_static_expression(node, eval_types=False)
|
|
2559
|
-
|
|
2560
2934
|
if warp.types.is_value(obj):
|
|
2561
|
-
|
|
2935
|
+
constants[".".join(path)] = obj
|
|
2936
|
+
|
|
2937
|
+
elif isinstance(node, ast.Call):
|
|
2938
|
+
func, _ = adj.resolve_static_expression(node.func, eval_types=False)
|
|
2939
|
+
if isinstance(func, warp.context.Function) and not func.is_builtin():
|
|
2940
|
+
# calling user-defined function
|
|
2941
|
+
functions[func] = None
|
|
2942
|
+
elif isinstance(func, Struct):
|
|
2943
|
+
# calling struct constructor
|
|
2944
|
+
types[func] = None
|
|
2945
|
+
elif isinstance(func, type) and warp.types.type_is_value(func):
|
|
2946
|
+
# calling value type constructor
|
|
2947
|
+
types[func] = None
|
|
2562
2948
|
|
|
2563
2949
|
elif isinstance(node, ast.Assign):
|
|
2564
2950
|
# Add the LHS names to the local_variables so we know any subsequent uses are shadowed
|
|
2565
2951
|
lhs = node.targets[0]
|
|
2566
|
-
|
|
2567
2952
|
if isinstance(lhs, ast.Tuple):
|
|
2568
2953
|
for v in lhs.elts:
|
|
2569
2954
|
if isinstance(v, ast.Name):
|
|
@@ -2571,7 +2956,7 @@ class Adjoint:
|
|
|
2571
2956
|
elif isinstance(lhs, ast.Name):
|
|
2572
2957
|
local_variables.add(lhs.id)
|
|
2573
2958
|
|
|
2574
|
-
return
|
|
2959
|
+
return constants, types, functions
|
|
2575
2960
|
|
|
2576
2961
|
|
|
2577
2962
|
# ----------------
|
|
@@ -2817,6 +3202,15 @@ def constant_str(value):
|
|
|
2817
3202
|
# make sure we emit the value of objects, e.g. uint32
|
|
2818
3203
|
return str(value.value)
|
|
2819
3204
|
|
|
3205
|
+
elif issubclass(value_type, warp.codegen.StructInstance):
|
|
3206
|
+
# constant struct instance
|
|
3207
|
+
arg_strs = []
|
|
3208
|
+
for key, var in value._cls.vars.items():
|
|
3209
|
+
attr = getattr(value, key)
|
|
3210
|
+
arg_strs.append(f"{Var.type_to_ctype(var.type)}({constant_str(attr)})")
|
|
3211
|
+
arg_str = ", ".join(arg_strs)
|
|
3212
|
+
return f"{value.native_name}({arg_str})"
|
|
3213
|
+
|
|
2820
3214
|
elif value == math.inf:
|
|
2821
3215
|
return "INFINITY"
|
|
2822
3216
|
|
|
@@ -2845,7 +3239,7 @@ def make_full_qualified_name(func):
|
|
|
2845
3239
|
|
|
2846
3240
|
|
|
2847
3241
|
def codegen_struct(struct, device="cpu", indent_size=4):
|
|
2848
|
-
name =
|
|
3242
|
+
name = struct.native_name
|
|
2849
3243
|
|
|
2850
3244
|
body = []
|
|
2851
3245
|
indent_block = " " * indent_size
|