ttnn-visualizer 0.24.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. ttnn_visualizer/__init__.py +4 -0
  2. ttnn_visualizer/app.py +193 -0
  3. ttnn_visualizer/bin/docker-entrypoint-web +16 -0
  4. ttnn_visualizer/bin/pip3-install +17 -0
  5. ttnn_visualizer/csv_queries.py +618 -0
  6. ttnn_visualizer/decorators.py +117 -0
  7. ttnn_visualizer/enums.py +12 -0
  8. ttnn_visualizer/exceptions.py +40 -0
  9. ttnn_visualizer/extensions.py +14 -0
  10. ttnn_visualizer/file_uploads.py +78 -0
  11. ttnn_visualizer/models.py +275 -0
  12. ttnn_visualizer/queries.py +388 -0
  13. ttnn_visualizer/remote_sqlite_setup.py +91 -0
  14. ttnn_visualizer/requirements.txt +24 -0
  15. ttnn_visualizer/serializers.py +249 -0
  16. ttnn_visualizer/sessions.py +245 -0
  17. ttnn_visualizer/settings.py +118 -0
  18. ttnn_visualizer/sftp_operations.py +486 -0
  19. ttnn_visualizer/sockets.py +118 -0
  20. ttnn_visualizer/ssh_client.py +85 -0
  21. ttnn_visualizer/static/assets/allPaths-CKt4gwo3.js +1 -0
  22. ttnn_visualizer/static/assets/allPathsLoader-Dzw0zTnr.js +2 -0
  23. ttnn_visualizer/static/assets/index-BXlT2rEV.js +5247 -0
  24. ttnn_visualizer/static/assets/index-CsS_OkTl.js +1 -0
  25. ttnn_visualizer/static/assets/index-DTKBo2Os.css +7 -0
  26. ttnn_visualizer/static/assets/index-DxLGmC6o.js +1 -0
  27. ttnn_visualizer/static/assets/site-BTBrvHC5.webmanifest +19 -0
  28. ttnn_visualizer/static/assets/splitPathsBySizeLoader-HHqSPeQM.js +1 -0
  29. ttnn_visualizer/static/favicon/android-chrome-192x192.png +0 -0
  30. ttnn_visualizer/static/favicon/android-chrome-512x512.png +0 -0
  31. ttnn_visualizer/static/favicon/favicon-32x32.png +0 -0
  32. ttnn_visualizer/static/favicon/favicon.svg +3 -0
  33. ttnn_visualizer/static/index.html +36 -0
  34. ttnn_visualizer/static/sample-data/cluster-desc.yaml +763 -0
  35. ttnn_visualizer/tests/__init__.py +4 -0
  36. ttnn_visualizer/tests/test_queries.py +444 -0
  37. ttnn_visualizer/tests/test_serializers.py +582 -0
  38. ttnn_visualizer/utils.py +185 -0
  39. ttnn_visualizer/views.py +794 -0
  40. ttnn_visualizer-0.24.0.dist-info/LICENSE +202 -0
  41. ttnn_visualizer-0.24.0.dist-info/LICENSE_understanding.txt +3 -0
  42. ttnn_visualizer-0.24.0.dist-info/METADATA +144 -0
  43. ttnn_visualizer-0.24.0.dist-info/RECORD +46 -0
  44. ttnn_visualizer-0.24.0.dist-info/WHEEL +5 -0
  45. ttnn_visualizer-0.24.0.dist-info/entry_points.txt +2 -0
  46. ttnn_visualizer-0.24.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,794 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ #
3
+ # SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC
4
+
5
+ import dataclasses
6
+ import json
7
+ import logging
8
+ import time
9
+ from http import HTTPStatus
10
+ from pathlib import Path
11
+ from typing import List
12
+
13
+ import yaml
14
+ from flask import Blueprint, Response, jsonify
15
+ from flask import request, current_app
16
+
17
+ from ttnn_visualizer.csv_queries import DeviceLogProfilerQueries, OpsPerformanceQueries, OpsPerformanceReportQueries
18
+ from ttnn_visualizer.decorators import with_session
19
+ from ttnn_visualizer.exceptions import DataFormatError
20
+ from ttnn_visualizer.enums import ConnectionTestStates
21
+ from ttnn_visualizer.exceptions import RemoteConnectionException
22
+ from ttnn_visualizer.file_uploads import (
23
+ extract_report_name,
24
+ save_uploaded_files,
25
+ validate_files,
26
+ )
27
+ from ttnn_visualizer.models import (
28
+ RemoteReportFolder,
29
+ RemoteConnection,
30
+ StatusMessage,
31
+ TabSession,
32
+ )
33
+ from ttnn_visualizer.queries import DatabaseQueries
34
+ from ttnn_visualizer.remote_sqlite_setup import get_sqlite_path, check_sqlite_path
35
+ from ttnn_visualizer.serializers import (
36
+ serialize_operations,
37
+ serialize_tensors,
38
+ serialize_operation,
39
+ serialize_buffer_pages,
40
+ serialize_operation_buffers,
41
+ serialize_operations_buffers,
42
+ serialize_devices,
43
+ )
44
+ from ttnn_visualizer.sessions import (
45
+ update_tab_session,
46
+ )
47
+ from ttnn_visualizer.sftp_operations import (
48
+ sync_remote_folders,
49
+ read_remote_file,
50
+ check_remote_path_for_reports,
51
+ get_remote_report_folders,
52
+ check_remote_path_exists,
53
+ get_remote_profiler_folders,
54
+ sync_remote_profiler_folders,
55
+ get_cluster_desc,
56
+ )
57
+ from ttnn_visualizer.ssh_client import get_client
58
+ from ttnn_visualizer.utils import (
59
+ read_last_synced_file,
60
+ timer,
61
+ )
62
+
63
+ logger = logging.getLogger(__name__)
64
+
65
+ api = Blueprint("api", __name__, url_prefix="/api")
66
+
67
+
68
+ @api.route("/operations", methods=["GET"])
69
+ @with_session
70
+ @timer
71
+ def operation_list(session):
72
+ with DatabaseQueries(session) as db:
73
+ operations = list(db.query_operations())
74
+ operations.sort(key=lambda o: o.operation_id)
75
+ operation_arguments = list(db.query_operation_arguments())
76
+ device_operations = list(db.query_device_operations())
77
+ stack_traces = list(db.query_stack_traces())
78
+ outputs = list(db.query_output_tensors())
79
+ tensors = list(db.query_tensors())
80
+ inputs = list(db.query_input_tensors())
81
+ devices = list(db.query_devices())
82
+ producers_consumers = list(db.query_producers_consumers())
83
+
84
+ return serialize_operations(
85
+ inputs,
86
+ operation_arguments,
87
+ operations,
88
+ outputs,
89
+ stack_traces,
90
+ tensors,
91
+ devices,
92
+ producers_consumers,
93
+ device_operations,
94
+ )
95
+
96
+
97
+ @api.route("/operations/<operation_id>", methods=["GET"])
98
+ @with_session
99
+ @timer
100
+ def operation_detail(operation_id, session):
101
+ with DatabaseQueries(session) as db:
102
+
103
+ device_id = request.args.get("device_id", None)
104
+ operations = list(db.query_operations(filters={"operation_id": operation_id}))
105
+
106
+ if not operations:
107
+ return Response(status=HTTPStatus.NOT_FOUND)
108
+
109
+ operation = operations[0]
110
+
111
+ buffers = list(
112
+ db.query_buffers(
113
+ filters={"operation_id": operation_id, "device_id": device_id}
114
+ )
115
+ )
116
+ operation_arguments = list(
117
+ db.query_operation_arguments(filters={"operation_id": operation_id})
118
+ )
119
+ stack_trace = list(
120
+ db.query_stack_traces(filters={"operation_id": operation_id})
121
+ )
122
+
123
+ if stack_trace:
124
+ stack_trace = stack_trace[0]
125
+ else:
126
+ stack_trace = None
127
+
128
+ inputs = list(db.query_input_tensors(filters={"operation_id": operation_id}))
129
+ outputs = list(db.query_output_tensors({"operation_id": operation_id}))
130
+
131
+ input_tensor_ids = [i.tensor_id for i in inputs]
132
+ output_tensor_ids = [o.tensor_id for o in outputs]
133
+ tensor_ids = input_tensor_ids + output_tensor_ids
134
+ tensors = list(db.query_tensors(filters={"tensor_id": tensor_ids}))
135
+ local_comparisons = list(
136
+ db.query_tensor_comparisons(filters={"tensor_id": tensor_ids})
137
+ )
138
+ global_comparisons = list(
139
+ db.query_tensor_comparisons(local=False, filters={"tensor_id": tensor_ids})
140
+ )
141
+
142
+ device_operations = db.query_device_operations(
143
+ filters={"operation_id": operation_id}
144
+ )
145
+
146
+ producers_consumers = list(
147
+ filter(
148
+ lambda pc: pc.tensor_id in tensor_ids, db.query_producers_consumers()
149
+ )
150
+ )
151
+
152
+ devices = list(db.query_devices())
153
+
154
+ return serialize_operation(
155
+ buffers,
156
+ inputs,
157
+ operation,
158
+ operation_arguments,
159
+ outputs,
160
+ stack_trace,
161
+ tensors,
162
+ global_comparisons,
163
+ local_comparisons,
164
+ devices,
165
+ producers_consumers,
166
+ device_operations,
167
+ )
168
+
169
+
170
+ @api.route(
171
+ "operation-history",
172
+ methods=[
173
+ "GET",
174
+ ],
175
+ )
176
+ @with_session
177
+ @timer
178
+ def operation_history(session: TabSession):
179
+ operation_history_filename = "operation_history.json"
180
+ if session.remote_connection and session.remote_connection.useRemoteQuerying:
181
+ if not session.remote_folder:
182
+ return []
183
+ operation_history = read_remote_file(
184
+ remote_connection=session.remote_connection,
185
+ remote_path=Path(
186
+ session.remote_folder.remotePath, operation_history_filename
187
+ ),
188
+ )
189
+ if not operation_history:
190
+ return []
191
+ return json.loads(operation_history)
192
+ else:
193
+ operation_history_file = (
194
+ Path(str(session.report_path)).parent / operation_history_filename
195
+ )
196
+ if not operation_history_file.exists():
197
+ return []
198
+ with open(operation_history_file, "r") as file:
199
+ return json.load(file)
200
+
201
+
202
+ @api.route("/config")
203
+ @with_session
204
+ @timer
205
+ def get_config(session: TabSession):
206
+ if session.remote_connection and session.remote_connection.useRemoteQuerying:
207
+ if not session.remote_folder:
208
+ return {}
209
+ config = read_remote_file(
210
+ remote_connection=session.remote_connection,
211
+ remote_path=Path(session.remote_folder.remotePath, "config.json"),
212
+ )
213
+ if not config:
214
+ return {}
215
+ return config
216
+ else:
217
+ config_file = Path(str(session.report_path)).parent.joinpath("config.json")
218
+ if not config_file.exists():
219
+ return {}
220
+ with open(config_file, "r") as file:
221
+ return json.load(file)
222
+
223
+
224
+ @api.route("/tensors", methods=["GET"])
225
+ @with_session
226
+ @timer
227
+ def tensors_list(session: TabSession):
228
+ with DatabaseQueries(session) as db:
229
+ device_id = request.args.get("device_id", None)
230
+ tensors = list(db.query_tensors(filters={"device_id": device_id}))
231
+ local_comparisons = list(db.query_tensor_comparisons())
232
+ global_comparisons = list(db.query_tensor_comparisons(local=False))
233
+ producers_consumers = list(db.query_producers_consumers())
234
+ return serialize_tensors(
235
+ tensors, producers_consumers, local_comparisons, global_comparisons
236
+ )
237
+
238
+
239
+ @api.route("/buffer", methods=["GET"])
240
+ @with_session
241
+ @timer
242
+ def buffer_detail(session: TabSession):
243
+ address = request.args.get("address")
244
+ operation_id = request.args.get("operation_id")
245
+
246
+ if not address or not operation_id:
247
+ return Response(status=HTTPStatus.BAD_REQUEST)
248
+
249
+ if operation_id and str.isdigit(operation_id):
250
+ operation_id = int(operation_id)
251
+ else:
252
+ return Response(status=HTTPStatus.BAD_REQUEST)
253
+
254
+ with DatabaseQueries(session) as db:
255
+ buffer = db.query_next_buffer(operation_id, address)
256
+ if not buffer:
257
+ return Response(status=HTTPStatus.NOT_FOUND)
258
+ return dataclasses.asdict(buffer)
259
+
260
+
261
+ @api.route("/buffer-pages", methods=["GET"])
262
+ @with_session
263
+ @timer
264
+ def buffer_pages(session: TabSession):
265
+ address = request.args.get("address")
266
+ operation_id = request.args.get("operation_id")
267
+ buffer_type = request.args.get("buffer_type", "")
268
+ device_id = request.args.get("device_id", None)
269
+
270
+ if address:
271
+ addresses = [addr.strip() for addr in address.split(",")]
272
+ else:
273
+ addresses = None
274
+
275
+ if buffer_type and str.isdigit(buffer_type):
276
+ buffer_type = int(buffer_type)
277
+ else:
278
+ buffer_type = None
279
+
280
+ with DatabaseQueries(session) as db:
281
+ buffers = list(
282
+ list(
283
+ db.query_buffer_pages(
284
+ filters={
285
+ "operation_id": operation_id,
286
+ "device_id": device_id,
287
+ "address": addresses,
288
+ "buffer_type": buffer_type,
289
+ }
290
+ )
291
+ )
292
+ )
293
+ return serialize_buffer_pages(buffers)
294
+
295
+
296
+ @api.route("/tensors/<tensor_id>", methods=["GET"])
297
+ @with_session
298
+ @timer
299
+ def tensor_detail(tensor_id, session: TabSession):
300
+ with DatabaseQueries(session) as db:
301
+ tensors = list(db.query_tensors(filters={"tensor_id": tensor_id}))
302
+ if not tensors:
303
+ return Response(status=HTTPStatus.NOT_FOUND)
304
+
305
+ return dataclasses.asdict(tensors[0])
306
+
307
+
308
+ @api.route("/operation-buffers", methods=["GET"])
309
+ @with_session
310
+ def get_operations_buffers(session: TabSession):
311
+ buffer_type = request.args.get("buffer_type", "")
312
+ device_id = request.args.get("device_id", None)
313
+ if buffer_type and str.isdigit(buffer_type):
314
+ buffer_type = int(buffer_type)
315
+ else:
316
+ buffer_type = None
317
+
318
+ with DatabaseQueries(session) as db:
319
+ buffers = list(
320
+ db.query_buffers(
321
+ filters={"buffer_type": buffer_type, "device_id": device_id}
322
+ )
323
+ )
324
+ operations = list(db.query_operations())
325
+ return serialize_operations_buffers(operations, buffers)
326
+
327
+
328
+ @api.route("/operation-buffers/<operation_id>", methods=["GET"])
329
+ @with_session
330
+ def get_operation_buffers(operation_id, session: TabSession):
331
+ buffer_type = request.args.get("buffer_type", "")
332
+ device_id = request.args.get("device_id", None)
333
+ if buffer_type and str.isdigit(buffer_type):
334
+ buffer_type = int(buffer_type)
335
+ else:
336
+ buffer_type = None
337
+
338
+ with DatabaseQueries(session) as db:
339
+ operations = list(db.query_operations(filters={"operation_id": operation_id}))
340
+ if not operations:
341
+ return Response(status=HTTPStatus.NOT_FOUND)
342
+ operation = operations[0]
343
+ buffers = list(
344
+ db.query_buffers(
345
+ filters={
346
+ "operation_id": operation_id,
347
+ "buffer_type": buffer_type,
348
+ "device_id": device_id,
349
+ }
350
+ )
351
+ )
352
+ if not operation:
353
+ return Response(status=HTTPStatus.NOT_FOUND)
354
+ return serialize_operation_buffers(operation, buffers)
355
+
356
+
357
+ @api.route("/profiler/device-log", methods=["GET"])
358
+ @with_session
359
+ def get_profiler_data(session: TabSession):
360
+ if not session.profiler_path:
361
+ return Response(status=HTTPStatus.NOT_FOUND)
362
+ with DeviceLogProfilerQueries(session) as csv:
363
+ result = csv.get_all_entries(as_dict=True, limit=100)
364
+ return jsonify(result)
365
+
366
+
367
+ @api.route("/profiler/perf-results", methods=["GET"])
368
+ @with_session
369
+ def get_profiler_performance_data(session: TabSession):
370
+ if not session.profiler_path:
371
+ return Response(status=HTTPStatus.NOT_FOUND)
372
+ with OpsPerformanceQueries(session) as csv:
373
+ # result = csv.query_by_op_code(op_code="(torch) contiguous", as_dict=True)
374
+ result = csv.get_all_entries(as_dict=True, limit=100)
375
+ return jsonify(result)
376
+
377
+
378
+ @api.route("/profiler/perf-results/raw", methods=["GET"])
379
+ @with_session
380
+ def get_profiler_perf_results_data_raw(session: TabSession):
381
+ if not session.profiler_path:
382
+ return Response(status=HTTPStatus.NOT_FOUND)
383
+ content = OpsPerformanceQueries.get_raw_csv(session)
384
+ return Response(
385
+ content,
386
+ mimetype="text/csv",
387
+ headers={"Content-Disposition": "attachment; filename=op_perf_results.csv"},
388
+ )
389
+
390
+
391
+ @api.route("/profiler/perf-results/report", methods=["GET"])
392
+ @with_session
393
+ def get_profiler_perf_results_report(session: TabSession):
394
+ if not session.profiler_path:
395
+ return Response(status=HTTPStatus.NOT_FOUND)
396
+
397
+ try:
398
+ report = OpsPerformanceReportQueries.generate_report(session)
399
+ except DataFormatError:
400
+ return Response(status=HTTPStatus.UNPROCESSABLE_ENTITY)
401
+
402
+ return jsonify(report), 200
403
+
404
+
405
+ @api.route("/profiler/device-log/raw", methods=["GET"])
406
+ @with_session
407
+ def get_profiler_data_raw(session: TabSession):
408
+ if not session.profiler_path:
409
+ return Response(status=HTTPStatus.NOT_FOUND)
410
+ content = DeviceLogProfilerQueries.get_raw_csv(session)
411
+ return Response(
412
+ content,
413
+ mimetype="text/csv",
414
+ headers={"Content-Disposition": "attachment; filename=profile_log_device.csv"},
415
+ )
416
+
417
+
418
+ @api.route("/profiler/device-log/zone/<zone>", methods=["GET"])
419
+ @with_session
420
+ def get_zone_statistics(zone, session: TabSession):
421
+ if not session.profiler_path:
422
+ return Response(status=HTTPStatus.NOT_FOUND)
423
+ with DeviceLogProfilerQueries(session) as csv:
424
+ result = csv.query_zone_statistics(zone_name=zone, as_dict=True)
425
+ return jsonify(result)
426
+
427
+
428
+ @api.route("/devices", methods=["GET"])
429
+ @with_session
430
+ def get_devices(session: TabSession):
431
+ with DatabaseQueries(session) as db:
432
+ devices = list(db.query_devices())
433
+ return serialize_devices(devices)
434
+
435
+
436
+ @api.route("/local/upload/report", methods=["POST"])
437
+ def create_report_files():
438
+ files = request.files.getlist("files")
439
+ report_directory = current_app.config["LOCAL_DATA_DIRECTORY"]
440
+
441
+ if not validate_files(files, {"db.sqlite", "config.json"}):
442
+ return StatusMessage(
443
+ status=ConnectionTestStates.FAILED,
444
+ message="Invalid project directory.",
445
+ ).model_dump()
446
+
447
+ report_name = extract_report_name(files)
448
+ logger.info(f"Writing report files to {report_directory}/{report_name}")
449
+
450
+ save_uploaded_files(files, report_directory, report_name)
451
+
452
+ tab_id = request.args.get("tabId")
453
+ update_tab_session(tab_id=tab_id, report_name=report_name, clear_remote=True)
454
+
455
+ return StatusMessage(
456
+ status=ConnectionTestStates.OK, message="Success."
457
+ ).model_dump()
458
+
459
+
460
+ @api.route("/local/upload/profile", methods=["POST"])
461
+ def create_profile_files():
462
+ files = request.files.getlist("files")
463
+ report_directory = Path(current_app.config["LOCAL_DATA_DIRECTORY"])
464
+ tab_id = request.args.get("tabId")
465
+
466
+ if not validate_files(
467
+ files,
468
+ {"profile_log_device.csv", "tracy_profile_log_host.tracy"},
469
+ pattern="ops_perf_results",
470
+ ):
471
+ return StatusMessage(
472
+ status=ConnectionTestStates.FAILED,
473
+ message="Invalid project directory.",
474
+ ).model_dump()
475
+
476
+ logger.info(f"Writing profile files to {report_directory} / 'profiles'")
477
+
478
+ # Construct the base directory with report_name first
479
+ target_directory = report_directory / "profiles"
480
+ target_directory.mkdir(parents=True, exist_ok=True)
481
+
482
+ if files:
483
+ first_file_path = Path(files[0].filename)
484
+ profiler_folder_name = first_file_path.parts[0]
485
+ else:
486
+ profiler_folder_name = None
487
+
488
+ updated_files = []
489
+ for file in files:
490
+ original_path = Path(file.filename)
491
+ updated_path = target_directory / original_path
492
+ updated_path.parent.mkdir(parents=True, exist_ok=True)
493
+ file.filename = str(updated_path)
494
+ updated_files.append(file)
495
+
496
+ save_uploaded_files(
497
+ updated_files,
498
+ str(report_directory),
499
+ )
500
+
501
+ update_tab_session(
502
+ tab_id=tab_id, profile_name=profiler_folder_name, clear_remote=True
503
+ )
504
+
505
+ return StatusMessage(
506
+ status=ConnectionTestStates.OK, message="Success."
507
+ ).model_dump()
508
+
509
+
510
+ @api.route("/remote/folder", methods=["POST"])
511
+ def get_remote_folders():
512
+ connection = RemoteConnection.model_validate(request.json, strict=False)
513
+ try:
514
+ remote_folders: List[RemoteReportFolder] = get_remote_report_folders(
515
+ RemoteConnection.model_validate(connection, strict=False)
516
+ )
517
+
518
+ for rf in remote_folders:
519
+ directory_name = Path(rf.remotePath).name
520
+ remote_data_directory = current_app.config["REMOTE_DATA_DIRECTORY"]
521
+ local_path = (
522
+ Path(remote_data_directory)
523
+ .joinpath(connection.host)
524
+ .joinpath(directory_name)
525
+ )
526
+ logger.info(f"Checking last synced for {directory_name}")
527
+ rf.lastSynced = read_last_synced_file(str(local_path))
528
+ if not rf.lastSynced:
529
+ logger.info(f"{directory_name} not yet synced")
530
+
531
+ return [r.model_dump() for r in remote_folders]
532
+ except RemoteConnectionException as e:
533
+ return Response(status=e.http_status, response=e.message)
534
+
535
+
536
+ @api.route("/remote/profiles", methods=["POST"])
537
+ def get_remote_profile_folders():
538
+ request_body = request.get_json()
539
+ connection = RemoteConnection.model_validate(
540
+ request_body.get("connection"), strict=False
541
+ )
542
+
543
+ try:
544
+ remote_profile_folders: List[RemoteReportFolder] = get_remote_profiler_folders(
545
+ RemoteConnection.model_validate(connection, strict=False)
546
+ )
547
+
548
+ for rf in remote_profile_folders:
549
+ profile_name = Path(rf.remotePath).name
550
+ remote_data_directory = current_app.config["REMOTE_DATA_DIRECTORY"]
551
+ local_path = (
552
+ Path(remote_data_directory)
553
+ .joinpath(connection.host)
554
+ .joinpath("profiler")
555
+ .joinpath(profile_name)
556
+ )
557
+ logger.info(f"Checking last synced for {profile_name}")
558
+ rf.lastSynced = read_last_synced_file(str(local_path))
559
+ if not rf.lastSynced:
560
+ logger.info(f"{profile_name} not yet synced")
561
+
562
+ return [r.model_dump() for r in remote_profile_folders]
563
+ except RemoteConnectionException as e:
564
+ return Response(status=e.http_status, response=e.message)
565
+
566
+
567
+ from flask import Response, jsonify
568
+ import yaml
569
+
570
+
571
+ @api.route("/cluster_desc", methods=["GET"])
572
+ @with_session
573
+ def get_cluster_description_file(session: TabSession):
574
+ if not session.remote_connection:
575
+ return jsonify({"error": "Remote connection not found"}), 404
576
+
577
+ try:
578
+ cluster_desc_file = get_cluster_desc(session.remote_connection)
579
+ if not cluster_desc_file:
580
+ return jsonify({"error": "cluster_descriptor.yaml not found"}), 404
581
+ yaml_data = yaml.safe_load(cluster_desc_file.decode("utf-8"))
582
+ return jsonify(yaml_data), 200
583
+
584
+ except yaml.YAMLError as e:
585
+ return jsonify({"error": f"Failed to parse YAML: {str(e)}"}), 400
586
+
587
+ except RemoteConnectionException as e:
588
+ return jsonify({"error": e.message}), e.http_status
589
+
590
+ except Exception as e:
591
+ return jsonify({"error": f"An unexpected error occurred: {str(e)}"}), 500
592
+
593
+
594
+ @api.route("/remote/test", methods=["POST"])
595
+ def test_remote_folder():
596
+ connection_data = request.json
597
+ connection = RemoteConnection.model_validate(connection_data)
598
+ statuses = []
599
+
600
+ def add_status(status, message):
601
+ statuses.append(StatusMessage(status=status, message=message))
602
+
603
+ def has_failures():
604
+ return any(
605
+ status.status != ConnectionTestStates.OK.value for status in statuses
606
+ )
607
+
608
+ # Test SSH Connection
609
+ try:
610
+ get_client(connection)
611
+ add_status(ConnectionTestStates.OK.value, "SSH connection established")
612
+ except RemoteConnectionException as e:
613
+ add_status(ConnectionTestStates.FAILED.value, e.message)
614
+
615
+ # Test Directory Configuration
616
+ if not has_failures():
617
+ try:
618
+ check_remote_path_exists(connection, "reportPath")
619
+ add_status(ConnectionTestStates.OK.value, "Report folder path exists")
620
+ except RemoteConnectionException as e:
621
+ add_status(ConnectionTestStates.FAILED.value, e.message)
622
+
623
+ # Test Directory Configuration (perf)
624
+ if not has_failures() and connection.performancePath:
625
+ try:
626
+ check_remote_path_exists(connection, "performancePath")
627
+ add_status(ConnectionTestStates.OK.value, "Performance folder path exists")
628
+ except RemoteConnectionException as e:
629
+ add_status(ConnectionTestStates.FAILED.value, e.message)
630
+
631
+ # Check for Project Configurations
632
+ if not has_failures():
633
+ try:
634
+ check_remote_path_for_reports(connection)
635
+ except RemoteConnectionException as e:
636
+ add_status(ConnectionTestStates.FAILED.value, e.message)
637
+
638
+ # Test Sqlite binary path configuration
639
+ if not has_failures() and connection.useRemoteQuerying:
640
+ if not connection.sqliteBinaryPath:
641
+ add_status(ConnectionTestStates.FAILED, "SQLite binary path not provided")
642
+ else:
643
+ try:
644
+ check_sqlite_path(connection)
645
+ add_status(ConnectionTestStates.OK, "SQLite binary found.")
646
+ except RemoteConnectionException as e:
647
+ add_status(ConnectionTestStates.FAILED, e.message)
648
+
649
+ return [status.model_dump() for status in statuses]
650
+
651
+
652
+ @api.route("/remote/read", methods=["POST"])
653
+ def read_remote_folder():
654
+ connection = RemoteConnection.model_validate(request.json, strict=False)
655
+ try:
656
+ content = read_remote_file(connection, remote_path=connection.path)
657
+ except RemoteConnectionException as e:
658
+ return Response(status=e.http_status, response=e.message)
659
+ return Response(status=200, response=content)
660
+
661
+
662
+ @api.route("/remote/sync", methods=["POST"])
663
+ def sync_remote_folder():
664
+ remote_dir = current_app.config["REMOTE_DATA_DIRECTORY"]
665
+ request_body = request.get_json()
666
+
667
+ # Check if request_body is None or not a dictionary
668
+ if not request_body or not isinstance(request_body, dict):
669
+ return jsonify({"error": "Invalid or missing JSON data"}), 400
670
+
671
+ folder = request_body.get("folder")
672
+ profile = request_body.get("profile", None)
673
+ tab_id = request.args.get("tabId", None)
674
+ connection = RemoteConnection.model_validate(
675
+ request_body.get("connection"), strict=False
676
+ )
677
+
678
+ if profile:
679
+ profile_folder = RemoteReportFolder.model_validate(profile, strict=False)
680
+ try:
681
+ sync_remote_profiler_folders(
682
+ connection,
683
+ remote_dir,
684
+ profile=profile_folder,
685
+ exclude_patterns=[r"/tensors(/|$)"],
686
+ sid=tab_id,
687
+ )
688
+
689
+ profile_folder.lastSynced = int(time.time())
690
+
691
+ return profile_folder.model_dump()
692
+
693
+ except RemoteConnectionException as e:
694
+ return Response(status=e.http_status, response=e.message)
695
+
696
+ try:
697
+ remote_folder = RemoteReportFolder.model_validate(folder, strict=False)
698
+
699
+ sync_remote_folders(
700
+ connection,
701
+ remote_folder.remotePath,
702
+ remote_dir,
703
+ exclude_patterns=[r"/tensors(/|$)"],
704
+ sid=tab_id,
705
+ )
706
+
707
+ remote_folder.lastSynced = int(time.time())
708
+
709
+ return remote_folder.model_dump()
710
+
711
+ except RemoteConnectionException as e:
712
+ return Response(status=e.http_status, response=e.message)
713
+
714
+
715
+ @api.route("/remote/sqlite/detect-path", methods=["POST"])
716
+ def detect_sqlite_path():
717
+ connection = request.json
718
+ connection = RemoteConnection.model_validate(connection, strict=False)
719
+ status_message = StatusMessage(
720
+ status=ConnectionTestStates.OK, message="Unable to Detect Path"
721
+ )
722
+ try:
723
+ path = get_sqlite_path(connection=connection)
724
+ if path:
725
+ status_message = StatusMessage(status=ConnectionTestStates.OK, message=path)
726
+ else:
727
+ status_message = StatusMessage(
728
+ status=ConnectionTestStates.OK, message="Unable to Detect Path"
729
+ )
730
+ except RemoteConnectionException as e:
731
+ current_app.logger.error(f"Unable to detect SQLite3 path {str(e)}")
732
+ status_message = StatusMessage(
733
+ status=ConnectionTestStates.FAILED,
734
+ message="Unable to detect SQLite3 path. See logs",
735
+ )
736
+ finally:
737
+ return status_message.model_dump()
738
+
739
+
740
+ @api.route("/remote/use", methods=["POST"])
741
+ def use_remote_folder():
742
+ data = request.get_json(force=True)
743
+ connection = data.get("connection", None)
744
+ folder = data.get("folder", None)
745
+ profile = data.get("profile", None)
746
+
747
+ if not connection or not folder:
748
+ return Response(status=HTTPStatus.BAD_REQUEST)
749
+
750
+ connection = RemoteConnection.model_validate(connection, strict=False)
751
+ folder = RemoteReportFolder.model_validate(folder, strict=False)
752
+ profile_name = None
753
+ remote_profile_folder = None
754
+ if profile:
755
+ remote_profile_folder = RemoteReportFolder.model_validate(profile, strict=False)
756
+ profile_name = remote_profile_folder.testName
757
+ report_data_directory = current_app.config["REMOTE_DATA_DIRECTORY"]
758
+ report_folder = Path(folder.remotePath).name
759
+
760
+ connection_directory = Path(report_data_directory, connection.host, report_folder)
761
+
762
+ if not connection.useRemoteQuerying and not connection_directory.exists():
763
+ return Response(
764
+ status=HTTPStatus.INTERNAL_SERVER_ERROR,
765
+ response=f"{connection_directory} does not exist.",
766
+ )
767
+
768
+ remote_path = f"{Path(report_data_directory).name}/{connection.host}/{connection_directory.name}"
769
+
770
+ tab_id = request.args.get("tabId")
771
+ current_app.logger.info(f"Setting active report for {tab_id} - {remote_path}")
772
+
773
+ update_tab_session(
774
+ tab_id=tab_id,
775
+ report_name=report_folder,
776
+ profile_name=profile_name,
777
+ remote_connection=connection,
778
+ remote_folder=folder,
779
+ remote_profile_folder=remote_profile_folder,
780
+ )
781
+
782
+ return Response(status=HTTPStatus.OK)
783
+
784
+
785
+ @api.route("/up", methods=["GET", "POST"])
786
+ def health_check():
787
+ return Response(status=HTTPStatus.OK)
788
+
789
+
790
+ @api.route("/session", methods=["GET"])
791
+ @with_session
792
+ def get_tab_session(session: TabSession):
793
+ # Used to gate UI functions if no report is active
794
+ return session.model_dump()