EasySmallEmbeddingModel 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- easysmallembeddingmodel-0.1.0.dist-info/METADATA +161 -0
- easysmallembeddingmodel-0.1.0.dist-info/RECORD +18 -0
- easysmallembeddingmodel-0.1.0.dist-info/WHEEL +5 -0
- easysmallembeddingmodel-0.1.0.dist-info/entry_points.txt +2 -0
- easysmallembeddingmodel-0.1.0.dist-info/top_level.txt +1 -0
- smallmodel/__init__.py +7 -0
- smallmodel/arch.py +441 -0
- smallmodel/cli.py +152 -0
- smallmodel/core.py +447 -0
- smallmodel/data.py +131 -0
- smallmodel/distill.py +217 -0
- smallmodel/sizing.py +152 -0
- smallmodel/teachers.py +144 -0
- smallmodel/web/__init__.py +0 -0
- smallmodel/web/app.py +573 -0
- smallmodel/web/static/css/style.css +803 -0
- smallmodel/web/static/js/app.js +587 -0
- smallmodel/web/templates/index.html +285 -0
|
@@ -0,0 +1,161 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: EasySmallEmbeddingModel
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Compress large embedding models into small, fast students via layer pruning, vocab pruning, hidden dim reduction, and knowledge distillation.
|
|
5
|
+
Author: gomyk
|
|
6
|
+
License: Apache-2.0
|
|
7
|
+
Keywords: embedding,model-compression,distillation,pruning,sentence-transformers
|
|
8
|
+
Classifier: Development Status :: 3 - Alpha
|
|
9
|
+
Classifier: Intended Audience :: Science/Research
|
|
10
|
+
Classifier: License :: OSI Approved :: Apache Software License
|
|
11
|
+
Classifier: Programming Language :: Python :: 3
|
|
12
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
13
|
+
Requires-Python: >=3.9
|
|
14
|
+
Description-Content-Type: text/markdown
|
|
15
|
+
Requires-Dist: torch>=2.0.0
|
|
16
|
+
Requires-Dist: transformers>=4.40.0
|
|
17
|
+
Requires-Dist: sentence-transformers>=3.0.0
|
|
18
|
+
Requires-Dist: numpy>=1.24.0
|
|
19
|
+
Requires-Dist: tqdm
|
|
20
|
+
Requires-Dist: sentencepiece
|
|
21
|
+
Requires-Dist: protobuf
|
|
22
|
+
Provides-Extra: eval
|
|
23
|
+
Requires-Dist: mteb>=1.14.0; extra == "eval"
|
|
24
|
+
Requires-Dist: pandas>=2.0.0; extra == "eval"
|
|
25
|
+
Provides-Extra: export
|
|
26
|
+
Requires-Dist: onnxruntime>=1.16.0; extra == "export"
|
|
27
|
+
Requires-Dist: onnx>=1.15.0; extra == "export"
|
|
28
|
+
Requires-Dist: optimum[onnxruntime]>=1.16.0; extra == "export"
|
|
29
|
+
Provides-Extra: hub
|
|
30
|
+
Requires-Dist: huggingface-hub>=0.20.0; extra == "hub"
|
|
31
|
+
Provides-Extra: web
|
|
32
|
+
Requires-Dist: flask>=3.0.0; extra == "web"
|
|
33
|
+
Provides-Extra: all
|
|
34
|
+
Requires-Dist: EasySmallEmbeddingModel[eval,export,hub,web]; extra == "all"
|
|
35
|
+
|
|
36
|
+
# SmallModel
|
|
37
|
+
|
|
38
|
+
Compress large embedding models into small, fast students via layer pruning, vocab pruning, hidden dim reduction, and knowledge distillation.
|
|
39
|
+
|
|
40
|
+
## Features
|
|
41
|
+
|
|
42
|
+
- **Layer Pruning** - Select which transformer layers to keep
|
|
43
|
+
- **Vocab Pruning** - Remove unused tokens based on corpus frequency
|
|
44
|
+
- **Hidden Dim Reduction** - Shrink internal dimensions (slicing or PCA)
|
|
45
|
+
- **Knowledge Distillation** - MSE + Cosine loss alignment with teacher
|
|
46
|
+
- **Auto Compress** - Find optimal config within size constraints
|
|
47
|
+
- **2-Stage Distillation** - Progressive distillation for 10x+ compression
|
|
48
|
+
- **Interactive Web UI** - Visual layer editor with real-time size estimation
|
|
49
|
+
- **MTEB Evaluation** - Benchmark on Classification, Clustering, STS tasks
|
|
50
|
+
|
|
51
|
+
## Installation
|
|
52
|
+
|
|
53
|
+
```bash
|
|
54
|
+
pip install smallmodel[all]
|
|
55
|
+
```
|
|
56
|
+
|
|
57
|
+
Or install specific extras:
|
|
58
|
+
|
|
59
|
+
```bash
|
|
60
|
+
pip install smallmodel # core only
|
|
61
|
+
pip install smallmodel[web] # + Flask web UI
|
|
62
|
+
pip install smallmodel[eval] # + MTEB evaluation
|
|
63
|
+
pip install smallmodel[export] # + ONNX export
|
|
64
|
+
pip install smallmodel[hub] # + HuggingFace Hub upload
|
|
65
|
+
```
|
|
66
|
+
|
|
67
|
+
For development:
|
|
68
|
+
|
|
69
|
+
```bash
|
|
70
|
+
git clone https://github.com/gomyk/smallmodel.git
|
|
71
|
+
cd smallmodel
|
|
72
|
+
pip install -e ".[all]"
|
|
73
|
+
```
|
|
74
|
+
|
|
75
|
+
## Quick Start
|
|
76
|
+
|
|
77
|
+
### Python API
|
|
78
|
+
|
|
79
|
+
```python
|
|
80
|
+
from smallmodel import SmallModel
|
|
81
|
+
|
|
82
|
+
# Auto-compress within 50MB
|
|
83
|
+
sm = SmallModel.from_teacher("gte")
|
|
84
|
+
sm.compress(max_fp32_mb=50.0)
|
|
85
|
+
sm.distill(epochs=10)
|
|
86
|
+
|
|
87
|
+
# Manual layer selection
|
|
88
|
+
sm = SmallModel.from_teacher("gte", layer_indices=[0, 3, 6, 11])
|
|
89
|
+
sm.create()
|
|
90
|
+
|
|
91
|
+
# Register custom teacher
|
|
92
|
+
from smallmodel import register_teacher
|
|
93
|
+
register_teacher(
|
|
94
|
+
"my-bert",
|
|
95
|
+
model_id="my-org/my-bert-base",
|
|
96
|
+
short_name="MyBERT",
|
|
97
|
+
hidden_dim=768, num_layers=12,
|
|
98
|
+
intermediate_size=3072, vocab_size=30522,
|
|
99
|
+
)
|
|
100
|
+
```
|
|
101
|
+
|
|
102
|
+
### Web UI
|
|
103
|
+
|
|
104
|
+
```python
|
|
105
|
+
from smallmodel import SmallModel
|
|
106
|
+
|
|
107
|
+
sm = SmallModel.from_teacher("gte")
|
|
108
|
+
sm.serve() # http://127.0.0.1:7860
|
|
109
|
+
```
|
|
110
|
+
|
|
111
|
+
Or via CLI:
|
|
112
|
+
|
|
113
|
+
```bash
|
|
114
|
+
smallmodel serve --teacher gte --port 7860
|
|
115
|
+
```
|
|
116
|
+
|
|
117
|
+
The web UI lets you:
|
|
118
|
+
- Select teacher model from 7+ pre-registered models
|
|
119
|
+
- Toggle layers on/off with preset configurations
|
|
120
|
+
- Adjust hidden dim, FFN size, and vocab size
|
|
121
|
+
- See real-time size estimation and compression ratio
|
|
122
|
+
- Select distillation datasets and evaluation tasks
|
|
123
|
+
- Analyze vocab coverage at different vocab sizes
|
|
124
|
+
- Create compressed models with one click
|
|
125
|
+
|
|
126
|
+
### CLI
|
|
127
|
+
|
|
128
|
+
```bash
|
|
129
|
+
smallmodel list-teachers
|
|
130
|
+
smallmodel compress --teacher gte --max-mb 50
|
|
131
|
+
smallmodel create --teacher gte --layers 0,3,6,11
|
|
132
|
+
smallmodel distill --teacher gte --student output/students/gte/gte_compressed
|
|
133
|
+
smallmodel serve --teacher gte
|
|
134
|
+
```
|
|
135
|
+
|
|
136
|
+
## Pre-registered Teachers
|
|
137
|
+
|
|
138
|
+
| Key | Model | Layers | Hidden | Vocab | FP32 MB |
|
|
139
|
+
|---|---|---|---|---|---|
|
|
140
|
+
| minilm | paraphrase-multilingual-MiniLM-L12-v2 | 12 | 384 | 250K | 448 |
|
|
141
|
+
| modernbert | ModernBERT-base | 22 | 768 | 50K | 496 |
|
|
142
|
+
| gte | gte-multilingual-base | 12 | 768 | 250K | 1058 |
|
|
143
|
+
| me5 | multilingual-e5-base | 12 | 768 | 250K | 1058 |
|
|
144
|
+
| me5s | multilingual-e5-small | 12 | 384 | 250K | 448 |
|
|
145
|
+
| gemma_emb | embeddinggemma-300m | 24 | 768 | 262K | 1155 |
|
|
146
|
+
| qwen3 | Qwen3-0.6B | 28 | 1024 | 152K | 2274 |
|
|
147
|
+
|
|
148
|
+
## How It Works
|
|
149
|
+
|
|
150
|
+
1. **Layer Pruning** - Copy selected layers from teacher (uniform spacing recommended)
|
|
151
|
+
2. **Hidden Dim Reduction** - Shrink dimensions if needed to meet size target
|
|
152
|
+
3. **Vocab Pruning** - Remove tokens not seen in training corpus
|
|
153
|
+
4. **Knowledge Distillation** - Train student to reproduce teacher's embeddings
|
|
154
|
+
5. **Evaluation** - MTEB benchmark (Classification, Clustering, STS)
|
|
155
|
+
|
|
156
|
+
For compression ratios > 10x, a 2-stage distillation pipeline is used:
|
|
157
|
+
Teacher → Intermediate (~1/5 teacher) → Final Student
|
|
158
|
+
|
|
159
|
+
## License
|
|
160
|
+
|
|
161
|
+
Apache-2.0
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
smallmodel/__init__.py,sha256=GgJw-ldxMcQBOv4200CmXoMrQ7jGvDlMys3esMHuO_E,257
|
|
2
|
+
smallmodel/arch.py,sha256=rUusfhJ2H8-FgQ1MseYAoy-jTuZm2xQgRpvbdK2H-_c,17205
|
|
3
|
+
smallmodel/cli.py,sha256=LhSbl_MyJYhs_l32B6gbT5teyT87D7a-0kwYhodg6NY,5341
|
|
4
|
+
smallmodel/core.py,sha256=_Z9wsEWlbEnQS2u9ppuWmHBM-bU70xrrtztBh7Rf79w,19222
|
|
5
|
+
smallmodel/data.py,sha256=udjwpEe0MH1TTYTAMy533pDmh3ASqVDbaRv0xNIDTYU,5734
|
|
6
|
+
smallmodel/distill.py,sha256=sXdV-3n4y_3At2i0qQUV8xcIFet9NMmszrn0RRVeJ-U,7739
|
|
7
|
+
smallmodel/sizing.py,sha256=RyrnJWHtWy7Ycl9pZurU6lEqu6Cx18GmqOuO-k3i7e8,5265
|
|
8
|
+
smallmodel/teachers.py,sha256=gRbDMyVAB_wcWaaE2t4BhEEY_mVekrUIxECEq1vAmHU,4474
|
|
9
|
+
smallmodel/web/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
10
|
+
smallmodel/web/app.py,sha256=Hy8r4cHGR6xqS6stry93TJYHCCnAXx5PQrn9LYRRMxg,20721
|
|
11
|
+
smallmodel/web/static/css/style.css,sha256=6t3wLF4wrBmZfKnX-Ex3PMt_YaTWVpnkFtSEmC1TDS0,15101
|
|
12
|
+
smallmodel/web/static/js/app.js,sha256=gpkOGrtzjtQFtx5lhFEwKRcX6sfH3g5Q82vHzwd7d9g,22382
|
|
13
|
+
smallmodel/web/templates/index.html,sha256=WciG04lmVwrLkC5VcuRQVZT0n3PsIq-9H0h-pwC0ZUc,16366
|
|
14
|
+
easysmallembeddingmodel-0.1.0.dist-info/METADATA,sha256=Br9F10VeA1SEkXKNnKsC2bX03edP_bhslYpBirACvIg,5227
|
|
15
|
+
easysmallembeddingmodel-0.1.0.dist-info/WHEEL,sha256=aeYiig01lYGDzBgS8HxWXOg3uV61G9ijOsup-k9o1sk,91
|
|
16
|
+
easysmallembeddingmodel-0.1.0.dist-info/entry_points.txt,sha256=HNo0Be_rqqYaeULejKSs-AZ5QDchylEn9vbZMRK_cQ8,51
|
|
17
|
+
easysmallembeddingmodel-0.1.0.dist-info/top_level.txt,sha256=Spct7pFIJwbc3SqefL9ZYf6muQQk-IGYp8Jv5qkcnHM,11
|
|
18
|
+
easysmallembeddingmodel-0.1.0.dist-info/RECORD,,
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
smallmodel
|
smallmodel/__init__.py
ADDED
smallmodel/arch.py
ADDED
|
@@ -0,0 +1,441 @@
|
|
|
1
|
+
"""Architecture utilities: layer pruning, hidden dim reduction, vocab pruning."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import copy
|
|
6
|
+
import json
|
|
7
|
+
import os
|
|
8
|
+
from collections import Counter
|
|
9
|
+
|
|
10
|
+
import torch
|
|
11
|
+
import torch.nn as nn
|
|
12
|
+
from transformers import AutoModel, AutoTokenizer, AutoConfig
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
# ── Layer Access ─────────────────────────────────────────────────
|
|
16
|
+
|
|
17
|
+
def get_layers(model, layer_accessor: str):
|
|
18
|
+
obj = model
|
|
19
|
+
for attr in layer_accessor.split("."):
|
|
20
|
+
obj = getattr(obj, attr)
|
|
21
|
+
return obj
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def set_layers(model, layer_accessor: str, new_layers):
|
|
25
|
+
parts = layer_accessor.split(".")
|
|
26
|
+
obj = model
|
|
27
|
+
for attr in parts[:-1]:
|
|
28
|
+
obj = getattr(obj, attr)
|
|
29
|
+
setattr(obj, parts[-1], new_layers)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def discover_layer_accessor(model) -> str:
|
|
33
|
+
candidates = [
|
|
34
|
+
"encoder.layer", "encoder.layers", "layers",
|
|
35
|
+
"transformer.layer", "transformer.layers",
|
|
36
|
+
]
|
|
37
|
+
for path in candidates:
|
|
38
|
+
try:
|
|
39
|
+
layers = get_layers(model, path)
|
|
40
|
+
if isinstance(layers, nn.ModuleList) and len(layers) > 0:
|
|
41
|
+
return path
|
|
42
|
+
except AttributeError:
|
|
43
|
+
continue
|
|
44
|
+
raise ValueError(f"Could not detect layer accessor for {type(model).__name__}")
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
# ── Layer Pruning ────────────────────────────────────────────────
|
|
48
|
+
|
|
49
|
+
def prune_layers(model, layer_indices: list[int], layer_accessor: str | None = None):
|
|
50
|
+
if layer_accessor is None:
|
|
51
|
+
layer_accessor = discover_layer_accessor(model)
|
|
52
|
+
layers = get_layers(model, layer_accessor)
|
|
53
|
+
kept = nn.ModuleList([layers[i] for i in layer_indices])
|
|
54
|
+
set_layers(model, layer_accessor, kept)
|
|
55
|
+
model.config.num_hidden_layers = len(layer_indices)
|
|
56
|
+
return model
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def create_pruned_student(teacher_model_id: str, layer_indices: list[int],
|
|
60
|
+
layer_accessor: str | None = None,
|
|
61
|
+
trust_remote_code: bool = False):
|
|
62
|
+
"""Load teacher and prune layers to create a student.
|
|
63
|
+
|
|
64
|
+
Returns (student_model, tokenizer).
|
|
65
|
+
"""
|
|
66
|
+
model = AutoModel.from_pretrained(teacher_model_id, trust_remote_code=trust_remote_code)
|
|
67
|
+
tokenizer = AutoTokenizer.from_pretrained(teacher_model_id, trust_remote_code=trust_remote_code)
|
|
68
|
+
|
|
69
|
+
if layer_accessor is None:
|
|
70
|
+
layer_accessor = discover_layer_accessor(model)
|
|
71
|
+
|
|
72
|
+
student = copy.deepcopy(model)
|
|
73
|
+
student = prune_layers(student, layer_indices, layer_accessor)
|
|
74
|
+
return student, tokenizer
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
# ── Hidden Dimension Reduction ───────────────────────────────────
|
|
78
|
+
|
|
79
|
+
def reduce_hidden_dim(model, new_hidden_dim: int, new_intermediate_size: int | None = None,
|
|
80
|
+
trust_remote_code: bool = False):
|
|
81
|
+
old_hidden = model.config.hidden_size
|
|
82
|
+
if new_hidden_dim >= old_hidden:
|
|
83
|
+
return model
|
|
84
|
+
|
|
85
|
+
old_inter = getattr(model.config, 'intermediate_size', old_hidden * 4)
|
|
86
|
+
if new_intermediate_size is None:
|
|
87
|
+
ratio = new_hidden_dim / old_hidden
|
|
88
|
+
new_intermediate_size = max(64, (int(old_inter * ratio) // 64) * 64)
|
|
89
|
+
|
|
90
|
+
new_config = copy.deepcopy(model.config)
|
|
91
|
+
new_config.hidden_size = new_hidden_dim
|
|
92
|
+
new_config.intermediate_size = new_intermediate_size
|
|
93
|
+
|
|
94
|
+
ratio = new_hidden_dim / old_hidden
|
|
95
|
+
old_n_kv = getattr(new_config, 'num_key_value_heads', None)
|
|
96
|
+
|
|
97
|
+
if hasattr(new_config, 'num_attention_heads'):
|
|
98
|
+
n_heads = getattr(new_config, 'num_attention_heads')
|
|
99
|
+
if n_heads is not None:
|
|
100
|
+
n_heads = max(1, int(n_heads * ratio))
|
|
101
|
+
while new_hidden_dim % n_heads != 0 and n_heads > 1:
|
|
102
|
+
n_heads -= 1
|
|
103
|
+
new_config.num_attention_heads = n_heads
|
|
104
|
+
|
|
105
|
+
if old_n_kv is not None and hasattr(new_config, 'num_key_value_heads'):
|
|
106
|
+
n_heads = getattr(new_config, 'num_attention_heads', n_heads)
|
|
107
|
+
n_kv = max(1, int(old_n_kv * ratio))
|
|
108
|
+
while n_kv > 1 and (n_heads % n_kv != 0 or new_hidden_dim % n_kv != 0):
|
|
109
|
+
n_kv -= 1
|
|
110
|
+
new_config.num_key_value_heads = n_kv
|
|
111
|
+
|
|
112
|
+
if hasattr(new_config, 'head_dim') and new_config.head_dim is not None:
|
|
113
|
+
new_heads = getattr(new_config, 'num_attention_heads', 1)
|
|
114
|
+
new_config.head_dim = new_hidden_dim // new_heads
|
|
115
|
+
|
|
116
|
+
new_model = AutoModel.from_config(new_config, trust_remote_code=trust_remote_code)
|
|
117
|
+
|
|
118
|
+
old_sd = model.state_dict()
|
|
119
|
+
new_sd = new_model.state_dict()
|
|
120
|
+
|
|
121
|
+
for key in new_sd:
|
|
122
|
+
if key not in old_sd:
|
|
123
|
+
continue
|
|
124
|
+
old_t = old_sd[key]
|
|
125
|
+
new_t = new_sd[key]
|
|
126
|
+
if old_t.shape == new_t.shape:
|
|
127
|
+
new_sd[key] = old_t.clone()
|
|
128
|
+
else:
|
|
129
|
+
slices = tuple(
|
|
130
|
+
slice(0, min(s_new, s_old))
|
|
131
|
+
for s_new, s_old in zip(new_t.shape, old_t.shape)
|
|
132
|
+
)
|
|
133
|
+
sliced = old_t[slices]
|
|
134
|
+
if sliced.shape == new_t.shape:
|
|
135
|
+
new_sd[key] = sliced.clone()
|
|
136
|
+
else:
|
|
137
|
+
target_slices = tuple(slice(0, s) for s in sliced.shape)
|
|
138
|
+
new_sd[key][target_slices] = sliced.clone()
|
|
139
|
+
|
|
140
|
+
new_model.load_state_dict(new_sd)
|
|
141
|
+
|
|
142
|
+
for attr in ["pad_token_id", "bos_token_id", "eos_token_id",
|
|
143
|
+
"cls_token_id", "sep_token_id", "unk_token_id", "mask_token_id"]:
|
|
144
|
+
old_val = getattr(model.config, attr, None)
|
|
145
|
+
if old_val is not None:
|
|
146
|
+
setattr(new_model.config, attr, old_val)
|
|
147
|
+
|
|
148
|
+
return new_model
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
# ── Tokenizer Type Detection ────────────────────────────────────
|
|
152
|
+
|
|
153
|
+
def detect_tokenizer_type(tokenizer) -> str:
|
|
154
|
+
tok_json = json.loads(tokenizer.backend_tokenizer.to_str())
|
|
155
|
+
return tok_json["model"]["type"]
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
# ── Vocab Pruning ────────────────────────────────────────────────
|
|
159
|
+
|
|
160
|
+
def collect_corpus_tokens(tokenizer, texts: list[str] | None = None,
|
|
161
|
+
max_vocab: int | None = None,
|
|
162
|
+
vocab_keep_ratio: float | None = None) -> list[int]:
|
|
163
|
+
"""Collect tokens used in a corpus, returning sorted IDs to keep."""
|
|
164
|
+
if texts is None:
|
|
165
|
+
texts = _get_default_multilingual_samples()
|
|
166
|
+
|
|
167
|
+
freq = Counter()
|
|
168
|
+
batch_size = 1000
|
|
169
|
+
for i in range(0, len(texts), batch_size):
|
|
170
|
+
batch = texts[i:i + batch_size]
|
|
171
|
+
encoded = tokenizer(batch, add_special_tokens=True, truncation=True, max_length=128)
|
|
172
|
+
for ids in encoded["input_ids"]:
|
|
173
|
+
freq.update(ids)
|
|
174
|
+
|
|
175
|
+
keep_ids = set(tokenizer.all_special_ids)
|
|
176
|
+
|
|
177
|
+
basic_chars = list("0123456789.,!?;:'\"-()[]{}/@#$%^&*+=<>~_ \t\n")
|
|
178
|
+
for ch in basic_chars:
|
|
179
|
+
ids = tokenizer.encode(ch, add_special_tokens=False)
|
|
180
|
+
keep_ids.update(ids)
|
|
181
|
+
|
|
182
|
+
tok_type = detect_tokenizer_type(tokenizer)
|
|
183
|
+
if tok_type == "BPE":
|
|
184
|
+
tok_json = json.loads(tokenizer.backend_tokenizer.to_str())
|
|
185
|
+
vocab = tok_json["model"]["vocab"]
|
|
186
|
+
for token, tid in vocab.items():
|
|
187
|
+
if tid < 256 or len(token) <= 1:
|
|
188
|
+
keep_ids.add(tid)
|
|
189
|
+
|
|
190
|
+
if vocab_keep_ratio is not None:
|
|
191
|
+
total_freq = sum(freq.values())
|
|
192
|
+
target_freq = total_freq * vocab_keep_ratio
|
|
193
|
+
corpus_tokens = sorted(freq.keys(), key=lambda t: freq[t], reverse=True)
|
|
194
|
+
cumsum = 0
|
|
195
|
+
for tid in corpus_tokens:
|
|
196
|
+
keep_ids.add(tid)
|
|
197
|
+
cumsum += freq[tid]
|
|
198
|
+
if cumsum >= target_freq:
|
|
199
|
+
break
|
|
200
|
+
elif max_vocab is not None:
|
|
201
|
+
remaining = max_vocab - len(keep_ids)
|
|
202
|
+
if remaining > 0:
|
|
203
|
+
for tid, _ in freq.most_common():
|
|
204
|
+
if tid not in keep_ids:
|
|
205
|
+
keep_ids.add(tid)
|
|
206
|
+
if len(keep_ids) >= max_vocab:
|
|
207
|
+
break
|
|
208
|
+
else:
|
|
209
|
+
keep_ids.update(freq.keys())
|
|
210
|
+
|
|
211
|
+
return sorted(keep_ids)
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
def prune_tokenizer_and_embeddings(model, tokenizer, keep_ids: list[int], save_dir: str):
|
|
215
|
+
"""Prune tokenizer vocab and model embeddings simultaneously."""
|
|
216
|
+
os.makedirs(save_dir, exist_ok=True)
|
|
217
|
+
old_to_new = {old_id: new_id for new_id, old_id in enumerate(keep_ids)}
|
|
218
|
+
|
|
219
|
+
tok_json = json.loads(tokenizer.backend_tokenizer.to_str())
|
|
220
|
+
model_type = tok_json["model"]["type"]
|
|
221
|
+
|
|
222
|
+
if model_type == "Unigram":
|
|
223
|
+
tok_json = _prune_unigram(tok_json, keep_ids, old_to_new)
|
|
224
|
+
elif model_type == "BPE":
|
|
225
|
+
tok_json = _prune_bpe(tok_json, keep_ids, old_to_new)
|
|
226
|
+
elif model_type == "WordPiece":
|
|
227
|
+
tok_json = _prune_wordpiece(tok_json, keep_ids, old_to_new)
|
|
228
|
+
else:
|
|
229
|
+
tokenizer.save_pretrained(save_dir)
|
|
230
|
+
model = _prune_embeddings(model, keep_ids)
|
|
231
|
+
return model
|
|
232
|
+
|
|
233
|
+
if "added_tokens" in tok_json:
|
|
234
|
+
new_added = []
|
|
235
|
+
for at in tok_json["added_tokens"]:
|
|
236
|
+
old_id = at["id"]
|
|
237
|
+
if old_id in old_to_new:
|
|
238
|
+
at["id"] = old_to_new[old_id]
|
|
239
|
+
new_added.append(at)
|
|
240
|
+
tok_json["added_tokens"] = new_added
|
|
241
|
+
|
|
242
|
+
pp = tok_json.get("post_processor")
|
|
243
|
+
if pp and "special_tokens" in pp:
|
|
244
|
+
for token_name, token_info in pp["special_tokens"].items():
|
|
245
|
+
if "ids" in token_info:
|
|
246
|
+
token_info["ids"] = [
|
|
247
|
+
old_to_new[oid] for oid in token_info["ids"] if oid in old_to_new
|
|
248
|
+
]
|
|
249
|
+
|
|
250
|
+
tokenizer.save_pretrained(save_dir)
|
|
251
|
+
|
|
252
|
+
tok_json_path = os.path.join(save_dir, "tokenizer.json")
|
|
253
|
+
with open(tok_json_path, "w", encoding="utf-8") as f:
|
|
254
|
+
json.dump(tok_json, f, ensure_ascii=False)
|
|
255
|
+
|
|
256
|
+
added_tokens_path = os.path.join(save_dir, "added_tokens.json")
|
|
257
|
+
if os.path.exists(added_tokens_path):
|
|
258
|
+
with open(added_tokens_path, "r", encoding="utf-8") as f:
|
|
259
|
+
added_tokens = json.load(f)
|
|
260
|
+
new_added_tokens = {}
|
|
261
|
+
for token_str, old_id in added_tokens.items():
|
|
262
|
+
if old_id in old_to_new:
|
|
263
|
+
new_added_tokens[token_str] = old_to_new[old_id]
|
|
264
|
+
with open(added_tokens_path, "w", encoding="utf-8") as f:
|
|
265
|
+
json.dump(new_added_tokens, f, ensure_ascii=False)
|
|
266
|
+
|
|
267
|
+
tok_config_path = os.path.join(save_dir, "tokenizer_config.json")
|
|
268
|
+
if os.path.exists(tok_config_path):
|
|
269
|
+
with open(tok_config_path, "r", encoding="utf-8") as f:
|
|
270
|
+
tok_config = json.load(f)
|
|
271
|
+
if "added_tokens_decoder" in tok_config:
|
|
272
|
+
new_decoder = {}
|
|
273
|
+
for old_id_str, token_info in tok_config["added_tokens_decoder"].items():
|
|
274
|
+
old_id = int(old_id_str)
|
|
275
|
+
if old_id in old_to_new:
|
|
276
|
+
new_decoder[str(old_to_new[old_id])] = token_info
|
|
277
|
+
tok_config["added_tokens_decoder"] = new_decoder
|
|
278
|
+
with open(tok_config_path, "w", encoding="utf-8") as f:
|
|
279
|
+
json.dump(tok_config, f, ensure_ascii=False, indent=2)
|
|
280
|
+
|
|
281
|
+
model = _prune_embeddings(model, keep_ids)
|
|
282
|
+
return model
|
|
283
|
+
|
|
284
|
+
|
|
285
|
+
def save_as_sentence_transformer(model, tokenizer, save_path: str):
|
|
286
|
+
"""Save HF model as SentenceTransformer format."""
|
|
287
|
+
from sentence_transformers import SentenceTransformer, models as st_models
|
|
288
|
+
import shutil
|
|
289
|
+
import glob as _glob
|
|
290
|
+
|
|
291
|
+
hf_tmp = os.path.join(save_path, "_hf_tmp")
|
|
292
|
+
os.makedirs(hf_tmp, exist_ok=True)
|
|
293
|
+
model.save_pretrained(hf_tmp)
|
|
294
|
+
tokenizer.save_pretrained(hf_tmp)
|
|
295
|
+
|
|
296
|
+
config_path = os.path.join(hf_tmp, "config.json")
|
|
297
|
+
is_custom = False
|
|
298
|
+
if os.path.exists(config_path):
|
|
299
|
+
with open(config_path, "r", encoding="utf-8") as f:
|
|
300
|
+
config = json.load(f)
|
|
301
|
+
is_custom = "auto_map" in config
|
|
302
|
+
|
|
303
|
+
if is_custom:
|
|
304
|
+
_copy_custom_code_files(model, hf_tmp)
|
|
305
|
+
else:
|
|
306
|
+
if os.path.exists(config_path):
|
|
307
|
+
with open(config_path, "r", encoding="utf-8") as f:
|
|
308
|
+
config = json.load(f)
|
|
309
|
+
config.pop("_name_or_path", None)
|
|
310
|
+
with open(config_path, "w", encoding="utf-8") as f:
|
|
311
|
+
json.dump(config, f, indent=2, ensure_ascii=False)
|
|
312
|
+
|
|
313
|
+
os.environ["HF_HUB_TRUST_REMOTE_CODE"] = "1"
|
|
314
|
+
word_model = st_models.Transformer(
|
|
315
|
+
hf_tmp,
|
|
316
|
+
config_args={"trust_remote_code": True},
|
|
317
|
+
model_args={"trust_remote_code": True},
|
|
318
|
+
tokenizer_args={"trust_remote_code": True},
|
|
319
|
+
)
|
|
320
|
+
pool_model = st_models.Pooling(
|
|
321
|
+
word_model.get_word_embedding_dimension(),
|
|
322
|
+
pooling_mode_mean_tokens=True,
|
|
323
|
+
)
|
|
324
|
+
st_model = SentenceTransformer(modules=[word_model, pool_model])
|
|
325
|
+
st_model.save(save_path)
|
|
326
|
+
shutil.rmtree(hf_tmp, ignore_errors=True)
|
|
327
|
+
return st_model
|
|
328
|
+
|
|
329
|
+
|
|
330
|
+
# ── Internal helpers ─────────────────────────────────────────────
|
|
331
|
+
|
|
332
|
+
def _prune_unigram(tok_json, keep_ids, old_to_new):
|
|
333
|
+
old_vocab = tok_json["model"]["vocab"]
|
|
334
|
+
new_vocab = []
|
|
335
|
+
for old_id in keep_ids:
|
|
336
|
+
if old_id < len(old_vocab):
|
|
337
|
+
new_vocab.append(old_vocab[old_id])
|
|
338
|
+
tok_json["model"]["vocab"] = new_vocab
|
|
339
|
+
old_unk_id = tok_json["model"].get("unk_id")
|
|
340
|
+
if old_unk_id is not None and old_unk_id in old_to_new:
|
|
341
|
+
tok_json["model"]["unk_id"] = old_to_new[old_unk_id]
|
|
342
|
+
return tok_json
|
|
343
|
+
|
|
344
|
+
|
|
345
|
+
def _prune_bpe(tok_json, keep_ids, old_to_new):
|
|
346
|
+
old_vocab = tok_json["model"]["vocab"]
|
|
347
|
+
keep_ids_set = set(keep_ids)
|
|
348
|
+
new_vocab = {}
|
|
349
|
+
for token, old_id in old_vocab.items():
|
|
350
|
+
if old_id in keep_ids_set:
|
|
351
|
+
new_vocab[token] = old_to_new[old_id]
|
|
352
|
+
tok_json["model"]["vocab"] = new_vocab
|
|
353
|
+
kept_tokens = set(new_vocab.keys())
|
|
354
|
+
if "merges" in tok_json["model"]:
|
|
355
|
+
new_merges = []
|
|
356
|
+
for merge in tok_json["model"]["merges"]:
|
|
357
|
+
parts = merge if isinstance(merge, list) else merge.split(" ")
|
|
358
|
+
if len(parts) == 2:
|
|
359
|
+
merged_token = parts[0] + parts[1]
|
|
360
|
+
if (parts[0] in kept_tokens and parts[1] in kept_tokens
|
|
361
|
+
and merged_token in kept_tokens):
|
|
362
|
+
new_merges.append(merge)
|
|
363
|
+
tok_json["model"]["merges"] = new_merges
|
|
364
|
+
return tok_json
|
|
365
|
+
|
|
366
|
+
|
|
367
|
+
def _prune_wordpiece(tok_json, keep_ids, old_to_new):
|
|
368
|
+
old_vocab = tok_json["model"]["vocab"]
|
|
369
|
+
keep_ids_set = set(keep_ids)
|
|
370
|
+
new_vocab = {}
|
|
371
|
+
for token, old_id in old_vocab.items():
|
|
372
|
+
if old_id in keep_ids_set:
|
|
373
|
+
new_vocab[token] = old_to_new[old_id]
|
|
374
|
+
tok_json["model"]["vocab"] = new_vocab
|
|
375
|
+
return tok_json
|
|
376
|
+
|
|
377
|
+
|
|
378
|
+
def _prune_embeddings(model, keep_ids):
|
|
379
|
+
old_emb = model.get_input_embeddings()
|
|
380
|
+
old_weight = old_emb.weight.data
|
|
381
|
+
new_vocab_size = len(keep_ids)
|
|
382
|
+
old_to_new = {old_id: new_id for new_id, old_id in enumerate(keep_ids)}
|
|
383
|
+
|
|
384
|
+
for attr in ["pad_token_id", "bos_token_id", "eos_token_id",
|
|
385
|
+
"cls_token_id", "sep_token_id", "unk_token_id",
|
|
386
|
+
"mask_token_id", "decoder_start_token_id"]:
|
|
387
|
+
old_id = getattr(model.config, attr, None)
|
|
388
|
+
if old_id is not None:
|
|
389
|
+
if old_id in old_to_new:
|
|
390
|
+
setattr(model.config, attr, old_to_new[old_id])
|
|
391
|
+
else:
|
|
392
|
+
setattr(model.config, attr, None)
|
|
393
|
+
|
|
394
|
+
padding_idx = getattr(old_emb, 'padding_idx', None)
|
|
395
|
+
if padding_idx is not None:
|
|
396
|
+
padding_idx = old_to_new.get(padding_idx, None)
|
|
397
|
+
new_emb = nn.Embedding(new_vocab_size, old_weight.shape[1], padding_idx=padding_idx)
|
|
398
|
+
for new_id, old_id in enumerate(keep_ids):
|
|
399
|
+
if old_id < old_weight.shape[0]:
|
|
400
|
+
new_emb.weight.data[new_id] = old_weight[old_id]
|
|
401
|
+
|
|
402
|
+
model.set_input_embeddings(new_emb)
|
|
403
|
+
model.config.vocab_size = new_vocab_size
|
|
404
|
+
return model
|
|
405
|
+
|
|
406
|
+
|
|
407
|
+
def _copy_custom_code_files(model, target_dir):
|
|
408
|
+
import shutil
|
|
409
|
+
import glob as _glob
|
|
410
|
+
|
|
411
|
+
source_path = getattr(model.config, '_name_or_path', None)
|
|
412
|
+
if not source_path:
|
|
413
|
+
return
|
|
414
|
+
|
|
415
|
+
cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "huggingface", "hub")
|
|
416
|
+
model_cache_name = "models--" + source_path.replace("/", "--")
|
|
417
|
+
model_cache_dir = os.path.join(cache_dir, model_cache_name)
|
|
418
|
+
|
|
419
|
+
if os.path.exists(model_cache_dir):
|
|
420
|
+
snapshots_dir = os.path.join(model_cache_dir, "snapshots")
|
|
421
|
+
if os.path.exists(snapshots_dir):
|
|
422
|
+
snapshots = os.listdir(snapshots_dir)
|
|
423
|
+
if snapshots:
|
|
424
|
+
latest = os.path.join(snapshots_dir, snapshots[-1])
|
|
425
|
+
for py_file in _glob.glob(os.path.join(latest, "*.py")):
|
|
426
|
+
fname = os.path.basename(py_file)
|
|
427
|
+
dest = os.path.join(target_dir, fname)
|
|
428
|
+
if not os.path.exists(dest):
|
|
429
|
+
shutil.copy2(py_file, dest)
|
|
430
|
+
|
|
431
|
+
|
|
432
|
+
def _get_default_multilingual_samples() -> list[str]:
|
|
433
|
+
return [
|
|
434
|
+
"예약 좀 해줘", "지난번 주문 뭐였지?", "안녕하세요 반갑습니다",
|
|
435
|
+
"Book a table for me", "What did I order last time?", "Hello how are you",
|
|
436
|
+
"予約をお願いします", "前回の注文は何でしたか", "こんにちは元気ですか",
|
|
437
|
+
"帮我预约一下", "上次我点了什么", "你好你好吗",
|
|
438
|
+
"Reserva una mesa", "Qué pedí la última vez", "Hola cómo estás",
|
|
439
|
+
"Réservez une table", "Qu'est-ce que j'ai commandé", "Bonjour comment allez-vous",
|
|
440
|
+
"Reservieren Sie einen Tisch", "Was habe ich bestellt", "Hallo wie geht es",
|
|
441
|
+
] * 10
|