NepTrainKit 2.2.1__tar.gz → 2.2.2.dev23__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/PKG-INFO +1 -1
- {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/_version.py +2 -2
- {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/core/canvas/base/canvas.py +10 -0
- {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/core/custom_widget/__init__.py +4 -0
- {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/core/custom_widget/completer.py +2 -2
- {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/core/custom_widget/dialog.py +75 -1
- neptrainkit-2.2.2.dev23/src/NepTrainKit/core/energy_shift.py +216 -0
- {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/core/io/base.py +3 -0
- {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/core/io/nep.py +11 -0
- {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/core/structure.py +7 -2
- {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/core/update.py +4 -3
- {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/core/views/nep.py +105 -5
- {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/core/views/toolbar.py +16 -0
- {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/src_rc.py +323 -110
- {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/utils.py +43 -1
- {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/src/NepTrainKit.egg-info/PKG-INFO +1 -1
- {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/src/NepTrainKit.egg-info/SOURCES.txt +1 -0
- {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/.gitattributes +0 -0
- {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/.gitignore +0 -0
- {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/.readthedocs.yml +0 -0
- {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/LICENSE +0 -0
- {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/MANIFEST.in +0 -0
- {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/README.md +0 -0
- {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/pyproject.toml +0 -0
- {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/requirements.txt +0 -0
- {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/setup.cfg +0 -0
- {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/setup.py +0 -0
- {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/Config/config.sqlite +0 -0
- {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/Config/nep.json +0 -0
- {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/Config/nep89.txt +0 -0
- {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/Config/ptable.json +0 -0
- {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/__init__.py +0 -0
- {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/core/__init__.py +0 -0
- {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/core/calculator.py +0 -0
- {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/core/canvas/pyqtgraph/__init__.py +0 -0
- {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/core/canvas/pyqtgraph/canvas.py +0 -0
- {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/core/canvas/pyqtgraph/structure.py +0 -0
- {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/core/canvas/vispy/__init__.py +0 -0
- {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/core/canvas/vispy/canvas.py +0 -0
- {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/core/canvas/vispy/structure.py +0 -0
- {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/core/config.py +0 -0
- {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/core/custom_widget/card_widget.py +0 -0
- {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/core/custom_widget/docker.py +0 -0
- {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/core/custom_widget/doping_rule.py +0 -0
- {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/core/custom_widget/input.py +0 -0
- {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/core/custom_widget/label.py +0 -0
- {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/core/custom_widget/layout.py +0 -0
- {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/core/custom_widget/search_widget.py +0 -0
- {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/core/custom_widget/settingscard.py +0 -0
- {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/core/custom_widget/vacancy_rule.py +0 -0
- {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/core/io/__init__.py +0 -0
- {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/core/io/select.py +0 -0
- {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/core/io/utils.py +0 -0
- {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/core/message.py +0 -0
- {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/core/pages/__init__.py +0 -0
- {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/core/pages/makedata.py +0 -0
- {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/core/pages/settings.py +0 -0
- {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/core/pages/show_nep.py +0 -0
- {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/core/types.py +0 -0
- {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/core/views/__init__.py +0 -0
- {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/core/views/cards.py +0 -0
- {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/core/views/structure.py +0 -0
- {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/main.py +0 -0
- {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/version.py +0 -0
- {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/src/NepTrainKit.egg-info/dependency_links.txt +0 -0
- {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/src/NepTrainKit.egg-info/entry_points.txt +0 -0
- {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/src/NepTrainKit.egg-info/not-zip-safe +0 -0
- {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/src/NepTrainKit.egg-info/requires.txt +0 -0
- {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/src/NepTrainKit.egg-info/top_level.txt +0 -0
- {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/src/nep_cpu/dftd3para.h +0 -0
- {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/src/nep_cpu/nep.cpp +0 -0
- {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/src/nep_cpu/nep.h +0 -0
- {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/src/nep_cpu/nep_cpu.cpp +0 -0
|
@@ -17,5 +17,5 @@ __version__: str
|
|
|
17
17
|
__version_tuple__: VERSION_TUPLE
|
|
18
18
|
version_tuple: VERSION_TUPLE
|
|
19
19
|
|
|
20
|
-
__version__ = version = '2.2.
|
|
21
|
-
__version_tuple__ = version_tuple = (2, 2,
|
|
20
|
+
__version__ = version = '2.2.2.dev23'
|
|
21
|
+
__version_tuple__ = version_tuple = (2, 2, 2, 'dev23')
|
|
@@ -156,6 +156,16 @@ class CanvasLayoutBase(CanvasBase):
|
|
|
156
156
|
self.nep_result_data.select(structure_index)
|
|
157
157
|
|
|
158
158
|
self.update_scatter_color(structure_index, Brushes.Selected)
|
|
159
|
+
|
|
160
|
+
def inverse_select(self):
|
|
161
|
+
if self.nep_result_data is None:
|
|
162
|
+
return
|
|
163
|
+
|
|
164
|
+
active_indices = set(self.nep_result_data.structure.now_indices.tolist())
|
|
165
|
+
selected = set(self.nep_result_data.select_index)
|
|
166
|
+
|
|
167
|
+
self.select_index(list(selected), True)
|
|
168
|
+
self.select_index(list(active_indices - selected), False)
|
|
159
169
|
class VispyCanvasLayoutBase(CanvasLayoutBase,QObject,metaclass=CombinedMeta):
|
|
160
170
|
def __init__(self,*args,**kwargs):
|
|
161
171
|
QObject.__init__(self)
|
{neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/core/custom_widget/__init__.py
RENAMED
|
@@ -9,6 +9,8 @@ from .completer import CompleterModel, JoinDelegate, ConfigCompleter
|
|
|
9
9
|
from .dialog import (
|
|
10
10
|
GetIntMessageBox,
|
|
11
11
|
SparseMessageBox,
|
|
12
|
+
IndexSelectMessageBox,
|
|
13
|
+
ShiftEnergyMessageBox,
|
|
12
14
|
ProgressDialog,
|
|
13
15
|
PeriodicTableDialog,
|
|
14
16
|
)
|
|
@@ -33,6 +35,8 @@ __all__ = [
|
|
|
33
35
|
"ConfigCompleter",
|
|
34
36
|
"GetIntMessageBox",
|
|
35
37
|
"SparseMessageBox",
|
|
38
|
+
"IndexSelectMessageBox",
|
|
39
|
+
"ShiftEnergyMessageBox",
|
|
36
40
|
"ProgressDialog",
|
|
37
41
|
"PeriodicTableDialog",
|
|
38
42
|
"SpinBoxUnitInputFrame",
|
{neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/core/custom_widget/completer.py
RENAMED
|
@@ -33,7 +33,7 @@ class CompleterModel(QAbstractListModel):
|
|
|
33
33
|
def data(self, index, role=Qt.ItemDataRole.DisplayRole):
|
|
34
34
|
|
|
35
35
|
if not index.isValid():
|
|
36
|
-
|
|
36
|
+
|
|
37
37
|
return None
|
|
38
38
|
# print(role)
|
|
39
39
|
# Qt.ItemDataRole.DisplayRole
|
|
@@ -44,7 +44,7 @@ class CompleterModel(QAbstractListModel):
|
|
|
44
44
|
if role == Qt.ItemDataRole.DisplayRole or role == Qt.ItemDataRole.EditRole:
|
|
45
45
|
return word
|
|
46
46
|
elif role == CountRole:
|
|
47
|
-
|
|
47
|
+
|
|
48
48
|
return str(count)
|
|
49
49
|
|
|
50
50
|
|
|
@@ -4,14 +4,16 @@
|
|
|
4
4
|
# @Author : 兵
|
|
5
5
|
# @email : 1747193328@qq.com
|
|
6
6
|
from PySide6.QtGui import QIcon
|
|
7
|
-
from PySide6.QtWidgets import QVBoxLayout, QFrame, QGridLayout, QPushButton
|
|
7
|
+
from PySide6.QtWidgets import QVBoxLayout, QFrame, QGridLayout, QPushButton, QLineEdit
|
|
8
8
|
from PySide6.QtCore import Signal, Qt
|
|
9
9
|
from qfluentwidgets import (
|
|
10
10
|
MessageBoxBase,
|
|
11
11
|
SpinBox,
|
|
12
12
|
CaptionLabel,
|
|
13
13
|
DoubleSpinBox,
|
|
14
|
+
CheckBox,
|
|
14
15
|
ProgressBar,
|
|
16
|
+
ComboBox,
|
|
15
17
|
FluentStyleSheet,
|
|
16
18
|
FluentTitleBar,
|
|
17
19
|
TitleLabel
|
|
@@ -73,6 +75,78 @@ class SparseMessageBox(MessageBoxBase):
|
|
|
73
75
|
|
|
74
76
|
self.widget.setMinimumWidth(200)
|
|
75
77
|
|
|
78
|
+
|
|
79
|
+
class IndexSelectMessageBox(MessageBoxBase):
|
|
80
|
+
"""Dialog for selecting structures by index."""
|
|
81
|
+
|
|
82
|
+
def __init__(self, parent=None, tip="Specify index or slice"):
|
|
83
|
+
super().__init__(parent)
|
|
84
|
+
self.titleLabel = CaptionLabel(tip, self)
|
|
85
|
+
self.titleLabel.setWordWrap(True)
|
|
86
|
+
self.indexEdit = QLineEdit(self)
|
|
87
|
+
self.checkBox = CheckBox("Use original indices", self)
|
|
88
|
+
self.checkBox.setChecked(True)
|
|
89
|
+
|
|
90
|
+
self.viewLayout.addWidget(self.titleLabel)
|
|
91
|
+
self.viewLayout.addWidget(self.indexEdit)
|
|
92
|
+
self.viewLayout.addWidget(self.checkBox)
|
|
93
|
+
|
|
94
|
+
self.yesButton.setText('Ok')
|
|
95
|
+
self.cancelButton.setText('Cancel')
|
|
96
|
+
self.widget.setMinimumWidth(200)
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
class ShiftEnergyMessageBox(MessageBoxBase):
|
|
100
|
+
"""Dialog for energy baseline shift parameters."""
|
|
101
|
+
|
|
102
|
+
def __init__(self, parent=None, tip="Group regex patterns (comma separated)"):
|
|
103
|
+
super().__init__(parent)
|
|
104
|
+
self.titleLabel = CaptionLabel(tip, self)
|
|
105
|
+
self.titleLabel.setWordWrap(True)
|
|
106
|
+
self.groupEdit = QLineEdit(self)
|
|
107
|
+
|
|
108
|
+
self._frame = QFrame(self)
|
|
109
|
+
self.frame_layout = QGridLayout(self._frame)
|
|
110
|
+
self.frame_layout.setContentsMargins(0, 0, 0, 0)
|
|
111
|
+
self.frame_layout.setSpacing(2)
|
|
112
|
+
|
|
113
|
+
self.genSpinBox = SpinBox(self)
|
|
114
|
+
self.genSpinBox.setMaximum(100000000)
|
|
115
|
+
self.sizeSpinBox = SpinBox(self)
|
|
116
|
+
self.sizeSpinBox.setMaximum(999999)
|
|
117
|
+
self.tolSpinBox = DoubleSpinBox(self)
|
|
118
|
+
self.tolSpinBox.setDecimals(8)
|
|
119
|
+
self.tolSpinBox.setMinimum(0)
|
|
120
|
+
self.modeCombo = ComboBox(self)
|
|
121
|
+
self.modeCombo.addItems([
|
|
122
|
+
"REF_GROUP_ALIGNMENT",
|
|
123
|
+
"ZERO_BASELINE_ALIGNMENT",
|
|
124
|
+
"DFT_TO_NEP_ALIGNMENT",
|
|
125
|
+
])
|
|
126
|
+
self.modeCombo.setCurrentText("DFT_TO_NEP_ALIGNMENT")
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
self.frame_layout.addWidget(CaptionLabel("Max generations", self), 0, 0)
|
|
130
|
+
self.frame_layout.addWidget(self.genSpinBox, 0, 1)
|
|
131
|
+
self.frame_layout.addWidget(CaptionLabel("Population size", self), 1, 0)
|
|
132
|
+
self.frame_layout.addWidget(self.sizeSpinBox, 1, 1)
|
|
133
|
+
self.frame_layout.addWidget(CaptionLabel("Convergence tol", self), 2, 0)
|
|
134
|
+
self.frame_layout.addWidget(self.tolSpinBox, 2, 1)
|
|
135
|
+
self.frame_layout.addWidget(CaptionLabel("Mode", self), 3, 0)
|
|
136
|
+
self.frame_layout.addWidget(self.modeCombo, 3, 1)
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
self.viewLayout.addWidget(self.titleLabel)
|
|
140
|
+
self.viewLayout.addWidget(self.groupEdit)
|
|
141
|
+
self.viewLayout.addWidget(self._frame)
|
|
142
|
+
|
|
143
|
+
self.yesButton.setText('Ok')
|
|
144
|
+
self.cancelButton.setText('Cancel')
|
|
145
|
+
self.widget.setMinimumWidth(250)
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
|
|
76
150
|
class ProgressDialog(FramelessDialog):
|
|
77
151
|
"""进度条弹窗"""
|
|
78
152
|
def __init__(self,parent=None,title=""):
|
|
@@ -0,0 +1,216 @@
|
|
|
1
|
+
"""Utilities for shifting structure energies using atomic baselines.
|
|
2
|
+
抄的陈博的代码(已允许)
|
|
3
|
+
url: https://github.com/brucefan1983/GPUMD/tree/master/tools/Analysis_and_Processing/energy-reference-aligner
|
|
4
|
+
Zherui Chen Email: <chenzherui0124@foxmail.com>
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import numpy as np
|
|
10
|
+
from collections import Counter
|
|
11
|
+
from typing import List, Dict
|
|
12
|
+
import re
|
|
13
|
+
from NepTrainKit import utils
|
|
14
|
+
from .structure import Structure
|
|
15
|
+
from .calculator import NepCalculator
|
|
16
|
+
|
|
17
|
+
REF_GROUP_ALIGNMENT = "REF_GROUP_ALIGNMENT"
|
|
18
|
+
ZERO_BASELINE_ALIGNMENT = "ZERO_BASELINE_ALIGNMENT"
|
|
19
|
+
DFT_TO_NEP_ALIGNMENT = "DFT_TO_NEP_ALIGNMENT"
|
|
20
|
+
|
|
21
|
+
def longest_common_prefix(strs: List[str]) -> str:
|
|
22
|
+
if not strs:
|
|
23
|
+
return ""
|
|
24
|
+
s1, s2 = min(strs), max(strs)
|
|
25
|
+
for i, c in enumerate(s1):
|
|
26
|
+
if c != s2[i]:
|
|
27
|
+
return s1[:i]
|
|
28
|
+
return s1
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def suggest_group_patterns(config_types: List[str], min_group_size: int = 2, min_prefix_len: int = 3) -> List[str]:
|
|
32
|
+
"""Group strings by common prefix without relying on delimiters, and output regex patterns."""
|
|
33
|
+
unused = set(config_types)
|
|
34
|
+
patterns = []
|
|
35
|
+
|
|
36
|
+
while unused:
|
|
37
|
+
base = unused.pop()
|
|
38
|
+
group = [base]
|
|
39
|
+
to_remove = []
|
|
40
|
+
|
|
41
|
+
for other in unused:
|
|
42
|
+
prefix = longest_common_prefix([base, other])
|
|
43
|
+
if len(prefix) >= min_prefix_len:
|
|
44
|
+
group.append(other)
|
|
45
|
+
to_remove.append(other)
|
|
46
|
+
|
|
47
|
+
for item in to_remove:
|
|
48
|
+
unused.remove(item)
|
|
49
|
+
|
|
50
|
+
if len(group) >= min_group_size:
|
|
51
|
+
prefix = longest_common_prefix(group)
|
|
52
|
+
patterns.append(re.escape(prefix) + '.*')
|
|
53
|
+
else:
|
|
54
|
+
patterns.extend(re.escape(g) for g in group)
|
|
55
|
+
|
|
56
|
+
return sorted(patterns)
|
|
57
|
+
def atomic_baseline_cost(param_population: np.ndarray,
|
|
58
|
+
energies: np.ndarray,
|
|
59
|
+
element_counts: np.ndarray,
|
|
60
|
+
target_energies: np.ndarray) -> np.ndarray:
|
|
61
|
+
"""Vectorized MSE cost for atomic reference baseline."""
|
|
62
|
+
shifted = energies[None, :] - np.dot(param_population, element_counts.T)
|
|
63
|
+
cost = np.mean((shifted - target_energies[None, :]) ** 2, axis=1)
|
|
64
|
+
return cost.reshape(-1, 1)
|
|
65
|
+
|
|
66
|
+
@utils.timeit
|
|
67
|
+
def nes_optimize_atomic_baseline(num_variables: int,
|
|
68
|
+
max_generations: int,
|
|
69
|
+
energies: np.ndarray,
|
|
70
|
+
element_counts: np.ndarray,
|
|
71
|
+
targets: np.ndarray,
|
|
72
|
+
pop_size: int = 40,
|
|
73
|
+
tol: float = 1e-8,
|
|
74
|
+
seed: int = 42,
|
|
75
|
+
print_every: int = 100) -> np.ndarray:
|
|
76
|
+
"""NES optimizer for atomic reference energies."""
|
|
77
|
+
np.random.seed(seed)
|
|
78
|
+
|
|
79
|
+
best_fitness = np.ones((max_generations, 1))
|
|
80
|
+
elite = np.zeros((max_generations, num_variables))
|
|
81
|
+
mean = -1 * np.random.rand(1, num_variables)
|
|
82
|
+
stddev = 0.1 * np.ones((1, num_variables))
|
|
83
|
+
lr_mean = 1.0
|
|
84
|
+
lr_std = (3 + np.log(num_variables)) / (5 * np.sqrt(num_variables)) / 2
|
|
85
|
+
weights = np.maximum(0, np.log(pop_size / 2 + 1) - np.log(np.arange(1, pop_size + 1)))
|
|
86
|
+
weights = weights / np.sum(weights) - 1 / pop_size
|
|
87
|
+
|
|
88
|
+
for gen in range(max_generations):
|
|
89
|
+
z = np.random.randn(pop_size, num_variables)
|
|
90
|
+
pop = mean + stddev * z
|
|
91
|
+
fitness = atomic_baseline_cost(pop, energies, element_counts, targets)
|
|
92
|
+
idx = np.argsort(fitness.flatten())
|
|
93
|
+
fitness = fitness[idx]
|
|
94
|
+
z = z[idx, :]
|
|
95
|
+
pop = pop[idx, :]
|
|
96
|
+
best_fitness[gen] = fitness[0]
|
|
97
|
+
elite[gen, :] = pop[0, :]
|
|
98
|
+
mean += lr_mean * stddev * (weights @ z)
|
|
99
|
+
stddev *= np.exp(lr_std * (weights @ (z ** 2 - 1)))
|
|
100
|
+
if gen > 0 and abs(best_fitness[gen] - best_fitness[gen - 1]) < tol:
|
|
101
|
+
best_fitness = best_fitness[:gen + 1]
|
|
102
|
+
elite = elite[:gen + 1]
|
|
103
|
+
break
|
|
104
|
+
return elite[-1]
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def shift_dataset_energy(
|
|
109
|
+
structures: List[Structure],
|
|
110
|
+
reference_structures: List[Structure] | None,
|
|
111
|
+
max_generations: int = 100000,
|
|
112
|
+
population_size: int = 40,
|
|
113
|
+
convergence_tol: float = 1e-8,
|
|
114
|
+
random_seed: int = 42,
|
|
115
|
+
group_patterns: List[str] | None = None,
|
|
116
|
+
alignment_mode: str = REF_GROUP_ALIGNMENT,
|
|
117
|
+
nep_energy_array: np.array | None = None):
|
|
118
|
+
"""Shift structure energies using different alignment strategies.
|
|
119
|
+
|
|
120
|
+
Parameters
|
|
121
|
+
----------
|
|
122
|
+
structures
|
|
123
|
+
Structures whose energies will be shifted.
|
|
124
|
+
reference_structures
|
|
125
|
+
Structures used to compute the reference mean energy when
|
|
126
|
+
``alignment_mode`` is ``REF_GROUP_ALIGNMENT``.
|
|
127
|
+
alignment_mode
|
|
128
|
+
One of ``REF_GROUP_ALIGNMENT``, ``ZERO_BASELINE_ALIGNMENT`` or
|
|
129
|
+
``DFT_TO_NEP_ALIGNMENT``.
|
|
130
|
+
nep_energy_array
|
|
131
|
+
nep energy array when ``alignment_mode`` is
|
|
132
|
+
``DFT_TO_NEP_ALIGNMENT``.
|
|
133
|
+
"""
|
|
134
|
+
frames = []
|
|
135
|
+
for s in structures:
|
|
136
|
+
energy = float(s.energy)
|
|
137
|
+
config_type = str(s.additional_fields.get("Config_type", ""))
|
|
138
|
+
elem_counts = Counter(s.elements)
|
|
139
|
+
|
|
140
|
+
frames.append({"energy": energy, "config_type": config_type, "elem_counts": elem_counts})
|
|
141
|
+
|
|
142
|
+
all_elements = sorted({e for f in frames for e in f["elem_counts"]})
|
|
143
|
+
num_elements = len(all_elements)
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
ref_mean = None
|
|
147
|
+
if alignment_mode == REF_GROUP_ALIGNMENT:
|
|
148
|
+
if not len(reference_structures):
|
|
149
|
+
raise ValueError("reference_structures is required for REF_GROUP_ALIGNMENT")
|
|
150
|
+
ref_energies = np.array([f.energy for f in reference_structures])
|
|
151
|
+
ref_mean = np.mean(ref_energies)
|
|
152
|
+
|
|
153
|
+
if alignment_mode == DFT_TO_NEP_ALIGNMENT:
|
|
154
|
+
if nep_energy_array is None:
|
|
155
|
+
raise ValueError("nep_energy_array is required for DFT_TO_NEP_ALIGNMENT")
|
|
156
|
+
|
|
157
|
+
for f, e in zip(frames, nep_energy_array):
|
|
158
|
+
f["nep_energy"] = e * f["elem_counts"].total()
|
|
159
|
+
|
|
160
|
+
all_config_types = {f["config_type"] for f in frames}
|
|
161
|
+
|
|
162
|
+
# build mapping from config_type to regex group name
|
|
163
|
+
config_to_group: Dict[str, str] = {}
|
|
164
|
+
if group_patterns:
|
|
165
|
+
for pat in group_patterns:
|
|
166
|
+
try:
|
|
167
|
+
regex = re.compile(pat)
|
|
168
|
+
except re.error:
|
|
169
|
+
continue
|
|
170
|
+
for ct in all_config_types:
|
|
171
|
+
if ct not in config_to_group and regex.match(ct):
|
|
172
|
+
config_to_group[ct] = pat
|
|
173
|
+
for ct in all_config_types:
|
|
174
|
+
config_to_group.setdefault(ct, ct)
|
|
175
|
+
|
|
176
|
+
shift_groups = sorted(set(config_to_group.values()))
|
|
177
|
+
|
|
178
|
+
group_to_atomic_ref = {}
|
|
179
|
+
for group in shift_groups:
|
|
180
|
+
|
|
181
|
+
grp_frames = [f for f in frames if config_to_group[f["config_type"]] == group]
|
|
182
|
+
if not grp_frames:
|
|
183
|
+
continue
|
|
184
|
+
energies = np.array([f["energy"] for f in grp_frames])
|
|
185
|
+
counts = np.array([[f["elem_counts"].get(e, 0) for e in all_elements] for f in grp_frames], dtype=float)
|
|
186
|
+
|
|
187
|
+
if alignment_mode == REF_GROUP_ALIGNMENT:
|
|
188
|
+
targets = np.full_like(energies, ref_mean)
|
|
189
|
+
elif alignment_mode == ZERO_BASELINE_ALIGNMENT:
|
|
190
|
+
targets = np.zeros_like(energies)
|
|
191
|
+
else: # DFT_TO_NEP_ALIGNMENT
|
|
192
|
+
targets = np.array([f["nep_energy"] for f in grp_frames])
|
|
193
|
+
atomic_ref = nes_optimize_atomic_baseline(
|
|
194
|
+
num_elements,
|
|
195
|
+
max_generations,
|
|
196
|
+
energies,
|
|
197
|
+
counts,
|
|
198
|
+
targets,
|
|
199
|
+
pop_size=population_size,
|
|
200
|
+
tol=convergence_tol,
|
|
201
|
+
seed=random_seed,
|
|
202
|
+
print_every=100,
|
|
203
|
+
)
|
|
204
|
+
group_to_atomic_ref[group] = atomic_ref
|
|
205
|
+
#这里是为了更新ui信号
|
|
206
|
+
yield 1
|
|
207
|
+
|
|
208
|
+
# apply shift
|
|
209
|
+
for s, frame in zip(structures, frames):
|
|
210
|
+
group = config_to_group[frame["config_type"]]
|
|
211
|
+
if group in group_to_atomic_ref:
|
|
212
|
+
count_vec = np.array([frame["elem_counts"].get(e, 0) for e in all_elements], dtype=float)
|
|
213
|
+
shift = np.dot(count_vec, group_to_atomic_ref[group])
|
|
214
|
+
new_energy = frame["energy"] - shift
|
|
215
|
+
s.energy = new_energy
|
|
216
|
+
# return group_to_atomic_ref
|
|
@@ -155,6 +155,17 @@ class ResultData(QObject):
|
|
|
155
155
|
self.select_index.remove(i)
|
|
156
156
|
|
|
157
157
|
self.updateInfoSignal.emit()
|
|
158
|
+
|
|
159
|
+
def inverse_select(self):
|
|
160
|
+
"""Invert the current selection state of all active structures"""
|
|
161
|
+
active_indices = set(self.structure.data.now_indices.tolist())
|
|
162
|
+
selected_indices = set(self.select_index)
|
|
163
|
+
unselect = list(selected_indices)
|
|
164
|
+
select = list(active_indices - selected_indices)
|
|
165
|
+
if unselect:
|
|
166
|
+
self.uncheck(unselect)
|
|
167
|
+
if select:
|
|
168
|
+
self.select(select)
|
|
158
169
|
def export_selected_xyz(self,save_file_path):
|
|
159
170
|
"""
|
|
160
171
|
导出当前选中的结构
|
|
@@ -116,8 +116,13 @@ class Structure:
|
|
|
116
116
|
|
|
117
117
|
@property
|
|
118
118
|
def per_atom_energy(self):
|
|
119
|
-
return self.
|
|
120
|
-
|
|
119
|
+
return self.energy/self.num_atoms
|
|
120
|
+
@property
|
|
121
|
+
def energy(self):
|
|
122
|
+
return self.additional_fields["energy"]
|
|
123
|
+
@energy.setter
|
|
124
|
+
def energy(self,new_energy):
|
|
125
|
+
self.additional_fields["energy"] = new_energy
|
|
121
126
|
@property
|
|
122
127
|
def forces(self):
|
|
123
128
|
return self.structure_info[self.force_label]
|
|
@@ -72,9 +72,7 @@ class UpdateWoker( QObject):
|
|
|
72
72
|
if version_info['tag_name'][1:] == __version__:
|
|
73
73
|
MessageManager.send_success_message("You are already using the latest version!")
|
|
74
74
|
return
|
|
75
|
-
|
|
76
|
-
MessageManager.send_info_message("You can update via pip install NepTrainKit -U")
|
|
77
|
-
return
|
|
75
|
+
|
|
78
76
|
box = MessageBox("New version detected:" + version_info["name"] + version_info["tag_name"],
|
|
79
77
|
version_info["body"],
|
|
80
78
|
self._parent
|
|
@@ -92,6 +90,9 @@ class UpdateWoker( QObject):
|
|
|
92
90
|
MessageManager.send_warning_message("No update package available for your system. Please download it manually!")
|
|
93
91
|
|
|
94
92
|
def check_update(self):
|
|
93
|
+
if not is_nuitka_compiled:
|
|
94
|
+
MessageManager.send_info_message("You can update via pip install NepTrainKit -U --pre")
|
|
95
|
+
return
|
|
95
96
|
self.update_thread.start_work(self._check_update)
|
|
96
97
|
|
|
97
98
|
|
|
@@ -3,17 +3,28 @@
|
|
|
3
3
|
# @Time : 2024/10/20 22:22
|
|
4
4
|
# @Author : 兵
|
|
5
5
|
# @email : 1747193328@qq.com
|
|
6
|
+
import os
|
|
6
7
|
import time
|
|
8
|
+
import traceback
|
|
9
|
+
|
|
10
|
+
from loguru import logger
|
|
11
|
+
|
|
7
12
|
start=time.time()
|
|
8
13
|
import numpy as np
|
|
9
14
|
from PySide6.QtWidgets import QHBoxLayout, QWidget, QProgressDialog
|
|
10
15
|
|
|
11
16
|
|
|
12
|
-
from NepTrainKit import utils
|
|
17
|
+
from NepTrainKit import utils, module_path
|
|
13
18
|
from NepTrainKit.core import MessageManager, Config
|
|
14
|
-
from NepTrainKit.core.custom_widget import
|
|
19
|
+
from NepTrainKit.core.custom_widget import (
|
|
20
|
+
GetIntMessageBox,
|
|
21
|
+
SparseMessageBox,
|
|
22
|
+
IndexSelectMessageBox,
|
|
23
|
+
ShiftEnergyMessageBox,
|
|
24
|
+
)
|
|
15
25
|
from NepTrainKit.core.io.select import farthest_point_sampling
|
|
16
26
|
from NepTrainKit.core.views.toolbar import NepDisplayGraphicsToolBar
|
|
27
|
+
from NepTrainKit.core.energy_shift import shift_dataset_energy, suggest_group_patterns
|
|
17
28
|
|
|
18
29
|
|
|
19
30
|
class NepResultPlotWidget(QWidget):
|
|
@@ -63,6 +74,9 @@ class NepResultPlotWidget(QWidget):
|
|
|
63
74
|
self.tool_bar.findMaxSignal.connect(self.find_max_error_point)
|
|
64
75
|
self.tool_bar.discoverySignal.connect(self.find_non_physical_structures)
|
|
65
76
|
self.tool_bar.sparseSignal.connect(self.sparse_point)
|
|
77
|
+
self.tool_bar.shiftEnergySignal.connect(self.shift_energy_baseline)
|
|
78
|
+
self.tool_bar.inverseSignal.connect(self.inverse_select)
|
|
79
|
+
self.tool_bar.selectIndexSignal.connect(self.select_by_index)
|
|
66
80
|
self.canvas.tool_bar=self.tool_bar
|
|
67
81
|
|
|
68
82
|
|
|
@@ -142,11 +156,11 @@ class NepResultPlotWidget(QWidget):
|
|
|
142
156
|
remaining_indices = farthest_point_sampling(dataset.now_data,n_samples=n_samples,min_dist=distance)
|
|
143
157
|
|
|
144
158
|
# 获取所有索引(从 0 到 len(arr)-1)
|
|
145
|
-
all_indices = np.arange(dataset.now_data.shape[0])
|
|
159
|
+
# all_indices = np.arange(dataset.now_data.shape[0])
|
|
146
160
|
|
|
147
161
|
# 使用 setdiff1d 获取不在 indices_to_remove 中的索引
|
|
148
|
-
remove_indices = np.setdiff1d(all_indices, remaining_indices)
|
|
149
|
-
structures = dataset.group_array[
|
|
162
|
+
# remove_indices = np.setdiff1d(all_indices, remaining_indices)
|
|
163
|
+
structures = dataset.group_array[remaining_indices]
|
|
150
164
|
self.canvas.select_index(structures.tolist(),False)
|
|
151
165
|
|
|
152
166
|
def export_descriptor_data(self):
|
|
@@ -175,6 +189,92 @@ class NepResultPlotWidget(QWidget):
|
|
|
175
189
|
with open(path, "w") as f:
|
|
176
190
|
np.savetxt(f,descriptor_data,fmt='%.6g',delimiter='\t')
|
|
177
191
|
|
|
192
|
+
|
|
193
|
+
def shift_energy_baseline(self):
|
|
194
|
+
data = self.canvas.nep_result_data
|
|
195
|
+
if data is None:
|
|
196
|
+
return
|
|
197
|
+
ref_index = list(data.select_index)
|
|
198
|
+
# if len(ref_index) == 0:
|
|
199
|
+
# MessageManager.send_info_message("No data selected!")
|
|
200
|
+
# return
|
|
201
|
+
|
|
202
|
+
max_generations = Config.getint("widget","max_generation_value",100000)
|
|
203
|
+
population_size = Config.getint("widget","population_size",40)
|
|
204
|
+
convergence_tol = Config.getfloat("widget","convergence_tol", 1e-8)
|
|
205
|
+
config_set = set(data.structure.get_all_config())
|
|
206
|
+
suggested = suggest_group_patterns(list(config_set))
|
|
207
|
+
box = ShiftEnergyMessageBox(
|
|
208
|
+
self._parent,
|
|
209
|
+
"Specify regex groups for Config_type (comma separated)"
|
|
210
|
+
)
|
|
211
|
+
box.groupEdit.setText(";".join(suggested))
|
|
212
|
+
box.genSpinBox.setValue(max_generations)
|
|
213
|
+
box.sizeSpinBox.setValue(population_size)
|
|
214
|
+
box.tolSpinBox.setValue(convergence_tol)
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
if not box.exec():
|
|
218
|
+
return
|
|
219
|
+
|
|
220
|
+
pattern_text = box.groupEdit.text().strip()
|
|
221
|
+
group_patterns = [p.strip() for p in pattern_text.split(';') if p.strip()]
|
|
222
|
+
|
|
223
|
+
alignment_mode = box.modeCombo.currentText()
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
max_generations = box.genSpinBox.value()
|
|
227
|
+
population_size = box.sizeSpinBox.value()
|
|
228
|
+
convergence_tol = box.tolSpinBox.value()
|
|
229
|
+
Config.set("widget","max_generation_value",max_generations)
|
|
230
|
+
Config.set("widget","population_size",population_size)
|
|
231
|
+
Config.set("widget","convergence_tol",convergence_tol)
|
|
232
|
+
config_set = set(data.structure.get_all_config())
|
|
233
|
+
progress_diag = QProgressDialog(f"", "Cancel", 0, len(config_set), self._parent)
|
|
234
|
+
thread = utils.LoadingThread(self._parent, show_tip=False)
|
|
235
|
+
progress_diag.setFixedSize(300, 100)
|
|
236
|
+
progress_diag.setWindowTitle("Shift energies")
|
|
237
|
+
thread.progressSignal.connect(progress_diag.setValue)
|
|
238
|
+
thread.finished.connect(progress_diag.accept)
|
|
239
|
+
progress_diag.canceled.connect(thread.stop_work) # 用户取消时终止线程
|
|
240
|
+
thread.start_work(
|
|
241
|
+
shift_dataset_energy,
|
|
242
|
+
structures=data.structure.now_data,
|
|
243
|
+
reference_structures=data.structure.all_data[ref_index],
|
|
244
|
+
max_generations=max_generations,
|
|
245
|
+
population_size=population_size,
|
|
246
|
+
convergence_tol=convergence_tol,
|
|
247
|
+
group_patterns=group_patterns,
|
|
248
|
+
alignment_mode=alignment_mode,
|
|
249
|
+
nep_energy_array=data.energy.y,
|
|
250
|
+
)
|
|
251
|
+
progress_diag.exec()
|
|
252
|
+
if hasattr(data, "energy") and data.energy.num != 0:
|
|
253
|
+
for i, s in enumerate(data.structure.all_data):
|
|
254
|
+
# print(s.per_atom_energy)
|
|
255
|
+
data.energy.data._data[i, 1] = s.per_atom_energy
|
|
256
|
+
self.canvas.plot_nep_result()
|
|
257
|
+
|
|
258
|
+
def inverse_select(self):
|
|
259
|
+
self.canvas.inverse_select()
|
|
260
|
+
|
|
261
|
+
def select_by_index(self):
|
|
262
|
+
if self.canvas.nep_result_data is None:
|
|
263
|
+
return
|
|
264
|
+
box = IndexSelectMessageBox(self._parent, "Select structures by index")
|
|
265
|
+
if not box.exec():
|
|
266
|
+
return
|
|
267
|
+
text = box.indexEdit.text().strip()
|
|
268
|
+
use_origin = box.checkBox.isChecked()
|
|
269
|
+
data = self.canvas.nep_result_data.structure
|
|
270
|
+
total = data.all_data.shape[0] if use_origin else data.now_data.shape[0]
|
|
271
|
+
indices = utils.parse_index_string(text, total)
|
|
272
|
+
if not indices:
|
|
273
|
+
return
|
|
274
|
+
if not use_origin:
|
|
275
|
+
indices = data.group_array.now_data[indices].tolist()
|
|
276
|
+
self.canvas.select_index(indices, False)
|
|
277
|
+
|
|
178
278
|
def set_dataset(self,dataset):
|
|
179
279
|
|
|
180
280
|
if self.last_figure_num !=len(dataset.dataset):
|
|
@@ -42,6 +42,9 @@ class NepDisplayGraphicsToolBar(KitToolBarBase):
|
|
|
42
42
|
deleteSignal=Signal()
|
|
43
43
|
revokeSignal=Signal()
|
|
44
44
|
exportSignal=Signal()
|
|
45
|
+
shiftEnergySignal=Signal()
|
|
46
|
+
inverseSignal=Signal()
|
|
47
|
+
selectIndexSignal=Signal()
|
|
45
48
|
|
|
46
49
|
def init_actions(self):
|
|
47
50
|
self.addButton("Reset View",QIcon(":/images/src/images/init.svg"),self.resetSignal)
|
|
@@ -50,6 +53,9 @@ class NepDisplayGraphicsToolBar(KitToolBarBase):
|
|
|
50
53
|
self.pan,
|
|
51
54
|
True
|
|
52
55
|
)
|
|
56
|
+
self.addButton("Select by Index",
|
|
57
|
+
QIcon(":/images/src/images/index.svg"),
|
|
58
|
+
self.selectIndexSignal)
|
|
53
59
|
find_max_action = self.addButton("Find Max Error Point",
|
|
54
60
|
QIcon(":/images/src/images/find_max.svg"),
|
|
55
61
|
self.findMaxSignal)
|
|
@@ -58,6 +64,8 @@ class NepDisplayGraphicsToolBar(KitToolBarBase):
|
|
|
58
64
|
self.sparseSignal)
|
|
59
65
|
|
|
60
66
|
|
|
67
|
+
|
|
68
|
+
|
|
61
69
|
pen_action=self.addButton("Mouse Selection",
|
|
62
70
|
QIcon(":/images/src/images/pen.svg"),
|
|
63
71
|
self.pen,
|
|
@@ -71,6 +79,9 @@ class NepDisplayGraphicsToolBar(KitToolBarBase):
|
|
|
71
79
|
discovery_action = self.addButton("Finding non-physical structures",
|
|
72
80
|
QIcon(":/images/src/images/discovery.svg"),
|
|
73
81
|
self.discoverySignal)
|
|
82
|
+
inverse_action = self.addButton("Inverse Selection",
|
|
83
|
+
QIcon(":/images/src/images/inverse.svg"),
|
|
84
|
+
self.inverseSignal)
|
|
74
85
|
revoke_action = self.addButton("Undo",
|
|
75
86
|
QIcon(":/images/src/images/revoke.svg"),
|
|
76
87
|
self.revokeSignal)
|
|
@@ -78,10 +89,15 @@ class NepDisplayGraphicsToolBar(KitToolBarBase):
|
|
|
78
89
|
delete_action = self.addButton("Delete Selected Items",
|
|
79
90
|
QIcon(":/images/src/images/delete.svg"),
|
|
80
91
|
self.deleteSignal)
|
|
92
|
+
|
|
81
93
|
self.addSeparator()
|
|
82
94
|
export_action = self.addButton("Export structure descriptor",
|
|
83
95
|
QIcon(":/images/src/images/export.svg"),
|
|
84
96
|
self.exportSignal)
|
|
97
|
+
self.addSeparator()
|
|
98
|
+
self.addButton("Energy Baseline Shift",
|
|
99
|
+
QIcon(":/images/src/images/alignment.svg"),
|
|
100
|
+
self.shiftEnergySignal)
|
|
85
101
|
|
|
86
102
|
def reset(self):
|
|
87
103
|
if self.action_group.checkedAction():
|