xradio 0.0.29__py3-none-any.whl → 0.0.31__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
xradio/__init__.py CHANGED
@@ -1,13 +1,13 @@
1
1
  import os
2
2
  from graphviper.utils.logger import setup_logger
3
3
 
4
- _logger_name = "xradio"
5
- if os.getenv("VIPER_LOGGER_NAME") != _logger_name:
6
- os.environ["VIPER_LOGGER_NAME"] = _logger_name
7
- setup_logger(
8
- logger_name="xradio",
9
- log_to_term=True,
10
- log_to_file=False, # True
11
- log_file="xradio-logfile",
12
- log_level="DEBUG",
13
- )
4
+ # _logger_name = "xradio"
5
+ # if os.getenv("VIPER_LOGGER_NAME") != _logger_name:
6
+ # os.environ["VIPER_LOGGER_NAME"] = _logger_name
7
+ # setup_logger(
8
+ # logger_name="xradio",
9
+ # log_to_term=True,
10
+ # log_to_file=False, # True
11
+ # log_file="xradio-logfile",
12
+ # log_level="DEBUG",
13
+ # )
xradio/_utils/common.py CHANGED
@@ -1,3 +1,60 @@
1
1
  import numpy as np
2
2
 
3
3
  _deg_to_rad = np.pi / 180
4
+
5
+
6
+ def cast_to_str(x):
7
+ if isinstance(x, list):
8
+ return x[0]
9
+ else:
10
+ return x
11
+
12
+
13
+ def convert_to_si_units(xds):
14
+ for data_var in xds.data_vars:
15
+ if "units" in xds[data_var].attrs:
16
+ for u_i, u in enumerate(xds[data_var].attrs["units"]):
17
+ if u == "km":
18
+ xds[data_var][..., u_i] = xds[data_var][..., u_i] * 1e3
19
+ xds[data_var].attrs["units"][u_i] = "m"
20
+ if u == "km/s":
21
+ xds[data_var][..., u_i] = xds[data_var][..., u_i] * 1e3
22
+ xds[data_var].attrs["units"][u_i] = "m/s"
23
+ if u == "deg":
24
+ xds[data_var][..., u_i] = xds[data_var][..., u_i] * np.pi / 180
25
+ xds[data_var].attrs["units"][u_i] = "rad"
26
+ if u == "Au" or u == "AU":
27
+ xds[data_var][..., u_i] = xds[data_var][..., u_i] * 149597870700
28
+ xds[data_var].attrs["units"][u_i] = "m"
29
+ if u == "Au/d" or u == "AU/d":
30
+ xds[data_var][..., u_i] = (
31
+ xds[data_var][..., u_i] * 149597870700 / 86400
32
+ )
33
+ xds[data_var].attrs["units"][u_i] = "m/s"
34
+ if u == "arcsec":
35
+ xds[data_var][..., u_i] = xds[data_var][..., u_i] * np.pi / 648000
36
+ xds[data_var].attrs["units"][u_i] = "rad"
37
+ return xds
38
+
39
+
40
+ def add_position_offsets(dv_1, dv_2):
41
+ # Fun with angles: We are adding angles together. We need to make sure that the results are between -pi and pi.
42
+ new_pos = dv_1 + dv_2
43
+
44
+ while np.any(new_pos[:, 0] > np.pi) or np.any(new_pos[:, 0] < -np.pi):
45
+ new_pos[:, 0] = np.where(
46
+ new_pos[:, 0] > np.pi, new_pos[:, 0] - 2 * np.pi, new_pos[:, 0]
47
+ )
48
+ new_pos[:, 0] = np.where(
49
+ new_pos[:, 0] < -np.pi, new_pos[:, 0] + 2 * np.pi, new_pos[:, 0]
50
+ )
51
+
52
+ while np.any(new_pos[:, 1] > np.pi / 2) or np.any(new_pos[:, 1] < -np.pi / 2):
53
+ new_pos[:, 1] = np.where(
54
+ new_pos[:, 1] > np.pi / 2, new_pos[:, 1] - np.pi, new_pos[:, 1]
55
+ )
56
+ new_pos[:, 1] = np.where(
57
+ new_pos[:, 1] < -np.pi / 2, new_pos[:, 1] + np.pi, new_pos[:, 1]
58
+ )
59
+
60
+ return new_pos
@@ -1,6 +1,23 @@
1
1
  """Contains optimised functions to be used within other modules."""
2
2
 
3
3
  import numpy as np
4
+ import xarray as xr
5
+
6
+
7
+ def to_list(x):
8
+ if isinstance(x, (list, np.ndarray)):
9
+ if x.ndim == 0:
10
+ return [x.item()]
11
+ return list(x) # needed for json serialization
12
+ return [x]
13
+
14
+
15
+ def to_np_array(x):
16
+ if isinstance(x, (list, np.ndarray)):
17
+ if x.ndim == 0:
18
+ return np.array([x.item()])
19
+ return np.array(x) # needed for json serialization
20
+ return np.array([x])
4
21
 
5
22
 
6
23
  def check_if_consistent(array: np.ndarray, array_name: str) -> np.ndarray:
@@ -45,6 +62,12 @@ def unique_1d(array: np.ndarray) -> np.ndarray:
45
62
  a sorted array of unique values.
46
63
 
47
64
  """
65
+ if isinstance(array, xr.core.dataarray.DataArray):
66
+ array = array.values
67
+
68
+ if array.ndim == 0:
69
+ return np.array([array.item()])
70
+
48
71
  return np.sort(pd.unique(array))
49
72
 
50
73
 
@@ -6,12 +6,6 @@ class processing_set(dict):
6
6
  super().__init__(*args, **kwargs)
7
7
  self.meta = {"summary": {}}
8
8
 
9
- # generate_meta(self)
10
-
11
- # def generate_meta(self):
12
- # self.meta['summary'] = {"base": _summary(self)}
13
- # self.meta['max_dims'] = _get_ps_max_dims(self)
14
-
15
9
  def summary(self, data_group="base"):
16
10
  if data_group in self.meta["summary"]:
17
11
  return self.meta["summary"][data_group]
@@ -36,49 +30,65 @@ class processing_set(dict):
36
30
  def _summary(self, data_group="base"):
37
31
  summary_data = {
38
32
  "name": [],
39
- "ddi": [],
40
- "intent": [],
41
- "field_id": [],
33
+ "obs_mode": [],
34
+ "shape": [],
35
+ "polarization": [],
36
+ "spw_id": [],
37
+ # "field_id": [],
42
38
  "field_name": [],
39
+ # "source_id": [],
40
+ "source_name": [],
41
+ "field_coords": [],
43
42
  "start_frequency": [],
44
43
  "end_frequency": [],
45
- "shape": [],
46
- "field_coords": [],
47
44
  }
48
45
  from astropy.coordinates import SkyCoord
49
46
  import astropy.units as u
50
47
 
51
48
  for key, value in self.items():
52
49
  summary_data["name"].append(key)
53
- summary_data["ddi"].append(value.attrs["ddi"])
54
- summary_data["intent"].append(value.attrs["intent"])
50
+ summary_data["obs_mode"].append(value.attrs["partition_info"]["obs_mode"])
51
+ summary_data["spw_id"].append(
52
+ value.attrs["partition_info"]["spectral_window_id"]
53
+ )
54
+ summary_data["polarization"].append(value.polarization.values)
55
55
 
56
56
  if "visibility" in value.attrs["data_groups"][data_group]:
57
57
  data_name = value.attrs["data_groups"][data_group]["visibility"]
58
+ center_name = "FIELD_PHASE_CENTER"
58
59
 
59
60
  if "spectrum" in value.attrs["data_groups"][data_group]:
60
61
  data_name = value.attrs["data_groups"][data_group]["spectrum"]
62
+ center_name = "FIELD_REFERENCE_CENTER"
61
63
 
62
64
  summary_data["shape"].append(value[data_name].shape)
63
65
 
64
- summary_data["field_id"].append(value.attrs["partition_info"]["field_id"])
66
+ # summary_data["field_id"].append(value.attrs["partition_info"]["field_id"])
67
+ # summary_data["source_id"].append(value.attrs["partition_info"]["source_id"])
68
+
65
69
  summary_data["field_name"].append(
66
- value[data_name].attrs["field_and_source_xds"].attrs["field_name"]
70
+ value.attrs["partition_info"]["field_name"]
71
+ )
72
+ summary_data["source_name"].append(
73
+ value.attrs["partition_info"]["source_name"]
67
74
  )
68
75
  summary_data["start_frequency"].append(value["frequency"].values[0])
69
76
  summary_data["end_frequency"].append(value["frequency"].values[-1])
70
77
 
71
78
  if value[data_name].attrs["field_and_source_xds"].is_ephemeris:
72
79
  summary_data["field_coords"].append("Ephemeris")
80
+ elif (
81
+ "time"
82
+ in value[data_name].attrs["field_and_source_xds"][center_name].coords
83
+ ):
84
+ summary_data["field_coords"].append("Multi-Phase-Center")
73
85
  else:
74
86
  ra_dec_rad = (
75
- value[data_name]
76
- .attrs["field_and_source_xds"]["FIELD_PHASE_CENTER"]
77
- .values
87
+ value[data_name].attrs["field_and_source_xds"][center_name].values
78
88
  )
79
89
  frame = (
80
90
  value[data_name]
81
- .attrs["field_and_source_xds"]["FIELD_PHASE_CENTER"]
91
+ .attrs["field_and_source_xds"][center_name]
82
92
  .attrs["frame"]
83
93
  .lower()
84
94
  )
@@ -90,8 +100,8 @@ class processing_set(dict):
90
100
  summary_data["field_coords"].append(
91
101
  [
92
102
  frame,
93
- coord.ra.to_string(unit=u.hour),
94
- coord.dec.to_string(unit=u.deg),
103
+ coord.ra.to_string(unit=u.hour, precision=2),
104
+ coord.dec.to_string(unit=u.deg, precision=2),
95
105
  ]
96
106
  )
97
107
 
@@ -108,8 +118,8 @@ class processing_set(dict):
108
118
  assert (
109
119
  frame == ms_xds.frequency.attrs["frame"]
110
120
  ), "Frequency reference frame not consistent in processing set."
111
- if ms_xds.frequency.attrs["spw_id"] not in spw_ids:
112
- spw_ids.append(ms_xds.frequency.attrs["spw_id"])
121
+ if ms_xds.frequency.attrs["spectral_window_id"] not in spw_ids:
122
+ spw_ids.append(ms_xds.frequency.attrs["spectral_window_id"])
113
123
  freq_axis_list.append(ms_xds.frequency)
114
124
 
115
125
  freq_axis = xr.concat(freq_axis_list, dim="frequency").sortby("frequency")
@@ -131,3 +141,36 @@ class processing_set(dict):
131
141
 
132
142
  def get(self, id):
133
143
  return self[list(self.keys())[id]]
144
+
145
+ def sel(self, **kwargs):
146
+ import numpy as np
147
+
148
+ summary_table = self.summary()
149
+ for key, value in kwargs.items():
150
+ if isinstance(value, list) or isinstance(value, np.ndarray):
151
+ summary_table = summary_table[summary_table[key].isin(value)]
152
+ elif isinstance(value, slice):
153
+ summary_table = summary_table[
154
+ summary_table[key].between(value.start, value.stop)
155
+ ]
156
+ else:
157
+ summary_table = summary_table[summary_table[key] == value]
158
+
159
+ sub_ps = processing_set()
160
+ for key, val in self.items():
161
+ if key in summary_table["name"].values:
162
+ sub_ps[key] = val
163
+
164
+ return sub_ps
165
+
166
+ def ms_sel(self, **kwargs):
167
+ sub_ps = processing_set()
168
+ for key, val in self.items():
169
+ sub_ps[key] = val.sel(kwargs)
170
+ return sub_ps
171
+
172
+ def ms_isel(self, **kwargs):
173
+ sub_ps = processing_set()
174
+ for key, val in self.items():
175
+ sub_ps[key] = val.isel(kwargs)
176
+ return sub_ps