surface-construct 0.8__tar.gz → 0.8.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {surface_construct-0.8/surface_construct.egg-info → surface_construct-0.8.1}/PKG-INFO +1 -1
- {surface_construct-0.8 → surface_construct-0.8.1}/setup.py +1 -1
- {surface_construct-0.8 → surface_construct-0.8.1}/surface_construct/sampling.py +295 -284
- {surface_construct-0.8 → surface_construct-0.8.1}/surface_construct/surface_grid.py +9 -66
- {surface_construct-0.8 → surface_construct-0.8.1}/surface_construct/utils.py +3 -3
- {surface_construct-0.8 → surface_construct-0.8.1/surface_construct.egg-info}/PKG-INFO +1 -1
- {surface_construct-0.8 → surface_construct-0.8.1}/LICENSE +0 -0
- {surface_construct-0.8 → surface_construct-0.8.1}/README.md +0 -0
- {surface_construct-0.8 → surface_construct-0.8.1}/setup.cfg +0 -0
- {surface_construct-0.8 → surface_construct-0.8.1}/surface_construct/__init__.py +0 -0
- {surface_construct-0.8 → surface_construct-0.8.1}/surface_construct/atoms.py +0 -0
- {surface_construct-0.8 → surface_construct-0.8.1}/surface_construct/db.py +0 -0
- {surface_construct-0.8 → surface_construct-0.8.1}/surface_construct/default_parameter.py +0 -0
- {surface_construct-0.8 → surface_construct-0.8.1}/surface_construct/structure.py +0 -0
- {surface_construct-0.8 → surface_construct-0.8.1}/surface_construct/surface.py +0 -0
- {surface_construct-0.8 → surface_construct-0.8.1}/surface_construct.egg-info/SOURCES.txt +0 -0
- {surface_construct-0.8 → surface_construct-0.8.1}/surface_construct.egg-info/dependency_links.txt +0 -0
- {surface_construct-0.8 → surface_construct-0.8.1}/surface_construct.egg-info/requires.txt +0 -0
- {surface_construct-0.8 → surface_construct-0.8.1}/surface_construct.egg-info/top_level.txt +0 -0
|
@@ -1,284 +1,295 @@
|
|
|
1
|
-
"""
|
|
2
|
-
TODO: 关键点采样:top 位置、hollow位,bridge 位等等。
|
|
3
|
-
"""
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
from
|
|
8
|
-
from scipy.spatial
|
|
9
|
-
from
|
|
10
|
-
import
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
"""
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
self.sg_obj.
|
|
77
|
-
self.sg_obj.
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
def
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
"""
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
:
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
#
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
"""
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
1
|
+
"""
|
|
2
|
+
TODO: 关键点采样:top 位置、hollow位,bridge 位等等。
|
|
3
|
+
"""
|
|
4
|
+
from logging import warning
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
from ase.geometry import get_distances
|
|
8
|
+
from scipy.spatial import ConvexHull
|
|
9
|
+
from scipy.spatial.distance import cdist
|
|
10
|
+
from sklearn.cluster import KMeans as Cluster
|
|
11
|
+
import random
|
|
12
|
+
|
|
13
|
+
from surface_construct.utils import furthest_sites
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def addition_samples(sg_obj, size=None, probability=None, **kwargs):
|
|
17
|
+
if 'seed' in kwargs:
|
|
18
|
+
seed = kwargs['seed']
|
|
19
|
+
else:
|
|
20
|
+
seed = None
|
|
21
|
+
if probability is None:
|
|
22
|
+
probability = {
|
|
23
|
+
"max_sigma": 0.2, # 采样方法的概率
|
|
24
|
+
"max_diversity": 0.8,
|
|
25
|
+
}
|
|
26
|
+
if size is None:
|
|
27
|
+
size = 1
|
|
28
|
+
|
|
29
|
+
# 归一化
|
|
30
|
+
total = sum(probability.values())
|
|
31
|
+
if total != 1.0:
|
|
32
|
+
probability = {k: v / total for k, v in probability.items()}
|
|
33
|
+
|
|
34
|
+
rng = np.random.default_rng(seed)
|
|
35
|
+
method_list = rng.choice(list(probability.keys()), size=size, p=list(probability.values()))
|
|
36
|
+
|
|
37
|
+
point_idx = np.array([], dtype=int)
|
|
38
|
+
for method in method_list:
|
|
39
|
+
method_lower = method.lower()
|
|
40
|
+
if method_lower == 'max_sigma':
|
|
41
|
+
sampling_obj = MaxSigmaSampling(sg_obj)
|
|
42
|
+
elif method_lower == 'max_diversity':
|
|
43
|
+
sampling_obj = MaxDiversitySampling(sg_obj)
|
|
44
|
+
else:
|
|
45
|
+
raise NotImplementedError
|
|
46
|
+
point_idx = np.concatenate([point_idx, sampling_obj.samples(size=1, **kwargs)]) # 每种方法只采一个
|
|
47
|
+
|
|
48
|
+
return point_idx
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class SamplingBase:
|
|
52
|
+
def __init__(self, sg_obj, **kwargs):
|
|
53
|
+
self.sg_obj = sg_obj
|
|
54
|
+
|
|
55
|
+
@property
|
|
56
|
+
def _pop_size(self):
|
|
57
|
+
return len(self.sg_obj.points)
|
|
58
|
+
|
|
59
|
+
@property
|
|
60
|
+
def _population(self):
|
|
61
|
+
"""
|
|
62
|
+
默认的全体是 sg_obj.points 的 index
|
|
63
|
+
:return:
|
|
64
|
+
"""
|
|
65
|
+
return range(self._pop_size)
|
|
66
|
+
|
|
67
|
+
def _append_sample_to_sg(self, point_idx=None):
|
|
68
|
+
"""
|
|
69
|
+
将采样点加入到 sg_obj.sample_points 和相应的 vector
|
|
70
|
+
|
|
71
|
+
:return:
|
|
72
|
+
"""
|
|
73
|
+
if point_idx is not None:
|
|
74
|
+
if self.sg_obj.sample_idx is not None:
|
|
75
|
+
self.sg_obj.sample_idx = np.concatenate([self.sg_obj.sample_idx, point_idx])
|
|
76
|
+
self.sg_obj._sample_vector = np.concatenate([self.sg_obj._sample_vector, self.sg_obj.vector[point_idx]])
|
|
77
|
+
self.sg_obj.sample_points = np.concatenate([self.sg_obj.sample_points, self.sg_obj.points[point_idx]])
|
|
78
|
+
else:
|
|
79
|
+
self.sg_obj.sample_idx = np.array(point_idx)
|
|
80
|
+
self.sg_obj._sample_vector = self.sg_obj.vector[point_idx]
|
|
81
|
+
self.sg_obj.sample_points = self.sg_obj.points[point_idx]
|
|
82
|
+
|
|
83
|
+
def _samples(self, size, **kwargs):
|
|
84
|
+
raise NotImplementedError
|
|
85
|
+
|
|
86
|
+
def samples(self, size=1, **kwargs):
|
|
87
|
+
point_idx = self._samples(size=size, **kwargs)
|
|
88
|
+
self._append_sample_to_sg(point_idx=point_idx)
|
|
89
|
+
|
|
90
|
+
return point_idx
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
class KeyPointSampling(SamplingBase):
|
|
94
|
+
"""
|
|
95
|
+
基本思路:先定位 hollow位,再根据图论分析定位 bridge 位,最后定位top位。第二步,筛选等价位点。
|
|
96
|
+
需要一个基础func,从xy坐标,找到对应的格点。
|
|
97
|
+
"""
|
|
98
|
+
def _samples(self, size, **kwargs):
|
|
99
|
+
pass
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
class RandomSampling(SamplingBase):
|
|
103
|
+
"""
|
|
104
|
+
完全随机的选择点,仅用于测试,效率太低。
|
|
105
|
+
"""
|
|
106
|
+
|
|
107
|
+
def __init__(self, sg_obj, **kwargs):
|
|
108
|
+
super().__init__(sg_obj, **kwargs)
|
|
109
|
+
if 'seed' in kwargs:
|
|
110
|
+
self.seed = kwargs['seed']
|
|
111
|
+
else:
|
|
112
|
+
self.seed = None
|
|
113
|
+
|
|
114
|
+
def _samples(self, size, **kwargs):
|
|
115
|
+
rng = np.random.default_rng(self.seed)
|
|
116
|
+
pop_idx = rng.choice(self._population, size=size)
|
|
117
|
+
return pop_idx
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
class MaxSigmaSampling(SamplingBase):
|
|
121
|
+
"""
|
|
122
|
+
对最大误差的点进行采样
|
|
123
|
+
"""
|
|
124
|
+
|
|
125
|
+
def _samples(self, size, **kwargs):
|
|
126
|
+
if 'energy' in self.sg_obj.grid_property:
|
|
127
|
+
# 如果已经读入了一些能量,则返回误差最大的点
|
|
128
|
+
idx = self.sg_obj.grid_property_sigma['energy'].argmax()
|
|
129
|
+
return [idx]
|
|
130
|
+
else:
|
|
131
|
+
raise "No energy for all population, pls do initial sampling first!"
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
class InitialSampling(SamplingBase):
|
|
135
|
+
"""
|
|
136
|
+
使用聚类-分层采样进行初始采样
|
|
137
|
+
"""
|
|
138
|
+
|
|
139
|
+
def _samples(self, size, **kwargs):
|
|
140
|
+
hull = ConvexHull(self.sg_obj.vector)
|
|
141
|
+
vertices = []
|
|
142
|
+
# 去掉 hull 的 simplices 的角度较大的点
|
|
143
|
+
for i in hull.vertices:
|
|
144
|
+
p1_idx, p2_idx = np.argwhere(hull.simplices == i)
|
|
145
|
+
p0 = hull.points[i]
|
|
146
|
+
p1 = hull.points[hull.simplices[p1_idx[0],1-p1_idx[1]]]
|
|
147
|
+
p2 = hull.points[hull.simplices[p2_idx[0],1-p2_idx[1]]]
|
|
148
|
+
a = p1 - p0
|
|
149
|
+
b = p2 - p0
|
|
150
|
+
cosangle = a.dot(b)/(np.linalg.norm(a) * np.linalg.norm(b))
|
|
151
|
+
if cosangle > np.cos(np.pi*150/180):
|
|
152
|
+
vertices.append(i)
|
|
153
|
+
# 聚类,vector_mesh
|
|
154
|
+
n_vector_mesh = int(hull.volume / (self.sg_obj._vector_unit *
|
|
155
|
+
self.sg_obj.interval)**self.sg_obj.vector.shape[1]) + 1
|
|
156
|
+
cluster0 = Cluster(n_clusters=n_vector_mesh)
|
|
157
|
+
cluster0.fit(self.sg_obj.vector)
|
|
158
|
+
mesh_centers = cluster0.cluster_centers_
|
|
159
|
+
self.sg_obj._mesh_centers = mesh_centers
|
|
160
|
+
cluster = Cluster(n_clusters=size)
|
|
161
|
+
cluster.fit(mesh_centers)
|
|
162
|
+
self.sg_obj._clusters = cluster
|
|
163
|
+
nvert = len(vertices)
|
|
164
|
+
if nvert >= size:
|
|
165
|
+
warning("Sample number better be larger than {nvert}!")
|
|
166
|
+
if size == 1:
|
|
167
|
+
sample_idx = np.random.choice(vertices,1)
|
|
168
|
+
else:
|
|
169
|
+
sample_idx = [vertices[i] for i in
|
|
170
|
+
furthest_sites(self.sg_obj.vector[vertices], size)]
|
|
171
|
+
else:
|
|
172
|
+
# 聚类
|
|
173
|
+
cluster2 = Cluster(n_clusters=size-nvert)
|
|
174
|
+
cluster2.fit(mesh_centers)
|
|
175
|
+
center_dist = cdist(cluster2.cluster_centers_, self.sg_obj.vector) # 计算每个点到中心的距离
|
|
176
|
+
sample_idx = vertices + np.argmin(center_dist, axis=-1).tolist()
|
|
177
|
+
return sample_idx
|
|
178
|
+
|
|
179
|
+
def _append_sample_to_sg(self, point_idx=None):
|
|
180
|
+
"""
|
|
181
|
+
将采样点加入到 sg_obj.sample_points 和相应的 vector
|
|
182
|
+
:return:
|
|
183
|
+
"""
|
|
184
|
+
if point_idx is not None:
|
|
185
|
+
self.sg_obj.sample_idx = np.asarray(point_idx)
|
|
186
|
+
self.sg_obj._sample_vector = self.sg_obj.vector[point_idx]
|
|
187
|
+
self.sg_obj.sample_points = self.sg_obj.points[point_idx]
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
class MaxDiversitySampling(SamplingBase):
|
|
191
|
+
"""
|
|
192
|
+
对当前采样结构差异最大的点进行采样
|
|
193
|
+
基本思路是这样的:
|
|
194
|
+
* 重新进行聚类,
|
|
195
|
+
* 判断已经采样点属于的类别,找出没有点的类别,空类
|
|
196
|
+
* 如果空类不止一个,比较这些空类中心与旧点的距离,选择距离最大的点。
|
|
197
|
+
"""
|
|
198
|
+
|
|
199
|
+
def _samples(self, size, center=False, **kwargs):
|
|
200
|
+
"""
|
|
201
|
+
|
|
202
|
+
:param size:
|
|
203
|
+
:param center: 是否取中心。如果不是,则取能量最小值的点。如果没有能量则报错。
|
|
204
|
+
:param kwargs:
|
|
205
|
+
:return:
|
|
206
|
+
"""
|
|
207
|
+
# 判断是否有过往的采样点,如果没有,调用 InitialSampling
|
|
208
|
+
if self.sg_obj.sample_idx is None:
|
|
209
|
+
raise "Please add initial samples (e.g. InitialSampling) before invoke this method!"
|
|
210
|
+
cluster_size = len(self.sg_obj.sample_idx) + size
|
|
211
|
+
nvirgin = 0
|
|
212
|
+
larger_clusters = None
|
|
213
|
+
larger_virgin = None
|
|
214
|
+
virgin = None
|
|
215
|
+
clusters = None
|
|
216
|
+
# 如果等于则停止,并保存 cluster
|
|
217
|
+
while nvirgin != size:
|
|
218
|
+
# 以 len(sample_idx) + size 作为新的聚类的size
|
|
219
|
+
clusters = Cluster(n_clusters=cluster_size).fit(self.sg_obj.vector)
|
|
220
|
+
labels = clusters.labels_[self.sg_obj.sample_idx]
|
|
221
|
+
labels_set = set(labels)
|
|
222
|
+
virgin = set(range(cluster_size)) - labels_set
|
|
223
|
+
nvirgin = len(virgin)
|
|
224
|
+
# 判断分类以后空类数目与size的大小
|
|
225
|
+
# 如果大于size,则减小size,并记录空类的数目
|
|
226
|
+
if nvirgin > size:
|
|
227
|
+
cluster_size -= 1
|
|
228
|
+
larger_clusters = clusters
|
|
229
|
+
larger_virgin = virgin
|
|
230
|
+
# 如果小于 size 则增大size,检查上一个size是否有记录,如果有记录则使用上个size 的记录。从中随机选择size个点作为采样点。
|
|
231
|
+
elif nvirgin < size:
|
|
232
|
+
cluster_size += 1
|
|
233
|
+
if larger_clusters is not None:
|
|
234
|
+
clusters = larger_clusters
|
|
235
|
+
virgin = larger_virgin
|
|
236
|
+
break
|
|
237
|
+
# 从 virgin 里面选取 size 个点
|
|
238
|
+
rng = np.random.default_rng()
|
|
239
|
+
cluster_idx = rng.choice(list(virgin), size=size)
|
|
240
|
+
if center:
|
|
241
|
+
# 取中心位置的格点
|
|
242
|
+
centers = clusters.cluster_centers_[cluster_idx]
|
|
243
|
+
center_dist = cdist(centers, self.sg_obj.vector) # 计算每个点到中心的距离
|
|
244
|
+
point_idx = np.argmin(center_dist, axis=-1)
|
|
245
|
+
else:
|
|
246
|
+
# 取这些 clusters 中能量最小值点
|
|
247
|
+
point_idx = []
|
|
248
|
+
for c_id in cluster_idx:
|
|
249
|
+
p_idx = np.arange(len(self.sg_obj.points))[clusters.labels_ == c_id]
|
|
250
|
+
# 求这些点的能量最小值
|
|
251
|
+
if 'energy' not in self.sg_obj.grid_property:
|
|
252
|
+
raise NotImplementedError
|
|
253
|
+
p_energy = self.sg_obj.grid_energy[p_idx]
|
|
254
|
+
point_idx.append(p_idx[p_energy.argmin()])
|
|
255
|
+
|
|
256
|
+
return point_idx
|
|
257
|
+
|
|
258
|
+
|
|
259
|
+
class NewtonSampling(SamplingBase):
|
|
260
|
+
"""
|
|
261
|
+
沿着受力方向进行采样
|
|
262
|
+
"""
|
|
263
|
+
|
|
264
|
+
def _samples(self, size, **kwargs):
|
|
265
|
+
raise NotImplementedError
|
|
266
|
+
|
|
267
|
+
|
|
268
|
+
class RandomWalk(SamplingBase):
|
|
269
|
+
"""
|
|
270
|
+
从给定点出发随机行走进行采样
|
|
271
|
+
"""
|
|
272
|
+
|
|
273
|
+
def __init__(self, sg_obj=None, probability=1.0, **kwargs):
|
|
274
|
+
super().__init__(sg_obj, probability, **kwargs)
|
|
275
|
+
|
|
276
|
+
def _samples(self, size, **kwargs):
|
|
277
|
+
raise NotImplementedError
|
|
278
|
+
|
|
279
|
+
|
|
280
|
+
class SystematicSampling(SamplingBase):
|
|
281
|
+
"""
|
|
282
|
+
等距采样。主要用于测试。
|
|
283
|
+
"""
|
|
284
|
+
|
|
285
|
+
def _samples(self, size, **kwargs):
|
|
286
|
+
if 'start' in kwargs:
|
|
287
|
+
start = kwargs['start']
|
|
288
|
+
else:
|
|
289
|
+
start = random.randint(0, self._pop_size)
|
|
290
|
+
stop = self._pop_size
|
|
291
|
+
indices = list(range(start, stop)) + list(range(0, start))
|
|
292
|
+
step = int(self._pop_size / size)
|
|
293
|
+
point_idx = [indices[i + n * step] for n, i in enumerate(range(size))]
|
|
294
|
+
|
|
295
|
+
return point_idx
|
|
@@ -24,12 +24,12 @@ from scipy.interpolate import griddata
|
|
|
24
24
|
from scipy.spatial import ConvexHull
|
|
25
25
|
from scipy.spatial.distance import euclidean, cdist
|
|
26
26
|
from sklearn.cluster import KMeans as Cluster
|
|
27
|
-
from sklearn.cluster import kmeans_plusplus
|
|
28
27
|
from sklearn.decomposition import PCA
|
|
29
28
|
from sklearn.gaussian_process import GaussianProcessRegressor
|
|
30
29
|
from sklearn.gaussian_process.kernels import RBF, ConstantKernel, WhiteKernel
|
|
31
30
|
from sklearn.preprocessing import StandardScaler
|
|
32
31
|
|
|
32
|
+
from surface_construct.sampling import InitialSampling, addition_samples
|
|
33
33
|
from surface_construct.utils import get_calc_info, GridGenerator, get_distances, furthest_sites
|
|
34
34
|
|
|
35
35
|
|
|
@@ -321,60 +321,18 @@ class SurfaceGrid:
|
|
|
321
321
|
k = d_vector / self.interval
|
|
322
322
|
return np.min(k)
|
|
323
323
|
|
|
324
|
-
def grid_sample(self, N=
|
|
324
|
+
def grid_sample(self, probability=None, N=1, **kwargs):
|
|
325
325
|
"""
|
|
326
326
|
Warning: Obsoleted, replaced by Sampling class
|
|
327
|
+
:param probability:
|
|
327
328
|
:param N:
|
|
328
329
|
:return:
|
|
329
330
|
"""
|
|
330
|
-
|
|
331
331
|
if 'energy' in self.grid_property:
|
|
332
|
-
|
|
333
|
-
idx = self.grid_property_sigma['energy'].argmax()
|
|
334
|
-
max_sigma_point = np.array([self.points[idx]])
|
|
335
|
-
self.sample_idx = np.concatenate([self.sample_idx, [idx]])
|
|
336
|
-
self._sample_vector = np.concatenate([self._sample_vector, [self.vector[idx]]])
|
|
337
|
-
self.sample_points = np.concatenate([self.sample_points, max_sigma_point])
|
|
338
|
-
return [idx]
|
|
339
|
-
|
|
340
|
-
assert N > 1
|
|
341
|
-
hull = ConvexHull(self.vector)
|
|
342
|
-
vertices = []
|
|
343
|
-
# 去掉 hull 的 simplices 的角度较大的点
|
|
344
|
-
for i in hull.vertices:
|
|
345
|
-
p1_idx, p2_idx = np.argwhere(hull.simplices == i)
|
|
346
|
-
p0 = hull.points[i]
|
|
347
|
-
p1 = hull.points[hull.simplices[p1_idx[0],1-p1_idx[1]]]
|
|
348
|
-
p2 = hull.points[hull.simplices[p2_idx[0],1-p2_idx[1]]]
|
|
349
|
-
a = p1 - p0
|
|
350
|
-
b = p2 - p0
|
|
351
|
-
cosangle = a.dot(b)/(np.linalg.norm(a) * np.linalg.norm(b))
|
|
352
|
-
if cosangle > np.cos(np.pi*150/180):
|
|
353
|
-
vertices.append(i)
|
|
354
|
-
# 聚类,vector_mesh
|
|
355
|
-
n_vector_mesh = int(hull.volume / (self._vector_unit * self.interval)**self.vector.shape[1]) + 1
|
|
356
|
-
cluster0 = Cluster(n_clusters=n_vector_mesh)
|
|
357
|
-
cluster0.fit(self.vector)
|
|
358
|
-
mesh_centers = cluster0.cluster_centers_
|
|
359
|
-
self._mesh_centers = mesh_centers
|
|
360
|
-
nvert = len(vertices)
|
|
361
|
-
cluster = Cluster(n_clusters=N)
|
|
362
|
-
cluster.fit(mesh_centers)
|
|
363
|
-
if nvert >= N:
|
|
364
|
-
warning("Sample number should be larger than {nvert}")
|
|
365
|
-
sample_idx = [vertices[i] for i in furthest_sites(self.vector[vertices], N)]
|
|
332
|
+
points_idx = addition_samples(self, size=N, probability=probability, **kwargs)
|
|
366
333
|
else:
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
cluster2.fit(mesh_centers)
|
|
370
|
-
center_dist = cdist(cluster2.cluster_centers_, self.vector) # 计算每个点到中心的距离
|
|
371
|
-
sample_idx = vertices + np.argmin(center_dist, axis=-1).tolist()
|
|
372
|
-
|
|
373
|
-
self._clusters = cluster
|
|
374
|
-
self.sample_idx = sample_idx
|
|
375
|
-
self.sample_points = self.points[sample_idx]
|
|
376
|
-
self._sample_vector = self.vector[sample_idx] # 保存用于作图
|
|
377
|
-
return self.sample_idx
|
|
334
|
+
points_idx = InitialSampling(self).samples(size=N)
|
|
335
|
+
return points_idx
|
|
378
336
|
|
|
379
337
|
# TODO: 将中心重新映射回到Cartesian坐标 :
|
|
380
338
|
# 找到向量空间最紧邻的N个点,判断其实空间的距离是否小于 interval × 2,直到有三个点满足
|
|
@@ -555,15 +513,7 @@ class SurfaceGrid:
|
|
|
555
513
|
fig, ax = plt.subplots()
|
|
556
514
|
ax.set_aspect('equal')
|
|
557
515
|
|
|
558
|
-
print("Plot {} ..."
|
|
559
|
-
#X = self.points[:, 0].reshape((self.grid_ny, self.grid_nx))
|
|
560
|
-
#Y = self.points[:, 1].reshape((self.grid_ny, self.grid_nx))
|
|
561
|
-
#Z = self.grid_property[key].reshape((self.grid_ny, self.grid_nx))
|
|
562
|
-
#if vmax is None:
|
|
563
|
-
# vmax = Z.max() + (Z.max() - Z.min()) * 0.2
|
|
564
|
-
#if vmin is None:
|
|
565
|
-
# vmin = Z.min() - (Z.max() - Z.min()) * 0.2
|
|
566
|
-
#contourf0 = ax.contourf(X, Y, Z, levels=50, cmap="jet", vmin=vmin, vmax=vmax)
|
|
516
|
+
print(f"Plot {key} ...")
|
|
567
517
|
x = self.points[:, 0]
|
|
568
518
|
y = self.points[:, 1]
|
|
569
519
|
z = self.grid_property[key]
|
|
@@ -579,8 +529,7 @@ class SurfaceGrid:
|
|
|
579
529
|
ax.scatter(sample_points[nsampled:, 0], sample_points[nsampled:, 1], marker="o", s=100, linewidths=2,
|
|
580
530
|
color="w", zorder=10)
|
|
581
531
|
fig.colorbar(contourf0, ax=ax)
|
|
582
|
-
title = "{} distribution, Max(sigma)=
|
|
583
|
-
key.capitalize(), self.grid_property_sigma[key].max())
|
|
532
|
+
title = f"{key.capitalize()} distribution, Max(sigma)={self.grid_property_sigma[key].max():.3f}"
|
|
584
533
|
ax.set_title(title)
|
|
585
534
|
fig.set_dpi(300)
|
|
586
535
|
fig.set_size_inches(10, 10)
|
|
@@ -651,11 +600,6 @@ class SurfaceGrid:
|
|
|
651
600
|
fig, ax = plt.subplots()
|
|
652
601
|
ax.set_aspect('equal')
|
|
653
602
|
|
|
654
|
-
#X = self.points[:, 0].reshape((self.grid_ny, self.grid_nx))
|
|
655
|
-
#Y = self.points[:, 1].reshape((self.grid_ny, self.grid_nx))
|
|
656
|
-
#Z = self.grid_property_sigma[key].reshape((self.grid_ny, self.grid_nx))
|
|
657
|
-
#contourf0 = ax.contourf(X, Y, Z, levels=50, cmap="RdPu", vmax=vmax, vmin=vmin)
|
|
658
|
-
|
|
659
603
|
x = self.points[:, 0]
|
|
660
604
|
y = self.points[:, 1]
|
|
661
605
|
z = self.grid_property_sigma[key]
|
|
@@ -667,8 +611,7 @@ class SurfaceGrid:
|
|
|
667
611
|
ax.scatter(max_sigma_point[0], max_sigma_point[1], marker="+", s=100, linewidths=2,
|
|
668
612
|
color="w", zorder=10)
|
|
669
613
|
fig.colorbar(contourf0, ax=ax)
|
|
670
|
-
title = "Gaussian Process Error of {}, Max(sigma)=
|
|
671
|
-
key.capitalize(), self.grid_property_sigma[key].max())
|
|
614
|
+
title = f"Gaussian Process Error of {key.capitalize()}, Max(sigma)={self.grid_property_sigma[key].max():.3f} eV"
|
|
672
615
|
ax.set_title(title)
|
|
673
616
|
|
|
674
617
|
fig.set_dpi(300)
|
|
@@ -14,11 +14,11 @@ from skimage.measure import marching_cubes
|
|
|
14
14
|
def calc_hull_vertices(v):
|
|
15
15
|
shape = v.shape
|
|
16
16
|
if len(shape) != 2:
|
|
17
|
-
print("Warning: The vector should be 2D, however {}D vector was provided!)"
|
|
17
|
+
print(f"Warning: The vector should be 2D, however {len(shape)}D vector was provided!)")
|
|
18
18
|
print("The Convex Hull Vertices won't be calculated.")
|
|
19
19
|
return None
|
|
20
20
|
if shape[1] > 5:
|
|
21
|
-
print("Warning: The vector.shape[1]={} is too large to be calculated!)"
|
|
21
|
+
print(f"Warning: The vector.shape[1]={shape[1]} is too large to be calculated!)")
|
|
22
22
|
print("The Convex Hull Vertices won't be calculated.")
|
|
23
23
|
return None
|
|
24
24
|
try:
|
|
@@ -261,7 +261,7 @@ class GridGenerator:
|
|
|
261
261
|
xyz = rattle(atoms.cell.cartesian_positions(fxyz), stdev=self.interval / 3)
|
|
262
262
|
grid_tree = cKDTree(xyz, copy_data=True)
|
|
263
263
|
|
|
264
|
-
# 对atoms 在
|
|
264
|
+
# 对atoms 在 xyz 方向超胞. Adapt from ase.geometry.geometry.general_find_mic
|
|
265
265
|
ranges = [np.arange(-1 * p, p + 1) for p in atoms.pbc]
|
|
266
266
|
hkls = np.array(list(itertools.product(*ranges)))
|
|
267
267
|
vrvecs = hkls @ atoms.cell
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{surface_construct-0.8 → surface_construct-0.8.1}/surface_construct.egg-info/dependency_links.txt
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|