ttnn-visualizer 0.24.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. ttnn_visualizer/__init__.py +4 -0
  2. ttnn_visualizer/app.py +193 -0
  3. ttnn_visualizer/bin/docker-entrypoint-web +16 -0
  4. ttnn_visualizer/bin/pip3-install +17 -0
  5. ttnn_visualizer/csv_queries.py +618 -0
  6. ttnn_visualizer/decorators.py +117 -0
  7. ttnn_visualizer/enums.py +12 -0
  8. ttnn_visualizer/exceptions.py +40 -0
  9. ttnn_visualizer/extensions.py +14 -0
  10. ttnn_visualizer/file_uploads.py +78 -0
  11. ttnn_visualizer/models.py +275 -0
  12. ttnn_visualizer/queries.py +388 -0
  13. ttnn_visualizer/remote_sqlite_setup.py +91 -0
  14. ttnn_visualizer/requirements.txt +24 -0
  15. ttnn_visualizer/serializers.py +249 -0
  16. ttnn_visualizer/sessions.py +245 -0
  17. ttnn_visualizer/settings.py +118 -0
  18. ttnn_visualizer/sftp_operations.py +486 -0
  19. ttnn_visualizer/sockets.py +118 -0
  20. ttnn_visualizer/ssh_client.py +85 -0
  21. ttnn_visualizer/static/assets/allPaths-CKt4gwo3.js +1 -0
  22. ttnn_visualizer/static/assets/allPathsLoader-Dzw0zTnr.js +2 -0
  23. ttnn_visualizer/static/assets/index-BXlT2rEV.js +5247 -0
  24. ttnn_visualizer/static/assets/index-CsS_OkTl.js +1 -0
  25. ttnn_visualizer/static/assets/index-DTKBo2Os.css +7 -0
  26. ttnn_visualizer/static/assets/index-DxLGmC6o.js +1 -0
  27. ttnn_visualizer/static/assets/site-BTBrvHC5.webmanifest +19 -0
  28. ttnn_visualizer/static/assets/splitPathsBySizeLoader-HHqSPeQM.js +1 -0
  29. ttnn_visualizer/static/favicon/android-chrome-192x192.png +0 -0
  30. ttnn_visualizer/static/favicon/android-chrome-512x512.png +0 -0
  31. ttnn_visualizer/static/favicon/favicon-32x32.png +0 -0
  32. ttnn_visualizer/static/favicon/favicon.svg +3 -0
  33. ttnn_visualizer/static/index.html +36 -0
  34. ttnn_visualizer/static/sample-data/cluster-desc.yaml +763 -0
  35. ttnn_visualizer/tests/__init__.py +4 -0
  36. ttnn_visualizer/tests/test_queries.py +444 -0
  37. ttnn_visualizer/tests/test_serializers.py +582 -0
  38. ttnn_visualizer/utils.py +185 -0
  39. ttnn_visualizer/views.py +794 -0
  40. ttnn_visualizer-0.24.0.dist-info/LICENSE +202 -0
  41. ttnn_visualizer-0.24.0.dist-info/LICENSE_understanding.txt +3 -0
  42. ttnn_visualizer-0.24.0.dist-info/METADATA +144 -0
  43. ttnn_visualizer-0.24.0.dist-info/RECORD +46 -0
  44. ttnn_visualizer-0.24.0.dist-info/WHEEL +5 -0
  45. ttnn_visualizer-0.24.0.dist-info/entry_points.txt +2 -0
  46. ttnn_visualizer-0.24.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,4 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ #
3
+ # SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC
4
+
@@ -0,0 +1,444 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ #
3
+ # SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC
4
+
5
+
6
+ import tempfile
7
+ import sqlite3
8
+ import tempfile
9
+ import unittest
10
+ from unittest.mock import Mock
11
+ from unittest.mock import patch
12
+
13
+ from ttnn_visualizer.models import (
14
+ DeviceOperation,
15
+ TensorComparisonRecord,
16
+ )
17
+ from ttnn_visualizer.queries import DatabaseQueries
18
+ from ttnn_visualizer.queries import LocalQueryRunner
19
+ from ttnn_visualizer.queries import RemoteQueryRunner
20
+
21
+
22
+ class TestQueryTable(unittest.TestCase):
23
+ """Tests query construction logic"""
24
+
25
+ def setUp(self):
26
+ # Mock the query_runner
27
+ self.mock_query_runner = Mock()
28
+ self.db_queries = DatabaseQueries(connection=Mock())
29
+ self.db_queries.query_runner = self.mock_query_runner
30
+
31
+ def test_query_table_no_filters_or_conditions(self):
32
+ self.db_queries._query_table("test_table")
33
+ self.mock_query_runner.execute_query.assert_called_once_with(
34
+ "SELECT * FROM test_table WHERE 1=1", []
35
+ )
36
+
37
+ def test_query_table_single_filter(self):
38
+ filters = {"column1": "value1"}
39
+ self.db_queries._query_table("test_table", filters)
40
+ self.mock_query_runner.execute_query.assert_called_once_with(
41
+ "SELECT * FROM test_table WHERE 1=1 AND column1 = ?", ["value1"]
42
+ )
43
+
44
+ def test_query_table_multiple_filters(self):
45
+ filters = {"column1": "value1", "column2": 42}
46
+ self.db_queries._query_table("test_table", filters)
47
+ self.mock_query_runner.execute_query.assert_called_once_with(
48
+ "SELECT * FROM test_table WHERE 1=1 AND column1 = ? AND column2 = ?",
49
+ ["value1", 42],
50
+ )
51
+
52
+ def test_query_table_filter_with_none_value(self):
53
+ filters = {"column1": "value1", "column2": None}
54
+ self.db_queries._query_table("test_table", filters)
55
+ self.mock_query_runner.execute_query.assert_called_once_with(
56
+ "SELECT * FROM test_table WHERE 1=1 AND column1 = ?", ["value1"]
57
+ )
58
+
59
+ def test_query_table_empty_list_filter(self):
60
+ filters = {"column1": []}
61
+ self.db_queries._query_table("test_table", filters)
62
+ self.mock_query_runner.execute_query.assert_called_once_with(
63
+ "SELECT * FROM test_table WHERE 1=1", []
64
+ )
65
+
66
+ def test_query_table_list_based_filter(self):
67
+ filters = {"column1": [1, 2, 3]}
68
+ self.db_queries._query_table("test_table", filters)
69
+ self.mock_query_runner.execute_query.assert_called_once_with(
70
+ "SELECT * FROM test_table WHERE 1=1 AND column1 IN (?, ?, ?)", [1, 2, 3]
71
+ )
72
+
73
+ def test_query_table_with_additional_conditions(self):
74
+ additional_conditions = "AND column3 > ?"
75
+ additional_params = [100]
76
+ self.db_queries._query_table(
77
+ "test_table",
78
+ additional_conditions=additional_conditions,
79
+ additional_params=additional_params,
80
+ )
81
+ self.mock_query_runner.execute_query.assert_called_once_with(
82
+ "SELECT * FROM test_table WHERE 1=1 AND column3 > ?", [100]
83
+ )
84
+
85
+ def test_query_table_with_filters_and_conditions(self):
86
+ filters = {"column1": "value1"}
87
+ additional_conditions = "AND column3 > ?"
88
+ additional_params = [100]
89
+ self.db_queries._query_table(
90
+ "test_table", filters, additional_conditions, additional_params
91
+ )
92
+ self.mock_query_runner.execute_query.assert_called_once_with(
93
+ "SELECT * FROM test_table WHERE 1=1 AND column1 = ? AND column3 > ?",
94
+ ["value1", 100],
95
+ )
96
+
97
+ def tearDown(self):
98
+ self.mock_query_runner.reset_mock()
99
+
100
+
101
+ class TestDatabaseQueries(unittest.TestCase):
102
+ """
103
+ Tests specific table querying with filters and conditions
104
+ """
105
+
106
+ def setUp(self):
107
+ self.connection = sqlite3.connect(":memory:")
108
+ self.db_queries = DatabaseQueries(connection=self.connection)
109
+ self._create_tables()
110
+
111
+ def tearDown(self):
112
+ self.connection.close()
113
+
114
+ def _create_tables(self):
115
+ schema = """
116
+ CREATE TABLE devices (
117
+ device_id int,
118
+ num_y_cores int,
119
+ num_x_cores int,
120
+ num_y_compute_cores int,
121
+ num_x_compute_cores int,
122
+ worker_l1_size int,
123
+ l1_num_banks int,
124
+ l1_bank_size int,
125
+ address_at_first_l1_bank int,
126
+ address_at_first_l1_cb_buffer int,
127
+ num_banks_per_storage_core int,
128
+ num_compute_cores int,
129
+ num_storage_cores int,
130
+ total_l1_memory int,
131
+ total_l1_for_tensors int,
132
+ total_l1_for_interleaved_buffers int,
133
+ total_l1_for_sharded_buffers int,
134
+ cb_limit int
135
+ );
136
+ CREATE TABLE captured_graph (
137
+ operation_id int,
138
+ captured_graph text
139
+ );
140
+ CREATE TABLE buffers (
141
+ operation_id int,
142
+ device_id int,
143
+ address int,
144
+ max_size_per_bank int,
145
+ buffer_type int
146
+ );
147
+ CREATE TABLE tensors (
148
+ tensor_id int UNIQUE,
149
+ shape text,
150
+ dtype text,
151
+ layout text,
152
+ memory_config text,
153
+ device_id int,
154
+ address int,
155
+ buffer_type int
156
+ );
157
+ CREATE TABLE operation_arguments (
158
+ operation_id int,
159
+ name text,
160
+ value text
161
+ );
162
+ CREATE TABLE stack_traces (
163
+ operation_id int,
164
+ stack_trace text
165
+ );
166
+ CREATE TABLE input_tensors (
167
+ operation_id int,
168
+ input_index int,
169
+ tensor_id int
170
+ );
171
+ CREATE TABLE output_tensors (
172
+ operation_id int,
173
+ output_index int,
174
+ tensor_id int
175
+ );
176
+ CREATE TABLE operations (
177
+ operation_id int UNIQUE,
178
+ name text,
179
+ duration float
180
+ );
181
+ CREATE TABLE buffer_pages (
182
+ operation_id INT,
183
+ device_id INT,
184
+ address INT,
185
+ core_y INT,
186
+ core_x INT,
187
+ bank_id INT,
188
+ page_index INT,
189
+ page_address INT,
190
+ page_size INT,
191
+ buffer_type INT
192
+ );
193
+ CREATE TABLE local_tensor_comparison_records (
194
+ tensor_id int,
195
+ golden_tensor_id int,
196
+ matches int,
197
+ desired_pcc float,
198
+ actual_pcc float
199
+ );
200
+ CREATE TABLE global_tensor_comparison_records (
201
+ tensor_id int,
202
+ golden_tensor_id int,
203
+ matches int,
204
+ desired_pcc float,
205
+ actual_pcc float
206
+ );
207
+
208
+ """
209
+ self.connection.executescript(schema)
210
+
211
+ def test_init_with_valid_connection(self):
212
+ connection = sqlite3.connect(":memory:")
213
+ db_queries = DatabaseQueries(connection=connection)
214
+ self.assertIsInstance(db_queries.query_runner, LocalQueryRunner)
215
+ connection.close()
216
+
217
+ def test_init_with_missing_session_and_connection(self):
218
+ with self.assertRaises(ValueError) as context:
219
+ DatabaseQueries(session=None, connection=None)
220
+ self.assertIn(
221
+ "Must provide either an existing connection or session",
222
+ str(context.exception),
223
+ )
224
+
225
+ @patch("ttnn_visualizer.queries.get_client")
226
+ def test_init_with_valid_remote_session(self, _mock_client):
227
+ mock_session = Mock()
228
+ mock_session.remote_connection = Mock(useRemoteQuerying=True)
229
+ mock_session.remote_connection.sqliteBinaryPath = "/usr/bin/sqlite3"
230
+ mock_session.remote_folder = Mock(remotePath="/remote/path")
231
+ db_queries = DatabaseQueries(session=mock_session)
232
+ self.assertIsInstance(db_queries.query_runner, RemoteQueryRunner)
233
+
234
+ def test_init_with_valid_local_session(self):
235
+ with tempfile.NamedTemporaryFile(suffix=".sqlite") as temp_db_file:
236
+ connection = sqlite3.connect(temp_db_file.name)
237
+ connection.execute("CREATE TABLE test (id INTEGER PRIMARY KEY, name TEXT);")
238
+ connection.close()
239
+
240
+ mock_session = Mock()
241
+ mock_session.report_path = temp_db_file.name
242
+ mock_session.remote_connection = None
243
+
244
+ db_queries = DatabaseQueries(session=mock_session)
245
+ self.assertIsInstance(db_queries.query_runner, LocalQueryRunner)
246
+
247
+ def test_init_with_invalid_session(self):
248
+ mock_session = Mock()
249
+ mock_session.report_path = None
250
+ mock_session.remote_connection = None
251
+ with self.assertRaises(ValueError) as context:
252
+ DatabaseQueries(session=mock_session)
253
+ self.assertIn(
254
+ "Report path must be provided for local queries", str(context.exception)
255
+ )
256
+
257
+ def test_check_table_exists(self):
258
+ self.assertTrue(self.db_queries._check_table_exists("devices"))
259
+ self.assertFalse(self.db_queries._check_table_exists("nonexistent_table"))
260
+
261
+ def test_query_device_operations(self):
262
+ self.connection.execute(
263
+ 'INSERT INTO captured_graph VALUES (1, \'[{"counter": 1, "data": "value1"}]\')'
264
+ )
265
+ results = self.db_queries.query_device_operations(filters={"operation_id": 1})
266
+ self.assertEqual(len(results), 1)
267
+ self.assertIsInstance(results[0], DeviceOperation)
268
+
269
+ def test_query_device_operations_table_missing(self):
270
+ self.connection.execute("DROP TABLE captured_graph")
271
+ results = self.db_queries.query_device_operations(filters={"operation_id": 1})
272
+ self.assertEqual(results, [])
273
+
274
+ def test_query_operation_arguments(self):
275
+ self.connection.execute(
276
+ "INSERT INTO operation_arguments VALUES (1, 'arg_name', 'arg_value')"
277
+ )
278
+ results = list(
279
+ self.db_queries.query_operation_arguments(filters={"operation_id": 1})
280
+ )
281
+ self.assertEqual(len(results), 1)
282
+ self.assertEqual(results[0].name, "arg_name")
283
+
284
+ def test_query_operations(self):
285
+ self.connection.execute("INSERT INTO operations VALUES (1, 'op1', 2.0)")
286
+ results = list(self.db_queries.query_operations(filters={"operation_id": 1}))
287
+ self.assertEqual(len(results), 1)
288
+ self.assertEqual(results[0].name, "op1")
289
+
290
+ def test_query_buffers(self):
291
+ self.connection.execute("INSERT INTO buffers VALUES (1, 1, 100, 1024, 0)")
292
+ results = list(self.db_queries.query_buffers(filters={"operation_id": 1}))
293
+ self.assertEqual(len(results), 1)
294
+ self.assertEqual(results[0].address, 100)
295
+
296
+ def test_query_stack_traces(self):
297
+ self.connection.execute("INSERT INTO stack_traces VALUES (1, 'trace_data')")
298
+ results = list(self.db_queries.query_stack_traces(filters={"operation_id": 1}))
299
+ self.assertEqual(len(results), 1)
300
+ self.assertEqual(results[0].stack_trace, "trace_data")
301
+
302
+ def test_query_tensor_comparisons(self):
303
+ self.connection.execute(
304
+ """
305
+ INSERT INTO local_tensor_comparison_records
306
+ (tensor_id, golden_tensor_id, matches, desired_pcc, actual_pcc)
307
+ VALUES (1, 10, 1, 0.9, 0.8)
308
+ """
309
+ )
310
+ results = list(
311
+ self.db_queries.query_tensor_comparisons(local=True, filters={"matches": 1})
312
+ )
313
+
314
+ self.assertEqual(len(results), 1)
315
+ comparison = results[0]
316
+ self.assertIsInstance(comparison, TensorComparisonRecord)
317
+ self.assertEqual(comparison.tensor_id, 1)
318
+ self.assertEqual(comparison.golden_tensor_id, 10)
319
+ self.assertTrue(comparison.matches)
320
+ self.assertAlmostEqual(comparison.desired_pcc, 0.9)
321
+ self.assertAlmostEqual(comparison.actual_pcc, 0.8)
322
+
323
+ def test_query_buffer_pages(self):
324
+ self.connection.execute(
325
+ "INSERT INTO buffer_pages VALUES (1, 1, 100, 0, 0, 1, 0, 1000, 4096, 0)"
326
+ )
327
+ results = list(self.db_queries.query_buffer_pages(filters={"operation_id": 1}))
328
+ self.assertEqual(len(results), 1)
329
+ self.assertEqual(results[0].address, 100)
330
+
331
+ def test_query_tensors(self):
332
+ self.connection.execute(
333
+ "INSERT INTO tensors VALUES (1, '(2,2)', 'float32', 'NCHW', 'default', 1, 100, 0)"
334
+ )
335
+ results = list(self.db_queries.query_tensors(filters={"tensor_id": 1}))
336
+ self.assertEqual(len(results), 1)
337
+ self.assertEqual(results[0].tensor_id, 1)
338
+
339
+ def test_query_input_tensors(self):
340
+ self.connection.execute("INSERT INTO input_tensors VALUES (1, 0, 1)")
341
+ results = list(self.db_queries.query_input_tensors(filters={"operation_id": 1}))
342
+ self.assertEqual(len(results), 1)
343
+ self.assertEqual(results[0].operation_id, 1)
344
+
345
+ def test_query_output_tensors(self):
346
+ self.connection.execute("INSERT INTO output_tensors VALUES (1, 0, 1)")
347
+ results = list(
348
+ self.db_queries.query_output_tensors(filters={"operation_id": 1})
349
+ )
350
+ self.assertEqual(len(results), 1)
351
+ self.assertEqual(results[0].operation_id, 1)
352
+
353
+ def test_query_devices(self):
354
+ self.connection.execute(
355
+ "INSERT INTO devices VALUES (1, 4, 4, 2, 2, 1024, 4, 256, 0, 0, 1, 2, 2, 4096, 2048, 2048, 2048, 256)"
356
+ )
357
+ results = list(self.db_queries.query_devices(filters={"device_id": 1}))
358
+ self.assertEqual(len(results), 1)
359
+ self.assertEqual(results[0].device_id, 1)
360
+
361
+ def test_query_producers_consumers(self):
362
+ self.connection.execute(
363
+ "INSERT INTO tensors VALUES (1, '(2,2)', 'float32', 'NCHW', 'default', 1, 100, 0)"
364
+ )
365
+ self.connection.execute("INSERT INTO input_tensors VALUES (2, 0, 1)")
366
+ self.connection.execute("INSERT INTO output_tensors VALUES (1, 0, 1)")
367
+
368
+ results = list(self.db_queries.query_producers_consumers())
369
+
370
+ self.assertEqual(len(results), 1)
371
+ pc = results[0]
372
+ self.assertEqual(pc.tensor_id, 1)
373
+ self.assertIn(1, pc.producers)
374
+ self.assertIn(2, pc.consumers)
375
+
376
+ def test_query_next_buffer(self):
377
+ self.connection.execute("INSERT INTO buffers VALUES (1, 1, 100, 1024, 0)")
378
+ self.connection.execute("INSERT INTO buffers VALUES (2, 1, 100, 2048, 0)")
379
+ result = self.db_queries.query_next_buffer(operation_id=1, address=100)
380
+ self.assertIsNotNone(result)
381
+ self.assertEqual(result.operation_id, 2)
382
+
383
+
384
+ class TestRemoteQueryRunner(unittest.TestCase):
385
+
386
+ def setUp(self):
387
+ self.mock_session = Mock()
388
+ self.mock_session.remote_connection.sqliteBinaryPath = "/usr/bin/sqlite3"
389
+ self.mock_session.remote_connection.host = "mockhost"
390
+ self.mock_session.remote_connection.user = "mockuser"
391
+ self.mock_session.remote_folder.remotePath = "/remote/db"
392
+
393
+ @patch("ttnn_visualizer.queries.get_client")
394
+ def test_init_with_mock_get_client(self, mock_get_client):
395
+ # Mock the SSHClient returned by get_client
396
+ mock_ssh_client = Mock()
397
+ mock_get_client.return_value = mock_ssh_client
398
+
399
+ runner = RemoteQueryRunner(session=self.mock_session)
400
+ self.assertEqual(runner.ssh_client, mock_ssh_client)
401
+ mock_get_client.assert_called_once_with(
402
+ remote_connection=self.mock_session.remote_connection
403
+ )
404
+
405
+ @patch("ttnn_visualizer.queries.get_client")
406
+ def test_execute_query(self, mock_get_client):
407
+ # Mock the SSH client
408
+ mock_ssh_client = Mock()
409
+ mock_get_client.return_value = mock_ssh_client
410
+
411
+ mock_stdout = Mock()
412
+ mock_stdout.read.return_value = b'[{"col1": "value1", "col2": "value2"}]'
413
+ mock_stderr = Mock()
414
+ mock_stderr.read.return_value = b""
415
+ mock_ssh_client.exec_command.return_value = (None, mock_stdout, mock_stderr)
416
+
417
+ runner = RemoteQueryRunner(session=self.mock_session)
418
+
419
+ query = "SELECT * FROM table WHERE id = ?"
420
+ params = [1]
421
+ results = runner.execute_query(query, params)
422
+
423
+ # Validate results
424
+ self.assertEqual(results, [("value1", "value2")])
425
+ mock_get_client.assert_called_once()
426
+ mock_ssh_client.exec_command.assert_called_once()
427
+
428
+ @patch("ttnn_visualizer.queries.get_client")
429
+ def test_close(self, mock_get_client):
430
+ # Mock the SSH client
431
+ mock_ssh_client = Mock()
432
+ mock_get_client.return_value = mock_ssh_client
433
+
434
+ runner = RemoteQueryRunner(session=self.mock_session)
435
+
436
+ runner.close()
437
+ mock_ssh_client.close.assert_called_once()
438
+
439
+ def tearDown(self):
440
+ pass
441
+
442
+
443
+ if __name__ == "__main__":
444
+ unittest.main()