emu-mps 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
emu_mps/mpo.py ADDED
@@ -0,0 +1,243 @@
1
+ from __future__ import annotations
2
+ import itertools
3
+ from typing import Any, List, cast, Iterable, Optional
4
+
5
+ import torch
6
+
7
+ from emu_mps.algebra import add_factors, scale_factors, zip_right
8
+ from emu_base.base_classes.operator import FullOp, QuditOp
9
+ from emu_base import Operator, State
10
+ from emu_mps.mps import MPS
11
+ from emu_mps.utils import new_left_bath, assign_devices, DEVICE_COUNT
12
+
13
+
14
+ def _validate_operator_targets(operations: FullOp, nqubits: int) -> None:
15
+ for tensorop in operations:
16
+ target_qids = (factor[1] for factor in tensorop[1])
17
+ target_qids_list = list(itertools.chain(*target_qids))
18
+ target_qids_set = set(target_qids_list)
19
+ if len(target_qids_set) < len(target_qids_list):
20
+ # Either the qubit id has been defined twice in an operation:
21
+ for qids in target_qids:
22
+ if len(set(qids)) < len(qids):
23
+ raise ValueError("Duplicate atom ids in argument list.")
24
+ # Or it was defined in two different operations
25
+ raise ValueError("Each qubit can be targeted by only one operation.")
26
+ if max(target_qids_set) >= nqubits:
27
+ raise ValueError(
28
+ "The operation targets more qubits than there are in the register."
29
+ )
30
+
31
+
32
+ class MPO(Operator):
33
+ """
34
+ Matrix Product Operator.
35
+
36
+ Each tensor has 4 dimensions ordered as such: (left bond, output, input, right bond).
37
+
38
+ Args:
39
+ factors: the tensors making up the MPO
40
+ """
41
+
42
+ def __init__(
43
+ self, factors: List[torch.Tensor], /, num_gpus_to_use: Optional[int] = None
44
+ ):
45
+ self.factors = factors
46
+ self.num_sites = len(factors)
47
+ if not self.num_sites > 1:
48
+ raise ValueError("For 1 qubit states, do state vector")
49
+ if factors[0].shape[0] != 1 or factors[-1].shape[-1] != 1:
50
+ raise ValueError(
51
+ "The dimension of the left (right) link of the first (last) tensor should be 1"
52
+ )
53
+ assert all(
54
+ factors[i - 1].shape[-1] == factors[i].shape[0]
55
+ for i in range(1, self.num_sites)
56
+ )
57
+
58
+ if num_gpus_to_use is not None:
59
+ assign_devices(self.factors, min(DEVICE_COUNT, num_gpus_to_use))
60
+
61
+ def __repr__(self) -> str:
62
+ return "[" + ", ".join(map(repr, self.factors)) + "]"
63
+
64
+ def __mul__(self, other: State) -> MPS:
65
+ """
66
+ Applies this MPO to the given MPS.
67
+ The returned MPS is:
68
+
69
+ - othogonal on the first site
70
+ - truncated up to `other.precision`
71
+ - distributed on the same devices of `other`
72
+
73
+ Args:
74
+ other: the state to apply this operator to
75
+
76
+ Returns:
77
+ the resulting state
78
+ """
79
+ assert isinstance(other, MPS), "MPO can only be multiplied with MPS"
80
+ factors = zip_right(
81
+ self.factors,
82
+ other.factors,
83
+ max_error=other.precision,
84
+ max_rank=other.max_bond_dim,
85
+ )
86
+ return MPS(factors, orthogonality_center=0)
87
+
88
+ def __add__(self, other: Operator) -> MPO:
89
+ """
90
+ Returns the sum of two MPOs, computed with a direct algorithm.
91
+ The result is currently not truncated
92
+
93
+ Args:
94
+ other: the other operator
95
+
96
+ Returns:
97
+ the summed operator
98
+ """
99
+ assert isinstance(other, MPO), "MPO can only be added to another MPO"
100
+ sum_factors = add_factors(self.factors, other.factors)
101
+ return MPO(sum_factors)
102
+
103
+ def __rmul__(self, scalar: complex) -> MPO:
104
+ """
105
+ Multiply an MPO by scalar.
106
+ Assumes the orthogonal centre is on the first factor.
107
+
108
+ Args:
109
+ scalar: the scale factor to multiply with
110
+
111
+ Returns:
112
+ the scaled MPO
113
+ """
114
+ factors = scale_factors(self.factors, scalar, which=0)
115
+ return MPO(factors)
116
+
117
+ def __matmul__(self, other: Operator) -> MPO:
118
+ """
119
+ Compose two operators. The ordering is that
120
+ self is applied after other.
121
+
122
+ Args:
123
+ other: the operator to compose with self
124
+
125
+ Returns:
126
+ the composed operator
127
+ """
128
+ assert isinstance(other, MPO), "MPO can only be applied to another MPO"
129
+ factors = zip_right(self.factors, other.factors)
130
+ return MPO(factors)
131
+
132
+ def expect(self, state: State) -> float | complex:
133
+ """
134
+ Compute the expectation value of self on the given state.
135
+
136
+ Args:
137
+ state: the state with which to compute
138
+
139
+ Returns:
140
+ the expectation
141
+ """
142
+ assert isinstance(
143
+ state, MPS
144
+ ), "currently, only expectation values of MPSs are \
145
+ supported"
146
+ acc = torch.ones(
147
+ 1, 1, 1, dtype=state.factors[0].dtype, device=state.factors[0].device
148
+ )
149
+ n = len(self.factors) - 1
150
+ for i in range(n):
151
+ acc = new_left_bath(acc, state.factors[i], self.factors[i]).to(
152
+ state.factors[i + 1].device
153
+ )
154
+ acc = new_left_bath(acc, state.factors[n], self.factors[n])
155
+ return acc.item() # type: ignore [no-any-return]
156
+
157
+ @staticmethod
158
+ def from_operator_string(
159
+ basis: Iterable[str],
160
+ nqubits: int,
161
+ operations: FullOp,
162
+ operators: dict[str, QuditOp] = {},
163
+ /,
164
+ **kwargs: Any,
165
+ ) -> MPO:
166
+ """
167
+ See the base class
168
+
169
+ Args:
170
+ basis: the eigenstates in the basis to use e.g. ('r', 'g')
171
+ nqubits: how many qubits there are in the state
172
+ operations: which bitstrings make up the state with what weight
173
+ operators: additional symbols to be used in operations
174
+
175
+ Returns:
176
+ the operator in MPO form.
177
+ """
178
+
179
+ _validate_operator_targets(operations, nqubits)
180
+
181
+ basis = set(basis)
182
+ if basis == {"r", "g"}:
183
+ # operators will now contain the basis for single qubit ops, and potentially
184
+ # user defined strings in terms of these
185
+ operators |= {
186
+ "gg": torch.tensor(
187
+ [[1.0, 0.0], [0.0, 0.0]], dtype=torch.complex128
188
+ ).reshape(1, 2, 2, 1),
189
+ "gr": torch.tensor(
190
+ [[0.0, 0.0], [1.0, 0.0]], dtype=torch.complex128
191
+ ).reshape(1, 2, 2, 1),
192
+ "rg": torch.tensor(
193
+ [[0.0, 1.0], [0.0, 0.0]], dtype=torch.complex128
194
+ ).reshape(1, 2, 2, 1),
195
+ "rr": torch.tensor(
196
+ [[0.0, 0.0], [0.0, 1.0]], dtype=torch.complex128
197
+ ).reshape(1, 2, 2, 1),
198
+ }
199
+ elif basis == {"0", "1"}:
200
+ # operators will now contain the basis for single qubit ops, and potentially
201
+ # user defined strings in terms of these
202
+ operators |= {
203
+ "00": torch.tensor(
204
+ [[1.0, 0.0], [0.0, 0.0]], dtype=torch.complex128
205
+ ).reshape(1, 2, 2, 1),
206
+ "01": torch.tensor(
207
+ [[0.0, 0.0], [1.0, 0.0]], dtype=torch.complex128
208
+ ).reshape(1, 2, 2, 1),
209
+ "10": torch.tensor(
210
+ [[0.0, 1.0], [0.0, 0.0]], dtype=torch.complex128
211
+ ).reshape(1, 2, 2, 1),
212
+ "11": torch.tensor(
213
+ [[0.0, 0.0], [0.0, 1.0]], dtype=torch.complex128
214
+ ).reshape(1, 2, 2, 1),
215
+ }
216
+ else:
217
+ raise ValueError("Unsupported basis provided")
218
+
219
+ mpos = []
220
+ for coeff, tensorop in operations:
221
+ # this function will recurse through the operators, and replace any definitions
222
+ # in terms of strings by the computed tensor
223
+ def replace_operator_string(op: QuditOp | torch.Tensor) -> torch.Tensor:
224
+ if isinstance(op, dict):
225
+ for opstr, coeff in op.items():
226
+ tensor = replace_operator_string(operators[opstr])
227
+ operators[opstr] = tensor
228
+ op[opstr] = tensor * coeff
229
+ op = sum(cast(list[torch.Tensor], op.values()))
230
+ return op
231
+
232
+ factors = [
233
+ torch.eye(2, 2, dtype=torch.complex128).reshape(1, 2, 2, 1)
234
+ ] * nqubits
235
+
236
+ for i, op in enumerate(tensorop):
237
+ tensorop[i] = (replace_operator_string(op[0]), op[1])
238
+
239
+ for op in tensorop:
240
+ for i in op[1]:
241
+ factors[i] = op[0]
242
+ mpos.append(coeff * MPO(factors, **kwargs))
243
+ return sum(mpos[1:], start=mpos[0])