NepTrainKit 2.2.0__tar.gz → 2.2.2.dev23__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. {neptrainkit-2.2.0 → neptrainkit-2.2.2.dev23}/PKG-INFO +1 -1
  2. {neptrainkit-2.2.0 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/_version.py +2 -2
  3. {neptrainkit-2.2.0 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/core/canvas/base/canvas.py +10 -0
  4. {neptrainkit-2.2.0 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/core/custom_widget/__init__.py +6 -0
  5. {neptrainkit-2.2.0 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/core/custom_widget/completer.py +2 -2
  6. {neptrainkit-2.2.0 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/core/custom_widget/dialog.py +75 -1
  7. {neptrainkit-2.2.0 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/core/custom_widget/doping_rule.py +20 -17
  8. {neptrainkit-2.2.0 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/core/custom_widget/input.py +3 -0
  9. neptrainkit-2.2.2.dev23/src/NepTrainKit/core/custom_widget/vacancy_rule.py +139 -0
  10. neptrainkit-2.2.2.dev23/src/NepTrainKit/core/energy_shift.py +216 -0
  11. {neptrainkit-2.2.0 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/core/io/base.py +3 -0
  12. {neptrainkit-2.2.0 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/core/io/nep.py +12 -1
  13. {neptrainkit-2.2.0 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/core/structure.py +7 -2
  14. {neptrainkit-2.2.0 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/core/update.py +4 -3
  15. {neptrainkit-2.2.0 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/core/views/cards.py +99 -3
  16. {neptrainkit-2.2.0 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/core/views/nep.py +105 -5
  17. {neptrainkit-2.2.0 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/core/views/toolbar.py +16 -0
  18. {neptrainkit-2.2.0 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/src_rc.py +323 -110
  19. {neptrainkit-2.2.0 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/utils.py +43 -1
  20. {neptrainkit-2.2.0 → neptrainkit-2.2.2.dev23}/src/NepTrainKit.egg-info/PKG-INFO +1 -1
  21. {neptrainkit-2.2.0 → neptrainkit-2.2.2.dev23}/src/NepTrainKit.egg-info/SOURCES.txt +2 -0
  22. {neptrainkit-2.2.0 → neptrainkit-2.2.2.dev23}/.gitattributes +0 -0
  23. {neptrainkit-2.2.0 → neptrainkit-2.2.2.dev23}/.gitignore +0 -0
  24. {neptrainkit-2.2.0 → neptrainkit-2.2.2.dev23}/.readthedocs.yml +0 -0
  25. {neptrainkit-2.2.0 → neptrainkit-2.2.2.dev23}/LICENSE +0 -0
  26. {neptrainkit-2.2.0 → neptrainkit-2.2.2.dev23}/MANIFEST.in +0 -0
  27. {neptrainkit-2.2.0 → neptrainkit-2.2.2.dev23}/README.md +0 -0
  28. {neptrainkit-2.2.0 → neptrainkit-2.2.2.dev23}/pyproject.toml +0 -0
  29. {neptrainkit-2.2.0 → neptrainkit-2.2.2.dev23}/requirements.txt +0 -0
  30. {neptrainkit-2.2.0 → neptrainkit-2.2.2.dev23}/setup.cfg +0 -0
  31. {neptrainkit-2.2.0 → neptrainkit-2.2.2.dev23}/setup.py +0 -0
  32. {neptrainkit-2.2.0 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/Config/config.sqlite +0 -0
  33. {neptrainkit-2.2.0 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/Config/nep.json +0 -0
  34. {neptrainkit-2.2.0 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/Config/nep89.txt +0 -0
  35. {neptrainkit-2.2.0 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/Config/ptable.json +0 -0
  36. {neptrainkit-2.2.0 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/__init__.py +0 -0
  37. {neptrainkit-2.2.0 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/core/__init__.py +0 -0
  38. {neptrainkit-2.2.0 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/core/calculator.py +0 -0
  39. {neptrainkit-2.2.0 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/core/canvas/pyqtgraph/__init__.py +0 -0
  40. {neptrainkit-2.2.0 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/core/canvas/pyqtgraph/canvas.py +0 -0
  41. {neptrainkit-2.2.0 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/core/canvas/pyqtgraph/structure.py +0 -0
  42. {neptrainkit-2.2.0 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/core/canvas/vispy/__init__.py +0 -0
  43. {neptrainkit-2.2.0 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/core/canvas/vispy/canvas.py +0 -0
  44. {neptrainkit-2.2.0 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/core/canvas/vispy/structure.py +0 -0
  45. {neptrainkit-2.2.0 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/core/config.py +0 -0
  46. {neptrainkit-2.2.0 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/core/custom_widget/card_widget.py +0 -0
  47. {neptrainkit-2.2.0 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/core/custom_widget/docker.py +0 -0
  48. {neptrainkit-2.2.0 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/core/custom_widget/label.py +0 -0
  49. {neptrainkit-2.2.0 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/core/custom_widget/layout.py +0 -0
  50. {neptrainkit-2.2.0 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/core/custom_widget/search_widget.py +0 -0
  51. {neptrainkit-2.2.0 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/core/custom_widget/settingscard.py +0 -0
  52. {neptrainkit-2.2.0 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/core/io/__init__.py +0 -0
  53. {neptrainkit-2.2.0 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/core/io/select.py +0 -0
  54. {neptrainkit-2.2.0 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/core/io/utils.py +0 -0
  55. {neptrainkit-2.2.0 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/core/message.py +0 -0
  56. {neptrainkit-2.2.0 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/core/pages/__init__.py +0 -0
  57. {neptrainkit-2.2.0 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/core/pages/makedata.py +0 -0
  58. {neptrainkit-2.2.0 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/core/pages/settings.py +0 -0
  59. {neptrainkit-2.2.0 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/core/pages/show_nep.py +0 -0
  60. {neptrainkit-2.2.0 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/core/types.py +0 -0
  61. {neptrainkit-2.2.0 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/core/views/__init__.py +0 -0
  62. {neptrainkit-2.2.0 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/core/views/structure.py +0 -0
  63. {neptrainkit-2.2.0 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/main.py +0 -0
  64. {neptrainkit-2.2.0 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/version.py +0 -0
  65. {neptrainkit-2.2.0 → neptrainkit-2.2.2.dev23}/src/NepTrainKit.egg-info/dependency_links.txt +0 -0
  66. {neptrainkit-2.2.0 → neptrainkit-2.2.2.dev23}/src/NepTrainKit.egg-info/entry_points.txt +0 -0
  67. {neptrainkit-2.2.0 → neptrainkit-2.2.2.dev23}/src/NepTrainKit.egg-info/not-zip-safe +0 -0
  68. {neptrainkit-2.2.0 → neptrainkit-2.2.2.dev23}/src/NepTrainKit.egg-info/requires.txt +0 -0
  69. {neptrainkit-2.2.0 → neptrainkit-2.2.2.dev23}/src/NepTrainKit.egg-info/top_level.txt +0 -0
  70. {neptrainkit-2.2.0 → neptrainkit-2.2.2.dev23}/src/nep_cpu/dftd3para.h +0 -0
  71. {neptrainkit-2.2.0 → neptrainkit-2.2.2.dev23}/src/nep_cpu/nep.cpp +0 -0
  72. {neptrainkit-2.2.0 → neptrainkit-2.2.2.dev23}/src/nep_cpu/nep.h +0 -0
  73. {neptrainkit-2.2.0 → neptrainkit-2.2.2.dev23}/src/nep_cpu/nep_cpu.cpp +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: NepTrainKit
3
- Version: 2.2.0
3
+ Version: 2.2.2.dev23
4
4
  Summary: NepTrainKit is a Python package for visualizing and manipulating training datasets for NEP.
5
5
  Author: Chen Cheng bing
6
6
  Author-email: Chen Cheng bing <1747193328@qq.com>
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '2.2.0'
21
- __version_tuple__ = version_tuple = (2, 2, 0)
20
+ __version__ = version = '2.2.2.dev23'
21
+ __version_tuple__ = version_tuple = (2, 2, 2, 'dev23')
@@ -156,6 +156,16 @@ class CanvasLayoutBase(CanvasBase):
156
156
  self.nep_result_data.select(structure_index)
157
157
 
158
158
  self.update_scatter_color(structure_index, Brushes.Selected)
159
+
160
+ def inverse_select(self):
161
+ if self.nep_result_data is None:
162
+ return
163
+
164
+ active_indices = set(self.nep_result_data.structure.now_indices.tolist())
165
+ selected = set(self.nep_result_data.select_index)
166
+
167
+ self.select_index(list(selected), True)
168
+ self.select_index(list(active_indices - selected), False)
159
169
  class VispyCanvasLayoutBase(CanvasLayoutBase,QObject,metaclass=CombinedMeta):
160
170
  def __init__(self,*args,**kwargs):
161
171
  QObject.__init__(self)
@@ -9,6 +9,8 @@ from .completer import CompleterModel, JoinDelegate, ConfigCompleter
9
9
  from .dialog import (
10
10
  GetIntMessageBox,
11
11
  SparseMessageBox,
12
+ IndexSelectMessageBox,
13
+ ShiftEnergyMessageBox,
12
14
  ProgressDialog,
13
15
  PeriodicTableDialog,
14
16
  )
@@ -19,6 +21,7 @@ from .card_widget import (
19
21
  MakeDataCardWidget,
20
22
  )
21
23
  from .doping_rule import DopingRulesWidget
24
+ from .vacancy_rule import VacancyRulesWidget
22
25
 
23
26
  from .docker import MakeWorkflowArea
24
27
  from .search_widget import ConfigTypeSearchLineEdit
@@ -32,6 +35,8 @@ __all__ = [
32
35
  "ConfigCompleter",
33
36
  "GetIntMessageBox",
34
37
  "SparseMessageBox",
38
+ "IndexSelectMessageBox",
39
+ "ShiftEnergyMessageBox",
35
40
  "ProgressDialog",
36
41
  "PeriodicTableDialog",
37
42
  "SpinBoxUnitInputFrame",
@@ -43,4 +48,5 @@ __all__ = [
43
48
  "MyComboBoxSettingCard",
44
49
  "DoubleSpinBoxSettingCard",
45
50
  "DopingRulesWidget",
51
+ "VacancyRulesWidget",
46
52
  ]
@@ -33,7 +33,7 @@ class CompleterModel(QAbstractListModel):
33
33
  def data(self, index, role=Qt.ItemDataRole.DisplayRole):
34
34
 
35
35
  if not index.isValid():
36
- print("index")
36
+
37
37
  return None
38
38
  # print(role)
39
39
  # Qt.ItemDataRole.DisplayRole
@@ -44,7 +44,7 @@ class CompleterModel(QAbstractListModel):
44
44
  if role == Qt.ItemDataRole.DisplayRole or role == Qt.ItemDataRole.EditRole:
45
45
  return word
46
46
  elif role == CountRole:
47
- print(count)
47
+
48
48
  return str(count)
49
49
 
50
50
 
@@ -4,14 +4,16 @@
4
4
  # @Author : 兵
5
5
  # @email : 1747193328@qq.com
6
6
  from PySide6.QtGui import QIcon
7
- from PySide6.QtWidgets import QVBoxLayout, QFrame, QGridLayout, QPushButton
7
+ from PySide6.QtWidgets import QVBoxLayout, QFrame, QGridLayout, QPushButton, QLineEdit
8
8
  from PySide6.QtCore import Signal, Qt
9
9
  from qfluentwidgets import (
10
10
  MessageBoxBase,
11
11
  SpinBox,
12
12
  CaptionLabel,
13
13
  DoubleSpinBox,
14
+ CheckBox,
14
15
  ProgressBar,
16
+ ComboBox,
15
17
  FluentStyleSheet,
16
18
  FluentTitleBar,
17
19
  TitleLabel
@@ -73,6 +75,78 @@ class SparseMessageBox(MessageBoxBase):
73
75
 
74
76
  self.widget.setMinimumWidth(200)
75
77
 
78
+
79
+ class IndexSelectMessageBox(MessageBoxBase):
80
+ """Dialog for selecting structures by index."""
81
+
82
+ def __init__(self, parent=None, tip="Specify index or slice"):
83
+ super().__init__(parent)
84
+ self.titleLabel = CaptionLabel(tip, self)
85
+ self.titleLabel.setWordWrap(True)
86
+ self.indexEdit = QLineEdit(self)
87
+ self.checkBox = CheckBox("Use original indices", self)
88
+ self.checkBox.setChecked(True)
89
+
90
+ self.viewLayout.addWidget(self.titleLabel)
91
+ self.viewLayout.addWidget(self.indexEdit)
92
+ self.viewLayout.addWidget(self.checkBox)
93
+
94
+ self.yesButton.setText('Ok')
95
+ self.cancelButton.setText('Cancel')
96
+ self.widget.setMinimumWidth(200)
97
+
98
+
99
+ class ShiftEnergyMessageBox(MessageBoxBase):
100
+ """Dialog for energy baseline shift parameters."""
101
+
102
+ def __init__(self, parent=None, tip="Group regex patterns (comma separated)"):
103
+ super().__init__(parent)
104
+ self.titleLabel = CaptionLabel(tip, self)
105
+ self.titleLabel.setWordWrap(True)
106
+ self.groupEdit = QLineEdit(self)
107
+
108
+ self._frame = QFrame(self)
109
+ self.frame_layout = QGridLayout(self._frame)
110
+ self.frame_layout.setContentsMargins(0, 0, 0, 0)
111
+ self.frame_layout.setSpacing(2)
112
+
113
+ self.genSpinBox = SpinBox(self)
114
+ self.genSpinBox.setMaximum(100000000)
115
+ self.sizeSpinBox = SpinBox(self)
116
+ self.sizeSpinBox.setMaximum(999999)
117
+ self.tolSpinBox = DoubleSpinBox(self)
118
+ self.tolSpinBox.setDecimals(8)
119
+ self.tolSpinBox.setMinimum(0)
120
+ self.modeCombo = ComboBox(self)
121
+ self.modeCombo.addItems([
122
+ "REF_GROUP_ALIGNMENT",
123
+ "ZERO_BASELINE_ALIGNMENT",
124
+ "DFT_TO_NEP_ALIGNMENT",
125
+ ])
126
+ self.modeCombo.setCurrentText("DFT_TO_NEP_ALIGNMENT")
127
+
128
+
129
+ self.frame_layout.addWidget(CaptionLabel("Max generations", self), 0, 0)
130
+ self.frame_layout.addWidget(self.genSpinBox, 0, 1)
131
+ self.frame_layout.addWidget(CaptionLabel("Population size", self), 1, 0)
132
+ self.frame_layout.addWidget(self.sizeSpinBox, 1, 1)
133
+ self.frame_layout.addWidget(CaptionLabel("Convergence tol", self), 2, 0)
134
+ self.frame_layout.addWidget(self.tolSpinBox, 2, 1)
135
+ self.frame_layout.addWidget(CaptionLabel("Mode", self), 3, 0)
136
+ self.frame_layout.addWidget(self.modeCombo, 3, 1)
137
+
138
+
139
+ self.viewLayout.addWidget(self.titleLabel)
140
+ self.viewLayout.addWidget(self.groupEdit)
141
+ self.viewLayout.addWidget(self._frame)
142
+
143
+ self.yesButton.setText('Ok')
144
+ self.cancelButton.setText('Cancel')
145
+ self.widget.setMinimumWidth(250)
146
+
147
+
148
+
149
+
76
150
  class ProgressDialog(FramelessDialog):
77
151
  """进度条弹窗"""
78
152
  def __init__(self,parent=None,title=""):
@@ -15,20 +15,19 @@ from PySide6.QtWidgets import (
15
15
  QGridLayout,
16
16
  QHBoxLayout,
17
17
  QVBoxLayout,
18
- QWidget, QLineEdit,
19
-
18
+ QWidget,
19
+ QLineEdit,
20
20
  )
21
21
  from qfluentwidgets import (
22
22
  BodyLabel,
23
23
  TransparentToolButton,
24
- SpinBox,
25
- DoubleSpinBox,
26
24
  FluentIcon,
27
25
  LineEdit,
28
26
  RadioButton,
29
27
  ToolTipFilter,
30
28
  ToolTipPosition,
31
29
  )
30
+ from .input import SpinBoxUnitInputFrame
32
31
 
33
32
 
34
33
  class DopingRuleItem(QFrame):
@@ -43,16 +42,20 @@ class DopingRuleItem(QFrame):
43
42
  self.target_edit = QLineEdit(self)
44
43
  self.target_edit.setPlaceholderText("Cs")
45
44
 
46
- self.setFixedSize(300, 100)
45
+ self.setFixedSize(300, 130)
47
46
  self.dopants_edit = QLineEdit(self)
48
47
 
49
- self.concentration_spin = QLineEdit(self)
50
- self.concentration_spin.setText("1")
48
+ self.concentration_frame = SpinBoxUnitInputFrame(self)
49
+ self.concentration_frame.set_input(["-", ""], 2, "float")
50
+ self.concentration_frame.setRange(0, 1)
51
+ self.concentration_frame.set_input_value([1.0, 1.0])
51
52
 
52
53
  self.concentration_botton = RadioButton("Conc", self)
53
54
  self.concentration_botton.setChecked(True)
54
- self.count_spin = QLineEdit(self)
55
- self.count_spin.setText("10")
55
+ self.count_frame = SpinBoxUnitInputFrame(self)
56
+ self.count_frame.set_input(["-", ""], 2, "int")
57
+ self.count_frame.setRange(0, 10000)
58
+ self.count_frame.set_input_value([10, 10])
56
59
  self.count_botton = RadioButton("Count", self)
57
60
 
58
61
  self.indices_edit = QLineEdit(self)
@@ -78,15 +81,15 @@ class DopingRuleItem(QFrame):
78
81
  self.concentration_botton.setToolTip("Use concentration")
79
82
  self.concentration_botton.installEventFilter(ToolTipFilter(self.concentration_botton, 300, ToolTipPosition.TOP))
80
83
  self.layout.addWidget(self.concentration_botton, 2, 0)
81
- self.layout.addWidget(self.concentration_spin, 2, 1)
84
+ self.layout.addWidget(self.concentration_frame, 2, 1,1,4)
82
85
  self.count_botton.setToolTip("Use count")
83
86
  self.count_botton.installEventFilter(ToolTipFilter(self.count_botton, 300, ToolTipPosition.TOP))
84
- self.layout.addWidget(self.count_botton, 2, 2)
85
- self.layout.addWidget(self.count_spin, 2, 3)
87
+ self.layout.addWidget(self.count_botton, 3, 0)
88
+ self.layout.addWidget(self.count_frame, 3, 1,1,4)
86
89
 
87
90
  self.delete_button.setToolTip("Delete rule")
88
91
  self.delete_button.installEventFilter(ToolTipFilter(self.delete_button, 300, ToolTipPosition.TOP))
89
- self.layout.addWidget(self.delete_button, 0, 4, 3, 3)
92
+ self.layout.addWidget(self.delete_button, 0, 4, 3, 1)
90
93
 
91
94
  def _delete_self(self) -> None:
92
95
  self.setParent(None)
@@ -115,9 +118,9 @@ class DopingRuleItem(QFrame):
115
118
  except Exception:
116
119
  logger.error(traceback.format_exc())
117
120
 
118
- rule["concentration"] =float(self.concentration_spin.text().strip()) if self.concentration_spin.text().strip() else 0
121
+ rule["concentration"] = [float(v) for v in self.concentration_frame.get_input_value()]
119
122
 
120
- rule["count"] =float(self.count_spin.text().strip()) if self.count_spin.text().strip() else 0
123
+ rule["count"] = [int(v) for v in self.count_frame.get_input_value()]
121
124
  rule["use"] = "concentration" if self.concentration_botton.isChecked() else "count"
122
125
  indices_text = self.indices_edit.text().strip()
123
126
  if indices_text:
@@ -136,9 +139,9 @@ class DopingRuleItem(QFrame):
136
139
  if dopants is not None:
137
140
  self.dopants_edit.setText(json.dumps(dopants))
138
141
  if "concentration" in rule:
139
- self.concentration_spin.setText(str(rule["concentration"]))
142
+ self.concentration_frame.set_input_value(rule["concentration"])
140
143
  if "count" in rule:
141
- self.count_spin.setText(str(rule["count"]))
144
+ self.count_frame.set_input_value(rule["count"])
142
145
  if "group" in rule:
143
146
  self.indices_edit.setText(",".join(str(i) for i in rule["group"]))
144
147
  if "use" in rule:
@@ -53,5 +53,8 @@ class SpinBoxUnitInputFrame(QFrame):
53
53
  return [input_object.value() for input_object in self.object_list]
54
54
 
55
55
  def set_input_value(self, value_list):
56
+ if not isinstance(value_list,list):
57
+ value_list=[value_list]*len(self.object_list)
58
+
56
59
  for i, input_object in enumerate(self.object_list):
57
60
  input_object.setValue(value_list[i])
@@ -0,0 +1,139 @@
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ """Widget to edit vacancy rules."""
4
+
5
+ from __future__ import annotations
6
+
7
+ from PySide6.QtCore import Qt
8
+ from PySide6.QtGui import QIcon
9
+ from PySide6.QtWidgets import (
10
+ QFrame,
11
+ QGridLayout,
12
+ QHBoxLayout,
13
+ QVBoxLayout,
14
+ QWidget,
15
+ QLineEdit,
16
+ )
17
+ from qfluentwidgets import (
18
+ BodyLabel,
19
+ TransparentToolButton,
20
+ FluentIcon,
21
+ ToolTipFilter,
22
+ ToolTipPosition,
23
+ )
24
+ from .input import SpinBoxUnitInputFrame
25
+
26
+
27
+ class VacancyRuleItem(QFrame):
28
+ """Single vacancy rule widget."""
29
+
30
+ def __init__(self, parent: QWidget | None = None) -> None:
31
+ super().__init__(parent)
32
+ self.layout = QGridLayout(self)
33
+ self.layout.setContentsMargins(0, 0, 0, 0)
34
+ self.layout.setSpacing(4)
35
+ self.setStyleSheet("background-color: rgb(239, 249, 254);")
36
+
37
+ self.element_edit = QLineEdit(self)
38
+ self.element_edit.setPlaceholderText("Cs")
39
+ self.group_edit = QLineEdit(self)
40
+ self.count_frame = SpinBoxUnitInputFrame(self)
41
+ self.count_frame.set_input(["-", ""], 2, "int")
42
+ self.count_frame.setRange(0, 10000)
43
+ self.count_frame.set_input_value([1, 1])
44
+
45
+ self.delete_button = TransparentToolButton(QIcon(":/images/src/images/delete.svg"), self)
46
+ self.delete_button.clicked.connect(self._delete_self)
47
+
48
+ self.element_label = BodyLabel("Element", self)
49
+ self.element_label.setToolTip("Element to remove")
50
+ self.element_label.installEventFilter(ToolTipFilter(self.element_label, 300, ToolTipPosition.TOP))
51
+ self.group_label = BodyLabel("Group", self)
52
+ self.group_label.setToolTip("Optional group name")
53
+ self.group_label.installEventFilter(ToolTipFilter(self.group_label, 300, ToolTipPosition.TOP))
54
+ self.count_label = BodyLabel("Count", self)
55
+ self.count_label.setToolTip("Number of atoms to remove")
56
+ self.count_label.installEventFilter(ToolTipFilter(self.count_label, 300, ToolTipPosition.TOP))
57
+
58
+ self.layout.addWidget(self.element_label, 0, 0)
59
+ self.layout.addWidget(self.element_edit, 0, 1)
60
+ self.layout.addWidget(self.group_label, 0, 2)
61
+ self.layout.addWidget(self.group_edit, 0, 3)
62
+ self.layout.addWidget(self.count_label, 1, 0)
63
+ self.layout.addWidget(self.count_frame, 1, 1)
64
+ self.layout.addWidget(self.delete_button, 0, 4, 2, 1)
65
+
66
+ def _delete_self(self) -> None:
67
+ self.setParent(None)
68
+ self.deleteLater()
69
+
70
+ def to_rule(self) -> dict:
71
+ rule: dict[str, object] = {}
72
+ element = self.element_edit.text().strip()
73
+ if element:
74
+ rule["element"] = element
75
+ rule["count"] = [int(v) for v in self.count_frame.get_input_value()]
76
+ groups = self.group_edit.text().strip()
77
+ if groups:
78
+ rule["group"] = [g.strip() for g in groups.split(",") if g.strip()]
79
+ return rule
80
+
81
+ def from_rule(self, rule: dict) -> None:
82
+ if not rule:
83
+ return
84
+ self.element_edit.setText(str(rule.get("element", "")))
85
+ if "count" in rule:
86
+ self.count_frame.set_input_value(rule["count"])
87
+ if "group" in rule:
88
+ self.group_edit.setText(",".join(str(i) for i in rule["group"]))
89
+
90
+
91
+ class VacancyRulesWidget(QWidget):
92
+ """Container widget for multiple vacancy rules."""
93
+
94
+ def __init__(self, parent: QWidget | None = None) -> None:
95
+ super().__init__(parent)
96
+ self.layout = QVBoxLayout(self)
97
+ self.layout.setContentsMargins(0, 0, 0, 0)
98
+ self.layout.setSpacing(4)
99
+
100
+ btn_layout = QHBoxLayout()
101
+ btn_layout.setContentsMargins(0, 0, 0, 0)
102
+ self.add_button = TransparentToolButton(FluentIcon.ADD, self)
103
+ self.add_button.clicked.connect(self.add_rule)
104
+ self.add_button.setToolTip("Add rule")
105
+ self.add_button.installEventFilter(ToolTipFilter(self.add_button, 300, ToolTipPosition.TOP))
106
+ btn_layout.addWidget(self.add_button, 0, Qt.AlignLeft)
107
+ btn_layout.addStretch(1)
108
+ self.layout.addLayout(btn_layout)
109
+
110
+ self.rule_container = QWidget(self)
111
+ self.rule_layout = QVBoxLayout(self.rule_container)
112
+ self.rule_layout.setContentsMargins(0, 0, 0, 0)
113
+ self.rule_layout.setSpacing(4)
114
+ self.layout.addWidget(self.rule_container)
115
+
116
+ def add_rule(self, rule: dict | None = None) -> VacancyRuleItem:
117
+ item = VacancyRuleItem(self.rule_container)
118
+ self.rule_layout.addWidget(item)
119
+ if rule:
120
+ item.from_rule(rule)
121
+ return item
122
+
123
+ def to_rules(self) -> list[dict]:
124
+ rules: list[dict] = []
125
+ for i in range(self.rule_layout.count()):
126
+ widget = self.rule_layout.itemAt(i).widget()
127
+ if isinstance(widget, VacancyRuleItem):
128
+ rule = widget.to_rule()
129
+ if rule:
130
+ rules.append(rule)
131
+ return rules
132
+
133
+ def from_rules(self, rules: list[dict]) -> None:
134
+ while self.rule_layout.count():
135
+ item = self.rule_layout.takeAt(0).widget()
136
+ if item is not None:
137
+ item.deleteLater()
138
+ for rule in rules or []:
139
+ self.add_rule(rule)
@@ -0,0 +1,216 @@
1
+ """Utilities for shifting structure energies using atomic baselines.
2
+ 抄的陈博的代码(已允许)
3
+ url: https://github.com/brucefan1983/GPUMD/tree/master/tools/Analysis_and_Processing/energy-reference-aligner
4
+ Zherui Chen Email: <chenzherui0124@foxmail.com>
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import numpy as np
10
+ from collections import Counter
11
+ from typing import List, Dict
12
+ import re
13
+ from NepTrainKit import utils
14
+ from .structure import Structure
15
+ from .calculator import NepCalculator
16
+
17
+ REF_GROUP_ALIGNMENT = "REF_GROUP_ALIGNMENT"
18
+ ZERO_BASELINE_ALIGNMENT = "ZERO_BASELINE_ALIGNMENT"
19
+ DFT_TO_NEP_ALIGNMENT = "DFT_TO_NEP_ALIGNMENT"
20
+
21
+ def longest_common_prefix(strs: List[str]) -> str:
22
+ if not strs:
23
+ return ""
24
+ s1, s2 = min(strs), max(strs)
25
+ for i, c in enumerate(s1):
26
+ if c != s2[i]:
27
+ return s1[:i]
28
+ return s1
29
+
30
+
31
+ def suggest_group_patterns(config_types: List[str], min_group_size: int = 2, min_prefix_len: int = 3) -> List[str]:
32
+ """Group strings by common prefix without relying on delimiters, and output regex patterns."""
33
+ unused = set(config_types)
34
+ patterns = []
35
+
36
+ while unused:
37
+ base = unused.pop()
38
+ group = [base]
39
+ to_remove = []
40
+
41
+ for other in unused:
42
+ prefix = longest_common_prefix([base, other])
43
+ if len(prefix) >= min_prefix_len:
44
+ group.append(other)
45
+ to_remove.append(other)
46
+
47
+ for item in to_remove:
48
+ unused.remove(item)
49
+
50
+ if len(group) >= min_group_size:
51
+ prefix = longest_common_prefix(group)
52
+ patterns.append(re.escape(prefix) + '.*')
53
+ else:
54
+ patterns.extend(re.escape(g) for g in group)
55
+
56
+ return sorted(patterns)
57
+ def atomic_baseline_cost(param_population: np.ndarray,
58
+ energies: np.ndarray,
59
+ element_counts: np.ndarray,
60
+ target_energies: np.ndarray) -> np.ndarray:
61
+ """Vectorized MSE cost for atomic reference baseline."""
62
+ shifted = energies[None, :] - np.dot(param_population, element_counts.T)
63
+ cost = np.mean((shifted - target_energies[None, :]) ** 2, axis=1)
64
+ return cost.reshape(-1, 1)
65
+
66
+ @utils.timeit
67
+ def nes_optimize_atomic_baseline(num_variables: int,
68
+ max_generations: int,
69
+ energies: np.ndarray,
70
+ element_counts: np.ndarray,
71
+ targets: np.ndarray,
72
+ pop_size: int = 40,
73
+ tol: float = 1e-8,
74
+ seed: int = 42,
75
+ print_every: int = 100) -> np.ndarray:
76
+ """NES optimizer for atomic reference energies."""
77
+ np.random.seed(seed)
78
+
79
+ best_fitness = np.ones((max_generations, 1))
80
+ elite = np.zeros((max_generations, num_variables))
81
+ mean = -1 * np.random.rand(1, num_variables)
82
+ stddev = 0.1 * np.ones((1, num_variables))
83
+ lr_mean = 1.0
84
+ lr_std = (3 + np.log(num_variables)) / (5 * np.sqrt(num_variables)) / 2
85
+ weights = np.maximum(0, np.log(pop_size / 2 + 1) - np.log(np.arange(1, pop_size + 1)))
86
+ weights = weights / np.sum(weights) - 1 / pop_size
87
+
88
+ for gen in range(max_generations):
89
+ z = np.random.randn(pop_size, num_variables)
90
+ pop = mean + stddev * z
91
+ fitness = atomic_baseline_cost(pop, energies, element_counts, targets)
92
+ idx = np.argsort(fitness.flatten())
93
+ fitness = fitness[idx]
94
+ z = z[idx, :]
95
+ pop = pop[idx, :]
96
+ best_fitness[gen] = fitness[0]
97
+ elite[gen, :] = pop[0, :]
98
+ mean += lr_mean * stddev * (weights @ z)
99
+ stddev *= np.exp(lr_std * (weights @ (z ** 2 - 1)))
100
+ if gen > 0 and abs(best_fitness[gen] - best_fitness[gen - 1]) < tol:
101
+ best_fitness = best_fitness[:gen + 1]
102
+ elite = elite[:gen + 1]
103
+ break
104
+ return elite[-1]
105
+
106
+
107
+
108
+ def shift_dataset_energy(
109
+ structures: List[Structure],
110
+ reference_structures: List[Structure] | None,
111
+ max_generations: int = 100000,
112
+ population_size: int = 40,
113
+ convergence_tol: float = 1e-8,
114
+ random_seed: int = 42,
115
+ group_patterns: List[str] | None = None,
116
+ alignment_mode: str = REF_GROUP_ALIGNMENT,
117
+ nep_energy_array: np.array | None = None):
118
+ """Shift structure energies using different alignment strategies.
119
+
120
+ Parameters
121
+ ----------
122
+ structures
123
+ Structures whose energies will be shifted.
124
+ reference_structures
125
+ Structures used to compute the reference mean energy when
126
+ ``alignment_mode`` is ``REF_GROUP_ALIGNMENT``.
127
+ alignment_mode
128
+ One of ``REF_GROUP_ALIGNMENT``, ``ZERO_BASELINE_ALIGNMENT`` or
129
+ ``DFT_TO_NEP_ALIGNMENT``.
130
+ nep_energy_array
131
+ nep energy array when ``alignment_mode`` is
132
+ ``DFT_TO_NEP_ALIGNMENT``.
133
+ """
134
+ frames = []
135
+ for s in structures:
136
+ energy = float(s.energy)
137
+ config_type = str(s.additional_fields.get("Config_type", ""))
138
+ elem_counts = Counter(s.elements)
139
+
140
+ frames.append({"energy": energy, "config_type": config_type, "elem_counts": elem_counts})
141
+
142
+ all_elements = sorted({e for f in frames for e in f["elem_counts"]})
143
+ num_elements = len(all_elements)
144
+
145
+
146
+ ref_mean = None
147
+ if alignment_mode == REF_GROUP_ALIGNMENT:
148
+ if not len(reference_structures):
149
+ raise ValueError("reference_structures is required for REF_GROUP_ALIGNMENT")
150
+ ref_energies = np.array([f.energy for f in reference_structures])
151
+ ref_mean = np.mean(ref_energies)
152
+
153
+ if alignment_mode == DFT_TO_NEP_ALIGNMENT:
154
+ if nep_energy_array is None:
155
+ raise ValueError("nep_energy_array is required for DFT_TO_NEP_ALIGNMENT")
156
+
157
+ for f, e in zip(frames, nep_energy_array):
158
+ f["nep_energy"] = e * f["elem_counts"].total()
159
+
160
+ all_config_types = {f["config_type"] for f in frames}
161
+
162
+ # build mapping from config_type to regex group name
163
+ config_to_group: Dict[str, str] = {}
164
+ if group_patterns:
165
+ for pat in group_patterns:
166
+ try:
167
+ regex = re.compile(pat)
168
+ except re.error:
169
+ continue
170
+ for ct in all_config_types:
171
+ if ct not in config_to_group and regex.match(ct):
172
+ config_to_group[ct] = pat
173
+ for ct in all_config_types:
174
+ config_to_group.setdefault(ct, ct)
175
+
176
+ shift_groups = sorted(set(config_to_group.values()))
177
+
178
+ group_to_atomic_ref = {}
179
+ for group in shift_groups:
180
+
181
+ grp_frames = [f for f in frames if config_to_group[f["config_type"]] == group]
182
+ if not grp_frames:
183
+ continue
184
+ energies = np.array([f["energy"] for f in grp_frames])
185
+ counts = np.array([[f["elem_counts"].get(e, 0) for e in all_elements] for f in grp_frames], dtype=float)
186
+
187
+ if alignment_mode == REF_GROUP_ALIGNMENT:
188
+ targets = np.full_like(energies, ref_mean)
189
+ elif alignment_mode == ZERO_BASELINE_ALIGNMENT:
190
+ targets = np.zeros_like(energies)
191
+ else: # DFT_TO_NEP_ALIGNMENT
192
+ targets = np.array([f["nep_energy"] for f in grp_frames])
193
+ atomic_ref = nes_optimize_atomic_baseline(
194
+ num_elements,
195
+ max_generations,
196
+ energies,
197
+ counts,
198
+ targets,
199
+ pop_size=population_size,
200
+ tol=convergence_tol,
201
+ seed=random_seed,
202
+ print_every=100,
203
+ )
204
+ group_to_atomic_ref[group] = atomic_ref
205
+ #这里是为了更新ui信号
206
+ yield 1
207
+
208
+ # apply shift
209
+ for s, frame in zip(structures, frames):
210
+ group = config_to_group[frame["config_type"]]
211
+ if group in group_to_atomic_ref:
212
+ count_vec = np.array([frame["elem_counts"].get(e, 0) for e in all_elements], dtype=float)
213
+ shift = np.dot(count_vec, group_to_atomic_ref[group])
214
+ new_energy = frame["energy"] - shift
215
+ s.energy = new_energy
216
+ # return group_to_atomic_ref
@@ -111,6 +111,9 @@ class NepData:
111
111
  """
112
112
  return self.data.now_data
113
113
  @property
114
+ def now_indices(self):
115
+ return self.data.now_indices
116
+ @property
114
117
  def all_data(self):
115
118
  return self.data.all_data
116
119
 
@@ -155,6 +155,17 @@ class ResultData(QObject):
155
155
  self.select_index.remove(i)
156
156
 
157
157
  self.updateInfoSignal.emit()
158
+
159
+ def inverse_select(self):
160
+ """Invert the current selection state of all active structures"""
161
+ active_indices = set(self.structure.data.now_indices.tolist())
162
+ selected_indices = set(self.select_index)
163
+ unselect = list(selected_indices)
164
+ select = list(active_indices - selected_indices)
165
+ if unselect:
166
+ self.uncheck(unselect)
167
+ if select:
168
+ self.select(select)
158
169
  def export_selected_xyz(self,save_file_path):
159
170
  """
160
171
  导出当前选中的结构
@@ -166,7 +177,7 @@ class ResultData(QObject):
166
177
 
167
178
  index=self.structure.convert_index(index)
168
179
 
169
- for structure in self.structure.now_data[index]:
180
+ for structure in self.structure.all_data[index]:
170
181
  structure.write(f)
171
182
 
172
183
  MessageManager.send_info_message(f"File exported to: {save_file_path}")