leap-model-parser 0.1.250.dev5__tar.gz → 0.1.250.dev7__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (26) hide show
  1. {leap_model_parser-0.1.250.dev5 → leap_model_parser-0.1.250.dev7}/PKG-INFO +1 -1
  2. {leap_model_parser-0.1.250.dev5 → leap_model_parser-0.1.250.dev7}/leap_model_parser/torch_utils.py +59 -4
  3. {leap_model_parser-0.1.250.dev5 → leap_model_parser-0.1.250.dev7}/pyproject.toml +1 -1
  4. {leap_model_parser-0.1.250.dev5 → leap_model_parser-0.1.250.dev7}/LICENSE +0 -0
  5. {leap_model_parser-0.1.250.dev5 → leap_model_parser-0.1.250.dev7}/README.md +0 -0
  6. {leap_model_parser-0.1.250.dev5 → leap_model_parser-0.1.250.dev7}/leap_model_parser/__init__.py +0 -0
  7. {leap_model_parser-0.1.250.dev5 → leap_model_parser-0.1.250.dev7}/leap_model_parser/contract/__init__.py +0 -0
  8. {leap_model_parser-0.1.250.dev5 → leap_model_parser-0.1.250.dev7}/leap_model_parser/contract/graph.py +0 -0
  9. {leap_model_parser-0.1.250.dev5 → leap_model_parser-0.1.250.dev7}/leap_model_parser/contract/importmodelresponse.py +0 -0
  10. {leap_model_parser-0.1.250.dev5 → leap_model_parser-0.1.250.dev7}/leap_model_parser/contract/nodedata.py +0 -0
  11. {leap_model_parser-0.1.250.dev5 → leap_model_parser-0.1.250.dev7}/leap_model_parser/contract/ui_components.json +0 -0
  12. {leap_model_parser-0.1.250.dev5 → leap_model_parser-0.1.250.dev7}/leap_model_parser/keras_json_model_import.py +0 -0
  13. {leap_model_parser-0.1.250.dev5 → leap_model_parser-0.1.250.dev7}/leap_model_parser/leap_graph_editor.py +0 -0
  14. {leap_model_parser-0.1.250.dev5 → leap_model_parser-0.1.250.dev7}/leap_model_parser/model_parser.py +0 -0
  15. {leap_model_parser-0.1.250.dev5 → leap_model_parser-0.1.250.dev7}/leap_model_parser/torch_graph_editor.py +0 -0
  16. {leap_model_parser-0.1.250.dev5 → leap_model_parser-0.1.250.dev7}/leap_model_parser/torch_model_parser.py +0 -0
  17. {leap_model_parser-0.1.250.dev5 → leap_model_parser-0.1.250.dev7}/leap_model_parser/utils/__init__.py +0 -0
  18. {leap_model_parser-0.1.250.dev5 → leap_model_parser-0.1.250.dev7}/leap_model_parser/utils/layerpedia/__init__.py +0 -0
  19. {leap_model_parser-0.1.250.dev5 → leap_model_parser-0.1.250.dev7}/leap_model_parser/utils/layerpedia/layerpedia.py +0 -0
  20. {leap_model_parser-0.1.250.dev5 → leap_model_parser-0.1.250.dev7}/leap_model_parser/utils/tlinspection/__init__.py +0 -0
  21. {leap_model_parser-0.1.250.dev5 → leap_model_parser-0.1.250.dev7}/leap_model_parser/utils/tlinspection/leapinspection.py +0 -0
  22. {leap_model_parser-0.1.250.dev5 → leap_model_parser-0.1.250.dev7}/leap_model_parser/utils/uicomponents/__init__.py +0 -0
  23. {leap_model_parser-0.1.250.dev5 → leap_model_parser-0.1.250.dev7}/leap_model_parser/utils/uicomponents/generatenodedata.py +0 -0
  24. {leap_model_parser-0.1.250.dev5 → leap_model_parser-0.1.250.dev7}/leap_model_parser/utils/uicomponents/tensorflowinscpection.py +0 -0
  25. {leap_model_parser-0.1.250.dev5 → leap_model_parser-0.1.250.dev7}/leap_model_parser/utils/uicomponents/ui_components.json +0 -0
  26. {leap_model_parser-0.1.250.dev5 → leap_model_parser-0.1.250.dev7}/leap_model_parser/utils/uicomponents/ui_components_config.yaml +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: leap-model-parser
3
- Version: 0.1.250.dev5
3
+ Version: 0.1.250.dev7
4
4
  Summary:
5
5
  Home-page: https://github.com/tensorleap/leap-model-parser
6
6
  License: MIT
@@ -3,7 +3,7 @@ import torch
3
3
  import torch.fx as fx
4
4
  from torch.export import ExportedProgram
5
5
  from torch.export.graph_signature import InputKind
6
- from typing import Any, Dict, List, Tuple
6
+ from typing import Any, Dict, List, Optional, Tuple
7
7
 
8
8
 
9
9
  def get_full_args(ep: ExportedProgram, user_inputs: List[torch.Tensor]) -> List[torch.Tensor]:
@@ -55,17 +55,28 @@ def extract_tensor(output: Any) -> torch.Tensor:
55
55
 
56
56
 
57
57
  class ActivationInterpreter(fx.Interpreter):
58
- def __init__(self, gm: fx.GraphModule):
58
+ def __init__(self, gm: fx.GraphModule, grad_mode: bool = False):
59
59
  super().__init__(gm)
60
60
  self.cache: Dict[str, torch.Tensor] = {}
61
+ self._grad_mode = grad_mode
61
62
 
62
63
  def _store_tensor(self, name: str, value: Any) -> None:
63
64
  if isinstance(value, torch.Tensor):
64
- self.cache[name] = value.detach().clone()
65
+ if self._grad_mode:
66
+ self.cache[name] = value
67
+ if value.requires_grad:
68
+ value.retain_grad()
69
+ else:
70
+ self.cache[name] = value.detach().clone()
65
71
  elif isinstance(value, (tuple, list)):
66
72
  for item in value:
67
73
  if isinstance(item, torch.Tensor):
68
- self.cache[name] = item.detach().clone()
74
+ if self._grad_mode:
75
+ self.cache[name] = item
76
+ if item.requires_grad:
77
+ item.retain_grad()
78
+ else:
79
+ self.cache[name] = item.detach().clone()
69
80
  break
70
81
 
71
82
  def run_node(self, node: fx.Node) -> Any:
@@ -74,6 +85,50 @@ class ActivationInterpreter(fx.Interpreter):
74
85
  return result
75
86
 
76
87
 
88
+ def get_pt2_output_module_paths(ep: ExportedProgram) -> List[Optional[str]]:
89
+ """Return the nn.Module path responsible for each output slot in the FX graph."""
90
+ output_node = None
91
+ for node in ep.graph_module.graph.nodes:
92
+ if node.op == "output":
93
+ output_node = node
94
+ break
95
+ if output_node is None:
96
+ return []
97
+
98
+ args = output_node.args[0]
99
+ if not isinstance(args, (list, tuple)):
100
+ args = [args]
101
+
102
+ result: List[Optional[str]] = []
103
+ for output_ref in args:
104
+ if output_ref is None or not hasattr(output_ref, "meta"):
105
+ result.append(None)
106
+ continue
107
+
108
+ visited: set = set()
109
+ current = output_ref
110
+ mod_path = None
111
+ while current.op == "call_function" and current not in visited:
112
+ stack = current.meta.get("nn_module_stack", {})
113
+ if stack:
114
+ mod_path = list(stack.items())[-1][1][0]
115
+ break
116
+ if current.args:
117
+ visited.add(current)
118
+ current = current.args[0]
119
+ else:
120
+ break
121
+
122
+ if mod_path is None:
123
+ stack = current.meta.get("nn_module_stack", {})
124
+ if stack:
125
+ mod_path = list(stack.items())[-1][1][0]
126
+
127
+ result.append(mod_path)
128
+
129
+ return result
130
+
131
+
77
132
  def get_module_activations(
78
133
  ep: ExportedProgram,
79
134
  user_inputs: List[torch.Tensor],
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "leap-model-parser"
3
- version = "0.1.250.dev5"
3
+ version = "0.1.250.dev7"
4
4
  description = ""
5
5
  authors = ["idan <idan.yogev@tensorleap.ai>"]
6
6
  license = "MIT"