ttnn-visualizer 0.24.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. ttnn_visualizer/__init__.py +4 -0
  2. ttnn_visualizer/app.py +193 -0
  3. ttnn_visualizer/bin/docker-entrypoint-web +16 -0
  4. ttnn_visualizer/bin/pip3-install +17 -0
  5. ttnn_visualizer/csv_queries.py +618 -0
  6. ttnn_visualizer/decorators.py +117 -0
  7. ttnn_visualizer/enums.py +12 -0
  8. ttnn_visualizer/exceptions.py +40 -0
  9. ttnn_visualizer/extensions.py +14 -0
  10. ttnn_visualizer/file_uploads.py +78 -0
  11. ttnn_visualizer/models.py +275 -0
  12. ttnn_visualizer/queries.py +388 -0
  13. ttnn_visualizer/remote_sqlite_setup.py +91 -0
  14. ttnn_visualizer/requirements.txt +24 -0
  15. ttnn_visualizer/serializers.py +249 -0
  16. ttnn_visualizer/sessions.py +245 -0
  17. ttnn_visualizer/settings.py +118 -0
  18. ttnn_visualizer/sftp_operations.py +486 -0
  19. ttnn_visualizer/sockets.py +118 -0
  20. ttnn_visualizer/ssh_client.py +85 -0
  21. ttnn_visualizer/static/assets/allPaths-CKt4gwo3.js +1 -0
  22. ttnn_visualizer/static/assets/allPathsLoader-Dzw0zTnr.js +2 -0
  23. ttnn_visualizer/static/assets/index-BXlT2rEV.js +5247 -0
  24. ttnn_visualizer/static/assets/index-CsS_OkTl.js +1 -0
  25. ttnn_visualizer/static/assets/index-DTKBo2Os.css +7 -0
  26. ttnn_visualizer/static/assets/index-DxLGmC6o.js +1 -0
  27. ttnn_visualizer/static/assets/site-BTBrvHC5.webmanifest +19 -0
  28. ttnn_visualizer/static/assets/splitPathsBySizeLoader-HHqSPeQM.js +1 -0
  29. ttnn_visualizer/static/favicon/android-chrome-192x192.png +0 -0
  30. ttnn_visualizer/static/favicon/android-chrome-512x512.png +0 -0
  31. ttnn_visualizer/static/favicon/favicon-32x32.png +0 -0
  32. ttnn_visualizer/static/favicon/favicon.svg +3 -0
  33. ttnn_visualizer/static/index.html +36 -0
  34. ttnn_visualizer/static/sample-data/cluster-desc.yaml +763 -0
  35. ttnn_visualizer/tests/__init__.py +4 -0
  36. ttnn_visualizer/tests/test_queries.py +444 -0
  37. ttnn_visualizer/tests/test_serializers.py +582 -0
  38. ttnn_visualizer/utils.py +185 -0
  39. ttnn_visualizer/views.py +794 -0
  40. ttnn_visualizer-0.24.0.dist-info/LICENSE +202 -0
  41. ttnn_visualizer-0.24.0.dist-info/LICENSE_understanding.txt +3 -0
  42. ttnn_visualizer-0.24.0.dist-info/METADATA +144 -0
  43. ttnn_visualizer-0.24.0.dist-info/RECORD +46 -0
  44. ttnn_visualizer-0.24.0.dist-info/WHEEL +5 -0
  45. ttnn_visualizer-0.24.0.dist-info/entry_points.txt +2 -0
  46. ttnn_visualizer-0.24.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,388 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ #
3
+ # SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC
4
+
5
+ import json
6
+ from typing import Generator, Dict, Any, Union
7
+
8
+ from ttnn_visualizer.exceptions import (
9
+ DatabaseFileNotFoundException,
10
+ )
11
+ from ttnn_visualizer.models import (
12
+ Operation,
13
+ DeviceOperation,
14
+ Buffer,
15
+ BufferPage,
16
+ TabSession,
17
+ Tensor,
18
+ OperationArgument,
19
+ StackTrace,
20
+ InputTensor,
21
+ OutputTensor,
22
+ Device,
23
+ ProducersConsumers,
24
+ TensorComparisonRecord,
25
+ )
26
+ from ttnn_visualizer.ssh_client import get_client
27
+ import sqlite3
28
+ from typing import List, Optional
29
+ from pathlib import Path
30
+ import paramiko
31
+
32
+
33
+ class LocalQueryRunner:
34
+ def __init__(self, session: Optional[TabSession] = None, connection=None):
35
+
36
+ if connection:
37
+ self.connection = connection
38
+ else:
39
+ if not session or not session.report_path:
40
+ raise ValueError("Report path must be provided for local queries")
41
+ db_path = str(session.report_path)
42
+ if not Path(db_path).exists():
43
+ raise DatabaseFileNotFoundException(
44
+ f"Database not found at path: {db_path}"
45
+ )
46
+ self.connection = sqlite3.connect(
47
+ session.report_path, isolation_level=None, timeout=30
48
+ )
49
+
50
+ def execute_query(self, query: str, params: Optional[List] = None) -> List:
51
+ """
52
+ Executes a query locally using SQLite.
53
+ """
54
+ cursor = self.connection.cursor()
55
+ try:
56
+ cursor.execute(query, params or [])
57
+ return cursor.fetchall()
58
+ finally:
59
+ cursor.close()
60
+
61
+ def close(self):
62
+ if self.connection:
63
+ self.connection.close()
64
+
65
+
66
+ class RemoteQueryRunner:
67
+ column_delimiter = "|||"
68
+
69
+ def __init__(self, session: TabSession):
70
+ self.session = session
71
+ self._validate_session()
72
+ self.ssh_client = self._get_ssh_client(self.session.remote_connection)
73
+ self.sqlite_binary = self.session.remote_connection.sqliteBinaryPath
74
+ self.remote_db_path = str(
75
+ Path(self.session.remote_folder.remotePath, "db.sqlite")
76
+ )
77
+
78
+ def _validate_session(self):
79
+ """
80
+ Validate that the session has all required remote connection attributes.
81
+ """
82
+ if (
83
+ not self.session.remote_connection
84
+ or not self.session.remote_connection.sqliteBinaryPath
85
+ or not self.session.remote_folder
86
+ or not self.session.remote_folder.remotePath
87
+ ):
88
+ raise ValueError(
89
+ "Remote connections require remote path and sqliteBinaryPath"
90
+ )
91
+
92
+ def _get_ssh_client(self, remote_connection) -> paramiko.SSHClient:
93
+ """
94
+ Retrieve the SSH client for the given remote connection.
95
+ """
96
+ return get_client(remote_connection=remote_connection)
97
+
98
+ def _format_query(self, query: str, params: Optional[List] = None) -> str:
99
+ """
100
+ Format the query by replacing placeholders with properly quoted parameters.
101
+ """
102
+ if not params:
103
+ return query
104
+
105
+ formatted_params = [
106
+ f"'{param}'" if isinstance(param, str) else str(param) for param in params
107
+ ]
108
+ return query.replace("?", "{}").format(*formatted_params)
109
+
110
+ def _build_command(self, formatted_query: str) -> str:
111
+ """
112
+ Build the remote SQLite command.
113
+ """
114
+ return f'{self.sqlite_binary} {self.remote_db_path} "{formatted_query}" -json'
115
+
116
+ def _execute_ssh_command(self, command: str) -> tuple:
117
+ """
118
+ Execute the SSH command and return the standard output and error.
119
+ """
120
+ stdin, stdout, stderr = self.ssh_client.exec_command(command)
121
+ output = stdout.read().decode("utf-8").strip()
122
+ error_output = stderr.read().decode("utf-8").strip()
123
+ return output, error_output
124
+
125
+ def _parse_output(self, output: str, command: str) -> List:
126
+ """
127
+ Parse the output from the SQLite command. Attempt JSON parsing first,
128
+ then fall back to line-based parsing.
129
+ """
130
+ if not output.strip():
131
+ return []
132
+
133
+ try:
134
+ rows = json.loads(output)
135
+ return [tuple(row.values()) for row in rows]
136
+ except json.JSONDecodeError:
137
+ print(
138
+ f"Output is not valid JSON, attempting manual parsing.\nCommand: {command}"
139
+ )
140
+ return [tuple(line.split("|")) for line in output.splitlines()]
141
+
142
+ def execute_query(self, query: str, params: Optional[List] = None) -> List:
143
+ """
144
+ Execute a remote SQLite query using the session's SSH client.
145
+ """
146
+ self._validate_session()
147
+ formatted_query = self._format_query(query, params)
148
+ command = self._build_command(formatted_query)
149
+ output, error_output = self._execute_ssh_command(command)
150
+
151
+ if error_output:
152
+ raise RuntimeError(
153
+ f"Error executing query remotely: {error_output}\nCommand: {command}"
154
+ )
155
+
156
+ return self._parse_output(output, command)
157
+
158
+ def close(self):
159
+ """
160
+ Close the SSH connection.
161
+ """
162
+ if self.ssh_client:
163
+ self.ssh_client.close()
164
+
165
+
166
+ class DatabaseQueries:
167
+
168
+ session: Optional[TabSession] = None
169
+ ssh_client = None
170
+ query_runner: LocalQueryRunner | RemoteQueryRunner
171
+
172
+ def __init__(self, session: Optional[TabSession] = None, connection=None):
173
+ self.session = session
174
+
175
+ if connection:
176
+ self.query_runner = LocalQueryRunner(connection=connection)
177
+ else:
178
+ if not session:
179
+ raise ValueError(
180
+ "Must provide either an existing connection or session"
181
+ )
182
+ remote_connection = session.remote_connection if session else None
183
+ if remote_connection and remote_connection.useRemoteQuerying:
184
+ self.query_runner = RemoteQueryRunner(session=session)
185
+ else:
186
+ self.query_runner = LocalQueryRunner(session=session)
187
+
188
+ def _check_table_exists(self, table_name: str) -> bool:
189
+ """
190
+ Checks if a table exists in the database.
191
+ This method works for both local and remote databases.
192
+ """
193
+ # Properly format the table name into the query string with single quotes
194
+ query = "SELECT name FROM sqlite_master WHERE type = 'table' AND name = ?"
195
+
196
+ # Use the execute_query method to handle both local and remote cases.
197
+ rows = self.query_runner.execute_query(query, [table_name])
198
+
199
+ return bool(rows)
200
+
201
+ def _query_table(
202
+ self,
203
+ table_name: str,
204
+ filters: Optional[Dict[str, Union[Any, List[Any]]]] = None,
205
+ additional_conditions: Optional[str] = None,
206
+ additional_params: Optional[List[Any]] = None,
207
+ ) -> List[Any]:
208
+ query = f"SELECT * FROM {table_name} WHERE 1=1"
209
+ params = []
210
+
211
+ if filters:
212
+ for column, value in filters.items():
213
+ if value is None: # Skip filters with None values
214
+ continue
215
+
216
+ if isinstance(value, list): # Handle list-based filters
217
+ if len(value) == 0: # Skip empty lists
218
+ continue
219
+ placeholders = ", ".join(["?"] * len(value))
220
+ query += f" AND {column} IN ({placeholders})"
221
+ params.extend(value)
222
+ else:
223
+ query += f" AND {column} = ?"
224
+ params.append(value)
225
+
226
+ if additional_conditions:
227
+ query += f" {additional_conditions}"
228
+ if additional_params:
229
+ params.extend(additional_params)
230
+
231
+ return self.query_runner.execute_query(query, params)
232
+
233
+ def query_device_operations(
234
+ self, filters: Optional[Dict[str, Union[Any, List[Any]]]] = None
235
+ ) -> List[DeviceOperation]:
236
+ if not self._check_table_exists("captured_graph"):
237
+ return []
238
+ rows = self._query_table("captured_graph", filters)
239
+ return [DeviceOperation(*row) for row in rows]
240
+
241
+ def query_operation_arguments(
242
+ self, filters: Optional[Dict[str, Union[Any, List[Any]]]] = None
243
+ ) -> Generator[OperationArgument, None, None]:
244
+ rows = self._query_table("operation_arguments", filters)
245
+ for row in rows:
246
+ yield OperationArgument(*row)
247
+
248
+ def query_operations(
249
+ self, filters: Optional[Dict[str, Any]] = None
250
+ ) -> Generator[Operation, None, None]:
251
+ rows = self._query_table("operations", filters)
252
+ for row in rows:
253
+ yield Operation(*row)
254
+
255
+ def query_buffers(
256
+ self, filters: Optional[Dict[str, Any]] = None
257
+ ) -> Generator[Buffer, None, None]:
258
+ rows = self._query_table("buffers", filters)
259
+ for row in rows:
260
+ yield Buffer(*row)
261
+
262
+ def query_stack_traces(
263
+ self, filters: Optional[Dict[str, Any]] = None
264
+ ) -> Generator[StackTrace, None, None]:
265
+ rows = self._query_table("stack_traces", filters)
266
+ for row in rows:
267
+ operation_id, stack_trace = row
268
+ yield StackTrace(operation_id, stack_trace=stack_trace)
269
+
270
+ def query_tensor_comparisons(
271
+ self, local: bool = True, filters: Optional[Dict[str, Any]] = None
272
+ ) -> Generator[TensorComparisonRecord, None, None]:
273
+ if local:
274
+ table_name = "local_tensor_comparison_records"
275
+ else:
276
+ table_name = "global_tensor_comparison_records"
277
+ rows = self._query_table(table_name, filters)
278
+ for row in rows:
279
+ yield TensorComparisonRecord(*row)
280
+
281
+ def query_buffer_pages(
282
+ self, filters: Optional[Dict[str, Any]] = None
283
+ ) -> Generator[BufferPage, None, None]:
284
+ rows = self._query_table("buffer_pages", filters)
285
+ for row in rows:
286
+ yield BufferPage(*row)
287
+
288
+ def query_tensors(
289
+ self, filters: Optional[Dict[str, Any]] = None
290
+ ) -> Generator[Tensor, None, None]:
291
+ rows = self._query_table("tensors", filters)
292
+ for row in rows:
293
+ device_addresses = []
294
+
295
+ try:
296
+ device_tensors = self._query_table(
297
+ "device_tensors", filters={"tensor_id": row[0]}
298
+ )
299
+ except sqlite3.OperationalError as err:
300
+ if str(err).startswith("no such table"):
301
+ pass
302
+ else:
303
+ raise err
304
+ else:
305
+ for device_tensor in sorted(device_tensors, key=lambda x: x[1]):
306
+ while len(device_addresses) < device_tensor[1]:
307
+ device_addresses.append(None)
308
+ device_addresses.append(device_tensor[2])
309
+
310
+ yield Tensor(*row, device_addresses)
311
+
312
+ def query_input_tensors(
313
+ self, filters: Optional[Dict[str, Any]] = None
314
+ ) -> Generator[InputTensor, None, None]:
315
+ rows = self._query_table("input_tensors", filters)
316
+ for row in rows:
317
+ yield InputTensor(*row)
318
+
319
+ def query_output_tensors(
320
+ self, filters: Optional[Dict[str, Any]] = None
321
+ ) -> Generator[OutputTensor, None, None]:
322
+ rows = self._query_table("output_tensors", filters)
323
+ for row in rows:
324
+ yield OutputTensor(*row)
325
+
326
+ def query_devices(
327
+ self, filters: Optional[Dict[str, Any]] = None
328
+ ) -> Generator[Device, None, None]:
329
+ rows = self._query_table("devices", filters)
330
+ for row in rows:
331
+ yield Device(*row)
332
+
333
+ def query_producers_consumers(self) -> Generator[ProducersConsumers, None, None]:
334
+ query = """
335
+ SELECT
336
+ t.tensor_id,
337
+ GROUP_CONCAT(ot.operation_id, ', ') AS consumers,
338
+ GROUP_CONCAT(it.operation_id, ', ') AS producers
339
+ FROM
340
+ tensors t
341
+ LEFT JOIN
342
+ input_tensors it ON t.tensor_id = it.tensor_id
343
+ LEFT JOIN
344
+ output_tensors ot on t.tensor_id = ot.tensor_id
345
+ GROUP BY
346
+ t.tensor_id
347
+ """
348
+ rows = self.query_runner.execute_query(query)
349
+ for row in rows:
350
+ tensor_id, producers_data, consumers_data = row
351
+ producers = sorted(
352
+ set(map(int, producers_data.strip('"').split(",")))
353
+ if producers_data
354
+ else []
355
+ )
356
+ consumers = sorted(
357
+ set(map(int, consumers_data.strip('"').split(",")))
358
+ if consumers_data
359
+ else []
360
+ )
361
+ yield ProducersConsumers(tensor_id, producers, consumers)
362
+
363
+ def query_next_buffer(self, operation_id: int, address: str) -> Optional[Buffer]:
364
+ query = """
365
+ SELECT
366
+ buffers.operation_id,
367
+ buffers.device_id,
368
+ buffers.address,
369
+ buffers.max_size_per_bank,
370
+ buffers.buffer_type
371
+ FROM
372
+ buffers
373
+ WHERE
374
+ buffers.address = ?
375
+ AND buffers.operation_id > ?
376
+ ORDER BY buffers.operation_id
377
+ """
378
+ rows = self.query_runner.execute_query(query, [address, operation_id])
379
+ return Buffer(*rows[0]) if rows else None
380
+
381
+ def __enter__(self):
382
+ return self
383
+
384
+ def __exit__(self, exc_type, exc_value, traceback):
385
+ if isinstance(self.query_runner, RemoteQueryRunner):
386
+ self.query_runner.close()
387
+ elif isinstance(self.query_runner, LocalQueryRunner):
388
+ self.query_runner.close()
@@ -0,0 +1,91 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ #
3
+ # SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC
4
+
5
+ import re
6
+
7
+ from ttnn_visualizer.decorators import remote_exception_handler
8
+ from ttnn_visualizer.enums import ConnectionTestStates
9
+ from ttnn_visualizer.exceptions import RemoteSqliteException
10
+ from ttnn_visualizer.models import RemoteConnection
11
+ from ttnn_visualizer.ssh_client import get_client
12
+
13
+ MINIMUM_SQLITE_VERSION = "3.38.0"
14
+
15
+
16
+ def find_sqlite_binary(connection):
17
+ """Check if SQLite is installed on the remote machine and return its path."""
18
+ ssh_client = get_client(connection)
19
+ try:
20
+ stdin, stdout, stderr = ssh_client.exec_command("which sqlite3")
21
+ binary_path = stdout.read().decode().strip()
22
+ error = stderr.read().decode().strip()
23
+ if binary_path:
24
+ print(f"SQLite binary found at: {binary_path}")
25
+ return binary_path
26
+ elif error:
27
+ print(f"Error checking SQLite binary: {error}")
28
+ return None
29
+ except Exception as e:
30
+ raise RemoteSqliteException(
31
+ message=f"Error finding SQLite binary: {str(e)}",
32
+ status=ConnectionTestStates.FAILED,
33
+ )
34
+
35
+
36
+ def is_sqlite_executable(ssh_client, binary_path):
37
+ """Check if the SQLite binary is executable by trying to run it."""
38
+ try:
39
+ stdin, stdout, stderr = ssh_client.exec_command(f"{binary_path} --version")
40
+ output = stdout.read().decode().strip()
41
+ error = stderr.read().decode().strip()
42
+ stdout.channel.recv_exit_status()
43
+ if error:
44
+ raise Exception(f"Error while trying to run SQLite binary: {error}")
45
+
46
+ version = get_sqlite_version(output)
47
+ if not is_version_at_least(version, MINIMUM_SQLITE_VERSION):
48
+ raise Exception(
49
+ f"SQLite version {version} is below the required minimum of {MINIMUM_SQLITE_VERSION}."
50
+ )
51
+
52
+ print(f"SQLite binary at {binary_path} is executable. Version: {version}")
53
+ return True
54
+
55
+ except Exception as e:
56
+ raise Exception(f"Error checking SQLite executability: {str(e)}")
57
+
58
+
59
+ def get_sqlite_version(version_output):
60
+ """Extract and return the SQLite version number from the output."""
61
+ match = re.search(r"(\d+\.\d+\.\d+)", version_output)
62
+ if match:
63
+ return match.group(1)
64
+ else:
65
+ raise ValueError("Could not parse SQLite version from output.")
66
+
67
+
68
+ def is_version_at_least(version, minimum_version):
69
+ """Check if the provided version is at least the minimum version."""
70
+ version_parts = [int(v) for v in version.split(".")]
71
+ minimum_parts = [int(v) for v in minimum_version.split(".")]
72
+
73
+ return version_parts >= minimum_parts
74
+
75
+
76
+ @remote_exception_handler
77
+ def check_sqlite_path(remote_connection: RemoteConnection):
78
+ try:
79
+ client = get_client(remote_connection)
80
+ is_sqlite_executable(client, remote_connection.sqliteBinaryPath)
81
+ except Exception as e:
82
+ raise RemoteSqliteException(message=str(e), status=ConnectionTestStates.FAILED)
83
+
84
+
85
+ def get_sqlite_path(connection: RemoteConnection):
86
+ try:
87
+ path = find_sqlite_binary(connection)
88
+ if path:
89
+ return path
90
+ except Exception as e:
91
+ raise RemoteSqliteException(message=str(e), status=ConnectionTestStates.FAILED)
@@ -0,0 +1,24 @@
1
+ Flask==3.1.0
2
+ gunicorn~=22.0.0
3
+ uvicorn==0.30.1
4
+ paramiko~=3.4.0
5
+ flask_cors==4.0.1
6
+ pydantic==2.7.3
7
+ pydantic_core==2.18.4
8
+ flask_static_digest==0.4.1
9
+ setuptools==65.5.0
10
+ python-dotenv==1.0.1
11
+ flask-sqlalchemy
12
+ flask-socketio
13
+ gevent==24.10.2
14
+ flask-session
15
+ pandas==2.2.3
16
+ wheel
17
+ build
18
+ PyYAML==6.0.2
19
+ python-dotenv==1.0.1
20
+ tt-perf-report==1.0.6
21
+
22
+ # Dev dependencies
23
+ mypy
24
+ types-paramiko