AeroViz 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of AeroViz might be problematic. Click here for more details.

Files changed (102) hide show
  1. AeroViz/__init__.py +15 -0
  2. AeroViz/dataProcess/Chemistry/__init__.py +63 -0
  3. AeroViz/dataProcess/Chemistry/_calculate.py +27 -0
  4. AeroViz/dataProcess/Chemistry/_isoropia.py +99 -0
  5. AeroViz/dataProcess/Chemistry/_mass_volume.py +175 -0
  6. AeroViz/dataProcess/Chemistry/_ocec.py +184 -0
  7. AeroViz/dataProcess/Chemistry/_partition.py +29 -0
  8. AeroViz/dataProcess/Chemistry/_teom.py +16 -0
  9. AeroViz/dataProcess/Optical/_IMPROVE.py +61 -0
  10. AeroViz/dataProcess/Optical/__init__.py +62 -0
  11. AeroViz/dataProcess/Optical/_absorption.py +54 -0
  12. AeroViz/dataProcess/Optical/_extinction.py +36 -0
  13. AeroViz/dataProcess/Optical/_mie.py +16 -0
  14. AeroViz/dataProcess/Optical/_mie_sd.py +143 -0
  15. AeroViz/dataProcess/Optical/_scattering.py +30 -0
  16. AeroViz/dataProcess/SizeDistr/__init__.py +61 -0
  17. AeroViz/dataProcess/SizeDistr/__merge.py +250 -0
  18. AeroViz/dataProcess/SizeDistr/_merge.py +245 -0
  19. AeroViz/dataProcess/SizeDistr/_merge_v1.py +254 -0
  20. AeroViz/dataProcess/SizeDistr/_merge_v2.py +243 -0
  21. AeroViz/dataProcess/SizeDistr/_merge_v3.py +518 -0
  22. AeroViz/dataProcess/SizeDistr/_merge_v4.py +424 -0
  23. AeroViz/dataProcess/SizeDistr/_size_distr.py +93 -0
  24. AeroViz/dataProcess/VOC/__init__.py +19 -0
  25. AeroViz/dataProcess/VOC/_potential_par.py +76 -0
  26. AeroViz/dataProcess/__init__.py +11 -0
  27. AeroViz/dataProcess/core/__init__.py +92 -0
  28. AeroViz/plot/__init__.py +7 -0
  29. AeroViz/plot/distribution/__init__.py +1 -0
  30. AeroViz/plot/distribution/distribution.py +582 -0
  31. AeroViz/plot/improve/__init__.py +1 -0
  32. AeroViz/plot/improve/improve.py +240 -0
  33. AeroViz/plot/meteorology/__init__.py +1 -0
  34. AeroViz/plot/meteorology/meteorology.py +317 -0
  35. AeroViz/plot/optical/__init__.py +2 -0
  36. AeroViz/plot/optical/aethalometer.py +77 -0
  37. AeroViz/plot/optical/optical.py +388 -0
  38. AeroViz/plot/templates/__init__.py +8 -0
  39. AeroViz/plot/templates/contour.py +47 -0
  40. AeroViz/plot/templates/corr_matrix.py +108 -0
  41. AeroViz/plot/templates/diurnal_pattern.py +42 -0
  42. AeroViz/plot/templates/event_evolution.py +65 -0
  43. AeroViz/plot/templates/koschmieder.py +156 -0
  44. AeroViz/plot/templates/metal_heatmap.py +57 -0
  45. AeroViz/plot/templates/regression.py +256 -0
  46. AeroViz/plot/templates/scatter.py +130 -0
  47. AeroViz/plot/templates/templates.py +398 -0
  48. AeroViz/plot/timeseries/__init__.py +1 -0
  49. AeroViz/plot/timeseries/timeseries.py +317 -0
  50. AeroViz/plot/utils/__init__.py +3 -0
  51. AeroViz/plot/utils/_color.py +71 -0
  52. AeroViz/plot/utils/_decorator.py +74 -0
  53. AeroViz/plot/utils/_unit.py +55 -0
  54. AeroViz/process/__init__.py +31 -0
  55. AeroViz/process/core/DataProc.py +19 -0
  56. AeroViz/process/core/SizeDist.py +90 -0
  57. AeroViz/process/core/__init__.py +4 -0
  58. AeroViz/process/method/PyMieScatt_update.py +567 -0
  59. AeroViz/process/method/__init__.py +2 -0
  60. AeroViz/process/method/mie_theory.py +258 -0
  61. AeroViz/process/method/prop.py +62 -0
  62. AeroViz/process/script/AbstractDistCalc.py +143 -0
  63. AeroViz/process/script/Chemical.py +176 -0
  64. AeroViz/process/script/IMPACT.py +49 -0
  65. AeroViz/process/script/IMPROVE.py +161 -0
  66. AeroViz/process/script/Others.py +65 -0
  67. AeroViz/process/script/PSD.py +103 -0
  68. AeroViz/process/script/PSD_dry.py +94 -0
  69. AeroViz/process/script/__init__.py +5 -0
  70. AeroViz/process/script/retrieve_RI.py +70 -0
  71. AeroViz/rawDataReader/__init__.py +68 -0
  72. AeroViz/rawDataReader/core/__init__.py +397 -0
  73. AeroViz/rawDataReader/script/AE33.py +31 -0
  74. AeroViz/rawDataReader/script/AE43.py +34 -0
  75. AeroViz/rawDataReader/script/APS_3321.py +47 -0
  76. AeroViz/rawDataReader/script/Aurora.py +38 -0
  77. AeroViz/rawDataReader/script/BC1054.py +46 -0
  78. AeroViz/rawDataReader/script/EPA_vertical.py +18 -0
  79. AeroViz/rawDataReader/script/GRIMM.py +35 -0
  80. AeroViz/rawDataReader/script/IGAC_TH.py +104 -0
  81. AeroViz/rawDataReader/script/IGAC_ZM.py +90 -0
  82. AeroViz/rawDataReader/script/MA350.py +45 -0
  83. AeroViz/rawDataReader/script/NEPH.py +57 -0
  84. AeroViz/rawDataReader/script/OCEC_LCRES.py +34 -0
  85. AeroViz/rawDataReader/script/OCEC_RES.py +28 -0
  86. AeroViz/rawDataReader/script/SMPS_TH.py +41 -0
  87. AeroViz/rawDataReader/script/SMPS_aim11.py +51 -0
  88. AeroViz/rawDataReader/script/SMPS_genr.py +51 -0
  89. AeroViz/rawDataReader/script/TEOM.py +46 -0
  90. AeroViz/rawDataReader/script/Table.py +28 -0
  91. AeroViz/rawDataReader/script/VOC_TH.py +30 -0
  92. AeroViz/rawDataReader/script/VOC_ZM.py +37 -0
  93. AeroViz/rawDataReader/script/__init__.py +22 -0
  94. AeroViz/tools/__init__.py +3 -0
  95. AeroViz/tools/database.py +94 -0
  96. AeroViz/tools/dataclassifier.py +117 -0
  97. AeroViz/tools/datareader.py +66 -0
  98. AeroViz-0.1.0.dist-info/LICENSE +21 -0
  99. AeroViz-0.1.0.dist-info/METADATA +117 -0
  100. AeroViz-0.1.0.dist-info/RECORD +102 -0
  101. AeroViz-0.1.0.dist-info/WHEEL +5 -0
  102. AeroViz-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,156 @@
1
+ from typing import Literal
2
+
3
+ import matplotlib.pyplot as plt
4
+ import numpy as np
5
+ import pandas as pd
6
+ from matplotlib.pyplot import Figure, Axes
7
+ from scipy.optimize import curve_fit
8
+
9
+ from AeroViz.plot.utils import *
10
+
11
+ __all__ = ['koschmieder']
12
+
13
+
14
+ @set_figure(fs=12)
15
+ def koschmieder(df: pd.DataFrame,
16
+ y: Literal['Vis_Naked', 'Vis_LPV'],
17
+ function: Literal['log', 'reciprocal'] = 'log',
18
+ ax: Axes | None = None,
19
+ **kwargs) -> tuple[Figure, Axes]:
20
+ # x = Visibility, y = Extinction, log-log fit!!
21
+ def _log_fit(x, y, func=lambda x, a: -x + a):
22
+ x_log = np.log(x)
23
+ y_log = np.log(y)
24
+
25
+ popt, pcov = curve_fit(func, x_log, y_log)
26
+
27
+ residuals = y_log - func(x_log, *popt)
28
+ ss_res = np.sum(residuals ** 2)
29
+ ss_total = np.sum((y_log - np.mean(y_log)) ** 2)
30
+ r_squared = 1 - (ss_res / ss_total)
31
+ print(f'Const_Log = {popt[0].round(3)}')
32
+ print(f'Const = {np.exp(popt)[0].round(3)}')
33
+ print(f'R^2 = {r_squared.round(3)}')
34
+ return np.exp(popt)[0], pcov
35
+
36
+ def _reciprocal_fit(x, y, func=lambda x, a, b: a / (x ** b)):
37
+ popt, pcov = curve_fit(func, x, y)
38
+
39
+ residuals = y - func(x, *popt)
40
+ ss_res = np.sum(residuals ** 2)
41
+ ss_total = np.sum((y - np.mean(y)) ** 2)
42
+ r_squared = 1 - (ss_res / ss_total)
43
+ print(f'Const = {popt.round(3)}')
44
+ print(f' R^2 = {r_squared.round(3)}')
45
+ return popt, pcov
46
+
47
+ fig, ax = plt.subplots(**kwargs.get('fig_kws', {})) if ax is None else (ax.get_figure(), ax)
48
+
49
+ _df1 = df[['Extinction', 'ExtinctionByGas', y]].dropna().copy()
50
+ _df2 = df[['total_ext_dry', 'ExtinctionByGas', y]].dropna().copy()
51
+
52
+ x_data1 = _df1[y]
53
+ y_data1 = _df1['Extinction'] + _df1['ExtinctionByGas']
54
+
55
+ x_data2 = _df2[y]
56
+ y_data2 = _df2['total_ext_dry'] + _df2['ExtinctionByGas']
57
+
58
+ para_coeff = []
59
+ boxcolors = ['#3f83bf', '#a5bf6b']
60
+
61
+ for i, (df_, x_data, y_data) in enumerate(zip([_df1, _df2], [x_data1, x_data2], [y_data1, y_data2])):
62
+ df_['Total_Ext'] = y_data
63
+
64
+ if y == 'Vis_Naked':
65
+ df_grp = df_.groupby(f'{y}')
66
+
67
+ vals, median_vals, vis = [], [], []
68
+ for j, (name, subdf) in enumerate(df_grp):
69
+ if len(subdf['Total_Ext'].dropna()) > 20:
70
+ vis.append('{:.0f}'.format(name))
71
+ vals.append(subdf['Total_Ext'].dropna().values)
72
+ median_vals.append(subdf['Total_Ext'].dropna().median())
73
+
74
+ plt.boxplot(vals, labels=vis, positions=np.array(vis, dtype='int'), widths=0.4,
75
+ showfliers=False, showmeans=True, meanline=False, patch_artist=True,
76
+ boxprops=dict(facecolor=boxcolors[i], alpha=.7),
77
+ meanprops=dict(marker='o', markerfacecolor='white', markeredgecolor='k', markersize=4),
78
+ medianprops=dict(color='#000000', ls='-'))
79
+
80
+ plt.scatter(x_data, y_data, marker='.', s=10, facecolor='white', edgecolor=boxcolors[i], alpha=0.1)
81
+
82
+ if y == 'Vis_LPV':
83
+ bins = np.linspace(0, 70, 36)
84
+ wid = (bins + (bins[1] - bins[0]) / 2)[0:-1]
85
+
86
+ df_[f'{x_data.name}' + '_bins'] = pd.cut(x=x_data, bins=bins, labels=wid)
87
+
88
+ grouped = df_.groupby(f'{x_data.name}' + '_bins', observed=False)
89
+
90
+ vals, median_vals, vis = [], [], []
91
+ for j, (name, subdf) in enumerate(grouped):
92
+ if len(subdf['Total_Ext'].dropna()) > 20:
93
+ vis.append('{:.1f}'.format(name))
94
+ vals.append(subdf['Total_Ext'].dropna().values)
95
+ median_vals.append(subdf['Total_Ext'].dropna().mean())
96
+
97
+ plt.boxplot(vals, labels=vis, positions=np.array(vis, dtype='float'), widths=(bins[1] - bins[0]) / 2.5,
98
+ showfliers=False, showmeans=True, meanline=False, patch_artist=True,
99
+ boxprops=dict(facecolor=boxcolors[i], alpha=.7),
100
+ meanprops=dict(marker='o', markerfacecolor='white', markeredgecolor='k', markersize=4),
101
+ medianprops=dict(color='#000000', ls='-'))
102
+
103
+ plt.scatter(x_data, y_data, marker='.', s=10, facecolor='white', edgecolor=boxcolors[i], alpha=0.1)
104
+
105
+ # fit curve
106
+ _x = np.array(vis, dtype='float')
107
+ _y = np.array(median_vals, dtype='float')
108
+
109
+ if function == 'log':
110
+ func = lambda x, a: a / x
111
+ coeff, pcov = _log_fit(_x, _y)
112
+
113
+ else:
114
+ func = lambda x, a, b: a / (x ** b)
115
+ coeff, pcov = _reciprocal_fit(_x, _y)
116
+
117
+ para_coeff.append(coeff)
118
+
119
+ # Plot lines (ref & Measurement)
120
+ x_fit = np.linspace(0.1, 70, 1000)
121
+
122
+ if function == 'log':
123
+ line1, = ax.plot(x_fit, func(x_fit, para_coeff[0]), c='b', lw=3)
124
+ line2, = ax.plot(x_fit, func(x_fit, para_coeff[1]), c='g', lw=3)
125
+
126
+ labels = ['Vis (km) = ' + f'{round(para_coeff[0])}' + ' / Ext (Dry Extinction)',
127
+ 'Vis (km) = ' + f'{round(para_coeff[1])}' + ' / Ext (Amb Extinction)']
128
+
129
+ else:
130
+ x_fit = np.linspace(0.1, 70, 1000)
131
+ line1, = ax.plot(x_fit, func(x_fit, *para_coeff[0]), c='b', lw=3)
132
+ line2, = ax.plot(x_fit, func(x_fit, *para_coeff[1]), c='g', lw=3)
133
+
134
+ labels = [f'Ext = ' + '{:.0f} / Vis ^ {:.3f}'.format(*para_coeff[0]) + ' (Dry Extinction)',
135
+ f'Ext = ' + '{:.0f} / Vis ^ {:.3f}'.format(*para_coeff[1]) + ' (Amb Extinction)']
136
+
137
+ plt.legend(handles=[line1, line2], labels=labels, loc='upper right', prop=dict(size=10, weight='bold'),
138
+ bbox_to_anchor=(0.99, 0.99))
139
+
140
+ plt.xticks(ticks=np.array(range(0, 51, 5)), labels=np.array(range(0, 51, 5)))
141
+ plt.xlim(0, 50)
142
+ plt.ylim(0, 700)
143
+ plt.title(r'$\bf Koschmieder\ relationship$')
144
+ plt.xlabel(f'{y} (km)')
145
+ plt.ylabel(r'$\bf Extinction\ coefficient\ (1/Mm)$')
146
+
147
+ plt.show()
148
+
149
+ return fig, ax
150
+
151
+
152
+ if __name__ == '__main__':
153
+ from AeroViz.tools import DataBase
154
+
155
+ koschmieder(DataBase(), 'Vis_LPV', 'log')
156
+ # koschmieder(DataBase, 'Vis_Naked', 'reciprocal')
@@ -0,0 +1,57 @@
1
+ import matplotlib.pyplot as plt
2
+ import numpy as np
3
+ import seaborn as sns
4
+ from matplotlib.pyplot import Figure, Axes
5
+ from pandas import DataFrame, date_range
6
+ from sklearn.preprocessing import StandardScaler
7
+
8
+ from AeroViz.plot.utils import *
9
+
10
+
11
+ def process_data(df):
12
+ # detected_limit = 0.0001
13
+ df = df.where(df >= 0.0001, np.nan)
14
+ # Normalize the data
15
+ df = DataFrame(StandardScaler().fit_transform(df), index=df.index, columns=df.columns)
16
+ # Remove outliers
17
+ df = df[(np.abs(df) < 6)]
18
+ # Interpolate the missing values
19
+ df = df.interpolate(method='linear')
20
+ # Smooth the data
21
+ df = df.rolling(window=3, min_periods=1).mean()
22
+
23
+ return df
24
+
25
+
26
+ @set_figure(figsize=(12, 3), fs=6)
27
+ def metal_heatmaps(df, major_freq='24h', minor_freq='12h', ax: Axes | None = None, title=None, **kwargs
28
+ ) -> tuple[Figure, Axes]:
29
+ items = ['Al', 'Zr', 'Si', 'Ca', 'Ti', 'Mn', 'Fe', 'V', 'Cl', 'K',
30
+ 'Sr', 'Ba', 'Bi', 'Pd', 'Sn', 'Cr', 'W', 'Cu', 'Zn',
31
+ 'As', 'Co', 'Se', 'Br', 'Cd', 'Sb', 'In', 'Pb', 'Ni']
32
+
33
+ df = df[items]
34
+
35
+ fig, ax = plt.subplots(**kwargs.get('fig_kws', {})) if ax is None else (ax.get_figure(), ax)
36
+
37
+ sns.heatmap(df.T, vmax=3, cmap="jet", xticklabels=False, yticklabels=True,
38
+ cbar_kws={'label': 'Z score'})
39
+ ax.grid(color='gray', linestyle='-', linewidth=0.3)
40
+ # Set x-tick positions and labels
41
+ major_tick = date_range(start=df.index[0], end=df.index[-1], freq=major_freq)
42
+ minor_tick = date_range(start=df.index[0], end=df.index[-1], freq=minor_freq)
43
+
44
+ # Set the major and minor ticks
45
+ ax.set_xticks(ticks=[df.index.get_loc(t) for t in major_tick])
46
+ ax.set_xticks(ticks=[df.index.get_loc(t) for t in minor_tick], minor=True)
47
+ ax.set_xticklabels(major_tick.strftime('%F'))
48
+ ax.tick_params(axis='y', rotation=0)
49
+
50
+ ax.set_title(f"{title}", fontsize=10)
51
+ ax.set(xlabel='',
52
+ ylabel='',
53
+ )
54
+
55
+ plt.show()
56
+
57
+ return fig, ax
@@ -0,0 +1,256 @@
1
+ import matplotlib.pyplot as plt
2
+ import numpy as np
3
+ import pandas as pd
4
+ from matplotlib.pyplot import Figure, Axes
5
+ from sklearn.linear_model import LinearRegression
6
+ from tabulate import tabulate
7
+
8
+ from AeroViz.plot.utils import *
9
+
10
+ __all__ = [
11
+ 'linear_regression',
12
+ 'multiple_linear_regression',
13
+ ]
14
+
15
+
16
+ def _linear_regression(x_array: np.ndarray,
17
+ y_array: np.ndarray,
18
+ columns: str | list[str] | None = None,
19
+ positive: bool = True,
20
+ fit_intercept: bool = True):
21
+ if len(x_array.shape) > 1 and x_array.shape[1] >= 2:
22
+ model = LinearRegression(positive=positive, fit_intercept=fit_intercept).fit(x_array, y_array)
23
+
24
+ coefficients = model.coef_[0].round(3)
25
+ intercept = model.intercept_[0].round(3) if fit_intercept else 'None'
26
+ r_square = model.score(x_array, y_array).round(3)
27
+ y_predict = model.predict(x_array)
28
+
29
+ equation = ' + '.join([f'{coeff:.3f} * {col}' for coeff, col in zip(coefficients, columns)])
30
+ equation = equation.replace(' + 0.000 * Const', '') # Remove terms with coefficient 0
31
+
32
+ text = 'y = ' + str(equation) + '\n' + r'$\bf R^2 = $' + str(r_square)
33
+ tab = tabulate([[*coefficients, intercept, r_square]], headers=[*columns, 'intercept', 'R^2'], floatfmt=".3f",
34
+ tablefmt="fancy_grid")
35
+ print('\n' + tab)
36
+
37
+ return text, y_predict, coefficients
38
+
39
+ else:
40
+ x_array = x_array.reshape(-1, 1)
41
+ y_array = y_array.reshape(-1, 1)
42
+
43
+ model = LinearRegression(positive=positive, fit_intercept=fit_intercept).fit(x_array, y_array)
44
+
45
+ slope = model.coef_[0][0].round(3)
46
+ intercept = model.intercept_[0].round(3) if fit_intercept else 'None'
47
+ r_square = model.score(x_array, y_array).round(3)
48
+ y_predict = model.predict(x_array)
49
+
50
+ text = np.poly1d([slope, intercept])
51
+ text = 'y = ' + str(text).replace('\n', "") + '\n' + r'$\bf R^2 = $' + str(r_square)
52
+
53
+ tab = tabulate([[slope, intercept, r_square]], headers=['slope', 'intercept', 'R^2'], floatfmt=".3f",
54
+ tablefmt="fancy_grid")
55
+ print('\n' + tab)
56
+
57
+ return text, y_predict, slope
58
+
59
+
60
+ @set_figure
61
+ def linear_regression(df: pd.DataFrame,
62
+ x: str | list[str],
63
+ y: str | list[str],
64
+ labels: str | list[str] = None,
65
+ ax: Axes | None = None,
66
+ diagonal=False,
67
+ positive: bool = True,
68
+ fit_intercept: bool = True,
69
+ **kwargs
70
+ ) -> tuple[Figure, Axes]:
71
+ """
72
+ Create a scatter plot with multiple regression lines for the given data.
73
+
74
+ Parameters
75
+ ----------
76
+ df : DataFrame
77
+ Input DataFrame containing the data.
78
+
79
+ x : str or list of str
80
+ Column name(s) for the x-axis variable(s).
81
+
82
+ y : str or list of str
83
+ Column name(s) for the y-axis variable(s).
84
+
85
+ labels : str or list of str, optional
86
+ Labels for the y-axis variable(s). If None, column names are used as labels. Default is None.
87
+
88
+ ax : AxesSubplot, optional
89
+ Matplotlib AxesSubplot to use for the plot. If None, a new subplot is created. Default is None.
90
+
91
+ diagonal : bool, optional
92
+ If True, a diagonal line (1:1 line) is added to the plot. Default is False.
93
+
94
+ positive : bool, optional
95
+ Whether to let coefficient positive. Default is True.
96
+
97
+ fit_intercept: bool, optional
98
+ Whether to fit intercept. Default is True.
99
+
100
+ **kwargs
101
+ Additional keyword arguments to customize the plot.
102
+
103
+ Returns
104
+ -------
105
+ AxesSubplot
106
+ Matplotlib AxesSubplot containing the scatter plot.
107
+
108
+ Notes
109
+ -----
110
+ - The function creates a scatter plot with the option to include multiple regression lines.
111
+ - If regression is True, regression lines are fitted for each y variable.
112
+ - Additional customization can be done using the **kwargs.
113
+
114
+ Example
115
+ -------
116
+ >>> linear_regression(df, x='X', y=['Y1', 'Y2'], labels=['Label1', 'Label2'],
117
+ ... regression=True, diagonal=True, xlim=(0, 10), ylim=(0, 20),
118
+ ... xlabel="X-axis", ylabel="Y-axis", title="Scatter Plot with Regressions")
119
+ """
120
+ fig, ax = plt.subplots(**kwargs.get('fig_kws', {})) if ax is None else (ax.get_figure(), ax)
121
+
122
+ if not isinstance(x, str):
123
+ x = x[0]
124
+
125
+ if not isinstance(y, list):
126
+ y = [y]
127
+
128
+ if labels is None:
129
+ labels = y
130
+
131
+ df = df.dropna(subset=[x, *y])
132
+ x_array = df[[x]].to_numpy()
133
+
134
+ color_cycle = Color.linecolor
135
+
136
+ handles, text_list = [], []
137
+
138
+ for i, y_var in enumerate(y):
139
+ y_array = df[[y_var]].to_numpy()
140
+
141
+ color = color_cycle[i % len(color_cycle)]
142
+
143
+ scatter = ax.scatter(x_array, y_array, s=25, color=color['face'], edgecolors=color['edge'], alpha=0.8,
144
+ label=labels[i])
145
+ handles.append(scatter)
146
+
147
+ text, y_predict, slope = _linear_regression(x_array, y_array,
148
+ columns=labels[i],
149
+ positive=positive,
150
+ fit_intercept=fit_intercept)
151
+
152
+ text_list.append(f'{labels[i]}: {text}')
153
+ plt.plot(x_array, y_predict, linewidth=3, color=color['line'], alpha=1, zorder=3)
154
+
155
+ ax.set(xlim=kwargs.get('xlim'), ylim=kwargs.get('ylim'), xlabel=Unit(x), ylabel=Unit(y[0]),
156
+ title=kwargs.get('title'))
157
+
158
+ # Add regression info to the legend
159
+ leg = plt.legend(handles=handles, labels=text_list, loc='upper left', prop={'weight': 'bold', 'size': 10})
160
+
161
+ for text, color in zip(leg.get_texts(), [color['line'] for color in color_cycle]):
162
+ text.set_color(color)
163
+
164
+ if diagonal:
165
+ ax.axline((0, 0), slope=1., color='k', lw=2, ls='--', alpha=0.5, label='1:1')
166
+ plt.text(0.97, 0.97, r'$\bf 1:1\ Line$', color='k', ha='right', va='top', transform=ax.transAxes)
167
+
168
+ plt.show()
169
+
170
+ return fig, ax
171
+
172
+
173
+ @set_figure
174
+ def multiple_linear_regression(df: pd.DataFrame,
175
+ x: str | list[str],
176
+ y: str | list[str],
177
+ labels: str | list[str] = None,
178
+ ax: Axes | None = None,
179
+ diagonal=False,
180
+ positive: bool = True,
181
+ fit_intercept: bool = True,
182
+ **kwargs
183
+ ) -> tuple[Figure, Axes]:
184
+ """
185
+ Perform multiple linear regression analysis and plot the results.
186
+
187
+ Parameters
188
+ ----------
189
+ df : pandas.DataFrame
190
+ Input DataFrame containing the data.
191
+
192
+ x : str or list of str
193
+ Column name(s) for the independent variable(s). Can be a single string or a list of strings.
194
+
195
+ y : str or list of str
196
+ Column name(s) for the dependent variable(s). Can be a single string or a list of strings.
197
+
198
+ labels : str or list of str, optional
199
+ Labels for the dependent variable(s). If None, column names are used as labels. Default is None.
200
+
201
+ ax : matplotlib.axes.Axes or None, optional
202
+ Matplotlib Axes object to use for the plot. If None, a new subplot is created. Default is None.
203
+
204
+ diagonal : bool, optional
205
+ Whether to include a diagonal line (1:1 line) in the plot. Default is False.
206
+
207
+ positive : bool, optional
208
+ Whether to let coefficient positive. Default is True.
209
+
210
+ fit_intercept: bool, optional
211
+ Whether to fit intercept. Default is True.
212
+
213
+ **kwargs
214
+ Additional keyword arguments to customize the plot.
215
+
216
+ Returns
217
+ -------
218
+ matplotlib.axes.Axes
219
+ Matplotlib Axes object containing the regression plot.
220
+
221
+ Notes
222
+ -----
223
+ This function performs multiple linear regression analysis using the input DataFrame.
224
+ It supports multiple independent variables and can plot the regression results.
225
+
226
+ Example
227
+ -------
228
+ >>> multiple_linear_regression(df, x=['X1', 'X2'], y='Y', labels=['Y1', 'Y2'],
229
+ ... diagonal=True, add_constant=True,
230
+ ... xlabel="X-axis", ylabel="Y-axis", title="Multiple Linear Regression Plot")
231
+ """
232
+ fig, ax = plt.subplots(**kwargs.get('fig_kws', {})) if ax is None else (ax.get_figure(), ax)
233
+
234
+ if not isinstance(x, list):
235
+ x = [x]
236
+
237
+ if not isinstance(y, str):
238
+ y = y[0]
239
+
240
+ if labels is None:
241
+ labels = x
242
+
243
+ df = df[[*x, y]].dropna()
244
+ x_array = df[[*x]].to_numpy()
245
+ y_array = df[[y]].to_numpy()
246
+
247
+ text, y_predict, coefficients = _linear_regression(x_array, y_array,
248
+ columns=labels,
249
+ positive=positive,
250
+ fit_intercept=fit_intercept)
251
+
252
+ df = pd.DataFrame(np.concatenate([y_array, y_predict], axis=1), columns=['y_actual', 'y_predict'])
253
+
254
+ linear_regression(df, x='y_actual', y='y_predict', ax=ax, regression=True, diagonal=diagonal)
255
+
256
+ return fig, ax
@@ -0,0 +1,130 @@
1
+ import matplotlib.pyplot as plt
2
+ import numpy as np
3
+ import pandas as pd
4
+ import seaborn as sns
5
+ from matplotlib.colors import Normalize
6
+ from matplotlib.pyplot import Figure, Axes
7
+ from matplotlib.ticker import ScalarFormatter
8
+
9
+ from AeroViz.plot.templates.regression import _linear_regression
10
+ from AeroViz.plot.utils import *
11
+
12
+ __all__ = ['scatter']
13
+
14
+
15
+ @set_figure
16
+ def scatter(df: pd.DataFrame,
17
+ x: str,
18
+ y: str,
19
+ c: str | None = None,
20
+ s: str | None = None,
21
+ cmap='jet',
22
+ regression=False,
23
+ diagonal=False,
24
+ box=False,
25
+ ax: Axes | None = None,
26
+ **kwargs) -> tuple[Figure, Axes]:
27
+ fig, ax = plt.subplots(**kwargs.get('fig_kws', {})) if ax is None else (ax.get_figure(), ax)
28
+
29
+ if c is not None and s is not None:
30
+ df_ = df.dropna(subset=[x, y, c, s]).copy()
31
+ x_data, y_data, c_data, s_data = df_[x].to_numpy(), df_[y].to_numpy(), df_[c].to_numpy(), df_[s].to_numpy()
32
+
33
+ scatter = ax.scatter(x_data, y_data, c=c_data,
34
+ norm=Normalize(vmin=np.percentile(c_data, 10), vmax=np.percentile(c_data, 90)),
35
+ cmap=cmap, s=50 * (s_data / s_data.max()) ** 1.5, alpha=0.7, edgecolors=None)
36
+ colorbar = True
37
+
38
+ dot = np.linspace(s_data.min(), s_data.max(), 6).round(-1)
39
+
40
+ for dott in dot[1:-1]:
41
+ plt.scatter([], [], c='k', alpha=0.8, s=50 * (dott / s_data.max()) ** 1.5, label='{:.0f}'.format(dott))
42
+
43
+ plt.legend(title=Unit(s))
44
+
45
+ elif c is not None:
46
+ df_ = df.dropna(subset=[x, y, c]).copy()
47
+ x_data, y_data, c_data = df_[x].to_numpy(), df_[y].to_numpy(), df_[c].to_numpy()
48
+
49
+ scatter = ax.scatter(x_data, y_data, c=c_data, vmin=c_data.min(), vmax=np.percentile(c_data, 90), cmap=cmap,
50
+ alpha=0.7,
51
+ edgecolors=None)
52
+ colorbar = True
53
+
54
+ elif s is not None:
55
+ df_ = df.dropna(subset=[x, y, s]).copy()
56
+ x_data, y_data, s_data = df_[x].to_numpy(), df_[y].to_numpy(), df_[s].to_numpy()
57
+
58
+ scatter = ax.scatter(x_data, y_data, s=50 * (s_data / s_data.max()) ** 1.5, color='#7a97c9', alpha=0.7,
59
+ edgecolors='white')
60
+ colorbar = False
61
+
62
+ # dealing
63
+ dot = np.linspace(s_data.min(), s_data.max(), 6).round(-1)
64
+
65
+ for dott in dot[1:-1]:
66
+ plt.scatter([], [], c='k', alpha=0.8, s=50 * (dott / s_data.max()) ** 1.5, label='{:.0f}'.format(dott))
67
+
68
+ plt.legend(title=Unit(s))
69
+
70
+ else:
71
+ df_ = df.dropna(subset=[x, y]).copy()
72
+ x_data, y_data = df_[x].to_numpy(), df_[y].to_numpy()
73
+
74
+ scatter = ax.scatter(x_data, y_data, s=30, color='#7a97c9', alpha=0.7, edgecolors='white')
75
+ colorbar = False
76
+
77
+ xlim = kwargs.get('xlim', (x_data.min(), x_data.max()))
78
+ ylim = kwargs.get('ylim', (y_data.min(), y_data.max()))
79
+ xlabel = kwargs.get('xlabel', Unit(x))
80
+ ylabel = kwargs.get('ylabel', Unit(y))
81
+ title = kwargs.get('title', '')
82
+ ax.set(xlim=xlim, ylim=ylim, xlabel=xlabel, ylabel=ylabel, title=title)
83
+
84
+ # color_bar
85
+ if colorbar:
86
+ color_bar = plt.colorbar(scatter, extend='both')
87
+ color_bar.set_label(label=Unit(c), size=14)
88
+
89
+ if regression:
90
+ text, y_predict, slope = _linear_regression(x_data, y_data)
91
+ plt.plot(x_data, y_predict, linewidth=3, color=sns.xkcd_rgb["denim blue"], alpha=1, zorder=3)
92
+
93
+ plt.text(0.05, 0.95, f'{text}', fontdict={'weight': 'bold'}, color=sns.xkcd_rgb["denim blue"],
94
+ ha='left', va='top', transform=ax.transAxes)
95
+
96
+ if diagonal:
97
+ ax.axline((0, 0), slope=1., color='k', lw=2, ls='--', alpha=0.5, label='1:1')
98
+ plt.text(0.91, 0.97, r'$\bf 1:1\ Line$', color='k', ha='right', va='top', transform=ax.transAxes)
99
+
100
+ if box:
101
+ bins = np.linspace(x_data.min(), x_data.max(), 11, endpoint=True)
102
+ wid = (bins + (bins[1] - bins[0]) / 2)[0:-1]
103
+
104
+ df[x + '_bin'] = pd.cut(x=x_data, bins=bins, labels=wid)
105
+
106
+ group = x + '_bin'
107
+ column = y
108
+ grouped = df.groupby(group, observed=False)
109
+
110
+ names, vals = [], []
111
+
112
+ for i, (name, subdf) in enumerate(grouped):
113
+ names.append('{:.0f}'.format(name))
114
+ vals.append(subdf[column].dropna().values)
115
+
116
+ plt.boxplot(vals, labels=names, positions=wid, widths=(bins[1] - bins[0]) / 3,
117
+ showfliers=False, showmeans=True, meanline=True, patch_artist=True,
118
+ boxprops=dict(facecolor='#f2c872', alpha=.7),
119
+ meanprops=dict(color='#000000', ls='none'),
120
+ medianprops=dict(ls='-', color='#000000'))
121
+
122
+ plt.xlim(x_data.min(), x_data.max())
123
+ ax.set_xticks(bins, labels=bins.astype(int))
124
+
125
+ ax.xaxis.set_major_formatter(ScalarFormatter())
126
+ ax.yaxis.set_major_formatter(ScalarFormatter())
127
+
128
+ plt.show()
129
+
130
+ return fig, ax