gridvoting-jax 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,672 @@
1
+ __version__ = "0.2.0"
2
+
3
+ import os
4
+ import numpy as np
5
+ import matplotlib.pyplot as plt
6
+ import matplotlib.cm as cm
7
+ from warnings import warn
8
+ import jax
9
+ import jax.numpy as jnp
10
+ import chex
11
+
12
+ # Default tolerance for floating-point calculations
13
+ # 5e-5 for float32 (default), 1e-10 for float64 (after calling enable_float64())
14
+ TOLERANCE = 5e-5
15
+
16
+ def enable_float64():
17
+ """Enable 64-bit floating point precision in JAX.
18
+
19
+ By default, JAX uses 32-bit floats for better GPU performance.
20
+ Call this function to enable 64-bit precision for higher accuracy.
21
+
22
+ This is a global configuration that affects all subsequent JAX operations.
23
+ See: https://docs.jax.dev/en/latest/default_dtypes.html
24
+
25
+ Example:
26
+ >>> import gridvoting_jax as gv
27
+ >>> gv.enable_float64()
28
+ >>> # All subsequent JAX operations will use float64
29
+ """
30
+ jax.config.update("jax_enable_x64", True)
31
+
32
+ # Device detection with GV_FORCE_CPU override
33
+ use_accelerator = False
34
+ device_type = 'cpu'
35
+
36
+ if os.environ.get('GV_FORCE_CPU', '0') != '1':
37
+ # Check for available accelerators (TPU > GPU > CPU)
38
+ devices = jax.devices()
39
+ if devices:
40
+ default_device = devices[0]
41
+ device_type = default_device.platform
42
+ if device_type in ['gpu', 'tpu']:
43
+ use_accelerator = True
44
+ warn(f"JAX using {device_type.upper()}: {default_device}")
45
+ else:
46
+ warn("JAX using CPU (no GPU/TPU detected)")
47
+ else:
48
+ warn("GV_FORCE_CPU=1: JAX forced to CPU-only mode")
49
+
50
+ @jax.jit
51
+ def dist_sqeuclidean(XA, XB):
52
+ """JAX-based squared Euclidean pairwise distance calculation.
53
+
54
+ Args:
55
+ XA: array of shape (m, n)
56
+ XB: array of shape (p, n)
57
+
58
+ Returns:
59
+ Distance matrix of shape (m, p)
60
+ """
61
+ XA = jnp.asarray(XA)
62
+ XB = jnp.asarray(XB)
63
+ # Squared Euclidean: ||a-b||^2 = ||a||^2 + ||b||^2 - 2*a·b
64
+ XA_sq = jnp.sum(XA**2, axis=1, keepdims=True)
65
+ XB_sq = jnp.sum(XB**2, axis=1, keepdims=True)
66
+ return XA_sq + XB_sq.T - 2 * jnp.dot(XA, XB.T)
67
+
68
+ @jax.jit
69
+ def dist_manhattan(XA, XB):
70
+ """JAX-based Manhattan pairwise distance calculation.
71
+
72
+ Args:
73
+ XA: array of shape (m, n)
74
+ XB: array of shape (p, n)
75
+
76
+ Returns:
77
+ Distance matrix of shape (m, p)
78
+ """
79
+ XA = jnp.asarray(XA)
80
+ XB = jnp.asarray(XB)
81
+ # Manhattan distance: sum(|a-b|)
82
+ return jnp.sum(jnp.abs(XA[:, None, :] - XB[None, :, :]), axis=2)
83
+
84
+ @jax.jit
85
+ def _is_in_triangle_single(p, a, b, c):
86
+ """
87
+ Returns True if point p is in triangle (a, b, c).
88
+ Robust for arbitrary vertex winding (CW or CCW).
89
+
90
+ Args:
91
+ p: Point as [x, y]
92
+ a, b, c: Triangle vertices as [x, y]
93
+
94
+ Returns:
95
+ Boolean indicating if p is inside triangle
96
+
97
+ See also: computational geometry, half-plane test;
98
+ Stack Overflow answer to https://stackoverflow.com/questions/2049582/how-to-determine-if-a-point-is-in-a-2d-triangle
99
+ https://stackoverflow.com/a/2049593/103081
100
+ by https://stackoverflow.com/users/233522/kornel-kisielewicz
101
+ """
102
+ def cross(o, a, b):
103
+ return (a[0] - o[0]) * (b[1] - o[1]) - (a[1] - o[1]) * (b[0] - o[0])
104
+
105
+ s1 = cross(p, a, b)
106
+ s2 = cross(p, b, c)
107
+ s3 = cross(p, c, a)
108
+
109
+ # Small epsilon for numerical tolerance on edges/vertices
110
+ eps = 1e-10
111
+ has_neg = (s1 < -eps) | (s2 < -eps) | (s3 < -eps)
112
+ has_pos = (s1 > eps) | (s2 > eps) | (s3 > eps)
113
+
114
+ return ~(has_neg & has_pos)
115
+
116
+ @jax.jit
117
+ def _move_neg_prob_to_max(pvector):
118
+ """Fix negative probability components by moving mass to maximum values.
119
+
120
+ Redistributes the total mass from negative components equally among
121
+ all indices that share the maximum value (within TOLERANCE).
122
+
123
+ Args:
124
+ pvector: JAX array that may contain small negative values
125
+
126
+ Returns:
127
+ fixed_pvector: JAX array with negative values zeroed and mass
128
+ redistributed equally to all maximum-value indices
129
+ """
130
+ # Identify negative components and calculate mass to redistribute
131
+ # Use jnp.where to avoid boolean indexing which is incompatible with JIT
132
+ to_zero = pvector < 0.0
133
+ mass_destroyed = jnp.where(to_zero, pvector, 0.0).sum()
134
+
135
+ # Zero out negative components
136
+ fixed_pvector = jnp.where(to_zero, 0.0, pvector)
137
+
138
+ # Find ALL indices with maximum value (within TOLERANCE)
139
+ max_val = fixed_pvector.max()
140
+ is_max = jnp.abs(fixed_pvector - max_val) < TOLERANCE
141
+ num_max_indices = is_max.sum()
142
+
143
+ # Distribute mass equally among all maximum indices
144
+ mass_per_index = mass_destroyed / num_max_indices
145
+ fixed_pvector = jnp.where(is_max, fixed_pvector + mass_per_index, fixed_pvector)
146
+
147
+ return fixed_pvector
148
+
149
+
150
+
151
+ class Grid:
152
+ def __init__(self, *, x0, x1, xstep=1, y0, y1, ystep=1):
153
+ """initializes 2D grid with x0<=x<=x1 and y0<=y<=y1;
154
+ Creates a 1D JAX array of grid coordinates in self.x and self.y"""
155
+ self.x0 = x0
156
+ self.y0 = y0
157
+ self.x1 = x1
158
+ self.y1 = y1
159
+ self.xstep = xstep
160
+ self.ystep = ystep
161
+ xvals = jnp.arange(x0, x1 + xstep, xstep)
162
+ yvals = jnp.arange(y1, y0 - ystep, -ystep)
163
+ xgrid, ygrid = jnp.meshgrid(xvals, yvals)
164
+ self.x = jnp.ravel(xgrid)
165
+ self.y = jnp.ravel(ygrid)
166
+ self.points = jnp.column_stack((self.x,self.y))
167
+ # extent should match extent=(x0,x1,y0,y1) for compatibility with matplotlib.pyplot.contour
168
+ # see https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.contour.html
169
+ self.extent = (self.x0, self.x1, self.y0, self.y1)
170
+ self.gshape = self.shape()
171
+ self.boundary = ((self.x==x0) | (self.x==x1) | (self.y==y0) | (self.y==y1))
172
+ self.len = self.gshape[0] * self.gshape[1]
173
+
174
+ def shape(self, *, x0=None, x1=None, xstep=None, y0=None, y1=None, ystep=None):
175
+ """returns a tuple(number_of_rows,number_of_cols) for the natural shape of the current grid, or a subset"""
176
+ x0 = self.x0 if x0 is None else x0
177
+ x1 = self.x1 if x1 is None else x1
178
+ y0 = self.y0 if y0 is None else y0
179
+ y1 = self.y1 if y1 is None else y1
180
+ xstep = self.xstep if xstep is None else xstep
181
+ ystep = self.ystep if ystep is None else ystep
182
+ if x1 < x0:
183
+ raise ValueError
184
+ if y1 < y0:
185
+ raise ValueError
186
+ if xstep <= 0:
187
+ raise ValueError
188
+ if ystep <= 0:
189
+ raise ValueError
190
+ number_of_rows = 1 + int((y1 - y0) / ystep)
191
+ number_of_cols = 1 + int((x1 - x0) / xstep)
192
+ return (number_of_rows, number_of_cols)
193
+
194
+ def within_box(self, *, x0=None, x1=None, y0=None, y1=None):
195
+ """returns a 1D numpy boolean array, suitable as an index mask, for testing whether a grid point is also in the defined box"""
196
+ x0 = self.x0 if x0 is None else x0
197
+ x1 = self.x1 if x1 is None else x1
198
+ y0 = self.y0 if y0 is None else y0
199
+ y1 = self.y1 if y1 is None else y1
200
+ return (self.x >= x0) & (self.x <= x1) & (self.y >= y0) & (self.y <= y1)
201
+
202
+ def within_disk(self, *, x0, y0, r, metric="euclidean", **kwargs):
203
+ """returns 1D JAX boolean array, suitable as an index mask, for testing whether a grid point is also in the defined disk"""
204
+ center = jnp.array([[x0, y0]])
205
+
206
+ if metric == "euclidean":
207
+ # For Euclidean distance, use squared Euclidean and compare r^2
208
+ distances_sq = dist_sqeuclidean(center, self.points)
209
+ mask = (distances_sq <= r**2).flatten()
210
+ elif metric == "manhattan":
211
+ distances = dist_manhattan(center, self.points)
212
+ mask = (distances <= r).flatten()
213
+ else:
214
+ raise ValueError(f"Unsupported metric: {metric}. Use 'euclidean' or 'manhattan'.")
215
+
216
+ return mask
217
+
218
+ def within_triangle(self, *, points):
219
+ """returns 1D JAX boolean array, suitable as an index mask, for testing whether a grid point is also in the defined triangle"""
220
+ points = jnp.asarray(points)
221
+ a, b, c = points[0], points[1], points[2]
222
+
223
+ # Vectorized cross-product triangle containment test
224
+ # Use vmap to apply the single-point test to all grid points
225
+ mask = jax.vmap(
226
+ lambda p: _is_in_triangle_single(p, a, b, c)
227
+ )(self.points)
228
+
229
+ return mask
230
+
231
+ def index(self, *, x, y):
232
+ """returns the unique 1D array index for grid point (x,y)"""
233
+ isSelectedPoint = (self.x == x) & (self.y == y)
234
+ indexes = jnp.flatnonzero(isSelectedPoint)
235
+ return int(indexes[0])
236
+
237
+ def embedding(self, *, valid):
238
+ """
239
+ returns an embedding function efunc(z,fill=0.0) from 1D arrays z of size sum(valid)
240
+ to arrays of size self.len
241
+
242
+ valid is a jnp.array of type boolean, of size self.len
243
+
244
+ fill is the value for indices outside the embedding. The default
245
+ is zero (0.0). Setting fill=jnp.nan can be useful for
246
+ plotting purposes as matplotlib will omit jnp.nan values from various
247
+ kinds of plots.
248
+ """
249
+
250
+ correct_z_len = valid.sum()
251
+
252
+ def efunc(z, fill=0.0):
253
+ v = jnp.full(self.len, fill)
254
+ return v.at[valid].set(z)
255
+
256
+ return efunc
257
+
258
+ def extremes(self, z, *, valid=None):
259
+ # if valid is None, defaults to all True array for grid
260
+ valid = jnp.full((self.len,), True) if valid is None else valid
261
+ min_z = z.min()
262
+ min_z_mask = jnp.abs(z-min_z)<1e-10 # Strict tolerance for exact min/max
263
+ max_z = z.max()
264
+ max_z_mask = jnp.abs(z-max_z)<1e-10 # Strict tolerance for exact min/max
265
+ return (min_z,self.points[valid][min_z_mask],max_z,self.points[valid][max_z_mask])
266
+
267
+ def spatial_utilities(
268
+ self, *, voter_ideal_points, metric="sqeuclidean", scale=-1, **kwargs
269
+ ):
270
+ """returns utility function values for each voter at each grid point"""
271
+ voter_ideal_points = jnp.asarray(voter_ideal_points)
272
+
273
+ if metric == "sqeuclidean":
274
+ distances = dist_sqeuclidean(voter_ideal_points, self.points)
275
+ elif metric == "manhattan":
276
+ distances = dist_manhattan(voter_ideal_points, self.points)
277
+ else:
278
+ raise ValueError(f"Unsupported metric: {metric}. Use 'sqeuclidean' or 'manhattan'.")
279
+
280
+ return scale * distances
281
+
282
+ def plot(
283
+ self,
284
+ z,
285
+ *,
286
+ title=None,
287
+ cmap=cm.gray_r,
288
+ alpha=0.6,
289
+ alpha_points=0.3,
290
+ log=True,
291
+ points=None,
292
+ zoom=False,
293
+ border=1,
294
+ logbias=1e-100,
295
+ figsize=(10, 10),
296
+ dpi=72,
297
+ fname=None
298
+ ):
299
+ """plots values z defined on the grid;
300
+ optionally plots additional 2D points
301
+ and zooms to fit the bounding box of the points"""
302
+ # Convert JAX arrays to NumPy for matplotlib compatibility
303
+ z = np.array(z)
304
+ grid_x = np.array(self.x)
305
+ grid_y = np.array(self.y)
306
+
307
+ plt.figure(figsize=figsize, dpi=dpi)
308
+ plt.rcParams["font.size"] = "24"
309
+ fmt = "%1.2f" if log else "%.2e"
310
+ if zoom:
311
+ points = np.asarray(points)
312
+ [min_x, min_y] = np.min(points, axis=0) - border
313
+ [max_x, max_y] = np.max(points, axis=0) + border
314
+ box = {"x0": min_x, "x1": max_x, "y0": min_y, "y1": max_y}
315
+ inZoom = np.array(self.within_box(**box))
316
+ zshape = self.shape(**box)
317
+ extent = (min_x, max_x, min_y, max_y)
318
+ zraw = np.copy(z[inZoom]).reshape(zshape)
319
+ x = np.copy(grid_x[inZoom]).reshape(zshape)
320
+ y = np.copy(grid_y[inZoom]).reshape(zshape)
321
+ else:
322
+ zshape = self.gshape
323
+ extent = self.extent
324
+ zraw = z.reshape(zshape)
325
+ x = grid_x.reshape(zshape)
326
+ y = grid_y.reshape(zshape)
327
+ zplot = np.log10(logbias + zraw) if log else zraw
328
+ contours = plt.contour(x, y, zplot, extent=extent, cmap=cmap)
329
+ plt.clabel(contours, inline=True, fontsize=12, fmt=fmt)
330
+ plt.imshow(zplot, extent=extent, cmap=cmap, alpha=alpha)
331
+ if points is not None:
332
+ plt.scatter(points[:, 0], points[:, 1], alpha=alpha_points, color="black")
333
+ if title is not None:
334
+ plt.title(title)
335
+ if fname is None:
336
+ plt.show()
337
+ else:
338
+ plt.savefig(fname)
339
+
340
+
341
+ def assert_valid_transition_matrix(P, *, decimal=6):
342
+ """asserts that JAX array is square and that each row sums to 1.0
343
+ with default tolerance of 6 decimal places (float32) or 10 decimal places (float64)"""
344
+ P = jnp.asarray(P)
345
+ rows, cols = P.shape
346
+ chex.assert_shape(P, (rows, cols)) # Ensure square matrix
347
+ assert rows == cols, f"Matrix must be square, got shape {P.shape}"
348
+
349
+ row_sums = P.sum(axis=1)
350
+ expected = jnp.ones(rows)
351
+ tolerance = 10 ** (-decimal) * 1.1 # Slightly increased for numerical stability
352
+
353
+ chex.assert_trees_all_close(row_sums, expected, atol=tolerance, rtol=0)
354
+
355
+
356
+ def assert_zero_diagonal_int_matrix(M):
357
+ """asserts that JAX array is square and the diagonal is 0.0"""
358
+ M = jnp.asarray(M)
359
+ rows, cols = M.shape
360
+ chex.assert_shape(M, (rows, cols)) # Ensure square matrix
361
+ assert rows == cols, f"Matrix must be square, got shape {M.shape}"
362
+
363
+ diagonal = jnp.diag(M)
364
+ expected = jnp.zeros(rows, dtype=int)
365
+
366
+ chex.assert_trees_all_equal(diagonal, expected)
367
+
368
+ class MarkovChainCPUGPU:
369
+ def __init__(self, *, P, computeNow=True, tolerance=None):
370
+ """initializes a MarkovChainCPUGPU instance by copying in the transition
371
+ matrix P and calculating chain properties"""
372
+ if tolerance is None:
373
+ tolerance = TOLERANCE
374
+ self.P = jnp.asarray(P) # copy transition matrix to JAX array
375
+ assert_valid_transition_matrix(P)
376
+ diagP = jnp.diagonal(self.P)
377
+ self.absorbing_points = jnp.equal(diagP, 1.0)
378
+ self.unreachable_points = jnp.equal(jnp.sum(self.P, axis=0), diagP)
379
+ self.has_unique_stationary_distribution = not jnp.any(self.absorbing_points)
380
+ if computeNow and self.has_unique_stationary_distribution:
381
+ self.find_unique_stationary_distribution(tolerance=tolerance)
382
+
383
+ def L1_norm_of_single_step_change(self, x):
384
+ """returns float(L1(xP-x))"""
385
+ return float(jnp.linalg.norm(jnp.dot(x, self.P) - x, ord=1))
386
+
387
+ def solve_for_unit_eigenvector(self):
388
+ """This is another way to potentially find the stationary distribution,
389
+ but can suffer from numerical irregularities like negative entries.
390
+ Assumes eigenvalue of 1.0 exists and solves for the eigenvector by
391
+ considering a related matrix equation Q v = b, where:
392
+ Q is P transpose minus the identity matrix I, with the first row
393
+ replaced by all ones for the vector scaling requirement;
394
+ v is the eigenvector of eigenvalue 1 to be found; and
395
+ b is the first basis vector, where b[0]=1 and 0 elsewhere."""
396
+ n = self.P.shape[0]
397
+ Q = jnp.transpose(self.P) - jnp.eye(n)
398
+ Q = Q.at[0].set(jnp.ones(n)) # JAX immutable update
399
+ b = jnp.zeros(n)
400
+ b = b.at[0].set(1.0) # JAX immutable update
401
+
402
+ error_unable_msg = "unable to find unique unit eigenvector "
403
+ try:
404
+ unit_eigenvector = jnp.linalg.solve(Q, b)
405
+ except Exception as err:
406
+ warn(str(err)) # print the original exception lest it be lost for debugging purposes
407
+ raise RuntimeError(error_unable_msg+"(solver)")
408
+
409
+ if jnp.isnan(unit_eigenvector.sum()):
410
+ raise RuntimeError(error_unable_msg+"(nan)")
411
+
412
+ min_component = float(unit_eigenvector.min())
413
+ # Increased threshold for NumPy 2.0 compatibility (was -1e-7)
414
+ if ((min_component<0.0) and (min_component>-2e-7)):
415
+ unit_eigenvector = _move_neg_prob_to_max(unit_eigenvector)
416
+ unit_eigenvector = jnp.dot(unit_eigenvector, self.P)
417
+ min_component = float(unit_eigenvector.min())
418
+
419
+ if (min_component<0.0):
420
+ neg_msg = "(negative components: "+str(min_component)+" )"
421
+ warn(neg_msg)
422
+ raise RuntimeError(error_unable_msg+neg_msg)
423
+
424
+ self.unit_eigenvector = unit_eigenvector
425
+ return self.unit_eigenvector
426
+
427
+
428
+ def find_unique_stationary_distribution(self, *, tolerance=None, **kwargs):
429
+ """finds the stationary distribution for a Markov Chain using algebraic method"""
430
+ if tolerance is None:
431
+ tolerance = TOLERANCE
432
+ if jnp.any(self.absorbing_points):
433
+ self.stationary_distribution = None
434
+ return None
435
+ self.stationary_distribution = self.solve_for_unit_eigenvector()
436
+ self.check_norm = self.L1_norm_of_single_step_change(self.stationary_distribution)
437
+ if self.check_norm > tolerance:
438
+ raise RuntimeError(f"Stationary distribution check norm {self.check_norm} exceeds tolerance {tolerance}")
439
+ return self.stationary_distribution
440
+
441
+ def diagnostic_metrics(self):
442
+ """ return Markov chain approximation metrics in mathematician-friendly format """
443
+ metrics = {
444
+ '||F||': self.P.shape[0],
445
+ '(𝝨𝝿)-1': float(self.stationary_distribution.sum())-1.0, # cast to float to avoid cupy array singleton
446
+ '||𝝿P-𝝿||_L1_norm': self.L1_norm_of_single_step_change(
447
+ self.stationary_distribution
448
+ )
449
+ }
450
+ return metrics
451
+
452
+ class VotingModel:
453
+ def __init__(
454
+ self,
455
+ *,
456
+ utility_functions,
457
+ number_of_voters,
458
+ number_of_feasible_alternatives,
459
+ majority,
460
+ zi
461
+ ):
462
+ """initializes a VotingModel with utility_functions for each voter,
463
+ the number_of_voters,
464
+ the number_of_feasible_alternatives,
465
+ the majority size, and whether to use zi fully random agenda or
466
+ intelligent challengers random over winning set+status quo"""
467
+ assert utility_functions.shape == (
468
+ number_of_voters,
469
+ number_of_feasible_alternatives,
470
+ )
471
+ self.utility_functions = utility_functions
472
+ self.number_of_voters = number_of_voters
473
+ self.number_of_feasible_alternatives = number_of_feasible_alternatives
474
+ self.majority = majority
475
+ self.zi = zi
476
+ self.analyzed = False
477
+
478
+ def E_𝝿(self,z):
479
+ """returns mean, i.e., expected value of z under the stationary distribution"""
480
+ return jnp.dot(self.stationary_distribution,z)
481
+
482
+ def analyze(self):
483
+ self.MarkovChain = MarkovChainCPUGPU(P=self._get_transition_matrix())
484
+ self.core_points = self.MarkovChain.absorbing_points
485
+ self.core_exists = jnp.any(self.core_points)
486
+ if not self.core_exists:
487
+ self.stationary_distribution = self.MarkovChain.stationary_distribution
488
+ self.analyzed = True
489
+
490
+ def what_beats(self, *, index):
491
+ """returns array of size number_of_feasible_alternatives
492
+ with value 1 where alternative beats current index by some majority"""
493
+ assert self.analyzed
494
+ points = (self.MarkovChain.P[index, :] > 0).astype("int32")
495
+ points = points.at[index].set(0)
496
+ return points
497
+
498
+ def what_is_beaten_by(self, *, index):
499
+ """returns array of size number_of_feasible_alternatives
500
+ with value 1 where current index beats alternative by some majority"""
501
+ assert self.analyzed
502
+ points = (self.MarkovChain.P[:, index] > 0).astype("int32")
503
+ points = points.at[index].set(0)
504
+ return points
505
+
506
+ def summarize_in_context(self,*,grid,valid=None):
507
+ """calculate summary statistics for stationary distribution using grid's coordinates and optional subset valid"""
508
+ # missing valid defaults to all True array for grid
509
+ valid = jnp.full((grid.len,), True) if valid is None else valid
510
+ # check valid array shape
511
+ assert valid.shape == (grid.len,)
512
+ # get X and Y coordinates for valid grid points
513
+ validX = grid.x[valid]
514
+ validY = grid.y[valid]
515
+ valid_points = grid.points[valid]
516
+ if self.core_exists:
517
+ return {
518
+ 'core_exists': self.core_exists,
519
+ 'core_points': valid_points[self.core_points]
520
+ }
521
+ # core does not exist, so evaulate mean, cov, min, max of stationary distribution
522
+ # first check that the number of valid points matches the dimensionality of the stationary distribution
523
+ assert (valid.sum(),) == self.stationary_distribution.shape
524
+ point_mean = self.E_𝝿(valid_points)
525
+ cov = jnp.cov(valid_points, rowvar=False, ddof=0, aweights=self.stationary_distribution)
526
+ (prob_min,prob_min_points,prob_max,prob_max_points) = \
527
+ grid.extremes(self.stationary_distribution,valid=valid)
528
+ _nonzero_statd = self.stationary_distribution[self.stationary_distribution>0]
529
+ entropy_bits = -_nonzero_statd.dot(jnp.log2(_nonzero_statd))
530
+ return {
531
+ 'core_exists': self.core_exists,
532
+ 'point_mean': point_mean,
533
+ 'point_cov': cov,
534
+ 'prob_min': prob_min,
535
+ 'prob_min_points': prob_min_points,
536
+ 'prob_max': prob_max,
537
+ 'prob_max_points': prob_max_points,
538
+ 'entropy_bits': entropy_bits
539
+ }
540
+
541
+ def plots(
542
+ self,
543
+ *,
544
+ grid,
545
+ voter_ideal_points,
546
+ diagnostics=False,
547
+ log=True,
548
+ embedding=lambda z, fill: z,
549
+ zoomborder=0,
550
+ dpi=72,
551
+ figsize=(10, 10),
552
+ fprefix=None,
553
+ title_core="Core (aborbing) points",
554
+ title_sad="L1 norm of difference in two rows of P^power",
555
+ title_diff1="L1 norm of change in corner row",
556
+ title_diff2="L1 norm of change in center row",
557
+ title_sum1minus1="Corner row sum minus 1.0",
558
+ title_sum2minus1="Center row sum minus 1.0",
559
+ title_unreachable_points="Dominated (unreachable) points",
560
+ title_stationary_distribution_no_grid="Stationary Distribution",
561
+ title_stationary_distribution="Stationary Distribution",
562
+ title_stationary_distribution_zoom="Stationary Distribution (zoom)"
563
+ ):
564
+ def _fn(name):
565
+ return None if fprefix is None else fprefix + name
566
+
567
+ def _save(fname):
568
+ if fprefix is not None:
569
+ plt.savefig(fprefix + fname)
570
+
571
+ if self.core_exists:
572
+ grid.plot(
573
+ embedding(self.core_points.astype("int32"), fill=np.nan),
574
+ log=log,
575
+ points=voter_ideal_points,
576
+ zoom=True,
577
+ title=title_core,
578
+ dpi=dpi,
579
+ figsize=figsize,
580
+ fname=_fn("core.png"),
581
+ )
582
+ return None # when core exists abort as additional plots undefined
583
+ z = self.stationary_distribution
584
+ if grid is None:
585
+ plt.figure(figsize=figsize)
586
+ plt.plot(z)
587
+ plt.title(title_stationary_distribution_no_grid)
588
+ _save("stationary_distribution_no_grid.png")
589
+ else:
590
+ grid.plot(
591
+ embedding(z, fill=np.nan),
592
+ log=log,
593
+ points=voter_ideal_points,
594
+ title=title_stationary_distribution,
595
+ figsize=figsize,
596
+ dpi=dpi,
597
+ fname=_fn("stationary_distribution.png"),
598
+ )
599
+ if voter_ideal_points is not None:
600
+ grid.plot(
601
+ embedding(z, fill=np.nan),
602
+ log=log,
603
+ points=voter_ideal_points,
604
+ zoom=True,
605
+ border=zoomborder,
606
+ title=title_stationary_distribution_zoom,
607
+ figsize=figsize,
608
+ dpi=dpi,
609
+ fname=_fn("stationary_distribution_zoom.png"),
610
+ )
611
+
612
+ def _get_transition_matrix(self):
613
+ utility_functions = self.utility_functions
614
+ majority = self.majority
615
+ zi = self.zi
616
+ nfa = self.number_of_feasible_alternatives
617
+ cU = jnp.asarray(utility_functions)
618
+
619
+ # Vectorized computation: compare all alternatives at once
620
+ # cU shape: (n_voters, nfa)
621
+ # cU[:, :, jnp.newaxis] shape: (n_voters, nfa, 1)
622
+ # cU[:, jnp.newaxis, :] shape: (n_voters, 1, nfa)
623
+ # Result shape: (n_voters, nfa, nfa) where [v, sq, ch] = voter v prefers challenger ch over status quo sq
624
+ preferences = jnp.greater(cU[:, jnp.newaxis, :], cU[:, :, jnp.newaxis])
625
+
626
+ # Sum votes across voters: shape (nfa, nfa) where [sq, ch] = votes for ch when sq is status quo
627
+ total_votes = preferences.astype("int32").sum(axis=0)
628
+
629
+ # Determine winners: 1 if challenger gets majority, 0 otherwise
630
+ cV = jnp.greater_equal(total_votes, majority).astype("int32")
631
+
632
+ assert_zero_diagonal_int_matrix(cV)
633
+ cV_sum_of_row = cV.sum(axis=1) # sum up all col for each row
634
+
635
+ # set up the ZI and MI transition matrices
636
+ if zi:
637
+ cP = jnp.divide(
638
+ jnp.add(cV, jnp.diag(jnp.subtract(nfa, cV_sum_of_row))),
639
+ nfa
640
+ )
641
+ else:
642
+ cP = jnp.divide(
643
+ jnp.add(cV, jnp.eye(nfa)),
644
+ (1 + cV_sum_of_row)[:, jnp.newaxis]
645
+ )
646
+
647
+ assert_valid_transition_matrix(cP)
648
+ return cP
649
+
650
+
651
+ class CondorcetCycle(VotingModel):
652
+ def __init__(self, *, zi):
653
+ # docs suggest to call superclass directly
654
+ # instead of using super()
655
+ # https://docs.python.org/3/tutorial/classes.html#inheritance
656
+ VotingModel.__init__(
657
+ self,
658
+ zi=zi,
659
+ number_of_voters=3,
660
+ majority=2,
661
+ number_of_feasible_alternatives=3,
662
+ utility_functions=jnp.array(
663
+ [
664
+ [3, 2, 1], # first agent prefers A>B>C
665
+ [1, 3, 2], # second agent prefers B>C>A
666
+ [2, 1, 3], # third agents prefers C>A>B
667
+ ]
668
+ ),
669
+ )
670
+
671
+ # Import benchmarks submodule
672
+ from . import benchmarks
@@ -0,0 +1,6 @@
1
+ """Benchmarking utilities for gridvoting-jax."""
2
+
3
+ from .performance import performance
4
+ from .osf_comparison import run_comparison_report
5
+
6
+ __all__ = ['performance', 'run_comparison_report']