astro-nest 0.5.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- astro_nest-0.5.0/NEST/NEST.py +412 -0
- astro_nest-0.5.0/NEST/__init__.py +1 -0
- astro_nest-0.5.0/NEST/domain.pkl +0 -0
- astro_nest-0.5.0/NEST/models/BaSTI.sav +0 -0
- astro_nest-0.5.0/NEST/models/BaSTI2.sav +0 -0
- astro_nest-0.5.0/NEST/models/BaSTI2_BPRP.sav +0 -0
- astro_nest-0.5.0/NEST/models/BaSTI_BPRP.sav +0 -0
- astro_nest-0.5.0/NEST/models/Dartmouth.sav +0 -0
- astro_nest-0.5.0/NEST/models/Dartmouth_BPRP.sav +0 -0
- astro_nest-0.5.0/NEST/models/MIST.sav +0 -0
- astro_nest-0.5.0/NEST/models/MIST_BPRP.sav +0 -0
- astro_nest-0.5.0/NEST/models/NN_BaSTI.json +1 -0
- astro_nest-0.5.0/NEST/models/NN_BaSTI.sav +0 -0
- astro_nest-0.5.0/NEST/models/NN_BaSTI.sav.old +0 -0
- astro_nest-0.5.0/NEST/models/NN_BaSTI_BPRP.json +1 -0
- astro_nest-0.5.0/NEST/models/NN_BaSTI_BPRP.sav +0 -0
- astro_nest-0.5.0/NEST/models/NN_BaSTI_HST_BPRP.json +1 -0
- astro_nest-0.5.0/NEST/models/NN_BaSTI_HST_BPRP.sav +0 -0
- astro_nest-0.5.0/NEST/models/NN_BaSTI_HST_alpha_zero_BPRP.json +1 -0
- astro_nest-0.5.0/NEST/models/NN_BaSTI_HST_alpha_zero_BPRP.sav +0 -0
- astro_nest-0.5.0/NEST/models/NN_BaSTI_cut.json +1 -0
- astro_nest-0.5.0/NEST/models/NN_BaSTI_cut.sav +0 -0
- astro_nest-0.5.0/NEST/models/NN_BaSTI_cut_BPRP.json +1 -0
- astro_nest-0.5.0/NEST/models/NN_BaSTI_cut_BPRP.sav +0 -0
- astro_nest-0.5.0/NEST/models/NN_Dartmouth.json +1 -0
- astro_nest-0.5.0/NEST/models/NN_Dartmouth.sav +0 -0
- astro_nest-0.5.0/NEST/models/NN_Dartmouth_BPRP.json +1 -0
- astro_nest-0.5.0/NEST/models/NN_Dartmouth_BPRP.sav +0 -0
- astro_nest-0.5.0/NEST/models/NN_MIST.json +1 -0
- astro_nest-0.5.0/NEST/models/NN_MIST.sav +0 -0
- astro_nest-0.5.0/NEST/models/NN_MIST_BPRP.json +1 -0
- astro_nest-0.5.0/NEST/models/NN_MIST_BPRP.sav +0 -0
- astro_nest-0.5.0/NEST/models/NN_PARSEC.json +1 -0
- astro_nest-0.5.0/NEST/models/NN_PARSEC.sav +0 -0
- astro_nest-0.5.0/NEST/models/NN_PARSEC_BPRP.json +1 -0
- astro_nest-0.5.0/NEST/models/NN_PARSEC_BPRP.sav +0 -0
- astro_nest-0.5.0/NEST/models/NN_SYCLIST.json +1 -0
- astro_nest-0.5.0/NEST/models/NN_SYCLIST.sav +0 -0
- astro_nest-0.5.0/NEST/models/NN_SYCLIST_BPRP.json +1 -0
- astro_nest-0.5.0/NEST/models/NN_SYCLIST_BPRP.sav +0 -0
- astro_nest-0.5.0/NEST/models/NN_YaPSI.json +1 -0
- astro_nest-0.5.0/NEST/models/NN_YaPSI.sav +0 -0
- astro_nest-0.5.0/NEST/models/NN_YaPSI_BPRP.json +1 -0
- astro_nest-0.5.0/NEST/models/NN_YaPSI_BPRP.sav +0 -0
- astro_nest-0.5.0/NEST/models/PARSEC.sav +0 -0
- astro_nest-0.5.0/NEST/models/PARSEC_BPRP.sav +0 -0
- astro_nest-0.5.0/NEST/models/SYCLIST.sav +0 -0
- astro_nest-0.5.0/NEST/models/SYCLIST_BPRP.sav +0 -0
- astro_nest-0.5.0/NEST/models/YaPSI.sav +0 -0
- astro_nest-0.5.0/NEST/models/YaPSI_BPRP.sav +0 -0
- astro_nest-0.5.0/NEST/models/scaler_BaSTI.sav +0 -0
- astro_nest-0.5.0/NEST/models/scaler_BaSTI_BPRP.sav +0 -0
- astro_nest-0.5.0/NEST/models/scaler_BaSTI_BPRP_cut.sav +0 -0
- astro_nest-0.5.0/NEST/models/scaler_BaSTI_HST_BPRP.sav +0 -0
- astro_nest-0.5.0/NEST/models/scaler_BaSTI_HST_alpha_zero_BPRP.sav +0 -0
- astro_nest-0.5.0/NEST/models/scaler_BaSTI_cut.sav +0 -0
- astro_nest-0.5.0/NEST/models/scaler_BaSTI_cut_BPRP.sav +0 -0
- astro_nest-0.5.0/NEST/models/scaler_Dartmouth.sav +0 -0
- astro_nest-0.5.0/NEST/models/scaler_Dartmouth_BPRP.sav +0 -0
- astro_nest-0.5.0/NEST/models/scaler_MIST.sav +0 -0
- astro_nest-0.5.0/NEST/models/scaler_MIST_BPRP.sav +0 -0
- astro_nest-0.5.0/NEST/models/scaler_PARSEC.sav +0 -0
- astro_nest-0.5.0/NEST/models/scaler_PARSEC_BPRP.sav +0 -0
- astro_nest-0.5.0/NEST/models/scaler_SYCLIST.sav +0 -0
- astro_nest-0.5.0/NEST/models/scaler_SYCLIST_BPRP.sav +0 -0
- astro_nest-0.5.0/NEST/models/scaler_YaPSI.sav +0 -0
- astro_nest-0.5.0/NEST/models/scaler_YaPSI_BPRP.sav +0 -0
- astro_nest-0.5.0/NEST/tutorial.ipynb +566 -0
- astro_nest-0.5.0/PKG-INFO +25 -0
- astro_nest-0.5.0/astro_nest.egg-info/PKG-INFO +25 -0
- astro_nest-0.5.0/astro_nest.egg-info/SOURCES.txt +74 -0
- astro_nest-0.5.0/astro_nest.egg-info/dependency_links.txt +1 -0
- astro_nest-0.5.0/astro_nest.egg-info/top_level.txt +1 -0
- astro_nest-0.5.0/pyproject.toml +28 -0
- astro_nest-0.5.0/readme.md +14 -0
- astro_nest-0.5.0/setup.cfg +4 -0
@@ -0,0 +1,412 @@
|
|
1
|
+
import numpy as np
|
2
|
+
import pickle,json,os,warnings
|
3
|
+
try:
|
4
|
+
from tqdm import tqdm
|
5
|
+
has_tqdm = True
|
6
|
+
except ImportError:
|
7
|
+
has_tqdm = False
|
8
|
+
NNSA_DIR = os.path.dirname(os.path.abspath(__file__))
|
9
|
+
|
10
|
+
def custom_warning(message, category, filename, lineno, file=None, line=None):
|
11
|
+
print(f"{category.__name__}: {message}")
|
12
|
+
|
13
|
+
warnings.showwarning = custom_warning
|
14
|
+
|
15
|
+
def get_mode(arr,min_age=0,max_age=14):
|
16
|
+
hist, bins = np.histogram(arr,bins=100,range=(min_age,max_age))
|
17
|
+
return bins[np.argmax(hist)] + (bins[1]-bins[0])/2
|
18
|
+
|
19
|
+
def available_models():
|
20
|
+
model_list = ['BaSTI','PARSEC','MIST','Geneva','Dartmouth','YaPSI']
|
21
|
+
model_sources = [
|
22
|
+
'http://basti-iac.oa-abruzzo.inaf.it/',
|
23
|
+
'https://stev.oapd.inaf.it/PARSEC/',
|
24
|
+
'https://waps.cfa.harvard.edu/MIST/',
|
25
|
+
'https://www.unige.ch/sciences/astro/evolution/en/database/syclist',
|
26
|
+
'https://rcweb.dartmouth.edu/stellar/',
|
27
|
+
'http://www.astro.yale.edu/yapsi/'
|
28
|
+
]
|
29
|
+
for model,source in zip(model_list,model_sources):
|
30
|
+
print(model + 'Model (' + source + ')')
|
31
|
+
|
32
|
+
class AgeModel:
|
33
|
+
def __init__(self,model_name,cut=False,use_sklearn=True,use_tqdm=True):
|
34
|
+
self.model_name = model_name
|
35
|
+
self.use_sklearn = use_sklearn
|
36
|
+
self.use_tqdm = use_tqdm
|
37
|
+
if not has_tqdm and self.use_tqdm:
|
38
|
+
self.use_tqdm = False
|
39
|
+
self.cut = cut
|
40
|
+
domain_path = os.path.join(NNSA_DIR, 'domain.pkl')
|
41
|
+
domain = pickle.load(open(domain_path, 'rb'))
|
42
|
+
if model_name in domain:
|
43
|
+
self.domain = domain[model_name]
|
44
|
+
self.space_col = self.domain['spaces'][0]
|
45
|
+
self.space_mag = self.domain['spaces'][1]
|
46
|
+
self.space_met = self.domain['spaces'][2]
|
47
|
+
self.domain = self.domain['grid']
|
48
|
+
else:
|
49
|
+
self.domain = None
|
50
|
+
self.space_col = None
|
51
|
+
self.space_mag = None
|
52
|
+
self.space_met = None
|
53
|
+
if self.cut:
|
54
|
+
self.model_name = self.model_name + '_cut'
|
55
|
+
self.neural_networks = {}
|
56
|
+
self.scalers = {}
|
57
|
+
self.samples = None
|
58
|
+
self.ages = None
|
59
|
+
self.medians = None
|
60
|
+
self.means = None
|
61
|
+
self.modes = None
|
62
|
+
self.stds = None
|
63
|
+
self.load_neural_network(self.model_name)
|
64
|
+
|
65
|
+
def __str__(self):
|
66
|
+
return self.model_name + ' Age Model'
|
67
|
+
|
68
|
+
def load_neural_network(self, model_name):
|
69
|
+
if self.use_sklearn:
|
70
|
+
model_path_full = os.path.join(NNSA_DIR, 'models', f'{model_name}.sav')
|
71
|
+
model_path_reduced = os.path.join(NNSA_DIR, 'models', f'{model_name}_BPRP.sav')
|
72
|
+
if os.path.exists(model_path_full):
|
73
|
+
nn = pickle.load(open(model_path_full, 'rb'))
|
74
|
+
self.neural_networks['full'] = nn['NN']
|
75
|
+
self.scalers['full'] = nn['Scaler']
|
76
|
+
if os.path.exists(model_path_reduced):
|
77
|
+
nn = pickle.load(open(model_path_reduced, 'rb'))
|
78
|
+
self.neural_networks['reduced'] = nn['NN']
|
79
|
+
self.scalers['reduced'] = nn['Scaler']
|
80
|
+
else:
|
81
|
+
model_path_full = os.path.join(NNSA_DIR, 'models', f'NN_{model_name}.json')
|
82
|
+
model_path_reduced = os.path.join(NNSA_DIR, 'models', f'NN_{model_name}_BPRP.json')
|
83
|
+
if os.path.exists(model_path_full):
|
84
|
+
json_nn = json.load(open(model_path_full, 'r'))
|
85
|
+
self.neural_networks['full'] = {
|
86
|
+
'weights':json_nn['weights'],
|
87
|
+
'biases':json_nn['biases']
|
88
|
+
}
|
89
|
+
self.scalers['full'] = {
|
90
|
+
'means':json_nn['means'],
|
91
|
+
'stds':json_nn['stds']
|
92
|
+
}
|
93
|
+
if os.path.exists(model_path_reduced):
|
94
|
+
json_nn = json.load(open(model_path_reduced, 'r'))
|
95
|
+
self.neural_networks['reduced'] = {
|
96
|
+
'weights':json_nn['weights'],
|
97
|
+
'biases':json_nn['biases']
|
98
|
+
}
|
99
|
+
self.scalers['reduced'] = {
|
100
|
+
'means':json_nn['means'],
|
101
|
+
'stds':json_nn['stds']
|
102
|
+
}
|
103
|
+
|
104
|
+
def ages_prediction(self,
|
105
|
+
met,mag,col,
|
106
|
+
emet=None,emag=None,ecol=None,
|
107
|
+
GBP=None,GRP=None,
|
108
|
+
eGBP=None,eGRP=None,
|
109
|
+
n=1,
|
110
|
+
store_samples=True,
|
111
|
+
min_age=0,max_age=14):
|
112
|
+
|
113
|
+
if met is not None and type(met) is not list:
|
114
|
+
if hasattr(met,'tolist'):
|
115
|
+
met = met.tolist()
|
116
|
+
else:
|
117
|
+
met = [met]
|
118
|
+
if mag is not None and type(mag) is not list:
|
119
|
+
if hasattr(mag,'tolist'):
|
120
|
+
mag = mag.tolist()
|
121
|
+
else:
|
122
|
+
mag = [mag]
|
123
|
+
if col is not None and type(col) is not list:
|
124
|
+
if hasattr(col,'tolist'):
|
125
|
+
col = col.tolist()
|
126
|
+
else:
|
127
|
+
col = [col]
|
128
|
+
if emet is not None and type(emet) is not list:
|
129
|
+
if hasattr(emet,'tolist'):
|
130
|
+
emet = emet.tolist()
|
131
|
+
else:
|
132
|
+
emet = [emet]
|
133
|
+
if emag is not None and type(emag) is not list:
|
134
|
+
if hasattr(emag,'tolist'):
|
135
|
+
emag = emag.tolist()
|
136
|
+
else:
|
137
|
+
emag = [emag]
|
138
|
+
if ecol is not None and type(ecol) is not list:
|
139
|
+
if hasattr(ecol,'tolist'):
|
140
|
+
ecol = ecol.tolist()
|
141
|
+
else:
|
142
|
+
ecol = [ecol]
|
143
|
+
if GBP is not None and type(GBP) is not list:
|
144
|
+
if hasattr(GBP,'tolist'):
|
145
|
+
GBP = GBP.tolist()
|
146
|
+
else:
|
147
|
+
GBP = [GBP]
|
148
|
+
if GRP is not None and type(GRP) is not list:
|
149
|
+
if hasattr(GRP,'tolist'):
|
150
|
+
GRP = GRP.tolist()
|
151
|
+
else:
|
152
|
+
GRP = [GRP]
|
153
|
+
if eGBP is not None and type(eGBP) is not list:
|
154
|
+
if hasattr(eGBP,'tolist'):
|
155
|
+
eGBP = eGBP.tolist()
|
156
|
+
else:
|
157
|
+
eGBP = [eGBP]
|
158
|
+
if eGRP is not None and type(eGRP) is not list:
|
159
|
+
if hasattr(eGRP,'tolist'):
|
160
|
+
eGRP = eGRP.tolist()
|
161
|
+
else:
|
162
|
+
eGRP = [eGRP]
|
163
|
+
|
164
|
+
if store_samples and n*len(met) > 1e6:
|
165
|
+
warnings.warn('Storing samples for {} stars with {} samples for each will take a lot of memory. Consider setting store_samples=False to only store mean,median,mode and std of individual age distributions.'.format(len(met),n))
|
166
|
+
|
167
|
+
inputs = [input for input in [met,mag,col,emet,emag,ecol,GBP,GRP,eGBP,eGRP] if input is not None]
|
168
|
+
if len(set(map(len,inputs))) != 1:
|
169
|
+
raise ValueError('All input arrays must have the same length')
|
170
|
+
|
171
|
+
is_reduced = True
|
172
|
+
has_errors = False
|
173
|
+
if GBP is not None and GRP is not None:
|
174
|
+
is_reduced = False
|
175
|
+
if emet is not None and emag is not None and ecol is not None and eGBP is not None and eGRP is not None:
|
176
|
+
has_errors = True
|
177
|
+
else:
|
178
|
+
if emet is not None and emag is not None and ecol is not None:
|
179
|
+
has_errors = True
|
180
|
+
|
181
|
+
if n > 1 and not has_errors:
|
182
|
+
raise ValueError('For more than one sample, errors must be provided')
|
183
|
+
|
184
|
+
if is_reduced:
|
185
|
+
X = np.array([met, mag, col])
|
186
|
+
X_errors = np.array([emet, emag, ecol])
|
187
|
+
else:
|
188
|
+
X = np.array([met, mag, GBP, GRP, col])
|
189
|
+
X_errors = np.array([emet, emag, eGBP, eGRP, ecol])
|
190
|
+
|
191
|
+
X = X.T
|
192
|
+
X_errors = X_errors.T
|
193
|
+
|
194
|
+
if X.shape[1] == 3:
|
195
|
+
if self.neural_networks.get('reduced') is None:
|
196
|
+
raise ValueError('Reduced neural network not available for this model')
|
197
|
+
scaler = self.scalers['reduced']
|
198
|
+
neural_network = self.neural_networks['reduced']
|
199
|
+
else:
|
200
|
+
if self.neural_networks.get('full') is None:
|
201
|
+
raise ValueError('Full neural network not available for this model')
|
202
|
+
scaler = self.scalers['full']
|
203
|
+
neural_network = self.neural_networks['full']
|
204
|
+
|
205
|
+
self.ages = np.zeros((X.shape[0],n))
|
206
|
+
if store_samples:
|
207
|
+
self.samples = np.zeros((X.shape[0],n,X.shape[1]))
|
208
|
+
else:
|
209
|
+
self.samples = None
|
210
|
+
self.medians = np.zeros(X.shape[0])
|
211
|
+
self.means = np.zeros(X.shape[0])
|
212
|
+
self.modes = np.zeros(X.shape[0])
|
213
|
+
self.stds = np.zeros(X.shape[0])
|
214
|
+
|
215
|
+
if self.use_tqdm and (n > 1 or X.shape[0] > 1):
|
216
|
+
loop = tqdm(range(X.shape[0]))
|
217
|
+
else:
|
218
|
+
loop = range(X.shape[0])
|
219
|
+
for i in loop:
|
220
|
+
if n > 1:
|
221
|
+
X_i = np.random.normal(X[i],X_errors[i],(n,X.shape[1]))
|
222
|
+
if store_samples:
|
223
|
+
self.samples[i] = X_i
|
224
|
+
else:
|
225
|
+
X_i = X[i].reshape(1,-1)
|
226
|
+
if store_samples:
|
227
|
+
self.samples[i] = X_i
|
228
|
+
|
229
|
+
ages = self.propagate(X_i,neural_network,scaler)
|
230
|
+
if store_samples:
|
231
|
+
self.ages[i] = ages
|
232
|
+
else:
|
233
|
+
median = np.median(ages)
|
234
|
+
mean = np.mean(ages)
|
235
|
+
mode = get_mode(ages,min_age,max_age)
|
236
|
+
std = np.std(ages)
|
237
|
+
self.medians[i] = median
|
238
|
+
self.means[i] = mean
|
239
|
+
self.modes[i] = mode
|
240
|
+
self.stds[i] = std
|
241
|
+
|
242
|
+
if store_samples:
|
243
|
+
return self.ages
|
244
|
+
else:
|
245
|
+
return {'mean':self.means,'median':self.medians,'mode':self.modes,'std':self.stds}
|
246
|
+
|
247
|
+
def check_domain(self,met,mag,col,emet=None,emag=None,ecol=None):
|
248
|
+
if self.domain is None:
|
249
|
+
raise ValueError('No domain defined for this model')
|
250
|
+
if met is not None and type(met) is not list:
|
251
|
+
if hasattr(met,'tolist'):
|
252
|
+
met = met.tolist()
|
253
|
+
else:
|
254
|
+
met = [met]
|
255
|
+
if mag is not None and type(mag) is not list:
|
256
|
+
if hasattr(mag,'tolist'):
|
257
|
+
mag = mag.tolist()
|
258
|
+
else:
|
259
|
+
mag = [mag]
|
260
|
+
if col is not None and type(col) is not list:
|
261
|
+
if hasattr(col,'tolist'):
|
262
|
+
col = col.tolist()
|
263
|
+
else:
|
264
|
+
col = [col]
|
265
|
+
if emet is not None and type(emet) is not list:
|
266
|
+
if hasattr(emet,'tolist'):
|
267
|
+
emet = emet.tolist()
|
268
|
+
else:
|
269
|
+
emet = [emet]
|
270
|
+
if emag is not None and type(emag) is not list:
|
271
|
+
if hasattr(emag,'tolist'):
|
272
|
+
emag = emag.tolist()
|
273
|
+
else:
|
274
|
+
emag = [emag]
|
275
|
+
if ecol is not None and type(ecol) is not list:
|
276
|
+
if hasattr(ecol,'tolist'):
|
277
|
+
ecol = ecol.tolist()
|
278
|
+
else:
|
279
|
+
ecol = [ecol]
|
280
|
+
|
281
|
+
has_errors = emet != None and emag != None and ecol != None
|
282
|
+
|
283
|
+
in_domain = np.zeros(len(met),dtype=bool)
|
284
|
+
|
285
|
+
if self.use_tqdm and len(met) > 1:
|
286
|
+
loop = tqdm(range(len(met)))
|
287
|
+
else:
|
288
|
+
loop = range(len(met))
|
289
|
+
|
290
|
+
for i in loop:
|
291
|
+
if has_errors:
|
292
|
+
errors = [ecol[i],emag[i],emet[i]]
|
293
|
+
else:
|
294
|
+
errors = [0,0,0]
|
295
|
+
min_i_col = np.maximum(np.digitize(col[i] - errors[0],self.space_col) - 1,0)
|
296
|
+
max_i_col = np.minimum(np.digitize(col[i] + errors[0],self.space_col) - 1,self.space_col.size-2)
|
297
|
+
min_i_mag = np.maximum(np.digitize(mag[i] - errors[1],self.space_mag) - 1,0)
|
298
|
+
max_i_mag = np.minimum(np.digitize(mag[i] + errors[1],self.space_mag) - 1,self.space_mag.size-2)
|
299
|
+
min_i_met = np.maximum(np.digitize(met[i] - errors[2],self.space_met) - 1,0)
|
300
|
+
max_i_met = np.minimum(np.digitize(met[i] + errors[2],self.space_met) - 1,self.space_met.size-2)
|
301
|
+
#cells[i] = np.array([[min_i_col,max_i_col],[min_i_mag,max_i_mag],[min_i_met,max_i_met]])
|
302
|
+
if self.cut and (self.space_col[min_i_col] > 1.25 or self.space_mag[min_i_mag] > 4):
|
303
|
+
in_domain[i] = False
|
304
|
+
continue
|
305
|
+
in_domain[i] = bool(np.any(self.domain[min_i_col:max_i_col+1,min_i_mag:max_i_mag+1,min_i_met:max_i_met+1]) == 1)
|
306
|
+
|
307
|
+
return in_domain#,cells
|
308
|
+
|
309
|
+
def propagate(self,X,neural_network,scaler):
|
310
|
+
if self.use_sklearn:
|
311
|
+
with warnings.catch_warnings():
|
312
|
+
warnings.simplefilter("ignore", RuntimeWarning)
|
313
|
+
X = scaler.transform(X)
|
314
|
+
return neural_network.predict(X)
|
315
|
+
else:
|
316
|
+
weights = neural_network['weights']
|
317
|
+
biases = neural_network['biases']
|
318
|
+
means = scaler['means']
|
319
|
+
stds = scaler['stds']
|
320
|
+
outputs = []
|
321
|
+
for x in X:
|
322
|
+
a = (x - means)/stds
|
323
|
+
output = self.predict_nn(a,weights,biases)
|
324
|
+
outputs.append(output)
|
325
|
+
return np.array(outputs)
|
326
|
+
|
327
|
+
def relu(self,x):
|
328
|
+
return np.maximum(0,x)
|
329
|
+
|
330
|
+
def dot(self,x,y):
|
331
|
+
x_dot_y = 0
|
332
|
+
for i in range(len(x)):
|
333
|
+
x_dot_y += x[i]*y[i]
|
334
|
+
return x_dot_y
|
335
|
+
|
336
|
+
def predict_nn(self,X,weights,biases):
|
337
|
+
a = X
|
338
|
+
for i in range(len(weights)):
|
339
|
+
a = self.dot(a,weights[i]) + biases[i]
|
340
|
+
a = self.relu(a)
|
341
|
+
return a[0]
|
342
|
+
|
343
|
+
def mean_ages(self):
|
344
|
+
if self.ages is None:
|
345
|
+
raise ValueError('No age predictions have been made yet')
|
346
|
+
self.means = np.mean(self.ages,axis=1)
|
347
|
+
return self.means
|
348
|
+
|
349
|
+
def median_ages(self):
|
350
|
+
if self.ages is None:
|
351
|
+
raise ValueError('No age predictions have been made yet')
|
352
|
+
self.medians = np.median(self.ages,axis=1)
|
353
|
+
return self.medians
|
354
|
+
|
355
|
+
def mode_ages(self):
|
356
|
+
if self.ages is None:
|
357
|
+
raise ValueError('No age predictions have been made yet')
|
358
|
+
#TODO: choose number of bins appropriately
|
359
|
+
modes = []
|
360
|
+
min_age = max(0,self.ages.min())
|
361
|
+
max_age = max(14,self.ages.max())
|
362
|
+
|
363
|
+
for i in range(len(self.ages)):
|
364
|
+
modes.append(get_mode(self.ages[i],min_age,max_age))
|
365
|
+
self.modes = np.array(modes)
|
366
|
+
return self.modes
|
367
|
+
|
368
|
+
def std_ages(self):
|
369
|
+
if self.ages is None:
|
370
|
+
raise ValueError('No age predictions have been made yet')
|
371
|
+
self.stds = np.std(self.ages,axis=1)
|
372
|
+
return self.stds
|
373
|
+
|
374
|
+
class BaSTIModel(AgeModel):
|
375
|
+
def __init__(self,cut=False,use_sklearn=True,use_tqdm=True):
|
376
|
+
super().__init__('BaSTI',cut,use_sklearn,use_tqdm)
|
377
|
+
|
378
|
+
'''
|
379
|
+
class BaSTI2Model(AgeModel):
|
380
|
+
def __init__(self,cut=False,use_sklearn=True,use_tqdm=True):
|
381
|
+
super().__init__('BaSTI2',cut,use_sklearn,use_tqdm)
|
382
|
+
|
383
|
+
class BaSTI_HSTModel(AgeModel):
|
384
|
+
def __init__(self,cut=False,use_sklearn=True,use_tqdm=True):
|
385
|
+
super().__init__('BaSTI_HST',cut,use_sklearn,use_tqdm)
|
386
|
+
|
387
|
+
class BaSTI_HST_alpha_zeroModel(AgeModel):
|
388
|
+
def __init__(self,cut=False,use_sklearn=True,use_tqdm=True):
|
389
|
+
super().__init__('BaSTI_HST_alpha_zero',cut,use_sklearn,use_tqdm)
|
390
|
+
'''
|
391
|
+
|
392
|
+
class PARSECModel(AgeModel):
|
393
|
+
def __init__(self,cut=False,use_sklearn=True,use_tqdm=True):
|
394
|
+
super().__init__('PARSEC',cut,use_sklearn,use_tqdm)
|
395
|
+
|
396
|
+
class MISTModel(AgeModel):
|
397
|
+
def __init__(self,cut=False,use_sklearn=True,use_tqdm=True):
|
398
|
+
super().__init__('MIST',cut,use_sklearn,use_tqdm)
|
399
|
+
|
400
|
+
class GenevaModel(AgeModel):
|
401
|
+
def __init__(self,cut=False,use_sklearn=True,use_tqdm=True):
|
402
|
+
super().__init__('Geneva',cut,use_sklearn,use_tqdm)
|
403
|
+
|
404
|
+
class DartmouthModel(AgeModel):
|
405
|
+
def __init__(self,cut=False,use_sklearn=True,use_tqdm=True):
|
406
|
+
super().__init__('Dartmouth',cut,use_sklearn,use_tqdm)
|
407
|
+
|
408
|
+
class YaPSIModel(AgeModel):
|
409
|
+
def __init__(self,cut=False,use_sklearn=True,use_tqdm=True):
|
410
|
+
super().__init__('YaPSI',cut,use_sklearn,use_tqdm)
|
411
|
+
|
412
|
+
#TODO: add flavors to models (e.g. trained on cut CMD for optimal performance)
|
@@ -0,0 +1 @@
|
|
1
|
+
from .NEST import *
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|