oafuncs 0.0.98.31__py3-none-any.whl → 0.0.98.33__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
oafuncs/oa_draw.py CHANGED
@@ -1,30 +1,15 @@
1
- #!/usr/bin/env python
2
- # coding=utf-8
3
- """
4
- Author: Liu Kun && 16031215@qq.com
5
- Date: 2024-09-17 17:26:11
6
- LastEditors: Liu Kun && 16031215@qq.com
7
- LastEditTime: 2024-11-21 13:10:47
8
- FilePath: \\Python\\My_Funcs\\OAFuncs\\oafuncs\\oa_draw.py
9
- Description:
10
- EditPlatform: vscode
11
- ComputerInfo: XPS 15 9510
12
- SystemInfo: Windows 11
13
- Python Version: 3.11
14
- """
15
-
16
1
  import warnings
17
2
 
18
- import cv2
19
3
  import cartopy.crs as ccrs
20
4
  import cartopy.feature as cfeature
5
+ import cv2
21
6
  import matplotlib as mpl
22
7
  import matplotlib.pyplot as plt
23
8
  import numpy as np
24
9
  from cartopy.mpl.ticker import LatitudeFormatter, LongitudeFormatter
25
10
  from rich import print
26
11
 
27
- __all__ = ["fig_minus", "gif", "add_cartopy", "add_gridlines", "MidpointNormalize", "add_lonlat_unit"]
12
+ __all__ = ["fig_minus", "gif", "movie", "setup_map", "MidpointNormalize"]
28
13
 
29
14
  warnings.filterwarnings("ignore")
30
15
 
@@ -43,15 +28,24 @@ def fig_minus(x_axis: plt.Axes = None, y_axis: plt.Axes = None, colorbar: mpl.co
43
28
  plt.Axes | mpl.colorbar.Colorbar | None: The modified axis or colorbar object.
44
29
 
45
30
  Example:
46
- >>> fig_minus(x_axis=ax, y_axis=None, colorbar=colorbar, decimal_places=2, add_spacing=True)
31
+ >>> fig_minus(x_axis=ax, decimal_places=2, add_spacing=True)
47
32
  """
33
+ current_ticks = None
34
+ target_object = None
35
+
48
36
  # Determine which object to use and get its ticks
49
37
  if x_axis is not None:
50
38
  current_ticks = x_axis.get_xticks()
51
- if y_axis is not None:
39
+ target_object = x_axis
40
+ elif y_axis is not None:
52
41
  current_ticks = y_axis.get_yticks()
53
- if colorbar is not None:
42
+ target_object = y_axis
43
+ elif colorbar is not None:
54
44
  current_ticks = colorbar.get_ticks()
45
+ target_object = colorbar
46
+ else:
47
+ print("[yellow]Warning:[/yellow] No valid axis or colorbar provided.")
48
+ return None
55
49
 
56
50
  # Find index for adding space to non-negative values if needed
57
51
  if add_spacing:
@@ -75,326 +69,330 @@ def fig_minus(x_axis: plt.Axes = None, y_axis: plt.Axes = None, colorbar: mpl.co
75
69
  # Apply formatted ticks to the appropriate object
76
70
  if x_axis is not None:
77
71
  x_axis.set_xticklabels(out_ticks)
78
- if y_axis is not None:
72
+ elif y_axis is not None:
79
73
  y_axis.set_yticklabels(out_ticks)
80
- if colorbar is not None:
74
+ elif colorbar is not None:
81
75
  colorbar.set_ticklabels(out_ticks)
82
76
 
83
77
  print("[green]Axis tick labels updated successfully.[/green]")
84
- return x_axis or y_axis or colorbar
78
+ return target_object
85
79
 
86
80
 
87
- def gif(image_paths: list[str], output_gif_name: str, frame_duration: float = 200, resize_dimensions: tuple[int, int] = None) -> None:
81
+ def gif(image_paths: list[str], output_gif_name: str, frame_duration: float = 0.2, resize_dimensions: tuple[int, int] = None) -> None:
88
82
  """Create a GIF from a list of images.
89
83
 
90
84
  Args:
91
85
  image_paths (list[str]): List of image file paths.
92
86
  output_gif_name (str): Name of the output GIF file.
93
- frame_duration (float): Duration of each frame in milliseconds.
87
+ frame_duration (float): Duration of each frame in seconds. Defaults to 0.2.
94
88
  resize_dimensions (tuple[int, int], optional): Resize dimensions (width, height). Defaults to None.
95
89
 
96
90
  Returns:
97
91
  None
98
92
 
99
93
  Example:
100
- >>> gif(['image1.png', 'image2.png'], 'output.gif', frame_duration=200, resize_dimensions=(800, 600))
94
+ >>> gif(['image1.png', 'image2.png'], 'output.gif', frame_duration=0.5, resize_dimensions=(800, 600))
101
95
  """
102
96
  import imageio.v2 as imageio
103
- import numpy as np
104
97
  from PIL import Image
105
98
 
99
+ if not image_paths:
100
+ print("[red]Error:[/red] Image paths list is empty.")
101
+ return
102
+
106
103
  frames = []
107
104
 
108
- # 获取目标尺寸
105
+ # Get target dimensions
109
106
  if resize_dimensions is None and image_paths:
110
- # 使用第一张图片的尺寸作为标准
111
107
  with Image.open(image_paths[0]) as img:
112
108
  resize_dimensions = img.size
113
109
 
114
- # 读取并调整所有图片的尺寸
110
+ # Read and resize all images
115
111
  for image_name in image_paths:
116
- with Image.open(image_name) as img:
117
- if resize_dimensions:
118
- img = img.resize(resize_dimensions, Image.LANCZOS)
119
- frames.append(np.array(img))
112
+ try:
113
+ with Image.open(image_name) as img:
114
+ if resize_dimensions:
115
+ img = img.resize(resize_dimensions, Image.LANCZOS)
116
+ frames.append(np.array(img))
117
+ except Exception as e:
118
+ print(f"[yellow]Warning:[/yellow] Failed to read image {image_name}: {e}")
119
+ continue
120
120
 
121
- # 修改此处:明确使用 frame_duration 值,并将其作为每帧的持续时间(以秒为单位)
122
- # 某些版本的 imageio 可能需要以毫秒为单位,或者使用 fps 参数
121
+ if not frames:
122
+ print("[red]Error:[/red] No valid images found.")
123
+ return
124
+
125
+ # Create GIF
123
126
  try:
124
- # 先尝试直接使用 frame_duration 参数(以秒为单位)
125
127
  imageio.mimsave(output_gif_name, frames, format="GIF", duration=frame_duration)
128
+ print(f"[green]GIF created successfully![/green] Size: {resize_dimensions}, Frame duration: {frame_duration}s")
126
129
  except Exception as e:
127
- print(f"[yellow]Warning:[/yellow] Attempting to use fps parameter instead of duration: {e}")
128
- # 如果失败,尝试使用 fps 参数(fps = 1/frame_duration)
129
- fps = 1.0 / frame_duration if frame_duration > 0 else 5.0
130
- imageio.mimsave(output_gif_name, frames, format="GIF", fps=fps)
130
+ print(f"[red]Error:[/red] Failed to create GIF: {e}")
131
131
 
132
- print(f"[green]GIF created successfully![/green] Size: {resize_dimensions}, Frame interval: {frame_duration} ms")
133
- return
134
132
 
135
-
136
- def movie(image_files, output_video_path, fps):
137
- """
138
- 从图像文件列表创建视频。
133
+ def movie(image_files: list[str], output_video_path: str, fps: int) -> None:
134
+ """Create a video from a list of image files.
139
135
 
140
136
  Args:
141
- image_files (list): 按顺序排列的图像文件路径列表。
142
- output_video_path (str): 输出视频文件的路径 (例如 'output.mp4')
143
- fps (int): 视频的帧率。
137
+ image_files (list[str]): List of image file paths in order.
138
+ output_video_path (str): Output video file path (e.g., 'output.mp4').
139
+ fps (int): Video frame rate.
140
+
141
+ Returns:
142
+ None
143
+
144
+ Example:
145
+ >>> movie(['img1.jpg', 'img2.jpg'], 'output.mp4', fps=30)
144
146
  """
145
147
  if not image_files:
146
- print("错误:图像文件列表为空。")
148
+ print("[red]Error:[/red] Image files list is empty.")
147
149
  return
148
150
 
149
- # 读取第一张图片以获取帧尺寸
151
+ # Read first image to get frame dimensions
150
152
  try:
151
153
  frame = cv2.imread(image_files[0])
152
154
  if frame is None:
153
- print(f"错误:无法读取第一张图片:{image_files[0]}")
155
+ print(f"[red]Error:[/red] Cannot read first image: {image_files[0]}")
154
156
  return
155
157
  height, width, layers = frame.shape
156
158
  size = (width, height)
157
- print(f"视频尺寸设置为:{size}")
159
+ print(f"Video dimensions set to: {size}")
158
160
  except Exception as e:
159
- print(f"读取第一张图片时出错:{e}")
161
+ print(f"[red]Error:[/red] Error reading first image: {e}")
160
162
  return
161
163
 
162
- # 选择编解码器并创建VideoWriter对象
163
- # 对于 .mp4 文件,常用 'mp4v' 或 'avc1'
164
- # 对于 .avi 文件,常用 'XVID' 或 'MJPG'
165
- fourcc = cv2.VideoWriter_fourcc(*"mp4v") # 或者尝试 'avc1', 'XVID' 等
164
+ # Create VideoWriter object
165
+ fourcc = cv2.VideoWriter_fourcc(*"mp4v")
166
166
  out = cv2.VideoWriter(output_video_path, fourcc, fps, size)
167
167
 
168
168
  if not out.isOpened():
169
- print(f"错误:无法打开视频文件进行写入:{output_video_path}")
170
- print("请检查编解码器 ('fourcc') 是否受支持以及路径是否有效。")
169
+ print(f"[red]Error:[/red] Cannot open video file for writing: {output_video_path}")
170
+ print("Please check if the codec is supported and the path is valid.")
171
171
  return
172
172
 
173
- print(f"开始将图像写入视频:{output_video_path}...")
173
+ print(f"Starting to write images to video: {output_video_path}...")
174
+ successful_frames = 0
175
+
174
176
  for i, filename in enumerate(image_files):
175
177
  try:
176
178
  frame = cv2.imread(filename)
177
179
  if frame is None:
178
- print(f"警告:跳过无法读取的图像:{filename}")
180
+ print(f"[yellow]Warning:[/yellow] Skipping unreadable image: {filename}")
179
181
  continue
180
- # 确保帧尺寸与初始化时相同,如果需要可以调整大小
182
+
183
+ # Ensure frame dimensions match initialization
181
184
  current_height, current_width, _ = frame.shape
182
185
  if (current_width, current_height) != size:
183
- print(f"警告:图像 {filename} 的尺寸 ({current_width}, {current_height}) 与初始尺寸 {size} 不同。将调整大小。")
184
186
  frame = cv2.resize(frame, size)
185
187
 
186
188
  out.write(frame)
187
- # 打印进度(可选)
189
+ successful_frames += 1
190
+
191
+ # Print progress
188
192
  if (i + 1) % 50 == 0 or (i + 1) == len(image_files):
189
- print(f"已处理 {i + 1}/{len(image_files)} ")
193
+ print(f"Processed {i + 1}/{len(image_files)} frames")
190
194
 
191
195
  except Exception as e:
192
- print(f"处理图像 {filename} 时出错:{e}")
193
- continue # 跳过有问题的帧
196
+ print(f"[yellow]Warning:[/yellow] Error processing image {filename}: {e}")
197
+ continue
194
198
 
195
- # 释放资源
199
+ # Release resources
196
200
  out.release()
197
- print(f"视频创建成功:{output_video_path}")
198
-
199
-
200
- def add_lonlat_unit(longitudes: list[float] = None, latitudes: list[float] = None, decimal_places: int = 2) -> tuple[list[str], list[str]] | list[str]:
201
- """Convert longitude and latitude values to formatted string labels.
201
+ print(f"[green]Video created successfully:[/green] {output_video_path} ({successful_frames} frames)")
202
+
203
+
204
+ def setup_map(
205
+ axes: plt.Axes,
206
+ longitude_data: np.ndarray = None,
207
+ latitude_data: np.ndarray = None,
208
+ map_projection: ccrs.Projection = ccrs.PlateCarree(),
209
+ # Map features
210
+ show_land: bool = True,
211
+ show_ocean: bool = True,
212
+ show_coastline: bool = True,
213
+ show_borders: bool = False,
214
+ land_color: str = "lightgrey",
215
+ ocean_color: str = "lightblue",
216
+ coastline_linewidth: float = 0.5,
217
+ # Gridlines and ticks
218
+ show_gridlines: bool = False,
219
+ longitude_ticks: list[float] = None,
220
+ latitude_ticks: list[float] = None,
221
+ tick_decimals: int = 0,
222
+ # Gridline styling
223
+ grid_color: str = "k",
224
+ grid_alpha: float = 0.5,
225
+ grid_style: str = "--",
226
+ grid_width: float = 0.5,
227
+ # Label options
228
+ show_labels: bool = True,
229
+ left_labels: bool = True,
230
+ bottom_labels: bool = True,
231
+ right_labels: bool = False,
232
+ top_labels: bool = False,
233
+ ) -> plt.Axes:
234
+ """Setup a complete cartopy map with customizable features.
202
235
 
203
236
  Args:
204
- longitudes (list[float], optional): List of longitude values to format.
205
- latitudes (list[float], optional): List of latitude values to format.
206
- decimal_places (int, optional): Number of decimal places to display. Defaults to 2.
207
-
208
- Returns:
209
- tuple[list[str], list[str]] | list[str]: Formatted longitude and/or latitude labels.
210
- Returns a tuple of two lists if both longitudes and latitudes are provided,
211
- otherwise returns a single list of formatted values.
212
-
213
- Examples:
214
- >>> add_lonlat_unit(longitudes=[120, 180], latitudes=[30, 60], decimal_places=1)
215
- (['120.0°E', '180.0°'], ['30.0°N', '60.0°N'])
216
- >>> add_lonlat_unit(longitudes=[120, -60])
217
- ['120.00°E', '60.00°W']
218
- """
219
-
220
- def _format_longitude(longitude_values: list[float]) -> list[str] | str:
221
- """Format longitude values to string labels with directional indicators.
222
-
223
- Converts numerical longitude values to formatted strings with degree symbols
224
- and East/West indicators. Values outside the -180 to 180 range are normalized.
225
-
226
- Args:
227
- longitude_values: List of longitude values to format.
228
-
229
- Returns:
230
- List of formatted strings if input contains multiple values,
231
- or a single string if input contains just one value.
232
- """
233
- out_list = []
234
- for x in longitude_values:
235
- if x > 180 or x < -180:
236
- print(f"[yellow]Warning:[/yellow] Longitude value {x} outside normal range (-180 to 180)")
237
- x = ((x + 180) % 360) - 180 # Normalize to -180 to 180 range
238
-
239
- degrees = round(abs(x), decimal_places)
240
- direction = "E" if x >= 0 else "W"
241
- out_list.append(f"{degrees:.{decimal_places}f}°{direction}" if x != 0 and x != 180 else f"{degrees}°")
242
- return out_list if len(out_list) > 1 else out_list[0]
243
-
244
- def _format_latitude(latitude_values: list[float]) -> list[str] | str:
245
- """Format latitude values to string labels with directional indicators.
246
-
247
- Converts numerical latitude values to formatted strings with degree symbols
248
- and North/South indicators. Values outside the -90 to 90 range are normalized.
249
-
250
- Args:
251
- latitude_values (list[float]): List of latitude values to format
252
-
253
- Returns:
254
- list[str] | str: List of formatted strings if input contains multiple values,
255
- or a single string if input contains just one value
256
- """
257
- out_list = []
258
- for y in latitude_values:
259
- if y > 90 or y < -90:
260
- print(f"[yellow]Warning:[/yellow] Latitude value {y} outside valid range (-90 to 90)")
261
- y = min(max(y % 180 - 90, -90), 90) # Normalize to -90 to 90 range
262
-
263
- degrees = round(abs(y), decimal_places)
264
- direction = "N" if y >= 0 else "S"
265
- out_list.append(f"{degrees:.{decimal_places}f}°{direction}" if y != 0 else f"{degrees}°")
266
- return out_list if len(out_list) > 1 else out_list[0]
267
-
268
- # Input validation
269
- if longitudes is not None and not isinstance(longitudes, list):
270
- longitudes = [longitudes] # Convert single value to list
271
- if latitudes is not None and not isinstance(latitudes, list):
272
- latitudes = [latitudes] # Convert single value to list
273
-
274
- if longitudes and latitudes:
275
- result = _format_longitude(longitudes), _format_latitude(latitudes)
276
- elif longitudes:
277
- result = _format_longitude(longitudes)
278
- elif latitudes:
279
- result = _format_latitude(latitudes)
280
- else:
281
- result = []
282
-
283
- print("[green]Longitude and latitude values formatted successfully.[/green]")
284
- return result
285
-
286
-
287
- def add_gridlines(axes: plt.Axes, longitude_lines: list[float] = None, latitude_lines: list[float] = None, map_projection: ccrs.Projection = ccrs.PlateCarree(), line_color: str = "k", line_alpha: float = 0.5, line_style: str = "--", line_width: float = 0.5) -> tuple[plt.Axes, mpl.ticker.Locator]:
288
- """Add gridlines to a map.
289
-
290
- Args:
291
- axes (plt.Axes): The axes to add gridlines to.
292
- longitude_lines (list[float], optional): List of longitude positions for gridlines.
293
- latitude_lines (list[float], optional): List of latitude positions for gridlines.
294
- map_projection (ccrs.Projection, optional): Coordinate reference system. Defaults to PlateCarree.
295
- line_color (str, optional): Line color. Defaults to "k".
296
- line_alpha (float, optional): Line transparency. Defaults to 0.5.
297
- line_style (str, optional): Line style. Defaults to "--".
298
- line_width (float, optional): Line width. Defaults to 0.5.
299
-
300
- Returns:
301
- tuple[plt.Axes, mpl.ticker.Locator]: The axes and gridlines objects.
302
-
303
- Example:
304
- >>> add_gridlines(axes, longitude_lines=[0, 30], latitude_lines=[-90, 90], map_projection=ccrs.PlateCarree())
305
- >>> axes, gl = add_gridlines(axes, longitude_lines=[0, 30], latitude_lines=[-90, 90])
306
- """
307
- from matplotlib import ticker as mticker
308
-
309
- # add gridlines
310
- gl = axes.gridlines(crs=map_projection, draw_labels=True, linewidth=line_width, color=line_color, alpha=line_alpha, linestyle=line_style)
311
- gl.right_labels = False
312
- gl.top_labels = False
313
- gl.xformatter = LongitudeFormatter(zero_direction_label=False)
314
- gl.yformatter = LatitudeFormatter()
315
-
316
- if longitude_lines is not None:
317
- gl.xlocator = mticker.FixedLocator(np.array(longitude_lines))
318
- if latitude_lines is not None:
319
- gl.ylocator = mticker.FixedLocator(np.array(latitude_lines))
320
-
321
- # print("[green]Gridlines added successfully.[/green]")
322
- return axes, gl
323
-
324
-
325
- def add_cartopy(axes: plt.Axes, longitude_data: np.ndarray = None, latitude_data: np.ndarray = None, map_projection: ccrs.Projection = ccrs.PlateCarree(), show_gridlines: bool = True, land_color: str = "lightgrey", ocean_color: str = "lightblue", coastline_linewidth: float = 0.5) -> None:
326
- """Add cartopy features to a map.
327
-
328
- Args:
329
- axes (plt.Axes): The axes to add map features to.
237
+ axes (plt.Axes): The axes to setup as a map.
330
238
  longitude_data (np.ndarray, optional): Array of longitudes to set map extent.
331
239
  latitude_data (np.ndarray, optional): Array of latitudes to set map extent.
332
240
  map_projection (ccrs.Projection, optional): Coordinate reference system. Defaults to PlateCarree.
333
- show_gridlines (bool, optional): Whether to add gridlines. Defaults to True.
241
+
242
+ show_land (bool, optional): Whether to show land features. Defaults to True.
243
+ show_ocean (bool, optional): Whether to show ocean features. Defaults to True.
244
+ show_coastline (bool, optional): Whether to show coastlines. Defaults to True.
245
+ show_borders (bool, optional): Whether to show country borders. Defaults to False.
334
246
  land_color (str, optional): Color of land. Defaults to "lightgrey".
335
247
  ocean_color (str, optional): Color of oceans. Defaults to "lightblue".
336
248
  coastline_linewidth (float, optional): Line width for coastlines. Defaults to 0.5.
337
249
 
338
- Returns:
339
- None
250
+ show_gridlines (bool, optional): Whether to show gridlines. Defaults to False.
251
+ longitude_ticks (list[float], optional): Longitude tick positions.
252
+ latitude_ticks (list[float], optional): Latitude tick positions.
253
+ tick_decimals (int, optional): Number of decimal places for tick labels. Defaults to 0.
340
254
 
341
- Example:
342
- >>> add_cartopy(axes, longitude_data=lon_data, latitude_data=lat_data, map_projection=ccrs.PlateCarree(), show_gridlines=True)
343
- >>> axes = add_cartopy(axes, longitude_data=None, latitude_data=None, map_projection=ccrs.PlateCarree(), show_gridlines=False)
255
+ grid_color (str, optional): Gridline color. Defaults to "k".
256
+ grid_alpha (float, optional): Gridline transparency. Defaults to 0.5.
257
+ grid_style (str, optional): Gridline style. Defaults to "--".
258
+ grid_width (float, optional): Gridline width. Defaults to 0.5.
344
259
 
345
- """
346
- # add coastlines
347
- axes.add_feature(cfeature.LAND, facecolor=land_color)
348
- axes.add_feature(cfeature.OCEAN, facecolor=ocean_color)
349
- axes.add_feature(cfeature.COASTLINE, linewidth=coastline_linewidth)
350
- # axes.add_feature(cfeature.BORDERS, linewidth=coastline_linewidth, linestyle=":")
260
+ show_labels (bool, optional): Whether to show coordinate labels. Defaults to True.
261
+ left_labels (bool, optional): Show labels on left side. Defaults to True.
262
+ bottom_labels (bool, optional): Show labels on bottom. Defaults to True.
263
+ right_labels (bool, optional): Show labels on right side. Defaults to False.
264
+ top_labels (bool, optional): Show labels on top. Defaults to False.
351
265
 
352
- # add gridlines
353
- if show_gridlines:
354
- axes, gl = add_gridlines(axes, map_projection=map_projection)
266
+ Returns:
267
+ plt.Axes: The configured map axes.
355
268
 
356
- # set longitude and latitude format
357
- lon_formatter = LongitudeFormatter(zero_direction_label=False)
358
- lat_formatter = LatitudeFormatter()
359
- axes.xaxis.set_major_formatter(lon_formatter)
360
- axes.yaxis.set_major_formatter(lat_formatter)
269
+ Examples:
270
+ >>> # Basic map setup
271
+ >>> ax = setup_map(ax)
272
+
273
+ >>> # Map with gridlines and custom extent
274
+ >>> ax = setup_map(ax, longitude_data=lon, latitude_data=lat, show_gridlines=True)
275
+
276
+ >>> # Customized map
277
+ >>> ax = setup_map(
278
+ ... ax,
279
+ ... show_gridlines=True,
280
+ ... longitude_ticks=[0, 30, 60],
281
+ ... latitude_ticks=[-30, 0, 30],
282
+ ... land_color='wheat',
283
+ ... ocean_color='lightcyan'
284
+ ... )
285
+ """
286
+ from matplotlib import ticker as mticker
361
287
 
362
- # set extent
288
+ # Add map features
289
+ if show_land:
290
+ axes.add_feature(cfeature.LAND, facecolor=land_color)
291
+ if show_ocean:
292
+ axes.add_feature(cfeature.OCEAN, facecolor=ocean_color)
293
+ if show_coastline:
294
+ axes.add_feature(cfeature.COASTLINE, linewidth=coastline_linewidth)
295
+ if show_borders:
296
+ axes.add_feature(cfeature.BORDERS, linewidth=coastline_linewidth, linestyle=":")
297
+
298
+ # Setup coordinate formatting
299
+ lon_formatter = LongitudeFormatter(zero_direction_label=False, number_format=f".{tick_decimals}f")
300
+ lat_formatter = LatitudeFormatter(number_format=f".{tick_decimals}f")
301
+
302
+ # Handle gridlines and ticks
303
+ if show_gridlines:
304
+ # Add gridlines with labels
305
+ gl = axes.gridlines(crs=map_projection, draw_labels=show_labels, linewidth=grid_width, color=grid_color, alpha=grid_alpha, linestyle=grid_style)
306
+
307
+ # Configure label positions
308
+ gl.left_labels = left_labels
309
+ gl.bottom_labels = bottom_labels
310
+ gl.right_labels = right_labels
311
+ gl.top_labels = top_labels
312
+
313
+ # Set formatters
314
+ gl.xformatter = lon_formatter
315
+ gl.yformatter = lat_formatter
316
+
317
+ # Set custom tick positions if provided
318
+ if longitude_ticks is not None:
319
+ gl.xlocator = mticker.FixedLocator(np.array(longitude_ticks))
320
+ if latitude_ticks is not None:
321
+ gl.ylocator = mticker.FixedLocator(np.array(latitude_ticks))
322
+
323
+ elif show_labels:
324
+ # Add tick labels without gridlines
325
+ # Generate default tick positions based on current extent if not provided
326
+ if longitude_ticks is None:
327
+ current_extent = axes.get_extent(crs=map_projection)
328
+ lon_range = current_extent[1] - current_extent[0]
329
+ # Generate reasonable tick spacing
330
+ tick_spacing = 5 if lon_range <= 30 else (10 if lon_range <= 90 else 20)
331
+ longitude_ticks = np.arange(np.ceil(current_extent[0] / tick_spacing) * tick_spacing, current_extent[1] + tick_spacing, tick_spacing)
332
+
333
+ if latitude_ticks is None:
334
+ current_extent = axes.get_extent(crs=map_projection)
335
+ lat_range = current_extent[3] - current_extent[2]
336
+ # Generate reasonable tick spacing
337
+ tick_spacing = 5 if lat_range <= 30 else (10 if lat_range <= 90 else 20)
338
+ latitude_ticks = np.arange(np.ceil(current_extent[2] / tick_spacing) * tick_spacing, current_extent[3] + tick_spacing, tick_spacing)
339
+
340
+ # Set tick positions and formatters
341
+ axes.set_xticks(longitude_ticks, crs=map_projection)
342
+ axes.set_yticks(latitude_ticks, crs=map_projection)
343
+ axes.xaxis.set_major_formatter(lon_formatter)
344
+ axes.yaxis.set_major_formatter(lat_formatter)
345
+
346
+ # 只要传入经纬度数据就自动设置范围
347
+ # 范围必须在cartopy添加地图特征之后设置,因为添加特征可能会改变axes的范围
363
348
  if longitude_data is not None and latitude_data is not None:
364
- lon_min, lon_max = np.nanmin(longitude_data), np.nanmax(longitude_data)
365
- lat_min, lat_max = np.nanmin(latitude_data), np.nanmax(latitude_data)
366
- axes.set_extent([lon_min, lon_max, lat_min, lat_max], crs=map_projection)
367
-
368
- # print("[green]Cartopy features added successfully.[/green]")
349
+ # 过滤掉NaN,避免极端值影响
350
+ lon_valid = np.asarray(longitude_data)[~np.isnan(longitude_data)]
351
+ lat_valid = np.asarray(latitude_data)[~np.isnan(latitude_data)]
352
+ if lon_valid.size > 0 and lat_valid.size > 0:
353
+ lon_min, lon_max = np.min(lon_valid), np.max(lon_valid)
354
+ lat_min, lat_max = np.min(lat_valid), np.max(lat_valid)
355
+ axes.set_extent([lon_min, lon_max, lat_min, lat_max], crs=map_projection)
356
+ else:
357
+ # 若全是NaN则不设置范围
358
+ pass
369
359
  return axes
370
360
 
371
361
 
372
362
  class MidpointNormalize(mpl.colors.Normalize):
373
- """Custom normalization class to center 0 value.
363
+ """Custom normalization class to center a specific value.
374
364
 
375
365
  Args:
376
- min_value (float, optional): Minimum data value. Defaults to None.
377
- max_value (float, optional): Maximum data value. Defaults to None.
378
- center_value (float, optional): Center value for normalization. Defaults to None.
379
- clip_values (bool, optional): Whether to clip data outside the range. Defaults to False.
366
+ vmin (float, optional): Minimum data value. Defaults to None.
367
+ vmax (float, optional): Maximum data value. Defaults to None.
368
+ vcenter (float, optional): Center value for normalization. Defaults to 0.
369
+ clip (bool, optional): Whether to clip data outside the range. Defaults to False.
380
370
 
381
371
  Example:
382
- >>> norm = MidpointNormalize(min_value=-2, max_value=1, center_value=0)
372
+ >>> norm = MidpointNormalize(vmin=-2, vmax=1, vcenter=0)
383
373
  """
384
374
 
385
- def __init__(self, min_value: float = None, max_value: float = None, center_value: float = None, clip_values: bool = False) -> None:
386
- self.vcenter = center_value
387
- super().__init__(min_value, max_value, clip_values)
375
+ def __init__(self, vmin: float = None, vmax: float = None, vcenter: float = 0, clip: bool = False) -> None:
376
+ self.vcenter = vcenter
377
+ super().__init__(vmin, vmax, clip)
378
+
379
+ def __call__(self, value: np.ndarray, clip: bool = None) -> np.ma.MaskedArray:
380
+ # Use the clip parameter from initialization if not provided
381
+ if clip is None:
382
+ clip = self.clip
388
383
 
389
- def __call__(self, input_values: np.ndarray, clip_values: bool = None) -> np.ma.MaskedArray:
390
384
  x, y = [self.vmin, self.vcenter, self.vmax], [0, 0.5, 1.0]
391
- return np.ma.masked_array(np.interp(input_values, x, y, left=-np.inf, right=np.inf))
385
+ result = np.interp(value, x, y)
392
386
 
393
- def inverse(self, normalized_values: np.ndarray) -> np.ndarray:
394
- y, x = [self.vmin, self.vcenter, self.vmax], [0, 0.5, 1]
395
- return np.interp(normalized_values, x, y, left=-np.inf, right=np.inf)
387
+ # Apply clipping if requested
388
+ if clip:
389
+ result = np.clip(result, 0, 1)
396
390
 
397
- # print("[green]Midpoint normalization applied successfully.[/green]")
391
+ return np.ma.masked_array(result)
392
+
393
+ def inverse(self, value: np.ndarray) -> np.ndarray:
394
+ y, x = [self.vmin, self.vcenter, self.vmax], [0, 0.5, 1]
395
+ return np.interp(value, x, y)
398
396
 
399
397
 
400
398
  if __name__ == "__main__":