alphai 0.0.7__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- alphai/__init__.py +8 -4
- alphai/auth.py +362 -0
- alphai/cli.py +1015 -0
- alphai/client.py +400 -0
- alphai/config.py +88 -0
- alphai/docker.py +764 -0
- alphai/utils.py +192 -0
- alphai-0.1.0.dist-info/METADATA +394 -0
- alphai-0.1.0.dist-info/RECORD +12 -0
- {alphai-0.0.7.dist-info → alphai-0.1.0.dist-info}/WHEEL +2 -1
- alphai-0.1.0.dist-info/entry_points.txt +2 -0
- alphai-0.1.0.dist-info/top_level.txt +1 -0
- alphai/alphai.py +0 -786
- alphai/api/client.py +0 -0
- alphai/benchmarking/benchmarker.py +0 -37
- alphai/client/__init__.py +0 -0
- alphai/client/client.py +0 -382
- alphai/profilers/__init__.py +0 -0
- alphai/profilers/configs_base.py +0 -7
- alphai/profilers/jax.py +0 -37
- alphai/profilers/pytorch.py +0 -83
- alphai/profilers/pytorch_utils.py +0 -419
- alphai/util.py +0 -19
- alphai-0.0.7.dist-info/LICENSE +0 -201
- alphai-0.0.7.dist-info/METADATA +0 -125
- alphai-0.0.7.dist-info/RECORD +0 -16
alphai/api/client.py
DELETED
|
File without changes
|
|
@@ -1,37 +0,0 @@
|
|
|
1
|
-
from typing import Callable
|
|
2
|
-
|
|
3
|
-
import timeit
|
|
4
|
-
from timeit import default_timer
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
class Benchmarker:
|
|
8
|
-
def __init__(self):
|
|
9
|
-
self.start_time = 0
|
|
10
|
-
self.end_time = 0
|
|
11
|
-
|
|
12
|
-
def start(self):
|
|
13
|
-
self.start_time = default_timer()
|
|
14
|
-
|
|
15
|
-
def stop(self, print_results: bool = True):
|
|
16
|
-
self.end_time = default_timer()
|
|
17
|
-
delta = self.end_time - self.start_time
|
|
18
|
-
print(f"Measured: {delta * 1e6:.1f} us")
|
|
19
|
-
return delta
|
|
20
|
-
|
|
21
|
-
def benchmark(
|
|
22
|
-
self,
|
|
23
|
-
function: Callable = None,
|
|
24
|
-
*args,
|
|
25
|
-
num_iter: int = 1,
|
|
26
|
-
print_results: bool = True,
|
|
27
|
-
**kwargs,
|
|
28
|
-
):
|
|
29
|
-
results = {}
|
|
30
|
-
total_time = timeit.Timer(lambda: function(*args, **kwargs)).timeit(num_iter)
|
|
31
|
-
avg_run = total_time / num_iter
|
|
32
|
-
results["total_time_seconds"] = total_time
|
|
33
|
-
results["avg_run_seconds"] = avg_run
|
|
34
|
-
if print_results:
|
|
35
|
-
print(f"Measured total run: {total_time * 1e6:>5.1f} us")
|
|
36
|
-
print(f"Measured averaged run: {avg_run * 1e6:>5.1f} us")
|
|
37
|
-
return results
|
alphai/client/__init__.py
DELETED
|
File without changes
|
alphai/client/client.py
DELETED
|
@@ -1,382 +0,0 @@
|
|
|
1
|
-
# src/alphai/client/client.py
|
|
2
|
-
import requests
|
|
3
|
-
from typing import List
|
|
4
|
-
import urllib.parse
|
|
5
|
-
import urllib.request
|
|
6
|
-
import base64
|
|
7
|
-
import json
|
|
8
|
-
from websocket import create_connection, WebSocketTimeoutException
|
|
9
|
-
import uuid
|
|
10
|
-
import datetime
|
|
11
|
-
|
|
12
|
-
import nbserv_client
|
|
13
|
-
import os
|
|
14
|
-
import jh_client
|
|
15
|
-
from pprint import pprint
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
class Client:
|
|
19
|
-
def __init__(
|
|
20
|
-
self,
|
|
21
|
-
host: str = "https://lab.amdatascience.com",
|
|
22
|
-
dashboard_url: str = "https://dashboard.amdatascience.com",
|
|
23
|
-
access_token=None,
|
|
24
|
-
):
|
|
25
|
-
self.host = host
|
|
26
|
-
self.dashboard_url = dashboard_url
|
|
27
|
-
self.access_token = access_token
|
|
28
|
-
self.configuration = jh_client.Configuration(host=f"{host}/hub/api")
|
|
29
|
-
self.configuration.access_token = access_token
|
|
30
|
-
|
|
31
|
-
# Enter a context with an instance of the API client
|
|
32
|
-
self.api_client = jh_client.ApiClient(self.configuration)
|
|
33
|
-
# Create an instance of the API class
|
|
34
|
-
self.api_instance = jh_client.DefaultApi(self.api_client)
|
|
35
|
-
|
|
36
|
-
# Initialize server
|
|
37
|
-
self.initialize_server()
|
|
38
|
-
|
|
39
|
-
def initialize_server(self, server_name=""):
|
|
40
|
-
self.server_name = server_name
|
|
41
|
-
try:
|
|
42
|
-
# Return authenticated user's model
|
|
43
|
-
self.user_api_response = self.api_instance.user_get()
|
|
44
|
-
self.user_info = self.user_api_response.to_dict()
|
|
45
|
-
|
|
46
|
-
except Exception as e:
|
|
47
|
-
print("Exception when calling DefaultApi->user_get: %s\n" % e)
|
|
48
|
-
|
|
49
|
-
self.user_name = self.user_info["name"]
|
|
50
|
-
self.server_configuration = nbserv_client.Configuration(
|
|
51
|
-
host=f"{self.host}/user/{self.user_name}/{server_name}"
|
|
52
|
-
)
|
|
53
|
-
self.server_api_client = nbserv_client.ApiClient(
|
|
54
|
-
self.server_configuration,
|
|
55
|
-
header_name="Authorization",
|
|
56
|
-
header_value=f"Token {self.access_token}",
|
|
57
|
-
)
|
|
58
|
-
self.server_api_instance = nbserv_client.ContentsApi(self.server_api_client)
|
|
59
|
-
|
|
60
|
-
def get_user_info(self):
|
|
61
|
-
return self.user_info
|
|
62
|
-
|
|
63
|
-
def get_servers(self):
|
|
64
|
-
user_api_response = self.api_instance.users_name_get(
|
|
65
|
-
self.user_name, include_stopped_servers=True
|
|
66
|
-
)
|
|
67
|
-
servers = user_api_response.to_dict()["servers"]
|
|
68
|
-
return {
|
|
69
|
-
k: {
|
|
70
|
-
"name": v["name"],
|
|
71
|
-
"ready": v["ready"],
|
|
72
|
-
"url": v["url"],
|
|
73
|
-
"last_activity": v["last_activity"],
|
|
74
|
-
}
|
|
75
|
-
for k, v in servers.items()
|
|
76
|
-
}
|
|
77
|
-
|
|
78
|
-
# Dashboard Client
|
|
79
|
-
def start_server(
|
|
80
|
-
self,
|
|
81
|
-
server_name: str = "",
|
|
82
|
-
environment: str = "ai",
|
|
83
|
-
server_request: str = "medium-cpu",
|
|
84
|
-
):
|
|
85
|
-
|
|
86
|
-
# Start Server given name
|
|
87
|
-
# Data to be sent in POST request
|
|
88
|
-
data = {
|
|
89
|
-
"server_name": server_name,
|
|
90
|
-
"environment": environment,
|
|
91
|
-
"server_request": server_request,
|
|
92
|
-
"port": 5000,
|
|
93
|
-
}
|
|
94
|
-
|
|
95
|
-
url = f"{self.dashboard_url}/api/server"
|
|
96
|
-
headers = {
|
|
97
|
-
"apikey": f"{self.access_token}",
|
|
98
|
-
'Accept': 'application/json'
|
|
99
|
-
}
|
|
100
|
-
|
|
101
|
-
response = requests.post(
|
|
102
|
-
url,
|
|
103
|
-
json=data,
|
|
104
|
-
headers=headers,
|
|
105
|
-
)
|
|
106
|
-
|
|
107
|
-
# If the response is JSON
|
|
108
|
-
try:
|
|
109
|
-
response_data = response.json()
|
|
110
|
-
return response_data
|
|
111
|
-
except ValueError:
|
|
112
|
-
print("Response is not in JSON format")
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
def stop_server(self, server_name: str = ""):
|
|
116
|
-
# Stop Server given name
|
|
117
|
-
# Data to be sent in POST request
|
|
118
|
-
if not server_name:
|
|
119
|
-
server_name = "default"
|
|
120
|
-
data = {
|
|
121
|
-
"stop": True,
|
|
122
|
-
}
|
|
123
|
-
|
|
124
|
-
url = f"{self.dashboard_url}/api/server/{server_name}"
|
|
125
|
-
headers = {
|
|
126
|
-
"apikey": f"{self.access_token}",
|
|
127
|
-
'Accept': 'application/json'
|
|
128
|
-
}
|
|
129
|
-
|
|
130
|
-
response = requests.post(
|
|
131
|
-
url,
|
|
132
|
-
json=data,
|
|
133
|
-
headers=headers,
|
|
134
|
-
)
|
|
135
|
-
|
|
136
|
-
# If the response is JSON
|
|
137
|
-
try:
|
|
138
|
-
response_data = response.json()
|
|
139
|
-
return response_data
|
|
140
|
-
except ValueError:
|
|
141
|
-
print("Response is not in JSON format")
|
|
142
|
-
|
|
143
|
-
def alph(
|
|
144
|
-
self,
|
|
145
|
-
server_name: str = "",
|
|
146
|
-
messages: str | list = "Hi Alph.",
|
|
147
|
-
engine: str = "gpt3",
|
|
148
|
-
):
|
|
149
|
-
|
|
150
|
-
# Agent Alph call
|
|
151
|
-
# Data to be sent in POST request
|
|
152
|
-
if isinstance(messages, str):
|
|
153
|
-
data = {
|
|
154
|
-
"messages": [
|
|
155
|
-
{"role": "user", "content": messages}
|
|
156
|
-
],
|
|
157
|
-
}
|
|
158
|
-
else:
|
|
159
|
-
data = {
|
|
160
|
-
"messages": messages
|
|
161
|
-
}
|
|
162
|
-
|
|
163
|
-
url = f"{self.dashboard_url}/api/alph/{server_name}/{engine}"
|
|
164
|
-
headers = {
|
|
165
|
-
"apikey": f"{self.access_token}",
|
|
166
|
-
'Accept': 'application/json'
|
|
167
|
-
}
|
|
168
|
-
|
|
169
|
-
response = requests.post(
|
|
170
|
-
url,
|
|
171
|
-
json=data,
|
|
172
|
-
headers=headers,
|
|
173
|
-
stream=True
|
|
174
|
-
)
|
|
175
|
-
|
|
176
|
-
# If the response is JSON
|
|
177
|
-
#try:
|
|
178
|
-
if response.encoding is None:
|
|
179
|
-
response.encoding = 'utf-8'
|
|
180
|
-
|
|
181
|
-
full_output = []
|
|
182
|
-
try:
|
|
183
|
-
for line in response.iter_lines(decode_unicode=True):
|
|
184
|
-
if line:
|
|
185
|
-
#import pdb; pdb.set_trace()
|
|
186
|
-
split_line = line.split(':')
|
|
187
|
-
cleaned_line = split_line[1].replace('"', '')
|
|
188
|
-
print(cleaned_line, end="")
|
|
189
|
-
full_output.append(cleaned_line)
|
|
190
|
-
except ValueError:
|
|
191
|
-
print("Response encoded incorrectly.")
|
|
192
|
-
|
|
193
|
-
return "".join(full_output)
|
|
194
|
-
|
|
195
|
-
# NB server client
|
|
196
|
-
def get_contents(self, server_name: str = ""):
|
|
197
|
-
if server_name != self.server_name:
|
|
198
|
-
self.initialize_server(server_name=server_name)
|
|
199
|
-
path = "" # str | file path
|
|
200
|
-
type = None # str | File type ('file', 'directory') (optional)
|
|
201
|
-
format = "text" # str | How file content should be returned ('text', 'base64') (optional)
|
|
202
|
-
content = None # int | Return content (0 for no content, 1 for return content) (optional)
|
|
203
|
-
hash = None # int | May return hash hexdigest string of content and the hash algorithm (0 for no hash - default, 1 for return hash). It may be ignored by the content manager. (optional)
|
|
204
|
-
|
|
205
|
-
try:
|
|
206
|
-
# Get contents of file or directory
|
|
207
|
-
api_response = self.server_api_instance.api_contents_path_get(
|
|
208
|
-
path, type=type, format=format, content=content, hash=hash
|
|
209
|
-
)
|
|
210
|
-
print("The response of ContentsApi->api_contents_path_get:\n")
|
|
211
|
-
pprint(api_response)
|
|
212
|
-
except Exception as e:
|
|
213
|
-
print("Exception when calling ContentsApi->api_contents_path_get: %s\n" % e)
|
|
214
|
-
|
|
215
|
-
def post_contents(
|
|
216
|
-
self,
|
|
217
|
-
server_name: str = "",
|
|
218
|
-
path: str = "",
|
|
219
|
-
ext: str = "",
|
|
220
|
-
type: str = "directory",
|
|
221
|
-
):
|
|
222
|
-
# Data to be sent in POST request
|
|
223
|
-
data = {
|
|
224
|
-
"type": type,
|
|
225
|
-
"ext": ext,
|
|
226
|
-
}
|
|
227
|
-
|
|
228
|
-
url = f"{self.host}/user/{self.user_name}/{server_name}/api/contents/{path}"
|
|
229
|
-
headers = {"Authorization": f"Token {self.access_token}"}
|
|
230
|
-
|
|
231
|
-
response = requests.post(
|
|
232
|
-
url,
|
|
233
|
-
json=data,
|
|
234
|
-
headers=headers,
|
|
235
|
-
)
|
|
236
|
-
|
|
237
|
-
# If the response is JSON
|
|
238
|
-
try:
|
|
239
|
-
response_data = response.json()
|
|
240
|
-
return response_data
|
|
241
|
-
except ValueError:
|
|
242
|
-
print("Response is not in JSON format")
|
|
243
|
-
|
|
244
|
-
def patch_contents(
|
|
245
|
-
self,
|
|
246
|
-
server_name: str = "",
|
|
247
|
-
path: str = "Untitled Folder",
|
|
248
|
-
new_path: str = "alphai_",
|
|
249
|
-
):
|
|
250
|
-
# Data to be sent in PATCH request
|
|
251
|
-
data = {"path": new_path}
|
|
252
|
-
|
|
253
|
-
url = f"{self.host}/user/{self.user_name}/{server_name}/api/contents/{path}"
|
|
254
|
-
headers = {"Authorization": f"Token {self.access_token}"}
|
|
255
|
-
|
|
256
|
-
response = requests.patch(
|
|
257
|
-
url,
|
|
258
|
-
json=data,
|
|
259
|
-
headers=headers,
|
|
260
|
-
)
|
|
261
|
-
|
|
262
|
-
# If the response is JSON
|
|
263
|
-
try:
|
|
264
|
-
response_data = response.json()
|
|
265
|
-
return response_data
|
|
266
|
-
except ValueError:
|
|
267
|
-
print("Response is not in JSON format")
|
|
268
|
-
|
|
269
|
-
def put_contents(self, server_name: str = "", path: str = "", file_path: str = ""):
|
|
270
|
-
# Data to be sent in POST request
|
|
271
|
-
file_name = file_path[1 + file_path.rfind(os.sep) :]
|
|
272
|
-
url = f"{self.host}/user/{self.user_name}/{server_name}/api/contents/{path}/{file_name}"
|
|
273
|
-
headers = {"Authorization": f"Token {self.access_token}"}
|
|
274
|
-
|
|
275
|
-
try:
|
|
276
|
-
with open(file_path, "rb") as f:
|
|
277
|
-
data = f.read()
|
|
278
|
-
b64data = base64.b64encode(data)
|
|
279
|
-
body = json.dumps(
|
|
280
|
-
{
|
|
281
|
-
"content": b64data.decode(),
|
|
282
|
-
"name": file_name,
|
|
283
|
-
"path": path,
|
|
284
|
-
"format": "base64",
|
|
285
|
-
"type": "file",
|
|
286
|
-
}
|
|
287
|
-
)
|
|
288
|
-
return requests.put(url, data=body, headers=headers, verify=True)
|
|
289
|
-
except ValueError:
|
|
290
|
-
print("Request is invalid")
|
|
291
|
-
|
|
292
|
-
def get_kernels(self, server_name=""):
|
|
293
|
-
# Get initial kernel info
|
|
294
|
-
url = f"{self.host}/user/{self.user_name}/{server_name}/api/kernels"
|
|
295
|
-
headers = {"Authorization": f"Token {self.access_token}"}
|
|
296
|
-
response = requests.get(url, headers=headers)
|
|
297
|
-
kernels = json.loads(response.text)
|
|
298
|
-
self.kernels = kernels
|
|
299
|
-
return kernels
|
|
300
|
-
|
|
301
|
-
def shutdown_all_kernels(self, server_name=""):
|
|
302
|
-
# Get initial kernel info
|
|
303
|
-
kernels = self.get_kernels(server_name=server_name)
|
|
304
|
-
# Delete all kernels
|
|
305
|
-
headers = {"Authorization": f"Token {self.access_token}"}
|
|
306
|
-
for k in kernels:
|
|
307
|
-
url = (
|
|
308
|
-
f"{self.host}/user/{self.user_name}/{server_name}/api/kernels/{k['id']}"
|
|
309
|
-
)
|
|
310
|
-
response = requests.delete(url, headers=headers)
|
|
311
|
-
|
|
312
|
-
def send_channel_execute(
|
|
313
|
-
self,
|
|
314
|
-
server_name="",
|
|
315
|
-
messages: List[str] = ["print('Hello World!')"],
|
|
316
|
-
return_full: bool = False,
|
|
317
|
-
):
|
|
318
|
-
# start initial kernel info
|
|
319
|
-
url = f"{self.host}/user/{self.user_name}/{server_name}/api/kernels"
|
|
320
|
-
headers = {"Authorization": f"Token {self.access_token}"}
|
|
321
|
-
response = requests.post(url, headers=headers)
|
|
322
|
-
kernel = json.loads(response.text)
|
|
323
|
-
|
|
324
|
-
# Execution request/reply is done on websockets channels
|
|
325
|
-
ws_url = f"wss://{self.host.split('https://')[-1]}/user/{self.user_name}/{urllib.parse.quote(server_name)}/api/kernels/{kernel['id']}/channels"
|
|
326
|
-
ws = create_connection(ws_url, header=headers)
|
|
327
|
-
|
|
328
|
-
code = messages
|
|
329
|
-
|
|
330
|
-
def execute_request(code):
|
|
331
|
-
msg_type = "execute_request"
|
|
332
|
-
content = {"code": code, "silent": False}
|
|
333
|
-
hdr = {
|
|
334
|
-
"msg_id": uuid.uuid1().hex,
|
|
335
|
-
"username": "test",
|
|
336
|
-
"session": uuid.uuid1().hex,
|
|
337
|
-
"data": datetime.datetime.now().isoformat(),
|
|
338
|
-
"msg_type": msg_type,
|
|
339
|
-
"version": "5.0",
|
|
340
|
-
}
|
|
341
|
-
msg = {
|
|
342
|
-
"header": hdr,
|
|
343
|
-
"parent_header": hdr,
|
|
344
|
-
"metadata": {},
|
|
345
|
-
"content": content,
|
|
346
|
-
}
|
|
347
|
-
return msg
|
|
348
|
-
|
|
349
|
-
for c in code:
|
|
350
|
-
if not c.startswith("!"):
|
|
351
|
-
c += ";print('AlphAI Run Complete')"
|
|
352
|
-
ws.send(json.dumps(execute_request(c)))
|
|
353
|
-
|
|
354
|
-
results = {}
|
|
355
|
-
for i in range(0, len(code)):
|
|
356
|
-
msg_type = ""
|
|
357
|
-
results[code[i]] = []
|
|
358
|
-
count = 0
|
|
359
|
-
while msg_type != "stream":
|
|
360
|
-
try:
|
|
361
|
-
rsp = json.loads(ws.recv())
|
|
362
|
-
results[code[i]].append(rsp)
|
|
363
|
-
# print(rsp["msg_type"])
|
|
364
|
-
# print(rsp["content"])
|
|
365
|
-
if rsp["msg_type"] == "stream":
|
|
366
|
-
print(rsp["content"]["text"])
|
|
367
|
-
msg_type = rsp["msg_type"]
|
|
368
|
-
if msg_type == "error":
|
|
369
|
-
raise Exception(rsp["content"]["traceback"][0])
|
|
370
|
-
except WebSocketTimeoutException as _e:
|
|
371
|
-
print("No output")
|
|
372
|
-
return
|
|
373
|
-
ws.close()
|
|
374
|
-
|
|
375
|
-
if return_full:
|
|
376
|
-
return results
|
|
377
|
-
|
|
378
|
-
def get_service(self, server_name: str = ""):
|
|
379
|
-
server_name = f"--{server_name}" if server_name else ""
|
|
380
|
-
user_name = self.user_name.replace("@", "-40").replace(".", "-2e")
|
|
381
|
-
url = f"https://jupyter-{user_name}{server_name}.americandatascience.dev"
|
|
382
|
-
return url
|
alphai/profilers/__init__.py
DELETED
|
File without changes
|
alphai/profilers/configs_base.py
DELETED
alphai/profilers/jax.py
DELETED
|
@@ -1,37 +0,0 @@
|
|
|
1
|
-
# src/alphai/profilers/jax.py
|
|
2
|
-
import os
|
|
3
|
-
from dataclasses import dataclass, field
|
|
4
|
-
import datetime
|
|
5
|
-
|
|
6
|
-
from alphai.profilers.configs_base import BaseProfilerConfigs
|
|
7
|
-
from jax.profiler import start_trace, stop_trace, trace
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
@dataclass
|
|
11
|
-
class JaxProfilerConfigs(BaseProfilerConfigs):
|
|
12
|
-
dir_path: str = "./.alphai"
|
|
13
|
-
create_perfetto_link: bool = False
|
|
14
|
-
create_perfetto_trace: bool = False
|
|
15
|
-
|
|
16
|
-
def as_dict(self):
|
|
17
|
-
return {key: vars(self)[key] for key in vars(self) if key != "dir_path"}
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
class JaxProfiler:
|
|
21
|
-
def __init__(self, configs: JaxProfilerConfigs, **kwargs):
|
|
22
|
-
self.configs = configs
|
|
23
|
-
self.dir_path = configs.dir_path
|
|
24
|
-
|
|
25
|
-
def start(self, dir_name: str = None):
|
|
26
|
-
formatted_datetime = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
|
27
|
-
if not dir_name:
|
|
28
|
-
dir_name = f"jax_trace_{formatted_datetime}"
|
|
29
|
-
profiler_path = os.path.join(self.dir_path, dir_name)
|
|
30
|
-
start_trace(
|
|
31
|
-
log_dir=profiler_path,
|
|
32
|
-
create_perfetto_link=self.configs.create_perfetto_link,
|
|
33
|
-
create_perfetto_trace=self.configs.create_perfetto_trace,
|
|
34
|
-
)
|
|
35
|
-
|
|
36
|
-
def stop(self):
|
|
37
|
-
stop_trace()
|
alphai/profilers/pytorch.py
DELETED
|
@@ -1,83 +0,0 @@
|
|
|
1
|
-
# src/alphai/profilers/pytorch.py
|
|
2
|
-
import os
|
|
3
|
-
import datetime
|
|
4
|
-
import json
|
|
5
|
-
from typing import Optional, Callable, Iterable, Any
|
|
6
|
-
from dataclasses import dataclass, field
|
|
7
|
-
|
|
8
|
-
import torch
|
|
9
|
-
from torch.profiler import (
|
|
10
|
-
profile,
|
|
11
|
-
ProfilerActivity,
|
|
12
|
-
ProfilerAction,
|
|
13
|
-
_ExperimentalConfig,
|
|
14
|
-
)
|
|
15
|
-
|
|
16
|
-
from alphai.profilers.configs_base import BaseProfilerConfigs
|
|
17
|
-
from alphai.profilers.pytorch_utils import _build_dataframe
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
@dataclass
|
|
21
|
-
class PyTorchProfilerConfigs(BaseProfilerConfigs):
|
|
22
|
-
dir_path: str = "./.alphai"
|
|
23
|
-
activities: Optional[Iterable[ProfilerActivity]] = field(
|
|
24
|
-
default_factory=lambda: [
|
|
25
|
-
ProfilerActivity.CPU,
|
|
26
|
-
ProfilerActivity.CUDA,
|
|
27
|
-
]
|
|
28
|
-
)
|
|
29
|
-
schedule: Optional[Callable[[int], ProfilerAction]] = None
|
|
30
|
-
on_trace_ready: Optional[
|
|
31
|
-
Callable[..., Any]
|
|
32
|
-
] = torch.profiler.tensorboard_trace_handler(dir_path)
|
|
33
|
-
record_shapes: bool = False # True
|
|
34
|
-
profile_memory: bool = True # True
|
|
35
|
-
with_stack: bool = False
|
|
36
|
-
with_flops: bool = True
|
|
37
|
-
with_modules: bool = False
|
|
38
|
-
experimental_config: Optional[_ExperimentalConfig] = None
|
|
39
|
-
|
|
40
|
-
def as_dict(self):
|
|
41
|
-
return {key: vars(self)[key] for key in vars(self) if key != "dir_path"}
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
class PyTorchProfiler(profile):
|
|
45
|
-
def __init__(self, configs: PyTorchProfilerConfigs, **kwargs):
|
|
46
|
-
super().__init__(**configs.as_dict())
|
|
47
|
-
self.dir_path = configs.dir_path
|
|
48
|
-
|
|
49
|
-
def start(self, dir_name: str = None):
|
|
50
|
-
super().start()
|
|
51
|
-
if not self._get_distributed_info():
|
|
52
|
-
self.add_metadata_json("distributedInfo", json.dumps({"rank": 0}))
|
|
53
|
-
formatted_datetime = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
|
54
|
-
if not dir_name:
|
|
55
|
-
dir_name = f"pt_trace_{formatted_datetime}"
|
|
56
|
-
self.profiler_path = os.path.join(self.dir_path, dir_name)
|
|
57
|
-
self.on_trace_ready = torch.profiler.tensorboard_trace_handler(
|
|
58
|
-
self.profiler_path
|
|
59
|
-
)
|
|
60
|
-
|
|
61
|
-
def get_averages(
|
|
62
|
-
self,
|
|
63
|
-
sort_by="cuda_time_total",
|
|
64
|
-
header=None,
|
|
65
|
-
row_limit=100,
|
|
66
|
-
max_src_column_width=75,
|
|
67
|
-
max_name_column_width=55,
|
|
68
|
-
max_shapes_column_width=80,
|
|
69
|
-
top_level_events_only=False,
|
|
70
|
-
**kwargs,
|
|
71
|
-
):
|
|
72
|
-
return _build_dataframe(
|
|
73
|
-
self.events().key_averages(),
|
|
74
|
-
sort_by=sort_by,
|
|
75
|
-
header=header,
|
|
76
|
-
row_limit=row_limit,
|
|
77
|
-
max_src_column_width=max_src_column_width,
|
|
78
|
-
max_name_column_width=max_name_column_width,
|
|
79
|
-
max_shapes_column_width=max_shapes_column_width,
|
|
80
|
-
with_flops=self.events()._with_flops,
|
|
81
|
-
profile_memory=self.events()._profile_memory,
|
|
82
|
-
top_level_events_only=top_level_events_only,
|
|
83
|
-
)
|