asbm 0.1.4.dev3__cp313-cp313-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- asbm/__init__.py +18 -0
- asbm/_asbm.cp313-win_amd64.pyd +0 -0
- asbm/_version.py +24 -0
- asbm/core.py +166 -0
- asbm/input_output.py +316 -0
- asbm/py.typed +0 -0
- asbm-0.1.4.dev3.dist-info/METADATA +173 -0
- asbm-0.1.4.dev3.dist-info/RECORD +10 -0
- asbm-0.1.4.dev3.dist-info/WHEEL +5 -0
- asbm-0.1.4.dev3.dist-info/licenses/LICENSE +19 -0
asbm/__init__.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Copyright (c) 2025 Max Jerdee. All rights reserved.
|
|
3
|
+
|
|
4
|
+
asbm: Finding group structure in networks
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
from ._version import version as __version__
|
|
10
|
+
|
|
11
|
+
__all__ = [
|
|
12
|
+
"__version__",
|
|
13
|
+
"Config",
|
|
14
|
+
"Result",
|
|
15
|
+
"fit",
|
|
16
|
+
]
|
|
17
|
+
|
|
18
|
+
from .core import Config, Result, fit
|
|
Binary file
|
asbm/_version.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
# file generated by vcs-versioning
|
|
2
|
+
# don't change, don't track in version control
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
__all__ = [
|
|
6
|
+
"__version__",
|
|
7
|
+
"__version_tuple__",
|
|
8
|
+
"version",
|
|
9
|
+
"version_tuple",
|
|
10
|
+
"__commit_id__",
|
|
11
|
+
"commit_id",
|
|
12
|
+
]
|
|
13
|
+
|
|
14
|
+
version: str
|
|
15
|
+
__version__: str
|
|
16
|
+
__version_tuple__: tuple[int | str, ...]
|
|
17
|
+
version_tuple: tuple[int | str, ...]
|
|
18
|
+
commit_id: str | None
|
|
19
|
+
__commit_id__: str | None
|
|
20
|
+
|
|
21
|
+
__version__ = version = '0.1.4.dev3'
|
|
22
|
+
__version_tuple__ = version_tuple = (0, 1, 4, 'dev3')
|
|
23
|
+
|
|
24
|
+
__commit_id__ = commit_id = 'gd2cbac3fd'
|
asbm/core.py
ADDED
|
@@ -0,0 +1,166 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Literal
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
import networkx as nx
|
|
8
|
+
|
|
9
|
+
try:
|
|
10
|
+
from ._asbm import FitResult as _NativeFitResult
|
|
11
|
+
from ._asbm import fit_backend
|
|
12
|
+
except ImportError as e: # pragma: no cover
|
|
13
|
+
_NativeFitResult = None # type: ignore[assignment]
|
|
14
|
+
fit_backend = None # type: ignore[assignment]
|
|
15
|
+
_BACKEND_IMPORT_ERROR = e
|
|
16
|
+
else:
|
|
17
|
+
_BACKEND_IMPORT_ERROR = None
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
ModelName = Literal[
|
|
21
|
+
"asbm", "general_canonical", "general_unified",
|
|
22
|
+
"simple_asbm", "simple_ASBM",
|
|
23
|
+
"general_asbm", "general_ASBM",
|
|
24
|
+
"hybrid_ASBM",
|
|
25
|
+
"planted_partition",
|
|
26
|
+
"traditional_SBM", "traditional_DCSBM", "traditional_GDCSBM",
|
|
27
|
+
"microcanonical_SBM", "microcanonical_DCSBM",
|
|
28
|
+
]
|
|
29
|
+
|
|
30
|
+
_ALLOWED_MODELS: frozenset[str] = frozenset({
|
|
31
|
+
"asbm", "general_canonical", "general_unified",
|
|
32
|
+
"simple_asbm", "simple_ASBM",
|
|
33
|
+
"general_asbm", "general_ASBM",
|
|
34
|
+
"hybrid_ASBM",
|
|
35
|
+
"planted_partition",
|
|
36
|
+
"traditional_SBM", "traditional_DCSBM", "traditional_GDCSBM",
|
|
37
|
+
"microcanonical_SBM", "microcanonical_DCSBM",
|
|
38
|
+
})
|
|
39
|
+
|
|
40
|
+
@dataclass(frozen=True)
|
|
41
|
+
class Config:
|
|
42
|
+
model: ModelName = "asbm"
|
|
43
|
+
degree_correction: bool = True
|
|
44
|
+
|
|
45
|
+
def __post_init__(self) -> None:
|
|
46
|
+
if self.model not in _ALLOWED_MODELS:
|
|
47
|
+
raise ValueError(
|
|
48
|
+
"model must be one of: " + ", ".join(sorted(_ALLOWED_MODELS))
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class Result:
|
|
53
|
+
def __init__(self, native_result: _NativeFitResult, node_order: tuple[object, ...]) -> None:
|
|
54
|
+
self._native_result = native_result
|
|
55
|
+
self._node_order = node_order
|
|
56
|
+
|
|
57
|
+
@property
|
|
58
|
+
def mdl_value(self) -> float:
|
|
59
|
+
return float(self._native_result.mdl_value)
|
|
60
|
+
|
|
61
|
+
@property
|
|
62
|
+
def mdl_partition(self) -> np.ndarray:
|
|
63
|
+
return np.asarray(self._native_result.mdl_partition, dtype=int)
|
|
64
|
+
|
|
65
|
+
@property
|
|
66
|
+
def summary(self) -> dict:
|
|
67
|
+
return dict(self._native_result.summary)
|
|
68
|
+
|
|
69
|
+
def samples_df(self):
|
|
70
|
+
return self._native_result.samples_df()
|
|
71
|
+
|
|
72
|
+
def consensus_partition(self) -> np.ndarray:
|
|
73
|
+
return np.asarray(self._native_result.consensus_partition(), dtype=int)
|
|
74
|
+
|
|
75
|
+
def coincidence_matrix(self) -> np.ndarray:
|
|
76
|
+
return np.asarray(self._native_result.coincidence_matrix(), dtype=float)
|
|
77
|
+
|
|
78
|
+
def log_posterior_predictive(self, G_test: nx.Graph, f: float | None = None) -> float:
|
|
79
|
+
index_of = {node: i for i, node in enumerate(self._node_order)}
|
|
80
|
+
test_nodes = set(G_test.nodes())
|
|
81
|
+
train_nodes = set(self._node_order)
|
|
82
|
+
if test_nodes != train_nodes:
|
|
83
|
+
missing = sorted(train_nodes - test_nodes)
|
|
84
|
+
extra = sorted(test_nodes - train_nodes)
|
|
85
|
+
raise ValueError(f"Test graph node set must match the training graph. Missing={missing}, extra={extra}")
|
|
86
|
+
|
|
87
|
+
test_edges = [
|
|
88
|
+
(
|
|
89
|
+
index_of[u],
|
|
90
|
+
index_of[v],
|
|
91
|
+
float(data.get("weight", 1.0)),
|
|
92
|
+
)
|
|
93
|
+
for u, v, data in G_test.edges(data=True)
|
|
94
|
+
]
|
|
95
|
+
return float(self._native_result.log_posterior_predictive(test_edges, f))
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def _require_backend() -> None:
|
|
99
|
+
if fit_backend is None:
|
|
100
|
+
raise ImportError(
|
|
101
|
+
"Native backend unavailable. Build the extension module first."
|
|
102
|
+
) from _BACKEND_IMPORT_ERROR
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def fit(
|
|
106
|
+
config: Config,
|
|
107
|
+
G: nx.Graph,
|
|
108
|
+
*,
|
|
109
|
+
sweeps: int = 100,
|
|
110
|
+
burn_sweeps: int = 100,
|
|
111
|
+
num_chains: int = 2,
|
|
112
|
+
samples_per_chain: int = 1,
|
|
113
|
+
num_samples: int | None = None,
|
|
114
|
+
thinning: int = 1,
|
|
115
|
+
seed: int = 0,
|
|
116
|
+
show_progress: bool = False,
|
|
117
|
+
) -> Result:
|
|
118
|
+
_require_backend()
|
|
119
|
+
|
|
120
|
+
if not isinstance(config, Config):
|
|
121
|
+
raise TypeError("config must be an instance of Config")
|
|
122
|
+
if not isinstance(G, nx.Graph):
|
|
123
|
+
raise TypeError("G must be a networkx.Graph")
|
|
124
|
+
|
|
125
|
+
if num_samples is not None:
|
|
126
|
+
if num_samples <= 0:
|
|
127
|
+
raise ValueError("num_samples must be positive")
|
|
128
|
+
num_chains = 1
|
|
129
|
+
samples_per_chain = int(num_samples)
|
|
130
|
+
|
|
131
|
+
if sweeps <= 0:
|
|
132
|
+
raise ValueError("sweeps must be positive")
|
|
133
|
+
if burn_sweeps < 0:
|
|
134
|
+
raise ValueError("burn_sweeps must be non-negative")
|
|
135
|
+
if num_chains <= 0:
|
|
136
|
+
raise ValueError("num_chains must be positive")
|
|
137
|
+
if samples_per_chain <= 0:
|
|
138
|
+
raise ValueError("samples_per_chain must be positive")
|
|
139
|
+
if thinning <= 0:
|
|
140
|
+
raise ValueError("thinning must be positive")
|
|
141
|
+
|
|
142
|
+
nodes = sorted(G.nodes())
|
|
143
|
+
index_of = {node: i for i, node in enumerate(nodes)}
|
|
144
|
+
edges = [
|
|
145
|
+
(
|
|
146
|
+
index_of[u],
|
|
147
|
+
index_of[v],
|
|
148
|
+
float(data.get("weight", 1.0)),
|
|
149
|
+
)
|
|
150
|
+
for u, v, data in G.edges(data=True)
|
|
151
|
+
]
|
|
152
|
+
|
|
153
|
+
native_result = fit_backend(
|
|
154
|
+
num_nodes=len(nodes),
|
|
155
|
+
edges=edges,
|
|
156
|
+
model=config.model,
|
|
157
|
+
degree_correction=config.degree_correction,
|
|
158
|
+
sweeps=sweeps,
|
|
159
|
+
burn_sweeps=burn_sweeps,
|
|
160
|
+
num_chains=num_chains,
|
|
161
|
+
samples_per_chain=samples_per_chain,
|
|
162
|
+
thinning=thinning,
|
|
163
|
+
seed=seed,
|
|
164
|
+
show_progress=show_progress,
|
|
165
|
+
)
|
|
166
|
+
return Result(native_result, tuple(nodes))
|
asbm/input_output.py
ADDED
|
@@ -0,0 +1,316 @@
|
|
|
1
|
+
# Handling input and output of data
|
|
2
|
+
|
|
3
|
+
import networkx as nx
|
|
4
|
+
import pandas as pd
|
|
5
|
+
from typing import Optional, Iterable
|
|
6
|
+
import warnings
|
|
7
|
+
|
|
8
|
+
# Sentinel for unset parameters
|
|
9
|
+
_UNSET = object()
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def _canonical_model_name(model_name: str) -> str:
|
|
13
|
+
"""
|
|
14
|
+
Validate that `model_name` is one of the canonical model identifiers.
|
|
15
|
+
|
|
16
|
+
This function only accepts canonical names (case-insensitive) and returns
|
|
17
|
+
the canonical-cased string. Using legacy aliases is not permitted; callers
|
|
18
|
+
should pass canonical model names such as "simple_ASBM" or
|
|
19
|
+
"general_ASBM".
|
|
20
|
+
"""
|
|
21
|
+
if not isinstance(model_name, str):
|
|
22
|
+
raise TypeError("model_name must be a string containing a canonical model name")
|
|
23
|
+
name = model_name.strip()
|
|
24
|
+
canonical = {
|
|
25
|
+
"general_canonical": "general_canonical",
|
|
26
|
+
"general_unified": "general_unified",
|
|
27
|
+
"general_ASBM": "general_ASBM",
|
|
28
|
+
"simple_ASBM": "simple_ASBM",
|
|
29
|
+
"hybrid_ASBM": "hybrid_ASBM",
|
|
30
|
+
"traditional_SBM": "traditional_SBM",
|
|
31
|
+
"traditional_DCSBM": "traditional_DCSBM",
|
|
32
|
+
"traditional_GDCSBM": "traditional_GDCSBM",
|
|
33
|
+
"microcanonical_SBM": "microcanonical_SBM",
|
|
34
|
+
"microcanonical_DCSBM": "microcanonical_DCSBM",
|
|
35
|
+
"planted_partition": "planted_partition",
|
|
36
|
+
}
|
|
37
|
+
lower_map = {k.lower(): v for k, v in canonical.items()}
|
|
38
|
+
if name.lower() in lower_map:
|
|
39
|
+
return lower_map[name.lower()]
|
|
40
|
+
raise ValueError(
|
|
41
|
+
"model_name must be one of the canonical names: " + ", ".join(sorted(canonical.values())) +
|
|
42
|
+
f". Got '{model_name}'."
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
# Check that the arguments passed to the function are valid, if not raise an error
|
|
47
|
+
def _check_input_validity(
|
|
48
|
+
G: nx.Graph,
|
|
49
|
+
node_groups: Optional[pd.DataFrame] = None,
|
|
50
|
+
*,
|
|
51
|
+
model_name: str = "general_canonical",
|
|
52
|
+
assortative: bool = _UNSET,
|
|
53
|
+
mixing_variation: str = _UNSET,
|
|
54
|
+
variation_in: Optional[float] = _UNSET,
|
|
55
|
+
variation_out: Optional[float] = _UNSET,
|
|
56
|
+
degree_correction: Optional[float] = _UNSET,
|
|
57
|
+
num_groups: Optional[int] = _UNSET,
|
|
58
|
+
initial_partition: Optional[Iterable[int]] = None,
|
|
59
|
+
seed: Optional[int] = None,
|
|
60
|
+
num_samples: int = 1000,
|
|
61
|
+
beta: float = 1.0,
|
|
62
|
+
num_tempering_chains: int = 1,
|
|
63
|
+
no_cache: bool = False,
|
|
64
|
+
timeout: float = 60.0,
|
|
65
|
+
verbose: bool = False,
|
|
66
|
+
) -> None:
|
|
67
|
+
"""Validate public API inputs and raise informative errors for invalid values.
|
|
68
|
+
|
|
69
|
+
This function checks types, ranges and basic consistency for the parameters
|
|
70
|
+
documented in the public API. It intentionally performs lightweight checks
|
|
71
|
+
only (no heavy computation or graph processing).
|
|
72
|
+
"""
|
|
73
|
+
|
|
74
|
+
# Graph checks
|
|
75
|
+
if not isinstance(G, nx.Graph):
|
|
76
|
+
raise TypeError("G must be a NetworkX Graph object.")
|
|
77
|
+
|
|
78
|
+
# model_name
|
|
79
|
+
model_name = _canonical_model_name(model_name)
|
|
80
|
+
|
|
81
|
+
# Fix the unset parameters to their defaults (to pass the type checks below)
|
|
82
|
+
if assortative is _UNSET:
|
|
83
|
+
assortative = True
|
|
84
|
+
if mixing_variation is _UNSET:
|
|
85
|
+
mixing_variation = "general"
|
|
86
|
+
if variation_in is _UNSET:
|
|
87
|
+
variation_in = None
|
|
88
|
+
if variation_out is _UNSET:
|
|
89
|
+
variation_out = None
|
|
90
|
+
if degree_correction is _UNSET:
|
|
91
|
+
degree_correction = None
|
|
92
|
+
# mean_degree_scaling has been fixed to 0.0 and is no longer configurable
|
|
93
|
+
if num_groups is _UNSET:
|
|
94
|
+
num_groups = None
|
|
95
|
+
|
|
96
|
+
# assortative
|
|
97
|
+
if not isinstance(assortative, bool):
|
|
98
|
+
raise TypeError("assortative must be a boolean.")
|
|
99
|
+
|
|
100
|
+
# mixing_variation
|
|
101
|
+
valid_mixing = ["simple", "none", "internal", "external", "general"]
|
|
102
|
+
if mixing_variation not in valid_mixing:
|
|
103
|
+
raise ValueError(f"mixing_variation must be one of {valid_mixing}")
|
|
104
|
+
|
|
105
|
+
# variation_in/out: None or 0 <= x <= 1
|
|
106
|
+
for name, val in (("variation_in", variation_in), ("variation_out", variation_out)):
|
|
107
|
+
if val is not None:
|
|
108
|
+
if not isinstance(val, (int, float)):
|
|
109
|
+
raise TypeError(f"{name} must be a float in [0, 1] or None.")
|
|
110
|
+
if not (0.0 <= float(val) <= 1.0):
|
|
111
|
+
raise ValueError(f"{name} must be between 0 and 1 (inclusive). Got {val}.")
|
|
112
|
+
|
|
113
|
+
# degree_correction: None or 0 <= x <= 1
|
|
114
|
+
if degree_correction is not None:
|
|
115
|
+
if not isinstance(degree_correction, (int, float)):
|
|
116
|
+
raise TypeError("degree_correction must be a float in [0, 1] or None.")
|
|
117
|
+
if not (0.0 <= float(degree_correction) <= 1.0):
|
|
118
|
+
raise ValueError("degree_correction must be between 0 and 1 (inclusive).")
|
|
119
|
+
|
|
120
|
+
# mean_degree_scaling removed from user-facing API
|
|
121
|
+
|
|
122
|
+
# num_groups must be positive integer if provided and not greater than number of nodes
|
|
123
|
+
if num_groups is not None:
|
|
124
|
+
if not isinstance(num_groups, int) or num_groups <= 0:
|
|
125
|
+
raise ValueError("num_groups must be a positive integer or None.")
|
|
126
|
+
if G is not None and num_groups > G.number_of_nodes():
|
|
127
|
+
raise ValueError("num_groups cannot be greater than the number of nodes in G.")
|
|
128
|
+
|
|
129
|
+
# initial_partition: if provided, must be iterable of ints of length equal to number of nodes
|
|
130
|
+
if initial_partition is not None:
|
|
131
|
+
try:
|
|
132
|
+
ip_list = list(initial_partition)
|
|
133
|
+
except Exception:
|
|
134
|
+
raise TypeError("initial_partition must be an iterable of integers or None.")
|
|
135
|
+
if G is not None and len(ip_list) != G.number_of_nodes():
|
|
136
|
+
raise ValueError("initial_partition length must match number of nodes in G.")
|
|
137
|
+
for v in ip_list:
|
|
138
|
+
if not isinstance(v, int) or v < 0:
|
|
139
|
+
raise ValueError("initial_partition must contain non-negative integer labels.")
|
|
140
|
+
if num_groups is not None:
|
|
141
|
+
if max(ip_list) >= num_groups:
|
|
142
|
+
raise ValueError("initial_partition contains labels >= num_groups.")
|
|
143
|
+
# Make sure that each group has at least one member
|
|
144
|
+
unique_groups = set(ip_list)
|
|
145
|
+
if len(unique_groups) < max(ip_list) + 1:
|
|
146
|
+
raise ValueError("initial_partition does not contain members for every group.")
|
|
147
|
+
if num_groups is not None and len(unique_groups) < num_groups:
|
|
148
|
+
raise ValueError("initial_partition has a different number of groups than num_groups.")
|
|
149
|
+
|
|
150
|
+
# seed
|
|
151
|
+
if seed is not None and not isinstance(seed, int):
|
|
152
|
+
raise TypeError("seed must be an integer or None.")
|
|
153
|
+
|
|
154
|
+
# num_samples
|
|
155
|
+
if not isinstance(num_samples, int) or num_samples <= 0:
|
|
156
|
+
raise ValueError("num_samples must be a positive integer.")
|
|
157
|
+
|
|
158
|
+
# beta (non-negative float)
|
|
159
|
+
if not isinstance(beta, (int, float)) or float(beta) < 0.0:
|
|
160
|
+
raise ValueError("beta must be a non-negative number.")
|
|
161
|
+
|
|
162
|
+
# num_tempering_chains
|
|
163
|
+
if not isinstance(num_tempering_chains, int) or num_tempering_chains <= 0:
|
|
164
|
+
raise ValueError("num_tempering_chains must be a positive integer.")
|
|
165
|
+
|
|
166
|
+
# no_cache (boolean)
|
|
167
|
+
if not isinstance(no_cache, bool):
|
|
168
|
+
raise TypeError("no_cache must be a boolean.")
|
|
169
|
+
|
|
170
|
+
# timeout
|
|
171
|
+
if not isinstance(timeout, (int, float)) or float(timeout) <= 0.0:
|
|
172
|
+
raise ValueError("timeout must be a positive number (seconds).")
|
|
173
|
+
|
|
174
|
+
# verbose
|
|
175
|
+
if not isinstance(verbose, bool):
|
|
176
|
+
raise TypeError("verbose must be a boolean.")
|
|
177
|
+
|
|
178
|
+
# Passes all checks
|
|
179
|
+
return None
|
|
180
|
+
|
|
181
|
+
def _set_model_parameters(model_name: str,
|
|
182
|
+
assortative: bool = _UNSET,
|
|
183
|
+
mixing_variation: str = _UNSET,
|
|
184
|
+
variation_in: Optional[float] = _UNSET,
|
|
185
|
+
variation_out: Optional[float] = _UNSET,
|
|
186
|
+
degree_correction: Optional[float] = _UNSET,
|
|
187
|
+
num_groups: Optional[int] = _UNSET) -> tuple:
|
|
188
|
+
"""
|
|
189
|
+
Convert the model_name to corresponding non-default parameter values (if not explicitly provided)
|
|
190
|
+
"""
|
|
191
|
+
model_name = _canonical_model_name(model_name)
|
|
192
|
+
|
|
193
|
+
if model_name == 'general_canonical':
|
|
194
|
+
pass # all parameters take default values
|
|
195
|
+
elif model_name == 'traditional_SBM':
|
|
196
|
+
if assortative is _UNSET:
|
|
197
|
+
assortative = False
|
|
198
|
+
if degree_correction is _UNSET:
|
|
199
|
+
degree_correction = 0.0
|
|
200
|
+
elif model_name == 'traditional_DCSBM':
|
|
201
|
+
if assortative is _UNSET:
|
|
202
|
+
assortative = False
|
|
203
|
+
if degree_correction is _UNSET:
|
|
204
|
+
degree_correction = 0.5
|
|
205
|
+
elif model_name == 'traditional_GDCSBM':
|
|
206
|
+
if assortative is _UNSET:
|
|
207
|
+
assortative = False
|
|
208
|
+
elif model_name == 'simple_ASBM':
|
|
209
|
+
if degree_correction is _UNSET:
|
|
210
|
+
degree_correction = 0.0
|
|
211
|
+
elif model_name == 'hybrid_ASBM':
|
|
212
|
+
if mixing_variation is _UNSET:
|
|
213
|
+
mixing_variation = 'internal'
|
|
214
|
+
if degree_correction is _UNSET:
|
|
215
|
+
degree_correction = 0.0
|
|
216
|
+
elif model_name == 'planted_partition':
|
|
217
|
+
if mixing_variation is _UNSET:
|
|
218
|
+
mixing_variation = 'none'
|
|
219
|
+
if degree_correction is _UNSET:
|
|
220
|
+
degree_correction = 0.0
|
|
221
|
+
elif model_name == 'general_ASBM':
|
|
222
|
+
if mixing_variation is _UNSET:
|
|
223
|
+
mixing_variation = 'general'
|
|
224
|
+
if degree_correction is _UNSET:
|
|
225
|
+
degree_correction = 0.0
|
|
226
|
+
elif model_name == 'microcanonical_SBM':
|
|
227
|
+
if assortative is _UNSET:
|
|
228
|
+
assortative = False
|
|
229
|
+
if degree_correction is _UNSET:
|
|
230
|
+
degree_correction = 0.0
|
|
231
|
+
# mean_degree_scaling is fixed to 0.0 and not exposed
|
|
232
|
+
elif model_name == 'microcanonical_DCSBM':
|
|
233
|
+
if assortative is _UNSET:
|
|
234
|
+
assortative = False
|
|
235
|
+
if degree_correction is _UNSET:
|
|
236
|
+
degree_correction = 0.5
|
|
237
|
+
# mean_degree_scaling is fixed to 0.0 and not exposed
|
|
238
|
+
elif model_name == 'general_unified':
|
|
239
|
+
# mean_degree_scaling previously inferred here; now fixed to 0.0
|
|
240
|
+
pass
|
|
241
|
+
|
|
242
|
+
# Set remaining _UNSET parameters to defaults
|
|
243
|
+
if mixing_variation is _UNSET:
|
|
244
|
+
mixing_variation = 'general'
|
|
245
|
+
|
|
246
|
+
# Implement the mixing_variation presets
|
|
247
|
+
if mixing_variation == 'general':
|
|
248
|
+
pass # both inferred
|
|
249
|
+
elif mixing_variation == 'simple':
|
|
250
|
+
if variation_in is _UNSET:
|
|
251
|
+
variation_in = 0.5
|
|
252
|
+
if variation_out is _UNSET:
|
|
253
|
+
variation_out = 0.5
|
|
254
|
+
elif mixing_variation == 'none':
|
|
255
|
+
if variation_in is _UNSET:
|
|
256
|
+
variation_in = 0.0
|
|
257
|
+
if variation_out is _UNSET:
|
|
258
|
+
variation_out = 0.0
|
|
259
|
+
elif mixing_variation == 'internal':
|
|
260
|
+
if variation_in is _UNSET:
|
|
261
|
+
variation_in = 0.5
|
|
262
|
+
if variation_out is _UNSET:
|
|
263
|
+
variation_out = 0.0
|
|
264
|
+
elif mixing_variation == 'external':
|
|
265
|
+
if variation_in is _UNSET:
|
|
266
|
+
variation_in = 0.0
|
|
267
|
+
if variation_out is _UNSET:
|
|
268
|
+
variation_out = 0.5
|
|
269
|
+
|
|
270
|
+
# Set remaining _UNSET parameters to defaults
|
|
271
|
+
if assortative is _UNSET:
|
|
272
|
+
assortative = True
|
|
273
|
+
if degree_correction is _UNSET:
|
|
274
|
+
degree_correction = None
|
|
275
|
+
if num_groups is _UNSET:
|
|
276
|
+
num_groups = None
|
|
277
|
+
if variation_in is _UNSET:
|
|
278
|
+
variation_in = None
|
|
279
|
+
if variation_out is _UNSET:
|
|
280
|
+
variation_out = None
|
|
281
|
+
|
|
282
|
+
return (assortative, mixing_variation, variation_in, variation_out, degree_correction, num_groups)
|
|
283
|
+
|
|
284
|
+
# Plot the groups on the network
|
|
285
|
+
# Removed unused plotting and GML helper functions (plot_groups, add_node_group_attribute,
|
|
286
|
+
# write_gml). These helpers were unused and partially implemented; callers should
|
|
287
|
+
# perform plotting or file writing externally if needed.
|
|
288
|
+
|
|
289
|
+
|
|
290
|
+
def get_neighbors(G: nx.Graph) -> Iterable[dict]:
|
|
291
|
+
"""
|
|
292
|
+
Return the neighbors of the graph in a format that can be used by the MCMC algorithm
|
|
293
|
+
neighbors_list[i] is a dictionary where the keys are the indices of the neighbors of node i
|
|
294
|
+
and the values are the edge weights
|
|
295
|
+
|
|
296
|
+
Parameters
|
|
297
|
+
----------
|
|
298
|
+
G : nx.Graph
|
|
299
|
+
NetworkX graph to get neighbors for.
|
|
300
|
+
"""
|
|
301
|
+
node_to_index_dict = {node: i for i, node in enumerate(G.nodes)}
|
|
302
|
+
|
|
303
|
+
neighbors_list: list[dict[int, float]] = [dict() for _ in range(G.number_of_nodes())]
|
|
304
|
+
|
|
305
|
+
for u, v, data in G.edges(data=True):
|
|
306
|
+
u_index = node_to_index_dict[u]
|
|
307
|
+
v_index = node_to_index_dict[v]
|
|
308
|
+
weight = float(data.get("weight", 1.0))
|
|
309
|
+
if u_index == v_index:
|
|
310
|
+
# Self-loop convention: neighbors[i][i] = 2 * actual_weight (both directions).
|
|
311
|
+
neighbors_list[u_index][u_index] = neighbors_list[u_index].get(u_index, 0.0) + 2.0 * weight
|
|
312
|
+
else:
|
|
313
|
+
neighbors_list[u_index][v_index] = neighbors_list[u_index].get(v_index, 0.0) + weight
|
|
314
|
+
neighbors_list[v_index][u_index] = neighbors_list[v_index].get(u_index, 0.0) + weight
|
|
315
|
+
|
|
316
|
+
return neighbors_list
|
asbm/py.typed
ADDED
|
File without changes
|
|
@@ -0,0 +1,173 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: asbm
|
|
3
|
+
Version: 0.1.4.dev3
|
|
4
|
+
Summary: Find group structures in networks
|
|
5
|
+
Author-Email: Maximilian Jerdee <mjerdee@umich.edu>
|
|
6
|
+
License-Expression: MIT
|
|
7
|
+
License-File: LICENSE
|
|
8
|
+
Classifier: Development Status :: 1 - Planning
|
|
9
|
+
Classifier: Intended Audience :: Science/Research
|
|
10
|
+
Classifier: Intended Audience :: Developers
|
|
11
|
+
Classifier: Operating System :: OS Independent
|
|
12
|
+
Classifier: Programming Language :: Python
|
|
13
|
+
Classifier: Programming Language :: Python :: 3
|
|
14
|
+
Classifier: Programming Language :: Python :: 3 :: Only
|
|
15
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
16
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
17
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
18
|
+
Classifier: Programming Language :: Python :: 3.13
|
|
19
|
+
Classifier: Programming Language :: Python :: 3.14
|
|
20
|
+
Classifier: Topic :: Scientific/Engineering
|
|
21
|
+
Classifier: Typing :: Typed
|
|
22
|
+
Project-URL: Homepage, https://github.com/maxjerdee/asbm
|
|
23
|
+
Project-URL: Bug Tracker, https://github.com/maxjerdee/asbm/issues
|
|
24
|
+
Project-URL: Discussions, https://github.com/maxjerdee/asbm/discussions
|
|
25
|
+
Project-URL: Changelog, https://github.com/maxjerdee/asbm/releases
|
|
26
|
+
Requires-Python: >=3.10
|
|
27
|
+
Requires-Dist: matplotlib>=3.8.0
|
|
28
|
+
Requires-Dist: numpy>=1.26.0
|
|
29
|
+
Requires-Dist: pandas>=2.2.0
|
|
30
|
+
Requires-Dist: networkx>=3.0
|
|
31
|
+
Requires-Dist: tqdm>=4.65.0
|
|
32
|
+
Description-Content-Type: text/markdown
|
|
33
|
+
|
|
34
|
+
# asbm
|
|
35
|
+
|
|
36
|
+
[![Documentation Status][rtd-badge]][rtd-link]
|
|
37
|
+
[![PyPI version][pypi-version]][pypi-link]
|
|
38
|
+
[![PyPI platforms][pypi-platforms]][pypi-link]
|
|
39
|
+
|
|
40
|
+
<!-- SPHINX-START -->
|
|
41
|
+
|
|
42
|
+
<!-- prettier-ignore-start -->
|
|
43
|
+
[actions-badge]: https://github.com/maxjerdee/asbm/workflows/CI/badge.svg
|
|
44
|
+
[actions-link]: https://github.com/maxjerdee/asbm/actions
|
|
45
|
+
[conda-badge]: https://img.shields.io/conda/vn/conda-forge/asbm
|
|
46
|
+
[conda-link]: https://github.com/conda-forge/asbm-feedstock
|
|
47
|
+
[github-discussions-badge]: https://img.shields.io/static/v1?label=Discussions&message=Ask&color=blue&logo=github
|
|
48
|
+
[github-discussions-link]: https://github.com/maxjerdee/asbm/discussions
|
|
49
|
+
[paper-link]: https://arxiv.org/abs/TODO
|
|
50
|
+
[pypi-link]: https://pypi.org/project/asbm/
|
|
51
|
+
[pypi-platforms]: https://img.shields.io/pypi/pyversions/asbm
|
|
52
|
+
[pypi-version]: https://img.shields.io/pypi/v/asbm
|
|
53
|
+
[rtd-badge]: https://readthedocs.org/projects/asbm/badge/?version=latest
|
|
54
|
+
[rtd-link]: https://asbm.readthedocs.io/en/latest/?badge=latest
|
|
55
|
+
|
|
56
|
+
<!-- prettier-ignore-end -->
|
|
57
|
+
|
|
58
|
+
## Infer network groups and properties with assortative stochastic block models
|
|
59
|
+
|
|
60
|
+
*Maximilian Jerdee*
|
|
61
|
+
|
|
62
|
+
This Python package uses Bayesian inference to find meaningful groupings of nodes in networks.
|
|
63
|
+
|
|
64
|
+
We implement a [general assortative SBM][paper-link] that unifies the standard SBM and the planted partition model under a common framework. Its parameters directly measure the violation of each model's assumptions: the assortative preference ρ_in/ρ_out captures departure from in/out symmetry, while per-group variation coefficients v_in and v_out measure heterogeneity across groups. The standard SBM, planted partition model, and the Zhang–Peixoto hybrid emerge as special cases, enabling exact Bayesian model comparison between them.
|
|
65
|
+
|
|
66
|
+
For each model, the package includes algorithms to:
|
|
67
|
+
- Find consensus estimates of the group structure
|
|
68
|
+
- Infer global network parameters (assortativity, group sizes)
|
|
69
|
+
- Score held-out edges via posterior predictive likelihood
|
|
70
|
+
|
|
71
|
+
## Installation
|
|
72
|
+
|
|
73
|
+
Implementations are available for Python, R, and Julia.
|
|
74
|
+
|
|
75
|
+
### Python
|
|
76
|
+
|
|
77
|
+
```bash
|
|
78
|
+
pip install asbm
|
|
79
|
+
```
|
|
80
|
+
|
|
81
|
+
Or build locally from the repository root:
|
|
82
|
+
|
|
83
|
+
```bash
|
|
84
|
+
pip install .
|
|
85
|
+
```
|
|
86
|
+
|
|
87
|
+
### R
|
|
88
|
+
|
|
89
|
+
```r
|
|
90
|
+
install.packages("asbm", repos = "https://maxjerdee.r-universe.dev")
|
|
91
|
+
```
|
|
92
|
+
|
|
93
|
+
### Julia
|
|
94
|
+
|
|
95
|
+
```julia
|
|
96
|
+
using Pkg
|
|
97
|
+
Pkg.add(url="https://github.com/maxjerdee/asbm", subdir="bindings/julia")
|
|
98
|
+
```
|
|
99
|
+
|
|
100
|
+
Building from source requires CMake and a C++17 compiler. Run `Pkg.build("ASBM")` after installation to compile the native library.
|
|
101
|
+
|
|
102
|
+
## Quickstart
|
|
103
|
+
|
|
104
|
+
### Python
|
|
105
|
+
|
|
106
|
+
```python
|
|
107
|
+
import asbm
|
|
108
|
+
import networkx as nx
|
|
109
|
+
|
|
110
|
+
G = nx.read_gml("examples/data/dolphins.gml", label="id")
|
|
111
|
+
|
|
112
|
+
config = asbm.Config(
|
|
113
|
+
model="general_asbm",
|
|
114
|
+
degree_correction=True,
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
result = asbm.fit(config, G)
|
|
118
|
+
|
|
119
|
+
print(result.mdl_partition)
|
|
120
|
+
print(result.mdl_value)
|
|
121
|
+
print(result.consensus_partition())
|
|
122
|
+
```
|
|
123
|
+
|
|
124
|
+
To score held-out edges:
|
|
125
|
+
|
|
126
|
+
```python
|
|
127
|
+
G_train = nx.read_gml("examples/data/train.gml")
|
|
128
|
+
G_test = nx.read_gml("examples/data/test.gml")
|
|
129
|
+
|
|
130
|
+
result = asbm.fit(asbm.Config(model="general_asbm"), G_train)
|
|
131
|
+
score = result.log_posterior_predictive(G_test)
|
|
132
|
+
```
|
|
133
|
+
|
|
134
|
+
### R
|
|
135
|
+
|
|
136
|
+
```r
|
|
137
|
+
library(asbm)
|
|
138
|
+
library(igraph)
|
|
139
|
+
|
|
140
|
+
G <- read_graph("examples/data/dolphins.gml", format = "gml")
|
|
141
|
+
|
|
142
|
+
result <- fit(G,
|
|
143
|
+
model = "general_asbm",
|
|
144
|
+
degree_correction = TRUE,
|
|
145
|
+
num_chains = 4,
|
|
146
|
+
samples_per_chain = 100,
|
|
147
|
+
seed = 42)
|
|
148
|
+
|
|
149
|
+
print(result$mdl_value)
|
|
150
|
+
print(result$mdl_partition)
|
|
151
|
+
print(result$consensus_partition)
|
|
152
|
+
```
|
|
153
|
+
|
|
154
|
+
### Julia
|
|
155
|
+
|
|
156
|
+
```julia
|
|
157
|
+
using ASBM, Graphs
|
|
158
|
+
|
|
159
|
+
g = cycle_graph(62) # or load via GraphIO
|
|
160
|
+
|
|
161
|
+
result = fit(g;
|
|
162
|
+
model = "general_asbm",
|
|
163
|
+
degree_correction = true,
|
|
164
|
+
num_chains = 4,
|
|
165
|
+
samples_per_chain = 100,
|
|
166
|
+
seed = 42)
|
|
167
|
+
|
|
168
|
+
println(result.mdl_value)
|
|
169
|
+
println(result.mdl_partition)
|
|
170
|
+
println(result.consensus_partition)
|
|
171
|
+
```
|
|
172
|
+
|
|
173
|
+
For the full API, including posterior predictive evaluation and the samples schema, see the [package documentation][rtd-link].
|
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
asbm/__init__.py,sha256=Mhl0eIOXCjR3wWpYwKTu2lVZHqkyBYNDg047JB-Um2U,313
|
|
2
|
+
asbm/_asbm.cp313-win_amd64.pyd,sha256=13xenevjUZbvaLYu3Przwzk6cxgphUt362i4JmFSZhc,413184
|
|
3
|
+
asbm/_version.py,sha256=r5MqpXVcUNROsObvaJ5rDQq1GnaqEXbreKXRuNHxjqU,565
|
|
4
|
+
asbm/core.py,sha256=qvss9eZs9IfcinHw-_9sx4n61z5eysRL7p_3BzRltCc,5246
|
|
5
|
+
asbm/input_output.py,sha256=BEmCLFmaVR1fh4uAtWpVMrB7yyFBZWmYzCDkj9AwIuA,12814
|
|
6
|
+
asbm/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
7
|
+
asbm-0.1.4.dev3.dist-info/METADATA,sha256=sQx6C18Sgc0l4IZYa0JNDmYvLGh2QJjnbAEsyVWN1Sc,5436
|
|
8
|
+
asbm-0.1.4.dev3.dist-info/WHEEL,sha256=UZrbbE4r80xj7Ncfa6JoeTVe-77bdXLkKUA63V8pKWQ,106
|
|
9
|
+
asbm-0.1.4.dev3.dist-info/licenses/LICENSE,sha256=nXOCnfm201QiVNd89JJAWmzk6cFctyYaVecsfK4kXk8,1076
|
|
10
|
+
asbm-0.1.4.dev3.dist-info/RECORD,,
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
Copyright 2025 Maximilian Jerdee
|
|
2
|
+
|
|
3
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy of
|
|
4
|
+
this software and associated documentation files (the "Software"), to deal in
|
|
5
|
+
the Software without restriction, including without limitation the rights to
|
|
6
|
+
use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
|
|
7
|
+
of the Software, and to permit persons to whom the Software is furnished to do
|
|
8
|
+
so, subject to the following conditions:
|
|
9
|
+
|
|
10
|
+
The above copyright notice and this permission notice shall be included in all
|
|
11
|
+
copies or substantial portions of the Software.
|
|
12
|
+
|
|
13
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
14
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
15
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
16
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
17
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
18
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
19
|
+
SOFTWARE.
|