gridvoting-jax 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,613 @@
1
+ __version__ = "0.0.1"
2
+
3
+ import os
4
+ import numpy as np
5
+ import pandas as pd
6
+ import matplotlib.pyplot as plt
7
+ import matplotlib.cm as cm
8
+ from scipy.spatial.distance import cdist
9
+ from warnings import warn
10
+ import jax
11
+ import jax.numpy as jnp
12
+
13
+ # Default tolerance for float32 calculations
14
+ TOLERANCE = 5e-5
15
+
16
+ # Device detection with NO_GPU override
17
+ use_accelerator = False
18
+ device_type = 'cpu'
19
+
20
+ if os.environ.get('NO_GPU', '0') != '1':
21
+ # Check for available accelerators (TPU > GPU > CPU)
22
+ devices = jax.devices()
23
+ if devices:
24
+ default_device = devices[0]
25
+ device_type = default_device.platform
26
+ if device_type in ['gpu', 'tpu']:
27
+ use_accelerator = True
28
+ warn(f"JAX using {device_type.upper()}: {default_device}")
29
+ else:
30
+ warn("JAX using CPU (no GPU/TPU detected)")
31
+ else:
32
+ warn("NO_GPU=1: JAX forced to CPU-only mode")
33
+
34
+ # Use jax.numpy as the array backend
35
+ xp = jnp
36
+ # For compatibility, add asnumpy function
37
+ xp.asnumpy = lambda x: np.array(x)
38
+
39
+
40
+ class Grid:
41
+ def __init__(self, *, x0, x1, xstep=1, y0, y1, ystep=1):
42
+ """initializes 2D grid with x0<=x<=x1 and y0<=y<=y1;
43
+ Creates a 1D numpy array of grid coordinates in self.x and self.y"""
44
+ self.x0 = x0
45
+ self.y0 = y0
46
+ self.x1 = x1
47
+ self.y1 = y1
48
+ self.xstep = xstep
49
+ self.ystep = ystep
50
+ xvals = np.arange(x0, x1 + xstep, xstep)
51
+ yvals = np.arange(y1, y0 - ystep, -ystep)
52
+ xgrid, ygrid = np.meshgrid(xvals, yvals)
53
+ self.x = np.ravel(xgrid)
54
+ self.y = np.ravel(ygrid)
55
+ self.points = np.column_stack((self.x,self.y))
56
+ # extent should match extent=(x0,x1,y0,y1) for compatibility with matplotlib.pyplot.contour
57
+ # see https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.contour.html
58
+ self.extent = (self.x0, self.x1, self.y0, self.y1)
59
+ self.gshape = self.shape()
60
+ self.boundary = ((self.x==x0) | (self.x==x1) | (self.y==y0) | (self.y==y1))
61
+ self.len = self.gshape[0] * self.gshape[1]
62
+ assert self.x.shape == (self.len,)
63
+ assert self.y.shape == (self.len,)
64
+ assert self.points.shape == (self.len,2)
65
+
66
+ def shape(self, *, x0=None, x1=None, xstep=None, y0=None, y1=None, ystep=None):
67
+ """returns a tuple(number_of_rows,number_of_cols) for the natural shape of the current grid, or a subset"""
68
+ x0 = self.x0 if x0 is None else x0
69
+ x1 = self.x1 if x1 is None else x1
70
+ y0 = self.y0 if y0 is None else y0
71
+ y1 = self.y1 if y1 is None else y1
72
+ xstep = self.xstep if xstep is None else xstep
73
+ ystep = self.ystep if ystep is None else ystep
74
+ if x1 < x0:
75
+ raise ValueError
76
+ if y1 < y0:
77
+ raise ValueError
78
+ if xstep <= 0:
79
+ raise ValueError
80
+ if ystep <= 0:
81
+ raise ValueError
82
+ number_of_rows = 1 + int((y1 - y0) / ystep)
83
+ number_of_cols = 1 + int((x1 - x0) / xstep)
84
+ return (number_of_rows, number_of_cols)
85
+
86
+ def within_box(self, *, x0=None, x1=None, y0=None, y1=None):
87
+ """returns a 1D numpy boolean array, suitable as an index mask, for testing whether a grid point is also in the defined box"""
88
+ x0 = self.x0 if x0 is None else x0
89
+ x1 = self.x1 if x1 is None else x1
90
+ y0 = self.y0 if y0 is None else y0
91
+ y1 = self.y1 if y1 is None else y1
92
+ return (self.x >= x0) & (self.x <= x1) & (self.y >= y0) & (self.y <= y1)
93
+
94
+ def within_disk(self, *, x0, y0, r, metric="euclidean", **kwargs):
95
+ """returns 1D numpy boolean array, suitable as an index mask, for testing whether a grid point is also in the defined disk"""
96
+ mask = (
97
+ cdist([[x0, y0]], self.points, metric=metric, **kwargs) <= r
98
+ ).flatten()
99
+ assert mask.shape == (self.len,)
100
+ return mask
101
+
102
+ def within_triangle(self,*,points):
103
+ """returns 1D numpy boolean array, suitable as an index mask, for testing whether a grid point is also in the defined triangle"""
104
+ points = np.asarray(points)
105
+ assert points.shape == (3,2)
106
+ barycentric_to_cartesian_matrix = np.vstack((points[:,0],points[:,1],np.ones(points.shape[0])))
107
+ assert barycentric_to_cartesian_matrix.shape == (3,3)
108
+ cartesian_to_barycentrix_matrix = np.linalg.inv(barycentric_to_cartesian_matrix)
109
+ mask = np.logical_not(
110
+ np.any(
111
+ np.dot(
112
+ cartesian_to_barycentrix_matrix,
113
+ np.vstack(
114
+ (
115
+ self.x,
116
+ self.y,
117
+ np.ones(self.len)
118
+ )
119
+ )
120
+ ) < (-1e-10),
121
+ axis=0)
122
+ )
123
+ assert mask.shape == (self.len,)
124
+ return mask
125
+
126
+ def index(self, *, x, y):
127
+ """returns the unique 1D array index for grid point (x,y)"""
128
+ isSelectedPoint = (self.x == x) & (self.y == y)
129
+ indexes = np.flatnonzero((isSelectedPoint))
130
+ assert len(indexes) == 1
131
+ return indexes[0]
132
+
133
+ def embedding(self, *, valid):
134
+ """
135
+ returns an embedding function efunc(z,fill=0.0) from 1D arrays z of size sum(valid)
136
+ to arrays of size self.len
137
+
138
+ valid is a np.array of type boolean, of size self.len
139
+
140
+ fill is the value for indices outside the embedding. The default
141
+ is zero (0.0). Setting fill=np.nan can be useful for
142
+ plotting purposes as matplotlib will omit np.nan values from various
143
+ kinds of plots.
144
+ """
145
+
146
+ assert self.len == len(valid)
147
+ correct_z_len = valid.sum()
148
+
149
+ def efunc(z, fill=0.0):
150
+ assert len(z) == correct_z_len
151
+ v = np.full(self.len, fill)
152
+ v[valid] = z
153
+ return v
154
+
155
+ return efunc
156
+
157
+ def extremes(self, z, *, valid=None):
158
+ # missing valid defaults to all True array for grid
159
+ valid = np.full((self.len,), True) if valid is None else valid
160
+ assert valid.shape == (self.len,)
161
+ assert z.shape == (valid.sum(),)
162
+ min_z = z.min()
163
+ min_z_mask = np.abs(z-min_z)<1e-10
164
+ max_z = z.max()
165
+ max_z_mask = np.abs(z-max_z)<1e-10
166
+ return (min_z,self.points[valid][min_z_mask],max_z,self.points[valid][max_z_mask])
167
+
168
+
169
+ def spatial_utilities(
170
+ self, *, voter_ideal_points, metric="sqeuclidean", scale=-1, **kwargs
171
+ ):
172
+ """returns utility function values for each voter at each grid point"""
173
+ return scale * cdist(
174
+ np.asarray(voter_ideal_points), self.points, metric=metric, **kwargs
175
+ )
176
+
177
+ def plot(
178
+ self,
179
+ z,
180
+ *,
181
+ title=None,
182
+ cmap=cm.gray_r,
183
+ alpha=0.6,
184
+ alpha_points=0.3,
185
+ log=True,
186
+ points=None,
187
+ zoom=False,
188
+ border=1,
189
+ logbias=1e-100,
190
+ figsize=(10, 10),
191
+ dpi=72,
192
+ fname=None
193
+ ):
194
+ """plots values z defined on the grid;
195
+ optionally plots additional 2D points
196
+ and zooms to fit the bounding box of the points"""
197
+ plt.figure(figsize=figsize, dpi=dpi)
198
+ plt.rcParams["font.size"] = "24"
199
+ fmt = "%1.2f" if log else "%.2e"
200
+ if zoom:
201
+ assert points.shape[0] > 2
202
+ assert points.shape[1] == 2
203
+ [min_x, min_y] = np.min(points, axis=0) - border
204
+ [max_x, max_y] = np.max(points, axis=0) + border
205
+ box = {"x0": min_x, "x1": max_x, "y0": min_y, "y1": max_y}
206
+ inZoom = self.within_box(**box)
207
+ zshape = self.shape(**box)
208
+ extent = (min_x, max_x, min_y, max_y)
209
+ zraw = np.copy(z[inZoom]).reshape(zshape)
210
+ x = np.copy(self.x[inZoom]).reshape(zshape)
211
+ y = np.copy(self.y[inZoom]).reshape(zshape)
212
+ else:
213
+ zshape = self.gshape
214
+ extent = self.extent
215
+ zraw = z.reshape(zshape)
216
+ x = self.x.reshape(zshape)
217
+ y = self.y.reshape(zshape)
218
+ zplot = np.log10(logbias + zraw) if log else zraw
219
+ contours = plt.contour(x, y, zplot, extent=extent, cmap=cmap)
220
+ plt.clabel(contours, inline=True, fontsize=12, fmt=fmt)
221
+ plt.imshow(zplot, extent=extent, cmap=cmap, alpha=alpha)
222
+ if points is not None:
223
+ plt.scatter(points[:, 0], points[:, 1], alpha=alpha_points, color="black")
224
+ if title is not None:
225
+ plt.title(title)
226
+ if fname is None:
227
+ plt.show()
228
+ else:
229
+ plt.savefig(fname)
230
+
231
+
232
+ def assert_valid_transition_matrix(P, *, decimal=6):
233
+ """asserts that jax or numpy array is square and that each row sums to 1.0
234
+ with default tolerance of 6 decimal places (appropriate for float32)"""
235
+ rows, cols = P.shape
236
+ assert rows == cols
237
+ # Convert to numpy for testing
238
+ P_np = np.array(P)
239
+ np.testing.assert_array_almost_equal(
240
+ P_np.sum(axis=1),
241
+ np.ones(shape=(rows)),
242
+ decimal
243
+ )
244
+
245
+
246
+ def assert_zero_diagonal_int_matrix(M):
247
+ """asserts that jax or numpy array is square and the diagonal is 0.0"""
248
+ rows, cols = M.shape
249
+ assert rows == cols
250
+ M_np = np.array(M)
251
+ np.testing.assert_array_equal(
252
+ np.diag(M_np),
253
+ np.zeros(shape=(rows), dtype=int)
254
+ )
255
+
256
+ class MarkovChainCPUGPU:
257
+ def __init__(self, *, P, computeNow=True, tolerance=None):
258
+ """initializes a MarkovChainCPUGPU instance by copying in the transition
259
+ matrix P and calculating chain properties"""
260
+ if tolerance is None:
261
+ tolerance = TOLERANCE
262
+ self.P = jnp.asarray(P) # copy transition matrix to JAX array
263
+ assert_valid_transition_matrix(P)
264
+ diagP = jnp.diagonal(self.P)
265
+ self.absorbing_points = jnp.equal(diagP, 1.0)
266
+ self.unreachable_points = jnp.equal(jnp.sum(self.P, axis=0), diagP)
267
+ self.has_unique_stationary_distibution = not jnp.any(self.absorbing_points)
268
+ if computeNow and self.has_unique_stationary_distibution:
269
+ self.find_unique_stationary_distribution(tolerance=tolerance)
270
+
271
+ def L1_norm_of_single_step_change(self, x):
272
+ """returns float(L1(xP-x))"""
273
+ return float(jnp.linalg.norm(jnp.dot(x, self.P) - x, ord=1))
274
+
275
+ def solve_for_unit_eigenvector(self):
276
+ """This is another way to potentially find the stationary distribution,
277
+ but can suffer from numerical irregularities like negative entries.
278
+ Assumes eigenvalue of 1.0 exists and solves for the eigenvector by
279
+ considering a related matrix equation Q v = b, where:
280
+ Q is P transpose minus the identity matrix I, with the first row
281
+ replaced by all ones for the vector scaling requirement;
282
+ v is the eigenvector of eigenvalue 1 to be found; and
283
+ b is the first basis vector, where b[0]=1 and 0 elsewhere."""
284
+ n = self.P.shape[0]
285
+ Q = jnp.transpose(self.P).astype(jnp.float32) - jnp.eye(n, dtype=jnp.float32)
286
+ Q = Q.at[0].set(jnp.ones(n, dtype=jnp.float32)) # JAX immutable update
287
+ b = jnp.zeros(n, dtype=jnp.float32)
288
+ b = b.at[0].set(1.0) # JAX immutable update
289
+
290
+ error_unable_msg = "unable to find unique unit eigenvector "
291
+ try:
292
+ unit_eigenvector = jnp.linalg.solve(Q, b)
293
+ except Exception as err:
294
+ warn(str(err)) # print the original exception lest it be lost for debugging purposes
295
+ raise RuntimeError(error_unable_msg+"(solver)")
296
+
297
+ if jnp.isnan(unit_eigenvector.sum()):
298
+ raise RuntimeError(error_unable_msg+"(nan)")
299
+
300
+ min_component = float(unit_eigenvector.min())
301
+ # Increased threshold for NumPy 2.0 compatibility (was -1e-7)
302
+ if ((min_component<0.0) and (min_component>-2e-7)):
303
+ warn('attempting fix of neg components')
304
+ to_zero = unit_eigenvector < 0.0
305
+ num_zeroed = to_zero.sum()
306
+ mass_destroyed = unit_eigenvector[to_zero].sum()
307
+ warn('num_zeroed = '+str(num_zeroed))
308
+ warn('mass relocated = '+str(mass_destroyed))
309
+
310
+ # JAX immutable updates
311
+ unit_eigenvector = unit_eigenvector.at[to_zero].set(0.0)
312
+ am = unit_eigenvector.argmax()
313
+ unit_eigenvector = unit_eigenvector.at[am].add(mass_destroyed)
314
+ unit_eigenvector = jnp.dot(unit_eigenvector, self.P)
315
+
316
+ min_component = float(unit_eigenvector.min())
317
+ warn('fixed min_component '+str(min_component))
318
+
319
+ if (min_component<0.0):
320
+ neg_msg = "(negative components: "+str(min_component)+" )"
321
+ warn(neg_msg)
322
+ raise RuntimeError(error_unable_msg+neg_msg)
323
+
324
+ self.unit_eigenvector = unit_eigenvector
325
+ return self.unit_eigenvector
326
+
327
+
328
+ def find_unique_stationary_distribution(self, *, tolerance=None, **kwargs):
329
+ """finds the stationary distribution for a Markov Chain using algebraic method"""
330
+ if tolerance is None:
331
+ tolerance = TOLERANCE
332
+ if jnp.any(self.absorbing_points):
333
+ self.stationary_distribution = None
334
+ return None
335
+ self.stationary_distribution = self.solve_for_unit_eigenvector()
336
+ self.check_norm = self.L1_norm_of_single_step_change(self.stationary_distribution)
337
+ if self.check_norm > tolerance:
338
+ raise RuntimeError(f"Stationary distribution check norm {self.check_norm} exceeds tolerance {tolerance}")
339
+ return self.stationary_distribution
340
+
341
+ def diagnostic_metrics(self):
342
+ """ return Markov chain approximation metrics in mathematician-friendly format """
343
+ metrics = {
344
+ '||F||': self.P.shape[0],
345
+ '(𝝨𝝿)-1': float(self.stationary_distribution.sum())-1.0, # cast to float to avoid cupy array singleton
346
+ '||𝝿P-𝝿||_L1_norm': self.L1_norm_of_single_step_change(
347
+ self.stationary_distribution
348
+ )
349
+ }
350
+ return metrics
351
+
352
+ class VotingModel:
353
+ def __init__(
354
+ self,
355
+ *,
356
+ utility_functions,
357
+ number_of_voters,
358
+ number_of_feasible_alternatives,
359
+ majority,
360
+ zi
361
+ ):
362
+ """initializes a VotingModel with utility_functions for each voter,
363
+ the number_of_voters,
364
+ the number_of_feasible_alternatives,
365
+ the majority size, and whether to use zi fully random agenda or
366
+ intelligent challengers random over winning set+status quo"""
367
+ assert utility_functions.shape == (
368
+ number_of_voters,
369
+ number_of_feasible_alternatives,
370
+ )
371
+ self.utility_functions = utility_functions
372
+ self.number_of_voters = number_of_voters
373
+ self.number_of_feasible_alternatives = number_of_feasible_alternatives
374
+ self.majority = majority
375
+ self.zi = zi
376
+ self.analyzed = False
377
+
378
+ def E_𝝿(self,z):
379
+ """returns mean, i.e., expected value of z under the stationary distribution"""
380
+ return np.dot(self.stationary_distribution,z)
381
+
382
+ def analyze(self):
383
+ self.MarkovChain = MarkovChainCPUGPU(P=self._get_transition_matrix())
384
+ self.core_points = xp.asnumpy(self.MarkovChain.absorbing_points)
385
+ self.core_exists = np.any(self.core_points)
386
+ if not self.core_exists:
387
+ self.stationary_distribution = xp.asnumpy(
388
+ self.MarkovChain.stationary_distribution
389
+ )
390
+ self.analyzed = True
391
+
392
+ def what_beats(self, *, index):
393
+ """returns array of size number_of_feasible_alternatives
394
+ with value 1 where alternative beats current index by some majority"""
395
+ assert self.analyzed
396
+ points = xp.asnumpy(self.MarkovChain.P[index, :] > 0).astype("int32")
397
+ points[index] = 0
398
+ return points
399
+
400
+ def what_is_beaten_by(self, *, index):
401
+ """returns array of size number_of_feasible_alternatives
402
+ with value 1 where current index beats alternative by some majority"""
403
+ assert self.analyzed
404
+ points = xp.asnumpy(self.MarkovChain.P[:, index] > 0).astype("int32")
405
+ points[index] = 0
406
+ return points
407
+
408
+ def summarize_in_context(self,*,grid,valid=None):
409
+ """calculate summary statistics for stationary distribution using grid's coordinates and optional subset valid"""
410
+ # missing valid defaults to all True array for grid
411
+ valid = np.full((grid.len,), True) if valid is None else valid
412
+ # check valid array shape
413
+ assert valid.shape == (grid.len,)
414
+ # get X and Y coordinates for valid grid points
415
+ validX = grid.x[valid]
416
+ validY = grid.y[valid]
417
+ valid_points = grid.points[valid]
418
+ if self.core_exists:
419
+ return {
420
+ 'core_exists': self.core_exists,
421
+ 'core_points': valid_points[self.core_points]
422
+ }
423
+ # core does not exist, so evaulate mean, cov, min, max of stationary distribution
424
+ # first check that the number of valid points matches the dimensionality of the stationary distribution
425
+ assert (valid.sum(),) == self.stationary_distribution.shape
426
+ point_mean = self.E_𝝿(valid_points)
427
+ cov = np.cov(valid_points,rowvar=False,ddof=0,aweights=self.stationary_distribution)
428
+ (prob_min,prob_min_points,prob_max,prob_max_points) = \
429
+ grid.extremes(self.stationary_distribution,valid=valid)
430
+ _nonzero_statd = self.stationary_distribution[self.stationary_distribution>0]
431
+ entropy_bits = -_nonzero_statd.dot(np.log2(_nonzero_statd))
432
+ return {
433
+ 'core_exists': self.core_exists,
434
+ 'point_mean': point_mean,
435
+ 'point_cov': cov,
436
+ 'prob_min': prob_min,
437
+ 'prob_min_points': prob_min_points,
438
+ 'prob_max': prob_max,
439
+ 'prob_max_points': prob_max_points,
440
+ 'entropy_bits': entropy_bits
441
+ }
442
+
443
+ def plots(
444
+ self,
445
+ *,
446
+ grid,
447
+ voter_ideal_points,
448
+ diagnostics=False,
449
+ log=True,
450
+ embedding=lambda z, fill: z,
451
+ zoomborder=0,
452
+ dpi=72,
453
+ figsize=(10, 10),
454
+ fprefix=None,
455
+ title_core="Core (aborbing) points",
456
+ title_sad="L1 norm of difference in two rows of P^power",
457
+ title_diff1="L1 norm of change in corner row",
458
+ title_diff2="L1 norm of change in center row",
459
+ title_sum1minus1="Corner row sum minus 1.0",
460
+ title_sum2minus1="Center row sum minus 1.0",
461
+ title_unreachable_points="Dominated (unreachable) points",
462
+ title_stationary_distribution_no_grid="Stationary Distribution",
463
+ title_stationary_distribution="Stationary Distribution",
464
+ title_stationary_distribution_zoom="Stationary Distribution (zoom)"
465
+ ):
466
+ def _fn(name):
467
+ return None if fprefix is None else fprefix + name
468
+
469
+ def _save(fname):
470
+ if fprefix is not None:
471
+ plt.savefig(fprefix + fname)
472
+
473
+ if self.core_exists:
474
+ grid.plot(
475
+ embedding(self.core_points.astype("int32"), fill=np.nan),
476
+ log=log,
477
+ points=voter_ideal_points,
478
+ zoom=True,
479
+ title=title_core,
480
+ dpi=dpi,
481
+ figsize=figsize,
482
+ fname=_fn("core.png"),
483
+ )
484
+ return None # when core exists abort as additional plots undefined
485
+ if diagnostics:
486
+ df = pd.DataFrame(self.MarkovChain.power_method_diagnostics)
487
+ df.plot.scatter(
488
+ "power", "sad", loglog=True, title=title_sad, figsize=figsize
489
+ )
490
+ _save("diagnostic_sad.png")
491
+ df.plot.scatter(
492
+ "power", "diff1", loglog=True, title=title_diff1, figsize=figsize
493
+ )
494
+ _save("diagnostic_diff1.png")
495
+ df.plot.scatter(
496
+ "power", "diff2", loglog=True, title=title_diff2, figsize=figsize
497
+ )
498
+ _save("diagnostic_diff2.png")
499
+ df.plot.scatter(
500
+ "power",
501
+ "sum1minus1",
502
+ logx=True,
503
+ title=title_sum1minus1,
504
+ figsize=figsize,
505
+ )
506
+ _save("diagnostic_sum1minus1.png")
507
+ df.plot.scatter(
508
+ "power",
509
+ "sum2minus1",
510
+ logx=True,
511
+ title=title_sum2minus1,
512
+ figsize=figsize,
513
+ )
514
+ _save("diagnostic_sum2minus1.png")
515
+ if grid is not None:
516
+ grid.plot(
517
+ embedding(
518
+ xp.asnumpy(self.MarkovChain.unreachable_points).astype("int32"),
519
+ fill=np.nan
520
+ ),
521
+ log=log,
522
+ title=title_unreachable_points,
523
+ dpi=dpi,
524
+ figsize=figsize,
525
+ fname=_fn("unreachable.png"),
526
+ )
527
+ z = self.stationary_distribution
528
+ if grid is None:
529
+ pd.Series(z).plot(
530
+ title=title_stationary_distribution_no_grid, figsize=figsize
531
+ )
532
+ _save("stationary_distribubtion_no_grid.png")
533
+ else:
534
+ grid.plot(
535
+ embedding(z, fill=np.nan),
536
+ log=log,
537
+ points=voter_ideal_points,
538
+ title=title_stationary_distribution,
539
+ figsize=figsize,
540
+ dpi=dpi,
541
+ fname=_fn("stationary_distribution.png"),
542
+ )
543
+ if voter_ideal_points is not None:
544
+ grid.plot(
545
+ embedding(z, fill=np.nan),
546
+ log=log,
547
+ points=voter_ideal_points,
548
+ zoom=True,
549
+ border=zoomborder,
550
+ title=title_stationary_distribution_zoom,
551
+ figsize=figsize,
552
+ dpi=dpi,
553
+ fname=_fn("stationary_distribution_zoom.png"),
554
+ )
555
+
556
+ def _get_transition_matrix(self):
557
+ utility_functions = self.utility_functions
558
+ majority = self.majority
559
+ zi = self.zi
560
+ nfa = self.number_of_feasible_alternatives
561
+ cU = jnp.asarray(utility_functions)
562
+
563
+ # Vectorized computation: compare all alternatives at once
564
+ # cU shape: (n_voters, nfa)
565
+ # cU[:, :, jnp.newaxis] shape: (n_voters, nfa, 1)
566
+ # cU[:, jnp.newaxis, :] shape: (n_voters, 1, nfa)
567
+ # Result shape: (n_voters, nfa, nfa) where [v, sq, ch] = voter v prefers challenger ch over status quo sq
568
+ preferences = jnp.greater(cU[:, jnp.newaxis, :], cU[:, :, jnp.newaxis])
569
+
570
+ # Sum votes across voters: shape (nfa, nfa) where [sq, ch] = votes for ch when sq is status quo
571
+ total_votes = preferences.astype("int32").sum(axis=0)
572
+
573
+ # Determine winners: 1 if challenger gets majority, 0 otherwise
574
+ cV = jnp.greater_equal(total_votes, majority).astype("int32")
575
+
576
+ assert_zero_diagonal_int_matrix(cV)
577
+ cV_sum_of_row = cV.sum(axis=1) # sum up all col for each row
578
+
579
+ # set up the ZI and MI transition matrices
580
+ if zi:
581
+ cP = jnp.divide(
582
+ jnp.add(cV, jnp.diag(jnp.subtract(nfa, cV_sum_of_row))),
583
+ nfa
584
+ ).astype(jnp.float32)
585
+ else:
586
+ cP = jnp.divide(
587
+ jnp.add(cV, jnp.eye(nfa)),
588
+ (1 + cV_sum_of_row)[:, jnp.newaxis]
589
+ ).astype(jnp.float32)
590
+
591
+ assert_valid_transition_matrix(cP)
592
+ return cP
593
+
594
+
595
+ class CondorcetCycle(VotingModel):
596
+ def __init__(self, *, zi):
597
+ # docs suggest to call superclass directly
598
+ # instead of using super()
599
+ # https://docs.python.org/3/tutorial/classes.html#inheritance
600
+ VotingModel.__init__(
601
+ self,
602
+ zi=zi,
603
+ number_of_voters=3,
604
+ majority=2,
605
+ number_of_feasible_alternatives=3,
606
+ utility_functions=np.array(
607
+ [
608
+ [3, 2, 1], # first agent prefers A>B>C
609
+ [1, 3, 2], # second agent prefers B>C>A
610
+ [2, 1, 3], # third agents prefers C>A>B
611
+ ]
612
+ ),
613
+ )
File without changes
@@ -0,0 +1,399 @@
1
+ Metadata-Version: 2.4
2
+ Name: gridvoting-jax
3
+ Version: 0.0.1
4
+ Summary: Spatial voting simulations on a grid with random challengers (with float32 JAX backend)
5
+ Home-page: https://github.com/drpaulbrewer/gridvoting-jax
6
+ Author: Paul Brewer
7
+ Author-email: drpaulbrewer@eaftc.com
8
+ Project-URL: Bug Tracker, https://github.com/drpaulbrewer/gridvoting-jax/issues
9
+ Project-URL: Original Project, https://github.com/drpaulbrewer/gridvoting
10
+ Classifier: Programming Language :: Python :: 3
11
+ Classifier: Programming Language :: Python :: 3.9
12
+ Classifier: Programming Language :: Python :: 3.10
13
+ Classifier: Programming Language :: Python :: 3.11
14
+ Classifier: Programming Language :: Python :: 3.12
15
+ Classifier: License :: OSI Approved :: MIT License
16
+ Classifier: Operating System :: POSIX :: Linux
17
+ Classifier: Operating System :: MacOS
18
+ Classifier: Operating System :: Microsoft :: Windows
19
+ Classifier: Development Status :: 3 - Alpha
20
+ Classifier: Environment :: GPU :: NVIDIA CUDA
21
+ Classifier: Environment :: GPU
22
+ Classifier: Intended Audience :: Science/Research
23
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
24
+ Classifier: Topic :: Scientific/Engineering :: Mathematics
25
+ Requires-Python: >=3.9
26
+ Description-Content-Type: text/markdown
27
+ License-File: LICENSE.md
28
+ Requires-Dist: numpy>=2.0.0
29
+ Requires-Dist: pandas>=2.2.0
30
+ Requires-Dist: scipy>=1.13.0
31
+ Requires-Dist: matplotlib>=3.8.0
32
+ Requires-Dist: jax>=0.4.20
33
+ Dynamic: license-file
34
+
35
+ # gridvoting-jax
36
+
37
+ **A JAX-powered derivative of the original [gridvoting](https://github.com/drpaulbrewer/gridvoting) project**
38
+
39
+ [![PyPI version](https://badge.fury.io/py/gridvoting-jax.svg)](https://badge.fury.io/py/gridvoting-jax)
40
+ [![Python 3.9+](https://img.shields.io/badge/python-3.9+-blue.svg)](https://www.python.org/downloads/)
41
+ [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
42
+
43
+ This library provides GPU/TPU/CPU-accelerated spatial voting simulations using Google's JAX framework with float32 precision.
44
+
45
+ ## Origin and Development
46
+
47
+ This project is derived from the original `gridvoting` module, which was developed for the research publication:
48
+
49
+ > Brewer, P., Juybari, J. & Moberly, R.
50
+ > A comparison of zero- and minimal-intelligence agendas in majority-rule voting models.
51
+ > J Econ Interact Coord (2023). https://doi.org/10.1007/s11403-023-00387-8
52
+
53
+ **Migration to JAX**: The computational backend was refactored from NumPy/CuPy to JAX using Google's Antigravity AI assistant. This migration provides:
54
+ - ✨ Unified CPU/GPU/TPU support through JAX
55
+ - 🚀 Improved performance through JIT compilation
56
+ - 💾 Float32 precision for efficiency
57
+ - 🔗 Better compatibility with modern ML/AI workflows
58
+
59
+ **Original Project**: https://github.com/drpaulbrewer/gridvoting
60
+
61
+ ---
62
+
63
+ ## Quick Start
64
+
65
+ ```python
66
+ import gridvoting_jax as gv
67
+
68
+ # Create a grid
69
+ grid = gv.Grid(x0=-20, x1=20, y0=-20, y1=20)
70
+
71
+ # Define voter ideal points
72
+ voter_ideal_points = [[-15, -9], [0, 17], [15, -9]]
73
+
74
+ # Generate utility functions
75
+ utilities = grid.spatial_utilities(voter_ideal_points=voter_ideal_points)
76
+
77
+ # Create and analyze voting model
78
+ vm = gv.VotingModel(
79
+ utility_functions=utilities,
80
+ majority=2,
81
+ zi=False, # Minimal Intelligence agenda
82
+ number_of_voters=3,
83
+ number_of_feasible_alternatives=grid.len
84
+ )
85
+
86
+ vm.analyze()
87
+
88
+ # View results
89
+ print(f"Device: {gv.device_type}") # Shows 'gpu', 'tpu', or 'cpu'
90
+ print(f"Stationary distribution: {vm.stationary_distribution[:5]}...")
91
+ ```
92
+
93
+ ---
94
+
95
+ ## Installation
96
+
97
+ ### Google Colab (Recommended)
98
+ All dependencies are pre-installed! Just run:
99
+ ```python
100
+ !pip install gridvoting-jax
101
+ ```
102
+
103
+ ### Local Installation
104
+ ```bash
105
+ pip install gridvoting-jax
106
+ ```
107
+
108
+ **GPU Support**: JAX automatically detects and uses NVIDIA GPUs (CUDA) when available.
109
+
110
+ **TPU Support**: JAX automatically detects TPUs on Google Cloud.
111
+
112
+ **CPU-Only Mode**: Set environment variable `NO_GPU=1` to force CPU-only execution:
113
+ ```bash
114
+ NO_GPU=1 python your_script.py
115
+ ```
116
+
117
+ ---
118
+
119
+ ## Requirements
120
+
121
+ - Python 3.9+
122
+ - numpy >= 2.0.0
123
+ - pandas >= 2.2.0
124
+ - scipy >= 1.13.0
125
+ - matplotlib >= 3.8.0
126
+ - jax >= 0.4.20
127
+
128
+ **Google Colab**: All dependencies are pre-installed (numpy 2.0.2, pandas 2.2.2, scipy 1.16.3, matplotlib 3.10, jax 0.7).
129
+
130
+ ---
131
+
132
+ ## Performance
133
+
134
+ gridvoting-jax uses JAX's JIT compilation for high performance:
135
+
136
+ - **First run**: ~1-2s (includes JIT compilation)
137
+ - **Subsequent runs**: ~0.03-0.05s (comparable to CuPy)
138
+ - **Vectorized operations**: All computations run on GPU/TPU when available
139
+
140
+ **Benchmark** (g=20, 1681 alternatives, Nvidia 1080Ti):
141
+ - Analysis time: 0.033s (after JIT compilation)
142
+ - Test suite: 22 tests in ~16s
143
+ - Speedup: 10-30x faster than CPU-only
144
+
145
+ ---
146
+
147
+ ## Differences from Original gridvoting
148
+
149
+ This JAX version differs from the original in several ways:
150
+
151
+ | Feature | Original gridvoting | gridvoting-jax |
152
+ |---------|-------------------|----------------|
153
+ | **Backend** | NumPy/CuPy | JAX |
154
+ | **Precision** | Float64 | Float32 |
155
+ | **Solver** | Power + Algebraic | Algebraic only |
156
+ | **Tolerance** | 1e-10 | 5e-5 |
157
+ | **Device Detection** | GPU/CPU | TPU/GPU/CPU |
158
+ | **Import** | `import gridvoting` | `import gridvoting_jax` |
159
+
160
+ **Numerical Accuracy**: Float32 provides ~7 decimal digits of precision, which is sufficient for spatial voting simulations. Tolerance of 5e-5 ensures robust convergence on grids up to 60x60.
161
+
162
+ ---
163
+
164
+ ## Random Sequential Voting Simulations
165
+
166
+ This follows [section 2 of our research paper](https://link.springer.com/article/10.1007/s11403-023-00387-8#Sec4).
167
+
168
+ A simulation consists of:
169
+ - A sequence of times: `t=0,1,2,3,...`
170
+ - A finite feasible set of alternatives **F**
171
+ - A set of voters who have preferences over the alternatives and vote truthfully
172
+ - A rule for voting and selecting challengers
173
+ - A mapping of the set of alternatives **F** into a 2D grid
174
+
175
+ The active or status quo alternative at time t is called `f[t]`.
176
+
177
+ At each t, there is a majority-rule vote between alternative `f[t]` and a challenger alternative `c[t]`. The winner of that vote becomes the next status quo `f[t+1]`.
178
+
179
+ **Randomness** enters through two possible rules for choosing the challenger `c[t]`:
180
+ - **Zero Intelligence (ZI)** (`zi=True`): `c[t]` is chosen uniformly at random from **F**
181
+ - **Minimal Intelligence (MI)** (`zi=False`): `c[t]` is chosen uniformly from the status quo `f[t]` and the possible winning alternatives given `f[t]`
182
+
183
+ ---
184
+
185
+ ## API Documentation
186
+
187
+ ### class Grid
188
+
189
+ #### Constructor
190
+
191
+ ```python
192
+ gridvoting_jax.Grid(x0, x1, xstep=1, y0, y1, ystep=1)
193
+ ```
194
+
195
+ Constructs a 2D grid in x and y dimensions.
196
+
197
+ **Parameters:**
198
+ - `x0`: leftmost grid x-coordinate
199
+ - `x1`: rightmost grid x-coordinate
200
+ - `xstep=1`: optional, grid spacing in x dimension
201
+ - `y0`: lowest grid y-coordinate
202
+ - `y1`: highest grid y-coordinate
203
+ - `ystep=1`: optional, grid spacing in y dimension
204
+
205
+ **Example:**
206
+ ```python
207
+ import gridvoting_jax as gv
208
+ grid = gv.Grid(x0=-5, x1=5, y0=-7, y1=7)
209
+ ```
210
+
211
+ **Instance Properties:**
212
+ - `grid.x0, grid.x1, grid.xstep, grid.y0, grid.y1, grid.ystep` - constructor parameters
213
+ - `grid.points` - 2D numpy array of grid points in typewriter order `[[x0,y1],[x0+1,y1],...,[x1,y0]]`
214
+ - `grid.x` - 1D numpy array of x-coordinates in typewriter order
215
+ - `grid.y` - 1D numpy array of y-coordinates in typewriter order
216
+ - `grid.gshape` - natural shape `(number_of_rows, number_of_cols)`
217
+ - `grid.extent` - tuple `(x0, x1, y0, y1)` for matplotlib
218
+ - `grid.len` - number of points on the grid
219
+ - `grid.boundary` - 1D boolean array indicating boundary points
220
+
221
+ #### Methods
222
+
223
+ **`grid.spatial_utilities(voter_ideal_points, metric='sqeuclidean', scale=-1)`**
224
+
225
+ Returns utility function values for each voter at each grid point as a function of distance from an ideal point.
226
+
227
+ - `voter_ideal_points`: array of 2D coordinates `[[xv1,yv1],[xv2,yv2],...]`
228
+ - `metric`: distance metric (default `'sqeuclidean'`). See [scipy.spatial.distance.cdist](https://docs.scipy.org/doc/scipy/reference/generated/scipy.spatial.distance.cdist.html)
229
+
230
+ **`grid.within_box(x0=None, x1=None, y0=None, y1=None)`**
231
+
232
+ Returns 1D boolean array for testing whether grid points are in the defined box.
233
+
234
+ **`grid.within_disk(x0, y0, r, metric='euclidean')`**
235
+
236
+ Returns 1D boolean array for testing whether grid points are in the defined disk.
237
+
238
+ **`grid.within_triangle(points)`**
239
+
240
+ Returns 1D boolean array for testing whether grid points are in the defined triangle.
241
+ - `points`: shape `(3,2)` array of triangle vertices
242
+
243
+ **`grid.embedding(valid)`**
244
+
245
+ Returns an embedding function `efunc(z, fill=0.0)` that maps 1D arrays of size `valid.sum()` to arrays of size `grid.len`.
246
+
247
+ - `valid`: boolean array of length `grid.len` selecting valid grid points
248
+ - `fill`: value for invalid indices (default 0.0, use `np.nan` for plotting)
249
+
250
+ **`grid.plot(z, title=None, log=True, points=None, zoom=False, ...)`**
251
+
252
+ Creates a contour plot of values z defined on the grid.
253
+
254
+ ---
255
+
256
+ ### class VotingModel
257
+
258
+ #### Constructor
259
+
260
+ ```python
261
+ gridvoting_jax.VotingModel(
262
+ utility_functions,
263
+ number_of_voters,
264
+ number_of_feasible_alternatives,
265
+ majority,
266
+ zi
267
+ )
268
+ ```
269
+
270
+ **Parameters:**
271
+ - `utility_functions`: 2D array of shape `(number_of_voters, number_of_feasible_alternatives)`
272
+ - `number_of_voters`: integer
273
+ - `number_of_feasible_alternatives`: integer
274
+ - `majority`: integer, number of votes needed to win
275
+ - `zi`: boolean, True for Zero Intelligence, False for Minimal Intelligence
276
+
277
+ #### Methods
278
+
279
+ **`analyze()`**
280
+
281
+ Computes the transition matrix and stationary distribution.
282
+
283
+ **`what_beats(index)`**
284
+
285
+ Returns array indicating which alternatives beat the alternative at `index`.
286
+
287
+ **`what_is_beaten_by(index)`**
288
+
289
+ Returns array indicating which alternatives are beaten by the alternative at `index`.
290
+
291
+ **`summarize_in_context(grid, valid=None)`**
292
+
293
+ Calculate summary statistics for stationary distribution using grid coordinates.
294
+
295
+ **`plots(grid, voter_ideal_points, ...)`**
296
+
297
+ Creates visualization plots of the stationary distribution.
298
+
299
+ ---
300
+
301
+ ### class MarkovChainCPUGPU
302
+
303
+ #### Constructor
304
+
305
+ ```python
306
+ gridvoting_jax.MarkovChainCPUGPU(P, computeNow=True, tolerance=5e-5)
307
+ ```
308
+
309
+ **Parameters:**
310
+ - `P`: valid transition matrix (square JAX/numpy array whose rows sum to 1.0)
311
+ - `computeNow=True`: immediately compute Markov Chain properties
312
+ - `tolerance=5e-5`: tolerance for checking convergence (appropriate for float32)
313
+
314
+ #### Methods
315
+
316
+ **`solve_for_unit_eigenvector()`**
317
+
318
+ Finds the stationary distribution by solving for the unit eigenvector.
319
+
320
+ **`find_unique_stationary_distribution(tolerance=5e-5)`**
321
+
322
+ Finds the unique stationary distribution using the algebraic method.
323
+
324
+ **`diagnostic_metrics()`**
325
+
326
+ Returns dictionary of diagnostic metrics for the Markov chain.
327
+
328
+ ---
329
+
330
+ ## Testing
331
+
332
+ ### Run Tests
333
+
334
+ ```bash
335
+ # Install development dependencies
336
+ pip install -r requirements-dev.txt
337
+
338
+ # Run all tests
339
+ pytest tests/
340
+
341
+ # Run with coverage
342
+ pytest tests/ --cov=gridvoting_jax
343
+ ```
344
+
345
+ ### Google Colab
346
+
347
+ ```python
348
+ !pip install gridvoting-jax
349
+ !pytest /usr/local/lib/python3.*/dist-packages/gridvoting_jax/
350
+ ```
351
+
352
+ ---
353
+
354
+ ## License
355
+
356
+ The software is provided under the standard [MIT License](./LICENSE.md).
357
+
358
+ You are welcome to try the software, read it, copy it, adapt it to your needs, and redistribute your adaptations. If you change the software, be sure to change the module name so that others know it is not the original. See the LICENSE file for more details.
359
+
360
+ ---
361
+
362
+ ## Disclaimers
363
+
364
+ The software is provided in the hope that it may be useful to others, but it is not a full-featured turnkey system for conducting arbitrary voting simulations. Additional coding is required to define a specific simulation.
365
+
366
+ Automated tests exist and run on GitHub Actions. However, this cannot guarantee that the software is free of bugs or defects or that it will run on your computer without adjustments.
367
+
368
+ The [MIT License](./LICENSE.md) includes this disclaimer:
369
+
370
+ > THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
371
+
372
+ ---
373
+
374
+ ## Research Data
375
+
376
+ Code specific to the spatial voting and budget voting portions of our research publication -- as well as output data -- is deposited at: [OSF Dataset for A comparison of zero and minimal Intelligence agendas in majority rule voting models](https://osf.io/k2phe/) and is freely available.
377
+
378
+ ---
379
+
380
+ ## Contributing
381
+
382
+ Contributions are welcome! Please feel free to submit a Pull Request.
383
+
384
+ ---
385
+
386
+ ## Citation
387
+
388
+ If you use this software in your research, please cite the original paper:
389
+
390
+ ```bibtex
391
+ @article{brewer2023comparison,
392
+ title={A comparison of zero-and minimal-intelligence agendas in majority-rule voting models},
393
+ author={Brewer, Paul and Juybari, Jeremy and Moberly, Raymond},
394
+ journal={Journal of Economic Interaction and Coordination},
395
+ year={2023},
396
+ publisher={Springer},
397
+ doi={10.1007/s11403-023-00387-8}
398
+ }
399
+ ```
@@ -0,0 +1,7 @@
1
+ gridvoting_jax/__init__.py,sha256=6EYnu9pESYYRL3de9XOCgFlrokygu0sDJeBQnVAjybE,24276
2
+ gridvoting_jax/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
+ gridvoting_jax-0.0.1.dist-info/licenses/LICENSE.md,sha256=4es0Mvw6i0cDKIhQT4blX0pBO-bWfjxOrQY96VUxrbk,1147
4
+ gridvoting_jax-0.0.1.dist-info/METADATA,sha256=9emqibEOiVx1DPJOE9LDTCXFEC_dqduPhvTgXZI9dsc,13049
5
+ gridvoting_jax-0.0.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
6
+ gridvoting_jax-0.0.1.dist-info/top_level.txt,sha256=VH3wQfI2eGIWIsdwMwhLVEsjHQywew_58PA5NDM-Vk0,15
7
+ gridvoting_jax-0.0.1.dist-info/RECORD,,
@@ -0,0 +1,5 @@
1
+ Wheel-Version: 1.0
2
+ Generator: setuptools (80.9.0)
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
5
+
@@ -0,0 +1,8 @@
1
+ Copyright 2021-2025 by Contributors:
2
+ Paul Brewer <drpaulbrewer@eaftc.com> Economic and Financial Technology Consulting LLC
3
+
4
+ Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
5
+
6
+ The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
7
+
8
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
@@ -0,0 +1 @@
1
+ gridvoting_jax