ncpi 0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ncpi/Analysis.py +524 -0
- ncpi/Features.py +725 -0
- ncpi/FieldPotential.py +884 -0
- ncpi/Inference.py +545 -0
- ncpi/Simulation.py +65 -0
- ncpi/__init__.py +10 -0
- ncpi/tools.py +69 -0
- ncpi-0.1.dist-info/METADATA +42 -0
- ncpi-0.1.dist-info/RECORD +11 -0
- ncpi-0.1.dist-info/WHEEL +5 -0
- ncpi-0.1.dist-info/top_level.txt +1 -0
ncpi/Analysis.py
ADDED
|
@@ -0,0 +1,524 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import pandas as pd
|
|
3
|
+
import scipy.interpolate
|
|
4
|
+
import matplotlib.pyplot as plt
|
|
5
|
+
import matplotlib.patches as mpatches
|
|
6
|
+
from matplotlib.cm import ScalarMappable
|
|
7
|
+
from ncpi import tools
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class Analysis:
|
|
11
|
+
""" The Analysis class is designed to facilitate statistical analysis and data visualization.
|
|
12
|
+
|
|
13
|
+
Parameters
|
|
14
|
+
----------
|
|
15
|
+
data: (list, np.ndarray, pd.DataFrame)
|
|
16
|
+
Data to be analyzed.
|
|
17
|
+
"""
|
|
18
|
+
def __init__(self, data):
|
|
19
|
+
self.data = data
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def lmer(self, control_group = 'HC', data_col = 'Y', data_index = -1, models = None,
|
|
23
|
+
bic_models = None, anova_tests = None, specs = None):
|
|
24
|
+
"""
|
|
25
|
+
Perform linear mixed-effects model (lmer) or linear model (lm) comparisons using R's `lme4` and `emmeans`
|
|
26
|
+
packages.
|
|
27
|
+
|
|
28
|
+
Parameters
|
|
29
|
+
----------
|
|
30
|
+
control_group: str
|
|
31
|
+
The control group to be used for comparisons.
|
|
32
|
+
data_col: str
|
|
33
|
+
The name of the data column to be analyzed.
|
|
34
|
+
data_index: int
|
|
35
|
+
The index of the data column to be analyzed. If -1, the entire column is used.
|
|
36
|
+
models: dict
|
|
37
|
+
A dictionary of models to be used for analysis. The keys are model names and the values are model formulas.
|
|
38
|
+
if models is None, the default models are used:
|
|
39
|
+
- mod00: Y ~ Group + (1 | ID)
|
|
40
|
+
- mod01: Y ~ Group
|
|
41
|
+
The best model is selected based on BIC (Bayesian Information Criterion), unless bic_models is None.
|
|
42
|
+
bic_models: list
|
|
43
|
+
A list of models to be evaluated using BIC. If bic_models is None, the first model is selected. All models
|
|
44
|
+
have to be defined in the models dictionary.
|
|
45
|
+
Example:
|
|
46
|
+
bic_models = ["mod01", "mod02"] # Compare mod01 vs. mod02
|
|
47
|
+
anova_tests: dict
|
|
48
|
+
A dictionary that specifies which models should undergo an ANOVA test after BIC selection. If anova_tests
|
|
49
|
+
is None, no ANOVA tests are performed. Each test must contain two models to be compared. All models have
|
|
50
|
+
to be defined in the models dictionary.
|
|
51
|
+
Example:
|
|
52
|
+
anova_tests = {
|
|
53
|
+
"test1": ["mod00", "mod02"], # Compare mod00 vs. mod02
|
|
54
|
+
"test2": ["mod01", "mod03"] # Compare mod01 vs. mod03
|
|
55
|
+
}
|
|
56
|
+
specs: string
|
|
57
|
+
The specifications for the emmeans function in R. If specs is None, the default specs are used:
|
|
58
|
+
- ~Group
|
|
59
|
+
|
|
60
|
+
Returns
|
|
61
|
+
-------
|
|
62
|
+
results: dict
|
|
63
|
+
A dictionary containing the results of the analysis. The keys are the names of the groups being compared
|
|
64
|
+
and the values are DataFrames containing the results of the analysis.
|
|
65
|
+
|
|
66
|
+
"""
|
|
67
|
+
|
|
68
|
+
# Check if rpy2 is installed
|
|
69
|
+
if not tools.ensure_module("rpy2"):
|
|
70
|
+
raise ImportError("rpy2 is required for lmer but is not installed.")
|
|
71
|
+
pandas2ri = tools.dynamic_import("rpy2.robjects.pandas2ri")
|
|
72
|
+
r = tools.dynamic_import("rpy2.robjects","r")
|
|
73
|
+
ListVector = tools.dynamic_import("rpy2.robjects","ListVector")
|
|
74
|
+
ro = tools.dynamic_import("rpy2","robjects")
|
|
75
|
+
|
|
76
|
+
# Activate pandas2ri
|
|
77
|
+
pandas2ri.activate()
|
|
78
|
+
|
|
79
|
+
# Import R packages
|
|
80
|
+
ro.r('''
|
|
81
|
+
# Function to check and load packages
|
|
82
|
+
load_packages <- function(packages) {
|
|
83
|
+
for (pkg in packages) {
|
|
84
|
+
if (!require(pkg, character.only = TRUE)) {
|
|
85
|
+
stop("R package '", pkg, "' is not installed.")
|
|
86
|
+
}
|
|
87
|
+
}
|
|
88
|
+
}
|
|
89
|
+
|
|
90
|
+
# Load required packages
|
|
91
|
+
load_packages(c("dplyr", "lme4", "emmeans", "ggplot2", "repr", "mgcv"))
|
|
92
|
+
''')
|
|
93
|
+
|
|
94
|
+
# Check if the data is a pandas DataFrame
|
|
95
|
+
if not isinstance(self.data, pd.DataFrame):
|
|
96
|
+
raise ValueError('The data must be a pandas DataFrame.')
|
|
97
|
+
|
|
98
|
+
# Check if the data_col is in the DataFrame
|
|
99
|
+
if data_col not in self.data.columns:
|
|
100
|
+
raise ValueError(f'The data_col "{data_col}" is not in the DataFrame columns.')
|
|
101
|
+
|
|
102
|
+
# Check if 'ID', 'Group', 'Epoch' and 'Sensor' are in the DataFrame
|
|
103
|
+
for col in ['ID', 'Group', 'Epoch', 'Sensor']:
|
|
104
|
+
if col not in self.data.columns:
|
|
105
|
+
raise ValueError(f'The column "{col}" is not in the DataFrame.')
|
|
106
|
+
|
|
107
|
+
# Copy the dataframe
|
|
108
|
+
df = self.data.copy()
|
|
109
|
+
|
|
110
|
+
# Remove all columns except 'ID', 'Group', 'Epoch', 'Sensor' and data_col
|
|
111
|
+
df = df[['ID', 'Group', 'Epoch', 'Sensor', data_col]]
|
|
112
|
+
|
|
113
|
+
# If data_index is not -1, select the data_index value from the data_col
|
|
114
|
+
if data_index >= 0:
|
|
115
|
+
df[data_col] = df[data_col].apply(lambda x: x[data_index])
|
|
116
|
+
|
|
117
|
+
# Filter out control_group from the list of unique groups
|
|
118
|
+
groups = df['Group'].unique()
|
|
119
|
+
groups = [group for group in groups if group != control_group]
|
|
120
|
+
|
|
121
|
+
# Create a list with the different group comparisons
|
|
122
|
+
groups_comp = [f'{group}vs{control_group}' for group in groups]
|
|
123
|
+
|
|
124
|
+
# Remove rows where the data_col is zero
|
|
125
|
+
df = df[df[data_col] != 0]
|
|
126
|
+
|
|
127
|
+
# Rename data_col column to Y
|
|
128
|
+
df.rename(columns={data_col: 'Y'}, inplace=True)
|
|
129
|
+
|
|
130
|
+
# Force categorical data type
|
|
131
|
+
df["Sensor"] = df["Sensor"].astype('category')
|
|
132
|
+
|
|
133
|
+
# Default models if none are provided
|
|
134
|
+
if models is None:
|
|
135
|
+
models = {
|
|
136
|
+
'mod00': 'Y ~ Group + (1 | ID)',
|
|
137
|
+
'mod01': 'Y ~ Group'
|
|
138
|
+
}
|
|
139
|
+
|
|
140
|
+
# Default specs if none are provided
|
|
141
|
+
if specs is None:
|
|
142
|
+
specs = '~Group'
|
|
143
|
+
|
|
144
|
+
# Check that all models defined in bic_models have been also included in the models dictionary
|
|
145
|
+
if bic_models is not None:
|
|
146
|
+
for model in bic_models:
|
|
147
|
+
if model not in models.keys():
|
|
148
|
+
raise ValueError(f'bic_models: the model "{model}" is not defined in the models dictionary.')
|
|
149
|
+
|
|
150
|
+
# Check that all models defined in anova_tests have been also included in the models dictionary
|
|
151
|
+
if anova_tests is not None:
|
|
152
|
+
for test in anova_tests.values():
|
|
153
|
+
for model in test:
|
|
154
|
+
if model not in models.keys():
|
|
155
|
+
raise ValueError(f'anova_tests: the model "{model}" is not defined in the models dictionary.')
|
|
156
|
+
|
|
157
|
+
results = {}
|
|
158
|
+
for label, label_comp in zip(groups, groups_comp):
|
|
159
|
+
print(f'\n\n--- Group: {label}')
|
|
160
|
+
r('rm(list = ls())')
|
|
161
|
+
df_pair = df[df['Group'].isin([control_group, label])]
|
|
162
|
+
ro.globalenv['df_pair'] = pandas2ri.py2rpy(df_pair)
|
|
163
|
+
ro.globalenv['label'] = label
|
|
164
|
+
ro.globalenv['control_group'] = control_group
|
|
165
|
+
|
|
166
|
+
# Convert to factors
|
|
167
|
+
r('''
|
|
168
|
+
df_pair$ID = as.factor(df_pair$ID)
|
|
169
|
+
df_pair$Group = factor(df_pair$Group, levels = c(label, control_group))
|
|
170
|
+
df_pair$Epoch = as.factor(df_pair$Epoch)
|
|
171
|
+
df_pair$Sensor = as.factor(df_pair$Sensor)
|
|
172
|
+
print(table(df_pair$Group))
|
|
173
|
+
''')
|
|
174
|
+
|
|
175
|
+
# if table in R is empty for any group, skip the analysis
|
|
176
|
+
if r('table(df_pair$Group)')[0] == 0 or r('table(df_pair$Group)')[1] == 0:
|
|
177
|
+
results[label_comp] = pd.DataFrame({'p.value': [1], 'z.ratio': [0]})
|
|
178
|
+
# Fit the linear (mixed-effects) models
|
|
179
|
+
else:
|
|
180
|
+
for ii, (model_name, formula) in enumerate(models.items()):
|
|
181
|
+
if (bic_models is None and ii == 0) or (bic_models is not None and model_name in bic_models):
|
|
182
|
+
# print(f'--- BIC test: fitting model: {model_name}')
|
|
183
|
+
ro.globalenv[model_name] = formula
|
|
184
|
+
r(f"{model_name} <- {'lmer' if '(1 | ID)' in formula else 'lm'}({model_name}, data=df_pair)")
|
|
185
|
+
|
|
186
|
+
# BIC test: handle single and multiple model cases properly
|
|
187
|
+
r('''
|
|
188
|
+
all_models <- names(which(sapply(ls(), function(x) inherits(get(x), "merMod") || inherits(get(x), "lm"))))
|
|
189
|
+
|
|
190
|
+
if (length(all_models) == 1) {
|
|
191
|
+
m_sel <- all_models[1] # Use the only model available
|
|
192
|
+
} else {
|
|
193
|
+
bics <- sapply(all_models, function(m) BIC(get(m)))
|
|
194
|
+
index <- which.min(bics)
|
|
195
|
+
m_sel <- all_models[index]
|
|
196
|
+
}
|
|
197
|
+
|
|
198
|
+
final_model <- get(m_sel)
|
|
199
|
+
|
|
200
|
+
''')
|
|
201
|
+
print(f'--- BIC test. Selected model: {r("m_sel")}')
|
|
202
|
+
|
|
203
|
+
# Perform ANOVA tests only for user-specified comparisons
|
|
204
|
+
if anova_tests is not None:
|
|
205
|
+
# Fit the remaining models
|
|
206
|
+
for ii, (model_name, formula) in enumerate(models.items()):
|
|
207
|
+
# Check if model already exists in the R environment
|
|
208
|
+
if model_name not in r.ls():
|
|
209
|
+
# print(f'--- ANOVA test: fitting model: {model_name}')
|
|
210
|
+
ro.globalenv[model_name] = formula
|
|
211
|
+
r(f"{model_name} <- {'lmer' if '(1 | ID)' in formula else 'lm'}({model_name}, data=df_pair)")
|
|
212
|
+
|
|
213
|
+
# Convert to R list
|
|
214
|
+
r_anova_tests = ListVector(anova_tests)
|
|
215
|
+
|
|
216
|
+
# Assign to R global environment
|
|
217
|
+
ro.globalenv['anova_tests'] = r_anova_tests
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
r('''
|
|
221
|
+
m_name <- m_sel
|
|
222
|
+
for (comparison in names(anova_tests)) {
|
|
223
|
+
models <- anova_tests[[comparison]]
|
|
224
|
+
|
|
225
|
+
# Check if the selected model is in the list of models to compare
|
|
226
|
+
if (m_sel %in% models) {
|
|
227
|
+
anova_result <- capture.output(anova(get(as.character(models[2])), get(as.character(models[1]))))
|
|
228
|
+
|
|
229
|
+
# Extract p-value from the ANOVA result.
|
|
230
|
+
# Case 1: Standard ANOVA table with Pr(>F)
|
|
231
|
+
if (any(grepl("Pr\\\\(>F\\\\)", anova_result))) {
|
|
232
|
+
p_line <- anova_result[grep("Pr\\\\(>F\\\\)", anova_result)+2]
|
|
233
|
+
}
|
|
234
|
+
# Case 2: Mixed effects output with Pr(>Chisq)
|
|
235
|
+
else if (any(grepl("Pr\\\\(>Chisq\\\\)", anova_result))) {
|
|
236
|
+
p_line <- anova_result[grep("Pr\\\\(>Chisq\\\\)", anova_result) + 2]
|
|
237
|
+
}
|
|
238
|
+
# are there any other cases?
|
|
239
|
+
else {
|
|
240
|
+
next
|
|
241
|
+
}
|
|
242
|
+
matches <- regmatches(p_line, gregexpr("[0-9]+\\\\.?[0-9]*(e[+-]?[0-9]+)?", p_line))[[1]]
|
|
243
|
+
p_value <- as.numeric(tail(matches, 1))
|
|
244
|
+
|
|
245
|
+
if (!is.na(p_value) && p_value >= 0.05) {
|
|
246
|
+
# Determine which model is simpler (counts all fixed-effect terms, including interactions)
|
|
247
|
+
formula1 <- formula(get(as.character(models[1])))
|
|
248
|
+
formula2 <- formula(get(as.character(models[2])))
|
|
249
|
+
|
|
250
|
+
# Count terms (excluding random effects after '|')
|
|
251
|
+
terms1 <- length(attr(terms(formula1), "term.labels")) # Handles *, :, etc.
|
|
252
|
+
terms2 <- length(attr(terms(formula2), "term.labels"))
|
|
253
|
+
|
|
254
|
+
# Select the model with less complexity
|
|
255
|
+
if (terms1 <= terms2) {
|
|
256
|
+
final_model <- get(as.character(models[1]))
|
|
257
|
+
m_name <- models[1]
|
|
258
|
+
} else {
|
|
259
|
+
final_model <- get(as.character(models[2]))
|
|
260
|
+
m_name <- models[2]
|
|
261
|
+
}
|
|
262
|
+
break
|
|
263
|
+
}
|
|
264
|
+
}
|
|
265
|
+
}
|
|
266
|
+
''')
|
|
267
|
+
print(f'--- ANOVA test. Selected model: {r("m_name")}')
|
|
268
|
+
|
|
269
|
+
# Compute pairwise comparisons
|
|
270
|
+
ro.globalenv['specs'] = specs
|
|
271
|
+
|
|
272
|
+
r('''
|
|
273
|
+
emm <- suppressMessages(emmeans(final_model, specs=as.formula(specs)))
|
|
274
|
+
res <- pairs(emm, adjust='holm')
|
|
275
|
+
df_res <- as.data.frame(res)
|
|
276
|
+
''')
|
|
277
|
+
|
|
278
|
+
# Ensure Sensor remains as a character column
|
|
279
|
+
if 'Sensor' in r('names(df_res)'):
|
|
280
|
+
r('''
|
|
281
|
+
df_res$Sensor <- as.character(df_res$Sensor)
|
|
282
|
+
''')
|
|
283
|
+
|
|
284
|
+
df_res_r = ro.r['df_res']
|
|
285
|
+
with (pandas2ri.converter + pandas2ri.converter).context():
|
|
286
|
+
df_res_pd = pandas2ri.conversion.get_conversion().rpy2py(df_res_r)
|
|
287
|
+
|
|
288
|
+
results[label_comp] = df_res_pd
|
|
289
|
+
|
|
290
|
+
return results
|
|
291
|
+
|
|
292
|
+
|
|
293
|
+
def EEG_topographic_plot(self, **kwargs):
|
|
294
|
+
'''
|
|
295
|
+
Generate a topographical plot of EEG data using the 10-20 electrode placement system,
|
|
296
|
+
visualizing activity from 19 or 20 electrodes.
|
|
297
|
+
|
|
298
|
+
Parameters
|
|
299
|
+
----------
|
|
300
|
+
**kwargs: keyword arguments:
|
|
301
|
+
- radius: (float)
|
|
302
|
+
Radius of the head circumference.
|
|
303
|
+
- pos: (float)
|
|
304
|
+
Position of the head on the x-axis.
|
|
305
|
+
- electrode_size: (float)
|
|
306
|
+
Size of the electrodes.
|
|
307
|
+
- label: (bool)
|
|
308
|
+
Show the colorbar label.
|
|
309
|
+
- ax: (matplotlib Axes object)
|
|
310
|
+
Axes object to plot the data.
|
|
311
|
+
- fig: (matplotlib Figure object)
|
|
312
|
+
Figure object to plot the data.
|
|
313
|
+
- vmin: (float)
|
|
314
|
+
Min value used for plotting.
|
|
315
|
+
- vmax: (float)
|
|
316
|
+
Max value used for plotting.
|
|
317
|
+
'''
|
|
318
|
+
|
|
319
|
+
# Check if mpl_toolkits is installed
|
|
320
|
+
if not tools.ensure_module("mpl_toolkits"):
|
|
321
|
+
raise ImportError("mpl_toolkits is required for EEG_topographic_plot but is not installed.")
|
|
322
|
+
make_axes_locatable = tools.dynamic_import("mpl_toolkits.axes_grid1",
|
|
323
|
+
"make_axes_locatable")
|
|
324
|
+
|
|
325
|
+
default_parameters = {
|
|
326
|
+
'radius': 0.6,
|
|
327
|
+
'pos': 0.0,
|
|
328
|
+
'electrode_size': 0.9,
|
|
329
|
+
'label': True,
|
|
330
|
+
'ax': None,
|
|
331
|
+
'fig': None,
|
|
332
|
+
'vmin': None,
|
|
333
|
+
'vmax': None
|
|
334
|
+
}
|
|
335
|
+
|
|
336
|
+
for key in kwargs.keys():
|
|
337
|
+
if key not in default_parameters.keys():
|
|
338
|
+
raise ValueError(f'Invalid parameter: {key}')
|
|
339
|
+
|
|
340
|
+
radius = kwargs.get('radius', default_parameters['radius'])
|
|
341
|
+
pos = kwargs.get('pos', default_parameters['pos'])
|
|
342
|
+
electrode_size = kwargs.get('electrode_size', default_parameters['electrode_size'])
|
|
343
|
+
label = kwargs.get('label', default_parameters['label'])
|
|
344
|
+
ax = kwargs.get('ax', default_parameters['ax'])
|
|
345
|
+
fig = kwargs.get('fig', default_parameters['fig'])
|
|
346
|
+
vmin = kwargs.get('vmin', default_parameters['vmin'])
|
|
347
|
+
vmax = kwargs.get('vmax', default_parameters['vmax'])
|
|
348
|
+
|
|
349
|
+
if not isinstance(radius, float):
|
|
350
|
+
raise ValueError('The radius parameter must be a float.')
|
|
351
|
+
if not isinstance(pos, float):
|
|
352
|
+
raise ValueError('The pos parameter must be a float.')
|
|
353
|
+
if not isinstance(electrode_size, float):
|
|
354
|
+
raise ValueError('The electrode_size parameter must be a float.')
|
|
355
|
+
if not isinstance(label, bool):
|
|
356
|
+
raise ValueError('The label parameter must be a boolean.')
|
|
357
|
+
if not isinstance(ax, plt.Axes):
|
|
358
|
+
raise ValueError('The ax parameter must be a matplotlib Axes object.')
|
|
359
|
+
if not isinstance(fig, plt.Figure):
|
|
360
|
+
raise ValueError('The fig parameter must be a matplotlib Figure object.')
|
|
361
|
+
if not isinstance(vmin, float):
|
|
362
|
+
raise ValueError('The vmin parameter must be a float.')
|
|
363
|
+
if not isinstance(vmax, float):
|
|
364
|
+
raise ValueError('The vmax parameter must be a float.')
|
|
365
|
+
if not isinstance(self.data, (list, np.ndarray)):
|
|
366
|
+
raise ValueError('The data parameter must be a list or numpy array.')
|
|
367
|
+
if len(self.data) not in [19, 20]:
|
|
368
|
+
raise ValueError('The data parameter must contain 19 or 20 elements.')
|
|
369
|
+
|
|
370
|
+
|
|
371
|
+
def plot_simple_head(ax, radius=0.6, pos=0):
|
|
372
|
+
'''
|
|
373
|
+
Plot a simple head model with ears and nose.
|
|
374
|
+
|
|
375
|
+
Parameters
|
|
376
|
+
----------
|
|
377
|
+
ax: matplotlib Axes object
|
|
378
|
+
radius: float,
|
|
379
|
+
radius of the head circumference.
|
|
380
|
+
pos: float
|
|
381
|
+
Position of the head on the x-axis.
|
|
382
|
+
'''
|
|
383
|
+
|
|
384
|
+
# Adjust the aspect ratio of the plot
|
|
385
|
+
ax.set_aspect('equal')
|
|
386
|
+
|
|
387
|
+
# Head
|
|
388
|
+
head_circle = mpatches.Circle((pos, 0), radius+0.02, edgecolor='k', facecolor='none', linewidth=0.5)
|
|
389
|
+
ax.add_patch(head_circle)
|
|
390
|
+
|
|
391
|
+
# Ears
|
|
392
|
+
right_ear = mpatches.FancyBboxPatch([pos + radius + radius / 20, -radius / 10],
|
|
393
|
+
radius / 50, radius / 5,
|
|
394
|
+
boxstyle=mpatches.BoxStyle("Round", pad=radius / 20),
|
|
395
|
+
linewidth=0.5)
|
|
396
|
+
ax.add_patch(right_ear)
|
|
397
|
+
|
|
398
|
+
left_ear = mpatches.FancyBboxPatch([pos - radius - radius / 20 - radius / 50, -radius / 10],
|
|
399
|
+
radius / 50, radius / 5,
|
|
400
|
+
boxstyle=mpatches.BoxStyle("Round", pad=radius / 20),
|
|
401
|
+
linewidth=0.5)
|
|
402
|
+
ax.add_patch(left_ear)
|
|
403
|
+
|
|
404
|
+
# Nose
|
|
405
|
+
ax.plot([pos - radius / 10, pos, pos + radius / 10],
|
|
406
|
+
[radius + 0.02, radius + radius / 10 + 0.02,0.02 + radius],
|
|
407
|
+
'k', linewidth=0.5)
|
|
408
|
+
|
|
409
|
+
|
|
410
|
+
def plot_EEG(data, radius, pos, electrode_size, label, ax, fig, vmin, vmax):
|
|
411
|
+
'''
|
|
412
|
+
Plot the EEG data on the head model as a topographic map.
|
|
413
|
+
|
|
414
|
+
Parameters
|
|
415
|
+
----------
|
|
416
|
+
data: list or np.ndarray of size (19,) or (20,)
|
|
417
|
+
EEG data.
|
|
418
|
+
radius: float
|
|
419
|
+
Radius of the head circumference.
|
|
420
|
+
pos: float
|
|
421
|
+
Position of the head on the x-axis.
|
|
422
|
+
electrode_size: float
|
|
423
|
+
Size of the electrodes.
|
|
424
|
+
label: bool
|
|
425
|
+
Show the colorbar label.
|
|
426
|
+
ax: matplotlib Axes object
|
|
427
|
+
Axes object to plot the data.
|
|
428
|
+
fig: matplotlib Figure object
|
|
429
|
+
Figure object to plot the data.
|
|
430
|
+
vmin: float
|
|
431
|
+
Min value used for plotting.
|
|
432
|
+
vmax: float
|
|
433
|
+
Max value used for plotting.
|
|
434
|
+
'''
|
|
435
|
+
|
|
436
|
+
# Check data type
|
|
437
|
+
if not isinstance(data, (list, np.ndarray)):
|
|
438
|
+
raise ValueError('The data must be a list or numpy array.')
|
|
439
|
+
|
|
440
|
+
# Check data length
|
|
441
|
+
if len(data) not in [19, 20]:
|
|
442
|
+
raise ValueError('The data must contain 19 or 20 elements.')
|
|
443
|
+
|
|
444
|
+
# Coordinates of the EEG electrodes
|
|
445
|
+
koord_dict = {
|
|
446
|
+
'Fp1': [pos - 0.25 * radius, 0.8 * radius],
|
|
447
|
+
'Fp2': [pos + 0.25 * radius, 0.8 * radius],
|
|
448
|
+
'F3': [pos - 0.3 * radius, 0.35 * radius],
|
|
449
|
+
'F4': [pos + 0.3 * radius, 0.35 * radius],
|
|
450
|
+
'C3': [pos - 0.35 * radius, 0.0],
|
|
451
|
+
'C4': [pos + 0.35 * radius, 0.0],
|
|
452
|
+
'P3': [pos - 0.3 * radius, -0.4 * radius],
|
|
453
|
+
'P4': [pos + 0.3 * radius, -0.4 * radius],
|
|
454
|
+
'O1': [pos - 0.35 * radius, -0.8 * radius],
|
|
455
|
+
'O2': [pos + 0.35 * radius, -0.8 * radius],
|
|
456
|
+
'F7': [pos - 0.6 * radius, 0.45 * radius],
|
|
457
|
+
'F8': [pos + 0.6 * radius, 0.45 * radius],
|
|
458
|
+
'T3': [pos - 0.8 * radius, 0.0],
|
|
459
|
+
'T4': [pos + 0.8 * radius, 0.0],
|
|
460
|
+
'T5': [pos - 0.6 * radius, -0.2],
|
|
461
|
+
'T6': [pos + 0.6 * radius, -0.2],
|
|
462
|
+
'Fz': [pos, 0.35 * radius],
|
|
463
|
+
'Cz': [pos, 0.0],
|
|
464
|
+
'Pz': [pos, -0.4 * radius],
|
|
465
|
+
'Oz': [pos, -0.8 * radius]
|
|
466
|
+
}
|
|
467
|
+
|
|
468
|
+
if len(data) == 19:
|
|
469
|
+
del koord_dict['Oz']
|
|
470
|
+
koord = list(koord_dict.values())
|
|
471
|
+
|
|
472
|
+
# Number of points used for interpolation
|
|
473
|
+
N = 100
|
|
474
|
+
|
|
475
|
+
# External fake electrodes used for interpolation
|
|
476
|
+
for xx in np.linspace(pos-radius,pos+radius,50):
|
|
477
|
+
koord.append([xx,np.sqrt(radius**2 - (xx)**2)])
|
|
478
|
+
koord.append([xx,-np.sqrt(radius**2 - (xx)**2)])
|
|
479
|
+
data.append(0)
|
|
480
|
+
data.append(0)
|
|
481
|
+
|
|
482
|
+
# Interpolate data points
|
|
483
|
+
x,y = [],[]
|
|
484
|
+
for i in koord:
|
|
485
|
+
x.append(i[0])
|
|
486
|
+
y.append(i[1])
|
|
487
|
+
z = data
|
|
488
|
+
|
|
489
|
+
xi = np.linspace(-radius, radius, N)
|
|
490
|
+
yi = np.linspace(-radius, radius, N)
|
|
491
|
+
zi = scipy.interpolate.griddata((np.array(x), np.array(y)), z,
|
|
492
|
+
(xi[None,:], yi[:,None]), method='cubic')
|
|
493
|
+
|
|
494
|
+
|
|
495
|
+
# Use different number of levels for the fill and the lines
|
|
496
|
+
CS = ax.contourf(xi, yi, zi, 30, cmap = plt.cm.bwr, zorder = 1,
|
|
497
|
+
vmin = vmin, vmax = vmax)
|
|
498
|
+
ax.contour(xi, yi, zi, 5, colors ="grey", zorder = 2, linewidths = 0.4,
|
|
499
|
+
vmin = vmin, vmax = vmax)
|
|
500
|
+
|
|
501
|
+
# Make a color bar
|
|
502
|
+
# cbar = fig.colorbar(CS, ax=Vax)
|
|
503
|
+
divider = make_axes_locatable(ax)
|
|
504
|
+
cax = divider.append_axes("right", size="5%", pad=0.05)
|
|
505
|
+
|
|
506
|
+
if np.sum(np.abs(data)) > 2:
|
|
507
|
+
colorbar = fig.colorbar(ScalarMappable(norm=CS.norm, cmap=CS.cmap), cax=cax)
|
|
508
|
+
colorbar.ax.tick_params(labelsize=8)
|
|
509
|
+
if label == True:
|
|
510
|
+
colorbar.ax.xaxis.set_label_position('bottom')
|
|
511
|
+
# bbox = colorbar.ax.get_position()
|
|
512
|
+
# print(bbox)
|
|
513
|
+
colorbar.set_label('z-ratio', size=5, labelpad=-15, rotation=0, y=0.)
|
|
514
|
+
|
|
515
|
+
else:
|
|
516
|
+
# Hide the colorbar if the data is not significant
|
|
517
|
+
cax.axis('off')
|
|
518
|
+
|
|
519
|
+
# Add the EEG electrode positions
|
|
520
|
+
ax.scatter(x[:len(koord_dict)], y[:len(koord_dict)], marker ='o', c ='k', s = electrode_size, zorder = 3)
|
|
521
|
+
|
|
522
|
+
|
|
523
|
+
plot_simple_head(ax, radius, pos)
|
|
524
|
+
plot_EEG(self.data, radius, pos, electrode_size, label, ax, fig, vmin, vmax)
|