relucent 0.2.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,83 @@
1
+ Metadata-Version: 2.4
2
+ Name: relucent
3
+ Version: 0.2.1
4
+ Summary: Explore polyhedral complexes associated with the activation states of ReLU neural networks.
5
+ Author: Blake B. Gaines
6
+ License-Expression: AGPL-3.0-or-later
7
+ Requires-Dist: pandas>=2.3
8
+ Requires-Dist: gurobipy>=12.0
9
+ Requires-Dist: networkx>=3.6
10
+ Requires-Dist: matplotlib>=3.10
11
+ Requires-Dist: numpy>=2.4
12
+ Requires-Dist: pillow>=12.1
13
+ Requires-Dist: plotly>=6.3
14
+ Requires-Dist: scikit-learn>=1.8
15
+ Requires-Dist: scipy>=1.17
16
+ Requires-Dist: tqdm>=4.67
17
+ Requires-Dist: torch>=2.13 ; extra == 'cli'
18
+ Requires-Dist: torchvision ; extra == 'cli'
19
+ Requires-Dist: pyvis>=0.3 ; extra == 'cli'
20
+ Requires-Dist: kaleido ; extra == 'cli'
21
+ Requires-Python: >=3.13, <3.14
22
+ Project-URL: Repository, https://github.com/bl-ake/relucent
23
+ Project-URL: Documentation, https://bl-ake.github.io/relucent/
24
+ Provides-Extra: cli
25
+ Description-Content-Type: text/markdown
26
+
27
+ [![Usable](https://github.com/bl-ake/relucent/actions/workflows/python-package.yml/badge.svg)](https://github.com/bl-ake/relucent/actions/workflows/python-package.yml)
28
+ [![Latest Release](https://img.shields.io/github/v/tag/bl-ake/relucent?label=Latest%20Release)](https://github.com/bl-ake/relucent/releases)
29
+
30
+ # Relucent
31
+ Explore polyhedral complexes associated with the activation states of ReLU neural networks
32
+
33
+ ## Environment Setup
34
+ 1. Install Python 3.13
35
+ 2. Install [PyTorch >= 2.3.0](https://pytorch.org/get-started/locally/)
36
+ 3. Run `pip install relucent`
37
+
38
+ ## Getting Started
39
+ To see if the installation has been successful, try plotting the complex of a randomly initialized network in 2 dimensions like this:
40
+ ```
41
+ from relucent import Complex, get_mlp_model
42
+
43
+ network = get_mlp_model(widths=[2, 10, 5, 1])
44
+ cplx = Complex(network)
45
+ cplx.bfs()
46
+ fig = cplx.plot(bound=10000)
47
+ fig.show()
48
+ ```
49
+
50
+ The "NN" object returned by get_mlp_model inherits from torch.nn.Module, so you can train and manipulate it just like you're used to :)
51
+
52
+ Given some input point, you could get a minimal H-representation of the polyhedron containing it like this:
53
+ ```
54
+ import numpy as np
55
+
56
+ input_point = np.random.random(network.input_shape)
57
+ p = cplx.point2poly(input_point)
58
+ print(p.halfspaces[p.shis, :])
59
+ ```
60
+
61
+ You could also check the average number of faces of all polyhedrons with:
62
+ ```
63
+ sum(len(p.shis) for p in cplx) / len(cplx)
64
+ ```
65
+ Or, get the adjacency graph of top-dimensional cells in the complex as a [NetworkX Graph](https://networkx.org/documentation/stable/tutorial.html) with:
66
+ ```
67
+ print(cplx.get_dual_graph())
68
+ ```
69
+
70
+ View the documentation for this library at https://bl-ake.github.io/relucent/
71
+
72
+ ## Source Code Structure
73
+ * [model.py](src/relucent/model.py): PyTorch Module that acts as an interface between the model and the rest of the code
74
+ * [poly.py](src/relucent/poly.py): Class for calculations involving individual polyhedrons (e.g. computing boundaries, neighbors, volume)
75
+ * [complex.py](src/relucent/complex.py): Class for calculations involving the polyhedral cplx (e.g. polyhedron search, connectivity graph calculation)
76
+ * [convert_model.py](src/relucent/convert_model.py): Utilities for converting various PyTorch.nn layers to Linear layers
77
+ * [bvs.py](src/relucent/bvs.py): Data structures for storing large numbers of sign vectors
78
+
79
+ ## Obtaining a Gurobi License
80
+ Without a [license](https://support.gurobi.com/hc/en-us/articles/12872879801105-How-do-I-retrieve-and-set-up-a-Gurobi-license), Gurobi will only work with a limited feature set. This includes a limit on the number of decision variables in the models it can solve, which limits the size of the networks this code is able to analyze. There are multiple ways to install the software, but we recommend the following steps to those eligible for an academic license:
81
+ 1. Install the [Gurobi Python library](https://pypi.org/project/gurobipy/), for example using `pip install gurobipy`
82
+ 2. [Obtain a Gurobi license](https://support.gurobi.com/hc/en-us/articles/360040541251-How-do-I-obtain-a-free-academic-license) (Note: a WLS license will limit the number of concurrent sessions across multiple devices, which can result in slowdowns when using this library on different machines simultaneously.)
83
+ 3. In your Conda environment, run `grbgetkey` followed by your license key
@@ -0,0 +1,57 @@
1
+ [![Usable](https://github.com/bl-ake/relucent/actions/workflows/python-package.yml/badge.svg)](https://github.com/bl-ake/relucent/actions/workflows/python-package.yml)
2
+ [![Latest Release](https://img.shields.io/github/v/tag/bl-ake/relucent?label=Latest%20Release)](https://github.com/bl-ake/relucent/releases)
3
+
4
+ # Relucent
5
+ Explore polyhedral complexes associated with the activation states of ReLU neural networks
6
+
7
+ ## Environment Setup
8
+ 1. Install Python 3.13
9
+ 2. Install [PyTorch >= 2.3.0](https://pytorch.org/get-started/locally/)
10
+ 3. Run `pip install relucent`
11
+
12
+ ## Getting Started
13
+ To see if the installation has been successful, try plotting the complex of a randomly initialized network in 2 dimensions like this:
14
+ ```
15
+ from relucent import Complex, get_mlp_model
16
+
17
+ network = get_mlp_model(widths=[2, 10, 5, 1])
18
+ cplx = Complex(network)
19
+ cplx.bfs()
20
+ fig = cplx.plot(bound=10000)
21
+ fig.show()
22
+ ```
23
+
24
+ The "NN" object returned by get_mlp_model inherits from torch.nn.Module, so you can train and manipulate it just like you're used to :)
25
+
26
+ Given some input point, you could get a minimal H-representation of the polyhedron containing it like this:
27
+ ```
28
+ import numpy as np
29
+
30
+ input_point = np.random.random(network.input_shape)
31
+ p = cplx.point2poly(input_point)
32
+ print(p.halfspaces[p.shis, :])
33
+ ```
34
+
35
+ You could also check the average number of faces of all polyhedrons with:
36
+ ```
37
+ sum(len(p.shis) for p in cplx) / len(cplx)
38
+ ```
39
+ Or, get the adjacency graph of top-dimensional cells in the complex as a [NetworkX Graph](https://networkx.org/documentation/stable/tutorial.html) with:
40
+ ```
41
+ print(cplx.get_dual_graph())
42
+ ```
43
+
44
+ View the documentation for this library at https://bl-ake.github.io/relucent/
45
+
46
+ ## Source Code Structure
47
+ * [model.py](src/relucent/model.py): PyTorch Module that acts as an interface between the model and the rest of the code
48
+ * [poly.py](src/relucent/poly.py): Class for calculations involving individual polyhedrons (e.g. computing boundaries, neighbors, volume)
49
+ * [complex.py](src/relucent/complex.py): Class for calculations involving the polyhedral cplx (e.g. polyhedron search, connectivity graph calculation)
50
+ * [convert_model.py](src/relucent/convert_model.py): Utilities for converting various PyTorch.nn layers to Linear layers
51
+ * [bvs.py](src/relucent/bvs.py): Data structures for storing large numbers of sign vectors
52
+
53
+ ## Obtaining a Gurobi License
54
+ Without a [license](https://support.gurobi.com/hc/en-us/articles/12872879801105-How-do-I-retrieve-and-set-up-a-Gurobi-license), Gurobi will only work with a limited feature set. This includes a limit on the number of decision variables in the models it can solve, which limits the size of the networks this code is able to analyze. There are multiple ways to install the software, but we recommend the following steps to those eligible for an academic license:
55
+ 1. Install the [Gurobi Python library](https://pypi.org/project/gurobipy/), for example using `pip install gurobipy`
56
+ 2. [Obtain a Gurobi license](https://support.gurobi.com/hc/en-us/articles/360040541251-How-do-I-obtain-a-free-academic-license) (Note: a WLS license will limit the number of concurrent sessions across multiple devices, which can result in slowdowns when using this library on different machines simultaneously.)
57
+ 3. In your Conda environment, run `grbgetkey` followed by your license key
@@ -0,0 +1,40 @@
1
+ [build-system]
2
+ requires = [
3
+ "uv_build >= 0.9.5, <0.10.0",
4
+ ]
5
+ build-backend = "uv_build"
6
+
7
+ [project]
8
+ name = "relucent"
9
+ description = "Explore polyhedral complexes associated with the activation states of ReLU neural networks."
10
+ version = "0.2.1"
11
+ requires-python = ">= 3.13, <3.14"
12
+ dependencies = [
13
+ "pandas>=2.3",
14
+ "gurobipy>=12.0",
15
+ "networkx>=3.6",
16
+ "matplotlib>=3.10",
17
+ "numpy>=2.4",
18
+ "Pillow>=12.1",
19
+ "plotly>=6.3",
20
+ "scikit_learn>=1.8",
21
+ "scipy>=1.17",
22
+ "tqdm>=4.67",
23
+ ]
24
+ authors = [
25
+ { name = "Blake B. Gaines" },
26
+ ]
27
+ license = "AGPL-3.0-or-later"
28
+ readme = "README.md"
29
+
30
+ [project.optional-dependencies]
31
+ cli = [
32
+ "torch>=2.13",
33
+ "torchvision",
34
+ "pyvis>=0.3",
35
+ "kaleido",
36
+ ]
37
+
38
+ [project.urls]
39
+ Repository = "https://github.com/bl-ake/relucent"
40
+ Documentation = "https://bl-ake.github.io/relucent/"
@@ -0,0 +1,18 @@
1
+ try:
2
+ from torch import __version__
3
+ from torchvision import __version__ # noqa
4
+ except ImportError:
5
+ raise ImportError(
6
+ "Relucent requires PyTorch to be installed manually. "
7
+ "Please install the version compatible with your system from: "
8
+ "https://pytorch.org/get-started/previous-versions/#:~:text=org/whl/cpu-,v2.3.0"
9
+ )
10
+
11
+ from .bvs import BVManager
12
+ from .complex import Complex
13
+ from .poly import Polyhedron
14
+ from .model import NN, get_mlp_model
15
+ from .convert_model import convert
16
+ from .utils import get_env, split_sequential, set_seeds
17
+
18
+ __all__ = [Complex, Polyhedron, NN, get_mlp_model, BVManager, convert, get_env, split_sequential, set_seeds]
@@ -0,0 +1,257 @@
1
+ from heapq import heappop, heappush
2
+
3
+ from torch import Tensor
4
+
5
+ from relucent.poly import encode_bv
6
+
7
+
8
+ class BVManager:
9
+ """Manages storage and lookup of sign sequences.
10
+
11
+ This class provides a dictionary-like interface for storing and retrieving
12
+ sign sequences (arrays with values in {-1, 0, 1}). It maintains an index
13
+ mapping and allows efficient membership testing and retrieval.
14
+
15
+ Sign sequences are encoded as hashable tags for efficient storage and lookup.
16
+ """
17
+
18
+ def __init__(self):
19
+ self.index2bv = list()
20
+ self.tag2index = dict() ## Tags are just hashable versions of bvs, should be unique
21
+ self._len = 0
22
+
23
+ def _get_tag(self, bv):
24
+ if isinstance(bv, Tensor):
25
+ bv = bv.detach().cpu().numpy()
26
+ return encode_bv(bv)
27
+
28
+ def add(self, bv):
29
+ """Add a sign sequence to the manager.
30
+
31
+ Args:
32
+ bv: A sign sequence as torch.Tensor or np.ndarray.
33
+ """
34
+ tag = self._get_tag(bv)
35
+ if tag not in self.tag2index:
36
+ self.tag2index[tag] = len(self.index2bv)
37
+ self.index2bv.append(bv)
38
+ self._len += 1
39
+
40
+ def __getitem__(self, bv):
41
+ tag = self._get_tag(bv)
42
+ index = self.tag2index[tag]
43
+ if self.index2bv[index] is None:
44
+ raise KeyError
45
+ return index
46
+
47
+ def __contains__(self, bv):
48
+ tag = self._get_tag(bv)
49
+ if tag not in self.tag2index:
50
+ return False
51
+ return self.index2bv[self.tag2index[tag]] is not None
52
+
53
+ def __delitem__(self, bv):
54
+ tag = self._get_tag(bv)
55
+ index = self.tag2index[tag]
56
+ self.index2bv[index] = None
57
+ self._len -= 1
58
+
59
+ def __iter__(self):
60
+ return iter((bv for bv in self.index2bv if bv is not None))
61
+
62
+ def __len__(self):
63
+ return self._len
64
+
65
+
66
+ # TODO: Move to utils as general priority queue
67
+ class BVPriorityQueue:
68
+ """Priority queue for tasks with sign sequences.
69
+
70
+ A priority queue implementation that supports updating task priorities and
71
+ removing tasks. Tasks are tuples starting with a sign sequence (BV) followed
72
+ by additional data. Based on the heapq implementation from Python docs.
73
+
74
+ Reference: https://docs.python.org/3/library/heapq.html
75
+ """
76
+
77
+ REMOVED = "<removed-task>" # placeholder for a removed task
78
+
79
+ def __init__(self):
80
+ self.pq = [] # list of entries arranged in a heap
81
+ self.entry_finder = {} # mapping of tasks to entries
82
+ self.counter = 0 # unique sequence count
83
+
84
+ def push(self, task, priority=0):
85
+ """Add a new task or update the priority of an existing task.
86
+
87
+ Args:
88
+ task: A tuple starting with a sign sequence followed by
89
+ additional task data.
90
+ priority: The priority value (lower = higher priority). Defaults to 0.
91
+ """
92
+ bv, *task = task
93
+ task = tuple(task)
94
+ if task in self.entry_finder:
95
+ self.remove_task(task)
96
+ entry = [priority, self.counter, bv, task]
97
+ self.entry_finder[task] = entry
98
+ heappush(self.pq, entry)
99
+ self.counter += 1
100
+
101
+ def remove_task(self, task):
102
+ "Mark an existing task as REMOVED. Raise KeyError if not found."
103
+ entry = self.entry_finder.pop(task)
104
+ entry[-1] = self.REMOVED
105
+
106
+ def pop(self):
107
+ """Remove and return the lowest priority task.
108
+
109
+ Returns:
110
+ tuple: A tuple starting with the sign sequence (BV) followed by
111
+ the task data.
112
+
113
+ Raises:
114
+ KeyError: If the queue is empty.
115
+ """
116
+ while self.pq:
117
+ _, _, bv, task = heappop(self.pq)
118
+ if task is not self.REMOVED:
119
+ del self.entry_finder[task]
120
+ return bv, *task
121
+ raise KeyError("pop from an empty priority queue")
122
+
123
+ def __len__(self):
124
+ return len(self.entry_finder)
125
+
126
+
127
+ # class BVPriorityQueue:
128
+ # """Simpler, less efficient version of the one above for debugging"""
129
+ # def __init__(self):
130
+ # self.pq = [] # list of entries arranged in a heap
131
+
132
+ # def push(self, task, priority=0):
133
+ # "Add a new task or update the priority of an existing task"
134
+ # # bv, *task = task
135
+ # self.pq.append((priority, task))
136
+ # self.pq.sort(reverse=True, key=lambda x: x[0])
137
+
138
+ # def remove_task(self, task):
139
+ # "Mark an existing task as REMOVED. Raise KeyError if not found."
140
+ # self.pq = [(p, t) for p, t in self.pq if t != task]
141
+
142
+ # def pop(self):
143
+ # "Remove and return the lowest priority task. Raise KeyError if empty."
144
+ # return self.pq.pop(-1)[1]
145
+
146
+ # def __len__(self):
147
+ # return len(self.pq)
148
+
149
+
150
+ # class BVNode:
151
+ # def __init__(self, key):
152
+ # self.key = key ## Key of bv being set
153
+ # self.left = None ## for all nodes in subtree, bv[0, key] = -1
154
+ # self.middle = None ## ...bv[0, key] = 0
155
+ # self.right = None ## ...bv[0, key] = 1
156
+
157
+ # def get_child(self, bv):
158
+ # # return either BVNode or int
159
+ # if bv[0, self.key] == -1:
160
+ # next_node = self.left
161
+ # elif bv[0, self.key] == 0:
162
+ # next_node = self.middle
163
+ # elif bv[0, self.key] == 1:
164
+ # next_node = self.right
165
+ # return next_node
166
+
167
+ # def set_child(self, bv, node):
168
+ # if bv[0, self.key] == -1:
169
+ # self.left = node
170
+ # elif bv[0, self.key] == 0:
171
+ # self.middle = node
172
+ # elif bv[0, self.key] == 1:
173
+ # self.right = node
174
+
175
+ # def print(self, level=0):
176
+ # print(" " * level + str(self.key) + ":")
177
+ # for name, k in zip(("L", "M", "R"), (self.left, self.middle, self.right)):
178
+ # if isinstance(k, BVNode):
179
+ # print(" " * (level + 2) + name)
180
+ # k.print(level=level + 4)
181
+ # elif isinstance(k, int):
182
+ # print(" " * (level + 2) + name + ": leaf " + str(k))
183
+
184
+ # # Trie
185
+ # # Each edge in the tree sets a dimension to a value
186
+ # # Leaf nodes are just indices of bvs in index2bv
187
+ # class BVManager:
188
+ # def __init__(self):
189
+ # self.root = BVNode(0)
190
+ # self.index2bv = list()
191
+
192
+ # def add(self, bv):
193
+ # assert bv.ndim == 2
194
+ # node = self.root
195
+ # child = self.root.get_child(bv)
196
+ # while isinstance(child, BVNode):
197
+ # node = child
198
+ # child = node.get_child(bv)
199
+ # if child is None:
200
+ # node.set_child(bv, len(self.index2bv))
201
+ # self.index2bv.append(bv)
202
+ # elif isinstance(child, int): ## TODO: This check should be redundant
203
+ # child_bv = self.index2bv[child]
204
+ # if not isinstance(child_bv, type(bv)):
205
+ # if isinstance(bv, Tensor):
206
+ # bv = bv.detach().cpu().numpy()
207
+ # else:
208
+ # child_bv = child_bv.detach().cpu().numpy()
209
+ # if not (child_bv == bv).all():
210
+ # if child_bv[0, node.key] == bv[0, node.key]: ## TODO: This check should be redundant
211
+ # for i in range(bv.shape[1]):
212
+ # if child_bv[0, i] != bv[0, i]:
213
+ # break
214
+
215
+ # ## Replace the existing child of the node with a new node
216
+ # new_bvnode = BVNode(i)
217
+ # node.set_child(bv, new_bvnode)
218
+
219
+ # ## Set the new node's children
220
+ # new_bvnode.set_child(child_bv, child)
221
+ # new_bvnode.set_child(bv, len(self.index2bv))
222
+ # self.index2bv.append(bv)
223
+ # else:
224
+ # raise ValueError("Something went wrong")
225
+
226
+ # def __getitem__(self, bv):
227
+ # assert bv.ndim == 2
228
+ # node = self.root
229
+ # while isinstance(node, BVNode):
230
+ # node = node.get_child(bv)
231
+ # if isinstance(node, int):
232
+ # found_bv = self.index2bv[node]
233
+ # if not isinstance(bv, type(found_bv)):
234
+ # if isinstance(bv, Tensor):
235
+ # bv = bv.detach().cpu().numpy()
236
+ # else:
237
+ # found_bv = found_bv.detach().cpu().numpy()
238
+ # if (found_bv == bv).all():
239
+ # return node
240
+ # raise KeyError
241
+
242
+ # def __contains__(self, bv):
243
+ # try:
244
+ # self[bv]
245
+ # return True
246
+ # except KeyError:
247
+ # return False
248
+
249
+ # def __iter__(self):
250
+ # return iter(self.index2bv)
251
+
252
+ # def __len__(self):
253
+ # return len(self.index2bv)
254
+
255
+ # def print(self):
256
+ # print("Number of BVs:", len(self))
257
+ # self.root.print()