dense-evolution 8.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
dense_evolution.py
ADDED
|
@@ -0,0 +1,1445 @@
|
|
|
1
|
+
import subprocess, sys, os
|
|
2
|
+
import importlib
|
|
3
|
+
import numpy as np
|
|
4
|
+
from numpy import linalg as LA
|
|
5
|
+
import scipy.linalg
|
|
6
|
+
import matplotlib
|
|
7
|
+
import matplotlib.pyplot as plt
|
|
8
|
+
import matplotlib.gridspec as gridspec
|
|
9
|
+
from matplotlib.patches import FancyArrowPatch, Rectangle, Circle, FancyBboxPatch
|
|
10
|
+
from matplotlib.colors import LinearSegmentedColormap, Normalize
|
|
11
|
+
from matplotlib.cm import ScalarMappable
|
|
12
|
+
import matplotlib.ticker as ticker
|
|
13
|
+
from IPython.display import display, HTML, clear_output
|
|
14
|
+
import time, re, io, warnings, hashlib, json, copy, psutil, platform
|
|
15
|
+
from datetime import datetime
|
|
16
|
+
from typing import List, Dict, Optional, Tuple
|
|
17
|
+
from dataclasses import dataclass, field
|
|
18
|
+
from enum import Enum
|
|
19
|
+
import pandas as pd
|
|
20
|
+
|
|
21
|
+
warnings.filterwarnings('ignore')
|
|
22
|
+
|
|
23
|
+
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
|
|
24
|
+
# CELLA 1: Hardware Detection & Configuration
|
|
25
|
+
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
|
|
26
|
+
|
|
27
|
+
def install(pkg):
|
|
28
|
+
subprocess.check_call([sys.executable, '-m', 'pip', 'install', '-q', pkg])
|
|
29
|
+
|
|
30
|
+
# GPU Support (Optional)
|
|
31
|
+
try:
|
|
32
|
+
import cupy as cp
|
|
33
|
+
HAS_CUPY = True
|
|
34
|
+
print('✅ CuPy disponibile — GPU acceleration attiva')
|
|
35
|
+
except:
|
|
36
|
+
HAS_CUPY = False
|
|
37
|
+
print('ℹ️ CuPy non disponibile — usando NumPy CPU')
|
|
38
|
+
|
|
39
|
+
# JAX Support (Optional with NumPy fallback)
|
|
40
|
+
try:
|
|
41
|
+
import jax
|
|
42
|
+
import jax.numpy as jnp
|
|
43
|
+
HAS_JAX = True
|
|
44
|
+
print('✅ JAX disponibile (JIT optimization attivo)')
|
|
45
|
+
except:
|
|
46
|
+
HAS_JAX = False
|
|
47
|
+
jnp = None # Fallback: will use NumPy instead
|
|
48
|
+
print('ℹ️ JAX non disponibile — fallback NumPy attivo')
|
|
49
|
+
|
|
50
|
+
# Hardware detection
|
|
51
|
+
ram_total = psutil.virtual_memory().total / (1024**3)
|
|
52
|
+
ram_avail = psutil.virtual_memory().available / (1024**3)
|
|
53
|
+
print(f'\n⌨️ Sistema: {platform.processor()}')
|
|
54
|
+
print(f'💾 RAM Totale: {ram_total:.1f} GB | Disponibile: {ram_avail:.1f} GB')
|
|
55
|
+
|
|
56
|
+
# Automatic qubit limits based on RAM
|
|
57
|
+
if ram_total >= 50:
|
|
58
|
+
MAX_DENSE_QUBITS = 28
|
|
59
|
+
print(f'🚀 High-RAM runtime → Dense SV fino a 28 qubit')
|
|
60
|
+
elif ram_total >= 12:
|
|
61
|
+
MAX_DENSE_QUBITS = 24
|
|
62
|
+
print(f'✅ Standard runtime → Dense SV fino a 24 qubit')
|
|
63
|
+
else:
|
|
64
|
+
MAX_DENSE_QUBITS = 20
|
|
65
|
+
print(f'⚠️ RAM limitata → Dense SV fino a 20 qubit')
|
|
66
|
+
|
|
67
|
+
# GPU detection
|
|
68
|
+
try:
|
|
69
|
+
gpu_info = subprocess.check_output(['nvidia-smi', '--query-gpu=name,memory.total',
|
|
70
|
+
'--format=csv,noheader'], text=True).strip()
|
|
71
|
+
print(f'🎮 GPU: {gpu_info}')
|
|
72
|
+
HAS_GPU = True
|
|
73
|
+
except:
|
|
74
|
+
HAS_GPU = False
|
|
75
|
+
print('ℹ️ Nessuna GPU NVIDIA rilevata')
|
|
76
|
+
|
|
77
|
+
print(f'\n📊 Configurazione: MAX_DENSE={MAX_DENSE_QUBITS}q | JAX={HAS_JAX} | GPU={HAS_GPU}')
|
|
78
|
+
|
|
79
|
+
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
|
|
80
|
+
# CELLA 2: Matplotlib Professional Styling
|
|
81
|
+
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
|
|
82
|
+
|
|
83
|
+
plt.style.use('dark_background')
|
|
84
|
+
DARK_BG = '#010409'
|
|
85
|
+
PANEL_BG = '#0d1117'
|
|
86
|
+
PANEL_BG2 = '#161b22'
|
|
87
|
+
BORDER = '#21262d'
|
|
88
|
+
ACC_G = '#00ff9d'
|
|
89
|
+
ACC_B = '#00c8ff'
|
|
90
|
+
ACC_O = '#ff6b35'
|
|
91
|
+
ACC_P = '#b400ff'
|
|
92
|
+
ACC_TEAL = '#00ffff'
|
|
93
|
+
ACC_PINK = '#ff007f'
|
|
94
|
+
WARN = '#ffd700'
|
|
95
|
+
DANGER = '#ff4444'
|
|
96
|
+
MUTED = '#7d8590'
|
|
97
|
+
TEXT = '#e6edf3'
|
|
98
|
+
|
|
99
|
+
matplotlib.rcParams.update({
|
|
100
|
+
'figure.facecolor': DARK_BG,
|
|
101
|
+
'axes.facecolor': PANEL_BG,
|
|
102
|
+
'axes.edgecolor': BORDER,
|
|
103
|
+
'axes.labelcolor': MUTED,
|
|
104
|
+
'axes.titlecolor': TEXT,
|
|
105
|
+
'text.color': TEXT,
|
|
106
|
+
'xtick.color': MUTED,
|
|
107
|
+
'ytick.color': MUTED,
|
|
108
|
+
'grid.color': BORDER,
|
|
109
|
+
'grid.alpha': 0.5,
|
|
110
|
+
'font.family': 'monospace',
|
|
111
|
+
'font.size': 9,
|
|
112
|
+
'figure.dpi': 130,
|
|
113
|
+
'savefig.dpi': 200,
|
|
114
|
+
})
|
|
115
|
+
|
|
116
|
+
print('✅ Matplotlib tema dark professionale configurato')
|
|
117
|
+
|
|
118
|
+
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
|
|
119
|
+
# CELLA 3: Gate Matrices & Operators
|
|
120
|
+
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
|
|
121
|
+
|
|
122
|
+
INV2 = 1.0 / np.sqrt(2.0)
|
|
123
|
+
|
|
124
|
+
GATES = {
|
|
125
|
+
'h': INV2 * np.array([[1,1],[1,-1]], dtype=complex),
|
|
126
|
+
'x': np.array([[0,1],[1,0]], dtype=complex),
|
|
127
|
+
'y': np.array([[0,-1j],[1j,0]], dtype=complex),
|
|
128
|
+
'z': np.array([[1,0],[0,-1]], dtype=complex),
|
|
129
|
+
's': np.array([[1,0],[0,1j]], dtype=complex),
|
|
130
|
+
'sdg': np.array([[1,0],[0,-1j]], dtype=complex),
|
|
131
|
+
't': np.array([[1,0],[0,np.exp(1j*np.pi/4)]], dtype=complex),
|
|
132
|
+
'tdg': np.array([[1,0],[0,np.exp(-1j*np.pi/4)]], dtype=complex),
|
|
133
|
+
'sx': 0.5*np.array([[1+1j,1-1j],[1-1j,1+1j]], dtype=complex),
|
|
134
|
+
'id': np.eye(2, dtype=complex),
|
|
135
|
+
'cx': np.array([[1,0,0,0],[0,1,0,0],[0,0,0,1],[0,0,1,0]], dtype=complex),
|
|
136
|
+
'cz': np.array([[1,0,0,0],[0,1,0,0],[0,0,1,0],[0,0,0,-1]], dtype=complex),
|
|
137
|
+
'cy': np.array([[1,0,0,0],[0,1,0,0],[0,0,0,-1j],[0,0,1j,0]], dtype=complex),
|
|
138
|
+
'swap': np.array([[1,0,0,0],[0,0,1,0],[0,1,0,0],[0,0,0,1]], dtype=complex),
|
|
139
|
+
'iswap':np.array([[1,0,0,0],[0,0,1j,0],[0,1j,0,0],[0,0,0,1]], dtype=complex),
|
|
140
|
+
'ecr': INV2 * np.array([[0,0,1,1j],[0,0,1j,1],[1,-1j,0,0],[-1j,1,0,0]], dtype=complex),
|
|
141
|
+
'ccx': np.array([[1,0,0,0,0,0,0,0],
|
|
142
|
+
[0,1,0,0,0,0,0,0],
|
|
143
|
+
[0,0,1,0,0,0,0,0],
|
|
144
|
+
[0,0,0,1,0,0,0,0],
|
|
145
|
+
[0,0,0,0,1,0,0,0],
|
|
146
|
+
[0,0,0,0,0,1,0,0],
|
|
147
|
+
[0,0,0,0,0,0,0,1],
|
|
148
|
+
[0,0,0,0,0,0,1,0]], dtype=complex)
|
|
149
|
+
}
|
|
150
|
+
|
|
151
|
+
# ┌─────────────────────────────────────────────────────────────────┐
|
|
152
|
+
# │ FIX #1: PARAMETRIC GATES WITH NUMPY FALLBACK (JAX-OPTIONAL) │
|
|
153
|
+
# └─────────────────────────────────────────────────────────────────┘
|
|
154
|
+
def _build_parametric_gates(use_jax: bool = HAS_JAX):
|
|
155
|
+
"""
|
|
156
|
+
Builds parametric gate functions with automatic fallback to NumPy if JAX unavailable.
|
|
157
|
+
Returns a dictionary of gate factories.
|
|
158
|
+
"""
|
|
159
|
+
if use_jax and HAS_JAX:
|
|
160
|
+
# JAX version with JIT optimization potential
|
|
161
|
+
def rx_gate(theta: float):
|
|
162
|
+
c, s = jnp.cos(theta/2), jnp.sin(theta/2)
|
|
163
|
+
return jnp.array([[c, -1j*s], [-1j*s, c]], dtype=complex)
|
|
164
|
+
|
|
165
|
+
def ry_gate(theta: float):
|
|
166
|
+
c, s = jnp.cos(theta/2), jnp.sin(theta/2)
|
|
167
|
+
return jnp.array([[c, -s], [s, c]], dtype=complex)
|
|
168
|
+
|
|
169
|
+
def rz_gate(theta: float):
|
|
170
|
+
return jnp.array([[jnp.exp(-1j*theta/2), 0],
|
|
171
|
+
[0, jnp.exp(1j*theta/2)]], dtype=complex)
|
|
172
|
+
|
|
173
|
+
def u3_gate(theta: float, phi: float, lam: float):
|
|
174
|
+
c, s = jnp.cos(theta/2), jnp.sin(theta/2)
|
|
175
|
+
return jnp.array(
|
|
176
|
+
[[c, -jnp.exp(1j*lam)*s],
|
|
177
|
+
[jnp.exp(1j*phi)*s, jnp.exp(1j*(phi+lam))*c]]
|
|
178
|
+
, dtype=complex)
|
|
179
|
+
|
|
180
|
+
def p_gate(lam: float):
|
|
181
|
+
return jnp.array([[1, 0], [0, jnp.exp(1j*lam)]], dtype=complex)
|
|
182
|
+
|
|
183
|
+
def cp_gate(lam: float):
|
|
184
|
+
return jnp.array([[1,0,0,0],[0,1,0,0],[0,0,1,0],[0,0,0,jnp.exp(1j*lam)]], dtype=complex)
|
|
185
|
+
|
|
186
|
+
def crz_gate(theta: float):
|
|
187
|
+
return jnp.array([[1,0,0,0],[0,1,0,0],
|
|
188
|
+
[0,0,jnp.exp(-1j*theta/2),0],
|
|
189
|
+
[0,0,0,jnp.exp(1j*theta/2)]], dtype=complex)
|
|
190
|
+
else:
|
|
191
|
+
# NumPy version (fallback or primary when JAX unavailable)
|
|
192
|
+
def rx_gate(theta: float):
|
|
193
|
+
c, s = np.cos(theta/2), np.sin(theta/2)
|
|
194
|
+
return np.array([[c, -1j*s], [-1j*s, c]], dtype=complex)
|
|
195
|
+
|
|
196
|
+
def ry_gate(theta: float):
|
|
197
|
+
c, s = np.cos(theta/2), np.sin(theta/2)
|
|
198
|
+
return np.array([[c, -s], [s, c]], dtype=complex)
|
|
199
|
+
|
|
200
|
+
def rz_gate(theta: float):
|
|
201
|
+
return np.array([[np.exp(-1j*theta/2), 0],
|
|
202
|
+
[0, np.exp(1j*theta/2)]], dtype=complex)
|
|
203
|
+
|
|
204
|
+
def u3_gate(theta: float, phi: float, lam: float):
|
|
205
|
+
c, s = np.cos(theta/2), np.sin(theta/2)
|
|
206
|
+
return np.array(
|
|
207
|
+
[[c, -np.exp(1j*lam)*s],
|
|
208
|
+
[np.exp(1j*phi)*s, np.exp(1j*(phi+lam))*c]]
|
|
209
|
+
, dtype=complex)
|
|
210
|
+
|
|
211
|
+
def p_gate(lam: float):
|
|
212
|
+
return np.array([[1, 0], [0, np.exp(1j*lam)]], dtype=complex)
|
|
213
|
+
|
|
214
|
+
def cp_gate(lam: float):
|
|
215
|
+
return np.array([[1,0,0,0],[0,1,0,0],[0,0,1,0],[0,0,0,np.exp(1j*lam)]], dtype=complex)
|
|
216
|
+
|
|
217
|
+
def crz_gate(theta: float):
|
|
218
|
+
return np.array([[1,0,0,0],[0,1,0,0],
|
|
219
|
+
[0,0,np.exp(-1j*theta/2),0],
|
|
220
|
+
[0,0,0,np.exp(1j*theta/2)]], dtype=complex)
|
|
221
|
+
|
|
222
|
+
return {
|
|
223
|
+
'rx': rx_gate, 'ry': ry_gate, 'rz': rz_gate,
|
|
224
|
+
'u3': u3_gate, 'u2': lambda p,l: u3_gate(np.pi/2,p,l),
|
|
225
|
+
'u1': lambda l: p_gate(l), 'p': p_gate,
|
|
226
|
+
'cp': cp_gate, 'crz': crz_gate,
|
|
227
|
+
}
|
|
228
|
+
|
|
229
|
+
PARAMETRIC_GATES = _build_parametric_gates(use_jax=HAS_JAX)
|
|
230
|
+
|
|
231
|
+
print('✅ Gate library caricata (Parametric gates: JAX-safe with NumPy fallback)')
|
|
232
|
+
print(f' Gate 1q: {list(GATES.keys())[:6]}...')
|
|
233
|
+
print(f' Parametrici: {list(PARAMETRIC_GATES.keys())}')
|
|
234
|
+
|
|
235
|
+
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
|
|
236
|
+
# CELLA 4: JAX JIT Compilation (Optional)
|
|
237
|
+
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
|
|
238
|
+
|
|
239
|
+
if HAS_JAX:
|
|
240
|
+
def _jax_apply_gate_1q_einsum_impl(sv_array, gate_array, n_qubits, qubit_idx):
|
|
241
|
+
n, q = int(n_qubits), int(qubit_idx)
|
|
242
|
+
sv_nd = sv_array.reshape([2] * n)
|
|
243
|
+
sv_moved = jnp.moveaxis(sv_nd, q, -1)
|
|
244
|
+
flat_shape = (1 << (n - 1), 2)
|
|
245
|
+
result_moved = jnp.dot(sv_moved.reshape(flat_shape), gate_array.T)
|
|
246
|
+
result_nd = result_moved.reshape([2] * n)
|
|
247
|
+
return jnp.moveaxis(result_nd, -1, q).ravel()
|
|
248
|
+
|
|
249
|
+
jax_apply_gate_1q_einsum = jax.jit(_jax_apply_gate_1q_einsum_impl, static_argnums=(2, 3))
|
|
250
|
+
|
|
251
|
+
def _jax_apply_gate_2q_einsum_impl(sv_array, gate_array, n_qubits, q1, q2):
|
|
252
|
+
n = int(n_qubits)
|
|
253
|
+
sv_nd = sv_array.reshape([2] * n)
|
|
254
|
+
sv_moved = jnp.moveaxis(sv_nd, (q1, q2), (-2, -1))
|
|
255
|
+
flat_shape = (1 << (n - 2), 4)
|
|
256
|
+
gate_2d = gate_array.reshape(4, 4)
|
|
257
|
+
result_moved = jnp.dot(sv_moved.reshape(flat_shape), gate_2d.T)
|
|
258
|
+
result_nd = result_moved.reshape([2] * n)
|
|
259
|
+
return jnp.moveaxis(result_nd, (-2, -1), (q1, q2)).ravel()
|
|
260
|
+
|
|
261
|
+
jax_apply_gate_2q_einsum = jax.jit(_jax_apply_gate_2q_einsum_impl, static_argnums=(2, 3, 4))
|
|
262
|
+
print("💎 JAX JIT compilation attiva per gate 1q e 2q")
|
|
263
|
+
else:
|
|
264
|
+
jax_apply_gate_1q_einsum = None
|
|
265
|
+
jax_apply_gate_2q_einsum = None
|
|
266
|
+
print("ℹ️ JAX JIT non disponibile — fallback NumPy ottimizzato")
|
|
267
|
+
|
|
268
|
+
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
|
|
269
|
+
# CELLA 5: DenseSVSimulator Core Engine
|
|
270
|
+
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
|
|
271
|
+
|
|
272
|
+
class DenseSVSimulator:
|
|
273
|
+
"""
|
|
274
|
+
Professional quantum circuit simulator using dense statevector representation.
|
|
275
|
+
|
|
276
|
+
Features:
|
|
277
|
+
- NumPy/CuPy/JAX backend with automatic selection
|
|
278
|
+
- 1q and 2q gate support via micro-optimized back-axis dot parallelization
|
|
279
|
+
- Optional JAX JIT compilation
|
|
280
|
+
- Vectorized stride-slicing measurement and collapse
|
|
281
|
+
- Noise model integration (Kraus operators via stochastic trajectories)
|
|
282
|
+
|
|
283
|
+
Endianness: MSB-first (bit 0 = qubit n-1, bit n-1 = qubit 0)
|
|
284
|
+
"""
|
|
285
|
+
|
|
286
|
+
def __init__(self, n_qubits: int, use_gpu: bool = True, use_float32: bool = False):
|
|
287
|
+
self.n = n_qubits
|
|
288
|
+
self.dim = 2**n_qubits
|
|
289
|
+
self.use_gpu = use_gpu and HAS_CUPY
|
|
290
|
+
self.dtype = np.complex64 if use_float32 else np.complex128
|
|
291
|
+
|
|
292
|
+
if self.use_gpu:
|
|
293
|
+
import cupy as cp
|
|
294
|
+
self.xp = cp
|
|
295
|
+
print(f'🎮 DenseSV: CuPy GPU | n={n_qubits} | dim={self.dim:,}')
|
|
296
|
+
elif HAS_JAX:
|
|
297
|
+
self.xp = jnp
|
|
298
|
+
dtype_str = 'float32' if use_float32 else 'float64'
|
|
299
|
+
print(f'⚡ DenseSV: JAX CPU/TPU | n={n_qubits} | dim={self.dim:,} | {dtype_str}')
|
|
300
|
+
else:
|
|
301
|
+
self.xp = np
|
|
302
|
+
dtype_str = 'float32' if use_float32 else 'float64'
|
|
303
|
+
print(f'⌨️ DenseSV: NumPy CPU | n={n_qubits} | dim={self.dim:,} | {dtype_str}')
|
|
304
|
+
|
|
305
|
+
# Initialize |00...0⟩
|
|
306
|
+
self.sv = self.xp.zeros(self.dim, dtype=self.dtype)
|
|
307
|
+
if self.xp is jnp:
|
|
308
|
+
self.sv = self.sv.at[0].set(1.0)
|
|
309
|
+
else:
|
|
310
|
+
self.sv[0] = 1.0
|
|
311
|
+
|
|
312
|
+
ram_mb = (self.dim * (8 if use_float32 else 16)) / (1024**2)
|
|
313
|
+
print(f' RAM allocata: {ram_mb:.1f} MB')
|
|
314
|
+
if ram_mb > 1000:
|
|
315
|
+
print(f' ⚠️ >1GB: Richiede architettura Full Vector ottimizzata a basso livello')
|
|
316
|
+
|
|
317
|
+
def set_initial_state(self, state_vector: Optional[np.ndarray] = None):
|
|
318
|
+
"""Set initial state vector or reset to |0...0⟩"""
|
|
319
|
+
xp = self.xp
|
|
320
|
+
if state_vector is None:
|
|
321
|
+
self.sv = xp.zeros(self.dim, dtype=self.dtype)
|
|
322
|
+
if xp is jnp:
|
|
323
|
+
self.sv = self.sv.at[0].set(1.0)
|
|
324
|
+
else:
|
|
325
|
+
self.sv[0] = 1.0
|
|
326
|
+
else:
|
|
327
|
+
if len(state_vector) != self.dim:
|
|
328
|
+
raise ValueError(f"State vector must have length 2^n ({self.dim})")
|
|
329
|
+
self.sv = xp.asarray(state_vector, dtype=self.dtype)
|
|
330
|
+
self.normalize()
|
|
331
|
+
|
|
332
|
+
def apply_gate_1q(self, gate: np.ndarray, qubit: int):
|
|
333
|
+
"""Apply single-qubit gate with MSB convention"""
|
|
334
|
+
if not 0 <= qubit < self.n:
|
|
335
|
+
raise ValueError(f"Qubit index {qubit} out of bounds for {self.n} qubits")
|
|
336
|
+
self._apply_gate_fast(gate, qubit)
|
|
337
|
+
|
|
338
|
+
def _apply_gate_fast(self, gate: np.ndarray, qubit: int):
|
|
339
|
+
"""Vectorized O(2^n) 1q gate application using back-axis dot product"""
|
|
340
|
+
xp = self.xp
|
|
341
|
+
g = xp.asarray(gate, dtype=self.dtype)
|
|
342
|
+
|
|
343
|
+
if HAS_JAX and xp is jnp and jax_apply_gate_1q_einsum is not None:
|
|
344
|
+
self.sv = jax_apply_gate_1q_einsum(self.sv, g, self.n, qubit)
|
|
345
|
+
else:
|
|
346
|
+
# Micro-ottimizzazione: Spostamento dell'asse all'ultimo posto ed esecuzione dot parallelo
|
|
347
|
+
sv_nd = self.sv.reshape([2] * self.n)
|
|
348
|
+
sv_moved = xp.moveaxis(sv_nd, qubit, -1)
|
|
349
|
+
flat_shape = (1 << (self.n - 1), 2)
|
|
350
|
+
result_moved = xp.dot(sv_moved.reshape(flat_shape), g.T)
|
|
351
|
+
result_nd = result_moved.reshape([2] * self.n)
|
|
352
|
+
self.sv = xp.moveaxis(result_nd, -1, qubit).ravel()
|
|
353
|
+
|
|
354
|
+
def apply_gate_2q(self, gate: np.ndarray, q1: int, q2: int):
|
|
355
|
+
"""Apply 2-qubit gate (4x4 or 2x2x2x2 tensor) with MSB convention"""
|
|
356
|
+
xp = self.xp
|
|
357
|
+
if not (0 <= q1 < self.n and 0 <= q2 < self.n and q1 != q2):
|
|
358
|
+
raise ValueError(f"Invalid qubit indices ({q1}, {q2})")
|
|
359
|
+
|
|
360
|
+
g_2d = xp.asarray(gate, dtype=self.dtype).reshape(4, 4)
|
|
361
|
+
|
|
362
|
+
if HAS_JAX and xp is jnp and jax_apply_gate_2q_einsum is not None:
|
|
363
|
+
self.sv = jax_apply_gate_2q_einsum(self.sv, g_2d, self.n, q1, q2)
|
|
364
|
+
else:
|
|
365
|
+
# Soluzione migliore: eliminazione stringhe e contrazione via vettorializzazione BLAS posteriore
|
|
366
|
+
sv_nd = self.sv.reshape([2] * self.n)
|
|
367
|
+
sv_moved = xp.moveaxis(sv_nd, (q1, q2), (-2, -1))
|
|
368
|
+
flat_shape = (1 << (self.n - 2), 4)
|
|
369
|
+
result_moved = xp.dot(sv_moved.reshape(flat_shape), g_2d.T)
|
|
370
|
+
result_nd = result_moved.reshape([2] * self.n)
|
|
371
|
+
self.sv = xp.moveaxis(result_nd, (-2, -1), (q1, q2)).ravel()
|
|
372
|
+
|
|
373
|
+
def apply_cx(self, ctrl: int, tgt: int):
|
|
374
|
+
"""Controlled-X (CNOT) gate - bit-mask based (No massive index allocation)"""
|
|
375
|
+
xp = self.xp
|
|
376
|
+
if not (0 <= ctrl < self.n and 0 <= tgt < self.n and ctrl != tgt):
|
|
377
|
+
raise ValueError(f"Invalid control ({ctrl}) or target ({tgt})")
|
|
378
|
+
|
|
379
|
+
if xp is jnp:
|
|
380
|
+
# Ramo JAX: Sfrutta il motore a 2 qubit nativo per non rompere la compilazione tracciata
|
|
381
|
+
cx_mat = xp.array([[1,0,0,0],[0,1,0,0],[0,0,0,1],[0,0,1,0]], dtype=self.dtype)
|
|
382
|
+
self.apply_gate_2q(cx_mat, ctrl, tgt)
|
|
383
|
+
else:
|
|
384
|
+
# Ramo NumPy/CuPy: Ottimizzazione chirurgica della memoria in-place sulla CPU/GPU
|
|
385
|
+
c_stride = 1 << (self.n - 1 - ctrl)
|
|
386
|
+
t_stride = 1 << (self.n - 1 - tgt)
|
|
387
|
+
step = 2 * max(c_stride, t_stride)
|
|
388
|
+
inner_step = 2 * min(c_stride, t_stride)
|
|
389
|
+
|
|
390
|
+
for i in range(0, self.dim, step):
|
|
391
|
+
for j in range(0, max(c_stride, t_stride), inner_step):
|
|
392
|
+
base_idx = i + j + c_stride
|
|
393
|
+
idx_0 = base_idx
|
|
394
|
+
idx_1 = base_idx + t_stride
|
|
395
|
+
|
|
396
|
+
# Swap dei blocchi contigui senza creare array intermedi condizionali
|
|
397
|
+
tmp = self.sv[idx_0 : idx_0 + min(c_stride, t_stride)].copy()
|
|
398
|
+
self.sv[idx_0 : idx_0 + min(c_stride, t_stride)] = self.sv[idx_1 : idx_1 + min(c_stride, t_stride)]
|
|
399
|
+
self.sv[idx_1 : idx_1 + min(c_stride, t_stride)] = tmp
|
|
400
|
+
|
|
401
|
+
def apply_cz(self, ctrl: int, tgt: int):
|
|
402
|
+
"""Controlled-Z gate - Micro-optimized stride slicing"""
|
|
403
|
+
xp = self.xp
|
|
404
|
+
if not (0 <= ctrl < self.n and 0 <= tgt < self.n and ctrl != tgt):
|
|
405
|
+
raise ValueError(f"Invalid indices ({ctrl}, {tgt})")
|
|
406
|
+
|
|
407
|
+
if xp is jnp:
|
|
408
|
+
cz_mat = xp.array([[1,0,0,0],[0,1,0,0],[0,0,1,0],[0,0,0,-1]], dtype=self.dtype)
|
|
409
|
+
self.apply_gate_2q(cz_mat, ctrl, tgt)
|
|
410
|
+
else:
|
|
411
|
+
c_stride = 1 << (self.n - 1 - ctrl)
|
|
412
|
+
t_stride = 1 << (self.n - 1 - tgt)
|
|
413
|
+
step = 2 * max(c_stride, t_stride)
|
|
414
|
+
inner_step = 2 * min(c_stride, t_stride)
|
|
415
|
+
|
|
416
|
+
# Inversione di segno diretta sui blocchi mirati dove sia il controllo che il target sono a 1
|
|
417
|
+
for i in range(0, self.dim, step):
|
|
418
|
+
for j in range(0, max(c_stride, t_stride), inner_step):
|
|
419
|
+
idx = i + j + c_stride + t_stride
|
|
420
|
+
self.sv[idx : idx + min(c_stride, t_stride)] *= -1
|
|
421
|
+
|
|
422
|
+
def normalize(self):
|
|
423
|
+
"""Normalize statevector to unit norm in-place"""
|
|
424
|
+
norm = self.xp.linalg.norm(self.sv)
|
|
425
|
+
if norm > 1e-12:
|
|
426
|
+
if self.xp is jnp:
|
|
427
|
+
self.sv = self.sv / norm
|
|
428
|
+
else:
|
|
429
|
+
self.sv /= norm
|
|
430
|
+
|
|
431
|
+
def get_probabilities(self) -> np.ndarray:
|
|
432
|
+
"""Compute |ψ|² for each basis state"""
|
|
433
|
+
probs = self.xp.abs(self.sv)**2
|
|
434
|
+
if self.use_gpu:
|
|
435
|
+
return probs.get()
|
|
436
|
+
return np.array(probs, dtype=np.float64)
|
|
437
|
+
|
|
438
|
+
def get_statevector(self) -> np.ndarray:
|
|
439
|
+
"""Return statevector as NumPy array"""
|
|
440
|
+
sv = self.sv
|
|
441
|
+
if self.xp is jnp:
|
|
442
|
+
return np.array(sv)
|
|
443
|
+
if self.use_gpu:
|
|
444
|
+
return sv.get()
|
|
445
|
+
return np.array(sv)
|
|
446
|
+
|
|
447
|
+
# ┌─────────────────────────────────────────────────────────────────┐
|
|
448
|
+
# │ FIX #2: VECTORIZED MEASURE (Stride Slicing — No Index Masks) │
|
|
449
|
+
# └─────────────────────────────────────────────────────────────────┘
|
|
450
|
+
def measure(self, qubit_idx: int) -> int:
|
|
451
|
+
"""
|
|
452
|
+
Measure a single qubit and collapse statevector.
|
|
453
|
+
FIXED: Uses micro-optimized stride-slicing without memory allocation.
|
|
454
|
+
"""
|
|
455
|
+
if not 0 <= qubit_idx < self.n:
|
|
456
|
+
raise ValueError(f"Qubit {qubit_idx} out of bounds")
|
|
457
|
+
|
|
458
|
+
xp = self.xp
|
|
459
|
+
|
|
460
|
+
# Calcolo dell'indice fisico specchiato per la convenzione MSB
|
|
461
|
+
phys_q = self.n - 1 - qubit_idx
|
|
462
|
+
stride = 1 << phys_q
|
|
463
|
+
|
|
464
|
+
if xp is jnp:
|
|
465
|
+
# Ramo JAX: Calcolo conforme al tracciamento statico dei tensori con fette esatte
|
|
466
|
+
probs = self.xp.abs(self.sv)**2
|
|
467
|
+
sv_shape = [2] * self.n
|
|
468
|
+
sv_nd = probs.reshape(sv_shape)
|
|
469
|
+
prob_0 = float(jnp.sum(jnp.moveaxis(sv_nd, phys_q, 0)[0]))
|
|
470
|
+
prob_1 = float(jnp.sum(jnp.moveaxis(sv_nd, phys_q, 0)[1]))
|
|
471
|
+
else:
|
|
472
|
+
# Ramo NumPy/CuPy Ultra-Performante: Somma a salti in memoria (Zero allocazione)
|
|
473
|
+
sv_reshaped = self.sv.reshape(-1, 2, stride)
|
|
474
|
+
prob_0 = float(xp.sum(xp.abs(sv_reshaped[:, 0, :])**2))
|
|
475
|
+
prob_1 = float(xp.sum(xp.abs(sv_reshaped[:, 1, :])**2))
|
|
476
|
+
|
|
477
|
+
# Normalizzazione delle probabilità estratte
|
|
478
|
+
total = prob_0 + prob_1
|
|
479
|
+
if total > 1e-12:
|
|
480
|
+
prob_0 /= total
|
|
481
|
+
prob_1 /= total
|
|
482
|
+
|
|
483
|
+
# Campionamento dell'esito della misura
|
|
484
|
+
result = int(np.random.choice([0, 1], p=[prob_0, prob_1]))
|
|
485
|
+
|
|
486
|
+
# Collasso della funzione d'onda in-place (Zero allocazione di maschere giganti)
|
|
487
|
+
if xp is jnp:
|
|
488
|
+
sv_shape = [2] * self.n
|
|
489
|
+
sv_nd = self.sv.reshape(sv_shape)
|
|
490
|
+
moved_sv = jnp.moveaxis(sv_nd, phys_q, 0)
|
|
491
|
+
moved_sv = moved_sv.at[1 if result == 0 else 0].set(0.0)
|
|
492
|
+
self.sv = jnp.moveaxis(moved_sv, 0, phys_q).ravel()
|
|
493
|
+
else:
|
|
494
|
+
# Slicing chirurgico nativo: azzera metà del vettore direttamente sulla matrice di vista
|
|
495
|
+
sv_reshaped[:, 1 if result == 0 else 0, :] = 0.0
|
|
496
|
+
|
|
497
|
+
self.normalize()
|
|
498
|
+
return result
|
|
499
|
+
|
|
500
|
+
def memory_mb(self) -> float:
|
|
501
|
+
"""Estimate RAM usage in MB"""
|
|
502
|
+
elem_size = 8 if self.dtype == np.complex64 else 16
|
|
503
|
+
return self.dim * elem_size / 1e6
|
|
504
|
+
|
|
505
|
+
# ┌─────────────────────────────────────────────────────────────────┐
|
|
506
|
+
# │ PARAMETRIC GATE INJECTION (VERSIONE INTEGRALE DA REPOSITORY) │
|
|
507
|
+
# └─────────────────────────────────────────────────────────────────┘
|
|
508
|
+
|
|
509
|
+
def patch_dense_parametric(cls):
|
|
510
|
+
"""Inject all parametric standard OpenQASM 2.0 methods into DenseSVSimulator"""
|
|
511
|
+
|
|
512
|
+
def apply_rx(self, qubit: int, theta: float):
|
|
513
|
+
gate = PARAMETRIC_GATES['rx'](theta)
|
|
514
|
+
self.apply_gate_1q(gate, qubit)
|
|
515
|
+
|
|
516
|
+
def apply_ry(self, qubit: int, theta: float):
|
|
517
|
+
gate = PARAMETRIC_GATES['ry'](theta)
|
|
518
|
+
self.apply_gate_1q(gate, qubit)
|
|
519
|
+
|
|
520
|
+
def apply_rz(self, qubit: int, theta: float):
|
|
521
|
+
gate = PARAMETRIC_GATES['rz'](theta)
|
|
522
|
+
self.apply_gate_1q(gate, qubit)
|
|
523
|
+
|
|
524
|
+
def apply_u3(self, qubit: int, theta: float, phi: float, lam: float):
|
|
525
|
+
gate = PARAMETRIC_GATES['u3'](theta, phi, lam)
|
|
526
|
+
self.apply_gate_1q(gate, qubit)
|
|
527
|
+
|
|
528
|
+
def apply_u2(self, qubit: int, phi: float, lam: float):
|
|
529
|
+
gate = PARAMETRIC_GATES['u2'](phi, lam)
|
|
530
|
+
self.apply_gate_1q(gate, qubit)
|
|
531
|
+
|
|
532
|
+
def apply_u1(self, qubit: int, lam: float):
|
|
533
|
+
gate = PARAMETRIC_GATES['u1'](lam)
|
|
534
|
+
self.apply_gate_1q(gate, qubit)
|
|
535
|
+
|
|
536
|
+
def apply_p(self, qubit: int, lam: float):
|
|
537
|
+
gate = PARAMETRIC_GATES['p'](lam)
|
|
538
|
+
self.apply_gate_1q(gate, qubit)
|
|
539
|
+
|
|
540
|
+
def apply_cp(self, ctrl: int, tgt: int, lam: float):
|
|
541
|
+
gate = PARAMETRIC_GATES['cp'](lam)
|
|
542
|
+
self.apply_gate_2q(gate, ctrl, tgt)
|
|
543
|
+
|
|
544
|
+
def apply_crz(self, ctrl: int, tgt: int, theta: float):
|
|
545
|
+
gate = PARAMETRIC_GATES['crz'](theta)
|
|
546
|
+
self.apply_gate_2q(gate, ctrl, tgt)
|
|
547
|
+
|
|
548
|
+
# Iniezione di tutta la suite senza eccezioni o esclusioni
|
|
549
|
+
cls.apply_rx = apply_rx
|
|
550
|
+
cls.apply_ry = apply_ry
|
|
551
|
+
cls.apply_rz = apply_rz
|
|
552
|
+
cls.apply_u3 = apply_u3
|
|
553
|
+
cls.apply_u2 = apply_u2
|
|
554
|
+
cls.apply_u1 = apply_u1
|
|
555
|
+
cls.apply_p = apply_p
|
|
556
|
+
cls.apply_cp = apply_cp
|
|
557
|
+
cls.apply_crz = apply_crz
|
|
558
|
+
|
|
559
|
+
print("✅ All parametric methods (including u1/u2) injected into DenseSVSimulator")
|
|
560
|
+
|
|
561
|
+
patch_dense_parametric(DenseSVSimulator)
|
|
562
|
+
|
|
563
|
+
import numpy as np
|
|
564
|
+
from typing import Optional, List, Dict
|
|
565
|
+
import time
|
|
566
|
+
|
|
567
|
+
# ═══════════════════════════════════════════════════════════════════════════════
|
|
568
|
+
# CELLA 6: Modelli di rumore con operatori Kraus (VERSIONE INTEGRALE JAX FIXED)
|
|
569
|
+
# ═══════════════════════════════════════════════════════════════════════════════
|
|
570
|
+
# [PROPRIETARY ALGORITHM - (c) 2026 Salvatore Pennacchio - Licensed under EUPL-1.2]
|
|
571
|
+
|
|
572
|
+
try:
|
|
573
|
+
import jax
|
|
574
|
+
import jax.numpy as jnp
|
|
575
|
+
HAS_JAX = True
|
|
576
|
+
except ImportError:
|
|
577
|
+
HAS_JAX = False
|
|
578
|
+
|
|
579
|
+
class NoiseModel:
|
|
580
|
+
"""
|
|
581
|
+
5 modelli fisici di decoerenza con operatori Kraus.
|
|
582
|
+
Ottimizzato per agire in-place su NumPy/CuPy e funzionalmente su JAX XLA.
|
|
583
|
+
"""
|
|
584
|
+
|
|
585
|
+
MODELS = ['ideal', 'depolarizing', 'bitflip', 'phaseflip', 'amplitude_damping', 'combined']
|
|
586
|
+
|
|
587
|
+
@staticmethod
|
|
588
|
+
def apply_to_sv(sv: np.ndarray, n: int, model: str, p: float,
|
|
589
|
+
rng: Optional[np.random.Generator] = None, qubits: Optional[List[int]] = None,
|
|
590
|
+
jax_key: Optional[any] = None) -> np.ndarray:
|
|
591
|
+
"""
|
|
592
|
+
Applica il rumore stocastico al vettore di stato tramite traiettorie quantistiche (quantum jumps).
|
|
593
|
+
|
|
594
|
+
Args:
|
|
595
|
+
sv: Vettore di stato (np.ndarray o jnp.ndarray)
|
|
596
|
+
n: Numero totale di qubit nel registro
|
|
597
|
+
model: Stringa identificativa del modello di rumore
|
|
598
|
+
p: Probabilità di errore / parametro di damping
|
|
599
|
+
rng: Generatore di numeri casuali NumPy (usato solo per backend NumPy/CuPy)
|
|
600
|
+
qubits: Lista di qubit su cui applicare il rumore (default: tutti)
|
|
601
|
+
jax_key: jax.random.PRNGKey obbligatoria per garantire la stochastiticità sotto JAX JIT
|
|
602
|
+
"""
|
|
603
|
+
if model == 'ideal' or p <= 0:
|
|
604
|
+
return sv
|
|
605
|
+
|
|
606
|
+
is_jax_array = HAS_JAX and isinstance(sv, jnp.ndarray)
|
|
607
|
+
xp = jnp if is_jax_array else np
|
|
608
|
+
|
|
609
|
+
target_qubits = qubits if qubits else list(range(n))
|
|
610
|
+
dim = len(sv)
|
|
611
|
+
sv_local = sv # JAX array immutabile, le operazioni .at restituiranno nuove istanze
|
|
612
|
+
|
|
613
|
+
# Fallback del generatore NumPy per compatibilità retroattiva NumPy/CuPy
|
|
614
|
+
if not is_jax_array and rng is None:
|
|
615
|
+
rng = np.random.default_rng(int(time.time()))
|
|
616
|
+
|
|
617
|
+
# Inizializzazione della chiave funzionale di JAX per evitare il tracing ghost
|
|
618
|
+
if is_jax_array:
|
|
619
|
+
if jax_key is None:
|
|
620
|
+
jax_key = jax.random.PRNGKey(int(time.time() * 1000))
|
|
621
|
+
current_key = jax_key
|
|
622
|
+
|
|
623
|
+
for q in target_qubits:
|
|
624
|
+
step = 1 << q
|
|
625
|
+
indices = xp.arange(dim)
|
|
626
|
+
mask_0 = (indices & step) == 0
|
|
627
|
+
idx_0 = xp.where(mask_0)[0]
|
|
628
|
+
idx_1 = idx_0 | step
|
|
629
|
+
len_idx = len(idx_0)
|
|
630
|
+
|
|
631
|
+
# --- GENERAZIONE DEL VETTORE CASUALE AGNOSTIC BACKEND ---
|
|
632
|
+
if is_jax_array:
|
|
633
|
+
current_key, subkey = jax.random.split(current_key)
|
|
634
|
+
r_vec = jax.random.uniform(subkey, shape=(len_idx,), minval=0.0, maxval=1.0)
|
|
635
|
+
else:
|
|
636
|
+
r_vec = rng.random(len_idx)
|
|
637
|
+
|
|
638
|
+
# --- APPLICAZIONE MODELLI DI RUMORE ---
|
|
639
|
+
if model == 'depolarizing':
|
|
640
|
+
mask_x = r_vec < p/3
|
|
641
|
+
mask_z = (r_vec >= p/3) & (r_vec < 2*p/3)
|
|
642
|
+
mask_y = (r_vec >= 2*p/3) & (r_vec < p)
|
|
643
|
+
|
|
644
|
+
if is_jax_array:
|
|
645
|
+
# Inversione di ampiezza X-Gate (JAX Immutabile via indici booleani fissi)
|
|
646
|
+
temp_sv_x = sv_local[idx_0[mask_x]]
|
|
647
|
+
sv_local = sv_local.at[idx_0[mask_x]].set(sv_local[idx_1[mask_x]])
|
|
648
|
+
sv_local = sv_local.at[idx_1[mask_x]].set(temp_sv_x)
|
|
649
|
+
|
|
650
|
+
# Inversione di fase Z-Gate (JAX Immutabile)
|
|
651
|
+
sv_local = sv_local.at[idx_1[mask_z]].multiply(-1)
|
|
652
|
+
|
|
653
|
+
# Rotazione complessa Y-Gate (JAX Immutabile)
|
|
654
|
+
temp_sv_y0 = sv_local[idx_0[mask_y]]
|
|
655
|
+
sv_local = sv_local.at[idx_0[mask_y]].set(-1j * sv_local[idx_1[mask_y]])
|
|
656
|
+
sv_local = sv_local.at[idx_1[mask_y]].set(1j * temp_sv_y0)
|
|
657
|
+
else:
|
|
658
|
+
# Inversione di ampiezza X-Gate (NumPy standard in-place mutabile)
|
|
659
|
+
temp_sv_x = sv_local[idx_0[mask_x]].copy()
|
|
660
|
+
sv_local[idx_0[mask_x]] = sv_local[idx_1[mask_x]]
|
|
661
|
+
sv_local[idx_1[mask_x]] = temp_sv_x
|
|
662
|
+
|
|
663
|
+
sv_local[idx_1[mask_z]] *= -1
|
|
664
|
+
|
|
665
|
+
# Rotazione complessa Y-Gate (NumPy Mutabile)
|
|
666
|
+
temp_sv_y0 = sv_local[idx_0[mask_y]].copy()
|
|
667
|
+
sv_local[idx_0[mask_y]] = -1j * sv_local[idx_1[mask_y]]
|
|
668
|
+
sv_local[idx_1[mask_y]] = 1j * temp_sv_y0
|
|
669
|
+
|
|
670
|
+
elif model == 'bitflip':
|
|
671
|
+
mask_flip = r_vec < p
|
|
672
|
+
if is_jax_array:
|
|
673
|
+
temp_sv_flip = sv_local[idx_0[mask_flip]]
|
|
674
|
+
sv_local = sv_local.at[idx_0[mask_flip]].set(sv_local[idx_1[mask_flip]])
|
|
675
|
+
sv_local = sv_local.at[idx_1[mask_flip]].set(temp_sv_flip)
|
|
676
|
+
else:
|
|
677
|
+
temp_sv_flip = sv_local[idx_0[mask_flip]].copy()
|
|
678
|
+
sv_local[idx_0[mask_flip]] = sv_local[idx_1[mask_flip]]
|
|
679
|
+
sv_local[idx_1[mask_flip]] = temp_sv_flip
|
|
680
|
+
|
|
681
|
+
elif model == 'phaseflip':
|
|
682
|
+
mask_flip = r_vec < p
|
|
683
|
+
if is_jax_array:
|
|
684
|
+
sv_local = sv_local.at[idx_1[mask_flip]].multiply(-1)
|
|
685
|
+
else:
|
|
686
|
+
sv_local[idx_1[mask_flip]] *= -1
|
|
687
|
+
|
|
688
|
+
elif model == 'amplitude_damping':
|
|
689
|
+
gamma = p
|
|
690
|
+
if is_jax_array:
|
|
691
|
+
sv_local = sv_local.at[idx_1].multiply(xp.sqrt(1 - gamma))
|
|
692
|
+
mask_decay = r_vec < gamma
|
|
693
|
+
sv_local = sv_local.at[idx_0[mask_decay]].add(sv_local[idx_1[mask_decay]])
|
|
694
|
+
sv_local = sv_local.at[idx_1[mask_decay]].set(0)
|
|
695
|
+
else:
|
|
696
|
+
sv_local[idx_1] *= np.sqrt(1 - gamma)
|
|
697
|
+
mask_decay = r_vec < gamma
|
|
698
|
+
sv_local[idx_0[mask_decay]] += sv_local[idx_1[mask_decay]]
|
|
699
|
+
sv_local[idx_1[mask_decay]] = 0
|
|
700
|
+
|
|
701
|
+
elif model == 'combined':
|
|
702
|
+
mask_x = r_vec < p*0.2
|
|
703
|
+
mask_z = (r_vec >= p*0.2) & (r_vec < p*0.4)
|
|
704
|
+
mask_y = (r_vec >= p*0.4) & (r_vec < p*0.6)
|
|
705
|
+
|
|
706
|
+
if is_jax_array:
|
|
707
|
+
temp_sv_x = sv_local[idx_0[mask_x]]
|
|
708
|
+
sv_local = sv_local.at[idx_0[mask_x]].set(sv_local[idx_1[mask_x]])
|
|
709
|
+
sv_local = sv_local.at[idx_1[mask_x]].set(temp_sv_x)
|
|
710
|
+
|
|
711
|
+
sv_local = sv_local.at[idx_1[mask_z]].multiply(-1)
|
|
712
|
+
|
|
713
|
+
temp_sv_y0 = sv_local[idx_0[mask_y]]
|
|
714
|
+
sv_local = sv_local.at[idx_0[mask_y]].set(-1j * sv_local[idx_1[mask_y]])
|
|
715
|
+
sv_local = sv_local.at[idx_1[mask_y]].set(1j * temp_sv_y0)
|
|
716
|
+
|
|
717
|
+
sv_local = sv_local.at[idx_1].multiply(xp.sqrt(1 - p*0.3))
|
|
718
|
+
else:
|
|
719
|
+
temp_sv_x = sv_local[idx_0[mask_x]].copy()
|
|
720
|
+
sv_local[idx_0[mask_x]] = sv_local[idx_1[mask_x]]
|
|
721
|
+
sv_local[idx_1[mask_x]] = temp_sv_x
|
|
722
|
+
|
|
723
|
+
sv_local[idx_1[mask_z]] *= -1
|
|
724
|
+
|
|
725
|
+
temp_sv_y0 = sv_local[idx_0[mask_y]].copy()
|
|
726
|
+
sv_local[idx_0[mask_y]] = -1j * sv_local[idx_1[mask_y]]
|
|
727
|
+
sv_local[idx_1[mask_y]] = 1j * temp_sv_y0
|
|
728
|
+
|
|
729
|
+
sv_local[idx_1] *= np.sqrt(1 - p*0.3)
|
|
730
|
+
|
|
731
|
+
# Rinormalizzazione di traiettoria protetta
|
|
732
|
+
norm = xp.linalg.norm(sv_local)
|
|
733
|
+
return sv_local / (norm + 1e-15)
|
|
734
|
+
|
|
735
|
+
@staticmethod
|
|
736
|
+
def kraus_description(model: str) -> Dict:
|
|
737
|
+
desc = {
|
|
738
|
+
'ideal': {'kraus': 1, 'formula': 'K\u2080=I', 'physical': 'Nessun rumore'},
|
|
739
|
+
'depolarizing': {'kraus': 4, 'formula': 'K\u2080=\u221a(1-p)I, K\u2081=\u221a(p/3)X, K\u2082=\u221a(p/3)Y, K\u2083=\u221a(p/3)Z', 'physical': 'Errore isotropo'},
|
|
740
|
+
'bitflip': {'kraus': 2, 'formula': 'K\u2080=\u221a(1-p)I, K\u2081=\u221ap\u00b7X', 'physical': 'Flip di qubit \u03c3_x'},
|
|
741
|
+
'phaseflip': {'kraus': 2, 'formula': 'K\u2080=\u221a(1-p)I, K\u2081=\u221ap\u00b7Z', 'physical': 'Dephasing puro'},
|
|
742
|
+
'amplitude_damping': {'kraus': 2, 'formula': 'K\u2080=diag(1,\u221a(1-\u03b3)), K\u2081=[[0,\u221a\u03b3],[0,0]]', 'physical': 'Decadimento T\u2081 (relassazione)'},
|
|
743
|
+
'combined': {'kraus': 6, 'formula': 'Dep(p*0.4) + AmpDamp(p*0.3)', 'physical': 'Worst-case NISQ'},
|
|
744
|
+
}
|
|
745
|
+
return desc.get(model, desc['ideal'])
|
|
746
|
+
|
|
747
|
+
print("✅ NoiseModel aggiornato (EUPL-1.2): Pieno supporto stocastico runtime JAX JIT sigillato!")
|
|
748
|
+
|
|
749
|
+
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
|
|
750
|
+
# CELLA 7: OpenQASM 2.0 Parser & Transpiler
|
|
751
|
+
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
|
|
752
|
+
|
|
753
|
+
@dataclass
|
|
754
|
+
class QASMCircuit:
|
|
755
|
+
n_qubits: int = 0
|
|
756
|
+
n_cbits: int = 0
|
|
757
|
+
ops: List[Dict[str, any]] = field(default_factory=list)
|
|
758
|
+
|
|
759
|
+
|
|
760
|
+
class QASMParser:
|
|
761
|
+
"""Parse OpenQASM 2.0 circuits - Production Grade GitHub Standard"""
|
|
762
|
+
|
|
763
|
+
# Pre-compilazione delle espressioni regolari per eliminare l'overhead di parsing a runtime
|
|
764
|
+
_REG_QUBIT = re.compile(r'\[(\d+)\]')
|
|
765
|
+
_ALIAS_MAP = {'cu1': 'cp', 'u1': 'p', 'toffoli': 'ccx', 'fredkin': 'cswap'}
|
|
766
|
+
_MATH_ENV = {'__builtins__': {}, 'np': np, 'pi': np.pi, 'sin': np.sin,
|
|
767
|
+
'cos': np.cos, 'sqrt': np.sqrt, 'exp': np.exp}
|
|
768
|
+
|
|
769
|
+
def parse(self, qasm_str: str) -> QASMCircuit:
|
|
770
|
+
n_qubits, n_cbits, ops = 0, 0, []
|
|
771
|
+
|
|
772
|
+
# Rimozione dei commenti in un unico passaggio lineare
|
|
773
|
+
clean = []
|
|
774
|
+
for raw in qasm_str.split('\n'):
|
|
775
|
+
line = raw.split('//')[0].strip()
|
|
776
|
+
if line:
|
|
777
|
+
clean.append(line)
|
|
778
|
+
|
|
779
|
+
# Tokenizzazione efficiente basata su delimitatore di istruzione standard ';'
|
|
780
|
+
for instr in "".join(clean).split(';'):
|
|
781
|
+
instr = instr.strip()
|
|
782
|
+
if not instr or any(instr.startswith(t) for t in ('OPENQASM', 'include', 'barrier')):
|
|
783
|
+
continue
|
|
784
|
+
|
|
785
|
+
if instr.startswith('qreg'):
|
|
786
|
+
m = self._REG_QUBIT.search(instr)
|
|
787
|
+
if m:
|
|
788
|
+
n_qubits = int(m.group(1))
|
|
789
|
+
continue
|
|
790
|
+
|
|
791
|
+
if instr.startswith('creg'):
|
|
792
|
+
m = self._REG_QUBIT.search(instr)
|
|
793
|
+
if m:
|
|
794
|
+
n_cbits = int(m.group(1))
|
|
795
|
+
continue
|
|
796
|
+
|
|
797
|
+
if instr.startswith('measure'):
|
|
798
|
+
continue
|
|
799
|
+
|
|
800
|
+
parts = instr.split()
|
|
801
|
+
if not parts:
|
|
802
|
+
continue
|
|
803
|
+
|
|
804
|
+
gate_raw = parts[0]
|
|
805
|
+
gate_name = gate_raw.split('(')[0].lower()
|
|
806
|
+
gate_name = self._ALIAS_MAP.get(gate_name, gate_name)
|
|
807
|
+
|
|
808
|
+
# Estrazione e risoluzione matematica deterministica dei parametri angolari delle porte
|
|
809
|
+
params: List[float] = []
|
|
810
|
+
if '(' in gate_raw:
|
|
811
|
+
try:
|
|
812
|
+
inner = gate_raw[gate_raw.index('(') + 1 : gate_raw.index(')')]
|
|
813
|
+
for tok in inner.split(','):
|
|
814
|
+
tok = tok.strip()
|
|
815
|
+
if tok:
|
|
816
|
+
params.append(float(eval(tok, self._MATH_ENV)))
|
|
817
|
+
except Exception:
|
|
818
|
+
params.append(0.0)
|
|
819
|
+
|
|
820
|
+
# Estrazione simultanea di tutti i qubit target coinvolti dall'istruzione
|
|
821
|
+
qubit_indices = [int(x) for x in self._REG_QUBIT.findall(" ".join(parts[1:]))]
|
|
822
|
+
|
|
823
|
+
if qubit_indices:
|
|
824
|
+
ops.append({'type': 'gate', 'name': gate_name,
|
|
825
|
+
'qubits': qubit_indices, 'params': params})
|
|
826
|
+
|
|
827
|
+
return QASMCircuit(n_qubits, n_cbits, ops)
|
|
828
|
+
|
|
829
|
+
def validate(self, circ: QASMCircuit) -> Tuple[bool, str]:
|
|
830
|
+
if circ.n_qubits <= 0:
|
|
831
|
+
return False, "n_qubits deve essere > 0."
|
|
832
|
+
if not circ.ops:
|
|
833
|
+
return False, "Nessuna operazione rilevata nel circuito quantistico."
|
|
834
|
+
return True, ""
|
|
835
|
+
|
|
836
|
+
|
|
837
|
+
class QuantumTranspiler:
|
|
838
|
+
"""Decompose multi-qubit gates into 1q and 2q execution primitives"""
|
|
839
|
+
|
|
840
|
+
@staticmethod
|
|
841
|
+
def decompose_toffoli(c1: int, c2: int, t: int) -> List[Tuple]:
|
|
842
|
+
"""Barenco et al. decomposition optimized for Full-Vector mapping (6 CNOT gates)"""
|
|
843
|
+
return [
|
|
844
|
+
('h', t),
|
|
845
|
+
('cx', c2, t), ('tdg', t),
|
|
846
|
+
('cx', c1, t), ('t', t),
|
|
847
|
+
('cx', c2, t), ('tdg', t),
|
|
848
|
+
('cx', c1, t),
|
|
849
|
+
('t', c2), ('t', t),
|
|
850
|
+
('cx', c1, c2), ('h', t),
|
|
851
|
+
('t', c1), ('tdg', c2),
|
|
852
|
+
('cx', c1, c2),
|
|
853
|
+
]
|
|
854
|
+
|
|
855
|
+
@staticmethod
|
|
856
|
+
def decompose_swap(q1: int, q2: int) -> List[Tuple]:
|
|
857
|
+
"""Decompose SWAP into core CNOT sequence compatible with hardware strides"""
|
|
858
|
+
return [('cx', q1, q2), ('cx', q2, q1), ('cx', q1, q2)]
|
|
859
|
+
|
|
860
|
+
@staticmethod
|
|
861
|
+
def transpile(circuit: List[Tuple]) -> List[Tuple]:
|
|
862
|
+
"""Unroll high-level structures into native operational primitives"""
|
|
863
|
+
out = []
|
|
864
|
+
for cmd in circuit:
|
|
865
|
+
name = cmd[0].lower()
|
|
866
|
+
if name == 'ccx':
|
|
867
|
+
out.extend(QuantumTranspiler.decompose_toffoli(*cmd[1:]))
|
|
868
|
+
elif name == 'swap':
|
|
869
|
+
out.extend(QuantumTranspiler.decompose_swap(*cmd[1:]))
|
|
870
|
+
else:
|
|
871
|
+
out.append(cmd)
|
|
872
|
+
return out
|
|
873
|
+
|
|
874
|
+
|
|
875
|
+
print("✅ QASMParser and QuantumTranspiler loaded (Optimized regex & clean primitives)")
|
|
876
|
+
|
|
877
|
+
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
|
|
878
|
+
# CELLA 8: Circuit Execution (MSB-aware Engine Core - CORRETTA)
|
|
879
|
+
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
|
|
880
|
+
|
|
881
|
+
def run_circuit(self, circuit: List[Tuple], transpile: bool = True):
|
|
882
|
+
"""
|
|
883
|
+
Execute circuit with automatic endianness correction.
|
|
884
|
+
FIXED: Prevents index double-flipping for multi-qubit gates.
|
|
885
|
+
"""
|
|
886
|
+
target = QuantumTranspiler.transpile(circuit) if transpile else circuit
|
|
887
|
+
is_dense = hasattr(self, 'sv')
|
|
888
|
+
|
|
889
|
+
for cmd in target:
|
|
890
|
+
gate_name = cmd[0].lower()
|
|
891
|
+
args = cmd[1:]
|
|
892
|
+
|
|
893
|
+
mat = None
|
|
894
|
+
if gate_name in GATES:
|
|
895
|
+
mat = GATES[gate_name]
|
|
896
|
+
elif gate_name in PARAMETRIC_GATES:
|
|
897
|
+
try:
|
|
898
|
+
mat = PARAMETRIC_GATES[gate_name](*[a for a in args if isinstance(a, (float, int)) and not isinstance(a, bool)])
|
|
899
|
+
args = tuple([a for a in args if isinstance(a, int) and not isinstance(a, bool)])
|
|
900
|
+
except Exception:
|
|
901
|
+
pass
|
|
902
|
+
|
|
903
|
+
if mat is None:
|
|
904
|
+
method = getattr(self, f'apply_{gate_name}', None)
|
|
905
|
+
if method is None and gate_name == 'measure':
|
|
906
|
+
method = getattr(self, 'measure', None)
|
|
907
|
+
|
|
908
|
+
if method:
|
|
909
|
+
method(*args)
|
|
910
|
+
else:
|
|
911
|
+
raise ValueError(f"Porta quantistica o istruzione '{gate_name}' non riconosciuta.")
|
|
912
|
+
continue
|
|
913
|
+
|
|
914
|
+
mat = np.asarray(mat)
|
|
915
|
+
if mat.ndim == 2 and mat.shape == (2, 2):
|
|
916
|
+
# Pass logical qubit index directly. apply_gate_1q handles internal mapping.
|
|
917
|
+
self.apply_gate_1q(mat, args[0])
|
|
918
|
+
elif mat.ndim == 2 and mat.shape == (4, 4):
|
|
919
|
+
# Pass logical qubit indices directly. apply_gate_2q handles internal mapping.
|
|
920
|
+
self.apply_gate_2q(mat, args[0], args[1])
|
|
921
|
+
elif mat.ndim == 4:
|
|
922
|
+
# Pass logical qubit indices directly. apply_gate_2q handles internal mapping.
|
|
923
|
+
self.apply_gate_2q(mat.reshape(4, 4), args[0], args[1])
|
|
924
|
+
|
|
925
|
+
DenseSVSimulator.run_circuit = run_circuit
|
|
926
|
+
print("✅ run_circuit patchato con successo: allineamento indici MSB stabilizzato!")
|
|
927
|
+
|
|
928
|
+
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
|
|
929
|
+
# CORE COMPILATION ENGINE (KERNEL FUSION LINEARE AD ALLOCAZIONE ZERO)
|
|
930
|
+
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
|
|
931
|
+
|
|
932
|
+
import time
|
|
933
|
+
import numpy as np
|
|
934
|
+
from typing import List, Tuple
|
|
935
|
+
|
|
936
|
+
try:
|
|
937
|
+
import jax
|
|
938
|
+
import jax.numpy as jnp
|
|
939
|
+
HAS_JAX = True
|
|
940
|
+
# Abilita i 64-bit nativi in JAX per evitare overflow degli indici oltre i 24 qubit
|
|
941
|
+
jax.config.update("jax_enable_x64", True)
|
|
942
|
+
except ImportError:
|
|
943
|
+
HAS_JAX = False
|
|
944
|
+
|
|
945
|
+
# Mappatura stazionaria ottimizzata (0-11 per 1Q, 20-21 per 2Q)
|
|
946
|
+
GATE_IDS = {
|
|
947
|
+
'id': 0, 'h': 1, 'x': 2, 'y': 3, 'z': 4, 's': 5, 'sdg': 6, 't': 7, 'tdg': 8,
|
|
948
|
+
'rx': 9, 'ry': 10, 'rz': 11, 'cx': 20, 'cz': 21
|
|
949
|
+
}
|
|
950
|
+
|
|
951
|
+
if HAS_JAX:
|
|
952
|
+
@jax.jit
|
|
953
|
+
def _apply_gate_fast_step(sv, operation):
|
|
954
|
+
|
|
955
|
+
g_id, q1, q2, param = operation
|
|
956
|
+
dim = sv.shape[0]
|
|
957
|
+
|
|
958
|
+
inv2 = 1.0 / jnp.sqrt(2.0)
|
|
959
|
+
cos_p = jnp.cos(param / 2.0)
|
|
960
|
+
sin_p = jnp.sin(param / 2.0)
|
|
961
|
+
|
|
962
|
+
# Clamping protettivo dell'indice virtuale per evitare errori Out-of-Bounds in XLA
|
|
963
|
+
safe_gid = jnp.where(g_id <= 11, g_id, 0).astype(jnp.int32)
|
|
964
|
+
|
|
965
|
+
g_1q = jax.lax.switch(
|
|
966
|
+
safe_gid,
|
|
967
|
+
[
|
|
968
|
+
lambda _: jnp.eye(2, dtype=jnp.complex128), # 0: id
|
|
969
|
+
lambda _: inv2 * jnp.array([[1.0, 1.0], [1.0, -1.0]], dtype=jnp.complex128), # 1: h
|
|
970
|
+
lambda _: jnp.array([[0.0, 1.0], [1.0, 0.0]], dtype=jnp.complex128), # 2: x
|
|
971
|
+
lambda _: jnp.array([[0.0, -1j], [1j, 0.0]], dtype=jnp.complex128), # 3: y
|
|
972
|
+
lambda _: jnp.array([[1.0, 0.0], [0.0, -1.0]], dtype=jnp.complex128), # 4: z
|
|
973
|
+
lambda _: jnp.array([[1.0, 0.0], [0.0, 1j]], dtype=jnp.complex128), # 5: s
|
|
974
|
+
lambda _: jnp.array([[1.0, 0.0], [0.0, -1j]], dtype=jnp.complex128), # 6: sdg
|
|
975
|
+
lambda _: jnp.array([[1.0, 0.0], [0.0, jnp.exp(1j * jnp.pi / 4)]], dtype=jnp.complex128), # 7: t
|
|
976
|
+
lambda _: jnp.array([[1.0, 0.0], [0.0, jnp.exp(-1j * jnp.pi / 4)]], dtype=jnp.complex128), # 8: tdg
|
|
977
|
+
lambda _: jnp.array([[cos_p, -1j * sin_p], [-1j * sin_p, cos_p]], dtype=jnp.complex128), # 9: rx
|
|
978
|
+
lambda _: jnp.array([[cos_p, -sin_p], [sin_p, cos_p]], dtype=jnp.complex128), # 10: ry
|
|
979
|
+
lambda _: jnp.array([[jnp.exp(-1j * param / 2.0), 0.0], [0.0, jnp.exp(1j * param / 2.0)]], dtype=jnp.complex128) # 11: rz
|
|
980
|
+
],
|
|
981
|
+
operand=None
|
|
982
|
+
)
|
|
983
|
+
|
|
984
|
+
# 1-QUBIT: APPLICAZIONE LINEARE MATRICE-SPAZZATA (Zero .reshape, Zero indici, Norma protetta)
|
|
985
|
+
def do_1q(_sv):
|
|
986
|
+
t_bit = q1.astype(jnp.int64)
|
|
987
|
+
stride = 1 << t_bit
|
|
988
|
+
|
|
989
|
+
# Generazione implicita delle maschere di canale a 1D contigua ammessa da XLA
|
|
990
|
+
# Isola gli stati accoppiati specchiati proiettando direttamente il vettore originale
|
|
991
|
+
idx_full = jnp.arange(dim, dtype=jnp.int64)
|
|
992
|
+
mask_0 = (idx_full & stride) == 0
|
|
993
|
+
|
|
994
|
+
# Troviamo i puntatori specchiati esatti per ciascuna cella di memoria
|
|
995
|
+
idx_0 = jnp.where(mask_0, idx_full, idx_full ^ stride)
|
|
996
|
+
idx_1 = idx_0 | stride
|
|
997
|
+
|
|
998
|
+
# Estrazione sicura dei coefficienti scalari complessi per evitare conflitti di broadcasting
|
|
999
|
+
g00, g01, g10, g11 = g_1q[0, 0], g_1q[0, 1], g_1q[1, 0], g_1q[1, 1]
|
|
1000
|
+
|
|
1001
|
+
# Calcolo simultaneo della superposizione lineare lungo i registri della CPU
|
|
1002
|
+
new_sv0 = g00 * _sv[idx_0] + g01 * _sv[idx_1]
|
|
1003
|
+
new_sv1 = g10 * _sv[idx_0] + g11 * _sv[idx_1]
|
|
1004
|
+
|
|
1005
|
+
# Ri-assemblaggio lineare continuo (Costo di allocazione scratchpad = 0)
|
|
1006
|
+
return jnp.where(mask_0, new_sv0, new_sv1)
|
|
1007
|
+
|
|
1008
|
+
# 2-QUBIT: APPLICAZIONE LINEARE CONTROLLATA (Zero .reshape, Zero indici, Anti-OOM a 24 Qubit)
|
|
1009
|
+
def do_2q(_sv):
|
|
1010
|
+
ctrl = q1.astype(jnp.int64)
|
|
1011
|
+
trgt = q2.astype(jnp.int64)
|
|
1012
|
+
|
|
1013
|
+
idx_full = jnp.arange(dim, dtype=jnp.int64)
|
|
1014
|
+
ctrl_active = (idx_full & (1 << ctrl)) != 0
|
|
1015
|
+
trgt_active = (idx_full & (1 << trgt)) != 0
|
|
1016
|
+
|
|
1017
|
+
# Caso CX: Inversione del bit target tramite operatore XOR lineare specchiato
|
|
1018
|
+
cx_sv = _sv[idx_full ^ (1 << trgt)]
|
|
1019
|
+
# Caso CZ: Inversione di fase condizionale sullo stato eccitato comune |11>
|
|
1020
|
+
cz_sv = jnp.where(trgt_active, -_sv, _sv)
|
|
1021
|
+
|
|
1022
|
+
# Selezione condizionale fusa del bersaglio mutato
|
|
1023
|
+
target_sv = jax.lax.cond(g_id == 20, lambda _: cx_sv, lambda _: cz_sv, operand=None)
|
|
1024
|
+
|
|
1025
|
+
# Restituisce il vettore modificato solo nei canali attivi, preservando intatto il resto
|
|
1026
|
+
return jnp.where(ctrl_active, target_sv, _sv)
|
|
1027
|
+
|
|
1028
|
+
# Configurazione del tracciatore statico ed esecuzione dei rami fusi
|
|
1029
|
+
exec_1q = g_id <= 11
|
|
1030
|
+
new_sv = jax.lax.cond(exec_1q, do_1q, do_2q, sv)
|
|
1031
|
+
return new_sv, None
|
|
1032
|
+
|
|
1033
|
+
@jax.jit
|
|
1034
|
+
def _compile_and_run_circuit_jit(state_vector, compiled_ops):
|
|
1035
|
+
"""Pipeline lineare fusa in XLA tramite scansione nativa hardware"""
|
|
1036
|
+
final_sv, _ = jax.lax.scan(_apply_gate_fast_step, state_vector, compiled_ops)
|
|
1037
|
+
return final_sv
|
|
1038
|
+
|
|
1039
|
+
print("💎 CORE COMPILER SIGILLATO V4 (ULTRA): Rimossi definitivamente i reshape dinamici. Struttura JIT stabile ed esatta!")
|
|
1040
|
+
|
|
1041
|
+
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
|
|
1042
|
+
# CORE COMPILATION ENGINE (KERNEL FUSION LINEARE AD ALLOCAZIONE ZERO)
|
|
1043
|
+
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
|
|
1044
|
+
if HAS_JAX:
|
|
1045
|
+
@jax.jit
|
|
1046
|
+
def _apply_gate_fast_step(sv, operation):
|
|
1047
|
+
g_id, q1, q2, param = operation
|
|
1048
|
+
dim = sv.shape[0]
|
|
1049
|
+
inv2 = 1.0 / jnp.sqrt(2.0)
|
|
1050
|
+
cos_p = jnp.cos(param / 2.0)
|
|
1051
|
+
sin_p = jnp.sin(param / 2.0)
|
|
1052
|
+
safe_gid = jnp.where(g_id <= 11, g_id, 0).astype(jnp.int32)
|
|
1053
|
+
|
|
1054
|
+
g_1q = jax.lax.switch(
|
|
1055
|
+
safe_gid,
|
|
1056
|
+
[
|
|
1057
|
+
lambda _: jnp.eye(2, dtype=jnp.complex128),
|
|
1058
|
+
lambda _: inv2 * jnp.array([[1.0, 1.0], [1.0, -1.0]], dtype=jnp.complex128),
|
|
1059
|
+
lambda _: jnp.array([[0.0, 1.0], [1.0, 0.0]], dtype=jnp.complex128),
|
|
1060
|
+
lambda _: jnp.array([[0.0, -1j], [1j, 0.0]], dtype=jnp.complex128),
|
|
1061
|
+
lambda _: jnp.array([[1.0, 0.0], [0.0, -1.0]], dtype=jnp.complex128),
|
|
1062
|
+
lambda _: jnp.array([[1.0, 0.0], [0.0, 1j]], dtype=jnp.complex128),
|
|
1063
|
+
lambda _: jnp.array([[1.0, 0.0], [0.0, -1j]], dtype=jnp.complex128),
|
|
1064
|
+
lambda _: jnp.array([[1.0, 0.0], [0.0, jnp.exp(1j * jnp.pi / 4)]], dtype=jnp.complex128),
|
|
1065
|
+
lambda _: jnp.array([[1.0, 0.0], [0.0, jnp.exp(-1j * jnp.pi / 4)]], dtype=jnp.complex128),
|
|
1066
|
+
lambda _: jnp.array([[cos_p, -1j * sin_p], [-1j * sin_p, cos_p]], dtype=jnp.complex128),
|
|
1067
|
+
lambda _: jnp.array([[cos_p, -sin_p], [sin_p, cos_p]], dtype=jnp.complex128),
|
|
1068
|
+
lambda _: jnp.array([[jnp.exp(-1j * param / 2.0), 0.0], [0.0, jnp.exp(1j * param / 2.0)]], dtype=jnp.complex128)
|
|
1069
|
+
],
|
|
1070
|
+
operand=None
|
|
1071
|
+
)
|
|
1072
|
+
|
|
1073
|
+
def do_1q(_sv):
|
|
1074
|
+
t_bit = q1.astype(jnp.int64)
|
|
1075
|
+
stride = 1 << t_bit
|
|
1076
|
+
idx_full = jnp.arange(dim, dtype=jnp.int64)
|
|
1077
|
+
mask_0 = (idx_full & stride) == 0
|
|
1078
|
+
idx_0 = jnp.where(mask_0, idx_full, idx_full ^ stride)
|
|
1079
|
+
idx_1 = idx_0 | stride
|
|
1080
|
+
g00, g01, g10, g11 = g_1q[0, 0], g_1q[0, 1], g_1q[1, 0], g_1q[1, 1]
|
|
1081
|
+
new_sv0 = g00 * _sv[idx_0] + g01 * _sv[idx_1]
|
|
1082
|
+
new_sv1 = g10 * _sv[idx_0] + g11 * _sv[idx_1]
|
|
1083
|
+
return jnp.where(mask_0, new_sv0, new_sv1)
|
|
1084
|
+
|
|
1085
|
+
def do_2q(_sv):
|
|
1086
|
+
ctrl = q1.astype(jnp.int64)
|
|
1087
|
+
trgt = q2.astype(jnp.int64)
|
|
1088
|
+
idx_full = jnp.arange(dim, dtype=jnp.int64)
|
|
1089
|
+
ctrl_active = (idx_full & (1 << ctrl)) != 0
|
|
1090
|
+
trgt_active = (idx_full & (1 << trgt)) != 0
|
|
1091
|
+
cx_sv = _sv[idx_full ^ (1 << trgt)]
|
|
1092
|
+
cz_sv = jnp.where(trgt_active, -_sv, _sv)
|
|
1093
|
+
target_sv = jax.lax.cond(g_id == 20, lambda _: cx_sv, lambda _: cz_sv, operand=None)
|
|
1094
|
+
return jnp.where(ctrl_active, target_sv, _sv)
|
|
1095
|
+
|
|
1096
|
+
exec_1q = g_id <= 11
|
|
1097
|
+
new_sv = jax.lax.cond(exec_1q, do_1q, do_2q, sv)
|
|
1098
|
+
return new_sv, None
|
|
1099
|
+
|
|
1100
|
+
@jax.jit
|
|
1101
|
+
def _compile_and_run_circuit_jit(state_vector, compiled_ops):
|
|
1102
|
+
final_sv, _ = jax.lax.scan(_apply_gate_fast_step, state_vector, compiled_ops)
|
|
1103
|
+
return final_sv
|
|
1104
|
+
|
|
1105
|
+
print("💎 CORE COMPILER PATCHATO V4: Struttura 1D lineare stabilizzata a norma fissa.")
|
|
1106
|
+
|
|
1107
|
+
|
|
1108
|
+
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
|
|
1109
|
+
# INTERFACCIA: run_circuit_jit_beast_mode (Mappatura Riallineata V2)
|
|
1110
|
+
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
|
|
1111
|
+
|
|
1112
|
+
def run_circuit_jit_beast_mode(self, circuit: List[Tuple]):
|
|
1113
|
+
|
|
1114
|
+
if not (HAS_JAX and self.xp is jnp):
|
|
1115
|
+
print("⚠️ JAX non attivo o istanza non JAX. Esecuzione via run_circuit standard...")
|
|
1116
|
+
return self.run_circuit(circuit, transpile=True)
|
|
1117
|
+
|
|
1118
|
+
# Scomposizione preliminare delle macro-porte (Toffoli, SWAP) nelle primitive compatibili
|
|
1119
|
+
target = QuantumTranspiler.transpile(circuit)
|
|
1120
|
+
|
|
1121
|
+
compiled_list = []
|
|
1122
|
+
for cmd in target:
|
|
1123
|
+
g_name = cmd[0].lower()
|
|
1124
|
+
args = cmd[1:]
|
|
1125
|
+
|
|
1126
|
+
if g_name in GATE_IDS:
|
|
1127
|
+
g_id = GATE_IDS[g_name]
|
|
1128
|
+
|
|
1129
|
+
# Smistamento in base alla firma della porta nella nuova mappa stazionaria
|
|
1130
|
+
if g_name in ['rx', 'ry', 'rz']:
|
|
1131
|
+
# Struttura gate parametrico standard: (name, qubit, theta)
|
|
1132
|
+
q1 = float(args[0])
|
|
1133
|
+
q2 = 0.0
|
|
1134
|
+
param = float(args[1])
|
|
1135
|
+
elif g_name in ['cx', 'cz']:
|
|
1136
|
+
# Struttura gate a due qubit standard: (name, control, target)
|
|
1137
|
+
q1 = float(args[0])
|
|
1138
|
+
q2 = float(args[1])
|
|
1139
|
+
param = 0.0
|
|
1140
|
+
else:
|
|
1141
|
+
# Struttura gate fissa a un qubit standard: (name, qubit)
|
|
1142
|
+
q1 = float(args[0])
|
|
1143
|
+
q2 = 0.0
|
|
1144
|
+
param = 0.0
|
|
1145
|
+
|
|
1146
|
+
compiled_list.append([float(g_id), q1, q2, param])
|
|
1147
|
+
|
|
1148
|
+
if not compiled_list:
|
|
1149
|
+
return
|
|
1150
|
+
|
|
1151
|
+
# Generazione della matrice di operazioni numeriche coerente [N_porte, 4] in float64
|
|
1152
|
+
compiled_ops = jnp.array(compiled_list, dtype=jnp.float64)
|
|
1153
|
+
|
|
1154
|
+
# Invocazione della pipeline fusa XLA ad alte prestazioni
|
|
1155
|
+
self.sv = _compile_and_run_circuit_jit(self.sv, compiled_ops)
|
|
1156
|
+
|
|
1157
|
+
|
|
1158
|
+
# Iniezione dell'interfaccia corretta e ripulita nella classe del simulatore
|
|
1159
|
+
DenseSVSimulator.run_circuit_jit_beast_mode = run_circuit_jit_beast_mode
|
|
1160
|
+
print("💎 INTERFACCIA RIALLINEATA: 'run_circuit_jit_beast_mode' agganciata con successo a DenseSVSimulator!")
|
|
1161
|
+
|
|
1162
|
+
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
|
|
1163
|
+
# AGGANCIO RUNTIME MANCANTI: measure & memory_mb
|
|
1164
|
+
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
|
|
1165
|
+
|
|
1166
|
+
def measure(self, qubit_idx: int) -> int:
|
|
1167
|
+
|
|
1168
|
+
if not 0 <= qubit_idx < self.n:
|
|
1169
|
+
raise ValueError(f"Qubit {qubit_idx} out of bounds")
|
|
1170
|
+
|
|
1171
|
+
xp = self.xp
|
|
1172
|
+
|
|
1173
|
+
# Calcolo dell'indice fisico specchiato per la convenzione MSB
|
|
1174
|
+
phys_q = self.n - 1 - qubit_idx
|
|
1175
|
+
stride = 1 << phys_q
|
|
1176
|
+
|
|
1177
|
+
if xp is jnp:
|
|
1178
|
+
# Ramo JAX: Calcolo conforme al tracciamento statico dei tensori
|
|
1179
|
+
probs = self.xp.abs(self.sv)**2
|
|
1180
|
+
sv_shape = [2] * self.n
|
|
1181
|
+
sv_nd = probs.reshape(sv_shape)
|
|
1182
|
+
prob_0 = float(jnp.sum(jnp.moveaxis(sv_nd, phys_q, 0)[0]))
|
|
1183
|
+
prob_1 = float(jnp.sum(jnp.moveaxis(sv_nd, phys_q, 0)[1]))
|
|
1184
|
+
else:
|
|
1185
|
+
# Ramo NumPy/CuPy Ultra-Performante: Somma a salti in memoria (Zero allocazione)
|
|
1186
|
+
sv_reshaped = self.sv.reshape(-1, 2, stride)
|
|
1187
|
+
prob_0 = float(xp.sum(xp.abs(sv_reshaped[:, 0, :])**2))
|
|
1188
|
+
prob_1 = float(xp.sum(xp.abs(sv_reshaped[:, 1, :])**2))
|
|
1189
|
+
|
|
1190
|
+
# Normalizzazione delle probabilità estratte
|
|
1191
|
+
total = prob_0 + prob_1
|
|
1192
|
+
if total > 1e-12:
|
|
1193
|
+
prob_0 /= total
|
|
1194
|
+
prob_1 /= total
|
|
1195
|
+
|
|
1196
|
+
# Campionamento dell'esito della misura
|
|
1197
|
+
result = int(np.random.choice([0, 1], p=[prob_0, prob_1]))
|
|
1198
|
+
|
|
1199
|
+
# Collasso della funzione d'onda in-place (Zero allocazione di maschere giganti)
|
|
1200
|
+
if xp is jnp:
|
|
1201
|
+
sv_shape = [2] * self.n
|
|
1202
|
+
sv_nd = self.sv.reshape(sv_shape)
|
|
1203
|
+
moved_sv = jnp.moveaxis(sv_nd, phys_q, 0)
|
|
1204
|
+
moved_sv = moved_sv.at[1 if result == 0 else 0].set(0.0)
|
|
1205
|
+
self.sv = jnp.moveaxis(moved_sv, 0, phys_q).ravel()
|
|
1206
|
+
else:
|
|
1207
|
+
# Slicing chirurgico nativo: azzera metà del vettore direttamente sulla matrice di vista
|
|
1208
|
+
sv_reshaped = self.sv.reshape(-1, 2, stride)
|
|
1209
|
+
sv_reshaped[:, 1 if result == 0 else 0, :] = 0.0
|
|
1210
|
+
|
|
1211
|
+
self.normalize()
|
|
1212
|
+
return result
|
|
1213
|
+
|
|
1214
|
+
def memory_mb(self) -> float:
|
|
1215
|
+
"""Estimate RAM usage in MB"""
|
|
1216
|
+
elem_size = 8 if self.dtype == np.complex64 else 16
|
|
1217
|
+
return self.dim * elem_size / 1e6
|
|
1218
|
+
|
|
1219
|
+
# Forza l'iniezione e l'ancoraggio dei due metodi nella classe principale
|
|
1220
|
+
DenseSVSimulator.measure = measure
|
|
1221
|
+
DenseSVSimulator.memory_mb = memory_mb
|
|
1222
|
+
|
|
1223
|
+
print("🚀 Metodi 'measure' e 'memory_mb' agganciati ed iniettati con successo in DenseSVSimulator!")
|
|
1224
|
+
|
|
1225
|
+
import random # Import the standard random module
|
|
1226
|
+
|
|
1227
|
+
def measure(self, qubit_idx: int) -> int:
|
|
1228
|
+
|
|
1229
|
+
if not 0 <= qubit_idx < self.n:
|
|
1230
|
+
raise ValueError(f"Qubit {qubit_idx} out of bounds")
|
|
1231
|
+
|
|
1232
|
+
xp = self.xp
|
|
1233
|
+
phys_q = self.n - 1 - qubit_idx
|
|
1234
|
+
stride = 1 << phys_q
|
|
1235
|
+
|
|
1236
|
+
if xp is jnp:
|
|
1237
|
+
# Ramo JAX: Calcolo esatto estraendo gli indici ffetivi dell'asse tensoriale spostato
|
|
1238
|
+
probs = self.xp.abs(self.sv)**2
|
|
1239
|
+
sv_shape = [2] * self.n
|
|
1240
|
+
sv_nd = probs.reshape(sv_shape)
|
|
1241
|
+
moved_probs = jnp.moveaxis(sv_nd, phys_q, 0)
|
|
1242
|
+
prob_0 = float(jnp.sum(moved_probs[0]))
|
|
1243
|
+
prob_1 = float(jnp.sum(moved_probs[1]))
|
|
1244
|
+
else:
|
|
1245
|
+
# Ramo NumPy/CuPy Stride Slicing
|
|
1246
|
+
sv_reshaped = self.sv.reshape(-1, 2, stride)
|
|
1247
|
+
prob_0 = float(xp.sum(xp.abs(sv_reshaped[:, 0, :])**2))
|
|
1248
|
+
prob_1 = float(xp.sum(xp.abs(sv_reshaped[:, 1, :])**2))
|
|
1249
|
+
|
|
1250
|
+
total = prob_0 + prob_1
|
|
1251
|
+
if total > 1e-12:
|
|
1252
|
+
prob_0 /= total
|
|
1253
|
+
prob_1 /= total
|
|
1254
|
+
|
|
1255
|
+
# Campionamento dell'esito della misura
|
|
1256
|
+
result = int(np.random.choice([0, 1], p=[prob_0, prob_1]))
|
|
1257
|
+
|
|
1258
|
+
if xp is jnp:
|
|
1259
|
+
sv_shape = [2] * self.n
|
|
1260
|
+
sv_nd = self.sv.reshape(sv_shape)
|
|
1261
|
+
moved_sv = jnp.moveaxis(sv_nd, phys_q, 0)
|
|
1262
|
+
# FIX: Correctly zero out the unmeasured component (1 if result is 0, 0 if result is 1)
|
|
1263
|
+
moved_sv = moved_sv.at[1 - result].set(0.0)
|
|
1264
|
+
self.sv = jnp.moveaxis(moved_sv, 0, phys_q).ravel()
|
|
1265
|
+
else:
|
|
1266
|
+
sv_reshaped = self.sv.reshape(-1, 2, stride)
|
|
1267
|
+
sv_reshaped[:, 1 if result == 0 else 0, :] = 0.0
|
|
1268
|
+
|
|
1269
|
+
self.normalize()
|
|
1270
|
+
return result
|
|
1271
|
+
|
|
1272
|
+
|
|
1273
|
+
def apply_cx(self, ctrl: int, tgt: int):
|
|
1274
|
+
"""CNOT gate - Corregge l'allineamento degli assi fisici per JAX ed evita il double-flipping."""
|
|
1275
|
+
xp = self.xp
|
|
1276
|
+
if not (0 <= ctrl < self.n and 0 <= tgt < self.n and ctrl != tgt):
|
|
1277
|
+
raise ValueError(f"Invalid control ({ctrl}) or target ({tgt})")
|
|
1278
|
+
|
|
1279
|
+
if xp is jnp:
|
|
1280
|
+
# Pass logical qubit indices directly to apply_gate_2q as it handles internal mapping for JAX.
|
|
1281
|
+
cx_mat = xp.array([[1,0,0,0],[0,1,0,0],[0,0,0,1],[0,0,1,0]], dtype=self.dtype)
|
|
1282
|
+
self.apply_gate_2q(cx_mat, ctrl, tgt)
|
|
1283
|
+
else:
|
|
1284
|
+
# Ramo NumPy/CuPy Stride Slicing classico in-place
|
|
1285
|
+
c_stride = 1 << (self.n - 1 - ctrl)
|
|
1286
|
+
t_stride = 1 << (self.n - 1 - tgt)
|
|
1287
|
+
step = 2 * max(c_stride, t_stride)
|
|
1288
|
+
inner_step = 2 * min(c_stride, t_stride)
|
|
1289
|
+
|
|
1290
|
+
for i in range(0, self.dim, step):
|
|
1291
|
+
for j in range(0, max(c_stride, t_stride), inner_step):
|
|
1292
|
+
base_idx = i + j + c_stride
|
|
1293
|
+
idx_0 = base_idx
|
|
1294
|
+
idx_1 = base_idx + t_stride
|
|
1295
|
+
tmp = self.sv[idx_0 : idx_0 + min(c_stride, t_stride)].copy()
|
|
1296
|
+
self.sv[idx_0 : idx_0 + min(c_stride, t_stride)] = self.sv[idx_1 : idx_1 + min(c_stride, t_stride)]
|
|
1297
|
+
self.sv[idx_1 : idx_1 + min(c_stride, t_stride)] = tmp
|
|
1298
|
+
|
|
1299
|
+
|
|
1300
|
+
def apply_cz(self, ctrl: int, tgt: int):
|
|
1301
|
+
"""Controlled-Z gate - Corregge l'allineamento degli assi fisici per JAX."""
|
|
1302
|
+
xp = self.xp
|
|
1303
|
+
if not (0 <= ctrl < self.n and 0 <= tgt < self.n and ctrl != tgt):
|
|
1304
|
+
raise ValueError(f"Invalid indices ({ctrl}, {tgt})")
|
|
1305
|
+
|
|
1306
|
+
if xp is jnp:
|
|
1307
|
+
# Pass logical qubit indices directly to apply_gate_2q as it handles internal mapping for JAX.
|
|
1308
|
+
cz_mat = xp.array([[1,0,0,0],[0,1,0,0],[0,0,1,0],[0,0,0,-1]], dtype=self.dtype)
|
|
1309
|
+
self.apply_gate_2q(cz_mat, ctrl, tgt)
|
|
1310
|
+
else:
|
|
1311
|
+
c_stride = 1 << (self.n - 1 - ctrl)
|
|
1312
|
+
t_stride = 1 << (self.n - 1 - tgt)
|
|
1313
|
+
step = 2 * max(c_stride, t_stride)
|
|
1314
|
+
inner_step = 2 * min(c_stride, t_stride)
|
|
1315
|
+
|
|
1316
|
+
for i in range(0, self.dim, step):
|
|
1317
|
+
for j in range(0, max(c_stride, t_stride), inner_step):
|
|
1318
|
+
idx = i + j + c_stride + t_stride
|
|
1319
|
+
self.sv[idx : idx + min(c_stride, t_stride)] *= -1
|
|
1320
|
+
|
|
1321
|
+
|
|
1322
|
+
# Iniezione strutturata finale
|
|
1323
|
+
DenseSVSimulator.measure = measure
|
|
1324
|
+
DenseSVSimulator.apply_cx = apply_cx
|
|
1325
|
+
DenseSVSimulator.apply_cz = apply_cz
|
|
1326
|
+
|
|
1327
|
+
print("💎 ENGINE CORE RIALLINEATO PERFETTAMENTE: Tutti i canali JAX e NumPy sono stabili!")
|
|
1328
|
+
|
|
1329
|
+
DenseSVSimulator.measure = measure
|
|
1330
|
+
DenseSVSimulator.apply_cx = apply_cx
|
|
1331
|
+
DenseSVSimulator.apply_cz = apply_cz
|
|
1332
|
+
|
|
1333
|
+
print("💎 ENGINE CORE RIALLINEATO PERFETTAMENTE: Tutti i canali JAX e NumPy sono stabili!")
|
|
1334
|
+
|
|
1335
|
+
import numpy as np
|
|
1336
|
+
|
|
1337
|
+
def run_circuit_with_chunking(self, circuit: list, chunk_size: int = 500, transpile: bool = True):
|
|
1338
|
+
"""
|
|
1339
|
+
Esegue circuiti quantistici di profondità estrema frammentandoli in sotto-blocchi.
|
|
1340
|
+
Previene la saturazione della cache JIT di JAX e azzera l'overhead sui circuiti NISQ.
|
|
1341
|
+
"""
|
|
1342
|
+
# 1. Transpilazione preliminare facoltativa delle macro-porte
|
|
1343
|
+
target_circuit = QuantumTranspiler.transpile(circuit) if transpile else circuit
|
|
1344
|
+
total_gates = len(target_circuit)
|
|
1345
|
+
|
|
1346
|
+
if total_gates <= chunk_size:
|
|
1347
|
+
try:
|
|
1348
|
+
self.run_circuit_jit_beast_mode(target_circuit)
|
|
1349
|
+
except Exception:
|
|
1350
|
+
self.run_circuit(target_circuit)
|
|
1351
|
+
return
|
|
1352
|
+
|
|
1353
|
+
print(f"⚙️ Circuit Chunking Attivo: {total_gates} gate totali divisi in blocchi da {chunk_size}...")
|
|
1354
|
+
|
|
1355
|
+
# 2. Suddivisione lineare del circuito in chunk protetti
|
|
1356
|
+
for i in range(0, total_gates, chunk_size):
|
|
1357
|
+
chunk = target_circuit[i : i + chunk_size]
|
|
1358
|
+
|
|
1359
|
+
try:
|
|
1360
|
+
self.run_circuit_jit_beast_mode(chunk)
|
|
1361
|
+
except Exception:
|
|
1362
|
+
self.run_circuit(chunk)
|
|
1363
|
+
|
|
1364
|
+
# Forza JAX a sincronizzare e scaricare i buffer temporanei della CPU/TPU
|
|
1365
|
+
if self.xp.__name__ == 'jax.numpy':
|
|
1366
|
+
self.sv.block_until_ready()
|
|
1367
|
+
|
|
1368
|
+
print(f"✅ Esecuzione completata con successo tramite {int(np.ceil(total_gates/chunk_size))} Chunk geometrici.")
|
|
1369
|
+
|
|
1370
|
+
# Iniezione del metodo enterprise corretto nella classe principale
|
|
1371
|
+
DenseSVSimulator.run_circuit_with_chunking = run_circuit_with_chunking
|
|
1372
|
+
|
|
1373
|
+
import jax
|
|
1374
|
+
import jax.numpy as jnp
|
|
1375
|
+
import numpy as np
|
|
1376
|
+
import time
|
|
1377
|
+
|
|
1378
|
+
def run_parametric_batch_jit(self, base_circuit: list, parameter_batch: np.ndarray) -> jnp.ndarray:
|
|
1379
|
+
|
|
1380
|
+
if not HAS_JAX or self.xp is not jnp:
|
|
1381
|
+
raise RuntimeError("JAX deve essere il backend attivo per usare run_parametric_batch_jit.")
|
|
1382
|
+
|
|
1383
|
+
# Decomposizione preliminare delle macro-porte (Toffoli, SWAP) tramite il tuo Transpiler
|
|
1384
|
+
target_circuit = QuantumTranspiler.transpile(base_circuit)
|
|
1385
|
+
|
|
1386
|
+
# Mappiamo il circuito secondo lo standard numerico della tua CELLA 8
|
|
1387
|
+
compiled_list = []
|
|
1388
|
+
for cmd in target_circuit:
|
|
1389
|
+
g_name = cmd[0].lower()
|
|
1390
|
+
args = cmd[1:]
|
|
1391
|
+
|
|
1392
|
+
if g_name in GATE_IDS:
|
|
1393
|
+
g_id = GATE_IDS[g_name]
|
|
1394
|
+
if g_name in ['rx', 'ry', 'rz', 'p']:
|
|
1395
|
+
# Memorizziamo una flag numerica (-1) per identificare la posizione del parametro dinamico
|
|
1396
|
+
compiled_list.append([float(g_id), float(args[0]), 0.0, -1.0])
|
|
1397
|
+
elif g_name in ['cx', 'cz']:
|
|
1398
|
+
compiled_list.append([float(g_id), float(args[0]), float(args[1]), 0.0])
|
|
1399
|
+
else:
|
|
1400
|
+
compiled_list.append([float(g_id), float(args[0]), 0.0, 0.0])
|
|
1401
|
+
|
|
1402
|
+
compiled_ops_template = jnp.array(compiled_list, dtype=jnp.float64)
|
|
1403
|
+
n_qubits = self.n
|
|
1404
|
+
dim = self.dim
|
|
1405
|
+
|
|
1406
|
+
# Definizione della funzione da vettorializzare per la singola istanza del batch
|
|
1407
|
+
def simulate_single_instance(single_params):
|
|
1408
|
+
# Inizializzazione dello stato |00...0> conforme alla tua CELLA 5
|
|
1409
|
+
local_sv = jnp.zeros(dim, dtype=jnp.complex128).at[0].set(1.0)
|
|
1410
|
+
|
|
1411
|
+
# Ricostruiamo la matrice delle operazioni sostituendo i parametri dinamici del batch
|
|
1412
|
+
# Trova dove abbiamo messo la flag -1.0 e inserisce il parametro reale
|
|
1413
|
+
def patch_ops(carry, op):
|
|
1414
|
+
g_id, q1, q2, p_val = op
|
|
1415
|
+
param_idx = carry[0]
|
|
1416
|
+
|
|
1417
|
+
# Se p_val == -1.0, prendiamo il parametro corrente dal batch e incrementiamo l'indice
|
|
1418
|
+
final_p = jax.lax.cond(p_val == -1.0, lambda _: single_params[param_idx], lambda _: p_val, operand=None)
|
|
1419
|
+
next_idx = jax.lax.cond(p_val == -1.0, lambda _: param_idx + 1, lambda _: param_idx, operand=None)
|
|
1420
|
+
|
|
1421
|
+
return (next_idx,), jnp.array([g_id, q1, q2, final_p], dtype=jnp.float64)
|
|
1422
|
+
|
|
1423
|
+
_, patched_ops = jax.lax.scan(patch_ops, (0,), compiled_ops_template)
|
|
1424
|
+
|
|
1425
|
+
# Chiamata diretta al tuo motore fuso XLA nativo (Cella 7 del tuo notebook)
|
|
1426
|
+
return _compile_and_run_circuit_jit(local_sv, patched_ops)
|
|
1427
|
+
|
|
1428
|
+
print(f"🚀 VMAP COMPILER: Parallelizzazione inter-circuito attiva per {len(parameter_batch)} istanze...")
|
|
1429
|
+
|
|
1430
|
+
# Applichiamo vmap sul super-grafo fuso
|
|
1431
|
+
vmap_sim = jax.vmap(simulate_single_instance, in_axes=(0,))
|
|
1432
|
+
jitted_vmap = jax.jit(vmap_sim)
|
|
1433
|
+
|
|
1434
|
+
t0 = time.perf_counter()
|
|
1435
|
+
res = jitted_vmap(jnp.asarray(parameter_batch, dtype=jnp.float64))
|
|
1436
|
+
res.block_until_ready()
|
|
1437
|
+
print(f"✅ Batch completato in {time.perf_counter() - t0:.4f} secondi!")
|
|
1438
|
+
return res
|
|
1439
|
+
|
|
1440
|
+
# Iniettiamo il metodo nel tuo simulatore originale
|
|
1441
|
+
DenseSVSimulator.run_parametric_batch_jit = run_parametric_batch_jit
|
|
1442
|
+
print("💎 BATCH ENGINE AGGANGIATO: Pieno supporto QML & VQE attivo sul tuo core!")
|
|
1443
|
+
|
|
1444
|
+
|
|
1445
|
+
|