LZGraphs 2.2.0__tar.gz → 2.3.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. {lzgraphs-2.2.0 → lzgraphs-2.3.0}/PKG-INFO +1 -1
  2. lzgraphs-2.3.0/setup.py +40 -0
  3. {lzgraphs-2.2.0 → lzgraphs-2.3.0}/src/LZGraphs/__init__.py +1 -1
  4. lzgraphs-2.3.0/src/LZGraphs/_fast_walk.c +321 -0
  5. {lzgraphs-2.2.0 → lzgraphs-2.3.0}/src/LZGraphs/graphs/amino_acid_positional.py +0 -2
  6. {lzgraphs-2.2.0 → lzgraphs-2.3.0}/src/LZGraphs/graphs/graph_operations.py +4 -13
  7. {lzgraphs-2.2.0 → lzgraphs-2.3.0}/src/LZGraphs/graphs/lz_graph_base.py +128 -24
  8. {lzgraphs-2.2.0 → lzgraphs-2.3.0}/src/LZGraphs/graphs/nucleotide_double_positional.py +0 -2
  9. {lzgraphs-2.2.0 → lzgraphs-2.3.0}/src/LZGraphs/mixins/gene_logic.py +22 -4
  10. {lzgraphs-2.2.0 → lzgraphs-2.3.0}/src/LZGraphs/mixins/lzpgen_distribution.py +9 -9
  11. {lzgraphs-2.2.0 → lzgraphs-2.3.0}/src/LZGraphs/mixins/serialization.py +31 -18
  12. {lzgraphs-2.2.0 → lzgraphs-2.3.0}/src/LZGraphs.egg-info/PKG-INFO +1 -1
  13. {lzgraphs-2.2.0 → lzgraphs-2.3.0}/src/LZGraphs.egg-info/SOURCES.txt +2 -0
  14. {lzgraphs-2.2.0 → lzgraphs-2.3.0}/tests/test_lzpgen_distribution.py +6 -1
  15. {lzgraphs-2.2.0 → lzgraphs-2.3.0}/CHANGELOG.md +0 -0
  16. {lzgraphs-2.2.0 → lzgraphs-2.3.0}/CONTRIBUTING.md +0 -0
  17. {lzgraphs-2.2.0 → lzgraphs-2.3.0}/LICENSE +0 -0
  18. {lzgraphs-2.2.0 → lzgraphs-2.3.0}/MANIFEST.in +0 -0
  19. {lzgraphs-2.2.0 → lzgraphs-2.3.0}/README.md +0 -0
  20. {lzgraphs-2.2.0 → lzgraphs-2.3.0}/pyproject.toml +0 -0
  21. {lzgraphs-2.2.0 → lzgraphs-2.3.0}/requirements.txt +0 -0
  22. {lzgraphs-2.2.0 → lzgraphs-2.3.0}/setup.cfg +0 -0
  23. {lzgraphs-2.2.0 → lzgraphs-2.3.0}/src/LZGraphs/bag_of_words/__init__.py +0 -0
  24. {lzgraphs-2.2.0 → lzgraphs-2.3.0}/src/LZGraphs/bag_of_words/bow_encoder.py +0 -0
  25. {lzgraphs-2.2.0 → lzgraphs-2.3.0}/src/LZGraphs/constants.py +0 -0
  26. {lzgraphs-2.2.0 → lzgraphs-2.3.0}/src/LZGraphs/exceptions/__init__.py +0 -0
  27. {lzgraphs-2.2.0 → lzgraphs-2.3.0}/src/LZGraphs/graphs/__init__.py +0 -0
  28. {lzgraphs-2.2.0 → lzgraphs-2.3.0}/src/LZGraphs/graphs/edge_data.py +0 -0
  29. {lzgraphs-2.2.0 → lzgraphs-2.3.0}/src/LZGraphs/graphs/naive.py +0 -0
  30. {lzgraphs-2.2.0 → lzgraphs-2.3.0}/src/LZGraphs/metrics/__init__.py +0 -0
  31. {lzgraphs-2.2.0 → lzgraphs-2.3.0}/src/LZGraphs/metrics/convenience.py +0 -0
  32. {lzgraphs-2.2.0 → lzgraphs-2.3.0}/src/LZGraphs/metrics/diversity.py +0 -0
  33. {lzgraphs-2.2.0 → lzgraphs-2.3.0}/src/LZGraphs/metrics/entropy.py +0 -0
  34. {lzgraphs-2.2.0 → lzgraphs-2.3.0}/src/LZGraphs/metrics/pgen_distribution.py +0 -0
  35. {lzgraphs-2.2.0 → lzgraphs-2.3.0}/src/LZGraphs/metrics/saturation.py +0 -0
  36. {lzgraphs-2.2.0 → lzgraphs-2.3.0}/src/LZGraphs/mixins/__init__.py +0 -0
  37. {lzgraphs-2.2.0 → lzgraphs-2.3.0}/src/LZGraphs/mixins/bayesian_posterior.py +0 -0
  38. {lzgraphs-2.2.0 → lzgraphs-2.3.0}/src/LZGraphs/mixins/gene_prediction.py +0 -0
  39. {lzgraphs-2.2.0 → lzgraphs-2.3.0}/src/LZGraphs/mixins/graph_topology.py +0 -0
  40. {lzgraphs-2.2.0 → lzgraphs-2.3.0}/src/LZGraphs/mixins/random_walk.py +0 -0
  41. {lzgraphs-2.2.0 → lzgraphs-2.3.0}/src/LZGraphs/mixins/walk_analysis.py +0 -0
  42. {lzgraphs-2.2.0 → lzgraphs-2.3.0}/src/LZGraphs/py.typed +0 -0
  43. {lzgraphs-2.2.0 → lzgraphs-2.3.0}/src/LZGraphs/utilities/__init__.py +0 -0
  44. {lzgraphs-2.2.0 → lzgraphs-2.3.0}/src/LZGraphs/utilities/decomposition.py +0 -0
  45. {lzgraphs-2.2.0 → lzgraphs-2.3.0}/src/LZGraphs/utilities/helpers.py +0 -0
  46. {lzgraphs-2.2.0 → lzgraphs-2.3.0}/src/LZGraphs/utilities/misc.py +0 -0
  47. {lzgraphs-2.2.0 → lzgraphs-2.3.0}/src/LZGraphs/visualization/__init__.py +0 -0
  48. {lzgraphs-2.2.0 → lzgraphs-2.3.0}/src/LZGraphs/visualization/visualize.py +0 -0
  49. {lzgraphs-2.2.0 → lzgraphs-2.3.0}/src/LZGraphs.egg-info/dependency_links.txt +0 -0
  50. {lzgraphs-2.2.0 → lzgraphs-2.3.0}/src/LZGraphs.egg-info/requires.txt +0 -0
  51. {lzgraphs-2.2.0 → lzgraphs-2.3.0}/src/LZGraphs.egg-info/top_level.txt +0 -0
  52. {lzgraphs-2.2.0 → lzgraphs-2.3.0}/tests/test_aap_lzgraph.py +0 -0
  53. {lzgraphs-2.2.0 → lzgraphs-2.3.0}/tests/test_abundance.py +0 -0
  54. {lzgraphs-2.2.0 → lzgraphs-2.3.0}/tests/test_analytical_distribution.py +0 -0
  55. {lzgraphs-2.2.0 → lzgraphs-2.3.0}/tests/test_base_class_methods.py +0 -0
  56. {lzgraphs-2.2.0 → lzgraphs-2.3.0}/tests/test_bow_encoder.py +0 -0
  57. {lzgraphs-2.2.0 → lzgraphs-2.3.0}/tests/test_diversity_theory.py +0 -0
  58. {lzgraphs-2.2.0 → lzgraphs-2.3.0}/tests/test_flexible_input.py +0 -0
  59. {lzgraphs-2.2.0 → lzgraphs-2.3.0}/tests/test_graph_operations.py +0 -0
  60. {lzgraphs-2.2.0 → lzgraphs-2.3.0}/tests/test_metrics.py +0 -0
  61. {lzgraphs-2.2.0 → lzgraphs-2.3.0}/tests/test_naive_lzgraph.py +0 -0
  62. {lzgraphs-2.2.0 → lzgraphs-2.3.0}/tests/test_ndp_lzgraph.py +0 -0
  63. {lzgraphs-2.2.0 → lzgraphs-2.3.0}/tests/test_new_features.py +0 -0
  64. {lzgraphs-2.2.0 → lzgraphs-2.3.0}/tests/test_pgen_fixes.py +0 -0
  65. {lzgraphs-2.2.0 → lzgraphs-2.3.0}/tests/test_serialization.py +0 -0
  66. {lzgraphs-2.2.0 → lzgraphs-2.3.0}/tests/test_simulate.py +0 -0
  67. {lzgraphs-2.2.0 → lzgraphs-2.3.0}/tests/test_utilities.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: LZGraphs
3
- Version: 2.2.0
3
+ Version: 2.3.0
4
4
  Summary: An Implementation of LZ76 Based Graphs for Repertoire Representation and Analysis
5
5
  Author-email: Thomas Konstantinovsky <thomaskon90@gmail.com>
6
6
  Maintainer-email: Thomas Konstantinovsky <thomaskon90@gmail.com>
@@ -0,0 +1,40 @@
1
+ """
2
+ Build script for optional C extensions.
3
+
4
+ The _fast_walk extension accelerates LZGraph.simulate() by ~50-100x.
5
+ If compilation fails (no C compiler), the package still installs and
6
+ falls back to the pure-Python implementation automatically.
7
+ """
8
+
9
+ import os
10
+ import sys
11
+ from setuptools import setup, Extension
12
+
13
+ # Ensure setuptools can resolve the dynamic version (attr = "LZGraphs.__version__")
14
+ # when running in an isolated build environment where src/ isn't on sys.path.
15
+ sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), "src"))
16
+
17
+ ext_modules = [
18
+ Extension(
19
+ "LZGraphs._fast_walk",
20
+ sources=[os.path.join("src", "LZGraphs", "_fast_walk.c")],
21
+ # No external library dependencies — pure C + Python.h
22
+ ),
23
+ ]
24
+
25
+
26
+ def run_setup(extensions):
27
+ setup(ext_modules=extensions)
28
+
29
+
30
+ try:
31
+ run_setup(ext_modules)
32
+ except Exception:
33
+ print(
34
+ "\n"
35
+ "WARNING: Failed to compile C extension _fast_walk.\n"
36
+ " LZGraphs will use the pure-Python fallback for simulate().\n"
37
+ " This is fine — the package works without it, just slower.\n"
38
+ "\n"
39
+ )
40
+ run_setup([])
@@ -1,4 +1,4 @@
1
- __version__ = "2.2.0"
1
+ __version__ = "2.3.0"
2
2
 
3
3
  # =============================================================================
4
4
  # Graph classes
@@ -0,0 +1,321 @@
1
+ /*
2
+ * _fast_walk.c — CPython C extension for fast Markov chain random walks.
3
+ *
4
+ * Implements the full simulate() loop in C including string assembly,
5
+ * for ~100-200x speedup over the original pure-Python implementation.
6
+ * Uses xoshiro256++ for fast, high-quality RNG.
7
+ *
8
+ * The extension is optional: if it fails to compile (no C compiler),
9
+ * LZGraphs falls back to the pure-Python bisect-based implementation.
10
+ */
11
+
12
+ #define PY_SSIZE_T_CLEAN
13
+ #include <Python.h>
14
+ #include <stdint.h>
15
+ #include <string.h>
16
+
17
+ /* ========================================================================
18
+ * xoshiro256++ RNG — public domain by David Blackman and Sebastiano Vigna
19
+ * ======================================================================== */
20
+
21
+ static inline uint64_t rotl(const uint64_t x, int k) {
22
+ return (x << k) | (x >> (64 - k));
23
+ }
24
+
25
+ typedef struct {
26
+ uint64_t s[4];
27
+ } xoshiro256_state;
28
+
29
+ static inline uint64_t xoshiro256pp_next(xoshiro256_state *state) {
30
+ const uint64_t result = rotl(state->s[0] + state->s[3], 23) + state->s[0];
31
+ const uint64_t t = state->s[1] << 17;
32
+ state->s[2] ^= state->s[0];
33
+ state->s[3] ^= state->s[1];
34
+ state->s[1] ^= state->s[2];
35
+ state->s[0] ^= state->s[3];
36
+ state->s[2] ^= t;
37
+ state->s[3] = rotl(state->s[3], 45);
38
+ return result;
39
+ }
40
+
41
+ static inline double xoshiro256pp_double(xoshiro256_state *state) {
42
+ return (double)(xoshiro256pp_next(state) >> 11) * 0x1.0p-53;
43
+ }
44
+
45
+ static inline uint64_t splitmix64(uint64_t *x) {
46
+ uint64_t z = (*x += 0x9e3779b97f4a7c15ULL);
47
+ z = (z ^ (z >> 30)) * 0xbf58476d1ce4e5b9ULL;
48
+ z = (z ^ (z >> 27)) * 0x94d049bb133111ebULL;
49
+ return z ^ (z >> 31);
50
+ }
51
+
52
+ static void seed_xoshiro256(xoshiro256_state *state, uint64_t seed) {
53
+ state->s[0] = splitmix64(&seed);
54
+ state->s[1] = splitmix64(&seed);
55
+ state->s[2] = splitmix64(&seed);
56
+ state->s[3] = splitmix64(&seed);
57
+ }
58
+
59
+ /* ========================================================================
60
+ * Binary search (bisect_left) on a double array
61
+ * ======================================================================== */
62
+
63
+ static inline Py_ssize_t bisect_left_double(
64
+ const double *arr, Py_ssize_t n, double value
65
+ ) {
66
+ Py_ssize_t lo = 0, hi = n;
67
+ while (lo < hi) {
68
+ Py_ssize_t mid = lo + (hi - lo) / 2;
69
+ if (arr[mid] < value)
70
+ lo = mid + 1;
71
+ else
72
+ hi = mid;
73
+ }
74
+ return lo;
75
+ }
76
+
77
+ /* ========================================================================
78
+ * simulate_walks — full simulation with string assembly in C
79
+ *
80
+ * Args:
81
+ * n_walks : int
82
+ * offsets : intp array [n_nodes+1] (buffer)
83
+ * neighbors : intp array [total_edges] (buffer)
84
+ * cumweights : float64 array [total_edges] (buffer)
85
+ * stop_probs : float64 array [n_nodes] (buffer)
86
+ * initial_ids : intp array [n_initial] (buffer)
87
+ * initial_cw : float64 array [n_initial] (buffer)
88
+ * seed : uint64
89
+ * clean_labels : list[str] — label for each node ID
90
+ * return_walks : bool — if True, return (walk, seq) tuples
91
+ * id_to_node : list[str] — node names (only used if return_walks)
92
+ *
93
+ * Returns:
94
+ * list[str] or list[tuple[list[str], str]]
95
+ * ======================================================================== */
96
+
97
+ static PyObject* py_simulate_walks(PyObject *self, PyObject *args) {
98
+ int n_walks, return_walks;
99
+ Py_buffer offsets_buf, neighbors_buf, cumweights_buf;
100
+ Py_buffer stop_probs_buf, initial_ids_buf, initial_cw_buf;
101
+ unsigned long long seed;
102
+ PyObject *clean_labels; /* Python list of str */
103
+ PyObject *id_to_node; /* Python list of str */
104
+ PyObject *result_list = NULL;
105
+
106
+ if (!PyArg_ParseTuple(args, "iy*y*y*y*y*y*KOpO",
107
+ &n_walks,
108
+ &offsets_buf, &neighbors_buf, &cumweights_buf,
109
+ &stop_probs_buf, &initial_ids_buf, &initial_cw_buf,
110
+ &seed,
111
+ &clean_labels,
112
+ &return_walks,
113
+ &id_to_node))
114
+ return NULL;
115
+
116
+ const Py_ssize_t *offsets = (const Py_ssize_t *)offsets_buf.buf;
117
+ const Py_ssize_t *neighbors = (const Py_ssize_t *)neighbors_buf.buf;
118
+ const double *cumweights = (const double *)cumweights_buf.buf;
119
+ const double *stop_probs = (const double *)stop_probs_buf.buf;
120
+ const Py_ssize_t *initial_ids = (const Py_ssize_t *)initial_ids_buf.buf;
121
+ const double *initial_cw = (const double *)initial_cw_buf.buf;
122
+ const Py_ssize_t n_initial = initial_cw_buf.len / (Py_ssize_t)sizeof(double);
123
+
124
+ if (n_initial <= 0) {
125
+ PyErr_SetString(PyExc_ValueError,
126
+ "Cannot simulate: graph has no initial states.");
127
+ goto cleanup;
128
+ }
129
+
130
+ /* Pre-fetch label UTF-8 data for fast string assembly */
131
+ const Py_ssize_t n_labels = PyList_GET_SIZE(clean_labels);
132
+ const char **label_ptrs = (const char **)PyMem_Malloc(n_labels * sizeof(char *));
133
+ Py_ssize_t *label_lens = (Py_ssize_t *)PyMem_Malloc(n_labels * sizeof(Py_ssize_t));
134
+ if (!label_ptrs || !label_lens) {
135
+ PyMem_Free(label_ptrs);
136
+ PyMem_Free(label_lens);
137
+ PyErr_NoMemory();
138
+ goto cleanup;
139
+ }
140
+ for (Py_ssize_t i = 0; i < n_labels; i++) {
141
+ PyObject *s = PyList_GET_ITEM(clean_labels, i);
142
+ label_ptrs[i] = PyUnicode_AsUTF8AndSize(s, &label_lens[i]);
143
+ if (!label_ptrs[i]) {
144
+ PyMem_Free(label_ptrs);
145
+ PyMem_Free(label_lens);
146
+ goto cleanup;
147
+ }
148
+ }
149
+
150
+ xoshiro256_state rng;
151
+ seed_xoshiro256(&rng, (uint64_t)seed);
152
+
153
+ result_list = PyList_New(n_walks);
154
+ if (!result_list) {
155
+ PyMem_Free(label_ptrs);
156
+ PyMem_Free(label_lens);
157
+ goto cleanup;
158
+ }
159
+
160
+ /* Reusable walk buffer */
161
+ Py_ssize_t walk_cap = 64;
162
+ Py_ssize_t *walk_buf = (Py_ssize_t *)PyMem_Malloc(walk_cap * sizeof(Py_ssize_t));
163
+ /* Reusable string buffer */
164
+ Py_ssize_t str_cap = 256;
165
+ char *str_buf = (char *)PyMem_Malloc(str_cap);
166
+ if (!walk_buf || !str_buf) {
167
+ PyMem_Free(walk_buf);
168
+ PyMem_Free(str_buf);
169
+ PyMem_Free(label_ptrs);
170
+ PyMem_Free(label_lens);
171
+ Py_DECREF(result_list);
172
+ PyErr_NoMemory();
173
+ goto cleanup;
174
+ }
175
+
176
+ for (int i = 0; i < n_walks; i++) {
177
+ /* Pick initial state */
178
+ double r = xoshiro256pp_double(&rng);
179
+ Py_ssize_t init_idx = bisect_left_double(initial_cw, n_initial, r);
180
+ if (init_idx >= n_initial) init_idx = n_initial - 1;
181
+ Py_ssize_t current = initial_ids[init_idx];
182
+
183
+ Py_ssize_t walk_len = 0;
184
+ walk_buf[walk_len++] = current;
185
+
186
+ /* Build string incrementally */
187
+ Py_ssize_t str_len = 0;
188
+ Py_ssize_t llen = label_lens[current];
189
+ if (str_len + llen > str_cap) {
190
+ str_cap = (str_len + llen) * 2;
191
+ str_buf = (char *)PyMem_Realloc(str_buf, str_cap);
192
+ if (!str_buf) goto oom;
193
+ }
194
+ memcpy(str_buf + str_len, label_ptrs[current], llen);
195
+ str_len += llen;
196
+
197
+ while (1) {
198
+ double sp = stop_probs[current];
199
+ if (sp == sp) {
200
+ if (xoshiro256pp_double(&rng) < sp)
201
+ break;
202
+ }
203
+
204
+ Py_ssize_t start = offsets[current];
205
+ Py_ssize_t end = offsets[current + 1];
206
+ if (start == end)
207
+ break;
208
+
209
+ r = xoshiro256pp_double(&rng);
210
+ Py_ssize_t idx = bisect_left_double(cumweights + start, end - start, r);
211
+ if (idx >= end - start) idx = end - start - 1;
212
+ current = neighbors[start + idx];
213
+
214
+ /* Grow walk buffer if needed */
215
+ if (walk_len >= walk_cap) {
216
+ walk_cap *= 2;
217
+ Py_ssize_t *new_buf = (Py_ssize_t *)PyMem_Realloc(walk_buf, walk_cap * sizeof(Py_ssize_t));
218
+ if (!new_buf) goto oom;
219
+ walk_buf = new_buf;
220
+ }
221
+ walk_buf[walk_len++] = current;
222
+
223
+ /* Append label to string buffer */
224
+ llen = label_lens[current];
225
+ if (str_len + llen > str_cap) {
226
+ str_cap = (str_len + llen) * 2;
227
+ char *new_str = (char *)PyMem_Realloc(str_buf, str_cap);
228
+ if (!new_str) goto oom;
229
+ str_buf = new_str;
230
+ }
231
+ memcpy(str_buf + str_len, label_ptrs[current], llen);
232
+ str_len += llen;
233
+ }
234
+
235
+ /* Create Python string from buffer */
236
+ PyObject *seq = PyUnicode_FromStringAndSize(str_buf, str_len);
237
+ if (!seq) goto oom;
238
+
239
+ if (return_walks) {
240
+ /* Build walk list of node name strings */
241
+ PyObject *walk = PyList_New(walk_len);
242
+ if (!walk) { Py_DECREF(seq); goto oom; }
243
+ for (Py_ssize_t j = 0; j < walk_len; j++) {
244
+ PyObject *node_name = PyList_GET_ITEM(id_to_node, walk_buf[j]);
245
+ Py_INCREF(node_name);
246
+ PyList_SET_ITEM(walk, j, node_name);
247
+ }
248
+ PyObject *tup = PyTuple_Pack(2, walk, seq);
249
+ Py_DECREF(walk);
250
+ Py_DECREF(seq);
251
+ if (!tup) goto oom;
252
+ PyList_SET_ITEM(result_list, i, tup);
253
+ } else {
254
+ PyList_SET_ITEM(result_list, i, seq);
255
+ }
256
+ }
257
+
258
+ PyMem_Free(walk_buf);
259
+ PyMem_Free(str_buf);
260
+ PyMem_Free(label_ptrs);
261
+ PyMem_Free(label_lens);
262
+ goto cleanup;
263
+
264
+ oom:
265
+ PyMem_Free(walk_buf);
266
+ PyMem_Free(str_buf);
267
+ PyMem_Free(label_ptrs);
268
+ PyMem_Free(label_lens);
269
+ Py_XDECREF(result_list);
270
+ result_list = NULL;
271
+ if (!PyErr_Occurred())
272
+ PyErr_NoMemory();
273
+
274
+ cleanup:
275
+ PyBuffer_Release(&offsets_buf);
276
+ PyBuffer_Release(&neighbors_buf);
277
+ PyBuffer_Release(&cumweights_buf);
278
+ PyBuffer_Release(&stop_probs_buf);
279
+ PyBuffer_Release(&initial_ids_buf);
280
+ PyBuffer_Release(&initial_cw_buf);
281
+
282
+ return result_list;
283
+ }
284
+
285
+ /* ========================================================================
286
+ * Module definition
287
+ * ======================================================================== */
288
+
289
+ static PyMethodDef FastWalkMethods[] = {
290
+ {"simulate_walks", py_simulate_walks, METH_VARARGS,
291
+ "Run n random walks on a CSR-encoded graph with string assembly.\n\n"
292
+ "Args:\n"
293
+ " n_walks (int): Number of walks.\n"
294
+ " offsets (array): CSR row offsets [n_nodes+1], dtype=intp.\n"
295
+ " neighbors (array): Flat neighbor IDs, dtype=intp.\n"
296
+ " cumweights (array): Flat cumulative weights, dtype=float64.\n"
297
+ " stop_probs (array): Per-node stop probability (NaN=none), dtype=float64.\n"
298
+ " initial_ids (array): Initial state IDs, dtype=intp.\n"
299
+ " initial_cumprobs (array): Cumulative initial probs, dtype=float64.\n"
300
+ " seed (int): RNG seed (xoshiro256++).\n"
301
+ " clean_labels (list[str]): Subpattern label for each node.\n"
302
+ " return_walks (bool): If True, return (walk, seq) tuples.\n"
303
+ " id_to_node (list[str]): Node names for walk output.\n\n"
304
+ "Returns:\n"
305
+ " list[str] or list[tuple[list[str], str]]\n"},
306
+ {NULL, NULL, 0, NULL}
307
+ };
308
+
309
+ static struct PyModuleDef fast_walk_module = {
310
+ PyModuleDef_HEAD_INIT,
311
+ "_fast_walk",
312
+ "C-accelerated random walk simulation for LZGraphs.\n"
313
+ "Uses xoshiro256++ RNG for high-quality, fast random number generation.\n"
314
+ "This module is optional — LZGraphs falls back to pure Python if unavailable.",
315
+ -1,
316
+ FastWalkMethods
317
+ };
318
+
319
+ PyMODINIT_FUNC PyInit__fast_walk(void) {
320
+ return PyModule_Create(&fast_walk_module);
321
+ }
@@ -141,8 +141,6 @@ class AAPLZGraph(LZGraphBase):
141
141
  self._log_step("Graph constructed.", verbose)
142
142
 
143
143
  # Normalize and derive probability dicts
144
- self.length_counts = dict(self.lengths)
145
-
146
144
  total_terminal = sum(self.terminal_state_counts.values())
147
145
  self.length_probabilities = (
148
146
  {k: v / total_terminal for k, v in self.terminal_state_counts.items()}
@@ -91,19 +91,10 @@ def graph_union(graphA, graphB):
91
91
  }
92
92
 
93
93
  # Merge length_distribution counts
94
- if hasattr(graphA, 'length_counts') and hasattr(graphB, 'length_counts'):
95
- for k, v in graphB.length_counts.items():
96
- graphA.length_counts[k] = graphA.length_counts.get(k, 0) + v
97
-
98
- # Merge observed gene sets
99
- if hasattr(graphB, 'observed_v_genes'):
100
- graphA.observed_v_genes = list(
101
- set(graphA.observed_v_genes) | set(graphB.observed_v_genes)
102
- )
103
- if hasattr(graphB, 'observed_j_genes'):
104
- graphA.observed_j_genes = list(
105
- set(graphA.observed_j_genes) | set(graphB.observed_j_genes)
106
- )
94
+ for k, v in graphB.lengths.items():
95
+ graphA.lengths[k] = graphA.lengths.get(k, 0) + v
96
+
97
+ # observed_v/j_genes are now derived from marginal_v/j_genes (already merged above)
107
98
 
108
99
  # 5. Recalculate ALL derived state from raw counts
109
100
  graphA.recalculate()
@@ -1,6 +1,7 @@
1
1
  import logging
2
2
  import re
3
3
  from abc import ABC, abstractmethod
4
+ from bisect import bisect_left as _bisect_left
4
5
  from time import time
5
6
 
6
7
  import networkx as nx
@@ -12,6 +13,12 @@ from ..utilities.misc import choice, window
12
13
  # Shared constants
13
14
  from ..constants import _EPS, _LOG_EPS
14
15
 
16
+ # Optional C extension for fast simulation
17
+ try:
18
+ from .._fast_walk import simulate_walks as _c_simulate_walks
19
+ except ImportError:
20
+ _c_simulate_walks = None
21
+
15
22
  # EdgeData
16
23
  from .edge_data import EdgeData
17
24
 
@@ -106,6 +113,19 @@ class LZGraphBase(
106
113
  # Topological order cache (built lazily, invalidated on structural changes)
107
114
  self._topo_order = None
108
115
 
116
+ # ------------------------------------------------------------------
117
+ # Derived properties (single source of truth)
118
+ # ------------------------------------------------------------------
119
+
120
+ @property
121
+ def length_counts(self):
122
+ """Alias for ``lengths`` — avoids storing the same dict twice."""
123
+ return self.lengths
124
+
125
+ @length_counts.setter
126
+ def length_counts(self, value):
127
+ self.lengths = value
128
+
109
129
  @staticmethod
110
130
  def _normalize_input(data, seq_column, abundances=None, v_genes=None, j_genes=None):
111
131
  """Convert flexible input to a standardised dict-of-lists.
@@ -242,7 +262,7 @@ class LZGraphBase(
242
262
  if self.has_gene_data and other.has_gene_data:
243
263
  aux += not _dicts_close(self.marginal_v_genes, other.marginal_v_genes, decimals=3)
244
264
  aux += not _dicts_close(self.vj_probabilities, other.vj_probabilities, decimals=3)
245
- aux += not _dicts_close(self.length_counts, other.length_counts, decimals=3)
265
+ aux += not _dicts_close(self.lengths, other.lengths, decimals=3)
246
266
 
247
267
  return (aux == 0)
248
268
 
@@ -802,6 +822,10 @@ class LZGraphBase(
802
822
  def _build_walk_cache(self, seed=None):
803
823
  """Build pre-computed numpy arrays for fast random walks.
804
824
 
825
+ Uses CSR (Compressed Sparse Row) format for neighbor data so
826
+ the entire walk can be driven by flat numpy arrays and
827
+ ``searchsorted`` instead of per-step ``rng.choice()`` calls.
828
+
805
829
  Returns a dict with the cache data, stored as ``self._walk_cache``.
806
830
  """
807
831
  graph = self.graph
@@ -814,17 +838,42 @@ class LZGraphBase(
814
838
  # Pre-compute clean labels for all nodes
815
839
  clean_labels = np.array([self.extract_subpattern(name) for name in nodes], dtype=object)
816
840
 
817
- # Per-node neighbor IDs and weights
818
- neighbor_ids = [None] * n
819
- neighbor_weights = [None] * n
841
+ # Build per-node neighbor/weight arrays.
842
+ # node_neighbors + node_weights: used by lzpgen_distribution
843
+ # node_cumweights: used by Python simulate fallback (bisect)
844
+ # CSR flat arrays are built only when the C extension is available.
845
+ node_neighbors = [None] * n # list of numpy int arrays
846
+ node_weights = [None] * n # list of numpy float arrays
847
+ node_cumweights = [None] * n # list of Python float lists (for bisect)
820
848
  for i, name in enumerate(nodes):
821
849
  succs = list(graph.successors(name))
822
850
  if succs:
823
851
  ids = np.array([node_to_id[s] for s in succs], dtype=np.intp)
852
+ node_neighbors[i] = ids
824
853
  wts = np.array([graph[name][s]['data'].weight for s in succs], dtype=np.float64)
825
- wts /= wts.sum() # ensure normalization
826
- neighbor_ids[i] = ids
827
- neighbor_weights[i] = wts
854
+ wts /= wts.sum()
855
+ node_weights[i] = wts
856
+ cw = np.cumsum(wts)
857
+ cw[-1] = 1.0 # clamp for floating point safety
858
+ node_cumweights[i] = cw.tolist()
859
+
860
+ # Build flat CSR arrays only if C extension is available
861
+ if _c_simulate_walks is not None:
862
+ csr_offsets = np.empty(n + 1, dtype=np.intp)
863
+ csr_parts_nb = []
864
+ csr_parts_cw = []
865
+ offset = 0
866
+ for i in range(n):
867
+ csr_offsets[i] = offset
868
+ if node_neighbors[i] is not None:
869
+ csr_parts_nb.append(node_neighbors[i])
870
+ csr_parts_cw.append(np.array(node_cumweights[i], dtype=np.float64))
871
+ offset += len(node_neighbors[i])
872
+ csr_offsets[n] = offset
873
+ csr_neighbors = np.concatenate(csr_parts_nb) if csr_parts_nb else np.empty(0, dtype=np.intp)
874
+ csr_cumweights = np.concatenate(csr_parts_cw) if csr_parts_cw else np.empty(0, dtype=np.float64)
875
+ else:
876
+ csr_offsets = csr_neighbors = csr_cumweights = None
828
877
 
829
878
  # Stop probabilities: NaN for non-terminal nodes
830
879
  stop_probs = np.full(n, np.nan, dtype=np.float64)
@@ -834,12 +883,19 @@ class LZGraphBase(
834
883
 
835
884
  # Initial state arrays
836
885
  init_states = list(self.initial_state_probabilities.keys())
886
+ if not init_states:
887
+ raise ValueError(
888
+ "Cannot simulate: graph has no initial states. "
889
+ "Ensure the graph was constructed with valid sequences."
890
+ )
837
891
  init_probs = np.array(
838
892
  [self.initial_state_probabilities[s] for s in init_states],
839
893
  dtype=np.float64,
840
894
  )
841
895
  init_probs = init_probs / init_probs.sum() # ensure normalization
842
896
  initial_ids = np.array([node_to_id[s] for s in init_states], dtype=np.intp)
897
+ initial_cumprobs = np.cumsum(init_probs)
898
+ initial_cumprobs[-1] = 1.0 # clamp
843
899
 
844
900
  rng = np.random.default_rng(seed)
845
901
 
@@ -847,11 +903,16 @@ class LZGraphBase(
847
903
  'node_to_id': node_to_id,
848
904
  'id_to_node': id_to_node,
849
905
  'clean_labels': clean_labels,
850
- 'neighbor_ids': neighbor_ids,
851
- 'neighbor_weights': neighbor_weights,
906
+ 'node_neighbors': node_neighbors,
907
+ 'node_weights': node_weights,
908
+ 'node_cumweights': node_cumweights,
909
+ 'csr_offsets': csr_offsets,
910
+ 'csr_neighbors': csr_neighbors,
911
+ 'csr_cumweights': csr_cumweights,
852
912
  'stop_probs': stop_probs,
853
913
  'initial_ids': initial_ids,
854
914
  'initial_probs': init_probs,
915
+ 'initial_cumprobs': initial_cumprobs,
855
916
  'rng': rng,
856
917
  }
857
918
  return self._walk_cache
@@ -879,46 +940,89 @@ class LZGraphBase(
879
940
  self._build_walk_cache(seed)
880
941
 
881
942
  cache = self._walk_cache
943
+ clean_labels = cache['clean_labels']
944
+ id_to_node = cache['id_to_node']
945
+
946
+ # ── C fast path ──────────────────────────────────────────────
947
+ if _c_simulate_walks is not None:
948
+ rng_seed = seed if seed is not None else int(cache['rng'].integers(0, 2**63))
949
+ return _c_simulate_walks(
950
+ n,
951
+ cache['csr_offsets'],
952
+ cache['csr_neighbors'],
953
+ cache['csr_cumweights'],
954
+ cache['stop_probs'],
955
+ cache['initial_ids'],
956
+ cache['initial_cumprobs'],
957
+ rng_seed,
958
+ list(clean_labels),
959
+ return_walks,
960
+ list(id_to_node),
961
+ )
962
+
963
+ # ── Python fallback ──────────────────────────────────────────
882
964
  rng = cache['rng']
883
965
  initial_ids = cache['initial_ids']
884
- initial_probs = cache['initial_probs']
966
+ initial_cumprobs = cache['initial_cumprobs']
885
967
  stop_probs = cache['stop_probs']
886
- neighbor_ids = cache['neighbor_ids']
887
- neighbor_weights = cache['neighbor_weights']
888
- clean_labels = cache['clean_labels']
889
- id_to_node = cache['id_to_node']
968
+ node_neighbors = cache['node_neighbors']
969
+ node_cumweights = cache['node_cumweights']
970
+
971
+ # Pre-generate random numbers in bulk for throughput
972
+ buf_size = max(n * 25, 1024)
973
+ rand_buf = rng.random(buf_size)
974
+ rand_idx = 0
890
975
 
976
+ # Local references to avoid repeated global/attribute lookups
977
+ bisect = _bisect_left
978
+ init_cumprobs_list = initial_cumprobs.tolist()
891
979
  results = []
980
+ results_append = results.append
981
+
892
982
  for _ in range(n):
893
- # Pick initial state
894
- current = rng.choice(initial_ids, p=initial_probs)
983
+ # Refill buffer if running low
984
+ if rand_idx + 50 > len(rand_buf):
985
+ rand_buf = rng.random(buf_size)
986
+ rand_idx = 0
987
+
988
+ # Pick initial state via bisect on cumulative probs
989
+ current = initial_ids[bisect(init_cumprobs_list, rand_buf[rand_idx])]
990
+ rand_idx += 1
895
991
  parts = [clean_labels[current]]
896
992
  walk_ids = [current] if return_walks else None
897
993
 
898
994
  while True:
899
995
  # Check stop condition
900
996
  stop_p = stop_probs[current]
901
- if not np.isnan(stop_p):
902
- if rng.random() < stop_p:
997
+ if stop_p == stop_p: # fast NaN check (NaN != NaN)
998
+ if rand_buf[rand_idx] < stop_p:
999
+ rand_idx += 1
903
1000
  break
1001
+ rand_idx += 1
904
1002
 
905
1003
  # Check for dead-end (no outgoing edges)
906
- nb_ids = neighbor_ids[current]
907
- if nb_ids is None:
1004
+ nb = node_neighbors[current]
1005
+ if nb is None:
908
1006
  break
909
1007
 
910
- # Take a step
911
- current = rng.choice(nb_ids, p=neighbor_weights[current])
1008
+ # Take a step via bisect on per-node cumulative weights
1009
+ current = nb[bisect(node_cumweights[current], rand_buf[rand_idx])]
1010
+ rand_idx += 1
912
1011
  parts.append(clean_labels[current])
913
1012
  if return_walks:
914
1013
  walk_ids.append(current)
915
1014
 
1015
+ # Refill buffer if running low
1016
+ if rand_idx + 50 > len(rand_buf):
1017
+ rand_buf = rng.random(buf_size)
1018
+ rand_idx = 0
1019
+
916
1020
  sequence = ''.join(parts)
917
1021
  if return_walks:
918
1022
  walk = [id_to_node[wid] for wid in walk_ids]
919
- results.append((walk, sequence))
1023
+ results_append((walk, sequence))
920
1024
  else:
921
- results.append(sequence)
1025
+ results_append(sequence)
922
1026
 
923
1027
  return results
924
1028
 
@@ -126,8 +126,6 @@ class NDPLZGraph(LZGraphBase):
126
126
  self._log_step("Graph constructed.", verbose)
127
127
 
128
128
  # Normalize and derive probability dicts
129
- self.length_counts = dict(self.lengths)
130
-
131
129
  total_terminal = sum(self.terminal_state_counts.values())
132
130
  self.length_probabilities = (
133
131
  {k: v / total_terminal for k, v in self.terminal_state_counts.items()}
@@ -17,6 +17,28 @@ class GeneLogicMixin:
17
17
  with probability distribution `weights`.
18
18
  """
19
19
 
20
+ @property
21
+ def observed_v_genes(self):
22
+ """Unique V genes — derived from ``marginal_v_genes`` keys."""
23
+ mg = getattr(self, 'marginal_v_genes', None)
24
+ return list(mg.keys()) if mg else []
25
+
26
+ @observed_v_genes.setter
27
+ def observed_v_genes(self, value):
28
+ # Accept sets from old pickles / JSON deserialization (no-op storage)
29
+ pass
30
+
31
+ @property
32
+ def observed_j_genes(self):
33
+ """Unique J genes — derived from ``marginal_j_genes`` keys."""
34
+ mg = getattr(self, 'marginal_j_genes', None)
35
+ return list(mg.keys()) if mg else []
36
+
37
+ @observed_j_genes.setter
38
+ def observed_j_genes(self, value):
39
+ # Accept sets from old pickles / JSON deserialization (no-op storage)
40
+ pass
41
+
20
42
  def _raise_genetic_mode_error(self):
21
43
  """
22
44
  Raise an error if genetic mode is off but a genetic function is called.
@@ -39,10 +61,6 @@ class GeneLogicMixin:
39
61
  v_list = data['v_genes']
40
62
  j_list = data['j_genes']
41
63
 
42
- # Unique sets of V and J
43
- self.observed_v_genes = list(set(v_list))
44
- self.observed_j_genes = list(set(j_list))
45
-
46
64
  # Marginal distributions (normalized) — stored as plain dicts
47
65
  n = len(v_list)
48
66
  v_counts = {}
@@ -56,8 +56,8 @@ class LZPgenDistributionMixin:
56
56
  initial_ids = cache['initial_ids']
57
57
  initial_probs = cache['initial_probs']
58
58
  stop_probs = cache['stop_probs']
59
- neighbor_ids = cache['neighbor_ids']
60
- neighbor_weights = cache['neighbor_weights']
59
+ node_neighbors = cache['node_neighbors']
60
+ node_weights = cache['node_weights']
61
61
 
62
62
  # Pre-compute log values for zero per-step overhead
63
63
  eps = _EPS
@@ -71,9 +71,9 @@ class LZPgenDistributionMixin:
71
71
 
72
72
  neighbor_log_weights = [None] * n_nodes
73
73
  for i in range(n_nodes):
74
- if neighbor_weights[i] is not None:
74
+ if node_weights[i] is not None:
75
75
  neighbor_log_weights[i] = np.log(
76
- np.maximum(neighbor_weights[i], eps)
76
+ np.maximum(node_weights[i], eps)
77
77
  )
78
78
 
79
79
  log_probs = np.empty(n, dtype=np.float64)
@@ -94,16 +94,16 @@ class LZPgenDistributionMixin:
94
94
  break
95
95
 
96
96
  # Dead-end check
97
- nb_ids = neighbor_ids[current]
98
- if nb_ids is None:
97
+ nb = node_neighbors[current]
98
+ if nb is None:
99
99
  log_p += np.log(eps)
100
100
  break
101
101
 
102
102
  # Take a step
103
- n_nb = len(nb_ids)
104
- step_idx = rng.choice(n_nb, p=neighbor_weights[current])
103
+ n_nb = len(nb)
104
+ step_idx = rng.choice(n_nb, p=node_weights[current])
105
105
  log_p += neighbor_log_weights[current][step_idx]
106
- current = nb_ids[step_idx]
106
+ current = nb[step_idx]
107
107
 
108
108
  log_probs[seq_idx] = log_p
109
109
 
@@ -32,7 +32,7 @@ class SerializationMixin:
32
32
  'subpattern_individual_probability': 'node_probability',
33
33
  'per_node_observed_frequency': 'node_outgoing_counts',
34
34
  'length_distribution_proba': 'length_probabilities',
35
- 'length_distribution': 'length_counts',
35
+ 'length_distribution': 'lengths',
36
36
  'n_subpatterns': 'num_subpatterns',
37
37
  'n_transitions': 'num_transitions',
38
38
  'marginal_vgenes': 'marginal_v_genes',
@@ -55,6 +55,20 @@ class SerializationMixin:
55
55
  'j_call': 'J',
56
56
  }
57
57
 
58
+ # Transient attributes excluded from pickle (rebuilt on demand)
59
+ _TRANSIENT_ATTRS = frozenset({
60
+ '_walk_cache',
61
+ '_topo_order',
62
+ '_edges_cache',
63
+ 'constructor_start_time',
64
+ 'constructor_end_time',
65
+ })
66
+
67
+ def __getstate__(self):
68
+ """Exclude transient caches from pickle to reduce file size."""
69
+ return {k: v for k, v in self.__dict__.items()
70
+ if k not in self._TRANSIENT_ATTRS}
71
+
58
72
  def __setstate__(self, state):
59
73
  """Restore instance from pickle, migrating old attribute names and pandas types."""
60
74
  # Migrate old attribute names to new names
@@ -62,6 +76,17 @@ class SerializationMixin:
62
76
  if old_name in state and new_name not in state:
63
77
  state[new_name] = state.pop(old_name)
64
78
 
79
+ # length_counts is now a property aliasing lengths — migrate old pickles
80
+ if 'length_counts' in state:
81
+ if 'lengths' not in state:
82
+ state['lengths'] = state.pop('length_counts')
83
+ else:
84
+ del state['length_counts']
85
+
86
+ # observed_v/j_genes are now properties — remove stored values
87
+ state.pop('observed_v_genes', None)
88
+ state.pop('observed_j_genes', None)
89
+
65
90
  # Migrate 'wsif/sep' key inside terminal_state_data dicts
66
91
  tsd = state.get('terminal_state_data')
67
92
  if tsd is not None and isinstance(tsd, dict):
@@ -80,7 +105,7 @@ class SerializationMixin:
80
105
  for attr in ('initial_state_counts', 'terminal_state_counts',
81
106
  'initial_state_probabilities', 'length_probabilities',
82
107
  'marginal_v_genes', 'marginal_j_genes', 'vj_probabilities',
83
- 'length_counts'):
108
+ 'lengths'):
84
109
  val = getattr(self, attr, None)
85
110
  if val is not None and hasattr(val, 'to_dict'):
86
111
  setattr(self, attr, val.to_dict())
@@ -324,12 +349,7 @@ class SerializationMixin:
324
349
  data['marginal_j_genes'] = _to_dict(self.marginal_j_genes)
325
350
  if hasattr(self, 'vj_probabilities'):
326
351
  data['vj_probabilities'] = _to_dict(self.vj_probabilities)
327
- if hasattr(self, 'length_counts'):
328
- data['length_counts'] = _to_dict(self.length_counts)
329
- if hasattr(self, 'observed_v_genes'):
330
- data['observed_v_genes'] = list(self.observed_v_genes)
331
- if hasattr(self, 'observed_j_genes'):
332
- data['observed_j_genes'] = list(self.observed_j_genes)
352
+ data['length_counts'] = _to_dict(self.lengths)
333
353
 
334
354
  # Terminal state data
335
355
  if hasattr(self, 'terminal_state_data'):
@@ -421,7 +441,9 @@ class SerializationMixin:
421
441
  instance.terminal_state_counts = _to_plain_dict(
422
442
  data.get('terminal_state_counts', data.get('terminal_states'))
423
443
  )
424
- instance.lengths = data.get('lengths', {})
444
+ instance.lengths = data.get('lengths',
445
+ data.get('length_counts',
446
+ data.get('length_distribution', {})))
425
447
  instance.vj_combination_graphs = {}
426
448
  instance.num_neighbours = {}
427
449
  instance.node_outgoing_counts = data.get('node_outgoing_counts',
@@ -451,15 +473,6 @@ class SerializationMixin:
451
473
  instance.marginal_j_genes = _to_plain_dict(mg_j)
452
474
  if 'vj_probabilities' in data:
453
475
  instance.vj_probabilities = _to_plain_dict(data['vj_probabilities'])
454
- lc = data.get('length_counts', data.get('length_distribution'))
455
- if lc is not None:
456
- instance.length_counts = _to_plain_dict(lc)
457
- ov = data.get('observed_v_genes', data.get('observed_vgenes'))
458
- if ov is not None:
459
- instance.observed_v_genes = set(ov)
460
- oj = data.get('observed_j_genes', data.get('observed_jgenes'))
461
- if oj is not None:
462
- instance.observed_j_genes = set(oj)
463
476
 
464
477
  # Restore terminal state data
465
478
  if 'terminal_state_data' in data:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: LZGraphs
3
- Version: 2.2.0
3
+ Version: 2.3.0
4
4
  Summary: An Implementation of LZ76 Based Graphs for Repertoire Representation and Analysis
5
5
  Author-email: Thomas Konstantinovsky <thomaskon90@gmail.com>
6
6
  Maintainer-email: Thomas Konstantinovsky <thomaskon90@gmail.com>
@@ -6,7 +6,9 @@ README.md
6
6
  pyproject.toml
7
7
  requirements.txt
8
8
  setup.cfg
9
+ setup.py
9
10
  src/LZGraphs/__init__.py
11
+ src/LZGraphs/_fast_walk.c
10
12
  src/LZGraphs/constants.py
11
13
  src/LZGraphs/py.typed
12
14
  src/LZGraphs.egg-info/PKG-INFO
@@ -58,8 +58,13 @@ class TestLZPgenDistributionBasic:
58
58
  result = aap_lzgraph.lzpgen_distribution(n=0, seed=42)
59
59
  assert len(result) == 0
60
60
 
61
- def test_consistent_with_walk_log_probability(self, aap_lzgraph):
61
+ def test_consistent_with_walk_log_probability(self, aap_lzgraph, monkeypatch):
62
62
  """Values should match walk_log_probability on the same walks."""
63
+ # Force Python fallback so simulate() and lzpgen_distribution()
64
+ # use the same numpy RNG and produce identical walks for the same seed.
65
+ import LZGraphs.graphs.lz_graph_base as _base
66
+ monkeypatch.setattr(_base, '_c_simulate_walks', None)
67
+
63
68
  walks_and_seqs = aap_lzgraph.simulate(20, seed=42, return_walks=True)
64
69
  dist = aap_lzgraph.lzpgen_distribution(n=20, seed=42)
65
70
 
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes