ttnn-visualizer 0.24.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ttnn_visualizer/__init__.py +4 -0
- ttnn_visualizer/app.py +193 -0
- ttnn_visualizer/bin/docker-entrypoint-web +16 -0
- ttnn_visualizer/bin/pip3-install +17 -0
- ttnn_visualizer/csv_queries.py +618 -0
- ttnn_visualizer/decorators.py +117 -0
- ttnn_visualizer/enums.py +12 -0
- ttnn_visualizer/exceptions.py +40 -0
- ttnn_visualizer/extensions.py +14 -0
- ttnn_visualizer/file_uploads.py +78 -0
- ttnn_visualizer/models.py +275 -0
- ttnn_visualizer/queries.py +388 -0
- ttnn_visualizer/remote_sqlite_setup.py +91 -0
- ttnn_visualizer/requirements.txt +24 -0
- ttnn_visualizer/serializers.py +249 -0
- ttnn_visualizer/sessions.py +245 -0
- ttnn_visualizer/settings.py +118 -0
- ttnn_visualizer/sftp_operations.py +486 -0
- ttnn_visualizer/sockets.py +118 -0
- ttnn_visualizer/ssh_client.py +85 -0
- ttnn_visualizer/static/assets/allPaths-CKt4gwo3.js +1 -0
- ttnn_visualizer/static/assets/allPathsLoader-Dzw0zTnr.js +2 -0
- ttnn_visualizer/static/assets/index-BXlT2rEV.js +5247 -0
- ttnn_visualizer/static/assets/index-CsS_OkTl.js +1 -0
- ttnn_visualizer/static/assets/index-DTKBo2Os.css +7 -0
- ttnn_visualizer/static/assets/index-DxLGmC6o.js +1 -0
- ttnn_visualizer/static/assets/site-BTBrvHC5.webmanifest +19 -0
- ttnn_visualizer/static/assets/splitPathsBySizeLoader-HHqSPeQM.js +1 -0
- ttnn_visualizer/static/favicon/android-chrome-192x192.png +0 -0
- ttnn_visualizer/static/favicon/android-chrome-512x512.png +0 -0
- ttnn_visualizer/static/favicon/favicon-32x32.png +0 -0
- ttnn_visualizer/static/favicon/favicon.svg +3 -0
- ttnn_visualizer/static/index.html +36 -0
- ttnn_visualizer/static/sample-data/cluster-desc.yaml +763 -0
- ttnn_visualizer/tests/__init__.py +4 -0
- ttnn_visualizer/tests/test_queries.py +444 -0
- ttnn_visualizer/tests/test_serializers.py +582 -0
- ttnn_visualizer/utils.py +185 -0
- ttnn_visualizer/views.py +794 -0
- ttnn_visualizer-0.24.0.dist-info/LICENSE +202 -0
- ttnn_visualizer-0.24.0.dist-info/LICENSE_understanding.txt +3 -0
- ttnn_visualizer-0.24.0.dist-info/METADATA +144 -0
- ttnn_visualizer-0.24.0.dist-info/RECORD +46 -0
- ttnn_visualizer-0.24.0.dist-info/WHEEL +5 -0
- ttnn_visualizer-0.24.0.dist-info/entry_points.txt +2 -0
- ttnn_visualizer-0.24.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,249 @@
|
|
1
|
+
import dataclasses
|
2
|
+
from collections import defaultdict
|
3
|
+
from typing import List
|
4
|
+
# SPDX-License-Identifier: Apache-2.0
|
5
|
+
#
|
6
|
+
# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC
|
7
|
+
|
8
|
+
|
9
|
+
from ttnn_visualizer.models import BufferType, Operation, TensorComparisonRecord
|
10
|
+
|
11
|
+
|
12
|
+
def serialize_operations(
|
13
|
+
inputs,
|
14
|
+
operation_arguments,
|
15
|
+
operations,
|
16
|
+
outputs,
|
17
|
+
stack_traces,
|
18
|
+
tensors,
|
19
|
+
devices,
|
20
|
+
producers_consumers,
|
21
|
+
device_operations,
|
22
|
+
):
|
23
|
+
tensors_dict = {t.tensor_id: t for t in tensors}
|
24
|
+
device_operations_dict = {
|
25
|
+
do.operation_id: do.captured_graph
|
26
|
+
for do in device_operations
|
27
|
+
if hasattr(do, "operation_id")
|
28
|
+
}
|
29
|
+
|
30
|
+
stack_traces_dict = {st.operation_id: st.stack_trace for st in stack_traces}
|
31
|
+
|
32
|
+
arguments_dict = defaultdict(list)
|
33
|
+
for argument in operation_arguments:
|
34
|
+
arguments_dict[argument.operation_id].append(argument)
|
35
|
+
|
36
|
+
inputs_dict, outputs_dict = serialize_inputs_outputs(
|
37
|
+
inputs, outputs, producers_consumers, tensors_dict
|
38
|
+
)
|
39
|
+
|
40
|
+
results = []
|
41
|
+
for operation in operations:
|
42
|
+
|
43
|
+
inputs = inputs_dict[operation.operation_id]
|
44
|
+
outputs = outputs_dict[operation.operation_id]
|
45
|
+
arguments = [a.to_dict() for a in arguments_dict[operation.operation_id]]
|
46
|
+
operation_data = operation.to_dict()
|
47
|
+
operation_data["id"] = operation.operation_id
|
48
|
+
operation_device_operations = device_operations_dict.get(
|
49
|
+
operation.operation_id, []
|
50
|
+
)
|
51
|
+
id = operation_data.pop("operation_id", None)
|
52
|
+
|
53
|
+
results.append(
|
54
|
+
{
|
55
|
+
**operation_data,
|
56
|
+
"id": id,
|
57
|
+
"stack_trace": stack_traces_dict.get(operation.operation_id),
|
58
|
+
"device_operations": operation_device_operations,
|
59
|
+
"arguments": arguments,
|
60
|
+
"inputs": inputs,
|
61
|
+
"outputs": outputs,
|
62
|
+
}
|
63
|
+
)
|
64
|
+
return results
|
65
|
+
|
66
|
+
|
67
|
+
def serialize_inputs_outputs(
|
68
|
+
inputs,
|
69
|
+
outputs,
|
70
|
+
producers_consumers,
|
71
|
+
tensors_dict,
|
72
|
+
comparisons=None,
|
73
|
+
):
|
74
|
+
producers_consumers_dict = {pc.tensor_id: pc for pc in producers_consumers}
|
75
|
+
|
76
|
+
def attach_tensor_data(values):
|
77
|
+
values_dict = defaultdict(list)
|
78
|
+
for value in values:
|
79
|
+
tensor = tensors_dict.get(value.tensor_id)
|
80
|
+
tensor_dict = tensor.to_dict()
|
81
|
+
pc = producers_consumers_dict.get(value.tensor_id)
|
82
|
+
value_dict = dataclasses.asdict(value)
|
83
|
+
value_dict.pop("tensor_id", None)
|
84
|
+
value_dict.update(
|
85
|
+
{
|
86
|
+
"id": tensor_dict.pop("tensor_id"),
|
87
|
+
"consumers": pc.consumers if pc else [],
|
88
|
+
"producers": pc.producers if pc else [],
|
89
|
+
}
|
90
|
+
)
|
91
|
+
if comparisons:
|
92
|
+
comparison = comparisons.get(value.tensor_id)
|
93
|
+
value_dict.update({"comparison": comparison})
|
94
|
+
|
95
|
+
values_dict[value.operation_id].append({**value_dict, **tensor_dict})
|
96
|
+
return values_dict
|
97
|
+
|
98
|
+
inputs_dict = attach_tensor_data(inputs)
|
99
|
+
outputs_dict = attach_tensor_data(outputs)
|
100
|
+
return inputs_dict, outputs_dict
|
101
|
+
|
102
|
+
|
103
|
+
def serialize_buffer_pages(buffer_pages):
|
104
|
+
# Collect device-specific data if needed
|
105
|
+
|
106
|
+
# Serialize each buffer page to a dictionary using dataclasses.asdict
|
107
|
+
buffer_pages_list = [page.to_dict() for page in buffer_pages]
|
108
|
+
|
109
|
+
# Optionally, modify or adjust the serialized data as needed
|
110
|
+
for page_data in buffer_pages_list:
|
111
|
+
# Set a custom id field if needed
|
112
|
+
page_data["id"] = f"{page_data['operation_id']}_{page_data['page_index']}"
|
113
|
+
|
114
|
+
# If the buffer_type is handled by an enum, adjust it similarly to your BufferPage model
|
115
|
+
if "buffer_type" in page_data and isinstance(
|
116
|
+
page_data["buffer_type"], BufferType
|
117
|
+
):
|
118
|
+
page_data["buffer_type"] = page_data["buffer_type"].value
|
119
|
+
|
120
|
+
return buffer_pages_list
|
121
|
+
|
122
|
+
|
123
|
+
def comparisons_by_tensor_id(
|
124
|
+
local_comparisons: List[TensorComparisonRecord],
|
125
|
+
global_comparisons: List[TensorComparisonRecord],
|
126
|
+
):
|
127
|
+
comparisons = defaultdict(dict)
|
128
|
+
for local_comparison in local_comparisons:
|
129
|
+
comparisons[local_comparison.tensor_id].update({"local": local_comparison})
|
130
|
+
for global_comparison in global_comparisons:
|
131
|
+
comparisons[global_comparison.tensor_id].update({"global": global_comparison})
|
132
|
+
return comparisons
|
133
|
+
|
134
|
+
|
135
|
+
def serialize_operation(
|
136
|
+
buffers,
|
137
|
+
inputs,
|
138
|
+
operation,
|
139
|
+
operation_arguments,
|
140
|
+
outputs,
|
141
|
+
stack_trace,
|
142
|
+
tensors,
|
143
|
+
global_tensor_comparisons,
|
144
|
+
local_tensor_comparisons,
|
145
|
+
devices,
|
146
|
+
producers_consumers,
|
147
|
+
device_operations,
|
148
|
+
):
|
149
|
+
tensors_dict = {t.tensor_id: t for t in tensors}
|
150
|
+
comparisons = comparisons_by_tensor_id(
|
151
|
+
local_tensor_comparisons, global_tensor_comparisons
|
152
|
+
)
|
153
|
+
|
154
|
+
inputs_dict, outputs_dict = serialize_inputs_outputs(
|
155
|
+
inputs,
|
156
|
+
outputs,
|
157
|
+
producers_consumers,
|
158
|
+
tensors_dict,
|
159
|
+
comparisons,
|
160
|
+
)
|
161
|
+
|
162
|
+
buffer_list = [buffer.to_dict() for buffer in buffers]
|
163
|
+
|
164
|
+
l1_sizes = [d.worker_l1_size for d in devices]
|
165
|
+
arguments_data = [argument.to_dict() for argument in operation_arguments]
|
166
|
+
operation_data = operation.to_dict()
|
167
|
+
operation_data["id"] = operation.operation_id
|
168
|
+
|
169
|
+
inputs_data = inputs_dict.get(operation.operation_id)
|
170
|
+
outputs_data = outputs_dict.get(operation.operation_id)
|
171
|
+
|
172
|
+
id = operation_data.pop("operation_id", None)
|
173
|
+
|
174
|
+
device_operations_data = []
|
175
|
+
for do in device_operations:
|
176
|
+
if do.operation_id == operation.operation_id:
|
177
|
+
device_operations_data = do.captured_graph
|
178
|
+
break
|
179
|
+
|
180
|
+
return {
|
181
|
+
**operation_data,
|
182
|
+
"id": id,
|
183
|
+
"l1_sizes": l1_sizes,
|
184
|
+
"device_operations": device_operations_data,
|
185
|
+
"stack_trace": stack_trace.stack_trace if stack_trace else "",
|
186
|
+
"buffers": buffer_list,
|
187
|
+
"arguments": arguments_data,
|
188
|
+
"inputs": inputs_data or [],
|
189
|
+
"outputs": outputs_data or [],
|
190
|
+
}
|
191
|
+
|
192
|
+
|
193
|
+
def serialize_operation_buffers(operation: Operation, operation_buffers):
|
194
|
+
buffer_data = [b.to_dict() for b in operation_buffers]
|
195
|
+
for b in buffer_data:
|
196
|
+
b.pop("operation_id")
|
197
|
+
b.update({"size": b.pop("max_size_per_bank")})
|
198
|
+
return {
|
199
|
+
"id": operation.operation_id,
|
200
|
+
"name": operation.name,
|
201
|
+
"buffers": list(buffer_data),
|
202
|
+
}
|
203
|
+
|
204
|
+
|
205
|
+
def serialize_devices(devices):
|
206
|
+
return [d.to_dict() for d in devices]
|
207
|
+
|
208
|
+
|
209
|
+
def serialize_operations_buffers(operations, buffers):
|
210
|
+
|
211
|
+
buffer_dict = defaultdict(list)
|
212
|
+
for b in buffers:
|
213
|
+
buffer_dict[b.operation_id].append(b)
|
214
|
+
|
215
|
+
results = []
|
216
|
+
for operation in operations:
|
217
|
+
operation_buffers = buffer_dict.get(operation.operation_id, [])
|
218
|
+
results.append(serialize_operation_buffers(operation, operation_buffers))
|
219
|
+
return results
|
220
|
+
|
221
|
+
|
222
|
+
def serialize_tensors(
|
223
|
+
tensors, producers_consumers, local_comparisons, global_comparisons
|
224
|
+
):
|
225
|
+
producers_consumers_dict = {pc.tensor_id: pc for pc in producers_consumers}
|
226
|
+
results = []
|
227
|
+
comparisons = comparisons_by_tensor_id(local_comparisons, global_comparisons)
|
228
|
+
for tensor in tensors:
|
229
|
+
tensor_data = tensor.to_dict()
|
230
|
+
tensor_id = tensor_data.pop("tensor_id")
|
231
|
+
tensor_data.update(
|
232
|
+
{
|
233
|
+
"id": tensor_id,
|
234
|
+
"comparison": comparisons.get(tensor_id),
|
235
|
+
"consumers": (
|
236
|
+
producers_consumers_dict[tensor_id].consumers
|
237
|
+
if tensor_id in producers_consumers_dict
|
238
|
+
else []
|
239
|
+
),
|
240
|
+
"producers": (
|
241
|
+
producers_consumers_dict[tensor_id].producers
|
242
|
+
if tensor_id in producers_consumers_dict
|
243
|
+
else []
|
244
|
+
),
|
245
|
+
}
|
246
|
+
)
|
247
|
+
results.append(tensor_data)
|
248
|
+
|
249
|
+
return results
|
@@ -0,0 +1,245 @@
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
2
|
+
#
|
3
|
+
# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC
|
4
|
+
|
5
|
+
from logging import getLogger
|
6
|
+
|
7
|
+
from flask import request
|
8
|
+
|
9
|
+
from ttnn_visualizer.utils import get_report_path, get_profiler_path
|
10
|
+
from ttnn_visualizer.models import (
|
11
|
+
TabSessionTable,
|
12
|
+
)
|
13
|
+
from ttnn_visualizer.extensions import db
|
14
|
+
|
15
|
+
logger = getLogger(__name__)
|
16
|
+
|
17
|
+
from flask import jsonify, current_app
|
18
|
+
from sqlalchemy.exc import SQLAlchemyError
|
19
|
+
import json
|
20
|
+
|
21
|
+
|
22
|
+
def update_existing_tab_session(
|
23
|
+
session_data,
|
24
|
+
report_name,
|
25
|
+
profile_name,
|
26
|
+
remote_connection,
|
27
|
+
remote_folder,
|
28
|
+
remote_profile_folder,
|
29
|
+
clear_remote,
|
30
|
+
):
|
31
|
+
active_report = session_data.active_report or {}
|
32
|
+
|
33
|
+
if report_name:
|
34
|
+
active_report["report_name"] = report_name
|
35
|
+
if profile_name:
|
36
|
+
active_report["profile_name"] = profile_name
|
37
|
+
|
38
|
+
session_data.active_report = active_report
|
39
|
+
|
40
|
+
if remote_connection:
|
41
|
+
session_data.remote_connection = remote_connection.model_dump()
|
42
|
+
if remote_folder:
|
43
|
+
session_data.remote_folder = remote_folder.model_dump()
|
44
|
+
if remote_profile_folder:
|
45
|
+
session_data.remote_profile_folder = remote_profile_folder.model_dump()
|
46
|
+
|
47
|
+
if clear_remote:
|
48
|
+
clear_remote_data(session_data)
|
49
|
+
|
50
|
+
update_paths(
|
51
|
+
session_data, active_report, remote_connection, report_name, profile_name
|
52
|
+
)
|
53
|
+
|
54
|
+
|
55
|
+
def clear_remote_data(session_data):
|
56
|
+
session_data.remote_connection = None
|
57
|
+
session_data.remote_folder = None
|
58
|
+
session_data.remote_profile_folder = None
|
59
|
+
|
60
|
+
|
61
|
+
def handle_sqlalchemy_error(error):
|
62
|
+
current_app.logger.error(f"Failed to update tab session: {str(error)}")
|
63
|
+
db.session.rollback()
|
64
|
+
|
65
|
+
|
66
|
+
def commit_and_log_session(session_data, tab_id):
|
67
|
+
db.session.commit()
|
68
|
+
|
69
|
+
session_data = TabSessionTable.query.filter_by(tab_id=tab_id).first()
|
70
|
+
current_app.logger.info(
|
71
|
+
f"Session data for tab {tab_id}: {json.dumps(session_data.to_dict(), indent=4)}"
|
72
|
+
)
|
73
|
+
|
74
|
+
|
75
|
+
def update_paths(
|
76
|
+
session_data, active_report, remote_connection, report_name, profile_name
|
77
|
+
):
|
78
|
+
if active_report.get("profile_name"):
|
79
|
+
session_data.profiler_path = get_profiler_path(
|
80
|
+
profile_name=active_report["profile_name"],
|
81
|
+
current_app=current_app,
|
82
|
+
remote_connection=remote_connection,
|
83
|
+
)
|
84
|
+
|
85
|
+
if active_report.get("report_name"):
|
86
|
+
session_data.report_path = get_report_path(
|
87
|
+
active_report=active_report,
|
88
|
+
current_app=current_app,
|
89
|
+
remote_connection=remote_connection,
|
90
|
+
)
|
91
|
+
|
92
|
+
|
93
|
+
def create_new_tab_session(
|
94
|
+
tab_id,
|
95
|
+
report_name,
|
96
|
+
profile_name,
|
97
|
+
remote_connection,
|
98
|
+
remote_folder,
|
99
|
+
remote_profile_folder,
|
100
|
+
clear_remote,
|
101
|
+
):
|
102
|
+
active_report = {}
|
103
|
+
if report_name:
|
104
|
+
active_report["report_name"] = report_name
|
105
|
+
if profile_name:
|
106
|
+
active_report["profile_name"] = profile_name
|
107
|
+
|
108
|
+
if clear_remote:
|
109
|
+
remote_connection = None
|
110
|
+
remote_folder = None
|
111
|
+
remote_profile_folder = None
|
112
|
+
|
113
|
+
session_data = TabSessionTable(
|
114
|
+
tab_id=tab_id,
|
115
|
+
active_report=active_report,
|
116
|
+
report_path=get_report_path(
|
117
|
+
active_report,
|
118
|
+
current_app=current_app,
|
119
|
+
remote_connection=remote_connection,
|
120
|
+
),
|
121
|
+
remote_connection=(
|
122
|
+
remote_connection.model_dump() if remote_connection else None
|
123
|
+
),
|
124
|
+
remote_folder=remote_folder.model_dump() if remote_folder else None,
|
125
|
+
remote_profile_folder=(
|
126
|
+
remote_profile_folder.model_dump() if remote_profile_folder else None
|
127
|
+
),
|
128
|
+
)
|
129
|
+
db.session.add(session_data)
|
130
|
+
return session_data
|
131
|
+
|
132
|
+
|
133
|
+
def update_tab_session(
|
134
|
+
tab_id,
|
135
|
+
report_name=None,
|
136
|
+
profile_name=None,
|
137
|
+
remote_connection=None,
|
138
|
+
remote_folder=None,
|
139
|
+
remote_profile_folder=None,
|
140
|
+
clear_remote=False,
|
141
|
+
):
|
142
|
+
try:
|
143
|
+
session_data = get_or_create_tab_session(tab_id)
|
144
|
+
|
145
|
+
if session_data:
|
146
|
+
update_existing_tab_session(
|
147
|
+
session_data,
|
148
|
+
report_name,
|
149
|
+
profile_name,
|
150
|
+
remote_connection,
|
151
|
+
remote_folder,
|
152
|
+
remote_profile_folder,
|
153
|
+
clear_remote,
|
154
|
+
)
|
155
|
+
else:
|
156
|
+
session_data = create_new_tab_session(
|
157
|
+
tab_id,
|
158
|
+
report_name,
|
159
|
+
profile_name,
|
160
|
+
remote_connection,
|
161
|
+
remote_folder,
|
162
|
+
remote_profile_folder,
|
163
|
+
clear_remote,
|
164
|
+
)
|
165
|
+
|
166
|
+
commit_and_log_session(session_data, tab_id)
|
167
|
+
return jsonify({"message": "Tab session updated successfully"}), 200
|
168
|
+
|
169
|
+
except SQLAlchemyError as e:
|
170
|
+
handle_sqlalchemy_error(e)
|
171
|
+
return jsonify({"error": "Failed to update tab session"}), 500
|
172
|
+
|
173
|
+
|
174
|
+
def get_or_create_tab_session(
|
175
|
+
tab_id,
|
176
|
+
report_name=None,
|
177
|
+
profile_name=None,
|
178
|
+
remote_connection=None,
|
179
|
+
remote_folder=None,
|
180
|
+
):
|
181
|
+
"""
|
182
|
+
Retrieve an existing tab session or create a new one if it doesn't exist.
|
183
|
+
Uses the TabSession model to manage session data and supports conditional updates.
|
184
|
+
"""
|
185
|
+
try:
|
186
|
+
# Query the database for the tab session
|
187
|
+
session_data = TabSessionTable.query.filter_by(tab_id=tab_id).first()
|
188
|
+
|
189
|
+
# If session doesn't exist, initialize it
|
190
|
+
if not session_data:
|
191
|
+
session_data = TabSessionTable(
|
192
|
+
tab_id=tab_id,
|
193
|
+
active_report={},
|
194
|
+
remote_connection=None,
|
195
|
+
remote_folder=None,
|
196
|
+
)
|
197
|
+
db.session.add(session_data)
|
198
|
+
db.session.commit()
|
199
|
+
|
200
|
+
# Update the session if any new data is provided
|
201
|
+
if report_name or profile_name or remote_connection or remote_folder:
|
202
|
+
update_tab_session(
|
203
|
+
tab_id=tab_id,
|
204
|
+
report_name=report_name,
|
205
|
+
profile_name=profile_name,
|
206
|
+
remote_connection=remote_connection,
|
207
|
+
remote_folder=remote_folder,
|
208
|
+
)
|
209
|
+
|
210
|
+
# Query again to get the updated session data
|
211
|
+
session_data = TabSessionTable.query.filter_by(tab_id=tab_id).first()
|
212
|
+
|
213
|
+
return session_data
|
214
|
+
|
215
|
+
except SQLAlchemyError as e:
|
216
|
+
current_app.logger.error(f"Failed to get or create tab session: {str(e)}")
|
217
|
+
db.session.rollback()
|
218
|
+
return None
|
219
|
+
|
220
|
+
|
221
|
+
def get_tab_session():
|
222
|
+
"""
|
223
|
+
Middleware to retrieve or create a tab session based on the tab_id.
|
224
|
+
"""
|
225
|
+
tab_id = request.args.get("tabId", None)
|
226
|
+
|
227
|
+
current_app.logger.info(f"get_tab_session: Received tab_id: {tab_id}")
|
228
|
+
if not tab_id:
|
229
|
+
current_app.logger.error("get_tab_session: No tab_id found")
|
230
|
+
return jsonify({"error": "tabId is required"}), 400
|
231
|
+
|
232
|
+
active_report = get_or_create_tab_session(tab_id)
|
233
|
+
current_app.logger.info(
|
234
|
+
f"get_tab_session: Session retrieved: {active_report.active_report}"
|
235
|
+
)
|
236
|
+
|
237
|
+
return jsonify({"active_report": active_report.active_report}), 200
|
238
|
+
|
239
|
+
|
240
|
+
def init_sessions(app):
|
241
|
+
"""
|
242
|
+
Initializes session middleware and hooks it into Flask.
|
243
|
+
"""
|
244
|
+
app.before_request(get_tab_session)
|
245
|
+
app.logger.info("Sessions middleware initialized.")
|
@@ -0,0 +1,118 @@
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
2
|
+
#
|
3
|
+
# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC
|
4
|
+
|
5
|
+
import os
|
6
|
+
from pathlib import Path
|
7
|
+
from dotenv import load_dotenv
|
8
|
+
|
9
|
+
from ttnn_visualizer.utils import str_to_bool
|
10
|
+
|
11
|
+
load_dotenv()
|
12
|
+
|
13
|
+
class DefaultConfig(object):
|
14
|
+
# General Settings
|
15
|
+
SECRET_KEY = os.getenv("SECRET_KEY", "90909")
|
16
|
+
DEBUG = bool(str_to_bool(os.getenv("FLASK_DEBUG", "false")))
|
17
|
+
TESTING = False
|
18
|
+
PRINT_ENV = True
|
19
|
+
|
20
|
+
# Path Settings
|
21
|
+
REPORT_DATA_DIRECTORY = Path(__file__).parent.absolute().joinpath("data")
|
22
|
+
VERSION = "0.14.2"
|
23
|
+
LOCAL_DATA_DIRECTORY = Path(REPORT_DATA_DIRECTORY).joinpath("local")
|
24
|
+
REMOTE_DATA_DIRECTORY = Path(REPORT_DATA_DIRECTORY).joinpath("remote")
|
25
|
+
APPLICATION_DIR = os.path.abspath(os.path.join(__file__, "..", os.pardir))
|
26
|
+
STATIC_ASSETS_DIR = Path(APPLICATION_DIR).joinpath("ttnn_visualizer", "static")
|
27
|
+
SEND_FILE_MAX_AGE_DEFAULT = 0
|
28
|
+
|
29
|
+
LAUNCH_BROWSER_ON_START = str_to_bool(os.getenv("LAUNCH_BROWSER_ON_START", "true"))
|
30
|
+
|
31
|
+
# File Name Configs
|
32
|
+
TEST_CONFIG_FILE = "config.json"
|
33
|
+
SQLITE_DB_PATH = "db.sqlite"
|
34
|
+
|
35
|
+
# For development you may want to disable sockets
|
36
|
+
USE_WEBSOCKETS = str_to_bool(os.getenv("USE_WEBSOCKETS", "true"))
|
37
|
+
|
38
|
+
# SQL Alchemy Settings
|
39
|
+
SQLALCHEMY_DATABASE_URI = (
|
40
|
+
f"sqlite:///{os.path.join(APPLICATION_DIR, f"ttnn_{VERSION}.db")}"
|
41
|
+
)
|
42
|
+
SQLALCHEMY_ENGINE_OPTIONS = {
|
43
|
+
"pool_size": 10, # Adjust pool size as needed (default is 5)
|
44
|
+
"max_overflow": 20, # Allow overflow of the pool size if necessary
|
45
|
+
"pool_timeout": 30, # Timeout in seconds before giving up on getting a connection
|
46
|
+
}
|
47
|
+
SQLALCHEMY_TRACK_MODIFICATIONS = False
|
48
|
+
|
49
|
+
# Gunicorn settings
|
50
|
+
GUNICORN_WORKER_CLASS = os.getenv("GUNICORN_WORKER_CLASS", "gevent")
|
51
|
+
GUNICORN_WORKERS = os.getenv("GUNICORN_WORKERS", "1")
|
52
|
+
PORT = os.getenv("PORT", "8000")
|
53
|
+
HOST = "localhost"
|
54
|
+
DEV_SERVER_PORT = "5173"
|
55
|
+
DEV_SERVER_HOST = "localhost"
|
56
|
+
|
57
|
+
GUNICORN_BIND = f"{HOST}:{PORT}"
|
58
|
+
GUNICORN_APP_MODULE = os.getenv(
|
59
|
+
"GUNICORN_APP_MODULE", "ttnn_visualizer.app:create_app()"
|
60
|
+
)
|
61
|
+
|
62
|
+
# Session Settings
|
63
|
+
SESSION_COOKIE_SAMESITE = "Lax"
|
64
|
+
SESSION_COOKIE_SECURE = False # For development on HTTP
|
65
|
+
|
66
|
+
def override_with_env_variables(self):
|
67
|
+
"""Override config values with environment variables."""
|
68
|
+
for key, value in self.__class__.__dict__.items():
|
69
|
+
if not key.startswith("_"): # Skip private/protected attributes
|
70
|
+
env_value = os.getenv(key)
|
71
|
+
if env_value is not None:
|
72
|
+
setattr(self, key, env_value)
|
73
|
+
|
74
|
+
def to_dict(self):
|
75
|
+
"""Return all config values as a dictionary, including inherited attributes."""
|
76
|
+
return {
|
77
|
+
key: getattr(self, key)
|
78
|
+
for key in dir(self)
|
79
|
+
if not key.startswith("_") and not callable(getattr(self, key))
|
80
|
+
}
|
81
|
+
|
82
|
+
|
83
|
+
class DevelopmentConfig(DefaultConfig):
|
84
|
+
pass
|
85
|
+
|
86
|
+
|
87
|
+
class TestingConfig(DefaultConfig):
|
88
|
+
DEBUG = bool(str_to_bool(os.getenv("FLASK_DEBUG", "True")))
|
89
|
+
TESTING = True
|
90
|
+
|
91
|
+
|
92
|
+
class ProductionConfig(DefaultConfig):
|
93
|
+
DEBUG = False
|
94
|
+
TESTING = False
|
95
|
+
|
96
|
+
|
97
|
+
class Config:
|
98
|
+
_instance = None
|
99
|
+
|
100
|
+
def __new__(cls):
|
101
|
+
if cls._instance is None:
|
102
|
+
cls._instance = super(Config, cls).__new__(cls)
|
103
|
+
cls._instance = cls._determine_config()
|
104
|
+
cls._instance.override_with_env_variables()
|
105
|
+
return cls._instance
|
106
|
+
|
107
|
+
@staticmethod
|
108
|
+
def _determine_config():
|
109
|
+
# Determine the environment
|
110
|
+
flask_env = os.getenv("FLASK_ENV", "development").lower()
|
111
|
+
|
112
|
+
# Choose the correct configuration class based on FLASK_ENV
|
113
|
+
if flask_env == "production":
|
114
|
+
return ProductionConfig()
|
115
|
+
elif flask_env == "testing":
|
116
|
+
return TestingConfig()
|
117
|
+
else:
|
118
|
+
return DevelopmentConfig()
|