gridvoting-jax 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gridvoting_jax/__init__.py +672 -0
- gridvoting_jax/benchmarks/__init__.py +6 -0
- gridvoting_jax/benchmarks/osf_comparison.py +440 -0
- gridvoting_jax/benchmarks/performance.py +124 -0
- gridvoting_jax/py.typed +0 -0
- gridvoting_jax-0.2.0.dist-info/METADATA +473 -0
- gridvoting_jax-0.2.0.dist-info/RECORD +10 -0
- gridvoting_jax-0.2.0.dist-info/WHEEL +5 -0
- gridvoting_jax-0.2.0.dist-info/licenses/LICENSE.md +8 -0
- gridvoting_jax-0.2.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,672 @@
|
|
|
1
|
+
__version__ = "0.2.0"
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
import numpy as np
|
|
5
|
+
import matplotlib.pyplot as plt
|
|
6
|
+
import matplotlib.cm as cm
|
|
7
|
+
from warnings import warn
|
|
8
|
+
import jax
|
|
9
|
+
import jax.numpy as jnp
|
|
10
|
+
import chex
|
|
11
|
+
|
|
12
|
+
# Default tolerance for floating-point calculations
|
|
13
|
+
# 5e-5 for float32 (default), 1e-10 for float64 (after calling enable_float64())
|
|
14
|
+
TOLERANCE = 5e-5
|
|
15
|
+
|
|
16
|
+
def enable_float64():
|
|
17
|
+
"""Enable 64-bit floating point precision in JAX.
|
|
18
|
+
|
|
19
|
+
By default, JAX uses 32-bit floats for better GPU performance.
|
|
20
|
+
Call this function to enable 64-bit precision for higher accuracy.
|
|
21
|
+
|
|
22
|
+
This is a global configuration that affects all subsequent JAX operations.
|
|
23
|
+
See: https://docs.jax.dev/en/latest/default_dtypes.html
|
|
24
|
+
|
|
25
|
+
Example:
|
|
26
|
+
>>> import gridvoting_jax as gv
|
|
27
|
+
>>> gv.enable_float64()
|
|
28
|
+
>>> # All subsequent JAX operations will use float64
|
|
29
|
+
"""
|
|
30
|
+
jax.config.update("jax_enable_x64", True)
|
|
31
|
+
|
|
32
|
+
# Device detection with GV_FORCE_CPU override
|
|
33
|
+
use_accelerator = False
|
|
34
|
+
device_type = 'cpu'
|
|
35
|
+
|
|
36
|
+
if os.environ.get('GV_FORCE_CPU', '0') != '1':
|
|
37
|
+
# Check for available accelerators (TPU > GPU > CPU)
|
|
38
|
+
devices = jax.devices()
|
|
39
|
+
if devices:
|
|
40
|
+
default_device = devices[0]
|
|
41
|
+
device_type = default_device.platform
|
|
42
|
+
if device_type in ['gpu', 'tpu']:
|
|
43
|
+
use_accelerator = True
|
|
44
|
+
warn(f"JAX using {device_type.upper()}: {default_device}")
|
|
45
|
+
else:
|
|
46
|
+
warn("JAX using CPU (no GPU/TPU detected)")
|
|
47
|
+
else:
|
|
48
|
+
warn("GV_FORCE_CPU=1: JAX forced to CPU-only mode")
|
|
49
|
+
|
|
50
|
+
@jax.jit
|
|
51
|
+
def dist_sqeuclidean(XA, XB):
|
|
52
|
+
"""JAX-based squared Euclidean pairwise distance calculation.
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
XA: array of shape (m, n)
|
|
56
|
+
XB: array of shape (p, n)
|
|
57
|
+
|
|
58
|
+
Returns:
|
|
59
|
+
Distance matrix of shape (m, p)
|
|
60
|
+
"""
|
|
61
|
+
XA = jnp.asarray(XA)
|
|
62
|
+
XB = jnp.asarray(XB)
|
|
63
|
+
# Squared Euclidean: ||a-b||^2 = ||a||^2 + ||b||^2 - 2*a·b
|
|
64
|
+
XA_sq = jnp.sum(XA**2, axis=1, keepdims=True)
|
|
65
|
+
XB_sq = jnp.sum(XB**2, axis=1, keepdims=True)
|
|
66
|
+
return XA_sq + XB_sq.T - 2 * jnp.dot(XA, XB.T)
|
|
67
|
+
|
|
68
|
+
@jax.jit
|
|
69
|
+
def dist_manhattan(XA, XB):
|
|
70
|
+
"""JAX-based Manhattan pairwise distance calculation.
|
|
71
|
+
|
|
72
|
+
Args:
|
|
73
|
+
XA: array of shape (m, n)
|
|
74
|
+
XB: array of shape (p, n)
|
|
75
|
+
|
|
76
|
+
Returns:
|
|
77
|
+
Distance matrix of shape (m, p)
|
|
78
|
+
"""
|
|
79
|
+
XA = jnp.asarray(XA)
|
|
80
|
+
XB = jnp.asarray(XB)
|
|
81
|
+
# Manhattan distance: sum(|a-b|)
|
|
82
|
+
return jnp.sum(jnp.abs(XA[:, None, :] - XB[None, :, :]), axis=2)
|
|
83
|
+
|
|
84
|
+
@jax.jit
|
|
85
|
+
def _is_in_triangle_single(p, a, b, c):
|
|
86
|
+
"""
|
|
87
|
+
Returns True if point p is in triangle (a, b, c).
|
|
88
|
+
Robust for arbitrary vertex winding (CW or CCW).
|
|
89
|
+
|
|
90
|
+
Args:
|
|
91
|
+
p: Point as [x, y]
|
|
92
|
+
a, b, c: Triangle vertices as [x, y]
|
|
93
|
+
|
|
94
|
+
Returns:
|
|
95
|
+
Boolean indicating if p is inside triangle
|
|
96
|
+
|
|
97
|
+
See also: computational geometry, half-plane test;
|
|
98
|
+
Stack Overflow answer to https://stackoverflow.com/questions/2049582/how-to-determine-if-a-point-is-in-a-2d-triangle
|
|
99
|
+
https://stackoverflow.com/a/2049593/103081
|
|
100
|
+
by https://stackoverflow.com/users/233522/kornel-kisielewicz
|
|
101
|
+
"""
|
|
102
|
+
def cross(o, a, b):
|
|
103
|
+
return (a[0] - o[0]) * (b[1] - o[1]) - (a[1] - o[1]) * (b[0] - o[0])
|
|
104
|
+
|
|
105
|
+
s1 = cross(p, a, b)
|
|
106
|
+
s2 = cross(p, b, c)
|
|
107
|
+
s3 = cross(p, c, a)
|
|
108
|
+
|
|
109
|
+
# Small epsilon for numerical tolerance on edges/vertices
|
|
110
|
+
eps = 1e-10
|
|
111
|
+
has_neg = (s1 < -eps) | (s2 < -eps) | (s3 < -eps)
|
|
112
|
+
has_pos = (s1 > eps) | (s2 > eps) | (s3 > eps)
|
|
113
|
+
|
|
114
|
+
return ~(has_neg & has_pos)
|
|
115
|
+
|
|
116
|
+
@jax.jit
|
|
117
|
+
def _move_neg_prob_to_max(pvector):
|
|
118
|
+
"""Fix negative probability components by moving mass to maximum values.
|
|
119
|
+
|
|
120
|
+
Redistributes the total mass from negative components equally among
|
|
121
|
+
all indices that share the maximum value (within TOLERANCE).
|
|
122
|
+
|
|
123
|
+
Args:
|
|
124
|
+
pvector: JAX array that may contain small negative values
|
|
125
|
+
|
|
126
|
+
Returns:
|
|
127
|
+
fixed_pvector: JAX array with negative values zeroed and mass
|
|
128
|
+
redistributed equally to all maximum-value indices
|
|
129
|
+
"""
|
|
130
|
+
# Identify negative components and calculate mass to redistribute
|
|
131
|
+
# Use jnp.where to avoid boolean indexing which is incompatible with JIT
|
|
132
|
+
to_zero = pvector < 0.0
|
|
133
|
+
mass_destroyed = jnp.where(to_zero, pvector, 0.0).sum()
|
|
134
|
+
|
|
135
|
+
# Zero out negative components
|
|
136
|
+
fixed_pvector = jnp.where(to_zero, 0.0, pvector)
|
|
137
|
+
|
|
138
|
+
# Find ALL indices with maximum value (within TOLERANCE)
|
|
139
|
+
max_val = fixed_pvector.max()
|
|
140
|
+
is_max = jnp.abs(fixed_pvector - max_val) < TOLERANCE
|
|
141
|
+
num_max_indices = is_max.sum()
|
|
142
|
+
|
|
143
|
+
# Distribute mass equally among all maximum indices
|
|
144
|
+
mass_per_index = mass_destroyed / num_max_indices
|
|
145
|
+
fixed_pvector = jnp.where(is_max, fixed_pvector + mass_per_index, fixed_pvector)
|
|
146
|
+
|
|
147
|
+
return fixed_pvector
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
class Grid:
|
|
152
|
+
def __init__(self, *, x0, x1, xstep=1, y0, y1, ystep=1):
|
|
153
|
+
"""initializes 2D grid with x0<=x<=x1 and y0<=y<=y1;
|
|
154
|
+
Creates a 1D JAX array of grid coordinates in self.x and self.y"""
|
|
155
|
+
self.x0 = x0
|
|
156
|
+
self.y0 = y0
|
|
157
|
+
self.x1 = x1
|
|
158
|
+
self.y1 = y1
|
|
159
|
+
self.xstep = xstep
|
|
160
|
+
self.ystep = ystep
|
|
161
|
+
xvals = jnp.arange(x0, x1 + xstep, xstep)
|
|
162
|
+
yvals = jnp.arange(y1, y0 - ystep, -ystep)
|
|
163
|
+
xgrid, ygrid = jnp.meshgrid(xvals, yvals)
|
|
164
|
+
self.x = jnp.ravel(xgrid)
|
|
165
|
+
self.y = jnp.ravel(ygrid)
|
|
166
|
+
self.points = jnp.column_stack((self.x,self.y))
|
|
167
|
+
# extent should match extent=(x0,x1,y0,y1) for compatibility with matplotlib.pyplot.contour
|
|
168
|
+
# see https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.contour.html
|
|
169
|
+
self.extent = (self.x0, self.x1, self.y0, self.y1)
|
|
170
|
+
self.gshape = self.shape()
|
|
171
|
+
self.boundary = ((self.x==x0) | (self.x==x1) | (self.y==y0) | (self.y==y1))
|
|
172
|
+
self.len = self.gshape[0] * self.gshape[1]
|
|
173
|
+
|
|
174
|
+
def shape(self, *, x0=None, x1=None, xstep=None, y0=None, y1=None, ystep=None):
|
|
175
|
+
"""returns a tuple(number_of_rows,number_of_cols) for the natural shape of the current grid, or a subset"""
|
|
176
|
+
x0 = self.x0 if x0 is None else x0
|
|
177
|
+
x1 = self.x1 if x1 is None else x1
|
|
178
|
+
y0 = self.y0 if y0 is None else y0
|
|
179
|
+
y1 = self.y1 if y1 is None else y1
|
|
180
|
+
xstep = self.xstep if xstep is None else xstep
|
|
181
|
+
ystep = self.ystep if ystep is None else ystep
|
|
182
|
+
if x1 < x0:
|
|
183
|
+
raise ValueError
|
|
184
|
+
if y1 < y0:
|
|
185
|
+
raise ValueError
|
|
186
|
+
if xstep <= 0:
|
|
187
|
+
raise ValueError
|
|
188
|
+
if ystep <= 0:
|
|
189
|
+
raise ValueError
|
|
190
|
+
number_of_rows = 1 + int((y1 - y0) / ystep)
|
|
191
|
+
number_of_cols = 1 + int((x1 - x0) / xstep)
|
|
192
|
+
return (number_of_rows, number_of_cols)
|
|
193
|
+
|
|
194
|
+
def within_box(self, *, x0=None, x1=None, y0=None, y1=None):
|
|
195
|
+
"""returns a 1D numpy boolean array, suitable as an index mask, for testing whether a grid point is also in the defined box"""
|
|
196
|
+
x0 = self.x0 if x0 is None else x0
|
|
197
|
+
x1 = self.x1 if x1 is None else x1
|
|
198
|
+
y0 = self.y0 if y0 is None else y0
|
|
199
|
+
y1 = self.y1 if y1 is None else y1
|
|
200
|
+
return (self.x >= x0) & (self.x <= x1) & (self.y >= y0) & (self.y <= y1)
|
|
201
|
+
|
|
202
|
+
def within_disk(self, *, x0, y0, r, metric="euclidean", **kwargs):
|
|
203
|
+
"""returns 1D JAX boolean array, suitable as an index mask, for testing whether a grid point is also in the defined disk"""
|
|
204
|
+
center = jnp.array([[x0, y0]])
|
|
205
|
+
|
|
206
|
+
if metric == "euclidean":
|
|
207
|
+
# For Euclidean distance, use squared Euclidean and compare r^2
|
|
208
|
+
distances_sq = dist_sqeuclidean(center, self.points)
|
|
209
|
+
mask = (distances_sq <= r**2).flatten()
|
|
210
|
+
elif metric == "manhattan":
|
|
211
|
+
distances = dist_manhattan(center, self.points)
|
|
212
|
+
mask = (distances <= r).flatten()
|
|
213
|
+
else:
|
|
214
|
+
raise ValueError(f"Unsupported metric: {metric}. Use 'euclidean' or 'manhattan'.")
|
|
215
|
+
|
|
216
|
+
return mask
|
|
217
|
+
|
|
218
|
+
def within_triangle(self, *, points):
|
|
219
|
+
"""returns 1D JAX boolean array, suitable as an index mask, for testing whether a grid point is also in the defined triangle"""
|
|
220
|
+
points = jnp.asarray(points)
|
|
221
|
+
a, b, c = points[0], points[1], points[2]
|
|
222
|
+
|
|
223
|
+
# Vectorized cross-product triangle containment test
|
|
224
|
+
# Use vmap to apply the single-point test to all grid points
|
|
225
|
+
mask = jax.vmap(
|
|
226
|
+
lambda p: _is_in_triangle_single(p, a, b, c)
|
|
227
|
+
)(self.points)
|
|
228
|
+
|
|
229
|
+
return mask
|
|
230
|
+
|
|
231
|
+
def index(self, *, x, y):
|
|
232
|
+
"""returns the unique 1D array index for grid point (x,y)"""
|
|
233
|
+
isSelectedPoint = (self.x == x) & (self.y == y)
|
|
234
|
+
indexes = jnp.flatnonzero(isSelectedPoint)
|
|
235
|
+
return int(indexes[0])
|
|
236
|
+
|
|
237
|
+
def embedding(self, *, valid):
|
|
238
|
+
"""
|
|
239
|
+
returns an embedding function efunc(z,fill=0.0) from 1D arrays z of size sum(valid)
|
|
240
|
+
to arrays of size self.len
|
|
241
|
+
|
|
242
|
+
valid is a jnp.array of type boolean, of size self.len
|
|
243
|
+
|
|
244
|
+
fill is the value for indices outside the embedding. The default
|
|
245
|
+
is zero (0.0). Setting fill=jnp.nan can be useful for
|
|
246
|
+
plotting purposes as matplotlib will omit jnp.nan values from various
|
|
247
|
+
kinds of plots.
|
|
248
|
+
"""
|
|
249
|
+
|
|
250
|
+
correct_z_len = valid.sum()
|
|
251
|
+
|
|
252
|
+
def efunc(z, fill=0.0):
|
|
253
|
+
v = jnp.full(self.len, fill)
|
|
254
|
+
return v.at[valid].set(z)
|
|
255
|
+
|
|
256
|
+
return efunc
|
|
257
|
+
|
|
258
|
+
def extremes(self, z, *, valid=None):
|
|
259
|
+
# if valid is None, defaults to all True array for grid
|
|
260
|
+
valid = jnp.full((self.len,), True) if valid is None else valid
|
|
261
|
+
min_z = z.min()
|
|
262
|
+
min_z_mask = jnp.abs(z-min_z)<1e-10 # Strict tolerance for exact min/max
|
|
263
|
+
max_z = z.max()
|
|
264
|
+
max_z_mask = jnp.abs(z-max_z)<1e-10 # Strict tolerance for exact min/max
|
|
265
|
+
return (min_z,self.points[valid][min_z_mask],max_z,self.points[valid][max_z_mask])
|
|
266
|
+
|
|
267
|
+
def spatial_utilities(
|
|
268
|
+
self, *, voter_ideal_points, metric="sqeuclidean", scale=-1, **kwargs
|
|
269
|
+
):
|
|
270
|
+
"""returns utility function values for each voter at each grid point"""
|
|
271
|
+
voter_ideal_points = jnp.asarray(voter_ideal_points)
|
|
272
|
+
|
|
273
|
+
if metric == "sqeuclidean":
|
|
274
|
+
distances = dist_sqeuclidean(voter_ideal_points, self.points)
|
|
275
|
+
elif metric == "manhattan":
|
|
276
|
+
distances = dist_manhattan(voter_ideal_points, self.points)
|
|
277
|
+
else:
|
|
278
|
+
raise ValueError(f"Unsupported metric: {metric}. Use 'sqeuclidean' or 'manhattan'.")
|
|
279
|
+
|
|
280
|
+
return scale * distances
|
|
281
|
+
|
|
282
|
+
def plot(
|
|
283
|
+
self,
|
|
284
|
+
z,
|
|
285
|
+
*,
|
|
286
|
+
title=None,
|
|
287
|
+
cmap=cm.gray_r,
|
|
288
|
+
alpha=0.6,
|
|
289
|
+
alpha_points=0.3,
|
|
290
|
+
log=True,
|
|
291
|
+
points=None,
|
|
292
|
+
zoom=False,
|
|
293
|
+
border=1,
|
|
294
|
+
logbias=1e-100,
|
|
295
|
+
figsize=(10, 10),
|
|
296
|
+
dpi=72,
|
|
297
|
+
fname=None
|
|
298
|
+
):
|
|
299
|
+
"""plots values z defined on the grid;
|
|
300
|
+
optionally plots additional 2D points
|
|
301
|
+
and zooms to fit the bounding box of the points"""
|
|
302
|
+
# Convert JAX arrays to NumPy for matplotlib compatibility
|
|
303
|
+
z = np.array(z)
|
|
304
|
+
grid_x = np.array(self.x)
|
|
305
|
+
grid_y = np.array(self.y)
|
|
306
|
+
|
|
307
|
+
plt.figure(figsize=figsize, dpi=dpi)
|
|
308
|
+
plt.rcParams["font.size"] = "24"
|
|
309
|
+
fmt = "%1.2f" if log else "%.2e"
|
|
310
|
+
if zoom:
|
|
311
|
+
points = np.asarray(points)
|
|
312
|
+
[min_x, min_y] = np.min(points, axis=0) - border
|
|
313
|
+
[max_x, max_y] = np.max(points, axis=0) + border
|
|
314
|
+
box = {"x0": min_x, "x1": max_x, "y0": min_y, "y1": max_y}
|
|
315
|
+
inZoom = np.array(self.within_box(**box))
|
|
316
|
+
zshape = self.shape(**box)
|
|
317
|
+
extent = (min_x, max_x, min_y, max_y)
|
|
318
|
+
zraw = np.copy(z[inZoom]).reshape(zshape)
|
|
319
|
+
x = np.copy(grid_x[inZoom]).reshape(zshape)
|
|
320
|
+
y = np.copy(grid_y[inZoom]).reshape(zshape)
|
|
321
|
+
else:
|
|
322
|
+
zshape = self.gshape
|
|
323
|
+
extent = self.extent
|
|
324
|
+
zraw = z.reshape(zshape)
|
|
325
|
+
x = grid_x.reshape(zshape)
|
|
326
|
+
y = grid_y.reshape(zshape)
|
|
327
|
+
zplot = np.log10(logbias + zraw) if log else zraw
|
|
328
|
+
contours = plt.contour(x, y, zplot, extent=extent, cmap=cmap)
|
|
329
|
+
plt.clabel(contours, inline=True, fontsize=12, fmt=fmt)
|
|
330
|
+
plt.imshow(zplot, extent=extent, cmap=cmap, alpha=alpha)
|
|
331
|
+
if points is not None:
|
|
332
|
+
plt.scatter(points[:, 0], points[:, 1], alpha=alpha_points, color="black")
|
|
333
|
+
if title is not None:
|
|
334
|
+
plt.title(title)
|
|
335
|
+
if fname is None:
|
|
336
|
+
plt.show()
|
|
337
|
+
else:
|
|
338
|
+
plt.savefig(fname)
|
|
339
|
+
|
|
340
|
+
|
|
341
|
+
def assert_valid_transition_matrix(P, *, decimal=6):
|
|
342
|
+
"""asserts that JAX array is square and that each row sums to 1.0
|
|
343
|
+
with default tolerance of 6 decimal places (float32) or 10 decimal places (float64)"""
|
|
344
|
+
P = jnp.asarray(P)
|
|
345
|
+
rows, cols = P.shape
|
|
346
|
+
chex.assert_shape(P, (rows, cols)) # Ensure square matrix
|
|
347
|
+
assert rows == cols, f"Matrix must be square, got shape {P.shape}"
|
|
348
|
+
|
|
349
|
+
row_sums = P.sum(axis=1)
|
|
350
|
+
expected = jnp.ones(rows)
|
|
351
|
+
tolerance = 10 ** (-decimal) * 1.1 # Slightly increased for numerical stability
|
|
352
|
+
|
|
353
|
+
chex.assert_trees_all_close(row_sums, expected, atol=tolerance, rtol=0)
|
|
354
|
+
|
|
355
|
+
|
|
356
|
+
def assert_zero_diagonal_int_matrix(M):
|
|
357
|
+
"""asserts that JAX array is square and the diagonal is 0.0"""
|
|
358
|
+
M = jnp.asarray(M)
|
|
359
|
+
rows, cols = M.shape
|
|
360
|
+
chex.assert_shape(M, (rows, cols)) # Ensure square matrix
|
|
361
|
+
assert rows == cols, f"Matrix must be square, got shape {M.shape}"
|
|
362
|
+
|
|
363
|
+
diagonal = jnp.diag(M)
|
|
364
|
+
expected = jnp.zeros(rows, dtype=int)
|
|
365
|
+
|
|
366
|
+
chex.assert_trees_all_equal(diagonal, expected)
|
|
367
|
+
|
|
368
|
+
class MarkovChainCPUGPU:
|
|
369
|
+
def __init__(self, *, P, computeNow=True, tolerance=None):
|
|
370
|
+
"""initializes a MarkovChainCPUGPU instance by copying in the transition
|
|
371
|
+
matrix P and calculating chain properties"""
|
|
372
|
+
if tolerance is None:
|
|
373
|
+
tolerance = TOLERANCE
|
|
374
|
+
self.P = jnp.asarray(P) # copy transition matrix to JAX array
|
|
375
|
+
assert_valid_transition_matrix(P)
|
|
376
|
+
diagP = jnp.diagonal(self.P)
|
|
377
|
+
self.absorbing_points = jnp.equal(diagP, 1.0)
|
|
378
|
+
self.unreachable_points = jnp.equal(jnp.sum(self.P, axis=0), diagP)
|
|
379
|
+
self.has_unique_stationary_distribution = not jnp.any(self.absorbing_points)
|
|
380
|
+
if computeNow and self.has_unique_stationary_distribution:
|
|
381
|
+
self.find_unique_stationary_distribution(tolerance=tolerance)
|
|
382
|
+
|
|
383
|
+
def L1_norm_of_single_step_change(self, x):
|
|
384
|
+
"""returns float(L1(xP-x))"""
|
|
385
|
+
return float(jnp.linalg.norm(jnp.dot(x, self.P) - x, ord=1))
|
|
386
|
+
|
|
387
|
+
def solve_for_unit_eigenvector(self):
|
|
388
|
+
"""This is another way to potentially find the stationary distribution,
|
|
389
|
+
but can suffer from numerical irregularities like negative entries.
|
|
390
|
+
Assumes eigenvalue of 1.0 exists and solves for the eigenvector by
|
|
391
|
+
considering a related matrix equation Q v = b, where:
|
|
392
|
+
Q is P transpose minus the identity matrix I, with the first row
|
|
393
|
+
replaced by all ones for the vector scaling requirement;
|
|
394
|
+
v is the eigenvector of eigenvalue 1 to be found; and
|
|
395
|
+
b is the first basis vector, where b[0]=1 and 0 elsewhere."""
|
|
396
|
+
n = self.P.shape[0]
|
|
397
|
+
Q = jnp.transpose(self.P) - jnp.eye(n)
|
|
398
|
+
Q = Q.at[0].set(jnp.ones(n)) # JAX immutable update
|
|
399
|
+
b = jnp.zeros(n)
|
|
400
|
+
b = b.at[0].set(1.0) # JAX immutable update
|
|
401
|
+
|
|
402
|
+
error_unable_msg = "unable to find unique unit eigenvector "
|
|
403
|
+
try:
|
|
404
|
+
unit_eigenvector = jnp.linalg.solve(Q, b)
|
|
405
|
+
except Exception as err:
|
|
406
|
+
warn(str(err)) # print the original exception lest it be lost for debugging purposes
|
|
407
|
+
raise RuntimeError(error_unable_msg+"(solver)")
|
|
408
|
+
|
|
409
|
+
if jnp.isnan(unit_eigenvector.sum()):
|
|
410
|
+
raise RuntimeError(error_unable_msg+"(nan)")
|
|
411
|
+
|
|
412
|
+
min_component = float(unit_eigenvector.min())
|
|
413
|
+
# Increased threshold for NumPy 2.0 compatibility (was -1e-7)
|
|
414
|
+
if ((min_component<0.0) and (min_component>-2e-7)):
|
|
415
|
+
unit_eigenvector = _move_neg_prob_to_max(unit_eigenvector)
|
|
416
|
+
unit_eigenvector = jnp.dot(unit_eigenvector, self.P)
|
|
417
|
+
min_component = float(unit_eigenvector.min())
|
|
418
|
+
|
|
419
|
+
if (min_component<0.0):
|
|
420
|
+
neg_msg = "(negative components: "+str(min_component)+" )"
|
|
421
|
+
warn(neg_msg)
|
|
422
|
+
raise RuntimeError(error_unable_msg+neg_msg)
|
|
423
|
+
|
|
424
|
+
self.unit_eigenvector = unit_eigenvector
|
|
425
|
+
return self.unit_eigenvector
|
|
426
|
+
|
|
427
|
+
|
|
428
|
+
def find_unique_stationary_distribution(self, *, tolerance=None, **kwargs):
|
|
429
|
+
"""finds the stationary distribution for a Markov Chain using algebraic method"""
|
|
430
|
+
if tolerance is None:
|
|
431
|
+
tolerance = TOLERANCE
|
|
432
|
+
if jnp.any(self.absorbing_points):
|
|
433
|
+
self.stationary_distribution = None
|
|
434
|
+
return None
|
|
435
|
+
self.stationary_distribution = self.solve_for_unit_eigenvector()
|
|
436
|
+
self.check_norm = self.L1_norm_of_single_step_change(self.stationary_distribution)
|
|
437
|
+
if self.check_norm > tolerance:
|
|
438
|
+
raise RuntimeError(f"Stationary distribution check norm {self.check_norm} exceeds tolerance {tolerance}")
|
|
439
|
+
return self.stationary_distribution
|
|
440
|
+
|
|
441
|
+
def diagnostic_metrics(self):
|
|
442
|
+
""" return Markov chain approximation metrics in mathematician-friendly format """
|
|
443
|
+
metrics = {
|
|
444
|
+
'||F||': self.P.shape[0],
|
|
445
|
+
'(𝝨𝝿)-1': float(self.stationary_distribution.sum())-1.0, # cast to float to avoid cupy array singleton
|
|
446
|
+
'||𝝿P-𝝿||_L1_norm': self.L1_norm_of_single_step_change(
|
|
447
|
+
self.stationary_distribution
|
|
448
|
+
)
|
|
449
|
+
}
|
|
450
|
+
return metrics
|
|
451
|
+
|
|
452
|
+
class VotingModel:
|
|
453
|
+
def __init__(
|
|
454
|
+
self,
|
|
455
|
+
*,
|
|
456
|
+
utility_functions,
|
|
457
|
+
number_of_voters,
|
|
458
|
+
number_of_feasible_alternatives,
|
|
459
|
+
majority,
|
|
460
|
+
zi
|
|
461
|
+
):
|
|
462
|
+
"""initializes a VotingModel with utility_functions for each voter,
|
|
463
|
+
the number_of_voters,
|
|
464
|
+
the number_of_feasible_alternatives,
|
|
465
|
+
the majority size, and whether to use zi fully random agenda or
|
|
466
|
+
intelligent challengers random over winning set+status quo"""
|
|
467
|
+
assert utility_functions.shape == (
|
|
468
|
+
number_of_voters,
|
|
469
|
+
number_of_feasible_alternatives,
|
|
470
|
+
)
|
|
471
|
+
self.utility_functions = utility_functions
|
|
472
|
+
self.number_of_voters = number_of_voters
|
|
473
|
+
self.number_of_feasible_alternatives = number_of_feasible_alternatives
|
|
474
|
+
self.majority = majority
|
|
475
|
+
self.zi = zi
|
|
476
|
+
self.analyzed = False
|
|
477
|
+
|
|
478
|
+
def E_𝝿(self,z):
|
|
479
|
+
"""returns mean, i.e., expected value of z under the stationary distribution"""
|
|
480
|
+
return jnp.dot(self.stationary_distribution,z)
|
|
481
|
+
|
|
482
|
+
def analyze(self):
|
|
483
|
+
self.MarkovChain = MarkovChainCPUGPU(P=self._get_transition_matrix())
|
|
484
|
+
self.core_points = self.MarkovChain.absorbing_points
|
|
485
|
+
self.core_exists = jnp.any(self.core_points)
|
|
486
|
+
if not self.core_exists:
|
|
487
|
+
self.stationary_distribution = self.MarkovChain.stationary_distribution
|
|
488
|
+
self.analyzed = True
|
|
489
|
+
|
|
490
|
+
def what_beats(self, *, index):
|
|
491
|
+
"""returns array of size number_of_feasible_alternatives
|
|
492
|
+
with value 1 where alternative beats current index by some majority"""
|
|
493
|
+
assert self.analyzed
|
|
494
|
+
points = (self.MarkovChain.P[index, :] > 0).astype("int32")
|
|
495
|
+
points = points.at[index].set(0)
|
|
496
|
+
return points
|
|
497
|
+
|
|
498
|
+
def what_is_beaten_by(self, *, index):
|
|
499
|
+
"""returns array of size number_of_feasible_alternatives
|
|
500
|
+
with value 1 where current index beats alternative by some majority"""
|
|
501
|
+
assert self.analyzed
|
|
502
|
+
points = (self.MarkovChain.P[:, index] > 0).astype("int32")
|
|
503
|
+
points = points.at[index].set(0)
|
|
504
|
+
return points
|
|
505
|
+
|
|
506
|
+
def summarize_in_context(self,*,grid,valid=None):
|
|
507
|
+
"""calculate summary statistics for stationary distribution using grid's coordinates and optional subset valid"""
|
|
508
|
+
# missing valid defaults to all True array for grid
|
|
509
|
+
valid = jnp.full((grid.len,), True) if valid is None else valid
|
|
510
|
+
# check valid array shape
|
|
511
|
+
assert valid.shape == (grid.len,)
|
|
512
|
+
# get X and Y coordinates for valid grid points
|
|
513
|
+
validX = grid.x[valid]
|
|
514
|
+
validY = grid.y[valid]
|
|
515
|
+
valid_points = grid.points[valid]
|
|
516
|
+
if self.core_exists:
|
|
517
|
+
return {
|
|
518
|
+
'core_exists': self.core_exists,
|
|
519
|
+
'core_points': valid_points[self.core_points]
|
|
520
|
+
}
|
|
521
|
+
# core does not exist, so evaulate mean, cov, min, max of stationary distribution
|
|
522
|
+
# first check that the number of valid points matches the dimensionality of the stationary distribution
|
|
523
|
+
assert (valid.sum(),) == self.stationary_distribution.shape
|
|
524
|
+
point_mean = self.E_𝝿(valid_points)
|
|
525
|
+
cov = jnp.cov(valid_points, rowvar=False, ddof=0, aweights=self.stationary_distribution)
|
|
526
|
+
(prob_min,prob_min_points,prob_max,prob_max_points) = \
|
|
527
|
+
grid.extremes(self.stationary_distribution,valid=valid)
|
|
528
|
+
_nonzero_statd = self.stationary_distribution[self.stationary_distribution>0]
|
|
529
|
+
entropy_bits = -_nonzero_statd.dot(jnp.log2(_nonzero_statd))
|
|
530
|
+
return {
|
|
531
|
+
'core_exists': self.core_exists,
|
|
532
|
+
'point_mean': point_mean,
|
|
533
|
+
'point_cov': cov,
|
|
534
|
+
'prob_min': prob_min,
|
|
535
|
+
'prob_min_points': prob_min_points,
|
|
536
|
+
'prob_max': prob_max,
|
|
537
|
+
'prob_max_points': prob_max_points,
|
|
538
|
+
'entropy_bits': entropy_bits
|
|
539
|
+
}
|
|
540
|
+
|
|
541
|
+
def plots(
|
|
542
|
+
self,
|
|
543
|
+
*,
|
|
544
|
+
grid,
|
|
545
|
+
voter_ideal_points,
|
|
546
|
+
diagnostics=False,
|
|
547
|
+
log=True,
|
|
548
|
+
embedding=lambda z, fill: z,
|
|
549
|
+
zoomborder=0,
|
|
550
|
+
dpi=72,
|
|
551
|
+
figsize=(10, 10),
|
|
552
|
+
fprefix=None,
|
|
553
|
+
title_core="Core (aborbing) points",
|
|
554
|
+
title_sad="L1 norm of difference in two rows of P^power",
|
|
555
|
+
title_diff1="L1 norm of change in corner row",
|
|
556
|
+
title_diff2="L1 norm of change in center row",
|
|
557
|
+
title_sum1minus1="Corner row sum minus 1.0",
|
|
558
|
+
title_sum2minus1="Center row sum minus 1.0",
|
|
559
|
+
title_unreachable_points="Dominated (unreachable) points",
|
|
560
|
+
title_stationary_distribution_no_grid="Stationary Distribution",
|
|
561
|
+
title_stationary_distribution="Stationary Distribution",
|
|
562
|
+
title_stationary_distribution_zoom="Stationary Distribution (zoom)"
|
|
563
|
+
):
|
|
564
|
+
def _fn(name):
|
|
565
|
+
return None if fprefix is None else fprefix + name
|
|
566
|
+
|
|
567
|
+
def _save(fname):
|
|
568
|
+
if fprefix is not None:
|
|
569
|
+
plt.savefig(fprefix + fname)
|
|
570
|
+
|
|
571
|
+
if self.core_exists:
|
|
572
|
+
grid.plot(
|
|
573
|
+
embedding(self.core_points.astype("int32"), fill=np.nan),
|
|
574
|
+
log=log,
|
|
575
|
+
points=voter_ideal_points,
|
|
576
|
+
zoom=True,
|
|
577
|
+
title=title_core,
|
|
578
|
+
dpi=dpi,
|
|
579
|
+
figsize=figsize,
|
|
580
|
+
fname=_fn("core.png"),
|
|
581
|
+
)
|
|
582
|
+
return None # when core exists abort as additional plots undefined
|
|
583
|
+
z = self.stationary_distribution
|
|
584
|
+
if grid is None:
|
|
585
|
+
plt.figure(figsize=figsize)
|
|
586
|
+
plt.plot(z)
|
|
587
|
+
plt.title(title_stationary_distribution_no_grid)
|
|
588
|
+
_save("stationary_distribution_no_grid.png")
|
|
589
|
+
else:
|
|
590
|
+
grid.plot(
|
|
591
|
+
embedding(z, fill=np.nan),
|
|
592
|
+
log=log,
|
|
593
|
+
points=voter_ideal_points,
|
|
594
|
+
title=title_stationary_distribution,
|
|
595
|
+
figsize=figsize,
|
|
596
|
+
dpi=dpi,
|
|
597
|
+
fname=_fn("stationary_distribution.png"),
|
|
598
|
+
)
|
|
599
|
+
if voter_ideal_points is not None:
|
|
600
|
+
grid.plot(
|
|
601
|
+
embedding(z, fill=np.nan),
|
|
602
|
+
log=log,
|
|
603
|
+
points=voter_ideal_points,
|
|
604
|
+
zoom=True,
|
|
605
|
+
border=zoomborder,
|
|
606
|
+
title=title_stationary_distribution_zoom,
|
|
607
|
+
figsize=figsize,
|
|
608
|
+
dpi=dpi,
|
|
609
|
+
fname=_fn("stationary_distribution_zoom.png"),
|
|
610
|
+
)
|
|
611
|
+
|
|
612
|
+
def _get_transition_matrix(self):
|
|
613
|
+
utility_functions = self.utility_functions
|
|
614
|
+
majority = self.majority
|
|
615
|
+
zi = self.zi
|
|
616
|
+
nfa = self.number_of_feasible_alternatives
|
|
617
|
+
cU = jnp.asarray(utility_functions)
|
|
618
|
+
|
|
619
|
+
# Vectorized computation: compare all alternatives at once
|
|
620
|
+
# cU shape: (n_voters, nfa)
|
|
621
|
+
# cU[:, :, jnp.newaxis] shape: (n_voters, nfa, 1)
|
|
622
|
+
# cU[:, jnp.newaxis, :] shape: (n_voters, 1, nfa)
|
|
623
|
+
# Result shape: (n_voters, nfa, nfa) where [v, sq, ch] = voter v prefers challenger ch over status quo sq
|
|
624
|
+
preferences = jnp.greater(cU[:, jnp.newaxis, :], cU[:, :, jnp.newaxis])
|
|
625
|
+
|
|
626
|
+
# Sum votes across voters: shape (nfa, nfa) where [sq, ch] = votes for ch when sq is status quo
|
|
627
|
+
total_votes = preferences.astype("int32").sum(axis=0)
|
|
628
|
+
|
|
629
|
+
# Determine winners: 1 if challenger gets majority, 0 otherwise
|
|
630
|
+
cV = jnp.greater_equal(total_votes, majority).astype("int32")
|
|
631
|
+
|
|
632
|
+
assert_zero_diagonal_int_matrix(cV)
|
|
633
|
+
cV_sum_of_row = cV.sum(axis=1) # sum up all col for each row
|
|
634
|
+
|
|
635
|
+
# set up the ZI and MI transition matrices
|
|
636
|
+
if zi:
|
|
637
|
+
cP = jnp.divide(
|
|
638
|
+
jnp.add(cV, jnp.diag(jnp.subtract(nfa, cV_sum_of_row))),
|
|
639
|
+
nfa
|
|
640
|
+
)
|
|
641
|
+
else:
|
|
642
|
+
cP = jnp.divide(
|
|
643
|
+
jnp.add(cV, jnp.eye(nfa)),
|
|
644
|
+
(1 + cV_sum_of_row)[:, jnp.newaxis]
|
|
645
|
+
)
|
|
646
|
+
|
|
647
|
+
assert_valid_transition_matrix(cP)
|
|
648
|
+
return cP
|
|
649
|
+
|
|
650
|
+
|
|
651
|
+
class CondorcetCycle(VotingModel):
|
|
652
|
+
def __init__(self, *, zi):
|
|
653
|
+
# docs suggest to call superclass directly
|
|
654
|
+
# instead of using super()
|
|
655
|
+
# https://docs.python.org/3/tutorial/classes.html#inheritance
|
|
656
|
+
VotingModel.__init__(
|
|
657
|
+
self,
|
|
658
|
+
zi=zi,
|
|
659
|
+
number_of_voters=3,
|
|
660
|
+
majority=2,
|
|
661
|
+
number_of_feasible_alternatives=3,
|
|
662
|
+
utility_functions=jnp.array(
|
|
663
|
+
[
|
|
664
|
+
[3, 2, 1], # first agent prefers A>B>C
|
|
665
|
+
[1, 3, 2], # second agent prefers B>C>A
|
|
666
|
+
[2, 1, 3], # third agents prefers C>A>B
|
|
667
|
+
]
|
|
668
|
+
),
|
|
669
|
+
)
|
|
670
|
+
|
|
671
|
+
# Import benchmarks submodule
|
|
672
|
+
from . import benchmarks
|