coordinate-system 5.2.2__cp313-cp313-win_amd64.whl → 6.0.0__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,61 +1,99 @@
1
1
  """
2
- 3D Coordinate System and Curve Visualization Module
3
- ===================================================
2
+ 3D Coordinate System and Surface Visualization Module
3
+ =====================================================
4
4
 
5
- This module provides tools for visualizing coordinate systems and curves in 3D space.
5
+ Comprehensive visualization tools for coordinate systems, curves, and surfaces.
6
6
 
7
7
  Features:
8
- - Draw coordinate systems with RGB colors (X=Red, Y=Green, Z=Blue)
9
- - Visualize curves with vertices, tangents, and normals
10
- - Display coordinate frame fields along curves
11
- - Support for both single frames and frame arrays
12
-
13
- Author: PanGuoJun
14
- Date: 2025-12-01
8
+ - Coordinate frame visualization with RGB color scheme (X=Red, Y=Green, Z=Blue)
9
+ - Parametric curve visualization with Frenet frames
10
+ - Surface rendering with curvature coloring
11
+ - Frame field visualization on surfaces
12
+ - Multiple view angles and animation support
13
+
14
+ Author: Coordinate System Package
15
+ Date: 2025-12-03
15
16
  """
16
17
 
17
18
  import numpy as np
18
19
  import matplotlib.pyplot as plt
19
20
  from mpl_toolkits.mplot3d import Axes3D
20
- from typing import List, Optional, Tuple, Callable, Union
21
+ from matplotlib import cm
22
+ from matplotlib.colors import Normalize, LinearSegmentedColormap
23
+ import matplotlib.animation as animation
24
+ from typing import List, Optional, Tuple, Callable, Union, Dict
21
25
 
22
26
  # Import from the C extension module
23
27
  try:
24
28
  from .coordinate_system import vec3, coord3
25
29
  except ImportError:
26
- # Fallback for development
27
30
  import coordinate_system
28
31
  vec3 = coordinate_system.vec3
29
32
  coord3 = coordinate_system.coord3
30
33
 
34
+ # Import differential geometry for surface visualization
35
+ try:
36
+ from .differential_geometry import Surface, Sphere, Torus, compute_gaussian_curvature, compute_mean_curvature
37
+ except ImportError:
38
+ Surface = None
39
+ Sphere = None
40
+ Torus = None
41
+ compute_gaussian_curvature = None
42
+ compute_mean_curvature = None
43
+
44
+
45
+ # ============================================================
46
+ # Color Schemes
47
+ # ============================================================
48
+
49
+ # Curvature colormap: blue (negative) -> white (zero) -> red (positive)
50
+ CURVATURE_COLORS = [
51
+ (0.0, 'blue'),
52
+ (0.5, 'white'),
53
+ (1.0, 'red')
54
+ ]
55
+
56
+ def create_curvature_colormap():
57
+ """Create a diverging colormap for curvature visualization."""
58
+ return LinearSegmentedColormap.from_list('curvature',
59
+ [(0.0, 'blue'), (0.5, 'white'), (1.0, 'red')])
60
+
61
+
62
+ # ============================================================
63
+ # Coordinate System Visualizer
64
+ # ============================================================
31
65
 
32
66
  class CoordinateSystemVisualizer:
33
67
  """
34
- 三维坐标系可视化工具
68
+ 3D coordinate system visualization tool.
35
69
 
36
- 绘制坐标系时使用RGB颜色方案:
37
- - X轴: 红色 (Red)
38
- - Y轴: 绿色 (Green)
39
- - Z轴: 蓝色 (Blue)
70
+ RGB color scheme:
71
+ - X axis: Red
72
+ - Y axis: Green
73
+ - Z axis: Blue
40
74
  """
41
75
 
42
- def __init__(self, figsize: Tuple[int, int] = (12, 9)):
76
+ def __init__(self, figsize: Tuple[int, int] = (12, 9), dpi: int = 100):
43
77
  """
44
- 初始化可视化工具
78
+ Initialize visualizer.
45
79
 
46
80
  Args:
47
- figsize: 图形大小 (width, height)
81
+ figsize: Figure size (width, height) in inches
82
+ dpi: Dots per inch for rendering
48
83
  """
49
- self.fig = plt.figure(figsize=figsize)
84
+ self.fig = plt.figure(figsize=figsize, dpi=dpi)
50
85
  self.ax = self.fig.add_subplot(111, projection='3d')
51
86
  self._setup_axis()
52
87
 
53
88
  def _setup_axis(self):
54
- """设置坐标轴样式"""
55
- self.ax.set_xlabel('X', fontsize=12, color='red')
56
- self.ax.set_ylabel('Y', fontsize=12, color='green')
57
- self.ax.set_zlabel('Z', fontsize=12, color='blue')
58
- self.ax.grid(True, alpha=0.3)
89
+ """Configure axis style."""
90
+ self.ax.set_xlabel('X', fontsize=12, color='red', fontweight='bold')
91
+ self.ax.set_ylabel('Y', fontsize=12, color='green', fontweight='bold')
92
+ self.ax.set_zlabel('Z', fontsize=12, color='blue', fontweight='bold')
93
+ self.ax.grid(True, alpha=0.3, linestyle='--')
94
+ self.ax.xaxis.pane.fill = False
95
+ self.ax.yaxis.pane.fill = False
96
+ self.ax.zaxis.pane.fill = False
59
97
 
60
98
  def draw_coord_system(
61
99
  self,
@@ -63,49 +101,54 @@ class CoordinateSystemVisualizer:
63
101
  scale: float = 1.0,
64
102
  linewidth: float = 2.0,
65
103
  alpha: float = 0.8,
66
- label_prefix: str = ""
104
+ label_prefix: str = "",
105
+ arrow_style: bool = True
67
106
  ):
68
107
  """
69
- 绘制单个坐标系
108
+ Draw a single coordinate frame.
70
109
 
71
110
  Args:
72
- coord: coord3对象
73
- scale: 轴长度缩放因子
74
- linewidth: 线宽
75
- alpha: 透明度
76
- label_prefix: 标签前缀
111
+ coord: coord3 object
112
+ scale: Axis length scale factor
113
+ linewidth: Line width
114
+ alpha: Transparency
115
+ label_prefix: Label prefix for legend
116
+ arrow_style: Use arrow heads if True
77
117
  """
78
118
  origin = coord.o
79
119
 
80
- # X轴 (红色)
81
- x_end = origin + coord.ux * scale
82
- self.ax.plot(
83
- [origin.x, x_end.x],
84
- [origin.y, x_end.y],
85
- [origin.z, x_end.z],
86
- 'r-', linewidth=linewidth, alpha=alpha,
87
- label=f'{label_prefix}X' if label_prefix else None
88
- )
89
-
90
- # Y轴 (绿色)
91
- y_end = origin + coord.uy * scale
92
- self.ax.plot(
93
- [origin.x, y_end.x],
94
- [origin.y, y_end.y],
95
- [origin.z, y_end.z],
96
- 'g-', linewidth=linewidth, alpha=alpha,
97
- label=f'{label_prefix}Y' if label_prefix else None
98
- )
99
-
100
- # Z轴 (蓝色)
101
- z_end = origin + coord.uz * scale
102
- self.ax.plot(
103
- [origin.x, z_end.x],
104
- [origin.y, z_end.y],
105
- [origin.z, z_end.z],
106
- 'b-', linewidth=linewidth, alpha=alpha,
107
- label=f'{label_prefix}Z' if label_prefix else None
108
- )
120
+ if arrow_style:
121
+ # Use quiver for arrows
122
+ self.ax.quiver(
123
+ origin.x, origin.y, origin.z,
124
+ coord.ux.x * scale, coord.ux.y * scale, coord.ux.z * scale,
125
+ color='red', linewidth=linewidth, alpha=alpha,
126
+ arrow_length_ratio=0.15
127
+ )
128
+ self.ax.quiver(
129
+ origin.x, origin.y, origin.z,
130
+ coord.uy.x * scale, coord.uy.y * scale, coord.uy.z * scale,
131
+ color='green', linewidth=linewidth, alpha=alpha,
132
+ arrow_length_ratio=0.15
133
+ )
134
+ self.ax.quiver(
135
+ origin.x, origin.y, origin.z,
136
+ coord.uz.x * scale, coord.uz.y * scale, coord.uz.z * scale,
137
+ color='blue', linewidth=linewidth, alpha=alpha,
138
+ arrow_length_ratio=0.15
139
+ )
140
+ else:
141
+ # Use simple lines
142
+ x_end = origin + coord.ux * scale
143
+ y_end = origin + coord.uy * scale
144
+ z_end = origin + coord.uz * scale
145
+
146
+ self.ax.plot([origin.x, x_end.x], [origin.y, x_end.y], [origin.z, x_end.z],
147
+ 'r-', linewidth=linewidth, alpha=alpha)
148
+ self.ax.plot([origin.x, y_end.x], [origin.y, y_end.y], [origin.z, y_end.z],
149
+ 'g-', linewidth=linewidth, alpha=alpha)
150
+ self.ax.plot([origin.x, z_end.x], [origin.y, z_end.y], [origin.z, z_end.z],
151
+ 'b-', linewidth=linewidth, alpha=alpha)
109
152
 
110
153
  def draw_world_coord(
111
154
  self,
@@ -114,12 +157,12 @@ class CoordinateSystemVisualizer:
114
157
  linewidth: float = 3.0
115
158
  ):
116
159
  """
117
- 绘制世界坐标系
160
+ Draw world coordinate system.
118
161
 
119
162
  Args:
120
- origin: 原点位置,默认为(0,0,0)
121
- scale: 轴长度
122
- linewidth: 线宽
163
+ origin: Origin position, default is (0, 0, 0)
164
+ scale: Axis length
165
+ linewidth: Line width
123
166
  """
124
167
  if origin is None:
125
168
  origin = vec3(0, 0, 0)
@@ -135,91 +178,381 @@ class CoordinateSystemVisualizer:
135
178
  label_prefix="World-"
136
179
  )
137
180
 
181
+ def draw_point(
182
+ self,
183
+ point: vec3,
184
+ color: str = 'black',
185
+ size: float = 50,
186
+ marker: str = 'o',
187
+ label: str = None
188
+ ):
189
+ """
190
+ Draw a single point.
191
+
192
+ Args:
193
+ point: Point position
194
+ color: Point color
195
+ size: Marker size
196
+ marker: Marker style
197
+ label: Label for legend
198
+ """
199
+ self.ax.scatter([point.x], [point.y], [point.z],
200
+ c=color, s=size, marker=marker, label=label)
201
+
202
+ def draw_vector(
203
+ self,
204
+ start: vec3,
205
+ direction: vec3,
206
+ color: str = 'black',
207
+ linewidth: float = 2.0,
208
+ alpha: float = 0.8,
209
+ label: str = None
210
+ ):
211
+ """
212
+ Draw a vector as an arrow.
213
+
214
+ Args:
215
+ start: Starting point
216
+ direction: Direction vector
217
+ color: Arrow color
218
+ linewidth: Line width
219
+ alpha: Transparency
220
+ label: Label for legend
221
+ """
222
+ self.ax.quiver(
223
+ start.x, start.y, start.z,
224
+ direction.x, direction.y, direction.z,
225
+ color=color, linewidth=linewidth, alpha=alpha,
226
+ arrow_length_ratio=0.15, label=label
227
+ )
228
+
138
229
  def set_equal_aspect(self):
139
- """设置等比例显示"""
140
- # 获取当前所有数据的范围
230
+ """Set equal aspect ratio for all axes."""
141
231
  xlim = self.ax.get_xlim3d()
142
232
  ylim = self.ax.get_ylim3d()
143
233
  zlim = self.ax.get_zlim3d()
144
234
 
145
- # 计算范围
146
235
  x_range = abs(xlim[1] - xlim[0])
147
236
  y_range = abs(ylim[1] - ylim[0])
148
237
  z_range = abs(zlim[1] - zlim[0])
149
238
 
150
- # 使用最大范围
151
239
  max_range = max(x_range, y_range, z_range)
152
240
 
153
- # 计算中心点
154
241
  x_middle = np.mean(xlim)
155
242
  y_middle = np.mean(ylim)
156
243
  z_middle = np.mean(zlim)
157
244
 
158
- # 设置相等的范围
159
245
  self.ax.set_xlim3d([x_middle - max_range/2, x_middle + max_range/2])
160
246
  self.ax.set_ylim3d([y_middle - max_range/2, y_middle + max_range/2])
161
247
  self.ax.set_zlim3d([z_middle - max_range/2, z_middle + max_range/2])
162
248
 
249
+ def set_view(self, elev: float = 30, azim: float = 45):
250
+ """
251
+ Set camera view angle.
252
+
253
+ Args:
254
+ elev: Elevation angle in degrees
255
+ azim: Azimuth angle in degrees
256
+ """
257
+ self.ax.view_init(elev=elev, azim=azim)
258
+
259
+ def set_title(self, title: str, fontsize: int = 14):
260
+ """Set plot title."""
261
+ self.ax.set_title(title, fontsize=fontsize, fontweight='bold')
262
+
163
263
  def show(self):
164
- """显示图形"""
165
- self.ax.legend()
264
+ """Display the figure."""
265
+ self.ax.legend(loc='upper left')
166
266
  plt.tight_layout()
167
267
  plt.show()
168
268
 
169
269
  def save(self, filename: str, dpi: int = 300):
170
270
  """
171
- 保存图形
271
+ Save figure to file.
172
272
 
173
273
  Args:
174
- filename: 文件名
175
- dpi: 分辨率
274
+ filename: Output filename
275
+ dpi: Resolution
176
276
  """
177
- self.ax.legend()
277
+ self.ax.legend(loc='upper left')
178
278
  plt.tight_layout()
179
279
  plt.savefig(filename, dpi=dpi, bbox_inches='tight')
280
+ print(f"Saved: {filename}")
180
281
 
181
282
 
182
- class CurveVisualizer(CoordinateSystemVisualizer):
283
+ # ============================================================
284
+ # Surface Visualizer
285
+ # ============================================================
286
+
287
+ class SurfaceVisualizer(CoordinateSystemVisualizer):
183
288
  """
184
- 曲线可视化工具
289
+ Surface visualization with curvature coloring.
185
290
 
186
- 支持绘制:
187
- - 曲线顶点
188
- - 切线
189
- - 法线
190
- - 完整坐标系场
291
+ Supports:
292
+ - Wireframe and surface rendering
293
+ - Gaussian/Mean curvature coloring
294
+ - Frame field overlay
295
+ - Normal vector visualization
191
296
  """
192
297
 
193
- def __init__(self, figsize: Tuple[int, int] = (12, 9)):
298
+ def __init__(self, figsize: Tuple[int, int] = (14, 10), dpi: int = 100):
299
+ super().__init__(figsize, dpi)
300
+ self.colorbar = None
301
+
302
+ def draw_surface(
303
+ self,
304
+ surface: 'Surface',
305
+ u_range: Tuple[float, float] = (0.1, np.pi - 0.1),
306
+ v_range: Tuple[float, float] = (0, 2 * np.pi),
307
+ nu: int = 30,
308
+ nv: int = 40,
309
+ color: str = 'cyan',
310
+ alpha: float = 0.6,
311
+ wireframe: bool = True,
312
+ surface_plot: bool = True
313
+ ):
314
+ """
315
+ Draw a parametric surface.
316
+
317
+ Args:
318
+ surface: Surface object
319
+ u_range: Parameter u range
320
+ v_range: Parameter v range
321
+ nu: Number of u samples
322
+ nv: Number of v samples
323
+ color: Surface color
324
+ alpha: Transparency
325
+ wireframe: Show wireframe
326
+ surface_plot: Show filled surface
327
+ """
328
+ u = np.linspace(u_range[0], u_range[1], nu)
329
+ v = np.linspace(v_range[0], v_range[1], nv)
330
+ U, V = np.meshgrid(u, v)
331
+
332
+ X = np.zeros_like(U)
333
+ Y = np.zeros_like(U)
334
+ Z = np.zeros_like(U)
335
+
336
+ for i in range(nu):
337
+ for j in range(nv):
338
+ pos = surface.position(U[j, i], V[j, i])
339
+ X[j, i] = pos.x
340
+ Y[j, i] = pos.y
341
+ Z[j, i] = pos.z
342
+
343
+ if surface_plot:
344
+ self.ax.plot_surface(X, Y, Z, color=color, alpha=alpha,
345
+ edgecolor='none', shade=True)
346
+
347
+ if wireframe:
348
+ self.ax.plot_wireframe(X, Y, Z, color='gray', alpha=0.3,
349
+ linewidth=0.5, rstride=2, cstride=2)
350
+
351
+ def draw_surface_curvature(
352
+ self,
353
+ surface: 'Surface',
354
+ curvature_type: str = 'gaussian',
355
+ u_range: Tuple[float, float] = (0.1, np.pi - 0.1),
356
+ v_range: Tuple[float, float] = (0, 2 * np.pi),
357
+ nu: int = 30,
358
+ nv: int = 40,
359
+ alpha: float = 0.8,
360
+ show_colorbar: bool = True,
361
+ step_size: float = 1e-3
362
+ ):
363
+ """
364
+ Draw surface with curvature coloring.
365
+
366
+ Args:
367
+ surface: Surface object
368
+ curvature_type: 'gaussian' or 'mean'
369
+ u_range: Parameter u range
370
+ v_range: Parameter v range
371
+ nu: Number of u samples
372
+ nv: Number of v samples
373
+ alpha: Transparency
374
+ show_colorbar: Show colorbar
375
+ step_size: Step size for curvature computation
376
+ """
377
+ if compute_gaussian_curvature is None:
378
+ raise ImportError("differential_geometry module not available")
379
+
380
+ u = np.linspace(u_range[0], u_range[1], nu)
381
+ v = np.linspace(v_range[0], v_range[1], nv)
382
+ U, V = np.meshgrid(u, v)
383
+
384
+ X = np.zeros_like(U)
385
+ Y = np.zeros_like(U)
386
+ Z = np.zeros_like(U)
387
+ K = np.zeros_like(U)
388
+
389
+ compute_func = compute_gaussian_curvature if curvature_type == 'gaussian' else compute_mean_curvature
390
+
391
+ for i in range(nu):
392
+ for j in range(nv):
393
+ pos = surface.position(U[j, i], V[j, i])
394
+ X[j, i] = pos.x
395
+ Y[j, i] = pos.y
396
+ Z[j, i] = pos.z
397
+ K[j, i] = compute_func(surface, U[j, i], V[j, i], step_size)
398
+
399
+ # Normalize curvature for coloring
400
+ K_abs_max = max(abs(K.min()), abs(K.max()))
401
+ if K_abs_max > 1e-10:
402
+ K_normalized = (K / K_abs_max + 1) / 2 # Map to [0, 1]
403
+ else:
404
+ K_normalized = np.ones_like(K) * 0.5
405
+
406
+ # Create colormap
407
+ cmap = create_curvature_colormap()
408
+ colors = cmap(K_normalized)
409
+
410
+ # Draw surface
411
+ surf = self.ax.plot_surface(X, Y, Z, facecolors=colors, alpha=alpha,
412
+ shade=True, linewidth=0, antialiased=True)
413
+
414
+ if show_colorbar:
415
+ norm = Normalize(vmin=-K_abs_max, vmax=K_abs_max)
416
+ sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
417
+ sm.set_array([])
418
+ self.colorbar = self.fig.colorbar(sm, ax=self.ax, shrink=0.6, aspect=20,
419
+ label=f'{curvature_type.capitalize()} Curvature')
420
+
421
+ def draw_surface_normals(
422
+ self,
423
+ surface: 'Surface',
424
+ u_range: Tuple[float, float] = (0.1, np.pi - 0.1),
425
+ v_range: Tuple[float, float] = (0, 2 * np.pi),
426
+ nu: int = 8,
427
+ nv: int = 12,
428
+ scale: float = 0.3,
429
+ color: str = 'blue',
430
+ alpha: float = 0.8
431
+ ):
194
432
  """
195
- 初始化曲线可视化工具
433
+ Draw surface normal vectors.
196
434
 
197
435
  Args:
198
- figsize: 图形大小
436
+ surface: Surface object
437
+ u_range: Parameter u range
438
+ v_range: Parameter v range
439
+ nu: Number of u samples
440
+ nv: Number of v samples
441
+ scale: Normal vector length
442
+ color: Vector color
443
+ alpha: Transparency
199
444
  """
200
- super().__init__(figsize)
445
+ u = np.linspace(u_range[0], u_range[1], nu)
446
+ v = np.linspace(v_range[0], v_range[1], nv)
447
+
448
+ for ui in u:
449
+ for vi in v:
450
+ pos = surface.position(ui, vi)
451
+ n = surface.normal(ui, vi)
452
+
453
+ self.ax.quiver(
454
+ pos.x, pos.y, pos.z,
455
+ n.x * scale, n.y * scale, n.z * scale,
456
+ color=color, alpha=alpha, arrow_length_ratio=0.2
457
+ )
458
+
459
+ def draw_surface_frames(
460
+ self,
461
+ surface: 'Surface',
462
+ u_range: Tuple[float, float] = (0.1, np.pi - 0.1),
463
+ v_range: Tuple[float, float] = (0, 2 * np.pi),
464
+ nu: int = 6,
465
+ nv: int = 8,
466
+ scale: float = 0.25,
467
+ linewidth: float = 1.5,
468
+ alpha: float = 0.8
469
+ ):
470
+ """
471
+ Draw frame field on surface.
472
+
473
+ Args:
474
+ surface: Surface object
475
+ u_range: Parameter u range
476
+ v_range: Parameter v range
477
+ nu: Number of u samples
478
+ nv: Number of v samples
479
+ scale: Frame axis length
480
+ linewidth: Line width
481
+ alpha: Transparency
482
+ """
483
+ u = np.linspace(u_range[0], u_range[1], nu)
484
+ v = np.linspace(v_range[0], v_range[1], nv)
485
+
486
+ for ui in u:
487
+ for vi in v:
488
+ pos = surface.position(ui, vi)
489
+ r_u = surface.tangent_u(ui, vi)
490
+ r_v = surface.tangent_v(ui, vi)
491
+ n = surface.normal(ui, vi)
492
+
493
+ # Normalize
494
+ r_u_norm = (r_u.x**2 + r_u.y**2 + r_u.z**2) ** 0.5
495
+ r_v_norm = (r_v.x**2 + r_v.y**2 + r_v.z**2) ** 0.5
496
+
497
+ if r_u_norm > 1e-10:
498
+ r_u = r_u * (1.0 / r_u_norm)
499
+ if r_v_norm > 1e-10:
500
+ r_v = r_v * (1.0 / r_v_norm)
501
+
502
+ # Draw frame (red=u, green=v, blue=normal)
503
+ self.ax.quiver(pos.x, pos.y, pos.z,
504
+ r_u.x * scale, r_u.y * scale, r_u.z * scale,
505
+ color='red', linewidth=linewidth, alpha=alpha,
506
+ arrow_length_ratio=0.15)
507
+ self.ax.quiver(pos.x, pos.y, pos.z,
508
+ r_v.x * scale, r_v.y * scale, r_v.z * scale,
509
+ color='green', linewidth=linewidth, alpha=alpha,
510
+ arrow_length_ratio=0.15)
511
+ self.ax.quiver(pos.x, pos.y, pos.z,
512
+ n.x * scale, n.y * scale, n.z * scale,
513
+ color='blue', linewidth=linewidth, alpha=alpha,
514
+ arrow_length_ratio=0.15)
515
+
516
+
517
+ # ============================================================
518
+ # Curve Visualizer
519
+ # ============================================================
520
+
521
+ class CurveVisualizer(CoordinateSystemVisualizer):
522
+ """
523
+ Curve visualization with Frenet frame support.
524
+
525
+ Features:
526
+ - Curve path rendering
527
+ - Tangent, normal, binormal vectors
528
+ - Frenet frame field
529
+ - Animation support
530
+ """
531
+
532
+ def __init__(self, figsize: Tuple[int, int] = (12, 9), dpi: int = 100):
533
+ super().__init__(figsize, dpi)
201
534
 
202
535
  def draw_curve_vertices(
203
536
  self,
204
537
  points: List[vec3],
205
538
  color: str = 'black',
206
539
  linewidth: float = 2.0,
207
- marker: str = 'o',
540
+ marker: str = '',
208
541
  markersize: float = 4,
209
- alpha: float = 0.7,
542
+ alpha: float = 0.9,
210
543
  label: str = "Curve"
211
544
  ):
212
545
  """
213
- 绘制曲线顶点
546
+ Draw curve through vertices.
214
547
 
215
548
  Args:
216
- points: 顶点列表
217
- color: 颜色
218
- linewidth: 线宽
219
- marker: 标记样式
220
- markersize: 标记大小
221
- alpha: 透明度
222
- label: 标签
549
+ points: Vertex list
550
+ color: Line color
551
+ linewidth: Line width
552
+ marker: Marker style (empty for no markers)
553
+ markersize: Marker size
554
+ alpha: Transparency
555
+ label: Legend label
223
556
  """
224
557
  if not points:
225
558
  return
@@ -228,109 +561,93 @@ class CurveVisualizer(CoordinateSystemVisualizer):
228
561
  y = [p.y for p in points]
229
562
  z = [p.z for p in points]
230
563
 
231
- self.ax.plot(
232
- x, y, z,
233
- color=color,
234
- linewidth=linewidth,
235
- marker=marker,
236
- markersize=markersize,
237
- alpha=alpha,
238
- label=label
239
- )
564
+ self.ax.plot(x, y, z, color=color, linewidth=linewidth,
565
+ marker=marker, markersize=markersize,
566
+ alpha=alpha, label=label)
240
567
 
241
568
  def draw_tangents(
242
569
  self,
243
570
  points: List[vec3],
244
571
  tangents: List[vec3],
245
572
  scale: float = 0.5,
246
- color: str = 'orange',
573
+ color: str = 'red',
247
574
  linewidth: float = 1.5,
248
- alpha: float = 0.8,
249
- arrow: bool = True
575
+ alpha: float = 0.8
250
576
  ):
251
577
  """
252
- 绘制切线
578
+ Draw tangent vectors.
253
579
 
254
580
  Args:
255
- points: 曲线点列表
256
- tangents: 切线向量列表
257
- scale: 切线长度缩放
258
- color: 颜色
259
- linewidth: 线宽
260
- alpha: 透明度
261
- arrow: 是否使用箭头
581
+ points: Curve points
582
+ tangents: Tangent vectors
583
+ scale: Vector length scale
584
+ color: Vector color
585
+ linewidth: Line width
586
+ alpha: Transparency
262
587
  """
263
- if len(points) != len(tangents):
264
- raise ValueError("Points and tangents must have the same length")
265
-
266
588
  for point, tangent in zip(points, tangents):
267
- end = point + tangent * scale
268
-
269
- if arrow:
270
- self.ax.quiver(
271
- point.x, point.y, point.z,
272
- tangent.x * scale, tangent.y * scale, tangent.z * scale,
273
- color=color,
274
- linewidth=linewidth,
275
- alpha=alpha,
276
- arrow_length_ratio=0.3
277
- )
278
- else:
279
- self.ax.plot(
280
- [point.x, end.x],
281
- [point.y, end.y],
282
- [point.z, end.z],
283
- color=color,
284
- linewidth=linewidth,
285
- alpha=alpha
286
- )
589
+ self.ax.quiver(
590
+ point.x, point.y, point.z,
591
+ tangent.x * scale, tangent.y * scale, tangent.z * scale,
592
+ color=color, linewidth=linewidth, alpha=alpha,
593
+ arrow_length_ratio=0.2
594
+ )
287
595
 
288
596
  def draw_normals(
289
597
  self,
290
598
  points: List[vec3],
291
599
  normals: List[vec3],
292
600
  scale: float = 0.5,
293
- color: str = 'purple',
601
+ color: str = 'green',
294
602
  linewidth: float = 1.5,
295
- alpha: float = 0.8,
296
- arrow: bool = True
603
+ alpha: float = 0.8
297
604
  ):
298
605
  """
299
- 绘制法线
606
+ Draw normal vectors.
300
607
 
301
608
  Args:
302
- points: 曲线点列表
303
- normals: 法线向量列表
304
- scale: 法线长度缩放
305
- color: 颜色
306
- linewidth: 线宽
307
- alpha: 透明度
308
- arrow: 是否使用箭头
609
+ points: Curve points
610
+ normals: Normal vectors
611
+ scale: Vector length scale
612
+ color: Vector color
613
+ linewidth: Line width
614
+ alpha: Transparency
309
615
  """
310
- if len(points) != len(normals):
311
- raise ValueError("Points and normals must have the same length")
312
-
313
616
  for point, normal in zip(points, normals):
314
- end = point + normal * scale
617
+ self.ax.quiver(
618
+ point.x, point.y, point.z,
619
+ normal.x * scale, normal.y * scale, normal.z * scale,
620
+ color=color, linewidth=linewidth, alpha=alpha,
621
+ arrow_length_ratio=0.2
622
+ )
623
+
624
+ def draw_binormals(
625
+ self,
626
+ points: List[vec3],
627
+ binormals: List[vec3],
628
+ scale: float = 0.5,
629
+ color: str = 'blue',
630
+ linewidth: float = 1.5,
631
+ alpha: float = 0.8
632
+ ):
633
+ """
634
+ Draw binormal vectors.
315
635
 
316
- if arrow:
317
- self.ax.quiver(
318
- point.x, point.y, point.z,
319
- normal.x * scale, normal.y * scale, normal.z * scale,
320
- color=color,
321
- linewidth=linewidth,
322
- alpha=alpha,
323
- arrow_length_ratio=0.3
324
- )
325
- else:
326
- self.ax.plot(
327
- [point.x, end.x],
328
- [point.y, end.y],
329
- [point.z, end.z],
330
- color=color,
331
- linewidth=linewidth,
332
- alpha=alpha
333
- )
636
+ Args:
637
+ points: Curve points
638
+ binormals: Binormal vectors
639
+ scale: Vector length scale
640
+ color: Vector color
641
+ linewidth: Line width
642
+ alpha: Transparency
643
+ """
644
+ for point, binormal in zip(points, binormals):
645
+ self.ax.quiver(
646
+ point.x, point.y, point.z,
647
+ binormal.x * scale, binormal.y * scale, binormal.z * scale,
648
+ color=color, linewidth=linewidth, alpha=alpha,
649
+ arrow_length_ratio=0.2
650
+ )
334
651
 
335
652
  def draw_curve_frames(
336
653
  self,
@@ -341,89 +658,87 @@ class CurveVisualizer(CoordinateSystemVisualizer):
341
658
  skip: int = 1
342
659
  ):
343
660
  """
344
- 绘制曲线上的坐标系场
661
+ Draw Frenet frames along curve.
345
662
 
346
663
  Args:
347
- frames: 坐标系列表
348
- scale: 坐标轴长度
349
- linewidth: 线宽
350
- alpha: 透明度
351
- skip: 跳过间隔(每skip个绘制一个)
664
+ frames: List of coord3 frames
665
+ scale: Axis length
666
+ linewidth: Line width
667
+ alpha: Transparency
668
+ skip: Draw every nth frame
352
669
  """
353
670
  for i, frame in enumerate(frames):
354
671
  if i % skip == 0:
355
- self.draw_coord_system(
356
- frame,
357
- scale=scale,
358
- linewidth=linewidth,
359
- alpha=alpha
360
- )
672
+ self.draw_coord_system(frame, scale=scale,
673
+ linewidth=linewidth, alpha=alpha)
361
674
 
362
675
  def draw_complete_curve(
363
676
  self,
364
677
  points: List[vec3],
365
678
  tangents: Optional[List[vec3]] = None,
366
679
  normals: Optional[List[vec3]] = None,
680
+ binormals: Optional[List[vec3]] = None,
367
681
  frames: Optional[List[coord3]] = None,
368
682
  curve_color: str = 'black',
369
683
  tangent_scale: float = 0.5,
370
684
  normal_scale: float = 0.5,
685
+ binormal_scale: float = 0.5,
371
686
  frame_scale: float = 0.3,
372
687
  frame_skip: int = 5,
373
688
  show_world_coord: bool = True
374
689
  ):
375
690
  """
376
- 绘制完整的曲线及其几何属性
691
+ Draw complete curve with geometric properties.
377
692
 
378
693
  Args:
379
- points: 曲线顶点
380
- tangents: 切线向量(可选)
381
- normals: 法线向量(可选)
382
- frames: 坐标系列表(可选)
383
- curve_color: 曲线颜色
384
- tangent_scale: 切线长度缩放
385
- normal_scale: 法线长度缩放
386
- frame_scale: 坐标系轴长度
387
- frame_skip: 坐标系绘制间隔
388
- show_world_coord: 是否显示世界坐标系
389
-
390
- Note:
391
- 如果提供了 frames (Frenet标架),将使用RGB颜色方案绘制完整标架,
392
- 此时会忽略单独的 tangents 和 normals 参数,避免颜色混淆。
393
- """
394
- # 绘制世界坐标系
694
+ points: Curve vertices
695
+ tangents: Tangent vectors (optional)
696
+ normals: Normal vectors (optional)
697
+ binormals: Binormal vectors (optional)
698
+ frames: Frenet frames (optional, overrides individual vectors)
699
+ curve_color: Curve color
700
+ tangent_scale: Tangent vector scale
701
+ normal_scale: Normal vector scale
702
+ binormal_scale: Binormal vector scale
703
+ frame_scale: Frame axis length
704
+ frame_skip: Frame drawing interval
705
+ show_world_coord: Show world coordinate system
706
+ """
395
707
  if show_world_coord:
396
708
  self.draw_world_coord(scale=1.0)
397
709
 
398
- # 绘制曲线
399
710
  self.draw_curve_vertices(points, color=curve_color, label="Curve")
400
711
 
401
- # 如果提供了完整标架,优先使用RGB标架,不再绘制单独的切线/法线
402
712
  if frames is not None:
403
- # 绘制完整Frenet标架 (红=T, 绿=N, 蓝=B)
713
+ # Use complete Frenet frames (RGB coloring)
404
714
  self.draw_curve_frames(frames, scale=frame_scale, skip=frame_skip)
405
715
  else:
406
- # 否则绘制单独的切线和法线 (橙色/紫色)
716
+ # Draw individual vectors
407
717
  if tangents is not None:
408
718
  self.draw_tangents(points, tangents, scale=tangent_scale)
409
-
410
719
  if normals is not None:
411
720
  self.draw_normals(points, normals, scale=normal_scale)
721
+ if binormals is not None:
722
+ self.draw_binormals(points, binormals, scale=binormal_scale)
412
723
 
413
- # 设置等比例
414
724
  self.set_equal_aspect()
415
725
 
416
726
 
727
+ # ============================================================
728
+ # Parametric Curve
729
+ # ============================================================
730
+
417
731
  class ParametricCurve:
418
732
  """
419
- 参数化曲线类
420
-
421
- 提供曲线的几何属性计算:
422
- - 位置
423
- - 切线
424
- - 法线
425
- - 副法线
426
- - Frenet标架
733
+ Parametric curve with Frenet frame computation.
734
+
735
+ Provides:
736
+ - Position r(t)
737
+ - Tangent T(t)
738
+ - Normal N(t)
739
+ - Binormal B(t)
740
+ - Frenet frame {T, N, B}
741
+ - Curvature and torsion
427
742
  """
428
743
 
429
744
  def __init__(
@@ -433,29 +748,29 @@ class ParametricCurve:
433
748
  num_points: int = 100
434
749
  ):
435
750
  """
436
- 初始化参数化曲线
751
+ Initialize parametric curve.
437
752
 
438
753
  Args:
439
- position_func: 位置函数 r(t) -> vec3
440
- t_range: 参数范围 (t_min, t_max)
441
- num_points: 采样点数
754
+ position_func: Position function r(t) -> vec3
755
+ t_range: Parameter range (t_min, t_max)
756
+ num_points: Number of sample points
442
757
  """
443
758
  self.position_func = position_func
444
759
  self.t_range = t_range
445
760
  self.num_points = num_points
446
- self.h = 1e-6 # 数值微分步长
761
+ self.h = 1e-6 # Numerical differentiation step
447
762
 
448
763
  def position(self, t: float) -> vec3:
449
- """计算位置"""
764
+ """Compute position at parameter t."""
450
765
  return self.position_func(t)
451
766
 
452
767
  def tangent(self, t: float, normalized: bool = True) -> vec3:
453
768
  """
454
- 计算切线 dr/dt
769
+ Compute tangent vector T = dr/dt.
455
770
 
456
771
  Args:
457
- t: 参数值
458
- normalized: 是否归一化
772
+ t: Parameter value
773
+ normalized: Normalize to unit vector
459
774
  """
460
775
  r_plus = self.position_func(t + self.h)
461
776
  r_minus = self.position_func(t - self.h)
@@ -469,51 +784,40 @@ class ParametricCurve:
469
784
  return tangent
470
785
 
471
786
  def second_derivative(self, t: float) -> vec3:
472
- """计算二阶导数 d²r/dt²"""
787
+ """Compute second derivative d^2r/dt^2."""
473
788
  r_plus = self.position_func(t + self.h)
474
789
  r_center = self.position_func(t)
475
790
  r_minus = self.position_func(t - self.h)
476
-
477
- d2r = (r_plus + r_minus - r_center * 2.0) * (1.0 / (self.h * self.h))
478
- return d2r
791
+ return (r_plus + r_minus - r_center * 2.0) * (1.0 / (self.h * self.h))
479
792
 
480
793
  def normal(self, t: float, normalized: bool = True) -> vec3:
481
794
  """
482
- 计算主法线(指向曲率中心)
483
-
484
- 使用Frenet-Serret公式:
485
- N = (dT/dt) / |dT/dt|
486
- 其中 T = (dr/dt) / |dr/dt|
795
+ Compute principal normal N = dT/ds / |dT/ds|.
487
796
 
488
797
  Args:
489
- t: 参数值
490
- normalized: 是否归一化
798
+ t: Parameter value
799
+ normalized: Normalize to unit vector
491
800
  """
492
- # 计算 T(t+h) 和 T(t-h)
493
801
  T_plus = self.tangent(t + self.h, normalized=True)
494
802
  T_minus = self.tangent(t - self.h, normalized=True)
495
-
496
- # 数值微分: dT/dt ≈ (T(t+h) - T(t-h)) / (2h)
497
803
  dT_dt = (T_plus - T_minus) * (1.0 / (2.0 * self.h))
498
804
 
499
- # 主法线是 dT/dt 的归一化
500
805
  length = (dT_dt.x**2 + dT_dt.y**2 + dT_dt.z**2) ** 0.5
501
806
 
502
807
  if length > 1e-10:
503
808
  N = dT_dt * (1.0 / length) if normalized else dT_dt
504
809
  else:
505
- # 如果长度太小,返回一个默认值
506
810
  N = vec3(0, 0, 1)
507
811
 
508
812
  return N
509
813
 
510
814
  def binormal(self, t: float, normalized: bool = True) -> vec3:
511
815
  """
512
- 计算副法线 B = T × N
816
+ Compute binormal B = T x N.
513
817
 
514
818
  Args:
515
- t: 参数值
516
- normalized: 是否归一化
819
+ t: Parameter value
820
+ normalized: Normalize to unit vector
517
821
  """
518
822
  T = self.tangent(t, normalized=True)
519
823
  N = self.normal(t, normalized=True)
@@ -526,81 +830,113 @@ class ParametricCurve:
526
830
 
527
831
  return B
528
832
 
529
- def frenet_frame(self, t: float) -> coord3:
833
+ def curvature(self, t: float) -> float:
530
834
  """
531
- 计算Frenet标架 {T, N, B}
835
+ Compute curvature kappa = |dT/ds|.
532
836
 
533
837
  Args:
534
- t: 参数值
838
+ t: Parameter value
535
839
 
536
840
  Returns:
537
- coord3对象,其中:
538
- - o: 位置
539
- - ux: 切线 T (绘制时为红色)
540
- - uy: 主法线 N (绘制时为绿色)
541
- - uz: 副法线 B (绘制时为蓝色)
841
+ Curvature value
842
+ """
843
+ T_plus = self.tangent(t + self.h, normalized=True)
844
+ T_minus = self.tangent(t - self.h, normalized=True)
845
+ dT_dt = (T_plus - T_minus) * (1.0 / (2.0 * self.h))
846
+
847
+ # Get speed |dr/dt|
848
+ dr_dt = (self.position_func(t + self.h) - self.position_func(t - self.h)) * (1.0 / (2.0 * self.h))
849
+ speed = (dr_dt.x**2 + dr_dt.y**2 + dr_dt.z**2) ** 0.5
850
+
851
+ if speed > 1e-10:
852
+ kappa = (dT_dt.x**2 + dT_dt.y**2 + dT_dt.z**2) ** 0.5 / speed
853
+ else:
854
+ kappa = 0.0
542
855
 
543
- 颜色映射:
544
- - 切线 T = X轴 = 红色
545
- - 主法线 N = Y轴 = 绿色
546
- - 副法线 B = Z轴 = 蓝色
856
+ return kappa
857
+
858
+ def frenet_frame(self, t: float) -> coord3:
859
+ """
860
+ Compute Frenet frame {T, N, B}.
861
+
862
+ Returns coord3 with:
863
+ - o: Position
864
+ - ux: Tangent T (red)
865
+ - uy: Normal N (green)
866
+ - uz: Binormal B (blue)
547
867
  """
548
868
  frame = coord3()
549
869
  frame.o = self.position(t)
550
- frame.ux = self.tangent(t, normalized=True) # T → 红色
551
- frame.uy = self.normal(t, normalized=True) # N → 绿色
552
- frame.uz = self.binormal(t, normalized=True) # B → 蓝色
553
-
870
+ frame.ux = self.tangent(t, normalized=True)
871
+ frame.uy = self.normal(t, normalized=True)
872
+ frame.uz = self.binormal(t, normalized=True)
554
873
  return frame
555
874
 
556
875
  def sample_points(self) -> List[vec3]:
557
- """采样曲线点"""
876
+ """Sample curve positions."""
558
877
  t_min, t_max = self.t_range
559
878
  t_values = np.linspace(t_min, t_max, self.num_points)
560
879
  return [self.position(t) for t in t_values]
561
880
 
562
881
  def sample_tangents(self) -> List[vec3]:
563
- """采样切线"""
882
+ """Sample tangent vectors."""
564
883
  t_min, t_max = self.t_range
565
884
  t_values = np.linspace(t_min, t_max, self.num_points)
566
885
  return [self.tangent(t) for t in t_values]
567
886
 
568
887
  def sample_normals(self) -> List[vec3]:
569
- """采样主法线"""
888
+ """Sample normal vectors."""
570
889
  t_min, t_max = self.t_range
571
890
  t_values = np.linspace(t_min, t_max, self.num_points)
572
891
  return [self.normal(t) for t in t_values]
573
892
 
893
+ def sample_binormals(self) -> List[vec3]:
894
+ """Sample binormal vectors."""
895
+ t_min, t_max = self.t_range
896
+ t_values = np.linspace(t_min, t_max, self.num_points)
897
+ return [self.binormal(t) for t in t_values]
898
+
574
899
  def sample_frames(self) -> List[coord3]:
575
- """采样Frenet标架"""
900
+ """Sample Frenet frames."""
576
901
  t_min, t_max = self.t_range
577
902
  t_values = np.linspace(t_min, t_max, self.num_points)
578
903
  return [self.frenet_frame(t) for t in t_values]
579
904
 
905
+ def sample_curvature(self) -> List[float]:
906
+ """Sample curvature values."""
907
+ t_min, t_max = self.t_range
908
+ t_values = np.linspace(t_min, t_max, self.num_points)
909
+ return [self.curvature(t) for t in t_values]
580
910
 
581
- # ========== 便捷函数 ==========
911
+
912
+ # ============================================================
913
+ # Convenience Functions
914
+ # ============================================================
582
915
 
583
916
  def visualize_coord_system(
584
917
  coord: coord3,
585
918
  scale: float = 1.0,
586
919
  figsize: Tuple[int, int] = (10, 8),
587
920
  show: bool = True,
588
- save_path: Optional[str] = None
921
+ save_path: Optional[str] = None,
922
+ title: str = "Coordinate System"
589
923
  ):
590
924
  """
591
- 快速可视化单个坐标系
925
+ Quick visualization of a single coordinate frame.
592
926
 
593
927
  Args:
594
- coord: 坐标系对象
595
- scale: 轴长度
596
- figsize: 图形大小
597
- show: 是否显示
598
- save_path: 保存路径(可选)
928
+ coord: Coordinate frame
929
+ scale: Axis length
930
+ figsize: Figure size
931
+ show: Display figure
932
+ save_path: Save path (optional)
933
+ title: Plot title
599
934
  """
600
935
  vis = CoordinateSystemVisualizer(figsize=figsize)
601
936
  vis.draw_world_coord(scale=scale * 0.8)
602
937
  vis.draw_coord_system(coord, scale=scale, label_prefix="Frame-")
603
938
  vis.set_equal_aspect()
939
+ vis.set_title(title)
604
940
 
605
941
  if save_path:
606
942
  vis.save(save_path)
@@ -610,41 +946,99 @@ def visualize_coord_system(
610
946
 
611
947
  def visualize_curve(
612
948
  curve: ParametricCurve,
613
- show_tangents: bool = True,
614
- show_normals: bool = True,
615
- show_frames: bool = False,
949
+ show_tangents: bool = False,
950
+ show_normals: bool = False,
951
+ show_binormals: bool = False,
952
+ show_frames: bool = True,
616
953
  frame_skip: int = 5,
617
954
  figsize: Tuple[int, int] = (12, 9),
618
955
  show: bool = True,
619
- save_path: Optional[str] = None
956
+ save_path: Optional[str] = None,
957
+ title: str = "Parametric Curve"
620
958
  ):
621
959
  """
622
- 快速可视化参数化曲线
960
+ Quick visualization of a parametric curve.
623
961
 
624
962
  Args:
625
- curve: 参数化曲线对象
626
- show_tangents: 是否显示切线
627
- show_normals: 是否显示法线
628
- show_frames: 是否显示完整坐标系
629
- frame_skip: 坐标系绘制间隔
630
- figsize: 图形大小
631
- show: 是否显示
632
- save_path: 保存路径(可选)
963
+ curve: Parametric curve object
964
+ show_tangents: Show tangent vectors
965
+ show_normals: Show normal vectors
966
+ show_binormals: Show binormal vectors
967
+ show_frames: Show complete Frenet frames
968
+ frame_skip: Frame drawing interval
969
+ figsize: Figure size
970
+ show: Display figure
971
+ save_path: Save path (optional)
972
+ title: Plot title
633
973
  """
634
974
  vis = CurveVisualizer(figsize=figsize)
635
975
 
636
976
  points = curve.sample_points()
637
- tangents = curve.sample_tangents() if show_tangents else None
638
- normals = curve.sample_normals() if show_normals else None
977
+ tangents = curve.sample_tangents() if show_tangents and not show_frames else None
978
+ normals = curve.sample_normals() if show_normals and not show_frames else None
979
+ binormals = curve.sample_binormals() if show_binormals and not show_frames else None
639
980
  frames = curve.sample_frames() if show_frames else None
640
981
 
641
982
  vis.draw_complete_curve(
642
983
  points=points,
643
984
  tangents=tangents,
644
985
  normals=normals,
986
+ binormals=binormals,
645
987
  frames=frames,
646
988
  frame_skip=frame_skip
647
989
  )
990
+ vis.set_title(title)
991
+
992
+ if save_path:
993
+ vis.save(save_path)
994
+ if show:
995
+ vis.show()
996
+
997
+
998
+ def visualize_surface(
999
+ surface: 'Surface',
1000
+ curvature_type: Optional[str] = None,
1001
+ show_normals: bool = False,
1002
+ show_frames: bool = False,
1003
+ u_range: Tuple[float, float] = (0.1, np.pi - 0.1),
1004
+ v_range: Tuple[float, float] = (0, 2 * np.pi),
1005
+ figsize: Tuple[int, int] = (14, 10),
1006
+ show: bool = True,
1007
+ save_path: Optional[str] = None,
1008
+ title: str = "Surface"
1009
+ ):
1010
+ """
1011
+ Quick visualization of a parametric surface.
1012
+
1013
+ Args:
1014
+ surface: Surface object
1015
+ curvature_type: 'gaussian' or 'mean' for curvature coloring (None for plain)
1016
+ show_normals: Show normal vectors
1017
+ show_frames: Show frame field
1018
+ u_range: Parameter u range
1019
+ v_range: Parameter v range
1020
+ figsize: Figure size
1021
+ show: Display figure
1022
+ save_path: Save path (optional)
1023
+ title: Plot title
1024
+ """
1025
+ vis = SurfaceVisualizer(figsize=figsize)
1026
+
1027
+ if curvature_type:
1028
+ vis.draw_surface_curvature(surface, curvature_type=curvature_type,
1029
+ u_range=u_range, v_range=v_range)
1030
+ else:
1031
+ vis.draw_surface(surface, u_range=u_range, v_range=v_range)
1032
+
1033
+ if show_normals:
1034
+ vis.draw_surface_normals(surface, u_range=u_range, v_range=v_range)
1035
+
1036
+ if show_frames:
1037
+ vis.draw_surface_frames(surface, u_range=u_range, v_range=v_range)
1038
+
1039
+ vis.draw_world_coord(scale=0.5)
1040
+ vis.set_equal_aspect()
1041
+ vis.set_title(title)
648
1042
 
649
1043
  if save_path:
650
1044
  vis.save(save_path)
@@ -652,15 +1046,24 @@ def visualize_curve(
652
1046
  vis.show()
653
1047
 
654
1048
 
655
- # ========== Export ==========
1049
+ # ============================================================
1050
+ # Export
1051
+ # ============================================================
656
1052
 
657
1053
  __all__ = [
658
- # Classes
1054
+ # Visualizer classes
659
1055
  'CoordinateSystemVisualizer',
660
1056
  'CurveVisualizer',
1057
+ 'SurfaceVisualizer',
1058
+
1059
+ # Curve class
661
1060
  'ParametricCurve',
662
1061
 
663
- # Functions
1062
+ # Convenience functions
664
1063
  'visualize_coord_system',
665
1064
  'visualize_curve',
1065
+ 'visualize_surface',
1066
+
1067
+ # Utility
1068
+ 'create_curvature_colormap',
666
1069
  ]