bidsreader 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bidsreader/__init__.py +15 -0
- bidsreader/_errorwrap.py +50 -0
- bidsreader/basereader.py +208 -0
- bidsreader/cmlbidsreader.py +269 -0
- bidsreader/convert.py +57 -0
- bidsreader/exc.py +23 -0
- bidsreader/filtering.py +178 -0
- bidsreader/helpers.py +148 -0
- bidsreader/units.py +287 -0
- bidsreader-0.1.0.dist-info/METADATA +494 -0
- bidsreader-0.1.0.dist-info/RECORD +14 -0
- bidsreader-0.1.0.dist-info/WHEEL +5 -0
- bidsreader-0.1.0.dist-info/licenses/LICENSE +21 -0
- bidsreader-0.1.0.dist-info/top_level.txt +1 -0
bidsreader/filtering.py
ADDED
|
@@ -0,0 +1,178 @@
|
|
|
1
|
+
import mne
|
|
2
|
+
import numpy as np
|
|
3
|
+
import pandas as pd
|
|
4
|
+
from typing import Iterable, Optional, Dict
|
|
5
|
+
from ._errorwrap import public_api
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def _label_has_trial_type(label: str, trial_types: list[str]) -> bool:
|
|
9
|
+
# exact token match within merged labels like "WORD/STIM"
|
|
10
|
+
tokens = label.split("/")
|
|
11
|
+
return any(t in tokens for t in trial_types)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def _ensure_list(trial_types: Iterable[str] | str) -> list[str]:
|
|
15
|
+
return [trial_types] if isinstance(trial_types, str) else list(trial_types)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@public_api
|
|
19
|
+
def filter_events_df_by_trial_types(
|
|
20
|
+
events_df: pd.DataFrame,
|
|
21
|
+
trial_types: Iterable[str] | str,
|
|
22
|
+
) -> tuple[pd.DataFrame, np.ndarray]:
|
|
23
|
+
tt = _ensure_list(trial_types)
|
|
24
|
+
|
|
25
|
+
mask = events_df["trial_type"].isin(tt).to_numpy()
|
|
26
|
+
filtered_df = events_df.loc[mask].copy()
|
|
27
|
+
|
|
28
|
+
# integer positions (0..n-1) into the *original* events_df rows
|
|
29
|
+
df_idx = np.flatnonzero(mask)
|
|
30
|
+
|
|
31
|
+
return filtered_df, df_idx
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
@public_api
|
|
35
|
+
def filter_raw_events_by_trial_types(
|
|
36
|
+
raw: mne.io.BaseRaw,
|
|
37
|
+
trial_types: Iterable[str] | str,
|
|
38
|
+
) -> tuple[np.ndarray, Dict[str, int], np.ndarray]:
|
|
39
|
+
tt = _ensure_list(trial_types)
|
|
40
|
+
|
|
41
|
+
events_raw, event_id = mne.events_from_annotations(raw)
|
|
42
|
+
|
|
43
|
+
filtered_event_id = {
|
|
44
|
+
k: v for k, v in event_id.items()
|
|
45
|
+
if _label_has_trial_type(k, tt)
|
|
46
|
+
}
|
|
47
|
+
|
|
48
|
+
if filtered_event_id:
|
|
49
|
+
codes = np.fromiter(filtered_event_id.values(), dtype=int)
|
|
50
|
+
mask = np.isin(events_raw[:, 2], codes)
|
|
51
|
+
filtered_events = events_raw[mask]
|
|
52
|
+
raw_idx = np.flatnonzero(mask) # indices into events_raw
|
|
53
|
+
else:
|
|
54
|
+
filtered_events = events_raw[:0].copy()
|
|
55
|
+
filtered_event_id = {}
|
|
56
|
+
raw_idx = np.array([], dtype=int)
|
|
57
|
+
|
|
58
|
+
return filtered_events, filtered_event_id, raw_idx
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
@public_api
|
|
62
|
+
def filter_epochs_by_trial_types(
|
|
63
|
+
epochs: mne.Epochs,
|
|
64
|
+
trial_types: Iterable[str] | str,
|
|
65
|
+
) -> tuple[mne.Epochs, Dict[str, int], np.ndarray]:
|
|
66
|
+
tt = _ensure_list(trial_types)
|
|
67
|
+
|
|
68
|
+
keys = [
|
|
69
|
+
k for k in epochs.event_id.keys()
|
|
70
|
+
if _label_has_trial_type(k, tt)
|
|
71
|
+
]
|
|
72
|
+
filtered_event_id = {k: epochs.event_id[k] for k in keys}
|
|
73
|
+
|
|
74
|
+
if keys:
|
|
75
|
+
filtered_epochs = epochs[keys]
|
|
76
|
+
codes = np.fromiter(filtered_event_id.values(), dtype=int)
|
|
77
|
+
mask = np.isin(epochs.events[:, 2], codes)
|
|
78
|
+
ep_idx = np.flatnonzero(mask) # indices into original epochs
|
|
79
|
+
else:
|
|
80
|
+
filtered_epochs = epochs.copy()[[]]
|
|
81
|
+
ep_idx = np.array([], dtype=int)
|
|
82
|
+
|
|
83
|
+
return filtered_epochs, filtered_event_id, ep_idx
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
@public_api
|
|
87
|
+
def filter_by_trial_types(
|
|
88
|
+
trial_types: Iterable[str] | str,
|
|
89
|
+
*,
|
|
90
|
+
events_df: Optional[pd.DataFrame] = None,
|
|
91
|
+
raw: Optional[mne.io.BaseRaw] = None,
|
|
92
|
+
epochs: Optional[mne.Epochs] = None,
|
|
93
|
+
) -> tuple[
|
|
94
|
+
Optional[pd.DataFrame],
|
|
95
|
+
Optional[np.ndarray], # filtered_events (from raw)
|
|
96
|
+
Optional[mne.Epochs],
|
|
97
|
+
Dict[str, int],
|
|
98
|
+
np.ndarray, # filtered_event_idx (0..n-1)
|
|
99
|
+
]:
|
|
100
|
+
tt = _ensure_list(trial_types)
|
|
101
|
+
|
|
102
|
+
filtered_df: Optional[pd.DataFrame] = None
|
|
103
|
+
filtered_events: Optional[np.ndarray] = None
|
|
104
|
+
filtered_epochs: Optional[mne.Epochs] = None
|
|
105
|
+
|
|
106
|
+
df_idx: Optional[np.ndarray] = None
|
|
107
|
+
raw_idx: Optional[np.ndarray] = None
|
|
108
|
+
ep_idx: Optional[np.ndarray] = None
|
|
109
|
+
|
|
110
|
+
event_id_raw: Optional[Dict[str, int]] = None
|
|
111
|
+
event_id_epochs: Optional[Dict[str, int]] = None
|
|
112
|
+
|
|
113
|
+
n_df = None
|
|
114
|
+
n_raw = None
|
|
115
|
+
n_ep = None
|
|
116
|
+
|
|
117
|
+
# ---- DF ----
|
|
118
|
+
if events_df is not None:
|
|
119
|
+
filtered_df, df_idx = filter_events_df_by_trial_types(events_df, tt)
|
|
120
|
+
n_df = len(filtered_df)
|
|
121
|
+
|
|
122
|
+
# ---- RAW ----
|
|
123
|
+
raw_onsets = None
|
|
124
|
+
if raw is not None:
|
|
125
|
+
filtered_events, event_id_raw, raw_idx = filter_raw_events_by_trial_types(raw, tt)
|
|
126
|
+
n_raw = int(filtered_events.shape[0])
|
|
127
|
+
raw_onsets = filtered_events[:, 0].astype(int)
|
|
128
|
+
|
|
129
|
+
# ---- EPOCHS ----
|
|
130
|
+
ep_onsets = None
|
|
131
|
+
if epochs is not None:
|
|
132
|
+
filtered_epochs, event_id_epochs, ep_idx = filter_epochs_by_trial_types(epochs, tt)
|
|
133
|
+
n_ep = len(filtered_epochs)
|
|
134
|
+
|
|
135
|
+
if event_id_epochs:
|
|
136
|
+
codes = np.fromiter(event_id_epochs.values(), dtype=int)
|
|
137
|
+
mask = np.isin(epochs.events[:, 2], codes)
|
|
138
|
+
ep_onsets = epochs.events[mask, 0].astype(int)
|
|
139
|
+
else:
|
|
140
|
+
ep_onsets = np.array([], dtype=int)
|
|
141
|
+
|
|
142
|
+
# ---- Check event_id consistency (keys) ----
|
|
143
|
+
if event_id_raw is not None and event_id_epochs is not None:
|
|
144
|
+
if set(event_id_raw.keys()) != set(event_id_epochs.keys()):
|
|
145
|
+
raise ValueError(
|
|
146
|
+
"filtered_event_id key mismatch between raw and epochs.\n"
|
|
147
|
+
f"raw keys={sorted(event_id_raw.keys())}\n"
|
|
148
|
+
f"epochs keys={sorted(event_id_epochs.keys())}"
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
# Strong alignment check: same event onsets
|
|
152
|
+
if raw_onsets is not None and ep_onsets is not None:
|
|
153
|
+
if raw_onsets.shape != ep_onsets.shape or not np.array_equal(raw_onsets, ep_onsets):
|
|
154
|
+
raise ValueError(
|
|
155
|
+
"raw/epochs trial alignment mismatch: filtered event sample onsets differ.\n"
|
|
156
|
+
f"n_raw={len(raw_onsets)} n_epochs={len(ep_onsets)}"
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
filtered_event_id = event_id_raw # choose one
|
|
160
|
+
elif event_id_raw is not None:
|
|
161
|
+
filtered_event_id = event_id_raw
|
|
162
|
+
elif event_id_epochs is not None:
|
|
163
|
+
filtered_event_id = event_id_epochs
|
|
164
|
+
else:
|
|
165
|
+
filtered_event_id = {}
|
|
166
|
+
|
|
167
|
+
# ---- Trial count consistency (for whichever inputs are provided) ----
|
|
168
|
+
ns = [n for n in (n_df, n_raw, n_ep) if n is not None]
|
|
169
|
+
n = ns[0] if ns else 0
|
|
170
|
+
|
|
171
|
+
if any(x != n for x in ns):
|
|
172
|
+
raise ValueError(
|
|
173
|
+
"Trial count mismatch across provided inputs.\n"
|
|
174
|
+
f"n_df={n_df} n_raw={n_raw} n_epochs={n_ep}"
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
filtered_event_idx = np.arange(n, dtype=int)
|
|
178
|
+
return filtered_df, filtered_events, filtered_epochs, filtered_event_id, filtered_event_idx
|
bidsreader/helpers.py
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import pandas as pd
|
|
3
|
+
from typing import Iterable, Any, Tuple, Sequence, Optional, Dict
|
|
4
|
+
import re
|
|
5
|
+
from .exc import InvalidOptionError
|
|
6
|
+
|
|
7
|
+
def validate_option(name: str, value: Any, allowed: Iterable[Any]) -> Any:
|
|
8
|
+
if value is None:
|
|
9
|
+
return None
|
|
10
|
+
if value not in allowed:
|
|
11
|
+
raise InvalidOptionError(f"{name} must be one of: {allowed}. Got {value!r}")
|
|
12
|
+
return value
|
|
13
|
+
|
|
14
|
+
def space_from_coordsystem_fname(fname: str) -> Optional[str]:
|
|
15
|
+
if "_space-" not in fname:
|
|
16
|
+
return None
|
|
17
|
+
return fname.split("_space-")[1].split("_coordsystem.json")[0]
|
|
18
|
+
|
|
19
|
+
def add_prefix(value: Optional[str], prefix: str) -> Optional[str]:
|
|
20
|
+
if value is None:
|
|
21
|
+
return None
|
|
22
|
+
|
|
23
|
+
value = str(value)
|
|
24
|
+
|
|
25
|
+
if value.startswith(prefix):
|
|
26
|
+
return value
|
|
27
|
+
|
|
28
|
+
return f"{prefix}{value}"
|
|
29
|
+
|
|
30
|
+
def merge_duplicate_sample_events(evs: pd.DataFrame, sample_col: str = "sample") -> pd.DataFrame:
|
|
31
|
+
df = evs.copy()
|
|
32
|
+
|
|
33
|
+
# Ensure stable ordering so "first" is well-defined.
|
|
34
|
+
df["_orig_order"] = np.arange(len(df))
|
|
35
|
+
|
|
36
|
+
def first_non_nan(s: pd.Series):
|
|
37
|
+
s2 = s.dropna()
|
|
38
|
+
return s2.iloc[0] if len(s2) else np.nan
|
|
39
|
+
|
|
40
|
+
def merge_series(s: pd.Series):
|
|
41
|
+
# General "take the first non-NaN; if only one non-NaN, that's what it is" behavior
|
|
42
|
+
return first_non_nan(s)
|
|
43
|
+
|
|
44
|
+
def merge_trial_type(s: pd.Series):
|
|
45
|
+
vals = [v for v in s.tolist() if pd.notna(v)]
|
|
46
|
+
# preserve order but avoid duplicates like A/A
|
|
47
|
+
uniq = []
|
|
48
|
+
for v in vals:
|
|
49
|
+
if v not in uniq:
|
|
50
|
+
uniq.append(v)
|
|
51
|
+
if not uniq:
|
|
52
|
+
return np.nan
|
|
53
|
+
return "/".join(map(str, uniq))
|
|
54
|
+
|
|
55
|
+
merged_rows = []
|
|
56
|
+
for sample_val, g in df.sort_values("_orig_order").groupby(sample_col, sort=False):
|
|
57
|
+
out = {}
|
|
58
|
+
for col in df.columns:
|
|
59
|
+
if col in ("_orig_order",):
|
|
60
|
+
continue
|
|
61
|
+
if col == "trial_type":
|
|
62
|
+
out[col] = merge_trial_type(g[col])
|
|
63
|
+
else:
|
|
64
|
+
out[col] = merge_series(g[col])
|
|
65
|
+
merged_rows.append(out)
|
|
66
|
+
|
|
67
|
+
out_df = pd.DataFrame(merged_rows)
|
|
68
|
+
|
|
69
|
+
# If you want to preserve original column order (minus helper col)
|
|
70
|
+
out_df = out_df[[c for c in evs.columns if c in out_df.columns]]
|
|
71
|
+
|
|
72
|
+
return out_df
|
|
73
|
+
|
|
74
|
+
def find_coord_triplets(columns: Sequence[str]) -> Dict[str, Tuple[str, str, str]]:
|
|
75
|
+
cols = set(columns)
|
|
76
|
+
|
|
77
|
+
triplets = {}
|
|
78
|
+
|
|
79
|
+
if {"x", "y", "z"} <= cols:
|
|
80
|
+
triplets[""] = ("x", "y", "z")
|
|
81
|
+
|
|
82
|
+
prefixed = [c for c in cols if re.match(r"^.+\.(x|y|z)$", c)]
|
|
83
|
+
prefixes = set(c.rsplit(".", 1)[0] for c in prefixed)
|
|
84
|
+
|
|
85
|
+
for p in prefixes:
|
|
86
|
+
x, y, z = f"{p}.x", f"{p}.y", f"{p}.z"
|
|
87
|
+
if {x, y, z} <= cols:
|
|
88
|
+
triplets[p] = (x, y, z)
|
|
89
|
+
|
|
90
|
+
return triplets
|
|
91
|
+
|
|
92
|
+
def combine_bipolar_electrodes(
|
|
93
|
+
pairs_df: pd.DataFrame,
|
|
94
|
+
elec_df: pd.DataFrame,
|
|
95
|
+
pair_col: str = "name",
|
|
96
|
+
elec_name_col: str = "name",
|
|
97
|
+
region_cols: Sequence[str] = ("wb.region", "ind.region", "stein.region"),
|
|
98
|
+
) -> pd.DataFrame:
|
|
99
|
+
sep = "-"
|
|
100
|
+
out = pairs_df.copy()
|
|
101
|
+
|
|
102
|
+
# Split bipolar pair
|
|
103
|
+
ch = out[pair_col].astype(str).str.split(sep, n=1, expand=True)
|
|
104
|
+
out["ch1"] = ch[0].str.strip()
|
|
105
|
+
out["ch2"] = ch[1].str.strip()
|
|
106
|
+
|
|
107
|
+
# Detect all coordinate triplets present in electrodes df
|
|
108
|
+
coord_triplets = find_coord_triplets(elec_df.columns)
|
|
109
|
+
|
|
110
|
+
# Keep electrode name + region cols + all coordinate columns we found
|
|
111
|
+
coord_cols = [c for trip in coord_triplets.values() for c in trip]
|
|
112
|
+
keep_cols = [elec_name_col, *region_cols, *coord_cols]
|
|
113
|
+
keep_cols = [c for c in keep_cols if c in elec_df.columns] # safety
|
|
114
|
+
|
|
115
|
+
look = elec_df[keep_cols].copy()
|
|
116
|
+
|
|
117
|
+
# Merge ch1 metadata
|
|
118
|
+
look1 = look.add_suffix("_ch1").rename(columns={f"{elec_name_col}_ch1": "ch1"})
|
|
119
|
+
out = out.merge(look1, on="ch1", how="left")
|
|
120
|
+
|
|
121
|
+
# Merge ch2 metadata
|
|
122
|
+
look2 = look.add_suffix("_ch2").rename(columns={f"{elec_name_col}_ch2": "ch2"})
|
|
123
|
+
out = out.merge(look2, on="ch2", how="left")
|
|
124
|
+
|
|
125
|
+
# Region agreement
|
|
126
|
+
for rc in region_cols:
|
|
127
|
+
if f"{rc}_ch1" in out.columns and f"{rc}_ch2" in out.columns:
|
|
128
|
+
a = out[f"{rc}_ch1"]
|
|
129
|
+
b = out[f"{rc}_ch2"]
|
|
130
|
+
out[f"{rc}_pair"] = np.where(a.notna() & (a == b), a, np.nan)
|
|
131
|
+
|
|
132
|
+
# Midpoints for every detected coordinate triplet
|
|
133
|
+
for prefix, (xcol, ycol, zcol) in coord_triplets.items():
|
|
134
|
+
for col in (xcol, ycol, zcol):
|
|
135
|
+
a = out[f"{col}_ch1"]
|
|
136
|
+
b = out[f"{col}_ch2"]
|
|
137
|
+
mid_name = f"{col}_mid" # e.g., "x_mid" or "tal.x_mid"
|
|
138
|
+
out[mid_name] = np.where(a.notna() & b.notna(), (a + b) / 2.0, np.nan)
|
|
139
|
+
|
|
140
|
+
return out
|
|
141
|
+
|
|
142
|
+
def normalize_trial_types(trial_types: Iterable[str]) -> set[str]:
|
|
143
|
+
return {str(t) for t in trial_types}
|
|
144
|
+
|
|
145
|
+
def match_event_label(label: str, trial_types: list[str]) -> bool:
|
|
146
|
+
# exact token match within merged labels like "WORD/STIM"
|
|
147
|
+
tokens = label.split("/")
|
|
148
|
+
return any(t in tokens for t in trial_types)
|
bidsreader/units.py
ADDED
|
@@ -0,0 +1,287 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import mne
|
|
4
|
+
import numpy as np
|
|
5
|
+
from typing import Optional, Union, TYPE_CHECKING
|
|
6
|
+
from ._errorwrap import public_api
|
|
7
|
+
|
|
8
|
+
if TYPE_CHECKING:
|
|
9
|
+
from ptsa.data.timeseries import TimeSeries
|
|
10
|
+
|
|
11
|
+
# ---------- unit constants ----------
|
|
12
|
+
_EXP_TO_PREFIX = {
|
|
13
|
+
15: "P", 14: "e14_", 13: "e13_", 12: "T", 11: "e11_", 10: "e10_",
|
|
14
|
+
9: "G", 8: "e8_", 7: "e7_", 6: "M", 5: "e5_", 4: "e4_",
|
|
15
|
+
3: "k", 2: "h", 1: "da",
|
|
16
|
+
0: "",
|
|
17
|
+
-1: "d", -2: "c",
|
|
18
|
+
-3: "m", -4: "e-4_", -5: "e-5_", -6: "u", -7: "e-7_", -8: "e-8_",
|
|
19
|
+
-9: "n", -10: "e-10_", -11: "e-11_", -12: "p", -13: "e-13_", -14: "e-14_",
|
|
20
|
+
-15: "f",
|
|
21
|
+
}
|
|
22
|
+
|
|
23
|
+
_UNIT_EXPONENTS = {
|
|
24
|
+
# Volts
|
|
25
|
+
"PV": 15, "TV": 12, "GV": 9, "MV": 6, "kV": 3,
|
|
26
|
+
"hV": 2, "daV": 1,
|
|
27
|
+
"V": 0,
|
|
28
|
+
"dV": -1, "cV": -2,
|
|
29
|
+
"mV": -3, "uV": -6, "nV": -9, "pV": -12, "fV": -15,
|
|
30
|
+
# Tesla
|
|
31
|
+
"PT": 15, "TT": 12, "GT": 9, "MT": 6, "kT": 3,
|
|
32
|
+
"hT": 2, "daT": 1,
|
|
33
|
+
"T": 0,
|
|
34
|
+
"dT": -1, "cT": -2,
|
|
35
|
+
"mT": -3, "uT": -6, "nT": -9, "pT": -12, "fT": -15,
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
_FIFF_UNIT_TO_BASE = {107: "V", 201: "T", 0: None}
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
# ---------- internal helpers ----------
|
|
42
|
+
|
|
43
|
+
def _normalize_unit(unit: str) -> str:
|
|
44
|
+
return (
|
|
45
|
+
unit
|
|
46
|
+
.replace("\u00b5", "u") # micro sign
|
|
47
|
+
.replace("\u03bc", "u") # greek mu
|
|
48
|
+
.replace("\u03a9", "Ohm") # greek omega
|
|
49
|
+
.replace("\u2126", "Ohm") # ohm sign
|
|
50
|
+
.replace("\u00b0", "deg") # degree sign
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def _detect_unit_mne(inst: Union[mne.io.BaseRaw, mne.Epochs]) -> str:
|
|
55
|
+
"""Detect unit string from an MNE Raw or Epochs object."""
|
|
56
|
+
eeg_types = {"eeg", "seeg", "ecog", "ieeg", "dbs"}
|
|
57
|
+
|
|
58
|
+
for ch_info in inst.info["chs"]:
|
|
59
|
+
ch_kind = mne.channel_type(
|
|
60
|
+
inst.info, inst.ch_names.index(ch_info["ch_name"]),
|
|
61
|
+
)
|
|
62
|
+
if ch_kind not in eeg_types:
|
|
63
|
+
continue
|
|
64
|
+
|
|
65
|
+
fiff_unit = ch_info.get("unit", 0)
|
|
66
|
+
fiff_mul = ch_info.get("unit_mul", 0)
|
|
67
|
+
|
|
68
|
+
base = _FIFF_UNIT_TO_BASE.get(fiff_unit)
|
|
69
|
+
if base is None:
|
|
70
|
+
raise ValueError(
|
|
71
|
+
f"Unknown FIFF unit code {fiff_unit} on channel "
|
|
72
|
+
f"'{ch_info['ch_name']}'. Pass current_unit= explicitly."
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
exp = fiff_mul
|
|
76
|
+
prefix = _EXP_TO_PREFIX.get(exp, "")
|
|
77
|
+
return f"{prefix}{base}"
|
|
78
|
+
|
|
79
|
+
raise ValueError(
|
|
80
|
+
"No EEG/iEEG/SEEG/ECoG channel found. Cannot detect unit."
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def _detect_unit_ptsa(ts: TimeSeries) -> str:
|
|
85
|
+
"""Detect unit string from a PTSA TimeSeries."""
|
|
86
|
+
for key in ("units", "unit"):
|
|
87
|
+
val = ts.attrs.get(key)
|
|
88
|
+
if val is not None and str(val).strip():
|
|
89
|
+
unit_str = _normalize_unit(str(val).strip())
|
|
90
|
+
if unit_str in _UNIT_EXPONENTS:
|
|
91
|
+
return unit_str
|
|
92
|
+
raise ValueError(
|
|
93
|
+
f"TimeSeries has unit '{val}' which is not recognized. "
|
|
94
|
+
f"Known: {sorted(_UNIT_EXPONENTS.keys())}"
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
raise ValueError(
|
|
98
|
+
"TimeSeries has no 'units' or 'unit' attribute. "
|
|
99
|
+
"Pass current_unit= explicitly."
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def _convert_mne(
|
|
104
|
+
inst: Union[mne.io.BaseRaw, mne.Epochs],
|
|
105
|
+
factor: float,
|
|
106
|
+
target_unit: str,
|
|
107
|
+
copy: bool,
|
|
108
|
+
) -> Union[mne.io.BaseRaw, mne.Epochs]:
|
|
109
|
+
"""Scale MNE data and update FIFF unit metadata."""
|
|
110
|
+
if copy:
|
|
111
|
+
inst = inst.copy()
|
|
112
|
+
|
|
113
|
+
inst.apply_function(lambda x: x * factor, picks="all", channel_wise=False)
|
|
114
|
+
|
|
115
|
+
base_char = target_unit[-1]
|
|
116
|
+
target_exp = _UNIT_EXPONENTS[target_unit]
|
|
117
|
+
fiff_unit_code = {"V": 107, "T": 201}.get(base_char, 0)
|
|
118
|
+
fiff_mul = max(-15, min(15, target_exp))
|
|
119
|
+
|
|
120
|
+
eeg_kinds = {2, 302, 802, 803}
|
|
121
|
+
for ch in inst.info["chs"]:
|
|
122
|
+
if ch.get("kind", 0) in eeg_kinds or ch.get("unit", 0) in (107, 201):
|
|
123
|
+
ch["unit"] = fiff_unit_code
|
|
124
|
+
ch["unit_mul"] = fiff_mul
|
|
125
|
+
|
|
126
|
+
return inst
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def _convert_ptsa(
|
|
130
|
+
ts: TimeSeries,
|
|
131
|
+
factor: float,
|
|
132
|
+
target_unit: str,
|
|
133
|
+
copy: bool,
|
|
134
|
+
) -> TimeSeries:
|
|
135
|
+
"""Scale PTSA TimeSeries data and update attrs."""
|
|
136
|
+
if copy:
|
|
137
|
+
result = ts * factor
|
|
138
|
+
else:
|
|
139
|
+
ts.values[:] *= factor
|
|
140
|
+
result = ts
|
|
141
|
+
|
|
142
|
+
result.attrs["units"] = target_unit
|
|
143
|
+
result.attrs["unit"] = target_unit
|
|
144
|
+
|
|
145
|
+
return result
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
def _is_timeseries(obj) -> bool:
|
|
149
|
+
"""Check if obj is a PTSA TimeSeries without requiring PTSA at import time."""
|
|
150
|
+
try:
|
|
151
|
+
from ptsa.data.timeseries import TimeSeries
|
|
152
|
+
return isinstance(obj, TimeSeries)
|
|
153
|
+
except ImportError:
|
|
154
|
+
return False
|
|
155
|
+
|
|
156
|
+
# ---------- public API ----------
|
|
157
|
+
|
|
158
|
+
@public_api
|
|
159
|
+
def detect_unit(
|
|
160
|
+
data: Union[mne.io.BaseRaw, mne.Epochs, TimeSeries],
|
|
161
|
+
current_unit: Optional[str] = None,
|
|
162
|
+
) -> str:
|
|
163
|
+
"""Detect or validate the unit of EEG data.
|
|
164
|
+
|
|
165
|
+
Parameters
|
|
166
|
+
----------
|
|
167
|
+
data : mne.io.BaseRaw, mne.Epochs, or PTSA TimeSeries
|
|
168
|
+
The data object to inspect.
|
|
169
|
+
current_unit : str, optional
|
|
170
|
+
If provided, overrides auto-detection. Validated against
|
|
171
|
+
known units and returned directly.
|
|
172
|
+
|
|
173
|
+
Returns
|
|
174
|
+
-------
|
|
175
|
+
str
|
|
176
|
+
Unit string like "V", "mV", "uV", "nV", "T", etc.
|
|
177
|
+
|
|
178
|
+
Raises
|
|
179
|
+
------
|
|
180
|
+
ValueError
|
|
181
|
+
If unit cannot be detected and current_unit is not provided.
|
|
182
|
+
"""
|
|
183
|
+
if current_unit is not None:
|
|
184
|
+
normalized = _normalize_unit(current_unit)
|
|
185
|
+
if normalized not in _UNIT_EXPONENTS:
|
|
186
|
+
raise ValueError(
|
|
187
|
+
f"Unknown unit '{current_unit}'. "
|
|
188
|
+
f"Known: {sorted(_UNIT_EXPONENTS.keys())}"
|
|
189
|
+
)
|
|
190
|
+
return normalized
|
|
191
|
+
|
|
192
|
+
if isinstance(data, (mne.io.BaseRaw, mne.Epochs)):
|
|
193
|
+
return _detect_unit_mne(data)
|
|
194
|
+
|
|
195
|
+
if _is_timeseries(data):
|
|
196
|
+
return _detect_unit_ptsa(data)
|
|
197
|
+
|
|
198
|
+
raise TypeError(
|
|
199
|
+
f"Cannot detect unit from {type(data).__name__}. "
|
|
200
|
+
f"Expected mne.io.BaseRaw, mne.Epochs, or TimeSeries."
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
@public_api
|
|
205
|
+
def get_scale_factor(from_unit: str, to_unit: str) -> float:
|
|
206
|
+
"""Compute multiplicative factor to convert between units.
|
|
207
|
+
|
|
208
|
+
Parameters
|
|
209
|
+
----------
|
|
210
|
+
from_unit : str
|
|
211
|
+
Current unit (e.g. "V").
|
|
212
|
+
to_unit : str
|
|
213
|
+
Target unit (e.g. "uV").
|
|
214
|
+
|
|
215
|
+
Returns
|
|
216
|
+
-------
|
|
217
|
+
float
|
|
218
|
+
Multiply data by this value to convert.
|
|
219
|
+
|
|
220
|
+
Examples
|
|
221
|
+
--------
|
|
222
|
+
>>> get_scale_factor("V", "uV")
|
|
223
|
+
1000000.0
|
|
224
|
+
>>> get_scale_factor("uV", "V")
|
|
225
|
+
1e-06
|
|
226
|
+
"""
|
|
227
|
+
from_u = _normalize_unit(from_unit)
|
|
228
|
+
to_u = _normalize_unit(to_unit)
|
|
229
|
+
|
|
230
|
+
if from_u not in _UNIT_EXPONENTS:
|
|
231
|
+
raise ValueError(f"Unknown source unit '{from_unit}'")
|
|
232
|
+
if to_u not in _UNIT_EXPONENTS:
|
|
233
|
+
raise ValueError(f"Unknown target unit '{to_unit}'")
|
|
234
|
+
|
|
235
|
+
from_base = from_u[-1]
|
|
236
|
+
to_base = to_u[-1]
|
|
237
|
+
if from_base != to_base:
|
|
238
|
+
raise ValueError(
|
|
239
|
+
f"Cannot convert between different base units: "
|
|
240
|
+
f"'{from_unit}' ({from_base}) -> '{to_unit}' ({to_base})"
|
|
241
|
+
)
|
|
242
|
+
|
|
243
|
+
from_exp = _UNIT_EXPONENTS[from_u]
|
|
244
|
+
to_exp = _UNIT_EXPONENTS[to_u]
|
|
245
|
+
return 10.0 ** (from_exp - to_exp)
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
@public_api
|
|
249
|
+
def convert_unit(
|
|
250
|
+
data: Union[mne.io.BaseRaw, mne.Epochs, TimeSeries],
|
|
251
|
+
target: str,
|
|
252
|
+
*,
|
|
253
|
+
current_unit: Optional[str] = None,
|
|
254
|
+
copy: bool = True,
|
|
255
|
+
) -> Union[mne.io.BaseRaw, mne.Epochs, TimeSeries]:
|
|
256
|
+
"""Convert EEG data to a target unit.
|
|
257
|
+
|
|
258
|
+
Parameters
|
|
259
|
+
----------
|
|
260
|
+
data : mne.io.BaseRaw, mne.Epochs, or PTSA TimeSeries
|
|
261
|
+
The data to convert.
|
|
262
|
+
target : str
|
|
263
|
+
Target unit string (e.g. "uV", "mV", "V").
|
|
264
|
+
current_unit : str, optional
|
|
265
|
+
Override auto-detection of the current unit. Required if
|
|
266
|
+
the data object doesn't store unit metadata.
|
|
267
|
+
copy : bool
|
|
268
|
+
If True (default), return a copy. If False, modify in place.
|
|
269
|
+
|
|
270
|
+
Returns
|
|
271
|
+
-------
|
|
272
|
+
Same type as input, with data scaled to the target unit.
|
|
273
|
+
"""
|
|
274
|
+
detected = detect_unit(data, current_unit=current_unit)
|
|
275
|
+
target_normalized = _normalize_unit(target)
|
|
276
|
+
factor = get_scale_factor(detected, target_normalized)
|
|
277
|
+
|
|
278
|
+
if factor == 1.0:
|
|
279
|
+
return data.copy() if copy else data
|
|
280
|
+
|
|
281
|
+
if isinstance(data, (mne.io.BaseRaw, mne.Epochs)):
|
|
282
|
+
return _convert_mne(data, factor, target_normalized, copy)
|
|
283
|
+
|
|
284
|
+
if _is_timeseries(data):
|
|
285
|
+
return _convert_ptsa(data, factor, target_normalized, copy)
|
|
286
|
+
|
|
287
|
+
raise TypeError(f"Cannot convert type {type(data).__name__}")
|