warp-lang 1.0.0b5__py3-none-manylinux2014_x86_64.whl → 1.0.0b6__py3-none-manylinux2014_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docs/conf.py +3 -4
- examples/env/env_ant.py +1 -1
- examples/env/env_cartpole.py +1 -1
- examples/env/env_humanoid.py +1 -1
- examples/example_dem.py +28 -26
- examples/example_diffray.py +37 -30
- examples/example_fluid.py +7 -3
- examples/example_jacobian_ik.py +1 -1
- examples/example_mesh_intersect.py +10 -7
- examples/example_nvdb.py +3 -3
- examples/example_render_opengl.py +19 -10
- examples/example_sim_cartpole.py +9 -5
- examples/example_sim_cloth.py +29 -25
- examples/example_sim_fk_grad.py +2 -2
- examples/example_sim_fk_grad_torch.py +3 -3
- examples/example_sim_grad_bounce.py +11 -8
- examples/example_sim_grad_cloth.py +12 -9
- examples/example_sim_granular.py +2 -2
- examples/example_sim_granular_collision_sdf.py +13 -13
- examples/example_sim_neo_hookean.py +3 -3
- examples/example_sim_particle_chain.py +2 -2
- examples/example_sim_quadruped.py +8 -5
- examples/example_sim_rigid_chain.py +8 -5
- examples/example_sim_rigid_contact.py +13 -10
- examples/example_sim_rigid_fem.py +2 -2
- examples/example_sim_rigid_gyroscopic.py +2 -2
- examples/example_sim_rigid_kinematics.py +1 -1
- examples/example_sim_trajopt.py +3 -2
- examples/fem/example_apic_fluid.py +5 -7
- examples/fem/example_diffusion_mgpu.py +18 -16
- warp/__init__.py +3 -2
- warp/bin/warp.so +0 -0
- warp/build_dll.py +29 -9
- warp/builtins.py +206 -7
- warp/codegen.py +58 -38
- warp/config.py +3 -1
- warp/context.py +234 -128
- warp/fem/__init__.py +2 -2
- warp/fem/cache.py +2 -1
- warp/fem/field/nodal_field.py +18 -17
- warp/fem/geometry/hexmesh.py +11 -6
- warp/fem/geometry/quadmesh_2d.py +16 -12
- warp/fem/geometry/tetmesh.py +19 -8
- warp/fem/geometry/trimesh_2d.py +18 -7
- warp/fem/integrate.py +341 -196
- warp/fem/quadrature/__init__.py +1 -1
- warp/fem/quadrature/pic_quadrature.py +138 -53
- warp/fem/quadrature/quadrature.py +81 -9
- warp/fem/space/__init__.py +1 -1
- warp/fem/space/basis_space.py +169 -51
- warp/fem/space/grid_2d_function_space.py +2 -2
- warp/fem/space/grid_3d_function_space.py +2 -2
- warp/fem/space/hexmesh_function_space.py +2 -2
- warp/fem/space/partition.py +9 -6
- warp/fem/space/quadmesh_2d_function_space.py +2 -2
- warp/fem/space/shape/cube_shape_function.py +27 -15
- warp/fem/space/shape/square_shape_function.py +29 -18
- warp/fem/space/tetmesh_function_space.py +2 -2
- warp/fem/space/topology.py +10 -0
- warp/fem/space/trimesh_2d_function_space.py +2 -2
- warp/fem/utils.py +10 -5
- warp/native/array.h +49 -8
- warp/native/builtin.h +31 -14
- warp/native/cuda_util.cpp +8 -3
- warp/native/cuda_util.h +1 -0
- warp/native/exports.h +1177 -1108
- warp/native/intersect.h +4 -4
- warp/native/intersect_adj.h +8 -8
- warp/native/mat.h +65 -6
- warp/native/mesh.h +126 -5
- warp/native/quat.h +28 -4
- warp/native/vec.h +76 -14
- warp/native/warp.cu +1 -6
- warp/render/render_opengl.py +261 -109
- warp/sim/import_mjcf.py +13 -7
- warp/sim/import_urdf.py +14 -14
- warp/sim/inertia.py +17 -18
- warp/sim/model.py +67 -67
- warp/sim/render.py +1 -1
- warp/sparse.py +6 -6
- warp/stubs.py +19 -81
- warp/tape.py +1 -1
- warp/tests/__main__.py +3 -6
- warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
- warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
- warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
- warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
- warp/tests/aux_test_unresolved_func.py +14 -0
- warp/tests/aux_test_unresolved_symbol.py +14 -0
- warp/tests/{test_kinematics.py → disabled_kinematics.py} +10 -12
- warp/tests/run_coverage_serial.py +31 -0
- warp/tests/test_adam.py +102 -106
- warp/tests/test_arithmetic.py +39 -40
- warp/tests/test_array.py +46 -48
- warp/tests/test_array_reduce.py +25 -19
- warp/tests/test_atomic.py +62 -26
- warp/tests/test_bool.py +16 -11
- warp/tests/test_builtins_resolution.py +1292 -0
- warp/tests/test_bvh.py +9 -12
- warp/tests/test_closest_point_edge_edge.py +53 -57
- warp/tests/test_codegen.py +164 -134
- warp/tests/test_compile_consts.py +13 -19
- warp/tests/test_conditional.py +30 -32
- warp/tests/test_copy.py +9 -12
- warp/tests/test_ctypes.py +90 -98
- warp/tests/test_dense.py +20 -14
- warp/tests/test_devices.py +34 -35
- warp/tests/test_dlpack.py +74 -75
- warp/tests/test_examples.py +215 -97
- warp/tests/test_fabricarray.py +15 -21
- warp/tests/test_fast_math.py +14 -11
- warp/tests/test_fem.py +280 -97
- warp/tests/test_fp16.py +19 -15
- warp/tests/test_func.py +177 -194
- warp/tests/test_generics.py +71 -77
- warp/tests/test_grad.py +83 -32
- warp/tests/test_grad_customs.py +7 -9
- warp/tests/test_hash_grid.py +6 -10
- warp/tests/test_import.py +9 -23
- warp/tests/test_indexedarray.py +19 -21
- warp/tests/test_intersect.py +15 -9
- warp/tests/test_large.py +17 -19
- warp/tests/test_launch.py +14 -17
- warp/tests/test_lerp.py +63 -63
- warp/tests/test_lvalue.py +84 -35
- warp/tests/test_marching_cubes.py +9 -13
- warp/tests/test_mat.py +388 -3004
- warp/tests/test_mat_lite.py +9 -12
- warp/tests/test_mat_scalar_ops.py +2889 -0
- warp/tests/test_math.py +10 -11
- warp/tests/test_matmul.py +104 -100
- warp/tests/test_matmul_lite.py +72 -98
- warp/tests/test_mesh.py +35 -32
- warp/tests/test_mesh_query_aabb.py +18 -25
- warp/tests/test_mesh_query_point.py +39 -23
- warp/tests/test_mesh_query_ray.py +9 -21
- warp/tests/test_mlp.py +8 -9
- warp/tests/test_model.py +89 -93
- warp/tests/test_modules_lite.py +15 -25
- warp/tests/test_multigpu.py +87 -114
- warp/tests/test_noise.py +10 -12
- warp/tests/test_operators.py +14 -21
- warp/tests/test_options.py +10 -11
- warp/tests/test_pinned.py +16 -18
- warp/tests/test_print.py +16 -20
- warp/tests/test_quat.py +121 -88
- warp/tests/test_rand.py +12 -13
- warp/tests/test_reload.py +27 -32
- warp/tests/test_rounding.py +7 -10
- warp/tests/test_runlength_encode.py +105 -106
- warp/tests/test_smoothstep.py +8 -9
- warp/tests/test_snippet.py +13 -22
- warp/tests/test_sparse.py +30 -29
- warp/tests/test_spatial.py +179 -174
- warp/tests/test_streams.py +100 -107
- warp/tests/test_struct.py +98 -67
- warp/tests/test_tape.py +11 -17
- warp/tests/test_torch.py +89 -86
- warp/tests/test_transient_module.py +9 -12
- warp/tests/test_types.py +328 -50
- warp/tests/test_utils.py +217 -218
- warp/tests/test_vec.py +133 -2133
- warp/tests/test_vec_lite.py +8 -11
- warp/tests/test_vec_scalar_ops.py +2099 -0
- warp/tests/test_volume.py +391 -382
- warp/tests/test_volume_write.py +122 -135
- warp/tests/unittest_serial.py +35 -0
- warp/tests/unittest_suites.py +291 -0
- warp/tests/{test_base.py → unittest_utils.py} +138 -25
- warp/tests/{test_misc.py → unused_test_misc.py} +13 -5
- warp/tests/{test_debug.py → walkthough_debug.py} +2 -15
- warp/thirdparty/unittest_parallel.py +257 -54
- warp/types.py +119 -98
- warp/utils.py +14 -0
- {warp_lang-1.0.0b5.dist-info → warp_lang-1.0.0b6.dist-info}/METADATA +2 -1
- {warp_lang-1.0.0b5.dist-info → warp_lang-1.0.0b6.dist-info}/RECORD +182 -178
- {warp_lang-1.0.0b5.dist-info → warp_lang-1.0.0b6.dist-info}/WHEEL +1 -1
- warp/tests/test_all.py +0 -239
- warp/tests/test_conditional_unequal_types_kernels.py +0 -14
- warp/tests/test_coverage.py +0 -38
- warp/tests/test_unresolved_func.py +0 -7
- warp/tests/test_unresolved_symbol.py +0 -7
- /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
- /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
- /warp/tests/{test_square.py → aux_test_square.py} +0 -0
- {warp_lang-1.0.0b5.dist-info → warp_lang-1.0.0b6.dist-info}/LICENSE.md +0 -0
- {warp_lang-1.0.0b5.dist-info → warp_lang-1.0.0b6.dist-info}/top_level.txt +0 -0
warp/codegen.py
CHANGED
|
@@ -518,20 +518,17 @@ class Adjoint:
|
|
|
518
518
|
# whether the generation of the adjoint code is skipped for this function
|
|
519
519
|
adj.skip_reverse_codegen = skip_reverse_codegen
|
|
520
520
|
|
|
521
|
-
#
|
|
522
|
-
adj.
|
|
523
|
-
|
|
524
|
-
|
|
525
|
-
adj.raw_source, adj.fun_lineno = inspect.getsourcelines(func)
|
|
526
|
-
|
|
527
|
-
# keep track of line number in function code
|
|
528
|
-
adj.lineno = None
|
|
521
|
+
# extract name of source file
|
|
522
|
+
adj.filename = inspect.getsourcefile(func) or "unknown source file"
|
|
523
|
+
# get source file line number where function starts
|
|
524
|
+
_, adj.fun_lineno = inspect.getsourcelines(func)
|
|
529
525
|
|
|
526
|
+
# get function source code
|
|
527
|
+
adj.source = inspect.getsource(func)
|
|
530
528
|
# ensures that indented class methods can be parsed as kernels
|
|
531
529
|
adj.source = textwrap.dedent(adj.source)
|
|
532
530
|
|
|
533
|
-
|
|
534
|
-
adj.filename = inspect.getsourcefile(func) or "unknown source file"
|
|
531
|
+
adj.source_lines = adj.source.splitlines()
|
|
535
532
|
|
|
536
533
|
# build AST and apply node transformers
|
|
537
534
|
adj.tree = ast.parse(adj.source)
|
|
@@ -541,6 +538,9 @@ class Adjoint:
|
|
|
541
538
|
|
|
542
539
|
adj.fun_name = adj.tree.body[0].name
|
|
543
540
|
|
|
541
|
+
# for keeping track of line number in function code
|
|
542
|
+
adj.lineno = None
|
|
543
|
+
|
|
544
544
|
# whether the forward code shall be used for the reverse pass and a custom
|
|
545
545
|
# function signature is applied to the reverse version of the function
|
|
546
546
|
adj.custom_reverse_mode = custom_reverse_mode
|
|
@@ -625,7 +625,7 @@ class Adjoint:
|
|
|
625
625
|
else:
|
|
626
626
|
msg = "Error"
|
|
627
627
|
lineno = adj.lineno + adj.fun_lineno
|
|
628
|
-
line = adj.
|
|
628
|
+
line = adj.source_lines[adj.lineno]
|
|
629
629
|
msg += f' while parsing function "{adj.fun_name}" at {adj.filename}:{lineno}:\n{line}\n'
|
|
630
630
|
ex, data, traceback = sys.exc_info()
|
|
631
631
|
e = ex(";".join([msg] + [str(a) for a in data.args])).with_traceback(traceback)
|
|
@@ -683,10 +683,11 @@ class Adjoint:
|
|
|
683
683
|
args_out,
|
|
684
684
|
use_initializer_list,
|
|
685
685
|
has_output_args=True,
|
|
686
|
+
require_original_output_arg=False,
|
|
686
687
|
):
|
|
687
688
|
formatted_var = adj.format_args("var", args_var)
|
|
688
689
|
formatted_out = []
|
|
689
|
-
if has_output_args and len(args_out) > 1:
|
|
690
|
+
if has_output_args and (require_original_output_arg or len(args_out) > 1):
|
|
690
691
|
formatted_out = adj.format_args("var", args_out)
|
|
691
692
|
formatted_var_adj = adj.format_args(
|
|
692
693
|
"&adj" if use_initializer_list else "adj",
|
|
@@ -966,13 +967,16 @@ class Adjoint:
|
|
|
966
967
|
adj.add_forward(forward_call, replay=replay_call)
|
|
967
968
|
|
|
968
969
|
if not func.missing_grad and len(args):
|
|
969
|
-
reverse_has_output_args =
|
|
970
|
+
reverse_has_output_args = (
|
|
971
|
+
func.require_original_output_arg or len(output_list) > 1
|
|
972
|
+
) and func.custom_grad_func is None
|
|
970
973
|
arg_str = adj.format_reverse_call_args(
|
|
971
974
|
args_var,
|
|
972
975
|
args,
|
|
973
976
|
output_list,
|
|
974
977
|
use_initializer_list,
|
|
975
978
|
has_output_args=reverse_has_output_args,
|
|
979
|
+
require_original_output_arg=func.require_original_output_arg,
|
|
976
980
|
)
|
|
977
981
|
if arg_str is not None:
|
|
978
982
|
reverse_call = f"{func.namespace}adj_{func.native_func}({arg_str});"
|
|
@@ -1291,6 +1295,12 @@ class Adjoint:
|
|
|
1291
1295
|
index = adj.add_constant(index)
|
|
1292
1296
|
return index
|
|
1293
1297
|
|
|
1298
|
+
@staticmethod
|
|
1299
|
+
def is_differentiable_value_type(var_type):
|
|
1300
|
+
# checks that the argument type is a value type (i.e, not an array)
|
|
1301
|
+
# possibly holding differentiable values (for which gradients must be accumulated)
|
|
1302
|
+
return type_scalar_type(var_type) in float_types or isinstance(var_type, Struct)
|
|
1303
|
+
|
|
1294
1304
|
def emit_Attribute(adj, node):
|
|
1295
1305
|
if hasattr(node, "is_adjoint"):
|
|
1296
1306
|
node.value.is_adjoint = True
|
|
@@ -1327,9 +1337,12 @@ class Adjoint:
|
|
|
1327
1337
|
|
|
1328
1338
|
if is_reference(aggregate.type):
|
|
1329
1339
|
adj.add_forward(f"{attr.emit()} = &({aggregate.emit()}->{node.attr});")
|
|
1330
|
-
adj.add_reverse(f"{aggregate.emit_adj()}.{node.attr} = {attr.emit_adj()};")
|
|
1331
1340
|
else:
|
|
1332
1341
|
adj.add_forward(f"{attr.emit()} = &({aggregate.emit()}.{node.attr});")
|
|
1342
|
+
|
|
1343
|
+
if adj.is_differentiable_value_type(strip_reference(attr_type)):
|
|
1344
|
+
adj.add_reverse(f"{aggregate.emit_adj()}.{node.attr} += {attr.emit_adj()};")
|
|
1345
|
+
else:
|
|
1333
1346
|
adj.add_reverse(f"{aggregate.emit_adj()}.{node.attr} = {attr.emit_adj()};")
|
|
1334
1347
|
|
|
1335
1348
|
return attr
|
|
@@ -1344,7 +1357,7 @@ class Adjoint:
|
|
|
1344
1357
|
|
|
1345
1358
|
if isinstance(aggregate, Var):
|
|
1346
1359
|
raise WarpCodegenAttributeError(
|
|
1347
|
-
f"Error, `{node.attr}` is not an attribute of '{
|
|
1360
|
+
f"Error, `{node.attr}` is not an attribute of '{node.value.id}' ({type_repr(aggregate.type)})"
|
|
1348
1361
|
)
|
|
1349
1362
|
raise WarpCodegenAttributeError(f"Error, `{node.attr}` is not an attribute of '{aggregate}'")
|
|
1350
1363
|
|
|
@@ -1368,12 +1381,12 @@ class Adjoint:
|
|
|
1368
1381
|
return
|
|
1369
1382
|
|
|
1370
1383
|
def emit_NameConstant(adj, node):
|
|
1371
|
-
if node.value
|
|
1384
|
+
if node.value:
|
|
1372
1385
|
return adj.add_constant(True)
|
|
1373
|
-
elif node.value is False:
|
|
1374
|
-
return adj.add_constant(False)
|
|
1375
1386
|
elif node.value is None:
|
|
1376
1387
|
raise WarpCodegenTypeError("None type unsupported")
|
|
1388
|
+
else:
|
|
1389
|
+
return adj.add_constant(False)
|
|
1377
1390
|
|
|
1378
1391
|
def emit_Constant(adj, node):
|
|
1379
1392
|
if isinstance(node, ast.Str):
|
|
@@ -1413,7 +1426,7 @@ class Adjoint:
|
|
|
1413
1426
|
if var1 != var2:
|
|
1414
1427
|
if warp.config.verbose and not adj.custom_reverse_mode:
|
|
1415
1428
|
lineno = adj.lineno + adj.fun_lineno
|
|
1416
|
-
line = adj.
|
|
1429
|
+
line = adj.source_lines[adj.lineno]
|
|
1417
1430
|
msg = f'Warning: detected mutated variable {sym} during a dynamic for-loop in function "{adj.fun_name}" at {adj.filename}:{lineno}: this may not be a differentiable operation.\n{line}\n'
|
|
1418
1431
|
print(msg)
|
|
1419
1432
|
|
|
@@ -1450,7 +1463,11 @@ class Adjoint:
|
|
|
1450
1463
|
|
|
1451
1464
|
# try and resolve the expression to an object
|
|
1452
1465
|
# e.g.: wp.constant in the globals scope
|
|
1453
|
-
obj,
|
|
1466
|
+
obj, _ = adj.resolve_static_expression(a)
|
|
1467
|
+
|
|
1468
|
+
if isinstance(obj, Var) and obj.constant is not None:
|
|
1469
|
+
obj = obj.constant
|
|
1470
|
+
|
|
1454
1471
|
return warp.types.is_int(obj), obj
|
|
1455
1472
|
|
|
1456
1473
|
# detects whether a loop contains a break (or continue) statement
|
|
@@ -1596,7 +1613,7 @@ class Adjoint:
|
|
|
1596
1613
|
if adj.is_user_function:
|
|
1597
1614
|
if hasattr(node.func, "attr") and node.func.attr == "tid":
|
|
1598
1615
|
lineno = adj.lineno + adj.fun_lineno
|
|
1599
|
-
line = adj.
|
|
1616
|
+
line = adj.source_lines[adj.lineno]
|
|
1600
1617
|
raise WarpCodegenError(
|
|
1601
1618
|
"tid() may only be called from a Warp kernel, not a Warp function. "
|
|
1602
1619
|
"Instead, obtain the indices from a @wp.kernel and pass them as "
|
|
@@ -1613,7 +1630,7 @@ class Adjoint:
|
|
|
1613
1630
|
|
|
1614
1631
|
if not isinstance(func, warp.context.Function):
|
|
1615
1632
|
if len(path) == 0:
|
|
1616
|
-
raise WarpCodegenError(f"
|
|
1633
|
+
raise WarpCodegenError(f"Unknown function or operator: '{node.func.func.id}'")
|
|
1617
1634
|
|
|
1618
1635
|
attr = path[-1]
|
|
1619
1636
|
caller = func
|
|
@@ -1818,7 +1835,7 @@ class Adjoint:
|
|
|
1818
1835
|
|
|
1819
1836
|
if warp.config.verbose and not adj.custom_reverse_mode:
|
|
1820
1837
|
lineno = adj.lineno + adj.fun_lineno
|
|
1821
|
-
line = adj.
|
|
1838
|
+
line = adj.source_lines[adj.lineno]
|
|
1822
1839
|
node_source = adj.get_node_source(lhs.value)
|
|
1823
1840
|
print(
|
|
1824
1841
|
f"Warning: mutating {node_source} in function {adj.fun_name} at {adj.filename}:{lineno}: this is a non-differentiable operation.\n{line}\n"
|
|
@@ -1875,7 +1892,7 @@ class Adjoint:
|
|
|
1875
1892
|
|
|
1876
1893
|
if warp.config.verbose and not adj.custom_reverse_mode:
|
|
1877
1894
|
lineno = adj.lineno + adj.fun_lineno
|
|
1878
|
-
line = adj.
|
|
1895
|
+
line = adj.source_lines[adj.lineno]
|
|
1879
1896
|
msg = f'Warning: detected mutated struct {attr.label} during function "{adj.fun_name}" at {adj.filename}:{lineno}: this is a non-differentiable operation.\n{line}\n'
|
|
1880
1897
|
print(msg)
|
|
1881
1898
|
|
|
@@ -1901,7 +1918,8 @@ class Adjoint:
|
|
|
1901
1918
|
if var is not None:
|
|
1902
1919
|
adj.return_var = tuple()
|
|
1903
1920
|
for ret in var:
|
|
1904
|
-
|
|
1921
|
+
if is_reference(ret.type):
|
|
1922
|
+
ret = adj.add_builtin_call("copy", [ret])
|
|
1905
1923
|
adj.return_var += (ret,)
|
|
1906
1924
|
|
|
1907
1925
|
adj.add_return(adj.return_var)
|
|
@@ -1945,7 +1963,7 @@ class Adjoint:
|
|
|
1945
1963
|
ast.AugAssign: emit_AugAssign,
|
|
1946
1964
|
ast.Tuple: emit_Tuple,
|
|
1947
1965
|
ast.Pass: emit_Pass,
|
|
1948
|
-
ast.Ellipsis: emit_Ellipsis
|
|
1966
|
+
ast.Ellipsis: emit_Ellipsis,
|
|
1949
1967
|
}
|
|
1950
1968
|
|
|
1951
1969
|
def eval(adj, node):
|
|
@@ -2009,16 +2027,11 @@ class Adjoint:
|
|
|
2009
2027
|
attributes.append(node.attr)
|
|
2010
2028
|
node = node.value
|
|
2011
2029
|
|
|
2012
|
-
if eval_types and isinstance(node, ast.Call):
|
|
2030
|
+
if eval_types and isinstance(node, ast.Call) and isinstance(node.func, ast.Name):
|
|
2013
2031
|
# support for operators returning modules
|
|
2014
2032
|
# i.e. operator_name(*operator_args).x.y.z
|
|
2015
2033
|
operator_args = node.args
|
|
2016
|
-
operator_name =
|
|
2017
|
-
|
|
2018
|
-
if operator_name is None:
|
|
2019
|
-
raise WarpCodegenError(
|
|
2020
|
-
f"Invalid operator call syntax, expected a plain name, got {ast.dump(node.func, annotate_fields=False)}"
|
|
2021
|
-
)
|
|
2034
|
+
operator_name = node.func.id
|
|
2022
2035
|
|
|
2023
2036
|
if operator_name == "type":
|
|
2024
2037
|
if len(operator_args) != 1:
|
|
@@ -2043,8 +2056,6 @@ class Adjoint:
|
|
|
2043
2056
|
else:
|
|
2044
2057
|
raise WarpCodegenError(f"Cannot deduce the type of {var}")
|
|
2045
2058
|
|
|
2046
|
-
raise WarpCodegenError(f"Unknown operator '{operator_name}'")
|
|
2047
|
-
|
|
2048
2059
|
# reverse list since ast presents it backward order
|
|
2049
2060
|
path = [*reversed(attributes)]
|
|
2050
2061
|
if isinstance(node, ast.Name):
|
|
@@ -2071,14 +2082,14 @@ class Adjoint:
|
|
|
2071
2082
|
def set_lineno(adj, lineno):
|
|
2072
2083
|
if adj.lineno is None or adj.lineno != lineno:
|
|
2073
2084
|
line = lineno + adj.fun_lineno
|
|
2074
|
-
source = adj.
|
|
2085
|
+
source = adj.source_lines[lineno].strip().ljust(80 - len(adj.indentation), " ")
|
|
2075
2086
|
adj.add_forward(f"// {source} <L {line}>")
|
|
2076
2087
|
adj.add_reverse(f"// adj: {source} <L {line}>")
|
|
2077
2088
|
adj.lineno = lineno
|
|
2078
2089
|
|
|
2079
2090
|
def get_node_source(adj, node):
|
|
2080
2091
|
# return the Python code corresponding to the given AST node
|
|
2081
|
-
return ast.get_source_segment(
|
|
2092
|
+
return ast.get_source_segment(adj.source, node)
|
|
2082
2093
|
|
|
2083
2094
|
|
|
2084
2095
|
# ----------------
|
|
@@ -2130,7 +2141,9 @@ struct {name}
|
|
|
2130
2141
|
{{
|
|
2131
2142
|
}}
|
|
2132
2143
|
|
|
2133
|
-
CUDA_CALLABLE {name}& operator += (const {name}&)
|
|
2144
|
+
CUDA_CALLABLE {name}& operator += (const {name}& rhs)
|
|
2145
|
+
{{{prefix_add_body}
|
|
2146
|
+
return *this;}}
|
|
2134
2147
|
|
|
2135
2148
|
}};
|
|
2136
2149
|
|
|
@@ -2357,6 +2370,7 @@ def codegen_struct(struct, device="cpu", indent_size=4):
|
|
|
2357
2370
|
forward_initializers = []
|
|
2358
2371
|
reverse_body = []
|
|
2359
2372
|
atomic_add_body = []
|
|
2373
|
+
prefix_add_body = []
|
|
2360
2374
|
|
|
2361
2375
|
# forward args
|
|
2362
2376
|
for label, var in struct.vars.items():
|
|
@@ -2370,6 +2384,11 @@ def codegen_struct(struct, device="cpu", indent_size=4):
|
|
|
2370
2384
|
prefix = f"{indent_block}," if forward_initializers else ":"
|
|
2371
2385
|
forward_initializers.append(f"{indent_block}{prefix} {label}{{{label}}}\n")
|
|
2372
2386
|
|
|
2387
|
+
# prefix-add operator
|
|
2388
|
+
for label, var in struct.vars.items():
|
|
2389
|
+
if not is_array(var.type):
|
|
2390
|
+
prefix_add_body.append(f"{indent_block}{label} += rhs.{label};\n")
|
|
2391
|
+
|
|
2373
2392
|
# reverse args
|
|
2374
2393
|
for label, var in struct.vars.items():
|
|
2375
2394
|
reverse_args.append(var.ctype() + " & adj_" + label)
|
|
@@ -2387,6 +2406,7 @@ def codegen_struct(struct, device="cpu", indent_size=4):
|
|
|
2387
2406
|
forward_initializers="".join(forward_initializers),
|
|
2388
2407
|
reverse_args=indent(reverse_args),
|
|
2389
2408
|
reverse_body="".join(reverse_body),
|
|
2409
|
+
prefix_add_body="".join(prefix_add_body),
|
|
2390
2410
|
atomic_add_body="".join(atomic_add_body),
|
|
2391
2411
|
)
|
|
2392
2412
|
|
warp/config.py
CHANGED
|
@@ -5,7 +5,7 @@
|
|
|
5
5
|
# distribution of this software and related documentation without an express
|
|
6
6
|
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
|
7
7
|
|
|
8
|
-
version = "1.0.0-beta.
|
|
8
|
+
version = "1.0.0-beta.6"
|
|
9
9
|
|
|
10
10
|
cuda_path = (
|
|
11
11
|
None # path to local CUDA toolchain, if None at init time warp will attempt to find the SDK using CUDA_PATH env var
|
|
@@ -33,3 +33,5 @@ ptx_target_arch = 70 # target architecture for PTX generation, defaults to the
|
|
|
33
33
|
enable_backward = True # whether to compiler the backward passes of the kernels
|
|
34
34
|
|
|
35
35
|
llvm_cuda = False # use Clang/LLVM instead of NVRTC to compile CUDA
|
|
36
|
+
|
|
37
|
+
graph_capture_module_load_default = True # Default value of force_module_load for capture_begin()
|