warp-lang 1.3.2__py3-none-manylinux2014_x86_64.whl → 1.4.0__py3-none-manylinux2014_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +6 -0
- warp/autograd.py +59 -6
- warp/bin/warp.so +0 -0
- warp/build_dll.py +8 -10
- warp/builtins.py +126 -4
- warp/codegen.py +435 -53
- warp/config.py +1 -1
- warp/context.py +678 -403
- warp/dlpack.py +2 -0
- warp/examples/benchmarks/benchmark_cloth.py +10 -0
- warp/examples/core/example_render_opengl.py +12 -10
- warp/examples/fem/example_adaptive_grid.py +251 -0
- warp/examples/fem/example_apic_fluid.py +1 -1
- warp/examples/fem/example_diffusion_3d.py +2 -2
- warp/examples/fem/example_magnetostatics.py +1 -1
- warp/examples/fem/example_streamlines.py +1 -0
- warp/examples/fem/utils.py +23 -4
- warp/examples/sim/example_cloth.py +50 -6
- warp/fem/__init__.py +2 -0
- warp/fem/adaptivity.py +493 -0
- warp/fem/field/field.py +2 -1
- warp/fem/field/nodal_field.py +18 -26
- warp/fem/field/test.py +4 -4
- warp/fem/field/trial.py +4 -4
- warp/fem/geometry/__init__.py +1 -0
- warp/fem/geometry/adaptive_nanogrid.py +843 -0
- warp/fem/geometry/nanogrid.py +55 -28
- warp/fem/space/__init__.py +1 -1
- warp/fem/space/nanogrid_function_space.py +69 -35
- warp/fem/utils.py +113 -107
- warp/jax_experimental.py +28 -15
- warp/native/array.h +0 -1
- warp/native/builtin.h +103 -6
- warp/native/bvh.cu +2 -0
- warp/native/cuda_util.cpp +14 -0
- warp/native/cuda_util.h +2 -0
- warp/native/error.cpp +4 -2
- warp/native/exports.h +99 -17
- warp/native/mat.h +97 -0
- warp/native/mesh.cpp +36 -0
- warp/native/mesh.cu +51 -0
- warp/native/mesh.h +1 -0
- warp/native/quat.h +43 -0
- warp/native/spatial.h +6 -0
- warp/native/vec.h +74 -0
- warp/native/warp.cpp +2 -1
- warp/native/warp.cu +10 -3
- warp/native/warp.h +8 -1
- warp/paddle.py +382 -0
- warp/sim/__init__.py +1 -0
- warp/sim/collide.py +519 -0
- warp/sim/integrator_euler.py +18 -5
- warp/sim/integrator_featherstone.py +5 -5
- warp/sim/integrator_vbd.py +1026 -0
- warp/sim/model.py +49 -23
- warp/stubs.py +459 -0
- warp/tape.py +2 -0
- warp/tests/aux_test_dependent.py +1 -0
- warp/tests/aux_test_name_clash1.py +32 -0
- warp/tests/aux_test_name_clash2.py +32 -0
- warp/tests/aux_test_square.py +1 -0
- warp/tests/test_array.py +222 -0
- warp/tests/test_async.py +3 -3
- warp/tests/test_atomic.py +6 -0
- warp/tests/test_closest_point_edge_edge.py +93 -1
- warp/tests/test_codegen.py +62 -15
- warp/tests/test_codegen_instancing.py +1457 -0
- warp/tests/test_collision.py +486 -0
- warp/tests/test_compile_consts.py +3 -28
- warp/tests/test_dlpack.py +170 -0
- warp/tests/test_examples.py +22 -8
- warp/tests/test_fast_math.py +10 -4
- warp/tests/test_fem.py +64 -0
- warp/tests/test_func.py +46 -0
- warp/tests/test_implicit_init.py +49 -0
- warp/tests/test_jax.py +58 -0
- warp/tests/test_mat.py +84 -0
- warp/tests/test_mesh_query_point.py +188 -0
- warp/tests/test_module_hashing.py +40 -0
- warp/tests/test_multigpu.py +3 -3
- warp/tests/test_overwrite.py +8 -0
- warp/tests/test_paddle.py +852 -0
- warp/tests/test_print.py +89 -0
- warp/tests/test_quat.py +111 -0
- warp/tests/test_reload.py +31 -1
- warp/tests/test_scalar_ops.py +2 -0
- warp/tests/test_static.py +412 -0
- warp/tests/test_streams.py +64 -3
- warp/tests/test_struct.py +4 -4
- warp/tests/test_torch.py +24 -0
- warp/tests/test_triangle_closest_point.py +137 -0
- warp/tests/test_types.py +1 -1
- warp/tests/test_vbd.py +386 -0
- warp/tests/test_vec.py +143 -0
- warp/tests/test_vec_scalar_ops.py +139 -0
- warp/tests/test_volume.py +30 -0
- warp/tests/unittest_suites.py +12 -0
- warp/tests/unittest_utils.py +9 -5
- warp/thirdparty/dlpack.py +3 -1
- warp/types.py +157 -34
- warp/utils.py +37 -14
- {warp_lang-1.3.2.dist-info → warp_lang-1.4.0.dist-info}/METADATA +10 -8
- {warp_lang-1.3.2.dist-info → warp_lang-1.4.0.dist-info}/RECORD +106 -94
- warp/tests/test_point_triangle_closest_point.py +0 -143
- {warp_lang-1.3.2.dist-info → warp_lang-1.4.0.dist-info}/LICENSE.md +0 -0
- {warp_lang-1.3.2.dist-info → warp_lang-1.4.0.dist-info}/WHEEL +0 -0
- {warp_lang-1.3.2.dist-info → warp_lang-1.4.0.dist-info}/top_level.txt +0 -0
warp/codegen.py
CHANGED
|
@@ -11,6 +11,7 @@ import ast
|
|
|
11
11
|
import builtins
|
|
12
12
|
import ctypes
|
|
13
13
|
import functools
|
|
14
|
+
import hashlib
|
|
14
15
|
import inspect
|
|
15
16
|
import math
|
|
16
17
|
import re
|
|
@@ -232,8 +233,11 @@ class StructInstance:
|
|
|
232
233
|
|
|
233
234
|
def __getattribute__(self, name):
|
|
234
235
|
cls = super().__getattribute__("_cls")
|
|
235
|
-
if name
|
|
236
|
-
|
|
236
|
+
if name == "native_name":
|
|
237
|
+
return cls.native_name
|
|
238
|
+
|
|
239
|
+
var = cls.vars.get(name)
|
|
240
|
+
if var is not None:
|
|
237
241
|
if isinstance(var.type, type) and issubclass(var.type, ctypes.Array):
|
|
238
242
|
# Each field stored in a `StructInstance` is exposed as
|
|
239
243
|
# a standard Python attribute but also has a `ctypes`
|
|
@@ -408,6 +412,9 @@ class Struct:
|
|
|
408
412
|
elif issubclass(var.type, ctypes.Array):
|
|
409
413
|
fields.append((label, var.type))
|
|
410
414
|
else:
|
|
415
|
+
# HACK: fp16 requires conversion functions from warp.so
|
|
416
|
+
if var.type is warp.float16:
|
|
417
|
+
warp.init()
|
|
411
418
|
fields.append((label, var.type._type_))
|
|
412
419
|
|
|
413
420
|
class StructType(ctypes.Structure):
|
|
@@ -416,15 +423,35 @@ class Struct:
|
|
|
416
423
|
|
|
417
424
|
self.ctype = StructType
|
|
418
425
|
|
|
426
|
+
# Compute the hash. We can cache the hash because it's static, even with nested structs.
|
|
427
|
+
# All field types are specified in the annotations, so they're resolved at declaration time.
|
|
428
|
+
ch = hashlib.sha256()
|
|
429
|
+
|
|
430
|
+
ch.update(bytes(self.key, "utf-8"))
|
|
431
|
+
|
|
432
|
+
for name, type_hint in annotations.items():
|
|
433
|
+
s = f"{name}:{warp.types.get_type_code(type_hint)}"
|
|
434
|
+
ch.update(bytes(s, "utf-8"))
|
|
435
|
+
|
|
436
|
+
# recurse on nested structs
|
|
437
|
+
if isinstance(type_hint, Struct):
|
|
438
|
+
ch.update(type_hint.hash)
|
|
439
|
+
|
|
440
|
+
self.hash = ch.digest()
|
|
441
|
+
|
|
442
|
+
# generate unique identifier for structs in native code
|
|
443
|
+
hash_suffix = f"{self.hash.hex()[:8]}"
|
|
444
|
+
self.native_name = f"{self.key}_{hash_suffix}"
|
|
445
|
+
|
|
419
446
|
# create default constructor (zero-initialize)
|
|
420
447
|
self.default_constructor = warp.context.Function(
|
|
421
448
|
func=None,
|
|
422
|
-
key=self.
|
|
449
|
+
key=self.native_name,
|
|
423
450
|
namespace="",
|
|
424
451
|
value_func=lambda *_: self,
|
|
425
452
|
input_types={},
|
|
426
453
|
initializer_list_func=lambda *_: False,
|
|
427
|
-
native_func=
|
|
454
|
+
native_func=self.native_name,
|
|
428
455
|
)
|
|
429
456
|
|
|
430
457
|
# build a constructor that takes each param as a value
|
|
@@ -432,12 +459,12 @@ class Struct:
|
|
|
432
459
|
|
|
433
460
|
self.value_constructor = warp.context.Function(
|
|
434
461
|
func=None,
|
|
435
|
-
key=self.
|
|
462
|
+
key=self.native_name,
|
|
436
463
|
namespace="",
|
|
437
464
|
value_func=lambda *_: self,
|
|
438
465
|
input_types=input_types,
|
|
439
466
|
initializer_list_func=lambda *_: False,
|
|
440
|
-
native_func=
|
|
467
|
+
native_func=self.native_name,
|
|
441
468
|
)
|
|
442
469
|
|
|
443
470
|
self.default_constructor.add_overload(self.value_constructor)
|
|
@@ -465,6 +492,10 @@ class Struct:
|
|
|
465
492
|
def __init__(inst):
|
|
466
493
|
StructInstance.__init__(inst, self, None)
|
|
467
494
|
|
|
495
|
+
# make sure warp.types.get_type_code works with this StructInstance
|
|
496
|
+
NewStructInstance.cls = self.cls
|
|
497
|
+
NewStructInstance.native_name = self.native_name
|
|
498
|
+
|
|
468
499
|
return NewStructInstance()
|
|
469
500
|
|
|
470
501
|
def initializer(self):
|
|
@@ -599,7 +630,7 @@ class Var:
|
|
|
599
630
|
if hasattr(t.dtype, "_wp_generic_type_str_"):
|
|
600
631
|
dtypestr = compute_type_str(f"wp::{t.dtype._wp_generic_type_str_}", t.dtype._wp_type_params_)
|
|
601
632
|
elif isinstance(t.dtype, Struct):
|
|
602
|
-
dtypestr =
|
|
633
|
+
dtypestr = t.dtype.native_name
|
|
603
634
|
elif t.dtype.__name__ in ("bool", "int", "float"):
|
|
604
635
|
dtypestr = t.dtype.__name__
|
|
605
636
|
else:
|
|
@@ -607,7 +638,10 @@ class Var:
|
|
|
607
638
|
classstr = f"wp::{type(t).__name__}"
|
|
608
639
|
return f"{classstr}_t<{dtypestr}>"
|
|
609
640
|
elif isinstance(t, Struct):
|
|
610
|
-
return
|
|
641
|
+
return t.native_name
|
|
642
|
+
elif isinstance(t, type) and issubclass(t, StructInstance):
|
|
643
|
+
# ensure the actual Struct name is used instead of "NewStructInstance"
|
|
644
|
+
return t.native_name
|
|
611
645
|
elif is_reference(t):
|
|
612
646
|
if not value_type:
|
|
613
647
|
return Var.type_to_ctype(t.value_type) + "*"
|
|
@@ -863,6 +897,12 @@ class Adjoint:
|
|
|
863
897
|
# this is to avoid registering false references to overshadowed modules
|
|
864
898
|
adj.symbols[name] = arg
|
|
865
899
|
|
|
900
|
+
# try to replace static expressions by their constant result if the
|
|
901
|
+
# expression can be evaluated at declaration time
|
|
902
|
+
adj.static_expressions: Dict[str, Any] = {}
|
|
903
|
+
if "static" in adj.source:
|
|
904
|
+
adj.replace_static_expressions()
|
|
905
|
+
|
|
866
906
|
# There are cases where a same module might be rebuilt multiple times,
|
|
867
907
|
# for example when kernels are nested inside of functions, or when
|
|
868
908
|
# a kernel's launch raises an exception. Ideally we'd always want to
|
|
@@ -896,6 +936,7 @@ class Adjoint:
|
|
|
896
936
|
|
|
897
937
|
adj.return_var = None # return type for function or kernel
|
|
898
938
|
adj.loop_symbols = [] # symbols at the start of each loop
|
|
939
|
+
adj.loop_const_iter_symbols = set() # iteration variables (constant) for static loops
|
|
899
940
|
|
|
900
941
|
# blocks
|
|
901
942
|
adj.blocks = [Block()]
|
|
@@ -948,9 +989,9 @@ class Adjoint:
|
|
|
948
989
|
if isinstance(a, warp.context.Function):
|
|
949
990
|
# functions don't have a var_ prefix so strip it off here
|
|
950
991
|
if prefix == "var":
|
|
951
|
-
arg_strs.append(a.
|
|
992
|
+
arg_strs.append(a.native_func)
|
|
952
993
|
else:
|
|
953
|
-
arg_strs.append(f"{prefix}_{a.
|
|
994
|
+
arg_strs.append(f"{prefix}_{a.native_func}")
|
|
954
995
|
elif is_reference(a.type):
|
|
955
996
|
arg_strs.append(f"{prefix}_{a}")
|
|
956
997
|
elif isinstance(a, Var):
|
|
@@ -1255,6 +1296,10 @@ class Adjoint:
|
|
|
1255
1296
|
if not isinstance(func_arg, (Reference, warp.context.Function)):
|
|
1256
1297
|
func_arg = adj.load(func_arg)
|
|
1257
1298
|
|
|
1299
|
+
# if the argument is a function, build it recursively
|
|
1300
|
+
if isinstance(func_arg, warp.context.Function):
|
|
1301
|
+
adj.builder.build_function(func_arg)
|
|
1302
|
+
|
|
1258
1303
|
fwd_args.append(strip_reference(func_arg))
|
|
1259
1304
|
|
|
1260
1305
|
if return_type is None:
|
|
@@ -1440,6 +1485,7 @@ class Adjoint:
|
|
|
1440
1485
|
cond_block.body_forward.append(f"start_{cond_block.label}:;")
|
|
1441
1486
|
|
|
1442
1487
|
c = adj.eval(cond)
|
|
1488
|
+
c = adj.load(c)
|
|
1443
1489
|
|
|
1444
1490
|
cond_block.body_forward.append(f"if (({c.emit()}) == false) goto end_{cond_block.label};")
|
|
1445
1491
|
|
|
@@ -1493,6 +1539,9 @@ class Adjoint:
|
|
|
1493
1539
|
|
|
1494
1540
|
def emit_FunctionDef(adj, node):
|
|
1495
1541
|
for f in node.body:
|
|
1542
|
+
# Skip variable creation for standalone constants, including docstrings
|
|
1543
|
+
if isinstance(f, ast.Expr) and isinstance(f.value, ast.Constant):
|
|
1544
|
+
continue
|
|
1496
1545
|
adj.eval(f)
|
|
1497
1546
|
|
|
1498
1547
|
if adj.return_var is not None and len(adj.return_var) == 1:
|
|
@@ -1523,6 +1572,16 @@ class Adjoint:
|
|
|
1523
1572
|
# eval condition
|
|
1524
1573
|
cond = adj.eval(node.test)
|
|
1525
1574
|
|
|
1575
|
+
if cond.constant is not None:
|
|
1576
|
+
# resolve constant condition
|
|
1577
|
+
if cond.constant:
|
|
1578
|
+
for stmt in node.body:
|
|
1579
|
+
adj.eval(stmt)
|
|
1580
|
+
else:
|
|
1581
|
+
for stmt in node.orelse:
|
|
1582
|
+
adj.eval(stmt)
|
|
1583
|
+
return None
|
|
1584
|
+
|
|
1526
1585
|
# save symbol map
|
|
1527
1586
|
symbols_prev = adj.symbols.copy()
|
|
1528
1587
|
|
|
@@ -1618,7 +1677,7 @@ class Adjoint:
|
|
|
1618
1677
|
if isinstance(obj, types.ModuleType):
|
|
1619
1678
|
return obj
|
|
1620
1679
|
|
|
1621
|
-
raise
|
|
1680
|
+
raise TypeError(f"Invalid external reference type: {type(obj)}")
|
|
1622
1681
|
|
|
1623
1682
|
@staticmethod
|
|
1624
1683
|
def resolve_type_attribute(var_type: type, attr: str):
|
|
@@ -1732,7 +1791,7 @@ class Adjoint:
|
|
|
1732
1791
|
|
|
1733
1792
|
def emit_NameConstant(adj, node):
|
|
1734
1793
|
if node.value:
|
|
1735
|
-
return adj.add_constant(
|
|
1794
|
+
return adj.add_constant(node.value)
|
|
1736
1795
|
elif node.value is None:
|
|
1737
1796
|
raise WarpCodegenTypeError("None type unsupported")
|
|
1738
1797
|
else:
|
|
@@ -1746,7 +1805,7 @@ class Adjoint:
|
|
|
1746
1805
|
elif isinstance(node, ast.Ellipsis):
|
|
1747
1806
|
return adj.emit_Ellipsis(node)
|
|
1748
1807
|
else:
|
|
1749
|
-
assert isinstance(node, ast.NameConstant)
|
|
1808
|
+
assert isinstance(node, ast.NameConstant) or isinstance(node, ast.Constant)
|
|
1750
1809
|
return adj.emit_NameConstant(node)
|
|
1751
1810
|
|
|
1752
1811
|
def emit_BinOp(adj, node):
|
|
@@ -1787,6 +1846,11 @@ class Adjoint:
|
|
|
1787
1846
|
# detect symbols with conflicting definitions (assigned inside the for loop)
|
|
1788
1847
|
for items in symbols.items():
|
|
1789
1848
|
sym = items[0]
|
|
1849
|
+
if adj.loop_const_iter_symbols is not None and sym in adj.loop_const_iter_symbols:
|
|
1850
|
+
# ignore constant overwriting in for-loops if it is a loop iterator
|
|
1851
|
+
# (it is no problem to unroll static loops multiple times in sequence)
|
|
1852
|
+
continue
|
|
1853
|
+
|
|
1790
1854
|
var1 = items[1]
|
|
1791
1855
|
var2 = adj.symbols[sym]
|
|
1792
1856
|
|
|
@@ -1933,15 +1997,27 @@ class Adjoint:
|
|
|
1933
1997
|
)
|
|
1934
1998
|
return range_call
|
|
1935
1999
|
|
|
2000
|
+
def begin_record_constant_iter_symbols(adj):
|
|
2001
|
+
if adj.loop_const_iter_symbols is None:
|
|
2002
|
+
adj.loop_const_iter_symbols = set()
|
|
2003
|
+
|
|
2004
|
+
def end_record_constant_iter_symbols(adj):
|
|
2005
|
+
adj.loop_const_iter_symbols = None
|
|
2006
|
+
|
|
1936
2007
|
def emit_For(adj, node):
|
|
1937
2008
|
# try and unroll simple range() statements that use constant args
|
|
1938
2009
|
unroll_range = adj.get_unroll_range(node)
|
|
1939
2010
|
|
|
1940
2011
|
if isinstance(unroll_range, range):
|
|
2012
|
+
const_iter_sym = node.target.id
|
|
2013
|
+
if adj.loop_const_iter_symbols is not None:
|
|
2014
|
+
# prevent constant conflicts in `materialize_redefinitions()`
|
|
2015
|
+
adj.loop_const_iter_symbols.add(const_iter_sym)
|
|
2016
|
+
|
|
2017
|
+
# unroll static for-loop
|
|
1941
2018
|
for i in unroll_range:
|
|
1942
2019
|
const_iter = adj.add_constant(i)
|
|
1943
|
-
|
|
1944
|
-
adj.symbols[node.target.id] = var_iter
|
|
2020
|
+
adj.symbols[const_iter_sym] = const_iter
|
|
1945
2021
|
|
|
1946
2022
|
# eval body
|
|
1947
2023
|
for s in node.body:
|
|
@@ -1957,6 +2033,7 @@ class Adjoint:
|
|
|
1957
2033
|
iter = adj.eval(node.iter)
|
|
1958
2034
|
|
|
1959
2035
|
adj.symbols[node.target.id] = adj.begin_for(iter)
|
|
2036
|
+
adj.begin_record_constant_iter_symbols()
|
|
1960
2037
|
|
|
1961
2038
|
# for loops should be side-effect free, here we store a copy
|
|
1962
2039
|
adj.loop_symbols.append(adj.symbols.copy())
|
|
@@ -1967,6 +2044,7 @@ class Adjoint:
|
|
|
1967
2044
|
|
|
1968
2045
|
adj.materialize_redefinitions(adj.loop_symbols[-1])
|
|
1969
2046
|
adj.loop_symbols.pop()
|
|
2047
|
+
adj.end_record_constant_iter_symbols()
|
|
1970
2048
|
|
|
1971
2049
|
adj.end_for(iter)
|
|
1972
2050
|
|
|
@@ -2023,13 +2101,28 @@ class Adjoint:
|
|
|
2023
2101
|
|
|
2024
2102
|
# try and lookup function in globals by
|
|
2025
2103
|
# resolving path (e.g.: module.submodule.attr)
|
|
2026
|
-
|
|
2104
|
+
if hasattr(node.func, "warp_func"):
|
|
2105
|
+
func = node.func.warp_func
|
|
2106
|
+
path = []
|
|
2107
|
+
else:
|
|
2108
|
+
func, path = adj.resolve_static_expression(node.func)
|
|
2027
2109
|
if func is None:
|
|
2028
2110
|
func = adj.eval(node.func)
|
|
2029
2111
|
|
|
2112
|
+
if adj.is_static_expression(func):
|
|
2113
|
+
# try to evaluate wp.static() expressions
|
|
2114
|
+
obj, _ = adj.evaluate_static_expression(node)
|
|
2115
|
+
if obj is not None:
|
|
2116
|
+
if isinstance(obj, warp.context.Function):
|
|
2117
|
+
# special handling for wp.static() evaluating to a function
|
|
2118
|
+
return obj
|
|
2119
|
+
else:
|
|
2120
|
+
out = adj.add_constant(obj)
|
|
2121
|
+
return out
|
|
2122
|
+
|
|
2030
2123
|
type_args = {}
|
|
2031
2124
|
|
|
2032
|
-
if not isinstance(func, warp.context.Function):
|
|
2125
|
+
if len(path) > 0 and not isinstance(func, warp.context.Function):
|
|
2033
2126
|
attr = path[-1]
|
|
2034
2127
|
caller = func
|
|
2035
2128
|
func = None
|
|
@@ -2083,6 +2176,9 @@ class Adjoint:
|
|
|
2083
2176
|
args = tuple(adj.resolve_arg(x) for x in node.args)
|
|
2084
2177
|
kwargs = {x.arg: adj.resolve_arg(x.value) for x in node.keywords}
|
|
2085
2178
|
|
|
2179
|
+
# add the call and build the callee adjoint if needed (func.adj)
|
|
2180
|
+
out = adj.add_call(func, args, kwargs, type_args, min_outputs=min_outputs)
|
|
2181
|
+
|
|
2086
2182
|
if warp.config.verify_autograd_array_access:
|
|
2087
2183
|
# update arg read/write states according to what happens to that arg in the called function
|
|
2088
2184
|
if hasattr(func, "adj"):
|
|
@@ -2095,7 +2191,6 @@ class Adjoint:
|
|
|
2095
2191
|
if func.adj.args[i].is_read:
|
|
2096
2192
|
arg.mark_read()
|
|
2097
2193
|
|
|
2098
|
-
out = adj.add_call(func, args, kwargs, type_args, min_outputs=min_outputs)
|
|
2099
2194
|
return out
|
|
2100
2195
|
|
|
2101
2196
|
def emit_Index(adj, node):
|
|
@@ -2281,20 +2376,40 @@ class Adjoint:
|
|
|
2281
2376
|
target.mark_write(kernel_name=kernel_name, filename=filename, lineno=lineno)
|
|
2282
2377
|
|
|
2283
2378
|
elif type_is_vector(target_type) or type_is_quaternion(target_type) or type_is_matrix(target_type):
|
|
2379
|
+
# recursively unwind AST, stopping at penultimate node
|
|
2380
|
+
node = lhs
|
|
2381
|
+
while hasattr(node, "value"):
|
|
2382
|
+
if hasattr(node.value, "value"):
|
|
2383
|
+
node = node.value
|
|
2384
|
+
else:
|
|
2385
|
+
break
|
|
2386
|
+
# lhs is updating a variable adjoint (i.e. wp.adjoint[var])
|
|
2387
|
+
if hasattr(node, "attr") and node.attr == "adjoint":
|
|
2388
|
+
attr = adj.add_builtin_call("index", [target, *indices])
|
|
2389
|
+
adj.add_builtin_call("store", [attr, rhs])
|
|
2390
|
+
return
|
|
2391
|
+
|
|
2392
|
+
# TODO: array vec component case
|
|
2284
2393
|
if is_reference(target.type):
|
|
2285
2394
|
attr = adj.add_builtin_call("indexref", [target, *indices])
|
|
2286
|
-
|
|
2287
|
-
attr = adj.add_builtin_call("index", [target, *indices])
|
|
2395
|
+
adj.add_builtin_call("store", [attr, rhs])
|
|
2288
2396
|
|
|
2289
|
-
|
|
2397
|
+
if warp.config.verbose and not adj.custom_reverse_mode:
|
|
2398
|
+
lineno = adj.lineno + adj.fun_lineno
|
|
2399
|
+
line = adj.source_lines[adj.lineno]
|
|
2400
|
+
node_source = adj.get_node_source(lhs.value)
|
|
2401
|
+
print(
|
|
2402
|
+
f"Warning: mutating {node_source} in function {adj.fun_name} at {adj.filename}:{lineno}: this is a non-differentiable operation.\n{line}\n"
|
|
2403
|
+
)
|
|
2290
2404
|
|
|
2291
|
-
|
|
2292
|
-
|
|
2293
|
-
|
|
2294
|
-
|
|
2295
|
-
|
|
2296
|
-
|
|
2297
|
-
|
|
2405
|
+
else:
|
|
2406
|
+
out = adj.add_builtin_call("assign", [target, *indices, rhs])
|
|
2407
|
+
|
|
2408
|
+
# re-point target symbol to out var
|
|
2409
|
+
for id in adj.symbols:
|
|
2410
|
+
if adj.symbols[id] == target:
|
|
2411
|
+
adj.symbols[id] = out
|
|
2412
|
+
break
|
|
2298
2413
|
|
|
2299
2414
|
else:
|
|
2300
2415
|
raise WarpCodegenError(
|
|
@@ -2329,16 +2444,24 @@ class Adjoint:
|
|
|
2329
2444
|
aggregate = adj.eval(lhs.value)
|
|
2330
2445
|
aggregate_type = strip_reference(aggregate.type)
|
|
2331
2446
|
|
|
2332
|
-
# assigning to a vector component
|
|
2333
|
-
if type_is_vector(aggregate_type):
|
|
2447
|
+
# assigning to a vector or quaternion component
|
|
2448
|
+
if type_is_vector(aggregate_type) or type_is_quaternion(aggregate_type):
|
|
2449
|
+
# TODO: handle wp.adjoint case
|
|
2450
|
+
|
|
2334
2451
|
index = adj.vector_component_index(lhs.attr, aggregate_type)
|
|
2335
2452
|
|
|
2453
|
+
# TODO: array vec component case
|
|
2336
2454
|
if is_reference(aggregate.type):
|
|
2337
2455
|
attr = adj.add_builtin_call("indexref", [aggregate, index])
|
|
2456
|
+
adj.add_builtin_call("store", [attr, rhs])
|
|
2338
2457
|
else:
|
|
2339
|
-
|
|
2458
|
+
out = adj.add_builtin_call("assign", [aggregate, index, rhs])
|
|
2340
2459
|
|
|
2341
|
-
|
|
2460
|
+
# re-point target symbol to out var
|
|
2461
|
+
for id in adj.symbols:
|
|
2462
|
+
if adj.symbols[id] == aggregate:
|
|
2463
|
+
adj.symbols[id] = out
|
|
2464
|
+
break
|
|
2342
2465
|
|
|
2343
2466
|
else:
|
|
2344
2467
|
attr = adj.emit_Attribute(lhs)
|
|
@@ -2382,9 +2505,66 @@ class Adjoint:
|
|
|
2382
2505
|
adj.add_return(adj.return_var)
|
|
2383
2506
|
|
|
2384
2507
|
def emit_AugAssign(adj, node):
|
|
2385
|
-
|
|
2386
|
-
|
|
2387
|
-
|
|
2508
|
+
lhs = node.target
|
|
2509
|
+
|
|
2510
|
+
# replace augmented assignment with assignment statement + binary op (default behaviour)
|
|
2511
|
+
def make_new_assign_statement():
|
|
2512
|
+
new_node = ast.Assign(targets=[lhs], value=ast.BinOp(lhs, node.op, node.value))
|
|
2513
|
+
adj.eval(new_node)
|
|
2514
|
+
|
|
2515
|
+
if isinstance(lhs, ast.Subscript):
|
|
2516
|
+
rhs = adj.eval(node.value)
|
|
2517
|
+
|
|
2518
|
+
# wp.adjoint[var] appears in custom grad functions, and does not require
|
|
2519
|
+
# special consideration in the AugAssign case
|
|
2520
|
+
if hasattr(lhs.value, "attr") and lhs.value.attr == "adjoint":
|
|
2521
|
+
make_new_assign_statement()
|
|
2522
|
+
return
|
|
2523
|
+
|
|
2524
|
+
target, indices = adj.eval_subscript(lhs)
|
|
2525
|
+
|
|
2526
|
+
target_type = strip_reference(target.type)
|
|
2527
|
+
|
|
2528
|
+
if is_array(target_type):
|
|
2529
|
+
# target_type is not suitable for atomic array accumulation
|
|
2530
|
+
if target_type.dtype not in warp.types.atomic_types:
|
|
2531
|
+
make_new_assign_statement()
|
|
2532
|
+
return
|
|
2533
|
+
|
|
2534
|
+
kernel_name = adj.fun_name
|
|
2535
|
+
filename = adj.filename
|
|
2536
|
+
lineno = adj.lineno + adj.fun_lineno
|
|
2537
|
+
|
|
2538
|
+
if isinstance(node.op, ast.Add):
|
|
2539
|
+
adj.add_builtin_call("atomic_add", [target, *indices, rhs])
|
|
2540
|
+
|
|
2541
|
+
if warp.config.verify_autograd_array_access:
|
|
2542
|
+
target.mark_write(kernel_name=kernel_name, filename=filename, lineno=lineno)
|
|
2543
|
+
|
|
2544
|
+
elif isinstance(node.op, ast.Sub):
|
|
2545
|
+
adj.add_builtin_call("atomic_sub", [target, *indices, rhs])
|
|
2546
|
+
|
|
2547
|
+
if warp.config.verify_autograd_array_access:
|
|
2548
|
+
target.mark_write(kernel_name=kernel_name, filename=filename, lineno=lineno)
|
|
2549
|
+
else:
|
|
2550
|
+
print(f"Warning: in-place op {node.op} is not differentiable")
|
|
2551
|
+
|
|
2552
|
+
# TODO
|
|
2553
|
+
elif type_is_vector(target_type) or type_is_quaternion(target_type) or type_is_matrix(target_type):
|
|
2554
|
+
make_new_assign_statement()
|
|
2555
|
+
return
|
|
2556
|
+
|
|
2557
|
+
else:
|
|
2558
|
+
raise WarpCodegenError("Can only subscript in-place assign array, vector, quaternion, and matrix types")
|
|
2559
|
+
|
|
2560
|
+
# TODO
|
|
2561
|
+
elif isinstance(lhs, ast.Attribute):
|
|
2562
|
+
make_new_assign_statement()
|
|
2563
|
+
return
|
|
2564
|
+
|
|
2565
|
+
else:
|
|
2566
|
+
make_new_assign_statement()
|
|
2567
|
+
return
|
|
2388
2568
|
|
|
2389
2569
|
def emit_Tuple(adj, node):
|
|
2390
2570
|
# LHS for expressions, such as i, j, k = 1, 2, 3
|
|
@@ -2445,9 +2625,6 @@ class Adjoint:
|
|
|
2445
2625
|
if path[0] in adj.symbols:
|
|
2446
2626
|
return None
|
|
2447
2627
|
|
|
2448
|
-
if path[0] in __builtins__:
|
|
2449
|
-
return __builtins__[path[0]]
|
|
2450
|
-
|
|
2451
2628
|
# look up in closure/global variables
|
|
2452
2629
|
expr = adj.resolve_external_reference(path[0])
|
|
2453
2630
|
|
|
@@ -2455,13 +2632,201 @@ class Adjoint:
|
|
|
2455
2632
|
if expr is None:
|
|
2456
2633
|
expr = getattr(warp, path[0], None)
|
|
2457
2634
|
|
|
2458
|
-
|
|
2635
|
+
# look up in builtins
|
|
2636
|
+
if expr is None:
|
|
2637
|
+
expr = __builtins__.get(path[0])
|
|
2638
|
+
|
|
2639
|
+
if expr is not None:
|
|
2459
2640
|
for i in range(1, len(path)):
|
|
2460
2641
|
if hasattr(expr, path[i]):
|
|
2461
2642
|
expr = getattr(expr, path[i])
|
|
2462
2643
|
|
|
2463
2644
|
return expr
|
|
2464
2645
|
|
|
2646
|
+
# retrieves a dictionary of all closure and global variables and their values
|
|
2647
|
+
# to be used in the evaluation context of wp.static() expressions
|
|
2648
|
+
def get_static_evaluation_context(adj):
|
|
2649
|
+
closure_vars = dict(
|
|
2650
|
+
zip(
|
|
2651
|
+
adj.func.__code__.co_freevars,
|
|
2652
|
+
[c.cell_contents for c in (adj.func.__closure__ or [])],
|
|
2653
|
+
)
|
|
2654
|
+
)
|
|
2655
|
+
|
|
2656
|
+
vars_dict = {}
|
|
2657
|
+
vars_dict.update(adj.func.__globals__)
|
|
2658
|
+
# variables captured in closure have precedence over global vars
|
|
2659
|
+
vars_dict.update(closure_vars)
|
|
2660
|
+
|
|
2661
|
+
return vars_dict
|
|
2662
|
+
|
|
2663
|
+
def is_static_expression(adj, func):
|
|
2664
|
+
return (
|
|
2665
|
+
isinstance(func, types.FunctionType)
|
|
2666
|
+
and func.__module__ == "warp.builtins"
|
|
2667
|
+
and func.__qualname__ == "static"
|
|
2668
|
+
)
|
|
2669
|
+
|
|
2670
|
+
# verify the return type of a wp.static() expression is supported inside a Warp kernel
|
|
2671
|
+
def verify_static_return_value(adj, value):
|
|
2672
|
+
if value is None:
|
|
2673
|
+
raise ValueError("None is returned")
|
|
2674
|
+
if warp.types.is_value(value):
|
|
2675
|
+
return True
|
|
2676
|
+
if warp.types.is_array(value):
|
|
2677
|
+
# more useful explanation for the common case of creating a Warp array
|
|
2678
|
+
raise ValueError("a Warp array cannot be created inside Warp kernels")
|
|
2679
|
+
if isinstance(value, str):
|
|
2680
|
+
# we want to support cases such as `print(wp.static("test"))`
|
|
2681
|
+
return True
|
|
2682
|
+
if isinstance(value, warp.context.Function):
|
|
2683
|
+
return True
|
|
2684
|
+
|
|
2685
|
+
def verify_struct(s: StructInstance, attr_path: List[str]):
|
|
2686
|
+
for key in s._cls.vars.keys():
|
|
2687
|
+
v = getattr(s, key)
|
|
2688
|
+
if issubclass(type(v), StructInstance):
|
|
2689
|
+
verify_struct(v, attr_path + [key])
|
|
2690
|
+
else:
|
|
2691
|
+
try:
|
|
2692
|
+
adj.verify_static_return_value(v)
|
|
2693
|
+
except ValueError as e:
|
|
2694
|
+
raise ValueError(
|
|
2695
|
+
f"the returned Warp struct contains a data type that cannot be constructed inside Warp kernels: {e} at {value._cls.key}.{'.'.join(attr_path)}"
|
|
2696
|
+
) from e
|
|
2697
|
+
|
|
2698
|
+
if issubclass(type(value), StructInstance):
|
|
2699
|
+
return verify_struct(value, [])
|
|
2700
|
+
|
|
2701
|
+
raise ValueError(f"value of type {type(value)} cannot be constructed inside Warp kernels")
|
|
2702
|
+
|
|
2703
|
+
# find the source code string of an AST node
|
|
2704
|
+
def extract_node_source(adj, node) -> Optional[str]:
|
|
2705
|
+
if not hasattr(node, "lineno") or not hasattr(node, "col_offset"):
|
|
2706
|
+
return None
|
|
2707
|
+
|
|
2708
|
+
start_line = node.lineno - 1 # line numbers start at 1
|
|
2709
|
+
start_col = node.col_offset
|
|
2710
|
+
|
|
2711
|
+
if hasattr(node, "end_lineno") and hasattr(node, "end_col_offset"):
|
|
2712
|
+
end_line = node.end_lineno - 1
|
|
2713
|
+
end_col = node.end_col_offset
|
|
2714
|
+
else:
|
|
2715
|
+
# fallback for Python versions before 3.8
|
|
2716
|
+
# we have to find the end line and column manually
|
|
2717
|
+
end_line = start_line
|
|
2718
|
+
end_col = start_col
|
|
2719
|
+
parenthesis_count = 1
|
|
2720
|
+
for lineno in range(start_line, len(adj.source_lines)):
|
|
2721
|
+
if lineno == start_line:
|
|
2722
|
+
c_start = start_col
|
|
2723
|
+
else:
|
|
2724
|
+
c_start = 0
|
|
2725
|
+
line = adj.source_lines[lineno]
|
|
2726
|
+
for i in range(c_start, len(line)):
|
|
2727
|
+
c = line[i]
|
|
2728
|
+
if c == "(":
|
|
2729
|
+
parenthesis_count += 1
|
|
2730
|
+
elif c == ")":
|
|
2731
|
+
parenthesis_count -= 1
|
|
2732
|
+
if parenthesis_count == 0:
|
|
2733
|
+
end_col = i
|
|
2734
|
+
end_line = lineno
|
|
2735
|
+
break
|
|
2736
|
+
if parenthesis_count == 0:
|
|
2737
|
+
break
|
|
2738
|
+
|
|
2739
|
+
if start_line == end_line:
|
|
2740
|
+
# single-line expression
|
|
2741
|
+
return adj.source_lines[start_line][start_col:end_col]
|
|
2742
|
+
else:
|
|
2743
|
+
# multi-line expression
|
|
2744
|
+
lines = []
|
|
2745
|
+
# first line (from start_col to the end)
|
|
2746
|
+
lines.append(adj.source_lines[start_line][start_col:])
|
|
2747
|
+
# middle lines (entire lines)
|
|
2748
|
+
lines.extend(adj.source_lines[start_line + 1 : end_line])
|
|
2749
|
+
# last line (from the start to end_col)
|
|
2750
|
+
lines.append(adj.source_lines[end_line][:end_col])
|
|
2751
|
+
return "\n".join(lines).strip()
|
|
2752
|
+
|
|
2753
|
+
# handles a wp.static() expression and returns the resulting object and a string representing the code
|
|
2754
|
+
# of the static expression
|
|
2755
|
+
def evaluate_static_expression(adj, node) -> Tuple[Any, str]:
|
|
2756
|
+
if len(node.args) == 1:
|
|
2757
|
+
static_code = adj.extract_node_source(node.args[0])
|
|
2758
|
+
elif len(node.keywords) == 1:
|
|
2759
|
+
static_code = adj.extract_node_source(node.keywords[0])
|
|
2760
|
+
else:
|
|
2761
|
+
raise WarpCodegenError("warp.static() requires a single argument or keyword")
|
|
2762
|
+
if static_code is None:
|
|
2763
|
+
raise WarpCodegenError("Error extracting source code from wp.static() expression")
|
|
2764
|
+
|
|
2765
|
+
vars_dict = adj.get_static_evaluation_context()
|
|
2766
|
+
# add constant variables to the static call context
|
|
2767
|
+
constant_vars = {k: v.constant for k, v in adj.symbols.items() if isinstance(v, Var) and v.constant is not None}
|
|
2768
|
+
vars_dict.update(constant_vars)
|
|
2769
|
+
|
|
2770
|
+
try:
|
|
2771
|
+
value = eval(static_code, vars_dict)
|
|
2772
|
+
if warp.config.verbose:
|
|
2773
|
+
print(f"Evaluated static command: {static_code} = {value}")
|
|
2774
|
+
except NameError as e:
|
|
2775
|
+
raise WarpCodegenError(
|
|
2776
|
+
f"Error evaluating static expression: {e}. Make sure all variables used in the static expression are constant."
|
|
2777
|
+
) from e
|
|
2778
|
+
except Exception as e:
|
|
2779
|
+
raise WarpCodegenError(
|
|
2780
|
+
f"Error evaluating static expression: {e} while evaluating the following code generated from the static expression:\n{static_code}"
|
|
2781
|
+
) from e
|
|
2782
|
+
|
|
2783
|
+
try:
|
|
2784
|
+
adj.verify_static_return_value(value)
|
|
2785
|
+
except ValueError as e:
|
|
2786
|
+
raise WarpCodegenError(
|
|
2787
|
+
f"Static expression returns an unsupported value: {e} while evaluating the following code generated from the static expression:\n{static_code}"
|
|
2788
|
+
) from e
|
|
2789
|
+
|
|
2790
|
+
return value, static_code
|
|
2791
|
+
|
|
2792
|
+
# try to replace wp.static() expressions by their evaluated value if the
|
|
2793
|
+
# expression can be evaluated
|
|
2794
|
+
def replace_static_expressions(adj):
|
|
2795
|
+
class StaticExpressionReplacer(ast.NodeTransformer):
|
|
2796
|
+
def visit_Call(self, node):
|
|
2797
|
+
func, _ = adj.resolve_static_expression(node.func, eval_types=False)
|
|
2798
|
+
if adj.is_static_expression(func):
|
|
2799
|
+
try:
|
|
2800
|
+
# the static expression will execute as long as the static expression is valid and
|
|
2801
|
+
# only depends on global or captured variables
|
|
2802
|
+
obj, code = adj.evaluate_static_expression(node)
|
|
2803
|
+
if code is not None:
|
|
2804
|
+
adj.static_expressions[code] = obj
|
|
2805
|
+
if isinstance(obj, warp.context.Function):
|
|
2806
|
+
name_node = ast.Name("__warp_func__")
|
|
2807
|
+
# we add a pointer to the Warp function here so that we can refer to it later at
|
|
2808
|
+
# codegen time (note that the function key itself is not sufficient to uniquely
|
|
2809
|
+
# identify the function, as the function may be redefined between the current time
|
|
2810
|
+
# of wp.static() declaration and the time of codegen during module building)
|
|
2811
|
+
name_node.warp_func = obj
|
|
2812
|
+
return ast.copy_location(name_node, node)
|
|
2813
|
+
else:
|
|
2814
|
+
return ast.copy_location(ast.Constant(value=obj), node)
|
|
2815
|
+
except Exception:
|
|
2816
|
+
# Ignoring failing static expressions should generally not be an issue because only
|
|
2817
|
+
# one of these cases should be possible:
|
|
2818
|
+
# 1) the static expression itself is invalid code, in which case the module cannot be
|
|
2819
|
+
# built all,
|
|
2820
|
+
# 2) the static expression contains a reference to a local (even if constant) variable
|
|
2821
|
+
# (and is therefore not executable and raises this exception), in which
|
|
2822
|
+
# case changing the constant, or the code affecting this constant, would lead to
|
|
2823
|
+
# a different module hash anyway.
|
|
2824
|
+
pass
|
|
2825
|
+
|
|
2826
|
+
return self.generic_visit(node)
|
|
2827
|
+
|
|
2828
|
+
adj.tree = StaticExpressionReplacer().visit(adj.tree)
|
|
2829
|
+
|
|
2465
2830
|
# Evaluates a static expression that does not depend on runtime values
|
|
2466
2831
|
# if eval_types is True, try resolving the path using evaluated type information as well
|
|
2467
2832
|
def resolve_static_expression(adj, root_node, eval_types=True):
|
|
@@ -2536,34 +2901,42 @@ class Adjoint:
|
|
|
2536
2901
|
# return the Python code corresponding to the given AST node
|
|
2537
2902
|
return ast.get_source_segment(adj.source, node)
|
|
2538
2903
|
|
|
2539
|
-
def
|
|
2540
|
-
"""Traverses ``adj.tree`` and returns
|
|
2541
|
-
|
|
2542
|
-
This function is meant to be used to populate a module's constants dictionary, which then feeds
|
|
2543
|
-
into the computation of the module's ``content_hash``.
|
|
2544
|
-
"""
|
|
2904
|
+
def get_references(adj) -> Tuple[Dict[str, Any], Dict[Any, Any], Dict[warp.context.Function, Any]]:
|
|
2905
|
+
"""Traverses ``adj.tree`` and returns referenced constants, types, and user-defined functions."""
|
|
2545
2906
|
|
|
2546
2907
|
local_variables = set() # Track local variables appearing on the LHS so we know when variables are shadowed
|
|
2547
|
-
|
|
2908
|
+
|
|
2909
|
+
constants = {}
|
|
2910
|
+
types = {}
|
|
2911
|
+
functions = {}
|
|
2548
2912
|
|
|
2549
2913
|
for node in ast.walk(adj.tree):
|
|
2550
2914
|
if isinstance(node, ast.Name) and node.id not in local_variables:
|
|
2551
2915
|
# look up in closure/global variables
|
|
2552
2916
|
obj = adj.resolve_external_reference(node.id)
|
|
2553
|
-
|
|
2554
2917
|
if warp.types.is_value(obj):
|
|
2555
|
-
|
|
2918
|
+
constants[node.id] = obj
|
|
2556
2919
|
|
|
2557
2920
|
elif isinstance(node, ast.Attribute):
|
|
2558
2921
|
obj, path = adj.resolve_static_expression(node, eval_types=False)
|
|
2559
|
-
|
|
2560
2922
|
if warp.types.is_value(obj):
|
|
2561
|
-
|
|
2923
|
+
constants[".".join(path)] = obj
|
|
2924
|
+
|
|
2925
|
+
elif isinstance(node, ast.Call):
|
|
2926
|
+
func, _ = adj.resolve_static_expression(node.func, eval_types=False)
|
|
2927
|
+
if isinstance(func, warp.context.Function) and not func.is_builtin():
|
|
2928
|
+
# calling user-defined function
|
|
2929
|
+
functions[func] = None
|
|
2930
|
+
elif isinstance(func, Struct):
|
|
2931
|
+
# calling struct constructor
|
|
2932
|
+
types[func] = None
|
|
2933
|
+
elif isinstance(func, type) and warp.types.type_is_value(func):
|
|
2934
|
+
# calling value type constructor
|
|
2935
|
+
types[func] = None
|
|
2562
2936
|
|
|
2563
2937
|
elif isinstance(node, ast.Assign):
|
|
2564
2938
|
# Add the LHS names to the local_variables so we know any subsequent uses are shadowed
|
|
2565
2939
|
lhs = node.targets[0]
|
|
2566
|
-
|
|
2567
2940
|
if isinstance(lhs, ast.Tuple):
|
|
2568
2941
|
for v in lhs.elts:
|
|
2569
2942
|
if isinstance(v, ast.Name):
|
|
@@ -2571,7 +2944,7 @@ class Adjoint:
|
|
|
2571
2944
|
elif isinstance(lhs, ast.Name):
|
|
2572
2945
|
local_variables.add(lhs.id)
|
|
2573
2946
|
|
|
2574
|
-
return
|
|
2947
|
+
return constants, types, functions
|
|
2575
2948
|
|
|
2576
2949
|
|
|
2577
2950
|
# ----------------
|
|
@@ -2817,6 +3190,15 @@ def constant_str(value):
|
|
|
2817
3190
|
# make sure we emit the value of objects, e.g. uint32
|
|
2818
3191
|
return str(value.value)
|
|
2819
3192
|
|
|
3193
|
+
elif issubclass(value_type, warp.codegen.StructInstance):
|
|
3194
|
+
# constant struct instance
|
|
3195
|
+
arg_strs = []
|
|
3196
|
+
for key, var in value._cls.vars.items():
|
|
3197
|
+
attr = getattr(value, key)
|
|
3198
|
+
arg_strs.append(f"{Var.type_to_ctype(var.type)}({constant_str(attr)})")
|
|
3199
|
+
arg_str = ", ".join(arg_strs)
|
|
3200
|
+
return f"{value.native_name}({arg_str})"
|
|
3201
|
+
|
|
2820
3202
|
elif value == math.inf:
|
|
2821
3203
|
return "INFINITY"
|
|
2822
3204
|
|
|
@@ -2845,7 +3227,7 @@ def make_full_qualified_name(func):
|
|
|
2845
3227
|
|
|
2846
3228
|
|
|
2847
3229
|
def codegen_struct(struct, device="cpu", indent_size=4):
|
|
2848
|
-
name =
|
|
3230
|
+
name = struct.native_name
|
|
2849
3231
|
|
|
2850
3232
|
body = []
|
|
2851
3233
|
indent_block = " " * indent_size
|