emu-base 2.0.0__py3-none-any.whl → 2.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
emu_base/__init__.py CHANGED
@@ -16,4 +16,4 @@ __all__ = [
16
16
  "DEVICE_COUNT",
17
17
  ]
18
18
 
19
- __version__ = "2.0.0"
19
+ __version__ = "2.0.2"
@@ -0,0 +1,112 @@
1
+ import torch
2
+ from typing import Callable
3
+
4
+ from emu_base.math.krylov_exp import DEFAULT_MAX_KRYLOV_DIM
5
+
6
+ max_krylov_dim = DEFAULT_MAX_KRYLOV_DIM
7
+
8
+
9
+ def double_krylov(
10
+ op: Callable,
11
+ state: torch.Tensor,
12
+ grad: torch.Tensor,
13
+ tolerance: float,
14
+ ) -> tuple[list[torch.Tensor], torch.Tensor, list[torch.Tensor]]:
15
+ """
16
+ Returns a Lanczos decomposition of the Fréchet derivative of the exponential
17
+ map U=exp(op) along the direction |state❭❬grad|.
18
+ The decomposition is represented by the tuple (Vs, dS, Vg) such that,
19
+ dU(op, |state❭❬grad|) = Vsᵗ @ dS @ Vg*
20
+
21
+ Args:
22
+ op (Callable): linear map to exponentiate, e.g. op(|ψ❭) = H|ψ❭.
23
+ state (torch.Tensor):
24
+ grad (torch.Tensor):
25
+ tolerance (float): tolerance of the returned derivative.
26
+
27
+ Returns:
28
+ Vstate (list): Lanczos basis of state
29
+ dS (torch.Tensor): matrix representing the derivative in the new basis
30
+ Vgrad (list): Lanczos basis of grad
31
+
32
+ Notes:
33
+ Fréchet derivative dU(op,|a❭❬b|) being defined as:
34
+ exp|op |a❭❬b|| = |exp(op) dU(op,|a❭❬b|)|
35
+ |0 op | |0 exp(op) |
36
+
37
+ The function computes two Lanczos decomposition
38
+ up to the given tolerance
39
+ Va = Lanczos(|a❭,op(|a❭),op^2(|a❭),...)
40
+ Vb = Lanczos(|b❭,op(|b❭),op^2(|b❭),...)
41
+ such that,
42
+ op = Vaᵗ @ Ta @ Va*
43
+ op = Vbᵗ @ Tb @ Vb*
44
+
45
+ In the new basis Va, Vb
46
+ |op |a❭❬b|| -> |Ta ab|0❭❬0||
47
+ |0 op | |0 Tb |
48
+ where the top-right block only has one nonzero element.
49
+ Exponentiating such matrix and selecting the top-right block
50
+ gives us the desired matrix dS such that
51
+ dU(op, |a❭❬b|) = Vaᵗ @ dS @ Vb*
52
+ """
53
+ Vs, Ts = lanczos(op, state, tolerance)
54
+ Vg, Tg = lanczos(op, grad, tolerance)
55
+ size_s = len(Vs)
56
+ big_mat = torch.block_diag(Ts, Tg)
57
+ # Only one element in the top-right corner
58
+ big_mat[0, size_s] = state.norm() * grad.norm()
59
+ dS = torch.matrix_exp(big_mat)[:size_s, size_s:]
60
+ return Vs, dS, Vg
61
+
62
+
63
+ def lanczos(
64
+ op: Callable,
65
+ v: torch.Tensor,
66
+ tolerance: float,
67
+ ) -> tuple[list[torch.Tensor], torch.Tensor]:
68
+ """
69
+ Copy of the code in krylov_exp to do Laczos iteration
70
+ To decide
71
+ 1. refactor this
72
+ 2. allow krylov results to store lanczos_vectors and T
73
+ """
74
+ converged = False
75
+ lanczos_vectors = [v / v.norm()]
76
+ T = torch.zeros(max_krylov_dim + 2, max_krylov_dim + 2, dtype=v.dtype)
77
+
78
+ for j in range(max_krylov_dim):
79
+ w = op(lanczos_vectors[-1])
80
+ n = w.norm()
81
+ for k in range(max(0, j - 1), j + 1):
82
+ overlap = torch.tensordot(lanczos_vectors[k].conj(), w, dims=w.dim())
83
+ T[k, j] = overlap
84
+ w -= overlap * lanczos_vectors[k]
85
+
86
+ n2 = w.norm()
87
+ T[j + 1, j] = n2
88
+
89
+ if n2 < tolerance:
90
+ converged = True
91
+ break
92
+
93
+ lanczos_vectors.append(w / n2)
94
+ # Compute exponential of extended T matrix
95
+ T[j + 2, j + 1] = 1
96
+ expd = torch.linalg.matrix_exp(T[: j + 3, : j + 3])
97
+
98
+ # Local truncation error estimation
99
+ err1 = abs(expd[j + 1, 0])
100
+ err2 = abs(expd[j + 2, 0] * n)
101
+
102
+ err = err1 if err1 < err2 else (err1 * err2 / (err1 - err2))
103
+ if err < tolerance:
104
+ converged = True
105
+ break
106
+
107
+ if not converged:
108
+ raise RecursionError(
109
+ "Lanczos iteration did not converge to precision in allotted number of steps."
110
+ )
111
+ size = len(lanczos_vectors)
112
+ return lanczos_vectors, T[:size, :size]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: emu-base
3
- Version: 2.0.0
3
+ Version: 2.0.2
4
4
  Summary: Pasqal base classes for emulators
5
5
  Project-URL: Documentation, https://pasqal-io.github.io/emulators/
6
6
  Project-URL: Repository, https://github.com/pasqal-io/emulators
@@ -26,7 +26,7 @@ Classifier: Programming Language :: Python :: Implementation :: CPython
26
26
  Classifier: Programming Language :: Python :: Implementation :: PyPy
27
27
  Requires-Python: >=3.10
28
28
  Requires-Dist: pulser-core==1.4.*
29
- Requires-Dist: torch==2.5.0
29
+ Requires-Dist: torch==2.7.0
30
30
  Description-Content-Type: text/markdown
31
31
 
32
32
  <div align="center">
@@ -1,4 +1,4 @@
1
- emu_base/__init__.py,sha256=b60wKpJR1-oUIEv68t0-WNza2IXSL7joPQVt5Hw-rj8,493
1
+ emu_base/__init__.py,sha256=dRd4TzTdTmz9IgVlisaBWBMjstVpVASVnmvjjSMbncU,493
2
2
  emu_base/aggregators.py,sha256=bB-rldoDAErxQMpL715K5lpiabGOpkCY0GyxW7mfHuc,5000
3
3
  emu_base/constants.py,sha256=41LYkKLUCz-oxPbd-j7nUDZuhIbUrnez6prT0uR0jcE,56
4
4
  emu_base/lindblad_operators.py,sha256=Nsl1YrWb8IDM9Z50ucy2Ed44p_IRETnlbr6qaqAgV50,1629
@@ -6,7 +6,8 @@ emu_base/pulser_adapter.py,sha256=dRD80z_dVXkCjDBLRIkmqNGg5M78VEKkQuk3H5JdZSM,11
6
6
  emu_base/utils.py,sha256=RM8O0qfPAJfcdqqAojwEEKV7I3ZfVDklnTisTGhUg5k,233
7
7
  emu_base/math/__init__.py,sha256=6BbIytYV5uC-e5jLMtIErkcUl_PvfSNnhmVFY9Il8uQ,97
8
8
  emu_base/math/brents_root_finding.py,sha256=AVx6L1Il6rpPJWrLJ7cn6oNmJyZOPRgEaaZaubC9lsU,3711
9
+ emu_base/math/double_krylov.py,sha256=-DUZ5R3g7CUMQWSET2MUxXZKObXgLNanwAtS5nX8T68,3677
9
10
  emu_base/math/krylov_exp.py,sha256=UCFNeq-j2ukgBsOPC9_Jiv1aqpy88SrslDLiCxIGBwk,3840
10
- emu_base-2.0.0.dist-info/METADATA,sha256=uoylMuopYijyAJ9G8iY_cxXanQlJGu1ibvkd17Soi2g,3522
11
- emu_base-2.0.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
12
- emu_base-2.0.0.dist-info/RECORD,,
11
+ emu_base-2.0.2.dist-info/METADATA,sha256=jqdSM-agy3W_gomgmN3dXuKw2NajCGgRtQPIDz1M6Us,3522
12
+ emu_base-2.0.2.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
13
+ emu_base-2.0.2.dist-info/RECORD,,