tribble-clustering 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,39 @@
1
+ name: Publish to PyPI
2
+
3
+ on:
4
+ push:
5
+ tags:
6
+ - 'v*'
7
+
8
+ jobs:
9
+ build-n-publish:
10
+ name: Build and publish Python distribution to PyPI
11
+ runs-on: ubuntu-latest
12
+ permissions:
13
+ # This permission is required for trusted publishing (OIDC)
14
+ id-token: write
15
+ contents: read
16
+
17
+ steps:
18
+ - uses: actions/checkout@v4
19
+
20
+ - name: Set up Python
21
+ uses: actions/setup-python@v5
22
+ with:
23
+ python-version: "3.10"
24
+
25
+ - name: Install build dependencies
26
+ run: |
27
+ python -m pip install --upgrade pip
28
+ pip install build
29
+
30
+ - name: Build binary wheel and source tarball
31
+ run: python -m build
32
+
33
+ - name: Publish package distributions to PyPI
34
+ uses: pypa/gh-action-pypi-publish@release/v1
35
+ # Note: For this to work with Trusted Publishing, the user must configure
36
+ # the repository on PyPI to trust this GitHub workflow.
37
+ # If using an API token, uncomment the following line and set up secrets.
38
+ # with:
39
+ # password: ${{ secrets.PYPI_API_TOKEN }}
@@ -0,0 +1,32 @@
1
+ name: Test
2
+
3
+ on:
4
+ push:
5
+ branches: [ main ]
6
+ pull_request:
7
+ branches: [ main ]
8
+
9
+ jobs:
10
+ test:
11
+ runs-on: ubuntu-latest
12
+ strategy:
13
+ matrix:
14
+ python-version: ["3.8", "3.9", "3.10", "3.11", "3.12","3.13","3.14"]
15
+
16
+ steps:
17
+ - uses: actions/checkout@v4
18
+
19
+ - name: Set up Python ${{ matrix.python-version }}
20
+ uses: actions/setup-python@v5
21
+ with:
22
+ python-version: ${{ matrix.python-version }}
23
+
24
+ - name: Install dependencies
25
+ run: |
26
+ python -m pip install --upgrade pip
27
+ pip install pytest
28
+ pip install .
29
+
30
+ - name: Run tests
31
+ run: |
32
+ pytest tests
@@ -0,0 +1,148 @@
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so may as well local ignore them.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or app, you might want to share your Python version
87
+ .python-version
88
+
89
+ # pipenv
90
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
91
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
92
+ # having no cross-platform support, pipenv may install dependencies that don't work, or even
93
+ # fail to install them. In that case, the lock file should be added to .gitignore.
94
+ #Pipfile.lock
95
+
96
+ # poetry
97
+ # Similar to Pipenv, see https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
98
+ #poetry.lock
99
+
100
+ # pdm
101
+ # Similar to Pipenv, see https://pdm.fming.dev/#usage/commit-your-pdmlock-file-to-version-control
102
+ #pdm.lock
103
+
104
+ # PEP 582; used by e.g. github.com/fannheyward/coc-pyright
105
+ __pypackages__/
106
+
107
+ # Celery stuff
108
+ celerybeat-schedule
109
+ celerybeat.pid
110
+
111
+ # SageMath parsed files
112
+ *.sage.py
113
+
114
+ # Environments
115
+ .env
116
+ .venv
117
+ env/
118
+ venv/
119
+ ENV/
120
+ env.bak/
121
+ venv.bak/
122
+
123
+ # Spyder project settings
124
+ .spyderproject
125
+ .spyproject
126
+
127
+ # Rope project settings
128
+ .ropeproject
129
+
130
+ # mkdocs documentation
131
+ /site
132
+
133
+ # mypy
134
+ .mypy_cache/
135
+ .dmypy.json
136
+ dmypy.json
137
+
138
+ # Pyre type checker
139
+ .pyre/
140
+
141
+ # pytype static type analyzer
142
+ .pytype/
143
+
144
+ # Cython debug symbols
145
+ cython_debug/
146
+
147
+ # PyCharm
148
+ .idea/
File without changes
File without changes
File without changes
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2026 Scott Phillips
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1,41 @@
1
+ Metadata-Version: 2.4
2
+ Name: tribble-clustering
3
+ Version: 0.1.0
4
+ Summary: An optimized Visualization Assessment Tendency (VAT/IVAT) and fuzzy clustering (FCM) package
5
+ Project-URL: Homepage, https://github.com/fundthmcalculus/clustering
6
+ Project-URL: Bug Tracker, https://github.com/fundthmcalculus/clustering/issues
7
+ Author-email: Scott Phillips <polygonguru@gmail.com>
8
+ License-File: LICENSE
9
+ Classifier: License :: OSI Approved :: MIT License
10
+ Classifier: Operating System :: OS Independent
11
+ Classifier: Programming Language :: Python :: 3
12
+ Requires-Python: >=3.8
13
+ Requires-Dist: numba-progress==1.2.0
14
+ Requires-Dist: numba==0.65.1
15
+ Requires-Dist: numpy==2.4.6
16
+ Requires-Dist: scipy==1.17.1
17
+ Provides-Extra: dev
18
+ Requires-Dist: matplotlib; extra == 'dev'
19
+ Requires-Dist: pytest; extra == 'dev'
20
+ Requires-Dist: ucimlrepo; extra == 'dev'
21
+ Description-Content-Type: text/markdown
22
+
23
+ # Clustering Package
24
+
25
+ A simple Python package for clustering tasks.
26
+
27
+ ## Installation
28
+
29
+ ```bash
30
+ pip install clustering-pkg
31
+ ```
32
+
33
+ ## Usage
34
+
35
+ ```python
36
+ from clustering import simple_cluster
37
+
38
+ data = [1, 2, 10, 11]
39
+ result = simple_cluster(data, threshold=5)
40
+ print(result) # [[1, 2], [10, 11]]
41
+ ```
@@ -0,0 +1,19 @@
1
+ # Clustering Package
2
+
3
+ A simple Python package for clustering tasks.
4
+
5
+ ## Installation
6
+
7
+ ```bash
8
+ pip install clustering-pkg
9
+ ```
10
+
11
+ ## Usage
12
+
13
+ ```python
14
+ from clustering import simple_cluster
15
+
16
+ data = [1, 2, 10, 11]
17
+ result = simple_cluster(data, threshold=5)
18
+ print(result) # [[1, 2], [10, 11]]
19
+ ```
@@ -0,0 +1,39 @@
1
+ [build-system]
2
+ requires = ["hatchling"]
3
+ build-backend = "hatchling.build"
4
+
5
+ [project]
6
+ name = "tribble-clustering"
7
+ version = "0.1.0"
8
+ authors = [
9
+ { name="Scott Phillips", email="polygonguru@gmail.com" },
10
+ ]
11
+ description = "An optimized Visualization Assessment Tendency (VAT/IVAT) and fuzzy clustering (FCM) package"
12
+ readme = "README.md"
13
+ requires-python = ">=3.8"
14
+ classifiers = [
15
+ "Programming Language :: Python :: 3",
16
+ "License :: OSI Approved :: MIT License",
17
+ "Operating System :: OS Independent",
18
+ ]
19
+ dependencies = [
20
+ "numba==0.65.1",
21
+ "numpy==2.4.6",
22
+ "scipy==1.17.1",
23
+ "numba-progress==1.2.0",
24
+ ]
25
+
26
+ [project.optional-dependencies]
27
+ dev = [
28
+ "ucimlrepo",
29
+ "matplotlib",
30
+ "pytest"
31
+ ]
32
+
33
+ [project.urls]
34
+ "Homepage" = "https://github.com/fundthmcalculus/clustering"
35
+ "Bug Tracker" = "https://github.com/fundthmcalculus/clustering/issues"
36
+
37
+ [tool.hatch.build.targets.wheel]
38
+ packages = ["src/clustering"]
39
+
@@ -0,0 +1,6 @@
1
+ from .pvat import (
2
+ compute_ordered_dis_njit_merge,
3
+ vat_prim_mst,
4
+ vat_prim_mst_seq,
5
+ compute_ivat,
6
+ )
@@ -0,0 +1,84 @@
1
+ from typing import Optional, Literal
2
+
3
+ import numpy as np
4
+ from numpy import ndarray
5
+ from scipy.optimize import minimize
6
+
7
+
8
+ def _j_w_c(x: np.ndarray, c: np.ndarray, m: float) -> float:
9
+ """Compute the weighted sum of squared distances"""
10
+ w_ij = _get_weights(c, m, x)
11
+ j_wc = np.sum(
12
+ w_ij**m * np.sum((x[:, np.newaxis, :] - c[np.newaxis, :, :]) ** 2.0, axis=2),
13
+ axis=None,
14
+ )
15
+
16
+ return j_wc
17
+
18
+
19
+ def _get_weights(c: ndarray, m: float, x: ndarray) -> ndarray:
20
+ distances = np.linalg.norm(x[:, np.newaxis, :] - c[np.newaxis, :, :], axis=2)
21
+ distances_to_jj = distances[:, :, np.newaxis]
22
+ distances_to_all = distances[:, np.newaxis, :]
23
+ w_ij = 1.0 / np.sum((distances_to_jj / distances_to_all) ** (2.0 / (m - 1)), axis=2)
24
+ w_ij = np.where(np.isnan(w_ij) | np.isinf(w_ij), 0.0, w_ij)
25
+ return w_ij
26
+
27
+
28
+ def _get_v_ij(w_ij: ndarray, m: float, x: ndarray) -> ndarray:
29
+ v_ij = np.sum(w_ij[:, :, np.newaxis]**m * (x[:, np.newaxis, :]), axis=0) / np.sum(w_ij ** m, axis=0)[:, np.newaxis]
30
+ return v_ij
31
+
32
+
33
+ def fuzzy_c_means(
34
+ x: np.ndarray,
35
+ n: int,
36
+ m: float = 2.0,
37
+ *,
38
+ method: Literal["gd","iter"] = "iter",
39
+ indices: Optional[np.ndarray | list[int]] = None,
40
+ initial_guess: Optional[np.ndarray] = None
41
+ ) -> tuple[np.ndarray, np.ndarray]:
42
+ """Compute the fuzzy c-means"""
43
+ if initial_guess is not None and indices is not None:
44
+ raise ValueError("initial_guess and indices cannot both be provided")
45
+ # 1. Create the candidate centers
46
+ if indices is not None:
47
+ c = x[indices, :]
48
+ elif initial_guess is not None:
49
+ if initial_guess.shape != (n, x.shape[1]):
50
+ raise ValueError(
51
+ f"initial_guess must have shape ({n}, {x.shape[1]}), "
52
+ f"got {initial_guess.shape}"
53
+ )
54
+ c = initial_guess
55
+ else:
56
+ indices = np.random.choice(x.shape[0], size=n * 2, replace=False)
57
+ c = x[indices, :]
58
+ # Combine every two rows into one so no cluster center exactly matches a data-point
59
+ c = c.reshape(n, 2, x.shape[1]).mean(axis=1)
60
+
61
+ # 2. Iteratively refine with a gradient descent method
62
+ def optim_j_w_c(c_opt: np.ndarray) -> float:
63
+ c_reshaped = c_opt.reshape(n, x.shape[1])
64
+ return _j_w_c(x, c_reshaped, m)
65
+
66
+ if method == "gd":
67
+ result = minimize(optim_j_w_c, c.flatten(), method="BFGS")
68
+ c = result.x.reshape(n, x.shape[1])
69
+ elif method == "iter":
70
+ # Max of 100 iterations
71
+ for _ in range(100):
72
+ w_ij = _get_weights(c, m, x)
73
+ c_new = _get_v_ij(w_ij, m, x)
74
+ if np.allclose(c_new, c, rtol=1e-5, atol=1e-8):
75
+ break
76
+ c = c_new
77
+ else:
78
+ raise ValueError(f"Invalid method: {method}. Choose 'gd' or 'iter'.")
79
+
80
+ # Calculate membership matrix
81
+ w_ij = _get_weights(c, m, x)
82
+
83
+ # 3. Return the center-points
84
+ return c, w_ij
@@ -0,0 +1,278 @@
1
+ import heapq
2
+
3
+ from numba import njit, prange
4
+ import numpy as np
5
+ from numba_progress import ProgressBar
6
+ from numpy import ndarray
7
+
8
+
9
+ def compute_ivat(
10
+ matrix_of_pairwise_distance: np.ndarray, inplace: bool = False
11
+ ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
12
+ """
13
+ Computes the improved VAT (IVAT) for the provided dissimilarity (distance) matrix
14
+ :param matrix_of_pairwise_distance: dissimilarity matrix, typically an L2-norm matrix, it must be symmetric and positive semi-definite
15
+ :param inplace: whether to perform the computation in-place on the input matrix
16
+ :return: tuple of the IVAT matrix, the VAT matrix, the sequence of IVAT indices, and the sequence of permutation (VAT) indices
17
+ """
18
+ d_star, p_seq, as_seq = compute_ordered_dis_njit_merge(
19
+ matrix_of_pairwise_distance, inplace=inplace
20
+ )
21
+ N = d_star.shape[0]
22
+ if not inplace:
23
+ d_p_star = np.zeros(d_star.shape, dtype=d_star.dtype)
24
+ else:
25
+ d_p_star = d_star
26
+ argmin_seq = []
27
+ for r in range(1, N):
28
+ jj = np.argmin(d_star[r, :r])
29
+ # TODO - Get from the prim-mst sequence?
30
+ # jj = as_seq[r-1]
31
+ argmin_seq.append(jj)
32
+
33
+ # TODO - Handle doing just upper-triangular matrix for memory savings?
34
+ d_p_star[r, jj] = d_star[r, jj]
35
+ d_p_star[jj, r] = d_star[r, jj]
36
+ for c in range(r):
37
+ if c != jj:
38
+ d_p_star[c, r] = d_p_star[r, c] = max(d_star[r, jj], d_p_star[jj, c])
39
+
40
+ return d_p_star, d_star, argmin_seq, p_seq
41
+
42
+
43
+ def compute_vat(matrix_of_pairwise_distance: np.ndarray, inplace: bool = False) -> tuple[np.ndarray, np.ndarray]:
44
+ """
45
+ Computes the visualization assessment of cluster tendency (VAT) for the provided dissimilarity (distance) matrix
46
+ :param matrix_of_pairwise_distance: dissimilarity matrix, typically an L2-norm matrix, it must be symmetric and positive semi-definite
47
+ :param inplace: whether to perform the computation in-place on the input matrix
48
+ :return: tuple of the permuted distance (VAT) matrix and the permutation (VAT) sequence
49
+ """
50
+ d_star, p_seq, as_seq = compute_ordered_dis_njit_merge(
51
+ matrix_of_pairwise_distance, inplace=inplace
52
+ )
53
+ return d_star, p_seq
54
+
55
+
56
+ @njit(cache=True, parallel=True, nogil=True)
57
+ def compute_ordered_dis_njit_merge(
58
+ matrix_of_pairwise_distance: np.ndarray,
59
+ inplace: bool = False,
60
+ progress_bar: ProgressBar | None = None,
61
+ ) -> tuple[np.ndarray, list[int], list[int]]:
62
+ n = matrix_of_pairwise_distance.shape[0]
63
+ if inplace:
64
+ ordered_matrix = matrix_of_pairwise_distance
65
+ else:
66
+ ordered_matrix: np.ndarray = np.zeros(
67
+ matrix_of_pairwise_distance.shape, dtype=matrix_of_pairwise_distance.dtype
68
+ )
69
+ p, q = vat_prim_mst(matrix_of_pairwise_distance, progress_bar=progress_bar)
70
+ # Step 3 - since this is symmetric, we only have to do half
71
+ n_bit_mask = int(np.ceil(n / 8))
72
+ # Boolean is stored as a byte, so this is smaller
73
+ visited = np.zeros((n, n_bit_mask), dtype=np.uint8)
74
+
75
+ if progress_bar is not None:
76
+ progress_bar.set(0)
77
+
78
+ if inplace:
79
+ # Due to loop-walking, we cannot use the parallel operations since we cannot know a-priori which loops are different.
80
+ for ij in range(n):
81
+ shuffle_ordered_column(n, ij, ordered_matrix, p, visited)
82
+ if progress_bar is not None:
83
+ progress_bar.update(1)
84
+ else:
85
+ for ij in prange(n):
86
+ for jk in range(ij, n):
87
+ ordered_matrix[ij, jk] = ordered_matrix[jk, ij] = (
88
+ matrix_of_pairwise_distance[p[ij], p[jk]]
89
+ )
90
+ if progress_bar is not None:
91
+ progress_bar.update(1)
92
+
93
+ # Step 4 - since this is symmetric, we only have to do half
94
+ return ordered_matrix, p, q
95
+
96
+
97
+ @njit(cache=True)
98
+ def shuffle_ordered_column(
99
+ n: int, ij: int, ordered_matrix: ndarray, p: ndarray, visited: ndarray
100
+ ):
101
+ for jk in range(ij, n):
102
+ if _get_bit(visited, ij, jk):
103
+ continue
104
+ # Walk this loop, and store which visited
105
+ r0, c0 = ij, jk
106
+ r1, c1 = -1, -1
107
+ p0 = ordered_matrix[r0, c0]
108
+ while r1 != ij or c1 != jk:
109
+ r1, c1 = p[r0], p[c0]
110
+ _set_bit(visited, r0, c0)
111
+ _set_bit(visited, c0, r0)
112
+ ordered_matrix[r0, c0] = ordered_matrix[c0, r0] = ordered_matrix[r1, c1]
113
+ # Next step!
114
+ r0, c0 = r1, c1
115
+ # Close the final block
116
+ ordered_matrix[r0, c0] = ordered_matrix[c0, r0] = p0
117
+ _set_bit(visited, r0, c0)
118
+ _set_bit(visited, c0, r0)
119
+
120
+
121
+ @njit(cache=True)
122
+ def _set_bit(bitmask: np.ndarray, row: int, col: int) -> None:
123
+ bitmask[row, col // 8] |= 1 << (col % 8)
124
+
125
+
126
+ @njit(cache=True)
127
+ def _get_bit(bitmask: np.ndarray, row: int, col: int) -> int:
128
+ return (bitmask[row, col // 8] >> (col % 8)) & 1
129
+
130
+
131
+ @njit(cache=True)
132
+ def vat_prim_mst(
133
+ adj: np.ndarray, progress_bar: ProgressBar | None = None
134
+ ) -> tuple[np.ndarray, np.ndarray]:
135
+ n: int = len(adj)
136
+
137
+ # Find the column of the maximum value.
138
+ max_adj: np.signedinteger = np.argmax(adj)
139
+ src_i: np.signedinteger = max_adj // n
140
+ src_j: np.signedinteger = max_adj % n
141
+ src_key = adj[src_i, src_j]
142
+
143
+ # Create a list for keys and initialize all keys as infinite (INF)
144
+ key: np.ndarray = np.full(n, np.inf, dtype=adj.dtype)
145
+
146
+ # To store the parent array which, in turn, stores MST
147
+ parent: np.ndarray = np.full(n, -1, dtype=np.int32)
148
+
149
+ # To keep track of vertices included in MST
150
+ in_mst: np.ndarray = np.full(n, False, dtype=np.bool_)
151
+
152
+ # Insert the source itself into the priority queue and initialize its key as 0
153
+ pq: list[tuple[float, np.signedinteger, np.signedinteger]] = [
154
+ (src_key, src_i, src_j)
155
+ ] # Priority queue to store vertices that are being processed
156
+ key[src_i] = src_key
157
+
158
+ # The final sequence of vertices in MST
159
+ heap_seq: np.ndarray = np.zeros(n, dtype=np.int32)
160
+ heap_seq_idx: int = 0
161
+
162
+ # Parent sequences of vertices in MST (for iVAT)
163
+ parent_seq: np.ndarray = np.zeros(n, dtype=np.int32)
164
+ parent_seq_idx: int = 0
165
+
166
+ # Preallocated
167
+ vertices: np.ndarray = np.arange(n)
168
+
169
+ # Loop until the priority queue becomes empty
170
+ while pq:
171
+ # The first vertex in the pair is the minimum key vertex
172
+ # Extract it from the priority queue
173
+ # The vertex label is stored in the second of the pair
174
+ w, u, v0 = heapq.heappop(pq)
175
+
176
+ # Different key values for the same vertex may exist in the priority queue.
177
+ # The one with the least key value is always processed first.
178
+ # Therefore, ignore the rest.
179
+ if in_mst[u]:
180
+ continue
181
+
182
+ in_mst[u] = True # Include the vertex in MST
183
+ heap_seq[heap_seq_idx] = u
184
+ heap_seq_idx += 1
185
+
186
+ parent_seq[parent_seq_idx] = v0
187
+ parent_seq_idx += 1
188
+
189
+ if progress_bar is not None:
190
+ progress_bar.update(1)
191
+
192
+ # Iterate through all adjacent vertices of a vertex
193
+ # Parallel processing of adjacent vertices
194
+ mask = (vertices != u) & ~in_mst & (key[vertices] >= adj[u, vertices])
195
+ key[mask] = adj[u, mask]
196
+ for v in vertices[mask]:
197
+ heapq.heappush(pq, (key[v], v, heap_seq_idx))
198
+ parent[v] = u
199
+
200
+ return heap_seq, parent_seq
201
+
202
+
203
+ @njit(cache=True)
204
+ def vat_prim_mst_seq(samples: np.ndarray) -> np.ndarray:
205
+ n = len(samples)
206
+
207
+ # Find the column of the maximum value.
208
+ max_adj = -np.inf
209
+ max_idx = (-1, -1)
210
+ for ij in range(n):
211
+ for jk in range(ij, n):
212
+ cur_dist = _get_dist(samples, ij, jk)
213
+ if cur_dist > max_adj:
214
+ max_adj = cur_dist
215
+ max_idx = (ij, jk)
216
+
217
+ src = max_idx[0]
218
+ src_key = max_adj
219
+
220
+ # Create a list for keys and initialize all keys as infinite (INF)
221
+ key: np.ndarray = np.full(n, float("inf"))
222
+
223
+ # To store the parent array which, in turn, stores MST
224
+ parent: np.ndarray = np.full(n, -1)
225
+
226
+ # To keep track of vertices included in MST
227
+ in_mst = np.full(n, False)
228
+
229
+ # Insert the source itself into the priority queue and initialize its key as 0
230
+ pq: list[tuple[float, int]] = [
231
+ (src_key, src)
232
+ ] # Priority queue to store vertices that are being processed
233
+ key[src] = src_key
234
+
235
+ # The final sequence of vertices in MST
236
+ heap_seq: np.ndarray = np.zeros(n, dtype=np.int32)
237
+ heap_seq_idx = 0
238
+
239
+ # Preallocated
240
+ vertices = np.arange(n)
241
+
242
+ # Loop until the priority queue becomes empty
243
+ while pq:
244
+ # The first vertex in the pair is the minimum key vertex
245
+ # Extract it from the priority queue
246
+ # The vertex label is stored in the second of the pair
247
+ u = heapq.heappop(pq)[1]
248
+
249
+ # Different key values for the same vertex may exist in the priority queue.
250
+ # The one with the least key value is always processed first.
251
+ # Therefore, ignore the rest.
252
+ if in_mst[u]:
253
+ continue
254
+
255
+ in_mst[u] = True # Include the vertex in MST
256
+ heap_seq[heap_seq_idx] = u
257
+ heap_seq_idx += 1
258
+
259
+ # Iterate through all adjacent vertices of a vertex
260
+ # Parallel processing of adjacent vertices
261
+
262
+ mask = (
263
+ (vertices != u)
264
+ & ~in_mst
265
+ & (key[vertices] > _get_dist(samples, u, vertices))
266
+ )
267
+ key[mask] = _get_dist(samples, u, vertices[mask])
268
+ for v in vertices[mask]:
269
+ heapq.heappush(pq, (key[v], v))
270
+ parent[v] = u
271
+
272
+ return heap_seq
273
+
274
+
275
+ @njit(cache=True)
276
+ def _get_dist(samples: np.ndarray, idx1: int, idx2: int) -> float:
277
+ diff = samples[idx1, :] - samples[idx2, :]
278
+ return np.sqrt(np.sum(np.square(diff)))
@@ -0,0 +1,17 @@
1
+ import numpy as np
2
+ from numba import prange, njit
3
+
4
+
5
+ @njit(cache=True, parallel=True, nogil=True)
6
+ def pairwise_distances(data: np.ndarray) -> np.ndarray:
7
+ is_1d: bool = data.shape[1] == 1
8
+ if is_1d:
9
+ # Vectorized computation for 1D case
10
+ return np.abs(data.T - data)
11
+ else:
12
+ dist_arr = np.zeros((data.shape[0], data.shape[0]), dtype=data.dtype)
13
+ for i in prange(len(data)):
14
+ for j in range(i + 1, len(data)):
15
+ dist_arr[i, j] = np.linalg.norm(data[i, :] - data[j, :])
16
+ dist_arr[j, i] = dist_arr[i, j]
17
+ return dist_arr
@@ -0,0 +1,552 @@
1
+ import time
2
+ from typing import Any
3
+
4
+ import numpy as np
5
+ from matplotlib import pyplot as plt
6
+ from numpy import ndarray
7
+ from scipy.spatial import Voronoi, voronoi_plot_2d
8
+
9
+ from clustering.util import pairwise_distances
10
+ from src.clustering import vat_prim_mst_seq, compute_ivat, fcm
11
+
12
+
13
+ def _random_cities(
14
+ center_x, center_y, n_cities: int = 10, cluster_diameter: float = 3.0
15
+ ) -> np.ndarray:
16
+ if n_cities == 1:
17
+ return np.array([[center_x, center_y]])
18
+ # Randomly distribute cities in a uniform circle?
19
+ theta = np.linspace(0, 2 * np.pi, n_cities + 1, dtype=np.float32)
20
+ theta = theta[:-1]
21
+ # Add slight random scramble to locations
22
+ scramble = np.random.uniform(
23
+ -cluster_diameter * 0.05, cluster_diameter * 0.05, size=(n_cities, 2)
24
+ )
25
+ city_x = np.cos(theta) * cluster_diameter / 2.0 + center_x + scramble[:, 0]
26
+ city_y = np.sin(theta) * cluster_diameter / 2.0 + center_y + scramble[:, 1]
27
+ return np.c_[city_x, city_y]
28
+
29
+
30
+ def _circle_random_clusters(
31
+ n_clusters: int = 10,
32
+ n_cities: int = 10,
33
+ cluster_diameter: float = 2.0,
34
+ cluster_spacing: float = 10.0,
35
+ ) -> np.ndarray:
36
+ city_locations = np.zeros(shape=(0, 2), dtype=np.float32)
37
+ for theta in np.linspace(0, 2 * np.pi, n_clusters):
38
+ theta *= n_clusters / (n_clusters + 1)
39
+ cx = cluster_spacing * np.cos(theta)
40
+ cy = cluster_spacing * np.sin(theta)
41
+ city_locations = np.concatenate(
42
+ (
43
+ city_locations,
44
+ _random_cities(
45
+ cx, cy, n_cities=n_cities, cluster_diameter=cluster_diameter
46
+ ),
47
+ ),
48
+ axis=0,
49
+ )
50
+ return city_locations
51
+
52
+
53
+ def _hierarchical_circle_clusters(
54
+ clusters_per_level: list[int],
55
+ diameters_per_level: list[float],
56
+ ) -> np.ndarray:
57
+ """
58
+ Create hierarchical clusters arranged in circles around circles recursively.
59
+
60
+ Args:
61
+ clusters_per_level: Number of clusters at each hierarchical level (e.g., [3, 4, 5] means 3 top-level clusters,
62
+ each containing 4 mid-level clusters, each containing 5 leaf clusters)
63
+ diameters_per_level: Diameter for clusters at each level
64
+
65
+ Returns:
66
+ Array of point coordinates (N, 2)
67
+
68
+ Raises:
69
+ ValueError: If configuration would create more than 16000 points
70
+ """
71
+ if len(clusters_per_level) != len(diameters_per_level):
72
+ raise ValueError("clusters_per_level and diameters_per_level must have the same length")
73
+
74
+ # Calculate total number of points
75
+ total_points = np.prod(clusters_per_level)
76
+
77
+ if total_points > 16000:
78
+ raise ValueError(
79
+ f"Configuration would create {total_points} points, which exceeds the limit of 16000. "
80
+ f"Reduce clusters_per_level, diameters_per_level, or points_per_leaf."
81
+ )
82
+
83
+ def _create_level(
84
+ center_x: float,
85
+ center_y: float,
86
+ level_idx: int,
87
+ ) -> np.ndarray:
88
+ """Recursively create clusters at the current level"""
89
+ n_clusters = clusters_per_level[level_idx]
90
+ diameter = diameters_per_level[level_idx]
91
+ # Last level gets a bit of noise
92
+ if level_idx == len(clusters_per_level)-1:
93
+ # Base case: create leaf points
94
+ return _random_cities(
95
+ center_x, center_y,
96
+ n_cities=n_clusters,
97
+ cluster_diameter=diameter
98
+ )
99
+
100
+ all_points = np.zeros(shape=(0, 2), dtype=np.float32)
101
+
102
+ for theta in np.linspace(0, 2 * np.pi, n_clusters, endpoint=False):
103
+ # Calculate position of sub-cluster center
104
+ sub_cx = center_x + diameter * np.cos(theta)
105
+ sub_cy = center_y + diameter * np.sin(theta)
106
+
107
+ # Recursively create points for this sub-cluster
108
+ sub_points = _create_level(sub_cx, sub_cy,level_idx + 1)
109
+
110
+ all_points = np.concatenate((all_points, sub_points), axis=0)
111
+
112
+ return all_points
113
+
114
+ # Start recursion from the origin
115
+ return _create_level(0.0, 0.0, 0)
116
+
117
+
118
+ def _test_cluster_sequencing():
119
+ from ucimlrepo import fetch_ucirepo
120
+
121
+ # fetch dataset
122
+ # 59 is letter recognition
123
+ # 827 is sepsis survival (allocates 80+ GB RAM)
124
+ # 148 is shuttle stat log (allocates 50 GB RAM)
125
+ letter_recognition = fetch_ucirepo(id=59)
126
+
127
+ # data (as pandas dataframes)
128
+ X = np.array(letter_recognition.data.features)
129
+
130
+ # metadata
131
+ print(f"Metadata: {letter_recognition.metadata}")
132
+
133
+ # variable information
134
+ print(f"Variable Information: {letter_recognition.variables}")
135
+
136
+ # Compute the pairwise distances
137
+ t0 = time.time()
138
+ ordered_matrix = vat_prim_mst_seq(X)
139
+ t1 = time.time()
140
+
141
+ print(f"Elapsed time for {len(X)} data points: {t1-t0:.02f}")
142
+
143
+
144
+ def test_merge_ivat():
145
+ all_cities = _circle_random_clusters(
146
+ n_clusters=10, n_cities=5, cluster_spacing=5.0, cluster_diameter=1
147
+ )
148
+ # Scramble the order of the cities
149
+ scramble_order = np.random.permutation(len(all_cities))
150
+ all_cities = all_cities[scramble_order]
151
+ matrix_of_pairwise_distance = pairwise_distances(all_cities)
152
+
153
+ ivat_mst, vat_mst, ivat_order, vat_order = compute_ivat(matrix_of_pairwise_distance)
154
+ plot_vat_ivat(ivat_mst, vat_mst)
155
+
156
+
157
+ def plot_vat_ivat(ivat_mst: np.ndarray, vat_mst: np.ndarray):
158
+ fig, (ax1, ax2) = plt.subplots(1, 2)
159
+
160
+ im1 = ax1.imshow(vat_mst, cmap="viridis")
161
+ ax1.set_title("VAT Matrix")
162
+ plt.colorbar(im1, ax=ax1)
163
+
164
+ im2 = ax2.imshow(ivat_mst, cmap="viridis")
165
+ ax2.set_title("iVAT Matrix")
166
+ plt.colorbar(im2, ax=ax2)
167
+ plt.tight_layout()
168
+ plt.show()
169
+
170
+
171
+ def test_fcm_with_center_on_datapoint():
172
+ """Test FCM behavior when a cluster center coincides with a data point"""
173
+ # Create 5 points on a line: (1,0), (3,0), (5,0), (7,0), (9,0)
174
+ data_points = np.array([[1.0, 0.0], [3.0, 0.0], [5.0, 0.0], [7.0, 0.0], [9.0, 0.0]])
175
+
176
+ # Run FCM with 2 clusters
177
+ n_clusters = 2
178
+ for idx0 in range(len(data_points)):
179
+ for idx1 in range(idx0, len(data_points)):
180
+ cluster_centers, membership_weights = fcm.fuzzy_c_means(
181
+ data_points, n_clusters, m=2.0, indices=[idx0, idx1]
182
+ )
183
+
184
+ # Verify that we got 2 cluster centers
185
+ assert cluster_centers.shape == (
186
+ n_clusters,
187
+ 2,
188
+ ), f"Expected shape {(n_clusters, 2)}, got {cluster_centers.shape}"
189
+
190
+ # Verify that membership weights sum to 1 for each data point
191
+ membership_sums = np.sum(membership_weights, axis=1)
192
+ # Verify that membership weights sum to 1 for each data point
193
+ # (or 0 for duplicate points where cluster center coincides with data point)
194
+ membership_sums = np.sum(membership_weights, axis=1)
195
+ expected_values = np.where(
196
+ membership_sums > 0.5, 1.0, 0.0
197
+ ) # Expect 1.0 or 0.0
198
+ np.testing.assert_array_almost_equal(
199
+ membership_sums,
200
+ expected_values,
201
+ err_msg=f"Membership weights should sum to 1 for each data point (or 0 for duplicates): {idx0, idx1}",
202
+ )
203
+
204
+ # Verify all membership weights are between 0 and 1
205
+ assert np.all(membership_weights >= 0) and np.all(
206
+ membership_weights <= 1
207
+ ), "All membership weights should be between 0 and 1"
208
+
209
+
210
+ def test_heirarchy_ivat_means():
211
+ """Test hierarchical circle clusters with iVAT and FCM"""
212
+ # Example: 3 top-level clusters, each with 4 mid-level, each with 5 leaf clusters (3*4*5*10 = 600 points)
213
+ all_cities = _hierarchical_circle_clusters(
214
+ clusters_per_level=[3, 4, 5],
215
+ diameters_per_level=[15.0, 5.0, 1.0]
216
+ )
217
+
218
+ # Scramble the order of the cities
219
+ scramble_order = np.random.permutation(len(all_cities))
220
+ all_cities = all_cities[scramble_order]
221
+
222
+ print(f"Created {len(all_cities)} hierarchical points")
223
+
224
+ # Compute pairwise distances and iVAT
225
+ matrix_of_pairwise_distance = pairwise_distances(all_cities)
226
+ ivat_mst, vat_mst, ivat_order, vat_order = compute_ivat(matrix_of_pairwise_distance)
227
+
228
+ # Get cluster information from iVAT
229
+ abrupt_change_indices, cluster_city_ids, diagonal_values, initial_centroids, max_diff_indices, peaks_threshold, sorted_diagonal = _get_ivat_means(
230
+ all_cities, ivat_mst, vat_order)
231
+ # Run FCM with iVAT-derived initial guess
232
+ n_clusters = len(initial_centroids)
233
+ meth_c, w_c = fcm.fuzzy_c_means(all_cities, n_clusters, 2, initial_guess=initial_centroids[0])
234
+
235
+ print(f"Detected {n_clusters} clusters using iVAT")
236
+
237
+ # Visualize results
238
+ plot_vat_ivat(ivat_mst, vat_mst)
239
+ for idx in range(len(initial_centroids)):
240
+ # plot_membership(all_cities, cluster_city_ids[idx], meth_c, w_c)
241
+ plot_diagonal(
242
+ diagonal_values,
243
+ [max_diff_indices[idx]],
244
+ peaks_threshold[idx],
245
+ sorted_diagonal,
246
+ abrupt_change_indices[idx],
247
+ )
248
+ plot_voronoi(all_cities, initial_centroids[idx])
249
+ plt.show()
250
+
251
+
252
+ def plot_voronoi(all_cities, centroids):
253
+ v = Voronoi(centroids)
254
+ fig = voronoi_plot_2d(v)
255
+ fig.axes[0].set_title("Voronoi plot")
256
+ fig.axes[0].scatter(all_cities[:, 0], all_cities[:, 1])
257
+ fig.show()
258
+
259
+
260
+ def test_multi_dim_pairwise_dist_perf():
261
+ results = []
262
+ # Do 1 pairwise distances to reduce nogil/numba randomness
263
+ pairwise_distances(np.zeros((100, 3)))
264
+
265
+ dims = [1, 2]
266
+ sizes = [1000, 2000, 3000, 5000, 8000, 10000, 20000]
267
+ # sizes = [1000, 2000]
268
+ for dim in dims:
269
+ for size in sizes:
270
+ data = np.random.rand(size, dim)
271
+ start = time.time()
272
+ pairwise_distances(data)
273
+ end = time.time()
274
+ elapsed = end - start
275
+ results.append((dim, size, elapsed))
276
+
277
+ # Plot results
278
+ fig, ax = plt.subplots(figsize=(10, 6))
279
+
280
+ # Group results by dimension
281
+ colors = plt.cm.viridis(np.linspace(0, 1, len(dims)))
282
+
283
+ for dim, color in zip(dims, colors):
284
+ dim_results = [(size, time) for d, size, time in results if d == dim]
285
+ sizes, times = zip(*dim_results)
286
+ ax.plot(sizes, times, marker='o', label=f'Dim={dim}', color=color, linewidth=2)
287
+
288
+ ax.set_xlabel('Data Size', fontsize=12)
289
+ ax.set_ylabel('Time (seconds)', fontsize=12)
290
+ ax.set_title('Pairwise Distance Computation Performance', fontsize=14)
291
+ ax.legend()
292
+ ax.grid(True, alpha=0.3)
293
+ plt.tight_layout()
294
+ plt.show()
295
+
296
+
297
+ def test_fuzzy_c_means():
298
+ n_total: int = 256
299
+ n_clusters: int = 16
300
+ n_cities: int = n_total // n_clusters
301
+ all_cities = _circle_random_clusters(
302
+ n_clusters=n_clusters, n_cities=n_cities, cluster_spacing=5, cluster_diameter=0.5
303
+ )
304
+ # Scramble the order of the cities
305
+ scramble_order = np.random.permutation(len(all_cities))
306
+ all_cities = all_cities[scramble_order]
307
+
308
+ # Time the elbow method (multiple FCM calls with varying cluster counts)
309
+ start_elbow = time.time()
310
+ elbow_results = []
311
+ cluster_range = range(2, n_clusters + 1)
312
+ for k in cluster_range:
313
+ centers, weights = fcm.fuzzy_c_means(all_cities, k, 2)
314
+ elbow_results.append((k, centers, weights))
315
+ end_elbow = time.time()
316
+ elbow_time = end_elbow - start_elbow
317
+
318
+ start_ivat = time.time()
319
+ matrix_of_pairwise_distance = pairwise_distances(all_cities)
320
+ # Compute the IVAT
321
+ ivat_mst, vat_mst, ivat_order, vat_order = compute_ivat(matrix_of_pairwise_distance)
322
+ abrupt_change_indices, cluster_city_ids, diagonal_values, initial_centroids, max_diff_indices, peaks_threshold, sorted_diagonal = _get_ivat_means(
323
+ all_cities, ivat_mst, vat_order)
324
+
325
+ # Time the single FCM call
326
+ start_single = time.time()
327
+ meth_c, w_c = fcm.fuzzy_c_means(all_cities, n_clusters, 2, initial_guess=initial_centroids)
328
+ mid_single = time.time()
329
+ _, _ = fcm.fuzzy_c_means(all_cities, n_clusters, 2)
330
+ end_single = time.time()
331
+ _, _ = fcm.fuzzy_c_means(all_cities, n_clusters, 2, method='gd')
332
+ end_gd = time.time()
333
+ smart_fcm_time = mid_single - start_single
334
+ single_fcm_time = end_single - mid_single
335
+ iter_fcm_time = end_gd - end_single
336
+ single_ivat_time = start_single - start_ivat
337
+
338
+ # Print performance comparison
339
+ print(f"\n{'=' * 60}")
340
+ print(f"Performance Comparison:")
341
+ print(f"{'=' * 60}")
342
+ print(f"Elbow Method (n=2 to {n_clusters}): {elbow_time:.4f} seconds")
343
+ print(f"Single iter-FCM (n={n_clusters}): {single_fcm_time:.4f} seconds")
344
+ print(f"Single GD-FCM (n={n_clusters}): {iter_fcm_time:.4f} seconds")
345
+ print(f"Smart FCM (n={n_clusters}): {smart_fcm_time:.4f} seconds")
346
+ print(f"IVAT (n={n_clusters}): {single_ivat_time:.4f} seconds")
347
+ print(f"Time difference: {elbow_time - single_ivat_time:.4f} seconds")
348
+ print(f"Elbow method is {elbow_time/single_ivat_time:.2f}x slower")
349
+ print(f"{'='*60}\n")
350
+
351
+ # Assert that every city has been allocated to a cluster
352
+ all_allocated_cities = np.sort(np.concatenate(cluster_city_ids))
353
+ # print(f"All cities:\n{np.r_[0:len(all_cities)]}")
354
+ # print(f"Allocated Cities:\n{all_allocated_cities}")
355
+ assert len(all_allocated_cities) == len(
356
+ all_cities
357
+ ), f"Not all cities allocated: {len(all_allocated_cities)} allocated out of {len(all_cities)} total"
358
+ assert len(np.unique(all_allocated_cities)) == len(
359
+ all_cities
360
+ ), f"Duplicate city allocations detected"
361
+
362
+ plot_vat_ivat(ivat_mst, vat_mst)
363
+
364
+ plot_diagonal(
365
+ diagonal_values,
366
+ max_diff_indices,
367
+ peaks_threshold,
368
+ sorted_diagonal,
369
+ abrupt_change_indices,
370
+ )
371
+
372
+ plot_membership(all_cities, cluster_city_ids, meth_c, w_c)
373
+ plt.show()
374
+
375
+
376
+ def _get_ivat_means(all_cities: ndarray, ivat_mst: ndarray,
377
+ vat_order: ndarray) -> tuple[ndarray, list[Any], ndarray, ndarray, list[int], ndarray, ndarray]:
378
+ # Look down the off-by-1 diagonal and count the number of substantial changes.
379
+ diagonal_values = np.diag(ivat_mst, k=1)
380
+ # Augment back to original size, just prepend the initial value to avoid throwing off the diff fcn
381
+ # Expand this to the original size for convenience.
382
+ diagonal_values = np.concatenate(
383
+ [np.array([diagonal_values[0]]), diagonal_values], axis=0
384
+ )
385
+ # Sort the diagonal values
386
+ sorted_diagonal = np.sort(diagonal_values)
387
+ # Find the maximum difference and the index thereof
388
+ diagonal_diffs = np.diff(sorted_diagonal)
389
+ max_diff_indices = _arg_max(diagonal_diffs, 3)
390
+ peaks_threshold = sorted_diagonal[max_diff_indices + 1]
391
+ abrupt_change_indices = []
392
+ cluster_groups = []
393
+ cluster_city_ids = []
394
+ initial_centroids = []
395
+ for index, peak_th in enumerate(peaks_threshold):
396
+ abrupt_change_idx = np.where(diagonal_values >= peak_th)[0]
397
+ abrupt_change_indices.append(abrupt_change_idx)
398
+
399
+ # Use each section as a cluster endpoint, inclusive.
400
+ cluster_group = np.concatenate([np.array([0]), abrupt_change_idx, np.array([len(all_cities)])])
401
+ cluster_groups.append(cluster_group)
402
+ cluster_city_indexs = []
403
+ for idx in range(0, len(cluster_group) - 1):
404
+ cg_start = cluster_group[idx]
405
+ cg_end = cluster_group[idx + 1]
406
+ # Use the VAT order to pick out the cities in each cluster
407
+ cluster_city_indexs.append(vat_order[cg_start:cg_end])
408
+
409
+ # Compute the initial guess as the centroid of each city cluster
410
+ initial_centroids.append(np.array([
411
+ np.mean(all_cities[cluster_ids], axis=0)
412
+ for cluster_ids in cluster_city_indexs
413
+ ]))
414
+ cluster_city_ids.append(cluster_city_indexs)
415
+
416
+ return abrupt_change_indices, cluster_city_ids, diagonal_values, initial_centroids, max_diff_indices, peaks_threshold, sorted_diagonal
417
+
418
+
419
+ def get_ivat_means(all_cities: ndarray, ivat_mst: ndarray, vat_order: ndarray) -> tuple[ndarray, list[Any], ndarray]:
420
+ abrupt_change_indices, cluster_city_ids, diagonal_values, initial_centroids, max_diff_indices, peaks_threshold, sorted_diagonal = _get_ivat_means(all_cities, ivat_mst, vat_order)
421
+ return initial_centroids, cluster_city_ids, peaks_threshold
422
+
423
+
424
+ def _arg_max(a: ndarray, n: int = 1) -> ndarray:
425
+ """Get the indexes of the n-largest values in the array. You can assume it's a 1D array"""
426
+ if n >= len(a):
427
+ return np.argsort(a)[::-1]
428
+ # Use argpartition to find the n largest elements efficiently
429
+ partitioned_indices = np.argpartition(a, -n)[-n:]
430
+ # Sort these indices by their corresponding values in descending order
431
+ sorted_indices = partitioned_indices[np.argsort(a[partitioned_indices])[::-1]]
432
+ return sorted_indices
433
+
434
+
435
+ def plot_membership(all_cities: ndarray, cluster_city_ids: list[Any],
436
+ meth_c: ndarray, w_c: ndarray):
437
+ # Create a color map for clusters
438
+ colors = plt.cm.rainbow(np.linspace(0, 1, meth_c.shape[0]))
439
+
440
+ # Create plot
441
+ fig, ax = plt.subplots()
442
+
443
+ # Plot each point with blended color based on membership weights
444
+ for i in range(all_cities.shape[0]):
445
+ # Blend colors based on membership weights
446
+ blended_color = np.zeros(4) # RGBA
447
+ for j in range(meth_c.shape[0]):
448
+ blended_color += w_c[i, j] * colors[j]
449
+
450
+ blended_color /= blended_color.max()
451
+
452
+ ax.scatter(
453
+ all_cities[i, 0],
454
+ all_cities[i, 1],
455
+ c=[blended_color],
456
+ s=50,
457
+ alpha=0.7,
458
+ edgecolors="black",
459
+ linewidth=0.5,
460
+ )
461
+
462
+ # Plot cluster city IDs with "*" markers
463
+ ivat_centers = []
464
+ for idx, cluster_ids in enumerate(cluster_city_ids):
465
+ cluster_points = all_cities[cluster_ids]
466
+ cluster_color = colors[idx % len(colors)]
467
+ ax.scatter(
468
+ cluster_points[:, 0],
469
+ cluster_points[:, 1],
470
+ marker="*",
471
+ edgecolors=cluster_color,
472
+ facecolors="none",
473
+ label=f"Cluster {idx}",
474
+ )
475
+ center = np.mean(cluster_points, axis=0)
476
+ ivat_centers.append(center)
477
+ ivat_centers = np.array(ivat_centers)
478
+
479
+ # Plot ivat cluster centers
480
+ ax.scatter(
481
+ ivat_centers[:, 0],
482
+ ivat_centers[:, 1],
483
+ c="red",
484
+ s=150,
485
+ marker="D",
486
+ edgecolors="white",
487
+ label="iVAT Cluster Centers",
488
+ )
489
+
490
+ # Plot cluster centers
491
+ ax.scatter(
492
+ meth_c[:, 0],
493
+ meth_c[:, 1],
494
+ c="black",
495
+ s=150,
496
+ marker="X",
497
+ edgecolors="white",
498
+ label="FCM Cluster Centers",
499
+ )
500
+
501
+ ax.set_title("Fuzzy C-Means Clustering with Membership-based Colors")
502
+ ax.set_xlabel("X Coordinate")
503
+ ax.set_ylabel("Y Coordinate")
504
+ # ax.legend(bbox_to_anchor=(1.05, 1), loc="upper left")
505
+ ax.legend()
506
+ plt.tight_layout()
507
+
508
+
509
+ def plot_diagonal(
510
+ diagonal_values: ndarray,
511
+ max_diff_indices: list[int],
512
+ peaks_threshold,
513
+ sorted_diagonal: ndarray,
514
+ abrupt_change_indices: ndarray,
515
+ ) -> ndarray:
516
+ fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(5, 8))
517
+ ax1.plot(diagonal_values, marker="o")
518
+ ax1.set_title("Off-by-One Diagonal of iVAT Matrix")
519
+ ax1.set_xlabel("Index")
520
+ ax1.set_ylabel("Distance Value")
521
+ ax1.grid(True)
522
+
523
+ ax2.plot(sorted_diagonal, marker="o")
524
+ for idx in max_diff_indices:
525
+ ax2.axvline(
526
+ x=idx,
527
+ color="r",
528
+ linestyle="--",
529
+ label=f"Max diff at index {idx}",
530
+ )
531
+ ax2.legend()
532
+ ax2.set_title("Sorted Off-by-One Diagonal of iVAT Matrix")
533
+ ax2.set_xlabel("Index")
534
+ ax2.set_ylabel("Distance Value")
535
+ ax2.grid(True)
536
+ plt.tight_layout()
537
+
538
+ # Count abrupt size changes using a basic stats test
539
+ ax1.axhline(
540
+ y=peaks_threshold,
541
+ color="r",
542
+ linestyle="--",
543
+ label=f"Threshold: {peaks_threshold:.2f}",
544
+ )
545
+ ax2.text(
546
+ 0.02,
547
+ 0.98,
548
+ f"Abrupt changes: {len(abrupt_change_indices)}, threshold: {peaks_threshold:.2f}",
549
+ transform=ax2.transAxes,
550
+ verticalalignment="top",
551
+ bbox=dict(boxstyle="round", facecolor="wheat", alpha=0.5),
552
+ )