warp-lang 1.3.3__py3-none-manylinux2014_x86_64.whl → 1.4.1__py3-none-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (110) hide show
  1. warp/__init__.py +6 -0
  2. warp/autograd.py +59 -6
  3. warp/bin/warp.so +0 -0
  4. warp/build_dll.py +8 -10
  5. warp/builtins.py +103 -3
  6. warp/codegen.py +447 -53
  7. warp/config.py +1 -1
  8. warp/context.py +682 -405
  9. warp/dlpack.py +2 -0
  10. warp/examples/benchmarks/benchmark_cloth.py +10 -0
  11. warp/examples/core/example_render_opengl.py +12 -10
  12. warp/examples/fem/example_adaptive_grid.py +251 -0
  13. warp/examples/fem/example_apic_fluid.py +1 -1
  14. warp/examples/fem/example_diffusion_3d.py +2 -2
  15. warp/examples/fem/example_magnetostatics.py +1 -1
  16. warp/examples/fem/example_streamlines.py +1 -0
  17. warp/examples/fem/utils.py +25 -5
  18. warp/examples/sim/example_cloth.py +50 -6
  19. warp/fem/__init__.py +2 -0
  20. warp/fem/adaptivity.py +493 -0
  21. warp/fem/field/field.py +2 -1
  22. warp/fem/field/nodal_field.py +18 -26
  23. warp/fem/field/test.py +4 -4
  24. warp/fem/field/trial.py +4 -4
  25. warp/fem/geometry/__init__.py +1 -0
  26. warp/fem/geometry/adaptive_nanogrid.py +843 -0
  27. warp/fem/geometry/nanogrid.py +55 -28
  28. warp/fem/space/__init__.py +1 -1
  29. warp/fem/space/nanogrid_function_space.py +69 -35
  30. warp/fem/utils.py +118 -107
  31. warp/jax_experimental.py +28 -15
  32. warp/native/array.h +0 -1
  33. warp/native/builtin.h +103 -6
  34. warp/native/bvh.cu +4 -2
  35. warp/native/cuda_util.cpp +14 -0
  36. warp/native/cuda_util.h +2 -0
  37. warp/native/error.cpp +4 -2
  38. warp/native/exports.h +99 -0
  39. warp/native/mat.h +97 -0
  40. warp/native/mesh.cpp +36 -0
  41. warp/native/mesh.cu +52 -1
  42. warp/native/mesh.h +1 -0
  43. warp/native/quat.h +43 -0
  44. warp/native/range.h +11 -2
  45. warp/native/spatial.h +6 -0
  46. warp/native/vec.h +74 -0
  47. warp/native/warp.cpp +2 -1
  48. warp/native/warp.cu +10 -3
  49. warp/native/warp.h +8 -1
  50. warp/paddle.py +382 -0
  51. warp/sim/__init__.py +1 -0
  52. warp/sim/collide.py +519 -0
  53. warp/sim/integrator_euler.py +18 -5
  54. warp/sim/integrator_featherstone.py +5 -5
  55. warp/sim/integrator_vbd.py +1026 -0
  56. warp/sim/integrator_xpbd.py +2 -6
  57. warp/sim/model.py +50 -25
  58. warp/sparse.py +9 -7
  59. warp/stubs.py +459 -0
  60. warp/tape.py +2 -0
  61. warp/tests/aux_test_dependent.py +1 -0
  62. warp/tests/aux_test_name_clash1.py +32 -0
  63. warp/tests/aux_test_name_clash2.py +32 -0
  64. warp/tests/aux_test_square.py +1 -0
  65. warp/tests/test_array.py +188 -0
  66. warp/tests/test_async.py +3 -3
  67. warp/tests/test_atomic.py +6 -0
  68. warp/tests/test_closest_point_edge_edge.py +93 -1
  69. warp/tests/test_codegen.py +93 -15
  70. warp/tests/test_codegen_instancing.py +1457 -0
  71. warp/tests/test_collision.py +486 -0
  72. warp/tests/test_compile_consts.py +3 -28
  73. warp/tests/test_dlpack.py +170 -0
  74. warp/tests/test_examples.py +22 -8
  75. warp/tests/test_fast_math.py +10 -4
  76. warp/tests/test_fem.py +81 -1
  77. warp/tests/test_func.py +46 -0
  78. warp/tests/test_implicit_init.py +49 -0
  79. warp/tests/test_jax.py +58 -0
  80. warp/tests/test_mat.py +84 -0
  81. warp/tests/test_mesh_query_point.py +188 -0
  82. warp/tests/test_model.py +13 -0
  83. warp/tests/test_module_hashing.py +40 -0
  84. warp/tests/test_multigpu.py +3 -3
  85. warp/tests/test_overwrite.py +8 -0
  86. warp/tests/test_paddle.py +852 -0
  87. warp/tests/test_print.py +89 -0
  88. warp/tests/test_quat.py +111 -0
  89. warp/tests/test_reload.py +31 -1
  90. warp/tests/test_scalar_ops.py +2 -0
  91. warp/tests/test_static.py +568 -0
  92. warp/tests/test_streams.py +64 -3
  93. warp/tests/test_struct.py +4 -4
  94. warp/tests/test_torch.py +24 -0
  95. warp/tests/test_triangle_closest_point.py +137 -0
  96. warp/tests/test_types.py +1 -1
  97. warp/tests/test_vbd.py +386 -0
  98. warp/tests/test_vec.py +143 -0
  99. warp/tests/test_vec_scalar_ops.py +139 -0
  100. warp/tests/unittest_suites.py +12 -0
  101. warp/tests/unittest_utils.py +9 -5
  102. warp/thirdparty/dlpack.py +3 -1
  103. warp/types.py +167 -36
  104. warp/utils.py +37 -14
  105. {warp_lang-1.3.3.dist-info → warp_lang-1.4.1.dist-info}/METADATA +10 -8
  106. {warp_lang-1.3.3.dist-info → warp_lang-1.4.1.dist-info}/RECORD +109 -97
  107. warp/tests/test_point_triangle_closest_point.py +0 -143
  108. {warp_lang-1.3.3.dist-info → warp_lang-1.4.1.dist-info}/LICENSE.md +0 -0
  109. {warp_lang-1.3.3.dist-info → warp_lang-1.4.1.dist-info}/WHEEL +0 -0
  110. {warp_lang-1.3.3.dist-info → warp_lang-1.4.1.dist-info}/top_level.txt +0 -0
warp/codegen.py CHANGED
@@ -11,6 +11,7 @@ import ast
11
11
  import builtins
12
12
  import ctypes
13
13
  import functools
14
+ import hashlib
14
15
  import inspect
15
16
  import math
16
17
  import re
@@ -232,8 +233,11 @@ class StructInstance:
232
233
 
233
234
  def __getattribute__(self, name):
234
235
  cls = super().__getattribute__("_cls")
235
- if name in cls.vars:
236
- var = cls.vars[name]
236
+ if name == "native_name":
237
+ return cls.native_name
238
+
239
+ var = cls.vars.get(name)
240
+ if var is not None:
237
241
  if isinstance(var.type, type) and issubclass(var.type, ctypes.Array):
238
242
  # Each field stored in a `StructInstance` is exposed as
239
243
  # a standard Python attribute but also has a `ctypes`
@@ -408,6 +412,9 @@ class Struct:
408
412
  elif issubclass(var.type, ctypes.Array):
409
413
  fields.append((label, var.type))
410
414
  else:
415
+ # HACK: fp16 requires conversion functions from warp.so
416
+ if var.type is warp.float16:
417
+ warp.init()
411
418
  fields.append((label, var.type._type_))
412
419
 
413
420
  class StructType(ctypes.Structure):
@@ -416,15 +423,35 @@ class Struct:
416
423
 
417
424
  self.ctype = StructType
418
425
 
426
+ # Compute the hash. We can cache the hash because it's static, even with nested structs.
427
+ # All field types are specified in the annotations, so they're resolved at declaration time.
428
+ ch = hashlib.sha256()
429
+
430
+ ch.update(bytes(self.key, "utf-8"))
431
+
432
+ for name, type_hint in annotations.items():
433
+ s = f"{name}:{warp.types.get_type_code(type_hint)}"
434
+ ch.update(bytes(s, "utf-8"))
435
+
436
+ # recurse on nested structs
437
+ if isinstance(type_hint, Struct):
438
+ ch.update(type_hint.hash)
439
+
440
+ self.hash = ch.digest()
441
+
442
+ # generate unique identifier for structs in native code
443
+ hash_suffix = f"{self.hash.hex()[:8]}"
444
+ self.native_name = f"{self.key}_{hash_suffix}"
445
+
419
446
  # create default constructor (zero-initialize)
420
447
  self.default_constructor = warp.context.Function(
421
448
  func=None,
422
- key=self.key,
449
+ key=self.native_name,
423
450
  namespace="",
424
451
  value_func=lambda *_: self,
425
452
  input_types={},
426
453
  initializer_list_func=lambda *_: False,
427
- native_func=make_full_qualified_name(self.cls),
454
+ native_func=self.native_name,
428
455
  )
429
456
 
430
457
  # build a constructor that takes each param as a value
@@ -432,12 +459,12 @@ class Struct:
432
459
 
433
460
  self.value_constructor = warp.context.Function(
434
461
  func=None,
435
- key=self.key,
462
+ key=self.native_name,
436
463
  namespace="",
437
464
  value_func=lambda *_: self,
438
465
  input_types=input_types,
439
466
  initializer_list_func=lambda *_: False,
440
- native_func=make_full_qualified_name(self.cls),
467
+ native_func=self.native_name,
441
468
  )
442
469
 
443
470
  self.default_constructor.add_overload(self.value_constructor)
@@ -465,6 +492,10 @@ class Struct:
465
492
  def __init__(inst):
466
493
  StructInstance.__init__(inst, self, None)
467
494
 
495
+ # make sure warp.types.get_type_code works with this StructInstance
496
+ NewStructInstance.cls = self.cls
497
+ NewStructInstance.native_name = self.native_name
498
+
468
499
  return NewStructInstance()
469
500
 
470
501
  def initializer(self):
@@ -599,7 +630,7 @@ class Var:
599
630
  if hasattr(t.dtype, "_wp_generic_type_str_"):
600
631
  dtypestr = compute_type_str(f"wp::{t.dtype._wp_generic_type_str_}", t.dtype._wp_type_params_)
601
632
  elif isinstance(t.dtype, Struct):
602
- dtypestr = make_full_qualified_name(t.dtype.cls)
633
+ dtypestr = t.dtype.native_name
603
634
  elif t.dtype.__name__ in ("bool", "int", "float"):
604
635
  dtypestr = t.dtype.__name__
605
636
  else:
@@ -607,7 +638,10 @@ class Var:
607
638
  classstr = f"wp::{type(t).__name__}"
608
639
  return f"{classstr}_t<{dtypestr}>"
609
640
  elif isinstance(t, Struct):
610
- return make_full_qualified_name(t.cls)
641
+ return t.native_name
642
+ elif isinstance(t, type) and issubclass(t, StructInstance):
643
+ # ensure the actual Struct name is used instead of "NewStructInstance"
644
+ return t.native_name
611
645
  elif is_reference(t):
612
646
  if not value_type:
613
647
  return Var.type_to_ctype(t.value_type) + "*"
@@ -743,6 +777,9 @@ def func_match_args(func, arg_types, kwarg_types):
743
777
 
744
778
 
745
779
  def get_arg_type(arg: Union[Var, Any]):
780
+ if isinstance(arg, str):
781
+ return str
782
+
746
783
  if isinstance(arg, Sequence):
747
784
  return tuple(get_arg_type(x) for x in arg)
748
785
 
@@ -863,6 +900,12 @@ class Adjoint:
863
900
  # this is to avoid registering false references to overshadowed modules
864
901
  adj.symbols[name] = arg
865
902
 
903
+ # try to replace static expressions by their constant result if the
904
+ # expression can be evaluated at declaration time
905
+ adj.static_expressions: Dict[str, Any] = {}
906
+ if "static" in adj.source:
907
+ adj.replace_static_expressions()
908
+
866
909
  # There are cases where a same module might be rebuilt multiple times,
867
910
  # for example when kernels are nested inside of functions, or when
868
911
  # a kernel's launch raises an exception. Ideally we'd always want to
@@ -896,6 +939,7 @@ class Adjoint:
896
939
 
897
940
  adj.return_var = None # return type for function or kernel
898
941
  adj.loop_symbols = [] # symbols at the start of each loop
942
+ adj.loop_const_iter_symbols = [] # iteration variables (constant) for static loops
899
943
 
900
944
  # blocks
901
945
  adj.blocks = [Block()]
@@ -948,9 +992,9 @@ class Adjoint:
948
992
  if isinstance(a, warp.context.Function):
949
993
  # functions don't have a var_ prefix so strip it off here
950
994
  if prefix == "var":
951
- arg_strs.append(a.key)
995
+ arg_strs.append(a.native_func)
952
996
  else:
953
- arg_strs.append(f"{prefix}_{a.key}")
997
+ arg_strs.append(f"{prefix}_{a.native_func}")
954
998
  elif is_reference(a.type):
955
999
  arg_strs.append(f"{prefix}_{a}")
956
1000
  elif isinstance(a, Var):
@@ -1255,6 +1299,10 @@ class Adjoint:
1255
1299
  if not isinstance(func_arg, (Reference, warp.context.Function)):
1256
1300
  func_arg = adj.load(func_arg)
1257
1301
 
1302
+ # if the argument is a function, build it recursively
1303
+ if isinstance(func_arg, warp.context.Function):
1304
+ adj.builder.build_function(func_arg)
1305
+
1258
1306
  fwd_args.append(strip_reference(func_arg))
1259
1307
 
1260
1308
  if return_type is None:
@@ -1440,6 +1488,7 @@ class Adjoint:
1440
1488
  cond_block.body_forward.append(f"start_{cond_block.label}:;")
1441
1489
 
1442
1490
  c = adj.eval(cond)
1491
+ c = adj.load(c)
1443
1492
 
1444
1493
  cond_block.body_forward.append(f"if (({c.emit()}) == false) goto end_{cond_block.label};")
1445
1494
 
@@ -1493,6 +1542,9 @@ class Adjoint:
1493
1542
 
1494
1543
  def emit_FunctionDef(adj, node):
1495
1544
  for f in node.body:
1545
+ # Skip variable creation for standalone constants, including docstrings
1546
+ if isinstance(f, ast.Expr) and isinstance(f.value, ast.Constant):
1547
+ continue
1496
1548
  adj.eval(f)
1497
1549
 
1498
1550
  if adj.return_var is not None and len(adj.return_var) == 1:
@@ -1523,6 +1575,16 @@ class Adjoint:
1523
1575
  # eval condition
1524
1576
  cond = adj.eval(node.test)
1525
1577
 
1578
+ if cond.constant is not None:
1579
+ # resolve constant condition
1580
+ if cond.constant:
1581
+ for stmt in node.body:
1582
+ adj.eval(stmt)
1583
+ else:
1584
+ for stmt in node.orelse:
1585
+ adj.eval(stmt)
1586
+ return None
1587
+
1526
1588
  # save symbol map
1527
1589
  symbols_prev = adj.symbols.copy()
1528
1590
 
@@ -1618,7 +1680,7 @@ class Adjoint:
1618
1680
  if isinstance(obj, types.ModuleType):
1619
1681
  return obj
1620
1682
 
1621
- raise RuntimeError("Cannot reference a global variable from a kernel unless `wp.constant()` is being used")
1683
+ raise TypeError(f"Invalid external reference type: {type(obj)}")
1622
1684
 
1623
1685
  @staticmethod
1624
1686
  def resolve_type_attribute(var_type: type, attr: str):
@@ -1732,7 +1794,7 @@ class Adjoint:
1732
1794
 
1733
1795
  def emit_NameConstant(adj, node):
1734
1796
  if node.value:
1735
- return adj.add_constant(True)
1797
+ return adj.add_constant(node.value)
1736
1798
  elif node.value is None:
1737
1799
  raise WarpCodegenTypeError("None type unsupported")
1738
1800
  else:
@@ -1746,7 +1808,7 @@ class Adjoint:
1746
1808
  elif isinstance(node, ast.Ellipsis):
1747
1809
  return adj.emit_Ellipsis(node)
1748
1810
  else:
1749
- assert isinstance(node, ast.NameConstant)
1811
+ assert isinstance(node, ast.NameConstant) or isinstance(node, ast.Constant)
1750
1812
  return adj.emit_NameConstant(node)
1751
1813
 
1752
1814
  def emit_BinOp(adj, node):
@@ -1787,6 +1849,11 @@ class Adjoint:
1787
1849
  # detect symbols with conflicting definitions (assigned inside the for loop)
1788
1850
  for items in symbols.items():
1789
1851
  sym = items[0]
1852
+ if adj.is_constant_iter_symbol(sym):
1853
+ # ignore constant overwriting in for-loops if it is a loop iterator
1854
+ # (it is no problem to unroll static loops multiple times in sequence)
1855
+ continue
1856
+
1790
1857
  var1 = items[1]
1791
1858
  var2 = adj.symbols[sym]
1792
1859
 
@@ -1933,15 +2000,36 @@ class Adjoint:
1933
2000
  )
1934
2001
  return range_call
1935
2002
 
2003
+ def begin_record_constant_iter_symbols(adj):
2004
+ if len(adj.loop_const_iter_symbols) > 0:
2005
+ adj.loop_const_iter_symbols.append(adj.loop_const_iter_symbols[-1])
2006
+ else:
2007
+ adj.loop_const_iter_symbols.append(set())
2008
+
2009
+ def end_record_constant_iter_symbols(adj):
2010
+ if len(adj.loop_const_iter_symbols) > 0:
2011
+ adj.loop_const_iter_symbols.pop()
2012
+
2013
+ def record_constant_iter_symbol(adj, sym):
2014
+ if len(adj.loop_const_iter_symbols) > 0:
2015
+ adj.loop_const_iter_symbols[-1].add(sym)
2016
+
2017
+ def is_constant_iter_symbol(adj, sym):
2018
+ return len(adj.loop_const_iter_symbols) > 0 and sym in adj.loop_const_iter_symbols[-1]
2019
+
1936
2020
  def emit_For(adj, node):
1937
2021
  # try and unroll simple range() statements that use constant args
1938
2022
  unroll_range = adj.get_unroll_range(node)
1939
2023
 
1940
2024
  if isinstance(unroll_range, range):
2025
+ const_iter_sym = node.target.id
2026
+ # prevent constant conflicts in `materialize_redefinitions()`
2027
+ adj.record_constant_iter_symbol(const_iter_sym)
2028
+
2029
+ # unroll static for-loop
1941
2030
  for i in unroll_range:
1942
2031
  const_iter = adj.add_constant(i)
1943
- var_iter = adj.add_builtin_call("int", [const_iter])
1944
- adj.symbols[node.target.id] = var_iter
2032
+ adj.symbols[const_iter_sym] = const_iter
1945
2033
 
1946
2034
  # eval body
1947
2035
  for s in node.body:
@@ -1957,6 +2045,7 @@ class Adjoint:
1957
2045
  iter = adj.eval(node.iter)
1958
2046
 
1959
2047
  adj.symbols[node.target.id] = adj.begin_for(iter)
2048
+ adj.begin_record_constant_iter_symbols()
1960
2049
 
1961
2050
  # for loops should be side-effect free, here we store a copy
1962
2051
  adj.loop_symbols.append(adj.symbols.copy())
@@ -1967,6 +2056,7 @@ class Adjoint:
1967
2056
 
1968
2057
  adj.materialize_redefinitions(adj.loop_symbols[-1])
1969
2058
  adj.loop_symbols.pop()
2059
+ adj.end_record_constant_iter_symbols()
1970
2060
 
1971
2061
  adj.end_for(iter)
1972
2062
 
@@ -2023,13 +2113,28 @@ class Adjoint:
2023
2113
 
2024
2114
  # try and lookup function in globals by
2025
2115
  # resolving path (e.g.: module.submodule.attr)
2026
- func, path = adj.resolve_static_expression(node.func)
2116
+ if hasattr(node.func, "warp_func"):
2117
+ func = node.func.warp_func
2118
+ path = []
2119
+ else:
2120
+ func, path = adj.resolve_static_expression(node.func)
2027
2121
  if func is None:
2028
2122
  func = adj.eval(node.func)
2029
2123
 
2124
+ if adj.is_static_expression(func):
2125
+ # try to evaluate wp.static() expressions
2126
+ obj, _ = adj.evaluate_static_expression(node)
2127
+ if obj is not None:
2128
+ if isinstance(obj, warp.context.Function):
2129
+ # special handling for wp.static() evaluating to a function
2130
+ return obj
2131
+ else:
2132
+ out = adj.add_constant(obj)
2133
+ return out
2134
+
2030
2135
  type_args = {}
2031
2136
 
2032
- if not isinstance(func, warp.context.Function):
2137
+ if len(path) > 0 and not isinstance(func, warp.context.Function):
2033
2138
  attr = path[-1]
2034
2139
  caller = func
2035
2140
  func = None
@@ -2083,6 +2188,9 @@ class Adjoint:
2083
2188
  args = tuple(adj.resolve_arg(x) for x in node.args)
2084
2189
  kwargs = {x.arg: adj.resolve_arg(x.value) for x in node.keywords}
2085
2190
 
2191
+ # add the call and build the callee adjoint if needed (func.adj)
2192
+ out = adj.add_call(func, args, kwargs, type_args, min_outputs=min_outputs)
2193
+
2086
2194
  if warp.config.verify_autograd_array_access:
2087
2195
  # update arg read/write states according to what happens to that arg in the called function
2088
2196
  if hasattr(func, "adj"):
@@ -2095,7 +2203,6 @@ class Adjoint:
2095
2203
  if func.adj.args[i].is_read:
2096
2204
  arg.mark_read()
2097
2205
 
2098
- out = adj.add_call(func, args, kwargs, type_args, min_outputs=min_outputs)
2099
2206
  return out
2100
2207
 
2101
2208
  def emit_Index(adj, node):
@@ -2281,20 +2388,40 @@ class Adjoint:
2281
2388
  target.mark_write(kernel_name=kernel_name, filename=filename, lineno=lineno)
2282
2389
 
2283
2390
  elif type_is_vector(target_type) or type_is_quaternion(target_type) or type_is_matrix(target_type):
2391
+ # recursively unwind AST, stopping at penultimate node
2392
+ node = lhs
2393
+ while hasattr(node, "value"):
2394
+ if hasattr(node.value, "value"):
2395
+ node = node.value
2396
+ else:
2397
+ break
2398
+ # lhs is updating a variable adjoint (i.e. wp.adjoint[var])
2399
+ if hasattr(node, "attr") and node.attr == "adjoint":
2400
+ attr = adj.add_builtin_call("index", [target, *indices])
2401
+ adj.add_builtin_call("store", [attr, rhs])
2402
+ return
2403
+
2404
+ # TODO: array vec component case
2284
2405
  if is_reference(target.type):
2285
2406
  attr = adj.add_builtin_call("indexref", [target, *indices])
2286
- else:
2287
- attr = adj.add_builtin_call("index", [target, *indices])
2407
+ adj.add_builtin_call("store", [attr, rhs])
2288
2408
 
2289
- adj.add_builtin_call("store", [attr, rhs])
2409
+ if warp.config.verbose and not adj.custom_reverse_mode:
2410
+ lineno = adj.lineno + adj.fun_lineno
2411
+ line = adj.source_lines[adj.lineno]
2412
+ node_source = adj.get_node_source(lhs.value)
2413
+ print(
2414
+ f"Warning: mutating {node_source} in function {adj.fun_name} at {adj.filename}:{lineno}: this is a non-differentiable operation.\n{line}\n"
2415
+ )
2290
2416
 
2291
- if warp.config.verbose and not adj.custom_reverse_mode:
2292
- lineno = adj.lineno + adj.fun_lineno
2293
- line = adj.source_lines[adj.lineno]
2294
- node_source = adj.get_node_source(lhs.value)
2295
- print(
2296
- f"Warning: mutating {node_source} in function {adj.fun_name} at {adj.filename}:{lineno}: this is a non-differentiable operation.\n{line}\n"
2297
- )
2417
+ else:
2418
+ out = adj.add_builtin_call("assign", [target, *indices, rhs])
2419
+
2420
+ # re-point target symbol to out var
2421
+ for id in adj.symbols:
2422
+ if adj.symbols[id] == target:
2423
+ adj.symbols[id] = out
2424
+ break
2298
2425
 
2299
2426
  else:
2300
2427
  raise WarpCodegenError(
@@ -2329,16 +2456,24 @@ class Adjoint:
2329
2456
  aggregate = adj.eval(lhs.value)
2330
2457
  aggregate_type = strip_reference(aggregate.type)
2331
2458
 
2332
- # assigning to a vector component
2333
- if type_is_vector(aggregate_type):
2459
+ # assigning to a vector or quaternion component
2460
+ if type_is_vector(aggregate_type) or type_is_quaternion(aggregate_type):
2461
+ # TODO: handle wp.adjoint case
2462
+
2334
2463
  index = adj.vector_component_index(lhs.attr, aggregate_type)
2335
2464
 
2465
+ # TODO: array vec component case
2336
2466
  if is_reference(aggregate.type):
2337
2467
  attr = adj.add_builtin_call("indexref", [aggregate, index])
2468
+ adj.add_builtin_call("store", [attr, rhs])
2338
2469
  else:
2339
- attr = adj.add_builtin_call("index", [aggregate, index])
2470
+ out = adj.add_builtin_call("assign", [aggregate, index, rhs])
2340
2471
 
2341
- adj.add_builtin_call("store", [attr, rhs])
2472
+ # re-point target symbol to out var
2473
+ for id in adj.symbols:
2474
+ if adj.symbols[id] == aggregate:
2475
+ adj.symbols[id] = out
2476
+ break
2342
2477
 
2343
2478
  else:
2344
2479
  attr = adj.emit_Attribute(lhs)
@@ -2382,9 +2517,66 @@ class Adjoint:
2382
2517
  adj.add_return(adj.return_var)
2383
2518
 
2384
2519
  def emit_AugAssign(adj, node):
2385
- # replace augmented assignment with assignment statement + binary op
2386
- new_node = ast.Assign(targets=[node.target], value=ast.BinOp(node.target, node.op, node.value))
2387
- adj.eval(new_node)
2520
+ lhs = node.target
2521
+
2522
+ # replace augmented assignment with assignment statement + binary op (default behaviour)
2523
+ def make_new_assign_statement():
2524
+ new_node = ast.Assign(targets=[lhs], value=ast.BinOp(lhs, node.op, node.value))
2525
+ adj.eval(new_node)
2526
+
2527
+ if isinstance(lhs, ast.Subscript):
2528
+ rhs = adj.eval(node.value)
2529
+
2530
+ # wp.adjoint[var] appears in custom grad functions, and does not require
2531
+ # special consideration in the AugAssign case
2532
+ if hasattr(lhs.value, "attr") and lhs.value.attr == "adjoint":
2533
+ make_new_assign_statement()
2534
+ return
2535
+
2536
+ target, indices = adj.eval_subscript(lhs)
2537
+
2538
+ target_type = strip_reference(target.type)
2539
+
2540
+ if is_array(target_type):
2541
+ # target_type is not suitable for atomic array accumulation
2542
+ if target_type.dtype not in warp.types.atomic_types:
2543
+ make_new_assign_statement()
2544
+ return
2545
+
2546
+ kernel_name = adj.fun_name
2547
+ filename = adj.filename
2548
+ lineno = adj.lineno + adj.fun_lineno
2549
+
2550
+ if isinstance(node.op, ast.Add):
2551
+ adj.add_builtin_call("atomic_add", [target, *indices, rhs])
2552
+
2553
+ if warp.config.verify_autograd_array_access:
2554
+ target.mark_write(kernel_name=kernel_name, filename=filename, lineno=lineno)
2555
+
2556
+ elif isinstance(node.op, ast.Sub):
2557
+ adj.add_builtin_call("atomic_sub", [target, *indices, rhs])
2558
+
2559
+ if warp.config.verify_autograd_array_access:
2560
+ target.mark_write(kernel_name=kernel_name, filename=filename, lineno=lineno)
2561
+ else:
2562
+ print(f"Warning: in-place op {node.op} is not differentiable")
2563
+
2564
+ # TODO
2565
+ elif type_is_vector(target_type) or type_is_quaternion(target_type) or type_is_matrix(target_type):
2566
+ make_new_assign_statement()
2567
+ return
2568
+
2569
+ else:
2570
+ raise WarpCodegenError("Can only subscript in-place assign array, vector, quaternion, and matrix types")
2571
+
2572
+ # TODO
2573
+ elif isinstance(lhs, ast.Attribute):
2574
+ make_new_assign_statement()
2575
+ return
2576
+
2577
+ else:
2578
+ make_new_assign_statement()
2579
+ return
2388
2580
 
2389
2581
  def emit_Tuple(adj, node):
2390
2582
  # LHS for expressions, such as i, j, k = 1, 2, 3
@@ -2445,9 +2637,6 @@ class Adjoint:
2445
2637
  if path[0] in adj.symbols:
2446
2638
  return None
2447
2639
 
2448
- if path[0] in __builtins__:
2449
- return __builtins__[path[0]]
2450
-
2451
2640
  # look up in closure/global variables
2452
2641
  expr = adj.resolve_external_reference(path[0])
2453
2642
 
@@ -2455,13 +2644,201 @@ class Adjoint:
2455
2644
  if expr is None:
2456
2645
  expr = getattr(warp, path[0], None)
2457
2646
 
2458
- if expr:
2647
+ # look up in builtins
2648
+ if expr is None:
2649
+ expr = __builtins__.get(path[0])
2650
+
2651
+ if expr is not None:
2459
2652
  for i in range(1, len(path)):
2460
2653
  if hasattr(expr, path[i]):
2461
2654
  expr = getattr(expr, path[i])
2462
2655
 
2463
2656
  return expr
2464
2657
 
2658
+ # retrieves a dictionary of all closure and global variables and their values
2659
+ # to be used in the evaluation context of wp.static() expressions
2660
+ def get_static_evaluation_context(adj):
2661
+ closure_vars = dict(
2662
+ zip(
2663
+ adj.func.__code__.co_freevars,
2664
+ [c.cell_contents for c in (adj.func.__closure__ or [])],
2665
+ )
2666
+ )
2667
+
2668
+ vars_dict = {}
2669
+ vars_dict.update(adj.func.__globals__)
2670
+ # variables captured in closure have precedence over global vars
2671
+ vars_dict.update(closure_vars)
2672
+
2673
+ return vars_dict
2674
+
2675
+ def is_static_expression(adj, func):
2676
+ return (
2677
+ isinstance(func, types.FunctionType)
2678
+ and func.__module__ == "warp.builtins"
2679
+ and func.__qualname__ == "static"
2680
+ )
2681
+
2682
+ # verify the return type of a wp.static() expression is supported inside a Warp kernel
2683
+ def verify_static_return_value(adj, value):
2684
+ if value is None:
2685
+ raise ValueError("None is returned")
2686
+ if warp.types.is_value(value):
2687
+ return True
2688
+ if warp.types.is_array(value):
2689
+ # more useful explanation for the common case of creating a Warp array
2690
+ raise ValueError("a Warp array cannot be created inside Warp kernels")
2691
+ if isinstance(value, str):
2692
+ # we want to support cases such as `print(wp.static("test"))`
2693
+ return True
2694
+ if isinstance(value, warp.context.Function):
2695
+ return True
2696
+
2697
+ def verify_struct(s: StructInstance, attr_path: List[str]):
2698
+ for key in s._cls.vars.keys():
2699
+ v = getattr(s, key)
2700
+ if issubclass(type(v), StructInstance):
2701
+ verify_struct(v, attr_path + [key])
2702
+ else:
2703
+ try:
2704
+ adj.verify_static_return_value(v)
2705
+ except ValueError as e:
2706
+ raise ValueError(
2707
+ f"the returned Warp struct contains a data type that cannot be constructed inside Warp kernels: {e} at {value._cls.key}.{'.'.join(attr_path)}"
2708
+ ) from e
2709
+
2710
+ if issubclass(type(value), StructInstance):
2711
+ return verify_struct(value, [])
2712
+
2713
+ raise ValueError(f"value of type {type(value)} cannot be constructed inside Warp kernels")
2714
+
2715
+ # find the source code string of an AST node
2716
+ def extract_node_source(adj, node) -> Optional[str]:
2717
+ if not hasattr(node, "lineno") or not hasattr(node, "col_offset"):
2718
+ return None
2719
+
2720
+ start_line = node.lineno - 1 # line numbers start at 1
2721
+ start_col = node.col_offset
2722
+
2723
+ if hasattr(node, "end_lineno") and hasattr(node, "end_col_offset"):
2724
+ end_line = node.end_lineno - 1
2725
+ end_col = node.end_col_offset
2726
+ else:
2727
+ # fallback for Python versions before 3.8
2728
+ # we have to find the end line and column manually
2729
+ end_line = start_line
2730
+ end_col = start_col
2731
+ parenthesis_count = 1
2732
+ for lineno in range(start_line, len(adj.source_lines)):
2733
+ if lineno == start_line:
2734
+ c_start = start_col
2735
+ else:
2736
+ c_start = 0
2737
+ line = adj.source_lines[lineno]
2738
+ for i in range(c_start, len(line)):
2739
+ c = line[i]
2740
+ if c == "(":
2741
+ parenthesis_count += 1
2742
+ elif c == ")":
2743
+ parenthesis_count -= 1
2744
+ if parenthesis_count == 0:
2745
+ end_col = i
2746
+ end_line = lineno
2747
+ break
2748
+ if parenthesis_count == 0:
2749
+ break
2750
+
2751
+ if start_line == end_line:
2752
+ # single-line expression
2753
+ return adj.source_lines[start_line][start_col:end_col]
2754
+ else:
2755
+ # multi-line expression
2756
+ lines = []
2757
+ # first line (from start_col to the end)
2758
+ lines.append(adj.source_lines[start_line][start_col:])
2759
+ # middle lines (entire lines)
2760
+ lines.extend(adj.source_lines[start_line + 1 : end_line])
2761
+ # last line (from the start to end_col)
2762
+ lines.append(adj.source_lines[end_line][:end_col])
2763
+ return "\n".join(lines).strip()
2764
+
2765
+ # handles a wp.static() expression and returns the resulting object and a string representing the code
2766
+ # of the static expression
2767
+ def evaluate_static_expression(adj, node) -> Tuple[Any, str]:
2768
+ if len(node.args) == 1:
2769
+ static_code = adj.extract_node_source(node.args[0])
2770
+ elif len(node.keywords) == 1:
2771
+ static_code = adj.extract_node_source(node.keywords[0])
2772
+ else:
2773
+ raise WarpCodegenError("warp.static() requires a single argument or keyword")
2774
+ if static_code is None:
2775
+ raise WarpCodegenError("Error extracting source code from wp.static() expression")
2776
+
2777
+ vars_dict = adj.get_static_evaluation_context()
2778
+ # add constant variables to the static call context
2779
+ constant_vars = {k: v.constant for k, v in adj.symbols.items() if isinstance(v, Var) and v.constant is not None}
2780
+ vars_dict.update(constant_vars)
2781
+
2782
+ try:
2783
+ value = eval(static_code, vars_dict)
2784
+ if warp.config.verbose:
2785
+ print(f"Evaluated static command: {static_code} = {value}")
2786
+ except NameError as e:
2787
+ raise WarpCodegenError(
2788
+ f"Error evaluating static expression: {e}. Make sure all variables used in the static expression are constant."
2789
+ ) from e
2790
+ except Exception as e:
2791
+ raise WarpCodegenError(
2792
+ f"Error evaluating static expression: {e} while evaluating the following code generated from the static expression:\n{static_code}"
2793
+ ) from e
2794
+
2795
+ try:
2796
+ adj.verify_static_return_value(value)
2797
+ except ValueError as e:
2798
+ raise WarpCodegenError(
2799
+ f"Static expression returns an unsupported value: {e} while evaluating the following code generated from the static expression:\n{static_code}"
2800
+ ) from e
2801
+
2802
+ return value, static_code
2803
+
2804
+ # try to replace wp.static() expressions by their evaluated value if the
2805
+ # expression can be evaluated
2806
+ def replace_static_expressions(adj):
2807
+ class StaticExpressionReplacer(ast.NodeTransformer):
2808
+ def visit_Call(self, node):
2809
+ func, _ = adj.resolve_static_expression(node.func, eval_types=False)
2810
+ if adj.is_static_expression(func):
2811
+ try:
2812
+ # the static expression will execute as long as the static expression is valid and
2813
+ # only depends on global or captured variables
2814
+ obj, code = adj.evaluate_static_expression(node)
2815
+ if code is not None:
2816
+ adj.static_expressions[code] = obj
2817
+ if isinstance(obj, warp.context.Function):
2818
+ name_node = ast.Name("__warp_func__")
2819
+ # we add a pointer to the Warp function here so that we can refer to it later at
2820
+ # codegen time (note that the function key itself is not sufficient to uniquely
2821
+ # identify the function, as the function may be redefined between the current time
2822
+ # of wp.static() declaration and the time of codegen during module building)
2823
+ name_node.warp_func = obj
2824
+ return ast.copy_location(name_node, node)
2825
+ else:
2826
+ return ast.copy_location(ast.Constant(value=obj), node)
2827
+ except Exception:
2828
+ # Ignoring failing static expressions should generally not be an issue because only
2829
+ # one of these cases should be possible:
2830
+ # 1) the static expression itself is invalid code, in which case the module cannot be
2831
+ # built all,
2832
+ # 2) the static expression contains a reference to a local (even if constant) variable
2833
+ # (and is therefore not executable and raises this exception), in which
2834
+ # case changing the constant, or the code affecting this constant, would lead to
2835
+ # a different module hash anyway.
2836
+ pass
2837
+
2838
+ return self.generic_visit(node)
2839
+
2840
+ adj.tree = StaticExpressionReplacer().visit(adj.tree)
2841
+
2465
2842
  # Evaluates a static expression that does not depend on runtime values
2466
2843
  # if eval_types is True, try resolving the path using evaluated type information as well
2467
2844
  def resolve_static_expression(adj, root_node, eval_types=True):
@@ -2536,34 +2913,42 @@ class Adjoint:
2536
2913
  # return the Python code corresponding to the given AST node
2537
2914
  return ast.get_source_segment(adj.source, node)
2538
2915
 
2539
- def get_constant_references(adj) -> Dict[str, Any]:
2540
- """Traverses ``adj.tree`` and returns a dictionary containing constant variable names and values.
2541
-
2542
- This function is meant to be used to populate a module's constants dictionary, which then feeds
2543
- into the computation of the module's ``content_hash``.
2544
- """
2916
+ def get_references(adj) -> Tuple[Dict[str, Any], Dict[Any, Any], Dict[warp.context.Function, Any]]:
2917
+ """Traverses ``adj.tree`` and returns referenced constants, types, and user-defined functions."""
2545
2918
 
2546
2919
  local_variables = set() # Track local variables appearing on the LHS so we know when variables are shadowed
2547
- constants_dict = {}
2920
+
2921
+ constants = {}
2922
+ types = {}
2923
+ functions = {}
2548
2924
 
2549
2925
  for node in ast.walk(adj.tree):
2550
2926
  if isinstance(node, ast.Name) and node.id not in local_variables:
2551
2927
  # look up in closure/global variables
2552
2928
  obj = adj.resolve_external_reference(node.id)
2553
-
2554
2929
  if warp.types.is_value(obj):
2555
- constants_dict[node.id] = obj
2930
+ constants[node.id] = obj
2556
2931
 
2557
2932
  elif isinstance(node, ast.Attribute):
2558
2933
  obj, path = adj.resolve_static_expression(node, eval_types=False)
2559
-
2560
2934
  if warp.types.is_value(obj):
2561
- constants_dict[".".join(path)] = obj
2935
+ constants[".".join(path)] = obj
2936
+
2937
+ elif isinstance(node, ast.Call):
2938
+ func, _ = adj.resolve_static_expression(node.func, eval_types=False)
2939
+ if isinstance(func, warp.context.Function) and not func.is_builtin():
2940
+ # calling user-defined function
2941
+ functions[func] = None
2942
+ elif isinstance(func, Struct):
2943
+ # calling struct constructor
2944
+ types[func] = None
2945
+ elif isinstance(func, type) and warp.types.type_is_value(func):
2946
+ # calling value type constructor
2947
+ types[func] = None
2562
2948
 
2563
2949
  elif isinstance(node, ast.Assign):
2564
2950
  # Add the LHS names to the local_variables so we know any subsequent uses are shadowed
2565
2951
  lhs = node.targets[0]
2566
-
2567
2952
  if isinstance(lhs, ast.Tuple):
2568
2953
  for v in lhs.elts:
2569
2954
  if isinstance(v, ast.Name):
@@ -2571,7 +2956,7 @@ class Adjoint:
2571
2956
  elif isinstance(lhs, ast.Name):
2572
2957
  local_variables.add(lhs.id)
2573
2958
 
2574
- return constants_dict
2959
+ return constants, types, functions
2575
2960
 
2576
2961
 
2577
2962
  # ----------------
@@ -2817,6 +3202,15 @@ def constant_str(value):
2817
3202
  # make sure we emit the value of objects, e.g. uint32
2818
3203
  return str(value.value)
2819
3204
 
3205
+ elif issubclass(value_type, warp.codegen.StructInstance):
3206
+ # constant struct instance
3207
+ arg_strs = []
3208
+ for key, var in value._cls.vars.items():
3209
+ attr = getattr(value, key)
3210
+ arg_strs.append(f"{Var.type_to_ctype(var.type)}({constant_str(attr)})")
3211
+ arg_str = ", ".join(arg_strs)
3212
+ return f"{value.native_name}({arg_str})"
3213
+
2820
3214
  elif value == math.inf:
2821
3215
  return "INFINITY"
2822
3216
 
@@ -2845,7 +3239,7 @@ def make_full_qualified_name(func):
2845
3239
 
2846
3240
 
2847
3241
  def codegen_struct(struct, device="cpu", indent_size=4):
2848
- name = make_full_qualified_name(struct.cls)
3242
+ name = struct.native_name
2849
3243
 
2850
3244
  body = []
2851
3245
  indent_block = " " * indent_size