ttnn-visualizer 0.42.0__py3-none-any.whl → 0.43.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ttnn_visualizer/__init__.py +0 -1
- ttnn_visualizer/app.py +15 -4
- ttnn_visualizer/csv_queries.py +82 -48
- ttnn_visualizer/decorators.py +38 -15
- ttnn_visualizer/exceptions.py +29 -1
- ttnn_visualizer/file_uploads.py +1 -0
- ttnn_visualizer/instances.py +42 -15
- ttnn_visualizer/models.py +12 -7
- ttnn_visualizer/remote_sqlite_setup.py +37 -30
- ttnn_visualizer/requirements.txt +1 -0
- ttnn_visualizer/serializers.py +1 -0
- ttnn_visualizer/settings.py +9 -5
- ttnn_visualizer/sftp_operations.py +144 -125
- ttnn_visualizer/sockets.py +9 -3
- ttnn_visualizer/static/assets/{allPaths-wwXsGKJ2.js → allPaths-CGmhlOs-.js} +1 -1
- ttnn_visualizer/static/assets/{allPathsLoader-BK9jqlVe.js → allPathsLoader-CH9za42_.js} +2 -2
- ttnn_visualizer/static/assets/index-B-fsa5Ru.js +1 -0
- ttnn_visualizer/static/assets/{index-C1rJBrMl.css → index-C-t6jBt9.css} +1 -1
- ttnn_visualizer/static/assets/{index-Ybr1HJxx.js → index-DEb3r1jy.js} +69 -69
- ttnn_visualizer/static/assets/index-DLOviMB1.js +1 -0
- ttnn_visualizer/static/assets/{splitPathsBySizeLoader-CauQGZHk.js → splitPathsBySizeLoader-CP-kodGu.js} +1 -1
- ttnn_visualizer/static/index.html +2 -2
- ttnn_visualizer/tests/__init__.py +0 -1
- ttnn_visualizer/tests/test_queries.py +0 -1
- ttnn_visualizer/tests/test_serializers.py +2 -2
- ttnn_visualizer/utils.py +7 -3
- ttnn_visualizer/views.py +250 -82
- {ttnn_visualizer-0.42.0.dist-info → ttnn_visualizer-0.43.1.dist-info}/METADATA +5 -1
- ttnn_visualizer-0.43.1.dist-info/RECORD +45 -0
- ttnn_visualizer/static/assets/index-BKzgFDAn.js +0 -1
- ttnn_visualizer/static/assets/index-BvSuWPlB.js +0 -1
- ttnn_visualizer-0.42.0.dist-info/RECORD +0 -45
- {ttnn_visualizer-0.42.0.dist-info → ttnn_visualizer-0.43.1.dist-info}/LICENSE +0 -0
- {ttnn_visualizer-0.42.0.dist-info → ttnn_visualizer-0.43.1.dist-info}/LICENSE_understanding.txt +0 -0
- {ttnn_visualizer-0.42.0.dist-info → ttnn_visualizer-0.43.1.dist-info}/WHEEL +0 -0
- {ttnn_visualizer-0.42.0.dist-info → ttnn_visualizer-0.43.1.dist-info}/entry_points.txt +0 -0
- {ttnn_visualizer-0.42.0.dist-info → ttnn_visualizer-0.43.1.dist-info}/top_level.txt +0 -0
ttnn_visualizer/__init__.py
CHANGED
ttnn_visualizer/app.py
CHANGED
@@ -21,7 +21,11 @@ from flask_cors import CORS
|
|
21
21
|
from werkzeug.debug import DebuggedApplication
|
22
22
|
from werkzeug.middleware.proxy_fix import ProxyFix
|
23
23
|
|
24
|
-
from ttnn_visualizer.exceptions import
|
24
|
+
from ttnn_visualizer.exceptions import (
|
25
|
+
DatabaseFileNotFoundException,
|
26
|
+
InvalidProfilerPath,
|
27
|
+
InvalidReportPath,
|
28
|
+
)
|
25
29
|
from ttnn_visualizer.instances import create_instance_from_local_paths
|
26
30
|
from ttnn_visualizer.settings import Config, DefaultConfig
|
27
31
|
|
@@ -65,6 +69,7 @@ def create_app(settings_override=None):
|
|
65
69
|
extensions(app)
|
66
70
|
|
67
71
|
if flask_env == "production":
|
72
|
+
|
68
73
|
@app.route(f"{app.config['BASE_PATH']}", defaults={"path": ""})
|
69
74
|
@app.route(f"{app.config['BASE_PATH']}<path:path>")
|
70
75
|
def catch_all(path):
|
@@ -179,9 +184,15 @@ def open_browser(host, port, instance_id=None):
|
|
179
184
|
|
180
185
|
|
181
186
|
def parse_args():
|
182
|
-
parser = argparse.ArgumentParser(
|
183
|
-
|
184
|
-
|
187
|
+
parser = argparse.ArgumentParser(
|
188
|
+
description="A tool for visualizing the Tenstorrent Neural Network model (TT-NN)"
|
189
|
+
)
|
190
|
+
parser.add_argument(
|
191
|
+
"--profiler-path", type=str, help="Specify a profiler path", default=None
|
192
|
+
)
|
193
|
+
parser.add_argument(
|
194
|
+
"--performance-path", help="Specify a performance path", default=None
|
195
|
+
)
|
185
196
|
return parser.parse_args()
|
186
197
|
|
187
198
|
|
ttnn_visualizer/csv_queries.py
CHANGED
@@ -11,16 +11,22 @@ from pathlib import Path
|
|
11
11
|
from typing import List, Dict, Union, Optional
|
12
12
|
|
13
13
|
import pandas as pd
|
14
|
+
import zstd
|
14
15
|
from tt_perf_report import perf_report
|
15
16
|
|
16
17
|
from ttnn_visualizer.exceptions import DataFormatError
|
18
|
+
from ttnn_visualizer.exceptions import (
|
19
|
+
SSHException,
|
20
|
+
AuthenticationException,
|
21
|
+
NoValidConnectionsError,
|
22
|
+
)
|
17
23
|
from ttnn_visualizer.models import Instance, RemoteConnection
|
18
|
-
from ttnn_visualizer.exceptions import SSHException, AuthenticationException, NoValidConnectionsError
|
19
|
-
from ttnn_visualizer.models import Instance
|
20
24
|
from ttnn_visualizer.sftp_operations import read_remote_file
|
21
25
|
|
22
26
|
|
23
|
-
def handle_ssh_subprocess_error(
|
27
|
+
def handle_ssh_subprocess_error(
|
28
|
+
e: subprocess.CalledProcessError, remote_connection: RemoteConnection
|
29
|
+
):
|
24
30
|
"""
|
25
31
|
Convert subprocess SSH errors to appropriate SSH exceptions.
|
26
32
|
|
@@ -31,23 +37,29 @@ def handle_ssh_subprocess_error(e: subprocess.CalledProcessError, remote_connect
|
|
31
37
|
stderr = e.stderr.lower() if e.stderr else ""
|
32
38
|
|
33
39
|
# Check for authentication failures
|
34
|
-
if any(
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
40
|
+
if any(
|
41
|
+
auth_err in stderr
|
42
|
+
for auth_err in [
|
43
|
+
"permission denied",
|
44
|
+
"authentication failed",
|
45
|
+
"publickey",
|
46
|
+
"password",
|
47
|
+
"host key verification failed",
|
48
|
+
]
|
49
|
+
):
|
41
50
|
raise AuthenticationException(f"SSH authentication failed: {e.stderr}")
|
42
51
|
|
43
52
|
# Check for connection failures
|
44
|
-
elif any(
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
53
|
+
elif any(
|
54
|
+
conn_err in stderr
|
55
|
+
for conn_err in [
|
56
|
+
"connection refused",
|
57
|
+
"network is unreachable",
|
58
|
+
"no route to host",
|
59
|
+
"name or service not known",
|
60
|
+
"connection timed out",
|
61
|
+
]
|
62
|
+
):
|
51
63
|
raise NoValidConnectionsError(f"SSH connection failed: {e.stderr}")
|
52
64
|
|
53
65
|
# Check for general SSH protocol errors
|
@@ -58,6 +70,7 @@ def handle_ssh_subprocess_error(e: subprocess.CalledProcessError, remote_connect
|
|
58
70
|
else:
|
59
71
|
raise SSHException(f"SSH command failed: {e.stderr}")
|
60
72
|
|
73
|
+
|
61
74
|
class LocalCSVQueryRunner:
|
62
75
|
def __init__(self, file_path: str, offset: int = 0):
|
63
76
|
self.file_path = file_path
|
@@ -154,24 +167,22 @@ class RemoteCSVQueryRunner:
|
|
154
167
|
|
155
168
|
def _execute_ssh_command(self, command: str) -> str:
|
156
169
|
"""Execute an SSH command and return the output."""
|
157
|
-
ssh_cmd = ["ssh"]
|
158
|
-
|
170
|
+
ssh_cmd = ["ssh", "-o", "PasswordAuthentication=no"]
|
171
|
+
|
159
172
|
# Handle non-standard SSH port
|
160
173
|
if self.remote_connection.port != 22:
|
161
174
|
ssh_cmd.extend(["-p", str(self.remote_connection.port)])
|
162
|
-
|
163
|
-
ssh_cmd.extend(
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
175
|
+
|
176
|
+
ssh_cmd.extend(
|
177
|
+
[
|
178
|
+
f"{self.remote_connection.username}@{self.remote_connection.host}",
|
179
|
+
command,
|
180
|
+
]
|
181
|
+
)
|
182
|
+
|
168
183
|
try:
|
169
184
|
result = subprocess.run(
|
170
|
-
ssh_cmd,
|
171
|
-
capture_output=True,
|
172
|
-
text=True,
|
173
|
-
check=True,
|
174
|
-
timeout=30
|
185
|
+
ssh_cmd, capture_output=True, text=True, check=True, timeout=30
|
175
186
|
)
|
176
187
|
return result.stdout
|
177
188
|
except subprocess.CalledProcessError as e:
|
@@ -269,7 +280,7 @@ class RemoteCSVQueryRunner:
|
|
269
280
|
)
|
270
281
|
output = self._execute_ssh_command(cmd).strip()
|
271
282
|
|
272
|
-
return output.splitlines()[self.offset:]
|
283
|
+
return output.splitlines()[self.offset :]
|
273
284
|
|
274
285
|
def get_csv_header(self) -> Dict[str, int]:
|
275
286
|
"""
|
@@ -324,7 +335,9 @@ class NPEQueries:
|
|
324
335
|
and not instance.remote_connection.useRemoteQuerying
|
325
336
|
):
|
326
337
|
file_path = Path(
|
327
|
-
instance.performance_path,
|
338
|
+
instance.performance_path,
|
339
|
+
NPEQueries.NPE_FOLDER,
|
340
|
+
NPEQueries.MANIFEST_FILE,
|
328
341
|
)
|
329
342
|
with open(file_path, "r") as f:
|
330
343
|
return json.load(f)
|
@@ -338,7 +351,9 @@ class NPEQueries:
|
|
338
351
|
@staticmethod
|
339
352
|
def get_npe_timeline(instance: Instance, filename: str):
|
340
353
|
if not filename:
|
341
|
-
raise ValueError(
|
354
|
+
raise ValueError(
|
355
|
+
"filename parameter is required and cannot be None or empty"
|
356
|
+
)
|
342
357
|
|
343
358
|
if (
|
344
359
|
not instance.remote_connection
|
@@ -347,18 +362,33 @@ class NPEQueries:
|
|
347
362
|
if not instance.performance_path:
|
348
363
|
raise ValueError("instance.performance_path is None")
|
349
364
|
|
350
|
-
file_path = Path(
|
351
|
-
|
352
|
-
)
|
353
|
-
|
354
|
-
|
365
|
+
file_path = Path(instance.performance_path, NPEQueries.NPE_FOLDER, filename)
|
366
|
+
|
367
|
+
if filename.endswith(".zst"):
|
368
|
+
with open(file_path, "rb") as file:
|
369
|
+
compressed_data = file.read()
|
370
|
+
uncompressed_data = zstd.uncompress(compressed_data)
|
371
|
+
return json.loads(uncompressed_data)
|
372
|
+
else:
|
373
|
+
with open(file_path, "r") as f:
|
374
|
+
return json.load(f)
|
375
|
+
|
355
376
|
else:
|
356
377
|
profiler_folder = instance.remote_profile_folder
|
357
|
-
|
358
|
-
|
359
|
-
f"{profiler_folder.remotePath}/{NPEQueries.NPE_FOLDER}/{filename}",
|
378
|
+
remote_path = (
|
379
|
+
f"{profiler_folder.remotePath}/{NPEQueries.NPE_FOLDER}/{filename}"
|
360
380
|
)
|
381
|
+
remote_data = read_remote_file(instance.remote_connection, remote_path)
|
361
382
|
|
383
|
+
if filename.endswith(".zst"):
|
384
|
+
if isinstance(remote_data, str):
|
385
|
+
remote_data = remote_data.encode("utf-8")
|
386
|
+
uncompressed_data = zstd.decompress(remote_data)
|
387
|
+
return json.loads(uncompressed_data)
|
388
|
+
else:
|
389
|
+
if isinstance(remote_data, bytes):
|
390
|
+
remote_data = remote_data.decode("utf-8")
|
391
|
+
return json.loads(remote_data)
|
362
392
|
|
363
393
|
|
364
394
|
class DeviceLogProfilerQueries:
|
@@ -614,7 +644,9 @@ class OpsPerformanceQueries:
|
|
614
644
|
or instance.remote_connection
|
615
645
|
and not instance.remote_connection.useRemoteQuerying
|
616
646
|
):
|
617
|
-
with open(
|
647
|
+
with open(
|
648
|
+
OpsPerformanceQueries.get_local_ops_perf_file_path(instance)
|
649
|
+
) as f:
|
618
650
|
return f.read()
|
619
651
|
else:
|
620
652
|
path = OpsPerformanceQueries.get_remote_ops_perf_file_path(instance)
|
@@ -656,9 +688,7 @@ class OpsPerformanceQueries:
|
|
656
688
|
"""
|
657
689
|
try:
|
658
690
|
return [
|
659
|
-
folder.name
|
660
|
-
for folder in Path(directory).iterdir()
|
661
|
-
if folder.is_dir()
|
691
|
+
folder.name for folder in Path(directory).iterdir() if folder.is_dir()
|
662
692
|
]
|
663
693
|
except Exception as e:
|
664
694
|
raise RuntimeError(f"Error accessing directory: {e}")
|
@@ -688,7 +718,7 @@ class OpsPerformanceReportQueries:
|
|
688
718
|
"output_subblock_w",
|
689
719
|
"global_call_count",
|
690
720
|
"advice",
|
691
|
-
"raw_op_code"
|
721
|
+
"raw_op_code",
|
692
722
|
]
|
693
723
|
|
694
724
|
PASSTHROUGH_COLUMNS = {
|
@@ -734,7 +764,9 @@ class OpsPerformanceReportQueries:
|
|
734
764
|
next(reader, None)
|
735
765
|
for row in reader:
|
736
766
|
processed_row = {
|
737
|
-
column: row[index]
|
767
|
+
column: row[index]
|
768
|
+
for index, column in enumerate(cls.REPORT_COLUMNS)
|
769
|
+
if index < len(row)
|
738
770
|
}
|
739
771
|
if "advice" in processed_row and processed_row["advice"]:
|
740
772
|
processed_row["advice"] = processed_row["advice"].split(" • ")
|
@@ -743,7 +775,9 @@ class OpsPerformanceReportQueries:
|
|
743
775
|
|
744
776
|
for key, value in cls.PASSTHROUGH_COLUMNS.items():
|
745
777
|
op_id = int(row[0])
|
746
|
-
idx =
|
778
|
+
idx = (
|
779
|
+
op_id - 2
|
780
|
+
) # IDs in result column one correspond to row numbers in ops perf results csv
|
747
781
|
processed_row[key] = ops_perf_results[idx][value]
|
748
782
|
|
749
783
|
report.append(processed_row)
|
ttnn_visualizer/decorators.py
CHANGED
@@ -14,6 +14,7 @@ from ttnn_visualizer.exceptions import (
|
|
14
14
|
NoValidConnectionsError,
|
15
15
|
SSHException,
|
16
16
|
RemoteConnectionException,
|
17
|
+
AuthenticationFailedException,
|
17
18
|
NoProjectsException,
|
18
19
|
RemoteSqliteException,
|
19
20
|
)
|
@@ -34,21 +35,23 @@ def with_instance(func):
|
|
34
35
|
abort(404)
|
35
36
|
|
36
37
|
instance_query_data = get_or_create_instance(instance_id=instance_id)
|
37
|
-
|
38
|
+
|
38
39
|
# Handle case where get_or_create_instance returns None due to database error
|
39
40
|
if instance_query_data is None:
|
40
|
-
current_app.logger.error(
|
41
|
+
current_app.logger.error(
|
42
|
+
f"Failed to get or create instance with ID: {instance_id}"
|
43
|
+
)
|
41
44
|
abort(500)
|
42
|
-
|
45
|
+
|
43
46
|
instance = instance_query_data.to_pydantic()
|
44
47
|
|
45
48
|
kwargs["instance"] = instance
|
46
49
|
|
47
|
-
if
|
48
|
-
session[
|
50
|
+
if "instances" not in session:
|
51
|
+
session["instances"] = []
|
49
52
|
|
50
|
-
if instance.instance_id not in session[
|
51
|
-
session[
|
53
|
+
if instance.instance_id not in session["instances"]:
|
54
|
+
session["instances"] = session.get("instances", []) + [instance.instance_id]
|
52
55
|
|
53
56
|
return func(*args, **kwargs)
|
54
57
|
|
@@ -68,10 +71,21 @@ def remote_exception_handler(func):
|
|
68
71
|
try:
|
69
72
|
return func(*args, **kwargs)
|
70
73
|
except AuthenticationException as err:
|
71
|
-
|
72
|
-
|
74
|
+
# Log the detailed error for debugging, but don't show full traceback
|
75
|
+
current_app.logger.warning(
|
76
|
+
f"SSH authentication failed for {connection.username}@{connection.host}: SSH key authentication required"
|
77
|
+
)
|
78
|
+
|
79
|
+
# Return user-friendly error message about SSH keys
|
80
|
+
user_message = (
|
81
|
+
"SSH authentication failed. This application requires SSH key-based authentication. "
|
82
|
+
"Please ensure your SSH public key is added to the authorized_keys file on the remote server. "
|
83
|
+
"Password authentication is not supported."
|
84
|
+
)
|
85
|
+
|
86
|
+
raise AuthenticationFailedException(
|
87
|
+
message=user_message,
|
73
88
|
status=ConnectionTestStates.FAILED,
|
74
|
-
message=f"Unable to authenticate: {str(err)}",
|
75
89
|
)
|
76
90
|
except FileNotFoundError as err:
|
77
91
|
current_app.logger.error(f"File not found: {str(err)}")
|
@@ -86,11 +100,20 @@ def remote_exception_handler(func):
|
|
86
100
|
message=f"No projects found at remote location: {connection.path}",
|
87
101
|
)
|
88
102
|
except NoValidConnectionsError as err:
|
89
|
-
current_app.logger.
|
90
|
-
|
103
|
+
current_app.logger.warning(
|
104
|
+
f"SSH connection failed for {connection.username}@{connection.host}: {str(err)}"
|
105
|
+
)
|
106
|
+
|
107
|
+
# Provide user-friendly message for connection issues
|
108
|
+
user_message = (
|
109
|
+
f"Unable to establish SSH connection to {connection.host}. "
|
110
|
+
"Please check the hostname, port, and network connectivity. "
|
111
|
+
"Ensure SSH key-based authentication is properly configured."
|
112
|
+
)
|
113
|
+
|
91
114
|
raise RemoteConnectionException(
|
92
115
|
status=ConnectionTestStates.FAILED,
|
93
|
-
message=
|
116
|
+
message=user_message,
|
94
117
|
)
|
95
118
|
|
96
119
|
except RemoteSqliteException as err:
|
@@ -111,10 +134,10 @@ def remote_exception_handler(func):
|
|
111
134
|
)
|
112
135
|
except SSHException as err:
|
113
136
|
if str(err) == "No existing session":
|
114
|
-
message = "
|
137
|
+
message = "SSH authentication failed. Please ensure SSH keys are configured and ssh-agent is running."
|
115
138
|
else:
|
116
139
|
err_message = re.sub(r"\[.*?]", "", str(err)).strip()
|
117
|
-
message = f"
|
140
|
+
message = f"SSH connection error to {connection.host}: {err_message}. Ensure SSH key-based authentication is properly configured."
|
118
141
|
|
119
142
|
raise RemoteConnectionException(
|
120
143
|
status=ConnectionTestStates.FAILED, message=message
|
ttnn_visualizer/exceptions.py
CHANGED
@@ -3,24 +3,49 @@
|
|
3
3
|
# SPDX-FileCopyrightText: © 2025 Tenstorrent AI ULC
|
4
4
|
|
5
5
|
from http import HTTPStatus
|
6
|
+
from typing import Optional
|
6
7
|
|
7
8
|
from ttnn_visualizer.enums import ConnectionTestStates
|
8
9
|
|
9
10
|
|
10
11
|
class RemoteConnectionException(Exception):
|
11
|
-
def __init__(
|
12
|
+
def __init__(
|
13
|
+
self,
|
14
|
+
message,
|
15
|
+
status: ConnectionTestStates,
|
16
|
+
http_status_code: Optional[HTTPStatus] = None,
|
17
|
+
):
|
12
18
|
super().__init__(message)
|
13
19
|
self.message = message
|
14
20
|
self.status = status
|
21
|
+
self._http_status_code = http_status_code
|
15
22
|
|
16
23
|
@property
|
17
24
|
def http_status(self):
|
25
|
+
# Use custom HTTP status code if provided
|
26
|
+
if self._http_status_code is not None:
|
27
|
+
return self._http_status_code
|
28
|
+
|
29
|
+
# Default behavior
|
18
30
|
if self.status == ConnectionTestStates.FAILED:
|
19
31
|
return HTTPStatus.INTERNAL_SERVER_ERROR
|
20
32
|
if self.status == ConnectionTestStates.OK:
|
21
33
|
return HTTPStatus.OK
|
22
34
|
|
23
35
|
|
36
|
+
class AuthenticationFailedException(RemoteConnectionException):
|
37
|
+
"""Exception for SSH authentication failures that should return HTTP 422"""
|
38
|
+
|
39
|
+
def __init__(
|
40
|
+
self, message, status: ConnectionTestStates = ConnectionTestStates.FAILED
|
41
|
+
):
|
42
|
+
super().__init__(
|
43
|
+
message=message,
|
44
|
+
status=status,
|
45
|
+
http_status_code=HTTPStatus.UNPROCESSABLE_ENTITY, # 422
|
46
|
+
)
|
47
|
+
|
48
|
+
|
24
49
|
class NoProjectsException(RemoteConnectionException):
|
25
50
|
pass
|
26
51
|
|
@@ -50,14 +75,17 @@ class InvalidProfilerPath(Exception):
|
|
50
75
|
|
51
76
|
class SSHException(Exception):
|
52
77
|
"""Base SSH exception for subprocess SSH operations"""
|
78
|
+
|
53
79
|
pass
|
54
80
|
|
55
81
|
|
56
82
|
class AuthenticationException(SSHException):
|
57
83
|
"""Raised when SSH authentication fails"""
|
84
|
+
|
58
85
|
pass
|
59
86
|
|
60
87
|
|
61
88
|
class NoValidConnectionsError(SSHException):
|
62
89
|
"""Raised when SSH connection cannot be established"""
|
90
|
+
|
63
91
|
pass
|
ttnn_visualizer/file_uploads.py
CHANGED
ttnn_visualizer/instances.py
CHANGED
@@ -94,8 +94,7 @@ def update_existing_instance(
|
|
94
94
|
else:
|
95
95
|
if active_report.get("npe_name"):
|
96
96
|
instance_data.npe_path = get_npe_path(
|
97
|
-
npe_name=active_report["npe_name"],
|
98
|
-
current_app=current_app
|
97
|
+
npe_name=active_report["npe_name"], current_app=current_app
|
99
98
|
)
|
100
99
|
|
101
100
|
|
@@ -148,17 +147,25 @@ def create_new_instance(
|
|
148
147
|
instance_data = InstanceTable(
|
149
148
|
instance_id=instance_id,
|
150
149
|
active_report=active_report,
|
151
|
-
profiler_path=
|
152
|
-
|
153
|
-
|
154
|
-
|
150
|
+
profiler_path=(
|
151
|
+
profiler_path
|
152
|
+
if profiler_path is not _sentinel
|
153
|
+
else get_profiler_path(
|
154
|
+
active_report["profiler_name"],
|
155
|
+
current_app=current_app,
|
156
|
+
remote_connection=remote_connection,
|
157
|
+
)
|
155
158
|
),
|
156
159
|
remote_connection=(
|
157
160
|
remote_connection.model_dump() if remote_connection else None
|
158
161
|
),
|
159
|
-
remote_profiler_folder=
|
162
|
+
remote_profiler_folder=(
|
163
|
+
remote_profiler_folder.model_dump() if remote_profiler_folder else None
|
164
|
+
),
|
160
165
|
remote_performance_folder=(
|
161
|
-
remote_performance_folder.model_dump()
|
166
|
+
remote_performance_folder.model_dump()
|
167
|
+
if remote_performance_folder
|
168
|
+
else None
|
162
169
|
),
|
163
170
|
)
|
164
171
|
|
@@ -255,10 +262,18 @@ def get_or_create_instance(
|
|
255
262
|
db.session.commit()
|
256
263
|
except IntegrityError:
|
257
264
|
db.session.rollback()
|
258
|
-
instance_data = InstanceTable.query.filter_by(
|
265
|
+
instance_data = InstanceTable.query.filter_by(
|
266
|
+
instance_id=instance_id
|
267
|
+
).first()
|
259
268
|
|
260
269
|
# Update the instance if any new data is provided
|
261
|
-
if
|
270
|
+
if (
|
271
|
+
profiler_name
|
272
|
+
or performance_name
|
273
|
+
or npe_name
|
274
|
+
or remote_connection
|
275
|
+
or remote_profiler_folder
|
276
|
+
):
|
262
277
|
update_instance(
|
263
278
|
instance_id=instance_id,
|
264
279
|
profiler_name=profiler_name,
|
@@ -269,7 +284,9 @@ def get_or_create_instance(
|
|
269
284
|
)
|
270
285
|
|
271
286
|
# Query again to get the updated instance data
|
272
|
-
instance_data = InstanceTable.query.filter_by(
|
287
|
+
instance_data = InstanceTable.query.filter_by(
|
288
|
+
instance_id=instance_id
|
289
|
+
).first()
|
273
290
|
|
274
291
|
return instance_data
|
275
292
|
|
@@ -318,7 +335,7 @@ def init_instances(app):
|
|
318
335
|
|
319
336
|
|
320
337
|
def create_random_instance_id():
|
321
|
-
return
|
338
|
+
return "".join(random.choices(string.ascii_lowercase + string.digits, k=45))
|
322
339
|
|
323
340
|
|
324
341
|
def create_instance_from_local_paths(profiler_path, performance_path):
|
@@ -328,11 +345,21 @@ def create_instance_from_local_paths(profiler_path, performance_path):
|
|
328
345
|
if _profiler_path and (not _profiler_path.exists() or not _profiler_path.is_dir()):
|
329
346
|
raise InvalidReportPath()
|
330
347
|
|
331
|
-
if _performance_path and (
|
348
|
+
if _performance_path and (
|
349
|
+
not _performance_path.exists() or not _performance_path.is_dir()
|
350
|
+
):
|
332
351
|
raise InvalidProfilerPath()
|
333
352
|
|
334
|
-
profiler_name =
|
335
|
-
|
353
|
+
profiler_name = (
|
354
|
+
_profiler_path.parts[-1]
|
355
|
+
if _profiler_path and len(_profiler_path.parts) > 2
|
356
|
+
else ""
|
357
|
+
)
|
358
|
+
performance_name = (
|
359
|
+
_performance_path.parts[-1]
|
360
|
+
if _performance_path and len(_performance_path.parts) > 2
|
361
|
+
else ""
|
362
|
+
)
|
336
363
|
instance_data = InstanceTable(
|
337
364
|
instance_id=create_random_instance_id(),
|
338
365
|
active_report={
|
ttnn_visualizer/models.py
CHANGED
@@ -119,6 +119,7 @@ class Tensor(SerializeableDataclass):
|
|
119
119
|
def __post_init__(self):
|
120
120
|
self.memory_config = parse_memory_config(self.memory_config)
|
121
121
|
|
122
|
+
|
122
123
|
@dataclasses.dataclass
|
123
124
|
class InputTensor(SerializeableDataclass):
|
124
125
|
operation_id: int
|
@@ -246,19 +247,21 @@ class InstanceTable(db.Model):
|
|
246
247
|
"remote_performance_folder": self.remote_performance_folder,
|
247
248
|
"profiler_path": self.profiler_path,
|
248
249
|
"performance_path": self.performance_path,
|
249
|
-
"npe_path": self.npe_path
|
250
|
+
"npe_path": self.npe_path,
|
250
251
|
}
|
251
252
|
|
252
253
|
def to_pydantic(self) -> Instance:
|
253
254
|
return Instance(
|
254
255
|
instance_id=str(self.instance_id),
|
255
|
-
profiler_path=
|
256
|
-
|
257
|
-
str(self.performance_path) if self.performance_path is not None else None
|
256
|
+
profiler_path=(
|
257
|
+
str(self.profiler_path) if self.profiler_path is not None else None
|
258
258
|
),
|
259
|
-
|
260
|
-
str(self.
|
259
|
+
performance_path=(
|
260
|
+
str(self.performance_path)
|
261
|
+
if self.performance_path is not None
|
262
|
+
else None
|
261
263
|
),
|
264
|
+
npe_path=(str(self.npe_path) if self.npe_path is not None else None),
|
262
265
|
active_report=(
|
263
266
|
(ActiveReports(**self.active_report) if self.active_report else None)
|
264
267
|
if isinstance(self.active_report, dict)
|
@@ -270,7 +273,9 @@ class InstanceTable(db.Model):
|
|
270
273
|
else None
|
271
274
|
),
|
272
275
|
remote_profiler_folder=(
|
273
|
-
RemoteReportFolder.model_validate(
|
276
|
+
RemoteReportFolder.model_validate(
|
277
|
+
self.remote_profiler_folder, strict=False
|
278
|
+
)
|
274
279
|
if self.remote_profiler_folder is not None
|
275
280
|
else None
|
276
281
|
),
|