litert-tunner 0.1.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. litert_tunner-0.1.1/LICENSE +21 -0
  2. litert_tunner-0.1.1/PKG-INFO +149 -0
  3. litert_tunner-0.1.1/README.md +96 -0
  4. litert_tunner-0.1.1/pyproject.toml +181 -0
  5. litert_tunner-0.1.1/setup.cfg +4 -0
  6. litert_tunner-0.1.1/src/litert_tunner/__init__.py +46 -0
  7. litert_tunner-0.1.1/src/litert_tunner/export.py +100 -0
  8. litert_tunner-0.1.1/src/litert_tunner/flatbuffer/__init__.py +9 -0
  9. litert_tunner-0.1.1/src/litert_tunner/flatbuffer/parser.py +375 -0
  10. litert_tunner-0.1.1/src/litert_tunner/flatbuffer/writer.py +182 -0
  11. litert_tunner-0.1.1/src/litert_tunner/graph/__init__.py +44 -0
  12. litert_tunner-0.1.1/src/litert_tunner/graph/builder.py +77 -0
  13. litert_tunner-0.1.1/src/litert_tunner/graph/types.py +167 -0
  14. litert_tunner-0.1.1/src/litert_tunner/logging.py +30 -0
  15. litert_tunner-0.1.1/src/litert_tunner/ops/__init__.py +68 -0
  16. litert_tunner-0.1.1/src/litert_tunner/ops/add.py +292 -0
  17. litert_tunner-0.1.1/src/litert_tunner/ops/concatenation.py +257 -0
  18. litert_tunner-0.1.1/src/litert_tunner/ops/conv2d.py +530 -0
  19. litert_tunner-0.1.1/src/litert_tunner/ops/dense.py +480 -0
  20. litert_tunner-0.1.1/src/litert_tunner/ops/depthwise_conv2d.py +574 -0
  21. litert_tunner-0.1.1/src/litert_tunner/ops/div.py +323 -0
  22. litert_tunner-0.1.1/src/litert_tunner/ops/expand_dims.py +99 -0
  23. litert_tunner-0.1.1/src/litert_tunner/ops/gelu.py +169 -0
  24. litert_tunner-0.1.1/src/litert_tunner/ops/logistic.py +151 -0
  25. litert_tunner-0.1.1/src/litert_tunner/ops/mean.py +251 -0
  26. litert_tunner-0.1.1/src/litert_tunner/ops/mul.py +292 -0
  27. litert_tunner-0.1.1/src/litert_tunner/ops/neg.py +152 -0
  28. litert_tunner-0.1.1/src/litert_tunner/ops/pack.py +143 -0
  29. litert_tunner-0.1.1/src/litert_tunner/ops/pool.py +139 -0
  30. litert_tunner-0.1.1/src/litert_tunner/ops/quantize_op.py +281 -0
  31. litert_tunner-0.1.1/src/litert_tunner/ops/registry.py +77 -0
  32. litert_tunner-0.1.1/src/litert_tunner/ops/relu.py +148 -0
  33. litert_tunner-0.1.1/src/litert_tunner/ops/reshape.py +136 -0
  34. litert_tunner-0.1.1/src/litert_tunner/ops/resize_nearest_neighbor.py +122 -0
  35. litert_tunner-0.1.1/src/litert_tunner/ops/rsqrt.py +148 -0
  36. litert_tunner-0.1.1/src/litert_tunner/ops/shape_op.py +72 -0
  37. litert_tunner-0.1.1/src/litert_tunner/ops/softmax.py +171 -0
  38. litert_tunner-0.1.1/src/litert_tunner/ops/squared_difference.py +255 -0
  39. litert_tunner-0.1.1/src/litert_tunner/ops/strided_slice.py +186 -0
  40. litert_tunner-0.1.1/src/litert_tunner/ops/sub.py +291 -0
  41. litert_tunner-0.1.1/src/litert_tunner/ops/tile.py +94 -0
  42. litert_tunner-0.1.1/src/litert_tunner/ops/transpose.py +96 -0
  43. litert_tunner-0.1.1/src/litert_tunner/ops/utils.py +512 -0
  44. litert_tunner-0.1.1/src/litert_tunner/testing_utils.py +89 -0
  45. litert_tunner-0.1.1/src/litert_tunner/trainer.py +225 -0
  46. litert_tunner-0.1.1/src/litert_tunner.egg-info/PKG-INFO +149 -0
  47. litert_tunner-0.1.1/src/litert_tunner.egg-info/SOURCES.txt +54 -0
  48. litert_tunner-0.1.1/src/litert_tunner.egg-info/dependency_links.txt +1 -0
  49. litert_tunner-0.1.1/src/litert_tunner.egg-info/requires.txt +22 -0
  50. litert_tunner-0.1.1/src/litert_tunner.egg-info/top_level.txt +1 -0
  51. litert_tunner-0.1.1/tests/test_export.py +158 -0
  52. litert_tunner-0.1.1/tests/test_finetuning_e2e.py +99 -0
  53. litert_tunner-0.1.1/tests/test_load_save_roundtrip.py +81 -0
  54. litert_tunner-0.1.1/tests/test_logging.py +22 -0
  55. litert_tunner-0.1.1/tests/test_testing_utils.py +84 -0
  56. litert_tunner-0.1.1/tests/test_trainer.py +102 -0
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2026 Krzysztof Kolasinski
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1,149 @@
1
+ Metadata-Version: 2.4
2
+ Name: litert_tunner
3
+ Version: 0.1.1
4
+ Summary: LiteRT tuner Python library.
5
+ Author: litert-tunner contributors
6
+ License: MIT License
7
+
8
+ Copyright (c) 2026 Krzysztof Kolasinski
9
+
10
+ Permission is hereby granted, free of charge, to any person obtaining a copy
11
+ of this software and associated documentation files (the "Software"), to deal
12
+ in the Software without restriction, including without limitation the rights
13
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
14
+ copies of the Software, and to permit persons to whom the Software is
15
+ furnished to do so, subject to the following conditions:
16
+
17
+ The above copyright notice and this permission notice shall be included in all
18
+ copies or substantial portions of the Software.
19
+
20
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
21
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
22
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
23
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
24
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
25
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
26
+ SOFTWARE.
27
+
28
+ Requires-Python: >=3.11
29
+ Description-Content-Type: text/markdown
30
+ License-File: LICENSE
31
+ Requires-Dist: keras>=3.0
32
+ Requires-Dist: numpy
33
+ Requires-Dist: tflite>=2.18.0
34
+ Requires-Dist: flatbuffers
35
+ Provides-Extra: dev
36
+ Requires-Dist: tensorflow>=2.18.0; extra == "dev"
37
+ Requires-Dist: ai-edge-litert; extra == "dev"
38
+ Requires-Dist: bump-my-version>=0.26.0; extra == "dev"
39
+ Requires-Dist: pre-commit>=3.7.0; extra == "dev"
40
+ Requires-Dist: pytest>=8.2.0; extra == "dev"
41
+ Requires-Dist: pytest-cov; extra == "dev"
42
+ Requires-Dist: pytest-mock; extra == "dev"
43
+ Requires-Dist: pytest-sugar; extra == "dev"
44
+ Requires-Dist: pytest-watch; extra == "dev"
45
+ Requires-Dist: pytest-forked; extra == "dev"
46
+ Requires-Dist: pytest-randomly; extra == "dev"
47
+ Requires-Dist: pytest-xdist; extra == "dev"
48
+ Requires-Dist: pytest-dotenv; extra == "dev"
49
+ Requires-Dist: types-tensorflow; extra == "dev"
50
+ Requires-Dist: ruff>=0.5.0; extra == "dev"
51
+ Requires-Dist: ty>=0.0.1a8; extra == "dev"
52
+ Dynamic: license-file
53
+
54
+ # litert-tunner
55
+
56
+ Fine-tune fully-quantized **INT8 LiteRT** (TFLite) models *after* export — without retraining from scratch.
57
+
58
+ ## What is this?
59
+
60
+ **litert-tunner** parses an exported INT8 `.tflite` model, reconstructs it as a Keras 3 model with simulated quantization, lets you fine-tune parameters (biases, scales, weights), and writes them back into the flatbuffer. Graph topology stays intact.
61
+
62
+ > **Note:** Even though this library simulates the quantization process during fine-tuning, in some cases the resulting INT8 model may still perform worse than a full-precision float32 model.
63
+
64
+ ## Installation
65
+
66
+ ### From source (recommended for development)
67
+
68
+ ```bash
69
+ # Clone the repository
70
+ git clone https://github.com/kmkolasinski/litert-tunner.git
71
+ cd litert-tunner
72
+
73
+ # Create a virtual environment and install uv
74
+ make venv
75
+ make init
76
+
77
+ # Install the project in editable mode with dev dependencies
78
+ make install
79
+ ```
80
+
81
+ ### As a dependency (pip)
82
+
83
+ ```bash
84
+ pip install litert-tunner
85
+ ```
86
+
87
+ *(Core dependencies: `keras>=3.0`, `numpy`, `tflite>=2.18.0`, `flatbuffers`. `tensorflow` and `ai-edge-litert` are only needed for dev/tests).*
88
+
89
+ ## Quickstart
90
+
91
+ For a complete, runnable end-to-end example, see the **[Example Notebook](notebooks/quickstart_finetuning.ipynb)**.
92
+
93
+ ```python
94
+ import litert_tunner
95
+
96
+ # 1. Load an INT8 LiteRT model → trainable Keras replica
97
+ model = litert_tunner.load_model("model_int8.tflite")
98
+
99
+ # 2. Inference — should match LiteRT Interpreter output
100
+ predictions = model.predict(inputs)
101
+
102
+ # 3. Prepare for fine-tuning (freeze everything except biases & scales)
103
+ litert_tunner.prepare_for_finetuning(model, trainable_pattern=".*bias")
104
+
105
+ # 4. Fine-tune with any Keras optimizer / loss
106
+ model.compile(optimizer="adam", loss="mse")
107
+ model.fit(x_train, y_train, epochs=5)
108
+
109
+ # 5. Export — writes updated parameters back into the flatbuffer
110
+ litert_tunner.save_model(model, "model_int8_finetuned.tflite")
111
+ ```
112
+
113
+ ### Distillation Fine-Tuning (Recommended)
114
+
115
+ ```python
116
+ import litert_tunner
117
+
118
+ # 1. Load an INT8 LiteRT model → trainable Keras replica
119
+ tunner_model = litert_tunner.load_model("model_int8.tflite")
120
+
121
+ # 2. Freeze everything except biases
122
+ litert_tunner.prepare_for_finetuning(tunner_model, trainable_pattern=".*bias")
123
+
124
+ # 3. Fine-tune using Trainer (handles distillation & weight drift)
125
+ trainer = litert_tunner.Trainer(
126
+ student_model=tunner_model,
127
+ teacher_model=teacher_model, # Original float32 model
128
+ )
129
+ trainer.compile(optimizer="adam", loss="mse")
130
+ trainer.fit(train_ds, validation_data=val_ds, epochs=5)
131
+
132
+ # 4. Save updated parameters to flatbuffer
133
+ litert_tunner.save_model(tunner_model, "model_int8_finetuned.tflite")
134
+ ```
135
+
136
+ ## Supported Operations
137
+
138
+ - **Linear:** `FULLY_CONNECTED`, `CONV_2D`, `DEPTHWISE_CONV_2D`
139
+ - **Arithmetic:** `ADD`, `SUB`, `MUL`, `DIV`, `SQUARED_DIFFERENCE`, `NEG`
140
+ - **Activation:** `RELU`, `GELU`, `LOGISTIC`, `SOFTMAX`
141
+ - **Pooling:** `AVERAGE_POOL_2D`, `MAX_POOL_2D`, `MEAN`
142
+ - **Reshape/Resize:** `RESHAPE`, `TRANSPOSE`, `PACK`, `STRIDED_SLICE`, `RESIZE_NEAREST_NEIGHBOR`, `SHAPE`
143
+ - **Other:** `CONCATENATION`, `RSQRT`, `QUANTIZE`, `DEQUANTIZE`
144
+
145
+ Fused activations (`RELU`, `RELU6`, `RELU_N1_TO_1`) are supported.
146
+
147
+ ## License
148
+
149
+ [MIT](LICENSE) — Copyright © 2026 Krzysztof Kolasinski
@@ -0,0 +1,96 @@
1
+ # litert-tunner
2
+
3
+ Fine-tune fully-quantized **INT8 LiteRT** (TFLite) models *after* export — without retraining from scratch.
4
+
5
+ ## What is this?
6
+
7
+ **litert-tunner** parses an exported INT8 `.tflite` model, reconstructs it as a Keras 3 model with simulated quantization, lets you fine-tune parameters (biases, scales, weights), and writes them back into the flatbuffer. Graph topology stays intact.
8
+
9
+ > **Note:** Even though this library simulates the quantization process during fine-tuning, in some cases the resulting INT8 model may still perform worse than a full-precision float32 model.
10
+
11
+ ## Installation
12
+
13
+ ### From source (recommended for development)
14
+
15
+ ```bash
16
+ # Clone the repository
17
+ git clone https://github.com/kmkolasinski/litert-tunner.git
18
+ cd litert-tunner
19
+
20
+ # Create a virtual environment and install uv
21
+ make venv
22
+ make init
23
+
24
+ # Install the project in editable mode with dev dependencies
25
+ make install
26
+ ```
27
+
28
+ ### As a dependency (pip)
29
+
30
+ ```bash
31
+ pip install litert-tunner
32
+ ```
33
+
34
+ *(Core dependencies: `keras>=3.0`, `numpy`, `tflite>=2.18.0`, `flatbuffers`. `tensorflow` and `ai-edge-litert` are only needed for dev/tests).*
35
+
36
+ ## Quickstart
37
+
38
+ For a complete, runnable end-to-end example, see the **[Example Notebook](notebooks/quickstart_finetuning.ipynb)**.
39
+
40
+ ```python
41
+ import litert_tunner
42
+
43
+ # 1. Load an INT8 LiteRT model → trainable Keras replica
44
+ model = litert_tunner.load_model("model_int8.tflite")
45
+
46
+ # 2. Inference — should match LiteRT Interpreter output
47
+ predictions = model.predict(inputs)
48
+
49
+ # 3. Prepare for fine-tuning (freeze everything except biases & scales)
50
+ litert_tunner.prepare_for_finetuning(model, trainable_pattern=".*bias")
51
+
52
+ # 4. Fine-tune with any Keras optimizer / loss
53
+ model.compile(optimizer="adam", loss="mse")
54
+ model.fit(x_train, y_train, epochs=5)
55
+
56
+ # 5. Export — writes updated parameters back into the flatbuffer
57
+ litert_tunner.save_model(model, "model_int8_finetuned.tflite")
58
+ ```
59
+
60
+ ### Distillation Fine-Tuning (Recommended)
61
+
62
+ ```python
63
+ import litert_tunner
64
+
65
+ # 1. Load an INT8 LiteRT model → trainable Keras replica
66
+ tunner_model = litert_tunner.load_model("model_int8.tflite")
67
+
68
+ # 2. Freeze everything except biases
69
+ litert_tunner.prepare_for_finetuning(tunner_model, trainable_pattern=".*bias")
70
+
71
+ # 3. Fine-tune using Trainer (handles distillation & weight drift)
72
+ trainer = litert_tunner.Trainer(
73
+ student_model=tunner_model,
74
+ teacher_model=teacher_model, # Original float32 model
75
+ )
76
+ trainer.compile(optimizer="adam", loss="mse")
77
+ trainer.fit(train_ds, validation_data=val_ds, epochs=5)
78
+
79
+ # 4. Save updated parameters to flatbuffer
80
+ litert_tunner.save_model(tunner_model, "model_int8_finetuned.tflite")
81
+ ```
82
+
83
+ ## Supported Operations
84
+
85
+ - **Linear:** `FULLY_CONNECTED`, `CONV_2D`, `DEPTHWISE_CONV_2D`
86
+ - **Arithmetic:** `ADD`, `SUB`, `MUL`, `DIV`, `SQUARED_DIFFERENCE`, `NEG`
87
+ - **Activation:** `RELU`, `GELU`, `LOGISTIC`, `SOFTMAX`
88
+ - **Pooling:** `AVERAGE_POOL_2D`, `MAX_POOL_2D`, `MEAN`
89
+ - **Reshape/Resize:** `RESHAPE`, `TRANSPOSE`, `PACK`, `STRIDED_SLICE`, `RESIZE_NEAREST_NEIGHBOR`, `SHAPE`
90
+ - **Other:** `CONCATENATION`, `RSQRT`, `QUANTIZE`, `DEQUANTIZE`
91
+
92
+ Fused activations (`RELU`, `RELU6`, `RELU_N1_TO_1`) are supported.
93
+
94
+ ## License
95
+
96
+ [MIT](LICENSE) — Copyright © 2026 Krzysztof Kolasinski
@@ -0,0 +1,181 @@
1
+ [build-system]
2
+ requires = ["setuptools>=69", "wheel"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "litert_tunner"
7
+ version = "0.1.1"
8
+ description = "LiteRT tuner Python library."
9
+ readme = "README.md"
10
+ requires-python = ">=3.11"
11
+ license = { file = "LICENSE" }
12
+ authors = [{ name = "litert-tunner contributors" }]
13
+ dependencies = [
14
+ "keras>=3.0",
15
+ "numpy",
16
+ "tflite>=2.18.0",
17
+ "flatbuffers",
18
+ ]
19
+
20
+ [project.optional-dependencies]
21
+ dev = [
22
+ "tensorflow>=2.18.0",
23
+ "ai-edge-litert",
24
+ "bump-my-version>=0.26.0",
25
+ "pre-commit>=3.7.0",
26
+ "pytest>=8.2.0",
27
+ "pytest-cov",
28
+ "pytest-mock",
29
+ "pytest-sugar",
30
+ "pytest-watch",
31
+ "pytest-forked",
32
+ "pytest-randomly",
33
+ "pytest-xdist",
34
+ "pytest-dotenv",
35
+ "types-tensorflow",
36
+ "ruff>=0.5.0",
37
+ "ty>=0.0.1a8",
38
+ ]
39
+
40
+ [tool.setuptools]
41
+ package-dir = { "" = "src" }
42
+
43
+ [tool.setuptools.packages.find]
44
+ where = ["src"]
45
+
46
+ [tool.pytest.ini_options]
47
+ testpaths = [
48
+ "tests",
49
+ ]
50
+
51
+ # Load environment variables from pytest.env file
52
+ env_files = [
53
+ "pytest.env",
54
+ ]
55
+
56
+ addopts = [
57
+ "-vv",
58
+ "--junit-xml=test-results/junit.xml",
59
+ "--randomly-seed=32",
60
+ ]
61
+
62
+ [tool.coverage.run]
63
+ parallel = true
64
+ patch = ["_exit"]
65
+ data_file = "test-results/.coverage"
66
+
67
+ [tool.ruff]
68
+ target-version = "py311"
69
+ line-length = 100
70
+ indent-width = 4
71
+ exclude = [
72
+ "metadata_pb2.py",
73
+ ]
74
+
75
+ # Apply fixes automatically
76
+ fix = true
77
+
78
+ [tool.ruff.format]
79
+ preview = true
80
+ # Like Black, use double quotes for strings.
81
+ quote-style = "double"
82
+
83
+ # Like Black, indent with spaces, rather than tabs.
84
+ indent-style = "space"
85
+
86
+ # Like Black, respect magic trailing commas.
87
+ skip-magic-trailing-comma = false
88
+
89
+ # Like Black, automatically detect the appropriate line ending.
90
+ line-ending = "auto"
91
+
92
+ # By default, Ruff enables Flake8's F rules, along with a subset of the E rules,
93
+ # omitting any stylistic rules that overlap with the use of a formatter, like ruff format or Black.
94
+ # For more details see https://docs.astral.sh/ruff/rules/
95
+ [tool.ruff.lint]
96
+ exclude = ["*.ipynb"]
97
+ select = ["ALL"]
98
+
99
+ ignore = [
100
+ # Mostly Missing docstring errors - most of them are False Positives
101
+ "D100", "D101", "D102", "D103", "D104", "D105", "D107", "D205",
102
+ "ANN204", # Missing type annotation for self in method
103
+ "ANN201", # Missing return type annotation for public function
104
+ "ANN002", "ANN003", # Missing type annotation for `*args` or `**kwargs`
105
+ "COM812", # Conflicts with another rule
106
+ "TD002", "TD003", "FIX002", # Todos related formatting rules
107
+ "TRY003", # Avoid specifying long messages outside the exception class
108
+ "EM101", # Exception must not use a string literal, assign to variable first
109
+ "G004", # Logging uses f-string formatting
110
+ "PLR0913", # Too many arguments in function - some functions are complex by nature
111
+ ]
112
+
113
+ [tool.ruff.lint.per-file-ignores]
114
+ "src/litert_tunner/testing_utils.py" = [
115
+ "S101", # Use of assert detected
116
+ ]
117
+ "tests/*" = [
118
+ "SLF001", # Private member accessed
119
+ "PERF", # Performance checks are disabled for tests
120
+ "PLR2004", # Magic value used in comparison
121
+ "ARG001", # Unused argument
122
+ "PT019", # Fixture `_` without value is injected as parameter
123
+ "S101", # Use of assert detected
124
+ "ANN", # Missing type annotation in tests is allowed
125
+ "ISC001", # single-line-implicit-string-concatenation - conflicts with another rule
126
+ "PT011", # ValueError is too broad
127
+ "FBT", # Boolean positional arguments are allowed in tests
128
+ "C901", # Complexity checks are disabled for test code
129
+ "PLR0912", # Branch complexity checks are disabled for test code
130
+ "PLR0915", # Statement complexity checks are disabled for test code
131
+ "S108", # Insecure temp files are allowed in test code
132
+ ]
133
+
134
+ # Only allow autofix for these import-related and docstring rules
135
+ fixable = ["F401", "F403", "I001", "D411", "D415", "D208", "D209"]
136
+
137
+ [tool.ruff.lint.pydocstyle]
138
+ convention = "google"
139
+
140
+ [tool.bumpversion]
141
+ current_version = "0.1.1"
142
+ commit = true
143
+ tag = true
144
+ tag_name = "v{new_version}"
145
+ allow_dirty = false
146
+
147
+ [[tool.bumpversion.files]]
148
+ filename = "pyproject.toml"
149
+ search = 'version = "{current_version}"'
150
+ replace = 'version = "{new_version}"'
151
+
152
+ [[tool.bumpversion.files]]
153
+ filename = "src/litert_tunner/__init__.py"
154
+ search = '__version__ = "{current_version}"'
155
+ replace = '__version__ = "{new_version}"'
156
+
157
+ [tool.interrogate]
158
+ ignore-init-method = true
159
+ ignore-init-module = true
160
+ ignore-magic = false
161
+ ignore-semiprivate = false
162
+ ignore-private = false
163
+ ignore-property-decorators = false
164
+ ignore-module = true
165
+ ignore-nested-functions = false
166
+ ignore-nested-classes = true
167
+ ignore-overloaded-functions = true
168
+ ignore-setters = false
169
+ fail-under = 80
170
+ exclude = ["setup.py", "docs", "build"]
171
+ verbose = 1
172
+ quiet = false
173
+ whitelist-regex = []
174
+ color = true
175
+ omit-covered-files = false
176
+
177
+ [tool.pyright]
178
+ include = ["src", "tests"]
179
+ extraPaths = ["."]
180
+ venvPath = "."
181
+ venv = ".venv"
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+
@@ -0,0 +1,46 @@
1
+ """litert_tunner package."""
2
+
3
+ import keras
4
+
5
+ from litert_tunner import flatbuffer, graph
6
+ from litert_tunner.testing_utils import (
7
+ assert_allclose_with_mismatch_tolerance,
8
+ assert_cosine_similarity,
9
+ )
10
+ from litert_tunner.trainer import Trainer, cosine_similarity, prepare_for_finetuning
11
+
12
+ __version__ = "0.1.1"
13
+
14
+
15
+ def load_model(path: str) -> keras.Model:
16
+ """Load a .tflite INT8 model and return a trainable Keras replica.
17
+
18
+ Args:
19
+ path: Path to the .tflite file.
20
+
21
+ Returns:
22
+ A trainable Keras Model replica.
23
+ """
24
+ graph_def = flatbuffer.parse_tflite(path)
25
+ return graph.build_keras_model(graph_def)
26
+
27
+
28
+ def save_model(model: keras.Model, path: str) -> None:
29
+ """Save updated parameters back to a .tflite file.
30
+
31
+ Args:
32
+ model: The trained Keras Model replica.
33
+ path: Path to write the updated .tflite file.
34
+ """
35
+ flatbuffer.save_tflite(model, path)
36
+
37
+
38
+ __all__ = [
39
+ "Trainer",
40
+ "assert_allclose_with_mismatch_tolerance",
41
+ "assert_cosine_similarity",
42
+ "cosine_similarity",
43
+ "load_model",
44
+ "prepare_for_finetuning",
45
+ "save_model",
46
+ ]
@@ -0,0 +1,100 @@
1
+ from collections.abc import Callable, Iterable
2
+ from pathlib import Path
3
+ from typing import Literal, TypeAlias
4
+
5
+ import keras
6
+ import numpy as np
7
+ import tensorflow as tf
8
+
9
+ RepresentativeDataset: TypeAlias = Callable[
10
+ [], tf.data.Dataset | Iterable[list[np.ndarray] | dict[str, np.ndarray]]
11
+ ]
12
+
13
+
14
+ def export_litert_model(
15
+ model: keras.Model,
16
+ export_dir: str | Path,
17
+ *,
18
+ quantization: Literal["int8", "float32"] = "int8",
19
+ float_io: bool = True,
20
+ representative_dataset: RepresentativeDataset | None = None,
21
+ run_debugger: bool = False,
22
+ denylisted_ops: list[str] | None = None,
23
+ denylisted_nodes: list[str] | None = None,
24
+ ) -> tuple[Path, "tf.lite.experimental.QuantizationDebugger | None"]:
25
+ """Export a Keras model to a TFLite model.
26
+
27
+ Args:
28
+ model: The Keras model to export.
29
+ export_dir: Directory where the exported model and stats will be saved.
30
+ quantization: The quantization mode. Either "int8" or "float32". Defaults to "int8".
31
+ float_io: If True, use float32 inputs/outputs even for int8 quantization. Defaults to True.
32
+ representative_dataset: A generator yielding sample inputs for quantization.
33
+ Required if quantization is "int8".
34
+ run_debugger: If True, run the TFLite QuantizationDebugger and save stats.
35
+ Only supported for int8 quantization. Defaults to False.
36
+ denylisted_ops: List of ops to exclude from quantization.
37
+ denylisted_nodes: List of nodes to exclude from quantization.
38
+
39
+ Returns:
40
+ A tuple containing the path to the exported .tflite model and the
41
+ QuantizationDebugger instance (if run_debugger is True, otherwise None).
42
+ """
43
+ if (denylisted_ops or denylisted_nodes) and not run_debugger:
44
+ raise ValueError(
45
+ "denylisted_ops and denylisted_nodes can only be provided when run_debugger is True"
46
+ )
47
+ if quantization == "int8" and representative_dataset is None:
48
+ raise ValueError("representative_dataset must be provided for int8 quantization")
49
+ if quantization == "float32" and run_debugger:
50
+ raise ValueError("Debugger only supported for int8 quantization")
51
+ if quantization not in ("int8", "float32"):
52
+ msg = f"Unknown quantization: {quantization}"
53
+ raise ValueError(msg)
54
+
55
+ converter = tf.lite.TFLiteConverter.from_keras_model(model)
56
+
57
+ match quantization:
58
+ case "int8":
59
+ converter.optimizations = [tf.lite.Optimize.DEFAULT]
60
+ converter.representative_dataset = representative_dataset
61
+ if denylisted_ops or denylisted_nodes:
62
+ converter.target_spec.supported_ops = [
63
+ tf.lite.OpsSet.TFLITE_BUILTINS_INT8,
64
+ tf.lite.OpsSet.TFLITE_BUILTINS,
65
+ ]
66
+ else:
67
+ converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
68
+ case "float32":
69
+ pass
70
+
71
+ io_type = tf.float32 if float_io or quantization == "float32" else tf.int8
72
+ converter.inference_input_type = io_type
73
+ converter.inference_output_type = io_type
74
+
75
+ export_dir = Path(export_dir)
76
+ export_dir.mkdir(parents=True, exist_ok=True)
77
+
78
+ debugger = None
79
+ if run_debugger:
80
+ debug_options = tf.lite.experimental.QuantizationDebugOptions(
81
+ denylisted_ops=denylisted_ops,
82
+ denylisted_nodes=denylisted_nodes,
83
+ )
84
+ debugger = tf.lite.experimental.QuantizationDebugger(
85
+ converter=converter,
86
+ debug_dataset=representative_dataset,
87
+ debug_options=debug_options,
88
+ )
89
+ debugger.run()
90
+ tflite_model = debugger.get_nondebug_quantized_model()
91
+
92
+ with (export_dir / "quantization_stats.csv").open("w") as f:
93
+ debugger.layer_statistics_dump(f)
94
+ else:
95
+ tflite_model = converter.convert()
96
+
97
+ tflite_model_filepath = export_dir / "model.tflite"
98
+ tflite_model_filepath.write_bytes(tflite_model)
99
+
100
+ return tflite_model_filepath, debugger
@@ -0,0 +1,9 @@
1
+ """Flatbuffer module for litert_tunner.
2
+
3
+ Exposes parse_tflite and save_tflite.
4
+ """
5
+
6
+ from __future__ import annotations
7
+
8
+ from litert_tunner.flatbuffer.parser import parse_tflite as parse_tflite
9
+ from litert_tunner.flatbuffer.writer import save_tflite as save_tflite