astro-nest 0.5.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. astro_nest-0.5.0/NEST/NEST.py +412 -0
  2. astro_nest-0.5.0/NEST/__init__.py +1 -0
  3. astro_nest-0.5.0/NEST/domain.pkl +0 -0
  4. astro_nest-0.5.0/NEST/models/BaSTI.sav +0 -0
  5. astro_nest-0.5.0/NEST/models/BaSTI2.sav +0 -0
  6. astro_nest-0.5.0/NEST/models/BaSTI2_BPRP.sav +0 -0
  7. astro_nest-0.5.0/NEST/models/BaSTI_BPRP.sav +0 -0
  8. astro_nest-0.5.0/NEST/models/Dartmouth.sav +0 -0
  9. astro_nest-0.5.0/NEST/models/Dartmouth_BPRP.sav +0 -0
  10. astro_nest-0.5.0/NEST/models/MIST.sav +0 -0
  11. astro_nest-0.5.0/NEST/models/MIST_BPRP.sav +0 -0
  12. astro_nest-0.5.0/NEST/models/NN_BaSTI.json +1 -0
  13. astro_nest-0.5.0/NEST/models/NN_BaSTI.sav +0 -0
  14. astro_nest-0.5.0/NEST/models/NN_BaSTI.sav.old +0 -0
  15. astro_nest-0.5.0/NEST/models/NN_BaSTI_BPRP.json +1 -0
  16. astro_nest-0.5.0/NEST/models/NN_BaSTI_BPRP.sav +0 -0
  17. astro_nest-0.5.0/NEST/models/NN_BaSTI_HST_BPRP.json +1 -0
  18. astro_nest-0.5.0/NEST/models/NN_BaSTI_HST_BPRP.sav +0 -0
  19. astro_nest-0.5.0/NEST/models/NN_BaSTI_HST_alpha_zero_BPRP.json +1 -0
  20. astro_nest-0.5.0/NEST/models/NN_BaSTI_HST_alpha_zero_BPRP.sav +0 -0
  21. astro_nest-0.5.0/NEST/models/NN_BaSTI_cut.json +1 -0
  22. astro_nest-0.5.0/NEST/models/NN_BaSTI_cut.sav +0 -0
  23. astro_nest-0.5.0/NEST/models/NN_BaSTI_cut_BPRP.json +1 -0
  24. astro_nest-0.5.0/NEST/models/NN_BaSTI_cut_BPRP.sav +0 -0
  25. astro_nest-0.5.0/NEST/models/NN_Dartmouth.json +1 -0
  26. astro_nest-0.5.0/NEST/models/NN_Dartmouth.sav +0 -0
  27. astro_nest-0.5.0/NEST/models/NN_Dartmouth_BPRP.json +1 -0
  28. astro_nest-0.5.0/NEST/models/NN_Dartmouth_BPRP.sav +0 -0
  29. astro_nest-0.5.0/NEST/models/NN_MIST.json +1 -0
  30. astro_nest-0.5.0/NEST/models/NN_MIST.sav +0 -0
  31. astro_nest-0.5.0/NEST/models/NN_MIST_BPRP.json +1 -0
  32. astro_nest-0.5.0/NEST/models/NN_MIST_BPRP.sav +0 -0
  33. astro_nest-0.5.0/NEST/models/NN_PARSEC.json +1 -0
  34. astro_nest-0.5.0/NEST/models/NN_PARSEC.sav +0 -0
  35. astro_nest-0.5.0/NEST/models/NN_PARSEC_BPRP.json +1 -0
  36. astro_nest-0.5.0/NEST/models/NN_PARSEC_BPRP.sav +0 -0
  37. astro_nest-0.5.0/NEST/models/NN_SYCLIST.json +1 -0
  38. astro_nest-0.5.0/NEST/models/NN_SYCLIST.sav +0 -0
  39. astro_nest-0.5.0/NEST/models/NN_SYCLIST_BPRP.json +1 -0
  40. astro_nest-0.5.0/NEST/models/NN_SYCLIST_BPRP.sav +0 -0
  41. astro_nest-0.5.0/NEST/models/NN_YaPSI.json +1 -0
  42. astro_nest-0.5.0/NEST/models/NN_YaPSI.sav +0 -0
  43. astro_nest-0.5.0/NEST/models/NN_YaPSI_BPRP.json +1 -0
  44. astro_nest-0.5.0/NEST/models/NN_YaPSI_BPRP.sav +0 -0
  45. astro_nest-0.5.0/NEST/models/PARSEC.sav +0 -0
  46. astro_nest-0.5.0/NEST/models/PARSEC_BPRP.sav +0 -0
  47. astro_nest-0.5.0/NEST/models/SYCLIST.sav +0 -0
  48. astro_nest-0.5.0/NEST/models/SYCLIST_BPRP.sav +0 -0
  49. astro_nest-0.5.0/NEST/models/YaPSI.sav +0 -0
  50. astro_nest-0.5.0/NEST/models/YaPSI_BPRP.sav +0 -0
  51. astro_nest-0.5.0/NEST/models/scaler_BaSTI.sav +0 -0
  52. astro_nest-0.5.0/NEST/models/scaler_BaSTI_BPRP.sav +0 -0
  53. astro_nest-0.5.0/NEST/models/scaler_BaSTI_BPRP_cut.sav +0 -0
  54. astro_nest-0.5.0/NEST/models/scaler_BaSTI_HST_BPRP.sav +0 -0
  55. astro_nest-0.5.0/NEST/models/scaler_BaSTI_HST_alpha_zero_BPRP.sav +0 -0
  56. astro_nest-0.5.0/NEST/models/scaler_BaSTI_cut.sav +0 -0
  57. astro_nest-0.5.0/NEST/models/scaler_BaSTI_cut_BPRP.sav +0 -0
  58. astro_nest-0.5.0/NEST/models/scaler_Dartmouth.sav +0 -0
  59. astro_nest-0.5.0/NEST/models/scaler_Dartmouth_BPRP.sav +0 -0
  60. astro_nest-0.5.0/NEST/models/scaler_MIST.sav +0 -0
  61. astro_nest-0.5.0/NEST/models/scaler_MIST_BPRP.sav +0 -0
  62. astro_nest-0.5.0/NEST/models/scaler_PARSEC.sav +0 -0
  63. astro_nest-0.5.0/NEST/models/scaler_PARSEC_BPRP.sav +0 -0
  64. astro_nest-0.5.0/NEST/models/scaler_SYCLIST.sav +0 -0
  65. astro_nest-0.5.0/NEST/models/scaler_SYCLIST_BPRP.sav +0 -0
  66. astro_nest-0.5.0/NEST/models/scaler_YaPSI.sav +0 -0
  67. astro_nest-0.5.0/NEST/models/scaler_YaPSI_BPRP.sav +0 -0
  68. astro_nest-0.5.0/NEST/tutorial.ipynb +566 -0
  69. astro_nest-0.5.0/PKG-INFO +25 -0
  70. astro_nest-0.5.0/astro_nest.egg-info/PKG-INFO +25 -0
  71. astro_nest-0.5.0/astro_nest.egg-info/SOURCES.txt +74 -0
  72. astro_nest-0.5.0/astro_nest.egg-info/dependency_links.txt +1 -0
  73. astro_nest-0.5.0/astro_nest.egg-info/top_level.txt +1 -0
  74. astro_nest-0.5.0/pyproject.toml +28 -0
  75. astro_nest-0.5.0/readme.md +14 -0
  76. astro_nest-0.5.0/setup.cfg +4 -0
@@ -0,0 +1,412 @@
1
+ import numpy as np
2
+ import pickle,json,os,warnings
3
+ try:
4
+ from tqdm import tqdm
5
+ has_tqdm = True
6
+ except ImportError:
7
+ has_tqdm = False
8
+ NNSA_DIR = os.path.dirname(os.path.abspath(__file__))
9
+
10
+ def custom_warning(message, category, filename, lineno, file=None, line=None):
11
+ print(f"{category.__name__}: {message}")
12
+
13
+ warnings.showwarning = custom_warning
14
+
15
+ def get_mode(arr,min_age=0,max_age=14):
16
+ hist, bins = np.histogram(arr,bins=100,range=(min_age,max_age))
17
+ return bins[np.argmax(hist)] + (bins[1]-bins[0])/2
18
+
19
+ def available_models():
20
+ model_list = ['BaSTI','PARSEC','MIST','Geneva','Dartmouth','YaPSI']
21
+ model_sources = [
22
+ 'http://basti-iac.oa-abruzzo.inaf.it/',
23
+ 'https://stev.oapd.inaf.it/PARSEC/',
24
+ 'https://waps.cfa.harvard.edu/MIST/',
25
+ 'https://www.unige.ch/sciences/astro/evolution/en/database/syclist',
26
+ 'https://rcweb.dartmouth.edu/stellar/',
27
+ 'http://www.astro.yale.edu/yapsi/'
28
+ ]
29
+ for model,source in zip(model_list,model_sources):
30
+ print(model + 'Model (' + source + ')')
31
+
32
+ class AgeModel:
33
+ def __init__(self,model_name,cut=False,use_sklearn=True,use_tqdm=True):
34
+ self.model_name = model_name
35
+ self.use_sklearn = use_sklearn
36
+ self.use_tqdm = use_tqdm
37
+ if not has_tqdm and self.use_tqdm:
38
+ self.use_tqdm = False
39
+ self.cut = cut
40
+ domain_path = os.path.join(NNSA_DIR, 'domain.pkl')
41
+ domain = pickle.load(open(domain_path, 'rb'))
42
+ if model_name in domain:
43
+ self.domain = domain[model_name]
44
+ self.space_col = self.domain['spaces'][0]
45
+ self.space_mag = self.domain['spaces'][1]
46
+ self.space_met = self.domain['spaces'][2]
47
+ self.domain = self.domain['grid']
48
+ else:
49
+ self.domain = None
50
+ self.space_col = None
51
+ self.space_mag = None
52
+ self.space_met = None
53
+ if self.cut:
54
+ self.model_name = self.model_name + '_cut'
55
+ self.neural_networks = {}
56
+ self.scalers = {}
57
+ self.samples = None
58
+ self.ages = None
59
+ self.medians = None
60
+ self.means = None
61
+ self.modes = None
62
+ self.stds = None
63
+ self.load_neural_network(self.model_name)
64
+
65
+ def __str__(self):
66
+ return self.model_name + ' Age Model'
67
+
68
+ def load_neural_network(self, model_name):
69
+ if self.use_sklearn:
70
+ model_path_full = os.path.join(NNSA_DIR, 'models', f'{model_name}.sav')
71
+ model_path_reduced = os.path.join(NNSA_DIR, 'models', f'{model_name}_BPRP.sav')
72
+ if os.path.exists(model_path_full):
73
+ nn = pickle.load(open(model_path_full, 'rb'))
74
+ self.neural_networks['full'] = nn['NN']
75
+ self.scalers['full'] = nn['Scaler']
76
+ if os.path.exists(model_path_reduced):
77
+ nn = pickle.load(open(model_path_reduced, 'rb'))
78
+ self.neural_networks['reduced'] = nn['NN']
79
+ self.scalers['reduced'] = nn['Scaler']
80
+ else:
81
+ model_path_full = os.path.join(NNSA_DIR, 'models', f'NN_{model_name}.json')
82
+ model_path_reduced = os.path.join(NNSA_DIR, 'models', f'NN_{model_name}_BPRP.json')
83
+ if os.path.exists(model_path_full):
84
+ json_nn = json.load(open(model_path_full, 'r'))
85
+ self.neural_networks['full'] = {
86
+ 'weights':json_nn['weights'],
87
+ 'biases':json_nn['biases']
88
+ }
89
+ self.scalers['full'] = {
90
+ 'means':json_nn['means'],
91
+ 'stds':json_nn['stds']
92
+ }
93
+ if os.path.exists(model_path_reduced):
94
+ json_nn = json.load(open(model_path_reduced, 'r'))
95
+ self.neural_networks['reduced'] = {
96
+ 'weights':json_nn['weights'],
97
+ 'biases':json_nn['biases']
98
+ }
99
+ self.scalers['reduced'] = {
100
+ 'means':json_nn['means'],
101
+ 'stds':json_nn['stds']
102
+ }
103
+
104
+ def ages_prediction(self,
105
+ met,mag,col,
106
+ emet=None,emag=None,ecol=None,
107
+ GBP=None,GRP=None,
108
+ eGBP=None,eGRP=None,
109
+ n=1,
110
+ store_samples=True,
111
+ min_age=0,max_age=14):
112
+
113
+ if met is not None and type(met) is not list:
114
+ if hasattr(met,'tolist'):
115
+ met = met.tolist()
116
+ else:
117
+ met = [met]
118
+ if mag is not None and type(mag) is not list:
119
+ if hasattr(mag,'tolist'):
120
+ mag = mag.tolist()
121
+ else:
122
+ mag = [mag]
123
+ if col is not None and type(col) is not list:
124
+ if hasattr(col,'tolist'):
125
+ col = col.tolist()
126
+ else:
127
+ col = [col]
128
+ if emet is not None and type(emet) is not list:
129
+ if hasattr(emet,'tolist'):
130
+ emet = emet.tolist()
131
+ else:
132
+ emet = [emet]
133
+ if emag is not None and type(emag) is not list:
134
+ if hasattr(emag,'tolist'):
135
+ emag = emag.tolist()
136
+ else:
137
+ emag = [emag]
138
+ if ecol is not None and type(ecol) is not list:
139
+ if hasattr(ecol,'tolist'):
140
+ ecol = ecol.tolist()
141
+ else:
142
+ ecol = [ecol]
143
+ if GBP is not None and type(GBP) is not list:
144
+ if hasattr(GBP,'tolist'):
145
+ GBP = GBP.tolist()
146
+ else:
147
+ GBP = [GBP]
148
+ if GRP is not None and type(GRP) is not list:
149
+ if hasattr(GRP,'tolist'):
150
+ GRP = GRP.tolist()
151
+ else:
152
+ GRP = [GRP]
153
+ if eGBP is not None and type(eGBP) is not list:
154
+ if hasattr(eGBP,'tolist'):
155
+ eGBP = eGBP.tolist()
156
+ else:
157
+ eGBP = [eGBP]
158
+ if eGRP is not None and type(eGRP) is not list:
159
+ if hasattr(eGRP,'tolist'):
160
+ eGRP = eGRP.tolist()
161
+ else:
162
+ eGRP = [eGRP]
163
+
164
+ if store_samples and n*len(met) > 1e6:
165
+ warnings.warn('Storing samples for {} stars with {} samples for each will take a lot of memory. Consider setting store_samples=False to only store mean,median,mode and std of individual age distributions.'.format(len(met),n))
166
+
167
+ inputs = [input for input in [met,mag,col,emet,emag,ecol,GBP,GRP,eGBP,eGRP] if input is not None]
168
+ if len(set(map(len,inputs))) != 1:
169
+ raise ValueError('All input arrays must have the same length')
170
+
171
+ is_reduced = True
172
+ has_errors = False
173
+ if GBP is not None and GRP is not None:
174
+ is_reduced = False
175
+ if emet is not None and emag is not None and ecol is not None and eGBP is not None and eGRP is not None:
176
+ has_errors = True
177
+ else:
178
+ if emet is not None and emag is not None and ecol is not None:
179
+ has_errors = True
180
+
181
+ if n > 1 and not has_errors:
182
+ raise ValueError('For more than one sample, errors must be provided')
183
+
184
+ if is_reduced:
185
+ X = np.array([met, mag, col])
186
+ X_errors = np.array([emet, emag, ecol])
187
+ else:
188
+ X = np.array([met, mag, GBP, GRP, col])
189
+ X_errors = np.array([emet, emag, eGBP, eGRP, ecol])
190
+
191
+ X = X.T
192
+ X_errors = X_errors.T
193
+
194
+ if X.shape[1] == 3:
195
+ if self.neural_networks.get('reduced') is None:
196
+ raise ValueError('Reduced neural network not available for this model')
197
+ scaler = self.scalers['reduced']
198
+ neural_network = self.neural_networks['reduced']
199
+ else:
200
+ if self.neural_networks.get('full') is None:
201
+ raise ValueError('Full neural network not available for this model')
202
+ scaler = self.scalers['full']
203
+ neural_network = self.neural_networks['full']
204
+
205
+ self.ages = np.zeros((X.shape[0],n))
206
+ if store_samples:
207
+ self.samples = np.zeros((X.shape[0],n,X.shape[1]))
208
+ else:
209
+ self.samples = None
210
+ self.medians = np.zeros(X.shape[0])
211
+ self.means = np.zeros(X.shape[0])
212
+ self.modes = np.zeros(X.shape[0])
213
+ self.stds = np.zeros(X.shape[0])
214
+
215
+ if self.use_tqdm and (n > 1 or X.shape[0] > 1):
216
+ loop = tqdm(range(X.shape[0]))
217
+ else:
218
+ loop = range(X.shape[0])
219
+ for i in loop:
220
+ if n > 1:
221
+ X_i = np.random.normal(X[i],X_errors[i],(n,X.shape[1]))
222
+ if store_samples:
223
+ self.samples[i] = X_i
224
+ else:
225
+ X_i = X[i].reshape(1,-1)
226
+ if store_samples:
227
+ self.samples[i] = X_i
228
+
229
+ ages = self.propagate(X_i,neural_network,scaler)
230
+ if store_samples:
231
+ self.ages[i] = ages
232
+ else:
233
+ median = np.median(ages)
234
+ mean = np.mean(ages)
235
+ mode = get_mode(ages,min_age,max_age)
236
+ std = np.std(ages)
237
+ self.medians[i] = median
238
+ self.means[i] = mean
239
+ self.modes[i] = mode
240
+ self.stds[i] = std
241
+
242
+ if store_samples:
243
+ return self.ages
244
+ else:
245
+ return {'mean':self.means,'median':self.medians,'mode':self.modes,'std':self.stds}
246
+
247
+ def check_domain(self,met,mag,col,emet=None,emag=None,ecol=None):
248
+ if self.domain is None:
249
+ raise ValueError('No domain defined for this model')
250
+ if met is not None and type(met) is not list:
251
+ if hasattr(met,'tolist'):
252
+ met = met.tolist()
253
+ else:
254
+ met = [met]
255
+ if mag is not None and type(mag) is not list:
256
+ if hasattr(mag,'tolist'):
257
+ mag = mag.tolist()
258
+ else:
259
+ mag = [mag]
260
+ if col is not None and type(col) is not list:
261
+ if hasattr(col,'tolist'):
262
+ col = col.tolist()
263
+ else:
264
+ col = [col]
265
+ if emet is not None and type(emet) is not list:
266
+ if hasattr(emet,'tolist'):
267
+ emet = emet.tolist()
268
+ else:
269
+ emet = [emet]
270
+ if emag is not None and type(emag) is not list:
271
+ if hasattr(emag,'tolist'):
272
+ emag = emag.tolist()
273
+ else:
274
+ emag = [emag]
275
+ if ecol is not None and type(ecol) is not list:
276
+ if hasattr(ecol,'tolist'):
277
+ ecol = ecol.tolist()
278
+ else:
279
+ ecol = [ecol]
280
+
281
+ has_errors = emet != None and emag != None and ecol != None
282
+
283
+ in_domain = np.zeros(len(met),dtype=bool)
284
+
285
+ if self.use_tqdm and len(met) > 1:
286
+ loop = tqdm(range(len(met)))
287
+ else:
288
+ loop = range(len(met))
289
+
290
+ for i in loop:
291
+ if has_errors:
292
+ errors = [ecol[i],emag[i],emet[i]]
293
+ else:
294
+ errors = [0,0,0]
295
+ min_i_col = np.maximum(np.digitize(col[i] - errors[0],self.space_col) - 1,0)
296
+ max_i_col = np.minimum(np.digitize(col[i] + errors[0],self.space_col) - 1,self.space_col.size-2)
297
+ min_i_mag = np.maximum(np.digitize(mag[i] - errors[1],self.space_mag) - 1,0)
298
+ max_i_mag = np.minimum(np.digitize(mag[i] + errors[1],self.space_mag) - 1,self.space_mag.size-2)
299
+ min_i_met = np.maximum(np.digitize(met[i] - errors[2],self.space_met) - 1,0)
300
+ max_i_met = np.minimum(np.digitize(met[i] + errors[2],self.space_met) - 1,self.space_met.size-2)
301
+ #cells[i] = np.array([[min_i_col,max_i_col],[min_i_mag,max_i_mag],[min_i_met,max_i_met]])
302
+ if self.cut and (self.space_col[min_i_col] > 1.25 or self.space_mag[min_i_mag] > 4):
303
+ in_domain[i] = False
304
+ continue
305
+ in_domain[i] = bool(np.any(self.domain[min_i_col:max_i_col+1,min_i_mag:max_i_mag+1,min_i_met:max_i_met+1]) == 1)
306
+
307
+ return in_domain#,cells
308
+
309
+ def propagate(self,X,neural_network,scaler):
310
+ if self.use_sklearn:
311
+ with warnings.catch_warnings():
312
+ warnings.simplefilter("ignore", RuntimeWarning)
313
+ X = scaler.transform(X)
314
+ return neural_network.predict(X)
315
+ else:
316
+ weights = neural_network['weights']
317
+ biases = neural_network['biases']
318
+ means = scaler['means']
319
+ stds = scaler['stds']
320
+ outputs = []
321
+ for x in X:
322
+ a = (x - means)/stds
323
+ output = self.predict_nn(a,weights,biases)
324
+ outputs.append(output)
325
+ return np.array(outputs)
326
+
327
+ def relu(self,x):
328
+ return np.maximum(0,x)
329
+
330
+ def dot(self,x,y):
331
+ x_dot_y = 0
332
+ for i in range(len(x)):
333
+ x_dot_y += x[i]*y[i]
334
+ return x_dot_y
335
+
336
+ def predict_nn(self,X,weights,biases):
337
+ a = X
338
+ for i in range(len(weights)):
339
+ a = self.dot(a,weights[i]) + biases[i]
340
+ a = self.relu(a)
341
+ return a[0]
342
+
343
+ def mean_ages(self):
344
+ if self.ages is None:
345
+ raise ValueError('No age predictions have been made yet')
346
+ self.means = np.mean(self.ages,axis=1)
347
+ return self.means
348
+
349
+ def median_ages(self):
350
+ if self.ages is None:
351
+ raise ValueError('No age predictions have been made yet')
352
+ self.medians = np.median(self.ages,axis=1)
353
+ return self.medians
354
+
355
+ def mode_ages(self):
356
+ if self.ages is None:
357
+ raise ValueError('No age predictions have been made yet')
358
+ #TODO: choose number of bins appropriately
359
+ modes = []
360
+ min_age = max(0,self.ages.min())
361
+ max_age = max(14,self.ages.max())
362
+
363
+ for i in range(len(self.ages)):
364
+ modes.append(get_mode(self.ages[i],min_age,max_age))
365
+ self.modes = np.array(modes)
366
+ return self.modes
367
+
368
+ def std_ages(self):
369
+ if self.ages is None:
370
+ raise ValueError('No age predictions have been made yet')
371
+ self.stds = np.std(self.ages,axis=1)
372
+ return self.stds
373
+
374
+ class BaSTIModel(AgeModel):
375
+ def __init__(self,cut=False,use_sklearn=True,use_tqdm=True):
376
+ super().__init__('BaSTI',cut,use_sklearn,use_tqdm)
377
+
378
+ '''
379
+ class BaSTI2Model(AgeModel):
380
+ def __init__(self,cut=False,use_sklearn=True,use_tqdm=True):
381
+ super().__init__('BaSTI2',cut,use_sklearn,use_tqdm)
382
+
383
+ class BaSTI_HSTModel(AgeModel):
384
+ def __init__(self,cut=False,use_sklearn=True,use_tqdm=True):
385
+ super().__init__('BaSTI_HST',cut,use_sklearn,use_tqdm)
386
+
387
+ class BaSTI_HST_alpha_zeroModel(AgeModel):
388
+ def __init__(self,cut=False,use_sklearn=True,use_tqdm=True):
389
+ super().__init__('BaSTI_HST_alpha_zero',cut,use_sklearn,use_tqdm)
390
+ '''
391
+
392
+ class PARSECModel(AgeModel):
393
+ def __init__(self,cut=False,use_sklearn=True,use_tqdm=True):
394
+ super().__init__('PARSEC',cut,use_sklearn,use_tqdm)
395
+
396
+ class MISTModel(AgeModel):
397
+ def __init__(self,cut=False,use_sklearn=True,use_tqdm=True):
398
+ super().__init__('MIST',cut,use_sklearn,use_tqdm)
399
+
400
+ class GenevaModel(AgeModel):
401
+ def __init__(self,cut=False,use_sklearn=True,use_tqdm=True):
402
+ super().__init__('Geneva',cut,use_sklearn,use_tqdm)
403
+
404
+ class DartmouthModel(AgeModel):
405
+ def __init__(self,cut=False,use_sklearn=True,use_tqdm=True):
406
+ super().__init__('Dartmouth',cut,use_sklearn,use_tqdm)
407
+
408
+ class YaPSIModel(AgeModel):
409
+ def __init__(self,cut=False,use_sklearn=True,use_tqdm=True):
410
+ super().__init__('YaPSI',cut,use_sklearn,use_tqdm)
411
+
412
+ #TODO: add flavors to models (e.g. trained on cut CMD for optimal performance)
@@ -0,0 +1 @@
1
+ from .NEST import *
Binary file
Binary file
Binary file
Binary file