GroupMultiNeSS 0.0.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,72 @@
1
+ import numpy as np
2
+ from typing import List
3
+ from utils import truncated_svd, truncated_eigen_decomposition
4
+
5
+
6
+ def MASE(As: np.array, d_shared: int, d_individs: List[int]):
7
+ """
8
+ Compute the Multiple Adjacency Spectral Embedding (MASE) of a set of adjacency matrices, method proposed in [1].
9
+
10
+ Parameters
11
+ ----------
12
+ As : np.array
13
+ List of adjacency matrices to embed.
14
+ d_shared : int
15
+ Dimension of the shared latent space across all layers.
16
+ d_individs : List[int]
17
+ List of dimensions for individual latent spaces of each adjacency matrix.
18
+
19
+ Returns
20
+ -------
21
+ ps : List[np.ndarray]
22
+ Reconstructed matrices projected onto the shared space.
23
+ u_joint : np.ndarray
24
+ Shared latent positions.
25
+ rs : List[np.ndarray]
26
+ Latent-space representations of each adjacency matrix.
27
+
28
+ References
29
+ ----------
30
+ [1] Arroyo, J., Athreya, A., Cape, J., Chen, G., Priebe, C. E., & Vogelstein, J. T. (2021).
31
+ Inference for multiple heterogeneous networks with a common invariant subspace.
32
+ Journal of Machine Learning Research, 22, 1‑49.
33
+ """
34
+ assert len(As) == len(d_individs)
35
+ us = []
36
+ for A, d_individ in zip(As, d_individs):
37
+ _, eigenvectors = truncated_eigen_decomposition(A, max_rank=d_individ)
38
+ us.append(eigenvectors)
39
+ u_joint, _, _ = truncated_svd(np.hstack(us), max_rank=d_shared)
40
+ rs = [u_joint.T @ A @ u_joint for A in As]
41
+ ps = [u_joint @ r @ u_joint.T for r in rs]
42
+ return ps, u_joint, rs
43
+
44
+
45
+ def ASE(A: np.array, d: int, check_if_symmetric=True):
46
+ """
47
+ Compute the Adjacency Spectral Embedding (ASE) of a symmetric adjacency matrix.
48
+
49
+ Parameters
50
+ ----------
51
+ A : np.array
52
+ Symmetric adjacency matrix to embed.
53
+ d : int
54
+ Embedding dimension.
55
+ check_if_symmetric : bool, optional
56
+ Whether to check symmetry of A. Default is True.
57
+
58
+ Returns
59
+ -------
60
+ np.ndarray
61
+ Latent position matrix of shape (n_nodes, d).
62
+
63
+ References
64
+ ----------
65
+ Sussman, D. L., Tang, M., Fishkind, D. E., & Priebe, C. E. (2012). A consistent adjacency spectral embedding
66
+ for stochastic blockmodel graphs and some of its applications. Journal of Computational and Graphical Statistics
67
+ """
68
+ if check_if_symmetric:
69
+ assert np.allclose(A, A.T), "A should be a symmetric matrix"
70
+ eigvecs, eigvals, _ = truncated_svd(A, max_rank=d)
71
+ eigvals = np.sqrt(eigvals)
72
+ return eigvecs @ np.diag(eigvals)
@@ -0,0 +1,7 @@
1
+ from .base import *
2
+ from .data_generation import *
3
+ from .group_multiness import *
4
+ from .MASE import *
5
+ from .multiness import *
6
+ from .shared_space_hunting import *
7
+ from .utils import *
@@ -0,0 +1,337 @@
1
+ import numpy as np
2
+ from numpy.linalg import norm
3
+ from typing import List, Union
4
+ from warnings import warn
5
+ from typing import Iterable
6
+ from multiprocessing import shared_memory
7
+ import statsmodels.api as sm
8
+ from more_itertools import zip_equal
9
+
10
+ from .utils import sigmoid, leading_left_eigenvectors
11
+
12
+
13
+ class SharedMemoryMatrixFitter:
14
+ """
15
+ Base class for managing parameter matrices in shared memory.
16
+
17
+ Allows storing, retrieving, and updating matrices in shared memory,
18
+ which can be used for parallel computations without duplicating data.
19
+ """
20
+
21
+ def _init_param_matrices(self, vals: Union[List[np.ndarray], np.ndarray] = None,
22
+ shape: Iterable[int] = None) -> None:
23
+ """
24
+ Initialize shared memory for parameter matrices.
25
+
26
+ Parameters
27
+ ----------
28
+ vals : list of np.ndarray or np.ndarray, optional
29
+ Initial values for the matrices. If provided, shape is inferred.
30
+ shape : iterable of int, optional
31
+ Shape of the matrices to initialize if vals is not provided.
32
+
33
+ Notes
34
+ -----
35
+ - If both vals and shape are provided, vals take precedence and shape is ignored.
36
+ - Allocates shared memory of type float64 for storing matrices.
37
+ """
38
+ if shape is None:
39
+ assert vals is not None, "One of vals or shape arguments should berandom_seed provided!"
40
+ vals = np.array(vals)
41
+ shape = vals.shape
42
+ elif vals is not None:
43
+ warn("shape argument is not used as vals argument is provided!")
44
+
45
+ shm = shared_memory.SharedMemory(create=True, size=np.prod(shape) * 8)
46
+ self._param_matrices_shm = shared_memory.SharedMemory(name=shm.name)
47
+ self.param_matrices = np.ndarray(shape, dtype=np.float64, buffer=self._param_matrices_shm.buf)
48
+ if vals is not None:
49
+ self.param_matrices[:] = vals
50
+
51
+ def _set_matrix(self, val, idx):
52
+ """
53
+ Set a specific matrix in the shared memory.
54
+
55
+ Parameters
56
+ ----------
57
+ val : np.ndarray
58
+ Matrix value to set.
59
+ idx : int
60
+ Index of the matrix to update.
61
+ """
62
+ param_matrices = self.get_all_fitted_matrices()
63
+ param_matrices[idx] = val
64
+
65
+ def _set_matrices(self, vals, indices: Union[List[int], np.ndarray] = None):
66
+ """
67
+ Set multiple matrices in shared memory.
68
+
69
+ Parameters
70
+ ----------
71
+ vals : list[np.ndarray] or np.ndarray
72
+ List of matrix values to set.
73
+ indices : list or np.ndarray, optional
74
+ Indices of matrices to update. If None, all matrices are updated.
75
+ """
76
+ if indices is None:
77
+ indices = list(range(len(self.get_all_fitted_matrices())))
78
+ assert len(vals) == len(indices)
79
+ for idx, val in zip_equal(indices, vals):
80
+ self._set_matrix(val, idx)
81
+
82
+ def get_all_fitted_matrices(self):
83
+ """
84
+ Retrieve all matrices stored in shared memory.
85
+
86
+ Returns
87
+ -------
88
+ np.ndarray
89
+ Array backed by shared memory containing all parameter matrices.
90
+ """
91
+ return np.ndarray(self.param_matrices.shape, dtype=np.float64, buffer=self._param_matrices_shm.buf)
92
+
93
+
94
+ class BaseMultiplexNetworksModel:
95
+ """
96
+ Base class for multiplex network models.
97
+
98
+ Handles edge distribution, loop constraints, and basic input validation.
99
+ """
100
+
101
+ def __init__(self, edge_distrib: str = "normal", loops_allowed: bool = True):
102
+ """
103
+ Initialize a multiplex network model.
104
+
105
+ Parameters
106
+ ----------
107
+ edge_distrib : str, default='normal'
108
+ Type of edge distribution ('normal' or 'bernoulli').
109
+ loops_allowed : bool, default=True
110
+ Whether self-loops are allowed in adjacency matrices.
111
+ """
112
+ self.edge_distrib = edge_distrib
113
+ self.loops_allowed = loops_allowed
114
+ if edge_distrib == "normal":
115
+ self.link_ = lambda x: x
116
+ elif edge_distrib == "bernoulli":
117
+ self.link_ = sigmoid
118
+ else:
119
+ raise NotImplementedError("Edge distribution should be either normal or bernoulli!")
120
+
121
+ def _validate_input(self, As: List[np.array]) -> np.ndarray:
122
+ """
123
+ Validate input adjacency matrices.
124
+
125
+ Checks for:
126
+ - Square matrices
127
+ - Consistent shape across layers
128
+ - Absence of loops if loops_allowed is False
129
+
130
+ Parameters
131
+ ----------
132
+ As : list of np.ndarray
133
+ List of adjacency matrices for each layer.
134
+
135
+ Returns
136
+ -------
137
+ np.ndarray
138
+ Stack of adjacency matrices as a 3D array.
139
+ """
140
+ assert As[0].shape[0] == As[0].shape[1], "adjacency matrix should have a square form"
141
+ assert np.all([As[0].shape == A.shape for A in As[1:]]), "networks should share the same vertex set"
142
+ if not self.loops_allowed:
143
+ assert np.all(np.hstack([np.diag(A) for A in As]) == 0), \
144
+ "loops present in one of adjacency matrices while loops_allowed == False"
145
+ self.n_nodes_ = As[0].shape[0]
146
+ self.n_layers_ = len(As)
147
+ self._param_participance_masks = self._get_param_participance_masks()
148
+ return np.stack(As)
149
+
150
+ def _get_param_participance_masks(self) -> np.ndarray:
151
+ """
152
+ Return masks indicating which parameters participate in each layer.
153
+
154
+ Must be implemented by subclasses.
155
+
156
+ Returns
157
+ -------
158
+ np.ndarray
159
+ Boolean mask array for parameter participation.
160
+ """
161
+ pass
162
+
163
+
164
+ class BaseRefitting(BaseMultiplexNetworksModel, SharedMemoryMatrixFitter):
165
+ """
166
+ Base class for refitting latent matrices in multiplex network models.
167
+ """
168
+
169
+ def __init__(self, edge_distrib="normal", max_rank: int = None, loops_allowed: bool = True,
170
+ refit_threshold: float = 1e-8):
171
+ """
172
+ Initialize the BaseRefitting object.
173
+
174
+ Parameters
175
+ ----------
176
+ edge_distrib : str, default="normal"
177
+ Type of edge distribution ('normal' or 'bernoulli').
178
+ max_rank : int, optional
179
+ Maximum rank for eigenvector-based refitting.
180
+ loops_allowed : bool, default=True
181
+ Whether self-loops are allowed in adjacency matrices.
182
+ refit_threshold : float, default=1e-8
183
+ Threshold to determine which matrices should be refitted.
184
+ """
185
+ BaseMultiplexNetworksModel.__init__(self, edge_distrib=edge_distrib, loops_allowed=loops_allowed)
186
+ self.max_rank = max_rank
187
+ self.refit_threshold = refit_threshold
188
+
189
+ @property
190
+ def n_fitted_matrices(self):
191
+ """
192
+ Number of matrices currently fitted in shared memory.
193
+
194
+ Returns
195
+ -------
196
+ int
197
+ Number of fitted matrices.
198
+ """
199
+ return len(self.get_all_fitted_matrices())
200
+
201
+ def _compute_offset(self, refit_matrix_indices) -> np.ndarray:
202
+ """
203
+ Compute offset vector for refitting.
204
+
205
+ Parameters
206
+ ----------
207
+ refit_matrix_indices : array-like
208
+ Indices of matrices to refit.
209
+
210
+ Returns
211
+ -------
212
+ np.ndarray
213
+ Offset vector for GLM refitting.
214
+ """
215
+ offsets_over_obs = []
216
+ obs_mask = self._param_participance_masks[refit_matrix_indices].any(0)
217
+ triu_indices = np.triu_indices(self.n_nodes_, k=0 if self.loops_allowed else 1)
218
+ for layer_particip_mask in self._param_participance_masks[:, obs_mask].T:
219
+ offset_indices = np.setdiff1d(np.arange(self.n_fitted_matrices)[layer_particip_mask], refit_matrix_indices)
220
+ obs_offset = self.get_all_fitted_matrices()[offset_indices].sum(0)[triu_indices]
221
+ offsets_over_obs.append(obs_offset)
222
+ return np.hstack(offsets_over_obs)
223
+
224
+ def _construct_refit_design_response_and_offset(self, As: np.ndarray, refit_matrix_indices: np.ndarray,
225
+ refit_matrix_eigvecs: List[np.ndarray]):
226
+ """
227
+ Construct the design matrix, response vector, and offset for GLM refitting.
228
+
229
+ Parameters
230
+ ----------
231
+ As : np.ndarray
232
+ Stacked adjacency matrices.
233
+ refit_matrix_indices : np.ndarray
234
+ Indices of matrices to refit.
235
+ refit_matrix_eigvecs : list of np.ndarray
236
+ Eigenvectors for the refit matrices.
237
+
238
+ Returns
239
+ -------
240
+ tuple of np.ndarray
241
+ Design matrix, response vector, and offset vector for GLM fitting.
242
+ """
243
+ triu_indices = np.triu_indices(self.n_nodes_, k=0 if self.loops_allowed else 1)
244
+ obs_mask = self._param_participance_masks[refit_matrix_indices].any(0)
245
+ response = np.hstack([A[triu_indices] for A in As[obs_mask]])
246
+
247
+ param_designs = [np.stack([np.outer(evec, evec)[triu_indices] for evec in eigvecs.T], axis=-1)
248
+ for eigvecs in refit_matrix_eigvecs]
249
+ offset = self._compute_offset(refit_matrix_indices)
250
+ refit_matrix_participance_masks = self._param_participance_masks[refit_matrix_indices][:, obs_mask]
251
+
252
+ design_mat = []
253
+ for param_design, participance_mask in zip_equal(param_designs, refit_matrix_participance_masks):
254
+ design_mat.append(np.vstack([param_design if participate else np.zeros_like(param_design)
255
+ for participate in participance_mask]))
256
+
257
+ present_vals_mask = ~np.isnan(response)
258
+ return np.hstack(design_mat)[present_vals_mask], response[present_vals_mask], offset[present_vals_mask]
259
+
260
+ @staticmethod
261
+ def _construct_refitted_matrices(refit_matrix_eigvecs: List[np.ndarray], refit_matrix_eigvals: np.ndarray):
262
+ """
263
+ Reconstruct matrices from eigenvectors and refitted eigenvalues.
264
+
265
+ Parameters
266
+ ----------
267
+ refit_matrix_eigvecs : list of np.ndarray
268
+ Eigenvectors for each refit matrix.
269
+ refit_matrix_eigvals : np.ndarray
270
+ Refitted eigenvalues from GLM.
271
+
272
+ Returns
273
+ -------
274
+ list of np.ndarray
275
+ Refitted matrices.
276
+ """
277
+ assert np.sum([mat.shape[1] for mat in refit_matrix_eigvecs]) == len(refit_matrix_eigvals)
278
+ refitted_matrices = []
279
+ start_idx = 0
280
+ for matrix_eigvec in refit_matrix_eigvecs:
281
+ matrix_dim = matrix_eigvec.shape[1]
282
+ matrix_eigvals = refit_matrix_eigvals[start_idx: start_idx + matrix_dim]
283
+ refit_matrix = matrix_eigvec @ np.diag(matrix_eigvals) @ matrix_eigvec.T
284
+ refitted_matrices.append(refit_matrix)
285
+ start_idx += matrix_dim
286
+ return refitted_matrices
287
+
288
+ def _preprocess_refit_matrix_indices(self, refit_matrix_indices=None) -> np.ndarray:
289
+ """
290
+ Preprocess the indices of matrices to refit by filtering out near-zero matrices.
291
+
292
+ Parameters
293
+ ----------
294
+ refit_matrix_indices : iterable of int, optional
295
+ Indices of matrices to consider for refitting.
296
+
297
+ Returns
298
+ -------
299
+ np.ndarray
300
+ Filtered array of matrix indices to refit.
301
+ """
302
+ if refit_matrix_indices is None:
303
+ refit_matrix_indices = range(self.n_fitted_matrices)
304
+ refit_matrix_indices = list(filter(lambda idx:
305
+ norm(self.get_all_fitted_matrices()[idx], ord=2) > self.refit_threshold,
306
+ refit_matrix_indices)) # don't refit rank zero matrices after hard-threshold
307
+ return np.array(refit_matrix_indices, dtype=int)
308
+
309
+ def refit(self, As: Union[List[np.array], np.array], refit_matrix_indices=None):
310
+ """
311
+ Refit parameter matrices using GLM based on current adjacency matrices.
312
+
313
+ Parameters
314
+ ----------
315
+ As : list or np.ndarray
316
+ Input adjacency matrices.
317
+ refit_matrix_indices : iterable of int, optional
318
+ Indices of matrices to refit. If None, all non-zero matrices are refitted.
319
+ """
320
+ As = self._validate_input(As)
321
+ refit_matrix_indices = self._preprocess_refit_matrix_indices(refit_matrix_indices)
322
+ if len(refit_matrix_indices) != 0:
323
+
324
+ refit_matrix_eigvecs = [leading_left_eigenvectors(mat, k=self.max_rank,
325
+ eigval_threshold=self.refit_threshold)
326
+ for mat in self.get_all_fitted_matrices()[refit_matrix_indices]]
327
+
328
+ design_mat, response, offset = self._construct_refit_design_response_and_offset(
329
+ As=As, refit_matrix_indices=refit_matrix_indices, refit_matrix_eigvecs=refit_matrix_eigvecs)
330
+
331
+ family = sm.families.Gaussian() if self.edge_distrib == "normal" else sm.families.Binomial()
332
+ refit_model = sm.GLM(response, design_mat, family=family, offset=offset)
333
+ result = refit_model.fit()
334
+ refit_matrix_eigvals = result.params
335
+
336
+ refitted_matrices = self._construct_refitted_matrices(refit_matrix_eigvecs, refit_matrix_eigvals)
337
+ self._set_matrices(refitted_matrices, indices=refit_matrix_indices)
@@ -0,0 +1,205 @@
1
+ import numpy as np
2
+ from typing import Union, List
3
+ from warnings import warn
4
+
5
+ from .utils import sigmoid, generate_matrices_given_pairwise_max_cosines
6
+
7
+
8
+ class LatentPositionGenerator:
9
+ def __init__(self, n_nodes, n_layers, *,
10
+ edge_distrib: str = "normal",
11
+ noise_sigma: float = 1.,
12
+ loops_allowed=True,
13
+ d_shared: int = 2,
14
+ d_individs: Union[List, int] = 2,
15
+ s_vu: float = 0.,
16
+ s_uu: float = 0.,
17
+ comps_max_cosine_mat: np.array = None,
18
+ min_V_max_U_eigval_ratio=None):
19
+
20
+ self.n_nodes = n_nodes
21
+ self.n_layers = n_layers
22
+ self.edge_distrib = edge_distrib
23
+ self.noise_sigma = noise_sigma
24
+ self.loops_allowed = loops_allowed
25
+ self.d_shared = d_shared
26
+ self.d_individs = d_individs
27
+ self.s_vu = s_vu
28
+ self.s_uu = s_uu
29
+ self.comps_max_cosine_mat = comps_max_cosine_mat
30
+ self.min_V_max_U_eigval_ratio = min_V_max_U_eigval_ratio
31
+
32
+ def _validate_input(self):
33
+ if self.edge_distrib == "normal":
34
+ self.link_fun = lambda x: x
35
+ elif self.edge_distrib == "bernoulli":
36
+ self.link_fun = sigmoid
37
+ else:
38
+ raise NotImplementedError("Link function should be normal or bernoulli")
39
+
40
+ if isinstance(self.d_individs, (list, np.ndarray)):
41
+ assert len(self.d_individs) == self.n_layers, \
42
+ "d_individs should have the length equal to the layers number"
43
+ self.d_individs_ = self.d_individs
44
+ else:
45
+ self.d_individs_ = self.d_individs * np.ones(self.n_layers, dtype=int)
46
+
47
+ if self.comps_max_cosine_mat is not None:
48
+ assert self.comps_max_cosine_mat.ndim == 2
49
+ assert np.allclose(self.comps_max_cosine_mat, self.comps_max_cosine_mat.T), \
50
+ "comps_max_cosine_mat matrix should be symmetric"
51
+ assert np.all((self.comps_max_cosine_mat >= 0) & (self.comps_max_cosine_mat <= 1))
52
+ assert np.allclose(np.all(np.diag(self.comps_max_cosine_mat)), 1)
53
+
54
+ def _generate_latent_spaces(self):
55
+ all_dims = [self.d_shared] + list(self.d_individs_)
56
+ if np.sum(all_dims) <= self.n_nodes:
57
+ if self.comps_max_cosine_mat is None:
58
+ vu_block = self.s_vu * np.ones((1, self.n_layers))
59
+ uu_block = self.s_uu * np.ones((self.n_layers, self.n_layers))
60
+ self.comps_max_cosine_mat = np.block([[np.eye(1), vu_block],
61
+ [vu_block.T, uu_block]])
62
+ np.fill_diagonal(self.comps_max_cosine_mat, 1)
63
+ else:
64
+ assert self.comps_max_cosine_mat.shape[0] == np.sum(all_dims)
65
+ all_lat_spaces = generate_matrices_given_pairwise_max_cosines(self.n_nodes, ds=all_dims,
66
+ pairwise_cos_mat=self.comps_max_cosine_mat)
67
+ else:
68
+ warn("When n_nodes < sum of all components' dimensions, all angle constraints cannot be satisfied.")
69
+ all_lat_spaces = [np.random.randn(self.n_nodes, d) for d in all_dims]
70
+ self.V, self.Us = all_lat_spaces[0], all_lat_spaces[1:]
71
+
72
+ if self.min_V_max_U_eigval_ratio is not None:
73
+ max_eigval_U = np.linalg.svd(self.V, compute_uv=False)[self.d_shared - 1] / self.min_V_max_U_eigval_ratio
74
+ self.Us = [U / np.linalg.norm(U, ord=2) * max_eigval_U for U in self.Us]
75
+
76
+ def _compute_shared_latent_position(self):
77
+ self.S = self.V @ self.V.T
78
+
79
+ def _compute_individual_latent_positions(self):
80
+ self.Rs = np.stack([U @ U.T for U in self.Us])
81
+
82
+ def _compute_latent_positions(self):
83
+ self._compute_shared_latent_position()
84
+ self._compute_individual_latent_positions()
85
+ return self.S + self.Rs
86
+
87
+ def generate(self, random_seed=None):
88
+ self._validate_input()
89
+ np.random.seed(random_seed)
90
+ self._generate_latent_spaces()
91
+ self.Ps = self.link_fun(self._compute_latent_positions())
92
+ As = []
93
+
94
+ triu_x, triu_y = np.triu_indices(self.n_nodes, k=1)
95
+ for idx, P in enumerate(self.Ps):
96
+ if self.edge_distrib == "normal":
97
+ noise = self.noise_sigma * np.random.randn(self.n_nodes, self.n_nodes)
98
+ noise[triu_x, triu_y] = noise[triu_y, triu_x]
99
+ A = P + noise
100
+ elif self.edge_distrib == "bernoulli":
101
+ A = (np.random.rand(self.n_nodes, self.n_nodes) <= P).astype(float)
102
+ A[triu_x, triu_y] = A[triu_y, triu_x]
103
+ else:
104
+ raise NotImplementedError()
105
+
106
+ if not self.loops_allowed:
107
+ np.fill_diagonal(A, 0.)
108
+ As.append(A)
109
+ self.As = np.stack(As)
110
+
111
+
112
+ class GroupLatentPositionGenerator(LatentPositionGenerator):
113
+ def __init__(self, n_nodes: int, n_layers: int, *,
114
+ group_indices: List[int],
115
+ edge_distrib: str = "normal",
116
+ noise_sigma: float = 1.,
117
+ loops_allowed: bool = True,
118
+ d_shared: int = 2,
119
+ d_individs: Union[List, int] = 2,
120
+ d_groups: Union[List, int] = 2,
121
+ s_vw: float = 0.,
122
+ s_vu: float = 0.,
123
+ s_ww: float = 0.,
124
+ s_wu: float = 0.,
125
+ s_uu: float = 0.,
126
+ comps_max_cosine_mat=None,
127
+ min_V_max_W_eigval_ratio=None):
128
+
129
+ super().__init__(n_nodes, n_layers, edge_distrib=edge_distrib, noise_sigma=noise_sigma,
130
+ loops_allowed=loops_allowed, d_shared=d_shared, d_individs=d_individs,
131
+ comps_max_cosine_mat=comps_max_cosine_mat)
132
+
133
+ self.group_indices = group_indices
134
+ self.d_groups = d_groups
135
+ self.unique_groups = np.sort(np.unique(group_indices))
136
+ self.s_vw = s_vw
137
+ self.s_vu = s_vu
138
+ self.s_ww = s_ww
139
+ self.s_wu = s_wu
140
+ self.s_uu = s_uu
141
+ self.min_V_max_W_eigval_ratio = min_V_max_W_eigval_ratio
142
+
143
+ def _validate_input(self):
144
+ super()._validate_input()
145
+ assert (0 <= self.s_vw <= 1) & (0 <= self.s_uu <= 1) & (0 <= self.s_ww <= 1)
146
+ assert len(self.group_indices) == self.n_layers, "Number of group indices should = the number of layers"
147
+ assert np.all(np.sort(self.unique_groups) == np.arange(len(self.unique_groups))), \
148
+ "group_indices should contain ints from 0 to n_groups - 1, where n_groups is the number of distinct groups"
149
+ self.n_groups_ = len(self.unique_groups)
150
+ assert self.n_groups_ >= 2, "Number of groups should be at least 2"
151
+ if isinstance(self.d_groups, (list, np.ndarray)):
152
+ assert self.n_groups_ == len(self.d_groups), \
153
+ "Number of distinct groups should be the same as the length of d_groups"
154
+ self.d_groups_ = self.d_groups
155
+ else:
156
+ self.d_groups_ = self.d_groups * np.ones(self.n_groups_, dtype=int)
157
+
158
+ def _generate_latent_spaces(self):
159
+ all_dims = [self.d_shared] + list(self.d_groups_) + list(self.d_individs_)
160
+ if self.n_nodes >= np.sum(all_dims):
161
+ if self.comps_max_cosine_mat is None:
162
+ uu_block = self.s_uu * np.ones((self.n_layers, self.n_layers))
163
+ ww_block = self.s_ww * np.ones((self.n_groups_, self.n_groups_))
164
+ vw_block = self.s_vw * np.ones((1, self.n_groups_))
165
+ vu_block = self.s_vu * np.ones((1, self.n_layers))
166
+ wu_block = self.s_wu * np.ones((self.n_groups_, self.n_layers))
167
+ self.comps_max_cosine_mat = np.block([[np.eye(1), vw_block, vu_block],
168
+ [vw_block.T, ww_block, wu_block],
169
+ [vu_block.T, wu_block.T, uu_block]])
170
+ np.fill_diagonal(self.comps_max_cosine_mat, 1)
171
+ else:
172
+ assert self.comps_max_cosine_mat.shape[0] == np.sum(all_dims)
173
+
174
+ all_lat_spaces = generate_matrices_given_pairwise_max_cosines(self.n_nodes, ds=all_dims,
175
+ pairwise_cos_mat=self.comps_max_cosine_mat)
176
+ elif self.n_nodes >= self.d_shared + np.sum(self.d_groups_):
177
+ shared_group_dims = [self.d_shared] + list(self.d_groups_)
178
+ ww_block = self.s_ww * np.ones((self.n_groups_, self.n_groups_))
179
+ vw_block = self.s_vw * np.ones((1, self.n_groups_))
180
+ cos_mat = np.block([[np.eye(1), vw_block],
181
+ [vw_block.T, ww_block]])
182
+ np.fill_diagonal(cos_mat, 1)
183
+ shared_group_lat_spaces = generate_matrices_given_pairwise_max_cosines(self.n_nodes, ds=shared_group_dims,
184
+ pairwise_cos_mat=cos_mat)
185
+ all_lat_spaces = shared_group_lat_spaces + [np.random.randn(self.n_nodes, d) for d in self.d_individs_]
186
+ else:
187
+ warn("When n_nodes < sum of all components' dimensions, all angle constraints cannot be satisfied.")
188
+ all_lat_spaces = [np.random.randn(self.n_nodes, d) for d in all_dims]
189
+
190
+ self.V = all_lat_spaces[0]
191
+ self.Ws = all_lat_spaces[1: self.n_groups_ + 1]
192
+ self.Us = all_lat_spaces[self.n_groups_ + 1:]
193
+
194
+ if self.min_V_max_W_eigval_ratio is not None:
195
+ max_eigval_W = np.linalg.svd(self.V, compute_uv=False)[self.d_shared - 1] / self.min_V_max_W_eigval_ratio
196
+ self.Ws = [W / np.linalg.norm(W, ord=2) * max_eigval_W for W in self.Ws]
197
+
198
+ def _compute_group_latent_positions(self):
199
+ self.Qs = np.stack([W @ W.T for W in self.Ws])
200
+
201
+ def _compute_latent_positions(self):
202
+ self._compute_shared_latent_position()
203
+ self._compute_group_latent_positions()
204
+ self._compute_individual_latent_positions()
205
+ return self.S + self.Qs[self.group_indices] + self.Rs