foscat 2025.7.2__py3-none-any.whl → 2025.8.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
foscat/UNET.py ADDED
@@ -0,0 +1,200 @@
1
+ import numpy as np
2
+
3
+ import foscat.scat_cov as sc
4
+ import foscat.HOrientedConvol as hs
5
+
6
+ class UNET:
7
+
8
+ def __init__(
9
+ self,
10
+ nparam=1,
11
+ KERNELSZ=3,
12
+ NORIENT=4,
13
+ chanlist=None,
14
+ in_nside=1,
15
+ n_chan_in=1,
16
+ n_chan_out=1,
17
+ cell_ids=None,
18
+ SEED=1234,
19
+ filename=None,
20
+ ):
21
+ self.f=sc.funct(KERNELSZ=KERNELSZ)
22
+
23
+ if chanlist is None:
24
+ nlayer=int(np.log2(in_nside))
25
+ chanlist=[4*2**k for k in range(nlayer)]
26
+ else:
27
+ nlayer=len(chanlist)
28
+ print('N_layer ',nlayer)
29
+
30
+ n=0
31
+ wconv={}
32
+ hconv={}
33
+ l_cell_ids={}
34
+ self.KERNELSZ=KERNELSZ
35
+ kernelsz=self.KERNELSZ
36
+
37
+ # create the CNN part
38
+ l_nside=in_nside
39
+ l_cell_ids[0]=cell_ids.copy()
40
+ l_data=self.f.backend.bk_cast(np.ones([1,1,l_cell_ids[0].shape[0]]))
41
+ l_chan=n_chan_in
42
+ print('Initial chan %d Npix=%d'%(l_chan,l_data.shape[2]))
43
+ for l in range(nlayer):
44
+ print('Layer %d Npix=%d'%(l,l_data.shape[2]))
45
+ # init double convol weights
46
+ wconv[2*l]=n
47
+ nw=l_chan*chanlist[l]*kernelsz*kernelsz
48
+ print('Layer %d conv [%d,%d]'%(l,l_chan,chanlist[l]))
49
+ n+=nw
50
+ wconv[2*l+1]=n
51
+ nw=chanlist[l]*chanlist[l]*kernelsz*kernelsz
52
+ print('Layer %d conv [%d,%d]'%(l,chanlist[l],chanlist[l]))
53
+ n+=nw
54
+
55
+ hconvol=hs.HOrientedConvol(l_nside,3,cell_ids=l_cell_ids[l])
56
+ hconvol.make_idx_weights()
57
+ hconv[l]=hconvol
58
+ l_data,n_cell_ids=self.f.ud_grade_2(l_data,cell_ids=l_cell_ids[l],nside=l_nside)
59
+ l_cell_ids[l+1]=self.f.backend.to_numpy(n_cell_ids)
60
+ l_nside//=2
61
+ # plus one to add the input downgrade data
62
+ l_chan=chanlist[l]+n_chan_in
63
+
64
+ self.n_cnn=n
65
+ self.l_cell_ids=l_cell_ids
66
+ self.wconv=wconv
67
+ self.hconv=hconv
68
+
69
+ # create the transpose CNN part
70
+ m_cell_ids={}
71
+ m_cell_ids[0]=l_cell_ids[nlayer]
72
+ t_wconv={}
73
+ t_hconv={}
74
+
75
+ for l in range(nlayer):
76
+ #upgrade data
77
+ l_chan+=n_chan_in
78
+ l_data=self.f.up_grade(l_data,l_nside*2,
79
+ cell_ids=l_cell_ids[nlayer-l],
80
+ o_cell_ids=l_cell_ids[nlayer-1-l],
81
+ nside=l_nside)
82
+ print('Transpose Layer %d Npix=%d'%(l,l_data.shape[2]))
83
+
84
+
85
+ m_cell_ids[l]=l_cell_ids[nlayer-1-l]
86
+ l_nside*=2
87
+
88
+ # init double convol weights
89
+ t_wconv[2*l]=n
90
+ nw=l_chan*l_chan*kernelsz*kernelsz
91
+ print('Transpose Layer %d conv [%d,%d]'%(l,l_chan,l_chan))
92
+ n+=nw
93
+ t_wconv[2*l+1]=n
94
+ out_chan=n_chan_out
95
+ if nlayer-1-l>0:
96
+ out_chan+=chanlist[nlayer-1-l]
97
+ print('Transpose Layer %d conv [%d,%d]'%(l,l_chan,out_chan))
98
+ nw=l_chan*out_chan*kernelsz*kernelsz
99
+ n+=nw
100
+
101
+ hconvol=hs.HOrientedConvol(l_nside,3,cell_ids=m_cell_ids[l])
102
+ hconvol.make_idx_weights()
103
+ t_hconv[l]=hconvol
104
+
105
+ # plus one to add the input downgrade data
106
+ l_chan=out_chan
107
+ print('Final chan %d Npix=%d'%(out_chan,l_data.shape[2]))
108
+ self.n_cnn=n
109
+ self.m_cell_ids=l_cell_ids
110
+ self.t_wconv=t_wconv
111
+ self.t_hconv=t_hconv
112
+ self.x=self.f.backend.bk_cast((np.random.rand(n)-0.5)/self.KERNELSZ)
113
+ self.nside=in_nside
114
+ self.n_chan_in=n_chan_in
115
+ self.n_chan_out=n_chan_out
116
+ self.chanlist=chanlist
117
+
118
+ def get_param(self):
119
+ return self.x
120
+
121
+ def set_param(self,x):
122
+ self.x=self.f.backend.bk_cast(x)
123
+
124
+ def eval(self,data):
125
+ # create the CNN part
126
+ l_nside=self.nside
127
+ l_chan=self.n_chan_in
128
+ l_data=data
129
+ m_data=data
130
+ nlayer=len(self.chanlist)
131
+ kernelsz=self.KERNELSZ
132
+ ud_data={}
133
+
134
+ for l in range(nlayer):
135
+ # init double convol weights
136
+ nw=l_chan*self.chanlist[l]*kernelsz*kernelsz
137
+ ww=self.x[self.wconv[2*l]:self.wconv[2*l]+nw]
138
+ ww=self.f.backend.bk_reshape(ww,[l_chan,
139
+ self.chanlist[l],
140
+ kernelsz*kernelsz])
141
+ l_data = self.hconv[l].Convol_torch(l_data,ww)
142
+
143
+ nw=self.chanlist[l]*self.chanlist[l]*kernelsz*kernelsz
144
+ ww=self.x[self.wconv[2*l+1]:self.wconv[2*l+1]+nw]
145
+ ww=self.f.backend.bk_reshape(ww,[self.chanlist[l],
146
+ self.chanlist[l],
147
+ kernelsz*kernelsz])
148
+
149
+ l_data = self.hconv[l].Convol_torch(l_data,ww)
150
+
151
+ l_data,_=self.f.ud_grade_2(l_data,
152
+ cell_ids=self.l_cell_ids[l],
153
+ nside=l_nside)
154
+
155
+ ud_data[l]=m_data
156
+
157
+ m_data,_=self.f.ud_grade_2(m_data,
158
+ cell_ids=self.l_cell_ids[l],
159
+ nside=l_nside)
160
+
161
+ l_data = self.f.backend.bk_concat([m_data,l_data],1)
162
+
163
+ l_nside//=2
164
+ # plus one to add the input downgrade data
165
+ l_chan=self.chanlist[l]+self.n_chan_in
166
+
167
+ for l in range(nlayer):
168
+ l_chan+=self.n_chan_in
169
+ l_data=self.f.up_grade(l_data,l_nside*2,
170
+ cell_ids=self.l_cell_ids[nlayer-l],
171
+ o_cell_ids=self.l_cell_ids[nlayer-1-l],
172
+ nside=l_nside)
173
+
174
+
175
+ l_data = self.f.backend.bk_concat([ud_data[nlayer-1-l],l_data],1)
176
+ l_nside*=2
177
+
178
+ # init double convol weights
179
+ out_chan=self.n_chan_out
180
+ if nlayer-1-l>0:
181
+ out_chan+=self.chanlist[nlayer-1-l]
182
+ nw=l_chan*l_chan*kernelsz*kernelsz
183
+ ww=self.x[self.t_wconv[2*l]:self.t_wconv[2*l]+nw]
184
+ ww=self.f.backend.bk_reshape(ww,[l_chan,
185
+ l_chan,
186
+ kernelsz*kernelsz])
187
+
188
+ c_data = self.t_hconv[l].Convol_torch(l_data,ww)
189
+
190
+ nw=l_chan*out_chan*kernelsz*kernelsz
191
+ ww=self.x[self.t_wconv[2*l+1]:self.t_wconv[2*l+1]+nw]
192
+ ww=self.f.backend.bk_reshape(ww,[l_chan,
193
+ out_chan,
194
+ kernelsz*kernelsz])
195
+ l_data = self.t_hconv[l].Convol_torch(c_data,ww)
196
+
197
+ # plus one to add the input downgrade data
198
+ l_chan=out_chan
199
+
200
+ return l_data