surface-construct 0.8.4__tar.gz → 0.9__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (25) hide show
  1. {surface_construct-0.8.4/surface_construct.egg-info → surface_construct-0.9}/PKG-INFO +3 -3
  2. {surface_construct-0.8.4 → surface_construct-0.9}/setup.py +1 -2
  3. surface_construct-0.9/surface_construct/__init__.py +12 -0
  4. {surface_construct-0.8.4 → surface_construct-0.9}/surface_construct/db.py +1 -1
  5. surface_construct-0.8.4/surface_construct/sampling.py → surface_construct-0.9/surface_construct/sg_sampler.py +81 -79
  6. {surface_construct-0.8.4 → surface_construct-0.9/surface_construct.egg-info}/PKG-INFO +3 -3
  7. {surface_construct-0.8.4 → surface_construct-0.9}/surface_construct.egg-info/SOURCES.txt +3 -8
  8. {surface_construct-0.8.4 → surface_construct-0.9}/tests/test_sampling1.py +6 -5
  9. {surface_construct-0.8.4 → surface_construct-0.9}/tests/test_sampling2.py +6 -6
  10. {surface_construct-0.8.4 → surface_construct-0.9}/tests/test_surface_grid.py +2 -2
  11. surface_construct-0.9/tests/test_task.py +83 -0
  12. surface_construct-0.8.4/surface_construct/__init__.py +0 -4
  13. surface_construct-0.8.4/surface_construct/atoms.py +0 -564
  14. surface_construct-0.8.4/surface_construct/structure.py +0 -536
  15. surface_construct-0.8.4/surface_construct/surface.py +0 -738
  16. surface_construct-0.8.4/surface_construct/surface_grid.py +0 -1115
  17. surface_construct-0.8.4/surface_construct/utils.py +0 -177
  18. surface_construct-0.8.4/surface_construct/weight_functions.py +0 -65
  19. {surface_construct-0.8.4 → surface_construct-0.9}/LICENSE +0 -0
  20. {surface_construct-0.8.4 → surface_construct-0.9}/README.md +0 -0
  21. {surface_construct-0.8.4 → surface_construct-0.9}/setup.cfg +0 -0
  22. {surface_construct-0.8.4 → surface_construct-0.9}/surface_construct/default_parameter.py +0 -0
  23. {surface_construct-0.8.4 → surface_construct-0.9}/surface_construct.egg-info/dependency_links.txt +0 -0
  24. {surface_construct-0.8.4 → surface_construct-0.9}/surface_construct.egg-info/requires.txt +0 -0
  25. {surface_construct-0.8.4 → surface_construct-0.9}/surface_construct.egg-info/top_level.txt +0 -0
@@ -1,13 +1,12 @@
1
- Metadata-Version: 2.2
1
+ Metadata-Version: 2.4
2
2
  Name: surface_construct
3
- Version: 0.8.4
3
+ Version: 0.9
4
4
  Summary: Surface termination construction especially for complex model, such as oxides or carbides.
5
5
  Home-page: https://gitee.com/pjren/surface_construct/
6
6
  Author: ren
7
7
  Author-email: 0403114076@163.com
8
8
  License: GPL
9
9
  Classifier: Programming Language :: Python :: 3
10
- Classifier: License :: OSI Approved :: GNU Lesser General Public License v3 (LGPLv3)
11
10
  Classifier: Operating System :: OS Independent
12
11
  Description-Content-Type: text/markdown
13
12
  License-File: LICENSE
@@ -25,6 +24,7 @@ Dynamic: description
25
24
  Dynamic: description-content-type
26
25
  Dynamic: home-page
27
26
  Dynamic: license
27
+ Dynamic: license-file
28
28
  Dynamic: requires-dist
29
29
  Dynamic: summary
30
30
 
@@ -15,7 +15,7 @@ install_requires = [
15
15
 
16
16
  setup(
17
17
  name='surface_construct',
18
- version='0.8.4',
18
+ version='0.9',
19
19
  packages=['surface_construct'],
20
20
  url='https://gitee.com/pjren/surface_construct/',
21
21
  license='GPL',
@@ -27,7 +27,6 @@ setup(
27
27
  install_requires=install_requires,
28
28
  classifiers=[
29
29
  "Programming Language :: Python :: 3",
30
- "License :: OSI Approved :: GNU Lesser General Public License v3 (LGPLv3)",
31
30
  "Operating System :: OS Independent",
32
31
  ],
33
32
  )
@@ -0,0 +1,12 @@
1
+ from surface_construct.structures.surface import Crystal, Surface, Slab, Termination
2
+ from surface_construct.structures.surface import get_terminations_score
3
+ from surface_construct.structures.surface_grid import SurfaceGrid, GridGenerator
4
+
5
+
6
+ __all__ = ['SurfaceGrid',
7
+ 'GridGenerator',
8
+ 'Crystal',
9
+ 'Surface',
10
+ 'Slab',
11
+ 'Termination',
12
+ ]
@@ -6,7 +6,7 @@ from ase.db.core import now
6
6
  from ase.io.jsonio import write_json, read_json
7
7
  import os
8
8
  from ase.db.row import AtomsRow, row2dct
9
- from surface_construct.atoms import get_atoms_topo_id
9
+ from surface_construct.utils.atoms import get_atoms_topo_id
10
10
  from surface_construct.utils import calc_hull_vertices, get_calc_info
11
11
 
12
12
  """
@@ -1,36 +1,12 @@
1
- """
2
- TODO: 关键点采样:top 位置、hollow位,bridge 位等等。
3
- """
4
- import itertools
5
1
  import numpy as np
6
- from ase.geometry import get_distances
7
- from scipy.spatial import ConvexHull, cKDTree
2
+ from scipy.spatial import cKDTree
8
3
  from scipy.spatial.distance import cdist
9
- from scipy.special import comb
10
4
  from sklearn.cluster import KMeans as Cluster
11
5
  import random
12
6
 
13
- from surface_construct.utils import furthest_sites
14
7
 
15
- MIN_HULL_ANGLE_COS = np.cos(np.pi * 30 / 180)
16
-
17
- def hull_vertices(hull):
18
- hsimplices = hull.simplices
19
- hvertices = hull.vertices
20
- hnorms = hull.equations[:,0:-1]
21
- ndim = hsimplices.shape[1]
22
- vertices = []
23
- # 去掉 hull 的 simplices 的角度较大的点
24
- for i in hvertices:
25
- p0_facets_idx = np.argwhere(hsimplices == i)[:,0]
26
- p0_norms = hnorms[p0_facets_idx]
27
- cosangle = lambda a,b: a.dot(b) / (np.linalg.norm(a) * np.linalg.norm(b))
28
- # 求 i 凸点相邻的超平面的法向向量之间的夹角。如果存在夹角小于30度,即平面之间的夹角大于150度,则排除该点。反之,保留该点。
29
- norm_angle_cos = np.absolute([cosangle(a,b) for a,b in itertools.combinations(p0_norms, 2)])
30
- if np.sum(norm_angle_cos < MIN_HULL_ANGLE_COS) >= comb(ndim,2):
31
- vertices.append(i)
32
-
33
- return vertices
8
+ def name2sampler(name):
9
+ return globals()[name]
34
10
 
35
11
 
36
12
  def addition_samples(sg_obj, size=None, probability=None, **kwargs):
@@ -58,9 +34,9 @@ def addition_samples(sg_obj, size=None, probability=None, **kwargs):
58
34
  for method in method_list:
59
35
  method_lower = method.lower()
60
36
  if method_lower == 'max_sigma':
61
- sampling_obj = MaxSigmaSampling(sg_obj)
37
+ sampling_obj = MaxSigmaSGSampler(sg_obj)
62
38
  elif method_lower == 'max_diversity':
63
- sampling_obj = MaxDiversitySampling(sg_obj)
39
+ sampling_obj = MaxDiversitySGSampler(sg_obj)
64
40
  else:
65
41
  raise NotImplementedError
66
42
  point_idx = np.concatenate([point_idx, sampling_obj.samples(size=1, **kwargs)]) # 每种方法只采一个
@@ -68,10 +44,12 @@ def addition_samples(sg_obj, size=None, probability=None, **kwargs):
68
44
  return point_idx
69
45
 
70
46
 
71
- class SamplingBase:
47
+ class SGSamplerBase:
72
48
  def __init__(self, sg_obj, **kwargs):
73
49
  self.sg_obj = sg_obj
74
50
  self.threshold = kwargs.get('threshold', 0.37) # 0.37 is half of H-H bond
51
+ self.weight = kwargs.get('weight', 1.0) # 采样的几率,最后进行归一化处理
52
+ self.kwargs = kwargs
75
53
 
76
54
  @property
77
55
  def _pop_size(self):
@@ -91,29 +69,34 @@ class SamplingBase:
91
69
 
92
70
  :return:
93
71
  """
94
- if point_idx is not None:
95
- if self.sg_obj.sample_idx is not None:
96
- self.sg_obj.sample_idx = np.concatenate([self.sg_obj.sample_idx, point_idx])
97
- self.sg_obj._sample_vector = np.concatenate([self.sg_obj._sample_vector, self.sg_obj.vector[point_idx]])
98
- self.sg_obj.sample_points = np.concatenate([self.sg_obj.sample_points, self.sg_obj.points[point_idx]])
99
- else:
100
- self.sg_obj.sample_idx = np.array(point_idx)
101
- self.sg_obj._sample_vector = self.sg_obj.vector[point_idx]
102
- self.sg_obj.sample_points = self.sg_obj.points[point_idx]
72
+ if point_idx is None:
73
+ point_idx = []
74
+ elif type(point_idx) is int:
75
+ point_idx = [point_idx]
76
+
77
+ for p in point_idx:
78
+ self.sg_obj.sample_idx = p
103
79
 
104
80
  def _samples(self, size, **kwargs):
105
81
  raise NotImplementedError
106
82
 
107
83
  def samples(self, size=1, **kwargs):
108
- point_idx = self._samples(size=size, **kwargs)
109
- self._append_sample_to_sg(point_idx=point_idx)
110
-
111
- return point_idx
84
+ result = []
85
+ curr_size = size
86
+ loop = 0
87
+ while len(result) < size and loop < 10:
88
+ point_idx = self._samples(size=curr_size, **kwargs)
89
+ filtered_idx = self.exclude_too_close_sample(point_idx)
90
+ self._append_sample_to_sg(point_idx=filtered_idx)
91
+ result += filtered_idx
92
+ curr_size = size - len(filtered_idx)
93
+ loop += 1
94
+ return result
112
95
 
113
96
  def exclude_too_close_sample(self, idx_list, threshold=None):
114
97
  if threshold is None:
115
98
  threshold = self.threshold
116
- if self.sg_obj.sample_idx:
99
+ if self.sg_obj.sample_idx is not None:
117
100
  unique_idx_list = [i for i in idx_list if i not in self.sg_obj.sample_idx]
118
101
  points = list(self.sg_obj.sample_points)
119
102
  else:
@@ -131,11 +114,11 @@ class SamplingBase:
131
114
  points.append(p)
132
115
  new_idx_list.append(idx)
133
116
 
134
- if len(new_idx_list) != idx_list:
117
+ if len(new_idx_list) != len(idx_list):
135
118
  print(f"Exclude too close sample {set(idx_list)-set(new_idx_list)}")
136
119
  return new_idx_list
137
120
 
138
- class KeyPointSampling(SamplingBase):
121
+ class KeyPointSGSampler(SGSamplerBase):
139
122
  """
140
123
  关键点采样,使用 vip_id
141
124
  """
@@ -145,8 +128,13 @@ class KeyPointSampling(SamplingBase):
145
128
  self.sg_obj._clusters = clusters
146
129
  return sample_idx
147
130
 
131
+ def samples(self, size=None, **kwargs):
132
+ point_idx = self._samples(**kwargs)
133
+ filtered_idx = self.exclude_too_close_sample(point_idx)
134
+ self._append_sample_to_sg(point_idx=filtered_idx)
135
+ return filtered_idx
148
136
 
149
- class RandomSampling(SamplingBase):
137
+ class RandomSGSampler(SGSamplerBase):
150
138
  """
151
139
  完全随机的选择点,仅用于测试,效率太低。
152
140
  """
@@ -158,54 +146,71 @@ class RandomSampling(SamplingBase):
158
146
  self.seed = None
159
147
 
160
148
  def _samples(self, size, **kwargs):
161
- rng = np.random.default_rng(self.seed)
162
- pop_idx = rng.choice(self._population, size=size)
163
- return pop_idx
149
+ idx = random.sample(self._population, size)
150
+ return idx
164
151
 
165
152
 
166
- class MaxSigmaSampling(SamplingBase):
153
+ class MaxSigmaSGSampler(SGSamplerBase):
167
154
  """
168
155
  对最大误差的点进行采样
169
156
  """
170
157
  def _samples(self, size, **kwargs):
171
158
  if 'energy' in self.sg_obj.grid_property:
172
159
  # 如果已经读入了一些能量,则返回误差最大的点
173
- idx = self.sg_obj.grid_property_sigma['energy'].argmax()
174
- return [idx]
160
+ sigma_array = self.sg_obj.grid_property['energy']
161
+ sigma0 = sigma_array.max()
162
+ idx_list = np.argwhere(sigma_array <= sigma0-0.1).flatten().tolist()
163
+ idx = random.sample(idx_list, size)
164
+ return idx
175
165
  else:
176
166
  raise "No energy for all population, pls do initial sampling first!"
177
167
 
178
168
 
179
- class InitialSampling(SamplingBase):
169
+ class MinEnergySGSampler(SGSamplerBase):
180
170
  """
181
- 结合使用 KeyPointSampling 和 MaxDiversitySampling
171
+ 对最大误差的点进行采样
182
172
  """
183
-
184
173
  def _samples(self, size, **kwargs):
174
+ if 'energy' in self.sg_obj.grid_property:
175
+ E_array = self.sg_obj.grid_property['energy']
176
+ # 如果已经读入了一些能量,则返回能量最低的点 (<0.1eV 以内,然后随机选一个)
177
+ E0 = E_array.min()
178
+ idx_list = np.argwhere(E_array <= E0+0.1).flatten().tolist()
179
+ idx = random.sample(idx_list, size)
180
+ return idx
181
+ else:
182
+ raise "No energy for all population, pls do initial sampling first!"
183
+
184
+
185
+ class InitialSGSampler(SGSamplerBase):
186
+ """
187
+ 结合使用 KeyPointSampling 和 MaxDiversitySampling
188
+ """
189
+ def _samples(self, size=None, **kwargs):
190
+ # 先进行 KeyPoint sampling,数量不够再进行 Max diversity sampling
185
191
  vip_idx = self.sg_obj.unique_vip_id
186
- clusters = Cluster(n_clusters=len(vip_idx)).fit(self.sg_obj.vector)
187
- self.sg_obj._clusters = clusters
188
- # 如果 size 小于 vip,则从中随机选取部分
192
+ if size is None:
193
+ size = len(vip_idx)
194
+
189
195
  if size == len(vip_idx):
190
- sample_idx = vip_idx
191
- self._append_sample_to_sg(point_idx=sample_idx)
196
+ # 已经排除了距离过近的点,而且已经加入到了sg_obj
197
+ sample_idx = KeyPointSGSampler(self.sg_obj, **self.kwargs).samples(**kwargs)
192
198
  elif size < len(vip_idx):
193
- print("Warning: The initial sampling size is smaller than the number of key points")
194
- rng = np.random.default_rng()
195
- comb_vip = list(itertools.combinations(vip_idx, size))
196
- sample_idx = rng.choice(comb_vip)
199
+ print(f"The initial sampling size {size} is smaller than the number of key points {len(vip_idx)}.")
200
+ sample_idx = random.sample(vip_idx, size)
197
201
  self._append_sample_to_sg(point_idx=sample_idx)
198
- else: # 如果 size 大于 vip,则需要 MaxDiversitySampling 新增一些点
199
- self._append_sample_to_sg(point_idx=vip_idx) # 先增加进去vip 点作为已经采样的点,再进行最大多样性采样
200
- adding_sample = MaxDiversitySampling(self.sg_obj).samples(size=size-len(vip_idx), **kwargs)
201
- self._append_sample_to_sg(point_idx=adding_sample)
202
- sample_idx = np.concatenate([vip_idx, adding_sample])
202
+ else:
203
+ sample_idx = KeyPointSGSampler(self.sg_obj, **self.kwargs).samples(**kwargs)
204
+ while len(sample_idx) < size:
205
+ adding_sample = MaxDiversitySGSampler(self.sg_obj).samples(size=size-len(sample_idx),**kwargs)
206
+ sample_idx = np.concatenate([sample_idx, adding_sample])
203
207
  return sample_idx
204
208
 
205
209
  def samples(self, size=1, **kwargs):
206
210
  return self._samples(size=size, **kwargs)
207
211
 
208
- class MaxDiversitySampling(SamplingBase):
212
+
213
+ class MaxDiversitySGSampler(SGSamplerBase):
209
214
  """
210
215
  对当前采样结构差异最大的点进行采样
211
216
  基本思路是这样的:
@@ -215,14 +220,13 @@ class MaxDiversitySampling(SamplingBase):
215
220
  """
216
221
  def _samples(self, size, center=True, **kwargs):
217
222
  """
218
-
219
223
  :param size:
220
224
  :param center: 是否取中心。如果不是,则取能量最小值的点。如果没有能量则报错。
221
225
  :param kwargs:
222
226
  :return:
223
227
  """
224
228
  # 判断是否有过往的采样点,如果没有,调用 InitialSampling
225
- if self.sg_obj.sample_idx is None:
229
+ if len(self.sg_obj.sample_idx) == 0:
226
230
  clusters = Cluster(n_clusters=size).fit(self.sg_obj.vector)
227
231
  virgin = list(set(clusters.labels_))
228
232
  else:
@@ -254,9 +258,7 @@ class MaxDiversitySampling(SamplingBase):
254
258
  virgin = larger_virgin
255
259
  break
256
260
  # 从 virgin 里面选取 size 个点
257
- rng = np.random.default_rng()
258
- comb_vip = list(itertools.combinations(list(virgin), size))
259
- cluster_idx = rng.choice(comb_vip)
261
+ cluster_idx = random.sample(list(virgin), size)
260
262
  if (not center) and 'energy' not in self.sg_obj.grid_property:
261
263
  center = True
262
264
  print("Warning: Can't get cluster minimum energy, use cluster center instead!")
@@ -278,7 +280,7 @@ class MaxDiversitySampling(SamplingBase):
278
280
  return point_idx
279
281
 
280
282
 
281
- class NewtonSampling(SamplingBase):
283
+ class NewtonSGSampler(SGSamplerBase):
282
284
  """
283
285
  沿着受力方向进行采样
284
286
  """
@@ -287,7 +289,7 @@ class NewtonSampling(SamplingBase):
287
289
  raise NotImplementedError
288
290
 
289
291
 
290
- class RandomWalk(SamplingBase):
292
+ class RandomWalk(SGSamplerBase):
291
293
  """
292
294
  从给定点出发随机行走进行采样
293
295
  """
@@ -295,7 +297,7 @@ class RandomWalk(SamplingBase):
295
297
  raise NotImplementedError
296
298
 
297
299
 
298
- class SystematicSampling(SamplingBase):
300
+ class SystematicSGSampler(SGSamplerBase):
299
301
  """
300
302
  等距采样。主要用于测试。
301
303
  """
@@ -1,13 +1,12 @@
1
- Metadata-Version: 2.2
1
+ Metadata-Version: 2.4
2
2
  Name: surface_construct
3
- Version: 0.8.4
3
+ Version: 0.9
4
4
  Summary: Surface termination construction especially for complex model, such as oxides or carbides.
5
5
  Home-page: https://gitee.com/pjren/surface_construct/
6
6
  Author: ren
7
7
  Author-email: 0403114076@163.com
8
8
  License: GPL
9
9
  Classifier: Programming Language :: Python :: 3
10
- Classifier: License :: OSI Approved :: GNU Lesser General Public License v3 (LGPLv3)
11
10
  Classifier: Operating System :: OS Independent
12
11
  Description-Content-Type: text/markdown
13
12
  License-File: LICENSE
@@ -25,6 +24,7 @@ Dynamic: description
25
24
  Dynamic: description-content-type
26
25
  Dynamic: home-page
27
26
  Dynamic: license
27
+ Dynamic: license-file
28
28
  Dynamic: requires-dist
29
29
  Dynamic: summary
30
30
 
@@ -2,15 +2,9 @@ LICENSE
2
2
  README.md
3
3
  setup.py
4
4
  surface_construct/__init__.py
5
- surface_construct/atoms.py
6
5
  surface_construct/db.py
7
6
  surface_construct/default_parameter.py
8
- surface_construct/sampling.py
9
- surface_construct/structure.py
10
- surface_construct/surface.py
11
- surface_construct/surface_grid.py
12
- surface_construct/utils.py
13
- surface_construct/weight_functions.py
7
+ surface_construct/sg_sampler.py
14
8
  surface_construct.egg-info/PKG-INFO
15
9
  surface_construct.egg-info/SOURCES.txt
16
10
  surface_construct.egg-info/dependency_links.txt
@@ -18,4 +12,5 @@ surface_construct.egg-info/requires.txt
18
12
  surface_construct.egg-info/top_level.txt
19
13
  tests/test_sampling1.py
20
14
  tests/test_sampling2.py
21
- tests/test_surface_grid.py
15
+ tests/test_surface_grid.py
16
+ tests/test_task.py
@@ -1,8 +1,8 @@
1
1
  import pytest
2
2
  import ase.io
3
3
 
4
- from surface_construct.sampling import InitialSampling, KeyPointSampling, MaxDiversitySampling
5
- from surface_construct.surface_grid import SurfaceGrid
4
+ from surface_construct.sg_sampler import InitialSGSampler, KeyPointSGSampler, MaxDiversitySGSampler
5
+ from surface_construct import SurfaceGrid
6
6
 
7
7
 
8
8
  class TestSampling1:
@@ -19,17 +19,18 @@ class TestSampling1:
19
19
 
20
20
  def test_initial_sampling(self):
21
21
  for size in range(1, 6):
22
- sample_obj = InitialSampling(self.sg_obj)
22
+ sample_obj = InitialSGSampler(self.sg_obj)
23
23
  samples = sample_obj.samples(size=size)
24
24
  self.sg_obj.plot_cluster(figname=f'sampling_{size}')
25
+ self.sg_obj.del_sample()
25
26
 
26
27
  def test_keypoint_sampling(self):
27
- sample_obj = KeyPointSampling(self.sg_obj)
28
+ sample_obj = KeyPointSGSampler(self.sg_obj)
28
29
  samples = sample_obj.samples()
29
30
  self.sg_obj.plot_cluster(figname=f'KeyPoint_sampling')
30
31
 
31
32
  def test_max_diversity_sampling(self):
32
- sample_obj = MaxDiversitySampling(self.sg_obj)
33
+ sample_obj = MaxDiversitySGSampler(self.sg_obj)
33
34
  samples = sample_obj.samples(size=4)
34
35
  self.sg_obj.plot_cluster(figname=f'MaxDiversity_sampling')
35
36
 
@@ -1,8 +1,8 @@
1
1
  import pytest
2
2
  import ase.io
3
3
 
4
- from surface_construct.surface_grid import SurfaceGrid
5
- from surface_construct.sampling import InitialSampling, KeyPointSampling, MaxDiversitySampling
4
+ from surface_construct import SurfaceGrid
5
+ from surface_construct.sg_sampler import InitialSGSampler, KeyPointSGSampler, MaxDiversitySGSampler
6
6
 
7
7
  class TestSampling2:
8
8
  """
@@ -18,18 +18,18 @@ class TestSampling2:
18
18
 
19
19
  def test_initial_sampling(self):
20
20
  for size in [4, 8, 16]:
21
- sample_obj = InitialSampling(self.sg_obj)
21
+ sample_obj = InitialSGSampler(self.sg_obj)
22
22
  samples = sample_obj.samples(size=size)
23
23
  self.sg_obj.plot_cluster(figname=f'sampling_{size}')
24
24
  self.sg_obj.del_sample()
25
25
 
26
26
  def test_keypoint_sampling(self):
27
- sample_obj = KeyPointSampling(self.sg_obj)
27
+ sample_obj = KeyPointSGSampler(self.sg_obj)
28
28
  samples = sample_obj.samples()
29
29
  self.sg_obj.plot_cluster(figname=f'KeyPoint_sampling')
30
30
 
31
31
  def test_max_diversity_sampling(self):
32
- sample_obj = MaxDiversitySampling(self.sg_obj)
32
+ sample_obj = MaxDiversitySGSampler(self.sg_obj)
33
33
  samples = sample_obj.samples(size=10)
34
34
  self.sg_obj.plot_cluster(figname=f'MaxDiversity_sampling')
35
35
 
@@ -38,7 +38,7 @@ class TestSampling2:
38
38
 
39
39
  def test_exclude_too_close_sampling(self):
40
40
  samples = [0,1,2,3]
41
- sample_obj = KeyPointSampling(self.sg_obj)
41
+ sample_obj = KeyPointSGSampler(self.sg_obj)
42
42
  result_sample = sample_obj.exclude_too_close_sample(samples)
43
43
  print(self.sg_obj.points[samples])
44
44
  print(result_sample)
@@ -5,9 +5,9 @@ import numpy as np
5
5
  from ase.visualize import view
6
6
  from ase import Atoms
7
7
  import ase
8
- from surface_construct.surface_grid import GridGenerator
8
+ from surface_construct import GridGenerator
9
9
  from ase.cluster.cubic import FaceCenteredCubic
10
- from surface_construct.surface_grid import SurfaceGrid
10
+ from surface_construct import SurfaceGrid
11
11
 
12
12
  class TestSurfaceGrid:
13
13
  def test_Cu_cluster(self):
@@ -0,0 +1,83 @@
1
+ import os
2
+ import shutil
3
+ from random import randint
4
+
5
+ import ase.io
6
+ import pytest
7
+ from lasp_ase.lasp import Lasp
8
+ from surface_construct import SurfaceGrid
9
+ from surface_construct.structures import AdsGridCombiner
10
+ from surface_construct.structures.adsorbate import Adsorbate
11
+ from surface_construct.tasks import SurfaceSiteSampleTask
12
+ from ase.optimize import LBFGS, BFGS
13
+
14
+
15
+ class TestSurfaceSiteSampling:
16
+ """
17
+ Simple Ru 0001 suface
18
+ """
19
+ def setup_method(self):
20
+ self.task_dir = '%x' % randint(16**3, 16**4 - 1)
21
+ if not os.path.exists(self.task_dir):
22
+ os.makedirs(self.task_dir)
23
+ os.chdir(self.task_dir)
24
+
25
+ def teardown_method(self):
26
+ os.chdir("../")
27
+ # if os.path.exists(self.task_dir):
28
+ # shutil.rmtree(self.task_dir)
29
+
30
+ def test_job1(self):
31
+ """
32
+ C atom on Ru(0001) surface
33
+ :return:
34
+ """
35
+ shutil.copyfile('../atoms_files/RuCHO_lasp.in', 'lasp.in')
36
+ shutil.copyfile('../atoms_files/RuCHO_pf2.pot', 'RuCHO.pot')
37
+ atoms = ase.io.read('../atoms_files/ru_0001_POSCAR')
38
+ atoms.calc = Lasp()
39
+ ads_atoms = ase.Atoms('C',[[0.,0.,0.]])
40
+ ads_obj = Adsorbate(ads_atoms)
41
+ sg_obj = SurfaceGrid(atoms)
42
+ ads_grid_comb = AdsGridCombiner(sg_obj, ads_obj)
43
+ sampler =[
44
+ {
45
+ 'size': 3, # 采样大小
46
+ 'surface': "InitialSGSampler", # 表面采样方法
47
+ }, # 第一步采样
48
+ {
49
+ 'size': 5, # 采样大小
50
+ 'surface': ("MaxDiversitySGSampler", "MinEnergySGSampler", "MaxSigmaSGSampler"), # 表面采样方法
51
+ 'weight': (0.1, 0.45, 0.45), # 表面采样方法的权重
52
+ } # 第二步采样
53
+ ]
54
+ task_obj = SurfaceSiteSampleTask(combiner=ads_grid_comb, sampler=sampler, optimizer=BFGS)
55
+ task_obj.run()
56
+ print('Done')
57
+
58
+ def test_job2(self):
59
+ """
60
+ H atom on CuO/Cu surface
61
+ :return:
62
+ """
63
+ shutil.copyfile('../atoms_files/CuCHO_lasp.in', 'lasp.in')
64
+ shutil.copyfile('../atoms_files/CuCHO.pot', 'CuCHO.pot')
65
+ atoms = ase.io.read('../atoms_files/CuOx-Cu100-CONTCAR')
66
+ atoms.calc = Lasp()
67
+ ads_atoms = ase.Atoms('H',[[0.,0.,0.]])
68
+ ads_obj = Adsorbate(ads_atoms)
69
+ sg_obj = SurfaceGrid(atoms)
70
+ ads_grid_comb = AdsGridCombiner(sg_obj, ads_obj)
71
+ sampler =[
72
+ {
73
+ 'surface': "KeyPointSGSampler", # 表面采样方法
74
+ }, # 第一步采样
75
+ {
76
+ 'size': 3, # 采样大小
77
+ 'surface': ("MaxDiversitySGSampler", "MinEnergySGSampler", "MaxSigmaSGSampler"), # 表面采样方法
78
+ 'weight': (0.4, 0.3, 0.3), # 表面采样方法的权重
79
+ } # 第二步采样
80
+ ]
81
+ task_obj = SurfaceSiteSampleTask(combiner=ads_grid_comb, sampler=sampler, optimizer=LBFGS)
82
+ task_obj.run()
83
+ print('Done')
@@ -1,4 +0,0 @@
1
- from surface_construct.surface import Crystal, Surface, Slab, Termination
2
- from surface_construct.surface import get_terminations_score
3
-
4
- __all__ = ['Crystal', 'Surface', 'Slab', 'Termination']