asbm 0.1.4.dev3__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
asbm/__init__.py ADDED
@@ -0,0 +1,18 @@
1
+ """
2
+ Copyright (c) 2025 Max Jerdee. All rights reserved.
3
+
4
+ asbm: Finding group structure in networks
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from ._version import version as __version__
10
+
11
+ __all__ = [
12
+ "__version__",
13
+ "Config",
14
+ "Result",
15
+ "fit",
16
+ ]
17
+
18
+ from .core import Config, Result, fit
Binary file
asbm/_version.py ADDED
@@ -0,0 +1,24 @@
1
+ # file generated by vcs-versioning
2
+ # don't change, don't track in version control
3
+ from __future__ import annotations
4
+
5
+ __all__ = [
6
+ "__version__",
7
+ "__version_tuple__",
8
+ "version",
9
+ "version_tuple",
10
+ "__commit_id__",
11
+ "commit_id",
12
+ ]
13
+
14
+ version: str
15
+ __version__: str
16
+ __version_tuple__: tuple[int | str, ...]
17
+ version_tuple: tuple[int | str, ...]
18
+ commit_id: str | None
19
+ __commit_id__: str | None
20
+
21
+ __version__ = version = '0.1.4.dev3'
22
+ __version_tuple__ = version_tuple = (0, 1, 4, 'dev3')
23
+
24
+ __commit_id__ = commit_id = 'gd2cbac3fd'
asbm/core.py ADDED
@@ -0,0 +1,166 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Literal
5
+
6
+ import numpy as np
7
+ import networkx as nx
8
+
9
+ try:
10
+ from ._asbm import FitResult as _NativeFitResult
11
+ from ._asbm import fit_backend
12
+ except ImportError as e: # pragma: no cover
13
+ _NativeFitResult = None # type: ignore[assignment]
14
+ fit_backend = None # type: ignore[assignment]
15
+ _BACKEND_IMPORT_ERROR = e
16
+ else:
17
+ _BACKEND_IMPORT_ERROR = None
18
+
19
+
20
+ ModelName = Literal[
21
+ "asbm", "general_canonical", "general_unified",
22
+ "simple_asbm", "simple_ASBM",
23
+ "general_asbm", "general_ASBM",
24
+ "hybrid_ASBM",
25
+ "planted_partition",
26
+ "traditional_SBM", "traditional_DCSBM", "traditional_GDCSBM",
27
+ "microcanonical_SBM", "microcanonical_DCSBM",
28
+ ]
29
+
30
+ _ALLOWED_MODELS: frozenset[str] = frozenset({
31
+ "asbm", "general_canonical", "general_unified",
32
+ "simple_asbm", "simple_ASBM",
33
+ "general_asbm", "general_ASBM",
34
+ "hybrid_ASBM",
35
+ "planted_partition",
36
+ "traditional_SBM", "traditional_DCSBM", "traditional_GDCSBM",
37
+ "microcanonical_SBM", "microcanonical_DCSBM",
38
+ })
39
+
40
+ @dataclass(frozen=True)
41
+ class Config:
42
+ model: ModelName = "asbm"
43
+ degree_correction: bool = True
44
+
45
+ def __post_init__(self) -> None:
46
+ if self.model not in _ALLOWED_MODELS:
47
+ raise ValueError(
48
+ "model must be one of: " + ", ".join(sorted(_ALLOWED_MODELS))
49
+ )
50
+
51
+
52
+ class Result:
53
+ def __init__(self, native_result: _NativeFitResult, node_order: tuple[object, ...]) -> None:
54
+ self._native_result = native_result
55
+ self._node_order = node_order
56
+
57
+ @property
58
+ def mdl_value(self) -> float:
59
+ return float(self._native_result.mdl_value)
60
+
61
+ @property
62
+ def mdl_partition(self) -> np.ndarray:
63
+ return np.asarray(self._native_result.mdl_partition, dtype=int)
64
+
65
+ @property
66
+ def summary(self) -> dict:
67
+ return dict(self._native_result.summary)
68
+
69
+ def samples_df(self):
70
+ return self._native_result.samples_df()
71
+
72
+ def consensus_partition(self) -> np.ndarray:
73
+ return np.asarray(self._native_result.consensus_partition(), dtype=int)
74
+
75
+ def coincidence_matrix(self) -> np.ndarray:
76
+ return np.asarray(self._native_result.coincidence_matrix(), dtype=float)
77
+
78
+ def log_posterior_predictive(self, G_test: nx.Graph, f: float | None = None) -> float:
79
+ index_of = {node: i for i, node in enumerate(self._node_order)}
80
+ test_nodes = set(G_test.nodes())
81
+ train_nodes = set(self._node_order)
82
+ if test_nodes != train_nodes:
83
+ missing = sorted(train_nodes - test_nodes)
84
+ extra = sorted(test_nodes - train_nodes)
85
+ raise ValueError(f"Test graph node set must match the training graph. Missing={missing}, extra={extra}")
86
+
87
+ test_edges = [
88
+ (
89
+ index_of[u],
90
+ index_of[v],
91
+ float(data.get("weight", 1.0)),
92
+ )
93
+ for u, v, data in G_test.edges(data=True)
94
+ ]
95
+ return float(self._native_result.log_posterior_predictive(test_edges, f))
96
+
97
+
98
+ def _require_backend() -> None:
99
+ if fit_backend is None:
100
+ raise ImportError(
101
+ "Native backend unavailable. Build the extension module first."
102
+ ) from _BACKEND_IMPORT_ERROR
103
+
104
+
105
+ def fit(
106
+ config: Config,
107
+ G: nx.Graph,
108
+ *,
109
+ sweeps: int = 100,
110
+ burn_sweeps: int = 100,
111
+ num_chains: int = 2,
112
+ samples_per_chain: int = 1,
113
+ num_samples: int | None = None,
114
+ thinning: int = 1,
115
+ seed: int = 0,
116
+ show_progress: bool = False,
117
+ ) -> Result:
118
+ _require_backend()
119
+
120
+ if not isinstance(config, Config):
121
+ raise TypeError("config must be an instance of Config")
122
+ if not isinstance(G, nx.Graph):
123
+ raise TypeError("G must be a networkx.Graph")
124
+
125
+ if num_samples is not None:
126
+ if num_samples <= 0:
127
+ raise ValueError("num_samples must be positive")
128
+ num_chains = 1
129
+ samples_per_chain = int(num_samples)
130
+
131
+ if sweeps <= 0:
132
+ raise ValueError("sweeps must be positive")
133
+ if burn_sweeps < 0:
134
+ raise ValueError("burn_sweeps must be non-negative")
135
+ if num_chains <= 0:
136
+ raise ValueError("num_chains must be positive")
137
+ if samples_per_chain <= 0:
138
+ raise ValueError("samples_per_chain must be positive")
139
+ if thinning <= 0:
140
+ raise ValueError("thinning must be positive")
141
+
142
+ nodes = sorted(G.nodes())
143
+ index_of = {node: i for i, node in enumerate(nodes)}
144
+ edges = [
145
+ (
146
+ index_of[u],
147
+ index_of[v],
148
+ float(data.get("weight", 1.0)),
149
+ )
150
+ for u, v, data in G.edges(data=True)
151
+ ]
152
+
153
+ native_result = fit_backend(
154
+ num_nodes=len(nodes),
155
+ edges=edges,
156
+ model=config.model,
157
+ degree_correction=config.degree_correction,
158
+ sweeps=sweeps,
159
+ burn_sweeps=burn_sweeps,
160
+ num_chains=num_chains,
161
+ samples_per_chain=samples_per_chain,
162
+ thinning=thinning,
163
+ seed=seed,
164
+ show_progress=show_progress,
165
+ )
166
+ return Result(native_result, tuple(nodes))
asbm/input_output.py ADDED
@@ -0,0 +1,316 @@
1
+ # Handling input and output of data
2
+
3
+ import networkx as nx
4
+ import pandas as pd
5
+ from typing import Optional, Iterable
6
+ import warnings
7
+
8
+ # Sentinel for unset parameters
9
+ _UNSET = object()
10
+
11
+
12
+ def _canonical_model_name(model_name: str) -> str:
13
+ """
14
+ Validate that `model_name` is one of the canonical model identifiers.
15
+
16
+ This function only accepts canonical names (case-insensitive) and returns
17
+ the canonical-cased string. Using legacy aliases is not permitted; callers
18
+ should pass canonical model names such as "simple_ASBM" or
19
+ "general_ASBM".
20
+ """
21
+ if not isinstance(model_name, str):
22
+ raise TypeError("model_name must be a string containing a canonical model name")
23
+ name = model_name.strip()
24
+ canonical = {
25
+ "general_canonical": "general_canonical",
26
+ "general_unified": "general_unified",
27
+ "general_ASBM": "general_ASBM",
28
+ "simple_ASBM": "simple_ASBM",
29
+ "hybrid_ASBM": "hybrid_ASBM",
30
+ "traditional_SBM": "traditional_SBM",
31
+ "traditional_DCSBM": "traditional_DCSBM",
32
+ "traditional_GDCSBM": "traditional_GDCSBM",
33
+ "microcanonical_SBM": "microcanonical_SBM",
34
+ "microcanonical_DCSBM": "microcanonical_DCSBM",
35
+ "planted_partition": "planted_partition",
36
+ }
37
+ lower_map = {k.lower(): v for k, v in canonical.items()}
38
+ if name.lower() in lower_map:
39
+ return lower_map[name.lower()]
40
+ raise ValueError(
41
+ "model_name must be one of the canonical names: " + ", ".join(sorted(canonical.values())) +
42
+ f". Got '{model_name}'."
43
+ )
44
+
45
+
46
+ # Check that the arguments passed to the function are valid, if not raise an error
47
+ def _check_input_validity(
48
+ G: nx.Graph,
49
+ node_groups: Optional[pd.DataFrame] = None,
50
+ *,
51
+ model_name: str = "general_canonical",
52
+ assortative: bool = _UNSET,
53
+ mixing_variation: str = _UNSET,
54
+ variation_in: Optional[float] = _UNSET,
55
+ variation_out: Optional[float] = _UNSET,
56
+ degree_correction: Optional[float] = _UNSET,
57
+ num_groups: Optional[int] = _UNSET,
58
+ initial_partition: Optional[Iterable[int]] = None,
59
+ seed: Optional[int] = None,
60
+ num_samples: int = 1000,
61
+ beta: float = 1.0,
62
+ num_tempering_chains: int = 1,
63
+ no_cache: bool = False,
64
+ timeout: float = 60.0,
65
+ verbose: bool = False,
66
+ ) -> None:
67
+ """Validate public API inputs and raise informative errors for invalid values.
68
+
69
+ This function checks types, ranges and basic consistency for the parameters
70
+ documented in the public API. It intentionally performs lightweight checks
71
+ only (no heavy computation or graph processing).
72
+ """
73
+
74
+ # Graph checks
75
+ if not isinstance(G, nx.Graph):
76
+ raise TypeError("G must be a NetworkX Graph object.")
77
+
78
+ # model_name
79
+ model_name = _canonical_model_name(model_name)
80
+
81
+ # Fix the unset parameters to their defaults (to pass the type checks below)
82
+ if assortative is _UNSET:
83
+ assortative = True
84
+ if mixing_variation is _UNSET:
85
+ mixing_variation = "general"
86
+ if variation_in is _UNSET:
87
+ variation_in = None
88
+ if variation_out is _UNSET:
89
+ variation_out = None
90
+ if degree_correction is _UNSET:
91
+ degree_correction = None
92
+ # mean_degree_scaling has been fixed to 0.0 and is no longer configurable
93
+ if num_groups is _UNSET:
94
+ num_groups = None
95
+
96
+ # assortative
97
+ if not isinstance(assortative, bool):
98
+ raise TypeError("assortative must be a boolean.")
99
+
100
+ # mixing_variation
101
+ valid_mixing = ["simple", "none", "internal", "external", "general"]
102
+ if mixing_variation not in valid_mixing:
103
+ raise ValueError(f"mixing_variation must be one of {valid_mixing}")
104
+
105
+ # variation_in/out: None or 0 <= x <= 1
106
+ for name, val in (("variation_in", variation_in), ("variation_out", variation_out)):
107
+ if val is not None:
108
+ if not isinstance(val, (int, float)):
109
+ raise TypeError(f"{name} must be a float in [0, 1] or None.")
110
+ if not (0.0 <= float(val) <= 1.0):
111
+ raise ValueError(f"{name} must be between 0 and 1 (inclusive). Got {val}.")
112
+
113
+ # degree_correction: None or 0 <= x <= 1
114
+ if degree_correction is not None:
115
+ if not isinstance(degree_correction, (int, float)):
116
+ raise TypeError("degree_correction must be a float in [0, 1] or None.")
117
+ if not (0.0 <= float(degree_correction) <= 1.0):
118
+ raise ValueError("degree_correction must be between 0 and 1 (inclusive).")
119
+
120
+ # mean_degree_scaling removed from user-facing API
121
+
122
+ # num_groups must be positive integer if provided and not greater than number of nodes
123
+ if num_groups is not None:
124
+ if not isinstance(num_groups, int) or num_groups <= 0:
125
+ raise ValueError("num_groups must be a positive integer or None.")
126
+ if G is not None and num_groups > G.number_of_nodes():
127
+ raise ValueError("num_groups cannot be greater than the number of nodes in G.")
128
+
129
+ # initial_partition: if provided, must be iterable of ints of length equal to number of nodes
130
+ if initial_partition is not None:
131
+ try:
132
+ ip_list = list(initial_partition)
133
+ except Exception:
134
+ raise TypeError("initial_partition must be an iterable of integers or None.")
135
+ if G is not None and len(ip_list) != G.number_of_nodes():
136
+ raise ValueError("initial_partition length must match number of nodes in G.")
137
+ for v in ip_list:
138
+ if not isinstance(v, int) or v < 0:
139
+ raise ValueError("initial_partition must contain non-negative integer labels.")
140
+ if num_groups is not None:
141
+ if max(ip_list) >= num_groups:
142
+ raise ValueError("initial_partition contains labels >= num_groups.")
143
+ # Make sure that each group has at least one member
144
+ unique_groups = set(ip_list)
145
+ if len(unique_groups) < max(ip_list) + 1:
146
+ raise ValueError("initial_partition does not contain members for every group.")
147
+ if num_groups is not None and len(unique_groups) < num_groups:
148
+ raise ValueError("initial_partition has a different number of groups than num_groups.")
149
+
150
+ # seed
151
+ if seed is not None and not isinstance(seed, int):
152
+ raise TypeError("seed must be an integer or None.")
153
+
154
+ # num_samples
155
+ if not isinstance(num_samples, int) or num_samples <= 0:
156
+ raise ValueError("num_samples must be a positive integer.")
157
+
158
+ # beta (non-negative float)
159
+ if not isinstance(beta, (int, float)) or float(beta) < 0.0:
160
+ raise ValueError("beta must be a non-negative number.")
161
+
162
+ # num_tempering_chains
163
+ if not isinstance(num_tempering_chains, int) or num_tempering_chains <= 0:
164
+ raise ValueError("num_tempering_chains must be a positive integer.")
165
+
166
+ # no_cache (boolean)
167
+ if not isinstance(no_cache, bool):
168
+ raise TypeError("no_cache must be a boolean.")
169
+
170
+ # timeout
171
+ if not isinstance(timeout, (int, float)) or float(timeout) <= 0.0:
172
+ raise ValueError("timeout must be a positive number (seconds).")
173
+
174
+ # verbose
175
+ if not isinstance(verbose, bool):
176
+ raise TypeError("verbose must be a boolean.")
177
+
178
+ # Passes all checks
179
+ return None
180
+
181
+ def _set_model_parameters(model_name: str,
182
+ assortative: bool = _UNSET,
183
+ mixing_variation: str = _UNSET,
184
+ variation_in: Optional[float] = _UNSET,
185
+ variation_out: Optional[float] = _UNSET,
186
+ degree_correction: Optional[float] = _UNSET,
187
+ num_groups: Optional[int] = _UNSET) -> tuple:
188
+ """
189
+ Convert the model_name to corresponding non-default parameter values (if not explicitly provided)
190
+ """
191
+ model_name = _canonical_model_name(model_name)
192
+
193
+ if model_name == 'general_canonical':
194
+ pass # all parameters take default values
195
+ elif model_name == 'traditional_SBM':
196
+ if assortative is _UNSET:
197
+ assortative = False
198
+ if degree_correction is _UNSET:
199
+ degree_correction = 0.0
200
+ elif model_name == 'traditional_DCSBM':
201
+ if assortative is _UNSET:
202
+ assortative = False
203
+ if degree_correction is _UNSET:
204
+ degree_correction = 0.5
205
+ elif model_name == 'traditional_GDCSBM':
206
+ if assortative is _UNSET:
207
+ assortative = False
208
+ elif model_name == 'simple_ASBM':
209
+ if degree_correction is _UNSET:
210
+ degree_correction = 0.0
211
+ elif model_name == 'hybrid_ASBM':
212
+ if mixing_variation is _UNSET:
213
+ mixing_variation = 'internal'
214
+ if degree_correction is _UNSET:
215
+ degree_correction = 0.0
216
+ elif model_name == 'planted_partition':
217
+ if mixing_variation is _UNSET:
218
+ mixing_variation = 'none'
219
+ if degree_correction is _UNSET:
220
+ degree_correction = 0.0
221
+ elif model_name == 'general_ASBM':
222
+ if mixing_variation is _UNSET:
223
+ mixing_variation = 'general'
224
+ if degree_correction is _UNSET:
225
+ degree_correction = 0.0
226
+ elif model_name == 'microcanonical_SBM':
227
+ if assortative is _UNSET:
228
+ assortative = False
229
+ if degree_correction is _UNSET:
230
+ degree_correction = 0.0
231
+ # mean_degree_scaling is fixed to 0.0 and not exposed
232
+ elif model_name == 'microcanonical_DCSBM':
233
+ if assortative is _UNSET:
234
+ assortative = False
235
+ if degree_correction is _UNSET:
236
+ degree_correction = 0.5
237
+ # mean_degree_scaling is fixed to 0.0 and not exposed
238
+ elif model_name == 'general_unified':
239
+ # mean_degree_scaling previously inferred here; now fixed to 0.0
240
+ pass
241
+
242
+ # Set remaining _UNSET parameters to defaults
243
+ if mixing_variation is _UNSET:
244
+ mixing_variation = 'general'
245
+
246
+ # Implement the mixing_variation presets
247
+ if mixing_variation == 'general':
248
+ pass # both inferred
249
+ elif mixing_variation == 'simple':
250
+ if variation_in is _UNSET:
251
+ variation_in = 0.5
252
+ if variation_out is _UNSET:
253
+ variation_out = 0.5
254
+ elif mixing_variation == 'none':
255
+ if variation_in is _UNSET:
256
+ variation_in = 0.0
257
+ if variation_out is _UNSET:
258
+ variation_out = 0.0
259
+ elif mixing_variation == 'internal':
260
+ if variation_in is _UNSET:
261
+ variation_in = 0.5
262
+ if variation_out is _UNSET:
263
+ variation_out = 0.0
264
+ elif mixing_variation == 'external':
265
+ if variation_in is _UNSET:
266
+ variation_in = 0.0
267
+ if variation_out is _UNSET:
268
+ variation_out = 0.5
269
+
270
+ # Set remaining _UNSET parameters to defaults
271
+ if assortative is _UNSET:
272
+ assortative = True
273
+ if degree_correction is _UNSET:
274
+ degree_correction = None
275
+ if num_groups is _UNSET:
276
+ num_groups = None
277
+ if variation_in is _UNSET:
278
+ variation_in = None
279
+ if variation_out is _UNSET:
280
+ variation_out = None
281
+
282
+ return (assortative, mixing_variation, variation_in, variation_out, degree_correction, num_groups)
283
+
284
+ # Plot the groups on the network
285
+ # Removed unused plotting and GML helper functions (plot_groups, add_node_group_attribute,
286
+ # write_gml). These helpers were unused and partially implemented; callers should
287
+ # perform plotting or file writing externally if needed.
288
+
289
+
290
+ def get_neighbors(G: nx.Graph) -> Iterable[dict]:
291
+ """
292
+ Return the neighbors of the graph in a format that can be used by the MCMC algorithm
293
+ neighbors_list[i] is a dictionary where the keys are the indices of the neighbors of node i
294
+ and the values are the edge weights
295
+
296
+ Parameters
297
+ ----------
298
+ G : nx.Graph
299
+ NetworkX graph to get neighbors for.
300
+ """
301
+ node_to_index_dict = {node: i for i, node in enumerate(G.nodes)}
302
+
303
+ neighbors_list: list[dict[int, float]] = [dict() for _ in range(G.number_of_nodes())]
304
+
305
+ for u, v, data in G.edges(data=True):
306
+ u_index = node_to_index_dict[u]
307
+ v_index = node_to_index_dict[v]
308
+ weight = float(data.get("weight", 1.0))
309
+ if u_index == v_index:
310
+ # Self-loop convention: neighbors[i][i] = 2 * actual_weight (both directions).
311
+ neighbors_list[u_index][u_index] = neighbors_list[u_index].get(u_index, 0.0) + 2.0 * weight
312
+ else:
313
+ neighbors_list[u_index][v_index] = neighbors_list[u_index].get(v_index, 0.0) + weight
314
+ neighbors_list[v_index][u_index] = neighbors_list[v_index].get(u_index, 0.0) + weight
315
+
316
+ return neighbors_list
asbm/py.typed ADDED
File without changes
@@ -0,0 +1,173 @@
1
+ Metadata-Version: 2.4
2
+ Name: asbm
3
+ Version: 0.1.4.dev3
4
+ Summary: Find group structures in networks
5
+ Author-Email: Maximilian Jerdee <mjerdee@umich.edu>
6
+ License-Expression: MIT
7
+ License-File: LICENSE
8
+ Classifier: Development Status :: 1 - Planning
9
+ Classifier: Intended Audience :: Science/Research
10
+ Classifier: Intended Audience :: Developers
11
+ Classifier: Operating System :: OS Independent
12
+ Classifier: Programming Language :: Python
13
+ Classifier: Programming Language :: Python :: 3
14
+ Classifier: Programming Language :: Python :: 3 :: Only
15
+ Classifier: Programming Language :: Python :: 3.10
16
+ Classifier: Programming Language :: Python :: 3.11
17
+ Classifier: Programming Language :: Python :: 3.12
18
+ Classifier: Programming Language :: Python :: 3.13
19
+ Classifier: Programming Language :: Python :: 3.14
20
+ Classifier: Topic :: Scientific/Engineering
21
+ Classifier: Typing :: Typed
22
+ Project-URL: Homepage, https://github.com/maxjerdee/asbm
23
+ Project-URL: Bug Tracker, https://github.com/maxjerdee/asbm/issues
24
+ Project-URL: Discussions, https://github.com/maxjerdee/asbm/discussions
25
+ Project-URL: Changelog, https://github.com/maxjerdee/asbm/releases
26
+ Requires-Python: >=3.10
27
+ Requires-Dist: matplotlib>=3.8.0
28
+ Requires-Dist: numpy>=1.26.0
29
+ Requires-Dist: pandas>=2.2.0
30
+ Requires-Dist: networkx>=3.0
31
+ Requires-Dist: tqdm>=4.65.0
32
+ Description-Content-Type: text/markdown
33
+
34
+ # asbm
35
+
36
+ [![Documentation Status][rtd-badge]][rtd-link]
37
+ [![PyPI version][pypi-version]][pypi-link]
38
+ [![PyPI platforms][pypi-platforms]][pypi-link]
39
+
40
+ <!-- SPHINX-START -->
41
+
42
+ <!-- prettier-ignore-start -->
43
+ [actions-badge]: https://github.com/maxjerdee/asbm/workflows/CI/badge.svg
44
+ [actions-link]: https://github.com/maxjerdee/asbm/actions
45
+ [conda-badge]: https://img.shields.io/conda/vn/conda-forge/asbm
46
+ [conda-link]: https://github.com/conda-forge/asbm-feedstock
47
+ [github-discussions-badge]: https://img.shields.io/static/v1?label=Discussions&message=Ask&color=blue&logo=github
48
+ [github-discussions-link]: https://github.com/maxjerdee/asbm/discussions
49
+ [paper-link]: https://arxiv.org/abs/TODO
50
+ [pypi-link]: https://pypi.org/project/asbm/
51
+ [pypi-platforms]: https://img.shields.io/pypi/pyversions/asbm
52
+ [pypi-version]: https://img.shields.io/pypi/v/asbm
53
+ [rtd-badge]: https://readthedocs.org/projects/asbm/badge/?version=latest
54
+ [rtd-link]: https://asbm.readthedocs.io/en/latest/?badge=latest
55
+
56
+ <!-- prettier-ignore-end -->
57
+
58
+ ## Infer network groups and properties with assortative stochastic block models
59
+
60
+ *Maximilian Jerdee*
61
+
62
+ This Python package uses Bayesian inference to find meaningful groupings of nodes in networks.
63
+
64
+ We implement a [general assortative SBM][paper-link] that unifies the standard SBM and the planted partition model under a common framework. Its parameters directly measure the violation of each model's assumptions: the assortative preference ρ_in/ρ_out captures departure from in/out symmetry, while per-group variation coefficients v_in and v_out measure heterogeneity across groups. The standard SBM, planted partition model, and the Zhang–Peixoto hybrid emerge as special cases, enabling exact Bayesian model comparison between them.
65
+
66
+ For each model, the package includes algorithms to:
67
+ - Find consensus estimates of the group structure
68
+ - Infer global network parameters (assortativity, group sizes)
69
+ - Score held-out edges via posterior predictive likelihood
70
+
71
+ ## Installation
72
+
73
+ Implementations are available for Python, R, and Julia.
74
+
75
+ ### Python
76
+
77
+ ```bash
78
+ pip install asbm
79
+ ```
80
+
81
+ Or build locally from the repository root:
82
+
83
+ ```bash
84
+ pip install .
85
+ ```
86
+
87
+ ### R
88
+
89
+ ```r
90
+ install.packages("asbm", repos = "https://maxjerdee.r-universe.dev")
91
+ ```
92
+
93
+ ### Julia
94
+
95
+ ```julia
96
+ using Pkg
97
+ Pkg.add(url="https://github.com/maxjerdee/asbm", subdir="bindings/julia")
98
+ ```
99
+
100
+ Building from source requires CMake and a C++17 compiler. Run `Pkg.build("ASBM")` after installation to compile the native library.
101
+
102
+ ## Quickstart
103
+
104
+ ### Python
105
+
106
+ ```python
107
+ import asbm
108
+ import networkx as nx
109
+
110
+ G = nx.read_gml("examples/data/dolphins.gml", label="id")
111
+
112
+ config = asbm.Config(
113
+ model="general_asbm",
114
+ degree_correction=True,
115
+ )
116
+
117
+ result = asbm.fit(config, G)
118
+
119
+ print(result.mdl_partition)
120
+ print(result.mdl_value)
121
+ print(result.consensus_partition())
122
+ ```
123
+
124
+ To score held-out edges:
125
+
126
+ ```python
127
+ G_train = nx.read_gml("examples/data/train.gml")
128
+ G_test = nx.read_gml("examples/data/test.gml")
129
+
130
+ result = asbm.fit(asbm.Config(model="general_asbm"), G_train)
131
+ score = result.log_posterior_predictive(G_test)
132
+ ```
133
+
134
+ ### R
135
+
136
+ ```r
137
+ library(asbm)
138
+ library(igraph)
139
+
140
+ G <- read_graph("examples/data/dolphins.gml", format = "gml")
141
+
142
+ result <- fit(G,
143
+ model = "general_asbm",
144
+ degree_correction = TRUE,
145
+ num_chains = 4,
146
+ samples_per_chain = 100,
147
+ seed = 42)
148
+
149
+ print(result$mdl_value)
150
+ print(result$mdl_partition)
151
+ print(result$consensus_partition)
152
+ ```
153
+
154
+ ### Julia
155
+
156
+ ```julia
157
+ using ASBM, Graphs
158
+
159
+ g = cycle_graph(62) # or load via GraphIO
160
+
161
+ result = fit(g;
162
+ model = "general_asbm",
163
+ degree_correction = true,
164
+ num_chains = 4,
165
+ samples_per_chain = 100,
166
+ seed = 42)
167
+
168
+ println(result.mdl_value)
169
+ println(result.mdl_partition)
170
+ println(result.consensus_partition)
171
+ ```
172
+
173
+ For the full API, including posterior predictive evaluation and the samples schema, see the [package documentation][rtd-link].
@@ -0,0 +1,10 @@
1
+ asbm/__init__.py,sha256=Mhl0eIOXCjR3wWpYwKTu2lVZHqkyBYNDg047JB-Um2U,313
2
+ asbm/_asbm.cp313-win_amd64.pyd,sha256=13xenevjUZbvaLYu3Przwzk6cxgphUt362i4JmFSZhc,413184
3
+ asbm/_version.py,sha256=r5MqpXVcUNROsObvaJ5rDQq1GnaqEXbreKXRuNHxjqU,565
4
+ asbm/core.py,sha256=qvss9eZs9IfcinHw-_9sx4n61z5eysRL7p_3BzRltCc,5246
5
+ asbm/input_output.py,sha256=BEmCLFmaVR1fh4uAtWpVMrB7yyFBZWmYzCDkj9AwIuA,12814
6
+ asbm/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
7
+ asbm-0.1.4.dev3.dist-info/METADATA,sha256=sQx6C18Sgc0l4IZYa0JNDmYvLGh2QJjnbAEsyVWN1Sc,5436
8
+ asbm-0.1.4.dev3.dist-info/WHEEL,sha256=UZrbbE4r80xj7Ncfa6JoeTVe-77bdXLkKUA63V8pKWQ,106
9
+ asbm-0.1.4.dev3.dist-info/licenses/LICENSE,sha256=nXOCnfm201QiVNd89JJAWmzk6cFctyYaVecsfK4kXk8,1076
10
+ asbm-0.1.4.dev3.dist-info/RECORD,,
@@ -0,0 +1,5 @@
1
+ Wheel-Version: 1.0
2
+ Generator: scikit-build-core 0.12.2
3
+ Root-Is-Purelib: false
4
+ Tag: cp313-cp313-win_amd64
5
+
@@ -0,0 +1,19 @@
1
+ Copyright 2025 Maximilian Jerdee
2
+
3
+ Permission is hereby granted, free of charge, to any person obtaining a copy of
4
+ this software and associated documentation files (the "Software"), to deal in
5
+ the Software without restriction, including without limitation the rights to
6
+ use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
7
+ of the Software, and to permit persons to whom the Software is furnished to do
8
+ so, subject to the following conditions:
9
+
10
+ The above copyright notice and this permission notice shall be included in all
11
+ copies or substantial portions of the Software.
12
+
13
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
19
+ SOFTWARE.