EasySmallEmbeddingModel 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,161 @@
1
+ Metadata-Version: 2.4
2
+ Name: EasySmallEmbeddingModel
3
+ Version: 0.1.0
4
+ Summary: Compress large embedding models into small, fast students via layer pruning, vocab pruning, hidden dim reduction, and knowledge distillation.
5
+ Author: gomyk
6
+ License: Apache-2.0
7
+ Keywords: embedding,model-compression,distillation,pruning,sentence-transformers
8
+ Classifier: Development Status :: 3 - Alpha
9
+ Classifier: Intended Audience :: Science/Research
10
+ Classifier: License :: OSI Approved :: Apache Software License
11
+ Classifier: Programming Language :: Python :: 3
12
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
13
+ Requires-Python: >=3.9
14
+ Description-Content-Type: text/markdown
15
+ Requires-Dist: torch>=2.0.0
16
+ Requires-Dist: transformers>=4.40.0
17
+ Requires-Dist: sentence-transformers>=3.0.0
18
+ Requires-Dist: numpy>=1.24.0
19
+ Requires-Dist: tqdm
20
+ Requires-Dist: sentencepiece
21
+ Requires-Dist: protobuf
22
+ Provides-Extra: eval
23
+ Requires-Dist: mteb>=1.14.0; extra == "eval"
24
+ Requires-Dist: pandas>=2.0.0; extra == "eval"
25
+ Provides-Extra: export
26
+ Requires-Dist: onnxruntime>=1.16.0; extra == "export"
27
+ Requires-Dist: onnx>=1.15.0; extra == "export"
28
+ Requires-Dist: optimum[onnxruntime]>=1.16.0; extra == "export"
29
+ Provides-Extra: hub
30
+ Requires-Dist: huggingface-hub>=0.20.0; extra == "hub"
31
+ Provides-Extra: web
32
+ Requires-Dist: flask>=3.0.0; extra == "web"
33
+ Provides-Extra: all
34
+ Requires-Dist: EasySmallEmbeddingModel[eval,export,hub,web]; extra == "all"
35
+
36
+ # SmallModel
37
+
38
+ Compress large embedding models into small, fast students via layer pruning, vocab pruning, hidden dim reduction, and knowledge distillation.
39
+
40
+ ## Features
41
+
42
+ - **Layer Pruning** - Select which transformer layers to keep
43
+ - **Vocab Pruning** - Remove unused tokens based on corpus frequency
44
+ - **Hidden Dim Reduction** - Shrink internal dimensions (slicing or PCA)
45
+ - **Knowledge Distillation** - MSE + Cosine loss alignment with teacher
46
+ - **Auto Compress** - Find optimal config within size constraints
47
+ - **2-Stage Distillation** - Progressive distillation for 10x+ compression
48
+ - **Interactive Web UI** - Visual layer editor with real-time size estimation
49
+ - **MTEB Evaluation** - Benchmark on Classification, Clustering, STS tasks
50
+
51
+ ## Installation
52
+
53
+ ```bash
54
+ pip install smallmodel[all]
55
+ ```
56
+
57
+ Or install specific extras:
58
+
59
+ ```bash
60
+ pip install smallmodel # core only
61
+ pip install smallmodel[web] # + Flask web UI
62
+ pip install smallmodel[eval] # + MTEB evaluation
63
+ pip install smallmodel[export] # + ONNX export
64
+ pip install smallmodel[hub] # + HuggingFace Hub upload
65
+ ```
66
+
67
+ For development:
68
+
69
+ ```bash
70
+ git clone https://github.com/gomyk/smallmodel.git
71
+ cd smallmodel
72
+ pip install -e ".[all]"
73
+ ```
74
+
75
+ ## Quick Start
76
+
77
+ ### Python API
78
+
79
+ ```python
80
+ from smallmodel import SmallModel
81
+
82
+ # Auto-compress within 50MB
83
+ sm = SmallModel.from_teacher("gte")
84
+ sm.compress(max_fp32_mb=50.0)
85
+ sm.distill(epochs=10)
86
+
87
+ # Manual layer selection
88
+ sm = SmallModel.from_teacher("gte", layer_indices=[0, 3, 6, 11])
89
+ sm.create()
90
+
91
+ # Register custom teacher
92
+ from smallmodel import register_teacher
93
+ register_teacher(
94
+ "my-bert",
95
+ model_id="my-org/my-bert-base",
96
+ short_name="MyBERT",
97
+ hidden_dim=768, num_layers=12,
98
+ intermediate_size=3072, vocab_size=30522,
99
+ )
100
+ ```
101
+
102
+ ### Web UI
103
+
104
+ ```python
105
+ from smallmodel import SmallModel
106
+
107
+ sm = SmallModel.from_teacher("gte")
108
+ sm.serve() # http://127.0.0.1:7860
109
+ ```
110
+
111
+ Or via CLI:
112
+
113
+ ```bash
114
+ smallmodel serve --teacher gte --port 7860
115
+ ```
116
+
117
+ The web UI lets you:
118
+ - Select teacher model from 7+ pre-registered models
119
+ - Toggle layers on/off with preset configurations
120
+ - Adjust hidden dim, FFN size, and vocab size
121
+ - See real-time size estimation and compression ratio
122
+ - Select distillation datasets and evaluation tasks
123
+ - Analyze vocab coverage at different vocab sizes
124
+ - Create compressed models with one click
125
+
126
+ ### CLI
127
+
128
+ ```bash
129
+ smallmodel list-teachers
130
+ smallmodel compress --teacher gte --max-mb 50
131
+ smallmodel create --teacher gte --layers 0,3,6,11
132
+ smallmodel distill --teacher gte --student output/students/gte/gte_compressed
133
+ smallmodel serve --teacher gte
134
+ ```
135
+
136
+ ## Pre-registered Teachers
137
+
138
+ | Key | Model | Layers | Hidden | Vocab | FP32 MB |
139
+ |---|---|---|---|---|---|
140
+ | minilm | paraphrase-multilingual-MiniLM-L12-v2 | 12 | 384 | 250K | 448 |
141
+ | modernbert | ModernBERT-base | 22 | 768 | 50K | 496 |
142
+ | gte | gte-multilingual-base | 12 | 768 | 250K | 1058 |
143
+ | me5 | multilingual-e5-base | 12 | 768 | 250K | 1058 |
144
+ | me5s | multilingual-e5-small | 12 | 384 | 250K | 448 |
145
+ | gemma_emb | embeddinggemma-300m | 24 | 768 | 262K | 1155 |
146
+ | qwen3 | Qwen3-0.6B | 28 | 1024 | 152K | 2274 |
147
+
148
+ ## How It Works
149
+
150
+ 1. **Layer Pruning** - Copy selected layers from teacher (uniform spacing recommended)
151
+ 2. **Hidden Dim Reduction** - Shrink dimensions if needed to meet size target
152
+ 3. **Vocab Pruning** - Remove tokens not seen in training corpus
153
+ 4. **Knowledge Distillation** - Train student to reproduce teacher's embeddings
154
+ 5. **Evaluation** - MTEB benchmark (Classification, Clustering, STS)
155
+
156
+ For compression ratios > 10x, a 2-stage distillation pipeline is used:
157
+ Teacher → Intermediate (~1/5 teacher) → Final Student
158
+
159
+ ## License
160
+
161
+ Apache-2.0
@@ -0,0 +1,18 @@
1
+ smallmodel/__init__.py,sha256=GgJw-ldxMcQBOv4200CmXoMrQ7jGvDlMys3esMHuO_E,257
2
+ smallmodel/arch.py,sha256=rUusfhJ2H8-FgQ1MseYAoy-jTuZm2xQgRpvbdK2H-_c,17205
3
+ smallmodel/cli.py,sha256=LhSbl_MyJYhs_l32B6gbT5teyT87D7a-0kwYhodg6NY,5341
4
+ smallmodel/core.py,sha256=_Z9wsEWlbEnQS2u9ppuWmHBM-bU70xrrtztBh7Rf79w,19222
5
+ smallmodel/data.py,sha256=udjwpEe0MH1TTYTAMy533pDmh3ASqVDbaRv0xNIDTYU,5734
6
+ smallmodel/distill.py,sha256=sXdV-3n4y_3At2i0qQUV8xcIFet9NMmszrn0RRVeJ-U,7739
7
+ smallmodel/sizing.py,sha256=RyrnJWHtWy7Ycl9pZurU6lEqu6Cx18GmqOuO-k3i7e8,5265
8
+ smallmodel/teachers.py,sha256=gRbDMyVAB_wcWaaE2t4BhEEY_mVekrUIxECEq1vAmHU,4474
9
+ smallmodel/web/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
10
+ smallmodel/web/app.py,sha256=Hy8r4cHGR6xqS6stry93TJYHCCnAXx5PQrn9LYRRMxg,20721
11
+ smallmodel/web/static/css/style.css,sha256=6t3wLF4wrBmZfKnX-Ex3PMt_YaTWVpnkFtSEmC1TDS0,15101
12
+ smallmodel/web/static/js/app.js,sha256=gpkOGrtzjtQFtx5lhFEwKRcX6sfH3g5Q82vHzwd7d9g,22382
13
+ smallmodel/web/templates/index.html,sha256=WciG04lmVwrLkC5VcuRQVZT0n3PsIq-9H0h-pwC0ZUc,16366
14
+ easysmallembeddingmodel-0.1.0.dist-info/METADATA,sha256=Br9F10VeA1SEkXKNnKsC2bX03edP_bhslYpBirACvIg,5227
15
+ easysmallembeddingmodel-0.1.0.dist-info/WHEEL,sha256=aeYiig01lYGDzBgS8HxWXOg3uV61G9ijOsup-k9o1sk,91
16
+ easysmallembeddingmodel-0.1.0.dist-info/entry_points.txt,sha256=HNo0Be_rqqYaeULejKSs-AZ5QDchylEn9vbZMRK_cQ8,51
17
+ easysmallembeddingmodel-0.1.0.dist-info/top_level.txt,sha256=Spct7pFIJwbc3SqefL9ZYf6muQQk-IGYp8Jv5qkcnHM,11
18
+ easysmallembeddingmodel-0.1.0.dist-info/RECORD,,
@@ -0,0 +1,5 @@
1
+ Wheel-Version: 1.0
2
+ Generator: setuptools (82.0.1)
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
5
+
@@ -0,0 +1,2 @@
1
+ [console_scripts]
2
+ smallmodel = smallmodel.cli:main
@@ -0,0 +1 @@
1
+ smallmodel
smallmodel/__init__.py ADDED
@@ -0,0 +1,7 @@
1
+ """SmallModel - Compress large embedding models into small, fast students."""
2
+
3
+ from smallmodel.core import SmallModel
4
+ from smallmodel.teachers import TEACHERS, register_teacher
5
+
6
+ __version__ = "0.1.0"
7
+ __all__ = ["SmallModel", "TEACHERS", "register_teacher"]
smallmodel/arch.py ADDED
@@ -0,0 +1,441 @@
1
+ """Architecture utilities: layer pruning, hidden dim reduction, vocab pruning."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import copy
6
+ import json
7
+ import os
8
+ from collections import Counter
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ from transformers import AutoModel, AutoTokenizer, AutoConfig
13
+
14
+
15
+ # ── Layer Access ─────────────────────────────────────────────────
16
+
17
+ def get_layers(model, layer_accessor: str):
18
+ obj = model
19
+ for attr in layer_accessor.split("."):
20
+ obj = getattr(obj, attr)
21
+ return obj
22
+
23
+
24
+ def set_layers(model, layer_accessor: str, new_layers):
25
+ parts = layer_accessor.split(".")
26
+ obj = model
27
+ for attr in parts[:-1]:
28
+ obj = getattr(obj, attr)
29
+ setattr(obj, parts[-1], new_layers)
30
+
31
+
32
+ def discover_layer_accessor(model) -> str:
33
+ candidates = [
34
+ "encoder.layer", "encoder.layers", "layers",
35
+ "transformer.layer", "transformer.layers",
36
+ ]
37
+ for path in candidates:
38
+ try:
39
+ layers = get_layers(model, path)
40
+ if isinstance(layers, nn.ModuleList) and len(layers) > 0:
41
+ return path
42
+ except AttributeError:
43
+ continue
44
+ raise ValueError(f"Could not detect layer accessor for {type(model).__name__}")
45
+
46
+
47
+ # ── Layer Pruning ────────────────────────────────────────────────
48
+
49
+ def prune_layers(model, layer_indices: list[int], layer_accessor: str | None = None):
50
+ if layer_accessor is None:
51
+ layer_accessor = discover_layer_accessor(model)
52
+ layers = get_layers(model, layer_accessor)
53
+ kept = nn.ModuleList([layers[i] for i in layer_indices])
54
+ set_layers(model, layer_accessor, kept)
55
+ model.config.num_hidden_layers = len(layer_indices)
56
+ return model
57
+
58
+
59
+ def create_pruned_student(teacher_model_id: str, layer_indices: list[int],
60
+ layer_accessor: str | None = None,
61
+ trust_remote_code: bool = False):
62
+ """Load teacher and prune layers to create a student.
63
+
64
+ Returns (student_model, tokenizer).
65
+ """
66
+ model = AutoModel.from_pretrained(teacher_model_id, trust_remote_code=trust_remote_code)
67
+ tokenizer = AutoTokenizer.from_pretrained(teacher_model_id, trust_remote_code=trust_remote_code)
68
+
69
+ if layer_accessor is None:
70
+ layer_accessor = discover_layer_accessor(model)
71
+
72
+ student = copy.deepcopy(model)
73
+ student = prune_layers(student, layer_indices, layer_accessor)
74
+ return student, tokenizer
75
+
76
+
77
+ # ── Hidden Dimension Reduction ───────────────────────────────────
78
+
79
+ def reduce_hidden_dim(model, new_hidden_dim: int, new_intermediate_size: int | None = None,
80
+ trust_remote_code: bool = False):
81
+ old_hidden = model.config.hidden_size
82
+ if new_hidden_dim >= old_hidden:
83
+ return model
84
+
85
+ old_inter = getattr(model.config, 'intermediate_size', old_hidden * 4)
86
+ if new_intermediate_size is None:
87
+ ratio = new_hidden_dim / old_hidden
88
+ new_intermediate_size = max(64, (int(old_inter * ratio) // 64) * 64)
89
+
90
+ new_config = copy.deepcopy(model.config)
91
+ new_config.hidden_size = new_hidden_dim
92
+ new_config.intermediate_size = new_intermediate_size
93
+
94
+ ratio = new_hidden_dim / old_hidden
95
+ old_n_kv = getattr(new_config, 'num_key_value_heads', None)
96
+
97
+ if hasattr(new_config, 'num_attention_heads'):
98
+ n_heads = getattr(new_config, 'num_attention_heads')
99
+ if n_heads is not None:
100
+ n_heads = max(1, int(n_heads * ratio))
101
+ while new_hidden_dim % n_heads != 0 and n_heads > 1:
102
+ n_heads -= 1
103
+ new_config.num_attention_heads = n_heads
104
+
105
+ if old_n_kv is not None and hasattr(new_config, 'num_key_value_heads'):
106
+ n_heads = getattr(new_config, 'num_attention_heads', n_heads)
107
+ n_kv = max(1, int(old_n_kv * ratio))
108
+ while n_kv > 1 and (n_heads % n_kv != 0 or new_hidden_dim % n_kv != 0):
109
+ n_kv -= 1
110
+ new_config.num_key_value_heads = n_kv
111
+
112
+ if hasattr(new_config, 'head_dim') and new_config.head_dim is not None:
113
+ new_heads = getattr(new_config, 'num_attention_heads', 1)
114
+ new_config.head_dim = new_hidden_dim // new_heads
115
+
116
+ new_model = AutoModel.from_config(new_config, trust_remote_code=trust_remote_code)
117
+
118
+ old_sd = model.state_dict()
119
+ new_sd = new_model.state_dict()
120
+
121
+ for key in new_sd:
122
+ if key not in old_sd:
123
+ continue
124
+ old_t = old_sd[key]
125
+ new_t = new_sd[key]
126
+ if old_t.shape == new_t.shape:
127
+ new_sd[key] = old_t.clone()
128
+ else:
129
+ slices = tuple(
130
+ slice(0, min(s_new, s_old))
131
+ for s_new, s_old in zip(new_t.shape, old_t.shape)
132
+ )
133
+ sliced = old_t[slices]
134
+ if sliced.shape == new_t.shape:
135
+ new_sd[key] = sliced.clone()
136
+ else:
137
+ target_slices = tuple(slice(0, s) for s in sliced.shape)
138
+ new_sd[key][target_slices] = sliced.clone()
139
+
140
+ new_model.load_state_dict(new_sd)
141
+
142
+ for attr in ["pad_token_id", "bos_token_id", "eos_token_id",
143
+ "cls_token_id", "sep_token_id", "unk_token_id", "mask_token_id"]:
144
+ old_val = getattr(model.config, attr, None)
145
+ if old_val is not None:
146
+ setattr(new_model.config, attr, old_val)
147
+
148
+ return new_model
149
+
150
+
151
+ # ── Tokenizer Type Detection ────────────────────────────────────
152
+
153
+ def detect_tokenizer_type(tokenizer) -> str:
154
+ tok_json = json.loads(tokenizer.backend_tokenizer.to_str())
155
+ return tok_json["model"]["type"]
156
+
157
+
158
+ # ── Vocab Pruning ────────────────────────────────────────────────
159
+
160
+ def collect_corpus_tokens(tokenizer, texts: list[str] | None = None,
161
+ max_vocab: int | None = None,
162
+ vocab_keep_ratio: float | None = None) -> list[int]:
163
+ """Collect tokens used in a corpus, returning sorted IDs to keep."""
164
+ if texts is None:
165
+ texts = _get_default_multilingual_samples()
166
+
167
+ freq = Counter()
168
+ batch_size = 1000
169
+ for i in range(0, len(texts), batch_size):
170
+ batch = texts[i:i + batch_size]
171
+ encoded = tokenizer(batch, add_special_tokens=True, truncation=True, max_length=128)
172
+ for ids in encoded["input_ids"]:
173
+ freq.update(ids)
174
+
175
+ keep_ids = set(tokenizer.all_special_ids)
176
+
177
+ basic_chars = list("0123456789.,!?;:'\"-()[]{}/@#$%^&*+=<>~_ \t\n")
178
+ for ch in basic_chars:
179
+ ids = tokenizer.encode(ch, add_special_tokens=False)
180
+ keep_ids.update(ids)
181
+
182
+ tok_type = detect_tokenizer_type(tokenizer)
183
+ if tok_type == "BPE":
184
+ tok_json = json.loads(tokenizer.backend_tokenizer.to_str())
185
+ vocab = tok_json["model"]["vocab"]
186
+ for token, tid in vocab.items():
187
+ if tid < 256 or len(token) <= 1:
188
+ keep_ids.add(tid)
189
+
190
+ if vocab_keep_ratio is not None:
191
+ total_freq = sum(freq.values())
192
+ target_freq = total_freq * vocab_keep_ratio
193
+ corpus_tokens = sorted(freq.keys(), key=lambda t: freq[t], reverse=True)
194
+ cumsum = 0
195
+ for tid in corpus_tokens:
196
+ keep_ids.add(tid)
197
+ cumsum += freq[tid]
198
+ if cumsum >= target_freq:
199
+ break
200
+ elif max_vocab is not None:
201
+ remaining = max_vocab - len(keep_ids)
202
+ if remaining > 0:
203
+ for tid, _ in freq.most_common():
204
+ if tid not in keep_ids:
205
+ keep_ids.add(tid)
206
+ if len(keep_ids) >= max_vocab:
207
+ break
208
+ else:
209
+ keep_ids.update(freq.keys())
210
+
211
+ return sorted(keep_ids)
212
+
213
+
214
+ def prune_tokenizer_and_embeddings(model, tokenizer, keep_ids: list[int], save_dir: str):
215
+ """Prune tokenizer vocab and model embeddings simultaneously."""
216
+ os.makedirs(save_dir, exist_ok=True)
217
+ old_to_new = {old_id: new_id for new_id, old_id in enumerate(keep_ids)}
218
+
219
+ tok_json = json.loads(tokenizer.backend_tokenizer.to_str())
220
+ model_type = tok_json["model"]["type"]
221
+
222
+ if model_type == "Unigram":
223
+ tok_json = _prune_unigram(tok_json, keep_ids, old_to_new)
224
+ elif model_type == "BPE":
225
+ tok_json = _prune_bpe(tok_json, keep_ids, old_to_new)
226
+ elif model_type == "WordPiece":
227
+ tok_json = _prune_wordpiece(tok_json, keep_ids, old_to_new)
228
+ else:
229
+ tokenizer.save_pretrained(save_dir)
230
+ model = _prune_embeddings(model, keep_ids)
231
+ return model
232
+
233
+ if "added_tokens" in tok_json:
234
+ new_added = []
235
+ for at in tok_json["added_tokens"]:
236
+ old_id = at["id"]
237
+ if old_id in old_to_new:
238
+ at["id"] = old_to_new[old_id]
239
+ new_added.append(at)
240
+ tok_json["added_tokens"] = new_added
241
+
242
+ pp = tok_json.get("post_processor")
243
+ if pp and "special_tokens" in pp:
244
+ for token_name, token_info in pp["special_tokens"].items():
245
+ if "ids" in token_info:
246
+ token_info["ids"] = [
247
+ old_to_new[oid] for oid in token_info["ids"] if oid in old_to_new
248
+ ]
249
+
250
+ tokenizer.save_pretrained(save_dir)
251
+
252
+ tok_json_path = os.path.join(save_dir, "tokenizer.json")
253
+ with open(tok_json_path, "w", encoding="utf-8") as f:
254
+ json.dump(tok_json, f, ensure_ascii=False)
255
+
256
+ added_tokens_path = os.path.join(save_dir, "added_tokens.json")
257
+ if os.path.exists(added_tokens_path):
258
+ with open(added_tokens_path, "r", encoding="utf-8") as f:
259
+ added_tokens = json.load(f)
260
+ new_added_tokens = {}
261
+ for token_str, old_id in added_tokens.items():
262
+ if old_id in old_to_new:
263
+ new_added_tokens[token_str] = old_to_new[old_id]
264
+ with open(added_tokens_path, "w", encoding="utf-8") as f:
265
+ json.dump(new_added_tokens, f, ensure_ascii=False)
266
+
267
+ tok_config_path = os.path.join(save_dir, "tokenizer_config.json")
268
+ if os.path.exists(tok_config_path):
269
+ with open(tok_config_path, "r", encoding="utf-8") as f:
270
+ tok_config = json.load(f)
271
+ if "added_tokens_decoder" in tok_config:
272
+ new_decoder = {}
273
+ for old_id_str, token_info in tok_config["added_tokens_decoder"].items():
274
+ old_id = int(old_id_str)
275
+ if old_id in old_to_new:
276
+ new_decoder[str(old_to_new[old_id])] = token_info
277
+ tok_config["added_tokens_decoder"] = new_decoder
278
+ with open(tok_config_path, "w", encoding="utf-8") as f:
279
+ json.dump(tok_config, f, ensure_ascii=False, indent=2)
280
+
281
+ model = _prune_embeddings(model, keep_ids)
282
+ return model
283
+
284
+
285
+ def save_as_sentence_transformer(model, tokenizer, save_path: str):
286
+ """Save HF model as SentenceTransformer format."""
287
+ from sentence_transformers import SentenceTransformer, models as st_models
288
+ import shutil
289
+ import glob as _glob
290
+
291
+ hf_tmp = os.path.join(save_path, "_hf_tmp")
292
+ os.makedirs(hf_tmp, exist_ok=True)
293
+ model.save_pretrained(hf_tmp)
294
+ tokenizer.save_pretrained(hf_tmp)
295
+
296
+ config_path = os.path.join(hf_tmp, "config.json")
297
+ is_custom = False
298
+ if os.path.exists(config_path):
299
+ with open(config_path, "r", encoding="utf-8") as f:
300
+ config = json.load(f)
301
+ is_custom = "auto_map" in config
302
+
303
+ if is_custom:
304
+ _copy_custom_code_files(model, hf_tmp)
305
+ else:
306
+ if os.path.exists(config_path):
307
+ with open(config_path, "r", encoding="utf-8") as f:
308
+ config = json.load(f)
309
+ config.pop("_name_or_path", None)
310
+ with open(config_path, "w", encoding="utf-8") as f:
311
+ json.dump(config, f, indent=2, ensure_ascii=False)
312
+
313
+ os.environ["HF_HUB_TRUST_REMOTE_CODE"] = "1"
314
+ word_model = st_models.Transformer(
315
+ hf_tmp,
316
+ config_args={"trust_remote_code": True},
317
+ model_args={"trust_remote_code": True},
318
+ tokenizer_args={"trust_remote_code": True},
319
+ )
320
+ pool_model = st_models.Pooling(
321
+ word_model.get_word_embedding_dimension(),
322
+ pooling_mode_mean_tokens=True,
323
+ )
324
+ st_model = SentenceTransformer(modules=[word_model, pool_model])
325
+ st_model.save(save_path)
326
+ shutil.rmtree(hf_tmp, ignore_errors=True)
327
+ return st_model
328
+
329
+
330
+ # ── Internal helpers ─────────────────────────────────────────────
331
+
332
+ def _prune_unigram(tok_json, keep_ids, old_to_new):
333
+ old_vocab = tok_json["model"]["vocab"]
334
+ new_vocab = []
335
+ for old_id in keep_ids:
336
+ if old_id < len(old_vocab):
337
+ new_vocab.append(old_vocab[old_id])
338
+ tok_json["model"]["vocab"] = new_vocab
339
+ old_unk_id = tok_json["model"].get("unk_id")
340
+ if old_unk_id is not None and old_unk_id in old_to_new:
341
+ tok_json["model"]["unk_id"] = old_to_new[old_unk_id]
342
+ return tok_json
343
+
344
+
345
+ def _prune_bpe(tok_json, keep_ids, old_to_new):
346
+ old_vocab = tok_json["model"]["vocab"]
347
+ keep_ids_set = set(keep_ids)
348
+ new_vocab = {}
349
+ for token, old_id in old_vocab.items():
350
+ if old_id in keep_ids_set:
351
+ new_vocab[token] = old_to_new[old_id]
352
+ tok_json["model"]["vocab"] = new_vocab
353
+ kept_tokens = set(new_vocab.keys())
354
+ if "merges" in tok_json["model"]:
355
+ new_merges = []
356
+ for merge in tok_json["model"]["merges"]:
357
+ parts = merge if isinstance(merge, list) else merge.split(" ")
358
+ if len(parts) == 2:
359
+ merged_token = parts[0] + parts[1]
360
+ if (parts[0] in kept_tokens and parts[1] in kept_tokens
361
+ and merged_token in kept_tokens):
362
+ new_merges.append(merge)
363
+ tok_json["model"]["merges"] = new_merges
364
+ return tok_json
365
+
366
+
367
+ def _prune_wordpiece(tok_json, keep_ids, old_to_new):
368
+ old_vocab = tok_json["model"]["vocab"]
369
+ keep_ids_set = set(keep_ids)
370
+ new_vocab = {}
371
+ for token, old_id in old_vocab.items():
372
+ if old_id in keep_ids_set:
373
+ new_vocab[token] = old_to_new[old_id]
374
+ tok_json["model"]["vocab"] = new_vocab
375
+ return tok_json
376
+
377
+
378
+ def _prune_embeddings(model, keep_ids):
379
+ old_emb = model.get_input_embeddings()
380
+ old_weight = old_emb.weight.data
381
+ new_vocab_size = len(keep_ids)
382
+ old_to_new = {old_id: new_id for new_id, old_id in enumerate(keep_ids)}
383
+
384
+ for attr in ["pad_token_id", "bos_token_id", "eos_token_id",
385
+ "cls_token_id", "sep_token_id", "unk_token_id",
386
+ "mask_token_id", "decoder_start_token_id"]:
387
+ old_id = getattr(model.config, attr, None)
388
+ if old_id is not None:
389
+ if old_id in old_to_new:
390
+ setattr(model.config, attr, old_to_new[old_id])
391
+ else:
392
+ setattr(model.config, attr, None)
393
+
394
+ padding_idx = getattr(old_emb, 'padding_idx', None)
395
+ if padding_idx is not None:
396
+ padding_idx = old_to_new.get(padding_idx, None)
397
+ new_emb = nn.Embedding(new_vocab_size, old_weight.shape[1], padding_idx=padding_idx)
398
+ for new_id, old_id in enumerate(keep_ids):
399
+ if old_id < old_weight.shape[0]:
400
+ new_emb.weight.data[new_id] = old_weight[old_id]
401
+
402
+ model.set_input_embeddings(new_emb)
403
+ model.config.vocab_size = new_vocab_size
404
+ return model
405
+
406
+
407
+ def _copy_custom_code_files(model, target_dir):
408
+ import shutil
409
+ import glob as _glob
410
+
411
+ source_path = getattr(model.config, '_name_or_path', None)
412
+ if not source_path:
413
+ return
414
+
415
+ cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "huggingface", "hub")
416
+ model_cache_name = "models--" + source_path.replace("/", "--")
417
+ model_cache_dir = os.path.join(cache_dir, model_cache_name)
418
+
419
+ if os.path.exists(model_cache_dir):
420
+ snapshots_dir = os.path.join(model_cache_dir, "snapshots")
421
+ if os.path.exists(snapshots_dir):
422
+ snapshots = os.listdir(snapshots_dir)
423
+ if snapshots:
424
+ latest = os.path.join(snapshots_dir, snapshots[-1])
425
+ for py_file in _glob.glob(os.path.join(latest, "*.py")):
426
+ fname = os.path.basename(py_file)
427
+ dest = os.path.join(target_dir, fname)
428
+ if not os.path.exists(dest):
429
+ shutil.copy2(py_file, dest)
430
+
431
+
432
+ def _get_default_multilingual_samples() -> list[str]:
433
+ return [
434
+ "예약 좀 해줘", "지난번 주문 뭐였지?", "안녕하세요 반갑습니다",
435
+ "Book a table for me", "What did I order last time?", "Hello how are you",
436
+ "予約をお願いします", "前回の注文は何でしたか", "こんにちは元気ですか",
437
+ "帮我预约一下", "上次我点了什么", "你好你好吗",
438
+ "Reserva una mesa", "Qué pedí la última vez", "Hola cómo estás",
439
+ "Réservez une table", "Qu'est-ce que j'ai commandé", "Bonjour comment allez-vous",
440
+ "Reservieren Sie einen Tisch", "Was habe ich bestellt", "Hallo wie geht es",
441
+ ] * 10