astro-nest 0.5.5__py3-none-any.whl → 0.5.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. NEST/NEST.py +86 -137
  2. NEST/domain.pkl +0 -0
  3. NEST/models/BaSTI_HST_BPRP.sav +0 -0
  4. NEST/models/NN_BaSTI.json +1 -1
  5. NEST/models/NN_BaSTI_BPRP.json +1 -1
  6. {astro_nest-0.5.5.dist-info → astro_nest-0.5.6.dist-info}/METADATA +1 -1
  7. astro_nest-0.5.6.dist-info/RECORD +34 -0
  8. NEST/models/BaSTI2.sav +0 -0
  9. NEST/models/BaSTI2_BPRP.sav +0 -0
  10. NEST/models/NN_BaSTI.sav +0 -0
  11. NEST/models/NN_BaSTI.sav.old +0 -0
  12. NEST/models/NN_BaSTI_BPRP.sav +0 -0
  13. NEST/models/NN_BaSTI_HST_BPRP.sav +0 -0
  14. NEST/models/NN_BaSTI_HST_alpha_zero_BPRP.json +0 -1
  15. NEST/models/NN_BaSTI_HST_alpha_zero_BPRP.sav +0 -0
  16. NEST/models/NN_BaSTI_cut.json +0 -1
  17. NEST/models/NN_BaSTI_cut.sav +0 -0
  18. NEST/models/NN_BaSTI_cut_BPRP.json +0 -1
  19. NEST/models/NN_BaSTI_cut_BPRP.sav +0 -0
  20. NEST/models/NN_Dartmouth.sav +0 -0
  21. NEST/models/NN_Dartmouth_BPRP.sav +0 -0
  22. NEST/models/NN_MIST.sav +0 -0
  23. NEST/models/NN_MIST_BPRP.sav +0 -0
  24. NEST/models/NN_PARSEC.sav +0 -0
  25. NEST/models/NN_PARSEC_BPRP.sav +0 -0
  26. NEST/models/NN_SYCLIST.sav +0 -0
  27. NEST/models/NN_SYCLIST_BPRP.sav +0 -0
  28. NEST/models/NN_YaPSI.sav +0 -0
  29. NEST/models/NN_YaPSI_BPRP.sav +0 -0
  30. NEST/models/scaler_BaSTI.sav +0 -0
  31. NEST/models/scaler_BaSTI_BPRP.sav +0 -0
  32. NEST/models/scaler_BaSTI_BPRP_cut.sav +0 -0
  33. NEST/models/scaler_BaSTI_HST_BPRP.sav +0 -0
  34. NEST/models/scaler_BaSTI_HST_alpha_zero_BPRP.sav +0 -0
  35. NEST/models/scaler_BaSTI_cut.sav +0 -0
  36. NEST/models/scaler_BaSTI_cut_BPRP.sav +0 -0
  37. NEST/models/scaler_Dartmouth.sav +0 -0
  38. NEST/models/scaler_Dartmouth_BPRP.sav +0 -0
  39. NEST/models/scaler_MIST.sav +0 -0
  40. NEST/models/scaler_MIST_BPRP.sav +0 -0
  41. NEST/models/scaler_PARSEC.sav +0 -0
  42. NEST/models/scaler_PARSEC_BPRP.sav +0 -0
  43. NEST/models/scaler_SYCLIST.sav +0 -0
  44. NEST/models/scaler_SYCLIST_BPRP.sav +0 -0
  45. NEST/models/scaler_YaPSI.sav +0 -0
  46. NEST/models/scaler_YaPSI_BPRP.sav +0 -0
  47. astro_nest-0.5.5.dist-info/RECORD +0 -72
  48. /NEST/models/{SYCLIST.sav → Geneva.sav} +0 -0
  49. /NEST/models/{SYCLIST_BPRP.sav → Geneva_BPRP.sav} +0 -0
  50. /NEST/models/{NN_SYCLIST.json → NN_Geneva.json} +0 -0
  51. /NEST/models/{NN_SYCLIST_BPRP.json → NN_Geneva_BPRP.json} +0 -0
  52. {astro_nest-0.5.5.dist-info → astro_nest-0.5.6.dist-info}/WHEEL +0 -0
  53. {astro_nest-0.5.5.dist-info → astro_nest-0.5.6.dist-info}/top_level.txt +0 -0
NEST/NEST.py CHANGED
@@ -1,5 +1,6 @@
1
1
  import numpy as np
2
2
  import pickle,json,os,warnings,shutil,zipfile,urllib.request,json
3
+
3
4
  try:
4
5
  from tqdm import tqdm
5
6
  has_tqdm = True
@@ -15,7 +16,8 @@ try:
15
16
  has_matplotlib = True
16
17
  except ImportError:
17
18
  has_matplotlib = False
18
- NNSA_DIR = os.path.dirname(os.path.abspath(__file__))
19
+
20
+ NEST_DIR = os.path.dirname(os.path.abspath(__file__))
19
21
 
20
22
  loaded_isochrones = {}
21
23
 
@@ -24,16 +26,25 @@ def custom_warning(message, category, filename, lineno, file=None, line=None):
24
26
 
25
27
  warnings.showwarning = custom_warning
26
28
 
27
- def get_mode(arr,min_age=0,max_age=14):
28
- hist, bins = np.histogram(arr,bins=100,range=(min_age,max_age))
29
+ def sanitize_input(input):
30
+ if input is not None and type(input) is not list:
31
+ if hasattr(input,'tolist'):
32
+ input = input.tolist()
33
+ else:
34
+ input = [input]
35
+ return input
36
+
37
+ def get_mode(arr,min_age=0,max_age=14,nbins=280):
38
+ #TODO: choose number of bins appropriately
39
+ hist, bins = np.histogram(arr,bins=nbins,range=(min_age,max_age))
29
40
  return bins[np.argmax(hist)] + (bins[1]-bins[0])/2
30
41
 
31
42
  def download_isochrones():
32
43
  if input('Isochrone curves for plots do not exist. Download them ? (15.6Mb) [Y/n] (default: Y)') in ['n', 'N']:
33
44
  return None
34
45
  iso_url = "https://github.com/star-age/star-age.github.io/archive/refs/heads/main.zip"
35
- iso_dir = os.path.join(NNSA_DIR, 'isochrones')
36
- tmp_zip = os.path.join(NNSA_DIR, 'isochrones_tmp.zip')
46
+ iso_dir = os.path.join(NEST_DIR, 'isochrones')
47
+ tmp_zip = os.path.join(NEST_DIR, 'isochrones_tmp.zip')
37
48
 
38
49
  print("Downloading isochrones from GitHub...")
39
50
  if has_requests and has_tqdm:
@@ -54,49 +65,52 @@ def download_isochrones():
54
65
 
55
66
  with zipfile.ZipFile(tmp_zip, 'r') as zip_ref:
56
67
  members = [m for m in zip_ref.namelist() if m.startswith('star-age.github.io-main/isochrones/')]
57
- zip_ref.extractall(NNSA_DIR, members)
68
+ zip_ref.extractall(NEST_DIR, members)
58
69
 
59
- src = os.path.join(NNSA_DIR, 'star-age.github.io-main', 'isochrones')
70
+ src = os.path.join(NEST_DIR, 'star-age.github.io-main', 'isochrones')
60
71
  if os.path.exists(iso_dir):
61
72
  shutil.rmtree(iso_dir)
62
73
  shutil.move(src, iso_dir)
63
- shutil.rmtree(os.path.join(NNSA_DIR, 'star-age.github.io-main'))
74
+ shutil.rmtree(os.path.join(NEST_DIR, 'star-age.github.io-main'))
64
75
  os.remove(tmp_zip)
65
76
  print("Isochrones downloaded and extracted.")
66
77
 
67
78
  def get_isochrones(model):
68
79
  if model.model_name in loaded_isochrones:
69
80
  return loaded_isochrones[model.model_name]
70
- if os.path.exists(os.path.join(NNSA_DIR, 'isochrones')) == False:
81
+ if os.path.exists(os.path.join(NEST_DIR, 'isochrones')) == False:
71
82
  download_isochrones()
72
- isochrone_path = os.path.join(NNSA_DIR, 'isochrones', 'isochrones_' + model.model_name + '.json')
83
+ isochrone_path = os.path.join(NEST_DIR, 'isochrones', 'isochrones_' + model.model_name + '.json')
73
84
  if os.path.exists(isochrone_path):
74
85
  loaded_isochrones[model.model_name] = json.load(open(isochrone_path, 'r'))
75
86
  return loaded_isochrones[model.model_name]
76
87
  return None
77
88
 
78
89
  def available_models():
79
- model_list = ['BaSTI','PARSEC','MIST','Geneva','Dartmouth','YaPSI']
90
+ model_list = ['BaSTI','PARSEC','MIST','Geneva','Dartmouth','YaPSI','BaSTI_HST']
80
91
  model_sources = [
81
92
  'http://basti-iac.oa-abruzzo.inaf.it/',
82
93
  'https://stev.oapd.inaf.it/PARSEC/',
83
94
  'https://waps.cfa.harvard.edu/MIST/',
84
- 'https://www.unige.ch/sciences/astro/evolution/en/database/syclist',
95
+ 'https://www.unige.ch/sciences/astro/evolution/en/database/syclist/',
85
96
  'https://rcweb.dartmouth.edu/stellar/',
86
- 'http://www.astro.yale.edu/yapsi/'
97
+ 'http://www.astro.yale.edu/yapsi/',
98
+ 'http://basti-iac.oa-abruzzo.inaf.it/'
87
99
  ]
88
100
  for model,source in zip(model_list,model_sources):
89
101
  print(model + 'Model (' + source + ')')
90
102
 
91
103
  class AgeModel:
92
- def __init__(self,model_name,cut=False,use_sklearn=True,use_tqdm=True):
104
+ def __init__(self,model_name,use_sklearn=True,use_tqdm=True,photometric_type=None):
93
105
  self.model_name = model_name
94
106
  self.use_sklearn = use_sklearn
95
107
  self.use_tqdm = use_tqdm
108
+ if photometric_type == None:
109
+ photometric_type = 'Gaia'
110
+ self.photometric_type = photometric_type
96
111
  if not has_tqdm and self.use_tqdm:
97
112
  self.use_tqdm = False
98
- self.cut = cut
99
- domain_path = os.path.join(NNSA_DIR, 'domain.pkl')
113
+ domain_path = os.path.join(NEST_DIR, 'domain.pkl')
100
114
  domain = pickle.load(open(domain_path, 'rb'))
101
115
  if model_name in domain:
102
116
  self.domain = domain[model_name]
@@ -109,8 +123,6 @@ class AgeModel:
109
123
  self.space_col = None
110
124
  self.space_mag = None
111
125
  self.space_met = None
112
- if self.cut:
113
- self.model_name = self.model_name + '_cut'
114
126
  self.neural_networks = {}
115
127
  self.scalers = {}
116
128
  self.samples = None
@@ -121,13 +133,21 @@ class AgeModel:
121
133
  self.stds = None
122
134
  self.load_neural_network(self.model_name)
123
135
 
136
+ def __repr__(self):
137
+ return self.__str__()
138
+
124
139
  def __str__(self):
125
- return self.model_name + ' Age Model'
140
+ _str = self.model_name + ' Age Model'
141
+ if self.photometric_type == 'Gaia':
142
+ _str += ', Gaia photometry: mag=MG, col=(GBP-GRP).'
143
+ elif self.photometric_type == 'HST':
144
+ _str += ', HST photometry: mag=F814W, col=(F606W-F814W).'
145
+ return _str
126
146
 
127
147
  def load_neural_network(self, model_name):
128
148
  if self.use_sklearn:
129
- model_path_full = os.path.join(NNSA_DIR, 'models', f'{model_name}.sav')
130
- model_path_reduced = os.path.join(NNSA_DIR, 'models', f'{model_name}_BPRP.sav')
149
+ model_path_full = os.path.join(NEST_DIR, 'models', f'{model_name}.sav')
150
+ model_path_reduced = os.path.join(NEST_DIR, 'models', f'{model_name}_BPRP.sav')
131
151
  if os.path.exists(model_path_full):
132
152
  nn = pickle.load(open(model_path_full, 'rb'))
133
153
  self.neural_networks['full'] = nn['NN']
@@ -137,8 +157,8 @@ class AgeModel:
137
157
  self.neural_networks['reduced'] = nn['NN']
138
158
  self.scalers['reduced'] = nn['Scaler']
139
159
  else:
140
- model_path_full = os.path.join(NNSA_DIR, 'models', f'NN_{model_name}.json')
141
- model_path_reduced = os.path.join(NNSA_DIR, 'models', f'NN_{model_name}_BPRP.json')
160
+ model_path_full = os.path.join(NEST_DIR, 'models', f'NN_{model_name}.json')
161
+ model_path_reduced = os.path.join(NEST_DIR, 'models', f'NN_{model_name}_BPRP.json')
142
162
  if os.path.exists(model_path_full):
143
163
  json_nn = json.load(open(model_path_full, 'r'))
144
164
  self.neural_networks['full'] = {
@@ -169,56 +189,17 @@ class AgeModel:
169
189
  store_samples=True,
170
190
  min_age=0,max_age=14):
171
191
 
172
- if met is not None and type(met) is not list:
173
- if hasattr(met,'tolist'):
174
- met = met.tolist()
175
- else:
176
- met = [met]
177
- if mag is not None and type(mag) is not list:
178
- if hasattr(mag,'tolist'):
179
- mag = mag.tolist()
180
- else:
181
- mag = [mag]
182
- if col is not None and type(col) is not list:
183
- if hasattr(col,'tolist'):
184
- col = col.tolist()
185
- else:
186
- col = [col]
187
- if emet is not None and type(emet) is not list:
188
- if hasattr(emet,'tolist'):
189
- emet = emet.tolist()
190
- else:
191
- emet = [emet]
192
- if emag is not None and type(emag) is not list:
193
- if hasattr(emag,'tolist'):
194
- emag = emag.tolist()
195
- else:
196
- emag = [emag]
197
- if ecol is not None and type(ecol) is not list:
198
- if hasattr(ecol,'tolist'):
199
- ecol = ecol.tolist()
200
- else:
201
- ecol = [ecol]
202
- if GBP is not None and type(GBP) is not list:
203
- if hasattr(GBP,'tolist'):
204
- GBP = GBP.tolist()
205
- else:
206
- GBP = [GBP]
207
- if GRP is not None and type(GRP) is not list:
208
- if hasattr(GRP,'tolist'):
209
- GRP = GRP.tolist()
210
- else:
211
- GRP = [GRP]
212
- if eGBP is not None and type(eGBP) is not list:
213
- if hasattr(eGBP,'tolist'):
214
- eGBP = eGBP.tolist()
215
- else:
216
- eGBP = [eGBP]
217
- if eGRP is not None and type(eGRP) is not list:
218
- if hasattr(eGRP,'tolist'):
219
- eGRP = eGRP.tolist()
220
- else:
221
- eGRP = [eGRP]
192
+
193
+ met = sanitize_input(met)
194
+ mag = sanitize_input(mag)
195
+ col = sanitize_input(col)
196
+ emet = sanitize_input(emet)
197
+ emag = sanitize_input(emag)
198
+ ecol = sanitize_input(ecol)
199
+ GBP = sanitize_input(GBP)
200
+ GRP = sanitize_input(GRP)
201
+ eGBP = sanitize_input(eGBP)
202
+ eGRP = sanitize_input(eGRP)
222
203
 
223
204
  if store_samples and n*len(met) > 1e6:
224
205
  warnings.warn('Storing samples for {} stars with {} samples for each will take a lot of memory. Consider setting store_samples=False to only store mean,median,mode and std of individual age distributions.'.format(len(met),n))
@@ -306,36 +287,12 @@ class AgeModel:
306
287
  def check_domain(self,met,mag,col,emet=None,emag=None,ecol=None):
307
288
  if self.domain is None:
308
289
  raise ValueError('No domain defined for this model')
309
- if met is not None and type(met) is not list:
310
- if hasattr(met,'tolist'):
311
- met = met.tolist()
312
- else:
313
- met = [met]
314
- if mag is not None and type(mag) is not list:
315
- if hasattr(mag,'tolist'):
316
- mag = mag.tolist()
317
- else:
318
- mag = [mag]
319
- if col is not None and type(col) is not list:
320
- if hasattr(col,'tolist'):
321
- col = col.tolist()
322
- else:
323
- col = [col]
324
- if emet is not None and type(emet) is not list:
325
- if hasattr(emet,'tolist'):
326
- emet = emet.tolist()
327
- else:
328
- emet = [emet]
329
- if emag is not None and type(emag) is not list:
330
- if hasattr(emag,'tolist'):
331
- emag = emag.tolist()
332
- else:
333
- emag = [emag]
334
- if ecol is not None and type(ecol) is not list:
335
- if hasattr(ecol,'tolist'):
336
- ecol = ecol.tolist()
337
- else:
338
- ecol = [ecol]
290
+ met = sanitize_input(met)
291
+ mag = sanitize_input(mag)
292
+ col = sanitize_input(col)
293
+ emet = sanitize_input(emet)
294
+ emag = sanitize_input(emag)
295
+ ecol = sanitize_input(ecol)
339
296
 
340
297
  has_errors = emet != None and emag != None and ecol != None
341
298
 
@@ -357,9 +314,6 @@ class AgeModel:
357
314
  max_i_mag = np.minimum(np.digitize(mag[i] + errors[1],self.space_mag) - 1,self.space_mag.size-2)
358
315
  min_i_met = np.maximum(np.digitize(met[i] - errors[2],self.space_met) - 1,0)
359
316
  max_i_met = np.minimum(np.digitize(met[i] + errors[2],self.space_met) - 1,self.space_met.size-2)
360
- if self.cut and (self.space_col[min_i_col] > 1.25 or self.space_mag[min_i_mag] > 4):
361
- in_domain[i] = False
362
- continue
363
317
  in_domain[i] = bool(np.any(self.domain[min_i_col:max_i_col+1,min_i_mag:max_i_mag+1,min_i_met:max_i_met+1]) == 1)
364
318
 
365
319
  return in_domain
@@ -413,7 +367,6 @@ class AgeModel:
413
367
  def mode_ages(self):
414
368
  if self.ages is None:
415
369
  raise ValueError('No age predictions have been made yet')
416
- #TODO: choose number of bins appropriately
417
370
  modes = []
418
371
  min_age = max(0,self.ages.min())
419
372
  max_age = max(14,self.ages.max())
@@ -458,6 +411,9 @@ class AgeModel:
458
411
  if check_domain and met is not None:
459
412
  in_domain = self.check_domain(met,mag,col)
460
413
  age = age[in_domain]
414
+ met = np.array(sanitize_input(met))
415
+ mag = np.array(sanitize_input(mag))
416
+ col = np.array(sanitize_input(col))
461
417
  met = met[in_domain]
462
418
  mag = mag[in_domain]
463
419
  col = col[in_domain]
@@ -475,7 +431,7 @@ class AgeModel:
475
431
  print('Metallicity not available for this model (-2,-1,0), using [M/H]=0 instead')
476
432
  isochrone_met = '0'
477
433
  isochrones = isochrones[isochrone_met]
478
- # Store line objects and their ages for interactivity
434
+
479
435
  lines = []
480
436
  ages = []
481
437
  for isochrone in isochrones:
@@ -503,7 +459,6 @@ class AgeModel:
503
459
  )
504
460
  plt.colorbar(scatter,label='Age [Gyr]')
505
461
 
506
- # Add hover info box for isochrone lines
507
462
  annot = ax.annotate("", xy=(0,0), xytext=(0,15), textcoords="offset points",
508
463
  ha='center',
509
464
  bbox=dict(boxstyle="round", fc="w"),
@@ -535,46 +490,40 @@ class AgeModel:
535
490
 
536
491
  ax.set_xlim(-1,3)
537
492
  ax.set_ylim(10,-5)
538
- ax.set_xlabel(r'$(G_{BP}-G_{RP})_0$ [mag]')
539
- ax.set_ylabel(r'$M_G$ [mag]')
493
+ if self.photometric_type == 'Gaia':
494
+ ax.set_xlabel(r'$(G_{BP}-G_{RP})_0$ [mag]')
495
+ ax.set_ylabel(r'$M_G$ [mag]')
496
+ elif self.photometric_type == 'HST':
497
+ ax.set_xlabel(r'$(F606W-F814W)$ [mag]')
498
+ ax.set_ylabel(r'$F606W$ [mag]')
540
499
 
541
500
  plt.tight_layout()
542
501
  plt.show()
543
502
 
544
503
  class BaSTIModel(AgeModel):
545
- def __init__(self,cut=False,use_sklearn=True,use_tqdm=True):
546
- super().__init__('BaSTI',cut,use_sklearn,use_tqdm)
547
-
548
- '''
549
- class BaSTI2Model(AgeModel):
550
- def __init__(self,cut=False,use_sklearn=True,use_tqdm=True):
551
- super().__init__('BaSTI2',cut,use_sklearn,use_tqdm)
552
-
553
- class BaSTI_HSTModel(AgeModel):
554
- def __init__(self,cut=False,use_sklearn=True,use_tqdm=True):
555
- super().__init__('BaSTI_HST',cut,use_sklearn,use_tqdm)
556
-
557
- class BaSTI_HST_alpha_zeroModel(AgeModel):
558
- def __init__(self,cut=False,use_sklearn=True,use_tqdm=True):
559
- super().__init__('BaSTI_HST_alpha_zero',cut,use_sklearn,use_tqdm)
560
- '''
504
+ def __init__(self,use_sklearn=True,use_tqdm=True):
505
+ super().__init__('BaSTI',use_sklearn,use_tqdm)
561
506
 
562
507
  class PARSECModel(AgeModel):
563
- def __init__(self,cut=False,use_sklearn=True,use_tqdm=True):
564
- super().__init__('PARSEC',cut,use_sklearn,use_tqdm)
508
+ def __init__(self,use_sklearn=True,use_tqdm=True):
509
+ super().__init__('PARSEC',use_sklearn,use_tqdm)
565
510
 
566
511
  class MISTModel(AgeModel):
567
- def __init__(self,cut=False,use_sklearn=True,use_tqdm=True):
568
- super().__init__('MIST',cut,use_sklearn,use_tqdm)
512
+ def __init__(self,use_sklearn=True,use_tqdm=True):
513
+ super().__init__('MIST',use_sklearn,use_tqdm)
569
514
 
570
515
  class GenevaModel(AgeModel):
571
- def __init__(self,cut=False,use_sklearn=True,use_tqdm=True):
572
- super().__init__('Geneva',cut,use_sklearn,use_tqdm)
516
+ def __init__(self,use_sklearn=True,use_tqdm=True):
517
+ super().__init__('Geneva',use_sklearn,use_tqdm)
573
518
 
574
519
  class DartmouthModel(AgeModel):
575
- def __init__(self,cut=False,use_sklearn=True,use_tqdm=True):
576
- super().__init__('Dartmouth',cut,use_sklearn,use_tqdm)
520
+ def __init__(self,use_sklearn=True,use_tqdm=True):
521
+ super().__init__('Dartmouth',use_sklearn,use_tqdm)
577
522
 
578
523
  class YaPSIModel(AgeModel):
579
- def __init__(self,cut=False,use_sklearn=True,use_tqdm=True):
580
- super().__init__('YaPSI',cut,use_sklearn,use_tqdm)
524
+ def __init__(self,use_sklearn=True,use_tqdm=True):
525
+ super().__init__('YaPSI',use_sklearn,use_tqdm)
526
+
527
+ class BaSTI_HSTModel(AgeModel):
528
+ def __init__(self,use_sklearn=True,use_tqdm=True,photometric_type='HST'):
529
+ super().__init__('BaSTI_HST',use_sklearn,use_tqdm,photometric_type)
NEST/domain.pkl CHANGED
Binary file
Binary file