specula 0.0.0__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. specula/_version.py +34 -0
  2. specula/data/EltM1SegmMap1015pix38570mm.fits +0 -0
  3. specula/display/__init__.py +0 -0
  4. specula/display/base_display.py +99 -0
  5. specula/display/data_plotter.py +177 -0
  6. specula/display/double_phase_display.py +142 -0
  7. specula/display/modes_display.py +89 -0
  8. specula/display/phase_display.py +48 -0
  9. specula/display/pixels_display.py +197 -0
  10. specula/display/plot_display.py +135 -0
  11. specula/display/plot_vector_display.py +270 -0
  12. specula/display/psf_display.py +72 -0
  13. specula/display/pupil_display.py +134 -0
  14. specula/display/slopec_display.py +47 -0
  15. specula/lib/ELT_M1_ifunc_calculator.py +114 -0
  16. specula/lib/__init__.py +0 -0
  17. specula/lib/calc_noise_cov_elong.py +322 -0
  18. specula/lib/calc_phasescreen.py +105 -0
  19. specula/lib/calc_psf.py +185 -0
  20. specula/lib/calc_spatialfrequency.py +36 -0
  21. specula/lib/compute_petal_ifunc.py +146 -0
  22. specula/lib/compute_zern_ifunc.py +30 -0
  23. specula/lib/compute_zonal_ifunc.py +191 -0
  24. specula/lib/extrapolation_2d.py +172 -0
  25. specula/lib/flatten.py +37 -0
  26. specula/lib/fsoc_lib.py +71 -0
  27. specula/lib/interp2d.py +170 -0
  28. specula/lib/make_mask.py +147 -0
  29. specula/lib/make_xy.py +186 -0
  30. specula/lib/mask.py +383 -0
  31. specula/lib/mmse_reconstructor.py +132 -0
  32. specula/lib/modal_base_generator.py +370 -0
  33. specula/lib/modal_pushpull_signal.py +114 -0
  34. specula/lib/n_phot.py +304 -0
  35. specula/lib/phasescreen_manager.py +54 -0
  36. specula/lib/platescale_coeff.py +125 -0
  37. specula/lib/process_utils.py +50 -0
  38. specula/lib/rebin.py +58 -0
  39. specula/lib/toccd.py +190 -0
  40. specula/lib/utils.py +245 -0
  41. specula/lib/zernike_generator.py +457 -0
  42. specula/scripts/parse_classes.py +159 -0
  43. specula/scripts/specula_main.py +31 -0
  44. specula/scripts/templates/specula_display.html +168 -0
  45. specula/scripts/templates/specula_frontend.html +128 -0
  46. specula/scripts/web_frontend.py +47 -0
  47. {specula-0.0.0.dist-info → specula-1.0.0.dist-info}/METADATA +7 -2
  48. {specula-0.0.0.dist-info → specula-1.0.0.dist-info}/RECORD +52 -6
  49. {specula-0.0.0.dist-info → specula-1.0.0.dist-info}/WHEEL +0 -0
  50. {specula-0.0.0.dist-info → specula-1.0.0.dist-info}/entry_points.txt +0 -0
  51. {specula-0.0.0.dist-info → specula-1.0.0.dist-info}/licenses/LICENSE +0 -0
  52. {specula-0.0.0.dist-info → specula-1.0.0.dist-info}/top_level.txt +0 -0
specula/_version.py ADDED
@@ -0,0 +1,34 @@
1
+ # file generated by setuptools-scm
2
+ # don't change, don't track in version control
3
+
4
+ __all__ = [
5
+ "__version__",
6
+ "__version_tuple__",
7
+ "version",
8
+ "version_tuple",
9
+ "__commit_id__",
10
+ "commit_id",
11
+ ]
12
+
13
+ TYPE_CHECKING = False
14
+ if TYPE_CHECKING:
15
+ from typing import Tuple
16
+ from typing import Union
17
+
18
+ VERSION_TUPLE = Tuple[Union[int, str], ...]
19
+ COMMIT_ID = Union[str, None]
20
+ else:
21
+ VERSION_TUPLE = object
22
+ COMMIT_ID = object
23
+
24
+ version: str
25
+ __version__: str
26
+ __version_tuple__: VERSION_TUPLE
27
+ version_tuple: VERSION_TUPLE
28
+ commit_id: COMMIT_ID
29
+ __commit_id__: COMMIT_ID
30
+
31
+ __version__ = version = '1.0.0'
32
+ __version_tuple__ = version_tuple = (1, 0, 0)
33
+
34
+ __commit_id__ = commit_id = 'gd50bd1fba'
File without changes
@@ -0,0 +1,99 @@
1
+ from specula.base_processing_obj import BaseProcessingObj
2
+ import matplotlib.pyplot as plt
3
+
4
+
5
+ class BaseDisplay(BaseProcessingObj):
6
+ def __init__(self,
7
+ title='',
8
+ figsize=(8, 6)):
9
+ super().__init__()
10
+ self._title = title
11
+ self._figsize = figsize
12
+ self._opened = False
13
+ self._colorbar_added = False
14
+ self.input_key = ''
15
+ self.fig = None
16
+ self.ax = None
17
+ self.img = None
18
+ self.line = None
19
+
20
+ def _create_figure(self):
21
+ """Create the matplotlib figure and axes"""
22
+ if self._opened:
23
+ return
24
+
25
+ self.fig = plt.figure(figsize=self._figsize)
26
+ self.ax = self.fig.add_subplot(111)
27
+ if self._title:
28
+ self.fig.suptitle(self._title)
29
+ self.fig.show()
30
+ self._opened = True
31
+
32
+ def _update_display(self, data):
33
+ """Update the display with new data"""
34
+ raise NotImplementedError("Subclasses should implement this method")
35
+
36
+ def _get_data(self):
37
+ """Get data from input. Derived classes can override this method
38
+ in case of complex data"""
39
+ data = self.local_inputs.get(self.input_key)
40
+ if data is None:
41
+ self._show_error(f"No {self.input_key} data available")
42
+ return
43
+ return data
44
+
45
+ def trigger_code(self):
46
+ try:
47
+ if not self._opened:
48
+ self._create_figure()
49
+ data = self._get_data()
50
+ self._update_display(data)
51
+ except Exception as e:
52
+ self._show_error(f"Display error: {str(e)}")
53
+
54
+ # ============ UTILITY METHODS ============
55
+
56
+ def set_y_range(self, ymin, ymax):
57
+ """Set fixed Y axis range"""
58
+ if hasattr(self, '_yrange'):
59
+ self._yrange = (ymin, ymax)
60
+ if self._opened and self.ax:
61
+ self.ax.set_ylim(ymin, ymax)
62
+ self._safe_draw()
63
+
64
+ def auto_y_range(self):
65
+ """Enable automatic Y axis scaling - override in subclasses for specific logic"""
66
+ if hasattr(self, '_yrange'):
67
+ self._yrange = (0, 0)
68
+
69
+ def _add_colorbar_if_needed(self, image_obj):
70
+ """Add colorbar if not already present"""
71
+ if not hasattr(self, '_colorbar_added'):
72
+ self._colorbar_added = False
73
+
74
+ if not self._colorbar_added and image_obj is not None:
75
+ plt.colorbar(image_obj, ax=self.ax)
76
+ self._colorbar_added = True
77
+
78
+ def _update_image_data(self, image_obj, data):
79
+ """Standard image update logic"""
80
+ if image_obj is not None:
81
+ image_obj.set_data(data)
82
+ image_obj.set_clim(data.min(), data.max())
83
+
84
+ def _show_error(self, message):
85
+ if not self._opened:
86
+ self._create_figure()
87
+ self.ax.clear()
88
+ self.ax.text(0.5, 0.5, message, ha='center', va='center',
89
+ transform=self.ax.transAxes, color='red', fontsize=12)
90
+ self._safe_draw()
91
+
92
+ def _safe_draw(self):
93
+ """Thread-safe drawing method"""
94
+ try:
95
+ if self.fig and self.fig.canvas:
96
+ self.fig.canvas.draw_idle()
97
+ self.fig.canvas.flush_events()
98
+ except Exception as e:
99
+ print(f"Drawing error: {e}")
@@ -0,0 +1,177 @@
1
+
2
+ import matplotlib
3
+ import numpy as np
4
+ matplotlib.use('Agg') # Memory backend, no GUI
5
+
6
+ from matplotlib.figure import Figure
7
+
8
+ dataplotter_cache = {}
9
+
10
+ class DataPlotter():
11
+ '''
12
+ Plot any kind of data in a memory backend
13
+ '''
14
+ def __init__(self, disp_factor=1, histlen=200, wsize=(400, 300), yrange=(-10, 10), title=''):
15
+ super().__init__()
16
+
17
+ self._wsize = wsize
18
+ self._history = np.zeros(histlen)
19
+ self._count = 0
20
+ self._yrange = yrange
21
+ self._value = None
22
+ self._title = title
23
+ self._w_is_set = False
24
+ self._first = True
25
+ self._disp_factor = disp_factor
26
+
27
+ def set_w(self, size_frame=None, nframes=1):
28
+ if self._w_is_set:
29
+ return
30
+
31
+ if size_frame is None:
32
+ size_frame = self._wsize
33
+ self.fig = Figure(figsize=(size_frame[0] * self._disp_factor / 100 * nframes, size_frame[1] * self._disp_factor / 100))
34
+ self.ax = []
35
+ for i in range(nframes):
36
+ self.ax.append(self.fig.add_subplot(1, nframes, i+1))
37
+ self.fig.subplots_adjust(hspace=0.1, wspace=0.1) # Adjust these values as needed
38
+
39
+ self._w_is_set = True
40
+
41
+ def multi_plot(self, obj_list):
42
+ '''
43
+ Plot a list of data objects one next to the other.
44
+
45
+ Returns a matplotlib.Figure object.
46
+
47
+ Generation of the numpy array to display is delegated to each data object.
48
+ '''
49
+
50
+ if obj_list is None:
51
+ return self.plot_text(f'This is None')
52
+
53
+ if len(obj_list) < 1:
54
+ return self.plot_text(f'No values to plot')
55
+
56
+ # TODO: commented out because it does not work for prop.layer_list,
57
+ # since layers are of different types
58
+ #
59
+ # for obj in objrefs[1:]:
60
+ # if type(obj) is not type(objrefs[0]):
61
+ # raise ValueError('All objects in multi_plot() must be of the same type')
62
+
63
+ if not hasattr(obj_list[0], 'array_for_display'):
64
+ return self.plot_text(f'Plot not implemented for class {obj_list[0].__class__.__name__}')
65
+
66
+ frames = [x.array_for_display() for x in obj_list]
67
+
68
+ for f in frames:
69
+ if f is None:
70
+ return self.plot_text(f'Cannot plot None values')
71
+
72
+ # 2d images: imshow
73
+ if len(frames[0].shape) == 2:
74
+ return self.imshow(frames)
75
+
76
+ # Single vector value: plot_vector
77
+ elif len(frames[0].shape) == 1 and len(frames[0]) > 1 and len(frames) == 1:
78
+ return self.plot_vector(frames[0])
79
+
80
+ # Scalar value: plot history
81
+ elif len(frames[0].shape) == 1 and len(frames[0]) == 1 and len(frames) == 1:
82
+ return self.plot_history(frames[0])
83
+
84
+ # Another kind of scalar value: plot history
85
+ elif len(frames[0].shape) == 0 and len(frames) == 1:
86
+ return self.plot_history(frames[0])
87
+
88
+ else:
89
+ return self.plot_text(f'Cannot plot: data shape is {frames[0].shape} x {len(frames)}')
90
+
91
+ def plot_history(self, value):
92
+ n = len(self._history)
93
+ if self._count >= n:
94
+ self._history[:-1] = self._history[1:]
95
+ self._count = n - 1
96
+
97
+ self.set_w()
98
+ self._history[self._count] = value
99
+ self._count += 1
100
+
101
+ x = np.arange(self._count)
102
+ y = self._history[:self._count]
103
+ if self._first:
104
+ self.fig.suptitle(self._title)
105
+ self.line = self.ax[0].plot(x, y, marker='.')
106
+ self._first = False
107
+ else:
108
+ self.line[0].set_xdata(x)
109
+ self.line[0].set_ydata(y)
110
+ self.ax[0].set_xlim(x.min(), x.max())
111
+ self.ax[0].set_ylim(y.min(), y.max())
112
+ self.fig.canvas.draw()
113
+ return self.fig
114
+
115
+ def plot_vector(self, vector):
116
+ self.set_w()
117
+
118
+ if self._first:
119
+ self._line = self.ax[0].plot(vector, '.-')
120
+ self.fig.suptitle(self._title)
121
+ self.ax[0].set_ylim([vector.min(), vector.max()])
122
+ self._first = False
123
+ else:
124
+ self._line[0].set_ydata(vector)
125
+ return self.fig
126
+
127
+ def imshow(self, frames):
128
+ if np.sum(self._wsize) == 0:
129
+ size_frame = frames[0].shape[0]
130
+ else:
131
+ size_frame = self._wsize
132
+
133
+ self.set_w(size_frame, len(frames))
134
+
135
+ if self._first:
136
+ self.img = []
137
+ for i, frame in enumerate(frames):
138
+ self.img.append(self.ax[i].imshow(frame))
139
+ self._first = False
140
+ else:
141
+ for i, frame in enumerate(frames):
142
+ self.img[i].set_data(frame)
143
+ self.img[i].set_clim(frame.min(), frame.max())
144
+ self.fig.canvas.draw()
145
+ return self.fig
146
+
147
+ def plot_text(self, text):
148
+ self.set_w()
149
+
150
+ if self._first:
151
+ self.text = self.ax[0].text(0, 0, text, fontsize=14)
152
+ else:
153
+ del self.text
154
+ self.text = self.ax[0].text(0, 0, text, fontsize=14)
155
+ self.fig.canvas.draw()
156
+ return self.fig
157
+
158
+
159
+ @staticmethod
160
+ def plot_best_effort(plot_name, dataobj_or_list):
161
+ '''
162
+ Plot a data object or a list of data objects as best as it can be done.
163
+ The plot_name is used to remember the DataPlotter instance and allow
164
+ plot updates instead of expensive re-plots from scratch
165
+ '''
166
+
167
+ if plot_name not in dataplotter_cache:
168
+ dataplotter_cache[plot_name] = DataPlotter()
169
+
170
+ if isinstance(dataobj_or_list, list):
171
+ for obj in dataobj_or_list:
172
+ obj.xp = np
173
+ fig = dataplotter_cache[plot_name].multi_plot(dataobj_or_list)
174
+ else:
175
+ dataobj_or_list.xp = np # Supply a numpy instance, sometimes it is needed
176
+ fig = dataplotter_cache[plot_name].multi_plot([dataobj_or_list])
177
+ return fig
@@ -0,0 +1,142 @@
1
+ import numpy as np
2
+
3
+ from specula import cpuArray
4
+
5
+ from specula.display.base_display import BaseDisplay
6
+ from specula.connections import InputValue
7
+ from specula.data_objects.electric_field import ElectricField
8
+
9
+ from symao.turbolence import ft_ft2
10
+
11
+
12
+ class DoublePhaseDisplay(BaseDisplay):
13
+ def __init__(self,
14
+ title='Double Phase Display',
15
+ figsize=(12, 3)): # 4 subplots side by side
16
+
17
+ super().__init__(
18
+ title=title,
19
+ figsize=figsize
20
+ )
21
+
22
+ self.img1 = None
23
+ self.img2 = None
24
+ self.nframes = 0
25
+ self.psd_statTot1 = None
26
+ self.psd_statTot2 = None
27
+
28
+ # Setup inputs - two phase inputs
29
+ self.inputs['phase1'] = InputValue(type=ElectricField)
30
+ self.inputs['phase2'] = InputValue(type=ElectricField)
31
+
32
+ def _process_phase_data(self, phase):
33
+ """Process phase data: mask and remove average"""
34
+ frame = cpuArray(phase.phaseInNm * (phase.A > 0).astype(float))
35
+
36
+ # Get valid indices (where amplitude > 0)
37
+ valid_mask = cpuArray(phase.A) > 0
38
+
39
+ if np.any(valid_mask):
40
+ # Remove average phase only from valid pixels
41
+ frame[valid_mask] -= np.mean(frame[valid_mask])
42
+
43
+ if self.verbose:
44
+ print('Removing average phase in double_phase_display')
45
+
46
+ return frame
47
+
48
+ def _calculate_psd(self, frame):
49
+ """Calculate power spectral density"""
50
+ return np.absolute(ft_ft2(frame, 1))**2
51
+
52
+ def _get_data(self):
53
+ """Get both phases"""
54
+ phase1 = self.local_inputs.get('phase1')
55
+ phase2 = self.local_inputs.get('phase2')
56
+
57
+ if phase1 is None or phase2 is None:
58
+ return [] # BaseDisplay will show error
59
+
60
+ return [phase1, phase2]
61
+
62
+ def _update_display(self, data_list):
63
+ """Override base method - now receives list of [phase1, phase2]"""
64
+ if len(data_list) != 2:
65
+ self._show_error("Need both phase1 and phase2 inputs")
66
+ return
67
+
68
+ phase1, phase2 = data_list
69
+
70
+ # Process both phases
71
+ frame1 = self._process_phase_data(phase1)
72
+ frame2 = self._process_phase_data(phase2)
73
+
74
+ # Calculate PSDs
75
+ psd_stat1 = self._calculate_psd(frame1)
76
+ psd_stat2 = self._calculate_psd(frame2)
77
+ ss = frame1.shape[0]
78
+
79
+ # Update frame counter and accumulate PSDs
80
+ self.nframes += 1
81
+
82
+ if self.psd_statTot1 is None:
83
+ self.psd_statTot1 = np.zeros_like(psd_stat1)
84
+ if self.psd_statTot2 is None:
85
+ self.psd_statTot2 = np.zeros_like(psd_stat2)
86
+
87
+ self.psd_statTot1 = (self.psd_statTot1 * (self.nframes-1) + psd_stat1) / self.nframes
88
+ self.psd_statTot2 = (self.psd_statTot2 * (self.nframes-1) + psd_stat2) / self.nframes
89
+
90
+ # Create subplots on first run
91
+ if self.img1 is None:
92
+ # Clear default axis and create 4 subplots
93
+ self.fig.clear()
94
+ self.ax1 = self.fig.add_subplot(141)
95
+ self.ax2 = self.fig.add_subplot(142)
96
+ self.ax3 = self.fig.add_subplot(143)
97
+ self.ax4 = self.fig.add_subplot(144)
98
+
99
+ # Create images
100
+ self.img1 = self.ax1.imshow(frame1)
101
+ self.img2 = self.ax2.imshow(frame2)
102
+
103
+ # Set titles
104
+ self.ax1.set_title('Phase 1')
105
+ self.ax2.set_title('Phase 2')
106
+ self.ax3.set_title('PSD Instantaneous')
107
+ self.ax4.set_title('PSD Average')
108
+ else:
109
+ # Update existing images
110
+ self.img1.set_data(frame1)
111
+ self.img1.set_clim(frame1.min(), frame1.max())
112
+ self.img2.set_data(frame2)
113
+ self.img2.set_clim(frame2.min(), frame2.max())
114
+
115
+ # Clear and update PSD plots
116
+ self.ax3.clear()
117
+ self.ax4.clear()
118
+
119
+ # Plot instantaneous PSDs with low alpha
120
+ self.ax3.loglog(psd_stat1[ss//2, ss//2+1:], alpha=0.025, color='r', label='Phase 1')
121
+ self.ax3.loglog(psd_stat2[ss//2, ss//2+1:], alpha=0.025, color='b', label='Phase 2')
122
+ self.ax3.set_title('PSD Instantaneous')
123
+ self.ax3.legend()
124
+
125
+ # Plot averaged PSDs
126
+ self.ax4.loglog(self.psd_statTot1[ss//2, ss//2+1:], color='r', label='Phase 1')
127
+ self.ax4.loglog(self.psd_statTot2[ss//2, ss//2+1:], color='b', label='Phase 2')
128
+ self.ax4.set_title(f'PSD Average (n={self.nframes})')
129
+ self.ax4.legend()
130
+
131
+ self._safe_draw()
132
+
133
+ def trigger_code(self):
134
+ """Override to handle dual phase inputs"""
135
+ try:
136
+ if not self._opened:
137
+ self._create_figure()
138
+
139
+ # DoublePhaseDisplay handles dual inputs
140
+ self._update_display()
141
+ except Exception as e:
142
+ self._show_error(f"Double phase display error: {str(e)}")
@@ -0,0 +1,89 @@
1
+ import numpy as np
2
+
3
+ from specula.display.base_display import BaseDisplay
4
+ from specula.connections import InputValue
5
+ from specula.base_value import BaseValue
6
+ from specula import cpuArray
7
+
8
+
9
+ class ModesDisplay(BaseDisplay):
10
+ def __init__(self,
11
+ title='Modes Display',
12
+ figsize=(6, 3),
13
+ xrange=None,
14
+ yrange=(-500, 500)):
15
+
16
+ super().__init__(
17
+ title=title,
18
+ figsize=figsize
19
+ )
20
+
21
+ self._xrange = xrange
22
+ self._yrange = yrange
23
+ self.line = None
24
+
25
+ # Setup input
26
+ self.input_key = 'modes' # Used by base class
27
+ self.inputs['modes'] = InputValue(type=BaseValue)
28
+
29
+ def _update_display(self, modes):
30
+ """Override base method to implement modes-specific display"""
31
+ # Get the modes vector
32
+ y = cpuArray(modes.value)
33
+ x = np.arange(len(y))
34
+
35
+ if self.line is None:
36
+ # First time: create line
37
+ self.line = self.ax.plot(x, y, '.-')[0]
38
+
39
+ # Set fixed X range if specified
40
+ if self._xrange is not None:
41
+ self.ax.set_xlim(self._xrange[0], self._xrange[1])
42
+ elif len(x) > 0:
43
+ self.ax.set_xlim(0, len(x) - 1)
44
+
45
+ # Set fixed Y range if specified
46
+ if np.sum(np.abs(self._yrange)) > 0:
47
+ self.ax.set_ylim(self._yrange[0], self._yrange[1])
48
+ else:
49
+ # Auto-scale based on data
50
+ self.ax.set_ylim(y.min() * 1.1, y.max() * 1.1)
51
+
52
+ # Add reference line at y=0
53
+ self.ax.axhline(y=0, color='grey', linestyle='--',
54
+ dashes=(4, 8), linewidth=0.5, alpha=0.7)
55
+
56
+ # Set labels
57
+ self.ax.set_xlabel('Mode Index')
58
+ self.ax.set_ylabel('Mode Value')
59
+ else:
60
+ # Update existing line
61
+ self.line.set_xdata(x)
62
+ self.line.set_ydata(y)
63
+
64
+ # Update X limits if vector size changed or xrange is set
65
+ if self._xrange is not None:
66
+ self.ax.set_xlim(self._xrange[0], self._xrange[1])
67
+ elif len(x) > 0:
68
+ self.ax.set_xlim(0, len(x) - 1)
69
+
70
+ # Update Y limits if auto-scaling
71
+ if np.sum(np.abs(self._yrange)) == 0:
72
+ self.ax.set_ylim(y.min() * 1.1, y.max() * 1.1)
73
+
74
+ # Draw efficiently
75
+ self._safe_draw()
76
+
77
+ def set_x_range(self, xmin, xmax):
78
+ """Set fixed X axis range"""
79
+ self._xrange = (xmin, xmax)
80
+ if self.line is not None:
81
+ self.ax.set_xlim(xmin, xmax)
82
+ self._safe_draw()
83
+
84
+ def set_y_range(self, ymin, ymax):
85
+ """Set fixed Y axis range"""
86
+ self._yrange = (ymin, ymax)
87
+ if self.line is not None:
88
+ self.ax.set_ylim(ymin, ymax)
89
+ self._safe_draw()
@@ -0,0 +1,48 @@
1
+
2
+ import numpy as np
3
+
4
+ from specula import cpuArray
5
+
6
+ from specula.display.base_display import BaseDisplay
7
+ from specula.connections import InputValue
8
+ from specula.data_objects.electric_field import ElectricField
9
+
10
+ class PhaseDisplay(BaseDisplay):
11
+ def __init__(self,
12
+ title='Phase Display',
13
+ figsize=(8, 6)): # Default size in inches
14
+ super().__init__(
15
+ title=title,
16
+ figsize=figsize
17
+ )
18
+
19
+ # Setup input
20
+ self.input_key = 'phase' # Used by base class
21
+ self.inputs['phase'] = InputValue(type=ElectricField)
22
+
23
+ def _process_phase_data(self, phase):
24
+ """Process phase data: mask and remove average"""
25
+ frame = cpuArray(phase.phaseInNm * (phase.A > 0).astype(float))
26
+
27
+ # Get valid indices (where amplitude > 0)
28
+ valid_mask = cpuArray(phase.A) > 0
29
+
30
+ if np.any(valid_mask):
31
+ # Remove average phase only from valid pixels
32
+ frame[valid_mask] -= np.mean(frame[valid_mask])
33
+
34
+ if self.verbose:
35
+ print('Removing average phase in phase_display')
36
+
37
+ return frame
38
+
39
+ def _update_display(self, phase):
40
+ frame = self._process_phase_data(phase)
41
+
42
+ if self.img is None:
43
+ self.img = self.ax.imshow(frame)
44
+ self._add_colorbar_if_needed(self.img)
45
+ else:
46
+ self._update_image_data(self.img, frame)
47
+
48
+ self._safe_draw()