mediapipe-nightly 0.10.10.post20240216__cp311-cp311-macosx_11_0_universal2.whl → 0.10.10.post20240220__cp311-cp311-macosx_11_0_universal2.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (22) hide show
  1. mediapipe/__init__.py +1 -1
  2. mediapipe/python/_framework_bindings.cpython-311-darwin.so +0 -0
  3. mediapipe/tasks/python/__init__.py +1 -0
  4. mediapipe/tasks/python/genai/__init__.py +14 -0
  5. mediapipe/tasks/python/genai/converter/__init__.py +24 -0
  6. mediapipe/tasks/python/genai/converter/converter_base.py +172 -0
  7. mediapipe/tasks/python/genai/converter/converter_factory.py +79 -0
  8. mediapipe/tasks/python/genai/converter/llm_converter.py +213 -0
  9. mediapipe/tasks/python/genai/converter/pytorch_converter.py +315 -0
  10. mediapipe/tasks/python/genai/converter/pytorch_converter_test.py +86 -0
  11. mediapipe/tasks/python/genai/converter/quantization_util.py +516 -0
  12. mediapipe/tasks/python/genai/converter/quantization_util_test.py +259 -0
  13. mediapipe/tasks/python/genai/converter/safetensors_converter.py +521 -0
  14. mediapipe/tasks/python/genai/converter/safetensors_converter_test.py +83 -0
  15. mediapipe/tasks/python/genai/converter/weight_bins_writer.py +111 -0
  16. mediapipe/tasks/python/genai/converter/weight_bins_writer_test.py +62 -0
  17. mediapipe/version.txt +1 -1
  18. {mediapipe_nightly-0.10.10.post20240216.dist-info → mediapipe_nightly-0.10.10.post20240220.dist-info}/METADATA +1 -1
  19. {mediapipe_nightly-0.10.10.post20240216.dist-info → mediapipe_nightly-0.10.10.post20240220.dist-info}/RECORD +21 -8
  20. {mediapipe_nightly-0.10.10.post20240216.dist-info → mediapipe_nightly-0.10.10.post20240220.dist-info}/LICENSE +0 -0
  21. {mediapipe_nightly-0.10.10.post20240216.dist-info → mediapipe_nightly-0.10.10.post20240220.dist-info}/WHEEL +0 -0
  22. {mediapipe_nightly-0.10.10.post20240216.dist-info → mediapipe_nightly-0.10.10.post20240220.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,521 @@
1
+ # Copyright 2024 The MediaPipe Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """CkptLoader implementation for loading the Safetensors."""
16
+
17
+ import array
18
+ import enum
19
+ import glob
20
+ import json
21
+ import os
22
+ from typing import List, Optional
23
+
24
+ import numpy as np
25
+ import torch
26
+
27
+ from mediapipe.tasks.python.genai.converter import converter_base
28
+
29
+
30
+ DTYPE_MAP = {
31
+ "F16": torch.float16,
32
+ "BF16": torch.bfloat16,
33
+ "F32": torch.float32,
34
+ }
35
+
36
+
37
+ class _SafetensorsShardReader:
38
+ """Reads a single safetensors shard."""
39
+
40
+ _HEAD_BYTES = 8
41
+
42
+ def __init__(self, shard_path: str):
43
+ self._shard_path = shard_path
44
+ if not os.path.exists(self._shard_path):
45
+ raise ValueError(f"{self._shard_path} does not exists.")
46
+ with open(self._shard_path, "rb") as f:
47
+ head_bytes = f.read(self._HEAD_BYTES)
48
+ metadata_bytes_num = np.frombuffer(head_bytes, dtype=np.uint64)[0]
49
+ metadata_bytes = f.read(metadata_bytes_num)
50
+ self.layers_info = json.loads(metadata_bytes)
51
+ self.metadata_bytes_num = metadata_bytes_num
52
+
53
+ def read_tensor_as_numpy(self, tensor_name) -> np.ndarray:
54
+ """Reads a tensor from the model file as a numpy array with np.float32 type."""
55
+ tensor_info = self.layers_info[tensor_name]
56
+ with open(self._shard_path, "rb") as f:
57
+ shape = tensor_info["shape"]
58
+ dtype = tensor_info["dtype"]
59
+ if dtype not in DTYPE_MAP:
60
+ raise ValueError(f"{dtype} is not supported.")
61
+ data_offsets = tensor_info["data_offsets"]
62
+ f.seek(int(self._HEAD_BYTES + self.metadata_bytes_num + data_offsets[0]))
63
+ tensor_bytes = f.read(data_offsets[1] - data_offsets[0])
64
+ raw_tensor = torch.frombuffer(
65
+ array.array("b", tensor_bytes), dtype=DTYPE_MAP[dtype]
66
+ ).reshape(shape)
67
+ return raw_tensor.float().t().contiguous().numpy()
68
+
69
+ def get_tensor_names(self) -> List[str]:
70
+ names = list(self.layers_info.keys())
71
+ if "__metadata__" in names:
72
+ names.remove("__metadata__")
73
+ return names
74
+
75
+
76
+ class _SafetensorsReader:
77
+ """Reads all the safetensors shards."""
78
+
79
+ def __init__(self, ckpt_path: str):
80
+ shards = []
81
+ if os.path.isdir(ckpt_path):
82
+ # Read all safetensors files within checkpoint
83
+ for shard_path in glob.glob(os.path.join(ckpt_path, "*.safetensors")):
84
+ shards.append(_SafetensorsShardReader(shard_path))
85
+ else:
86
+ # Assume the ckpt_path is a file or a file pattern to match.
87
+ for shard_path in glob.glob(ckpt_path):
88
+ shards.append(_SafetensorsShardReader(shard_path))
89
+ assert shards is not None
90
+
91
+ self._ckpt_path = ckpt_path
92
+ self._tensors_map = {}
93
+ for shard in shards:
94
+ tensor_names = shard.get_tensor_names()
95
+ for tensor_name in tensor_names:
96
+ if tensor_name in self._tensors_map:
97
+ raise ValueError(f"Duplicate tensor name: {tensor_name}")
98
+ self._tensors_map[tensor_name] = shard
99
+
100
+ def get_tensor_names(self) -> List[str]:
101
+ return list(self._tensors_map.keys())
102
+
103
+ def read_tensor_as_numpy(self, tensor_name: str) -> np.ndarray:
104
+ return self._tensors_map[tensor_name].read_tensor_as_numpy(tensor_name)
105
+
106
+
107
+ class LayerType(enum.Enum):
108
+ """Enum for layer type."""
109
+
110
+ NONE = 0
111
+ ATTENTION = 1 # Layer is part of the attention module.
112
+ FEEDFORWARD = 2 # Layer is part of the feedforward module in the Transformer.
113
+ EMBEDDING = 3 # Layer is the embedding lookup or final projection layer.
114
+ LAYER_NORM = (
115
+ 4 # Layer is layer normalization before and after attention layer.
116
+ )
117
+
118
+ @classmethod
119
+ def get_layer_type(cls, layer_name: str):
120
+ """Gets the layer type of the given layer name."""
121
+ ffn_layers = [
122
+ "mlp",
123
+ ]
124
+ attn_layers = [
125
+ "self_attn",
126
+ ]
127
+ emb_layers = [
128
+ "embed_tokens",
129
+ "lm_head",
130
+ ]
131
+ layer_norms = [
132
+ "input_layernorm",
133
+ "post_attention_layernorm",
134
+ "final_layernorm",
135
+ "model.norm.weight",
136
+ ]
137
+ if any(sub_name in layer_name for sub_name in attn_layers):
138
+ return LayerType.ATTENTION
139
+ if any(sub_name in layer_name for sub_name in ffn_layers):
140
+ return LayerType.FEEDFORWARD
141
+ if any(sub_name in layer_name for sub_name in emb_layers):
142
+ return LayerType.EMBEDDING
143
+ if any(sub_name in layer_name for sub_name in layer_norms):
144
+ return LayerType.LAYER_NORM
145
+ else:
146
+ return LayerType.NONE
147
+
148
+
149
+ class StablelmMapper(converter_base.LayerActionMapperBase):
150
+ """LayerActionMapper for handling the StableLM model."""
151
+
152
+ def __init__(
153
+ self,
154
+ is_symmetric: bool,
155
+ attention_quant_bits: int,
156
+ feedforward_quant_bits: int,
157
+ embedding_quant_bits: int,
158
+ backend: str,
159
+ reader: _SafetensorsReader,
160
+ ):
161
+ super().__init__(
162
+ is_symmetric=is_symmetric,
163
+ attention_quant_bits=attention_quant_bits,
164
+ feedforward_quant_bits=feedforward_quant_bits,
165
+ embedding_quant_bits=embedding_quant_bits,
166
+ backend=backend,
167
+ )
168
+ self._reader = reader
169
+
170
+ def map_to_actions(
171
+ self, layer_name: str
172
+ ) -> Optional[List[converter_base.QuantizationAction]]:
173
+ """Map the given layer name to actions."""
174
+ tensor_value = self._reader.read_tensor_as_numpy(layer_name)
175
+ quantize_axis = None
176
+ quantize_bits = None
177
+ layer_type = LayerType.get_layer_type(layer_name)
178
+
179
+ if layer_type != LayerType.LAYER_NORM and layer_name.endswith(".weight"):
180
+ quantize_axis = [0]
181
+ if layer_type == LayerType.FEEDFORWARD:
182
+ quantize_bits = self._feedforward_quant_bits
183
+ elif layer_type == LayerType.ATTENTION:
184
+ quantize_bits = self._attention_quant_bits
185
+ if self._backend == "cpu" and ".o_proj." in layer_name:
186
+ tensor_value = np.transpose(tensor_value)
187
+ quantize_axis = [1]
188
+ elif layer_type == LayerType.EMBEDDING:
189
+ quantize_bits = self._embedding_quant_bits
190
+ if self._backend == "cpu" and ".embed_tokens." in layer_name:
191
+ tensor_value = np.transpose(tensor_value)
192
+ quantize_axis = [1]
193
+ target_name = self.update_target_name(layer_name)
194
+
195
+ actions = [
196
+ converter_base.QuantizationAction(
197
+ tensor_name=layer_name,
198
+ tensor_value=tensor_value,
199
+ target_name=target_name,
200
+ quantize_axis=quantize_axis,
201
+ quantize_bits=quantize_bits,
202
+ pack_dim=0,
203
+ )
204
+ ]
205
+ return actions
206
+
207
+ def update_target_name(self, target_name: str) -> str:
208
+ """Updates the target name to match the tensor name convention."""
209
+ target_name = target_name.replace(
210
+ "model.layers.", "params.lm.transformer.x_layers_"
211
+ )
212
+ target_name = target_name.replace("mlp.up_proj", "ff_layer.ffn_layer1")
213
+ target_name = target_name.replace("mlp.down_proj", "ff_layer.ffn_layer2")
214
+ target_name = target_name.replace(
215
+ "mlp.gate_proj", "ff_layer.ffn_layer1_gate"
216
+ )
217
+ target_name = target_name.replace("input_layernorm", "pre_layer_norm")
218
+ target_name = target_name.replace(
219
+ "pre_layer_norm.weight", "pre_layer_norm.scale"
220
+ )
221
+ if self._backend == "cpu":
222
+ target_name = target_name.replace(
223
+ "post_attention_layernorm", "ff_layer.pre_layer_norm"
224
+ )
225
+ target_name = target_name.replace(
226
+ "ff_layer.pre_layer_norm.weight", "ff_layer.pre_layer_norm.scale"
227
+ )
228
+ else:
229
+ target_name = target_name.replace(
230
+ "post_attention_layernorm", "post_layer_norm"
231
+ )
232
+ target_name = target_name.replace(
233
+ "post_layer_norm.weight", "post_layer_norm.scale"
234
+ )
235
+ target_name = target_name.replace("self_attn.q_proj", "self_attention.q")
236
+ target_name = target_name.replace("self_attn.k_proj", "self_attention.k")
237
+ target_name = target_name.replace("self_attn.v_proj", "self_attention.v")
238
+ target_name = target_name.replace("self_attn.o_proj", "self_attention.post")
239
+ target_name = target_name.replace(
240
+ "model.embed_tokens", "params.lm.token_embedding"
241
+ )
242
+ target_name = target_name.replace("model.norm", "params.lm.final_ln")
243
+ target_name = target_name.replace("final_ln.weight", "final_ln.scale")
244
+ target_name = target_name.replace("lm_head", "params.lm.softmax.logits_ffn")
245
+ target_name = target_name.replace(".weight", ".w")
246
+
247
+ return target_name
248
+
249
+
250
+ class PhiMapper(converter_base.LayerActionMapperBase):
251
+ """LayerActionMapper for handling the Phi model."""
252
+
253
+ def __init__(
254
+ self,
255
+ is_symmetric: bool,
256
+ attention_quant_bits: int,
257
+ feedforward_quant_bits: int,
258
+ embedding_quant_bits: int,
259
+ backend: str,
260
+ reader: _SafetensorsReader,
261
+ ):
262
+ super().__init__(
263
+ is_symmetric=is_symmetric,
264
+ attention_quant_bits=attention_quant_bits,
265
+ feedforward_quant_bits=feedforward_quant_bits,
266
+ embedding_quant_bits=embedding_quant_bits,
267
+ backend=backend,
268
+ )
269
+ self._reader = reader
270
+
271
+ def map_to_actions(
272
+ self, layer_name: str
273
+ ) -> Optional[List[converter_base.QuantizationAction]]:
274
+ """Map the given layer name to actions."""
275
+ tensor_value = self._reader.read_tensor_as_numpy(layer_name)
276
+ quantize_axis = None
277
+ quantize_bits = None
278
+ layer_type = LayerType.get_layer_type(layer_name)
279
+
280
+ if layer_type != LayerType.LAYER_NORM and layer_name.endswith(".weight"):
281
+ quantize_axis = [0]
282
+ if layer_type == LayerType.FEEDFORWARD:
283
+ quantize_bits = self._feedforward_quant_bits
284
+ elif layer_type == LayerType.ATTENTION:
285
+ quantize_bits = self._attention_quant_bits
286
+ if self._backend == "cpu" and ".dense." in layer_name:
287
+ tensor_value = np.transpose(tensor_value)
288
+ quantize_axis = [1]
289
+ elif layer_type == LayerType.EMBEDDING:
290
+ quantize_bits = self._embedding_quant_bits
291
+ if self._backend == "cpu" and ".embed_tokens." in layer_name:
292
+ tensor_value = np.transpose(tensor_value)
293
+ quantize_axis = [1]
294
+ target_name = self.update_target_name(layer_name)
295
+
296
+ actions = [
297
+ converter_base.QuantizationAction(
298
+ tensor_name=layer_name,
299
+ tensor_value=tensor_value,
300
+ target_name=target_name,
301
+ quantize_axis=quantize_axis,
302
+ quantize_bits=quantize_bits,
303
+ pack_dim=0,
304
+ )
305
+ ]
306
+ return actions
307
+
308
+ def update_target_name(self, target_name: str) -> str:
309
+ """Updates the target name to match the tensor name convention."""
310
+ target_name = target_name.replace(
311
+ "model.layers.", "params.lm.transformer.x_layers_"
312
+ )
313
+
314
+ layer_type = LayerType.get_layer_type(target_name)
315
+ if layer_type == LayerType.FEEDFORWARD:
316
+ target_name = target_name.replace(".weight", ".linear.w")
317
+ target_name = target_name.replace(".bias", ".bias.b")
318
+ target_name = target_name.replace("mlp.fc1", "ff_layer.ffn_layer1")
319
+ target_name = target_name.replace("mlp.fc2", "ff_layer.ffn_layer2")
320
+
321
+ elif layer_type == LayerType.ATTENTION:
322
+ target_name = target_name.replace(".weight", ".linear.w")
323
+ target_name = target_name.replace(".bias", ".bias.b")
324
+ target_name = target_name.replace("self_attn.q_proj", "self_attention.q")
325
+ target_name = target_name.replace("self_attn.k_proj", "self_attention.k")
326
+ target_name = target_name.replace("self_attn.v_proj", "self_attention.v")
327
+ target_name = target_name.replace(
328
+ "self_attn.dense", "self_attention.post"
329
+ )
330
+ elif layer_type == LayerType.EMBEDDING:
331
+ target_name = target_name.replace(
332
+ "model.embed_tokens", "params.lm.token_embedding"
333
+ )
334
+ target_name = target_name.replace(
335
+ "lm_head", "params.lm.softmax.logits_ffn"
336
+ )
337
+ target_name = target_name.replace(
338
+ "logits_ffn.weight", "logits_ffn.linear.w"
339
+ )
340
+ target_name = target_name.replace("logits_ffn.bias", "logits_ffn.bias.b")
341
+ elif layer_type == LayerType.LAYER_NORM:
342
+ target_name = target_name.replace("input_layernorm", "pre_layer_norm")
343
+ target_name = target_name.replace(
344
+ "pre_layer_norm.weight", "pre_layer_norm.scale"
345
+ )
346
+ target_name = target_name.replace(
347
+ "model.final_layernorm", "params.lm.final_ln"
348
+ )
349
+ target_name = target_name.replace("final_ln.weight", "final_ln.scale")
350
+ target_name = target_name.replace(".weight", ".w")
351
+ return target_name
352
+
353
+
354
+ class GemmaMapper(converter_base.LayerActionMapperBase):
355
+ """LayerActionMapper for handling the StableLM model."""
356
+
357
+ def __init__(
358
+ self,
359
+ is_symmetric: bool,
360
+ attention_quant_bits: int,
361
+ feedforward_quant_bits: int,
362
+ embedding_quant_bits: int,
363
+ backend: str,
364
+ reader: _SafetensorsReader,
365
+ ):
366
+ super().__init__(
367
+ is_symmetric=is_symmetric,
368
+ attention_quant_bits=attention_quant_bits,
369
+ feedforward_quant_bits=feedforward_quant_bits,
370
+ embedding_quant_bits=embedding_quant_bits,
371
+ backend=backend,
372
+ )
373
+ self._reader = reader
374
+
375
+ def map_to_actions(
376
+ self, layer_name: str
377
+ ) -> Optional[List[converter_base.QuantizationAction]]:
378
+ """Map the given layer name to actions."""
379
+ tensor_value = self._reader.read_tensor_as_numpy(layer_name)
380
+ quantize_axis = None
381
+ quantize_bits = None
382
+ layer_type = LayerType.get_layer_type(layer_name)
383
+
384
+ if layer_type != LayerType.LAYER_NORM and layer_name.endswith(".weight"):
385
+ quantize_axis = [0]
386
+ if layer_type == LayerType.FEEDFORWARD:
387
+ quantize_bits = self._feedforward_quant_bits
388
+ elif layer_type == LayerType.ATTENTION:
389
+ quantize_bits = self._attention_quant_bits
390
+ if "o_proj" in layer_name:
391
+ tensor_value = np.transpose(tensor_value)
392
+ quantize_axis = [1]
393
+ elif layer_type == LayerType.EMBEDDING:
394
+ quantize_bits = self._embedding_quant_bits
395
+ target_name = self.update_target_name(layer_name)
396
+
397
+ actions = [
398
+ converter_base.QuantizationAction(
399
+ tensor_name=layer_name,
400
+ tensor_value=tensor_value,
401
+ target_name=target_name,
402
+ quantize_axis=quantize_axis,
403
+ quantize_bits=quantize_bits,
404
+ pack_dim=0,
405
+ )
406
+ ]
407
+ return actions
408
+
409
+ def update_target_name(self, target_name: str) -> str:
410
+ """Updates the target name to match the tensor name convention."""
411
+ target_name = target_name.replace(
412
+ "model.layers.", "params.lm.transformer.x_layers_"
413
+ )
414
+ target_name = target_name.replace("mlp.up_proj", "ff_layer.ffn_layer1")
415
+ target_name = target_name.replace("mlp.down_proj", "ff_layer.ffn_layer2")
416
+ target_name = target_name.replace(
417
+ "mlp.gate_proj", "ff_layer.ffn_layer1_gate"
418
+ )
419
+ target_name = target_name.replace("input_layernorm", "pre_layer_norm")
420
+ target_name = target_name.replace(
421
+ "pre_layer_norm.weight", "pre_layer_norm.scale"
422
+ )
423
+ target_name = target_name.replace(
424
+ "post_attention_layernorm", "ff_layer.pre_layer_norm"
425
+ )
426
+ target_name = target_name.replace(
427
+ "ff_layer.pre_layer_norm.weight", "ff_layer.pre_layer_norm.scale"
428
+ )
429
+ target_name = target_name.replace("self_attn.q_proj", "self_attention.q")
430
+ target_name = target_name.replace("self_attn.k_proj", "self_attention.k")
431
+ target_name = target_name.replace("self_attn.v_proj", "self_attention.v")
432
+ target_name = target_name.replace("self_attn.o_proj", "self_attention.post")
433
+ target_name = target_name.replace(
434
+ "model.embed_tokens", "params.lm.softmax.logits_ffn"
435
+ )
436
+ target_name = target_name.replace("model.norm", "params.lm.final_ln")
437
+ target_name = target_name.replace("final_ln.weight", "final_ln.scale")
438
+ target_name = target_name.replace(".weight", ".w")
439
+
440
+ return target_name
441
+
442
+
443
+ class SafetensorsCkptLoader(converter_base.CkptLoaderBase):
444
+ """CkptLoader implementation for loading the Safetensors."""
445
+
446
+ def __init__(
447
+ self,
448
+ ckpt_path: str,
449
+ is_symmetric: bool,
450
+ attention_quant_bits: int,
451
+ feedforward_quant_bits: int,
452
+ embedding_quant_bits: int,
453
+ special_model: str,
454
+ backend: str,
455
+ ):
456
+ """Initializes the loader.
457
+
458
+ Args:
459
+ ckpt_path: The filepath to the safetensors file.
460
+ is_symmetric: Whether to apply symmetric or asymmetric quantization.
461
+ attention_quant_bits: An integer that specify the target quantization bits
462
+ (support 8 or 4) for the attention layers.
463
+ feedforward_quant_bits: An integer that specify the target quantization
464
+ bits (support 8 or 4) for the feedforward layers in each Transformer
465
+ blocks.
466
+ embedding_quant_bits: An integer that specify the target quantization bits
467
+ (support 8 or 4) for the embedding (and the final projection) layers.
468
+ special_model: A string that indicates which input model is and whether
469
+ any special treatment is needed.
470
+ backend: A string indicating the backend used when converting this model.
471
+ Valid options are "cpu" and "gpu".
472
+ """
473
+ super().__init__(
474
+ ckpt_path,
475
+ is_symmetric,
476
+ attention_quant_bits,
477
+ feedforward_quant_bits,
478
+ embedding_quant_bits,
479
+ )
480
+
481
+ self._special_model = special_model
482
+ self._reader = _SafetensorsReader(ckpt_path)
483
+ if special_model in ["STABLELM_4E1T_3B"]:
484
+ self.mapper = StablelmMapper(
485
+ is_symmetric,
486
+ attention_quant_bits,
487
+ feedforward_quant_bits,
488
+ embedding_quant_bits,
489
+ backend,
490
+ self._reader,
491
+ )
492
+ elif special_model in ["PHI_2"]:
493
+ self.mapper = PhiMapper(
494
+ is_symmetric,
495
+ attention_quant_bits,
496
+ feedforward_quant_bits,
497
+ embedding_quant_bits,
498
+ backend,
499
+ self._reader,
500
+ )
501
+ elif special_model in ["GEMMA_2B"]:
502
+ self.mapper = GemmaMapper(
503
+ is_symmetric,
504
+ attention_quant_bits,
505
+ feedforward_quant_bits,
506
+ embedding_quant_bits,
507
+ backend,
508
+ self._reader,
509
+ )
510
+ else:
511
+ raise ValueError(f"Unknown special model: {special_model}")
512
+
513
+ def load_to_actions(self) -> List[converter_base.QuantizationAction]:
514
+ tensor_names = self._reader.get_tensor_names()
515
+ actions = []
516
+ for tensor_name in tensor_names:
517
+ tensor_actions = self.mapper.map_to_actions(tensor_name)
518
+ if tensor_actions is None:
519
+ continue
520
+ actions.extend(tensor_actions)
521
+ return actions
@@ -0,0 +1,83 @@
1
+ # Copyright 2024 The MediaPipe Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Unit tests for safetensors_converter."""
16
+
17
+ import os
18
+
19
+ from absl.testing import absltest
20
+ from absl.testing import parameterized
21
+
22
+ from mediapipe.tasks.python.genai.converter import safetensors_converter
23
+ from mediapipe.tasks.python.test import test_utils
24
+
25
+ _TEST_DATA_DIR = 'mediapipe/tasks/testdata/text'
26
+ _SAFETENSORS_FILE = test_utils.get_test_data_path(
27
+ os.path.join(_TEST_DATA_DIR, 'stablelm_3b_4e1t_test_weight.safetensors')
28
+ )
29
+
30
+
31
+ class SafetensorsConverterTest(parameterized.TestCase):
32
+ VARIABLE_NAMES = [
33
+ 'model.embed_tokens.weight',
34
+ 'model.layers.0.input_layernorm.bias',
35
+ 'model.layers.0.input_layernorm.weight',
36
+ 'model.layers.0.mlp.down_proj.weight',
37
+ 'model.layers.0.mlp.gate_proj.weight',
38
+ 'model.layers.0.mlp.up_proj.weight',
39
+ 'model.layers.0.post_attention_layernorm.bias',
40
+ 'model.layers.0.post_attention_layernorm.weight',
41
+ 'model.layers.0.self_attn.k_proj.weight',
42
+ 'model.layers.0.self_attn.o_proj.weight',
43
+ 'model.layers.0.self_attn.q_proj.weight',
44
+ 'model.layers.0.self_attn.v_proj.weight',
45
+ 'model.norm.bias',
46
+ 'model.norm.weight',
47
+ 'lm_head.weight',
48
+ ]
49
+
50
+ def test_init(self):
51
+ loader = safetensors_converter.SafetensorsCkptLoader(
52
+ ckpt_path=_SAFETENSORS_FILE,
53
+ is_symmetric=True,
54
+ attention_quant_bits=8,
55
+ feedforward_quant_bits=8,
56
+ embedding_quant_bits=8,
57
+ special_model='STABLELM_4E1T_3B',
58
+ backend='gpu',
59
+ )
60
+ self.assertEqual(loader._ckpt_path, _SAFETENSORS_FILE)
61
+ self.assertEqual(loader._is_symmetric, True)
62
+ self.assertEqual(loader._attention_quant_bits, 8)
63
+ self.assertEqual(loader._feedforward_quant_bits, 8)
64
+
65
+ @parameterized.product(
66
+ quant_bits=(4, 8),
67
+ )
68
+ def test_load_to_actions(self, quant_bits):
69
+ loader = safetensors_converter.SafetensorsCkptLoader(
70
+ ckpt_path=_SAFETENSORS_FILE,
71
+ is_symmetric=True,
72
+ attention_quant_bits=8,
73
+ feedforward_quant_bits=quant_bits,
74
+ embedding_quant_bits=8,
75
+ special_model='STABLELM_4E1T_3B',
76
+ backend='gpu',
77
+ )
78
+ actions = loader.load_to_actions()
79
+ self.assertLen(actions, 15)
80
+
81
+
82
+ if __name__ == '__main__':
83
+ absltest.main()