qubitclient 0.1.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (81) hide show
  1. qubitclient/__init__.py +5 -0
  2. qubitclient/draw/__init__.py +0 -0
  3. qubitclient/draw/optpipulsepltplotter.py +75 -0
  4. qubitclient/draw/optpipulseplyplotter.py +114 -0
  5. qubitclient/draw/pltmanager.py +50 -0
  6. qubitclient/draw/pltplotter.py +20 -0
  7. qubitclient/draw/plymanager.py +57 -0
  8. qubitclient/draw/plyplotter.py +21 -0
  9. qubitclient/draw/powershiftpltplotter.py +108 -0
  10. qubitclient/draw/powershiftplyplotter.py +194 -0
  11. qubitclient/draw/rabicospltplotter.py +74 -0
  12. qubitclient/draw/rabicosplyplotter.py +90 -0
  13. qubitclient/draw/rabipltplotter.py +66 -0
  14. qubitclient/draw/rabiplyplotter.py +86 -0
  15. qubitclient/draw/s21peakpltplotter.py +67 -0
  16. qubitclient/draw/s21peakplyplotter.py +124 -0
  17. qubitclient/draw/s21vfluxpltplotter.py +84 -0
  18. qubitclient/draw/s21vfluxplyplotter.py +163 -0
  19. qubitclient/draw/singleshotpltplotter.py +149 -0
  20. qubitclient/draw/singleshotplyplotter.py +324 -0
  21. qubitclient/draw/spectrum2dpltplotter.py +107 -0
  22. qubitclient/draw/spectrum2dplyplotter.py +244 -0
  23. qubitclient/draw/spectrum2dscopepltplotter.py +72 -0
  24. qubitclient/draw/spectrum2dscopeplyplotter.py +195 -0
  25. qubitclient/draw/spectrumpltplotter.py +106 -0
  26. qubitclient/draw/spectrumplyplotter.py +133 -0
  27. qubitclient/draw/t1fitpltplotter.py +76 -0
  28. qubitclient/draw/t1fitplyplotter.py +109 -0
  29. qubitclient/draw/t2fitpltplotter.py +70 -0
  30. qubitclient/draw/t2fitplyplotter.py +111 -0
  31. qubitclient/nnscope/nnscope.py +51 -0
  32. qubitclient/nnscope/nnscope_api/curve/__init__.py +0 -0
  33. qubitclient/nnscope/nnscope_api/curve/curve_type.py +15 -0
  34. qubitclient/nnscope/task.py +170 -0
  35. qubitclient/nnscope/utils/data_convert.py +114 -0
  36. qubitclient/nnscope/utils/data_parser.py +41 -0
  37. qubitclient/nnscope/utils/request_tool.py +41 -0
  38. qubitclient/nnscope/utils/result_parser.py +55 -0
  39. qubitclient/scope/scope.py +50 -0
  40. qubitclient/scope/scope_api/__init__.py +8 -0
  41. qubitclient/scope/scope_api/api/__init__.py +1 -0
  42. qubitclient/scope/scope_api/api/defined_tasks/__init__.py +1 -0
  43. qubitclient/scope/scope_api/api/defined_tasks/get_task_result_api_v1_tasks_demo_pk_get.py +155 -0
  44. qubitclient/scope/scope_api/api/defined_tasks/get_task_result_api_v1_tasks_scope_pk_get.py +155 -0
  45. qubitclient/scope/scope_api/api/defined_tasks/optpipulse_api_v1_tasks_scope_optpipulse_post.py +218 -0
  46. qubitclient/scope/scope_api/api/defined_tasks/powershift_api_v1_tasks_scope_powershift_post.py +218 -0
  47. qubitclient/scope/scope_api/api/defined_tasks/rabi_api_v1_tasks_scope_rabi_post.py +218 -0
  48. qubitclient/scope/scope_api/api/defined_tasks/rabicos_api_v1_tasks_scope_rabicospeak_post.py +218 -0
  49. qubitclient/scope/scope_api/api/defined_tasks/s21peak_api_v1_tasks_scope_s21peak_post.py +218 -0
  50. qubitclient/scope/scope_api/api/defined_tasks/s21vflux_api_v1_tasks_scope_s21vflux_post.py +218 -0
  51. qubitclient/scope/scope_api/api/defined_tasks/singleshot_api_v1_tasks_scope_singleshot_post.py +218 -0
  52. qubitclient/scope/scope_api/api/defined_tasks/spectrum2d_api_v1_tasks_scope_spectrum2d_post.py +218 -0
  53. qubitclient/scope/scope_api/api/defined_tasks/spectrum_api_v1_tasks_scope_spectrum_post.py +218 -0
  54. qubitclient/scope/scope_api/api/defined_tasks/t1fit_api_v1_tasks_scope_t1fit_post.py +218 -0
  55. qubitclient/scope/scope_api/api/defined_tasks/t1fit_api_v1_tasks_scope_t2fit_post.py +218 -0
  56. qubitclient/scope/scope_api/client.py +268 -0
  57. qubitclient/scope/scope_api/errors.py +16 -0
  58. qubitclient/scope/scope_api/models/__init__.py +31 -0
  59. qubitclient/scope/scope_api/models/body_optpipulse_api_v1_tasks_scope_optpipulse_post.py +83 -0
  60. qubitclient/scope/scope_api/models/body_powershift_api_v1_tasks_scope_powershift_post.py +83 -0
  61. qubitclient/scope/scope_api/models/body_rabi_api_v1_tasks_scope_rabi_post.py +83 -0
  62. qubitclient/scope/scope_api/models/body_rabicos_api_v1_tasks_scope_rabicospeak_post.py +83 -0
  63. qubitclient/scope/scope_api/models/body_s21_peak_api_v1_tasks_scope_s21_peak_post.py +83 -0
  64. qubitclient/scope/scope_api/models/body_s21_vflux_api_v1_tasks_scope_s21_vflux_post.py +83 -0
  65. qubitclient/scope/scope_api/models/body_singleshot_api_v1_tasks_scope_singleshot_post.py +83 -0
  66. qubitclient/scope/scope_api/models/body_spectrum_2d_api_v1_tasks_scope_spectrum_2d_post.py +83 -0
  67. qubitclient/scope/scope_api/models/body_spectrum_api_v1_tasks_scope_spectrum_post.py +83 -0
  68. qubitclient/scope/scope_api/models/body_t1_fit_api_v1_tasks_scope_t1_fit_post.py +83 -0
  69. qubitclient/scope/scope_api/models/body_t1_fit_api_v1_tasks_scope_t2_fit_post.py +83 -0
  70. qubitclient/scope/scope_api/models/http_validation_error.py +75 -0
  71. qubitclient/scope/scope_api/models/validation_error.py +88 -0
  72. qubitclient/scope/scope_api/types.py +54 -0
  73. qubitclient/scope/task.py +163 -0
  74. qubitclient/scope/utils/__init__.py +0 -0
  75. qubitclient/scope/utils/data_parser.py +20 -0
  76. qubitclient-0.1.4.dist-info/METADATA +173 -0
  77. qubitclient-0.1.4.dist-info/RECORD +81 -0
  78. qubitclient-0.1.4.dist-info/WHEEL +5 -0
  79. qubitclient-0.1.4.dist-info/licenses/LICENSE +674 -0
  80. qubitclient-0.1.4.dist-info/top_level.txt +1 -0
  81. qubitclient-0.1.4.dist-info/zip-safe +1 -0
@@ -0,0 +1,5 @@
1
+ from .nnscope.nnscope_api.curve.curve_type import CurveType # noqa: F401
2
+ from .scope.scope import QubitScopeClient
3
+ from .nnscope.nnscope import QubitNNScopeClient
4
+ from .scope.task import TaskName
5
+ from .nnscope.task import NNTaskName
File without changes
@@ -0,0 +1,75 @@
1
+ from .pltplotter import QuantumDataPltPlotter
2
+ import matplotlib.pyplot as plt
3
+ import numpy as np
4
+
5
+
6
+ class OptPiPulseDataPltPlotter(QuantumDataPltPlotter):
7
+ def __init__(self):
8
+ super().__init__("optpipulse")
9
+
10
+ def plot_result_npy(self, **kwargs):
11
+ result = kwargs.get('result')
12
+ dict_param = kwargs.get('dict_param')
13
+
14
+ if not result or not dict_param:
15
+ fig, ax = plt.subplots()
16
+ ax.text(0.5, 0.5, "No data", ha='center', transform=ax.transAxes)
17
+ plt.close(fig)
18
+ return fig
19
+
20
+ data = dict_param.item() if isinstance(dict_param, np.ndarray) else dict_param
21
+ image_dict = data.get("image", {})
22
+ qubit_names = list(image_dict.keys())
23
+ if not qubit_names:
24
+ fig, ax = plt.subplots()
25
+ ax.text(0.5, 0.5, "No qubits", ha='center', transform=ax.transAxes)
26
+ plt.close(fig)
27
+ return fig
28
+
29
+ cols = min(3, len(qubit_names))
30
+ rows = (len(qubit_names) + cols - 1) // cols
31
+
32
+ fig = plt.figure(figsize=(5.8 * cols, 4.5 * rows))
33
+ fig.suptitle("Opt-Pi-Pulse", fontsize=14, y=0.96)
34
+
35
+ wave_colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728',
36
+ '#9467bd', '#8c564b', '#e377c2', '#7f7f7f']
37
+
38
+ params_list = result.get("params", [])
39
+ confs_list = result.get("confs", [])
40
+
41
+ for q_idx, q_name in enumerate(qubit_names):
42
+ ax = fig.add_subplot(rows, cols, q_idx + 1)
43
+ item = image_dict[q_name]
44
+ if not isinstance(item, (list, tuple)) or len(item) < 2:
45
+ continue
46
+
47
+ waveforms = np.asarray(item[0])
48
+ x_axis = np.asarray(item[1])
49
+
50
+ for w_idx, wave in enumerate(waveforms):
51
+ ax.plot(x_axis, wave,
52
+ color=wave_colors[w_idx % len(wave_colors)],
53
+ linewidth=1.2)
54
+
55
+ if q_idx < len(params_list):
56
+ peaks = params_list[q_idx]
57
+ confs = confs_list[q_idx] if q_idx < len(confs_list) else []
58
+ for p_idx, (peak, conf) in enumerate(zip(peaks, confs)):
59
+ ax.axvline(peak, color='red', linestyle='--', linewidth=1.8)
60
+ ax.annotate(f"x={peak:.4f}\nconf:{conf:.3f}",
61
+ (peak, ax.get_ylim()[1]),
62
+ xytext=(0, 8), textcoords='offset points',
63
+ ha='center', va='bottom',
64
+ fontsize=8, color='red',
65
+ bbox=dict(boxstyle="round,pad=0.3", facecolor="yellow", alpha=0.8))
66
+
67
+ ax.set_title(q_name, fontsize=11, pad=10)
68
+ ax.set_xlabel("Time")
69
+ ax.set_ylabel("Amp")
70
+ ax.grid(True, linestyle='--', alpha=0.5)
71
+ ax.legend(['wave', 'peak'], fontsize=8, loc='upper right', framealpha=0.9)
72
+
73
+ plt.tight_layout(rect=[0, 0, 1, 0.94])
74
+ plt.close(fig)
75
+ return fig
@@ -0,0 +1,114 @@
1
+ from .plyplotter import QuantumDataPlyPlotter
2
+ import plotly.graph_objects as go
3
+ from plotly.subplots import make_subplots
4
+ import numpy as np
5
+
6
+
7
+ class OptPiPulseDataPlyPlotter(QuantumDataPlyPlotter):
8
+ def __init__(self):
9
+ super().__init__("optpipulse")
10
+
11
+ def plot_result_npy(self, **kwargs):
12
+ result = kwargs.get('result')
13
+ dict_param = kwargs.get('dict_param')
14
+
15
+ if not result or not dict_param:
16
+ fig = go.Figure()
17
+ fig.add_annotation(text="Missing data", xref="paper", yref="paper", x=0.5, y=0.5, showarrow=False)
18
+ return fig
19
+
20
+ data = dict_param.item() if isinstance(dict_param, np.ndarray) else dict_param
21
+ image_dict = data.get("image", {})
22
+ if not image_dict:
23
+ fig = go.Figure()
24
+ fig.add_annotation(text="No image data", xref="paper", yref="paper", x=0.5, y=0.5, showarrow=False)
25
+ return fig
26
+
27
+ qubit_names = list(image_dict.keys())
28
+ cols = min(3, len(qubit_names))
29
+ rows = (len(qubit_names) + cols - 1) // cols
30
+
31
+ fig = make_subplots(
32
+ rows=rows, cols=cols,
33
+ subplot_titles=qubit_names,
34
+ vertical_spacing=0.08,
35
+ horizontal_spacing=0.08,
36
+ )
37
+
38
+ wave_colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728',
39
+ '#9467bd', '#8c564b', '#e377c2', '#7f7f7f']
40
+
41
+ params_list = result.get("params", [])
42
+ confs_list = result.get("confs", [])
43
+
44
+ wave_legend_shown = False
45
+ peak_legend_shown = False
46
+
47
+ for q_idx, q_name in enumerate(qubit_names):
48
+ row = q_idx // cols + 1
49
+ col = q_idx % cols + 1
50
+
51
+ item = image_dict[q_name]
52
+ if not isinstance(item, (list, tuple)) or len(item) < 2:
53
+ continue
54
+
55
+ waveforms = np.asarray(item[0])
56
+ x_axis = np.asarray(item[1])
57
+
58
+ for w_idx, wave in enumerate(waveforms):
59
+ fig.add_trace(
60
+ go.Scatter(x=x_axis, y=wave,
61
+ mode='lines',
62
+ line=dict(color=wave_colors[w_idx % len(wave_colors)]),
63
+ name='wave',
64
+ legendgroup='wave',
65
+ showlegend=not wave_legend_shown),
66
+ row=row, col=col
67
+ )
68
+ if not wave_legend_shown:
69
+ wave_legend_shown = True
70
+
71
+ if q_idx < len(params_list):
72
+ peaks = params_list[q_idx]
73
+ confs = confs_list[q_idx] if q_idx < len(confs_list) else []
74
+ for p_idx, (peak, conf) in enumerate(zip(peaks, confs)):
75
+ show_peak_legend = (not peak_legend_shown) and (p_idx == 0)
76
+
77
+ fig.add_trace(
78
+ go.Scatter(x=[peak, peak],
79
+ y=[waveforms.min(), waveforms.max()],
80
+ mode='lines',
81
+ line=dict(color='red', width=2, dash='dash'),
82
+ name='peak',
83
+ legendgroup='peak',
84
+ showlegend=show_peak_legend),
85
+ row=row, col=col
86
+ )
87
+ if not peak_legend_shown:
88
+ peak_legend_shown = True
89
+
90
+ fig.add_trace(
91
+ go.Scatter(x=[peak],
92
+ y=[waveforms.max() * 1.08],
93
+ mode='text',
94
+ text=[f"x={peak:.4f}<br>conf:{conf:.3f}"],
95
+ textposition="top center",
96
+ showlegend=False,
97
+ textfont=dict(size=10, color="red")),
98
+ row=row, col=col
99
+ )
100
+
101
+ if row == rows:
102
+ fig.update_xaxes(title_text="Time", row=row, col=col)
103
+ if col == 1:
104
+ fig.update_yaxes(title_text="Amp", row=row, col=col)
105
+
106
+ fig.update_layout(
107
+ height=400 * rows,
108
+ width=520 * cols,
109
+ title_text="Opt-Pi-Pulse",
110
+ title_x=0.5,
111
+ legend=dict(font=dict(size=10), bgcolor="rgba(255,255,255,0.8)"),
112
+ margin=dict(l=60, r=60, t=80, b=60)
113
+ )
114
+ return fig
@@ -0,0 +1,50 @@
1
+ from typing import Dict, List
2
+ from .pltplotter import QuantumDataPltPlotter
3
+ from .spectrum2dpltplotter import Spectrum2DDataPltPlotter
4
+ from .s21vfluxpltplotter import S21VfluxDataPltPlotter
5
+ from .singleshotpltplotter import SingleShotDataPltPlotter
6
+ from .spectrum2dscopepltplotter import Spectrum2DScopeDataPltPlotter
7
+ from .spectrumpltplotter import SpectrumDataPltPlotter
8
+ from .s21peakpltplotter import S21PeakDataPltPlotter
9
+
10
+ from .optpipulsepltplotter import OptPiPulseDataPltPlotter
11
+ from .rabipltplotter import RabiDataPltPlotter
12
+ from .t1fitpltplotter import T1FitDataPltPlotter
13
+ from .t2fitpltplotter import T2FitDataPltPlotter
14
+ from .rabicospltplotter import RabiCosDataPltPlotter
15
+ from .powershiftpltplotter import PowerShiftDataPltPlotter
16
+
17
+ class QuantumPlotPltManager:
18
+ def __init__(self):
19
+ self.plotters: Dict[str, QuantumDataPltPlotter] = {}
20
+ self.register_plotters()
21
+
22
+ def register_plotters(self):
23
+ self.plotters["spectrum2d"] = Spectrum2DDataPltPlotter()
24
+ self.plotters["s21vflux"] = S21VfluxDataPltPlotter()
25
+ self.plotters["singleshot"] = SingleShotDataPltPlotter()
26
+ self.plotters["spectrum2dscope"] = Spectrum2DScopeDataPltPlotter()
27
+ self.plotters["spectrum"] = SpectrumDataPltPlotter()
28
+ self.plotters["optpipulse"] = OptPiPulseDataPltPlotter()
29
+ self.plotters["rabicos"] = RabiCosDataPltPlotter()
30
+ self.plotters["t1fit"] = T1FitDataPltPlotter()
31
+ self.plotters["t2fit"] = T2FitDataPltPlotter()
32
+ self.plotters["rabi"] = RabiDataPltPlotter()
33
+ self.plotters["s21peak"] = S21PeakDataPltPlotter()
34
+ self.plotters["powershift"] = PowerShiftDataPltPlotter()
35
+
36
+ def get_plotter(self, task_type: str) -> QuantumDataPltPlotter:
37
+ if task_type not in self.plotters:
38
+ raise ValueError(f"未找到任务 '{task_type}' 的绘图器")
39
+ return self.plotters[task_type]
40
+
41
+ def list_available_tasks(self) -> List[str]:
42
+ return list(self.plotters.keys())
43
+
44
+ def plot_quantum_data(self, data_type: str, task_type: str, save_path, **kwargs):
45
+ plotter = self.get_plotter(task_type)
46
+ if data_type=='npy':
47
+ fig = plotter.plot_result_npy(**kwargs)
48
+ if data_type=='npz':
49
+ fig = plotter.plot_result_npz(**kwargs)
50
+ plotter.save_plot(fig,save_path)
@@ -0,0 +1,20 @@
1
+ import os
2
+ from abc import ABC, abstractmethod
3
+
4
+ class QuantumDataPltPlotter(ABC):
5
+ def __init__(self, task_type: str):
6
+ self.task_type = task_type
7
+
8
+
9
+ @abstractmethod
10
+ def plot_result_npy(self, **kwargs):
11
+ pass
12
+
13
+ def plot_result_npz(self, **kwargs):
14
+ pass
15
+
16
+ def save_plot(self, fig, save_path: str):
17
+ directory = os.path.dirname(save_path)
18
+ if os.path.exists(directory):
19
+ fig.savefig(save_path) # save_path 最中存储路径 “./tmp/client/result_s21peak_tmp***.png”
20
+ return fig
@@ -0,0 +1,57 @@
1
+
2
+ from typing import Dict, List
3
+ from .plyplotter import QuantumDataPlyPlotter
4
+ from .spectrum2dplyplotter import Spectrum2DDataPlyPlotter
5
+ from .s21vfluxplyplotter import S21VfluxDataPlyPlotter
6
+ from .singleshotplyplotter import SingleShotDataPlyPlotter
7
+ from .spectrum2dscopeplyplotter import Spectrum2DScopeDataPlyPlotter
8
+ from .spectrumplyplotter import SpectrumDataPlyPlotter
9
+
10
+ from .optpipulseplyplotter import OptPiPulseDataPlyPlotter
11
+ from .rabiplyplotter import RabiDataPlyPlotter
12
+ from .t1fitplyplotter import T1FitDataPlyPlotter
13
+ from .t2fitplyplotter import T2FitDataPlyPlotter
14
+ from .rabicosplyplotter import RabiCosDataPlyPlotter
15
+ from .s21peakplyplotter import S21PeakDataPlyPlotter
16
+ from .powershiftplyplotter import PowerShiftDataPlyPlotter
17
+
18
+ class QuantumPlotPlyManager:
19
+
20
+
21
+ def __init__(self):
22
+ self.plotters: Dict[str, QuantumDataPlyPlotter] = {}
23
+ self.register_plotters()
24
+
25
+ def register_plotters(self):
26
+
27
+ self.plotters["spectrum2d"] = Spectrum2DDataPlyPlotter()
28
+ self.plotters["s21vflux"] = S21VfluxDataPlyPlotter()
29
+ self.plotters["singleshot"] = SingleShotDataPlyPlotter()
30
+ self.plotters["spectrum2dscope"] = Spectrum2DScopeDataPlyPlotter()
31
+ self.plotters["spectrum"] = SpectrumDataPlyPlotter()
32
+ self.plotters["optpipulse"] = OptPiPulseDataPlyPlotter()
33
+ self.plotters["rabicos"] = RabiCosDataPlyPlotter()
34
+ self.plotters["t1fit"] = T1FitDataPlyPlotter()
35
+ self.plotters["t2fit"] = T2FitDataPlyPlotter()
36
+ self.plotters["rabi"] = RabiDataPlyPlotter()
37
+ self.plotters["s21peak"] = S21PeakDataPlyPlotter()
38
+ self.plotters["powershift"] = PowerShiftDataPlyPlotter()
39
+
40
+ def get_plotter(self, task_type: str) -> QuantumDataPlyPlotter:
41
+
42
+ if task_type not in self.plotters:
43
+ raise ValueError(f"未找到任务 '{task_type}' 的绘图器")
44
+ return self.plotters[task_type]
45
+
46
+ def list_available_tasks(self) -> List[str]:
47
+
48
+ return list(self.plotters.keys())
49
+
50
+ def plot_quantum_data(self, data_type: str, task_type: str,save_path: str,**kwargs):
51
+ plotter = self.get_plotter(task_type)
52
+ if data_type=='npy':
53
+ fig = plotter.plot_result_npy(**kwargs)
54
+ if data_type=='npz':
55
+ fig = plotter.plot_result_npz(**kwargs)
56
+
57
+ plotter.save_plot(fig,save_path)
@@ -0,0 +1,21 @@
1
+
2
+ from abc import ABC, abstractmethod
3
+ import json
4
+ import os
5
+
6
+ class QuantumDataPlyPlotter(ABC):
7
+
8
+
9
+ def __init__(self, task_type: str):
10
+ self.task_type = task_type
11
+ @abstractmethod
12
+ def plot_result_npy(self, **kwargs):
13
+ pass
14
+ def plot_result_npz(self, **kwargs):
15
+ pass
16
+
17
+ def save_plot(self, fig, save_path: str):
18
+ directory = os.path.dirname(save_path)
19
+ if os.path.exists(directory):
20
+ fig.write_html(save_path) # save_path 最中存储路径 “./tmp/client/result_s21peak_tmp***.html”
21
+ return fig
@@ -0,0 +1,108 @@
1
+ import numpy as np
2
+ import matplotlib.pyplot as plt
3
+ from .pltplotter import QuantumDataPltPlotter
4
+
5
+ class PowerShiftDataPltPlotter(QuantumDataPltPlotter):
6
+
7
+ def __init__(self):
8
+ super().__init__("powershift")
9
+
10
+ def plot_result_npy(self, **kwargs):
11
+ result = kwargs.get('result')
12
+ dict_param = kwargs.get('dict_param')
13
+
14
+ data = dict_param.item()
15
+ image = data["image"]
16
+ q_list = list(image.keys()) # 确保q_list是列表形式,便于索引
17
+ num_qubits = len(q_list)
18
+
19
+ # 数据提取
20
+ items = []
21
+ for q_name in q_list:
22
+ image_q = image[q_name]
23
+ x, y, value = image_q[0], image_q[1], image_q[2]
24
+
25
+ # 获取当前量子比特对应的关键点、类别和配置
26
+ idx = q_list.index(q_name)
27
+ keypoints = result['keypoints_list'][idx] if idx < len(result['keypoints_list']) else []
28
+ class_num = result['class_num_list'][idx] if idx < len(result['class_num_list']) else None
29
+ conf = result['confs'][idx] if idx < len(result['confs']) else None
30
+
31
+ items.append({
32
+ 'x': x,
33
+ 'y': y,
34
+ 'value': value,
35
+ 'keypoints': keypoints,
36
+ 'q_name': q_name,
37
+ 'class_num': class_num,
38
+ 'conf': conf
39
+ })
40
+
41
+ # 结果数据(原代码保留)
42
+ confs = result['confs']
43
+ class_num_list = result['class_num_list']
44
+ keypoints_list = result['keypoints_list']
45
+
46
+ # 合并所有item的图像到一张图中(多行多列布局)
47
+ num_items = len(items)
48
+ if num_items == 0:
49
+ raise ValueError("没有可合并的item数据")
50
+
51
+ # 配置每行最多显示的子图数量
52
+ max_cols = 4
53
+ # 计算需要的行数和列数
54
+ rows = (num_items + max_cols - 1) // max_cols # 向上取整计算行数
55
+ cols = min(num_items, max_cols) # 列数不超过max_cols
56
+
57
+ # 创建多行多列布局的画布
58
+ fig, axes = plt.subplots(rows, cols, figsize=(6 * cols, 6 * rows))
59
+ # 将axes转换为一维数组(方便统一处理单一行/列的情况)
60
+ axes = axes.flatten() if rows * cols > 1 else [axes]
61
+
62
+ # 为每个item绘制内容
63
+ for i, item in enumerate(items):
64
+ ax = axes[i]
65
+ x = item["x"]
66
+ y = item["y"]
67
+ values = item["value"]
68
+ keypoints = item["keypoints"]
69
+ q_name = item["q_name"]
70
+ class_num = item["class_num"]
71
+ conf = item["conf"]
72
+
73
+ # 绘制原始图像
74
+ im = ax.pcolormesh(x, y, values, cmap='viridis', shading='auto')
75
+ fig.colorbar(im, ax=ax)
76
+
77
+ # 绘制关键点
78
+ if keypoints:
79
+ sorted_keypoints = sorted(keypoints, key=lambda p: (-p[1], p[0]))
80
+ kp_x = [p[0] for p in sorted_keypoints]
81
+ kp_y = [p[1] for p in sorted_keypoints]
82
+ ax.scatter(kp_x, kp_y, color='red', s=50, marker='*', label='Key Points')
83
+ ax.plot(kp_x, kp_y, 'r--', linewidth=2)
84
+
85
+ # 添加class和confs信息
86
+ info_text = f"Qubit: {q_name}\n"
87
+ if class_num is not None:
88
+ info_text += f"Class: {class_num}\n"
89
+ if conf is not None:
90
+ # 格式化置信度为两位小数
91
+ info_text += f"Confidence: {conf:.2f}"
92
+
93
+ # 在图中添加文本信息
94
+ ax.text(0.05, 0.95, info_text, transform=ax.transAxes,
95
+ verticalalignment='top', bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
96
+
97
+ ax.set_title('Original Image')
98
+ ax.set_xlabel('X')
99
+ ax.set_ylabel('Y')
100
+ ax.legend()
101
+
102
+ # 隐藏多余的子图
103
+ for i in range(num_items, rows * cols):
104
+ axes[i].axis('off')
105
+
106
+ # 调整布局
107
+ plt.tight_layout()
108
+ return fig
@@ -0,0 +1,194 @@
1
+ import numpy as np
2
+ import plotly.graph_objects as go
3
+ from plotly.subplots import make_subplots
4
+ from .plyplotter import QuantumDataPlyPlotter
5
+
6
+ class PowerShiftDataPlyPlotter(QuantumDataPlyPlotter):
7
+
8
+ def __init__(self):
9
+ super().__init__("powershift")
10
+
11
+ def plot_result_npy(self, **kwargs):
12
+ result = kwargs.get('result')
13
+ dict_param = kwargs.get('dict_param')
14
+
15
+ data = dict_param.item()
16
+ image = data["image"]
17
+ q_list = list(image.keys())
18
+ if not q_list:
19
+ raise ValueError("没有找到量子比特数据")
20
+
21
+ items = []
22
+ for q_name in q_list:
23
+ image_q = image[q_name]
24
+ x, y, value = image_q[0], image_q[1], image_q[2]
25
+
26
+ idx = q_list.index(q_name)
27
+ keypoints = result['keypoints_list'][idx] if idx < len(result['keypoints_list']) else []
28
+ class_num = result['class_num_list'][idx] if idx < len(result['class_num_list']) else None
29
+ conf = result['confs'][idx] if idx < len(result['confs']) else None
30
+
31
+ # 数据格式处理
32
+ x = np.squeeze(x)
33
+ y = np.squeeze(y)
34
+ value = np.squeeze(value)
35
+
36
+ # 热力图网格与坐标匹配处理
37
+ if value.shape[0] == len(y) - 1 and value.shape[1] == len(x) - 1:
38
+ pass
39
+ else:
40
+ value = value[:len(y), :len(x)]
41
+
42
+ items.append({
43
+ 'x': x,
44
+ 'y': y,
45
+ 'value': value,
46
+ 'keypoints': keypoints,
47
+ 'q_name': q_name,
48
+ 'class_num': class_num,
49
+ 'conf': conf
50
+ })
51
+
52
+ num_items = len(items)
53
+ max_cols = 4 # 保持每行4个子图,平衡宽度和可读性
54
+ rows = (num_items + max_cols - 1) // max_cols
55
+ cols = min(num_items, max_cols)
56
+
57
+ # 1. 增大子图基础尺寸,按数量动态调整(优先保证显示完整)
58
+ base_size = 320 # 基础尺寸从250提升到320,显著增大子图
59
+ if num_items > 30:
60
+ base_size = 280 # 30个以上子图适当缩小,但仍比之前200大
61
+ elif num_items > 15:
62
+ base_size = 300 # 15-30个子图微调
63
+ fig_height = base_size * rows
64
+ fig_width = base_size * cols
65
+
66
+ # 2. 调整间距:增大垂直/水平间距,避免内容挤压
67
+ vertical_spacing = 0.08 # 从0.02提升到0.06,增加上下子图间隙
68
+ horizontal_spacing = 0.08 # 从0.03提升到0.08,增加左右子图间隙
69
+ # 子图数量过多时,间距适度缩小(但仍大于之前的最小间距)
70
+ if num_items > 35:
71
+ vertical_spacing = 0.06
72
+ horizontal_spacing = 0.06
73
+
74
+ subplot_titles = []
75
+ for item in items:
76
+ q_name = item['q_name']
77
+ class_num = item['class_num']
78
+ conf = item['conf']
79
+
80
+ # 构建标题基础部分
81
+ title_parts = [f"Qubit: {q_name}"]
82
+
83
+ # 添加class信息(如果存在)
84
+ if class_num is not None:
85
+ title_parts.append(f"class: {class_num}")
86
+
87
+ # 添加conf信息(如果存在,保留两位小数)
88
+ if conf is not None:
89
+ title_parts.append(f"conf: {conf:.2f}")
90
+
91
+ # 拼接所有部分为最终标题
92
+ subplot_titles.append("_".join(title_parts))
93
+
94
+ # 创建子图
95
+ fig = make_subplots(
96
+ rows=rows, cols=cols,
97
+ subplot_titles=subplot_titles,
98
+ vertical_spacing=vertical_spacing,
99
+ horizontal_spacing=horizontal_spacing,
100
+ shared_yaxes=False,
101
+ shared_xaxes=False
102
+ )
103
+
104
+ # 统一颜色条范围
105
+ all_values = np.concatenate([item['value'].flatten() for item in items])
106
+ z_min, z_max = np.min(all_values), np.max(all_values)
107
+
108
+ for i, item in enumerate(items):
109
+ row = (i // cols) + 1
110
+ col = (i % cols) + 1
111
+ x = item["x"]
112
+ y = item["y"]
113
+ values = item["value"]
114
+ keypoints = item["keypoints"]
115
+ q_name = item["q_name"]
116
+ class_num = item["class_num"]
117
+ conf = item["conf"]
118
+
119
+ # 绘制热力图(保持颜色条配置正确)
120
+ heatmap = go.Heatmap(
121
+ z=values,
122
+ x=x,
123
+ y=y,
124
+ zmin=z_min,
125
+ zmax=z_max,
126
+ colorscale='Viridis',
127
+ colorbar=dict(
128
+ thickness=12, # 增大颜色条厚度,提升可读性
129
+ title=dict(
130
+ text="Value",
131
+ side="right",
132
+ font=dict(size=10) # 颜色条标题字体
133
+ )
134
+ ) if i == (num_items - 1) else None,
135
+ showscale=(i == num_items - 1),
136
+ transpose=False
137
+ )
138
+ fig.add_trace(heatmap, row=row, col=col)
139
+
140
+ # 关键点连线按y从高到低排序
141
+ if keypoints and len(keypoints) > 0:
142
+ keypoints = np.array(keypoints).reshape(-1, 2)
143
+ sorted_keypoints = sorted(keypoints, key=lambda p: (-p[1], p[0]))
144
+ kp_x = [p[0] for p in sorted_keypoints]
145
+ kp_y = [p[1] for p in sorted_keypoints]
146
+
147
+ # 关键点散点(增大尺寸,避免看不清)
148
+ scatter = go.Scatter(
149
+ x=kp_x, y=kp_y,
150
+ mode='markers',
151
+ marker=dict(color='red', size=11, symbol='star', line=dict(width=1.2, color='white')),
152
+ name='Key Points',
153
+ showlegend=False
154
+ )
155
+ fig.add_trace(scatter, row=row, col=col)
156
+
157
+ # 关键点连接线(加粗线条)
158
+ if len(kp_x) > 1:
159
+ line = go.Scatter(
160
+ x=kp_x, y=kp_y,
161
+ mode='lines',
162
+ line=dict(color='red', dash='dash', width=1.8),
163
+ showlegend=False
164
+ )
165
+ fig.add_trace(line, row=row, col=col)
166
+
167
+ # 4. 坐标轴优化:增大字体,避免刻度/标题看不清
168
+ fig.update_xaxes(
169
+ title_text="X",
170
+ row=row, col=col,
171
+ range=[np.min(x), np.max(x)],
172
+ title_font=dict(size=11), # 轴标题字体增大
173
+ tickfont=dict(size=9), # 刻度字体增大
174
+ ticklen=4 # 增大刻度长度,提升可读性
175
+ )
176
+ fig.update_yaxes(
177
+ title_text="Y",
178
+ row=row, col=col,
179
+ range=[np.min(y), np.max(y)],
180
+ title_font=dict(size=11),
181
+ tickfont=dict(size=9),
182
+ ticklen=4
183
+ )
184
+
185
+ # 整体布局优化:增大边距,避免边缘内容被截断
186
+ fig.update_layout(
187
+ height=fig_height,
188
+ width=fig_width,
189
+ title_text="Power Shift Data Visualization",
190
+ title_font=dict(size=16, weight='bold'),
191
+ margin=dict(l=40, r=60, t=60, b=40) # 增大右/上/左/下边距,适配颜色条和标题
192
+ )
193
+
194
+ return fig