NeMo-Export-Deploy 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. nemo_deploy/__init__.py +26 -0
  2. nemo_deploy/deploy_base.py +71 -0
  3. nemo_deploy/deploy_pytriton.py +173 -0
  4. nemo_deploy/deploy_ray.py +140 -0
  5. nemo_deploy/multimodal/__init__.py +16 -0
  6. nemo_deploy/multimodal/query_multimodal.py +184 -0
  7. nemo_deploy/nlp/__init__.py +23 -0
  8. nemo_deploy/nlp/hf_deployable.py +466 -0
  9. nemo_deploy/nlp/hf_deployable_ray.py +375 -0
  10. nemo_deploy/nlp/inference/inference_base.py +491 -0
  11. nemo_deploy/nlp/inference/tron_utils.py +485 -0
  12. nemo_deploy/nlp/megatronllm_deployable.py +538 -0
  13. nemo_deploy/nlp/megatronllm_deployable_ray.py +374 -0
  14. nemo_deploy/nlp/query_llm.py +518 -0
  15. nemo_deploy/nlp/trtllm_api_deployable.py +194 -0
  16. nemo_deploy/package_info.py +33 -0
  17. nemo_deploy/ray_utils.py +47 -0
  18. nemo_deploy/service/__init__.py +16 -0
  19. nemo_deploy/service/fastapi_interface_to_pytriton.py +319 -0
  20. nemo_deploy/service/rest_model_api.py +140 -0
  21. nemo_deploy/triton_deployable.py +32 -0
  22. nemo_deploy/utils.py +227 -0
  23. nemo_export/__init__.py +28 -0
  24. nemo_export/multimodal/__init__.py +13 -0
  25. nemo_export/multimodal/build.py +679 -0
  26. nemo_export/multimodal/run.py +989 -0
  27. nemo_export/onnx_llm_exporter.py +543 -0
  28. nemo_export/package_info.py +33 -0
  29. nemo_export/sentencepiece_tokenizer.py +286 -0
  30. nemo_export/tarutils.py +236 -0
  31. nemo_export/tensorrt_llm.py +1134 -0
  32. nemo_export/tensorrt_llm_deployable_ray.py +293 -0
  33. nemo_export/tensorrt_mm_exporter.py +377 -0
  34. nemo_export/tiktoken_tokenizer.py +120 -0
  35. nemo_export/trt_llm/__init__.py +13 -0
  36. nemo_export/trt_llm/nemo_ckpt_loader/__init__.py +13 -0
  37. nemo_export/trt_llm/nemo_ckpt_loader/nemo_file.py +433 -0
  38. nemo_export/trt_llm/qnemo/__init__.py +17 -0
  39. nemo_export/trt_llm/qnemo/qnemo_to_tensorrt_llm.py +128 -0
  40. nemo_export/trt_llm/qnemo/utils.py +32 -0
  41. nemo_export/trt_llm/tensorrt_llm_run.py +565 -0
  42. nemo_export/trt_llm/utils.py +69 -0
  43. nemo_export/utils/__init__.py +45 -0
  44. nemo_export/utils/_mock_import.py +78 -0
  45. nemo_export/utils/constants.py +16 -0
  46. nemo_export/utils/lora_converter.py +236 -0
  47. nemo_export/utils/model_loader.py +213 -0
  48. nemo_export/utils/utils.py +167 -0
  49. nemo_export/vllm/__init__.py +13 -0
  50. nemo_export/vllm/model_config.py +279 -0
  51. nemo_export/vllm/model_converters.py +480 -0
  52. nemo_export/vllm/model_loader.py +111 -0
  53. nemo_export/vllm_exporter.py +583 -0
  54. nemo_export/vllm_hf_exporter.py +154 -0
  55. nemo_export_deploy-0.1.0.dist-info/METADATA +472 -0
  56. nemo_export_deploy-0.1.0.dist-info/RECORD +62 -0
  57. nemo_export_deploy-0.1.0.dist-info/WHEEL +5 -0
  58. nemo_export_deploy-0.1.0.dist-info/licenses/LICENSE +201 -0
  59. nemo_export_deploy-0.1.0.dist-info/top_level.txt +3 -0
  60. nemo_export_deploy_common/__init__.py +19 -0
  61. nemo_export_deploy_common/import_utils.py +450 -0
  62. nemo_export_deploy_common/package_info.py +45 -0
@@ -0,0 +1,26 @@
1
+ # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from nemo_deploy.deploy_base import DeployBase
16
+ from nemo_deploy.deploy_pytriton import DeployPyTriton
17
+ from nemo_deploy.triton_deployable import ITritonDeployable
18
+ from nemo_export_deploy_common.package_info import __package_name__, __version__
19
+
20
+ __all__ = [
21
+ "DeployBase",
22
+ "DeployPyTriton",
23
+ "ITritonDeployable",
24
+ "__version__",
25
+ "__package_name__",
26
+ ]
@@ -0,0 +1,71 @@
1
+ # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import logging
16
+ from abc import ABC, abstractmethod
17
+
18
+ from nemo_deploy.triton_deployable import ITritonDeployable
19
+
20
+ LOGGER = logging.getLogger("NeMo")
21
+
22
+
23
+ class DeployBase(ABC):
24
+ def __init__(
25
+ self,
26
+ triton_model_name: str,
27
+ triton_model_version: int = 1,
28
+ model=None,
29
+ max_batch_size: int = 128,
30
+ http_port: int = 8000,
31
+ grpc_port: int = 8001,
32
+ address="0.0.0.0",
33
+ allow_grpc=True,
34
+ allow_http=True,
35
+ streaming=False,
36
+ ):
37
+ self.triton_model_name = triton_model_name
38
+ self.triton_model_version = triton_model_version
39
+ self.max_batch_size = max_batch_size
40
+ self.model = model
41
+ self.http_port = http_port
42
+ self.grpc_port = grpc_port
43
+ self.address = address
44
+ self.triton = None
45
+ self.allow_grpc = allow_grpc
46
+ self.allow_http = allow_http
47
+ self.streaming = streaming
48
+
49
+ @abstractmethod
50
+ def deploy(self):
51
+ pass
52
+
53
+ @abstractmethod
54
+ def serve(self):
55
+ pass
56
+
57
+ @abstractmethod
58
+ def run(self):
59
+ pass
60
+
61
+ @abstractmethod
62
+ def stop(self):
63
+ pass
64
+
65
+ def _is_model_deployable(self):
66
+ if not issubclass(type(self.model), ITritonDeployable):
67
+ raise Exception(
68
+ "This model is not deployable to Triton.nemo_deploy.ITritonDeployable class should be inherited"
69
+ )
70
+ else:
71
+ return True
@@ -0,0 +1,173 @@
1
+ # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import logging
17
+
18
+ from nemo_deploy.deploy_base import DeployBase
19
+ from nemo_export_deploy_common.import_utils import MISSING_TRITON_MSG, UnavailableError
20
+
21
+ LOGGER = logging.getLogger("NeMo")
22
+
23
+ try:
24
+ from pytriton.model_config import ModelConfig
25
+ from pytriton.triton import Triton, TritonConfig
26
+
27
+ HAVE_TRITON = True
28
+ except (ImportError, ModuleNotFoundError):
29
+ HAVE_TRITON = False
30
+
31
+
32
+ class DeployPyTriton(DeployBase):
33
+ """Deploys any models to Triton Inference Server that implements ITritonDeployable interface in nemo_deploy.
34
+
35
+ Example:
36
+ from nemo_deploy import DeployPyTriton, NemoQueryLLM
37
+ from nemo_export.tensorrt_llm import TensorRTLLM
38
+
39
+ trt_llm_exporter = TensorRTLLM(model_dir="/path/for/model/files")
40
+ trt_llm_exporter.export(
41
+ nemo_checkpoint_path="/path/for/nemo/checkpoint",
42
+ model_type="llama",
43
+ tensor_parallelism_size=1,
44
+ )
45
+
46
+ nm = DeployPyTriton(model=trt_llm_exporter, triton_model_name="model_name", http_port=8000)
47
+ nm.deploy()
48
+ nm.run()
49
+ nq = NemoQueryLLM(url="localhost", model_name="model_name")
50
+
51
+ prompts = ["hello, testing GPT inference", "another GPT inference test?"]
52
+ output = nq.query_llm(prompts=prompts, max_output_len=100)
53
+ print("prompts: ", prompts)
54
+ print("")
55
+ print("output: ", output)
56
+ print("")
57
+
58
+ prompts = ["Give me some info about Paris", "Do you think Londan is a good city to visit?", "What do you think about Rome?"]
59
+ output = nq.query_llm(prompts=prompts, max_output_len=250)
60
+ print("prompts: ", prompts)
61
+ print("")
62
+ print("output: ", output)
63
+ print("")
64
+
65
+ """
66
+
67
+ def __init__(
68
+ self,
69
+ triton_model_name: str,
70
+ triton_model_version: int = 1,
71
+ model=None,
72
+ max_batch_size: int = 128,
73
+ http_port: int = 8000,
74
+ grpc_port: int = 8001,
75
+ address="0.0.0.0",
76
+ allow_grpc=True,
77
+ allow_http=True,
78
+ streaming=False,
79
+ pytriton_log_verbose=0,
80
+ ):
81
+ """A nemo checkpoint or model is expected for serving on Triton Inference Server.
82
+
83
+ Args:
84
+ triton_model_name (str): Name for the service
85
+ triton_model_version(int): Version for the service
86
+ checkpoint_path (str): path of the nemo file
87
+ model (ITritonDeployable): A model that implements the ITritonDeployable from nemo_deploy import ITritonDeployable
88
+ max_batch_size (int): max batch size
89
+ port (int) : port for the Triton server
90
+ address (str): http address for Triton server to bind.
91
+ """
92
+ super().__init__(
93
+ triton_model_name=triton_model_name,
94
+ triton_model_version=triton_model_version,
95
+ model=model,
96
+ max_batch_size=max_batch_size,
97
+ http_port=http_port,
98
+ grpc_port=grpc_port,
99
+ address=address,
100
+ allow_grpc=allow_grpc,
101
+ allow_http=allow_http,
102
+ streaming=streaming,
103
+ )
104
+ self.pytriton_log_verbose = pytriton_log_verbose
105
+
106
+ def deploy(self):
107
+ """Deploys any models to Triton Inference Server."""
108
+ if not HAVE_TRITON:
109
+ raise UnavailableError(MISSING_TRITON_MSG)
110
+
111
+ try:
112
+ if self.streaming:
113
+ triton_config = TritonConfig(
114
+ log_verbose=self.pytriton_log_verbose,
115
+ allow_grpc=self.allow_grpc,
116
+ allow_http=self.allow_http,
117
+ grpc_address=self.address,
118
+ )
119
+ self.triton = Triton(config=triton_config)
120
+ self.triton.bind(
121
+ model_name=self.triton_model_name,
122
+ model_version=self.triton_model_version,
123
+ infer_func=self.model.triton_infer_fn_streaming,
124
+ inputs=self.model.get_triton_input,
125
+ outputs=self.model.get_triton_output,
126
+ config=ModelConfig(decoupled=True),
127
+ )
128
+ else:
129
+ triton_config = TritonConfig(
130
+ http_address=self.address,
131
+ http_port=self.http_port,
132
+ grpc_address=self.address,
133
+ grpc_port=self.grpc_port,
134
+ allow_grpc=self.allow_grpc,
135
+ allow_http=self.allow_http,
136
+ )
137
+ self.triton = Triton(config=triton_config)
138
+ self.triton.bind(
139
+ model_name=self.triton_model_name,
140
+ model_version=self.triton_model_version,
141
+ infer_func=self.model.triton_infer_fn,
142
+ inputs=self.model.get_triton_input,
143
+ outputs=self.model.get_triton_output,
144
+ config=ModelConfig(max_batch_size=self.max_batch_size),
145
+ )
146
+ except Exception as e:
147
+ self.triton = None
148
+ LOGGER.error(e)
149
+
150
+ def serve(self):
151
+ """Starts serving the model and waits for the requests."""
152
+ if self.triton is None:
153
+ raise Exception("deploy should be called first.")
154
+
155
+ try:
156
+ self.triton.serve()
157
+ except Exception as e:
158
+ self.triton = None
159
+ LOGGER.error(e)
160
+
161
+ def run(self):
162
+ """Starts serving the model asynchronously."""
163
+ if self.triton is None:
164
+ raise Exception("deploy should be called first.")
165
+
166
+ self.triton.run()
167
+
168
+ def stop(self):
169
+ """Stops serving the model."""
170
+ if self.triton is None:
171
+ raise Exception("deploy should be called first.")
172
+
173
+ self.triton.stop()
@@ -0,0 +1,140 @@
1
+ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import logging
17
+
18
+ from nemo_deploy.ray_utils import find_available_port
19
+ from nemo_export_deploy_common.import_utils import MISSING_RAY_MSG, UnavailableError
20
+
21
+ try:
22
+ import ray
23
+ from ray import serve
24
+ from ray.serve import Application
25
+
26
+ HAVE_RAY = True
27
+ except (ImportError, ModuleNotFoundError):
28
+ from unittest.mock import MagicMock
29
+
30
+ ray = MagicMock()
31
+ serve = MagicMock()
32
+ Application = MagicMock()
33
+ HAVE_RAY = False
34
+
35
+ LOGGER = logging.getLogger("NeMo")
36
+
37
+
38
+ class DeployRay:
39
+ """A class for managing Ray deployment and serving of models.
40
+
41
+ This class provides functionality to initialize Ray, start Ray Serve,
42
+ deploy models, and manage the lifecycle of the Ray cluster.
43
+
44
+ Attributes:
45
+ address (str): The address of the Ray cluster to connect to.
46
+ num_cpus (int): Number of CPUs to allocate for the Ray cluster.
47
+ num_gpus (int): Number of GPUs to allocate for the Ray cluster.
48
+ include_dashboard (bool): Whether to include the Ray dashboard.
49
+ ignore_reinit_error (bool): Whether to ignore errors when reinitializing Ray.
50
+ runtime_env (dict): Runtime environment configuration for Ray.
51
+ """
52
+
53
+ def __init__(
54
+ self,
55
+ address: str = "auto",
56
+ num_cpus: int = 1,
57
+ num_gpus: int = 1,
58
+ include_dashboard: bool = False,
59
+ ignore_reinit_error: bool = True,
60
+ runtime_env: dict = None,
61
+ ):
62
+ """Initialize the DeployRay instance and set up the Ray cluster.
63
+
64
+ Args:
65
+ address (str, optional): Address of the Ray cluster. Defaults to "auto".
66
+ num_cpus (int, optional): Number of CPUs to allocate. Defaults to 1.
67
+ num_gpus (int, optional): Number of GPUs to allocate. Defaults to 1.
68
+ include_dashboard (bool, optional): Whether to include the dashboard. Defaults to False.
69
+ ignore_reinit_error (bool, optional): Whether to ignore reinit errors. Defaults to True.
70
+ runtime_env (dict, optional): Runtime environment configuration. Defaults to None.
71
+
72
+ Raises:
73
+ Exception: If Ray is not installed.
74
+ """
75
+ if not HAVE_RAY:
76
+ raise UnavailableError(MISSING_RAY_MSG)
77
+
78
+ # Initialize Ray with proper configuration
79
+
80
+ try:
81
+ # Try to connect to existing Ray cluster
82
+ ray.init(
83
+ address=address,
84
+ ignore_reinit_error=ignore_reinit_error,
85
+ runtime_env=runtime_env,
86
+ )
87
+ except ConnectionError:
88
+ # If no cluster exists, start a local one
89
+ LOGGER.info("No existing Ray cluster found. Starting a local Ray cluster...")
90
+ ray.init(
91
+ num_cpus=num_cpus,
92
+ num_gpus=num_gpus,
93
+ include_dashboard=include_dashboard,
94
+ ignore_reinit_error=ignore_reinit_error,
95
+ runtime_env=runtime_env,
96
+ )
97
+
98
+ def start(self, host: str = "0.0.0.0", port: int = None):
99
+ """Start Ray Serve with the specified host and port.
100
+
101
+ Args:
102
+ host (str, optional): Host address to bind to. Defaults to "0.0.0.0".
103
+ port (int, optional): Port number to use. If None, an available port will be found.
104
+ """
105
+ if not port:
106
+ port = find_available_port(8000, host)
107
+ serve.start(
108
+ http_options={
109
+ "host": host,
110
+ "port": port,
111
+ }
112
+ )
113
+
114
+ def run(self, app: Application, model_name: str):
115
+ """Deploy and start serving a model using Ray Serve.
116
+
117
+ Args:
118
+ app (Application): The Ray Serve application to deploy.
119
+ model_name (str): Name to give to the deployed model.
120
+ """
121
+ serve.run(app, name=model_name)
122
+
123
+ def stop(self):
124
+ """Stop the Ray Serve deployment and shutdown the Ray cluster.
125
+
126
+ This method attempts to gracefully shutdown both Ray Serve and the Ray cluster.
127
+ If any errors occur during shutdown, they are logged as warnings.
128
+ """
129
+ try:
130
+ # First try to gracefully shutdown Ray Serve
131
+ LOGGER.info("Shutting down Ray Serve...")
132
+ serve.shutdown()
133
+ except Exception as e:
134
+ LOGGER.warning(f"Error during serve.shutdown(): {str(e)}")
135
+ try:
136
+ # Then try to gracefully shutdown Ray
137
+ LOGGER.info("Shutting down Ray...")
138
+ ray.shutdown()
139
+ except Exception as e:
140
+ LOGGER.warning(f"Error during ray.shutdown(): {str(e)}")
@@ -0,0 +1,16 @@
1
+ # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from nemo_deploy.multimodal.query_multimodal import NemoQueryMultimodal
@@ -0,0 +1,184 @@
1
+ # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from io import BytesIO
16
+
17
+ import numpy as np
18
+ import requests
19
+
20
+ from nemo_deploy.utils import str_list2numpy
21
+ from nemo_export_deploy_common.import_utils import (
22
+ MISSING_DECORD_MSG,
23
+ MISSING_PIL_MSG,
24
+ MISSING_TRITON_MSG,
25
+ UnavailableError,
26
+ )
27
+
28
+ try:
29
+ from PIL import Image
30
+
31
+ HAVE_PIL = True
32
+ except (ImportError, ModuleNotFoundError):
33
+ HAVE_PIL = False
34
+
35
+ try:
36
+ from decord import VideoReader
37
+
38
+ HAVE_DECORD = True
39
+ except (ImportError, ModuleNotFoundError):
40
+ HAVE_DECORD = False
41
+
42
+ try:
43
+ from pytriton.client import ModelClient
44
+
45
+ HAVE_TRITON = True
46
+ except (ImportError, ModuleNotFoundError):
47
+ from unittest.mock import MagicMock
48
+
49
+ ModelClient = MagicMock()
50
+ HAVE_TRITON = False
51
+
52
+
53
+ class NemoQueryMultimodal:
54
+ """Sends a query to Triton for Multimodal inference.
55
+
56
+ Example:
57
+ from nemo_deploy.multimodal import NemoQueryMultimodal
58
+
59
+ nq = NemoQueryMultimodal(url="localhost", model_name="neva", model_type="neva")
60
+
61
+ input_text = "Hi! What is in this image?"
62
+ output = nq.query(
63
+ input_text=input_text,
64
+ input_media="/path/to/image.jpg",
65
+ max_output_len=30,
66
+ top_k=1,
67
+ top_p=0.0,
68
+ temperature=1.0,
69
+ )
70
+ print("prompts: ", prompts)
71
+ """
72
+
73
+ def __init__(self, url, model_name, model_type):
74
+ self.url = url
75
+ self.model_name = model_name
76
+ self.model_type = model_type
77
+
78
+ def setup_media(self, input_media):
79
+ """Setup input media."""
80
+ if self.model_type == "video-neva":
81
+ if not HAVE_DECORD:
82
+ raise UnavailableError(MISSING_DECORD_MSG)
83
+
84
+ vr = VideoReader(input_media)
85
+ frames = [f.asnumpy() for f in vr]
86
+ return np.array(frames)
87
+ elif self.model_type == "lita" or self.model_type == "vita":
88
+ if not HAVE_DECORD:
89
+ raise UnavailableError(MISSING_DECORD_MSG)
90
+
91
+ vr = VideoReader(input_media)
92
+ frames = [f.asnumpy() for f in vr]
93
+ subsample_len = self.frame_len(frames)
94
+ sub_frames = self.get_subsampled_frames(frames, subsample_len)
95
+ return np.array(sub_frames)
96
+ elif self.model_type in ["neva", "vila", "mllama"]:
97
+ if not HAVE_PIL:
98
+ raise UnavailableError(MISSING_PIL_MSG)
99
+
100
+ if input_media.startswith("http") or input_media.startswith("https"):
101
+ response = requests.get(input_media, timeout=5)
102
+ media = Image.open(BytesIO(response.content)).convert("RGB")
103
+ else:
104
+ media = Image.open(input_media).convert("RGB")
105
+ return np.expand_dims(np.array(media), axis=0)
106
+ else:
107
+ raise RuntimeError(f"Invalid model type {self.model_type}")
108
+
109
+ def frame_len(self, frames):
110
+ """Get frame len."""
111
+ max_frames = 256
112
+ if len(frames) <= max_frames:
113
+ return len(frames)
114
+ else:
115
+ subsample = int(np.ceil(float(len(frames)) / max_frames))
116
+ return int(np.round(float(len(frames)) / subsample))
117
+
118
+ def get_subsampled_frames(self, frames, subsample_len):
119
+ """Get subsampled frames."""
120
+ idx = np.round(np.linspace(0, len(frames) - 1, subsample_len)).astype(int)
121
+ sub_frames = [frames[i] for i in idx]
122
+ return sub_frames
123
+
124
+ def query(
125
+ self,
126
+ input_text,
127
+ input_media,
128
+ batch_size=1,
129
+ max_output_len=30,
130
+ top_k=1,
131
+ top_p=0.0,
132
+ temperature=1.0,
133
+ repetition_penalty=1.0,
134
+ num_beams=1,
135
+ init_timeout=60.0,
136
+ lora_uids=None,
137
+ ):
138
+ """Run query."""
139
+ if not HAVE_TRITON:
140
+ raise UnavailableError(MISSING_TRITON_MSG)
141
+
142
+ prompts = str_list2numpy([input_text])
143
+ inputs = {"input_text": prompts}
144
+
145
+ media = self.setup_media(input_media)
146
+ if isinstance(media, dict):
147
+ inputs.update(media)
148
+ else:
149
+ inputs["input_media"] = np.repeat(media[np.newaxis, :, :, :, :], prompts.shape[0], axis=0)
150
+
151
+ if batch_size is not None:
152
+ inputs["batch_size"] = np.full(prompts.shape, batch_size, dtype=np.int_)
153
+
154
+ if max_output_len is not None:
155
+ inputs["max_output_len"] = np.full(prompts.shape, max_output_len, dtype=np.int_)
156
+
157
+ if top_k is not None:
158
+ inputs["top_k"] = np.full(prompts.shape, top_k, dtype=np.int_)
159
+
160
+ if top_p is not None:
161
+ inputs["top_p"] = np.full(prompts.shape, top_p, dtype=np.single)
162
+
163
+ if temperature is not None:
164
+ inputs["temperature"] = np.full(prompts.shape, temperature, dtype=np.single)
165
+
166
+ if repetition_penalty is not None:
167
+ inputs["repetition_penalty"] = np.full(prompts.shape, repetition_penalty, dtype=np.single)
168
+
169
+ if num_beams is not None:
170
+ inputs["num_beams"] = np.full(prompts.shape, num_beams, dtype=np.int_)
171
+
172
+ if lora_uids is not None:
173
+ lora_uids = np.char.encode(lora_uids, "utf-8")
174
+ inputs["lora_uids"] = np.full((prompts.shape[0], len(lora_uids)), lora_uids)
175
+
176
+ with ModelClient(self.url, self.model_name, init_timeout_s=init_timeout) as client:
177
+ result_dict = client.infer_batch(**inputs)
178
+ output_type = client.model_config.outputs[0].dtype
179
+
180
+ if output_type == np.bytes_:
181
+ sentences = np.char.decode(result_dict["outputs"].astype("bytes"), "utf-8")
182
+ return sentences
183
+ else:
184
+ return result_dict["outputs"]
@@ -0,0 +1,23 @@
1
+ # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from nemo_deploy.nlp.query_llm import NemoQueryLLM, NemoQueryLLMHF, NemoQueryLLMPyTorch, NemoQueryTRTLLMAPI
17
+
18
+ __all__ = [
19
+ "NemoQueryLLM",
20
+ "NemoQueryLLMHF",
21
+ "NemoQueryLLMPyTorch",
22
+ "NemoQueryTRTLLMAPI",
23
+ ]