vscope 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
vscope/localmotion.py ADDED
@@ -0,0 +1,407 @@
1
+ #!/usr/bin/python3
2
+
3
+ import vscope.rois
4
+ import numpy as np
5
+ import spectralign.funcs as swiftir
6
+ import scipy.signal
7
+
8
+ def upsampledImages(x, cam):
9
+ '''Return Tx512x512 images, upsampling the Y direction to make square
10
+ pixels and returns the scale factor.'''
11
+ s = x.ccd[cam].astype(float) - 1000
12
+ T,Y,X = s.shape
13
+ SY = 1
14
+ if Y<X:
15
+ SY = X//Y
16
+ s = s.reshape(T,Y,1,X).repeat(SY,2).reshape(T,Y*SY,X)
17
+ return s.astype(np.float32), SY
18
+
19
+
20
+ def stationaryImage(frms, skipstart=30, skipend=10):
21
+ '''Return an average image from a series of frames'''
22
+ return frms[skipstart:-skipend].mean(0)
23
+
24
+
25
+ def blockcorners(ff, r):
26
+ '''Remove the corners from a 2d fft to block low frequencies'''
27
+ ff = 0 + ff
28
+ ff[:r, :r, :] = 0
29
+ ff[:r, -r:, :] = 0
30
+ ff[-r:, :r, :] = 0
31
+ ff[-r:, -r:, :] = 0
32
+ return ff
33
+
34
+ def blockedges(ff, r):
35
+ '''Remove the edges from a 2d fft to block low frequencies'''
36
+ ff = 0 + ff
37
+ ff[:r, :, :] = 0
38
+ ff[-r:, :, :] = 0
39
+ ff[:, :r, :] = 0
40
+ ff[:, -r:, :] = 0
41
+ return ff
42
+
43
+ def blockhighfreq(ff, SY):
44
+ '''Drops high Y frequencies from the FFT image. This is needed,
45
+ because those are an artifact from upsampling and should not be
46
+ used for alignment.
47
+ '''
48
+ Y,X,C = ff.shape
49
+ ff = 0 + ff
50
+ ff[Y//SY//2:-Y//SY//2,:,:] = 0
51
+ return ff
52
+
53
+ def globalAffines(frms, stat=None, lp=None, blockhigh=None):
54
+ '''Calculate affine transforms to align a stack of frames with a
55
+ stationary image'''
56
+ if stat is None:
57
+ stat = stationaryImage(frms)
58
+ T, Y, X = frms.shape
59
+ pa = np.array([[.4*X, .4*X, .6*X, .6*X], [.4*Y, .6*Y, .4*Y, .6*Y]])
60
+ ppb = []
61
+ siz = (320,320)
62
+ for p in pa.T:
63
+ apostat = swiftir.apodize(swiftir.extractStraightWindow(stat, p, siz))
64
+ fftstat = swiftir.fft(apostat)
65
+ fftstat = blockedges(fftstat, 4)
66
+ if blockhigh is not None:
67
+ fftstat = blockhighfreq(fftstat, blockhigh)
68
+ #qp.figure('s5')
69
+ #qp.imsc(swiftir.ifft(fftstat))
70
+ swim = [ swiftir.swim(fftstat,
71
+ swiftir.extractStraightWindow(frms[t], p, siz),
72
+ rad=10)
73
+ for t in range(T)]
74
+ swim = np.array(swim)
75
+ dx1 = swim[:,0]
76
+ dy1 = swim[:,1]
77
+ sx1 = swim[:,2]
78
+ sy1 = swim[:,3]
79
+ snr1 = swim[:,4]
80
+ ppb.append(p.reshape(2,1) + np.array([dx1,dy1]))
81
+ ppb = np.array(ppb).transpose(2,1,0)
82
+ if lp is not None:
83
+ b,a=scipy.signal.butter(1, 1/lp)
84
+ ppb = scipy.signal.filtfilt(b,a, ppb, axis=0)
85
+ afms = [ swiftir.mirAffine(pa, pb)[0] for pb in ppb ]
86
+ return afms
87
+
88
+
89
+ def correctGlobal(frms, afms):
90
+ '''Interpolate a stack of frames so as to align them onto a stationary
91
+ image. AFMS must be from globalAffines.'''
92
+ T = len(afms)
93
+ clnd = [swiftir.affineImage(afms[t], frms[t]) for t in range(T)]
94
+ return np.array(clnd)
95
+
96
+
97
+ def localSwim(stat, frms, p0, Q, SY, lp=None):
98
+ '''Calculate xy shift near point of interest. Q must be window size.
99
+ SY is Y scale factor, lp>1 enables low-pass temporal filtering.'''
100
+ T, Y, X = frms.shape
101
+ statwin = swiftir.extractStraightWindow(stat, p0, Q)
102
+ apostat = swiftir.apodize(statwin)
103
+ fftstat = swiftir.fft(apostat)
104
+ fftstat = blockedges(fftstat, 4)
105
+ if SY:
106
+ fftstat = blockhighfreq(fftstat, SY)
107
+ swim = [ swiftir.swim(fftstat,
108
+ swiftir.extractStraightWindow(frms[t], p0, Q))
109
+ for t in range(T)]
110
+ swim = np.array(swim)
111
+ dxy1 = swim[:,:2]
112
+ if lp is not None and lp>1:
113
+ b,a=scipy.signal.butter(1, 1/lp)
114
+ dxy1 = scipy.signal.filtfilt(b,a, dxy1, axis=0)
115
+ sxy1 = swim[:,2:4]
116
+ snr1 = swim[:,4]
117
+ return dxy1, sxy1, snr1
118
+
119
+
120
+ def localShift(clndfrms, x, roiid, cam, stat, lp=None, blockhigh=None):
121
+ '''Calculate xy shift near a cell of interest.
122
+ returns (dx,dy) [Tx2], (sx,sy) [Tx2], snr [T], p0, Q,
123
+ where p0 is the center of the image tile and Q its width.'''
124
+ xx, yy = vscope.rois.pixelcoords(x, roiid, cam)
125
+ #Q = int(max(np.max(np.abs(xx-x0)),
126
+ # np.max(np.abs(yy*4-y0)))) * 4 # Plenty of space for apo
127
+ Q = 128 # Fix it to avoid surprises
128
+ x0 = xx.mean()
129
+ M = Q//2
130
+ T,Y,X = clndfrms.shape
131
+ if x0<M:
132
+ x0 = M
133
+ elif x0>X-M:
134
+ x0 = X-M
135
+ y0 = yy.mean()*4
136
+ if y0<M:
137
+ y0 = M
138
+ elif y0>Y-M:
139
+ y0 = Y-M
140
+ p0 = np.round(np.array([x0,y0]))
141
+ dxy, sxy, snr = localSwim(stat, clndfrms, p0, Q, blockhigh, lp)
142
+ return dxy, sxy, snr, p0, Q
143
+
144
+
145
+ def correctLocal(clndfrms, dxy, p0, Q):
146
+ '''To be used after localshift, returns the contents of the tile
147
+ mapped onto the stationary image.'''
148
+ T = len(clndfrms)
149
+ cl1 = [swiftir.extractStraightWindow(clndfrms[t], p0 + dxy[t], Q)
150
+ for t in range(T)]
151
+ try:
152
+ qx = Q[0]
153
+ qy = Q[1]
154
+ except:
155
+ qx = qy = Q
156
+ return np.array(cl1).reshape(T,qy//4,4,qx).mean(2)
157
+
158
+
159
+
160
+ def patchwork(stat, frms, width=128, step=32, SY=4, sigma=64):
161
+ '''Calculates local motion in a grid of patches.
162
+
163
+ Parameters are:
164
+ STAT - Stationary image (YxX)
165
+ FRMS - TxYxX stack of moving images
166
+ WIDTH - Size of square patches in pixels
167
+ STEP - Step between adjacent patches in pixels, usually smaller
168
+ than WIDTH, so patches overlap
169
+ SY - Known expansion factor in the Y dimension of the images. (This is
170
+ used for LOCALSWIM to apply low-pass filtering in y-dimension.)
171
+ SIGMA - Sigma parameter for gaussian interpolation over patches with
172
+ poor definition.
173
+
174
+ Returns:
175
+ DELTAXY - A function (X, Y) -> (DXY) that takes (vectors of) x- and y-
176
+ coordinates and returns a Nx2 vector of shifts
177
+ DETAILS - A dictionary with additional information, containing:
178
+ xx0: x-coordinates of centers of patches
179
+ yy0: y-coordinates of centers of patches
180
+ dxy: raw shift at centers of patches (YxXxTx2)
181
+ snr: signal-to-noise for each patch (YxXxT)
182
+ weight: interpolation weights for each patch (YxX)
183
+ dxymax: max. shift in each path (YxX)
184
+ width: from function call
185
+ step: from function call
186
+ SY: from function call
187
+ sigma: from function call
188
+ '''
189
+ T, Y, X = frms.shape
190
+ xx0 = np.arange(width//2 + step//2, X - width//2, step)
191
+ yy0 = np.arange(width//2 + step//2, Y - width//2, step)
192
+ NX = len(xx0)
193
+ NY = len(yy0)
194
+ dxy = []
195
+ snr = []
196
+ for ny in range(NY):
197
+ dxy.append([])
198
+ snr.append([])
199
+ print(f"LOCALMOTION Patchwork row {ny} of {NY}")
200
+ for nx in range(NX):
201
+ p0 = np.array([xx0[nx], yy0[ny]])
202
+ dxy1, sxy1, snr1 = localSwim(stat, frms, p0, width, SY)
203
+ dxy[-1].append(dxy1)
204
+ snr[-1].append(snr1)
205
+
206
+ dxy = np.array(dxy) # YxXxTx2
207
+ snr = np.array(snr) # YxXxT
208
+
209
+ snr0 = np.mean(snr, -1) # YxX
210
+
211
+ weight = np.exp(-(15-snr0)/5)
212
+ weight[weight>1] = 1
213
+ dxymax = np.sqrt(np.max(dxy[:,:,:,0]**2 + dxy[:,:,:,1]**2, axis=-1))
214
+ weight[dxymax>5] = 0
215
+
216
+ details = {
217
+ "xx0": xx0,
218
+ "yy0": yy0,
219
+ "dxy": dxy,
220
+ "snr": snr,
221
+ "weight": weight,
222
+ "dxymax": dxymax,
223
+ "width": width,
224
+ "step": step,
225
+ "SY": SY,
226
+ "sigma": sigma }
227
+ def dist(x, y):
228
+ return np.array([ [ np.exp(-.5*((x-x0)**2 + (y-y0)**2)/sigma**2)
229
+ for x0 in xx0 ]
230
+ for y0 in yy0])
231
+ def deltaxy(x, y):
232
+ ww = dist(x, y) * weight
233
+ norm = np.sum(ww)
234
+ dxy1 = np.sum(ww.reshape(NY,NX,1,1) * dxy, (0,1)) / norm
235
+ return dxy1
236
+
237
+ return deltaxy, details
238
+
239
+
240
+ class LocalCorrector:
241
+ def __init__(self, x, cam, global_lp=5, local_lp=5, twice=True):
242
+ '''LocalCorrector - Local motion correction for vscope
243
+ LocalCorrector(x, cam), where X is from LOADER.LOAD and CAM is
244
+ a camera ID, constructs an object that can perform local motion
245
+ correction for the given recording.
246
+ Optional arguments are:
247
+ GLOBAL_LP: tau for a low-pass filter to be applied to global
248
+ motion, in frames
249
+ LOCAL_LP: tau for a low-pass filter for local motion, ditto
250
+ TWICE: Perform global estimation twice to improve accuracy.
251
+ This uses SWIFTIR for global motion estimation, using four large
252
+ overlapping patches to estimate an affine transformation, then
253
+ PATCHWORK to further refine local motion.
254
+ '''
255
+ self.x = x
256
+ self.cam = cam
257
+ self.frms, self.SY = upsampledImages(x, cam)
258
+ print(f"LOCALMOTION Calculating stationary image for {cam}")
259
+ self.stat = stationaryImage(self.frms)
260
+ self.global_lp = global_lp # low pass timescale for global affine
261
+ self.local_lp = local_lp # low pass timescale for local shift
262
+ ts, te, ok = vscope.ccdtime(x, cam)
263
+ self.tt = (ts+te)/2
264
+ print(f"LOCALMOTION Calculating global affines for {cam}")
265
+ self.afms = globalAffines(self.frms, self.stat,
266
+ lp=self.global_lp,
267
+ blockhigh=self.SY)
268
+ print(f"LOCALMOTION Aligning image stack for {cam}")
269
+ self.clnd = correctGlobal(self.frms, self.afms)
270
+ T, Y, X = self.clnd.shape
271
+ self.c0 = self.clnd.reshape(T,Y//4,4,X).mean(2)
272
+ print(f"LOCALMOTION Recalculating stationary image for {cam}")
273
+ self.stat = stationaryImage(self.clnd)
274
+
275
+ if twice:
276
+ print(f"LOCALMOTION Final recalc of global affines for {cam}")
277
+ self.afms2 = globalAffines(self.clnd, self.stat,
278
+ lp=self.global_lp,
279
+ blockhigh=self.SY)
280
+ print(f"LOCALMOTION Final aligning image stack for {cam}")
281
+ self.clnd = correctGlobal(self.clnd, self.afms2)
282
+ T, Y, X = self.clnd.shape
283
+ self.c0 = self.clnd.reshape(T,Y//4,4,X).mean(2)
284
+ print(f"LOCALMOTION Final recalc of stationary image for {cam}")
285
+ self.stat = stationaryImage(self.clnd)
286
+
287
+ self.patchdxy, self.patchdetails = patchwork(self.stat, self.clnd)
288
+
289
+ def timestamps(self):
290
+ '''TIMESTAMPS - Timestamps of each of the frames'''
291
+ return self.tt
292
+
293
+ def rawSignal(self, roiid, normalize=True):
294
+ '''RAWSIGNAL - Optionally normalized raw signal from given ROI
295
+ yy = RAWSIGNAL(roiid) returns the %dF/F signal for the given ROI.
296
+ yy = RAWSIGNAL(roiid, False) returns the raw signal (after subtracting
297
+ 1000 for the QuantEM camera baseline.'''
298
+ xx, yy = vscope.rois.pixelcoords(self.x, roiid, self.cam)
299
+ sig = self.x.ccd[self.cam][:,yy,xx].mean(-1) - 1000
300
+ if normalize:
301
+ return 100 * (sig/sig.mean() - 1)
302
+ else:
303
+ return sig
304
+
305
+ def globallyCorrectedSignal(self, roiid, normalize=True):
306
+ '''GLOBALLYCORRECTEDSIGNAL - Signal after global motion correction.
307
+ yy = GLOBALLYCORRECTEDSIGNAL(roiid) returns the %dF/F signal for
308
+ the given ROI after global motion correction (only).'''
309
+ xx, yy = vscope.rois.pixelcoords(self.x, roiid, self.cam)
310
+ sig = self.c0[:,yy,xx].mean(-1)
311
+ if normalize:
312
+ return 100 * (sig/sig.mean() - 1)
313
+ else:
314
+ return sig
315
+
316
+ def localShift(self, roiid):
317
+ '''LOCALSHIFT - Calculate local shift at location of ROI
318
+ dxy, p0, Q = LOCALSHIFT(roiid) returns parameters for local
319
+ shift (to be applied on top of the global correction) at the
320
+ location of a given cell.
321
+ This function is deprecated. Use the PATCHWORK-based approach
322
+ instead.'''
323
+ dxy, sxy, snr, p0, Q = localShift(self.clnd, self.x, roiid,
324
+ self.cam, self.stat,
325
+ lp=self.local_lp,
326
+ blockhigh=self.SY)
327
+ return dxy, p0, Q
328
+
329
+ def locallyCorrectedSignal(self, roiid, normalize=True):
330
+ '''LOCALLYCORRECTEDSIGNAL - Signal after local motion correction.
331
+ yy = LOCALLYCORRECTEDSIGNAL(roiid) returns the %dF/F signal for
332
+ the given ROI after local motion correction.
333
+ This function is deprecated. Use the PATCHWORK-based approach
334
+ instead.'''
335
+ dxy, p0, Q = self.localShift(roiid)
336
+ img = correctLocal(self.clnd, dxy, p0, Q)
337
+ xx, yy = vscope.rois.pixelcoords(self.x, roiid, self.cam)
338
+ y0 = int(p0[1]/4)
339
+ x0 = int(p0[0])
340
+ sig = img[:,Q//8+yy-y0,Q//2+xx-x0].mean(-1)
341
+ if normalize:
342
+ return 100 * (sig/sig.mean() - 1)
343
+ else:
344
+ return sig
345
+
346
+ def patchCorrectedSignal(self, roiid, normalize=True, return_dxy=False,
347
+ data=None):
348
+ '''PATCHCORRECTEDSIGNAL - Signal after local motion correction.
349
+ yy = PATCHCORRECTEDSIGNAL(roiid) returns the %dF/F signal for
350
+ the given ROI after patchwork-based motion correction.
351
+ yy, dxy, dxy1 = PATCHCORRECTEDSIGNAL(roiid, return_dxy=True) also
352
+ returns:
353
+ DXY: the total shift at the center of the ROI as a Tx2 array
354
+ DXY1: only the local shift at the center of the ROI as a Tx2 array
355
+ Optional argument DATA overrides the data from the original vscope
356
+ structure. It must be TxXxY.
357
+ '''
358
+ T = len(self.afms)
359
+ xx, yy = vscope.rois.pixelcoords(self.x, roiid, self.cam)
360
+ x0 = np.mean(xx)
361
+ y0 = np.mean(yy)*self.SY
362
+ dxy1 = self.patchdxy(x0, y0)
363
+ b,a = scipy.signal.butter(1, 1/self.local_lp)
364
+ dxy1 = scipy.signal.filtfilt(b, a, dxy1, axis=0)
365
+ xx0 = []
366
+ yy0 = []
367
+ for t in range(T):
368
+ pmov = swiftir.stationaryToMoving(self.afms[t], [x0, y0])
369
+ pmov = swiftir.stationaryToMoving(self.afms2[t], pmov)
370
+ xx0.append(pmov[0] + dxy1[t,0])
371
+ yy0.append(pmov[1] + dxy1[t,1])
372
+
373
+ dxx = np.array(xx0) - x0
374
+ dyy = (np.array(yy0) - y0)/4
375
+ dxy = np.array((dxx, dyy)).T
376
+
377
+ xmax = np.max(xx)
378
+ xmin = np.min(xx)
379
+ ymax = np.max(yy)
380
+ ymin = np.min(yy)
381
+ p0 = np.array([int((xmax+xmin)/2+.5), int((ymax+ymin)/2+.5)])
382
+ qx = int((xmax - xmin) + 4)//2
383
+ qy = int((ymax - ymin) + 4)//2
384
+ if data is None:
385
+ datfoo = lambda t: self.x.ccd[self.cam][t].astype(np.float32)-1000
386
+ else:
387
+ datfoo = lambda t: data[t].astype(np.float32)
388
+ imgs = [swiftir.extractStraightWindow(datfoo(t),
389
+ p0 + dxy[t], (qx*2,qy*2))
390
+ for t in range(T)]
391
+ imgs = np.array(imgs)
392
+ psig = imgs[:, yy-p0[1]+qy, xx-p0[0]+qx].mean(-1)
393
+ if normalize:
394
+ psig = 100 * (psig/psig.mean() - 1)
395
+ if return_dxy:
396
+ return psig, dxy, dxy1
397
+ else:
398
+ return psig
399
+
400
+ def correctedFrames(self):
401
+ '''CORRECTEDFRAMES - Raw CCD frames after global motion correction
402
+ frms = CORRECTEDFRAMES() returns the raw CCD frames after global
403
+ motion correction as a TxYxX array. Any 4x vertical scaling
404
+ used for parameter estimation is undone first, so Y=128 if 4x
405
+ binning was used in recording. This function does not subtract
406
+ 1000 from the data. '''
407
+ return self.c0