xslope 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
xslope/plot_fem.py ADDED
@@ -0,0 +1,1658 @@
1
+ # Copyright 2025 Norman L. Jones
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import warnings
16
+
17
+ import matplotlib.patches as patches
18
+ import matplotlib.pyplot as plt
19
+ import numpy as np
20
+ from matplotlib.collections import LineCollection
21
+ from matplotlib.colors import LinearSegmentedColormap
22
+ from matplotlib.patches import Polygon
23
+
24
+
25
+ def plot_fem_data(fem_data, figsize=(14, 6), show_nodes=False, show_bc=True, material_table=False,
26
+ label_elements=False, label_nodes=False, alpha=0.4, bc_symbol_size=0.03):
27
+ """
28
+ Plots a FEM mesh colored by material zone with boundary conditions displayed.
29
+
30
+ Args:
31
+ fem_data: Dictionary containing FEM data from build_fem_data
32
+ figsize: Figure size
33
+ show_nodes: If True, plot node points
34
+ show_bc: If True, plot boundary condition symbols
35
+ material_table: If True, show material table
36
+ label_elements: If True, label each element with its number at its centroid
37
+ label_nodes: If True, label each node with its number just above and to the right
38
+ alpha: Transparency for element faces
39
+ bc_symbol_size: Size factor for boundary condition symbols (as fraction of mesh size)
40
+ """
41
+
42
+ # Extract data from fem_data
43
+ nodes = fem_data["nodes"]
44
+ elements = fem_data["elements"]
45
+ element_materials = fem_data["element_materials"]
46
+ element_types = fem_data.get("element_types", None)
47
+ bc_type = fem_data["bc_type"]
48
+ bc_values = fem_data["bc_values"]
49
+
50
+ fig, ax = plt.subplots(figsize=figsize)
51
+ materials = np.unique(element_materials)
52
+
53
+ # Import get_material_color to ensure consistent colors with plot_mesh
54
+ from .plot import get_material_color
55
+ mat_to_color = {mat: get_material_color(mat) for mat in materials}
56
+
57
+ # If element_types is not provided, assume all triangles (backward compatibility)
58
+ if element_types is None:
59
+ element_types = np.full(len(elements), 3)
60
+
61
+ # Plot mesh elements with material colors
62
+ for idx, element_nodes in enumerate(elements):
63
+ element_type = element_types[idx]
64
+ color = mat_to_color[element_materials[idx]]
65
+
66
+ if element_type == 3: # Linear triangle
67
+ polygon_coords = nodes[element_nodes[:3]]
68
+ polygon = Polygon(polygon_coords, edgecolor='k', facecolor=color, linewidth=0.5, alpha=alpha)
69
+ ax.add_patch(polygon)
70
+
71
+ elif element_type == 6: # Quadratic triangle - subdivide into 4 sub-triangles
72
+ # Corner nodes
73
+ n0, n1, n2 = nodes[element_nodes[0]], nodes[element_nodes[1]], nodes[element_nodes[2]]
74
+ # Midpoint nodes - standard GMSH pattern: n3=edge 0-1, n4=edge 1-2, n5=edge 2-0
75
+ n3, n4, n5 = nodes[element_nodes[3]], nodes[element_nodes[4]], nodes[element_nodes[5]]
76
+
77
+ # Create 4 sub-triangles with standard GMSH connectivity
78
+ sub_triangles = [
79
+ [n0, n3, n5], # Corner triangle at node 0 (uses midpoints 0-1 and 2-0)
80
+ [n3, n1, n4], # Corner triangle at node 1 (uses midpoints 0-1 and 1-2)
81
+ [n5, n4, n2], # Corner triangle at node 2 (uses midpoints 2-0 and 1-2)
82
+ [n3, n4, n5] # Center triangle (connects all midpoints)
83
+ ]
84
+
85
+ # Add all sub-triangles without internal edges
86
+ for sub_tri in sub_triangles:
87
+ polygon = Polygon(sub_tri, edgecolor='none', facecolor=color, alpha=alpha)
88
+ ax.add_patch(polygon)
89
+
90
+ # Add outer boundary of the tri6 element
91
+ outer_boundary = [n0, n1, n2, n0] # Close the triangle
92
+ ax.plot([p[0] for p in outer_boundary], [p[1] for p in outer_boundary],
93
+ 'k-', linewidth=0.5)
94
+
95
+ elif element_type == 4: # Linear quadrilateral
96
+ polygon_coords = nodes[element_nodes[:4]]
97
+ polygon = Polygon(polygon_coords, edgecolor='k', facecolor=color, linewidth=0.5, alpha=alpha)
98
+ ax.add_patch(polygon)
99
+
100
+ elif element_type == 8: # Quadratic quadrilateral - subdivide into 4 sub-quads
101
+ # Corner nodes
102
+ n0, n1, n2, n3 = nodes[element_nodes[0]], nodes[element_nodes[1]], nodes[element_nodes[2]], nodes[element_nodes[3]]
103
+ # Midpoint nodes
104
+ n4, n5, n6, n7 = nodes[element_nodes[4]], nodes[element_nodes[5]], nodes[element_nodes[6]], nodes[element_nodes[7]]
105
+
106
+ # Calculate center point (average of all 8 nodes)
107
+ center = ((n0[0] + n1[0] + n2[0] + n3[0] + n4[0] + n5[0] + n6[0] + n7[0]) / 8,
108
+ (n0[1] + n1[1] + n2[1] + n3[1] + n4[1] + n5[1] + n6[1] + n7[1]) / 8)
109
+
110
+ # Create 4 sub-quadrilaterals
111
+ sub_quads = [
112
+ [n0, n4, center, n7], # Sub-quad at corner 0
113
+ [n4, n1, n5, center], # Sub-quad at corner 1
114
+ [center, n5, n2, n6], # Sub-quad at corner 2
115
+ [n7, center, n6, n3] # Sub-quad at corner 3
116
+ ]
117
+
118
+ # Add all sub-quads without internal edges
119
+ for sub_quad in sub_quads:
120
+ polygon = Polygon(sub_quad, edgecolor='none', facecolor=color, alpha=alpha)
121
+ ax.add_patch(polygon)
122
+
123
+ # Add outer boundary of the quad8 element
124
+ outer_boundary = [n0, n1, n2, n3, n0] # Close the quadrilateral
125
+ ax.plot([p[0] for p in outer_boundary], [p[1] for p in outer_boundary],
126
+ 'k-', linewidth=0.5)
127
+
128
+ elif element_type == 9: # 9-node quadrilateral - subdivide using actual center node
129
+ # Corner nodes
130
+ n0, n1, n2, n3 = nodes[element_nodes[0]], nodes[element_nodes[1]], nodes[element_nodes[2]], nodes[element_nodes[3]]
131
+ # Midpoint nodes
132
+ n4, n5, n6, n7 = nodes[element_nodes[4]], nodes[element_nodes[5]], nodes[element_nodes[6]], nodes[element_nodes[7]]
133
+ # Center node
134
+ center = nodes[element_nodes[8]]
135
+
136
+ # Create 4 sub-quadrilaterals using the actual center node
137
+ sub_quads = [
138
+ [n0, n4, center, n7], # Sub-quad at corner 0
139
+ [n4, n1, n5, center], # Sub-quad at corner 1
140
+ [center, n5, n2, n6], # Sub-quad at corner 2
141
+ [n7, center, n6, n3] # Sub-quad at corner 3
142
+ ]
143
+
144
+ # Add all sub-quads without internal edges
145
+ for sub_quad in sub_quads:
146
+ polygon = Polygon(sub_quad, edgecolor='none', facecolor=color, alpha=alpha)
147
+ ax.add_patch(polygon)
148
+
149
+ # Add outer boundary of the quad9 element
150
+ outer_boundary = [n0, n1, n2, n3, n0] # Close the quadrilateral
151
+ ax.plot([p[0] for p in outer_boundary], [p[1] for p in outer_boundary],
152
+ 'k-', linewidth=0.5)
153
+
154
+ # Label element number at centroid if requested
155
+ if label_elements:
156
+ # Calculate centroid based on element type
157
+ if element_type in [3, 4]:
158
+ # For linear elements, use the polygon_coords
159
+ if element_type == 3:
160
+ element_coords = nodes[element_nodes[:3]]
161
+ else:
162
+ element_coords = nodes[element_nodes[:4]]
163
+ else:
164
+ # For quadratic elements, use all nodes to calculate centroid
165
+ if element_type == 6:
166
+ element_coords = nodes[element_nodes[:6]]
167
+ elif element_type == 8:
168
+ element_coords = nodes[element_nodes[:8]]
169
+ else: # element_type == 9
170
+ element_coords = nodes[element_nodes[:9]]
171
+
172
+ centroid = np.mean(element_coords, axis=0)
173
+ ax.text(centroid[0], centroid[1], str(idx+1),
174
+ ha='center', va='center', fontsize=6, color='black', alpha=0.4,
175
+ zorder=10)
176
+
177
+ if show_nodes:
178
+ ax.plot(nodes[:, 0], nodes[:, 1], 'k.', markersize=2)
179
+
180
+ # Label node numbers if requested
181
+ if label_nodes:
182
+ for i, (x, y) in enumerate(nodes):
183
+ ax.text(x + 0.5, y + 0.5, str(i+1), fontsize=6, color='blue', alpha=0.7,
184
+ ha='left', va='bottom', zorder=11)
185
+
186
+ # Get material names if available
187
+ material_names = fem_data.get("material_names", [])
188
+
189
+ legend_handles = []
190
+ for mat in materials:
191
+ # Use material name if available, otherwise use "Material {mat}"
192
+ if material_names and mat <= len(material_names):
193
+ label = material_names[mat - 1] # Convert to 0-based index
194
+ else:
195
+ label = f"Material {mat}"
196
+
197
+ legend_handles.append(
198
+ plt.Line2D([0], [0], color=mat_to_color[mat], lw=4, label=label)
199
+ )
200
+
201
+ # Plot boundary conditions
202
+ if show_bc:
203
+ _plot_boundary_conditions(ax, nodes, bc_type, bc_values, legend_handles, bc_symbol_size)
204
+
205
+ # Single combined legend outside the plot
206
+ ax.legend(
207
+ handles=legend_handles,
208
+ loc='upper center',
209
+ bbox_to_anchor=(0.5, -0.1),
210
+ ncol=3, # or more, depending on how many items you have
211
+ frameon=False
212
+ )
213
+ # Adjust plot limits to accommodate force arrows
214
+ x_min, x_max = nodes[:, 0].min(), nodes[:, 0].max()
215
+ y_min, y_max = nodes[:, 1].min(), nodes[:, 1].max()
216
+
217
+ # Add extra space for force arrows if they exist
218
+ force_nodes = np.where(bc_type == 4)[0]
219
+ if len(force_nodes) > 0:
220
+ # Find the extent of force arrows
221
+ mesh_size = min(x_max - x_min, y_max - y_min)
222
+ symbol_size = mesh_size * bc_symbol_size
223
+
224
+ # Add padding for force arrows (they extend outward from nodes)
225
+ y_padding = symbol_size * 4 # Extra space above for upward arrows
226
+ x_padding = (x_max - x_min) * 0.05 # Standard padding
227
+ y_padding_bottom = (y_max - y_min) * 0.05
228
+ else:
229
+ # Standard padding
230
+ x_padding = (x_max - x_min) * 0.05
231
+ y_padding = (y_max - y_min) * 0.05
232
+ y_padding_bottom = y_padding
233
+
234
+ ax.set_xlim(x_min - x_padding, x_max + x_padding)
235
+ ax.set_ylim(y_min - y_padding_bottom, y_max + y_padding)
236
+ ax.set_aspect("equal")
237
+
238
+ # Count element types for title
239
+ num_triangles = np.sum(element_types == 3)
240
+ num_quads = np.sum(element_types == 4)
241
+ if num_triangles > 0 and num_quads > 0:
242
+ title = f"FEM Mesh with Material Zones ({num_triangles} triangles, {num_quads} quads)"
243
+ elif num_quads > 0:
244
+ title = f"FEM Mesh with Material Zones ({num_quads} quadrilaterals)"
245
+ else:
246
+ title = f"FEM Mesh with Material Zones ({num_triangles} triangles)"
247
+
248
+ # Place the table in the upper left
249
+ if material_table:
250
+ _plot_fem_material_table(ax, fem_data, xloc=0.3, yloc=1.1) # upper left
251
+
252
+ ax.set_title(title)
253
+ plt.tight_layout()
254
+ plt.show()
255
+
256
+
257
+ def _plot_boundary_conditions(ax, nodes, bc_type, bc_values, legend_handles, bc_symbol_size=0.03):
258
+ """
259
+ Plot boundary condition symbols on the mesh.
260
+
261
+ BC types:
262
+ 0 = free (do nothing)
263
+ 1 = fixed (small triangle below node)
264
+ 2 = x roller (small circle + line, left/right sides)
265
+ 3 = y roller (shouldn't have any)
266
+ 4 = specified force (vector arrow)
267
+ """
268
+
269
+ # Get mesh bounds for symbol sizing
270
+ x_min, x_max = nodes[:, 0].min(), nodes[:, 0].max()
271
+ y_min, y_max = nodes[:, 1].min(), nodes[:, 1].max()
272
+ mesh_size = min(x_max - x_min, y_max - y_min)
273
+ symbol_size = mesh_size * bc_symbol_size # Adjustable symbol size
274
+
275
+ # Fixed boundary conditions (type 1) - triangle below node
276
+ fixed_nodes = np.where(bc_type == 1)[0]
277
+ if len(fixed_nodes) > 0:
278
+ for node_idx in fixed_nodes:
279
+ x, y = nodes[node_idx]
280
+ # Create small isosceles triangle below the node
281
+ triangle_height = symbol_size
282
+ triangle_width = symbol_size * 0.8
283
+ triangle = patches.Polygon([
284
+ [x - triangle_width/2, y - triangle_height],
285
+ [x + triangle_width/2, y - triangle_height],
286
+ [x, y]
287
+ ], closed=True, facecolor='none', edgecolor='red', linewidth=1.5)
288
+ ax.add_patch(triangle)
289
+
290
+ # Add to legend
291
+ legend_handles.append(
292
+ plt.Line2D([0], [0], marker='^', color='red', linestyle='None',
293
+ markersize=8, label='Fixed (bc_type=1)')
294
+ )
295
+
296
+ # X-roller boundary conditions (type 3) - circle + line on left/right sides
297
+ x_roller_nodes = np.where(bc_type == 2)[0]
298
+ if len(x_roller_nodes) > 0:
299
+ for node_idx in x_roller_nodes:
300
+ x, y = nodes[node_idx]
301
+
302
+ # Determine if node is on left or right side of mesh
303
+ is_left_side = x < (x_min + x_max) / 2
304
+
305
+ circle_radius = symbol_size * 0.4
306
+
307
+ if is_left_side:
308
+ # Put roller symbol on the left of node (circle touching node)
309
+ circle_center_x = x - circle_radius
310
+ line_x = circle_center_x - circle_radius
311
+ else:
312
+ # Put roller symbol on the right of node (circle touching node)
313
+ circle_center_x = x + circle_radius
314
+ line_x = circle_center_x + circle_radius
315
+
316
+ # Create small hollow circle
317
+ circle = patches.Circle((circle_center_x, y), circle_radius,
318
+ facecolor='none', edgecolor='blue', linewidth=1)
319
+ ax.add_patch(circle)
320
+
321
+ # Create tangent line
322
+ line_length = symbol_size
323
+ ax.plot([line_x, line_x], [y - line_length/2, y + line_length/2],
324
+ 'b-', linewidth=1)
325
+
326
+ # Add to legend
327
+ legend_handles.append(
328
+ plt.Line2D([0], [0], marker='o', color='blue', linestyle='None',
329
+ markersize=6, markerfacecolor='none', markeredgewidth=1, label='Y-Roller (bc_type=3)')
330
+ )
331
+
332
+ # Specified force boundary conditions (type 4) - vector arrows
333
+ force_nodes = np.where(bc_type == 4)[0]
334
+ if len(force_nodes) > 0:
335
+ # Find max force magnitude for scaling
336
+ force_magnitudes = []
337
+ for node_idx in force_nodes:
338
+ fx, fy = bc_values[node_idx]
339
+ force_magnitudes.append(np.sqrt(fx**2 + fy**2))
340
+
341
+ if force_magnitudes:
342
+ max_force = max(force_magnitudes)
343
+ if max_force > 0:
344
+ scale = symbol_size * 3 / max_force # Scale arrows to reasonable size
345
+
346
+ for node_idx in force_nodes:
347
+ x, y = nodes[node_idx]
348
+ fx, fy = bc_values[node_idx]
349
+
350
+ # Scale force components
351
+ scaled_fx = fx * scale
352
+ scaled_fy = fy * scale
353
+
354
+ # Draw arrow from force end to node (so arrow points to node)
355
+ ax.annotate('', xy=(x, y), xytext=(x - scaled_fx, y - scaled_fy),
356
+ arrowprops=dict(arrowstyle='->', color='green', lw=2))
357
+
358
+ # Add to legend
359
+ legend_handles.append(
360
+ plt.Line2D([0], [0], marker='>', color='green', linestyle='-',
361
+ markersize=8, label='Applied Force (bc_type=4)')
362
+ )
363
+
364
+
365
+ def _plot_fem_material_table(ax, fem_data, xloc=0.6, yloc=0.7):
366
+ """
367
+ Adds a FEM material properties table to the plot.
368
+
369
+ Parameters:
370
+ ax: matplotlib Axes object
371
+ fem_data: Dictionary containing FEM data with material properties
372
+ xloc: x-location of table (0-1)
373
+ yloc: y-location of table (0-1)
374
+
375
+ Returns:
376
+ None
377
+ """
378
+ # Extract material properties from fem_data
379
+ c_by_mat = fem_data.get("c_by_mat")
380
+ phi_by_mat = fem_data.get("phi_by_mat")
381
+ E_by_mat = fem_data.get("E_by_mat")
382
+ nu_by_mat = fem_data.get("nu_by_mat")
383
+ gamma_by_mat = fem_data.get("gamma_by_mat")
384
+ material_names = fem_data.get("material_names", [])
385
+
386
+ if c_by_mat is None or len(c_by_mat) == 0:
387
+ return
388
+
389
+ # Column headers for FEM properties
390
+ col_labels = ["Mat", "Name", "γ", "c", "φ", "E", "ν"]
391
+
392
+ # Build table rows
393
+ table_data = []
394
+ for idx in range(len(c_by_mat)):
395
+ c = c_by_mat[idx]
396
+ phi = phi_by_mat[idx] if phi_by_mat is not None else 0.0
397
+ E = E_by_mat[idx] if E_by_mat is not None else 0.0
398
+ nu = nu_by_mat[idx] if nu_by_mat is not None else 0.0
399
+ gamma = gamma_by_mat[idx] if gamma_by_mat is not None else 0.0
400
+
401
+ # Get material name, use default if not available
402
+ material_name = material_names[idx] if idx < len(material_names) else f"Material {idx+1}"
403
+
404
+ # Format values with appropriate precision
405
+ row = [
406
+ idx + 1, # Material number (1-based)
407
+ material_name, # Material name
408
+ f"{gamma:.1f}", # unit weight
409
+ f"{c:.1f}", # cohesion
410
+ f"{phi:.1f}", # friction angle
411
+ f"{E:.0f}", # Young's modulus
412
+ f"{nu:.2f}" # Poisson's ratio
413
+ ]
414
+ table_data.append(row)
415
+
416
+ # Add the table
417
+ table = ax.table(cellText=table_data,
418
+ colLabels=col_labels,
419
+ loc='upper right',
420
+ colLoc='center',
421
+ cellLoc='center',
422
+ bbox=[xloc, yloc, 0.45, 0.25]) # Increased width to accommodate name column
423
+ table.auto_set_font_size(False)
424
+ table.set_fontsize(8)
425
+
426
+
427
+ def plot_fem_results(fem_data, solution, plot_type='displacement', deform_scale=None,
428
+ show_mesh=True, show_reinforcement=True, figsize=(12, 8), label_elements=False,
429
+ plot_nodes=False, plot_elements=False, plot_boundary=True, displacement_tolerance=0.5,
430
+ scale_vectors=False):
431
+ """
432
+ Plot FEM results with various visualization options.
433
+
434
+ Parameters:
435
+ fem_data (dict): FEM data dictionary
436
+ solution (dict): FEM solution dictionary
437
+ plot_type (str): Type(s) of plot. Single type ('stress', 'displace_mag', 'displace_vector', 'deformation')
438
+ or comma-separated multiple types ('stress,deformation', 'displace_mag,displace_vector').
439
+ Multiple types are stacked vertically in the order specified.
440
+ deform_scale (float or None): Scale factor for deformed mesh visualization.
441
+ If None, automatically calculates scale factor so max deformation is 10% of mesh size.
442
+ If 1.0, shows actual displacements (may be too small or too large to see).
443
+ show_mesh (bool): Whether to show mesh lines
444
+ show_reinforcement (bool): Whether to show reinforcement elements
445
+ figsize (tuple): Figure size
446
+ label_elements (bool): If True, show element IDs at element centers
447
+ plot_nodes (bool): For displace_vector plots, show dots at all node locations
448
+ plot_elements (bool): For displace_vector plots, show all element edges
449
+ plot_boundary (bool): For displace_vector plots, show only boundary edges (default mesh display)
450
+ displacement_tolerance (float): Minimum displacement magnitude to show vectors (uses actual displacement)
451
+ scale_vectors (bool): For displace_vector plots, scale vectors for visualization; if False, use actual displacement
452
+
453
+ Returns:
454
+ matplotlib figure and axes (or list of axes for multiple plots)
455
+ """
456
+
457
+ nodes = fem_data["nodes"]
458
+ elements = fem_data["elements"]
459
+ element_types = fem_data["element_types"]
460
+ displacements = solution.get("displacements", np.zeros(2 * len(nodes)))
461
+
462
+ # Parse plot types (support comma-separated list)
463
+ plot_types = [pt.strip().lower() for pt in plot_type.split(',')]
464
+ valid_types = ['displace_mag', 'displace_vector', 'deformation', 'stress', 'strain', 'shear_strain', 'yield']
465
+
466
+ # Validate plot types
467
+ for pt in plot_types:
468
+ if pt not in valid_types:
469
+ raise ValueError(f"Unknown plot_type: '{pt}'. Valid types: {valid_types}")
470
+
471
+ # Set default deformation scale to 1.0 to match vector plot behavior
472
+ if deform_scale is None:
473
+ deform_scale = 1.0 # Default to actual displacement scale
474
+
475
+ # Create subplots based on number of plot types
476
+ n_plots = len(plot_types)
477
+ if n_plots == 1:
478
+ fig, ax = plt.subplots(figsize=figsize)
479
+ axes = [ax]
480
+ else:
481
+ # For multiple plots, adjust height scaling and use tighter spacing
482
+ height_factor = min(0.8, 1.2 / n_plots) # Reduce height factor for more plots
483
+ fig, axes = plt.subplots(n_plots, 1, figsize=(figsize[0], figsize[1] * n_plots * height_factor))
484
+ if n_plots == 1: # Handle case where subplots returns single axis for n=1
485
+ axes = [axes]
486
+
487
+
488
+
489
+ # Calculate overall mesh bounds for consistent axis limits
490
+ nodes = fem_data["nodes"]
491
+ x_min, x_max = np.min(nodes[:, 0]), np.max(nodes[:, 0])
492
+ y_min, y_max = np.min(nodes[:, 1]), np.max(nodes[:, 1])
493
+
494
+ # Add small margin
495
+ x_margin = (x_max - x_min) * 0.05
496
+ y_margin = (y_max - y_min) * 0.05
497
+
498
+ # Plot each type
499
+ for i, pt in enumerate(plot_types):
500
+ ax = axes[i]
501
+
502
+ # Calculate colorbar parameters based on number of plots
503
+ if n_plots == 1:
504
+ cbar_shrink = 0.8
505
+ cbar_labelpad = 20
506
+ elif n_plots == 2:
507
+ cbar_shrink = 0.7 # Slightly larger than before
508
+ cbar_labelpad = 15
509
+ else: # 3 or more plots
510
+ cbar_shrink = 0.5 # Slightly larger than before
511
+ cbar_labelpad = 12
512
+
513
+ if pt == 'displace_mag':
514
+ plot_displacement_contours(ax, fem_data, solution, show_mesh, show_reinforcement,
515
+ cbar_shrink=cbar_shrink, cbar_labelpad=cbar_labelpad, label_elements=label_elements)
516
+ elif pt == 'displace_vector':
517
+ plot_displacement_vectors(ax, fem_data, solution, show_mesh, show_reinforcement,
518
+ cbar_shrink=cbar_shrink, cbar_labelpad=cbar_labelpad, label_elements=label_elements,
519
+ plot_nodes=plot_nodes, plot_elements=plot_elements, plot_boundary=plot_boundary,
520
+ displacement_tolerance=displacement_tolerance, scale_vectors=scale_vectors)
521
+ elif pt == 'deformation':
522
+ plot_deformed_mesh(ax, fem_data, solution, deform_scale, show_mesh, show_reinforcement,
523
+ cbar_shrink=cbar_shrink, cbar_labelpad=cbar_labelpad, label_elements=label_elements)
524
+ elif pt == 'stress':
525
+ plot_stress_contours(ax, fem_data, solution, show_mesh, show_reinforcement,
526
+ cbar_shrink=cbar_shrink, cbar_labelpad=cbar_labelpad, label_elements=label_elements)
527
+ elif pt == 'strain':
528
+ plot_strain_contours(ax, fem_data, solution, show_mesh, show_reinforcement,
529
+ cbar_shrink=cbar_shrink, cbar_labelpad=cbar_labelpad, label_elements=label_elements)
530
+ elif pt == 'shear_strain':
531
+ plot_shear_strain_contours(ax, fem_data, solution, show_mesh, show_reinforcement,
532
+ cbar_shrink=cbar_shrink, cbar_labelpad=cbar_labelpad, label_elements=label_elements)
533
+ elif pt == 'yield':
534
+ plot_yield_function_contours(ax, fem_data, solution, show_mesh, show_reinforcement,
535
+ cbar_shrink=cbar_shrink, cbar_labelpad=cbar_labelpad, label_elements=label_elements)
536
+
537
+ # Set consistent axis limits for all plots (including single plots)
538
+ ax.set_xlim(x_min - x_margin, x_max + x_margin)
539
+ ax.set_ylim(y_min - y_margin, y_max + y_margin)
540
+ ax.set_aspect('equal')
541
+
542
+ plt.tight_layout()
543
+ plt.show()
544
+
545
+ # Return appropriate values
546
+ if n_plots == 1:
547
+ return fig, axes[0]
548
+ else:
549
+ return fig, axes
550
+
551
+
552
+ def plot_displacement_contours(ax, fem_data, solution, show_mesh=True, show_reinforcement=True,
553
+ cbar_shrink=0.8, cbar_labelpad=20, label_elements=False):
554
+ """
555
+ Plot displacement magnitude contours.
556
+ """
557
+ nodes = fem_data["nodes"]
558
+ elements = fem_data["elements"]
559
+ element_types = fem_data["element_types"]
560
+ displacements = solution.get("displacements", np.zeros(2 * len(nodes)))
561
+
562
+ # Calculate displacement magnitudes
563
+ u = displacements[0::2] # x-displacements
564
+ v = displacements[1::2] # y-displacements
565
+ disp_mag = np.sqrt(u**2 + v**2)
566
+
567
+ # Create triangulation for contouring
568
+ triangles = []
569
+ for i, elem in enumerate(elements):
570
+ elem_type = element_types[i]
571
+ if elem_type == 3: # Triangle
572
+ triangles.append([elem[0], elem[1], elem[2]])
573
+ elif elem_type == 4: # Quad - split into triangles
574
+ triangles.append([elem[0], elem[1], elem[2]])
575
+ triangles.append([elem[0], elem[2], elem[3]])
576
+ elif elem_type == 6: # 6-node triangle - use corner nodes
577
+ triangles.append([elem[0], elem[1], elem[2]])
578
+ elif elem_type in [8, 9]: # 8-node or 9-node quad - use corner nodes
579
+ triangles.append([elem[0], elem[1], elem[2]])
580
+ triangles.append([elem[0], elem[2], elem[3]])
581
+
582
+ if triangles:
583
+ triangles = np.array(triangles)
584
+
585
+ # Create contour plot
586
+ tcf = ax.tricontourf(nodes[:, 0], nodes[:, 1], triangles, disp_mag,
587
+ levels=20, cmap='viridis', alpha=0.8)
588
+
589
+ # Colorbar
590
+ cbar = plt.colorbar(tcf, ax=ax, shrink=cbar_shrink)
591
+ cbar.set_label('Displacement Magnitude', rotation=270, labelpad=cbar_labelpad)
592
+
593
+ # Plot mesh
594
+ if show_mesh:
595
+ plot_mesh_lines(ax, fem_data, color='black', alpha=0.3, linewidth=0.5)
596
+
597
+ # Plot reinforcement
598
+ if show_reinforcement and 'elements_1d' in fem_data:
599
+ plot_reinforcement_lines(ax, fem_data, solution)
600
+
601
+ # Add element labels if requested
602
+ if label_elements:
603
+ _add_element_labels(ax, fem_data)
604
+
605
+ ax.set_aspect('equal')
606
+ ax.set_title('Displacement Magnitude Contours')
607
+ ax.set_xlabel('x')
608
+ ax.set_ylabel('y')
609
+
610
+
611
+ def _get_mesh_boundary(fem_data):
612
+ """
613
+ Compute the boundary edges of the mesh.
614
+
615
+ Returns:
616
+ boundary_edges: List of (node1, node2) tuples representing boundary edges
617
+ """
618
+ nodes = fem_data["nodes"]
619
+ elements = fem_data["elements"]
620
+ element_types = fem_data["element_types"]
621
+
622
+ # Count how many times each edge appears
623
+ edge_count = {}
624
+
625
+ for i, elem in enumerate(elements):
626
+ elem_type = element_types[i]
627
+
628
+ # Define edges for each element type
629
+ if elem_type == 3: # Triangle
630
+ edges = [(elem[0], elem[1]), (elem[1], elem[2]), (elem[2], elem[0])]
631
+ elif elem_type == 4: # Quadrilateral
632
+ edges = [(elem[0], elem[1]), (elem[1], elem[2]), (elem[2], elem[3]), (elem[3], elem[0])]
633
+ elif elem_type == 6: # 6-node triangle - use corner nodes
634
+ edges = [(elem[0], elem[1]), (elem[1], elem[2]), (elem[2], elem[0])]
635
+ elif elem_type in [8, 9]: # 8-node or 9-node quad - use corner nodes
636
+ edges = [(elem[0], elem[1]), (elem[1], elem[2]), (elem[2], elem[3]), (elem[3], elem[0])]
637
+ else:
638
+ continue
639
+
640
+ # Count each edge (both directions)
641
+ for edge in edges:
642
+ # Normalize edge direction (smaller node first)
643
+ normalized_edge = tuple(sorted(edge))
644
+ edge_count[normalized_edge] = edge_count.get(normalized_edge, 0) + 1
645
+
646
+ # Boundary edges appear only once
647
+ boundary_edges = [edge for edge, count in edge_count.items() if count == 1]
648
+
649
+ return boundary_edges
650
+
651
+
652
+ def plot_displacement_vectors(ax, fem_data, solution, show_mesh=True, show_reinforcement=True,
653
+ cbar_shrink=0.8, cbar_labelpad=20, label_elements=False,
654
+ plot_nodes=True, plot_elements=False, plot_boundary=False,
655
+ displacement_tolerance=1e-6, scale_vectors=False):
656
+ """
657
+ Plot displacement vectors at nodes with plastic strain.
658
+ The tail of each vector is at the original node location and the head is at the final location.
659
+
660
+ Vectors are ALWAYS plotted at ALL nodes with plastic strain above the tolerance.
661
+ The plot_nodes/plot_elements/plot_boundary options control additional visual elements:
662
+
663
+ Parameters:
664
+ plot_nodes: If True, show dots at all node locations
665
+ plot_elements: If True, show all element edges
666
+ plot_boundary: If True, show only boundary edges (default mesh display)
667
+ displacement_tolerance: Minimum displacement magnitude to show vectors (uses actual displacement)
668
+ scale_vectors: If True, scale vectors for visualization; if False, use actual displacement
669
+ """
670
+ nodes = fem_data["nodes"]
671
+ elements = fem_data["elements"]
672
+ element_types = fem_data["element_types"]
673
+ displacements = solution.get("displacements", np.zeros(2 * len(nodes)))
674
+ plastic_elements = solution.get("plastic_elements", np.zeros(len(elements), dtype=bool))
675
+
676
+ # Calculate displacement components
677
+ u = displacements[0::2] # x-displacements
678
+ v = displacements[1::2] # y-displacements
679
+
680
+ # First, find all nodes with displacement above tolerance
681
+ nodes_above_tolerance = set()
682
+ for node_idx in range(len(nodes)):
683
+ disp_mag = np.sqrt(u[node_idx]**2 + v[node_idx]**2)
684
+ if disp_mag > displacement_tolerance:
685
+ nodes_above_tolerance.add(node_idx)
686
+
687
+ # Then, find nodes that belong to elements with plastic strain
688
+ plastic_nodes = set()
689
+ for i, elem in enumerate(elements):
690
+ if plastic_elements[i]:
691
+ elem_type = element_types[i]
692
+ # Add all nodes of this element
693
+ for j in range(elem_type):
694
+ if j < len(elem):
695
+ plastic_nodes.add(elem[j])
696
+
697
+ # Only keep nodes that have BOTH plastic strain AND displacement above tolerance
698
+ target_nodes = list(plastic_nodes.intersection(nodes_above_tolerance))
699
+ target_nodes = [node for node in target_nodes if node < len(nodes)]
700
+
701
+ if not target_nodes:
702
+ print("Warning: No target nodes found for displacement vector plot")
703
+ return
704
+
705
+ # Calculate vector scaling for visualization
706
+ max_disp_mag = np.max(np.sqrt(u**2 + v**2))
707
+ if scale_vectors and max_disp_mag > 0:
708
+ # Scale vectors so the maximum displacement is about 10% of mesh size
709
+ mesh_x_size = np.max(nodes[:, 0]) - np.min(nodes[:, 0])
710
+ mesh_y_size = np.max(nodes[:, 1]) - np.min(nodes[:, 1])
711
+ mesh_size = min(mesh_x_size, mesh_y_size)
712
+ scale_factor = (mesh_size * 0.1) / max_disp_mag
713
+ else:
714
+ scale_factor = 1.0
715
+
716
+ # Plot displacement vectors (all target_nodes already meet both criteria)
717
+ vectors_plotted = 0
718
+ for node_idx in target_nodes:
719
+ x_orig = nodes[node_idx, 0]
720
+ y_orig = nodes[node_idx, 1]
721
+
722
+ # Apply scaling only for visualization
723
+ u_plot = u[node_idx] * scale_factor
724
+ v_plot = v[node_idx] * scale_factor
725
+
726
+ # Calculate mesh size for arrow sizing
727
+ mesh_x_size = np.max(nodes[:, 0]) - np.min(nodes[:, 0])
728
+ mesh_y_size = np.max(nodes[:, 1]) - np.min(nodes[:, 1])
729
+ mesh_size = min(mesh_x_size, mesh_y_size)
730
+
731
+ ax.arrow(x_orig, y_orig, u_plot, v_plot,
732
+ head_width=mesh_size*0.01, head_length=mesh_size*0.015,
733
+ fc='black', ec='black', alpha=0.8, linewidth=1.0)
734
+ vectors_plotted += 1
735
+
736
+ print(f"Plotted {vectors_plotted} displacement vectors (tolerance = {displacement_tolerance:.2e})")
737
+
738
+ # Plot additional visual elements based on options
739
+ if show_mesh:
740
+ if plot_elements:
741
+ # Plot all element edges
742
+ plot_mesh_lines(ax, fem_data, color='lightgray', alpha=0.5, linewidth=0.5)
743
+ elif plot_boundary:
744
+ # Plot only boundary edges
745
+ boundary_edges = _get_mesh_boundary(fem_data)
746
+ for edge in boundary_edges:
747
+ x_coords = [nodes[edge[0], 0], nodes[edge[1], 0]]
748
+ y_coords = [nodes[edge[0], 1], nodes[edge[1], 1]]
749
+ ax.plot(x_coords, y_coords, 'k-', alpha=0.7, linewidth=1.0)
750
+
751
+ # Plot node dots if requested
752
+ if plot_nodes:
753
+ ax.plot(nodes[:, 0], nodes[:, 1], 'k.', markersize=2, alpha=0.6)
754
+
755
+ # Plot reinforcement
756
+ if show_reinforcement and 'elements_1d' in fem_data:
757
+ plot_reinforcement_lines(ax, fem_data, solution)
758
+
759
+ # Add element labels if requested
760
+ if label_elements:
761
+ _add_element_labels(ax, fem_data)
762
+
763
+ # Add a dummy colorbar to maintain consistent spacing with other plots
764
+ dummy_data = np.array([[0, 1]])
765
+ dummy_im = ax.imshow(dummy_data, cmap='viridis', alpha=0)
766
+ cbar = plt.colorbar(dummy_im, ax=ax, shrink=cbar_shrink)
767
+ cbar.set_label('Displacement Vectors', rotation=270, labelpad=cbar_labelpad, color='white')
768
+ cbar.set_ticks([])
769
+ cbar.set_ticklabels([])
770
+ cbar.outline.set_color('white')
771
+ cbar.outline.set_linewidth(0)
772
+
773
+ ax.set_aspect('equal')
774
+ ax.set_title(f'Displacement Vectors (Scale Factor = {scale_factor:.2f})')
775
+ ax.set_xlabel('x')
776
+ ax.set_ylabel('y')
777
+
778
+
779
+ def plot_stress_contours(ax, fem_data, solution, show_mesh=True, show_reinforcement=True,
780
+ cbar_shrink=0.8, cbar_labelpad=20, label_elements=False):
781
+ """
782
+ Plot von Mises stress contours.
783
+ """
784
+ nodes = fem_data["nodes"]
785
+ elements = fem_data["elements"]
786
+ element_types = fem_data["element_types"]
787
+ stresses = solution.get("stresses", np.zeros((len(elements), 4)))
788
+
789
+ # Use yield function to determine plastic elements for consistency
790
+ # If yield_function is available, use it; otherwise fall back to plastic_elements
791
+ yield_function = solution.get("yield_function", None)
792
+ if yield_function is not None:
793
+ plastic_elements = yield_function > 0 # F > 0 means yielding
794
+ else:
795
+ plastic_elements = solution.get("plastic_elements", np.zeros(len(elements), dtype=bool))
796
+
797
+ # Extract von Mises stresses
798
+ von_mises = stresses[:, 3] # 4th column is von Mises stress
799
+
800
+ # Create element patches with color based on stress
801
+ patches_list = []
802
+ stress_values = []
803
+
804
+ for i, elem in enumerate(elements):
805
+ elem_type = element_types[i]
806
+ if elem_type == 3: # Triangle
807
+ coords = nodes[elem[:3]]
808
+ patch = Polygon(coords, closed=True)
809
+ patches_list.append(patch)
810
+ stress_values.append(von_mises[i])
811
+ elif elem_type == 4: # Quadrilateral
812
+ coords = nodes[elem[:4]]
813
+ patch = Polygon(coords, closed=True)
814
+ patches_list.append(patch)
815
+ stress_values.append(von_mises[i])
816
+ elif elem_type == 6: # 6-node triangle - use corner nodes
817
+ coords = nodes[elem[:3]]
818
+ patch = Polygon(coords, closed=True)
819
+ patches_list.append(patch)
820
+ stress_values.append(von_mises[i])
821
+ elif elem_type in [8, 9]: # 8-node or 9-node quad - use corner nodes
822
+ coords = nodes[elem[:4]]
823
+ patch = Polygon(coords, closed=True)
824
+ patches_list.append(patch)
825
+ stress_values.append(von_mises[i])
826
+
827
+ if patches_list:
828
+ from matplotlib.collections import PatchCollection
829
+
830
+ # Create patch collection
831
+ p = PatchCollection(patches_list, alpha=0.8, edgecolors='none')
832
+ p.set_array(np.array(stress_values))
833
+ p.set_cmap('plasma')
834
+ ax.add_collection(p)
835
+
836
+ # Colorbar
837
+ cbar = plt.colorbar(p, ax=ax, shrink=cbar_shrink)
838
+ cbar.set_label('von Mises Stress', rotation=270, labelpad=cbar_labelpad)
839
+
840
+ # Highlight plastic elements with thick boundary
841
+ if np.any(plastic_elements):
842
+ for i, elem in enumerate(elements):
843
+ if plastic_elements[i]:
844
+ elem_type = element_types[i]
845
+ if elem_type == 3: # Triangle
846
+ coords = nodes[elem[:3]]
847
+ coords = np.vstack([coords, coords[0]]) # Close the polygon
848
+ ax.plot(coords[:, 0], coords[:, 1], 'r-', linewidth=2, alpha=0.8)
849
+ elif elem_type == 4: # Quadrilateral
850
+ coords = nodes[elem[:4]]
851
+ coords = np.vstack([coords, coords[0]]) # Close the polygon
852
+ ax.plot(coords[:, 0], coords[:, 1], 'r-', linewidth=2, alpha=0.8)
853
+ elif elem_type == 6: # 6-node triangle - use corner nodes
854
+ coords = nodes[elem[:3]]
855
+ coords = np.vstack([coords, coords[0]]) # Close the polygon
856
+ ax.plot(coords[:, 0], coords[:, 1], 'r-', linewidth=2, alpha=0.8)
857
+ elif elem_type in [8, 9]: # 8-node or 9-node quad - use corner nodes
858
+ coords = nodes[elem[:4]]
859
+ coords = np.vstack([coords, coords[0]]) # Close the polygon
860
+ ax.plot(coords[:, 0], coords[:, 1], 'r-', linewidth=2, alpha=0.8)
861
+
862
+ # Plot mesh
863
+ if show_mesh:
864
+ plot_mesh_lines(ax, fem_data, color='gray', alpha=0.3, linewidth=0.3)
865
+
866
+ # Plot reinforcement with force visualization
867
+ if show_reinforcement and 'elements_1d' in fem_data:
868
+ plot_reinforcement_forces(ax, fem_data, solution)
869
+
870
+ # Add element labels if requested
871
+ if label_elements:
872
+ _add_element_labels(ax, fem_data)
873
+
874
+ ax.set_aspect('equal')
875
+ ax.set_title('von Mises Stress (Red outline = Yielding/Plastic Elements)')
876
+ ax.set_xlabel('x')
877
+ ax.set_ylabel('y')
878
+
879
+
880
+ def plot_deformed_mesh(ax, fem_data, solution, deform_scale=1.0, show_mesh=True, show_reinforcement=True,
881
+ cbar_shrink=0.8, cbar_labelpad=20, label_elements=False):
882
+ """
883
+ Plot deformed mesh overlay on original mesh.
884
+ """
885
+ nodes = fem_data["nodes"]
886
+ elements = fem_data["elements"]
887
+ element_types = fem_data["element_types"]
888
+ displacements = solution.get("displacements", np.zeros(2 * len(nodes)))
889
+
890
+ # Calculate deformed node positions
891
+ u = displacements[0::2]
892
+ v = displacements[1::2]
893
+ nodes_deformed = nodes + deform_scale * np.column_stack([u, v])
894
+
895
+ # Plot original mesh
896
+ if show_mesh:
897
+ plot_mesh_lines(ax, fem_data, color='lightgray', alpha=0.5, linewidth=1.0, label='Original')
898
+
899
+ # Plot deformed mesh
900
+ fem_data_deformed = fem_data.copy()
901
+ fem_data_deformed["nodes"] = nodes_deformed
902
+ plot_mesh_lines(ax, fem_data_deformed, color='blue', alpha=0.8, linewidth=1.5, label='Deformed')
903
+
904
+ # Plot reinforcement in both original and deformed configurations
905
+ if show_reinforcement and 'elements_1d' in fem_data:
906
+ plot_reinforcement_lines(ax, fem_data, solution, color='gray', alpha=0.5, linewidth=2, label='Original Reinforcement')
907
+ plot_reinforcement_lines(ax, fem_data_deformed, solution, color='red', alpha=0.8, linewidth=2, label='Deformed Reinforcement')
908
+
909
+ # Add element labels if requested
910
+ if label_elements:
911
+ _add_element_labels(ax, fem_data_deformed) # Label on deformed mesh
912
+
913
+ # Add a dummy colorbar to maintain consistent spacing with other plots
914
+ # This ensures the x-axis alignment is consistent across all subplots
915
+ dummy_data = np.array([[0, 1]])
916
+ dummy_im = ax.imshow(dummy_data, cmap='viridis', alpha=0)
917
+ cbar = plt.colorbar(dummy_im, ax=ax, shrink=cbar_shrink)
918
+ cbar.set_label('Deformation Scale', rotation=270, labelpad=cbar_labelpad, color='white')
919
+ cbar.set_ticks([]) # Remove tick marks
920
+ cbar.set_ticklabels([]) # Remove tick labels
921
+
922
+ # Make the colorbar completely invisible by setting colors to background
923
+ cbar.outline.set_color('white') # Make the border invisible
924
+ cbar.outline.set_linewidth(0) # Remove the border line
925
+
926
+ # Note: Axis limits will be set by the calling function for consistent multi-plot alignment
927
+ # When used as a standalone plot, matplotlib will auto-scale appropriately
928
+ ax.set_title(f'Mesh Deformation (Scale Factor = {deform_scale:.1f})')
929
+ ax.set_xlabel('x')
930
+ ax.set_ylabel('y')
931
+ if show_mesh or show_reinforcement:
932
+ ax.legend(loc='upper center', bbox_to_anchor=(0.5, -0.25), ncol=2)
933
+
934
+
935
+ def _add_element_labels(ax, fem_data):
936
+ """
937
+ Add element ID labels at element centers.
938
+ """
939
+ nodes = fem_data["nodes"]
940
+ elements = fem_data["elements"]
941
+ element_types = fem_data["element_types"]
942
+
943
+ for i, elem in enumerate(elements):
944
+ elem_type = element_types[i]
945
+
946
+ # Get element nodes for centroid calculation
947
+ if elem_type == 3: # Triangle
948
+ elem_nodes = nodes[elem[:3]]
949
+ elif elem_type == 4: # Quad
950
+ elem_nodes = nodes[elem[:4]]
951
+ elif elem_type == 6: # 6-node triangle - use corner nodes
952
+ elem_nodes = nodes[elem[:3]]
953
+ elif elem_type in [8, 9]: # 8 or 9-node quad - use corner nodes
954
+ elem_nodes = nodes[elem[:4]]
955
+ else:
956
+ continue
957
+
958
+ # Calculate centroid
959
+ centroid = np.mean(elem_nodes, axis=0)
960
+
961
+ # Add label (1-based indexing for display)
962
+ ax.text(centroid[0], centroid[1], str(i+1),
963
+ ha='center', va='center', fontsize=6,
964
+ color='darkblue', alpha=0.7, zorder=100)
965
+
966
+
967
+ def plot_mesh_lines(ax, fem_data, color='black', alpha=1.0, linewidth=1.0, label=None):
968
+ """
969
+ Plot mesh element boundaries.
970
+ """
971
+ nodes = fem_data["nodes"]
972
+ elements = fem_data["elements"]
973
+ element_types = fem_data["element_types"]
974
+
975
+ lines = []
976
+ for i, elem in enumerate(elements):
977
+ elem_type = element_types[i]
978
+ if elem_type == 3: # Triangle
979
+ # Add triangle edges
980
+ edges = [(elem[0], elem[1]), (elem[1], elem[2]), (elem[2], elem[0])]
981
+ elif elem_type == 4: # Quadrilateral
982
+ # Add quad edges
983
+ edges = [(elem[0], elem[1]), (elem[1], elem[2]), (elem[2], elem[3]), (elem[3], elem[0])]
984
+ elif elem_type == 6: # 6-node triangle - use corner nodes
985
+ edges = [(elem[0], elem[1]), (elem[1], elem[2]), (elem[2], elem[0])]
986
+ elif elem_type in [8, 9]: # 8-node or 9-node quad - use corner nodes
987
+ edges = [(elem[0], elem[1]), (elem[1], elem[2]), (elem[2], elem[3]), (elem[3], elem[0])]
988
+ else:
989
+ continue
990
+
991
+ for edge in edges:
992
+ line_coords = nodes[[edge[0], edge[1]]]
993
+ lines.append(line_coords)
994
+
995
+ if lines:
996
+ lc = LineCollection(lines, colors=color, alpha=alpha, linewidths=linewidth, label=label)
997
+ ax.add_collection(lc)
998
+
999
+
1000
+ def plot_reinforcement_lines(ax, fem_data, solution, color='red', alpha=1.0, linewidth=2, label=None):
1001
+ """
1002
+ Plot reinforcement elements as lines.
1003
+ """
1004
+ if 'elements_1d' not in fem_data:
1005
+ return
1006
+
1007
+ nodes = fem_data["nodes"]
1008
+ elements_1d = fem_data["elements_1d"]
1009
+ element_types_1d = fem_data["element_types_1d"]
1010
+
1011
+ lines = []
1012
+ for i, elem in enumerate(elements_1d):
1013
+ elem_type = element_types_1d[i]
1014
+ if elem_type >= 2: # At least 2 nodes
1015
+ line_coords = nodes[elem[:2]] # Use first two nodes for line
1016
+ lines.append(line_coords)
1017
+
1018
+ if lines:
1019
+ lc = LineCollection(lines, colors=color, alpha=alpha, linewidths=linewidth, label=label)
1020
+ ax.add_collection(lc)
1021
+
1022
+
1023
+ def plot_reinforcement_forces(ax, fem_data, solution):
1024
+ """
1025
+ Plot reinforcement elements with color based on force magnitude.
1026
+ """
1027
+ if 'elements_1d' not in fem_data:
1028
+ return
1029
+
1030
+ nodes = fem_data["nodes"]
1031
+ elements_1d = fem_data["elements_1d"]
1032
+ element_types_1d = fem_data["element_types_1d"]
1033
+ forces_1d = solution.get("forces_1d", np.zeros(len(elements_1d)))
1034
+ t_allow = fem_data.get("t_allow_by_1d_elem", np.ones(len(elements_1d)))
1035
+ failed_1d = solution.get("failed_1d_elements", np.zeros(len(elements_1d), dtype=bool))
1036
+
1037
+ lines = []
1038
+ force_ratios = []
1039
+
1040
+ for i, elem in enumerate(elements_1d):
1041
+ elem_type = element_types_1d[i]
1042
+ if elem_type >= 2: # At least 2 nodes
1043
+ line_coords = nodes[elem[:2]]
1044
+ lines.append(line_coords)
1045
+
1046
+ # Compute force ratio (force / allowable)
1047
+ if t_allow[i] > 0:
1048
+ force_ratio = abs(forces_1d[i]) / t_allow[i]
1049
+ else:
1050
+ force_ratio = 0.0
1051
+
1052
+ # Cap at 1.5 for color scaling
1053
+ force_ratios.append(min(force_ratio, 1.5))
1054
+
1055
+ if lines:
1056
+ # Create line collection with colors based on force ratio
1057
+ lc = LineCollection(lines, linewidths=3, alpha=0.8)
1058
+ lc.set_array(np.array(force_ratios))
1059
+ lc.set_cmap('coolwarm') # Blue = low force, Red = high force
1060
+ ax.add_collection(lc)
1061
+
1062
+ # Colorbar for reinforcement forces
1063
+ cbar = plt.colorbar(lc, ax=ax, shrink=0.6, pad=0.02)
1064
+ cbar.set_label('Force Ratio (Force/Allowable)', rotation=270, labelpad=15, fontsize=10)
1065
+
1066
+ # Mark failed elements with thick black outline
1067
+ if np.any(failed_1d):
1068
+ failed_lines = [lines[i] for i in range(len(lines)) if i < len(failed_1d) and failed_1d[i]]
1069
+ if failed_lines:
1070
+ lc_failed = LineCollection(failed_lines, colors='black', linewidths=5, alpha=0.6)
1071
+ ax.add_collection(lc_failed)
1072
+
1073
+
1074
+ def plot_reinforcement_force_profiles(fem_data, solution, figsize=(12, 8)):
1075
+ """
1076
+ Plot force profiles along each reinforcement line.
1077
+ """
1078
+ if 'elements_1d' not in fem_data:
1079
+ print("No reinforcement elements found")
1080
+ return None, None
1081
+
1082
+ nodes = fem_data["nodes"]
1083
+ elements_1d = fem_data["elements_1d"]
1084
+ element_materials_1d = fem_data["element_materials_1d"]
1085
+ forces_1d = solution.get("forces_1d", np.zeros(len(elements_1d)))
1086
+ t_allow = fem_data.get("t_allow_by_1d_elem", np.ones(len(elements_1d)))
1087
+ t_res = fem_data.get("t_res_by_1d_elem", np.zeros(len(elements_1d)))
1088
+ failed_1d = solution.get("failed_1d_elements", np.zeros(len(elements_1d), dtype=bool))
1089
+
1090
+ # Group elements by reinforcement line (material ID)
1091
+ unique_lines = np.unique(element_materials_1d)
1092
+ n_lines = len(unique_lines)
1093
+
1094
+ if n_lines == 0:
1095
+ print("No reinforcement lines found")
1096
+ return None, None
1097
+
1098
+ # Create subplot layout
1099
+ if n_lines <= 3:
1100
+ fig, axes = plt.subplots(n_lines, 1, figsize=figsize, squeeze=False)
1101
+ axes = axes.flatten()
1102
+ else:
1103
+ rows = int(np.ceil(n_lines / 2))
1104
+ fig, axes = plt.subplots(rows, 2, figsize=figsize, squeeze=False)
1105
+ axes = axes.flatten()
1106
+
1107
+ for line_idx, line_id in enumerate(unique_lines):
1108
+ ax = axes[line_idx]
1109
+
1110
+ # Get elements for this line
1111
+ line_elements = np.where(element_materials_1d == line_id)[0]
1112
+
1113
+ if len(line_elements) == 0:
1114
+ continue
1115
+
1116
+ # Get element positions along the line
1117
+ positions = []
1118
+ forces = []
1119
+ t_allow_line = []
1120
+ t_res_line = []
1121
+ failed_line = []
1122
+
1123
+ for elem_idx in line_elements:
1124
+ elem = elements_1d[elem_idx]
1125
+ # Use midpoint of element
1126
+ mid_point = 0.5 * (nodes[elem[0]] + nodes[elem[1]])
1127
+ # Distance along line (simplified - use x-coordinate)
1128
+ positions.append(mid_point[0])
1129
+ forces.append(forces_1d[elem_idx])
1130
+ t_allow_line.append(t_allow[elem_idx])
1131
+ t_res_line.append(t_res[elem_idx])
1132
+ failed_line.append(failed_1d[elem_idx])
1133
+
1134
+ # Sort by position
1135
+ sorted_indices = np.argsort(positions)
1136
+ positions = np.array(positions)[sorted_indices]
1137
+ forces = np.array(forces)[sorted_indices]
1138
+ t_allow_line = np.array(t_allow_line)[sorted_indices]
1139
+ t_res_line = np.array(t_res_line)[sorted_indices]
1140
+ failed_line = np.array(failed_line)[sorted_indices]
1141
+
1142
+ # Plot force profile
1143
+ ax.plot(positions, forces, 'b-o', linewidth=2, markersize=6, label='Tensile Force')
1144
+ ax.plot(positions, t_allow_line, 'g--', linewidth=1, label='Allowable Force')
1145
+
1146
+ if np.any(t_res_line > 0):
1147
+ ax.plot(positions, t_res_line, 'orange', linestyle='--', linewidth=1, label='Residual Force')
1148
+
1149
+ # Mark failed elements
1150
+ if np.any(failed_line):
1151
+ failed_positions = positions[failed_line]
1152
+ failed_forces = forces[failed_line]
1153
+ ax.scatter(failed_positions, failed_forces, color='red', s=100, marker='x',
1154
+ linewidth=3, label='Failed Elements', zorder=10)
1155
+
1156
+ # Formatting
1157
+ ax.set_xlabel('Position along line')
1158
+ ax.set_ylabel('Force')
1159
+ ax.set_title(f'Reinforcement Line {line_id} Force Profile')
1160
+ ax.grid(True, alpha=0.3)
1161
+ ax.legend()
1162
+
1163
+ # Set y-limits to show all relevant values
1164
+ max_val = max(np.max(np.abs(forces)), np.max(t_allow_line))
1165
+ if max_val > 0:
1166
+ ax.set_ylim([-max_val * 0.1, max_val * 1.1])
1167
+
1168
+ # Hide unused subplots
1169
+ for i in range(n_lines, len(axes)):
1170
+ axes[i].set_visible(False)
1171
+
1172
+ plt.tight_layout()
1173
+ return fig, axes
1174
+
1175
+
1176
+ def plot_ssrm_convergence(ssrm_solution, figsize=(10, 6)):
1177
+ """
1178
+ Plot SSRM convergence history.
1179
+ """
1180
+ if 'F_history' not in ssrm_solution:
1181
+ print("No SSRM convergence history found")
1182
+ return None, None
1183
+
1184
+ F_history = ssrm_solution['F_history']
1185
+ convergence_history = ssrm_solution['convergence_history']
1186
+
1187
+ fig, (ax1, ax2) = plt.subplots(2, 1, figsize=figsize)
1188
+
1189
+ # Plot F vs iteration
1190
+ iterations = range(1, len(F_history) + 1)
1191
+ colors = ['green' if conv else 'red' for conv in convergence_history]
1192
+
1193
+ ax1.scatter(iterations, F_history, c=colors, s=50, alpha=0.7)
1194
+ ax1.plot(iterations, F_history, 'k-', alpha=0.5)
1195
+
1196
+ # Mark final FS
1197
+ if 'FS' in ssrm_solution and ssrm_solution['FS'] is not None:
1198
+ ax1.axhline(y=ssrm_solution['FS'], color='blue', linestyle='--',
1199
+ linewidth=2, label=f"FS = {ssrm_solution['FS']:.3f}")
1200
+ ax1.legend()
1201
+
1202
+ ax1.set_xlabel('SSRM Iteration')
1203
+ ax1.set_ylabel('Reduction Factor F')
1204
+ ax1.set_title('SSRM Convergence History')
1205
+ ax1.grid(True, alpha=0.3)
1206
+
1207
+ # Plot convergence status
1208
+ conv_status = [1 if conv else 0 for conv in convergence_history]
1209
+ ax2.bar(iterations, conv_status, color=colors, alpha=0.7, width=0.8)
1210
+ ax2.set_xlabel('SSRM Iteration')
1211
+ ax2.set_ylabel('Converged')
1212
+ ax2.set_title('Convergence Status (Green=Converged, Red=Failed)')
1213
+ ax2.set_ylim([0, 1.2])
1214
+ ax2.grid(True, alpha=0.3)
1215
+
1216
+ plt.tight_layout()
1217
+ return fig, (ax1, ax2)
1218
+
1219
+
1220
+ def plot_strain_contours(ax, fem_data, solution, show_mesh=True, show_reinforcement=True,
1221
+ cbar_shrink=0.8, cbar_labelpad=20, label_elements=False):
1222
+ """
1223
+ Plot equivalent strain contours (von Mises equivalent strain).
1224
+ """
1225
+ nodes = fem_data["nodes"]
1226
+ elements = fem_data["elements"]
1227
+ element_types = fem_data["element_types"]
1228
+ strains = solution.get("strains", np.zeros((len(elements), 4)))
1229
+
1230
+ if strains.shape[1] < 3:
1231
+ print("Warning: Strain data not available or incomplete")
1232
+ return
1233
+
1234
+ # Calculate equivalent strain (von Mises equivalent strain)
1235
+ # For plane strain: equiv_strain = sqrt(2/3) * sqrt(eps_x^2 + eps_y^2 + eps_x*eps_y + 3/4*gamma_xy^2)
1236
+ eps_x = strains[:, 0]
1237
+ eps_y = strains[:, 1]
1238
+ gamma_xy = strains[:, 2]
1239
+
1240
+ equiv_strain = np.sqrt((2/3) * (eps_x**2 + eps_y**2 + eps_x*eps_y + 0.75*gamma_xy**2))
1241
+
1242
+ # Plot contours
1243
+ _plot_element_contours(ax, fem_data, equiv_strain, 'Equivalent Strain',
1244
+ show_mesh, show_reinforcement, cbar_shrink, cbar_labelpad, label_elements)
1245
+
1246
+
1247
+ def plot_shear_strain_contours(ax, fem_data, solution, show_mesh=True, show_reinforcement=True,
1248
+ cbar_shrink=0.8, cbar_labelpad=20, label_elements=False):
1249
+ """
1250
+ Plot maximum shear strain contours - key indicator for failure surfaces in slope stability.
1251
+ """
1252
+ nodes = fem_data["nodes"]
1253
+ elements = fem_data["elements"]
1254
+ element_types = fem_data["element_types"]
1255
+ strains = solution.get("strains", np.zeros((len(elements), 4)))
1256
+
1257
+ if strains.shape[1] < 4:
1258
+ print("Warning: Maximum shear strain data not available")
1259
+ return
1260
+
1261
+ # Extract maximum shear strain (4th column)
1262
+ max_shear_strain = strains[:, 3]
1263
+
1264
+ # Plot contours with specialized colormap for shear strain (red=high, blue=low)
1265
+ _plot_nodal_contours(ax, fem_data, max_shear_strain, 'Max Shear Strain',
1266
+ False, show_reinforcement, cbar_shrink, cbar_labelpad,
1267
+ colormap='coolwarm', label_elements=label_elements) # Coolwarm: red=high, blue=low
1268
+
1269
+ # Add title indicating this shows failure zones
1270
+ ax.set_title('Max Shear Strain (Failure Zone Indicator)', fontsize=12, pad=15)
1271
+
1272
+
1273
+ def plot_yield_function_contours(ax, fem_data, solution, show_mesh=True, show_reinforcement=True,
1274
+ cbar_shrink=0.8, cbar_labelpad=20, label_elements=False):
1275
+ """
1276
+ Plot yield function values (Mohr-Coulomb failure criterion).
1277
+ Positive values indicate yielding/failure, negative values indicate elastic state.
1278
+ """
1279
+ nodes = fem_data["nodes"]
1280
+ elements = fem_data["elements"]
1281
+ element_types = fem_data["element_types"]
1282
+ yield_function = solution.get("yield_function", None)
1283
+
1284
+ if yield_function is None:
1285
+ print("Warning: Yield function data not available in solution")
1286
+ # Create dummy data
1287
+ yield_function = np.zeros(len(elements))
1288
+
1289
+ # Create custom colormap for yield function visualization
1290
+ # Strong blue for very negative (very safe), white near zero, red for positive (yielding)
1291
+ from matplotlib.colors import LinearSegmentedColormap
1292
+
1293
+ # Define color transitions for yield function
1294
+ # F < 0: shades of blue/green (elastic/safe)
1295
+ # F = 0: white/light gray (critical)
1296
+ # F > 0: shades of red (yielding/plastic)
1297
+ colors_below = ['#0000FF', '#0066FF', '#00AAFF', '#00DDDD', '#CCCCCC'] # Blue to gray
1298
+ colors_above = ['#FFCCCC', '#FF9999', '#FF6666', '#FF3333', '#FF0000', '#CC0000'] # Light red to dark red
1299
+
1300
+ # Create custom colormap with sharp transition at F=0
1301
+ n_bins = 256
1302
+ n_below = int(n_bins * 0.7) # 70% for negative values
1303
+ n_above = n_bins - n_below # 30% for positive values
1304
+
1305
+ from matplotlib.colors import ListedColormap
1306
+ colors_below_interp = plt.cm.Blues_r(np.linspace(0.2, 0.9, n_below))
1307
+ colors_above_interp = plt.cm.Reds(np.linspace(0.3, 1.0, n_above))
1308
+ colors_all = np.vstack([colors_below_interp, colors_above_interp])
1309
+ cmap_yield = ListedColormap(colors_all)
1310
+
1311
+ # Set visualization bounds - asymmetric to focus on near-yield region
1312
+ vmin = -200 # Cap negative values for better contrast
1313
+ vmax = 50 # Positive values are more important
1314
+
1315
+ # Plot each element as a colored patch
1316
+ from matplotlib.collections import PatchCollection
1317
+ from matplotlib.patches import Polygon
1318
+ patches_list = []
1319
+ values_list = []
1320
+
1321
+ for i, elem in enumerate(elements):
1322
+ elem_type = element_types[i]
1323
+ if elem_type == 3: # Triangle
1324
+ coords = nodes[elem[:3]]
1325
+ elif elem_type == 4: # Quad
1326
+ coords = nodes[elem[:4]]
1327
+ elif elem_type == 6: # 6-node triangle - use corner nodes
1328
+ coords = nodes[elem[:3]]
1329
+ elif elem_type in [8, 9]: # 8 or 9-node quad - use corner nodes
1330
+ coords = nodes[elem[:4]]
1331
+ else:
1332
+ continue
1333
+
1334
+ patch = Polygon(coords, closed=True)
1335
+ patches_list.append(patch)
1336
+ # Clip values for visualization
1337
+ values_list.append(np.clip(yield_function[i], vmin, vmax))
1338
+
1339
+ if patches_list:
1340
+ p = PatchCollection(patches_list, alpha=0.9, edgecolors='gray', linewidths=0.3)
1341
+ p.set_array(np.array(values_list))
1342
+ p.set_cmap(cmap_yield)
1343
+ p.set_clim(vmin, vmax)
1344
+ ax.add_collection(p)
1345
+
1346
+ # Add colorbar with custom ticks
1347
+ cbar = plt.colorbar(p, ax=ax, shrink=cbar_shrink)
1348
+ cbar.set_label('Yield Function F', rotation=270, labelpad=cbar_labelpad)
1349
+
1350
+ # Set custom ticks to highlight key values
1351
+ tick_values = [-200, -100, -50, -20, -10, -5, 0, 5, 10, 20, 50]
1352
+ tick_labels = ['-200', '-100', '-50', '-20', '-10', '-5', '0', '5', '10', '20', '50']
1353
+ # Filter ticks to those within bounds
1354
+ valid_ticks = [(v, l) for v, l in zip(tick_values, tick_labels) if vmin <= v <= vmax]
1355
+ if valid_ticks:
1356
+ tick_values, tick_labels = zip(*valid_ticks)
1357
+ cbar.set_ticks(tick_values)
1358
+ cbar.set_ticklabels(tick_labels)
1359
+
1360
+ # Add a line at F=0
1361
+ cbar.ax.axhline(y=0, color='black', linewidth=2)
1362
+
1363
+ # Add yield function values as text on elements (if requested or for yielding elements)
1364
+ for i, elem in enumerate(elements):
1365
+ elem_type = element_types[i]
1366
+
1367
+ # Get element centroid
1368
+ if elem_type == 3: # Triangle
1369
+ elem_nodes = nodes[elem[:3]]
1370
+ elif elem_type == 4: # Quad
1371
+ elem_nodes = nodes[elem[:4]]
1372
+ elif elem_type == 6: # 6-node triangle - use corner nodes
1373
+ elem_nodes = nodes[elem[:3]]
1374
+ elif elem_type in [8, 9]: # 8 or 9-node quad - use corner nodes
1375
+ elem_nodes = nodes[elem[:4]]
1376
+ else:
1377
+ continue
1378
+
1379
+ centroid = np.mean(elem_nodes, axis=0)
1380
+
1381
+ # Show values for elements that are close to yielding or already yielding
1382
+ # or if label_elements is True
1383
+ f_val = yield_function[i]
1384
+
1385
+ if label_elements or f_val > -50: # Show if requested or if close to yielding
1386
+ # Format the number based on magnitude
1387
+ if abs(f_val) < 10:
1388
+ text = f'{f_val:.1f}'
1389
+ else:
1390
+ text = f'{f_val:.0f}'
1391
+
1392
+ # Choose text color based on value
1393
+ if f_val > 0:
1394
+ color = 'white' # White on red background
1395
+ fontweight = 'bold'
1396
+ elif f_val > -10:
1397
+ color = 'black' # Black on light background
1398
+ fontweight = 'normal'
1399
+ else:
1400
+ color = 'white' # White on blue background
1401
+ fontweight = 'normal'
1402
+
1403
+ # Only show for elements near yield or if explicitly requested
1404
+ if label_elements or f_val > -30:
1405
+ ax.text(centroid[0], centroid[1], text,
1406
+ ha='center', va='center', fontsize=5,
1407
+ color=color, fontweight=fontweight, alpha=0.8)
1408
+
1409
+ # Highlight yielding elements with thick red border
1410
+ for i, elem in enumerate(elements):
1411
+ if yield_function[i] > 0:
1412
+ elem_type = element_types[i]
1413
+ if elem_type == 3: # Triangle
1414
+ coords = nodes[elem[:3]]
1415
+ elif elem_type == 4: # Quad
1416
+ coords = nodes[elem[:4]]
1417
+ elif elem_type == 6: # 6-node triangle - use corner nodes
1418
+ coords = nodes[elem[:3]]
1419
+ elif elem_type in [8, 9]: # 8 or 9-node quad - use corner nodes
1420
+ coords = nodes[elem[:4]]
1421
+ else:
1422
+ continue
1423
+
1424
+ # Close the polygon
1425
+ coords = np.vstack([coords, coords[0]])
1426
+ ax.plot(coords[:, 0], coords[:, 1], 'k-', linewidth=2.5, alpha=1.0) # Black border for yielding elements
1427
+
1428
+ # Add reinforcement if requested
1429
+ if show_reinforcement and 'elements_1d' in fem_data:
1430
+ plot_reinforcement_lines(ax, fem_data, solution)
1431
+
1432
+ # Add title indicating yield state
1433
+ ax.set_title('Yield Function (Red: F>0 Yielding/Plastic, Blue: F<0 Elastic)', fontsize=12, pad=15)
1434
+
1435
+ # Add statistics to the plot
1436
+ n_yielding = np.sum(yield_function > 0)
1437
+ n_total = len(yield_function)
1438
+ n_critical = np.sum((yield_function > -10) & (yield_function <= 0)) # Near yielding
1439
+
1440
+ stats_text = f'Yielding: {n_yielding}/{n_total} elements\n'
1441
+ stats_text += f'Critical (F>-10): {n_critical} elements\n'
1442
+ stats_text += f'Max F: {np.max(yield_function):.1f}\n'
1443
+ stats_text += f'Min F: {np.min(yield_function):.1f}'
1444
+
1445
+ ax.text(0.02, 0.98, stats_text,
1446
+ transform=ax.transAxes, fontsize=9, verticalalignment='top',
1447
+ bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.7))
1448
+
1449
+
1450
+ def _plot_element_contours(ax, fem_data, values, label, show_mesh=True, show_reinforcement=True,
1451
+ cbar_shrink=0.8, cbar_labelpad=20, label_elements=False, colormap='viridis'):
1452
+ """
1453
+ Helper function to plot element-based contour data.
1454
+ """
1455
+ nodes = fem_data["nodes"]
1456
+ elements = fem_data["elements"]
1457
+ element_types = fem_data["element_types"]
1458
+
1459
+ # For element-based values, we need to interpolate to nodes or use a different approach
1460
+ # Let's use a simpler approach: plot each element as a colored patch
1461
+
1462
+ # Create contour plot by directly coloring elements
1463
+ if np.max(values) > np.min(values): # Only plot if there's variation
1464
+ # Normalize values for colormap
1465
+ vmin, vmax = np.min(values), np.max(values)
1466
+ norm = plt.Normalize(vmin=vmin, vmax=vmax)
1467
+ cmap = plt.get_cmap(colormap)
1468
+
1469
+ # Plot each element as colored patch
1470
+ for i, elem in enumerate(elements):
1471
+ elem_type = element_types[i]
1472
+ color = cmap(norm(values[i]))
1473
+
1474
+ if elem_type == 3: # Triangle
1475
+ coords = nodes[elem[:3]]
1476
+ triangle = plt.Polygon(coords, facecolor=color, edgecolor='none', alpha=0.8)
1477
+ ax.add_patch(triangle)
1478
+ elif elem_type == 4: # Quad
1479
+ coords = nodes[elem[:4]]
1480
+ quad = plt.Polygon(coords, facecolor=color, edgecolor='none', alpha=0.8)
1481
+ ax.add_patch(quad)
1482
+ elif elem_type == 6: # 6-node triangle - use corner nodes
1483
+ coords = nodes[elem[:3]]
1484
+ triangle = plt.Polygon(coords, facecolor=color, edgecolor='none', alpha=0.8)
1485
+ ax.add_patch(triangle)
1486
+ elif elem_type in [8, 9]: # 8-node or 9-node quad - use corner nodes
1487
+ coords = nodes[elem[:4]]
1488
+ quad = plt.Polygon(coords, facecolor=color, edgecolor='none', alpha=0.8)
1489
+ ax.add_patch(quad)
1490
+
1491
+ # Create colorbar using a ScalarMappable
1492
+ sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
1493
+ sm.set_array([])
1494
+ cbar = plt.colorbar(sm, ax=ax, shrink=cbar_shrink, pad=0.05)
1495
+ cbar.set_label(label, rotation=270, labelpad=cbar_labelpad)
1496
+ else:
1497
+ # Uniform values - just color all elements the same
1498
+ for i, elem in enumerate(elements):
1499
+ elem_type = element_types[i]
1500
+ if elem_type == 3: # Triangle
1501
+ coords = nodes[elem[:3]]
1502
+ triangle = plt.Polygon(coords, facecolor='lightblue', edgecolor='none', alpha=0.7)
1503
+ ax.add_patch(triangle)
1504
+ elif elem_type == 4: # Quad
1505
+ coords = nodes[elem[:4]]
1506
+ quad = plt.Polygon(coords, facecolor='lightblue', edgecolor='none', alpha=0.7)
1507
+ ax.add_patch(quad)
1508
+ elif elem_type == 6: # 6-node triangle - use corner nodes
1509
+ coords = nodes[elem[:3]]
1510
+ triangle = plt.Polygon(coords, facecolor='lightblue', edgecolor='none', alpha=0.7)
1511
+ ax.add_patch(triangle)
1512
+ elif elem_type in [8, 9]: # 8-node or 9-node quad - use corner nodes
1513
+ coords = nodes[elem[:4]]
1514
+ quad = plt.Polygon(coords, facecolor='lightblue', edgecolor='none', alpha=0.7)
1515
+ ax.add_patch(quad)
1516
+
1517
+ # Overlay mesh if requested
1518
+ if show_mesh:
1519
+ for i, elem in enumerate(elements):
1520
+ elem_type = element_types[i]
1521
+ if elem_type == 3: # Triangle
1522
+ coords = nodes[elem[:3]]
1523
+ triangle = plt.Polygon(coords, fill=False, edgecolor='black', linewidth=0.5, alpha=0.7)
1524
+ ax.add_patch(triangle)
1525
+ elif elem_type == 4: # Quad
1526
+ coords = nodes[elem[:4]]
1527
+ quad = plt.Polygon(coords, fill=False, edgecolor='black', linewidth=0.5, alpha=0.7)
1528
+ ax.add_patch(quad)
1529
+ elif elem_type == 6: # 6-node triangle - use corner nodes
1530
+ coords = nodes[elem[:3]]
1531
+ triangle = plt.Polygon(coords, fill=False, edgecolor='black', linewidth=0.5, alpha=0.7)
1532
+ ax.add_patch(triangle)
1533
+ elif elem_type in [8, 9]: # 8-node or 9-node quad - use corner nodes
1534
+ coords = nodes[elem[:4]]
1535
+ quad = plt.Polygon(coords, fill=False, edgecolor='black', linewidth=0.5, alpha=0.7)
1536
+ ax.add_patch(quad)
1537
+
1538
+ # Add reinforcement if requested
1539
+ if show_reinforcement:
1540
+ elements_1d = fem_data.get("elements_1d", np.array([]).reshape(0, 3))
1541
+ if len(elements_1d) > 0:
1542
+ for elem in elements_1d:
1543
+ if len(elem) >= 2:
1544
+ x_coords = [nodes[elem[0], 0], nodes[elem[1], 0]]
1545
+ y_coords = [nodes[elem[0], 1], nodes[elem[1], 1]]
1546
+ ax.plot(x_coords, y_coords, 'r-', linewidth=2, alpha=0.8)
1547
+
1548
+ # Add element labels if requested
1549
+ if label_elements:
1550
+ _add_element_labels(ax, fem_data)
1551
+
1552
+ ax.set_aspect('equal')
1553
+
1554
+
1555
+ def _plot_nodal_contours(ax, fem_data, element_values, label, show_mesh=True, show_reinforcement=True,
1556
+ cbar_shrink=0.8, cbar_labelpad=20, colormap='viridis', label_elements=False):
1557
+ """
1558
+ Plot smooth contours by interpolating element values to nodes.
1559
+ """
1560
+ nodes = fem_data["nodes"]
1561
+ elements = fem_data["elements"]
1562
+ element_types = fem_data["element_types"]
1563
+
1564
+ # Interpolate element values to nodes
1565
+ nodal_values = np.zeros(len(nodes))
1566
+ node_counts = np.zeros(len(nodes)) # For averaging
1567
+
1568
+ for i, elem in enumerate(elements):
1569
+ elem_type = element_types[i]
1570
+ elem_nodes = elem[:elem_type] if elem_type <= len(elem) else elem
1571
+
1572
+ # Add this element's value to all its nodes
1573
+ for node_id in elem_nodes:
1574
+ if node_id < len(nodes):
1575
+ nodal_values[node_id] += element_values[i]
1576
+ node_counts[node_id] += 1
1577
+
1578
+ # Average values at nodes (avoid division by zero)
1579
+ valid_nodes = node_counts > 0
1580
+ nodal_values[valid_nodes] /= node_counts[valid_nodes]
1581
+
1582
+ # Create triangulation for smooth contouring
1583
+ triangles = []
1584
+ for i, elem in enumerate(elements):
1585
+ elem_type = element_types[i]
1586
+ if elem_type == 3: # Triangle
1587
+ triangles.append([elem[0], elem[1], elem[2]])
1588
+ elif elem_type == 4: # Quad - split into triangles
1589
+ triangles.append([elem[0], elem[1], elem[2]])
1590
+ triangles.append([elem[0], elem[2], elem[3]])
1591
+ elif elem_type == 6: # 6-node triangle - use corner nodes
1592
+ triangles.append([elem[0], elem[1], elem[2]])
1593
+ elif elem_type in [8, 9]: # 8-node or 9-node quad - use corner nodes
1594
+ triangles.append([elem[0], elem[1], elem[2]])
1595
+ triangles.append([elem[0], elem[2], elem[3]])
1596
+
1597
+ if not triangles:
1598
+ print("No valid elements for contouring")
1599
+ return
1600
+
1601
+ import matplotlib.tri as tri
1602
+ triangles = np.array(triangles)
1603
+
1604
+ # Create triangulation
1605
+ triang = tri.Triangulation(nodes[:, 0], nodes[:, 1], triangles)
1606
+
1607
+ # Create smooth contour plot
1608
+ if np.max(nodal_values) > np.min(nodal_values): # Only plot if there's variation
1609
+ levels = np.linspace(np.min(nodal_values), np.max(nodal_values), 20)
1610
+ cs = ax.tricontourf(triang, nodal_values, levels=levels, cmap=colormap)
1611
+
1612
+ # Add colorbar
1613
+ cbar = plt.colorbar(cs, ax=ax, shrink=cbar_shrink, pad=0.05)
1614
+ cbar.set_label(label, rotation=270, labelpad=cbar_labelpad)
1615
+ else:
1616
+ # Uniform values - just color all elements the same
1617
+ uniform_color = plt.get_cmap(colormap)(0.5)
1618
+ for triangle_nodes in triangles:
1619
+ coords = nodes[triangle_nodes]
1620
+ triangle = plt.Polygon(coords, facecolor=uniform_color, edgecolor='none', alpha=0.8)
1621
+ ax.add_patch(triangle)
1622
+
1623
+ # Overlay mesh if requested
1624
+ if show_mesh:
1625
+ for i, elem in enumerate(elements):
1626
+ elem_type = element_types[i]
1627
+ if elem_type == 3: # Triangle
1628
+ coords = nodes[elem[:3]]
1629
+ triangle = plt.Polygon(coords, fill=False, edgecolor='black', linewidth=0.5, alpha=0.7)
1630
+ ax.add_patch(triangle)
1631
+ elif elem_type == 4: # Quad
1632
+ coords = nodes[elem[:4]]
1633
+ quad = plt.Polygon(coords, fill=False, edgecolor='black', linewidth=0.5, alpha=0.7)
1634
+ ax.add_patch(quad)
1635
+ elif elem_type == 6: # 6-node triangle - use corner nodes
1636
+ coords = nodes[elem[:3]]
1637
+ triangle = plt.Polygon(coords, fill=False, edgecolor='black', linewidth=0.5, alpha=0.7)
1638
+ ax.add_patch(triangle)
1639
+ elif elem_type in [8, 9]: # 8-node or 9-node quad - use corner nodes
1640
+ coords = nodes[elem[:4]]
1641
+ quad = plt.Polygon(coords, fill=False, edgecolor='black', linewidth=0.5, alpha=0.7)
1642
+ ax.add_patch(quad)
1643
+
1644
+ # Add reinforcement if requested
1645
+ if show_reinforcement:
1646
+ elements_1d = fem_data.get("elements_1d", np.array([]).reshape(0, 3))
1647
+ if len(elements_1d) > 0:
1648
+ for elem in elements_1d:
1649
+ if len(elem) >= 2:
1650
+ x_coords = [nodes[elem[0], 0], nodes[elem[1], 0]]
1651
+ y_coords = [nodes[elem[0], 1], nodes[elem[1], 1]]
1652
+ ax.plot(x_coords, y_coords, 'r-', linewidth=2, alpha=0.8)
1653
+
1654
+ # Add element labels if requested
1655
+ if label_elements:
1656
+ _add_element_labels(ax, fem_data)
1657
+
1658
+ ax.set_aspect('equal')