ttnn-visualizer 0.24.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ttnn_visualizer/__init__.py +4 -0
- ttnn_visualizer/app.py +193 -0
- ttnn_visualizer/bin/docker-entrypoint-web +16 -0
- ttnn_visualizer/bin/pip3-install +17 -0
- ttnn_visualizer/csv_queries.py +618 -0
- ttnn_visualizer/decorators.py +117 -0
- ttnn_visualizer/enums.py +12 -0
- ttnn_visualizer/exceptions.py +40 -0
- ttnn_visualizer/extensions.py +14 -0
- ttnn_visualizer/file_uploads.py +78 -0
- ttnn_visualizer/models.py +275 -0
- ttnn_visualizer/queries.py +388 -0
- ttnn_visualizer/remote_sqlite_setup.py +91 -0
- ttnn_visualizer/requirements.txt +24 -0
- ttnn_visualizer/serializers.py +249 -0
- ttnn_visualizer/sessions.py +245 -0
- ttnn_visualizer/settings.py +118 -0
- ttnn_visualizer/sftp_operations.py +486 -0
- ttnn_visualizer/sockets.py +118 -0
- ttnn_visualizer/ssh_client.py +85 -0
- ttnn_visualizer/static/assets/allPaths-CKt4gwo3.js +1 -0
- ttnn_visualizer/static/assets/allPathsLoader-Dzw0zTnr.js +2 -0
- ttnn_visualizer/static/assets/index-BXlT2rEV.js +5247 -0
- ttnn_visualizer/static/assets/index-CsS_OkTl.js +1 -0
- ttnn_visualizer/static/assets/index-DTKBo2Os.css +7 -0
- ttnn_visualizer/static/assets/index-DxLGmC6o.js +1 -0
- ttnn_visualizer/static/assets/site-BTBrvHC5.webmanifest +19 -0
- ttnn_visualizer/static/assets/splitPathsBySizeLoader-HHqSPeQM.js +1 -0
- ttnn_visualizer/static/favicon/android-chrome-192x192.png +0 -0
- ttnn_visualizer/static/favicon/android-chrome-512x512.png +0 -0
- ttnn_visualizer/static/favicon/favicon-32x32.png +0 -0
- ttnn_visualizer/static/favicon/favicon.svg +3 -0
- ttnn_visualizer/static/index.html +36 -0
- ttnn_visualizer/static/sample-data/cluster-desc.yaml +763 -0
- ttnn_visualizer/tests/__init__.py +4 -0
- ttnn_visualizer/tests/test_queries.py +444 -0
- ttnn_visualizer/tests/test_serializers.py +582 -0
- ttnn_visualizer/utils.py +185 -0
- ttnn_visualizer/views.py +794 -0
- ttnn_visualizer-0.24.0.dist-info/LICENSE +202 -0
- ttnn_visualizer-0.24.0.dist-info/LICENSE_understanding.txt +3 -0
- ttnn_visualizer-0.24.0.dist-info/METADATA +144 -0
- ttnn_visualizer-0.24.0.dist-info/RECORD +46 -0
- ttnn_visualizer-0.24.0.dist-info/WHEEL +5 -0
- ttnn_visualizer-0.24.0.dist-info/entry_points.txt +2 -0
- ttnn_visualizer-0.24.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,618 @@
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
2
|
+
#
|
3
|
+
# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
|
4
|
+
import csv
|
5
|
+
import os
|
6
|
+
import tempfile
|
7
|
+
from io import StringIO
|
8
|
+
from pathlib import Path
|
9
|
+
from typing import List, Dict, Union, Optional
|
10
|
+
|
11
|
+
import pandas as pd
|
12
|
+
from tt_perf_report import perf_report
|
13
|
+
|
14
|
+
from ttnn_visualizer.exceptions import DataFormatError
|
15
|
+
from ttnn_visualizer.models import TabSession
|
16
|
+
from ttnn_visualizer.ssh_client import get_client
|
17
|
+
|
18
|
+
|
19
|
+
class LocalCSVQueryRunner:
|
20
|
+
def __init__(self, file_path: str, offset: int = 0):
|
21
|
+
self.file_path = file_path
|
22
|
+
self.offset = offset
|
23
|
+
self.df: Optional[pd.DataFrame] = None
|
24
|
+
|
25
|
+
def __enter__(self):
|
26
|
+
# Load the CSV file
|
27
|
+
self.df = pd.read_csv(self.file_path, skiprows=self.offset)
|
28
|
+
return self
|
29
|
+
|
30
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
31
|
+
self.df = None
|
32
|
+
|
33
|
+
def get_csv_header(self) -> Dict[str, int]:
|
34
|
+
if self.df is None:
|
35
|
+
raise RuntimeError(
|
36
|
+
"DataFrame is not loaded. Ensure the runner is used within a context."
|
37
|
+
)
|
38
|
+
return {col: idx + 1 for idx, col in enumerate(self.df.columns)}
|
39
|
+
|
40
|
+
def execute_query(
|
41
|
+
self,
|
42
|
+
columns: List[str],
|
43
|
+
filters: Dict[str, Union[str, None]] = None,
|
44
|
+
as_dict: bool = False,
|
45
|
+
limit: int = None,
|
46
|
+
) -> Union[
|
47
|
+
List[List[Optional[Union[str, float, int]]]],
|
48
|
+
List[Dict[str, Optional[Union[str, float, int]]]],
|
49
|
+
]:
|
50
|
+
"""
|
51
|
+
Executes a query on the loaded DataFrame with optional limit.
|
52
|
+
:param columns: List of columns to select.
|
53
|
+
:param filters: Dictionary of column-value pairs to filter the rows.
|
54
|
+
:param as_dict: Whether to return results as a list of dictionaries.
|
55
|
+
:param limit: Maximum number of rows to return.
|
56
|
+
:return: List of lists or dictionaries containing the result rows.
|
57
|
+
"""
|
58
|
+
if self.df is None:
|
59
|
+
raise RuntimeError(
|
60
|
+
"DataFrame is not loaded. Ensure the runner is used within a context."
|
61
|
+
)
|
62
|
+
|
63
|
+
# Apply filters if provided
|
64
|
+
df_filtered = self.df
|
65
|
+
if filters:
|
66
|
+
for col, value in filters.items():
|
67
|
+
if value is None:
|
68
|
+
df_filtered = df_filtered[df_filtered[col].isna()]
|
69
|
+
else:
|
70
|
+
df_filtered = df_filtered[df_filtered[col] == value]
|
71
|
+
|
72
|
+
# Select specified columns
|
73
|
+
if columns:
|
74
|
+
result_df = df_filtered[columns]
|
75
|
+
else:
|
76
|
+
result_df = df_filtered
|
77
|
+
|
78
|
+
# Apply limit if specified
|
79
|
+
if limit is not None:
|
80
|
+
result_df = result_df.head(limit)
|
81
|
+
|
82
|
+
# Replace NaN with None in the query results
|
83
|
+
sanitized_df = result_df.applymap(lambda x: None if pd.isna(x) else x)
|
84
|
+
|
85
|
+
if as_dict:
|
86
|
+
sanitized_columns = {
|
87
|
+
col: col.replace(" ", "_") for col in sanitized_df.columns
|
88
|
+
}
|
89
|
+
sanitized_df = sanitized_df.copy()
|
90
|
+
sanitized_df.rename(columns=sanitized_columns, inplace=True)
|
91
|
+
return sanitized_df.to_dict(orient="records")
|
92
|
+
|
93
|
+
return sanitized_df.values.tolist()
|
94
|
+
|
95
|
+
|
96
|
+
class RemoteCSVQueryRunner:
|
97
|
+
def __init__(
|
98
|
+
self, file_path: str, remote_connection, sep: str = ",", offset: int = 0
|
99
|
+
):
|
100
|
+
"""
|
101
|
+
Initialize the RemoteCSVQueryRunner.
|
102
|
+
|
103
|
+
:param file_path: Path to the remote file.
|
104
|
+
:param remote_connection: RemoteConnection object for SSH access.
|
105
|
+
:param sep: Separator used in the CSV file.
|
106
|
+
:param offset: Number of lines to skip before treating the first valid line as headers.
|
107
|
+
"""
|
108
|
+
self.file_path = file_path
|
109
|
+
self.remote_connection = remote_connection
|
110
|
+
self.sep = sep
|
111
|
+
self.offset = offset
|
112
|
+
self.ssh_client = get_client(remote_connection)
|
113
|
+
|
114
|
+
def execute_query(
|
115
|
+
self,
|
116
|
+
filters: Optional[Dict[str, str]] = None, # Allow unsanitized filter keys
|
117
|
+
as_dict: bool = False, # Convert rows to dictionaries if True
|
118
|
+
limit: int = None,
|
119
|
+
columns=None,
|
120
|
+
) -> Union[List[List[str]], List[Dict[str, str]]]:
|
121
|
+
"""
|
122
|
+
Fetch rows with optional filtering and limit, returning either raw rows or dictionaries.
|
123
|
+
:param filters: Dictionary of unsanitized column filters (e.g., {"zone name": "BRISC-FW"}).
|
124
|
+
:param as_dict: Whether to return results as a list of dictionaries.
|
125
|
+
:param limit: Maximum number of rows to return.
|
126
|
+
:return: List of rows as lists or dictionaries.
|
127
|
+
"""
|
128
|
+
# Fetch header row, accounting for the offset
|
129
|
+
header_cmd = f"head -n {self.offset + 1} {self.file_path} | tail -n 1"
|
130
|
+
stdin, stdout, stderr = self.ssh_client.exec_command(header_cmd)
|
131
|
+
raw_header = stdout.read().decode("utf-8").strip()
|
132
|
+
error = stderr.read().decode("utf-8").strip()
|
133
|
+
|
134
|
+
if error:
|
135
|
+
raise RuntimeError(f"Error fetching header row: {error}")
|
136
|
+
|
137
|
+
# Sanitize headers
|
138
|
+
headers = [
|
139
|
+
col.strip().replace(" ", "_").lower() for col in raw_header.split(self.sep)
|
140
|
+
]
|
141
|
+
|
142
|
+
# Build the AWK command for filtering
|
143
|
+
awk_filter = ""
|
144
|
+
if filters:
|
145
|
+
filter_conditions = []
|
146
|
+
for unsanitized_col, value in filters.items():
|
147
|
+
# Sanitize the filter key
|
148
|
+
sanitized_col = unsanitized_col.strip().replace(" ", "_").lower()
|
149
|
+
if sanitized_col in headers:
|
150
|
+
col_idx = headers.index(sanitized_col) + 1
|
151
|
+
filter_conditions.append(f'${col_idx} == "{value}"')
|
152
|
+
else:
|
153
|
+
print(
|
154
|
+
f"WARNING: Column '{unsanitized_col}' (sanitized: '{sanitized_col}') not found in headers."
|
155
|
+
)
|
156
|
+
awk_filter = " && ".join(filter_conditions)
|
157
|
+
|
158
|
+
# Build AWK command
|
159
|
+
limit_clause = f"| head -n {limit}" if limit else ""
|
160
|
+
awk_cmd = f"awk -F'{self.sep}' 'NR > {self.offset + 1} {f'&& {awk_filter}' if awk_filter else ''} {{print}}' {self.file_path} {limit_clause}"
|
161
|
+
|
162
|
+
stdin, stdout, stderr = self.ssh_client.exec_command(awk_cmd)
|
163
|
+
output = stdout.read().decode("utf-8").strip()
|
164
|
+
error = stderr.read().decode("utf-8").strip()
|
165
|
+
|
166
|
+
if error:
|
167
|
+
raise RuntimeError(f"Error executing AWK command: {error}")
|
168
|
+
|
169
|
+
# Split rows into lists of strings
|
170
|
+
rows = [
|
171
|
+
[field.strip().strip('"') for field in line.split(self.sep)]
|
172
|
+
for line in output.splitlines()
|
173
|
+
]
|
174
|
+
if as_dict:
|
175
|
+
# Convert rows to dictionaries
|
176
|
+
result = [dict(zip(headers, row)) for row in rows]
|
177
|
+
|
178
|
+
if columns:
|
179
|
+
sanitized_columns = [
|
180
|
+
col.strip().replace(" ", "_").lower() for col in columns
|
181
|
+
]
|
182
|
+
result = [
|
183
|
+
{
|
184
|
+
key: value
|
185
|
+
for key, value in row.items()
|
186
|
+
if key in sanitized_columns
|
187
|
+
}
|
188
|
+
for row in result
|
189
|
+
]
|
190
|
+
print(f"DEBUG: Filtered columns: {sanitized_columns}")
|
191
|
+
return result
|
192
|
+
return rows
|
193
|
+
|
194
|
+
def execute_query_raw(self, limit: int = None) -> List[str]:
|
195
|
+
"""
|
196
|
+
Fetch raw lines from the remote CSV file, accounting for the offset.
|
197
|
+
|
198
|
+
:param limit: Maximum number of rows to fetch (including offset rows).
|
199
|
+
:return: List of raw rows as strings.
|
200
|
+
"""
|
201
|
+
total_lines = self.offset + limit if limit else ""
|
202
|
+
cmd = (
|
203
|
+
f"head -n {total_lines} {self.file_path}"
|
204
|
+
if total_lines
|
205
|
+
else f"cat {self.file_path}"
|
206
|
+
)
|
207
|
+
stdin, stdout, stderr = self.ssh_client.exec_command(cmd)
|
208
|
+
output = stdout.read().decode("utf-8").strip()
|
209
|
+
error = stderr.read().decode("utf-8").strip()
|
210
|
+
|
211
|
+
if error:
|
212
|
+
raise RuntimeError(f"Error fetching raw rows: {error}")
|
213
|
+
|
214
|
+
return output.splitlines()[self.offset :]
|
215
|
+
|
216
|
+
def get_csv_header(self) -> Dict[str, int]:
|
217
|
+
"""
|
218
|
+
Retrieve the CSV headers as a dictionary mapping column names to their indices (1-based).
|
219
|
+
:return: Dictionary of headers.
|
220
|
+
"""
|
221
|
+
header_cmd = f"head -n {self.offset + 1} {self.file_path} | tail -n 1"
|
222
|
+
stdin, stdout, stderr = self.ssh_client.exec_command(header_cmd)
|
223
|
+
header = stdout.read().decode("utf-8").strip()
|
224
|
+
error = stderr.read().decode("utf-8").strip()
|
225
|
+
|
226
|
+
if error:
|
227
|
+
raise RuntimeError(f"Error reading CSV header: {error}")
|
228
|
+
|
229
|
+
# Trim spaces in header names
|
230
|
+
column_names = [name.strip() for name in header.split(self.sep)]
|
231
|
+
return {name: idx + 1 for idx, name in enumerate(column_names)}
|
232
|
+
|
233
|
+
def build_awk_filter(
|
234
|
+
self, column_indices: Dict[str, int], filters: Dict[str, str]
|
235
|
+
) -> str:
|
236
|
+
if not filters:
|
237
|
+
return ""
|
238
|
+
conditions = [
|
239
|
+
f'${column_indices[col]} == "{val}"' for col, val in filters.items()
|
240
|
+
]
|
241
|
+
return " && ".join(conditions)
|
242
|
+
|
243
|
+
def build_awk_columns(
|
244
|
+
self, column_indices: Dict[str, int], columns: List[str]
|
245
|
+
) -> str:
|
246
|
+
return ", ".join([f"${column_indices[col]}" for col in columns])
|
247
|
+
|
248
|
+
def __enter__(self):
|
249
|
+
"""
|
250
|
+
Enable usage with context management.
|
251
|
+
"""
|
252
|
+
return self
|
253
|
+
|
254
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
255
|
+
"""
|
256
|
+
Clean up the SSH connection when exiting context.
|
257
|
+
"""
|
258
|
+
if self.ssh_client:
|
259
|
+
self.ssh_client.close()
|
260
|
+
|
261
|
+
|
262
|
+
class DeviceLogProfilerQueries:
|
263
|
+
DEVICE_LOG_FILE = "profile_log_device.csv"
|
264
|
+
LOCAL_PROFILER_DIRECTORY = "profiler"
|
265
|
+
DEVICE_LOG_COLUMNS = [
|
266
|
+
"PCIe slot",
|
267
|
+
"core_x",
|
268
|
+
"core_y",
|
269
|
+
"RISC processor type",
|
270
|
+
"timer_id",
|
271
|
+
"time[cycles since reset]",
|
272
|
+
"stat value",
|
273
|
+
"run ID",
|
274
|
+
"run host ID",
|
275
|
+
"zone name",
|
276
|
+
"zone phase",
|
277
|
+
"source line",
|
278
|
+
"source file",
|
279
|
+
]
|
280
|
+
|
281
|
+
def __init__(self, session: TabSession):
|
282
|
+
"""
|
283
|
+
Initialize the profiler with a session object.
|
284
|
+
The session determines whether to use a local or remote runner.
|
285
|
+
"""
|
286
|
+
self.session = session
|
287
|
+
self.runner = None
|
288
|
+
|
289
|
+
def __enter__(self):
|
290
|
+
"""
|
291
|
+
Determine the appropriate query runner based on the session's remote connection.
|
292
|
+
"""
|
293
|
+
|
294
|
+
is_remote = self.session.remote_connection
|
295
|
+
use_remote_querying = False
|
296
|
+
|
297
|
+
if is_remote:
|
298
|
+
use_remote_querying = self.session.remote_connection.useRemoteQuerying
|
299
|
+
|
300
|
+
# Determine if this is a local or remote operation
|
301
|
+
if is_remote and use_remote_querying:
|
302
|
+
remote_profiler_folder = self.session.remote_profile_folder
|
303
|
+
file_path = f"{remote_profiler_folder.remotePath}/{self.DEVICE_LOG_FILE}"
|
304
|
+
self.runner = RemoteCSVQueryRunner(
|
305
|
+
file_path=file_path,
|
306
|
+
remote_connection=self.session.remote_connection,
|
307
|
+
offset=1, # Skip the first line for device log files
|
308
|
+
)
|
309
|
+
else:
|
310
|
+
self.runner = LocalCSVQueryRunner(
|
311
|
+
file_path=Path(self.session.profiler_path).joinpath(
|
312
|
+
self.DEVICE_LOG_FILE
|
313
|
+
),
|
314
|
+
offset=1, # Skip the first line for device log files
|
315
|
+
)
|
316
|
+
|
317
|
+
self.runner.__enter__()
|
318
|
+
|
319
|
+
if not is_remote or (is_remote and not use_remote_querying):
|
320
|
+
self.runner.df.columns = self.DEVICE_LOG_COLUMNS
|
321
|
+
self.runner.df.columns = self.runner.df.columns.str.strip()
|
322
|
+
|
323
|
+
return self
|
324
|
+
|
325
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
326
|
+
"""
|
327
|
+
Ensure resources are cleaned up when exiting the context.
|
328
|
+
"""
|
329
|
+
if self.runner:
|
330
|
+
self.runner.__exit__(exc_type, exc_val, exc_tb)
|
331
|
+
|
332
|
+
def query_by_timer_id(
|
333
|
+
self, timer_id: str, as_dict: bool = False
|
334
|
+
) -> Union[List[List[str]], List[Dict[str, str]]]:
|
335
|
+
"""
|
336
|
+
Example query: Filter rows by a specific timer_id and optionally return results as dictionaries.
|
337
|
+
"""
|
338
|
+
return self.runner.execute_query(
|
339
|
+
columns=[],
|
340
|
+
filters={"timer_id": timer_id},
|
341
|
+
as_dict=as_dict,
|
342
|
+
)
|
343
|
+
|
344
|
+
def query_zone_statistics(
|
345
|
+
self, zone_name: str, as_dict: bool = False, limit: int = None
|
346
|
+
) -> Union[List[List[str]], List[Dict[str, str]]]:
|
347
|
+
"""
|
348
|
+
Example query: Retrieve statistics for a specific zone name.
|
349
|
+
"""
|
350
|
+
return self.runner.execute_query(
|
351
|
+
columns=[],
|
352
|
+
filters={"zone name": zone_name},
|
353
|
+
as_dict=as_dict,
|
354
|
+
limit=limit,
|
355
|
+
)
|
356
|
+
|
357
|
+
def get_all_entries(
|
358
|
+
self, as_dict: bool = False, limit: int = None
|
359
|
+
) -> List[List[str]]:
|
360
|
+
"""
|
361
|
+
Fetch all entries from the device log.
|
362
|
+
"""
|
363
|
+
return self.runner.execute_query(
|
364
|
+
columns=self.DEVICE_LOG_COLUMNS, as_dict=as_dict, limit=limit
|
365
|
+
)
|
366
|
+
|
367
|
+
@staticmethod
|
368
|
+
def get_raw_csv(session: TabSession):
|
369
|
+
from ttnn_visualizer.sftp_operations import read_remote_file
|
370
|
+
|
371
|
+
if (
|
372
|
+
not session.remote_connection
|
373
|
+
or session.remote_connection
|
374
|
+
and not session.remote_connection.useRemoteQuerying
|
375
|
+
):
|
376
|
+
file_path = Path(
|
377
|
+
session.profiler_path, DeviceLogProfilerQueries.DEVICE_LOG_FILE
|
378
|
+
)
|
379
|
+
with open(file_path, "r") as f:
|
380
|
+
return f.read()
|
381
|
+
else:
|
382
|
+
profiler_folder = session.remote_profile_folder
|
383
|
+
return read_remote_file(
|
384
|
+
session.remote_connection,
|
385
|
+
f"{profiler_folder.remotePath}/{DeviceLogProfilerQueries.DEVICE_LOG_FILE}",
|
386
|
+
)
|
387
|
+
|
388
|
+
|
389
|
+
class OpsPerformanceQueries:
|
390
|
+
PERF_RESULTS_PREFIX = "ops_perf_results"
|
391
|
+
PERF_RESULTS_COLUMNS = [
|
392
|
+
"OP CODE",
|
393
|
+
"OP TYPE",
|
394
|
+
"GLOBAL CALL COUNT",
|
395
|
+
"DEVICE ID",
|
396
|
+
"ATTRIBUTES",
|
397
|
+
"MATH FIDELITY",
|
398
|
+
"CORE COUNT",
|
399
|
+
"PARALLELIZATION STRATEGY",
|
400
|
+
"HOST START TS",
|
401
|
+
"HOST END TS",
|
402
|
+
"HOST DURATION [ns]",
|
403
|
+
"DEVICE FW START CYCLE",
|
404
|
+
"DEVICE FW END CYCLE",
|
405
|
+
"OP TO OP LATENCY [ns]",
|
406
|
+
"DEVICE FW DURATION [ns]",
|
407
|
+
"DEVICE KERNEL DURATION [ns]",
|
408
|
+
"DEVICE BRISC KERNEL DURATION [ns]",
|
409
|
+
"DEVICE NCRISC KERNEL DURATION [ns]",
|
410
|
+
"DEVICE TRISC0 KERNEL DURATION [ns]",
|
411
|
+
"DEVICE TRISC1 KERNEL DURATION [ns]",
|
412
|
+
"DEVICE TRISC2 KERNEL DURATION [ns]",
|
413
|
+
"DEVICE ERISC KERNEL DURATION [ns]",
|
414
|
+
"DEVICE COMPUTE CB WAIT FRONT [ns]",
|
415
|
+
"DEVICE COMPUTE CB RESERVE BACK [ns]",
|
416
|
+
"INPUT_0_W",
|
417
|
+
"INPUT_0_Z",
|
418
|
+
"INPUT_0_Y",
|
419
|
+
"INPUT_0_X",
|
420
|
+
"INPUT_0_LAYOUT",
|
421
|
+
"INPUT_0_DATATYPE",
|
422
|
+
"INPUT_0_MEMORY",
|
423
|
+
"INPUT_1_W",
|
424
|
+
"INPUT_1_Z",
|
425
|
+
"INPUT_1_Y",
|
426
|
+
"INPUT_1_X",
|
427
|
+
"INPUT_1_LAYOUT",
|
428
|
+
"INPUT_1_DATATYPE",
|
429
|
+
"INPUT_1_MEMORY",
|
430
|
+
"INPUT_2_W",
|
431
|
+
"INPUT_2_Z",
|
432
|
+
"INPUT_2_Y",
|
433
|
+
"INPUT_2_X",
|
434
|
+
"INPUT_2_LAYOUT",
|
435
|
+
"INPUT_2_DATATYPE",
|
436
|
+
"INPUT_2_MEMORY",
|
437
|
+
"OUTPUT_0_W",
|
438
|
+
"OUTPUT_0_Z",
|
439
|
+
"OUTPUT_0_Y",
|
440
|
+
"OUTPUT_0_X",
|
441
|
+
"OUTPUT_0_LAYOUT",
|
442
|
+
"OUTPUT_0_DATATYPE",
|
443
|
+
"OUTPUT_0_MEMORY",
|
444
|
+
"COMPUTE KERNEL SOURCE",
|
445
|
+
"COMPUTE KERNEL HASH",
|
446
|
+
"DATA MOVEMENT KERNEL SOURCE",
|
447
|
+
"DATA MOVEMENT KERNEL HASH",
|
448
|
+
"PM IDEAL [ns]",
|
449
|
+
"PM COMPUTE [ns]",
|
450
|
+
"PM BANDWIDTH [ns]",
|
451
|
+
"PM REQ I BW",
|
452
|
+
"PM REQ O BW",
|
453
|
+
"CompileProgram_TT_HOST_FUNC [ns]",
|
454
|
+
"HWCommandQueue_write_buffer_TT_HOST_FUNC [ns]",
|
455
|
+
]
|
456
|
+
|
457
|
+
def __init__(self, session: TabSession):
|
458
|
+
"""
|
459
|
+
Initialize the performance profiler with a session object.
|
460
|
+
"""
|
461
|
+
self.session = session
|
462
|
+
self.runner = None
|
463
|
+
|
464
|
+
def __enter__(self):
|
465
|
+
"""
|
466
|
+
|
467
|
+
:return:
|
468
|
+
"""
|
469
|
+
file_path = OpsPerformanceQueries.get_local_ops_perf_file_path(self.session)
|
470
|
+
self.runner = LocalCSVQueryRunner(file_path=file_path, offset=1)
|
471
|
+
self.runner.__enter__()
|
472
|
+
|
473
|
+
# Set up columns
|
474
|
+
self.runner.df.columns = self.PERF_RESULTS_COLUMNS
|
475
|
+
self.runner.df.columns = self.runner.df.columns.str.strip()
|
476
|
+
|
477
|
+
return self
|
478
|
+
|
479
|
+
@staticmethod
|
480
|
+
def get_local_ops_perf_file_path(session):
|
481
|
+
profiler_path = Path(session.profiler_path)
|
482
|
+
|
483
|
+
# Find the latest file with the correct prefix
|
484
|
+
perf_files = list(
|
485
|
+
profiler_path.glob(f"{OpsPerformanceQueries.PERF_RESULTS_PREFIX}_*.csv")
|
486
|
+
)
|
487
|
+
if not perf_files:
|
488
|
+
raise FileNotFoundError("No performance results file found.")
|
489
|
+
|
490
|
+
# Use the latest file
|
491
|
+
latest_file = max(perf_files, key=os.path.getctime)
|
492
|
+
return str(latest_file)
|
493
|
+
|
494
|
+
@staticmethod
|
495
|
+
def get_remote_ops_perf_file_path(session):
|
496
|
+
from ttnn_visualizer.sftp_operations import resolve_file_path
|
497
|
+
|
498
|
+
remote_profile_folder = session.remote_profile_folder.remotePath
|
499
|
+
return resolve_file_path(
|
500
|
+
session.remote_connection,
|
501
|
+
f"{remote_profile_folder}/{OpsPerformanceQueries.PERF_RESULTS_PREFIX}*",
|
502
|
+
)
|
503
|
+
|
504
|
+
@staticmethod
|
505
|
+
def get_raw_csv(session):
|
506
|
+
from ttnn_visualizer.sftp_operations import read_remote_file
|
507
|
+
|
508
|
+
if (
|
509
|
+
not session.remote_connection
|
510
|
+
or session.remote_connection
|
511
|
+
and not session.remote_connection.useRemoteQuerying
|
512
|
+
):
|
513
|
+
with open(OpsPerformanceQueries.get_local_ops_perf_file_path(session)) as f:
|
514
|
+
return f.read()
|
515
|
+
else:
|
516
|
+
path = OpsPerformanceQueries.get_remote_ops_perf_file_path(session)
|
517
|
+
return read_remote_file(session.remote_connection, path)
|
518
|
+
|
519
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
520
|
+
"""
|
521
|
+
Clean up resources when exiting the context.
|
522
|
+
"""
|
523
|
+
if self.runner:
|
524
|
+
self.runner.__exit__(exc_type, exc_val, exc_tb)
|
525
|
+
|
526
|
+
def query_by_op_code(
|
527
|
+
self, op_code: str, as_dict: bool = False
|
528
|
+
) -> Union[List[List[str]], List[Dict[str, str]]]:
|
529
|
+
"""
|
530
|
+
Query for rows with a specific OP CODE.
|
531
|
+
"""
|
532
|
+
return self.runner.execute_query(
|
533
|
+
filters={"OP CODE": op_code}, as_dict=as_dict, columns=None
|
534
|
+
)
|
535
|
+
|
536
|
+
def get_all_entries(
|
537
|
+
self, as_dict: bool = False, limit: int = None
|
538
|
+
) -> List[List[str]]:
|
539
|
+
"""
|
540
|
+
Fetch all entries from the performance log.
|
541
|
+
"""
|
542
|
+
return self.runner.execute_query(
|
543
|
+
columns=self.PERF_RESULTS_COLUMNS, as_dict=as_dict, limit=limit
|
544
|
+
)
|
545
|
+
|
546
|
+
|
547
|
+
class OpsPerformanceReportQueries:
|
548
|
+
REPORT_COLUMNS = [
|
549
|
+
"id",
|
550
|
+
"total_percent",
|
551
|
+
"bound",
|
552
|
+
"op_code",
|
553
|
+
"device_time",
|
554
|
+
"op_to_op_gap",
|
555
|
+
"cores",
|
556
|
+
"dram",
|
557
|
+
"dram_percent",
|
558
|
+
"flops",
|
559
|
+
"flops_percent",
|
560
|
+
"math_fidelity",
|
561
|
+
"output_datatype",
|
562
|
+
"input_0_datatype",
|
563
|
+
"input_1_datatype",
|
564
|
+
"dram_sharded",
|
565
|
+
"input_0_memory",
|
566
|
+
"inner_dim_block_size",
|
567
|
+
"output_subblock_h",
|
568
|
+
"output_subblock_w",
|
569
|
+
"advice",
|
570
|
+
"raw_op_code"
|
571
|
+
]
|
572
|
+
|
573
|
+
DEFAULT_SIGNPOST = None
|
574
|
+
DEFAULT_IGNORE_SIGNPOSTS = None
|
575
|
+
DEFAULT_MIN_PERCENTAGE = 0.5
|
576
|
+
DEFAULT_ID_RANGE = None
|
577
|
+
DEFAULT_NO_ADVICE = False
|
578
|
+
DEFAULT_TRACING_MODE = False
|
579
|
+
|
580
|
+
@classmethod
|
581
|
+
def generate_report(cls, session):
|
582
|
+
raw_csv = OpsPerformanceQueries.get_raw_csv(session)
|
583
|
+
csv_file = StringIO(raw_csv)
|
584
|
+
csv_output_file = tempfile.mktemp(suffix=".csv")
|
585
|
+
perf_report.generate_perf_report(
|
586
|
+
csv_file,
|
587
|
+
cls.DEFAULT_SIGNPOST,
|
588
|
+
cls.DEFAULT_IGNORE_SIGNPOSTS,
|
589
|
+
cls.DEFAULT_MIN_PERCENTAGE,
|
590
|
+
cls.DEFAULT_ID_RANGE,
|
591
|
+
csv_output_file,
|
592
|
+
cls.DEFAULT_NO_ADVICE,
|
593
|
+
cls.DEFAULT_TRACING_MODE,
|
594
|
+
True,
|
595
|
+
True,
|
596
|
+
)
|
597
|
+
|
598
|
+
report = []
|
599
|
+
|
600
|
+
try:
|
601
|
+
with open(csv_output_file, newline="") as csvfile:
|
602
|
+
reader = csv.reader(csvfile, delimiter=",")
|
603
|
+
next(reader, None)
|
604
|
+
for row in reader:
|
605
|
+
processed_row = {
|
606
|
+
column: row[index] for index, column in enumerate(cls.REPORT_COLUMNS) if index < len(row)
|
607
|
+
}
|
608
|
+
if "advice" in processed_row and processed_row["advice"]:
|
609
|
+
processed_row["advice"] = processed_row["advice"].split(" • ")
|
610
|
+
else:
|
611
|
+
processed_row["advice"] = []
|
612
|
+
report.append(processed_row)
|
613
|
+
except csv.Error as e:
|
614
|
+
raise DataFormatError() from e
|
615
|
+
finally:
|
616
|
+
os.unlink(csv_output_file)
|
617
|
+
|
618
|
+
return report
|