amica-python 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- amica/__init__.py +5 -0
- amica/_batching.py +194 -0
- amica/_newton.py +77 -0
- amica/_sklearn_interface.py +387 -0
- amica/_types.py +44 -0
- amica/conftest.py +30 -0
- amica/constants.py +47 -0
- amica/core.py +1165 -0
- amica/datasets.py +15 -0
- amica/kernels.py +1308 -0
- amica/linalg.py +349 -0
- amica/state.py +385 -0
- amica/tests/test_amica.py +497 -0
- amica/utils/__init__.py +36 -0
- amica/utils/_logging.py +64 -0
- amica/utils/_progress.py +34 -0
- amica/utils/_verbose.py +14 -0
- amica/utils/fetch.py +274 -0
- amica/utils/fortran.py +387 -0
- amica/utils/imports.py +46 -0
- amica/utils/mne.py +74 -0
- amica/utils/parallel.py +72 -0
- amica/utils/simulation.py +36 -0
- amica/utils/tests/test_fetch.py +9 -0
- amica/utils/tests/test_fortran.py +47 -0
- amica/utils/tests/test_imports.py +0 -0
- amica/utils/tests/test_logger.py +29 -0
- amica/utils/tests/test_mne.py +27 -0
- amica_python-0.1.0.dist-info/METADATA +196 -0
- amica_python-0.1.0.dist-info/RECORD +33 -0
- amica_python-0.1.0.dist-info/WHEEL +5 -0
- amica_python-0.1.0.dist-info/licenses/LICENSE +25 -0
- amica_python-0.1.0.dist-info/top_level.txt +1 -0
amica/utils/fetch.py
ADDED
|
@@ -0,0 +1,274 @@
|
|
|
1
|
+
"""Utilities for downloading and caching datasets for AMICA Python."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
import tempfile
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
|
|
7
|
+
# Cache directory for all test data
|
|
8
|
+
CACHE_DIR = Path.home() / "amica_test_data"
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
# -------------------------------
|
|
12
|
+
# EEGLAB test data
|
|
13
|
+
# -------------------------------
|
|
14
|
+
EEGLAB_BASE = "https://github.com/sccn/eeglab/raw/develop/sample_data/"
|
|
15
|
+
EEGLAB_FILES = {
|
|
16
|
+
"fdt": ("eeglab_data.fdt", "md5:a135e79e2acc93670746b2b6f44570a7"),
|
|
17
|
+
"set": ("eeglab_data.set", "md5:cf2d4549e48b8fd82cff776d13adb46c"),
|
|
18
|
+
}
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def fetch_datasets() -> Path:
|
|
22
|
+
"""
|
|
23
|
+
Download the default AMICA test datasets.
|
|
24
|
+
|
|
25
|
+
Notes
|
|
26
|
+
-----
|
|
27
|
+
This intentionally excludes the optional Planck astronomy maps because they
|
|
28
|
+
are large and should not be fetched automatically for every user.
|
|
29
|
+
|
|
30
|
+
Returns
|
|
31
|
+
-------
|
|
32
|
+
pathlib.Path
|
|
33
|
+
Path to the directory containing all cached test datasets.
|
|
34
|
+
"""
|
|
35
|
+
fetch_test_data()
|
|
36
|
+
fetch_fortran_outputs()
|
|
37
|
+
fetch_photos()
|
|
38
|
+
return CACHE_DIR
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def fetch_test_data() -> Path: # pragma: no cover
|
|
42
|
+
"""
|
|
43
|
+
Download the EEGLAB sample dataset (.set + .fdt).
|
|
44
|
+
|
|
45
|
+
Returns
|
|
46
|
+
-------
|
|
47
|
+
dict
|
|
48
|
+
Keys: "set", "fdt"
|
|
49
|
+
Values: pathlib.Path objects to the cached files
|
|
50
|
+
"""
|
|
51
|
+
import pooch
|
|
52
|
+
|
|
53
|
+
sample_dir = CACHE_DIR / "eeglab_sample_data"
|
|
54
|
+
sample_dir.mkdir(parents=True, exist_ok=True)
|
|
55
|
+
for _, (fname, known_hash) in EEGLAB_FILES.items():
|
|
56
|
+
url = f"{EEGLAB_BASE}{fname}"
|
|
57
|
+
_ = pooch.retrieve(
|
|
58
|
+
url=url,
|
|
59
|
+
known_hash=known_hash,
|
|
60
|
+
path=sample_dir,
|
|
61
|
+
fname=fname,
|
|
62
|
+
progressbar=True,
|
|
63
|
+
)
|
|
64
|
+
return sample_dir
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
# -------------------------------
|
|
68
|
+
# Fortran golden outputs
|
|
69
|
+
# -------------------------------
|
|
70
|
+
version = "v0.6.0"
|
|
71
|
+
FORTRAN_URL = (
|
|
72
|
+
"https://github.com/scott-huberty/amica/"
|
|
73
|
+
f"releases/download/{version}/test_output.tar.gz"
|
|
74
|
+
)
|
|
75
|
+
FORTRAN_HASH = "sha256:46ec71a0f66565a43480825f85611ea0126e1aa05eaa52ceaf7b628631d753c1"
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def fetch_fortran_outputs() -> Path:
|
|
79
|
+
"""
|
|
80
|
+
Download and extract golden outputs from the Fortran implementation.
|
|
81
|
+
|
|
82
|
+
Returns
|
|
83
|
+
-------
|
|
84
|
+
list of pathlib.Path
|
|
85
|
+
Paths to the extracted files inside the tarball.
|
|
86
|
+
"""
|
|
87
|
+
import pooch
|
|
88
|
+
|
|
89
|
+
unpack = pooch.Untar(extract_dir=".")
|
|
90
|
+
outputs_dir = pooch.retrieve(
|
|
91
|
+
url=FORTRAN_URL,
|
|
92
|
+
known_hash=FORTRAN_HASH,
|
|
93
|
+
path=CACHE_DIR,
|
|
94
|
+
processor=unpack,
|
|
95
|
+
progressbar=True,
|
|
96
|
+
)
|
|
97
|
+
return Path(outputs_dir[0]).parent # return the directory containing the files
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
# -------------------------------
|
|
101
|
+
# Photos dataset
|
|
102
|
+
# -------------------------------
|
|
103
|
+
COCKTAIL_BASE = "https://github.com/marcromani/cocktail/raw/refs/heads/master/examples/data/" # noqa E501
|
|
104
|
+
PHOTOS_FILES = {
|
|
105
|
+
"example2_baboon": ("example2_baboon", "md5:b38c092ca8fda06e29182926866a1950"),
|
|
106
|
+
"example2_cameraman": ("example2_cameraman", "md5:cb541d1814d3e27a4133d31653bbb01a"), # noqa E501
|
|
107
|
+
"example2_lena": ("example2_lena", "md5:ce2f1b9f96561a409a5afc10654ab744"),
|
|
108
|
+
"example2_mona": ("example2_mona", "md5:01f7b4cc8cade346843caac334d793d7"),
|
|
109
|
+
"example2_texture": ("example2_texture", "md5:65c169b03e55686a66f6fb6471ca7d60"),
|
|
110
|
+
}
|
|
111
|
+
|
|
112
|
+
def fetch_photos() -> Path:
|
|
113
|
+
"""
|
|
114
|
+
Download the photos dataset used in cocktail party example.
|
|
115
|
+
|
|
116
|
+
Returns
|
|
117
|
+
-------
|
|
118
|
+
pathlib.Path
|
|
119
|
+
Path to the directory containing all cached photo files.
|
|
120
|
+
"""
|
|
121
|
+
import pooch
|
|
122
|
+
|
|
123
|
+
photos_dir = CACHE_DIR / "photos"
|
|
124
|
+
photos_dir.mkdir(parents=True, exist_ok=True)
|
|
125
|
+
for _, (fname, known_hash) in PHOTOS_FILES.items():
|
|
126
|
+
url = f"{COCKTAIL_BASE}{fname}"
|
|
127
|
+
_ = pooch.retrieve(
|
|
128
|
+
url=url,
|
|
129
|
+
known_hash=known_hash,
|
|
130
|
+
path=photos_dir,
|
|
131
|
+
fname=fname,
|
|
132
|
+
progressbar=True,
|
|
133
|
+
)
|
|
134
|
+
return photos_dir
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
# -------------------------------
|
|
138
|
+
# Optional Planck PR3 astronomy maps
|
|
139
|
+
# -------------------------------
|
|
140
|
+
PLANCK_BASE_URL = "https://irsa.ipac.caltech.edu/data/Planck/release_3/all-sky-maps/maps/" # noqa: E501
|
|
141
|
+
PLANCK_MAP_FILENAMES = {
|
|
142
|
+
30: "LFI_SkyMap_030-BPassCorrected-field-IQU_1024_R3.00_full.fits",
|
|
143
|
+
44: "LFI_SkyMap_044-BPassCorrected-field-IQU_1024_R3.00_full.fits",
|
|
144
|
+
70: "LFI_SkyMap_070-BPassCorrected-field-IQU_1024_R3.00_full.fits",
|
|
145
|
+
100: "HFI_SkyMap_100_2048_R3.01_full.fits",
|
|
146
|
+
143: "HFI_SkyMap_143_2048_R3.01_full.fits",
|
|
147
|
+
217: "HFI_SkyMap_217_2048_R3.01_full.fits",
|
|
148
|
+
}
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
# -------------------------------
|
|
152
|
+
# Optional MICA benchmark dataset
|
|
153
|
+
# -------------------------------
|
|
154
|
+
MICA_RELEASE_URL = "http://sccn.ucsd.edu/pub/mica_release.zip"
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def _prepare_cache_dir(root: Path, dataset_name: str) -> Path:
|
|
158
|
+
"""Create a writable cache directory, falling back when needed."""
|
|
159
|
+
cache_dir = root / dataset_name
|
|
160
|
+
cache_dir.mkdir(parents=True, exist_ok=True)
|
|
161
|
+
if not os.access(cache_dir, os.W_OK):
|
|
162
|
+
raise PermissionError(f"Cache directory is not writable: {cache_dir}")
|
|
163
|
+
return cache_dir
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
def fetch_planck_temperature_map(filename: str) -> Path: # pragma: no cover
|
|
167
|
+
"""Download one public Planck PR3 map and return the local path."""
|
|
168
|
+
import pooch
|
|
169
|
+
|
|
170
|
+
cache_root = Path(os.environ.get("AMICA_PLANCK_CACHE", CACHE_DIR))
|
|
171
|
+
fallback_cache_root = Path(tempfile.gettempdir()) / "amica-python"
|
|
172
|
+
url = f"{PLANCK_BASE_URL}{filename}"
|
|
173
|
+
|
|
174
|
+
try:
|
|
175
|
+
cache_dir = _prepare_cache_dir(cache_root, "planck_pr3")
|
|
176
|
+
except PermissionError:
|
|
177
|
+
cache_dir = _prepare_cache_dir(fallback_cache_root, "planck_pr3")
|
|
178
|
+
|
|
179
|
+
try:
|
|
180
|
+
return Path(
|
|
181
|
+
pooch.retrieve(
|
|
182
|
+
url=url,
|
|
183
|
+
known_hash=None,
|
|
184
|
+
path=cache_dir,
|
|
185
|
+
fname=filename,
|
|
186
|
+
progressbar=True,
|
|
187
|
+
)
|
|
188
|
+
)
|
|
189
|
+
except Exception as exc: # pragma: no cover - depends on network/data host
|
|
190
|
+
raise RuntimeError(
|
|
191
|
+
f"Could not download the Planck map '{filename}' from {url}."
|
|
192
|
+
) from exc
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
def fetch_planck_temperature_maps(
|
|
196
|
+
frequencies_ghz: tuple[int, ...] | None = None,
|
|
197
|
+
) -> dict[int, Path]: # pragma: no cover
|
|
198
|
+
"""Download selected Planck PR3 temperature maps.
|
|
199
|
+
|
|
200
|
+
Parameters
|
|
201
|
+
----------
|
|
202
|
+
frequencies_ghz : tuple of int | None
|
|
203
|
+
Requested Planck channel frequencies in GHz. If ``None``, downloads all
|
|
204
|
+
channels listed in ``PLANCK_MAP_FILENAMES``.
|
|
205
|
+
|
|
206
|
+
Returns
|
|
207
|
+
-------
|
|
208
|
+
dict of int to pathlib.Path
|
|
209
|
+
Mapping from channel frequency in GHz to the local cached file path.
|
|
210
|
+
"""
|
|
211
|
+
if frequencies_ghz is None:
|
|
212
|
+
frequencies_ghz = tuple(PLANCK_MAP_FILENAMES)
|
|
213
|
+
|
|
214
|
+
missing = sorted(set(frequencies_ghz) - set(PLANCK_MAP_FILENAMES))
|
|
215
|
+
if missing:
|
|
216
|
+
raise ValueError(
|
|
217
|
+
f"Unsupported Planck frequencies requested: {missing}. "
|
|
218
|
+
f"Available channels are {sorted(PLANCK_MAP_FILENAMES)}."
|
|
219
|
+
)
|
|
220
|
+
|
|
221
|
+
return {
|
|
222
|
+
frequency_ghz: fetch_planck_temperature_map(PLANCK_MAP_FILENAMES[frequency_ghz])
|
|
223
|
+
for frequency_ghz in frequencies_ghz
|
|
224
|
+
}
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
def fetch_mica_release(output_dir: Path | None = None) -> Path: # pragma: no cover
|
|
228
|
+
"""Download and extract the optional EEGLAB MICA benchmark dataset.
|
|
229
|
+
|
|
230
|
+
This dataset is large, so it is intentionally excluded from
|
|
231
|
+
:func:`fetch_datasets` and must be requested explicitly.
|
|
232
|
+
|
|
233
|
+
Parameters
|
|
234
|
+
----------
|
|
235
|
+
output_dir : pathlib.Path | None
|
|
236
|
+
Directory where the extracted ``mica_release`` folder should live.
|
|
237
|
+
Defaults to ``~/amica_test_data``.
|
|
238
|
+
|
|
239
|
+
Returns
|
|
240
|
+
-------
|
|
241
|
+
pathlib.Path
|
|
242
|
+
Path to the extracted ``mica_release`` directory.
|
|
243
|
+
"""
|
|
244
|
+
import pooch
|
|
245
|
+
|
|
246
|
+
cache_root = Path(output_dir).expanduser() if output_dir is not None else CACHE_DIR
|
|
247
|
+
cache_root.mkdir(parents=True, exist_ok=True)
|
|
248
|
+
|
|
249
|
+
zip_path = Path(
|
|
250
|
+
pooch.retrieve(
|
|
251
|
+
url=MICA_RELEASE_URL,
|
|
252
|
+
known_hash=None,
|
|
253
|
+
path=cache_root,
|
|
254
|
+
fname="mica_release.zip",
|
|
255
|
+
progressbar=True,
|
|
256
|
+
)
|
|
257
|
+
)
|
|
258
|
+
|
|
259
|
+
release_dir = cache_root / "mica_release"
|
|
260
|
+
if release_dir.exists():
|
|
261
|
+
return release_dir
|
|
262
|
+
|
|
263
|
+
import zipfile
|
|
264
|
+
|
|
265
|
+
with zipfile.ZipFile(zip_path) as zf:
|
|
266
|
+
zf.extractall(cache_root)
|
|
267
|
+
|
|
268
|
+
if not release_dir.exists():
|
|
269
|
+
raise RuntimeError(
|
|
270
|
+
f"Expected extracted benchmark directory at {release_dir}, "
|
|
271
|
+
f"but it was not created from {zip_path}."
|
|
272
|
+
)
|
|
273
|
+
|
|
274
|
+
return release_dir
|
amica/utils/fortran.py
ADDED
|
@@ -0,0 +1,387 @@
|
|
|
1
|
+
"""Utilities for interfacing with Fortran AMICA outputs."""
|
|
2
|
+
import inspect
|
|
3
|
+
from dataclasses import MISSING, asdict, dataclass, fields
|
|
4
|
+
from functools import lru_cache
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def load_initial_weights(fortran_outdir, *, n_components, n_mixtures):
|
|
11
|
+
"""Load w_init, sbeta_init, and mu_init binary files from a Fortran AMICA run."""
|
|
12
|
+
fortran_outdir = Path(fortran_outdir)
|
|
13
|
+
assert fortran_outdir.exists()
|
|
14
|
+
|
|
15
|
+
initial_weights = np.fromfile(
|
|
16
|
+
fortran_outdir / "Wtmp.bin",
|
|
17
|
+
dtype=np.float64
|
|
18
|
+
)
|
|
19
|
+
initial_weights = initial_weights.reshape((n_components, n_components), order="F")
|
|
20
|
+
initial_scales = np.fromfile(
|
|
21
|
+
fortran_outdir / "sbetatmp.bin",
|
|
22
|
+
dtype=np.float64
|
|
23
|
+
)
|
|
24
|
+
initial_scales = initial_scales.reshape((n_mixtures, n_components), order="F")
|
|
25
|
+
initial_scales = initial_scales.T # Match Our dimension standard
|
|
26
|
+
initial_locations = np.fromfile(
|
|
27
|
+
fortran_outdir / "mutmp.bin",
|
|
28
|
+
dtype=np.float64
|
|
29
|
+
)
|
|
30
|
+
initial_locations = initial_locations.reshape((n_mixtures, n_components), order="F")
|
|
31
|
+
initial_locations = initial_locations.T # Match Our dimension standard
|
|
32
|
+
return initial_weights, initial_scales, initial_locations
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def load_fortran_results(fortran_outdir, *, n_components, n_mixtures, n_features=None):
|
|
36
|
+
"""Load results from a completed Fortran AMICA run for comparison."""
|
|
37
|
+
fortran_outdir = Path(fortran_outdir)
|
|
38
|
+
assert fortran_outdir.exists()
|
|
39
|
+
assert fortran_outdir.is_dir()
|
|
40
|
+
if n_features is None:
|
|
41
|
+
n_features = n_components
|
|
42
|
+
|
|
43
|
+
# Channel means
|
|
44
|
+
mean_f = np.fromfile(f"{fortran_outdir}/mean")
|
|
45
|
+
|
|
46
|
+
# Sphering matrix
|
|
47
|
+
S_f = np.fromfile(f"{fortran_outdir}/S", dtype=np.float64)
|
|
48
|
+
S_f = S_f.reshape((n_features, n_features), order="F")
|
|
49
|
+
|
|
50
|
+
# Unmixing matrix
|
|
51
|
+
W_f = np.fromfile(f"{fortran_outdir}/W", dtype=np.float64)
|
|
52
|
+
W_f = W_f.reshape((n_components, n_components, 1), order="F")
|
|
53
|
+
|
|
54
|
+
# Mixing matrix
|
|
55
|
+
A_f = np.fromfile(f"{fortran_outdir}/A")
|
|
56
|
+
A_f = A_f.reshape((n_components, n_components), order="F")
|
|
57
|
+
|
|
58
|
+
# Bias term
|
|
59
|
+
c_f = np.fromfile(f"{fortran_outdir}/c")
|
|
60
|
+
c_f = c_f.reshape((n_components, 1), order="F")
|
|
61
|
+
|
|
62
|
+
# Log-likelihood
|
|
63
|
+
LL_f = np.fromfile(f"{fortran_outdir}/LL")
|
|
64
|
+
|
|
65
|
+
# Mixture model parameters
|
|
66
|
+
# Fortran order is n_mixtures x n_components. Ours is n_components x n_mixtures
|
|
67
|
+
alpha_f = np.fromfile(f"{fortran_outdir}/alpha")
|
|
68
|
+
alpha_f = alpha_f.reshape((n_mixtures, n_components), order="F").T
|
|
69
|
+
# Remember that alpha (and sbeta, mu etc) are (num_comps, num_mix) in Python
|
|
70
|
+
|
|
71
|
+
# Scale parameters
|
|
72
|
+
sbeta_f = np.fromfile(f"{fortran_outdir}/sbeta", dtype=np.float64)
|
|
73
|
+
sbeta_f = sbeta_f.reshape((n_mixtures, n_components), order="F").T
|
|
74
|
+
|
|
75
|
+
# Location parameters
|
|
76
|
+
mu_f = np.fromfile(f"{fortran_outdir}/mu", dtype=np.float64)
|
|
77
|
+
mu_f = mu_f.reshape((n_mixtures, n_components), order="F").T
|
|
78
|
+
|
|
79
|
+
rho_f = np.fromfile(f"{fortran_outdir}/rho", dtype=np.float64)
|
|
80
|
+
rho_f = rho_f.reshape((n_mixtures, n_components), order="F").T
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
comp_list_f = np.fromfile(f"{fortran_outdir}/comp_list", dtype=np.int32)
|
|
84
|
+
# Something weird is happening there. I expect (num_comps, num_models) = (32, 1)
|
|
85
|
+
comp_list_f = np.reshape(comp_list_f, (n_components, 2), order="F")
|
|
86
|
+
|
|
87
|
+
gm_f = np.fromfile(f"{fortran_outdir}/gm")
|
|
88
|
+
return {
|
|
89
|
+
"W": W_f,
|
|
90
|
+
"S": S_f,
|
|
91
|
+
"sbeta": sbeta_f,
|
|
92
|
+
"rho": rho_f,
|
|
93
|
+
"mu": mu_f,
|
|
94
|
+
"mean": mean_f,
|
|
95
|
+
"gm": gm_f,
|
|
96
|
+
"comp_list": comp_list_f,
|
|
97
|
+
"c": c_f,
|
|
98
|
+
"alpha": alpha_f,
|
|
99
|
+
"A": A_f,
|
|
100
|
+
"LL": LL_f
|
|
101
|
+
}
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def write_data(data, filename):
|
|
105
|
+
"""Save data to a binary file in Fortran-compatible format.
|
|
106
|
+
|
|
107
|
+
Parameters
|
|
108
|
+
----------
|
|
109
|
+
data : array-like
|
|
110
|
+
The data of shape (n_samples, n_features) to save. Will be converted to a
|
|
111
|
+
Fortran-contiguous array of type float32.
|
|
112
|
+
filename : str or Path
|
|
113
|
+
The path to the output binary file.
|
|
114
|
+
|
|
115
|
+
Returns
|
|
116
|
+
-------
|
|
117
|
+
data : np.ndarray
|
|
118
|
+
The Fortran-contiguous array that was saved.
|
|
119
|
+
path : Path
|
|
120
|
+
The path to the saved file.
|
|
121
|
+
"""
|
|
122
|
+
# tofile ravels matrices in C order, so force Fortran order.
|
|
123
|
+
fpath = Path(filename).expanduser().resolve()
|
|
124
|
+
# We actually have to write in C order to be Fortran compatible.
|
|
125
|
+
# Or transpose the data First and write in Fortran order.
|
|
126
|
+
# Because Fortran program expects (n_features, n_samples)
|
|
127
|
+
vector = data.T.astype("<f4").ravel(order="F")
|
|
128
|
+
vector.tofile(fpath)
|
|
129
|
+
return fpath, data
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def load_data(filename, *, dtype=np.float32, shape=None):
|
|
133
|
+
"""Load binary data file that saved for use with Fortran AMICA.
|
|
134
|
+
|
|
135
|
+
Parameters
|
|
136
|
+
----------
|
|
137
|
+
filename : str or Path
|
|
138
|
+
The path to the input binary file.
|
|
139
|
+
dtype : data-type
|
|
140
|
+
The desired data-type for the loaded array. Default is np.float32.
|
|
141
|
+
shape : tuple of int
|
|
142
|
+
The shape of the array to load.
|
|
143
|
+
|
|
144
|
+
Returns
|
|
145
|
+
-------
|
|
146
|
+
data : np.ndarray
|
|
147
|
+
The Fortran-contiguous array that was loaded.
|
|
148
|
+
|
|
149
|
+
Notes
|
|
150
|
+
-----
|
|
151
|
+
Fortran stores arrays in column-major order, and the Fortran program
|
|
152
|
+
expectes data in shape (n_features, n_samples). So when loading data
|
|
153
|
+
for use in Python, you should reshape to (n_features, n_samples) and
|
|
154
|
+
then transpose to (n_samples, n_features) to match the common Python
|
|
155
|
+
convention.
|
|
156
|
+
|
|
157
|
+
Examples
|
|
158
|
+
--------
|
|
159
|
+
>>> data = load_data('data.bin', shape=(64, 1000)).T
|
|
160
|
+
"""
|
|
161
|
+
data = np.fromfile(filename, dtype=dtype)
|
|
162
|
+
if shape is not None:
|
|
163
|
+
data = data.reshape(shape, order="F")
|
|
164
|
+
return data
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
@lru_cache(maxsize=1)
|
|
168
|
+
def _get_fortran_param_defaults():
|
|
169
|
+
"""Collect Fortran parameter defaults aligned with AMICA-Python."""
|
|
170
|
+
defaults = {}
|
|
171
|
+
for field in fields(FortranParams):
|
|
172
|
+
if field.default is not MISSING:
|
|
173
|
+
defaults[field.name] = field.default
|
|
174
|
+
|
|
175
|
+
# Public API defaults from fit_amica.
|
|
176
|
+
from amica.core import fit_amica
|
|
177
|
+
|
|
178
|
+
fit_defaults = {
|
|
179
|
+
name: param.default
|
|
180
|
+
for name, param in inspect.signature(fit_amica).parameters.items()
|
|
181
|
+
if param.default is not inspect.Signature.empty
|
|
182
|
+
}
|
|
183
|
+
fit_to_fortran = {
|
|
184
|
+
"max_iter": "max_iter",
|
|
185
|
+
"n_models": "num_models",
|
|
186
|
+
"n_mixtures": "num_mix_comps",
|
|
187
|
+
"lrate": "lrate",
|
|
188
|
+
"rholrate": "rholrate",
|
|
189
|
+
"pdftype": "pdftype",
|
|
190
|
+
"do_newton": "do_newton",
|
|
191
|
+
"newt_start": "newt_start",
|
|
192
|
+
"newtrate": "newtrate",
|
|
193
|
+
"newt_ramp": "newt_ramp",
|
|
194
|
+
"do_reject": "do_reject",
|
|
195
|
+
}
|
|
196
|
+
for fit_key, fortran_key in fit_to_fortran.items():
|
|
197
|
+
if fit_key in fit_defaults:
|
|
198
|
+
defaults[fortran_key] = fit_defaults[fit_key]
|
|
199
|
+
|
|
200
|
+
# Internal defaults from constants.py.
|
|
201
|
+
from amica import constants as c
|
|
202
|
+
|
|
203
|
+
const_to_fortran = {
|
|
204
|
+
"lratefact": "lratefact",
|
|
205
|
+
"rholratefact": "rholratefact",
|
|
206
|
+
"use_min_dll": "use_min_dll",
|
|
207
|
+
"use_grad_norm": "use_grad_norm",
|
|
208
|
+
"min_dll": "min_dll",
|
|
209
|
+
"min_nd": "min_grad_norm",
|
|
210
|
+
"do_opt_block": "do_opt_block",
|
|
211
|
+
"mineig": "mineig",
|
|
212
|
+
"minlrate": "minlrate",
|
|
213
|
+
"rho0": "rho0",
|
|
214
|
+
"minrho": "minrho",
|
|
215
|
+
"maxrho": "maxrho",
|
|
216
|
+
"invsigmax": "invsigmax",
|
|
217
|
+
"invsigmin": "invsigmin",
|
|
218
|
+
"dorho": "do_rho",
|
|
219
|
+
"doscaling": "doscaling",
|
|
220
|
+
"maxdecs": "max_decs",
|
|
221
|
+
"share_comps": "share_comps",
|
|
222
|
+
"share_start": "share_start",
|
|
223
|
+
"share_iter": "share_iter",
|
|
224
|
+
"outstep": "writestep",
|
|
225
|
+
"fix_init": "fix_init",
|
|
226
|
+
}
|
|
227
|
+
for const_key, fortran_key in const_to_fortran.items():
|
|
228
|
+
if hasattr(c, const_key):
|
|
229
|
+
defaults[fortran_key] = getattr(c, const_key)
|
|
230
|
+
|
|
231
|
+
return defaults
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
def write_param_file(fpath, data, **kwargs):
|
|
235
|
+
"""Write a Fortran AMICA parameter file.
|
|
236
|
+
|
|
237
|
+
Parameters
|
|
238
|
+
----------
|
|
239
|
+
fpath : str or Path
|
|
240
|
+
The path to the output parameter file.
|
|
241
|
+
data : np.ndarray
|
|
242
|
+
The data array to write to the file.
|
|
243
|
+
**kwargs : dict
|
|
244
|
+
Additional parameters to write to the file. ``files`` and ``outdir`` are
|
|
245
|
+
required and should be passed via kwargs.
|
|
246
|
+
|
|
247
|
+
Returns
|
|
248
|
+
-------
|
|
249
|
+
path : Path
|
|
250
|
+
The path to the saved parameter file.
|
|
251
|
+
"""
|
|
252
|
+
fpath = Path(fpath).expanduser().resolve()
|
|
253
|
+
missing = [key for key in ("files", "outdir") if key not in kwargs]
|
|
254
|
+
if missing:
|
|
255
|
+
missing_joined = ", ".join(missing)
|
|
256
|
+
raise ValueError(
|
|
257
|
+
f"Missing required parameter(s): {missing_joined}. Pass them via kwargs."
|
|
258
|
+
)
|
|
259
|
+
|
|
260
|
+
merged = {**_get_fortran_param_defaults(), **kwargs}
|
|
261
|
+
merged.setdefault("data_dim", data.shape[1])
|
|
262
|
+
merged.setdefault("field_dim", data.shape[0])
|
|
263
|
+
merged.setdefault("block_size", data.shape[0])
|
|
264
|
+
merged.setdefault("pcakeep", data.shape[1])
|
|
265
|
+
|
|
266
|
+
params = FortranParams(**merged)
|
|
267
|
+
params_dict = params.to_param_dict()
|
|
268
|
+
|
|
269
|
+
fpath.write_text("".join(f"{key} {value}\n" for key, value in params_dict.items()))
|
|
270
|
+
return fpath, params
|
|
271
|
+
|
|
272
|
+
|
|
273
|
+
@dataclass
|
|
274
|
+
class FortranParams:
|
|
275
|
+
"""Dataclass to hold Fortran AMICA parameters."""
|
|
276
|
+
|
|
277
|
+
# Required parameters
|
|
278
|
+
files: str | Path
|
|
279
|
+
outdir: str | Path
|
|
280
|
+
# Data Shape
|
|
281
|
+
block_size: int
|
|
282
|
+
data_dim: int # n_features
|
|
283
|
+
field_dim: int # n_samples
|
|
284
|
+
max_iter: int = 500
|
|
285
|
+
blk_min: int | None = None
|
|
286
|
+
blk_step: int | None = None
|
|
287
|
+
blk_max: int | None = None
|
|
288
|
+
# Whitening
|
|
289
|
+
do_mean: int = 1
|
|
290
|
+
do_sphere: int = 1
|
|
291
|
+
doPCA: int = 1
|
|
292
|
+
pcakeep: int | None = None
|
|
293
|
+
pcadb: float = 30.000000
|
|
294
|
+
# You'll probably never need to change these...
|
|
295
|
+
# Main Model Params
|
|
296
|
+
num_models : int = 1
|
|
297
|
+
max_threads : int = 1 # Single-threaded (aids debugging)
|
|
298
|
+
# Newton
|
|
299
|
+
do_newton: int =1
|
|
300
|
+
newt_start: int = 50
|
|
301
|
+
newt_ramp: int = 10
|
|
302
|
+
newtrate: float = 1.000000
|
|
303
|
+
# Learning Rates
|
|
304
|
+
lrate: float = 0.050000
|
|
305
|
+
rholrate: float = 0.050000
|
|
306
|
+
lratefact: float = 0.500000
|
|
307
|
+
rholratefact: float = 0.500000
|
|
308
|
+
# Convergence
|
|
309
|
+
use_min_dll: int | bool = 1
|
|
310
|
+
min_dll: float = 1.000000e-09
|
|
311
|
+
use_grad_norm: int = 1
|
|
312
|
+
min_grad_norm: float = 1.000000e-07
|
|
313
|
+
# Misc.
|
|
314
|
+
do_opt_block: int | bool = 0
|
|
315
|
+
num_mix_comps: int = 3
|
|
316
|
+
pdftype: int = 0
|
|
317
|
+
num_samples: int = 1
|
|
318
|
+
field_blocksize: int = 1
|
|
319
|
+
do_history: int = 0
|
|
320
|
+
histstep: int = 10
|
|
321
|
+
share_comps: int = 0
|
|
322
|
+
share_start: int = 100
|
|
323
|
+
comp_thresh: float = 0.990000
|
|
324
|
+
share_iter: int = 100
|
|
325
|
+
minlrate: float = 1.000000e-08
|
|
326
|
+
mineig: float = 1.000000e-12
|
|
327
|
+
rho0: float = 1.500000
|
|
328
|
+
minrho: float = 1.000000
|
|
329
|
+
maxrho: float = 2.000000
|
|
330
|
+
kurt_start: int = 3
|
|
331
|
+
num_kurt: int = 5
|
|
332
|
+
kurt_int: int = 1
|
|
333
|
+
# Rejection
|
|
334
|
+
do_reject: int = 0
|
|
335
|
+
numrej: int = 3
|
|
336
|
+
rejsig: float = 3.000000
|
|
337
|
+
rejstart: int = 2
|
|
338
|
+
rejint: int = 3
|
|
339
|
+
# Saving
|
|
340
|
+
writestep: int = 1 # Write every iteration (aids debugging)
|
|
341
|
+
write_nd: int = 0
|
|
342
|
+
write_LLt: int = 1
|
|
343
|
+
decwindow: int = 1
|
|
344
|
+
max_decs: int = 3
|
|
345
|
+
fix_init: int = 0
|
|
346
|
+
update_A: int = 1
|
|
347
|
+
update_c: int = 1
|
|
348
|
+
update_gm: int = 1
|
|
349
|
+
update_alpha: int = 1
|
|
350
|
+
update_mu: int = 1
|
|
351
|
+
update_beta: int = 1
|
|
352
|
+
invsigmax: float = 100.000000
|
|
353
|
+
invsigmin: float = 1.0e-08
|
|
354
|
+
do_rho: int = 1
|
|
355
|
+
# Debugging
|
|
356
|
+
load_rej: int = 0
|
|
357
|
+
load_W: int = 0
|
|
358
|
+
load_c: int = 0
|
|
359
|
+
load_gm: int = 0
|
|
360
|
+
load_alpha: int = 0
|
|
361
|
+
load_mu: int = 0
|
|
362
|
+
load_beta: int = 0
|
|
363
|
+
load_rho: int = 0
|
|
364
|
+
load_comp_list: int = 0
|
|
365
|
+
byte_size: int = 4
|
|
366
|
+
doscaling: int = 1
|
|
367
|
+
scalestep: int = 1
|
|
368
|
+
|
|
369
|
+
def __post_init__(self):
|
|
370
|
+
"""Initialize attributes."""
|
|
371
|
+
if self.blk_min is None:
|
|
372
|
+
self.blk_min = self.block_size // 4
|
|
373
|
+
if self.blk_step is None:
|
|
374
|
+
self.blk_step = self.block_size // 4
|
|
375
|
+
if self.blk_max is None:
|
|
376
|
+
self.blk_max = self.block_size
|
|
377
|
+
if self.pcakeep is None:
|
|
378
|
+
self.pcakeep = self.data_dim
|
|
379
|
+
|
|
380
|
+
# Convert bools to int
|
|
381
|
+
for field in fields(self):
|
|
382
|
+
if isinstance(getattr(self, field.name), bool):
|
|
383
|
+
setattr(self, field.name, int(getattr(self, field.name)))
|
|
384
|
+
|
|
385
|
+
def to_param_dict(self):
|
|
386
|
+
"""Convert to a dict suitable for writing to aFortran AMICA parameter file."""
|
|
387
|
+
return asdict(self)
|
amica/utils/imports.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
"""Soft import helper."""
|
|
2
|
+
from importlib import import_module
|
|
3
|
+
from importlib.util import find_spec
|
|
4
|
+
from types import ModuleType
|
|
5
|
+
|
|
6
|
+
# A mapping from import name to package name (on PyPI) when the package name
|
|
7
|
+
# is different.
|
|
8
|
+
_INSTALL_MAPPING: dict[str, str] = {
|
|
9
|
+
"codespell_lib": "codespell",
|
|
10
|
+
"cv2": "opencv-python",
|
|
11
|
+
"parallel": "pyparallel",
|
|
12
|
+
"pytest_cov": "pytest-cov",
|
|
13
|
+
"serial": "pyserial",
|
|
14
|
+
"sklearn": "scikit-learn",
|
|
15
|
+
"sksparse": "scikit-sparse",
|
|
16
|
+
}
|
|
17
|
+
|
|
18
|
+
def import_optional_dependency(
|
|
19
|
+
name: str,
|
|
20
|
+
extra: str = "",
|
|
21
|
+
) -> ModuleType:
|
|
22
|
+
"""Import an optional dependency.
|
|
23
|
+
|
|
24
|
+
If a dependency is missing an ImportError with a nice message will be
|
|
25
|
+
raised.
|
|
26
|
+
|
|
27
|
+
Parameters
|
|
28
|
+
----------
|
|
29
|
+
name : str
|
|
30
|
+
The module name.
|
|
31
|
+
extra : str
|
|
32
|
+
Additional text to include in the ImportError message.
|
|
33
|
+
|
|
34
|
+
Returns
|
|
35
|
+
-------
|
|
36
|
+
module : Module
|
|
37
|
+
The imported module when found.
|
|
38
|
+
"""
|
|
39
|
+
package_name = _INSTALL_MAPPING.get(name)
|
|
40
|
+
install_name = package_name if package_name is not None else name
|
|
41
|
+
if find_spec(name) is None:
|
|
42
|
+
raise ImportError(
|
|
43
|
+
f"Missing optional dependency '{install_name}'. {extra} Use pip or "
|
|
44
|
+
f"conda to install {install_name}."
|
|
45
|
+
)
|
|
46
|
+
return import_module(name)
|