ttnn-visualizer 0.41.0__py3-none-any.whl → 0.43.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ttnn_visualizer/__init__.py +0 -1
- ttnn_visualizer/app.py +15 -4
- ttnn_visualizer/csv_queries.py +150 -40
- ttnn_visualizer/decorators.py +42 -16
- ttnn_visualizer/exceptions.py +45 -1
- ttnn_visualizer/file_uploads.py +1 -0
- ttnn_visualizer/instances.py +42 -15
- ttnn_visualizer/models.py +12 -7
- ttnn_visualizer/queries.py +3 -109
- ttnn_visualizer/remote_sqlite_setup.py +104 -19
- ttnn_visualizer/requirements.txt +2 -3
- ttnn_visualizer/serializers.py +1 -0
- ttnn_visualizer/settings.py +9 -5
- ttnn_visualizer/sftp_operations.py +657 -220
- ttnn_visualizer/sockets.py +9 -3
- ttnn_visualizer/static/assets/{allPaths-4_pFqSAW.js → allPaths-BQN_j7ek.js} +1 -1
- ttnn_visualizer/static/assets/{allPathsLoader-CpLPTLlt.js → allPathsLoader-BvkkQ77q.js} +2 -2
- ttnn_visualizer/static/assets/index-B-fsa5Ru.js +1 -0
- ttnn_visualizer/static/assets/{index-DFVwehlj.js → index-Bng0kcmi.js} +214 -214
- ttnn_visualizer/static/assets/{index-C1rJBrMl.css → index-C-t6jBt9.css} +1 -1
- ttnn_visualizer/static/assets/index-DLOviMB1.js +1 -0
- ttnn_visualizer/static/assets/{splitPathsBySizeLoader-D-RvsTqO.js → splitPathsBySizeLoader-Cl0NRdfL.js} +1 -1
- ttnn_visualizer/static/index.html +2 -2
- ttnn_visualizer/tests/__init__.py +0 -1
- ttnn_visualizer/tests/test_queries.py +0 -69
- ttnn_visualizer/tests/test_serializers.py +2 -2
- ttnn_visualizer/utils.py +7 -3
- ttnn_visualizer/views.py +315 -52
- {ttnn_visualizer-0.41.0.dist-info → ttnn_visualizer-0.43.0.dist-info}/LICENSE +0 -1
- {ttnn_visualizer-0.41.0.dist-info → ttnn_visualizer-0.43.0.dist-info}/METADATA +6 -3
- ttnn_visualizer-0.43.0.dist-info/RECORD +45 -0
- ttnn_visualizer/ssh_client.py +0 -85
- ttnn_visualizer/static/assets/index-BKzgFDAn.js +0 -1
- ttnn_visualizer/static/assets/index-BvSuWPlB.js +0 -1
- ttnn_visualizer-0.41.0.dist-info/RECORD +0 -46
- {ttnn_visualizer-0.41.0.dist-info → ttnn_visualizer-0.43.0.dist-info}/LICENSE_understanding.txt +0 -0
- {ttnn_visualizer-0.41.0.dist-info → ttnn_visualizer-0.43.0.dist-info}/WHEEL +0 -0
- {ttnn_visualizer-0.41.0.dist-info → ttnn_visualizer-0.43.0.dist-info}/entry_points.txt +0 -0
- {ttnn_visualizer-0.41.0.dist-info → ttnn_visualizer-0.43.0.dist-info}/top_level.txt +0 -0
ttnn_visualizer/__init__.py
CHANGED
ttnn_visualizer/app.py
CHANGED
@@ -21,7 +21,11 @@ from flask_cors import CORS
|
|
21
21
|
from werkzeug.debug import DebuggedApplication
|
22
22
|
from werkzeug.middleware.proxy_fix import ProxyFix
|
23
23
|
|
24
|
-
from ttnn_visualizer.exceptions import
|
24
|
+
from ttnn_visualizer.exceptions import (
|
25
|
+
DatabaseFileNotFoundException,
|
26
|
+
InvalidProfilerPath,
|
27
|
+
InvalidReportPath,
|
28
|
+
)
|
25
29
|
from ttnn_visualizer.instances import create_instance_from_local_paths
|
26
30
|
from ttnn_visualizer.settings import Config, DefaultConfig
|
27
31
|
|
@@ -65,6 +69,7 @@ def create_app(settings_override=None):
|
|
65
69
|
extensions(app)
|
66
70
|
|
67
71
|
if flask_env == "production":
|
72
|
+
|
68
73
|
@app.route(f"{app.config['BASE_PATH']}", defaults={"path": ""})
|
69
74
|
@app.route(f"{app.config['BASE_PATH']}<path:path>")
|
70
75
|
def catch_all(path):
|
@@ -179,9 +184,15 @@ def open_browser(host, port, instance_id=None):
|
|
179
184
|
|
180
185
|
|
181
186
|
def parse_args():
|
182
|
-
parser = argparse.ArgumentParser(
|
183
|
-
|
184
|
-
|
187
|
+
parser = argparse.ArgumentParser(
|
188
|
+
description="A tool for visualizing the Tenstorrent Neural Network model (TT-NN)"
|
189
|
+
)
|
190
|
+
parser.add_argument(
|
191
|
+
"--profiler-path", type=str, help="Specify a profiler path", default=None
|
192
|
+
)
|
193
|
+
parser.add_argument(
|
194
|
+
"--performance-path", help="Specify a performance path", default=None
|
195
|
+
)
|
185
196
|
return parser.parse_args()
|
186
197
|
|
187
198
|
|
ttnn_visualizer/csv_queries.py
CHANGED
@@ -4,19 +4,73 @@
|
|
4
4
|
import csv
|
5
5
|
import json
|
6
6
|
import os
|
7
|
+
import subprocess
|
7
8
|
import tempfile
|
8
9
|
from io import StringIO
|
9
10
|
from pathlib import Path
|
10
11
|
from typing import List, Dict, Union, Optional
|
11
12
|
|
12
13
|
import pandas as pd
|
14
|
+
import zstd
|
13
15
|
from tt_perf_report import perf_report
|
14
16
|
|
15
17
|
from ttnn_visualizer.exceptions import DataFormatError
|
16
|
-
from ttnn_visualizer.
|
17
|
-
|
18
|
+
from ttnn_visualizer.exceptions import (
|
19
|
+
SSHException,
|
20
|
+
AuthenticationException,
|
21
|
+
NoValidConnectionsError,
|
22
|
+
)
|
23
|
+
from ttnn_visualizer.models import Instance, RemoteConnection
|
18
24
|
from ttnn_visualizer.sftp_operations import read_remote_file
|
19
25
|
|
26
|
+
|
27
|
+
def handle_ssh_subprocess_error(
|
28
|
+
e: subprocess.CalledProcessError, remote_connection: RemoteConnection
|
29
|
+
):
|
30
|
+
"""
|
31
|
+
Convert subprocess SSH errors to appropriate SSH exceptions.
|
32
|
+
|
33
|
+
:param e: The subprocess.CalledProcessError
|
34
|
+
:param remote_connection: The RemoteConnection object for context
|
35
|
+
:raises: SSHException, AuthenticationException, or NoValidConnectionsError
|
36
|
+
"""
|
37
|
+
stderr = e.stderr.lower() if e.stderr else ""
|
38
|
+
|
39
|
+
# Check for authentication failures
|
40
|
+
if any(
|
41
|
+
auth_err in stderr
|
42
|
+
for auth_err in [
|
43
|
+
"permission denied",
|
44
|
+
"authentication failed",
|
45
|
+
"publickey",
|
46
|
+
"password",
|
47
|
+
"host key verification failed",
|
48
|
+
]
|
49
|
+
):
|
50
|
+
raise AuthenticationException(f"SSH authentication failed: {e.stderr}")
|
51
|
+
|
52
|
+
# Check for connection failures
|
53
|
+
elif any(
|
54
|
+
conn_err in stderr
|
55
|
+
for conn_err in [
|
56
|
+
"connection refused",
|
57
|
+
"network is unreachable",
|
58
|
+
"no route to host",
|
59
|
+
"name or service not known",
|
60
|
+
"connection timed out",
|
61
|
+
]
|
62
|
+
):
|
63
|
+
raise NoValidConnectionsError(f"SSH connection failed: {e.stderr}")
|
64
|
+
|
65
|
+
# Check for general SSH protocol errors
|
66
|
+
elif "ssh:" in stderr or "protocol" in stderr:
|
67
|
+
raise SSHException(f"SSH protocol error: {e.stderr}")
|
68
|
+
|
69
|
+
# Default to generic SSH exception
|
70
|
+
else:
|
71
|
+
raise SSHException(f"SSH command failed: {e.stderr}")
|
72
|
+
|
73
|
+
|
20
74
|
class LocalCSVQueryRunner:
|
21
75
|
def __init__(self, file_path: str, offset: int = 0):
|
22
76
|
self.file_path = file_path
|
@@ -110,7 +164,36 @@ class RemoteCSVQueryRunner:
|
|
110
164
|
self.remote_connection = remote_connection
|
111
165
|
self.sep = sep
|
112
166
|
self.offset = offset
|
113
|
-
|
167
|
+
|
168
|
+
def _execute_ssh_command(self, command: str) -> str:
|
169
|
+
"""Execute an SSH command and return the output."""
|
170
|
+
ssh_cmd = ["ssh", "-o", "PasswordAuthentication=no"]
|
171
|
+
|
172
|
+
# Handle non-standard SSH port
|
173
|
+
if self.remote_connection.port != 22:
|
174
|
+
ssh_cmd.extend(["-p", str(self.remote_connection.port)])
|
175
|
+
|
176
|
+
ssh_cmd.extend(
|
177
|
+
[
|
178
|
+
f"{self.remote_connection.username}@{self.remote_connection.host}",
|
179
|
+
command,
|
180
|
+
]
|
181
|
+
)
|
182
|
+
|
183
|
+
try:
|
184
|
+
result = subprocess.run(
|
185
|
+
ssh_cmd, capture_output=True, text=True, check=True, timeout=30
|
186
|
+
)
|
187
|
+
return result.stdout
|
188
|
+
except subprocess.CalledProcessError as e:
|
189
|
+
if e.returncode == 255: # SSH protocol errors
|
190
|
+
handle_ssh_subprocess_error(e, self.remote_connection)
|
191
|
+
# This line should never be reached as handle_ssh_subprocess_error raises an exception
|
192
|
+
raise RuntimeError(f"SSH command failed: {e.stderr}")
|
193
|
+
else:
|
194
|
+
raise RuntimeError(f"SSH command failed: {e.stderr}")
|
195
|
+
except subprocess.TimeoutExpired:
|
196
|
+
raise RuntimeError(f"SSH command timed out: {command}")
|
114
197
|
|
115
198
|
def execute_query(
|
116
199
|
self,
|
@@ -128,12 +211,7 @@ class RemoteCSVQueryRunner:
|
|
128
211
|
"""
|
129
212
|
# Fetch header row, accounting for the offset
|
130
213
|
header_cmd = f"head -n {self.offset + 1} {self.file_path} | tail -n 1"
|
131
|
-
|
132
|
-
raw_header = stdout.read().decode("utf-8").strip()
|
133
|
-
error = stderr.read().decode("utf-8").strip()
|
134
|
-
|
135
|
-
if error:
|
136
|
-
raise RuntimeError(f"Error fetching header row: {error}")
|
214
|
+
raw_header = self._execute_ssh_command(header_cmd).strip()
|
137
215
|
|
138
216
|
# Sanitize headers
|
139
217
|
headers = [
|
@@ -160,12 +238,7 @@ class RemoteCSVQueryRunner:
|
|
160
238
|
limit_clause = f"| head -n {limit}" if limit else ""
|
161
239
|
awk_cmd = f"awk -F'{self.sep}' 'NR > {self.offset + 1} {f'&& {awk_filter}' if awk_filter else ''} {{print}}' {self.file_path} {limit_clause}"
|
162
240
|
|
163
|
-
|
164
|
-
output = stdout.read().decode("utf-8").strip()
|
165
|
-
error = stderr.read().decode("utf-8").strip()
|
166
|
-
|
167
|
-
if error:
|
168
|
-
raise RuntimeError(f"Error executing AWK command: {error}")
|
241
|
+
output = self._execute_ssh_command(awk_cmd).strip()
|
169
242
|
|
170
243
|
# Split rows into lists of strings
|
171
244
|
rows = [
|
@@ -205,14 +278,9 @@ class RemoteCSVQueryRunner:
|
|
205
278
|
if total_lines
|
206
279
|
else f"cat {self.file_path}"
|
207
280
|
)
|
208
|
-
|
209
|
-
output = stdout.read().decode("utf-8").strip()
|
210
|
-
error = stderr.read().decode("utf-8").strip()
|
211
|
-
|
212
|
-
if error:
|
213
|
-
raise RuntimeError(f"Error fetching raw rows: {error}")
|
281
|
+
output = self._execute_ssh_command(cmd).strip()
|
214
282
|
|
215
|
-
return output.splitlines()[self.offset:]
|
283
|
+
return output.splitlines()[self.offset :]
|
216
284
|
|
217
285
|
def get_csv_header(self) -> Dict[str, int]:
|
218
286
|
"""
|
@@ -220,12 +288,7 @@ class RemoteCSVQueryRunner:
|
|
220
288
|
:return: Dictionary of headers.
|
221
289
|
"""
|
222
290
|
header_cmd = f"head -n {self.offset + 1} {self.file_path} | tail -n 1"
|
223
|
-
|
224
|
-
header = stdout.read().decode("utf-8").strip()
|
225
|
-
error = stderr.read().decode("utf-8").strip()
|
226
|
-
|
227
|
-
if error:
|
228
|
-
raise RuntimeError(f"Error reading CSV header: {error}")
|
291
|
+
header = self._execute_ssh_command(header_cmd).strip()
|
229
292
|
|
230
293
|
# Trim spaces in header names
|
231
294
|
column_names = [name.strip() for name in header.split(self.sep)]
|
@@ -254,10 +317,9 @@ class RemoteCSVQueryRunner:
|
|
254
317
|
|
255
318
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
256
319
|
"""
|
257
|
-
Clean up
|
320
|
+
Clean up resources when exiting context.
|
258
321
|
"""
|
259
|
-
|
260
|
-
self.ssh_client.close()
|
322
|
+
pass
|
261
323
|
|
262
324
|
|
263
325
|
class NPEQueries:
|
@@ -267,14 +329,15 @@ class NPEQueries:
|
|
267
329
|
@staticmethod
|
268
330
|
def get_npe_manifest(instance: Instance):
|
269
331
|
|
270
|
-
|
271
332
|
if (
|
272
333
|
not instance.remote_connection
|
273
334
|
or instance.remote_connection
|
274
335
|
and not instance.remote_connection.useRemoteQuerying
|
275
336
|
):
|
276
337
|
file_path = Path(
|
277
|
-
instance.performance_path,
|
338
|
+
instance.performance_path,
|
339
|
+
NPEQueries.NPE_FOLDER,
|
340
|
+
NPEQueries.MANIFEST_FILE,
|
278
341
|
)
|
279
342
|
with open(file_path, "r") as f:
|
280
343
|
return json.load(f)
|
@@ -285,6 +348,48 @@ class NPEQueries:
|
|
285
348
|
f"{profiler_folder.remotePath}/{NPEQueries.NPE_FOLDER}/{NPEQueries.MANIFEST_FILE}",
|
286
349
|
)
|
287
350
|
|
351
|
+
@staticmethod
|
352
|
+
def get_npe_timeline(instance: Instance, filename: str):
|
353
|
+
if not filename:
|
354
|
+
raise ValueError(
|
355
|
+
"filename parameter is required and cannot be None or empty"
|
356
|
+
)
|
357
|
+
|
358
|
+
if (
|
359
|
+
not instance.remote_connection
|
360
|
+
or not instance.remote_connection.useRemoteQuerying
|
361
|
+
):
|
362
|
+
if not instance.performance_path:
|
363
|
+
raise ValueError("instance.performance_path is None")
|
364
|
+
|
365
|
+
file_path = Path(instance.performance_path, NPEQueries.NPE_FOLDER, filename)
|
366
|
+
|
367
|
+
if filename.endswith(".zst"):
|
368
|
+
with open(file_path, "rb") as file:
|
369
|
+
compressed_data = file.read()
|
370
|
+
uncompressed_data = zstd.uncompress(compressed_data)
|
371
|
+
return json.loads(uncompressed_data)
|
372
|
+
else:
|
373
|
+
with open(file_path, "r") as f:
|
374
|
+
return json.load(f)
|
375
|
+
|
376
|
+
else:
|
377
|
+
profiler_folder = instance.remote_profile_folder
|
378
|
+
remote_path = (
|
379
|
+
f"{profiler_folder.remotePath}/{NPEQueries.NPE_FOLDER}/{filename}"
|
380
|
+
)
|
381
|
+
remote_data = read_remote_file(instance.remote_connection, remote_path)
|
382
|
+
|
383
|
+
if filename.endswith(".zst"):
|
384
|
+
if isinstance(remote_data, str):
|
385
|
+
remote_data = remote_data.encode("utf-8")
|
386
|
+
uncompressed_data = zstd.decompress(remote_data)
|
387
|
+
return json.loads(uncompressed_data)
|
388
|
+
else:
|
389
|
+
if isinstance(remote_data, bytes):
|
390
|
+
remote_data = remote_data.decode("utf-8")
|
391
|
+
return json.loads(remote_data)
|
392
|
+
|
288
393
|
|
289
394
|
class DeviceLogProfilerQueries:
|
290
395
|
DEVICE_LOG_FILE = "profile_log_device.csv"
|
@@ -539,7 +644,9 @@ class OpsPerformanceQueries:
|
|
539
644
|
or instance.remote_connection
|
540
645
|
and not instance.remote_connection.useRemoteQuerying
|
541
646
|
):
|
542
|
-
with open(
|
647
|
+
with open(
|
648
|
+
OpsPerformanceQueries.get_local_ops_perf_file_path(instance)
|
649
|
+
) as f:
|
543
650
|
return f.read()
|
544
651
|
else:
|
545
652
|
path = OpsPerformanceQueries.get_remote_ops_perf_file_path(instance)
|
@@ -581,9 +688,7 @@ class OpsPerformanceQueries:
|
|
581
688
|
"""
|
582
689
|
try:
|
583
690
|
return [
|
584
|
-
folder.name
|
585
|
-
for folder in Path(directory).iterdir()
|
586
|
-
if folder.is_dir()
|
691
|
+
folder.name for folder in Path(directory).iterdir() if folder.is_dir()
|
587
692
|
]
|
588
693
|
except Exception as e:
|
589
694
|
raise RuntimeError(f"Error accessing directory: {e}")
|
@@ -611,8 +716,9 @@ class OpsPerformanceReportQueries:
|
|
611
716
|
"inner_dim_block_size",
|
612
717
|
"output_subblock_h",
|
613
718
|
"output_subblock_w",
|
719
|
+
"global_call_count",
|
614
720
|
"advice",
|
615
|
-
"raw_op_code"
|
721
|
+
"raw_op_code",
|
616
722
|
]
|
617
723
|
|
618
724
|
PASSTHROUGH_COLUMNS = {
|
@@ -658,7 +764,9 @@ class OpsPerformanceReportQueries:
|
|
658
764
|
next(reader, None)
|
659
765
|
for row in reader:
|
660
766
|
processed_row = {
|
661
|
-
column: row[index]
|
767
|
+
column: row[index]
|
768
|
+
for index, column in enumerate(cls.REPORT_COLUMNS)
|
769
|
+
if index < len(row)
|
662
770
|
}
|
663
771
|
if "advice" in processed_row and processed_row["advice"]:
|
664
772
|
processed_row["advice"] = processed_row["advice"].split(" • ")
|
@@ -667,7 +775,9 @@ class OpsPerformanceReportQueries:
|
|
667
775
|
|
668
776
|
for key, value in cls.PASSTHROUGH_COLUMNS.items():
|
669
777
|
op_id = int(row[0])
|
670
|
-
idx =
|
778
|
+
idx = (
|
779
|
+
op_id - 2
|
780
|
+
) # IDs in result column one correspond to row numbers in ops perf results csv
|
671
781
|
processed_row[key] = ops_perf_results[idx][value]
|
672
782
|
|
673
783
|
report.append(processed_row)
|
ttnn_visualizer/decorators.py
CHANGED
@@ -9,14 +9,12 @@ from ttnn_visualizer.enums import ConnectionTestStates
|
|
9
9
|
|
10
10
|
from functools import wraps
|
11
11
|
from flask import abort, request, session
|
12
|
-
from
|
12
|
+
from ttnn_visualizer.exceptions import (
|
13
13
|
AuthenticationException,
|
14
14
|
NoValidConnectionsError,
|
15
15
|
SSHException,
|
16
|
-
)
|
17
|
-
|
18
|
-
from ttnn_visualizer.exceptions import (
|
19
16
|
RemoteConnectionException,
|
17
|
+
AuthenticationFailedException,
|
20
18
|
NoProjectsException,
|
21
19
|
RemoteSqliteException,
|
22
20
|
)
|
@@ -37,15 +35,23 @@ def with_instance(func):
|
|
37
35
|
abort(404)
|
38
36
|
|
39
37
|
instance_query_data = get_or_create_instance(instance_id=instance_id)
|
38
|
+
|
39
|
+
# Handle case where get_or_create_instance returns None due to database error
|
40
|
+
if instance_query_data is None:
|
41
|
+
current_app.logger.error(
|
42
|
+
f"Failed to get or create instance with ID: {instance_id}"
|
43
|
+
)
|
44
|
+
abort(500)
|
45
|
+
|
40
46
|
instance = instance_query_data.to_pydantic()
|
41
47
|
|
42
48
|
kwargs["instance"] = instance
|
43
49
|
|
44
|
-
if
|
45
|
-
session[
|
50
|
+
if "instances" not in session:
|
51
|
+
session["instances"] = []
|
46
52
|
|
47
|
-
if instance.instance_id not in session[
|
48
|
-
session[
|
53
|
+
if instance.instance_id not in session["instances"]:
|
54
|
+
session["instances"] = session.get("instances", []) + [instance.instance_id]
|
49
55
|
|
50
56
|
return func(*args, **kwargs)
|
51
57
|
|
@@ -65,10 +71,21 @@ def remote_exception_handler(func):
|
|
65
71
|
try:
|
66
72
|
return func(*args, **kwargs)
|
67
73
|
except AuthenticationException as err:
|
68
|
-
|
69
|
-
|
74
|
+
# Log the detailed error for debugging, but don't show full traceback
|
75
|
+
current_app.logger.warning(
|
76
|
+
f"SSH authentication failed for {connection.username}@{connection.host}: SSH key authentication required"
|
77
|
+
)
|
78
|
+
|
79
|
+
# Return user-friendly error message about SSH keys
|
80
|
+
user_message = (
|
81
|
+
"SSH authentication failed. This application requires SSH key-based authentication. "
|
82
|
+
"Please ensure your SSH public key is added to the authorized_keys file on the remote server. "
|
83
|
+
"Password authentication is not supported."
|
84
|
+
)
|
85
|
+
|
86
|
+
raise AuthenticationFailedException(
|
87
|
+
message=user_message,
|
70
88
|
status=ConnectionTestStates.FAILED,
|
71
|
-
message=f"Unable to authenticate: {str(err)}",
|
72
89
|
)
|
73
90
|
except FileNotFoundError as err:
|
74
91
|
current_app.logger.error(f"File not found: {str(err)}")
|
@@ -83,11 +100,20 @@ def remote_exception_handler(func):
|
|
83
100
|
message=f"No projects found at remote location: {connection.path}",
|
84
101
|
)
|
85
102
|
except NoValidConnectionsError as err:
|
86
|
-
current_app.logger.
|
87
|
-
|
103
|
+
current_app.logger.warning(
|
104
|
+
f"SSH connection failed for {connection.username}@{connection.host}: {str(err)}"
|
105
|
+
)
|
106
|
+
|
107
|
+
# Provide user-friendly message for connection issues
|
108
|
+
user_message = (
|
109
|
+
f"Unable to establish SSH connection to {connection.host}. "
|
110
|
+
"Please check the hostname, port, and network connectivity. "
|
111
|
+
"Ensure SSH key-based authentication is properly configured."
|
112
|
+
)
|
113
|
+
|
88
114
|
raise RemoteConnectionException(
|
89
115
|
status=ConnectionTestStates.FAILED,
|
90
|
-
message=
|
116
|
+
message=user_message,
|
91
117
|
)
|
92
118
|
|
93
119
|
except RemoteSqliteException as err:
|
@@ -108,10 +134,10 @@ def remote_exception_handler(func):
|
|
108
134
|
)
|
109
135
|
except SSHException as err:
|
110
136
|
if str(err) == "No existing session":
|
111
|
-
message = "
|
137
|
+
message = "SSH authentication failed. Please ensure SSH keys are configured and ssh-agent is running."
|
112
138
|
else:
|
113
139
|
err_message = re.sub(r"\[.*?]", "", str(err)).strip()
|
114
|
-
message = f"
|
140
|
+
message = f"SSH connection error to {connection.host}: {err_message}. Ensure SSH key-based authentication is properly configured."
|
115
141
|
|
116
142
|
raise RemoteConnectionException(
|
117
143
|
status=ConnectionTestStates.FAILED, message=message
|
ttnn_visualizer/exceptions.py
CHANGED
@@ -3,24 +3,49 @@
|
|
3
3
|
# SPDX-FileCopyrightText: © 2025 Tenstorrent AI ULC
|
4
4
|
|
5
5
|
from http import HTTPStatus
|
6
|
+
from typing import Optional
|
6
7
|
|
7
8
|
from ttnn_visualizer.enums import ConnectionTestStates
|
8
9
|
|
9
10
|
|
10
11
|
class RemoteConnectionException(Exception):
|
11
|
-
def __init__(
|
12
|
+
def __init__(
|
13
|
+
self,
|
14
|
+
message,
|
15
|
+
status: ConnectionTestStates,
|
16
|
+
http_status_code: Optional[HTTPStatus] = None,
|
17
|
+
):
|
12
18
|
super().__init__(message)
|
13
19
|
self.message = message
|
14
20
|
self.status = status
|
21
|
+
self._http_status_code = http_status_code
|
15
22
|
|
16
23
|
@property
|
17
24
|
def http_status(self):
|
25
|
+
# Use custom HTTP status code if provided
|
26
|
+
if self._http_status_code is not None:
|
27
|
+
return self._http_status_code
|
28
|
+
|
29
|
+
# Default behavior
|
18
30
|
if self.status == ConnectionTestStates.FAILED:
|
19
31
|
return HTTPStatus.INTERNAL_SERVER_ERROR
|
20
32
|
if self.status == ConnectionTestStates.OK:
|
21
33
|
return HTTPStatus.OK
|
22
34
|
|
23
35
|
|
36
|
+
class AuthenticationFailedException(RemoteConnectionException):
|
37
|
+
"""Exception for SSH authentication failures that should return HTTP 422"""
|
38
|
+
|
39
|
+
def __init__(
|
40
|
+
self, message, status: ConnectionTestStates = ConnectionTestStates.FAILED
|
41
|
+
):
|
42
|
+
super().__init__(
|
43
|
+
message=message,
|
44
|
+
status=status,
|
45
|
+
http_status_code=HTTPStatus.UNPROCESSABLE_ENTITY, # 422
|
46
|
+
)
|
47
|
+
|
48
|
+
|
24
49
|
class NoProjectsException(RemoteConnectionException):
|
25
50
|
pass
|
26
51
|
|
@@ -43,5 +68,24 @@ class DataFormatError(Exception):
|
|
43
68
|
class InvalidReportPath(Exception):
|
44
69
|
pass
|
45
70
|
|
71
|
+
|
46
72
|
class InvalidProfilerPath(Exception):
|
47
73
|
pass
|
74
|
+
|
75
|
+
|
76
|
+
class SSHException(Exception):
|
77
|
+
"""Base SSH exception for subprocess SSH operations"""
|
78
|
+
|
79
|
+
pass
|
80
|
+
|
81
|
+
|
82
|
+
class AuthenticationException(SSHException):
|
83
|
+
"""Raised when SSH authentication fails"""
|
84
|
+
|
85
|
+
pass
|
86
|
+
|
87
|
+
|
88
|
+
class NoValidConnectionsError(SSHException):
|
89
|
+
"""Raised when SSH connection cannot be established"""
|
90
|
+
|
91
|
+
pass
|
ttnn_visualizer/file_uploads.py
CHANGED
ttnn_visualizer/instances.py
CHANGED
@@ -94,8 +94,7 @@ def update_existing_instance(
|
|
94
94
|
else:
|
95
95
|
if active_report.get("npe_name"):
|
96
96
|
instance_data.npe_path = get_npe_path(
|
97
|
-
npe_name=active_report["npe_name"],
|
98
|
-
current_app=current_app
|
97
|
+
npe_name=active_report["npe_name"], current_app=current_app
|
99
98
|
)
|
100
99
|
|
101
100
|
|
@@ -148,17 +147,25 @@ def create_new_instance(
|
|
148
147
|
instance_data = InstanceTable(
|
149
148
|
instance_id=instance_id,
|
150
149
|
active_report=active_report,
|
151
|
-
profiler_path=
|
152
|
-
|
153
|
-
|
154
|
-
|
150
|
+
profiler_path=(
|
151
|
+
profiler_path
|
152
|
+
if profiler_path is not _sentinel
|
153
|
+
else get_profiler_path(
|
154
|
+
active_report["profiler_name"],
|
155
|
+
current_app=current_app,
|
156
|
+
remote_connection=remote_connection,
|
157
|
+
)
|
155
158
|
),
|
156
159
|
remote_connection=(
|
157
160
|
remote_connection.model_dump() if remote_connection else None
|
158
161
|
),
|
159
|
-
remote_profiler_folder=
|
162
|
+
remote_profiler_folder=(
|
163
|
+
remote_profiler_folder.model_dump() if remote_profiler_folder else None
|
164
|
+
),
|
160
165
|
remote_performance_folder=(
|
161
|
-
remote_performance_folder.model_dump()
|
166
|
+
remote_performance_folder.model_dump()
|
167
|
+
if remote_performance_folder
|
168
|
+
else None
|
162
169
|
),
|
163
170
|
)
|
164
171
|
|
@@ -255,10 +262,18 @@ def get_or_create_instance(
|
|
255
262
|
db.session.commit()
|
256
263
|
except IntegrityError:
|
257
264
|
db.session.rollback()
|
258
|
-
instance_data = InstanceTable.query.filter_by(
|
265
|
+
instance_data = InstanceTable.query.filter_by(
|
266
|
+
instance_id=instance_id
|
267
|
+
).first()
|
259
268
|
|
260
269
|
# Update the instance if any new data is provided
|
261
|
-
if
|
270
|
+
if (
|
271
|
+
profiler_name
|
272
|
+
or performance_name
|
273
|
+
or npe_name
|
274
|
+
or remote_connection
|
275
|
+
or remote_profiler_folder
|
276
|
+
):
|
262
277
|
update_instance(
|
263
278
|
instance_id=instance_id,
|
264
279
|
profiler_name=profiler_name,
|
@@ -269,7 +284,9 @@ def get_or_create_instance(
|
|
269
284
|
)
|
270
285
|
|
271
286
|
# Query again to get the updated instance data
|
272
|
-
instance_data = InstanceTable.query.filter_by(
|
287
|
+
instance_data = InstanceTable.query.filter_by(
|
288
|
+
instance_id=instance_id
|
289
|
+
).first()
|
273
290
|
|
274
291
|
return instance_data
|
275
292
|
|
@@ -318,7 +335,7 @@ def init_instances(app):
|
|
318
335
|
|
319
336
|
|
320
337
|
def create_random_instance_id():
|
321
|
-
return
|
338
|
+
return "".join(random.choices(string.ascii_lowercase + string.digits, k=45))
|
322
339
|
|
323
340
|
|
324
341
|
def create_instance_from_local_paths(profiler_path, performance_path):
|
@@ -328,11 +345,21 @@ def create_instance_from_local_paths(profiler_path, performance_path):
|
|
328
345
|
if _profiler_path and (not _profiler_path.exists() or not _profiler_path.is_dir()):
|
329
346
|
raise InvalidReportPath()
|
330
347
|
|
331
|
-
if _performance_path and (
|
348
|
+
if _performance_path and (
|
349
|
+
not _performance_path.exists() or not _performance_path.is_dir()
|
350
|
+
):
|
332
351
|
raise InvalidProfilerPath()
|
333
352
|
|
334
|
-
profiler_name =
|
335
|
-
|
353
|
+
profiler_name = (
|
354
|
+
_profiler_path.parts[-1]
|
355
|
+
if _profiler_path and len(_profiler_path.parts) > 2
|
356
|
+
else ""
|
357
|
+
)
|
358
|
+
performance_name = (
|
359
|
+
_performance_path.parts[-1]
|
360
|
+
if _performance_path and len(_performance_path.parts) > 2
|
361
|
+
else ""
|
362
|
+
)
|
336
363
|
instance_data = InstanceTable(
|
337
364
|
instance_id=create_random_instance_id(),
|
338
365
|
active_report={
|