modelinfo-cli 1.0.0__tar.gz → 1.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (26) hide show
  1. {modelinfo_cli-1.0.0/src/modelinfo_cli.egg-info → modelinfo_cli-1.1.0}/PKG-INFO +3 -3
  2. {modelinfo_cli-1.0.0 → modelinfo_cli-1.1.0}/README.md +2 -2
  3. {modelinfo_cli-1.0.0 → modelinfo_cli-1.1.0}/pyproject.toml +1 -1
  4. modelinfo_cli-1.1.0/src/modelinfo/architecture.py +109 -0
  5. {modelinfo_cli-1.0.0 → modelinfo_cli-1.1.0}/src/modelinfo/calculator.py +4 -3
  6. {modelinfo_cli-1.0.0 → modelinfo_cli-1.1.0}/src/modelinfo/cli.py +24 -2
  7. {modelinfo_cli-1.0.0 → modelinfo_cli-1.1.0}/src/modelinfo/ui.py +11 -4
  8. {modelinfo_cli-1.0.0 → modelinfo_cli-1.1.0/src/modelinfo_cli.egg-info}/PKG-INFO +3 -3
  9. {modelinfo_cli-1.0.0 → modelinfo_cli-1.1.0}/tests/test_calculator.py +36 -0
  10. {modelinfo_cli-1.0.0 → modelinfo_cli-1.1.0}/tests/test_parsers.py +15 -0
  11. modelinfo_cli-1.0.0/src/modelinfo/architecture.py +0 -45
  12. {modelinfo_cli-1.0.0 → modelinfo_cli-1.1.0}/LICENSE +0 -0
  13. {modelinfo_cli-1.0.0 → modelinfo_cli-1.1.0}/setup.cfg +0 -0
  14. {modelinfo_cli-1.0.0 → modelinfo_cli-1.1.0}/src/modelinfo/__init__.py +0 -0
  15. {modelinfo_cli-1.0.0 → modelinfo_cli-1.1.0}/src/modelinfo/__main__.py +0 -0
  16. {modelinfo_cli-1.0.0 → modelinfo_cli-1.1.0}/src/modelinfo/parsers/__init__.py +0 -0
  17. {modelinfo_cli-1.0.0 → modelinfo_cli-1.1.0}/src/modelinfo/parsers/base.py +0 -0
  18. {modelinfo_cli-1.0.0 → modelinfo_cli-1.1.0}/src/modelinfo/parsers/gguf.py +0 -0
  19. {modelinfo_cli-1.0.0 → modelinfo_cli-1.1.0}/src/modelinfo/parsers/pytorch.py +0 -0
  20. {modelinfo_cli-1.0.0 → modelinfo_cli-1.1.0}/src/modelinfo/parsers/safetensors.py +0 -0
  21. {modelinfo_cli-1.0.0 → modelinfo_cli-1.1.0}/src/modelinfo_cli.egg-info/SOURCES.txt +0 -0
  22. {modelinfo_cli-1.0.0 → modelinfo_cli-1.1.0}/src/modelinfo_cli.egg-info/dependency_links.txt +0 -0
  23. {modelinfo_cli-1.0.0 → modelinfo_cli-1.1.0}/src/modelinfo_cli.egg-info/entry_points.txt +0 -0
  24. {modelinfo_cli-1.0.0 → modelinfo_cli-1.1.0}/src/modelinfo_cli.egg-info/requires.txt +0 -0
  25. {modelinfo_cli-1.0.0 → modelinfo_cli-1.1.0}/src/modelinfo_cli.egg-info/top_level.txt +0 -0
  26. {modelinfo_cli-1.0.0 → modelinfo_cli-1.1.0}/tests/test_constraints.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: modelinfo-cli
3
- Version: 1.0.0
3
+ Version: 1.1.0
4
4
  Summary: A sub-100ms, zero-dependency CLI tool to inspect ML model checkpoints and dynamically calculate VRAM requirements.
5
5
  Author: ModelInfo Contributors
6
6
  License: MIT
@@ -28,9 +28,9 @@ It reads binary headers directly using the Python standard library. By bypassing
28
28
 
29
29
  ## Features
30
30
 
31
- - **Zero-Dependency Parsing**: Reads the 8-byte JSON prefix of `.safetensors` files and the binary key-value metadata of `.gguf` directly via `struct` and `json`.
31
+ - **Zero-Dependency Parsing**: Reads the 8-byte JSON prefix of `.safetensors` files and the binary key-value metadata of `.gguf` directly via `struct` and `json`. Seamlessly reads adjacent `config.json` for robust fallback logic.
32
32
  - **Sharded Model Support**: Transparently parses `model.safetensors.index.json` to detect multi-file checkpoint distributions, gracefully guarding against partial downloads without crashing.
33
- - **Dynamic VRAM Estimation**: Extracts underlying model architecture (layers, heads, dimensions) to calculate exact VRAM limits, including dynamic KV cache footprints based on user-specified context lengths.
33
+ - **Dynamic VRAM Estimation**: Extracts underlying model architecture (layers, heads, dimensions) to calculate exact VRAM limits, including dynamic KV cache footprints based on user-specified context lengths. Actively warns users if requested context exceeds the model's native limit.
34
34
  - **Precise Block Quantization**: Factors in exact byte-scaling coefficients for GGUF formats (e.g., Q8, Q6, Q4) rather than naive averages, eliminating VRAM under-reporting.
35
35
  - **Secure Pickling**: Inspects legacy `.pt` files without executing arbitrary code by using a highly restricted `pickle.Unpickler`.
36
36
  - **Terminal UI**: Groups repetitive structural layers and color-codes VRAM heatmaps using `rich`.
@@ -10,9 +10,9 @@ It reads binary headers directly using the Python standard library. By bypassing
10
10
 
11
11
  ## Features
12
12
 
13
- - **Zero-Dependency Parsing**: Reads the 8-byte JSON prefix of `.safetensors` files and the binary key-value metadata of `.gguf` directly via `struct` and `json`.
13
+ - **Zero-Dependency Parsing**: Reads the 8-byte JSON prefix of `.safetensors` files and the binary key-value metadata of `.gguf` directly via `struct` and `json`. Seamlessly reads adjacent `config.json` for robust fallback logic.
14
14
  - **Sharded Model Support**: Transparently parses `model.safetensors.index.json` to detect multi-file checkpoint distributions, gracefully guarding against partial downloads without crashing.
15
- - **Dynamic VRAM Estimation**: Extracts underlying model architecture (layers, heads, dimensions) to calculate exact VRAM limits, including dynamic KV cache footprints based on user-specified context lengths.
15
+ - **Dynamic VRAM Estimation**: Extracts underlying model architecture (layers, heads, dimensions) to calculate exact VRAM limits, including dynamic KV cache footprints based on user-specified context lengths. Actively warns users if requested context exceeds the model's native limit.
16
16
  - **Precise Block Quantization**: Factors in exact byte-scaling coefficients for GGUF formats (e.g., Q8, Q6, Q4) rather than naive averages, eliminating VRAM under-reporting.
17
17
  - **Secure Pickling**: Inspects legacy `.pt` files without executing arbitrary code by using a highly restricted `pickle.Unpickler`.
18
18
  - **Terminal UI**: Groups repetitive structural layers and color-codes VRAM heatmaps using `rich`.
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "modelinfo-cli"
7
- version = "1.0.0"
7
+ version = "1.1.0"
8
8
  description = "A sub-100ms, zero-dependency CLI tool to inspect ML model checkpoints and dynamically calculate VRAM requirements."
9
9
  readme = "README.md"
10
10
  requires-python = ">=3.10"
@@ -0,0 +1,109 @@
1
+ import os
2
+ import json
3
+ from typing import Any, Dict, Tuple
4
+
5
+ def extract_architecture(tensors: Dict[str, Any], config: Dict[str, Any] = None) -> Tuple[int, int, bool]:
6
+ """
7
+ Extracts the number of layers and KV cache dimension (kv_heads * head_dim).
8
+ Returns (num_layers, kv_dim, is_estimate).
9
+ """
10
+ num_layers = 0
11
+ kv_dim = 0
12
+ is_estimate = False
13
+
14
+ metadata = tensors.get("__metadata__", {})
15
+ gen_arch = metadata.get("general.architecture")
16
+
17
+ # 1. Attempt explicit GGUF metadata
18
+ if gen_arch:
19
+ arch_str = str(gen_arch)
20
+ num_layers = metadata.get(f"{arch_str}.block_count", 0)
21
+ kv_heads = metadata.get(f"{arch_str}.attention.head_count_kv", 0)
22
+
23
+ key_length = metadata.get(f"{arch_str}.attention.key_length")
24
+ if not key_length:
25
+ embed_len = metadata.get(f"{arch_str}.embedding_length", 0)
26
+ q_heads = metadata.get(f"{arch_str}.attention.head_count", 1)
27
+ if q_heads > 0:
28
+ key_length = embed_len // q_heads
29
+ else:
30
+ key_length = 0
31
+
32
+ if kv_heads > 0 and key_length > 0:
33
+ kv_dim = kv_heads * key_length
34
+ if num_layers > 0:
35
+ return num_layers, kv_dim, False
36
+
37
+ # 2. Attempt explicit SafeTensors config.json
38
+ if config:
39
+ num_layers = config.get("num_hidden_layers", 0)
40
+ num_attention_heads = config.get("num_attention_heads", 1)
41
+ num_key_value_heads = config.get("num_key_value_heads", num_attention_heads)
42
+ hidden_size = config.get("hidden_size", 0)
43
+
44
+ if num_attention_heads > 0:
45
+ head_dim = hidden_size // num_attention_heads
46
+ kv_dim = num_key_value_heads * head_dim
47
+ if num_layers > 0 and kv_dim > 0:
48
+ return num_layers, kv_dim, False
49
+
50
+ # 3. Fallback to shape guessing
51
+ layers_set = set()
52
+ found_fused = False
53
+ found_k_proj = False
54
+
55
+ for name, meta in tensors.items():
56
+ if name == "__metadata__":
57
+ continue
58
+
59
+ parts = name.split(".")
60
+ if "layers" in parts:
61
+ idx = parts.index("layers")
62
+ if len(parts) > idx + 1 and parts[idx+1].isdigit():
63
+ layers_set.add(int(parts[idx+1]))
64
+ elif "h" in parts:
65
+ idx = parts.index("h")
66
+ if len(parts) > idx + 1 and parts[idx+1].isdigit():
67
+ layers_set.add(int(parts[idx+1]))
68
+
69
+ if name.endswith("k_proj.weight") or name.endswith("attn.k.weight") or name.endswith("k_proj.w"):
70
+ found_k_proj = True
71
+ shape = meta.get("shape", [])
72
+ if len(shape) >= 2:
73
+ kv_dim = shape[0]
74
+
75
+ if "qkv_proj.weight" in name or "c_attn.weight" in name:
76
+ found_fused = True
77
+ if not found_k_proj:
78
+ shape = meta.get("shape", [])
79
+ if len(shape) >= 2:
80
+ kv_dim = shape[0] // 3
81
+
82
+ num_layers = len(layers_set)
83
+ if found_fused and not found_k_proj and kv_dim > 0:
84
+ is_estimate = True
85
+
86
+ return num_layers, kv_dim, is_estimate
87
+
88
+ def identify_architecture_name(tensors: Dict[str, Any], num_layers: int) -> str:
89
+ """Attempt to identify the architecture family based on tensor names or metadata."""
90
+ metadata = tensors.get("__metadata__", {})
91
+ gen_arch = metadata.get("general.architecture")
92
+
93
+ if gen_arch:
94
+ arch_title = str(gen_arch).title()
95
+ return f"{arch_title} ({num_layers} transformer layers)" if num_layers else arch_title
96
+
97
+ for name in tensors.keys():
98
+ if name == "__metadata__":
99
+ continue
100
+
101
+ name_lower = name.lower()
102
+ if "llama" in name_lower:
103
+ return f"Llama ({num_layers} transformer layers)" if num_layers else "Llama"
104
+ if "mistral" in name_lower:
105
+ return f"Mistral ({num_layers} transformer layers)" if num_layers else "Mistral"
106
+ if "qwen" in name_lower:
107
+ return f"Qwen ({num_layers} transformer layers)" if num_layers else "Qwen"
108
+
109
+ return f"Generic Transformer ({num_layers} layers)" if num_layers > 0 else "Unknown Architecture"
@@ -29,7 +29,7 @@ def _get_bytes_per_param(dtype: str) -> float:
29
29
  """Return the size in bytes for a given data type."""
30
30
  return DTYPE_BYTES.get(dtype.upper(), 2.0)
31
31
 
32
- def calculate_footprint(tensors: Dict[str, Any], context_length: int = 0, batch_size: int = 1) -> Dict[str, Any]:
32
+ def calculate_footprint(tensors: Dict[str, Any], context_length: int = 0, batch_size: int = 1, config: Dict[str, Any] = None) -> Dict[str, Any]:
33
33
  """
34
34
  Calculate the memory footprint of a model based on its tensors and context length.
35
35
  """
@@ -54,7 +54,7 @@ def calculate_footprint(tensors: Dict[str, Any], context_length: int = 0, batch_
54
54
  bytes_per_param = _get_bytes_per_param(dtype)
55
55
  base_memory_bytes += param_count * bytes_per_param
56
56
 
57
- num_layers, kv_dim = extract_architecture(tensors)
57
+ num_layers, kv_dim, is_estimate = extract_architecture(tensors, config)
58
58
 
59
59
  # Formula: 2 * Layers * (KV_Heads * Head_Dim) * Context_Length * Batch_Size * Bytes_per_param
60
60
  # Assume FP16 (2 bytes) for KV cache
@@ -69,7 +69,8 @@ def calculate_footprint(tensors: Dict[str, Any], context_length: int = 0, batch_
69
69
  "total_memory_bytes": base_memory_bytes + kv_cache_bytes,
70
70
  "num_layers": num_layers,
71
71
  "kv_dim": kv_dim,
72
- "primary_dtype": primary_dtype
72
+ "primary_dtype": primary_dtype,
73
+ "kv_is_estimate": is_estimate
73
74
  }
74
75
 
75
76
  def format_bytes(size_bytes: float) -> str:
@@ -1,4 +1,5 @@
1
1
  import argparse
2
+ import json
2
3
  import os
3
4
  import sys
4
5
  from typing import Sequence
@@ -37,10 +38,21 @@ def main(argv: Sequence[str] | None = None) -> int:
37
38
 
38
39
  file_path = args.file.lower()
39
40
  tensors = {}
41
+ config = None
40
42
 
41
43
  if file_path.endswith(".safetensors") or file_path.endswith(".index.json"):
42
44
  tensors = parse_safetensors_header(args.file)
43
45
  format_name = "SafeTensors"
46
+
47
+ # Read config.json to maintain pure math engines
48
+ config_path = os.path.join(os.path.dirname(args.file), "config.json")
49
+ if os.path.exists(config_path):
50
+ try:
51
+ with open(config_path, "r", encoding="utf-8") as f:
52
+ config = json.load(f)
53
+ except (json.JSONDecodeError, OSError):
54
+ pass
55
+
44
56
  elif file_path.endswith(".gguf"):
45
57
  tensors = parse_gguf_header(args.file)
46
58
  format_name = "GGUF"
@@ -53,10 +65,19 @@ def main(argv: Sequence[str] | None = None) -> int:
53
65
  )
54
66
  return 1
55
67
 
56
- footprint = calculate_footprint(tensors, context_length=args.context)
68
+ footprint = calculate_footprint(tensors, context_length=args.context, config=config)
57
69
  num_layers = footprint["num_layers"]
58
70
  arch_name = identify_architecture_name(tensors, num_layers)
59
71
 
72
+ max_context = None
73
+ if config:
74
+ max_context = config.get("max_position_embeddings")
75
+ else:
76
+ metadata = tensors.get("__metadata__", {})
77
+ gen_arch = metadata.get("general.architecture")
78
+ if gen_arch:
79
+ max_context = metadata.get(f"{gen_arch}.context_length")
80
+
60
81
  disk_size = os.path.getsize(args.file) if os.path.exists(args.file) else 0.0
61
82
  tensor_count = len([k for k in tensors.keys() if k != "__metadata__"])
62
83
 
@@ -67,7 +88,8 @@ def main(argv: Sequence[str] | None = None) -> int:
67
88
  footprint=footprint,
68
89
  disk_size=disk_size,
69
90
  context_length=args.context,
70
- tensors=tensors
91
+ tensors=tensors,
92
+ max_context=max_context
71
93
  )
72
94
 
73
95
  return 0
@@ -43,7 +43,8 @@ def print_model_info(
43
43
  footprint: Dict[str, Any],
44
44
  disk_size: float,
45
45
  context_length: int,
46
- tensors: Dict[str, Any]
46
+ tensors: Dict[str, Any],
47
+ max_context: int | None = None
47
48
  ) -> None:
48
49
  summary = Table(box=None, show_header=False, pad_edge=False, padding=(0, 2))
49
50
  summary.add_column("Property", style="bold")
@@ -65,7 +66,10 @@ def print_model_info(
65
66
 
66
67
  vram_text = f"~{format_bytes(vram_bytes)}"
67
68
  if context_length > 0:
68
- vram_text += f" ({footprint['primary_dtype']}, KV cache for {context_length} tokens)"
69
+ if footprint.get("kv_is_estimate"):
70
+ vram_text += f" ({footprint['primary_dtype']}, Estimated KV Cache - Missing Config)"
71
+ else:
72
+ vram_text += f" ({footprint['primary_dtype']}, KV cache for {context_length} tokens)"
69
73
  else:
70
74
  vram_text += f" ({footprint['primary_dtype']}, no KV cache)"
71
75
  vram_display = f"[{vram_color}]{vram_text}[/{vram_color}]"
@@ -81,8 +85,11 @@ def print_model_info(
81
85
  console.print(summary)
82
86
 
83
87
  if missing_shards > 0:
84
- console.print(f"[bold yellow]⚠️ Partial Model: Missing {missing_shards} of {total_shards} shards on disk. Totals are incomplete.[/bold yellow]")
88
+ console.print(f"[bold yellow]WARNING: Partial Model. Missing {missing_shards} of {total_shards} shards on disk. Totals are incomplete.[/bold yellow]")
85
89
 
90
+ if context_length > 0 and max_context is not None and context_length > max_context:
91
+ console.print(f"[bold yellow]WARNING: Requested context ({context_length:,}) exceeds model's native limit ({max_context:,}).[/bold yellow]")
92
+
86
93
  console.print()
87
94
 
88
95
  console.print("Top Tensors by Size:", style="bold")
@@ -90,7 +97,7 @@ def print_model_info(
90
97
  grouped_tensors = group_tensors_by_size(tensors)
91
98
 
92
99
  tensor_table = Table(box=None, show_header=False, pad_edge=False, padding=(0, 2))
93
- tensor_table.add_column("Name")
100
+ tensor_table.add_column("Name", no_wrap=True, overflow="fold")
94
101
  tensor_table.add_column("Shape", justify="right")
95
102
  tensor_table.add_column("Dtype", justify="left")
96
103
  tensor_table.add_column("Params", justify="right")
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: modelinfo-cli
3
- Version: 1.0.0
3
+ Version: 1.1.0
4
4
  Summary: A sub-100ms, zero-dependency CLI tool to inspect ML model checkpoints and dynamically calculate VRAM requirements.
5
5
  Author: ModelInfo Contributors
6
6
  License: MIT
@@ -28,9 +28,9 @@ It reads binary headers directly using the Python standard library. By bypassing
28
28
 
29
29
  ## Features
30
30
 
31
- - **Zero-Dependency Parsing**: Reads the 8-byte JSON prefix of `.safetensors` files and the binary key-value metadata of `.gguf` directly via `struct` and `json`.
31
+ - **Zero-Dependency Parsing**: Reads the 8-byte JSON prefix of `.safetensors` files and the binary key-value metadata of `.gguf` directly via `struct` and `json`. Seamlessly reads adjacent `config.json` for robust fallback logic.
32
32
  - **Sharded Model Support**: Transparently parses `model.safetensors.index.json` to detect multi-file checkpoint distributions, gracefully guarding against partial downloads without crashing.
33
- - **Dynamic VRAM Estimation**: Extracts underlying model architecture (layers, heads, dimensions) to calculate exact VRAM limits, including dynamic KV cache footprints based on user-specified context lengths.
33
+ - **Dynamic VRAM Estimation**: Extracts underlying model architecture (layers, heads, dimensions) to calculate exact VRAM limits, including dynamic KV cache footprints based on user-specified context lengths. Actively warns users if requested context exceeds the model's native limit.
34
34
  - **Precise Block Quantization**: Factors in exact byte-scaling coefficients for GGUF formats (e.g., Q8, Q6, Q4) rather than naive averages, eliminating VRAM under-reporting.
35
35
  - **Secure Pickling**: Inspects legacy `.pt` files without executing arbitrary code by using a highly restricted `pickle.Unpickler`.
36
36
  - **Terminal UI**: Groups repetitive structural layers and color-codes VRAM heatmaps using `rich`.
@@ -63,3 +63,39 @@ def test_dynamic_kv_cache():
63
63
  assert footprint["num_layers"] == 2
64
64
  assert footprint["kv_dim"] == 1024
65
65
  assert footprint["kv_cache_bytes"] == 8192000
66
+
67
+ def test_safetensors_config_fallback():
68
+ """Verify that architecture extraction correctly parses a config dictionary for SafeTensors."""
69
+ from modelinfo.architecture import extract_architecture
70
+
71
+ tensors = {
72
+ "model.layers.0.qkv_proj.weight": {
73
+ "shape": [6144, 4096],
74
+ "dtype": "F16"
75
+ }
76
+ }
77
+
78
+ config = {
79
+ "num_hidden_layers": 32,
80
+ "num_attention_heads": 32,
81
+ "num_key_value_heads": 8,
82
+ "hidden_size": 4096
83
+ }
84
+
85
+ num_layers, kv_dim, is_estimate = extract_architecture(tensors, config=config)
86
+
87
+ assert num_layers == 32
88
+ assert kv_dim == 1024
89
+ assert is_estimate is False
90
+
91
+ def test_kv_cache_is_fp16():
92
+ """Verify that KV cache is always calculated using 2.0 bytes (FP16), even for Q4 base models."""
93
+ tensors = {
94
+ "model.layers.0.attn.weight": {"shape": [4096, 4096], "dtype": "Q4"},
95
+ "model.layers.0.attn.k.weight": {"shape": [1024, 4096], "dtype": "Q4"},
96
+ }
97
+
98
+ footprint = calculate_footprint(tensors, context_length=8192)
99
+
100
+ assert footprint["kv_cache_bytes"] == 33554432
101
+ assert footprint["primary_dtype"] == "Q4"
@@ -30,3 +30,18 @@ def test_missing_shard_handling():
30
30
  # it fails safely when a file truly doesn't exist.
31
31
  with pytest.raises(FileNotFoundError):
32
32
  parse_safetensors_header(os.path.join(FIXTURES_DIR, "does_not_exist.safetensors"))
33
+
34
+ def test_gguf_parser_metadata():
35
+ """Verify that the GGUF parser extracts the global metadata bypass."""
36
+ from modelinfo.parsers.gguf import parse_gguf_header
37
+ from modelinfo.architecture import identify_architecture_name
38
+
39
+ mock_path = os.path.join(FIXTURES_DIR, "mock_model.gguf")
40
+ tensors = parse_gguf_header(mock_path)
41
+
42
+ assert "__metadata__" in tensors
43
+ assert tensors["__metadata__"]["general.architecture"] == "qwen2"
44
+
45
+ # Verify the architecture bypass parses it to titlecase and prevents "Unknown Architecture"
46
+ arch_name = identify_architecture_name(tensors, num_layers=1)
47
+ assert arch_name == "Qwen2 (1 transformer layers)"
@@ -1,45 +0,0 @@
1
- from typing import Any, Dict, Tuple
2
-
3
- def extract_architecture(tensors: Dict[str, Any]) -> Tuple[int, int]:
4
- """
5
- Extracts the number of layers and KV cache dimension (kv_heads * head_dim)
6
- from tensor metadata.
7
- """
8
- layers = set()
9
- kv_dim = 0
10
-
11
- for name, metadata in tensors.items():
12
- if name == "__metadata__":
13
- continue
14
-
15
- parts = name.split(".")
16
-
17
- if "layers" in parts:
18
- idx = parts.index("layers")
19
- if len(parts) > idx + 1 and parts[idx+1].isdigit():
20
- layers.add(int(parts[idx+1]))
21
- elif "h" in parts:
22
- idx = parts.index("h")
23
- if len(parts) > idx + 1 and parts[idx+1].isdigit():
24
- layers.add(int(parts[idx+1]))
25
-
26
- if name.endswith("k_proj.weight") or name.endswith("attn.k.weight") or name.endswith("k_proj.w"):
27
- shape = metadata.get("shape", [])
28
- if len(shape) >= 2:
29
- # Typically [out_features, in_features], so out_features is shape[0]
30
- kv_dim = shape[0]
31
-
32
- return len(layers), kv_dim
33
-
34
- def identify_architecture_name(tensors: Dict[str, Any], num_layers: int) -> str:
35
- """Attempt to identify the architecture family based on tensor names."""
36
- for name in tensors.keys():
37
- name_lower = name.lower()
38
- if "llama" in name_lower:
39
- return f"Llama ({num_layers} transformer layers)" if num_layers else "Llama"
40
- if "mistral" in name_lower:
41
- return f"Mistral ({num_layers} transformer layers)" if num_layers else "Mistral"
42
- if "qwen" in name_lower:
43
- return f"Qwen ({num_layers} transformer layers)" if num_layers else "Qwen"
44
-
45
- return f"Generic Transformer ({num_layers} layers)" if num_layers > 0 else "Unknown Architecture"
File without changes
File without changes