jacksung-dev 0.0.4.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. jacksung/__init__.py +1 -0
  2. jacksung/ai/GeoAttX.py +356 -0
  3. jacksung/ai/GeoNet/__init__.py +0 -0
  4. jacksung/ai/GeoNet/m_block.py +393 -0
  5. jacksung/ai/GeoNet/m_blockV2.py +442 -0
  6. jacksung/ai/GeoNet/m_network.py +107 -0
  7. jacksung/ai/GeoNet/m_networkV2.py +91 -0
  8. jacksung/ai/__init__.py +0 -0
  9. jacksung/ai/latex_tool.py +199 -0
  10. jacksung/ai/metrics.py +181 -0
  11. jacksung/ai/utils/__init__.py +0 -0
  12. jacksung/ai/utils/cmorph.py +42 -0
  13. jacksung/ai/utils/data_parallelV2.py +90 -0
  14. jacksung/ai/utils/fy.py +333 -0
  15. jacksung/ai/utils/goes.py +161 -0
  16. jacksung/ai/utils/gsmap.py +24 -0
  17. jacksung/ai/utils/imerg.py +159 -0
  18. jacksung/ai/utils/metsat.py +164 -0
  19. jacksung/ai/utils/norm_util.py +109 -0
  20. jacksung/ai/utils/util.py +300 -0
  21. jacksung/libs/times.ttf +0 -0
  22. jacksung/utils/__init__.py +1 -0
  23. jacksung/utils/base_db.py +72 -0
  24. jacksung/utils/cache.py +71 -0
  25. jacksung/utils/data_convert.py +273 -0
  26. jacksung/utils/exception.py +27 -0
  27. jacksung/utils/fastnumpy.py +115 -0
  28. jacksung/utils/figure.py +251 -0
  29. jacksung/utils/hash.py +26 -0
  30. jacksung/utils/image.py +221 -0
  31. jacksung/utils/log.py +86 -0
  32. jacksung/utils/login.py +149 -0
  33. jacksung/utils/mean_std.py +66 -0
  34. jacksung/utils/multi_task.py +129 -0
  35. jacksung/utils/number.py +6 -0
  36. jacksung/utils/nvidia.py +140 -0
  37. jacksung/utils/time.py +87 -0
  38. jacksung/utils/web.py +63 -0
  39. jacksung_dev-0.0.4.15.dist-info/LICENSE +201 -0
  40. jacksung_dev-0.0.4.15.dist-info/METADATA +228 -0
  41. jacksung_dev-0.0.4.15.dist-info/RECORD +44 -0
  42. jacksung_dev-0.0.4.15.dist-info/WHEEL +5 -0
  43. jacksung_dev-0.0.4.15.dist-info/entry_points.txt +3 -0
  44. jacksung_dev-0.0.4.15.dist-info/top_level.txt +1 -0
@@ -0,0 +1,199 @@
1
+ import os
2
+ from openai import OpenAI
3
+ from tqdm import tqdm
4
+ from jacksung.utils.time import Stopwatch, get_time_str
5
+
6
+
7
+ def get_en_polish_prompt(text, prompt_type='polish'):
8
+ polish_prompt = \
9
+ fr'''
10
+ # Rewrite the text in an academic writing style, using more appropriate vocabulary and sentence structure while keeping the original meaning unchanged:
11
+ - Make sure the rewritten version conveys the same information and intention as the original text.
12
+ - Please output the rewritten text directly in latex format, without including the original text, thinking logic, comments, explanations, etc.
13
+ - Do not output any control commands that do not exist in the input content (such as \documentclass, \begin, \end, etc.), just use latex format to output mathematical formulas, symbols, references, or other latex instructions contained in the input content.
14
+ - Note that special symbols and formulas are output in latex format, not directly output special characters.
15
+ - Only when the input content contains control codes such as \par, the code needs to be added to the corresponding position of the output content.
16
+ - If the input content only contains code and does not contain any substantial text content, the input content is directly output without any changes. Be careful not to miss symbols such as brackets.
17
+ - Make sure that the output content can be compiled normally after replacing the input content in the original document.
18
+ The following is the input content:
19
+ {text}
20
+ '''
21
+ check_prompt = \
22
+ fr'''
23
+ # Correct grammatical errors and misspelled words in the input text while preserving the original meaning and format:
24
+ - If the text is free of grammatical errors and spelling mistakes, output the original text without making any modifications
25
+ - Ensure the corrected version maintains the exact same information and intent as the original text
26
+ - Output only the corrected text directly without any additional content
27
+ - Preserve all mathematical formulas, symbols, and special formatting exactly as input
28
+ - Maintain the original document structure and formatting commands
29
+ - Only modify actual grammatical errors and misspelled words
30
+ - If the input contains LaTeX code, preserve it exactly and only correct text outside code blocks
31
+ - Do not alter technical terms, proper nouns, or specialized vocabulary
32
+ - If no errors are detected, output the original text unchanged
33
+ - Do not include any explanations, comments, or thinking process in the output
34
+
35
+ Input text:
36
+ {text}
37
+ '''
38
+ if prompt_type == 'polish':
39
+ return polish_prompt
40
+ elif prompt_type == 'check':
41
+ return check_prompt
42
+ else:
43
+ raise Exception(rf'Unknown prompt type {prompt_type}, please specify "polish" or "check"')
44
+
45
+
46
+ def get_cn_polish_prompt(text, prompt_type='polish'):
47
+ polish_prompt = \
48
+ fr'''
49
+ # 用学术写作风格重写下面的文本,在保持原本涵义不变的情况下使用更合适的词汇和句子结构:
50
+ - 确保改写后的版本传达的信息和意图与原文相同
51
+ - 请直接以latex格式输出重写后的文本,不需要包含原文、思考逻辑、注释、解释说明等其他内容。
52
+ - 不需要输出任何输入内容中不存在的控制命令(如\documentclass、\begin、\end等),只需要使用latex格式输出数学公式、符号、引用或者输入内容中所包含的其他latex指令。
53
+ - 注意特殊符号和公式以latex格式输出,而不是直接输出特殊字符。
54
+ - 仅在输入内容中包含\par类似的控制性代码时,在输出内容对应位置需要添加该代码。
55
+ - 如果输入内容仅包含代码,不包含任何实质性的文本内容,则直接将输入内容不做任何改动输出。注意不要遗漏括号等符号。
56
+ - 确保输出内容在原文档中替换输入内容后能够正常编译通过。
57
+ - 重要:你需要确定输入内容所使用的语言,然后使用相同语言进行输出,以保证输入和输出语言一致。
58
+ 以下为输入内容:
59
+ {text}
60
+ '''
61
+ check_prompt = \
62
+ fr'''
63
+ # 修正输入文本中的语法错误和单词拼写错误,同时保持原意和格式不变:\
64
+ - 如果如数的文本没有语法错误和单词拼写则原文输出,不做任何改动
65
+ - 确保修正后的文本与原文的信息和意图完全相同
66
+ - 请直接输出修正后的文本,不需要包含原文、思考逻辑、注释、解释说明等其他内容
67
+ - 不需要输出任何输入内容中不存在的控制命令(如\documentclass、\begin、\end等),只需要使用latex格式输出数学公式、符号、引用或者输入内容中所包含的其他latex指令
68
+ - 注意特殊符号和公式以latex格式输出,而不是直接输出特殊字符
69
+ - 仅在输入内容中包含\par类似的控制性代码时,在输出内容对应位置需要添加该代码
70
+ - 如果输入内容仅包含代码,不包含任何实质性的文本内容,则直接将输入内容不做任何改动输出。注意不要遗漏括号等符号
71
+ - 确保输出内容在原文档中替换输入内容后能够正常编译通过
72
+ - 重要:你需要确定输入内容所使用的语言,然后使用相同语言进行输出,以保证输入和输出语言一致
73
+ 以下为输入内容:
74
+ {text}
75
+ '''
76
+ if prompt_type == 'polish':
77
+ return polish_prompt
78
+ elif prompt_type == 'check':
79
+ return check_prompt
80
+ else:
81
+ raise Exception(rf'Unknown prompt type {prompt_type}, please specify "polish" or "check"')
82
+
83
+
84
+ def merge_content(tex_dir, main_tex):
85
+ result_tex = ''
86
+ with open(os.path.join(tex_dir, main_tex), 'r', encoding='utf-8') as f:
87
+ while True:
88
+ line = f.readline()
89
+ if not line:
90
+ break
91
+ if line.startswith(r'\input{') or line.startswith(r'\include{'):
92
+ sub_tex_path = line.split('{')[1].split('}')[0]
93
+ if not sub_tex_path.endswith('.tex'):
94
+ sub_tex_path += r'.tex'
95
+ result_tex += merge_content(tex_dir, sub_tex_path) + '\n'
96
+ else:
97
+ result_tex += line
98
+ return result_tex
99
+
100
+
101
+ class AI:
102
+ def __init__(self, token, base_url, model_name='deepseek-r1:70b', prompt_type='polish'):
103
+ self.client = OpenAI(api_key=token, base_url=base_url)
104
+ self.model_name = model_name
105
+ self.prompt_type = prompt_type
106
+
107
+ def call_ai_polish(self, text, cn_prompt=False, prompt=None):
108
+ response = self.client.chat.completions.create(
109
+ model=self.model_name,
110
+ messages=[
111
+ {"role": "user",
112
+ "content": ((get_cn_polish_prompt(text, self.prompt_type) if cn_prompt else get_en_polish_prompt(
113
+ text, self.prompt_type)) if prompt is None else prompt.replace('{content}', text))}
114
+ ],
115
+ temperature=0.6,
116
+ # max_tokens=1024,
117
+ stream=False
118
+ )
119
+ # 逐步接收并处理响应
120
+ # for chunk in response:
121
+ # print(chunk.choices[0].delta.content, end='')
122
+ # print(response.choices[0].message.content)
123
+ content = response.choices[0].message.content
124
+ content = content.split('</think>')[1].strip().replace('\n\n', ' ')
125
+ if text.startswith(r'\par ') and not content.startswith(r'\par '):
126
+ print(rf'missing \par in polished text, append it to the beginning of the text.')
127
+ content = r'\par ' + content
128
+ if text.startswith(r'\par{') and not content.startswith(r'\par{'):
129
+ print(rf'missing \par in polished text, append it to the beginning of the text.')
130
+ content = r'\par{' + content
131
+ if text.endswith(r'}') and not content.endswith(r'}'):
132
+ print(r'missing } in polished text, append it to the beginning of the text.')
133
+ content = content + r'}'
134
+ return content
135
+
136
+
137
+ def polish(main_dir_path, tex_file, server_url, token='Your token here', model_name='deepseek-r1:70b', cn_prompt=False,
138
+ prompt=None, rewrite_list=(r'\caption{', r'\par ', r'\par{'), skip_part_list=('figure', 'table', 'equation'),
139
+ ignore_length=100, prompt_type='polish'):
140
+ st = Stopwatch()
141
+ ai = AI(token=token, base_url=server_url, model_name=model_name, prompt_type=prompt_type)
142
+ result_tex = merge_content(main_dir_path, tex_file)
143
+ new_tex = ''
144
+ up_flag = False
145
+ result_split = result_tex.split('\n')
146
+ for idx, line in enumerate(result_split):
147
+ spend_count = Stopwatch()
148
+ line = line.strip()
149
+ line_up_flag = True
150
+ if line.startswith('%') or line.startswith('\\') or len(line) < ignore_length:
151
+ for flag in rewrite_list:
152
+ if line.startswith(flag):
153
+ line_up_flag = False
154
+ break
155
+ else:
156
+ line_up_flag = False
157
+
158
+ for flag in skip_part_list:
159
+ if line.count(r'\begin{' + flag) > 0:
160
+ up_flag = True
161
+ break
162
+ if line.count(r'\end{' + flag) > 0:
163
+ up_flag = False
164
+ break
165
+
166
+ if up_flag or line_up_flag:
167
+ new_tex += line + '\n'
168
+ else:
169
+ try:
170
+ print(rf'当前处理{idx}/{len(result_split)}行,总共用时{st.pinch()},当前时间:{get_time_str()}')
171
+ print(f'Input[{line[:100]}{"..." if len(line) > 100 else line}]')
172
+ polish_text = ai.call_ai_polish(line, cn_prompt, prompt)
173
+ print(f'Polished:[{polish_text[:100]}{"..." if len(polish_text) > 100 else polish_text}]')
174
+ print(rf'处理结束,耗时{spend_count.pinch()},共改写{len(line)}个字符')
175
+ new_tex += polish_text + '\n'
176
+ except Exception as e:
177
+ tqdm.write(f'**e**{e}')
178
+ new_tex += line + '\n'
179
+
180
+ with open(rf'{main_dir_path}\old.tex', 'w', encoding='utf-8') as f:
181
+ f.write(result_tex)
182
+ with open(rf'{main_dir_path}\new.tex', 'w', encoding='utf-8') as f:
183
+ f.write(new_tex)
184
+ write_diff(main_dir_path)
185
+
186
+
187
+ def write_diff(dir_path):
188
+ diff_tex = '''\RequirePackage{shellesc}
189
+ \ShellEscape{pdfLatex new.tex} %编译新文档
190
+ \ShellEscape{pdfLatex old.tex} %编译新文档
191
+ \ShellEscape{latexdiff old.tex new.tex > diff_result.tex}
192
+ \input{diff_result}
193
+ \documentclass{dummy}'''
194
+ with open(rf'{dir_path}\diff.tex', 'w', encoding='utf-8') as f:
195
+ f.write(diff_tex)
196
+
197
+
198
+ if __name__ == "__main__":
199
+ pass
jacksung/ai/metrics.py ADDED
@@ -0,0 +1,181 @@
1
+ import os
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from torch.autograd import Variable
7
+ import math
8
+ from pytorch_msssim import ssim
9
+ from torchmetrics.image import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure
10
+ from torchmetrics import R2Score, PearsonCorrCoef, AUROC
11
+ from torchmetrics.regression import MeanSquaredError
12
+ import importlib
13
+ from einops import rearrange
14
+ import cv2
15
+
16
+
17
+ def compute_rmse(da_fc, da_true):
18
+ error = da_fc - da_true
19
+ error = error ** 2
20
+ number = torch.sqrt(error.mean((-2, -1)))
21
+ return number.mean()
22
+
23
+
24
+ class Metrics:
25
+ def __init__(self):
26
+ self.psnr = PeakSignalNoiseRatio()
27
+ self.ssim = StructuralSimilarityIndexMeasure()
28
+ self.rr = R2Score()
29
+ self.p = PearsonCorrCoef()
30
+ self.AUROC = AUROC("binary")
31
+
32
+ def mask_nan(self, pred, target):
33
+ pred = pred.flatten()
34
+ target = target.flatten()
35
+ # 生成有效掩码:真实值和预测值均非 NaN 的位置为 True
36
+ valid_mask = ~(torch.isnan(pred) | torch.isnan(target))
37
+
38
+ # 过滤无效元素
39
+ pred = pred[valid_mask]
40
+ target = target[valid_mask]
41
+
42
+ return pred, target
43
+
44
+ def calc_AUROC(self, preds, targets):
45
+ AUROC = 0
46
+ for i in range(len(preds)):
47
+ pred, target = self.mask_nan(preds[i], targets[i])
48
+ AUROC += self.AUROC(pred, target)
49
+ self.AUROC.reset()
50
+ return AUROC / len(preds)
51
+
52
+ def calc_psnr(self, preds, targets):
53
+ psnr = 0
54
+ for i in range(len(preds)):
55
+ psnr += self.psnr(rearrange(preds[i], '(b c) h w->b c h w', b=1),
56
+ rearrange(targets[i], '(b c) h w->b c h w', b=1))
57
+ self.psnr.reset()
58
+ return psnr / len(preds)
59
+
60
+ def calc_ssim(self, preds, targets):
61
+ ssim = 0
62
+ for i in range(len(preds)):
63
+ ssim += self.ssim(rearrange(preds[i], '(b c) h w->b c h w', b=1),
64
+ rearrange(targets[i], '(b c) h w->b c h w', b=1))
65
+ self.ssim.reset()
66
+ return ssim / len(preds)
67
+
68
+ def calc_rmse(self, preds, targets):
69
+ rmse = 0
70
+ for i in range(len(preds)):
71
+ rmse += compute_rmse(preds[i], targets[i])
72
+ return rmse / len(preds)
73
+
74
+ def calc_rr(self, preds, targets):
75
+ rr = 0
76
+ for i in range(len(preds)):
77
+ pred, tar = self.mask_nan(preds[i], targets[i])
78
+ rr += self.rr(pred, tar)
79
+ self.rr.reset()
80
+ return rr / len(preds)
81
+
82
+ def calc_p(self, preds, targets, exclude_zero=False):
83
+ p = 0
84
+ count = 0
85
+ for i in range(len(preds)):
86
+ pred, target = self.mask_nan(preds[i], targets[i])
87
+ if exclude_zero:
88
+ mask = target != 0
89
+ pred = pred[mask]
90
+ target = target[mask]
91
+ if pred.var() == 0 or target.var() == 0:
92
+ continue
93
+ count += 1
94
+ p += self.p(pred, target)
95
+ self.p.reset()
96
+ return p / count
97
+
98
+ def print_metrics(self, preds, targets, print_log=True):
99
+ rr = float(self.calc_rr(preds, targets))
100
+ p = float(self.calc_p(preds, targets))
101
+ rmse = float(self.calc_rmse(preds, targets))
102
+ ssim = float(self.calc_ssim(preds, targets))
103
+ psnr = float(self.calc_psnr(preds, targets))
104
+ if print_log:
105
+ print(rf'p: {p} rr: {rr} rmse: {rmse} ssim: {ssim} psnr: {psnr}')
106
+ return {'p': p, 'rr': rr, 'rmse': rmse, 'ssim': ssim, 'psnr': psnr}
107
+
108
+ def calculate_rain_metrics(self, preds, targets, threshold=0.1):
109
+ """
110
+ 使用flatten()批量计算多张降雨图的POD、FAR、ACC、CSI指标
111
+
112
+ 参数:
113
+ preds: 预测的降雨tensor,形状为[样本数, 高度, 宽度]
114
+ target: 观测的降雨tensor,形状与preds相同
115
+ threshold: 降雨事件的阈值,默认0.1mm
116
+
117
+ 返回:
118
+ metrics: 包含每个样本指标的字典
119
+ """
120
+ # 1. 将每个样本的空间维度展平([样本数, 高度, 宽度] → [样本数, 像素总数])
121
+ preds_flat = preds.flatten(start_dim=1) # 从第1维开始展平(保留样本维度)
122
+ targets_flat = targets.flatten(start_dim=1)
123
+
124
+ # 2. 二值化(1=有雨,0=无雨)
125
+ preds_binary = (preds_flat >= threshold).float()
126
+ targets_binary = (targets_flat >= threshold).float()
127
+
128
+ # 3. 计算混淆矩阵元素(按样本维度求和)
129
+ TP = torch.sum(preds_binary * targets_binary, dim=1) # 每个样本的TP总和
130
+ FP = torch.sum(preds_binary * (1 - targets_binary), dim=1)
131
+ TN = torch.sum((1 - preds_binary) * (1 - targets_binary), dim=1)
132
+ FN = torch.sum((1 - preds_binary) * targets_binary, dim=1)
133
+
134
+ # 4. 计算指标(处理分母为0的情况)
135
+ POD = TP / (TP + FN)
136
+ FAR = FP / (TP + FP)
137
+ ACC = (TP + TN) / (TP + FP + TN + FN)
138
+ CSI = TP / (TP + FP + FN)
139
+
140
+ # 5. 标记无效值为NaN
141
+ POD = torch.where((TP + FN) == 0, torch.nan, POD)
142
+ FAR = torch.where((TP + FP) == 0, torch.nan, FAR)
143
+ ACC = torch.where((TP + FP + TN + FN) == 0, torch.nan, ACC)
144
+ CSI = torch.where((TP + FP + FN) == 0, torch.nan, CSI)
145
+
146
+ return POD, FAR, ACC, CSI
147
+
148
+
149
+ def img2tensor(img):
150
+ if type(img) == str:
151
+ img = cv2.imread(img, -1)
152
+ img = torch.from_numpy(img)
153
+ img = rearrange(img, ' (b c h) w->b c h w', b=1, c=1)
154
+ return img
155
+
156
+
157
+ if __name__ == '__main__':
158
+ preds = torch.Tensor([[[0, 0, 1],
159
+ [0, 1, 1],
160
+ [1, 1, 1]
161
+ ],
162
+ [[0, 1, 1],
163
+ [0, 1, 1],
164
+ [0, 1, 1]
165
+ ]])
166
+ # target = torch.rand(2, 1, 3, 3)
167
+ target = torch.Tensor([[[0, 1, 1],
168
+ [0, 1, 1],
169
+ [0, 1, 1]
170
+ ],
171
+ [[0, 1, 1],
172
+ [0, 1, 1],
173
+ [0, 0, 1]
174
+ ]])
175
+ target[1, 0, 2] = torch.nan
176
+ # m = Metrics()
177
+ # print(m.calc_rr(preds, target))
178
+ m = Metrics()
179
+ AUROC = m.calc_AUROC(preds, target)
180
+ print(AUROC)
181
+ print(m.calculate_rain_metrics(preds, target))
File without changes
@@ -0,0 +1,42 @@
1
+ import netCDF4 as nc
2
+ import numpy as np
3
+ from einops import rearrange, repeat
4
+ from jacksung.utils.data_convert import np2tif, get_transform_from_lonlat_matrices
5
+
6
+
7
+ def getNPfromHDF(hdf_path, lock=None, save_file=True, print_log=False):
8
+ if lock:
9
+ lock.acquire()
10
+ ds = nc.Dataset(hdf_path)
11
+ if lock:
12
+ lock.release()
13
+ np_data = np.array(ds['cmorph'][:]).astype(np.float32)
14
+ lon_array = np.array(ds['lon'][:]).astype(np.float32)
15
+ lat_array = np.array(ds['lat'][:]).astype(np.float32)
16
+ lon_dim = len(lon_array)
17
+ lat_dim = len(lat_array)
18
+ lon_array = repeat(lon_array, 'w -> w h', h=lat_dim)
19
+ lat_array = repeat(lat_array, 'h -> w h', w=lon_dim)
20
+ ds.close()
21
+ # np_data = rearrange(np_data[0], 'w h->h w')[::-1, :]
22
+ np_data[np_data < 0] = 0
23
+ # np_data = np_data[0] + np_data[1]
24
+ transform, avg_error = get_transform_from_lonlat_matrices(
25
+ lon_array=lon_array,
26
+ lat_array=lat_array,
27
+ gcp_density=20, # 范围越大,gcp_density建议越大
28
+ print_log=print_log
29
+ )
30
+ if save_file:
31
+ np2tif(np_data, save_path='np2tif_dir', out_name='CMORPH', dtype='float32', transform=transform)
32
+ return np_data, transform
33
+
34
+
35
+ if __name__ == '__main__':
36
+ data = getNPfromHDF(rf'C:\Users\ECNU\PycharmProjects\CMORPH_V1.0_ADJ_8km-30min_2022070203.nc')
37
+ # from datetime import datetime
38
+ #
39
+ # da = datetime.utcfromtimestamp(1656730800)
40
+ # print(da)
41
+ # da = datetime.utcfromtimestamp(1656732600)
42
+ # print(da)
@@ -0,0 +1,90 @@
1
+ from torch.nn.parallel import DataParallel
2
+ import torch
3
+ from torch.nn.parallel._functions import Scatter
4
+ from torch.nn.parallel.parallel_apply import parallel_apply
5
+
6
+
7
+ def scatter(inputs, target_gpus, chunk_sizes, dim=0):
8
+ r"""
9
+ Slices tensors into approximately equal chunks and
10
+ distributes them across given GPUs. Duplicates
11
+ references to objects that are not tensors.
12
+ """
13
+
14
+ def scatter_map(obj):
15
+ if isinstance(obj, torch.Tensor):
16
+ try:
17
+ return Scatter.apply(target_gpus, chunk_sizes, dim, obj)
18
+ except:
19
+ print('obj', obj.size())
20
+ print('dim', dim)
21
+ print('chunk_sizes', chunk_sizes)
22
+ quit()
23
+ if isinstance(obj, tuple) and len(obj) > 0:
24
+ return list(zip(*map(scatter_map, obj)))
25
+ if isinstance(obj, list) and len(obj) > 0:
26
+ return list(map(list, zip(*map(scatter_map, obj))))
27
+ if isinstance(obj, dict) and len(obj) > 0:
28
+ return list(map(type(obj), zip(*map(scatter_map, obj.items()))))
29
+ return [obj for targets in target_gpus]
30
+
31
+ # After scatter_map is called, a scatter_map cell will exist. This cell
32
+ # has a reference to the actual function scatter_map, which has references
33
+ # to a closure that has a reference to the scatter_map cell (because the
34
+ # fn is recursive). To avoid this reference cycle, we set the function to
35
+ # None, clearing the cell
36
+ try:
37
+ return scatter_map(inputs)
38
+ finally:
39
+ scatter_map = None
40
+
41
+
42
+ def scatter_kwargs(inputs, kwargs, target_gpus, chunk_sizes, dim=0):
43
+ r"""Scatter with support for kwargs dictionary"""
44
+ inputs = scatter(inputs, target_gpus, chunk_sizes, dim) if inputs else []
45
+ kwargs = scatter(kwargs, target_gpus, chunk_sizes, dim) if kwargs else []
46
+ if len(inputs) < len(kwargs):
47
+ inputs.extend([() for _ in range(len(kwargs) - len(inputs))])
48
+ elif len(kwargs) < len(inputs):
49
+ kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))])
50
+ inputs = tuple(inputs)
51
+ kwargs = tuple(kwargs)
52
+ return inputs, kwargs
53
+
54
+
55
+ class BalancedDataParallel(DataParallel):
56
+ def __init__(self, gpu0_bsz, *args, **kwargs):
57
+ self.gpu0_bsz = gpu0_bsz
58
+ super().__init__(*args, **kwargs)
59
+
60
+ def forward(self, *inputs, **kwargs):
61
+ if not self.device_ids:
62
+ return self.module(*inputs, **kwargs)
63
+ gpu_idx_start = 1 if self.gpu0_bsz == 0 else 0
64
+ inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids[gpu_idx_start:])
65
+ if len(self.device_ids) == 1:
66
+ return self.module(*inputs[0], **kwargs[0])
67
+ replicas = self.replicate(self.module, self.device_ids[:len(inputs) + gpu_idx_start])
68
+ replicas = replicas[gpu_idx_start:]
69
+ outputs = self.parallel_apply(replicas, self.device_ids[gpu_idx_start:], inputs, kwargs)
70
+ return self.gather(outputs, self.output_device)
71
+
72
+ def parallel_apply(self, replicas, device_ids, inputs, kwargs):
73
+ return parallel_apply(replicas, inputs, kwargs, device_ids[:len(inputs)])
74
+
75
+ def scatter(self, inputs, kwargs, device_ids):
76
+ bsz = inputs[0].size(self.dim)
77
+ num_dev = len(self.device_ids)
78
+ gpu0_bsz = self.gpu0_bsz
79
+ bsz_unit = (bsz - gpu0_bsz) // (num_dev - 1)
80
+ if gpu0_bsz < bsz_unit:
81
+ chunk_sizes = [gpu0_bsz] + [bsz_unit] * (num_dev - 1)
82
+ delta = bsz - sum(chunk_sizes)
83
+ for i in range(delta):
84
+ chunk_sizes[i + 1] += 1
85
+ if gpu0_bsz == 0:
86
+ chunk_sizes = chunk_sizes[1:]
87
+ else:
88
+ return super().scatter(inputs, kwargs, device_ids)
89
+
90
+ return scatter_kwargs(inputs, kwargs, device_ids, chunk_sizes, dim=self.dim)