warp-lang 1.2.2__py3-none-manylinux2014_aarch64.whl → 1.3.0__py3-none-manylinux2014_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (193) hide show
  1. warp/__init__.py +8 -6
  2. warp/autograd.py +823 -0
  3. warp/bin/warp.so +0 -0
  4. warp/build.py +6 -2
  5. warp/builtins.py +1410 -886
  6. warp/codegen.py +503 -166
  7. warp/config.py +48 -18
  8. warp/context.py +400 -198
  9. warp/dlpack.py +8 -0
  10. warp/examples/assets/bunny.usd +0 -0
  11. warp/examples/benchmarks/benchmark_cloth_warp.py +1 -1
  12. warp/examples/benchmarks/benchmark_interop_torch.py +158 -0
  13. warp/examples/benchmarks/benchmark_launches.py +1 -1
  14. warp/examples/core/example_cupy.py +78 -0
  15. warp/examples/fem/example_apic_fluid.py +17 -36
  16. warp/examples/fem/example_burgers.py +9 -18
  17. warp/examples/fem/example_convection_diffusion.py +7 -17
  18. warp/examples/fem/example_convection_diffusion_dg.py +27 -47
  19. warp/examples/fem/example_deformed_geometry.py +11 -22
  20. warp/examples/fem/example_diffusion.py +7 -18
  21. warp/examples/fem/example_diffusion_3d.py +24 -28
  22. warp/examples/fem/example_diffusion_mgpu.py +7 -14
  23. warp/examples/fem/example_magnetostatics.py +190 -0
  24. warp/examples/fem/example_mixed_elasticity.py +111 -80
  25. warp/examples/fem/example_navier_stokes.py +30 -34
  26. warp/examples/fem/example_nonconforming_contact.py +290 -0
  27. warp/examples/fem/example_stokes.py +17 -32
  28. warp/examples/fem/example_stokes_transfer.py +12 -21
  29. warp/examples/fem/example_streamlines.py +350 -0
  30. warp/examples/fem/utils.py +936 -0
  31. warp/fabric.py +5 -2
  32. warp/fem/__init__.py +13 -3
  33. warp/fem/cache.py +161 -11
  34. warp/fem/dirichlet.py +37 -28
  35. warp/fem/domain.py +105 -14
  36. warp/fem/field/__init__.py +14 -3
  37. warp/fem/field/field.py +454 -11
  38. warp/fem/field/nodal_field.py +33 -18
  39. warp/fem/geometry/deformed_geometry.py +50 -15
  40. warp/fem/geometry/hexmesh.py +12 -24
  41. warp/fem/geometry/nanogrid.py +106 -31
  42. warp/fem/geometry/quadmesh_2d.py +6 -11
  43. warp/fem/geometry/tetmesh.py +103 -61
  44. warp/fem/geometry/trimesh_2d.py +98 -47
  45. warp/fem/integrate.py +231 -186
  46. warp/fem/operator.py +14 -9
  47. warp/fem/quadrature/pic_quadrature.py +35 -9
  48. warp/fem/quadrature/quadrature.py +119 -32
  49. warp/fem/space/basis_space.py +98 -22
  50. warp/fem/space/collocated_function_space.py +3 -1
  51. warp/fem/space/function_space.py +7 -2
  52. warp/fem/space/grid_2d_function_space.py +3 -3
  53. warp/fem/space/grid_3d_function_space.py +4 -4
  54. warp/fem/space/hexmesh_function_space.py +3 -2
  55. warp/fem/space/nanogrid_function_space.py +12 -14
  56. warp/fem/space/partition.py +45 -47
  57. warp/fem/space/restriction.py +19 -16
  58. warp/fem/space/shape/cube_shape_function.py +91 -3
  59. warp/fem/space/shape/shape_function.py +7 -0
  60. warp/fem/space/shape/square_shape_function.py +32 -0
  61. warp/fem/space/shape/tet_shape_function.py +11 -7
  62. warp/fem/space/shape/triangle_shape_function.py +10 -1
  63. warp/fem/space/topology.py +116 -42
  64. warp/fem/types.py +8 -1
  65. warp/fem/utils.py +301 -83
  66. warp/native/array.h +16 -0
  67. warp/native/builtin.h +0 -15
  68. warp/native/cuda_util.cpp +14 -6
  69. warp/native/exports.h +1348 -1308
  70. warp/native/quat.h +79 -0
  71. warp/native/rand.h +27 -4
  72. warp/native/sparse.cpp +83 -81
  73. warp/native/sparse.cu +381 -453
  74. warp/native/vec.h +64 -0
  75. warp/native/volume.cpp +40 -49
  76. warp/native/volume_builder.cu +2 -3
  77. warp/native/volume_builder.h +12 -17
  78. warp/native/warp.cu +3 -3
  79. warp/native/warp.h +69 -59
  80. warp/render/render_opengl.py +17 -9
  81. warp/sim/articulation.py +117 -17
  82. warp/sim/collide.py +35 -29
  83. warp/sim/model.py +123 -18
  84. warp/sim/render.py +3 -1
  85. warp/sparse.py +867 -203
  86. warp/stubs.py +312 -541
  87. warp/tape.py +29 -1
  88. warp/tests/disabled_kinematics.py +1 -1
  89. warp/tests/test_adam.py +1 -1
  90. warp/tests/test_arithmetic.py +1 -1
  91. warp/tests/test_array.py +58 -1
  92. warp/tests/test_array_reduce.py +1 -1
  93. warp/tests/test_async.py +1 -1
  94. warp/tests/test_atomic.py +1 -1
  95. warp/tests/test_bool.py +1 -1
  96. warp/tests/test_builtins_resolution.py +1 -1
  97. warp/tests/test_bvh.py +6 -1
  98. warp/tests/test_closest_point_edge_edge.py +1 -1
  99. warp/tests/test_codegen.py +66 -1
  100. warp/tests/test_compile_consts.py +1 -1
  101. warp/tests/test_conditional.py +1 -1
  102. warp/tests/test_copy.py +1 -1
  103. warp/tests/test_ctypes.py +1 -1
  104. warp/tests/test_dense.py +1 -1
  105. warp/tests/test_devices.py +1 -1
  106. warp/tests/test_dlpack.py +1 -1
  107. warp/tests/test_examples.py +33 -4
  108. warp/tests/test_fabricarray.py +5 -2
  109. warp/tests/test_fast_math.py +1 -1
  110. warp/tests/test_fem.py +213 -6
  111. warp/tests/test_fp16.py +1 -1
  112. warp/tests/test_func.py +1 -1
  113. warp/tests/test_future_annotations.py +90 -0
  114. warp/tests/test_generics.py +1 -1
  115. warp/tests/test_grad.py +1 -1
  116. warp/tests/test_grad_customs.py +1 -1
  117. warp/tests/test_grad_debug.py +247 -0
  118. warp/tests/test_hash_grid.py +6 -1
  119. warp/tests/test_implicit_init.py +354 -0
  120. warp/tests/test_import.py +1 -1
  121. warp/tests/test_indexedarray.py +1 -1
  122. warp/tests/test_intersect.py +1 -1
  123. warp/tests/test_jax.py +1 -1
  124. warp/tests/test_large.py +1 -1
  125. warp/tests/test_launch.py +1 -1
  126. warp/tests/test_lerp.py +1 -1
  127. warp/tests/test_linear_solvers.py +1 -1
  128. warp/tests/test_lvalue.py +1 -1
  129. warp/tests/test_marching_cubes.py +5 -2
  130. warp/tests/test_mat.py +34 -35
  131. warp/tests/test_mat_lite.py +2 -1
  132. warp/tests/test_mat_scalar_ops.py +1 -1
  133. warp/tests/test_math.py +1 -1
  134. warp/tests/test_matmul.py +20 -16
  135. warp/tests/test_matmul_lite.py +1 -1
  136. warp/tests/test_mempool.py +1 -1
  137. warp/tests/test_mesh.py +5 -2
  138. warp/tests/test_mesh_query_aabb.py +1 -1
  139. warp/tests/test_mesh_query_point.py +1 -1
  140. warp/tests/test_mesh_query_ray.py +1 -1
  141. warp/tests/test_mlp.py +1 -1
  142. warp/tests/test_model.py +1 -1
  143. warp/tests/test_module_hashing.py +77 -1
  144. warp/tests/test_modules_lite.py +1 -1
  145. warp/tests/test_multigpu.py +1 -1
  146. warp/tests/test_noise.py +1 -1
  147. warp/tests/test_operators.py +1 -1
  148. warp/tests/test_options.py +1 -1
  149. warp/tests/test_overwrite.py +542 -0
  150. warp/tests/test_peer.py +1 -1
  151. warp/tests/test_pinned.py +1 -1
  152. warp/tests/test_print.py +1 -1
  153. warp/tests/test_quat.py +15 -1
  154. warp/tests/test_rand.py +1 -1
  155. warp/tests/test_reload.py +1 -1
  156. warp/tests/test_rounding.py +1 -1
  157. warp/tests/test_runlength_encode.py +1 -1
  158. warp/tests/test_scalar_ops.py +95 -0
  159. warp/tests/test_sim_grad.py +1 -1
  160. warp/tests/test_sim_kinematics.py +1 -1
  161. warp/tests/test_smoothstep.py +1 -1
  162. warp/tests/test_sparse.py +82 -15
  163. warp/tests/test_spatial.py +1 -1
  164. warp/tests/test_special_values.py +2 -11
  165. warp/tests/test_streams.py +11 -1
  166. warp/tests/test_struct.py +1 -1
  167. warp/tests/test_tape.py +1 -1
  168. warp/tests/test_torch.py +194 -1
  169. warp/tests/test_transient_module.py +1 -1
  170. warp/tests/test_types.py +1 -1
  171. warp/tests/test_utils.py +1 -1
  172. warp/tests/test_vec.py +15 -63
  173. warp/tests/test_vec_lite.py +2 -1
  174. warp/tests/test_vec_scalar_ops.py +65 -1
  175. warp/tests/test_verify_fp.py +1 -1
  176. warp/tests/test_volume.py +28 -2
  177. warp/tests/test_volume_write.py +1 -1
  178. warp/tests/unittest_serial.py +1 -1
  179. warp/tests/unittest_suites.py +9 -1
  180. warp/tests/walkthrough_debug.py +1 -1
  181. warp/thirdparty/unittest_parallel.py +2 -5
  182. warp/torch.py +103 -41
  183. warp/types.py +341 -224
  184. warp/utils.py +11 -2
  185. {warp_lang-1.2.2.dist-info → warp_lang-1.3.0.dist-info}/METADATA +99 -46
  186. warp_lang-1.3.0.dist-info/RECORD +368 -0
  187. warp/examples/fem/bsr_utils.py +0 -378
  188. warp/examples/fem/mesh_utils.py +0 -133
  189. warp/examples/fem/plot_utils.py +0 -292
  190. warp_lang-1.2.2.dist-info/RECORD +0 -359
  191. {warp_lang-1.2.2.dist-info → warp_lang-1.3.0.dist-info}/LICENSE.md +0 -0
  192. {warp_lang-1.2.2.dist-info → warp_lang-1.3.0.dist-info}/WHEEL +0 -0
  193. {warp_lang-1.2.2.dist-info → warp_lang-1.3.0.dist-info}/top_level.txt +0 -0
warp/autograd.py ADDED
@@ -0,0 +1,823 @@
1
+ # Copyright (c) 2024 NVIDIA CORPORATION. All rights reserved.
2
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
3
+ # and proprietary rights in and to this software, related documentation
4
+ # and any modifications thereto. Any use, reproduction, disclosure or
5
+ # distribution of this software and related documentation without an express
6
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
7
+
8
+ import itertools
9
+ from typing import Any, Dict, List, Sequence, Tuple, Union
10
+
11
+ import numpy as np
12
+
13
+ import warp as wp
14
+
15
+ __all__ = [
16
+ "jacobian",
17
+ "jacobian_fd",
18
+ "gradcheck",
19
+ "gradcheck_tape",
20
+ "plot_kernel_jacobians",
21
+ ]
22
+
23
+
24
+ def gradcheck(
25
+ function: wp.Kernel,
26
+ dim: Tuple[int],
27
+ inputs: Sequence,
28
+ outputs: Sequence,
29
+ *,
30
+ eps=1e-4,
31
+ atol=1e-3,
32
+ rtol=1e-2,
33
+ raise_exception=True,
34
+ input_output_mask: List[Tuple[Union[str, int], Union[str, int]]] = None,
35
+ device: wp.context.Devicelike = None,
36
+ max_blocks=0,
37
+ max_inputs_per_var=-1,
38
+ max_outputs_per_var=-1,
39
+ plot_relative_error=False,
40
+ plot_absolute_error=False,
41
+ show_summary: bool = True,
42
+ ) -> bool:
43
+ """
44
+ Checks whether the autodiff gradient of a Warp kernel matches finite differences.
45
+ Fails if the relative or absolute errors between the autodiff and finite difference gradients exceed the specified tolerance, or if the autodiff gradients contain NaN values.
46
+
47
+ The kernel function and its adjoint version are launched with the given inputs and outputs, as well as the provided ``dim`` and ``max_blocks`` arguments (see :func:`warp.launch` for more details).
48
+
49
+ Note:
50
+ This function only supports Warp kernels whose input arguments precede the output arguments.
51
+
52
+ Only Warp arrays with ``requires_grad=True`` are considered for the Jacobian computation.
53
+
54
+ Structs arguments are not yet supported by this function to compute Jacobians.
55
+
56
+ Args:
57
+ function: The Warp kernel function, decorated with the ``@wp.kernel`` decorator.
58
+ dim: The number of threads to launch the kernel, can be an integer, or a Tuple of ints.
59
+ inputs: List of input variables.
60
+ outputs: List of output variables.
61
+ eps: The finite-difference step size.
62
+ atol: The absolute tolerance for the gradient check.
63
+ rtol: The relative tolerance for the gradient check.
64
+ raise_exception: If True, raises a `ValueError` if the gradient check fails.
65
+ input_output_mask: List of tuples specifying the input-output pairs to compute the Jacobian for. Inputs and outputs can be identified either by their integer indices of where they appear in the kernel input/output arguments, or by the respective argument names as strings. If None, computes the Jacobian for all input-output pairs.
66
+ device: The device to launch on (optional)
67
+ max_blocks: The maximum number of CUDA thread blocks to use.
68
+ max_inputs_per_var: Maximum number of input dimensions over which to evaluate the Jacobians for the input-output pairs. Evaluates all input dimensions if value <= 0.
69
+ max_outputs_per_var: Maximum number of output dimensions over which to evaluate the Jacobians for the input-output pairs. Evaluates all output dimensions if value <= 0.
70
+ plot_relative_error: If True, visualizes the relative error of the Jacobians in a plot (requires ``matplotlib``).
71
+ plot_absolute_error: If True, visualizes the absolute error of the Jacobians in a plot (requires ``matplotlib``).
72
+ show_summary: If True, prints a summary table of the gradient check results.
73
+
74
+ Returns:
75
+ True if the gradient check passes, False otherwise.
76
+ """
77
+
78
+ assert isinstance(function, wp.Kernel), "The function argument must be a Warp kernel"
79
+
80
+ jacs_fd = jacobian_fd(
81
+ function,
82
+ dim=dim,
83
+ inputs=inputs,
84
+ outputs=outputs,
85
+ input_output_mask=input_output_mask,
86
+ device=device,
87
+ max_blocks=max_blocks,
88
+ max_inputs_per_var=max_inputs_per_var,
89
+ eps=eps,
90
+ plot_jacobians=False,
91
+ )
92
+
93
+ jacs_ad = jacobian(
94
+ function,
95
+ dim=dim,
96
+ inputs=inputs,
97
+ outputs=outputs,
98
+ input_output_mask=input_output_mask,
99
+ device=device,
100
+ max_blocks=max_blocks,
101
+ max_outputs_per_var=max_outputs_per_var,
102
+ plot_jacobians=False,
103
+ )
104
+
105
+ relative_error_jacs = {}
106
+ absolute_error_jacs = {}
107
+
108
+ if show_summary:
109
+ summary = []
110
+ summary_header = ["Input", "Output", "Max Abs Error", "Max Rel Error", "Pass"]
111
+
112
+ class FontColors:
113
+ OKGREEN = "\033[92m"
114
+ WARNING = "\033[93m"
115
+ FAIL = "\033[91m"
116
+ ENDC = "\033[0m"
117
+
118
+ success = True
119
+ for (input_i, output_i), jac_fd in jacs_fd.items():
120
+ jac_ad = jacs_ad[input_i, output_i]
121
+ if plot_relative_error or plot_absolute_error:
122
+ jac_rel_error = wp.empty_like(jac_fd)
123
+ jac_abs_error = wp.empty_like(jac_fd)
124
+ flat_jac_fd = scalarize_array_1d(jac_fd)
125
+ flat_jac_ad = scalarize_array_1d(jac_ad)
126
+ flat_jac_rel_error = scalarize_array_1d(jac_rel_error)
127
+ flat_jac_abs_error = scalarize_array_1d(jac_abs_error)
128
+ wp.launch(
129
+ compute_error_kernel,
130
+ dim=len(flat_jac_fd),
131
+ inputs=[flat_jac_ad, flat_jac_fd, flat_jac_rel_error, flat_jac_abs_error],
132
+ device=jac_fd.device,
133
+ )
134
+ relative_error_jacs[(input_i, output_i)] = jac_rel_error
135
+ absolute_error_jacs[(input_i, output_i)] = jac_abs_error
136
+ cut_jac_fd = jac_fd.numpy()
137
+ cut_jac_ad = jac_ad.numpy()
138
+ if max_outputs_per_var > 0:
139
+ cut_jac_fd = cut_jac_fd[:max_outputs_per_var]
140
+ cut_jac_ad = cut_jac_ad[:max_outputs_per_var]
141
+ if max_inputs_per_var > 0:
142
+ cut_jac_fd = cut_jac_fd[:, :max_inputs_per_var]
143
+ cut_jac_ad = cut_jac_ad[:, :max_inputs_per_var]
144
+ grad_matches = np.allclose(cut_jac_ad, cut_jac_fd, atol=atol, rtol=rtol)
145
+ success = success and grad_matches
146
+ if not grad_matches:
147
+ if raise_exception:
148
+ raise ValueError(
149
+ f"Gradient check failed for kernel {function.key}, input {input_i}, output {output_i}: "
150
+ f"finite difference and autodiff gradients do not match"
151
+ )
152
+ elif not show_summary:
153
+ return False
154
+ isnan = np.any(np.isnan(cut_jac_ad))
155
+ success = success and not isnan
156
+ if isnan:
157
+ if raise_exception:
158
+ raise ValueError(
159
+ f"Gradient check failed for kernel {function.key}, input {input_i}, output {output_i}: "
160
+ f"gradient contains NaN values"
161
+ )
162
+ elif not show_summary:
163
+ return False
164
+
165
+ if show_summary:
166
+ max_abs_error = np.abs(cut_jac_ad - cut_jac_fd).max()
167
+ max_rel_error = np.abs((cut_jac_ad - cut_jac_fd) / (cut_jac_fd + 1e-8)).max()
168
+ if isnan:
169
+ pass_str = FontColors.FAIL + "NaN" + FontColors.ENDC
170
+ elif grad_matches:
171
+ pass_str = FontColors.OKGREEN + "PASS" + FontColors.ENDC
172
+ else:
173
+ pass_str = FontColors.FAIL + "FAIL" + FontColors.ENDC
174
+ input_name = function.adj.args[input_i].label
175
+ output_name = function.adj.args[len(inputs) + output_i].label
176
+ summary.append([input_name, output_name, f"{max_abs_error:.7e}", f"{max_rel_error:.7e}", pass_str])
177
+
178
+ if show_summary:
179
+ print_table(summary_header, summary)
180
+ if not success:
181
+ print(FontColors.FAIL + f"Gradient check for kernel {function.key} failed" + FontColors.ENDC)
182
+ else:
183
+ print(FontColors.OKGREEN + f"Gradient check for kernel {function.key} passed" + FontColors.ENDC)
184
+ if plot_relative_error:
185
+ plot_kernel_jacobians(
186
+ relative_error_jacs,
187
+ function,
188
+ inputs,
189
+ outputs,
190
+ title=f"{function.key} kernel Jacobian relative error",
191
+ )
192
+ if plot_absolute_error:
193
+ plot_kernel_jacobians(
194
+ absolute_error_jacs,
195
+ function,
196
+ inputs,
197
+ outputs,
198
+ title=f"{function.key} kernel Jacobian absolute error",
199
+ )
200
+
201
+ return success
202
+
203
+
204
+ def gradcheck_tape(
205
+ tape: wp.Tape,
206
+ *,
207
+ eps=1e-4,
208
+ atol=1e-3,
209
+ rtol=1e-2,
210
+ raise_exception=True,
211
+ input_output_masks: Dict[str, List[Tuple[Union[str, int], Union[str, int]]]] = None,
212
+ blacklist_kernels: List[str] = None,
213
+ whitelist_kernels: List[str] = None,
214
+ max_inputs_per_var=-1,
215
+ max_outputs_per_var=-1,
216
+ plot_relative_error=False,
217
+ plot_absolute_error=False,
218
+ show_summary: bool = True,
219
+ ) -> bool:
220
+ """
221
+ Checks whether the autodiff gradients for kernels recorded on the Warp tape match finite differences.
222
+ Fails if the relative or absolute errors between the autodiff and finite difference gradients exceed the specified tolerance, or if the autodiff gradients contain NaN values.
223
+
224
+ Note:
225
+ Only Warp kernels recorded on the tape are checked but not arbitrary functions that have been recorded, e.g. via :meth:`Tape.record_func`.
226
+
227
+ Only Warp arrays with ``requires_grad=True`` are considered for the Jacobian computation.
228
+
229
+ Structs arguments are not yet supported by this function to compute Jacobians.
230
+
231
+ Args:
232
+ tape: The Warp tape to perform the gradient check on.
233
+ eps: The finite-difference step size.
234
+ atol: The absolute tolerance for the gradient check.
235
+ rtol: The relative tolerance for the gradient check.
236
+ raise_exception: If True, raises a `ValueError` if the gradient check fails.
237
+ input_output_masks: Dictionary of input-output masks for each kernel in the tape, mapping from kernel keys to input-output masks. Inputs and outputs can be identified either by their integer indices of where they appear in the kernel input/output arguments, or by the respective argument names as strings. If None, computes the Jacobian for all input-output pairs.
238
+ blacklist_kernels: List of kernel keys to exclude from the gradient check.
239
+ whitelist_kernels: List of kernel keys to include in the gradient check. If not empty or None, only kernels in this list are checked.
240
+ max_blocks: The maximum number of CUDA thread blocks to use.
241
+ max_inputs_per_var: Maximum number of input dimensions over which to evaluate the Jacobians for the input-output pairs. Evaluates all input dimensions if value <= 0.
242
+ max_outputs_per_var: Maximum number of output dimensions over which to evaluate the Jacobians for the input-output pairs. Evaluates all output dimensions if value <= 0.
243
+ plot_relative_error: If True, visualizes the relative error of the Jacobians in a plot (requires ``matplotlib``).
244
+ plot_absolute_error: If True, visualizes the absolute error of the Jacobians in a plot (requires ``matplotlib``).
245
+ show_summary: If True, prints a summary table of the gradient check results.
246
+
247
+ Returns:
248
+ True if the gradient check passes for all kernels on the tape, False otherwise.
249
+ """
250
+ if input_output_masks is None:
251
+ input_output_masks = {}
252
+ if blacklist_kernels is None:
253
+ blacklist_kernels = []
254
+ else:
255
+ blacklist_kernels = set(blacklist_kernels)
256
+ if whitelist_kernels is None:
257
+ whitelist_kernels = []
258
+ else:
259
+ whitelist_kernels = set(whitelist_kernels)
260
+
261
+ overall_success = True
262
+ for launch in tape.launches:
263
+ if not isinstance(launch[0], wp.Kernel):
264
+ continue
265
+ kernel, dim, max_blocks, inputs, outputs, device = launch[:6]
266
+ if len(whitelist_kernels) > 0 and kernel.key not in whitelist_kernels:
267
+ continue
268
+ if kernel.key in blacklist_kernels:
269
+ continue
270
+ input_output_mask = input_output_masks.get(kernel.key)
271
+ success = gradcheck(
272
+ kernel,
273
+ dim,
274
+ inputs,
275
+ outputs,
276
+ eps=eps,
277
+ atol=atol,
278
+ rtol=rtol,
279
+ raise_exception=raise_exception,
280
+ input_output_mask=input_output_mask,
281
+ device=device,
282
+ max_blocks=max_blocks,
283
+ max_inputs_per_var=max_inputs_per_var,
284
+ max_outputs_per_var=max_outputs_per_var,
285
+ plot_relative_error=plot_relative_error,
286
+ plot_absolute_error=plot_absolute_error,
287
+ show_summary=show_summary,
288
+ )
289
+ overall_success = overall_success and success
290
+
291
+ return overall_success
292
+
293
+
294
+ def get_struct_vars(x: wp.codegen.StructInstance):
295
+ return {varname: getattr(x, varname) for varname, _ in x._cls.ctype._fields_}
296
+
297
+
298
+ def infer_device(xs: list):
299
+ # retrieve best matching Warp device for a list of variables
300
+ for x in xs:
301
+ if isinstance(x, wp.array):
302
+ return x.device
303
+ elif isinstance(x, wp.codegen.StructInstance):
304
+ for var in get_struct_vars(x).values():
305
+ if isinstance(var, wp.array):
306
+ return var.device
307
+ return wp.get_preferred_device()
308
+
309
+
310
+ def plot_kernel_jacobians(
311
+ jacobians: Dict[Tuple[int, int], wp.array],
312
+ kernel: wp.Kernel,
313
+ inputs: Sequence,
314
+ outputs: Sequence,
315
+ show_plot=True,
316
+ show_colorbar=True,
317
+ scale_colors_per_submatrix=False,
318
+ title: str = None,
319
+ colormap: str = "coolwarm",
320
+ log_scale=False,
321
+ ):
322
+ """
323
+ Visualizes the Jacobians computed by :func:`jacobian` or :func:`jacobian_fd` in a combined image plot.
324
+ Requires the ``matplotlib`` package to be installed.
325
+
326
+ Args:
327
+ jacobians: A dictionary of Jacobians, where the keys are tuples of input and output indices, and the values are the Jacobian matrices.
328
+ kernel: The Warp kernel function, decorated with the ``@wp.kernel`` decorator.
329
+ inputs: List of input variables.
330
+ outputs: List of output variables.
331
+ show_plot: If True, displays the plot via ``plt.show()``.
332
+ show_colorbar: If True, displays a colorbar next to the plot (or a colorbar next to every submatrix if ).
333
+ scale_colors_per_submatrix: If True, considers the minimum and maximum of each Jacobian submatrix separately for color scaling. Otherwise, uses the global minimum and maximum of all Jacobians.
334
+ title: The title of the plot (optional).
335
+ colormap: The colormap to use for the plot.
336
+ log_scale: If True, uses a logarithmic scale for the matrix values shown in the image plot.
337
+
338
+ Returns:
339
+ The created Matplotlib figure.
340
+ """
341
+ import matplotlib.pyplot as plt
342
+ from matplotlib.ticker import FuncFormatter, MaxNLocator, MultipleLocator
343
+
344
+ jacobians = sorted(jacobians.items(), key=lambda x: (x[0][1], x[0][0]))
345
+ jacobians = dict(jacobians)
346
+
347
+ input_to_ax = {}
348
+ output_to_ax = {}
349
+ for i, j in jacobians.keys():
350
+ if i not in input_to_ax:
351
+ input_to_ax[i] = len(input_to_ax)
352
+ if j not in output_to_ax:
353
+ output_to_ax[j] = len(output_to_ax)
354
+
355
+ num_rows = len(output_to_ax)
356
+ num_cols = len(input_to_ax)
357
+ if num_rows == 0 or num_cols == 0:
358
+ return
359
+
360
+ # determine the width and height ratios for the subplots based on the
361
+ # dimensions of the Jacobians
362
+ width_ratios = []
363
+ height_ratios = []
364
+ for i, input in enumerate(inputs):
365
+ if not isinstance(input, wp.array) or not input.requires_grad:
366
+ continue
367
+ input_stride = input.dtype._length_
368
+ for j in range(len(outputs)):
369
+ if (i, j) not in jacobians:
370
+ continue
371
+ jac_wp = jacobians[(i, j)]
372
+ width_ratios.append(jac_wp.shape[1] * input_stride)
373
+ break
374
+
375
+ for i, output in enumerate(outputs):
376
+ if not isinstance(output, wp.array) or not output.requires_grad:
377
+ continue
378
+ for j in range(len(inputs)):
379
+ if (j, i) not in jacobians:
380
+ continue
381
+ jac_wp = jacobians[(j, i)]
382
+ height_ratios.append(jac_wp.shape[0])
383
+ break
384
+
385
+ fig, axs = plt.subplots(
386
+ ncols=num_cols,
387
+ nrows=num_rows,
388
+ figsize=(7, 7),
389
+ sharex="col",
390
+ sharey="row",
391
+ gridspec_kw={
392
+ "wspace": 0.1,
393
+ "hspace": 0.1,
394
+ "width_ratios": width_ratios,
395
+ "height_ratios": height_ratios,
396
+ },
397
+ subplot_kw={"aspect": 1},
398
+ squeeze=False,
399
+ )
400
+ if title is None:
401
+ title = f"{kernel.key} kernel Jacobian"
402
+ fig.suptitle(title)
403
+ fig.canvas.manager.set_window_title(title)
404
+
405
+ if not scale_colors_per_submatrix:
406
+ safe_jacobians = [jac.numpy().flatten() for jac in jacobians.values()]
407
+ safe_jacobians = [jac[~np.isnan(jac)] for jac in safe_jacobians]
408
+ safe_jacobians = [jac for jac in safe_jacobians if len(jac) > 0]
409
+ if len(safe_jacobians) == 0:
410
+ vmin = 0
411
+ vmax = 0
412
+ else:
413
+ vmin = min([jac.min() for jac in safe_jacobians])
414
+ vmax = max([jac.max() for jac in safe_jacobians])
415
+
416
+ has_plot = np.ones((num_rows, num_cols), dtype=bool)
417
+ for i in range(num_rows):
418
+ for j in range(num_cols):
419
+ if (j, i) not in jacobians:
420
+ ax = axs[i, j]
421
+ ax.axis("off")
422
+ has_plot[i, j] = False
423
+
424
+ jac_i = 0
425
+ for (input_i, output_i), jac_wp in jacobians.items():
426
+ input = inputs[input_i]
427
+ output = outputs[output_i]
428
+ if not isinstance(input, wp.array) or not input.requires_grad:
429
+ continue
430
+ if not isinstance(output, wp.array) or not output.requires_grad:
431
+ continue
432
+
433
+ input_name = kernel.adj.args[input_i].label
434
+ output_name = kernel.adj.args[len(inputs) + output_i].label
435
+
436
+ ax_i, ax_j = output_to_ax[output_i], input_to_ax[input_i]
437
+ ax = axs[ax_i, ax_j]
438
+ ax.tick_params(which="major", width=1, length=7)
439
+ ax.tick_params(which="minor", width=1, length=4, color="gray")
440
+ # ax.yaxis.set_minor_formatter('{x:.0f}')
441
+
442
+ input_stride = input.dtype._length_
443
+ output_stride = output.dtype._length_
444
+
445
+ jac = jac_wp.numpy()
446
+ # Jacobian matrix has output stride already multiplied to first dimension
447
+ jac = jac.reshape(jac_wp.shape[0], jac_wp.shape[1] * input_stride)
448
+ ax.xaxis.set_minor_formatter("")
449
+ ax.yaxis.set_minor_formatter("")
450
+ ax.xaxis.set_minor_locator(MultipleLocator(1))
451
+ ax.yaxis.set_minor_locator(MultipleLocator(1))
452
+ # ax.set_xticks(np.arange(jac.shape[0]))
453
+ # stride = jac.shape[1] // jacobians[jac_i].shape[1]
454
+ # ax.xaxis.set_major_locator(MultipleLocator(input_stride))
455
+ if input_stride > 1:
456
+ ax.xaxis.set_major_locator(MaxNLocator(integer=True, nbins=1, steps=[input_stride]))
457
+ ticks = FuncFormatter(lambda x, pos, input_stride=input_stride: "{0:g}".format(x // input_stride))
458
+ ax.xaxis.set_major_formatter(ticks)
459
+ # ax.xaxis.set_major_locator(FixedLocator(np.arange(0, jac.shape[1] + 1, input_stride)))
460
+ # ax.xaxis.set_major_formatter('{x:.0f}')
461
+ # ticks = np.arange(jac_wp.shape[1] + 1)
462
+ # ax.set_xticklabels(ticks)
463
+
464
+ # ax.yaxis.set_major_locator(FixedLocator(np.arange(0, jac.shape[0] + 1, output_stride)))
465
+ # ax.yaxis.set_major_formatter('{x:.0f}')
466
+ # ax.yaxis.set_major_locator(MultipleLocator(output_stride))
467
+
468
+ if output_stride > 1:
469
+ ax.yaxis.set_major_locator(MaxNLocator(integer=True, nbins=1, steps=[output_stride]))
470
+ max_y = jac_wp.shape[0]
471
+ ticks = FuncFormatter(
472
+ lambda y, pos, max_y=max_y, output_stride=output_stride: "{0:g}".format((max_y - y) // output_stride)
473
+ )
474
+ ax.yaxis.set_major_formatter(ticks)
475
+ # divide by output stride to get the correct number of rows
476
+ ticks = np.arange(jac_wp.shape[0] // output_stride + 1)
477
+ # flip y labels to match the order of matrix rows starting from the top
478
+ # ax.set_yticklabels(ticks[::-1])
479
+ if scale_colors_per_submatrix:
480
+ safe_jac = jac[~np.isnan(jac)]
481
+ vmin = safe_jac.min()
482
+ vmax = safe_jac.max()
483
+ img = ax.imshow(
484
+ np.log10(np.abs(jac) + 1e-8) if log_scale else jac,
485
+ cmap=colormap,
486
+ aspect="auto",
487
+ interpolation="nearest",
488
+ extent=[0, jac.shape[1], 0, jac.shape[0]],
489
+ vmin=vmin,
490
+ vmax=vmax,
491
+ )
492
+ if ax_i == len(outputs) - 1 or not has_plot[ax_i + 1 :, ax_j].any():
493
+ # last plot of this column
494
+ ax.set_xlabel(input_name)
495
+ if ax_j == 0 or not has_plot[ax_i, :ax_j].any():
496
+ # first plot of this row
497
+ ax.set_ylabel(output_name)
498
+ ax.grid(color="gray", which="minor", linestyle="--", linewidth=0.5)
499
+ ax.grid(color="black", which="major", linewidth=1.0)
500
+
501
+ if show_colorbar and scale_colors_per_submatrix:
502
+ plt.colorbar(img, ax=ax, orientation="vertical", pad=0.02)
503
+
504
+ jac_i += 1
505
+
506
+ if show_colorbar and not scale_colors_per_submatrix:
507
+ m = plt.cm.ScalarMappable(cmap=colormap)
508
+ m.set_array([vmin, vmax])
509
+ m.set_clim(vmin, vmax)
510
+ plt.colorbar(m, ax=axs, orientation="vertical", pad=0.02)
511
+
512
+ plt.tight_layout()
513
+ if show_plot:
514
+ plt.show()
515
+ return fig
516
+
517
+
518
+ def scalarize_array_1d(arr):
519
+ # convert array to 1D array with scalar dtype
520
+ if arr.dtype in wp.types.scalar_types:
521
+ return arr.flatten()
522
+ elif arr.dtype in wp.types.vector_types:
523
+ return wp.array(
524
+ ptr=arr.ptr,
525
+ shape=(arr.size * arr.dtype._length_,),
526
+ dtype=arr.dtype._wp_scalar_type_,
527
+ device=arr.device,
528
+ )
529
+ else:
530
+ raise ValueError(
531
+ f"Unsupported array dtype {arr.dtype}: array to be flattened must be a scalar/vector/matrix array"
532
+ )
533
+
534
+
535
+ def scalarize_array_2d(arr):
536
+ assert arr.ndim == 2
537
+ # convert array to 2D array with scalar dtype
538
+ if arr.dtype in wp.types.scalar_types:
539
+ return arr
540
+ elif arr.dtype in wp.types.vector_types:
541
+ return wp.array(
542
+ ptr=arr.ptr,
543
+ shape=(arr.shape[0], arr.shape[1] * arr.dtype._length_),
544
+ dtype=arr.dtype._wp_scalar_type_,
545
+ device=arr.device,
546
+ )
547
+ else:
548
+ raise ValueError(
549
+ f"Unsupported array dtype {arr.dtype}: array to be flattened must be a scalar/vector/matrix array"
550
+ )
551
+
552
+
553
+ def jacobian(
554
+ kernel: wp.Kernel,
555
+ dim: Tuple[int],
556
+ inputs: Sequence,
557
+ outputs: Sequence = None,
558
+ input_output_mask: List[Tuple[Union[str, int], Union[str, int]]] = None,
559
+ device: wp.context.Devicelike = None,
560
+ max_blocks=0,
561
+ max_outputs_per_var=-1,
562
+ plot_jacobians=False,
563
+ ) -> Dict[Tuple[int, int], wp.array]:
564
+ """
565
+ Computes the Jacobians of a Warp kernel launch for the provided selection of differentiable inputs to differentiable outputs.
566
+
567
+ The kernel adjoint function is launched with the given inputs and outputs, as well as the provided ``dim`` and ``max_blocks`` arguments (see :func:`warp.launch` for more details).
568
+
569
+ Note:
570
+ This function only supports Warp kernels whose input arguments precede the output arguments.
571
+
572
+ Only Warp arrays with ``requires_grad=True`` are considered for the Jacobian computation.
573
+
574
+ Structs arguments are not yet supported by this function to compute Jacobians.
575
+
576
+ Args:
577
+ kernel: The Warp kernel function, decorated with the ``@wp.kernel`` decorator
578
+ dim: The number of threads to launch the kernel, can be an integer, or a Tuple of ints
579
+ inputs: List of input variables.
580
+ outputs: List of output variables. If None, the outputs are inferred from the kernel argument flags.
581
+ input_output_mask: List of tuples specifying the input-output pairs to compute the Jacobian for. Inputs and outputs can be identified either by their integer indices of where they appear in the kernel input/output arguments, or by the respective argument names as strings. If None, computes the Jacobian for all input-output pairs.
582
+ device: The device to launch on (optional)
583
+ max_blocks: The maximum number of CUDA thread blocks to use.
584
+ max_outputs_per_var: Maximum number of output dimensions over which to evaluate the Jacobians for the input-output pairs. Evaluates all output dimensions if value <= 0.
585
+ plot_jacobians: If True, visualizes the computed Jacobians in a plot (requires ``matplotlib``).
586
+
587
+ Returns:
588
+ A dictionary of Jacobians, where the keys are tuples of input and output indices, and the values are the Jacobian matrices.
589
+ """
590
+ if outputs is None:
591
+ outputs = []
592
+ if input_output_mask is None:
593
+ input_output_mask = []
594
+ arg_names = [arg.label for arg in kernel.adj.args]
595
+
596
+ def resolve_arg(name):
597
+ if isinstance(name, int):
598
+ return name
599
+ return arg_names.index(name)
600
+
601
+ input_output_mask = [
602
+ (resolve_arg(input_name), resolve_arg(output_name) - len(inputs))
603
+ for input_name, output_name in input_output_mask
604
+ ]
605
+ input_output_mask = set(input_output_mask)
606
+
607
+ if device is None:
608
+ device = infer_device(inputs + outputs)
609
+
610
+ tape = wp.Tape()
611
+ tape.record_launch(kernel=kernel, dim=dim, max_blocks=max_blocks, inputs=inputs, outputs=outputs, device=device)
612
+
613
+ jacobians = {}
614
+
615
+ for input_i, output_i in itertools.product(range(len(inputs)), range(len(outputs))):
616
+ if len(input_output_mask) > 0 and (input_i, output_i) not in input_output_mask:
617
+ continue
618
+ input = inputs[input_i]
619
+ output = outputs[output_i]
620
+ if not isinstance(input, wp.array) or not input.requires_grad:
621
+ continue
622
+ if not isinstance(output, wp.array) or not output.requires_grad:
623
+ continue
624
+ out_grad = scalarize_array_1d(output.grad)
625
+ output_num = out_grad.shape[0]
626
+ jacobian = wp.empty((output_num, input.size), dtype=input.dtype, device=input.device)
627
+ jacobian.fill_(wp.nan)
628
+ if max_outputs_per_var > 0:
629
+ output_num = min(output_num, max_outputs_per_var)
630
+ for i in range(output_num):
631
+ tape.zero()
632
+ if i > 0:
633
+ set_element(out_grad, i - 1, 0.0)
634
+ set_element(out_grad, i, 1.0)
635
+ tape.backward()
636
+ jacobian[i].assign(input.grad)
637
+ output.grad.zero_()
638
+ jacobians[input_i, output_i] = jacobian
639
+
640
+ if plot_jacobians:
641
+ plot_kernel_jacobians(
642
+ jacobians,
643
+ kernel,
644
+ inputs,
645
+ outputs,
646
+ )
647
+
648
+ return jacobians
649
+
650
+
651
+ def jacobian_fd(
652
+ kernel: wp.Kernel,
653
+ dim: Tuple[int],
654
+ inputs: Sequence,
655
+ outputs: Sequence = None,
656
+ input_output_mask: List[Tuple[Union[str, int], Union[str, int]]] = None,
657
+ device: wp.context.Devicelike = None,
658
+ max_blocks=0,
659
+ max_inputs_per_var=-1,
660
+ eps=1e-4,
661
+ plot_jacobians=False,
662
+ ) -> Dict[Tuple[int, int], wp.array]:
663
+ """
664
+ Computes the finite-difference Jacobian of a Warp kernel launch for the provided selection of differentiable inputs to differentiable outputs.
665
+ The method uses a central difference scheme to approximate the Jacobian.
666
+
667
+ The kernel is launched multiple times in forward-only mode with the given inputs and outputs, as well as the provided ``dim`` and ``max_blocks`` arguments (see :func:`warp.launch` for more details).
668
+
669
+ Note:
670
+ This function only supports Warp kernels whose input arguments precede the output arguments.
671
+
672
+ Only Warp arrays with ``requires_grad=True`` are considered for the Jacobian computation.
673
+
674
+ Structs arguments are not yet supported by this function to compute Jacobians.
675
+
676
+ Args:
677
+ kernel: The Warp kernel function, decorated with the ``@wp.kernel`` decorator
678
+ dim: The number of threads to launch the kernel, can be an integer, or a Tuple of ints
679
+ inputs: List of input variables.
680
+ outputs: List of output variables. If None, the outputs are inferred from the kernel argument flags.
681
+ input_output_mask: List of tuples specifying the input-output pairs to compute the Jacobian for. Inputs and outputs can be identified either by their integer indices of where they appear in the kernel input/output arguments, or by the respective argument names as strings. If None, computes the Jacobian for all input-output pairs.
682
+ device: The device to launch on (optional)
683
+ max_blocks: The maximum number of CUDA thread blocks to use.
684
+ max_inputs_per_var: Maximum number of input dimensions over which to evaluate the Jacobians for the input-output pairs. Evaluates all input dimensions if value <= 0.
685
+ eps: The finite-difference step size.
686
+ plot_jacobians: If True, visualizes the computed Jacobians in a plot (requires ``matplotlib``).
687
+
688
+ Returns:
689
+ A dictionary of Jacobians, where the keys are tuples of input and output indices, and the values are the Jacobian matrices.
690
+ """
691
+ if outputs is None:
692
+ outputs = []
693
+ if input_output_mask is None:
694
+ input_output_mask = []
695
+ arg_names = [arg.label for arg in kernel.adj.args]
696
+
697
+ def resolve_arg(name):
698
+ if isinstance(name, int):
699
+ return name
700
+ return arg_names.index(name)
701
+
702
+ input_output_mask = [
703
+ (resolve_arg(input_name), resolve_arg(output_name) - len(inputs))
704
+ for input_name, output_name in input_output_mask
705
+ ]
706
+ input_output_mask = set(input_output_mask)
707
+
708
+ if device is None:
709
+ device = infer_device(inputs + outputs)
710
+
711
+ jacobians = {}
712
+
713
+ for input_i, output_i in itertools.product(range(len(inputs)), range(len(outputs))):
714
+ if len(input_output_mask) > 0 and (input_i, output_i) not in input_output_mask:
715
+ continue
716
+ input = inputs[input_i]
717
+ output = outputs[output_i]
718
+ if not isinstance(input, wp.array) or not input.requires_grad:
719
+ continue
720
+ if not isinstance(output, wp.array) or not output.requires_grad:
721
+ continue
722
+
723
+ flat_input = scalarize_array_1d(input)
724
+
725
+ left = wp.clone(output)
726
+ right = wp.clone(output)
727
+ flat_left = scalarize_array_1d(left)
728
+ flat_right = scalarize_array_1d(right)
729
+
730
+ left_outputs = outputs[:output_i] + [left] + outputs[output_i + 1 :]
731
+ right_outputs = outputs[:output_i] + [right] + outputs[output_i + 1 :]
732
+
733
+ input_num = flat_input.shape[0]
734
+ jacobian = wp.empty((flat_left.size, input.size), dtype=input.dtype, device=input.device)
735
+ jacobian.fill_(wp.nan)
736
+
737
+ jacobian_scalar = scalarize_array_2d(jacobian)
738
+ jacobian_t = jacobian_scalar.transpose()
739
+ if max_inputs_per_var > 0:
740
+ input_num = min(input_num, max_inputs_per_var)
741
+ for i in range(input_num):
742
+ set_element(flat_input, i, -eps, relative=True)
743
+ wp.launch(kernel, dim=dim, max_blocks=max_blocks, inputs=inputs, outputs=left_outputs, device=device)
744
+
745
+ set_element(flat_input, i, 2 * eps, relative=True)
746
+ wp.launch(kernel, dim=dim, max_blocks=max_blocks, inputs=inputs, outputs=right_outputs, device=device)
747
+
748
+ set_element(flat_input, i, -eps, relative=True)
749
+
750
+ compute_fd(flat_left, flat_right, eps, jacobian_t[i])
751
+
752
+ output.grad.zero_()
753
+ jacobians[input_i, output_i] = jacobian
754
+
755
+ if plot_jacobians:
756
+ plot_kernel_jacobians(
757
+ jacobians,
758
+ kernel,
759
+ inputs,
760
+ outputs,
761
+ )
762
+
763
+ return jacobians
764
+
765
+
766
+ @wp.kernel(enable_backward=False)
767
+ def set_element_kernel(a: wp.array(dtype=Any), i: int, val: Any, relative: bool):
768
+ if relative:
769
+ a[i] += val
770
+ else:
771
+ a[i] = val
772
+
773
+
774
+ def set_element(a: wp.array(dtype=Any), i: int, val: Any, relative: bool = False):
775
+ wp.launch(set_element_kernel, dim=1, inputs=[a, i, a.dtype(val), relative], device=a.device)
776
+
777
+
778
+ @wp.kernel(enable_backward=False)
779
+ def compute_fd_kernel(left: wp.array(dtype=Any), right: wp.array(dtype=Any), eps: Any, fd: wp.array(dtype=Any)):
780
+ tid = wp.tid()
781
+ fd[tid] = (right[tid] - left[tid]) / (2.0 * eps)
782
+
783
+
784
+ def compute_fd(left: wp.array(dtype=Any), right: wp.array(dtype=Any), eps: float, fd: wp.array(dtype=Any)):
785
+ wp.launch(compute_fd_kernel, dim=len(left), inputs=[left, right, eps], outputs=[fd], device=left.device)
786
+
787
+
788
+ @wp.kernel(enable_backward=False)
789
+ def compute_error_kernel(
790
+ jacobian_ad: wp.array(dtype=Any),
791
+ jacobian_fd: wp.array(dtype=Any),
792
+ relative_error: wp.array(dtype=Any),
793
+ absolute_error: wp.array(dtype=Any),
794
+ ):
795
+ tid = wp.tid()
796
+ ad = jacobian_ad[tid]
797
+ fd = jacobian_fd[tid]
798
+ relative_error[tid] = (ad - fd) / (ad + 1e-8)
799
+ absolute_error[tid] = wp.abs(ad - fd)
800
+
801
+
802
+ def print_table(headers, cells):
803
+ """
804
+ Prints a table with the given headers and cells.
805
+
806
+ Args:
807
+ headers: List of header strings.
808
+ cells: List of lists of cell strings.
809
+ """
810
+ import re
811
+
812
+ def sanitized_len(s):
813
+ return len(re.sub(r"\033\[\d+m", "", str(s)))
814
+
815
+ col_widths = [max(sanitized_len(cell) for cell in col) for col in zip(headers, *cells)]
816
+ for header, col_width in zip(headers, col_widths):
817
+ print(f"{header:{col_width}}", end=" | ")
818
+ print()
819
+ print("-" * (sum(col_widths) + 3 * len(col_widths) - 1))
820
+ for cell_row in cells:
821
+ for cell, col_width in zip(cell_row, col_widths):
822
+ print(f"{cell:{col_width}}", end=" | ")
823
+ print()