warp-lang 1.8.0__py3-none-macosx_10_13_universal2.whl → 1.9.0__py3-none-macosx_10_13_universal2.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +282 -103
- warp/__init__.pyi +482 -110
- warp/bin/libwarp-clang.dylib +0 -0
- warp/bin/libwarp.dylib +0 -0
- warp/build.py +93 -30
- warp/build_dll.py +48 -63
- warp/builtins.py +955 -137
- warp/codegen.py +327 -209
- warp/config.py +1 -1
- warp/context.py +1363 -800
- warp/examples/core/example_marching_cubes.py +1 -0
- warp/examples/core/example_render_opengl.py +100 -3
- warp/examples/fem/example_apic_fluid.py +98 -52
- warp/examples/fem/example_convection_diffusion_dg.py +25 -4
- warp/examples/fem/example_diffusion_mgpu.py +8 -3
- warp/examples/fem/utils.py +68 -22
- warp/examples/interop/example_jax_callable.py +34 -4
- warp/examples/interop/example_jax_kernel.py +27 -1
- warp/fabric.py +1 -1
- warp/fem/cache.py +27 -19
- warp/fem/domain.py +2 -2
- warp/fem/field/nodal_field.py +2 -2
- warp/fem/field/virtual.py +266 -166
- warp/fem/geometry/geometry.py +5 -5
- warp/fem/integrate.py +200 -91
- warp/fem/space/restriction.py +4 -0
- warp/fem/space/shape/tet_shape_function.py +3 -10
- warp/jax_experimental/custom_call.py +1 -1
- warp/jax_experimental/ffi.py +203 -54
- warp/marching_cubes.py +708 -0
- warp/native/array.h +103 -8
- warp/native/builtin.h +90 -9
- warp/native/bvh.cpp +64 -28
- warp/native/bvh.cu +58 -58
- warp/native/bvh.h +2 -2
- warp/native/clang/clang.cpp +7 -7
- warp/native/coloring.cpp +13 -3
- warp/native/crt.cpp +2 -2
- warp/native/crt.h +3 -5
- warp/native/cuda_util.cpp +42 -11
- warp/native/cuda_util.h +10 -4
- warp/native/exports.h +1842 -1908
- warp/native/fabric.h +2 -1
- warp/native/hashgrid.cpp +37 -37
- warp/native/hashgrid.cu +2 -2
- warp/native/initializer_array.h +1 -1
- warp/native/intersect.h +4 -4
- warp/native/mat.h +1913 -119
- warp/native/mathdx.cpp +43 -43
- warp/native/mesh.cpp +24 -24
- warp/native/mesh.cu +26 -26
- warp/native/mesh.h +5 -3
- warp/native/nanovdb/GridHandle.h +179 -12
- warp/native/nanovdb/HostBuffer.h +8 -7
- warp/native/nanovdb/NanoVDB.h +517 -895
- warp/native/nanovdb/NodeManager.h +323 -0
- warp/native/nanovdb/PNanoVDB.h +2 -2
- warp/native/quat.h +337 -16
- warp/native/rand.h +7 -7
- warp/native/range.h +7 -1
- warp/native/reduce.cpp +10 -10
- warp/native/reduce.cu +13 -14
- warp/native/runlength_encode.cpp +2 -2
- warp/native/runlength_encode.cu +5 -5
- warp/native/scan.cpp +3 -3
- warp/native/scan.cu +4 -4
- warp/native/sort.cpp +10 -10
- warp/native/sort.cu +22 -22
- warp/native/sparse.cpp +8 -8
- warp/native/sparse.cu +14 -14
- warp/native/spatial.h +366 -17
- warp/native/svd.h +23 -8
- warp/native/temp_buffer.h +2 -2
- warp/native/tile.h +303 -70
- warp/native/tile_radix_sort.h +5 -1
- warp/native/tile_reduce.h +16 -25
- warp/native/tuple.h +2 -2
- warp/native/vec.h +385 -18
- warp/native/volume.cpp +54 -54
- warp/native/volume.cu +1 -1
- warp/native/volume.h +2 -1
- warp/native/volume_builder.cu +30 -37
- warp/native/warp.cpp +150 -149
- warp/native/warp.cu +337 -193
- warp/native/warp.h +227 -226
- warp/optim/linear.py +736 -271
- warp/render/imgui_manager.py +289 -0
- warp/render/render_opengl.py +137 -57
- warp/render/render_usd.py +0 -1
- warp/sim/collide.py +1 -2
- warp/sim/graph_coloring.py +2 -2
- warp/sim/integrator_vbd.py +10 -2
- warp/sparse.py +559 -176
- warp/tape.py +2 -0
- warp/tests/aux_test_module_aot.py +7 -0
- warp/tests/cuda/test_async.py +3 -3
- warp/tests/cuda/test_conditional_captures.py +101 -0
- warp/tests/geometry/test_marching_cubes.py +233 -12
- warp/tests/sim/test_cloth.py +89 -6
- warp/tests/sim/test_coloring.py +82 -7
- warp/tests/test_array.py +56 -5
- warp/tests/test_assert.py +53 -0
- warp/tests/test_atomic_cas.py +127 -114
- warp/tests/test_codegen.py +3 -2
- warp/tests/test_context.py +8 -15
- warp/tests/test_enum.py +136 -0
- warp/tests/test_examples.py +2 -2
- warp/tests/test_fem.py +45 -2
- warp/tests/test_fixedarray.py +229 -0
- warp/tests/test_func.py +18 -15
- warp/tests/test_future_annotations.py +7 -5
- warp/tests/test_linear_solvers.py +30 -0
- warp/tests/test_map.py +1 -1
- warp/tests/test_mat.py +1540 -378
- warp/tests/test_mat_assign_copy.py +178 -0
- warp/tests/test_mat_constructors.py +574 -0
- warp/tests/test_module_aot.py +287 -0
- warp/tests/test_print.py +69 -0
- warp/tests/test_quat.py +162 -34
- warp/tests/test_quat_assign_copy.py +145 -0
- warp/tests/test_reload.py +2 -1
- warp/tests/test_sparse.py +103 -0
- warp/tests/test_spatial.py +140 -34
- warp/tests/test_spatial_assign_copy.py +160 -0
- warp/tests/test_static.py +48 -0
- warp/tests/test_struct.py +43 -3
- warp/tests/test_tape.py +38 -0
- warp/tests/test_types.py +0 -20
- warp/tests/test_vec.py +216 -441
- warp/tests/test_vec_assign_copy.py +143 -0
- warp/tests/test_vec_constructors.py +325 -0
- warp/tests/tile/test_tile.py +206 -152
- warp/tests/tile/test_tile_cholesky.py +605 -0
- warp/tests/tile/test_tile_load.py +169 -0
- warp/tests/tile/test_tile_mathdx.py +2 -558
- warp/tests/tile/test_tile_matmul.py +179 -0
- warp/tests/tile/test_tile_mlp.py +1 -1
- warp/tests/tile/test_tile_reduce.py +100 -11
- warp/tests/tile/test_tile_shared_memory.py +16 -16
- warp/tests/tile/test_tile_sort.py +59 -55
- warp/tests/unittest_suites.py +16 -0
- warp/tests/walkthrough_debug.py +1 -1
- warp/thirdparty/unittest_parallel.py +108 -9
- warp/types.py +554 -264
- warp/utils.py +68 -86
- {warp_lang-1.8.0.dist-info → warp_lang-1.9.0.dist-info}/METADATA +28 -65
- {warp_lang-1.8.0.dist-info → warp_lang-1.9.0.dist-info}/RECORD +150 -138
- warp/native/marching.cpp +0 -19
- warp/native/marching.cu +0 -514
- warp/native/marching.h +0 -19
- {warp_lang-1.8.0.dist-info → warp_lang-1.9.0.dist-info}/WHEEL +0 -0
- {warp_lang-1.8.0.dist-info → warp_lang-1.9.0.dist-info}/licenses/LICENSE.md +0 -0
- {warp_lang-1.8.0.dist-info → warp_lang-1.9.0.dist-info}/top_level.txt +0 -0
|
@@ -42,7 +42,7 @@ def scale_vec_kernel(a: wp.array(dtype=wp.vec2), s: float, output: wp.array(dtyp
|
|
|
42
42
|
|
|
43
43
|
# The Python function to call.
|
|
44
44
|
# Note the argument annotations, just like Warp kernels.
|
|
45
|
-
def
|
|
45
|
+
def scale_func(
|
|
46
46
|
# inputs
|
|
47
47
|
a: wp.array(dtype=float),
|
|
48
48
|
b: wp.array(dtype=wp.vec2),
|
|
@@ -55,8 +55,23 @@ def example_func(
|
|
|
55
55
|
wp.launch(scale_vec_kernel, dim=b.shape, inputs=[b, s], outputs=[d])
|
|
56
56
|
|
|
57
57
|
|
|
58
|
+
@wp.kernel
|
|
59
|
+
def accum_kernel(a: wp.array(dtype=float), b: wp.array(dtype=float)):
|
|
60
|
+
tid = wp.tid()
|
|
61
|
+
b[tid] += a[tid]
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def in_out_func(
|
|
65
|
+
a: wp.array(dtype=float), # input only
|
|
66
|
+
b: wp.array(dtype=float), # input and output
|
|
67
|
+
c: wp.array(dtype=float), # output only
|
|
68
|
+
):
|
|
69
|
+
wp.launch(scale_kernel, dim=a.size, inputs=[a, 2.0], outputs=[c])
|
|
70
|
+
wp.launch(accum_kernel, dim=a.size, inputs=[a, b]) # modifies `b`
|
|
71
|
+
|
|
72
|
+
|
|
58
73
|
def example1():
|
|
59
|
-
jax_func = jax_callable(
|
|
74
|
+
jax_func = jax_callable(scale_func, num_outputs=2)
|
|
60
75
|
|
|
61
76
|
@jax.jit
|
|
62
77
|
def f():
|
|
@@ -78,7 +93,7 @@ def example1():
|
|
|
78
93
|
|
|
79
94
|
|
|
80
95
|
def example2():
|
|
81
|
-
jax_func = jax_callable(
|
|
96
|
+
jax_func = jax_callable(scale_func, num_outputs=2)
|
|
82
97
|
|
|
83
98
|
# NOTE: scalar arguments must be static compile-time constants
|
|
84
99
|
@partial(jax.jit, static_argnames=["s"])
|
|
@@ -100,11 +115,26 @@ def example2():
|
|
|
100
115
|
print(r2)
|
|
101
116
|
|
|
102
117
|
|
|
118
|
+
def example3():
|
|
119
|
+
# Using input-output arguments
|
|
120
|
+
|
|
121
|
+
jax_func = jax_callable(in_out_func, num_outputs=2, in_out_argnames=["b"])
|
|
122
|
+
|
|
123
|
+
f = jax.jit(jax_func)
|
|
124
|
+
|
|
125
|
+
a = jnp.ones(10, dtype=jnp.float32)
|
|
126
|
+
b = jnp.arange(10, dtype=jnp.float32)
|
|
127
|
+
|
|
128
|
+
b, c = f(a, b)
|
|
129
|
+
print(b)
|
|
130
|
+
print(c)
|
|
131
|
+
|
|
132
|
+
|
|
103
133
|
def main():
|
|
104
134
|
wp.init()
|
|
105
135
|
wp.load_module(device=wp.get_device())
|
|
106
136
|
|
|
107
|
-
examples = [example1, example2]
|
|
137
|
+
examples = [example1, example2, example3]
|
|
108
138
|
|
|
109
139
|
for example in examples:
|
|
110
140
|
print("\n===========================================================================")
|
|
@@ -72,6 +72,17 @@ def scale_vec_kernel(a: wp.array(dtype=wp.vec2), s: float, output: wp.array(dtyp
|
|
|
72
72
|
output[tid] = a[tid] * s
|
|
73
73
|
|
|
74
74
|
|
|
75
|
+
@wp.kernel
|
|
76
|
+
def in_out_kernel(
|
|
77
|
+
a: wp.array(dtype=float), # input only
|
|
78
|
+
b: wp.array(dtype=float), # input and output
|
|
79
|
+
c: wp.array(dtype=float), # output only
|
|
80
|
+
):
|
|
81
|
+
tid = wp.tid()
|
|
82
|
+
b[tid] += a[tid]
|
|
83
|
+
c[tid] = 2.0 * a[tid]
|
|
84
|
+
|
|
85
|
+
|
|
75
86
|
def example1():
|
|
76
87
|
# two inputs and one output
|
|
77
88
|
jax_add = jax_kernel(add_kernel)
|
|
@@ -189,11 +200,26 @@ def example7():
|
|
|
189
200
|
print(f())
|
|
190
201
|
|
|
191
202
|
|
|
203
|
+
def example8():
|
|
204
|
+
# Using input-output arguments
|
|
205
|
+
|
|
206
|
+
jax_func = jax_kernel(in_out_kernel, num_outputs=2, in_out_argnames=["b"])
|
|
207
|
+
|
|
208
|
+
f = jax.jit(jax_func)
|
|
209
|
+
|
|
210
|
+
a = jnp.ones(10, dtype=jnp.float32)
|
|
211
|
+
b = jnp.arange(10, dtype=jnp.float32)
|
|
212
|
+
|
|
213
|
+
b, c = f(a, b)
|
|
214
|
+
print(b)
|
|
215
|
+
print(c)
|
|
216
|
+
|
|
217
|
+
|
|
192
218
|
def main():
|
|
193
219
|
wp.init()
|
|
194
220
|
wp.load_module(device=wp.get_device())
|
|
195
221
|
|
|
196
|
-
examples = [example1, example2, example3, example4, example5, example6, example7]
|
|
222
|
+
examples = [example1, example2, example3, example4, example5, example6, example7, example8]
|
|
197
223
|
|
|
198
224
|
for example in examples:
|
|
199
225
|
print("\n===========================================================================")
|
warp/fabric.py
CHANGED
|
@@ -211,7 +211,7 @@ class fabricarray(noncontiguous_array_base[T]):
|
|
|
211
211
|
allocator = self.device.get_allocator()
|
|
212
212
|
buckets_ptr = allocator.alloc(buckets_size)
|
|
213
213
|
cuda_stream = self.device.stream.cuda_stream
|
|
214
|
-
runtime.core.
|
|
214
|
+
runtime.core.wp_memcpy_h2d(
|
|
215
215
|
self.device.context, buckets_ptr, ctypes.addressof(buckets), buckets_size, cuda_stream
|
|
216
216
|
)
|
|
217
217
|
self.deleter = allocator.deleter
|
warp/fem/cache.py
CHANGED
|
@@ -34,7 +34,7 @@ _key_re = re.compile("[^0-9a-zA-Z_]+")
|
|
|
34
34
|
|
|
35
35
|
def _make_key(obj, suffix: str, options: Optional[Dict[str, Any]] = None):
|
|
36
36
|
# human-readable part
|
|
37
|
-
|
|
37
|
+
suffix = str(suffix)
|
|
38
38
|
|
|
39
39
|
sorted_opts = sorted(options.items()) if options is not None else ()
|
|
40
40
|
opts_str = "".join(
|
|
@@ -49,7 +49,7 @@ def _make_key(obj, suffix: str, options: Optional[Dict[str, Any]] = None):
|
|
|
49
49
|
uid = hashlib.blake2b(bytes(opts_str, encoding="utf-8"), digest_size=4).hexdigest()
|
|
50
50
|
|
|
51
51
|
# avoid long keys, issues on win
|
|
52
|
-
key = f"{
|
|
52
|
+
key = f"{obj.__name__}_{suffix[:32]}_{uid}"
|
|
53
53
|
|
|
54
54
|
return key
|
|
55
55
|
|
|
@@ -62,7 +62,10 @@ def _arg_type_name(arg_type):
|
|
|
62
62
|
return wp.types.get_type_code(wp.types.type_to_warp(arg_type))
|
|
63
63
|
|
|
64
64
|
|
|
65
|
-
def _make_cache_key(func, key, argspec=None):
|
|
65
|
+
def _make_cache_key(func, key, argspec=None, allow_overloads: bool = True):
|
|
66
|
+
if not allow_overloads:
|
|
67
|
+
return key
|
|
68
|
+
|
|
66
69
|
if argspec is None:
|
|
67
70
|
annotations = get_annotations(func)
|
|
68
71
|
else:
|
|
@@ -80,6 +83,7 @@ def _register_function(
|
|
|
80
83
|
):
|
|
81
84
|
# wp.Function will override existing func for a given key...
|
|
82
85
|
# manually add back our overloads
|
|
86
|
+
key = _key_re.sub("", key)
|
|
83
87
|
existing = module.functions.get(key)
|
|
84
88
|
new_fn = wp.Function(
|
|
85
89
|
func=func,
|
|
@@ -95,9 +99,9 @@ def _register_function(
|
|
|
95
99
|
return module.functions[key]
|
|
96
100
|
|
|
97
101
|
|
|
98
|
-
def get_func(func, suffix: str, code_transformers=None):
|
|
102
|
+
def get_func(func, suffix: str, code_transformers=None, allow_overloads=False):
|
|
99
103
|
key = _make_key(func, suffix)
|
|
100
|
-
cache_key = _make_cache_key(func, key)
|
|
104
|
+
cache_key = _make_cache_key(func, key, allow_overloads=allow_overloads)
|
|
101
105
|
|
|
102
106
|
if cache_key not in _func_cache:
|
|
103
107
|
module = wp.get_module(func.__module__)
|
|
@@ -111,9 +115,9 @@ def get_func(func, suffix: str, code_transformers=None):
|
|
|
111
115
|
return _func_cache[cache_key]
|
|
112
116
|
|
|
113
117
|
|
|
114
|
-
def dynamic_func(suffix: str, code_transformers=None):
|
|
118
|
+
def dynamic_func(suffix: str, code_transformers=None, allow_overloads=False):
|
|
115
119
|
def wrap_func(func: Callable):
|
|
116
|
-
return get_func(func, suffix=suffix, code_transformers=code_transformers)
|
|
120
|
+
return get_func(func, suffix=suffix, code_transformers=code_transformers, allow_overloads=allow_overloads)
|
|
117
121
|
|
|
118
122
|
return wrap_func
|
|
119
123
|
|
|
@@ -122,46 +126,49 @@ def get_kernel(
|
|
|
122
126
|
func,
|
|
123
127
|
suffix: str,
|
|
124
128
|
kernel_options: Optional[Dict[str, Any]] = None,
|
|
129
|
+
allow_overloads=False,
|
|
125
130
|
):
|
|
126
131
|
if kernel_options is None:
|
|
127
132
|
kernel_options = {}
|
|
128
133
|
|
|
129
134
|
key = _make_key(func, suffix, kernel_options)
|
|
130
|
-
cache_key = _make_cache_key(func, key)
|
|
135
|
+
cache_key = _make_cache_key(func, key, allow_overloads=allow_overloads)
|
|
131
136
|
|
|
132
137
|
if cache_key not in _kernel_cache:
|
|
133
|
-
|
|
138
|
+
kernel_key = _key_re.sub("", key)
|
|
139
|
+
module_name = f"{func.__module__}.dyn.{kernel_key}"
|
|
134
140
|
module = wp.get_module(module_name)
|
|
135
141
|
module.options = dict(wp.get_module(func.__module__).options)
|
|
136
142
|
module.options.update(kernel_options)
|
|
137
|
-
_kernel_cache[cache_key] = wp.Kernel(func=func, key=
|
|
143
|
+
_kernel_cache[cache_key] = wp.Kernel(func=func, key=kernel_key, module=module, options=kernel_options)
|
|
138
144
|
return _kernel_cache[cache_key]
|
|
139
145
|
|
|
140
146
|
|
|
141
|
-
def dynamic_kernel(suffix: str, kernel_options: Optional[Dict[str, Any]] = None):
|
|
147
|
+
def dynamic_kernel(suffix: str, kernel_options: Optional[Dict[str, Any]] = None, allow_overloads=False):
|
|
142
148
|
if kernel_options is None:
|
|
143
149
|
kernel_options = {}
|
|
144
150
|
|
|
145
151
|
def wrap_kernel(func: Callable):
|
|
146
|
-
return get_kernel(func, suffix=suffix, kernel_options=kernel_options)
|
|
152
|
+
return get_kernel(func, suffix=suffix, kernel_options=kernel_options, allow_overloads=allow_overloads)
|
|
147
153
|
|
|
148
154
|
return wrap_kernel
|
|
149
155
|
|
|
150
156
|
|
|
151
157
|
def get_struct(struct: type, suffix: str):
|
|
152
158
|
key = _make_key(struct, suffix)
|
|
153
|
-
|
|
154
|
-
struct.__qualname__ = key
|
|
159
|
+
cache_key = key
|
|
155
160
|
|
|
156
|
-
if
|
|
161
|
+
if cache_key not in _struct_cache:
|
|
162
|
+
# used in codegen
|
|
163
|
+
struct.__qualname__ = _key_re.sub("", key)
|
|
157
164
|
module = wp.get_module(struct.__module__)
|
|
158
|
-
_struct_cache[
|
|
159
|
-
key=
|
|
165
|
+
_struct_cache[cache_key] = wp.codegen.Struct(
|
|
166
|
+
key=struct.__qualname__,
|
|
160
167
|
cls=struct,
|
|
161
168
|
module=module,
|
|
162
169
|
)
|
|
163
170
|
|
|
164
|
-
return _struct_cache[
|
|
171
|
+
return _struct_cache[cache_key]
|
|
165
172
|
|
|
166
173
|
|
|
167
174
|
def dynamic_struct(suffix: str):
|
|
@@ -293,12 +300,13 @@ def get_integrand_kernel(
|
|
|
293
300
|
options.update(kernel_options)
|
|
294
301
|
|
|
295
302
|
kernel_key = _make_key(integrand.func, suffix, options=options)
|
|
296
|
-
cache_key = _make_cache_key(integrand, kernel_key, integrand.argspec)
|
|
303
|
+
cache_key = _make_cache_key(integrand, kernel_key, integrand.argspec, allow_overloads=True)
|
|
297
304
|
|
|
298
305
|
if cache_key not in _kernel_cache:
|
|
299
306
|
if kernel_fn is None:
|
|
300
307
|
return None
|
|
301
308
|
|
|
309
|
+
kernel_key = _key_re.sub("", kernel_key)
|
|
302
310
|
module = wp.get_module(f"{integrand.module.name}.{kernel_key}")
|
|
303
311
|
module.options = options
|
|
304
312
|
_kernel_cache[cache_key] = wp.Kernel(
|
warp/fem/domain.py
CHANGED
|
@@ -237,11 +237,11 @@ class Cells(GeometryDomain):
|
|
|
237
237
|
filter_target = True
|
|
238
238
|
pos_type = cache.cached_vec_type(self.geometry.dimension, dtype=float)
|
|
239
239
|
|
|
240
|
-
@cache.dynamic_func(suffix=self.name)
|
|
240
|
+
@cache.dynamic_func(suffix=self.name, allow_overloads=True)
|
|
241
241
|
def cell_partition_lookup(args: self.DomainArg, pos: pos_type, max_dist: float):
|
|
242
242
|
return filtered_cell_lookup(args.geo, pos, max_dist, args.index, filter_target)
|
|
243
243
|
|
|
244
|
-
@cache.dynamic_func(suffix=self.name)
|
|
244
|
+
@cache.dynamic_func(suffix=self.name, allow_overloads=True)
|
|
245
245
|
def cell_partition_lookup(args: self.DomainArg, pos: pos_type):
|
|
246
246
|
max_dist = 0.0
|
|
247
247
|
return filtered_cell_lookup(args.geo, pos, max_dist, args.index, filter_target)
|
warp/fem/field/nodal_field.py
CHANGED
|
@@ -232,7 +232,7 @@ class NodalFieldBase(DiscreteField):
|
|
|
232
232
|
@cache.dynamic_func(suffix=self.name)
|
|
233
233
|
def eval_grad_outer_world_space(args: self.ElementEvalArg, s: Sample):
|
|
234
234
|
grad_transform = self.space.element_outer_reference_gradient_transform(args.elt_arg, s)
|
|
235
|
-
return
|
|
235
|
+
return eval_grad_outer(args, s, grad_transform)
|
|
236
236
|
|
|
237
237
|
return eval_grad_outer_world_space
|
|
238
238
|
else:
|
|
@@ -240,7 +240,7 @@ class NodalFieldBase(DiscreteField):
|
|
|
240
240
|
@cache.dynamic_func(suffix=self.name)
|
|
241
241
|
def eval_grad_outer_ref_space(args: self.ElementEvalArg, s: Sample):
|
|
242
242
|
grad_transform = 1.0
|
|
243
|
-
return
|
|
243
|
+
return eval_grad_outer(args, s, grad_transform)
|
|
244
244
|
|
|
245
245
|
return eval_grad_outer_ref_space
|
|
246
246
|
|