cubevis 0.5.14__py3-none-any.whl → 0.5.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cubevis might be problematic. Click here for more details.

@@ -1,218 +0,0 @@
1
- '''
2
- Calculate statistics on xradio ProcessingSet data.
3
- '''
4
-
5
- import dask
6
- import numpy as np
7
-
8
- from xradio.measurement_set.load_processing_set import ProcessingSetIterator
9
- from graphviper.graph_tools import generate_dask_workflow
10
- from graphviper.graph_tools.coordinate_utils import make_parallel_coord, interpolate_data_coords_onto_parallel_coords
11
- from graphviper.graph_tools.map import map as graph_map
12
- from graphviper.graph_tools.reduce import reduce as graph_reduce
13
-
14
- try:
15
- from toolviper.dask.client import get_client
16
- _HAVE_TOOLVIPER = True
17
- except ImportError:
18
- _HAVE_TOOLVIPER = False
19
-
20
- from cubevis.data.measurement_set.processing_set._xds_data import get_correlated_data, get_axis_data
21
-
22
- def calculate_ps_stats(ps_xdt, ps_store, vis_axis, data_group, logger):
23
- '''
24
- Calculate stats for unflagged visibilities: min, max, mean, std
25
- ps_xdt (xarray.DataTree): input MeasurementSet opened from zarr file
26
- ps_store (str): path to visibility zarr file
27
- vis_axis (str): complex component (amp, phase, real, imag)
28
- Returns: stats tuple (min, max, mean, stddev) or None if all data flagged (count=0)
29
- '''
30
- input_params = {}
31
- input_params['input_data_store'] = ps_store
32
- input_params['xdt'] = ps_xdt
33
- input_params['vis_axis'] = vis_axis
34
- input_params['data_group'] = data_group
35
- for ms_xdt in ps_xdt.values():
36
- if data_group in ms_xdt.attrs['data_groups']:
37
- input_params['correlated_data'] = get_correlated_data(ms_xdt.ds, data_group)
38
- break
39
-
40
- if _HAVE_TOOLVIPER:
41
- active_client = get_client() # could be None if not set up outside cubevis
42
- else:
43
- active_client = None
44
- n_threads = active_client.thread_info()['n_threads'] if active_client is not None else 4
45
- logger.debug(f"Setting {n_threads} n_chunks for parallel coords.")
46
- mapping = _get_task_data_mapping(ps_xdt, n_threads)
47
-
48
- data_min, data_max, data_mean = _calc_basic_stats(ps_xdt, mapping, input_params, logger)
49
- if np.isfinite(data_mean):
50
- input_params['mean'] = data_mean
51
- data_stddev = _calc_stddev(ps_xdt, mapping, input_params, logger)
52
- return data_min, data_max, data_mean, data_stddev
53
- return None
54
-
55
- def _get_task_data_mapping(ps_xdt, n_threads):
56
- frequencies = ps_xdt.xr_ps.get_freq_axis()
57
- parallel_coords = {"frequency": make_parallel_coord(coord=frequencies, n_chunks=n_threads)}
58
- return interpolate_data_coords_onto_parallel_coords(parallel_coords, ps_xdt)
59
-
60
- def _calc_basic_stats(ps_xdt, mapping, input_params, logger):
61
- ''' Calculate min, max, mean using graph map/reduce '''
62
- graph = graph_map(
63
- input_data=ps_xdt,
64
- node_task_data_mapping=mapping,
65
- node_task=_map_stats,
66
- input_params=input_params
67
- )
68
- reduce_map = graph_reduce(
69
- graph, _reduce_stats, input_params, mode='tree'
70
- )
71
- dask_graph = generate_dask_workflow(reduce_map)
72
- #dask_graph.visualize(filename='stats.png')
73
- results = dask.compute(dask_graph)
74
-
75
- data_min, data_max, data_sum, data_count = results[0]
76
- if data_count == 0.0:
77
- logger.debug("stats: no unflagged data")
78
- return (data_min, data_max, np.inf)
79
-
80
- data_mean = data_sum / data_count
81
- logger.debug(f"basic stats: min={data_min:.4f}, max={data_max:.4f}, sum={data_sum:.4f}, count={data_count} mean={data_mean:.4f}")
82
- return data_min, data_max, data_mean
83
-
84
- def _calc_stddev(ps_xdt, mapping, input_params, logger):
85
- ''' Calculate stddev using graph map/reduce '''
86
- graph = graph_map(
87
- input_data=ps_xdt,
88
- node_task_data_mapping=mapping,
89
- node_task=_map_variance,
90
- input_params=input_params
91
- )
92
- reduce_map = graph_reduce(
93
- graph, _reduce_variance, input_params, mode='tree'
94
- )
95
- dask_graph = generate_dask_workflow(reduce_map)
96
- results = dask.compute(dask_graph)
97
-
98
- var_sum, var_count = results[0]
99
- data_variance = var_sum / var_count
100
- data_stddev = data_variance ** 0.5
101
- logger.debug(f"stats: variance={data_variance:.4f}, stddev={data_stddev:.4f}")
102
- return data_stddev
103
-
104
- def _get_stats_xda(xds, vis_axis, data_group):
105
- ''' Return xda with only unflagged cross-corr visibility data '''
106
- # apply flags to get unflagged vis data
107
- xda = get_axis_data(xds, vis_axis, data_group)
108
- unflagged_xda = xda.where(np.logical_not(xds.FLAG))
109
-
110
- if unflagged_xda.count() > 0 and "baseline_antenna1_name" in unflagged_xda.coords:
111
- # if unflagged data, remove autocorrelation baselines
112
- stats_xda = unflagged_xda.where(
113
- unflagged_xda.baseline_antenna1_name != unflagged_xda.baseline_antenna2_name
114
- )
115
- if stats_xda.count() > 0:
116
- # return xda with nan where flagged or auto-corr
117
- return stats_xda
118
-
119
- # return xda with nan where flagged
120
- return unflagged_xda
121
-
122
- def _map_stats(input_params):
123
- ''' Return min, max, sum, and count of data chunk '''
124
- vis_axis = input_params['vis_axis']
125
- data_group = input_params['data_group']
126
- correlated_data = input_params['correlated_data']
127
- min_vals = []
128
- max_vals = []
129
- sum_vals = []
130
- count_vals = []
131
-
132
- ps_iter = ProcessingSetIterator(
133
- input_params['data_selection'],
134
- input_params['input_data_store'],
135
- input_params['xdt'],
136
- input_params['data_group'],
137
- include_variables=[correlated_data, 'FLAG'],
138
- load_sub_datasets=False
139
- )
140
-
141
- for xds in ps_iter:
142
- xda = _get_stats_xda(xds, vis_axis, data_group)
143
- if xda.count() > 0:
144
- xda_data = xda.values.ravel()
145
- try:
146
- min_vals.append(xda_data[np.nanargmin(xda_data)])
147
- except ValueError:
148
- pass
149
- try:
150
- max_vals.append(xda_data[np.nanargmax(xda_data)])
151
- except ValueError:
152
- pass
153
- sum_vals.append(np.nansum(xda))
154
- count_vals.append(xda.count().values)
155
- try:
156
- min_value = np.nanmin(min_vals)
157
- except ValueError:
158
- min_value = np.nan
159
- try:
160
- max_value = np.nanmax(max_vals)
161
- except ValueError:
162
- max_value = np.nan
163
- return (min_value, max_value, sum(sum_vals), sum(count_vals))
164
-
165
- # pylint: disable=unused-argument
166
- def _reduce_stats(graph_inputs, input_params):
167
- ''' Compute min, max, sum, and count of all data.
168
- input_parameters seems to be required although unused. '''
169
- data_min = 0.0
170
- mins = [values[0] for values in graph_inputs]
171
- if not np.isnan(mins).all():
172
- data_min = min(0.0, np.nanmin(mins))
173
-
174
- data_max = 0.0
175
- maxs = [values[1] for values in graph_inputs]
176
- if not np.isnan(maxs).all():
177
- data_max = max(0.0, np.nanmax(maxs))
178
-
179
- data_sum = sum(values[2] for values in graph_inputs)
180
- data_count = sum(values[3] for values in graph_inputs)
181
- return (data_min, data_max, data_sum, data_count)
182
- # pylint: enable=unused-argument
183
-
184
- def _map_variance(input_params):
185
- ''' Return sum, count, of (xda - mean) squared '''
186
- vis_axis = input_params['vis_axis']
187
- data_group = input_params['data_group']
188
- correlated_data = input_params['correlated_data']
189
- mean = input_params['mean']
190
-
191
- sq_diff_sum = 0.0
192
- sq_diff_count = 0
193
-
194
- ps_iter = ProcessingSetIterator(
195
- input_params['data_selection'],
196
- input_params['input_data_store'],
197
- input_params['xdt'],
198
- input_params['data_group'],
199
- include_variables=[correlated_data, 'FLAG'],
200
- load_sub_datasets=False
201
- )
202
-
203
- for xds in ps_iter:
204
- xda = _get_stats_xda(xds, vis_axis, data_group)
205
- if xda.size > 0:
206
- sq_diff = (xda - mean) ** 2
207
- sq_diff_sum += np.nansum(sq_diff)
208
- sq_diff_count += sq_diff.count().values
209
- return (sq_diff_sum, sq_diff_count)
210
-
211
- # pylint: disable=unused-argument
212
- def _reduce_variance(graph_inputs, input_params):
213
- ''' Compute sum and count of all (xda-mean) squared data.
214
- input_parameters seems to be required although unused. '''
215
- sq_diff_sum = sum(values[0] for values in graph_inputs)
216
- sq_diff_count = sum(values[1] for values in graph_inputs)
217
- return (sq_diff_sum, sq_diff_count)
218
- # pylint: enable=unused-argument
@@ -1,149 +0,0 @@
1
- ''' Get MeasurementSet data from xarray Dataset '''
2
-
3
- from astropy import constants
4
- import numpy as np
5
- import xarray as xr
6
-
7
- from cubevis.plot.ms_plot._ms_plot_constants import SPECTRUM_AXIS_OPTIONS, UVW_AXIS_OPTIONS, VIS_AXIS_OPTIONS, WEIGHT_AXIS_OPTIONS
8
-
9
- def get_correlated_data(xds, data_group):
10
- ''' Return correlated_data value in data_group dict '''
11
- return xds.attrs['data_groups'][data_group]['correlated_data']
12
-
13
- def get_axis_data(xds, axis, data_group=None):
14
- ''' Get requested axis data from xarray dataset.
15
- xds (dict): msv4 xarray.Dataset
16
- axis (str): axis data to retrieve.
17
- data_group (str): correlated data group name.
18
- Returns: xarray.DataArray
19
- '''
20
- group_info = xds.data_groups[data_group]
21
- xda = None
22
- if _is_coordinate_axis(axis):
23
- xda = xds[axis]
24
- elif _is_antenna_axis(axis):
25
- xda = _get_antenna_axis(xds, axis)
26
- elif _is_vis_axis(axis):
27
- xda = _calc_vis_axis(xds, axis, group_info)
28
- elif _is_uvw_axis(axis):
29
- xda = _calc_uvw_axis(xds, axis, group_info)
30
- elif _is_weight_axis(axis):
31
- xda = _calc_weight_axis(xds, axis, group_info)
32
- elif 'spw' in axis:
33
- xda = _get_spw_axis(xds, axis)
34
- elif 'wave' in axis:
35
- xda = _calc_wave_axis(xds, axis, group_info)
36
- elif axis == 'field':
37
- xda = xr.DataArray([xds[group_info['correlated_data']].field_and_source_xds.field_name])
38
- elif axis == 'flag':
39
- xda = xds[group_info['flag']]
40
- elif axis == 'intents':
41
- xda = xr.DataArray(["".join(xds.partition_info['intents'])])
42
- elif axis == 'channel':
43
- xda = xr.DataArray(np.array(range(xds.frequency.size)))
44
- return xda
45
-
46
- def _is_coordinate_axis(axis):
47
- return axis in ['scan_number', 'time', 'frequency', 'polarization',
48
- #'velocity': 'frequency', # calculate
49
- # TODO?
50
- #'observation': no id, xds.observation_info (observer, project, release date)
51
- #'feed1': no id, xds.antenna_xds
52
- #'feed2': no id, xds.antenna_xds
53
- ]
54
-
55
- def _is_vis_axis(axis):
56
- return axis in VIS_AXIS_OPTIONS
57
-
58
- def _is_uvw_axis(axis):
59
- return axis in UVW_AXIS_OPTIONS
60
-
61
- def _is_antenna_axis(axis):
62
- return 'baseline' in axis or 'antenna' in axis
63
-
64
- def _is_weight_axis(axis):
65
- return axis in WEIGHT_AXIS_OPTIONS
66
-
67
- def _get_spw_axis(xds, axis):
68
- if axis == 'spw_id':
69
- return xr.DataArray([xds.frequency.spectral_window_id])
70
- if axis == 'spw_name':
71
- return xr.DataArray([xds.frequency.spectral_window_name])
72
- raise ValueError(f"Invalid spw axis {axis}")
73
-
74
- def _calc_vis_axis(xds, axis, group_info):
75
- ''' Calculate axis from correlated data '''
76
- correlated_data = group_info['correlated_data']
77
- xda = xds[correlated_data]
78
-
79
- # Single dish spectrum
80
- if correlated_data == "SPECTRUM":
81
- if axis in SPECTRUM_AXIS_OPTIONS:
82
- return xda.assign_attrs(units='Jy')
83
- raise RuntimeError(f"Vis axis {axis} invalid for SPECTRUM dataset, select from {SPECTRUM_AXIS_OPTIONS}")
84
-
85
- # Interferometry visibilities
86
- if axis == 'amp':
87
- return np.absolute(xda).assign_attrs(units='Jy')
88
- if axis == 'phase':
89
- # np.angle(xda) returns ndarray not xr.DataArray
90
- return (np.arctan2(xda.imag, xda.real) * 180.0/np.pi).assign_attrs(units="deg")
91
- if axis == 'real':
92
- return np.real(xda.assign_attrs(units='Jy'))
93
- if axis == 'imag':
94
- return np.imag(xda.assign_attrs(units='Jy'))
95
- return None
96
-
97
- def _calc_uvw_axis(xds, axis, group_info):
98
- ''' Calculate axis from UVW xarray DataArray '''
99
- if 'uvw' not in group_info:
100
- raise RuntimeError(f"Axis {axis} is not valid in this dataset, no uvw data")
101
-
102
- uvw_xda = xds[group_info['uvw']]
103
-
104
- if axis == 'u':
105
- return uvw_xda.isel(uvw_label=0)
106
- if axis == 'v':
107
- return uvw_xda.isel(uvw_label=1)
108
- if axis == 'w':
109
- return uvw_xda.isel(uvw_label=2)
110
-
111
- # uvdist
112
- u_xda = uvw_xda.isel(uvw_label=0)
113
- v_xda = uvw_xda.isel(uvw_label=1)
114
- return np.sqrt(np.square(u_xda) + np.square(v_xda))
115
-
116
- def _calc_wave_axis(xds, axis, group_info):
117
- wave_axes = {'uwave': 'u', 'vwave': 'v', 'wwave': 'w', 'uvwave': 'uvdist'}
118
- if axis not in wave_axes:
119
- raise ValueError(f"Invalid wave axis {axis}")
120
- uvwdist_array = _calc_uvw_axis(xds, wave_axes[axis], group_info).values / constants.c
121
- uvwdist_len = len(uvwdist_array)
122
- freq_array = xds.frequency.values
123
- wave = np.zeros(shape=(len(freq_array), uvwdist_len), dtype=np.double)
124
-
125
- for i, in range(uvwdist_len):
126
- wave[:, i] = uvwdist_array[i] * freq_array
127
-
128
- wave_xda = xr.DataArray(wave, dims=['frequency', 'uvw_label'])
129
- return wave_xda
130
-
131
- def _calc_weight_axis(xds, axis, group_info):
132
- weight = xds[group_info['weight']]
133
- if axis == 'weight':
134
- return weight
135
- return np.sqrt(1.0 / weight)
136
-
137
- def _get_antenna_axis(xds, axis):
138
- if 'antenna_name' in xds.coords:
139
- return xds.antenna_name
140
-
141
- if axis == 'antenna1':
142
- return xds.baseline_antenna1_name
143
- if axis == 'antenna2':
144
- return xds.baseline_antenna2_name
145
- if axis == 'baseline_id':
146
- return xds.baseline_id
147
- if axis == 'baseline':
148
- return xds.baseline
149
- raise ValueError(f"Invalid antenna/baseline axis {axis}")
cubevis/plot/__init__.py DELETED
@@ -1 +0,0 @@
1
- ''' Module for plotting functions. Currently supported MeasurementSet plots only. '''
@@ -1,29 +0,0 @@
1
- ''' This module contains classes and utilities to support MS plotting '''
2
-
3
- from ._ms_plot import MsPlot
4
-
5
- from ._ms_plot_constants import (
6
- SPECTRUM_AXIS_OPTIONS,
7
- UVW_AXIS_OPTIONS,
8
- VIS_AXIS_OPTIONS,
9
- WEIGHT_AXIS_OPTIONS,
10
- PS_SELECTION_OPTIONS,
11
- MS_SELECTION_OPTIONS,
12
- AGGREGATOR_OPTIONS,
13
- DEFAULT_UNFLAGGED_CMAP,
14
- DEFAULT_FLAGGED_CMAP,
15
- )
16
-
17
- from ._ms_plot_selectors import (
18
- file_selector,
19
- title_selector,
20
- style_selector,
21
- axis_selector,
22
- aggregation_selector,
23
- iteration_selector,
24
- selection_selector,
25
- plot_starter,
26
- )
27
-
28
- from ._raster_plot_inputs import check_inputs
29
- from ._raster_plot import RasterPlot
@@ -1,242 +0,0 @@
1
- '''
2
- Base class for ms plots
3
- '''
4
-
5
- import os
6
- import time
7
-
8
- from bokeh.plotting import show
9
- import hvplot
10
- import holoviews as hv
11
- import numpy as np
12
- import panel as pn
13
-
14
- try:
15
- from toolviper.utils.logger import get_logger, setup_logger
16
- _HAVE_TOOLVIPER = True
17
- except ImportError:
18
- _HAVE_TOOLVIPER = False
19
-
20
- from cubevis.data.measurement_set._ms_data import MsData
21
- from cubevis.toolbox import AppContext
22
- from cubevis.utils._logging import get_logger
23
-
24
- class MsPlot:
25
-
26
- ''' Base class for MS plots with common functionality '''
27
-
28
- def __init__(self, ms=None, log_level="info", show_gui=False, app_name="MsPlot"):
29
- if not ms and not show_gui:
30
- raise RuntimeError("Must provide ms/zarr path if gui not shown.")
31
-
32
- # Set logger: use toolviper logger else casalog else python logger
33
- if _HAVE_TOOLVIPER:
34
- self._logger = setup_logger(app_name, log_to_term=True, log_to_file=False, log_level=log_level.upper())
35
- else:
36
- self._logger = get_logger()
37
- self._logger.setLevel(log_level.upper())
38
-
39
- # Save parameters; ms set below
40
- self._show_gui = show_gui
41
- self._app_name = app_name
42
-
43
- # Set up temp dir for output html files; do not add cubevis bokeh libraries
44
- self._app_context = AppContext(app_name, init_bokeh=False)
45
-
46
- if show_gui:
47
- # Enable "toast" notifications
48
- pn.config.notifications = True
49
- self._toast = None # for destroy() with new plot or new notification
50
-
51
- # Initialize plot inputs and params
52
- self._plot_inputs = {}
53
-
54
- # Initialize plots
55
- self._plot_init = False
56
- self._plots_locked = False
57
- self._plots = []
58
-
59
- # Set data (if ms)
60
- self._data = None
61
- self._ms_info = {}
62
- self._set_ms(ms)
63
-
64
- def summary(self, data_group='base', columns=None):
65
- ''' Print ProcessingSet summary.
66
- Args:
67
- data_group (str): data group to use for summary.
68
- columns (None, str, list): type of metadata to list.
69
- None: Print all summary columns in ProcessingSet.
70
- 'by_msv4': Print formatted summary metadata by MSv4.
71
- str, list: Print a subset of summary columns in ProcessingSet.
72
- Options: 'name', 'intents', 'shape', 'polarization', 'scan_name', 'spw_name',
73
- 'field_name', 'source_name', 'field_coords', 'start_frequency', 'end_frequency'
74
- Returns: list of unique values when single column is requested, else None
75
- '''
76
- self._data.summary(data_group, columns)
77
-
78
- def data_groups(self):
79
- ''' Returns set of data groups from all ProcessingSet ms_xds. '''
80
- return self._data.data_groups()
81
-
82
- def antennas(self, plot_positions=False, label_antennas=False):
83
- ''' Returns list of antenna names in ProcessingSet antenna_xds.
84
- plot_positions (bool): show plot of antenna positions.
85
- label_antennas (bool): label positions with antenna names.
86
- '''
87
- return self._data.get_antennas(plot_positions, label_antennas)
88
-
89
- def plot_phase_centers(self, data_group='base', label_fields=False):
90
- ''' Plot the phase center locations of all fields in the Processing Set and highlight central field.
91
- data_group (str): data group to use for field and source xds.
92
- label_fields (bool): label all fields on the plot if True, else label central field only
93
- '''
94
- self._data.plot_phase_centers(data_group, label_fields)
95
-
96
- def clear_plots(self):
97
- ''' Clear plot list '''
98
- while self._plots_locked:
99
- time.sleep(1)
100
- self._plots.clear()
101
-
102
- def clear_selection(self):
103
- ''' Clear selection in data and restore to original '''
104
- if self._data:
105
- self._data.clear_selection()
106
-
107
- def show(self):
108
- '''
109
- Show interactive Bokeh plots in a browser. Plot tools include pan, zoom, hover, and save.
110
- '''
111
- if not self._plots:
112
- raise RuntimeError("No plots to show. Run plot() to create plot.")
113
-
114
- # Do not delete plot list until rendered
115
- self._plots_locked = True
116
-
117
- # Single plot or combine plots into layout using subplots (rows, columns)
118
- # Not layout if subplots is single plot (default if None) or if only one plot saved
119
- subplots = self._plot_inputs['subplots']
120
- layout_plot, is_layout = self._layout_plots(subplots)
121
-
122
- # Render to bokeh figure
123
- if is_layout:
124
- # Show plots in columns
125
- plot = hv.render(layout_plot.cols(subplots[1]))
126
- else:
127
- # Show single plot
128
- plot = hv.render(layout_plot)
129
-
130
- self._plots_locked = False
131
- show(plot)
132
-
133
- def save(self, filename='ms_plot.png', fmt='auto', width=900, height=600):
134
- '''
135
- Save plot to file with filename, format, and size.
136
- If iteration plots were created:
137
- If subplots is a grid, the layout plot will be saved to a single file.
138
- If subplots is a single plot, iteration plots will be saved individually,
139
- with a plot index appended to the filename: {filename}_{index}.{ext}.
140
- '''
141
- if not self._plots:
142
- raise RuntimeError("No plot to save. Run plot() to create plot.")
143
-
144
- start_time = time.time()
145
-
146
- # Save single plot or combine plots into layout using subplots (rows, columns).
147
- # Not layout if subplots is single plot or if only one plot saved.
148
- subplots = self._plot_inputs['subplots']
149
- layout_plot, is_layout = self._layout_plots(subplots)
150
-
151
- if is_layout:
152
- # Save plots combined into one layout
153
- hvplot.save(layout_plot.cols(subplots[1]), filename=filename, fmt=fmt)
154
- self._logger.info("Saved plot to %s.", filename)
155
- else:
156
- # Save plots individually, with index appended if exprange='all' and multiple plots.
157
- if self._plot_inputs['iter_axis'] is None:
158
- hvplot.save(layout_plot.opts(width=width, height=height), filename=filename, fmt=fmt)
159
- self._logger.info("Saved plot to %s.", filename)
160
- else:
161
- name, ext = os.path.splitext(filename)
162
- iter_range = self._plot_inputs['iter_range'] # None or (start, end)
163
- plot_idx = 0 if iter_range is None else iter_range[0]
164
-
165
- for plot in self._plots:
166
- exportname = f"{name}_{plot_idx}.{ext}"
167
- hvplot.save(plot.opts(width=width, height=height), filename=exportname, fmt=fmt)
168
- self._logger.info("Saved plot to %s.", exportname)
169
- plot_idx += 1
170
-
171
- self._logger.debug("Save elapsed time: %.2fs.", time.time() - start_time)
172
-
173
- def _layout_plots(self, subplots):
174
- subplots = (1, 1) if subplots is None else subplots
175
- num_plots = len(self._plots)
176
- num_layout_plots = min(num_plots, np.prod(subplots))
177
-
178
- if num_layout_plots == 1:
179
- return self._plots[0], False
180
-
181
- # Set plots in layout
182
- plot_count = 0
183
- layout = None
184
- for i in range(num_layout_plots):
185
- plot = self._plots[i]
186
- layout = plot if layout is None else layout + plot
187
- plot_count += 1
188
-
189
- is_layout = plot_count > 1
190
- return layout, is_layout
191
-
192
- def _set_ms(self, ms):
193
- ''' Update ms info for input ms filepath (MSv2 or zarr), or None in show_gui mode.
194
- Return whether ms changed (false if ms is None). '''
195
- self._ms_info['ms'] = ms
196
- ms_error = ""
197
- ms_changed = ms and (not self._data or not self._data.is_ms_path(ms))
198
-
199
- if ms_changed:
200
- try:
201
- # Set new MS data
202
- self._data = MsData(ms, self._logger)
203
- ms_path = self._data.get_path()
204
- self._ms_info['ms'] = ms_path
205
- root, ext = os.path.splitext(os.path.basename(ms_path))
206
- while ext != '':
207
- root, ext = os.path.splitext(root)
208
- self._ms_info['basename'] = root
209
- self._ms_info['data_dims'] = self._data.get_data_dimensions()
210
- except RuntimeError as e:
211
- ms_error = str(e)
212
- self._data = None
213
-
214
- if ms_error:
215
- self._notify(ms_error, 'error', 0)
216
-
217
- return ms_changed
218
-
219
- def _notify(self, message, level, duration=3000):
220
- ''' Log message. If show_gui, notify user with toast for duration in ms.
221
- Zero duration must be dismissed. '''
222
- if self._show_gui:
223
- pn.state.notifications.position = 'top-center'
224
- if self._toast:
225
- self._toast.destroy()
226
-
227
- if level == "info":
228
- self._logger.info(message)
229
- if self._show_gui:
230
- self._toast = pn.state.notifications.info(message, duration=duration)
231
- elif level == "error":
232
- self._logger.error(message)
233
- if self._show_gui:
234
- self._toast = pn.state.notifications.error(message, duration=duration)
235
- elif level == "success":
236
- self._logger.info(message)
237
- if self._show_gui:
238
- self._toast = pn.state.notifications.success(message, duration=duration)
239
- elif level == "warning":
240
- self._logger.warning(message)
241
- if self._show_gui:
242
- self._toast = pn.state.notifications.warning(message, duration=duration)
@@ -1,22 +0,0 @@
1
- ''' Define constants used for plotting MeasurementSets '''
2
-
3
- SPECTRUM_AXIS_OPTIONS = ['amp', 'real']
4
- UVW_AXIS_OPTIONS = ['u', 'v', 'w', 'uvdist']
5
- VIS_AXIS_OPTIONS = ['amp', 'phase', 'real', 'imag']
6
- WEIGHT_AXIS_OPTIONS = ['weight', 'sigma']
7
-
8
- PS_SELECTION_OPTIONS = {
9
- 'MSv4 Name': 'name',
10
- 'Intents': 'intents',
11
- 'Scan Name': 'scan_name',
12
- 'Spectral Window Name': 'spw_name',
13
- 'Field Name': 'field_name',
14
- 'Source Name': 'source_name',
15
- 'Line Name': 'line_name'
16
- }
17
- MS_SELECTION_OPTIONS = ['Data Group', 'Time', 'Baseline', 'Antenna1', 'Antenna2', 'Frequency', 'Polarization']
18
-
19
- AGGREGATOR_OPTIONS = ['None', 'max', 'mean', 'median', 'min', 'std', 'sum', 'var']
20
-
21
- DEFAULT_UNFLAGGED_CMAP = "Viridis"
22
- DEFAULT_FLAGGED_CMAP = "Reds"