NeMo-Export-Deploy 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nemo_deploy/__init__.py +26 -0
- nemo_deploy/deploy_base.py +71 -0
- nemo_deploy/deploy_pytriton.py +173 -0
- nemo_deploy/deploy_ray.py +140 -0
- nemo_deploy/multimodal/__init__.py +16 -0
- nemo_deploy/multimodal/query_multimodal.py +184 -0
- nemo_deploy/nlp/__init__.py +23 -0
- nemo_deploy/nlp/hf_deployable.py +466 -0
- nemo_deploy/nlp/hf_deployable_ray.py +375 -0
- nemo_deploy/nlp/inference/inference_base.py +491 -0
- nemo_deploy/nlp/inference/tron_utils.py +485 -0
- nemo_deploy/nlp/megatronllm_deployable.py +538 -0
- nemo_deploy/nlp/megatronllm_deployable_ray.py +374 -0
- nemo_deploy/nlp/query_llm.py +518 -0
- nemo_deploy/nlp/trtllm_api_deployable.py +194 -0
- nemo_deploy/package_info.py +33 -0
- nemo_deploy/ray_utils.py +47 -0
- nemo_deploy/service/__init__.py +16 -0
- nemo_deploy/service/fastapi_interface_to_pytriton.py +319 -0
- nemo_deploy/service/rest_model_api.py +140 -0
- nemo_deploy/triton_deployable.py +32 -0
- nemo_deploy/utils.py +227 -0
- nemo_export/__init__.py +28 -0
- nemo_export/multimodal/__init__.py +13 -0
- nemo_export/multimodal/build.py +679 -0
- nemo_export/multimodal/run.py +989 -0
- nemo_export/onnx_llm_exporter.py +543 -0
- nemo_export/package_info.py +33 -0
- nemo_export/sentencepiece_tokenizer.py +286 -0
- nemo_export/tarutils.py +236 -0
- nemo_export/tensorrt_llm.py +1134 -0
- nemo_export/tensorrt_llm_deployable_ray.py +293 -0
- nemo_export/tensorrt_mm_exporter.py +377 -0
- nemo_export/tiktoken_tokenizer.py +120 -0
- nemo_export/trt_llm/__init__.py +13 -0
- nemo_export/trt_llm/nemo_ckpt_loader/__init__.py +13 -0
- nemo_export/trt_llm/nemo_ckpt_loader/nemo_file.py +433 -0
- nemo_export/trt_llm/qnemo/__init__.py +17 -0
- nemo_export/trt_llm/qnemo/qnemo_to_tensorrt_llm.py +128 -0
- nemo_export/trt_llm/qnemo/utils.py +32 -0
- nemo_export/trt_llm/tensorrt_llm_run.py +565 -0
- nemo_export/trt_llm/utils.py +69 -0
- nemo_export/utils/__init__.py +45 -0
- nemo_export/utils/_mock_import.py +78 -0
- nemo_export/utils/constants.py +16 -0
- nemo_export/utils/lora_converter.py +236 -0
- nemo_export/utils/model_loader.py +213 -0
- nemo_export/utils/utils.py +167 -0
- nemo_export/vllm/__init__.py +13 -0
- nemo_export/vllm/model_config.py +279 -0
- nemo_export/vllm/model_converters.py +480 -0
- nemo_export/vllm/model_loader.py +111 -0
- nemo_export/vllm_exporter.py +583 -0
- nemo_export/vllm_hf_exporter.py +154 -0
- nemo_export_deploy-0.1.0.dist-info/METADATA +472 -0
- nemo_export_deploy-0.1.0.dist-info/RECORD +62 -0
- nemo_export_deploy-0.1.0.dist-info/WHEEL +5 -0
- nemo_export_deploy-0.1.0.dist-info/licenses/LICENSE +201 -0
- nemo_export_deploy-0.1.0.dist-info/top_level.txt +3 -0
- nemo_export_deploy_common/__init__.py +19 -0
- nemo_export_deploy_common/import_utils.py +450 -0
- nemo_export_deploy_common/package_info.py +45 -0
nemo_deploy/__init__.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from nemo_deploy.deploy_base import DeployBase
|
|
16
|
+
from nemo_deploy.deploy_pytriton import DeployPyTriton
|
|
17
|
+
from nemo_deploy.triton_deployable import ITritonDeployable
|
|
18
|
+
from nemo_export_deploy_common.package_info import __package_name__, __version__
|
|
19
|
+
|
|
20
|
+
__all__ = [
|
|
21
|
+
"DeployBase",
|
|
22
|
+
"DeployPyTriton",
|
|
23
|
+
"ITritonDeployable",
|
|
24
|
+
"__version__",
|
|
25
|
+
"__package_name__",
|
|
26
|
+
]
|
|
@@ -0,0 +1,71 @@
|
|
|
1
|
+
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import logging
|
|
16
|
+
from abc import ABC, abstractmethod
|
|
17
|
+
|
|
18
|
+
from nemo_deploy.triton_deployable import ITritonDeployable
|
|
19
|
+
|
|
20
|
+
LOGGER = logging.getLogger("NeMo")
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class DeployBase(ABC):
|
|
24
|
+
def __init__(
|
|
25
|
+
self,
|
|
26
|
+
triton_model_name: str,
|
|
27
|
+
triton_model_version: int = 1,
|
|
28
|
+
model=None,
|
|
29
|
+
max_batch_size: int = 128,
|
|
30
|
+
http_port: int = 8000,
|
|
31
|
+
grpc_port: int = 8001,
|
|
32
|
+
address="0.0.0.0",
|
|
33
|
+
allow_grpc=True,
|
|
34
|
+
allow_http=True,
|
|
35
|
+
streaming=False,
|
|
36
|
+
):
|
|
37
|
+
self.triton_model_name = triton_model_name
|
|
38
|
+
self.triton_model_version = triton_model_version
|
|
39
|
+
self.max_batch_size = max_batch_size
|
|
40
|
+
self.model = model
|
|
41
|
+
self.http_port = http_port
|
|
42
|
+
self.grpc_port = grpc_port
|
|
43
|
+
self.address = address
|
|
44
|
+
self.triton = None
|
|
45
|
+
self.allow_grpc = allow_grpc
|
|
46
|
+
self.allow_http = allow_http
|
|
47
|
+
self.streaming = streaming
|
|
48
|
+
|
|
49
|
+
@abstractmethod
|
|
50
|
+
def deploy(self):
|
|
51
|
+
pass
|
|
52
|
+
|
|
53
|
+
@abstractmethod
|
|
54
|
+
def serve(self):
|
|
55
|
+
pass
|
|
56
|
+
|
|
57
|
+
@abstractmethod
|
|
58
|
+
def run(self):
|
|
59
|
+
pass
|
|
60
|
+
|
|
61
|
+
@abstractmethod
|
|
62
|
+
def stop(self):
|
|
63
|
+
pass
|
|
64
|
+
|
|
65
|
+
def _is_model_deployable(self):
|
|
66
|
+
if not issubclass(type(self.model), ITritonDeployable):
|
|
67
|
+
raise Exception(
|
|
68
|
+
"This model is not deployable to Triton.nemo_deploy.ITritonDeployable class should be inherited"
|
|
69
|
+
)
|
|
70
|
+
else:
|
|
71
|
+
return True
|
|
@@ -0,0 +1,173 @@
|
|
|
1
|
+
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
import logging
|
|
17
|
+
|
|
18
|
+
from nemo_deploy.deploy_base import DeployBase
|
|
19
|
+
from nemo_export_deploy_common.import_utils import MISSING_TRITON_MSG, UnavailableError
|
|
20
|
+
|
|
21
|
+
LOGGER = logging.getLogger("NeMo")
|
|
22
|
+
|
|
23
|
+
try:
|
|
24
|
+
from pytriton.model_config import ModelConfig
|
|
25
|
+
from pytriton.triton import Triton, TritonConfig
|
|
26
|
+
|
|
27
|
+
HAVE_TRITON = True
|
|
28
|
+
except (ImportError, ModuleNotFoundError):
|
|
29
|
+
HAVE_TRITON = False
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class DeployPyTriton(DeployBase):
|
|
33
|
+
"""Deploys any models to Triton Inference Server that implements ITritonDeployable interface in nemo_deploy.
|
|
34
|
+
|
|
35
|
+
Example:
|
|
36
|
+
from nemo_deploy import DeployPyTriton, NemoQueryLLM
|
|
37
|
+
from nemo_export.tensorrt_llm import TensorRTLLM
|
|
38
|
+
|
|
39
|
+
trt_llm_exporter = TensorRTLLM(model_dir="/path/for/model/files")
|
|
40
|
+
trt_llm_exporter.export(
|
|
41
|
+
nemo_checkpoint_path="/path/for/nemo/checkpoint",
|
|
42
|
+
model_type="llama",
|
|
43
|
+
tensor_parallelism_size=1,
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
nm = DeployPyTriton(model=trt_llm_exporter, triton_model_name="model_name", http_port=8000)
|
|
47
|
+
nm.deploy()
|
|
48
|
+
nm.run()
|
|
49
|
+
nq = NemoQueryLLM(url="localhost", model_name="model_name")
|
|
50
|
+
|
|
51
|
+
prompts = ["hello, testing GPT inference", "another GPT inference test?"]
|
|
52
|
+
output = nq.query_llm(prompts=prompts, max_output_len=100)
|
|
53
|
+
print("prompts: ", prompts)
|
|
54
|
+
print("")
|
|
55
|
+
print("output: ", output)
|
|
56
|
+
print("")
|
|
57
|
+
|
|
58
|
+
prompts = ["Give me some info about Paris", "Do you think Londan is a good city to visit?", "What do you think about Rome?"]
|
|
59
|
+
output = nq.query_llm(prompts=prompts, max_output_len=250)
|
|
60
|
+
print("prompts: ", prompts)
|
|
61
|
+
print("")
|
|
62
|
+
print("output: ", output)
|
|
63
|
+
print("")
|
|
64
|
+
|
|
65
|
+
"""
|
|
66
|
+
|
|
67
|
+
def __init__(
|
|
68
|
+
self,
|
|
69
|
+
triton_model_name: str,
|
|
70
|
+
triton_model_version: int = 1,
|
|
71
|
+
model=None,
|
|
72
|
+
max_batch_size: int = 128,
|
|
73
|
+
http_port: int = 8000,
|
|
74
|
+
grpc_port: int = 8001,
|
|
75
|
+
address="0.0.0.0",
|
|
76
|
+
allow_grpc=True,
|
|
77
|
+
allow_http=True,
|
|
78
|
+
streaming=False,
|
|
79
|
+
pytriton_log_verbose=0,
|
|
80
|
+
):
|
|
81
|
+
"""A nemo checkpoint or model is expected for serving on Triton Inference Server.
|
|
82
|
+
|
|
83
|
+
Args:
|
|
84
|
+
triton_model_name (str): Name for the service
|
|
85
|
+
triton_model_version(int): Version for the service
|
|
86
|
+
checkpoint_path (str): path of the nemo file
|
|
87
|
+
model (ITritonDeployable): A model that implements the ITritonDeployable from nemo_deploy import ITritonDeployable
|
|
88
|
+
max_batch_size (int): max batch size
|
|
89
|
+
port (int) : port for the Triton server
|
|
90
|
+
address (str): http address for Triton server to bind.
|
|
91
|
+
"""
|
|
92
|
+
super().__init__(
|
|
93
|
+
triton_model_name=triton_model_name,
|
|
94
|
+
triton_model_version=triton_model_version,
|
|
95
|
+
model=model,
|
|
96
|
+
max_batch_size=max_batch_size,
|
|
97
|
+
http_port=http_port,
|
|
98
|
+
grpc_port=grpc_port,
|
|
99
|
+
address=address,
|
|
100
|
+
allow_grpc=allow_grpc,
|
|
101
|
+
allow_http=allow_http,
|
|
102
|
+
streaming=streaming,
|
|
103
|
+
)
|
|
104
|
+
self.pytriton_log_verbose = pytriton_log_verbose
|
|
105
|
+
|
|
106
|
+
def deploy(self):
|
|
107
|
+
"""Deploys any models to Triton Inference Server."""
|
|
108
|
+
if not HAVE_TRITON:
|
|
109
|
+
raise UnavailableError(MISSING_TRITON_MSG)
|
|
110
|
+
|
|
111
|
+
try:
|
|
112
|
+
if self.streaming:
|
|
113
|
+
triton_config = TritonConfig(
|
|
114
|
+
log_verbose=self.pytriton_log_verbose,
|
|
115
|
+
allow_grpc=self.allow_grpc,
|
|
116
|
+
allow_http=self.allow_http,
|
|
117
|
+
grpc_address=self.address,
|
|
118
|
+
)
|
|
119
|
+
self.triton = Triton(config=triton_config)
|
|
120
|
+
self.triton.bind(
|
|
121
|
+
model_name=self.triton_model_name,
|
|
122
|
+
model_version=self.triton_model_version,
|
|
123
|
+
infer_func=self.model.triton_infer_fn_streaming,
|
|
124
|
+
inputs=self.model.get_triton_input,
|
|
125
|
+
outputs=self.model.get_triton_output,
|
|
126
|
+
config=ModelConfig(decoupled=True),
|
|
127
|
+
)
|
|
128
|
+
else:
|
|
129
|
+
triton_config = TritonConfig(
|
|
130
|
+
http_address=self.address,
|
|
131
|
+
http_port=self.http_port,
|
|
132
|
+
grpc_address=self.address,
|
|
133
|
+
grpc_port=self.grpc_port,
|
|
134
|
+
allow_grpc=self.allow_grpc,
|
|
135
|
+
allow_http=self.allow_http,
|
|
136
|
+
)
|
|
137
|
+
self.triton = Triton(config=triton_config)
|
|
138
|
+
self.triton.bind(
|
|
139
|
+
model_name=self.triton_model_name,
|
|
140
|
+
model_version=self.triton_model_version,
|
|
141
|
+
infer_func=self.model.triton_infer_fn,
|
|
142
|
+
inputs=self.model.get_triton_input,
|
|
143
|
+
outputs=self.model.get_triton_output,
|
|
144
|
+
config=ModelConfig(max_batch_size=self.max_batch_size),
|
|
145
|
+
)
|
|
146
|
+
except Exception as e:
|
|
147
|
+
self.triton = None
|
|
148
|
+
LOGGER.error(e)
|
|
149
|
+
|
|
150
|
+
def serve(self):
|
|
151
|
+
"""Starts serving the model and waits for the requests."""
|
|
152
|
+
if self.triton is None:
|
|
153
|
+
raise Exception("deploy should be called first.")
|
|
154
|
+
|
|
155
|
+
try:
|
|
156
|
+
self.triton.serve()
|
|
157
|
+
except Exception as e:
|
|
158
|
+
self.triton = None
|
|
159
|
+
LOGGER.error(e)
|
|
160
|
+
|
|
161
|
+
def run(self):
|
|
162
|
+
"""Starts serving the model asynchronously."""
|
|
163
|
+
if self.triton is None:
|
|
164
|
+
raise Exception("deploy should be called first.")
|
|
165
|
+
|
|
166
|
+
self.triton.run()
|
|
167
|
+
|
|
168
|
+
def stop(self):
|
|
169
|
+
"""Stops serving the model."""
|
|
170
|
+
if self.triton is None:
|
|
171
|
+
raise Exception("deploy should be called first.")
|
|
172
|
+
|
|
173
|
+
self.triton.stop()
|
|
@@ -0,0 +1,140 @@
|
|
|
1
|
+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
import logging
|
|
17
|
+
|
|
18
|
+
from nemo_deploy.ray_utils import find_available_port
|
|
19
|
+
from nemo_export_deploy_common.import_utils import MISSING_RAY_MSG, UnavailableError
|
|
20
|
+
|
|
21
|
+
try:
|
|
22
|
+
import ray
|
|
23
|
+
from ray import serve
|
|
24
|
+
from ray.serve import Application
|
|
25
|
+
|
|
26
|
+
HAVE_RAY = True
|
|
27
|
+
except (ImportError, ModuleNotFoundError):
|
|
28
|
+
from unittest.mock import MagicMock
|
|
29
|
+
|
|
30
|
+
ray = MagicMock()
|
|
31
|
+
serve = MagicMock()
|
|
32
|
+
Application = MagicMock()
|
|
33
|
+
HAVE_RAY = False
|
|
34
|
+
|
|
35
|
+
LOGGER = logging.getLogger("NeMo")
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class DeployRay:
|
|
39
|
+
"""A class for managing Ray deployment and serving of models.
|
|
40
|
+
|
|
41
|
+
This class provides functionality to initialize Ray, start Ray Serve,
|
|
42
|
+
deploy models, and manage the lifecycle of the Ray cluster.
|
|
43
|
+
|
|
44
|
+
Attributes:
|
|
45
|
+
address (str): The address of the Ray cluster to connect to.
|
|
46
|
+
num_cpus (int): Number of CPUs to allocate for the Ray cluster.
|
|
47
|
+
num_gpus (int): Number of GPUs to allocate for the Ray cluster.
|
|
48
|
+
include_dashboard (bool): Whether to include the Ray dashboard.
|
|
49
|
+
ignore_reinit_error (bool): Whether to ignore errors when reinitializing Ray.
|
|
50
|
+
runtime_env (dict): Runtime environment configuration for Ray.
|
|
51
|
+
"""
|
|
52
|
+
|
|
53
|
+
def __init__(
|
|
54
|
+
self,
|
|
55
|
+
address: str = "auto",
|
|
56
|
+
num_cpus: int = 1,
|
|
57
|
+
num_gpus: int = 1,
|
|
58
|
+
include_dashboard: bool = False,
|
|
59
|
+
ignore_reinit_error: bool = True,
|
|
60
|
+
runtime_env: dict = None,
|
|
61
|
+
):
|
|
62
|
+
"""Initialize the DeployRay instance and set up the Ray cluster.
|
|
63
|
+
|
|
64
|
+
Args:
|
|
65
|
+
address (str, optional): Address of the Ray cluster. Defaults to "auto".
|
|
66
|
+
num_cpus (int, optional): Number of CPUs to allocate. Defaults to 1.
|
|
67
|
+
num_gpus (int, optional): Number of GPUs to allocate. Defaults to 1.
|
|
68
|
+
include_dashboard (bool, optional): Whether to include the dashboard. Defaults to False.
|
|
69
|
+
ignore_reinit_error (bool, optional): Whether to ignore reinit errors. Defaults to True.
|
|
70
|
+
runtime_env (dict, optional): Runtime environment configuration. Defaults to None.
|
|
71
|
+
|
|
72
|
+
Raises:
|
|
73
|
+
Exception: If Ray is not installed.
|
|
74
|
+
"""
|
|
75
|
+
if not HAVE_RAY:
|
|
76
|
+
raise UnavailableError(MISSING_RAY_MSG)
|
|
77
|
+
|
|
78
|
+
# Initialize Ray with proper configuration
|
|
79
|
+
|
|
80
|
+
try:
|
|
81
|
+
# Try to connect to existing Ray cluster
|
|
82
|
+
ray.init(
|
|
83
|
+
address=address,
|
|
84
|
+
ignore_reinit_error=ignore_reinit_error,
|
|
85
|
+
runtime_env=runtime_env,
|
|
86
|
+
)
|
|
87
|
+
except ConnectionError:
|
|
88
|
+
# If no cluster exists, start a local one
|
|
89
|
+
LOGGER.info("No existing Ray cluster found. Starting a local Ray cluster...")
|
|
90
|
+
ray.init(
|
|
91
|
+
num_cpus=num_cpus,
|
|
92
|
+
num_gpus=num_gpus,
|
|
93
|
+
include_dashboard=include_dashboard,
|
|
94
|
+
ignore_reinit_error=ignore_reinit_error,
|
|
95
|
+
runtime_env=runtime_env,
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
def start(self, host: str = "0.0.0.0", port: int = None):
|
|
99
|
+
"""Start Ray Serve with the specified host and port.
|
|
100
|
+
|
|
101
|
+
Args:
|
|
102
|
+
host (str, optional): Host address to bind to. Defaults to "0.0.0.0".
|
|
103
|
+
port (int, optional): Port number to use. If None, an available port will be found.
|
|
104
|
+
"""
|
|
105
|
+
if not port:
|
|
106
|
+
port = find_available_port(8000, host)
|
|
107
|
+
serve.start(
|
|
108
|
+
http_options={
|
|
109
|
+
"host": host,
|
|
110
|
+
"port": port,
|
|
111
|
+
}
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
def run(self, app: Application, model_name: str):
|
|
115
|
+
"""Deploy and start serving a model using Ray Serve.
|
|
116
|
+
|
|
117
|
+
Args:
|
|
118
|
+
app (Application): The Ray Serve application to deploy.
|
|
119
|
+
model_name (str): Name to give to the deployed model.
|
|
120
|
+
"""
|
|
121
|
+
serve.run(app, name=model_name)
|
|
122
|
+
|
|
123
|
+
def stop(self):
|
|
124
|
+
"""Stop the Ray Serve deployment and shutdown the Ray cluster.
|
|
125
|
+
|
|
126
|
+
This method attempts to gracefully shutdown both Ray Serve and the Ray cluster.
|
|
127
|
+
If any errors occur during shutdown, they are logged as warnings.
|
|
128
|
+
"""
|
|
129
|
+
try:
|
|
130
|
+
# First try to gracefully shutdown Ray Serve
|
|
131
|
+
LOGGER.info("Shutting down Ray Serve...")
|
|
132
|
+
serve.shutdown()
|
|
133
|
+
except Exception as e:
|
|
134
|
+
LOGGER.warning(f"Error during serve.shutdown(): {str(e)}")
|
|
135
|
+
try:
|
|
136
|
+
# Then try to gracefully shutdown Ray
|
|
137
|
+
LOGGER.info("Shutting down Ray...")
|
|
138
|
+
ray.shutdown()
|
|
139
|
+
except Exception as e:
|
|
140
|
+
LOGGER.warning(f"Error during ray.shutdown(): {str(e)}")
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
from nemo_deploy.multimodal.query_multimodal import NemoQueryMultimodal
|
|
@@ -0,0 +1,184 @@
|
|
|
1
|
+
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from io import BytesIO
|
|
16
|
+
|
|
17
|
+
import numpy as np
|
|
18
|
+
import requests
|
|
19
|
+
|
|
20
|
+
from nemo_deploy.utils import str_list2numpy
|
|
21
|
+
from nemo_export_deploy_common.import_utils import (
|
|
22
|
+
MISSING_DECORD_MSG,
|
|
23
|
+
MISSING_PIL_MSG,
|
|
24
|
+
MISSING_TRITON_MSG,
|
|
25
|
+
UnavailableError,
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
try:
|
|
29
|
+
from PIL import Image
|
|
30
|
+
|
|
31
|
+
HAVE_PIL = True
|
|
32
|
+
except (ImportError, ModuleNotFoundError):
|
|
33
|
+
HAVE_PIL = False
|
|
34
|
+
|
|
35
|
+
try:
|
|
36
|
+
from decord import VideoReader
|
|
37
|
+
|
|
38
|
+
HAVE_DECORD = True
|
|
39
|
+
except (ImportError, ModuleNotFoundError):
|
|
40
|
+
HAVE_DECORD = False
|
|
41
|
+
|
|
42
|
+
try:
|
|
43
|
+
from pytriton.client import ModelClient
|
|
44
|
+
|
|
45
|
+
HAVE_TRITON = True
|
|
46
|
+
except (ImportError, ModuleNotFoundError):
|
|
47
|
+
from unittest.mock import MagicMock
|
|
48
|
+
|
|
49
|
+
ModelClient = MagicMock()
|
|
50
|
+
HAVE_TRITON = False
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class NemoQueryMultimodal:
|
|
54
|
+
"""Sends a query to Triton for Multimodal inference.
|
|
55
|
+
|
|
56
|
+
Example:
|
|
57
|
+
from nemo_deploy.multimodal import NemoQueryMultimodal
|
|
58
|
+
|
|
59
|
+
nq = NemoQueryMultimodal(url="localhost", model_name="neva", model_type="neva")
|
|
60
|
+
|
|
61
|
+
input_text = "Hi! What is in this image?"
|
|
62
|
+
output = nq.query(
|
|
63
|
+
input_text=input_text,
|
|
64
|
+
input_media="/path/to/image.jpg",
|
|
65
|
+
max_output_len=30,
|
|
66
|
+
top_k=1,
|
|
67
|
+
top_p=0.0,
|
|
68
|
+
temperature=1.0,
|
|
69
|
+
)
|
|
70
|
+
print("prompts: ", prompts)
|
|
71
|
+
"""
|
|
72
|
+
|
|
73
|
+
def __init__(self, url, model_name, model_type):
|
|
74
|
+
self.url = url
|
|
75
|
+
self.model_name = model_name
|
|
76
|
+
self.model_type = model_type
|
|
77
|
+
|
|
78
|
+
def setup_media(self, input_media):
|
|
79
|
+
"""Setup input media."""
|
|
80
|
+
if self.model_type == "video-neva":
|
|
81
|
+
if not HAVE_DECORD:
|
|
82
|
+
raise UnavailableError(MISSING_DECORD_MSG)
|
|
83
|
+
|
|
84
|
+
vr = VideoReader(input_media)
|
|
85
|
+
frames = [f.asnumpy() for f in vr]
|
|
86
|
+
return np.array(frames)
|
|
87
|
+
elif self.model_type == "lita" or self.model_type == "vita":
|
|
88
|
+
if not HAVE_DECORD:
|
|
89
|
+
raise UnavailableError(MISSING_DECORD_MSG)
|
|
90
|
+
|
|
91
|
+
vr = VideoReader(input_media)
|
|
92
|
+
frames = [f.asnumpy() for f in vr]
|
|
93
|
+
subsample_len = self.frame_len(frames)
|
|
94
|
+
sub_frames = self.get_subsampled_frames(frames, subsample_len)
|
|
95
|
+
return np.array(sub_frames)
|
|
96
|
+
elif self.model_type in ["neva", "vila", "mllama"]:
|
|
97
|
+
if not HAVE_PIL:
|
|
98
|
+
raise UnavailableError(MISSING_PIL_MSG)
|
|
99
|
+
|
|
100
|
+
if input_media.startswith("http") or input_media.startswith("https"):
|
|
101
|
+
response = requests.get(input_media, timeout=5)
|
|
102
|
+
media = Image.open(BytesIO(response.content)).convert("RGB")
|
|
103
|
+
else:
|
|
104
|
+
media = Image.open(input_media).convert("RGB")
|
|
105
|
+
return np.expand_dims(np.array(media), axis=0)
|
|
106
|
+
else:
|
|
107
|
+
raise RuntimeError(f"Invalid model type {self.model_type}")
|
|
108
|
+
|
|
109
|
+
def frame_len(self, frames):
|
|
110
|
+
"""Get frame len."""
|
|
111
|
+
max_frames = 256
|
|
112
|
+
if len(frames) <= max_frames:
|
|
113
|
+
return len(frames)
|
|
114
|
+
else:
|
|
115
|
+
subsample = int(np.ceil(float(len(frames)) / max_frames))
|
|
116
|
+
return int(np.round(float(len(frames)) / subsample))
|
|
117
|
+
|
|
118
|
+
def get_subsampled_frames(self, frames, subsample_len):
|
|
119
|
+
"""Get subsampled frames."""
|
|
120
|
+
idx = np.round(np.linspace(0, len(frames) - 1, subsample_len)).astype(int)
|
|
121
|
+
sub_frames = [frames[i] for i in idx]
|
|
122
|
+
return sub_frames
|
|
123
|
+
|
|
124
|
+
def query(
|
|
125
|
+
self,
|
|
126
|
+
input_text,
|
|
127
|
+
input_media,
|
|
128
|
+
batch_size=1,
|
|
129
|
+
max_output_len=30,
|
|
130
|
+
top_k=1,
|
|
131
|
+
top_p=0.0,
|
|
132
|
+
temperature=1.0,
|
|
133
|
+
repetition_penalty=1.0,
|
|
134
|
+
num_beams=1,
|
|
135
|
+
init_timeout=60.0,
|
|
136
|
+
lora_uids=None,
|
|
137
|
+
):
|
|
138
|
+
"""Run query."""
|
|
139
|
+
if not HAVE_TRITON:
|
|
140
|
+
raise UnavailableError(MISSING_TRITON_MSG)
|
|
141
|
+
|
|
142
|
+
prompts = str_list2numpy([input_text])
|
|
143
|
+
inputs = {"input_text": prompts}
|
|
144
|
+
|
|
145
|
+
media = self.setup_media(input_media)
|
|
146
|
+
if isinstance(media, dict):
|
|
147
|
+
inputs.update(media)
|
|
148
|
+
else:
|
|
149
|
+
inputs["input_media"] = np.repeat(media[np.newaxis, :, :, :, :], prompts.shape[0], axis=0)
|
|
150
|
+
|
|
151
|
+
if batch_size is not None:
|
|
152
|
+
inputs["batch_size"] = np.full(prompts.shape, batch_size, dtype=np.int_)
|
|
153
|
+
|
|
154
|
+
if max_output_len is not None:
|
|
155
|
+
inputs["max_output_len"] = np.full(prompts.shape, max_output_len, dtype=np.int_)
|
|
156
|
+
|
|
157
|
+
if top_k is not None:
|
|
158
|
+
inputs["top_k"] = np.full(prompts.shape, top_k, dtype=np.int_)
|
|
159
|
+
|
|
160
|
+
if top_p is not None:
|
|
161
|
+
inputs["top_p"] = np.full(prompts.shape, top_p, dtype=np.single)
|
|
162
|
+
|
|
163
|
+
if temperature is not None:
|
|
164
|
+
inputs["temperature"] = np.full(prompts.shape, temperature, dtype=np.single)
|
|
165
|
+
|
|
166
|
+
if repetition_penalty is not None:
|
|
167
|
+
inputs["repetition_penalty"] = np.full(prompts.shape, repetition_penalty, dtype=np.single)
|
|
168
|
+
|
|
169
|
+
if num_beams is not None:
|
|
170
|
+
inputs["num_beams"] = np.full(prompts.shape, num_beams, dtype=np.int_)
|
|
171
|
+
|
|
172
|
+
if lora_uids is not None:
|
|
173
|
+
lora_uids = np.char.encode(lora_uids, "utf-8")
|
|
174
|
+
inputs["lora_uids"] = np.full((prompts.shape[0], len(lora_uids)), lora_uids)
|
|
175
|
+
|
|
176
|
+
with ModelClient(self.url, self.model_name, init_timeout_s=init_timeout) as client:
|
|
177
|
+
result_dict = client.infer_batch(**inputs)
|
|
178
|
+
output_type = client.model_config.outputs[0].dtype
|
|
179
|
+
|
|
180
|
+
if output_type == np.bytes_:
|
|
181
|
+
sentences = np.char.decode(result_dict["outputs"].astype("bytes"), "utf-8")
|
|
182
|
+
return sentences
|
|
183
|
+
else:
|
|
184
|
+
return result_dict["outputs"]
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
from nemo_deploy.nlp.query_llm import NemoQueryLLM, NemoQueryLLMHF, NemoQueryLLMPyTorch, NemoQueryTRTLLMAPI
|
|
17
|
+
|
|
18
|
+
__all__ = [
|
|
19
|
+
"NemoQueryLLM",
|
|
20
|
+
"NemoQueryLLMHF",
|
|
21
|
+
"NemoQueryLLMPyTorch",
|
|
22
|
+
"NemoQueryTRTLLMAPI",
|
|
23
|
+
]
|