centml 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- centml/__init__.py +1 -0
- centml/cli/__init__.py +1 -0
- centml/cli/cluster.py +192 -0
- centml/cli/login.py +30 -0
- centml/cli/main.py +35 -0
- centml/compiler/__init__.py +3 -0
- centml/compiler/backend.py +194 -0
- centml/compiler/config.py +45 -0
- centml/compiler/main.py +57 -0
- centml/compiler/prediction/__init__.py +0 -0
- centml/compiler/prediction/backend.py +28 -0
- centml/compiler/prediction/kdtree.py +69 -0
- centml/compiler/prediction/metric.py +30 -0
- centml/compiler/prediction/profiler.py +153 -0
- centml/compiler/server.py +118 -0
- centml/compiler/server_compilation.py +51 -0
- centml/compiler/utils.py +28 -0
- centml/sdk/__init__.py +2 -0
- centml/sdk/api.py +74 -0
- centml/sdk/auth.py +64 -0
- centml/sdk/config.py +16 -0
- centml/sdk/utils/__init__.py +0 -0
- centml/sdk/utils/client_certs.py +93 -0
- centml-0.1.0.dist-info/LICENSE +201 -0
- centml-0.1.0.dist-info/METADATA +106 -0
- centml-0.1.0.dist-info/RECORD +34 -0
- centml-0.1.0.dist-info/WHEEL +5 -0
- centml-0.1.0.dist-info/entry_points.txt +3 -0
- centml-0.1.0.dist-info/top_level.txt +2 -0
- tests/__init__.py +0 -0
- tests/conftest.py +16 -0
- tests/test_backend.py +260 -0
- tests/test_helpers.py +25 -0
- tests/test_server.py +168 -0
centml/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from centml.compiler.main import compile
|
centml/cli/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from centml.cli.main import cli, ccluster
|
centml/cli/cluster.py
ADDED
|
@@ -0,0 +1,192 @@
|
|
|
1
|
+
import sys
|
|
2
|
+
from functools import wraps
|
|
3
|
+
from typing import Dict
|
|
4
|
+
import click
|
|
5
|
+
from tabulate import tabulate
|
|
6
|
+
from centml.sdk import DeploymentType, DeploymentStatus, HealthStatus, ApiException, HardwareInstanceResponse
|
|
7
|
+
from centml.sdk.api import get_centml_client
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
depl_name_to_type_map = {
|
|
11
|
+
"inference": DeploymentType.INFERENCE_V2,
|
|
12
|
+
"compute": DeploymentType.COMPUTE_V2,
|
|
13
|
+
"cserve": DeploymentType.CSERVE,
|
|
14
|
+
}
|
|
15
|
+
depl_type_to_name_map = {v: k for k, v in depl_name_to_type_map.items()}
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def handle_exception(func):
|
|
19
|
+
@wraps(func)
|
|
20
|
+
def wrapper(*args, **kwargs):
|
|
21
|
+
try:
|
|
22
|
+
return func(*args, **kwargs)
|
|
23
|
+
except ApiException as e:
|
|
24
|
+
click.echo(f"Error: {e.reason}")
|
|
25
|
+
return None
|
|
26
|
+
|
|
27
|
+
return wrapper
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def _get_hw_to_id_map(cclient, cluster_id):
|
|
31
|
+
response = cclient.get_hardware_instances(cluster_id)
|
|
32
|
+
|
|
33
|
+
# Initialize hashmap for hardware to id or vice versa mapping
|
|
34
|
+
hw_to_id_map: Dict[str, int] = {}
|
|
35
|
+
id_to_hw_map: Dict[int, HardwareInstanceResponse] = {}
|
|
36
|
+
|
|
37
|
+
for hw in response:
|
|
38
|
+
hw_to_id_map[hw.name] = hw.id
|
|
39
|
+
id_to_hw_map[hw.id] = hw
|
|
40
|
+
return hw_to_id_map, id_to_hw_map
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def _format_ssh_key(ssh_key):
|
|
44
|
+
if not ssh_key:
|
|
45
|
+
return "No SSH Key Found"
|
|
46
|
+
return ssh_key[:10] + '...'
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def _get_ready_status(cclient, deployment):
|
|
50
|
+
api_status = deployment.status
|
|
51
|
+
service_status = (
|
|
52
|
+
cclient.get_status(deployment.id).service_status if deployment.status == DeploymentStatus.ACTIVE else None
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
status_styles = {
|
|
56
|
+
(DeploymentStatus.PAUSED, None): ("paused", "yellow", "black"),
|
|
57
|
+
(DeploymentStatus.DELETED, None): ("deleted", "white", "black"),
|
|
58
|
+
(DeploymentStatus.ACTIVE, HealthStatus.HEALTHY): ("ready", "green", "black"),
|
|
59
|
+
(DeploymentStatus.ACTIVE, HealthStatus.PROGRESSING): ("starting", "black", "white"),
|
|
60
|
+
(DeploymentStatus.ACTIVE, HealthStatus.DEGRADED): ("starting", "black", "white"),
|
|
61
|
+
(DeploymentStatus.ACTIVE, HealthStatus.MISSING): ("not found", "cyan"),
|
|
62
|
+
}
|
|
63
|
+
|
|
64
|
+
style = status_styles.get((api_status, service_status), ("unknown", "black", "white"))
|
|
65
|
+
# Handle foreground and background colors
|
|
66
|
+
return click.style(style[0], fg=style[1], bg=style[2])
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
@click.command(help="List all deployments")
|
|
70
|
+
@click.argument("type", type=click.Choice(list(depl_name_to_type_map.keys())), required=False, default=None)
|
|
71
|
+
def ls(type):
|
|
72
|
+
with get_centml_client() as cclient:
|
|
73
|
+
depl_type = depl_name_to_type_map[type] if type in depl_name_to_type_map else None
|
|
74
|
+
deployments = cclient.get(depl_type)
|
|
75
|
+
rows = [
|
|
76
|
+
[d.id, d.name, depl_type_to_name_map[d.type], d.status.value, d.created_at.strftime("%Y-%m-%d %H:%M:%S")]
|
|
77
|
+
for d in deployments
|
|
78
|
+
]
|
|
79
|
+
|
|
80
|
+
click.echo(
|
|
81
|
+
tabulate(
|
|
82
|
+
rows,
|
|
83
|
+
headers=["ID", "Name", "Type", "Status", "Created at"],
|
|
84
|
+
tablefmt="rounded_outline",
|
|
85
|
+
disable_numparse=True,
|
|
86
|
+
)
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
@click.command(help="Get deployment details")
|
|
91
|
+
@click.argument("type", type=click.Choice(list(depl_name_to_type_map.keys())))
|
|
92
|
+
@click.argument("id", type=int)
|
|
93
|
+
@handle_exception
|
|
94
|
+
def get(type, id):
|
|
95
|
+
with get_centml_client() as cclient:
|
|
96
|
+
depl_type = depl_name_to_type_map[type]
|
|
97
|
+
|
|
98
|
+
if depl_type == DeploymentType.INFERENCE_V2:
|
|
99
|
+
deployment = cclient.get_inference(id)
|
|
100
|
+
elif depl_type == DeploymentType.COMPUTE_V2:
|
|
101
|
+
deployment = cclient.get_compute(id)
|
|
102
|
+
elif depl_type == DeploymentType.CSERVE:
|
|
103
|
+
deployment = cclient.get_cserve(id)
|
|
104
|
+
else:
|
|
105
|
+
sys.exit("Please enter correct deployment type")
|
|
106
|
+
|
|
107
|
+
ready_status = _get_ready_status(cclient, deployment)
|
|
108
|
+
_, id_to_hw_map = _get_hw_to_id_map(cclient, deployment.cluster_id)
|
|
109
|
+
hw = id_to_hw_map[deployment.hardware_instance_id]
|
|
110
|
+
|
|
111
|
+
click.echo(
|
|
112
|
+
tabulate(
|
|
113
|
+
[
|
|
114
|
+
("Name", deployment.name),
|
|
115
|
+
("Status", ready_status),
|
|
116
|
+
("Endpoint", deployment.endpoint_url),
|
|
117
|
+
("Created at", deployment.created_at.strftime("%Y-%m-%d %H:%M:%S")),
|
|
118
|
+
("Hardware", f"{hw.name} ({hw.num_gpu}x {hw.gpu_type})"),
|
|
119
|
+
("Cost", f"{hw.cost_per_hr/100} credits/hr"),
|
|
120
|
+
],
|
|
121
|
+
tablefmt="rounded_outline",
|
|
122
|
+
disable_numparse=True,
|
|
123
|
+
)
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
click.echo("Additional deployment configurations:")
|
|
127
|
+
if depl_type == DeploymentType.INFERENCE_V2:
|
|
128
|
+
click.echo(
|
|
129
|
+
tabulate(
|
|
130
|
+
[
|
|
131
|
+
("Image", deployment.image_url),
|
|
132
|
+
("Container port", deployment.container_port),
|
|
133
|
+
("Healthcheck", deployment.healthcheck or "/"),
|
|
134
|
+
("Replicas", {"min": deployment.min_scale, "max": deployment.max_scale}),
|
|
135
|
+
("Environment variables", deployment.env_vars or "None"),
|
|
136
|
+
("Max concurrency", deployment.concurrency or "None"),
|
|
137
|
+
],
|
|
138
|
+
tablefmt="rounded_outline",
|
|
139
|
+
disable_numparse=True,
|
|
140
|
+
)
|
|
141
|
+
)
|
|
142
|
+
elif depl_type == DeploymentType.COMPUTE_V2:
|
|
143
|
+
click.echo(
|
|
144
|
+
tabulate(
|
|
145
|
+
[("Username", "centml"), ("SSH key", _format_ssh_key(deployment.ssh_public_key))],
|
|
146
|
+
tablefmt="rounded_outline",
|
|
147
|
+
disable_numparse=True,
|
|
148
|
+
)
|
|
149
|
+
)
|
|
150
|
+
elif depl_type == DeploymentType.CSERVE:
|
|
151
|
+
click.echo(
|
|
152
|
+
tabulate(
|
|
153
|
+
[
|
|
154
|
+
("Hugging face model", deployment.model),
|
|
155
|
+
(
|
|
156
|
+
"Parallelism",
|
|
157
|
+
{"tensor": deployment.tensor_parallel_size, "pipeline": deployment.pipeline_parallel_size},
|
|
158
|
+
),
|
|
159
|
+
("Replicas", {"min": deployment.min_scale, "max": deployment.max_scale}),
|
|
160
|
+
("Max concurrency", deployment.concurrency or "None"),
|
|
161
|
+
],
|
|
162
|
+
tablefmt="rounded_outline",
|
|
163
|
+
disable_numparse=True,
|
|
164
|
+
)
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
@click.command(help="Delete a deployment")
|
|
169
|
+
@click.argument("id", type=int)
|
|
170
|
+
@handle_exception
|
|
171
|
+
def delete(id):
|
|
172
|
+
with get_centml_client() as cclient:
|
|
173
|
+
cclient.delete(id)
|
|
174
|
+
click.echo("Deployment has been deleted")
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
@click.command(help="Pause a deployment")
|
|
178
|
+
@click.argument("id", type=int)
|
|
179
|
+
@handle_exception
|
|
180
|
+
def pause(id):
|
|
181
|
+
with get_centml_client() as cclient:
|
|
182
|
+
cclient.pause(id)
|
|
183
|
+
click.echo("Deployment has been paused")
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
@click.command(help="Resume a deployment")
|
|
187
|
+
@click.argument("id", type=int)
|
|
188
|
+
@handle_exception
|
|
189
|
+
def resume(id):
|
|
190
|
+
with get_centml_client() as cclient:
|
|
191
|
+
cclient.resume(id)
|
|
192
|
+
click.echo("Deployment has been resumed")
|
centml/cli/login.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
import click
|
|
2
|
+
|
|
3
|
+
from centml.sdk import auth
|
|
4
|
+
from centml.sdk.config import settings
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
@click.command(help="Login to CentML")
|
|
8
|
+
@click.argument("token_file", required=False)
|
|
9
|
+
def login(token_file):
|
|
10
|
+
if token_file:
|
|
11
|
+
auth.store_centml_cred(token_file)
|
|
12
|
+
|
|
13
|
+
if auth.load_centml_cred():
|
|
14
|
+
click.echo(f"Authenticating with credentials from {settings.CENTML_CRED_FILE_PATH}\n")
|
|
15
|
+
click.echo("Login successful")
|
|
16
|
+
else:
|
|
17
|
+
click.echo("Login with CentML authentication token")
|
|
18
|
+
click.echo("Usage: centml login TOKEN_FILE\n")
|
|
19
|
+
choice = click.confirm("Do you want to download the token?")
|
|
20
|
+
|
|
21
|
+
if choice:
|
|
22
|
+
click.launch(f"{settings.CENTML_WEB_URL}?isCliAuthenticated=true")
|
|
23
|
+
else:
|
|
24
|
+
click.echo("Login unsuccessful")
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@click.command(help="Logout from CentML")
|
|
28
|
+
def logout():
|
|
29
|
+
auth.remove_centml_cred()
|
|
30
|
+
click.echo("Logout successful")
|
centml/cli/main.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
import click
|
|
2
|
+
|
|
3
|
+
from centml.cli.login import login, logout
|
|
4
|
+
from centml.cli.cluster import ls, get, delete, pause, resume
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
@click.group()
|
|
8
|
+
def cli():
|
|
9
|
+
pass
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
cli.add_command(login)
|
|
13
|
+
cli.add_command(logout)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@cli.command(help="Start remote compilation server")
|
|
17
|
+
def server():
|
|
18
|
+
from centml.compiler.server import run
|
|
19
|
+
|
|
20
|
+
run()
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@click.group(help="CentML cluster CLI tool")
|
|
24
|
+
def ccluster():
|
|
25
|
+
pass
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
ccluster.add_command(ls)
|
|
29
|
+
ccluster.add_command(get)
|
|
30
|
+
ccluster.add_command(delete)
|
|
31
|
+
ccluster.add_command(pause)
|
|
32
|
+
ccluster.add_command(resume)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
cli.add_command(ccluster, name="cluster")
|
|
@@ -0,0 +1,194 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import gc
|
|
3
|
+
import time
|
|
4
|
+
import hashlib
|
|
5
|
+
import logging
|
|
6
|
+
import threading as th
|
|
7
|
+
from http import HTTPStatus
|
|
8
|
+
from weakref import ReferenceType, ref
|
|
9
|
+
from tempfile import TemporaryDirectory
|
|
10
|
+
from typing import List, Callable, Optional
|
|
11
|
+
import requests
|
|
12
|
+
import torch
|
|
13
|
+
from torch.fx import GraphModule
|
|
14
|
+
from centml.compiler.config import settings, CompilationStatus
|
|
15
|
+
from centml.compiler.utils import get_backend_compiled_forward_path
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class Runner:
|
|
19
|
+
def __init__(self, module: GraphModule, inputs: List[torch.Tensor]):
|
|
20
|
+
if not module:
|
|
21
|
+
raise Exception("No module provided")
|
|
22
|
+
|
|
23
|
+
self._module: ReferenceType[GraphModule] = ref(module)
|
|
24
|
+
self._inputs: List[torch.Tensor] = inputs
|
|
25
|
+
self.compiled_forward_function: Optional[Callable[[torch.Tensor], tuple]] = None
|
|
26
|
+
self.lock = th.Lock()
|
|
27
|
+
self.child_thread = th.Thread(target=self.remote_compilation_starter)
|
|
28
|
+
|
|
29
|
+
self.serialized_model_dir: Optional[TemporaryDirectory] = None
|
|
30
|
+
self.serialized_model_path: Optional[str] = None
|
|
31
|
+
self.serialized_input_path: Optional[str] = None
|
|
32
|
+
|
|
33
|
+
try:
|
|
34
|
+
self.child_thread.start()
|
|
35
|
+
except Exception as e:
|
|
36
|
+
logging.getLogger(__name__).exception(f"Failed to start compilation thread\n{e}")
|
|
37
|
+
|
|
38
|
+
@property
|
|
39
|
+
def module(self) -> Optional[GraphModule]:
|
|
40
|
+
return self._module()
|
|
41
|
+
|
|
42
|
+
@module.deleter
|
|
43
|
+
def module(self):
|
|
44
|
+
self._module().graph.owning_module = None
|
|
45
|
+
self._module = None
|
|
46
|
+
|
|
47
|
+
@property
|
|
48
|
+
def inputs(self) -> List[torch.Tensor]:
|
|
49
|
+
return self._inputs
|
|
50
|
+
|
|
51
|
+
@inputs.deleter
|
|
52
|
+
def inputs(self):
|
|
53
|
+
self._inputs = None
|
|
54
|
+
|
|
55
|
+
def _serialize_model_and_inputs(self):
|
|
56
|
+
self.serialized_model_dir = TemporaryDirectory() # pylint: disable=consider-using-with
|
|
57
|
+
self.serialized_model_path = os.path.join(self.serialized_model_dir.name, settings.CENTML_SERIALIZED_MODEL_FILE)
|
|
58
|
+
self.serialized_input_path = os.path.join(self.serialized_model_dir.name, settings.CENTML_SERIALIZED_INPUT_FILE)
|
|
59
|
+
|
|
60
|
+
# torch.save saves a zip file full of pickled files with the model's states.
|
|
61
|
+
try:
|
|
62
|
+
torch.save(self.module, self.serialized_model_path, pickle_protocol=settings.CENTML_PICKLE_PROTOCOL)
|
|
63
|
+
torch.save(self.inputs, self.serialized_input_path, pickle_protocol=settings.CENTML_PICKLE_PROTOCOL)
|
|
64
|
+
except Exception as e:
|
|
65
|
+
raise Exception(f"Failed to save module or inputs with torch.save: {e}") from e
|
|
66
|
+
|
|
67
|
+
def _get_model_id(self) -> str:
|
|
68
|
+
if not self.serialized_model_path or not os.path.isfile(self.serialized_model_path):
|
|
69
|
+
raise Exception(f"Model not saved at path {self.serialized_model_path}")
|
|
70
|
+
|
|
71
|
+
sha_hash = hashlib.sha256()
|
|
72
|
+
with open(self.serialized_model_path, "rb") as serialized_model_file:
|
|
73
|
+
# Read in chunks to not load too much into memory
|
|
74
|
+
for block in iter(lambda: serialized_model_file.read(settings.CENTML_HASH_CHUNK_SIZE), b""):
|
|
75
|
+
sha_hash.update(block)
|
|
76
|
+
|
|
77
|
+
model_id = sha_hash.hexdigest()
|
|
78
|
+
logging.info(f"Model has id {model_id}")
|
|
79
|
+
return model_id
|
|
80
|
+
|
|
81
|
+
def _download_model(self, model_id: str):
|
|
82
|
+
download_response = requests.get(
|
|
83
|
+
url=f"{settings.CENTML_SERVER_URL}/download/{model_id}", timeout=settings.CENTML_COMPILER_TIMEOUT
|
|
84
|
+
)
|
|
85
|
+
if download_response.status_code != HTTPStatus.OK:
|
|
86
|
+
raise Exception(
|
|
87
|
+
f"Download: request failed, exception from server:\n{download_response.json().get('detail')}"
|
|
88
|
+
)
|
|
89
|
+
if download_response.content == b"":
|
|
90
|
+
raise Exception("Download: empty response from server")
|
|
91
|
+
download_path = get_backend_compiled_forward_path(model_id)
|
|
92
|
+
with open(download_path, "wb") as f:
|
|
93
|
+
f.write(download_response.content)
|
|
94
|
+
return torch.load(download_path)
|
|
95
|
+
|
|
96
|
+
def _compile_model(self, model_id: str):
|
|
97
|
+
# The model should have been saved using torch.save when we found the model_id
|
|
98
|
+
if not self.serialized_model_path or not self.serialized_input_path:
|
|
99
|
+
raise Exception("Model or inputs not serialized")
|
|
100
|
+
if not os.path.isfile(self.serialized_model_path):
|
|
101
|
+
raise Exception(f"Model not saved at path {self.serialized_model_path}")
|
|
102
|
+
if not os.path.isfile(self.serialized_input_path):
|
|
103
|
+
raise Exception(f"Inputs not saved at path {self.serialized_input_path}")
|
|
104
|
+
|
|
105
|
+
with open(self.serialized_model_path, 'rb') as model_file, open(self.serialized_input_path, 'rb') as input_file:
|
|
106
|
+
compile_response = requests.post(
|
|
107
|
+
url=f"{settings.CENTML_SERVER_URL}/submit/{model_id}",
|
|
108
|
+
files={"model": model_file, "inputs": input_file},
|
|
109
|
+
timeout=settings.CENTML_COMPILER_TIMEOUT,
|
|
110
|
+
)
|
|
111
|
+
if compile_response.status_code != HTTPStatus.OK:
|
|
112
|
+
raise Exception(
|
|
113
|
+
f"Compile model: request failed, exception from server:\n{compile_response.json().get('detail')}\n"
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
def _wait_for_status(self, model_id: str) -> bool:
|
|
117
|
+
tries = 0
|
|
118
|
+
while True:
|
|
119
|
+
# get server compilation status
|
|
120
|
+
status = None
|
|
121
|
+
try:
|
|
122
|
+
status_response = requests.get(
|
|
123
|
+
f"{settings.CENTML_SERVER_URL}/status/{model_id}", timeout=settings.CENTML_COMPILER_TIMEOUT
|
|
124
|
+
)
|
|
125
|
+
if status_response.status_code != HTTPStatus.OK:
|
|
126
|
+
raise Exception(
|
|
127
|
+
f"Status check: request failed, exception from server:\n{status_response.json().get('detail')}"
|
|
128
|
+
)
|
|
129
|
+
status = status_response.json().get("status")
|
|
130
|
+
except Exception as e:
|
|
131
|
+
logging.getLogger(__name__).exception(f"Status check failed:\n{e}")
|
|
132
|
+
|
|
133
|
+
if status == CompilationStatus.DONE.value:
|
|
134
|
+
return True
|
|
135
|
+
elif status == CompilationStatus.COMPILING.value:
|
|
136
|
+
pass
|
|
137
|
+
elif status == CompilationStatus.NOT_FOUND.value:
|
|
138
|
+
logging.info("Submitting model to server for compilation.")
|
|
139
|
+
try:
|
|
140
|
+
self._compile_model(model_id)
|
|
141
|
+
except Exception as e:
|
|
142
|
+
logging.getLogger(__name__).exception(f"Submitting compilation failed:\n{e}")
|
|
143
|
+
tries += 1
|
|
144
|
+
else:
|
|
145
|
+
tries += 1
|
|
146
|
+
|
|
147
|
+
if tries > settings.CENTML_COMPILER_MAX_RETRIES:
|
|
148
|
+
raise Exception("Waiting for status: compilation failed too many times.\n")
|
|
149
|
+
|
|
150
|
+
time.sleep(settings.CENTML_COMPILER_SLEEP_TIME)
|
|
151
|
+
|
|
152
|
+
def remote_compilation_starter(self):
|
|
153
|
+
try:
|
|
154
|
+
self.remote_compilation()
|
|
155
|
+
except Exception as e:
|
|
156
|
+
logging.getLogger(__name__).exception(f"Compilation thread failed:\n{e}")
|
|
157
|
+
|
|
158
|
+
def remote_compilation(self):
|
|
159
|
+
self._serialize_model_and_inputs()
|
|
160
|
+
|
|
161
|
+
model_id = self._get_model_id()
|
|
162
|
+
|
|
163
|
+
# check if compiled forward is saved locally
|
|
164
|
+
compiled_forward_path = get_backend_compiled_forward_path(model_id)
|
|
165
|
+
if os.path.isfile(compiled_forward_path):
|
|
166
|
+
logging.info("Compiled model found in local cache. Not submitting to server.")
|
|
167
|
+
compiled_forward = torch.load(compiled_forward_path)
|
|
168
|
+
else:
|
|
169
|
+
self._wait_for_status(model_id)
|
|
170
|
+
compiled_forward = self._download_model(model_id)
|
|
171
|
+
|
|
172
|
+
self.compiled_forward_function = compiled_forward
|
|
173
|
+
|
|
174
|
+
logging.info("Compilation successful.")
|
|
175
|
+
|
|
176
|
+
# Let garbage collector free the memory used by the uncompiled model
|
|
177
|
+
with self.lock:
|
|
178
|
+
del self.inputs
|
|
179
|
+
if self.module:
|
|
180
|
+
del self.module
|
|
181
|
+
gc.collect()
|
|
182
|
+
torch.cuda.empty_cache()
|
|
183
|
+
|
|
184
|
+
def __call__(self, *args, **kwargs):
|
|
185
|
+
# If model is currently compiling, return the uncompiled forward function
|
|
186
|
+
with self.lock:
|
|
187
|
+
if not self.compiled_forward_function:
|
|
188
|
+
return self.module.forward(*args, **kwargs)
|
|
189
|
+
|
|
190
|
+
return self.compiled_forward_function(*args)
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
def centml_dynamo_backend(gm: GraphModule, example_inputs: List[torch.Tensor]):
|
|
194
|
+
return Runner(gm, example_inputs)
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from enum import Enum
|
|
3
|
+
from pydantic_settings import BaseSettings
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class CompilationStatus(Enum):
|
|
7
|
+
NOT_FOUND = "NOT_FOUND"
|
|
8
|
+
COMPILING = "COMPILING"
|
|
9
|
+
DONE = "DONE"
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class OperationMode(Enum):
|
|
13
|
+
PREDICTION = "PREDICTION"
|
|
14
|
+
REMOTE_COMPILATION = "REMOTE_COMPILATION"
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class Config(BaseSettings):
|
|
18
|
+
CENTML_COMPILER_TIMEOUT: int = 10
|
|
19
|
+
CENTML_COMPILER_MAX_RETRIES: int = 3
|
|
20
|
+
CENTML_COMPILER_SLEEP_TIME: int = 15
|
|
21
|
+
|
|
22
|
+
CENTML_BASE_CACHE_DIR: str = os.path.expanduser("~/.cache/centml")
|
|
23
|
+
CENTML_BACKEND_BASE_PATH: str = os.path.join(CENTML_BASE_CACHE_DIR, "backend")
|
|
24
|
+
CENTML_SERVER_BASE_PATH: str = os.path.join(CENTML_BASE_CACHE_DIR, "server")
|
|
25
|
+
|
|
26
|
+
CENTML_SERVER_URL: str = "http://0.0.0.0:8090"
|
|
27
|
+
|
|
28
|
+
# Use a constant path since torch.save uses the given file name in it's zipfile.
|
|
29
|
+
# Using a different filename would result in a different hash.
|
|
30
|
+
CENTML_SERIALIZED_MODEL_FILE: str = "serialized_model.zip"
|
|
31
|
+
CENTML_SERIALIZED_INPUT_FILE: str = "serialized_input.zip"
|
|
32
|
+
CENTML_PICKLE_PROTOCOL: int = 4
|
|
33
|
+
|
|
34
|
+
CENTML_HASH_CHUNK_SIZE: int = 4096
|
|
35
|
+
|
|
36
|
+
# If the server response is smaller than this, don't gzip it
|
|
37
|
+
CENTML_MINIMUM_GZIP_SIZE: int = 1000
|
|
38
|
+
|
|
39
|
+
CENTML_MODE: OperationMode = OperationMode.REMOTE_COMPILATION
|
|
40
|
+
CENTML_PREDICTION_DATA_FILE: str = 'tests/sample_data.csv'
|
|
41
|
+
CENTML_PREDICTION_GPUS: str = "A10G,A100SXM440GB"
|
|
42
|
+
CENTML_PROMETHEUS_PORT: int = 8000
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
settings = Config()
|
centml/compiler/main.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
import builtins
|
|
2
|
+
from typing import Callable, Dict, Optional, Union
|
|
3
|
+
|
|
4
|
+
from centml.compiler.config import OperationMode, settings
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def compile(
|
|
8
|
+
model: Optional[Callable] = None,
|
|
9
|
+
*,
|
|
10
|
+
fullgraph: builtins.bool = False,
|
|
11
|
+
dynamic: Optional[builtins.bool] = None,
|
|
12
|
+
mode: Union[str, None] = None,
|
|
13
|
+
options: Optional[Dict[str, Union[str, builtins.int, builtins.bool]]] = None,
|
|
14
|
+
disable: builtins.bool = False,
|
|
15
|
+
) -> Callable:
|
|
16
|
+
import torch
|
|
17
|
+
|
|
18
|
+
if settings.CENTML_MODE == OperationMode.REMOTE_COMPILATION:
|
|
19
|
+
from centml.compiler.backend import centml_dynamo_backend
|
|
20
|
+
|
|
21
|
+
# Return the remote-compiled model
|
|
22
|
+
compiled_model = torch.compile(
|
|
23
|
+
model,
|
|
24
|
+
backend=centml_dynamo_backend, # Compilation backend
|
|
25
|
+
fullgraph=fullgraph,
|
|
26
|
+
dynamic=dynamic,
|
|
27
|
+
mode=mode,
|
|
28
|
+
options=options,
|
|
29
|
+
disable=disable,
|
|
30
|
+
)
|
|
31
|
+
return compiled_model
|
|
32
|
+
elif settings.CENTML_MODE == OperationMode.PREDICTION:
|
|
33
|
+
from centml.compiler.prediction.backend import centml_prediction_backend, get_gauge
|
|
34
|
+
|
|
35
|
+
# Proceed with prediction workflow
|
|
36
|
+
compiled_model = torch.compile(
|
|
37
|
+
model,
|
|
38
|
+
backend=centml_prediction_backend, # Prediction backend
|
|
39
|
+
fullgraph=fullgraph,
|
|
40
|
+
dynamic=dynamic,
|
|
41
|
+
mode=mode,
|
|
42
|
+
options=options,
|
|
43
|
+
disable=disable,
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
def centml_wrapper(*args, **kwargs):
|
|
47
|
+
out = compiled_model(*args, **kwargs)
|
|
48
|
+
# Update the prometheus metrics with final values
|
|
49
|
+
gauge = get_gauge()
|
|
50
|
+
for gpu in settings.CENTML_PREDICTION_GPUS.split(','):
|
|
51
|
+
gauge.set_metric_value(gpu)
|
|
52
|
+
|
|
53
|
+
return out
|
|
54
|
+
|
|
55
|
+
return centml_wrapper
|
|
56
|
+
else:
|
|
57
|
+
raise Exception("Invalid operation mode")
|
|
File without changes
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
from typing import List
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from torch._subclasses.fake_tensor import FakeTensorMode
|
|
5
|
+
|
|
6
|
+
from centml.compiler.config import settings
|
|
7
|
+
from centml.compiler.prediction.kdtree import get_tree_db
|
|
8
|
+
from centml.compiler.prediction.metric import get_gauge
|
|
9
|
+
from centml.compiler.prediction.profiler import Profiler
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def centml_prediction_backend(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
|
|
13
|
+
profilers = []
|
|
14
|
+
tree_db = get_tree_db()
|
|
15
|
+
for gpu in settings.CENTML_PREDICTION_GPUS.split(','):
|
|
16
|
+
profilers.append(Profiler(gm, gpu, tree_db))
|
|
17
|
+
|
|
18
|
+
def forward(*args):
|
|
19
|
+
fake_mode = FakeTensorMode(allow_non_fake_inputs=True)
|
|
20
|
+
fake_args = [fake_mode.from_tensor(arg) if isinstance(arg, torch.Tensor) else arg for arg in args]
|
|
21
|
+
with fake_mode:
|
|
22
|
+
for prof in profilers:
|
|
23
|
+
out, t = prof.propagate(*fake_args)
|
|
24
|
+
gauge = get_gauge()
|
|
25
|
+
gauge.increment(prof.gpu, t)
|
|
26
|
+
return out
|
|
27
|
+
|
|
28
|
+
return forward
|
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
import ast
|
|
2
|
+
import csv
|
|
3
|
+
import logging
|
|
4
|
+
|
|
5
|
+
from sklearn.neighbors import KDTree # type: ignore
|
|
6
|
+
|
|
7
|
+
from centml.compiler.config import settings
|
|
8
|
+
|
|
9
|
+
_tree_db = None
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class KDTreeWithValues:
|
|
13
|
+
def __init__(self, points=None, values=None):
|
|
14
|
+
self.points = points if points else []
|
|
15
|
+
self.values = values if values else []
|
|
16
|
+
if self.points:
|
|
17
|
+
self.tree = KDTree(self.points)
|
|
18
|
+
else:
|
|
19
|
+
self.tree = None
|
|
20
|
+
|
|
21
|
+
def add(self, point, value):
|
|
22
|
+
self.points.append(point)
|
|
23
|
+
self.values.append(value)
|
|
24
|
+
self.tree = KDTree(self.points)
|
|
25
|
+
|
|
26
|
+
def query(self, point):
|
|
27
|
+
if self.tree is None:
|
|
28
|
+
return None, None
|
|
29
|
+
|
|
30
|
+
dist, idx = self.tree.query([point], k=1)
|
|
31
|
+
return dist[0][0], self.values[idx[0][0]]
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class TreeDB:
|
|
35
|
+
def __init__(self, data_csv):
|
|
36
|
+
self.db = {}
|
|
37
|
+
self._populate_db(data_csv)
|
|
38
|
+
|
|
39
|
+
def get(self, key, inp):
|
|
40
|
+
if key not in self.db:
|
|
41
|
+
logging.getLogger(__name__).warning(f"Key {key} not found in database")
|
|
42
|
+
return float('-inf')
|
|
43
|
+
# TODO: Handle the case of unfound keys better. For now, return -inf to indicate something went wrong.
|
|
44
|
+
# Ideally, we shouldn't throw away a whole prediction because of one possibly insignificant node.
|
|
45
|
+
|
|
46
|
+
_, val = self.db[key].query(inp)
|
|
47
|
+
return val
|
|
48
|
+
|
|
49
|
+
def _add_from_db(self, key, points, values):
|
|
50
|
+
self.db[key] = KDTreeWithValues(points, values)
|
|
51
|
+
|
|
52
|
+
def _populate_db(self, data_csv):
|
|
53
|
+
with open(data_csv, newline='') as f:
|
|
54
|
+
reader = csv.DictReader(f)
|
|
55
|
+
for row in reader:
|
|
56
|
+
try:
|
|
57
|
+
key = (row['op'], int(row['dim']), row['inp_dtypes'], row['out_dtypes'], row['gpu'])
|
|
58
|
+
points = ast.literal_eval(row['points'])
|
|
59
|
+
values = ast.literal_eval(row['values'])
|
|
60
|
+
self._add_from_db(key, points, values)
|
|
61
|
+
except ValueError as e:
|
|
62
|
+
logging.getLogger(__name__).exception(f"Error parsing row: {row}\n{e}")
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def get_tree_db():
|
|
66
|
+
global _tree_db
|
|
67
|
+
if _tree_db is None:
|
|
68
|
+
_tree_db = TreeDB(settings.CENTML_PREDICTION_DATA_FILE)
|
|
69
|
+
return _tree_db
|