ttnn-visualizer 0.24.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ttnn_visualizer/__init__.py +4 -0
- ttnn_visualizer/app.py +193 -0
- ttnn_visualizer/bin/docker-entrypoint-web +16 -0
- ttnn_visualizer/bin/pip3-install +17 -0
- ttnn_visualizer/csv_queries.py +618 -0
- ttnn_visualizer/decorators.py +117 -0
- ttnn_visualizer/enums.py +12 -0
- ttnn_visualizer/exceptions.py +40 -0
- ttnn_visualizer/extensions.py +14 -0
- ttnn_visualizer/file_uploads.py +78 -0
- ttnn_visualizer/models.py +275 -0
- ttnn_visualizer/queries.py +388 -0
- ttnn_visualizer/remote_sqlite_setup.py +91 -0
- ttnn_visualizer/requirements.txt +24 -0
- ttnn_visualizer/serializers.py +249 -0
- ttnn_visualizer/sessions.py +245 -0
- ttnn_visualizer/settings.py +118 -0
- ttnn_visualizer/sftp_operations.py +486 -0
- ttnn_visualizer/sockets.py +118 -0
- ttnn_visualizer/ssh_client.py +85 -0
- ttnn_visualizer/static/assets/allPaths-CKt4gwo3.js +1 -0
- ttnn_visualizer/static/assets/allPathsLoader-Dzw0zTnr.js +2 -0
- ttnn_visualizer/static/assets/index-BXlT2rEV.js +5247 -0
- ttnn_visualizer/static/assets/index-CsS_OkTl.js +1 -0
- ttnn_visualizer/static/assets/index-DTKBo2Os.css +7 -0
- ttnn_visualizer/static/assets/index-DxLGmC6o.js +1 -0
- ttnn_visualizer/static/assets/site-BTBrvHC5.webmanifest +19 -0
- ttnn_visualizer/static/assets/splitPathsBySizeLoader-HHqSPeQM.js +1 -0
- ttnn_visualizer/static/favicon/android-chrome-192x192.png +0 -0
- ttnn_visualizer/static/favicon/android-chrome-512x512.png +0 -0
- ttnn_visualizer/static/favicon/favicon-32x32.png +0 -0
- ttnn_visualizer/static/favicon/favicon.svg +3 -0
- ttnn_visualizer/static/index.html +36 -0
- ttnn_visualizer/static/sample-data/cluster-desc.yaml +763 -0
- ttnn_visualizer/tests/__init__.py +4 -0
- ttnn_visualizer/tests/test_queries.py +444 -0
- ttnn_visualizer/tests/test_serializers.py +582 -0
- ttnn_visualizer/utils.py +185 -0
- ttnn_visualizer/views.py +794 -0
- ttnn_visualizer-0.24.0.dist-info/LICENSE +202 -0
- ttnn_visualizer-0.24.0.dist-info/LICENSE_understanding.txt +3 -0
- ttnn_visualizer-0.24.0.dist-info/METADATA +144 -0
- ttnn_visualizer-0.24.0.dist-info/RECORD +46 -0
- ttnn_visualizer-0.24.0.dist-info/WHEEL +5 -0
- ttnn_visualizer-0.24.0.dist-info/entry_points.txt +2 -0
- ttnn_visualizer-0.24.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,444 @@
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
2
|
+
#
|
3
|
+
# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC
|
4
|
+
|
5
|
+
|
6
|
+
import tempfile
|
7
|
+
import sqlite3
|
8
|
+
import tempfile
|
9
|
+
import unittest
|
10
|
+
from unittest.mock import Mock
|
11
|
+
from unittest.mock import patch
|
12
|
+
|
13
|
+
from ttnn_visualizer.models import (
|
14
|
+
DeviceOperation,
|
15
|
+
TensorComparisonRecord,
|
16
|
+
)
|
17
|
+
from ttnn_visualizer.queries import DatabaseQueries
|
18
|
+
from ttnn_visualizer.queries import LocalQueryRunner
|
19
|
+
from ttnn_visualizer.queries import RemoteQueryRunner
|
20
|
+
|
21
|
+
|
22
|
+
class TestQueryTable(unittest.TestCase):
|
23
|
+
"""Tests query construction logic"""
|
24
|
+
|
25
|
+
def setUp(self):
|
26
|
+
# Mock the query_runner
|
27
|
+
self.mock_query_runner = Mock()
|
28
|
+
self.db_queries = DatabaseQueries(connection=Mock())
|
29
|
+
self.db_queries.query_runner = self.mock_query_runner
|
30
|
+
|
31
|
+
def test_query_table_no_filters_or_conditions(self):
|
32
|
+
self.db_queries._query_table("test_table")
|
33
|
+
self.mock_query_runner.execute_query.assert_called_once_with(
|
34
|
+
"SELECT * FROM test_table WHERE 1=1", []
|
35
|
+
)
|
36
|
+
|
37
|
+
def test_query_table_single_filter(self):
|
38
|
+
filters = {"column1": "value1"}
|
39
|
+
self.db_queries._query_table("test_table", filters)
|
40
|
+
self.mock_query_runner.execute_query.assert_called_once_with(
|
41
|
+
"SELECT * FROM test_table WHERE 1=1 AND column1 = ?", ["value1"]
|
42
|
+
)
|
43
|
+
|
44
|
+
def test_query_table_multiple_filters(self):
|
45
|
+
filters = {"column1": "value1", "column2": 42}
|
46
|
+
self.db_queries._query_table("test_table", filters)
|
47
|
+
self.mock_query_runner.execute_query.assert_called_once_with(
|
48
|
+
"SELECT * FROM test_table WHERE 1=1 AND column1 = ? AND column2 = ?",
|
49
|
+
["value1", 42],
|
50
|
+
)
|
51
|
+
|
52
|
+
def test_query_table_filter_with_none_value(self):
|
53
|
+
filters = {"column1": "value1", "column2": None}
|
54
|
+
self.db_queries._query_table("test_table", filters)
|
55
|
+
self.mock_query_runner.execute_query.assert_called_once_with(
|
56
|
+
"SELECT * FROM test_table WHERE 1=1 AND column1 = ?", ["value1"]
|
57
|
+
)
|
58
|
+
|
59
|
+
def test_query_table_empty_list_filter(self):
|
60
|
+
filters = {"column1": []}
|
61
|
+
self.db_queries._query_table("test_table", filters)
|
62
|
+
self.mock_query_runner.execute_query.assert_called_once_with(
|
63
|
+
"SELECT * FROM test_table WHERE 1=1", []
|
64
|
+
)
|
65
|
+
|
66
|
+
def test_query_table_list_based_filter(self):
|
67
|
+
filters = {"column1": [1, 2, 3]}
|
68
|
+
self.db_queries._query_table("test_table", filters)
|
69
|
+
self.mock_query_runner.execute_query.assert_called_once_with(
|
70
|
+
"SELECT * FROM test_table WHERE 1=1 AND column1 IN (?, ?, ?)", [1, 2, 3]
|
71
|
+
)
|
72
|
+
|
73
|
+
def test_query_table_with_additional_conditions(self):
|
74
|
+
additional_conditions = "AND column3 > ?"
|
75
|
+
additional_params = [100]
|
76
|
+
self.db_queries._query_table(
|
77
|
+
"test_table",
|
78
|
+
additional_conditions=additional_conditions,
|
79
|
+
additional_params=additional_params,
|
80
|
+
)
|
81
|
+
self.mock_query_runner.execute_query.assert_called_once_with(
|
82
|
+
"SELECT * FROM test_table WHERE 1=1 AND column3 > ?", [100]
|
83
|
+
)
|
84
|
+
|
85
|
+
def test_query_table_with_filters_and_conditions(self):
|
86
|
+
filters = {"column1": "value1"}
|
87
|
+
additional_conditions = "AND column3 > ?"
|
88
|
+
additional_params = [100]
|
89
|
+
self.db_queries._query_table(
|
90
|
+
"test_table", filters, additional_conditions, additional_params
|
91
|
+
)
|
92
|
+
self.mock_query_runner.execute_query.assert_called_once_with(
|
93
|
+
"SELECT * FROM test_table WHERE 1=1 AND column1 = ? AND column3 > ?",
|
94
|
+
["value1", 100],
|
95
|
+
)
|
96
|
+
|
97
|
+
def tearDown(self):
|
98
|
+
self.mock_query_runner.reset_mock()
|
99
|
+
|
100
|
+
|
101
|
+
class TestDatabaseQueries(unittest.TestCase):
|
102
|
+
"""
|
103
|
+
Tests specific table querying with filters and conditions
|
104
|
+
"""
|
105
|
+
|
106
|
+
def setUp(self):
|
107
|
+
self.connection = sqlite3.connect(":memory:")
|
108
|
+
self.db_queries = DatabaseQueries(connection=self.connection)
|
109
|
+
self._create_tables()
|
110
|
+
|
111
|
+
def tearDown(self):
|
112
|
+
self.connection.close()
|
113
|
+
|
114
|
+
def _create_tables(self):
|
115
|
+
schema = """
|
116
|
+
CREATE TABLE devices (
|
117
|
+
device_id int,
|
118
|
+
num_y_cores int,
|
119
|
+
num_x_cores int,
|
120
|
+
num_y_compute_cores int,
|
121
|
+
num_x_compute_cores int,
|
122
|
+
worker_l1_size int,
|
123
|
+
l1_num_banks int,
|
124
|
+
l1_bank_size int,
|
125
|
+
address_at_first_l1_bank int,
|
126
|
+
address_at_first_l1_cb_buffer int,
|
127
|
+
num_banks_per_storage_core int,
|
128
|
+
num_compute_cores int,
|
129
|
+
num_storage_cores int,
|
130
|
+
total_l1_memory int,
|
131
|
+
total_l1_for_tensors int,
|
132
|
+
total_l1_for_interleaved_buffers int,
|
133
|
+
total_l1_for_sharded_buffers int,
|
134
|
+
cb_limit int
|
135
|
+
);
|
136
|
+
CREATE TABLE captured_graph (
|
137
|
+
operation_id int,
|
138
|
+
captured_graph text
|
139
|
+
);
|
140
|
+
CREATE TABLE buffers (
|
141
|
+
operation_id int,
|
142
|
+
device_id int,
|
143
|
+
address int,
|
144
|
+
max_size_per_bank int,
|
145
|
+
buffer_type int
|
146
|
+
);
|
147
|
+
CREATE TABLE tensors (
|
148
|
+
tensor_id int UNIQUE,
|
149
|
+
shape text,
|
150
|
+
dtype text,
|
151
|
+
layout text,
|
152
|
+
memory_config text,
|
153
|
+
device_id int,
|
154
|
+
address int,
|
155
|
+
buffer_type int
|
156
|
+
);
|
157
|
+
CREATE TABLE operation_arguments (
|
158
|
+
operation_id int,
|
159
|
+
name text,
|
160
|
+
value text
|
161
|
+
);
|
162
|
+
CREATE TABLE stack_traces (
|
163
|
+
operation_id int,
|
164
|
+
stack_trace text
|
165
|
+
);
|
166
|
+
CREATE TABLE input_tensors (
|
167
|
+
operation_id int,
|
168
|
+
input_index int,
|
169
|
+
tensor_id int
|
170
|
+
);
|
171
|
+
CREATE TABLE output_tensors (
|
172
|
+
operation_id int,
|
173
|
+
output_index int,
|
174
|
+
tensor_id int
|
175
|
+
);
|
176
|
+
CREATE TABLE operations (
|
177
|
+
operation_id int UNIQUE,
|
178
|
+
name text,
|
179
|
+
duration float
|
180
|
+
);
|
181
|
+
CREATE TABLE buffer_pages (
|
182
|
+
operation_id INT,
|
183
|
+
device_id INT,
|
184
|
+
address INT,
|
185
|
+
core_y INT,
|
186
|
+
core_x INT,
|
187
|
+
bank_id INT,
|
188
|
+
page_index INT,
|
189
|
+
page_address INT,
|
190
|
+
page_size INT,
|
191
|
+
buffer_type INT
|
192
|
+
);
|
193
|
+
CREATE TABLE local_tensor_comparison_records (
|
194
|
+
tensor_id int,
|
195
|
+
golden_tensor_id int,
|
196
|
+
matches int,
|
197
|
+
desired_pcc float,
|
198
|
+
actual_pcc float
|
199
|
+
);
|
200
|
+
CREATE TABLE global_tensor_comparison_records (
|
201
|
+
tensor_id int,
|
202
|
+
golden_tensor_id int,
|
203
|
+
matches int,
|
204
|
+
desired_pcc float,
|
205
|
+
actual_pcc float
|
206
|
+
);
|
207
|
+
|
208
|
+
"""
|
209
|
+
self.connection.executescript(schema)
|
210
|
+
|
211
|
+
def test_init_with_valid_connection(self):
|
212
|
+
connection = sqlite3.connect(":memory:")
|
213
|
+
db_queries = DatabaseQueries(connection=connection)
|
214
|
+
self.assertIsInstance(db_queries.query_runner, LocalQueryRunner)
|
215
|
+
connection.close()
|
216
|
+
|
217
|
+
def test_init_with_missing_session_and_connection(self):
|
218
|
+
with self.assertRaises(ValueError) as context:
|
219
|
+
DatabaseQueries(session=None, connection=None)
|
220
|
+
self.assertIn(
|
221
|
+
"Must provide either an existing connection or session",
|
222
|
+
str(context.exception),
|
223
|
+
)
|
224
|
+
|
225
|
+
@patch("ttnn_visualizer.queries.get_client")
|
226
|
+
def test_init_with_valid_remote_session(self, _mock_client):
|
227
|
+
mock_session = Mock()
|
228
|
+
mock_session.remote_connection = Mock(useRemoteQuerying=True)
|
229
|
+
mock_session.remote_connection.sqliteBinaryPath = "/usr/bin/sqlite3"
|
230
|
+
mock_session.remote_folder = Mock(remotePath="/remote/path")
|
231
|
+
db_queries = DatabaseQueries(session=mock_session)
|
232
|
+
self.assertIsInstance(db_queries.query_runner, RemoteQueryRunner)
|
233
|
+
|
234
|
+
def test_init_with_valid_local_session(self):
|
235
|
+
with tempfile.NamedTemporaryFile(suffix=".sqlite") as temp_db_file:
|
236
|
+
connection = sqlite3.connect(temp_db_file.name)
|
237
|
+
connection.execute("CREATE TABLE test (id INTEGER PRIMARY KEY, name TEXT);")
|
238
|
+
connection.close()
|
239
|
+
|
240
|
+
mock_session = Mock()
|
241
|
+
mock_session.report_path = temp_db_file.name
|
242
|
+
mock_session.remote_connection = None
|
243
|
+
|
244
|
+
db_queries = DatabaseQueries(session=mock_session)
|
245
|
+
self.assertIsInstance(db_queries.query_runner, LocalQueryRunner)
|
246
|
+
|
247
|
+
def test_init_with_invalid_session(self):
|
248
|
+
mock_session = Mock()
|
249
|
+
mock_session.report_path = None
|
250
|
+
mock_session.remote_connection = None
|
251
|
+
with self.assertRaises(ValueError) as context:
|
252
|
+
DatabaseQueries(session=mock_session)
|
253
|
+
self.assertIn(
|
254
|
+
"Report path must be provided for local queries", str(context.exception)
|
255
|
+
)
|
256
|
+
|
257
|
+
def test_check_table_exists(self):
|
258
|
+
self.assertTrue(self.db_queries._check_table_exists("devices"))
|
259
|
+
self.assertFalse(self.db_queries._check_table_exists("nonexistent_table"))
|
260
|
+
|
261
|
+
def test_query_device_operations(self):
|
262
|
+
self.connection.execute(
|
263
|
+
'INSERT INTO captured_graph VALUES (1, \'[{"counter": 1, "data": "value1"}]\')'
|
264
|
+
)
|
265
|
+
results = self.db_queries.query_device_operations(filters={"operation_id": 1})
|
266
|
+
self.assertEqual(len(results), 1)
|
267
|
+
self.assertIsInstance(results[0], DeviceOperation)
|
268
|
+
|
269
|
+
def test_query_device_operations_table_missing(self):
|
270
|
+
self.connection.execute("DROP TABLE captured_graph")
|
271
|
+
results = self.db_queries.query_device_operations(filters={"operation_id": 1})
|
272
|
+
self.assertEqual(results, [])
|
273
|
+
|
274
|
+
def test_query_operation_arguments(self):
|
275
|
+
self.connection.execute(
|
276
|
+
"INSERT INTO operation_arguments VALUES (1, 'arg_name', 'arg_value')"
|
277
|
+
)
|
278
|
+
results = list(
|
279
|
+
self.db_queries.query_operation_arguments(filters={"operation_id": 1})
|
280
|
+
)
|
281
|
+
self.assertEqual(len(results), 1)
|
282
|
+
self.assertEqual(results[0].name, "arg_name")
|
283
|
+
|
284
|
+
def test_query_operations(self):
|
285
|
+
self.connection.execute("INSERT INTO operations VALUES (1, 'op1', 2.0)")
|
286
|
+
results = list(self.db_queries.query_operations(filters={"operation_id": 1}))
|
287
|
+
self.assertEqual(len(results), 1)
|
288
|
+
self.assertEqual(results[0].name, "op1")
|
289
|
+
|
290
|
+
def test_query_buffers(self):
|
291
|
+
self.connection.execute("INSERT INTO buffers VALUES (1, 1, 100, 1024, 0)")
|
292
|
+
results = list(self.db_queries.query_buffers(filters={"operation_id": 1}))
|
293
|
+
self.assertEqual(len(results), 1)
|
294
|
+
self.assertEqual(results[0].address, 100)
|
295
|
+
|
296
|
+
def test_query_stack_traces(self):
|
297
|
+
self.connection.execute("INSERT INTO stack_traces VALUES (1, 'trace_data')")
|
298
|
+
results = list(self.db_queries.query_stack_traces(filters={"operation_id": 1}))
|
299
|
+
self.assertEqual(len(results), 1)
|
300
|
+
self.assertEqual(results[0].stack_trace, "trace_data")
|
301
|
+
|
302
|
+
def test_query_tensor_comparisons(self):
|
303
|
+
self.connection.execute(
|
304
|
+
"""
|
305
|
+
INSERT INTO local_tensor_comparison_records
|
306
|
+
(tensor_id, golden_tensor_id, matches, desired_pcc, actual_pcc)
|
307
|
+
VALUES (1, 10, 1, 0.9, 0.8)
|
308
|
+
"""
|
309
|
+
)
|
310
|
+
results = list(
|
311
|
+
self.db_queries.query_tensor_comparisons(local=True, filters={"matches": 1})
|
312
|
+
)
|
313
|
+
|
314
|
+
self.assertEqual(len(results), 1)
|
315
|
+
comparison = results[0]
|
316
|
+
self.assertIsInstance(comparison, TensorComparisonRecord)
|
317
|
+
self.assertEqual(comparison.tensor_id, 1)
|
318
|
+
self.assertEqual(comparison.golden_tensor_id, 10)
|
319
|
+
self.assertTrue(comparison.matches)
|
320
|
+
self.assertAlmostEqual(comparison.desired_pcc, 0.9)
|
321
|
+
self.assertAlmostEqual(comparison.actual_pcc, 0.8)
|
322
|
+
|
323
|
+
def test_query_buffer_pages(self):
|
324
|
+
self.connection.execute(
|
325
|
+
"INSERT INTO buffer_pages VALUES (1, 1, 100, 0, 0, 1, 0, 1000, 4096, 0)"
|
326
|
+
)
|
327
|
+
results = list(self.db_queries.query_buffer_pages(filters={"operation_id": 1}))
|
328
|
+
self.assertEqual(len(results), 1)
|
329
|
+
self.assertEqual(results[0].address, 100)
|
330
|
+
|
331
|
+
def test_query_tensors(self):
|
332
|
+
self.connection.execute(
|
333
|
+
"INSERT INTO tensors VALUES (1, '(2,2)', 'float32', 'NCHW', 'default', 1, 100, 0)"
|
334
|
+
)
|
335
|
+
results = list(self.db_queries.query_tensors(filters={"tensor_id": 1}))
|
336
|
+
self.assertEqual(len(results), 1)
|
337
|
+
self.assertEqual(results[0].tensor_id, 1)
|
338
|
+
|
339
|
+
def test_query_input_tensors(self):
|
340
|
+
self.connection.execute("INSERT INTO input_tensors VALUES (1, 0, 1)")
|
341
|
+
results = list(self.db_queries.query_input_tensors(filters={"operation_id": 1}))
|
342
|
+
self.assertEqual(len(results), 1)
|
343
|
+
self.assertEqual(results[0].operation_id, 1)
|
344
|
+
|
345
|
+
def test_query_output_tensors(self):
|
346
|
+
self.connection.execute("INSERT INTO output_tensors VALUES (1, 0, 1)")
|
347
|
+
results = list(
|
348
|
+
self.db_queries.query_output_tensors(filters={"operation_id": 1})
|
349
|
+
)
|
350
|
+
self.assertEqual(len(results), 1)
|
351
|
+
self.assertEqual(results[0].operation_id, 1)
|
352
|
+
|
353
|
+
def test_query_devices(self):
|
354
|
+
self.connection.execute(
|
355
|
+
"INSERT INTO devices VALUES (1, 4, 4, 2, 2, 1024, 4, 256, 0, 0, 1, 2, 2, 4096, 2048, 2048, 2048, 256)"
|
356
|
+
)
|
357
|
+
results = list(self.db_queries.query_devices(filters={"device_id": 1}))
|
358
|
+
self.assertEqual(len(results), 1)
|
359
|
+
self.assertEqual(results[0].device_id, 1)
|
360
|
+
|
361
|
+
def test_query_producers_consumers(self):
|
362
|
+
self.connection.execute(
|
363
|
+
"INSERT INTO tensors VALUES (1, '(2,2)', 'float32', 'NCHW', 'default', 1, 100, 0)"
|
364
|
+
)
|
365
|
+
self.connection.execute("INSERT INTO input_tensors VALUES (2, 0, 1)")
|
366
|
+
self.connection.execute("INSERT INTO output_tensors VALUES (1, 0, 1)")
|
367
|
+
|
368
|
+
results = list(self.db_queries.query_producers_consumers())
|
369
|
+
|
370
|
+
self.assertEqual(len(results), 1)
|
371
|
+
pc = results[0]
|
372
|
+
self.assertEqual(pc.tensor_id, 1)
|
373
|
+
self.assertIn(1, pc.producers)
|
374
|
+
self.assertIn(2, pc.consumers)
|
375
|
+
|
376
|
+
def test_query_next_buffer(self):
|
377
|
+
self.connection.execute("INSERT INTO buffers VALUES (1, 1, 100, 1024, 0)")
|
378
|
+
self.connection.execute("INSERT INTO buffers VALUES (2, 1, 100, 2048, 0)")
|
379
|
+
result = self.db_queries.query_next_buffer(operation_id=1, address=100)
|
380
|
+
self.assertIsNotNone(result)
|
381
|
+
self.assertEqual(result.operation_id, 2)
|
382
|
+
|
383
|
+
|
384
|
+
class TestRemoteQueryRunner(unittest.TestCase):
|
385
|
+
|
386
|
+
def setUp(self):
|
387
|
+
self.mock_session = Mock()
|
388
|
+
self.mock_session.remote_connection.sqliteBinaryPath = "/usr/bin/sqlite3"
|
389
|
+
self.mock_session.remote_connection.host = "mockhost"
|
390
|
+
self.mock_session.remote_connection.user = "mockuser"
|
391
|
+
self.mock_session.remote_folder.remotePath = "/remote/db"
|
392
|
+
|
393
|
+
@patch("ttnn_visualizer.queries.get_client")
|
394
|
+
def test_init_with_mock_get_client(self, mock_get_client):
|
395
|
+
# Mock the SSHClient returned by get_client
|
396
|
+
mock_ssh_client = Mock()
|
397
|
+
mock_get_client.return_value = mock_ssh_client
|
398
|
+
|
399
|
+
runner = RemoteQueryRunner(session=self.mock_session)
|
400
|
+
self.assertEqual(runner.ssh_client, mock_ssh_client)
|
401
|
+
mock_get_client.assert_called_once_with(
|
402
|
+
remote_connection=self.mock_session.remote_connection
|
403
|
+
)
|
404
|
+
|
405
|
+
@patch("ttnn_visualizer.queries.get_client")
|
406
|
+
def test_execute_query(self, mock_get_client):
|
407
|
+
# Mock the SSH client
|
408
|
+
mock_ssh_client = Mock()
|
409
|
+
mock_get_client.return_value = mock_ssh_client
|
410
|
+
|
411
|
+
mock_stdout = Mock()
|
412
|
+
mock_stdout.read.return_value = b'[{"col1": "value1", "col2": "value2"}]'
|
413
|
+
mock_stderr = Mock()
|
414
|
+
mock_stderr.read.return_value = b""
|
415
|
+
mock_ssh_client.exec_command.return_value = (None, mock_stdout, mock_stderr)
|
416
|
+
|
417
|
+
runner = RemoteQueryRunner(session=self.mock_session)
|
418
|
+
|
419
|
+
query = "SELECT * FROM table WHERE id = ?"
|
420
|
+
params = [1]
|
421
|
+
results = runner.execute_query(query, params)
|
422
|
+
|
423
|
+
# Validate results
|
424
|
+
self.assertEqual(results, [("value1", "value2")])
|
425
|
+
mock_get_client.assert_called_once()
|
426
|
+
mock_ssh_client.exec_command.assert_called_once()
|
427
|
+
|
428
|
+
@patch("ttnn_visualizer.queries.get_client")
|
429
|
+
def test_close(self, mock_get_client):
|
430
|
+
# Mock the SSH client
|
431
|
+
mock_ssh_client = Mock()
|
432
|
+
mock_get_client.return_value = mock_ssh_client
|
433
|
+
|
434
|
+
runner = RemoteQueryRunner(session=self.mock_session)
|
435
|
+
|
436
|
+
runner.close()
|
437
|
+
mock_ssh_client.close.assert_called_once()
|
438
|
+
|
439
|
+
def tearDown(self):
|
440
|
+
pass
|
441
|
+
|
442
|
+
|
443
|
+
if __name__ == "__main__":
|
444
|
+
unittest.main()
|