quant-met 0.0.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,9 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2024-present Tjark <tsievers@physnet.uni-hamburg.de>
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
6
+
7
+ The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
8
+
9
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
@@ -0,0 +1,61 @@
1
+ Metadata-Version: 2.1
2
+ Name: quant-met
3
+ Version: 0.0.1
4
+ Summary:
5
+ Author: Tjark Sievers
6
+ Author-email: tsievers@physnet.uni-hamburg.de
7
+ Requires-Python: >=3.10,<4.0
8
+ Classifier: Programming Language :: Python :: 3
9
+ Classifier: Programming Language :: Python :: 3.10
10
+ Classifier: Programming Language :: Python :: 3.11
11
+ Classifier: Programming Language :: Python :: 3.12
12
+ Requires-Dist: click (>=8.1.7,<9.0.0)
13
+ Requires-Dist: matplotlib (>=3.8.4,<4.0.0)
14
+ Requires-Dist: numpy (>=1.26.4,<2.0.0)
15
+ Requires-Dist: pandas (>=2.2.2,<3.0.0)
16
+ Requires-Dist: scipy (>=1.13.0,<2.0.0)
17
+ Requires-Dist: sympy (>=1.12,<2.0)
18
+ Description-Content-Type: text/markdown
19
+
20
+ # quant-met
21
+
22
+ [![Test](https://github.com/Ruberhauptmann/quant-met/actions/workflows/test.yml/badge.svg)](https://github.com/Ruberhauptmann/quant-met/actions/workflows/test.yml)
23
+ [![Coverage Status](https://coveralls.io/repos/github/Ruberhauptmann/quant-met/badge.svg?branch=main)](https://coveralls.io/github/Ruberhauptmann/quant-met?branch=main)
24
+
25
+ * Documentation: [quant-met.readthedocs.io](https://quant-met.readthedocs.io/en/latest/)
26
+
27
+ ## Installation
28
+
29
+ The package can be installed via
30
+
31
+ ## Usage
32
+
33
+ ## Contributing
34
+
35
+ You are welcome to open an issue if you want something changed or added in the software or if there are bugs occuring.
36
+
37
+ ### Developing
38
+
39
+ You can also help develop this software further.
40
+ This should help you get set up to start this.
41
+
42
+ Prerequisites:
43
+ * make
44
+ * python
45
+ * conda
46
+
47
+ Set up the development environment:
48
+ * clone the repository
49
+ * run `make environment`
50
+ * now activate the conda environment `conda activate quant-met-dev`
51
+
52
+ Now you can create a separate branch to work on the project.
53
+
54
+ You can manually run tests using for example `tox -e py312` (for running against python 3.12).
55
+ After pushing your branch, all tests will also be run via Gitlab Actions.
56
+
57
+ Using `pre-commit`, automatic linting and formatting is done before every commit, which may cause the first commit to fail.
58
+ A second try should then succeed.
59
+
60
+ After you are done working on an issue and all tests are running successful, you can add a new piece of changelog via `scriv create` and make a merge request.
61
+
@@ -0,0 +1,41 @@
1
+ # quant-met
2
+
3
+ [![Test](https://github.com/Ruberhauptmann/quant-met/actions/workflows/test.yml/badge.svg)](https://github.com/Ruberhauptmann/quant-met/actions/workflows/test.yml)
4
+ [![Coverage Status](https://coveralls.io/repos/github/Ruberhauptmann/quant-met/badge.svg?branch=main)](https://coveralls.io/github/Ruberhauptmann/quant-met?branch=main)
5
+
6
+ * Documentation: [quant-met.readthedocs.io](https://quant-met.readthedocs.io/en/latest/)
7
+
8
+ ## Installation
9
+
10
+ The package can be installed via
11
+
12
+ ## Usage
13
+
14
+ ## Contributing
15
+
16
+ You are welcome to open an issue if you want something changed or added in the software or if there are bugs occuring.
17
+
18
+ ### Developing
19
+
20
+ You can also help develop this software further.
21
+ This should help you get set up to start this.
22
+
23
+ Prerequisites:
24
+ * make
25
+ * python
26
+ * conda
27
+
28
+ Set up the development environment:
29
+ * clone the repository
30
+ * run `make environment`
31
+ * now activate the conda environment `conda activate quant-met-dev`
32
+
33
+ Now you can create a separate branch to work on the project.
34
+
35
+ You can manually run tests using for example `tox -e py312` (for running against python 3.12).
36
+ After pushing your branch, all tests will also be run via Gitlab Actions.
37
+
38
+ Using `pre-commit`, automatic linting and formatting is done before every commit, which may cause the first commit to fail.
39
+ A second try should then succeed.
40
+
41
+ After you are done working on an issue and all tests are running successful, you can add a new piece of changelog via `scriv create` and make a merge request.
@@ -0,0 +1,35 @@
1
+ [tool.poetry]
2
+ name = "quant-met"
3
+ version = "0.0.1"
4
+ description = ""
5
+ authors = ["Tjark Sievers <tsievers@physnet.uni-hamburg.de>"]
6
+ readme = "README.md"
7
+
8
+ [tool.poetry.scripts]
9
+ quantmet = "quant_met.cli:cli"
10
+
11
+ [tool.poetry.dependencies]
12
+ python = "^3.10"
13
+ numpy = "^1.26.4"
14
+ scipy = "^1.13.0"
15
+ click = "^8.1.7"
16
+ matplotlib = "^3.8.4"
17
+ pandas = "^2.2.2"
18
+ sympy = "^1.12"
19
+
20
+ [tool.poetry.group.dev.dependencies]
21
+ pre-commit = "^3.7.0"
22
+ scriv = "^1.5.1"
23
+ tox = "^4.15.0"
24
+ jupyter = "^1.0.0"
25
+ ipympl = "^0.9.4"
26
+ black = "^24.4.2"
27
+ sphinx = "^7.3.7"
28
+ sphinx-rtd-theme = "^2.0.0"
29
+ myst-parser = "^3.0.1"
30
+ nbsphinx = "^0.9.4"
31
+ sphinx-gallery = "^0.16.0"
32
+
33
+ [build-system]
34
+ requires = ["poetry-core"]
35
+ build-backend = "poetry.core.masonry.api"
@@ -0,0 +1,4 @@
1
+ # SPDX-FileCopyrightText: 2024-present Tjark <tjarksievers@icloud.com>
2
+ #
3
+ # SPDX-License-Identifier: MIT
4
+ __version__ = "0.0.1"
@@ -0,0 +1,3 @@
1
+ # SPDX-FileCopyrightText: 2024-present Tjark <tjarksievers@icloud.com>
2
+ #
3
+ # SPDX-License-Identifier: MIT
File without changes
@@ -0,0 +1,65 @@
1
+ import numpy as np
2
+ import numpy.typing as npt
3
+ from scipy import interpolate, optimize
4
+
5
+ from quant_met.bcs.gap_equation import gap_equation_real
6
+ from quant_met.configuration import DeltaVector
7
+ from quant_met.hamiltonians import BaseHamiltonian
8
+
9
+
10
+ def generate_k_space_grid(nx, nrows, corner_1, corner_2):
11
+ k_points = np.concatenate(
12
+ [
13
+ np.linspace(
14
+ i / (nrows - 1) * corner_2,
15
+ corner_1 + i / (nrows - 1) * corner_2,
16
+ num=nx,
17
+ )
18
+ for i in range(nrows)
19
+ ]
20
+ )
21
+
22
+ return k_points
23
+
24
+
25
+ def solve_gap_equation(
26
+ hamiltonian: BaseHamiltonian, k_points: npt.NDArray, beta: float = 0
27
+ ) -> DeltaVector:
28
+ energies, bloch_absolute = hamiltonian.generate_bloch(k_points=k_points)
29
+
30
+ delta_vector = DeltaVector(
31
+ k_points=k_points, initial=0.1, number_bands=hamiltonian.number_bands
32
+ )
33
+ try:
34
+ solution = optimize.fixed_point(
35
+ gap_equation_real,
36
+ delta_vector.as_1d_vector,
37
+ args=(hamiltonian.U, beta, bloch_absolute, energies, len(k_points)),
38
+ )
39
+ except RuntimeError:
40
+ print("Failed")
41
+ solution = DeltaVector(
42
+ k_points=k_points, initial=0.0, number_bands=hamiltonian.number_bands
43
+ ).as_1d_vector
44
+
45
+ delta_vector.update_from_1d_vector(solution)
46
+
47
+ return delta_vector
48
+
49
+
50
+ def interpolate_gap(
51
+ delta_vector_on_grid: DeltaVector, bandpath: npt.NDArray
52
+ ) -> DeltaVector:
53
+ delta_vector_interpolated = DeltaVector(
54
+ k_points=bandpath, number_bands=delta_vector_on_grid.number_bands
55
+ )
56
+
57
+ for band in range(delta_vector_interpolated.number_bands):
58
+ delta_vector_interpolated.data.loc[:, f"delta_{band}"] = interpolate.griddata(
59
+ delta_vector_on_grid.k_points,
60
+ delta_vector_on_grid.data.loc[:, f"delta_{band}"],
61
+ bandpath,
62
+ method="cubic",
63
+ )
64
+
65
+ return delta_vector_interpolated
@@ -0,0 +1,43 @@
1
+ import numpy as np
2
+ import numpy.typing as npt
3
+
4
+
5
+ def gap_equation_real(
6
+ delta_k: npt.NDArray,
7
+ U: npt.NDArray,
8
+ beta: float,
9
+ bloch_absolute: npt.NDArray,
10
+ energies: npt.NDArray,
11
+ number_k_points: int,
12
+ ):
13
+ return_vector = np.zeros(len(delta_k))
14
+
15
+ number_bands = int(len(return_vector) / number_k_points)
16
+
17
+ for n in range(number_bands):
18
+ offset_n = int(len(delta_k) / number_bands * n)
19
+ for k_prime_index in range(0, number_k_points):
20
+ sum_tmp = 0
21
+ for alpha in range(number_bands):
22
+ for m in range(number_bands):
23
+ offset_m = int(len(delta_k) / number_bands * m)
24
+ for k_index in range(0, number_k_points):
25
+ sum_tmp += (
26
+ U[alpha]
27
+ * bloch_absolute[k_prime_index][alpha][n]
28
+ * bloch_absolute[k_index][alpha][m]
29
+ * delta_k[k_index + offset_m]
30
+ / (
31
+ 2
32
+ * np.sqrt(
33
+ (energies[k_index][m]) ** 2
34
+ + np.abs(delta_k[k_index + offset_m]) ** 2
35
+ )
36
+ )
37
+ )
38
+
39
+ return_vector[k_prime_index + offset_n] = sum_tmp / (
40
+ 2.5980762113533156 * number_k_points
41
+ )
42
+
43
+ return return_vector
@@ -0,0 +1,16 @@
1
+ import click
2
+
3
+ from quant_met.bcs import find_fixpoint
4
+
5
+
6
+ @click.command()
7
+ def cli():
8
+ find_fixpoint.solve_gap_equation()
9
+
10
+
11
+ def hello_world():
12
+ return "Hello World!"
13
+
14
+
15
+ if __name__ == "__main__":
16
+ cli()
@@ -0,0 +1,53 @@
1
+ import numpy as np
2
+ import numpy.typing as npt
3
+ import pandas as pd
4
+
5
+
6
+ class DeltaVector:
7
+ def __init__(
8
+ self,
9
+ number_bands: int,
10
+ hdf_file=None,
11
+ k_points: npt.NDArray | None = None,
12
+ initial: float | None = None,
13
+ ):
14
+ self.number_bands = number_bands
15
+ if hdf_file is not None:
16
+ pass
17
+ # self.data = pd.DataFrame(pd.read_hdf(hdf_file, key="table"))
18
+ # self.k_points = np.column_stack(
19
+ # (np.array(self.data.loc[:, "kx"]), np.array(self.data.loc[:, "ky"]))
20
+ # )
21
+ else:
22
+ self.k_points = k_points
23
+ self.data = pd.DataFrame(
24
+ # columns=["kx", "ky", "delta_1", "delta_2", "delta_3"],
25
+ index=range(len(k_points)),
26
+ dtype=np.float64,
27
+ )
28
+ self.data.loc[:, "kx"] = self.k_points[:, 0]
29
+ self.data.loc[:, "ky"] = self.k_points[:, 1]
30
+ if initial is not None:
31
+ for i in range(number_bands):
32
+ self.data.loc[:, f"delta_{i}"] = initial
33
+
34
+ def __repr__(self):
35
+ return self.data.to_string(index=False)
36
+
37
+ def update_from_1d_vector(self, delta: npt.NDArray):
38
+ for n in range(self.number_bands):
39
+ offset = int(n * len(delta) / self.number_bands)
40
+ self.data.loc[:, f"delta_{n}"] = delta[offset : offset + len(self.k_points)]
41
+
42
+ def save(self, path):
43
+ pass
44
+ # self.data.to_hdf(path, key="table", format="table", data_columns=True)
45
+
46
+ @property
47
+ def as_1d_vector(self) -> npt.NDArray:
48
+ return np.concatenate(
49
+ [
50
+ np.array(self.data.loc[:, f"delta_{n}"].values)
51
+ for n in range(self.number_bands)
52
+ ]
53
+ )
@@ -0,0 +1,174 @@
1
+ from abc import ABC, abstractmethod
2
+
3
+ import numpy as np
4
+ import numpy.typing as npt
5
+ import pandas as pd
6
+
7
+
8
+ class BaseHamiltonian(ABC):
9
+ @property
10
+ @abstractmethod
11
+ def number_bands(self) -> int:
12
+ raise NotImplementedError
13
+
14
+ @property
15
+ @abstractmethod
16
+ def U(self) -> list[float]:
17
+ raise NotImplementedError
18
+
19
+ @abstractmethod
20
+ def _k_space_matrix_one_point(self, k: npt.NDArray, h: npt.NDArray) -> npt.NDArray:
21
+ raise NotImplementedError
22
+
23
+ def k_space_matrix(self, k: npt.NDArray) -> npt.NDArray:
24
+ if k.ndim == 1:
25
+ h = np.zeros((1, self.number_bands, self.number_bands), dtype=complex)
26
+ h[0] = self._k_space_matrix_one_point(k, h[0])
27
+ else:
28
+ h = np.zeros(
29
+ (k.shape[0], self.number_bands, self.number_bands), dtype=complex
30
+ )
31
+ for k_index, k in enumerate(k):
32
+ h[k_index] = self._k_space_matrix_one_point(k, h[k_index])
33
+ return h
34
+
35
+ def calculate_bandstructure(self, k_point_list: npt.NDArray):
36
+ k_point_matrix = self.k_space_matrix(k_point_list)
37
+
38
+ results = pd.DataFrame(
39
+ index=range(len(k_point_list)),
40
+ dtype=float,
41
+ )
42
+
43
+ for i, k in enumerate(k_point_list):
44
+ energies, eigenvectors = np.linalg.eigh(k_point_matrix[i])
45
+
46
+ for band_index in range(self.number_bands):
47
+ results.at[i, f"band_{band_index}"] = energies[band_index]
48
+
49
+ return results
50
+
51
+ def generate_bloch(self, k_points: npt.NDArray):
52
+ k_point_matrix = self.k_space_matrix(k_points)
53
+
54
+ if k_points.ndim == 1:
55
+ energies, bloch = np.linalg.eigh(k_point_matrix[0])
56
+ else:
57
+ bloch = np.zeros(
58
+ (len(k_points), self.number_bands, self.number_bands), dtype=complex
59
+ )
60
+ energies = np.zeros((len(k_points), self.number_bands))
61
+
62
+ for i, k in enumerate(k_points):
63
+ energies[i], bloch[i] = np.linalg.eigh(k_point_matrix[i])
64
+
65
+ return energies, bloch
66
+
67
+
68
+ class GrapheneHamiltonian(BaseHamiltonian):
69
+ def __init__(self, t_nn: float, a: float, mu: float, U_gr: float):
70
+ self.t_nn = t_nn
71
+ self.a = a
72
+ self.mu = mu
73
+ self.U_gr = U_gr
74
+
75
+ @property
76
+ def U(self) -> list[float]:
77
+ return [self.U_gr, self.U_gr]
78
+
79
+ @property
80
+ def number_bands(self) -> int:
81
+ return 2
82
+
83
+ def _k_space_matrix_one_point(self, k: npt.NDArray, h: npt.NDArray) -> npt.NDArray:
84
+ t_nn = self.t_nn
85
+ a = self.a
86
+ mu = self.mu
87
+
88
+ h[0, 1] = t_nn * (
89
+ np.exp(1j * k[1] * a / np.sqrt(3))
90
+ + 2 * np.exp(-0.5j * a / np.sqrt(3) * k[1]) * (np.cos(0.5 * a * k[0]))
91
+ )
92
+
93
+ h[1, 0] = h[0, 1].conjugate()
94
+ h = h - mu * np.eye(2)
95
+ return h
96
+
97
+
98
+ class EGXHamiltonian(BaseHamiltonian):
99
+ def __init__(
100
+ self,
101
+ t_gr: float,
102
+ t_x: float,
103
+ V: float,
104
+ a: float,
105
+ mu: float,
106
+ U_gr: float,
107
+ U_x: float,
108
+ ):
109
+ self.t_gr = t_gr
110
+ self.t_x = t_x
111
+ self.V = V
112
+ self.a = a
113
+ self.mu = mu
114
+ self.U_gr = U_gr
115
+ self.U_x = U_x
116
+
117
+ @property
118
+ def U(self) -> list[float]:
119
+ return [self.U_gr, self.U_gr, self.U_x]
120
+
121
+ @property
122
+ def number_bands(self) -> int:
123
+ return 3
124
+
125
+ def _k_space_matrix_one_point(self, k: npt.NDArray, h: npt.NDArray) -> npt.NDArray:
126
+ t_gr = self.t_gr
127
+ t_x = self.t_x
128
+ a = self.a
129
+ a_0 = a / np.sqrt(3)
130
+ V = self.V
131
+ mu = self.mu
132
+
133
+ h[0, 1] = t_gr * (
134
+ np.exp(1j * k[1] * a / np.sqrt(3))
135
+ + 2 * np.exp(-0.5j * a / np.sqrt(3) * k[1]) * (np.cos(0.5 * a * k[0]))
136
+ )
137
+
138
+ h[1, 0] = h[0, 1].conjugate()
139
+
140
+ h[2, 0] = V
141
+ h[0, 2] = V
142
+
143
+ h[2, 2] = (
144
+ -2
145
+ * t_x
146
+ * (
147
+ np.cos(a * k[0])
148
+ + 2 * np.cos(0.5 * a * k[0]) * np.cos(0.5 * np.sqrt(3) * a * k[1])
149
+ )
150
+ )
151
+ h = h - mu * np.eye(3)
152
+ return h
153
+
154
+ def calculate_bandstructure(self, k_point_list: npt.NDArray):
155
+ k_point_matrix = self.k_space_matrix(k_point_list)
156
+
157
+ results = pd.DataFrame(
158
+ index=range(len(k_point_list)),
159
+ dtype=float,
160
+ )
161
+
162
+ for i, k in enumerate(k_point_list):
163
+ energies, eigenvectors = np.linalg.eigh(k_point_matrix[i])
164
+
165
+ for band_index in range(self.number_bands):
166
+ results.at[i, f"band_{band_index}"] = energies[band_index]
167
+ results.at[i, f"wx_{band_index}"] = (
168
+ np.abs(np.dot(eigenvectors[:, band_index], np.array([0, 0, 1])))
169
+ ** 2
170
+ - np.abs(np.dot(eigenvectors[:, band_index], np.array([1, 0, 0])))
171
+ ** 2
172
+ )
173
+
174
+ return results
@@ -0,0 +1,2 @@
1
+ from . import _plotting
2
+ from ._plotting import *
@@ -0,0 +1,180 @@
1
+ import matplotlib.pyplot as plt
2
+ import numpy as np
3
+ import numpy.typing as npt
4
+ from matplotlib.collections import LineCollection
5
+
6
+
7
+ def plot_into_bz(
8
+ bz_corners, k_points, fig: plt.Figure | None = None, ax: plt.Axes | None = None
9
+ ):
10
+ if fig is None:
11
+ fig, ax = plt.subplots()
12
+
13
+ ax.scatter(*zip(*k_points))
14
+ ax.scatter(*zip(*bz_corners), alpha=0.8)
15
+
16
+ ax.set_aspect("equal", adjustable="box")
17
+
18
+ return fig
19
+
20
+
21
+ def scatter_into_bz(
22
+ bz_corners,
23
+ k_points,
24
+ data,
25
+ fig: plt.Figure | None = None,
26
+ ax: plt.Axes | None = None,
27
+ ):
28
+ if fig is None:
29
+ fig, ax = plt.subplots()
30
+
31
+ scatter = ax.scatter(*zip(*k_points), c=data, cmap="viridis")
32
+ ax.scatter(*zip(*bz_corners), alpha=0.8)
33
+ fig.colorbar(scatter, ax=ax, fraction=0.046, pad=0.04)
34
+
35
+ ax.set_aspect("equal", adjustable="box")
36
+
37
+ return fig
38
+
39
+
40
+ def plot_bcs_bandstructure(
41
+ non_interacting_bands,
42
+ deltas,
43
+ k_point_list,
44
+ ticks,
45
+ labels,
46
+ fig: plt.Figure | None = None,
47
+ ax: plt.Axes | None = None,
48
+ ):
49
+ if fig is None:
50
+ fig, ax = plt.subplots()
51
+
52
+ ax.axhline(y=0, alpha=0.7, linestyle="--", color="black")
53
+
54
+ for index, (band, delta) in enumerate(zip(non_interacting_bands, deltas)):
55
+ ax.plot(
56
+ k_point_list,
57
+ np.sqrt(band**2 + np.abs(delta) ** 2),
58
+ label=f"band {index}, +",
59
+ )
60
+ ax.plot(
61
+ k_point_list,
62
+ -np.sqrt(band**2 + np.abs(delta) ** 2),
63
+ label=f"band {index}, -",
64
+ )
65
+ ax.set_box_aspect(1)
66
+
67
+ ax.set_xticks(ticks, labels)
68
+ ax.set_yticks(range(-5, 6))
69
+ ax.set_facecolor("lightgray")
70
+ ax.grid(visible=True)
71
+ # ax.set_ylim([-5, 5])
72
+ ax.tick_params(
73
+ axis="both", direction="in", bottom=True, top=True, left=True, right=True
74
+ )
75
+
76
+ ax.legend()
77
+
78
+ return fig, ax
79
+
80
+
81
+ def plot_nonint_bandstructure(
82
+ bands,
83
+ k_point_list,
84
+ ticks,
85
+ labels,
86
+ overlaps: npt.NDArray | None = None,
87
+ fig: plt.Figure | None = None,
88
+ ax: plt.Axes | None = None,
89
+ ):
90
+ if fig is None:
91
+ fig, ax = plt.subplots()
92
+
93
+ ax.axhline(y=0, alpha=0.7, linestyle="--", color="black")
94
+
95
+ if overlaps is None:
96
+ for band in bands:
97
+ ax.plot(k_point_list, band)
98
+ else:
99
+ line = None
100
+
101
+ for band, wx in zip(bands, overlaps):
102
+ points = np.array([k_point_list, band]).T.reshape(-1, 1, 2)
103
+ segments = np.concatenate([points[:-1], points[1:]], axis=1)
104
+
105
+ norm = plt.Normalize(-1, 1)
106
+ lc = LineCollection(segments, cmap="seismic", norm=norm)
107
+ lc.set_array(wx)
108
+ lc.set_linewidth(2)
109
+ line = ax.add_collection(lc)
110
+
111
+ colorbar = fig.colorbar(line, fraction=0.046, pad=0.04, ax=ax)
112
+ color_ticks = [-1, 1]
113
+ colorbar.set_ticks(ticks=color_ticks, labels=[r"$w_{\mathrm{Gr}_1}$", r"$w_X$"])
114
+
115
+ ax.set_box_aspect(1)
116
+ ax.set_xticks(ticks, labels)
117
+ ax.set_yticks(range(-5, 6))
118
+ ax.set_facecolor("lightgray")
119
+ ax.grid(visible=True)
120
+ # ax.set_ylim([-5, 5])
121
+ ax.tick_params(
122
+ axis="both", direction="in", bottom=True, top=True, left=True, right=True
123
+ )
124
+
125
+ return fig
126
+
127
+
128
+ def _generate_part_of_path(p_0, p_1, n, length_whole_path):
129
+ distance = np.linalg.norm(p_1 - p_0)
130
+ number_of_points = int(n * distance / length_whole_path) + 1
131
+
132
+ k_space_path = np.vstack(
133
+ [
134
+ np.linspace(p_0[0], p_1[0], number_of_points),
135
+ np.linspace(p_0[1], p_1[1], number_of_points),
136
+ ]
137
+ ).T[:-1]
138
+
139
+ return k_space_path
140
+
141
+
142
+ def generate_bz_path(points=None, number_of_points=1000):
143
+ n = number_of_points
144
+
145
+ cycle = [
146
+ np.linalg.norm(points[i][0] - points[i + 1][0]) for i in range(len(points) - 1)
147
+ ]
148
+ cycle.append(np.linalg.norm(points[-1][0] - points[0][0]))
149
+
150
+ length_whole_path = np.sum(np.array([cycle]))
151
+
152
+ ticks = [0]
153
+ for i in range(0, len(cycle) - 1):
154
+ ticks.append(np.sum(cycle[0 : i + 1]) / length_whole_path)
155
+ ticks.append(1)
156
+ labels = [rf"${points[i][1]}$" for i in range(len(points))]
157
+ labels.append(rf"${points[0][1]}$")
158
+
159
+ whole_path_plot = np.concatenate(
160
+ [
161
+ np.linspace(
162
+ ticks[i],
163
+ ticks[i + 1],
164
+ num=int(n * cycle[i] / length_whole_path),
165
+ endpoint=False,
166
+ )
167
+ for i in range(0, len(ticks) - 1)
168
+ ]
169
+ )
170
+
171
+ points_path = [
172
+ _generate_part_of_path(points[i][0], points[i + 1][0], n, length_whole_path)
173
+ for i in range(0, len(points) - 1)
174
+ ]
175
+ points_path.append(
176
+ _generate_part_of_path(points[-1][0], points[0][0], n, length_whole_path)
177
+ )
178
+ whole_path = np.concatenate(points_path)
179
+
180
+ return whole_path, whole_path_plot, ticks, labels