emu-sv 1.0.0__py3-none-any.whl → 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,195 @@
1
+ from __future__ import annotations
2
+ from collections import Counter
3
+ import math
4
+ from typing import Mapping, TypeVar, Type, Sequence
5
+ import torch
6
+ from pulser.backend import State
7
+ from emu_base import DEVICE_COUNT
8
+ from emu_sv.state_vector import StateVector
9
+ from emu_sv.utils import index_to_bitstring
10
+ from pulser.backend.state import Eigenstate
11
+
12
+ DensityMatrixType = TypeVar("DensityMatrixType", bound="DensityMatrix")
13
+
14
+ dtype = torch.complex128
15
+
16
+
17
+ class DensityMatrix(State[complex, torch.Tensor]):
18
+ """Represents a density matrix in a computational basis."""
19
+
20
+ # for the moment no need to check positivity and trace 1
21
+ def __init__(
22
+ self,
23
+ matrix: torch.Tensor,
24
+ *,
25
+ gpu: bool = True,
26
+ ):
27
+ # NOTE: this accepts also zero matrices.
28
+
29
+ device = "cuda" if gpu and DEVICE_COUNT > 0 else "cpu"
30
+ self.matrix = matrix.to(dtype=dtype, device=device)
31
+
32
+ @property
33
+ def n_qudits(self) -> int:
34
+ """The number of qudits in the state."""
35
+ nqudits = math.log2(self.matrix.shape[0])
36
+ return int(nqudits)
37
+
38
+ @classmethod
39
+ def make(cls, n_atoms: int, gpu: bool = True) -> DensityMatrix:
40
+ """Creates the density matrix of the ground state |000...0>"""
41
+ result = torch.zeros(2**n_atoms, 2**n_atoms, dtype=dtype)
42
+ result[0, 0] = 1.0
43
+ return cls(result, gpu=gpu)
44
+
45
+ def __add__(self, other: State) -> DensityMatrix:
46
+ raise NotImplementedError("Not implemented")
47
+
48
+ def __rmul__(self, scalar: complex) -> DensityMatrix:
49
+ raise NotImplementedError("Not implemented")
50
+
51
+ def _normalize(self) -> None:
52
+ # NOTE: use this in the callbacks
53
+ """Normalize the density matrix state"""
54
+ matrix_trace = torch.trace(self.matrix)
55
+ if not torch.allclose(matrix_trace, torch.tensor(1.0, dtype=torch.float64)):
56
+ self.matrix = self.matrix / matrix_trace
57
+
58
+ def overlap(self, other: State) -> torch.Tensor:
59
+ """
60
+ Compute Tr(self^† @ other). The type of other must be DensityMatrix.
61
+
62
+ Args:
63
+ other: the other state
64
+
65
+ Returns:
66
+ the inner product
67
+
68
+ Example:
69
+ >>> density_bell_state = (1/2* torch.tensor([[1, 0, 0, 1], [0, 0, 0, 0],
70
+ ... [0, 0, 0, 0], [1, 0, 0, 1]],dtype=torch.complex128))
71
+ >>> density_c = DensityMatrix(density_bell_state, gpu=False)
72
+ >>> density_c.overlap(density_c)
73
+ tensor(1.+0.j, dtype=torch.complex128)
74
+ """
75
+
76
+ assert isinstance(
77
+ other, DensityMatrix
78
+ ), "Other state also needs to be a DensityMatrix"
79
+ assert (
80
+ self.matrix.shape == other.matrix.shape
81
+ ), "States do not have the same number of sites"
82
+
83
+ return torch.vdot(self.matrix.flatten(), other.matrix.flatten())
84
+
85
+ @classmethod
86
+ def from_state_vector(cls, state: StateVector) -> DensityMatrix:
87
+ """Convert a state vector to a density matrix.
88
+ This function takes a state vector |ψ❭ and returns the corresponding
89
+ density matrix ρ = |ψ❭❬ψ| representing the pure state |ψ❭.
90
+ Example:
91
+ >>> from emu_sv import StateVector
92
+ >>> import math
93
+ >>> bell_state_vec = 1 / math.sqrt(2) * torch.tensor(
94
+ ... [1.0, 0.0, 0.0, 1.0j],dtype=torch.complex128)
95
+ >>> bell_state = StateVector(bell_state_vec, gpu=False)
96
+ >>> density = DensityMatrix.from_state_vector(bell_state)
97
+ >>> print(density.matrix)
98
+ tensor([[0.5000+0.0000j, 0.0000+0.0000j, 0.0000+0.0000j, 0.0000-0.5000j],
99
+ [0.0000+0.0000j, 0.0000+0.0000j, 0.0000+0.0000j, 0.0000+0.0000j],
100
+ [0.0000+0.0000j, 0.0000+0.0000j, 0.0000+0.0000j, 0.0000+0.0000j],
101
+ [0.0000+0.5000j, 0.0000+0.0000j, 0.0000+0.0000j, 0.5000+0.0000j]],
102
+ dtype=torch.complex128)
103
+ """
104
+
105
+ return cls(
106
+ torch.outer(state.vector, state.vector.conj()), gpu=state.vector.is_cuda
107
+ )
108
+
109
+ @classmethod
110
+ def _from_state_amplitudes(
111
+ cls: Type[DensityMatrixType],
112
+ *,
113
+ eigenstates: Sequence[Eigenstate],
114
+ amplitudes: Mapping[str, complex],
115
+ ) -> tuple[DensityMatrix, Mapping[str, complex]]:
116
+ """Transforms a state given by a string into a density matrix.
117
+
118
+ Construct a state from the pulser abstract representation
119
+ https://pulser.readthedocs.io/en/stable/conventions.html
120
+
121
+ Args:
122
+ basis: A tuple containing the basis states (e.g., ('r', 'g')).
123
+ nqubits: the number of qubits.
124
+ strings: A dictionary mapping state strings to complex or floats amplitudes.
125
+
126
+ Returns:
127
+ The resulting state.
128
+
129
+ Examples:
130
+ >>> eigenstates = ("r","g")
131
+ >>> n = 2
132
+ >>> dense_mat=DensityMatrix.from_state_amplitudes(eigenstates=eigenstates,
133
+ ... amplitudes={"rr":1.0,"gg":1.0})
134
+ >>> print(dense_mat.matrix)
135
+ tensor([[0.5000+0.j, 0.0000+0.j, 0.0000+0.j, 0.5000+0.j],
136
+ [0.0000+0.j, 0.0000+0.j, 0.0000+0.j, 0.0000+0.j],
137
+ [0.0000+0.j, 0.0000+0.j, 0.0000+0.j, 0.0000+0.j],
138
+ [0.5000+0.j, 0.0000+0.j, 0.0000+0.j, 0.5000+0.j]],
139
+ dtype=torch.complex128)
140
+ """
141
+
142
+ state_vector, amplitudes = StateVector._from_state_amplitudes(
143
+ eigenstates=eigenstates, amplitudes=amplitudes
144
+ )
145
+
146
+ return DensityMatrix.from_state_vector(state_vector), amplitudes
147
+
148
+ def sample(
149
+ self,
150
+ num_shots: int = 1000,
151
+ one_state: Eigenstate | None = None,
152
+ p_false_pos: float = 0.0,
153
+ p_false_neg: float = 0.0,
154
+ ) -> Counter[str]:
155
+ """
156
+ Samples bitstrings, taking into account the specified error rates.
157
+
158
+ Args:
159
+ num_shots: how many bitstrings to sample
160
+ p_false_pos: the rate at which a 0 is read as a 1
161
+ p_false_neg: teh rate at which a 1 is read as a 0
162
+
163
+ Returns:
164
+ the measured bitstrings, by count
165
+
166
+ Example:
167
+ >>> import math
168
+ >>> torch.manual_seed(1234)
169
+ >>> from emu_sv import StateVector
170
+ >>> bell_vec = 1 / math.sqrt(2) * torch.tensor(
171
+ ... [1.0, 0.0, 0.0, 1.0j],dtype=torch.complex128)
172
+ >>> bell_state_vec = StateVector(bell_vec)
173
+ >>> bell_density = DensityMatrix.from_state_vector(bell_state_vec)
174
+ >>> bell_density.sample(1000)
175
+ Counter({'00': 517, '11': 483})
176
+ """
177
+
178
+ assert p_false_neg == p_false_pos == 0.0, "Error rates must be 0.0"
179
+
180
+ probabilities = torch.abs(self.matrix.diagonal())
181
+
182
+ outcomes = torch.multinomial(probabilities, num_shots, replacement=True)
183
+
184
+ # Convert outcomes to bitstrings and count occurrences
185
+ counts = Counter(
186
+ [index_to_bitstring(self.n_qudits, outcome) for outcome in outcomes]
187
+ )
188
+
189
+ return counts
190
+
191
+
192
+ if __name__ == "__main__":
193
+ import doctest
194
+
195
+ doctest.testmod()
emu_sv/hamiltonian.py CHANGED
@@ -71,15 +71,14 @@ class RydbergHamiltonian:
71
71
  the resulting state vector.
72
72
  """
73
73
  # (-∑ⱼΔⱼnⱼ + ∑ᵢ﹥ⱼUᵢⱼnᵢnⱼ)|ψ❭
74
- diag_result = self.diag * vec
74
+ result = self.diag * vec
75
75
  # ∑ⱼΩⱼ/2[cos(ϕⱼ)σˣⱼ + sin(ϕⱼ)σʸⱼ]|ψ❭
76
- sigma_result = self._apply_sigma_operators(vec)
77
- result: torch.Tensor
78
- result = diag_result + sigma_result
79
-
76
+ self._apply_sigma_operators(result, vec)
80
77
  return result
81
78
 
82
- def _apply_sigma_operators_real(self, vec: torch.Tensor) -> torch.Tensor:
79
+ def _apply_sigma_operators_real(
80
+ self, result: torch.Tensor, vec: torch.Tensor
81
+ ) -> None:
83
82
  """
84
83
  Apply the ∑ⱼ(Ωⱼ/2)σˣⱼ operator to the input vector |ψ❭.
85
84
 
@@ -89,18 +88,16 @@ class RydbergHamiltonian:
89
88
  Returns:
90
89
  the resulting state vector.
91
90
  """
92
- result = torch.zeros_like(vec)
93
-
94
91
  dim_to_act = 1
95
92
  for n, omega_n in enumerate(self.omegas):
96
93
  shape_n = (2**n, 2, 2 ** (self.nqubits - n - 1))
97
- vec = vec.reshape(shape_n)
98
- result = result.reshape(shape_n)
94
+ vec = vec.view(shape_n)
95
+ result = result.view(shape_n)
99
96
  result.index_add_(dim_to_act, self.inds, vec, alpha=omega_n)
100
97
 
101
- return result.reshape(-1)
102
-
103
- def _apply_sigma_operators_complex(self, vec: torch.Tensor) -> torch.Tensor:
98
+ def _apply_sigma_operators_complex(
99
+ self, result: torch.Tensor, vec: torch.Tensor
100
+ ) -> None:
104
101
  """
105
102
  Apply the ∑ⱼΩⱼ/2[cos(ϕⱼ)σˣⱼ + sin(ϕⱼ)σʸⱼ] operator to the input vector |ψ❭.
106
103
 
@@ -111,13 +108,12 @@ class RydbergHamiltonian:
111
108
  the resulting state vector.
112
109
  """
113
110
  c_omegas = self.omegas * torch.exp(1j * self.phis)
114
- result = torch.zeros_like(vec)
115
111
 
116
112
  dim_to_act = 1
117
113
  for n, c_omega_n in enumerate(c_omegas):
118
114
  shape_n = (2**n, 2, 2 ** (self.nqubits - n - 1))
119
- vec = vec.reshape(shape_n)
120
- result = result.reshape(shape_n)
115
+ vec = vec.view(shape_n)
116
+ result = result.view(shape_n)
121
117
  result.index_add_(
122
118
  dim_to_act, self.inds[0], vec[:, 0, :].unsqueeze(1), alpha=c_omega_n
123
119
  )
@@ -128,8 +124,6 @@ class RydbergHamiltonian:
128
124
  alpha=c_omega_n.conj(),
129
125
  )
130
126
 
131
- return result.reshape(-1)
132
-
133
127
  def _create_diagonal(self) -> torch.Tensor:
134
128
  """
135
129
  Return the diagonal elements of the Rydberg Hamiltonian matrix
@@ -139,19 +133,21 @@ class RydbergHamiltonian:
139
133
  diag = torch.zeros(2**self.nqubits, dtype=torch.complex128, device=self.device)
140
134
 
141
135
  for i in range(self.nqubits):
142
- diag = diag.reshape(2**i, 2, -1)
136
+ diag = diag.view(2**i, 2, -1)
143
137
  i_fixed = diag[:, 1, :]
144
138
  i_fixed -= self.deltas[i]
145
139
  for j in range(i + 1, self.nqubits):
146
- i_fixed = i_fixed.reshape(2**i, 2 ** (j - i - 1), 2, -1)
140
+ i_fixed = i_fixed.view(2**i, 2 ** (j - i - 1), 2, -1)
147
141
  # replacing i_j_fixed by i_fixed breaks the code :)
148
142
  i_j_fixed = i_fixed[:, :, 1, :]
149
143
  i_j_fixed += self.interaction_matrix[i, j]
150
- return diag.reshape(-1)
144
+ return diag.view(-1)
151
145
 
152
146
  def expect(self, state: StateVector) -> torch.Tensor:
153
147
  """Return the energy expectation value E=❬ψ|H|ψ❭"""
154
148
  assert isinstance(
155
149
  state, StateVector
156
150
  ), "Currently, only expectation values of StateVectors are supported"
157
- return torch.vdot(state.vector, self * state.vector)
151
+ en = torch.vdot(state.vector, self * state.vector)
152
+ assert torch.allclose(en.imag, torch.zeros_like(en.imag), atol=1e-8)
153
+ return en.real
emu_sv/state_vector.py CHANGED
@@ -1,18 +1,23 @@
1
1
  from __future__ import annotations
2
2
 
3
- from collections import Counter
4
- from typing import Any, Iterable
5
3
  import math
4
+ from collections import Counter
5
+ from typing import Sequence, Type, TypeVar, Mapping
6
6
 
7
+ import torch
7
8
 
8
- from emu_base import State, DEVICE_COUNT
9
+ from emu_sv.utils import index_to_bitstring
9
10
 
10
- import torch
11
+ from emu_base import DEVICE_COUNT
12
+ from pulser.backend import State
13
+ from pulser.backend.state import Eigenstate
11
14
 
15
+ StateVectorType = TypeVar("StateVectorType", bound="StateVector")
16
+ # Default tensor data type
12
17
  dtype = torch.complex128
13
18
 
14
19
 
15
- class StateVector(State):
20
+ class StateVector(State[complex, torch.Tensor]):
16
21
  """
17
22
  Represents a quantum state vector in a computational basis.
18
23
 
@@ -41,13 +46,26 @@ class StateVector(State):
41
46
  device = "cuda" if gpu and DEVICE_COUNT > 0 else "cpu"
42
47
  self.vector = vector.to(dtype=dtype, device=device)
43
48
 
49
+ @property
50
+ def n_qudits(self) -> int:
51
+ """The number of qudits in the state."""
52
+ nqudits = math.log2(self.vector.reshape(-1).shape[0])
53
+ return int(nqudits)
54
+
44
55
  def _normalize(self) -> None:
45
- # NOTE: use this in the callbacks
46
- """Checks if the input is normalized or not"""
47
- norm_state = torch.linalg.vector_norm(self.vector)
56
+ """Normalizes the state vector to ensure it has unit norm.
48
57
 
49
- if not torch.allclose(norm_state, torch.tensor(1.0, dtype=torch.float64)):
50
- self.vector = self.vector / norm_state
58
+ If the vector norm is not 1, the method scales the vector
59
+ to enforce normalization.
60
+
61
+ Note:
62
+ This method is intended to be used in callbacks.
63
+ """
64
+
65
+ norm = torch.linalg.vector_norm(self.vector)
66
+
67
+ if not torch.allclose(norm, torch.ones_like(norm)):
68
+ self.vector = self.vector / norm
51
69
 
52
70
  @classmethod
53
71
  def zero(cls, num_sites: int, gpu: bool = True) -> StateVector:
@@ -62,7 +80,7 @@ class StateVector(State):
62
80
  The zero state
63
81
 
64
82
  Examples:
65
- >>> StateVector.zero(2)
83
+ >>> StateVector.zero(2,gpu=False)
66
84
  tensor([0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j], dtype=torch.complex128)
67
85
  """
68
86
 
@@ -73,8 +91,7 @@ class StateVector(State):
73
91
  @classmethod
74
92
  def make(cls, num_sites: int, gpu: bool = True) -> StateVector:
75
93
  """
76
- Returns a State vector in ground state |000..0>.
77
- The vector in the output of StateVector has the shape (2,)*number of qubits
94
+ Returns a State vector in the ground state |00..0>.
78
95
 
79
96
  Args:
80
97
  num_sites: the number of qubits
@@ -84,17 +101,19 @@ class StateVector(State):
84
101
  The described state
85
102
 
86
103
  Examples:
87
- >>> StateVector.make(2)
104
+ >>> StateVector.make(2,gpu=False)
88
105
  tensor([1.+0.j, 0.+0.j, 0.+0.j, 0.+0.j], dtype=torch.complex128)
106
+
107
+
89
108
  """
90
109
 
91
110
  result = cls.zero(num_sites=num_sites, gpu=gpu)
92
111
  result.vector[0] = 1.0
93
112
  return result
94
113
 
95
- def inner(self, other: State) -> float | complex:
114
+ def inner(self, other: State) -> torch.Tensor:
96
115
  """
97
- Compute <self, other>. The type of other must be StateVector.
116
+ Compute <self|other>. The type of other must be StateVector.
98
117
 
99
118
  Args:
100
119
  other: the other state
@@ -102,17 +121,21 @@ class StateVector(State):
102
121
  Returns:
103
122
  the inner product
104
123
  """
105
- assert isinstance(
106
- other, StateVector
107
- ), "Other state also needs to be a StateVector"
124
+ assert isinstance(other, StateVector), "Other state must be a StateVector"
108
125
  assert (
109
126
  self.vector.shape == other.vector.shape
110
- ), "States do not have the same number of sites"
127
+ ), "States do not have the same shape"
111
128
 
112
- return torch.vdot(self.vector, other.vector).item()
129
+ # by our internal convention inner and norm return to cpu
130
+ return torch.vdot(self.vector, other.vector).cpu()
113
131
 
114
132
  def sample(
115
- self, num_shots: int = 1000, p_false_pos: float = 0.0, p_false_neg: float = 0.0
133
+ self,
134
+ *,
135
+ num_shots: int = 1000,
136
+ one_state: Eigenstate | None = None,
137
+ p_false_pos: float = 0.0,
138
+ p_false_neg: float = 0.0,
116
139
  ) -> Counter[str]:
117
140
  """
118
141
  Samples bitstrings, taking into account the specified error rates.
@@ -125,25 +148,20 @@ class StateVector(State):
125
148
  Returns:
126
149
  the measured bitstrings, by count
127
150
  """
151
+ assert p_false_neg == p_false_pos == 0.0, "Error rates must be 0.0"
128
152
 
129
153
  probabilities = torch.abs(self.vector) ** 2
130
- probabilities /= probabilities.sum() # multinomial does not normalize the input
131
154
 
132
155
  outcomes = torch.multinomial(probabilities, num_shots, replacement=True)
133
156
 
134
157
  # Convert outcomes to bitstrings and count occurrences
135
- counts = Counter([self._index_to_bitstring(outcome) for outcome in outcomes])
158
+ counts = Counter(
159
+ [index_to_bitstring(self.n_qudits, outcome) for outcome in outcomes]
160
+ )
136
161
 
137
162
  # NOTE: false positives and negatives
138
163
  return counts
139
164
 
140
- def _index_to_bitstring(self, index: int) -> str:
141
- """
142
- Convert an integer index into its corresponding bitstring representation.
143
- """
144
- nqubits = int(math.log2(self.vector.reshape(-1).shape[0]))
145
- return format(index, f"0{nqubits}b")
146
-
147
165
  def __add__(self, other: State) -> StateVector:
148
166
  """Sum of two state vectors
149
167
 
@@ -153,11 +171,8 @@ class StateVector(State):
153
171
  Returns:
154
172
  The summed state
155
173
  """
156
- assert isinstance(
157
- other, StateVector
158
- ), "Other state also needs to be a StateVector"
159
- result = self.vector + other.vector
160
- return StateVector(result)
174
+ assert isinstance(other, StateVector), "`Other` state can only be a StateVector"
175
+ return StateVector(self.vector + other.vector, gpu=self.vector.is_cuda)
161
176
 
162
177
  def __rmul__(self, scalar: complex) -> StateVector:
163
178
  """Scalar multiplication
@@ -168,62 +183,68 @@ class StateVector(State):
168
183
  Returns:
169
184
  The scaled state
170
185
  """
171
- result = scalar * self.vector
186
+ return StateVector(scalar * self.vector, gpu=self.vector.is_cuda)
172
187
 
173
- return StateVector(result)
174
-
175
- def norm(self) -> float | complex:
188
+ def norm(self) -> torch.Tensor:
176
189
  """Returns the norm of the state
177
190
 
178
191
  Returns:
179
192
  the norm of the state
180
193
  """
181
- norm: float | complex = torch.linalg.vector_norm(self.vector).item()
182
- return norm
194
+ nrm: torch.Tensor = torch.linalg.vector_norm(self.vector).cpu()
195
+ return nrm
183
196
 
184
197
  def __repr__(self) -> str:
185
198
  return repr(self.vector)
186
199
 
187
- @staticmethod
188
- def from_state_string(
200
+ @classmethod
201
+ def _from_state_amplitudes(
202
+ cls: Type[StateVectorType],
189
203
  *,
190
- basis: Iterable[str],
191
- nqubits: int,
192
- strings: dict[str, complex],
193
- **kwargs: Any,
194
- ) -> StateVector:
204
+ eigenstates: Sequence[Eigenstate],
205
+ amplitudes: Mapping[str, complex],
206
+ ) -> tuple[StateVector, Mapping[str, complex]]:
195
207
  """Transforms a state given by a string into a state vector.
196
208
 
197
209
  Construct a state from the pulser abstract representation
198
210
  https://pulser.readthedocs.io/en/stable/conventions.html
199
211
 
200
212
  Args:
201
- basis: A tuple containing the basis states (e.g., ('r', 'g')).
202
- nqubits: the number of qubits.
203
- strings: A dictionary mapping state strings to complex or floats amplitudes.
213
+ eigenstates: A tuple containing the basis states (e.g., ('r', 'g')).
214
+ amplitudes: A dictionary mapping state strings to complex or floats amplitudes.
204
215
 
205
216
  Returns:
206
- The resulting state.
217
+ The normalised resulting state.
207
218
 
208
219
  Examples:
209
220
  >>> basis = ("r","g")
210
221
  >>> n = 2
211
- >>> st=StateVector.from_state_string(basis=basis,nqubits=n,strings={"rr":1.0,"gg":1.0})
222
+ >>> st=StateVector.from_state_string(basis=basis,
223
+ ... nqubits=n,strings={"rr":1.0,"gg":1.0},gpu=False)
224
+ >>> st = StateVector.from_state_amplitudes(
225
+ ... eigenstates=basis,
226
+ ... amplitudes={"rr": 1.0, "gg": 1.0}
227
+ ... )
212
228
  >>> print(st)
213
- tensor([0.7071+0.j, 0.0000+0.j, 0.0000+0.j, 0.7071+0.j], dtype=torch.complex128)
229
+ tensor([0.7071+0.j, 0.0000+0.j, 0.0000+0.j, 0.7071+0.j],
230
+ dtype=torch.complex128)
214
231
  """
215
232
 
216
- basis = set(basis)
233
+ # nqubits = len(next(iter(amplitudes.keys())))
234
+ nqubits = cls._validate_amplitudes(amplitudes=amplitudes, eigenstates=eigenstates)
235
+ basis = set(eigenstates)
217
236
  if basis == {"r", "g"}:
218
237
  one = "r"
219
238
  elif basis == {"0", "1"}:
220
- one = "1"
239
+ raise NotImplementedError(
240
+ "{'0','1'} basis is related to XY Hamiltonian, which is not implemented"
241
+ )
221
242
  else:
222
243
  raise ValueError("Unsupported basis provided")
223
244
 
224
- accum_state = StateVector.zero(num_sites=nqubits, **kwargs)
245
+ accum_state = StateVector.zero(num_sites=nqubits)
225
246
 
226
- for state, amplitude in strings.items():
247
+ for state, amplitude in amplitudes.items():
227
248
  bin_to_int = int(
228
249
  state.replace(one, "1").replace("g", "0"), 2
229
250
  ) # "0" basis is already in "0"
@@ -231,7 +252,10 @@ class StateVector(State):
231
252
 
232
253
  accum_state._normalize()
233
254
 
234
- return accum_state
255
+ return accum_state, amplitudes
256
+
257
+ def overlap(self, other: StateVector, /) -> torch.Tensor:
258
+ return self.inner(other)
235
259
 
236
260
 
237
261
  def inner(left: StateVector, right: StateVector) -> torch.Tensor:
@@ -248,21 +272,27 @@ def inner(left: StateVector, right: StateVector) -> torch.Tensor:
248
272
  Examples:
249
273
  >>> factor = math.sqrt(2.0)
250
274
  >>> basis = ("r","g")
251
- >>> nqubits = 2
252
275
  >>> string_state1 = {"gg":1.0,"rr":1.0}
253
276
  >>> state1 = StateVector.from_state_string(basis=basis,
254
- >>> nqubits=nqubits,strings=string_state1)
277
+ ... nqubits=nqubits,strings=string_state1)
255
278
  >>> string_state2 = {"gr":1.0/factor,"rr":1.0/factor}
256
279
  >>> state2 = StateVector.from_state_string(basis=basis,
257
- >>> nqubits=nqubits,strings=string_state2)
280
+ ... nqubits=nqubits,strings=string_state2)
281
+
282
+ >>> state1 = StateVector.from_state_amplitudes(eigenstates=basis,
283
+ ... amplitudes=string_state1)
284
+ >>> string_state2 = {"gr":1.0/factor,"rr":1.0/factor}
285
+ >>> state2 = StateVector.from_state_amplitudes(eigenstates=basis,
286
+ ... amplitudes=string_state2)
258
287
  >>> inner(state1,state2).item()
259
- (0.4999999999999999+0j)
288
+ (0.49999999144286444+0j)
260
289
  """
261
290
 
262
- assert (left.vector.shape == right.vector.shape) and (
263
- left.vector.dim() == 1
264
- ), "Shape of a and b should be the same and both needs to be 1D tesnor"
265
- return torch.inner(left.vector, right.vector)
291
+ assert (left.vector.shape == right.vector.shape) and (left.vector.dim() == 1), (
292
+ "Shape of left.vector and right.vector should be",
293
+ " the same and both need to be 1D tesnor",
294
+ )
295
+ return left.inner(right)
266
296
 
267
297
 
268
298
  if __name__ == "__main__":