diffusion-cartogram 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,534 @@
1
+ """
2
+ Visualization tools for 2-D VDERM exports.
3
+
4
+ Mirrors visualization.py for the 2-D pipeline:
5
+
6
+ plot_map_2d — static scatter plot of 2-D map points
7
+ plot_density_field_2d — heatmap of the grid density field
8
+ animate_map_deformation_2d — GIF / MP4 from exported map CSVs
9
+ animate_grid_deformation_2d — GIF / MP4 from exported grid CSVs
10
+ plot_density_evolution_2d — density statistics over iterations
11
+
12
+ All animation functions read the CSV files written by
13
+ run_VDERM_2d_with_tracking().
14
+ """
15
+
16
+ import numpy as np
17
+ import matplotlib.pyplot as plt
18
+ from matplotlib.animation import FuncAnimation, PillowWriter
19
+ import glob
20
+ import os
21
+ from tqdm import tqdm
22
+
23
+ from .core_2d import read_csv_2d
24
+
25
+
26
+ # ─── Static plots ─────────────────────────────────────────────────────────────
27
+
28
+ def plot_map_2d(positions, densities=None, title='Map Points',
29
+ cmap='plasma', figsize=(8, 7), point_size=2, alpha=0.7,
30
+ save_file=None):
31
+ """
32
+ Plot a 2-D point set with optional density colour coding.
33
+
34
+ Parameters
35
+ ----------
36
+ positions : ndarray, shape (n, 2)
37
+ [x, y] coordinates.
38
+ densities : ndarray, shape (n,), optional
39
+ Per-point density values. If None, points are drawn in a
40
+ uniform colour.
41
+ title : str, default='Map Points'
42
+ cmap : str, default='plasma'
43
+ figsize : tuple, default=(8, 7)
44
+ point_size : float, default=2
45
+ alpha : float, default=0.7
46
+ save_file : str, optional
47
+ Path to save the figure. If None the figure is shown interactively.
48
+
49
+ Returns
50
+ -------
51
+ fig : matplotlib Figure
52
+
53
+ Examples
54
+ --------
55
+ >>> pts, crs = vd.read_geojson('countries.geojson')
56
+ >>> vd.plot_map_2d(pts, title='World Countries')
57
+
58
+ >>> # After deformation
59
+ >>> deformed = vd.interpolate_to_map_2d(pts, gp, grid.get_displacement_field())
60
+ >>> dens = vd.interpolate_densities_2d(pts, grid)
61
+ >>> vd.plot_map_2d(deformed, densities=dens, title='Cartogram')
62
+ """
63
+ fig, ax = plt.subplots(figsize=figsize)
64
+
65
+ if densities is not None:
66
+ sc = ax.scatter(positions[:, 0], positions[:, 1],
67
+ c=densities, cmap=cmap, s=point_size, alpha=alpha)
68
+ plt.colorbar(sc, ax=ax, label='Density')
69
+ else:
70
+ ax.scatter(positions[:, 0], positions[:, 1],
71
+ c='dodgerblue', s=point_size, alpha=alpha)
72
+
73
+ ax.set_aspect('equal')
74
+ ax.set_xlabel('X')
75
+ ax.set_ylabel('Y')
76
+ ax.set_title(title, fontsize=14, fontweight='bold')
77
+ ax.grid(True, alpha=0.3)
78
+
79
+ if save_file:
80
+ plt.savefig(save_file, dpi=150, bbox_inches='tight')
81
+ plt.close()
82
+ print(f"Plot saved to: {save_file}")
83
+ else:
84
+ plt.show()
85
+
86
+ return fig
87
+
88
+
89
+ def plot_density_field_2d(grid, title='Density Field',
90
+ cmap='plasma', figsize=(8, 7), save_file=None):
91
+ """
92
+ Render the current grid density field as a 2-D heatmap.
93
+
94
+ Parameters
95
+ ----------
96
+ grid : VDERMGrid2D
97
+ title : str, default='Density Field'
98
+ cmap : str, default='plasma'
99
+ figsize : tuple, default=(8, 7)
100
+ save_file : str, optional
101
+
102
+ Returns
103
+ -------
104
+ fig : matplotlib Figure
105
+
106
+ Examples
107
+ --------
108
+ >>> grid.set_density(lambda x, y: 1 + np.exp(-((x-0)**2+(y-0)**2)))
109
+ >>> vd.plot_density_field_2d(grid, title='Gaussian Density')
110
+ """
111
+ x0, y0 = grid.min_bounds
112
+ x1 = x0 + (grid.L - 1) * grid.h
113
+ y1 = y0 + (grid.M - 1) * grid.h
114
+
115
+ fig, ax = plt.subplots(figsize=figsize)
116
+ # rho shape is (L, M) = (x_index, y_index); transpose for imshow (row=y, col=x)
117
+ im = ax.imshow(
118
+ grid.rho.T,
119
+ origin='lower',
120
+ cmap=cmap,
121
+ extent=[x0, x1, y0, y1],
122
+ aspect='equal',
123
+ )
124
+ plt.colorbar(im, ax=ax, label='Density')
125
+ ax.set_xlabel('X')
126
+ ax.set_ylabel('Y')
127
+ ax.set_title(title, fontsize=14, fontweight='bold')
128
+
129
+ if save_file:
130
+ plt.savefig(save_file, dpi=150, bbox_inches='tight')
131
+ plt.close()
132
+ print(f"Plot saved to: {save_file}")
133
+ else:
134
+ plt.show()
135
+
136
+ return fig
137
+
138
+
139
+ def plot_map_before_after(original_points, deformed_points,
140
+ densities=None, title='Before / After Deformation',
141
+ cmap='plasma', figsize=(14, 6), point_size=2,
142
+ alpha=0.7, save_file=None):
143
+ """
144
+ Side-by-side comparison of original and deformed map points.
145
+
146
+ Parameters
147
+ ----------
148
+ original_points : ndarray, shape (n, 2)
149
+ deformed_points : ndarray, shape (n, 2)
150
+ densities : ndarray, shape (n,), optional
151
+ Density values to colour the deformed panel.
152
+ title : str
153
+ cmap : str, default='plasma'
154
+ figsize : tuple, default=(14, 6)
155
+ point_size : float, default=2
156
+ alpha : float, default=0.7
157
+ save_file : str, optional
158
+
159
+ Returns
160
+ -------
161
+ fig : matplotlib Figure
162
+ """
163
+ fig, (ax_orig, ax_deform) = plt.subplots(1, 2, figsize=figsize)
164
+
165
+ ax_orig.scatter(original_points[:, 0], original_points[:, 1],
166
+ c='steelblue', s=point_size, alpha=alpha)
167
+ ax_orig.set_title('Original', fontsize=12)
168
+ ax_orig.set_aspect('equal')
169
+ ax_orig.set_xlabel('X')
170
+ ax_orig.set_ylabel('Y')
171
+ ax_orig.grid(True, alpha=0.3)
172
+
173
+ if densities is not None:
174
+ sc = ax_deform.scatter(deformed_points[:, 0], deformed_points[:, 1],
175
+ c=densities, cmap=cmap, s=point_size, alpha=alpha)
176
+ plt.colorbar(sc, ax=ax_deform, label='Density')
177
+ else:
178
+ ax_deform.scatter(deformed_points[:, 0], deformed_points[:, 1],
179
+ c='coral', s=point_size, alpha=alpha)
180
+
181
+ ax_deform.set_title('Deformed (Cartogram)', fontsize=12)
182
+ ax_deform.set_aspect('equal')
183
+ ax_deform.set_xlabel('X')
184
+ ax_deform.set_ylabel('Y')
185
+ ax_deform.grid(True, alpha=0.3)
186
+
187
+ fig.suptitle(title, fontsize=14, fontweight='bold')
188
+ plt.tight_layout()
189
+
190
+ if save_file:
191
+ plt.savefig(save_file, dpi=150, bbox_inches='tight')
192
+ plt.close()
193
+ print(f"Plot saved to: {save_file}")
194
+ else:
195
+ plt.show()
196
+
197
+ return fig
198
+
199
+
200
+ # ─── Animations ──────────────────────────────────────────────────────────────
201
+
202
+ def animate_map_deformation_2d(export_folder='vderm_2d_exports',
203
+ subfolder='vderm_map',
204
+ output_file='map_animation.gif',
205
+ fps=5,
206
+ subsample=5000,
207
+ cmap='plasma',
208
+ figsize=(8, 7),
209
+ alpha=0.7):
210
+ """
211
+ Create an animated GIF / MP4 of 2-D map point deformation.
212
+
213
+ Reads the CSV files written by run_VDERM_2d_with_tracking() with
214
+ ``export_map=True``.
215
+
216
+ Parameters
217
+ ----------
218
+ export_folder : str, default='vderm_2d_exports'
219
+ subfolder : str, default='vderm_map'
220
+ output_file : str, default='map_animation.gif'
221
+ Output path (.gif or .mp4).
222
+ fps : int, default=5
223
+ subsample : int or None, default=5000
224
+ Downsample to this many points per frame.
225
+ cmap : str, default='plasma'
226
+ figsize : tuple, default=(8, 7)
227
+ alpha : float, default=0.7
228
+
229
+ Returns
230
+ -------
231
+ None
232
+
233
+ Examples
234
+ --------
235
+ >>> vd.animate_map_deformation_2d('vderm_2d_exports',
236
+ ... output_file='cartogram.gif')
237
+ """
238
+ pattern = os.path.join(export_folder, subfolder, 'map_iteration_*.csv')
239
+ files = sorted(glob.glob(pattern))
240
+
241
+ final_pattern = os.path.join(export_folder, subfolder, 'map_final_*.csv')
242
+ final_files = glob.glob(final_pattern)
243
+ if final_files:
244
+ files.append(sorted(final_files)[0])
245
+
246
+ if not files:
247
+ raise FileNotFoundError(
248
+ f"No map CSV files found matching {pattern}\n"
249
+ "Run run_VDERM_2d_with_tracking with export_map=True."
250
+ )
251
+
252
+ print(f"Found {len(files)} frames")
253
+
254
+ # Determine fixed subsample indices from first frame
255
+ pos0, _ = read_csv_2d(files[0])
256
+ if subsample and len(pos0) > subsample:
257
+ sub_idx = np.sort(np.random.choice(len(pos0), subsample, replace=False))
258
+ print(f"Subsampling {len(pos0)} → {subsample} points")
259
+ else:
260
+ sub_idx = None
261
+
262
+ all_positions, all_densities = [], []
263
+ for f in tqdm(files, desc="Loading"):
264
+ pos, dens = read_csv_2d(f)
265
+ if sub_idx is not None:
266
+ pos = pos[sub_idx]
267
+ dens = dens[sub_idx] if dens is not None else None
268
+ all_positions.append(pos)
269
+ all_densities.append(dens)
270
+
271
+ all_pos = np.vstack(all_positions)
272
+ pos_min = all_pos.min(axis=0)
273
+ pos_max = all_pos.max(axis=0)
274
+
275
+ has_density = all_densities[0] is not None
276
+ if has_density:
277
+ all_dens_flat = np.hstack([d for d in all_densities if d is not None])
278
+ dens_min, dens_max = all_dens_flat.min(), all_dens_flat.max()
279
+ else:
280
+ dens_min, dens_max = 0, 1
281
+
282
+ fig, ax = plt.subplots(figsize=figsize)
283
+
284
+ # Initial scatter for colorbar
285
+ sc_init = ax.scatter(
286
+ all_positions[0][:, 0], all_positions[0][:, 1],
287
+ c=(all_densities[0] if has_density else 'dodgerblue'),
288
+ cmap=cmap if has_density else None,
289
+ vmin=dens_min, vmax=dens_max,
290
+ s=1, alpha=alpha
291
+ )
292
+ if has_density:
293
+ plt.colorbar(sc_init, ax=ax, label='Density')
294
+
295
+ def update(frame):
296
+ ax.clear()
297
+ pos = all_positions[frame]
298
+ dens = all_densities[frame]
299
+ if has_density:
300
+ sc = ax.scatter(pos[:, 0], pos[:, 1], c=dens, cmap=cmap,
301
+ vmin=dens_min, vmax=dens_max, s=1, alpha=alpha)
302
+ else:
303
+ ax.scatter(pos[:, 0], pos[:, 1], c='dodgerblue', s=1, alpha=alpha)
304
+ ax.set_xlim(pos_min[0], pos_max[0])
305
+ ax.set_ylim(pos_min[1], pos_max[1])
306
+ ax.set_aspect('equal')
307
+ ax.set_xlabel('X')
308
+ ax.set_ylabel('Y')
309
+ ax.grid(True, alpha=0.3)
310
+
311
+ fname = os.path.basename(files[frame])
312
+ if 'final' in fname:
313
+ ax.set_title('Map: Final (Converged)', fontsize=13,
314
+ fontweight='bold')
315
+ else:
316
+ iter_num = int(fname.split('_')[-1].replace('.csv', ''))
317
+ ax.set_title(f'Map: Iteration {iter_num}', fontsize=13)
318
+ return ax,
319
+
320
+ print(f"Creating animation ({fps} fps)...")
321
+ anim = FuncAnimation(fig, update, frames=len(files),
322
+ interval=1000 // fps, blit=False)
323
+
324
+ _save_animation(anim, output_file, fps)
325
+ plt.close()
326
+ print(f"Animation saved to: {output_file}")
327
+
328
+
329
+ def animate_grid_deformation_2d(export_folder='vderm_2d_exports',
330
+ subfolder='vderm_grid',
331
+ output_file='grid_animation_2d.gif',
332
+ fps=5,
333
+ subsample=5000,
334
+ cmap='plasma',
335
+ figsize=(8, 7),
336
+ alpha=0.5):
337
+ """
338
+ Animate the 2-D grid node positions coloured by density.
339
+
340
+ Reads CSV files written by run_VDERM_2d_with_tracking() with
341
+ ``export_grid=True`` (5-column format: x y v_x v_y rho).
342
+
343
+ Parameters
344
+ ----------
345
+ export_folder : str, default='vderm_2d_exports'
346
+ subfolder : str, default='vderm_grid'
347
+ output_file : str, default='grid_animation_2d.gif'
348
+ fps : int, default=5
349
+ subsample : int or None, default=5000
350
+ cmap : str, default='plasma'
351
+ figsize : tuple, default=(8, 7)
352
+ alpha : float, default=0.5
353
+
354
+ Returns
355
+ -------
356
+ None
357
+ """
358
+ pattern = os.path.join(export_folder, subfolder, 'grid_iteration_*.csv')
359
+ files = sorted(glob.glob(pattern))
360
+ final_files = glob.glob(
361
+ os.path.join(export_folder, subfolder, 'grid_final_*.csv')
362
+ )
363
+ if final_files:
364
+ files.append(sorted(final_files)[0])
365
+
366
+ if not files:
367
+ raise FileNotFoundError(
368
+ f"No grid CSV files found at {pattern}\n"
369
+ "Run run_VDERM_2d_with_tracking with export_grid=True."
370
+ )
371
+
372
+ print(f"Found {len(files)} frames")
373
+
374
+ # Grid CSVs have 5 columns: x y v_x v_y rho
375
+ def _load_grid_csv(path):
376
+ data = np.loadtxt(path)
377
+ if data.ndim == 1:
378
+ data = data.reshape(1, -1)
379
+ pos = data[:, :2]
380
+ rho = data[:, 4] if data.shape[1] >= 5 else None
381
+ return pos, rho
382
+
383
+ pos0, _ = _load_grid_csv(files[0])
384
+ if subsample and len(pos0) > subsample:
385
+ sub_idx = np.sort(np.random.choice(len(pos0), subsample, replace=False))
386
+ else:
387
+ sub_idx = None
388
+
389
+ all_positions, all_densities = [], []
390
+ for f in tqdm(files, desc="Loading"):
391
+ pos, rho = _load_grid_csv(f)
392
+ if sub_idx is not None:
393
+ pos = pos[sub_idx]
394
+ rho = rho[sub_idx] if rho is not None else None
395
+ all_positions.append(pos)
396
+ all_densities.append(rho)
397
+
398
+ all_pos = np.vstack(all_positions)
399
+ pos_min = all_pos.min(axis=0)
400
+ pos_max = all_pos.max(axis=0)
401
+
402
+ has_density = all_densities[0] is not None
403
+ if has_density:
404
+ all_d = np.hstack([d for d in all_densities if d is not None])
405
+ dens_min, dens_max = all_d.min(), all_d.max()
406
+ else:
407
+ dens_min, dens_max = 0, 1
408
+
409
+ fig, ax = plt.subplots(figsize=figsize)
410
+ sc_init = ax.scatter(
411
+ all_positions[0][:, 0], all_positions[0][:, 1],
412
+ c=(all_densities[0] if has_density else 'steelblue'),
413
+ cmap=cmap if has_density else None,
414
+ vmin=dens_min, vmax=dens_max,
415
+ s=1, alpha=alpha
416
+ )
417
+ if has_density:
418
+ plt.colorbar(sc_init, ax=ax, label='Density')
419
+
420
+ def update(frame):
421
+ ax.clear()
422
+ pos = all_positions[frame]
423
+ dens = all_densities[frame]
424
+ if has_density:
425
+ ax.scatter(pos[:, 0], pos[:, 1], c=dens, cmap=cmap,
426
+ vmin=dens_min, vmax=dens_max, s=1, alpha=alpha)
427
+ else:
428
+ ax.scatter(pos[:, 0], pos[:, 1], c='steelblue', s=1, alpha=alpha)
429
+ ax.set_xlim(pos_min[0], pos_max[0])
430
+ ax.set_ylim(pos_min[1], pos_max[1])
431
+ ax.set_aspect('equal')
432
+ ax.set_xlabel('X')
433
+ ax.set_ylabel('Y')
434
+ ax.grid(True, alpha=0.3)
435
+ fname = os.path.basename(files[frame])
436
+ if 'final' in fname:
437
+ ax.set_title('Grid: Final (Converged)', fontsize=13,
438
+ fontweight='bold')
439
+ else:
440
+ iter_num = int(fname.split('_')[-1].replace('.csv', ''))
441
+ ax.set_title(f'Grid: Iteration {iter_num}', fontsize=13)
442
+ return ax,
443
+
444
+ print(f"Creating animation ({fps} fps)...")
445
+ anim = FuncAnimation(fig, update, frames=len(files),
446
+ interval=1000 // fps, blit=False)
447
+ _save_animation(anim, output_file, fps)
448
+ plt.close()
449
+ print(f"Animation saved to: {output_file}")
450
+
451
+
452
+ def plot_density_evolution_2d(export_folder='vderm_2d_exports',
453
+ grid_folder='vderm_grid',
454
+ output_file='density_evolution_2d.png'):
455
+ """
456
+ Plot mean, min, max, and std-dev of the grid density over iterations.
457
+
458
+ Reads CSV files written by run_VDERM_2d_with_tracking() with
459
+ ``export_grid=True``.
460
+
461
+ Parameters
462
+ ----------
463
+ export_folder : str
464
+ grid_folder : str
465
+ output_file : str
466
+
467
+ Examples
468
+ --------
469
+ >>> vd.plot_density_evolution_2d('vderm_2d_exports')
470
+ """
471
+ pattern = os.path.join(export_folder, grid_folder, 'grid_iteration_*.csv')
472
+ files = sorted(glob.glob(pattern))
473
+ if not files:
474
+ raise FileNotFoundError(f"No grid CSV files at {pattern}")
475
+
476
+ iterations, means, maxs, mins, stds = [], [], [], [], []
477
+
478
+ for f in tqdm(files, desc="Analysing"):
479
+ data = np.loadtxt(f)
480
+ if data.ndim == 1:
481
+ data = data.reshape(1, -1)
482
+ if data.shape[1] < 5:
483
+ continue
484
+ dens = data[:, 4]
485
+ fname = os.path.basename(f)
486
+ iter_num = int(fname.split('_')[-1].replace('.csv', ''))
487
+ iterations.append(iter_num)
488
+ means.append(dens.mean())
489
+ maxs.append(dens.max())
490
+ mins.append(dens.min())
491
+ stds.append(dens.std())
492
+
493
+ fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 8))
494
+
495
+ ax1.plot(iterations, means, 'b-', label='Mean', linewidth=2)
496
+ ax1.plot(iterations, maxs, 'r--', label='Max', linewidth=1.5)
497
+ ax1.plot(iterations, mins, 'g--', label='Min', linewidth=1.5)
498
+ ax1.fill_between(iterations, mins, maxs, alpha=0.2)
499
+ ax1.set_xlabel('Iteration')
500
+ ax1.set_ylabel('Density')
501
+ ax1.set_title('Density Statistics (2-D)')
502
+ ax1.legend()
503
+ ax1.grid(True, alpha=0.3)
504
+
505
+ ax2.plot(iterations, stds, 'purple', linewidth=2)
506
+ ax2.set_xlabel('Iteration')
507
+ ax2.set_ylabel('Std Dev')
508
+ ax2.set_title('Density Variation Over Time')
509
+ ax2.grid(True, alpha=0.3)
510
+
511
+ plt.tight_layout()
512
+ plt.savefig(output_file, dpi=150, bbox_inches='tight')
513
+ plt.close()
514
+ print(f"Density evolution plot saved to: {output_file}")
515
+
516
+
517
+ # ─── Helpers ──────────────────────────────────────────────────────────────────
518
+
519
+ def _save_animation(anim, output_file, fps):
520
+ """Save a FuncAnimation to .gif or .mp4."""
521
+ if output_file.endswith('.gif'):
522
+ anim.save(output_file, writer=PillowWriter(fps=fps))
523
+ elif output_file.endswith('.mp4'):
524
+ try:
525
+ from matplotlib.animation import FFMpegWriter
526
+ anim.save(output_file, writer=FFMpegWriter(fps=fps, bitrate=1800))
527
+ except Exception as exc:
528
+ raise RuntimeError(
529
+ f"Failed to save MP4 — is ffmpeg installed?\n{exc}"
530
+ ) from exc
531
+ else:
532
+ raise ValueError(
533
+ f"output_file must end with .gif or .mp4, got: {output_file}"
534
+ )