nxs-analysis-tools 0.1.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- _meta/__init__.py +10 -0
- nxs_analysis_tools/__init__.py +15 -0
- nxs_analysis_tools/chess.py +866 -0
- nxs_analysis_tools/datareduction.py +1542 -0
- nxs_analysis_tools/datasets.py +137 -0
- nxs_analysis_tools/fitting.py +301 -0
- nxs_analysis_tools/lineartransformations.py +51 -0
- nxs_analysis_tools/pairdistribution.py +1758 -0
- nxs_analysis_tools-0.1.13.dist-info/METADATA +89 -0
- nxs_analysis_tools-0.1.13.dist-info/RECORD +13 -0
- nxs_analysis_tools-0.1.13.dist-info/WHEEL +5 -0
- nxs_analysis_tools-0.1.13.dist-info/licenses/LICENSE +21 -0
- nxs_analysis_tools-0.1.13.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,137 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import pooch
|
|
3
|
+
|
|
4
|
+
GOODBOY = pooch.create(
|
|
5
|
+
path=pooch.os_cache("nxs_analysis_tools/cubic"),
|
|
6
|
+
base_url="https://raw.githubusercontent.com/stevenjgomez/dataset-cubic/main/data/",
|
|
7
|
+
registry={
|
|
8
|
+
"cubic_15.nxs": None,
|
|
9
|
+
"15/transform.nxs": None,
|
|
10
|
+
"cubic_25.nxs": None,
|
|
11
|
+
"25/transform.nxs": None,
|
|
12
|
+
"cubic_35.nxs": None,
|
|
13
|
+
"35/transform.nxs": None,
|
|
14
|
+
"cubic_45.nxs": None,
|
|
15
|
+
"45/transform.nxs": None,
|
|
16
|
+
"cubic_55.nxs": None,
|
|
17
|
+
"55/transform.nxs": None,
|
|
18
|
+
"cubic_65.nxs": None,
|
|
19
|
+
"65/transform.nxs": None,
|
|
20
|
+
"cubic_75.nxs": None,
|
|
21
|
+
"75/transform.nxs": None,
|
|
22
|
+
"cubic_80.nxs": None,
|
|
23
|
+
"80/transform.nxs": None,
|
|
24
|
+
"cubic_104.nxs": None,
|
|
25
|
+
"104/transform.nxs": None,
|
|
26
|
+
"cubic_128.nxs": None,
|
|
27
|
+
"128/transform.nxs": None,
|
|
28
|
+
"cubic_153.nxs": None,
|
|
29
|
+
"153/transform.nxs": None,
|
|
30
|
+
"cubic_177.nxs": None,
|
|
31
|
+
"177/transform.nxs": None,
|
|
32
|
+
"cubic_202.nxs": None,
|
|
33
|
+
"202/transform.nxs": None,
|
|
34
|
+
"cubic_226.nxs": None,
|
|
35
|
+
"226/transform.nxs": None,
|
|
36
|
+
"cubic_251.nxs": None,
|
|
37
|
+
"251/transform.nxs": None,
|
|
38
|
+
"cubic_275.nxs": None,
|
|
39
|
+
"275/transform.nxs": None,
|
|
40
|
+
"cubic_300.nxs": None,
|
|
41
|
+
"300/transform.nxs": None,
|
|
42
|
+
}
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
def fetch_cubic(temperatures=None):
|
|
46
|
+
"""
|
|
47
|
+
Load the cubic dataset.
|
|
48
|
+
"""
|
|
49
|
+
fnames = []
|
|
50
|
+
temperatures = [15, 25, 35, 45, 55, 65, 75, 80, 104, 128,
|
|
51
|
+
153, 177, 202, 226, 251, 275, 300] if temperatures is None else temperatures
|
|
52
|
+
for T in temperatures:
|
|
53
|
+
fnames.append(GOODBOY.fetch(f"cubic_{T}.nxs"))
|
|
54
|
+
fnames.append(GOODBOY.fetch(f"{T}/transform.nxs"))
|
|
55
|
+
return fnames
|
|
56
|
+
|
|
57
|
+
def cubic(temperatures=None):
|
|
58
|
+
fnames = fetch_cubic(temperatures)
|
|
59
|
+
dirname = os.path.dirname(fnames[0])
|
|
60
|
+
return dirname
|
|
61
|
+
|
|
62
|
+
POOCH = pooch.create(
|
|
63
|
+
path=pooch.os_cache("nxs_analysis_tools/hexagonal"),
|
|
64
|
+
base_url="https://raw.githubusercontent.com/stevenjgomez/dataset-hexagonal/main/data/",
|
|
65
|
+
registry={
|
|
66
|
+
"hexagonal_15.nxs": "850d666d6fb0c7bbf7f7159fed952fbd53355c3c0bfb40410874d3918a3cca49",
|
|
67
|
+
"15/transform.nxs": "45c089be295e0a5b927e963540a90b41f567edb75f283811dbc6bb4a26f2fba5",
|
|
68
|
+
"hexagonal_300.nxs": "c6a9ff704d1e42d9576d007a92a333f529e3ddf605e3f76a82ff15557b7d4a43",
|
|
69
|
+
"300/transform.nxs": "e665ba59debe8e60c90c3181e2fb1ebbce668a3d3918a89a6bf31e3563ebf32e",
|
|
70
|
+
}
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
def fetch_hexagonal(temperatures=None):
|
|
74
|
+
"""
|
|
75
|
+
Load the hexagonal dataset.
|
|
76
|
+
"""
|
|
77
|
+
fnames = []
|
|
78
|
+
temperatures = [15, 300] if temperatures is None else temperatures
|
|
79
|
+
for T in temperatures:
|
|
80
|
+
fnames.append(POOCH.fetch(f"hexagonal_{T}.nxs"))
|
|
81
|
+
fnames.append(POOCH.fetch(f"{T}/transform.nxs"))
|
|
82
|
+
return fnames
|
|
83
|
+
|
|
84
|
+
def hexagonal(temperatures=None):
|
|
85
|
+
fnames = fetch_hexagonal(temperatures)
|
|
86
|
+
dirname = os.path.dirname(fnames[0])
|
|
87
|
+
return dirname
|
|
88
|
+
|
|
89
|
+
LARGEBOI = pooch.create(
|
|
90
|
+
path=pooch.os_cache("nxs_analysis_tools/orthorhombic"),
|
|
91
|
+
base_url="https://raw.githubusercontent.com/stevenjgomez/dataset-orthorhombic/main/data/",
|
|
92
|
+
registry={
|
|
93
|
+
"orthorhombic_15.nxs": None,
|
|
94
|
+
"15/transform.nxs": None,
|
|
95
|
+
"orthorhombic_100.nxs": None,
|
|
96
|
+
"100/transform.nxs": None,
|
|
97
|
+
"orthorhombic_300.nxs": None,
|
|
98
|
+
"300/transform.nxs": None,
|
|
99
|
+
}
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
def fetch_orthorhombic(temperatures=None):
|
|
103
|
+
"""
|
|
104
|
+
Load the orthorhombic dataset.
|
|
105
|
+
"""
|
|
106
|
+
fnames = []
|
|
107
|
+
temperatures = [15, 100, 300] if temperatures is None else temperatures
|
|
108
|
+
for T in temperatures:
|
|
109
|
+
fnames.append(LARGEBOI.fetch(f"orthorhombic_{T}.nxs"))
|
|
110
|
+
fnames.append(LARGEBOI.fetch(f"{T}/transform.nxs"))
|
|
111
|
+
return fnames
|
|
112
|
+
|
|
113
|
+
def orthorhombic(temperatures=None):
|
|
114
|
+
fnames = fetch_orthorhombic(temperatures)
|
|
115
|
+
dirname = os.path.dirname(fnames[0])
|
|
116
|
+
return dirname
|
|
117
|
+
|
|
118
|
+
BONES = pooch.create(
|
|
119
|
+
path=pooch.os_cache("nxs_analysis_tools/vacancies"),
|
|
120
|
+
base_url="https://raw.githubusercontent.com/stevenjgomez/dataset-vacancies/main/",
|
|
121
|
+
registry={
|
|
122
|
+
"vacancies.nxs": "39eaf8df84a0dbcacbe6ce7c6017da4da578fbf68a6218ee18ade3953c26efb5",
|
|
123
|
+
"fft.nxs": "c81178eda0ec843502935f29fcb2b0b878f7413e461612c731d37ea9e5e414a9",
|
|
124
|
+
}
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
def vacancies():
|
|
128
|
+
"""
|
|
129
|
+
Load the vacancies dataset.
|
|
130
|
+
"""
|
|
131
|
+
return BONES.fetch(f"vacancies.nxs")
|
|
132
|
+
|
|
133
|
+
def vacanciesfft():
|
|
134
|
+
"""
|
|
135
|
+
Load the vacancies fft dataset.
|
|
136
|
+
"""
|
|
137
|
+
return BONES.fetch(f"fft.nxs")
|
|
@@ -0,0 +1,301 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Module for fitting of linecuts using the lmfit package.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import operator
|
|
6
|
+
from lmfit import Parameters
|
|
7
|
+
from lmfit.model import Model, CompositeModel
|
|
8
|
+
from lmfit.models import PseudoVoigtModel, LinearModel
|
|
9
|
+
import matplotlib.pyplot as plt
|
|
10
|
+
import numpy as np
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class LinecutModel:
|
|
14
|
+
"""
|
|
15
|
+
A class representing a linecut model for data analysis and fitting.
|
|
16
|
+
|
|
17
|
+
Attributes
|
|
18
|
+
----------
|
|
19
|
+
y : array-like or None
|
|
20
|
+
The dependent variable data.
|
|
21
|
+
x : array-like or None
|
|
22
|
+
The independent variable data.
|
|
23
|
+
y_eval : array-like or None
|
|
24
|
+
The evaluated y-values of the fitted model.
|
|
25
|
+
x_eval : array-like or None
|
|
26
|
+
The x-values used for evaluation.
|
|
27
|
+
y_eval_components : dict or None
|
|
28
|
+
The evaluated y-values of the model components.
|
|
29
|
+
y_fit_components : dict or None
|
|
30
|
+
The fitted y-values of the model components.
|
|
31
|
+
y_fit : array-like or None
|
|
32
|
+
The fitted y-values of the model.
|
|
33
|
+
x_fit : array-like or None
|
|
34
|
+
The x-values used for fitting.
|
|
35
|
+
y_init_fit : array-like or None
|
|
36
|
+
The initial guess of the y-values.
|
|
37
|
+
params : Parameters or None
|
|
38
|
+
The parameters of the model.
|
|
39
|
+
model_components : Model or list of Models or None
|
|
40
|
+
The model component(s) used for fitting.
|
|
41
|
+
model : Model or None
|
|
42
|
+
The composite model used for fitting.
|
|
43
|
+
modelresult : ModelResult or None
|
|
44
|
+
The result of the model fitting.
|
|
45
|
+
data : NXdata or None
|
|
46
|
+
The 1D linecut data used for analysis.
|
|
47
|
+
|
|
48
|
+
Methods
|
|
49
|
+
-------
|
|
50
|
+
__init__(self, data=None)
|
|
51
|
+
Initialize the LinecutModel.
|
|
52
|
+
set_data(self, data)
|
|
53
|
+
Set the data for analysis.
|
|
54
|
+
set_model_components(self, model_components)
|
|
55
|
+
Set the model components.
|
|
56
|
+
set_param_hint(self, *args, **kwargs)
|
|
57
|
+
Set parameter hints for the model.
|
|
58
|
+
make_params(self)
|
|
59
|
+
Create and initialize the parameters for the model.
|
|
60
|
+
guess(self)
|
|
61
|
+
Perform initial guesses for each model component.
|
|
62
|
+
print_initial_params(self)
|
|
63
|
+
Print out initial guesses for each parameter of the model.
|
|
64
|
+
plot_initial_guess(self, numpoints=None)
|
|
65
|
+
Plot the initial guess.
|
|
66
|
+
fit(self)
|
|
67
|
+
Fit the model to the data.
|
|
68
|
+
plot_fit(self, numpoints=None, fit_report=True, **kwargs)
|
|
69
|
+
Plot the fitted model.
|
|
70
|
+
fit_peak_simple():
|
|
71
|
+
Perform a basic fit using a pseudo-Voigt peak shape, linear background, and no constraints.
|
|
72
|
+
print_fit_report(self)
|
|
73
|
+
Print the fit report.
|
|
74
|
+
"""
|
|
75
|
+
def __init__(self, data=None):
|
|
76
|
+
"""
|
|
77
|
+
Initialize the LinecutModel.
|
|
78
|
+
"""
|
|
79
|
+
self.y = None
|
|
80
|
+
self.x = None
|
|
81
|
+
self.y_eval = None
|
|
82
|
+
self.x_eval = None
|
|
83
|
+
self.y_eval_components = None
|
|
84
|
+
self.y_fit_components = None
|
|
85
|
+
self.y_fit = None
|
|
86
|
+
self.x_fit = None
|
|
87
|
+
self.y_init_fit = None
|
|
88
|
+
self.params = None
|
|
89
|
+
self.model_components = None
|
|
90
|
+
self.model = None
|
|
91
|
+
self.modelresult = None
|
|
92
|
+
self.data = data if data is not None else None
|
|
93
|
+
if self.data is not None:
|
|
94
|
+
self.x = data[data.axes].nxdata
|
|
95
|
+
self.y = data[data.signal].nxdata
|
|
96
|
+
|
|
97
|
+
def set_data(self, data):
|
|
98
|
+
"""
|
|
99
|
+
Set the data for analysis.
|
|
100
|
+
|
|
101
|
+
Parameters
|
|
102
|
+
----------
|
|
103
|
+
data : NXdata
|
|
104
|
+
The 1D linecut data to be used for analysis.
|
|
105
|
+
"""
|
|
106
|
+
self.data = data
|
|
107
|
+
self.x = data[data.axes].nxdata
|
|
108
|
+
self.y = data[data.signal].nxdata
|
|
109
|
+
|
|
110
|
+
def set_model_components(self, model_components):
|
|
111
|
+
"""
|
|
112
|
+
Set the model components
|
|
113
|
+
|
|
114
|
+
Parameters
|
|
115
|
+
----------
|
|
116
|
+
model_components : Model, CompositeModel, or iterable of Model
|
|
117
|
+
The model component(s) to be used for fitting.
|
|
118
|
+
"""
|
|
119
|
+
|
|
120
|
+
# If the model only has one component, then use it as the model
|
|
121
|
+
if isinstance(model_components, Model):
|
|
122
|
+
self.model = model_components
|
|
123
|
+
self.params = self.model.make_params()
|
|
124
|
+
|
|
125
|
+
# If the model is a composite model, then use it as the model
|
|
126
|
+
elif isinstance(model_components, CompositeModel):
|
|
127
|
+
self.model = model_components
|
|
128
|
+
self.model_components = self.model.components
|
|
129
|
+
# Make params for each component of the model
|
|
130
|
+
self.params = Parameters()
|
|
131
|
+
for component in self.model.components:
|
|
132
|
+
self.params.update(component.make_params())
|
|
133
|
+
|
|
134
|
+
# Else, combine the components into a composite model and use that as the model
|
|
135
|
+
else:
|
|
136
|
+
self.model_components = model_components
|
|
137
|
+
self.model = model_components[0]
|
|
138
|
+
|
|
139
|
+
# Combine remaining components into the composite model
|
|
140
|
+
for component in model_components[1:]:
|
|
141
|
+
self.model = CompositeModel(self.model, component, operator.add)
|
|
142
|
+
|
|
143
|
+
# Make params for each component of the model
|
|
144
|
+
self.params = Parameters()
|
|
145
|
+
for component in self.model.components:
|
|
146
|
+
self.params.update(component.make_params())
|
|
147
|
+
|
|
148
|
+
def set_param_hint(self, *args, **kwargs):
|
|
149
|
+
"""
|
|
150
|
+
Set parameter hints for the model. These are implemented when the .make_params() method
|
|
151
|
+
is called.
|
|
152
|
+
|
|
153
|
+
Parameters
|
|
154
|
+
----------
|
|
155
|
+
*args : positional arguments
|
|
156
|
+
Positional arguments passed to the `set_param_hint` method of the underlying model.
|
|
157
|
+
|
|
158
|
+
**kwargs : keyword arguments
|
|
159
|
+
Keyword arguments passed to the `set_param_hint` method of the underlying model.
|
|
160
|
+
"""
|
|
161
|
+
|
|
162
|
+
self.model.set_param_hint(*args, **kwargs)
|
|
163
|
+
|
|
164
|
+
def make_params(self):
|
|
165
|
+
"""
|
|
166
|
+
Create and initialize the parameters for the model.
|
|
167
|
+
|
|
168
|
+
Returns
|
|
169
|
+
-------
|
|
170
|
+
Parameters
|
|
171
|
+
The initialized parameters for the model.
|
|
172
|
+
"""
|
|
173
|
+
# Initialize empty parameters (in function)
|
|
174
|
+
params = self.model.make_params()
|
|
175
|
+
self.params = params
|
|
176
|
+
|
|
177
|
+
return params
|
|
178
|
+
|
|
179
|
+
def guess(self):
|
|
180
|
+
"""
|
|
181
|
+
Perform initial guesses for each model component and update params. This overwrites any
|
|
182
|
+
prior initial values and constraints.
|
|
183
|
+
|
|
184
|
+
Returns
|
|
185
|
+
-------
|
|
186
|
+
components_params : list
|
|
187
|
+
A list containing params objects for each component of the model.
|
|
188
|
+
"""
|
|
189
|
+
|
|
190
|
+
components_params = []
|
|
191
|
+
|
|
192
|
+
for model_component in self.model.components:
|
|
193
|
+
self.params.update(model_component.guess(self.y, x=self.x))
|
|
194
|
+
components_params.append(model_component.guess(self.y, x=self.x))
|
|
195
|
+
|
|
196
|
+
return components_params
|
|
197
|
+
|
|
198
|
+
def print_initial_params(self):
|
|
199
|
+
"""
|
|
200
|
+
Print out initial guesses for each parameter of the model.
|
|
201
|
+
"""
|
|
202
|
+
model = self.model
|
|
203
|
+
for param, hint in model.param_hints.items():
|
|
204
|
+
print(f'{param}')
|
|
205
|
+
for key, value in hint.items():
|
|
206
|
+
print(f'\t{key}: {value}')
|
|
207
|
+
|
|
208
|
+
def plot_initial_guess(self, numpoints=None):
|
|
209
|
+
"""
|
|
210
|
+
Plot initial guess.
|
|
211
|
+
"""
|
|
212
|
+
model = self.model
|
|
213
|
+
params = self.params
|
|
214
|
+
x = self.x
|
|
215
|
+
y = self.y
|
|
216
|
+
y_init_fit = model.eval(params=params, x=x)
|
|
217
|
+
self.y_init_fit = y_init_fit
|
|
218
|
+
plt.plot(x, y, 'o', label='data')
|
|
219
|
+
plt.plot(x, y_init_fit, '--', label='guess')
|
|
220
|
+
|
|
221
|
+
# Plot the components of the model
|
|
222
|
+
if numpoints is None:
|
|
223
|
+
numpoints = len(self.x)
|
|
224
|
+
self.x_eval = np.linspace(self.x.min(), self.x.max(), numpoints)
|
|
225
|
+
y_init_fit_components = model.eval_components(params=params, x=self.x_eval)
|
|
226
|
+
ax = plt.gca()
|
|
227
|
+
for model_component, value in y_init_fit_components.items():
|
|
228
|
+
ax.fill_between(self.x_eval, value, alpha=0.3, label=model_component)
|
|
229
|
+
plt.legend()
|
|
230
|
+
plt.show()
|
|
231
|
+
|
|
232
|
+
def fit(self):
|
|
233
|
+
"""
|
|
234
|
+
Fit the model to the data.
|
|
235
|
+
|
|
236
|
+
This method fits the model to the data using the specified parameters and the x-values.
|
|
237
|
+
It updates the model result, fitted y-values, and the evaluated components.
|
|
238
|
+
|
|
239
|
+
"""
|
|
240
|
+
self.modelresult = self.model.fit(self.y, self.params, x=self.x)
|
|
241
|
+
self.y_fit = self.modelresult.eval(x=self.x)
|
|
242
|
+
self.y_fit_components = self.modelresult.eval_components(x=self.x)
|
|
243
|
+
|
|
244
|
+
def plot_fit(self, numpoints=None, fit_report=True, **kwargs):
|
|
245
|
+
"""
|
|
246
|
+
Plot the fitted model.
|
|
247
|
+
|
|
248
|
+
This method plots the fitted model along with the original data.
|
|
249
|
+
It evaluates the model and its components at the specified number of points (numpoints)
|
|
250
|
+
and plots the results.
|
|
251
|
+
|
|
252
|
+
Parameters
|
|
253
|
+
----------
|
|
254
|
+
numpoints : int, optional
|
|
255
|
+
Number of points to evaluate the model and its components. If not provided,
|
|
256
|
+
it defaults to the length of the x-values.
|
|
257
|
+
|
|
258
|
+
fit_report : bool, optional
|
|
259
|
+
Whether to print the fit report. Default is True.
|
|
260
|
+
|
|
261
|
+
**kwargs : dict, optional
|
|
262
|
+
Additional keyword arguments to be passed to the `plot` method.
|
|
263
|
+
|
|
264
|
+
Returns
|
|
265
|
+
-------
|
|
266
|
+
ax : matplotlib.axes.Axes
|
|
267
|
+
The Axes object containing the plot.
|
|
268
|
+
|
|
269
|
+
"""
|
|
270
|
+
if numpoints is None:
|
|
271
|
+
numpoints = len(self.x)
|
|
272
|
+
self.x_eval = np.linspace(self.x.min(), self.x.max(), numpoints)
|
|
273
|
+
self.y_eval = self.modelresult.eval(x=self.x_eval)
|
|
274
|
+
self.y_eval_components = self.modelresult.eval_components(x=self.x_eval)
|
|
275
|
+
self.modelresult.plot(numpoints=numpoints, **kwargs)
|
|
276
|
+
ax = plt.gca()
|
|
277
|
+
for model_component, value in self.y_eval_components.items():
|
|
278
|
+
ax.fill_between(self.x_eval, value, alpha=0.3, label=model_component)
|
|
279
|
+
# ax.plot(self.x_eval, value, label=model_component)
|
|
280
|
+
plt.legend()
|
|
281
|
+
plt.show()
|
|
282
|
+
if fit_report:
|
|
283
|
+
print(self.modelresult.fit_report())
|
|
284
|
+
return ax
|
|
285
|
+
|
|
286
|
+
def fit_peak_simple(self):
|
|
287
|
+
"""
|
|
288
|
+
Fit all linecuts in the temperature series using a pseudo-Voigt peak shape and linear
|
|
289
|
+
background, with no constraints.
|
|
290
|
+
"""
|
|
291
|
+
self.set_model_components([PseudoVoigtModel(prefix='peak'),
|
|
292
|
+
LinearModel(prefix='background')])
|
|
293
|
+
self.make_params()
|
|
294
|
+
self.guess()
|
|
295
|
+
self.fit()
|
|
296
|
+
|
|
297
|
+
def print_fit_report(self):
|
|
298
|
+
"""
|
|
299
|
+
Show fit report.
|
|
300
|
+
"""
|
|
301
|
+
print(self.modelresult.fit_report())
|
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
from scipy.ndimage import affine_transform
|
|
3
|
+
from matplotlib.transforms import Affine2D
|
|
4
|
+
|
|
5
|
+
def shear_transformation(angle):
|
|
6
|
+
# Define shear transformation
|
|
7
|
+
t = Affine2D()
|
|
8
|
+
|
|
9
|
+
# Scale y-axis to preserve norm while shearing
|
|
10
|
+
t += Affine2D().scale(1, np.cos(angle * np.pi / 180))
|
|
11
|
+
|
|
12
|
+
# Shear along x-axis
|
|
13
|
+
t += Affine2D().skew_deg(angle, 0)
|
|
14
|
+
|
|
15
|
+
# Return to original y-axis scaling
|
|
16
|
+
t += Affine2D().scale(1, np.cos(angle * np.pi / 180)).inverted()
|
|
17
|
+
|
|
18
|
+
return t
|
|
19
|
+
|
|
20
|
+
class ShearTransformer():
|
|
21
|
+
def __init__(self, angle):
|
|
22
|
+
self.shear_angle = 90 - angle
|
|
23
|
+
self.t = shear_transformation(self.shear_angle)
|
|
24
|
+
self.scale = np.cos(self.shear_angle * np.pi / 180)
|
|
25
|
+
|
|
26
|
+
def apply(self, image):
|
|
27
|
+
# Perform shear operation
|
|
28
|
+
image_skewed = affine_transform(image, self.t.inverted().get_matrix()[:2, :2],
|
|
29
|
+
offset=[image.shape[0] / 2 * np.sin(self.shear_angle * np.pi / 180), 0],
|
|
30
|
+
order=0
|
|
31
|
+
)
|
|
32
|
+
# Scale data based on skew angle
|
|
33
|
+
image_scaled = affine_transform(image_skewed, Affine2D().scale(self.scale, 1).get_matrix()[:2, :2],
|
|
34
|
+
offset=[(1 - self.scale) * image.shape[0] / 2, 0],
|
|
35
|
+
order=0
|
|
36
|
+
)
|
|
37
|
+
return image_scaled
|
|
38
|
+
|
|
39
|
+
def invert(self, image):
|
|
40
|
+
|
|
41
|
+
# Undo scaling
|
|
42
|
+
image_unscaled = affine_transform(image, Affine2D().scale(self.scale, 1).inverted().get_matrix()[:2, :2],
|
|
43
|
+
offset=[-(1 - self.scale) * image.shape[0] / 2 / self.scale, 0],
|
|
44
|
+
order=0
|
|
45
|
+
)
|
|
46
|
+
# Undo shear operation
|
|
47
|
+
image_unskewed = affine_transform(image_unscaled, self.t.get_matrix()[:2, :2],
|
|
48
|
+
offset=[(-image.shape[0] / 2 * np.sin(self.shear_angle * np.pi / 180)), 0],
|
|
49
|
+
order=0
|
|
50
|
+
)
|
|
51
|
+
return image_unskewed
|