astro-nest 0.5.0__py3-none-any.whl → 0.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
NEST/NEST.py CHANGED
@@ -1,12 +1,24 @@
1
1
  import numpy as np
2
- import pickle,json,os,warnings
2
+ import pickle,json,os,warnings,shutil,zipfile,urllib.request,json
3
3
  try:
4
4
  from tqdm import tqdm
5
5
  has_tqdm = True
6
6
  except ImportError:
7
7
  has_tqdm = False
8
+ try:
9
+ import requests
10
+ has_requests = True
11
+ except ImportError:
12
+ has_requests = False
13
+ try:
14
+ import matplotlib.pyplot as plt
15
+ has_matplotlib = True
16
+ except ImportError:
17
+ has_matplotlib = False
8
18
  NNSA_DIR = os.path.dirname(os.path.abspath(__file__))
9
19
 
20
+ loaded_isochrones = {}
21
+
10
22
  def custom_warning(message, category, filename, lineno, file=None, line=None):
11
23
  print(f"{category.__name__}: {message}")
12
24
 
@@ -16,6 +28,53 @@ def get_mode(arr,min_age=0,max_age=14):
16
28
  hist, bins = np.histogram(arr,bins=100,range=(min_age,max_age))
17
29
  return bins[np.argmax(hist)] + (bins[1]-bins[0])/2
18
30
 
31
+ def download_isochrones():
32
+ if input('Isochrone curves for plots do not exist. Download them ? (15.6Mb) [Y/n] (default: Y)') in ['n', 'N']:
33
+ return None
34
+ iso_url = "https://github.com/star-age/star-age.github.io/archive/refs/heads/main.zip"
35
+ iso_dir = os.path.join(NNSA_DIR, 'isochrones')
36
+ tmp_zip = os.path.join(NNSA_DIR, 'isochrones_tmp.zip')
37
+
38
+ print("Downloading isochrones from GitHub...")
39
+ if has_requests and has_tqdm:
40
+ response = requests.get(iso_url, stream=True)
41
+ total = int(response.headers.get('content-length', 0))
42
+ with open(tmp_zip, 'wb') as file, tqdm(
43
+ desc="Downloading isochrones",
44
+ total=total,
45
+ unit='B',
46
+ unit_scale=True,
47
+ unit_divisor=1024,
48
+ ) as bar:
49
+ for data in response.iter_content(chunk_size=1024):
50
+ size = file.write(data)
51
+ bar.update(size)
52
+ else:
53
+ urllib.request.urlretrieve(iso_url, tmp_zip)
54
+
55
+ with zipfile.ZipFile(tmp_zip, 'r') as zip_ref:
56
+ members = [m for m in zip_ref.namelist() if m.startswith('star-age.github.io-main/isochrones/')]
57
+ zip_ref.extractall(NNSA_DIR, members)
58
+
59
+ src = os.path.join(NNSA_DIR, 'star-age.github.io-main', 'isochrones')
60
+ if os.path.exists(iso_dir):
61
+ shutil.rmtree(iso_dir)
62
+ shutil.move(src, iso_dir)
63
+ shutil.rmtree(os.path.join(NNSA_DIR, 'star-age.github.io-main'))
64
+ os.remove(tmp_zip)
65
+ print("Isochrones downloaded and extracted.")
66
+
67
+ def get_isochrones(model):
68
+ if model.model_name in loaded_isochrones:
69
+ return loaded_isochrones[model.model_name]
70
+ if os.path.exists(os.path.join(NNSA_DIR, 'isochrones')) == False:
71
+ download_isochrones()
72
+ isochrone_path = os.path.join(NNSA_DIR, 'isochrones', 'isochrones_' + model.model_name + '.json')
73
+ if os.path.exists(isochrone_path):
74
+ loaded_isochrones[model.model_name] = json.load(open(isochrone_path, 'r'))
75
+ return loaded_isochrones[model.model_name]
76
+ return None
77
+
19
78
  def available_models():
20
79
  model_list = ['BaSTI','PARSEC','MIST','Geneva','Dartmouth','YaPSI']
21
80
  model_sources = [
@@ -298,13 +357,12 @@ class AgeModel:
298
357
  max_i_mag = np.minimum(np.digitize(mag[i] + errors[1],self.space_mag) - 1,self.space_mag.size-2)
299
358
  min_i_met = np.maximum(np.digitize(met[i] - errors[2],self.space_met) - 1,0)
300
359
  max_i_met = np.minimum(np.digitize(met[i] + errors[2],self.space_met) - 1,self.space_met.size-2)
301
- #cells[i] = np.array([[min_i_col,max_i_col],[min_i_mag,max_i_mag],[min_i_met,max_i_met]])
302
360
  if self.cut and (self.space_col[min_i_col] > 1.25 or self.space_mag[min_i_mag] > 4):
303
361
  in_domain[i] = False
304
362
  continue
305
363
  in_domain[i] = bool(np.any(self.domain[min_i_col:max_i_col+1,min_i_mag:max_i_mag+1,min_i_met:max_i_met+1]) == 1)
306
364
 
307
- return in_domain#,cells
365
+ return in_domain
308
366
 
309
367
  def propagate(self,X,neural_network,scaler):
310
368
  if self.use_sklearn:
@@ -371,6 +429,117 @@ class AgeModel:
371
429
  self.stds = np.std(self.ages,axis=1)
372
430
  return self.stds
373
431
 
432
+ def HR_diagram(self,met=None,mag=None,col=None,age=None,isochrone_met=0,age_type='median',check_domain=True,**kwargs):
433
+ if has_matplotlib == False:
434
+ raise ImportError('matplotlib is required for HR diagram plotting')
435
+
436
+ if age_type not in ('median','mean','mode'):
437
+ print('Age type not available, using median instead')
438
+ age_type = 'median'
439
+
440
+ if met is not None and mag is not None and col is not None:
441
+ if age is None:
442
+ result = self.ages_prediction(met,mag,col,store_samples=False)
443
+ age = result[age_type]
444
+ elif met is None and mag is None and col is None and age is None:
445
+ if self.samples is not None and len(self.samples) > 0 and self.ages is not None and len(self.ages) == len(self.samples):
446
+ if age_type == 'median':
447
+ age = self.median_ages()
448
+ elif age_type == 'mean':
449
+ age = self.mean_ages()
450
+ elif age_type == 'mode':
451
+ age = self.mode_ages()
452
+ met = np.median(self.samples,axis=1)[:,0]
453
+ mag = np.median(self.samples,axis=1)[:,1]
454
+ col = np.median(self.samples,axis=1)[:,2]
455
+ else:
456
+ raise ValueError('Not a valid combination of arguments.')
457
+
458
+ if check_domain and met is not None:
459
+ in_domain = self.check_domain(met,mag,col)
460
+ age = age[in_domain]
461
+ met = met[in_domain]
462
+ mag = mag[in_domain]
463
+ col = col[in_domain]
464
+
465
+ fig,ax = plt.subplots(figsize=(10,7))
466
+
467
+ isochrones = get_isochrones(self)
468
+ if isochrones is None:
469
+ raise ValueError('Isochrones not available for this model')
470
+ if type(isochrone_met) is not str:
471
+ if type (isochrone_met) is not float and type(isochrone_met) is not int:
472
+ raise ValueError('isochrone_met must be a float or int')
473
+ isochrone_met = str(round(isochrone_met))
474
+ if isochrone_met not in isochrones.keys():
475
+ print('Metallicity not available for this model (-2,-1,0), using [M/H]=0 instead')
476
+ isochrone_met = '0'
477
+ isochrones = isochrones[isochrone_met]
478
+ # Store line objects and their ages for interactivity
479
+ lines = []
480
+ ages = []
481
+ for isochrone in isochrones:
482
+ iso_age = isochrone['age']
483
+ iso_mag = isochrone['MG']
484
+ iso_col = isochrone['BP-RP']
485
+ if 'c' in kwargs:
486
+ kwargs.pop('c')
487
+ kwargs['color'] = 'k'
488
+ if 'lw' in kwargs:
489
+ kwargs.pop('lw')
490
+ kwargs['linewidth'] = 0.5
491
+ line, = ax.plot(iso_col, iso_mag, **kwargs)
492
+ lines.append(line)
493
+ ages.append(iso_age)
494
+
495
+ if mag is not None and col is not None:
496
+ scatter = ax.scatter(
497
+ col,
498
+ mag,
499
+ c=age,
500
+ cmap='viridis',
501
+ zorder=4
502
+ )
503
+ plt.colorbar(scatter,label='Age [Gyr]')
504
+
505
+ # Add hover info box for isochrone lines
506
+ annot = ax.annotate("", xy=(0,0), xytext=(0,15), textcoords="offset points",
507
+ ha='center',
508
+ bbox=dict(boxstyle="round", fc="w"),
509
+ arrowprops=dict(arrowstyle="->"),zorder=5)
510
+ annot.set_visible(False)
511
+
512
+ def update_annot(line, event, age):
513
+ x, y = event.xdata, event.ydata
514
+ annot.xy = (x, y)
515
+ annot.set_text(f"Age: {age:.2f} Gyr")
516
+ annot.get_bbox_patch().set_facecolor('white')
517
+ annot.get_bbox_patch().set_alpha(0.8)
518
+
519
+ def hover(event):
520
+ vis = annot.get_visible()
521
+ if event.inaxes == ax:
522
+ for line, age in zip(lines, ages):
523
+ cont, ind = line.contains(event)
524
+ if cont:
525
+ update_annot(line, event, age)
526
+ annot.set_visible(True)
527
+ fig.canvas.draw_idle()
528
+ return
529
+ if vis:
530
+ annot.set_visible(False)
531
+ fig.canvas.draw_idle()
532
+
533
+ fig.canvas.mpl_connect("motion_notify_event", hover)
534
+
535
+ ax.set_xlim(-1,3)
536
+ ax.set_ylim(10,-5)
537
+ ax.set_xlabel(r'$(G_{BP}-G_{RP})_0$ [mag]')
538
+ ax.set_ylabel(r'$M_G$ [mag]')
539
+
540
+ plt.tight_layout()
541
+ plt.show()
542
+
374
543
  class BaSTIModel(AgeModel):
375
544
  def __init__(self,cut=False,use_sklearn=True,use_tqdm=True):
376
545
  super().__init__('BaSTI',cut,use_sklearn,use_tqdm)
@@ -407,6 +576,4 @@ class DartmouthModel(AgeModel):
407
576
 
408
577
  class YaPSIModel(AgeModel):
409
578
  def __init__(self,cut=False,use_sklearn=True,use_tqdm=True):
410
- super().__init__('YaPSI',cut,use_sklearn,use_tqdm)
411
-
412
- #TODO: add flavors to models (e.g. trained on cut CMD for optimal performance)
579
+ super().__init__('YaPSI',cut,use_sklearn,use_tqdm)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: astro-nest
3
- Version: 0.5.0
3
+ Version: 0.5.2
4
4
  Summary: A python package that provides neural networks to estimate ages of stars/star populations.
5
5
  Author-email: Tristan Boin <tristan.boin@gmail.com>
6
6
  License-Expression: MIT
@@ -17,6 +17,7 @@ Based on an upcoming paper by Boin et al. 2025.
17
17
  - numpy
18
18
  - scikit (optional but recommended for speed)
19
19
  - tqdm (optional)
20
+ - matplotlib (optional)
20
21
 
21
22
  The [documentation website](https://star-age.github.io/NEST-docs/) guides you through the package usage.
22
23
 
@@ -1,4 +1,4 @@
1
- NEST/NEST.py,sha256=Xdjp0p6N3ct2fwfP7RA8qqLc9COtrAYxzzY62aaYFdY,15865
1
+ NEST/NEST.py,sha256=3F0ejOiHD03-PXmjDtTAIRHPPLkL42X6mshYuLUQgdk,22452
2
2
  NEST/__init__.py,sha256=OKJbgoTPkmXisdQCOj7rSIZf-kxym0077bfFPa0zxb0,19
3
3
  NEST/domain.pkl,sha256=Z0zv80P6Kc49Bw932gXe4q2ScANEjTvCetlaY9NMNCo,1353728
4
4
  NEST/tutorial.ipynb,sha256=_YcbYokFZcWgxO0kZZjJc4v5sBjd0t8td00flc1Mxh4,67643
@@ -66,7 +66,7 @@ NEST/models/scaler_SYCLIST.sav,sha256=mpr_hzBcB3jso80j2PNfMHD2IqRdHupp4LrAggBjrP
66
66
  NEST/models/scaler_SYCLIST_BPRP.sav,sha256=Eb9F6BAx8MCFFSRLgZUonGLAY1sII6X4xnyn8Qil23w,522
67
67
  NEST/models/scaler_YaPSI.sav,sha256=MzhMbTcjE3o9U_i1IKhZvw0pOjQe040X_qJJPDCJV4A,570
68
68
  NEST/models/scaler_YaPSI_BPRP.sav,sha256=j9gCxIQndfTPd84fhSNGyCWX3jcyBeLJEZZCjvgFrro,522
69
- astro_nest-0.5.0.dist-info/METADATA,sha256=EPPShbUBYUO-5ZYpEjfsrlaP1HrHqeFNPE14pBZYp3w,943
70
- astro_nest-0.5.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
71
- astro_nest-0.5.0.dist-info/top_level.txt,sha256=TpsNVEvaG60SlJ4tjyQu4aOsE5ov9RW9lQtLpgBAx4E,5
72
- astro_nest-0.5.0.dist-info/RECORD,,
69
+ astro_nest-0.5.2.dist-info/METADATA,sha256=SpX-YgOKtgPOIRDh2ysojH9ykqQAPSEHok1JEp9g1zQ,967
70
+ astro_nest-0.5.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
71
+ astro_nest-0.5.2.dist-info/top_level.txt,sha256=TpsNVEvaG60SlJ4tjyQu4aOsE5ov9RW9lQtLpgBAx4E,5
72
+ astro_nest-0.5.2.dist-info/RECORD,,