warp-lang 1.0.0b5__py3-none-manylinux2014_x86_64.whl → 1.0.0b6__py3-none-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (187) hide show
  1. docs/conf.py +3 -4
  2. examples/env/env_ant.py +1 -1
  3. examples/env/env_cartpole.py +1 -1
  4. examples/env/env_humanoid.py +1 -1
  5. examples/example_dem.py +28 -26
  6. examples/example_diffray.py +37 -30
  7. examples/example_fluid.py +7 -3
  8. examples/example_jacobian_ik.py +1 -1
  9. examples/example_mesh_intersect.py +10 -7
  10. examples/example_nvdb.py +3 -3
  11. examples/example_render_opengl.py +19 -10
  12. examples/example_sim_cartpole.py +9 -5
  13. examples/example_sim_cloth.py +29 -25
  14. examples/example_sim_fk_grad.py +2 -2
  15. examples/example_sim_fk_grad_torch.py +3 -3
  16. examples/example_sim_grad_bounce.py +11 -8
  17. examples/example_sim_grad_cloth.py +12 -9
  18. examples/example_sim_granular.py +2 -2
  19. examples/example_sim_granular_collision_sdf.py +13 -13
  20. examples/example_sim_neo_hookean.py +3 -3
  21. examples/example_sim_particle_chain.py +2 -2
  22. examples/example_sim_quadruped.py +8 -5
  23. examples/example_sim_rigid_chain.py +8 -5
  24. examples/example_sim_rigid_contact.py +13 -10
  25. examples/example_sim_rigid_fem.py +2 -2
  26. examples/example_sim_rigid_gyroscopic.py +2 -2
  27. examples/example_sim_rigid_kinematics.py +1 -1
  28. examples/example_sim_trajopt.py +3 -2
  29. examples/fem/example_apic_fluid.py +5 -7
  30. examples/fem/example_diffusion_mgpu.py +18 -16
  31. warp/__init__.py +3 -2
  32. warp/bin/warp.so +0 -0
  33. warp/build_dll.py +29 -9
  34. warp/builtins.py +206 -7
  35. warp/codegen.py +58 -38
  36. warp/config.py +3 -1
  37. warp/context.py +234 -128
  38. warp/fem/__init__.py +2 -2
  39. warp/fem/cache.py +2 -1
  40. warp/fem/field/nodal_field.py +18 -17
  41. warp/fem/geometry/hexmesh.py +11 -6
  42. warp/fem/geometry/quadmesh_2d.py +16 -12
  43. warp/fem/geometry/tetmesh.py +19 -8
  44. warp/fem/geometry/trimesh_2d.py +18 -7
  45. warp/fem/integrate.py +341 -196
  46. warp/fem/quadrature/__init__.py +1 -1
  47. warp/fem/quadrature/pic_quadrature.py +138 -53
  48. warp/fem/quadrature/quadrature.py +81 -9
  49. warp/fem/space/__init__.py +1 -1
  50. warp/fem/space/basis_space.py +169 -51
  51. warp/fem/space/grid_2d_function_space.py +2 -2
  52. warp/fem/space/grid_3d_function_space.py +2 -2
  53. warp/fem/space/hexmesh_function_space.py +2 -2
  54. warp/fem/space/partition.py +9 -6
  55. warp/fem/space/quadmesh_2d_function_space.py +2 -2
  56. warp/fem/space/shape/cube_shape_function.py +27 -15
  57. warp/fem/space/shape/square_shape_function.py +29 -18
  58. warp/fem/space/tetmesh_function_space.py +2 -2
  59. warp/fem/space/topology.py +10 -0
  60. warp/fem/space/trimesh_2d_function_space.py +2 -2
  61. warp/fem/utils.py +10 -5
  62. warp/native/array.h +49 -8
  63. warp/native/builtin.h +31 -14
  64. warp/native/cuda_util.cpp +8 -3
  65. warp/native/cuda_util.h +1 -0
  66. warp/native/exports.h +1177 -1108
  67. warp/native/intersect.h +4 -4
  68. warp/native/intersect_adj.h +8 -8
  69. warp/native/mat.h +65 -6
  70. warp/native/mesh.h +126 -5
  71. warp/native/quat.h +28 -4
  72. warp/native/vec.h +76 -14
  73. warp/native/warp.cu +1 -6
  74. warp/render/render_opengl.py +261 -109
  75. warp/sim/import_mjcf.py +13 -7
  76. warp/sim/import_urdf.py +14 -14
  77. warp/sim/inertia.py +17 -18
  78. warp/sim/model.py +67 -67
  79. warp/sim/render.py +1 -1
  80. warp/sparse.py +6 -6
  81. warp/stubs.py +19 -81
  82. warp/tape.py +1 -1
  83. warp/tests/__main__.py +3 -6
  84. warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
  85. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
  86. warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
  87. warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
  88. warp/tests/aux_test_unresolved_func.py +14 -0
  89. warp/tests/aux_test_unresolved_symbol.py +14 -0
  90. warp/tests/{test_kinematics.py → disabled_kinematics.py} +10 -12
  91. warp/tests/run_coverage_serial.py +31 -0
  92. warp/tests/test_adam.py +102 -106
  93. warp/tests/test_arithmetic.py +39 -40
  94. warp/tests/test_array.py +46 -48
  95. warp/tests/test_array_reduce.py +25 -19
  96. warp/tests/test_atomic.py +62 -26
  97. warp/tests/test_bool.py +16 -11
  98. warp/tests/test_builtins_resolution.py +1292 -0
  99. warp/tests/test_bvh.py +9 -12
  100. warp/tests/test_closest_point_edge_edge.py +53 -57
  101. warp/tests/test_codegen.py +164 -134
  102. warp/tests/test_compile_consts.py +13 -19
  103. warp/tests/test_conditional.py +30 -32
  104. warp/tests/test_copy.py +9 -12
  105. warp/tests/test_ctypes.py +90 -98
  106. warp/tests/test_dense.py +20 -14
  107. warp/tests/test_devices.py +34 -35
  108. warp/tests/test_dlpack.py +74 -75
  109. warp/tests/test_examples.py +215 -97
  110. warp/tests/test_fabricarray.py +15 -21
  111. warp/tests/test_fast_math.py +14 -11
  112. warp/tests/test_fem.py +280 -97
  113. warp/tests/test_fp16.py +19 -15
  114. warp/tests/test_func.py +177 -194
  115. warp/tests/test_generics.py +71 -77
  116. warp/tests/test_grad.py +83 -32
  117. warp/tests/test_grad_customs.py +7 -9
  118. warp/tests/test_hash_grid.py +6 -10
  119. warp/tests/test_import.py +9 -23
  120. warp/tests/test_indexedarray.py +19 -21
  121. warp/tests/test_intersect.py +15 -9
  122. warp/tests/test_large.py +17 -19
  123. warp/tests/test_launch.py +14 -17
  124. warp/tests/test_lerp.py +63 -63
  125. warp/tests/test_lvalue.py +84 -35
  126. warp/tests/test_marching_cubes.py +9 -13
  127. warp/tests/test_mat.py +388 -3004
  128. warp/tests/test_mat_lite.py +9 -12
  129. warp/tests/test_mat_scalar_ops.py +2889 -0
  130. warp/tests/test_math.py +10 -11
  131. warp/tests/test_matmul.py +104 -100
  132. warp/tests/test_matmul_lite.py +72 -98
  133. warp/tests/test_mesh.py +35 -32
  134. warp/tests/test_mesh_query_aabb.py +18 -25
  135. warp/tests/test_mesh_query_point.py +39 -23
  136. warp/tests/test_mesh_query_ray.py +9 -21
  137. warp/tests/test_mlp.py +8 -9
  138. warp/tests/test_model.py +89 -93
  139. warp/tests/test_modules_lite.py +15 -25
  140. warp/tests/test_multigpu.py +87 -114
  141. warp/tests/test_noise.py +10 -12
  142. warp/tests/test_operators.py +14 -21
  143. warp/tests/test_options.py +10 -11
  144. warp/tests/test_pinned.py +16 -18
  145. warp/tests/test_print.py +16 -20
  146. warp/tests/test_quat.py +121 -88
  147. warp/tests/test_rand.py +12 -13
  148. warp/tests/test_reload.py +27 -32
  149. warp/tests/test_rounding.py +7 -10
  150. warp/tests/test_runlength_encode.py +105 -106
  151. warp/tests/test_smoothstep.py +8 -9
  152. warp/tests/test_snippet.py +13 -22
  153. warp/tests/test_sparse.py +30 -29
  154. warp/tests/test_spatial.py +179 -174
  155. warp/tests/test_streams.py +100 -107
  156. warp/tests/test_struct.py +98 -67
  157. warp/tests/test_tape.py +11 -17
  158. warp/tests/test_torch.py +89 -86
  159. warp/tests/test_transient_module.py +9 -12
  160. warp/tests/test_types.py +328 -50
  161. warp/tests/test_utils.py +217 -218
  162. warp/tests/test_vec.py +133 -2133
  163. warp/tests/test_vec_lite.py +8 -11
  164. warp/tests/test_vec_scalar_ops.py +2099 -0
  165. warp/tests/test_volume.py +391 -382
  166. warp/tests/test_volume_write.py +122 -135
  167. warp/tests/unittest_serial.py +35 -0
  168. warp/tests/unittest_suites.py +291 -0
  169. warp/tests/{test_base.py → unittest_utils.py} +138 -25
  170. warp/tests/{test_misc.py → unused_test_misc.py} +13 -5
  171. warp/tests/{test_debug.py → walkthough_debug.py} +2 -15
  172. warp/thirdparty/unittest_parallel.py +257 -54
  173. warp/types.py +119 -98
  174. warp/utils.py +14 -0
  175. {warp_lang-1.0.0b5.dist-info → warp_lang-1.0.0b6.dist-info}/METADATA +2 -1
  176. {warp_lang-1.0.0b5.dist-info → warp_lang-1.0.0b6.dist-info}/RECORD +182 -178
  177. {warp_lang-1.0.0b5.dist-info → warp_lang-1.0.0b6.dist-info}/WHEEL +1 -1
  178. warp/tests/test_all.py +0 -239
  179. warp/tests/test_conditional_unequal_types_kernels.py +0 -14
  180. warp/tests/test_coverage.py +0 -38
  181. warp/tests/test_unresolved_func.py +0 -7
  182. warp/tests/test_unresolved_symbol.py +0 -7
  183. /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
  184. /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
  185. /warp/tests/{test_square.py → aux_test_square.py} +0 -0
  186. {warp_lang-1.0.0b5.dist-info → warp_lang-1.0.0b6.dist-info}/LICENSE.md +0 -0
  187. {warp_lang-1.0.0b5.dist-info → warp_lang-1.0.0b6.dist-info}/top_level.txt +0 -0
warp/codegen.py CHANGED
@@ -518,20 +518,17 @@ class Adjoint:
518
518
  # whether the generation of the adjoint code is skipped for this function
519
519
  adj.skip_reverse_codegen = skip_reverse_codegen
520
520
 
521
- # build AST from function object
522
- adj.source = inspect.getsource(func)
523
-
524
- # get source code lines and line number where function starts
525
- adj.raw_source, adj.fun_lineno = inspect.getsourcelines(func)
526
-
527
- # keep track of line number in function code
528
- adj.lineno = None
521
+ # extract name of source file
522
+ adj.filename = inspect.getsourcefile(func) or "unknown source file"
523
+ # get source file line number where function starts
524
+ _, adj.fun_lineno = inspect.getsourcelines(func)
529
525
 
526
+ # get function source code
527
+ adj.source = inspect.getsource(func)
530
528
  # ensures that indented class methods can be parsed as kernels
531
529
  adj.source = textwrap.dedent(adj.source)
532
530
 
533
- # extract name of source file
534
- adj.filename = inspect.getsourcefile(func) or "unknown source file"
531
+ adj.source_lines = adj.source.splitlines()
535
532
 
536
533
  # build AST and apply node transformers
537
534
  adj.tree = ast.parse(adj.source)
@@ -541,6 +538,9 @@ class Adjoint:
541
538
 
542
539
  adj.fun_name = adj.tree.body[0].name
543
540
 
541
+ # for keeping track of line number in function code
542
+ adj.lineno = None
543
+
544
544
  # whether the forward code shall be used for the reverse pass and a custom
545
545
  # function signature is applied to the reverse version of the function
546
546
  adj.custom_reverse_mode = custom_reverse_mode
@@ -625,7 +625,7 @@ class Adjoint:
625
625
  else:
626
626
  msg = "Error"
627
627
  lineno = adj.lineno + adj.fun_lineno
628
- line = adj.source.splitlines()[adj.lineno]
628
+ line = adj.source_lines[adj.lineno]
629
629
  msg += f' while parsing function "{adj.fun_name}" at {adj.filename}:{lineno}:\n{line}\n'
630
630
  ex, data, traceback = sys.exc_info()
631
631
  e = ex(";".join([msg] + [str(a) for a in data.args])).with_traceback(traceback)
@@ -683,10 +683,11 @@ class Adjoint:
683
683
  args_out,
684
684
  use_initializer_list,
685
685
  has_output_args=True,
686
+ require_original_output_arg=False,
686
687
  ):
687
688
  formatted_var = adj.format_args("var", args_var)
688
689
  formatted_out = []
689
- if has_output_args and len(args_out) > 1:
690
+ if has_output_args and (require_original_output_arg or len(args_out) > 1):
690
691
  formatted_out = adj.format_args("var", args_out)
691
692
  formatted_var_adj = adj.format_args(
692
693
  "&adj" if use_initializer_list else "adj",
@@ -966,13 +967,16 @@ class Adjoint:
966
967
  adj.add_forward(forward_call, replay=replay_call)
967
968
 
968
969
  if not func.missing_grad and len(args):
969
- reverse_has_output_args = len(output_list) > 1 and func.custom_grad_func is None
970
+ reverse_has_output_args = (
971
+ func.require_original_output_arg or len(output_list) > 1
972
+ ) and func.custom_grad_func is None
970
973
  arg_str = adj.format_reverse_call_args(
971
974
  args_var,
972
975
  args,
973
976
  output_list,
974
977
  use_initializer_list,
975
978
  has_output_args=reverse_has_output_args,
979
+ require_original_output_arg=func.require_original_output_arg,
976
980
  )
977
981
  if arg_str is not None:
978
982
  reverse_call = f"{func.namespace}adj_{func.native_func}({arg_str});"
@@ -1291,6 +1295,12 @@ class Adjoint:
1291
1295
  index = adj.add_constant(index)
1292
1296
  return index
1293
1297
 
1298
+ @staticmethod
1299
+ def is_differentiable_value_type(var_type):
1300
+ # checks that the argument type is a value type (i.e, not an array)
1301
+ # possibly holding differentiable values (for which gradients must be accumulated)
1302
+ return type_scalar_type(var_type) in float_types or isinstance(var_type, Struct)
1303
+
1294
1304
  def emit_Attribute(adj, node):
1295
1305
  if hasattr(node, "is_adjoint"):
1296
1306
  node.value.is_adjoint = True
@@ -1327,9 +1337,12 @@ class Adjoint:
1327
1337
 
1328
1338
  if is_reference(aggregate.type):
1329
1339
  adj.add_forward(f"{attr.emit()} = &({aggregate.emit()}->{node.attr});")
1330
- adj.add_reverse(f"{aggregate.emit_adj()}.{node.attr} = {attr.emit_adj()};")
1331
1340
  else:
1332
1341
  adj.add_forward(f"{attr.emit()} = &({aggregate.emit()}.{node.attr});")
1342
+
1343
+ if adj.is_differentiable_value_type(strip_reference(attr_type)):
1344
+ adj.add_reverse(f"{aggregate.emit_adj()}.{node.attr} += {attr.emit_adj()};")
1345
+ else:
1333
1346
  adj.add_reverse(f"{aggregate.emit_adj()}.{node.attr} = {attr.emit_adj()};")
1334
1347
 
1335
1348
  return attr
@@ -1344,7 +1357,7 @@ class Adjoint:
1344
1357
 
1345
1358
  if isinstance(aggregate, Var):
1346
1359
  raise WarpCodegenAttributeError(
1347
- f"Error, `{node.attr}` is not an attribute of '{aggregate.label}' ({type_repr(aggregate.type)})"
1360
+ f"Error, `{node.attr}` is not an attribute of '{node.value.id}' ({type_repr(aggregate.type)})"
1348
1361
  )
1349
1362
  raise WarpCodegenAttributeError(f"Error, `{node.attr}` is not an attribute of '{aggregate}'")
1350
1363
 
@@ -1368,12 +1381,12 @@ class Adjoint:
1368
1381
  return
1369
1382
 
1370
1383
  def emit_NameConstant(adj, node):
1371
- if node.value is True:
1384
+ if node.value:
1372
1385
  return adj.add_constant(True)
1373
- elif node.value is False:
1374
- return adj.add_constant(False)
1375
1386
  elif node.value is None:
1376
1387
  raise WarpCodegenTypeError("None type unsupported")
1388
+ else:
1389
+ return adj.add_constant(False)
1377
1390
 
1378
1391
  def emit_Constant(adj, node):
1379
1392
  if isinstance(node, ast.Str):
@@ -1413,7 +1426,7 @@ class Adjoint:
1413
1426
  if var1 != var2:
1414
1427
  if warp.config.verbose and not adj.custom_reverse_mode:
1415
1428
  lineno = adj.lineno + adj.fun_lineno
1416
- line = adj.source.splitlines()[adj.lineno]
1429
+ line = adj.source_lines[adj.lineno]
1417
1430
  msg = f'Warning: detected mutated variable {sym} during a dynamic for-loop in function "{adj.fun_name}" at {adj.filename}:{lineno}: this may not be a differentiable operation.\n{line}\n'
1418
1431
  print(msg)
1419
1432
 
@@ -1450,7 +1463,11 @@ class Adjoint:
1450
1463
 
1451
1464
  # try and resolve the expression to an object
1452
1465
  # e.g.: wp.constant in the globals scope
1453
- obj, path = adj.resolve_static_expression(a)
1466
+ obj, _ = adj.resolve_static_expression(a)
1467
+
1468
+ if isinstance(obj, Var) and obj.constant is not None:
1469
+ obj = obj.constant
1470
+
1454
1471
  return warp.types.is_int(obj), obj
1455
1472
 
1456
1473
  # detects whether a loop contains a break (or continue) statement
@@ -1596,7 +1613,7 @@ class Adjoint:
1596
1613
  if adj.is_user_function:
1597
1614
  if hasattr(node.func, "attr") and node.func.attr == "tid":
1598
1615
  lineno = adj.lineno + adj.fun_lineno
1599
- line = adj.source.splitlines()[adj.lineno]
1616
+ line = adj.source_lines[adj.lineno]
1600
1617
  raise WarpCodegenError(
1601
1618
  "tid() may only be called from a Warp kernel, not a Warp function. "
1602
1619
  "Instead, obtain the indices from a @wp.kernel and pass them as "
@@ -1613,7 +1630,7 @@ class Adjoint:
1613
1630
 
1614
1631
  if not isinstance(func, warp.context.Function):
1615
1632
  if len(path) == 0:
1616
- raise WarpCodegenError(f"Unrecognized syntax for function call, path not valid: '{node.func}'")
1633
+ raise WarpCodegenError(f"Unknown function or operator: '{node.func.func.id}'")
1617
1634
 
1618
1635
  attr = path[-1]
1619
1636
  caller = func
@@ -1818,7 +1835,7 @@ class Adjoint:
1818
1835
 
1819
1836
  if warp.config.verbose and not adj.custom_reverse_mode:
1820
1837
  lineno = adj.lineno + adj.fun_lineno
1821
- line = adj.source.splitlines()[adj.lineno]
1838
+ line = adj.source_lines[adj.lineno]
1822
1839
  node_source = adj.get_node_source(lhs.value)
1823
1840
  print(
1824
1841
  f"Warning: mutating {node_source} in function {adj.fun_name} at {adj.filename}:{lineno}: this is a non-differentiable operation.\n{line}\n"
@@ -1875,7 +1892,7 @@ class Adjoint:
1875
1892
 
1876
1893
  if warp.config.verbose and not adj.custom_reverse_mode:
1877
1894
  lineno = adj.lineno + adj.fun_lineno
1878
- line = adj.source.splitlines()[adj.lineno]
1895
+ line = adj.source_lines[adj.lineno]
1879
1896
  msg = f'Warning: detected mutated struct {attr.label} during function "{adj.fun_name}" at {adj.filename}:{lineno}: this is a non-differentiable operation.\n{line}\n'
1880
1897
  print(msg)
1881
1898
 
@@ -1901,7 +1918,8 @@ class Adjoint:
1901
1918
  if var is not None:
1902
1919
  adj.return_var = tuple()
1903
1920
  for ret in var:
1904
- ret = adj.load(ret)
1921
+ if is_reference(ret.type):
1922
+ ret = adj.add_builtin_call("copy", [ret])
1905
1923
  adj.return_var += (ret,)
1906
1924
 
1907
1925
  adj.add_return(adj.return_var)
@@ -1945,7 +1963,7 @@ class Adjoint:
1945
1963
  ast.AugAssign: emit_AugAssign,
1946
1964
  ast.Tuple: emit_Tuple,
1947
1965
  ast.Pass: emit_Pass,
1948
- ast.Ellipsis: emit_Ellipsis
1966
+ ast.Ellipsis: emit_Ellipsis,
1949
1967
  }
1950
1968
 
1951
1969
  def eval(adj, node):
@@ -2009,16 +2027,11 @@ class Adjoint:
2009
2027
  attributes.append(node.attr)
2010
2028
  node = node.value
2011
2029
 
2012
- if eval_types and isinstance(node, ast.Call):
2030
+ if eval_types and isinstance(node, ast.Call) and isinstance(node.func, ast.Name):
2013
2031
  # support for operators returning modules
2014
2032
  # i.e. operator_name(*operator_args).x.y.z
2015
2033
  operator_args = node.args
2016
- operator_name = getattr(node.func, "id", None)
2017
-
2018
- if operator_name is None:
2019
- raise WarpCodegenError(
2020
- f"Invalid operator call syntax, expected a plain name, got {ast.dump(node.func, annotate_fields=False)}"
2021
- )
2034
+ operator_name = node.func.id
2022
2035
 
2023
2036
  if operator_name == "type":
2024
2037
  if len(operator_args) != 1:
@@ -2043,8 +2056,6 @@ class Adjoint:
2043
2056
  else:
2044
2057
  raise WarpCodegenError(f"Cannot deduce the type of {var}")
2045
2058
 
2046
- raise WarpCodegenError(f"Unknown operator '{operator_name}'")
2047
-
2048
2059
  # reverse list since ast presents it backward order
2049
2060
  path = [*reversed(attributes)]
2050
2061
  if isinstance(node, ast.Name):
@@ -2071,14 +2082,14 @@ class Adjoint:
2071
2082
  def set_lineno(adj, lineno):
2072
2083
  if adj.lineno is None or adj.lineno != lineno:
2073
2084
  line = lineno + adj.fun_lineno
2074
- source = adj.raw_source[lineno].strip().ljust(80 - len(adj.indentation), " ")
2085
+ source = adj.source_lines[lineno].strip().ljust(80 - len(adj.indentation), " ")
2075
2086
  adj.add_forward(f"// {source} <L {line}>")
2076
2087
  adj.add_reverse(f"// adj: {source} <L {line}>")
2077
2088
  adj.lineno = lineno
2078
2089
 
2079
2090
  def get_node_source(adj, node):
2080
2091
  # return the Python code corresponding to the given AST node
2081
- return ast.get_source_segment("".join(adj.raw_source), node)
2092
+ return ast.get_source_segment(adj.source, node)
2082
2093
 
2083
2094
 
2084
2095
  # ----------------
@@ -2130,7 +2141,9 @@ struct {name}
2130
2141
  {{
2131
2142
  }}
2132
2143
 
2133
- CUDA_CALLABLE {name}& operator += (const {name}&) {{ return *this; }}
2144
+ CUDA_CALLABLE {name}& operator += (const {name}& rhs)
2145
+ {{{prefix_add_body}
2146
+ return *this;}}
2134
2147
 
2135
2148
  }};
2136
2149
 
@@ -2357,6 +2370,7 @@ def codegen_struct(struct, device="cpu", indent_size=4):
2357
2370
  forward_initializers = []
2358
2371
  reverse_body = []
2359
2372
  atomic_add_body = []
2373
+ prefix_add_body = []
2360
2374
 
2361
2375
  # forward args
2362
2376
  for label, var in struct.vars.items():
@@ -2370,6 +2384,11 @@ def codegen_struct(struct, device="cpu", indent_size=4):
2370
2384
  prefix = f"{indent_block}," if forward_initializers else ":"
2371
2385
  forward_initializers.append(f"{indent_block}{prefix} {label}{{{label}}}\n")
2372
2386
 
2387
+ # prefix-add operator
2388
+ for label, var in struct.vars.items():
2389
+ if not is_array(var.type):
2390
+ prefix_add_body.append(f"{indent_block}{label} += rhs.{label};\n")
2391
+
2373
2392
  # reverse args
2374
2393
  for label, var in struct.vars.items():
2375
2394
  reverse_args.append(var.ctype() + " & adj_" + label)
@@ -2387,6 +2406,7 @@ def codegen_struct(struct, device="cpu", indent_size=4):
2387
2406
  forward_initializers="".join(forward_initializers),
2388
2407
  reverse_args=indent(reverse_args),
2389
2408
  reverse_body="".join(reverse_body),
2409
+ prefix_add_body="".join(prefix_add_body),
2390
2410
  atomic_add_body="".join(atomic_add_body),
2391
2411
  )
2392
2412
 
warp/config.py CHANGED
@@ -5,7 +5,7 @@
5
5
  # distribution of this software and related documentation without an express
6
6
  # license agreement from NVIDIA CORPORATION is strictly prohibited.
7
7
 
8
- version = "1.0.0-beta.5"
8
+ version = "1.0.0-beta.6"
9
9
 
10
10
  cuda_path = (
11
11
  None # path to local CUDA toolchain, if None at init time warp will attempt to find the SDK using CUDA_PATH env var
@@ -33,3 +33,5 @@ ptx_target_arch = 70 # target architecture for PTX generation, defaults to the
33
33
  enable_backward = True # whether to compiler the backward passes of the kernels
34
34
 
35
35
  llvm_cuda = False # use Clang/LLVM instead of NVRTC to compile CUDA
36
+
37
+ graph_capture_module_load_default = True # Default value of force_module_load for capture_begin()