FlowCyPy 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- FlowCyPy/__init__.py +15 -0
- FlowCyPy/_version.py +16 -0
- FlowCyPy/classifier.py +196 -0
- FlowCyPy/coupling_mechanism/__init__.py +4 -0
- FlowCyPy/coupling_mechanism/empirical.py +47 -0
- FlowCyPy/coupling_mechanism/mie.py +205 -0
- FlowCyPy/coupling_mechanism/rayleigh.py +115 -0
- FlowCyPy/coupling_mechanism/uniform.py +39 -0
- FlowCyPy/cytometer.py +198 -0
- FlowCyPy/detector.py +616 -0
- FlowCyPy/directories.py +36 -0
- FlowCyPy/distribution/__init__.py +16 -0
- FlowCyPy/distribution/base_class.py +59 -0
- FlowCyPy/distribution/delta.py +86 -0
- FlowCyPy/distribution/lognormal.py +94 -0
- FlowCyPy/distribution/normal.py +95 -0
- FlowCyPy/distribution/particle_size_distribution.py +110 -0
- FlowCyPy/distribution/uniform.py +96 -0
- FlowCyPy/distribution/weibull.py +80 -0
- FlowCyPy/event_correlator.py +244 -0
- FlowCyPy/flow_cell.py +122 -0
- FlowCyPy/helper.py +85 -0
- FlowCyPy/logger.py +322 -0
- FlowCyPy/noises.py +29 -0
- FlowCyPy/particle_count.py +102 -0
- FlowCyPy/peak_locator/__init__.py +4 -0
- FlowCyPy/peak_locator/base_class.py +163 -0
- FlowCyPy/peak_locator/basic.py +108 -0
- FlowCyPy/peak_locator/derivative.py +143 -0
- FlowCyPy/peak_locator/moving_average.py +114 -0
- FlowCyPy/physical_constant.py +19 -0
- FlowCyPy/plottings.py +270 -0
- FlowCyPy/population.py +239 -0
- FlowCyPy/populations_instances.py +49 -0
- FlowCyPy/report.py +236 -0
- FlowCyPy/scatterer.py +373 -0
- FlowCyPy/source.py +249 -0
- FlowCyPy/units.py +26 -0
- FlowCyPy/utils.py +191 -0
- FlowCyPy-0.5.0.dist-info/LICENSE +21 -0
- FlowCyPy-0.5.0.dist-info/METADATA +252 -0
- FlowCyPy-0.5.0.dist-info/RECORD +44 -0
- FlowCyPy-0.5.0.dist-info/WHEEL +5 -0
- FlowCyPy-0.5.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,244 @@
|
|
|
1
|
+
from typing import Optional, Union
|
|
2
|
+
import pandas as pd
|
|
3
|
+
import numpy as np
|
|
4
|
+
import matplotlib.pyplot as plt
|
|
5
|
+
import seaborn as sns
|
|
6
|
+
from MPSPlots.styles import mps
|
|
7
|
+
from FlowCyPy.units import second
|
|
8
|
+
import warnings
|
|
9
|
+
from FlowCyPy.cytometer import FlowCytometer
|
|
10
|
+
from FlowCyPy.logger import EventCorrelatorLogger
|
|
11
|
+
from FlowCyPy.report import Report
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class EventCorrelator:
|
|
15
|
+
"""
|
|
16
|
+
A class for analyzing pulse signals generated by a flow cytometer. It processes signals
|
|
17
|
+
from multiple detectors to extract key features such as peak height, width, area,
|
|
18
|
+
and time of occurrence, while providing tools to detect coincident events between detectors.
|
|
19
|
+
|
|
20
|
+
Parameters
|
|
21
|
+
----------
|
|
22
|
+
cytometer : FlowCytometer
|
|
23
|
+
An instance of FlowCytometer that contains detectors and their associated signals.
|
|
24
|
+
|
|
25
|
+
Methods
|
|
26
|
+
-------
|
|
27
|
+
run_analysis(compute_peak_area=False):
|
|
28
|
+
Performs peak analysis on all detectors, extracting features such as height, width,
|
|
29
|
+
and area, and stores the results in a DataFrame.
|
|
30
|
+
get_coincidence(margin):
|
|
31
|
+
Identifies and returns events where peak times from different detectors coincide
|
|
32
|
+
within a given time margin.
|
|
33
|
+
display_features():
|
|
34
|
+
Displays the extracted features from the analysis in a tabular format.
|
|
35
|
+
plot_peak(show=True, figure_size=(7, 6)):
|
|
36
|
+
Plots the signals and highlights the detected peaks and their properties (e.g., width).
|
|
37
|
+
plot(show=True, log_plot=True):
|
|
38
|
+
Generates a 2D density plot of the scattering intensities from the detectors.
|
|
39
|
+
generate_report(filename):
|
|
40
|
+
Generates a report summarizing the results of the analysis.
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
def __init__(self, cytometer: FlowCytometer) -> None:
|
|
44
|
+
"""
|
|
45
|
+
Initializes the Analyzer class with the cytometer object.
|
|
46
|
+
|
|
47
|
+
Parameters
|
|
48
|
+
----------
|
|
49
|
+
cytometer : FlowCytometer
|
|
50
|
+
An instance of FlowCytometer that contains detectors and their signals.
|
|
51
|
+
"""
|
|
52
|
+
self.cytometer = cytometer
|
|
53
|
+
self.datasets = []
|
|
54
|
+
|
|
55
|
+
def run_analysis(self, compute_peak_area: bool = False) -> None:
|
|
56
|
+
"""
|
|
57
|
+
Runs the peak detection analysis on all detectors within the flow cytometer.
|
|
58
|
+
Extracts key features such as peak height, width, area, and stores the results
|
|
59
|
+
in a pandas DataFrame for further analysis.
|
|
60
|
+
|
|
61
|
+
Parameters
|
|
62
|
+
----------
|
|
63
|
+
compute_peak_area : bool, optional
|
|
64
|
+
Whether to compute the area under the peaks, by default False.
|
|
65
|
+
|
|
66
|
+
Returns
|
|
67
|
+
-------
|
|
68
|
+
pd.DataFrame
|
|
69
|
+
A DataFrame with the detected peak properties, organized by detector.
|
|
70
|
+
"""
|
|
71
|
+
# Run peak detection on each detector
|
|
72
|
+
self.dataframe = pd.concat(
|
|
73
|
+
[d.algorithm.peak_properties for d in self.cytometer.detectors],
|
|
74
|
+
keys=[f'{d.name}' for d in self.cytometer.detectors]
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
self.dataframe.index.names = ['Detector', 'Event']
|
|
78
|
+
|
|
79
|
+
self._log_statistics()
|
|
80
|
+
|
|
81
|
+
def _log_statistics(self) -> EventCorrelatorLogger:
|
|
82
|
+
"""
|
|
83
|
+
Logs statistical information about the detected peaks for each detector,
|
|
84
|
+
including the number of events, the first and last peak times, the average
|
|
85
|
+
time between peaks, and the minimum time between peaks.
|
|
86
|
+
|
|
87
|
+
The results are displayed in a formatted table using `tabulate` for clarity.
|
|
88
|
+
"""
|
|
89
|
+
logger = EventCorrelatorLogger(self)
|
|
90
|
+
|
|
91
|
+
logger.log_statistics(table_format="fancy_grid")
|
|
92
|
+
|
|
93
|
+
return logger
|
|
94
|
+
|
|
95
|
+
def get_coincidence(self, margin: second.dimensionality) -> pd.DataFrame:
|
|
96
|
+
"""
|
|
97
|
+
Identifies coincident events between two detectors within a specified time margin.
|
|
98
|
+
|
|
99
|
+
Parameters
|
|
100
|
+
----------
|
|
101
|
+
margin : pint.Quantity
|
|
102
|
+
The time margin within which peaks are considered coincident, in compatible time units.
|
|
103
|
+
"""
|
|
104
|
+
# Ensure margin has correct dimensionality (time)
|
|
105
|
+
assert margin.dimensionality == second.dimensionality, "Margin must have time dimensionality."
|
|
106
|
+
|
|
107
|
+
self.dataframe['PeakTimes'] = self.dataframe['PeakTimes'].pint.to(margin.units)
|
|
108
|
+
|
|
109
|
+
# Split the data for Detector_1 and Detector_2
|
|
110
|
+
d0 = self.dataframe.xs(self.cytometer.detectors[0].name, level='Detector')
|
|
111
|
+
d1 = self.dataframe.xs(self.cytometer.detectors[1].name, level='Detector')
|
|
112
|
+
|
|
113
|
+
# Repeat and tile PeakTimes for comparison (keeping your protocol)
|
|
114
|
+
d0_repeated = np.repeat(d0['PeakTimes'].values.numpy_data, len(d1)) * margin.units
|
|
115
|
+
d1_tiled = np.tile(d1['PeakTimes'].values.numpy_data, len(d0)) * margin.units
|
|
116
|
+
|
|
117
|
+
# Compute time differences and reshape the mask
|
|
118
|
+
time_diffs = np.abs(d0_repeated - d1_tiled)
|
|
119
|
+
mask = time_diffs <= margin
|
|
120
|
+
mask = mask.reshape(len(d0), len(d1))
|
|
121
|
+
|
|
122
|
+
# Find indices where coincidences occur
|
|
123
|
+
indices = np.where(mask)
|
|
124
|
+
|
|
125
|
+
# Count coincidences per column (for each event in Detector_1)
|
|
126
|
+
true_count_per_column = np.sum(mask.astype(int), axis=0)
|
|
127
|
+
|
|
128
|
+
# Warnings and assertions
|
|
129
|
+
if np.all(true_count_per_column == 0):
|
|
130
|
+
warnings.warn("No coincidence events found, the margin might be too low.")
|
|
131
|
+
|
|
132
|
+
assert np.all(true_count_per_column <= 1), \
|
|
133
|
+
"Coincidence events are ambiguously defined, the margin might be too high."
|
|
134
|
+
|
|
135
|
+
# Extract coincident events from both detectors
|
|
136
|
+
coincident_detector_0 = d0.iloc[indices[0]].reset_index(drop=True)
|
|
137
|
+
coincident_detector_1 = d1.iloc[indices[1]].reset_index(drop=True)
|
|
138
|
+
|
|
139
|
+
# Combine the coincident events into a single DataFrame
|
|
140
|
+
combined_coincidences = pd.concat([coincident_detector_0, coincident_detector_1], axis=1)
|
|
141
|
+
|
|
142
|
+
# Assign proper MultiIndex column names
|
|
143
|
+
combined_coincidences.columns = pd.MultiIndex.from_product([[d.name for d in self.cytometer.detectors], d0.columns])
|
|
144
|
+
|
|
145
|
+
self.coincidence = combined_coincidences
|
|
146
|
+
|
|
147
|
+
self.coincidence['Label'] = 0
|
|
148
|
+
|
|
149
|
+
return self.coincidence
|
|
150
|
+
|
|
151
|
+
def display_features(self) -> None:
|
|
152
|
+
"""
|
|
153
|
+
Displays extracted peak features for all datasets in a tabular format.
|
|
154
|
+
"""
|
|
155
|
+
for i, dataset in enumerate(self.datasets):
|
|
156
|
+
print(f"\nFeatures for Dataset {i + 1}:")
|
|
157
|
+
dataset.print_properties() # Reuse the print_properties method from DataSet
|
|
158
|
+
|
|
159
|
+
def plot(self, show: bool = True, log_plot: bool = True, x_limits: tuple = None, y_limits: tuple = None, bandwidth_adjust: float = 1, color_palette: Optional[Union[str, dict]] = None) -> None:
|
|
160
|
+
"""
|
|
161
|
+
Plots a 2D density plot of the scattering intensities from the two detectors,
|
|
162
|
+
along with individual peak heights.
|
|
163
|
+
|
|
164
|
+
Parameters
|
|
165
|
+
----------
|
|
166
|
+
show : bool, optional
|
|
167
|
+
Whether to display the plot immediately, by default True.
|
|
168
|
+
log_plot : bool, optional
|
|
169
|
+
Whether to use logarithmic scaling for the plot axes, by default True.
|
|
170
|
+
x_limits : tuple, optional
|
|
171
|
+
The x-axis limits (min, max), by default None.
|
|
172
|
+
y_limits : tuple, optional
|
|
173
|
+
The y-axis limits (min, max), by default None.
|
|
174
|
+
bandwidth_adjust : float, optional
|
|
175
|
+
Bandwidth adjustment factor for the kernel density estimate of the marginal distributions. Default is 1.
|
|
176
|
+
color_palette : str or dict, optional
|
|
177
|
+
The color palette to use for the hue in the scatterplot. Can be a seaborn palette name
|
|
178
|
+
(e.g., 'viridis', 'coolwarm') or a dictionary mapping hue levels to specific colors. Default is None.
|
|
179
|
+
"""
|
|
180
|
+
# Reset the index if necessary (to handle MultiIndex)
|
|
181
|
+
df_reset = self.coincidence.reset_index()
|
|
182
|
+
|
|
183
|
+
x_data = df_reset[(self.cytometer.detectors[0].name, 'Heights')]
|
|
184
|
+
y_data = df_reset[(self.cytometer.detectors[1].name, 'Heights')]
|
|
185
|
+
|
|
186
|
+
# Extract the units from the pint-pandas columns
|
|
187
|
+
x_units = x_data.max().to_compact().units
|
|
188
|
+
y_units = y_data.max().to_compact().units
|
|
189
|
+
|
|
190
|
+
x_data = x_data.pint.to(x_units)
|
|
191
|
+
y_data = y_data.pint.to(y_units)
|
|
192
|
+
|
|
193
|
+
with plt.style.context(mps):
|
|
194
|
+
|
|
195
|
+
g = sns.jointplot(data=df_reset, x=x_data, y=y_data,
|
|
196
|
+
kind='kde', alpha=0.8, fill=True,
|
|
197
|
+
joint_kws={'alpha': 0.7, 'bw_adjust': bandwidth_adjust}
|
|
198
|
+
)
|
|
199
|
+
sns.scatterplot(data=df_reset, x=x_data, y=y_data,
|
|
200
|
+
hue='Label', palette=color_palette, ax=g.ax_joint, alpha=0.6, zorder=1
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
# Set the x and y labels with units
|
|
204
|
+
g.ax_joint.set_xlabel(f"Heights : {self.cytometer.detectors[0].name} [{x_units:P}]")
|
|
205
|
+
g.ax_joint.set_ylabel(f"Heights: {self.cytometer.detectors[1].name} [{y_units:P}]")
|
|
206
|
+
|
|
207
|
+
if log_plot:
|
|
208
|
+
g.ax_marg_x.set_xscale('log')
|
|
209
|
+
g.ax_marg_y.set_yscale('log')
|
|
210
|
+
|
|
211
|
+
if x_limits is not None:
|
|
212
|
+
x0, x1 = x_limits
|
|
213
|
+
x0 = x0.to(x_units).magnitude
|
|
214
|
+
x1 = x1.to(x_units).magnitude
|
|
215
|
+
g.ax_joint.set_xlim(x0, x1)
|
|
216
|
+
|
|
217
|
+
if y_limits is not None:
|
|
218
|
+
y0, y1 = y_limits
|
|
219
|
+
y0 = y0.to(y_units).magnitude
|
|
220
|
+
y1 = y1.to(y_units).magnitude
|
|
221
|
+
g.ax_joint.set_ylim(y0, y1)
|
|
222
|
+
|
|
223
|
+
plt.tight_layout()
|
|
224
|
+
|
|
225
|
+
if show:
|
|
226
|
+
plt.show()
|
|
227
|
+
|
|
228
|
+
def generate_report(self, filename: str) -> None:
|
|
229
|
+
"""
|
|
230
|
+
Generates a detailed report summarizing the analysis, including peak features
|
|
231
|
+
and detected events.
|
|
232
|
+
|
|
233
|
+
Parameters
|
|
234
|
+
----------
|
|
235
|
+
filename : str
|
|
236
|
+
The filename where the report will be saved.
|
|
237
|
+
"""
|
|
238
|
+
report = Report(
|
|
239
|
+
flow_cell=self.cytometer.scatterer.flow_cell,
|
|
240
|
+
scatterer=self.cytometer.scatterer,
|
|
241
|
+
analyzer=self
|
|
242
|
+
)
|
|
243
|
+
|
|
244
|
+
report.generate_report()
|
FlowCyPy/flow_cell.py
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
1
|
+
from typing import List
|
|
2
|
+
from FlowCyPy.units import meter, second
|
|
3
|
+
from PyMieSim.units import Quantity
|
|
4
|
+
from tabulate import tabulate
|
|
5
|
+
from pydantic.dataclasses import dataclass
|
|
6
|
+
from pydantic import field_validator
|
|
7
|
+
|
|
8
|
+
config_dict = dict(
|
|
9
|
+
arbitrary_types_allowed=True,
|
|
10
|
+
kw_only=True,
|
|
11
|
+
slots=True,
|
|
12
|
+
extra='forbid'
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@dataclass(config=config_dict)
|
|
17
|
+
class FlowCell(object):
|
|
18
|
+
"""
|
|
19
|
+
Models the flow parameters in a flow cytometer, including flow speed, flow area,
|
|
20
|
+
and particle interactions. This class interacts with ScattererDistribution to simulate
|
|
21
|
+
the flow of particles through the cytometer.
|
|
22
|
+
|
|
23
|
+
Parameters
|
|
24
|
+
----------
|
|
25
|
+
flow_speed : Quantity
|
|
26
|
+
The speed of the flow in meters per second (m/s).
|
|
27
|
+
flow_area : Quantity
|
|
28
|
+
The cross-sectional area of the flow tube in square meters (m²).
|
|
29
|
+
run_time : Quantity
|
|
30
|
+
The total duration of the flow simulation in seconds.
|
|
31
|
+
"""
|
|
32
|
+
flow_speed: Quantity
|
|
33
|
+
flow_area: Quantity
|
|
34
|
+
run_time: Quantity
|
|
35
|
+
|
|
36
|
+
@field_validator('flow_speed')
|
|
37
|
+
def _validate_flow_speed(cls, value):
|
|
38
|
+
"""
|
|
39
|
+
Validates that the flow speed is provided in meter per second.
|
|
40
|
+
|
|
41
|
+
Parameters
|
|
42
|
+
----------
|
|
43
|
+
value : Quantity
|
|
44
|
+
The flow speed to validate.
|
|
45
|
+
|
|
46
|
+
Returns
|
|
47
|
+
-------
|
|
48
|
+
Quantity
|
|
49
|
+
The flow speed frequency.
|
|
50
|
+
|
|
51
|
+
Raises:
|
|
52
|
+
ValueError: If the flow speed is not in meter per second.
|
|
53
|
+
"""
|
|
54
|
+
if not value.check(meter / second):
|
|
55
|
+
raise ValueError(f"flow_speed must be in meter per second, but got {value.units}")
|
|
56
|
+
return value
|
|
57
|
+
|
|
58
|
+
@field_validator('flow_area')
|
|
59
|
+
def _validate_flow_area(cls, value):
|
|
60
|
+
"""
|
|
61
|
+
Validates that the flow area is provided in hertz.
|
|
62
|
+
|
|
63
|
+
Parameters
|
|
64
|
+
----------
|
|
65
|
+
value : Quantity
|
|
66
|
+
The flow area to validate.
|
|
67
|
+
|
|
68
|
+
Returns
|
|
69
|
+
-------
|
|
70
|
+
Quantity
|
|
71
|
+
The validated flow area.
|
|
72
|
+
|
|
73
|
+
Raises:
|
|
74
|
+
ValueError: If the flow area is not in hertz.
|
|
75
|
+
"""
|
|
76
|
+
if not value.check(meter ** 2):
|
|
77
|
+
raise ValueError(f"flow_area must be in meter ** 2, but got {value.units}")
|
|
78
|
+
return value
|
|
79
|
+
|
|
80
|
+
@field_validator('run_time')
|
|
81
|
+
def _validate_run_time(cls, value):
|
|
82
|
+
"""
|
|
83
|
+
Validates that the total time is provided in second.
|
|
84
|
+
|
|
85
|
+
Parameters
|
|
86
|
+
----------
|
|
87
|
+
value : Quantity
|
|
88
|
+
The total time to validate.
|
|
89
|
+
|
|
90
|
+
Returns
|
|
91
|
+
-------
|
|
92
|
+
Quantity
|
|
93
|
+
The validated total time.
|
|
94
|
+
|
|
95
|
+
Raises:
|
|
96
|
+
ValueError: If the total time is not in second.
|
|
97
|
+
"""
|
|
98
|
+
if not value.check(second):
|
|
99
|
+
raise ValueError(f"run_time must be in second, but got {value.units}")
|
|
100
|
+
return value
|
|
101
|
+
|
|
102
|
+
def __post_init__(self):
|
|
103
|
+
"""Initialize units for flow parameters."""
|
|
104
|
+
self.flow_speed = Quantity(self.flow_speed, meter / second)
|
|
105
|
+
self.flow_area = Quantity(self.flow_area, meter ** 2)
|
|
106
|
+
self.run_time = Quantity(self.run_time, second)
|
|
107
|
+
|
|
108
|
+
self.volume = self.flow_area * self.flow_speed * self.run_time
|
|
109
|
+
|
|
110
|
+
def print_properties(self) -> None:
|
|
111
|
+
"""
|
|
112
|
+
Print the core properties of the flow and particle interactions in the flow cytometer.
|
|
113
|
+
"""
|
|
114
|
+
print("\nFlow Properties")
|
|
115
|
+
print(tabulate(self.get_properties(), headers=["Property", "Value"], tablefmt="grid"))
|
|
116
|
+
|
|
117
|
+
def get_properties(self) -> List[List[str]]:
|
|
118
|
+
return [
|
|
119
|
+
['Flow Speed', f"{self.flow_speed:.2f~#P}"],
|
|
120
|
+
['Flow Area', f"{self.flow_area:.2f~#P}"],
|
|
121
|
+
['Total Time', f"{self.run_time:.2f~#P}"]
|
|
122
|
+
]
|
FlowCyPy/helper.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
1
|
+
from typing import Callable
|
|
2
|
+
import matplotlib.pyplot as plt
|
|
3
|
+
from MPSPlots.styles import mps
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def plot_helper(function: Callable) -> Callable:
|
|
7
|
+
"""
|
|
8
|
+
A decorator that helps in plotting by wrapping a plotting function with additional functionality
|
|
9
|
+
such as handling axes creation, setting the figure style, managing legends, and saving figures.
|
|
10
|
+
|
|
11
|
+
Parameters
|
|
12
|
+
----------
|
|
13
|
+
function : Callable
|
|
14
|
+
The plotting function that is decorated. It should accept `self`, `ax`, and `mode_of_interest`
|
|
15
|
+
as parameters.
|
|
16
|
+
|
|
17
|
+
Returns
|
|
18
|
+
-------
|
|
19
|
+
Callable
|
|
20
|
+
A wrapper function that adds the specified plotting functionalities.
|
|
21
|
+
|
|
22
|
+
Notes
|
|
23
|
+
-----
|
|
24
|
+
This decorator expects the decorated function to have the following signature:
|
|
25
|
+
`function(self, ax=None, mode_of_interest='all', **kwargs)`.
|
|
26
|
+
"""
|
|
27
|
+
def wrapper(self, ax: plt.Axes = None, show: bool = True, save_filename: str = None, figure_size: tuple = None, **kwargs) -> plt.Figure:
|
|
28
|
+
"""
|
|
29
|
+
A wrapped version of the plotting function that provides additional functionality for creating
|
|
30
|
+
and managing plots.
|
|
31
|
+
|
|
32
|
+
Parameters
|
|
33
|
+
----------
|
|
34
|
+
self : object
|
|
35
|
+
The instance of the class calling this method.
|
|
36
|
+
ax : plt.Axes, optional
|
|
37
|
+
A matplotlib Axes object to draw the plot on. If None, a new figure and axes are created.
|
|
38
|
+
Default is None.
|
|
39
|
+
show : bool, optional
|
|
40
|
+
Whether to display the plot. If False, the plot will not be shown but can still be saved
|
|
41
|
+
or returned. Default is True.
|
|
42
|
+
mode_of_interest : str, optional
|
|
43
|
+
Specifies the mode of interest for the plot. If 'all', all available modes will be plotted.
|
|
44
|
+
This parameter is interpreted using the `interpret_mode_of_interest` function. Default is 'all'.
|
|
45
|
+
save_filename : str, optional
|
|
46
|
+
A file path to save the figure. If None, the figure will not be saved. Default is None.
|
|
47
|
+
**kwargs : dict
|
|
48
|
+
Additional keyword arguments passed to the decorated function.
|
|
49
|
+
|
|
50
|
+
Returns
|
|
51
|
+
-------
|
|
52
|
+
plt.Figure
|
|
53
|
+
The matplotlib Figure object created or used for the plot.
|
|
54
|
+
|
|
55
|
+
Notes
|
|
56
|
+
-----
|
|
57
|
+
- If no `ax` is provided, a new figure and axes are created using the style context `mps`.
|
|
58
|
+
- The legend is only added if there are labels to display.
|
|
59
|
+
- If `save_filename` is specified, the figure is saved to the given path.
|
|
60
|
+
- The plot is shown if `show` is set to True.
|
|
61
|
+
"""
|
|
62
|
+
if ax is None:
|
|
63
|
+
with plt.style.context(mps):
|
|
64
|
+
figure, ax = plt.subplots(1, 1, figsize=figure_size)
|
|
65
|
+
|
|
66
|
+
else:
|
|
67
|
+
figure = ax.get_figure()
|
|
68
|
+
|
|
69
|
+
output = function(self, ax=ax, **kwargs)
|
|
70
|
+
|
|
71
|
+
_, labels = ax.get_legend_handles_labels()
|
|
72
|
+
|
|
73
|
+
# Only add a legend if there are labels
|
|
74
|
+
if labels:
|
|
75
|
+
ax.legend()
|
|
76
|
+
|
|
77
|
+
if save_filename:
|
|
78
|
+
figure.savefig(save_filename)
|
|
79
|
+
|
|
80
|
+
if show:
|
|
81
|
+
plt.show()
|
|
82
|
+
|
|
83
|
+
return output
|
|
84
|
+
|
|
85
|
+
return wrapper
|