brainsets 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
brainsets/__init__.py ADDED
@@ -0,0 +1,3 @@
1
+ __version__ = "0.1.0"
2
+
3
+ from .core import serialize_fn_map
brainsets/cli.py ADDED
@@ -0,0 +1,128 @@
1
+ import click
2
+ import json
3
+ from pathlib import Path
4
+ import subprocess
5
+
6
+
7
+ CONFIG_FILE = Path.home() / ".brainsets_config.json"
8
+
9
+ # TODO: Implement a function to dynamically generate this list
10
+ DATASETS = ["perich_miller_population_2018", "pei_pandarinath_nlb_2021"]
11
+
12
+
13
+ def load_config():
14
+ if CONFIG_FILE.exists():
15
+ with open(CONFIG_FILE, "r") as f:
16
+ return json.load(f)
17
+ return {"raw_dir": None, "processed_dir": None}
18
+
19
+
20
+ def save_config(config):
21
+ with open(CONFIG_FILE, "w") as f:
22
+ json.dump(config, f, indent=2)
23
+
24
+
25
+ @click.group()
26
+ def cli():
27
+ """Brainsets CLI tool."""
28
+ pass
29
+
30
+
31
+ @cli.command()
32
+ @click.argument("dataset", type=click.Choice(DATASETS, case_sensitive=False))
33
+ @click.option("-c", "--cores", default=4, help="Number of cores to use")
34
+ def prepare(dataset, cores):
35
+ """Download and process a specific dataset."""
36
+ click.echo(f"Preparing {dataset}...")
37
+
38
+ # Get config to check if directories are set
39
+ config = load_config()
40
+ if not config["raw_dir"] or not config["processed_dir"]:
41
+ click.echo(
42
+ "Error: Please set raw and processed directories first using 'brainsets config'"
43
+ )
44
+ return
45
+
46
+ # Run snakemake workflow for dataset download with live output
47
+ try:
48
+ process = subprocess.run(
49
+ [
50
+ "snakemake",
51
+ "--config",
52
+ f"raw_dir={config['raw_dir']}",
53
+ f"processed_dir={config['processed_dir']}",
54
+ f"-c{cores}",
55
+ f"{dataset}",
56
+ ],
57
+ check=True,
58
+ capture_output=False,
59
+ text=True,
60
+ )
61
+
62
+ if process.returncode == 0:
63
+ click.echo(f"Successfully downloaded {dataset}")
64
+ else:
65
+ click.echo("Error downloading dataset")
66
+ except subprocess.CalledProcessError as e:
67
+ click.echo(f"Error: Command failed with return code {e.returncode}")
68
+ except Exception as e:
69
+ click.echo(f"Error: {str(e)}")
70
+
71
+
72
+ @cli.command()
73
+ def list():
74
+ """List available datasets."""
75
+ click.echo("Available datasets:")
76
+ for dataset in DATASETS:
77
+ click.echo(f"- {dataset}")
78
+
79
+
80
+ @cli.command()
81
+ @click.option(
82
+ "--raw",
83
+ prompt="Enter raw data directory",
84
+ type=click.Path(file_okay=False, dir_okay=True),
85
+ required=False,
86
+ )
87
+ @click.option(
88
+ "--processed",
89
+ prompt="Enter processed data directory",
90
+ type=click.Path(file_okay=False, dir_okay=True),
91
+ required=False,
92
+ )
93
+ def config(raw, processed):
94
+ """Set raw and processed data directories."""
95
+ # Create directories if they don't exist
96
+ import os
97
+
98
+ # If no arguments provided, prompt for input
99
+ if raw is None or processed is None:
100
+ if raw is None:
101
+ raw = click.prompt(
102
+ "Enter raw data directory",
103
+ type=click.Path(file_okay=False, dir_okay=True),
104
+ )
105
+ if processed is None:
106
+ processed = click.prompt(
107
+ "Enter processed data directory",
108
+ type=click.Path(file_okay=False, dir_okay=True),
109
+ )
110
+
111
+ os.makedirs(raw, exist_ok=True)
112
+ os.makedirs(processed, exist_ok=True)
113
+
114
+ # Convert to absolute paths
115
+ raw = os.path.abspath(raw)
116
+ processed = os.path.abspath(processed)
117
+
118
+ config = load_config()
119
+ config["raw_dir"] = raw
120
+ config["processed_dir"] = processed
121
+ save_config(config)
122
+ click.echo("Configuration updated successfully.")
123
+ click.echo(f"Raw data directory: {raw}")
124
+ click.echo(f"Processed data directory: {processed}")
125
+
126
+
127
+ if __name__ == "__main__":
128
+ cli()
brainsets/core.py ADDED
@@ -0,0 +1,135 @@
1
+ from enum import Enum
2
+ import datetime
3
+
4
+
5
+ class NestedEnumType(type(Enum)):
6
+ def __new__(cls, clsname, bases, clsdict, parent=None):
7
+ new_cls = super().__new__(cls, clsname, bases, clsdict)
8
+ new_cls._parent = parent
9
+
10
+ if parent is not None:
11
+ parent._parent_cls = new_cls
12
+ for name, member in new_cls.__members__.items():
13
+ parent.__setattr__(name, member)
14
+
15
+ return new_cls
16
+
17
+ def __contains__(cls, member):
18
+ return (isinstance(member, cls) and (member._name_ in cls._member_map_)) or (
19
+ member._parent is not None and member._parent in cls
20
+ )
21
+
22
+
23
+ class StringIntEnum(Enum, metaclass=NestedEnumType):
24
+ r"""Base class for string-integer enums.
25
+
26
+ This class extends Python's built-in Enum class to provide:
27
+ - String representation via __str__
28
+ - Integer representation via __int__
29
+ - Case-insensitive string parsing via from_string()
30
+ - Maximum value lookup via max_value()
31
+
32
+ .. code-block:: python
33
+
34
+ >>> class Color(StringIntEnum):
35
+ ... RED = 1
36
+ ... BLUE = 2
37
+ >>> str(Color.RED)
38
+ 'RED'
39
+ >>> int(Color.RED)
40
+ 1
41
+ >>> Color.from_string("red")
42
+ <Color.RED: 1>
43
+ >>> Color.max_value()
44
+ 2
45
+ """
46
+
47
+ def __str__(self):
48
+ if self._parent is not None:
49
+ return f"{str(self._parent)}.{self.name}"
50
+ else:
51
+ return self.name
52
+
53
+ def __int__(self):
54
+ return self.value
55
+
56
+ @classmethod
57
+ def from_string(cls, string: str) -> "StringIntEnum":
58
+ r"""Convert a string to an enum member. This method is case insensitive and
59
+ will replace spaces with underscores.
60
+
61
+ Args:
62
+ string: The string to convert to an enum member.
63
+
64
+ Examples:
65
+ >>> from brainsets.taxonomy import Sex
66
+ >>> Sex.from_string("Male")
67
+ <Sex.MALE: 1>
68
+ >>> Sex.from_string("M")
69
+ <Sex.MALE: 1>
70
+ """
71
+ nested_string = string.split(".", maxsplit=1)
72
+ if len(nested_string) > 1:
73
+ parent = cls.from_string(nested_string[0])
74
+ return parent._parent_cls.from_string(nested_string[1])
75
+ else:
76
+ # normalize string by replacing spaces with underscores and converting
77
+ # to upper case
78
+ normalized_string = string.strip().upper().replace(" ", "_")
79
+ # create a mapping of enum names to enum members
80
+ mapping = {name.upper(): member for name, member in cls.__members__.items()}
81
+ # try to match the string to an enum name
82
+ if normalized_string in mapping:
83
+ return mapping[normalized_string]
84
+ # if there is no match raise an error
85
+ raise ValueError(
86
+ f"{normalized_string} does not exist in {cls.__name__}, "
87
+ "consider adding it to the enum."
88
+ )
89
+
90
+ @classmethod
91
+ def max_value(cls):
92
+ r"""Return the maximum value in the enum class."""
93
+ return max(cls.__members__.values(), key=lambda x: x.value).value
94
+
95
+
96
+ class Dictable:
97
+ r"""A dataclass that can be converted to a dict."""
98
+
99
+ def to_dict(self):
100
+ r"""Convert the dataclass instance to a dictionary.
101
+
102
+ Returns:
103
+ dict: A dictionary containing all fields of the dataclass as key-value pairs.
104
+
105
+ .. code-block:: python
106
+
107
+ >>> from dataclasses import dataclass
108
+ >>> @dataclass
109
+ ... class Person(Dictable):
110
+ ... name: str
111
+ ... age: int
112
+
113
+ >>> p = Person("Alice", 30)
114
+ >>> p.to_dict()
115
+ {'name': 'Alice', 'age': 30}
116
+ """
117
+ from dataclasses import asdict
118
+
119
+ return {k: v for k, v in asdict(self).items()} # type: ignore
120
+
121
+
122
+ def string_int_enum_serialize_fn(obj, serialize_fn_map=None):
123
+ r"""Convert a StringIntEnum object to a string."""
124
+ return str(obj)
125
+
126
+
127
+ def datetime_serialize_fn(obj, serialize_fn_map=None):
128
+ r"""Convert a datetime object to a string."""
129
+ return str(obj)
130
+
131
+
132
+ serialize_fn_map = {
133
+ StringIntEnum: string_int_enum_serialize_fn,
134
+ datetime.datetime: datetime_serialize_fn,
135
+ }
@@ -0,0 +1,124 @@
1
+ import datetime
2
+ from typing import Dict, List, Tuple, Optional, Union
3
+
4
+ from pydantic.dataclasses import dataclass
5
+ import temporaldata
6
+
7
+ import brainsets
8
+ from brainsets.taxonomy import *
9
+ from brainsets.taxonomy.mice import *
10
+
11
+
12
+ @dataclass
13
+ class BrainsetDescription(temporaldata.Data):
14
+ r"""A class for describing a brainset.
15
+
16
+ Parameters
17
+ ----------
18
+ id : str
19
+ Unique identifier for the brainset
20
+ origin_version : str
21
+ Version identifier for the original data source
22
+ derived_version : str
23
+ Version identifier for the derived/processed data
24
+ source : str
25
+ Original data source (usually a URL, or a short description otherwise)
26
+ description : str
27
+ Text description of the brainset
28
+ brainsets_version : str, optional
29
+ Version of brainsets package used, defaults to current version
30
+ temporaldata_version : str, optional
31
+ Version of temporaldata package used, defaults to current version
32
+ """
33
+
34
+ id: str
35
+ origin_version: str
36
+ derived_version: str
37
+ source: str
38
+ description: str
39
+ brainsets_version: str = brainsets.__version__
40
+ temporaldata_version: str = temporaldata.__version__
41
+
42
+
43
+ @dataclass
44
+ class SubjectDescription(temporaldata.Data):
45
+ r"""A class for describing a subject.
46
+
47
+ Parameters
48
+ ----------
49
+ id : str
50
+ Unique identifier for the subject
51
+ species : Species
52
+ Species of the subject
53
+ age : float, optional
54
+ Age of the subject in days, defaults to 0.0
55
+ sex : Sex, optional
56
+ Sex of the subject, defaults to UNKNOWN
57
+ genotype : str, optional
58
+ Genotype of the subject, defaults to "unknown"
59
+ cre_line : Cre_line, optional
60
+ Cre line of the subject, defaults to None
61
+ """
62
+
63
+ id: str
64
+ species: Species
65
+ age: float = 0.0 # in days
66
+ sex: Sex = Sex.UNKNOWN
67
+ genotype: str = "unknown" # no idea how many there will be for now.
68
+ cre_line: Optional[Cre_line] = None
69
+
70
+
71
+ @dataclass
72
+ class SessionDescription(temporaldata.Data):
73
+ r"""A class for describing an experimental session.
74
+
75
+ Parameters
76
+ ----------
77
+ id : str
78
+ Unique identifier for the session
79
+ recording_date : datetime.datetime
80
+ Date and time when the recording was made
81
+ task : Task
82
+ Task performed during the session
83
+ """
84
+
85
+ id: str
86
+ recording_date: datetime.datetime
87
+ task: Optional[Task] = None
88
+
89
+
90
+ @dataclass
91
+ class DeviceDescription(temporaldata.Data):
92
+ r"""A class for describing a recording device.
93
+
94
+ Parameters
95
+ ----------
96
+ id : str
97
+ Unique identifier for the device
98
+ recording_tech : RecordingTech or List[RecordingTech], optional
99
+ Recording technology used, defaults to None
100
+ processing : str, optional
101
+ Processing applied to the recording, defaults to None
102
+ chronic : bool, optional
103
+ Whether the device was chronically implanted, defaults to False
104
+ start_date : datetime.datetime, optional
105
+ Date when device was implanted/first used, defaults to None
106
+ end_date : datetime.datetime, optional
107
+ Date when device was removed/last used, defaults to None
108
+ imaging_depth : float, optional
109
+ Depth of imaging in micrometers, defaults to None
110
+ target_area : BrainRegion, optional
111
+ Target brain region for recording, defaults to None
112
+ """
113
+
114
+ id: str
115
+ # units: List[str]
116
+ # areas: Union[List[StringIntEnum], List[Macaque]]
117
+ recording_tech: Union[RecordingTech, List[RecordingTech]] = None
118
+ processing: Optional[str] = None
119
+ chronic: bool = False
120
+ start_date: Optional[datetime.datetime] = None
121
+ end_date: Optional[datetime.datetime] = None
122
+ # Ophys
123
+ imaging_depth: Optional[float] = None # in um
124
+ target_area: Optional[BrainRegion] = None
@@ -0,0 +1 @@
1
+ from .signal import downsample_wideband, extract_bands, cube_to_long
@@ -0,0 +1,169 @@
1
+ """Signal processing functions. Inspired by Stavisky et al. (2015).
2
+
3
+ https://dx.doi.org/10.1088/1741-2560/12/3/036009
4
+ """
5
+
6
+ from typing import List, Tuple
7
+
8
+ import numpy as np
9
+ import tqdm
10
+ from scipy import signal
11
+
12
+ from temporaldata import Data, IrregularTimeSeries, ArrayDict
13
+ from brainsets.taxonomy import RecordingTech
14
+
15
+
16
+ def downsample_wideband(
17
+ wideband: np.ndarray,
18
+ timestamps: np.ndarray,
19
+ wideband_Fs: float,
20
+ lfp_Fs: float = 1000,
21
+ ) -> tuple[np.ndarray, np.ndarray]:
22
+ """
23
+ Downsample wideband signal to LFP sampling rate.
24
+ """
25
+ assert wideband.shape[0] == timestamps.shape[0], "Time should be first dimension."
26
+ # Decimate by a factor of 4
27
+ dec_factor = 4
28
+ if wideband.shape[0] % dec_factor != 0:
29
+ wideband = wideband[: -(wideband.shape[0] % dec_factor), :]
30
+ timestamps = timestamps[: -(timestamps.shape[0] % dec_factor)]
31
+ wideband = wideband.reshape(-1, dec_factor, wideband.shape[1])
32
+ wideband = wideband.mean(axis=1)
33
+
34
+ timestamps = timestamps[::dec_factor]
35
+
36
+ nyq = 0.5 * wideband_Fs / dec_factor # Nyquist frequency
37
+ cutoff = 0.333 * lfp_Fs # remove everything above 170 Hz.
38
+ normal_cutoff = cutoff / nyq
39
+ b, a = signal.butter(4, normal_cutoff, btype="low", analog=False, output="ba")
40
+
41
+ # Interpolation to achieve the desired sampling rate
42
+ t_new = np.arange(timestamps[0], timestamps[-1], 1 / lfp_Fs)
43
+ lfp = np.zeros((len(t_new), wideband.shape[1]))
44
+ for i in range(wideband.shape[1]):
45
+ # We do this one channel at a time to save memory.
46
+ broadband_low = signal.filtfilt(b, a, wideband[:, i], axis=0)
47
+ lfp[:, i] = np.interp(t_new, timestamps, broadband_low)
48
+
49
+ return lfp, t_new
50
+
51
+
52
+ def extract_bands(
53
+ lfps: np.ndarray, ts: np.ndarray, Fs: float = 1000, notch: float = 60
54
+ ) -> Tuple[np.ndarray, np.ndarray, List]:
55
+ """Extract bands from LFP
56
+
57
+ We prefer to extract bands from the LFP upstream rather than downstream, because
58
+ it can be difficult to estimate e.g. the phase of low-frequency LFPs from
59
+ short segments.
60
+
61
+ We use the proposed bands from Stravisky et al. (2015), but we use the MNE toolbox
62
+ rather than straight scipy signal.
63
+ """
64
+ try:
65
+ import mne
66
+ except ImportError:
67
+ raise ImportError(
68
+ "This function requires the MNE library which you can install with "
69
+ "`pip install mne`"
70
+ )
71
+
72
+ target_Fs = 50
73
+ assert (
74
+ Fs % target_Fs == 0
75
+ ), "Sampling rate must be a multiple of the target frequency"
76
+
77
+ assert lfps.shape[0] == ts.shape[0], "Time should be first dimension."
78
+ info = mne.create_info(
79
+ ch_names=lfps.shape[1], sfreq=Fs, ch_types=["eeg"] * lfps.shape[1]
80
+ )
81
+ data = mne.io.RawArray(lfps.T, info)
82
+ data = data.notch_filter(np.arange(notch, notch * 5 + 1, notch), n_jobs=4)
83
+
84
+ filtered = []
85
+ band_names = ["delta", "theta", "alpha", "beta", "gamma", "lmp"]
86
+ bands = [(1, 4), (3, 10), (12, 23), (27, 38), (50, 300)]
87
+ for band_low, band_hi in bands:
88
+ band = data.copy().filter(band_low, band_hi, fir_design="firwin", n_jobs=4)
89
+ band = band.apply_function(lambda x: x**2, n_jobs=4)
90
+
91
+ band = band.filter(18, None, fir_design="firwin", n_jobs=4)
92
+ # It seems resample overwrites the original data, so we copy it first.
93
+ band = band.resample(target_Fs, npad="auto", n_jobs=4)
94
+
95
+ filtered.append(band.get_data().T)
96
+
97
+ lmp = data.copy().filter(0.1, 20, fir_design="firwin", n_jobs=4)
98
+ lmp = lmp.resample(target_Fs, npad="auto", n_jobs=4)
99
+ filtered.append(lmp.get_data().T)
100
+
101
+ ts = ts[int(Fs / target_Fs / 2) :: int(Fs / target_Fs)]
102
+ stacked = np.stack(filtered, axis=2)
103
+
104
+ # There can be off by one errors.
105
+ if stacked.shape[0] != len(ts):
106
+ stacked = stacked[: len(ts), :, :]
107
+
108
+ return stacked, ts, band_names
109
+
110
+
111
+ def cube_to_long(
112
+ ts: np.ndarray, cube: np.ndarray, channel_prefix="chan"
113
+ ) -> Tuple[List[IrregularTimeSeries], Data]:
114
+ """Convert a cube of threshold crossings to a list of trials and units."""
115
+ assert cube.shape[1] == len(ts)
116
+ assert cube.ndim == 3
117
+ channels = np.arange(cube.shape[2])
118
+ channels = np.tile(channels, [cube.shape[1], 1])
119
+
120
+ # First dim is batch, second is time, third is channel.
121
+ assert np.issubdtype(cube.dtype, np.integer)
122
+ assert cube.min() >= 0
123
+
124
+ ts = np.tile(ts.reshape((-1, 1)), [1, cube.shape[2]])
125
+ assert ts.shape == channels.shape
126
+
127
+ # The first dimension we map to a single trial.
128
+ trials = []
129
+ for b in tqdm.tqdm(range(cube.shape[0])):
130
+ cube_ = cube[b, :, :]
131
+ ts_ = []
132
+ channels_ = []
133
+
134
+ # This data is binned, so we create N identifical timestamps when there are N
135
+ # spikes in a bin.
136
+ for n in range(1, cube_.max() + 1):
137
+ ts_.append(ts[cube_ >= n])
138
+ channels_.append(channels[cube_ >= n])
139
+
140
+ ts_ = np.concatenate(ts_)
141
+ channels_ = np.concatenate(channels_)
142
+
143
+ tidx = np.argsort(ts_)
144
+ ts_ = ts_[tidx]
145
+ channels_ = channels_[tidx]
146
+
147
+ trials.append(
148
+ IrregularTimeSeries(
149
+ timestamps=ts_,
150
+ unit_index=channels_,
151
+ types=np.ones(len(ts_))
152
+ * int(RecordingTech.UTAH_ARRAY_THRESHOLD_CROSSINGS),
153
+ domain="auto",
154
+ )
155
+ )
156
+
157
+ counts = cube.sum(axis=0).sum(axis=0)
158
+ units = ArrayDict(
159
+ count=np.array(counts.astype(int)),
160
+ channel_name=np.array(
161
+ [f"{channel_prefix}{c:03}" for c in range(cube.shape[2])]
162
+ ),
163
+ unit_number=np.zeros(cube.shape[2]),
164
+ id=np.array([f"{channel_prefix}{c}" for c in range(cube.shape[2])]),
165
+ channel_number=np.arange(cube.shape[2]),
166
+ type=np.ones(cube.shape[2]) * int(RecordingTech.UTAH_ARRAY_THRESHOLD_CROSSINGS),
167
+ )
168
+
169
+ return trials, units
@@ -0,0 +1,17 @@
1
+ from .subject import (
2
+ Species,
3
+ Sex,
4
+ )
5
+
6
+ from .task import (
7
+ Task,
8
+ )
9
+
10
+ from .drifting_gratings import Orientation_8_Classes
11
+ from .macaque import Macaque
12
+ from .mice import Cre_line
13
+
14
+ from .recording_tech import (
15
+ RecordingTech,
16
+ Hemisphere,
17
+ )
@@ -0,0 +1,28 @@
1
+ ORIENTATION_8_CLASSES_map = {
2
+ 0.0: 0,
3
+ 45.0: 1,
4
+ 90.0: 2,
5
+ 135.0: 3,
6
+ 180.0: 4,
7
+ 225.0: 5,
8
+ 270.0: 6,
9
+ 315.0: 7,
10
+ }
11
+ ORIENTATION_12_CLASSES_map = {
12
+ 0.0: 0,
13
+ 30.0: 1,
14
+ 60.0: 2,
15
+ 90.0: 3,
16
+ 120.0: 4,
17
+ 150.0: 5,
18
+ 180.0: 6,
19
+ 210.0: 7,
20
+ 240.0: 8,
21
+ 270.0: 9,
22
+ 300.0: 10,
23
+ 330.0: 11,
24
+ }
25
+
26
+ TEMPORAL_FREQ_5_map = {1.0: 0, 2.0: 1, 4.0: 2, 8.0: 3, 15.0: 4}
27
+ SPATIAL_FREQ_5_map = {0.02: 0, 0.04: 1, 0.08: 2, 0.16: 3, 0.32: 4}
28
+ PHASE_4_map = {0.0: 0, 90.0: 1, 180.0: 2, 270.0: 3}
@@ -0,0 +1,12 @@
1
+ from brainsets.core import StringIntEnum
2
+
3
+
4
+ class Orientation_8_Classes(StringIntEnum):
5
+ angle_0 = 0
6
+ angle_45 = 1
7
+ angle_90 = 2
8
+ angle_135 = 3
9
+ angle_180 = 4
10
+ angle_225 = 5
11
+ angle_270 = 6
12
+ angle_315 = 7