surface-construct 0.8__tar.gz → 0.8.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (19) hide show
  1. {surface_construct-0.8/surface_construct.egg-info → surface_construct-0.8.2}/PKG-INFO +1 -1
  2. {surface_construct-0.8 → surface_construct-0.8.2}/setup.py +1 -1
  3. {surface_construct-0.8 → surface_construct-0.8.2}/surface_construct/sampling.py +318 -284
  4. {surface_construct-0.8 → surface_construct-0.8.2}/surface_construct/surface_grid.py +9 -66
  5. {surface_construct-0.8 → surface_construct-0.8.2}/surface_construct/utils.py +36 -4
  6. {surface_construct-0.8 → surface_construct-0.8.2/surface_construct.egg-info}/PKG-INFO +1 -1
  7. {surface_construct-0.8 → surface_construct-0.8.2}/LICENSE +0 -0
  8. {surface_construct-0.8 → surface_construct-0.8.2}/README.md +0 -0
  9. {surface_construct-0.8 → surface_construct-0.8.2}/setup.cfg +0 -0
  10. {surface_construct-0.8 → surface_construct-0.8.2}/surface_construct/__init__.py +0 -0
  11. {surface_construct-0.8 → surface_construct-0.8.2}/surface_construct/atoms.py +0 -0
  12. {surface_construct-0.8 → surface_construct-0.8.2}/surface_construct/db.py +0 -0
  13. {surface_construct-0.8 → surface_construct-0.8.2}/surface_construct/default_parameter.py +0 -0
  14. {surface_construct-0.8 → surface_construct-0.8.2}/surface_construct/structure.py +0 -0
  15. {surface_construct-0.8 → surface_construct-0.8.2}/surface_construct/surface.py +0 -0
  16. {surface_construct-0.8 → surface_construct-0.8.2}/surface_construct.egg-info/SOURCES.txt +0 -0
  17. {surface_construct-0.8 → surface_construct-0.8.2}/surface_construct.egg-info/dependency_links.txt +0 -0
  18. {surface_construct-0.8 → surface_construct-0.8.2}/surface_construct.egg-info/requires.txt +0 -0
  19. {surface_construct-0.8 → surface_construct-0.8.2}/surface_construct.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: surface_construct
3
- Version: 0.8
3
+ Version: 0.8.2
4
4
  Summary: Surface termination construction especially for complex model, such as oxides or carbides.
5
5
  Home-page: https://gitee.com/pjren/surface_construct/
6
6
  Author: ren
@@ -18,7 +18,7 @@ install_requires = [
18
18
 
19
19
  setup(
20
20
  name='surface_construct',
21
- version='0.8',
21
+ version='0.8.2',
22
22
  packages=['surface_construct'],
23
23
  url='https://gitee.com/pjren/surface_construct/',
24
24
  license='GPL',
@@ -1,284 +1,318 @@
1
- """
2
- TODO: 关键点采样:top 位置、hollow位,bridge 位等等。
3
- """
4
-
5
- import numpy as np
6
- from ase.geometry import get_distances
7
- from scipy.spatial import ConvexHull
8
- from scipy.spatial.distance import cdist
9
- from sklearn.cluster import KMeans as Cluster
10
- import random
11
-
12
-
13
- def addition_samples(sg_obj, size=None, probability=None, **kwargs):
14
- if 'seed' in kwargs:
15
- seed = kwargs['seed']
16
- else:
17
- seed = None
18
- if probability is None:
19
- probability = {
20
- "max_sigma": 0.2, # 采样方法的概率
21
- "max_diversity": 0.8,
22
- }
23
- if size is None:
24
- size = 1
25
-
26
- # 归一化
27
- total = sum(probability.values())
28
- if total != 1.0:
29
- probability = {k: v / total for k, v in probability.items()}
30
-
31
- rng = np.random.default_rng(seed)
32
- method_list = rng.choice(list(probability.keys()), size=size, p=list(probability.values()))
33
-
34
- point_idx = np.array([], dtype=int)
35
- for method in method_list:
36
- method_lower = method.lower()
37
- if method_lower == 'max_sigma':
38
- sampling_obj = MaxSigmaSampling(sg_obj)
39
- elif method_lower == 'max_diversity':
40
- sampling_obj = MaxDiversitySampling(sg_obj)
41
- else:
42
- raise NotImplementedError
43
- point_idx = np.concatenate([point_idx, sampling_obj.samples(size=1, **kwargs)]) # 每种方法只采一个
44
-
45
- return point_idx
46
-
47
-
48
- class SamplingBase:
49
- def __init__(self, sg_obj, **kwargs):
50
- self.sg_obj = sg_obj
51
-
52
- @property
53
- def _pop_size(self):
54
- return len(self.sg_obj.points)
55
-
56
- @property
57
- def _population(self):
58
- """
59
- 默认的全体是 sg_obj.points 的 index
60
- :return:
61
- """
62
- return range(self._pop_size)
63
-
64
- def _append_sample_to_sg(self, point_idx=None):
65
- """
66
- 将采样点加入到 sg_obj.sample_points 和相应的 vector
67
-
68
- :return:
69
- """
70
- if point_idx is not None:
71
- if self.sg_obj.sample_idx is not None:
72
- self.sg_obj.sample_idx = np.concatenate([self.sg_obj.sample_idx, point_idx])
73
- self.sg_obj._sample_vector = np.concatenate([self.sg_obj._sample_vector, self.sg_obj.vector[point_idx]])
74
- self.sg_obj.sample_points = np.concatenate([self.sg_obj.sample_points, self.sg_obj.points[point_idx]])
75
- else:
76
- self.sg_obj.sample_idx = np.array(point_idx)
77
- self.sg_obj._sample_vector = self.sg_obj.vector[point_idx]
78
- self.sg_obj.sample_points = self.sg_obj.points[point_idx]
79
-
80
- def _samples(self, size, **kwargs):
81
- raise NotImplementedError
82
-
83
- def samples(self, size=1, **kwargs):
84
- point_idx = self._samples(size=size, **kwargs)
85
- self._append_sample_to_sg(point_idx=point_idx)
86
-
87
- return point_idx
88
-
89
-
90
- class KeyPointSampling(SamplingBase):
91
- """
92
- 基本思路:先定位 hollow位,再根据图论分析定位 bridge 位,最后定位top位。第二步,筛选等价位点。
93
- 需要一个基础func,从xy坐标,找到对应的格点。
94
- """
95
- def _samples(self, size, **kwargs):
96
- pass
97
-
98
-
99
- class RandomSampling(SamplingBase):
100
- """
101
- 完全随机的选择点,仅用于测试,效率太低。
102
- """
103
-
104
- def __init__(self, sg_obj, **kwargs):
105
- super().__init__(sg_obj, **kwargs)
106
- if 'seed' in kwargs:
107
- self.seed = kwargs['seed']
108
- else:
109
- self.seed = None
110
-
111
- def _samples(self, size, **kwargs):
112
- rng = np.random.default_rng(self.seed)
113
- pop_idx = rng.choice(self._population, size=size)
114
- return pop_idx
115
-
116
-
117
- class MaxSigmaSampling(SamplingBase):
118
- """
119
- 对最大误差的点进行采样
120
- """
121
-
122
- def _samples(self, size, **kwargs):
123
- if 'energy' in self.sg_obj.grid_property:
124
- # 如果已经读入了一些能量,则返回误差最大的点
125
- idx = self.sg_obj.grid_property_sigma['energy'].argmax()
126
- return [idx]
127
- else:
128
- raise "No energy for all population, pls do initial sampling first!"
129
-
130
-
131
- class InitialSampling(SamplingBase):
132
- """
133
- 使用聚类-分层采样进行初始采样
134
- """
135
-
136
- def _samples(self, size, **kwargs):
137
- # 进行分类,然后采样。这里使用 Kmeans 方法进行聚类
138
- clusters = Cluster(n_clusters=size, random_state=0).fit(self.sg_obj.vector)
139
- self.sg_obj._clusters = clusters # 保存用于作图
140
- # 对于每一类取距离 cluster 中心最小的点的 idx
141
- centers = clusters.cluster_centers_
142
- center_dist = cdist(centers, self.sg_obj.vector) # 计算每个点到中心的距离
143
- point_idx = np.argmin(center_dist, axis=-1)
144
-
145
- # 找到 Hull 点,输出 hull 的个数
146
- # TODO: 有了 keypoint 采样,这个就是不需要的
147
- hull = ConvexHull(self.sg_obj.vector)
148
- vertices = hull.vertices
149
- points_hull_sample = self.sg_obj.points[vertices]
150
-
151
- # 计算 hull 点与 cluster_sample 的距离, 排除太近的 cluster 采样点
152
- hull_cluster_dist_array, hull_cluster_dist = get_distances(points_hull_sample,
153
- self.sg_obj.points[point_idx],
154
- cell=self.sg_obj.atoms.cell,
155
- pbc=self.sg_obj.atoms.pbc)
156
- cluster_index = vertices[np.all(hull_cluster_dist > self.sg_obj.interval * 2, axis=-1)]
157
- print("There are {:d} vertex according to ConvexHull analysis, which can be appended to the initial "
158
- "sampling set for better diversity.".format(len(cluster_index)))
159
- # 加入采样
160
- if 'include_vertex' in kwargs and kwargs['include_vertex']:
161
- print("The vertex are appended into initial sampling.")
162
- point_idx = np.concatenate([point_idx, cluster_index])
163
- else:
164
- print("The vertex weren't appended into initial sampling.")
165
-
166
- return point_idx
167
-
168
- def _append_sample_to_sg(self, point_idx=None):
169
- """
170
- 将采样点加入到 sg_obj.sample_points 和相应的 vector
171
- :return:
172
- """
173
- if point_idx is not None:
174
- self.sg_obj.sample_idx = np.asarray(point_idx)
175
- self.sg_obj._sample_vector = self.sg_obj.vector[point_idx]
176
- self.sg_obj.sample_points = self.sg_obj.points[point_idx]
177
-
178
-
179
- class MaxDiversitySampling(SamplingBase):
180
- """
181
- 对当前采样结构差异最大的点进行采样
182
- 基本思路是这样的:
183
- * 重新进行聚类,
184
- * 判断已经采样点属于的类别,找出没有点的类别,空类
185
- * 如果空类不止一个,比较这些空类中心与旧点的距离,选择距离最大的点。
186
- """
187
-
188
- def _samples(self, size, center=False, **kwargs):
189
- """
190
-
191
- :param size:
192
- :param center: 是否取中心。如果不是,则取能量最小值的点。如果没有能量则报错。
193
- :param kwargs:
194
- :return:
195
- """
196
- # 判断是否有过往的采样点,如果没有,调用 InitialSampling
197
- if self.sg_obj.sample_idx is None:
198
- raise "Please add initial samples (e.g. InitialSampling) before invoke this method!"
199
- cluster_size = len(self.sg_obj.sample_idx) + size
200
- nvirgin = 0
201
- larger_clusters = None
202
- larger_virgin = None
203
- virgin = None
204
- clusters = None
205
- # 如果等于则停止,并保存 cluster
206
- while nvirgin != size:
207
- # len(sample_idx) + size 作为新的聚类的size
208
- clusters = Cluster(n_clusters=cluster_size).fit(self.sg_obj.vector)
209
- labels = clusters.labels_[self.sg_obj.sample_idx]
210
- labels_set = set(labels)
211
- virgin = set(range(cluster_size)) - labels_set
212
- nvirgin = len(virgin)
213
- # 判断分类以后空类数目与size的大小
214
- # 如果大于size,则减小size,并记录空类的数目
215
- if nvirgin > size:
216
- cluster_size -= 1
217
- larger_clusters = clusters
218
- larger_virgin = virgin
219
- # 如果小于 size 则增大size,检查上一个size是否有记录,如果有记录则使用上个size 的记录。从中随机选择size个点作为采样点。
220
- elif nvirgin < size:
221
- cluster_size += 1
222
- if larger_clusters is not None:
223
- clusters = larger_clusters
224
- virgin = larger_virgin
225
- break
226
- # virgin 里面选取 size 个点
227
- rng = np.random.default_rng()
228
- cluster_idx = rng.choice(list(virgin), size=size)
229
- if center:
230
- # 取中心位置的格点
231
- centers = clusters.cluster_centers_[cluster_idx]
232
- center_dist = cdist(centers, self.sg_obj.vector) # 计算每个点到中心的距离
233
- point_idx = np.argmin(center_dist, axis=-1)
234
- else:
235
- # 取这些 clusters 中能量最小值点
236
- point_idx = []
237
- for c_id in cluster_idx:
238
- p_idx = np.arange(len(self.sg_obj.points))[clusters.labels_ == c_id]
239
- # 求这些点的能量最小值
240
- if 'energy' not in self.sg_obj.grid_property:
241
- raise NotImplementedError
242
- p_energy = self.sg_obj.grid_energy[p_idx]
243
- point_idx.append(p_idx[p_energy.argmin()])
244
-
245
- return point_idx
246
-
247
-
248
- class NewtonSampling(SamplingBase):
249
- """
250
- 沿着受力方向进行采样
251
- """
252
-
253
- def _samples(self, size, **kwargs):
254
- raise NotImplementedError
255
-
256
-
257
- class RandomWalk(SamplingBase):
258
- """
259
- 从给定点出发随机行走进行采样
260
- """
261
-
262
- def __init__(self, sg_obj=None, probability=1.0, **kwargs):
263
- super().__init__(sg_obj, probability, **kwargs)
264
-
265
- def _samples(self, size, **kwargs):
266
- raise NotImplementedError
267
-
268
-
269
- class SystematicSampling(SamplingBase):
270
- """
271
- 等距采样。主要用于测试。
272
- """
273
-
274
- def _samples(self, size, **kwargs):
275
- if 'start' in kwargs:
276
- start = kwargs['start']
277
- else:
278
- start = random.randint(0, self._pop_size)
279
- stop = self._pop_size
280
- indices = list(range(start, stop)) + list(range(0, start))
281
- step = int(self._pop_size / size)
282
- point_idx = [indices[i + n * step] for n, i in enumerate(range(size))]
283
-
284
- return point_idx
1
+ """
2
+ TODO: 关键点采样:top 位置、hollow位,bridge 位等等。
3
+ """
4
+ import itertools
5
+ import numpy as np
6
+ from ase.geometry import get_distances
7
+ from scipy.spatial import ConvexHull
8
+ from scipy.spatial.distance import cdist
9
+ from scipy.special import comb
10
+ from sklearn.cluster import KMeans as Cluster
11
+ import random
12
+
13
+ from surface_construct.utils import furthest_sites
14
+
15
+ MIN_HULL_ANGLE_COS = np.cos(np.pi * 30 / 180)
16
+
17
+ def hull_vertices(hull):
18
+ hsimplices = hull.simplices
19
+ hvertices = hull.vertices
20
+ hnorms = hull.equations[:,0:-1]
21
+ ndim = hsimplices.shape[1]
22
+ vertices = []
23
+ # 去掉 hull 的 simplices 的角度较大的点
24
+ for i in hvertices:
25
+ p0_facets_idx = np.argwhere(hsimplices == i)[:,0]
26
+ p0_norms = hnorms[p0_facets_idx]
27
+ cosangle = lambda a,b: a.dot(b) / (np.linalg.norm(a) * np.linalg.norm(b))
28
+ # i 凸点相邻的超平面的法向向量之间的夹角。如果存在夹角小于30度,即平面之间的夹角大于150度,则排除该点。反之,保留该点。
29
+ norm_angle_cos = np.absolute([cosangle(a,b) for a,b in itertools.combinations(p0_norms, 2)])
30
+ if np.sum(norm_angle_cos < MIN_HULL_ANGLE_COS) >= comb(ndim,2):
31
+ vertices.append(i)
32
+
33
+ return vertices
34
+
35
+
36
+ def addition_samples(sg_obj, size=None, probability=None, **kwargs):
37
+ if 'seed' in kwargs:
38
+ seed = kwargs['seed']
39
+ else:
40
+ seed = None
41
+ if probability is None:
42
+ probability = {
43
+ "max_sigma": 0.2, # 采样方法的概率
44
+ "max_diversity": 0.8,
45
+ }
46
+ if size is None:
47
+ size = 1
48
+
49
+ # 归一化
50
+ total = sum(probability.values())
51
+ if total != 1.0:
52
+ probability = {k: v / total for k, v in probability.items()}
53
+
54
+ rng = np.random.default_rng(seed)
55
+ method_list = rng.choice(list(probability.keys()), size=size, p=list(probability.values()))
56
+
57
+ point_idx = np.array([], dtype=int)
58
+ for method in method_list:
59
+ method_lower = method.lower()
60
+ if method_lower == 'max_sigma':
61
+ sampling_obj = MaxSigmaSampling(sg_obj)
62
+ elif method_lower == 'max_diversity':
63
+ sampling_obj = MaxDiversitySampling(sg_obj)
64
+ else:
65
+ raise NotImplementedError
66
+ point_idx = np.concatenate([point_idx, sampling_obj.samples(size=1, **kwargs)]) # 每种方法只采一个
67
+
68
+ return point_idx
69
+
70
+
71
+ class SamplingBase:
72
+ def __init__(self, sg_obj, **kwargs):
73
+ self.sg_obj = sg_obj
74
+
75
+ @property
76
+ def _pop_size(self):
77
+ return len(self.sg_obj.points)
78
+
79
+ @property
80
+ def _population(self):
81
+ """
82
+ 默认的全体是 sg_obj.points 的 index
83
+ :return:
84
+ """
85
+ return range(self._pop_size)
86
+
87
+ def _append_sample_to_sg(self, point_idx=None):
88
+ """
89
+ 将采样点加入到 sg_obj.sample_points 和相应的 vector
90
+
91
+ :return:
92
+ """
93
+ if point_idx is not None:
94
+ if self.sg_obj.sample_idx is not None:
95
+ self.sg_obj.sample_idx = np.concatenate([self.sg_obj.sample_idx, point_idx])
96
+ self.sg_obj._sample_vector = np.concatenate([self.sg_obj._sample_vector, self.sg_obj.vector[point_idx]])
97
+ self.sg_obj.sample_points = np.concatenate([self.sg_obj.sample_points, self.sg_obj.points[point_idx]])
98
+ else:
99
+ self.sg_obj.sample_idx = np.array(point_idx)
100
+ self.sg_obj._sample_vector = self.sg_obj.vector[point_idx]
101
+ self.sg_obj.sample_points = self.sg_obj.points[point_idx]
102
+
103
+ def _samples(self, size, **kwargs):
104
+ raise NotImplementedError
105
+
106
+ def samples(self, size=1, **kwargs):
107
+ point_idx = self._samples(size=size, **kwargs)
108
+ self._append_sample_to_sg(point_idx=point_idx)
109
+
110
+ return point_idx
111
+
112
+
113
+ class KeyPointSampling(SamplingBase):
114
+ """
115
+ 基本思路:先定位 hollow位,再根据图论分析定位 bridge 位,最后定位top位。第二步,筛选等价位点。
116
+ 需要一个基础func,从xy坐标,找到对应的格点。
117
+ """
118
+ def _samples(self, size, **kwargs):
119
+ pass
120
+
121
+
122
+ class RandomSampling(SamplingBase):
123
+ """
124
+ 完全随机的选择点,仅用于测试,效率太低。
125
+ """
126
+
127
+ def __init__(self, sg_obj, **kwargs):
128
+ super().__init__(sg_obj, **kwargs)
129
+ if 'seed' in kwargs:
130
+ self.seed = kwargs['seed']
131
+ else:
132
+ self.seed = None
133
+
134
+ def _samples(self, size, **kwargs):
135
+ rng = np.random.default_rng(self.seed)
136
+ pop_idx = rng.choice(self._population, size=size)
137
+ return pop_idx
138
+
139
+
140
+ class MaxSigmaSampling(SamplingBase):
141
+ """
142
+ 对最大误差的点进行采样
143
+ """
144
+
145
+ def _samples(self, size, **kwargs):
146
+ if 'energy' in self.sg_obj.grid_property:
147
+ # 如果已经读入了一些能量,则返回误差最大的点
148
+ idx = self.sg_obj.grid_property_sigma['energy'].argmax()
149
+ return [idx]
150
+ else:
151
+ raise "No energy for all population, pls do initial sampling first!"
152
+
153
+
154
+ class InitialSampling(SamplingBase):
155
+ """
156
+ 使用聚类-分层采样进行初始采样
157
+ """
158
+
159
+ def _samples(self, size, **kwargs):
160
+ hull = ConvexHull(self.sg_obj.vector)
161
+ #vertices = []
162
+ # 去掉 hull 的 simplices 的角度较大的点
163
+ #for i in hull.vertices:
164
+ # p1_idx, p2_idx = np.argwhere(hull.simplices == i)
165
+ # p0 = hull.points[i]
166
+ # p1 = hull.points[hull.simplices[p1_idx[0],1-p1_idx[1]]]
167
+ # p2 = hull.points[hull.simplices[p2_idx[0],1-p2_idx[1]]]
168
+ # a = p1 - p0
169
+ # b = p2 - p0
170
+ # cosangle = a.dot(b)/(np.linalg.norm(a) * np.linalg.norm(b))
171
+ # if cosangle > MIN_HULL_ANGLE_COS:
172
+ # vertices.append(i)
173
+ # 聚类,vector_mesh
174
+ vertices = hull_vertices(hull)
175
+ n_vector_mesh = int(hull.volume / (self.sg_obj._vector_unit *
176
+ self.sg_obj.interval)**self.sg_obj.vector.shape[1]) + 1
177
+ cluster0 = Cluster(n_clusters=n_vector_mesh)
178
+ cluster0.fit(self.sg_obj.vector)
179
+ mesh_centers = cluster0.cluster_centers_
180
+ self.sg_obj._mesh_centers = mesh_centers
181
+ cluster = Cluster(n_clusters=size)
182
+ cluster.fit(mesh_centers)
183
+ self.sg_obj._clusters = cluster
184
+ nvert = len(vertices)
185
+ if nvert >= size:
186
+ print(f"Warning: Sample number better be larger than {nvert}!")
187
+ if size == 1:
188
+ sample_idx = np.random.choice(vertices,1)
189
+ elif size==nvert:
190
+ sample_idx = vertices
191
+ else:
192
+ sample_idx = [vertices[i] for i in
193
+ furthest_sites(self.sg_obj.vector[vertices], size)]
194
+ else:
195
+ # 聚类
196
+ cluster2 = Cluster(n_clusters=size-nvert)
197
+ cluster2.fit(mesh_centers)
198
+ center_dist = cdist(cluster2.cluster_centers_, self.sg_obj.vector) # 计算每个点到中心的距离
199
+ sample_idx = vertices + np.argmin(center_dist, axis=-1).tolist()
200
+ return sample_idx
201
+
202
+ def _append_sample_to_sg(self, point_idx=None):
203
+ """
204
+ 将采样点加入到 sg_obj.sample_points 和相应的 vector
205
+ :return:
206
+ """
207
+ if point_idx is not None:
208
+ self.sg_obj.sample_idx = np.asarray(point_idx)
209
+ self.sg_obj._sample_vector = self.sg_obj.vector[point_idx]
210
+ self.sg_obj.sample_points = self.sg_obj.points[point_idx]
211
+
212
+
213
+ class MaxDiversitySampling(SamplingBase):
214
+ """
215
+ 对当前采样结构差异最大的点进行采样
216
+ 基本思路是这样的:
217
+ * 重新进行聚类,
218
+ * 判断已经采样点属于的类别,找出没有点的类别,空类
219
+ * 如果空类不止一个,比较这些空类中心与旧点的距离,选择距离最大的点。
220
+ """
221
+
222
+ def _samples(self, size, center=False, **kwargs):
223
+ """
224
+
225
+ :param size:
226
+ :param center: 是否取中心。如果不是,则取能量最小值的点。如果没有能量则报错。
227
+ :param kwargs:
228
+ :return:
229
+ """
230
+ # 判断是否有过往的采样点,如果没有,调用 InitialSampling
231
+ if self.sg_obj.sample_idx is None:
232
+ raise "Please add initial samples (e.g. InitialSampling) before invoke this method!"
233
+ cluster_size = len(self.sg_obj.sample_idx) + size
234
+ nvirgin = 0
235
+ larger_clusters = None
236
+ larger_virgin = None
237
+ virgin = None
238
+ clusters = None
239
+ # 如果等于则停止,并保存 cluster
240
+ while nvirgin != size:
241
+ # 以 len(sample_idx) + size 作为新的聚类的size
242
+ clusters = Cluster(n_clusters=cluster_size).fit(self.sg_obj.vector)
243
+ labels = clusters.labels_[self.sg_obj.sample_idx]
244
+ labels_set = set(labels)
245
+ virgin = set(range(cluster_size)) - labels_set
246
+ nvirgin = len(virgin)
247
+ # 判断分类以后空类数目与size的大小
248
+ # 如果大于size,则减小size,并记录空类的数目
249
+ if nvirgin > size:
250
+ cluster_size -= 1
251
+ larger_clusters = clusters
252
+ larger_virgin = virgin
253
+ # 如果小于 size 则增大size,检查上一个size是否有记录,如果有记录则使用上个size 的记录。从中随机选择size个点作为采样点。
254
+ elif nvirgin < size:
255
+ cluster_size += 1
256
+ if larger_clusters is not None:
257
+ clusters = larger_clusters
258
+ virgin = larger_virgin
259
+ break
260
+ # 从 virgin 里面选取 size 个点
261
+ rng = np.random.default_rng()
262
+ cluster_idx = rng.choice(list(virgin), size=size)
263
+ if center:
264
+ # 取中心位置的格点
265
+ centers = clusters.cluster_centers_[cluster_idx]
266
+ center_dist = cdist(centers, self.sg_obj.vector) # 计算每个点到中心的距离
267
+ point_idx = np.argmin(center_dist, axis=-1)
268
+ else:
269
+ # 取这些 clusters 中能量最小值点
270
+ point_idx = []
271
+ for c_id in cluster_idx:
272
+ p_idx = np.arange(len(self.sg_obj.points))[clusters.labels_ == c_id]
273
+ # 求这些点的能量最小值
274
+ if 'energy' not in self.sg_obj.grid_property:
275
+ raise NotImplementedError
276
+ p_energy = self.sg_obj.grid_energy[p_idx]
277
+ point_idx.append(p_idx[p_energy.argmin()])
278
+
279
+ return point_idx
280
+
281
+
282
+ class NewtonSampling(SamplingBase):
283
+ """
284
+ 沿着受力方向进行采样
285
+ """
286
+
287
+ def _samples(self, size, **kwargs):
288
+ raise NotImplementedError
289
+
290
+
291
+ class RandomWalk(SamplingBase):
292
+ """
293
+ 从给定点出发随机行走进行采样
294
+ """
295
+
296
+ def __init__(self, sg_obj=None, probability=1.0, **kwargs):
297
+ super().__init__(sg_obj, probability, **kwargs)
298
+
299
+ def _samples(self, size, **kwargs):
300
+ raise NotImplementedError
301
+
302
+
303
+ class SystematicSampling(SamplingBase):
304
+ """
305
+ 等距采样。主要用于测试。
306
+ """
307
+
308
+ def _samples(self, size, **kwargs):
309
+ if 'start' in kwargs:
310
+ start = kwargs['start']
311
+ else:
312
+ start = random.randint(0, self._pop_size)
313
+ stop = self._pop_size
314
+ indices = list(range(start, stop)) + list(range(0, start))
315
+ step = int(self._pop_size / size)
316
+ point_idx = [indices[i + n * step] for n, i in enumerate(range(size))]
317
+
318
+ return point_idx
@@ -24,12 +24,12 @@ from scipy.interpolate import griddata
24
24
  from scipy.spatial import ConvexHull
25
25
  from scipy.spatial.distance import euclidean, cdist
26
26
  from sklearn.cluster import KMeans as Cluster
27
- from sklearn.cluster import kmeans_plusplus
28
27
  from sklearn.decomposition import PCA
29
28
  from sklearn.gaussian_process import GaussianProcessRegressor
30
29
  from sklearn.gaussian_process.kernels import RBF, ConstantKernel, WhiteKernel
31
30
  from sklearn.preprocessing import StandardScaler
32
31
 
32
+ from surface_construct.sampling import InitialSampling, addition_samples
33
33
  from surface_construct.utils import get_calc_info, GridGenerator, get_distances, furthest_sites
34
34
 
35
35
 
@@ -321,60 +321,18 @@ class SurfaceGrid:
321
321
  k = d_vector / self.interval
322
322
  return np.min(k)
323
323
 
324
- def grid_sample(self, N=10):
324
+ def grid_sample(self, N=1, probability=None, **kwargs):
325
325
  """
326
326
  Warning: Obsoleted, replaced by Sampling class
327
+ :param probability:
327
328
  :param N:
328
329
  :return:
329
330
  """
330
-
331
331
  if 'energy' in self.grid_property:
332
- # 如果已经读入了一些能量,则返回误差最大的点
333
- idx = self.grid_property_sigma['energy'].argmax()
334
- max_sigma_point = np.array([self.points[idx]])
335
- self.sample_idx = np.concatenate([self.sample_idx, [idx]])
336
- self._sample_vector = np.concatenate([self._sample_vector, [self.vector[idx]]])
337
- self.sample_points = np.concatenate([self.sample_points, max_sigma_point])
338
- return [idx]
339
-
340
- assert N > 1
341
- hull = ConvexHull(self.vector)
342
- vertices = []
343
- # 去掉 hull 的 simplices 的角度较大的点
344
- for i in hull.vertices:
345
- p1_idx, p2_idx = np.argwhere(hull.simplices == i)
346
- p0 = hull.points[i]
347
- p1 = hull.points[hull.simplices[p1_idx[0],1-p1_idx[1]]]
348
- p2 = hull.points[hull.simplices[p2_idx[0],1-p2_idx[1]]]
349
- a = p1 - p0
350
- b = p2 - p0
351
- cosangle = a.dot(b)/(np.linalg.norm(a) * np.linalg.norm(b))
352
- if cosangle > np.cos(np.pi*150/180):
353
- vertices.append(i)
354
- # 聚类,vector_mesh
355
- n_vector_mesh = int(hull.volume / (self._vector_unit * self.interval)**self.vector.shape[1]) + 1
356
- cluster0 = Cluster(n_clusters=n_vector_mesh)
357
- cluster0.fit(self.vector)
358
- mesh_centers = cluster0.cluster_centers_
359
- self._mesh_centers = mesh_centers
360
- nvert = len(vertices)
361
- cluster = Cluster(n_clusters=N)
362
- cluster.fit(mesh_centers)
363
- if nvert >= N:
364
- warning("Sample number should be larger than {nvert}")
365
- sample_idx = [vertices[i] for i in furthest_sites(self.vector[vertices], N)]
332
+ points_idx = addition_samples(self, size=N, probability=probability, **kwargs)
366
333
  else:
367
- # 聚类
368
- cluster2 = Cluster(n_clusters=N-nvert)
369
- cluster2.fit(mesh_centers)
370
- center_dist = cdist(cluster2.cluster_centers_, self.vector) # 计算每个点到中心的距离
371
- sample_idx = vertices + np.argmin(center_dist, axis=-1).tolist()
372
-
373
- self._clusters = cluster
374
- self.sample_idx = sample_idx
375
- self.sample_points = self.points[sample_idx]
376
- self._sample_vector = self.vector[sample_idx] # 保存用于作图
377
- return self.sample_idx
334
+ points_idx = InitialSampling(self).samples(size=N)
335
+ return points_idx
378
336
 
379
337
  # TODO: 将中心重新映射回到Cartesian坐标 :
380
338
  # 找到向量空间最紧邻的N个点,判断其实空间的距离是否小于 interval × 2,直到有三个点满足
@@ -555,15 +513,7 @@ class SurfaceGrid:
555
513
  fig, ax = plt.subplots()
556
514
  ax.set_aspect('equal')
557
515
 
558
- print("Plot {} ...".format(key))
559
- #X = self.points[:, 0].reshape((self.grid_ny, self.grid_nx))
560
- #Y = self.points[:, 1].reshape((self.grid_ny, self.grid_nx))
561
- #Z = self.grid_property[key].reshape((self.grid_ny, self.grid_nx))
562
- #if vmax is None:
563
- # vmax = Z.max() + (Z.max() - Z.min()) * 0.2
564
- #if vmin is None:
565
- # vmin = Z.min() - (Z.max() - Z.min()) * 0.2
566
- #contourf0 = ax.contourf(X, Y, Z, levels=50, cmap="jet", vmin=vmin, vmax=vmax)
516
+ print(f"Plot {key} ...")
567
517
  x = self.points[:, 0]
568
518
  y = self.points[:, 1]
569
519
  z = self.grid_property[key]
@@ -579,8 +529,7 @@ class SurfaceGrid:
579
529
  ax.scatter(sample_points[nsampled:, 0], sample_points[nsampled:, 1], marker="o", s=100, linewidths=2,
580
530
  color="w", zorder=10)
581
531
  fig.colorbar(contourf0, ax=ax)
582
- title = "{} distribution, Max(sigma)= {:.3f}".format(
583
- key.capitalize(), self.grid_property_sigma[key].max())
532
+ title = f"{key.capitalize()} distribution, Max(sigma)={self.grid_property_sigma[key].max():.3f}"
584
533
  ax.set_title(title)
585
534
  fig.set_dpi(300)
586
535
  fig.set_size_inches(10, 10)
@@ -651,11 +600,6 @@ class SurfaceGrid:
651
600
  fig, ax = plt.subplots()
652
601
  ax.set_aspect('equal')
653
602
 
654
- #X = self.points[:, 0].reshape((self.grid_ny, self.grid_nx))
655
- #Y = self.points[:, 1].reshape((self.grid_ny, self.grid_nx))
656
- #Z = self.grid_property_sigma[key].reshape((self.grid_ny, self.grid_nx))
657
- #contourf0 = ax.contourf(X, Y, Z, levels=50, cmap="RdPu", vmax=vmax, vmin=vmin)
658
-
659
603
  x = self.points[:, 0]
660
604
  y = self.points[:, 1]
661
605
  z = self.grid_property_sigma[key]
@@ -667,8 +611,7 @@ class SurfaceGrid:
667
611
  ax.scatter(max_sigma_point[0], max_sigma_point[1], marker="+", s=100, linewidths=2,
668
612
  color="w", zorder=10)
669
613
  fig.colorbar(contourf0, ax=ax)
670
- title = "Gaussian Process Error of {}, Max(sigma)= {:.3f} eV".format(
671
- key.capitalize(), self.grid_property_sigma[key].max())
614
+ title = f"Gaussian Process Error of {key.capitalize()}, Max(sigma)={self.grid_property_sigma[key].max():.3f} eV"
672
615
  ax.set_title(title)
673
616
 
674
617
  fig.set_dpi(300)
@@ -14,11 +14,11 @@ from skimage.measure import marching_cubes
14
14
  def calc_hull_vertices(v):
15
15
  shape = v.shape
16
16
  if len(shape) != 2:
17
- print("Warning: The vector should be 2D, however {}D vector was provided!)".format(len(shape)))
17
+ print(f"Warning: The vector should be 2D, however {len(shape)}D vector was provided!)")
18
18
  print("The Convex Hull Vertices won't be calculated.")
19
19
  return None
20
20
  if shape[1] > 5:
21
- print("Warning: The vector.shape[1]={} is too large to be calculated!)".format(shape[1]))
21
+ print(f"Warning: The vector.shape[1]={shape[1]} is too large to be calculated!)")
22
22
  print("The Convex Hull Vertices won't be calculated.")
23
23
  return None
24
24
  try:
@@ -117,7 +117,7 @@ class GridGenerator:
117
117
  """
118
118
  self.atoms = atoms
119
119
  self._grid = None
120
- self.atoms_num_type = set(atoms.numbers)
120
+ self.atoms_num_type = sorted(set(atoms.numbers))
121
121
  self.interval = interval
122
122
 
123
123
  if subtype is None:
@@ -261,7 +261,7 @@ class GridGenerator:
261
261
  xyz = rattle(atoms.cell.cartesian_positions(fxyz), stdev=self.interval / 3)
262
262
  grid_tree = cKDTree(xyz, copy_data=True)
263
263
 
264
- # 对atoms 在 xy 方向超胞. Adapt from ase.geometry.geometry.general_find_mic
264
+ # 对atoms 在 xyz 方向超胞. Adapt from ase.geometry.geometry.general_find_mic
265
265
  ranges = [np.arange(-1 * p, p + 1) for p in atoms.pbc]
266
266
  hkls = np.array(list(itertools.product(*ranges)))
267
267
  vrvecs = hkls @ atoms.cell
@@ -299,6 +299,38 @@ class GridGenerator:
299
299
  print("Too much grid number, it will be very slow.")
300
300
  view(self.atoms + ase.Atoms(symbols=['X'] * len(self.grid), positions=self.grid))
301
301
 
302
+ def get_grid_site_type(self, site_dict=None):
303
+ """
304
+ 根据第一近邻原子返回格点所对应的类型
305
+ :param site_dict: 格点的类型字典
306
+ example: {0:((atom_num, count),(atom_num, count),...), ..., 'next_idx':int}
307
+ :return: site label for each grid, site_dict
308
+ """
309
+
310
+ if site_dict is None:
311
+ site_dict = {'next_idx': 0}
312
+ _, Dga = get_distances(self.grid, self.atoms.positions, use_ase=True, cell=self.atoms.cell, pbc=self.atoms.pbc)
313
+ Lga = (Dga - (self.rsub+self.rads)*1.1 < 0) # 格点与原子的连接性。如果距离小于半径之和,则为不连接
314
+ LTga = np.asarray([(Lga & (self.atoms.numbers==n1)).sum(-1) for n1 in self.atoms_num_type]).T # 每个格点相连类矩阵
315
+ label_T_set = sorted(set(tuple(i) for i in LTga)) # 以原子类型区分,不同格点的类别集合
316
+ site_label = tuple(tuple((self.atoms_num_type[idx], l) for idx, l in enumerate(label)) for label in label_T_set)
317
+ if site_dict is None:
318
+ site_dict = {i:v for i,v in enumerate(site_label)}
319
+ site_dict['next_idx'] = len(site_label)
320
+ else:
321
+ # 比较新旧字典,如果有新的,就加入到 site_dict 中
322
+ old_site_label = tuple(v for k,v in site_dict.items() if type(k) == int)
323
+ new_site = [i for i in site_label if i not in old_site_label]
324
+ for ns in new_site:
325
+ idx = site_dict['next_idx']
326
+ site_dict[idx] = ns
327
+ site_dict['next_idx'] = idx + 1
328
+ site_dict_reverse = {tuple(iv[1] for iv in v):i for i,v in site_dict.items() if type(i)==int}
329
+ grid_T_label = [site_dict_reverse[tuple(i)] for i in LTga]
330
+ self.grid_site_type = grid_T_label
331
+ self.site_type_dict = site_dict
332
+ return grid_T_label, site_dict
333
+
302
334
 
303
335
  def rattle(positions, stdev=0.001, rng=None, seed=None):
304
336
  """Rattle the grid to make the vector distribution more smooth.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: surface_construct
3
- Version: 0.8
3
+ Version: 0.8.2
4
4
  Summary: Surface termination construction especially for complex model, such as oxides or carbides.
5
5
  Home-page: https://gitee.com/pjren/surface_construct/
6
6
  Author: ren