alphai 0.0.7__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
alphai/alphai.py DELETED
@@ -1,786 +0,0 @@
1
- import os
2
- from typing import List, Callable, Union
3
- import warnings
4
- import logging
5
- import json
6
- import gc
7
- import datetime
8
-
9
- import torch
10
- from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
11
- from hta.trace_analysis import TraceAnalysis
12
-
13
- from alphai.util import is_package_installed, extract_param_value
14
- from alphai.profilers.configs_base import BaseProfilerConfigs
15
- from alphai.benchmarking.benchmarker import Benchmarker
16
- from alphai.client.client import Client
17
-
18
-
19
- class AlphAI:
20
- """
21
- The AlphAI class provides a high-level interface for benchmarking, memory estimation,
22
- and interaction with remote Jupyter Lab servers. It supports various tensor-based models
23
- and integrates with American Data Science Labs for managing GPU resources.
24
-
25
- Attributes:
26
- output_path (str): The path where output files are stored.
27
- supported_backends (List[str]): List of supported tensor backends (e.g., 'torch', 'jax').
28
- profiler_started (bool): Flag to indicate if the profiler has started.
29
- server_name (str): The name of the server for remote operations.
30
- api_key (str): API key for authentication with remote services.
31
- client (Client): Client instance for interacting with remote services.
32
- pt_profiler (PyTorchProfiler): Profiler instance for PyTorch.
33
- jax_profiler (JaxProfiler): Profiler instance for JAX.
34
- benchmarker (Benchmarker): Benchmarker instance for performance measurements.
35
- model (torch.nn.Module): The loaded PyTorch model.
36
- model_name_or_path (str): The name or path of the model.
37
- """
38
-
39
- def __init__(
40
- self,
41
- *,
42
- api_key: Union[str, None] = None,
43
- organization: Union[str, None] = None,
44
- base_url: str = None,
45
- output_path: str = "./alphai_profiler_store",
46
- server_name: str = "",
47
- pt_profiler_configs: BaseProfilerConfigs = None,
48
- jax_profiler_configs: BaseProfilerConfigs = None,
49
- **kwargs,
50
- ):
51
- """
52
- Initializes the AlphAI instance with provided configurations.
53
-
54
- Args:
55
- api_key (Union[str, None]): API key for authentication. If None, will try to read from environment.
56
- organization (Union[str, None]): The name of the organization. If None, will try to read from environment.
57
- base_url (str): The base URL for remote services. If None, defaults to a predefined URL.
58
- output_path (str): The path where output files are stored. Defaults to './alphai_profiler_store'.
59
- server_name (str): The name of the server for remote operations.
60
- pt_profiler_configs (BaseProfilerConfigs): Configuration for the PyTorch profiler.
61
- jax_profiler_configs (BaseProfilerConfigs): Configuration for the JAX profiler.
62
- """
63
-
64
- self.output_path = output_path
65
- self.supported_backends = ["torch", "jax", "tensorflow"]
66
- self.profiler_started = False
67
- self.server_name = server_name
68
-
69
- # Api
70
- if api_key is None:
71
- api_key = os.environ.get("ALPHAI_API_KEY")
72
- if api_key is None:
73
- logging.info(
74
- "Optional: Set the API key api_key parameter init or by setting the ALPHAI_API_KEY environment variable"
75
- )
76
- self.api_key = api_key
77
- if api_key:
78
- self.client = Client(access_token=api_key)
79
-
80
- if organization is None:
81
- organization = os.environ.get("ALPHAI_ORGANIZATION_NAME")
82
- self.organization = organization
83
-
84
- if base_url is None:
85
- base_url = os.environ.get("ALPHAI_BASE_URL")
86
- if base_url is None:
87
- base_url = f"https://lab.amdatascience.com"
88
- self.base_url = base_url
89
-
90
- # Directory ops
91
- self.pt_trace_dirs = self.get_pt_traces()
92
-
93
- # Profilers
94
- self.dict_idle_time = None
95
- self.dict_averages = None
96
-
97
- if is_package_installed("torch") and not pt_profiler_configs:
98
- from alphai.profilers.pytorch import PyTorchProfilerConfigs, PyTorchProfiler
99
-
100
- pt_profiler_configs = PyTorchProfilerConfigs()
101
- pt_profiler_configs.dir_path = output_path
102
- self.pt_profiler = PyTorchProfiler(pt_profiler_configs)
103
-
104
- if is_package_installed("jax") and not jax_profiler_configs:
105
- from alphai.profilers.jax import JaxProfilerConfigs, JaxProfiler
106
-
107
- jax_profiler_configs = JaxProfilerConfigs()
108
- jax_profiler_configs.dir_path = output_path
109
- self.jax_profiler = JaxProfiler(jax_profiler_configs)
110
-
111
- # Benchmarker
112
- self.benchmarker = Benchmarker()
113
-
114
- # HF Generate
115
- self.model_name_or_path = None
116
- self.model = None
117
-
118
- def start(self, tensor_backend: str = None):
119
- """
120
- Starts the profiler for the specified tensor backend.
121
-
122
- Args:
123
- tensor_backend (str): The backend to use for profiling ('torch', 'jax', 'tensorflow').
124
- If None, defaults to an available backend.
125
- """
126
- # Handle if none, not installed, or unknown tensor_backend given
127
- # Default to torch tensorbackend or whatever's available
128
- if not tensor_backend:
129
- if is_package_installed("torch"):
130
- tensor_backend = "torch"
131
- elif is_package_installed("jax"):
132
- tensor_backend = "jax"
133
- elif is_package_installed("tensorflow"):
134
- tensor_backend = "tensorflow"
135
- else:
136
- warnings.warn(
137
- f"Tensor framework must first be installed from a supported library: {self.supported_backends} to enable profiling."
138
- )
139
- return
140
- if tensor_backend not in self.supported_backends:
141
- warnings.warn(
142
- f"Tensor framework is not supported, must be one of {self.supported_backends} to enable profiling."
143
- )
144
- return
145
- if not is_package_installed(tensor_backend):
146
- warnings.warn(f"You need to install '{tensor_backend}' to start profiling")
147
-
148
- if tensor_backend == "torch":
149
- try:
150
- self.pt_profiler.start()
151
- except:
152
- # Try to stop hanging profiler and try again
153
- self.pt_profiler.stop()
154
- self.pt_profiler.start()
155
- elif tensor_backend == "jax":
156
- try:
157
- self.jax_profiler.start()
158
- except:
159
- # Try to stop hanging profiler and try again
160
- self.jax_profiler.stop()
161
- self.jax_profiler.start()
162
- elif tensor_backend == "tensorflow":
163
- pass
164
-
165
- self.tensor_backend = tensor_backend
166
- self.profiler_started = True
167
-
168
- def stop(self):
169
- """
170
- Stops the currently running profiler.
171
- """
172
- if not self.profiler_started or not self.tensor_backend:
173
- warnings.warn(f"Profiler never started")
174
- return
175
-
176
- if self.tensor_backend == "torch":
177
- self.pt_profiler.stop()
178
- elif self.tensor_backend == "jax":
179
- self.jax_profiler.stop()
180
- elif self.tensor_backend == "tensorflow":
181
- pass
182
-
183
- self.profiler_started = False
184
-
185
- def step(self):
186
- """
187
- Advances the profiler by one step. Mainly used for the PyTorch profiler.
188
- """
189
- self.pt_profiler.step()
190
-
191
- def __call__(self, tensor_backend: str = None):
192
- # Allows for param in context manager
193
- # self.tensor_backend only set with context manager or in start()
194
- self.tensor_backend = tensor_backend
195
- return self
196
-
197
- def __enter__(self):
198
- self.start(tensor_backend=self.tensor_backend)
199
-
200
- def __exit__(self, exc_type, exc_val, exc_t):
201
- self.stop()
202
-
203
- # API Methods
204
- def get_servers(self):
205
- """
206
- Retrieves the list of available servers from the remote service.
207
-
208
- Returns:
209
- A list of servers if successful, or raises an exception if the user is not authenticated.
210
- """
211
- if not self.api_key:
212
- raise ValueError("Requires user authentication with an API Key")
213
- return self.client.get_servers()
214
-
215
- def start_server(
216
- self,
217
- server_name: str = None,
218
- environment: str = "ai",
219
- server_request: str = "medium-cpu",
220
- ):
221
- """
222
- Starts a server with the given name.
223
-
224
- Args:
225
- server_name (str): The name of the server to start. If None, uses the server name set in the instance.
226
-
227
- Returns:
228
- Response from the server start request.
229
- """
230
- if not self.api_key:
231
- raise ValueError("Requires user authentication with an API Key")
232
- # Use set self.server_name if not provided
233
- if server_name is None:
234
- server_name = self.server_name
235
- return self.client.start_server(server_name=server_name, environment=environment, server_request=server_request)
236
-
237
- def stop_server(self, server_name: str = None):
238
- """
239
- Stops a server with the given name.
240
-
241
- Args:
242
- server_name (str): The name of the server to stop. If None, uses the server name set in the instance.
243
-
244
- Returns:
245
- Response from the server stop request.
246
- """
247
- if not self.api_key:
248
- raise ValueError("Requires user authentication with an API Key")
249
- # Use set self.server_name if not provided
250
- if server_name is None:
251
- server_name = self.server_name
252
- return self.client.stop_server(server_name=server_name)
253
-
254
- def alph(
255
- self,
256
- server_name: str = None,
257
- messages: str = "ls",
258
- engine: str = "gpt3",
259
- ):
260
- """
261
- Gives alph commands to help you and run on the server.
262
-
263
- Args:
264
- server_name (str): The name of the server to stop. If None, uses the server name set in the instance.
265
-
266
- Returns:
267
- Response from the server stop request.
268
- """
269
- if not self.api_key:
270
- raise ValueError("Requires user authentication with an API Key")
271
- # Use set self.server_name if not provided
272
- if server_name is None:
273
- server_name = self.server_name
274
- return self.client.alph(server_name=server_name, messages=messages, engine=engine)
275
-
276
- def upload(self, server_name: str = None, file_path: str = "", remote_path=""):
277
- """
278
- Uploads a file to a remote server.
279
-
280
- Args:
281
- server_name (str): The name of the server to which the file will be uploaded. If None, uses the server name set in the instance.
282
- file_path (str): The local path to the file.
283
- remote_path (str): The remote path where the file will be stored.
284
-
285
- Returns:
286
- The response from the upload request.
287
- """
288
- if not self.api_key:
289
- raise ValueError("Requires user authentication with an API Key")
290
- # Use set self.server_name if not provided
291
- if server_name is None:
292
- server_name = self.server_name
293
- return self.client.put_contents(
294
- server_name=server_name, path=remote_path, file_path=file_path
295
- )
296
-
297
- def run_code(
298
- self,
299
- code: str = "print('Hello world!')",
300
- server_name: str = None,
301
- clear_other_kernels: bool = True,
302
- return_full: bool = False,
303
- ):
304
- """
305
- Executes the given code on a remote server.
306
-
307
- Args:
308
- code (str): The code to execute. If a file path is provided, the code in the file is executed.
309
- server_name (str): The name of the server where the code will be executed. If None, uses the server name set in the instance.
310
- clear_other_kernels (bool): Whether to shut down other kernels on the server before executing the code.
311
- return_full (bool): Whether to return the full response from the server.
312
-
313
- Returns:
314
- The output from the code execution.
315
- """
316
- # Use set self.server_name if not provided
317
- if server_name is None:
318
- server_name = self.server_name
319
- if clear_other_kernels:
320
- self.client.shutdown_all_kernels(server_name=server_name)
321
- if os.path.isfile(code):
322
- if os.path.splitext(code)[1] != ".py":
323
- warnings.warn(
324
- "This doesn't seem to be a python file, but will try to run it anyway."
325
- )
326
- with open(code, "r") as f:
327
- code = f.read()
328
- return self.client.send_channel_execute(
329
- server_name=server_name, messages=[code], return_full=return_full
330
- )
331
-
332
- def get_service(self, server_name: str = None):
333
- """
334
- Retrieves the service URL for a running service or app on the server.
335
-
336
- Args:
337
- server_name (str): The name of the server. If None, uses the server name set in the instance.
338
-
339
- Returns:
340
- The URL to access the running service or app on the server.
341
- """
342
- if not self.api_key:
343
- raise ValueError("Requires user authentication with an API Key")
344
- if server_name is None:
345
- server_name = self.server_name
346
- return f"If you have running service or app in your server, check it out here -> {self.client.get_service(server_name=server_name)}"
347
-
348
- # Profilers
349
- def get_profiler_stats(self):
350
- """
351
- Retrieves statistics from the PyTorch profiler.
352
-
353
- Returns:
354
- A table containing key averages of profiler statistics, particularly focusing on CUDA time.
355
- """
356
- stat_table = self.pt_profiler.key_averages().table(
357
- sort_by="cuda_time_total", row_limit=10
358
- )
359
- return stat_table
360
-
361
- def get_averages(
362
- self,
363
- sort_by="cuda_time_total",
364
- header=None,
365
- row_limit=100,
366
- max_src_column_width=75,
367
- max_name_column_width=55,
368
- max_shapes_column_width=80,
369
- top_level_events_only=False,
370
- ):
371
- """
372
- Retrieves a DataFrame of average statistics from the PyTorch profiler powered by Kineto.
373
-
374
- Args:
375
- sort_by (str): The attribute to sort the data by. Defaults to 'cuda_time_total'.
376
- header (str, optional): Header for the DataFrame. Defaults to None.
377
- row_limit (int): The maximum number of rows to return. Defaults to 100.
378
- max_src_column_width (int): Maximum width for the source column. Defaults to 75.
379
- max_name_column_width (int): Maximum width for the name column. Defaults to 55.
380
- max_shapes_column_width (int): Maximum width for the shapes column. Defaults to 80.
381
- top_level_events_only (bool): Whether to include only top-level events. Defaults to False.
382
-
383
- Returns:
384
- pandas.DataFrame: A DataFrame containing the averaged profiler statistics.
385
- """
386
- df_averages, self.dict_averages, str_averages = self.pt_profiler.get_averages(
387
- sort_by="cuda_time_total",
388
- header=None,
389
- row_limit=100,
390
- max_src_column_width=75,
391
- max_name_column_width=55,
392
- max_shapes_column_width=80,
393
- top_level_events_only=False,
394
- )
395
- return df_averages
396
-
397
- def run_profiler_analysis(self, trace_path: str = None, visualize: bool = False):
398
- """
399
- Runs an analysis of the profiler data and optionally visualizes the results.
400
-
401
- Args:
402
- trace_path (str, optional): The path to the trace data. If None, uses the latest trace. Defaults to None.
403
- visualize (bool): Whether to visualize the analysis results. Defaults to False.
404
-
405
- Returns:
406
- A tuple of DataFrames containing various analysis results, such as idle time, temporal breakdown, and GPU kernel breakdown.
407
- """
408
- if trace_path:
409
- pt_trace_dirs = [trace_path]
410
- else:
411
- pt_trace_dirs = self.get_pt_traces()
412
- if pt_trace_dirs:
413
- try:
414
- trace_dir = os.path.join(self.pt_profiler.dir_path, pt_trace_dirs[-1])
415
- self.analyzer = TraceAnalysis(trace_dir=trace_dir)
416
- idle_time_df = self.analyzer.get_idle_time_breakdown(
417
- show_idle_interval_stats=True, visualize=visualize
418
- )
419
- time_spent_df = self.analyzer.get_temporal_breakdown(
420
- visualize=visualize
421
- )
422
- (
423
- kernel_type_metrics_df,
424
- kernel_metrics_df,
425
- ) = self.analyzer.get_gpu_kernel_breakdown()
426
- self.dict_idle_time = idle_time_df[0].to_dict()
427
- self.dict_time_spent = time_spent_df.to_dict()
428
- self.dict_type_metrics = kernel_type_metrics_df.to_dict()
429
- self.dict_kernel_metrics = kernel_metrics_df.to_dict()
430
- return (
431
- idle_time_df,
432
- time_spent_df,
433
- kernel_type_metrics_df,
434
- kernel_metrics_df,
435
- )
436
- except:
437
- warnings.warn(
438
- "Error running profiler analysis, may not have GPU trace data so will continue without it."
439
- )
440
- self.dict_idle_time = {}
441
- self.dict_time_spent = {}
442
- self.dict_type_metrics = {}
443
- self.dict_kernel_metrics = {}
444
- return
445
-
446
- def save(self, return_results: bool = False):
447
- """
448
- Saves the profiler data and analysis results to a specified directory.
449
-
450
- Args:
451
- return_results (bool): Whether to return the saved data as a dictionary. Defaults to False.
452
-
453
- Returns:
454
- dict (optional): A dictionary containing the saved data if return_results is True.
455
- """
456
- alphai_dict = {}
457
- if self.dict_idle_time is None:
458
- warnings.warn(
459
- "Make sure to run_profiler_analysis() before saving to your analytics."
460
- )
461
- self.run_profiler_analysis()
462
- self.get_averages()
463
- alphai_dict["metadata"] = self.analyzer.t.meta_data
464
- alphai_dict["idle_time"] = self.dict_idle_time
465
- alphai_dict["time_spent"] = self.dict_time_spent
466
- alphai_dict["type_metrics"] = self.dict_type_metrics
467
- alphai_dict["kernel_metrics"] = self.dict_kernel_metrics
468
- alphai_dict["key_averages"] = self.dict_averages
469
- with open(
470
- os.path.join(self.pt_profiler.profiler_path, "profiling.alphai"), "w"
471
- ) as f:
472
- json.dump(alphai_dict, f, indent=4)
473
- if return_results:
474
- return alphai_dict
475
-
476
- def load_view(self, dir_name: str = None):
477
- """
478
- Loads a view of the profiler data onto your remote server.
479
-
480
- Args:
481
- dir_name (str, optional): The directory name to load the view from. If None, generates a timestamp-based directory name. Defaults to None.
482
-
483
- Returns:
484
- str: A URL to the GPU usage statistics dashboard.
485
- """
486
- if not self.api_key:
487
- raise ValueError("Requires user authentication with an API Key")
488
- formatted_datetime = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
489
- if not dir_name:
490
- view_path = f"{formatted_datetime}.alphai"
491
- else:
492
- view_path = dir_name
493
- self.client.post_contents(path="", ext=".alphai", type="directory")
494
- self.client.patch_contents(path="Untitled Folder.alphai", new_path=view_path)
495
- self.client.put_contents(
496
- path=view_path,
497
- file_path=f"{self.pt_profiler.profiler_path}/profiling.alphai",
498
- )
499
- return f"Check out your GPU usage statistics at -> https://dashboard.amdatascience.com/agent-alph"
500
-
501
- def get_pt_traces(self):
502
- """
503
- Retrieves a list of PyTorch trace directories sorted by date.
504
-
505
- Returns:
506
- List[str]: A list of directory names containing PyTorch traces.
507
- """
508
- # List all items in the directory
509
- directory_path = self.output_path
510
- if not os.path.isdir(directory_path):
511
- return []
512
- all_items = os.listdir(directory_path)
513
-
514
- # Filter out items that are directories and follow the naming pattern
515
- date_directories = []
516
- for item in all_items:
517
- if os.path.isdir(os.path.join(directory_path, item)) and item.startswith(
518
- "pt_trace_"
519
- ):
520
- # Extract the date and time part from the folder name
521
- datetime_part = item.split("pt_trace_")[1]
522
-
523
- # Parse the date and time into a datetime object
524
- try:
525
- folder_date = datetime.datetime.strptime(
526
- datetime_part, "%Y-%m-%d_%H-%M-%S"
527
- )
528
- date_directories.append((item, folder_date))
529
- except ValueError:
530
- # Handle cases where the date format is incorrect or different
531
- print(f"Skipping {item} due to unexpected date format.")
532
-
533
- # Sort the directories by the parsed datetime
534
- date_directories.sort(key=lambda x: x[1])
535
-
536
- # Return only the directory names, in sorted order
537
- return [name for name, date in date_directories]
538
-
539
- def get_jax_traces(self):
540
- """
541
- Retrieves a list of JAX trace directories sorted by date.
542
-
543
- Returns:
544
- List[str]: A list of directory names containing JAX traces.
545
- """
546
- # List all items in the directory
547
- directory_path = self.output_path
548
- if not os.path.isdir(directory_path):
549
- return []
550
- all_items = os.listdir(directory_path)
551
-
552
- # Filter out items that are directories and follow the naming pattern
553
- date_directories = []
554
- for item in all_items:
555
- if os.path.isdir(os.path.join(directory_path, item)) and item.startswith(
556
- "jax_trace_"
557
- ):
558
- # Extract the date and time part from the folder name
559
- datetime_part = item.split("jax_trace_")[1]
560
-
561
- # Parse the date and time into a datetime object
562
- try:
563
- folder_date = datetime.datetime.strptime(
564
- datetime_part, "%Y-%m-%d_%H-%M-%S"
565
- )
566
- date_directories.append((item, folder_date))
567
- except ValueError:
568
- # Handle cases where the date format is incorrect or different
569
- print(f"Skipping {item} due to unexpected date format.")
570
-
571
- # Sort the directories by the parsed datetime
572
- date_directories.sort(key=lambda x: x[1])
573
-
574
- # Return only the directory names, in sorted order
575
- return [name for name, date in date_directories]
576
-
577
- # Benchmarker
578
- def start_timer(self):
579
- """
580
- Starts the benchmarking timer.
581
- """
582
- self.benchmarker.start()
583
-
584
- def stop_timer(self, print_results: bool = True):
585
- """
586
- Stops the timer and optionally prints the results.
587
-
588
- Args:
589
- print_results (bool): Whether to print the results. Defaults to True.
590
-
591
- Returns:
592
- The results of the benchmark.
593
- """
594
- return self.benchmarker.stop()
595
-
596
- def benchmark(
597
- self,
598
- function: Callable = None,
599
- *args,
600
- num_iter: int = 100,
601
- print_results: bool = True,
602
- **kwargs,
603
- ):
604
- """
605
- Benchmarks a function by running it a specified number of times.
606
-
607
- Args:
608
- function (Callable): The function to benchmark.
609
- *args: The arguments to pass to the function.
610
- num_iter (int): The number of times to run the function. Defaults to 100.
611
- print_results (bool): Whether to print the results. Defaults to True.
612
- **kwargs: The keyword arguments to pass to the function.
613
-
614
- Returns:
615
- The results of the benchmark.
616
- """
617
- return self.benchmarker.benchmark(
618
- function, *args, num_iter=num_iter, print_results=print_results, **kwargs
619
- )
620
-
621
- # Hugging Face utility
622
-
623
- def estimate_memory_requirement(
624
- self,
625
- model_name: str = "stabilityai/stablelm-zephyr-3b",
626
- ):
627
- """
628
- Estimates the memory requirement for a given model.
629
-
630
- Args:
631
- model_name (str): The name of the model. Defaults to "stabilityai/stablelm-zephyr-3b".
632
-
633
- Returns:
634
- A dictionary with the model name and the estimated memory requirement in MB and GB.
635
- """
636
- try:
637
- param_value = extract_param_value(model_name)
638
- megabyte_value = param_value * 2 * 1000
639
- gigabyte_value = param_value * 2
640
- print(
641
- f"Estimated memory requirement assuming float16 dtype for {model_name}: {megabyte_value:.2f} MB or {gigabyte_value:.2f} GB"
642
- )
643
- return {
644
- "model_name_or_path": model_name,
645
- "estimate_memory_requirement_mb_float16": f"{megabyte_value:.2f} MB",
646
- "estimate_memory_requirement_gb_float16": f"{gigabyte_value:.2f} GB",
647
- }
648
- except:
649
- warnings.warn(
650
- "Error parsing model name or path, can't estimate memory requirement."
651
- )
652
- return
653
-
654
- def memory_requirement(
655
- self,
656
- model_name_or_path: str = "stabilityai/stablelm-zephyr-3b",
657
- device: str = "cuda",
658
- trust_remote_code=True,
659
- torch_dtype="auto",
660
- ):
661
- """
662
- Estimates and prints the memory requirement for a specified model.
663
-
664
- Args:
665
- model_name_or_path (str): The name or path of the model to be loaded. Defaults to 'stabilityai/stablelm-zephyr-3b'.
666
- device (str): The device to load the model on ('cuda' or 'cpu'). Defaults to 'cuda'.
667
- trust_remote_code (bool): Whether to trust remote code when loading the model. Defaults to True.
668
- torch_dtype (str): The data type for the model parameters. Defaults to 'auto'.
669
-
670
- Returns:
671
- dict: A dictionary containing the memory requirement in MB and GB.
672
- """
673
- if not is_package_installed("torch"):
674
- warnings.warn(f"You need to install 'torch' to run memory_requirement")
675
- return
676
- if not self.model_name_or_path or self.model_name_or_path != model_name_or_path:
677
- self.model_name_or_path = model_name_or_path
678
- try:
679
- self.model = AutoModelForCausalLM.from_pretrained(
680
- model_name_or_path,
681
- trust_remote_code=trust_remote_code,
682
- torch_dtype=torch_dtype,
683
- ).to(device)
684
- except:
685
- warnings.warn(
686
- "Loading model to CPU instead of GPU since GPU is not available."
687
- )
688
- self.model = AutoModelForCausalLM.from_pretrained(
689
- model_name_or_path,
690
- trust_remote_code=trust_remote_code,
691
- torch_dtype=torch_dtype,
692
- ).to("cpu")
693
- try:
694
- param_value = self.model.num_parameters()
695
- except:
696
- param_value = sum(p.numel() for p in self.model.parameters())
697
-
698
- megabyte_value = param_value * 2 / 1000000
699
- gigabyte_value = param_value * 2 / 1000000000
700
- print(
701
- f"Memory requirement assuming float16 dtype for {model_name_or_path}: {megabyte_value:.2f} MB or {gigabyte_value:.2f} GB"
702
- )
703
- return {
704
- "model_name_or_path": model_name_or_path,
705
- "memory_requirement_mb_float16": f"{megabyte_value:.2f} MB",
706
- "memory_requirement_gb_float16": f"{gigabyte_value:.2f} GB",
707
- }
708
-
709
- def generate(
710
- self,
711
- text: str = "",
712
- prompt: List[dict] = None,
713
- model_name_or_path: str = "stabilityai/stablelm-zephyr-3b",
714
- trust_remote_code=True,
715
- torch_dtype="auto",
716
- stream: bool = True,
717
- device: str = "cuda",
718
- **kwargs,
719
- ):
720
- """
721
- Generates text using the specified model based on the given prompt or text.
722
-
723
- Args:
724
- text (str): The text to be used as a prompt. Defaults to an empty string.
725
- prompt (List[dict]): A list of dictionaries defining the prompt structure. Defaults to None.
726
- model_name_or_path (str): The name or path of the model to be used. Defaults to 'stabilityai/stablelm-zephyr-3b'.
727
- trust_remote_code (bool): Whether to trust remote code when loading the model. Defaults to True.
728
- torch_dtype (str): The data type for the model parameters. Defaults to 'auto'.
729
- stream (bool): Whether to use streaming for text generation. Defaults to True.
730
- device (str): The device to run the model on. Defaults to 'cuda'.
731
-
732
- Returns:
733
- str: The generated text.
734
- """
735
- if not is_package_installed("torch"):
736
- warnings.warn(f"You need to install 'torch' to run generate")
737
- return
738
- tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
739
- streamer = TextStreamer(tokenizer) if stream else None
740
- if not self.model_name_or_path or self.model_name_or_path != model_name_or_path:
741
- self.model_name_or_path = model_name_or_path
742
- try:
743
- self.model = AutoModelForCausalLM.from_pretrained(
744
- model_name_or_path,
745
- trust_remote_code=trust_remote_code,
746
- torch_dtype=torch_dtype,
747
- ).to(device)
748
- except:
749
- warnings.warn(
750
- "Loading model to CPU instead of GPU since GPU is not available."
751
- )
752
- self.model = AutoModelForCausalLM.from_pretrained(
753
- model_name_or_path,
754
- trust_remote_code=trust_remote_code,
755
- torch_dtype=torch_dtype,
756
- ).to("cpu")
757
-
758
- if not prompt:
759
- prompt = [{"role": "user", "content": text}]
760
- inputs = tokenizer.apply_chat_template(
761
- prompt, add_generation_prompt=True, return_tensors="pt"
762
- )
763
-
764
- tokens = self.model.generate(
765
- inputs.to(self.model.device),
766
- max_new_tokens=1024,
767
- temperature=0.8,
768
- do_sample=True,
769
- streamer=streamer,
770
- **kwargs,
771
- )
772
-
773
- return tokenizer.decode(tokens[0])
774
-
775
- def clear_cuda_memory(self):
776
- """
777
- Clears the CUDA memory cache to free up GPU memory.
778
-
779
- Raises:
780
- Warning: If PyTorch is not installed.
781
- """
782
- if not is_package_installed("torch"):
783
- warnings.warn(f"You need to install 'torch' to run clear_cuda_memory")
784
- return
785
- gc.collect()
786
- torch.cuda.empty_cache()