PyGhostID 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
PyGhostID/core.py ADDED
@@ -0,0 +1,1330 @@
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Core functions of PyGhostID
4
+
5
+ @author: Daniel Koch, 2026
6
+ """
7
+
8
+ # Import packages
9
+ import numpy as np
10
+ import jax
11
+ import jax.numpy as jnp
12
+ from scipy.signal import find_peaks
13
+ from scipy.spatial import cKDTree
14
+ from scipy.optimize import (
15
+ minimize,
16
+ differential_evolution,
17
+ dual_annealing,
18
+ basinhopping,
19
+ )
20
+ from scipy.stats import qmc
21
+
22
+ from scipy.integrate import solve_ivp
23
+ import matplotlib.pyplot as plt
24
+ import networkx as nx
25
+ from networkx.drawing.nx_agraph import graphviz_layout
26
+ from ._utils import *
27
+ from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor, as_completed
28
+ import os
29
+ import sys
30
+ from tqdm import tqdm
31
+
32
+ # Functions
33
+
34
+ def ghostID(model, params, dt, trajectory, epsilon_gid=0.05, **kwargs):
35
+
36
+ """
37
+ Identify ghost states along a simulated trajectory by detecting Q-minima and
38
+ evaluating the eigenvalue spectrum of the Jacobian along trajectory segments
39
+ in their vicinity.
40
+
41
+ Parameters
42
+ ----------
43
+ model : callable
44
+ Python function describing the system dynamics.
45
+ params : list or array-like
46
+ Parameters passed as arguments to model.
47
+ dt : float
48
+ Step-size of the numerical integration used to simulate trajectory.
49
+ trajectory : array-like
50
+ A trajectory simulated by the system to be analysed for ghost states.
51
+ epsilon_gid : float, optional
52
+ Radius of the epsilon-sphere around Q-minima. Determines which trajectory
53
+ segments are used for eigenvalue evaluation: only points within this distance
54
+ of a Q-minimum are included. Values in the range 0.01--0.1 are reasonable
55
+ for many models (default is 0.05). The value should be large enough that
56
+ segments contain sufficient points for reliable eigenvalue estimation, but
57
+ small enough that eigenvalues remain representative of the local phase-space
58
+ topology around the candidate ghost. A practical tuning strategy is to start
59
+ with small values and increase until eigenvalue control plots look reasonably
60
+ smooth.
61
+ **kwargs
62
+ Optional keyword arguments:
63
+
64
+ delta_gid : float
65
+ Phase-space distance below which two ghosts identified by GhostID are
66
+ considered the same and assigned the same identifier. Default is 0.1.
67
+ peak_kwargs : dict
68
+ Additional keyword arguments passed to scipy.signal.find_peaks to
69
+ improve detection of Q-minima from peaks in the pQ time-series.
70
+ evLimit : float
71
+ Default is 0 (disabled). If set to a value greater than 0, enables the
72
+ indirect method of ghost identification: a trajectory segment is
73
+ considered to pass near a ghost if (1) the absolute value of the median
74
+ of eigenvalues along the segment is below evLimit, (2) the linear fit of
75
+ eigenvalues has R² >= 0.99, and (3) the slope of the fit lies within the
76
+ range given by slopeLimits.
77
+ slopeLimits : array-like of length 2
78
+ Upper and lower bounds for the eigenvalue slope used in the indirect
79
+ identification method. Ignored if evLimit = 0. Default is [0, inf].
80
+ batchModel : callable
81
+ Vectorized (batch) version of model able to handle batch inputs. Can be
82
+ provided to improve performance when ghostID is called repeatedly with the
83
+ same model, avoiding repeated calls to make_batch_model. If not provided,
84
+ ghostID constructs a batch model internally via make_batch_model.
85
+
86
+ Eigenvalue cleaning
87
+ -------------------
88
+ Because eigenvalue indexing along a trajectory segment is not guaranteed to
89
+ be consistent across consecutive time steps (i.e. lambda_1 at time t may be
90
+ lambda_2 at time t+1), two complementary correction methods are available:
91
+
92
+ ev_outlier_removal : bool
93
+ If True, removes eigenvalue outliers detected within a sliding window.
94
+ An eigenvalue is considered an outlier if it falls above q3 + k*(q3-q1)
95
+ or below q1 - k*(q3-q1), where q1 and q3 are the 25th and 75th
96
+ percentiles and k is set by ev_outlier_k. Default is False.
97
+ ev_outlier_removal_k : float
98
+ Controls the width of the non-outlier range (see ev_outlier_removal).
99
+ Default is 1.5.
100
+ ev_outlier_removal_ws : float
101
+ Size of the sliding window used for outlier detection. Default is 7.
102
+ eigval_NN_sorting : bool
103
+ If True, sorts eigenvalues across time for each index i using a
104
+ nearest-neighbour prediction. Useful when eigenvalue time-series appear
105
+ scattered or discontinuous. Default is False.
106
+
107
+ Control outputs
108
+ ---------------
109
+ display_warnings : bool
110
+ Show or suppress warning messages from GhostID.
111
+ ctrlOutputs : dict
112
+ Controls diagnostic plots of the algorithm's two core quantities
113
+ (pQ-values and eigenvalues). Recognised keys:
114
+
115
+ ctrl_qplot (bool) : plot pQ-values and detected Q-minima along
116
+ the trajectory.
117
+ qplot_xscale (str) : x-axis scale for Q-plot, 'linear' (default)
118
+ or 'log'.
119
+ qplot_yscale (str) : y-axis scale for Q-plot, 'linear' (default)
120
+ or 'log'.
121
+ ctrl_evplot (bool) : plot eigenvalues along each trajectory segment
122
+ around identified Q-minima, including the
123
+ evaluation criteria listed in the plot heading.
124
+ evplot_xscale (str) : x-axis scale for eigenvalue plot, 'linear'
125
+ (default) or 'log'.
126
+ evplot_yscale (str) : y-axis scale for eigenvalue plot, 'linear'
127
+ (default) or 'log'.
128
+ return_ctrl_figs : bool
129
+ If True, returns control plot figures for manual customisation instead
130
+ of displaying them inline. Default is False.
131
+
132
+ Returns
133
+ -------
134
+ ghostSeq : list of dict
135
+ List of identified ghost states, each represented as a Python dictionary.
136
+ control_figures : optional
137
+ Control plot figures, returned only if return_ctrl_figs is True.
138
+
139
+ """
140
+
141
+ # Parse and validate kwargs
142
+ config = parse_kwargs(**kwargs)
143
+
144
+ # Extract parameters from config
145
+ display_warnings = config['display_warnings']
146
+ delta_gid = config['delta_gid']
147
+ peak_kwargs = config['peak_kwargs']
148
+ if "width" not in peak_kwargs:
149
+ peak_kwargs["width"] = 5 * dt # default if not supplied
150
+ batchModel = config['batchModel']
151
+ eigval_NN_sorting = config['eigval_NN_sorting']
152
+ ev_outlier_removal = config['ev_outlier_removal']
153
+ ev_outlier_removal_ws = config['ev_outlier_removal_ws']
154
+ ev_outlier_removal_k = config['ev_outlier_removal_k']
155
+ evLimit = config['evLimit']
156
+ slopeLimits = config['slopeLimits']
157
+
158
+ # Plotting control settings
159
+ return_ctrl_figs = config['return_ctrl_figs']
160
+ ctrl_qplot = config['ctrl_qplot']
161
+ qplot_xscale = config['qplot_xscale']
162
+ qplot_yscale = config['qplot_yscale']
163
+ ctrl_evplot = config['ctrl_evplot']
164
+ evplot_xscale = config['evplot_xscale']
165
+ evplot_yscale = config['evplot_yscale']
166
+
167
+ ####################################
168
+
169
+ # Handle batch model
170
+ if batchModel is not None:
171
+ Xs = batchModel(trajectory)
172
+ else:
173
+ model_batch = make_batch_model(model, params)
174
+ Xs = model_batch(trajectory)
175
+
176
+ if return_ctrl_figs:
177
+ ctrl_figures = []
178
+
179
+ n = trajectory.shape[1] # dimension of trajectory
180
+ fullTransientSeq = [] # list of visited transient states to be filled
181
+
182
+ ############# STEP 1 - identify non-oscillatory saddle-node ghosts #############################
183
+
184
+ ### Identify minima in Q-values along trajectory using batch model output
185
+
186
+ Q_ts = 0.5 * np.sqrt(np.sum(Xs**2, axis=1))
187
+ pQ = -np.log(Q_ts)
188
+
189
+ idx_minima, pk_props = find_peaks(pQ, **peak_kwargs) # positions of Q-minima
190
+
191
+ if ctrl_qplot:
192
+ if not len(idx_minima)>0:
193
+ t_axis = np.arange(len(Q_ts)) * dt
194
+ fig, ax = plt.subplots()
195
+ fig.set_size_inches(17/(2*2.54),17/(3*2.54))
196
+ ax.plot(t_axis, pQ, '-k', lw=0.8, label="pQ(t)")
197
+ ax.set_ylabel("pQ(t)")
198
+ ax.set_xlabel("t")
199
+ ax.set_xscale(qplot_xscale)
200
+ ax.set_yscale(qplot_yscale)
201
+ ax.legend(fontsize = 9)
202
+ ax.set_title("Detected Q-minima with prominences",fontsize=12)
203
+ plt.tight_layout()
204
+ if return_ctrl_figs:
205
+ ctrl_figures.append((fig,ax))
206
+ plt.close(fig)
207
+ else:
208
+ plt.show()
209
+
210
+ # Precompile JAX Jacobian function
211
+ J_fun = make_jacfun(model, params)
212
+
213
+ # Build KD-tree once for trajectory
214
+ kdtree = cKDTree(trajectory)
215
+
216
+ if len(idx_minima) > 0:
217
+ ghostSeq = [] # sequence of visited ghosts
218
+ ghostTimes = [] # times at which ghosts have been visited - important for sorting ghosts and oscillatory transients
219
+ ghostCoordinates = [] # unique phase-space positions of all ghosts visited
220
+
221
+ if ctrl_qplot:
222
+
223
+ t_axis = np.arange(len(Q_ts)) * dt
224
+ fig, ax = plt.subplots()
225
+ fig.set_size_inches(17/(2*2.54),17/(3*2.54))
226
+ ax.plot(t_axis, pQ, '-k', lw=0.8, label="pQ(t)")
227
+ ax.plot(idx_minima * dt, pQ[idx_minima], 'xr', label="Q-minima")
228
+ # Plot prominence lines
229
+ if "prominences" in pk_props:
230
+ prominences = pk_props["prominences"]
231
+ for idx, prom in zip(idx_minima, prominences):
232
+ x = idx * dt
233
+ peak_val = pQ[idx]
234
+ base_val = peak_val - prom
235
+ # vertical line showing prominence
236
+ ax.vlines(x, base_val, peak_val, color="gray", linestyle="--")
237
+ # add text next to line
238
+ ax.text(x, base_val - 0.05 * prom, f"{prom:.2f}",
239
+ ha="center", va="top", fontsize=7.5, color="gray")
240
+
241
+ ax.set_ylabel("pQ(t)")
242
+ ax.set_xlabel("t")
243
+ ax.set_xscale(qplot_xscale)
244
+ ax.set_yscale(qplot_yscale)
245
+ ax.legend(fontsize = 9)
246
+ ax.set_title("Detected Q-minima with prominences",fontsize=12)
247
+ plt.tight_layout()
248
+ if return_ctrl_figs:
249
+ ctrl_figures.append((fig,ax))
250
+ plt.close(fig)
251
+ else:
252
+ plt.show()
253
+
254
+
255
+ for i in idx_minima:
256
+
257
+ ghostCheck = False
258
+
259
+ t_ghost = i * dt # time at which a potential ghost was found
260
+ dur_ghost = pk_props["widths"][np.where(idx_minima==i)[0][0]]*dt # trapping time
261
+
262
+ qmin_xyz = trajectory[i] # position of Q_minimum in phase space
263
+
264
+ # KD-tree neighborhood query
265
+ idcs_Ueps_qmin = kdtree.query_ball_point(qmin_xyz, epsilon_gid)
266
+ idcs_Ueps_qmin = np.sort(np.asarray(idcs_Ueps_qmin, dtype=int))
267
+
268
+ if len(idcs_Ueps_qmin)<5:
269
+ print("ghostID error: insuffienct number of points in Ueps!")
270
+ if not(return_ctrl_figs):
271
+ return []
272
+ else:
273
+ return [], ctrl_figures
274
+
275
+ idcs_segment = trjSegment(idcs_Ueps_qmin, i)
276
+
277
+ # Check if trajectory leaves the epsilon environment
278
+ leaves_eps_qmin_i = False
279
+ dists = np.linalg.norm(trajectory[i:] - qmin_xyz, axis=1)
280
+ if np.any(dists > epsilon_gid):
281
+ leaves_eps_qmin_i = True
282
+
283
+ if leaves_eps_qmin_i:
284
+ # Batch Jacobian + eigenvalue evaluation for segment
285
+ pts_segment = jnp.asarray(trajectory[idcs_segment]) # JAX array
286
+ J_batch = jax.vmap(J_fun)(pts_segment) # batch Jacobians
287
+ eigVals = jax.vmap(jnp.linalg.eigvals)(J_batch) # eigenvalues
288
+ eigVals_real_ = np.real(np.asarray(eigVals)) # back to numpy for analysis
289
+
290
+ if eigval_NN_sorting:
291
+ eigVals_real = sort_NN(eigVals_real_.T).T
292
+ else:
293
+ eigVals_real = eigVals_real_
294
+
295
+ # Determine eigenvalue crossings along the segment
296
+ ev_signChanges = [sign_change(eigVals_real[:, ii],ev_outlier_removal,ev_outlier_removal_ws,ev_outlier_removal_k,display_warnings=display_warnings) for ii in range(n)]
297
+ crossings = sum(ev_signChanges)
298
+
299
+ # determine eigenvalue slopes
300
+ qualifyingSlopes = []
301
+
302
+ if ctrl_evplot: r2s = []
303
+ for ii in range(n):
304
+ # Only consider eigenvalues with small median real part along segment
305
+ if np.abs(np.median(eigVals_real[:, ii])) < evLimit:
306
+ slope, r2 = slope_and_r2(eigVals_real[:, ii], dt,ev_outlier_removal,ev_outlier_removal_ws,ev_outlier_removal_k)
307
+ if ctrl_evplot: r2s.append(r2)
308
+ if np.all((slope > slopeLimits[0]) & (slope < slopeLimits[1]) & (r2 >= 0.99)):
309
+ qualifyingSlopes.append(ii)
310
+
311
+ # Check for ghost
312
+ if crossings > 0 or len(qualifyingSlopes) > 0:
313
+ ghostCheck = True
314
+
315
+ if ctrl_evplot:
316
+ n_eig = eigVals_real.shape[1]
317
+ fig, axes = plt.subplots(n_eig, 1, figsize=(17/(2*2.54), 17/(4*2.54)*n_eig), sharex=True)
318
+ if n_eig == 1: axes = [axes] # ensure axes is iterable
319
+
320
+ t_seg = np.arange(len(eigVals_real)) * dt
321
+
322
+ for j, ax in enumerate(axes):
323
+ ax.plot(t_seg, eigVals_real[:, j], '-ok', markersize=1.5, lw=0.75)
324
+ ax.set_ylabel(f'λ{j+1}')
325
+
326
+ axes[-1].set_xlabel('Time along segment')
327
+ if evLimit > 0:
328
+ qsl = len(qualifyingSlopes)
329
+ else:
330
+ qsl = 'N/A'
331
+
332
+ Q_mami = np.max(Q_ts[idcs_segment])/Q_ts[i]
333
+
334
+ plt.suptitle(
335
+ f"Eig.vals near Qmin at t={t_ghost:.2f}, ghost: {str(ghostCheck)[0]}, leaves Uɛ: {str(leaves_eps_qmin_i)[0]}, Qmami: {Q_mami:.1e} \n"
336
+ f"sign changes λi: " + "".join(np.where(ev_signChanges, 'T ', 'F '))+f", qualifying slopes: {qsl}, "
337
+ f"R²: {[f'{ri:.3f}' for ri in r2s]}", fontsize=9)
338
+
339
+ ax.set_xscale(evplot_xscale)
340
+ ax.set_yscale(evplot_yscale)
341
+ plt.tight_layout()
342
+ if return_ctrl_figs:
343
+ ctrl_figures.append((fig,axes))
344
+ plt.close(fig)
345
+ else:
346
+ plt.show()
347
+
348
+
349
+ # If ghost found, characterize its dimension and check if it has been found previously
350
+ if ghostCheck:
351
+ ghostTimes.append(t_ghost)
352
+ gdim = max([crossings, len(qualifyingSlopes)])
353
+
354
+ if len(ghostCoordinates) > 0:
355
+ # Calculate distances to all previously found ghosts
356
+ distances = np.asarray([np.linalg.norm(g - trajectory[i]) for g in ghostCoordinates])
357
+
358
+ if not any(d < delta_gid for d in distances): # current ghost has not been found previously
359
+ ghostCoordinates.append(trajectory[i])
360
+ ghost = {
361
+ "id": "G" + str(len(ghostCoordinates)), # assign ID to the new ghost
362
+ "time": t_ghost,
363
+ "duration": dur_ghost,
364
+ "position": trajectory[i],
365
+ "dimension": gdim,
366
+ "q-value": Q_ts[i],
367
+ "crossing_eigenvalues": np.where(np.array(ev_signChanges)==True)[0],
368
+ "qualifying_slopes":qualifyingSlopes,
369
+ "eigenvalues_qmin": np.asarray(eigVals[i, :])
370
+ }
371
+ else: # current ghost has already been found previously
372
+ gidx = np.where(distances < delta_gid)[0][0] + 1
373
+ ghost = {
374
+ "id": "G" + str(gidx),
375
+ "time": t_ghost,
376
+ "duration": dur_ghost,
377
+ "position": trajectory[i],
378
+ "dimension": gdim,
379
+ "q-value": Q_ts[i],
380
+ "crossing_eigenvalues": np.where(np.array(ev_signChanges)==True)[0],
381
+ "qualifying_slopes":qualifyingSlopes,
382
+ "eigenvalues_qmin": np.asarray(eigVals[i, :])
383
+ }
384
+ ghostSeq.append(ghost)
385
+ else: # No ghost previously found yet
386
+ ghost = {
387
+ "id": "G" + str(len(ghostCoordinates) + 1),
388
+ "time": t_ghost,
389
+ "duration": dur_ghost,
390
+ "position": trajectory[i],
391
+ "dimension": gdim,
392
+ "q-value": Q_ts[i],
393
+ "crossing_eigenvalues": np.where(np.array(ev_signChanges)==True)[0],
394
+ "qualifying_slopes":qualifyingSlopes,
395
+ "eigenvalues_qmin": np.asarray(eigVals[i, :])
396
+ }
397
+ ghostSeq.append(ghost)
398
+ ghostCoordinates.append(trajectory[i])
399
+ else:
400
+ if display_warnings:
401
+ print("GhostID: Trajectory does not leave U_eps - stopping ghostID.")
402
+ break
403
+
404
+
405
+ ############# STEP 2 - identify oscillatory transients #############################
406
+ oscSeq = []
407
+ oscTimes = []
408
+
409
+ # Merge transient state lists
410
+ allTimes = np.asarray(ghostTimes + oscTimes)
411
+ allTimes.sort()
412
+ ghostTimes = np.asarray(ghostTimes)
413
+ oscTimes = np.asarray(oscTimes)
414
+
415
+ for t in allTimes:
416
+ i_t = np.where(ghostTimes == t)[0]
417
+ if len(i_t) == 1:
418
+ fullTransientSeq.append(ghostSeq[i_t[0]])
419
+ else:
420
+ i_t = np.where(oscTimes == t)[0]
421
+ fullTransientSeq.append(oscSeq[i_t[0]])
422
+
423
+ if not(return_ctrl_figs):
424
+ return fullTransientSeq
425
+ else:
426
+ return fullTransientSeq, ctrl_figures
427
+
428
+ def ghostID_phaseSpaceSample(model, model_params, t_start, t_end, dt, state_ranges,
429
+ n_samples=50, method='RK45', rtol=1.e-3, atol=1.e-6, n_workers=None, **kwargs):
430
+ """
431
+ Identify ghost states across a region of phase space by simulating multiple
432
+ trajectories from initial conditions drawn by Latin-hypercube sampling and
433
+ evaluating each trajectory with ghostID.
434
+
435
+ Implements an adaptive parallel-processing routine: uses threads when run
436
+ inside Jupyter or Spyder (avoiding pickling issues) and processes when run
437
+ as a standalone script for full CPU utilisation.
438
+
439
+ Parameters
440
+ ----------
441
+ model : callable
442
+ Python function describing the system dynamics.
443
+ model_params : list or array-like
444
+ Parameters passed as arguments to model.
445
+ t_start : float
446
+ Start time of each simulated trajectory.
447
+ t_end : float
448
+ End time of each simulated trajectory.
449
+ dt : float
450
+ Step-size for numerical integration.
451
+ state_ranges : list of tuple
452
+ List of n tuples (one per dimension of the system), each specifying the
453
+ (lower, upper) boundaries of the phase space region to be sampled along
454
+ that coordinate. For example, for a 2D system:
455
+ [(x_min, x_max), (y_min, y_max)].
456
+ n_samples : int, optional
457
+ Number of trajectories to simulate and analyse for ghost states.
458
+ method : str, optional
459
+ Integration method passed to scipy.integrate.solve_ivp. Default is 'RK45'.
460
+ rtol : float, optional
461
+ Relative tolerance for the numerical integrator. Default is 1e-3.
462
+ atol : float, optional
463
+ Absolute tolerance for the numerical integrator. Default is 1e-6.
464
+ n_workers : int or None, optional
465
+ Number of CPU cores used for parallel processing. If None, the number
466
+ of workers is chosen automatically. Default is None.
467
+ **kwargs
468
+ All keyword arguments accepted by ghostID are also accepted here (see
469
+ ghostID documentation). Note that control plots are unlikely to render
470
+ correctly when parallel processing is enabled.
471
+
472
+ Additional kwargs specific to ghostID_phaseSpaceSample:
473
+
474
+ delta_unify : float
475
+ Distance in phase space above which two ghosts identified across
476
+ different trajectories in the phase space sample are considered
477
+ distinct and assigned different identifiers. Analogous to delta_gid
478
+ in ghostID, but applied globally across all trajectories rather than
479
+ within a single trajectory. Default is 0.1.
480
+ seed : int or None
481
+ Random seed for the Latin-hypercube sampler. Setting a specific value
482
+ ensures exact reproducibility of the phase space sample. Default is
483
+ None (non-reproducible random sample).
484
+
485
+ Returns
486
+ -------
487
+ results : list
488
+ List containing a ghostSeq (see ghostID documentation) for each
489
+ simulated trajectory, allowing ghost states to be retrieved and compared
490
+ across the sampled region of phase space.
491
+ """
492
+
493
+ # Parse and validate kwargs
494
+ config = parse_kwargs(**kwargs)
495
+
496
+ display_warnings = config['display_warnings']
497
+
498
+ # Extract parameters from config
499
+ delta_gid = config['delta_gid']
500
+ peak_kwargs = config['peak_kwargs']
501
+ if "width" not in peak_kwargs:
502
+ peak_kwargs["width"] = 5 * dt # default if not supplied
503
+ batchModel = config['batchModel']
504
+ eigval_NN_sorting = config['eigval_NN_sorting']
505
+ ev_outlier_removal_ws = config['ev_outlier_removal_ws']
506
+ ev_outlier_removal_k = config['ev_outlier_removal_k']
507
+ evLimit = config['evLimit']
508
+ slopeLimits = config['slopeLimits']
509
+ epsilon_gid = config['epsilon_gid']
510
+ epsilon_unify = config['epsilon_unify'] # rename to delta_unify
511
+ seed = config['seed']
512
+
513
+ # Plotting control settings
514
+ return_ctrl_figs = config['return_ctrl_figs']
515
+ ctrlOutputs = kwargs.get("ctrlOutputs", {})
516
+
517
+ if n_workers is None:
518
+ n_workers = max(1, (os.cpu_count() or 4) - 1)
519
+
520
+ npts = int(t_end / dt)
521
+ t_eval = np.linspace(0, t_end, npts + 1)
522
+
523
+ ICs = phaseSpaceLHS(state_ranges, n_samples, seed)
524
+
525
+ # ---- Choose backend automatically ----
526
+ in_spyder_or_jupyter = (
527
+ "SPYDER" in sys.modules or
528
+ "spyder_kernels" in sys.modules or
529
+ "ipykernel" in sys.modules
530
+ )
531
+ Executor = ThreadPoolExecutor if in_spyder_or_jupyter else ProcessPoolExecutor
532
+ mode = "threads" if Executor is ThreadPoolExecutor else "processes"
533
+ print(f"[ghostID_phaseSpaceSample] Running with {mode} ({n_workers} workers)")
534
+
535
+ # ---- Define worker ----
536
+ def process_ic(ic):
537
+ sol = solve_ivp(model, (t_start, t_end), ic,
538
+ t_eval=t_eval, args=(model_params,),
539
+ method=method, rtol=rtol, atol=atol)
540
+ ghostSeq_ = ghostID(model, model_params, dt, sol.y.T,
541
+ epsilon_gid, delta_gid=delta_gid,peak_kwargs=peak_kwargs,
542
+ batchModel=batchModel,return_ctrl_figs=return_ctrl_figs,ctrlOutputs=ctrlOutputs,
543
+ evLimit=evLimit,slopeLimits=slopeLimits,eigval_NN_sorting=eigval_NN_sorting,
544
+ ev_outlier_removal_ws=ev_outlier_removal_ws,ev_outlier_removal_k=ev_outlier_removal_k,
545
+ display_warnings=display_warnings)
546
+ if return_ctrl_figs == False:
547
+ ghostSeq = ghostSeq_
548
+ else:
549
+ ghostSeq, ctrl_figures = ghostSeq_ #ignore ctrl_figures for now
550
+ return ghostSeq if ghostSeq else None
551
+
552
+ # ---- Parallel execution ----
553
+ ghostSeqs = []
554
+ with Executor(max_workers=n_workers) as executor:
555
+ futures = [executor.submit(process_ic, ic) for ic in ICs]
556
+ for f in tqdm(as_completed(futures), total=len(futures),
557
+ desc="Processing ICs", unit="IC"):
558
+ res = f.result()
559
+ if res is not None:
560
+ ghostSeqs.append(res)
561
+
562
+ # ---- Unify results ----
563
+ if len(ghostSeqs) > 0:
564
+ ghostSeqs_unified = unify_IDs(ghostSeqs, epsilon_unify, False)
565
+ return ghostSeqs_unified
566
+ else:
567
+ if display_warnings:
568
+ print("ghostID_phaseSpaceSample: No ghosts found in any trajectory.")
569
+ return []
570
+
571
+ def make_batch_model(model, params):
572
+ """
573
+ Wrap a single-point model into a batch version using vmap.
574
+
575
+ model: function (t, z, params) -> dz/dt
576
+ params: model parameters (passed unchanged)
577
+
578
+ Returns: function (Zs, params) -> dZs/dt for batch input
579
+ where Zs has shape (num_points, n).
580
+ """
581
+ def single(z):
582
+ return model(0, z, params) # ignore t (or pass if needed)
583
+
584
+ batched = jax.vmap(single)
585
+ return batched
586
+
587
+ def find_local_Qminimum(model, x0, model_params, delta, *, global_method="lhs", local_method="L-BFGS-B",
588
+ global_options=None, local_options=None, verbose=False):
589
+
590
+ """
591
+ Search for a Q-minimum within a region of radius delta around a given point
592
+ x0 in phase space, where Q(x) = 0.5 * ||F(x)||^2 and F is the vector field
593
+ defined by model. The search combines a global strategy (either sampling or
594
+ a global optimisation algorithm) with an optional subsequent local
595
+ optimisation step.
596
+
597
+ Parameters
598
+ ----------
599
+ model : callable
600
+ Python function describing the system dynamics.
601
+ x0 : array-like
602
+ Centre of the search region in phase space. The Q-minimum is sought
603
+ within a ball of radius delta around this point.
604
+ model_params : list or array-like
605
+ Parameters passed as arguments to model.
606
+ delta : float
607
+ Radius of the region around x0 in which to search for a Q-minimum.
608
+ global_method : str, optional
609
+ Strategy used for the global search stage. Options are:
610
+
611
+ 'lhs' (default)
612
+ Draw a Latin-hypercube sample from the delta-ball around x0,
613
+ evaluate Q at each sample point, and pass the k_seeds points with
614
+ the lowest Q-values to the local optimiser as starting points.
615
+ Controlled via global_options (see below).
616
+ 'differential_evolution'
617
+ Use scipy.optimize.differential_evolution to search for a
618
+ Q-minimum within the delta-ball.
619
+ 'dual_annealing'
620
+ Use scipy.optimize.dual_annealing to search for a Q-minimum
621
+ within the delta-ball.
622
+ 'basin_hopping'
623
+ Use scipy.optimize.basin_hopping to search for a Q-minimum
624
+ within the delta-ball.
625
+
626
+ local_method : str, optional
627
+ Local optimisation method applied after the global search stage.
628
+ Accepts any method available in scipy.optimize.minimize.
629
+ Default is 'L-BFGS-B'.
630
+ global_options : dict or None, optional
631
+ Options controlling the global search stage. For the scipy.optimize
632
+ global methods ('differential_evolution', 'dual_annealing',
633
+ 'basin_hopping'), any keyword argument accepted by the corresponding
634
+ scipy.optimize function may be provided. For global_method='lhs',
635
+ the following keys are recognised:
636
+
637
+ n_samples : int or None
638
+ Size of the Latin-hypercube sample. If None, chosen automatically
639
+ as min(2000, max(200, 20 * dim)), where dim is the dimension of
640
+ the model.
641
+ k_seeds : int or None
642
+ Number of lowest-Q sample points passed to the local optimiser as
643
+ starting points. If None, chosen automatically as
644
+ min(5, max(2, int(sqrt(dim)))), where dim is the dimension of
645
+ the model.
646
+ seed : int or None
647
+ Random seed for the Latin-hypercube sampler. Setting a specific
648
+ value ensures exact reproducibility of the sample. Default is None.
649
+
650
+ Default for global_options is None (all sub-options use their defaults).
651
+ local_options : dict or None, optional
652
+ Options passed directly to scipy.optimize.minimize for the local
653
+ optimisation step. Accepts any keyword argument recognised by the
654
+ chosen local_method. Default is None.
655
+ verbose : bool, optional
656
+ If True, enables control outputs during the search. Default is False.
657
+
658
+ Returns
659
+ -------
660
+ result : OptimizeResult or similar
661
+ The identified Q-minimum within the delta-ball around x0.
662
+ """
663
+
664
+ F = model
665
+ p = model_params
666
+
667
+ x0 = np.asarray(x0, dtype=float)
668
+ dim = x0.size
669
+ bounds = [(x0[i] - delta, x0[i] + delta) for i in range(dim)]
670
+
671
+ global_options = {} if global_options is None else dict(global_options)
672
+ local_options = {} if local_options is None else dict(local_options)
673
+
674
+ # --------------------------------------------------
675
+ # Define Q and grad Q
676
+ # --------------------------------------------------
677
+ def Q_func(x):
678
+ z = jnp.asarray(x)
679
+ return float(0.5 * jnp.sum(F(0.0, z, p) ** 2)) #SHOULD THERE NOT BE A SQRT HERE???
680
+
681
+ grad_Q = jax.grad(lambda z: 0.5 * jnp.sum(F(0.0, z, p) ** 2))
682
+
683
+ def grad_Q_np(x):
684
+ return np.asarray(grad_Q(jnp.asarray(x)))
685
+
686
+ # --------------------------------------------------
687
+ # Global search
688
+ # --------------------------------------------------
689
+ if verbose:
690
+ print(f"[Q-min] Global search method: {global_method}")
691
+
692
+ candidate_points = []
693
+
694
+ # ---- LHS ------------------------------------------------
695
+ if global_method == "lhs":
696
+ # Compute dimension-aware defaults if None
697
+ n_samples = global_options.get("n_samples", None)
698
+ if n_samples is None:
699
+ n_samples = min(2000, max(200, 20 * dim))
700
+
701
+ k_seeds = global_options.get("k_seeds", None)
702
+ if k_seeds is None:
703
+ k_seeds = min(5, max(2, int(np.sqrt(dim))))
704
+
705
+ seed = global_options.get("seed", None)
706
+
707
+ if verbose:
708
+ print(
709
+ f"[Q-min] LHS: n_samples={n_samples}, "
710
+ f"k_seeds={k_seeds}, seed={seed}"
711
+ )
712
+
713
+ sampler = qmc.LatinHypercube(d=dim, seed=seed)
714
+ samples = sampler.random(n=n_samples)
715
+ points = qmc.scale(
716
+ samples,
717
+ [b[0] for b in bounds],
718
+ [b[1] for b in bounds],
719
+ )
720
+
721
+ Q_vals = np.array([Q_func(x) for x in points])
722
+ best_idx = np.argsort(Q_vals)[:k_seeds]
723
+ candidate_points = points[best_idx]
724
+
725
+ # ---- Differential Evolution -----------------------------
726
+ elif global_method == "differential_evolution":
727
+ res = differential_evolution(
728
+ Q_func,
729
+ bounds,
730
+ disp=verbose,
731
+ **global_options,
732
+ )
733
+ candidate_points = [res.x]
734
+
735
+ # ---- Dual Annealing -------------------------------------
736
+ elif global_method == "dual_annealing":
737
+ res = dual_annealing(
738
+ Q_func,
739
+ bounds,
740
+ **global_options,
741
+ )
742
+ candidate_points = [res.x]
743
+
744
+ # ---- Basin Hopping --------------------------------------
745
+ elif global_method == "basin_hopping":
746
+ minimizer_kwargs = {
747
+ "method": local_method,
748
+ "jac": grad_Q_np,
749
+ "bounds": bounds if local_method in {"L-BFGS-B", "TNC", "SLSQP"} else None,
750
+ }
751
+ res = basinhopping(
752
+ Q_func,
753
+ x0,
754
+ minimizer_kwargs=minimizer_kwargs,
755
+ **global_options,
756
+ )
757
+ candidate_points = [res.x]
758
+
759
+ else:
760
+ raise ValueError(f"Unknown global_method '{global_method}'")
761
+
762
+ # --------------------------------------------------
763
+ # Local refinement
764
+ # --------------------------------------------------
765
+ if local_method is None:
766
+ Q_vals = [Q_func(x) for x in candidate_points]
767
+ best = int(np.argmin(Q_vals))
768
+ return candidate_points[best], Q_vals[best], None
769
+
770
+ results_local = []
771
+
772
+ for x_start in candidate_points:
773
+ res = minimize(
774
+ Q_func,
775
+ x_start,
776
+ jac=grad_Q_np,
777
+ method=local_method,
778
+ bounds=bounds if local_method in {"L-BFGS-B", "TNC", "SLSQP"} else None,
779
+ options={
780
+ "disp": verbose,
781
+ **local_options,
782
+ },
783
+ )
784
+ results_local.append(res)
785
+
786
+ best_res = min(results_local, key=lambda r: r.fun)
787
+
788
+ if verbose:
789
+ print(f"[Q-min] Final Q = {best_res.fun:.3e}")
790
+ print(f"[Q-min] x = {best_res.x}")
791
+
792
+ return best_res.x, best_res.fun, best_res
793
+
794
+ def qOnGrid(model, model_params, coords=None, n_points=50, ranges=None, overrides=None, indexing="ij", jit=False):
795
+
796
+ """
797
+ Evaluate Q(x) = 0.5 * ||F(x)||^2 on a phase space grid, where F is the
798
+ vector field defined by the model.
799
+
800
+ Parameters
801
+ ----------
802
+ model : callable
803
+ Python function describing the system dynamics.
804
+ model_params : list or array-like
805
+ Parameters passed as arguments to model.
806
+ coords : list of 1D array-like or None, optional
807
+ Explicit phase space grid on which to evaluate Q-values, provided as a
808
+ list of 1D arrays (one per dimension). If None, the grid is constructed
809
+ automatically using n_points, ranges, and overrides (see below).
810
+ Default is None.
811
+ n_points : int or list of int, optional
812
+ Number of grid points along each dimension. A single integer applies
813
+ the same value to all dimensions; a list of integers sets each
814
+ dimension individually. Ignored if coords is provided. Default is 50.
815
+ ranges : tuple or list of tuple, optional
816
+ Bounds of the grid along each dimension, given as a
817
+ single (lower, upper) tuple applied to all dimensions, or a list of
818
+ such tuples to set each dimension individually. Ignored if coords is
819
+ provided. Default is (-2, 2).
820
+ overrides : dict or None, optional
821
+ Dictionary for overriding n_points and/or ranges for specific
822
+ individual axes without affecting the others. Keys are axis indices
823
+ (integers) and values are dictionaries with keys 'n' (int) and/or
824
+ 'range' (tuple). For example, to set axis 1 to 100 points over
825
+ (-5.0, 5.0): {1: {'n': 100, 'range': (-5.0, 5.0)}}. Ignored if
826
+ coords is provided. Default is None.
827
+ indexing : str, optional
828
+ Indexing convention passed to jax.numpy.meshgrid. Default is 'ij'
829
+ (matrix-style indexing, where the first index varies along rows).
830
+ jit : bool, optional
831
+ If True, enables JAX just-in-time (JIT) compilation to speed up
832
+ evaluation of Q on the grid. Default is False.
833
+
834
+ Returns
835
+ -------
836
+ Q_grid : array-like
837
+ Array of Q-values evaluated at each point of the phase space grid,
838
+ with shape determined by n_points and the number of dimensions.
839
+ """
840
+
841
+ if coords is None:
842
+
843
+ test = model(0.0, jnp.zeros(1), model_params)
844
+ dim = test.shape[0]
845
+
846
+ if ranges is None:
847
+ ranges = [(-2.0, 2.0)]*dim
848
+ elif isinstance(ranges[0], (int,float)):
849
+ ranges = [ranges]*dim
850
+
851
+ if isinstance(n_points,int):
852
+ n_points = [n_points]*dim
853
+
854
+ coords = []
855
+ for d in range(dim):
856
+ n = n_points[d] if d<len(n_points) else n_points[-1]
857
+ r = ranges[d] if d<len(ranges) else ranges[-1]
858
+ if overrides and d in overrides:
859
+ if "n" in overrides[d]:
860
+ n = overrides[d]["n"]
861
+ if "range" in overrides[d]:
862
+ r = overrides[d]["range"]
863
+ coords.append(jnp.linspace(r[0], r[1], n))
864
+
865
+ meshes = jnp.meshgrid(*coords, indexing=indexing)
866
+ grid_points = jnp.stack(meshes, axis=-1)
867
+
868
+ def core(grid_points):
869
+ flat_pts = grid_points.reshape(-1, grid_points.shape[-1])
870
+ F_vmapped = jax.vmap(lambda pt: model(0.0, pt, model_params))
871
+ values = F_vmapped(flat_pts)
872
+ Q_flat = 0.5*jnp.sum(values**2, axis=-1)
873
+ return Q_flat.reshape(grid_points.shape[:-1])
874
+
875
+ core = jax.jit(core) if jit else core
876
+ return core(grid_points), grid_points
877
+
878
+ def track_ghost_branch(ghost, model, model_params, par_nr, par_steps, dpar, t_end, dt, delta=0.5, icStep=0.1, mode="first",
879
+ epsilon_gid=0.1,solve_ivp_method='RK45', rtol=1.e-3, atol=1.e-6,**kwargs):
880
+ """
881
+ Track a ghost state across a parameter sweep by iteratively updating the parameter
882
+ and re-identifying the ghost at each step.
883
+
884
+ Parameters
885
+ ----------
886
+ ghost : dict
887
+ Ghost to be tracked, identified either by 'ghostID' or 'ghostID_phaseSpaceSample'.
888
+ model : callable
889
+ Python function describing the system dynamics.
890
+ model_params : list or array-like
891
+ Parameters for the model function.
892
+ par_nr : int
893
+ Index of the parameter in model_params to be varied during the sweep.
894
+ par_steps : int
895
+ Number of times the parameter is updated during the sweep.
896
+ dpar : float
897
+ Size of the parameter increment per iteration. Use a positive value to increase
898
+ the parameter and a negative value to decrease it.
899
+ t_end : float
900
+ Length of the trajectory integrated at each iteration step.
901
+ dt : float
902
+ Step-size used for numerical integration.
903
+ delta : float, optional
904
+ Size of the region around the current ghost position (xg) in which to search
905
+ for xQmin (the point of minimum speed). Default is 0.5.
906
+ icStep : float, optional
907
+ Distance from xQmin at which to initialize a trajectory that will be analyzed
908
+ by GhostID. Default is 0.1.
909
+ mode : {'first', 'closest'}, optional
910
+ Strategy for selecting among multiple ghosts potentially identified along a
911
+ trajectory.
912
+ - 'first' : take the first ghost found (default).
913
+ - 'closest' : take the ghost closest in phase space to the current ghost
914
+ position (xg).
915
+ epsilon_gid : float, optional
916
+ Threshold parameter for GhostID (see section 1.1 of the paper). Default is 0.1.
917
+ solve_ivp_method : str, optional
918
+ Integration method passed to scipy.integrate.solve_ivp (e.g. 'RK45').
919
+ Default is 'RK45'.
920
+ rtol : float, optional
921
+ Relative tolerance for the numerical integrator. Default is 1e-3.
922
+ atol : float, optional
923
+ Absolute tolerance for the numerical integrator. Default is 1e-6.
924
+ **kwargs
925
+ Optional keyword arguments:
926
+ - distQminThr (float): Maximum allowable distance between the identified ghost
927
+ and xQmin. Any ghost candidate whose distance from xQmin exceeds this
928
+ threshold is rejected. Default is infinity (no constraint).
929
+
930
+ Returns
931
+ -------
932
+ ghostPositions : array-like
933
+ Positions in phase space of the ghosts found at each parameter step.
934
+ parSeq : array-like
935
+ Sequence of parameter values at which each ghost was identified.
936
+ ghostSeq_p : list
937
+ Full list of ghost objects found at each parameter value, enabling extraction
938
+ of additional ghost properties beyond phase-space position.
939
+ control_plots : optional
940
+ Control plots for GhostID, returned only if the corresponding option is enabled.
941
+ """
942
+
943
+
944
+ # Parse and validate kwargs
945
+ config = parse_kwargs(**kwargs)
946
+
947
+ display_warnings = config['display_warnings']
948
+
949
+ # Extract parameters from config
950
+ delta_gid = config['delta_gid']
951
+ peak_kwargs = config['peak_kwargs']
952
+ if "width" not in peak_kwargs:
953
+ peak_kwargs["width"] = 5 * dt # default if not supplied
954
+ batchModel = config['batchModel']
955
+ eigval_NN_sorting = config['eigval_NN_sorting']
956
+ ev_outlier_removal_ws = config['ev_outlier_removal_ws']
957
+ ev_outlier_removal_k = config['ev_outlier_removal_k']
958
+ evLimit = config['evLimit']
959
+ slopeLimits = config['slopeLimits']
960
+
961
+ # Plotting control settings
962
+ return_ctrl_figs = config['return_ctrl_figs']
963
+ ctrlOutputs = kwargs.get("ctrlOutputs", {})
964
+
965
+ # Q-min search related kwargs
966
+ distQminThr = config['distQminThr']
967
+ qmin_glob_method = config['qmin_glob_method']
968
+ qmin_loc_method = config['qmin_loc_method']
969
+ qmin_glob_options = config['qmin_glob_options']
970
+ min_loc_options = config['qmin_loc_options']
971
+
972
+ ghostSeq_p = []
973
+ ctrl_figures_p = []
974
+ parSeq = []
975
+
976
+ parNext = model_params[par_nr]
977
+ try:
978
+ model_params_ = np.asarray(model_params).copy()
979
+ except:
980
+ model_params_ = model_params.copy()
981
+
982
+ ghost_ = ghost.copy()
983
+
984
+ i = 0
985
+ with tqdm(total=par_steps + 1) as pbar:
986
+ while i < par_steps+1:
987
+ pct = 100 * i / par_steps
988
+ pbar.set_description(f"Progress: {pct:6.2f}% | param value={parNext:.5f}")
989
+
990
+ x0 = ghost_["position"]
991
+
992
+ qmin = find_local_Qminimum(model,x0,model_params_,delta,
993
+ global_method=qmin_glob_method,
994
+ local_method=qmin_loc_method,
995
+ global_options=qmin_glob_options,
996
+ local_options=min_loc_options,
997
+ verbose=False)[0]
998
+
999
+ ic_plus, _ , _ = icAtQmin(qmin, icStep ,ghost_["dimension"],model,model_params_)
1000
+ ic_minus, _ , _ = icAtQmin(qmin, -icStep ,ghost_["dimension"],model,model_params_)
1001
+
1002
+ sol_plus = solve_ivp(model, (0,5*dt), jnp.real(ic_plus), t_eval=np.asarray(np.arange(0, 5*dt, dt)), rtol=rtol, atol=atol, args=([model_params_]), method=solve_ivp_method)
1003
+ sol_minus = solve_ivp(model, (0,5*dt), jnp.real(ic_minus), t_eval=np.asarray(np.arange(0, 5*dt, dt)), rtol=rtol, atol=atol, args=([model_params_]), method=solve_ivp_method)
1004
+
1005
+ dist_ic_plus = np.linalg.norm(qmin-ic_plus)
1006
+ dist_sol_plus = np.linalg.norm(qmin-sol_plus.y[:,-1])
1007
+
1008
+ dist_ic_minus = np.linalg.norm(qmin-ic_minus)
1009
+ dist_sol_minus= np.linalg.norm(qmin-sol_minus.y[:,-1])
1010
+
1011
+ if dist_sol_plus<dist_ic_plus:
1012
+ ic_pick = ic_plus
1013
+ elif dist_sol_minus<dist_ic_minus:
1014
+ ic_pick = ic_minus
1015
+ else:
1016
+ print("Terminating track_ghost_branch: Error in chosing initial conditions around qmin (both trajectories are diverging). Try different global/local method/options for finding qmin or different icStep size.")
1017
+ if return_ctrl_figs == False:
1018
+ return None, None, None
1019
+ else:
1020
+ return None, None, None, None
1021
+
1022
+ ic_pick = jnp.real(ic_pick)
1023
+
1024
+ sol = solve_ivp(model, (0,t_end), ic_pick, t_eval=np.asarray(np.arange(0, t_end, dt)), rtol=rtol, atol=atol, args=([model_params_]), method=solve_ivp_method)
1025
+
1026
+ gid_output = ghostID(model, model_params, dt, sol.y.T,
1027
+ epsilon_gid, delta_gid=delta_gid,peak_kwargs=peak_kwargs,
1028
+ batchModel=batchModel,return_ctrl_figs=return_ctrl_figs,ctrlOutputs=ctrlOutputs,
1029
+ evLimit=evLimit,slopeLimits=slopeLimits,eigval_NN_sorting=eigval_NN_sorting,
1030
+ ev_outlier_removal_ws=ev_outlier_removal_ws,ev_outlier_removal_k=ev_outlier_removal_k,
1031
+ display_warnings=display_warnings)
1032
+
1033
+ if return_ctrl_figs == False:
1034
+ ghostSeq = gid_output
1035
+ else:
1036
+ ghostSeq, ctrl_figures = gid_output
1037
+ ctrl_figures_p.append(ctrl_figures)
1038
+
1039
+ if len(ghostSeq)>0:
1040
+
1041
+ #append
1042
+ if mode=="first":
1043
+ distance = np.linalg.norm(ghostSeq[0]["position"]-qmin)
1044
+ if distance < distQminThr:
1045
+ ghostSeq_p.append(ghostSeq[0])
1046
+ parSeq.append(parNext)
1047
+ elif mode == "closest":
1048
+ positions = np.array([ghostSeq[ii]["position"] for ii in range(len(ghostSeq))])
1049
+ distances = np.linalg.norm(positions-qmin,axis=1)
1050
+ idx_min = np.argmin(distances)
1051
+ if distances[idx_min]<distQminThr:
1052
+ ghostSeq_p.append(ghostSeq[idx_min])
1053
+ parSeq.append(parNext)
1054
+ else:
1055
+ print("Unknown mode argument. Use default mode instead.")
1056
+ mode = "first"
1057
+ continue
1058
+
1059
+ #update
1060
+ parNext = parNext + dpar
1061
+ model_params_[par_nr] = parNext
1062
+ pbar.update(1)
1063
+ i+=1
1064
+
1065
+ else:
1066
+ print("No further ghosts found.")
1067
+ break
1068
+
1069
+ ghostPositions = np.asarray([ghostSeq_p[ii]["position"] for ii in range(len(ghostSeq_p))])
1070
+
1071
+ if return_ctrl_figs == False:
1072
+ return ghostPositions, np.asarray(parSeq), ghostSeq_p
1073
+ else:
1074
+ return ghostPositions, np.asarray(parSeq), ghostSeq_p, ctrl_figures_p
1075
+
1076
+ def ghost_connections(gSeqs):
1077
+ """
1078
+ Takes list of ghost sequences and turns it into an adjacency matrix.
1079
+
1080
+ Input:
1081
+ - gSeq: list of ghost sequences that have been generated by ghostID
1082
+ Output:
1083
+ - adjM: adjecency matrix representing connections between identified ghosts in phase space
1084
+ - labels: labels of matrix rows/columns
1085
+ """
1086
+
1087
+ labels = []
1088
+
1089
+ for s in gSeqs:
1090
+ for i in s:
1091
+ if i["id"][:1]=="G" and not i["id"] in labels:
1092
+ labels.append(i["id"])
1093
+
1094
+ ng = len(labels)
1095
+ adjM = np.zeros((ng,ng))
1096
+
1097
+ seqIDs = [[g["id"] for g in s] for s in gSeqs]
1098
+
1099
+ for s in seqIDs:
1100
+ for i in range(len(s)-1):
1101
+ e_out = labels.index(s[i])
1102
+ e_in = labels.index(s[i+1])
1103
+ if adjM[e_out,e_in]==0:
1104
+ adjM[e_out,e_in]=1
1105
+
1106
+ return adjM, labels
1107
+
1108
+ def unique_ghosts(gSeq):
1109
+ """
1110
+ Takes list of unified ghost sequences and returns a list of all unique ghosts
1111
+ """
1112
+
1113
+ ghostIDs = []
1114
+ ghostsUnique = []
1115
+
1116
+ for s in gSeq:
1117
+ for i in s:
1118
+ if i["id"][:1]=="G" and not i["id"] in ghostIDs:
1119
+ ghostIDs.append(i["id"])
1120
+ ghostsUnique.append(i)
1121
+
1122
+ return ghostsUnique
1123
+
1124
+ def unify_IDs(seqs, delta_unify=0.1, update=True):
1125
+ """
1126
+ Unify ids across multiple sequences of transient ghost state objects found in timecourse simulations.
1127
+
1128
+ Parameters
1129
+ ----------
1130
+ Seqs : list[list[dict]]
1131
+ Each inner list corresponds to one simulation run.
1132
+ Each dict represents a ghost state and must contain:
1133
+ - 'position' : np.ndarray, phase-space position of the ghost
1134
+ - 'id' : str of the form 'G{i}'
1135
+ - 'q-value' : float, scalar quality / stability measure
1136
+ - 'dimension' : int, dimension associated with the ghost
1137
+
1138
+ delta_gid : float, optional
1139
+ Distance threshold below which two ghost states are considered
1140
+ identical (i.e. the same ghost across runs).
1141
+
1142
+ update : bool, optional (default=True)
1143
+ If True, perform a second pass after ID unification that
1144
+ synchronizes properties (position, q-value, dimension)
1145
+ across all ghosts sharing the same ID.
1146
+
1147
+ Returns
1148
+ -------
1149
+ Seqs : list[list[dict]]
1150
+ The same list structure, with unified IDs and (optionally)
1151
+ updated ghost properties.
1152
+ """
1153
+
1154
+ # ------------------------------------------------------------------
1155
+ # STEP 1: Initialize reference ghosts from the first sequence
1156
+ # ------------------------------------------------------------------
1157
+
1158
+ Seqs = seqs.copy()
1159
+ first = Seqs[0]
1160
+
1161
+ # Dictionary mapping ghost ID -> representative position
1162
+ known_G = {}
1163
+
1164
+ # Track the highest numerical ghost index encountered so far
1165
+ max_g = 0
1166
+
1167
+ for obj in first:
1168
+ pid = obj['id']
1169
+ if pid.startswith('G'):
1170
+ idx = int(pid[1:]) # extract numerical part of ID
1171
+ known_G[pid] = obj['position'].copy()
1172
+ max_g = max(max_g, idx)
1173
+ else:
1174
+ raise ValueError(f"Unrecognized id '{pid}'")
1175
+
1176
+ # ------------------------------------------------------------------
1177
+ # STEP 2: Unify IDs across all subsequent sequences
1178
+ # ------------------------------------------------------------------
1179
+
1180
+ for seq in Seqs[1:]:
1181
+ for obj in seq:
1182
+ pos = obj['position']
1183
+ orig = obj['id']
1184
+
1185
+ if not orig.startswith('G'):
1186
+ raise ValueError(f"Unrecognized id '{orig}'")
1187
+
1188
+ # Try to match this ghost against previously known ghosts
1189
+ matched = False
1190
+ for pid, refpos in known_G.items():
1191
+ # Compare Euclidean distance in phase space
1192
+ if np.linalg.norm(pos - refpos) < delta_unify:
1193
+ obj['id'] = pid # reuse existing ID
1194
+ matched = True
1195
+ break
1196
+
1197
+ # If no match was found, register a new ghost ID
1198
+ if not matched:
1199
+ max_g += 1
1200
+ new_id = f'G{max_g}'
1201
+ obj['id'] = new_id
1202
+ known_G[new_id] = pos.copy()
1203
+
1204
+ # ------------------------------------------------------------------
1205
+ # STEP 3 (optional): Update ghost properties across identical IDs
1206
+ # ------------------------------------------------------------------
1207
+
1208
+ if update:
1209
+ # Collect all ghosts grouped by ID
1210
+ ghosts_by_id = {}
1211
+ for seq in Seqs:
1212
+ for obj in seq:
1213
+ gid = obj['id']
1214
+ ghosts_by_id.setdefault(gid, []).append(obj)
1215
+
1216
+ # For each ghost ID, synchronize properties
1217
+ for gid, ghosts in ghosts_by_id.items():
1218
+ # ----------------------------------------------------------
1219
+ # (a) Find ghost with minimal q-value
1220
+ # ----------------------------------------------------------
1221
+ missing = [o for o in ghosts if 'q-value' not in o]
1222
+ if missing:
1223
+ print(f"Ghosts missing q-value for id {gid}:")
1224
+ for m in missing:
1225
+ print(m.keys())
1226
+ ref = min(ghosts, key=lambda o: o['q-value'])
1227
+ ref_pos = ref['position'].copy()
1228
+ ref_q = ref['q-value']
1229
+
1230
+ # ----------------------------------------------------------
1231
+ # (b) Find maximal dimension across ghosts with this ID
1232
+ # ----------------------------------------------------------
1233
+ max_dim = max(o['dimension'] for o in ghosts)
1234
+
1235
+ # ----------------------------------------------------------
1236
+ # (c) Update all ghosts with synchronized values
1237
+ # ----------------------------------------------------------
1238
+ for o in ghosts:
1239
+ o['position'] = ref_pos.copy()
1240
+ o['q-value'] = ref_q
1241
+ o['dimension'] = max_dim
1242
+
1243
+ return Seqs
1244
+
1245
+ def draw_network(adj_matrix, nodeCols, nlbls, layout="fdp", graphviz_args=None, layout_kwargs=None, rankdir="TB", node_size=1800, label_font_size=16.5, font="Arial"):
1246
+ """
1247
+ layout options
1248
+ --------------
1249
+ Graphviz:
1250
+ 'fdp', 'dot', 'neato', 'sfdp', 'circo'
1251
+ Semantic aliases:
1252
+ 'hierarchical' -> Graphviz 'dot'
1253
+ NetworkX:
1254
+ any nx.*_layout function (NOT nx.draw_*)
1255
+ """
1256
+
1257
+ if layout_kwargs is None:
1258
+ layout_kwargs = {}
1259
+
1260
+ nw_dim = adj_matrix.shape[0]
1261
+ G = nx.from_numpy_array(adj_matrix.transpose(), create_using=nx.DiGraph)
1262
+
1263
+ # --- Define edges ------------------------------------------------------
1264
+ inhEdges, actEdges = [], []
1265
+ for i in range(nw_dim):
1266
+ for ii in range(nw_dim):
1267
+ if adj_matrix[i, ii] == 1:
1268
+ G.add_edge(i, ii, weight=1)
1269
+ actEdges.append((i, ii))
1270
+ elif adj_matrix[i, ii] == -1:
1271
+ G.add_edge(i, ii, weight=-1)
1272
+ inhEdges.append((i, ii))
1273
+
1274
+ # --- Layout handling ---------------------------------------------------
1275
+ if graphviz_args is None:
1276
+ graphviz_args = (
1277
+ f"-Grankdir={rankdir} "
1278
+ "-Nwidth=350 -Nheight=350 -Nfixedsize=true "
1279
+ "-Goverlap=scale -Gnodesep=5000 -Granksep=200 "
1280
+ "-Nshape=oval -Nfontsize=14 -Econstraint=true"
1281
+ )
1282
+
1283
+ if layout == "hierarchical":
1284
+ layout = "dot"
1285
+
1286
+ if isinstance(layout, str):
1287
+ pos = graphviz_layout(G, prog=layout, args=graphviz_args)
1288
+
1289
+ elif callable(layout):
1290
+ if layout.__name__.startswith("draw_"):
1291
+ raise ValueError(
1292
+ f"{layout.__name__} is a drawing function. "
1293
+ "Use a layout function like nx.spectral_layout instead."
1294
+ )
1295
+ pos = layout(G, **layout_kwargs)
1296
+
1297
+ else:
1298
+ raise ValueError("layout must be a string or a callable")
1299
+
1300
+ # --- Draw nodes --------------------------------------------------------
1301
+ node_opts = {
1302
+ "node_size": node_size,
1303
+ "edgecolors": "white",
1304
+ }
1305
+ nx.draw_networkx_nodes(
1306
+ G, pos, node_color=nodeCols, **node_opts, alpha=0.90
1307
+ )
1308
+
1309
+ # --- Draw edges --------------------------------------------------------
1310
+ draw_custom_edges(
1311
+ G, pos, actEdges,
1312
+ color="k", head_length=8, head_width=4,
1313
+ width=1.1, trim_fraction=0.2
1314
+ )
1315
+ draw_custom_edges(
1316
+ G, pos, inhEdges,
1317
+ color="red", head_length=8, head_width=4,
1318
+ width=1.1, trim_fraction=0.2
1319
+ )
1320
+
1321
+ # --- Draw labels -------------------------------------------------------
1322
+ labels = {i: nlbls[i] for i in range(nw_dim)}
1323
+ nx.draw_networkx_labels(
1324
+ G,
1325
+ pos,
1326
+ labels,
1327
+ font_size=label_font_size,
1328
+ font_weight="bold",
1329
+ font_family=font,
1330
+ )