alphai 0.0.7__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- alphai/__init__.py +8 -4
- alphai/auth.py +362 -0
- alphai/cli.py +1015 -0
- alphai/client.py +400 -0
- alphai/config.py +88 -0
- alphai/docker.py +764 -0
- alphai/utils.py +192 -0
- alphai-0.1.0.dist-info/METADATA +394 -0
- alphai-0.1.0.dist-info/RECORD +12 -0
- {alphai-0.0.7.dist-info → alphai-0.1.0.dist-info}/WHEEL +2 -1
- alphai-0.1.0.dist-info/entry_points.txt +2 -0
- alphai-0.1.0.dist-info/top_level.txt +1 -0
- alphai/alphai.py +0 -786
- alphai/api/client.py +0 -0
- alphai/benchmarking/benchmarker.py +0 -37
- alphai/client/__init__.py +0 -0
- alphai/client/client.py +0 -382
- alphai/profilers/__init__.py +0 -0
- alphai/profilers/configs_base.py +0 -7
- alphai/profilers/jax.py +0 -37
- alphai/profilers/pytorch.py +0 -83
- alphai/profilers/pytorch_utils.py +0 -419
- alphai/util.py +0 -19
- alphai-0.0.7.dist-info/LICENSE +0 -201
- alphai-0.0.7.dist-info/METADATA +0 -125
- alphai-0.0.7.dist-info/RECORD +0 -16
alphai/alphai.py
DELETED
|
@@ -1,786 +0,0 @@
|
|
|
1
|
-
import os
|
|
2
|
-
from typing import List, Callable, Union
|
|
3
|
-
import warnings
|
|
4
|
-
import logging
|
|
5
|
-
import json
|
|
6
|
-
import gc
|
|
7
|
-
import datetime
|
|
8
|
-
|
|
9
|
-
import torch
|
|
10
|
-
from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
|
|
11
|
-
from hta.trace_analysis import TraceAnalysis
|
|
12
|
-
|
|
13
|
-
from alphai.util import is_package_installed, extract_param_value
|
|
14
|
-
from alphai.profilers.configs_base import BaseProfilerConfigs
|
|
15
|
-
from alphai.benchmarking.benchmarker import Benchmarker
|
|
16
|
-
from alphai.client.client import Client
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
class AlphAI:
|
|
20
|
-
"""
|
|
21
|
-
The AlphAI class provides a high-level interface for benchmarking, memory estimation,
|
|
22
|
-
and interaction with remote Jupyter Lab servers. It supports various tensor-based models
|
|
23
|
-
and integrates with American Data Science Labs for managing GPU resources.
|
|
24
|
-
|
|
25
|
-
Attributes:
|
|
26
|
-
output_path (str): The path where output files are stored.
|
|
27
|
-
supported_backends (List[str]): List of supported tensor backends (e.g., 'torch', 'jax').
|
|
28
|
-
profiler_started (bool): Flag to indicate if the profiler has started.
|
|
29
|
-
server_name (str): The name of the server for remote operations.
|
|
30
|
-
api_key (str): API key for authentication with remote services.
|
|
31
|
-
client (Client): Client instance for interacting with remote services.
|
|
32
|
-
pt_profiler (PyTorchProfiler): Profiler instance for PyTorch.
|
|
33
|
-
jax_profiler (JaxProfiler): Profiler instance for JAX.
|
|
34
|
-
benchmarker (Benchmarker): Benchmarker instance for performance measurements.
|
|
35
|
-
model (torch.nn.Module): The loaded PyTorch model.
|
|
36
|
-
model_name_or_path (str): The name or path of the model.
|
|
37
|
-
"""
|
|
38
|
-
|
|
39
|
-
def __init__(
|
|
40
|
-
self,
|
|
41
|
-
*,
|
|
42
|
-
api_key: Union[str, None] = None,
|
|
43
|
-
organization: Union[str, None] = None,
|
|
44
|
-
base_url: str = None,
|
|
45
|
-
output_path: str = "./alphai_profiler_store",
|
|
46
|
-
server_name: str = "",
|
|
47
|
-
pt_profiler_configs: BaseProfilerConfigs = None,
|
|
48
|
-
jax_profiler_configs: BaseProfilerConfigs = None,
|
|
49
|
-
**kwargs,
|
|
50
|
-
):
|
|
51
|
-
"""
|
|
52
|
-
Initializes the AlphAI instance with provided configurations.
|
|
53
|
-
|
|
54
|
-
Args:
|
|
55
|
-
api_key (Union[str, None]): API key for authentication. If None, will try to read from environment.
|
|
56
|
-
organization (Union[str, None]): The name of the organization. If None, will try to read from environment.
|
|
57
|
-
base_url (str): The base URL for remote services. If None, defaults to a predefined URL.
|
|
58
|
-
output_path (str): The path where output files are stored. Defaults to './alphai_profiler_store'.
|
|
59
|
-
server_name (str): The name of the server for remote operations.
|
|
60
|
-
pt_profiler_configs (BaseProfilerConfigs): Configuration for the PyTorch profiler.
|
|
61
|
-
jax_profiler_configs (BaseProfilerConfigs): Configuration for the JAX profiler.
|
|
62
|
-
"""
|
|
63
|
-
|
|
64
|
-
self.output_path = output_path
|
|
65
|
-
self.supported_backends = ["torch", "jax", "tensorflow"]
|
|
66
|
-
self.profiler_started = False
|
|
67
|
-
self.server_name = server_name
|
|
68
|
-
|
|
69
|
-
# Api
|
|
70
|
-
if api_key is None:
|
|
71
|
-
api_key = os.environ.get("ALPHAI_API_KEY")
|
|
72
|
-
if api_key is None:
|
|
73
|
-
logging.info(
|
|
74
|
-
"Optional: Set the API key api_key parameter init or by setting the ALPHAI_API_KEY environment variable"
|
|
75
|
-
)
|
|
76
|
-
self.api_key = api_key
|
|
77
|
-
if api_key:
|
|
78
|
-
self.client = Client(access_token=api_key)
|
|
79
|
-
|
|
80
|
-
if organization is None:
|
|
81
|
-
organization = os.environ.get("ALPHAI_ORGANIZATION_NAME")
|
|
82
|
-
self.organization = organization
|
|
83
|
-
|
|
84
|
-
if base_url is None:
|
|
85
|
-
base_url = os.environ.get("ALPHAI_BASE_URL")
|
|
86
|
-
if base_url is None:
|
|
87
|
-
base_url = f"https://lab.amdatascience.com"
|
|
88
|
-
self.base_url = base_url
|
|
89
|
-
|
|
90
|
-
# Directory ops
|
|
91
|
-
self.pt_trace_dirs = self.get_pt_traces()
|
|
92
|
-
|
|
93
|
-
# Profilers
|
|
94
|
-
self.dict_idle_time = None
|
|
95
|
-
self.dict_averages = None
|
|
96
|
-
|
|
97
|
-
if is_package_installed("torch") and not pt_profiler_configs:
|
|
98
|
-
from alphai.profilers.pytorch import PyTorchProfilerConfigs, PyTorchProfiler
|
|
99
|
-
|
|
100
|
-
pt_profiler_configs = PyTorchProfilerConfigs()
|
|
101
|
-
pt_profiler_configs.dir_path = output_path
|
|
102
|
-
self.pt_profiler = PyTorchProfiler(pt_profiler_configs)
|
|
103
|
-
|
|
104
|
-
if is_package_installed("jax") and not jax_profiler_configs:
|
|
105
|
-
from alphai.profilers.jax import JaxProfilerConfigs, JaxProfiler
|
|
106
|
-
|
|
107
|
-
jax_profiler_configs = JaxProfilerConfigs()
|
|
108
|
-
jax_profiler_configs.dir_path = output_path
|
|
109
|
-
self.jax_profiler = JaxProfiler(jax_profiler_configs)
|
|
110
|
-
|
|
111
|
-
# Benchmarker
|
|
112
|
-
self.benchmarker = Benchmarker()
|
|
113
|
-
|
|
114
|
-
# HF Generate
|
|
115
|
-
self.model_name_or_path = None
|
|
116
|
-
self.model = None
|
|
117
|
-
|
|
118
|
-
def start(self, tensor_backend: str = None):
|
|
119
|
-
"""
|
|
120
|
-
Starts the profiler for the specified tensor backend.
|
|
121
|
-
|
|
122
|
-
Args:
|
|
123
|
-
tensor_backend (str): The backend to use for profiling ('torch', 'jax', 'tensorflow').
|
|
124
|
-
If None, defaults to an available backend.
|
|
125
|
-
"""
|
|
126
|
-
# Handle if none, not installed, or unknown tensor_backend given
|
|
127
|
-
# Default to torch tensorbackend or whatever's available
|
|
128
|
-
if not tensor_backend:
|
|
129
|
-
if is_package_installed("torch"):
|
|
130
|
-
tensor_backend = "torch"
|
|
131
|
-
elif is_package_installed("jax"):
|
|
132
|
-
tensor_backend = "jax"
|
|
133
|
-
elif is_package_installed("tensorflow"):
|
|
134
|
-
tensor_backend = "tensorflow"
|
|
135
|
-
else:
|
|
136
|
-
warnings.warn(
|
|
137
|
-
f"Tensor framework must first be installed from a supported library: {self.supported_backends} to enable profiling."
|
|
138
|
-
)
|
|
139
|
-
return
|
|
140
|
-
if tensor_backend not in self.supported_backends:
|
|
141
|
-
warnings.warn(
|
|
142
|
-
f"Tensor framework is not supported, must be one of {self.supported_backends} to enable profiling."
|
|
143
|
-
)
|
|
144
|
-
return
|
|
145
|
-
if not is_package_installed(tensor_backend):
|
|
146
|
-
warnings.warn(f"You need to install '{tensor_backend}' to start profiling")
|
|
147
|
-
|
|
148
|
-
if tensor_backend == "torch":
|
|
149
|
-
try:
|
|
150
|
-
self.pt_profiler.start()
|
|
151
|
-
except:
|
|
152
|
-
# Try to stop hanging profiler and try again
|
|
153
|
-
self.pt_profiler.stop()
|
|
154
|
-
self.pt_profiler.start()
|
|
155
|
-
elif tensor_backend == "jax":
|
|
156
|
-
try:
|
|
157
|
-
self.jax_profiler.start()
|
|
158
|
-
except:
|
|
159
|
-
# Try to stop hanging profiler and try again
|
|
160
|
-
self.jax_profiler.stop()
|
|
161
|
-
self.jax_profiler.start()
|
|
162
|
-
elif tensor_backend == "tensorflow":
|
|
163
|
-
pass
|
|
164
|
-
|
|
165
|
-
self.tensor_backend = tensor_backend
|
|
166
|
-
self.profiler_started = True
|
|
167
|
-
|
|
168
|
-
def stop(self):
|
|
169
|
-
"""
|
|
170
|
-
Stops the currently running profiler.
|
|
171
|
-
"""
|
|
172
|
-
if not self.profiler_started or not self.tensor_backend:
|
|
173
|
-
warnings.warn(f"Profiler never started")
|
|
174
|
-
return
|
|
175
|
-
|
|
176
|
-
if self.tensor_backend == "torch":
|
|
177
|
-
self.pt_profiler.stop()
|
|
178
|
-
elif self.tensor_backend == "jax":
|
|
179
|
-
self.jax_profiler.stop()
|
|
180
|
-
elif self.tensor_backend == "tensorflow":
|
|
181
|
-
pass
|
|
182
|
-
|
|
183
|
-
self.profiler_started = False
|
|
184
|
-
|
|
185
|
-
def step(self):
|
|
186
|
-
"""
|
|
187
|
-
Advances the profiler by one step. Mainly used for the PyTorch profiler.
|
|
188
|
-
"""
|
|
189
|
-
self.pt_profiler.step()
|
|
190
|
-
|
|
191
|
-
def __call__(self, tensor_backend: str = None):
|
|
192
|
-
# Allows for param in context manager
|
|
193
|
-
# self.tensor_backend only set with context manager or in start()
|
|
194
|
-
self.tensor_backend = tensor_backend
|
|
195
|
-
return self
|
|
196
|
-
|
|
197
|
-
def __enter__(self):
|
|
198
|
-
self.start(tensor_backend=self.tensor_backend)
|
|
199
|
-
|
|
200
|
-
def __exit__(self, exc_type, exc_val, exc_t):
|
|
201
|
-
self.stop()
|
|
202
|
-
|
|
203
|
-
# API Methods
|
|
204
|
-
def get_servers(self):
|
|
205
|
-
"""
|
|
206
|
-
Retrieves the list of available servers from the remote service.
|
|
207
|
-
|
|
208
|
-
Returns:
|
|
209
|
-
A list of servers if successful, or raises an exception if the user is not authenticated.
|
|
210
|
-
"""
|
|
211
|
-
if not self.api_key:
|
|
212
|
-
raise ValueError("Requires user authentication with an API Key")
|
|
213
|
-
return self.client.get_servers()
|
|
214
|
-
|
|
215
|
-
def start_server(
|
|
216
|
-
self,
|
|
217
|
-
server_name: str = None,
|
|
218
|
-
environment: str = "ai",
|
|
219
|
-
server_request: str = "medium-cpu",
|
|
220
|
-
):
|
|
221
|
-
"""
|
|
222
|
-
Starts a server with the given name.
|
|
223
|
-
|
|
224
|
-
Args:
|
|
225
|
-
server_name (str): The name of the server to start. If None, uses the server name set in the instance.
|
|
226
|
-
|
|
227
|
-
Returns:
|
|
228
|
-
Response from the server start request.
|
|
229
|
-
"""
|
|
230
|
-
if not self.api_key:
|
|
231
|
-
raise ValueError("Requires user authentication with an API Key")
|
|
232
|
-
# Use set self.server_name if not provided
|
|
233
|
-
if server_name is None:
|
|
234
|
-
server_name = self.server_name
|
|
235
|
-
return self.client.start_server(server_name=server_name, environment=environment, server_request=server_request)
|
|
236
|
-
|
|
237
|
-
def stop_server(self, server_name: str = None):
|
|
238
|
-
"""
|
|
239
|
-
Stops a server with the given name.
|
|
240
|
-
|
|
241
|
-
Args:
|
|
242
|
-
server_name (str): The name of the server to stop. If None, uses the server name set in the instance.
|
|
243
|
-
|
|
244
|
-
Returns:
|
|
245
|
-
Response from the server stop request.
|
|
246
|
-
"""
|
|
247
|
-
if not self.api_key:
|
|
248
|
-
raise ValueError("Requires user authentication with an API Key")
|
|
249
|
-
# Use set self.server_name if not provided
|
|
250
|
-
if server_name is None:
|
|
251
|
-
server_name = self.server_name
|
|
252
|
-
return self.client.stop_server(server_name=server_name)
|
|
253
|
-
|
|
254
|
-
def alph(
|
|
255
|
-
self,
|
|
256
|
-
server_name: str = None,
|
|
257
|
-
messages: str = "ls",
|
|
258
|
-
engine: str = "gpt3",
|
|
259
|
-
):
|
|
260
|
-
"""
|
|
261
|
-
Gives alph commands to help you and run on the server.
|
|
262
|
-
|
|
263
|
-
Args:
|
|
264
|
-
server_name (str): The name of the server to stop. If None, uses the server name set in the instance.
|
|
265
|
-
|
|
266
|
-
Returns:
|
|
267
|
-
Response from the server stop request.
|
|
268
|
-
"""
|
|
269
|
-
if not self.api_key:
|
|
270
|
-
raise ValueError("Requires user authentication with an API Key")
|
|
271
|
-
# Use set self.server_name if not provided
|
|
272
|
-
if server_name is None:
|
|
273
|
-
server_name = self.server_name
|
|
274
|
-
return self.client.alph(server_name=server_name, messages=messages, engine=engine)
|
|
275
|
-
|
|
276
|
-
def upload(self, server_name: str = None, file_path: str = "", remote_path=""):
|
|
277
|
-
"""
|
|
278
|
-
Uploads a file to a remote server.
|
|
279
|
-
|
|
280
|
-
Args:
|
|
281
|
-
server_name (str): The name of the server to which the file will be uploaded. If None, uses the server name set in the instance.
|
|
282
|
-
file_path (str): The local path to the file.
|
|
283
|
-
remote_path (str): The remote path where the file will be stored.
|
|
284
|
-
|
|
285
|
-
Returns:
|
|
286
|
-
The response from the upload request.
|
|
287
|
-
"""
|
|
288
|
-
if not self.api_key:
|
|
289
|
-
raise ValueError("Requires user authentication with an API Key")
|
|
290
|
-
# Use set self.server_name if not provided
|
|
291
|
-
if server_name is None:
|
|
292
|
-
server_name = self.server_name
|
|
293
|
-
return self.client.put_contents(
|
|
294
|
-
server_name=server_name, path=remote_path, file_path=file_path
|
|
295
|
-
)
|
|
296
|
-
|
|
297
|
-
def run_code(
|
|
298
|
-
self,
|
|
299
|
-
code: str = "print('Hello world!')",
|
|
300
|
-
server_name: str = None,
|
|
301
|
-
clear_other_kernels: bool = True,
|
|
302
|
-
return_full: bool = False,
|
|
303
|
-
):
|
|
304
|
-
"""
|
|
305
|
-
Executes the given code on a remote server.
|
|
306
|
-
|
|
307
|
-
Args:
|
|
308
|
-
code (str): The code to execute. If a file path is provided, the code in the file is executed.
|
|
309
|
-
server_name (str): The name of the server where the code will be executed. If None, uses the server name set in the instance.
|
|
310
|
-
clear_other_kernels (bool): Whether to shut down other kernels on the server before executing the code.
|
|
311
|
-
return_full (bool): Whether to return the full response from the server.
|
|
312
|
-
|
|
313
|
-
Returns:
|
|
314
|
-
The output from the code execution.
|
|
315
|
-
"""
|
|
316
|
-
# Use set self.server_name if not provided
|
|
317
|
-
if server_name is None:
|
|
318
|
-
server_name = self.server_name
|
|
319
|
-
if clear_other_kernels:
|
|
320
|
-
self.client.shutdown_all_kernels(server_name=server_name)
|
|
321
|
-
if os.path.isfile(code):
|
|
322
|
-
if os.path.splitext(code)[1] != ".py":
|
|
323
|
-
warnings.warn(
|
|
324
|
-
"This doesn't seem to be a python file, but will try to run it anyway."
|
|
325
|
-
)
|
|
326
|
-
with open(code, "r") as f:
|
|
327
|
-
code = f.read()
|
|
328
|
-
return self.client.send_channel_execute(
|
|
329
|
-
server_name=server_name, messages=[code], return_full=return_full
|
|
330
|
-
)
|
|
331
|
-
|
|
332
|
-
def get_service(self, server_name: str = None):
|
|
333
|
-
"""
|
|
334
|
-
Retrieves the service URL for a running service or app on the server.
|
|
335
|
-
|
|
336
|
-
Args:
|
|
337
|
-
server_name (str): The name of the server. If None, uses the server name set in the instance.
|
|
338
|
-
|
|
339
|
-
Returns:
|
|
340
|
-
The URL to access the running service or app on the server.
|
|
341
|
-
"""
|
|
342
|
-
if not self.api_key:
|
|
343
|
-
raise ValueError("Requires user authentication with an API Key")
|
|
344
|
-
if server_name is None:
|
|
345
|
-
server_name = self.server_name
|
|
346
|
-
return f"If you have running service or app in your server, check it out here -> {self.client.get_service(server_name=server_name)}"
|
|
347
|
-
|
|
348
|
-
# Profilers
|
|
349
|
-
def get_profiler_stats(self):
|
|
350
|
-
"""
|
|
351
|
-
Retrieves statistics from the PyTorch profiler.
|
|
352
|
-
|
|
353
|
-
Returns:
|
|
354
|
-
A table containing key averages of profiler statistics, particularly focusing on CUDA time.
|
|
355
|
-
"""
|
|
356
|
-
stat_table = self.pt_profiler.key_averages().table(
|
|
357
|
-
sort_by="cuda_time_total", row_limit=10
|
|
358
|
-
)
|
|
359
|
-
return stat_table
|
|
360
|
-
|
|
361
|
-
def get_averages(
|
|
362
|
-
self,
|
|
363
|
-
sort_by="cuda_time_total",
|
|
364
|
-
header=None,
|
|
365
|
-
row_limit=100,
|
|
366
|
-
max_src_column_width=75,
|
|
367
|
-
max_name_column_width=55,
|
|
368
|
-
max_shapes_column_width=80,
|
|
369
|
-
top_level_events_only=False,
|
|
370
|
-
):
|
|
371
|
-
"""
|
|
372
|
-
Retrieves a DataFrame of average statistics from the PyTorch profiler powered by Kineto.
|
|
373
|
-
|
|
374
|
-
Args:
|
|
375
|
-
sort_by (str): The attribute to sort the data by. Defaults to 'cuda_time_total'.
|
|
376
|
-
header (str, optional): Header for the DataFrame. Defaults to None.
|
|
377
|
-
row_limit (int): The maximum number of rows to return. Defaults to 100.
|
|
378
|
-
max_src_column_width (int): Maximum width for the source column. Defaults to 75.
|
|
379
|
-
max_name_column_width (int): Maximum width for the name column. Defaults to 55.
|
|
380
|
-
max_shapes_column_width (int): Maximum width for the shapes column. Defaults to 80.
|
|
381
|
-
top_level_events_only (bool): Whether to include only top-level events. Defaults to False.
|
|
382
|
-
|
|
383
|
-
Returns:
|
|
384
|
-
pandas.DataFrame: A DataFrame containing the averaged profiler statistics.
|
|
385
|
-
"""
|
|
386
|
-
df_averages, self.dict_averages, str_averages = self.pt_profiler.get_averages(
|
|
387
|
-
sort_by="cuda_time_total",
|
|
388
|
-
header=None,
|
|
389
|
-
row_limit=100,
|
|
390
|
-
max_src_column_width=75,
|
|
391
|
-
max_name_column_width=55,
|
|
392
|
-
max_shapes_column_width=80,
|
|
393
|
-
top_level_events_only=False,
|
|
394
|
-
)
|
|
395
|
-
return df_averages
|
|
396
|
-
|
|
397
|
-
def run_profiler_analysis(self, trace_path: str = None, visualize: bool = False):
|
|
398
|
-
"""
|
|
399
|
-
Runs an analysis of the profiler data and optionally visualizes the results.
|
|
400
|
-
|
|
401
|
-
Args:
|
|
402
|
-
trace_path (str, optional): The path to the trace data. If None, uses the latest trace. Defaults to None.
|
|
403
|
-
visualize (bool): Whether to visualize the analysis results. Defaults to False.
|
|
404
|
-
|
|
405
|
-
Returns:
|
|
406
|
-
A tuple of DataFrames containing various analysis results, such as idle time, temporal breakdown, and GPU kernel breakdown.
|
|
407
|
-
"""
|
|
408
|
-
if trace_path:
|
|
409
|
-
pt_trace_dirs = [trace_path]
|
|
410
|
-
else:
|
|
411
|
-
pt_trace_dirs = self.get_pt_traces()
|
|
412
|
-
if pt_trace_dirs:
|
|
413
|
-
try:
|
|
414
|
-
trace_dir = os.path.join(self.pt_profiler.dir_path, pt_trace_dirs[-1])
|
|
415
|
-
self.analyzer = TraceAnalysis(trace_dir=trace_dir)
|
|
416
|
-
idle_time_df = self.analyzer.get_idle_time_breakdown(
|
|
417
|
-
show_idle_interval_stats=True, visualize=visualize
|
|
418
|
-
)
|
|
419
|
-
time_spent_df = self.analyzer.get_temporal_breakdown(
|
|
420
|
-
visualize=visualize
|
|
421
|
-
)
|
|
422
|
-
(
|
|
423
|
-
kernel_type_metrics_df,
|
|
424
|
-
kernel_metrics_df,
|
|
425
|
-
) = self.analyzer.get_gpu_kernel_breakdown()
|
|
426
|
-
self.dict_idle_time = idle_time_df[0].to_dict()
|
|
427
|
-
self.dict_time_spent = time_spent_df.to_dict()
|
|
428
|
-
self.dict_type_metrics = kernel_type_metrics_df.to_dict()
|
|
429
|
-
self.dict_kernel_metrics = kernel_metrics_df.to_dict()
|
|
430
|
-
return (
|
|
431
|
-
idle_time_df,
|
|
432
|
-
time_spent_df,
|
|
433
|
-
kernel_type_metrics_df,
|
|
434
|
-
kernel_metrics_df,
|
|
435
|
-
)
|
|
436
|
-
except:
|
|
437
|
-
warnings.warn(
|
|
438
|
-
"Error running profiler analysis, may not have GPU trace data so will continue without it."
|
|
439
|
-
)
|
|
440
|
-
self.dict_idle_time = {}
|
|
441
|
-
self.dict_time_spent = {}
|
|
442
|
-
self.dict_type_metrics = {}
|
|
443
|
-
self.dict_kernel_metrics = {}
|
|
444
|
-
return
|
|
445
|
-
|
|
446
|
-
def save(self, return_results: bool = False):
|
|
447
|
-
"""
|
|
448
|
-
Saves the profiler data and analysis results to a specified directory.
|
|
449
|
-
|
|
450
|
-
Args:
|
|
451
|
-
return_results (bool): Whether to return the saved data as a dictionary. Defaults to False.
|
|
452
|
-
|
|
453
|
-
Returns:
|
|
454
|
-
dict (optional): A dictionary containing the saved data if return_results is True.
|
|
455
|
-
"""
|
|
456
|
-
alphai_dict = {}
|
|
457
|
-
if self.dict_idle_time is None:
|
|
458
|
-
warnings.warn(
|
|
459
|
-
"Make sure to run_profiler_analysis() before saving to your analytics."
|
|
460
|
-
)
|
|
461
|
-
self.run_profiler_analysis()
|
|
462
|
-
self.get_averages()
|
|
463
|
-
alphai_dict["metadata"] = self.analyzer.t.meta_data
|
|
464
|
-
alphai_dict["idle_time"] = self.dict_idle_time
|
|
465
|
-
alphai_dict["time_spent"] = self.dict_time_spent
|
|
466
|
-
alphai_dict["type_metrics"] = self.dict_type_metrics
|
|
467
|
-
alphai_dict["kernel_metrics"] = self.dict_kernel_metrics
|
|
468
|
-
alphai_dict["key_averages"] = self.dict_averages
|
|
469
|
-
with open(
|
|
470
|
-
os.path.join(self.pt_profiler.profiler_path, "profiling.alphai"), "w"
|
|
471
|
-
) as f:
|
|
472
|
-
json.dump(alphai_dict, f, indent=4)
|
|
473
|
-
if return_results:
|
|
474
|
-
return alphai_dict
|
|
475
|
-
|
|
476
|
-
def load_view(self, dir_name: str = None):
|
|
477
|
-
"""
|
|
478
|
-
Loads a view of the profiler data onto your remote server.
|
|
479
|
-
|
|
480
|
-
Args:
|
|
481
|
-
dir_name (str, optional): The directory name to load the view from. If None, generates a timestamp-based directory name. Defaults to None.
|
|
482
|
-
|
|
483
|
-
Returns:
|
|
484
|
-
str: A URL to the GPU usage statistics dashboard.
|
|
485
|
-
"""
|
|
486
|
-
if not self.api_key:
|
|
487
|
-
raise ValueError("Requires user authentication with an API Key")
|
|
488
|
-
formatted_datetime = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
|
489
|
-
if not dir_name:
|
|
490
|
-
view_path = f"{formatted_datetime}.alphai"
|
|
491
|
-
else:
|
|
492
|
-
view_path = dir_name
|
|
493
|
-
self.client.post_contents(path="", ext=".alphai", type="directory")
|
|
494
|
-
self.client.patch_contents(path="Untitled Folder.alphai", new_path=view_path)
|
|
495
|
-
self.client.put_contents(
|
|
496
|
-
path=view_path,
|
|
497
|
-
file_path=f"{self.pt_profiler.profiler_path}/profiling.alphai",
|
|
498
|
-
)
|
|
499
|
-
return f"Check out your GPU usage statistics at -> https://dashboard.amdatascience.com/agent-alph"
|
|
500
|
-
|
|
501
|
-
def get_pt_traces(self):
|
|
502
|
-
"""
|
|
503
|
-
Retrieves a list of PyTorch trace directories sorted by date.
|
|
504
|
-
|
|
505
|
-
Returns:
|
|
506
|
-
List[str]: A list of directory names containing PyTorch traces.
|
|
507
|
-
"""
|
|
508
|
-
# List all items in the directory
|
|
509
|
-
directory_path = self.output_path
|
|
510
|
-
if not os.path.isdir(directory_path):
|
|
511
|
-
return []
|
|
512
|
-
all_items = os.listdir(directory_path)
|
|
513
|
-
|
|
514
|
-
# Filter out items that are directories and follow the naming pattern
|
|
515
|
-
date_directories = []
|
|
516
|
-
for item in all_items:
|
|
517
|
-
if os.path.isdir(os.path.join(directory_path, item)) and item.startswith(
|
|
518
|
-
"pt_trace_"
|
|
519
|
-
):
|
|
520
|
-
# Extract the date and time part from the folder name
|
|
521
|
-
datetime_part = item.split("pt_trace_")[1]
|
|
522
|
-
|
|
523
|
-
# Parse the date and time into a datetime object
|
|
524
|
-
try:
|
|
525
|
-
folder_date = datetime.datetime.strptime(
|
|
526
|
-
datetime_part, "%Y-%m-%d_%H-%M-%S"
|
|
527
|
-
)
|
|
528
|
-
date_directories.append((item, folder_date))
|
|
529
|
-
except ValueError:
|
|
530
|
-
# Handle cases where the date format is incorrect or different
|
|
531
|
-
print(f"Skipping {item} due to unexpected date format.")
|
|
532
|
-
|
|
533
|
-
# Sort the directories by the parsed datetime
|
|
534
|
-
date_directories.sort(key=lambda x: x[1])
|
|
535
|
-
|
|
536
|
-
# Return only the directory names, in sorted order
|
|
537
|
-
return [name for name, date in date_directories]
|
|
538
|
-
|
|
539
|
-
def get_jax_traces(self):
|
|
540
|
-
"""
|
|
541
|
-
Retrieves a list of JAX trace directories sorted by date.
|
|
542
|
-
|
|
543
|
-
Returns:
|
|
544
|
-
List[str]: A list of directory names containing JAX traces.
|
|
545
|
-
"""
|
|
546
|
-
# List all items in the directory
|
|
547
|
-
directory_path = self.output_path
|
|
548
|
-
if not os.path.isdir(directory_path):
|
|
549
|
-
return []
|
|
550
|
-
all_items = os.listdir(directory_path)
|
|
551
|
-
|
|
552
|
-
# Filter out items that are directories and follow the naming pattern
|
|
553
|
-
date_directories = []
|
|
554
|
-
for item in all_items:
|
|
555
|
-
if os.path.isdir(os.path.join(directory_path, item)) and item.startswith(
|
|
556
|
-
"jax_trace_"
|
|
557
|
-
):
|
|
558
|
-
# Extract the date and time part from the folder name
|
|
559
|
-
datetime_part = item.split("jax_trace_")[1]
|
|
560
|
-
|
|
561
|
-
# Parse the date and time into a datetime object
|
|
562
|
-
try:
|
|
563
|
-
folder_date = datetime.datetime.strptime(
|
|
564
|
-
datetime_part, "%Y-%m-%d_%H-%M-%S"
|
|
565
|
-
)
|
|
566
|
-
date_directories.append((item, folder_date))
|
|
567
|
-
except ValueError:
|
|
568
|
-
# Handle cases where the date format is incorrect or different
|
|
569
|
-
print(f"Skipping {item} due to unexpected date format.")
|
|
570
|
-
|
|
571
|
-
# Sort the directories by the parsed datetime
|
|
572
|
-
date_directories.sort(key=lambda x: x[1])
|
|
573
|
-
|
|
574
|
-
# Return only the directory names, in sorted order
|
|
575
|
-
return [name for name, date in date_directories]
|
|
576
|
-
|
|
577
|
-
# Benchmarker
|
|
578
|
-
def start_timer(self):
|
|
579
|
-
"""
|
|
580
|
-
Starts the benchmarking timer.
|
|
581
|
-
"""
|
|
582
|
-
self.benchmarker.start()
|
|
583
|
-
|
|
584
|
-
def stop_timer(self, print_results: bool = True):
|
|
585
|
-
"""
|
|
586
|
-
Stops the timer and optionally prints the results.
|
|
587
|
-
|
|
588
|
-
Args:
|
|
589
|
-
print_results (bool): Whether to print the results. Defaults to True.
|
|
590
|
-
|
|
591
|
-
Returns:
|
|
592
|
-
The results of the benchmark.
|
|
593
|
-
"""
|
|
594
|
-
return self.benchmarker.stop()
|
|
595
|
-
|
|
596
|
-
def benchmark(
|
|
597
|
-
self,
|
|
598
|
-
function: Callable = None,
|
|
599
|
-
*args,
|
|
600
|
-
num_iter: int = 100,
|
|
601
|
-
print_results: bool = True,
|
|
602
|
-
**kwargs,
|
|
603
|
-
):
|
|
604
|
-
"""
|
|
605
|
-
Benchmarks a function by running it a specified number of times.
|
|
606
|
-
|
|
607
|
-
Args:
|
|
608
|
-
function (Callable): The function to benchmark.
|
|
609
|
-
*args: The arguments to pass to the function.
|
|
610
|
-
num_iter (int): The number of times to run the function. Defaults to 100.
|
|
611
|
-
print_results (bool): Whether to print the results. Defaults to True.
|
|
612
|
-
**kwargs: The keyword arguments to pass to the function.
|
|
613
|
-
|
|
614
|
-
Returns:
|
|
615
|
-
The results of the benchmark.
|
|
616
|
-
"""
|
|
617
|
-
return self.benchmarker.benchmark(
|
|
618
|
-
function, *args, num_iter=num_iter, print_results=print_results, **kwargs
|
|
619
|
-
)
|
|
620
|
-
|
|
621
|
-
# Hugging Face utility
|
|
622
|
-
|
|
623
|
-
def estimate_memory_requirement(
|
|
624
|
-
self,
|
|
625
|
-
model_name: str = "stabilityai/stablelm-zephyr-3b",
|
|
626
|
-
):
|
|
627
|
-
"""
|
|
628
|
-
Estimates the memory requirement for a given model.
|
|
629
|
-
|
|
630
|
-
Args:
|
|
631
|
-
model_name (str): The name of the model. Defaults to "stabilityai/stablelm-zephyr-3b".
|
|
632
|
-
|
|
633
|
-
Returns:
|
|
634
|
-
A dictionary with the model name and the estimated memory requirement in MB and GB.
|
|
635
|
-
"""
|
|
636
|
-
try:
|
|
637
|
-
param_value = extract_param_value(model_name)
|
|
638
|
-
megabyte_value = param_value * 2 * 1000
|
|
639
|
-
gigabyte_value = param_value * 2
|
|
640
|
-
print(
|
|
641
|
-
f"Estimated memory requirement assuming float16 dtype for {model_name}: {megabyte_value:.2f} MB or {gigabyte_value:.2f} GB"
|
|
642
|
-
)
|
|
643
|
-
return {
|
|
644
|
-
"model_name_or_path": model_name,
|
|
645
|
-
"estimate_memory_requirement_mb_float16": f"{megabyte_value:.2f} MB",
|
|
646
|
-
"estimate_memory_requirement_gb_float16": f"{gigabyte_value:.2f} GB",
|
|
647
|
-
}
|
|
648
|
-
except:
|
|
649
|
-
warnings.warn(
|
|
650
|
-
"Error parsing model name or path, can't estimate memory requirement."
|
|
651
|
-
)
|
|
652
|
-
return
|
|
653
|
-
|
|
654
|
-
def memory_requirement(
|
|
655
|
-
self,
|
|
656
|
-
model_name_or_path: str = "stabilityai/stablelm-zephyr-3b",
|
|
657
|
-
device: str = "cuda",
|
|
658
|
-
trust_remote_code=True,
|
|
659
|
-
torch_dtype="auto",
|
|
660
|
-
):
|
|
661
|
-
"""
|
|
662
|
-
Estimates and prints the memory requirement for a specified model.
|
|
663
|
-
|
|
664
|
-
Args:
|
|
665
|
-
model_name_or_path (str): The name or path of the model to be loaded. Defaults to 'stabilityai/stablelm-zephyr-3b'.
|
|
666
|
-
device (str): The device to load the model on ('cuda' or 'cpu'). Defaults to 'cuda'.
|
|
667
|
-
trust_remote_code (bool): Whether to trust remote code when loading the model. Defaults to True.
|
|
668
|
-
torch_dtype (str): The data type for the model parameters. Defaults to 'auto'.
|
|
669
|
-
|
|
670
|
-
Returns:
|
|
671
|
-
dict: A dictionary containing the memory requirement in MB and GB.
|
|
672
|
-
"""
|
|
673
|
-
if not is_package_installed("torch"):
|
|
674
|
-
warnings.warn(f"You need to install 'torch' to run memory_requirement")
|
|
675
|
-
return
|
|
676
|
-
if not self.model_name_or_path or self.model_name_or_path != model_name_or_path:
|
|
677
|
-
self.model_name_or_path = model_name_or_path
|
|
678
|
-
try:
|
|
679
|
-
self.model = AutoModelForCausalLM.from_pretrained(
|
|
680
|
-
model_name_or_path,
|
|
681
|
-
trust_remote_code=trust_remote_code,
|
|
682
|
-
torch_dtype=torch_dtype,
|
|
683
|
-
).to(device)
|
|
684
|
-
except:
|
|
685
|
-
warnings.warn(
|
|
686
|
-
"Loading model to CPU instead of GPU since GPU is not available."
|
|
687
|
-
)
|
|
688
|
-
self.model = AutoModelForCausalLM.from_pretrained(
|
|
689
|
-
model_name_or_path,
|
|
690
|
-
trust_remote_code=trust_remote_code,
|
|
691
|
-
torch_dtype=torch_dtype,
|
|
692
|
-
).to("cpu")
|
|
693
|
-
try:
|
|
694
|
-
param_value = self.model.num_parameters()
|
|
695
|
-
except:
|
|
696
|
-
param_value = sum(p.numel() for p in self.model.parameters())
|
|
697
|
-
|
|
698
|
-
megabyte_value = param_value * 2 / 1000000
|
|
699
|
-
gigabyte_value = param_value * 2 / 1000000000
|
|
700
|
-
print(
|
|
701
|
-
f"Memory requirement assuming float16 dtype for {model_name_or_path}: {megabyte_value:.2f} MB or {gigabyte_value:.2f} GB"
|
|
702
|
-
)
|
|
703
|
-
return {
|
|
704
|
-
"model_name_or_path": model_name_or_path,
|
|
705
|
-
"memory_requirement_mb_float16": f"{megabyte_value:.2f} MB",
|
|
706
|
-
"memory_requirement_gb_float16": f"{gigabyte_value:.2f} GB",
|
|
707
|
-
}
|
|
708
|
-
|
|
709
|
-
def generate(
|
|
710
|
-
self,
|
|
711
|
-
text: str = "",
|
|
712
|
-
prompt: List[dict] = None,
|
|
713
|
-
model_name_or_path: str = "stabilityai/stablelm-zephyr-3b",
|
|
714
|
-
trust_remote_code=True,
|
|
715
|
-
torch_dtype="auto",
|
|
716
|
-
stream: bool = True,
|
|
717
|
-
device: str = "cuda",
|
|
718
|
-
**kwargs,
|
|
719
|
-
):
|
|
720
|
-
"""
|
|
721
|
-
Generates text using the specified model based on the given prompt or text.
|
|
722
|
-
|
|
723
|
-
Args:
|
|
724
|
-
text (str): The text to be used as a prompt. Defaults to an empty string.
|
|
725
|
-
prompt (List[dict]): A list of dictionaries defining the prompt structure. Defaults to None.
|
|
726
|
-
model_name_or_path (str): The name or path of the model to be used. Defaults to 'stabilityai/stablelm-zephyr-3b'.
|
|
727
|
-
trust_remote_code (bool): Whether to trust remote code when loading the model. Defaults to True.
|
|
728
|
-
torch_dtype (str): The data type for the model parameters. Defaults to 'auto'.
|
|
729
|
-
stream (bool): Whether to use streaming for text generation. Defaults to True.
|
|
730
|
-
device (str): The device to run the model on. Defaults to 'cuda'.
|
|
731
|
-
|
|
732
|
-
Returns:
|
|
733
|
-
str: The generated text.
|
|
734
|
-
"""
|
|
735
|
-
if not is_package_installed("torch"):
|
|
736
|
-
warnings.warn(f"You need to install 'torch' to run generate")
|
|
737
|
-
return
|
|
738
|
-
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
|
|
739
|
-
streamer = TextStreamer(tokenizer) if stream else None
|
|
740
|
-
if not self.model_name_or_path or self.model_name_or_path != model_name_or_path:
|
|
741
|
-
self.model_name_or_path = model_name_or_path
|
|
742
|
-
try:
|
|
743
|
-
self.model = AutoModelForCausalLM.from_pretrained(
|
|
744
|
-
model_name_or_path,
|
|
745
|
-
trust_remote_code=trust_remote_code,
|
|
746
|
-
torch_dtype=torch_dtype,
|
|
747
|
-
).to(device)
|
|
748
|
-
except:
|
|
749
|
-
warnings.warn(
|
|
750
|
-
"Loading model to CPU instead of GPU since GPU is not available."
|
|
751
|
-
)
|
|
752
|
-
self.model = AutoModelForCausalLM.from_pretrained(
|
|
753
|
-
model_name_or_path,
|
|
754
|
-
trust_remote_code=trust_remote_code,
|
|
755
|
-
torch_dtype=torch_dtype,
|
|
756
|
-
).to("cpu")
|
|
757
|
-
|
|
758
|
-
if not prompt:
|
|
759
|
-
prompt = [{"role": "user", "content": text}]
|
|
760
|
-
inputs = tokenizer.apply_chat_template(
|
|
761
|
-
prompt, add_generation_prompt=True, return_tensors="pt"
|
|
762
|
-
)
|
|
763
|
-
|
|
764
|
-
tokens = self.model.generate(
|
|
765
|
-
inputs.to(self.model.device),
|
|
766
|
-
max_new_tokens=1024,
|
|
767
|
-
temperature=0.8,
|
|
768
|
-
do_sample=True,
|
|
769
|
-
streamer=streamer,
|
|
770
|
-
**kwargs,
|
|
771
|
-
)
|
|
772
|
-
|
|
773
|
-
return tokenizer.decode(tokens[0])
|
|
774
|
-
|
|
775
|
-
def clear_cuda_memory(self):
|
|
776
|
-
"""
|
|
777
|
-
Clears the CUDA memory cache to free up GPU memory.
|
|
778
|
-
|
|
779
|
-
Raises:
|
|
780
|
-
Warning: If PyTorch is not installed.
|
|
781
|
-
"""
|
|
782
|
-
if not is_package_installed("torch"):
|
|
783
|
-
warnings.warn(f"You need to install 'torch' to run clear_cuda_memory")
|
|
784
|
-
return
|
|
785
|
-
gc.collect()
|
|
786
|
-
torch.cuda.empty_cache()
|