PyQUDA-Utils 0.10.1.dev0__tar.gz → 0.10.1.dev2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. {pyquda_utils-0.10.1.dev0 → pyquda_utils-0.10.1.dev2}/PKG-INFO +1 -1
  2. {pyquda_utils-0.10.1.dev0 → pyquda_utils-0.10.1.dev2}/PyQUDA_Utils.egg-info/PKG-INFO +1 -1
  3. {pyquda_utils-0.10.1.dev0 → pyquda_utils-0.10.1.dev2}/PyQUDA_Utils.egg-info/SOURCES.txt +1 -0
  4. {pyquda_utils-0.10.1.dev0 → pyquda_utils-0.10.1.dev2}/pyquda_io/_field_utils.py +6 -6
  5. {pyquda_utils-0.10.1.dev0 → pyquda_utils-0.10.1.dev2}/pyquda_io/chroma.py +2 -2
  6. {pyquda_utils-0.10.1.dev0 → pyquda_utils-0.10.1.dev2}/pyquda_io/milc.py +2 -2
  7. {pyquda_utils-0.10.1.dev0 → pyquda_utils-0.10.1.dev2}/pyquda_io/nersc.py +3 -3
  8. pyquda_utils-0.10.1.dev2/pyquda_utils/_version.py +1 -0
  9. pyquda_utils-0.10.1.dev2/pyquda_utils/gauge_nd_sun.py +331 -0
  10. {pyquda_utils-0.10.1.dev0 → pyquda_utils-0.10.1.dev2}/pyquda_utils/gpt.py +1 -1
  11. {pyquda_utils-0.10.1.dev0 → pyquda_utils-0.10.1.dev2}/pyquda_utils/io/__init__.py +4 -4
  12. pyquda_utils-0.10.1.dev0/pyquda_utils/_version.py +0 -1
  13. {pyquda_utils-0.10.1.dev0 → pyquda_utils-0.10.1.dev2}/LICENSE +0 -0
  14. {pyquda_utils-0.10.1.dev0 → pyquda_utils-0.10.1.dev2}/MANIFEST.in +0 -0
  15. {pyquda_utils-0.10.1.dev0 → pyquda_utils-0.10.1.dev2}/PyQUDA_Utils.egg-info/dependency_links.txt +0 -0
  16. {pyquda_utils-0.10.1.dev0 → pyquda_utils-0.10.1.dev2}/PyQUDA_Utils.egg-info/requires.txt +0 -0
  17. {pyquda_utils-0.10.1.dev0 → pyquda_utils-0.10.1.dev2}/PyQUDA_Utils.egg-info/top_level.txt +0 -0
  18. {pyquda_utils-0.10.1.dev0 → pyquda_utils-0.10.1.dev2}/README.md +0 -0
  19. {pyquda_utils-0.10.1.dev0 → pyquda_utils-0.10.1.dev2}/pyproject.toml +0 -0
  20. {pyquda_utils-0.10.1.dev0 → pyquda_utils-0.10.1.dev2}/pyquda_io/__init__.py +0 -0
  21. {pyquda_utils-0.10.1.dev0 → pyquda_utils-0.10.1.dev2}/pyquda_io/_mpi_file.py +0 -0
  22. {pyquda_utils-0.10.1.dev0 → pyquda_utils-0.10.1.dev2}/pyquda_io/io_general.py +0 -0
  23. {pyquda_utils-0.10.1.dev0 → pyquda_utils-0.10.1.dev2}/pyquda_io/kyu.py +0 -0
  24. {pyquda_utils-0.10.1.dev0 → pyquda_utils-0.10.1.dev2}/pyquda_io/lime.py +0 -0
  25. {pyquda_utils-0.10.1.dev0 → pyquda_utils-0.10.1.dev2}/pyquda_io/npy.py +0 -0
  26. {pyquda_utils-0.10.1.dev0 → pyquda_utils-0.10.1.dev2}/pyquda_io/openqcd.py +0 -0
  27. {pyquda_utils-0.10.1.dev0 → pyquda_utils-0.10.1.dev2}/pyquda_io/xqcd.py +0 -0
  28. {pyquda_utils-0.10.1.dev0 → pyquda_utils-0.10.1.dev2}/pyquda_utils/__init__.py +0 -0
  29. {pyquda_utils-0.10.1.dev0 → pyquda_utils-0.10.1.dev2}/pyquda_utils/convert.py +0 -0
  30. {pyquda_utils-0.10.1.dev0 → pyquda_utils-0.10.1.dev2}/pyquda_utils/core.py +0 -0
  31. {pyquda_utils-0.10.1.dev0 → pyquda_utils-0.10.1.dev2}/pyquda_utils/deprecated.py +0 -0
  32. {pyquda_utils-0.10.1.dev0 → pyquda_utils-0.10.1.dev2}/pyquda_utils/gamma.py +0 -0
  33. {pyquda_utils-0.10.1.dev0 → pyquda_utils-0.10.1.dev2}/pyquda_utils/hmc_param.py +0 -0
  34. {pyquda_utils-0.10.1.dev0 → pyquda_utils-0.10.1.dev2}/pyquda_utils/milc_rhmc_param.py +0 -0
  35. {pyquda_utils-0.10.1.dev0 → pyquda_utils-0.10.1.dev2}/pyquda_utils/phase.py +0 -0
  36. {pyquda_utils-0.10.1.dev0 → pyquda_utils-0.10.1.dev2}/pyquda_utils/quasi_axial_gauge_fixing.py +0 -0
  37. {pyquda_utils-0.10.1.dev0 → pyquda_utils-0.10.1.dev2}/pyquda_utils/source.py +0 -0
  38. {pyquda_utils-0.10.1.dev0 → pyquda_utils-0.10.1.dev2}/setup.cfg +0 -0
  39. {pyquda_utils-0.10.1.dev0 → pyquda_utils-0.10.1.dev2}/setup.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: PyQUDA-Utils
3
- Version: 0.10.1.dev0
3
+ Version: 0.10.1.dev2
4
4
  Summary: Utility scripts based on PyQUDA
5
5
  Author-email: SaltyChiang <SaltyChiang@users.noreply.github.com>
6
6
  Maintainer-email: SaltyChiang <SaltyChiang@users.noreply.github.com>
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: PyQUDA-Utils
3
- Version: 0.10.1.dev0
3
+ Version: 0.10.1.dev2
4
4
  Summary: Utility scripts based on PyQUDA
5
5
  Author-email: SaltyChiang <SaltyChiang@users.noreply.github.com>
6
6
  Maintainer-email: SaltyChiang <SaltyChiang@users.noreply.github.com>
@@ -26,6 +26,7 @@ pyquda_utils/convert.py
26
26
  pyquda_utils/core.py
27
27
  pyquda_utils/deprecated.py
28
28
  pyquda_utils/gamma.py
29
+ pyquda_utils/gauge_nd_sun.py
29
30
  pyquda_utils/gpt.py
30
31
  pyquda_utils/hmc_param.py
31
32
  pyquda_utils/milc_rhmc_param.py
@@ -173,7 +173,7 @@ def gaugeProject(gauge: numpy.ndarray):
173
173
  pass
174
174
 
175
175
 
176
- def gaugeReunitarize(gauge: numpy.ndarray, reunitarize_sigma: bool = True):
176
+ def gaugeReunitarize(gauge: numpy.ndarray, reunitarize_sigma: float):
177
177
  gauge = numpy.ascontiguousarray(gauge.transpose(5, 6, 0, 1, 2, 3, 4))
178
178
  row0_abs = numpy.linalg.norm(gauge[0], axis=0)
179
179
  gauge[0] /= row0_abs
@@ -182,7 +182,7 @@ def gaugeReunitarize(gauge: numpy.ndarray, reunitarize_sigma: bool = True):
182
182
  row1_abs = numpy.linalg.norm(gauge[1], axis=0)
183
183
  gauge[1] /= row1_abs
184
184
  row2 = numpy.cross(gauge[0], gauge[1], axis=0).conjugate()
185
- if reunitarize_sigma:
185
+ if reunitarize_sigma > 0:
186
186
  assert (
187
187
  MPI.COMM_WORLD.allreduce(
188
188
  numpy.sqrt(
@@ -193,13 +193,13 @@ def gaugeReunitarize(gauge: numpy.ndarray, reunitarize_sigma: bool = True):
193
193
  ).max(),
194
194
  MPI.MAX,
195
195
  )
196
- < 2e-7 # sqrt(Nc) * fp32 machine epsilon
196
+ < reunitarize_sigma
197
197
  )
198
198
  gauge[2] = row2
199
199
  return gauge.transpose(2, 3, 4, 5, 6, 0, 1)
200
200
 
201
201
 
202
- def gaugeReunitarizeReconstruct12(gauge: numpy.ndarray, reunitarize_sigma: bool = True):
202
+ def gaugeReunitarizeReconstruct12(gauge: numpy.ndarray, reunitarize_sigma: float):
203
203
  """gauge shape (Nd, Lt, Lz, Ly, Lx, Nc - 1, Nc)"""
204
204
  gauge_ = gauge.transpose(5, 6, 0, 1, 2, 3, 4)
205
205
  gauge = numpy.empty((Nc, *gauge_.shape[1:]), "<c16")
@@ -211,13 +211,13 @@ def gaugeReunitarizeReconstruct12(gauge: numpy.ndarray, reunitarize_sigma: bool
211
211
  row1_abs = numpy.linalg.norm(gauge[1], axis=0)
212
212
  gauge[1] /= row1_abs
213
213
  row2 = numpy.cross(gauge[0], gauge[1], axis=0).conjugate()
214
- if reunitarize_sigma:
214
+ if reunitarize_sigma > 0:
215
215
  assert (
216
216
  MPI.COMM_WORLD.allreduce(
217
217
  numpy.sqrt((1 - row0_abs) ** 2 + numpy.abs(row0_row1) ** 2 + (1 - row1_abs) ** 2).max(),
218
218
  MPI.MAX,
219
219
  )
220
- < 2e-7 # sqrt(Nc) * fp32 machine epsilon
220
+ < reunitarize_sigma
221
221
  )
222
222
  gauge[2] = row2
223
223
  return gauge.transpose(2, 3, 4, 5, 6, 0, 1)
@@ -34,7 +34,7 @@ def checksum_qio(latt_size: List[int], grid_size: List[int], data):
34
34
  return sum29, sum31
35
35
 
36
36
 
37
- def readQIOGauge(filename: str, grid_size: List[int], checksum: bool = True):
37
+ def readQIOGauge(filename: str, grid_size: List[int], checksum: bool = True, reunitarize_sigma: float = 2e-7):
38
38
  from .lime import Lime
39
39
 
40
40
  lime = Lime(filename)
@@ -68,7 +68,7 @@ def readQIOGauge(filename: str, grid_size: List[int], checksum: bool = True):
68
68
  ), f"Bad checksum for {filename}"
69
69
  gauge = gauge.transpose(4, 0, 1, 2, 3, 5, 6).astype("<c16")
70
70
  if precision == 4:
71
- gauge = gaugeReunitarize(gauge)
71
+ gauge = gaugeReunitarize(gauge, reunitarize_sigma) # 2e-7: Nc**0.5 * 1.1920929e-07
72
72
  return latt_size, gauge
73
73
 
74
74
 
@@ -55,7 +55,7 @@ def checksum_qio(latt_size: List[int], grid_size: List[int], data):
55
55
  return sum29, sum31
56
56
 
57
57
 
58
- def readGauge(filename: str, grid_size: List[int], checksum: bool = True, reunitarize_sigma: bool = True):
58
+ def readGauge(filename: str, grid_size: List[int], checksum: bool = True, reunitarize_sigma: float = 2e-7):
59
59
  filename = path.expanduser(path.expandvars(filename))
60
60
  with open(filename, "rb") as f:
61
61
  magic = f.read(4)
@@ -79,7 +79,7 @@ def readGauge(filename: str, grid_size: List[int], checksum: bool = True, reunit
79
79
  sum31,
80
80
  ), f"Bad checksum for {filename}"
81
81
  gauge = gauge.transpose(4, 0, 1, 2, 3, 5, 6).astype("<c16")
82
- gauge = gaugeReunitarize(gauge, reunitarize_sigma)
82
+ gauge = gaugeReunitarize(gauge, reunitarize_sigma) # 2e-7: Nc**0.5 * 1.1920929e-07
83
83
  return latt_size, gauge
84
84
 
85
85
 
@@ -27,7 +27,7 @@ def readGauge(
27
27
  checksum: bool = True,
28
28
  plaquette: bool = True,
29
29
  link_trace: bool = True,
30
- reunitarize_sigma: bool = True,
30
+ reunitarize_sigma: float = 2e-7,
31
31
  ):
32
32
  filename = path.expanduser(path.expandvars(filename))
33
33
  header: Dict[str, str] = {}
@@ -63,7 +63,7 @@ def readGauge(
63
63
  assert checksum_nersc(gauge.reshape(-1)) == int(header["CHECKSUM"], 16), f"Bad checksum for {filename}"
64
64
  gauge = gauge.transpose(4, 0, 1, 2, 3, 5, 6).astype("<c16")
65
65
  if float_nbytes == 4:
66
- gauge = gaugeReunitarize(gauge, reunitarize_sigma)
66
+ gauge = gaugeReunitarize(gauge, reunitarize_sigma) # 2e-7: Nc**0.5 * 1.1920929e-07
67
67
  elif header["DATATYPE"] == "4D_SU3_GAUGE":
68
68
  gauge = readMPIFile(filename, dtype, offset, (Lt, Lz, Ly, Lx, Nd, Nc - 1, Nc), (3, 2, 1, 0), grid_size)
69
69
  gauge = gauge.astype(f"<c{2 * float_nbytes}")
@@ -71,7 +71,7 @@ def readGauge(
71
71
  assert checksum_nersc(gauge.reshape(-1)) == int(header["CHECKSUM"], 16), f"Bad checksum for {filename}"
72
72
  gauge = gauge.transpose(4, 0, 1, 2, 3, 5, 6).astype("<c16")
73
73
  if float_nbytes == 4:
74
- gauge = gaugeReunitarizeReconstruct12(gauge, reunitarize_sigma)
74
+ gauge = gaugeReunitarizeReconstruct12(gauge, reunitarize_sigma) # 2e-7: Nc**0.5 * 1.1920929e-07
75
75
  elif float_nbytes == 8:
76
76
  gauge = gaugeReconstruct12(gauge)
77
77
  else:
@@ -0,0 +1 @@
1
+ __version__ = "0.10.1.dev2"
@@ -0,0 +1,331 @@
1
+ from os import environ
2
+ from typing import Sequence
3
+
4
+ import numpy
5
+ from mpi4py import MPI
6
+
7
+ _RANK = MPI.COMM_WORLD.Get_rank()
8
+ _GPUID: int = -1
9
+
10
+
11
+ def initGPU(gpuid: int = -1):
12
+ from platform import node as gethostname
13
+ import cupy
14
+
15
+ rank = _RANK
16
+
17
+ # quda/include/communicator_quda.h
18
+ # determine which GPU this rank will use
19
+ hostname = gethostname()
20
+ hostname_recv_buf = MPI.COMM_WORLD.allgather(hostname)
21
+
22
+ if gpuid < 0:
23
+ device_count = cupy.cuda.runtime.getDeviceCount()
24
+ if device_count == 0:
25
+ raise RuntimeError("No devices found")
26
+
27
+ # We initialize gpuid if it's still negative.
28
+ gpuid = 0
29
+ for i in range(rank):
30
+ if hostname == hostname_recv_buf[i]:
31
+ gpuid += 1
32
+
33
+ if gpuid >= device_count:
34
+ if "QUDA_ENABLE_MPS" in environ and environ["QUDA_ENABLE_MPS"] == "1":
35
+ gpuid %= device_count
36
+ print(f"MPS enabled, rank={rank} -> gpu={gpuid}")
37
+ else:
38
+ raise RuntimeError(f"Too few GPUs available on {hostname}")
39
+
40
+ cupy.cuda.Device(gpuid).use()
41
+ print(f"Rank {rank} uses GPU {gpuid}")
42
+
43
+ global _GPUID
44
+ _GPUID = gpuid
45
+
46
+
47
+ def getSublatticeSize(latt_size: Sequence[int], grid_size: Sequence[int]):
48
+ assert len(latt_size) == len(grid_size)
49
+ for GL, G in zip(latt_size, grid_size):
50
+ assert GL % G == 0
51
+ return [GL // G for GL, G in zip(latt_size, grid_size)]
52
+
53
+
54
+ def getGridCoord(grid_size: Sequence[int]):
55
+ rank = _RANK
56
+ grid_coord = []
57
+ for G in grid_size[::-1]:
58
+ grid_coord.append(rank % G)
59
+ rank //= G
60
+ return grid_coord[::-1]
61
+
62
+
63
+ def getShiftedRank(grid_coord: Sequence[int], grid_size: Sequence[int], delta: Sequence[int]):
64
+ Nd = len(grid_size)
65
+ grid_coord = [(g + d) % G for g, G, d in zip(grid_coord, grid_size, delta)]
66
+ rank = grid_coord[0]
67
+ for mu in range(1, Nd):
68
+ rank = rank * grid_size[mu] + grid_coord[mu]
69
+ return rank
70
+
71
+
72
+ def gaugeSendRecv(extended, gauge, dest, source):
73
+ rank = _RANK
74
+ if rank == dest and rank == source:
75
+ extended[:] = gauge
76
+ else:
77
+ buf = gauge.copy()
78
+ MPI.COMM_WORLD.Sendrecv_replace(buf, dest=dest, source=source)
79
+ extended[:] = buf
80
+
81
+
82
+ class LatticeLink:
83
+ def __init__(
84
+ self,
85
+ latt_size: Sequence[int],
86
+ grid_size: Sequence[int],
87
+ color: int,
88
+ matrix: numpy.ndarray = None,
89
+ mu: int = None,
90
+ ):
91
+ assert len(latt_size) == len(grid_size)
92
+ self.Nd = len(latt_size)
93
+ self.Nc = color
94
+ self.latt_size = tuple(latt_size)
95
+ self.grid_size = tuple(grid_size)
96
+ self.grid_coord = getGridCoord(grid_size)
97
+ self.sublatt_size = getSublatticeSize(latt_size, grid_size)
98
+ if matrix is None:
99
+ self.matrix = numpy.empty((*self.sublatt_size[::-1], self.Nc, self.Nc), numpy.complex128)
100
+ else:
101
+ self.matrix = matrix.reshape(*self.sublatt_size[::-1], self.Nc, self.Nc)
102
+ self.mu = mu
103
+
104
+ def __getitem__(self, key):
105
+ return self.matrix[key]
106
+
107
+ def __matmul__(self, other: "LatticeLink"):
108
+ return self.matrix @ other.matrix
109
+
110
+ @property
111
+ def backend(self):
112
+ if type(self.matrix).__module__ == "numpy":
113
+ return numpy
114
+ elif type(self.matrix).__module__ == "cupy":
115
+ import cupy
116
+
117
+ return cupy
118
+ else:
119
+ raise RuntimeError(f"Unknown array type {type(self.matrix)}")
120
+
121
+ def shift(self, mu: int, dagger: bool = False):
122
+ assert 0 <= mu < 2 * self.Nd
123
+ backend = self.backend
124
+ Nd = self.Nd
125
+ dir = 1 - 2 * (mu // self.Nd)
126
+ mu = mu % self.Nd
127
+ left_slice = [slice(None, None) for nu in range(self.Nd)]
128
+ right_slice = [slice(None, None) for nu in range(self.Nd)]
129
+ result = backend.empty_like(self.matrix)
130
+ right = self.matrix
131
+ rank = _RANK
132
+ dest = getShiftedRank(self.grid_coord, self.grid_size, [0 if nu != mu else -dir for nu in range(Nd)])
133
+ source = getShiftedRank(self.grid_coord, self.grid_size, [0 if nu != mu else dir for nu in range(Nd)])
134
+
135
+ left_slice[mu] = slice(-1, None) if dir == 1 else slice(None, 1)
136
+ right_slice[mu] = slice(None, 1) if dir == 1 else slice(-1, None)
137
+
138
+ # gaugeSendRecv(shifted[*shift_slice[::-1]], matrix[*matrix_slice[::-1]], dest, source)
139
+ sendbuf = right[*right_slice[::-1]] if not dagger else right[*right_slice[::-1]].swapaxes(-2, -1).conjugate()
140
+ if rank == source and rank == dest:
141
+ pass
142
+ else:
143
+ sendbuf = backend.ascontiguousarray(sendbuf)
144
+ request = MPI.COMM_WORLD.Isend(sendbuf, dest)
145
+
146
+ left_slice[mu] = slice(None, -1) if dir == 1 else slice(1, None)
147
+ right_slice[mu] = slice(1, None) if dir == 1 else slice(None, -1)
148
+ result[*left_slice[::-1]] = (
149
+ right[*right_slice[::-1]] if not dagger else right[*right_slice[::-1]].swapaxes(-2, -1).conjugate()
150
+ )
151
+ left_slice[mu] = slice(-1, None) if dir == 1 else slice(None, 1)
152
+ right_slice[mu] = slice(None, 1) if dir == 1 else slice(-1, None)
153
+
154
+ if rank == source and rank == dest:
155
+ recvbuf = sendbuf
156
+ else:
157
+ recvbuf = backend.empty_like(sendbuf)
158
+ MPI.COMM_WORLD.Recv(recvbuf, source)
159
+ request.Wait()
160
+ result[*left_slice[::-1]] = recvbuf
161
+
162
+ return LatticeLink(self.latt_size, self.grid_size, self.Nc, result, self.mu)
163
+
164
+ def link(self, right: "LatticeLink"):
165
+ assert self.mu is not None, "Ambiguous dimension and direction"
166
+ backend = self.backend
167
+ Nd = self.Nd
168
+ dir = 1 - 2 * (self.mu // self.Nd)
169
+ mu = self.mu % self.Nd
170
+ left_slice = [slice(None, None) for nu in range(self.Nd)]
171
+ right_slice = [slice(None, None) for nu in range(self.Nd)]
172
+ result = backend.empty_like(self.matrix)
173
+ left = self.matrix
174
+ right = right.matrix
175
+ rank = _RANK
176
+ dest = getShiftedRank(self.grid_coord, self.grid_size, [0 if nu != mu else -dir for nu in range(Nd)])
177
+ source = getShiftedRank(self.grid_coord, self.grid_size, [0 if nu != mu else dir for nu in range(Nd)])
178
+
179
+ left_slice[mu] = slice(-1, None) if dir == 1 else slice(None, 1)
180
+ right_slice[mu] = slice(None, 1) if dir == 1 else slice(-1, None)
181
+
182
+ sendbuf = right[*right_slice[::-1]]
183
+ if rank == source and rank == dest:
184
+ pass
185
+ else:
186
+ sendbuf = backend.ascontiguousarray(sendbuf)
187
+ request = MPI.COMM_WORLD.Isend(sendbuf, dest)
188
+
189
+ left_slice[mu] = slice(None, -1) if dir == 1 else slice(1, None)
190
+ right_slice[mu] = slice(1, None) if dir == 1 else slice(None, -1)
191
+ result[*left_slice[::-1]] = left[*left_slice[::-1]] @ right[*right_slice[::-1]]
192
+ left_slice[mu] = slice(-1, None) if dir == 1 else slice(None, 1)
193
+ right_slice[mu] = slice(None, 1) if dir == 1 else slice(-1, None)
194
+
195
+ if rank == source and rank == dest:
196
+ recvbuf = sendbuf
197
+ else:
198
+ recvbuf = backend.empty_like(sendbuf)
199
+ MPI.COMM_WORLD.Recv(recvbuf, source)
200
+ request.Wait()
201
+ result[*left_slice[::-1]] = left[*left_slice[::-1]] @ recvbuf
202
+
203
+ return LatticeLink(self.latt_size, self.grid_size, self.Nc, result)
204
+
205
+ def dagger(self):
206
+ return LatticeLink(
207
+ self.latt_size,
208
+ self.grid_size,
209
+ self.Nc,
210
+ self.matrix.swapaxes(-2, -1).conjugate(),
211
+ )
212
+
213
+ def toDevice(self):
214
+ import cupy
215
+
216
+ if _GPUID < 0:
217
+ initGPU()
218
+ self.matrix = cupy.asarray(self.matrix)
219
+
220
+ def toHost(self):
221
+ self.matrix = self.matrix.get()
222
+
223
+
224
+ class LatticeGauge:
225
+ def __init__(
226
+ self,
227
+ latt_size: Sequence[int],
228
+ grid_size: Sequence[int],
229
+ color: int,
230
+ border: int = 0,
231
+ gauge: numpy.ndarray = None,
232
+ extended: numpy.ndarray = None,
233
+ ):
234
+ assert len(latt_size) == len(grid_size)
235
+ self.Nd = len(latt_size)
236
+ self.Nc = color
237
+ self.latt_size = tuple(latt_size)
238
+ self.grid_size = tuple(grid_size)
239
+ self.grid_coord = getGridCoord(grid_size)
240
+ self.sublatt_size = getSublatticeSize(latt_size, grid_size)
241
+ shape = (self.Nd, *self.sublatt_size[::-1], self.Nc, self.Nc)
242
+ if gauge is None:
243
+ self.gauge = numpy.empty(shape, numpy.complex128)
244
+ else:
245
+ self.gauge = gauge.reshape(shape)
246
+ self.extend(border, extended)
247
+
248
+ def __getitem__(self, mu):
249
+ assert 0 <= mu < 2 * self.Nd
250
+ gauge_mu = LatticeLink(self.latt_size, self.grid_size, self.Nc, self.gauge[mu % self.Nd], mu)
251
+ return gauge_mu if mu < self.Nd else gauge_mu.shift(mu, True)
252
+
253
+ @property
254
+ def backend(self):
255
+ if type(self.gauge).__module__ == "numpy":
256
+ return numpy
257
+ elif type(self.gauge).__module__ == "cupy":
258
+ import cupy
259
+
260
+ return cupy
261
+ else:
262
+ raise RuntimeError(f"Unknown array type {type(self.gauge)}")
263
+
264
+ def extend(self, border: int, extended: numpy.ndarray = None):
265
+ if border <= 0:
266
+ self.Lb = 0
267
+ self.extlatt_size = self.sublatt_size
268
+ self.extended = None
269
+ else:
270
+ self.Lb = border
271
+ self.extlatt_size = [L + 2 * border for L in self.sublatt_size]
272
+ shape = (self.Nd, *self.extlatt_size[::-1], self.Nc, self.Nc)
273
+ if extended is None:
274
+ self.extended = self.backend.empty(shape, self.gauge.dtype)
275
+ else:
276
+ self.extended = extended.reshape(shape)
277
+ self.exchange()
278
+
279
+ def exchange(self):
280
+ assert self.extended is not None
281
+ Nd = self.Nd
282
+ Lb = self.Lb
283
+ extended_slice = [slice(Lb, -Lb) for mu in range(Nd)]
284
+ gauge_slice = [slice(None, None) for mu in range(Nd)]
285
+ stride = [3 ** (Nd - 1 - mu) for mu in range(Nd)]
286
+ for tag in range(3**Nd):
287
+ delta = [(tag // stride[mu] % 3 - 1) for mu in range(Nd)]
288
+ for mu in range(Nd):
289
+ if delta[mu] == -1:
290
+ extended_slice[mu] = slice(-Lb, None)
291
+ gauge_slice[mu] = slice(None, Lb)
292
+ elif delta[mu] == 1:
293
+ extended_slice[mu] = slice(None, Lb)
294
+ gauge_slice[mu] = slice(-Lb, None)
295
+ gaugeSendRecv(
296
+ self.extended[:, *extended_slice[::-1]],
297
+ self.gauge[:, *gauge_slice[::-1]],
298
+ getShiftedRank(self.grid_coord, self.grid_size, delta),
299
+ getShiftedRank(self.grid_coord, self.grid_size, [-d for d in delta]),
300
+ )
301
+ for mu in range(Nd):
302
+ if delta[mu] != 0:
303
+ extended_slice[mu] = slice(Lb, -Lb)
304
+ gauge_slice[mu] = slice(None, None)
305
+
306
+ def shift(self, delta: Sequence[int]):
307
+ assert numpy.abs(delta).max() <= self.Lb
308
+ Lb = self.Lb
309
+ extended_slice = [slice(Lb + d, None if Lb == d else -(Lb - d)) for d in delta[::-1]]
310
+ return LatticeGauge(self.latt_size, self.grid_size, self.Nc, 0, self.extended[:, *extended_slice], None)
311
+
312
+ def toDevice(self):
313
+ import cupy
314
+
315
+ if _GPUID < 0:
316
+ initGPU()
317
+ self.gauge = cupy.asarray(self.gauge)
318
+ if self.extended is not None:
319
+ self.extended = cupy.asarray(self.extended)
320
+
321
+ def toHost(self):
322
+ self.gauge = self.gauge.get()
323
+ if self.extended is not None:
324
+ self.extended = self.extended.get()
325
+
326
+
327
+ def link(*color_matrices: LatticeLink):
328
+ linked = color_matrices[-1]
329
+ for color_matrix in color_matrices[::-1][1:]:
330
+ linked = color_matrix.link(linked)
331
+ return linked
@@ -82,7 +82,7 @@ def LatticePropagatorGPT(lattice: g.lattice, gen_simd_width: int, propagator: La
82
82
  return propagator
83
83
  else:
84
84
  assert latt_info.size == propagator.latt_info.size
85
- gpt_shape = [i for sl in zip(gpt_simd, gpt_latt) for i in sl]
85
+ gpt_shape = [i for sl in zip(gpt_simd[::-1], gpt_latt[::-1]) for i in sl]
86
86
  lattice.mview()[0][:] = (
87
87
  propagator.lexico()
88
88
  .astype(f"<c{2 * gpt_prec}")
@@ -69,11 +69,11 @@ def rotateToDeGrandRossi(propagator: LatticePropagator):
69
69
  )
70
70
 
71
71
 
72
- def readChromaQIOGauge(filename: str, checksum: bool = True):
72
+ def readChromaQIOGauge(filename: str, checksum: bool = True, reunitarize_sigma: float = 2e-7):
73
73
  from pyquda import getGridSize
74
74
  from pyquda_io.chroma import readQIOGauge as read
75
75
 
76
- latt_size, gauge_raw = read(filename, getGridSize(), checksum)
76
+ latt_size, gauge_raw = read(filename, getGridSize(), checksum, reunitarize_sigma)
77
77
  return LatticeGauge(LatticeInfo(latt_size), evenodd(gauge_raw, [1, 2, 3, 4]))
78
78
 
79
79
 
@@ -96,7 +96,7 @@ def readChromaQIOPropagator(filename: str, checksum: bool = True):
96
96
  return LatticeStaggeredPropagator(LatticeInfo(latt_size), evenodd(propagator_raw, [0, 1, 2, 3]))
97
97
 
98
98
 
99
- def readMILCGauge(filename: str, checksum: bool = True, reunitarize_sigma: bool = True):
99
+ def readMILCGauge(filename: str, checksum: bool = True, reunitarize_sigma: float = 2e-7):
100
100
  from pyquda import getGridSize
101
101
  from pyquda_io.milc import readGauge as read
102
102
 
@@ -244,7 +244,7 @@ def readNERSCGauge(
244
244
  checksum: bool = True,
245
245
  plaquette: bool = True,
246
246
  link_trace: bool = True,
247
- reunitarize_sigma: bool = True,
247
+ reunitarize_sigma: float = 2e-7,
248
248
  ):
249
249
  from pyquda import getGridSize
250
250
  from pyquda_io.nersc import readGauge as read
@@ -1 +0,0 @@
1
- __version__ = "0.10.1.dev0"