ttnn-visualizer 0.24.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. ttnn_visualizer/__init__.py +4 -0
  2. ttnn_visualizer/app.py +193 -0
  3. ttnn_visualizer/bin/docker-entrypoint-web +16 -0
  4. ttnn_visualizer/bin/pip3-install +17 -0
  5. ttnn_visualizer/csv_queries.py +618 -0
  6. ttnn_visualizer/decorators.py +117 -0
  7. ttnn_visualizer/enums.py +12 -0
  8. ttnn_visualizer/exceptions.py +40 -0
  9. ttnn_visualizer/extensions.py +14 -0
  10. ttnn_visualizer/file_uploads.py +78 -0
  11. ttnn_visualizer/models.py +275 -0
  12. ttnn_visualizer/queries.py +388 -0
  13. ttnn_visualizer/remote_sqlite_setup.py +91 -0
  14. ttnn_visualizer/requirements.txt +24 -0
  15. ttnn_visualizer/serializers.py +249 -0
  16. ttnn_visualizer/sessions.py +245 -0
  17. ttnn_visualizer/settings.py +118 -0
  18. ttnn_visualizer/sftp_operations.py +486 -0
  19. ttnn_visualizer/sockets.py +118 -0
  20. ttnn_visualizer/ssh_client.py +85 -0
  21. ttnn_visualizer/static/assets/allPaths-CKt4gwo3.js +1 -0
  22. ttnn_visualizer/static/assets/allPathsLoader-Dzw0zTnr.js +2 -0
  23. ttnn_visualizer/static/assets/index-BXlT2rEV.js +5247 -0
  24. ttnn_visualizer/static/assets/index-CsS_OkTl.js +1 -0
  25. ttnn_visualizer/static/assets/index-DTKBo2Os.css +7 -0
  26. ttnn_visualizer/static/assets/index-DxLGmC6o.js +1 -0
  27. ttnn_visualizer/static/assets/site-BTBrvHC5.webmanifest +19 -0
  28. ttnn_visualizer/static/assets/splitPathsBySizeLoader-HHqSPeQM.js +1 -0
  29. ttnn_visualizer/static/favicon/android-chrome-192x192.png +0 -0
  30. ttnn_visualizer/static/favicon/android-chrome-512x512.png +0 -0
  31. ttnn_visualizer/static/favicon/favicon-32x32.png +0 -0
  32. ttnn_visualizer/static/favicon/favicon.svg +3 -0
  33. ttnn_visualizer/static/index.html +36 -0
  34. ttnn_visualizer/static/sample-data/cluster-desc.yaml +763 -0
  35. ttnn_visualizer/tests/__init__.py +4 -0
  36. ttnn_visualizer/tests/test_queries.py +444 -0
  37. ttnn_visualizer/tests/test_serializers.py +582 -0
  38. ttnn_visualizer/utils.py +185 -0
  39. ttnn_visualizer/views.py +794 -0
  40. ttnn_visualizer-0.24.0.dist-info/LICENSE +202 -0
  41. ttnn_visualizer-0.24.0.dist-info/LICENSE_understanding.txt +3 -0
  42. ttnn_visualizer-0.24.0.dist-info/METADATA +144 -0
  43. ttnn_visualizer-0.24.0.dist-info/RECORD +46 -0
  44. ttnn_visualizer-0.24.0.dist-info/WHEEL +5 -0
  45. ttnn_visualizer-0.24.0.dist-info/entry_points.txt +2 -0
  46. ttnn_visualizer-0.24.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,117 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ #
3
+ # SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC
4
+
5
+ import re
6
+ from ttnn_visualizer.enums import ConnectionTestStates
7
+
8
+
9
+ from functools import wraps
10
+ from flask import request, abort
11
+ from paramiko.ssh_exception import (
12
+ AuthenticationException,
13
+ NoValidConnectionsError,
14
+ SSHException,
15
+ )
16
+
17
+ from ttnn_visualizer.exceptions import (
18
+ RemoteConnectionException,
19
+ NoProjectsException,
20
+ RemoteSqliteException,
21
+ )
22
+ from ttnn_visualizer.sessions import get_or_create_tab_session
23
+
24
+
25
+ def with_session(func):
26
+ @wraps(func)
27
+ def wrapper(*args, **kwargs):
28
+ from flask import current_app
29
+
30
+ tab_id = request.args.get("tabId")
31
+
32
+ if not tab_id:
33
+ current_app.logger.error("No tabId present on request, returning 404")
34
+ abort(404)
35
+
36
+ session_query_data = get_or_create_tab_session(tab_id=tab_id)
37
+ session = session_query_data.to_pydantic()
38
+
39
+ if not session.active_report:
40
+ current_app.logger.error(
41
+ f"No active report exists for tabId {tab_id}, returning 404"
42
+ )
43
+ # Raise 404 if report_path is missing or does not exist
44
+ abort(404)
45
+
46
+ kwargs["session"] = session
47
+ return func(*args, **kwargs)
48
+
49
+ return wrapper
50
+
51
+
52
+ def remote_exception_handler(func):
53
+ from flask import current_app
54
+
55
+ def remote_handler(*args, **kwargs):
56
+ if kwargs.get("connection", None):
57
+ connection = kwargs["connection"]
58
+ elif kwargs.get("remote_connection", None):
59
+ connection = kwargs["remote_connection"]
60
+ else:
61
+ connection = args[0]
62
+ try:
63
+ return func(*args, **kwargs)
64
+ except AuthenticationException as err:
65
+ current_app.logger.error(f"Authentication failed {err}")
66
+ raise RemoteConnectionException(
67
+ status=ConnectionTestStates.FAILED,
68
+ message=f"Unable to authenticate: {str(err)}",
69
+ )
70
+ except FileNotFoundError as err:
71
+ current_app.logger.error(f"File not found: {str(err)}")
72
+ raise RemoteConnectionException(
73
+ status=ConnectionTestStates.FAILED,
74
+ message=f"Unable to open path {connection.path}: {str(err)}",
75
+ )
76
+ except NoProjectsException as err:
77
+ current_app.logger.error(f"No projects: {str(err)}")
78
+ raise RemoteConnectionException(
79
+ status=ConnectionTestStates.FAILED,
80
+ message=f"No projects found at remote location: {connection.path}",
81
+ )
82
+ except NoValidConnectionsError as err:
83
+ current_app.logger.error(f"No valid connections: {str(err)}")
84
+ message = re.sub(r"\[.*?]", "", str(err)).strip()
85
+ raise RemoteConnectionException(
86
+ status=ConnectionTestStates.FAILED,
87
+ message=f"{message}",
88
+ )
89
+
90
+ except RemoteSqliteException as err:
91
+ current_app.logger.error(f"Remote Sqlite exception: {str(err)}")
92
+ message = err.message
93
+ if "No such file" in str(err):
94
+ message = "Unable to open SQLite binary, check path"
95
+ raise RemoteConnectionException(
96
+ status=ConnectionTestStates.FAILED, message=message
97
+ )
98
+ except IOError as err:
99
+ message = f"Error opening remote folder {connection.path}: {str(err)}"
100
+ if "Name or service not known" in str(err):
101
+ message = f"Unable to connect to {connection.host} - check hostname"
102
+ raise RemoteConnectionException(
103
+ status=ConnectionTestStates.FAILED,
104
+ message=message,
105
+ )
106
+ except SSHException as err:
107
+ if str(err) == "No existing session":
108
+ message = "Authentication failed - check credentials and ssh-agent"
109
+ else:
110
+ err_message = re.sub(r"\[.*?]", "", str(err)).strip()
111
+ message = f"Error connecting to host {connection.host}: {err_message}"
112
+
113
+ raise RemoteConnectionException(
114
+ status=ConnectionTestStates.FAILED, message=message
115
+ )
116
+
117
+ return remote_handler
@@ -0,0 +1,12 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ #
3
+ # SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC
4
+
5
+ import enum
6
+
7
+
8
+ class ConnectionTestStates(enum.Enum):
9
+ IDLE = 0
10
+ PROGRESS = 1
11
+ FAILED = 2
12
+ OK = 3
@@ -0,0 +1,40 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ #
3
+ # SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC
4
+
5
+ from http import HTTPStatus
6
+
7
+ from ttnn_visualizer.enums import ConnectionTestStates
8
+
9
+
10
+ class RemoteConnectionException(Exception):
11
+ def __init__(self, message, status: ConnectionTestStates):
12
+ super().__init__(message)
13
+ self.message = message
14
+ self.status = status
15
+
16
+ @property
17
+ def http_status(self):
18
+ if self.status == ConnectionTestStates.FAILED:
19
+ return HTTPStatus.INTERNAL_SERVER_ERROR
20
+ if self.status == ConnectionTestStates.OK:
21
+ return HTTPStatus.OK
22
+
23
+
24
+ class NoProjectsException(RemoteConnectionException):
25
+ pass
26
+
27
+
28
+ class RemoteSqliteException(Exception):
29
+ def __init__(self, message, status):
30
+ super().__init__(message)
31
+ self.message = message
32
+ self.status = status
33
+
34
+
35
+ class DatabaseFileNotFoundException(Exception):
36
+ pass
37
+
38
+
39
+ class DataFormatError(Exception):
40
+ pass
@@ -0,0 +1,14 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ #
3
+ # SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC
4
+
5
+ from flask_socketio import SocketIO
6
+ from flask_static_digest import FlaskStaticDigest
7
+ from flask_sqlalchemy import SQLAlchemy
8
+
9
+
10
+ flask_static_digest = FlaskStaticDigest()
11
+ # Initialize Flask SQLAlchemy
12
+ db = SQLAlchemy()
13
+
14
+ socketio = SocketIO(cors_allowed_origins="*", async_mode="gevent")
@@ -0,0 +1,78 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ #
3
+ # SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
4
+
5
+ from pathlib import Path
6
+ import logging
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+
11
+ def validate_files(files, required_files, pattern=None):
12
+ """Validate uploaded files against required file names and an optional pattern."""
13
+ found_files = set()
14
+
15
+ for file in files:
16
+ file_path = Path(file.filename)
17
+
18
+ if file_path.name in required_files or (
19
+ pattern and file_path.name.startswith(pattern)
20
+ ):
21
+ found_files.add(file_path.name)
22
+ if len(file_path.parents) != 2:
23
+ logger.warning(
24
+ f"File {file.filename} is not under a single parent folder."
25
+ )
26
+ return False
27
+
28
+ missing_files = required_files - found_files
29
+ if pattern and not any(name.startswith(pattern) for name in found_files):
30
+ missing_files.add(f"{pattern}*")
31
+
32
+ if missing_files:
33
+ logger.warning(f"Missing required files: {', '.join(missing_files)}")
34
+ return False
35
+
36
+ return True
37
+
38
+
39
+ def extract_report_name(files):
40
+ """Extract the report name from the first file."""
41
+ if not files:
42
+ return None
43
+ unsplit_report_name = str(files[0].filename)
44
+ return unsplit_report_name.split("/")[0]
45
+
46
+
47
+ def save_uploaded_files(
48
+ files,
49
+ target_directory,
50
+ report_name=None,
51
+ ):
52
+ """
53
+ Save uploaded files to the target directory.
54
+
55
+ :param files: List of files to be saved.
56
+ :param target_directory: The base directory for saving the files.
57
+ :param report_name: The report name to use for the directory.
58
+ :param modify_path: Optional function to modify the file path before saving.
59
+ :param flat_structure: If True, saves files directly under the report_name directory without subdirectories.
60
+ """
61
+ for file in files:
62
+ current_file_name = str(file.filename)
63
+ logger.info(f"Processing file: {current_file_name}")
64
+
65
+ file_path = Path(current_file_name)
66
+
67
+ destination_file = Path(target_directory).joinpath(str(file_path))
68
+
69
+ logger.info(f"Writing file to {destination_file}")
70
+
71
+ # Create directory if it doesn't exist
72
+ if not destination_file.parent.exists():
73
+ logger.info(
74
+ f"{destination_file.parent.name} does not exist. Creating directory"
75
+ )
76
+ destination_file.parent.mkdir(exist_ok=True, parents=True)
77
+
78
+ file.save(destination_file)
@@ -0,0 +1,275 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ #
3
+ # SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC
4
+
5
+ import dataclasses
6
+ import enum
7
+ import json
8
+ from json import JSONDecodeError
9
+ from typing import Optional, Any
10
+
11
+ from pydantic import BaseModel, Field
12
+ from sqlalchemy import Integer, Column, String, JSON
13
+ from sqlalchemy.ext.mutable import MutableDict
14
+
15
+ from ttnn_visualizer.utils import SerializeableDataclass
16
+ from ttnn_visualizer.enums import ConnectionTestStates
17
+ from ttnn_visualizer.extensions import db
18
+
19
+ from ttnn_visualizer.utils import parse_memory_config
20
+
21
+
22
+ class BufferType(enum.Enum):
23
+ DRAM = 0
24
+ L1 = 1
25
+ SYSTEM_MEMORY = 2
26
+ L1_SMALL = 3
27
+ TRACE = 4
28
+
29
+
30
+ @dataclasses.dataclass
31
+ class Operation(SerializeableDataclass):
32
+ operation_id: int
33
+ name: str
34
+ duration: float
35
+
36
+
37
+ @dataclasses.dataclass
38
+ class Device(SerializeableDataclass):
39
+ device_id: int
40
+ num_y_cores: int
41
+ num_x_cores: int
42
+ num_y_compute_cores: int
43
+ num_x_compute_cores: int
44
+ worker_l1_size: int
45
+ l1_num_banks: int
46
+ l1_bank_size: int
47
+ address_at_first_l1_bank: int
48
+ address_at_first_l1_cb_buffer: int
49
+ num_banks_per_storage_core: int
50
+ num_compute_cores: int
51
+ num_storage_cores: int
52
+ total_l1_memory: int
53
+ total_l1_for_tensors: int
54
+ total_l1_for_interleaved_buffers: int
55
+ total_l1_for_sharded_buffers: int
56
+ cb_limit: int
57
+
58
+
59
+ @dataclasses.dataclass
60
+ class DeviceOperation(SerializeableDataclass):
61
+ operation_id: int
62
+ captured_graph: str
63
+
64
+ def __post_init__(self):
65
+ try:
66
+ captured_graph = json.loads(self.captured_graph)
67
+ for graph in captured_graph:
68
+ id = graph.pop("counter")
69
+ graph.update({"id": id})
70
+
71
+ self.captured_graph = captured_graph
72
+
73
+ except JSONDecodeError:
74
+ self.captured_graph = json.dumps({})
75
+
76
+
77
+ @dataclasses.dataclass
78
+ class Buffer(SerializeableDataclass):
79
+ operation_id: int
80
+ device_id: int
81
+ address: int
82
+ max_size_per_bank: int
83
+ buffer_type: BufferType
84
+
85
+
86
+ @dataclasses.dataclass
87
+ class BufferPage(SerializeableDataclass):
88
+ operation_id: int
89
+ device_id: int
90
+ address: int
91
+ core_y: int
92
+ core_x: int
93
+ bank_id: int
94
+ page_index: int
95
+ page_address: int
96
+ page_size: int
97
+ buffer_type: BufferType
98
+
99
+
100
+ @dataclasses.dataclass
101
+ class ProducersConsumers(SerializeableDataclass):
102
+ tensor_id: int
103
+ producers: list[int]
104
+ consumers: list[int]
105
+
106
+
107
+ @dataclasses.dataclass
108
+ class Tensor(SerializeableDataclass):
109
+ tensor_id: int
110
+ shape: str
111
+ dtype: str
112
+ layout: str
113
+ memory_config: str | dict[str, Any] | None
114
+ device_id: int
115
+ address: int
116
+ buffer_type: BufferType
117
+ device_addresses: list[int]
118
+
119
+ def __post_init__(self):
120
+ self.memory_config = parse_memory_config(self.memory_config)
121
+
122
+ @dataclasses.dataclass
123
+ class InputTensor(SerializeableDataclass):
124
+ operation_id: int
125
+ input_index: int
126
+ tensor_id: int
127
+
128
+
129
+ @dataclasses.dataclass
130
+ class OutputTensor(SerializeableDataclass):
131
+ operation_id: int
132
+ output_index: int
133
+ tensor_id: int
134
+
135
+
136
+ @dataclasses.dataclass
137
+ class TensorComparisonRecord(SerializeableDataclass):
138
+ tensor_id: int
139
+ golden_tensor_id: int
140
+ matches: bool
141
+ desired_pcc: bool
142
+ actual_pcc: float
143
+
144
+
145
+ @dataclasses.dataclass
146
+ class OperationArgument(SerializeableDataclass):
147
+ operation_id: int
148
+ name: str
149
+ value: str
150
+
151
+
152
+ @dataclasses.dataclass
153
+ class StackTrace(SerializeableDataclass):
154
+ operation_id: int
155
+ stack_trace: str
156
+
157
+
158
+ # Non Data Models
159
+
160
+
161
+ class SerializeableModel(BaseModel):
162
+ class Config:
163
+ use_enum_values = True
164
+
165
+
166
+ class RemoteConnection(SerializeableModel):
167
+ name: str
168
+ username: str
169
+ host: str
170
+ port: int = Field(ge=1, le=65535)
171
+ reportPath: str
172
+ performancePath: Optional[str] = None
173
+ sqliteBinaryPath: Optional[str] = None
174
+ useRemoteQuerying: bool = False
175
+
176
+
177
+ class StatusMessage(SerializeableModel):
178
+ status: ConnectionTestStates
179
+ message: str
180
+
181
+
182
+ class ActiveReport(SerializeableModel):
183
+ report_name: Optional[str] = None
184
+ profile_name: Optional[str] = None
185
+
186
+
187
+ class RemoteReportFolder(SerializeableModel):
188
+ testName: str
189
+ remotePath: str
190
+ lastModified: int
191
+ lastSynced: Optional[int] = None
192
+
193
+
194
+ class TabSession(BaseModel):
195
+ tab_id: str
196
+ report_path: Optional[str] = None
197
+ profiler_path: Optional[str] = None
198
+ active_report: Optional[ActiveReport] = None
199
+ remote_connection: Optional[RemoteConnection] = None
200
+ remote_folder: Optional[RemoteReportFolder] = None
201
+ remote_profile_folder: Optional[RemoteReportFolder] = None
202
+
203
+
204
+ class TabSessionTable(db.Model):
205
+ __tablename__ = "tab_sessions"
206
+
207
+ id = Column(Integer, primary_key=True)
208
+ tab_id = Column(String, unique=True, nullable=False)
209
+ report_path = Column(String)
210
+ profiler_path = Column(String, nullable=True)
211
+ active_report = db.Column(MutableDict.as_mutable(JSON), nullable=False, default={})
212
+ remote_connection = Column(JSON, nullable=True)
213
+ remote_folder = Column(JSON, nullable=True)
214
+ remote_profile_folder = Column(JSON, nullable=True)
215
+
216
+ def __init__(
217
+ self,
218
+ tab_id,
219
+ active_report,
220
+ remote_connection=None,
221
+ remote_folder=None,
222
+ report_path=None,
223
+ profiler_path=None,
224
+ remote_profile_folder=None,
225
+ ):
226
+ self.tab_id = tab_id
227
+ self.active_report = active_report
228
+ self.report_path = report_path
229
+ self.remote_connection = remote_connection
230
+ self.remote_folder = remote_folder
231
+ self.profiler_path = profiler_path
232
+ self.remote_profile_folder = remote_profile_folder
233
+
234
+ def to_dict(self):
235
+ return {
236
+ "id": self.id,
237
+ "tab_id": self.tab_id,
238
+ "active_report": self.active_report,
239
+ "remote_connection": self.remote_connection,
240
+ "remote_folder": self.remote_folder,
241
+ "remote_profile_folder": self.remote_profile_folder,
242
+ "report_path": self.report_path,
243
+ "profiler_path": self.profiler_path,
244
+ }
245
+
246
+ def to_pydantic(self) -> TabSession:
247
+ return TabSession(
248
+ tab_id=str(self.tab_id),
249
+ report_path=str(self.report_path) if self.report_path is not None else None,
250
+ profiler_path=(
251
+ str(self.profiler_path) if self.profiler_path is not None else None
252
+ ),
253
+ active_report=(
254
+ (ActiveReport(**self.active_report) if self.active_report else None)
255
+ if isinstance(self.active_report, dict)
256
+ else None
257
+ ),
258
+ remote_connection=(
259
+ RemoteConnection.model_validate(self.remote_connection, strict=False)
260
+ if self.remote_connection is not None
261
+ else None
262
+ ),
263
+ remote_folder=(
264
+ RemoteReportFolder.model_validate(self.remote_folder, strict=False)
265
+ if self.remote_folder is not None
266
+ else None
267
+ ),
268
+ remote_profile_folder=(
269
+ RemoteReportFolder.model_validate(
270
+ self.remote_profile_folder, strict=False
271
+ )
272
+ if self.remote_profile_folder is not None
273
+ else None
274
+ ),
275
+ )