warp-lang 1.3.2__py3-none-manylinux2014_aarch64.whl → 1.4.0__py3-none-manylinux2014_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (107) hide show
  1. warp/__init__.py +6 -0
  2. warp/autograd.py +59 -6
  3. warp/bin/warp.so +0 -0
  4. warp/build_dll.py +8 -10
  5. warp/builtins.py +126 -4
  6. warp/codegen.py +435 -53
  7. warp/config.py +1 -1
  8. warp/context.py +678 -403
  9. warp/dlpack.py +2 -0
  10. warp/examples/benchmarks/benchmark_cloth.py +10 -0
  11. warp/examples/core/example_render_opengl.py +12 -10
  12. warp/examples/fem/example_adaptive_grid.py +251 -0
  13. warp/examples/fem/example_apic_fluid.py +1 -1
  14. warp/examples/fem/example_diffusion_3d.py +2 -2
  15. warp/examples/fem/example_magnetostatics.py +1 -1
  16. warp/examples/fem/example_streamlines.py +1 -0
  17. warp/examples/fem/utils.py +23 -4
  18. warp/examples/sim/example_cloth.py +50 -6
  19. warp/fem/__init__.py +2 -0
  20. warp/fem/adaptivity.py +493 -0
  21. warp/fem/field/field.py +2 -1
  22. warp/fem/field/nodal_field.py +18 -26
  23. warp/fem/field/test.py +4 -4
  24. warp/fem/field/trial.py +4 -4
  25. warp/fem/geometry/__init__.py +1 -0
  26. warp/fem/geometry/adaptive_nanogrid.py +843 -0
  27. warp/fem/geometry/nanogrid.py +55 -28
  28. warp/fem/space/__init__.py +1 -1
  29. warp/fem/space/nanogrid_function_space.py +69 -35
  30. warp/fem/utils.py +113 -107
  31. warp/jax_experimental.py +28 -15
  32. warp/native/array.h +0 -1
  33. warp/native/builtin.h +103 -6
  34. warp/native/bvh.cu +2 -0
  35. warp/native/cuda_util.cpp +14 -0
  36. warp/native/cuda_util.h +2 -0
  37. warp/native/error.cpp +4 -2
  38. warp/native/exports.h +99 -17
  39. warp/native/mat.h +97 -0
  40. warp/native/mesh.cpp +36 -0
  41. warp/native/mesh.cu +51 -0
  42. warp/native/mesh.h +1 -0
  43. warp/native/quat.h +43 -0
  44. warp/native/spatial.h +6 -0
  45. warp/native/vec.h +74 -0
  46. warp/native/warp.cpp +2 -1
  47. warp/native/warp.cu +10 -3
  48. warp/native/warp.h +8 -1
  49. warp/paddle.py +382 -0
  50. warp/sim/__init__.py +1 -0
  51. warp/sim/collide.py +519 -0
  52. warp/sim/integrator_euler.py +18 -5
  53. warp/sim/integrator_featherstone.py +5 -5
  54. warp/sim/integrator_vbd.py +1026 -0
  55. warp/sim/model.py +49 -23
  56. warp/stubs.py +459 -0
  57. warp/tape.py +2 -0
  58. warp/tests/aux_test_dependent.py +1 -0
  59. warp/tests/aux_test_name_clash1.py +32 -0
  60. warp/tests/aux_test_name_clash2.py +32 -0
  61. warp/tests/aux_test_square.py +1 -0
  62. warp/tests/test_array.py +222 -0
  63. warp/tests/test_async.py +3 -3
  64. warp/tests/test_atomic.py +6 -0
  65. warp/tests/test_closest_point_edge_edge.py +93 -1
  66. warp/tests/test_codegen.py +62 -15
  67. warp/tests/test_codegen_instancing.py +1457 -0
  68. warp/tests/test_collision.py +486 -0
  69. warp/tests/test_compile_consts.py +3 -28
  70. warp/tests/test_dlpack.py +170 -0
  71. warp/tests/test_examples.py +22 -8
  72. warp/tests/test_fast_math.py +10 -4
  73. warp/tests/test_fem.py +64 -0
  74. warp/tests/test_func.py +46 -0
  75. warp/tests/test_implicit_init.py +49 -0
  76. warp/tests/test_jax.py +58 -0
  77. warp/tests/test_mat.py +84 -0
  78. warp/tests/test_mesh_query_point.py +188 -0
  79. warp/tests/test_module_hashing.py +40 -0
  80. warp/tests/test_multigpu.py +3 -3
  81. warp/tests/test_overwrite.py +8 -0
  82. warp/tests/test_paddle.py +852 -0
  83. warp/tests/test_print.py +89 -0
  84. warp/tests/test_quat.py +111 -0
  85. warp/tests/test_reload.py +31 -1
  86. warp/tests/test_scalar_ops.py +2 -0
  87. warp/tests/test_static.py +412 -0
  88. warp/tests/test_streams.py +64 -3
  89. warp/tests/test_struct.py +4 -4
  90. warp/tests/test_torch.py +24 -0
  91. warp/tests/test_triangle_closest_point.py +137 -0
  92. warp/tests/test_types.py +1 -1
  93. warp/tests/test_vbd.py +386 -0
  94. warp/tests/test_vec.py +143 -0
  95. warp/tests/test_vec_scalar_ops.py +139 -0
  96. warp/tests/test_volume.py +30 -0
  97. warp/tests/unittest_suites.py +12 -0
  98. warp/tests/unittest_utils.py +9 -5
  99. warp/thirdparty/dlpack.py +3 -1
  100. warp/types.py +157 -34
  101. warp/utils.py +37 -14
  102. {warp_lang-1.3.2.dist-info → warp_lang-1.4.0.dist-info}/METADATA +10 -8
  103. {warp_lang-1.3.2.dist-info → warp_lang-1.4.0.dist-info}/RECORD +106 -94
  104. warp/tests/test_point_triangle_closest_point.py +0 -143
  105. {warp_lang-1.3.2.dist-info → warp_lang-1.4.0.dist-info}/LICENSE.md +0 -0
  106. {warp_lang-1.3.2.dist-info → warp_lang-1.4.0.dist-info}/WHEEL +0 -0
  107. {warp_lang-1.3.2.dist-info → warp_lang-1.4.0.dist-info}/top_level.txt +0 -0
warp/codegen.py CHANGED
@@ -11,6 +11,7 @@ import ast
11
11
  import builtins
12
12
  import ctypes
13
13
  import functools
14
+ import hashlib
14
15
  import inspect
15
16
  import math
16
17
  import re
@@ -232,8 +233,11 @@ class StructInstance:
232
233
 
233
234
  def __getattribute__(self, name):
234
235
  cls = super().__getattribute__("_cls")
235
- if name in cls.vars:
236
- var = cls.vars[name]
236
+ if name == "native_name":
237
+ return cls.native_name
238
+
239
+ var = cls.vars.get(name)
240
+ if var is not None:
237
241
  if isinstance(var.type, type) and issubclass(var.type, ctypes.Array):
238
242
  # Each field stored in a `StructInstance` is exposed as
239
243
  # a standard Python attribute but also has a `ctypes`
@@ -408,6 +412,9 @@ class Struct:
408
412
  elif issubclass(var.type, ctypes.Array):
409
413
  fields.append((label, var.type))
410
414
  else:
415
+ # HACK: fp16 requires conversion functions from warp.so
416
+ if var.type is warp.float16:
417
+ warp.init()
411
418
  fields.append((label, var.type._type_))
412
419
 
413
420
  class StructType(ctypes.Structure):
@@ -416,15 +423,35 @@ class Struct:
416
423
 
417
424
  self.ctype = StructType
418
425
 
426
+ # Compute the hash. We can cache the hash because it's static, even with nested structs.
427
+ # All field types are specified in the annotations, so they're resolved at declaration time.
428
+ ch = hashlib.sha256()
429
+
430
+ ch.update(bytes(self.key, "utf-8"))
431
+
432
+ for name, type_hint in annotations.items():
433
+ s = f"{name}:{warp.types.get_type_code(type_hint)}"
434
+ ch.update(bytes(s, "utf-8"))
435
+
436
+ # recurse on nested structs
437
+ if isinstance(type_hint, Struct):
438
+ ch.update(type_hint.hash)
439
+
440
+ self.hash = ch.digest()
441
+
442
+ # generate unique identifier for structs in native code
443
+ hash_suffix = f"{self.hash.hex()[:8]}"
444
+ self.native_name = f"{self.key}_{hash_suffix}"
445
+
419
446
  # create default constructor (zero-initialize)
420
447
  self.default_constructor = warp.context.Function(
421
448
  func=None,
422
- key=self.key,
449
+ key=self.native_name,
423
450
  namespace="",
424
451
  value_func=lambda *_: self,
425
452
  input_types={},
426
453
  initializer_list_func=lambda *_: False,
427
- native_func=make_full_qualified_name(self.cls),
454
+ native_func=self.native_name,
428
455
  )
429
456
 
430
457
  # build a constructor that takes each param as a value
@@ -432,12 +459,12 @@ class Struct:
432
459
 
433
460
  self.value_constructor = warp.context.Function(
434
461
  func=None,
435
- key=self.key,
462
+ key=self.native_name,
436
463
  namespace="",
437
464
  value_func=lambda *_: self,
438
465
  input_types=input_types,
439
466
  initializer_list_func=lambda *_: False,
440
- native_func=make_full_qualified_name(self.cls),
467
+ native_func=self.native_name,
441
468
  )
442
469
 
443
470
  self.default_constructor.add_overload(self.value_constructor)
@@ -465,6 +492,10 @@ class Struct:
465
492
  def __init__(inst):
466
493
  StructInstance.__init__(inst, self, None)
467
494
 
495
+ # make sure warp.types.get_type_code works with this StructInstance
496
+ NewStructInstance.cls = self.cls
497
+ NewStructInstance.native_name = self.native_name
498
+
468
499
  return NewStructInstance()
469
500
 
470
501
  def initializer(self):
@@ -599,7 +630,7 @@ class Var:
599
630
  if hasattr(t.dtype, "_wp_generic_type_str_"):
600
631
  dtypestr = compute_type_str(f"wp::{t.dtype._wp_generic_type_str_}", t.dtype._wp_type_params_)
601
632
  elif isinstance(t.dtype, Struct):
602
- dtypestr = make_full_qualified_name(t.dtype.cls)
633
+ dtypestr = t.dtype.native_name
603
634
  elif t.dtype.__name__ in ("bool", "int", "float"):
604
635
  dtypestr = t.dtype.__name__
605
636
  else:
@@ -607,7 +638,10 @@ class Var:
607
638
  classstr = f"wp::{type(t).__name__}"
608
639
  return f"{classstr}_t<{dtypestr}>"
609
640
  elif isinstance(t, Struct):
610
- return make_full_qualified_name(t.cls)
641
+ return t.native_name
642
+ elif isinstance(t, type) and issubclass(t, StructInstance):
643
+ # ensure the actual Struct name is used instead of "NewStructInstance"
644
+ return t.native_name
611
645
  elif is_reference(t):
612
646
  if not value_type:
613
647
  return Var.type_to_ctype(t.value_type) + "*"
@@ -863,6 +897,12 @@ class Adjoint:
863
897
  # this is to avoid registering false references to overshadowed modules
864
898
  adj.symbols[name] = arg
865
899
 
900
+ # try to replace static expressions by their constant result if the
901
+ # expression can be evaluated at declaration time
902
+ adj.static_expressions: Dict[str, Any] = {}
903
+ if "static" in adj.source:
904
+ adj.replace_static_expressions()
905
+
866
906
  # There are cases where a same module might be rebuilt multiple times,
867
907
  # for example when kernels are nested inside of functions, or when
868
908
  # a kernel's launch raises an exception. Ideally we'd always want to
@@ -896,6 +936,7 @@ class Adjoint:
896
936
 
897
937
  adj.return_var = None # return type for function or kernel
898
938
  adj.loop_symbols = [] # symbols at the start of each loop
939
+ adj.loop_const_iter_symbols = set() # iteration variables (constant) for static loops
899
940
 
900
941
  # blocks
901
942
  adj.blocks = [Block()]
@@ -948,9 +989,9 @@ class Adjoint:
948
989
  if isinstance(a, warp.context.Function):
949
990
  # functions don't have a var_ prefix so strip it off here
950
991
  if prefix == "var":
951
- arg_strs.append(a.key)
992
+ arg_strs.append(a.native_func)
952
993
  else:
953
- arg_strs.append(f"{prefix}_{a.key}")
994
+ arg_strs.append(f"{prefix}_{a.native_func}")
954
995
  elif is_reference(a.type):
955
996
  arg_strs.append(f"{prefix}_{a}")
956
997
  elif isinstance(a, Var):
@@ -1255,6 +1296,10 @@ class Adjoint:
1255
1296
  if not isinstance(func_arg, (Reference, warp.context.Function)):
1256
1297
  func_arg = adj.load(func_arg)
1257
1298
 
1299
+ # if the argument is a function, build it recursively
1300
+ if isinstance(func_arg, warp.context.Function):
1301
+ adj.builder.build_function(func_arg)
1302
+
1258
1303
  fwd_args.append(strip_reference(func_arg))
1259
1304
 
1260
1305
  if return_type is None:
@@ -1440,6 +1485,7 @@ class Adjoint:
1440
1485
  cond_block.body_forward.append(f"start_{cond_block.label}:;")
1441
1486
 
1442
1487
  c = adj.eval(cond)
1488
+ c = adj.load(c)
1443
1489
 
1444
1490
  cond_block.body_forward.append(f"if (({c.emit()}) == false) goto end_{cond_block.label};")
1445
1491
 
@@ -1493,6 +1539,9 @@ class Adjoint:
1493
1539
 
1494
1540
  def emit_FunctionDef(adj, node):
1495
1541
  for f in node.body:
1542
+ # Skip variable creation for standalone constants, including docstrings
1543
+ if isinstance(f, ast.Expr) and isinstance(f.value, ast.Constant):
1544
+ continue
1496
1545
  adj.eval(f)
1497
1546
 
1498
1547
  if adj.return_var is not None and len(adj.return_var) == 1:
@@ -1523,6 +1572,16 @@ class Adjoint:
1523
1572
  # eval condition
1524
1573
  cond = adj.eval(node.test)
1525
1574
 
1575
+ if cond.constant is not None:
1576
+ # resolve constant condition
1577
+ if cond.constant:
1578
+ for stmt in node.body:
1579
+ adj.eval(stmt)
1580
+ else:
1581
+ for stmt in node.orelse:
1582
+ adj.eval(stmt)
1583
+ return None
1584
+
1526
1585
  # save symbol map
1527
1586
  symbols_prev = adj.symbols.copy()
1528
1587
 
@@ -1618,7 +1677,7 @@ class Adjoint:
1618
1677
  if isinstance(obj, types.ModuleType):
1619
1678
  return obj
1620
1679
 
1621
- raise RuntimeError("Cannot reference a global variable from a kernel unless `wp.constant()` is being used")
1680
+ raise TypeError(f"Invalid external reference type: {type(obj)}")
1622
1681
 
1623
1682
  @staticmethod
1624
1683
  def resolve_type_attribute(var_type: type, attr: str):
@@ -1732,7 +1791,7 @@ class Adjoint:
1732
1791
 
1733
1792
  def emit_NameConstant(adj, node):
1734
1793
  if node.value:
1735
- return adj.add_constant(True)
1794
+ return adj.add_constant(node.value)
1736
1795
  elif node.value is None:
1737
1796
  raise WarpCodegenTypeError("None type unsupported")
1738
1797
  else:
@@ -1746,7 +1805,7 @@ class Adjoint:
1746
1805
  elif isinstance(node, ast.Ellipsis):
1747
1806
  return adj.emit_Ellipsis(node)
1748
1807
  else:
1749
- assert isinstance(node, ast.NameConstant)
1808
+ assert isinstance(node, ast.NameConstant) or isinstance(node, ast.Constant)
1750
1809
  return adj.emit_NameConstant(node)
1751
1810
 
1752
1811
  def emit_BinOp(adj, node):
@@ -1787,6 +1846,11 @@ class Adjoint:
1787
1846
  # detect symbols with conflicting definitions (assigned inside the for loop)
1788
1847
  for items in symbols.items():
1789
1848
  sym = items[0]
1849
+ if adj.loop_const_iter_symbols is not None and sym in adj.loop_const_iter_symbols:
1850
+ # ignore constant overwriting in for-loops if it is a loop iterator
1851
+ # (it is no problem to unroll static loops multiple times in sequence)
1852
+ continue
1853
+
1790
1854
  var1 = items[1]
1791
1855
  var2 = adj.symbols[sym]
1792
1856
 
@@ -1933,15 +1997,27 @@ class Adjoint:
1933
1997
  )
1934
1998
  return range_call
1935
1999
 
2000
+ def begin_record_constant_iter_symbols(adj):
2001
+ if adj.loop_const_iter_symbols is None:
2002
+ adj.loop_const_iter_symbols = set()
2003
+
2004
+ def end_record_constant_iter_symbols(adj):
2005
+ adj.loop_const_iter_symbols = None
2006
+
1936
2007
  def emit_For(adj, node):
1937
2008
  # try and unroll simple range() statements that use constant args
1938
2009
  unroll_range = adj.get_unroll_range(node)
1939
2010
 
1940
2011
  if isinstance(unroll_range, range):
2012
+ const_iter_sym = node.target.id
2013
+ if adj.loop_const_iter_symbols is not None:
2014
+ # prevent constant conflicts in `materialize_redefinitions()`
2015
+ adj.loop_const_iter_symbols.add(const_iter_sym)
2016
+
2017
+ # unroll static for-loop
1941
2018
  for i in unroll_range:
1942
2019
  const_iter = adj.add_constant(i)
1943
- var_iter = adj.add_builtin_call("int", [const_iter])
1944
- adj.symbols[node.target.id] = var_iter
2020
+ adj.symbols[const_iter_sym] = const_iter
1945
2021
 
1946
2022
  # eval body
1947
2023
  for s in node.body:
@@ -1957,6 +2033,7 @@ class Adjoint:
1957
2033
  iter = adj.eval(node.iter)
1958
2034
 
1959
2035
  adj.symbols[node.target.id] = adj.begin_for(iter)
2036
+ adj.begin_record_constant_iter_symbols()
1960
2037
 
1961
2038
  # for loops should be side-effect free, here we store a copy
1962
2039
  adj.loop_symbols.append(adj.symbols.copy())
@@ -1967,6 +2044,7 @@ class Adjoint:
1967
2044
 
1968
2045
  adj.materialize_redefinitions(adj.loop_symbols[-1])
1969
2046
  adj.loop_symbols.pop()
2047
+ adj.end_record_constant_iter_symbols()
1970
2048
 
1971
2049
  adj.end_for(iter)
1972
2050
 
@@ -2023,13 +2101,28 @@ class Adjoint:
2023
2101
 
2024
2102
  # try and lookup function in globals by
2025
2103
  # resolving path (e.g.: module.submodule.attr)
2026
- func, path = adj.resolve_static_expression(node.func)
2104
+ if hasattr(node.func, "warp_func"):
2105
+ func = node.func.warp_func
2106
+ path = []
2107
+ else:
2108
+ func, path = adj.resolve_static_expression(node.func)
2027
2109
  if func is None:
2028
2110
  func = adj.eval(node.func)
2029
2111
 
2112
+ if adj.is_static_expression(func):
2113
+ # try to evaluate wp.static() expressions
2114
+ obj, _ = adj.evaluate_static_expression(node)
2115
+ if obj is not None:
2116
+ if isinstance(obj, warp.context.Function):
2117
+ # special handling for wp.static() evaluating to a function
2118
+ return obj
2119
+ else:
2120
+ out = adj.add_constant(obj)
2121
+ return out
2122
+
2030
2123
  type_args = {}
2031
2124
 
2032
- if not isinstance(func, warp.context.Function):
2125
+ if len(path) > 0 and not isinstance(func, warp.context.Function):
2033
2126
  attr = path[-1]
2034
2127
  caller = func
2035
2128
  func = None
@@ -2083,6 +2176,9 @@ class Adjoint:
2083
2176
  args = tuple(adj.resolve_arg(x) for x in node.args)
2084
2177
  kwargs = {x.arg: adj.resolve_arg(x.value) for x in node.keywords}
2085
2178
 
2179
+ # add the call and build the callee adjoint if needed (func.adj)
2180
+ out = adj.add_call(func, args, kwargs, type_args, min_outputs=min_outputs)
2181
+
2086
2182
  if warp.config.verify_autograd_array_access:
2087
2183
  # update arg read/write states according to what happens to that arg in the called function
2088
2184
  if hasattr(func, "adj"):
@@ -2095,7 +2191,6 @@ class Adjoint:
2095
2191
  if func.adj.args[i].is_read:
2096
2192
  arg.mark_read()
2097
2193
 
2098
- out = adj.add_call(func, args, kwargs, type_args, min_outputs=min_outputs)
2099
2194
  return out
2100
2195
 
2101
2196
  def emit_Index(adj, node):
@@ -2281,20 +2376,40 @@ class Adjoint:
2281
2376
  target.mark_write(kernel_name=kernel_name, filename=filename, lineno=lineno)
2282
2377
 
2283
2378
  elif type_is_vector(target_type) or type_is_quaternion(target_type) or type_is_matrix(target_type):
2379
+ # recursively unwind AST, stopping at penultimate node
2380
+ node = lhs
2381
+ while hasattr(node, "value"):
2382
+ if hasattr(node.value, "value"):
2383
+ node = node.value
2384
+ else:
2385
+ break
2386
+ # lhs is updating a variable adjoint (i.e. wp.adjoint[var])
2387
+ if hasattr(node, "attr") and node.attr == "adjoint":
2388
+ attr = adj.add_builtin_call("index", [target, *indices])
2389
+ adj.add_builtin_call("store", [attr, rhs])
2390
+ return
2391
+
2392
+ # TODO: array vec component case
2284
2393
  if is_reference(target.type):
2285
2394
  attr = adj.add_builtin_call("indexref", [target, *indices])
2286
- else:
2287
- attr = adj.add_builtin_call("index", [target, *indices])
2395
+ adj.add_builtin_call("store", [attr, rhs])
2288
2396
 
2289
- adj.add_builtin_call("store", [attr, rhs])
2397
+ if warp.config.verbose and not adj.custom_reverse_mode:
2398
+ lineno = adj.lineno + adj.fun_lineno
2399
+ line = adj.source_lines[adj.lineno]
2400
+ node_source = adj.get_node_source(lhs.value)
2401
+ print(
2402
+ f"Warning: mutating {node_source} in function {adj.fun_name} at {adj.filename}:{lineno}: this is a non-differentiable operation.\n{line}\n"
2403
+ )
2290
2404
 
2291
- if warp.config.verbose and not adj.custom_reverse_mode:
2292
- lineno = adj.lineno + adj.fun_lineno
2293
- line = adj.source_lines[adj.lineno]
2294
- node_source = adj.get_node_source(lhs.value)
2295
- print(
2296
- f"Warning: mutating {node_source} in function {adj.fun_name} at {adj.filename}:{lineno}: this is a non-differentiable operation.\n{line}\n"
2297
- )
2405
+ else:
2406
+ out = adj.add_builtin_call("assign", [target, *indices, rhs])
2407
+
2408
+ # re-point target symbol to out var
2409
+ for id in adj.symbols:
2410
+ if adj.symbols[id] == target:
2411
+ adj.symbols[id] = out
2412
+ break
2298
2413
 
2299
2414
  else:
2300
2415
  raise WarpCodegenError(
@@ -2329,16 +2444,24 @@ class Adjoint:
2329
2444
  aggregate = adj.eval(lhs.value)
2330
2445
  aggregate_type = strip_reference(aggregate.type)
2331
2446
 
2332
- # assigning to a vector component
2333
- if type_is_vector(aggregate_type):
2447
+ # assigning to a vector or quaternion component
2448
+ if type_is_vector(aggregate_type) or type_is_quaternion(aggregate_type):
2449
+ # TODO: handle wp.adjoint case
2450
+
2334
2451
  index = adj.vector_component_index(lhs.attr, aggregate_type)
2335
2452
 
2453
+ # TODO: array vec component case
2336
2454
  if is_reference(aggregate.type):
2337
2455
  attr = adj.add_builtin_call("indexref", [aggregate, index])
2456
+ adj.add_builtin_call("store", [attr, rhs])
2338
2457
  else:
2339
- attr = adj.add_builtin_call("index", [aggregate, index])
2458
+ out = adj.add_builtin_call("assign", [aggregate, index, rhs])
2340
2459
 
2341
- adj.add_builtin_call("store", [attr, rhs])
2460
+ # re-point target symbol to out var
2461
+ for id in adj.symbols:
2462
+ if adj.symbols[id] == aggregate:
2463
+ adj.symbols[id] = out
2464
+ break
2342
2465
 
2343
2466
  else:
2344
2467
  attr = adj.emit_Attribute(lhs)
@@ -2382,9 +2505,66 @@ class Adjoint:
2382
2505
  adj.add_return(adj.return_var)
2383
2506
 
2384
2507
  def emit_AugAssign(adj, node):
2385
- # replace augmented assignment with assignment statement + binary op
2386
- new_node = ast.Assign(targets=[node.target], value=ast.BinOp(node.target, node.op, node.value))
2387
- adj.eval(new_node)
2508
+ lhs = node.target
2509
+
2510
+ # replace augmented assignment with assignment statement + binary op (default behaviour)
2511
+ def make_new_assign_statement():
2512
+ new_node = ast.Assign(targets=[lhs], value=ast.BinOp(lhs, node.op, node.value))
2513
+ adj.eval(new_node)
2514
+
2515
+ if isinstance(lhs, ast.Subscript):
2516
+ rhs = adj.eval(node.value)
2517
+
2518
+ # wp.adjoint[var] appears in custom grad functions, and does not require
2519
+ # special consideration in the AugAssign case
2520
+ if hasattr(lhs.value, "attr") and lhs.value.attr == "adjoint":
2521
+ make_new_assign_statement()
2522
+ return
2523
+
2524
+ target, indices = adj.eval_subscript(lhs)
2525
+
2526
+ target_type = strip_reference(target.type)
2527
+
2528
+ if is_array(target_type):
2529
+ # target_type is not suitable for atomic array accumulation
2530
+ if target_type.dtype not in warp.types.atomic_types:
2531
+ make_new_assign_statement()
2532
+ return
2533
+
2534
+ kernel_name = adj.fun_name
2535
+ filename = adj.filename
2536
+ lineno = adj.lineno + adj.fun_lineno
2537
+
2538
+ if isinstance(node.op, ast.Add):
2539
+ adj.add_builtin_call("atomic_add", [target, *indices, rhs])
2540
+
2541
+ if warp.config.verify_autograd_array_access:
2542
+ target.mark_write(kernel_name=kernel_name, filename=filename, lineno=lineno)
2543
+
2544
+ elif isinstance(node.op, ast.Sub):
2545
+ adj.add_builtin_call("atomic_sub", [target, *indices, rhs])
2546
+
2547
+ if warp.config.verify_autograd_array_access:
2548
+ target.mark_write(kernel_name=kernel_name, filename=filename, lineno=lineno)
2549
+ else:
2550
+ print(f"Warning: in-place op {node.op} is not differentiable")
2551
+
2552
+ # TODO
2553
+ elif type_is_vector(target_type) or type_is_quaternion(target_type) or type_is_matrix(target_type):
2554
+ make_new_assign_statement()
2555
+ return
2556
+
2557
+ else:
2558
+ raise WarpCodegenError("Can only subscript in-place assign array, vector, quaternion, and matrix types")
2559
+
2560
+ # TODO
2561
+ elif isinstance(lhs, ast.Attribute):
2562
+ make_new_assign_statement()
2563
+ return
2564
+
2565
+ else:
2566
+ make_new_assign_statement()
2567
+ return
2388
2568
 
2389
2569
  def emit_Tuple(adj, node):
2390
2570
  # LHS for expressions, such as i, j, k = 1, 2, 3
@@ -2445,9 +2625,6 @@ class Adjoint:
2445
2625
  if path[0] in adj.symbols:
2446
2626
  return None
2447
2627
 
2448
- if path[0] in __builtins__:
2449
- return __builtins__[path[0]]
2450
-
2451
2628
  # look up in closure/global variables
2452
2629
  expr = adj.resolve_external_reference(path[0])
2453
2630
 
@@ -2455,13 +2632,201 @@ class Adjoint:
2455
2632
  if expr is None:
2456
2633
  expr = getattr(warp, path[0], None)
2457
2634
 
2458
- if expr:
2635
+ # look up in builtins
2636
+ if expr is None:
2637
+ expr = __builtins__.get(path[0])
2638
+
2639
+ if expr is not None:
2459
2640
  for i in range(1, len(path)):
2460
2641
  if hasattr(expr, path[i]):
2461
2642
  expr = getattr(expr, path[i])
2462
2643
 
2463
2644
  return expr
2464
2645
 
2646
+ # retrieves a dictionary of all closure and global variables and their values
2647
+ # to be used in the evaluation context of wp.static() expressions
2648
+ def get_static_evaluation_context(adj):
2649
+ closure_vars = dict(
2650
+ zip(
2651
+ adj.func.__code__.co_freevars,
2652
+ [c.cell_contents for c in (adj.func.__closure__ or [])],
2653
+ )
2654
+ )
2655
+
2656
+ vars_dict = {}
2657
+ vars_dict.update(adj.func.__globals__)
2658
+ # variables captured in closure have precedence over global vars
2659
+ vars_dict.update(closure_vars)
2660
+
2661
+ return vars_dict
2662
+
2663
+ def is_static_expression(adj, func):
2664
+ return (
2665
+ isinstance(func, types.FunctionType)
2666
+ and func.__module__ == "warp.builtins"
2667
+ and func.__qualname__ == "static"
2668
+ )
2669
+
2670
+ # verify the return type of a wp.static() expression is supported inside a Warp kernel
2671
+ def verify_static_return_value(adj, value):
2672
+ if value is None:
2673
+ raise ValueError("None is returned")
2674
+ if warp.types.is_value(value):
2675
+ return True
2676
+ if warp.types.is_array(value):
2677
+ # more useful explanation for the common case of creating a Warp array
2678
+ raise ValueError("a Warp array cannot be created inside Warp kernels")
2679
+ if isinstance(value, str):
2680
+ # we want to support cases such as `print(wp.static("test"))`
2681
+ return True
2682
+ if isinstance(value, warp.context.Function):
2683
+ return True
2684
+
2685
+ def verify_struct(s: StructInstance, attr_path: List[str]):
2686
+ for key in s._cls.vars.keys():
2687
+ v = getattr(s, key)
2688
+ if issubclass(type(v), StructInstance):
2689
+ verify_struct(v, attr_path + [key])
2690
+ else:
2691
+ try:
2692
+ adj.verify_static_return_value(v)
2693
+ except ValueError as e:
2694
+ raise ValueError(
2695
+ f"the returned Warp struct contains a data type that cannot be constructed inside Warp kernels: {e} at {value._cls.key}.{'.'.join(attr_path)}"
2696
+ ) from e
2697
+
2698
+ if issubclass(type(value), StructInstance):
2699
+ return verify_struct(value, [])
2700
+
2701
+ raise ValueError(f"value of type {type(value)} cannot be constructed inside Warp kernels")
2702
+
2703
+ # find the source code string of an AST node
2704
+ def extract_node_source(adj, node) -> Optional[str]:
2705
+ if not hasattr(node, "lineno") or not hasattr(node, "col_offset"):
2706
+ return None
2707
+
2708
+ start_line = node.lineno - 1 # line numbers start at 1
2709
+ start_col = node.col_offset
2710
+
2711
+ if hasattr(node, "end_lineno") and hasattr(node, "end_col_offset"):
2712
+ end_line = node.end_lineno - 1
2713
+ end_col = node.end_col_offset
2714
+ else:
2715
+ # fallback for Python versions before 3.8
2716
+ # we have to find the end line and column manually
2717
+ end_line = start_line
2718
+ end_col = start_col
2719
+ parenthesis_count = 1
2720
+ for lineno in range(start_line, len(adj.source_lines)):
2721
+ if lineno == start_line:
2722
+ c_start = start_col
2723
+ else:
2724
+ c_start = 0
2725
+ line = adj.source_lines[lineno]
2726
+ for i in range(c_start, len(line)):
2727
+ c = line[i]
2728
+ if c == "(":
2729
+ parenthesis_count += 1
2730
+ elif c == ")":
2731
+ parenthesis_count -= 1
2732
+ if parenthesis_count == 0:
2733
+ end_col = i
2734
+ end_line = lineno
2735
+ break
2736
+ if parenthesis_count == 0:
2737
+ break
2738
+
2739
+ if start_line == end_line:
2740
+ # single-line expression
2741
+ return adj.source_lines[start_line][start_col:end_col]
2742
+ else:
2743
+ # multi-line expression
2744
+ lines = []
2745
+ # first line (from start_col to the end)
2746
+ lines.append(adj.source_lines[start_line][start_col:])
2747
+ # middle lines (entire lines)
2748
+ lines.extend(adj.source_lines[start_line + 1 : end_line])
2749
+ # last line (from the start to end_col)
2750
+ lines.append(adj.source_lines[end_line][:end_col])
2751
+ return "\n".join(lines).strip()
2752
+
2753
+ # handles a wp.static() expression and returns the resulting object and a string representing the code
2754
+ # of the static expression
2755
+ def evaluate_static_expression(adj, node) -> Tuple[Any, str]:
2756
+ if len(node.args) == 1:
2757
+ static_code = adj.extract_node_source(node.args[0])
2758
+ elif len(node.keywords) == 1:
2759
+ static_code = adj.extract_node_source(node.keywords[0])
2760
+ else:
2761
+ raise WarpCodegenError("warp.static() requires a single argument or keyword")
2762
+ if static_code is None:
2763
+ raise WarpCodegenError("Error extracting source code from wp.static() expression")
2764
+
2765
+ vars_dict = adj.get_static_evaluation_context()
2766
+ # add constant variables to the static call context
2767
+ constant_vars = {k: v.constant for k, v in adj.symbols.items() if isinstance(v, Var) and v.constant is not None}
2768
+ vars_dict.update(constant_vars)
2769
+
2770
+ try:
2771
+ value = eval(static_code, vars_dict)
2772
+ if warp.config.verbose:
2773
+ print(f"Evaluated static command: {static_code} = {value}")
2774
+ except NameError as e:
2775
+ raise WarpCodegenError(
2776
+ f"Error evaluating static expression: {e}. Make sure all variables used in the static expression are constant."
2777
+ ) from e
2778
+ except Exception as e:
2779
+ raise WarpCodegenError(
2780
+ f"Error evaluating static expression: {e} while evaluating the following code generated from the static expression:\n{static_code}"
2781
+ ) from e
2782
+
2783
+ try:
2784
+ adj.verify_static_return_value(value)
2785
+ except ValueError as e:
2786
+ raise WarpCodegenError(
2787
+ f"Static expression returns an unsupported value: {e} while evaluating the following code generated from the static expression:\n{static_code}"
2788
+ ) from e
2789
+
2790
+ return value, static_code
2791
+
2792
+ # try to replace wp.static() expressions by their evaluated value if the
2793
+ # expression can be evaluated
2794
+ def replace_static_expressions(adj):
2795
+ class StaticExpressionReplacer(ast.NodeTransformer):
2796
+ def visit_Call(self, node):
2797
+ func, _ = adj.resolve_static_expression(node.func, eval_types=False)
2798
+ if adj.is_static_expression(func):
2799
+ try:
2800
+ # the static expression will execute as long as the static expression is valid and
2801
+ # only depends on global or captured variables
2802
+ obj, code = adj.evaluate_static_expression(node)
2803
+ if code is not None:
2804
+ adj.static_expressions[code] = obj
2805
+ if isinstance(obj, warp.context.Function):
2806
+ name_node = ast.Name("__warp_func__")
2807
+ # we add a pointer to the Warp function here so that we can refer to it later at
2808
+ # codegen time (note that the function key itself is not sufficient to uniquely
2809
+ # identify the function, as the function may be redefined between the current time
2810
+ # of wp.static() declaration and the time of codegen during module building)
2811
+ name_node.warp_func = obj
2812
+ return ast.copy_location(name_node, node)
2813
+ else:
2814
+ return ast.copy_location(ast.Constant(value=obj), node)
2815
+ except Exception:
2816
+ # Ignoring failing static expressions should generally not be an issue because only
2817
+ # one of these cases should be possible:
2818
+ # 1) the static expression itself is invalid code, in which case the module cannot be
2819
+ # built all,
2820
+ # 2) the static expression contains a reference to a local (even if constant) variable
2821
+ # (and is therefore not executable and raises this exception), in which
2822
+ # case changing the constant, or the code affecting this constant, would lead to
2823
+ # a different module hash anyway.
2824
+ pass
2825
+
2826
+ return self.generic_visit(node)
2827
+
2828
+ adj.tree = StaticExpressionReplacer().visit(adj.tree)
2829
+
2465
2830
  # Evaluates a static expression that does not depend on runtime values
2466
2831
  # if eval_types is True, try resolving the path using evaluated type information as well
2467
2832
  def resolve_static_expression(adj, root_node, eval_types=True):
@@ -2536,34 +2901,42 @@ class Adjoint:
2536
2901
  # return the Python code corresponding to the given AST node
2537
2902
  return ast.get_source_segment(adj.source, node)
2538
2903
 
2539
- def get_constant_references(adj) -> Dict[str, Any]:
2540
- """Traverses ``adj.tree`` and returns a dictionary containing constant variable names and values.
2541
-
2542
- This function is meant to be used to populate a module's constants dictionary, which then feeds
2543
- into the computation of the module's ``content_hash``.
2544
- """
2904
+ def get_references(adj) -> Tuple[Dict[str, Any], Dict[Any, Any], Dict[warp.context.Function, Any]]:
2905
+ """Traverses ``adj.tree`` and returns referenced constants, types, and user-defined functions."""
2545
2906
 
2546
2907
  local_variables = set() # Track local variables appearing on the LHS so we know when variables are shadowed
2547
- constants_dict = {}
2908
+
2909
+ constants = {}
2910
+ types = {}
2911
+ functions = {}
2548
2912
 
2549
2913
  for node in ast.walk(adj.tree):
2550
2914
  if isinstance(node, ast.Name) and node.id not in local_variables:
2551
2915
  # look up in closure/global variables
2552
2916
  obj = adj.resolve_external_reference(node.id)
2553
-
2554
2917
  if warp.types.is_value(obj):
2555
- constants_dict[node.id] = obj
2918
+ constants[node.id] = obj
2556
2919
 
2557
2920
  elif isinstance(node, ast.Attribute):
2558
2921
  obj, path = adj.resolve_static_expression(node, eval_types=False)
2559
-
2560
2922
  if warp.types.is_value(obj):
2561
- constants_dict[".".join(path)] = obj
2923
+ constants[".".join(path)] = obj
2924
+
2925
+ elif isinstance(node, ast.Call):
2926
+ func, _ = adj.resolve_static_expression(node.func, eval_types=False)
2927
+ if isinstance(func, warp.context.Function) and not func.is_builtin():
2928
+ # calling user-defined function
2929
+ functions[func] = None
2930
+ elif isinstance(func, Struct):
2931
+ # calling struct constructor
2932
+ types[func] = None
2933
+ elif isinstance(func, type) and warp.types.type_is_value(func):
2934
+ # calling value type constructor
2935
+ types[func] = None
2562
2936
 
2563
2937
  elif isinstance(node, ast.Assign):
2564
2938
  # Add the LHS names to the local_variables so we know any subsequent uses are shadowed
2565
2939
  lhs = node.targets[0]
2566
-
2567
2940
  if isinstance(lhs, ast.Tuple):
2568
2941
  for v in lhs.elts:
2569
2942
  if isinstance(v, ast.Name):
@@ -2571,7 +2944,7 @@ class Adjoint:
2571
2944
  elif isinstance(lhs, ast.Name):
2572
2945
  local_variables.add(lhs.id)
2573
2946
 
2574
- return constants_dict
2947
+ return constants, types, functions
2575
2948
 
2576
2949
 
2577
2950
  # ----------------
@@ -2817,6 +3190,15 @@ def constant_str(value):
2817
3190
  # make sure we emit the value of objects, e.g. uint32
2818
3191
  return str(value.value)
2819
3192
 
3193
+ elif issubclass(value_type, warp.codegen.StructInstance):
3194
+ # constant struct instance
3195
+ arg_strs = []
3196
+ for key, var in value._cls.vars.items():
3197
+ attr = getattr(value, key)
3198
+ arg_strs.append(f"{Var.type_to_ctype(var.type)}({constant_str(attr)})")
3199
+ arg_str = ", ".join(arg_strs)
3200
+ return f"{value.native_name}({arg_str})"
3201
+
2820
3202
  elif value == math.inf:
2821
3203
  return "INFINITY"
2822
3204
 
@@ -2845,7 +3227,7 @@ def make_full_qualified_name(func):
2845
3227
 
2846
3228
 
2847
3229
  def codegen_struct(struct, device="cpu", indent_size=4):
2848
- name = make_full_qualified_name(struct.cls)
3230
+ name = struct.native_name
2849
3231
 
2850
3232
  body = []
2851
3233
  indent_block = " " * indent_size