centml 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
centml/__init__.py ADDED
@@ -0,0 +1 @@
1
+ from centml.compiler.main import compile
centml/cli/__init__.py ADDED
@@ -0,0 +1 @@
1
+ from centml.cli.main import cli, ccluster
centml/cli/cluster.py ADDED
@@ -0,0 +1,192 @@
1
+ import sys
2
+ from functools import wraps
3
+ from typing import Dict
4
+ import click
5
+ from tabulate import tabulate
6
+ from centml.sdk import DeploymentType, DeploymentStatus, HealthStatus, ApiException, HardwareInstanceResponse
7
+ from centml.sdk.api import get_centml_client
8
+
9
+
10
+ depl_name_to_type_map = {
11
+ "inference": DeploymentType.INFERENCE_V2,
12
+ "compute": DeploymentType.COMPUTE_V2,
13
+ "cserve": DeploymentType.CSERVE,
14
+ }
15
+ depl_type_to_name_map = {v: k for k, v in depl_name_to_type_map.items()}
16
+
17
+
18
+ def handle_exception(func):
19
+ @wraps(func)
20
+ def wrapper(*args, **kwargs):
21
+ try:
22
+ return func(*args, **kwargs)
23
+ except ApiException as e:
24
+ click.echo(f"Error: {e.reason}")
25
+ return None
26
+
27
+ return wrapper
28
+
29
+
30
+ def _get_hw_to_id_map(cclient, cluster_id):
31
+ response = cclient.get_hardware_instances(cluster_id)
32
+
33
+ # Initialize hashmap for hardware to id or vice versa mapping
34
+ hw_to_id_map: Dict[str, int] = {}
35
+ id_to_hw_map: Dict[int, HardwareInstanceResponse] = {}
36
+
37
+ for hw in response:
38
+ hw_to_id_map[hw.name] = hw.id
39
+ id_to_hw_map[hw.id] = hw
40
+ return hw_to_id_map, id_to_hw_map
41
+
42
+
43
+ def _format_ssh_key(ssh_key):
44
+ if not ssh_key:
45
+ return "No SSH Key Found"
46
+ return ssh_key[:10] + '...'
47
+
48
+
49
+ def _get_ready_status(cclient, deployment):
50
+ api_status = deployment.status
51
+ service_status = (
52
+ cclient.get_status(deployment.id).service_status if deployment.status == DeploymentStatus.ACTIVE else None
53
+ )
54
+
55
+ status_styles = {
56
+ (DeploymentStatus.PAUSED, None): ("paused", "yellow", "black"),
57
+ (DeploymentStatus.DELETED, None): ("deleted", "white", "black"),
58
+ (DeploymentStatus.ACTIVE, HealthStatus.HEALTHY): ("ready", "green", "black"),
59
+ (DeploymentStatus.ACTIVE, HealthStatus.PROGRESSING): ("starting", "black", "white"),
60
+ (DeploymentStatus.ACTIVE, HealthStatus.DEGRADED): ("starting", "black", "white"),
61
+ (DeploymentStatus.ACTIVE, HealthStatus.MISSING): ("not found", "cyan"),
62
+ }
63
+
64
+ style = status_styles.get((api_status, service_status), ("unknown", "black", "white"))
65
+ # Handle foreground and background colors
66
+ return click.style(style[0], fg=style[1], bg=style[2])
67
+
68
+
69
+ @click.command(help="List all deployments")
70
+ @click.argument("type", type=click.Choice(list(depl_name_to_type_map.keys())), required=False, default=None)
71
+ def ls(type):
72
+ with get_centml_client() as cclient:
73
+ depl_type = depl_name_to_type_map[type] if type in depl_name_to_type_map else None
74
+ deployments = cclient.get(depl_type)
75
+ rows = [
76
+ [d.id, d.name, depl_type_to_name_map[d.type], d.status.value, d.created_at.strftime("%Y-%m-%d %H:%M:%S")]
77
+ for d in deployments
78
+ ]
79
+
80
+ click.echo(
81
+ tabulate(
82
+ rows,
83
+ headers=["ID", "Name", "Type", "Status", "Created at"],
84
+ tablefmt="rounded_outline",
85
+ disable_numparse=True,
86
+ )
87
+ )
88
+
89
+
90
+ @click.command(help="Get deployment details")
91
+ @click.argument("type", type=click.Choice(list(depl_name_to_type_map.keys())))
92
+ @click.argument("id", type=int)
93
+ @handle_exception
94
+ def get(type, id):
95
+ with get_centml_client() as cclient:
96
+ depl_type = depl_name_to_type_map[type]
97
+
98
+ if depl_type == DeploymentType.INFERENCE_V2:
99
+ deployment = cclient.get_inference(id)
100
+ elif depl_type == DeploymentType.COMPUTE_V2:
101
+ deployment = cclient.get_compute(id)
102
+ elif depl_type == DeploymentType.CSERVE:
103
+ deployment = cclient.get_cserve(id)
104
+ else:
105
+ sys.exit("Please enter correct deployment type")
106
+
107
+ ready_status = _get_ready_status(cclient, deployment)
108
+ _, id_to_hw_map = _get_hw_to_id_map(cclient, deployment.cluster_id)
109
+ hw = id_to_hw_map[deployment.hardware_instance_id]
110
+
111
+ click.echo(
112
+ tabulate(
113
+ [
114
+ ("Name", deployment.name),
115
+ ("Status", ready_status),
116
+ ("Endpoint", deployment.endpoint_url),
117
+ ("Created at", deployment.created_at.strftime("%Y-%m-%d %H:%M:%S")),
118
+ ("Hardware", f"{hw.name} ({hw.num_gpu}x {hw.gpu_type})"),
119
+ ("Cost", f"{hw.cost_per_hr/100} credits/hr"),
120
+ ],
121
+ tablefmt="rounded_outline",
122
+ disable_numparse=True,
123
+ )
124
+ )
125
+
126
+ click.echo("Additional deployment configurations:")
127
+ if depl_type == DeploymentType.INFERENCE_V2:
128
+ click.echo(
129
+ tabulate(
130
+ [
131
+ ("Image", deployment.image_url),
132
+ ("Container port", deployment.container_port),
133
+ ("Healthcheck", deployment.healthcheck or "/"),
134
+ ("Replicas", {"min": deployment.min_scale, "max": deployment.max_scale}),
135
+ ("Environment variables", deployment.env_vars or "None"),
136
+ ("Max concurrency", deployment.concurrency or "None"),
137
+ ],
138
+ tablefmt="rounded_outline",
139
+ disable_numparse=True,
140
+ )
141
+ )
142
+ elif depl_type == DeploymentType.COMPUTE_V2:
143
+ click.echo(
144
+ tabulate(
145
+ [("Username", "centml"), ("SSH key", _format_ssh_key(deployment.ssh_public_key))],
146
+ tablefmt="rounded_outline",
147
+ disable_numparse=True,
148
+ )
149
+ )
150
+ elif depl_type == DeploymentType.CSERVE:
151
+ click.echo(
152
+ tabulate(
153
+ [
154
+ ("Hugging face model", deployment.model),
155
+ (
156
+ "Parallelism",
157
+ {"tensor": deployment.tensor_parallel_size, "pipeline": deployment.pipeline_parallel_size},
158
+ ),
159
+ ("Replicas", {"min": deployment.min_scale, "max": deployment.max_scale}),
160
+ ("Max concurrency", deployment.concurrency or "None"),
161
+ ],
162
+ tablefmt="rounded_outline",
163
+ disable_numparse=True,
164
+ )
165
+ )
166
+
167
+
168
+ @click.command(help="Delete a deployment")
169
+ @click.argument("id", type=int)
170
+ @handle_exception
171
+ def delete(id):
172
+ with get_centml_client() as cclient:
173
+ cclient.delete(id)
174
+ click.echo("Deployment has been deleted")
175
+
176
+
177
+ @click.command(help="Pause a deployment")
178
+ @click.argument("id", type=int)
179
+ @handle_exception
180
+ def pause(id):
181
+ with get_centml_client() as cclient:
182
+ cclient.pause(id)
183
+ click.echo("Deployment has been paused")
184
+
185
+
186
+ @click.command(help="Resume a deployment")
187
+ @click.argument("id", type=int)
188
+ @handle_exception
189
+ def resume(id):
190
+ with get_centml_client() as cclient:
191
+ cclient.resume(id)
192
+ click.echo("Deployment has been resumed")
centml/cli/login.py ADDED
@@ -0,0 +1,30 @@
1
+ import click
2
+
3
+ from centml.sdk import auth
4
+ from centml.sdk.config import settings
5
+
6
+
7
+ @click.command(help="Login to CentML")
8
+ @click.argument("token_file", required=False)
9
+ def login(token_file):
10
+ if token_file:
11
+ auth.store_centml_cred(token_file)
12
+
13
+ if auth.load_centml_cred():
14
+ click.echo(f"Authenticating with credentials from {settings.CENTML_CRED_FILE_PATH}\n")
15
+ click.echo("Login successful")
16
+ else:
17
+ click.echo("Login with CentML authentication token")
18
+ click.echo("Usage: centml login TOKEN_FILE\n")
19
+ choice = click.confirm("Do you want to download the token?")
20
+
21
+ if choice:
22
+ click.launch(f"{settings.CENTML_WEB_URL}?isCliAuthenticated=true")
23
+ else:
24
+ click.echo("Login unsuccessful")
25
+
26
+
27
+ @click.command(help="Logout from CentML")
28
+ def logout():
29
+ auth.remove_centml_cred()
30
+ click.echo("Logout successful")
centml/cli/main.py ADDED
@@ -0,0 +1,35 @@
1
+ import click
2
+
3
+ from centml.cli.login import login, logout
4
+ from centml.cli.cluster import ls, get, delete, pause, resume
5
+
6
+
7
+ @click.group()
8
+ def cli():
9
+ pass
10
+
11
+
12
+ cli.add_command(login)
13
+ cli.add_command(logout)
14
+
15
+
16
+ @cli.command(help="Start remote compilation server")
17
+ def server():
18
+ from centml.compiler.server import run
19
+
20
+ run()
21
+
22
+
23
+ @click.group(help="CentML cluster CLI tool")
24
+ def ccluster():
25
+ pass
26
+
27
+
28
+ ccluster.add_command(ls)
29
+ ccluster.add_command(get)
30
+ ccluster.add_command(delete)
31
+ ccluster.add_command(pause)
32
+ ccluster.add_command(resume)
33
+
34
+
35
+ cli.add_command(ccluster, name="cluster")
@@ -0,0 +1,3 @@
1
+ from centml.compiler.main import compile
2
+
3
+ all = ["compile"]
@@ -0,0 +1,194 @@
1
+ import os
2
+ import gc
3
+ import time
4
+ import hashlib
5
+ import logging
6
+ import threading as th
7
+ from http import HTTPStatus
8
+ from weakref import ReferenceType, ref
9
+ from tempfile import TemporaryDirectory
10
+ from typing import List, Callable, Optional
11
+ import requests
12
+ import torch
13
+ from torch.fx import GraphModule
14
+ from centml.compiler.config import settings, CompilationStatus
15
+ from centml.compiler.utils import get_backend_compiled_forward_path
16
+
17
+
18
+ class Runner:
19
+ def __init__(self, module: GraphModule, inputs: List[torch.Tensor]):
20
+ if not module:
21
+ raise Exception("No module provided")
22
+
23
+ self._module: ReferenceType[GraphModule] = ref(module)
24
+ self._inputs: List[torch.Tensor] = inputs
25
+ self.compiled_forward_function: Optional[Callable[[torch.Tensor], tuple]] = None
26
+ self.lock = th.Lock()
27
+ self.child_thread = th.Thread(target=self.remote_compilation_starter)
28
+
29
+ self.serialized_model_dir: Optional[TemporaryDirectory] = None
30
+ self.serialized_model_path: Optional[str] = None
31
+ self.serialized_input_path: Optional[str] = None
32
+
33
+ try:
34
+ self.child_thread.start()
35
+ except Exception as e:
36
+ logging.getLogger(__name__).exception(f"Failed to start compilation thread\n{e}")
37
+
38
+ @property
39
+ def module(self) -> Optional[GraphModule]:
40
+ return self._module()
41
+
42
+ @module.deleter
43
+ def module(self):
44
+ self._module().graph.owning_module = None
45
+ self._module = None
46
+
47
+ @property
48
+ def inputs(self) -> List[torch.Tensor]:
49
+ return self._inputs
50
+
51
+ @inputs.deleter
52
+ def inputs(self):
53
+ self._inputs = None
54
+
55
+ def _serialize_model_and_inputs(self):
56
+ self.serialized_model_dir = TemporaryDirectory() # pylint: disable=consider-using-with
57
+ self.serialized_model_path = os.path.join(self.serialized_model_dir.name, settings.CENTML_SERIALIZED_MODEL_FILE)
58
+ self.serialized_input_path = os.path.join(self.serialized_model_dir.name, settings.CENTML_SERIALIZED_INPUT_FILE)
59
+
60
+ # torch.save saves a zip file full of pickled files with the model's states.
61
+ try:
62
+ torch.save(self.module, self.serialized_model_path, pickle_protocol=settings.CENTML_PICKLE_PROTOCOL)
63
+ torch.save(self.inputs, self.serialized_input_path, pickle_protocol=settings.CENTML_PICKLE_PROTOCOL)
64
+ except Exception as e:
65
+ raise Exception(f"Failed to save module or inputs with torch.save: {e}") from e
66
+
67
+ def _get_model_id(self) -> str:
68
+ if not self.serialized_model_path or not os.path.isfile(self.serialized_model_path):
69
+ raise Exception(f"Model not saved at path {self.serialized_model_path}")
70
+
71
+ sha_hash = hashlib.sha256()
72
+ with open(self.serialized_model_path, "rb") as serialized_model_file:
73
+ # Read in chunks to not load too much into memory
74
+ for block in iter(lambda: serialized_model_file.read(settings.CENTML_HASH_CHUNK_SIZE), b""):
75
+ sha_hash.update(block)
76
+
77
+ model_id = sha_hash.hexdigest()
78
+ logging.info(f"Model has id {model_id}")
79
+ return model_id
80
+
81
+ def _download_model(self, model_id: str):
82
+ download_response = requests.get(
83
+ url=f"{settings.CENTML_SERVER_URL}/download/{model_id}", timeout=settings.CENTML_COMPILER_TIMEOUT
84
+ )
85
+ if download_response.status_code != HTTPStatus.OK:
86
+ raise Exception(
87
+ f"Download: request failed, exception from server:\n{download_response.json().get('detail')}"
88
+ )
89
+ if download_response.content == b"":
90
+ raise Exception("Download: empty response from server")
91
+ download_path = get_backend_compiled_forward_path(model_id)
92
+ with open(download_path, "wb") as f:
93
+ f.write(download_response.content)
94
+ return torch.load(download_path)
95
+
96
+ def _compile_model(self, model_id: str):
97
+ # The model should have been saved using torch.save when we found the model_id
98
+ if not self.serialized_model_path or not self.serialized_input_path:
99
+ raise Exception("Model or inputs not serialized")
100
+ if not os.path.isfile(self.serialized_model_path):
101
+ raise Exception(f"Model not saved at path {self.serialized_model_path}")
102
+ if not os.path.isfile(self.serialized_input_path):
103
+ raise Exception(f"Inputs not saved at path {self.serialized_input_path}")
104
+
105
+ with open(self.serialized_model_path, 'rb') as model_file, open(self.serialized_input_path, 'rb') as input_file:
106
+ compile_response = requests.post(
107
+ url=f"{settings.CENTML_SERVER_URL}/submit/{model_id}",
108
+ files={"model": model_file, "inputs": input_file},
109
+ timeout=settings.CENTML_COMPILER_TIMEOUT,
110
+ )
111
+ if compile_response.status_code != HTTPStatus.OK:
112
+ raise Exception(
113
+ f"Compile model: request failed, exception from server:\n{compile_response.json().get('detail')}\n"
114
+ )
115
+
116
+ def _wait_for_status(self, model_id: str) -> bool:
117
+ tries = 0
118
+ while True:
119
+ # get server compilation status
120
+ status = None
121
+ try:
122
+ status_response = requests.get(
123
+ f"{settings.CENTML_SERVER_URL}/status/{model_id}", timeout=settings.CENTML_COMPILER_TIMEOUT
124
+ )
125
+ if status_response.status_code != HTTPStatus.OK:
126
+ raise Exception(
127
+ f"Status check: request failed, exception from server:\n{status_response.json().get('detail')}"
128
+ )
129
+ status = status_response.json().get("status")
130
+ except Exception as e:
131
+ logging.getLogger(__name__).exception(f"Status check failed:\n{e}")
132
+
133
+ if status == CompilationStatus.DONE.value:
134
+ return True
135
+ elif status == CompilationStatus.COMPILING.value:
136
+ pass
137
+ elif status == CompilationStatus.NOT_FOUND.value:
138
+ logging.info("Submitting model to server for compilation.")
139
+ try:
140
+ self._compile_model(model_id)
141
+ except Exception as e:
142
+ logging.getLogger(__name__).exception(f"Submitting compilation failed:\n{e}")
143
+ tries += 1
144
+ else:
145
+ tries += 1
146
+
147
+ if tries > settings.CENTML_COMPILER_MAX_RETRIES:
148
+ raise Exception("Waiting for status: compilation failed too many times.\n")
149
+
150
+ time.sleep(settings.CENTML_COMPILER_SLEEP_TIME)
151
+
152
+ def remote_compilation_starter(self):
153
+ try:
154
+ self.remote_compilation()
155
+ except Exception as e:
156
+ logging.getLogger(__name__).exception(f"Compilation thread failed:\n{e}")
157
+
158
+ def remote_compilation(self):
159
+ self._serialize_model_and_inputs()
160
+
161
+ model_id = self._get_model_id()
162
+
163
+ # check if compiled forward is saved locally
164
+ compiled_forward_path = get_backend_compiled_forward_path(model_id)
165
+ if os.path.isfile(compiled_forward_path):
166
+ logging.info("Compiled model found in local cache. Not submitting to server.")
167
+ compiled_forward = torch.load(compiled_forward_path)
168
+ else:
169
+ self._wait_for_status(model_id)
170
+ compiled_forward = self._download_model(model_id)
171
+
172
+ self.compiled_forward_function = compiled_forward
173
+
174
+ logging.info("Compilation successful.")
175
+
176
+ # Let garbage collector free the memory used by the uncompiled model
177
+ with self.lock:
178
+ del self.inputs
179
+ if self.module:
180
+ del self.module
181
+ gc.collect()
182
+ torch.cuda.empty_cache()
183
+
184
+ def __call__(self, *args, **kwargs):
185
+ # If model is currently compiling, return the uncompiled forward function
186
+ with self.lock:
187
+ if not self.compiled_forward_function:
188
+ return self.module.forward(*args, **kwargs)
189
+
190
+ return self.compiled_forward_function(*args)
191
+
192
+
193
+ def centml_dynamo_backend(gm: GraphModule, example_inputs: List[torch.Tensor]):
194
+ return Runner(gm, example_inputs)
@@ -0,0 +1,45 @@
1
+ import os
2
+ from enum import Enum
3
+ from pydantic_settings import BaseSettings
4
+
5
+
6
+ class CompilationStatus(Enum):
7
+ NOT_FOUND = "NOT_FOUND"
8
+ COMPILING = "COMPILING"
9
+ DONE = "DONE"
10
+
11
+
12
+ class OperationMode(Enum):
13
+ PREDICTION = "PREDICTION"
14
+ REMOTE_COMPILATION = "REMOTE_COMPILATION"
15
+
16
+
17
+ class Config(BaseSettings):
18
+ CENTML_COMPILER_TIMEOUT: int = 10
19
+ CENTML_COMPILER_MAX_RETRIES: int = 3
20
+ CENTML_COMPILER_SLEEP_TIME: int = 15
21
+
22
+ CENTML_BASE_CACHE_DIR: str = os.path.expanduser("~/.cache/centml")
23
+ CENTML_BACKEND_BASE_PATH: str = os.path.join(CENTML_BASE_CACHE_DIR, "backend")
24
+ CENTML_SERVER_BASE_PATH: str = os.path.join(CENTML_BASE_CACHE_DIR, "server")
25
+
26
+ CENTML_SERVER_URL: str = "http://0.0.0.0:8090"
27
+
28
+ # Use a constant path since torch.save uses the given file name in it's zipfile.
29
+ # Using a different filename would result in a different hash.
30
+ CENTML_SERIALIZED_MODEL_FILE: str = "serialized_model.zip"
31
+ CENTML_SERIALIZED_INPUT_FILE: str = "serialized_input.zip"
32
+ CENTML_PICKLE_PROTOCOL: int = 4
33
+
34
+ CENTML_HASH_CHUNK_SIZE: int = 4096
35
+
36
+ # If the server response is smaller than this, don't gzip it
37
+ CENTML_MINIMUM_GZIP_SIZE: int = 1000
38
+
39
+ CENTML_MODE: OperationMode = OperationMode.REMOTE_COMPILATION
40
+ CENTML_PREDICTION_DATA_FILE: str = 'tests/sample_data.csv'
41
+ CENTML_PREDICTION_GPUS: str = "A10G,A100SXM440GB"
42
+ CENTML_PROMETHEUS_PORT: int = 8000
43
+
44
+
45
+ settings = Config()
@@ -0,0 +1,57 @@
1
+ import builtins
2
+ from typing import Callable, Dict, Optional, Union
3
+
4
+ from centml.compiler.config import OperationMode, settings
5
+
6
+
7
+ def compile(
8
+ model: Optional[Callable] = None,
9
+ *,
10
+ fullgraph: builtins.bool = False,
11
+ dynamic: Optional[builtins.bool] = None,
12
+ mode: Union[str, None] = None,
13
+ options: Optional[Dict[str, Union[str, builtins.int, builtins.bool]]] = None,
14
+ disable: builtins.bool = False,
15
+ ) -> Callable:
16
+ import torch
17
+
18
+ if settings.CENTML_MODE == OperationMode.REMOTE_COMPILATION:
19
+ from centml.compiler.backend import centml_dynamo_backend
20
+
21
+ # Return the remote-compiled model
22
+ compiled_model = torch.compile(
23
+ model,
24
+ backend=centml_dynamo_backend, # Compilation backend
25
+ fullgraph=fullgraph,
26
+ dynamic=dynamic,
27
+ mode=mode,
28
+ options=options,
29
+ disable=disable,
30
+ )
31
+ return compiled_model
32
+ elif settings.CENTML_MODE == OperationMode.PREDICTION:
33
+ from centml.compiler.prediction.backend import centml_prediction_backend, get_gauge
34
+
35
+ # Proceed with prediction workflow
36
+ compiled_model = torch.compile(
37
+ model,
38
+ backend=centml_prediction_backend, # Prediction backend
39
+ fullgraph=fullgraph,
40
+ dynamic=dynamic,
41
+ mode=mode,
42
+ options=options,
43
+ disable=disable,
44
+ )
45
+
46
+ def centml_wrapper(*args, **kwargs):
47
+ out = compiled_model(*args, **kwargs)
48
+ # Update the prometheus metrics with final values
49
+ gauge = get_gauge()
50
+ for gpu in settings.CENTML_PREDICTION_GPUS.split(','):
51
+ gauge.set_metric_value(gpu)
52
+
53
+ return out
54
+
55
+ return centml_wrapper
56
+ else:
57
+ raise Exception("Invalid operation mode")
File without changes
@@ -0,0 +1,28 @@
1
+ from typing import List
2
+
3
+ import torch
4
+ from torch._subclasses.fake_tensor import FakeTensorMode
5
+
6
+ from centml.compiler.config import settings
7
+ from centml.compiler.prediction.kdtree import get_tree_db
8
+ from centml.compiler.prediction.metric import get_gauge
9
+ from centml.compiler.prediction.profiler import Profiler
10
+
11
+
12
+ def centml_prediction_backend(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
13
+ profilers = []
14
+ tree_db = get_tree_db()
15
+ for gpu in settings.CENTML_PREDICTION_GPUS.split(','):
16
+ profilers.append(Profiler(gm, gpu, tree_db))
17
+
18
+ def forward(*args):
19
+ fake_mode = FakeTensorMode(allow_non_fake_inputs=True)
20
+ fake_args = [fake_mode.from_tensor(arg) if isinstance(arg, torch.Tensor) else arg for arg in args]
21
+ with fake_mode:
22
+ for prof in profilers:
23
+ out, t = prof.propagate(*fake_args)
24
+ gauge = get_gauge()
25
+ gauge.increment(prof.gpu, t)
26
+ return out
27
+
28
+ return forward
@@ -0,0 +1,69 @@
1
+ import ast
2
+ import csv
3
+ import logging
4
+
5
+ from sklearn.neighbors import KDTree # type: ignore
6
+
7
+ from centml.compiler.config import settings
8
+
9
+ _tree_db = None
10
+
11
+
12
+ class KDTreeWithValues:
13
+ def __init__(self, points=None, values=None):
14
+ self.points = points if points else []
15
+ self.values = values if values else []
16
+ if self.points:
17
+ self.tree = KDTree(self.points)
18
+ else:
19
+ self.tree = None
20
+
21
+ def add(self, point, value):
22
+ self.points.append(point)
23
+ self.values.append(value)
24
+ self.tree = KDTree(self.points)
25
+
26
+ def query(self, point):
27
+ if self.tree is None:
28
+ return None, None
29
+
30
+ dist, idx = self.tree.query([point], k=1)
31
+ return dist[0][0], self.values[idx[0][0]]
32
+
33
+
34
+ class TreeDB:
35
+ def __init__(self, data_csv):
36
+ self.db = {}
37
+ self._populate_db(data_csv)
38
+
39
+ def get(self, key, inp):
40
+ if key not in self.db:
41
+ logging.getLogger(__name__).warning(f"Key {key} not found in database")
42
+ return float('-inf')
43
+ # TODO: Handle the case of unfound keys better. For now, return -inf to indicate something went wrong.
44
+ # Ideally, we shouldn't throw away a whole prediction because of one possibly insignificant node.
45
+
46
+ _, val = self.db[key].query(inp)
47
+ return val
48
+
49
+ def _add_from_db(self, key, points, values):
50
+ self.db[key] = KDTreeWithValues(points, values)
51
+
52
+ def _populate_db(self, data_csv):
53
+ with open(data_csv, newline='') as f:
54
+ reader = csv.DictReader(f)
55
+ for row in reader:
56
+ try:
57
+ key = (row['op'], int(row['dim']), row['inp_dtypes'], row['out_dtypes'], row['gpu'])
58
+ points = ast.literal_eval(row['points'])
59
+ values = ast.literal_eval(row['values'])
60
+ self._add_from_db(key, points, values)
61
+ except ValueError as e:
62
+ logging.getLogger(__name__).exception(f"Error parsing row: {row}\n{e}")
63
+
64
+
65
+ def get_tree_db():
66
+ global _tree_db
67
+ if _tree_db is None:
68
+ _tree_db = TreeDB(settings.CENTML_PREDICTION_DATA_FILE)
69
+ return _tree_db