py-neuromodulation 0.0.4__py3-none-any.whl → 0.0.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (80) hide show
  1. py_neuromodulation/ConnectivityDecoding/_get_grid_hull.m +34 -34
  2. py_neuromodulation/ConnectivityDecoding/_get_grid_whole_brain.py +95 -106
  3. py_neuromodulation/ConnectivityDecoding/_helper_write_connectome.py +107 -119
  4. py_neuromodulation/FieldTrip.py +589 -589
  5. py_neuromodulation/__init__.py +74 -13
  6. py_neuromodulation/_write_example_dataset_helper.py +83 -65
  7. py_neuromodulation/data/README +6 -6
  8. py_neuromodulation/data/dataset_description.json +8 -8
  9. py_neuromodulation/data/participants.json +32 -32
  10. py_neuromodulation/data/participants.tsv +2 -2
  11. py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_space-mni_coordsystem.json +5 -5
  12. py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_space-mni_electrodes.tsv +11 -11
  13. py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_task-gripforce_run-0_channels.tsv +11 -11
  14. py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_task-gripforce_run-0_ieeg.json +18 -18
  15. py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_task-gripforce_run-0_ieeg.vhdr +35 -35
  16. py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_task-gripforce_run-0_ieeg.vmrk +13 -13
  17. py_neuromodulation/data/sub-testsub/ses-EphysMedOff/sub-testsub_ses-EphysMedOff_scans.tsv +2 -2
  18. py_neuromodulation/grid_cortex.tsv +40 -40
  19. py_neuromodulation/liblsl/libpugixml.so.1.12 +0 -0
  20. py_neuromodulation/liblsl/linux/bionic_amd64/liblsl.1.16.2.so +0 -0
  21. py_neuromodulation/liblsl/linux/bookworm_amd64/liblsl.1.16.2.so +0 -0
  22. py_neuromodulation/liblsl/linux/focal_amd46/liblsl.1.16.2.so +0 -0
  23. py_neuromodulation/liblsl/linux/jammy_amd64/liblsl.1.16.2.so +0 -0
  24. py_neuromodulation/liblsl/linux/jammy_x86/liblsl.1.16.2.so +0 -0
  25. py_neuromodulation/liblsl/linux/noble_amd64/liblsl.1.16.2.so +0 -0
  26. py_neuromodulation/liblsl/macos/amd64/liblsl.1.16.2.dylib +0 -0
  27. py_neuromodulation/liblsl/macos/arm64/liblsl.1.16.0.dylib +0 -0
  28. py_neuromodulation/liblsl/windows/amd64/liblsl.1.16.2.dll +0 -0
  29. py_neuromodulation/liblsl/windows/x86/liblsl.1.16.2.dll +0 -0
  30. py_neuromodulation/nm_IO.py +413 -417
  31. py_neuromodulation/nm_RMAP.py +496 -531
  32. py_neuromodulation/nm_analysis.py +993 -1074
  33. py_neuromodulation/nm_artifacts.py +30 -25
  34. py_neuromodulation/nm_bispectra.py +154 -168
  35. py_neuromodulation/nm_bursts.py +292 -198
  36. py_neuromodulation/nm_coherence.py +251 -205
  37. py_neuromodulation/nm_database.py +149 -0
  38. py_neuromodulation/nm_decode.py +918 -992
  39. py_neuromodulation/nm_define_nmchannels.py +300 -302
  40. py_neuromodulation/nm_features.py +144 -116
  41. py_neuromodulation/nm_filter.py +219 -219
  42. py_neuromodulation/nm_filter_preprocessing.py +79 -91
  43. py_neuromodulation/nm_fooof.py +139 -159
  44. py_neuromodulation/nm_generator.py +45 -37
  45. py_neuromodulation/nm_hjorth_raw.py +52 -73
  46. py_neuromodulation/nm_kalmanfilter.py +71 -58
  47. py_neuromodulation/nm_linelength.py +21 -33
  48. py_neuromodulation/nm_logger.py +66 -0
  49. py_neuromodulation/nm_mne_connectivity.py +149 -112
  50. py_neuromodulation/nm_mnelsl_generator.py +90 -0
  51. py_neuromodulation/nm_mnelsl_stream.py +116 -0
  52. py_neuromodulation/nm_nolds.py +96 -93
  53. py_neuromodulation/nm_normalization.py +173 -214
  54. py_neuromodulation/nm_oscillatory.py +423 -448
  55. py_neuromodulation/nm_plots.py +585 -612
  56. py_neuromodulation/nm_preprocessing.py +83 -0
  57. py_neuromodulation/nm_projection.py +370 -394
  58. py_neuromodulation/nm_rereference.py +97 -95
  59. py_neuromodulation/nm_resample.py +59 -50
  60. py_neuromodulation/nm_run_analysis.py +325 -435
  61. py_neuromodulation/nm_settings.py +289 -68
  62. py_neuromodulation/nm_settings.yaml +244 -0
  63. py_neuromodulation/nm_sharpwaves.py +423 -401
  64. py_neuromodulation/nm_stats.py +464 -480
  65. py_neuromodulation/nm_stream.py +398 -0
  66. py_neuromodulation/nm_stream_abc.py +166 -218
  67. py_neuromodulation/nm_types.py +193 -0
  68. {py_neuromodulation-0.0.4.dist-info → py_neuromodulation-0.0.5.dist-info}/METADATA +29 -26
  69. py_neuromodulation-0.0.5.dist-info/RECORD +83 -0
  70. {py_neuromodulation-0.0.4.dist-info → py_neuromodulation-0.0.5.dist-info}/WHEEL +1 -1
  71. {py_neuromodulation-0.0.4.dist-info → py_neuromodulation-0.0.5.dist-info}/licenses/LICENSE +21 -21
  72. py_neuromodulation/nm_EpochStream.py +0 -92
  73. py_neuromodulation/nm_across_patient_decoding.py +0 -927
  74. py_neuromodulation/nm_cohortwrapper.py +0 -435
  75. py_neuromodulation/nm_eval_timing.py +0 -239
  76. py_neuromodulation/nm_features_abc.py +0 -39
  77. py_neuromodulation/nm_settings.json +0 -338
  78. py_neuromodulation/nm_stream_offline.py +0 -359
  79. py_neuromodulation/utils/_logging.py +0 -24
  80. py_neuromodulation-0.0.4.dist-info/RECORD +0 -72
@@ -1,205 +1,251 @@
1
- from scipy import signal
2
- import numpy as np
3
- from typing import Iterable
4
- import warnings
5
-
6
- from py_neuromodulation import nm_features_abc
7
-
8
-
9
- class CoherenceObject:
10
- def __init__(
11
- self,
12
- sfreq,
13
- window,
14
- fbands,
15
- fband_names,
16
- ch_1_name,
17
- ch_2_name,
18
- ch_1_idx,
19
- ch_2_idx,
20
- coh,
21
- icoh,
22
- features_coh,
23
- ) -> None:
24
- self.sfreq = sfreq
25
- self.window = window
26
- self.Pxx = None
27
- self.Pyy = None
28
- self.Pxy = None
29
- self.f = None
30
- self.coh = coh
31
- self.icoh = icoh
32
- self.coh_val = None
33
- self.icoh_val = None
34
- self.ch_1 = ch_1_name
35
- self.ch_2 = ch_2_name
36
- self.ch_1_idx = ch_1_idx
37
- self.ch_2_idx = ch_2_idx
38
- self.fbands = fbands # list of lists, e.g. [[10, 15], [15, 20]]
39
- self.fband_names = fband_names
40
- self.features_coh = features_coh
41
-
42
- def get_coh(self, features_compute, x, y):
43
- self.f, self.Pxx = signal.welch(x, self.sfreq, self.window, nperseg=128)
44
- self.Pyy = signal.welch(y, self.sfreq, self.window, nperseg=128)[1]
45
- self.Pxy = signal.csd(x, y, self.sfreq, self.window, nperseg=128)[1]
46
-
47
- if self.coh is True:
48
- self.coh_val = np.abs(self.Pxy**2) / (self.Pxx * self.Pyy)
49
- if self.icoh is True:
50
- self.icoh_val = np.array(self.Pxy / (self.Pxx * self.Pyy)).imag
51
-
52
- for coh_idx, coh_type in enumerate([self.coh, self.icoh]):
53
- if coh_type is True:
54
- if coh_idx == 0:
55
- coh_val = self.coh_val
56
- coh_name = "coh"
57
- else:
58
- coh_val = self.icoh_val
59
- coh_name = "icoh"
60
-
61
- for idx, fband in enumerate(self.fbands):
62
- if self.features_coh["mean_fband"] is True:
63
- feature_calc = np.mean(
64
- coh_val[
65
- np.bitwise_and(self.f > fband[0], self.f < fband[1])
66
- ]
67
- )
68
- feature_name = "_".join(
69
- [
70
- coh_name,
71
- self.ch_1,
72
- "to",
73
- self.ch_2,
74
- "mean_fband",
75
- self.fband_names[idx],
76
- ]
77
- )
78
- features_compute[feature_name] = feature_calc
79
- if self.features_coh["max_fband"] is True:
80
- feature_calc = np.max(
81
- coh_val[
82
- np.bitwise_and(self.f > fband[0], self.f < fband[1])
83
- ]
84
- )
85
- feature_name = "_".join(
86
- [
87
- coh_name,
88
- self.ch_1,
89
- "to",
90
- self.ch_2,
91
- "max_fband",
92
- self.fband_names[idx],
93
- ]
94
- )
95
- features_compute[feature_name] = feature_calc
96
- if self.features_coh["max_allfbands"] is True:
97
- feature_calc = self.f[np.argmax(coh_val)]
98
- feature_name = "_".join(
99
- [
100
- coh_name,
101
- self.ch_1,
102
- "to",
103
- self.ch_2,
104
- "max_allfbands",
105
- self.fband_names[idx],
106
- ]
107
- )
108
- features_compute[feature_name] = feature_calc
109
- return features_compute
110
-
111
-
112
- class NM_Coherence(nm_features_abc.Feature):
113
-
114
-
115
-
116
- def __init__(
117
- self, settings: dict, ch_names: Iterable[str], sfreq: float
118
- ) -> None:
119
- self.s = settings
120
- self.sfreq = sfreq
121
- self.ch_names = ch_names
122
- self.coherence_objects: Iterable[CoherenceObject] = []
123
-
124
- for idx_coh in range(len(self.s["coherence"]["channels"])):
125
- fband_names = self.s["coherence"]["frequency_bands"]
126
- fband_specs = []
127
- for band_name in fband_names:
128
- fband_specs.append(self.s["frequency_ranges_hz"][band_name])
129
-
130
- ch_1_name = self.s["coherence"]["channels"][idx_coh][0]
131
- ch_1_name_reref = [
132
- ch for ch in self.ch_names if ch.startswith(ch_1_name)
133
- ][0]
134
- ch_1_idx = self.ch_names.index(ch_1_name_reref)
135
-
136
- ch_2_name = self.s["coherence"]["channels"][idx_coh][1]
137
- ch_2_name_reref = [
138
- ch for ch in self.ch_names if ch.startswith(ch_2_name)
139
- ][0]
140
- ch_2_idx = self.ch_names.index(ch_2_name_reref)
141
-
142
- self.coherence_objects.append(
143
- CoherenceObject(
144
- sfreq,
145
- "hann",
146
- fband_specs,
147
- fband_names,
148
- ch_1_name,
149
- ch_2_name,
150
- ch_1_idx,
151
- ch_2_idx,
152
- self.s["coherence"]["method"]["coh"],
153
- self.s["coherence"]["method"]["icoh"],
154
- self.s["coherence"]["features"],
155
- )
156
- )
157
-
158
- @staticmethod
159
- def test_settings(
160
- s: dict,
161
- ch_names: Iterable[str],
162
- sfreq: int | float,
163
- ):
164
-
165
- assert (
166
- len(s["coherence"]["frequency_bands"]) > 0
167
- ), "coherence frequency_bands list needs to specify at least one frequency band"
168
- assert (ch_coh in ch_names for ch_coh in s["coherence"]["channels"]), (
169
- f"coherence selected channels don't match the ones in nm_channels"
170
- f"ch_names: {ch_names} settings['coherence']['channels']: {s['coherence']['channels']}"
171
- )
172
-
173
- assert (
174
- f_band_coh in s["frequency_ranges_hz"]
175
- for f_band_coh in s["coherence"]["frequency_bands"]
176
- ), (
177
- "coherence selected frequency bands don't match the ones"
178
- "specified in s['frequency_ranges_hz']"
179
- f"coherence frequency bands: {s['coherence']['frequency_bands']}"
180
- f"specified frequency_ranges_hz: {s['frequency_ranges_hz']}"
181
- )
182
-
183
- assert (
184
- s["frequency_ranges_hz"][fb][0] < sfreq / 2
185
- and s["frequency_ranges_hz"][fb][1] < sfreq / 2
186
- for fb in s["coherence"]["frequency_bands"]
187
- ), (
188
- "the coherence frequency band ranges need to be smaller than the nyquist frequency"
189
- f"got sfreq = {sfreq} and fband ranges {s['coherence']['frequency_bands']}"
190
- )
191
-
192
- if sum(list(s["coherence"]["method"].values())) == 0:
193
- warnings.warn(
194
- "feature coherence enabled, but no coherence['method'] selected"
195
- )
196
-
197
- def calc_feature(self, data: np.array, features_compute: dict) -> dict:
198
- for coh_obj in self.coherence_objects:
199
- features_compute = coh_obj.get_coh(
200
- features_compute,
201
- data[coh_obj.ch_1_idx, :],
202
- data[coh_obj.ch_2_idx, :],
203
- )
204
-
205
- return features_compute
1
+ import numpy as np
2
+ from collections.abc import Iterable
3
+
4
+
5
+ from typing import TYPE_CHECKING, Annotated
6
+ from pydantic import Field, field_validator
7
+
8
+ from py_neuromodulation.nm_features import NMFeature
9
+ from py_neuromodulation.nm_types import BoolSelector, FrequencyRange, NMBaseModel
10
+ from py_neuromodulation import logger
11
+
12
+ if TYPE_CHECKING:
13
+ from py_neuromodulation.nm_settings import NMSettings
14
+
15
+
16
+ class CoherenceMethods(BoolSelector):
17
+ coh: bool = True
18
+ icoh: bool = True
19
+
20
+
21
+ class CoherenceFeatures(BoolSelector):
22
+ mean_fband: bool = True
23
+ max_fband: bool = True
24
+ max_allfbands: bool = True
25
+
26
+
27
+ ListOfTwoStr = Annotated[list[str], Field(min_length=2, max_length=2)]
28
+
29
+
30
+ class CoherenceSettings(NMBaseModel):
31
+ features: CoherenceFeatures = CoherenceFeatures()
32
+ method: CoherenceMethods = CoherenceMethods()
33
+ channels: list[ListOfTwoStr] = []
34
+ frequency_bands: list[str] = Field(default=["high_beta"], min_length=1)
35
+
36
+ @field_validator("frequency_bands")
37
+ def fbands_spaces_to_underscores(cls, frequency_bands):
38
+ return [f.replace(" ", "_") for f in frequency_bands]
39
+
40
+
41
+ class CoherenceObject:
42
+ def __init__(
43
+ self,
44
+ sfreq: float,
45
+ window: str,
46
+ fbands: list[FrequencyRange],
47
+ fband_names: list[str],
48
+ ch_1_name: str,
49
+ ch_2_name: str,
50
+ ch_1_idx: int,
51
+ ch_2_idx: int,
52
+ coh: bool,
53
+ icoh: bool,
54
+ features_coh: CoherenceFeatures,
55
+ ) -> None:
56
+ self.sfreq = sfreq
57
+ self.window = window
58
+ self.fbands = fbands
59
+ self.fband_names = fband_names
60
+ self.ch_1 = ch_1_name
61
+ self.ch_2 = ch_2_name
62
+ self.ch_1_idx = ch_1_idx
63
+ self.ch_2_idx = ch_2_idx
64
+ self.coh = coh
65
+ self.icoh = icoh
66
+ self.features_coh = features_coh
67
+
68
+ self.Pxx = None
69
+ self.Pyy = None
70
+ self.Pxy = None
71
+ self.f = None
72
+ self.coh_val = None
73
+ self.icoh_val = None
74
+
75
+ def get_coh(self, feature_results, x, y):
76
+ from scipy.signal import welch, csd
77
+
78
+ self.f, self.Pxx = welch(x, self.sfreq, self.window, nperseg=128)
79
+ self.Pyy = welch(y, self.sfreq, self.window, nperseg=128)[1]
80
+ self.Pxy = csd(x, y, self.sfreq, self.window, nperseg=128)[1]
81
+
82
+ if self.coh:
83
+ self.coh_val = np.abs(self.Pxy**2) / (self.Pxx * self.Pyy)
84
+ if self.icoh:
85
+ self.icoh_val = np.array(self.Pxy / (self.Pxx * self.Pyy)).imag
86
+
87
+ for coh_idx, coh_type in enumerate([self.coh, self.icoh]):
88
+ if coh_type:
89
+ if coh_idx == 0:
90
+ coh_val = self.coh_val
91
+ coh_name = "coh"
92
+ else:
93
+ coh_val = self.icoh_val
94
+ coh_name = "icoh"
95
+
96
+ for idx, fband in enumerate(self.fbands):
97
+ if self.features_coh.mean_fband:
98
+ feature_calc = np.mean(
99
+ coh_val[np.bitwise_and(self.f > fband[0], self.f < fband[1])]
100
+ )
101
+ feature_name = "_".join(
102
+ [
103
+ coh_name,
104
+ self.ch_1,
105
+ "to",
106
+ self.ch_2,
107
+ "mean_fband",
108
+ self.fband_names[idx],
109
+ ]
110
+ )
111
+ feature_results[feature_name] = feature_calc
112
+ if self.features_coh.max_fband:
113
+ feature_calc = np.max(
114
+ coh_val[np.bitwise_and(self.f > fband[0], self.f < fband[1])]
115
+ )
116
+ feature_name = "_".join(
117
+ [
118
+ coh_name,
119
+ self.ch_1,
120
+ "to",
121
+ self.ch_2,
122
+ "max_fband",
123
+ self.fband_names[idx],
124
+ ]
125
+ )
126
+ feature_results[feature_name] = feature_calc
127
+ if self.features_coh.max_allfbands:
128
+ feature_calc = self.f[np.argmax(coh_val)]
129
+ feature_name = "_".join(
130
+ [
131
+ coh_name,
132
+ self.ch_1,
133
+ "to",
134
+ self.ch_2,
135
+ "max_allfbands",
136
+ self.fband_names[idx],
137
+ ]
138
+ )
139
+ feature_results[feature_name] = feature_calc
140
+ return feature_results
141
+
142
+
143
+ class NMCoherence(NMFeature):
144
+ def __init__(
145
+ self, settings: "NMSettings", ch_names: list[str], sfreq: float
146
+ ) -> None:
147
+ self.settings = settings.coherence
148
+ self.frequency_ranges_hz = settings.frequency_ranges_hz
149
+ self.sfreq = sfreq
150
+ self.ch_names = ch_names
151
+ self.coherence_objects: Iterable[CoherenceObject] = []
152
+
153
+ self.test_settings(settings, ch_names, sfreq)
154
+
155
+ for idx_coh in range(len(self.settings.channels)):
156
+ fband_names = self.settings.frequency_bands
157
+ fband_specs = []
158
+ for band_name in fband_names:
159
+ fband_specs.append(self.frequency_ranges_hz[band_name])
160
+
161
+ ch_1_name = self.settings.channels[idx_coh][0]
162
+ ch_1_name_reref = [ch for ch in self.ch_names if ch.startswith(ch_1_name)][
163
+ 0
164
+ ]
165
+ ch_1_idx = self.ch_names.index(ch_1_name_reref)
166
+
167
+ ch_2_name = self.settings.channels[idx_coh][1]
168
+ ch_2_name_reref = [ch for ch in self.ch_names if ch.startswith(ch_2_name)][
169
+ 0
170
+ ]
171
+ ch_2_idx = self.ch_names.index(ch_2_name_reref)
172
+
173
+ self.coherence_objects.append(
174
+ CoherenceObject(
175
+ sfreq,
176
+ "hann",
177
+ fband_specs,
178
+ fband_names,
179
+ ch_1_name,
180
+ ch_2_name,
181
+ ch_1_idx,
182
+ ch_2_idx,
183
+ self.settings.method.coh,
184
+ self.settings.method.icoh,
185
+ self.settings.features,
186
+ )
187
+ )
188
+
189
+ @staticmethod
190
+ def test_settings(
191
+ settings: "NMSettings",
192
+ ch_names: Iterable[str],
193
+ sfreq: float,
194
+ ):
195
+ flat_channels = [
196
+ ch for ch_pair in settings.coherence.channels for ch in ch_pair
197
+ ]
198
+
199
+ valid_coh_channel = [
200
+ sum(ch.startswith(ch_coh) for ch in ch_names) for ch_coh in flat_channels
201
+ ]
202
+ for ch_idx, ch_coh in enumerate(flat_channels):
203
+ if valid_coh_channel[ch_idx] == 0:
204
+ raise RuntimeError(
205
+ f"Coherence selected channel {ch_coh} does not match any channel name: \n"
206
+ f" - settings.coherence.channels: {settings.coherence.channels}\n"
207
+ f" - ch_names: {ch_names} \n"
208
+ )
209
+
210
+ if valid_coh_channel[ch_idx] > 1:
211
+ raise RuntimeError(
212
+ f"Coherence selected channel {ch_coh} is ambigous and matches more than one channel name: \n"
213
+ f" - settings.coherence.channels: {settings.coherence.channels}\n"
214
+ f" - ch_names: {ch_names} \n"
215
+ )
216
+
217
+ assert all(
218
+ f_band_coh in settings.frequency_ranges_hz
219
+ for f_band_coh in settings.coherence.frequency_bands
220
+ ), (
221
+ "coherence selected frequency bands don't match the ones"
222
+ "specified in s['frequency_ranges_hz']"
223
+ f"coherence frequency bands: {settings.coherence.frequency_bands}"
224
+ f"specified frequency_ranges_hz: {settings.frequency_ranges_hz}"
225
+ )
226
+
227
+ assert all(
228
+ settings.frequency_ranges_hz[fb][0] < sfreq / 2
229
+ and settings.frequency_ranges_hz[fb][1] < sfreq / 2
230
+ for fb in settings.coherence.frequency_bands
231
+ ), (
232
+ "the coherence frequency band ranges need to be smaller than the Nyquist frequency"
233
+ f"got sfreq = {sfreq} and fband ranges {settings.coherence.frequency_bands}"
234
+ )
235
+
236
+ if not settings.coherence.method.get_enabled():
237
+ logger.warn(
238
+ "feature coherence enabled, but no coherence['method'] selected"
239
+ )
240
+
241
+ def calc_feature(self, data: np.ndarray) -> dict:
242
+ feature_results = {}
243
+
244
+ for coh_obj in self.coherence_objects:
245
+ feature_results = coh_obj.get_coh(
246
+ feature_results,
247
+ data[coh_obj.ch_1_idx, :],
248
+ data[coh_obj.ch_2_idx, :],
249
+ )
250
+
251
+ return feature_results
@@ -0,0 +1,149 @@
1
+ import sqlite3
2
+ from pathlib import Path
3
+ import pandas as pd
4
+ from py_neuromodulation.nm_types import _PathLike
5
+ from py_neuromodulation.nm_IO import generate_unique_filename
6
+
7
+
8
+ class NMDatabase:
9
+ """
10
+ Class to create a database and insert data into it.
11
+ Parameters
12
+ ----------
13
+ out_dir : _PathLike
14
+ The directory to save the database.
15
+ csv_path : str, optional
16
+ The path to save the csv file. If not provided, it will be saved in the same folder as the database.
17
+ """
18
+
19
+ def __init__(
20
+ self,
21
+ name: str,
22
+ out_dir: _PathLike,
23
+ csv_path: _PathLike | None = None,
24
+ ):
25
+ # Make sure out_dir exists
26
+ Path(out_dir).mkdir(parents=True, exist_ok=True)
27
+
28
+ self.db_path = Path(out_dir, f"{name}.db")
29
+
30
+ self.table_name = f"{name}_data" # change to param?
31
+ self.table_created = False
32
+
33
+ if self.db_path.exists():
34
+ self.db_path = generate_unique_filename(self.db_path)
35
+ name = self.db_path.stem
36
+
37
+ if csv_path is None:
38
+ self.csv_path = Path(out_dir, f"{name}.csv")
39
+ else:
40
+ self.csv_path = Path(csv_path)
41
+
42
+ self.csv_path.parent.mkdir(parents=True, exist_ok=True)
43
+
44
+ self.conn = sqlite3.connect(self.db_path)
45
+ self.cursor = self.conn.cursor()
46
+
47
+ # Database config and optimization, prioritize data integrity
48
+ self.cursor.execute("PRAGMA journal_mode=WAL") # Write-Ahead Logging mode
49
+ self.cursor.execute("PRAGMA synchronous=FULL") # Sync on every commit
50
+ self.cursor.execute("PRAGMA temp_store=MEMORY") # Store temp tables in memory
51
+ self.cursor.execute(
52
+ "PRAGMA wal_autocheckpoint = 1000"
53
+ ) # WAL checkpoint every 1000 pages (default, 4MB, might change)
54
+ self.cursor.execute(
55
+ f"PRAGMA mmap_size = {2 * 1024 * 1024 * 1024}"
56
+ ) # 2GB of memory mapped
57
+
58
+ def infer_type(self, value):
59
+ """Infer the type of the value to create the table schema.
60
+ Parameters
61
+ ----------
62
+ value : int, float, str
63
+ The value to infer the type."""
64
+
65
+ if isinstance(value, (int, float)):
66
+ return "REAL"
67
+ elif isinstance(value, str):
68
+ return "TEXT"
69
+ else:
70
+ return "BLOB"
71
+
72
+ def create_table(self, feature_dict: dict):
73
+ """
74
+ Create a table in the database.
75
+ Parameters
76
+ ----------
77
+ feature_dict : dict
78
+ The dictionary with the feature names and values.
79
+ """
80
+ columns_schema = ", ".join(
81
+ [
82
+ f'"{column}" {self.infer_type(value)}'
83
+ for column, value in feature_dict.items()
84
+ ]
85
+ )
86
+
87
+ self.cursor.execute(
88
+ f'CREATE TABLE IF NOT EXISTS "{self.table_name}" ({columns_schema})'
89
+ )
90
+
91
+ # Create column names and placeholders for insert statement
92
+ self.columns: str = ", ".join([f'"{column}"' for column in feature_dict.keys()])
93
+ # Use named placeholders for more resiliency against unexpected change in column order
94
+ self.placeholders = ", ".join([f":{key}" for key in feature_dict.keys()])
95
+
96
+ def insert_data(self, feature_dict: dict):
97
+ """
98
+ Insert data into the database.
99
+ Parameters
100
+ ----------
101
+ feature_dict : dict
102
+ The dictionary with the feature names and values.
103
+ """
104
+
105
+ if not self.table_created:
106
+ self.create_table(feature_dict)
107
+ self.table_created = True
108
+
109
+ insert_sql = f'INSERT INTO "{self.table_name}" ({self.columns}) VALUES ({self.placeholders})'
110
+
111
+ self.cursor.execute(insert_sql, feature_dict)
112
+
113
+ def commit(self):
114
+ self.conn.commit()
115
+
116
+ def fetch_all(self):
117
+ """ "
118
+ Fetch all the data from the database.
119
+ Returns
120
+ -------
121
+ pd.DataFrame
122
+ The data in a pandas DataFrame.
123
+ """
124
+ return pd.read_sql_query(f'SELECT * FROM "{self.table_name}"', self.conn)
125
+
126
+ def head(self, n: int = 5):
127
+ """ "
128
+ Returns the first N rows of the database.
129
+ Parameters
130
+ ----------
131
+ n : int, optional
132
+ The number of rows to fetch, by default 1
133
+ -------
134
+ pd.DataFrame
135
+ The data in a pandas DataFrame.
136
+ """
137
+ return pd.read_sql_query(
138
+ f'SELECT * FROM "{self.table_name}" LIMIT {n}', self.conn
139
+ )
140
+
141
+ def save_as_csv(self):
142
+ df = self.fetch_all()
143
+ df.to_csv(self.csv_path, index=False)
144
+
145
+ def close(self):
146
+ # Optimize before closing is recommended:
147
+ # https://www.sqlite.org/pragma.html#pragma_optimize
148
+ self.cursor.execute("PRAGMA optimize")
149
+ self.conn.close()