oafuncs 0.0.97.16__py3-none-any.whl → 0.0.97.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
oafuncs/oa_help.py CHANGED
@@ -116,6 +116,16 @@ def log():
116
116
  log()
117
117
  """
118
118
  print("更新日志:")
119
+ print(
120
+ """
121
+ 2025-04-06
122
+ 1. 给所有函数使用Python标准的docstring格式(英文)添加/修改说明
123
+ 2. 同时给所有参数添加类型声明
124
+ 3. 逻辑检查,优化
125
+ 4. 使用rich库的print函数,增加颜色
126
+ 5. 所有输出,如果是中文,改成英文
127
+ """
128
+ )
119
129
  print(
120
130
  """
121
131
  2025-01-15
oafuncs/oa_nc.py CHANGED
@@ -1,20 +1,5 @@
1
- #!/usr/bin/env python
2
- # coding=utf-8
3
- """
4
- Author: Liu Kun && 16031215@qq.com
5
- Date: 2024-09-17 14:58:50
6
- LastEditors: Liu Kun && 16031215@qq.com
7
- LastEditTime: 2024-12-06 14:16:56
8
- FilePath: \\Python\\My_Funcs\\OAFuncs\\oafuncs\\oa_nc.py
9
- Description:
10
- EditPlatform: vscode
11
- ComputerInfo: XPS 15 9510
12
- SystemInfo: Windows 11
13
- Python Version: 3.11
14
- """
15
-
16
1
  import os
17
- from typing import List, Optional, Union, Tuple
2
+ from typing import List, Optional, Tuple, Union
18
3
 
19
4
  import netCDF4 as nc
20
5
  import numpy as np
@@ -24,225 +9,273 @@ from rich import print
24
9
  __all__ = ["save", "merge", "modify", "rename", "check", "convert_longitude", "isel", "draw"]
25
10
 
26
11
 
27
- def save(file: str, data: Union[np.ndarray, xr.DataArray], varname: Optional[str] = None, coords: Optional[dict] = None, mode: str = "w", scale_offset_switch: bool = True, compile_switch: bool = True) -> None:
12
+ def save(
13
+ file_path: str,
14
+ data: Union[np.ndarray, xr.DataArray],
15
+ variable_name: Optional[str] = None,
16
+ coordinates: Optional[dict] = None,
17
+ write_mode: str = "w",
18
+ use_scale_offset: bool = True,
19
+ use_compression: bool = True,
20
+ ) -> None:
28
21
  """
29
- Description:
30
- Write data to NetCDF file
31
- Parameters:
32
- file: str, file path
33
- data: np.ndarray or xr.DataArray, data to be written
34
- varname: Optional[str], variable name
35
- coords: Optional[dict], coordinates, key is the dimension name, value is the coordinate data
36
- mode: str, write mode, 'w' for write, 'a' for append
37
- scale_offset_switch: bool, whether to use scale_factor and add_offset, default is True
38
- compile_switch: bool, whether to use compression parameters, default is True
22
+ Write data to a NetCDF file.
23
+
24
+ Args:
25
+ file_path (str): File path to save the NetCDF file.
26
+ data (Union[np.ndarray, xr.DataArray]): Data to be written.
27
+ variable_name (Optional[str]): Variable name for the data.
28
+ coordinates (Optional[dict]): Coordinates, where keys are dimension names and values are coordinate data.
29
+ write_mode (str): Write mode, 'w' for write, 'a' for append. Default is 'w'.
30
+ use_scale_offset (bool): Whether to use scale_factor and add_offset. Default is True.
31
+ use_compression (bool): Whether to use compression parameters. Default is True.
32
+
39
33
  Example:
40
- save(r'test.nc', data, 'u', {'time': np.linspace(0, 120, 100), 'lev': np.linspace(0, 120, 50)}, 'a')
34
+ >>> save(r'test.nc', data, 'u', {'time': np.linspace(0, 120, 100), 'lev': np.linspace(0, 120, 50)}, 'a')
35
+ >>> save(r'test.nc', data, 'u', {'time': np.linspace(0, 120, 100), 'lev': np.linspace(0, 120, 50)}, 'w')
36
+ >>> save(r'test.nc', data, 'u', {'time': np.linspace(0, 120, 100), 'lev': np.linspace(0, 120, 50)}, 'w', use_scale_offset=False, use_compression=False)
37
+ >>> save(r'test.nc', data)
41
38
  """
42
39
  from ._script.netcdf_write import save_to_nc
43
40
 
44
- save_to_nc(file, data, varname, coords, mode, scale_offset_switch, compile_switch)
41
+ save_to_nc(file_path, data, variable_name, coordinates, write_mode, use_scale_offset, use_compression)
42
+ print(f"[green]Data successfully saved to {file_path}[/green]")
45
43
 
46
44
 
47
- def merge(file_list: Union[str, List[str]], var_name: Optional[Union[str, List[str]]] = None, dim_name: Optional[str] = None, target_filename: Optional[str] = None) -> None:
45
+ def merge(
46
+ file_paths: Union[str, List[str]],
47
+ variable_names: Optional[Union[str, List[str]]] = None,
48
+ merge_dimension: Optional[str] = None,
49
+ output_file: Optional[str] = None,
50
+ ) -> None:
48
51
  """
49
- Description:
50
- Merge multiple NetCDF files into one.
51
- Parameters:
52
- file_list: Union[str, List[str]], list of file paths or a single file path
53
- var_name: Optional[Union[str, List[str]]], variable names to merge
54
- dim_name: Optional[str], dimension name to merge along
55
- target_filename: Optional[str], output file name
52
+ Merge multiple NetCDF files into one.
53
+
54
+ Args:
55
+ file_paths (Union[str, List[str]]): List of file paths or a single file path.
56
+ variable_names (Optional[Union[str, List[str]]]): Variable names to merge.
57
+ merge_dimension (Optional[str]): Dimension name to merge along.
58
+ output_file (Optional[str]): Output file name.
59
+
60
+ Example:
61
+ merge(['file1.nc', 'file2.nc'], variable_names='temperature', merge_dimension='time', output_file='merged.nc')
56
62
  """
57
63
  from ._script.netcdf_merge import merge_nc
58
64
 
59
- merge_nc(file_list, var_name, dim_name, target_filename)
65
+ merge_nc(file_paths, variable_names, merge_dimension, output_file)
66
+ print(f"[green]Files successfully merged into {output_file}[/green]")
60
67
 
61
68
 
62
- def modify(nc_file: str, var_name: str, attr_name: Optional[str] = None, new_value: Optional[Union[str, float, int, np.ndarray]] = None) -> None:
69
+ def modify(
70
+ file_path: str,
71
+ variable_name: str,
72
+ attribute_name: Optional[str] = None,
73
+ new_value: Optional[Union[str, float, int, np.ndarray]] = None,
74
+ ) -> None:
63
75
  """
64
- Description:
65
- Modify the value of a variable or the value of an attribute in a NetCDF file.
66
- Parameters:
67
- nc_file: str, the path to the NetCDF file
68
- var_name: str, the name of the variable to be modified
69
- attr_name: Optional[str], the name of the attribute to be modified. If None, the variable value will be modified
70
- new_value: Optional[Union[str, float, int, np.ndarray]], the new value of the variable or attribute
76
+ Modify the value of a variable or an attribute in a NetCDF file.
77
+
78
+ Args:
79
+ file_path (str): Path to the NetCDF file.
80
+ variable_name (str): Name of the variable to be modified.
81
+ attribute_name (Optional[str]): Name of the attribute to be modified. If None, the variable value will be modified.
82
+ new_value (Optional[Union[str, float, int, np.ndarray]]): New value for the variable or attribute.
83
+
84
+ Example:
85
+ >>> modify('file.nc', 'temperature', 'units', 'Celsius')
86
+ >>> modify('file.nc', 'temperature', new_value=np.array([1, 2, 3]))
71
87
  """
72
88
  from ._script.netcdf_modify import modify_nc
73
89
 
74
- modify_nc(nc_file, var_name, attr_name, new_value)
90
+ modify_nc(file_path, variable_name, attribute_name, new_value)
91
+ print(f"[green]Successfully modified {variable_name} in {file_path}[/green]")
75
92
 
76
93
 
77
- def rename(ncfile_path: str, old_name: str, new_name: str) -> None:
94
+ def rename(
95
+ file_path: str,
96
+ old_name: str,
97
+ new_name: str,
98
+ ) -> None:
78
99
  """
79
- Description:
80
- Rename a variable and/or dimension in a NetCDF file.
81
- Parameters:
82
- ncfile_path: str, the path to the NetCDF file
83
- old_name: str, the current name of the variable or dimension
84
- new_name: str, the new name to assign to the variable or dimension
100
+ Rename a variable or dimension in a NetCDF file.
101
+
102
+ Args:
103
+ file_path (str): Path to the NetCDF file.
104
+ old_name (str): Current name of the variable or dimension.
105
+ new_name (str): New name to assign to the variable or dimension.
106
+
107
+ Example:
108
+ >>> rename('file.nc', 'old_var', 'new_var')
85
109
  """
86
110
  try:
87
- with nc.Dataset(ncfile_path, "r+") as dataset:
88
- # If the old name is not found as a variable or dimension, print a message
111
+ with nc.Dataset(file_path, "r+") as dataset:
89
112
  if old_name not in dataset.variables and old_name not in dataset.dimensions:
90
- print(f"Variable or dimension {old_name} not found in the file.")
113
+ print(f"[yellow]Variable or dimension {old_name} not found in the file.[/yellow]")
114
+ return
91
115
 
92
- # Attempt to rename the variable
93
116
  if old_name in dataset.variables:
94
117
  dataset.renameVariable(old_name, new_name)
95
- print(f"Successfully renamed variable {old_name} to {new_name}.")
118
+ print(f"[green]Successfully renamed variable {old_name} to {new_name}.[/green]")
96
119
 
97
- # Attempt to rename the dimension
98
120
  if old_name in dataset.dimensions:
99
- # Check if the new dimension name already exists
100
121
  if new_name in dataset.dimensions:
101
122
  raise ValueError(f"Dimension name {new_name} already exists in the file.")
102
123
  dataset.renameDimension(old_name, new_name)
103
- print(f"Successfully renamed dimension {old_name} to {new_name}.")
124
+ print(f"[green]Successfully renamed dimension {old_name} to {new_name}.[/green]")
104
125
 
105
126
  except Exception as e:
106
- print(f"An error occurred: {e}")
127
+ print(f"[red]An error occurred: {e}[/red]")
107
128
 
108
129
 
109
- def check(ncfile: str, delete_switch: bool = False, print_switch: bool = True) -> bool:
130
+ def check(
131
+ file_path: str,
132
+ delete_if_invalid: bool = False,
133
+ print_messages: bool = True,
134
+ ) -> bool:
110
135
  """
111
- Description:
112
- Check if a NetCDF file is corrupted with enhanced error handling.
113
- Parameters:
114
- ncfile: str, the path to the NetCDF file
115
- delete_switch: bool, whether to delete the file if it is corrupted
116
- print_switch: bool, whether to print messages during the check
136
+ Check if a NetCDF file is corrupted.
137
+
138
+ Args:
139
+ file_path (str): Path to the NetCDF file.
140
+ delete_if_invalid (bool): Whether to delete the file if it is corrupted. Default is False.
141
+ print_messages (bool): Whether to print messages during the check. Default is True.
142
+
117
143
  Returns:
118
- bool: True if the file is valid, False otherwise
144
+ bool: True if the file is valid, False otherwise.
145
+
146
+ Example:
147
+ >>> is_valid = check('file.nc', delete_if_invalid=True)
119
148
  """
120
149
  is_valid = False
121
150
 
122
- if not os.path.exists(ncfile):
123
- if print_switch:
124
- print(f"[#ffeac5]Local file missing: [#009d88]{ncfile}")
125
- # 提示:提示文件缺失也许是正常的,这只是检查文件是否存在于本地
126
- print("[#d6d9fd]Note: File missing may be normal, this is just to check if the file exists locally.")
151
+ if not os.path.exists(file_path):
152
+ if print_messages:
153
+ print(f"[yellow]File not found: {file_path}[/yellow]")
127
154
  return False
128
155
 
129
156
  try:
130
- # # 深度验证文件结构
131
- # with nc.Dataset(ncfile, "r") as ds:
132
- # # 显式检查文件结构完整性
133
- # ds.sync() # 强制刷新缓冲区
134
- # ds.close() # 显式关闭后重新打开验证
135
-
136
- # 二次验证确保变量可访问
137
- with nc.Dataset(ncfile, "r") as ds_verify:
157
+ with nc.Dataset(file_path, "r") as ds_verify:
138
158
  if not ds_verify.variables:
139
- if print_switch:
140
- print(f"[red]Empty variables: {ncfile}[/red]")
159
+ if print_messages:
160
+ print(f"[red]Empty variables in file: {file_path}[/red]")
141
161
  else:
142
- # 尝试访问元数据
143
162
  _ = ds_verify.__dict__
144
- # 抽样检查第一个变量
145
163
  for var in ds_verify.variables.values():
146
- _ = var.shape # 触发实际数据访问
164
+ _ = var.shape
147
165
  break
148
166
  is_valid = True
149
167
 
150
- except Exception as e: # 捕获所有异常类型
151
- if print_switch:
152
- print(f"[red]HDF5 validation failed for {ncfile}: {str(e)}[/red]")
153
- error_type = type(e).__name__
154
- if "HDF5" in error_type or "h5" in error_type.lower():
155
- if print_switch:
156
- print(f"[red]Critical HDF5 structure error detected in {ncfile}[/red]")
157
-
158
- # 安全删除流程
159
- if not is_valid:
160
- if delete_switch:
161
- try:
162
- os.remove(ncfile)
163
- if print_switch:
164
- print(f"[red]Removed corrupted file: {ncfile}[/red]")
165
- except Exception as del_error:
166
- if print_switch:
167
- print(f"[red]Failed to delete corrupted file: {ncfile} - {str(del_error)}[/red]")
168
- return False
169
-
170
- return True
168
+ except Exception as e:
169
+ if print_messages:
170
+ print(f"[red]File validation failed: {file_path} - {str(e)}[/red]")
171
+
172
+ if not is_valid and delete_if_invalid:
173
+ try:
174
+ os.remove(file_path)
175
+ if print_messages:
176
+ print(f"[red]Deleted corrupted file: {file_path}[/red]")
177
+ except Exception as del_error:
178
+ if print_messages:
179
+ print(f"[red]Failed to delete file: {file_path} - {str(del_error)}[/red]")
180
+
181
+ return is_valid
182
+
183
+
184
+ def convert_longitude(
185
+ dataset: xr.Dataset,
186
+ longitude_name: str = "longitude",
187
+ target_range: int = 180,
188
+ ) -> xr.Dataset:
189
+ """
190
+ Convert the longitude array to a specified range.
171
191
 
192
+ Args:
193
+ dataset (xr.Dataset): The xarray dataset containing the longitude data.
194
+ longitude_name (str): Name of the longitude variable. Default is "longitude".
195
+ target_range (int): Target range to convert to, either 180 or 360. Default is 180.
172
196
 
173
- def convert_longitude(ds: xr.Dataset, lon_name: str = "longitude", convert: int = 180) -> xr.Dataset:
174
- """
175
- Description:
176
- Convert the longitude array to a specified range.
177
- Parameters:
178
- ds: xr.Dataset, the xarray dataset containing the longitude data
179
- lon_name: str, the name of the longitude variable, default is "longitude"
180
- convert: int, the target range to convert to, can be 180 or 360, default is 180
181
197
  Returns:
182
- xr.Dataset: The xarray dataset with the converted longitude
198
+ xr.Dataset: Dataset with converted longitude.
199
+
200
+ Example:
201
+ >>> dataset = convert_longitude(dataset, longitude_name="lon", target_range=360)
183
202
  """
184
- to_which = int(convert)
185
- if to_which not in [180, 360]:
186
- raise ValueError("convert value must be '180' or '360'")
203
+ if target_range not in [180, 360]:
204
+ raise ValueError("target_range value must be 180 or 360")
187
205
 
188
- if to_which == 180:
189
- ds = ds.assign_coords({lon_name: (ds[lon_name] + 180) % 360 - 180})
190
- elif to_which == 360:
191
- ds = ds.assign_coords({lon_name: (ds[lon_name] + 360) % 360})
206
+ if target_range == 180:
207
+ dataset = dataset.assign_coords({longitude_name: (dataset[longitude_name] + 180) % 360 - 180})
208
+ else:
209
+ dataset = dataset.assign_coords({longitude_name: (dataset[longitude_name] + 360) % 360})
192
210
 
193
- return ds.sortby(lon_name)
211
+ return dataset.sortby(longitude_name)
194
212
 
195
213
 
196
- def isel(ncfile: str, dim_name: str, slice_list: List[int]) -> xr.Dataset:
214
+ def isel(
215
+ file_path: str,
216
+ dimension_name: str,
217
+ indices: List[int],
218
+ ) -> xr.Dataset:
197
219
  """
198
- Description:
199
- Choose the data by the index of the dimension.
200
- Parameters:
201
- ncfile: str, the path of the netCDF file
202
- dim_name: str, the name of the dimension
203
- slice_list: List[int], the indices of the dimension
220
+ Select data by the index of a dimension.
221
+
222
+ Args:
223
+ file_path (str): Path to the NetCDF file.
224
+ dimension_name (str): Name of the dimension.
225
+ indices (List[int]): Indices of the dimension to select.
226
+
204
227
  Returns:
205
- xr.Dataset: The subset dataset
228
+ xr.Dataset: Subset dataset.
229
+
230
+ Example:
231
+ >>> subset = isel('file.nc', 'time', [0, 1, 2])
206
232
  """
207
- ds = xr.open_dataset(ncfile)
208
- slice_list = np.array(slice_list).flatten()
209
- slice_list = [int(i) for i in slice_list]
210
- ds_new = ds.isel(**{dim_name: slice_list})
233
+ ds = xr.open_dataset(file_path)
234
+ indices = [int(i) for i in np.array(indices).flatten()]
235
+ ds_new = ds.isel(**{dimension_name: indices})
211
236
  ds.close()
212
237
  return ds_new
213
238
 
214
239
 
215
- def draw(output_dir: Optional[str] = None, dataset: Optional[xr.Dataset] = None, ncfile: Optional[str] = None, xyzt_dims: Union[List[str], Tuple[str, str, str, str]] = ("longitude", "latitude", "level", "time"), plot_type: str = "contourf", fixed_colorscale: bool = False) -> None:
240
+ def draw(
241
+ output_directory: Optional[str] = None,
242
+ dataset: Optional[xr.Dataset] = None,
243
+ file_path: Optional[str] = None,
244
+ dimensions: Union[List[str], Tuple[str, str, str, str]] = ("longitude", "latitude", "level", "time"),
245
+ plot_style: str = "contourf",
246
+ use_fixed_colorscale: bool = False,
247
+ ) -> None:
216
248
  """
217
- Description:
218
- Draw the data in the netCDF file.
219
- Parameters:
220
- output_dir: Optional[str], the path of the output directory
221
- dataset: Optional[xr.Dataset], the xarray dataset to plot
222
- ncfile: Optional[str], the path of the netCDF file
223
- xyzt_dims: Union[List[str], Tuple[str, str, str, str]], the dimensions for plotting
224
- plot_type: str, the type of the plot, default is "contourf" (contourf, contour)
225
- fixed_colorscale: bool, whether to use fixed colorscale, default is False
249
+ Draw data from a NetCDF file.
250
+
251
+ Args:
252
+ output_directory (Optional[str]): Path of the output directory.
253
+ dataset (Optional[xr.Dataset]): Xarray dataset to plot.
254
+ file_path (Optional[str]): Path to the NetCDF file.
255
+ dimensions (Union[List[str], Tuple[str, str, str, str]]): Dimensions for plotting.
256
+ plot_style (str): Type of the plot, e.g., "contourf" or "contour". Default is "contourf".
257
+ use_fixed_colorscale (bool): Whether to use a fixed colorscale. Default is False.
258
+
259
+ Example:
260
+ >>> draw(output_directory="plots", file_path="file.nc", plot_style="contour")
226
261
  """
227
262
  from ._script.plot_dataset import func_plot_dataset
228
263
 
229
- if output_dir is None:
230
- output_dir = str(os.getcwd())
231
- if isinstance(xyzt_dims, (list, tuple)):
232
- xyzt_dims = tuple(xyzt_dims)
233
- else:
234
- raise ValueError("xyzt_dims must be a list or tuple")
264
+ if output_directory is None:
265
+ output_directory = os.getcwd()
266
+ if not isinstance(dimensions, (list, tuple)):
267
+ raise ValueError("dimensions must be a list or tuple")
268
+
235
269
  if dataset is not None:
236
- func_plot_dataset(dataset, output_dir, xyzt_dims, plot_type, fixed_colorscale)
237
- else:
238
- if ncfile is not None:
239
- if check(ncfile):
240
- ds = xr.open_dataset(ncfile)
241
- func_plot_dataset(ds, output_dir, xyzt_dims, plot_type, fixed_colorscale)
242
- else:
243
- print(f"Invalid file: {ncfile}")
270
+ func_plot_dataset(dataset, output_directory, tuple(dimensions), plot_style, use_fixed_colorscale)
271
+ elif file_path is not None:
272
+ if check(file_path):
273
+ ds = xr.open_dataset(file_path)
274
+ func_plot_dataset(ds, output_directory, tuple(dimensions), plot_style, use_fixed_colorscale)
244
275
  else:
245
- print("No dataset or file provided.")
276
+ print(f"[red]Invalid file: {file_path}[/red]")
277
+ else:
278
+ print("[red]No dataset or file provided.[/red]")
246
279
 
247
280
 
248
281
  if __name__ == "__main__":
oafuncs/oa_python.py CHANGED
@@ -14,23 +14,37 @@ Python Version: 3.12
14
14
  """
15
15
 
16
16
  import os
17
+ from typing import List, Optional
17
18
 
18
19
  from rich import print
19
20
 
20
21
  __all__ = ["install_packages", "upgrade_packages"]
21
22
 
22
23
 
23
- def install_packages(packages=None, python_executable="python", package_manager="pip"):
24
+ def install_packages(
25
+ packages: Optional[List[str]] = None,
26
+ python_executable: str = "python",
27
+ package_manager: str = "pip",
28
+ ) -> None:
24
29
  """
25
- packages: list, libraries to be installed
26
- python_executable: str, Python version; for example, on Windows, copy python.exe to python312.exe, then set python_executable='python312'
27
- package_manager: str, the package manager to use ('pip' or 'conda')
30
+ Install the specified Python packages using the given package manager.
31
+
32
+ Args:
33
+ packages (Optional[List[str]]): A list of libraries to be installed. If None, no packages will be installed.
34
+ python_executable (str): The Python executable to use (e.g., 'python312').
35
+ package_manager (str): The package manager to use ('pip' or 'conda').
36
+
37
+ Raises:
38
+ ValueError: If 'packages' is not a list or None, or if 'package_manager' is not 'pip' or 'conda'.
39
+
40
+ Example:
41
+ >>> install_packages(packages=["numpy", "pandas"], python_executable="python", package_manager="pip")
28
42
  """
29
43
  if not isinstance(packages, (list, type(None))):
30
- raise ValueError("The 'packages' parameter must be a list or None")
44
+ raise ValueError("[red]The 'packages' parameter must be a list or None[/red]")
31
45
 
32
46
  if package_manager not in ["pip", "conda"]:
33
- raise ValueError("The 'package_manager' parameter must be either 'pip' or 'conda'")
47
+ raise ValueError("[red]The 'package_manager' parameter must be either 'pip' or 'conda'[/red]")
34
48
 
35
49
  if package_manager == "conda":
36
50
  if not packages:
@@ -39,11 +53,11 @@ def install_packages(packages=None, python_executable="python", package_manager=
39
53
  package_count = len(packages)
40
54
  for i, package in enumerate(packages):
41
55
  os.system(f"conda install -c conda-forge {package} -y")
42
- print("-" * 100)
43
- print(f"Successfully installed {package} ({i + 1}/{package_count})")
44
- print("-" * 100)
56
+ print(f"[green]{'-' * 100}[/green]")
57
+ print(f"[green]Successfully installed {package} ({i + 1}/{package_count})[/green]")
58
+ print(f"[green]{'-' * 100}[/green]")
45
59
  except Exception as e:
46
- print(f"Installation failed: {str(e)}")
60
+ print(f"[red]Installation failed: {str(e)}[/red]")
47
61
  return
48
62
 
49
63
  os.system(f"{python_executable} -m ensurepip")
@@ -55,29 +69,41 @@ def install_packages(packages=None, python_executable="python", package_manager=
55
69
  installed_packages = {pkg.split("==")[0].lower() for pkg in installed_packages}
56
70
  package_count = len(packages)
57
71
  for i, package in enumerate(packages):
58
- # Check if the library is already installed, skip if installed
59
72
  if package.lower() in installed_packages:
60
- print(f"{package} is already installed")
73
+ print(f"[yellow]{package} is already installed[/yellow]")
61
74
  continue
62
75
  os.system(f"{python_executable} -m pip install {package}")
63
- print("-" * 100)
64
- print(f"Successfully installed {package} ({i + 1}/{package_count})")
65
- print("-" * 100)
76
+ print(f"[green]{'-' * 100}[/green]")
77
+ print(f"[green]Successfully installed {package} ({i + 1}/{package_count})[/green]")
78
+ print(f"[green]{'-' * 100}[/green]")
66
79
  except Exception as e:
67
- print(f"Installation failed: {str(e)}")
80
+ print(f"[red]Installation failed: {str(e)}[/red]")
68
81
 
69
82
 
70
- def upgrade_packages(packages=None, python_executable="python", package_manager="pip"):
83
+ def upgrade_packages(
84
+ packages: Optional[List[str]] = None,
85
+ python_executable: str = "python",
86
+ package_manager: str = "pip",
87
+ ) -> None:
71
88
  """
72
- packages: list, libraries to be upgraded
73
- python_executable: str, Python version; for example, on Windows, copy python.exe to python312.exe, then set python_executable='python312'
74
- package_manager: str, the package manager to use ('pip' or 'conda')
89
+ Upgrade the specified Python packages using the given package manager.
90
+
91
+ Args:
92
+ packages (Optional[List[str]]): A list of libraries to be upgraded. If None, all installed packages will be upgraded.
93
+ python_executable (str): The Python executable to use (e.g., 'python312').
94
+ package_manager (str): The package manager to use ('pip' or 'conda').
95
+
96
+ Raises:
97
+ ValueError: If 'packages' is not a list or None, or if 'package_manager' is not 'pip' or 'conda'.
98
+
99
+ Example:
100
+ >>> upgrade_packages(packages=["numpy", "pandas"], python_executable="python", package_manager="pip")
75
101
  """
76
102
  if not isinstance(packages, (list, type(None))):
77
- raise ValueError("The 'packages' parameter must be a list or None")
103
+ raise ValueError("[red]The 'packages' parameter must be a list or None[/red]")
78
104
 
79
105
  if package_manager not in ["pip", "conda"]:
80
- raise ValueError("The 'package_manager' parameter must be either 'pip' or 'conda'")
106
+ raise ValueError("[red]The 'package_manager' parameter must be either 'pip' or 'conda'[/red]")
81
107
 
82
108
  try:
83
109
  if package_manager == "conda":
@@ -86,13 +112,13 @@ def upgrade_packages(packages=None, python_executable="python", package_manager=
86
112
  packages = [pkg.split("=")[0] for pkg in installed_packages if not pkg.startswith("#")]
87
113
  for package in packages:
88
114
  os.system(f"conda update -c conda-forge {package} -y")
89
- print("Upgrade successful")
115
+ print("[green]Upgrade successful[/green]")
90
116
  else:
91
117
  if not packages:
92
118
  installed_packages = os.popen(f"{python_executable} -m pip list --format=freeze").read().splitlines()
93
119
  packages = [pkg.split("==")[0] for pkg in installed_packages]
94
120
  for package in packages:
95
121
  os.system(f"{python_executable} -m pip install --upgrade {package}")
96
- print("Upgrade successful")
122
+ print("[green]Upgrade successful[/green]")
97
123
  except Exception as e:
98
- print(f"Upgrade failed: {str(e)}")
124
+ print(f"[red]Upgrade failed: {str(e)}[/red]")