cubevis 0.5.13__py3-none-any.whl → 0.5.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cubevis might be problematic. Click here for more details.
- cubevis/__version__.py +1 -1
- cubevis/private/apps/__init__.py +5 -1
- cubevis/toolbox/_cube.py +14 -11
- {cubevis-0.5.13.dist-info → cubevis-0.5.15.dist-info}/METADATA +1 -1
- {cubevis-0.5.13.dist-info → cubevis-0.5.15.dist-info}/RECORD +7 -27
- cubevis/data/measurement_set/__init__.py +0 -7
- cubevis/data/measurement_set/_ms_data.py +0 -178
- cubevis/data/measurement_set/processing_set/__init__.py +0 -30
- cubevis/data/measurement_set/processing_set/_ps_concat.py +0 -98
- cubevis/data/measurement_set/processing_set/_ps_coords.py +0 -78
- cubevis/data/measurement_set/processing_set/_ps_data.py +0 -213
- cubevis/data/measurement_set/processing_set/_ps_io.py +0 -55
- cubevis/data/measurement_set/processing_set/_ps_raster_data.py +0 -154
- cubevis/data/measurement_set/processing_set/_ps_select.py +0 -91
- cubevis/data/measurement_set/processing_set/_ps_stats.py +0 -218
- cubevis/data/measurement_set/processing_set/_xds_data.py +0 -149
- cubevis/plot/__init__.py +0 -1
- cubevis/plot/ms_plot/__init__.py +0 -29
- cubevis/plot/ms_plot/_ms_plot.py +0 -242
- cubevis/plot/ms_plot/_ms_plot_constants.py +0 -22
- cubevis/plot/ms_plot/_ms_plot_selectors.py +0 -348
- cubevis/plot/ms_plot/_raster_plot.py +0 -292
- cubevis/plot/ms_plot/_raster_plot_inputs.py +0 -116
- cubevis/plot/ms_plot/_xds_plot_axes.py +0 -110
- cubevis/private/apps/_ms_raster.py +0 -815
- {cubevis-0.5.13.dist-info → cubevis-0.5.15.dist-info}/WHEEL +0 -0
- {cubevis-0.5.13.dist-info → cubevis-0.5.15.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,218 +0,0 @@
|
|
|
1
|
-
'''
|
|
2
|
-
Calculate statistics on xradio ProcessingSet data.
|
|
3
|
-
'''
|
|
4
|
-
|
|
5
|
-
import dask
|
|
6
|
-
import numpy as np
|
|
7
|
-
|
|
8
|
-
from xradio.measurement_set.load_processing_set import ProcessingSetIterator
|
|
9
|
-
from graphviper.graph_tools import generate_dask_workflow
|
|
10
|
-
from graphviper.graph_tools.coordinate_utils import make_parallel_coord, interpolate_data_coords_onto_parallel_coords
|
|
11
|
-
from graphviper.graph_tools.map import map as graph_map
|
|
12
|
-
from graphviper.graph_tools.reduce import reduce as graph_reduce
|
|
13
|
-
|
|
14
|
-
try:
|
|
15
|
-
from toolviper.dask.client import get_client
|
|
16
|
-
_HAVE_TOOLVIPER = True
|
|
17
|
-
except ImportError:
|
|
18
|
-
_HAVE_TOOLVIPER = False
|
|
19
|
-
|
|
20
|
-
from cubevis.data.measurement_set.processing_set._xds_data import get_correlated_data, get_axis_data
|
|
21
|
-
|
|
22
|
-
def calculate_ps_stats(ps_xdt, ps_store, vis_axis, data_group, logger):
|
|
23
|
-
'''
|
|
24
|
-
Calculate stats for unflagged visibilities: min, max, mean, std
|
|
25
|
-
ps_xdt (xarray.DataTree): input MeasurementSet opened from zarr file
|
|
26
|
-
ps_store (str): path to visibility zarr file
|
|
27
|
-
vis_axis (str): complex component (amp, phase, real, imag)
|
|
28
|
-
Returns: stats tuple (min, max, mean, stddev) or None if all data flagged (count=0)
|
|
29
|
-
'''
|
|
30
|
-
input_params = {}
|
|
31
|
-
input_params['input_data_store'] = ps_store
|
|
32
|
-
input_params['xdt'] = ps_xdt
|
|
33
|
-
input_params['vis_axis'] = vis_axis
|
|
34
|
-
input_params['data_group'] = data_group
|
|
35
|
-
for ms_xdt in ps_xdt.values():
|
|
36
|
-
if data_group in ms_xdt.attrs['data_groups']:
|
|
37
|
-
input_params['correlated_data'] = get_correlated_data(ms_xdt.ds, data_group)
|
|
38
|
-
break
|
|
39
|
-
|
|
40
|
-
if _HAVE_TOOLVIPER:
|
|
41
|
-
active_client = get_client() # could be None if not set up outside cubevis
|
|
42
|
-
else:
|
|
43
|
-
active_client = None
|
|
44
|
-
n_threads = active_client.thread_info()['n_threads'] if active_client is not None else 4
|
|
45
|
-
logger.debug(f"Setting {n_threads} n_chunks for parallel coords.")
|
|
46
|
-
mapping = _get_task_data_mapping(ps_xdt, n_threads)
|
|
47
|
-
|
|
48
|
-
data_min, data_max, data_mean = _calc_basic_stats(ps_xdt, mapping, input_params, logger)
|
|
49
|
-
if np.isfinite(data_mean):
|
|
50
|
-
input_params['mean'] = data_mean
|
|
51
|
-
data_stddev = _calc_stddev(ps_xdt, mapping, input_params, logger)
|
|
52
|
-
return data_min, data_max, data_mean, data_stddev
|
|
53
|
-
return None
|
|
54
|
-
|
|
55
|
-
def _get_task_data_mapping(ps_xdt, n_threads):
|
|
56
|
-
frequencies = ps_xdt.xr_ps.get_freq_axis()
|
|
57
|
-
parallel_coords = {"frequency": make_parallel_coord(coord=frequencies, n_chunks=n_threads)}
|
|
58
|
-
return interpolate_data_coords_onto_parallel_coords(parallel_coords, ps_xdt)
|
|
59
|
-
|
|
60
|
-
def _calc_basic_stats(ps_xdt, mapping, input_params, logger):
|
|
61
|
-
''' Calculate min, max, mean using graph map/reduce '''
|
|
62
|
-
graph = graph_map(
|
|
63
|
-
input_data=ps_xdt,
|
|
64
|
-
node_task_data_mapping=mapping,
|
|
65
|
-
node_task=_map_stats,
|
|
66
|
-
input_params=input_params
|
|
67
|
-
)
|
|
68
|
-
reduce_map = graph_reduce(
|
|
69
|
-
graph, _reduce_stats, input_params, mode='tree'
|
|
70
|
-
)
|
|
71
|
-
dask_graph = generate_dask_workflow(reduce_map)
|
|
72
|
-
#dask_graph.visualize(filename='stats.png')
|
|
73
|
-
results = dask.compute(dask_graph)
|
|
74
|
-
|
|
75
|
-
data_min, data_max, data_sum, data_count = results[0]
|
|
76
|
-
if data_count == 0.0:
|
|
77
|
-
logger.debug("stats: no unflagged data")
|
|
78
|
-
return (data_min, data_max, np.inf)
|
|
79
|
-
|
|
80
|
-
data_mean = data_sum / data_count
|
|
81
|
-
logger.debug(f"basic stats: min={data_min:.4f}, max={data_max:.4f}, sum={data_sum:.4f}, count={data_count} mean={data_mean:.4f}")
|
|
82
|
-
return data_min, data_max, data_mean
|
|
83
|
-
|
|
84
|
-
def _calc_stddev(ps_xdt, mapping, input_params, logger):
|
|
85
|
-
''' Calculate stddev using graph map/reduce '''
|
|
86
|
-
graph = graph_map(
|
|
87
|
-
input_data=ps_xdt,
|
|
88
|
-
node_task_data_mapping=mapping,
|
|
89
|
-
node_task=_map_variance,
|
|
90
|
-
input_params=input_params
|
|
91
|
-
)
|
|
92
|
-
reduce_map = graph_reduce(
|
|
93
|
-
graph, _reduce_variance, input_params, mode='tree'
|
|
94
|
-
)
|
|
95
|
-
dask_graph = generate_dask_workflow(reduce_map)
|
|
96
|
-
results = dask.compute(dask_graph)
|
|
97
|
-
|
|
98
|
-
var_sum, var_count = results[0]
|
|
99
|
-
data_variance = var_sum / var_count
|
|
100
|
-
data_stddev = data_variance ** 0.5
|
|
101
|
-
logger.debug(f"stats: variance={data_variance:.4f}, stddev={data_stddev:.4f}")
|
|
102
|
-
return data_stddev
|
|
103
|
-
|
|
104
|
-
def _get_stats_xda(xds, vis_axis, data_group):
|
|
105
|
-
''' Return xda with only unflagged cross-corr visibility data '''
|
|
106
|
-
# apply flags to get unflagged vis data
|
|
107
|
-
xda = get_axis_data(xds, vis_axis, data_group)
|
|
108
|
-
unflagged_xda = xda.where(np.logical_not(xds.FLAG))
|
|
109
|
-
|
|
110
|
-
if unflagged_xda.count() > 0 and "baseline_antenna1_name" in unflagged_xda.coords:
|
|
111
|
-
# if unflagged data, remove autocorrelation baselines
|
|
112
|
-
stats_xda = unflagged_xda.where(
|
|
113
|
-
unflagged_xda.baseline_antenna1_name != unflagged_xda.baseline_antenna2_name
|
|
114
|
-
)
|
|
115
|
-
if stats_xda.count() > 0:
|
|
116
|
-
# return xda with nan where flagged or auto-corr
|
|
117
|
-
return stats_xda
|
|
118
|
-
|
|
119
|
-
# return xda with nan where flagged
|
|
120
|
-
return unflagged_xda
|
|
121
|
-
|
|
122
|
-
def _map_stats(input_params):
|
|
123
|
-
''' Return min, max, sum, and count of data chunk '''
|
|
124
|
-
vis_axis = input_params['vis_axis']
|
|
125
|
-
data_group = input_params['data_group']
|
|
126
|
-
correlated_data = input_params['correlated_data']
|
|
127
|
-
min_vals = []
|
|
128
|
-
max_vals = []
|
|
129
|
-
sum_vals = []
|
|
130
|
-
count_vals = []
|
|
131
|
-
|
|
132
|
-
ps_iter = ProcessingSetIterator(
|
|
133
|
-
input_params['data_selection'],
|
|
134
|
-
input_params['input_data_store'],
|
|
135
|
-
input_params['xdt'],
|
|
136
|
-
input_params['data_group'],
|
|
137
|
-
include_variables=[correlated_data, 'FLAG'],
|
|
138
|
-
load_sub_datasets=False
|
|
139
|
-
)
|
|
140
|
-
|
|
141
|
-
for xds in ps_iter:
|
|
142
|
-
xda = _get_stats_xda(xds, vis_axis, data_group)
|
|
143
|
-
if xda.count() > 0:
|
|
144
|
-
xda_data = xda.values.ravel()
|
|
145
|
-
try:
|
|
146
|
-
min_vals.append(xda_data[np.nanargmin(xda_data)])
|
|
147
|
-
except ValueError:
|
|
148
|
-
pass
|
|
149
|
-
try:
|
|
150
|
-
max_vals.append(xda_data[np.nanargmax(xda_data)])
|
|
151
|
-
except ValueError:
|
|
152
|
-
pass
|
|
153
|
-
sum_vals.append(np.nansum(xda))
|
|
154
|
-
count_vals.append(xda.count().values)
|
|
155
|
-
try:
|
|
156
|
-
min_value = np.nanmin(min_vals)
|
|
157
|
-
except ValueError:
|
|
158
|
-
min_value = np.nan
|
|
159
|
-
try:
|
|
160
|
-
max_value = np.nanmax(max_vals)
|
|
161
|
-
except ValueError:
|
|
162
|
-
max_value = np.nan
|
|
163
|
-
return (min_value, max_value, sum(sum_vals), sum(count_vals))
|
|
164
|
-
|
|
165
|
-
# pylint: disable=unused-argument
|
|
166
|
-
def _reduce_stats(graph_inputs, input_params):
|
|
167
|
-
''' Compute min, max, sum, and count of all data.
|
|
168
|
-
input_parameters seems to be required although unused. '''
|
|
169
|
-
data_min = 0.0
|
|
170
|
-
mins = [values[0] for values in graph_inputs]
|
|
171
|
-
if not np.isnan(mins).all():
|
|
172
|
-
data_min = min(0.0, np.nanmin(mins))
|
|
173
|
-
|
|
174
|
-
data_max = 0.0
|
|
175
|
-
maxs = [values[1] for values in graph_inputs]
|
|
176
|
-
if not np.isnan(maxs).all():
|
|
177
|
-
data_max = max(0.0, np.nanmax(maxs))
|
|
178
|
-
|
|
179
|
-
data_sum = sum(values[2] for values in graph_inputs)
|
|
180
|
-
data_count = sum(values[3] for values in graph_inputs)
|
|
181
|
-
return (data_min, data_max, data_sum, data_count)
|
|
182
|
-
# pylint: enable=unused-argument
|
|
183
|
-
|
|
184
|
-
def _map_variance(input_params):
|
|
185
|
-
''' Return sum, count, of (xda - mean) squared '''
|
|
186
|
-
vis_axis = input_params['vis_axis']
|
|
187
|
-
data_group = input_params['data_group']
|
|
188
|
-
correlated_data = input_params['correlated_data']
|
|
189
|
-
mean = input_params['mean']
|
|
190
|
-
|
|
191
|
-
sq_diff_sum = 0.0
|
|
192
|
-
sq_diff_count = 0
|
|
193
|
-
|
|
194
|
-
ps_iter = ProcessingSetIterator(
|
|
195
|
-
input_params['data_selection'],
|
|
196
|
-
input_params['input_data_store'],
|
|
197
|
-
input_params['xdt'],
|
|
198
|
-
input_params['data_group'],
|
|
199
|
-
include_variables=[correlated_data, 'FLAG'],
|
|
200
|
-
load_sub_datasets=False
|
|
201
|
-
)
|
|
202
|
-
|
|
203
|
-
for xds in ps_iter:
|
|
204
|
-
xda = _get_stats_xda(xds, vis_axis, data_group)
|
|
205
|
-
if xda.size > 0:
|
|
206
|
-
sq_diff = (xda - mean) ** 2
|
|
207
|
-
sq_diff_sum += np.nansum(sq_diff)
|
|
208
|
-
sq_diff_count += sq_diff.count().values
|
|
209
|
-
return (sq_diff_sum, sq_diff_count)
|
|
210
|
-
|
|
211
|
-
# pylint: disable=unused-argument
|
|
212
|
-
def _reduce_variance(graph_inputs, input_params):
|
|
213
|
-
''' Compute sum and count of all (xda-mean) squared data.
|
|
214
|
-
input_parameters seems to be required although unused. '''
|
|
215
|
-
sq_diff_sum = sum(values[0] for values in graph_inputs)
|
|
216
|
-
sq_diff_count = sum(values[1] for values in graph_inputs)
|
|
217
|
-
return (sq_diff_sum, sq_diff_count)
|
|
218
|
-
# pylint: enable=unused-argument
|
|
@@ -1,149 +0,0 @@
|
|
|
1
|
-
''' Get MeasurementSet data from xarray Dataset '''
|
|
2
|
-
|
|
3
|
-
from astropy import constants
|
|
4
|
-
import numpy as np
|
|
5
|
-
import xarray as xr
|
|
6
|
-
|
|
7
|
-
from cubevis.plot.ms_plot._ms_plot_constants import SPECTRUM_AXIS_OPTIONS, UVW_AXIS_OPTIONS, VIS_AXIS_OPTIONS, WEIGHT_AXIS_OPTIONS
|
|
8
|
-
|
|
9
|
-
def get_correlated_data(xds, data_group):
|
|
10
|
-
''' Return correlated_data value in data_group dict '''
|
|
11
|
-
return xds.attrs['data_groups'][data_group]['correlated_data']
|
|
12
|
-
|
|
13
|
-
def get_axis_data(xds, axis, data_group=None):
|
|
14
|
-
''' Get requested axis data from xarray dataset.
|
|
15
|
-
xds (dict): msv4 xarray.Dataset
|
|
16
|
-
axis (str): axis data to retrieve.
|
|
17
|
-
data_group (str): correlated data group name.
|
|
18
|
-
Returns: xarray.DataArray
|
|
19
|
-
'''
|
|
20
|
-
group_info = xds.data_groups[data_group]
|
|
21
|
-
xda = None
|
|
22
|
-
if _is_coordinate_axis(axis):
|
|
23
|
-
xda = xds[axis]
|
|
24
|
-
elif _is_antenna_axis(axis):
|
|
25
|
-
xda = _get_antenna_axis(xds, axis)
|
|
26
|
-
elif _is_vis_axis(axis):
|
|
27
|
-
xda = _calc_vis_axis(xds, axis, group_info)
|
|
28
|
-
elif _is_uvw_axis(axis):
|
|
29
|
-
xda = _calc_uvw_axis(xds, axis, group_info)
|
|
30
|
-
elif _is_weight_axis(axis):
|
|
31
|
-
xda = _calc_weight_axis(xds, axis, group_info)
|
|
32
|
-
elif 'spw' in axis:
|
|
33
|
-
xda = _get_spw_axis(xds, axis)
|
|
34
|
-
elif 'wave' in axis:
|
|
35
|
-
xda = _calc_wave_axis(xds, axis, group_info)
|
|
36
|
-
elif axis == 'field':
|
|
37
|
-
xda = xr.DataArray([xds[group_info['correlated_data']].field_and_source_xds.field_name])
|
|
38
|
-
elif axis == 'flag':
|
|
39
|
-
xda = xds[group_info['flag']]
|
|
40
|
-
elif axis == 'intents':
|
|
41
|
-
xda = xr.DataArray(["".join(xds.partition_info['intents'])])
|
|
42
|
-
elif axis == 'channel':
|
|
43
|
-
xda = xr.DataArray(np.array(range(xds.frequency.size)))
|
|
44
|
-
return xda
|
|
45
|
-
|
|
46
|
-
def _is_coordinate_axis(axis):
|
|
47
|
-
return axis in ['scan_number', 'time', 'frequency', 'polarization',
|
|
48
|
-
#'velocity': 'frequency', # calculate
|
|
49
|
-
# TODO?
|
|
50
|
-
#'observation': no id, xds.observation_info (observer, project, release date)
|
|
51
|
-
#'feed1': no id, xds.antenna_xds
|
|
52
|
-
#'feed2': no id, xds.antenna_xds
|
|
53
|
-
]
|
|
54
|
-
|
|
55
|
-
def _is_vis_axis(axis):
|
|
56
|
-
return axis in VIS_AXIS_OPTIONS
|
|
57
|
-
|
|
58
|
-
def _is_uvw_axis(axis):
|
|
59
|
-
return axis in UVW_AXIS_OPTIONS
|
|
60
|
-
|
|
61
|
-
def _is_antenna_axis(axis):
|
|
62
|
-
return 'baseline' in axis or 'antenna' in axis
|
|
63
|
-
|
|
64
|
-
def _is_weight_axis(axis):
|
|
65
|
-
return axis in WEIGHT_AXIS_OPTIONS
|
|
66
|
-
|
|
67
|
-
def _get_spw_axis(xds, axis):
|
|
68
|
-
if axis == 'spw_id':
|
|
69
|
-
return xr.DataArray([xds.frequency.spectral_window_id])
|
|
70
|
-
if axis == 'spw_name':
|
|
71
|
-
return xr.DataArray([xds.frequency.spectral_window_name])
|
|
72
|
-
raise ValueError(f"Invalid spw axis {axis}")
|
|
73
|
-
|
|
74
|
-
def _calc_vis_axis(xds, axis, group_info):
|
|
75
|
-
''' Calculate axis from correlated data '''
|
|
76
|
-
correlated_data = group_info['correlated_data']
|
|
77
|
-
xda = xds[correlated_data]
|
|
78
|
-
|
|
79
|
-
# Single dish spectrum
|
|
80
|
-
if correlated_data == "SPECTRUM":
|
|
81
|
-
if axis in SPECTRUM_AXIS_OPTIONS:
|
|
82
|
-
return xda.assign_attrs(units='Jy')
|
|
83
|
-
raise RuntimeError(f"Vis axis {axis} invalid for SPECTRUM dataset, select from {SPECTRUM_AXIS_OPTIONS}")
|
|
84
|
-
|
|
85
|
-
# Interferometry visibilities
|
|
86
|
-
if axis == 'amp':
|
|
87
|
-
return np.absolute(xda).assign_attrs(units='Jy')
|
|
88
|
-
if axis == 'phase':
|
|
89
|
-
# np.angle(xda) returns ndarray not xr.DataArray
|
|
90
|
-
return (np.arctan2(xda.imag, xda.real) * 180.0/np.pi).assign_attrs(units="deg")
|
|
91
|
-
if axis == 'real':
|
|
92
|
-
return np.real(xda.assign_attrs(units='Jy'))
|
|
93
|
-
if axis == 'imag':
|
|
94
|
-
return np.imag(xda.assign_attrs(units='Jy'))
|
|
95
|
-
return None
|
|
96
|
-
|
|
97
|
-
def _calc_uvw_axis(xds, axis, group_info):
|
|
98
|
-
''' Calculate axis from UVW xarray DataArray '''
|
|
99
|
-
if 'uvw' not in group_info:
|
|
100
|
-
raise RuntimeError(f"Axis {axis} is not valid in this dataset, no uvw data")
|
|
101
|
-
|
|
102
|
-
uvw_xda = xds[group_info['uvw']]
|
|
103
|
-
|
|
104
|
-
if axis == 'u':
|
|
105
|
-
return uvw_xda.isel(uvw_label=0)
|
|
106
|
-
if axis == 'v':
|
|
107
|
-
return uvw_xda.isel(uvw_label=1)
|
|
108
|
-
if axis == 'w':
|
|
109
|
-
return uvw_xda.isel(uvw_label=2)
|
|
110
|
-
|
|
111
|
-
# uvdist
|
|
112
|
-
u_xda = uvw_xda.isel(uvw_label=0)
|
|
113
|
-
v_xda = uvw_xda.isel(uvw_label=1)
|
|
114
|
-
return np.sqrt(np.square(u_xda) + np.square(v_xda))
|
|
115
|
-
|
|
116
|
-
def _calc_wave_axis(xds, axis, group_info):
|
|
117
|
-
wave_axes = {'uwave': 'u', 'vwave': 'v', 'wwave': 'w', 'uvwave': 'uvdist'}
|
|
118
|
-
if axis not in wave_axes:
|
|
119
|
-
raise ValueError(f"Invalid wave axis {axis}")
|
|
120
|
-
uvwdist_array = _calc_uvw_axis(xds, wave_axes[axis], group_info).values / constants.c
|
|
121
|
-
uvwdist_len = len(uvwdist_array)
|
|
122
|
-
freq_array = xds.frequency.values
|
|
123
|
-
wave = np.zeros(shape=(len(freq_array), uvwdist_len), dtype=np.double)
|
|
124
|
-
|
|
125
|
-
for i, in range(uvwdist_len):
|
|
126
|
-
wave[:, i] = uvwdist_array[i] * freq_array
|
|
127
|
-
|
|
128
|
-
wave_xda = xr.DataArray(wave, dims=['frequency', 'uvw_label'])
|
|
129
|
-
return wave_xda
|
|
130
|
-
|
|
131
|
-
def _calc_weight_axis(xds, axis, group_info):
|
|
132
|
-
weight = xds[group_info['weight']]
|
|
133
|
-
if axis == 'weight':
|
|
134
|
-
return weight
|
|
135
|
-
return np.sqrt(1.0 / weight)
|
|
136
|
-
|
|
137
|
-
def _get_antenna_axis(xds, axis):
|
|
138
|
-
if 'antenna_name' in xds.coords:
|
|
139
|
-
return xds.antenna_name
|
|
140
|
-
|
|
141
|
-
if axis == 'antenna1':
|
|
142
|
-
return xds.baseline_antenna1_name
|
|
143
|
-
if axis == 'antenna2':
|
|
144
|
-
return xds.baseline_antenna2_name
|
|
145
|
-
if axis == 'baseline_id':
|
|
146
|
-
return xds.baseline_id
|
|
147
|
-
if axis == 'baseline':
|
|
148
|
-
return xds.baseline
|
|
149
|
-
raise ValueError(f"Invalid antenna/baseline axis {axis}")
|
cubevis/plot/__init__.py
DELETED
|
@@ -1 +0,0 @@
|
|
|
1
|
-
''' Module for plotting functions. Currently supported MeasurementSet plots only. '''
|
cubevis/plot/ms_plot/__init__.py
DELETED
|
@@ -1,29 +0,0 @@
|
|
|
1
|
-
''' This module contains classes and utilities to support MS plotting '''
|
|
2
|
-
|
|
3
|
-
from ._ms_plot import MsPlot
|
|
4
|
-
|
|
5
|
-
from ._ms_plot_constants import (
|
|
6
|
-
SPECTRUM_AXIS_OPTIONS,
|
|
7
|
-
UVW_AXIS_OPTIONS,
|
|
8
|
-
VIS_AXIS_OPTIONS,
|
|
9
|
-
WEIGHT_AXIS_OPTIONS,
|
|
10
|
-
PS_SELECTION_OPTIONS,
|
|
11
|
-
MS_SELECTION_OPTIONS,
|
|
12
|
-
AGGREGATOR_OPTIONS,
|
|
13
|
-
DEFAULT_UNFLAGGED_CMAP,
|
|
14
|
-
DEFAULT_FLAGGED_CMAP,
|
|
15
|
-
)
|
|
16
|
-
|
|
17
|
-
from ._ms_plot_selectors import (
|
|
18
|
-
file_selector,
|
|
19
|
-
title_selector,
|
|
20
|
-
style_selector,
|
|
21
|
-
axis_selector,
|
|
22
|
-
aggregation_selector,
|
|
23
|
-
iteration_selector,
|
|
24
|
-
selection_selector,
|
|
25
|
-
plot_starter,
|
|
26
|
-
)
|
|
27
|
-
|
|
28
|
-
from ._raster_plot_inputs import check_inputs
|
|
29
|
-
from ._raster_plot import RasterPlot
|
cubevis/plot/ms_plot/_ms_plot.py
DELETED
|
@@ -1,242 +0,0 @@
|
|
|
1
|
-
'''
|
|
2
|
-
Base class for ms plots
|
|
3
|
-
'''
|
|
4
|
-
|
|
5
|
-
import os
|
|
6
|
-
import time
|
|
7
|
-
|
|
8
|
-
from bokeh.plotting import show
|
|
9
|
-
import hvplot
|
|
10
|
-
import holoviews as hv
|
|
11
|
-
import numpy as np
|
|
12
|
-
import panel as pn
|
|
13
|
-
|
|
14
|
-
try:
|
|
15
|
-
from toolviper.utils.logger import get_logger, setup_logger
|
|
16
|
-
_HAVE_TOOLVIPER = True
|
|
17
|
-
except ImportError:
|
|
18
|
-
_HAVE_TOOLVIPER = False
|
|
19
|
-
|
|
20
|
-
from cubevis.data.measurement_set._ms_data import MsData
|
|
21
|
-
from cubevis.toolbox import AppContext
|
|
22
|
-
from cubevis.utils._logging import get_logger
|
|
23
|
-
|
|
24
|
-
class MsPlot:
|
|
25
|
-
|
|
26
|
-
''' Base class for MS plots with common functionality '''
|
|
27
|
-
|
|
28
|
-
def __init__(self, ms=None, log_level="info", show_gui=False, app_name="MsPlot"):
|
|
29
|
-
if not ms and not show_gui:
|
|
30
|
-
raise RuntimeError("Must provide ms/zarr path if gui not shown.")
|
|
31
|
-
|
|
32
|
-
# Set logger: use toolviper logger else casalog else python logger
|
|
33
|
-
if _HAVE_TOOLVIPER:
|
|
34
|
-
self._logger = setup_logger(app_name, log_to_term=True, log_to_file=False, log_level=log_level.upper())
|
|
35
|
-
else:
|
|
36
|
-
self._logger = get_logger()
|
|
37
|
-
self._logger.setLevel(log_level.upper())
|
|
38
|
-
|
|
39
|
-
# Save parameters; ms set below
|
|
40
|
-
self._show_gui = show_gui
|
|
41
|
-
self._app_name = app_name
|
|
42
|
-
|
|
43
|
-
# Set up temp dir for output html files; do not add cubevis bokeh libraries
|
|
44
|
-
self._app_context = AppContext(app_name, init_bokeh=False)
|
|
45
|
-
|
|
46
|
-
if show_gui:
|
|
47
|
-
# Enable "toast" notifications
|
|
48
|
-
pn.config.notifications = True
|
|
49
|
-
self._toast = None # for destroy() with new plot or new notification
|
|
50
|
-
|
|
51
|
-
# Initialize plot inputs and params
|
|
52
|
-
self._plot_inputs = {}
|
|
53
|
-
|
|
54
|
-
# Initialize plots
|
|
55
|
-
self._plot_init = False
|
|
56
|
-
self._plots_locked = False
|
|
57
|
-
self._plots = []
|
|
58
|
-
|
|
59
|
-
# Set data (if ms)
|
|
60
|
-
self._data = None
|
|
61
|
-
self._ms_info = {}
|
|
62
|
-
self._set_ms(ms)
|
|
63
|
-
|
|
64
|
-
def summary(self, data_group='base', columns=None):
|
|
65
|
-
''' Print ProcessingSet summary.
|
|
66
|
-
Args:
|
|
67
|
-
data_group (str): data group to use for summary.
|
|
68
|
-
columns (None, str, list): type of metadata to list.
|
|
69
|
-
None: Print all summary columns in ProcessingSet.
|
|
70
|
-
'by_msv4': Print formatted summary metadata by MSv4.
|
|
71
|
-
str, list: Print a subset of summary columns in ProcessingSet.
|
|
72
|
-
Options: 'name', 'intents', 'shape', 'polarization', 'scan_name', 'spw_name',
|
|
73
|
-
'field_name', 'source_name', 'field_coords', 'start_frequency', 'end_frequency'
|
|
74
|
-
Returns: list of unique values when single column is requested, else None
|
|
75
|
-
'''
|
|
76
|
-
self._data.summary(data_group, columns)
|
|
77
|
-
|
|
78
|
-
def data_groups(self):
|
|
79
|
-
''' Returns set of data groups from all ProcessingSet ms_xds. '''
|
|
80
|
-
return self._data.data_groups()
|
|
81
|
-
|
|
82
|
-
def antennas(self, plot_positions=False, label_antennas=False):
|
|
83
|
-
''' Returns list of antenna names in ProcessingSet antenna_xds.
|
|
84
|
-
plot_positions (bool): show plot of antenna positions.
|
|
85
|
-
label_antennas (bool): label positions with antenna names.
|
|
86
|
-
'''
|
|
87
|
-
return self._data.get_antennas(plot_positions, label_antennas)
|
|
88
|
-
|
|
89
|
-
def plot_phase_centers(self, data_group='base', label_fields=False):
|
|
90
|
-
''' Plot the phase center locations of all fields in the Processing Set and highlight central field.
|
|
91
|
-
data_group (str): data group to use for field and source xds.
|
|
92
|
-
label_fields (bool): label all fields on the plot if True, else label central field only
|
|
93
|
-
'''
|
|
94
|
-
self._data.plot_phase_centers(data_group, label_fields)
|
|
95
|
-
|
|
96
|
-
def clear_plots(self):
|
|
97
|
-
''' Clear plot list '''
|
|
98
|
-
while self._plots_locked:
|
|
99
|
-
time.sleep(1)
|
|
100
|
-
self._plots.clear()
|
|
101
|
-
|
|
102
|
-
def clear_selection(self):
|
|
103
|
-
''' Clear selection in data and restore to original '''
|
|
104
|
-
if self._data:
|
|
105
|
-
self._data.clear_selection()
|
|
106
|
-
|
|
107
|
-
def show(self):
|
|
108
|
-
'''
|
|
109
|
-
Show interactive Bokeh plots in a browser. Plot tools include pan, zoom, hover, and save.
|
|
110
|
-
'''
|
|
111
|
-
if not self._plots:
|
|
112
|
-
raise RuntimeError("No plots to show. Run plot() to create plot.")
|
|
113
|
-
|
|
114
|
-
# Do not delete plot list until rendered
|
|
115
|
-
self._plots_locked = True
|
|
116
|
-
|
|
117
|
-
# Single plot or combine plots into layout using subplots (rows, columns)
|
|
118
|
-
# Not layout if subplots is single plot (default if None) or if only one plot saved
|
|
119
|
-
subplots = self._plot_inputs['subplots']
|
|
120
|
-
layout_plot, is_layout = self._layout_plots(subplots)
|
|
121
|
-
|
|
122
|
-
# Render to bokeh figure
|
|
123
|
-
if is_layout:
|
|
124
|
-
# Show plots in columns
|
|
125
|
-
plot = hv.render(layout_plot.cols(subplots[1]))
|
|
126
|
-
else:
|
|
127
|
-
# Show single plot
|
|
128
|
-
plot = hv.render(layout_plot)
|
|
129
|
-
|
|
130
|
-
self._plots_locked = False
|
|
131
|
-
show(plot)
|
|
132
|
-
|
|
133
|
-
def save(self, filename='ms_plot.png', fmt='auto', width=900, height=600):
|
|
134
|
-
'''
|
|
135
|
-
Save plot to file with filename, format, and size.
|
|
136
|
-
If iteration plots were created:
|
|
137
|
-
If subplots is a grid, the layout plot will be saved to a single file.
|
|
138
|
-
If subplots is a single plot, iteration plots will be saved individually,
|
|
139
|
-
with a plot index appended to the filename: {filename}_{index}.{ext}.
|
|
140
|
-
'''
|
|
141
|
-
if not self._plots:
|
|
142
|
-
raise RuntimeError("No plot to save. Run plot() to create plot.")
|
|
143
|
-
|
|
144
|
-
start_time = time.time()
|
|
145
|
-
|
|
146
|
-
# Save single plot or combine plots into layout using subplots (rows, columns).
|
|
147
|
-
# Not layout if subplots is single plot or if only one plot saved.
|
|
148
|
-
subplots = self._plot_inputs['subplots']
|
|
149
|
-
layout_plot, is_layout = self._layout_plots(subplots)
|
|
150
|
-
|
|
151
|
-
if is_layout:
|
|
152
|
-
# Save plots combined into one layout
|
|
153
|
-
hvplot.save(layout_plot.cols(subplots[1]), filename=filename, fmt=fmt)
|
|
154
|
-
self._logger.info("Saved plot to %s.", filename)
|
|
155
|
-
else:
|
|
156
|
-
# Save plots individually, with index appended if exprange='all' and multiple plots.
|
|
157
|
-
if self._plot_inputs['iter_axis'] is None:
|
|
158
|
-
hvplot.save(layout_plot.opts(width=width, height=height), filename=filename, fmt=fmt)
|
|
159
|
-
self._logger.info("Saved plot to %s.", filename)
|
|
160
|
-
else:
|
|
161
|
-
name, ext = os.path.splitext(filename)
|
|
162
|
-
iter_range = self._plot_inputs['iter_range'] # None or (start, end)
|
|
163
|
-
plot_idx = 0 if iter_range is None else iter_range[0]
|
|
164
|
-
|
|
165
|
-
for plot in self._plots:
|
|
166
|
-
exportname = f"{name}_{plot_idx}.{ext}"
|
|
167
|
-
hvplot.save(plot.opts(width=width, height=height), filename=exportname, fmt=fmt)
|
|
168
|
-
self._logger.info("Saved plot to %s.", exportname)
|
|
169
|
-
plot_idx += 1
|
|
170
|
-
|
|
171
|
-
self._logger.debug("Save elapsed time: %.2fs.", time.time() - start_time)
|
|
172
|
-
|
|
173
|
-
def _layout_plots(self, subplots):
|
|
174
|
-
subplots = (1, 1) if subplots is None else subplots
|
|
175
|
-
num_plots = len(self._plots)
|
|
176
|
-
num_layout_plots = min(num_plots, np.prod(subplots))
|
|
177
|
-
|
|
178
|
-
if num_layout_plots == 1:
|
|
179
|
-
return self._plots[0], False
|
|
180
|
-
|
|
181
|
-
# Set plots in layout
|
|
182
|
-
plot_count = 0
|
|
183
|
-
layout = None
|
|
184
|
-
for i in range(num_layout_plots):
|
|
185
|
-
plot = self._plots[i]
|
|
186
|
-
layout = plot if layout is None else layout + plot
|
|
187
|
-
plot_count += 1
|
|
188
|
-
|
|
189
|
-
is_layout = plot_count > 1
|
|
190
|
-
return layout, is_layout
|
|
191
|
-
|
|
192
|
-
def _set_ms(self, ms):
|
|
193
|
-
''' Update ms info for input ms filepath (MSv2 or zarr), or None in show_gui mode.
|
|
194
|
-
Return whether ms changed (false if ms is None). '''
|
|
195
|
-
self._ms_info['ms'] = ms
|
|
196
|
-
ms_error = ""
|
|
197
|
-
ms_changed = ms and (not self._data or not self._data.is_ms_path(ms))
|
|
198
|
-
|
|
199
|
-
if ms_changed:
|
|
200
|
-
try:
|
|
201
|
-
# Set new MS data
|
|
202
|
-
self._data = MsData(ms, self._logger)
|
|
203
|
-
ms_path = self._data.get_path()
|
|
204
|
-
self._ms_info['ms'] = ms_path
|
|
205
|
-
root, ext = os.path.splitext(os.path.basename(ms_path))
|
|
206
|
-
while ext != '':
|
|
207
|
-
root, ext = os.path.splitext(root)
|
|
208
|
-
self._ms_info['basename'] = root
|
|
209
|
-
self._ms_info['data_dims'] = self._data.get_data_dimensions()
|
|
210
|
-
except RuntimeError as e:
|
|
211
|
-
ms_error = str(e)
|
|
212
|
-
self._data = None
|
|
213
|
-
|
|
214
|
-
if ms_error:
|
|
215
|
-
self._notify(ms_error, 'error', 0)
|
|
216
|
-
|
|
217
|
-
return ms_changed
|
|
218
|
-
|
|
219
|
-
def _notify(self, message, level, duration=3000):
|
|
220
|
-
''' Log message. If show_gui, notify user with toast for duration in ms.
|
|
221
|
-
Zero duration must be dismissed. '''
|
|
222
|
-
if self._show_gui:
|
|
223
|
-
pn.state.notifications.position = 'top-center'
|
|
224
|
-
if self._toast:
|
|
225
|
-
self._toast.destroy()
|
|
226
|
-
|
|
227
|
-
if level == "info":
|
|
228
|
-
self._logger.info(message)
|
|
229
|
-
if self._show_gui:
|
|
230
|
-
self._toast = pn.state.notifications.info(message, duration=duration)
|
|
231
|
-
elif level == "error":
|
|
232
|
-
self._logger.error(message)
|
|
233
|
-
if self._show_gui:
|
|
234
|
-
self._toast = pn.state.notifications.error(message, duration=duration)
|
|
235
|
-
elif level == "success":
|
|
236
|
-
self._logger.info(message)
|
|
237
|
-
if self._show_gui:
|
|
238
|
-
self._toast = pn.state.notifications.success(message, duration=duration)
|
|
239
|
-
elif level == "warning":
|
|
240
|
-
self._logger.warning(message)
|
|
241
|
-
if self._show_gui:
|
|
242
|
-
self._toast = pn.state.notifications.warning(message, duration=duration)
|
|
@@ -1,22 +0,0 @@
|
|
|
1
|
-
''' Define constants used for plotting MeasurementSets '''
|
|
2
|
-
|
|
3
|
-
SPECTRUM_AXIS_OPTIONS = ['amp', 'real']
|
|
4
|
-
UVW_AXIS_OPTIONS = ['u', 'v', 'w', 'uvdist']
|
|
5
|
-
VIS_AXIS_OPTIONS = ['amp', 'phase', 'real', 'imag']
|
|
6
|
-
WEIGHT_AXIS_OPTIONS = ['weight', 'sigma']
|
|
7
|
-
|
|
8
|
-
PS_SELECTION_OPTIONS = {
|
|
9
|
-
'MSv4 Name': 'name',
|
|
10
|
-
'Intents': 'intents',
|
|
11
|
-
'Scan Name': 'scan_name',
|
|
12
|
-
'Spectral Window Name': 'spw_name',
|
|
13
|
-
'Field Name': 'field_name',
|
|
14
|
-
'Source Name': 'source_name',
|
|
15
|
-
'Line Name': 'line_name'
|
|
16
|
-
}
|
|
17
|
-
MS_SELECTION_OPTIONS = ['Data Group', 'Time', 'Baseline', 'Antenna1', 'Antenna2', 'Frequency', 'Polarization']
|
|
18
|
-
|
|
19
|
-
AGGREGATOR_OPTIONS = ['None', 'max', 'mean', 'median', 'min', 'std', 'sum', 'var']
|
|
20
|
-
|
|
21
|
-
DEFAULT_UNFLAGGED_CMAP = "Viridis"
|
|
22
|
-
DEFAULT_FLAGGED_CMAP = "Reds"
|