leap-model-parser 0.1.250.dev8__tar.gz → 0.1.250.dev10__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (26) hide show
  1. {leap_model_parser-0.1.250.dev8 → leap_model_parser-0.1.250.dev10}/PKG-INFO +2 -3
  2. {leap_model_parser-0.1.250.dev8 → leap_model_parser-0.1.250.dev10}/leap_model_parser/contract/graph.py +1 -2
  3. {leap_model_parser-0.1.250.dev8 → leap_model_parser-0.1.250.dev10}/leap_model_parser/contract/importmodelresponse.py +0 -1
  4. {leap_model_parser-0.1.250.dev8 → leap_model_parser-0.1.250.dev10}/leap_model_parser/leap_graph_editor.py +4 -11
  5. {leap_model_parser-0.1.250.dev8 → leap_model_parser-0.1.250.dev10}/pyproject.toml +2 -3
  6. leap_model_parser-0.1.250.dev8/leap_model_parser/torch_graph_editor.py +0 -93
  7. leap_model_parser-0.1.250.dev8/leap_model_parser/torch_model_parser.py +0 -57
  8. leap_model_parser-0.1.250.dev8/leap_model_parser/torch_utils.py +0 -149
  9. {leap_model_parser-0.1.250.dev8 → leap_model_parser-0.1.250.dev10}/LICENSE +0 -0
  10. {leap_model_parser-0.1.250.dev8 → leap_model_parser-0.1.250.dev10}/README.md +0 -0
  11. {leap_model_parser-0.1.250.dev8 → leap_model_parser-0.1.250.dev10}/leap_model_parser/__init__.py +0 -0
  12. {leap_model_parser-0.1.250.dev8 → leap_model_parser-0.1.250.dev10}/leap_model_parser/contract/__init__.py +0 -0
  13. {leap_model_parser-0.1.250.dev8 → leap_model_parser-0.1.250.dev10}/leap_model_parser/contract/nodedata.py +0 -0
  14. {leap_model_parser-0.1.250.dev8 → leap_model_parser-0.1.250.dev10}/leap_model_parser/contract/ui_components.json +0 -0
  15. {leap_model_parser-0.1.250.dev8 → leap_model_parser-0.1.250.dev10}/leap_model_parser/keras_json_model_import.py +0 -0
  16. {leap_model_parser-0.1.250.dev8 → leap_model_parser-0.1.250.dev10}/leap_model_parser/model_parser.py +0 -0
  17. {leap_model_parser-0.1.250.dev8 → leap_model_parser-0.1.250.dev10}/leap_model_parser/utils/__init__.py +0 -0
  18. {leap_model_parser-0.1.250.dev8 → leap_model_parser-0.1.250.dev10}/leap_model_parser/utils/layerpedia/__init__.py +0 -0
  19. {leap_model_parser-0.1.250.dev8 → leap_model_parser-0.1.250.dev10}/leap_model_parser/utils/layerpedia/layerpedia.py +0 -0
  20. {leap_model_parser-0.1.250.dev8 → leap_model_parser-0.1.250.dev10}/leap_model_parser/utils/tlinspection/__init__.py +0 -0
  21. {leap_model_parser-0.1.250.dev8 → leap_model_parser-0.1.250.dev10}/leap_model_parser/utils/tlinspection/leapinspection.py +0 -0
  22. {leap_model_parser-0.1.250.dev8 → leap_model_parser-0.1.250.dev10}/leap_model_parser/utils/uicomponents/__init__.py +0 -0
  23. {leap_model_parser-0.1.250.dev8 → leap_model_parser-0.1.250.dev10}/leap_model_parser/utils/uicomponents/generatenodedata.py +0 -0
  24. {leap_model_parser-0.1.250.dev8 → leap_model_parser-0.1.250.dev10}/leap_model_parser/utils/uicomponents/tensorflowinscpection.py +0 -0
  25. {leap_model_parser-0.1.250.dev8 → leap_model_parser-0.1.250.dev10}/leap_model_parser/utils/uicomponents/ui_components.json +0 -0
  26. {leap_model_parser-0.1.250.dev8 → leap_model_parser-0.1.250.dev10}/leap_model_parser/utils/uicomponents/ui_components_config.yaml +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: leap-model-parser
3
- Version: 0.1.250.dev8
3
+ Version: 0.1.250.dev10
4
4
  Summary:
5
5
  Home-page: https://github.com/tensorleap/leap-model-parser
6
6
  License: MIT
@@ -17,11 +17,10 @@ Requires-Dist: keras-data-format-converter (==0.1.24)
17
17
  Requires-Dist: leap-model-rebuilder (==0.1.7)
18
18
  Requires-Dist: numpy (>=1.22.3,<2.0.0)
19
19
  Requires-Dist: onnx (==1.13.0)
20
- Requires-Dist: onnx2kerastl (==0.0.191)
20
+ Requires-Dist: onnx2kerastl (==0.0.192.dev1)
21
21
  Requires-Dist: tensorflow (==2.12.0) ; platform_machine == "x86_64"
22
22
  Requires-Dist: tensorflow-io-gcs-filesystem (==0.34.0)
23
23
  Requires-Dist: tensorflow-macos (==2.12.0) ; platform_machine == "arm64"
24
- Requires-Dist: torch (>=2.3.0,<2.4.0)
25
24
  Project-URL: Repository, https://github.com/tensorleap/leap-model-parser
26
25
  Description-Content-Type: text/markdown
27
26
 
@@ -63,5 +63,4 @@ class Node:
63
63
  @dataclass
64
64
  class ModelGraph:
65
65
  id: str
66
- nodes: Dict[str, Node] = field(default_factory=dict)
67
- framework: str = "keras"
66
+ nodes: Dict[str, Node] = field(default_factory=dict)
@@ -6,4 +6,3 @@ class ImportModelTypeEnum(Enum):
6
6
  ONNX = "ONNX"
7
7
  PB_TF2 = "PB_TF2"
8
8
  H5_TF2 = "H5_TF2"
9
- PYTORCH_PT2 = "PYTORCH_PT2"
@@ -2,13 +2,14 @@ from enum import Enum
2
2
  from typing import Optional, Dict, Any, List
3
3
 
4
4
  from code_loader.contract.mapping import NodeConnection, NodeMappingType, NodeMapping # type: ignore
5
+ from keras import Model # type: ignore
5
6
 
6
7
  from leap_model_parser.contract.graph import Node as Node, OutputData, ConnectionOutput, ConnectionInput, InputData
7
8
 
8
9
 
9
10
 
10
11
  class LeapGraphEditor:
11
- def __init__(self, model_graph: Dict[str, Node], keras_model=None):
12
+ def __init__(self, model_graph: Dict[str, Node], keras_model: Model):
12
13
  self.model_graph = model_graph
13
14
  self.keras_model = keras_model
14
15
 
@@ -79,14 +80,6 @@ class LeapGraphEditor:
79
80
  return node
80
81
  return None
81
82
 
82
- def _resolve_prediction_origin_name(self, prediction_index: int) -> str:
83
- from keras import Model # deferred — only Keras path calls this
84
- return self.keras_model.outputs[prediction_index].node.layer.name
85
-
86
- def _resolve_input_origin_name(self, input_index: int) -> str:
87
- from keras import Model # deferred — only Keras path calls this
88
- return self.keras_model.inputs[input_index].node.layer.name
89
-
90
83
  def _replace_prediction_node_name_with_correct_name(self, connections: List[NodeConnection]) -> List[NodeConnection]:
91
84
  for connection in connections:
92
85
  if connection.node_inputs is None:
@@ -94,7 +87,7 @@ class LeapGraphEditor:
94
87
  for input_name, input_node in connection.node_inputs.items():
95
88
  if 'Prediction' in input_node.type.value:
96
89
  prediction_index = int(input_node.type.value.replace('Prediction', ''))
97
- origin_name = self._resolve_prediction_origin_name(prediction_index)
90
+ origin_name = self.keras_model.outputs[prediction_index].node.layer.name
98
91
  input_node.name = origin_name
99
92
 
100
93
  return connections
@@ -234,7 +227,7 @@ class LeapGraphEditor:
234
227
 
235
228
  def _handle_input_node_with_index(self, input_node: NodeMapping) -> str:
236
229
  input_index = int(input_node.type.value.replace('Input', ''))
237
- origin_name = self._resolve_input_origin_name(input_index)
230
+ origin_name = self.keras_model.inputs[input_index].node.layer.name
238
231
  input_node_by_origin = self._find_input_node_by_origin_name(origin_name)
239
232
  assert input_node_by_origin is not None, f"Input node with origin name {origin_name} not found in model graph"
240
233
  input_node_id = input_node_by_origin.id
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "leap-model-parser"
3
- version = "0.1.250.dev8"
3
+ version = "0.1.250.dev10"
4
4
  description = ""
5
5
  authors = ["idan <idan.yogev@tensorleap.ai>"]
6
6
  license = "MIT"
@@ -18,12 +18,11 @@ tensorflow = {version = "2.12.0", markers = "platform_machine == 'x86_64'"}
18
18
  tensorflow-macos = {version = "2.12.0", markers = "platform_machine == 'arm64'"}
19
19
  numpy = "^1.22.3"
20
20
  onnx = "1.13.0"
21
- onnx2kerastl = "0.0.191"
21
+ onnx2kerastl = "0.0.192.dev1"
22
22
  keras-data-format-converter = "0.1.24"
23
23
  leap-model-rebuilder = "0.1.7"
24
24
  tensorflow-io-gcs-filesystem = "0.34.0"
25
25
  code-loader = ">=1.0.127"
26
- torch = "~=2.3.0"
27
26
 
28
27
  [tool.poetry.dev-dependencies]
29
28
  pytest = "^7.1.1"
@@ -1,93 +0,0 @@
1
- from typing import Dict, List, Optional
2
-
3
- from torch.export import ExportedProgram
4
-
5
- from leap_model_parser.contract.graph import Node
6
- from leap_model_parser.leap_graph_editor import LeapGraphEditor
7
-
8
-
9
- class TorchLeapGraphEditor(LeapGraphEditor):
10
- """
11
- PT2 variant of LeapGraphEditor. Resolves Prediction{N} indices from
12
- ExportedProgram output specs instead of keras_model.outputs.
13
- """
14
-
15
- def __init__(self, model_graph: Dict[str, Node], ep: ExportedProgram):
16
- self.model_graph = model_graph
17
- self.keras_model = None
18
- self._next_node_id_index = 0
19
- self.ep = ep
20
-
21
- def _resolve_prediction_origin_name(self, prediction_index: int) -> str:
22
- module_path = self._find_pt2_output_module_path(prediction_index)
23
- if module_path is None:
24
- n = self._count_pt2_outputs()
25
- raise ValueError(
26
- f"Cannot resolve PT2 output module for Prediction{prediction_index}. "
27
- f"ExportedProgram has {n} outputs."
28
- )
29
- return module_path
30
-
31
- def _handle_input_node_with_index(self, input_node) -> str:
32
- return self._add_input_encoder_not_connected_to_the_model_node(input_node.name)
33
-
34
- def _resolve_input_origin_name(self, input_index: int) -> str:
35
- from torch.export.graph_signature import InputKind
36
- user_inputs = []
37
- node_to_spec = {s.arg.name: s for s in self.ep.graph_signature.input_specs}
38
- for node in self.ep.graph_module.graph.nodes:
39
- if node.op != "placeholder":
40
- break
41
- spec = node_to_spec.get(node.name)
42
- if spec and spec.kind not in (InputKind.PARAMETER, InputKind.BUFFER):
43
- user_inputs.append(node.name)
44
- if input_index >= len(user_inputs):
45
- raise ValueError(
46
- f"Cannot resolve PT2 input at index {input_index}. "
47
- f"ExportedProgram has {len(user_inputs)} user inputs."
48
- )
49
- return user_inputs[input_index]
50
-
51
- def _find_pt2_output_module_path(self, output_index: int) -> Optional[str]:
52
- output_node = None
53
- for node in self.ep.graph_module.graph.nodes:
54
- if node.op == "output":
55
- output_node = node
56
- break
57
- if output_node is None:
58
- return None
59
-
60
- args = output_node.args[0]
61
- if not isinstance(args, (list, tuple)):
62
- args = [args]
63
-
64
- if output_index >= len(args):
65
- return None
66
-
67
- output_ref = args[output_index]
68
- if output_ref is None or not hasattr(output_ref, "meta"):
69
- return None
70
-
71
- visited = set()
72
- current = output_ref
73
- while current.op == "call_function" and current not in visited:
74
- stack = current.meta.get("nn_module_stack", {})
75
- if stack:
76
- return list(stack.items())[-1][1][0]
77
- if current.args:
78
- visited.add(current)
79
- current = current.args[0]
80
- else:
81
- break
82
-
83
- stack = current.meta.get("nn_module_stack", {})
84
- if stack:
85
- return list(stack.items())[-1][1][0]
86
- return None
87
-
88
- def _count_pt2_outputs(self) -> int:
89
- for node in self.ep.graph_module.graph.nodes:
90
- if node.op == "output":
91
- args = node.args[0]
92
- return len(args) if isinstance(args, (list, tuple)) else 1
93
- return 0
@@ -1,57 +0,0 @@
1
- from typing import Dict, List, Optional, Tuple
2
-
3
- import torch
4
- from torch.export import ExportedProgram, load
5
- from torch.export.graph_signature import InputKind
6
-
7
- from leap_model_parser.contract.graph import InputInfo, Node
8
- from leap_model_parser.torch_utils import build_module_to_node_map
9
-
10
-
11
- class TorchModelParser:
12
- def parse(self, file_path: str) -> Tuple[Dict[str, Node], List[InputInfo], ExportedProgram, str]:
13
- ep = load(file_path)
14
- nodes = self._build_node_graph(ep)
15
- input_infos = self._extract_input_infos(ep)
16
- return nodes, input_infos, ep, "torch"
17
-
18
- def _build_node_graph(self, ep: ExportedProgram) -> Dict[str, Node]:
19
- m2n = build_module_to_node_map(ep)
20
- nodes: Dict[str, Node] = {}
21
- for i, (mod_path, (node_name, class_name)) in enumerate(m2n.items()):
22
- node = Node(
23
- id=mod_path,
24
- name=mod_path.split(".")[-1] if "." in mod_path else mod_path,
25
- position=[i * 100, 0],
26
- data={
27
- "type": "Layer",
28
- "class_name": class_name,
29
- "framework": "torch",
30
- "origin_name": mod_path,
31
- },
32
- )
33
- nodes[mod_path] = node
34
- return nodes
35
-
36
- def apply_mapping_connections(
37
- self,
38
- ui_nodes: Dict[str, Node],
39
- ep: ExportedProgram,
40
- mapping_connections: List,
41
- ) -> None:
42
- from leap_model_parser.torch_graph_editor import TorchLeapGraphEditor
43
- editor = TorchLeapGraphEditor(ui_nodes, ep)
44
- editor.add_connections_to_graph(mapping_connections)
45
-
46
- def _extract_input_infos(self, ep: ExportedProgram) -> List[InputInfo]:
47
- node_to_spec = {s.arg.name: s for s in ep.graph_signature.input_specs}
48
- infos: List[InputInfo] = []
49
- for node in ep.graph_module.graph.nodes:
50
- if node.op != "placeholder":
51
- break
52
- spec = node_to_spec.get(node.name)
53
- if spec and spec.kind not in (InputKind.PARAMETER, InputKind.BUFFER):
54
- val = node.meta.get("val")
55
- shape = list(val.shape) if val is not None else []
56
- infos.append(InputInfo(name=node.name, shape=shape))
57
- return infos
@@ -1,149 +0,0 @@
1
- import numpy as np
2
- import torch
3
- import torch.fx as fx
4
- from torch.export import ExportedProgram
5
- from torch.export.graph_signature import InputKind
6
- from typing import Any, Dict, List, Optional, Tuple
7
-
8
-
9
- def get_full_args(ep: ExportedProgram, user_inputs: List[torch.Tensor]) -> List[torch.Tensor]:
10
- node_to_spec = {s.arg.name: s for s in ep.graph_signature.input_specs}
11
- user_iter = iter(user_inputs)
12
- constants = getattr(ep, "constants", {}) or {}
13
- full_args = []
14
- for node in ep.graph_module.graph.nodes:
15
- if node.op != "placeholder":
16
- break
17
- spec = node_to_spec.get(node.name)
18
- if spec and spec.kind in (InputKind.PARAMETER, InputKind.BUFFER):
19
- if spec.target in ep.state_dict:
20
- full_args.append(ep.state_dict[spec.target].clone())
21
- else:
22
- full_args.append(constants[spec.target].clone())
23
- else:
24
- full_args.append(next(user_iter))
25
- return full_args
26
-
27
-
28
- def build_module_to_node_map(ep: ExportedProgram) -> Dict[str, Tuple[str, str]]:
29
- import operator as _op
30
- mapping: Dict[str, Tuple[str, str]] = {}
31
- for node in ep.graph_module.graph.nodes:
32
- if node.op != "call_function":
33
- continue
34
- if node.target is _op.getitem:
35
- continue
36
- stack = node.meta.get("nn_module_stack", {})
37
- if not stack:
38
- continue
39
- for mod_path, mod_type in stack.values():
40
- class_name = mod_type.__name__ if not isinstance(mod_type, str) else mod_type.split(".")[-1]
41
- mapping[mod_path] = (node.name, class_name)
42
- return mapping
43
-
44
-
45
- def extract_tensor(output: Any) -> torch.Tensor:
46
- if isinstance(output, torch.Tensor):
47
- return output
48
- if isinstance(output, (tuple, list)):
49
- t = output[0]
50
- return t if isinstance(t, torch.Tensor) else extract_tensor(t)
51
- if hasattr(output, "pooler_output") and output.pooler_output is not None:
52
- return output.pooler_output
53
- if hasattr(output, "last_hidden_state"):
54
- return output.last_hidden_state[:, 0, :]
55
- if hasattr(output, "logits"):
56
- return output.logits
57
- return output[0]
58
-
59
-
60
- class ActivationInterpreter(fx.Interpreter):
61
- def __init__(self, gm: fx.GraphModule, grad_mode: bool = False):
62
- super().__init__(gm)
63
- self.cache: Dict[str, torch.Tensor] = {}
64
- self._grad_mode = grad_mode
65
-
66
- def _store_tensor(self, name: str, value: Any) -> None:
67
- if isinstance(value, torch.Tensor):
68
- if self._grad_mode:
69
- self.cache[name] = value
70
- if value.requires_grad:
71
- value.retain_grad()
72
- else:
73
- self.cache[name] = value.detach().clone()
74
- elif isinstance(value, (tuple, list)):
75
- for item in value:
76
- if isinstance(item, torch.Tensor):
77
- if self._grad_mode:
78
- self.cache[name] = item
79
- if item.requires_grad:
80
- item.retain_grad()
81
- else:
82
- self.cache[name] = item.detach().clone()
83
- break
84
-
85
- def run_node(self, node: fx.Node) -> Any:
86
- result = super().run_node(node)
87
- self._store_tensor(node.name, result)
88
- return result
89
-
90
-
91
- def get_pt2_output_module_paths(ep: ExportedProgram) -> List[Optional[str]]:
92
- """Return the nn.Module path responsible for each output slot in the FX graph."""
93
- output_node = None
94
- for node in ep.graph_module.graph.nodes:
95
- if node.op == "output":
96
- output_node = node
97
- break
98
- if output_node is None:
99
- return []
100
-
101
- args = output_node.args[0]
102
- if not isinstance(args, (list, tuple)):
103
- args = [args]
104
-
105
- result: List[Optional[str]] = []
106
- for output_ref in args:
107
- if output_ref is None or not hasattr(output_ref, "meta"):
108
- result.append(None)
109
- continue
110
-
111
- visited: set = set()
112
- current = output_ref
113
- mod_path = None
114
- while current.op == "call_function" and current not in visited:
115
- stack = current.meta.get("nn_module_stack", {})
116
- if stack:
117
- mod_path = list(stack.items())[-1][1][0]
118
- break
119
- if current.args:
120
- visited.add(current)
121
- current = current.args[0]
122
- else:
123
- break
124
-
125
- if mod_path is None:
126
- stack = current.meta.get("nn_module_stack", {})
127
- if stack:
128
- mod_path = list(stack.items())[-1][1][0]
129
-
130
- result.append(mod_path)
131
-
132
- return result
133
-
134
-
135
- def get_module_activations(
136
- ep: ExportedProgram,
137
- user_inputs: List[torch.Tensor],
138
- ) -> Dict[str, np.ndarray]:
139
- tensors = [torch.from_numpy(x.astype(np.float32)) if isinstance(x, np.ndarray) else x
140
- for x in user_inputs]
141
- m2n = build_module_to_node_map(ep)
142
- interp = ActivationInterpreter(ep.graph_module)
143
- with torch.no_grad():
144
- interp.run(*get_full_args(ep, tensors))
145
- result: Dict[str, np.ndarray] = {}
146
- for mod_path, (node_name, _) in m2n.items():
147
- if node_name in interp.cache:
148
- result[mod_path] = interp.cache[node_name].numpy()
149
- return result