NepTrainKit 2.0.6.dev57__tar.gz → 2.1.1.dev37__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (86) hide show
  1. neptrainkit-2.1.1.dev37/.gitignore +8 -0
  2. {neptrainkit-2.0.6.dev57 → neptrainkit-2.1.1.dev37}/PKG-INFO +1 -1
  3. {neptrainkit-2.0.6.dev57 → neptrainkit-2.1.1.dev37}/pyproject.toml +1 -1
  4. {neptrainkit-2.0.6.dev57 → neptrainkit-2.1.1.dev37}/requirements.txt +3 -2
  5. {neptrainkit-2.0.6.dev57 → neptrainkit-2.1.1.dev37}/src/NepTrainKit/_version.py +2 -2
  6. neptrainkit-2.1.1.dev37/src/NepTrainKit/core/canvas/pyqtgraph/__init__.py +11 -0
  7. {neptrainkit-2.0.6.dev57 → neptrainkit-2.1.1.dev37}/src/NepTrainKit/core/canvas/pyqtgraph/canvas.py +108 -123
  8. {neptrainkit-2.0.6.dev57/src/NepTrainKit/core/views → neptrainkit-2.1.1.dev37/src/NepTrainKit/core/canvas/pyqtgraph}/structure.py +5 -74
  9. neptrainkit-2.1.1.dev37/src/NepTrainKit/core/canvas/vispy/__init__.py +12 -0
  10. {neptrainkit-2.0.6.dev57 → neptrainkit-2.1.1.dev37}/src/NepTrainKit/core/canvas/vispy/canvas.py +26 -27
  11. neptrainkit-2.1.1.dev37/src/NepTrainKit/core/canvas/vispy/structure.py +386 -0
  12. neptrainkit-2.1.1.dev37/src/NepTrainKit/core/custom_widget/__init__.py +37 -0
  13. {neptrainkit-2.0.6.dev57 → neptrainkit-2.1.1.dev37}/src/NepTrainKit/core/custom_widget/card_widget.py +2 -4
  14. {neptrainkit-2.0.6.dev57 → neptrainkit-2.1.1.dev37}/src/NepTrainKit/core/custom_widget/input.py +5 -5
  15. {neptrainkit-2.0.6.dev57 → neptrainkit-2.1.1.dev37}/src/NepTrainKit/core/io/base.py +51 -43
  16. {neptrainkit-2.0.6.dev57 → neptrainkit-2.1.1.dev37}/src/NepTrainKit/core/io/nep.py +87 -48
  17. {neptrainkit-2.0.6.dev57 → neptrainkit-2.1.1.dev37}/src/NepTrainKit/core/io/utils.py +3 -0
  18. {neptrainkit-2.0.6.dev57 → neptrainkit-2.1.1.dev37}/src/NepTrainKit/core/message.py +7 -2
  19. {neptrainkit-2.0.6.dev57 → neptrainkit-2.1.1.dev37}/src/NepTrainKit/core/pages/makedata.py +1 -1
  20. {neptrainkit-2.0.6.dev57 → neptrainkit-2.1.1.dev37}/src/NepTrainKit/core/pages/show_nep.py +16 -6
  21. {neptrainkit-2.0.6.dev57 → neptrainkit-2.1.1.dev37}/src/NepTrainKit/core/views/__init__.py +1 -1
  22. {neptrainkit-2.0.6.dev57 → neptrainkit-2.1.1.dev37}/src/NepTrainKit/core/views/cards.py +172 -78
  23. {neptrainkit-2.0.6.dev57 → neptrainkit-2.1.1.dev37}/src/NepTrainKit/core/views/nep.py +6 -5
  24. neptrainkit-2.1.1.dev37/src/NepTrainKit/core/views/structure.py +66 -0
  25. {neptrainkit-2.0.6.dev57 → neptrainkit-2.1.1.dev37}/src/NepTrainKit/main.py +3 -2
  26. {neptrainkit-2.0.6.dev57 → neptrainkit-2.1.1.dev37}/src/NepTrainKit/utils.py +1 -1
  27. {neptrainkit-2.0.6.dev57 → neptrainkit-2.1.1.dev37}/src/NepTrainKit.egg-info/PKG-INFO +1 -1
  28. {neptrainkit-2.0.6.dev57 → neptrainkit-2.1.1.dev37}/src/NepTrainKit.egg-info/SOURCES.txt +5 -0
  29. neptrainkit-2.1.1.dev37/tests/test_base.py +44 -0
  30. neptrainkit-2.0.6.dev57/src/NepTrainKit/core/custom_widget/__init__.py +0 -14
  31. neptrainkit-2.0.6.dev57/tests/test_base.py +0 -34
  32. {neptrainkit-2.0.6.dev57 → neptrainkit-2.1.1.dev37}/.gitattributes +0 -0
  33. {neptrainkit-2.0.6.dev57 → neptrainkit-2.1.1.dev37}/.readthedocs.yml +0 -0
  34. {neptrainkit-2.0.6.dev57 → neptrainkit-2.1.1.dev37}/LICENSE +0 -0
  35. {neptrainkit-2.0.6.dev57 → neptrainkit-2.1.1.dev37}/MANIFEST.in +0 -0
  36. {neptrainkit-2.0.6.dev57 → neptrainkit-2.1.1.dev37}/README.md +0 -0
  37. {neptrainkit-2.0.6.dev57 → neptrainkit-2.1.1.dev37}/setup.cfg +0 -0
  38. {neptrainkit-2.0.6.dev57 → neptrainkit-2.1.1.dev37}/setup.py +0 -0
  39. {neptrainkit-2.0.6.dev57 → neptrainkit-2.1.1.dev37}/src/NepTrainKit/Config/config.sqlite +0 -0
  40. {neptrainkit-2.0.6.dev57 → neptrainkit-2.1.1.dev37}/src/NepTrainKit/Config/nep.json +0 -0
  41. {neptrainkit-2.0.6.dev57 → neptrainkit-2.1.1.dev37}/src/NepTrainKit/Config/nep89.txt +0 -0
  42. {neptrainkit-2.0.6.dev57 → neptrainkit-2.1.1.dev37}/src/NepTrainKit/Config/ptable.json +0 -0
  43. {neptrainkit-2.0.6.dev57 → neptrainkit-2.1.1.dev37}/src/NepTrainKit/__init__.py +0 -0
  44. {neptrainkit-2.0.6.dev57 → neptrainkit-2.1.1.dev37}/src/NepTrainKit/core/__init__.py +0 -0
  45. {neptrainkit-2.0.6.dev57 → neptrainkit-2.1.1.dev37}/src/NepTrainKit/core/calculator.py +0 -0
  46. {neptrainkit-2.0.6.dev57 → neptrainkit-2.1.1.dev37}/src/NepTrainKit/core/canvas/base/canvas.py +0 -0
  47. {neptrainkit-2.0.6.dev57 → neptrainkit-2.1.1.dev37}/src/NepTrainKit/core/config.py +0 -0
  48. {neptrainkit-2.0.6.dev57 → neptrainkit-2.1.1.dev37}/src/NepTrainKit/core/custom_widget/completer.py +0 -0
  49. {neptrainkit-2.0.6.dev57 → neptrainkit-2.1.1.dev37}/src/NepTrainKit/core/custom_widget/dialog.py +0 -0
  50. {neptrainkit-2.0.6.dev57 → neptrainkit-2.1.1.dev37}/src/NepTrainKit/core/custom_widget/docker.py +0 -0
  51. {neptrainkit-2.0.6.dev57 → neptrainkit-2.1.1.dev37}/src/NepTrainKit/core/custom_widget/label.py +0 -0
  52. {neptrainkit-2.0.6.dev57 → neptrainkit-2.1.1.dev37}/src/NepTrainKit/core/custom_widget/layout.py +0 -0
  53. {neptrainkit-2.0.6.dev57 → neptrainkit-2.1.1.dev37}/src/NepTrainKit/core/custom_widget/search_widget.py +0 -0
  54. {neptrainkit-2.0.6.dev57 → neptrainkit-2.1.1.dev37}/src/NepTrainKit/core/custom_widget/settingscard.py +0 -0
  55. {neptrainkit-2.0.6.dev57 → neptrainkit-2.1.1.dev37}/src/NepTrainKit/core/io/__init__.py +0 -0
  56. {neptrainkit-2.0.6.dev57 → neptrainkit-2.1.1.dev37}/src/NepTrainKit/core/io/select.py +0 -0
  57. {neptrainkit-2.0.6.dev57 → neptrainkit-2.1.1.dev37}/src/NepTrainKit/core/pages/__init__.py +0 -0
  58. {neptrainkit-2.0.6.dev57 → neptrainkit-2.1.1.dev37}/src/NepTrainKit/core/pages/settings.py +0 -0
  59. {neptrainkit-2.0.6.dev57 → neptrainkit-2.1.1.dev37}/src/NepTrainKit/core/structure.py +0 -0
  60. {neptrainkit-2.0.6.dev57 → neptrainkit-2.1.1.dev37}/src/NepTrainKit/core/types.py +0 -0
  61. {neptrainkit-2.0.6.dev57 → neptrainkit-2.1.1.dev37}/src/NepTrainKit/core/update.py +0 -0
  62. {neptrainkit-2.0.6.dev57 → neptrainkit-2.1.1.dev37}/src/NepTrainKit/core/views/toolbar.py +0 -0
  63. {neptrainkit-2.0.6.dev57 → neptrainkit-2.1.1.dev37}/src/NepTrainKit/src_rc.py +0 -0
  64. {neptrainkit-2.0.6.dev57 → neptrainkit-2.1.1.dev37}/src/NepTrainKit/version.py +0 -0
  65. {neptrainkit-2.0.6.dev57 → neptrainkit-2.1.1.dev37}/src/NepTrainKit.egg-info/dependency_links.txt +0 -0
  66. {neptrainkit-2.0.6.dev57 → neptrainkit-2.1.1.dev37}/src/NepTrainKit.egg-info/entry_points.txt +0 -0
  67. {neptrainkit-2.0.6.dev57 → neptrainkit-2.1.1.dev37}/src/NepTrainKit.egg-info/not-zip-safe +0 -0
  68. {neptrainkit-2.0.6.dev57 → neptrainkit-2.1.1.dev37}/src/NepTrainKit.egg-info/requires.txt +0 -0
  69. {neptrainkit-2.0.6.dev57 → neptrainkit-2.1.1.dev37}/src/NepTrainKit.egg-info/top_level.txt +0 -0
  70. {neptrainkit-2.0.6.dev57 → neptrainkit-2.1.1.dev37}/src/nep_cpu/dftd3para.h +0 -0
  71. {neptrainkit-2.0.6.dev57 → neptrainkit-2.1.1.dev37}/src/nep_cpu/nep.cpp +0 -0
  72. {neptrainkit-2.0.6.dev57 → neptrainkit-2.1.1.dev37}/src/nep_cpu/nep.h +0 -0
  73. {neptrainkit-2.0.6.dev57 → neptrainkit-2.1.1.dev37}/src/nep_cpu/nep_cpu.cpp +0 -0
  74. {neptrainkit-2.0.6.dev57 → neptrainkit-2.1.1.dev37}/tests/data/dipole/nep.txt +0 -0
  75. {neptrainkit-2.0.6.dev57 → neptrainkit-2.1.1.dev37}/tests/data/dipole/train.xyz +0 -0
  76. {neptrainkit-2.0.6.dev57 → neptrainkit-2.1.1.dev37}/tests/data/nep/descriptor.npy +0 -0
  77. {neptrainkit-2.0.6.dev57 → neptrainkit-2.1.1.dev37}/tests/data/nep/energy.npy +0 -0
  78. {neptrainkit-2.0.6.dev57 → neptrainkit-2.1.1.dev37}/tests/data/nep/forces.npy +0 -0
  79. {neptrainkit-2.0.6.dev57 → neptrainkit-2.1.1.dev37}/tests/data/nep/nep.txt +0 -0
  80. {neptrainkit-2.0.6.dev57 → neptrainkit-2.1.1.dev37}/tests/data/nep/train.xyz +0 -0
  81. {neptrainkit-2.0.6.dev57 → neptrainkit-2.1.1.dev37}/tests/data/nep/virial.npy +0 -0
  82. {neptrainkit-2.0.6.dev57 → neptrainkit-2.1.1.dev37}/tests/data/polarizability/nep.txt +0 -0
  83. {neptrainkit-2.0.6.dev57 → neptrainkit-2.1.1.dev37}/tests/data/polarizability/train.xyz +0 -0
  84. {neptrainkit-2.0.6.dev57 → neptrainkit-2.1.1.dev37}/tests/test_calculator.py +0 -0
  85. {neptrainkit-2.0.6.dev57 → neptrainkit-2.1.1.dev37}/tests/test_nep.py +0 -0
  86. {neptrainkit-2.0.6.dev57 → neptrainkit-2.1.1.dev37}/tests/test_structure.py +0 -0
@@ -0,0 +1,8 @@
1
+ **/__pycache__/
2
+ /src/NepTrainKit.egg-info/
3
+ /build
4
+ /dist
5
+ /docs/build
6
+ .idea/
7
+ /src/NepTrainKit/?.py
8
+ /src/NepTrainKit/Log/*.log
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: NepTrainKit
3
- Version: 2.0.6.dev57
3
+ Version: 2.1.1.dev37
4
4
  Summary: NepTrainKit is a Python package for visualizing and manipulating training datasets for NEP.
5
5
  Author: Chen Cheng bing
6
6
  Author-email: Chen Cheng bing <1747193328@qq.com>
@@ -43,7 +43,7 @@ dependencies = [
43
43
  "numpy>=1.26.0",
44
44
  "vispy>=0.14.3",
45
45
  "ase",
46
- "scipy>=1.15.0"
46
+ "scipy>=1.15.0",
47
47
  ]
48
48
  [project.optional-dependencies]
49
49
  test = [
@@ -4,6 +4,7 @@ loguru==0.7.2
4
4
  requests==2.32.3
5
5
  pyqtgraph==0.13.7
6
6
  PyOpenGL==3.1.7
7
- vispy==0.14.3
7
+ vispy==0.15.2
8
8
  ase
9
- scipy==1.15.0
9
+ scipy==1.15.0
10
+
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '2.0.6.dev57'
21
- __version_tuple__ = version_tuple = (2, 0, 6, 'dev57')
20
+ __version__ = version = '2.1.1.dev37'
21
+ __version_tuple__ = version_tuple = (2, 1, 1, 'dev37')
@@ -0,0 +1,11 @@
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ # @Time : 2025/5/20 14:20
4
+ # @Author : 兵
5
+ # @email : 1747193328@qq.com
6
+ import pyqtgraph as pg
7
+ pg.setConfigOption('background', 'w') # 设置背景为白色
8
+ pg.setConfigOption('foreground', 'k') # 设置前景元素为黑色(如坐标轴)
9
+ pg.setConfigOptions(antialias=False,useOpenGL=False)
10
+ from .structure import StructurePlotWidget
11
+ from .canvas import PyqtgraphCanvas
@@ -8,57 +8,55 @@
8
8
  from functools import partial
9
9
 
10
10
  import numpy as np
11
- import pyqtgraph as pg
11
+
12
12
  from PySide6.QtCore import Qt
13
- from pyqtgraph import GraphicsLayoutWidget, ScatterPlotItem, PlotItem, ViewBox, TextItem
13
+ from pyqtgraph import GraphicsLayoutWidget, ScatterPlotItem, PlotItem, ViewBox, TextItem
14
14
 
15
15
  from NepTrainKit import utils
16
16
  from NepTrainKit.core.types import Brushes, Pens
17
17
  from ..base.canvas import CanvasLayoutBase
18
18
  from ...io import NepTrainResultData
19
19
 
20
- pg.setConfigOption('background', 'w') # 设置背景为白色
21
- pg.setConfigOption('foreground', 'k') # 设置前景元素为黑色(如坐标轴)
22
- pg.setConfigOptions(antialias=False,useOpenGL=False)
23
20
 
24
21
  class MyPlotItem(PlotItem):
25
22
  """
26
23
  自定义Item 实例化即可创建一个axes
27
24
  """
25
+
28
26
  def __init__(self, **kwargs):
29
27
  super().__init__(**kwargs)
30
28
  self.disableAutoRange()
31
29
 
32
- self._scatter=ScatterPlotItem()
30
+ self._scatter = ScatterPlotItem()
33
31
  self.addItem(self._scatter)
34
32
 
35
- self.text = TextItem( color=(231, 63, 50))
33
+ self.text = TextItem(color=(231, 63, 50))
36
34
 
37
35
  self.addItem(self.text)
38
36
 
39
- self.current_point=ScatterPlotItem()
37
+ self.current_point = ScatterPlotItem()
40
38
  self.current_point.setZValue(100)
41
39
  if "title" in kwargs:
42
40
  self.setTitle(kwargs["title"])
41
+
43
42
  def scatter(self, *args, **kargs):
44
43
  self._scatter.setData(*args, **kargs)
45
- def set_current_point(self, x,y):
46
44
 
47
- self.current_point.setData( x, y,brush=Brushes.Current ,pen=Pens.Current,
48
- symbol='star',size=15 )
49
- if self.current_point not in self.items:
45
+ def set_current_point(self, x, y):
50
46
 
47
+ self.current_point.setData(x, y, brush=Brushes.Current, pen=Pens.Current,
48
+ symbol='star', size=15)
49
+ if self.current_point not in self.items:
51
50
  self.addItem(self.current_point)
52
51
 
53
-
54
52
  def add_diagonal(self):
55
53
 
56
54
  self.addLine(angle=45, pos=(0.5, 0.5), pen=Pens.Line)
57
55
 
58
- def item_clicked(self,scatter_item,items,event):
56
+ def item_clicked(self, scatter_item, items, event):
59
57
 
60
58
  if items.any():
61
- item=items[0]
59
+ item = items[0]
62
60
 
63
61
  self.structureIndexChanged.emit(item.data())
64
62
 
@@ -70,42 +68,41 @@ class MyPlotItem(PlotItem):
70
68
  def title(self, t):
71
69
  if t == self.title:
72
70
  return
73
- self.setTitle( t)
71
+ self.setTitle(t)
74
72
  if t != "descriptor":
75
73
  self.add_diagonal()
76
74
 
77
75
 
78
-
79
76
  class CombinedMeta(type(CanvasLayoutBase), type(GraphicsLayoutWidget)):
80
77
  pass
81
78
 
82
- class PyqtgraphCanvas(CanvasLayoutBase,GraphicsLayoutWidget, metaclass=CombinedMeta):
79
+
80
+ class PyqtgraphCanvas(CanvasLayoutBase, GraphicsLayoutWidget, metaclass=CombinedMeta):
83
81
  """
84
82
  pyqtgraph 绘图类
85
83
  """
86
- def __init__(self,*args, **kwargs):
87
- GraphicsLayoutWidget.__init__(self,*args,**kwargs)
84
+
85
+ def __init__(self, *args, **kwargs):
86
+ GraphicsLayoutWidget.__init__(self, *args, **kwargs)
88
87
 
89
88
  CanvasLayoutBase.__init__(self)
90
- self.nep_result_data=None
91
- def set_nep_result_data(self,dataset):
92
- self.nep_result_data:NepTrainResultData=dataset
89
+ self.nep_result_data = None
93
90
 
91
+ def set_nep_result_data(self, dataset):
92
+ self.nep_result_data: NepTrainResultData = dataset
94
93
 
95
94
  def clear_axes(self):
96
95
  self.clear()
97
96
 
98
97
  super().clear_axes()
99
98
 
100
-
101
-
102
- def init_axes(self,axes_num ):
99
+ def init_axes(self, axes_num):
103
100
  self.clear_axes()
104
101
 
105
102
  for r in range(axes_num):
106
103
  plot = MyPlotItem(title="")
107
104
  self.addItem(plot)
108
- plot.getViewBox().mouseDoubleClickEvent = partial(self.view_on_double_clicked,plot=plot)
105
+ plot.getViewBox().mouseDoubleClickEvent = partial(self.view_on_double_clicked, plot=plot)
109
106
  plot.getViewBox().setMouseEnabled(False, False)
110
107
  self.axes_list.append(plot)
111
108
 
@@ -113,11 +110,11 @@ class PyqtgraphCanvas(CanvasLayoutBase,GraphicsLayoutWidget, metaclass=CombinedM
113
110
 
114
111
  self.set_view_layout()
115
112
 
116
- def view_on_double_clicked(self,event,plot):
113
+ def view_on_double_clicked(self, event, plot):
117
114
  self.set_current_axes(plot)
118
115
 
119
116
  def set_view_layout(self):
120
- if len(self.axes_list)==0:
117
+ if len(self.axes_list) == 0:
121
118
  return
122
119
  if self.current_axes not in self.axes_list:
123
120
  self.set_current_axes(self.axes_list[0])
@@ -134,7 +131,6 @@ class PyqtgraphCanvas(CanvasLayoutBase,GraphicsLayoutWidget, metaclass=CombinedM
134
131
  for col, factor in enumerate([3, 1]):
135
132
  self.ci.layout.setRowStretchFactor(col, factor)
136
133
 
137
-
138
134
  @utils.timeit
139
135
  def plot_nep_result(self):
140
136
  """
@@ -142,174 +138,165 @@ class PyqtgraphCanvas(CanvasLayoutBase,GraphicsLayoutWidget, metaclass=CombinedM
142
138
  """
143
139
  self.nep_result_data.select_index.clear()
144
140
 
145
- for index,_dataset in enumerate(self.nep_result_data.dataset):
146
- plot=self.axes_list[index]
147
- plot.title= _dataset.title
148
- plot.scatter(_dataset.x,_dataset.y,data=_dataset.structure_index,
149
- brush=Brushes.get(_dataset.title.upper()) ,pen=Pens.get(_dataset.title.upper()),
150
- symbol='o',size=7,
141
+ for index, _dataset in enumerate(self.nep_result_data.dataset):
142
+ plot = self.axes_list[index]
143
+ plot.title = _dataset.title
144
+ plot.scatter(_dataset.x, _dataset.y, data=_dataset.structure_index,
145
+ brush=Brushes.get(_dataset.title.upper()), pen=Pens.get(_dataset.title.upper()),
146
+ symbol='o', size=7,
151
147
 
152
- )
148
+ )
153
149
  # 设置视图框更新模式
154
150
  self.auto_range(plot)
155
- if _dataset.group_array.num !=0:
156
- #更新结构
151
+ if _dataset.group_array.num != 0:
152
+ # 更新结构
157
153
  if self.structure_index not in _dataset.group_array.now_data:
158
- self.structure_index=_dataset.group_array.now_data[0]
154
+ self.structure_index = _dataset.group_array.now_data[0]
159
155
  self.structureIndexChanged.emit(self.structure_index)
160
156
 
161
157
  else:
162
158
  plot.set_current_point([], [])
163
159
 
164
-
165
160
  if _dataset.title not in ["descriptor"]:
166
- #
167
- pos=self.convert_pos(plot,(0 ,1))
168
- text=f"rmse: {_dataset.get_formart_rmse()}"
161
+ #
162
+ pos = self.convert_pos(plot, (0, 1))
163
+ text = f"rmse: {_dataset.get_formart_rmse()}"
169
164
  plot.text.setText(text)
170
165
  plot.text.setPos(*pos)
171
166
 
172
-
173
-
174
- def plot_current_point(self,structure_index):
167
+ def plot_current_point(self, structure_index):
175
168
  """
176
169
  鼠标点击后 在所有子图上绘制五角星标记当前点
177
170
  """
178
- self.structure_index=structure_index
179
-
180
- for plot in self.axes_list :
181
- dataset=self.get_axes_dataset(plot)
182
- array_index=dataset.convert_index(structure_index)
183
- if dataset.now_data.size!=0:
184
- data=dataset.now_data[array_index,: ]
185
- plot.set_current_point(data[:,dataset.cols:].flatten(),
171
+ self.structure_index = structure_index
172
+
173
+ for plot in self.axes_list:
174
+ dataset = self.get_axes_dataset(plot)
175
+ array_index = dataset.convert_index(structure_index)
176
+ if dataset.now_data.size != 0:
177
+ data = dataset.all_data[array_index, :]
178
+ plot.set_current_point(data[:, dataset.cols:].flatten(),
186
179
  data[:, :dataset.cols].flatten(),
187
180
  )
188
-
189
- def item_clicked(self,scatter_item,items,event):
181
+
182
+ def item_clicked(self, scatter_item, items, event):
190
183
 
191
184
  if items.any():
192
- item=items[0]
185
+ item = items[0]
193
186
 
194
187
  self.structureIndexChanged.emit(item.data())
195
188
 
189
+ def select_point_from_polygon(self, polygon_xy, reverse):
196
190
 
197
-
198
-
199
-
200
- def select_point_from_polygon(self,polygon_xy,reverse ):
201
-
202
- index=self.is_point_in_polygon(np.column_stack([self.current_axes._scatter.data["x"],self.current_axes._scatter.data["y"]]),polygon_xy)
191
+ index = self.is_point_in_polygon(
192
+ np.column_stack([self.current_axes._scatter.data["x"], self.current_axes._scatter.data["y"]]), polygon_xy)
203
193
  index = np.where(index)[0]
204
- select_index=self.current_axes._scatter.data[index]["data"].tolist()
205
- self.select_index(select_index,reverse)
206
-
194
+ select_index = self.current_axes._scatter.data[index]["data"].tolist()
195
+ self.select_index(select_index, reverse)
207
196
 
208
- def select_point(self,pos,reverse):
197
+ def select_point(self, pos, reverse):
209
198
  """
210
199
  鼠标单击选择结构
211
200
  """
212
- items=self.current_axes._scatter.pointsAt(pos)
201
+ items = self.current_axes._scatter.pointsAt(pos)
213
202
  if len(items):
214
- item=items[0]
215
- index=item.index()
216
- structure_index =item.data()
217
- self.select_index(structure_index,reverse)
218
-
219
-
220
-
221
- def update_scatter_color(self,structure_index,color=Brushes.Selected):
203
+ item = items[0]
204
+ index = item.index()
205
+ structure_index = item.data()
206
+ self.select_index(structure_index, reverse)
207
+ @utils.timeit
208
+ def update_scatter_color(self, structure_index, color=Brushes.Selected):
222
209
  """
223
210
  当结构点的状态发生变化的时候 通过该函数更改axes中散点的颜色
224
211
  """
225
212
 
226
- for i,plot in enumerate(self.axes_list):
213
+ for i, plot in enumerate(self.axes_list):
227
214
 
228
215
  if not plot._scatter:
229
216
  continue
230
- structure_index_set= set(structure_index)
217
+ structure_index_set = set(structure_index)
231
218
  index_list = [i for i, val in enumerate(plot._scatter.data["data"]) if val in structure_index_set]
232
219
 
233
- plot._scatter.data["brush"][index_list]= color
220
+ plot._scatter.data["brush"][index_list] = color
234
221
  plot._scatter.data['sourceRect'][index_list] = (0, 0, 0, 0)
235
222
 
223
+ plot._scatter.updateSpots()
236
224
 
237
- plot._scatter.updateSpots( )
238
-
239
- def convert_pos(self,plot,pos):
225
+ def convert_pos(self, plot, pos):
240
226
  view_range = plot.viewRange()
241
227
  x_range = view_range[0] # x轴范围 [xmin, xmax]
242
228
  y_range = view_range[1] # y轴范围 [ymin, ymax]
243
229
 
244
230
  # 将百分比位置转换为坐标
245
- x_percent = pos[0] # 50% 对应 x 轴中间
246
- y_percent = pos[1] # 20% 对应 y 轴上的某个位置
231
+ x_percent = pos[0] # 50% 对应 x 轴中间
232
+ y_percent = pos[1] # 20% 对应 y 轴上的某个位置
247
233
 
248
234
  x_pos = x_range[0] + x_percent * (x_range[1] - x_range[0]) # 根据百分比计算实际位置
249
235
  y_pos = y_range[0] + y_percent * (y_range[1] - y_range[0]) # 根据百分比计算实际位置
250
- return x_pos,y_pos
251
- def auto_range(self,plot=None):
236
+ return x_pos, y_pos
237
+
238
+ def auto_range(self, plot=None):
252
239
 
253
240
  if plot is None:
254
- plot=self.current_axes
241
+ plot = self.current_axes
255
242
  if plot:
256
243
 
257
244
  view = plot.getViewBox()
258
245
 
259
- x_range=[10000,-10000]
260
- y_range=[10000,-10000]
246
+ x_range = [10000, -10000]
247
+ y_range = [10000, -10000]
261
248
  for item in view.addedItems:
262
249
  if isinstance(item, ScatterPlotItem):
263
250
 
264
- x=item.data["x"]
265
- y=item.data["y"]
251
+ x = item.data["x"]
252
+ y = item.data["y"]
266
253
 
267
- x=x[x>-10000]
268
- y=y[y>-10000]
269
- if x.size==0:
270
- x_range=[0,1]
271
- y_range=[0,1]
254
+ x = x[x > -10000]
255
+ y = y[y > -10000]
256
+ if x.size == 0:
257
+ x_range = [0, 1]
258
+ y_range = [0, 1]
272
259
  continue
273
- x_min = np.min(x )
274
- x_max = np.max(x )
275
- y_min = np.min(y )
276
- y_max = np.max(y )
260
+ x_min = np.min(x)
261
+ x_max = np.max(x)
262
+ y_min = np.min(y)
263
+ y_max = np.max(y)
277
264
  if x_min < x_range[0]:
278
- x_range[0]=x_min
265
+ x_range[0] = x_min
279
266
  if x_max > x_range[1]:
280
- x_range[1]=x_max
281
- if y_min <y_range[0]:
282
- y_range[0]=y_min
267
+ x_range[1] = x_max
268
+ if y_min < y_range[0]:
269
+ y_range[0] = y_min
283
270
  if y_max > y_range[1]:
284
- y_range[1]=y_max
285
- if plot.title!="descriptor":
271
+ y_range[1] = y_max
272
+ if plot.title != "descriptor":
286
273
 
287
- real_range=(min(x_range[0],y_range[0]),max(x_range[1],y_range[1]))
288
- view.setRange(xRange=real_range,yRange=real_range)
274
+ real_range = (min(x_range[0], y_range[0]), max(x_range[1], y_range[1]))
275
+ view.setRange(xRange=real_range, yRange=real_range)
289
276
  else:
290
- view.setRange(xRange=x_range,yRange=y_range)
277
+ view.setRange(xRange=x_range, yRange=y_range)
291
278
 
292
279
  def pan(self, checked):
293
280
 
294
281
  if self.current_axes:
295
-
296
282
  self.current_axes.setMouseEnabled(checked, checked)
297
283
  self.current_axes.getViewBox().setMouseMode(ViewBox.PanMode)
284
+
298
285
  def pen(self, checked):
299
286
  if self.current_axes is None:
300
-
301
287
  return False
302
288
 
303
289
  if checked:
304
- self.draw_mode=True
290
+ self.draw_mode = True
305
291
  # 初始化鼠标状态和轨迹数据
306
292
  self.is_drawing = False
307
293
  self.x_data = []
308
294
  self.y_data = []
309
295
 
310
296
  else:
311
- self.draw_mode=False
297
+ self.draw_mode = False
312
298
  pass
299
+
313
300
  #
314
301
  def mousePressEvent(self, event):
315
302
  if not self.draw_mode:
@@ -323,19 +310,18 @@ class PyqtgraphCanvas(CanvasLayoutBase,GraphicsLayoutWidget, metaclass=CombinedM
323
310
 
324
311
  self.curve.setData([], []) # 清空绘制线条,避免对角线
325
312
 
326
-
327
313
  def mouseReleaseEvent(self, event):
328
314
 
329
315
  if not self.draw_mode:
330
316
  return super().mouseReleaseEvent(event)
331
- if event.button() == Qt.MouseButton.LeftButton or event.button() == Qt.MouseButton.RightButton:
317
+ if event.button() == Qt.MouseButton.LeftButton or event.button() == Qt.MouseButton.RightButton:
332
318
  self.is_drawing = False
333
- reverse=event.button() == Qt.MouseButton.RightButton
319
+ reverse = event.button() == Qt.MouseButton.RightButton
334
320
  self.current_axes.removeItem(self.curve)
335
321
  # 创建鼠标轨迹的多边形
336
- if len(self.x_data)>2:
322
+ if len(self.x_data) > 2:
337
323
 
338
- self.select_point_from_polygon(np.column_stack((self.x_data, self.y_data)),reverse)
324
+ self.select_point_from_polygon(np.column_stack((self.x_data, self.y_data)), reverse)
339
325
  else:
340
326
  # 右键的话 选中单个点
341
327
  pass
@@ -343,10 +329,9 @@ class PyqtgraphCanvas(CanvasLayoutBase,GraphicsLayoutWidget, metaclass=CombinedM
343
329
  mouse_point = self.current_axes.getViewBox().mapSceneToView(pos)
344
330
 
345
331
  x = mouse_point.x()
346
- self.select_point(mouse_point,reverse)
332
+ self.select_point(mouse_point, reverse)
347
333
  return
348
334
 
349
-
350
335
  def mouseMoveEvent(self, event):
351
336
  if not self.draw_mode:
352
337
  return super().mouseMoveEvent(event)
@@ -355,11 +340,11 @@ class PyqtgraphCanvas(CanvasLayoutBase,GraphicsLayoutWidget, metaclass=CombinedM
355
340
  pos = event.pos()
356
341
  if self.current_axes.sceneBoundingRect().contains(pos):
357
342
  # 将场景坐标转换为视图坐标
358
- mouse_point =self.current_axes.getViewBox().mapSceneToView(pos)
343
+ mouse_point = self.current_axes.getViewBox().mapSceneToView(pos)
359
344
  x, y = mouse_point.x(), mouse_point.y()
360
345
  # 记录轨迹数据
361
346
  self.x_data.append(x)
362
347
  self.y_data.append(y)
363
348
 
364
349
  # 更新绘图
365
- self.curve.setData(self.x_data, self.y_data)
350
+ self.curve.setData(self.x_data, self.y_data)
@@ -1,27 +1,21 @@
1
1
  #!/usr/bin/env python
2
2
  # -*- coding: utf-8 -*-
3
- # @Time : 2025/3/15 13:44
3
+ # @Time : 2025/5/19 20:42
4
4
  # @Author : 兵
5
5
  # @email : 1747193328@qq.com
6
-
7
- import numpy as np
8
6
  #TODO: 这个导入慢 后面看能不能优化
9
7
  import pyqtgraph as pg
10
8
  import pyqtgraph.opengl as gl
11
9
 
12
-
13
- from OpenGL.GL import * # noqa
10
+ import numpy as np
11
+ from OpenGL.GL import GL_PROJECTION, glLoadMatrixf, glMatrixMode
12
+ from NepTrainKit.core import Config
14
13
 
15
14
 
16
15
  from PySide6.QtCore import Qt
17
16
  from PySide6.QtGui import QColor,QMatrix4x4
18
- from PySide6.QtWidgets import QApplication, QWidget, QGridLayout, QSizePolicy
19
17
 
20
-
21
- from NepTrainKit.core import Config
22
- from NepTrainKit.core.structure import table_info, Structure
23
- from NepTrainKit import utils
24
- from qfluentwidgets import BodyLabel
18
+ from NepTrainKit.core.structure import table_info
25
19
 
26
20
 
27
21
  class StructurePlotWidget(gl.GLViewWidget):
@@ -399,66 +393,3 @@ class StructurePlotWidget(gl.GLViewWidget):
399
393
 
400
394
  # 示例:高亮第0个原子
401
395
  # self.highlight_atom(0)
402
-
403
-
404
- class StructureInfoWidget(QWidget):
405
- def __init__(self, parent=None):
406
- super(StructureInfoWidget, self).__init__(parent)
407
- self.init_ui()
408
- def init_ui(self):
409
- self._layout = QGridLayout(self) # 创建布局
410
- self._layout.setContentsMargins(0, 0, 0, 0) # 设置边距
411
- self._layout.setSpacing(0) # 设置间距
412
- self.setLayout(self._layout) # 设置布局
413
-
414
-
415
- self.atom_label = BodyLabel(self)
416
- self.atom_label.setText("Atoms:")
417
- self.atom_num_text = BodyLabel(self)
418
-
419
- self.formula_label = BodyLabel(self)
420
- self.formula_label.setText("Formula:")
421
- self.formula_text = BodyLabel(self)
422
- self.formula_text.setSizePolicy(QSizePolicy.Policy.Preferred, QSizePolicy.Policy.Fixed)
423
- self.formula_text.setWordWrap(True)
424
-
425
- self.lattice_label=BodyLabel(self)
426
- self.lattice_label.setText("Lattice:")
427
- self.lattice_text = BodyLabel(self)
428
- self.lattice_text.setWordWrap(True)
429
-
430
- self.config_label = BodyLabel(self)
431
- self.config_label.setText("Config Type:")
432
- self.config_text = BodyLabel(self)
433
- self.config_text.setSizePolicy(QSizePolicy.Policy.Preferred, QSizePolicy.Policy.Fixed)
434
-
435
-
436
- self._layout.addWidget(self.atom_label, 0,0,1,1)
437
- self._layout.addWidget(self.atom_num_text, 0, 1,1,3)
438
- self._layout.addWidget(self.formula_label, 1,0,1,1)
439
- self._layout.addWidget(self.formula_text, 1, 1,1,3)
440
-
441
-
442
- self._layout.addWidget(self.config_label, 2, 0,1,1)
443
- self._layout.addWidget(self.config_text, 2, 1,1,3)
444
-
445
- self._layout.addWidget(self.lattice_label, 3, 0,1,1)
446
- self._layout.addWidget(self.lattice_text, 3, 1,1,3)
447
-
448
- def show_structure_info(self, structure):
449
- pass
450
- self.atom_num_text.setText(str(len(structure )))
451
- self.formula_text.setText(structure.html_formula)
452
- self.lattice_text.setText(str(np.round(structure.lattice,3)))
453
- self.config_text.setText(structure.Config_type)
454
-
455
- if __name__ == '__main__':
456
- app = QApplication([])
457
- view = StructurePlotWidget()
458
- view.show()
459
- import time
460
- start = time.time()
461
- atoms = Structure.read_xyz("good.xyz")
462
- view.show_structure(atoms) # 修改为show_structure,与代码一致
463
- print(time.time() - start)
464
- QApplication.instance().exec_()
@@ -0,0 +1,12 @@
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ # @Time : 2025/5/20 14:18
4
+ # @Author : 兵
5
+ # @email : 1747193328@qq.com
6
+ from vispy import use
7
+ # 不要去掉
8
+ from vispy.app.backends import _pyside6
9
+ use("PySide6", "gl2")
10
+
11
+ from .structure import StructurePlotWidget
12
+ from .canvas import VispyCanvas