nisplus 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nisplus/__init__.py +24 -0
- nisplus/data/__init__.py +41 -0
- nisplus/data/kuramoto.py +459 -0
- nisplus/data/sir.py +263 -0
- nisplus/ei.py +355 -0
- nisplus/losses.py +139 -0
- nisplus/models.py +300 -0
- nisplus-0.1.0.dist-info/METADATA +222 -0
- nisplus-0.1.0.dist-info/RECORD +11 -0
- nisplus-0.1.0.dist-info/WHEEL +5 -0
- nisplus-0.1.0.dist-info/top_level.txt +1 -0
nisplus/__init__.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
"""Core public API for NIS+."""
|
|
2
|
+
|
|
3
|
+
from .ei import (
|
|
4
|
+
compute_causal_emergence,
|
|
5
|
+
compute_weights,
|
|
6
|
+
estimate_dei,
|
|
7
|
+
estimate_dei_micro,
|
|
8
|
+
recompute_weights,
|
|
9
|
+
)
|
|
10
|
+
from .losses import multi_step_prediction_loss, stage1_loss, stage2_loss
|
|
11
|
+
from .models import NISPlusMicroDynamics, NISPlusNN
|
|
12
|
+
|
|
13
|
+
__all__ = [
|
|
14
|
+
"NISPlusNN",
|
|
15
|
+
"NISPlusMicroDynamics",
|
|
16
|
+
"stage1_loss",
|
|
17
|
+
"stage2_loss",
|
|
18
|
+
"multi_step_prediction_loss",
|
|
19
|
+
"compute_weights",
|
|
20
|
+
"recompute_weights",
|
|
21
|
+
"estimate_dei",
|
|
22
|
+
"estimate_dei_micro",
|
|
23
|
+
"compute_causal_emergence",
|
|
24
|
+
]
|
nisplus/data/__init__.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
"""Synthetic data utilities shipped with NIS+."""
|
|
2
|
+
|
|
3
|
+
from .kuramoto import (
|
|
4
|
+
KuramotoDataset,
|
|
5
|
+
angles_to_micro,
|
|
6
|
+
build_two_community_coupling,
|
|
7
|
+
community_order_parameters,
|
|
8
|
+
generate_kuramoto_dataset,
|
|
9
|
+
generate_kuramoto_dataset_multistep,
|
|
10
|
+
get_kuramoto_dataloaders,
|
|
11
|
+
global_order_parameter,
|
|
12
|
+
micro_to_community_macro,
|
|
13
|
+
simulate_kuramoto_trajectory,
|
|
14
|
+
)
|
|
15
|
+
from .sir import (
|
|
16
|
+
SIRDataset,
|
|
17
|
+
generate_sir_dataset,
|
|
18
|
+
get_sir_dataloaders,
|
|
19
|
+
macro_to_micro,
|
|
20
|
+
simulate_sir_trajectory,
|
|
21
|
+
sir_derivatives,
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
__all__ = [
|
|
25
|
+
"SIRDataset",
|
|
26
|
+
"sir_derivatives",
|
|
27
|
+
"simulate_sir_trajectory",
|
|
28
|
+
"macro_to_micro",
|
|
29
|
+
"generate_sir_dataset",
|
|
30
|
+
"get_sir_dataloaders",
|
|
31
|
+
"KuramotoDataset",
|
|
32
|
+
"angles_to_micro",
|
|
33
|
+
"build_two_community_coupling",
|
|
34
|
+
"community_order_parameters",
|
|
35
|
+
"global_order_parameter",
|
|
36
|
+
"micro_to_community_macro",
|
|
37
|
+
"simulate_kuramoto_trajectory",
|
|
38
|
+
"generate_kuramoto_dataset",
|
|
39
|
+
"generate_kuramoto_dataset_multistep",
|
|
40
|
+
"get_kuramoto_dataloaders",
|
|
41
|
+
]
|
nisplus/data/kuramoto.py
ADDED
|
@@ -0,0 +1,459 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Kuramoto data generation for NIS+ experiments.
|
|
3
|
+
|
|
4
|
+
This module builds a small two-community Kuramoto system:
|
|
5
|
+
- 10 oscillators split into two interacting communities
|
|
6
|
+
- strong intra-community coupling, weaker inter-community coupling
|
|
7
|
+
- micro-state is the concatenated cosine/sine embedding of all phases
|
|
8
|
+
- macro-state is the pair of community order parameters
|
|
9
|
+
|
|
10
|
+
The construction makes the 2-community scale the intended macro level.
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
import math
|
|
14
|
+
|
|
15
|
+
import numpy as np
|
|
16
|
+
import torch
|
|
17
|
+
from torch.utils.data import DataLoader, Dataset
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def wrap_angles(theta):
|
|
21
|
+
"""Wrap angles to [-pi, pi)."""
|
|
22
|
+
return (theta + math.pi) % (2 * math.pi) - math.pi
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def make_group_centers(n_groups, freq_bound=0.45):
|
|
26
|
+
"""Create evenly spaced natural-frequency centers for any number of communities."""
|
|
27
|
+
if n_groups == 1:
|
|
28
|
+
return (0.0,)
|
|
29
|
+
return tuple(np.linspace(-freq_bound, freq_bound, n_groups).tolist())
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def build_two_community_coupling(group_sizes=(5, 5), intra_coupling=2.2, inter_coupling=0.35):
|
|
33
|
+
"""Create a block-structured Kuramoto coupling matrix."""
|
|
34
|
+
n_oscillators = sum(group_sizes)
|
|
35
|
+
coupling = np.full((n_oscillators, n_oscillators), inter_coupling / n_oscillators, dtype=float)
|
|
36
|
+
|
|
37
|
+
groups = []
|
|
38
|
+
start = 0
|
|
39
|
+
for group_size in group_sizes:
|
|
40
|
+
group = np.arange(start, start + group_size)
|
|
41
|
+
groups.append(group)
|
|
42
|
+
coupling[np.ix_(group, group)] = intra_coupling / n_oscillators
|
|
43
|
+
start += group_size
|
|
44
|
+
|
|
45
|
+
np.fill_diagonal(coupling, 0.0)
|
|
46
|
+
return coupling, groups
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def sample_natural_frequencies(groups, base_frequencies=(-0.45, 0.45), jitter=0.05, rng=None):
|
|
50
|
+
"""Sample oscillator frequencies with different means for the two communities."""
|
|
51
|
+
if rng is None:
|
|
52
|
+
rng = np.random.default_rng()
|
|
53
|
+
|
|
54
|
+
if base_frequencies is None or len(base_frequencies) != len(groups):
|
|
55
|
+
base_frequencies = make_group_centers(len(groups))
|
|
56
|
+
|
|
57
|
+
omegas = np.zeros(sum(len(group) for group in groups), dtype=float)
|
|
58
|
+
for group, base in zip(groups, base_frequencies):
|
|
59
|
+
omegas[group] = base + rng.normal(0.0, jitter, size=len(group))
|
|
60
|
+
return omegas
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def kuramoto_step(theta, omegas, coupling, dt=0.05, process_noise=0.01, rng=None, groups=None):
|
|
64
|
+
"""Advance one Euler-Maruyama step of the Kuramoto dynamics."""
|
|
65
|
+
if rng is None:
|
|
66
|
+
rng = np.random.default_rng()
|
|
67
|
+
|
|
68
|
+
if groups is None:
|
|
69
|
+
phase_diff = theta[None, :] - theta[:, None]
|
|
70
|
+
interaction = np.sum(coupling * np.sin(phase_diff), axis=1)
|
|
71
|
+
else:
|
|
72
|
+
interaction = np.zeros_like(theta)
|
|
73
|
+
group_order = [np.exp(1j * theta[group]).mean() for group in groups]
|
|
74
|
+
for g_idx, group in enumerate(groups):
|
|
75
|
+
theta_g = theta[group]
|
|
76
|
+
total = np.zeros(len(group), dtype=float)
|
|
77
|
+
for h_idx, other_group in enumerate(groups):
|
|
78
|
+
coeff = coupling[group[0], other_group[0]]
|
|
79
|
+
total += coeff * len(other_group) * np.imag(group_order[h_idx] * np.exp(-1j * theta_g))
|
|
80
|
+
interaction[group] = total
|
|
81
|
+
|
|
82
|
+
noise = math.sqrt(dt) * process_noise * rng.normal(size=theta.shape)
|
|
83
|
+
theta_next = theta + dt * (omegas + interaction) + noise
|
|
84
|
+
return wrap_angles(theta_next)
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def simulate_kuramoto_trajectory(theta0, omegas, coupling, dt=0.05, n_steps=4,
|
|
88
|
+
burn_in=30, sample_stride=4, process_noise=0.01, rng=None,
|
|
89
|
+
groups=None):
|
|
90
|
+
"""
|
|
91
|
+
Simulate one trajectory and return sampled phase states.
|
|
92
|
+
|
|
93
|
+
The returned array has shape (n_steps + 1, n_oscillators).
|
|
94
|
+
"""
|
|
95
|
+
if rng is None:
|
|
96
|
+
rng = np.random.default_rng()
|
|
97
|
+
|
|
98
|
+
theta = np.asarray(theta0, dtype=float).copy()
|
|
99
|
+
|
|
100
|
+
for _ in range(burn_in):
|
|
101
|
+
theta = kuramoto_step(
|
|
102
|
+
theta, omegas, coupling, dt=dt, process_noise=process_noise, rng=rng, groups=groups
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
trajectory = [theta.copy()]
|
|
106
|
+
for _ in range(n_steps):
|
|
107
|
+
for _ in range(sample_stride):
|
|
108
|
+
theta = kuramoto_step(
|
|
109
|
+
theta, omegas, coupling, dt=dt, process_noise=process_noise, rng=rng, groups=groups
|
|
110
|
+
)
|
|
111
|
+
trajectory.append(theta.copy())
|
|
112
|
+
|
|
113
|
+
return np.stack(trajectory, axis=0)
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def angles_to_micro(theta, observation_noise=0.0, rng=None):
|
|
117
|
+
"""Map oscillator phases to a cosine/sine micro-state embedding."""
|
|
118
|
+
if rng is None:
|
|
119
|
+
rng = np.random.default_rng()
|
|
120
|
+
|
|
121
|
+
theta = np.asarray(theta)
|
|
122
|
+
x = np.concatenate([np.cos(theta), np.sin(theta)], axis=-1)
|
|
123
|
+
if observation_noise > 0:
|
|
124
|
+
x = x + rng.normal(0.0, observation_noise, size=x.shape)
|
|
125
|
+
x = np.clip(x, -1.2, 1.2)
|
|
126
|
+
return x
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def community_order_parameters(theta, groups):
|
|
130
|
+
"""
|
|
131
|
+
Compute community-level order parameters.
|
|
132
|
+
|
|
133
|
+
Returns a vector [Re z_1, Im z_1, Re z_2, Im z_2, ...].
|
|
134
|
+
"""
|
|
135
|
+
theta = np.asarray(theta)
|
|
136
|
+
features = []
|
|
137
|
+
for group in groups:
|
|
138
|
+
theta_g = theta[..., group]
|
|
139
|
+
features.append(np.cos(theta_g).mean(axis=-1))
|
|
140
|
+
features.append(np.sin(theta_g).mean(axis=-1))
|
|
141
|
+
return np.stack(features, axis=-1)
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
def global_order_parameter(theta):
|
|
145
|
+
"""Compute the global complex order parameter [Re z, Im z]."""
|
|
146
|
+
theta = np.asarray(theta)
|
|
147
|
+
return np.stack([np.cos(theta).mean(axis=-1), np.sin(theta).mean(axis=-1)], axis=-1)
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
def micro_to_community_macro(x, groups):
|
|
151
|
+
"""Recover community-level macro features from cosine/sine micro-states."""
|
|
152
|
+
x = np.asarray(x)
|
|
153
|
+
n_oscillators = x.shape[-1] // 2
|
|
154
|
+
cos_part = x[..., :n_oscillators]
|
|
155
|
+
sin_part = x[..., n_oscillators:]
|
|
156
|
+
|
|
157
|
+
features = []
|
|
158
|
+
for group in groups:
|
|
159
|
+
features.append(cos_part[..., group].mean(axis=-1))
|
|
160
|
+
features.append(sin_part[..., group].mean(axis=-1))
|
|
161
|
+
return np.stack(features, axis=-1)
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
def sample_initial_phases(groups, rng, phase_gap_range=(0.9, 2.1), init_spread=0.35):
|
|
165
|
+
"""Sample community-structured initial phases."""
|
|
166
|
+
base_phase = rng.uniform(-math.pi, math.pi)
|
|
167
|
+
gap = rng.uniform(*phase_gap_range) * rng.choice([-1.0, 1.0])
|
|
168
|
+
centers = [base_phase, wrap_angles(base_phase + gap)]
|
|
169
|
+
|
|
170
|
+
theta0 = np.zeros(sum(len(group) for group in groups), dtype=float)
|
|
171
|
+
for center, group in zip(centers, groups):
|
|
172
|
+
theta0[group] = wrap_angles(center + rng.normal(0.0, init_spread, size=len(group)))
|
|
173
|
+
return theta0
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
def generate_kuramoto_dataset(n_trajectories=500, n_steps=4, group_sizes=(5, 5),
|
|
177
|
+
intra_coupling=3.5, inter_coupling=0.45,
|
|
178
|
+
base_frequencies=(-0.45, 0.45), omega_jitter=0.05,
|
|
179
|
+
dt=0.05, burn_in=40, sample_stride=3, process_noise=0.05,
|
|
180
|
+
init_spread=0.35, phase_gap_range=(0.9, 2.1),
|
|
181
|
+
observation_noise=0.08, seed=42):
|
|
182
|
+
"""
|
|
183
|
+
Generate a two-community Kuramoto dataset for one-step prediction.
|
|
184
|
+
|
|
185
|
+
Returns:
|
|
186
|
+
x_t, x_tp1: micro-states, shape (N, 2 * n_oscillators)
|
|
187
|
+
y_t, y_tp1: community macro-states, shape (N, 2 * n_groups)
|
|
188
|
+
meta: system metadata for visualization / evaluation
|
|
189
|
+
"""
|
|
190
|
+
rng = np.random.default_rng(seed)
|
|
191
|
+
|
|
192
|
+
coupling, groups = build_two_community_coupling(
|
|
193
|
+
group_sizes=group_sizes,
|
|
194
|
+
intra_coupling=intra_coupling,
|
|
195
|
+
inter_coupling=inter_coupling,
|
|
196
|
+
)
|
|
197
|
+
omegas = sample_natural_frequencies(
|
|
198
|
+
groups,
|
|
199
|
+
base_frequencies=base_frequencies,
|
|
200
|
+
jitter=omega_jitter,
|
|
201
|
+
rng=rng,
|
|
202
|
+
)
|
|
203
|
+
|
|
204
|
+
x_t_all = []
|
|
205
|
+
x_tp1_all = []
|
|
206
|
+
y_t_all = []
|
|
207
|
+
y_tp1_all = []
|
|
208
|
+
theta_t_all = []
|
|
209
|
+
theta_tp1_all = []
|
|
210
|
+
|
|
211
|
+
for _ in range(n_trajectories):
|
|
212
|
+
theta0 = sample_initial_phases(
|
|
213
|
+
groups,
|
|
214
|
+
rng=rng,
|
|
215
|
+
phase_gap_range=phase_gap_range,
|
|
216
|
+
init_spread=init_spread,
|
|
217
|
+
)
|
|
218
|
+
trajectory = simulate_kuramoto_trajectory(
|
|
219
|
+
theta0,
|
|
220
|
+
omegas,
|
|
221
|
+
coupling,
|
|
222
|
+
dt=dt,
|
|
223
|
+
n_steps=n_steps,
|
|
224
|
+
burn_in=burn_in,
|
|
225
|
+
sample_stride=sample_stride,
|
|
226
|
+
process_noise=process_noise,
|
|
227
|
+
rng=rng,
|
|
228
|
+
groups=groups,
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
for t in range(n_steps):
|
|
232
|
+
theta_t = trajectory[t]
|
|
233
|
+
theta_tp1 = trajectory[t + 1]
|
|
234
|
+
|
|
235
|
+
x_t_all.append(angles_to_micro(theta_t, observation_noise=observation_noise, rng=rng))
|
|
236
|
+
x_tp1_all.append(angles_to_micro(theta_tp1, observation_noise=observation_noise, rng=rng))
|
|
237
|
+
y_t_all.append(community_order_parameters(theta_t, groups))
|
|
238
|
+
y_tp1_all.append(community_order_parameters(theta_tp1, groups))
|
|
239
|
+
theta_t_all.append(theta_t)
|
|
240
|
+
theta_tp1_all.append(theta_tp1)
|
|
241
|
+
|
|
242
|
+
meta = {
|
|
243
|
+
"groups": [group.tolist() for group in groups],
|
|
244
|
+
"omegas": omegas,
|
|
245
|
+
"coupling": coupling,
|
|
246
|
+
"global_order_t": global_order_parameter(np.stack(theta_t_all, axis=0)),
|
|
247
|
+
"global_order_tp1": global_order_parameter(np.stack(theta_tp1_all, axis=0)),
|
|
248
|
+
"theta_t": np.stack(theta_t_all, axis=0),
|
|
249
|
+
"theta_tp1": np.stack(theta_tp1_all, axis=0),
|
|
250
|
+
}
|
|
251
|
+
|
|
252
|
+
return (
|
|
253
|
+
np.stack(x_t_all, axis=0),
|
|
254
|
+
np.stack(x_tp1_all, axis=0),
|
|
255
|
+
np.stack(y_t_all, axis=0),
|
|
256
|
+
np.stack(y_tp1_all, axis=0),
|
|
257
|
+
meta,
|
|
258
|
+
)
|
|
259
|
+
|
|
260
|
+
|
|
261
|
+
class KuramotoDataset(Dataset):
|
|
262
|
+
"""PyTorch dataset for Kuramoto transition pairs."""
|
|
263
|
+
|
|
264
|
+
def __init__(self, x_t, x_tp1, y_t=None, y_tp1=None, x_future=None, y_future=None):
|
|
265
|
+
self.x_t = torch.FloatTensor(x_t)
|
|
266
|
+
self.x_tp1 = torch.FloatTensor(x_tp1)
|
|
267
|
+
self.y_t = torch.FloatTensor(y_t) if y_t is not None else None
|
|
268
|
+
self.y_tp1 = torch.FloatTensor(y_tp1) if y_tp1 is not None else None
|
|
269
|
+
self.x_future = torch.FloatTensor(x_future) if x_future is not None else None
|
|
270
|
+
self.y_future = torch.FloatTensor(y_future) if y_future is not None else None
|
|
271
|
+
|
|
272
|
+
def __len__(self):
|
|
273
|
+
return len(self.x_t)
|
|
274
|
+
|
|
275
|
+
def __getitem__(self, idx):
|
|
276
|
+
item = {
|
|
277
|
+
"x_t": self.x_t[idx],
|
|
278
|
+
"x_tp1": self.x_tp1[idx],
|
|
279
|
+
"idx": idx,
|
|
280
|
+
}
|
|
281
|
+
if self.y_t is not None:
|
|
282
|
+
item["y_t"] = self.y_t[idx]
|
|
283
|
+
if self.y_tp1 is not None:
|
|
284
|
+
item["y_tp1"] = self.y_tp1[idx]
|
|
285
|
+
if self.x_future is not None:
|
|
286
|
+
item["x_future"] = self.x_future[idx]
|
|
287
|
+
if self.y_future is not None:
|
|
288
|
+
item["y_future"] = self.y_future[idx]
|
|
289
|
+
return item
|
|
290
|
+
|
|
291
|
+
|
|
292
|
+
def generate_kuramoto_dataset_multistep(n_trajectories=500, n_steps=6, group_sizes=(5, 5),
|
|
293
|
+
intra_coupling=3.5, inter_coupling=0.45,
|
|
294
|
+
base_frequencies=(-0.45, 0.45), omega_jitter=0.05,
|
|
295
|
+
dt=0.05, burn_in=40, sample_stride=3, process_noise=0.05,
|
|
296
|
+
init_spread=0.35, phase_gap_range=(0.9, 2.1),
|
|
297
|
+
observation_noise=0.08, multi_horizon=1, seed=42):
|
|
298
|
+
"""
|
|
299
|
+
Generate a Kuramoto dataset with optional multi-step future windows.
|
|
300
|
+
|
|
301
|
+
Returns:
|
|
302
|
+
x_t, x_tp1, y_t, y_tp1, x_future, y_future, meta
|
|
303
|
+
"""
|
|
304
|
+
rng = np.random.default_rng(seed)
|
|
305
|
+
multi_horizon = max(int(multi_horizon), 1)
|
|
306
|
+
if n_steps < multi_horizon:
|
|
307
|
+
raise ValueError(f"n_steps={n_steps} must be >= multi_horizon={multi_horizon}")
|
|
308
|
+
|
|
309
|
+
coupling, groups = build_two_community_coupling(
|
|
310
|
+
group_sizes=group_sizes,
|
|
311
|
+
intra_coupling=intra_coupling,
|
|
312
|
+
inter_coupling=inter_coupling,
|
|
313
|
+
)
|
|
314
|
+
omegas = sample_natural_frequencies(
|
|
315
|
+
groups,
|
|
316
|
+
base_frequencies=base_frequencies,
|
|
317
|
+
jitter=omega_jitter,
|
|
318
|
+
rng=rng,
|
|
319
|
+
)
|
|
320
|
+
|
|
321
|
+
x_t_all = []
|
|
322
|
+
x_tp1_all = []
|
|
323
|
+
y_t_all = []
|
|
324
|
+
y_tp1_all = []
|
|
325
|
+
x_future_all = []
|
|
326
|
+
y_future_all = []
|
|
327
|
+
theta_t_all = []
|
|
328
|
+
theta_tp1_all = []
|
|
329
|
+
|
|
330
|
+
for _ in range(n_trajectories):
|
|
331
|
+
theta0 = sample_initial_phases(
|
|
332
|
+
groups,
|
|
333
|
+
rng=rng,
|
|
334
|
+
phase_gap_range=phase_gap_range,
|
|
335
|
+
init_spread=init_spread,
|
|
336
|
+
)
|
|
337
|
+
trajectory = simulate_kuramoto_trajectory(
|
|
338
|
+
theta0,
|
|
339
|
+
omegas,
|
|
340
|
+
coupling,
|
|
341
|
+
dt=dt,
|
|
342
|
+
n_steps=n_steps,
|
|
343
|
+
burn_in=burn_in,
|
|
344
|
+
sample_stride=sample_stride,
|
|
345
|
+
process_noise=process_noise,
|
|
346
|
+
rng=rng,
|
|
347
|
+
groups=groups,
|
|
348
|
+
)
|
|
349
|
+
|
|
350
|
+
x_traj = np.stack(
|
|
351
|
+
[angles_to_micro(theta, observation_noise=observation_noise, rng=rng) for theta in trajectory],
|
|
352
|
+
axis=0,
|
|
353
|
+
)
|
|
354
|
+
y_traj = community_order_parameters(trajectory, groups)
|
|
355
|
+
|
|
356
|
+
max_start = n_steps - multi_horizon + 1
|
|
357
|
+
for t in range(max_start):
|
|
358
|
+
x_t_all.append(x_traj[t])
|
|
359
|
+
x_tp1_all.append(x_traj[t + 1])
|
|
360
|
+
y_t_all.append(y_traj[t])
|
|
361
|
+
y_tp1_all.append(y_traj[t + 1])
|
|
362
|
+
x_future_all.append(x_traj[t + 1:t + 1 + multi_horizon])
|
|
363
|
+
y_future_all.append(y_traj[t + 1:t + 1 + multi_horizon])
|
|
364
|
+
theta_t_all.append(trajectory[t])
|
|
365
|
+
theta_tp1_all.append(trajectory[t + 1])
|
|
366
|
+
|
|
367
|
+
meta = {
|
|
368
|
+
"groups": [group.tolist() for group in groups],
|
|
369
|
+
"group_sizes": list(group_sizes),
|
|
370
|
+
"n_groups": len(groups),
|
|
371
|
+
"n_oscillators": int(sum(group_sizes)),
|
|
372
|
+
"omegas": omegas,
|
|
373
|
+
"coupling": coupling,
|
|
374
|
+
"multi_horizon": multi_horizon,
|
|
375
|
+
"global_order_t": global_order_parameter(np.stack(theta_t_all, axis=0)),
|
|
376
|
+
"global_order_tp1": global_order_parameter(np.stack(theta_tp1_all, axis=0)),
|
|
377
|
+
"theta_t": np.stack(theta_t_all, axis=0),
|
|
378
|
+
"theta_tp1": np.stack(theta_tp1_all, axis=0),
|
|
379
|
+
}
|
|
380
|
+
|
|
381
|
+
return (
|
|
382
|
+
np.stack(x_t_all, axis=0),
|
|
383
|
+
np.stack(x_tp1_all, axis=0),
|
|
384
|
+
np.stack(y_t_all, axis=0),
|
|
385
|
+
np.stack(y_tp1_all, axis=0),
|
|
386
|
+
np.stack(x_future_all, axis=0),
|
|
387
|
+
np.stack(y_future_all, axis=0),
|
|
388
|
+
meta,
|
|
389
|
+
)
|
|
390
|
+
|
|
391
|
+
|
|
392
|
+
def get_kuramoto_dataloaders(n_trajectories=500, n_steps=4, batch_size=128,
|
|
393
|
+
test_ratio=0.1, seed=42, multi_horizon=1, **dataset_kwargs):
|
|
394
|
+
"""Build train/test dataloaders and return dataset metadata."""
|
|
395
|
+
x_t, x_tp1, y_t, y_tp1, x_future, y_future, meta = generate_kuramoto_dataset_multistep(
|
|
396
|
+
n_trajectories=n_trajectories,
|
|
397
|
+
n_steps=n_steps,
|
|
398
|
+
multi_horizon=multi_horizon,
|
|
399
|
+
seed=seed,
|
|
400
|
+
**dataset_kwargs,
|
|
401
|
+
)
|
|
402
|
+
|
|
403
|
+
n_total = len(x_t)
|
|
404
|
+
n_test = int(n_total * test_ratio)
|
|
405
|
+
n_train = n_total - n_test
|
|
406
|
+
indices = np.random.default_rng(seed).permutation(n_total)
|
|
407
|
+
train_idx = indices[:n_train]
|
|
408
|
+
test_idx = indices[n_train:]
|
|
409
|
+
|
|
410
|
+
train_dataset = KuramotoDataset(
|
|
411
|
+
x_t[train_idx], x_tp1[train_idx], y_t[train_idx], y_tp1[train_idx],
|
|
412
|
+
x_future[train_idx], y_future[train_idx],
|
|
413
|
+
)
|
|
414
|
+
test_dataset = KuramotoDataset(
|
|
415
|
+
x_t[test_idx], x_tp1[test_idx], y_t[test_idx], y_tp1[test_idx],
|
|
416
|
+
x_future[test_idx], y_future[test_idx],
|
|
417
|
+
)
|
|
418
|
+
|
|
419
|
+
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
|
|
420
|
+
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
|
|
421
|
+
|
|
422
|
+
info = {
|
|
423
|
+
"x_dim": x_t.shape[1],
|
|
424
|
+
"y_dim": y_t.shape[1],
|
|
425
|
+
"n_train": n_train,
|
|
426
|
+
"n_test": n_test,
|
|
427
|
+
"train_x_t": torch.FloatTensor(x_t[train_idx]),
|
|
428
|
+
"train_x_tp1": torch.FloatTensor(x_tp1[train_idx]),
|
|
429
|
+
"train_y_t": torch.FloatTensor(y_t[train_idx]),
|
|
430
|
+
"train_y_tp1": torch.FloatTensor(y_tp1[train_idx]),
|
|
431
|
+
"train_x_future": torch.FloatTensor(x_future[train_idx]),
|
|
432
|
+
"train_y_future": torch.FloatTensor(y_future[train_idx]),
|
|
433
|
+
"test_x_t": torch.FloatTensor(x_t[test_idx]),
|
|
434
|
+
"test_x_tp1": torch.FloatTensor(x_tp1[test_idx]),
|
|
435
|
+
"test_y_t": torch.FloatTensor(y_t[test_idx]),
|
|
436
|
+
"test_y_tp1": torch.FloatTensor(y_tp1[test_idx]),
|
|
437
|
+
"test_x_future": torch.FloatTensor(x_future[test_idx]),
|
|
438
|
+
"test_y_future": torch.FloatTensor(y_future[test_idx]),
|
|
439
|
+
"full_x_t": torch.FloatTensor(x_t),
|
|
440
|
+
"full_x_tp1": torch.FloatTensor(x_tp1),
|
|
441
|
+
"full_y_t": torch.FloatTensor(y_t),
|
|
442
|
+
"full_y_tp1": torch.FloatTensor(y_tp1),
|
|
443
|
+
"full_x_future": torch.FloatTensor(x_future),
|
|
444
|
+
"full_y_future": torch.FloatTensor(y_future),
|
|
445
|
+
"meta": meta,
|
|
446
|
+
}
|
|
447
|
+
|
|
448
|
+
return train_loader, test_loader, info
|
|
449
|
+
|
|
450
|
+
|
|
451
|
+
if __name__ == "__main__":
|
|
452
|
+
train_loader, test_loader, info = get_kuramoto_dataloaders(
|
|
453
|
+
n_trajectories=32, n_steps=4, batch_size=16, multi_horizon=2
|
|
454
|
+
)
|
|
455
|
+
print(f"x_dim={info['x_dim']}, y_dim={info['y_dim']}")
|
|
456
|
+
print(f"n_train={info['n_train']}, n_test={info['n_test']}")
|
|
457
|
+
batch = next(iter(train_loader))
|
|
458
|
+
print(f"x_t batch shape: {batch['x_t'].shape}")
|
|
459
|
+
print(f"x_future batch shape: {batch['x_future'].shape}")
|