eegdash 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of eegdash might be problematic. Click here for more details.

Files changed (72) hide show
  1. eegdash/SignalStore/__init__.py +0 -0
  2. eegdash/SignalStore/signalstore/__init__.py +3 -0
  3. eegdash/SignalStore/signalstore/adapters/read_adapters/abstract_read_adapter.py +13 -0
  4. eegdash/SignalStore/signalstore/adapters/read_adapters/domain_modeling/schema_read_adapter.py +16 -0
  5. eegdash/SignalStore/signalstore/adapters/read_adapters/domain_modeling/vocabulary_read_adapter.py +19 -0
  6. eegdash/SignalStore/signalstore/adapters/read_adapters/handmade_records/excel_study_organizer_read_adapter.py +114 -0
  7. eegdash/SignalStore/signalstore/adapters/read_adapters/recording_acquisitions/axona/axona_read_adapter.py +912 -0
  8. eegdash/SignalStore/signalstore/adapters/read_adapters/recording_acquisitions/intan/ReadIntanSpikeFile.py +140 -0
  9. eegdash/SignalStore/signalstore/adapters/read_adapters/recording_acquisitions/intan/intan_read_adapter.py +29 -0
  10. eegdash/SignalStore/signalstore/adapters/read_adapters/recording_acquisitions/intan/load_intan_rhd_format/intanutil/__init__.py +0 -0
  11. eegdash/SignalStore/signalstore/adapters/read_adapters/recording_acquisitions/intan/load_intan_rhd_format/intanutil/data_to_result.py +62 -0
  12. eegdash/SignalStore/signalstore/adapters/read_adapters/recording_acquisitions/intan/load_intan_rhd_format/intanutil/get_bytes_per_data_block.py +36 -0
  13. eegdash/SignalStore/signalstore/adapters/read_adapters/recording_acquisitions/intan/load_intan_rhd_format/intanutil/notch_filter.py +50 -0
  14. eegdash/SignalStore/signalstore/adapters/read_adapters/recording_acquisitions/intan/load_intan_rhd_format/intanutil/qstring.py +41 -0
  15. eegdash/SignalStore/signalstore/adapters/read_adapters/recording_acquisitions/intan/load_intan_rhd_format/intanutil/read_header.py +135 -0
  16. eegdash/SignalStore/signalstore/adapters/read_adapters/recording_acquisitions/intan/load_intan_rhd_format/intanutil/read_one_data_block.py +45 -0
  17. eegdash/SignalStore/signalstore/adapters/read_adapters/recording_acquisitions/intan/load_intan_rhd_format/load_intan_rhd_format.py +204 -0
  18. eegdash/SignalStore/signalstore/adapters/read_adapters/recording_acquisitions/intan/load_intan_rhs_format/intanutil/__init__.py +0 -0
  19. eegdash/SignalStore/signalstore/adapters/read_adapters/recording_acquisitions/intan/load_intan_rhs_format/intanutil/data_to_result.py +60 -0
  20. eegdash/SignalStore/signalstore/adapters/read_adapters/recording_acquisitions/intan/load_intan_rhs_format/intanutil/get_bytes_per_data_block.py +37 -0
  21. eegdash/SignalStore/signalstore/adapters/read_adapters/recording_acquisitions/intan/load_intan_rhs_format/intanutil/notch_filter.py +50 -0
  22. eegdash/SignalStore/signalstore/adapters/read_adapters/recording_acquisitions/intan/load_intan_rhs_format/intanutil/qstring.py +41 -0
  23. eegdash/SignalStore/signalstore/adapters/read_adapters/recording_acquisitions/intan/load_intan_rhs_format/intanutil/read_header.py +153 -0
  24. eegdash/SignalStore/signalstore/adapters/read_adapters/recording_acquisitions/intan/load_intan_rhs_format/intanutil/read_one_data_block.py +47 -0
  25. eegdash/SignalStore/signalstore/adapters/read_adapters/recording_acquisitions/intan/load_intan_rhs_format/load_intan_rhs_format.py +213 -0
  26. eegdash/SignalStore/signalstore/adapters/read_adapters/recording_acquisitions/neurodata_without_borders/neurodata_without_borders_read_adapter.py +14 -0
  27. eegdash/SignalStore/signalstore/operations/__init__.py +4 -0
  28. eegdash/SignalStore/signalstore/operations/handler_executor.py +22 -0
  29. eegdash/SignalStore/signalstore/operations/handler_factory.py +41 -0
  30. eegdash/SignalStore/signalstore/operations/handlers/base_handler.py +44 -0
  31. eegdash/SignalStore/signalstore/operations/handlers/domain/property_model_handlers.py +79 -0
  32. eegdash/SignalStore/signalstore/operations/handlers/domain/schema_handlers.py +3 -0
  33. eegdash/SignalStore/signalstore/operations/helpers/abstract_helper.py +17 -0
  34. eegdash/SignalStore/signalstore/operations/helpers/neuroscikit_extractor.py +33 -0
  35. eegdash/SignalStore/signalstore/operations/helpers/neuroscikit_rawio.py +165 -0
  36. eegdash/SignalStore/signalstore/operations/helpers/spikeinterface_helper.py +100 -0
  37. eegdash/SignalStore/signalstore/operations/helpers/wrappers/neo_wrappers.py +21 -0
  38. eegdash/SignalStore/signalstore/operations/helpers/wrappers/nwb_wrappers.py +27 -0
  39. eegdash/SignalStore/signalstore/store/__init__.py +8 -0
  40. eegdash/SignalStore/signalstore/store/data_access_objects.py +1181 -0
  41. eegdash/SignalStore/signalstore/store/datafile_adapters.py +131 -0
  42. eegdash/SignalStore/signalstore/store/repositories.py +928 -0
  43. eegdash/SignalStore/signalstore/store/store_errors.py +68 -0
  44. eegdash/SignalStore/signalstore/store/unit_of_work.py +97 -0
  45. eegdash/SignalStore/signalstore/store/unit_of_work_provider.py +67 -0
  46. eegdash/SignalStore/signalstore/utilities/data_adapters/spike_interface_adapters/si_recording.py +1 -0
  47. eegdash/SignalStore/signalstore/utilities/data_adapters/spike_interface_adapters/si_sorter.py +1 -0
  48. eegdash/SignalStore/signalstore/utilities/testing/data_mocks.py +513 -0
  49. eegdash/SignalStore/signalstore/utilities/tools/dataarrays.py +49 -0
  50. eegdash/SignalStore/signalstore/utilities/tools/mongo_records.py +25 -0
  51. eegdash/SignalStore/signalstore/utilities/tools/operation_response.py +78 -0
  52. eegdash/SignalStore/signalstore/utilities/tools/purge_orchestration_response.py +21 -0
  53. eegdash/SignalStore/signalstore/utilities/tools/quantities.py +15 -0
  54. eegdash/SignalStore/signalstore/utilities/tools/strings.py +38 -0
  55. eegdash/SignalStore/signalstore/utilities/tools/time.py +17 -0
  56. eegdash/SignalStore/tests/conftest.py +799 -0
  57. eegdash/SignalStore/tests/data/valid_data/data_arrays/make_fake_data.py +59 -0
  58. eegdash/SignalStore/tests/unit/store/conftest.py +0 -0
  59. eegdash/SignalStore/tests/unit/store/test_data_access_objects.py +1235 -0
  60. eegdash/SignalStore/tests/unit/store/test_repositories.py +1309 -0
  61. eegdash/SignalStore/tests/unit/store/test_unit_of_work.py +7 -0
  62. eegdash/SignalStore/tests/unit/test_ci_cd.py +8 -0
  63. eegdash/__init__.py +1 -0
  64. eegdash/aws_ingest.py +29 -0
  65. eegdash/data_utils.py +213 -0
  66. eegdash/main.py +17 -0
  67. eegdash/signalstore_data_utils.py +280 -0
  68. eegdash-0.0.1.dist-info/LICENSE +20 -0
  69. eegdash-0.0.1.dist-info/METADATA +72 -0
  70. eegdash-0.0.1.dist-info/RECORD +72 -0
  71. eegdash-0.0.1.dist-info/WHEEL +5 -0
  72. eegdash-0.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,912 @@
1
+ from signalstore.operations.importers.adapters.abstract_read_adapter import AbstractReadAdapter
2
+
3
+ class AxonaReadAdapter(AbstractReadAdapter):
4
+ def __init__(self, directory):
5
+ self.data_directory = UPath(directory)
6
+ self.session_paths = self._assemble_session_paths()
7
+
8
+ # has __iter__ and __next__ methods inherited from AbstractReadAdapter.
9
+ # The methods use self.read() to get the data generator. There is no need to
10
+ # call read when using the axona read adapter. It is preferred to use
11
+ # data = AxonaReadAdapter(directory)
12
+ # and then use list(data) to get the data as a list.
13
+ # or next(data) to get the next item in the generator.
14
+ # or for item in data: to iterate over the generator.
15
+ def read(self):
16
+ for name in self.session_paths:
17
+ data = self._read_session_data(name)
18
+ yield data
19
+
20
+ def _read_session_data(self, session_key):
21
+ session_paths = self.session_paths[session_key]
22
+ data = []
23
+ data.extend(read_position_data_set(session_paths['pos'], session_key))
24
+ # data.extend(read_settings_data_set(session_paths['set'],
25
+ # session_key))
26
+ tetrode_keys = list(set(session_paths.keys()) - {'pos', 'set'})
27
+ for tetrode_key in tetrode_keys:
28
+ tetrode_paths = session_paths[tetrode_key]
29
+ for name in tetrode_paths:
30
+ if name == 'cut':
31
+ continue
32
+ if name == 'eeg' or name == 'egf':
33
+ data.extend(
34
+ read_lfp_data_set(
35
+ tetrode_paths[name],
36
+ session_key, tetrode_key))
37
+ elif name == 'tet':
38
+ data.extend(
39
+ read_spike_data_set(
40
+ tetrode_paths[name],
41
+ tetrode_paths['cut'],
42
+ session_key, tetrode_key))
43
+ for record in data:
44
+ if isinstance(record, xr.DataArray):
45
+ record.attrs.update(session_key=session_key)
46
+ elif isinstance(record, dict):
47
+ record.update(session_key=session_key)
48
+ return data
49
+
50
+ def _assemble_session_paths(self):
51
+ paths = {}
52
+ for UPath in self.data_directory.glob('*'):
53
+ try:
54
+ keys = self._path_key_tuple(UPath)
55
+ except AxonaReaderFileExtensionError:
56
+ continue
57
+ if keys[0] not in paths.keys():
58
+ paths[keys[0]] = {}
59
+ if len(keys) == 2:
60
+ paths[keys[0]][keys[1]] = UPath
61
+ elif len(keys) == 3:
62
+ if keys[1] not in paths[keys[0]].keys():
63
+ paths[keys[0]][keys[1]] = {}
64
+ paths[keys[0]][keys[1]][keys[2]] = UPath
65
+ return paths
66
+
67
+ def _path_key_tuple(self, UPath):
68
+ if self._ispos(UPath):
69
+ return self._pos_session_key(UPath), 'pos'
70
+ elif self._isset(UPath):
71
+ return self._set_session_key(UPath), 'set'
72
+ elif self._iscut(UPath):
73
+ return self._cut_session_key(
74
+ UPath), self._cut_tetrode_key(UPath), 'cut'
75
+ elif self._isclu(UPath):
76
+ return self._clu_session_key(
77
+ UPath), self._clu_tetrode_key(UPath), 'clu'
78
+ elif self._istet(UPath):
79
+ return self._tet_session_key(
80
+ UPath), self._tet_tetrode_key(UPath), 'tet'
81
+ elif self._iseeg(UPath):
82
+ return self._eeg_session_key(
83
+ UPath), self._eeg_tetrode_key(UPath), 'eeg'
84
+ elif self._isegf(UPath):
85
+ return self._egf_session_key(
86
+ UPath), self._egf_tetrode_key(UPath), 'egf'
87
+ else:
88
+ raise AxonaReaderFileExtensionError(
89
+ f'This file has an unsupported axona extension: {UPath}')
90
+
91
+ def _ispos(self, UPath):
92
+ return UPath.suffix == '.pos'
93
+
94
+ def _pos_session_key(self, UPath):
95
+ return UPath.stem
96
+
97
+ def _isset(self, UPath):
98
+ return UPath.suffix == '.set'
99
+
100
+ def _set_session_key(self, UPath):
101
+ return UPath.stem
102
+
103
+ def _iscut(self, UPath):
104
+ return UPath.suffix == '.cut'
105
+
106
+ def _cut_session_key(self, UPath):
107
+ # get the pathname before the last underscore
108
+ name = UPath.stem.split('_')[:-1]
109
+ if len(name) == 1:
110
+ return name[0]
111
+ elif len(name) > 1:
112
+ return '_'.join(name)
113
+ else:
114
+ raise AxonaReaderFileExtensionError(
115
+ f'This .cut file has an invalid extension: {UPath}')
116
+
117
+ def _cut_tetrode_key(self, UPath):
118
+ # get the number after the last underscore but before the extension
119
+ return int(UPath.stem.split('_')[-1])
120
+
121
+ def _isclu(self, UPath):
122
+ return UPath(UPath.stem).suffix == '.clu'
123
+
124
+ def _clu_session_key(self, UPath):
125
+ # get the pathname before the last underscore
126
+ return UPath(UPath.stem).stem
127
+
128
+ def _clu_tetrode_key(self, UPath):
129
+ # get the number after the last underscore but before the extension
130
+ return int(UPath.suffix[1:])
131
+
132
+ def _istet(self, UPath):
133
+ # check if the UPath extension is an integer
134
+ return UPath.suffix[1:].isnumeric()
135
+
136
+ def _tet_session_key(self, UPath):
137
+ # get the pathname before the last period
138
+ return UPath.stem
139
+
140
+ def _tet_tetrode_key(self, UPath):
141
+ return int(UPath.suffix[1:])
142
+
143
+ def _iseeg(self, UPath):
144
+ return UPath.suffix[:4] == '.eeg'
145
+
146
+ def _eeg_session_key(self, UPath):
147
+ return UPath.stem
148
+
149
+ def _eeg_tetrode_key(self, UPath):
150
+ if UPath.suffix == '.eeg':
151
+ return 1
152
+ elif UPath.suffix.replace('.eeg', '').isnumeric():
153
+ return int(UPath.suffix.replace('.eeg', ''))
154
+ else:
155
+ raise AxonaReaderFileExtensionError(
156
+ f'This .eeg* file has an invalid extension: {UPath}')
157
+
158
+ def _isegf(self, UPath):
159
+ return UPath.suffix[:4] == '.egf'
160
+
161
+ def _egf_session_key(self, UPath):
162
+ return UPath.stem
163
+
164
+ def _egf_tetrode_key(self, UPath):
165
+ if UPath.suffix == '.egf':
166
+ return 1
167
+ elif UPath.suffix.replace('.egf', '').isnumeric():
168
+ return int(UPath.suffix.replace('.egf', ''))
169
+ else:
170
+ raise AxonaReaderFileExtensionError(
171
+ f'This .egf* file has an invalid extension: \n{UPath}')
172
+
173
+
174
+ class AxonaReaderFileExtensionError(ValueError):
175
+ pass
176
+
177
+
178
+ # =============================================================================
179
+ # EEG or EGF Helpers
180
+ # =============================================================================
181
+
182
+ # Internal Dependencies
183
+ import contextlib # for closing the file
184
+ import mmap
185
+
186
+ # from core.data_voltage import EphysSeries
187
+
188
+ # A+ Grade Dependencies
189
+ import numpy as np
190
+ import xarray as xr
191
+ from upath import UPath
192
+
193
+ # A Grade Dependencies
194
+
195
+ # Other Dependencies
196
+
197
+
198
+ def read_lfp_data_set(file_path: UPath, session_key: str, tetrode_key: str):
199
+ with open(file_path, 'rb') as eeg_or_egf_file:
200
+ if 'eeg' in file_path.suffix:
201
+ file_type = 'eeg'
202
+ elif 'egf' in file_path.suffix:
203
+ file_type = 'egf'
204
+ else:
205
+ raise ValueError(
206
+ 'The file extension must be either "eeg*" or "egf*".'
207
+ + f'The current file extension is {file_path.suffix}'
208
+ )
209
+ records = read_eeg_or_egf(eeg_or_egf_file, file_type, session_key, tetrode_key)
210
+ for record in records:
211
+ record.attrs['session_key'] = session_key
212
+ return records
213
+
214
+
215
+ def read_eeg_or_egf(opened_eeg_or_egf_file, file_type: str, session_key, tetrode_key) -> np.ndarray:
216
+ """input:
217
+ opened_eeg_or_egf_file: an open file object for the .eeg or .egf file
218
+ Output:
219
+ The EEG waveform, and the sampling frequency
220
+ """
221
+ is_eeg = False
222
+ is_egf = False
223
+ if 'eeg' == file_type:
224
+ is_eeg = True
225
+ elif 'egf' == file_type:
226
+ is_egf = True
227
+ else:
228
+ raise ValueError('The file extension must be either "eeg" or "egf".')
229
+
230
+ if is_eeg == is_egf:
231
+ raise LFPFileTypeConflict('The file extension must be EITHER "eeg"'
232
+ + 'XOR "egf".\nCurrently the is_eeg ='
233
+ + f'{is_eeg}, and is_egf = {is_egf}')
234
+
235
+ mp = mmap.mmap(opened_eeg_or_egf_file.fileno(), 0, access=mmap.ACCESS_READ)
236
+ with contextlib.closing(mp) as memory_map:
237
+ # find the data_start
238
+ start_index = int(
239
+ memory_map.find(b'data_start') +
240
+ len('data_start')) # start of the data
241
+ stop_index = int(memory_map.find(b'\r\ndata_end')) # end of the data
242
+
243
+ sample_rate_start = memory_map.find(b'sample_rate')
244
+ sample_rate_end = memory_map[sample_rate_start:].find(b'\r\n')
245
+ Fs = float(memory_map[sample_rate_start:sample_rate_start +
246
+ sample_rate_end].decode('utf-8').split(' ')[1])
247
+
248
+ data_string = memory_map[start_index:stop_index]
249
+
250
+ if is_eeg and not is_egf:
251
+ assert Fs == 250
252
+ lfp_data = np.frombuffer(data_string, dtype='>b')
253
+ elif is_egf and not is_eeg:
254
+ assert Fs == 4.8e3
255
+ lfp_data = np.frombuffer(data_string, dtype='<h')
256
+ else:
257
+ raise ValueError(
258
+ 'The file extension must be either "eeg" or "egf"')
259
+
260
+ lfp_signal = xr.DataArray(lfp_data,
261
+ name = f"{session_key}_lfp_tetrode_{str(tetrode_key)}_{int(Fs)}_hz",
262
+ dims=['time'],
263
+ attrs={'type': 'lfp',
264
+ 'units': 'uV',
265
+ 'dimensionality': 'voltage',
266
+ 'sample_rate': Fs
267
+ }
268
+ )
269
+
270
+ lfp_time = xr.DataArray(np.arange(lfp_data.size) / Fs,
271
+ name=f"{session_key}_lfp_tetrode_{str(tetrode_key)}_time_{int(Fs)}_hz",
272
+ dims=['time'],
273
+ attrs={'type': 'lfp',
274
+ 'units': 's',
275
+ 'dimensionality': 'time',
276
+ 'sample_rate': Fs
277
+ }
278
+ )
279
+
280
+ return [lfp_signal, lfp_time]
281
+
282
+
283
+ class LFPFileTypeConflict(ValueError):
284
+ pass
285
+
286
+ # =============================================================================
287
+ # pos (position) Helpers
288
+ # =============================================================================
289
+
290
+ import numpy as np
291
+ import scipy
292
+ import struct
293
+ import xarray as xr
294
+
295
+ # A Grade Dependencies
296
+
297
+ # Other Dependencies
298
+
299
+
300
+ def read_position_data_set(pos_path: str, session_key, ppm=None) -> list:
301
+ data = read_position_data(pos_path, ppm=ppm)
302
+ pos_x = xr.DataArray(data['x'],
303
+ name=session_key + '_pos_x',
304
+ dims='time',
305
+ attrs={'type': 'position',
306
+ 'units': 'meters',
307
+ 'dimensionality': 'length',
308
+ 'orientation': 'horizontal',
309
+ 'session_key': session_key})
310
+ pos_y = xr.DataArray(data['y'],
311
+ dims='time',
312
+ name=session_key + '_pos_y',
313
+ attrs={'type': 'position',
314
+ 'units': 'meters',
315
+ 'dimensionality': 'length',
316
+ 'orientation': 'vertical',
317
+ 'session_key': session_key})
318
+ pos_t = xr.DataArray(data['t'],
319
+ name=session_key + '_pos_t',
320
+ dims='time',
321
+ attrs={'type': 'position',
322
+ 'units': 'seconds',
323
+ 'dimensionality': 'time',
324
+ 'session_key': session_key})
325
+ return [pos_x, pos_y, pos_t]
326
+
327
+
328
+ def read_position_data(pos_path: str, ppm=None) -> tuple:
329
+
330
+ '''
331
+ Extracts position data from .pos file
332
+
333
+ Params:
334
+ pos_path (str):
335
+ Directory of where the position file is stored
336
+ ppm (float):
337
+ Pixel per meter value
338
+
339
+ Returns:
340
+ Tuple: pos_x,pos_y,pos_t,(pos_x_width,pos_y_width)
341
+ --------
342
+ pos_x, pos_y, pos_t (np.ndarray):
343
+ Array of x, y coordinates, and timestamps
344
+ pos_x_width (float):
345
+ max - min x coordinate value (arena width)
346
+ pos_y_width (float)
347
+ max - min y coordinate value (arena length)
348
+ '''
349
+
350
+ pos_data = _get_position(pos_path, ppm=ppm)
351
+
352
+ # Correcting pos_t data in case of bad position file
353
+ new_pos_t = np.copy(pos_data[2])
354
+ if len(new_pos_t) < len(pos_data[0]):
355
+ while len(new_pos_t) != len(pos_data[0]):
356
+ new_pos_t = np.append(new_pos_t, float(new_pos_t[-1] + 0.02))
357
+ elif len(new_pos_t) > len(pos_data[0]):
358
+ while len(new_pos_t) != len(pos_data[0]):
359
+ new_pos_t = np.delete(new_pos_t, -1)
360
+
361
+ Fs_pos = pos_data[3]
362
+
363
+ file_ppm = pos_data[-1]
364
+
365
+ if file_ppm is None:
366
+ raise AxonaPosAdapterError('PPM must be in position file or settings'
367
+ + ' dictionary to proceed')
368
+
369
+ pos_x = pos_data[0]
370
+ pos_y = pos_data[1]
371
+ pos_t = new_pos_t
372
+
373
+ # Rescale coordinate values with respect to a center point
374
+ # (i.e arena center = origin (0,0))
375
+ center = _center_box(pos_x, pos_y)
376
+ pos_x = pos_x - center[0]
377
+ pos_y = pos_y - center[1]
378
+
379
+ # Correct for bad tracking
380
+ pos_data_corrected = _rem_bad_track(pos_x, pos_y, pos_t, 2)
381
+ pos_x = pos_data_corrected[0]
382
+ pos_y = pos_data_corrected[1]
383
+ pos_t = pos_data_corrected[2]
384
+
385
+ # Remove NaN values
386
+ nonNanValues = np.where(np.isnan(pos_x) == False)[0]
387
+ pos_t = pos_t[nonNanValues]
388
+ pos_x = pos_x[nonNanValues]
389
+ pos_y = pos_y[nonNanValues]
390
+
391
+ # Smooth data using boxcar convolution
392
+ B = np.ones((int(np.ceil(0.4 * Fs_pos)), 1)) / np.ceil(0.4 * Fs_pos)
393
+ pos_x = scipy.ndimage.convolve(pos_x, B, mode='nearest')
394
+ pos_y = scipy.ndimage.convolve(pos_y, B, mode='nearest')
395
+
396
+ pos_x_width = max(pos_x) - min(pos_x)
397
+ pos_y_width = max(pos_y) - min(pos_y)
398
+
399
+ return {"t": pos_t.flatten(),
400
+ "x": pos_x.flatten(),
401
+ "y": pos_y.flatten(),
402
+ "arena_width": pos_x_width,
403
+ "arena_height": pos_y_width,
404
+ "sample_rate": pos_data[3],
405
+ "ppm": file_ppm}
406
+
407
+
408
+ def _rem_bad_track(x, y, t, threshold):
409
+ """function [x,y,t] = _rem_bad_track(x,y,t,treshold)
410
+
411
+ % Indexes to position samples that are to be removed
412
+ """
413
+
414
+ remInd = []
415
+ diffx = np.diff(x, axis=0)
416
+ diffy = np.diff(y, axis=0)
417
+ diffR = np.sqrt(diffx ** 2 + diffy ** 2)
418
+
419
+ # the MATLAB works fine without NaNs, if there are Nan's just set them
420
+ # to threshold they will be removed later
421
+ diffR[np.isnan(diffR)] = threshold # setting the nan values to threshold
422
+ ind = np.where((diffR > threshold))[0]
423
+
424
+ if len(ind) == 0: # no bad samples to remove
425
+ return x, y, t
426
+
427
+ if ind[-1] == len(x):
428
+ offset = 2
429
+ else:
430
+ offset = 1
431
+
432
+ for index in range(len(ind) - offset):
433
+ if ind[index + 1] == ind[index] + 1:
434
+ # A single sample position jump, tracker jumps out one sample and
435
+ # then jumps back to UPath on the next sample. Remove bad sample.
436
+ remInd.append(ind[index] + 1)
437
+ else:
438
+ ''' Not a single jump. 2 possibilities:
439
+ 1. TrackerMetadata jumps out, and stay out at the same place
440
+ for several
441
+ samples and then jumps back.
442
+ 2. TrackerMetadata just has a small jump before UPath continues
443
+ as normal,
444
+ unknown reason for this. In latter case the samples are left
445
+ untouched'''
446
+ con = x[ind[index] + 1:ind[index + 1] + 1 + 1] == x[ind[index] + 1]
447
+ idx = np.where(con)[0]
448
+ if len(idx) == len(x[ind[index] + 1:ind[index + 1] + 1 + 1]):
449
+ remInd.extend(
450
+ list(range(ind[index] + 1, ind[index + 1] + 1 + 1)))
451
+ # have that extra since range goes to end-1
452
+
453
+ # keep_ind = [val for val in range(len(x)) if val not in remInd]
454
+ keep_ind = np.setdiff1d(np.arange(len(x)), remInd)
455
+
456
+ x = x[keep_ind].flatten()
457
+ y = y[keep_ind].flatten()
458
+ t = t[keep_ind].flatten()
459
+
460
+ return x.reshape((len(x), 1)), y.reshape((len(y), 1)), t.reshape((len(t), 1))
461
+
462
+
463
+ def _find_center(NE, NW, SW, SE):
464
+ """Finds the center point (x,y) of the position boundaries"""
465
+
466
+ x = np.mean([np.amax([NE[0], SE[0]]), np.amin([NW[0], SW[0]])])
467
+ y = np.mean([np.amax([NW[1], NE[1]]), np.amin([SW[1], SE[1]])])
468
+ return np.array([x, y])
469
+
470
+
471
+ def _center_box(posx, posy):
472
+ # must remove Nans first because the np.amin will return nan if there is a nan
473
+ posx = posx[~np.isnan(posx)] # removes NaNs
474
+ posy = posy[~np.isnan(posy)] # remove Nans
475
+
476
+ NE = np.array([np.amax(posx), np.amax(posy)])
477
+ NW = np.array([np.amin(posx), np.amax(posy)])
478
+ SW = np.array([np.amin(posx), np.amin(posy)])
479
+ SE = np.array([np.amax(posx), np.amin(posy)])
480
+
481
+ return _find_center(NE, NW, SW, SE)
482
+
483
+
484
+ def _fix_timestamps(post):
485
+ first = post[0]
486
+ N = len(post)
487
+ uniquePost = np.unique(post)
488
+
489
+ if len(uniquePost) != N:
490
+ didFix = True
491
+ numZeros = 0
492
+ # find the number of zeros at the end of the file
493
+
494
+ while True:
495
+ if post[-1 - numZeros] == 0:
496
+ numZeros += 1
497
+ else:
498
+ break
499
+ last = first + (N-1-numZeros)*0.02
500
+ fixedPost = np.arange(first, last+0.02, 0.02)
501
+ fixedPost = fixedPost.reshape((len(fixedPost), 1))
502
+
503
+ else:
504
+ didFix = False
505
+ fixedPost = []
506
+
507
+ return didFix, fixedPost
508
+
509
+
510
+ def _arena_config(posx, posy, ppm, center, flip_y=True):
511
+ """
512
+ :param posx:
513
+ :param posy:
514
+ :param arena:
515
+ :param conversion:
516
+ :param center:
517
+ :param flip_y: bool value that will determine if you want to flip y or not. When recording on Intan we inverted the
518
+ positions due to the camera position. However in the virtualmaze you might not want to flip y values.
519
+ :return:
520
+ """
521
+ center = center
522
+ conversion = ppm
523
+
524
+ posx = 100 * (posx - center[0]) / conversion
525
+
526
+ if flip_y:
527
+ # flip the y axis
528
+ posy = 100 * (-posy + center[1]) / conversion
529
+ else:
530
+ posy = 100 * (posy + center[1]) / conversion
531
+
532
+ return posx, posy
533
+
534
+
535
+ def _remove_nan(posx, posy, post):
536
+ """Remove any NaNs from the end of the array"""
537
+ remove_nan = True
538
+ while remove_nan:
539
+ if np.isnan(posx[-1]):
540
+ posx = posx[:-1]
541
+ posy = posy[:-1]
542
+ post = post[:-1]
543
+ else:
544
+ remove_nan = False
545
+ return posx, posy, post
546
+
547
+
548
+ def _get_position(pos_fpath, ppm=None, method='', flip_y=True):
549
+ """
550
+ _get_position function:
551
+ ---------------------------------------------
552
+ variables:
553
+ -pos_fpath: the full UPath (C:/example/session.pos)
554
+
555
+ output:
556
+ t: column numpy array of the time stamps
557
+ x: a column array of the x-values (in pixels)
558
+ y: a column array of the y-values (in pixels)
559
+ """
560
+
561
+ with open(pos_fpath, 'rb+') as f: # opening the .pos file
562
+ headers = '' # initializing the header string
563
+ for line in f: # reads line by line to read the header of the file
564
+ # print(line)
565
+ if 'data_start' in str(line): # if it reads data_start that means the header has ended
566
+ headers += 'data_start'
567
+ break # break out of for loop once header has finished
568
+ elif 'duration' in str(line):
569
+ headers += line.decode(encoding='UTF-8')
570
+ elif 'num_pos_samples' in str(line):
571
+ num_pos_samples = int(line.decode(encoding='UTF-8')[len('num_pos_samples '):])
572
+ headers += line.decode(encoding='UTF-8')
573
+ elif 'bytes_per_timestamp' in str(line):
574
+ bytes_per_timestamp = int(line.decode(encoding='UTF-8')[len('bytes_per_timestamp '):])
575
+ headers += line.decode(encoding='UTF-8')
576
+ elif 'bytes_per_coord' in str(line):
577
+ bytes_per_coord = int(line.decode(encoding='UTF-8')[len('bytes_per_coord '):])
578
+ headers += line.decode(encoding='UTF-8')
579
+ elif 'timebase' in str(line):
580
+ timebase = (line.decode(encoding='UTF-8')[len('timebase '):]).split(' ')[0]
581
+ headers += line.decode(encoding='UTF-8')
582
+ elif 'pixels_per_metre' in str(line):
583
+ # print('READING PIXELS PER METRE FROM FILE')
584
+ ppm = float(line.decode(encoding='UTF-8')[len('pixels_per_metre '):])
585
+ headers += line.decode(encoding='UTF-8')
586
+ elif 'min_x' in str(line) and 'window' not in str(line):
587
+ min_x = int(line.decode(encoding='UTF-8')[len('min_x '):])
588
+ headers += line.decode(encoding='UTF-8')
589
+ elif 'max_x' in str(line) and 'window' not in str(line):
590
+ max_x = int(line.decode(encoding='UTF-8')[len('max_x '):])
591
+ headers += line.decode(encoding='UTF-8')
592
+ elif 'min_y' in str(line) and 'window' not in str(line):
593
+ min_y = int(line.decode(encoding='UTF-8')[len('min_y '):])
594
+ headers += line.decode(encoding='UTF-8')
595
+ elif 'max_y' in str(line) and 'window' not in str(line):
596
+ max_y = int(line.decode(encoding='UTF-8')[len('max_y '):])
597
+ headers += line.decode(encoding='UTF-8')
598
+ elif 'pos_format' in str(line):
599
+ headers += line.decode(encoding='UTF-8')
600
+ if 't,x1,y1,x2,y2,numpix1,numpix2' in str(line):
601
+ two_spot = True
602
+ else:
603
+ two_spot = False
604
+ print('The position format is unrecognized!')
605
+
606
+ elif 'sample_rate' in str(line):
607
+ sample_rate = line.decode(encoding='UTF-8').split(' ')[1]
608
+ sample_rate = float(sample_rate)
609
+ headers += line.decode(encoding='UTF-8')
610
+
611
+ else:
612
+ headers += line.decode(encoding='UTF-8')
613
+
614
+ assert ppm is not None, 'PPM must be in position file or settings dictionary to proceed'
615
+
616
+ if two_spot:
617
+ '''Run when two spot mode is on, (one_spot has the same format so it will also run here)'''
618
+ with open(pos_fpath, 'rb+') as f:
619
+ '''get_pos for one_spot'''
620
+ pos_data = f.read() # all the position data values (including header)
621
+ pos_data = pos_data[len(headers):-12] # removes the header values
622
+
623
+ byte_string = 'i8h'
624
+
625
+ pos_data = np.asarray(struct.unpack('>%s' % (num_pos_samples * byte_string), pos_data))
626
+ pos_data = pos_data.astype(float).reshape((num_pos_samples, 9)) # there are 8 words and 1 time sample
627
+
628
+ x = pos_data[:, 1]
629
+ y = pos_data[:, 2]
630
+ t = pos_data[:, 0]
631
+
632
+ x = x.reshape((len(x), 1))
633
+ y = y.reshape((len(y), 1))
634
+ t = t.reshape((len(t), 1))
635
+
636
+ if method == 'raw':
637
+ return x, y, t, sample_rate
638
+
639
+ t = np.divide(t, np.float64(timebase)) # converting the frame number from Axona to the time value
640
+
641
+ # values that are NaN are set to 1023 in Axona's system, replace these values by NaN's
642
+
643
+ x[np.where(x == 1023)] = np.nan
644
+ y[np.where(y == 1023)] = np.nan
645
+
646
+ didFix, fixedPost = _fix_timestamps(t)
647
+
648
+ if didFix:
649
+ t = fixedPost
650
+
651
+ t = t - t[0]
652
+
653
+ x, y = _arena_config(x, y, ppm,
654
+ center=_center_box(x, y),
655
+ flip_y=flip_y)
656
+
657
+ # remove any NaNs at the end of the file
658
+ x, y, t = _remove_nan(x, y, t)
659
+
660
+ else:
661
+ print("Haven't made any code for this part yet.")
662
+
663
+ return x.reshape((len(x), 1)), y.reshape((len(y), 1)), t.reshape((len(t), 1)), sample_rate, ppm
664
+
665
+
666
+ class AxonaPosAdapterError(Exception):
667
+ pass
668
+
669
+
670
+ # =============================================================================
671
+ # cut and tet (tetrode) Helpers
672
+ # =============================================================================
673
+
674
+ import xarray as xr
675
+ import numpy as np
676
+ from datetime import datetime, timezone
677
+
678
+
679
+ def read_spike_data_set(tetrode_file_path, cut_file_path, session_data_ref, tetrode_name):
680
+
681
+ spike_labels = _read_cut_file(
682
+ cut_file_path,
683
+ session_data_ref
684
+ )
685
+ spike_timestamps, spike_waveforms = _read_tetrode_file(
686
+ tetrode_file_path,
687
+ session_data_ref
688
+ )
689
+ session_data_ref, tetrode_name = str(session_data_ref), str(tetrode_name)
690
+ spike_timestamps.name = session_data_ref + '_tetrode_' + tetrode_name + '_spike_timestamps'
691
+ spike_waveforms.name = session_data_ref + '_tetrode_' + tetrode_name + '_spike_waveforms'
692
+ spike_labels.name = session_data_ref + '_tetrode_' + tetrode_name + '_spike_labels'
693
+
694
+ return [spike_timestamps, spike_waveforms, spike_labels]
695
+
696
+
697
+ def _read_cut_file(cut_file_path, session_data_ref):
698
+ with open(cut_file_path, 'r') as cut_file:
699
+ lines = cut_file.readlines()
700
+ #print('>>>>>>>>>>>>>> ', lines)
701
+ spike_labels = []
702
+ extract_cut = False
703
+ for line in lines:
704
+ if 'Exact_cut' in line: # finding the beginning of the cut values
705
+ extract_cut = True
706
+ continue
707
+ if extract_cut: # read all the cut values
708
+ line_labels = str(line)
709
+ for string_val in ['\\n', ',', "'", '[', ']']:
710
+ # remove non base10 integer values
711
+ line_labels = line_labels.replace(string_val, '')
712
+ spike_labels.extend([int(val) for val in line_labels.split()])
713
+ else:
714
+ continue
715
+ if len(spike_labels) == 0:
716
+ raise ValueError('There are no spike labels in this file.')
717
+ return xr.DataArray(
718
+ data=np.array(spike_labels),
719
+ dims=['spikes'],
720
+ attrs={'type': 'spike_labels',
721
+ 'session_data_ref': session_data_ref,
722
+ 'units': 'neuron',
723
+ 'dimensionality': 'nominal',
724
+ }
725
+ )
726
+
727
+
728
+ def _read_tetrode_file(tetrode_file_path, session_data_ref):
729
+ raw_data = _get_raw_tetrode_data(tetrode_file_path)
730
+
731
+ spike_timestamps = _extract_spike_timestamps(raw_data)
732
+
733
+ spike_waveforms = _extract_spike_waveforms(raw_data)
734
+
735
+ metadata = {
736
+ 'session_data_ref': session_data_ref,
737
+ 'session_start': _session_start(
738
+ raw_data['trial_date'],
739
+ raw_data['trial_time']
740
+ ),
741
+ "session_duration": raw_data['duration'],
742
+ "sample_rate": raw_data['sample_rate'],
743
+ }
744
+
745
+ spike_timestamps.attrs.update(metadata)
746
+ spike_waveforms.attrs.update(metadata)
747
+
748
+ return spike_timestamps, spike_waveforms
749
+
750
+
751
+ def _get_raw_tetrode_data(tetrode_file_path):
752
+
753
+ data = {}
754
+ with TetrodeDecoder(tetrode_file_path) as d:
755
+ for line in d.tetrode_file:
756
+ for keyword in d.line_decoders.keys():
757
+ if keyword in str(line):
758
+ data[keyword] = d.line_decoders[keyword](line)
759
+ break
760
+
761
+ # replace data_start with spike_data name for clarity
762
+ data['spike_data'] = data.pop('data_start')
763
+ return data
764
+
765
+
766
+ class TetrodeDecoder:
767
+
768
+ def __init__(self, tetrode_file):
769
+ self.tetrode_file = open(tetrode_file, 'rb')
770
+ self.line_decoders = {
771
+ 'num_spikes': self.numeric,
772
+ 'bytes_per_timestamp': self.numeric,
773
+ 'samples_per_spike': self.numeric,
774
+ 'bytes_per_sample': self.numeric,
775
+ 'timebase': self.numeric,
776
+ 'duration': self.numeric,
777
+ 'sample_rate': self.numeric,
778
+ 'num_chans': self.numeric,
779
+ 'trial_date': self.datestring,
780
+ 'trial_time': self.timestring,
781
+ 'data_start': self.data
782
+ }
783
+
784
+ def __del__(self):
785
+ self.tetrode_file.close()
786
+
787
+ def __enter__(self):
788
+ return self
789
+
790
+ def __exit__(self, exc_type, exc_val, exc_tb):
791
+ self.tetrode_file.close()
792
+
793
+ def data(self, line):
794
+ start_index = len('data_start')
795
+ stop_index = -len('\r\ndata_end\r\n')
796
+ data_string = (line + self.tetrode_file.read())[start_index:stop_index]
797
+ spike_data = np.frombuffer(data_string, dtype='uint8')
798
+ return spike_data
799
+
800
+ def numeric(self, line):
801
+ return int(line.decode(encoding='UTF-8').split(" ")[1])
802
+
803
+ def datestring(self, line):
804
+ return line.decode(encoding='UTF-8').split(" ")[2:]
805
+
806
+ def timestring(self, line):
807
+ return line.decode(encoding='UTF-8').split(" ")[1]
808
+
809
+
810
+ def _session_start(trial_date, trial_time):
811
+ day, month, year = trial_date
812
+ month = datetime.strptime(str(month), '%b').month
813
+ hour, minute, second = trial_time.split(':')
814
+ date_time = datetime(int(year), int(month), int(day), int(hour), int(minute), int(second))
815
+ return date_time
816
+
817
+
818
+ def _extract_spike_timestamps(raw_data):
819
+ bytes_per_timestamp = raw_data["bytes_per_timestamp"]
820
+ num_spikes = raw_data["num_spikes"]
821
+ num_channels = raw_data["num_chans"]
822
+ spike_data = raw_data["spike_data"]
823
+ timebase = raw_data["timebase"]
824
+
825
+ big_endian_vector = 256 ** np.arange(bytes_per_timestamp - 1, -1, -1)
826
+
827
+ t_start_indices = _compute_time_start_indexes(raw_data)
828
+
829
+ t_indices = t_start_indices
830
+
831
+ for chan in np.arange(1, num_channels):
832
+ t_indices = np.hstack((t_indices, t_start_indices + chan))
833
+
834
+ # acquiring the time bytes
835
+ timestamps = spike_data[t_indices].reshape(num_spikes, bytes_per_timestamp)
836
+ # converting from bytes to float values
837
+ timestamps = np.sum(np.multiply(timestamps, big_endian_vector), axis=1) / timebase
838
+
839
+ return xr.DataArray(timestamps.flatten(),
840
+ dims=['spikes'],
841
+ attrs={'type': 'spike_times',
842
+ 'units': 's',
843
+ 'dimensionality': 'time'})
844
+
845
+
846
+ def _extract_spike_waveforms(raw_data):
847
+ bytes_per_timestamp = raw_data["bytes_per_timestamp"]
848
+ samples_per_spike = raw_data["samples_per_spike"]
849
+ num_spikes = raw_data["num_spikes"]
850
+ num_channels = raw_data["num_chans"]
851
+ spike_data = raw_data["spike_data"]
852
+
853
+ t_start_indices = _compute_time_start_indexes(raw_data)
854
+
855
+ # read the raw data formatted in the order: t,ch1,t,ch2,t,ch3,t,ch4
856
+
857
+ little_endian_matrix = _compute_little_endian_matrix(raw_data)
858
+
859
+ channels = []
860
+
861
+ for chan in range(num_channels): # only really care about the first time that gets written
862
+ chan_start_indices = t_start_indices + bytes_per_timestamp * (chan + 1) + samples_per_spike * chan
863
+ # print(chan_start_indices[0:100])
864
+ for spike_sample in np.arange(1, samples_per_spike):
865
+ chan_start_indices = np.hstack((chan_start_indices, t_start_indices + chan * samples_per_spike + bytes_per_timestamp * (chan+1) + spike_sample))
866
+
867
+ # acquiring the channel bytes
868
+ bts = spike_data[chan_start_indices]
869
+ bts = bts.reshape(num_spikes, samples_per_spike).astype('int8')
870
+ channels.append(bts)
871
+
872
+ channels[chan][np.where(channels[chan][:][:] > 127)] -= 256
873
+ channels[chan] = np.multiply(channels[chan][:][:],
874
+ little_endian_matrix,
875
+ dtype=np.float16)
876
+
877
+ waveform_data = np.stack(channels, axis=1)
878
+
879
+ return xr.DataArray(waveform_data,
880
+ dims=['spikes', 'channels', 'samples'],
881
+ attrs={'type': 'spike_waveforms',
882
+ 'units': 'uV',
883
+ 'dimensionality': 'voltage'
884
+ }
885
+ )
886
+
887
+
888
+ def _compute_time_start_indexes(raw_data):
889
+ """Computes the indices for the first bit of each timestamp in the raw data
890
+ """
891
+ # calculating the timestamps
892
+ bytes_per_timestamp = raw_data["bytes_per_timestamp"]
893
+ bytes_per_sample = raw_data["bytes_per_sample"]
894
+ samples_per_spike = raw_data["samples_per_spike"]
895
+ num_spikes = raw_data["num_spikes"]
896
+ step = (bytes_per_sample * samples_per_spike * 4 + bytes_per_timestamp * 4)
897
+ t_start_indices = np.arange(0, step*num_spikes, step).astype(int)
898
+ t_start_indices = t_start_indices.reshape(num_spikes, 1)
899
+ return t_start_indices
900
+
901
+
902
+ def _compute_little_endian_matrix(raw_data):
903
+ bps = raw_data["bytes_per_sample"]
904
+ sps = raw_data["samples_per_spike"]
905
+ little_endian_matrix = np.arange(0, bps).reshape(bps, 1)
906
+ little_endian_matrix = 256 ** np.tile(little_endian_matrix, (1, sps))
907
+ return little_endian_matrix
908
+
909
+ # =============================================================================
910
+ # set (settings) Helpers
911
+ # =============================================================================
912
+ # TODO: Add a function to read the settings file