quant-met 0.0.2__py3-none-any.whl → 0.0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- quant_met/__init__.py +2 -4
 - quant_met/mean_field/__init__.py +61 -0
 - quant_met/mean_field/_utils.py +27 -0
 - quant_met/mean_field/base_hamiltonian.py +310 -0
 - quant_met/mean_field/eg_x.py +173 -0
 - quant_met/mean_field/free_energy.py +130 -0
 - quant_met/mean_field/graphene.py +142 -0
 - quant_met/mean_field/quantum_metric.py +108 -0
 - quant_met/mean_field/superfluid_weight.py +146 -0
 - quant_met/plotting/__init__.py +22 -1
 - quant_met/plotting/plotting.py +230 -0
 - quant_met/utils.py +45 -2
 - quant_met-0.0.4.dist-info/LICENSES/MIT.txt +9 -0
 - {quant_met-0.0.2.dist-info → quant_met-0.0.4.dist-info}/METADATA +11 -7
 - quant_met-0.0.4.dist-info/RECORD +17 -0
 - quant_met/hamiltonians/__init__.py +0 -14
 - quant_met/hamiltonians/_base_hamiltonian.py +0 -172
 - quant_met/hamiltonians/_eg_x.py +0 -124
 - quant_met/hamiltonians/_free_energy.py +0 -39
 - quant_met/hamiltonians/_graphene.py +0 -93
 - quant_met/hamiltonians/_superfluid_weight.py +0 -130
 - quant_met/hamiltonians/_utils.py +0 -10
 - quant_met/plotting/_plotting.py +0 -156
 - quant_met-0.0.2.dist-info/RECORD +0 -15
 - {quant_met-0.0.2.dist-info → quant_met-0.0.4.dist-info}/LICENSE.txt +0 -0
 - {quant_met-0.0.2.dist-info → quant_met-0.0.4.dist-info}/WHEEL +0 -0
 
| 
         @@ -0,0 +1,142 @@ 
     | 
|
| 
      
 1 
     | 
    
         
            +
            # SPDX-FileCopyrightText: 2024 Tjark Sievers
         
     | 
| 
      
 2 
     | 
    
         
            +
            #
         
     | 
| 
      
 3 
     | 
    
         
            +
            # SPDX-License-Identifier: MIT
         
     | 
| 
      
 4 
     | 
    
         
            +
             
     | 
| 
      
 5 
     | 
    
         
            +
            """Provides the implementation for Graphene."""
         
     | 
| 
      
 6 
     | 
    
         
            +
             
     | 
| 
      
 7 
     | 
    
         
            +
            from typing import Any
         
     | 
| 
      
 8 
     | 
    
         
            +
             
     | 
| 
      
 9 
     | 
    
         
            +
            import numpy as np
         
     | 
| 
      
 10 
     | 
    
         
            +
            import numpy.typing as npt
         
     | 
| 
      
 11 
     | 
    
         
            +
             
     | 
| 
      
 12 
     | 
    
         
            +
            from ._utils import _check_valid_array, _validate_float
         
     | 
| 
      
 13 
     | 
    
         
            +
            from .base_hamiltonian import BaseHamiltonian
         
     | 
| 
      
 14 
     | 
    
         
            +
             
     | 
| 
      
 15 
     | 
    
         
            +
             
     | 
| 
      
 16 
     | 
    
         
            +
            class GrapheneHamiltonian(BaseHamiltonian):
         
     | 
| 
      
 17 
     | 
    
         
            +
                """Hamiltonian for Graphene."""
         
     | 
| 
      
 18 
     | 
    
         
            +
             
     | 
| 
      
 19 
     | 
    
         
            +
                def __init__(
         
     | 
| 
      
 20 
     | 
    
         
            +
                    self,
         
     | 
| 
      
 21 
     | 
    
         
            +
                    t_nn: float,
         
     | 
| 
      
 22 
     | 
    
         
            +
                    a: float,
         
     | 
| 
      
 23 
     | 
    
         
            +
                    mu: float,
         
     | 
| 
      
 24 
     | 
    
         
            +
                    coulomb_gr: float,
         
     | 
| 
      
 25 
     | 
    
         
            +
                    delta: npt.NDArray[np.float64] | None = None,
         
     | 
| 
      
 26 
     | 
    
         
            +
                    *args: tuple[Any, ...],
         
     | 
| 
      
 27 
     | 
    
         
            +
                    **kwargs: tuple[dict[str, Any], ...],
         
     | 
| 
      
 28 
     | 
    
         
            +
                ) -> None:
         
     | 
| 
      
 29 
     | 
    
         
            +
                    del args
         
     | 
| 
      
 30 
     | 
    
         
            +
                    del kwargs
         
     | 
| 
      
 31 
     | 
    
         
            +
                    self.t_nn = _validate_float(t_nn, "Hopping")
         
     | 
| 
      
 32 
     | 
    
         
            +
                    if a <= 0:
         
     | 
| 
      
 33 
     | 
    
         
            +
                        msg = "Lattice constant must be positive"
         
     | 
| 
      
 34 
     | 
    
         
            +
                        raise ValueError(msg)
         
     | 
| 
      
 35 
     | 
    
         
            +
                    self.a = _validate_float(a, "Lattice constant")
         
     | 
| 
      
 36 
     | 
    
         
            +
                    self.mu = _validate_float(mu, "Chemical potential")
         
     | 
| 
      
 37 
     | 
    
         
            +
                    self.coulomb_gr = _validate_float(coulomb_gr, "Coloumb interaction")
         
     | 
| 
      
 38 
     | 
    
         
            +
                    self._coloumb_orbital_basis = np.array([self.coulomb_gr, self.coulomb_gr])
         
     | 
| 
      
 39 
     | 
    
         
            +
                    self._number_of_bands = 2
         
     | 
| 
      
 40 
     | 
    
         
            +
                    if delta is None:
         
     | 
| 
      
 41 
     | 
    
         
            +
                        self._delta_orbital_basis = np.zeros(2)
         
     | 
| 
      
 42 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 43 
     | 
    
         
            +
                        self._delta_orbital_basis = delta
         
     | 
| 
      
 44 
     | 
    
         
            +
             
     | 
| 
      
 45 
     | 
    
         
            +
                @property
         
     | 
| 
      
 46 
     | 
    
         
            +
                def number_of_bands(self) -> int:  # noqa: D102
         
     | 
| 
      
 47 
     | 
    
         
            +
                    return self._number_of_bands
         
     | 
| 
      
 48 
     | 
    
         
            +
             
     | 
| 
      
 49 
     | 
    
         
            +
                @property
         
     | 
| 
      
 50 
     | 
    
         
            +
                def coloumb_orbital_basis(self) -> npt.NDArray[np.float64]:  # noqa: D102
         
     | 
| 
      
 51 
     | 
    
         
            +
                    return self._coloumb_orbital_basis
         
     | 
| 
      
 52 
     | 
    
         
            +
             
     | 
| 
      
 53 
     | 
    
         
            +
                @property
         
     | 
| 
      
 54 
     | 
    
         
            +
                def delta_orbital_basis(self) -> npt.NDArray[np.float64]:  # noqa: D102
         
     | 
| 
      
 55 
     | 
    
         
            +
                    return self._delta_orbital_basis
         
     | 
| 
      
 56 
     | 
    
         
            +
             
     | 
| 
      
 57 
     | 
    
         
            +
                @delta_orbital_basis.setter
         
     | 
| 
      
 58 
     | 
    
         
            +
                def delta_orbital_basis(self, new_delta: npt.NDArray[np.float64]) -> None:
         
     | 
| 
      
 59 
     | 
    
         
            +
                    self._delta_orbital_basis = new_delta
         
     | 
| 
      
 60 
     | 
    
         
            +
             
     | 
| 
      
 61 
     | 
    
         
            +
                def hamiltonian(self, k: npt.NDArray[np.float64]) -> npt.NDArray[np.complex64]:
         
     | 
| 
      
 62 
     | 
    
         
            +
                    """
         
     | 
| 
      
 63 
     | 
    
         
            +
                    Return the normal state Hamiltonian in orbital basis.
         
     | 
| 
      
 64 
     | 
    
         
            +
             
     | 
| 
      
 65 
     | 
    
         
            +
                    Parameters
         
     | 
| 
      
 66 
     | 
    
         
            +
                    ----------
         
     | 
| 
      
 67 
     | 
    
         
            +
                    k : :class:`numpy.ndarray`
         
     | 
| 
      
 68 
     | 
    
         
            +
                        List of k points.
         
     | 
| 
      
 69 
     | 
    
         
            +
             
     | 
| 
      
 70 
     | 
    
         
            +
                    Returns
         
     | 
| 
      
 71 
     | 
    
         
            +
                    -------
         
     | 
| 
      
 72 
     | 
    
         
            +
                    :class:`numpy.ndarray`
         
     | 
| 
      
 73 
     | 
    
         
            +
                        Hamiltonian in matrix form.
         
     | 
| 
      
 74 
     | 
    
         
            +
             
     | 
| 
      
 75 
     | 
    
         
            +
                    """
         
     | 
| 
      
 76 
     | 
    
         
            +
                    assert _check_valid_array(k)
         
     | 
| 
      
 77 
     | 
    
         
            +
                    t_nn = self.t_nn
         
     | 
| 
      
 78 
     | 
    
         
            +
                    a = self.a
         
     | 
| 
      
 79 
     | 
    
         
            +
                    mu = self.mu
         
     | 
| 
      
 80 
     | 
    
         
            +
                    if k.ndim == 1:
         
     | 
| 
      
 81 
     | 
    
         
            +
                        k = np.expand_dims(k, axis=0)
         
     | 
| 
      
 82 
     | 
    
         
            +
             
     | 
| 
      
 83 
     | 
    
         
            +
                    h = np.zeros((k.shape[0], self.number_of_bands, self.number_of_bands), dtype=np.complex64)
         
     | 
| 
      
 84 
     | 
    
         
            +
             
     | 
| 
      
 85 
     | 
    
         
            +
                    h[:, 0, 1] = -t_nn * (
         
     | 
| 
      
 86 
     | 
    
         
            +
                        np.exp(1j * k[:, 1] * a / np.sqrt(3))
         
     | 
| 
      
 87 
     | 
    
         
            +
                        + 2 * np.exp(-0.5j * a / np.sqrt(3) * k[:, 1]) * (np.cos(0.5 * a * k[:, 0]))
         
     | 
| 
      
 88 
     | 
    
         
            +
                    )
         
     | 
| 
      
 89 
     | 
    
         
            +
                    h[:, 1, 0] = h[:, 0, 1].conjugate()
         
     | 
| 
      
 90 
     | 
    
         
            +
                    h[:, 0, 0] -= mu
         
     | 
| 
      
 91 
     | 
    
         
            +
                    h[:, 1, 1] -= mu
         
     | 
| 
      
 92 
     | 
    
         
            +
             
     | 
| 
      
 93 
     | 
    
         
            +
                    return h.squeeze()
         
     | 
| 
      
 94 
     | 
    
         
            +
             
     | 
| 
      
 95 
     | 
    
         
            +
                def hamiltonian_derivative(
         
     | 
| 
      
 96 
     | 
    
         
            +
                    self, k: npt.NDArray[np.float64], direction: str
         
     | 
| 
      
 97 
     | 
    
         
            +
                ) -> npt.NDArray[np.complex64]:
         
     | 
| 
      
 98 
     | 
    
         
            +
                    """
         
     | 
| 
      
 99 
     | 
    
         
            +
                    Deriative of the Hamiltonian.
         
     | 
| 
      
 100 
     | 
    
         
            +
             
     | 
| 
      
 101 
     | 
    
         
            +
                    Parameters
         
     | 
| 
      
 102 
     | 
    
         
            +
                    ----------
         
     | 
| 
      
 103 
     | 
    
         
            +
                    k: :class:`numpy.ndarray`
         
     | 
| 
      
 104 
     | 
    
         
            +
                        List of k points.
         
     | 
| 
      
 105 
     | 
    
         
            +
                    direction: str
         
     | 
| 
      
 106 
     | 
    
         
            +
                        Direction for derivative, either 'x' oder 'y'.
         
     | 
| 
      
 107 
     | 
    
         
            +
             
     | 
| 
      
 108 
     | 
    
         
            +
                    Returns
         
     | 
| 
      
 109 
     | 
    
         
            +
                    -------
         
     | 
| 
      
 110 
     | 
    
         
            +
                    :class:`numpy.ndarray`
         
     | 
| 
      
 111 
     | 
    
         
            +
                        Derivative of Hamiltonian.
         
     | 
| 
      
 112 
     | 
    
         
            +
             
     | 
| 
      
 113 
     | 
    
         
            +
                    """
         
     | 
| 
      
 114 
     | 
    
         
            +
                    assert _check_valid_array(k)
         
     | 
| 
      
 115 
     | 
    
         
            +
                    assert direction in ["x", "y"]
         
     | 
| 
      
 116 
     | 
    
         
            +
             
     | 
| 
      
 117 
     | 
    
         
            +
                    t_nn = self.t_nn
         
     | 
| 
      
 118 
     | 
    
         
            +
                    a = self.a
         
     | 
| 
      
 119 
     | 
    
         
            +
                    if k.ndim == 1:
         
     | 
| 
      
 120 
     | 
    
         
            +
                        k = np.expand_dims(k, axis=0)
         
     | 
| 
      
 121 
     | 
    
         
            +
             
     | 
| 
      
 122 
     | 
    
         
            +
                    h = np.zeros((k.shape[0], self.number_of_bands, self.number_of_bands), dtype=np.complex64)
         
     | 
| 
      
 123 
     | 
    
         
            +
             
     | 
| 
      
 124 
     | 
    
         
            +
                    if direction == "x":
         
     | 
| 
      
 125 
     | 
    
         
            +
                        h[:, 0, 1] = (
         
     | 
| 
      
 126 
     | 
    
         
            +
                            t_nn * a * np.exp(-0.5j * a / np.sqrt(3) * k[:, 1]) * np.sin(0.5 * a * k[:, 0])
         
     | 
| 
      
 127 
     | 
    
         
            +
                        )
         
     | 
| 
      
 128 
     | 
    
         
            +
                        h[:, 1, 0] = h[:, 0, 1].conjugate()
         
     | 
| 
      
 129 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 130 
     | 
    
         
            +
                        h[:, 0, 1] = (
         
     | 
| 
      
 131 
     | 
    
         
            +
                            -t_nn
         
     | 
| 
      
 132 
     | 
    
         
            +
                            * 1j
         
     | 
| 
      
 133 
     | 
    
         
            +
                            * a
         
     | 
| 
      
 134 
     | 
    
         
            +
                            / np.sqrt(3)
         
     | 
| 
      
 135 
     | 
    
         
            +
                            * (
         
     | 
| 
      
 136 
     | 
    
         
            +
                                np.exp(1j * a / np.sqrt(3) * k[:, 1])
         
     | 
| 
      
 137 
     | 
    
         
            +
                                - np.exp(-0.5j * a / np.sqrt(3) * k[:, 1]) * np.cos(0.5 * a * k[:, 0])
         
     | 
| 
      
 138 
     | 
    
         
            +
                            )
         
     | 
| 
      
 139 
     | 
    
         
            +
                        )
         
     | 
| 
      
 140 
     | 
    
         
            +
                        h[:, 1, 0] = h[:, 0, 1].conjugate()
         
     | 
| 
      
 141 
     | 
    
         
            +
             
     | 
| 
      
 142 
     | 
    
         
            +
                    return h.squeeze()
         
     | 
| 
         @@ -0,0 +1,108 @@ 
     | 
|
| 
      
 1 
     | 
    
         
            +
            # SPDX-FileCopyrightText: 2024 Tjark Sievers
         
     | 
| 
      
 2 
     | 
    
         
            +
            #
         
     | 
| 
      
 3 
     | 
    
         
            +
            # SPDX-License-Identifier: MIT
         
     | 
| 
      
 4 
     | 
    
         
            +
             
     | 
| 
      
 5 
     | 
    
         
            +
            """Functions to calculate the quantum metric."""
         
     | 
| 
      
 6 
     | 
    
         
            +
             
     | 
| 
      
 7 
     | 
    
         
            +
            import numpy as np
         
     | 
| 
      
 8 
     | 
    
         
            +
            import numpy.typing as npt
         
     | 
| 
      
 9 
     | 
    
         
            +
             
     | 
| 
      
 10 
     | 
    
         
            +
            from .base_hamiltonian import BaseHamiltonian
         
     | 
| 
      
 11 
     | 
    
         
            +
             
     | 
| 
      
 12 
     | 
    
         
            +
             
     | 
| 
      
 13 
     | 
    
         
            +
            def quantum_metric(
         
     | 
| 
      
 14 
     | 
    
         
            +
                h: BaseHamiltonian, k_grid: npt.NDArray[np.float64], band: int
         
     | 
| 
      
 15 
     | 
    
         
            +
            ) -> npt.NDArray[np.float64]:
         
     | 
| 
      
 16 
     | 
    
         
            +
                """Calculate the quantum metric in the normal state.
         
     | 
| 
      
 17 
     | 
    
         
            +
             
     | 
| 
      
 18 
     | 
    
         
            +
                Parameters
         
     | 
| 
      
 19 
     | 
    
         
            +
                ----------
         
     | 
| 
      
 20 
     | 
    
         
            +
                h : :class:`~quant_met.BaseHamiltonian`
         
     | 
| 
      
 21 
     | 
    
         
            +
                    Hamiltonian object.
         
     | 
| 
      
 22 
     | 
    
         
            +
                k_grid : :class:`numpy.ndarray`
         
     | 
| 
      
 23 
     | 
    
         
            +
                    List of k points.
         
     | 
| 
      
 24 
     | 
    
         
            +
                band : int
         
     | 
| 
      
 25 
     | 
    
         
            +
                    Index of band for which the quantum metric is calculated.
         
     | 
| 
      
 26 
     | 
    
         
            +
             
     | 
| 
      
 27 
     | 
    
         
            +
                Returns
         
     | 
| 
      
 28 
     | 
    
         
            +
                -------
         
     | 
| 
      
 29 
     | 
    
         
            +
                :class:`numpy.ndarray`
         
     | 
| 
      
 30 
     | 
    
         
            +
                    Quantum metric in the normal state.
         
     | 
| 
      
 31 
     | 
    
         
            +
             
     | 
| 
      
 32 
     | 
    
         
            +
                """
         
     | 
| 
      
 33 
     | 
    
         
            +
                energies, bloch = h.diagonalize_nonint(k_grid)
         
     | 
| 
      
 34 
     | 
    
         
            +
             
     | 
| 
      
 35 
     | 
    
         
            +
                number_k_points = len(k_grid)
         
     | 
| 
      
 36 
     | 
    
         
            +
             
     | 
| 
      
 37 
     | 
    
         
            +
                quantum_geom_tensor = np.zeros(shape=(2, 2), dtype=np.complex64)
         
     | 
| 
      
 38 
     | 
    
         
            +
             
     | 
| 
      
 39 
     | 
    
         
            +
                for i, direction_1 in enumerate(["x", "y"]):
         
     | 
| 
      
 40 
     | 
    
         
            +
                    h_derivative_direction_1 = h.hamiltonian_derivative(k=k_grid, direction=direction_1)
         
     | 
| 
      
 41 
     | 
    
         
            +
                    for j, direction_2 in enumerate(["x", "y"]):
         
     | 
| 
      
 42 
     | 
    
         
            +
                        h_derivative_direction_2 = h.hamiltonian_derivative(k=k_grid, direction=direction_2)
         
     | 
| 
      
 43 
     | 
    
         
            +
                        for k_index in range(len(k_grid)):
         
     | 
| 
      
 44 
     | 
    
         
            +
                            for n in [i for i in range(h.number_of_bands) if i != band]:
         
     | 
| 
      
 45 
     | 
    
         
            +
                                quantum_geom_tensor[i, j] += (
         
     | 
| 
      
 46 
     | 
    
         
            +
                                    (
         
     | 
| 
      
 47 
     | 
    
         
            +
                                        bloch[k_index][:, band].conjugate()
         
     | 
| 
      
 48 
     | 
    
         
            +
                                        @ h_derivative_direction_1[k_index]
         
     | 
| 
      
 49 
     | 
    
         
            +
                                        @ bloch[k_index][:, n]
         
     | 
| 
      
 50 
     | 
    
         
            +
                                    )
         
     | 
| 
      
 51 
     | 
    
         
            +
                                    * (
         
     | 
| 
      
 52 
     | 
    
         
            +
                                        bloch[k_index][:, n].conjugate()
         
     | 
| 
      
 53 
     | 
    
         
            +
                                        @ h_derivative_direction_2[k_index]
         
     | 
| 
      
 54 
     | 
    
         
            +
                                        @ bloch[k_index][:, band]
         
     | 
| 
      
 55 
     | 
    
         
            +
                                    )
         
     | 
| 
      
 56 
     | 
    
         
            +
                                    / (energies[k_index][band] - energies[k_index][n]) ** 2
         
     | 
| 
      
 57 
     | 
    
         
            +
                                )
         
     | 
| 
      
 58 
     | 
    
         
            +
             
     | 
| 
      
 59 
     | 
    
         
            +
                return np.real(quantum_geom_tensor) / number_k_points
         
     | 
| 
      
 60 
     | 
    
         
            +
             
     | 
| 
      
 61 
     | 
    
         
            +
             
     | 
| 
      
 62 
     | 
    
         
            +
            def quantum_metric_bdg(
         
     | 
| 
      
 63 
     | 
    
         
            +
                h: BaseHamiltonian, k_grid: npt.NDArray[np.float64], band: int
         
     | 
| 
      
 64 
     | 
    
         
            +
            ) -> npt.NDArray[np.float64]:
         
     | 
| 
      
 65 
     | 
    
         
            +
                """Calculate the quantum metric in the BdG state.
         
     | 
| 
      
 66 
     | 
    
         
            +
             
     | 
| 
      
 67 
     | 
    
         
            +
                Parameters
         
     | 
| 
      
 68 
     | 
    
         
            +
                ----------
         
     | 
| 
      
 69 
     | 
    
         
            +
                h : :class:`~quant_met.BaseHamiltonian`
         
     | 
| 
      
 70 
     | 
    
         
            +
                    Hamiltonian object.
         
     | 
| 
      
 71 
     | 
    
         
            +
                k_grid : :class:`numpy.ndarray`
         
     | 
| 
      
 72 
     | 
    
         
            +
                    List of k points.
         
     | 
| 
      
 73 
     | 
    
         
            +
                band : int
         
     | 
| 
      
 74 
     | 
    
         
            +
                    Index of band for which the quantum metric is calculated.
         
     | 
| 
      
 75 
     | 
    
         
            +
             
     | 
| 
      
 76 
     | 
    
         
            +
                Returns
         
     | 
| 
      
 77 
     | 
    
         
            +
                -------
         
     | 
| 
      
 78 
     | 
    
         
            +
                :class:`numpy.ndarray`
         
     | 
| 
      
 79 
     | 
    
         
            +
                    Quantum metric in the normal state.
         
     | 
| 
      
 80 
     | 
    
         
            +
             
     | 
| 
      
 81 
     | 
    
         
            +
                """
         
     | 
| 
      
 82 
     | 
    
         
            +
                energies, bdg_functions = h.diagonalize_bdg(k_grid)
         
     | 
| 
      
 83 
     | 
    
         
            +
             
     | 
| 
      
 84 
     | 
    
         
            +
                number_k_points = len(k_grid)
         
     | 
| 
      
 85 
     | 
    
         
            +
             
     | 
| 
      
 86 
     | 
    
         
            +
                quantum_geom_tensor = np.zeros(shape=(2, 2), dtype=np.complex64)
         
     | 
| 
      
 87 
     | 
    
         
            +
             
     | 
| 
      
 88 
     | 
    
         
            +
                for i, direction_1 in enumerate(["x", "y"]):
         
     | 
| 
      
 89 
     | 
    
         
            +
                    h_derivative_direction_1 = h.bdg_hamiltonian_derivative(k=k_grid, direction=direction_1)
         
     | 
| 
      
 90 
     | 
    
         
            +
                    for j, direction_2 in enumerate(["x", "y"]):
         
     | 
| 
      
 91 
     | 
    
         
            +
                        h_derivative_direction_2 = h.bdg_hamiltonian_derivative(k=k_grid, direction=direction_2)
         
     | 
| 
      
 92 
     | 
    
         
            +
                        for k_index in range(len(k_grid)):
         
     | 
| 
      
 93 
     | 
    
         
            +
                            for n in [i for i in range(h.number_of_bands) if i != band]:
         
     | 
| 
      
 94 
     | 
    
         
            +
                                quantum_geom_tensor[i, j] += (
         
     | 
| 
      
 95 
     | 
    
         
            +
                                    (
         
     | 
| 
      
 96 
     | 
    
         
            +
                                        bdg_functions[k_index][:, band].conjugate()
         
     | 
| 
      
 97 
     | 
    
         
            +
                                        @ h_derivative_direction_1[k_index]
         
     | 
| 
      
 98 
     | 
    
         
            +
                                        @ bdg_functions[k_index][:, n]
         
     | 
| 
      
 99 
     | 
    
         
            +
                                    )
         
     | 
| 
      
 100 
     | 
    
         
            +
                                    * (
         
     | 
| 
      
 101 
     | 
    
         
            +
                                        bdg_functions[k_index][:, n].conjugate()
         
     | 
| 
      
 102 
     | 
    
         
            +
                                        @ h_derivative_direction_2[k_index]
         
     | 
| 
      
 103 
     | 
    
         
            +
                                        @ bdg_functions[k_index][:, band]
         
     | 
| 
      
 104 
     | 
    
         
            +
                                    )
         
     | 
| 
      
 105 
     | 
    
         
            +
                                    / (energies[k_index][band] - energies[k_index][n]) ** 2
         
     | 
| 
      
 106 
     | 
    
         
            +
                                )
         
     | 
| 
      
 107 
     | 
    
         
            +
             
     | 
| 
      
 108 
     | 
    
         
            +
                return np.real(quantum_geom_tensor) / number_k_points
         
     | 
| 
         @@ -0,0 +1,146 @@ 
     | 
|
| 
      
 1 
     | 
    
         
            +
            # SPDX-FileCopyrightText: 2024 Tjark Sievers
         
     | 
| 
      
 2 
     | 
    
         
            +
            #
         
     | 
| 
      
 3 
     | 
    
         
            +
            # SPDX-License-Identifier: MIT
         
     | 
| 
      
 4 
     | 
    
         
            +
             
     | 
| 
      
 5 
     | 
    
         
            +
            """Functions to calculate the superfluid weight."""
         
     | 
| 
      
 6 
     | 
    
         
            +
             
     | 
| 
      
 7 
     | 
    
         
            +
            import numpy as np
         
     | 
| 
      
 8 
     | 
    
         
            +
            import numpy.typing as npt
         
     | 
| 
      
 9 
     | 
    
         
            +
             
     | 
| 
      
 10 
     | 
    
         
            +
            from .base_hamiltonian import BaseHamiltonian
         
     | 
| 
      
 11 
     | 
    
         
            +
             
     | 
| 
      
 12 
     | 
    
         
            +
             
     | 
| 
      
 13 
     | 
    
         
            +
            def superfluid_weight(
         
     | 
| 
      
 14 
     | 
    
         
            +
                h: BaseHamiltonian,
         
     | 
| 
      
 15 
     | 
    
         
            +
                k_grid: npt.NDArray[np.float64],
         
     | 
| 
      
 16 
     | 
    
         
            +
            ) -> tuple[npt.NDArray[np.complex64], npt.NDArray[np.complex64]]:
         
     | 
| 
      
 17 
     | 
    
         
            +
                """Calculate the superfluid weight.
         
     | 
| 
      
 18 
     | 
    
         
            +
             
     | 
| 
      
 19 
     | 
    
         
            +
                Parameters
         
     | 
| 
      
 20 
     | 
    
         
            +
                ----------
         
     | 
| 
      
 21 
     | 
    
         
            +
                h : :class:`~quant_met.mean_field.Hamiltonian`
         
     | 
| 
      
 22 
     | 
    
         
            +
                    Hamiltonian.
         
     | 
| 
      
 23 
     | 
    
         
            +
                k_grid : :class:`numpy.ndarray`
         
     | 
| 
      
 24 
     | 
    
         
            +
                    List of k points.
         
     | 
| 
      
 25 
     | 
    
         
            +
             
     | 
| 
      
 26 
     | 
    
         
            +
                Returns
         
     | 
| 
      
 27 
     | 
    
         
            +
                -------
         
     | 
| 
      
 28 
     | 
    
         
            +
                :class:`numpy.ndarray`
         
     | 
| 
      
 29 
     | 
    
         
            +
                    Conventional contribution to the superfluid weight.
         
     | 
| 
      
 30 
     | 
    
         
            +
                :class:`numpy.ndarray`
         
     | 
| 
      
 31 
     | 
    
         
            +
                    Geometric contribution to the superfluid weight.
         
     | 
| 
      
 32 
     | 
    
         
            +
             
     | 
| 
      
 33 
     | 
    
         
            +
                """
         
     | 
| 
      
 34 
     | 
    
         
            +
                s_weight_conv = np.zeros(shape=(2, 2), dtype=np.complex64)
         
     | 
| 
      
 35 
     | 
    
         
            +
                s_weight_geom = np.zeros(shape=(2, 2), dtype=np.complex64)
         
     | 
| 
      
 36 
     | 
    
         
            +
             
     | 
| 
      
 37 
     | 
    
         
            +
                for i, direction_1 in enumerate(["x", "y"]):
         
     | 
| 
      
 38 
     | 
    
         
            +
                    for j, direction_2 in enumerate(["x", "y"]):
         
     | 
| 
      
 39 
     | 
    
         
            +
                        for k in k_grid:
         
     | 
| 
      
 40 
     | 
    
         
            +
                            c_mnpq = _c_factor(h, k)
         
     | 
| 
      
 41 
     | 
    
         
            +
                            j_up = _current_operator(h, direction_1, k)
         
     | 
| 
      
 42 
     | 
    
         
            +
                            j_down = _current_operator(h, direction_2, -k)
         
     | 
| 
      
 43 
     | 
    
         
            +
                            for m in range(h.number_of_bands):
         
     | 
| 
      
 44 
     | 
    
         
            +
                                for n in range(h.number_of_bands):
         
     | 
| 
      
 45 
     | 
    
         
            +
                                    for p in range(h.number_of_bands):
         
     | 
| 
      
 46 
     | 
    
         
            +
                                        for q in range(h.number_of_bands):
         
     | 
| 
      
 47 
     | 
    
         
            +
                                            s_weight = c_mnpq[m, n, p, q] * j_up[m, n] * j_down[q, p]
         
     | 
| 
      
 48 
     | 
    
         
            +
                                            if m == n and p == q:
         
     | 
| 
      
 49 
     | 
    
         
            +
                                                s_weight_conv[i, j] += s_weight
         
     | 
| 
      
 50 
     | 
    
         
            +
                                            else:
         
     | 
| 
      
 51 
     | 
    
         
            +
                                                s_weight_geom[i, j] += s_weight
         
     | 
| 
      
 52 
     | 
    
         
            +
             
     | 
| 
      
 53 
     | 
    
         
            +
                return s_weight_conv, s_weight_geom
         
     | 
| 
      
 54 
     | 
    
         
            +
             
     | 
| 
      
 55 
     | 
    
         
            +
             
     | 
| 
      
 56 
     | 
    
         
            +
            def _current_operator(
         
     | 
| 
      
 57 
     | 
    
         
            +
                h: BaseHamiltonian, direction: str, k: npt.NDArray[np.float64]
         
     | 
| 
      
 58 
     | 
    
         
            +
            ) -> npt.NDArray[np.complex64]:
         
     | 
| 
      
 59 
     | 
    
         
            +
                j = np.zeros(shape=(h.number_of_bands, h.number_of_bands), dtype=np.complex64)
         
     | 
| 
      
 60 
     | 
    
         
            +
             
     | 
| 
      
 61 
     | 
    
         
            +
                _, bloch = h.diagonalize_nonint(k=k)
         
     | 
| 
      
 62 
     | 
    
         
            +
             
     | 
| 
      
 63 
     | 
    
         
            +
                for m in range(h.number_of_bands):
         
     | 
| 
      
 64 
     | 
    
         
            +
                    for n in range(h.number_of_bands):
         
     | 
| 
      
 65 
     | 
    
         
            +
                        j[m, n] = (
         
     | 
| 
      
 66 
     | 
    
         
            +
                            bloch[:, m].conjugate()
         
     | 
| 
      
 67 
     | 
    
         
            +
                            @ h.hamiltonian_derivative(direction=direction, k=k)
         
     | 
| 
      
 68 
     | 
    
         
            +
                            @ bloch[:, n]
         
     | 
| 
      
 69 
     | 
    
         
            +
                        )
         
     | 
| 
      
 70 
     | 
    
         
            +
             
     | 
| 
      
 71 
     | 
    
         
            +
                return j
         
     | 
| 
      
 72 
     | 
    
         
            +
             
     | 
| 
      
 73 
     | 
    
         
            +
             
     | 
| 
      
 74 
     | 
    
         
            +
            def _w_matrix(
         
     | 
| 
      
 75 
     | 
    
         
            +
                h: BaseHamiltonian, k: npt.NDArray[np.float64]
         
     | 
| 
      
 76 
     | 
    
         
            +
            ) -> tuple[npt.NDArray[np.complex64], npt.NDArray[np.complex64]]:
         
     | 
| 
      
 77 
     | 
    
         
            +
                _, bloch = h.diagonalize_nonint(k=k)
         
     | 
| 
      
 78 
     | 
    
         
            +
                _, bdg_functions = h.diagonalize_bdg(k=k)
         
     | 
| 
      
 79 
     | 
    
         
            +
             
     | 
| 
      
 80 
     | 
    
         
            +
                w_plus = np.zeros((2 * h.number_of_bands, h.number_of_bands), dtype=np.complex64)
         
     | 
| 
      
 81 
     | 
    
         
            +
                for i in range(2 * h.number_of_bands):
         
     | 
| 
      
 82 
     | 
    
         
            +
                    for m in range(h.number_of_bands):
         
     | 
| 
      
 83 
     | 
    
         
            +
                        w_plus[i, m] = (
         
     | 
| 
      
 84 
     | 
    
         
            +
                            np.tensordot(bloch[:, m], np.array([1, 0]), axes=0).reshape(-1)
         
     | 
| 
      
 85 
     | 
    
         
            +
                            @ bdg_functions[:, i]
         
     | 
| 
      
 86 
     | 
    
         
            +
                        )
         
     | 
| 
      
 87 
     | 
    
         
            +
             
     | 
| 
      
 88 
     | 
    
         
            +
                w_minus = np.zeros((2 * h.number_of_bands, h.number_of_bands), dtype=np.complex64)
         
     | 
| 
      
 89 
     | 
    
         
            +
                for i in range(2 * h.number_of_bands):
         
     | 
| 
      
 90 
     | 
    
         
            +
                    for m in range(h.number_of_bands):
         
     | 
| 
      
 91 
     | 
    
         
            +
                        w_minus[i, m] = (
         
     | 
| 
      
 92 
     | 
    
         
            +
                            np.tensordot(bloch[:, m].conjugate(), np.array([0, 1]), axes=0).reshape(-1)
         
     | 
| 
      
 93 
     | 
    
         
            +
                            @ bdg_functions[:, i]
         
     | 
| 
      
 94 
     | 
    
         
            +
                        )
         
     | 
| 
      
 95 
     | 
    
         
            +
             
     | 
| 
      
 96 
     | 
    
         
            +
                return w_plus, w_minus
         
     | 
| 
      
 97 
     | 
    
         
            +
             
     | 
| 
      
 98 
     | 
    
         
            +
             
     | 
| 
      
 99 
     | 
    
         
            +
            def _c_factor(h: BaseHamiltonian, k: npt.NDArray[np.float64]) -> npt.NDArray[np.complex64]:
         
     | 
| 
      
 100 
     | 
    
         
            +
                bdg_energies, _ = h.diagonalize_bdg(k)
         
     | 
| 
      
 101 
     | 
    
         
            +
                w_plus, w_minus = _w_matrix(h, k)
         
     | 
| 
      
 102 
     | 
    
         
            +
                c_mnpq = np.zeros(
         
     | 
| 
      
 103 
     | 
    
         
            +
                    shape=(
         
     | 
| 
      
 104 
     | 
    
         
            +
                        h.number_of_bands,
         
     | 
| 
      
 105 
     | 
    
         
            +
                        h.number_of_bands,
         
     | 
| 
      
 106 
     | 
    
         
            +
                        h.number_of_bands,
         
     | 
| 
      
 107 
     | 
    
         
            +
                        h.number_of_bands,
         
     | 
| 
      
 108 
     | 
    
         
            +
                    ),
         
     | 
| 
      
 109 
     | 
    
         
            +
                    dtype=np.complex64,
         
     | 
| 
      
 110 
     | 
    
         
            +
                )
         
     | 
| 
      
 111 
     | 
    
         
            +
             
     | 
| 
      
 112 
     | 
    
         
            +
                for m in range(h.number_of_bands):
         
     | 
| 
      
 113 
     | 
    
         
            +
                    for n in range(h.number_of_bands):
         
     | 
| 
      
 114 
     | 
    
         
            +
                        for p in range(h.number_of_bands):
         
     | 
| 
      
 115 
     | 
    
         
            +
                            for q in range(h.number_of_bands):
         
     | 
| 
      
 116 
     | 
    
         
            +
                                c_tmp: float = 0
         
     | 
| 
      
 117 
     | 
    
         
            +
                                for i in range(2 * h.number_of_bands):
         
     | 
| 
      
 118 
     | 
    
         
            +
                                    for j in range(2 * h.number_of_bands):
         
     | 
| 
      
 119 
     | 
    
         
            +
                                        if bdg_energies[i] != bdg_energies[j]:
         
     | 
| 
      
 120 
     | 
    
         
            +
                                            c_tmp += (
         
     | 
| 
      
 121 
     | 
    
         
            +
                                                _fermi_dirac(bdg_energies[i]) - _fermi_dirac(bdg_energies[j])
         
     | 
| 
      
 122 
     | 
    
         
            +
                                            ) / (bdg_energies[j] - bdg_energies[i])
         
     | 
| 
      
 123 
     | 
    
         
            +
                                        else:
         
     | 
| 
      
 124 
     | 
    
         
            +
                                            c_tmp -= _fermi_dirac_derivative()
         
     | 
| 
      
 125 
     | 
    
         
            +
             
     | 
| 
      
 126 
     | 
    
         
            +
                                        c_tmp *= (
         
     | 
| 
      
 127 
     | 
    
         
            +
                                            w_minus[i, m].conjugate()
         
     | 
| 
      
 128 
     | 
    
         
            +
                                            * w_plus[j, n]
         
     | 
| 
      
 129 
     | 
    
         
            +
                                            * w_minus[j, p].conjugate()
         
     | 
| 
      
 130 
     | 
    
         
            +
                                            * w_minus[i, q]
         
     | 
| 
      
 131 
     | 
    
         
            +
                                        )
         
     | 
| 
      
 132 
     | 
    
         
            +
             
     | 
| 
      
 133 
     | 
    
         
            +
                                c_mnpq[m, n, p, q] = 2 * c_tmp
         
     | 
| 
      
 134 
     | 
    
         
            +
             
     | 
| 
      
 135 
     | 
    
         
            +
                return c_mnpq
         
     | 
| 
      
 136 
     | 
    
         
            +
             
     | 
| 
      
 137 
     | 
    
         
            +
             
     | 
| 
      
 138 
     | 
    
         
            +
            def _fermi_dirac_derivative() -> float:
         
     | 
| 
      
 139 
     | 
    
         
            +
                return 0
         
     | 
| 
      
 140 
     | 
    
         
            +
             
     | 
| 
      
 141 
     | 
    
         
            +
             
     | 
| 
      
 142 
     | 
    
         
            +
            def _fermi_dirac(energy: np.float64) -> np.float64:
         
     | 
| 
      
 143 
     | 
    
         
            +
                if energy > 0:
         
     | 
| 
      
 144 
     | 
    
         
            +
                    return np.float64(0)
         
     | 
| 
      
 145 
     | 
    
         
            +
             
     | 
| 
      
 146 
     | 
    
         
            +
                return np.float64(1)
         
     | 
    
        quant_met/plotting/__init__.py
    CHANGED
    
    | 
         @@ -1,4 +1,25 @@ 
     | 
|
| 
       1 
     | 
    
         
            -
             
     | 
| 
      
 1 
     | 
    
         
            +
            # SPDX-FileCopyrightText: 2024 Tjark Sievers
         
     | 
| 
      
 2 
     | 
    
         
            +
            #
         
     | 
| 
      
 3 
     | 
    
         
            +
            # SPDX-License-Identifier: MIT
         
     | 
| 
      
 4 
     | 
    
         
            +
             
     | 
| 
      
 5 
     | 
    
         
            +
            """
         
     | 
| 
      
 6 
     | 
    
         
            +
            Plotting
         
     | 
| 
      
 7 
     | 
    
         
            +
            ========
         
     | 
| 
      
 8 
     | 
    
         
            +
             
     | 
| 
      
 9 
     | 
    
         
            +
            .. currentmodule:: quant_met.plotting
         
     | 
| 
      
 10 
     | 
    
         
            +
             
     | 
| 
      
 11 
     | 
    
         
            +
            Functions
         
     | 
| 
      
 12 
     | 
    
         
            +
            ---------
         
     | 
| 
      
 13 
     | 
    
         
            +
             
     | 
| 
      
 14 
     | 
    
         
            +
            .. autosummary::
         
     | 
| 
      
 15 
     | 
    
         
            +
               :toctree: generated/
         
     | 
| 
      
 16 
     | 
    
         
            +
             
     | 
| 
      
 17 
     | 
    
         
            +
                scatter_into_bz
         
     | 
| 
      
 18 
     | 
    
         
            +
                plot_bandstructure
         
     | 
| 
      
 19 
     | 
    
         
            +
                generate_bz_path
         
     | 
| 
      
 20 
     | 
    
         
            +
            """  # noqa: D205, D400
         
     | 
| 
      
 21 
     | 
    
         
            +
             
     | 
| 
      
 22 
     | 
    
         
            +
            from .plotting import generate_bz_path, plot_bandstructure, scatter_into_bz
         
     | 
| 
       2 
23 
     | 
    
         | 
| 
       3 
24 
     | 
    
         
             
            __all__ = [
         
     | 
| 
       4 
25 
     | 
    
         
             
                "scatter_into_bz",
         
     | 
| 
         @@ -0,0 +1,230 @@ 
     | 
|
| 
      
 1 
     | 
    
         
            +
            # SPDX-FileCopyrightText: 2024 Tjark Sievers
         
     | 
| 
      
 2 
     | 
    
         
            +
            #
         
     | 
| 
      
 3 
     | 
    
         
            +
            # SPDX-License-Identifier: MIT
         
     | 
| 
      
 4 
     | 
    
         
            +
             
     | 
| 
      
 5 
     | 
    
         
            +
            """Methods for plotting data."""
         
     | 
| 
      
 6 
     | 
    
         
            +
             
     | 
| 
      
 7 
     | 
    
         
            +
            from typing import Any
         
     | 
| 
      
 8 
     | 
    
         
            +
             
     | 
| 
      
 9 
     | 
    
         
            +
            import matplotlib.axes
         
     | 
| 
      
 10 
     | 
    
         
            +
            import matplotlib.colors
         
     | 
| 
      
 11 
     | 
    
         
            +
            import matplotlib.figure
         
     | 
| 
      
 12 
     | 
    
         
            +
            import matplotlib.pyplot as plt
         
     | 
| 
      
 13 
     | 
    
         
            +
            import numpy as np
         
     | 
| 
      
 14 
     | 
    
         
            +
            import numpy.typing as npt
         
     | 
| 
      
 15 
     | 
    
         
            +
            from matplotlib.collections import Collection, LineCollection
         
     | 
| 
      
 16 
     | 
    
         
            +
            from numpy import dtype, generic, ndarray
         
     | 
| 
      
 17 
     | 
    
         
            +
             
     | 
| 
      
 18 
     | 
    
         
            +
             
     | 
| 
      
 19 
     | 
    
         
            +
            def scatter_into_bz(
         
     | 
| 
      
 20 
     | 
    
         
            +
                bz_corners: list[npt.NDArray[np.float64]],
         
     | 
| 
      
 21 
     | 
    
         
            +
                k_points: npt.NDArray[np.float64],
         
     | 
| 
      
 22 
     | 
    
         
            +
                data: npt.NDArray[np.float64] | None = None,
         
     | 
| 
      
 23 
     | 
    
         
            +
                data_label: str | None = None,
         
     | 
| 
      
 24 
     | 
    
         
            +
                fig_in: matplotlib.figure.Figure | None = None,
         
     | 
| 
      
 25 
     | 
    
         
            +
                ax_in: matplotlib.axes.Axes | None = None,
         
     | 
| 
      
 26 
     | 
    
         
            +
            ) -> matplotlib.figure.Figure:
         
     | 
| 
      
 27 
     | 
    
         
            +
                """Scatter a list of points into the brillouin zone.
         
     | 
| 
      
 28 
     | 
    
         
            +
             
     | 
| 
      
 29 
     | 
    
         
            +
                Parameters
         
     | 
| 
      
 30 
     | 
    
         
            +
                ----------
         
     | 
| 
      
 31 
     | 
    
         
            +
                bz_corners : list[:class:`numpy.ndarray`]
         
     | 
| 
      
 32 
     | 
    
         
            +
                    Corner points defining the brillouin zone.
         
     | 
| 
      
 33 
     | 
    
         
            +
                k_points : :class:`numpy.ndarray`
         
     | 
| 
      
 34 
     | 
    
         
            +
                    List of k points.
         
     | 
| 
      
 35 
     | 
    
         
            +
                data : :class:`numpy.ndarray`, optional
         
     | 
| 
      
 36 
     | 
    
         
            +
                    Data to put on the k points.
         
     | 
| 
      
 37 
     | 
    
         
            +
                data_label : :class:`str`, optional
         
     | 
| 
      
 38 
     | 
    
         
            +
                    Label for the data.
         
     | 
| 
      
 39 
     | 
    
         
            +
                fig_in : :class:`matplotlib.figure.Figure`, optional
         
     | 
| 
      
 40 
     | 
    
         
            +
                    Figure that holds the axes. If not provided, a new figure and ax is created.
         
     | 
| 
      
 41 
     | 
    
         
            +
                ax_in : :class:`matplotlib.axes.Axes`, optional
         
     | 
| 
      
 42 
     | 
    
         
            +
                    Ax to plot the data in. If not provided, a new figure and ax is created.
         
     | 
| 
      
 43 
     | 
    
         
            +
             
     | 
| 
      
 44 
     | 
    
         
            +
                Returns
         
     | 
| 
      
 45 
     | 
    
         
            +
                -------
         
     | 
| 
      
 46 
     | 
    
         
            +
                :obj:`matplotlib.figure.Figure`
         
     | 
| 
      
 47 
     | 
    
         
            +
                    Figure with the data plotted onto the axis.
         
     | 
| 
      
 48 
     | 
    
         
            +
             
     | 
| 
      
 49 
     | 
    
         
            +
                """
         
     | 
| 
      
 50 
     | 
    
         
            +
                if fig_in is None or ax_in is None:
         
     | 
| 
      
 51 
     | 
    
         
            +
                    fig, ax = plt.subplots()
         
     | 
| 
      
 52 
     | 
    
         
            +
                else:
         
     | 
| 
      
 53 
     | 
    
         
            +
                    fig, ax = fig_in, ax_in
         
     | 
| 
      
 54 
     | 
    
         
            +
             
     | 
| 
      
 55 
     | 
    
         
            +
                if data is not None:
         
     | 
| 
      
 56 
     | 
    
         
            +
                    x_coords, y_coords = zip(*k_points, strict=True)
         
     | 
| 
      
 57 
     | 
    
         
            +
                    scatter = ax.scatter(x=x_coords, y=y_coords, c=data, cmap="viridis")
         
     | 
| 
      
 58 
     | 
    
         
            +
                    fig.colorbar(scatter, ax=ax, fraction=0.046, pad=0.04, label=data_label)
         
     | 
| 
      
 59 
     | 
    
         
            +
                else:
         
     | 
| 
      
 60 
     | 
    
         
            +
                    x_coords, y_coords = zip(*k_points, strict=True)
         
     | 
| 
      
 61 
     | 
    
         
            +
                    ax.scatter(x=x_coords, y=y_coords)
         
     | 
| 
      
 62 
     | 
    
         
            +
             
     | 
| 
      
 63 
     | 
    
         
            +
                bz_corner_x, bz_corners_y = zip(*bz_corners, strict=True)
         
     | 
| 
      
 64 
     | 
    
         
            +
                ax.scatter(x=bz_corner_x, y=bz_corners_y, alpha=0.8)
         
     | 
| 
      
 65 
     | 
    
         
            +
                ax.set_aspect("equal", adjustable="box")
         
     | 
| 
      
 66 
     | 
    
         
            +
                ax.set_xlabel(r"$k_x\ [1/a_0]$")
         
     | 
| 
      
 67 
     | 
    
         
            +
                ax.set_ylabel(r"$k_y\ [1/a_0]$")
         
     | 
| 
      
 68 
     | 
    
         
            +
             
     | 
| 
      
 69 
     | 
    
         
            +
                return fig
         
     | 
| 
      
 70 
     | 
    
         
            +
             
     | 
| 
      
 71 
     | 
    
         
            +
             
     | 
| 
      
 72 
     | 
    
         
            +
            def plot_bandstructure(
         
     | 
| 
      
 73 
     | 
    
         
            +
                bands: npt.NDArray[np.float64],
         
     | 
| 
      
 74 
     | 
    
         
            +
                k_point_list: npt.NDArray[np.float64],
         
     | 
| 
      
 75 
     | 
    
         
            +
                ticks: list[float],
         
     | 
| 
      
 76 
     | 
    
         
            +
                labels: list[str],
         
     | 
| 
      
 77 
     | 
    
         
            +
                overlaps: npt.NDArray[np.float64] | None = None,
         
     | 
| 
      
 78 
     | 
    
         
            +
                overlap_labels: list[str] | None = None,
         
     | 
| 
      
 79 
     | 
    
         
            +
                fig_in: matplotlib.figure.Figure | None = None,
         
     | 
| 
      
 80 
     | 
    
         
            +
                ax_in: matplotlib.axes.Axes | None = None,
         
     | 
| 
      
 81 
     | 
    
         
            +
            ) -> matplotlib.figure.Figure:
         
     | 
| 
      
 82 
     | 
    
         
            +
                """Plot bands along a k space path.
         
     | 
| 
      
 83 
     | 
    
         
            +
             
     | 
| 
      
 84 
     | 
    
         
            +
                To have a plot that respects the distances in k space and generate everything that is needed for
         
     | 
| 
      
 85 
     | 
    
         
            +
                plotting, use the function :func:`~quant_met.plotting.generate_bz_path`.
         
     | 
| 
      
 86 
     | 
    
         
            +
             
     | 
| 
      
 87 
     | 
    
         
            +
                Parameters
         
     | 
| 
      
 88 
     | 
    
         
            +
                ----------
         
     | 
| 
      
 89 
     | 
    
         
            +
                bands : :class:`numpy.ndarray`
         
     | 
| 
      
 90 
     | 
    
         
            +
                k_point_list : :class:`numpy.ndarray`
         
     | 
| 
      
 91 
     | 
    
         
            +
                    List of points to plot against. This is not a list of two-dimensional k-points!
         
     | 
| 
      
 92 
     | 
    
         
            +
                ticks : list(float)
         
     | 
| 
      
 93 
     | 
    
         
            +
                    Position for ticks.
         
     | 
| 
      
 94 
     | 
    
         
            +
                labels : list(str)
         
     | 
| 
      
 95 
     | 
    
         
            +
                    Labels on ticks.
         
     | 
| 
      
 96 
     | 
    
         
            +
                overlaps : :class:`numpy.ndarray`, optional
         
     | 
| 
      
 97 
     | 
    
         
            +
                    Overlaps.
         
     | 
| 
      
 98 
     | 
    
         
            +
                overlap_labels : list(str), optional
         
     | 
| 
      
 99 
     | 
    
         
            +
                    Labels to put on overlaps.
         
     | 
| 
      
 100 
     | 
    
         
            +
                fig_in : :class:`matplotlib.figure.Figure`, optional
         
     | 
| 
      
 101 
     | 
    
         
            +
                    Figure that holds the axes. If not provided, a new figure and ax is created.
         
     | 
| 
      
 102 
     | 
    
         
            +
                ax_in : :class:`matplotlib.axes.Axes`, optional
         
     | 
| 
      
 103 
     | 
    
         
            +
                    Ax to plot the data in. If not provided, a new figure and ax is created.
         
     | 
| 
      
 104 
     | 
    
         
            +
             
     | 
| 
      
 105 
     | 
    
         
            +
                Returns
         
     | 
| 
      
 106 
     | 
    
         
            +
                -------
         
     | 
| 
      
 107 
     | 
    
         
            +
                :obj:`matplotlib.figure.Figure`
         
     | 
| 
      
 108 
     | 
    
         
            +
                    Figure with the data plotted onto the axis.
         
     | 
| 
      
 109 
     | 
    
         
            +
             
     | 
| 
      
 110 
     | 
    
         
            +
             
     | 
| 
      
 111 
     | 
    
         
            +
                """
         
     | 
| 
      
 112 
     | 
    
         
            +
                if fig_in is None or ax_in is None:
         
     | 
| 
      
 113 
     | 
    
         
            +
                    fig, ax = plt.subplots()
         
     | 
| 
      
 114 
     | 
    
         
            +
                else:
         
     | 
| 
      
 115 
     | 
    
         
            +
                    fig, ax = fig_in, ax_in
         
     | 
| 
      
 116 
     | 
    
         
            +
             
     | 
| 
      
 117 
     | 
    
         
            +
                ax.axhline(y=0, alpha=0.7, linestyle="--", color="black")
         
     | 
| 
      
 118 
     | 
    
         
            +
             
     | 
| 
      
 119 
     | 
    
         
            +
                if overlaps is None:
         
     | 
| 
      
 120 
     | 
    
         
            +
                    for band in bands:
         
     | 
| 
      
 121 
     | 
    
         
            +
                        ax.plot(k_point_list, band)
         
     | 
| 
      
 122 
     | 
    
         
            +
                else:
         
     | 
| 
      
 123 
     | 
    
         
            +
                    line = Collection()
         
     | 
| 
      
 124 
     | 
    
         
            +
                    for band, wx in zip(bands, overlaps, strict=True):
         
     | 
| 
      
 125 
     | 
    
         
            +
                        points = np.array([k_point_list, band]).T.reshape(-1, 1, 2)
         
     | 
| 
      
 126 
     | 
    
         
            +
                        segments = np.concatenate([points[:-1], points[1:]], axis=1)
         
     | 
| 
      
 127 
     | 
    
         
            +
             
     | 
| 
      
 128 
     | 
    
         
            +
                        norm = matplotlib.colors.Normalize(-1, 1)
         
     | 
| 
      
 129 
     | 
    
         
            +
                        lc = LineCollection(segments, cmap="seismic", norm=norm)
         
     | 
| 
      
 130 
     | 
    
         
            +
                        lc.set_array(wx)
         
     | 
| 
      
 131 
     | 
    
         
            +
                        lc.set_linewidth(2)
         
     | 
| 
      
 132 
     | 
    
         
            +
                        line = ax.add_collection(lc)
         
     | 
| 
      
 133 
     | 
    
         
            +
             
     | 
| 
      
 134 
     | 
    
         
            +
                    colorbar = fig.colorbar(line, fraction=0.046, pad=0.04, ax=ax)
         
     | 
| 
      
 135 
     | 
    
         
            +
                    color_ticks = [-1, 1]
         
     | 
| 
      
 136 
     | 
    
         
            +
                    colorbar.set_ticks(ticks=color_ticks, labels=overlap_labels)
         
     | 
| 
      
 137 
     | 
    
         
            +
             
     | 
| 
      
 138 
     | 
    
         
            +
                ax.set_ylim(
         
     | 
| 
      
 139 
     | 
    
         
            +
                    top=float(np.max(bands) + 0.1 * np.max(bands)),
         
     | 
| 
      
 140 
     | 
    
         
            +
                    bottom=float(np.min(bands) - 0.1 * np.abs(np.min(bands))),
         
     | 
| 
      
 141 
     | 
    
         
            +
                )
         
     | 
| 
      
 142 
     | 
    
         
            +
                ax.set_box_aspect(1)
         
     | 
| 
      
 143 
     | 
    
         
            +
                ax.set_xticks(ticks, labels)
         
     | 
| 
      
 144 
     | 
    
         
            +
                ax.set_ylabel(r"$E\ [t]$")
         
     | 
| 
      
 145 
     | 
    
         
            +
                ax.set_facecolor("lightgray")
         
     | 
| 
      
 146 
     | 
    
         
            +
                ax.grid(visible=True)
         
     | 
| 
      
 147 
     | 
    
         
            +
                ax.tick_params(axis="both", direction="in", bottom=True, top=True, left=True, right=True)
         
     | 
| 
      
 148 
     | 
    
         
            +
             
     | 
| 
      
 149 
     | 
    
         
            +
                return fig
         
     | 
| 
      
 150 
     | 
    
         
            +
             
     | 
| 
      
 151 
     | 
    
         
            +
             
     | 
| 
      
 152 
     | 
    
         
            +
            def _generate_part_of_path(
         
     | 
| 
      
 153 
     | 
    
         
            +
                p_0: npt.NDArray[np.float64],
         
     | 
| 
      
 154 
     | 
    
         
            +
                p_1: npt.NDArray[np.float64],
         
     | 
| 
      
 155 
     | 
    
         
            +
                n: int,
         
     | 
| 
      
 156 
     | 
    
         
            +
                length_whole_path: int,
         
     | 
| 
      
 157 
     | 
    
         
            +
            ) -> npt.NDArray[np.float64]:
         
     | 
| 
      
 158 
     | 
    
         
            +
                distance = np.linalg.norm(p_1 - p_0)
         
     | 
| 
      
 159 
     | 
    
         
            +
                number_of_points = int(n * distance / length_whole_path) + 1
         
     | 
| 
      
 160 
     | 
    
         
            +
             
     | 
| 
      
 161 
     | 
    
         
            +
                return np.vstack(
         
     | 
| 
      
 162 
     | 
    
         
            +
                    [
         
     | 
| 
      
 163 
     | 
    
         
            +
                        np.linspace(p_0[0], p_1[0], number_of_points),
         
     | 
| 
      
 164 
     | 
    
         
            +
                        np.linspace(p_0[1], p_1[1], number_of_points),
         
     | 
| 
      
 165 
     | 
    
         
            +
                    ]
         
     | 
| 
      
 166 
     | 
    
         
            +
                ).T[:-1]
         
     | 
| 
      
 167 
     | 
    
         
            +
             
     | 
| 
      
 168 
     | 
    
         
            +
             
     | 
| 
      
 169 
     | 
    
         
            +
            def generate_bz_path(
         
     | 
| 
      
 170 
     | 
    
         
            +
                points: list[tuple[npt.NDArray[np.float64], str]], number_of_points: int = 1000
         
     | 
| 
      
 171 
     | 
    
         
            +
            ) -> tuple[
         
     | 
| 
      
 172 
     | 
    
         
            +
                ndarray[Any, dtype[generic | Any]],
         
     | 
| 
      
 173 
     | 
    
         
            +
                ndarray[Any, dtype[generic | Any]],
         
     | 
| 
      
 174 
     | 
    
         
            +
                list[int | Any],
         
     | 
| 
      
 175 
     | 
    
         
            +
                list[str],
         
     | 
| 
      
 176 
     | 
    
         
            +
            ]:
         
     | 
| 
      
 177 
     | 
    
         
            +
                """Generate a path through high symmetry points.
         
     | 
| 
      
 178 
     | 
    
         
            +
             
     | 
| 
      
 179 
     | 
    
         
            +
                Parameters
         
     | 
| 
      
 180 
     | 
    
         
            +
                ----------
         
     | 
| 
      
 181 
     | 
    
         
            +
                points : :class:`numpy.ndarray`
         
     | 
| 
      
 182 
     | 
    
         
            +
                    Test
         
     | 
| 
      
 183 
     | 
    
         
            +
                number_of_points: int
         
     | 
| 
      
 184 
     | 
    
         
            +
                    Number of point in the whole path.
         
     | 
| 
      
 185 
     | 
    
         
            +
             
     | 
| 
      
 186 
     | 
    
         
            +
                Returns
         
     | 
| 
      
 187 
     | 
    
         
            +
                -------
         
     | 
| 
      
 188 
     | 
    
         
            +
                :class:`numpy.ndarray`
         
     | 
| 
      
 189 
     | 
    
         
            +
                    List of two-dimensional k points.
         
     | 
| 
      
 190 
     | 
    
         
            +
                :class:`numpy.ndarray`
         
     | 
| 
      
 191 
     | 
    
         
            +
                    Path for plotting purposes: points between 0 and 1, with appropiate spacing.
         
     | 
| 
      
 192 
     | 
    
         
            +
                list[float]
         
     | 
| 
      
 193 
     | 
    
         
            +
                    A list of ticks for the plotting path.
         
     | 
| 
      
 194 
     | 
    
         
            +
                list[str]
         
     | 
| 
      
 195 
     | 
    
         
            +
                    A list of labels for the plotting path.
         
     | 
| 
      
 196 
     | 
    
         
            +
             
     | 
| 
      
 197 
     | 
    
         
            +
                """
         
     | 
| 
      
 198 
     | 
    
         
            +
                n = number_of_points
         
     | 
| 
      
 199 
     | 
    
         
            +
             
     | 
| 
      
 200 
     | 
    
         
            +
                cycle = [np.linalg.norm(points[i][0] - points[i + 1][0]) for i in range(len(points) - 1)]
         
     | 
| 
      
 201 
     | 
    
         
            +
                cycle.append(np.linalg.norm(points[-1][0] - points[0][0]))
         
     | 
| 
      
 202 
     | 
    
         
            +
             
     | 
| 
      
 203 
     | 
    
         
            +
                length_whole_path = np.sum(np.array([cycle]))
         
     | 
| 
      
 204 
     | 
    
         
            +
             
     | 
| 
      
 205 
     | 
    
         
            +
                ticks = [0]
         
     | 
| 
      
 206 
     | 
    
         
            +
                ticks.extend([np.sum(cycle[0 : i + 1]) / length_whole_path for i in range(len(cycle) - 1)])
         
     | 
| 
      
 207 
     | 
    
         
            +
                ticks.append(1)
         
     | 
| 
      
 208 
     | 
    
         
            +
                labels = [rf"${points[i][1]}$" for i in range(len(points))]
         
     | 
| 
      
 209 
     | 
    
         
            +
                labels.append(rf"${points[0][1]}$")
         
     | 
| 
      
 210 
     | 
    
         
            +
             
     | 
| 
      
 211 
     | 
    
         
            +
                whole_path_plot = np.concatenate(
         
     | 
| 
      
 212 
     | 
    
         
            +
                    [
         
     | 
| 
      
 213 
     | 
    
         
            +
                        np.linspace(
         
     | 
| 
      
 214 
     | 
    
         
            +
                            ticks[i],
         
     | 
| 
      
 215 
     | 
    
         
            +
                            ticks[i + 1],
         
     | 
| 
      
 216 
     | 
    
         
            +
                            num=int(n * cycle[i] / length_whole_path),
         
     | 
| 
      
 217 
     | 
    
         
            +
                            endpoint=False,
         
     | 
| 
      
 218 
     | 
    
         
            +
                        )
         
     | 
| 
      
 219 
     | 
    
         
            +
                        for i in range(len(ticks) - 1)
         
     | 
| 
      
 220 
     | 
    
         
            +
                    ]
         
     | 
| 
      
 221 
     | 
    
         
            +
                )
         
     | 
| 
      
 222 
     | 
    
         
            +
             
     | 
| 
      
 223 
     | 
    
         
            +
                points_path = [
         
     | 
| 
      
 224 
     | 
    
         
            +
                    _generate_part_of_path(points[i][0], points[i + 1][0], n, length_whole_path)
         
     | 
| 
      
 225 
     | 
    
         
            +
                    for i in range(len(points) - 1)
         
     | 
| 
      
 226 
     | 
    
         
            +
                ]
         
     | 
| 
      
 227 
     | 
    
         
            +
                points_path.append(_generate_part_of_path(points[-1][0], points[0][0], n, length_whole_path))
         
     | 
| 
      
 228 
     | 
    
         
            +
                whole_path = np.concatenate(points_path)
         
     | 
| 
      
 229 
     | 
    
         
            +
             
     | 
| 
      
 230 
     | 
    
         
            +
                return whole_path, whole_path_plot, ticks, labels
         
     |