litert-tunner 0.1.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- litert_tunner-0.1.1/LICENSE +21 -0
- litert_tunner-0.1.1/PKG-INFO +149 -0
- litert_tunner-0.1.1/README.md +96 -0
- litert_tunner-0.1.1/pyproject.toml +181 -0
- litert_tunner-0.1.1/setup.cfg +4 -0
- litert_tunner-0.1.1/src/litert_tunner/__init__.py +46 -0
- litert_tunner-0.1.1/src/litert_tunner/export.py +100 -0
- litert_tunner-0.1.1/src/litert_tunner/flatbuffer/__init__.py +9 -0
- litert_tunner-0.1.1/src/litert_tunner/flatbuffer/parser.py +375 -0
- litert_tunner-0.1.1/src/litert_tunner/flatbuffer/writer.py +182 -0
- litert_tunner-0.1.1/src/litert_tunner/graph/__init__.py +44 -0
- litert_tunner-0.1.1/src/litert_tunner/graph/builder.py +77 -0
- litert_tunner-0.1.1/src/litert_tunner/graph/types.py +167 -0
- litert_tunner-0.1.1/src/litert_tunner/logging.py +30 -0
- litert_tunner-0.1.1/src/litert_tunner/ops/__init__.py +68 -0
- litert_tunner-0.1.1/src/litert_tunner/ops/add.py +292 -0
- litert_tunner-0.1.1/src/litert_tunner/ops/concatenation.py +257 -0
- litert_tunner-0.1.1/src/litert_tunner/ops/conv2d.py +530 -0
- litert_tunner-0.1.1/src/litert_tunner/ops/dense.py +480 -0
- litert_tunner-0.1.1/src/litert_tunner/ops/depthwise_conv2d.py +574 -0
- litert_tunner-0.1.1/src/litert_tunner/ops/div.py +323 -0
- litert_tunner-0.1.1/src/litert_tunner/ops/expand_dims.py +99 -0
- litert_tunner-0.1.1/src/litert_tunner/ops/gelu.py +169 -0
- litert_tunner-0.1.1/src/litert_tunner/ops/logistic.py +151 -0
- litert_tunner-0.1.1/src/litert_tunner/ops/mean.py +251 -0
- litert_tunner-0.1.1/src/litert_tunner/ops/mul.py +292 -0
- litert_tunner-0.1.1/src/litert_tunner/ops/neg.py +152 -0
- litert_tunner-0.1.1/src/litert_tunner/ops/pack.py +143 -0
- litert_tunner-0.1.1/src/litert_tunner/ops/pool.py +139 -0
- litert_tunner-0.1.1/src/litert_tunner/ops/quantize_op.py +281 -0
- litert_tunner-0.1.1/src/litert_tunner/ops/registry.py +77 -0
- litert_tunner-0.1.1/src/litert_tunner/ops/relu.py +148 -0
- litert_tunner-0.1.1/src/litert_tunner/ops/reshape.py +136 -0
- litert_tunner-0.1.1/src/litert_tunner/ops/resize_nearest_neighbor.py +122 -0
- litert_tunner-0.1.1/src/litert_tunner/ops/rsqrt.py +148 -0
- litert_tunner-0.1.1/src/litert_tunner/ops/shape_op.py +72 -0
- litert_tunner-0.1.1/src/litert_tunner/ops/softmax.py +171 -0
- litert_tunner-0.1.1/src/litert_tunner/ops/squared_difference.py +255 -0
- litert_tunner-0.1.1/src/litert_tunner/ops/strided_slice.py +186 -0
- litert_tunner-0.1.1/src/litert_tunner/ops/sub.py +291 -0
- litert_tunner-0.1.1/src/litert_tunner/ops/tile.py +94 -0
- litert_tunner-0.1.1/src/litert_tunner/ops/transpose.py +96 -0
- litert_tunner-0.1.1/src/litert_tunner/ops/utils.py +512 -0
- litert_tunner-0.1.1/src/litert_tunner/testing_utils.py +89 -0
- litert_tunner-0.1.1/src/litert_tunner/trainer.py +225 -0
- litert_tunner-0.1.1/src/litert_tunner.egg-info/PKG-INFO +149 -0
- litert_tunner-0.1.1/src/litert_tunner.egg-info/SOURCES.txt +54 -0
- litert_tunner-0.1.1/src/litert_tunner.egg-info/dependency_links.txt +1 -0
- litert_tunner-0.1.1/src/litert_tunner.egg-info/requires.txt +22 -0
- litert_tunner-0.1.1/src/litert_tunner.egg-info/top_level.txt +1 -0
- litert_tunner-0.1.1/tests/test_export.py +158 -0
- litert_tunner-0.1.1/tests/test_finetuning_e2e.py +99 -0
- litert_tunner-0.1.1/tests/test_load_save_roundtrip.py +81 -0
- litert_tunner-0.1.1/tests/test_logging.py +22 -0
- litert_tunner-0.1.1/tests/test_testing_utils.py +84 -0
- litert_tunner-0.1.1/tests/test_trainer.py +102 -0
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2026 Krzysztof Kolasinski
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
|
@@ -0,0 +1,149 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: litert_tunner
|
|
3
|
+
Version: 0.1.1
|
|
4
|
+
Summary: LiteRT tuner Python library.
|
|
5
|
+
Author: litert-tunner contributors
|
|
6
|
+
License: MIT License
|
|
7
|
+
|
|
8
|
+
Copyright (c) 2026 Krzysztof Kolasinski
|
|
9
|
+
|
|
10
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
11
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
12
|
+
in the Software without restriction, including without limitation the rights
|
|
13
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
14
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
15
|
+
furnished to do so, subject to the following conditions:
|
|
16
|
+
|
|
17
|
+
The above copyright notice and this permission notice shall be included in all
|
|
18
|
+
copies or substantial portions of the Software.
|
|
19
|
+
|
|
20
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
21
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
22
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
23
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
24
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
25
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
26
|
+
SOFTWARE.
|
|
27
|
+
|
|
28
|
+
Requires-Python: >=3.11
|
|
29
|
+
Description-Content-Type: text/markdown
|
|
30
|
+
License-File: LICENSE
|
|
31
|
+
Requires-Dist: keras>=3.0
|
|
32
|
+
Requires-Dist: numpy
|
|
33
|
+
Requires-Dist: tflite>=2.18.0
|
|
34
|
+
Requires-Dist: flatbuffers
|
|
35
|
+
Provides-Extra: dev
|
|
36
|
+
Requires-Dist: tensorflow>=2.18.0; extra == "dev"
|
|
37
|
+
Requires-Dist: ai-edge-litert; extra == "dev"
|
|
38
|
+
Requires-Dist: bump-my-version>=0.26.0; extra == "dev"
|
|
39
|
+
Requires-Dist: pre-commit>=3.7.0; extra == "dev"
|
|
40
|
+
Requires-Dist: pytest>=8.2.0; extra == "dev"
|
|
41
|
+
Requires-Dist: pytest-cov; extra == "dev"
|
|
42
|
+
Requires-Dist: pytest-mock; extra == "dev"
|
|
43
|
+
Requires-Dist: pytest-sugar; extra == "dev"
|
|
44
|
+
Requires-Dist: pytest-watch; extra == "dev"
|
|
45
|
+
Requires-Dist: pytest-forked; extra == "dev"
|
|
46
|
+
Requires-Dist: pytest-randomly; extra == "dev"
|
|
47
|
+
Requires-Dist: pytest-xdist; extra == "dev"
|
|
48
|
+
Requires-Dist: pytest-dotenv; extra == "dev"
|
|
49
|
+
Requires-Dist: types-tensorflow; extra == "dev"
|
|
50
|
+
Requires-Dist: ruff>=0.5.0; extra == "dev"
|
|
51
|
+
Requires-Dist: ty>=0.0.1a8; extra == "dev"
|
|
52
|
+
Dynamic: license-file
|
|
53
|
+
|
|
54
|
+
# litert-tunner
|
|
55
|
+
|
|
56
|
+
Fine-tune fully-quantized **INT8 LiteRT** (TFLite) models *after* export — without retraining from scratch.
|
|
57
|
+
|
|
58
|
+
## What is this?
|
|
59
|
+
|
|
60
|
+
**litert-tunner** parses an exported INT8 `.tflite` model, reconstructs it as a Keras 3 model with simulated quantization, lets you fine-tune parameters (biases, scales, weights), and writes them back into the flatbuffer. Graph topology stays intact.
|
|
61
|
+
|
|
62
|
+
> **Note:** Even though this library simulates the quantization process during fine-tuning, in some cases the resulting INT8 model may still perform worse than a full-precision float32 model.
|
|
63
|
+
|
|
64
|
+
## Installation
|
|
65
|
+
|
|
66
|
+
### From source (recommended for development)
|
|
67
|
+
|
|
68
|
+
```bash
|
|
69
|
+
# Clone the repository
|
|
70
|
+
git clone https://github.com/kmkolasinski/litert-tunner.git
|
|
71
|
+
cd litert-tunner
|
|
72
|
+
|
|
73
|
+
# Create a virtual environment and install uv
|
|
74
|
+
make venv
|
|
75
|
+
make init
|
|
76
|
+
|
|
77
|
+
# Install the project in editable mode with dev dependencies
|
|
78
|
+
make install
|
|
79
|
+
```
|
|
80
|
+
|
|
81
|
+
### As a dependency (pip)
|
|
82
|
+
|
|
83
|
+
```bash
|
|
84
|
+
pip install litert-tunner
|
|
85
|
+
```
|
|
86
|
+
|
|
87
|
+
*(Core dependencies: `keras>=3.0`, `numpy`, `tflite>=2.18.0`, `flatbuffers`. `tensorflow` and `ai-edge-litert` are only needed for dev/tests).*
|
|
88
|
+
|
|
89
|
+
## Quickstart
|
|
90
|
+
|
|
91
|
+
For a complete, runnable end-to-end example, see the **[Example Notebook](notebooks/quickstart_finetuning.ipynb)**.
|
|
92
|
+
|
|
93
|
+
```python
|
|
94
|
+
import litert_tunner
|
|
95
|
+
|
|
96
|
+
# 1. Load an INT8 LiteRT model → trainable Keras replica
|
|
97
|
+
model = litert_tunner.load_model("model_int8.tflite")
|
|
98
|
+
|
|
99
|
+
# 2. Inference — should match LiteRT Interpreter output
|
|
100
|
+
predictions = model.predict(inputs)
|
|
101
|
+
|
|
102
|
+
# 3. Prepare for fine-tuning (freeze everything except biases & scales)
|
|
103
|
+
litert_tunner.prepare_for_finetuning(model, trainable_pattern=".*bias")
|
|
104
|
+
|
|
105
|
+
# 4. Fine-tune with any Keras optimizer / loss
|
|
106
|
+
model.compile(optimizer="adam", loss="mse")
|
|
107
|
+
model.fit(x_train, y_train, epochs=5)
|
|
108
|
+
|
|
109
|
+
# 5. Export — writes updated parameters back into the flatbuffer
|
|
110
|
+
litert_tunner.save_model(model, "model_int8_finetuned.tflite")
|
|
111
|
+
```
|
|
112
|
+
|
|
113
|
+
### Distillation Fine-Tuning (Recommended)
|
|
114
|
+
|
|
115
|
+
```python
|
|
116
|
+
import litert_tunner
|
|
117
|
+
|
|
118
|
+
# 1. Load an INT8 LiteRT model → trainable Keras replica
|
|
119
|
+
tunner_model = litert_tunner.load_model("model_int8.tflite")
|
|
120
|
+
|
|
121
|
+
# 2. Freeze everything except biases
|
|
122
|
+
litert_tunner.prepare_for_finetuning(tunner_model, trainable_pattern=".*bias")
|
|
123
|
+
|
|
124
|
+
# 3. Fine-tune using Trainer (handles distillation & weight drift)
|
|
125
|
+
trainer = litert_tunner.Trainer(
|
|
126
|
+
student_model=tunner_model,
|
|
127
|
+
teacher_model=teacher_model, # Original float32 model
|
|
128
|
+
)
|
|
129
|
+
trainer.compile(optimizer="adam", loss="mse")
|
|
130
|
+
trainer.fit(train_ds, validation_data=val_ds, epochs=5)
|
|
131
|
+
|
|
132
|
+
# 4. Save updated parameters to flatbuffer
|
|
133
|
+
litert_tunner.save_model(tunner_model, "model_int8_finetuned.tflite")
|
|
134
|
+
```
|
|
135
|
+
|
|
136
|
+
## Supported Operations
|
|
137
|
+
|
|
138
|
+
- **Linear:** `FULLY_CONNECTED`, `CONV_2D`, `DEPTHWISE_CONV_2D`
|
|
139
|
+
- **Arithmetic:** `ADD`, `SUB`, `MUL`, `DIV`, `SQUARED_DIFFERENCE`, `NEG`
|
|
140
|
+
- **Activation:** `RELU`, `GELU`, `LOGISTIC`, `SOFTMAX`
|
|
141
|
+
- **Pooling:** `AVERAGE_POOL_2D`, `MAX_POOL_2D`, `MEAN`
|
|
142
|
+
- **Reshape/Resize:** `RESHAPE`, `TRANSPOSE`, `PACK`, `STRIDED_SLICE`, `RESIZE_NEAREST_NEIGHBOR`, `SHAPE`
|
|
143
|
+
- **Other:** `CONCATENATION`, `RSQRT`, `QUANTIZE`, `DEQUANTIZE`
|
|
144
|
+
|
|
145
|
+
Fused activations (`RELU`, `RELU6`, `RELU_N1_TO_1`) are supported.
|
|
146
|
+
|
|
147
|
+
## License
|
|
148
|
+
|
|
149
|
+
[MIT](LICENSE) — Copyright © 2026 Krzysztof Kolasinski
|
|
@@ -0,0 +1,96 @@
|
|
|
1
|
+
# litert-tunner
|
|
2
|
+
|
|
3
|
+
Fine-tune fully-quantized **INT8 LiteRT** (TFLite) models *after* export — without retraining from scratch.
|
|
4
|
+
|
|
5
|
+
## What is this?
|
|
6
|
+
|
|
7
|
+
**litert-tunner** parses an exported INT8 `.tflite` model, reconstructs it as a Keras 3 model with simulated quantization, lets you fine-tune parameters (biases, scales, weights), and writes them back into the flatbuffer. Graph topology stays intact.
|
|
8
|
+
|
|
9
|
+
> **Note:** Even though this library simulates the quantization process during fine-tuning, in some cases the resulting INT8 model may still perform worse than a full-precision float32 model.
|
|
10
|
+
|
|
11
|
+
## Installation
|
|
12
|
+
|
|
13
|
+
### From source (recommended for development)
|
|
14
|
+
|
|
15
|
+
```bash
|
|
16
|
+
# Clone the repository
|
|
17
|
+
git clone https://github.com/kmkolasinski/litert-tunner.git
|
|
18
|
+
cd litert-tunner
|
|
19
|
+
|
|
20
|
+
# Create a virtual environment and install uv
|
|
21
|
+
make venv
|
|
22
|
+
make init
|
|
23
|
+
|
|
24
|
+
# Install the project in editable mode with dev dependencies
|
|
25
|
+
make install
|
|
26
|
+
```
|
|
27
|
+
|
|
28
|
+
### As a dependency (pip)
|
|
29
|
+
|
|
30
|
+
```bash
|
|
31
|
+
pip install litert-tunner
|
|
32
|
+
```
|
|
33
|
+
|
|
34
|
+
*(Core dependencies: `keras>=3.0`, `numpy`, `tflite>=2.18.0`, `flatbuffers`. `tensorflow` and `ai-edge-litert` are only needed for dev/tests).*
|
|
35
|
+
|
|
36
|
+
## Quickstart
|
|
37
|
+
|
|
38
|
+
For a complete, runnable end-to-end example, see the **[Example Notebook](notebooks/quickstart_finetuning.ipynb)**.
|
|
39
|
+
|
|
40
|
+
```python
|
|
41
|
+
import litert_tunner
|
|
42
|
+
|
|
43
|
+
# 1. Load an INT8 LiteRT model → trainable Keras replica
|
|
44
|
+
model = litert_tunner.load_model("model_int8.tflite")
|
|
45
|
+
|
|
46
|
+
# 2. Inference — should match LiteRT Interpreter output
|
|
47
|
+
predictions = model.predict(inputs)
|
|
48
|
+
|
|
49
|
+
# 3. Prepare for fine-tuning (freeze everything except biases & scales)
|
|
50
|
+
litert_tunner.prepare_for_finetuning(model, trainable_pattern=".*bias")
|
|
51
|
+
|
|
52
|
+
# 4. Fine-tune with any Keras optimizer / loss
|
|
53
|
+
model.compile(optimizer="adam", loss="mse")
|
|
54
|
+
model.fit(x_train, y_train, epochs=5)
|
|
55
|
+
|
|
56
|
+
# 5. Export — writes updated parameters back into the flatbuffer
|
|
57
|
+
litert_tunner.save_model(model, "model_int8_finetuned.tflite")
|
|
58
|
+
```
|
|
59
|
+
|
|
60
|
+
### Distillation Fine-Tuning (Recommended)
|
|
61
|
+
|
|
62
|
+
```python
|
|
63
|
+
import litert_tunner
|
|
64
|
+
|
|
65
|
+
# 1. Load an INT8 LiteRT model → trainable Keras replica
|
|
66
|
+
tunner_model = litert_tunner.load_model("model_int8.tflite")
|
|
67
|
+
|
|
68
|
+
# 2. Freeze everything except biases
|
|
69
|
+
litert_tunner.prepare_for_finetuning(tunner_model, trainable_pattern=".*bias")
|
|
70
|
+
|
|
71
|
+
# 3. Fine-tune using Trainer (handles distillation & weight drift)
|
|
72
|
+
trainer = litert_tunner.Trainer(
|
|
73
|
+
student_model=tunner_model,
|
|
74
|
+
teacher_model=teacher_model, # Original float32 model
|
|
75
|
+
)
|
|
76
|
+
trainer.compile(optimizer="adam", loss="mse")
|
|
77
|
+
trainer.fit(train_ds, validation_data=val_ds, epochs=5)
|
|
78
|
+
|
|
79
|
+
# 4. Save updated parameters to flatbuffer
|
|
80
|
+
litert_tunner.save_model(tunner_model, "model_int8_finetuned.tflite")
|
|
81
|
+
```
|
|
82
|
+
|
|
83
|
+
## Supported Operations
|
|
84
|
+
|
|
85
|
+
- **Linear:** `FULLY_CONNECTED`, `CONV_2D`, `DEPTHWISE_CONV_2D`
|
|
86
|
+
- **Arithmetic:** `ADD`, `SUB`, `MUL`, `DIV`, `SQUARED_DIFFERENCE`, `NEG`
|
|
87
|
+
- **Activation:** `RELU`, `GELU`, `LOGISTIC`, `SOFTMAX`
|
|
88
|
+
- **Pooling:** `AVERAGE_POOL_2D`, `MAX_POOL_2D`, `MEAN`
|
|
89
|
+
- **Reshape/Resize:** `RESHAPE`, `TRANSPOSE`, `PACK`, `STRIDED_SLICE`, `RESIZE_NEAREST_NEIGHBOR`, `SHAPE`
|
|
90
|
+
- **Other:** `CONCATENATION`, `RSQRT`, `QUANTIZE`, `DEQUANTIZE`
|
|
91
|
+
|
|
92
|
+
Fused activations (`RELU`, `RELU6`, `RELU_N1_TO_1`) are supported.
|
|
93
|
+
|
|
94
|
+
## License
|
|
95
|
+
|
|
96
|
+
[MIT](LICENSE) — Copyright © 2026 Krzysztof Kolasinski
|
|
@@ -0,0 +1,181 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["setuptools>=69", "wheel"]
|
|
3
|
+
build-backend = "setuptools.build_meta"
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "litert_tunner"
|
|
7
|
+
version = "0.1.1"
|
|
8
|
+
description = "LiteRT tuner Python library."
|
|
9
|
+
readme = "README.md"
|
|
10
|
+
requires-python = ">=3.11"
|
|
11
|
+
license = { file = "LICENSE" }
|
|
12
|
+
authors = [{ name = "litert-tunner contributors" }]
|
|
13
|
+
dependencies = [
|
|
14
|
+
"keras>=3.0",
|
|
15
|
+
"numpy",
|
|
16
|
+
"tflite>=2.18.0",
|
|
17
|
+
"flatbuffers",
|
|
18
|
+
]
|
|
19
|
+
|
|
20
|
+
[project.optional-dependencies]
|
|
21
|
+
dev = [
|
|
22
|
+
"tensorflow>=2.18.0",
|
|
23
|
+
"ai-edge-litert",
|
|
24
|
+
"bump-my-version>=0.26.0",
|
|
25
|
+
"pre-commit>=3.7.0",
|
|
26
|
+
"pytest>=8.2.0",
|
|
27
|
+
"pytest-cov",
|
|
28
|
+
"pytest-mock",
|
|
29
|
+
"pytest-sugar",
|
|
30
|
+
"pytest-watch",
|
|
31
|
+
"pytest-forked",
|
|
32
|
+
"pytest-randomly",
|
|
33
|
+
"pytest-xdist",
|
|
34
|
+
"pytest-dotenv",
|
|
35
|
+
"types-tensorflow",
|
|
36
|
+
"ruff>=0.5.0",
|
|
37
|
+
"ty>=0.0.1a8",
|
|
38
|
+
]
|
|
39
|
+
|
|
40
|
+
[tool.setuptools]
|
|
41
|
+
package-dir = { "" = "src" }
|
|
42
|
+
|
|
43
|
+
[tool.setuptools.packages.find]
|
|
44
|
+
where = ["src"]
|
|
45
|
+
|
|
46
|
+
[tool.pytest.ini_options]
|
|
47
|
+
testpaths = [
|
|
48
|
+
"tests",
|
|
49
|
+
]
|
|
50
|
+
|
|
51
|
+
# Load environment variables from pytest.env file
|
|
52
|
+
env_files = [
|
|
53
|
+
"pytest.env",
|
|
54
|
+
]
|
|
55
|
+
|
|
56
|
+
addopts = [
|
|
57
|
+
"-vv",
|
|
58
|
+
"--junit-xml=test-results/junit.xml",
|
|
59
|
+
"--randomly-seed=32",
|
|
60
|
+
]
|
|
61
|
+
|
|
62
|
+
[tool.coverage.run]
|
|
63
|
+
parallel = true
|
|
64
|
+
patch = ["_exit"]
|
|
65
|
+
data_file = "test-results/.coverage"
|
|
66
|
+
|
|
67
|
+
[tool.ruff]
|
|
68
|
+
target-version = "py311"
|
|
69
|
+
line-length = 100
|
|
70
|
+
indent-width = 4
|
|
71
|
+
exclude = [
|
|
72
|
+
"metadata_pb2.py",
|
|
73
|
+
]
|
|
74
|
+
|
|
75
|
+
# Apply fixes automatically
|
|
76
|
+
fix = true
|
|
77
|
+
|
|
78
|
+
[tool.ruff.format]
|
|
79
|
+
preview = true
|
|
80
|
+
# Like Black, use double quotes for strings.
|
|
81
|
+
quote-style = "double"
|
|
82
|
+
|
|
83
|
+
# Like Black, indent with spaces, rather than tabs.
|
|
84
|
+
indent-style = "space"
|
|
85
|
+
|
|
86
|
+
# Like Black, respect magic trailing commas.
|
|
87
|
+
skip-magic-trailing-comma = false
|
|
88
|
+
|
|
89
|
+
# Like Black, automatically detect the appropriate line ending.
|
|
90
|
+
line-ending = "auto"
|
|
91
|
+
|
|
92
|
+
# By default, Ruff enables Flake8's F rules, along with a subset of the E rules,
|
|
93
|
+
# omitting any stylistic rules that overlap with the use of a formatter, like ruff format or Black.
|
|
94
|
+
# For more details see https://docs.astral.sh/ruff/rules/
|
|
95
|
+
[tool.ruff.lint]
|
|
96
|
+
exclude = ["*.ipynb"]
|
|
97
|
+
select = ["ALL"]
|
|
98
|
+
|
|
99
|
+
ignore = [
|
|
100
|
+
# Mostly Missing docstring errors - most of them are False Positives
|
|
101
|
+
"D100", "D101", "D102", "D103", "D104", "D105", "D107", "D205",
|
|
102
|
+
"ANN204", # Missing type annotation for self in method
|
|
103
|
+
"ANN201", # Missing return type annotation for public function
|
|
104
|
+
"ANN002", "ANN003", # Missing type annotation for `*args` or `**kwargs`
|
|
105
|
+
"COM812", # Conflicts with another rule
|
|
106
|
+
"TD002", "TD003", "FIX002", # Todos related formatting rules
|
|
107
|
+
"TRY003", # Avoid specifying long messages outside the exception class
|
|
108
|
+
"EM101", # Exception must not use a string literal, assign to variable first
|
|
109
|
+
"G004", # Logging uses f-string formatting
|
|
110
|
+
"PLR0913", # Too many arguments in function - some functions are complex by nature
|
|
111
|
+
]
|
|
112
|
+
|
|
113
|
+
[tool.ruff.lint.per-file-ignores]
|
|
114
|
+
"src/litert_tunner/testing_utils.py" = [
|
|
115
|
+
"S101", # Use of assert detected
|
|
116
|
+
]
|
|
117
|
+
"tests/*" = [
|
|
118
|
+
"SLF001", # Private member accessed
|
|
119
|
+
"PERF", # Performance checks are disabled for tests
|
|
120
|
+
"PLR2004", # Magic value used in comparison
|
|
121
|
+
"ARG001", # Unused argument
|
|
122
|
+
"PT019", # Fixture `_` without value is injected as parameter
|
|
123
|
+
"S101", # Use of assert detected
|
|
124
|
+
"ANN", # Missing type annotation in tests is allowed
|
|
125
|
+
"ISC001", # single-line-implicit-string-concatenation - conflicts with another rule
|
|
126
|
+
"PT011", # ValueError is too broad
|
|
127
|
+
"FBT", # Boolean positional arguments are allowed in tests
|
|
128
|
+
"C901", # Complexity checks are disabled for test code
|
|
129
|
+
"PLR0912", # Branch complexity checks are disabled for test code
|
|
130
|
+
"PLR0915", # Statement complexity checks are disabled for test code
|
|
131
|
+
"S108", # Insecure temp files are allowed in test code
|
|
132
|
+
]
|
|
133
|
+
|
|
134
|
+
# Only allow autofix for these import-related and docstring rules
|
|
135
|
+
fixable = ["F401", "F403", "I001", "D411", "D415", "D208", "D209"]
|
|
136
|
+
|
|
137
|
+
[tool.ruff.lint.pydocstyle]
|
|
138
|
+
convention = "google"
|
|
139
|
+
|
|
140
|
+
[tool.bumpversion]
|
|
141
|
+
current_version = "0.1.1"
|
|
142
|
+
commit = true
|
|
143
|
+
tag = true
|
|
144
|
+
tag_name = "v{new_version}"
|
|
145
|
+
allow_dirty = false
|
|
146
|
+
|
|
147
|
+
[[tool.bumpversion.files]]
|
|
148
|
+
filename = "pyproject.toml"
|
|
149
|
+
search = 'version = "{current_version}"'
|
|
150
|
+
replace = 'version = "{new_version}"'
|
|
151
|
+
|
|
152
|
+
[[tool.bumpversion.files]]
|
|
153
|
+
filename = "src/litert_tunner/__init__.py"
|
|
154
|
+
search = '__version__ = "{current_version}"'
|
|
155
|
+
replace = '__version__ = "{new_version}"'
|
|
156
|
+
|
|
157
|
+
[tool.interrogate]
|
|
158
|
+
ignore-init-method = true
|
|
159
|
+
ignore-init-module = true
|
|
160
|
+
ignore-magic = false
|
|
161
|
+
ignore-semiprivate = false
|
|
162
|
+
ignore-private = false
|
|
163
|
+
ignore-property-decorators = false
|
|
164
|
+
ignore-module = true
|
|
165
|
+
ignore-nested-functions = false
|
|
166
|
+
ignore-nested-classes = true
|
|
167
|
+
ignore-overloaded-functions = true
|
|
168
|
+
ignore-setters = false
|
|
169
|
+
fail-under = 80
|
|
170
|
+
exclude = ["setup.py", "docs", "build"]
|
|
171
|
+
verbose = 1
|
|
172
|
+
quiet = false
|
|
173
|
+
whitelist-regex = []
|
|
174
|
+
color = true
|
|
175
|
+
omit-covered-files = false
|
|
176
|
+
|
|
177
|
+
[tool.pyright]
|
|
178
|
+
include = ["src", "tests"]
|
|
179
|
+
extraPaths = ["."]
|
|
180
|
+
venvPath = "."
|
|
181
|
+
venv = ".venv"
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
"""litert_tunner package."""
|
|
2
|
+
|
|
3
|
+
import keras
|
|
4
|
+
|
|
5
|
+
from litert_tunner import flatbuffer, graph
|
|
6
|
+
from litert_tunner.testing_utils import (
|
|
7
|
+
assert_allclose_with_mismatch_tolerance,
|
|
8
|
+
assert_cosine_similarity,
|
|
9
|
+
)
|
|
10
|
+
from litert_tunner.trainer import Trainer, cosine_similarity, prepare_for_finetuning
|
|
11
|
+
|
|
12
|
+
__version__ = "0.1.1"
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def load_model(path: str) -> keras.Model:
|
|
16
|
+
"""Load a .tflite INT8 model and return a trainable Keras replica.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
path: Path to the .tflite file.
|
|
20
|
+
|
|
21
|
+
Returns:
|
|
22
|
+
A trainable Keras Model replica.
|
|
23
|
+
"""
|
|
24
|
+
graph_def = flatbuffer.parse_tflite(path)
|
|
25
|
+
return graph.build_keras_model(graph_def)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def save_model(model: keras.Model, path: str) -> None:
|
|
29
|
+
"""Save updated parameters back to a .tflite file.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
model: The trained Keras Model replica.
|
|
33
|
+
path: Path to write the updated .tflite file.
|
|
34
|
+
"""
|
|
35
|
+
flatbuffer.save_tflite(model, path)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
__all__ = [
|
|
39
|
+
"Trainer",
|
|
40
|
+
"assert_allclose_with_mismatch_tolerance",
|
|
41
|
+
"assert_cosine_similarity",
|
|
42
|
+
"cosine_similarity",
|
|
43
|
+
"load_model",
|
|
44
|
+
"prepare_for_finetuning",
|
|
45
|
+
"save_model",
|
|
46
|
+
]
|
|
@@ -0,0 +1,100 @@
|
|
|
1
|
+
from collections.abc import Callable, Iterable
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from typing import Literal, TypeAlias
|
|
4
|
+
|
|
5
|
+
import keras
|
|
6
|
+
import numpy as np
|
|
7
|
+
import tensorflow as tf
|
|
8
|
+
|
|
9
|
+
RepresentativeDataset: TypeAlias = Callable[
|
|
10
|
+
[], tf.data.Dataset | Iterable[list[np.ndarray] | dict[str, np.ndarray]]
|
|
11
|
+
]
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def export_litert_model(
|
|
15
|
+
model: keras.Model,
|
|
16
|
+
export_dir: str | Path,
|
|
17
|
+
*,
|
|
18
|
+
quantization: Literal["int8", "float32"] = "int8",
|
|
19
|
+
float_io: bool = True,
|
|
20
|
+
representative_dataset: RepresentativeDataset | None = None,
|
|
21
|
+
run_debugger: bool = False,
|
|
22
|
+
denylisted_ops: list[str] | None = None,
|
|
23
|
+
denylisted_nodes: list[str] | None = None,
|
|
24
|
+
) -> tuple[Path, "tf.lite.experimental.QuantizationDebugger | None"]:
|
|
25
|
+
"""Export a Keras model to a TFLite model.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
model: The Keras model to export.
|
|
29
|
+
export_dir: Directory where the exported model and stats will be saved.
|
|
30
|
+
quantization: The quantization mode. Either "int8" or "float32". Defaults to "int8".
|
|
31
|
+
float_io: If True, use float32 inputs/outputs even for int8 quantization. Defaults to True.
|
|
32
|
+
representative_dataset: A generator yielding sample inputs for quantization.
|
|
33
|
+
Required if quantization is "int8".
|
|
34
|
+
run_debugger: If True, run the TFLite QuantizationDebugger and save stats.
|
|
35
|
+
Only supported for int8 quantization. Defaults to False.
|
|
36
|
+
denylisted_ops: List of ops to exclude from quantization.
|
|
37
|
+
denylisted_nodes: List of nodes to exclude from quantization.
|
|
38
|
+
|
|
39
|
+
Returns:
|
|
40
|
+
A tuple containing the path to the exported .tflite model and the
|
|
41
|
+
QuantizationDebugger instance (if run_debugger is True, otherwise None).
|
|
42
|
+
"""
|
|
43
|
+
if (denylisted_ops or denylisted_nodes) and not run_debugger:
|
|
44
|
+
raise ValueError(
|
|
45
|
+
"denylisted_ops and denylisted_nodes can only be provided when run_debugger is True"
|
|
46
|
+
)
|
|
47
|
+
if quantization == "int8" and representative_dataset is None:
|
|
48
|
+
raise ValueError("representative_dataset must be provided for int8 quantization")
|
|
49
|
+
if quantization == "float32" and run_debugger:
|
|
50
|
+
raise ValueError("Debugger only supported for int8 quantization")
|
|
51
|
+
if quantization not in ("int8", "float32"):
|
|
52
|
+
msg = f"Unknown quantization: {quantization}"
|
|
53
|
+
raise ValueError(msg)
|
|
54
|
+
|
|
55
|
+
converter = tf.lite.TFLiteConverter.from_keras_model(model)
|
|
56
|
+
|
|
57
|
+
match quantization:
|
|
58
|
+
case "int8":
|
|
59
|
+
converter.optimizations = [tf.lite.Optimize.DEFAULT]
|
|
60
|
+
converter.representative_dataset = representative_dataset
|
|
61
|
+
if denylisted_ops or denylisted_nodes:
|
|
62
|
+
converter.target_spec.supported_ops = [
|
|
63
|
+
tf.lite.OpsSet.TFLITE_BUILTINS_INT8,
|
|
64
|
+
tf.lite.OpsSet.TFLITE_BUILTINS,
|
|
65
|
+
]
|
|
66
|
+
else:
|
|
67
|
+
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
|
|
68
|
+
case "float32":
|
|
69
|
+
pass
|
|
70
|
+
|
|
71
|
+
io_type = tf.float32 if float_io or quantization == "float32" else tf.int8
|
|
72
|
+
converter.inference_input_type = io_type
|
|
73
|
+
converter.inference_output_type = io_type
|
|
74
|
+
|
|
75
|
+
export_dir = Path(export_dir)
|
|
76
|
+
export_dir.mkdir(parents=True, exist_ok=True)
|
|
77
|
+
|
|
78
|
+
debugger = None
|
|
79
|
+
if run_debugger:
|
|
80
|
+
debug_options = tf.lite.experimental.QuantizationDebugOptions(
|
|
81
|
+
denylisted_ops=denylisted_ops,
|
|
82
|
+
denylisted_nodes=denylisted_nodes,
|
|
83
|
+
)
|
|
84
|
+
debugger = tf.lite.experimental.QuantizationDebugger(
|
|
85
|
+
converter=converter,
|
|
86
|
+
debug_dataset=representative_dataset,
|
|
87
|
+
debug_options=debug_options,
|
|
88
|
+
)
|
|
89
|
+
debugger.run()
|
|
90
|
+
tflite_model = debugger.get_nondebug_quantized_model()
|
|
91
|
+
|
|
92
|
+
with (export_dir / "quantization_stats.csv").open("w") as f:
|
|
93
|
+
debugger.layer_statistics_dump(f)
|
|
94
|
+
else:
|
|
95
|
+
tflite_model = converter.convert()
|
|
96
|
+
|
|
97
|
+
tflite_model_filepath = export_dir / "model.tflite"
|
|
98
|
+
tflite_model_filepath.write_bytes(tflite_model)
|
|
99
|
+
|
|
100
|
+
return tflite_model_filepath, debugger
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
"""Flatbuffer module for litert_tunner.
|
|
2
|
+
|
|
3
|
+
Exposes parse_tflite and save_tflite.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
from litert_tunner.flatbuffer.parser import parse_tflite as parse_tflite
|
|
9
|
+
from litert_tunner.flatbuffer.writer import save_tflite as save_tflite
|