moospread 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- moospread/__init__.py +3 -0
- moospread/core.py +1881 -0
- moospread/problem.py +193 -0
- moospread/tasks/__init__.py +4 -0
- moospread/tasks/dtlz_torch.py +139 -0
- moospread/tasks/mw_torch.py +274 -0
- moospread/tasks/re_torch.py +394 -0
- moospread/tasks/zdt_torch.py +112 -0
- moospread/utils/__init__.py +8 -0
- moospread/utils/constraint_utils/__init__.py +2 -0
- moospread/utils/constraint_utils/gradient.py +72 -0
- moospread/utils/constraint_utils/mgda_core.py +69 -0
- moospread/utils/constraint_utils/pmgda_solver.py +308 -0
- moospread/utils/constraint_utils/prefs.py +64 -0
- moospread/utils/ditmoo.py +127 -0
- moospread/utils/lhs.py +74 -0
- moospread/utils/misc.py +28 -0
- moospread/utils/mobo_utils/__init__.py +11 -0
- moospread/utils/mobo_utils/evolution/__init__.py +0 -0
- moospread/utils/mobo_utils/evolution/dom.py +60 -0
- moospread/utils/mobo_utils/evolution/norm.py +40 -0
- moospread/utils/mobo_utils/evolution/utils.py +97 -0
- moospread/utils/mobo_utils/learning/__init__.py +0 -0
- moospread/utils/mobo_utils/learning/model.py +40 -0
- moospread/utils/mobo_utils/learning/model_init.py +33 -0
- moospread/utils/mobo_utils/learning/model_update.py +51 -0
- moospread/utils/mobo_utils/learning/prediction.py +116 -0
- moospread/utils/mobo_utils/learning/utils.py +143 -0
- moospread/utils/mobo_utils/lhs_for_mobo.py +243 -0
- moospread/utils/mobo_utils/mobo/__init__.py +0 -0
- moospread/utils/mobo_utils/mobo/acquisition.py +209 -0
- moospread/utils/mobo_utils/mobo/algorithms.py +91 -0
- moospread/utils/mobo_utils/mobo/factory.py +86 -0
- moospread/utils/mobo_utils/mobo/mobo.py +132 -0
- moospread/utils/mobo_utils/mobo/selection.py +182 -0
- moospread/utils/mobo_utils/mobo/solver/__init__.py +5 -0
- moospread/utils/mobo_utils/mobo/solver/moead.py +17 -0
- moospread/utils/mobo_utils/mobo/solver/nsga2.py +10 -0
- moospread/utils/mobo_utils/mobo/solver/parego/__init__.py +1 -0
- moospread/utils/mobo_utils/mobo/solver/parego/parego.py +62 -0
- moospread/utils/mobo_utils/mobo/solver/parego/utils.py +34 -0
- moospread/utils/mobo_utils/mobo/solver/pareto_discovery/__init__.py +1 -0
- moospread/utils/mobo_utils/mobo/solver/pareto_discovery/buffer.py +364 -0
- moospread/utils/mobo_utils/mobo/solver/pareto_discovery/pareto_discovery.py +571 -0
- moospread/utils/mobo_utils/mobo/solver/pareto_discovery/utils.py +168 -0
- moospread/utils/mobo_utils/mobo/solver/solver.py +74 -0
- moospread/utils/mobo_utils/mobo/surrogate_model/__init__.py +2 -0
- moospread/utils/mobo_utils/mobo/surrogate_model/base.py +36 -0
- moospread/utils/mobo_utils/mobo/surrogate_model/gaussian_process.py +177 -0
- moospread/utils/mobo_utils/mobo/surrogate_model/thompson_sampling.py +79 -0
- moospread/utils/mobo_utils/mobo/surrogate_problem.py +44 -0
- moospread/utils/mobo_utils/mobo/transformation.py +106 -0
- moospread/utils/mobo_utils/mobo/utils.py +65 -0
- moospread/utils/mobo_utils/spread_mobo_utils.py +854 -0
- moospread/utils/offline_utils/__init__.py +10 -0
- moospread/utils/offline_utils/handle_task.py +203 -0
- moospread/utils/offline_utils/proxies.py +338 -0
- moospread/utils/spread_utils.py +91 -0
- moospread-0.1.0.dist-info/METADATA +75 -0
- moospread-0.1.0.dist-info/RECORD +63 -0
- moospread-0.1.0.dist-info/WHEEL +5 -0
- moospread-0.1.0.dist-info/licenses/LICENSE +10 -0
- moospread-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,364 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
import numpy as np
|
|
3
|
+
from scipy.spatial import Delaunay
|
|
4
|
+
from copy import deepcopy
|
|
5
|
+
from pygco import cut_from_graph
|
|
6
|
+
|
|
7
|
+
from moospread.utils.mobo_utils.mobo.solver.pareto_discovery.utils import generate_weights_batch
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class BufferBase(ABC):
|
|
11
|
+
'''
|
|
12
|
+
Base class of performance buffer.
|
|
13
|
+
'''
|
|
14
|
+
def __init__(self, cell_num, cell_size=None, origin=None, origin_constant=1e-2, delta_b=0.2, label_cost=0):
|
|
15
|
+
'''
|
|
16
|
+
Input:
|
|
17
|
+
cell_num: number of discretized cells
|
|
18
|
+
cell_size: max sample number within each cell, None means no limit
|
|
19
|
+
origin: the origin point (minimum utopia)
|
|
20
|
+
origin_constant: when the origin point is surpassed by new inserted samples, adjust the origin point and substract this constant
|
|
21
|
+
delta_b: normalization constaint for calculating unary energy in sparse approximation (NOTE: in the paper they also use this to determine appending to buffer or rejection)
|
|
22
|
+
label_cost: for reducing number of unique labels in sparse approximation
|
|
23
|
+
'''
|
|
24
|
+
self.cell_num = cell_num
|
|
25
|
+
self.cell_size = cell_size if cell_size is not None and cell_size > 0 else None
|
|
26
|
+
self.origin = origin
|
|
27
|
+
self.origin_constant = origin_constant
|
|
28
|
+
|
|
29
|
+
# sparse approximation related
|
|
30
|
+
self.C_inf = 10
|
|
31
|
+
self.delta_b = delta_b
|
|
32
|
+
self.label_cost = label_cost
|
|
33
|
+
|
|
34
|
+
# buffer element arrays
|
|
35
|
+
# NOTE: below initializations without prior size could be inefficient in memory access
|
|
36
|
+
self.buffer_x = [[] for _ in range(self.cell_num)]
|
|
37
|
+
self.buffer_y = [[] for _ in range(self.cell_num)]
|
|
38
|
+
self.buffer_dist = [[] for _ in range(self.cell_num)] # stores distance to origin for each sample, exactly the same size as self.buffer
|
|
39
|
+
self.buffer_patch_id = [[] for _ in range(self.cell_num)] # stores the index of manifold (patch) that each sample belongs to
|
|
40
|
+
|
|
41
|
+
self.sample_count = 0
|
|
42
|
+
|
|
43
|
+
@abstractmethod
|
|
44
|
+
def _find_cell_id(self, F):
|
|
45
|
+
'''
|
|
46
|
+
Find corresponding cell indices given normalized performance.
|
|
47
|
+
Input:
|
|
48
|
+
F: a batch of normalized performance
|
|
49
|
+
Output:
|
|
50
|
+
cell_ids: a batch of cell indices
|
|
51
|
+
'''
|
|
52
|
+
pass
|
|
53
|
+
|
|
54
|
+
def insert(self, X, Y, patch_ids):
|
|
55
|
+
'''
|
|
56
|
+
Insert samples (X, Y) into buffer, which come from manifolds (patches) indexed by 'patch_ids'
|
|
57
|
+
'''
|
|
58
|
+
# normalize performance
|
|
59
|
+
X, Y = np.array(X), np.array(Y)
|
|
60
|
+
self.move_origin(np.min(Y, axis=0))
|
|
61
|
+
F = Y - self.origin
|
|
62
|
+
|
|
63
|
+
# calculate corresponding cell index
|
|
64
|
+
dists = np.linalg.norm(F, axis=1)
|
|
65
|
+
cell_ids = self._find_cell_id(F)
|
|
66
|
+
|
|
67
|
+
# insert into buffer
|
|
68
|
+
for x, y, cell_id, dist, patch_id in zip(X, Y, cell_ids, dists, patch_ids):
|
|
69
|
+
self.buffer_x[cell_id].append(x)
|
|
70
|
+
self.buffer_y[cell_id].append(y)
|
|
71
|
+
self.buffer_dist[cell_id].append(dist)
|
|
72
|
+
self.buffer_patch_id[cell_id].append(patch_id)
|
|
73
|
+
self.sample_count += len(X)
|
|
74
|
+
|
|
75
|
+
# update cells
|
|
76
|
+
for cell_id in np.unique(cell_ids):
|
|
77
|
+
self._update_cell(cell_id)
|
|
78
|
+
|
|
79
|
+
def sample_old(self, n):
|
|
80
|
+
'''
|
|
81
|
+
Sample n samples in current buffer with best performance. (Deprecated)
|
|
82
|
+
'''
|
|
83
|
+
# TODO: check if it's proper to repeatedly sample the best one without considering others
|
|
84
|
+
selected_cell_ids = []
|
|
85
|
+
nonempty_cell_ids = [i for i in range(self.cell_num) if len(self.buffer_dist[i]) > 0]
|
|
86
|
+
n_nonempty_cells = len(nonempty_cell_ids) # number of non-empty cells
|
|
87
|
+
|
|
88
|
+
# while n >= n_nonempty_cells, we select all non-empty cells
|
|
89
|
+
selected_cell_ids.extend((n // n_nonempty_cells) * nonempty_cell_ids)
|
|
90
|
+
# when n < n_nonempty_cells, we select non-empty cells randomly
|
|
91
|
+
selected_cell_ids.extend(list(np.random.choice(nonempty_cell_ids, size=n % n_nonempty_cells, replace=False)))
|
|
92
|
+
|
|
93
|
+
# get the best solution in each cell
|
|
94
|
+
selected_cells = np.array(self.buffer_x)[np.array(selected_cell_ids)]
|
|
95
|
+
selected_samples = [cell[0] for cell in selected_cells]
|
|
96
|
+
|
|
97
|
+
return np.array(selected_samples)
|
|
98
|
+
|
|
99
|
+
def sample(self, n):
|
|
100
|
+
'''
|
|
101
|
+
Sample n samples in current buffer with best performance. (Active)
|
|
102
|
+
'''
|
|
103
|
+
nonempty_cell_ids = [i for i in range(self.cell_num) if len(self.buffer_dist[i]) > 0]
|
|
104
|
+
|
|
105
|
+
# when n is less than number of non-empty cells, randomly pick the 1st samples in cells
|
|
106
|
+
if n <= len(nonempty_cell_ids):
|
|
107
|
+
selected_cell_ids = np.random.choice(nonempty_cell_ids, size=n, replace=False)
|
|
108
|
+
selected_samples = [cell[0] for cell in np.array(self.buffer_x)[selected_cell_ids]]
|
|
109
|
+
|
|
110
|
+
# when n is greater, pick samples in cells round by round (1st, 2nd, ...)
|
|
111
|
+
else:
|
|
112
|
+
k = 0
|
|
113
|
+
selected_samples = []
|
|
114
|
+
while len(selected_samples) < n:
|
|
115
|
+
# find cells need to be sampled in current round
|
|
116
|
+
nonempty_cell_ids = [i for i in range(self.cell_num) if len(self.buffer_dist[i]) > k]
|
|
117
|
+
|
|
118
|
+
if len(nonempty_cell_ids) == 0: # when total number of samples in buffer is less than sample number
|
|
119
|
+
random_indices = np.random.choice(np.arange(len(selected_samples)), size=(n - len(selected_samples)))
|
|
120
|
+
selected_samples = np.vstack([selected_samples, np.array(selected_samples)[random_indices]])
|
|
121
|
+
break
|
|
122
|
+
|
|
123
|
+
curr_selected_samples = [cell[k] for cell in np.array(self.buffer_x)[nonempty_cell_ids]]
|
|
124
|
+
selected_samples.extend(np.random.permutation(curr_selected_samples))
|
|
125
|
+
selected_samples = np.array(selected_samples[:n])
|
|
126
|
+
return selected_samples
|
|
127
|
+
|
|
128
|
+
def move_origin(self, y_min):
|
|
129
|
+
'''
|
|
130
|
+
Move the origin point when y_min surpasses it, redistribute current buffer storage accordingly
|
|
131
|
+
'''
|
|
132
|
+
if (y_min >= self.origin).all() and not (y_min == self.origin).any(): return
|
|
133
|
+
|
|
134
|
+
self.origin = np.minimum(self.origin, y_min) - self.origin_constant
|
|
135
|
+
|
|
136
|
+
old_buffer_x, old_buffer_y = deepcopy(self.buffer_x), deepcopy(self.buffer_y)
|
|
137
|
+
old_buffer_patch_id = deepcopy(self.buffer_patch_id)
|
|
138
|
+
self.buffer_x, self.buffer_y, self.buffer_dist, self.buffer_patch_id = [[[] for _ in range(self.cell_num)] for _ in range(4)]
|
|
139
|
+
self.sample_count = 0
|
|
140
|
+
|
|
141
|
+
for cell_x, cell_y, cell_patch_id in zip(old_buffer_x, old_buffer_y, old_buffer_patch_id):
|
|
142
|
+
if len(cell_x) > 0:
|
|
143
|
+
self.insert(cell_x, cell_y, cell_patch_id)
|
|
144
|
+
|
|
145
|
+
def _update_cell(self, cell_id):
|
|
146
|
+
'''
|
|
147
|
+
Sort particular cell according to distance to origin, and only keep self.cell_size samples in the cell
|
|
148
|
+
'''
|
|
149
|
+
if len(self.buffer_dist[cell_id]) == 0: return
|
|
150
|
+
|
|
151
|
+
idx = np.argsort(self.buffer_dist[cell_id])
|
|
152
|
+
if self.cell_size is not None:
|
|
153
|
+
self.sample_count -= max(len(idx) - self.cell_size, 0)
|
|
154
|
+
idx = idx[:self.cell_size]
|
|
155
|
+
|
|
156
|
+
# TODO: check if time-consuming here
|
|
157
|
+
self.buffer_x[cell_id], self.buffer_y[cell_id], self.buffer_dist[cell_id], self.buffer_patch_id[cell_id] = \
|
|
158
|
+
map(lambda x: list(np.array(x)[idx]),
|
|
159
|
+
[self.buffer_x[cell_id], self.buffer_y[cell_id], self.buffer_dist[cell_id], self.buffer_patch_id[cell_id]])
|
|
160
|
+
|
|
161
|
+
@abstractmethod
|
|
162
|
+
def _get_graph_edges(self, valid_cells):
|
|
163
|
+
'''
|
|
164
|
+
Get the edge information of connectivity graph of buffer cells for graph-cut.
|
|
165
|
+
Used for sparse_approximation(), see section 6.4.
|
|
166
|
+
Input:
|
|
167
|
+
valid_cells: non-empty cells that can be formulated as vertices in the graph
|
|
168
|
+
Output:
|
|
169
|
+
edges: edge array of the input vertices, where an edge is represented by two vertices, shape = (n_edges, 2)
|
|
170
|
+
'''
|
|
171
|
+
pass
|
|
172
|
+
|
|
173
|
+
def sparse_approximation(self):
|
|
174
|
+
'''
|
|
175
|
+
Use a few manifolds to sparsely approximate the pareto front by graph-cut, see section 6.4.
|
|
176
|
+
Output:
|
|
177
|
+
labels: the optimized labels (manifold index) for each non-empty cell (the cells also contain the corresponding labeled sample), shape = (n_label,)
|
|
178
|
+
approx_x: the labeled design samples, shape = (n_label, n_var)
|
|
179
|
+
approx_y: the labeled performance values, shape = (n_label, n_obj)
|
|
180
|
+
'''
|
|
181
|
+
# update patch ids, remove non-existing ids previously removed from buffer
|
|
182
|
+
mapping = {}
|
|
183
|
+
patch_id_count = 0
|
|
184
|
+
for cell_id in range(self.cell_num):
|
|
185
|
+
if self.buffer_patch_id[cell_id] == []: continue
|
|
186
|
+
curr_patches = self.buffer_patch_id[cell_id]
|
|
187
|
+
for i in range(len(curr_patches)):
|
|
188
|
+
if curr_patches[i] not in mapping:
|
|
189
|
+
mapping[curr_patches[i]] = patch_id_count
|
|
190
|
+
patch_id_count += 1
|
|
191
|
+
self.buffer_patch_id[cell_id][i] = mapping[curr_patches[i]]
|
|
192
|
+
|
|
193
|
+
# construct unary and pairwise energy (cost) matrix for graph-cut
|
|
194
|
+
# NOTE: delta_b should be set properly
|
|
195
|
+
valid_cells = np.where([self.buffer_dist[cell_id] != [] for cell_id in range(self.cell_num)])[0] # non-empty cells
|
|
196
|
+
n_node = len(valid_cells)
|
|
197
|
+
n_label = patch_id_count
|
|
198
|
+
unary_cost = self.C_inf * np.ones((n_node, n_label))
|
|
199
|
+
pairwise_cost = -self.C_inf * np.eye(n_label)
|
|
200
|
+
|
|
201
|
+
for i, idx in enumerate(valid_cells):
|
|
202
|
+
patches, distances = np.array(self.buffer_patch_id[idx]), np.array(self.buffer_dist[idx])
|
|
203
|
+
min_dist = np.min(distances)
|
|
204
|
+
unary_cost[i, patches] = np.minimum((distances - min_dist) / self.delta_b, self.C_inf)
|
|
205
|
+
|
|
206
|
+
# get edge information (graph structure)
|
|
207
|
+
edges = self._get_graph_edges(valid_cells)
|
|
208
|
+
|
|
209
|
+
# NOTE: pygco only supports int32 as input, due to potential numerical error
|
|
210
|
+
edges, unary_cost, pairwise_cost, label_cost = \
|
|
211
|
+
edges.astype(np.int32), unary_cost.astype(np.int32), pairwise_cost.astype(np.int32), np.int32(self.label_cost)
|
|
212
|
+
|
|
213
|
+
# do graph-cut, optimize labels for each valid cell
|
|
214
|
+
labels_opt = cut_from_graph(edges, unary_cost, pairwise_cost, label_cost)
|
|
215
|
+
|
|
216
|
+
# find corresponding design and performance values of optimized labels for each valid cell
|
|
217
|
+
approx_xs, approx_ys = [], []
|
|
218
|
+
labels = [] # for a certain cell, there could be no sample belongs to that label, probably due to the randomness of sampling or improper energy definition
|
|
219
|
+
for idx, label in zip(valid_cells, labels_opt):
|
|
220
|
+
for cell_patch_id, cell_x, cell_y in zip(self.buffer_patch_id[idx], self.buffer_x[idx], self.buffer_y[idx]):
|
|
221
|
+
# since each buffer element array is sorted based on distance to origin
|
|
222
|
+
if cell_patch_id == label:
|
|
223
|
+
approx_xs.append(cell_x)
|
|
224
|
+
approx_ys.append(cell_y)
|
|
225
|
+
labels.append(label)
|
|
226
|
+
break
|
|
227
|
+
else: # TODO: check
|
|
228
|
+
approx_xs.append(self.buffer_x[idx][0])
|
|
229
|
+
approx_ys.append(self.buffer_y[idx][0])
|
|
230
|
+
labels.append(label)
|
|
231
|
+
approx_xs, approx_ys = np.array(approx_xs), np.array(approx_ys)
|
|
232
|
+
|
|
233
|
+
# NOTE: uncomment code below to show visualization of graph cut
|
|
234
|
+
# import matplotlib.pyplot as plt
|
|
235
|
+
# from matplotlib import cm
|
|
236
|
+
# cmap = cm.get_cmap('tab20', patch_id_count)
|
|
237
|
+
# fig, axs = plt.subplots(1, 2, sharex=True, sharey=True)
|
|
238
|
+
# buffer_ys = np.vstack([np.vstack(cell_y) for cell_y in self.buffer_y if cell_y != []])
|
|
239
|
+
# buffer_patch_ids = np.concatenate([cell_patch_id for cell_patch_id in np.array(self.buffer_patch_id)[valid_cells]])
|
|
240
|
+
# colors = [cmap(patch_id) for patch_id in buffer_patch_ids]
|
|
241
|
+
# axs[0].scatter(*buffer_ys.T, s=10, c=colors)
|
|
242
|
+
# axs[0].set_title('Before graph cut')
|
|
243
|
+
# colors = [cmap(label) for label in labels]
|
|
244
|
+
# axs[1].scatter(*approx_ys.T, s=10, c=colors)
|
|
245
|
+
# axs[1].set_title('After graph cut')
|
|
246
|
+
# fig.suptitle(f'Sparse approximation, # patches: {patch_id_count}, # families: {len(np.unique(labels))}')
|
|
247
|
+
# plt.show()
|
|
248
|
+
|
|
249
|
+
return labels, approx_xs, approx_ys
|
|
250
|
+
|
|
251
|
+
def flattened(self):
|
|
252
|
+
'''
|
|
253
|
+
Return flattened x and y arrays from all the cells.
|
|
254
|
+
'''
|
|
255
|
+
flattened_x, flattened_y = [], []
|
|
256
|
+
for cell_x, cell_y in zip(self.buffer_x, self.buffer_y):
|
|
257
|
+
if cell_x != []:
|
|
258
|
+
flattened_x.append(cell_x)
|
|
259
|
+
if cell_y != []:
|
|
260
|
+
flattened_y.append(cell_y)
|
|
261
|
+
return np.concatenate(flattened_x), np.concatenate(flattened_y)
|
|
262
|
+
|
|
263
|
+
|
|
264
|
+
class Buffer2D(BufferBase):
|
|
265
|
+
'''
|
|
266
|
+
2D performance buffer.
|
|
267
|
+
'''
|
|
268
|
+
def __init__(self, cell_num, *args, **kwargs):
|
|
269
|
+
if cell_num is None: cell_num = 100
|
|
270
|
+
super().__init__(cell_num, *args, **kwargs)
|
|
271
|
+
if self.origin is None:
|
|
272
|
+
self.origin = np.zeros(2)
|
|
273
|
+
self.dtheta = np.pi / 2.0 / self.cell_num
|
|
274
|
+
|
|
275
|
+
def _find_cell_id(self, F):
|
|
276
|
+
dist = np.linalg.norm(F, axis=1)
|
|
277
|
+
theta = np.arccos(F[:, 0] / dist)
|
|
278
|
+
cell_ids = theta / self.dtheta
|
|
279
|
+
cell_ids = np.minimum(cell_ids.astype(int), self.cell_num - 1)
|
|
280
|
+
return cell_ids
|
|
281
|
+
|
|
282
|
+
def _get_graph_edges(self, valid_cells):
|
|
283
|
+
# get edges by connecting neighbor cells
|
|
284
|
+
edges = np.array([[i, i + 1] for i in range(len(valid_cells) - 1)])
|
|
285
|
+
return edges
|
|
286
|
+
|
|
287
|
+
|
|
288
|
+
class Buffer3D(BufferBase):
|
|
289
|
+
'''
|
|
290
|
+
3D performance buffer.
|
|
291
|
+
'''
|
|
292
|
+
def __init__(self, cell_num, *args, origin=None, **kwargs):
|
|
293
|
+
if cell_num is None: cell_num = 1000
|
|
294
|
+
super().__init__(cell_num, *args, origin=origin, **kwargs)
|
|
295
|
+
if self.origin is None:
|
|
296
|
+
self.origin = np.zeros(3)
|
|
297
|
+
# it's really hard to generate evenly distributed unit vectors in 3d space, use some tricks here
|
|
298
|
+
edge_cell_num = int(np.sqrt(2 * cell_num + 0.25) + 0.5) - 1
|
|
299
|
+
cell_vecs = generate_weights_batch(n_dim=3, delta_weight=1.0 / (edge_cell_num - 1))
|
|
300
|
+
if len(cell_vecs) < cell_num:
|
|
301
|
+
random_vecs = np.random.random((cell_num - len(cell_vecs), 3))
|
|
302
|
+
cell_vecs = np.vstack([cell_vecs, random_vecs])
|
|
303
|
+
self.cell_vecs = cell_vecs / np.linalg.norm(cell_vecs, axis=1)[:, None]
|
|
304
|
+
|
|
305
|
+
def _find_cell_id(self, F):
|
|
306
|
+
dots = F @ self.cell_vecs.T
|
|
307
|
+
cell_ids = np.argmax(dots, axis=1)
|
|
308
|
+
return cell_ids
|
|
309
|
+
|
|
310
|
+
def _get_graph_edges(self, valid_cells):
|
|
311
|
+
|
|
312
|
+
# NOTE: uncomment code below to show visualization of buffer
|
|
313
|
+
# import matplotlib.pyplot as plt
|
|
314
|
+
# from mpl_toolkits.mplot3d import Axes3D
|
|
315
|
+
# fig = plt.figure()
|
|
316
|
+
# ax = fig.add_subplot(111, projection='3d')
|
|
317
|
+
# ax.view_init(azim=45)
|
|
318
|
+
# for vec in self.cell_vecs:
|
|
319
|
+
# ax.plot(*np.array([self.origin, self.origin + vec]).T, color='gray', linewidth=1, alpha=0.5)
|
|
320
|
+
# for cell_y in self.buffer_y:
|
|
321
|
+
# if cell_y == []: continue
|
|
322
|
+
# ax.scatter(*np.array(cell_y).T)
|
|
323
|
+
# plt.title(f'# samples: {self.sample_count}, # valid cells: {len(valid_cells)}')
|
|
324
|
+
# plt.show()
|
|
325
|
+
|
|
326
|
+
# TODO: corner cases, need to check why they happen
|
|
327
|
+
if len(valid_cells) == 1:
|
|
328
|
+
raise Exception('only 1 non-empty cell in buffer, cannot do graph cut')
|
|
329
|
+
elif len(valid_cells) == 2:
|
|
330
|
+
return np.array([[0, 1]])
|
|
331
|
+
elif len(valid_cells) == 3:
|
|
332
|
+
return np.array([[0, 1], [0, 2], [1, 2]])
|
|
333
|
+
|
|
334
|
+
# triangulate endpoints of cell vectors to form a mesh, then get edges from this mesh
|
|
335
|
+
vertices = self.cell_vecs[valid_cells]
|
|
336
|
+
|
|
337
|
+
# check if vertices fall on a single line
|
|
338
|
+
check_equal = (vertices == vertices[0]).all(axis=0)
|
|
339
|
+
if check_equal.any():
|
|
340
|
+
indices = np.argsort(vertices[:, np.where(np.logical_not(check_equal))[0][0]])
|
|
341
|
+
edges = np.array([indices[:-1], indices[1:]]).T
|
|
342
|
+
edges = np.ascontiguousarray(edges)
|
|
343
|
+
return edges
|
|
344
|
+
|
|
345
|
+
tri = Delaunay(vertices)
|
|
346
|
+
ind, all_neighbors = tri.vertex_neighbor_vertices
|
|
347
|
+
edges = []
|
|
348
|
+
for i in range(len(vertices)):
|
|
349
|
+
neighbors = all_neighbors[ind[i]:ind[i + 1]]
|
|
350
|
+
for j in neighbors:
|
|
351
|
+
edges.append(np.sort([i, j]))
|
|
352
|
+
edges = np.unique(edges, axis=0)
|
|
353
|
+
return edges
|
|
354
|
+
|
|
355
|
+
|
|
356
|
+
def get_buffer(n_obj, *args, **kwargs):
|
|
357
|
+
'''
|
|
358
|
+
Select buffer according to n_obj.
|
|
359
|
+
'''
|
|
360
|
+
buffer_map = {2: Buffer2D, 3: Buffer3D}
|
|
361
|
+
if n_obj in buffer_map:
|
|
362
|
+
return buffer_map[n_obj](*args, **kwargs)
|
|
363
|
+
else:
|
|
364
|
+
raise NotImplementedError
|