qubitclient 0.1.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- qubitclient/__init__.py +5 -0
- qubitclient/draw/__init__.py +0 -0
- qubitclient/draw/optpipulsepltplotter.py +75 -0
- qubitclient/draw/optpipulseplyplotter.py +114 -0
- qubitclient/draw/pltmanager.py +50 -0
- qubitclient/draw/pltplotter.py +20 -0
- qubitclient/draw/plymanager.py +57 -0
- qubitclient/draw/plyplotter.py +21 -0
- qubitclient/draw/powershiftpltplotter.py +108 -0
- qubitclient/draw/powershiftplyplotter.py +194 -0
- qubitclient/draw/rabicospltplotter.py +74 -0
- qubitclient/draw/rabicosplyplotter.py +90 -0
- qubitclient/draw/rabipltplotter.py +66 -0
- qubitclient/draw/rabiplyplotter.py +86 -0
- qubitclient/draw/s21peakpltplotter.py +67 -0
- qubitclient/draw/s21peakplyplotter.py +124 -0
- qubitclient/draw/s21vfluxpltplotter.py +84 -0
- qubitclient/draw/s21vfluxplyplotter.py +163 -0
- qubitclient/draw/singleshotpltplotter.py +149 -0
- qubitclient/draw/singleshotplyplotter.py +324 -0
- qubitclient/draw/spectrum2dpltplotter.py +107 -0
- qubitclient/draw/spectrum2dplyplotter.py +244 -0
- qubitclient/draw/spectrum2dscopepltplotter.py +72 -0
- qubitclient/draw/spectrum2dscopeplyplotter.py +195 -0
- qubitclient/draw/spectrumpltplotter.py +106 -0
- qubitclient/draw/spectrumplyplotter.py +133 -0
- qubitclient/draw/t1fitpltplotter.py +76 -0
- qubitclient/draw/t1fitplyplotter.py +109 -0
- qubitclient/draw/t2fitpltplotter.py +70 -0
- qubitclient/draw/t2fitplyplotter.py +111 -0
- qubitclient/nnscope/nnscope.py +51 -0
- qubitclient/nnscope/nnscope_api/curve/__init__.py +0 -0
- qubitclient/nnscope/nnscope_api/curve/curve_type.py +15 -0
- qubitclient/nnscope/task.py +170 -0
- qubitclient/nnscope/utils/data_convert.py +114 -0
- qubitclient/nnscope/utils/data_parser.py +41 -0
- qubitclient/nnscope/utils/request_tool.py +41 -0
- qubitclient/nnscope/utils/result_parser.py +55 -0
- qubitclient/scope/scope.py +50 -0
- qubitclient/scope/scope_api/__init__.py +8 -0
- qubitclient/scope/scope_api/api/__init__.py +1 -0
- qubitclient/scope/scope_api/api/defined_tasks/__init__.py +1 -0
- qubitclient/scope/scope_api/api/defined_tasks/get_task_result_api_v1_tasks_demo_pk_get.py +155 -0
- qubitclient/scope/scope_api/api/defined_tasks/get_task_result_api_v1_tasks_scope_pk_get.py +155 -0
- qubitclient/scope/scope_api/api/defined_tasks/optpipulse_api_v1_tasks_scope_optpipulse_post.py +218 -0
- qubitclient/scope/scope_api/api/defined_tasks/powershift_api_v1_tasks_scope_powershift_post.py +218 -0
- qubitclient/scope/scope_api/api/defined_tasks/rabi_api_v1_tasks_scope_rabi_post.py +218 -0
- qubitclient/scope/scope_api/api/defined_tasks/rabicos_api_v1_tasks_scope_rabicospeak_post.py +218 -0
- qubitclient/scope/scope_api/api/defined_tasks/s21peak_api_v1_tasks_scope_s21peak_post.py +218 -0
- qubitclient/scope/scope_api/api/defined_tasks/s21vflux_api_v1_tasks_scope_s21vflux_post.py +218 -0
- qubitclient/scope/scope_api/api/defined_tasks/singleshot_api_v1_tasks_scope_singleshot_post.py +218 -0
- qubitclient/scope/scope_api/api/defined_tasks/spectrum2d_api_v1_tasks_scope_spectrum2d_post.py +218 -0
- qubitclient/scope/scope_api/api/defined_tasks/spectrum_api_v1_tasks_scope_spectrum_post.py +218 -0
- qubitclient/scope/scope_api/api/defined_tasks/t1fit_api_v1_tasks_scope_t1fit_post.py +218 -0
- qubitclient/scope/scope_api/api/defined_tasks/t1fit_api_v1_tasks_scope_t2fit_post.py +218 -0
- qubitclient/scope/scope_api/client.py +268 -0
- qubitclient/scope/scope_api/errors.py +16 -0
- qubitclient/scope/scope_api/models/__init__.py +31 -0
- qubitclient/scope/scope_api/models/body_optpipulse_api_v1_tasks_scope_optpipulse_post.py +83 -0
- qubitclient/scope/scope_api/models/body_powershift_api_v1_tasks_scope_powershift_post.py +83 -0
- qubitclient/scope/scope_api/models/body_rabi_api_v1_tasks_scope_rabi_post.py +83 -0
- qubitclient/scope/scope_api/models/body_rabicos_api_v1_tasks_scope_rabicospeak_post.py +83 -0
- qubitclient/scope/scope_api/models/body_s21_peak_api_v1_tasks_scope_s21_peak_post.py +83 -0
- qubitclient/scope/scope_api/models/body_s21_vflux_api_v1_tasks_scope_s21_vflux_post.py +83 -0
- qubitclient/scope/scope_api/models/body_singleshot_api_v1_tasks_scope_singleshot_post.py +83 -0
- qubitclient/scope/scope_api/models/body_spectrum_2d_api_v1_tasks_scope_spectrum_2d_post.py +83 -0
- qubitclient/scope/scope_api/models/body_spectrum_api_v1_tasks_scope_spectrum_post.py +83 -0
- qubitclient/scope/scope_api/models/body_t1_fit_api_v1_tasks_scope_t1_fit_post.py +83 -0
- qubitclient/scope/scope_api/models/body_t1_fit_api_v1_tasks_scope_t2_fit_post.py +83 -0
- qubitclient/scope/scope_api/models/http_validation_error.py +75 -0
- qubitclient/scope/scope_api/models/validation_error.py +88 -0
- qubitclient/scope/scope_api/types.py +54 -0
- qubitclient/scope/task.py +163 -0
- qubitclient/scope/utils/__init__.py +0 -0
- qubitclient/scope/utils/data_parser.py +20 -0
- qubitclient-0.1.4.dist-info/METADATA +173 -0
- qubitclient-0.1.4.dist-info/RECORD +81 -0
- qubitclient-0.1.4.dist-info/WHEEL +5 -0
- qubitclient-0.1.4.dist-info/licenses/LICENSE +674 -0
- qubitclient-0.1.4.dist-info/top_level.txt +1 -0
- qubitclient-0.1.4.dist-info/zip-safe +1 -0
qubitclient/__init__.py
ADDED
|
File without changes
|
|
@@ -0,0 +1,75 @@
|
|
|
1
|
+
from .pltplotter import QuantumDataPltPlotter
|
|
2
|
+
import matplotlib.pyplot as plt
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class OptPiPulseDataPltPlotter(QuantumDataPltPlotter):
|
|
7
|
+
def __init__(self):
|
|
8
|
+
super().__init__("optpipulse")
|
|
9
|
+
|
|
10
|
+
def plot_result_npy(self, **kwargs):
|
|
11
|
+
result = kwargs.get('result')
|
|
12
|
+
dict_param = kwargs.get('dict_param')
|
|
13
|
+
|
|
14
|
+
if not result or not dict_param:
|
|
15
|
+
fig, ax = plt.subplots()
|
|
16
|
+
ax.text(0.5, 0.5, "No data", ha='center', transform=ax.transAxes)
|
|
17
|
+
plt.close(fig)
|
|
18
|
+
return fig
|
|
19
|
+
|
|
20
|
+
data = dict_param.item() if isinstance(dict_param, np.ndarray) else dict_param
|
|
21
|
+
image_dict = data.get("image", {})
|
|
22
|
+
qubit_names = list(image_dict.keys())
|
|
23
|
+
if not qubit_names:
|
|
24
|
+
fig, ax = plt.subplots()
|
|
25
|
+
ax.text(0.5, 0.5, "No qubits", ha='center', transform=ax.transAxes)
|
|
26
|
+
plt.close(fig)
|
|
27
|
+
return fig
|
|
28
|
+
|
|
29
|
+
cols = min(3, len(qubit_names))
|
|
30
|
+
rows = (len(qubit_names) + cols - 1) // cols
|
|
31
|
+
|
|
32
|
+
fig = plt.figure(figsize=(5.8 * cols, 4.5 * rows))
|
|
33
|
+
fig.suptitle("Opt-Pi-Pulse", fontsize=14, y=0.96)
|
|
34
|
+
|
|
35
|
+
wave_colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728',
|
|
36
|
+
'#9467bd', '#8c564b', '#e377c2', '#7f7f7f']
|
|
37
|
+
|
|
38
|
+
params_list = result.get("params", [])
|
|
39
|
+
confs_list = result.get("confs", [])
|
|
40
|
+
|
|
41
|
+
for q_idx, q_name in enumerate(qubit_names):
|
|
42
|
+
ax = fig.add_subplot(rows, cols, q_idx + 1)
|
|
43
|
+
item = image_dict[q_name]
|
|
44
|
+
if not isinstance(item, (list, tuple)) or len(item) < 2:
|
|
45
|
+
continue
|
|
46
|
+
|
|
47
|
+
waveforms = np.asarray(item[0])
|
|
48
|
+
x_axis = np.asarray(item[1])
|
|
49
|
+
|
|
50
|
+
for w_idx, wave in enumerate(waveforms):
|
|
51
|
+
ax.plot(x_axis, wave,
|
|
52
|
+
color=wave_colors[w_idx % len(wave_colors)],
|
|
53
|
+
linewidth=1.2)
|
|
54
|
+
|
|
55
|
+
if q_idx < len(params_list):
|
|
56
|
+
peaks = params_list[q_idx]
|
|
57
|
+
confs = confs_list[q_idx] if q_idx < len(confs_list) else []
|
|
58
|
+
for p_idx, (peak, conf) in enumerate(zip(peaks, confs)):
|
|
59
|
+
ax.axvline(peak, color='red', linestyle='--', linewidth=1.8)
|
|
60
|
+
ax.annotate(f"x={peak:.4f}\nconf:{conf:.3f}",
|
|
61
|
+
(peak, ax.get_ylim()[1]),
|
|
62
|
+
xytext=(0, 8), textcoords='offset points',
|
|
63
|
+
ha='center', va='bottom',
|
|
64
|
+
fontsize=8, color='red',
|
|
65
|
+
bbox=dict(boxstyle="round,pad=0.3", facecolor="yellow", alpha=0.8))
|
|
66
|
+
|
|
67
|
+
ax.set_title(q_name, fontsize=11, pad=10)
|
|
68
|
+
ax.set_xlabel("Time")
|
|
69
|
+
ax.set_ylabel("Amp")
|
|
70
|
+
ax.grid(True, linestyle='--', alpha=0.5)
|
|
71
|
+
ax.legend(['wave', 'peak'], fontsize=8, loc='upper right', framealpha=0.9)
|
|
72
|
+
|
|
73
|
+
plt.tight_layout(rect=[0, 0, 1, 0.94])
|
|
74
|
+
plt.close(fig)
|
|
75
|
+
return fig
|
|
@@ -0,0 +1,114 @@
|
|
|
1
|
+
from .plyplotter import QuantumDataPlyPlotter
|
|
2
|
+
import plotly.graph_objects as go
|
|
3
|
+
from plotly.subplots import make_subplots
|
|
4
|
+
import numpy as np
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class OptPiPulseDataPlyPlotter(QuantumDataPlyPlotter):
|
|
8
|
+
def __init__(self):
|
|
9
|
+
super().__init__("optpipulse")
|
|
10
|
+
|
|
11
|
+
def plot_result_npy(self, **kwargs):
|
|
12
|
+
result = kwargs.get('result')
|
|
13
|
+
dict_param = kwargs.get('dict_param')
|
|
14
|
+
|
|
15
|
+
if not result or not dict_param:
|
|
16
|
+
fig = go.Figure()
|
|
17
|
+
fig.add_annotation(text="Missing data", xref="paper", yref="paper", x=0.5, y=0.5, showarrow=False)
|
|
18
|
+
return fig
|
|
19
|
+
|
|
20
|
+
data = dict_param.item() if isinstance(dict_param, np.ndarray) else dict_param
|
|
21
|
+
image_dict = data.get("image", {})
|
|
22
|
+
if not image_dict:
|
|
23
|
+
fig = go.Figure()
|
|
24
|
+
fig.add_annotation(text="No image data", xref="paper", yref="paper", x=0.5, y=0.5, showarrow=False)
|
|
25
|
+
return fig
|
|
26
|
+
|
|
27
|
+
qubit_names = list(image_dict.keys())
|
|
28
|
+
cols = min(3, len(qubit_names))
|
|
29
|
+
rows = (len(qubit_names) + cols - 1) // cols
|
|
30
|
+
|
|
31
|
+
fig = make_subplots(
|
|
32
|
+
rows=rows, cols=cols,
|
|
33
|
+
subplot_titles=qubit_names,
|
|
34
|
+
vertical_spacing=0.08,
|
|
35
|
+
horizontal_spacing=0.08,
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
wave_colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728',
|
|
39
|
+
'#9467bd', '#8c564b', '#e377c2', '#7f7f7f']
|
|
40
|
+
|
|
41
|
+
params_list = result.get("params", [])
|
|
42
|
+
confs_list = result.get("confs", [])
|
|
43
|
+
|
|
44
|
+
wave_legend_shown = False
|
|
45
|
+
peak_legend_shown = False
|
|
46
|
+
|
|
47
|
+
for q_idx, q_name in enumerate(qubit_names):
|
|
48
|
+
row = q_idx // cols + 1
|
|
49
|
+
col = q_idx % cols + 1
|
|
50
|
+
|
|
51
|
+
item = image_dict[q_name]
|
|
52
|
+
if not isinstance(item, (list, tuple)) or len(item) < 2:
|
|
53
|
+
continue
|
|
54
|
+
|
|
55
|
+
waveforms = np.asarray(item[0])
|
|
56
|
+
x_axis = np.asarray(item[1])
|
|
57
|
+
|
|
58
|
+
for w_idx, wave in enumerate(waveforms):
|
|
59
|
+
fig.add_trace(
|
|
60
|
+
go.Scatter(x=x_axis, y=wave,
|
|
61
|
+
mode='lines',
|
|
62
|
+
line=dict(color=wave_colors[w_idx % len(wave_colors)]),
|
|
63
|
+
name='wave',
|
|
64
|
+
legendgroup='wave',
|
|
65
|
+
showlegend=not wave_legend_shown),
|
|
66
|
+
row=row, col=col
|
|
67
|
+
)
|
|
68
|
+
if not wave_legend_shown:
|
|
69
|
+
wave_legend_shown = True
|
|
70
|
+
|
|
71
|
+
if q_idx < len(params_list):
|
|
72
|
+
peaks = params_list[q_idx]
|
|
73
|
+
confs = confs_list[q_idx] if q_idx < len(confs_list) else []
|
|
74
|
+
for p_idx, (peak, conf) in enumerate(zip(peaks, confs)):
|
|
75
|
+
show_peak_legend = (not peak_legend_shown) and (p_idx == 0)
|
|
76
|
+
|
|
77
|
+
fig.add_trace(
|
|
78
|
+
go.Scatter(x=[peak, peak],
|
|
79
|
+
y=[waveforms.min(), waveforms.max()],
|
|
80
|
+
mode='lines',
|
|
81
|
+
line=dict(color='red', width=2, dash='dash'),
|
|
82
|
+
name='peak',
|
|
83
|
+
legendgroup='peak',
|
|
84
|
+
showlegend=show_peak_legend),
|
|
85
|
+
row=row, col=col
|
|
86
|
+
)
|
|
87
|
+
if not peak_legend_shown:
|
|
88
|
+
peak_legend_shown = True
|
|
89
|
+
|
|
90
|
+
fig.add_trace(
|
|
91
|
+
go.Scatter(x=[peak],
|
|
92
|
+
y=[waveforms.max() * 1.08],
|
|
93
|
+
mode='text',
|
|
94
|
+
text=[f"x={peak:.4f}<br>conf:{conf:.3f}"],
|
|
95
|
+
textposition="top center",
|
|
96
|
+
showlegend=False,
|
|
97
|
+
textfont=dict(size=10, color="red")),
|
|
98
|
+
row=row, col=col
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
if row == rows:
|
|
102
|
+
fig.update_xaxes(title_text="Time", row=row, col=col)
|
|
103
|
+
if col == 1:
|
|
104
|
+
fig.update_yaxes(title_text="Amp", row=row, col=col)
|
|
105
|
+
|
|
106
|
+
fig.update_layout(
|
|
107
|
+
height=400 * rows,
|
|
108
|
+
width=520 * cols,
|
|
109
|
+
title_text="Opt-Pi-Pulse",
|
|
110
|
+
title_x=0.5,
|
|
111
|
+
legend=dict(font=dict(size=10), bgcolor="rgba(255,255,255,0.8)"),
|
|
112
|
+
margin=dict(l=60, r=60, t=80, b=60)
|
|
113
|
+
)
|
|
114
|
+
return fig
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
from typing import Dict, List
|
|
2
|
+
from .pltplotter import QuantumDataPltPlotter
|
|
3
|
+
from .spectrum2dpltplotter import Spectrum2DDataPltPlotter
|
|
4
|
+
from .s21vfluxpltplotter import S21VfluxDataPltPlotter
|
|
5
|
+
from .singleshotpltplotter import SingleShotDataPltPlotter
|
|
6
|
+
from .spectrum2dscopepltplotter import Spectrum2DScopeDataPltPlotter
|
|
7
|
+
from .spectrumpltplotter import SpectrumDataPltPlotter
|
|
8
|
+
from .s21peakpltplotter import S21PeakDataPltPlotter
|
|
9
|
+
|
|
10
|
+
from .optpipulsepltplotter import OptPiPulseDataPltPlotter
|
|
11
|
+
from .rabipltplotter import RabiDataPltPlotter
|
|
12
|
+
from .t1fitpltplotter import T1FitDataPltPlotter
|
|
13
|
+
from .t2fitpltplotter import T2FitDataPltPlotter
|
|
14
|
+
from .rabicospltplotter import RabiCosDataPltPlotter
|
|
15
|
+
from .powershiftpltplotter import PowerShiftDataPltPlotter
|
|
16
|
+
|
|
17
|
+
class QuantumPlotPltManager:
|
|
18
|
+
def __init__(self):
|
|
19
|
+
self.plotters: Dict[str, QuantumDataPltPlotter] = {}
|
|
20
|
+
self.register_plotters()
|
|
21
|
+
|
|
22
|
+
def register_plotters(self):
|
|
23
|
+
self.plotters["spectrum2d"] = Spectrum2DDataPltPlotter()
|
|
24
|
+
self.plotters["s21vflux"] = S21VfluxDataPltPlotter()
|
|
25
|
+
self.plotters["singleshot"] = SingleShotDataPltPlotter()
|
|
26
|
+
self.plotters["spectrum2dscope"] = Spectrum2DScopeDataPltPlotter()
|
|
27
|
+
self.plotters["spectrum"] = SpectrumDataPltPlotter()
|
|
28
|
+
self.plotters["optpipulse"] = OptPiPulseDataPltPlotter()
|
|
29
|
+
self.plotters["rabicos"] = RabiCosDataPltPlotter()
|
|
30
|
+
self.plotters["t1fit"] = T1FitDataPltPlotter()
|
|
31
|
+
self.plotters["t2fit"] = T2FitDataPltPlotter()
|
|
32
|
+
self.plotters["rabi"] = RabiDataPltPlotter()
|
|
33
|
+
self.plotters["s21peak"] = S21PeakDataPltPlotter()
|
|
34
|
+
self.plotters["powershift"] = PowerShiftDataPltPlotter()
|
|
35
|
+
|
|
36
|
+
def get_plotter(self, task_type: str) -> QuantumDataPltPlotter:
|
|
37
|
+
if task_type not in self.plotters:
|
|
38
|
+
raise ValueError(f"未找到任务 '{task_type}' 的绘图器")
|
|
39
|
+
return self.plotters[task_type]
|
|
40
|
+
|
|
41
|
+
def list_available_tasks(self) -> List[str]:
|
|
42
|
+
return list(self.plotters.keys())
|
|
43
|
+
|
|
44
|
+
def plot_quantum_data(self, data_type: str, task_type: str, save_path, **kwargs):
|
|
45
|
+
plotter = self.get_plotter(task_type)
|
|
46
|
+
if data_type=='npy':
|
|
47
|
+
fig = plotter.plot_result_npy(**kwargs)
|
|
48
|
+
if data_type=='npz':
|
|
49
|
+
fig = plotter.plot_result_npz(**kwargs)
|
|
50
|
+
plotter.save_plot(fig,save_path)
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from abc import ABC, abstractmethod
|
|
3
|
+
|
|
4
|
+
class QuantumDataPltPlotter(ABC):
|
|
5
|
+
def __init__(self, task_type: str):
|
|
6
|
+
self.task_type = task_type
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@abstractmethod
|
|
10
|
+
def plot_result_npy(self, **kwargs):
|
|
11
|
+
pass
|
|
12
|
+
|
|
13
|
+
def plot_result_npz(self, **kwargs):
|
|
14
|
+
pass
|
|
15
|
+
|
|
16
|
+
def save_plot(self, fig, save_path: str):
|
|
17
|
+
directory = os.path.dirname(save_path)
|
|
18
|
+
if os.path.exists(directory):
|
|
19
|
+
fig.savefig(save_path) # save_path 最中存储路径 “./tmp/client/result_s21peak_tmp***.png”
|
|
20
|
+
return fig
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
|
|
2
|
+
from typing import Dict, List
|
|
3
|
+
from .plyplotter import QuantumDataPlyPlotter
|
|
4
|
+
from .spectrum2dplyplotter import Spectrum2DDataPlyPlotter
|
|
5
|
+
from .s21vfluxplyplotter import S21VfluxDataPlyPlotter
|
|
6
|
+
from .singleshotplyplotter import SingleShotDataPlyPlotter
|
|
7
|
+
from .spectrum2dscopeplyplotter import Spectrum2DScopeDataPlyPlotter
|
|
8
|
+
from .spectrumplyplotter import SpectrumDataPlyPlotter
|
|
9
|
+
|
|
10
|
+
from .optpipulseplyplotter import OptPiPulseDataPlyPlotter
|
|
11
|
+
from .rabiplyplotter import RabiDataPlyPlotter
|
|
12
|
+
from .t1fitplyplotter import T1FitDataPlyPlotter
|
|
13
|
+
from .t2fitplyplotter import T2FitDataPlyPlotter
|
|
14
|
+
from .rabicosplyplotter import RabiCosDataPlyPlotter
|
|
15
|
+
from .s21peakplyplotter import S21PeakDataPlyPlotter
|
|
16
|
+
from .powershiftplyplotter import PowerShiftDataPlyPlotter
|
|
17
|
+
|
|
18
|
+
class QuantumPlotPlyManager:
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def __init__(self):
|
|
22
|
+
self.plotters: Dict[str, QuantumDataPlyPlotter] = {}
|
|
23
|
+
self.register_plotters()
|
|
24
|
+
|
|
25
|
+
def register_plotters(self):
|
|
26
|
+
|
|
27
|
+
self.plotters["spectrum2d"] = Spectrum2DDataPlyPlotter()
|
|
28
|
+
self.plotters["s21vflux"] = S21VfluxDataPlyPlotter()
|
|
29
|
+
self.plotters["singleshot"] = SingleShotDataPlyPlotter()
|
|
30
|
+
self.plotters["spectrum2dscope"] = Spectrum2DScopeDataPlyPlotter()
|
|
31
|
+
self.plotters["spectrum"] = SpectrumDataPlyPlotter()
|
|
32
|
+
self.plotters["optpipulse"] = OptPiPulseDataPlyPlotter()
|
|
33
|
+
self.plotters["rabicos"] = RabiCosDataPlyPlotter()
|
|
34
|
+
self.plotters["t1fit"] = T1FitDataPlyPlotter()
|
|
35
|
+
self.plotters["t2fit"] = T2FitDataPlyPlotter()
|
|
36
|
+
self.plotters["rabi"] = RabiDataPlyPlotter()
|
|
37
|
+
self.plotters["s21peak"] = S21PeakDataPlyPlotter()
|
|
38
|
+
self.plotters["powershift"] = PowerShiftDataPlyPlotter()
|
|
39
|
+
|
|
40
|
+
def get_plotter(self, task_type: str) -> QuantumDataPlyPlotter:
|
|
41
|
+
|
|
42
|
+
if task_type not in self.plotters:
|
|
43
|
+
raise ValueError(f"未找到任务 '{task_type}' 的绘图器")
|
|
44
|
+
return self.plotters[task_type]
|
|
45
|
+
|
|
46
|
+
def list_available_tasks(self) -> List[str]:
|
|
47
|
+
|
|
48
|
+
return list(self.plotters.keys())
|
|
49
|
+
|
|
50
|
+
def plot_quantum_data(self, data_type: str, task_type: str,save_path: str,**kwargs):
|
|
51
|
+
plotter = self.get_plotter(task_type)
|
|
52
|
+
if data_type=='npy':
|
|
53
|
+
fig = plotter.plot_result_npy(**kwargs)
|
|
54
|
+
if data_type=='npz':
|
|
55
|
+
fig = plotter.plot_result_npz(**kwargs)
|
|
56
|
+
|
|
57
|
+
plotter.save_plot(fig,save_path)
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
|
|
2
|
+
from abc import ABC, abstractmethod
|
|
3
|
+
import json
|
|
4
|
+
import os
|
|
5
|
+
|
|
6
|
+
class QuantumDataPlyPlotter(ABC):
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def __init__(self, task_type: str):
|
|
10
|
+
self.task_type = task_type
|
|
11
|
+
@abstractmethod
|
|
12
|
+
def plot_result_npy(self, **kwargs):
|
|
13
|
+
pass
|
|
14
|
+
def plot_result_npz(self, **kwargs):
|
|
15
|
+
pass
|
|
16
|
+
|
|
17
|
+
def save_plot(self, fig, save_path: str):
|
|
18
|
+
directory = os.path.dirname(save_path)
|
|
19
|
+
if os.path.exists(directory):
|
|
20
|
+
fig.write_html(save_path) # save_path 最中存储路径 “./tmp/client/result_s21peak_tmp***.html”
|
|
21
|
+
return fig
|
|
@@ -0,0 +1,108 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import matplotlib.pyplot as plt
|
|
3
|
+
from .pltplotter import QuantumDataPltPlotter
|
|
4
|
+
|
|
5
|
+
class PowerShiftDataPltPlotter(QuantumDataPltPlotter):
|
|
6
|
+
|
|
7
|
+
def __init__(self):
|
|
8
|
+
super().__init__("powershift")
|
|
9
|
+
|
|
10
|
+
def plot_result_npy(self, **kwargs):
|
|
11
|
+
result = kwargs.get('result')
|
|
12
|
+
dict_param = kwargs.get('dict_param')
|
|
13
|
+
|
|
14
|
+
data = dict_param.item()
|
|
15
|
+
image = data["image"]
|
|
16
|
+
q_list = list(image.keys()) # 确保q_list是列表形式,便于索引
|
|
17
|
+
num_qubits = len(q_list)
|
|
18
|
+
|
|
19
|
+
# 数据提取
|
|
20
|
+
items = []
|
|
21
|
+
for q_name in q_list:
|
|
22
|
+
image_q = image[q_name]
|
|
23
|
+
x, y, value = image_q[0], image_q[1], image_q[2]
|
|
24
|
+
|
|
25
|
+
# 获取当前量子比特对应的关键点、类别和配置
|
|
26
|
+
idx = q_list.index(q_name)
|
|
27
|
+
keypoints = result['keypoints_list'][idx] if idx < len(result['keypoints_list']) else []
|
|
28
|
+
class_num = result['class_num_list'][idx] if idx < len(result['class_num_list']) else None
|
|
29
|
+
conf = result['confs'][idx] if idx < len(result['confs']) else None
|
|
30
|
+
|
|
31
|
+
items.append({
|
|
32
|
+
'x': x,
|
|
33
|
+
'y': y,
|
|
34
|
+
'value': value,
|
|
35
|
+
'keypoints': keypoints,
|
|
36
|
+
'q_name': q_name,
|
|
37
|
+
'class_num': class_num,
|
|
38
|
+
'conf': conf
|
|
39
|
+
})
|
|
40
|
+
|
|
41
|
+
# 结果数据(原代码保留)
|
|
42
|
+
confs = result['confs']
|
|
43
|
+
class_num_list = result['class_num_list']
|
|
44
|
+
keypoints_list = result['keypoints_list']
|
|
45
|
+
|
|
46
|
+
# 合并所有item的图像到一张图中(多行多列布局)
|
|
47
|
+
num_items = len(items)
|
|
48
|
+
if num_items == 0:
|
|
49
|
+
raise ValueError("没有可合并的item数据")
|
|
50
|
+
|
|
51
|
+
# 配置每行最多显示的子图数量
|
|
52
|
+
max_cols = 4
|
|
53
|
+
# 计算需要的行数和列数
|
|
54
|
+
rows = (num_items + max_cols - 1) // max_cols # 向上取整计算行数
|
|
55
|
+
cols = min(num_items, max_cols) # 列数不超过max_cols
|
|
56
|
+
|
|
57
|
+
# 创建多行多列布局的画布
|
|
58
|
+
fig, axes = plt.subplots(rows, cols, figsize=(6 * cols, 6 * rows))
|
|
59
|
+
# 将axes转换为一维数组(方便统一处理单一行/列的情况)
|
|
60
|
+
axes = axes.flatten() if rows * cols > 1 else [axes]
|
|
61
|
+
|
|
62
|
+
# 为每个item绘制内容
|
|
63
|
+
for i, item in enumerate(items):
|
|
64
|
+
ax = axes[i]
|
|
65
|
+
x = item["x"]
|
|
66
|
+
y = item["y"]
|
|
67
|
+
values = item["value"]
|
|
68
|
+
keypoints = item["keypoints"]
|
|
69
|
+
q_name = item["q_name"]
|
|
70
|
+
class_num = item["class_num"]
|
|
71
|
+
conf = item["conf"]
|
|
72
|
+
|
|
73
|
+
# 绘制原始图像
|
|
74
|
+
im = ax.pcolormesh(x, y, values, cmap='viridis', shading='auto')
|
|
75
|
+
fig.colorbar(im, ax=ax)
|
|
76
|
+
|
|
77
|
+
# 绘制关键点
|
|
78
|
+
if keypoints:
|
|
79
|
+
sorted_keypoints = sorted(keypoints, key=lambda p: (-p[1], p[0]))
|
|
80
|
+
kp_x = [p[0] for p in sorted_keypoints]
|
|
81
|
+
kp_y = [p[1] for p in sorted_keypoints]
|
|
82
|
+
ax.scatter(kp_x, kp_y, color='red', s=50, marker='*', label='Key Points')
|
|
83
|
+
ax.plot(kp_x, kp_y, 'r--', linewidth=2)
|
|
84
|
+
|
|
85
|
+
# 添加class和confs信息
|
|
86
|
+
info_text = f"Qubit: {q_name}\n"
|
|
87
|
+
if class_num is not None:
|
|
88
|
+
info_text += f"Class: {class_num}\n"
|
|
89
|
+
if conf is not None:
|
|
90
|
+
# 格式化置信度为两位小数
|
|
91
|
+
info_text += f"Confidence: {conf:.2f}"
|
|
92
|
+
|
|
93
|
+
# 在图中添加文本信息
|
|
94
|
+
ax.text(0.05, 0.95, info_text, transform=ax.transAxes,
|
|
95
|
+
verticalalignment='top', bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
|
|
96
|
+
|
|
97
|
+
ax.set_title('Original Image')
|
|
98
|
+
ax.set_xlabel('X')
|
|
99
|
+
ax.set_ylabel('Y')
|
|
100
|
+
ax.legend()
|
|
101
|
+
|
|
102
|
+
# 隐藏多余的子图
|
|
103
|
+
for i in range(num_items, rows * cols):
|
|
104
|
+
axes[i].axis('off')
|
|
105
|
+
|
|
106
|
+
# 调整布局
|
|
107
|
+
plt.tight_layout()
|
|
108
|
+
return fig
|
|
@@ -0,0 +1,194 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import plotly.graph_objects as go
|
|
3
|
+
from plotly.subplots import make_subplots
|
|
4
|
+
from .plyplotter import QuantumDataPlyPlotter
|
|
5
|
+
|
|
6
|
+
class PowerShiftDataPlyPlotter(QuantumDataPlyPlotter):
|
|
7
|
+
|
|
8
|
+
def __init__(self):
|
|
9
|
+
super().__init__("powershift")
|
|
10
|
+
|
|
11
|
+
def plot_result_npy(self, **kwargs):
|
|
12
|
+
result = kwargs.get('result')
|
|
13
|
+
dict_param = kwargs.get('dict_param')
|
|
14
|
+
|
|
15
|
+
data = dict_param.item()
|
|
16
|
+
image = data["image"]
|
|
17
|
+
q_list = list(image.keys())
|
|
18
|
+
if not q_list:
|
|
19
|
+
raise ValueError("没有找到量子比特数据")
|
|
20
|
+
|
|
21
|
+
items = []
|
|
22
|
+
for q_name in q_list:
|
|
23
|
+
image_q = image[q_name]
|
|
24
|
+
x, y, value = image_q[0], image_q[1], image_q[2]
|
|
25
|
+
|
|
26
|
+
idx = q_list.index(q_name)
|
|
27
|
+
keypoints = result['keypoints_list'][idx] if idx < len(result['keypoints_list']) else []
|
|
28
|
+
class_num = result['class_num_list'][idx] if idx < len(result['class_num_list']) else None
|
|
29
|
+
conf = result['confs'][idx] if idx < len(result['confs']) else None
|
|
30
|
+
|
|
31
|
+
# 数据格式处理
|
|
32
|
+
x = np.squeeze(x)
|
|
33
|
+
y = np.squeeze(y)
|
|
34
|
+
value = np.squeeze(value)
|
|
35
|
+
|
|
36
|
+
# 热力图网格与坐标匹配处理
|
|
37
|
+
if value.shape[0] == len(y) - 1 and value.shape[1] == len(x) - 1:
|
|
38
|
+
pass
|
|
39
|
+
else:
|
|
40
|
+
value = value[:len(y), :len(x)]
|
|
41
|
+
|
|
42
|
+
items.append({
|
|
43
|
+
'x': x,
|
|
44
|
+
'y': y,
|
|
45
|
+
'value': value,
|
|
46
|
+
'keypoints': keypoints,
|
|
47
|
+
'q_name': q_name,
|
|
48
|
+
'class_num': class_num,
|
|
49
|
+
'conf': conf
|
|
50
|
+
})
|
|
51
|
+
|
|
52
|
+
num_items = len(items)
|
|
53
|
+
max_cols = 4 # 保持每行4个子图,平衡宽度和可读性
|
|
54
|
+
rows = (num_items + max_cols - 1) // max_cols
|
|
55
|
+
cols = min(num_items, max_cols)
|
|
56
|
+
|
|
57
|
+
# 1. 增大子图基础尺寸,按数量动态调整(优先保证显示完整)
|
|
58
|
+
base_size = 320 # 基础尺寸从250提升到320,显著增大子图
|
|
59
|
+
if num_items > 30:
|
|
60
|
+
base_size = 280 # 30个以上子图适当缩小,但仍比之前200大
|
|
61
|
+
elif num_items > 15:
|
|
62
|
+
base_size = 300 # 15-30个子图微调
|
|
63
|
+
fig_height = base_size * rows
|
|
64
|
+
fig_width = base_size * cols
|
|
65
|
+
|
|
66
|
+
# 2. 调整间距:增大垂直/水平间距,避免内容挤压
|
|
67
|
+
vertical_spacing = 0.08 # 从0.02提升到0.06,增加上下子图间隙
|
|
68
|
+
horizontal_spacing = 0.08 # 从0.03提升到0.08,增加左右子图间隙
|
|
69
|
+
# 子图数量过多时,间距适度缩小(但仍大于之前的最小间距)
|
|
70
|
+
if num_items > 35:
|
|
71
|
+
vertical_spacing = 0.06
|
|
72
|
+
horizontal_spacing = 0.06
|
|
73
|
+
|
|
74
|
+
subplot_titles = []
|
|
75
|
+
for item in items:
|
|
76
|
+
q_name = item['q_name']
|
|
77
|
+
class_num = item['class_num']
|
|
78
|
+
conf = item['conf']
|
|
79
|
+
|
|
80
|
+
# 构建标题基础部分
|
|
81
|
+
title_parts = [f"Qubit: {q_name}"]
|
|
82
|
+
|
|
83
|
+
# 添加class信息(如果存在)
|
|
84
|
+
if class_num is not None:
|
|
85
|
+
title_parts.append(f"class: {class_num}")
|
|
86
|
+
|
|
87
|
+
# 添加conf信息(如果存在,保留两位小数)
|
|
88
|
+
if conf is not None:
|
|
89
|
+
title_parts.append(f"conf: {conf:.2f}")
|
|
90
|
+
|
|
91
|
+
# 拼接所有部分为最终标题
|
|
92
|
+
subplot_titles.append("_".join(title_parts))
|
|
93
|
+
|
|
94
|
+
# 创建子图
|
|
95
|
+
fig = make_subplots(
|
|
96
|
+
rows=rows, cols=cols,
|
|
97
|
+
subplot_titles=subplot_titles,
|
|
98
|
+
vertical_spacing=vertical_spacing,
|
|
99
|
+
horizontal_spacing=horizontal_spacing,
|
|
100
|
+
shared_yaxes=False,
|
|
101
|
+
shared_xaxes=False
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
# 统一颜色条范围
|
|
105
|
+
all_values = np.concatenate([item['value'].flatten() for item in items])
|
|
106
|
+
z_min, z_max = np.min(all_values), np.max(all_values)
|
|
107
|
+
|
|
108
|
+
for i, item in enumerate(items):
|
|
109
|
+
row = (i // cols) + 1
|
|
110
|
+
col = (i % cols) + 1
|
|
111
|
+
x = item["x"]
|
|
112
|
+
y = item["y"]
|
|
113
|
+
values = item["value"]
|
|
114
|
+
keypoints = item["keypoints"]
|
|
115
|
+
q_name = item["q_name"]
|
|
116
|
+
class_num = item["class_num"]
|
|
117
|
+
conf = item["conf"]
|
|
118
|
+
|
|
119
|
+
# 绘制热力图(保持颜色条配置正确)
|
|
120
|
+
heatmap = go.Heatmap(
|
|
121
|
+
z=values,
|
|
122
|
+
x=x,
|
|
123
|
+
y=y,
|
|
124
|
+
zmin=z_min,
|
|
125
|
+
zmax=z_max,
|
|
126
|
+
colorscale='Viridis',
|
|
127
|
+
colorbar=dict(
|
|
128
|
+
thickness=12, # 增大颜色条厚度,提升可读性
|
|
129
|
+
title=dict(
|
|
130
|
+
text="Value",
|
|
131
|
+
side="right",
|
|
132
|
+
font=dict(size=10) # 颜色条标题字体
|
|
133
|
+
)
|
|
134
|
+
) if i == (num_items - 1) else None,
|
|
135
|
+
showscale=(i == num_items - 1),
|
|
136
|
+
transpose=False
|
|
137
|
+
)
|
|
138
|
+
fig.add_trace(heatmap, row=row, col=col)
|
|
139
|
+
|
|
140
|
+
# 关键点连线按y从高到低排序
|
|
141
|
+
if keypoints and len(keypoints) > 0:
|
|
142
|
+
keypoints = np.array(keypoints).reshape(-1, 2)
|
|
143
|
+
sorted_keypoints = sorted(keypoints, key=lambda p: (-p[1], p[0]))
|
|
144
|
+
kp_x = [p[0] for p in sorted_keypoints]
|
|
145
|
+
kp_y = [p[1] for p in sorted_keypoints]
|
|
146
|
+
|
|
147
|
+
# 关键点散点(增大尺寸,避免看不清)
|
|
148
|
+
scatter = go.Scatter(
|
|
149
|
+
x=kp_x, y=kp_y,
|
|
150
|
+
mode='markers',
|
|
151
|
+
marker=dict(color='red', size=11, symbol='star', line=dict(width=1.2, color='white')),
|
|
152
|
+
name='Key Points',
|
|
153
|
+
showlegend=False
|
|
154
|
+
)
|
|
155
|
+
fig.add_trace(scatter, row=row, col=col)
|
|
156
|
+
|
|
157
|
+
# 关键点连接线(加粗线条)
|
|
158
|
+
if len(kp_x) > 1:
|
|
159
|
+
line = go.Scatter(
|
|
160
|
+
x=kp_x, y=kp_y,
|
|
161
|
+
mode='lines',
|
|
162
|
+
line=dict(color='red', dash='dash', width=1.8),
|
|
163
|
+
showlegend=False
|
|
164
|
+
)
|
|
165
|
+
fig.add_trace(line, row=row, col=col)
|
|
166
|
+
|
|
167
|
+
# 4. 坐标轴优化:增大字体,避免刻度/标题看不清
|
|
168
|
+
fig.update_xaxes(
|
|
169
|
+
title_text="X",
|
|
170
|
+
row=row, col=col,
|
|
171
|
+
range=[np.min(x), np.max(x)],
|
|
172
|
+
title_font=dict(size=11), # 轴标题字体增大
|
|
173
|
+
tickfont=dict(size=9), # 刻度字体增大
|
|
174
|
+
ticklen=4 # 增大刻度长度,提升可读性
|
|
175
|
+
)
|
|
176
|
+
fig.update_yaxes(
|
|
177
|
+
title_text="Y",
|
|
178
|
+
row=row, col=col,
|
|
179
|
+
range=[np.min(y), np.max(y)],
|
|
180
|
+
title_font=dict(size=11),
|
|
181
|
+
tickfont=dict(size=9),
|
|
182
|
+
ticklen=4
|
|
183
|
+
)
|
|
184
|
+
|
|
185
|
+
# 整体布局优化:增大边距,避免边缘内容被截断
|
|
186
|
+
fig.update_layout(
|
|
187
|
+
height=fig_height,
|
|
188
|
+
width=fig_width,
|
|
189
|
+
title_text="Power Shift Data Visualization",
|
|
190
|
+
title_font=dict(size=16, weight='bold'),
|
|
191
|
+
margin=dict(l=40, r=60, t=60, b=40) # 增大右/上/左/下边距,适配颜色条和标题
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
return fig
|