amica-python 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
amica/utils/fetch.py ADDED
@@ -0,0 +1,274 @@
1
+ """Utilities for downloading and caching datasets for AMICA Python."""
2
+
3
+ import os
4
+ import tempfile
5
+ from pathlib import Path
6
+
7
+ # Cache directory for all test data
8
+ CACHE_DIR = Path.home() / "amica_test_data"
9
+
10
+
11
+ # -------------------------------
12
+ # EEGLAB test data
13
+ # -------------------------------
14
+ EEGLAB_BASE = "https://github.com/sccn/eeglab/raw/develop/sample_data/"
15
+ EEGLAB_FILES = {
16
+ "fdt": ("eeglab_data.fdt", "md5:a135e79e2acc93670746b2b6f44570a7"),
17
+ "set": ("eeglab_data.set", "md5:cf2d4549e48b8fd82cff776d13adb46c"),
18
+ }
19
+
20
+
21
+ def fetch_datasets() -> Path:
22
+ """
23
+ Download the default AMICA test datasets.
24
+
25
+ Notes
26
+ -----
27
+ This intentionally excludes the optional Planck astronomy maps because they
28
+ are large and should not be fetched automatically for every user.
29
+
30
+ Returns
31
+ -------
32
+ pathlib.Path
33
+ Path to the directory containing all cached test datasets.
34
+ """
35
+ fetch_test_data()
36
+ fetch_fortran_outputs()
37
+ fetch_photos()
38
+ return CACHE_DIR
39
+
40
+
41
+ def fetch_test_data() -> Path: # pragma: no cover
42
+ """
43
+ Download the EEGLAB sample dataset (.set + .fdt).
44
+
45
+ Returns
46
+ -------
47
+ dict
48
+ Keys: "set", "fdt"
49
+ Values: pathlib.Path objects to the cached files
50
+ """
51
+ import pooch
52
+
53
+ sample_dir = CACHE_DIR / "eeglab_sample_data"
54
+ sample_dir.mkdir(parents=True, exist_ok=True)
55
+ for _, (fname, known_hash) in EEGLAB_FILES.items():
56
+ url = f"{EEGLAB_BASE}{fname}"
57
+ _ = pooch.retrieve(
58
+ url=url,
59
+ known_hash=known_hash,
60
+ path=sample_dir,
61
+ fname=fname,
62
+ progressbar=True,
63
+ )
64
+ return sample_dir
65
+
66
+
67
+ # -------------------------------
68
+ # Fortran golden outputs
69
+ # -------------------------------
70
+ version = "v0.6.0"
71
+ FORTRAN_URL = (
72
+ "https://github.com/scott-huberty/amica/"
73
+ f"releases/download/{version}/test_output.tar.gz"
74
+ )
75
+ FORTRAN_HASH = "sha256:46ec71a0f66565a43480825f85611ea0126e1aa05eaa52ceaf7b628631d753c1"
76
+
77
+
78
+ def fetch_fortran_outputs() -> Path:
79
+ """
80
+ Download and extract golden outputs from the Fortran implementation.
81
+
82
+ Returns
83
+ -------
84
+ list of pathlib.Path
85
+ Paths to the extracted files inside the tarball.
86
+ """
87
+ import pooch
88
+
89
+ unpack = pooch.Untar(extract_dir=".")
90
+ outputs_dir = pooch.retrieve(
91
+ url=FORTRAN_URL,
92
+ known_hash=FORTRAN_HASH,
93
+ path=CACHE_DIR,
94
+ processor=unpack,
95
+ progressbar=True,
96
+ )
97
+ return Path(outputs_dir[0]).parent # return the directory containing the files
98
+
99
+
100
+ # -------------------------------
101
+ # Photos dataset
102
+ # -------------------------------
103
+ COCKTAIL_BASE = "https://github.com/marcromani/cocktail/raw/refs/heads/master/examples/data/" # noqa E501
104
+ PHOTOS_FILES = {
105
+ "example2_baboon": ("example2_baboon", "md5:b38c092ca8fda06e29182926866a1950"),
106
+ "example2_cameraman": ("example2_cameraman", "md5:cb541d1814d3e27a4133d31653bbb01a"), # noqa E501
107
+ "example2_lena": ("example2_lena", "md5:ce2f1b9f96561a409a5afc10654ab744"),
108
+ "example2_mona": ("example2_mona", "md5:01f7b4cc8cade346843caac334d793d7"),
109
+ "example2_texture": ("example2_texture", "md5:65c169b03e55686a66f6fb6471ca7d60"),
110
+ }
111
+
112
+ def fetch_photos() -> Path:
113
+ """
114
+ Download the photos dataset used in cocktail party example.
115
+
116
+ Returns
117
+ -------
118
+ pathlib.Path
119
+ Path to the directory containing all cached photo files.
120
+ """
121
+ import pooch
122
+
123
+ photos_dir = CACHE_DIR / "photos"
124
+ photos_dir.mkdir(parents=True, exist_ok=True)
125
+ for _, (fname, known_hash) in PHOTOS_FILES.items():
126
+ url = f"{COCKTAIL_BASE}{fname}"
127
+ _ = pooch.retrieve(
128
+ url=url,
129
+ known_hash=known_hash,
130
+ path=photos_dir,
131
+ fname=fname,
132
+ progressbar=True,
133
+ )
134
+ return photos_dir
135
+
136
+
137
+ # -------------------------------
138
+ # Optional Planck PR3 astronomy maps
139
+ # -------------------------------
140
+ PLANCK_BASE_URL = "https://irsa.ipac.caltech.edu/data/Planck/release_3/all-sky-maps/maps/" # noqa: E501
141
+ PLANCK_MAP_FILENAMES = {
142
+ 30: "LFI_SkyMap_030-BPassCorrected-field-IQU_1024_R3.00_full.fits",
143
+ 44: "LFI_SkyMap_044-BPassCorrected-field-IQU_1024_R3.00_full.fits",
144
+ 70: "LFI_SkyMap_070-BPassCorrected-field-IQU_1024_R3.00_full.fits",
145
+ 100: "HFI_SkyMap_100_2048_R3.01_full.fits",
146
+ 143: "HFI_SkyMap_143_2048_R3.01_full.fits",
147
+ 217: "HFI_SkyMap_217_2048_R3.01_full.fits",
148
+ }
149
+
150
+
151
+ # -------------------------------
152
+ # Optional MICA benchmark dataset
153
+ # -------------------------------
154
+ MICA_RELEASE_URL = "http://sccn.ucsd.edu/pub/mica_release.zip"
155
+
156
+
157
+ def _prepare_cache_dir(root: Path, dataset_name: str) -> Path:
158
+ """Create a writable cache directory, falling back when needed."""
159
+ cache_dir = root / dataset_name
160
+ cache_dir.mkdir(parents=True, exist_ok=True)
161
+ if not os.access(cache_dir, os.W_OK):
162
+ raise PermissionError(f"Cache directory is not writable: {cache_dir}")
163
+ return cache_dir
164
+
165
+
166
+ def fetch_planck_temperature_map(filename: str) -> Path: # pragma: no cover
167
+ """Download one public Planck PR3 map and return the local path."""
168
+ import pooch
169
+
170
+ cache_root = Path(os.environ.get("AMICA_PLANCK_CACHE", CACHE_DIR))
171
+ fallback_cache_root = Path(tempfile.gettempdir()) / "amica-python"
172
+ url = f"{PLANCK_BASE_URL}{filename}"
173
+
174
+ try:
175
+ cache_dir = _prepare_cache_dir(cache_root, "planck_pr3")
176
+ except PermissionError:
177
+ cache_dir = _prepare_cache_dir(fallback_cache_root, "planck_pr3")
178
+
179
+ try:
180
+ return Path(
181
+ pooch.retrieve(
182
+ url=url,
183
+ known_hash=None,
184
+ path=cache_dir,
185
+ fname=filename,
186
+ progressbar=True,
187
+ )
188
+ )
189
+ except Exception as exc: # pragma: no cover - depends on network/data host
190
+ raise RuntimeError(
191
+ f"Could not download the Planck map '{filename}' from {url}."
192
+ ) from exc
193
+
194
+
195
+ def fetch_planck_temperature_maps(
196
+ frequencies_ghz: tuple[int, ...] | None = None,
197
+ ) -> dict[int, Path]: # pragma: no cover
198
+ """Download selected Planck PR3 temperature maps.
199
+
200
+ Parameters
201
+ ----------
202
+ frequencies_ghz : tuple of int | None
203
+ Requested Planck channel frequencies in GHz. If ``None``, downloads all
204
+ channels listed in ``PLANCK_MAP_FILENAMES``.
205
+
206
+ Returns
207
+ -------
208
+ dict of int to pathlib.Path
209
+ Mapping from channel frequency in GHz to the local cached file path.
210
+ """
211
+ if frequencies_ghz is None:
212
+ frequencies_ghz = tuple(PLANCK_MAP_FILENAMES)
213
+
214
+ missing = sorted(set(frequencies_ghz) - set(PLANCK_MAP_FILENAMES))
215
+ if missing:
216
+ raise ValueError(
217
+ f"Unsupported Planck frequencies requested: {missing}. "
218
+ f"Available channels are {sorted(PLANCK_MAP_FILENAMES)}."
219
+ )
220
+
221
+ return {
222
+ frequency_ghz: fetch_planck_temperature_map(PLANCK_MAP_FILENAMES[frequency_ghz])
223
+ for frequency_ghz in frequencies_ghz
224
+ }
225
+
226
+
227
+ def fetch_mica_release(output_dir: Path | None = None) -> Path: # pragma: no cover
228
+ """Download and extract the optional EEGLAB MICA benchmark dataset.
229
+
230
+ This dataset is large, so it is intentionally excluded from
231
+ :func:`fetch_datasets` and must be requested explicitly.
232
+
233
+ Parameters
234
+ ----------
235
+ output_dir : pathlib.Path | None
236
+ Directory where the extracted ``mica_release`` folder should live.
237
+ Defaults to ``~/amica_test_data``.
238
+
239
+ Returns
240
+ -------
241
+ pathlib.Path
242
+ Path to the extracted ``mica_release`` directory.
243
+ """
244
+ import pooch
245
+
246
+ cache_root = Path(output_dir).expanduser() if output_dir is not None else CACHE_DIR
247
+ cache_root.mkdir(parents=True, exist_ok=True)
248
+
249
+ zip_path = Path(
250
+ pooch.retrieve(
251
+ url=MICA_RELEASE_URL,
252
+ known_hash=None,
253
+ path=cache_root,
254
+ fname="mica_release.zip",
255
+ progressbar=True,
256
+ )
257
+ )
258
+
259
+ release_dir = cache_root / "mica_release"
260
+ if release_dir.exists():
261
+ return release_dir
262
+
263
+ import zipfile
264
+
265
+ with zipfile.ZipFile(zip_path) as zf:
266
+ zf.extractall(cache_root)
267
+
268
+ if not release_dir.exists():
269
+ raise RuntimeError(
270
+ f"Expected extracted benchmark directory at {release_dir}, "
271
+ f"but it was not created from {zip_path}."
272
+ )
273
+
274
+ return release_dir
amica/utils/fortran.py ADDED
@@ -0,0 +1,387 @@
1
+ """Utilities for interfacing with Fortran AMICA outputs."""
2
+ import inspect
3
+ from dataclasses import MISSING, asdict, dataclass, fields
4
+ from functools import lru_cache
5
+ from pathlib import Path
6
+
7
+ import numpy as np
8
+
9
+
10
+ def load_initial_weights(fortran_outdir, *, n_components, n_mixtures):
11
+ """Load w_init, sbeta_init, and mu_init binary files from a Fortran AMICA run."""
12
+ fortran_outdir = Path(fortran_outdir)
13
+ assert fortran_outdir.exists()
14
+
15
+ initial_weights = np.fromfile(
16
+ fortran_outdir / "Wtmp.bin",
17
+ dtype=np.float64
18
+ )
19
+ initial_weights = initial_weights.reshape((n_components, n_components), order="F")
20
+ initial_scales = np.fromfile(
21
+ fortran_outdir / "sbetatmp.bin",
22
+ dtype=np.float64
23
+ )
24
+ initial_scales = initial_scales.reshape((n_mixtures, n_components), order="F")
25
+ initial_scales = initial_scales.T # Match Our dimension standard
26
+ initial_locations = np.fromfile(
27
+ fortran_outdir / "mutmp.bin",
28
+ dtype=np.float64
29
+ )
30
+ initial_locations = initial_locations.reshape((n_mixtures, n_components), order="F")
31
+ initial_locations = initial_locations.T # Match Our dimension standard
32
+ return initial_weights, initial_scales, initial_locations
33
+
34
+
35
+ def load_fortran_results(fortran_outdir, *, n_components, n_mixtures, n_features=None):
36
+ """Load results from a completed Fortran AMICA run for comparison."""
37
+ fortran_outdir = Path(fortran_outdir)
38
+ assert fortran_outdir.exists()
39
+ assert fortran_outdir.is_dir()
40
+ if n_features is None:
41
+ n_features = n_components
42
+
43
+ # Channel means
44
+ mean_f = np.fromfile(f"{fortran_outdir}/mean")
45
+
46
+ # Sphering matrix
47
+ S_f = np.fromfile(f"{fortran_outdir}/S", dtype=np.float64)
48
+ S_f = S_f.reshape((n_features, n_features), order="F")
49
+
50
+ # Unmixing matrix
51
+ W_f = np.fromfile(f"{fortran_outdir}/W", dtype=np.float64)
52
+ W_f = W_f.reshape((n_components, n_components, 1), order="F")
53
+
54
+ # Mixing matrix
55
+ A_f = np.fromfile(f"{fortran_outdir}/A")
56
+ A_f = A_f.reshape((n_components, n_components), order="F")
57
+
58
+ # Bias term
59
+ c_f = np.fromfile(f"{fortran_outdir}/c")
60
+ c_f = c_f.reshape((n_components, 1), order="F")
61
+
62
+ # Log-likelihood
63
+ LL_f = np.fromfile(f"{fortran_outdir}/LL")
64
+
65
+ # Mixture model parameters
66
+ # Fortran order is n_mixtures x n_components. Ours is n_components x n_mixtures
67
+ alpha_f = np.fromfile(f"{fortran_outdir}/alpha")
68
+ alpha_f = alpha_f.reshape((n_mixtures, n_components), order="F").T
69
+ # Remember that alpha (and sbeta, mu etc) are (num_comps, num_mix) in Python
70
+
71
+ # Scale parameters
72
+ sbeta_f = np.fromfile(f"{fortran_outdir}/sbeta", dtype=np.float64)
73
+ sbeta_f = sbeta_f.reshape((n_mixtures, n_components), order="F").T
74
+
75
+ # Location parameters
76
+ mu_f = np.fromfile(f"{fortran_outdir}/mu", dtype=np.float64)
77
+ mu_f = mu_f.reshape((n_mixtures, n_components), order="F").T
78
+
79
+ rho_f = np.fromfile(f"{fortran_outdir}/rho", dtype=np.float64)
80
+ rho_f = rho_f.reshape((n_mixtures, n_components), order="F").T
81
+
82
+
83
+ comp_list_f = np.fromfile(f"{fortran_outdir}/comp_list", dtype=np.int32)
84
+ # Something weird is happening there. I expect (num_comps, num_models) = (32, 1)
85
+ comp_list_f = np.reshape(comp_list_f, (n_components, 2), order="F")
86
+
87
+ gm_f = np.fromfile(f"{fortran_outdir}/gm")
88
+ return {
89
+ "W": W_f,
90
+ "S": S_f,
91
+ "sbeta": sbeta_f,
92
+ "rho": rho_f,
93
+ "mu": mu_f,
94
+ "mean": mean_f,
95
+ "gm": gm_f,
96
+ "comp_list": comp_list_f,
97
+ "c": c_f,
98
+ "alpha": alpha_f,
99
+ "A": A_f,
100
+ "LL": LL_f
101
+ }
102
+
103
+
104
+ def write_data(data, filename):
105
+ """Save data to a binary file in Fortran-compatible format.
106
+
107
+ Parameters
108
+ ----------
109
+ data : array-like
110
+ The data of shape (n_samples, n_features) to save. Will be converted to a
111
+ Fortran-contiguous array of type float32.
112
+ filename : str or Path
113
+ The path to the output binary file.
114
+
115
+ Returns
116
+ -------
117
+ data : np.ndarray
118
+ The Fortran-contiguous array that was saved.
119
+ path : Path
120
+ The path to the saved file.
121
+ """
122
+ # tofile ravels matrices in C order, so force Fortran order.
123
+ fpath = Path(filename).expanduser().resolve()
124
+ # We actually have to write in C order to be Fortran compatible.
125
+ # Or transpose the data First and write in Fortran order.
126
+ # Because Fortran program expects (n_features, n_samples)
127
+ vector = data.T.astype("<f4").ravel(order="F")
128
+ vector.tofile(fpath)
129
+ return fpath, data
130
+
131
+
132
+ def load_data(filename, *, dtype=np.float32, shape=None):
133
+ """Load binary data file that saved for use with Fortran AMICA.
134
+
135
+ Parameters
136
+ ----------
137
+ filename : str or Path
138
+ The path to the input binary file.
139
+ dtype : data-type
140
+ The desired data-type for the loaded array. Default is np.float32.
141
+ shape : tuple of int
142
+ The shape of the array to load.
143
+
144
+ Returns
145
+ -------
146
+ data : np.ndarray
147
+ The Fortran-contiguous array that was loaded.
148
+
149
+ Notes
150
+ -----
151
+ Fortran stores arrays in column-major order, and the Fortran program
152
+ expectes data in shape (n_features, n_samples). So when loading data
153
+ for use in Python, you should reshape to (n_features, n_samples) and
154
+ then transpose to (n_samples, n_features) to match the common Python
155
+ convention.
156
+
157
+ Examples
158
+ --------
159
+ >>> data = load_data('data.bin', shape=(64, 1000)).T
160
+ """
161
+ data = np.fromfile(filename, dtype=dtype)
162
+ if shape is not None:
163
+ data = data.reshape(shape, order="F")
164
+ return data
165
+
166
+
167
+ @lru_cache(maxsize=1)
168
+ def _get_fortran_param_defaults():
169
+ """Collect Fortran parameter defaults aligned with AMICA-Python."""
170
+ defaults = {}
171
+ for field in fields(FortranParams):
172
+ if field.default is not MISSING:
173
+ defaults[field.name] = field.default
174
+
175
+ # Public API defaults from fit_amica.
176
+ from amica.core import fit_amica
177
+
178
+ fit_defaults = {
179
+ name: param.default
180
+ for name, param in inspect.signature(fit_amica).parameters.items()
181
+ if param.default is not inspect.Signature.empty
182
+ }
183
+ fit_to_fortran = {
184
+ "max_iter": "max_iter",
185
+ "n_models": "num_models",
186
+ "n_mixtures": "num_mix_comps",
187
+ "lrate": "lrate",
188
+ "rholrate": "rholrate",
189
+ "pdftype": "pdftype",
190
+ "do_newton": "do_newton",
191
+ "newt_start": "newt_start",
192
+ "newtrate": "newtrate",
193
+ "newt_ramp": "newt_ramp",
194
+ "do_reject": "do_reject",
195
+ }
196
+ for fit_key, fortran_key in fit_to_fortran.items():
197
+ if fit_key in fit_defaults:
198
+ defaults[fortran_key] = fit_defaults[fit_key]
199
+
200
+ # Internal defaults from constants.py.
201
+ from amica import constants as c
202
+
203
+ const_to_fortran = {
204
+ "lratefact": "lratefact",
205
+ "rholratefact": "rholratefact",
206
+ "use_min_dll": "use_min_dll",
207
+ "use_grad_norm": "use_grad_norm",
208
+ "min_dll": "min_dll",
209
+ "min_nd": "min_grad_norm",
210
+ "do_opt_block": "do_opt_block",
211
+ "mineig": "mineig",
212
+ "minlrate": "minlrate",
213
+ "rho0": "rho0",
214
+ "minrho": "minrho",
215
+ "maxrho": "maxrho",
216
+ "invsigmax": "invsigmax",
217
+ "invsigmin": "invsigmin",
218
+ "dorho": "do_rho",
219
+ "doscaling": "doscaling",
220
+ "maxdecs": "max_decs",
221
+ "share_comps": "share_comps",
222
+ "share_start": "share_start",
223
+ "share_iter": "share_iter",
224
+ "outstep": "writestep",
225
+ "fix_init": "fix_init",
226
+ }
227
+ for const_key, fortran_key in const_to_fortran.items():
228
+ if hasattr(c, const_key):
229
+ defaults[fortran_key] = getattr(c, const_key)
230
+
231
+ return defaults
232
+
233
+
234
+ def write_param_file(fpath, data, **kwargs):
235
+ """Write a Fortran AMICA parameter file.
236
+
237
+ Parameters
238
+ ----------
239
+ fpath : str or Path
240
+ The path to the output parameter file.
241
+ data : np.ndarray
242
+ The data array to write to the file.
243
+ **kwargs : dict
244
+ Additional parameters to write to the file. ``files`` and ``outdir`` are
245
+ required and should be passed via kwargs.
246
+
247
+ Returns
248
+ -------
249
+ path : Path
250
+ The path to the saved parameter file.
251
+ """
252
+ fpath = Path(fpath).expanduser().resolve()
253
+ missing = [key for key in ("files", "outdir") if key not in kwargs]
254
+ if missing:
255
+ missing_joined = ", ".join(missing)
256
+ raise ValueError(
257
+ f"Missing required parameter(s): {missing_joined}. Pass them via kwargs."
258
+ )
259
+
260
+ merged = {**_get_fortran_param_defaults(), **kwargs}
261
+ merged.setdefault("data_dim", data.shape[1])
262
+ merged.setdefault("field_dim", data.shape[0])
263
+ merged.setdefault("block_size", data.shape[0])
264
+ merged.setdefault("pcakeep", data.shape[1])
265
+
266
+ params = FortranParams(**merged)
267
+ params_dict = params.to_param_dict()
268
+
269
+ fpath.write_text("".join(f"{key} {value}\n" for key, value in params_dict.items()))
270
+ return fpath, params
271
+
272
+
273
+ @dataclass
274
+ class FortranParams:
275
+ """Dataclass to hold Fortran AMICA parameters."""
276
+
277
+ # Required parameters
278
+ files: str | Path
279
+ outdir: str | Path
280
+ # Data Shape
281
+ block_size: int
282
+ data_dim: int # n_features
283
+ field_dim: int # n_samples
284
+ max_iter: int = 500
285
+ blk_min: int | None = None
286
+ blk_step: int | None = None
287
+ blk_max: int | None = None
288
+ # Whitening
289
+ do_mean: int = 1
290
+ do_sphere: int = 1
291
+ doPCA: int = 1
292
+ pcakeep: int | None = None
293
+ pcadb: float = 30.000000
294
+ # You'll probably never need to change these...
295
+ # Main Model Params
296
+ num_models : int = 1
297
+ max_threads : int = 1 # Single-threaded (aids debugging)
298
+ # Newton
299
+ do_newton: int =1
300
+ newt_start: int = 50
301
+ newt_ramp: int = 10
302
+ newtrate: float = 1.000000
303
+ # Learning Rates
304
+ lrate: float = 0.050000
305
+ rholrate: float = 0.050000
306
+ lratefact: float = 0.500000
307
+ rholratefact: float = 0.500000
308
+ # Convergence
309
+ use_min_dll: int | bool = 1
310
+ min_dll: float = 1.000000e-09
311
+ use_grad_norm: int = 1
312
+ min_grad_norm: float = 1.000000e-07
313
+ # Misc.
314
+ do_opt_block: int | bool = 0
315
+ num_mix_comps: int = 3
316
+ pdftype: int = 0
317
+ num_samples: int = 1
318
+ field_blocksize: int = 1
319
+ do_history: int = 0
320
+ histstep: int = 10
321
+ share_comps: int = 0
322
+ share_start: int = 100
323
+ comp_thresh: float = 0.990000
324
+ share_iter: int = 100
325
+ minlrate: float = 1.000000e-08
326
+ mineig: float = 1.000000e-12
327
+ rho0: float = 1.500000
328
+ minrho: float = 1.000000
329
+ maxrho: float = 2.000000
330
+ kurt_start: int = 3
331
+ num_kurt: int = 5
332
+ kurt_int: int = 1
333
+ # Rejection
334
+ do_reject: int = 0
335
+ numrej: int = 3
336
+ rejsig: float = 3.000000
337
+ rejstart: int = 2
338
+ rejint: int = 3
339
+ # Saving
340
+ writestep: int = 1 # Write every iteration (aids debugging)
341
+ write_nd: int = 0
342
+ write_LLt: int = 1
343
+ decwindow: int = 1
344
+ max_decs: int = 3
345
+ fix_init: int = 0
346
+ update_A: int = 1
347
+ update_c: int = 1
348
+ update_gm: int = 1
349
+ update_alpha: int = 1
350
+ update_mu: int = 1
351
+ update_beta: int = 1
352
+ invsigmax: float = 100.000000
353
+ invsigmin: float = 1.0e-08
354
+ do_rho: int = 1
355
+ # Debugging
356
+ load_rej: int = 0
357
+ load_W: int = 0
358
+ load_c: int = 0
359
+ load_gm: int = 0
360
+ load_alpha: int = 0
361
+ load_mu: int = 0
362
+ load_beta: int = 0
363
+ load_rho: int = 0
364
+ load_comp_list: int = 0
365
+ byte_size: int = 4
366
+ doscaling: int = 1
367
+ scalestep: int = 1
368
+
369
+ def __post_init__(self):
370
+ """Initialize attributes."""
371
+ if self.blk_min is None:
372
+ self.blk_min = self.block_size // 4
373
+ if self.blk_step is None:
374
+ self.blk_step = self.block_size // 4
375
+ if self.blk_max is None:
376
+ self.blk_max = self.block_size
377
+ if self.pcakeep is None:
378
+ self.pcakeep = self.data_dim
379
+
380
+ # Convert bools to int
381
+ for field in fields(self):
382
+ if isinstance(getattr(self, field.name), bool):
383
+ setattr(self, field.name, int(getattr(self, field.name)))
384
+
385
+ def to_param_dict(self):
386
+ """Convert to a dict suitable for writing to aFortran AMICA parameter file."""
387
+ return asdict(self)
amica/utils/imports.py ADDED
@@ -0,0 +1,46 @@
1
+ """Soft import helper."""
2
+ from importlib import import_module
3
+ from importlib.util import find_spec
4
+ from types import ModuleType
5
+
6
+ # A mapping from import name to package name (on PyPI) when the package name
7
+ # is different.
8
+ _INSTALL_MAPPING: dict[str, str] = {
9
+ "codespell_lib": "codespell",
10
+ "cv2": "opencv-python",
11
+ "parallel": "pyparallel",
12
+ "pytest_cov": "pytest-cov",
13
+ "serial": "pyserial",
14
+ "sklearn": "scikit-learn",
15
+ "sksparse": "scikit-sparse",
16
+ }
17
+
18
+ def import_optional_dependency(
19
+ name: str,
20
+ extra: str = "",
21
+ ) -> ModuleType:
22
+ """Import an optional dependency.
23
+
24
+ If a dependency is missing an ImportError with a nice message will be
25
+ raised.
26
+
27
+ Parameters
28
+ ----------
29
+ name : str
30
+ The module name.
31
+ extra : str
32
+ Additional text to include in the ImportError message.
33
+
34
+ Returns
35
+ -------
36
+ module : Module
37
+ The imported module when found.
38
+ """
39
+ package_name = _INSTALL_MAPPING.get(name)
40
+ install_name = package_name if package_name is not None else name
41
+ if find_spec(name) is None:
42
+ raise ImportError(
43
+ f"Missing optional dependency '{install_name}'. {extra} Use pip or "
44
+ f"conda to install {install_name}."
45
+ )
46
+ return import_module(name)