ttnn-visualizer 0.42.0__py3-none-any.whl → 0.43.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. ttnn_visualizer/__init__.py +0 -1
  2. ttnn_visualizer/app.py +15 -4
  3. ttnn_visualizer/csv_queries.py +82 -48
  4. ttnn_visualizer/decorators.py +38 -15
  5. ttnn_visualizer/exceptions.py +29 -1
  6. ttnn_visualizer/file_uploads.py +1 -0
  7. ttnn_visualizer/instances.py +42 -15
  8. ttnn_visualizer/models.py +12 -7
  9. ttnn_visualizer/remote_sqlite_setup.py +37 -30
  10. ttnn_visualizer/requirements.txt +1 -0
  11. ttnn_visualizer/serializers.py +1 -0
  12. ttnn_visualizer/settings.py +9 -5
  13. ttnn_visualizer/sftp_operations.py +144 -125
  14. ttnn_visualizer/sockets.py +9 -3
  15. ttnn_visualizer/static/assets/{allPaths-wwXsGKJ2.js → allPaths-BQN_j7ek.js} +1 -1
  16. ttnn_visualizer/static/assets/{allPathsLoader-BK9jqlVe.js → allPathsLoader-BvkkQ77q.js} +2 -2
  17. ttnn_visualizer/static/assets/index-B-fsa5Ru.js +1 -0
  18. ttnn_visualizer/static/assets/{index-Ybr1HJxx.js → index-Bng0kcmi.js} +69 -69
  19. ttnn_visualizer/static/assets/{index-C1rJBrMl.css → index-C-t6jBt9.css} +1 -1
  20. ttnn_visualizer/static/assets/index-DLOviMB1.js +1 -0
  21. ttnn_visualizer/static/assets/{splitPathsBySizeLoader-CauQGZHk.js → splitPathsBySizeLoader-Cl0NRdfL.js} +1 -1
  22. ttnn_visualizer/static/index.html +2 -2
  23. ttnn_visualizer/tests/__init__.py +0 -1
  24. ttnn_visualizer/tests/test_queries.py +0 -1
  25. ttnn_visualizer/tests/test_serializers.py +2 -2
  26. ttnn_visualizer/utils.py +7 -3
  27. ttnn_visualizer/views.py +250 -82
  28. {ttnn_visualizer-0.42.0.dist-info → ttnn_visualizer-0.43.0.dist-info}/METADATA +5 -1
  29. ttnn_visualizer-0.43.0.dist-info/RECORD +45 -0
  30. ttnn_visualizer/static/assets/index-BKzgFDAn.js +0 -1
  31. ttnn_visualizer/static/assets/index-BvSuWPlB.js +0 -1
  32. ttnn_visualizer-0.42.0.dist-info/RECORD +0 -45
  33. {ttnn_visualizer-0.42.0.dist-info → ttnn_visualizer-0.43.0.dist-info}/LICENSE +0 -0
  34. {ttnn_visualizer-0.42.0.dist-info → ttnn_visualizer-0.43.0.dist-info}/LICENSE_understanding.txt +0 -0
  35. {ttnn_visualizer-0.42.0.dist-info → ttnn_visualizer-0.43.0.dist-info}/WHEEL +0 -0
  36. {ttnn_visualizer-0.42.0.dist-info → ttnn_visualizer-0.43.0.dist-info}/entry_points.txt +0 -0
  37. {ttnn_visualizer-0.42.0.dist-info → ttnn_visualizer-0.43.0.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,3 @@
1
1
  # SPDX-License-Identifier: Apache-2.0
2
2
  #
3
3
  # SPDX-FileCopyrightText: © 2025 Tenstorrent AI ULC
4
-
ttnn_visualizer/app.py CHANGED
@@ -21,7 +21,11 @@ from flask_cors import CORS
21
21
  from werkzeug.debug import DebuggedApplication
22
22
  from werkzeug.middleware.proxy_fix import ProxyFix
23
23
 
24
- from ttnn_visualizer.exceptions import DatabaseFileNotFoundException, InvalidProfilerPath, InvalidReportPath
24
+ from ttnn_visualizer.exceptions import (
25
+ DatabaseFileNotFoundException,
26
+ InvalidProfilerPath,
27
+ InvalidReportPath,
28
+ )
25
29
  from ttnn_visualizer.instances import create_instance_from_local_paths
26
30
  from ttnn_visualizer.settings import Config, DefaultConfig
27
31
 
@@ -65,6 +69,7 @@ def create_app(settings_override=None):
65
69
  extensions(app)
66
70
 
67
71
  if flask_env == "production":
72
+
68
73
  @app.route(f"{app.config['BASE_PATH']}", defaults={"path": ""})
69
74
  @app.route(f"{app.config['BASE_PATH']}<path:path>")
70
75
  def catch_all(path):
@@ -179,9 +184,15 @@ def open_browser(host, port, instance_id=None):
179
184
 
180
185
 
181
186
  def parse_args():
182
- parser = argparse.ArgumentParser(description="A tool for visualizing the Tenstorrent Neural Network model (TT-NN)")
183
- parser.add_argument("--profiler-path", type=str, help="Specify a profiler path", default=None)
184
- parser.add_argument("--performance-path", help="Specify a performance path", default=None)
187
+ parser = argparse.ArgumentParser(
188
+ description="A tool for visualizing the Tenstorrent Neural Network model (TT-NN)"
189
+ )
190
+ parser.add_argument(
191
+ "--profiler-path", type=str, help="Specify a profiler path", default=None
192
+ )
193
+ parser.add_argument(
194
+ "--performance-path", help="Specify a performance path", default=None
195
+ )
185
196
  return parser.parse_args()
186
197
 
187
198
 
@@ -11,16 +11,22 @@ from pathlib import Path
11
11
  from typing import List, Dict, Union, Optional
12
12
 
13
13
  import pandas as pd
14
+ import zstd
14
15
  from tt_perf_report import perf_report
15
16
 
16
17
  from ttnn_visualizer.exceptions import DataFormatError
18
+ from ttnn_visualizer.exceptions import (
19
+ SSHException,
20
+ AuthenticationException,
21
+ NoValidConnectionsError,
22
+ )
17
23
  from ttnn_visualizer.models import Instance, RemoteConnection
18
- from ttnn_visualizer.exceptions import SSHException, AuthenticationException, NoValidConnectionsError
19
- from ttnn_visualizer.models import Instance
20
24
  from ttnn_visualizer.sftp_operations import read_remote_file
21
25
 
22
26
 
23
- def handle_ssh_subprocess_error(e: subprocess.CalledProcessError, remote_connection: RemoteConnection):
27
+ def handle_ssh_subprocess_error(
28
+ e: subprocess.CalledProcessError, remote_connection: RemoteConnection
29
+ ):
24
30
  """
25
31
  Convert subprocess SSH errors to appropriate SSH exceptions.
26
32
 
@@ -31,23 +37,29 @@ def handle_ssh_subprocess_error(e: subprocess.CalledProcessError, remote_connect
31
37
  stderr = e.stderr.lower() if e.stderr else ""
32
38
 
33
39
  # Check for authentication failures
34
- if any(auth_err in stderr for auth_err in [
35
- "permission denied",
36
- "authentication failed",
37
- "publickey",
38
- "password",
39
- "host key verification failed"
40
- ]):
40
+ if any(
41
+ auth_err in stderr
42
+ for auth_err in [
43
+ "permission denied",
44
+ "authentication failed",
45
+ "publickey",
46
+ "password",
47
+ "host key verification failed",
48
+ ]
49
+ ):
41
50
  raise AuthenticationException(f"SSH authentication failed: {e.stderr}")
42
51
 
43
52
  # Check for connection failures
44
- elif any(conn_err in stderr for conn_err in [
45
- "connection refused",
46
- "network is unreachable",
47
- "no route to host",
48
- "name or service not known",
49
- "connection timed out"
50
- ]):
53
+ elif any(
54
+ conn_err in stderr
55
+ for conn_err in [
56
+ "connection refused",
57
+ "network is unreachable",
58
+ "no route to host",
59
+ "name or service not known",
60
+ "connection timed out",
61
+ ]
62
+ ):
51
63
  raise NoValidConnectionsError(f"SSH connection failed: {e.stderr}")
52
64
 
53
65
  # Check for general SSH protocol errors
@@ -58,6 +70,7 @@ def handle_ssh_subprocess_error(e: subprocess.CalledProcessError, remote_connect
58
70
  else:
59
71
  raise SSHException(f"SSH command failed: {e.stderr}")
60
72
 
73
+
61
74
  class LocalCSVQueryRunner:
62
75
  def __init__(self, file_path: str, offset: int = 0):
63
76
  self.file_path = file_path
@@ -154,24 +167,22 @@ class RemoteCSVQueryRunner:
154
167
 
155
168
  def _execute_ssh_command(self, command: str) -> str:
156
169
  """Execute an SSH command and return the output."""
157
- ssh_cmd = ["ssh"]
158
-
170
+ ssh_cmd = ["ssh", "-o", "PasswordAuthentication=no"]
171
+
159
172
  # Handle non-standard SSH port
160
173
  if self.remote_connection.port != 22:
161
174
  ssh_cmd.extend(["-p", str(self.remote_connection.port)])
162
-
163
- ssh_cmd.extend([
164
- f"{self.remote_connection.username}@{self.remote_connection.host}",
165
- command
166
- ])
167
-
175
+
176
+ ssh_cmd.extend(
177
+ [
178
+ f"{self.remote_connection.username}@{self.remote_connection.host}",
179
+ command,
180
+ ]
181
+ )
182
+
168
183
  try:
169
184
  result = subprocess.run(
170
- ssh_cmd,
171
- capture_output=True,
172
- text=True,
173
- check=True,
174
- timeout=30
185
+ ssh_cmd, capture_output=True, text=True, check=True, timeout=30
175
186
  )
176
187
  return result.stdout
177
188
  except subprocess.CalledProcessError as e:
@@ -269,7 +280,7 @@ class RemoteCSVQueryRunner:
269
280
  )
270
281
  output = self._execute_ssh_command(cmd).strip()
271
282
 
272
- return output.splitlines()[self.offset:]
283
+ return output.splitlines()[self.offset :]
273
284
 
274
285
  def get_csv_header(self) -> Dict[str, int]:
275
286
  """
@@ -324,7 +335,9 @@ class NPEQueries:
324
335
  and not instance.remote_connection.useRemoteQuerying
325
336
  ):
326
337
  file_path = Path(
327
- instance.performance_path, NPEQueries.NPE_FOLDER, NPEQueries.MANIFEST_FILE
338
+ instance.performance_path,
339
+ NPEQueries.NPE_FOLDER,
340
+ NPEQueries.MANIFEST_FILE,
328
341
  )
329
342
  with open(file_path, "r") as f:
330
343
  return json.load(f)
@@ -338,7 +351,9 @@ class NPEQueries:
338
351
  @staticmethod
339
352
  def get_npe_timeline(instance: Instance, filename: str):
340
353
  if not filename:
341
- raise ValueError("filename parameter is required and cannot be None or empty")
354
+ raise ValueError(
355
+ "filename parameter is required and cannot be None or empty"
356
+ )
342
357
 
343
358
  if (
344
359
  not instance.remote_connection
@@ -347,18 +362,33 @@ class NPEQueries:
347
362
  if not instance.performance_path:
348
363
  raise ValueError("instance.performance_path is None")
349
364
 
350
- file_path = Path(
351
- instance.performance_path, NPEQueries.NPE_FOLDER, filename
352
- )
353
- with open(file_path, "r") as f:
354
- return json.load(f)
365
+ file_path = Path(instance.performance_path, NPEQueries.NPE_FOLDER, filename)
366
+
367
+ if filename.endswith(".zst"):
368
+ with open(file_path, "rb") as file:
369
+ compressed_data = file.read()
370
+ uncompressed_data = zstd.uncompress(compressed_data)
371
+ return json.loads(uncompressed_data)
372
+ else:
373
+ with open(file_path, "r") as f:
374
+ return json.load(f)
375
+
355
376
  else:
356
377
  profiler_folder = instance.remote_profile_folder
357
- return read_remote_file(
358
- instance.remote_connection,
359
- f"{profiler_folder.remotePath}/{NPEQueries.NPE_FOLDER}/{filename}",
378
+ remote_path = (
379
+ f"{profiler_folder.remotePath}/{NPEQueries.NPE_FOLDER}/{filename}"
360
380
  )
381
+ remote_data = read_remote_file(instance.remote_connection, remote_path)
361
382
 
383
+ if filename.endswith(".zst"):
384
+ if isinstance(remote_data, str):
385
+ remote_data = remote_data.encode("utf-8")
386
+ uncompressed_data = zstd.decompress(remote_data)
387
+ return json.loads(uncompressed_data)
388
+ else:
389
+ if isinstance(remote_data, bytes):
390
+ remote_data = remote_data.decode("utf-8")
391
+ return json.loads(remote_data)
362
392
 
363
393
 
364
394
  class DeviceLogProfilerQueries:
@@ -614,7 +644,9 @@ class OpsPerformanceQueries:
614
644
  or instance.remote_connection
615
645
  and not instance.remote_connection.useRemoteQuerying
616
646
  ):
617
- with open(OpsPerformanceQueries.get_local_ops_perf_file_path(instance)) as f:
647
+ with open(
648
+ OpsPerformanceQueries.get_local_ops_perf_file_path(instance)
649
+ ) as f:
618
650
  return f.read()
619
651
  else:
620
652
  path = OpsPerformanceQueries.get_remote_ops_perf_file_path(instance)
@@ -656,9 +688,7 @@ class OpsPerformanceQueries:
656
688
  """
657
689
  try:
658
690
  return [
659
- folder.name
660
- for folder in Path(directory).iterdir()
661
- if folder.is_dir()
691
+ folder.name for folder in Path(directory).iterdir() if folder.is_dir()
662
692
  ]
663
693
  except Exception as e:
664
694
  raise RuntimeError(f"Error accessing directory: {e}")
@@ -688,7 +718,7 @@ class OpsPerformanceReportQueries:
688
718
  "output_subblock_w",
689
719
  "global_call_count",
690
720
  "advice",
691
- "raw_op_code"
721
+ "raw_op_code",
692
722
  ]
693
723
 
694
724
  PASSTHROUGH_COLUMNS = {
@@ -734,7 +764,9 @@ class OpsPerformanceReportQueries:
734
764
  next(reader, None)
735
765
  for row in reader:
736
766
  processed_row = {
737
- column: row[index] for index, column in enumerate(cls.REPORT_COLUMNS) if index < len(row)
767
+ column: row[index]
768
+ for index, column in enumerate(cls.REPORT_COLUMNS)
769
+ if index < len(row)
738
770
  }
739
771
  if "advice" in processed_row and processed_row["advice"]:
740
772
  processed_row["advice"] = processed_row["advice"].split(" • ")
@@ -743,7 +775,9 @@ class OpsPerformanceReportQueries:
743
775
 
744
776
  for key, value in cls.PASSTHROUGH_COLUMNS.items():
745
777
  op_id = int(row[0])
746
- idx = op_id - 2 # IDs in result column one correspond to row numbers in ops perf results csv
778
+ idx = (
779
+ op_id - 2
780
+ ) # IDs in result column one correspond to row numbers in ops perf results csv
747
781
  processed_row[key] = ops_perf_results[idx][value]
748
782
 
749
783
  report.append(processed_row)
@@ -14,6 +14,7 @@ from ttnn_visualizer.exceptions import (
14
14
  NoValidConnectionsError,
15
15
  SSHException,
16
16
  RemoteConnectionException,
17
+ AuthenticationFailedException,
17
18
  NoProjectsException,
18
19
  RemoteSqliteException,
19
20
  )
@@ -34,21 +35,23 @@ def with_instance(func):
34
35
  abort(404)
35
36
 
36
37
  instance_query_data = get_or_create_instance(instance_id=instance_id)
37
-
38
+
38
39
  # Handle case where get_or_create_instance returns None due to database error
39
40
  if instance_query_data is None:
40
- current_app.logger.error(f"Failed to get or create instance with ID: {instance_id}")
41
+ current_app.logger.error(
42
+ f"Failed to get or create instance with ID: {instance_id}"
43
+ )
41
44
  abort(500)
42
-
45
+
43
46
  instance = instance_query_data.to_pydantic()
44
47
 
45
48
  kwargs["instance"] = instance
46
49
 
47
- if 'instances' not in session:
48
- session['instances'] = []
50
+ if "instances" not in session:
51
+ session["instances"] = []
49
52
 
50
- if instance.instance_id not in session['instances']:
51
- session['instances'] = session.get('instances', []) + [instance.instance_id]
53
+ if instance.instance_id not in session["instances"]:
54
+ session["instances"] = session.get("instances", []) + [instance.instance_id]
52
55
 
53
56
  return func(*args, **kwargs)
54
57
 
@@ -68,10 +71,21 @@ def remote_exception_handler(func):
68
71
  try:
69
72
  return func(*args, **kwargs)
70
73
  except AuthenticationException as err:
71
- current_app.logger.error(f"Authentication failed {err}")
72
- raise RemoteConnectionException(
74
+ # Log the detailed error for debugging, but don't show full traceback
75
+ current_app.logger.warning(
76
+ f"SSH authentication failed for {connection.username}@{connection.host}: SSH key authentication required"
77
+ )
78
+
79
+ # Return user-friendly error message about SSH keys
80
+ user_message = (
81
+ "SSH authentication failed. This application requires SSH key-based authentication. "
82
+ "Please ensure your SSH public key is added to the authorized_keys file on the remote server. "
83
+ "Password authentication is not supported."
84
+ )
85
+
86
+ raise AuthenticationFailedException(
87
+ message=user_message,
73
88
  status=ConnectionTestStates.FAILED,
74
- message=f"Unable to authenticate: {str(err)}",
75
89
  )
76
90
  except FileNotFoundError as err:
77
91
  current_app.logger.error(f"File not found: {str(err)}")
@@ -86,11 +100,20 @@ def remote_exception_handler(func):
86
100
  message=f"No projects found at remote location: {connection.path}",
87
101
  )
88
102
  except NoValidConnectionsError as err:
89
- current_app.logger.error(f"No valid connections: {str(err)}")
90
- message = re.sub(r"\[.*?]", "", str(err)).strip()
103
+ current_app.logger.warning(
104
+ f"SSH connection failed for {connection.username}@{connection.host}: {str(err)}"
105
+ )
106
+
107
+ # Provide user-friendly message for connection issues
108
+ user_message = (
109
+ f"Unable to establish SSH connection to {connection.host}. "
110
+ "Please check the hostname, port, and network connectivity. "
111
+ "Ensure SSH key-based authentication is properly configured."
112
+ )
113
+
91
114
  raise RemoteConnectionException(
92
115
  status=ConnectionTestStates.FAILED,
93
- message=f"{message}",
116
+ message=user_message,
94
117
  )
95
118
 
96
119
  except RemoteSqliteException as err:
@@ -111,10 +134,10 @@ def remote_exception_handler(func):
111
134
  )
112
135
  except SSHException as err:
113
136
  if str(err) == "No existing session":
114
- message = "Authentication failed - check credentials and ssh-agent"
137
+ message = "SSH authentication failed. Please ensure SSH keys are configured and ssh-agent is running."
115
138
  else:
116
139
  err_message = re.sub(r"\[.*?]", "", str(err)).strip()
117
- message = f"Error connecting to host {connection.host}: {err_message}"
140
+ message = f"SSH connection error to {connection.host}: {err_message}. Ensure SSH key-based authentication is properly configured."
118
141
 
119
142
  raise RemoteConnectionException(
120
143
  status=ConnectionTestStates.FAILED, message=message
@@ -3,24 +3,49 @@
3
3
  # SPDX-FileCopyrightText: © 2025 Tenstorrent AI ULC
4
4
 
5
5
  from http import HTTPStatus
6
+ from typing import Optional
6
7
 
7
8
  from ttnn_visualizer.enums import ConnectionTestStates
8
9
 
9
10
 
10
11
  class RemoteConnectionException(Exception):
11
- def __init__(self, message, status: ConnectionTestStates):
12
+ def __init__(
13
+ self,
14
+ message,
15
+ status: ConnectionTestStates,
16
+ http_status_code: Optional[HTTPStatus] = None,
17
+ ):
12
18
  super().__init__(message)
13
19
  self.message = message
14
20
  self.status = status
21
+ self._http_status_code = http_status_code
15
22
 
16
23
  @property
17
24
  def http_status(self):
25
+ # Use custom HTTP status code if provided
26
+ if self._http_status_code is not None:
27
+ return self._http_status_code
28
+
29
+ # Default behavior
18
30
  if self.status == ConnectionTestStates.FAILED:
19
31
  return HTTPStatus.INTERNAL_SERVER_ERROR
20
32
  if self.status == ConnectionTestStates.OK:
21
33
  return HTTPStatus.OK
22
34
 
23
35
 
36
+ class AuthenticationFailedException(RemoteConnectionException):
37
+ """Exception for SSH authentication failures that should return HTTP 422"""
38
+
39
+ def __init__(
40
+ self, message, status: ConnectionTestStates = ConnectionTestStates.FAILED
41
+ ):
42
+ super().__init__(
43
+ message=message,
44
+ status=status,
45
+ http_status_code=HTTPStatus.UNPROCESSABLE_ENTITY, # 422
46
+ )
47
+
48
+
24
49
  class NoProjectsException(RemoteConnectionException):
25
50
  pass
26
51
 
@@ -50,14 +75,17 @@ class InvalidProfilerPath(Exception):
50
75
 
51
76
  class SSHException(Exception):
52
77
  """Base SSH exception for subprocess SSH operations"""
78
+
53
79
  pass
54
80
 
55
81
 
56
82
  class AuthenticationException(SSHException):
57
83
  """Raised when SSH authentication fails"""
84
+
58
85
  pass
59
86
 
60
87
 
61
88
  class NoValidConnectionsError(SSHException):
62
89
  """Raised when SSH connection cannot be established"""
90
+
63
91
  pass
@@ -54,6 +54,7 @@ def extract_folder_name_from_files(files):
54
54
  unsplit_name = str(files[0].filename)
55
55
  return unsplit_name.split("/")[0]
56
56
 
57
+
57
58
  def extract_npe_name(files):
58
59
  if not files:
59
60
  return None
@@ -94,8 +94,7 @@ def update_existing_instance(
94
94
  else:
95
95
  if active_report.get("npe_name"):
96
96
  instance_data.npe_path = get_npe_path(
97
- npe_name=active_report["npe_name"],
98
- current_app=current_app
97
+ npe_name=active_report["npe_name"], current_app=current_app
99
98
  )
100
99
 
101
100
 
@@ -148,17 +147,25 @@ def create_new_instance(
148
147
  instance_data = InstanceTable(
149
148
  instance_id=instance_id,
150
149
  active_report=active_report,
151
- profiler_path=profiler_path if profiler_path is not _sentinel else get_profiler_path(
152
- active_report["profiler_name"],
153
- current_app=current_app,
154
- remote_connection=remote_connection,
150
+ profiler_path=(
151
+ profiler_path
152
+ if profiler_path is not _sentinel
153
+ else get_profiler_path(
154
+ active_report["profiler_name"],
155
+ current_app=current_app,
156
+ remote_connection=remote_connection,
157
+ )
155
158
  ),
156
159
  remote_connection=(
157
160
  remote_connection.model_dump() if remote_connection else None
158
161
  ),
159
- remote_profiler_folder=remote_profiler_folder.model_dump() if remote_profiler_folder else None,
162
+ remote_profiler_folder=(
163
+ remote_profiler_folder.model_dump() if remote_profiler_folder else None
164
+ ),
160
165
  remote_performance_folder=(
161
- remote_performance_folder.model_dump() if remote_performance_folder else None
166
+ remote_performance_folder.model_dump()
167
+ if remote_performance_folder
168
+ else None
162
169
  ),
163
170
  )
164
171
 
@@ -255,10 +262,18 @@ def get_or_create_instance(
255
262
  db.session.commit()
256
263
  except IntegrityError:
257
264
  db.session.rollback()
258
- instance_data = InstanceTable.query.filter_by(instance_id=instance_id).first()
265
+ instance_data = InstanceTable.query.filter_by(
266
+ instance_id=instance_id
267
+ ).first()
259
268
 
260
269
  # Update the instance if any new data is provided
261
- if profiler_name or performance_name or npe_name or remote_connection or remote_profiler_folder:
270
+ if (
271
+ profiler_name
272
+ or performance_name
273
+ or npe_name
274
+ or remote_connection
275
+ or remote_profiler_folder
276
+ ):
262
277
  update_instance(
263
278
  instance_id=instance_id,
264
279
  profiler_name=profiler_name,
@@ -269,7 +284,9 @@ def get_or_create_instance(
269
284
  )
270
285
 
271
286
  # Query again to get the updated instance data
272
- instance_data = InstanceTable.query.filter_by(instance_id=instance_id).first()
287
+ instance_data = InstanceTable.query.filter_by(
288
+ instance_id=instance_id
289
+ ).first()
273
290
 
274
291
  return instance_data
275
292
 
@@ -318,7 +335,7 @@ def init_instances(app):
318
335
 
319
336
 
320
337
  def create_random_instance_id():
321
- return ''.join(random.choices(string.ascii_lowercase + string.digits, k=45))
338
+ return "".join(random.choices(string.ascii_lowercase + string.digits, k=45))
322
339
 
323
340
 
324
341
  def create_instance_from_local_paths(profiler_path, performance_path):
@@ -328,11 +345,21 @@ def create_instance_from_local_paths(profiler_path, performance_path):
328
345
  if _profiler_path and (not _profiler_path.exists() or not _profiler_path.is_dir()):
329
346
  raise InvalidReportPath()
330
347
 
331
- if _performance_path and (not _performance_path.exists() or not _performance_path.is_dir()):
348
+ if _performance_path and (
349
+ not _performance_path.exists() or not _performance_path.is_dir()
350
+ ):
332
351
  raise InvalidProfilerPath()
333
352
 
334
- profiler_name = _profiler_path.parts[-1] if _profiler_path and len(_profiler_path.parts) > 2 else ""
335
- performance_name = _performance_path.parts[-1] if _performance_path and len(_performance_path.parts) > 2 else ""
353
+ profiler_name = (
354
+ _profiler_path.parts[-1]
355
+ if _profiler_path and len(_profiler_path.parts) > 2
356
+ else ""
357
+ )
358
+ performance_name = (
359
+ _performance_path.parts[-1]
360
+ if _performance_path and len(_performance_path.parts) > 2
361
+ else ""
362
+ )
336
363
  instance_data = InstanceTable(
337
364
  instance_id=create_random_instance_id(),
338
365
  active_report={
ttnn_visualizer/models.py CHANGED
@@ -119,6 +119,7 @@ class Tensor(SerializeableDataclass):
119
119
  def __post_init__(self):
120
120
  self.memory_config = parse_memory_config(self.memory_config)
121
121
 
122
+
122
123
  @dataclasses.dataclass
123
124
  class InputTensor(SerializeableDataclass):
124
125
  operation_id: int
@@ -246,19 +247,21 @@ class InstanceTable(db.Model):
246
247
  "remote_performance_folder": self.remote_performance_folder,
247
248
  "profiler_path": self.profiler_path,
248
249
  "performance_path": self.performance_path,
249
- "npe_path": self.npe_path
250
+ "npe_path": self.npe_path,
250
251
  }
251
252
 
252
253
  def to_pydantic(self) -> Instance:
253
254
  return Instance(
254
255
  instance_id=str(self.instance_id),
255
- profiler_path=str(self.profiler_path) if self.profiler_path is not None else None,
256
- performance_path=(
257
- str(self.performance_path) if self.performance_path is not None else None
256
+ profiler_path=(
257
+ str(self.profiler_path) if self.profiler_path is not None else None
258
258
  ),
259
- npe_path=(
260
- str(self.npe_path) if self.npe_path is not None else None
259
+ performance_path=(
260
+ str(self.performance_path)
261
+ if self.performance_path is not None
262
+ else None
261
263
  ),
264
+ npe_path=(str(self.npe_path) if self.npe_path is not None else None),
262
265
  active_report=(
263
266
  (ActiveReports(**self.active_report) if self.active_report else None)
264
267
  if isinstance(self.active_report, dict)
@@ -270,7 +273,9 @@ class InstanceTable(db.Model):
270
273
  else None
271
274
  ),
272
275
  remote_profiler_folder=(
273
- RemoteReportFolder.model_validate(self.remote_profiler_folder, strict=False)
276
+ RemoteReportFolder.model_validate(
277
+ self.remote_profiler_folder, strict=False
278
+ )
274
279
  if self.remote_profiler_folder is not None
275
280
  else None
276
281
  ),