emu-sv 1.0.1__py3-none-any.whl → 2.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,195 @@
1
+ from __future__ import annotations
2
+ from collections import Counter
3
+ import math
4
+ from typing import Mapping, TypeVar, Type, Sequence
5
+ import torch
6
+ from pulser.backend import State
7
+ from emu_base import DEVICE_COUNT
8
+ from emu_sv.state_vector import StateVector
9
+ from emu_sv.utils import index_to_bitstring
10
+ from pulser.backend.state import Eigenstate
11
+
12
+ DensityMatrixType = TypeVar("DensityMatrixType", bound="DensityMatrix")
13
+
14
+ dtype = torch.complex128
15
+
16
+
17
+ class DensityMatrix(State[complex, torch.Tensor]):
18
+ """Represents a density matrix in a computational basis."""
19
+
20
+ # for the moment no need to check positivity and trace 1
21
+ def __init__(
22
+ self,
23
+ matrix: torch.Tensor,
24
+ *,
25
+ gpu: bool = True,
26
+ ):
27
+ # NOTE: this accepts also zero matrices.
28
+
29
+ device = "cuda" if gpu and DEVICE_COUNT > 0 else "cpu"
30
+ self.matrix = matrix.to(dtype=dtype, device=device)
31
+
32
+ @property
33
+ def n_qudits(self) -> int:
34
+ """The number of qudits in the state."""
35
+ nqudits = math.log2(self.matrix.shape[0])
36
+ return int(nqudits)
37
+
38
+ @classmethod
39
+ def make(cls, n_atoms: int, gpu: bool = True) -> DensityMatrix:
40
+ """Creates the density matrix of the ground state |000...0>"""
41
+ result = torch.zeros(2**n_atoms, 2**n_atoms, dtype=dtype)
42
+ result[0, 0] = 1.0
43
+ return cls(result, gpu=gpu)
44
+
45
+ def __add__(self, other: State) -> DensityMatrix:
46
+ raise NotImplementedError("Not implemented")
47
+
48
+ def __rmul__(self, scalar: complex) -> DensityMatrix:
49
+ raise NotImplementedError("Not implemented")
50
+
51
+ def _normalize(self) -> None:
52
+ # NOTE: use this in the callbacks
53
+ """Normalize the density matrix state"""
54
+ matrix_trace = torch.trace(self.matrix)
55
+ if not torch.allclose(matrix_trace, torch.tensor(1.0, dtype=torch.float64)):
56
+ self.matrix = self.matrix / matrix_trace
57
+
58
+ def overlap(self, other: State) -> torch.Tensor:
59
+ """
60
+ Compute Tr(self^† @ other). The type of other must be DensityMatrix.
61
+
62
+ Args:
63
+ other: the other state
64
+
65
+ Returns:
66
+ the inner product
67
+
68
+ Example:
69
+ >>> density_bell_state = (1/2* torch.tensor([[1, 0, 0, 1], [0, 0, 0, 0],
70
+ ... [0, 0, 0, 0], [1, 0, 0, 1]],dtype=torch.complex128))
71
+ >>> density_c = DensityMatrix(density_bell_state, gpu=False)
72
+ >>> density_c.overlap(density_c)
73
+ tensor(1.+0.j, dtype=torch.complex128)
74
+ """
75
+
76
+ assert isinstance(
77
+ other, DensityMatrix
78
+ ), "Other state also needs to be a DensityMatrix"
79
+ assert (
80
+ self.matrix.shape == other.matrix.shape
81
+ ), "States do not have the same number of sites"
82
+
83
+ return torch.vdot(self.matrix.flatten(), other.matrix.flatten())
84
+
85
+ @classmethod
86
+ def from_state_vector(cls, state: StateVector) -> DensityMatrix:
87
+ """Convert a state vector to a density matrix.
88
+ This function takes a state vector |ψ❭ and returns the corresponding
89
+ density matrix ρ = |ψ❭❬ψ| representing the pure state |ψ❭.
90
+ Example:
91
+ >>> from emu_sv import StateVector
92
+ >>> import math
93
+ >>> bell_state_vec = 1 / math.sqrt(2) * torch.tensor(
94
+ ... [1.0, 0.0, 0.0, 1.0j],dtype=torch.complex128)
95
+ >>> bell_state = StateVector(bell_state_vec, gpu=False)
96
+ >>> density = DensityMatrix.from_state_vector(bell_state)
97
+ >>> print(density.matrix)
98
+ tensor([[0.5000+0.0000j, 0.0000+0.0000j, 0.0000+0.0000j, 0.0000-0.5000j],
99
+ [0.0000+0.0000j, 0.0000+0.0000j, 0.0000+0.0000j, 0.0000+0.0000j],
100
+ [0.0000+0.0000j, 0.0000+0.0000j, 0.0000+0.0000j, 0.0000+0.0000j],
101
+ [0.0000+0.5000j, 0.0000+0.0000j, 0.0000+0.0000j, 0.5000+0.0000j]],
102
+ dtype=torch.complex128)
103
+ """
104
+
105
+ return cls(
106
+ torch.outer(state.vector, state.vector.conj()), gpu=state.vector.is_cuda
107
+ )
108
+
109
+ @classmethod
110
+ def _from_state_amplitudes(
111
+ cls: Type[DensityMatrixType],
112
+ *,
113
+ eigenstates: Sequence[Eigenstate],
114
+ amplitudes: Mapping[str, complex],
115
+ ) -> tuple[DensityMatrix, Mapping[str, complex]]:
116
+ """Transforms a state given by a string into a density matrix.
117
+
118
+ Construct a state from the pulser abstract representation
119
+ https://pulser.readthedocs.io/en/stable/conventions.html
120
+
121
+ Args:
122
+ basis: A tuple containing the basis states (e.g., ('r', 'g')).
123
+ nqubits: the number of qubits.
124
+ strings: A dictionary mapping state strings to complex or floats amplitudes.
125
+
126
+ Returns:
127
+ The resulting state.
128
+
129
+ Examples:
130
+ >>> eigenstates = ("r","g")
131
+ >>> n = 2
132
+ >>> dense_mat=DensityMatrix.from_state_amplitudes(eigenstates=eigenstates,
133
+ ... amplitudes={"rr":1.0,"gg":1.0})
134
+ >>> print(dense_mat.matrix)
135
+ tensor([[0.5000+0.j, 0.0000+0.j, 0.0000+0.j, 0.5000+0.j],
136
+ [0.0000+0.j, 0.0000+0.j, 0.0000+0.j, 0.0000+0.j],
137
+ [0.0000+0.j, 0.0000+0.j, 0.0000+0.j, 0.0000+0.j],
138
+ [0.5000+0.j, 0.0000+0.j, 0.0000+0.j, 0.5000+0.j]],
139
+ dtype=torch.complex128)
140
+ """
141
+
142
+ state_vector, amplitudes = StateVector._from_state_amplitudes(
143
+ eigenstates=eigenstates, amplitudes=amplitudes
144
+ )
145
+
146
+ return DensityMatrix.from_state_vector(state_vector), amplitudes
147
+
148
+ def sample(
149
+ self,
150
+ num_shots: int = 1000,
151
+ one_state: Eigenstate | None = None,
152
+ p_false_pos: float = 0.0,
153
+ p_false_neg: float = 0.0,
154
+ ) -> Counter[str]:
155
+ """
156
+ Samples bitstrings, taking into account the specified error rates.
157
+
158
+ Args:
159
+ num_shots: how many bitstrings to sample
160
+ p_false_pos: the rate at which a 0 is read as a 1
161
+ p_false_neg: teh rate at which a 1 is read as a 0
162
+
163
+ Returns:
164
+ the measured bitstrings, by count
165
+
166
+ Example:
167
+ >>> import math
168
+ >>> torch.manual_seed(1234)
169
+ >>> from emu_sv import StateVector
170
+ >>> bell_vec = 1 / math.sqrt(2) * torch.tensor(
171
+ ... [1.0, 0.0, 0.0, 1.0j],dtype=torch.complex128)
172
+ >>> bell_state_vec = StateVector(bell_vec)
173
+ >>> bell_density = DensityMatrix.from_state_vector(bell_state_vec)
174
+ >>> bell_density.sample(1000)
175
+ Counter({'00': 517, '11': 483})
176
+ """
177
+
178
+ assert p_false_neg == p_false_pos == 0.0, "Error rates must be 0.0"
179
+
180
+ probabilities = torch.abs(self.matrix.diagonal())
181
+
182
+ outcomes = torch.multinomial(probabilities, num_shots, replacement=True)
183
+
184
+ # Convert outcomes to bitstrings and count occurrences
185
+ counts = Counter(
186
+ [index_to_bitstring(self.n_qudits, outcome) for outcome in outcomes]
187
+ )
188
+
189
+ return counts
190
+
191
+
192
+ if __name__ == "__main__":
193
+ import doctest
194
+
195
+ doctest.testmod()
emu_sv/hamiltonian.py CHANGED
@@ -91,8 +91,8 @@ class RydbergHamiltonian:
91
91
  dim_to_act = 1
92
92
  for n, omega_n in enumerate(self.omegas):
93
93
  shape_n = (2**n, 2, 2 ** (self.nqubits - n - 1))
94
- vec = vec.reshape(shape_n)
95
- result = result.reshape(shape_n)
94
+ vec = vec.view(shape_n)
95
+ result = result.view(shape_n)
96
96
  result.index_add_(dim_to_act, self.inds, vec, alpha=omega_n)
97
97
 
98
98
  def _apply_sigma_operators_complex(
@@ -112,8 +112,8 @@ class RydbergHamiltonian:
112
112
  dim_to_act = 1
113
113
  for n, c_omega_n in enumerate(c_omegas):
114
114
  shape_n = (2**n, 2, 2 ** (self.nqubits - n - 1))
115
- vec = vec.reshape(shape_n)
116
- result = result.reshape(shape_n)
115
+ vec = vec.view(shape_n)
116
+ result = result.view(shape_n)
117
117
  result.index_add_(
118
118
  dim_to_act, self.inds[0], vec[:, 0, :].unsqueeze(1), alpha=c_omega_n
119
119
  )
@@ -133,19 +133,21 @@ class RydbergHamiltonian:
133
133
  diag = torch.zeros(2**self.nqubits, dtype=torch.complex128, device=self.device)
134
134
 
135
135
  for i in range(self.nqubits):
136
- diag = diag.reshape(2**i, 2, -1)
136
+ diag = diag.view(2**i, 2, -1)
137
137
  i_fixed = diag[:, 1, :]
138
138
  i_fixed -= self.deltas[i]
139
139
  for j in range(i + 1, self.nqubits):
140
- i_fixed = i_fixed.reshape(2**i, 2 ** (j - i - 1), 2, -1)
140
+ i_fixed = i_fixed.view(2**i, 2 ** (j - i - 1), 2, -1)
141
141
  # replacing i_j_fixed by i_fixed breaks the code :)
142
142
  i_j_fixed = i_fixed[:, :, 1, :]
143
143
  i_j_fixed += self.interaction_matrix[i, j]
144
- return diag.reshape(-1)
144
+ return diag.view(-1)
145
145
 
146
146
  def expect(self, state: StateVector) -> torch.Tensor:
147
147
  """Return the energy expectation value E=❬ψ|H|ψ❭"""
148
148
  assert isinstance(
149
149
  state, StateVector
150
150
  ), "Currently, only expectation values of StateVectors are supported"
151
- return torch.vdot(state.vector, self * state.vector)
151
+ en = torch.vdot(state.vector, self * state.vector)
152
+ assert torch.allclose(en.imag, torch.zeros_like(en.imag), atol=1e-8)
153
+ return en.real
emu_sv/state_vector.py CHANGED
@@ -1,18 +1,23 @@
1
1
  from __future__ import annotations
2
2
 
3
- from collections import Counter
4
- from typing import Any, Iterable
5
3
  import math
4
+ from collections import Counter
5
+ from typing import Sequence, Type, TypeVar, Mapping
6
6
 
7
+ import torch
7
8
 
8
- from emu_base import State, DEVICE_COUNT
9
+ from emu_sv.utils import index_to_bitstring
9
10
 
10
- import torch
11
+ from emu_base import DEVICE_COUNT
12
+ from pulser.backend import State
13
+ from pulser.backend.state import Eigenstate
11
14
 
15
+ StateVectorType = TypeVar("StateVectorType", bound="StateVector")
16
+ # Default tensor data type
12
17
  dtype = torch.complex128
13
18
 
14
19
 
15
- class StateVector(State):
20
+ class StateVector(State[complex, torch.Tensor]):
16
21
  """
17
22
  Represents a quantum state vector in a computational basis.
18
23
 
@@ -41,13 +46,26 @@ class StateVector(State):
41
46
  device = "cuda" if gpu and DEVICE_COUNT > 0 else "cpu"
42
47
  self.vector = vector.to(dtype=dtype, device=device)
43
48
 
49
+ @property
50
+ def n_qudits(self) -> int:
51
+ """The number of qudits in the state."""
52
+ nqudits = math.log2(self.vector.reshape(-1).shape[0])
53
+ return int(nqudits)
54
+
44
55
  def _normalize(self) -> None:
45
- # NOTE: use this in the callbacks
46
- """Checks if the input is normalized or not"""
47
- norm_state = torch.linalg.vector_norm(self.vector)
56
+ """Normalizes the state vector to ensure it has unit norm.
48
57
 
49
- if not torch.allclose(norm_state, torch.tensor(1.0, dtype=torch.float64)):
50
- self.vector = self.vector / norm_state
58
+ If the vector norm is not 1, the method scales the vector
59
+ to enforce normalization.
60
+
61
+ Note:
62
+ This method is intended to be used in callbacks.
63
+ """
64
+
65
+ norm = torch.linalg.vector_norm(self.vector)
66
+
67
+ if not torch.allclose(norm, torch.ones_like(norm)):
68
+ self.vector = self.vector / norm
51
69
 
52
70
  @classmethod
53
71
  def zero(cls, num_sites: int, gpu: bool = True) -> StateVector:
@@ -62,7 +80,7 @@ class StateVector(State):
62
80
  The zero state
63
81
 
64
82
  Examples:
65
- >>> StateVector.zero(2)
83
+ >>> StateVector.zero(2,gpu=False)
66
84
  tensor([0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j], dtype=torch.complex128)
67
85
  """
68
86
 
@@ -73,8 +91,7 @@ class StateVector(State):
73
91
  @classmethod
74
92
  def make(cls, num_sites: int, gpu: bool = True) -> StateVector:
75
93
  """
76
- Returns a State vector in ground state |000..0>.
77
- The vector in the output of StateVector has the shape (2,)*number of qubits
94
+ Returns a State vector in the ground state |00..0>.
78
95
 
79
96
  Args:
80
97
  num_sites: the number of qubits
@@ -84,17 +101,19 @@ class StateVector(State):
84
101
  The described state
85
102
 
86
103
  Examples:
87
- >>> StateVector.make(2)
104
+ >>> StateVector.make(2,gpu=False)
88
105
  tensor([1.+0.j, 0.+0.j, 0.+0.j, 0.+0.j], dtype=torch.complex128)
106
+
107
+
89
108
  """
90
109
 
91
110
  result = cls.zero(num_sites=num_sites, gpu=gpu)
92
111
  result.vector[0] = 1.0
93
112
  return result
94
113
 
95
- def inner(self, other: State) -> float | complex:
114
+ def inner(self, other: State) -> torch.Tensor:
96
115
  """
97
- Compute <self, other>. The type of other must be StateVector.
116
+ Compute <self|other>. The type of other must be StateVector.
98
117
 
99
118
  Args:
100
119
  other: the other state
@@ -102,17 +121,21 @@ class StateVector(State):
102
121
  Returns:
103
122
  the inner product
104
123
  """
105
- assert isinstance(
106
- other, StateVector
107
- ), "Other state also needs to be a StateVector"
124
+ assert isinstance(other, StateVector), "Other state must be a StateVector"
108
125
  assert (
109
126
  self.vector.shape == other.vector.shape
110
- ), "States do not have the same number of sites"
127
+ ), "States do not have the same shape"
111
128
 
112
- return torch.vdot(self.vector, other.vector).item()
129
+ # by our internal convention inner and norm return to cpu
130
+ return torch.vdot(self.vector, other.vector).cpu()
113
131
 
114
132
  def sample(
115
- self, num_shots: int = 1000, p_false_pos: float = 0.0, p_false_neg: float = 0.0
133
+ self,
134
+ *,
135
+ num_shots: int = 1000,
136
+ one_state: Eigenstate | None = None,
137
+ p_false_pos: float = 0.0,
138
+ p_false_neg: float = 0.0,
116
139
  ) -> Counter[str]:
117
140
  """
118
141
  Samples bitstrings, taking into account the specified error rates.
@@ -125,24 +148,20 @@ class StateVector(State):
125
148
  Returns:
126
149
  the measured bitstrings, by count
127
150
  """
151
+ assert p_false_neg == p_false_pos == 0.0, "Error rates must be 0.0"
128
152
 
129
153
  probabilities = torch.abs(self.vector) ** 2
130
154
 
131
155
  outcomes = torch.multinomial(probabilities, num_shots, replacement=True)
132
156
 
133
157
  # Convert outcomes to bitstrings and count occurrences
134
- counts = Counter([self._index_to_bitstring(outcome) for outcome in outcomes])
158
+ counts = Counter(
159
+ [index_to_bitstring(self.n_qudits, outcome) for outcome in outcomes]
160
+ )
135
161
 
136
162
  # NOTE: false positives and negatives
137
163
  return counts
138
164
 
139
- def _index_to_bitstring(self, index: int) -> str:
140
- """
141
- Convert an integer index into its corresponding bitstring representation.
142
- """
143
- nqubits = int(math.log2(self.vector.reshape(-1).shape[0]))
144
- return format(index, f"0{nqubits}b")
145
-
146
165
  def __add__(self, other: State) -> StateVector:
147
166
  """Sum of two state vectors
148
167
 
@@ -152,11 +171,8 @@ class StateVector(State):
152
171
  Returns:
153
172
  The summed state
154
173
  """
155
- assert isinstance(
156
- other, StateVector
157
- ), "Other state also needs to be a StateVector"
158
- result = self.vector + other.vector
159
- return StateVector(result)
174
+ assert isinstance(other, StateVector), "`Other` state can only be a StateVector"
175
+ return StateVector(self.vector + other.vector, gpu=self.vector.is_cuda)
160
176
 
161
177
  def __rmul__(self, scalar: complex) -> StateVector:
162
178
  """Scalar multiplication
@@ -167,62 +183,68 @@ class StateVector(State):
167
183
  Returns:
168
184
  The scaled state
169
185
  """
170
- result = scalar * self.vector
186
+ return StateVector(scalar * self.vector, gpu=self.vector.is_cuda)
171
187
 
172
- return StateVector(result)
173
-
174
- def norm(self) -> float | complex:
188
+ def norm(self) -> torch.Tensor:
175
189
  """Returns the norm of the state
176
190
 
177
191
  Returns:
178
192
  the norm of the state
179
193
  """
180
- norm: float | complex = torch.linalg.vector_norm(self.vector).item()
181
- return norm
194
+ nrm: torch.Tensor = torch.linalg.vector_norm(self.vector).cpu()
195
+ return nrm
182
196
 
183
197
  def __repr__(self) -> str:
184
198
  return repr(self.vector)
185
199
 
186
- @staticmethod
187
- def from_state_string(
200
+ @classmethod
201
+ def _from_state_amplitudes(
202
+ cls: Type[StateVectorType],
188
203
  *,
189
- basis: Iterable[str],
190
- nqubits: int,
191
- strings: dict[str, complex],
192
- **kwargs: Any,
193
- ) -> StateVector:
204
+ eigenstates: Sequence[Eigenstate],
205
+ amplitudes: Mapping[str, complex],
206
+ ) -> tuple[StateVector, Mapping[str, complex]]:
194
207
  """Transforms a state given by a string into a state vector.
195
208
 
196
209
  Construct a state from the pulser abstract representation
197
210
  https://pulser.readthedocs.io/en/stable/conventions.html
198
211
 
199
212
  Args:
200
- basis: A tuple containing the basis states (e.g., ('r', 'g')).
201
- nqubits: the number of qubits.
202
- strings: A dictionary mapping state strings to complex or floats amplitudes.
213
+ eigenstates: A tuple containing the basis states (e.g., ('r', 'g')).
214
+ amplitudes: A dictionary mapping state strings to complex or floats amplitudes.
203
215
 
204
216
  Returns:
205
- The resulting state.
217
+ The normalised resulting state.
206
218
 
207
219
  Examples:
208
220
  >>> basis = ("r","g")
209
221
  >>> n = 2
210
- >>> st=StateVector.from_state_string(basis=basis,nqubits=n,strings={"rr":1.0,"gg":1.0})
222
+ >>> st=StateVector.from_state_string(basis=basis,
223
+ ... nqubits=n,strings={"rr":1.0,"gg":1.0},gpu=False)
224
+ >>> st = StateVector.from_state_amplitudes(
225
+ ... eigenstates=basis,
226
+ ... amplitudes={"rr": 1.0, "gg": 1.0}
227
+ ... )
211
228
  >>> print(st)
212
- tensor([0.7071+0.j, 0.0000+0.j, 0.0000+0.j, 0.7071+0.j], dtype=torch.complex128)
229
+ tensor([0.7071+0.j, 0.0000+0.j, 0.0000+0.j, 0.7071+0.j],
230
+ dtype=torch.complex128)
213
231
  """
214
232
 
215
- basis = set(basis)
233
+ # nqubits = len(next(iter(amplitudes.keys())))
234
+ nqubits = cls._validate_amplitudes(amplitudes=amplitudes, eigenstates=eigenstates)
235
+ basis = set(eigenstates)
216
236
  if basis == {"r", "g"}:
217
237
  one = "r"
218
238
  elif basis == {"0", "1"}:
219
- one = "1"
239
+ raise NotImplementedError(
240
+ "{'0','1'} basis is related to XY Hamiltonian, which is not implemented"
241
+ )
220
242
  else:
221
243
  raise ValueError("Unsupported basis provided")
222
244
 
223
- accum_state = StateVector.zero(num_sites=nqubits, **kwargs)
245
+ accum_state = StateVector.zero(num_sites=nqubits)
224
246
 
225
- for state, amplitude in strings.items():
247
+ for state, amplitude in amplitudes.items():
226
248
  bin_to_int = int(
227
249
  state.replace(one, "1").replace("g", "0"), 2
228
250
  ) # "0" basis is already in "0"
@@ -230,7 +252,10 @@ class StateVector(State):
230
252
 
231
253
  accum_state._normalize()
232
254
 
233
- return accum_state
255
+ return accum_state, amplitudes
256
+
257
+ def overlap(self, other: StateVector, /) -> torch.Tensor:
258
+ return self.inner(other)
234
259
 
235
260
 
236
261
  def inner(left: StateVector, right: StateVector) -> torch.Tensor:
@@ -247,21 +272,27 @@ def inner(left: StateVector, right: StateVector) -> torch.Tensor:
247
272
  Examples:
248
273
  >>> factor = math.sqrt(2.0)
249
274
  >>> basis = ("r","g")
250
- >>> nqubits = 2
251
275
  >>> string_state1 = {"gg":1.0,"rr":1.0}
252
276
  >>> state1 = StateVector.from_state_string(basis=basis,
253
- >>> nqubits=nqubits,strings=string_state1)
277
+ ... nqubits=nqubits,strings=string_state1)
254
278
  >>> string_state2 = {"gr":1.0/factor,"rr":1.0/factor}
255
279
  >>> state2 = StateVector.from_state_string(basis=basis,
256
- >>> nqubits=nqubits,strings=string_state2)
280
+ ... nqubits=nqubits,strings=string_state2)
281
+
282
+ >>> state1 = StateVector.from_state_amplitudes(eigenstates=basis,
283
+ ... amplitudes=string_state1)
284
+ >>> string_state2 = {"gr":1.0/factor,"rr":1.0/factor}
285
+ >>> state2 = StateVector.from_state_amplitudes(eigenstates=basis,
286
+ ... amplitudes=string_state2)
257
287
  >>> inner(state1,state2).item()
258
- (0.4999999999999999+0j)
288
+ (0.49999999144286444+0j)
259
289
  """
260
290
 
261
- assert (left.vector.shape == right.vector.shape) and (
262
- left.vector.dim() == 1
263
- ), "Shape of a and b should be the same and both needs to be 1D tesnor"
264
- return torch.inner(left.vector, right.vector)
291
+ assert (left.vector.shape == right.vector.shape) and (left.vector.dim() == 1), (
292
+ "Shape of left.vector and right.vector should be",
293
+ " the same and both need to be 1D tesnor",
294
+ )
295
+ return left.inner(right)
265
296
 
266
297
 
267
298
  if __name__ == "__main__":