dense-evolution 8.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
dense_evolution.py ADDED
@@ -0,0 +1,1445 @@
1
+ import subprocess, sys, os
2
+ import importlib
3
+ import numpy as np
4
+ from numpy import linalg as LA
5
+ import scipy.linalg
6
+ import matplotlib
7
+ import matplotlib.pyplot as plt
8
+ import matplotlib.gridspec as gridspec
9
+ from matplotlib.patches import FancyArrowPatch, Rectangle, Circle, FancyBboxPatch
10
+ from matplotlib.colors import LinearSegmentedColormap, Normalize
11
+ from matplotlib.cm import ScalarMappable
12
+ import matplotlib.ticker as ticker
13
+ from IPython.display import display, HTML, clear_output
14
+ import time, re, io, warnings, hashlib, json, copy, psutil, platform
15
+ from datetime import datetime
16
+ from typing import List, Dict, Optional, Tuple
17
+ from dataclasses import dataclass, field
18
+ from enum import Enum
19
+ import pandas as pd
20
+
21
+ warnings.filterwarnings('ignore')
22
+
23
+ # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
24
+ # CELLA 1: Hardware Detection & Configuration
25
+ # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
26
+
27
+ def install(pkg):
28
+ subprocess.check_call([sys.executable, '-m', 'pip', 'install', '-q', pkg])
29
+
30
+ # GPU Support (Optional)
31
+ try:
32
+ import cupy as cp
33
+ HAS_CUPY = True
34
+ print('✅ CuPy disponibile — GPU acceleration attiva')
35
+ except:
36
+ HAS_CUPY = False
37
+ print('ℹ️ CuPy non disponibile — usando NumPy CPU')
38
+
39
+ # JAX Support (Optional with NumPy fallback)
40
+ try:
41
+ import jax
42
+ import jax.numpy as jnp
43
+ HAS_JAX = True
44
+ print('✅ JAX disponibile (JIT optimization attivo)')
45
+ except:
46
+ HAS_JAX = False
47
+ jnp = None # Fallback: will use NumPy instead
48
+ print('ℹ️ JAX non disponibile — fallback NumPy attivo')
49
+
50
+ # Hardware detection
51
+ ram_total = psutil.virtual_memory().total / (1024**3)
52
+ ram_avail = psutil.virtual_memory().available / (1024**3)
53
+ print(f'\n⌨️ Sistema: {platform.processor()}')
54
+ print(f'💾 RAM Totale: {ram_total:.1f} GB | Disponibile: {ram_avail:.1f} GB')
55
+
56
+ # Automatic qubit limits based on RAM
57
+ if ram_total >= 50:
58
+ MAX_DENSE_QUBITS = 28
59
+ print(f'🚀 High-RAM runtime → Dense SV fino a 28 qubit')
60
+ elif ram_total >= 12:
61
+ MAX_DENSE_QUBITS = 24
62
+ print(f'✅ Standard runtime → Dense SV fino a 24 qubit')
63
+ else:
64
+ MAX_DENSE_QUBITS = 20
65
+ print(f'⚠️ RAM limitata → Dense SV fino a 20 qubit')
66
+
67
+ # GPU detection
68
+ try:
69
+ gpu_info = subprocess.check_output(['nvidia-smi', '--query-gpu=name,memory.total',
70
+ '--format=csv,noheader'], text=True).strip()
71
+ print(f'🎮 GPU: {gpu_info}')
72
+ HAS_GPU = True
73
+ except:
74
+ HAS_GPU = False
75
+ print('ℹ️ Nessuna GPU NVIDIA rilevata')
76
+
77
+ print(f'\n📊 Configurazione: MAX_DENSE={MAX_DENSE_QUBITS}q | JAX={HAS_JAX} | GPU={HAS_GPU}')
78
+
79
+ # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
80
+ # CELLA 2: Matplotlib Professional Styling
81
+ # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
82
+
83
+ plt.style.use('dark_background')
84
+ DARK_BG = '#010409'
85
+ PANEL_BG = '#0d1117'
86
+ PANEL_BG2 = '#161b22'
87
+ BORDER = '#21262d'
88
+ ACC_G = '#00ff9d'
89
+ ACC_B = '#00c8ff'
90
+ ACC_O = '#ff6b35'
91
+ ACC_P = '#b400ff'
92
+ ACC_TEAL = '#00ffff'
93
+ ACC_PINK = '#ff007f'
94
+ WARN = '#ffd700'
95
+ DANGER = '#ff4444'
96
+ MUTED = '#7d8590'
97
+ TEXT = '#e6edf3'
98
+
99
+ matplotlib.rcParams.update({
100
+ 'figure.facecolor': DARK_BG,
101
+ 'axes.facecolor': PANEL_BG,
102
+ 'axes.edgecolor': BORDER,
103
+ 'axes.labelcolor': MUTED,
104
+ 'axes.titlecolor': TEXT,
105
+ 'text.color': TEXT,
106
+ 'xtick.color': MUTED,
107
+ 'ytick.color': MUTED,
108
+ 'grid.color': BORDER,
109
+ 'grid.alpha': 0.5,
110
+ 'font.family': 'monospace',
111
+ 'font.size': 9,
112
+ 'figure.dpi': 130,
113
+ 'savefig.dpi': 200,
114
+ })
115
+
116
+ print('✅ Matplotlib tema dark professionale configurato')
117
+
118
+ # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
119
+ # CELLA 3: Gate Matrices & Operators
120
+ # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
121
+
122
+ INV2 = 1.0 / np.sqrt(2.0)
123
+
124
+ GATES = {
125
+ 'h': INV2 * np.array([[1,1],[1,-1]], dtype=complex),
126
+ 'x': np.array([[0,1],[1,0]], dtype=complex),
127
+ 'y': np.array([[0,-1j],[1j,0]], dtype=complex),
128
+ 'z': np.array([[1,0],[0,-1]], dtype=complex),
129
+ 's': np.array([[1,0],[0,1j]], dtype=complex),
130
+ 'sdg': np.array([[1,0],[0,-1j]], dtype=complex),
131
+ 't': np.array([[1,0],[0,np.exp(1j*np.pi/4)]], dtype=complex),
132
+ 'tdg': np.array([[1,0],[0,np.exp(-1j*np.pi/4)]], dtype=complex),
133
+ 'sx': 0.5*np.array([[1+1j,1-1j],[1-1j,1+1j]], dtype=complex),
134
+ 'id': np.eye(2, dtype=complex),
135
+ 'cx': np.array([[1,0,0,0],[0,1,0,0],[0,0,0,1],[0,0,1,0]], dtype=complex),
136
+ 'cz': np.array([[1,0,0,0],[0,1,0,0],[0,0,1,0],[0,0,0,-1]], dtype=complex),
137
+ 'cy': np.array([[1,0,0,0],[0,1,0,0],[0,0,0,-1j],[0,0,1j,0]], dtype=complex),
138
+ 'swap': np.array([[1,0,0,0],[0,0,1,0],[0,1,0,0],[0,0,0,1]], dtype=complex),
139
+ 'iswap':np.array([[1,0,0,0],[0,0,1j,0],[0,1j,0,0],[0,0,0,1]], dtype=complex),
140
+ 'ecr': INV2 * np.array([[0,0,1,1j],[0,0,1j,1],[1,-1j,0,0],[-1j,1,0,0]], dtype=complex),
141
+ 'ccx': np.array([[1,0,0,0,0,0,0,0],
142
+ [0,1,0,0,0,0,0,0],
143
+ [0,0,1,0,0,0,0,0],
144
+ [0,0,0,1,0,0,0,0],
145
+ [0,0,0,0,1,0,0,0],
146
+ [0,0,0,0,0,1,0,0],
147
+ [0,0,0,0,0,0,0,1],
148
+ [0,0,0,0,0,0,1,0]], dtype=complex)
149
+ }
150
+
151
+ # ┌─────────────────────────────────────────────────────────────────┐
152
+ # │ FIX #1: PARAMETRIC GATES WITH NUMPY FALLBACK (JAX-OPTIONAL) │
153
+ # └─────────────────────────────────────────────────────────────────┘
154
+ def _build_parametric_gates(use_jax: bool = HAS_JAX):
155
+ """
156
+ Builds parametric gate functions with automatic fallback to NumPy if JAX unavailable.
157
+ Returns a dictionary of gate factories.
158
+ """
159
+ if use_jax and HAS_JAX:
160
+ # JAX version with JIT optimization potential
161
+ def rx_gate(theta: float):
162
+ c, s = jnp.cos(theta/2), jnp.sin(theta/2)
163
+ return jnp.array([[c, -1j*s], [-1j*s, c]], dtype=complex)
164
+
165
+ def ry_gate(theta: float):
166
+ c, s = jnp.cos(theta/2), jnp.sin(theta/2)
167
+ return jnp.array([[c, -s], [s, c]], dtype=complex)
168
+
169
+ def rz_gate(theta: float):
170
+ return jnp.array([[jnp.exp(-1j*theta/2), 0],
171
+ [0, jnp.exp(1j*theta/2)]], dtype=complex)
172
+
173
+ def u3_gate(theta: float, phi: float, lam: float):
174
+ c, s = jnp.cos(theta/2), jnp.sin(theta/2)
175
+ return jnp.array(
176
+ [[c, -jnp.exp(1j*lam)*s],
177
+ [jnp.exp(1j*phi)*s, jnp.exp(1j*(phi+lam))*c]]
178
+ , dtype=complex)
179
+
180
+ def p_gate(lam: float):
181
+ return jnp.array([[1, 0], [0, jnp.exp(1j*lam)]], dtype=complex)
182
+
183
+ def cp_gate(lam: float):
184
+ return jnp.array([[1,0,0,0],[0,1,0,0],[0,0,1,0],[0,0,0,jnp.exp(1j*lam)]], dtype=complex)
185
+
186
+ def crz_gate(theta: float):
187
+ return jnp.array([[1,0,0,0],[0,1,0,0],
188
+ [0,0,jnp.exp(-1j*theta/2),0],
189
+ [0,0,0,jnp.exp(1j*theta/2)]], dtype=complex)
190
+ else:
191
+ # NumPy version (fallback or primary when JAX unavailable)
192
+ def rx_gate(theta: float):
193
+ c, s = np.cos(theta/2), np.sin(theta/2)
194
+ return np.array([[c, -1j*s], [-1j*s, c]], dtype=complex)
195
+
196
+ def ry_gate(theta: float):
197
+ c, s = np.cos(theta/2), np.sin(theta/2)
198
+ return np.array([[c, -s], [s, c]], dtype=complex)
199
+
200
+ def rz_gate(theta: float):
201
+ return np.array([[np.exp(-1j*theta/2), 0],
202
+ [0, np.exp(1j*theta/2)]], dtype=complex)
203
+
204
+ def u3_gate(theta: float, phi: float, lam: float):
205
+ c, s = np.cos(theta/2), np.sin(theta/2)
206
+ return np.array(
207
+ [[c, -np.exp(1j*lam)*s],
208
+ [np.exp(1j*phi)*s, np.exp(1j*(phi+lam))*c]]
209
+ , dtype=complex)
210
+
211
+ def p_gate(lam: float):
212
+ return np.array([[1, 0], [0, np.exp(1j*lam)]], dtype=complex)
213
+
214
+ def cp_gate(lam: float):
215
+ return np.array([[1,0,0,0],[0,1,0,0],[0,0,1,0],[0,0,0,np.exp(1j*lam)]], dtype=complex)
216
+
217
+ def crz_gate(theta: float):
218
+ return np.array([[1,0,0,0],[0,1,0,0],
219
+ [0,0,np.exp(-1j*theta/2),0],
220
+ [0,0,0,np.exp(1j*theta/2)]], dtype=complex)
221
+
222
+ return {
223
+ 'rx': rx_gate, 'ry': ry_gate, 'rz': rz_gate,
224
+ 'u3': u3_gate, 'u2': lambda p,l: u3_gate(np.pi/2,p,l),
225
+ 'u1': lambda l: p_gate(l), 'p': p_gate,
226
+ 'cp': cp_gate, 'crz': crz_gate,
227
+ }
228
+
229
+ PARAMETRIC_GATES = _build_parametric_gates(use_jax=HAS_JAX)
230
+
231
+ print('✅ Gate library caricata (Parametric gates: JAX-safe with NumPy fallback)')
232
+ print(f' Gate 1q: {list(GATES.keys())[:6]}...')
233
+ print(f' Parametrici: {list(PARAMETRIC_GATES.keys())}')
234
+
235
+ # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
236
+ # CELLA 4: JAX JIT Compilation (Optional)
237
+ # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
238
+
239
+ if HAS_JAX:
240
+ def _jax_apply_gate_1q_einsum_impl(sv_array, gate_array, n_qubits, qubit_idx):
241
+ n, q = int(n_qubits), int(qubit_idx)
242
+ sv_nd = sv_array.reshape([2] * n)
243
+ sv_moved = jnp.moveaxis(sv_nd, q, -1)
244
+ flat_shape = (1 << (n - 1), 2)
245
+ result_moved = jnp.dot(sv_moved.reshape(flat_shape), gate_array.T)
246
+ result_nd = result_moved.reshape([2] * n)
247
+ return jnp.moveaxis(result_nd, -1, q).ravel()
248
+
249
+ jax_apply_gate_1q_einsum = jax.jit(_jax_apply_gate_1q_einsum_impl, static_argnums=(2, 3))
250
+
251
+ def _jax_apply_gate_2q_einsum_impl(sv_array, gate_array, n_qubits, q1, q2):
252
+ n = int(n_qubits)
253
+ sv_nd = sv_array.reshape([2] * n)
254
+ sv_moved = jnp.moveaxis(sv_nd, (q1, q2), (-2, -1))
255
+ flat_shape = (1 << (n - 2), 4)
256
+ gate_2d = gate_array.reshape(4, 4)
257
+ result_moved = jnp.dot(sv_moved.reshape(flat_shape), gate_2d.T)
258
+ result_nd = result_moved.reshape([2] * n)
259
+ return jnp.moveaxis(result_nd, (-2, -1), (q1, q2)).ravel()
260
+
261
+ jax_apply_gate_2q_einsum = jax.jit(_jax_apply_gate_2q_einsum_impl, static_argnums=(2, 3, 4))
262
+ print("💎 JAX JIT compilation attiva per gate 1q e 2q")
263
+ else:
264
+ jax_apply_gate_1q_einsum = None
265
+ jax_apply_gate_2q_einsum = None
266
+ print("ℹ️ JAX JIT non disponibile — fallback NumPy ottimizzato")
267
+
268
+ # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
269
+ # CELLA 5: DenseSVSimulator Core Engine
270
+ # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
271
+
272
+ class DenseSVSimulator:
273
+ """
274
+ Professional quantum circuit simulator using dense statevector representation.
275
+
276
+ Features:
277
+ - NumPy/CuPy/JAX backend with automatic selection
278
+ - 1q and 2q gate support via micro-optimized back-axis dot parallelization
279
+ - Optional JAX JIT compilation
280
+ - Vectorized stride-slicing measurement and collapse
281
+ - Noise model integration (Kraus operators via stochastic trajectories)
282
+
283
+ Endianness: MSB-first (bit 0 = qubit n-1, bit n-1 = qubit 0)
284
+ """
285
+
286
+ def __init__(self, n_qubits: int, use_gpu: bool = True, use_float32: bool = False):
287
+ self.n = n_qubits
288
+ self.dim = 2**n_qubits
289
+ self.use_gpu = use_gpu and HAS_CUPY
290
+ self.dtype = np.complex64 if use_float32 else np.complex128
291
+
292
+ if self.use_gpu:
293
+ import cupy as cp
294
+ self.xp = cp
295
+ print(f'🎮 DenseSV: CuPy GPU | n={n_qubits} | dim={self.dim:,}')
296
+ elif HAS_JAX:
297
+ self.xp = jnp
298
+ dtype_str = 'float32' if use_float32 else 'float64'
299
+ print(f'⚡ DenseSV: JAX CPU/TPU | n={n_qubits} | dim={self.dim:,} | {dtype_str}')
300
+ else:
301
+ self.xp = np
302
+ dtype_str = 'float32' if use_float32 else 'float64'
303
+ print(f'⌨️ DenseSV: NumPy CPU | n={n_qubits} | dim={self.dim:,} | {dtype_str}')
304
+
305
+ # Initialize |00...0⟩
306
+ self.sv = self.xp.zeros(self.dim, dtype=self.dtype)
307
+ if self.xp is jnp:
308
+ self.sv = self.sv.at[0].set(1.0)
309
+ else:
310
+ self.sv[0] = 1.0
311
+
312
+ ram_mb = (self.dim * (8 if use_float32 else 16)) / (1024**2)
313
+ print(f' RAM allocata: {ram_mb:.1f} MB')
314
+ if ram_mb > 1000:
315
+ print(f' ⚠️ >1GB: Richiede architettura Full Vector ottimizzata a basso livello')
316
+
317
+ def set_initial_state(self, state_vector: Optional[np.ndarray] = None):
318
+ """Set initial state vector or reset to |0...0⟩"""
319
+ xp = self.xp
320
+ if state_vector is None:
321
+ self.sv = xp.zeros(self.dim, dtype=self.dtype)
322
+ if xp is jnp:
323
+ self.sv = self.sv.at[0].set(1.0)
324
+ else:
325
+ self.sv[0] = 1.0
326
+ else:
327
+ if len(state_vector) != self.dim:
328
+ raise ValueError(f"State vector must have length 2^n ({self.dim})")
329
+ self.sv = xp.asarray(state_vector, dtype=self.dtype)
330
+ self.normalize()
331
+
332
+ def apply_gate_1q(self, gate: np.ndarray, qubit: int):
333
+ """Apply single-qubit gate with MSB convention"""
334
+ if not 0 <= qubit < self.n:
335
+ raise ValueError(f"Qubit index {qubit} out of bounds for {self.n} qubits")
336
+ self._apply_gate_fast(gate, qubit)
337
+
338
+ def _apply_gate_fast(self, gate: np.ndarray, qubit: int):
339
+ """Vectorized O(2^n) 1q gate application using back-axis dot product"""
340
+ xp = self.xp
341
+ g = xp.asarray(gate, dtype=self.dtype)
342
+
343
+ if HAS_JAX and xp is jnp and jax_apply_gate_1q_einsum is not None:
344
+ self.sv = jax_apply_gate_1q_einsum(self.sv, g, self.n, qubit)
345
+ else:
346
+ # Micro-ottimizzazione: Spostamento dell'asse all'ultimo posto ed esecuzione dot parallelo
347
+ sv_nd = self.sv.reshape([2] * self.n)
348
+ sv_moved = xp.moveaxis(sv_nd, qubit, -1)
349
+ flat_shape = (1 << (self.n - 1), 2)
350
+ result_moved = xp.dot(sv_moved.reshape(flat_shape), g.T)
351
+ result_nd = result_moved.reshape([2] * self.n)
352
+ self.sv = xp.moveaxis(result_nd, -1, qubit).ravel()
353
+
354
+ def apply_gate_2q(self, gate: np.ndarray, q1: int, q2: int):
355
+ """Apply 2-qubit gate (4x4 or 2x2x2x2 tensor) with MSB convention"""
356
+ xp = self.xp
357
+ if not (0 <= q1 < self.n and 0 <= q2 < self.n and q1 != q2):
358
+ raise ValueError(f"Invalid qubit indices ({q1}, {q2})")
359
+
360
+ g_2d = xp.asarray(gate, dtype=self.dtype).reshape(4, 4)
361
+
362
+ if HAS_JAX and xp is jnp and jax_apply_gate_2q_einsum is not None:
363
+ self.sv = jax_apply_gate_2q_einsum(self.sv, g_2d, self.n, q1, q2)
364
+ else:
365
+ # Soluzione migliore: eliminazione stringhe e contrazione via vettorializzazione BLAS posteriore
366
+ sv_nd = self.sv.reshape([2] * self.n)
367
+ sv_moved = xp.moveaxis(sv_nd, (q1, q2), (-2, -1))
368
+ flat_shape = (1 << (self.n - 2), 4)
369
+ result_moved = xp.dot(sv_moved.reshape(flat_shape), g_2d.T)
370
+ result_nd = result_moved.reshape([2] * self.n)
371
+ self.sv = xp.moveaxis(result_nd, (-2, -1), (q1, q2)).ravel()
372
+
373
+ def apply_cx(self, ctrl: int, tgt: int):
374
+ """Controlled-X (CNOT) gate - bit-mask based (No massive index allocation)"""
375
+ xp = self.xp
376
+ if not (0 <= ctrl < self.n and 0 <= tgt < self.n and ctrl != tgt):
377
+ raise ValueError(f"Invalid control ({ctrl}) or target ({tgt})")
378
+
379
+ if xp is jnp:
380
+ # Ramo JAX: Sfrutta il motore a 2 qubit nativo per non rompere la compilazione tracciata
381
+ cx_mat = xp.array([[1,0,0,0],[0,1,0,0],[0,0,0,1],[0,0,1,0]], dtype=self.dtype)
382
+ self.apply_gate_2q(cx_mat, ctrl, tgt)
383
+ else:
384
+ # Ramo NumPy/CuPy: Ottimizzazione chirurgica della memoria in-place sulla CPU/GPU
385
+ c_stride = 1 << (self.n - 1 - ctrl)
386
+ t_stride = 1 << (self.n - 1 - tgt)
387
+ step = 2 * max(c_stride, t_stride)
388
+ inner_step = 2 * min(c_stride, t_stride)
389
+
390
+ for i in range(0, self.dim, step):
391
+ for j in range(0, max(c_stride, t_stride), inner_step):
392
+ base_idx = i + j + c_stride
393
+ idx_0 = base_idx
394
+ idx_1 = base_idx + t_stride
395
+
396
+ # Swap dei blocchi contigui senza creare array intermedi condizionali
397
+ tmp = self.sv[idx_0 : idx_0 + min(c_stride, t_stride)].copy()
398
+ self.sv[idx_0 : idx_0 + min(c_stride, t_stride)] = self.sv[idx_1 : idx_1 + min(c_stride, t_stride)]
399
+ self.sv[idx_1 : idx_1 + min(c_stride, t_stride)] = tmp
400
+
401
+ def apply_cz(self, ctrl: int, tgt: int):
402
+ """Controlled-Z gate - Micro-optimized stride slicing"""
403
+ xp = self.xp
404
+ if not (0 <= ctrl < self.n and 0 <= tgt < self.n and ctrl != tgt):
405
+ raise ValueError(f"Invalid indices ({ctrl}, {tgt})")
406
+
407
+ if xp is jnp:
408
+ cz_mat = xp.array([[1,0,0,0],[0,1,0,0],[0,0,1,0],[0,0,0,-1]], dtype=self.dtype)
409
+ self.apply_gate_2q(cz_mat, ctrl, tgt)
410
+ else:
411
+ c_stride = 1 << (self.n - 1 - ctrl)
412
+ t_stride = 1 << (self.n - 1 - tgt)
413
+ step = 2 * max(c_stride, t_stride)
414
+ inner_step = 2 * min(c_stride, t_stride)
415
+
416
+ # Inversione di segno diretta sui blocchi mirati dove sia il controllo che il target sono a 1
417
+ for i in range(0, self.dim, step):
418
+ for j in range(0, max(c_stride, t_stride), inner_step):
419
+ idx = i + j + c_stride + t_stride
420
+ self.sv[idx : idx + min(c_stride, t_stride)] *= -1
421
+
422
+ def normalize(self):
423
+ """Normalize statevector to unit norm in-place"""
424
+ norm = self.xp.linalg.norm(self.sv)
425
+ if norm > 1e-12:
426
+ if self.xp is jnp:
427
+ self.sv = self.sv / norm
428
+ else:
429
+ self.sv /= norm
430
+
431
+ def get_probabilities(self) -> np.ndarray:
432
+ """Compute |ψ|² for each basis state"""
433
+ probs = self.xp.abs(self.sv)**2
434
+ if self.use_gpu:
435
+ return probs.get()
436
+ return np.array(probs, dtype=np.float64)
437
+
438
+ def get_statevector(self) -> np.ndarray:
439
+ """Return statevector as NumPy array"""
440
+ sv = self.sv
441
+ if self.xp is jnp:
442
+ return np.array(sv)
443
+ if self.use_gpu:
444
+ return sv.get()
445
+ return np.array(sv)
446
+
447
+ # ┌─────────────────────────────────────────────────────────────────┐
448
+ # │ FIX #2: VECTORIZED MEASURE (Stride Slicing — No Index Masks) │
449
+ # └─────────────────────────────────────────────────────────────────┘
450
+ def measure(self, qubit_idx: int) -> int:
451
+ """
452
+ Measure a single qubit and collapse statevector.
453
+ FIXED: Uses micro-optimized stride-slicing without memory allocation.
454
+ """
455
+ if not 0 <= qubit_idx < self.n:
456
+ raise ValueError(f"Qubit {qubit_idx} out of bounds")
457
+
458
+ xp = self.xp
459
+
460
+ # Calcolo dell'indice fisico specchiato per la convenzione MSB
461
+ phys_q = self.n - 1 - qubit_idx
462
+ stride = 1 << phys_q
463
+
464
+ if xp is jnp:
465
+ # Ramo JAX: Calcolo conforme al tracciamento statico dei tensori con fette esatte
466
+ probs = self.xp.abs(self.sv)**2
467
+ sv_shape = [2] * self.n
468
+ sv_nd = probs.reshape(sv_shape)
469
+ prob_0 = float(jnp.sum(jnp.moveaxis(sv_nd, phys_q, 0)[0]))
470
+ prob_1 = float(jnp.sum(jnp.moveaxis(sv_nd, phys_q, 0)[1]))
471
+ else:
472
+ # Ramo NumPy/CuPy Ultra-Performante: Somma a salti in memoria (Zero allocazione)
473
+ sv_reshaped = self.sv.reshape(-1, 2, stride)
474
+ prob_0 = float(xp.sum(xp.abs(sv_reshaped[:, 0, :])**2))
475
+ prob_1 = float(xp.sum(xp.abs(sv_reshaped[:, 1, :])**2))
476
+
477
+ # Normalizzazione delle probabilità estratte
478
+ total = prob_0 + prob_1
479
+ if total > 1e-12:
480
+ prob_0 /= total
481
+ prob_1 /= total
482
+
483
+ # Campionamento dell'esito della misura
484
+ result = int(np.random.choice([0, 1], p=[prob_0, prob_1]))
485
+
486
+ # Collasso della funzione d'onda in-place (Zero allocazione di maschere giganti)
487
+ if xp is jnp:
488
+ sv_shape = [2] * self.n
489
+ sv_nd = self.sv.reshape(sv_shape)
490
+ moved_sv = jnp.moveaxis(sv_nd, phys_q, 0)
491
+ moved_sv = moved_sv.at[1 if result == 0 else 0].set(0.0)
492
+ self.sv = jnp.moveaxis(moved_sv, 0, phys_q).ravel()
493
+ else:
494
+ # Slicing chirurgico nativo: azzera metà del vettore direttamente sulla matrice di vista
495
+ sv_reshaped[:, 1 if result == 0 else 0, :] = 0.0
496
+
497
+ self.normalize()
498
+ return result
499
+
500
+ def memory_mb(self) -> float:
501
+ """Estimate RAM usage in MB"""
502
+ elem_size = 8 if self.dtype == np.complex64 else 16
503
+ return self.dim * elem_size / 1e6
504
+
505
+ # ┌─────────────────────────────────────────────────────────────────┐
506
+ # │ PARAMETRIC GATE INJECTION (VERSIONE INTEGRALE DA REPOSITORY) │
507
+ # └─────────────────────────────────────────────────────────────────┘
508
+
509
+ def patch_dense_parametric(cls):
510
+ """Inject all parametric standard OpenQASM 2.0 methods into DenseSVSimulator"""
511
+
512
+ def apply_rx(self, qubit: int, theta: float):
513
+ gate = PARAMETRIC_GATES['rx'](theta)
514
+ self.apply_gate_1q(gate, qubit)
515
+
516
+ def apply_ry(self, qubit: int, theta: float):
517
+ gate = PARAMETRIC_GATES['ry'](theta)
518
+ self.apply_gate_1q(gate, qubit)
519
+
520
+ def apply_rz(self, qubit: int, theta: float):
521
+ gate = PARAMETRIC_GATES['rz'](theta)
522
+ self.apply_gate_1q(gate, qubit)
523
+
524
+ def apply_u3(self, qubit: int, theta: float, phi: float, lam: float):
525
+ gate = PARAMETRIC_GATES['u3'](theta, phi, lam)
526
+ self.apply_gate_1q(gate, qubit)
527
+
528
+ def apply_u2(self, qubit: int, phi: float, lam: float):
529
+ gate = PARAMETRIC_GATES['u2'](phi, lam)
530
+ self.apply_gate_1q(gate, qubit)
531
+
532
+ def apply_u1(self, qubit: int, lam: float):
533
+ gate = PARAMETRIC_GATES['u1'](lam)
534
+ self.apply_gate_1q(gate, qubit)
535
+
536
+ def apply_p(self, qubit: int, lam: float):
537
+ gate = PARAMETRIC_GATES['p'](lam)
538
+ self.apply_gate_1q(gate, qubit)
539
+
540
+ def apply_cp(self, ctrl: int, tgt: int, lam: float):
541
+ gate = PARAMETRIC_GATES['cp'](lam)
542
+ self.apply_gate_2q(gate, ctrl, tgt)
543
+
544
+ def apply_crz(self, ctrl: int, tgt: int, theta: float):
545
+ gate = PARAMETRIC_GATES['crz'](theta)
546
+ self.apply_gate_2q(gate, ctrl, tgt)
547
+
548
+ # Iniezione di tutta la suite senza eccezioni o esclusioni
549
+ cls.apply_rx = apply_rx
550
+ cls.apply_ry = apply_ry
551
+ cls.apply_rz = apply_rz
552
+ cls.apply_u3 = apply_u3
553
+ cls.apply_u2 = apply_u2
554
+ cls.apply_u1 = apply_u1
555
+ cls.apply_p = apply_p
556
+ cls.apply_cp = apply_cp
557
+ cls.apply_crz = apply_crz
558
+
559
+ print("✅ All parametric methods (including u1/u2) injected into DenseSVSimulator")
560
+
561
+ patch_dense_parametric(DenseSVSimulator)
562
+
563
+ import numpy as np
564
+ from typing import Optional, List, Dict
565
+ import time
566
+
567
+ # ═══════════════════════════════════════════════════════════════════════════════
568
+ # CELLA 6: Modelli di rumore con operatori Kraus (VERSIONE INTEGRALE JAX FIXED)
569
+ # ═══════════════════════════════════════════════════════════════════════════════
570
+ # [PROPRIETARY ALGORITHM - (c) 2026 Salvatore Pennacchio - Licensed under EUPL-1.2]
571
+
572
+ try:
573
+ import jax
574
+ import jax.numpy as jnp
575
+ HAS_JAX = True
576
+ except ImportError:
577
+ HAS_JAX = False
578
+
579
+ class NoiseModel:
580
+ """
581
+ 5 modelli fisici di decoerenza con operatori Kraus.
582
+ Ottimizzato per agire in-place su NumPy/CuPy e funzionalmente su JAX XLA.
583
+ """
584
+
585
+ MODELS = ['ideal', 'depolarizing', 'bitflip', 'phaseflip', 'amplitude_damping', 'combined']
586
+
587
+ @staticmethod
588
+ def apply_to_sv(sv: np.ndarray, n: int, model: str, p: float,
589
+ rng: Optional[np.random.Generator] = None, qubits: Optional[List[int]] = None,
590
+ jax_key: Optional[any] = None) -> np.ndarray:
591
+ """
592
+ Applica il rumore stocastico al vettore di stato tramite traiettorie quantistiche (quantum jumps).
593
+
594
+ Args:
595
+ sv: Vettore di stato (np.ndarray o jnp.ndarray)
596
+ n: Numero totale di qubit nel registro
597
+ model: Stringa identificativa del modello di rumore
598
+ p: Probabilità di errore / parametro di damping
599
+ rng: Generatore di numeri casuali NumPy (usato solo per backend NumPy/CuPy)
600
+ qubits: Lista di qubit su cui applicare il rumore (default: tutti)
601
+ jax_key: jax.random.PRNGKey obbligatoria per garantire la stochastiticità sotto JAX JIT
602
+ """
603
+ if model == 'ideal' or p <= 0:
604
+ return sv
605
+
606
+ is_jax_array = HAS_JAX and isinstance(sv, jnp.ndarray)
607
+ xp = jnp if is_jax_array else np
608
+
609
+ target_qubits = qubits if qubits else list(range(n))
610
+ dim = len(sv)
611
+ sv_local = sv # JAX array immutabile, le operazioni .at restituiranno nuove istanze
612
+
613
+ # Fallback del generatore NumPy per compatibilità retroattiva NumPy/CuPy
614
+ if not is_jax_array and rng is None:
615
+ rng = np.random.default_rng(int(time.time()))
616
+
617
+ # Inizializzazione della chiave funzionale di JAX per evitare il tracing ghost
618
+ if is_jax_array:
619
+ if jax_key is None:
620
+ jax_key = jax.random.PRNGKey(int(time.time() * 1000))
621
+ current_key = jax_key
622
+
623
+ for q in target_qubits:
624
+ step = 1 << q
625
+ indices = xp.arange(dim)
626
+ mask_0 = (indices & step) == 0
627
+ idx_0 = xp.where(mask_0)[0]
628
+ idx_1 = idx_0 | step
629
+ len_idx = len(idx_0)
630
+
631
+ # --- GENERAZIONE DEL VETTORE CASUALE AGNOSTIC BACKEND ---
632
+ if is_jax_array:
633
+ current_key, subkey = jax.random.split(current_key)
634
+ r_vec = jax.random.uniform(subkey, shape=(len_idx,), minval=0.0, maxval=1.0)
635
+ else:
636
+ r_vec = rng.random(len_idx)
637
+
638
+ # --- APPLICAZIONE MODELLI DI RUMORE ---
639
+ if model == 'depolarizing':
640
+ mask_x = r_vec < p/3
641
+ mask_z = (r_vec >= p/3) & (r_vec < 2*p/3)
642
+ mask_y = (r_vec >= 2*p/3) & (r_vec < p)
643
+
644
+ if is_jax_array:
645
+ # Inversione di ampiezza X-Gate (JAX Immutabile via indici booleani fissi)
646
+ temp_sv_x = sv_local[idx_0[mask_x]]
647
+ sv_local = sv_local.at[idx_0[mask_x]].set(sv_local[idx_1[mask_x]])
648
+ sv_local = sv_local.at[idx_1[mask_x]].set(temp_sv_x)
649
+
650
+ # Inversione di fase Z-Gate (JAX Immutabile)
651
+ sv_local = sv_local.at[idx_1[mask_z]].multiply(-1)
652
+
653
+ # Rotazione complessa Y-Gate (JAX Immutabile)
654
+ temp_sv_y0 = sv_local[idx_0[mask_y]]
655
+ sv_local = sv_local.at[idx_0[mask_y]].set(-1j * sv_local[idx_1[mask_y]])
656
+ sv_local = sv_local.at[idx_1[mask_y]].set(1j * temp_sv_y0)
657
+ else:
658
+ # Inversione di ampiezza X-Gate (NumPy standard in-place mutabile)
659
+ temp_sv_x = sv_local[idx_0[mask_x]].copy()
660
+ sv_local[idx_0[mask_x]] = sv_local[idx_1[mask_x]]
661
+ sv_local[idx_1[mask_x]] = temp_sv_x
662
+
663
+ sv_local[idx_1[mask_z]] *= -1
664
+
665
+ # Rotazione complessa Y-Gate (NumPy Mutabile)
666
+ temp_sv_y0 = sv_local[idx_0[mask_y]].copy()
667
+ sv_local[idx_0[mask_y]] = -1j * sv_local[idx_1[mask_y]]
668
+ sv_local[idx_1[mask_y]] = 1j * temp_sv_y0
669
+
670
+ elif model == 'bitflip':
671
+ mask_flip = r_vec < p
672
+ if is_jax_array:
673
+ temp_sv_flip = sv_local[idx_0[mask_flip]]
674
+ sv_local = sv_local.at[idx_0[mask_flip]].set(sv_local[idx_1[mask_flip]])
675
+ sv_local = sv_local.at[idx_1[mask_flip]].set(temp_sv_flip)
676
+ else:
677
+ temp_sv_flip = sv_local[idx_0[mask_flip]].copy()
678
+ sv_local[idx_0[mask_flip]] = sv_local[idx_1[mask_flip]]
679
+ sv_local[idx_1[mask_flip]] = temp_sv_flip
680
+
681
+ elif model == 'phaseflip':
682
+ mask_flip = r_vec < p
683
+ if is_jax_array:
684
+ sv_local = sv_local.at[idx_1[mask_flip]].multiply(-1)
685
+ else:
686
+ sv_local[idx_1[mask_flip]] *= -1
687
+
688
+ elif model == 'amplitude_damping':
689
+ gamma = p
690
+ if is_jax_array:
691
+ sv_local = sv_local.at[idx_1].multiply(xp.sqrt(1 - gamma))
692
+ mask_decay = r_vec < gamma
693
+ sv_local = sv_local.at[idx_0[mask_decay]].add(sv_local[idx_1[mask_decay]])
694
+ sv_local = sv_local.at[idx_1[mask_decay]].set(0)
695
+ else:
696
+ sv_local[idx_1] *= np.sqrt(1 - gamma)
697
+ mask_decay = r_vec < gamma
698
+ sv_local[idx_0[mask_decay]] += sv_local[idx_1[mask_decay]]
699
+ sv_local[idx_1[mask_decay]] = 0
700
+
701
+ elif model == 'combined':
702
+ mask_x = r_vec < p*0.2
703
+ mask_z = (r_vec >= p*0.2) & (r_vec < p*0.4)
704
+ mask_y = (r_vec >= p*0.4) & (r_vec < p*0.6)
705
+
706
+ if is_jax_array:
707
+ temp_sv_x = sv_local[idx_0[mask_x]]
708
+ sv_local = sv_local.at[idx_0[mask_x]].set(sv_local[idx_1[mask_x]])
709
+ sv_local = sv_local.at[idx_1[mask_x]].set(temp_sv_x)
710
+
711
+ sv_local = sv_local.at[idx_1[mask_z]].multiply(-1)
712
+
713
+ temp_sv_y0 = sv_local[idx_0[mask_y]]
714
+ sv_local = sv_local.at[idx_0[mask_y]].set(-1j * sv_local[idx_1[mask_y]])
715
+ sv_local = sv_local.at[idx_1[mask_y]].set(1j * temp_sv_y0)
716
+
717
+ sv_local = sv_local.at[idx_1].multiply(xp.sqrt(1 - p*0.3))
718
+ else:
719
+ temp_sv_x = sv_local[idx_0[mask_x]].copy()
720
+ sv_local[idx_0[mask_x]] = sv_local[idx_1[mask_x]]
721
+ sv_local[idx_1[mask_x]] = temp_sv_x
722
+
723
+ sv_local[idx_1[mask_z]] *= -1
724
+
725
+ temp_sv_y0 = sv_local[idx_0[mask_y]].copy()
726
+ sv_local[idx_0[mask_y]] = -1j * sv_local[idx_1[mask_y]]
727
+ sv_local[idx_1[mask_y]] = 1j * temp_sv_y0
728
+
729
+ sv_local[idx_1] *= np.sqrt(1 - p*0.3)
730
+
731
+ # Rinormalizzazione di traiettoria protetta
732
+ norm = xp.linalg.norm(sv_local)
733
+ return sv_local / (norm + 1e-15)
734
+
735
+ @staticmethod
736
+ def kraus_description(model: str) -> Dict:
737
+ desc = {
738
+ 'ideal': {'kraus': 1, 'formula': 'K\u2080=I', 'physical': 'Nessun rumore'},
739
+ 'depolarizing': {'kraus': 4, 'formula': 'K\u2080=\u221a(1-p)I, K\u2081=\u221a(p/3)X, K\u2082=\u221a(p/3)Y, K\u2083=\u221a(p/3)Z', 'physical': 'Errore isotropo'},
740
+ 'bitflip': {'kraus': 2, 'formula': 'K\u2080=\u221a(1-p)I, K\u2081=\u221ap\u00b7X', 'physical': 'Flip di qubit \u03c3_x'},
741
+ 'phaseflip': {'kraus': 2, 'formula': 'K\u2080=\u221a(1-p)I, K\u2081=\u221ap\u00b7Z', 'physical': 'Dephasing puro'},
742
+ 'amplitude_damping': {'kraus': 2, 'formula': 'K\u2080=diag(1,\u221a(1-\u03b3)), K\u2081=[[0,\u221a\u03b3],[0,0]]', 'physical': 'Decadimento T\u2081 (relassazione)'},
743
+ 'combined': {'kraus': 6, 'formula': 'Dep(p*0.4) + AmpDamp(p*0.3)', 'physical': 'Worst-case NISQ'},
744
+ }
745
+ return desc.get(model, desc['ideal'])
746
+
747
+ print("✅ NoiseModel aggiornato (EUPL-1.2): Pieno supporto stocastico runtime JAX JIT sigillato!")
748
+
749
+ # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
750
+ # CELLA 7: OpenQASM 2.0 Parser & Transpiler
751
+ # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
752
+
753
+ @dataclass
754
+ class QASMCircuit:
755
+ n_qubits: int = 0
756
+ n_cbits: int = 0
757
+ ops: List[Dict[str, any]] = field(default_factory=list)
758
+
759
+
760
+ class QASMParser:
761
+ """Parse OpenQASM 2.0 circuits - Production Grade GitHub Standard"""
762
+
763
+ # Pre-compilazione delle espressioni regolari per eliminare l'overhead di parsing a runtime
764
+ _REG_QUBIT = re.compile(r'\[(\d+)\]')
765
+ _ALIAS_MAP = {'cu1': 'cp', 'u1': 'p', 'toffoli': 'ccx', 'fredkin': 'cswap'}
766
+ _MATH_ENV = {'__builtins__': {}, 'np': np, 'pi': np.pi, 'sin': np.sin,
767
+ 'cos': np.cos, 'sqrt': np.sqrt, 'exp': np.exp}
768
+
769
+ def parse(self, qasm_str: str) -> QASMCircuit:
770
+ n_qubits, n_cbits, ops = 0, 0, []
771
+
772
+ # Rimozione dei commenti in un unico passaggio lineare
773
+ clean = []
774
+ for raw in qasm_str.split('\n'):
775
+ line = raw.split('//')[0].strip()
776
+ if line:
777
+ clean.append(line)
778
+
779
+ # Tokenizzazione efficiente basata su delimitatore di istruzione standard ';'
780
+ for instr in "".join(clean).split(';'):
781
+ instr = instr.strip()
782
+ if not instr or any(instr.startswith(t) for t in ('OPENQASM', 'include', 'barrier')):
783
+ continue
784
+
785
+ if instr.startswith('qreg'):
786
+ m = self._REG_QUBIT.search(instr)
787
+ if m:
788
+ n_qubits = int(m.group(1))
789
+ continue
790
+
791
+ if instr.startswith('creg'):
792
+ m = self._REG_QUBIT.search(instr)
793
+ if m:
794
+ n_cbits = int(m.group(1))
795
+ continue
796
+
797
+ if instr.startswith('measure'):
798
+ continue
799
+
800
+ parts = instr.split()
801
+ if not parts:
802
+ continue
803
+
804
+ gate_raw = parts[0]
805
+ gate_name = gate_raw.split('(')[0].lower()
806
+ gate_name = self._ALIAS_MAP.get(gate_name, gate_name)
807
+
808
+ # Estrazione e risoluzione matematica deterministica dei parametri angolari delle porte
809
+ params: List[float] = []
810
+ if '(' in gate_raw:
811
+ try:
812
+ inner = gate_raw[gate_raw.index('(') + 1 : gate_raw.index(')')]
813
+ for tok in inner.split(','):
814
+ tok = tok.strip()
815
+ if tok:
816
+ params.append(float(eval(tok, self._MATH_ENV)))
817
+ except Exception:
818
+ params.append(0.0)
819
+
820
+ # Estrazione simultanea di tutti i qubit target coinvolti dall'istruzione
821
+ qubit_indices = [int(x) for x in self._REG_QUBIT.findall(" ".join(parts[1:]))]
822
+
823
+ if qubit_indices:
824
+ ops.append({'type': 'gate', 'name': gate_name,
825
+ 'qubits': qubit_indices, 'params': params})
826
+
827
+ return QASMCircuit(n_qubits, n_cbits, ops)
828
+
829
+ def validate(self, circ: QASMCircuit) -> Tuple[bool, str]:
830
+ if circ.n_qubits <= 0:
831
+ return False, "n_qubits deve essere > 0."
832
+ if not circ.ops:
833
+ return False, "Nessuna operazione rilevata nel circuito quantistico."
834
+ return True, ""
835
+
836
+
837
+ class QuantumTranspiler:
838
+ """Decompose multi-qubit gates into 1q and 2q execution primitives"""
839
+
840
+ @staticmethod
841
+ def decompose_toffoli(c1: int, c2: int, t: int) -> List[Tuple]:
842
+ """Barenco et al. decomposition optimized for Full-Vector mapping (6 CNOT gates)"""
843
+ return [
844
+ ('h', t),
845
+ ('cx', c2, t), ('tdg', t),
846
+ ('cx', c1, t), ('t', t),
847
+ ('cx', c2, t), ('tdg', t),
848
+ ('cx', c1, t),
849
+ ('t', c2), ('t', t),
850
+ ('cx', c1, c2), ('h', t),
851
+ ('t', c1), ('tdg', c2),
852
+ ('cx', c1, c2),
853
+ ]
854
+
855
+ @staticmethod
856
+ def decompose_swap(q1: int, q2: int) -> List[Tuple]:
857
+ """Decompose SWAP into core CNOT sequence compatible with hardware strides"""
858
+ return [('cx', q1, q2), ('cx', q2, q1), ('cx', q1, q2)]
859
+
860
+ @staticmethod
861
+ def transpile(circuit: List[Tuple]) -> List[Tuple]:
862
+ """Unroll high-level structures into native operational primitives"""
863
+ out = []
864
+ for cmd in circuit:
865
+ name = cmd[0].lower()
866
+ if name == 'ccx':
867
+ out.extend(QuantumTranspiler.decompose_toffoli(*cmd[1:]))
868
+ elif name == 'swap':
869
+ out.extend(QuantumTranspiler.decompose_swap(*cmd[1:]))
870
+ else:
871
+ out.append(cmd)
872
+ return out
873
+
874
+
875
+ print("✅ QASMParser and QuantumTranspiler loaded (Optimized regex & clean primitives)")
876
+
877
+ # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
878
+ # CELLA 8: Circuit Execution (MSB-aware Engine Core - CORRETTA)
879
+ # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
880
+
881
+ def run_circuit(self, circuit: List[Tuple], transpile: bool = True):
882
+ """
883
+ Execute circuit with automatic endianness correction.
884
+ FIXED: Prevents index double-flipping for multi-qubit gates.
885
+ """
886
+ target = QuantumTranspiler.transpile(circuit) if transpile else circuit
887
+ is_dense = hasattr(self, 'sv')
888
+
889
+ for cmd in target:
890
+ gate_name = cmd[0].lower()
891
+ args = cmd[1:]
892
+
893
+ mat = None
894
+ if gate_name in GATES:
895
+ mat = GATES[gate_name]
896
+ elif gate_name in PARAMETRIC_GATES:
897
+ try:
898
+ mat = PARAMETRIC_GATES[gate_name](*[a for a in args if isinstance(a, (float, int)) and not isinstance(a, bool)])
899
+ args = tuple([a for a in args if isinstance(a, int) and not isinstance(a, bool)])
900
+ except Exception:
901
+ pass
902
+
903
+ if mat is None:
904
+ method = getattr(self, f'apply_{gate_name}', None)
905
+ if method is None and gate_name == 'measure':
906
+ method = getattr(self, 'measure', None)
907
+
908
+ if method:
909
+ method(*args)
910
+ else:
911
+ raise ValueError(f"Porta quantistica o istruzione '{gate_name}' non riconosciuta.")
912
+ continue
913
+
914
+ mat = np.asarray(mat)
915
+ if mat.ndim == 2 and mat.shape == (2, 2):
916
+ # Pass logical qubit index directly. apply_gate_1q handles internal mapping.
917
+ self.apply_gate_1q(mat, args[0])
918
+ elif mat.ndim == 2 and mat.shape == (4, 4):
919
+ # Pass logical qubit indices directly. apply_gate_2q handles internal mapping.
920
+ self.apply_gate_2q(mat, args[0], args[1])
921
+ elif mat.ndim == 4:
922
+ # Pass logical qubit indices directly. apply_gate_2q handles internal mapping.
923
+ self.apply_gate_2q(mat.reshape(4, 4), args[0], args[1])
924
+
925
+ DenseSVSimulator.run_circuit = run_circuit
926
+ print("✅ run_circuit patchato con successo: allineamento indici MSB stabilizzato!")
927
+
928
+ # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
929
+ # CORE COMPILATION ENGINE (KERNEL FUSION LINEARE AD ALLOCAZIONE ZERO)
930
+ # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
931
+
932
+ import time
933
+ import numpy as np
934
+ from typing import List, Tuple
935
+
936
+ try:
937
+ import jax
938
+ import jax.numpy as jnp
939
+ HAS_JAX = True
940
+ # Abilita i 64-bit nativi in JAX per evitare overflow degli indici oltre i 24 qubit
941
+ jax.config.update("jax_enable_x64", True)
942
+ except ImportError:
943
+ HAS_JAX = False
944
+
945
+ # Mappatura stazionaria ottimizzata (0-11 per 1Q, 20-21 per 2Q)
946
+ GATE_IDS = {
947
+ 'id': 0, 'h': 1, 'x': 2, 'y': 3, 'z': 4, 's': 5, 'sdg': 6, 't': 7, 'tdg': 8,
948
+ 'rx': 9, 'ry': 10, 'rz': 11, 'cx': 20, 'cz': 21
949
+ }
950
+
951
+ if HAS_JAX:
952
+ @jax.jit
953
+ def _apply_gate_fast_step(sv, operation):
954
+
955
+ g_id, q1, q2, param = operation
956
+ dim = sv.shape[0]
957
+
958
+ inv2 = 1.0 / jnp.sqrt(2.0)
959
+ cos_p = jnp.cos(param / 2.0)
960
+ sin_p = jnp.sin(param / 2.0)
961
+
962
+ # Clamping protettivo dell'indice virtuale per evitare errori Out-of-Bounds in XLA
963
+ safe_gid = jnp.where(g_id <= 11, g_id, 0).astype(jnp.int32)
964
+
965
+ g_1q = jax.lax.switch(
966
+ safe_gid,
967
+ [
968
+ lambda _: jnp.eye(2, dtype=jnp.complex128), # 0: id
969
+ lambda _: inv2 * jnp.array([[1.0, 1.0], [1.0, -1.0]], dtype=jnp.complex128), # 1: h
970
+ lambda _: jnp.array([[0.0, 1.0], [1.0, 0.0]], dtype=jnp.complex128), # 2: x
971
+ lambda _: jnp.array([[0.0, -1j], [1j, 0.0]], dtype=jnp.complex128), # 3: y
972
+ lambda _: jnp.array([[1.0, 0.0], [0.0, -1.0]], dtype=jnp.complex128), # 4: z
973
+ lambda _: jnp.array([[1.0, 0.0], [0.0, 1j]], dtype=jnp.complex128), # 5: s
974
+ lambda _: jnp.array([[1.0, 0.0], [0.0, -1j]], dtype=jnp.complex128), # 6: sdg
975
+ lambda _: jnp.array([[1.0, 0.0], [0.0, jnp.exp(1j * jnp.pi / 4)]], dtype=jnp.complex128), # 7: t
976
+ lambda _: jnp.array([[1.0, 0.0], [0.0, jnp.exp(-1j * jnp.pi / 4)]], dtype=jnp.complex128), # 8: tdg
977
+ lambda _: jnp.array([[cos_p, -1j * sin_p], [-1j * sin_p, cos_p]], dtype=jnp.complex128), # 9: rx
978
+ lambda _: jnp.array([[cos_p, -sin_p], [sin_p, cos_p]], dtype=jnp.complex128), # 10: ry
979
+ lambda _: jnp.array([[jnp.exp(-1j * param / 2.0), 0.0], [0.0, jnp.exp(1j * param / 2.0)]], dtype=jnp.complex128) # 11: rz
980
+ ],
981
+ operand=None
982
+ )
983
+
984
+ # 1-QUBIT: APPLICAZIONE LINEARE MATRICE-SPAZZATA (Zero .reshape, Zero indici, Norma protetta)
985
+ def do_1q(_sv):
986
+ t_bit = q1.astype(jnp.int64)
987
+ stride = 1 << t_bit
988
+
989
+ # Generazione implicita delle maschere di canale a 1D contigua ammessa da XLA
990
+ # Isola gli stati accoppiati specchiati proiettando direttamente il vettore originale
991
+ idx_full = jnp.arange(dim, dtype=jnp.int64)
992
+ mask_0 = (idx_full & stride) == 0
993
+
994
+ # Troviamo i puntatori specchiati esatti per ciascuna cella di memoria
995
+ idx_0 = jnp.where(mask_0, idx_full, idx_full ^ stride)
996
+ idx_1 = idx_0 | stride
997
+
998
+ # Estrazione sicura dei coefficienti scalari complessi per evitare conflitti di broadcasting
999
+ g00, g01, g10, g11 = g_1q[0, 0], g_1q[0, 1], g_1q[1, 0], g_1q[1, 1]
1000
+
1001
+ # Calcolo simultaneo della superposizione lineare lungo i registri della CPU
1002
+ new_sv0 = g00 * _sv[idx_0] + g01 * _sv[idx_1]
1003
+ new_sv1 = g10 * _sv[idx_0] + g11 * _sv[idx_1]
1004
+
1005
+ # Ri-assemblaggio lineare continuo (Costo di allocazione scratchpad = 0)
1006
+ return jnp.where(mask_0, new_sv0, new_sv1)
1007
+
1008
+ # 2-QUBIT: APPLICAZIONE LINEARE CONTROLLATA (Zero .reshape, Zero indici, Anti-OOM a 24 Qubit)
1009
+ def do_2q(_sv):
1010
+ ctrl = q1.astype(jnp.int64)
1011
+ trgt = q2.astype(jnp.int64)
1012
+
1013
+ idx_full = jnp.arange(dim, dtype=jnp.int64)
1014
+ ctrl_active = (idx_full & (1 << ctrl)) != 0
1015
+ trgt_active = (idx_full & (1 << trgt)) != 0
1016
+
1017
+ # Caso CX: Inversione del bit target tramite operatore XOR lineare specchiato
1018
+ cx_sv = _sv[idx_full ^ (1 << trgt)]
1019
+ # Caso CZ: Inversione di fase condizionale sullo stato eccitato comune |11>
1020
+ cz_sv = jnp.where(trgt_active, -_sv, _sv)
1021
+
1022
+ # Selezione condizionale fusa del bersaglio mutato
1023
+ target_sv = jax.lax.cond(g_id == 20, lambda _: cx_sv, lambda _: cz_sv, operand=None)
1024
+
1025
+ # Restituisce il vettore modificato solo nei canali attivi, preservando intatto il resto
1026
+ return jnp.where(ctrl_active, target_sv, _sv)
1027
+
1028
+ # Configurazione del tracciatore statico ed esecuzione dei rami fusi
1029
+ exec_1q = g_id <= 11
1030
+ new_sv = jax.lax.cond(exec_1q, do_1q, do_2q, sv)
1031
+ return new_sv, None
1032
+
1033
+ @jax.jit
1034
+ def _compile_and_run_circuit_jit(state_vector, compiled_ops):
1035
+ """Pipeline lineare fusa in XLA tramite scansione nativa hardware"""
1036
+ final_sv, _ = jax.lax.scan(_apply_gate_fast_step, state_vector, compiled_ops)
1037
+ return final_sv
1038
+
1039
+ print("💎 CORE COMPILER SIGILLATO V4 (ULTRA): Rimossi definitivamente i reshape dinamici. Struttura JIT stabile ed esatta!")
1040
+
1041
+ # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
1042
+ # CORE COMPILATION ENGINE (KERNEL FUSION LINEARE AD ALLOCAZIONE ZERO)
1043
+ # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
1044
+ if HAS_JAX:
1045
+ @jax.jit
1046
+ def _apply_gate_fast_step(sv, operation):
1047
+ g_id, q1, q2, param = operation
1048
+ dim = sv.shape[0]
1049
+ inv2 = 1.0 / jnp.sqrt(2.0)
1050
+ cos_p = jnp.cos(param / 2.0)
1051
+ sin_p = jnp.sin(param / 2.0)
1052
+ safe_gid = jnp.where(g_id <= 11, g_id, 0).astype(jnp.int32)
1053
+
1054
+ g_1q = jax.lax.switch(
1055
+ safe_gid,
1056
+ [
1057
+ lambda _: jnp.eye(2, dtype=jnp.complex128),
1058
+ lambda _: inv2 * jnp.array([[1.0, 1.0], [1.0, -1.0]], dtype=jnp.complex128),
1059
+ lambda _: jnp.array([[0.0, 1.0], [1.0, 0.0]], dtype=jnp.complex128),
1060
+ lambda _: jnp.array([[0.0, -1j], [1j, 0.0]], dtype=jnp.complex128),
1061
+ lambda _: jnp.array([[1.0, 0.0], [0.0, -1.0]], dtype=jnp.complex128),
1062
+ lambda _: jnp.array([[1.0, 0.0], [0.0, 1j]], dtype=jnp.complex128),
1063
+ lambda _: jnp.array([[1.0, 0.0], [0.0, -1j]], dtype=jnp.complex128),
1064
+ lambda _: jnp.array([[1.0, 0.0], [0.0, jnp.exp(1j * jnp.pi / 4)]], dtype=jnp.complex128),
1065
+ lambda _: jnp.array([[1.0, 0.0], [0.0, jnp.exp(-1j * jnp.pi / 4)]], dtype=jnp.complex128),
1066
+ lambda _: jnp.array([[cos_p, -1j * sin_p], [-1j * sin_p, cos_p]], dtype=jnp.complex128),
1067
+ lambda _: jnp.array([[cos_p, -sin_p], [sin_p, cos_p]], dtype=jnp.complex128),
1068
+ lambda _: jnp.array([[jnp.exp(-1j * param / 2.0), 0.0], [0.0, jnp.exp(1j * param / 2.0)]], dtype=jnp.complex128)
1069
+ ],
1070
+ operand=None
1071
+ )
1072
+
1073
+ def do_1q(_sv):
1074
+ t_bit = q1.astype(jnp.int64)
1075
+ stride = 1 << t_bit
1076
+ idx_full = jnp.arange(dim, dtype=jnp.int64)
1077
+ mask_0 = (idx_full & stride) == 0
1078
+ idx_0 = jnp.where(mask_0, idx_full, idx_full ^ stride)
1079
+ idx_1 = idx_0 | stride
1080
+ g00, g01, g10, g11 = g_1q[0, 0], g_1q[0, 1], g_1q[1, 0], g_1q[1, 1]
1081
+ new_sv0 = g00 * _sv[idx_0] + g01 * _sv[idx_1]
1082
+ new_sv1 = g10 * _sv[idx_0] + g11 * _sv[idx_1]
1083
+ return jnp.where(mask_0, new_sv0, new_sv1)
1084
+
1085
+ def do_2q(_sv):
1086
+ ctrl = q1.astype(jnp.int64)
1087
+ trgt = q2.astype(jnp.int64)
1088
+ idx_full = jnp.arange(dim, dtype=jnp.int64)
1089
+ ctrl_active = (idx_full & (1 << ctrl)) != 0
1090
+ trgt_active = (idx_full & (1 << trgt)) != 0
1091
+ cx_sv = _sv[idx_full ^ (1 << trgt)]
1092
+ cz_sv = jnp.where(trgt_active, -_sv, _sv)
1093
+ target_sv = jax.lax.cond(g_id == 20, lambda _: cx_sv, lambda _: cz_sv, operand=None)
1094
+ return jnp.where(ctrl_active, target_sv, _sv)
1095
+
1096
+ exec_1q = g_id <= 11
1097
+ new_sv = jax.lax.cond(exec_1q, do_1q, do_2q, sv)
1098
+ return new_sv, None
1099
+
1100
+ @jax.jit
1101
+ def _compile_and_run_circuit_jit(state_vector, compiled_ops):
1102
+ final_sv, _ = jax.lax.scan(_apply_gate_fast_step, state_vector, compiled_ops)
1103
+ return final_sv
1104
+
1105
+ print("💎 CORE COMPILER PATCHATO V4: Struttura 1D lineare stabilizzata a norma fissa.")
1106
+
1107
+
1108
+ # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
1109
+ # INTERFACCIA: run_circuit_jit_beast_mode (Mappatura Riallineata V2)
1110
+ # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
1111
+
1112
+ def run_circuit_jit_beast_mode(self, circuit: List[Tuple]):
1113
+
1114
+ if not (HAS_JAX and self.xp is jnp):
1115
+ print("⚠️ JAX non attivo o istanza non JAX. Esecuzione via run_circuit standard...")
1116
+ return self.run_circuit(circuit, transpile=True)
1117
+
1118
+ # Scomposizione preliminare delle macro-porte (Toffoli, SWAP) nelle primitive compatibili
1119
+ target = QuantumTranspiler.transpile(circuit)
1120
+
1121
+ compiled_list = []
1122
+ for cmd in target:
1123
+ g_name = cmd[0].lower()
1124
+ args = cmd[1:]
1125
+
1126
+ if g_name in GATE_IDS:
1127
+ g_id = GATE_IDS[g_name]
1128
+
1129
+ # Smistamento in base alla firma della porta nella nuova mappa stazionaria
1130
+ if g_name in ['rx', 'ry', 'rz']:
1131
+ # Struttura gate parametrico standard: (name, qubit, theta)
1132
+ q1 = float(args[0])
1133
+ q2 = 0.0
1134
+ param = float(args[1])
1135
+ elif g_name in ['cx', 'cz']:
1136
+ # Struttura gate a due qubit standard: (name, control, target)
1137
+ q1 = float(args[0])
1138
+ q2 = float(args[1])
1139
+ param = 0.0
1140
+ else:
1141
+ # Struttura gate fissa a un qubit standard: (name, qubit)
1142
+ q1 = float(args[0])
1143
+ q2 = 0.0
1144
+ param = 0.0
1145
+
1146
+ compiled_list.append([float(g_id), q1, q2, param])
1147
+
1148
+ if not compiled_list:
1149
+ return
1150
+
1151
+ # Generazione della matrice di operazioni numeriche coerente [N_porte, 4] in float64
1152
+ compiled_ops = jnp.array(compiled_list, dtype=jnp.float64)
1153
+
1154
+ # Invocazione della pipeline fusa XLA ad alte prestazioni
1155
+ self.sv = _compile_and_run_circuit_jit(self.sv, compiled_ops)
1156
+
1157
+
1158
+ # Iniezione dell'interfaccia corretta e ripulita nella classe del simulatore
1159
+ DenseSVSimulator.run_circuit_jit_beast_mode = run_circuit_jit_beast_mode
1160
+ print("💎 INTERFACCIA RIALLINEATA: 'run_circuit_jit_beast_mode' agganciata con successo a DenseSVSimulator!")
1161
+
1162
+ # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
1163
+ # AGGANCIO RUNTIME MANCANTI: measure & memory_mb
1164
+ # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
1165
+
1166
+ def measure(self, qubit_idx: int) -> int:
1167
+
1168
+ if not 0 <= qubit_idx < self.n:
1169
+ raise ValueError(f"Qubit {qubit_idx} out of bounds")
1170
+
1171
+ xp = self.xp
1172
+
1173
+ # Calcolo dell'indice fisico specchiato per la convenzione MSB
1174
+ phys_q = self.n - 1 - qubit_idx
1175
+ stride = 1 << phys_q
1176
+
1177
+ if xp is jnp:
1178
+ # Ramo JAX: Calcolo conforme al tracciamento statico dei tensori
1179
+ probs = self.xp.abs(self.sv)**2
1180
+ sv_shape = [2] * self.n
1181
+ sv_nd = probs.reshape(sv_shape)
1182
+ prob_0 = float(jnp.sum(jnp.moveaxis(sv_nd, phys_q, 0)[0]))
1183
+ prob_1 = float(jnp.sum(jnp.moveaxis(sv_nd, phys_q, 0)[1]))
1184
+ else:
1185
+ # Ramo NumPy/CuPy Ultra-Performante: Somma a salti in memoria (Zero allocazione)
1186
+ sv_reshaped = self.sv.reshape(-1, 2, stride)
1187
+ prob_0 = float(xp.sum(xp.abs(sv_reshaped[:, 0, :])**2))
1188
+ prob_1 = float(xp.sum(xp.abs(sv_reshaped[:, 1, :])**2))
1189
+
1190
+ # Normalizzazione delle probabilità estratte
1191
+ total = prob_0 + prob_1
1192
+ if total > 1e-12:
1193
+ prob_0 /= total
1194
+ prob_1 /= total
1195
+
1196
+ # Campionamento dell'esito della misura
1197
+ result = int(np.random.choice([0, 1], p=[prob_0, prob_1]))
1198
+
1199
+ # Collasso della funzione d'onda in-place (Zero allocazione di maschere giganti)
1200
+ if xp is jnp:
1201
+ sv_shape = [2] * self.n
1202
+ sv_nd = self.sv.reshape(sv_shape)
1203
+ moved_sv = jnp.moveaxis(sv_nd, phys_q, 0)
1204
+ moved_sv = moved_sv.at[1 if result == 0 else 0].set(0.0)
1205
+ self.sv = jnp.moveaxis(moved_sv, 0, phys_q).ravel()
1206
+ else:
1207
+ # Slicing chirurgico nativo: azzera metà del vettore direttamente sulla matrice di vista
1208
+ sv_reshaped = self.sv.reshape(-1, 2, stride)
1209
+ sv_reshaped[:, 1 if result == 0 else 0, :] = 0.0
1210
+
1211
+ self.normalize()
1212
+ return result
1213
+
1214
+ def memory_mb(self) -> float:
1215
+ """Estimate RAM usage in MB"""
1216
+ elem_size = 8 if self.dtype == np.complex64 else 16
1217
+ return self.dim * elem_size / 1e6
1218
+
1219
+ # Forza l'iniezione e l'ancoraggio dei due metodi nella classe principale
1220
+ DenseSVSimulator.measure = measure
1221
+ DenseSVSimulator.memory_mb = memory_mb
1222
+
1223
+ print("🚀 Metodi 'measure' e 'memory_mb' agganciati ed iniettati con successo in DenseSVSimulator!")
1224
+
1225
+ import random # Import the standard random module
1226
+
1227
+ def measure(self, qubit_idx: int) -> int:
1228
+
1229
+ if not 0 <= qubit_idx < self.n:
1230
+ raise ValueError(f"Qubit {qubit_idx} out of bounds")
1231
+
1232
+ xp = self.xp
1233
+ phys_q = self.n - 1 - qubit_idx
1234
+ stride = 1 << phys_q
1235
+
1236
+ if xp is jnp:
1237
+ # Ramo JAX: Calcolo esatto estraendo gli indici ffetivi dell'asse tensoriale spostato
1238
+ probs = self.xp.abs(self.sv)**2
1239
+ sv_shape = [2] * self.n
1240
+ sv_nd = probs.reshape(sv_shape)
1241
+ moved_probs = jnp.moveaxis(sv_nd, phys_q, 0)
1242
+ prob_0 = float(jnp.sum(moved_probs[0]))
1243
+ prob_1 = float(jnp.sum(moved_probs[1]))
1244
+ else:
1245
+ # Ramo NumPy/CuPy Stride Slicing
1246
+ sv_reshaped = self.sv.reshape(-1, 2, stride)
1247
+ prob_0 = float(xp.sum(xp.abs(sv_reshaped[:, 0, :])**2))
1248
+ prob_1 = float(xp.sum(xp.abs(sv_reshaped[:, 1, :])**2))
1249
+
1250
+ total = prob_0 + prob_1
1251
+ if total > 1e-12:
1252
+ prob_0 /= total
1253
+ prob_1 /= total
1254
+
1255
+ # Campionamento dell'esito della misura
1256
+ result = int(np.random.choice([0, 1], p=[prob_0, prob_1]))
1257
+
1258
+ if xp is jnp:
1259
+ sv_shape = [2] * self.n
1260
+ sv_nd = self.sv.reshape(sv_shape)
1261
+ moved_sv = jnp.moveaxis(sv_nd, phys_q, 0)
1262
+ # FIX: Correctly zero out the unmeasured component (1 if result is 0, 0 if result is 1)
1263
+ moved_sv = moved_sv.at[1 - result].set(0.0)
1264
+ self.sv = jnp.moveaxis(moved_sv, 0, phys_q).ravel()
1265
+ else:
1266
+ sv_reshaped = self.sv.reshape(-1, 2, stride)
1267
+ sv_reshaped[:, 1 if result == 0 else 0, :] = 0.0
1268
+
1269
+ self.normalize()
1270
+ return result
1271
+
1272
+
1273
+ def apply_cx(self, ctrl: int, tgt: int):
1274
+ """CNOT gate - Corregge l'allineamento degli assi fisici per JAX ed evita il double-flipping."""
1275
+ xp = self.xp
1276
+ if not (0 <= ctrl < self.n and 0 <= tgt < self.n and ctrl != tgt):
1277
+ raise ValueError(f"Invalid control ({ctrl}) or target ({tgt})")
1278
+
1279
+ if xp is jnp:
1280
+ # Pass logical qubit indices directly to apply_gate_2q as it handles internal mapping for JAX.
1281
+ cx_mat = xp.array([[1,0,0,0],[0,1,0,0],[0,0,0,1],[0,0,1,0]], dtype=self.dtype)
1282
+ self.apply_gate_2q(cx_mat, ctrl, tgt)
1283
+ else:
1284
+ # Ramo NumPy/CuPy Stride Slicing classico in-place
1285
+ c_stride = 1 << (self.n - 1 - ctrl)
1286
+ t_stride = 1 << (self.n - 1 - tgt)
1287
+ step = 2 * max(c_stride, t_stride)
1288
+ inner_step = 2 * min(c_stride, t_stride)
1289
+
1290
+ for i in range(0, self.dim, step):
1291
+ for j in range(0, max(c_stride, t_stride), inner_step):
1292
+ base_idx = i + j + c_stride
1293
+ idx_0 = base_idx
1294
+ idx_1 = base_idx + t_stride
1295
+ tmp = self.sv[idx_0 : idx_0 + min(c_stride, t_stride)].copy()
1296
+ self.sv[idx_0 : idx_0 + min(c_stride, t_stride)] = self.sv[idx_1 : idx_1 + min(c_stride, t_stride)]
1297
+ self.sv[idx_1 : idx_1 + min(c_stride, t_stride)] = tmp
1298
+
1299
+
1300
+ def apply_cz(self, ctrl: int, tgt: int):
1301
+ """Controlled-Z gate - Corregge l'allineamento degli assi fisici per JAX."""
1302
+ xp = self.xp
1303
+ if not (0 <= ctrl < self.n and 0 <= tgt < self.n and ctrl != tgt):
1304
+ raise ValueError(f"Invalid indices ({ctrl}, {tgt})")
1305
+
1306
+ if xp is jnp:
1307
+ # Pass logical qubit indices directly to apply_gate_2q as it handles internal mapping for JAX.
1308
+ cz_mat = xp.array([[1,0,0,0],[0,1,0,0],[0,0,1,0],[0,0,0,-1]], dtype=self.dtype)
1309
+ self.apply_gate_2q(cz_mat, ctrl, tgt)
1310
+ else:
1311
+ c_stride = 1 << (self.n - 1 - ctrl)
1312
+ t_stride = 1 << (self.n - 1 - tgt)
1313
+ step = 2 * max(c_stride, t_stride)
1314
+ inner_step = 2 * min(c_stride, t_stride)
1315
+
1316
+ for i in range(0, self.dim, step):
1317
+ for j in range(0, max(c_stride, t_stride), inner_step):
1318
+ idx = i + j + c_stride + t_stride
1319
+ self.sv[idx : idx + min(c_stride, t_stride)] *= -1
1320
+
1321
+
1322
+ # Iniezione strutturata finale
1323
+ DenseSVSimulator.measure = measure
1324
+ DenseSVSimulator.apply_cx = apply_cx
1325
+ DenseSVSimulator.apply_cz = apply_cz
1326
+
1327
+ print("💎 ENGINE CORE RIALLINEATO PERFETTAMENTE: Tutti i canali JAX e NumPy sono stabili!")
1328
+
1329
+ DenseSVSimulator.measure = measure
1330
+ DenseSVSimulator.apply_cx = apply_cx
1331
+ DenseSVSimulator.apply_cz = apply_cz
1332
+
1333
+ print("💎 ENGINE CORE RIALLINEATO PERFETTAMENTE: Tutti i canali JAX e NumPy sono stabili!")
1334
+
1335
+ import numpy as np
1336
+
1337
+ def run_circuit_with_chunking(self, circuit: list, chunk_size: int = 500, transpile: bool = True):
1338
+ """
1339
+ Esegue circuiti quantistici di profondità estrema frammentandoli in sotto-blocchi.
1340
+ Previene la saturazione della cache JIT di JAX e azzera l'overhead sui circuiti NISQ.
1341
+ """
1342
+ # 1. Transpilazione preliminare facoltativa delle macro-porte
1343
+ target_circuit = QuantumTranspiler.transpile(circuit) if transpile else circuit
1344
+ total_gates = len(target_circuit)
1345
+
1346
+ if total_gates <= chunk_size:
1347
+ try:
1348
+ self.run_circuit_jit_beast_mode(target_circuit)
1349
+ except Exception:
1350
+ self.run_circuit(target_circuit)
1351
+ return
1352
+
1353
+ print(f"⚙️ Circuit Chunking Attivo: {total_gates} gate totali divisi in blocchi da {chunk_size}...")
1354
+
1355
+ # 2. Suddivisione lineare del circuito in chunk protetti
1356
+ for i in range(0, total_gates, chunk_size):
1357
+ chunk = target_circuit[i : i + chunk_size]
1358
+
1359
+ try:
1360
+ self.run_circuit_jit_beast_mode(chunk)
1361
+ except Exception:
1362
+ self.run_circuit(chunk)
1363
+
1364
+ # Forza JAX a sincronizzare e scaricare i buffer temporanei della CPU/TPU
1365
+ if self.xp.__name__ == 'jax.numpy':
1366
+ self.sv.block_until_ready()
1367
+
1368
+ print(f"✅ Esecuzione completata con successo tramite {int(np.ceil(total_gates/chunk_size))} Chunk geometrici.")
1369
+
1370
+ # Iniezione del metodo enterprise corretto nella classe principale
1371
+ DenseSVSimulator.run_circuit_with_chunking = run_circuit_with_chunking
1372
+
1373
+ import jax
1374
+ import jax.numpy as jnp
1375
+ import numpy as np
1376
+ import time
1377
+
1378
+ def run_parametric_batch_jit(self, base_circuit: list, parameter_batch: np.ndarray) -> jnp.ndarray:
1379
+
1380
+ if not HAS_JAX or self.xp is not jnp:
1381
+ raise RuntimeError("JAX deve essere il backend attivo per usare run_parametric_batch_jit.")
1382
+
1383
+ # Decomposizione preliminare delle macro-porte (Toffoli, SWAP) tramite il tuo Transpiler
1384
+ target_circuit = QuantumTranspiler.transpile(base_circuit)
1385
+
1386
+ # Mappiamo il circuito secondo lo standard numerico della tua CELLA 8
1387
+ compiled_list = []
1388
+ for cmd in target_circuit:
1389
+ g_name = cmd[0].lower()
1390
+ args = cmd[1:]
1391
+
1392
+ if g_name in GATE_IDS:
1393
+ g_id = GATE_IDS[g_name]
1394
+ if g_name in ['rx', 'ry', 'rz', 'p']:
1395
+ # Memorizziamo una flag numerica (-1) per identificare la posizione del parametro dinamico
1396
+ compiled_list.append([float(g_id), float(args[0]), 0.0, -1.0])
1397
+ elif g_name in ['cx', 'cz']:
1398
+ compiled_list.append([float(g_id), float(args[0]), float(args[1]), 0.0])
1399
+ else:
1400
+ compiled_list.append([float(g_id), float(args[0]), 0.0, 0.0])
1401
+
1402
+ compiled_ops_template = jnp.array(compiled_list, dtype=jnp.float64)
1403
+ n_qubits = self.n
1404
+ dim = self.dim
1405
+
1406
+ # Definizione della funzione da vettorializzare per la singola istanza del batch
1407
+ def simulate_single_instance(single_params):
1408
+ # Inizializzazione dello stato |00...0> conforme alla tua CELLA 5
1409
+ local_sv = jnp.zeros(dim, dtype=jnp.complex128).at[0].set(1.0)
1410
+
1411
+ # Ricostruiamo la matrice delle operazioni sostituendo i parametri dinamici del batch
1412
+ # Trova dove abbiamo messo la flag -1.0 e inserisce il parametro reale
1413
+ def patch_ops(carry, op):
1414
+ g_id, q1, q2, p_val = op
1415
+ param_idx = carry[0]
1416
+
1417
+ # Se p_val == -1.0, prendiamo il parametro corrente dal batch e incrementiamo l'indice
1418
+ final_p = jax.lax.cond(p_val == -1.0, lambda _: single_params[param_idx], lambda _: p_val, operand=None)
1419
+ next_idx = jax.lax.cond(p_val == -1.0, lambda _: param_idx + 1, lambda _: param_idx, operand=None)
1420
+
1421
+ return (next_idx,), jnp.array([g_id, q1, q2, final_p], dtype=jnp.float64)
1422
+
1423
+ _, patched_ops = jax.lax.scan(patch_ops, (0,), compiled_ops_template)
1424
+
1425
+ # Chiamata diretta al tuo motore fuso XLA nativo (Cella 7 del tuo notebook)
1426
+ return _compile_and_run_circuit_jit(local_sv, patched_ops)
1427
+
1428
+ print(f"🚀 VMAP COMPILER: Parallelizzazione inter-circuito attiva per {len(parameter_batch)} istanze...")
1429
+
1430
+ # Applichiamo vmap sul super-grafo fuso
1431
+ vmap_sim = jax.vmap(simulate_single_instance, in_axes=(0,))
1432
+ jitted_vmap = jax.jit(vmap_sim)
1433
+
1434
+ t0 = time.perf_counter()
1435
+ res = jitted_vmap(jnp.asarray(parameter_batch, dtype=jnp.float64))
1436
+ res.block_until_ready()
1437
+ print(f"✅ Batch completato in {time.perf_counter() - t0:.4f} secondi!")
1438
+ return res
1439
+
1440
+ # Iniettiamo il metodo nel tuo simulatore originale
1441
+ DenseSVSimulator.run_parametric_batch_jit = run_parametric_batch_jit
1442
+ print("💎 BATCH ENGINE AGGANGIATO: Pieno supporto QML & VQE attivo sul tuo core!")
1443
+
1444
+
1445
+