warp-lang 1.3.1__py3-none-win_amd64.whl → 1.3.3__py3-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/autograd.py +6 -6
- warp/bin/warp-clang.dll +0 -0
- warp/bin/warp.dll +0 -0
- warp/builtins.py +44 -41
- warp/codegen.py +27 -38
- warp/config.py +1 -1
- warp/context.py +159 -110
- warp/examples/fem/example_mixed_elasticity.py +33 -23
- warp/fem/field/nodal_field.py +1 -1
- warp/fem/quadrature/quadrature.py +1 -0
- warp/native/builtin.h +3 -3
- warp/native/bvh.h +1 -1
- warp/native/svd.h +22 -7
- warp/native/warp.cpp +1 -0
- warp/native/warp.cu +5 -0
- warp/native/warp.h +1 -0
- warp/sim/collide.py +1 -1
- warp/sim/model.py +16 -3
- warp/sim/utils.py +1 -1
- warp/stubs.py +112 -112
- warp/tests/test_array.py +45 -0
- warp/tests/test_async.py +3 -1
- warp/tests/test_bvh.py +33 -8
- warp/tests/test_compile_consts.py +15 -0
- warp/tests/test_examples.py +6 -1
- warp/tests/test_fem.py +51 -0
- warp/tests/test_grad_debug.py +2 -1
- warp/tests/test_model.py +55 -0
- warp/tests/test_point_triangle_closest_point.py +143 -0
- warp/tests/test_reload.py +28 -0
- warp/tests/test_struct.py +48 -30
- warp/tests/test_volume.py +30 -0
- warp/types.py +11 -8
- {warp_lang-1.3.1.dist-info → warp_lang-1.3.3.dist-info}/METADATA +14 -14
- {warp_lang-1.3.1.dist-info → warp_lang-1.3.3.dist-info}/RECORD +38 -37
- {warp_lang-1.3.1.dist-info → warp_lang-1.3.3.dist-info}/WHEEL +1 -1
- {warp_lang-1.3.1.dist-info → warp_lang-1.3.3.dist-info}/LICENSE.md +0 -0
- {warp_lang-1.3.1.dist-info → warp_lang-1.3.3.dist-info}/top_level.txt +0 -0
warp/autograd.py
CHANGED
|
@@ -593,13 +593,13 @@ def jacobian(
|
|
|
593
593
|
input_output_mask = []
|
|
594
594
|
arg_names = [arg.label for arg in kernel.adj.args]
|
|
595
595
|
|
|
596
|
-
def resolve_arg(name):
|
|
596
|
+
def resolve_arg(name, offset: int = 0):
|
|
597
597
|
if isinstance(name, int):
|
|
598
598
|
return name
|
|
599
|
-
return arg_names.index(name)
|
|
599
|
+
return arg_names.index(name) + offset
|
|
600
600
|
|
|
601
601
|
input_output_mask = [
|
|
602
|
-
(resolve_arg(input_name), resolve_arg(output_name
|
|
602
|
+
(resolve_arg(input_name), resolve_arg(output_name, -len(inputs)))
|
|
603
603
|
for input_name, output_name in input_output_mask
|
|
604
604
|
]
|
|
605
605
|
input_output_mask = set(input_output_mask)
|
|
@@ -694,13 +694,13 @@ def jacobian_fd(
|
|
|
694
694
|
input_output_mask = []
|
|
695
695
|
arg_names = [arg.label for arg in kernel.adj.args]
|
|
696
696
|
|
|
697
|
-
def resolve_arg(name):
|
|
697
|
+
def resolve_arg(name, offset: int = 0):
|
|
698
698
|
if isinstance(name, int):
|
|
699
699
|
return name
|
|
700
|
-
return arg_names.index(name)
|
|
700
|
+
return arg_names.index(name) + offset
|
|
701
701
|
|
|
702
702
|
input_output_mask = [
|
|
703
|
-
(resolve_arg(input_name), resolve_arg(output_name
|
|
703
|
+
(resolve_arg(input_name), resolve_arg(output_name, -len(inputs)))
|
|
704
704
|
for input_name, output_name in input_output_mask
|
|
705
705
|
]
|
|
706
706
|
input_output_mask = set(input_output_mask)
|
warp/bin/warp-clang.dll
CHANGED
|
Binary file
|
warp/bin/warp.dll
CHANGED
|
Binary file
|
warp/builtins.py
CHANGED
|
@@ -549,7 +549,9 @@ add_builtin(
|
|
|
549
549
|
add_builtin(
|
|
550
550
|
"skew",
|
|
551
551
|
input_types={"vec": vector(length=3, dtype=Scalar)},
|
|
552
|
-
value_func=lambda arg_types, arg_values: matrix(shape=(3, 3), dtype=
|
|
552
|
+
value_func=lambda arg_types, arg_values: matrix(shape=(3, 3), dtype=Scalar)
|
|
553
|
+
if arg_types is None
|
|
554
|
+
else matrix(shape=(3, 3), dtype=arg_types["vec"]._wp_scalar_type_),
|
|
553
555
|
group="Vector Math",
|
|
554
556
|
doc="Compute the skew-symmetric 3x3 matrix for a 3D vector ``vec``.",
|
|
555
557
|
)
|
|
@@ -603,9 +605,9 @@ add_builtin(
|
|
|
603
605
|
add_builtin(
|
|
604
606
|
"transpose",
|
|
605
607
|
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar)},
|
|
606
|
-
value_func=lambda arg_types, arg_values: matrix(
|
|
607
|
-
|
|
608
|
-
),
|
|
608
|
+
value_func=lambda arg_types, arg_values: matrix(shape=(Any, Any), dtype=Scalar)
|
|
609
|
+
if arg_types is None
|
|
610
|
+
else matrix(shape=(arg_types["a"]._shape_[1], arg_types["a"]._shape_[0]), dtype=arg_types["a"]._wp_scalar_type_),
|
|
609
611
|
group="Vector Math",
|
|
610
612
|
doc="Return the transpose of the matrix ``a``.",
|
|
611
613
|
)
|
|
@@ -1652,14 +1654,18 @@ add_builtin(
|
|
|
1652
1654
|
add_builtin(
|
|
1653
1655
|
"spatial_top",
|
|
1654
1656
|
input_types={"svec": vector(length=6, dtype=Float)},
|
|
1655
|
-
value_func=lambda arg_types, arg_values: vector(length=3, dtype=
|
|
1657
|
+
value_func=lambda arg_types, arg_values: vector(length=3, dtype=Float)
|
|
1658
|
+
if arg_types is None
|
|
1659
|
+
else vector(length=3, dtype=arg_types["svec"]._wp_scalar_type_),
|
|
1656
1660
|
group="Spatial Math",
|
|
1657
1661
|
doc="Return the top (first) part of a 6D screw vector.",
|
|
1658
1662
|
)
|
|
1659
1663
|
add_builtin(
|
|
1660
1664
|
"spatial_bottom",
|
|
1661
1665
|
input_types={"svec": vector(length=6, dtype=Float)},
|
|
1662
|
-
value_func=lambda arg_types, arg_values: vector(length=3, dtype=
|
|
1666
|
+
value_func=lambda arg_types, arg_values: vector(length=3, dtype=Float)
|
|
1667
|
+
if arg_types is None
|
|
1668
|
+
else vector(length=3, dtype=arg_types["svec"]._wp_scalar_type_),
|
|
1663
1669
|
group="Spatial Math",
|
|
1664
1670
|
doc="Return the bottom (second) part of a 6D screw vector.",
|
|
1665
1671
|
)
|
|
@@ -3079,7 +3085,7 @@ add_builtin(
|
|
|
3079
3085
|
add_builtin(
|
|
3080
3086
|
"select",
|
|
3081
3087
|
input_types={"cond": builtins.bool, "value_if_false": Any, "value_if_true": Any},
|
|
3082
|
-
value_func=lambda arg_types, arg_values: arg_types["value_if_false"],
|
|
3088
|
+
value_func=lambda arg_types, arg_values: Any if arg_types is None else arg_types["value_if_false"],
|
|
3083
3089
|
doc="Select between two arguments, if ``cond`` is ``False`` then return ``value_if_false``, otherwise return ``value_if_true``",
|
|
3084
3090
|
group="Utility",
|
|
3085
3091
|
)
|
|
@@ -3087,14 +3093,14 @@ for t in int_types:
|
|
|
3087
3093
|
add_builtin(
|
|
3088
3094
|
"select",
|
|
3089
3095
|
input_types={"cond": t, "value_if_false": Any, "value_if_true": Any},
|
|
3090
|
-
value_func=lambda arg_types, arg_values: arg_types["value_if_false"],
|
|
3096
|
+
value_func=lambda arg_types, arg_values: Any if arg_types is None else arg_types["value_if_false"],
|
|
3091
3097
|
doc="Select between two arguments, if ``cond`` is ``False`` then return ``value_if_false``, otherwise return ``value_if_true``",
|
|
3092
3098
|
group="Utility",
|
|
3093
3099
|
)
|
|
3094
3100
|
add_builtin(
|
|
3095
3101
|
"select",
|
|
3096
3102
|
input_types={"arr": array(dtype=Any), "value_if_false": Any, "value_if_true": Any},
|
|
3097
|
-
value_func=lambda arg_types, arg_values: arg_types["value_if_false"],
|
|
3103
|
+
value_func=lambda arg_types, arg_values: Any if arg_types is None else arg_types["value_if_false"],
|
|
3098
3104
|
doc="Select between two arguments, if ``arr`` is null then return ``value_if_false``, otherwise return ``value_if_true``",
|
|
3099
3105
|
group="Utility",
|
|
3100
3106
|
)
|
|
@@ -3145,14 +3151,9 @@ def address_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, A
|
|
|
3145
3151
|
|
|
3146
3152
|
idx_count = len(idx_types)
|
|
3147
3153
|
|
|
3148
|
-
if idx_count
|
|
3154
|
+
if idx_count != arr_type.ndim:
|
|
3149
3155
|
raise RuntimeError(
|
|
3150
|
-
"
|
|
3151
|
-
)
|
|
3152
|
-
|
|
3153
|
-
if idx_count > arr_type.ndim:
|
|
3154
|
-
raise RuntimeError(
|
|
3155
|
-
f"Num indices > num dimensions for array load, received {idx_count}, but array only has {arr_type.ndim}"
|
|
3156
|
+
f"The number of indices provided ({idx_count}) does not match the array dimensions ({arr_type.ndim}) for array load"
|
|
3156
3157
|
)
|
|
3157
3158
|
|
|
3158
3159
|
# check index types
|
|
@@ -3229,14 +3230,9 @@ def array_store_value_func(arg_types: Mapping[str, type], arg_values: Mapping[st
|
|
|
3229
3230
|
|
|
3230
3231
|
idx_count = len(idx_types)
|
|
3231
3232
|
|
|
3232
|
-
if idx_count
|
|
3233
|
-
raise RuntimeError(
|
|
3234
|
-
"Num indices < num dimensions for array store, this is a codegen error, should have generated a view instead"
|
|
3235
|
-
)
|
|
3236
|
-
|
|
3237
|
-
if idx_count > arr_type.ndim:
|
|
3233
|
+
if idx_count != arr_type.ndim:
|
|
3238
3234
|
raise RuntimeError(
|
|
3239
|
-
f"
|
|
3235
|
+
f"The number of indices provided ({idx_count}) does not match the array dimensions ({arr_type.ndim}) for array store"
|
|
3240
3236
|
)
|
|
3241
3237
|
|
|
3242
3238
|
# check index types
|
|
@@ -3335,6 +3331,9 @@ add_builtin(
|
|
|
3335
3331
|
|
|
3336
3332
|
|
|
3337
3333
|
def atomic_op_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
3334
|
+
if arg_types is None:
|
|
3335
|
+
return Any
|
|
3336
|
+
|
|
3338
3337
|
arr_type = arg_types["arr"]
|
|
3339
3338
|
value_type = arg_types["value"]
|
|
3340
3339
|
idx_types = tuple(arg_types[x] for x in "ijkl" if arg_types.get(x, None) is not None)
|
|
@@ -3377,7 +3376,7 @@ for array_type in array_types:
|
|
|
3377
3376
|
hidden=hidden,
|
|
3378
3377
|
input_types={"arr": array_type(dtype=Any), "i": int, "value": Any},
|
|
3379
3378
|
value_func=atomic_op_value_func,
|
|
3380
|
-
doc="Atomically add ``value`` onto ``arr[i]
|
|
3379
|
+
doc="Atomically add ``value`` onto ``arr[i]`` and return the old value.",
|
|
3381
3380
|
group="Utility",
|
|
3382
3381
|
skip_replay=True,
|
|
3383
3382
|
)
|
|
@@ -3386,7 +3385,7 @@ for array_type in array_types:
|
|
|
3386
3385
|
hidden=hidden,
|
|
3387
3386
|
input_types={"arr": array_type(dtype=Any), "i": int, "j": int, "value": Any},
|
|
3388
3387
|
value_func=atomic_op_value_func,
|
|
3389
|
-
doc="Atomically add ``value`` onto ``arr[i,j]
|
|
3388
|
+
doc="Atomically add ``value`` onto ``arr[i,j]`` and return the old value.",
|
|
3390
3389
|
group="Utility",
|
|
3391
3390
|
skip_replay=True,
|
|
3392
3391
|
)
|
|
@@ -3395,7 +3394,7 @@ for array_type in array_types:
|
|
|
3395
3394
|
hidden=hidden,
|
|
3396
3395
|
input_types={"arr": array_type(dtype=Any), "i": int, "j": int, "k": int, "value": Any},
|
|
3397
3396
|
value_func=atomic_op_value_func,
|
|
3398
|
-
doc="Atomically add ``value`` onto ``arr[i,j,k]
|
|
3397
|
+
doc="Atomically add ``value`` onto ``arr[i,j,k]`` and return the old value.",
|
|
3399
3398
|
group="Utility",
|
|
3400
3399
|
skip_replay=True,
|
|
3401
3400
|
)
|
|
@@ -3404,7 +3403,7 @@ for array_type in array_types:
|
|
|
3404
3403
|
hidden=hidden,
|
|
3405
3404
|
input_types={"arr": array_type(dtype=Any), "i": int, "j": int, "k": int, "l": int, "value": Any},
|
|
3406
3405
|
value_func=atomic_op_value_func,
|
|
3407
|
-
doc="Atomically add ``value`` onto ``arr[i,j,k,l]
|
|
3406
|
+
doc="Atomically add ``value`` onto ``arr[i,j,k,l]`` and return the old value.",
|
|
3408
3407
|
group="Utility",
|
|
3409
3408
|
skip_replay=True,
|
|
3410
3409
|
)
|
|
@@ -3414,7 +3413,7 @@ for array_type in array_types:
|
|
|
3414
3413
|
hidden=hidden,
|
|
3415
3414
|
input_types={"arr": array_type(dtype=Any), "i": int, "value": Any},
|
|
3416
3415
|
value_func=atomic_op_value_func,
|
|
3417
|
-
doc="Atomically subtract ``value`` onto ``arr[i]
|
|
3416
|
+
doc="Atomically subtract ``value`` onto ``arr[i]`` and return the old value.",
|
|
3418
3417
|
group="Utility",
|
|
3419
3418
|
skip_replay=True,
|
|
3420
3419
|
)
|
|
@@ -3423,7 +3422,7 @@ for array_type in array_types:
|
|
|
3423
3422
|
hidden=hidden,
|
|
3424
3423
|
input_types={"arr": array_type(dtype=Any), "i": int, "j": int, "value": Any},
|
|
3425
3424
|
value_func=atomic_op_value_func,
|
|
3426
|
-
doc="Atomically subtract ``value`` onto ``arr[i,j]
|
|
3425
|
+
doc="Atomically subtract ``value`` onto ``arr[i,j]`` and return the old value.",
|
|
3427
3426
|
group="Utility",
|
|
3428
3427
|
skip_replay=True,
|
|
3429
3428
|
)
|
|
@@ -3432,7 +3431,7 @@ for array_type in array_types:
|
|
|
3432
3431
|
hidden=hidden,
|
|
3433
3432
|
input_types={"arr": array_type(dtype=Any), "i": int, "j": int, "k": int, "value": Any},
|
|
3434
3433
|
value_func=atomic_op_value_func,
|
|
3435
|
-
doc="Atomically subtract ``value`` onto ``arr[i,j,k]
|
|
3434
|
+
doc="Atomically subtract ``value`` onto ``arr[i,j,k]`` and return the old value.",
|
|
3436
3435
|
group="Utility",
|
|
3437
3436
|
skip_replay=True,
|
|
3438
3437
|
)
|
|
@@ -3441,7 +3440,7 @@ for array_type in array_types:
|
|
|
3441
3440
|
hidden=hidden,
|
|
3442
3441
|
input_types={"arr": array_type(dtype=Any), "i": int, "j": int, "k": int, "l": int, "value": Any},
|
|
3443
3442
|
value_func=atomic_op_value_func,
|
|
3444
|
-
doc="Atomically subtract ``value`` onto ``arr[i,j,k,l]
|
|
3443
|
+
doc="Atomically subtract ``value`` onto ``arr[i,j,k,l]`` and return the old value.",
|
|
3445
3444
|
group="Utility",
|
|
3446
3445
|
skip_replay=True,
|
|
3447
3446
|
)
|
|
@@ -3451,7 +3450,7 @@ for array_type in array_types:
|
|
|
3451
3450
|
hidden=hidden,
|
|
3452
3451
|
input_types={"arr": array_type(dtype=Any), "i": int, "value": Any},
|
|
3453
3452
|
value_func=atomic_op_value_func,
|
|
3454
|
-
doc="""Compute the minimum of ``value`` and ``arr[i]
|
|
3453
|
+
doc="""Compute the minimum of ``value`` and ``arr[i]``, atomically update the array, and return the old value.
|
|
3455
3454
|
|
|
3456
3455
|
.. note:: The operation is only atomic on a per-component basis for vectors and matrices.""",
|
|
3457
3456
|
group="Utility",
|
|
@@ -3462,7 +3461,7 @@ for array_type in array_types:
|
|
|
3462
3461
|
hidden=hidden,
|
|
3463
3462
|
input_types={"arr": array_type(dtype=Any), "i": int, "j": int, "value": Any},
|
|
3464
3463
|
value_func=atomic_op_value_func,
|
|
3465
|
-
doc="""Compute the minimum of ``value`` and ``arr[i,j]
|
|
3464
|
+
doc="""Compute the minimum of ``value`` and ``arr[i,j]``, atomically update the array, and return the old value.
|
|
3466
3465
|
|
|
3467
3466
|
.. note:: The operation is only atomic on a per-component basis for vectors and matrices.""",
|
|
3468
3467
|
group="Utility",
|
|
@@ -3473,7 +3472,7 @@ for array_type in array_types:
|
|
|
3473
3472
|
hidden=hidden,
|
|
3474
3473
|
input_types={"arr": array_type(dtype=Any), "i": int, "j": int, "k": int, "value": Any},
|
|
3475
3474
|
value_func=atomic_op_value_func,
|
|
3476
|
-
doc="""Compute the minimum of ``value`` and ``arr[i,j,k]
|
|
3475
|
+
doc="""Compute the minimum of ``value`` and ``arr[i,j,k]``, atomically update the array, and return the old value.
|
|
3477
3476
|
|
|
3478
3477
|
.. note:: The operation is only atomic on a per-component basis for vectors and matrices.""",
|
|
3479
3478
|
group="Utility",
|
|
@@ -3484,7 +3483,7 @@ for array_type in array_types:
|
|
|
3484
3483
|
hidden=hidden,
|
|
3485
3484
|
input_types={"arr": array_type(dtype=Any), "i": int, "j": int, "k": int, "l": int, "value": Any},
|
|
3486
3485
|
value_func=atomic_op_value_func,
|
|
3487
|
-
doc="""Compute the minimum of ``value`` and ``arr[i,j,k,l]
|
|
3486
|
+
doc="""Compute the minimum of ``value`` and ``arr[i,j,k,l]``, atomically update the array, and return the old value.
|
|
3488
3487
|
|
|
3489
3488
|
.. note:: The operation is only atomic on a per-component basis for vectors and matrices.""",
|
|
3490
3489
|
group="Utility",
|
|
@@ -3496,7 +3495,7 @@ for array_type in array_types:
|
|
|
3496
3495
|
hidden=hidden,
|
|
3497
3496
|
input_types={"arr": array_type(dtype=Any), "i": int, "value": Any},
|
|
3498
3497
|
value_func=atomic_op_value_func,
|
|
3499
|
-
doc="""Compute the maximum of ``value`` and ``arr[i]
|
|
3498
|
+
doc="""Compute the maximum of ``value`` and ``arr[i]``, atomically update the array, and return the old value.
|
|
3500
3499
|
|
|
3501
3500
|
.. note:: The operation is only atomic on a per-component basis for vectors and matrices.""",
|
|
3502
3501
|
group="Utility",
|
|
@@ -3507,7 +3506,7 @@ for array_type in array_types:
|
|
|
3507
3506
|
hidden=hidden,
|
|
3508
3507
|
input_types={"arr": array_type(dtype=Any), "i": int, "j": int, "value": Any},
|
|
3509
3508
|
value_func=atomic_op_value_func,
|
|
3510
|
-
doc="""Compute the maximum of ``value`` and ``arr[i,j]
|
|
3509
|
+
doc="""Compute the maximum of ``value`` and ``arr[i,j]``, atomically update the array, and return the old value.
|
|
3511
3510
|
|
|
3512
3511
|
.. note:: The operation is only atomic on a per-component basis for vectors and matrices.""",
|
|
3513
3512
|
group="Utility",
|
|
@@ -3518,7 +3517,7 @@ for array_type in array_types:
|
|
|
3518
3517
|
hidden=hidden,
|
|
3519
3518
|
input_types={"arr": array_type(dtype=Any), "i": int, "j": int, "k": int, "value": Any},
|
|
3520
3519
|
value_func=atomic_op_value_func,
|
|
3521
|
-
doc="""Compute the maximum of ``value`` and ``arr[i,j,k]
|
|
3520
|
+
doc="""Compute the maximum of ``value`` and ``arr[i,j,k]``, atomically update the array, and return the old value.
|
|
3522
3521
|
|
|
3523
3522
|
.. note:: The operation is only atomic on a per-component basis for vectors and matrices.""",
|
|
3524
3523
|
group="Utility",
|
|
@@ -3529,7 +3528,7 @@ for array_type in array_types:
|
|
|
3529
3528
|
hidden=hidden,
|
|
3530
3529
|
input_types={"arr": array_type(dtype=Any), "i": int, "j": int, "k": int, "l": int, "value": Any},
|
|
3531
3530
|
value_func=atomic_op_value_func,
|
|
3532
|
-
doc="""Compute the maximum of ``value`` and ``arr[i,j,k,l]
|
|
3531
|
+
doc="""Compute the maximum of ``value`` and ``arr[i,j,k,l]``, atomically update the array, and return the old value.
|
|
3533
3532
|
|
|
3534
3533
|
.. note:: The operation is only atomic on a per-component basis for vectors and matrices.""",
|
|
3535
3534
|
group="Utility",
|
|
@@ -4056,7 +4055,7 @@ def matmat_mul_constraint(arg_types: Mapping[str, type]):
|
|
|
4056
4055
|
|
|
4057
4056
|
def matmat_mul_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
4058
4057
|
if arg_types is None:
|
|
4059
|
-
return matrix(
|
|
4058
|
+
return matrix(shape=(Any, Any), dtype=Scalar)
|
|
4060
4059
|
|
|
4061
4060
|
if arg_types["a"]._wp_scalar_type_ != arg_types["b"]._wp_scalar_type_:
|
|
4062
4061
|
raise RuntimeError(
|
|
@@ -4102,7 +4101,11 @@ add_builtin(
|
|
|
4102
4101
|
)
|
|
4103
4102
|
|
|
4104
4103
|
add_builtin(
|
|
4105
|
-
"mod",
|
|
4104
|
+
"mod",
|
|
4105
|
+
input_types={"a": Scalar, "b": Scalar},
|
|
4106
|
+
value_func=sametypes_create_value_func(Scalar),
|
|
4107
|
+
doc="Modulo operation using truncated division.",
|
|
4108
|
+
group="Operators",
|
|
4106
4109
|
)
|
|
4107
4110
|
|
|
4108
4111
|
add_builtin(
|
warp/codegen.py
CHANGED
|
@@ -795,7 +795,13 @@ class Adjoint:
|
|
|
795
795
|
# extract name of source file
|
|
796
796
|
adj.filename = inspect.getsourcefile(func) or "unknown source file"
|
|
797
797
|
# get source file line number where function starts
|
|
798
|
-
|
|
798
|
+
try:
|
|
799
|
+
_, adj.fun_lineno = inspect.getsourcelines(func)
|
|
800
|
+
except OSError as e:
|
|
801
|
+
raise RuntimeError(
|
|
802
|
+
"Directly evaluating Warp code defined as a string using `exec()` is not supported, "
|
|
803
|
+
"please save it on a file and use `importlib` if needed."
|
|
804
|
+
) from e
|
|
799
805
|
|
|
800
806
|
# get function source code
|
|
801
807
|
adj.source = inspect.getsource(func)
|
|
@@ -1592,15 +1598,7 @@ class Adjoint:
|
|
|
1592
1598
|
if node.id in adj.symbols:
|
|
1593
1599
|
return adj.symbols[node.id]
|
|
1594
1600
|
|
|
1595
|
-
|
|
1596
|
-
obj = adj.func.__globals__.get(node.id)
|
|
1597
|
-
|
|
1598
|
-
if obj is None:
|
|
1599
|
-
# Lookup constant in captured contents
|
|
1600
|
-
capturedvars = dict(
|
|
1601
|
-
zip(adj.func.__code__.co_freevars, [c.cell_contents for c in (adj.func.__closure__ or [])])
|
|
1602
|
-
)
|
|
1603
|
-
obj = capturedvars.get(str(node.id), None)
|
|
1601
|
+
obj = adj.resolve_external_reference(node.id)
|
|
1604
1602
|
|
|
1605
1603
|
if obj is None:
|
|
1606
1604
|
raise WarpCodegenKeyError("Referencing undefined symbol: " + str(node.id))
|
|
@@ -2299,7 +2297,9 @@ class Adjoint:
|
|
|
2299
2297
|
)
|
|
2300
2298
|
|
|
2301
2299
|
else:
|
|
2302
|
-
raise WarpCodegenError(
|
|
2300
|
+
raise WarpCodegenError(
|
|
2301
|
+
f"Can only subscript assign array, vector, quaternion, and matrix types, got {target_type}"
|
|
2302
|
+
)
|
|
2303
2303
|
|
|
2304
2304
|
elif isinstance(lhs, ast.Name):
|
|
2305
2305
|
# symbol name
|
|
@@ -2448,24 +2448,11 @@ class Adjoint:
|
|
|
2448
2448
|
if path[0] in __builtins__:
|
|
2449
2449
|
return __builtins__[path[0]]
|
|
2450
2450
|
|
|
2451
|
-
#
|
|
2452
|
-
|
|
2453
|
-
# to variables you've declared inside that function:
|
|
2454
|
-
def extract_contents(contents):
|
|
2455
|
-
return contents if isinstance(contents, warp.context.Function) or not callable(contents) else contents
|
|
2456
|
-
|
|
2457
|
-
capturedvars = dict(
|
|
2458
|
-
zip(
|
|
2459
|
-
adj.func.__code__.co_freevars, [extract_contents(c.cell_contents) for c in (adj.func.__closure__ or [])]
|
|
2460
|
-
)
|
|
2461
|
-
)
|
|
2462
|
-
vars_dict = {**adj.func.__globals__, **capturedvars}
|
|
2463
|
-
|
|
2464
|
-
if path[0] in vars_dict:
|
|
2465
|
-
expr = vars_dict[path[0]]
|
|
2451
|
+
# look up in closure/global variables
|
|
2452
|
+
expr = adj.resolve_external_reference(path[0])
|
|
2466
2453
|
|
|
2467
2454
|
# Support Warp types in kernels without the module suffix (e.g. v = vec3(0.0,0.2,0.4)):
|
|
2468
|
-
|
|
2455
|
+
if expr is None:
|
|
2469
2456
|
expr = getattr(warp, path[0], None)
|
|
2470
2457
|
|
|
2471
2458
|
if expr:
|
|
@@ -2526,6 +2513,16 @@ class Adjoint:
|
|
|
2526
2513
|
|
|
2527
2514
|
return None, path
|
|
2528
2515
|
|
|
2516
|
+
def resolve_external_reference(adj, name: str):
|
|
2517
|
+
try:
|
|
2518
|
+
# look up in closure variables
|
|
2519
|
+
idx = adj.func.__code__.co_freevars.index(name)
|
|
2520
|
+
obj = adj.func.__closure__[idx].cell_contents
|
|
2521
|
+
except ValueError:
|
|
2522
|
+
# look up in global variables
|
|
2523
|
+
obj = adj.func.__globals__.get(name)
|
|
2524
|
+
return obj
|
|
2525
|
+
|
|
2529
2526
|
# annotate generated code with the original source code line
|
|
2530
2527
|
def set_lineno(adj, lineno):
|
|
2531
2528
|
if adj.lineno is None or adj.lineno != lineno:
|
|
@@ -2551,17 +2548,8 @@ class Adjoint:
|
|
|
2551
2548
|
|
|
2552
2549
|
for node in ast.walk(adj.tree):
|
|
2553
2550
|
if isinstance(node, ast.Name) and node.id not in local_variables:
|
|
2554
|
-
#
|
|
2555
|
-
|
|
2556
|
-
# try and resolve the name using the function's globals context (used to lookup constants + functions)
|
|
2557
|
-
obj = adj.func.__globals__.get(node.id)
|
|
2558
|
-
|
|
2559
|
-
if obj is None:
|
|
2560
|
-
# Lookup constant in captured contents
|
|
2561
|
-
capturedvars = dict(
|
|
2562
|
-
zip(adj.func.__code__.co_freevars, [c.cell_contents for c in (adj.func.__closure__ or [])])
|
|
2563
|
-
)
|
|
2564
|
-
obj = capturedvars.get(str(node.id), None)
|
|
2551
|
+
# look up in closure/global variables
|
|
2552
|
+
obj = adj.resolve_external_reference(node.id)
|
|
2565
2553
|
|
|
2566
2554
|
if warp.types.is_value(obj):
|
|
2567
2555
|
constants_dict[node.id] = obj
|
|
@@ -2571,6 +2559,7 @@ class Adjoint:
|
|
|
2571
2559
|
|
|
2572
2560
|
if warp.types.is_value(obj):
|
|
2573
2561
|
constants_dict[".".join(path)] = obj
|
|
2562
|
+
|
|
2574
2563
|
elif isinstance(node, ast.Assign):
|
|
2575
2564
|
# Add the LHS names to the local_variables so we know any subsequent uses are shadowed
|
|
2576
2565
|
lhs = node.targets[0]
|