LZGraphs 2.2.0__tar.gz → 2.3.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {lzgraphs-2.2.0 → lzgraphs-2.3.0}/PKG-INFO +1 -1
- lzgraphs-2.3.0/setup.py +40 -0
- {lzgraphs-2.2.0 → lzgraphs-2.3.0}/src/LZGraphs/__init__.py +1 -1
- lzgraphs-2.3.0/src/LZGraphs/_fast_walk.c +321 -0
- {lzgraphs-2.2.0 → lzgraphs-2.3.0}/src/LZGraphs/graphs/amino_acid_positional.py +0 -2
- {lzgraphs-2.2.0 → lzgraphs-2.3.0}/src/LZGraphs/graphs/graph_operations.py +4 -13
- {lzgraphs-2.2.0 → lzgraphs-2.3.0}/src/LZGraphs/graphs/lz_graph_base.py +128 -24
- {lzgraphs-2.2.0 → lzgraphs-2.3.0}/src/LZGraphs/graphs/nucleotide_double_positional.py +0 -2
- {lzgraphs-2.2.0 → lzgraphs-2.3.0}/src/LZGraphs/mixins/gene_logic.py +22 -4
- {lzgraphs-2.2.0 → lzgraphs-2.3.0}/src/LZGraphs/mixins/lzpgen_distribution.py +9 -9
- {lzgraphs-2.2.0 → lzgraphs-2.3.0}/src/LZGraphs/mixins/serialization.py +31 -18
- {lzgraphs-2.2.0 → lzgraphs-2.3.0}/src/LZGraphs.egg-info/PKG-INFO +1 -1
- {lzgraphs-2.2.0 → lzgraphs-2.3.0}/src/LZGraphs.egg-info/SOURCES.txt +2 -0
- {lzgraphs-2.2.0 → lzgraphs-2.3.0}/tests/test_lzpgen_distribution.py +6 -1
- {lzgraphs-2.2.0 → lzgraphs-2.3.0}/CHANGELOG.md +0 -0
- {lzgraphs-2.2.0 → lzgraphs-2.3.0}/CONTRIBUTING.md +0 -0
- {lzgraphs-2.2.0 → lzgraphs-2.3.0}/LICENSE +0 -0
- {lzgraphs-2.2.0 → lzgraphs-2.3.0}/MANIFEST.in +0 -0
- {lzgraphs-2.2.0 → lzgraphs-2.3.0}/README.md +0 -0
- {lzgraphs-2.2.0 → lzgraphs-2.3.0}/pyproject.toml +0 -0
- {lzgraphs-2.2.0 → lzgraphs-2.3.0}/requirements.txt +0 -0
- {lzgraphs-2.2.0 → lzgraphs-2.3.0}/setup.cfg +0 -0
- {lzgraphs-2.2.0 → lzgraphs-2.3.0}/src/LZGraphs/bag_of_words/__init__.py +0 -0
- {lzgraphs-2.2.0 → lzgraphs-2.3.0}/src/LZGraphs/bag_of_words/bow_encoder.py +0 -0
- {lzgraphs-2.2.0 → lzgraphs-2.3.0}/src/LZGraphs/constants.py +0 -0
- {lzgraphs-2.2.0 → lzgraphs-2.3.0}/src/LZGraphs/exceptions/__init__.py +0 -0
- {lzgraphs-2.2.0 → lzgraphs-2.3.0}/src/LZGraphs/graphs/__init__.py +0 -0
- {lzgraphs-2.2.0 → lzgraphs-2.3.0}/src/LZGraphs/graphs/edge_data.py +0 -0
- {lzgraphs-2.2.0 → lzgraphs-2.3.0}/src/LZGraphs/graphs/naive.py +0 -0
- {lzgraphs-2.2.0 → lzgraphs-2.3.0}/src/LZGraphs/metrics/__init__.py +0 -0
- {lzgraphs-2.2.0 → lzgraphs-2.3.0}/src/LZGraphs/metrics/convenience.py +0 -0
- {lzgraphs-2.2.0 → lzgraphs-2.3.0}/src/LZGraphs/metrics/diversity.py +0 -0
- {lzgraphs-2.2.0 → lzgraphs-2.3.0}/src/LZGraphs/metrics/entropy.py +0 -0
- {lzgraphs-2.2.0 → lzgraphs-2.3.0}/src/LZGraphs/metrics/pgen_distribution.py +0 -0
- {lzgraphs-2.2.0 → lzgraphs-2.3.0}/src/LZGraphs/metrics/saturation.py +0 -0
- {lzgraphs-2.2.0 → lzgraphs-2.3.0}/src/LZGraphs/mixins/__init__.py +0 -0
- {lzgraphs-2.2.0 → lzgraphs-2.3.0}/src/LZGraphs/mixins/bayesian_posterior.py +0 -0
- {lzgraphs-2.2.0 → lzgraphs-2.3.0}/src/LZGraphs/mixins/gene_prediction.py +0 -0
- {lzgraphs-2.2.0 → lzgraphs-2.3.0}/src/LZGraphs/mixins/graph_topology.py +0 -0
- {lzgraphs-2.2.0 → lzgraphs-2.3.0}/src/LZGraphs/mixins/random_walk.py +0 -0
- {lzgraphs-2.2.0 → lzgraphs-2.3.0}/src/LZGraphs/mixins/walk_analysis.py +0 -0
- {lzgraphs-2.2.0 → lzgraphs-2.3.0}/src/LZGraphs/py.typed +0 -0
- {lzgraphs-2.2.0 → lzgraphs-2.3.0}/src/LZGraphs/utilities/__init__.py +0 -0
- {lzgraphs-2.2.0 → lzgraphs-2.3.0}/src/LZGraphs/utilities/decomposition.py +0 -0
- {lzgraphs-2.2.0 → lzgraphs-2.3.0}/src/LZGraphs/utilities/helpers.py +0 -0
- {lzgraphs-2.2.0 → lzgraphs-2.3.0}/src/LZGraphs/utilities/misc.py +0 -0
- {lzgraphs-2.2.0 → lzgraphs-2.3.0}/src/LZGraphs/visualization/__init__.py +0 -0
- {lzgraphs-2.2.0 → lzgraphs-2.3.0}/src/LZGraphs/visualization/visualize.py +0 -0
- {lzgraphs-2.2.0 → lzgraphs-2.3.0}/src/LZGraphs.egg-info/dependency_links.txt +0 -0
- {lzgraphs-2.2.0 → lzgraphs-2.3.0}/src/LZGraphs.egg-info/requires.txt +0 -0
- {lzgraphs-2.2.0 → lzgraphs-2.3.0}/src/LZGraphs.egg-info/top_level.txt +0 -0
- {lzgraphs-2.2.0 → lzgraphs-2.3.0}/tests/test_aap_lzgraph.py +0 -0
- {lzgraphs-2.2.0 → lzgraphs-2.3.0}/tests/test_abundance.py +0 -0
- {lzgraphs-2.2.0 → lzgraphs-2.3.0}/tests/test_analytical_distribution.py +0 -0
- {lzgraphs-2.2.0 → lzgraphs-2.3.0}/tests/test_base_class_methods.py +0 -0
- {lzgraphs-2.2.0 → lzgraphs-2.3.0}/tests/test_bow_encoder.py +0 -0
- {lzgraphs-2.2.0 → lzgraphs-2.3.0}/tests/test_diversity_theory.py +0 -0
- {lzgraphs-2.2.0 → lzgraphs-2.3.0}/tests/test_flexible_input.py +0 -0
- {lzgraphs-2.2.0 → lzgraphs-2.3.0}/tests/test_graph_operations.py +0 -0
- {lzgraphs-2.2.0 → lzgraphs-2.3.0}/tests/test_metrics.py +0 -0
- {lzgraphs-2.2.0 → lzgraphs-2.3.0}/tests/test_naive_lzgraph.py +0 -0
- {lzgraphs-2.2.0 → lzgraphs-2.3.0}/tests/test_ndp_lzgraph.py +0 -0
- {lzgraphs-2.2.0 → lzgraphs-2.3.0}/tests/test_new_features.py +0 -0
- {lzgraphs-2.2.0 → lzgraphs-2.3.0}/tests/test_pgen_fixes.py +0 -0
- {lzgraphs-2.2.0 → lzgraphs-2.3.0}/tests/test_serialization.py +0 -0
- {lzgraphs-2.2.0 → lzgraphs-2.3.0}/tests/test_simulate.py +0 -0
- {lzgraphs-2.2.0 → lzgraphs-2.3.0}/tests/test_utilities.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: LZGraphs
|
|
3
|
-
Version: 2.
|
|
3
|
+
Version: 2.3.0
|
|
4
4
|
Summary: An Implementation of LZ76 Based Graphs for Repertoire Representation and Analysis
|
|
5
5
|
Author-email: Thomas Konstantinovsky <thomaskon90@gmail.com>
|
|
6
6
|
Maintainer-email: Thomas Konstantinovsky <thomaskon90@gmail.com>
|
lzgraphs-2.3.0/setup.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Build script for optional C extensions.
|
|
3
|
+
|
|
4
|
+
The _fast_walk extension accelerates LZGraph.simulate() by ~50-100x.
|
|
5
|
+
If compilation fails (no C compiler), the package still installs and
|
|
6
|
+
falls back to the pure-Python implementation automatically.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import os
|
|
10
|
+
import sys
|
|
11
|
+
from setuptools import setup, Extension
|
|
12
|
+
|
|
13
|
+
# Ensure setuptools can resolve the dynamic version (attr = "LZGraphs.__version__")
|
|
14
|
+
# when running in an isolated build environment where src/ isn't on sys.path.
|
|
15
|
+
sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), "src"))
|
|
16
|
+
|
|
17
|
+
ext_modules = [
|
|
18
|
+
Extension(
|
|
19
|
+
"LZGraphs._fast_walk",
|
|
20
|
+
sources=[os.path.join("src", "LZGraphs", "_fast_walk.c")],
|
|
21
|
+
# No external library dependencies — pure C + Python.h
|
|
22
|
+
),
|
|
23
|
+
]
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def run_setup(extensions):
|
|
27
|
+
setup(ext_modules=extensions)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
try:
|
|
31
|
+
run_setup(ext_modules)
|
|
32
|
+
except Exception:
|
|
33
|
+
print(
|
|
34
|
+
"\n"
|
|
35
|
+
"WARNING: Failed to compile C extension _fast_walk.\n"
|
|
36
|
+
" LZGraphs will use the pure-Python fallback for simulate().\n"
|
|
37
|
+
" This is fine — the package works without it, just slower.\n"
|
|
38
|
+
"\n"
|
|
39
|
+
)
|
|
40
|
+
run_setup([])
|
|
@@ -0,0 +1,321 @@
|
|
|
1
|
+
/*
|
|
2
|
+
* _fast_walk.c — CPython C extension for fast Markov chain random walks.
|
|
3
|
+
*
|
|
4
|
+
* Implements the full simulate() loop in C including string assembly,
|
|
5
|
+
* for ~100-200x speedup over the original pure-Python implementation.
|
|
6
|
+
* Uses xoshiro256++ for fast, high-quality RNG.
|
|
7
|
+
*
|
|
8
|
+
* The extension is optional: if it fails to compile (no C compiler),
|
|
9
|
+
* LZGraphs falls back to the pure-Python bisect-based implementation.
|
|
10
|
+
*/
|
|
11
|
+
|
|
12
|
+
#define PY_SSIZE_T_CLEAN
|
|
13
|
+
#include <Python.h>
|
|
14
|
+
#include <stdint.h>
|
|
15
|
+
#include <string.h>
|
|
16
|
+
|
|
17
|
+
/* ========================================================================
|
|
18
|
+
* xoshiro256++ RNG — public domain by David Blackman and Sebastiano Vigna
|
|
19
|
+
* ======================================================================== */
|
|
20
|
+
|
|
21
|
+
static inline uint64_t rotl(const uint64_t x, int k) {
|
|
22
|
+
return (x << k) | (x >> (64 - k));
|
|
23
|
+
}
|
|
24
|
+
|
|
25
|
+
typedef struct {
|
|
26
|
+
uint64_t s[4];
|
|
27
|
+
} xoshiro256_state;
|
|
28
|
+
|
|
29
|
+
static inline uint64_t xoshiro256pp_next(xoshiro256_state *state) {
|
|
30
|
+
const uint64_t result = rotl(state->s[0] + state->s[3], 23) + state->s[0];
|
|
31
|
+
const uint64_t t = state->s[1] << 17;
|
|
32
|
+
state->s[2] ^= state->s[0];
|
|
33
|
+
state->s[3] ^= state->s[1];
|
|
34
|
+
state->s[1] ^= state->s[2];
|
|
35
|
+
state->s[0] ^= state->s[3];
|
|
36
|
+
state->s[2] ^= t;
|
|
37
|
+
state->s[3] = rotl(state->s[3], 45);
|
|
38
|
+
return result;
|
|
39
|
+
}
|
|
40
|
+
|
|
41
|
+
static inline double xoshiro256pp_double(xoshiro256_state *state) {
|
|
42
|
+
return (double)(xoshiro256pp_next(state) >> 11) * 0x1.0p-53;
|
|
43
|
+
}
|
|
44
|
+
|
|
45
|
+
static inline uint64_t splitmix64(uint64_t *x) {
|
|
46
|
+
uint64_t z = (*x += 0x9e3779b97f4a7c15ULL);
|
|
47
|
+
z = (z ^ (z >> 30)) * 0xbf58476d1ce4e5b9ULL;
|
|
48
|
+
z = (z ^ (z >> 27)) * 0x94d049bb133111ebULL;
|
|
49
|
+
return z ^ (z >> 31);
|
|
50
|
+
}
|
|
51
|
+
|
|
52
|
+
static void seed_xoshiro256(xoshiro256_state *state, uint64_t seed) {
|
|
53
|
+
state->s[0] = splitmix64(&seed);
|
|
54
|
+
state->s[1] = splitmix64(&seed);
|
|
55
|
+
state->s[2] = splitmix64(&seed);
|
|
56
|
+
state->s[3] = splitmix64(&seed);
|
|
57
|
+
}
|
|
58
|
+
|
|
59
|
+
/* ========================================================================
|
|
60
|
+
* Binary search (bisect_left) on a double array
|
|
61
|
+
* ======================================================================== */
|
|
62
|
+
|
|
63
|
+
static inline Py_ssize_t bisect_left_double(
|
|
64
|
+
const double *arr, Py_ssize_t n, double value
|
|
65
|
+
) {
|
|
66
|
+
Py_ssize_t lo = 0, hi = n;
|
|
67
|
+
while (lo < hi) {
|
|
68
|
+
Py_ssize_t mid = lo + (hi - lo) / 2;
|
|
69
|
+
if (arr[mid] < value)
|
|
70
|
+
lo = mid + 1;
|
|
71
|
+
else
|
|
72
|
+
hi = mid;
|
|
73
|
+
}
|
|
74
|
+
return lo;
|
|
75
|
+
}
|
|
76
|
+
|
|
77
|
+
/* ========================================================================
|
|
78
|
+
* simulate_walks — full simulation with string assembly in C
|
|
79
|
+
*
|
|
80
|
+
* Args:
|
|
81
|
+
* n_walks : int
|
|
82
|
+
* offsets : intp array [n_nodes+1] (buffer)
|
|
83
|
+
* neighbors : intp array [total_edges] (buffer)
|
|
84
|
+
* cumweights : float64 array [total_edges] (buffer)
|
|
85
|
+
* stop_probs : float64 array [n_nodes] (buffer)
|
|
86
|
+
* initial_ids : intp array [n_initial] (buffer)
|
|
87
|
+
* initial_cw : float64 array [n_initial] (buffer)
|
|
88
|
+
* seed : uint64
|
|
89
|
+
* clean_labels : list[str] — label for each node ID
|
|
90
|
+
* return_walks : bool — if True, return (walk, seq) tuples
|
|
91
|
+
* id_to_node : list[str] — node names (only used if return_walks)
|
|
92
|
+
*
|
|
93
|
+
* Returns:
|
|
94
|
+
* list[str] or list[tuple[list[str], str]]
|
|
95
|
+
* ======================================================================== */
|
|
96
|
+
|
|
97
|
+
static PyObject* py_simulate_walks(PyObject *self, PyObject *args) {
|
|
98
|
+
int n_walks, return_walks;
|
|
99
|
+
Py_buffer offsets_buf, neighbors_buf, cumweights_buf;
|
|
100
|
+
Py_buffer stop_probs_buf, initial_ids_buf, initial_cw_buf;
|
|
101
|
+
unsigned long long seed;
|
|
102
|
+
PyObject *clean_labels; /* Python list of str */
|
|
103
|
+
PyObject *id_to_node; /* Python list of str */
|
|
104
|
+
PyObject *result_list = NULL;
|
|
105
|
+
|
|
106
|
+
if (!PyArg_ParseTuple(args, "iy*y*y*y*y*y*KOpO",
|
|
107
|
+
&n_walks,
|
|
108
|
+
&offsets_buf, &neighbors_buf, &cumweights_buf,
|
|
109
|
+
&stop_probs_buf, &initial_ids_buf, &initial_cw_buf,
|
|
110
|
+
&seed,
|
|
111
|
+
&clean_labels,
|
|
112
|
+
&return_walks,
|
|
113
|
+
&id_to_node))
|
|
114
|
+
return NULL;
|
|
115
|
+
|
|
116
|
+
const Py_ssize_t *offsets = (const Py_ssize_t *)offsets_buf.buf;
|
|
117
|
+
const Py_ssize_t *neighbors = (const Py_ssize_t *)neighbors_buf.buf;
|
|
118
|
+
const double *cumweights = (const double *)cumweights_buf.buf;
|
|
119
|
+
const double *stop_probs = (const double *)stop_probs_buf.buf;
|
|
120
|
+
const Py_ssize_t *initial_ids = (const Py_ssize_t *)initial_ids_buf.buf;
|
|
121
|
+
const double *initial_cw = (const double *)initial_cw_buf.buf;
|
|
122
|
+
const Py_ssize_t n_initial = initial_cw_buf.len / (Py_ssize_t)sizeof(double);
|
|
123
|
+
|
|
124
|
+
if (n_initial <= 0) {
|
|
125
|
+
PyErr_SetString(PyExc_ValueError,
|
|
126
|
+
"Cannot simulate: graph has no initial states.");
|
|
127
|
+
goto cleanup;
|
|
128
|
+
}
|
|
129
|
+
|
|
130
|
+
/* Pre-fetch label UTF-8 data for fast string assembly */
|
|
131
|
+
const Py_ssize_t n_labels = PyList_GET_SIZE(clean_labels);
|
|
132
|
+
const char **label_ptrs = (const char **)PyMem_Malloc(n_labels * sizeof(char *));
|
|
133
|
+
Py_ssize_t *label_lens = (Py_ssize_t *)PyMem_Malloc(n_labels * sizeof(Py_ssize_t));
|
|
134
|
+
if (!label_ptrs || !label_lens) {
|
|
135
|
+
PyMem_Free(label_ptrs);
|
|
136
|
+
PyMem_Free(label_lens);
|
|
137
|
+
PyErr_NoMemory();
|
|
138
|
+
goto cleanup;
|
|
139
|
+
}
|
|
140
|
+
for (Py_ssize_t i = 0; i < n_labels; i++) {
|
|
141
|
+
PyObject *s = PyList_GET_ITEM(clean_labels, i);
|
|
142
|
+
label_ptrs[i] = PyUnicode_AsUTF8AndSize(s, &label_lens[i]);
|
|
143
|
+
if (!label_ptrs[i]) {
|
|
144
|
+
PyMem_Free(label_ptrs);
|
|
145
|
+
PyMem_Free(label_lens);
|
|
146
|
+
goto cleanup;
|
|
147
|
+
}
|
|
148
|
+
}
|
|
149
|
+
|
|
150
|
+
xoshiro256_state rng;
|
|
151
|
+
seed_xoshiro256(&rng, (uint64_t)seed);
|
|
152
|
+
|
|
153
|
+
result_list = PyList_New(n_walks);
|
|
154
|
+
if (!result_list) {
|
|
155
|
+
PyMem_Free(label_ptrs);
|
|
156
|
+
PyMem_Free(label_lens);
|
|
157
|
+
goto cleanup;
|
|
158
|
+
}
|
|
159
|
+
|
|
160
|
+
/* Reusable walk buffer */
|
|
161
|
+
Py_ssize_t walk_cap = 64;
|
|
162
|
+
Py_ssize_t *walk_buf = (Py_ssize_t *)PyMem_Malloc(walk_cap * sizeof(Py_ssize_t));
|
|
163
|
+
/* Reusable string buffer */
|
|
164
|
+
Py_ssize_t str_cap = 256;
|
|
165
|
+
char *str_buf = (char *)PyMem_Malloc(str_cap);
|
|
166
|
+
if (!walk_buf || !str_buf) {
|
|
167
|
+
PyMem_Free(walk_buf);
|
|
168
|
+
PyMem_Free(str_buf);
|
|
169
|
+
PyMem_Free(label_ptrs);
|
|
170
|
+
PyMem_Free(label_lens);
|
|
171
|
+
Py_DECREF(result_list);
|
|
172
|
+
PyErr_NoMemory();
|
|
173
|
+
goto cleanup;
|
|
174
|
+
}
|
|
175
|
+
|
|
176
|
+
for (int i = 0; i < n_walks; i++) {
|
|
177
|
+
/* Pick initial state */
|
|
178
|
+
double r = xoshiro256pp_double(&rng);
|
|
179
|
+
Py_ssize_t init_idx = bisect_left_double(initial_cw, n_initial, r);
|
|
180
|
+
if (init_idx >= n_initial) init_idx = n_initial - 1;
|
|
181
|
+
Py_ssize_t current = initial_ids[init_idx];
|
|
182
|
+
|
|
183
|
+
Py_ssize_t walk_len = 0;
|
|
184
|
+
walk_buf[walk_len++] = current;
|
|
185
|
+
|
|
186
|
+
/* Build string incrementally */
|
|
187
|
+
Py_ssize_t str_len = 0;
|
|
188
|
+
Py_ssize_t llen = label_lens[current];
|
|
189
|
+
if (str_len + llen > str_cap) {
|
|
190
|
+
str_cap = (str_len + llen) * 2;
|
|
191
|
+
str_buf = (char *)PyMem_Realloc(str_buf, str_cap);
|
|
192
|
+
if (!str_buf) goto oom;
|
|
193
|
+
}
|
|
194
|
+
memcpy(str_buf + str_len, label_ptrs[current], llen);
|
|
195
|
+
str_len += llen;
|
|
196
|
+
|
|
197
|
+
while (1) {
|
|
198
|
+
double sp = stop_probs[current];
|
|
199
|
+
if (sp == sp) {
|
|
200
|
+
if (xoshiro256pp_double(&rng) < sp)
|
|
201
|
+
break;
|
|
202
|
+
}
|
|
203
|
+
|
|
204
|
+
Py_ssize_t start = offsets[current];
|
|
205
|
+
Py_ssize_t end = offsets[current + 1];
|
|
206
|
+
if (start == end)
|
|
207
|
+
break;
|
|
208
|
+
|
|
209
|
+
r = xoshiro256pp_double(&rng);
|
|
210
|
+
Py_ssize_t idx = bisect_left_double(cumweights + start, end - start, r);
|
|
211
|
+
if (idx >= end - start) idx = end - start - 1;
|
|
212
|
+
current = neighbors[start + idx];
|
|
213
|
+
|
|
214
|
+
/* Grow walk buffer if needed */
|
|
215
|
+
if (walk_len >= walk_cap) {
|
|
216
|
+
walk_cap *= 2;
|
|
217
|
+
Py_ssize_t *new_buf = (Py_ssize_t *)PyMem_Realloc(walk_buf, walk_cap * sizeof(Py_ssize_t));
|
|
218
|
+
if (!new_buf) goto oom;
|
|
219
|
+
walk_buf = new_buf;
|
|
220
|
+
}
|
|
221
|
+
walk_buf[walk_len++] = current;
|
|
222
|
+
|
|
223
|
+
/* Append label to string buffer */
|
|
224
|
+
llen = label_lens[current];
|
|
225
|
+
if (str_len + llen > str_cap) {
|
|
226
|
+
str_cap = (str_len + llen) * 2;
|
|
227
|
+
char *new_str = (char *)PyMem_Realloc(str_buf, str_cap);
|
|
228
|
+
if (!new_str) goto oom;
|
|
229
|
+
str_buf = new_str;
|
|
230
|
+
}
|
|
231
|
+
memcpy(str_buf + str_len, label_ptrs[current], llen);
|
|
232
|
+
str_len += llen;
|
|
233
|
+
}
|
|
234
|
+
|
|
235
|
+
/* Create Python string from buffer */
|
|
236
|
+
PyObject *seq = PyUnicode_FromStringAndSize(str_buf, str_len);
|
|
237
|
+
if (!seq) goto oom;
|
|
238
|
+
|
|
239
|
+
if (return_walks) {
|
|
240
|
+
/* Build walk list of node name strings */
|
|
241
|
+
PyObject *walk = PyList_New(walk_len);
|
|
242
|
+
if (!walk) { Py_DECREF(seq); goto oom; }
|
|
243
|
+
for (Py_ssize_t j = 0; j < walk_len; j++) {
|
|
244
|
+
PyObject *node_name = PyList_GET_ITEM(id_to_node, walk_buf[j]);
|
|
245
|
+
Py_INCREF(node_name);
|
|
246
|
+
PyList_SET_ITEM(walk, j, node_name);
|
|
247
|
+
}
|
|
248
|
+
PyObject *tup = PyTuple_Pack(2, walk, seq);
|
|
249
|
+
Py_DECREF(walk);
|
|
250
|
+
Py_DECREF(seq);
|
|
251
|
+
if (!tup) goto oom;
|
|
252
|
+
PyList_SET_ITEM(result_list, i, tup);
|
|
253
|
+
} else {
|
|
254
|
+
PyList_SET_ITEM(result_list, i, seq);
|
|
255
|
+
}
|
|
256
|
+
}
|
|
257
|
+
|
|
258
|
+
PyMem_Free(walk_buf);
|
|
259
|
+
PyMem_Free(str_buf);
|
|
260
|
+
PyMem_Free(label_ptrs);
|
|
261
|
+
PyMem_Free(label_lens);
|
|
262
|
+
goto cleanup;
|
|
263
|
+
|
|
264
|
+
oom:
|
|
265
|
+
PyMem_Free(walk_buf);
|
|
266
|
+
PyMem_Free(str_buf);
|
|
267
|
+
PyMem_Free(label_ptrs);
|
|
268
|
+
PyMem_Free(label_lens);
|
|
269
|
+
Py_XDECREF(result_list);
|
|
270
|
+
result_list = NULL;
|
|
271
|
+
if (!PyErr_Occurred())
|
|
272
|
+
PyErr_NoMemory();
|
|
273
|
+
|
|
274
|
+
cleanup:
|
|
275
|
+
PyBuffer_Release(&offsets_buf);
|
|
276
|
+
PyBuffer_Release(&neighbors_buf);
|
|
277
|
+
PyBuffer_Release(&cumweights_buf);
|
|
278
|
+
PyBuffer_Release(&stop_probs_buf);
|
|
279
|
+
PyBuffer_Release(&initial_ids_buf);
|
|
280
|
+
PyBuffer_Release(&initial_cw_buf);
|
|
281
|
+
|
|
282
|
+
return result_list;
|
|
283
|
+
}
|
|
284
|
+
|
|
285
|
+
/* ========================================================================
|
|
286
|
+
* Module definition
|
|
287
|
+
* ======================================================================== */
|
|
288
|
+
|
|
289
|
+
static PyMethodDef FastWalkMethods[] = {
|
|
290
|
+
{"simulate_walks", py_simulate_walks, METH_VARARGS,
|
|
291
|
+
"Run n random walks on a CSR-encoded graph with string assembly.\n\n"
|
|
292
|
+
"Args:\n"
|
|
293
|
+
" n_walks (int): Number of walks.\n"
|
|
294
|
+
" offsets (array): CSR row offsets [n_nodes+1], dtype=intp.\n"
|
|
295
|
+
" neighbors (array): Flat neighbor IDs, dtype=intp.\n"
|
|
296
|
+
" cumweights (array): Flat cumulative weights, dtype=float64.\n"
|
|
297
|
+
" stop_probs (array): Per-node stop probability (NaN=none), dtype=float64.\n"
|
|
298
|
+
" initial_ids (array): Initial state IDs, dtype=intp.\n"
|
|
299
|
+
" initial_cumprobs (array): Cumulative initial probs, dtype=float64.\n"
|
|
300
|
+
" seed (int): RNG seed (xoshiro256++).\n"
|
|
301
|
+
" clean_labels (list[str]): Subpattern label for each node.\n"
|
|
302
|
+
" return_walks (bool): If True, return (walk, seq) tuples.\n"
|
|
303
|
+
" id_to_node (list[str]): Node names for walk output.\n\n"
|
|
304
|
+
"Returns:\n"
|
|
305
|
+
" list[str] or list[tuple[list[str], str]]\n"},
|
|
306
|
+
{NULL, NULL, 0, NULL}
|
|
307
|
+
};
|
|
308
|
+
|
|
309
|
+
static struct PyModuleDef fast_walk_module = {
|
|
310
|
+
PyModuleDef_HEAD_INIT,
|
|
311
|
+
"_fast_walk",
|
|
312
|
+
"C-accelerated random walk simulation for LZGraphs.\n"
|
|
313
|
+
"Uses xoshiro256++ RNG for high-quality, fast random number generation.\n"
|
|
314
|
+
"This module is optional — LZGraphs falls back to pure Python if unavailable.",
|
|
315
|
+
-1,
|
|
316
|
+
FastWalkMethods
|
|
317
|
+
};
|
|
318
|
+
|
|
319
|
+
PyMODINIT_FUNC PyInit__fast_walk(void) {
|
|
320
|
+
return PyModule_Create(&fast_walk_module);
|
|
321
|
+
}
|
|
@@ -141,8 +141,6 @@ class AAPLZGraph(LZGraphBase):
|
|
|
141
141
|
self._log_step("Graph constructed.", verbose)
|
|
142
142
|
|
|
143
143
|
# Normalize and derive probability dicts
|
|
144
|
-
self.length_counts = dict(self.lengths)
|
|
145
|
-
|
|
146
144
|
total_terminal = sum(self.terminal_state_counts.values())
|
|
147
145
|
self.length_probabilities = (
|
|
148
146
|
{k: v / total_terminal for k, v in self.terminal_state_counts.items()}
|
|
@@ -91,19 +91,10 @@ def graph_union(graphA, graphB):
|
|
|
91
91
|
}
|
|
92
92
|
|
|
93
93
|
# Merge length_distribution counts
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
# Merge observed gene sets
|
|
99
|
-
if hasattr(graphB, 'observed_v_genes'):
|
|
100
|
-
graphA.observed_v_genes = list(
|
|
101
|
-
set(graphA.observed_v_genes) | set(graphB.observed_v_genes)
|
|
102
|
-
)
|
|
103
|
-
if hasattr(graphB, 'observed_j_genes'):
|
|
104
|
-
graphA.observed_j_genes = list(
|
|
105
|
-
set(graphA.observed_j_genes) | set(graphB.observed_j_genes)
|
|
106
|
-
)
|
|
94
|
+
for k, v in graphB.lengths.items():
|
|
95
|
+
graphA.lengths[k] = graphA.lengths.get(k, 0) + v
|
|
96
|
+
|
|
97
|
+
# observed_v/j_genes are now derived from marginal_v/j_genes (already merged above)
|
|
107
98
|
|
|
108
99
|
# 5. Recalculate ALL derived state from raw counts
|
|
109
100
|
graphA.recalculate()
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
import logging
|
|
2
2
|
import re
|
|
3
3
|
from abc import ABC, abstractmethod
|
|
4
|
+
from bisect import bisect_left as _bisect_left
|
|
4
5
|
from time import time
|
|
5
6
|
|
|
6
7
|
import networkx as nx
|
|
@@ -12,6 +13,12 @@ from ..utilities.misc import choice, window
|
|
|
12
13
|
# Shared constants
|
|
13
14
|
from ..constants import _EPS, _LOG_EPS
|
|
14
15
|
|
|
16
|
+
# Optional C extension for fast simulation
|
|
17
|
+
try:
|
|
18
|
+
from .._fast_walk import simulate_walks as _c_simulate_walks
|
|
19
|
+
except ImportError:
|
|
20
|
+
_c_simulate_walks = None
|
|
21
|
+
|
|
15
22
|
# EdgeData
|
|
16
23
|
from .edge_data import EdgeData
|
|
17
24
|
|
|
@@ -106,6 +113,19 @@ class LZGraphBase(
|
|
|
106
113
|
# Topological order cache (built lazily, invalidated on structural changes)
|
|
107
114
|
self._topo_order = None
|
|
108
115
|
|
|
116
|
+
# ------------------------------------------------------------------
|
|
117
|
+
# Derived properties (single source of truth)
|
|
118
|
+
# ------------------------------------------------------------------
|
|
119
|
+
|
|
120
|
+
@property
|
|
121
|
+
def length_counts(self):
|
|
122
|
+
"""Alias for ``lengths`` — avoids storing the same dict twice."""
|
|
123
|
+
return self.lengths
|
|
124
|
+
|
|
125
|
+
@length_counts.setter
|
|
126
|
+
def length_counts(self, value):
|
|
127
|
+
self.lengths = value
|
|
128
|
+
|
|
109
129
|
@staticmethod
|
|
110
130
|
def _normalize_input(data, seq_column, abundances=None, v_genes=None, j_genes=None):
|
|
111
131
|
"""Convert flexible input to a standardised dict-of-lists.
|
|
@@ -242,7 +262,7 @@ class LZGraphBase(
|
|
|
242
262
|
if self.has_gene_data and other.has_gene_data:
|
|
243
263
|
aux += not _dicts_close(self.marginal_v_genes, other.marginal_v_genes, decimals=3)
|
|
244
264
|
aux += not _dicts_close(self.vj_probabilities, other.vj_probabilities, decimals=3)
|
|
245
|
-
aux += not _dicts_close(self.
|
|
265
|
+
aux += not _dicts_close(self.lengths, other.lengths, decimals=3)
|
|
246
266
|
|
|
247
267
|
return (aux == 0)
|
|
248
268
|
|
|
@@ -802,6 +822,10 @@ class LZGraphBase(
|
|
|
802
822
|
def _build_walk_cache(self, seed=None):
|
|
803
823
|
"""Build pre-computed numpy arrays for fast random walks.
|
|
804
824
|
|
|
825
|
+
Uses CSR (Compressed Sparse Row) format for neighbor data so
|
|
826
|
+
the entire walk can be driven by flat numpy arrays and
|
|
827
|
+
``searchsorted`` instead of per-step ``rng.choice()`` calls.
|
|
828
|
+
|
|
805
829
|
Returns a dict with the cache data, stored as ``self._walk_cache``.
|
|
806
830
|
"""
|
|
807
831
|
graph = self.graph
|
|
@@ -814,17 +838,42 @@ class LZGraphBase(
|
|
|
814
838
|
# Pre-compute clean labels for all nodes
|
|
815
839
|
clean_labels = np.array([self.extract_subpattern(name) for name in nodes], dtype=object)
|
|
816
840
|
|
|
817
|
-
#
|
|
818
|
-
|
|
819
|
-
|
|
841
|
+
# Build per-node neighbor/weight arrays.
|
|
842
|
+
# node_neighbors + node_weights: used by lzpgen_distribution
|
|
843
|
+
# node_cumweights: used by Python simulate fallback (bisect)
|
|
844
|
+
# CSR flat arrays are built only when the C extension is available.
|
|
845
|
+
node_neighbors = [None] * n # list of numpy int arrays
|
|
846
|
+
node_weights = [None] * n # list of numpy float arrays
|
|
847
|
+
node_cumweights = [None] * n # list of Python float lists (for bisect)
|
|
820
848
|
for i, name in enumerate(nodes):
|
|
821
849
|
succs = list(graph.successors(name))
|
|
822
850
|
if succs:
|
|
823
851
|
ids = np.array([node_to_id[s] for s in succs], dtype=np.intp)
|
|
852
|
+
node_neighbors[i] = ids
|
|
824
853
|
wts = np.array([graph[name][s]['data'].weight for s in succs], dtype=np.float64)
|
|
825
|
-
wts /= wts.sum()
|
|
826
|
-
|
|
827
|
-
|
|
854
|
+
wts /= wts.sum()
|
|
855
|
+
node_weights[i] = wts
|
|
856
|
+
cw = np.cumsum(wts)
|
|
857
|
+
cw[-1] = 1.0 # clamp for floating point safety
|
|
858
|
+
node_cumweights[i] = cw.tolist()
|
|
859
|
+
|
|
860
|
+
# Build flat CSR arrays only if C extension is available
|
|
861
|
+
if _c_simulate_walks is not None:
|
|
862
|
+
csr_offsets = np.empty(n + 1, dtype=np.intp)
|
|
863
|
+
csr_parts_nb = []
|
|
864
|
+
csr_parts_cw = []
|
|
865
|
+
offset = 0
|
|
866
|
+
for i in range(n):
|
|
867
|
+
csr_offsets[i] = offset
|
|
868
|
+
if node_neighbors[i] is not None:
|
|
869
|
+
csr_parts_nb.append(node_neighbors[i])
|
|
870
|
+
csr_parts_cw.append(np.array(node_cumweights[i], dtype=np.float64))
|
|
871
|
+
offset += len(node_neighbors[i])
|
|
872
|
+
csr_offsets[n] = offset
|
|
873
|
+
csr_neighbors = np.concatenate(csr_parts_nb) if csr_parts_nb else np.empty(0, dtype=np.intp)
|
|
874
|
+
csr_cumweights = np.concatenate(csr_parts_cw) if csr_parts_cw else np.empty(0, dtype=np.float64)
|
|
875
|
+
else:
|
|
876
|
+
csr_offsets = csr_neighbors = csr_cumweights = None
|
|
828
877
|
|
|
829
878
|
# Stop probabilities: NaN for non-terminal nodes
|
|
830
879
|
stop_probs = np.full(n, np.nan, dtype=np.float64)
|
|
@@ -834,12 +883,19 @@ class LZGraphBase(
|
|
|
834
883
|
|
|
835
884
|
# Initial state arrays
|
|
836
885
|
init_states = list(self.initial_state_probabilities.keys())
|
|
886
|
+
if not init_states:
|
|
887
|
+
raise ValueError(
|
|
888
|
+
"Cannot simulate: graph has no initial states. "
|
|
889
|
+
"Ensure the graph was constructed with valid sequences."
|
|
890
|
+
)
|
|
837
891
|
init_probs = np.array(
|
|
838
892
|
[self.initial_state_probabilities[s] for s in init_states],
|
|
839
893
|
dtype=np.float64,
|
|
840
894
|
)
|
|
841
895
|
init_probs = init_probs / init_probs.sum() # ensure normalization
|
|
842
896
|
initial_ids = np.array([node_to_id[s] for s in init_states], dtype=np.intp)
|
|
897
|
+
initial_cumprobs = np.cumsum(init_probs)
|
|
898
|
+
initial_cumprobs[-1] = 1.0 # clamp
|
|
843
899
|
|
|
844
900
|
rng = np.random.default_rng(seed)
|
|
845
901
|
|
|
@@ -847,11 +903,16 @@ class LZGraphBase(
|
|
|
847
903
|
'node_to_id': node_to_id,
|
|
848
904
|
'id_to_node': id_to_node,
|
|
849
905
|
'clean_labels': clean_labels,
|
|
850
|
-
'
|
|
851
|
-
'
|
|
906
|
+
'node_neighbors': node_neighbors,
|
|
907
|
+
'node_weights': node_weights,
|
|
908
|
+
'node_cumweights': node_cumweights,
|
|
909
|
+
'csr_offsets': csr_offsets,
|
|
910
|
+
'csr_neighbors': csr_neighbors,
|
|
911
|
+
'csr_cumweights': csr_cumweights,
|
|
852
912
|
'stop_probs': stop_probs,
|
|
853
913
|
'initial_ids': initial_ids,
|
|
854
914
|
'initial_probs': init_probs,
|
|
915
|
+
'initial_cumprobs': initial_cumprobs,
|
|
855
916
|
'rng': rng,
|
|
856
917
|
}
|
|
857
918
|
return self._walk_cache
|
|
@@ -879,46 +940,89 @@ class LZGraphBase(
|
|
|
879
940
|
self._build_walk_cache(seed)
|
|
880
941
|
|
|
881
942
|
cache = self._walk_cache
|
|
943
|
+
clean_labels = cache['clean_labels']
|
|
944
|
+
id_to_node = cache['id_to_node']
|
|
945
|
+
|
|
946
|
+
# ── C fast path ──────────────────────────────────────────────
|
|
947
|
+
if _c_simulate_walks is not None:
|
|
948
|
+
rng_seed = seed if seed is not None else int(cache['rng'].integers(0, 2**63))
|
|
949
|
+
return _c_simulate_walks(
|
|
950
|
+
n,
|
|
951
|
+
cache['csr_offsets'],
|
|
952
|
+
cache['csr_neighbors'],
|
|
953
|
+
cache['csr_cumweights'],
|
|
954
|
+
cache['stop_probs'],
|
|
955
|
+
cache['initial_ids'],
|
|
956
|
+
cache['initial_cumprobs'],
|
|
957
|
+
rng_seed,
|
|
958
|
+
list(clean_labels),
|
|
959
|
+
return_walks,
|
|
960
|
+
list(id_to_node),
|
|
961
|
+
)
|
|
962
|
+
|
|
963
|
+
# ── Python fallback ──────────────────────────────────────────
|
|
882
964
|
rng = cache['rng']
|
|
883
965
|
initial_ids = cache['initial_ids']
|
|
884
|
-
|
|
966
|
+
initial_cumprobs = cache['initial_cumprobs']
|
|
885
967
|
stop_probs = cache['stop_probs']
|
|
886
|
-
|
|
887
|
-
|
|
888
|
-
|
|
889
|
-
|
|
968
|
+
node_neighbors = cache['node_neighbors']
|
|
969
|
+
node_cumweights = cache['node_cumweights']
|
|
970
|
+
|
|
971
|
+
# Pre-generate random numbers in bulk for throughput
|
|
972
|
+
buf_size = max(n * 25, 1024)
|
|
973
|
+
rand_buf = rng.random(buf_size)
|
|
974
|
+
rand_idx = 0
|
|
890
975
|
|
|
976
|
+
# Local references to avoid repeated global/attribute lookups
|
|
977
|
+
bisect = _bisect_left
|
|
978
|
+
init_cumprobs_list = initial_cumprobs.tolist()
|
|
891
979
|
results = []
|
|
980
|
+
results_append = results.append
|
|
981
|
+
|
|
892
982
|
for _ in range(n):
|
|
893
|
-
#
|
|
894
|
-
|
|
983
|
+
# Refill buffer if running low
|
|
984
|
+
if rand_idx + 50 > len(rand_buf):
|
|
985
|
+
rand_buf = rng.random(buf_size)
|
|
986
|
+
rand_idx = 0
|
|
987
|
+
|
|
988
|
+
# Pick initial state via bisect on cumulative probs
|
|
989
|
+
current = initial_ids[bisect(init_cumprobs_list, rand_buf[rand_idx])]
|
|
990
|
+
rand_idx += 1
|
|
895
991
|
parts = [clean_labels[current]]
|
|
896
992
|
walk_ids = [current] if return_walks else None
|
|
897
993
|
|
|
898
994
|
while True:
|
|
899
995
|
# Check stop condition
|
|
900
996
|
stop_p = stop_probs[current]
|
|
901
|
-
if
|
|
902
|
-
if
|
|
997
|
+
if stop_p == stop_p: # fast NaN check (NaN != NaN)
|
|
998
|
+
if rand_buf[rand_idx] < stop_p:
|
|
999
|
+
rand_idx += 1
|
|
903
1000
|
break
|
|
1001
|
+
rand_idx += 1
|
|
904
1002
|
|
|
905
1003
|
# Check for dead-end (no outgoing edges)
|
|
906
|
-
|
|
907
|
-
if
|
|
1004
|
+
nb = node_neighbors[current]
|
|
1005
|
+
if nb is None:
|
|
908
1006
|
break
|
|
909
1007
|
|
|
910
|
-
# Take a step
|
|
911
|
-
current =
|
|
1008
|
+
# Take a step via bisect on per-node cumulative weights
|
|
1009
|
+
current = nb[bisect(node_cumweights[current], rand_buf[rand_idx])]
|
|
1010
|
+
rand_idx += 1
|
|
912
1011
|
parts.append(clean_labels[current])
|
|
913
1012
|
if return_walks:
|
|
914
1013
|
walk_ids.append(current)
|
|
915
1014
|
|
|
1015
|
+
# Refill buffer if running low
|
|
1016
|
+
if rand_idx + 50 > len(rand_buf):
|
|
1017
|
+
rand_buf = rng.random(buf_size)
|
|
1018
|
+
rand_idx = 0
|
|
1019
|
+
|
|
916
1020
|
sequence = ''.join(parts)
|
|
917
1021
|
if return_walks:
|
|
918
1022
|
walk = [id_to_node[wid] for wid in walk_ids]
|
|
919
|
-
|
|
1023
|
+
results_append((walk, sequence))
|
|
920
1024
|
else:
|
|
921
|
-
|
|
1025
|
+
results_append(sequence)
|
|
922
1026
|
|
|
923
1027
|
return results
|
|
924
1028
|
|
|
@@ -126,8 +126,6 @@ class NDPLZGraph(LZGraphBase):
|
|
|
126
126
|
self._log_step("Graph constructed.", verbose)
|
|
127
127
|
|
|
128
128
|
# Normalize and derive probability dicts
|
|
129
|
-
self.length_counts = dict(self.lengths)
|
|
130
|
-
|
|
131
129
|
total_terminal = sum(self.terminal_state_counts.values())
|
|
132
130
|
self.length_probabilities = (
|
|
133
131
|
{k: v / total_terminal for k, v in self.terminal_state_counts.items()}
|
|
@@ -17,6 +17,28 @@ class GeneLogicMixin:
|
|
|
17
17
|
with probability distribution `weights`.
|
|
18
18
|
"""
|
|
19
19
|
|
|
20
|
+
@property
|
|
21
|
+
def observed_v_genes(self):
|
|
22
|
+
"""Unique V genes — derived from ``marginal_v_genes`` keys."""
|
|
23
|
+
mg = getattr(self, 'marginal_v_genes', None)
|
|
24
|
+
return list(mg.keys()) if mg else []
|
|
25
|
+
|
|
26
|
+
@observed_v_genes.setter
|
|
27
|
+
def observed_v_genes(self, value):
|
|
28
|
+
# Accept sets from old pickles / JSON deserialization (no-op storage)
|
|
29
|
+
pass
|
|
30
|
+
|
|
31
|
+
@property
|
|
32
|
+
def observed_j_genes(self):
|
|
33
|
+
"""Unique J genes — derived from ``marginal_j_genes`` keys."""
|
|
34
|
+
mg = getattr(self, 'marginal_j_genes', None)
|
|
35
|
+
return list(mg.keys()) if mg else []
|
|
36
|
+
|
|
37
|
+
@observed_j_genes.setter
|
|
38
|
+
def observed_j_genes(self, value):
|
|
39
|
+
# Accept sets from old pickles / JSON deserialization (no-op storage)
|
|
40
|
+
pass
|
|
41
|
+
|
|
20
42
|
def _raise_genetic_mode_error(self):
|
|
21
43
|
"""
|
|
22
44
|
Raise an error if genetic mode is off but a genetic function is called.
|
|
@@ -39,10 +61,6 @@ class GeneLogicMixin:
|
|
|
39
61
|
v_list = data['v_genes']
|
|
40
62
|
j_list = data['j_genes']
|
|
41
63
|
|
|
42
|
-
# Unique sets of V and J
|
|
43
|
-
self.observed_v_genes = list(set(v_list))
|
|
44
|
-
self.observed_j_genes = list(set(j_list))
|
|
45
|
-
|
|
46
64
|
# Marginal distributions (normalized) — stored as plain dicts
|
|
47
65
|
n = len(v_list)
|
|
48
66
|
v_counts = {}
|
|
@@ -56,8 +56,8 @@ class LZPgenDistributionMixin:
|
|
|
56
56
|
initial_ids = cache['initial_ids']
|
|
57
57
|
initial_probs = cache['initial_probs']
|
|
58
58
|
stop_probs = cache['stop_probs']
|
|
59
|
-
|
|
60
|
-
|
|
59
|
+
node_neighbors = cache['node_neighbors']
|
|
60
|
+
node_weights = cache['node_weights']
|
|
61
61
|
|
|
62
62
|
# Pre-compute log values for zero per-step overhead
|
|
63
63
|
eps = _EPS
|
|
@@ -71,9 +71,9 @@ class LZPgenDistributionMixin:
|
|
|
71
71
|
|
|
72
72
|
neighbor_log_weights = [None] * n_nodes
|
|
73
73
|
for i in range(n_nodes):
|
|
74
|
-
if
|
|
74
|
+
if node_weights[i] is not None:
|
|
75
75
|
neighbor_log_weights[i] = np.log(
|
|
76
|
-
np.maximum(
|
|
76
|
+
np.maximum(node_weights[i], eps)
|
|
77
77
|
)
|
|
78
78
|
|
|
79
79
|
log_probs = np.empty(n, dtype=np.float64)
|
|
@@ -94,16 +94,16 @@ class LZPgenDistributionMixin:
|
|
|
94
94
|
break
|
|
95
95
|
|
|
96
96
|
# Dead-end check
|
|
97
|
-
|
|
98
|
-
if
|
|
97
|
+
nb = node_neighbors[current]
|
|
98
|
+
if nb is None:
|
|
99
99
|
log_p += np.log(eps)
|
|
100
100
|
break
|
|
101
101
|
|
|
102
102
|
# Take a step
|
|
103
|
-
n_nb = len(
|
|
104
|
-
step_idx = rng.choice(n_nb, p=
|
|
103
|
+
n_nb = len(nb)
|
|
104
|
+
step_idx = rng.choice(n_nb, p=node_weights[current])
|
|
105
105
|
log_p += neighbor_log_weights[current][step_idx]
|
|
106
|
-
current =
|
|
106
|
+
current = nb[step_idx]
|
|
107
107
|
|
|
108
108
|
log_probs[seq_idx] = log_p
|
|
109
109
|
|
|
@@ -32,7 +32,7 @@ class SerializationMixin:
|
|
|
32
32
|
'subpattern_individual_probability': 'node_probability',
|
|
33
33
|
'per_node_observed_frequency': 'node_outgoing_counts',
|
|
34
34
|
'length_distribution_proba': 'length_probabilities',
|
|
35
|
-
'length_distribution': '
|
|
35
|
+
'length_distribution': 'lengths',
|
|
36
36
|
'n_subpatterns': 'num_subpatterns',
|
|
37
37
|
'n_transitions': 'num_transitions',
|
|
38
38
|
'marginal_vgenes': 'marginal_v_genes',
|
|
@@ -55,6 +55,20 @@ class SerializationMixin:
|
|
|
55
55
|
'j_call': 'J',
|
|
56
56
|
}
|
|
57
57
|
|
|
58
|
+
# Transient attributes excluded from pickle (rebuilt on demand)
|
|
59
|
+
_TRANSIENT_ATTRS = frozenset({
|
|
60
|
+
'_walk_cache',
|
|
61
|
+
'_topo_order',
|
|
62
|
+
'_edges_cache',
|
|
63
|
+
'constructor_start_time',
|
|
64
|
+
'constructor_end_time',
|
|
65
|
+
})
|
|
66
|
+
|
|
67
|
+
def __getstate__(self):
|
|
68
|
+
"""Exclude transient caches from pickle to reduce file size."""
|
|
69
|
+
return {k: v for k, v in self.__dict__.items()
|
|
70
|
+
if k not in self._TRANSIENT_ATTRS}
|
|
71
|
+
|
|
58
72
|
def __setstate__(self, state):
|
|
59
73
|
"""Restore instance from pickle, migrating old attribute names and pandas types."""
|
|
60
74
|
# Migrate old attribute names to new names
|
|
@@ -62,6 +76,17 @@ class SerializationMixin:
|
|
|
62
76
|
if old_name in state and new_name not in state:
|
|
63
77
|
state[new_name] = state.pop(old_name)
|
|
64
78
|
|
|
79
|
+
# length_counts is now a property aliasing lengths — migrate old pickles
|
|
80
|
+
if 'length_counts' in state:
|
|
81
|
+
if 'lengths' not in state:
|
|
82
|
+
state['lengths'] = state.pop('length_counts')
|
|
83
|
+
else:
|
|
84
|
+
del state['length_counts']
|
|
85
|
+
|
|
86
|
+
# observed_v/j_genes are now properties — remove stored values
|
|
87
|
+
state.pop('observed_v_genes', None)
|
|
88
|
+
state.pop('observed_j_genes', None)
|
|
89
|
+
|
|
65
90
|
# Migrate 'wsif/sep' key inside terminal_state_data dicts
|
|
66
91
|
tsd = state.get('terminal_state_data')
|
|
67
92
|
if tsd is not None and isinstance(tsd, dict):
|
|
@@ -80,7 +105,7 @@ class SerializationMixin:
|
|
|
80
105
|
for attr in ('initial_state_counts', 'terminal_state_counts',
|
|
81
106
|
'initial_state_probabilities', 'length_probabilities',
|
|
82
107
|
'marginal_v_genes', 'marginal_j_genes', 'vj_probabilities',
|
|
83
|
-
'
|
|
108
|
+
'lengths'):
|
|
84
109
|
val = getattr(self, attr, None)
|
|
85
110
|
if val is not None and hasattr(val, 'to_dict'):
|
|
86
111
|
setattr(self, attr, val.to_dict())
|
|
@@ -324,12 +349,7 @@ class SerializationMixin:
|
|
|
324
349
|
data['marginal_j_genes'] = _to_dict(self.marginal_j_genes)
|
|
325
350
|
if hasattr(self, 'vj_probabilities'):
|
|
326
351
|
data['vj_probabilities'] = _to_dict(self.vj_probabilities)
|
|
327
|
-
|
|
328
|
-
data['length_counts'] = _to_dict(self.length_counts)
|
|
329
|
-
if hasattr(self, 'observed_v_genes'):
|
|
330
|
-
data['observed_v_genes'] = list(self.observed_v_genes)
|
|
331
|
-
if hasattr(self, 'observed_j_genes'):
|
|
332
|
-
data['observed_j_genes'] = list(self.observed_j_genes)
|
|
352
|
+
data['length_counts'] = _to_dict(self.lengths)
|
|
333
353
|
|
|
334
354
|
# Terminal state data
|
|
335
355
|
if hasattr(self, 'terminal_state_data'):
|
|
@@ -421,7 +441,9 @@ class SerializationMixin:
|
|
|
421
441
|
instance.terminal_state_counts = _to_plain_dict(
|
|
422
442
|
data.get('terminal_state_counts', data.get('terminal_states'))
|
|
423
443
|
)
|
|
424
|
-
instance.lengths = data.get('lengths',
|
|
444
|
+
instance.lengths = data.get('lengths',
|
|
445
|
+
data.get('length_counts',
|
|
446
|
+
data.get('length_distribution', {})))
|
|
425
447
|
instance.vj_combination_graphs = {}
|
|
426
448
|
instance.num_neighbours = {}
|
|
427
449
|
instance.node_outgoing_counts = data.get('node_outgoing_counts',
|
|
@@ -451,15 +473,6 @@ class SerializationMixin:
|
|
|
451
473
|
instance.marginal_j_genes = _to_plain_dict(mg_j)
|
|
452
474
|
if 'vj_probabilities' in data:
|
|
453
475
|
instance.vj_probabilities = _to_plain_dict(data['vj_probabilities'])
|
|
454
|
-
lc = data.get('length_counts', data.get('length_distribution'))
|
|
455
|
-
if lc is not None:
|
|
456
|
-
instance.length_counts = _to_plain_dict(lc)
|
|
457
|
-
ov = data.get('observed_v_genes', data.get('observed_vgenes'))
|
|
458
|
-
if ov is not None:
|
|
459
|
-
instance.observed_v_genes = set(ov)
|
|
460
|
-
oj = data.get('observed_j_genes', data.get('observed_jgenes'))
|
|
461
|
-
if oj is not None:
|
|
462
|
-
instance.observed_j_genes = set(oj)
|
|
463
476
|
|
|
464
477
|
# Restore terminal state data
|
|
465
478
|
if 'terminal_state_data' in data:
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: LZGraphs
|
|
3
|
-
Version: 2.
|
|
3
|
+
Version: 2.3.0
|
|
4
4
|
Summary: An Implementation of LZ76 Based Graphs for Repertoire Representation and Analysis
|
|
5
5
|
Author-email: Thomas Konstantinovsky <thomaskon90@gmail.com>
|
|
6
6
|
Maintainer-email: Thomas Konstantinovsky <thomaskon90@gmail.com>
|
|
@@ -58,8 +58,13 @@ class TestLZPgenDistributionBasic:
|
|
|
58
58
|
result = aap_lzgraph.lzpgen_distribution(n=0, seed=42)
|
|
59
59
|
assert len(result) == 0
|
|
60
60
|
|
|
61
|
-
def test_consistent_with_walk_log_probability(self, aap_lzgraph):
|
|
61
|
+
def test_consistent_with_walk_log_probability(self, aap_lzgraph, monkeypatch):
|
|
62
62
|
"""Values should match walk_log_probability on the same walks."""
|
|
63
|
+
# Force Python fallback so simulate() and lzpgen_distribution()
|
|
64
|
+
# use the same numpy RNG and produce identical walks for the same seed.
|
|
65
|
+
import LZGraphs.graphs.lz_graph_base as _base
|
|
66
|
+
monkeypatch.setattr(_base, '_c_simulate_walks', None)
|
|
67
|
+
|
|
63
68
|
walks_and_seqs = aap_lzgraph.simulate(20, seed=42, return_walks=True)
|
|
64
69
|
dist = aap_lzgraph.lzpgen_distribution(n=20, seed=42)
|
|
65
70
|
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|