surface-construct 0.8.2__tar.gz → 0.8.4__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {surface_construct-0.8.2/surface_construct.egg-info → surface_construct-0.8.4}/PKG-INFO +11 -5
- {surface_construct-0.8.2 → surface_construct-0.8.4}/setup.py +1 -4
- {surface_construct-0.8.2 → surface_construct-0.8.4}/surface_construct/sampling.py +91 -97
- surface_construct-0.8.4/surface_construct/surface_grid.py +1115 -0
- surface_construct-0.8.4/surface_construct/utils.py +177 -0
- surface_construct-0.8.4/surface_construct/weight_functions.py +65 -0
- {surface_construct-0.8.2 → surface_construct-0.8.4/surface_construct.egg-info}/PKG-INFO +11 -5
- {surface_construct-0.8.2 → surface_construct-0.8.4}/surface_construct.egg-info/SOURCES.txt +5 -1
- {surface_construct-0.8.2 → surface_construct-0.8.4}/surface_construct.egg-info/requires.txt +0 -3
- surface_construct-0.8.4/tests/test_sampling1.py +37 -0
- surface_construct-0.8.4/tests/test_sampling2.py +44 -0
- surface_construct-0.8.4/tests/test_surface_grid.py +105 -0
- surface_construct-0.8.2/surface_construct/surface_grid.py +0 -705
- surface_construct-0.8.2/surface_construct/utils.py +0 -361
- {surface_construct-0.8.2 → surface_construct-0.8.4}/LICENSE +0 -0
- {surface_construct-0.8.2 → surface_construct-0.8.4}/README.md +0 -0
- {surface_construct-0.8.2 → surface_construct-0.8.4}/setup.cfg +0 -0
- {surface_construct-0.8.2 → surface_construct-0.8.4}/surface_construct/__init__.py +0 -0
- {surface_construct-0.8.2 → surface_construct-0.8.4}/surface_construct/atoms.py +0 -0
- {surface_construct-0.8.2 → surface_construct-0.8.4}/surface_construct/db.py +0 -0
- {surface_construct-0.8.2 → surface_construct-0.8.4}/surface_construct/default_parameter.py +0 -0
- {surface_construct-0.8.2 → surface_construct-0.8.4}/surface_construct/structure.py +0 -0
- {surface_construct-0.8.2 → surface_construct-0.8.4}/surface_construct/surface.py +0 -0
- {surface_construct-0.8.2 → surface_construct-0.8.4}/surface_construct.egg-info/dependency_links.txt +0 -0
- {surface_construct-0.8.2 → surface_construct-0.8.4}/surface_construct.egg-info/top_level.txt +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
|
-
Metadata-Version: 2.
|
|
1
|
+
Metadata-Version: 2.2
|
|
2
2
|
Name: surface_construct
|
|
3
|
-
Version: 0.8.
|
|
3
|
+
Version: 0.8.4
|
|
4
4
|
Summary: Surface termination construction especially for complex model, such as oxides or carbides.
|
|
5
5
|
Home-page: https://gitee.com/pjren/surface_construct/
|
|
6
6
|
Author: ren
|
|
@@ -13,14 +13,20 @@ Description-Content-Type: text/markdown
|
|
|
13
13
|
License-File: LICENSE
|
|
14
14
|
Requires-Dist: ase
|
|
15
15
|
Requires-Dist: networkx
|
|
16
|
-
Requires-Dist: numpy
|
|
17
16
|
Requires-Dist: spglib
|
|
18
17
|
Requires-Dist: pandas
|
|
19
18
|
Requires-Dist: tqdm
|
|
20
|
-
Requires-Dist: matplotlib
|
|
21
|
-
Requires-Dist: scipy
|
|
22
19
|
Requires-Dist: scikit-learn
|
|
23
20
|
Requires-Dist: scikit-image
|
|
21
|
+
Dynamic: author
|
|
22
|
+
Dynamic: author-email
|
|
23
|
+
Dynamic: classifier
|
|
24
|
+
Dynamic: description
|
|
25
|
+
Dynamic: description-content-type
|
|
26
|
+
Dynamic: home-page
|
|
27
|
+
Dynamic: license
|
|
28
|
+
Dynamic: requires-dist
|
|
29
|
+
Dynamic: summary
|
|
24
30
|
|
|
25
31
|
# 基于分层采样策略的催化剂表面位点全局分析
|
|
26
32
|
|
|
@@ -6,19 +6,16 @@ with open("README.md", "r", encoding='utf-8') as f:
|
|
|
6
6
|
install_requires = [
|
|
7
7
|
'ase',
|
|
8
8
|
'networkx',
|
|
9
|
-
'numpy',
|
|
10
9
|
'spglib',
|
|
11
10
|
'pandas',
|
|
12
11
|
'tqdm',
|
|
13
|
-
'matplotlib',
|
|
14
|
-
'scipy',
|
|
15
12
|
'scikit-learn',
|
|
16
13
|
'scikit-image'
|
|
17
14
|
]
|
|
18
15
|
|
|
19
16
|
setup(
|
|
20
17
|
name='surface_construct',
|
|
21
|
-
version='0.8.
|
|
18
|
+
version='0.8.4',
|
|
22
19
|
packages=['surface_construct'],
|
|
23
20
|
url='https://gitee.com/pjren/surface_construct/',
|
|
24
21
|
license='GPL',
|
|
@@ -4,7 +4,7 @@ TODO: 关键点采样:top 位置、hollow位,bridge 位等等。
|
|
|
4
4
|
import itertools
|
|
5
5
|
import numpy as np
|
|
6
6
|
from ase.geometry import get_distances
|
|
7
|
-
from scipy.spatial import ConvexHull
|
|
7
|
+
from scipy.spatial import ConvexHull, cKDTree
|
|
8
8
|
from scipy.spatial.distance import cdist
|
|
9
9
|
from scipy.special import comb
|
|
10
10
|
from sklearn.cluster import KMeans as Cluster
|
|
@@ -71,6 +71,7 @@ def addition_samples(sg_obj, size=None, probability=None, **kwargs):
|
|
|
71
71
|
class SamplingBase:
|
|
72
72
|
def __init__(self, sg_obj, **kwargs):
|
|
73
73
|
self.sg_obj = sg_obj
|
|
74
|
+
self.threshold = kwargs.get('threshold', 0.37) # 0.37 is half of H-H bond
|
|
74
75
|
|
|
75
76
|
@property
|
|
76
77
|
def _pop_size(self):
|
|
@@ -109,21 +110,46 @@ class SamplingBase:
|
|
|
109
110
|
|
|
110
111
|
return point_idx
|
|
111
112
|
|
|
113
|
+
def exclude_too_close_sample(self, idx_list, threshold=None):
|
|
114
|
+
if threshold is None:
|
|
115
|
+
threshold = self.threshold
|
|
116
|
+
if self.sg_obj.sample_idx:
|
|
117
|
+
unique_idx_list = [i for i in idx_list if i not in self.sg_obj.sample_idx]
|
|
118
|
+
points = list(self.sg_obj.sample_points)
|
|
119
|
+
else:
|
|
120
|
+
unique_idx_list = idx_list[:]
|
|
121
|
+
points = []
|
|
122
|
+
new_idx_list = []
|
|
123
|
+
for idx in unique_idx_list:
|
|
124
|
+
p = self.sg_obj.points[idx]
|
|
125
|
+
if len(points) == 0:
|
|
126
|
+
points.append(p)
|
|
127
|
+
new_idx_list.append(idx)
|
|
128
|
+
continue
|
|
129
|
+
tree = cKDTree(points)
|
|
130
|
+
if len(tree.query_ball_point(x=p, r=threshold,p=2))==0:
|
|
131
|
+
points.append(p)
|
|
132
|
+
new_idx_list.append(idx)
|
|
133
|
+
|
|
134
|
+
if len(new_idx_list) != idx_list:
|
|
135
|
+
print(f"Exclude too close sample {set(idx_list)-set(new_idx_list)}")
|
|
136
|
+
return new_idx_list
|
|
112
137
|
|
|
113
138
|
class KeyPointSampling(SamplingBase):
|
|
114
139
|
"""
|
|
115
|
-
|
|
116
|
-
需要一个基础func,从xy坐标,找到对应的格点。
|
|
140
|
+
关键点采样,使用 vip_id
|
|
117
141
|
"""
|
|
118
|
-
def _samples(self,
|
|
119
|
-
|
|
142
|
+
def _samples(self, **kwargs):
|
|
143
|
+
sample_idx = self.sg_obj.unique_vip_id
|
|
144
|
+
clusters = Cluster(n_clusters=len(sample_idx)).fit(self.sg_obj.vector)
|
|
145
|
+
self.sg_obj._clusters = clusters
|
|
146
|
+
return sample_idx
|
|
120
147
|
|
|
121
148
|
|
|
122
149
|
class RandomSampling(SamplingBase):
|
|
123
150
|
"""
|
|
124
151
|
完全随机的选择点,仅用于测试,效率太低。
|
|
125
152
|
"""
|
|
126
|
-
|
|
127
153
|
def __init__(self, sg_obj, **kwargs):
|
|
128
154
|
super().__init__(sg_obj, **kwargs)
|
|
129
155
|
if 'seed' in kwargs:
|
|
@@ -141,7 +167,6 @@ class MaxSigmaSampling(SamplingBase):
|
|
|
141
167
|
"""
|
|
142
168
|
对最大误差的点进行采样
|
|
143
169
|
"""
|
|
144
|
-
|
|
145
170
|
def _samples(self, size, **kwargs):
|
|
146
171
|
if 'energy' in self.sg_obj.grid_property:
|
|
147
172
|
# 如果已经读入了一些能量,则返回误差最大的点
|
|
@@ -153,62 +178,32 @@ class MaxSigmaSampling(SamplingBase):
|
|
|
153
178
|
|
|
154
179
|
class InitialSampling(SamplingBase):
|
|
155
180
|
"""
|
|
156
|
-
|
|
181
|
+
结合使用 KeyPointSampling 和 MaxDiversitySampling
|
|
157
182
|
"""
|
|
158
183
|
|
|
159
184
|
def _samples(self, size, **kwargs):
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
#
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
#
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
cluster0.fit(self.sg_obj.vector)
|
|
179
|
-
mesh_centers = cluster0.cluster_centers_
|
|
180
|
-
self.sg_obj._mesh_centers = mesh_centers
|
|
181
|
-
cluster = Cluster(n_clusters=size)
|
|
182
|
-
cluster.fit(mesh_centers)
|
|
183
|
-
self.sg_obj._clusters = cluster
|
|
184
|
-
nvert = len(vertices)
|
|
185
|
-
if nvert >= size:
|
|
186
|
-
print(f"Warning: Sample number better be larger than {nvert}!")
|
|
187
|
-
if size == 1:
|
|
188
|
-
sample_idx = np.random.choice(vertices,1)
|
|
189
|
-
elif size==nvert:
|
|
190
|
-
sample_idx = vertices
|
|
191
|
-
else:
|
|
192
|
-
sample_idx = [vertices[i] for i in
|
|
193
|
-
furthest_sites(self.sg_obj.vector[vertices], size)]
|
|
194
|
-
else:
|
|
195
|
-
# 聚类
|
|
196
|
-
cluster2 = Cluster(n_clusters=size-nvert)
|
|
197
|
-
cluster2.fit(mesh_centers)
|
|
198
|
-
center_dist = cdist(cluster2.cluster_centers_, self.sg_obj.vector) # 计算每个点到中心的距离
|
|
199
|
-
sample_idx = vertices + np.argmin(center_dist, axis=-1).tolist()
|
|
185
|
+
vip_idx = self.sg_obj.unique_vip_id
|
|
186
|
+
clusters = Cluster(n_clusters=len(vip_idx)).fit(self.sg_obj.vector)
|
|
187
|
+
self.sg_obj._clusters = clusters
|
|
188
|
+
# 如果 size 小于 vip,则从中随机选取部分
|
|
189
|
+
if size == len(vip_idx):
|
|
190
|
+
sample_idx = vip_idx
|
|
191
|
+
self._append_sample_to_sg(point_idx=sample_idx)
|
|
192
|
+
elif size < len(vip_idx):
|
|
193
|
+
print("Warning: The initial sampling size is smaller than the number of key points")
|
|
194
|
+
rng = np.random.default_rng()
|
|
195
|
+
comb_vip = list(itertools.combinations(vip_idx, size))
|
|
196
|
+
sample_idx = rng.choice(comb_vip)
|
|
197
|
+
self._append_sample_to_sg(point_idx=sample_idx)
|
|
198
|
+
else: # 如果 size 大于 vip,则需要 MaxDiversitySampling 新增一些点
|
|
199
|
+
self._append_sample_to_sg(point_idx=vip_idx) # 先增加进去vip 点作为已经采样的点,再进行最大多样性采样
|
|
200
|
+
adding_sample = MaxDiversitySampling(self.sg_obj).samples(size=size-len(vip_idx), **kwargs)
|
|
201
|
+
self._append_sample_to_sg(point_idx=adding_sample)
|
|
202
|
+
sample_idx = np.concatenate([vip_idx, adding_sample])
|
|
200
203
|
return sample_idx
|
|
201
204
|
|
|
202
|
-
def
|
|
203
|
-
|
|
204
|
-
将采样点加入到 sg_obj.sample_points 和相应的 vector
|
|
205
|
-
:return:
|
|
206
|
-
"""
|
|
207
|
-
if point_idx is not None:
|
|
208
|
-
self.sg_obj.sample_idx = np.asarray(point_idx)
|
|
209
|
-
self.sg_obj._sample_vector = self.sg_obj.vector[point_idx]
|
|
210
|
-
self.sg_obj.sample_points = self.sg_obj.points[point_idx]
|
|
211
|
-
|
|
205
|
+
def samples(self, size=1, **kwargs):
|
|
206
|
+
return self._samples(size=size, **kwargs)
|
|
212
207
|
|
|
213
208
|
class MaxDiversitySampling(SamplingBase):
|
|
214
209
|
"""
|
|
@@ -218,8 +213,7 @@ class MaxDiversitySampling(SamplingBase):
|
|
|
218
213
|
* 判断已经采样点属于的类别,找出没有点的类别,空类
|
|
219
214
|
* 如果空类不止一个,比较这些空类中心与旧点的距离,选择距离最大的点。
|
|
220
215
|
"""
|
|
221
|
-
|
|
222
|
-
def _samples(self, size, center=False, **kwargs):
|
|
216
|
+
def _samples(self, size, center=True, **kwargs):
|
|
223
217
|
"""
|
|
224
218
|
|
|
225
219
|
:param size:
|
|
@@ -229,37 +223,43 @@ class MaxDiversitySampling(SamplingBase):
|
|
|
229
223
|
"""
|
|
230
224
|
# 判断是否有过往的采样点,如果没有,调用 InitialSampling
|
|
231
225
|
if self.sg_obj.sample_idx is None:
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
#
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
226
|
+
clusters = Cluster(n_clusters=size).fit(self.sg_obj.vector)
|
|
227
|
+
virgin = list(set(clusters.labels_))
|
|
228
|
+
else:
|
|
229
|
+
cluster_size = len(self.sg_obj.sample_idx) + size
|
|
230
|
+
nvirgin = 0
|
|
231
|
+
larger_clusters = None
|
|
232
|
+
larger_virgin = None
|
|
233
|
+
virgin = None
|
|
234
|
+
clusters = None
|
|
235
|
+
# 如果等于则停止,并保存 cluster
|
|
236
|
+
while nvirgin != size:
|
|
237
|
+
# 以 len(sample_idx) + size 作为新的聚类的size
|
|
238
|
+
clusters = Cluster(n_clusters=cluster_size).fit(self.sg_obj.vector)
|
|
239
|
+
labels = clusters.labels_[self.sg_obj.sample_idx]
|
|
240
|
+
labels_set = set(labels)
|
|
241
|
+
virgin = set(range(cluster_size)) - labels_set
|
|
242
|
+
nvirgin = len(virgin)
|
|
243
|
+
# 判断分类以后空类数目与size的大小
|
|
244
|
+
# 如果大于size,则减小size,并记录空类的数目
|
|
245
|
+
if nvirgin > size:
|
|
246
|
+
cluster_size -= 1
|
|
247
|
+
larger_clusters = clusters
|
|
248
|
+
larger_virgin = virgin
|
|
249
|
+
# 如果小于 size 则增大size,检查上一个size是否有记录,如果有记录则使用上个size 的记录。从中随机选择size个点作为采样点。
|
|
250
|
+
elif nvirgin < size:
|
|
251
|
+
cluster_size += 1
|
|
252
|
+
if larger_clusters is not None:
|
|
253
|
+
clusters = larger_clusters
|
|
254
|
+
virgin = larger_virgin
|
|
255
|
+
break
|
|
260
256
|
# 从 virgin 里面选取 size 个点
|
|
261
257
|
rng = np.random.default_rng()
|
|
262
|
-
|
|
258
|
+
comb_vip = list(itertools.combinations(list(virgin), size))
|
|
259
|
+
cluster_idx = rng.choice(comb_vip)
|
|
260
|
+
if (not center) and 'energy' not in self.sg_obj.grid_property:
|
|
261
|
+
center = True
|
|
262
|
+
print("Warning: Can't get cluster minimum energy, use cluster center instead!")
|
|
263
263
|
if center:
|
|
264
264
|
# 取中心位置的格点
|
|
265
265
|
centers = clusters.cluster_centers_[cluster_idx]
|
|
@@ -271,11 +271,10 @@ class MaxDiversitySampling(SamplingBase):
|
|
|
271
271
|
for c_id in cluster_idx:
|
|
272
272
|
p_idx = np.arange(len(self.sg_obj.points))[clusters.labels_ == c_id]
|
|
273
273
|
# 求这些点的能量最小值
|
|
274
|
-
if 'energy' not in self.sg_obj.grid_property:
|
|
275
|
-
raise NotImplementedError
|
|
276
274
|
p_energy = self.sg_obj.grid_energy[p_idx]
|
|
277
275
|
point_idx.append(p_idx[p_energy.argmin()])
|
|
278
|
-
|
|
276
|
+
# assign cluster to sg_obj
|
|
277
|
+
self.sg_obj._clusters = clusters
|
|
279
278
|
return point_idx
|
|
280
279
|
|
|
281
280
|
|
|
@@ -292,10 +291,6 @@ class RandomWalk(SamplingBase):
|
|
|
292
291
|
"""
|
|
293
292
|
从给定点出发随机行走进行采样
|
|
294
293
|
"""
|
|
295
|
-
|
|
296
|
-
def __init__(self, sg_obj=None, probability=1.0, **kwargs):
|
|
297
|
-
super().__init__(sg_obj, probability, **kwargs)
|
|
298
|
-
|
|
299
294
|
def _samples(self, size, **kwargs):
|
|
300
295
|
raise NotImplementedError
|
|
301
296
|
|
|
@@ -304,7 +299,6 @@ class SystematicSampling(SamplingBase):
|
|
|
304
299
|
"""
|
|
305
300
|
等距采样。主要用于测试。
|
|
306
301
|
"""
|
|
307
|
-
|
|
308
302
|
def _samples(self, size, **kwargs):
|
|
309
303
|
if 'start' in kwargs:
|
|
310
304
|
start = kwargs['start']
|