leap-model-parser 0.1.250.dev3__tar.gz → 0.1.250.dev5__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (26) hide show
  1. {leap_model_parser-0.1.250.dev3 → leap_model_parser-0.1.250.dev5}/PKG-INFO +1 -1
  2. {leap_model_parser-0.1.250.dev3 → leap_model_parser-0.1.250.dev5}/leap_model_parser/leap_graph_editor.py +11 -4
  3. leap_model_parser-0.1.250.dev5/leap_model_parser/torch_graph_editor.py +93 -0
  4. {leap_model_parser-0.1.250.dev3 → leap_model_parser-0.1.250.dev5}/leap_model_parser/torch_model_parser.py +17 -2
  5. {leap_model_parser-0.1.250.dev3 → leap_model_parser-0.1.250.dev5}/pyproject.toml +1 -1
  6. {leap_model_parser-0.1.250.dev3 → leap_model_parser-0.1.250.dev5}/LICENSE +0 -0
  7. {leap_model_parser-0.1.250.dev3 → leap_model_parser-0.1.250.dev5}/README.md +0 -0
  8. {leap_model_parser-0.1.250.dev3 → leap_model_parser-0.1.250.dev5}/leap_model_parser/__init__.py +0 -0
  9. {leap_model_parser-0.1.250.dev3 → leap_model_parser-0.1.250.dev5}/leap_model_parser/contract/__init__.py +0 -0
  10. {leap_model_parser-0.1.250.dev3 → leap_model_parser-0.1.250.dev5}/leap_model_parser/contract/graph.py +0 -0
  11. {leap_model_parser-0.1.250.dev3 → leap_model_parser-0.1.250.dev5}/leap_model_parser/contract/importmodelresponse.py +0 -0
  12. {leap_model_parser-0.1.250.dev3 → leap_model_parser-0.1.250.dev5}/leap_model_parser/contract/nodedata.py +0 -0
  13. {leap_model_parser-0.1.250.dev3 → leap_model_parser-0.1.250.dev5}/leap_model_parser/contract/ui_components.json +0 -0
  14. {leap_model_parser-0.1.250.dev3 → leap_model_parser-0.1.250.dev5}/leap_model_parser/keras_json_model_import.py +0 -0
  15. {leap_model_parser-0.1.250.dev3 → leap_model_parser-0.1.250.dev5}/leap_model_parser/model_parser.py +0 -0
  16. {leap_model_parser-0.1.250.dev3 → leap_model_parser-0.1.250.dev5}/leap_model_parser/torch_utils.py +0 -0
  17. {leap_model_parser-0.1.250.dev3 → leap_model_parser-0.1.250.dev5}/leap_model_parser/utils/__init__.py +0 -0
  18. {leap_model_parser-0.1.250.dev3 → leap_model_parser-0.1.250.dev5}/leap_model_parser/utils/layerpedia/__init__.py +0 -0
  19. {leap_model_parser-0.1.250.dev3 → leap_model_parser-0.1.250.dev5}/leap_model_parser/utils/layerpedia/layerpedia.py +0 -0
  20. {leap_model_parser-0.1.250.dev3 → leap_model_parser-0.1.250.dev5}/leap_model_parser/utils/tlinspection/__init__.py +0 -0
  21. {leap_model_parser-0.1.250.dev3 → leap_model_parser-0.1.250.dev5}/leap_model_parser/utils/tlinspection/leapinspection.py +0 -0
  22. {leap_model_parser-0.1.250.dev3 → leap_model_parser-0.1.250.dev5}/leap_model_parser/utils/uicomponents/__init__.py +0 -0
  23. {leap_model_parser-0.1.250.dev3 → leap_model_parser-0.1.250.dev5}/leap_model_parser/utils/uicomponents/generatenodedata.py +0 -0
  24. {leap_model_parser-0.1.250.dev3 → leap_model_parser-0.1.250.dev5}/leap_model_parser/utils/uicomponents/tensorflowinscpection.py +0 -0
  25. {leap_model_parser-0.1.250.dev3 → leap_model_parser-0.1.250.dev5}/leap_model_parser/utils/uicomponents/ui_components.json +0 -0
  26. {leap_model_parser-0.1.250.dev3 → leap_model_parser-0.1.250.dev5}/leap_model_parser/utils/uicomponents/ui_components_config.yaml +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: leap-model-parser
3
- Version: 0.1.250.dev3
3
+ Version: 0.1.250.dev5
4
4
  Summary:
5
5
  Home-page: https://github.com/tensorleap/leap-model-parser
6
6
  License: MIT
@@ -2,14 +2,13 @@ from enum import Enum
2
2
  from typing import Optional, Dict, Any, List
3
3
 
4
4
  from code_loader.contract.mapping import NodeConnection, NodeMappingType, NodeMapping # type: ignore
5
- from keras import Model # type: ignore
6
5
 
7
6
  from leap_model_parser.contract.graph import Node as Node, OutputData, ConnectionOutput, ConnectionInput, InputData
8
7
 
9
8
 
10
9
 
11
10
  class LeapGraphEditor:
12
- def __init__(self, model_graph: Dict[str, Node], keras_model: Model):
11
+ def __init__(self, model_graph: Dict[str, Node], keras_model=None):
13
12
  self.model_graph = model_graph
14
13
  self.keras_model = keras_model
15
14
 
@@ -80,6 +79,14 @@ class LeapGraphEditor:
80
79
  return node
81
80
  return None
82
81
 
82
+ def _resolve_prediction_origin_name(self, prediction_index: int) -> str:
83
+ from keras import Model # deferred — only Keras path calls this
84
+ return self.keras_model.outputs[prediction_index].node.layer.name
85
+
86
+ def _resolve_input_origin_name(self, input_index: int) -> str:
87
+ from keras import Model # deferred — only Keras path calls this
88
+ return self.keras_model.inputs[input_index].node.layer.name
89
+
83
90
  def _replace_prediction_node_name_with_correct_name(self, connections: List[NodeConnection]) -> List[NodeConnection]:
84
91
  for connection in connections:
85
92
  if connection.node_inputs is None:
@@ -87,7 +94,7 @@ class LeapGraphEditor:
87
94
  for input_name, input_node in connection.node_inputs.items():
88
95
  if 'Prediction' in input_node.type.value:
89
96
  prediction_index = int(input_node.type.value.replace('Prediction', ''))
90
- origin_name = self.keras_model.outputs[prediction_index].node.layer.name
97
+ origin_name = self._resolve_prediction_origin_name(prediction_index)
91
98
  input_node.name = origin_name
92
99
 
93
100
  return connections
@@ -227,7 +234,7 @@ class LeapGraphEditor:
227
234
 
228
235
  def _handle_input_node_with_index(self, input_node: NodeMapping) -> str:
229
236
  input_index = int(input_node.type.value.replace('Input', ''))
230
- origin_name = self.keras_model.inputs[input_index].node.layer.name
237
+ origin_name = self._resolve_input_origin_name(input_index)
231
238
  input_node_by_origin = self._find_input_node_by_origin_name(origin_name)
232
239
  assert input_node_by_origin is not None, f"Input node with origin name {origin_name} not found in model graph"
233
240
  input_node_id = input_node_by_origin.id
@@ -0,0 +1,93 @@
1
+ from typing import Dict, List, Optional
2
+
3
+ from torch.export import ExportedProgram
4
+
5
+ from leap_model_parser.contract.graph import Node
6
+ from leap_model_parser.leap_graph_editor import LeapGraphEditor
7
+
8
+
9
+ class TorchLeapGraphEditor(LeapGraphEditor):
10
+ """
11
+ PT2 variant of LeapGraphEditor. Resolves Prediction{N} indices from
12
+ ExportedProgram output specs instead of keras_model.outputs.
13
+ """
14
+
15
+ def __init__(self, model_graph: Dict[str, Node], ep: ExportedProgram):
16
+ self.model_graph = model_graph
17
+ self.keras_model = None
18
+ self._next_node_id_index = 0
19
+ self.ep = ep
20
+
21
+ def _resolve_prediction_origin_name(self, prediction_index: int) -> str:
22
+ module_path = self._find_pt2_output_module_path(prediction_index)
23
+ if module_path is None:
24
+ n = self._count_pt2_outputs()
25
+ raise ValueError(
26
+ f"Cannot resolve PT2 output module for Prediction{prediction_index}. "
27
+ f"ExportedProgram has {n} outputs."
28
+ )
29
+ return module_path
30
+
31
+ def _handle_input_node_with_index(self, input_node) -> str:
32
+ return self._add_input_encoder_not_connected_to_the_model_node(input_node.name)
33
+
34
+ def _resolve_input_origin_name(self, input_index: int) -> str:
35
+ from torch.export.graph_signature import InputKind
36
+ user_inputs = []
37
+ node_to_spec = {s.arg.name: s for s in self.ep.graph_signature.input_specs}
38
+ for node in self.ep.graph_module.graph.nodes:
39
+ if node.op != "placeholder":
40
+ break
41
+ spec = node_to_spec.get(node.name)
42
+ if spec and spec.kind not in (InputKind.PARAMETER, InputKind.BUFFER):
43
+ user_inputs.append(node.name)
44
+ if input_index >= len(user_inputs):
45
+ raise ValueError(
46
+ f"Cannot resolve PT2 input at index {input_index}. "
47
+ f"ExportedProgram has {len(user_inputs)} user inputs."
48
+ )
49
+ return user_inputs[input_index]
50
+
51
+ def _find_pt2_output_module_path(self, output_index: int) -> Optional[str]:
52
+ output_node = None
53
+ for node in self.ep.graph_module.graph.nodes:
54
+ if node.op == "output":
55
+ output_node = node
56
+ break
57
+ if output_node is None:
58
+ return None
59
+
60
+ args = output_node.args[0]
61
+ if not isinstance(args, (list, tuple)):
62
+ args = [args]
63
+
64
+ if output_index >= len(args):
65
+ return None
66
+
67
+ output_ref = args[output_index]
68
+ if output_ref is None or not hasattr(output_ref, "meta"):
69
+ return None
70
+
71
+ visited = set()
72
+ current = output_ref
73
+ while current.op == "call_function" and current not in visited:
74
+ stack = current.meta.get("nn_module_stack", {})
75
+ if stack:
76
+ return list(stack.items())[-1][1][0]
77
+ if current.args:
78
+ visited.add(current)
79
+ current = current.args[0]
80
+ else:
81
+ break
82
+
83
+ stack = current.meta.get("nn_module_stack", {})
84
+ if stack:
85
+ return list(stack.items())[-1][1][0]
86
+ return None
87
+
88
+ def _count_pt2_outputs(self) -> int:
89
+ for node in self.ep.graph_module.graph.nodes:
90
+ if node.op == "output":
91
+ args = node.args[0]
92
+ return len(args) if isinstance(args, (list, tuple)) else 1
93
+ return 0
@@ -1,4 +1,4 @@
1
- from typing import Dict, List, Tuple
1
+ from typing import Dict, List, Optional, Tuple
2
2
 
3
3
  import torch
4
4
  from torch.export import ExportedProgram, load
@@ -23,11 +23,26 @@ class TorchModelParser:
23
23
  id=mod_path,
24
24
  name=mod_path.split(".")[-1] if "." in mod_path else mod_path,
25
25
  position=[i * 100, 0],
26
- data={"type": class_name, "framework": "torch"},
26
+ data={
27
+ "type": "Layer",
28
+ "class_name": class_name,
29
+ "framework": "torch",
30
+ "origin_name": mod_path,
31
+ },
27
32
  )
28
33
  nodes[mod_path] = node
29
34
  return nodes
30
35
 
36
+ def apply_mapping_connections(
37
+ self,
38
+ ui_nodes: Dict[str, Node],
39
+ ep: ExportedProgram,
40
+ mapping_connections: List,
41
+ ) -> None:
42
+ from leap_model_parser.torch_graph_editor import TorchLeapGraphEditor
43
+ editor = TorchLeapGraphEditor(ui_nodes, ep)
44
+ editor.add_connections_to_graph(mapping_connections)
45
+
31
46
  def _extract_input_infos(self, ep: ExportedProgram) -> List[InputInfo]:
32
47
  node_to_spec = {s.arg.name: s for s in ep.graph_signature.input_specs}
33
48
  infos: List[InputInfo] = []
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "leap-model-parser"
3
- version = "0.1.250.dev3"
3
+ version = "0.1.250.dev5"
4
4
  description = ""
5
5
  authors = ["idan <idan.yogev@tensorleap.ai>"]
6
6
  license = "MIT"