warp-lang 1.0.0b2__py3-none-manylinux2014_x86_64.whl → 1.0.0b6__py3-none-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (269) hide show
  1. docs/conf.py +17 -5
  2. examples/env/env_ant.py +1 -1
  3. examples/env/env_cartpole.py +1 -1
  4. examples/env/env_humanoid.py +1 -1
  5. examples/env/env_usd.py +4 -1
  6. examples/env/environment.py +8 -9
  7. examples/example_dem.py +34 -33
  8. examples/example_diffray.py +364 -337
  9. examples/example_fluid.py +32 -23
  10. examples/example_jacobian_ik.py +97 -93
  11. examples/example_marching_cubes.py +6 -16
  12. examples/example_mesh.py +6 -16
  13. examples/example_mesh_intersect.py +16 -14
  14. examples/example_nvdb.py +14 -16
  15. examples/example_raycast.py +14 -13
  16. examples/example_raymarch.py +16 -23
  17. examples/example_render_opengl.py +19 -10
  18. examples/example_sim_cartpole.py +82 -78
  19. examples/example_sim_cloth.py +45 -48
  20. examples/example_sim_fk_grad.py +51 -44
  21. examples/example_sim_fk_grad_torch.py +47 -40
  22. examples/example_sim_grad_bounce.py +108 -133
  23. examples/example_sim_grad_cloth.py +99 -113
  24. examples/example_sim_granular.py +5 -6
  25. examples/{example_sim_sdf_shape.py → example_sim_granular_collision_sdf.py} +37 -26
  26. examples/example_sim_neo_hookean.py +51 -55
  27. examples/example_sim_particle_chain.py +4 -4
  28. examples/example_sim_quadruped.py +126 -81
  29. examples/example_sim_rigid_chain.py +54 -61
  30. examples/example_sim_rigid_contact.py +66 -70
  31. examples/example_sim_rigid_fem.py +3 -3
  32. examples/example_sim_rigid_force.py +1 -1
  33. examples/example_sim_rigid_gyroscopic.py +3 -4
  34. examples/example_sim_rigid_kinematics.py +28 -39
  35. examples/example_sim_trajopt.py +112 -110
  36. examples/example_sph.py +9 -8
  37. examples/example_wave.py +7 -7
  38. examples/fem/bsr_utils.py +30 -17
  39. examples/fem/example_apic_fluid.py +85 -69
  40. examples/fem/example_convection_diffusion.py +97 -93
  41. examples/fem/example_convection_diffusion_dg.py +142 -149
  42. examples/fem/example_convection_diffusion_dg0.py +141 -136
  43. examples/fem/example_deformed_geometry.py +146 -0
  44. examples/fem/example_diffusion.py +115 -84
  45. examples/fem/example_diffusion_3d.py +116 -86
  46. examples/fem/example_diffusion_mgpu.py +102 -79
  47. examples/fem/example_mixed_elasticity.py +139 -100
  48. examples/fem/example_navier_stokes.py +175 -162
  49. examples/fem/example_stokes.py +143 -111
  50. examples/fem/example_stokes_transfer.py +186 -157
  51. examples/fem/mesh_utils.py +59 -97
  52. examples/fem/plot_utils.py +138 -17
  53. tools/ci/publishing/build_nodes_info.py +54 -0
  54. warp/__init__.py +4 -3
  55. warp/__init__.pyi +1 -0
  56. warp/bin/warp-clang.so +0 -0
  57. warp/bin/warp.so +0 -0
  58. warp/build.py +5 -3
  59. warp/build_dll.py +29 -9
  60. warp/builtins.py +836 -492
  61. warp/codegen.py +864 -553
  62. warp/config.py +3 -1
  63. warp/context.py +389 -172
  64. warp/fem/__init__.py +24 -6
  65. warp/fem/cache.py +318 -25
  66. warp/fem/dirichlet.py +7 -3
  67. warp/fem/domain.py +14 -0
  68. warp/fem/field/__init__.py +30 -38
  69. warp/fem/field/field.py +149 -0
  70. warp/fem/field/nodal_field.py +244 -138
  71. warp/fem/field/restriction.py +8 -6
  72. warp/fem/field/test.py +127 -59
  73. warp/fem/field/trial.py +117 -60
  74. warp/fem/geometry/__init__.py +5 -1
  75. warp/fem/geometry/deformed_geometry.py +271 -0
  76. warp/fem/geometry/element.py +24 -1
  77. warp/fem/geometry/geometry.py +86 -14
  78. warp/fem/geometry/grid_2d.py +112 -54
  79. warp/fem/geometry/grid_3d.py +134 -65
  80. warp/fem/geometry/hexmesh.py +953 -0
  81. warp/fem/geometry/partition.py +85 -33
  82. warp/fem/geometry/quadmesh_2d.py +532 -0
  83. warp/fem/geometry/tetmesh.py +451 -115
  84. warp/fem/geometry/trimesh_2d.py +197 -92
  85. warp/fem/integrate.py +534 -268
  86. warp/fem/operator.py +58 -31
  87. warp/fem/polynomial.py +11 -0
  88. warp/fem/quadrature/__init__.py +1 -1
  89. warp/fem/quadrature/pic_quadrature.py +150 -58
  90. warp/fem/quadrature/quadrature.py +209 -57
  91. warp/fem/space/__init__.py +230 -53
  92. warp/fem/space/basis_space.py +489 -0
  93. warp/fem/space/collocated_function_space.py +105 -0
  94. warp/fem/space/dof_mapper.py +49 -2
  95. warp/fem/space/function_space.py +90 -39
  96. warp/fem/space/grid_2d_function_space.py +149 -496
  97. warp/fem/space/grid_3d_function_space.py +173 -538
  98. warp/fem/space/hexmesh_function_space.py +352 -0
  99. warp/fem/space/partition.py +129 -76
  100. warp/fem/space/quadmesh_2d_function_space.py +369 -0
  101. warp/fem/space/restriction.py +46 -34
  102. warp/fem/space/shape/__init__.py +15 -0
  103. warp/fem/space/shape/cube_shape_function.py +738 -0
  104. warp/fem/space/shape/shape_function.py +103 -0
  105. warp/fem/space/shape/square_shape_function.py +611 -0
  106. warp/fem/space/shape/tet_shape_function.py +567 -0
  107. warp/fem/space/shape/triangle_shape_function.py +429 -0
  108. warp/fem/space/tetmesh_function_space.py +132 -1039
  109. warp/fem/space/topology.py +295 -0
  110. warp/fem/space/trimesh_2d_function_space.py +104 -742
  111. warp/fem/types.py +13 -11
  112. warp/fem/utils.py +335 -60
  113. warp/native/array.h +120 -34
  114. warp/native/builtin.h +101 -72
  115. warp/native/bvh.cpp +73 -325
  116. warp/native/bvh.cu +406 -23
  117. warp/native/bvh.h +22 -40
  118. warp/native/clang/clang.cpp +1 -0
  119. warp/native/crt.h +2 -0
  120. warp/native/cuda_util.cpp +8 -3
  121. warp/native/cuda_util.h +1 -0
  122. warp/native/exports.h +1522 -1243
  123. warp/native/intersect.h +19 -4
  124. warp/native/intersect_adj.h +8 -8
  125. warp/native/mat.h +76 -17
  126. warp/native/mesh.cpp +33 -108
  127. warp/native/mesh.cu +114 -18
  128. warp/native/mesh.h +395 -40
  129. warp/native/noise.h +272 -329
  130. warp/native/quat.h +51 -8
  131. warp/native/rand.h +44 -34
  132. warp/native/reduce.cpp +1 -1
  133. warp/native/sparse.cpp +4 -4
  134. warp/native/sparse.cu +163 -155
  135. warp/native/spatial.h +2 -2
  136. warp/native/temp_buffer.h +18 -14
  137. warp/native/vec.h +103 -21
  138. warp/native/warp.cpp +2 -1
  139. warp/native/warp.cu +28 -3
  140. warp/native/warp.h +4 -3
  141. warp/render/render_opengl.py +261 -109
  142. warp/sim/__init__.py +1 -2
  143. warp/sim/articulation.py +385 -185
  144. warp/sim/import_mjcf.py +59 -48
  145. warp/sim/import_urdf.py +15 -15
  146. warp/sim/import_usd.py +174 -102
  147. warp/sim/inertia.py +17 -18
  148. warp/sim/integrator_xpbd.py +4 -3
  149. warp/sim/model.py +330 -250
  150. warp/sim/render.py +1 -1
  151. warp/sparse.py +625 -152
  152. warp/stubs.py +341 -309
  153. warp/tape.py +9 -6
  154. warp/tests/__main__.py +3 -6
  155. warp/tests/assets/curlnoise_golden.npy +0 -0
  156. warp/tests/assets/pnoise_golden.npy +0 -0
  157. warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
  158. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
  159. warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
  160. warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
  161. warp/tests/aux_test_unresolved_func.py +14 -0
  162. warp/tests/aux_test_unresolved_symbol.py +14 -0
  163. warp/tests/disabled_kinematics.py +239 -0
  164. warp/tests/run_coverage_serial.py +31 -0
  165. warp/tests/test_adam.py +103 -106
  166. warp/tests/test_arithmetic.py +94 -74
  167. warp/tests/test_array.py +82 -101
  168. warp/tests/test_array_reduce.py +57 -23
  169. warp/tests/test_atomic.py +64 -28
  170. warp/tests/test_bool.py +22 -12
  171. warp/tests/test_builtins_resolution.py +1292 -0
  172. warp/tests/test_bvh.py +18 -18
  173. warp/tests/test_closest_point_edge_edge.py +54 -57
  174. warp/tests/test_codegen.py +165 -134
  175. warp/tests/test_compile_consts.py +28 -20
  176. warp/tests/test_conditional.py +108 -24
  177. warp/tests/test_copy.py +10 -12
  178. warp/tests/test_ctypes.py +112 -88
  179. warp/tests/test_dense.py +21 -14
  180. warp/tests/test_devices.py +98 -0
  181. warp/tests/test_dlpack.py +75 -75
  182. warp/tests/test_examples.py +237 -0
  183. warp/tests/test_fabricarray.py +22 -24
  184. warp/tests/test_fast_math.py +15 -11
  185. warp/tests/test_fem.py +1034 -124
  186. warp/tests/test_fp16.py +23 -16
  187. warp/tests/test_func.py +187 -86
  188. warp/tests/test_generics.py +194 -49
  189. warp/tests/test_grad.py +123 -181
  190. warp/tests/test_grad_customs.py +176 -0
  191. warp/tests/test_hash_grid.py +35 -34
  192. warp/tests/test_import.py +10 -23
  193. warp/tests/test_indexedarray.py +24 -25
  194. warp/tests/test_intersect.py +18 -9
  195. warp/tests/test_large.py +141 -0
  196. warp/tests/test_launch.py +14 -41
  197. warp/tests/test_lerp.py +64 -65
  198. warp/tests/test_lvalue.py +493 -0
  199. warp/tests/test_marching_cubes.py +12 -13
  200. warp/tests/test_mat.py +517 -2898
  201. warp/tests/test_mat_lite.py +115 -0
  202. warp/tests/test_mat_scalar_ops.py +2889 -0
  203. warp/tests/test_math.py +103 -9
  204. warp/tests/test_matmul.py +304 -69
  205. warp/tests/test_matmul_lite.py +410 -0
  206. warp/tests/test_mesh.py +60 -22
  207. warp/tests/test_mesh_query_aabb.py +21 -25
  208. warp/tests/test_mesh_query_point.py +111 -22
  209. warp/tests/test_mesh_query_ray.py +12 -24
  210. warp/tests/test_mlp.py +30 -22
  211. warp/tests/test_model.py +92 -89
  212. warp/tests/test_modules_lite.py +39 -0
  213. warp/tests/test_multigpu.py +88 -114
  214. warp/tests/test_noise.py +12 -11
  215. warp/tests/test_operators.py +16 -20
  216. warp/tests/test_options.py +11 -11
  217. warp/tests/test_pinned.py +17 -18
  218. warp/tests/test_print.py +32 -11
  219. warp/tests/test_quat.py +275 -129
  220. warp/tests/test_rand.py +18 -16
  221. warp/tests/test_reload.py +38 -34
  222. warp/tests/test_rounding.py +50 -43
  223. warp/tests/test_runlength_encode.py +168 -20
  224. warp/tests/test_smoothstep.py +9 -11
  225. warp/tests/test_snippet.py +143 -0
  226. warp/tests/test_sparse.py +261 -63
  227. warp/tests/test_spatial.py +276 -243
  228. warp/tests/test_streams.py +110 -85
  229. warp/tests/test_struct.py +268 -63
  230. warp/tests/test_tape.py +39 -21
  231. warp/tests/test_torch.py +90 -86
  232. warp/tests/test_transient_module.py +10 -12
  233. warp/tests/test_types.py +363 -0
  234. warp/tests/test_utils.py +451 -0
  235. warp/tests/test_vec.py +354 -2050
  236. warp/tests/test_vec_lite.py +73 -0
  237. warp/tests/test_vec_scalar_ops.py +2099 -0
  238. warp/tests/test_volume.py +418 -376
  239. warp/tests/test_volume_write.py +124 -134
  240. warp/tests/unittest_serial.py +35 -0
  241. warp/tests/unittest_suites.py +291 -0
  242. warp/tests/unittest_utils.py +342 -0
  243. warp/tests/{test_misc.py → unused_test_misc.py} +13 -5
  244. warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
  245. warp/thirdparty/appdirs.py +36 -45
  246. warp/thirdparty/unittest_parallel.py +589 -0
  247. warp/types.py +622 -211
  248. warp/utils.py +54 -393
  249. warp_lang-1.0.0b6.dist-info/METADATA +238 -0
  250. warp_lang-1.0.0b6.dist-info/RECORD +409 -0
  251. {warp_lang-1.0.0b2.dist-info → warp_lang-1.0.0b6.dist-info}/WHEEL +1 -1
  252. examples/example_cache_management.py +0 -40
  253. examples/example_multigpu.py +0 -54
  254. examples/example_struct.py +0 -65
  255. examples/fem/example_stokes_transfer_3d.py +0 -210
  256. warp/fem/field/discrete_field.py +0 -80
  257. warp/fem/space/nodal_function_space.py +0 -233
  258. warp/tests/test_all.py +0 -223
  259. warp/tests/test_array_scan.py +0 -60
  260. warp/tests/test_base.py +0 -208
  261. warp/tests/test_unresolved_func.py +0 -7
  262. warp/tests/test_unresolved_symbol.py +0 -7
  263. warp_lang-1.0.0b2.dist-info/METADATA +0 -26
  264. warp_lang-1.0.0b2.dist-info/RECORD +0 -378
  265. /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
  266. /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
  267. /warp/tests/{test_square.py → aux_test_square.py} +0 -0
  268. {warp_lang-1.0.0b2.dist-info → warp_lang-1.0.0b6.dist-info}/LICENSE.md +0 -0
  269. {warp_lang-1.0.0b2.dist-info → warp_lang-1.0.0b6.dist-info}/top_level.txt +0 -0
@@ -11,26 +11,24 @@ and using semi-Lagrangian advection
11
11
  import argparse
12
12
 
13
13
  import warp as wp
14
+ import warp.fem as fem
14
15
 
15
16
  import numpy as np
16
17
 
17
- from warp.fem.types import *
18
- from warp.fem.geometry import Grid2D, Trimesh2D
19
- from warp.fem.field import make_test, make_trial
20
- from warp.fem.space import make_polynomial_space
21
- from warp.fem.quadrature import RegularQuadrature
22
- from warp.fem.domain import Cells, BoundarySides
23
- from warp.fem.integrate import integrate
24
- from warp.fem.operator import integrand, D, div, lookup
25
- from warp.fem.dirichlet import project_linear_system, normalize_dirichlet_projector
26
18
  from warp.fem.utils import array_axpy
27
19
 
28
20
  from warp.sparse import bsr_mm, bsr_mv, bsr_copy
29
21
 
30
- from bsr_utils import bsr_to_scipy
31
- from plot_utils import plot_grid_streamlines, plot_velocities
32
- from mesh_utils import gen_trimesh
22
+ try:
23
+ from .bsr_utils import bsr_to_scipy
24
+ from .plot_utils import Plot
25
+ from .mesh_utils import gen_trimesh
26
+ except ImportError:
27
+ from bsr_utils import bsr_to_scipy
28
+ from plot_utils import Plot
29
+ from mesh_utils import gen_trimesh
33
30
 
31
+ # need to solve a saddle-point system, use scopy for simplicity
34
32
  from scipy.sparse import bmat
35
33
  from scipy.sparse.linalg import factorized
36
34
 
@@ -38,8 +36,8 @@ import matplotlib.pyplot as plt
38
36
  import matplotlib.animation as animation
39
37
 
40
38
 
41
- @integrand
42
- def u_boundary_value(s: Sample, domain: Domain, v: Field, top_vel: float):
39
+ @fem.integrand
40
+ def u_boundary_value(s: fem.Sample, domain: fem.Domain, v: fem.Field, top_vel: float):
43
41
  # Horizontal velocity on top of domain, zero elsewhere
44
42
  if domain(s)[1] == 1.0:
45
43
  return wp.dot(wp.vec2f(top_vel, 0.0), v(s))
@@ -47,194 +45,209 @@ def u_boundary_value(s: Sample, domain: Domain, v: Field, top_vel: float):
47
45
  return wp.dot(wp.vec2f(0.0, 0.0), v(s))
48
46
 
49
47
 
50
- @integrand
48
+ @fem.integrand
51
49
  def mass_form(
52
- s: Sample,
53
- u: Field,
54
- v: Field,
50
+ s: fem.Sample,
51
+ u: fem.Field,
52
+ v: fem.Field,
55
53
  ):
56
54
  return wp.dot(u(s), v(s))
57
55
 
58
56
 
59
- @integrand
60
- def inertia_form(s: Sample, u: Field, v: Field, dt: float):
57
+ @fem.integrand
58
+ def inertia_form(s: fem.Sample, u: fem.Field, v: fem.Field, dt: float):
61
59
  return mass_form(s, u, v) / dt
62
60
 
63
61
 
64
- @integrand
65
- def viscosity_form(s: Sample, u: Field, v: Field, nu: float):
66
- return 2.0 * nu * wp.ddot(D(u, s), D(v, s))
62
+ @fem.integrand
63
+ def viscosity_form(s: fem.Sample, u: fem.Field, v: fem.Field, nu: float):
64
+ return 2.0 * nu * wp.ddot(fem.D(u, s), fem.D(v, s))
67
65
 
68
66
 
69
- @integrand
70
- def viscosity_and_inertia_form(s: Sample, u: Field, v: Field, dt: float, nu: float):
67
+ @fem.integrand
68
+ def viscosity_and_inertia_form(s: fem.Sample, u: fem.Field, v: fem.Field, dt: float, nu: float):
71
69
  return inertia_form(s, u, v, dt) + viscosity_form(s, u, v, nu)
72
70
 
73
71
 
74
- @integrand
75
- def transported_inertia_form(s: Sample, domain: Domain, u: Field, v: Field, dt: float):
72
+ @fem.integrand
73
+ def transported_inertia_form(s: fem.Sample, domain: fem.Domain, u: fem.Field, v: fem.Field, dt: float):
76
74
  pos = domain(s)
77
75
  vel = u(s)
78
76
 
79
77
  conv_pos = pos - 0.5 * vel * dt
80
- conv_s = lookup(domain, conv_pos, s)
78
+ conv_s = fem.lookup(domain, conv_pos, s)
81
79
  conv_vel = u(conv_s)
82
80
 
83
81
  conv_pos = conv_pos - 0.5 * conv_vel * dt
84
- conv_vel = u(lookup(domain, conv_pos, conv_s))
82
+ conv_vel = u(fem.lookup(domain, conv_pos, conv_s))
85
83
 
86
84
  return wp.dot(conv_vel, v(s)) / dt
87
85
 
88
86
 
89
- @integrand
87
+ @fem.integrand
90
88
  def div_form(
91
- s: Sample,
92
- u: Field,
93
- q: Field,
89
+ s: fem.Sample,
90
+ u: fem.Field,
91
+ q: fem.Field,
94
92
  ):
95
- return -q(s) * div(u, s)
93
+ return -q(s) * fem.div(u, s)
96
94
 
97
95
 
98
- if __name__ == "__main__":
99
- wp.init()
100
- wp.set_module_options({"enable_backward": False})
101
-
96
+ class Example:
102
97
  parser = argparse.ArgumentParser()
103
98
  parser.add_argument("--resolution", type=int, default=25)
104
99
  parser.add_argument("--degree", type=int, default=2)
105
- parser.add_argument("--n_frames", type=int, default=1000)
100
+ parser.add_argument("--num_frames", type=int, default=1000)
106
101
  parser.add_argument("--top_velocity", type=float, default=1.0)
107
102
  parser.add_argument("--Re", type=float, default=1000.0)
108
103
  parser.add_argument("--tri_mesh", action="store_true", help="Use a triangular mesh")
109
- args = parser.parse_args()
110
-
111
- if args.tri_mesh:
112
- positions, tri_vidx = gen_trimesh(res=vec2i(args.resolution))
113
- geo = Trimesh2D(tri_vertex_indices=tri_vidx, positions=positions)
114
- else:
115
- geo = Grid2D(res=vec2i(args.resolution))
116
-
117
- boundary = BoundarySides(geo)
118
-
119
- viscosity = args.top_velocity / args.Re
120
- dt = 1.0 / args.resolution
121
-
122
- domain = Cells(geometry=geo)
123
-
124
- # Functions spaces: Q(d)-Q(d-1)
125
- u_degree = args.degree
126
- u_space = make_polynomial_space(geo, degree=u_degree, dtype=wp.vec2)
127
- p_space = make_polynomial_space(geo, degree=u_degree - 1)
128
- quadrature = RegularQuadrature(domain=domain, order=2 * u_degree)
129
-
130
- # Viscosity and inertia
131
- u_test = make_test(space=u_space, domain=domain)
132
- u_trial = make_trial(space=u_space, domain=domain)
133
-
134
- u_matrix = integrate(
135
- viscosity_and_inertia_form,
136
- fields={"u": u_trial, "v": u_test},
137
- values={"nu": viscosity, "dt": dt},
138
- )
139
-
140
- # Pressure-velocity coupling
141
- p_test = make_test(space=p_space, domain=domain)
142
- div_matrix = integrate(div_form, fields={"u": u_trial, "q": p_test})
143
-
144
- # Enforcing the Dirichlet boundary condition the hard way;
145
- # build projector for velocity left- and right-hand-sides
146
- u_bd_test = make_test(space=u_space, domain=boundary)
147
- u_bd_trial = make_trial(space=u_space, domain=boundary)
148
- u_bd_projector = integrate(mass_form, fields={"u": u_bd_trial, "v": u_bd_test}, nodal=True)
149
- u_bd_value = integrate(
150
- u_boundary_value,
151
- fields={"v": u_bd_test},
152
- values={"top_vel": args.top_velocity},
153
- nodal=True,
154
- output_dtype=wp.vec2d,
155
- )
156
-
157
- normalize_dirichlet_projector(u_bd_projector, u_bd_value)
158
-
159
- u_bd_rhs = wp.zeros_like(u_bd_value)
160
- project_linear_system(u_matrix, u_bd_rhs, u_bd_projector, u_bd_value, normalize_projector=False)
161
-
162
- # div_bd_rhs = div_matrix * u_bd_rhs
163
- div_bd_rhs = wp.zeros(shape=(div_matrix.nrow,), dtype=div_matrix.scalar_type)
164
- bsr_mv(div_matrix, u_bd_rhs, y=div_bd_rhs)
165
-
166
- # div_matrix = div_matrix - div_matrix * bd_projector
167
- bsr_mm(x=bsr_copy(div_matrix), y=u_bd_projector, z=div_matrix, alpha=-1.0, beta=1.0)
168
-
169
- # Assemble saddle system with Scipy
170
- div_matrix = bsr_to_scipy(div_matrix)
171
- u_matrix = bsr_to_scipy(u_matrix)
172
- div_bd_rhs = div_bd_rhs.numpy()
173
-
174
- ones = np.ones(shape=(p_space.node_count(), 1), dtype=float)
175
- saddle_system = bmat(
176
- [
177
- [u_matrix, div_matrix.transpose(), None],
178
- [div_matrix, None, ones],
179
- [None, ones.transpose(), None],
180
- ],
181
- )
182
-
183
- with wp.ScopedTimer("LU factorization"):
184
- solve_saddle = factorized(saddle_system)
185
-
186
- u_k = u_space.make_field()
187
- u_rhs = wp.zeros_like(u_bd_rhs)
188
-
189
- results = [u_k.dof_values.numpy()]
190
-
191
- for k in range(args.n_frames):
192
- print("Solving step", k)
193
-
194
- u_inertia_rhs = integrate(
104
+
105
+ def __init__(self, stage=None, quiet=False, args=None, **kwargs):
106
+ if args is None:
107
+ # Read args from kwargs, add default arg values from parser
108
+ args = argparse.Namespace(**kwargs)
109
+ args = Example.parser.parse_args(args=[], namespace=args)
110
+ self._args = args
111
+ self._quiet = quiet
112
+
113
+ res = args.resolution
114
+ self.sim_dt = 1.0 / args.resolution
115
+ self.current_frame = 0
116
+
117
+ viscosity = args.top_velocity / args.Re
118
+
119
+ if args.tri_mesh:
120
+ positions, tri_vidx = gen_trimesh(res=wp.vec2i(res))
121
+ geo = fem.Trimesh2D(tri_vertex_indices=tri_vidx, positions=positions)
122
+ else:
123
+ geo = fem.Grid2D(res=wp.vec2i(res))
124
+
125
+ domain = fem.Cells(geometry=geo)
126
+ boundary = fem.BoundarySides(geo)
127
+
128
+ # Functions spaces: Q(d)-Q(d-1)
129
+ u_degree = args.degree
130
+ u_space = fem.make_polynomial_space(geo, degree=u_degree, dtype=wp.vec2)
131
+ p_space = fem.make_polynomial_space(geo, degree=u_degree - 1)
132
+
133
+ # Viscosity and inertia
134
+ u_test = fem.make_test(space=u_space, domain=domain)
135
+ u_trial = fem.make_trial(space=u_space, domain=domain)
136
+
137
+ u_matrix = fem.integrate(
138
+ viscosity_and_inertia_form,
139
+ fields={"u": u_trial, "v": u_test},
140
+ values={"nu": viscosity, "dt": self.sim_dt},
141
+ )
142
+
143
+ # Pressure-velocity coupling
144
+ p_test = fem.make_test(space=p_space, domain=domain)
145
+ div_matrix = fem.integrate(div_form, fields={"u": u_trial, "q": p_test})
146
+
147
+ # Enforcing the Dirichlet boundary condition the hard way;
148
+ # build projector for velocity left- and right-hand-sides
149
+ u_bd_test = fem.make_test(space=u_space, domain=boundary)
150
+ u_bd_trial = fem.make_trial(space=u_space, domain=boundary)
151
+ u_bd_projector = fem.integrate(mass_form, fields={"u": u_bd_trial, "v": u_bd_test}, nodal=True)
152
+ u_bd_value = fem.integrate(
153
+ u_boundary_value,
154
+ fields={"v": u_bd_test},
155
+ values={"top_vel": args.top_velocity},
156
+ nodal=True,
157
+ output_dtype=wp.vec2d,
158
+ )
159
+
160
+ fem.normalize_dirichlet_projector(u_bd_projector, u_bd_value)
161
+
162
+ u_bd_rhs = wp.zeros_like(u_bd_value)
163
+ fem.project_linear_system(u_matrix, u_bd_rhs, u_bd_projector, u_bd_value, normalize_projector=False)
164
+
165
+ # div_bd_rhs = div_matrix * u_bd_rhs
166
+ div_bd_rhs = wp.zeros(shape=(div_matrix.nrow,), dtype=div_matrix.scalar_type)
167
+ bsr_mv(div_matrix, u_bd_rhs, y=div_bd_rhs)
168
+
169
+ # div_matrix = div_matrix - div_matrix * bd_projector
170
+ bsr_mm(x=bsr_copy(div_matrix), y=u_bd_projector, z=div_matrix, alpha=-1.0, beta=1.0)
171
+
172
+ # Assemble saddle system with Scipy
173
+ div_matrix = bsr_to_scipy(div_matrix)
174
+ u_matrix = bsr_to_scipy(u_matrix)
175
+ div_bd_rhs = div_bd_rhs.numpy()
176
+
177
+ ones = np.ones(shape=(p_space.node_count(), 1), dtype=float)
178
+ saddle_system = bmat(
179
+ [
180
+ [u_matrix, div_matrix.transpose(), None],
181
+ [div_matrix, None, ones],
182
+ [None, ones.transpose(), None],
183
+ ],
184
+ )
185
+
186
+ with wp.ScopedTimer("LU factorization"):
187
+ self._solve_saddle = factorized(saddle_system)
188
+
189
+ # Save data for computing time steps rhs
190
+ self._u_bd_projector = u_bd_projector
191
+ self._u_bd_rhs = u_bd_rhs
192
+ self._u_test = u_test
193
+ self._div_bd_rhs = div_bd_rhs
194
+
195
+ # Velocitiy field
196
+
197
+ self._u_field = u_space.make_field()
198
+
199
+ self.renderer = Plot(stage)
200
+ self.renderer.add_surface_vector("velocity", self._u_field)
201
+
202
+ def update(self):
203
+ self.current_frame += 1
204
+
205
+ u_rhs = fem.integrate(
195
206
  transported_inertia_form,
196
- quadrature=quadrature,
197
- fields={"u": u_k, "v": u_test},
198
- values={"dt": dt},
207
+ fields={"u": self._u_field, "v": self._u_test},
208
+ values={"dt": self.sim_dt},
199
209
  output_dtype=wp.vec2d,
200
210
  )
201
- # u_rhs = (I - P) * u_inertia_rhs + u_bd_rhs
202
- bsr_mv(u_bd_projector, u_inertia_rhs, y=u_rhs, alpha=-1.0, beta=0.0)
203
- array_axpy(x=u_inertia_rhs, y=u_rhs, alpha=1.0, beta=1.0)
204
- array_axpy(x=u_bd_rhs, y=u_rhs, alpha=1.0, beta=1.0)
211
+
212
+ # Apply boundary conditions
213
+ # u_rhs = (I - P) * u_rhs + u_bd_rhs
214
+ bsr_mv(self._u_bd_projector, x=u_rhs, y=u_rhs, alpha=-1.0, beta=1.0)
215
+ array_axpy(x=self._u_bd_rhs, y=u_rhs, alpha=1.0, beta=1.0)
205
216
 
206
217
  # Assemble scipy saddle system rhs
207
- saddle_rhs = np.zeros(saddle_system.shape[0])
208
- u_slice = slice(0, 2 * u_space.node_count())
209
- p_slice = slice(2 * u_space.node_count(), 2 * u_space.node_count() + p_space.node_count())
218
+ u_dof_count = self._u_bd_projector.shape[0]
219
+ p_dof_count = self._div_bd_rhs.shape[0]
220
+ tot_dof_count = u_dof_count + p_dof_count + 1
221
+
222
+ u_slice = slice(0, u_dof_count)
223
+ p_slice = slice(u_dof_count, tot_dof_count - 1)
224
+
225
+ saddle_rhs = np.zeros(tot_dof_count)
210
226
  saddle_rhs[u_slice] = u_rhs.numpy().flatten()
211
- saddle_rhs[p_slice] = div_bd_rhs
227
+ saddle_rhs[p_slice] = self._div_bd_rhs
212
228
 
213
- x = solve_saddle(saddle_rhs)
229
+ x = self._solve_saddle(saddle_rhs)
214
230
 
215
231
  # Extract result
216
- x_u = x[u_slice].reshape((-1, 2))
217
- results.append(x_u)
218
-
219
- u_k.dof_values = x_u
220
- # p_field.dof_values = x[p_slice]
221
-
222
- if isinstance(geo, Grid2D):
223
- plot_grid_streamlines(u_k)
224
-
225
- quiver = plot_velocities(u_k)
226
- ax = quiver.axes
227
-
228
- def animate(i):
229
- ax.clear()
230
- u_k.dof_values = results[i]
231
- return plot_velocities(u_k, axes=ax)
232
-
233
- anim = animation.FuncAnimation(
234
- ax.figure,
235
- animate,
236
- interval=30,
237
- blit=False,
238
- frames=len(results),
239
- )
240
- plt.show()
232
+ self._u_field.dof_values = x[u_slice].reshape((-1, 2))
233
+
234
+ def render(self):
235
+ self.renderer.begin_frame(time = self.current_frame * self.sim_dt)
236
+ self.renderer.add_surface_vector("velocity", self._u_field)
237
+ self.renderer.end_frame()
238
+
239
+
240
+ if __name__ == "__main__":
241
+ wp.init()
242
+ wp.set_module_options({"enable_backward": False})
243
+
244
+ args = Example.parser.parse_args()
245
+
246
+ example = Example(args=args)
247
+ for k in range(args.num_frames):
248
+ print(f"Frame {k}:")
249
+ example.update()
250
+ example.render()
251
+
252
+ example.renderer.add_surface_vector("velocity_final", example._u_field)
253
+ example.renderer.plot(streamlines=set(["velocity_final"]))