graphical-sampling 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,161 @@
1
+ import numpy as np
2
+ from numpy.typing import NDArray
3
+ from sklearn.cluster import KMeans
4
+
5
+
6
+ class SoftBalancedKMeans:
7
+ def __init__(
8
+ self, k: int, *, initial_centroids: NDArray = None, tolerance: int = 3
9
+ ) -> None:
10
+ self.k = k
11
+ self.tolerance = tolerance
12
+ self.coords: NDArray = None
13
+ self.centroids = initial_centroids
14
+ self.labels: NDArray = None
15
+ self.fractional_labels: NDArray = None
16
+ self.clusters_sum: NDArray = None
17
+ self.rng = np.random.default_rng()
18
+
19
+ def _generate_fractional_labels(self, probs: NDArray):
20
+ fractional_labels = np.zeros((*self.labels.shape, self.k))
21
+ for i in range(self.labels.shape[0]):
22
+ fractional_labels[i, self.labels[i]] = probs[i]
23
+ return fractional_labels
24
+
25
+ def _transfer_score(
26
+ self,
27
+ data_point: NDArray,
28
+ current_cluster_indx: float,
29
+ other_cluster_indx: float,
30
+ ) -> float:
31
+ if (
32
+ self.clusters_sum[current_cluster_indx]
33
+ - self.clusters_sum[other_cluster_indx]
34
+ > 10**-self.tolerance
35
+ ):
36
+ return (
37
+ np.linalg.norm(data_point - self.centroids[other_cluster_indx]) ** 2
38
+ - np.linalg.norm(data_point - self.centroids[current_cluster_indx]) ** 2
39
+ ) / (
40
+ self.clusters_sum[current_cluster_indx]
41
+ - self.clusters_sum[other_cluster_indx]
42
+ + 1e-9
43
+ )
44
+ else:
45
+ return np.inf
46
+
47
+ def _get_transfer_records(self, data: NDArray, top_m: int):
48
+ costs = []
49
+
50
+ for i in range(data.shape[0]):
51
+ for j in np.nonzero(self.fractional_labels[i])[0]:
52
+ t_min = np.argmin(
53
+ [self._transfer_score(data[i], j, t) for t in range(self.k)]
54
+ )
55
+ cost = self._transfer_score(data[i], j, t_min)
56
+ costs.append((cost, i, j, t_min))
57
+
58
+ costs = np.array(costs)
59
+
60
+ return costs[np.argsort(costs[:, 0])][:top_m, 1:].astype(int)
61
+
62
+ def _transfer(self, data_index: int, from_index: int, to_index: int) -> None:
63
+ if (
64
+ self.clusters_sum[from_index] >= 1 - 10**-self.tolerance
65
+ and self.clusters_sum[to_index] >= 1 - 10**-self.tolerance
66
+ ) or (
67
+ self.clusters_sum[from_index] <= 1 + 10**-self.tolerance
68
+ and self.clusters_sum[to_index] <= 1 + 10**-self.tolerance
69
+ ):
70
+ transfer_prob = min(
71
+ self.fractional_labels[data_index, from_index],
72
+ (self.clusters_sum[from_index] - self.clusters_sum[to_index]) / 2,
73
+ )
74
+ else:
75
+ transfer_prob = min(
76
+ self.fractional_labels[data_index, from_index],
77
+ self.clusters_sum[from_index] - 1,
78
+ 1 - self.clusters_sum[to_index],
79
+ )
80
+ self.fractional_labels[data_index, from_index] = (
81
+ self.fractional_labels[data_index, from_index] - transfer_prob
82
+ )
83
+ self.fractional_labels[data_index, to_index] = (
84
+ self.fractional_labels[data_index, to_index] + transfer_prob
85
+ )
86
+
87
+ def _no_transfer_possible(self, transfer_records: NDArray) -> bool:
88
+ return transfer_records[0, 0] == np.inf
89
+
90
+ def _is_transfer_possible(self, from_cluster: int, to_cluster: int) -> bool:
91
+ return (
92
+ self.clusters_sum[from_cluster] - self.clusters_sum[to_cluster]
93
+ > 10**-self.tolerance
94
+ )
95
+
96
+ def _stop_codition(self, tol) -> bool:
97
+ return np.all(np.abs(self.clusters_sum - 1) < 10**-tol)
98
+
99
+ def _expected_num_transfers(self) -> float:
100
+ max_diff_sum = np.max(self.clusters_sum - self.clusters_sum[:, None])
101
+ mean_nonzero_probs = np.mean(
102
+ self.fractional_labels[np.nonzero(self.fractional_labels)]
103
+ )
104
+ return max(int(np.floor(max_diff_sum / (2 * mean_nonzero_probs))), 1)
105
+
106
+ def _update_centroids(self, data: NDArray) -> None:
107
+ self.centroids = np.array(
108
+ [
109
+ np.mean(data[np.nonzero(self.fractional_labels[:, i])[0]], axis=0)
110
+ for i in range(self.k)
111
+ ]
112
+ )
113
+
114
+ def _numerical_stabilizer(self) -> float:
115
+ self.fractional_labels = np.round(self.fractional_labels, self.tolerance)
116
+ self.fractional_labels *= 1 / np.sum(self.fractional_labels, axis=0)
117
+ self.clusters_sum = np.sum(self.fractional_labels, axis=0)
118
+
119
+ def fit(self, data: NDArray, probs: NDArray) -> None:
120
+ self.coords = data
121
+
122
+ kmeans = KMeans(
123
+ n_clusters=self.k,
124
+ init=self.centroids if self.centroids is not None else "k-means++",
125
+ n_init=1 if self.centroids is not None else 10,
126
+ tol=10**-self.tolerance,
127
+ )
128
+ kmeans.fit(self.coords)
129
+
130
+ self.centroids = kmeans.cluster_centers_
131
+ self.labels = kmeans.labels_
132
+ self.fractional_labels = self._generate_fractional_labels(probs)
133
+ self.clusters_sum = np.sum(self.fractional_labels, axis=0)
134
+
135
+ while not self._stop_codition(self.tolerance):
136
+ transfer_records = self._get_transfer_records(
137
+ self.coords, top_m=self._expected_num_transfers()
138
+ )
139
+ if self._no_transfer_possible(transfer_records):
140
+ break
141
+ for data_index, from_cluster_index, to_cluster_index in transfer_records:
142
+ if self._is_transfer_possible(from_cluster_index, to_cluster_index):
143
+ self._transfer(data_index, from_cluster_index, to_cluster_index)
144
+ self.clusters_sum = np.sum(self.fractional_labels, axis=0)
145
+ self._update_centroids(self.coords)
146
+
147
+ self._numerical_stabilizer()
148
+
149
+ def get_clusters(self) -> NDArray:
150
+ clusters = []
151
+
152
+ for i in range(self.k):
153
+ probs = self.fractional_labels[:, i]
154
+ ids = np.nonzero(probs)[0]
155
+ units = np.concatenate(
156
+ [ids.reshape(-1, 1), self.coords[ids], probs[ids].reshape(-1, 1)],
157
+ axis=1,
158
+ )
159
+ clusters.append(units)
160
+
161
+ return clusters
@@ -0,0 +1,4 @@
1
+ from .var_nht import VarNHT
2
+
3
+
4
+ __all__ = ["VarNHT"]
@@ -0,0 +1,15 @@
1
+ from dataclasses import dataclass
2
+ from abc import abstractmethod
3
+
4
+ from numpy.typing import NDArray
5
+
6
+ from ..design import Design
7
+
8
+
9
+ @dataclass
10
+ class Criteria:
11
+ auxiliary_variable: NDArray
12
+ inclusion_probability: NDArray
13
+
14
+ @abstractmethod
15
+ def __call__(self, design: Design) -> float: ...
@@ -0,0 +1,26 @@
1
+ import numpy as np
2
+
3
+ from ..design import Design
4
+ from .criteria import Criteria
5
+
6
+
7
+ class VarNHT(Criteria):
8
+ def __call__(self, design: Design) -> float:
9
+ nht_estimator = np.array(
10
+ [
11
+ np.sum(
12
+ self.auxiliary_variable[list(sample.ids)]
13
+ / self.inclusion_probability[list(sample.ids)]
14
+ )
15
+ for sample in design
16
+ ]
17
+ )
18
+
19
+ samples_probabilities = np.array([sample.probability for sample in design])
20
+
21
+ variance_nht = (
22
+ np.sum((nht_estimator**2) * samples_probabilities)
23
+ - (np.sum(self.auxiliary_variable)) ** 2
24
+ )
25
+
26
+ return variance_nht
@@ -0,0 +1,128 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Iterator, Collection, Optional
4
+
5
+ import numpy as np
6
+ from matplotlib import pyplot as plt
7
+ from .structs import MaxHeap, Sample
8
+
9
+
10
+ class Design:
11
+ def __init__(
12
+ self,
13
+ inclusions: Optional[Collection[float]] = None,
14
+ rng: np.random.Generator = np.random.default_rng(),
15
+ ):
16
+ self.heap = MaxHeap[Sample](rng=rng)
17
+ self.rng = rng
18
+ self.changes = 0
19
+ if inclusions is not None:
20
+ self.push_initial_design(inclusions)
21
+
22
+ def push_initial_design(self, inclusions: Collection[float]):
23
+ events: list[tuple[float, str, int]] = []
24
+ level: float = 0
25
+ for i, p in enumerate(inclusions):
26
+ next_level = level + p
27
+ if next_level < 1 - 1e-9:
28
+ events.append((level, "start", i))
29
+ events.append((next_level, "end", i))
30
+ level = next_level
31
+ elif next_level > 1 + 1e-9:
32
+ events.append((level, "start", i))
33
+ events.append((1, "end", i))
34
+ events.append((0, "start", i))
35
+ events.append((next_level - 1, "end", i))
36
+ level = next_level - 1
37
+ else:
38
+ events.append((level, "start", i))
39
+ events.append((1, "end", i))
40
+ level = 0
41
+
42
+ events.sort()
43
+ active = set()
44
+ last_point: float = 0
45
+
46
+ for point, event_type, bar_index in events:
47
+ if event_type == "start":
48
+ active.add(bar_index)
49
+ elif event_type == "end":
50
+ if last_point != point:
51
+ self.push(Sample(round(point - last_point, 9), frozenset(active)))
52
+ active.remove(bar_index)
53
+
54
+ last_point = point
55
+
56
+ def copy(self) -> Design:
57
+ new_design = Design(
58
+ rng=self.rng,
59
+ )
60
+ new_design.heap = self.heap.copy()
61
+ new_design.changes = self.changes
62
+ return new_design
63
+
64
+ def pull(self, random: bool = False) -> Sample:
65
+ if random:
66
+ return self.heap.randompop()
67
+ return self.heap.pop()
68
+
69
+ def push(self, *args: Sample) -> None:
70
+ for r in args:
71
+ if not r.almost_zero():
72
+ self.heap.push(r)
73
+
74
+ def merge_identical(self):
75
+ dic = {}
76
+ for r in self.heap:
77
+ dic.setdefault(r.ids, 0)
78
+ dic[r.ids] += r.probability
79
+ self.heap = MaxHeap[Sample](
80
+ initial_heap=[Sample(length, ids) for ids, length in dic.items()],
81
+ rng=self.rng,
82
+ )
83
+
84
+ def switch(
85
+ self,
86
+ r1: Sample,
87
+ r2: Sample,
88
+ coefficient: float = 0.5,
89
+ ) -> tuple[Sample, Sample, Sample, Sample]:
90
+ length = coefficient * min(r1.probability, r2.probability)
91
+ n1 = self.rng.choice(list(r1.ids - r2.ids))
92
+ n2 = self.rng.choice(list(r2.ids - r1.ids))
93
+ return (
94
+ Sample(length, r1.ids - {n1} | {n2}),
95
+ Sample(r1.probability - length, r1.ids),
96
+ Sample(length, r2.ids - {n2} | {n1}),
97
+ Sample(r2.probability - length, r2.ids),
98
+ )
99
+
100
+ def iterate(
101
+ self, random_pull: bool = False, switch_coefficient: float = 0.5
102
+ ) -> None:
103
+ r1 = self.pull(random_pull)
104
+ r2 = self.pull(random_pull)
105
+ if r1.ids == r2.ids:
106
+ self.push(Sample(r1.probability + r2.probability, r1.ids))
107
+ else:
108
+ self.push(*self.switch(r1, r2, switch_coefficient))
109
+ self.changes += 1
110
+
111
+ def show(self) -> None:
112
+ initial_level: float = 0
113
+ for r in self.heap:
114
+ for i in r.ids:
115
+ plt.plot([i, i], [initial_level, initial_level + r.probability])
116
+ initial_level += r.probability
117
+ plt.show()
118
+
119
+ def __iter__(self) -> Iterator[Sample]:
120
+ return iter(self.heap)
121
+
122
+ def __eq__(self, other: object) -> bool:
123
+ if not isinstance(other, Design):
124
+ return NotImplemented
125
+ return self.heap == other.heap
126
+
127
+ def __hash__(self) -> int:
128
+ return hash(self.heap)
@@ -0,0 +1,4 @@
1
+ from .density import Density
2
+
3
+
4
+ __all__ = ["Density"]
@@ -0,0 +1,94 @@
1
+ from ..clustering import AggregateBalancedKMeans
2
+
3
+ import numpy as np
4
+ from numpy._typing import NDArray
5
+ from scipy.optimize import linear_sum_assignment
6
+ from sklearn.neighbors import KernelDensity
7
+
8
+
9
+ class Density:
10
+ def __init__(self, coordinates: NDArray, probabilities: NDArray, k: int):
11
+ # self.coords = (coordinates - coordinates.min(axis=0))/np.ptp(coordinates, axis=0)[0]
12
+ self.coords = coordinates
13
+ self.probs = probabilities
14
+ self.kde = self._kde(coordinates)
15
+ self.labels, self.centroids, self.clusters = self._generate_labels_centroids(k)
16
+
17
+ def _kde(self, coords: NDArray) -> KernelDensity:
18
+ kde = KernelDensity(kernel="tophat", bandwidth="scott")
19
+ kde.fit(coords)
20
+ return kde
21
+
22
+ def scale(self, scores):
23
+ scaled_scores = np.zeros_like(scores)
24
+ limit = np.sin(np.pi/8)/np.sin(np.pi/4)
25
+ for i, score in enumerate(scores):
26
+ if score >= 0:
27
+ scaled_scores[i] = min(score, limit)/limit
28
+ else:
29
+ scaled_scores[i] = max(score, -limit)/limit
30
+ return scaled_scores
31
+
32
+ def _density(self, shifted_coords: NDArray) -> float:
33
+ shifted_kde = self._kde(shifted_coords)
34
+ density = np.exp(self.kde.score_samples(self.coords))
35
+ shifted_density = np.exp(shifted_kde.score_samples(shifted_coords))
36
+ spread = np.mean(self.scale((density-shifted_density)/np.sqrt(density**2+shifted_density**2)))
37
+ var = np.mean(1 - (density+shifted_density)/(np.sqrt(2)*np.sqrt(density**2+shifted_density**2)))
38
+ scale_for_var = 1-np.cos(np.pi/8)
39
+ var_scaled = min(var, scale_for_var)/scale_for_var
40
+ measure = [
41
+ spread,
42
+ var_scaled,
43
+ spread + (np.sign(spread) - spread) * var_scaled,
44
+ # spread + (np.sign(spread) - spread) * var_scaled**2
45
+ ]
46
+ return density, shifted_density, measure
47
+
48
+ def _generate_labels_centroids(self, k):
49
+ agg = AggregateBalancedKMeans(k=k, tolerance=5)
50
+ agg.fit(self.coords, self.probs.reshape(-1, 1), np.array([1]))
51
+ labels = np.argmax(agg.membership, axis=1)
52
+ centroids = np.array(
53
+ [
54
+ np.mean(self.coords[labels == i], axis=0)
55
+ for i in range(k)
56
+ ]
57
+ )
58
+ return labels, centroids, agg.get_clusters()
59
+
60
+ def _assign_samples_to_centroids(self, samples, centroids):
61
+ cost_matrix = np.linalg.norm(samples[:, :, np.newaxis] - centroids, axis=3).transpose(0, 2, 1)
62
+ return np.array([samples[i][linear_sum_assignment(cost_matrix[i])[1]] for i in range(samples.shape[0])])
63
+
64
+ def _generate_shifted_coords(self, shifts: NDArray, labels: NDArray) -> NDArray:
65
+ shifted_coords = self.coords.copy()
66
+ for j, shift in enumerate(shifts):
67
+ shifted_coords[labels == j] += shift
68
+ return shifted_coords
69
+
70
+ def _score_sample(
71
+ self, sample: NDArray, labels: NDArray, centroids: NDArray
72
+ ) -> float:
73
+ shifts = sample - centroids
74
+ shifted_coords = self._generate_shifted_coords(shifts, labels)
75
+ return self._density(shifted_coords)
76
+
77
+ def score(self, samples: NDArray) -> NDArray:
78
+ scores = []
79
+ densities = []
80
+ samples_assigned = self._assign_samples_to_centroids(
81
+ self.coords[samples], self.centroids
82
+ )
83
+ for sample in samples_assigned:
84
+ density, shifted_density, score = self._score_sample(sample, self.labels, self.centroids)
85
+ densities.append([density, shifted_density])
86
+ scores.append(score)
87
+ return self.zippify(scores), densities
88
+
89
+ def zippify(self, scores):
90
+ zipped_scores = [[] for _ in range(len(scores[0]))]
91
+ for score in scores:
92
+ for i, den in enumerate(score):
93
+ zipped_scores[i].append(den)
94
+ return zipped_scores
@@ -0,0 +1,4 @@
1
+ from .generator import rng
2
+
3
+
4
+ __all__ = ["rng"]
@@ -0,0 +1,251 @@
1
+ import numpy as np
2
+ from numpy.typing import NDArray
3
+
4
+
5
+ class Generator:
6
+ def __init__(self, seed: int = None) -> None:
7
+ self.rng = np.random.default_rng(seed)
8
+
9
+ def grid_coordinates(self, size: int | tuple[int, ...] = None) -> NDArray:
10
+ """
11
+ Generate grid coordinates.
12
+
13
+ Args:
14
+ size (int|Tuple[int, ...]]): Specifies the grid dimensions.
15
+ - If None, generates a single grid of size `(1,)`.
16
+ - If an integer N, generates a single grid of size `(N,)`.
17
+ - If a tuple of `(N, D)`, generates a single grid with `N` points per `D` dimensions.
18
+ The size of the grids is `(N**D, D)`.
19
+ - If a tuple of `(B, N, D)`, generates a batch of `B` grids of size `(N**D, D)`.
20
+
21
+ Returns:
22
+ NDArray: An array containing the coordinates of the grid(s), normalized to [0, 1].
23
+ """
24
+ batch, N, dim = self._check_size(size)
25
+
26
+ linspace = np.linspace(0, 1, N)
27
+ grid = np.meshgrid(*[linspace] * dim)
28
+ base_coordinates = np.stack([indices.ravel() for indices in grid], axis=-1)
29
+ coordinates = np.repeat(base_coordinates[np.newaxis, :, :], batch, axis=0)
30
+
31
+ return coordinates.squeeze()
32
+
33
+ def random_coordinates(self, size: int | tuple[int, ...] = None) -> NDArray:
34
+ """
35
+ Generate random coordinates uniformly from [0, 1].
36
+
37
+ Args:
38
+ size (int|Tuple[int, ...]]): Specifies the grid dimensions.
39
+ - If None, generates a single coordinate of size `(1,)`.
40
+ - If an integer N, generates one dimensional coordinates for N points.
41
+ - If a tuple of `(N, D)`, generates a D-dimensional coordinates for N points.
42
+ - If a tuple of `(B, N, D)`, generates a batch of `B` coordinates of size `(N, D)`.
43
+
44
+ Returns:
45
+ NDArray: An array containing random coordinates in [0, 1].
46
+ """
47
+ batch, N, dim = self._check_size(size)
48
+ coordinates = self.rng.random((batch, N, dim))
49
+ return coordinates.squeeze()
50
+
51
+ def uniform_coordinates(
52
+ self, low: float = 0.0, high: float = 1.0, size: int | tuple[int, ...] = None
53
+ ) -> NDArray:
54
+ """
55
+ Generate random coordinates uniformly from [low, high].
56
+
57
+ Args:
58
+ low (float): Lower boundary of the output interval. All values generated will be greater than or equal to low.
59
+ The default value is 0.
60
+ high (float): Upper boundary of the output interval. All values generated will be less than high.
61
+ high - low must be non-negative. The default value is 1.0.
62
+ size (int|Tuple[int, ...]]): Specifies the grid dimensions.
63
+ - If None, generates a single coordinate of size `(1,)`.
64
+ - If an integer N, generates one dimensional coordinates for N points.
65
+ - If a tuple of `(N, D)`, generates a D-dimensional coordinates for N points.
66
+ - If a tuple of `(B, N, D)`, generates a batch of `B` coordinates of size `(N, D)`.
67
+
68
+ Returns:
69
+ NDArray: An array containing random coordinates in [0, 1].
70
+ """
71
+ batch, N, dim = self._check_size(size)
72
+ coordinates = self.rng.uniform(low, high, (batch, N, dim)).squeeze()
73
+ return coordinates.squeeze()
74
+
75
+ def normal_1D_coordinates(
76
+ self, mean: float = 0.0, std: float = 1.0, size: int | tuple[int, ...] = None
77
+ ) -> NDArray:
78
+ """
79
+ Generate random 1D coordinates sampled from a normal distribution.
80
+
81
+ Args:
82
+ mean (float): Mean of the normal distribution (default 0).
83
+ std (float): Standard deviation of the normal distribution (default 1).
84
+ size (int|Tuple[int, ...]]): Specifies the coordinates dimensions.
85
+ - If None, generates a single coordinate of size `(1,)`.
86
+ - If an integer N, generates one dimensional coordinates for N points.
87
+ - If a tuple of `(B, N)`, generates a batch of `B` coordinates of size `(N,)`.
88
+ - If a tuple of `(B, N, D)`, D has to be 1.
89
+
90
+ Returns:
91
+ NDArray: An array containing coordinates sampled from a normal distribution.
92
+ """
93
+ batch, N, dim = self._check_size(size, fixed_dim=1)
94
+ coordinates = self.rng.normal(mean, std, (batch, N)).squeeze()
95
+ return coordinates.squeeze()
96
+
97
+ def normal_mD_coordinates(
98
+ self, mean: NDArray, cov: NDArray, size: int | tuple[int, ...] = None
99
+ ) -> NDArray:
100
+ """
101
+ Generate random m-dimensional coordinates sampled from a multivariate normal distribution.
102
+
103
+ Args:
104
+ mean (ArrayLike): Mean of the m-dimensional distribution.
105
+ cov (ArrayLike): Covariance matrix of the distribution.
106
+ size (int|Tuple[int, ...]]): Specifies the coordinates dimensions.
107
+ - If None, generates a single coordinate of size `(1, m)`.
108
+ - If an integer N, generates m-dimensional coordinates for N points. the shape of coordinates is `(N, m)`.
109
+ - If a tuple of `(B, N)`, generates a batch of `B` coordinates of size `(N, m)`.
110
+ - If a tuple of `(B, N, D)`, D has to be m.
111
+
112
+ Returns:
113
+ NDArray: An array containing coordinates sampled from a multivariate normal distribution.
114
+ """
115
+ batch, N, dim = self._check_size(size, fixed_dim=mean.shape[0])
116
+ coordinates = self.rng.multivariate_normal(mean, cov, (batch, N)).squeeze()
117
+ return coordinates.squeeze()
118
+
119
+ def cluster_coordinates(
120
+ self,
121
+ n_clusters: int,
122
+ cluster_std: float | NDArray,
123
+ size: int | tuple[int, ...] = None,
124
+ ) -> NDArray:
125
+ """
126
+ Generate clustered coordinates by placing points around cluster centers.
127
+
128
+ Args:
129
+ n_clusters (int): Number of clusters to generate.
130
+ cluster_std (float|NDArray): Standard deviation of points around each cluster center.
131
+ size (int|Tuple[int, ...]]): Specifies the grid dimensions.
132
+ - If None, generates a single coordinate of size `(1,)`.
133
+ - If an integer N, generates one dimensional coordinates for N points.
134
+ - If a tuple of `(N, D)`, generates a D-dimensional coordinates for N points.
135
+ - If a tuple of `(B, N, D)`, generates a batch of `B` coordinates of size `(N, D)`.
136
+
137
+ Returns:
138
+ NDArray: An array containing clustered coordinates.
139
+ """
140
+ if not isinstance(n_clusters, int):
141
+ raise ValueError("n_clusters must be an integer.")
142
+ if not isinstance(cluster_std, (float, np.ndarray)):
143
+ raise ValueError("cluster_std must be a float or an array.")
144
+
145
+ batch, N, dim = self._check_size(size)
146
+
147
+ cluster_centers = self.rng.random((batch, n_clusters, dim))
148
+ cluster_assignments = self.rng.integers(0, n_clusters, size=(batch, N))
149
+ coordinates = np.zeros((batch, N, dim))
150
+
151
+ for b in range(batch):
152
+ for n in range(N):
153
+ cluster_id = cluster_assignments[b, n]
154
+ coordinates[b, n, :] = self.rng.normal(
155
+ cluster_centers[b, cluster_id],
156
+ cluster_std
157
+ if isinstance(cluster_std, float)
158
+ else cluster_std[cluster_id],
159
+ dim,
160
+ )
161
+
162
+ return coordinates.squeeze()
163
+
164
+ def equal_probabilities(
165
+ self, n: int, size: int | tuple[int, ...] = None
166
+ ) -> NDArray:
167
+ """
168
+ Generate equal probabilities that sum up to n.
169
+
170
+ Args:
171
+ n (int): The total sum of the probabilities.
172
+ size (int): The size of probabilities.
173
+ - If None, generates a single probability of size `(1,)`.
174
+ - If an integer N, generates N probabilities.
175
+ - If a tuple of `(B, N)`, generates a batch of `B` probabilities of size `(N,)`.
176
+ - If a tuple of `(B, N, D)`, D has to be 1.
177
+
178
+ Returns:
179
+ NDArray: An array of size `size` with equal probabilities summing to `n`.
180
+ """
181
+ if not isinstance(n, int) or n <= 0:
182
+ raise ValueError("n must be an positive integer.")
183
+
184
+ batch, N, dim = self._check_size(size, fixed_dim=1)
185
+ probabilities = np.full((batch, N, dim), n / N)
186
+ return probabilities.squeeze()
187
+
188
+ def unequal_probabilities(
189
+ self, n: int, size: int | tuple[int, ...] = None
190
+ ) -> NDArray:
191
+ """
192
+ Generate equal probabilities that sum up to n.
193
+
194
+ Args:
195
+ n (int): The total sum of the probabilities.
196
+ size (int): The size of probabilities.
197
+ - If None, generates a single probability of size `(1,)`.
198
+ - If an integer N, generates N probabilities.
199
+ - If a tuple of `(B, N)`, generates a batch of `B` probabilities of size `(N,)`.
200
+ - If a tuple of `(B, N, D)`, D has to be 1.
201
+
202
+ Returns:
203
+ NDArray: An array of size `size` with equal probabilities summing to `n`.
204
+ """
205
+ if not isinstance(n, int) or n <= 0:
206
+ raise ValueError("n must be an positive integer.")
207
+
208
+ batch, N, dim = self._check_size(size, fixed_dim=1)
209
+ probabilities = self.rng.random((batch, N))
210
+ probabilities *= n / probabilities.sum(axis=1)
211
+ return probabilities.squeeze()
212
+
213
+ def _check_size(
214
+ self, size: int | tuple[int, ...], fixed_dim: int = None
215
+ ) -> tuple[int, int, int]:
216
+ """
217
+ Validate and parse the size input.
218
+
219
+ Args:
220
+ size (int|Tuple[int, ...]]): Specifies the dimensions.
221
+
222
+ Returns:
223
+ Tuple[int, int, int]: A tuple `(B, N, D)` where:
224
+ - `B` is the batch size.
225
+ - `N` is the number of points per in the batch (or per dimension in some cases).
226
+ - `D` is the dimensionality of points.
227
+ """
228
+ if isinstance(size, int):
229
+ return (1, size, fixed_dim or 1)
230
+ if size is None or len(size) == 0:
231
+ return (1, 1, fixed_dim or 1)
232
+ if not isinstance(size, tuple) or len(size) > 3:
233
+ raise ValueError("Size must be a tuple of at most 3 positive integers.")
234
+ if not all(isinstance(x, int) and x > 0 for x in size):
235
+ raise ValueError("All elements of size must be positive integers.")
236
+ if len(size) == 3 and fixed_dim is not None and size[2] != fixed_dim:
237
+ raise ValueError(
238
+ "The dimension is fixed. you cannot set a different dimension."
239
+ )
240
+
241
+ match len(size):
242
+ case 1:
243
+ return (1, *size, fixed_dim or 1)
244
+ case 2:
245
+ return (1, *size) if fixed_dim is None else (*size, fixed_dim)
246
+ case 3:
247
+ return size
248
+
249
+
250
+ def rng(seed=None):
251
+ return Generator(seed=seed)