StraitFlux 1.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
StraitFlux/__init__.py ADDED
@@ -0,0 +1 @@
1
+
@@ -0,0 +1,301 @@
1
+ import xarray as xa
2
+ import numpy as np
3
+ try:
4
+ import matplotlib.pyplot as plt
5
+ except ImportError:
6
+ print('skipping matplotlib')
7
+ from tqdm import tqdm
8
+ import sys
9
+ try:
10
+ from dask.diagnostics import ProgressBar
11
+ except ImportError:
12
+ print('skipping dask import')
13
+ from StraitFlux.indices import check_availability_indices, prepare_indices
14
+
15
+ def check_Arakawa(u_data,v_data,T_data,model):
16
+
17
+ u=u_data
18
+ v=v_data
19
+ t=T_data
20
+
21
+ print('checking grid')
22
+ if model in ['MPI-ESM1-2-LR','MPI-ESM1-2-HR']:
23
+ grid='Arakawa-C'
24
+ print(grid)
25
+ elif u.lat[int(len(t.y)/2),0].values == t.lat[int(len(t.y)/2),0].values and v.lat[int(len(t.y)/2),0].values != t.lat[int(len(t.y)/2),0].values:
26
+ if v.lon[int(len(t.y)/2),0].values == t.lon[int(len(t.y)/2),0].values and u.lon[int(len(t.y)/2),0].values != t.lon[int(len(t.y)/2),0].values:
27
+ grid='Arakawa-C'
28
+ print(grid)
29
+ else:
30
+ print('grid not recognized, check manually')
31
+ sys.exit()
32
+ elif u.lat[int(len(t.y)/2),0].values == v.lat[int(len(t.y)/2),0].values and u.lon[int(len(t.y)/2),0].values == v.lon[int(len(t.y)/2),0].values:
33
+ if u.lat[int(len(t.y)/2),0].values != t.lat[int(len(t.y)/2),0].values and u.lon[int(len(t.y)/2),0].values != t.lon[int(len(t.y)/2),0].values:
34
+ grid='Arakawa-B'
35
+ print(grid)
36
+ elif u.lat[int(len(t.y)/2),0].values == t.lat[int(len(t.y)/2),0].values and u.lon[int(len(t.y)/2),0].values == t.lon[int(len(t.y)/2),0].values:
37
+ grid='Arakawa-A'
38
+ print(grid)
39
+ elif u.lat[int(len(t.y)/2),0].values == t.lat[int(len(t.y)/2),0].values == v.lat[int(len(t.y)/2),0].values and u.lon[int(len(t.y)/2),0].values != t.lon[int(len(ti.y)/2),0].values != v.lon[int(len(ti.y)/2),0].values:
40
+ grid='Arakawa-E'
41
+ print(grid+'?')
42
+ else:
43
+ print('grid not recognized, check manually')
44
+ sys.exit()
45
+ else:
46
+ print('grid not recognized, check manually')
47
+ sys.exit()
48
+ return grid
49
+
50
+ def transform_Arakawa(grid,mu,mv,deltaz,dzu3,dzv3,udata,vdata):
51
+
52
+ deltaz2=deltaz.thkcello###.mean(dim='time')
53
+ dzs=deltaz2.sum(dim='lev').where(deltaz2.sum(dim='lev')!=0) #axis=0
54
+
55
+ if grid == 'Arakawa-C':
56
+ mu2=mu
57
+ mv2=mv
58
+ if grid == 'Arakawa-B':
59
+ print('transforming Arakawa-B to Arakawa-C')
60
+ dzuv=deltaz2.rolling(x=2,min_periods=1).mean().rolling(y=2,min_periods=1).mean()
61
+ dzuv2=dzuv.cumsum('lev').where(dzuv.cumsum('lev')<=dzs) #dzuvs
62
+ dzuv2=dzuv2.fillna(dzs)#dzuvs
63
+ dzuv3=dzuv2.copy(data=np.diff(dzuv2, axis=dzuv2.get_axis_num("lev"), prepend=0))
64
+ mu2=mu.rolling(x=2,min_periods=1).mean().rolling(y=2,min_periods=1).mean()
65
+ mv2=mv.rolling(y=2,min_periods=1).mean().rolling(x=2,min_periods=1).mean()
66
+ udata2=(udata*mu2.dyu.values*dzuv3.values).rolling(y=2,min_periods=1).mean()/(mu.dyu.values*dzu3.values)#.fillna(0)
67
+ vdata2=(vdata*mv2.dxv.values*dzuv3.values).rolling(x=2,min_periods=1).mean()/(mv.dxv.values*dzv3.values)
68
+ udata2=udata2.where(udata2>-2).where(udata2<2).fillna(0)
69
+ vdata2=vdata2.where(vdata2>-2).where(vdata2<2).fillna(0)
70
+ udata=udata2*(udata/udata).values
71
+ vdata=vdata2*(udata/udata).values
72
+
73
+
74
+ if grid == 'Arakawa-A':
75
+ print('transforming Arakawa-A to Arakawa-C')
76
+ mu2=mu.rolling(x=2,min_periods=1).mean()#.rolling(y=2,min_periods=1).mean()
77
+ mv2=mv.rolling(y=2,min_periods=1).mean()#.rolling(x=2,min_periods=1).mean()
78
+ print('equation to get u/v at T faces')
79
+ udata2=(udata*mu.dyu.values*deltaz2.values).rolling(x=2,min_periods=1).mean()/(mu2.dyu.values*dzu3.values)#.fillna(0)
80
+ vdata2=(vdata*mv.dxv.values*deltaz2.values).rolling(y=2,min_periods=1).mean()/(mv2.dxv.values*dzv3.values)
81
+ udata=udata2.where(udata2>-1000).where(udata2<1000).fillna(0)
82
+ vdata=vdata2.where(vdata2>-1000).where(vdata2<1000).fillna(0)
83
+
84
+ return udata,vdata,dzu3,dzv3,mu2,mv2
85
+
86
+ def check_indices(indices,out_u,out_v,t,u,v,strait,model,path_save):
87
+ lp=indices.indices[-1][indices.indices[-1] != 0].values
88
+ slp=indices.indices[-2][indices.indices[-2] != 0].values
89
+ fp=indices.indices[0][indices.indices[0] != 0].values
90
+ sfp=indices.indices[1][indices.indices[1] != 0].values
91
+ tfp=indices.indices[2][indices.indices[2] != 0].values
92
+ #last point:
93
+ if indices.indices[-1][0] == 0 and indices.indices[-1][1] == 0:
94
+ if v.vo[int(lp[1]-1),int(lp[0]-1)].values > 0 or v.vo[int(lp[1]-1),int(lp[0]-1)].values < 0:
95
+ if v.vo[int(slp[1]-1),int(slp[0]-1)].values > 0 or v.vo[int(slp[1]-1),int(slp[0]-1)].values < 0:
96
+ print('!!!ATTENTION!!!: last point water, recheck indices line!')
97
+ else:
98
+ print('dropping last point...')
99
+ else:
100
+ print('line good')
101
+ else:
102
+ if u.uo[int(lp[1]-1),int(lp[0]-1)].values > 0 or u.uo[int(lp[1]-1),int(lp[0]-1)].values < 0:
103
+ if u.uo[int(slp[1]-1),int(slp[0]-1)].values > 0 or u.uo[int(slp[1]-1),int(slp[0]-1)].values < 0:
104
+ print('!!!ATTENTION!!!: last point water, recheck indices line!')
105
+ else:
106
+ print('dropping last point...')
107
+ else:
108
+ print('line good')
109
+
110
+
111
+ #first point:
112
+ if indices.indices[0][2] == 0 and indices.indices[0][3] == 0:
113
+ if u.uo[int(fp[1]-1),int(fp[0]-1)].values > 0 or u.uo[int(fp[1]-1),int(fp[0]-1)].values < 0:
114
+ if u.uo[int(sfp[1]-1),int(sfp[0]-1)].values > 0 or u.uo[int(sfp[1]-1),int(sfp[0]-1)].values < 0:
115
+ print('!!!ATTENTION!!!: first point water, recheck indices line!')
116
+ else:
117
+ print('dropping last point...')
118
+ else:
119
+ print('line good')
120
+ else:
121
+ if v.vo[int(fp[1]-1),int(fp[0]-1)].values > 0 or v.vo[int(fp[1]-1),int(fp[0]-1)].values < 0:
122
+ if v.vo[int(sfp[1]-1),int(sfp[0]-1)].values > 0 or v.vo[int(sfp[1]-1),int(sfp[0]-1)].values < 0:
123
+ print('!!!ATTENTION!!!: first point water, recheck indices line!')
124
+ else:
125
+ print('dropping last point...')
126
+ else:
127
+ print('line good')
128
+
129
+
130
+ out_u,out_v,out_u_vz = prepare_indices(indices)
131
+
132
+ min_x=np.nanmin((min(out_u[:,0],default=np.nan),min(out_v[:,0],default=np.nan)))
133
+ max_x=np.nanmax((max(out_u[:,0],default=np.nan),max(out_v[:,0],default=np.nan)))
134
+ min_y=np.nanmin((min(out_u[:,1],default=np.nan),min(out_v[:,1],default=np.nan)))
135
+ max_y=np.nanmax((max(out_u[:,1],default=np.nan),max(out_v[:,1],default=np.nan)))
136
+
137
+ if min_x == -1:
138
+ min_x = 0
139
+ max_x = max_x + 1
140
+
141
+ t=t.sel(x=slice(int(min_x)-2,int(max_x)+2),y=slice(int(min_y)-2,int(max_y)+2)).load()
142
+ u=u.sel(x=slice(int(min_x)-1,int(max_x)+1),y=slice(int(min_y)-1,int(max_y)+1)).load()
143
+ v=v.sel(x=slice(int(min_x)-1,int(max_x)+1),y=slice(int(min_y)-1,int(max_y)+1)).load()
144
+ #print(t)
145
+ try:
146
+ plt.title(model+'_'+strait,fontsize=14)
147
+ u.uo.plot()
148
+ plt.scatter(out_v[:,0],out_v[:,1]+0.5,marker='_',c='r',s=200)
149
+ plt.scatter(out_u[:,0]+0.5,out_u[:,1],marker='|',c='r',s=200)
150
+ plt.ylabel('y',fontsize=14)
151
+ plt.xlabel('x',fontsize=14)
152
+ plt.savefig(path_save+strait+'_'+model+'_indices_check.png')
153
+ plt.close()
154
+ except NameError:
155
+ print('skipping Plot')
156
+
157
+ def interp_TS(ds,d):
158
+ return ds.rolling({d:2},min_periods=1).mean()
159
+
160
+ def calc_dz_faces(deltaz,grid,model,path_mesh):
161
+
162
+ if model in ['MPI-ESM1-2-LR','MPI-ESM1-2-HR']:
163
+ print('swap')
164
+ deltaz['y']=np.arange(len(deltaz.y)-1,-1,-1)
165
+ deltaz=deltaz.sortby('y')
166
+ try:
167
+ with ProgressBar():
168
+ zv=deltaz.thkcello.values
169
+ except NameError:
170
+ zv=deltaz.thkcello.values
171
+ z0u=np.zeros(np.shape(zv))
172
+ z0v=np.zeros(np.shape(zv))
173
+ print('calc dz at cell faces')
174
+ if 'time' in deltaz.dims:
175
+ deltazu=xa.Dataset({'thkcello':(('time','lev','y','x'),np.zeros(np.shape(zv)))},coords=({'time':('time',deltaz.time.data),'lev':('lev',deltaz.lev.data),'x':('x',deltaz.x.data),'y':('y',deltaz.y.data)}))
176
+ deltazv=xa.Dataset({'thkcello':(('time','lev','y','x'),np.zeros(np.shape(zv)))},coords=({'time':('time',deltaz.time.data),'lev':('lev',deltaz.lev.data),'x':('x',deltaz.x.data),'y':('y',deltaz.y.data)}))
177
+ else:
178
+ deltazu=xa.Dataset({'thkcello':(('lev','y','x'),np.zeros(np.shape(zv)))},coords=({'lev':('lev',deltaz.lev.data),'x':('x',deltaz.x.data),'y':('y',deltaz.y.data)}))
179
+ deltazv=xa.Dataset({'thkcello':(('lev','y','x'),np.zeros(np.shape(zv)))},coords=({'lev':('lev',deltaz.lev.data),'x':('x',deltaz.x.data),'y':('y',deltaz.y.data)}))
180
+
181
+
182
+ if grid == 'Arakawa-C':
183
+ if 'time' in deltaz.dims:
184
+ for i in tqdm(range(len(deltaz.y)-1)):
185
+ #print(i)
186
+ for j in range(len(deltaz.x)-1):
187
+ p=np.nansum(zv[:,:,i,j:j+2],axis=1).argmin(axis=1)
188
+ for k in range(len(p)):
189
+ if p[k] == 0:
190
+ z0u[k,:,i,j]=zv[k,:,i,j]
191
+ if p[k] ==1:
192
+ l=np.isnan(zv[k,:,i,j+1]).argmax(axis=0)-1
193
+ z0u[k,:l,i,j]=zv[k,:l,i,j]
194
+ z0u[k,l,i,j]=zv[k,l,i,j+1]
195
+ if l >= 0:
196
+ z0u[k,l+1:,i,j]=np.nan
197
+ deltazu['thkcello'][:,:,:,:]=z0u
198
+ deltazu['thkcello'][:,:,:,-1]=zv[:,:,:,-1]
199
+ deltazu['thkcello'][:,:,-1,:]=zv[:,:,-1,:]
200
+
201
+ for j in tqdm(range(len(deltaz.x)-1)):
202
+ #print(j)
203
+ for i in range(len(deltaz.y)-1):
204
+ p=np.nansum(zv[:,:,i:i+2,j],axis=1).argmin(axis=1)
205
+ for k in range(len(p)):
206
+ if p[k] == 0:
207
+ z0v[k,:,i,j]=zv[k,:,i,j]
208
+ if p[k] ==1:
209
+ l=np.isnan(zv[k,:,i+1,j]).argmax(axis=0)-1
210
+ z0v[k,:l,i,j]=zv[k,:l,i,j]
211
+ z0v[k,l,i,j]=zv[k,l,i+1,j]
212
+ if l >= 0:
213
+ z0v[k,l+1:,i,j]=np.nan
214
+
215
+ deltazv['thkcello'][:,:,:,:]=z0v
216
+ deltazv['thkcello'][:,:,:,-1]=zv[:,:,:,-1]
217
+ deltazv['thkcello'][:,:,-1,:]=zv[:,:,-1,:]
218
+
219
+ else:
220
+ for i in tqdm(range(len(deltaz.y)-1)):
221
+ #print(i)
222
+ for j in range(len(deltaz.x)-1):
223
+ p=np.nansum(zv[:,i,j:j+2],axis=0).argmin()
224
+ if p == 0:
225
+ z0u[:,i,j]=zv[:,i,j]
226
+ if p ==1:
227
+ z0u[:,i,j]=zv[:,i,j+1]
228
+ deltazu['thkcello'][:,:,:]=z0u
229
+ deltazu['thkcello'][:,:,-1]=zv[:,:,-1]
230
+ deltazu['thkcello'][:,-1,:]=zv[:,-1,:]
231
+
232
+ for j in tqdm(range(len(deltaz.x)-1)):
233
+ for i in range(len(deltaz.y)-1):
234
+ p=np.nansum(zv[:,i:i+2,j],axis=0).argmin()
235
+ if p == 0:
236
+ z0v[:,i,j]=zv[:,i,j]
237
+ if p ==1:
238
+ z0v[:,i,j]=zv[:,i+1,j]
239
+ deltazv['thkcello'][:,:,:]=z0v
240
+ deltazv['thkcello'][:,:,-1]=zv[:,:,-1]
241
+ deltazv['thkcello'][:,-1,:]=zv[:,-1,:]
242
+
243
+ elif grid in ['Arakawa-B','Arakawa-A']:
244
+ if 'time' in deltaz.dims:
245
+ for k in tqdm(range(len(deltaz.time))):
246
+ #print(k)
247
+ zv=deltaz.thkcello[k].values
248
+ z0u=np.zeros(np.shape(zv))
249
+ z0v=np.zeros(np.shape(zv))
250
+ for i in range(len(deltaz.y)-1):
251
+ #print(i)
252
+ for j in range(len(deltaz.x)-1):
253
+ p=np.argwhere(np.nansum(zv[:,i:i+2,j:j+2],axis=0) == np.min(np.nansum(zv[:,i:i+2,j:j+2],axis=0)))[0]
254
+ if p[0] == 0:
255
+ if p[1] == 0:
256
+ z0u[:,i,j]=zv[:,i,j]
257
+ elif p[1] == 1:
258
+ z0u[:,i,j]=zv[:,i,j+1]
259
+ elif p[0] == 1:
260
+ if p[1] == 0:
261
+ z0u[:,i,j]=zv[:,i+1,j]
262
+ elif p[1] == 1:
263
+ z0u[:,i,j]=zv[:,i+1,j+1]
264
+ deltazu['thkcello'][k,:,:,:]=z0u
265
+ deltazu['thkcello'][k,:,:,-1]=zv[:,:,-1]
266
+ deltazu['thkcello'][k,:,-1,:]=zv[:,-1,:]
267
+ deltazv['thkcello'][k,:,:,:]=z0u
268
+ deltazv['thkcello'][k,:,:,-1]=zv[:,:,-1]
269
+ deltazv['thkcello'][k,:,-1,:]=zv[:,-1,:]
270
+
271
+
272
+ else:
273
+ for i in tqdm(range(len(deltaz.y)-1)):
274
+ #print(i)
275
+ for j in range(len(deltaz.x)-1):
276
+ p=np.argwhere(np.nansum(zv[:,i:i+2,j:j+2],axis=0) == np.min(np.nansum(zv[:,i:i+2,j:j+2],axis=0)))[0]
277
+ if p[0] == 0:
278
+ if p[1] == 0:
279
+ z0u[:,i,j]=zv[:,i,j]
280
+ elif p[1] == 1:
281
+ z0u[:,i,j]=zv[:,i,j+1]
282
+ elif p[0] == 1:
283
+ if p[1] == 0:
284
+ z0u[:,i,j]=zv[:,i+1,j]
285
+ elif p[1] == 1:
286
+ z0u[:,i,j]=zv[:,i+1,j+1]
287
+ deltazu['thkcello'][:,:,:]=z0u
288
+ deltazu['thkcello'][:,:,-1]=zv[:,:,-1]
289
+ deltazu['thkcello'][:,-1,:]=zv[:,-1,:]
290
+ deltazv['thkcello'][:,:,:]=z0u
291
+ deltazv['thkcello'][:,:,-1]=zv[:,:,-1]
292
+ deltazv['thkcello'][:,-1,:]=zv[:,-1,:]
293
+
294
+ if model in ['MPI-ESM1-2-LR','MPI-ESM1-2-HR']:
295
+ print('swap')
296
+ deltazu['y']=np.arange(len(deltazu.y)-1,-1,-1)
297
+ deltazu=deltazu.sortby('y')
298
+ deltazv['y']=np.arange(len(deltazv.y)-1,-1,-1)
299
+ deltazv=deltazv.sortby('y')
300
+
301
+ return deltazu.thkcello,deltazv.thkcello