qubitclient 0.1.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (81) hide show
  1. qubitclient/__init__.py +5 -0
  2. qubitclient/draw/__init__.py +0 -0
  3. qubitclient/draw/optpipulsepltplotter.py +75 -0
  4. qubitclient/draw/optpipulseplyplotter.py +114 -0
  5. qubitclient/draw/pltmanager.py +50 -0
  6. qubitclient/draw/pltplotter.py +20 -0
  7. qubitclient/draw/plymanager.py +57 -0
  8. qubitclient/draw/plyplotter.py +21 -0
  9. qubitclient/draw/powershiftpltplotter.py +108 -0
  10. qubitclient/draw/powershiftplyplotter.py +194 -0
  11. qubitclient/draw/rabicospltplotter.py +74 -0
  12. qubitclient/draw/rabicosplyplotter.py +90 -0
  13. qubitclient/draw/rabipltplotter.py +66 -0
  14. qubitclient/draw/rabiplyplotter.py +86 -0
  15. qubitclient/draw/s21peakpltplotter.py +67 -0
  16. qubitclient/draw/s21peakplyplotter.py +124 -0
  17. qubitclient/draw/s21vfluxpltplotter.py +84 -0
  18. qubitclient/draw/s21vfluxplyplotter.py +163 -0
  19. qubitclient/draw/singleshotpltplotter.py +149 -0
  20. qubitclient/draw/singleshotplyplotter.py +324 -0
  21. qubitclient/draw/spectrum2dpltplotter.py +107 -0
  22. qubitclient/draw/spectrum2dplyplotter.py +244 -0
  23. qubitclient/draw/spectrum2dscopepltplotter.py +72 -0
  24. qubitclient/draw/spectrum2dscopeplyplotter.py +195 -0
  25. qubitclient/draw/spectrumpltplotter.py +106 -0
  26. qubitclient/draw/spectrumplyplotter.py +133 -0
  27. qubitclient/draw/t1fitpltplotter.py +76 -0
  28. qubitclient/draw/t1fitplyplotter.py +109 -0
  29. qubitclient/draw/t2fitpltplotter.py +70 -0
  30. qubitclient/draw/t2fitplyplotter.py +111 -0
  31. qubitclient/nnscope/nnscope.py +51 -0
  32. qubitclient/nnscope/nnscope_api/curve/__init__.py +0 -0
  33. qubitclient/nnscope/nnscope_api/curve/curve_type.py +15 -0
  34. qubitclient/nnscope/task.py +170 -0
  35. qubitclient/nnscope/utils/data_convert.py +114 -0
  36. qubitclient/nnscope/utils/data_parser.py +41 -0
  37. qubitclient/nnscope/utils/request_tool.py +41 -0
  38. qubitclient/nnscope/utils/result_parser.py +55 -0
  39. qubitclient/scope/scope.py +50 -0
  40. qubitclient/scope/scope_api/__init__.py +8 -0
  41. qubitclient/scope/scope_api/api/__init__.py +1 -0
  42. qubitclient/scope/scope_api/api/defined_tasks/__init__.py +1 -0
  43. qubitclient/scope/scope_api/api/defined_tasks/get_task_result_api_v1_tasks_demo_pk_get.py +155 -0
  44. qubitclient/scope/scope_api/api/defined_tasks/get_task_result_api_v1_tasks_scope_pk_get.py +155 -0
  45. qubitclient/scope/scope_api/api/defined_tasks/optpipulse_api_v1_tasks_scope_optpipulse_post.py +218 -0
  46. qubitclient/scope/scope_api/api/defined_tasks/powershift_api_v1_tasks_scope_powershift_post.py +218 -0
  47. qubitclient/scope/scope_api/api/defined_tasks/rabi_api_v1_tasks_scope_rabi_post.py +218 -0
  48. qubitclient/scope/scope_api/api/defined_tasks/rabicos_api_v1_tasks_scope_rabicospeak_post.py +218 -0
  49. qubitclient/scope/scope_api/api/defined_tasks/s21peak_api_v1_tasks_scope_s21peak_post.py +218 -0
  50. qubitclient/scope/scope_api/api/defined_tasks/s21vflux_api_v1_tasks_scope_s21vflux_post.py +218 -0
  51. qubitclient/scope/scope_api/api/defined_tasks/singleshot_api_v1_tasks_scope_singleshot_post.py +218 -0
  52. qubitclient/scope/scope_api/api/defined_tasks/spectrum2d_api_v1_tasks_scope_spectrum2d_post.py +218 -0
  53. qubitclient/scope/scope_api/api/defined_tasks/spectrum_api_v1_tasks_scope_spectrum_post.py +218 -0
  54. qubitclient/scope/scope_api/api/defined_tasks/t1fit_api_v1_tasks_scope_t1fit_post.py +218 -0
  55. qubitclient/scope/scope_api/api/defined_tasks/t1fit_api_v1_tasks_scope_t2fit_post.py +218 -0
  56. qubitclient/scope/scope_api/client.py +268 -0
  57. qubitclient/scope/scope_api/errors.py +16 -0
  58. qubitclient/scope/scope_api/models/__init__.py +31 -0
  59. qubitclient/scope/scope_api/models/body_optpipulse_api_v1_tasks_scope_optpipulse_post.py +83 -0
  60. qubitclient/scope/scope_api/models/body_powershift_api_v1_tasks_scope_powershift_post.py +83 -0
  61. qubitclient/scope/scope_api/models/body_rabi_api_v1_tasks_scope_rabi_post.py +83 -0
  62. qubitclient/scope/scope_api/models/body_rabicos_api_v1_tasks_scope_rabicospeak_post.py +83 -0
  63. qubitclient/scope/scope_api/models/body_s21_peak_api_v1_tasks_scope_s21_peak_post.py +83 -0
  64. qubitclient/scope/scope_api/models/body_s21_vflux_api_v1_tasks_scope_s21_vflux_post.py +83 -0
  65. qubitclient/scope/scope_api/models/body_singleshot_api_v1_tasks_scope_singleshot_post.py +83 -0
  66. qubitclient/scope/scope_api/models/body_spectrum_2d_api_v1_tasks_scope_spectrum_2d_post.py +83 -0
  67. qubitclient/scope/scope_api/models/body_spectrum_api_v1_tasks_scope_spectrum_post.py +83 -0
  68. qubitclient/scope/scope_api/models/body_t1_fit_api_v1_tasks_scope_t1_fit_post.py +83 -0
  69. qubitclient/scope/scope_api/models/body_t1_fit_api_v1_tasks_scope_t2_fit_post.py +83 -0
  70. qubitclient/scope/scope_api/models/http_validation_error.py +75 -0
  71. qubitclient/scope/scope_api/models/validation_error.py +88 -0
  72. qubitclient/scope/scope_api/types.py +54 -0
  73. qubitclient/scope/task.py +163 -0
  74. qubitclient/scope/utils/__init__.py +0 -0
  75. qubitclient/scope/utils/data_parser.py +20 -0
  76. qubitclient-0.1.4.dist-info/METADATA +173 -0
  77. qubitclient-0.1.4.dist-info/RECORD +81 -0
  78. qubitclient-0.1.4.dist-info/WHEEL +5 -0
  79. qubitclient-0.1.4.dist-info/licenses/LICENSE +674 -0
  80. qubitclient-0.1.4.dist-info/top_level.txt +1 -0
  81. qubitclient-0.1.4.dist-info/zip-safe +1 -0
@@ -0,0 +1,107 @@
1
+ import numpy as np
2
+ import matplotlib.pyplot as plt
3
+ from .pltplotter import QuantumDataPltPlotter
4
+
5
+ class Spectrum2DDataPltPlotter(QuantumDataPltPlotter):
6
+ def __init__(self):
7
+ super().__init__("spectrum2d")
8
+
9
+ def plot_result_npy(self, **kwargs):
10
+ results = kwargs.get('results')
11
+ data_ndarray = kwargs.get('data_ndarray')
12
+
13
+ nums = len(results)*2
14
+ row = (nums // 2) + 1 if nums % 2 != 0 else nums // 2
15
+ col = min(nums, 2)
16
+
17
+ fig = plt.figure(figsize=(10 * col, 4 * row))
18
+ data_dict = data_ndarray.item() if isinstance(data_ndarray, np.ndarray) else data_ndarray
19
+ data_dict = data_dict['image']
20
+ dict_list = []
21
+ q_list = data_dict.keys()
22
+
23
+ for idx, q_name in enumerate(q_list):
24
+ npz_dict = {}
25
+ image_q = data_dict[q_name]
26
+ data = image_q[0]
27
+ if data.ndim != 2:
28
+ raise ValueError("数据格式无效,data不是二维数组")
29
+ data = np.array(data)
30
+ data = np.abs(data)
31
+
32
+ npz_dict['bias'] = image_q[1]
33
+ npz_dict['frequency'] = image_q[2]
34
+ npz_dict['iq_avg'] = data
35
+ npz_dict['name'] = q_name
36
+ dict_list.append(npz_dict)
37
+
38
+ for index in range(nums):
39
+ ax = fig.add_subplot(row, col, index + 1)
40
+ result = results[index//2]
41
+
42
+ points_list = []
43
+ for i in range(len(result["linepoints_list"])):
44
+ points_list.append(result["linepoints_list"][i])
45
+
46
+ plt.pcolormesh(dict_list[index//2]["bias"], dict_list[index//2]["frequency"], dict_list[index//2]["iq_avg"],
47
+ shading='auto', cmap='viridis')
48
+ plt.colorbar(label='IQ Average')
49
+ colors = plt.cm.rainbow(np.linspace(0, 1, len(result["linepoints_list"])))
50
+
51
+ if (index % 2 != 0):
52
+ for i in range(len(points_list)):
53
+ reflection_points = points_list[i]
54
+ reflection_points = np.array(reflection_points)
55
+ xy_x = reflection_points[:, 0]
56
+ xy_y = reflection_points[:, 1]
57
+
58
+ plt.scatter(xy_x, xy_y, color=colors[i],
59
+ label=f'XY Points{i}-conf:{round(result["confidence_list"][i], 2)}', s=5,
60
+ alpha=0.1)
61
+ file_name = dict_list[index//2]["name"]
62
+ plt.title(f"File: {file_name}")
63
+ plt.xlabel("Bias")
64
+ plt.ylabel("Frequency (GHz)")
65
+ plt.legend()
66
+ fig.tight_layout()
67
+ return fig
68
+
69
+ def plot_result_npz(self, **kwargs):
70
+ results = kwargs.get('results')
71
+ dict_list = kwargs.get('dict_list')
72
+ file_names = kwargs.get('file_names')
73
+
74
+ nums = len(results)*2
75
+ row = (nums // 2) + 1 if nums % 2 != 0 else nums // 2
76
+ col = min(nums, 2)
77
+
78
+ fig = plt.figure(figsize=(10 * col, 4 * row))
79
+
80
+ for index in range(nums):
81
+ ax = fig.add_subplot(row, col, index + 1)
82
+ result = results[index//2]
83
+ file_name = file_names[index//2]
84
+
85
+ points_list = []
86
+ for i in range(len(result["linepoints_list"])):
87
+ points_list.append(result["linepoints_list"][i])
88
+
89
+ plt.pcolormesh(dict_list[index//2]["bias"], dict_list[index//2]["frequency"], dict_list[index//2]["iq_avg"],
90
+ shading='auto', cmap='viridis')
91
+ plt.colorbar(label='IQ Average')
92
+ colors = plt.cm.rainbow(np.linspace(0, 1, len(result["linepoints_list"])))
93
+ if(index%2!=0):
94
+ for i in range(len(points_list)):
95
+ reflection_points = points_list[i]
96
+ reflection_points = np.array(reflection_points)
97
+ xy_x = reflection_points[:, 0]
98
+ xy_y = reflection_points[:, 1]
99
+ plt.scatter(xy_x, xy_y, color=colors[i],
100
+ label=f'XY Points{i}-conf:{round(result["confidence_list"][i], 2)}', s=5,
101
+ alpha=0.1)
102
+ plt.title(f"File: {file_name}")
103
+ plt.xlabel("Bias")
104
+ plt.ylabel("Frequency (GHz)")
105
+ plt.legend()
106
+ fig.tight_layout()
107
+ return fig
@@ -0,0 +1,244 @@
1
+ from .plyplotter import QuantumDataPlyPlotter
2
+ import numpy as np
3
+ import plotly.graph_objects as go
4
+ from plotly.subplots import make_subplots
5
+
6
+ class Spectrum2DDataPlyPlotter(QuantumDataPlyPlotter):
7
+
8
+ def __init__(self):
9
+ super().__init__("spectrum2d")
10
+
11
+
12
+ def plot_result_npy(self, **kwargs):
13
+
14
+ results = kwargs.get('results')
15
+ data_ndarray = kwargs.get('data_ndarray')
16
+
17
+ # 参数验证
18
+ if results is None:
19
+ raise ValueError("缺少必需的 'results' 参数")
20
+ if data_ndarray is None:
21
+ raise ValueError("缺少必需的 'data_ndarray' 参数")
22
+ nums = len(results)
23
+ rows = (nums*2 // 2) + 1 if nums*2 % 2 != 0 else nums*2 // 2
24
+ cols = min(nums*2, 2)
25
+
26
+ # 处理数据字典
27
+ data_dict = data_ndarray.item() if isinstance(data_ndarray, np.ndarray) else data_ndarray
28
+ data_dict = data_dict['image']
29
+ data_dict = data_dict.item() if isinstance(data_dict, np.ndarray) else data_dict
30
+ dict_list = []
31
+ q_list = data_dict.keys()
32
+
33
+ # 准备数据列表
34
+ for idx, q_name in enumerate(q_list):
35
+ npz_dict = {}
36
+ image_q = data_dict[q_name]
37
+ data = image_q[0]
38
+ if data.ndim != 2:
39
+ raise ValueError("数据格式无效,data不是二维数组")
40
+ data = np.array(data)
41
+ data = np.abs(data)
42
+
43
+ npz_dict['bias'] = image_q[1]
44
+ npz_dict['frequency'] = image_q[2]
45
+ npz_dict['iq_avg'] = data
46
+ npz_dict['name'] = q_name
47
+ dict_list.append(npz_dict)
48
+ subplot_titles = []
49
+ for i in range(nums):
50
+ subplot_titles.append(f"File: {dict_list[i]['name']}")
51
+ subplot_titles.append(f"File: {dict_list[i]['name']}")
52
+
53
+ # 创建子图
54
+ fig = make_subplots(
55
+ rows=rows, cols=cols,
56
+ subplot_titles=subplot_titles,
57
+ vertical_spacing=0.01,
58
+ horizontal_spacing=0.1,
59
+ x_title="Bias",
60
+ y_title="Frequency (GHz)"
61
+ )
62
+
63
+ # 遍历每个结果绘制子图
64
+ for index in range(nums*2):
65
+ row = (index // cols) + 1
66
+ col = (index % cols) + 1
67
+
68
+ result = results[index//2]
69
+ data = dict_list[index//2]
70
+
71
+ # 准备点数据
72
+ points_list = []
73
+ for i in range(len(result["linepoints_list"])):
74
+ points_list.append(np.array(result["linepoints_list"][i]))
75
+
76
+ # 添加热力图
77
+ heatmap = go.Heatmap(
78
+ z=data["iq_avg"],
79
+ x=data["bias"],
80
+ y=data["frequency"],
81
+ colorscale="Viridis",
82
+ colorbar=dict(
83
+ title="IQ Average",
84
+ thickness=10,
85
+ len=0.7,
86
+ yanchor="middle",
87
+ y=0.5
88
+ ),
89
+ showscale=(index == 0) # 只在第一个子图显示颜色条
90
+ )
91
+ fig.add_trace(heatmap, row=row, col=col)
92
+ # 添加散点
93
+ colors = np.linspace(0, 1, len(points_list))
94
+ for i, points in enumerate(points_list):
95
+ if len(points) == 0:
96
+ continue
97
+ xy_x = points[:, 0]
98
+ xy_y = points[:, 1]
99
+ scatter = go.Scatter(
100
+ x=xy_x,
101
+ y=xy_y,
102
+ mode="markers",
103
+ marker=dict(
104
+ color=colors[i],
105
+ colorscale="Rainbow",
106
+ size=5,
107
+ opacity=0.1,
108
+ showscale=False
109
+ ),
110
+ name=f'XY Points{i}-conf:{round(result["confidence_list"][i], 2)}',
111
+ legendgroup=f"group{index//2}",
112
+ showlegend=(row == 1 and col == 1) # 只在第一个子图显示图例
113
+ )
114
+ if(index%2!=0):
115
+ fig.add_trace(scatter, row=row, col=col)
116
+
117
+ # 更新布局
118
+ fig.update_layout(
119
+ height=500 * rows,
120
+ width=900 * cols,
121
+ margin=dict(r=60, t=60, b=60, l=60),
122
+ legend=dict(
123
+ font=dict(family="Courier", size=12, color="black"),
124
+ borderwidth=1
125
+ )
126
+ )
127
+
128
+ # 更新坐标轴设置
129
+ fig.update_xaxes(
130
+ title_text="Bias",
131
+ title_font=dict(size=10), # 缩小字体
132
+ title_standoff=8 # 增加标题与坐标轴的距离(单位:像素)
133
+ )
134
+ fig.update_yaxes(
135
+ title_text="Frequency (GHz)",
136
+ title_font=dict(size=10),
137
+ title_standoff=8
138
+ )
139
+ return fig
140
+ # 保存图片
141
+ def plot_result_npz(self, **kwargs):
142
+
143
+ results = kwargs.get('results')
144
+ dict_list = kwargs.get('dict_list')
145
+ file_names = kwargs.get('file_names')
146
+
147
+ nums = len(results)*2
148
+ rows = (nums // 2) + 1 if nums % 2 != 0 else nums // 2
149
+ cols = min(nums, 2)
150
+ subplot_titles = []
151
+ for name in file_names:
152
+ subplot_titles.append(f"File: {name}")
153
+ subplot_titles.append(f"File: {name}")
154
+
155
+ # 创建子图
156
+ fig = make_subplots(
157
+ rows=rows, cols=cols,
158
+ subplot_titles=subplot_titles,
159
+ vertical_spacing=0.015,
160
+ horizontal_spacing=0.1,
161
+ x_title="Bias",
162
+ y_title="Frequency (GHz)"
163
+ )
164
+
165
+ # 遍历每个结果绘制子图
166
+ for index in range(nums):
167
+ row = (index // cols) + 1
168
+ col = (index % cols) + 1
169
+
170
+ result = results[index//2]
171
+ data = dict_list[index//2]
172
+
173
+ # 准备点数据
174
+ points_list = []
175
+ for i in range(len(result["linepoints_list"])):
176
+ points_list.append(np.array(result["linepoints_list"][i]))
177
+
178
+ # 添加热力图
179
+ heatmap = go.Heatmap(
180
+ z=data["iq_avg"],
181
+ x=data["bias"],
182
+ y=data["frequency"],
183
+ colorscale="Viridis",
184
+ colorbar=dict(
185
+ title="IQ Average",
186
+ thickness=10,
187
+ len=0.7,
188
+ yanchor="middle",
189
+ y=0.5
190
+ ),
191
+ showscale=(index == 0) # 只在第一个子图显示颜色条
192
+ )
193
+ fig.add_trace(heatmap, row=row, col=col)
194
+
195
+ # 添加散点
196
+ colors = np.linspace(0, 1, len(points_list))
197
+ for i, points in enumerate(points_list):
198
+ if len(points) == 0:
199
+ continue
200
+ xy_x = points[:, 0]
201
+ xy_y = points[:, 1]
202
+ scatter = go.Scatter(
203
+ x=xy_x,
204
+ y=xy_y,
205
+ mode="markers",
206
+ marker=dict(
207
+ color=colors[i],
208
+ colorscale="Rainbow",
209
+ size=5,
210
+ opacity=0.1,
211
+ showscale=False
212
+ ),
213
+ name=f'XY Points{i}-conf:{round(result["confidence_list"][i], 2)}',
214
+ legendgroup=f"group{index//2}",
215
+ showlegend=(row == 1 and col == 1) # 只在第一个子图显示图例
216
+ )
217
+ if (index % 2 != 0):
218
+ fig.add_trace(scatter, row=row, col=col)
219
+
220
+ # 更新布局
221
+ fig.update_layout(
222
+ height=500 * rows,
223
+ width=900 * cols,
224
+ margin=dict(r=60, t=60, b=60, l=60),
225
+ legend=dict(
226
+ font=dict(family="Courier", size=12, color="black"),
227
+ borderwidth=1
228
+ )
229
+ )
230
+
231
+ # 更新坐标轴设置
232
+ fig.update_xaxes(
233
+ title_text="Bias",
234
+ title_font=dict(size=10), # 缩小字体
235
+ title_standoff=8 # 增加标题与坐标轴的距离(单位:像素)
236
+ )
237
+ fig.update_yaxes(
238
+ title_text="Frequency (GHz)",
239
+ title_font=dict(size=10),
240
+ title_standoff=8
241
+ )
242
+
243
+ return fig
244
+ # 保存图片
@@ -0,0 +1,72 @@
1
+ import numpy as np
2
+ import matplotlib.pyplot as plt
3
+ from .pltplotter import QuantumDataPltPlotter
4
+
5
+ class Spectrum2DScopeDataPltPlotter(QuantumDataPltPlotter):
6
+ def __init__(self):
7
+ super().__init__("spectrum2dscope")
8
+
9
+ def plot_result_npy(self, **kwargs):
10
+ results = kwargs.get('result')
11
+ dict_param = kwargs.get('dict_param')
12
+
13
+ data = dict_param.item()
14
+ image = data["image"]
15
+ q_list = image.keys()
16
+ volt_list = []
17
+ freq_list = []
18
+ s_list = []
19
+ q_name_list=[]
20
+ for idx, q_name in enumerate(q_list):
21
+ image_q = image[q_name]
22
+
23
+ volt = image_q[1]
24
+ freq = image_q[2]
25
+ s = np.abs(image_q[0])
26
+ volt_list.append(volt)
27
+ freq_list.append(freq)
28
+ s_list.append(s)
29
+ q_name_list.append(q_name)
30
+ coslines_list= results['params']
31
+ cosconfs_list= results['confs']
32
+ coscompress_list= results['coscompress_list']
33
+ lines_list= results['lines_list']
34
+ lineconfs_list= results['lineconfs_list']
35
+
36
+ nums = len(volt_list)*2
37
+ row = (nums // 2) + 1 if nums % 2 != 0 else nums // 2
38
+ col = min(nums, 2)
39
+
40
+ fig = plt.figure(figsize=(5 * col, 4 * row))
41
+
42
+ for ii in range(nums):
43
+ ax = fig.add_subplot(row, col, ii + 1)
44
+
45
+ volt = volt_list[ii//2]
46
+ freq = freq_list[ii//2]
47
+ s = s_list[ii//2]
48
+ coslines = coslines_list[ii//2]
49
+ cosconfs = cosconfs_list[ii//2]
50
+ coscompress = coscompress_list[ii//2]
51
+ lines = lines_list[ii//2]
52
+ lineconfs = lineconfs_list[ii//2]
53
+ plt.pcolormesh(volt, freq, s, cmap='viridis')
54
+ if (ii % 2 != 0):
55
+ if (lines):
56
+ for j, line in enumerate(lines):
57
+ final_x_line = [item[0] for item in line]
58
+ final_line_pred = [item[1] for item in line]
59
+ plt.plot(final_x_line, final_line_pred, c='r')
60
+ plt.text(volt[len(volt) // 2], freq[len(freq) // 2], f'confidence: {lineconfs[j]:.2f}', c='red',
61
+ size=15)
62
+
63
+ if (coslines):
64
+ for j, cosline in enumerate(coslines):
65
+ final_x_cos = [item[0] for item in cosline]
66
+ final_cos_pred = [item[1] for item in cosline]
67
+ plt.plot(final_x_cos, final_cos_pred, c='r')
68
+ plt.text(volt[len(volt) // 2], freq[len(freq) // 2],
69
+ f'confidence: {cosconfs[j]:.2f}\ncompress: {coscompress[j]:.2f}', c='red', size=15)
70
+ ax.set_title(f"{q_name_list[ii//2]}")
71
+ fig.tight_layout()
72
+ return fig # ✅ 返回 Figure 对象
@@ -0,0 +1,195 @@
1
+ from .plyplotter import QuantumDataPlyPlotter
2
+ import numpy as np
3
+ import plotly.graph_objects as go
4
+ from plotly.subplots import make_subplots
5
+
6
+ class Spectrum2DScopeDataPlyPlotter(QuantumDataPlyPlotter):
7
+
8
+ def __init__(self):
9
+ super().__init__("spectrum2dscope")
10
+
11
+
12
+ def plot_result_npy(self, **kwargs):
13
+
14
+ results = kwargs.get('result')
15
+ dict_param = kwargs.get('dict_param')
16
+
17
+ data = dict_param.item()
18
+ image = data["image"]
19
+ q_list = image.keys()
20
+
21
+ # 数据提取
22
+ volt_list = []
23
+ freq_list = []
24
+ s_list = []
25
+ q_name_list = []
26
+
27
+ for idx, q_name in enumerate(q_list):
28
+ image_q = image[q_name]
29
+ volt = image_q[1]
30
+ freq = image_q[2]
31
+ s = np.abs(image_q[0])
32
+
33
+ volt_list.append(volt)
34
+ freq_list.append(freq)
35
+ s_list.append(s)
36
+ q_name_list.append(q_name)
37
+
38
+ # 结果数据
39
+ coslines_list = results['params']
40
+ cosconfs_list = results['confs']
41
+ coscompress_list = results['coscompress_list']
42
+ lines_list = results['lines_list']
43
+ lineconfs_list = results['lineconfs_list']
44
+
45
+ # 计算子图布局
46
+ nums = len(volt_list) * 2
47
+ rows = (nums // 2) + 1 if nums % 2 != 0 else nums // 2
48
+ cols = min(nums, 2)
49
+
50
+ # 计算安全的垂直间距
51
+ max_vertical_spacing = 1 / (rows - 1) if rows > 1 else 0.1
52
+ safe_vertical_spacing = min(0.05, max_vertical_spacing - 0.01)
53
+
54
+ # 创建子图布局
55
+ fig = make_subplots(
56
+ rows=rows,
57
+ cols=cols,
58
+ vertical_spacing=0.01,
59
+ horizontal_spacing=0.1,
60
+ subplot_titles=[f"{q_name_list[ii // 2]}_Heatmap" if ii % 2 == 0
61
+ else f"{q_name_list[ii // 2]}_WithCurves" for ii in range(nums)]
62
+ )
63
+
64
+ # 遍历所有子图位置
65
+ for ii in range(nums):
66
+ row_pos = (ii // cols) + 1
67
+ col_pos = (ii % cols) + 1
68
+
69
+ volt = volt_list[ii // 2]
70
+ freq = freq_list[ii // 2]
71
+ s = s_list[ii // 2]
72
+ coslines = coslines_list[ii // 2]
73
+ cosconfs = cosconfs_list[ii // 2]
74
+ coscompress = coscompress_list[ii // 2]
75
+ lines = lines_list[ii // 2]
76
+ lineconfs = lineconfs_list[ii // 2]
77
+
78
+ # 热力图数据
79
+ heatmap_trace = go.Heatmap(
80
+ x=volt,
81
+ y=freq,
82
+ z=s,
83
+ colorscale='Viridis',
84
+ showscale=True,
85
+ colorbar=dict(title='Intensity'),
86
+ hovertemplate=(
87
+ 'Volt: %{x}<br>' +
88
+ 'Freq: %{y}<br>' +
89
+ 'Intensity: %{z}<extra></extra>'
90
+ )
91
+ )
92
+
93
+ fig.add_trace(heatmap_trace, row=row_pos, col=col_pos)
94
+
95
+ # 在奇数编号的子图中添加曲线和线条
96
+ if (ii % 2 != 0):
97
+ # 添加直线
98
+ if lines:
99
+ for j, line in enumerate(lines):
100
+ if line:
101
+ final_x_line = [item[0] for item in line]
102
+ final_line_pred = [item[1] for item in line]
103
+
104
+ line_trace = go.Scatter(
105
+ x=final_x_line,
106
+ y=final_line_pred,
107
+ mode='lines',
108
+ line=dict(color='red', width=3),
109
+ name=f'Line {j + 1}',
110
+ showlegend=False,
111
+ hovertemplate=(
112
+ 'Volt: %{x}<br>' +
113
+ 'Freq: %{y}<br>' +
114
+ f'Confidence: {lineconfs[j]:.2f}<extra></extra>'
115
+ )
116
+ )
117
+ fig.add_trace(line_trace, row=row_pos, col=col_pos)
118
+
119
+ # 添加置信度文本
120
+ mid_idx = len(volt) // 2
121
+ if mid_idx < len(volt):
122
+ fig.add_annotation(
123
+ x=volt[mid_idx],
124
+ y=freq[mid_idx],
125
+ text=f"conf: {lineconfs[j]:.2f}",
126
+ showarrow=False,
127
+ font=dict(color='red', size=12),
128
+ bgcolor='rgba(255,255,255,0.8)',
129
+ row=row_pos,
130
+ col=col_pos
131
+ )
132
+
133
+ # 添加余弦曲线
134
+ if coslines:
135
+ for j, cosline in enumerate(coslines):
136
+ if cosline:
137
+ final_x_cos = [item[0] for item in cosline]
138
+ final_cos_pred = [item[1] for item in cosline]
139
+
140
+ cos_trace = go.Scatter(
141
+ x=final_x_cos,
142
+ y=final_cos_pred,
143
+ mode='lines',
144
+ line=dict(color='red', width=3),
145
+ name=f'Cosine {j + 1}',
146
+ showlegend=False,
147
+ hovertemplate=(
148
+ 'Volt: %{x}<br>' +
149
+ 'Freq: %{y}<br>' +
150
+ f'Confidence: {cosconfs[j]:.2f}<br>' +
151
+ f'Compress: {coscompress[j]:.2f}<extra></extra>'
152
+ )
153
+ )
154
+ fig.add_trace(cos_trace, row=row_pos, col=col_pos)
155
+
156
+ # 添加置信度和压缩比文本
157
+ mid_idx = len(volt) // 2
158
+ if mid_idx < len(volt):
159
+ fig.add_annotation(
160
+ x=volt[mid_idx],
161
+ y=freq[mid_idx],
162
+ text=f"conf: {cosconfs[j]:.2f}<br>compress: {coscompress[j]:.2f}",
163
+ showarrow=False,
164
+ font=dict(color='red', size=12),
165
+ bgcolor='rgba(255,255,255,0.8)',
166
+ row=row_pos,
167
+ col=col_pos
168
+ )
169
+
170
+
171
+ # 更新布局
172
+ fig.update_layout(
173
+ height=500 * rows,
174
+ width=900 * cols,
175
+ margin=dict(r=60, t=60, b=60, l=60),
176
+ legend=dict(
177
+ font=dict(family="Courier", size=12, color="black"),
178
+ borderwidth=1
179
+ )
180
+ )
181
+
182
+ # 更新坐标轴设置
183
+ fig.update_xaxes(
184
+ title_text="Bias",
185
+ title_font=dict(size=10), # 缩小字体
186
+ title_standoff=8 # 增加标题与坐标轴的距离(单位:像素)
187
+ )
188
+ fig.update_yaxes(
189
+ title_text="Frequency (GHz)",
190
+ title_font=dict(size=10),
191
+ title_standoff=8
192
+ )
193
+
194
+ return fig
195
+ # 保存图片