segmodels-keras 0.1.0.dev0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- segmodels_keras-0.1.0.dev0/LICENSE +21 -0
- segmodels_keras-0.1.0.dev0/MANIFEST.in +1 -0
- segmodels_keras-0.1.0.dev0/PKG-INFO +35 -0
- segmodels_keras-0.1.0.dev0/README.rst +260 -0
- segmodels_keras-0.1.0.dev0/pyproject.toml +63 -0
- segmodels_keras-0.1.0.dev0/segmodels_keras/__init__.py +132 -0
- segmodels_keras-0.1.0.dev0/segmodels_keras/__version__.py +3 -0
- segmodels_keras-0.1.0.dev0/segmodels_keras/_compat.py +3 -0
- segmodels_keras-0.1.0.dev0/segmodels_keras/backbones/__init__.py +0 -0
- segmodels_keras-0.1.0.dev0/segmodels_keras/backbones/backbones_factory.py +269 -0
- segmodels_keras-0.1.0.dev0/segmodels_keras/backbones/inception_resnet_v2.py +466 -0
- segmodels_keras-0.1.0.dev0/segmodels_keras/backbones/inception_v3.py +494 -0
- segmodels_keras-0.1.0.dev0/segmodels_keras/base/__init__.py +4 -0
- segmodels_keras-0.1.0.dev0/segmodels_keras/base/functional.py +360 -0
- segmodels_keras-0.1.0.dev0/segmodels_keras/base/objects.py +113 -0
- segmodels_keras-0.1.0.dev0/segmodels_keras/losses.py +285 -0
- segmodels_keras-0.1.0.dev0/segmodels_keras/metrics.py +281 -0
- segmodels_keras-0.1.0.dev0/segmodels_keras/models/__init__.py +0 -0
- segmodels_keras-0.1.0.dev0/segmodels_keras/models/_common_blocks.py +68 -0
- segmodels_keras-0.1.0.dev0/segmodels_keras/models/_utils.py +14 -0
- segmodels_keras-0.1.0.dev0/segmodels_keras/models/fpn.py +278 -0
- segmodels_keras-0.1.0.dev0/segmodels_keras/models/linknet.py +287 -0
- segmodels_keras-0.1.0.dev0/segmodels_keras/models/pspnet.py +263 -0
- segmodels_keras-0.1.0.dev0/segmodels_keras/models/unet.py +258 -0
- segmodels_keras-0.1.0.dev0/segmodels_keras/utils.py +85 -0
- segmodels_keras-0.1.0.dev0/segmodels_keras.egg-info/PKG-INFO +35 -0
- segmodels_keras-0.1.0.dev0/segmodels_keras.egg-info/SOURCES.txt +34 -0
- segmodels_keras-0.1.0.dev0/segmodels_keras.egg-info/dependency_links.txt +1 -0
- segmodels_keras-0.1.0.dev0/segmodels_keras.egg-info/requires.txt +6 -0
- segmodels_keras-0.1.0.dev0/segmodels_keras.egg-info/top_level.txt +1 -0
- segmodels_keras-0.1.0.dev0/setup.cfg +4 -0
- segmodels_keras-0.1.0.dev0/setup.py +81 -0
- segmodels_keras-0.1.0.dev0/tests/test_backbones.py +21 -0
- segmodels_keras-0.1.0.dev0/tests/test_metrics.py +205 -0
- segmodels_keras-0.1.0.dev0/tests/test_models.py +128 -0
- segmodels_keras-0.1.0.dev0/tests/test_utils.py +110 -0
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
The MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2018, Pavel Yakubovskiy
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in
|
|
13
|
+
all copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
|
21
|
+
THE SOFTWARE.
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
include README.md LICENSE requirements.txt
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: segmodels_keras
|
|
3
|
+
Version: 0.1.0.dev0
|
|
4
|
+
Summary: Image segmentation models with pre-trained backbones with Keras.
|
|
5
|
+
Home-page: https://github.com/theroggy/segmodels_keras
|
|
6
|
+
Author: Pieter Roggemans
|
|
7
|
+
Author-email: pieter.roggemans@gmail.com
|
|
8
|
+
License: MIT
|
|
9
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
10
|
+
Classifier: Programming Language :: Python
|
|
11
|
+
Classifier: Programming Language :: Python :: 3
|
|
12
|
+
Classifier: Programming Language :: Python :: Implementation :: CPython
|
|
13
|
+
Classifier: Programming Language :: Python :: Implementation :: PyPy
|
|
14
|
+
Requires-Python: >=3.0.0
|
|
15
|
+
Description-Content-Type: text/markdown
|
|
16
|
+
License-File: LICENSE
|
|
17
|
+
Requires-Dist: keras
|
|
18
|
+
Provides-Extra: tests
|
|
19
|
+
Requires-Dist: numpy; extra == "tests"
|
|
20
|
+
Requires-Dist: pytest; extra == "tests"
|
|
21
|
+
Requires-Dist: scikit-image; extra == "tests"
|
|
22
|
+
Dynamic: author
|
|
23
|
+
Dynamic: author-email
|
|
24
|
+
Dynamic: classifier
|
|
25
|
+
Dynamic: description
|
|
26
|
+
Dynamic: description-content-type
|
|
27
|
+
Dynamic: home-page
|
|
28
|
+
Dynamic: license
|
|
29
|
+
Dynamic: license-file
|
|
30
|
+
Dynamic: provides-extra
|
|
31
|
+
Dynamic: requires-dist
|
|
32
|
+
Dynamic: requires-python
|
|
33
|
+
Dynamic: summary
|
|
34
|
+
|
|
35
|
+
Image segmentation models with pre-trained backbones with Keras.
|
|
@@ -0,0 +1,260 @@
|
|
|
1
|
+
segmodels_keras
|
|
2
|
+
=========================
|
|
3
|
+
|
|
4
|
+
This is a fork of the
|
|
5
|
+
[segmentation_models](https://github.com/qubvel/segmentation_models) library by
|
|
6
|
+
Pavel Iakubovskii, which is not maintained anymore.
|
|
7
|
+
|
|
8
|
+
This fork is updated to support latest versions of Keras and TensorFlow, and also
|
|
9
|
+
contains some bug fixes and improvements.
|
|
10
|
+
|
|
11
|
+
It is not meant as a full replacement of the original library, but rather as a
|
|
12
|
+
solution for a library I developed and depended on segmentation_models:
|
|
13
|
+
[orthoseg](https://github.com/orthoseg/orthoseg).
|
|
14
|
+
|
|
15
|
+
Hence, backwards compatibility,... is not guaranteed or even an explicit goal.
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
**The main features** of this library are:
|
|
19
|
+
|
|
20
|
+
- High level API (just two lines of code to create model for segmentation)
|
|
21
|
+
- **4** models architectures for binary and multi-class image segmentation
|
|
22
|
+
(including legendary **Unet**)
|
|
23
|
+
- **25** available backbones for each architecture
|
|
24
|
+
- All backbones have **pre-trained** weights for faster and better
|
|
25
|
+
convergence
|
|
26
|
+
- Helpful segmentation losses (Jaccard, Dice, Focal) and metrics (IoU, F-score)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
Table of Contents
|
|
30
|
+
~~~~~~~~~~~~~~~~~
|
|
31
|
+
- `Quick start`_
|
|
32
|
+
- `Simple training pipeline`_
|
|
33
|
+
- `Examples`_
|
|
34
|
+
- `Models and Backbones`_
|
|
35
|
+
- `Installation`_
|
|
36
|
+
- `Documentation`_
|
|
37
|
+
- `Change log`_
|
|
38
|
+
- `Citing`_
|
|
39
|
+
- `License`_
|
|
40
|
+
|
|
41
|
+
Quick start
|
|
42
|
+
~~~~~~~~~~~
|
|
43
|
+
Library is build to work together with Keras and TensorFlow Keras frameworks
|
|
44
|
+
|
|
45
|
+
.. code:: python
|
|
46
|
+
|
|
47
|
+
import segmentation_models as sm
|
|
48
|
+
# Segmentation Models: using `keras` framework.
|
|
49
|
+
|
|
50
|
+
By default it tries to import ``keras``, if it is not installed, it will try to start with ``tensorflow.keras`` framework.
|
|
51
|
+
There are several ways to choose framework:
|
|
52
|
+
|
|
53
|
+
- Provide environment variable ``SM_FRAMEWORK=keras`` / ``SM_FRAMEWORK=tf.keras`` before import ``segmentation_models``
|
|
54
|
+
- Change framework ``sm.set_framework('keras')`` / ``sm.set_framework('tf.keras')``
|
|
55
|
+
|
|
56
|
+
You can also specify what kind of ``image_data_format`` to use, segmentation-models works with both: ``channels_last`` and ``channels_first``.
|
|
57
|
+
This can be useful for further model conversion to Nvidia TensorRT format or optimizing model for cpu/gpu computations.
|
|
58
|
+
|
|
59
|
+
.. code:: python
|
|
60
|
+
|
|
61
|
+
import keras
|
|
62
|
+
# or from tensorflow import keras
|
|
63
|
+
|
|
64
|
+
keras.backend.set_image_data_format('channels_last')
|
|
65
|
+
# or keras.backend.set_image_data_format('channels_first')
|
|
66
|
+
|
|
67
|
+
Created segmentation model is just an instance of Keras Model, which can be build as easy as:
|
|
68
|
+
|
|
69
|
+
.. code:: python
|
|
70
|
+
|
|
71
|
+
model = sm.Unet()
|
|
72
|
+
|
|
73
|
+
Depending on the task, you can change the network architecture by choosing backbones with fewer or more parameters and use pretrainded weights to initialize it:
|
|
74
|
+
|
|
75
|
+
.. code:: python
|
|
76
|
+
|
|
77
|
+
model = sm.Unet('resnet34', encoder_weights='imagenet')
|
|
78
|
+
|
|
79
|
+
Change number of output classes in the model (choose your case):
|
|
80
|
+
|
|
81
|
+
.. code:: python
|
|
82
|
+
|
|
83
|
+
# binary segmentation (this parameters are default when you call Unet('resnet34')
|
|
84
|
+
model = sm.Unet('resnet34', classes=1, activation='sigmoid')
|
|
85
|
+
|
|
86
|
+
.. code:: python
|
|
87
|
+
|
|
88
|
+
# multiclass segmentation with non overlapping class masks (your classes + background)
|
|
89
|
+
model = sm.Unet('resnet34', classes=3, activation='softmax')
|
|
90
|
+
|
|
91
|
+
.. code:: python
|
|
92
|
+
|
|
93
|
+
# multiclass segmentation with independent overlapping/non-overlapping class masks
|
|
94
|
+
model = sm.Unet('resnet34', classes=3, activation='sigmoid')
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
Change input shape of the model:
|
|
98
|
+
|
|
99
|
+
.. code:: python
|
|
100
|
+
|
|
101
|
+
# if you set input channels not equal to 3, you have to set encoder_weights=None
|
|
102
|
+
# how to handle such case with encoder_weights='imagenet' described in docs
|
|
103
|
+
model = Unet('resnet34', input_shape=(None, None, 6), encoder_weights=None)
|
|
104
|
+
|
|
105
|
+
Simple training pipeline
|
|
106
|
+
~~~~~~~~~~~~~~~~~~~~~~~~
|
|
107
|
+
|
|
108
|
+
.. code:: python
|
|
109
|
+
|
|
110
|
+
import segmentation_models as sm
|
|
111
|
+
|
|
112
|
+
BACKBONE = 'resnet34'
|
|
113
|
+
preprocess_input = sm.get_preprocessing(BACKBONE)
|
|
114
|
+
|
|
115
|
+
# load your data
|
|
116
|
+
x_train, y_train, x_val, y_val = load_data(...)
|
|
117
|
+
|
|
118
|
+
# preprocess input
|
|
119
|
+
x_train = preprocess_input(x_train)
|
|
120
|
+
x_val = preprocess_input(x_val)
|
|
121
|
+
|
|
122
|
+
# define model
|
|
123
|
+
model = sm.Unet(BACKBONE, encoder_weights='imagenet')
|
|
124
|
+
model.compile(
|
|
125
|
+
'Adam',
|
|
126
|
+
loss=sm.losses.bce_jaccard_loss,
|
|
127
|
+
metrics=[sm.metrics.iou_score],
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
# fit model
|
|
131
|
+
# if you use data generator use model.fit_generator(...) instead of model.fit(...)
|
|
132
|
+
# more about `fit_generator` here: https://keras.io/models/sequential/#fit_generator
|
|
133
|
+
model.fit(
|
|
134
|
+
x=x_train,
|
|
135
|
+
y=y_train,
|
|
136
|
+
batch_size=16,
|
|
137
|
+
epochs=100,
|
|
138
|
+
validation_data=(x_val, y_val),
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
Same manipulations can be done with ``Linknet``, ``PSPNet`` and ``FPN``. For more detailed information about models API and use cases `Read the Docs <https://segmentation-models.readthedocs.io/en/latest/>`__.
|
|
142
|
+
|
|
143
|
+
Examples
|
|
144
|
+
~~~~~~~~
|
|
145
|
+
Models training examples:
|
|
146
|
+
- [Jupyter Notebook] Binary segmentation (`cars`) on CamVid dataset `here <https://github.com/qubvel/segmentation_models/blob/master/examples/binary%20segmentation%20(camvid).ipynb>`__.
|
|
147
|
+
- [Jupyter Notebook] Multi-class segmentation (`cars`, `pedestrians`) on CamVid dataset `here <https://github.com/qubvel/segmentation_models/blob/master/examples/multiclass%20segmentation%20(camvid).ipynb>`__.
|
|
148
|
+
|
|
149
|
+
Models and Backbones
|
|
150
|
+
~~~~~~~~~~~~~~~~~~~~
|
|
151
|
+
**Models**
|
|
152
|
+
|
|
153
|
+
- `Unet <https://arxiv.org/abs/1505.04597>`__
|
|
154
|
+
- `FPN <http://presentations.cocodataset.org/COCO17-Stuff-FAIR.pdf>`__
|
|
155
|
+
- `Linknet <https://arxiv.org/abs/1707.03718>`__
|
|
156
|
+
- `PSPNet <https://arxiv.org/abs/1612.01105>`__
|
|
157
|
+
|
|
158
|
+
============= ==============
|
|
159
|
+
Unet Linknet
|
|
160
|
+
============= ==============
|
|
161
|
+
|unet_image| |linknet_image|
|
|
162
|
+
============= ==============
|
|
163
|
+
|
|
164
|
+
============= ==============
|
|
165
|
+
PSPNet FPN
|
|
166
|
+
============= ==============
|
|
167
|
+
|psp_image| |fpn_image|
|
|
168
|
+
============= ==============
|
|
169
|
+
|
|
170
|
+
.. _Unet: https://github.com/qubvel/segmentation_models/blob/readme/LICENSE
|
|
171
|
+
.. _Linknet: https://arxiv.org/abs/1707.03718
|
|
172
|
+
.. _PSPNet: https://arxiv.org/abs/1612.01105
|
|
173
|
+
.. _FPN: http://presentations.cocodataset.org/COCO17-Stuff-FAIR.pdf
|
|
174
|
+
|
|
175
|
+
.. |unet_image| image:: https://github.com/qubvel/segmentation_models/blob/master/images/unet.png
|
|
176
|
+
.. |linknet_image| image:: https://github.com/qubvel/segmentation_models/blob/master/images/linknet.png
|
|
177
|
+
.. |psp_image| image:: https://github.com/qubvel/segmentation_models/blob/master/images/pspnet.png
|
|
178
|
+
.. |fpn_image| image:: https://github.com/qubvel/segmentation_models/blob/master/images/fpn.png
|
|
179
|
+
|
|
180
|
+
**Backbones**
|
|
181
|
+
|
|
182
|
+
.. table::
|
|
183
|
+
|
|
184
|
+
============= =====
|
|
185
|
+
Type Names
|
|
186
|
+
============= =====
|
|
187
|
+
VGG ``'vgg16' 'vgg19'``
|
|
188
|
+
ResNet ``'resnet18' 'resnet34' 'resnet50' 'resnet101' 'resnet152'``
|
|
189
|
+
SE-ResNet ``'seresnet18' 'seresnet34' 'seresnet50' 'seresnet101' 'seresnet152'``
|
|
190
|
+
ResNeXt ``'resnext50' 'resnext101'``
|
|
191
|
+
SE-ResNeXt ``'seresnext50' 'seresnext101'``
|
|
192
|
+
SENet154 ``'senet154'``
|
|
193
|
+
DenseNet ``'densenet121' 'densenet169' 'densenet201'``
|
|
194
|
+
Inception ``'inceptionv3' 'inceptionresnetv2'``
|
|
195
|
+
MobileNet ``'mobilenet' 'mobilenetv2'``
|
|
196
|
+
EfficientNet ``'efficientnetb0' 'efficientnetb1' 'efficientnetb2' 'efficientnetb3' 'efficientnetb4' 'efficientnetb5' efficientnetb6' efficientnetb7'``
|
|
197
|
+
============= =====
|
|
198
|
+
|
|
199
|
+
.. epigraph::
|
|
200
|
+
All backbones have weights trained on 2012 ILSVRC ImageNet dataset (``encoder_weights='imagenet'``).
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
Installation
|
|
204
|
+
~~~~~~~~~~~~
|
|
205
|
+
|
|
206
|
+
**Requirements**
|
|
207
|
+
|
|
208
|
+
1) python 3
|
|
209
|
+
2) keras >= 2.2.0 or tensorflow >= 1.13
|
|
210
|
+
3) keras-applications >= 1.0.7, <=1.0.8
|
|
211
|
+
4) image-classifiers == 1.0.*
|
|
212
|
+
5) efficientnet == 1.0.*
|
|
213
|
+
|
|
214
|
+
**PyPI stable package**
|
|
215
|
+
|
|
216
|
+
.. code:: bash
|
|
217
|
+
|
|
218
|
+
$ pip install -U segmentation-models
|
|
219
|
+
|
|
220
|
+
**PyPI latest package**
|
|
221
|
+
|
|
222
|
+
.. code:: bash
|
|
223
|
+
|
|
224
|
+
$ pip install -U --pre segmentation-models
|
|
225
|
+
|
|
226
|
+
**Source latest version**
|
|
227
|
+
|
|
228
|
+
.. code:: bash
|
|
229
|
+
|
|
230
|
+
$ pip install git+https://github.com/qubvel/segmentation_models
|
|
231
|
+
|
|
232
|
+
Documentation
|
|
233
|
+
~~~~~~~~~~~~~
|
|
234
|
+
Latest **documentation** is avaliable on `Read the
|
|
235
|
+
Docs <https://segmentation-models.readthedocs.io/en/latest/>`__
|
|
236
|
+
|
|
237
|
+
Change Log
|
|
238
|
+
~~~~~~~~~~
|
|
239
|
+
To see important changes between versions look at CHANGELOG.md_
|
|
240
|
+
|
|
241
|
+
Citing
|
|
242
|
+
~~~~~~~~
|
|
243
|
+
|
|
244
|
+
.. code::
|
|
245
|
+
|
|
246
|
+
@misc{Yakubovskiy:2019,
|
|
247
|
+
Author = {Pavel Iakubovskii},
|
|
248
|
+
Title = {Segmentation Models},
|
|
249
|
+
Year = {2019},
|
|
250
|
+
Publisher = {GitHub},
|
|
251
|
+
Journal = {GitHub repository},
|
|
252
|
+
Howpublished = {\url{https://github.com/qubvel/segmentation_models}}
|
|
253
|
+
}
|
|
254
|
+
|
|
255
|
+
License
|
|
256
|
+
~~~~~~~
|
|
257
|
+
Project is distributed under `MIT Licence`_.
|
|
258
|
+
|
|
259
|
+
.. _CHANGELOG.md: https://github.com/qubvel/segmentation_models/blob/master/CHANGELOG.md
|
|
260
|
+
.. _`MIT Licence`: https://github.com/qubvel/segmentation_models/blob/master/LICENSE
|
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
[tool.pyright]
|
|
2
|
+
exclude = ["local_ignore", "**/node_modules", "**/__pycache__", "**/.*"]
|
|
3
|
+
ignore = ["*"]
|
|
4
|
+
|
|
5
|
+
[tool.pytest.ini_options]
|
|
6
|
+
log_level = "DEBUG"
|
|
7
|
+
# Write all logging and warnings immediately while running test with debugger.
|
|
8
|
+
# log_cli = true
|
|
9
|
+
# addopts = "-p no:warnings"
|
|
10
|
+
|
|
11
|
+
[tool.ruff]
|
|
12
|
+
line-length = 88
|
|
13
|
+
target-version = "py310"
|
|
14
|
+
extend-exclude = ["docs/conf.py", "local_ignore/*", "examples"]
|
|
15
|
+
|
|
16
|
+
[tool.ruff.lint]
|
|
17
|
+
select = [
|
|
18
|
+
"YTT", # flake8-2020
|
|
19
|
+
# "ANN", # flake8-annotations
|
|
20
|
+
"B", # flake8-bugbear
|
|
21
|
+
"A", # flake8-builtins
|
|
22
|
+
"C4", # flake8-comprehensions
|
|
23
|
+
"T10", # flake8-debugger
|
|
24
|
+
"ISC", # flake8-implicit string concatenation
|
|
25
|
+
# "G", # flake8-logging-format
|
|
26
|
+
"PIE", # flake8-pie misc lints
|
|
27
|
+
# "SIM", # flake8-simplify
|
|
28
|
+
"TC", # flake8-type-checking imports
|
|
29
|
+
"ARG", # flake8-unused-arguments
|
|
30
|
+
# "PTH", # flake8-use-pathlib
|
|
31
|
+
"FLY", # flynt
|
|
32
|
+
"I", # isort
|
|
33
|
+
"NPY", # NumPy-specific rules
|
|
34
|
+
"E", # pycodestyle errors
|
|
35
|
+
"W", # pycodestyle warnings
|
|
36
|
+
# "D", # pydocstyle
|
|
37
|
+
"F", # pyflakes
|
|
38
|
+
"PLC", # pylint convention
|
|
39
|
+
"PLE", # pylint error
|
|
40
|
+
"PLR", # pylint refactor
|
|
41
|
+
"PLW", # pylint warning
|
|
42
|
+
"UP", # pyupgrade
|
|
43
|
+
"RUF", # Ruff-specific rules
|
|
44
|
+
]
|
|
45
|
+
|
|
46
|
+
ignore = [
|
|
47
|
+
### Intentionally disabled
|
|
48
|
+
"E402", # Module level import not at top of file
|
|
49
|
+
"PLR0913", # Too many arguments to function call
|
|
50
|
+
"PLR0911", # Too many returns
|
|
51
|
+
"PLR0912", # Too many branches
|
|
52
|
+
"PLR0915", # Too many statements
|
|
53
|
+
"PLR2004", # Magic number
|
|
54
|
+
"PLR5501", # Use `elif` instead of `else` then `if`, to reduce indentation
|
|
55
|
+
"PLW2901", # Loop variable overwritten by assignment target
|
|
56
|
+
]
|
|
57
|
+
|
|
58
|
+
[tool.ruff.lint.per-file-ignores]
|
|
59
|
+
"tests/*" = ["ANN", "D"]
|
|
60
|
+
"perftests/*" = ["ANN", "D"]
|
|
61
|
+
|
|
62
|
+
[tool.ruff.lint.pydocstyle]
|
|
63
|
+
convention = "google"
|
|
@@ -0,0 +1,132 @@
|
|
|
1
|
+
import functools
|
|
2
|
+
import os
|
|
3
|
+
|
|
4
|
+
from . import base
|
|
5
|
+
from .__version__ import __version__
|
|
6
|
+
|
|
7
|
+
_KERAS_FRAMEWORK_NAME = "keras"
|
|
8
|
+
_TF_KERAS_FRAMEWORK_NAME = "tf.keras"
|
|
9
|
+
|
|
10
|
+
_DEFAULT_KERAS_FRAMEWORK = _KERAS_FRAMEWORK_NAME
|
|
11
|
+
_KERAS_FRAMEWORK = None
|
|
12
|
+
_KERAS_BACKEND = None
|
|
13
|
+
_KERAS_LAYERS = None
|
|
14
|
+
_KERAS_MODELS = None
|
|
15
|
+
_KERAS_UTILS = None
|
|
16
|
+
_KERAS_LOSSES = None
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def inject_global_losses(func):
|
|
20
|
+
@functools.wraps(func)
|
|
21
|
+
def wrapper(*args, **kwargs):
|
|
22
|
+
kwargs["losses"] = _KERAS_LOSSES
|
|
23
|
+
return func(*args, **kwargs)
|
|
24
|
+
|
|
25
|
+
return wrapper
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def filter_kwargs(func):
|
|
29
|
+
@functools.wraps(func)
|
|
30
|
+
def wrapper(*args, **kwargs):
|
|
31
|
+
new_kwargs = {
|
|
32
|
+
k: v
|
|
33
|
+
for k, v in kwargs.items()
|
|
34
|
+
if k in ["backend", "layers", "models", "utils"]
|
|
35
|
+
}
|
|
36
|
+
return func(*args, **new_kwargs)
|
|
37
|
+
|
|
38
|
+
return wrapper
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def framework():
|
|
42
|
+
"""Return name of Segmentation Models framework"""
|
|
43
|
+
return _KERAS_FRAMEWORK
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def set_framework(name):
|
|
47
|
+
"""Set framework for Segmentation Models
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
name (str): one of ``keras``, ``tf.keras``, case insensitive.
|
|
51
|
+
|
|
52
|
+
Raises:
|
|
53
|
+
ValueError: in case of incorrect framework name.
|
|
54
|
+
ImportError: in case framework is not installed.
|
|
55
|
+
|
|
56
|
+
"""
|
|
57
|
+
name = name.lower()
|
|
58
|
+
|
|
59
|
+
if name == _KERAS_FRAMEWORK_NAME:
|
|
60
|
+
import keras
|
|
61
|
+
elif name == _TF_KERAS_FRAMEWORK_NAME:
|
|
62
|
+
from tensorflow import keras
|
|
63
|
+
else:
|
|
64
|
+
raise ValueError(
|
|
65
|
+
f"Not correct module name `{name}`, use `{_KERAS_FRAMEWORK_NAME}` or "
|
|
66
|
+
f"`{_TF_KERAS_FRAMEWORK_NAME}`"
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
global _KERAS_BACKEND, _KERAS_LAYERS, _KERAS_MODELS
|
|
70
|
+
global _KERAS_UTILS, _KERAS_LOSSES, _KERAS_FRAMEWORK
|
|
71
|
+
|
|
72
|
+
_KERAS_FRAMEWORK = name
|
|
73
|
+
_KERAS_BACKEND = keras.backend
|
|
74
|
+
_KERAS_LAYERS = keras.layers
|
|
75
|
+
_KERAS_MODELS = keras.models
|
|
76
|
+
_KERAS_UTILS = keras.utils
|
|
77
|
+
_KERAS_LOSSES = keras.losses
|
|
78
|
+
|
|
79
|
+
# allow losses/metrics get keras submodules
|
|
80
|
+
base.KerasObject.set_submodules(
|
|
81
|
+
backend=keras.backend,
|
|
82
|
+
layers=keras.layers,
|
|
83
|
+
models=keras.models,
|
|
84
|
+
utils=keras.utils,
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
# set default framework
|
|
89
|
+
_framework = os.environ.get("SM_FRAMEWORK", _DEFAULT_KERAS_FRAMEWORK)
|
|
90
|
+
try:
|
|
91
|
+
set_framework(_framework)
|
|
92
|
+
except ImportError:
|
|
93
|
+
other = (
|
|
94
|
+
_TF_KERAS_FRAMEWORK_NAME
|
|
95
|
+
if _framework == _KERAS_FRAMEWORK_NAME
|
|
96
|
+
else _KERAS_FRAMEWORK_NAME
|
|
97
|
+
)
|
|
98
|
+
set_framework(other)
|
|
99
|
+
|
|
100
|
+
print(f"Segmentation Models: using `{_KERAS_FRAMEWORK}` framework.")
|
|
101
|
+
|
|
102
|
+
# import helper modules
|
|
103
|
+
from . import losses, metrics, utils
|
|
104
|
+
|
|
105
|
+
# wrap segmentation models with framework modules
|
|
106
|
+
from .backbones.backbones_factory import Backbones
|
|
107
|
+
from .models.fpn import FPN
|
|
108
|
+
from .models.linknet import Linknet
|
|
109
|
+
from .models.pspnet import PSPNet
|
|
110
|
+
from .models.unet import Unet
|
|
111
|
+
|
|
112
|
+
get_available_backbone_names = Backbones.models_names
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def get_preprocessing(name):
|
|
116
|
+
return Backbones.get_preprocessing(name)
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
__all__ = [
|
|
120
|
+
"FPN",
|
|
121
|
+
"Linknet",
|
|
122
|
+
"PSPNet",
|
|
123
|
+
"Unet",
|
|
124
|
+
"__version__",
|
|
125
|
+
"framework",
|
|
126
|
+
"get_available_backbone_names",
|
|
127
|
+
"get_preprocessing",
|
|
128
|
+
"losses",
|
|
129
|
+
"metrics",
|
|
130
|
+
"set_framework",
|
|
131
|
+
"utils",
|
|
132
|
+
]
|
|
File without changes
|