NepTrainKit 2.2.1__tar.gz → 2.2.2.dev23__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/PKG-INFO +1 -1
  2. {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/_version.py +2 -2
  3. {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/core/canvas/base/canvas.py +10 -0
  4. {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/core/custom_widget/__init__.py +4 -0
  5. {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/core/custom_widget/completer.py +2 -2
  6. {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/core/custom_widget/dialog.py +75 -1
  7. neptrainkit-2.2.2.dev23/src/NepTrainKit/core/energy_shift.py +216 -0
  8. {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/core/io/base.py +3 -0
  9. {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/core/io/nep.py +11 -0
  10. {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/core/structure.py +7 -2
  11. {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/core/update.py +4 -3
  12. {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/core/views/nep.py +105 -5
  13. {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/core/views/toolbar.py +16 -0
  14. {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/src_rc.py +323 -110
  15. {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/utils.py +43 -1
  16. {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/src/NepTrainKit.egg-info/PKG-INFO +1 -1
  17. {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/src/NepTrainKit.egg-info/SOURCES.txt +1 -0
  18. {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/.gitattributes +0 -0
  19. {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/.gitignore +0 -0
  20. {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/.readthedocs.yml +0 -0
  21. {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/LICENSE +0 -0
  22. {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/MANIFEST.in +0 -0
  23. {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/README.md +0 -0
  24. {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/pyproject.toml +0 -0
  25. {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/requirements.txt +0 -0
  26. {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/setup.cfg +0 -0
  27. {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/setup.py +0 -0
  28. {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/Config/config.sqlite +0 -0
  29. {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/Config/nep.json +0 -0
  30. {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/Config/nep89.txt +0 -0
  31. {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/Config/ptable.json +0 -0
  32. {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/__init__.py +0 -0
  33. {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/core/__init__.py +0 -0
  34. {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/core/calculator.py +0 -0
  35. {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/core/canvas/pyqtgraph/__init__.py +0 -0
  36. {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/core/canvas/pyqtgraph/canvas.py +0 -0
  37. {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/core/canvas/pyqtgraph/structure.py +0 -0
  38. {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/core/canvas/vispy/__init__.py +0 -0
  39. {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/core/canvas/vispy/canvas.py +0 -0
  40. {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/core/canvas/vispy/structure.py +0 -0
  41. {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/core/config.py +0 -0
  42. {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/core/custom_widget/card_widget.py +0 -0
  43. {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/core/custom_widget/docker.py +0 -0
  44. {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/core/custom_widget/doping_rule.py +0 -0
  45. {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/core/custom_widget/input.py +0 -0
  46. {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/core/custom_widget/label.py +0 -0
  47. {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/core/custom_widget/layout.py +0 -0
  48. {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/core/custom_widget/search_widget.py +0 -0
  49. {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/core/custom_widget/settingscard.py +0 -0
  50. {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/core/custom_widget/vacancy_rule.py +0 -0
  51. {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/core/io/__init__.py +0 -0
  52. {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/core/io/select.py +0 -0
  53. {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/core/io/utils.py +0 -0
  54. {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/core/message.py +0 -0
  55. {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/core/pages/__init__.py +0 -0
  56. {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/core/pages/makedata.py +0 -0
  57. {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/core/pages/settings.py +0 -0
  58. {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/core/pages/show_nep.py +0 -0
  59. {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/core/types.py +0 -0
  60. {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/core/views/__init__.py +0 -0
  61. {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/core/views/cards.py +0 -0
  62. {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/core/views/structure.py +0 -0
  63. {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/main.py +0 -0
  64. {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/src/NepTrainKit/version.py +0 -0
  65. {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/src/NepTrainKit.egg-info/dependency_links.txt +0 -0
  66. {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/src/NepTrainKit.egg-info/entry_points.txt +0 -0
  67. {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/src/NepTrainKit.egg-info/not-zip-safe +0 -0
  68. {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/src/NepTrainKit.egg-info/requires.txt +0 -0
  69. {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/src/NepTrainKit.egg-info/top_level.txt +0 -0
  70. {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/src/nep_cpu/dftd3para.h +0 -0
  71. {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/src/nep_cpu/nep.cpp +0 -0
  72. {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/src/nep_cpu/nep.h +0 -0
  73. {neptrainkit-2.2.1 → neptrainkit-2.2.2.dev23}/src/nep_cpu/nep_cpu.cpp +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: NepTrainKit
3
- Version: 2.2.1
3
+ Version: 2.2.2.dev23
4
4
  Summary: NepTrainKit is a Python package for visualizing and manipulating training datasets for NEP.
5
5
  Author: Chen Cheng bing
6
6
  Author-email: Chen Cheng bing <1747193328@qq.com>
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '2.2.1'
21
- __version_tuple__ = version_tuple = (2, 2, 1)
20
+ __version__ = version = '2.2.2.dev23'
21
+ __version_tuple__ = version_tuple = (2, 2, 2, 'dev23')
@@ -156,6 +156,16 @@ class CanvasLayoutBase(CanvasBase):
156
156
  self.nep_result_data.select(structure_index)
157
157
 
158
158
  self.update_scatter_color(structure_index, Brushes.Selected)
159
+
160
+ def inverse_select(self):
161
+ if self.nep_result_data is None:
162
+ return
163
+
164
+ active_indices = set(self.nep_result_data.structure.now_indices.tolist())
165
+ selected = set(self.nep_result_data.select_index)
166
+
167
+ self.select_index(list(selected), True)
168
+ self.select_index(list(active_indices - selected), False)
159
169
  class VispyCanvasLayoutBase(CanvasLayoutBase,QObject,metaclass=CombinedMeta):
160
170
  def __init__(self,*args,**kwargs):
161
171
  QObject.__init__(self)
@@ -9,6 +9,8 @@ from .completer import CompleterModel, JoinDelegate, ConfigCompleter
9
9
  from .dialog import (
10
10
  GetIntMessageBox,
11
11
  SparseMessageBox,
12
+ IndexSelectMessageBox,
13
+ ShiftEnergyMessageBox,
12
14
  ProgressDialog,
13
15
  PeriodicTableDialog,
14
16
  )
@@ -33,6 +35,8 @@ __all__ = [
33
35
  "ConfigCompleter",
34
36
  "GetIntMessageBox",
35
37
  "SparseMessageBox",
38
+ "IndexSelectMessageBox",
39
+ "ShiftEnergyMessageBox",
36
40
  "ProgressDialog",
37
41
  "PeriodicTableDialog",
38
42
  "SpinBoxUnitInputFrame",
@@ -33,7 +33,7 @@ class CompleterModel(QAbstractListModel):
33
33
  def data(self, index, role=Qt.ItemDataRole.DisplayRole):
34
34
 
35
35
  if not index.isValid():
36
- print("index")
36
+
37
37
  return None
38
38
  # print(role)
39
39
  # Qt.ItemDataRole.DisplayRole
@@ -44,7 +44,7 @@ class CompleterModel(QAbstractListModel):
44
44
  if role == Qt.ItemDataRole.DisplayRole or role == Qt.ItemDataRole.EditRole:
45
45
  return word
46
46
  elif role == CountRole:
47
- print(count)
47
+
48
48
  return str(count)
49
49
 
50
50
 
@@ -4,14 +4,16 @@
4
4
  # @Author : 兵
5
5
  # @email : 1747193328@qq.com
6
6
  from PySide6.QtGui import QIcon
7
- from PySide6.QtWidgets import QVBoxLayout, QFrame, QGridLayout, QPushButton
7
+ from PySide6.QtWidgets import QVBoxLayout, QFrame, QGridLayout, QPushButton, QLineEdit
8
8
  from PySide6.QtCore import Signal, Qt
9
9
  from qfluentwidgets import (
10
10
  MessageBoxBase,
11
11
  SpinBox,
12
12
  CaptionLabel,
13
13
  DoubleSpinBox,
14
+ CheckBox,
14
15
  ProgressBar,
16
+ ComboBox,
15
17
  FluentStyleSheet,
16
18
  FluentTitleBar,
17
19
  TitleLabel
@@ -73,6 +75,78 @@ class SparseMessageBox(MessageBoxBase):
73
75
 
74
76
  self.widget.setMinimumWidth(200)
75
77
 
78
+
79
+ class IndexSelectMessageBox(MessageBoxBase):
80
+ """Dialog for selecting structures by index."""
81
+
82
+ def __init__(self, parent=None, tip="Specify index or slice"):
83
+ super().__init__(parent)
84
+ self.titleLabel = CaptionLabel(tip, self)
85
+ self.titleLabel.setWordWrap(True)
86
+ self.indexEdit = QLineEdit(self)
87
+ self.checkBox = CheckBox("Use original indices", self)
88
+ self.checkBox.setChecked(True)
89
+
90
+ self.viewLayout.addWidget(self.titleLabel)
91
+ self.viewLayout.addWidget(self.indexEdit)
92
+ self.viewLayout.addWidget(self.checkBox)
93
+
94
+ self.yesButton.setText('Ok')
95
+ self.cancelButton.setText('Cancel')
96
+ self.widget.setMinimumWidth(200)
97
+
98
+
99
+ class ShiftEnergyMessageBox(MessageBoxBase):
100
+ """Dialog for energy baseline shift parameters."""
101
+
102
+ def __init__(self, parent=None, tip="Group regex patterns (comma separated)"):
103
+ super().__init__(parent)
104
+ self.titleLabel = CaptionLabel(tip, self)
105
+ self.titleLabel.setWordWrap(True)
106
+ self.groupEdit = QLineEdit(self)
107
+
108
+ self._frame = QFrame(self)
109
+ self.frame_layout = QGridLayout(self._frame)
110
+ self.frame_layout.setContentsMargins(0, 0, 0, 0)
111
+ self.frame_layout.setSpacing(2)
112
+
113
+ self.genSpinBox = SpinBox(self)
114
+ self.genSpinBox.setMaximum(100000000)
115
+ self.sizeSpinBox = SpinBox(self)
116
+ self.sizeSpinBox.setMaximum(999999)
117
+ self.tolSpinBox = DoubleSpinBox(self)
118
+ self.tolSpinBox.setDecimals(8)
119
+ self.tolSpinBox.setMinimum(0)
120
+ self.modeCombo = ComboBox(self)
121
+ self.modeCombo.addItems([
122
+ "REF_GROUP_ALIGNMENT",
123
+ "ZERO_BASELINE_ALIGNMENT",
124
+ "DFT_TO_NEP_ALIGNMENT",
125
+ ])
126
+ self.modeCombo.setCurrentText("DFT_TO_NEP_ALIGNMENT")
127
+
128
+
129
+ self.frame_layout.addWidget(CaptionLabel("Max generations", self), 0, 0)
130
+ self.frame_layout.addWidget(self.genSpinBox, 0, 1)
131
+ self.frame_layout.addWidget(CaptionLabel("Population size", self), 1, 0)
132
+ self.frame_layout.addWidget(self.sizeSpinBox, 1, 1)
133
+ self.frame_layout.addWidget(CaptionLabel("Convergence tol", self), 2, 0)
134
+ self.frame_layout.addWidget(self.tolSpinBox, 2, 1)
135
+ self.frame_layout.addWidget(CaptionLabel("Mode", self), 3, 0)
136
+ self.frame_layout.addWidget(self.modeCombo, 3, 1)
137
+
138
+
139
+ self.viewLayout.addWidget(self.titleLabel)
140
+ self.viewLayout.addWidget(self.groupEdit)
141
+ self.viewLayout.addWidget(self._frame)
142
+
143
+ self.yesButton.setText('Ok')
144
+ self.cancelButton.setText('Cancel')
145
+ self.widget.setMinimumWidth(250)
146
+
147
+
148
+
149
+
76
150
  class ProgressDialog(FramelessDialog):
77
151
  """进度条弹窗"""
78
152
  def __init__(self,parent=None,title=""):
@@ -0,0 +1,216 @@
1
+ """Utilities for shifting structure energies using atomic baselines.
2
+ 抄的陈博的代码(已允许)
3
+ url: https://github.com/brucefan1983/GPUMD/tree/master/tools/Analysis_and_Processing/energy-reference-aligner
4
+ Zherui Chen Email: <chenzherui0124@foxmail.com>
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import numpy as np
10
+ from collections import Counter
11
+ from typing import List, Dict
12
+ import re
13
+ from NepTrainKit import utils
14
+ from .structure import Structure
15
+ from .calculator import NepCalculator
16
+
17
+ REF_GROUP_ALIGNMENT = "REF_GROUP_ALIGNMENT"
18
+ ZERO_BASELINE_ALIGNMENT = "ZERO_BASELINE_ALIGNMENT"
19
+ DFT_TO_NEP_ALIGNMENT = "DFT_TO_NEP_ALIGNMENT"
20
+
21
+ def longest_common_prefix(strs: List[str]) -> str:
22
+ if not strs:
23
+ return ""
24
+ s1, s2 = min(strs), max(strs)
25
+ for i, c in enumerate(s1):
26
+ if c != s2[i]:
27
+ return s1[:i]
28
+ return s1
29
+
30
+
31
+ def suggest_group_patterns(config_types: List[str], min_group_size: int = 2, min_prefix_len: int = 3) -> List[str]:
32
+ """Group strings by common prefix without relying on delimiters, and output regex patterns."""
33
+ unused = set(config_types)
34
+ patterns = []
35
+
36
+ while unused:
37
+ base = unused.pop()
38
+ group = [base]
39
+ to_remove = []
40
+
41
+ for other in unused:
42
+ prefix = longest_common_prefix([base, other])
43
+ if len(prefix) >= min_prefix_len:
44
+ group.append(other)
45
+ to_remove.append(other)
46
+
47
+ for item in to_remove:
48
+ unused.remove(item)
49
+
50
+ if len(group) >= min_group_size:
51
+ prefix = longest_common_prefix(group)
52
+ patterns.append(re.escape(prefix) + '.*')
53
+ else:
54
+ patterns.extend(re.escape(g) for g in group)
55
+
56
+ return sorted(patterns)
57
+ def atomic_baseline_cost(param_population: np.ndarray,
58
+ energies: np.ndarray,
59
+ element_counts: np.ndarray,
60
+ target_energies: np.ndarray) -> np.ndarray:
61
+ """Vectorized MSE cost for atomic reference baseline."""
62
+ shifted = energies[None, :] - np.dot(param_population, element_counts.T)
63
+ cost = np.mean((shifted - target_energies[None, :]) ** 2, axis=1)
64
+ return cost.reshape(-1, 1)
65
+
66
+ @utils.timeit
67
+ def nes_optimize_atomic_baseline(num_variables: int,
68
+ max_generations: int,
69
+ energies: np.ndarray,
70
+ element_counts: np.ndarray,
71
+ targets: np.ndarray,
72
+ pop_size: int = 40,
73
+ tol: float = 1e-8,
74
+ seed: int = 42,
75
+ print_every: int = 100) -> np.ndarray:
76
+ """NES optimizer for atomic reference energies."""
77
+ np.random.seed(seed)
78
+
79
+ best_fitness = np.ones((max_generations, 1))
80
+ elite = np.zeros((max_generations, num_variables))
81
+ mean = -1 * np.random.rand(1, num_variables)
82
+ stddev = 0.1 * np.ones((1, num_variables))
83
+ lr_mean = 1.0
84
+ lr_std = (3 + np.log(num_variables)) / (5 * np.sqrt(num_variables)) / 2
85
+ weights = np.maximum(0, np.log(pop_size / 2 + 1) - np.log(np.arange(1, pop_size + 1)))
86
+ weights = weights / np.sum(weights) - 1 / pop_size
87
+
88
+ for gen in range(max_generations):
89
+ z = np.random.randn(pop_size, num_variables)
90
+ pop = mean + stddev * z
91
+ fitness = atomic_baseline_cost(pop, energies, element_counts, targets)
92
+ idx = np.argsort(fitness.flatten())
93
+ fitness = fitness[idx]
94
+ z = z[idx, :]
95
+ pop = pop[idx, :]
96
+ best_fitness[gen] = fitness[0]
97
+ elite[gen, :] = pop[0, :]
98
+ mean += lr_mean * stddev * (weights @ z)
99
+ stddev *= np.exp(lr_std * (weights @ (z ** 2 - 1)))
100
+ if gen > 0 and abs(best_fitness[gen] - best_fitness[gen - 1]) < tol:
101
+ best_fitness = best_fitness[:gen + 1]
102
+ elite = elite[:gen + 1]
103
+ break
104
+ return elite[-1]
105
+
106
+
107
+
108
+ def shift_dataset_energy(
109
+ structures: List[Structure],
110
+ reference_structures: List[Structure] | None,
111
+ max_generations: int = 100000,
112
+ population_size: int = 40,
113
+ convergence_tol: float = 1e-8,
114
+ random_seed: int = 42,
115
+ group_patterns: List[str] | None = None,
116
+ alignment_mode: str = REF_GROUP_ALIGNMENT,
117
+ nep_energy_array: np.array | None = None):
118
+ """Shift structure energies using different alignment strategies.
119
+
120
+ Parameters
121
+ ----------
122
+ structures
123
+ Structures whose energies will be shifted.
124
+ reference_structures
125
+ Structures used to compute the reference mean energy when
126
+ ``alignment_mode`` is ``REF_GROUP_ALIGNMENT``.
127
+ alignment_mode
128
+ One of ``REF_GROUP_ALIGNMENT``, ``ZERO_BASELINE_ALIGNMENT`` or
129
+ ``DFT_TO_NEP_ALIGNMENT``.
130
+ nep_energy_array
131
+ nep energy array when ``alignment_mode`` is
132
+ ``DFT_TO_NEP_ALIGNMENT``.
133
+ """
134
+ frames = []
135
+ for s in structures:
136
+ energy = float(s.energy)
137
+ config_type = str(s.additional_fields.get("Config_type", ""))
138
+ elem_counts = Counter(s.elements)
139
+
140
+ frames.append({"energy": energy, "config_type": config_type, "elem_counts": elem_counts})
141
+
142
+ all_elements = sorted({e for f in frames for e in f["elem_counts"]})
143
+ num_elements = len(all_elements)
144
+
145
+
146
+ ref_mean = None
147
+ if alignment_mode == REF_GROUP_ALIGNMENT:
148
+ if not len(reference_structures):
149
+ raise ValueError("reference_structures is required for REF_GROUP_ALIGNMENT")
150
+ ref_energies = np.array([f.energy for f in reference_structures])
151
+ ref_mean = np.mean(ref_energies)
152
+
153
+ if alignment_mode == DFT_TO_NEP_ALIGNMENT:
154
+ if nep_energy_array is None:
155
+ raise ValueError("nep_energy_array is required for DFT_TO_NEP_ALIGNMENT")
156
+
157
+ for f, e in zip(frames, nep_energy_array):
158
+ f["nep_energy"] = e * f["elem_counts"].total()
159
+
160
+ all_config_types = {f["config_type"] for f in frames}
161
+
162
+ # build mapping from config_type to regex group name
163
+ config_to_group: Dict[str, str] = {}
164
+ if group_patterns:
165
+ for pat in group_patterns:
166
+ try:
167
+ regex = re.compile(pat)
168
+ except re.error:
169
+ continue
170
+ for ct in all_config_types:
171
+ if ct not in config_to_group and regex.match(ct):
172
+ config_to_group[ct] = pat
173
+ for ct in all_config_types:
174
+ config_to_group.setdefault(ct, ct)
175
+
176
+ shift_groups = sorted(set(config_to_group.values()))
177
+
178
+ group_to_atomic_ref = {}
179
+ for group in shift_groups:
180
+
181
+ grp_frames = [f for f in frames if config_to_group[f["config_type"]] == group]
182
+ if not grp_frames:
183
+ continue
184
+ energies = np.array([f["energy"] for f in grp_frames])
185
+ counts = np.array([[f["elem_counts"].get(e, 0) for e in all_elements] for f in grp_frames], dtype=float)
186
+
187
+ if alignment_mode == REF_GROUP_ALIGNMENT:
188
+ targets = np.full_like(energies, ref_mean)
189
+ elif alignment_mode == ZERO_BASELINE_ALIGNMENT:
190
+ targets = np.zeros_like(energies)
191
+ else: # DFT_TO_NEP_ALIGNMENT
192
+ targets = np.array([f["nep_energy"] for f in grp_frames])
193
+ atomic_ref = nes_optimize_atomic_baseline(
194
+ num_elements,
195
+ max_generations,
196
+ energies,
197
+ counts,
198
+ targets,
199
+ pop_size=population_size,
200
+ tol=convergence_tol,
201
+ seed=random_seed,
202
+ print_every=100,
203
+ )
204
+ group_to_atomic_ref[group] = atomic_ref
205
+ #这里是为了更新ui信号
206
+ yield 1
207
+
208
+ # apply shift
209
+ for s, frame in zip(structures, frames):
210
+ group = config_to_group[frame["config_type"]]
211
+ if group in group_to_atomic_ref:
212
+ count_vec = np.array([frame["elem_counts"].get(e, 0) for e in all_elements], dtype=float)
213
+ shift = np.dot(count_vec, group_to_atomic_ref[group])
214
+ new_energy = frame["energy"] - shift
215
+ s.energy = new_energy
216
+ # return group_to_atomic_ref
@@ -111,6 +111,9 @@ class NepData:
111
111
  """
112
112
  return self.data.now_data
113
113
  @property
114
+ def now_indices(self):
115
+ return self.data.now_indices
116
+ @property
114
117
  def all_data(self):
115
118
  return self.data.all_data
116
119
 
@@ -155,6 +155,17 @@ class ResultData(QObject):
155
155
  self.select_index.remove(i)
156
156
 
157
157
  self.updateInfoSignal.emit()
158
+
159
+ def inverse_select(self):
160
+ """Invert the current selection state of all active structures"""
161
+ active_indices = set(self.structure.data.now_indices.tolist())
162
+ selected_indices = set(self.select_index)
163
+ unselect = list(selected_indices)
164
+ select = list(active_indices - selected_indices)
165
+ if unselect:
166
+ self.uncheck(unselect)
167
+ if select:
168
+ self.select(select)
158
169
  def export_selected_xyz(self,save_file_path):
159
170
  """
160
171
  导出当前选中的结构
@@ -116,8 +116,13 @@ class Structure:
116
116
 
117
117
  @property
118
118
  def per_atom_energy(self):
119
- return self.additional_fields["energy"]/self.num_atoms
120
-
119
+ return self.energy/self.num_atoms
120
+ @property
121
+ def energy(self):
122
+ return self.additional_fields["energy"]
123
+ @energy.setter
124
+ def energy(self,new_energy):
125
+ self.additional_fields["energy"] = new_energy
121
126
  @property
122
127
  def forces(self):
123
128
  return self.structure_info[self.force_label]
@@ -72,9 +72,7 @@ class UpdateWoker( QObject):
72
72
  if version_info['tag_name'][1:] == __version__:
73
73
  MessageManager.send_success_message("You are already using the latest version!")
74
74
  return
75
- if not is_nuitka_compiled:
76
- MessageManager.send_info_message("You can update via pip install NepTrainKit -U")
77
- return
75
+
78
76
  box = MessageBox("New version detected:" + version_info["name"] + version_info["tag_name"],
79
77
  version_info["body"],
80
78
  self._parent
@@ -92,6 +90,9 @@ class UpdateWoker( QObject):
92
90
  MessageManager.send_warning_message("No update package available for your system. Please download it manually!")
93
91
 
94
92
  def check_update(self):
93
+ if not is_nuitka_compiled:
94
+ MessageManager.send_info_message("You can update via pip install NepTrainKit -U --pre")
95
+ return
95
96
  self.update_thread.start_work(self._check_update)
96
97
 
97
98
 
@@ -3,17 +3,28 @@
3
3
  # @Time : 2024/10/20 22:22
4
4
  # @Author : 兵
5
5
  # @email : 1747193328@qq.com
6
+ import os
6
7
  import time
8
+ import traceback
9
+
10
+ from loguru import logger
11
+
7
12
  start=time.time()
8
13
  import numpy as np
9
14
  from PySide6.QtWidgets import QHBoxLayout, QWidget, QProgressDialog
10
15
 
11
16
 
12
- from NepTrainKit import utils
17
+ from NepTrainKit import utils, module_path
13
18
  from NepTrainKit.core import MessageManager, Config
14
- from NepTrainKit.core.custom_widget import GetIntMessageBox, SparseMessageBox
19
+ from NepTrainKit.core.custom_widget import (
20
+ GetIntMessageBox,
21
+ SparseMessageBox,
22
+ IndexSelectMessageBox,
23
+ ShiftEnergyMessageBox,
24
+ )
15
25
  from NepTrainKit.core.io.select import farthest_point_sampling
16
26
  from NepTrainKit.core.views.toolbar import NepDisplayGraphicsToolBar
27
+ from NepTrainKit.core.energy_shift import shift_dataset_energy, suggest_group_patterns
17
28
 
18
29
 
19
30
  class NepResultPlotWidget(QWidget):
@@ -63,6 +74,9 @@ class NepResultPlotWidget(QWidget):
63
74
  self.tool_bar.findMaxSignal.connect(self.find_max_error_point)
64
75
  self.tool_bar.discoverySignal.connect(self.find_non_physical_structures)
65
76
  self.tool_bar.sparseSignal.connect(self.sparse_point)
77
+ self.tool_bar.shiftEnergySignal.connect(self.shift_energy_baseline)
78
+ self.tool_bar.inverseSignal.connect(self.inverse_select)
79
+ self.tool_bar.selectIndexSignal.connect(self.select_by_index)
66
80
  self.canvas.tool_bar=self.tool_bar
67
81
 
68
82
 
@@ -142,11 +156,11 @@ class NepResultPlotWidget(QWidget):
142
156
  remaining_indices = farthest_point_sampling(dataset.now_data,n_samples=n_samples,min_dist=distance)
143
157
 
144
158
  # 获取所有索引(从 0 到 len(arr)-1)
145
- all_indices = np.arange(dataset.now_data.shape[0])
159
+ # all_indices = np.arange(dataset.now_data.shape[0])
146
160
 
147
161
  # 使用 setdiff1d 获取不在 indices_to_remove 中的索引
148
- remove_indices = np.setdiff1d(all_indices, remaining_indices)
149
- structures = dataset.group_array[remove_indices]
162
+ # remove_indices = np.setdiff1d(all_indices, remaining_indices)
163
+ structures = dataset.group_array[remaining_indices]
150
164
  self.canvas.select_index(structures.tolist(),False)
151
165
 
152
166
  def export_descriptor_data(self):
@@ -175,6 +189,92 @@ class NepResultPlotWidget(QWidget):
175
189
  with open(path, "w") as f:
176
190
  np.savetxt(f,descriptor_data,fmt='%.6g',delimiter='\t')
177
191
 
192
+
193
+ def shift_energy_baseline(self):
194
+ data = self.canvas.nep_result_data
195
+ if data is None:
196
+ return
197
+ ref_index = list(data.select_index)
198
+ # if len(ref_index) == 0:
199
+ # MessageManager.send_info_message("No data selected!")
200
+ # return
201
+
202
+ max_generations = Config.getint("widget","max_generation_value",100000)
203
+ population_size = Config.getint("widget","population_size",40)
204
+ convergence_tol = Config.getfloat("widget","convergence_tol", 1e-8)
205
+ config_set = set(data.structure.get_all_config())
206
+ suggested = suggest_group_patterns(list(config_set))
207
+ box = ShiftEnergyMessageBox(
208
+ self._parent,
209
+ "Specify regex groups for Config_type (comma separated)"
210
+ )
211
+ box.groupEdit.setText(";".join(suggested))
212
+ box.genSpinBox.setValue(max_generations)
213
+ box.sizeSpinBox.setValue(population_size)
214
+ box.tolSpinBox.setValue(convergence_tol)
215
+
216
+
217
+ if not box.exec():
218
+ return
219
+
220
+ pattern_text = box.groupEdit.text().strip()
221
+ group_patterns = [p.strip() for p in pattern_text.split(';') if p.strip()]
222
+
223
+ alignment_mode = box.modeCombo.currentText()
224
+
225
+
226
+ max_generations = box.genSpinBox.value()
227
+ population_size = box.sizeSpinBox.value()
228
+ convergence_tol = box.tolSpinBox.value()
229
+ Config.set("widget","max_generation_value",max_generations)
230
+ Config.set("widget","population_size",population_size)
231
+ Config.set("widget","convergence_tol",convergence_tol)
232
+ config_set = set(data.structure.get_all_config())
233
+ progress_diag = QProgressDialog(f"", "Cancel", 0, len(config_set), self._parent)
234
+ thread = utils.LoadingThread(self._parent, show_tip=False)
235
+ progress_diag.setFixedSize(300, 100)
236
+ progress_diag.setWindowTitle("Shift energies")
237
+ thread.progressSignal.connect(progress_diag.setValue)
238
+ thread.finished.connect(progress_diag.accept)
239
+ progress_diag.canceled.connect(thread.stop_work) # 用户取消时终止线程
240
+ thread.start_work(
241
+ shift_dataset_energy,
242
+ structures=data.structure.now_data,
243
+ reference_structures=data.structure.all_data[ref_index],
244
+ max_generations=max_generations,
245
+ population_size=population_size,
246
+ convergence_tol=convergence_tol,
247
+ group_patterns=group_patterns,
248
+ alignment_mode=alignment_mode,
249
+ nep_energy_array=data.energy.y,
250
+ )
251
+ progress_diag.exec()
252
+ if hasattr(data, "energy") and data.energy.num != 0:
253
+ for i, s in enumerate(data.structure.all_data):
254
+ # print(s.per_atom_energy)
255
+ data.energy.data._data[i, 1] = s.per_atom_energy
256
+ self.canvas.plot_nep_result()
257
+
258
+ def inverse_select(self):
259
+ self.canvas.inverse_select()
260
+
261
+ def select_by_index(self):
262
+ if self.canvas.nep_result_data is None:
263
+ return
264
+ box = IndexSelectMessageBox(self._parent, "Select structures by index")
265
+ if not box.exec():
266
+ return
267
+ text = box.indexEdit.text().strip()
268
+ use_origin = box.checkBox.isChecked()
269
+ data = self.canvas.nep_result_data.structure
270
+ total = data.all_data.shape[0] if use_origin else data.now_data.shape[0]
271
+ indices = utils.parse_index_string(text, total)
272
+ if not indices:
273
+ return
274
+ if not use_origin:
275
+ indices = data.group_array.now_data[indices].tolist()
276
+ self.canvas.select_index(indices, False)
277
+
178
278
  def set_dataset(self,dataset):
179
279
 
180
280
  if self.last_figure_num !=len(dataset.dataset):
@@ -42,6 +42,9 @@ class NepDisplayGraphicsToolBar(KitToolBarBase):
42
42
  deleteSignal=Signal()
43
43
  revokeSignal=Signal()
44
44
  exportSignal=Signal()
45
+ shiftEnergySignal=Signal()
46
+ inverseSignal=Signal()
47
+ selectIndexSignal=Signal()
45
48
 
46
49
  def init_actions(self):
47
50
  self.addButton("Reset View",QIcon(":/images/src/images/init.svg"),self.resetSignal)
@@ -50,6 +53,9 @@ class NepDisplayGraphicsToolBar(KitToolBarBase):
50
53
  self.pan,
51
54
  True
52
55
  )
56
+ self.addButton("Select by Index",
57
+ QIcon(":/images/src/images/index.svg"),
58
+ self.selectIndexSignal)
53
59
  find_max_action = self.addButton("Find Max Error Point",
54
60
  QIcon(":/images/src/images/find_max.svg"),
55
61
  self.findMaxSignal)
@@ -58,6 +64,8 @@ class NepDisplayGraphicsToolBar(KitToolBarBase):
58
64
  self.sparseSignal)
59
65
 
60
66
 
67
+
68
+
61
69
  pen_action=self.addButton("Mouse Selection",
62
70
  QIcon(":/images/src/images/pen.svg"),
63
71
  self.pen,
@@ -71,6 +79,9 @@ class NepDisplayGraphicsToolBar(KitToolBarBase):
71
79
  discovery_action = self.addButton("Finding non-physical structures",
72
80
  QIcon(":/images/src/images/discovery.svg"),
73
81
  self.discoverySignal)
82
+ inverse_action = self.addButton("Inverse Selection",
83
+ QIcon(":/images/src/images/inverse.svg"),
84
+ self.inverseSignal)
74
85
  revoke_action = self.addButton("Undo",
75
86
  QIcon(":/images/src/images/revoke.svg"),
76
87
  self.revokeSignal)
@@ -78,10 +89,15 @@ class NepDisplayGraphicsToolBar(KitToolBarBase):
78
89
  delete_action = self.addButton("Delete Selected Items",
79
90
  QIcon(":/images/src/images/delete.svg"),
80
91
  self.deleteSignal)
92
+
81
93
  self.addSeparator()
82
94
  export_action = self.addButton("Export structure descriptor",
83
95
  QIcon(":/images/src/images/export.svg"),
84
96
  self.exportSignal)
97
+ self.addSeparator()
98
+ self.addButton("Energy Baseline Shift",
99
+ QIcon(":/images/src/images/alignment.svg"),
100
+ self.shiftEnergySignal)
85
101
 
86
102
  def reset(self):
87
103
  if self.action_group.checkedAction():