gst-python-ml 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- engine/__init__.py +0 -0
- engine/gst_device_queue_pool.py +121 -0
- engine/gst_engine_factory.py +130 -0
- engine/gst_ml_engine.py +75 -0
- engine/gst_onnx_engine.py +99 -0
- engine/gst_openvino_engine.py +150 -0
- engine/gst_pytorch_engine.py +376 -0
- engine/gst_pytorch_yolo_engine.py +74 -0
- engine/gst_tensorflow_engine.py +72 -0
- engine/gst_tflite_engine.py +105 -0
- gst_python_ml-0.1.0.dist-info/COPYING +16 -0
- gst_python_ml-0.1.0.dist-info/METADATA +363 -0
- gst_python_ml-0.1.0.dist-info/RECORD +15 -0
- gst_python_ml-0.1.0.dist-info/WHEEL +5 -0
- gst_python_ml-0.1.0.dist-info/top_level.txt +1 -0
engine/__init__.py
ADDED
|
File without changes
|
|
@@ -0,0 +1,121 @@
|
|
|
1
|
+
# GstDeviceQueuePool
|
|
2
|
+
# Copyright (C) 2024-2025 Collabora Ltd.
|
|
3
|
+
#
|
|
4
|
+
# This library is free software; you can redistribute it and/or
|
|
5
|
+
# modify it under the terms of the GNU Library General Public
|
|
6
|
+
# License as published by the Free Software Foundation; either
|
|
7
|
+
# version 2 of the License, or (at your option) any later version.
|
|
8
|
+
#
|
|
9
|
+
# This library is distributed in the hope that it will be useful,
|
|
10
|
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
11
|
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
|
|
12
|
+
# Library General Public License for more details.
|
|
13
|
+
#
|
|
14
|
+
# You should have received a copy of the GNU Library General Public
|
|
15
|
+
# License along with this library; if not, write to the
|
|
16
|
+
# Free Software Foundation, Inc., 51 Franklin Street, Fifth Floor,
|
|
17
|
+
# Boston, MA 02110-1301, USA.
|
|
18
|
+
|
|
19
|
+
import gi
|
|
20
|
+
from abc import ABC, abstractmethod
|
|
21
|
+
|
|
22
|
+
gi.require_version("Gst", "1.0")
|
|
23
|
+
from gi.repository import Gst # noqa: E402
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class GstDeviceQueue(ABC):
|
|
27
|
+
def __init__(self, queue_handle):
|
|
28
|
+
"""
|
|
29
|
+
Initialize the DeviceQueue with a given queue handle.
|
|
30
|
+
|
|
31
|
+
:param queue_handle: An integer representing the queue handle.
|
|
32
|
+
"""
|
|
33
|
+
self.queue_handle = queue_handle
|
|
34
|
+
|
|
35
|
+
@abstractmethod
|
|
36
|
+
def synchronize(self):
|
|
37
|
+
"""
|
|
38
|
+
Abstract method to synchronize the device queue.
|
|
39
|
+
Must be implemented by subclasses.
|
|
40
|
+
"""
|
|
41
|
+
pass
|
|
42
|
+
|
|
43
|
+
def __repr__(self):
|
|
44
|
+
return f"DeviceQueue(handle={self.queue_handle})"
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class GstDeviceQueuePool:
|
|
48
|
+
def __init__(self):
|
|
49
|
+
"""
|
|
50
|
+
Initialize the DeviceQueuePool with an empty dictionary to map IDs to DeviceQueues.
|
|
51
|
+
"""
|
|
52
|
+
self.queues = {}
|
|
53
|
+
|
|
54
|
+
def add_queue(self, queue_id, device_queue):
|
|
55
|
+
"""
|
|
56
|
+
Add a DeviceQueue to the pool by its ID.
|
|
57
|
+
|
|
58
|
+
:param queue_id: Unique ID for the DeviceQueue.
|
|
59
|
+
:param device_queue: A DeviceQueue object to add to the pool.
|
|
60
|
+
"""
|
|
61
|
+
if queue_id in self.queues:
|
|
62
|
+
Gst.warning(
|
|
63
|
+
f"DeviceQueue with ID {queue_id} already exists. Not adding again."
|
|
64
|
+
)
|
|
65
|
+
return
|
|
66
|
+
|
|
67
|
+
self.queues[queue_id] = device_queue
|
|
68
|
+
|
|
69
|
+
def get_queue(self, queue_id):
|
|
70
|
+
"""
|
|
71
|
+
Retrieve a DeviceQueue by its ID.
|
|
72
|
+
|
|
73
|
+
:param queue_id: Unique ID of the queue in the pool.
|
|
74
|
+
:return: DeviceQueue object if found, None otherwise.
|
|
75
|
+
"""
|
|
76
|
+
queue = self.queues.get(queue_id, None)
|
|
77
|
+
if queue is None:
|
|
78
|
+
Gst.warning(f"No DeviceQueue found for ID {queue_id}.")
|
|
79
|
+
return queue
|
|
80
|
+
|
|
81
|
+
def __repr__(self):
|
|
82
|
+
return f"GstDeviceQueuePool(queues={self.queues})"
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
class DeviceQueueManager:
|
|
86
|
+
_instance = None
|
|
87
|
+
_device_queue_pools = {}
|
|
88
|
+
|
|
89
|
+
def __new__(cls):
|
|
90
|
+
if cls._instance is None:
|
|
91
|
+
cls._instance = super(DeviceQueueManager, cls).__new__(cls)
|
|
92
|
+
cls._instance._device_queue_pools = {}
|
|
93
|
+
return cls._instance
|
|
94
|
+
|
|
95
|
+
def add_pool(self, device, queue_pool):
|
|
96
|
+
"""
|
|
97
|
+
Adds a DeviceQueuePool for a specific device.
|
|
98
|
+
|
|
99
|
+
:param device: Unique identifier for the device (e.g., "cuda:0").
|
|
100
|
+
:param queue_pool: A DeviceQueuePool object to associate with the device.
|
|
101
|
+
"""
|
|
102
|
+
if device in self._device_queue_pools:
|
|
103
|
+
return # Do not add if it already exists
|
|
104
|
+
|
|
105
|
+
self._device_queue_pools[device] = queue_pool
|
|
106
|
+
Gst.info(f"Added DeviceQueuePool for device {device}.")
|
|
107
|
+
|
|
108
|
+
def get_pool(self, device):
|
|
109
|
+
"""
|
|
110
|
+
Retrieves the DeviceQueuePool associated with the specified device.
|
|
111
|
+
|
|
112
|
+
:param device: Unique identifier for the device.
|
|
113
|
+
:return: DeviceQueuePool object if found, None otherwise.
|
|
114
|
+
"""
|
|
115
|
+
pool = self._device_queue_pools.get(device, None)
|
|
116
|
+
if pool is None:
|
|
117
|
+
Gst.warning(f"No DeviceQueuePool found for device {device}.")
|
|
118
|
+
return pool
|
|
119
|
+
|
|
120
|
+
def __repr__(self):
|
|
121
|
+
return f"DeviceQueueManager(device_queue_pools={self._device_queue_pools})"
|
|
@@ -0,0 +1,130 @@
|
|
|
1
|
+
# GstEngineFactory
|
|
2
|
+
# Copyright (C) 2024-2025 Collabora Ltd.
|
|
3
|
+
#
|
|
4
|
+
# This library is free software; you can redistribute it and/or
|
|
5
|
+
# modify it under the terms of the GNU Library General Public
|
|
6
|
+
# License as published by the Free Software Foundation; either
|
|
7
|
+
# version 2 of the License, or (at your option) any later version.
|
|
8
|
+
#
|
|
9
|
+
# This library is distributed in the hope that it will be useful,
|
|
10
|
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
11
|
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
|
|
12
|
+
# Library General Public License for more details.
|
|
13
|
+
#
|
|
14
|
+
# You should have received a copy of the GNU Library General Public
|
|
15
|
+
# License along with this library; if not, write to the
|
|
16
|
+
# Free Software Foundation, Inc., 51 Franklin Street, Fifth Floor,
|
|
17
|
+
# Boston, MA 02110-1301, USA.
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
try:
|
|
21
|
+
from .gst_pytorch_engine import GstPyTorchEngine
|
|
22
|
+
|
|
23
|
+
_pytorch_engine_available = True
|
|
24
|
+
except ImportError:
|
|
25
|
+
_pytorch_engine_available = False
|
|
26
|
+
|
|
27
|
+
try:
|
|
28
|
+
from .gst_pytorch_yolo_engine import GstPyTorchYoloEngine
|
|
29
|
+
|
|
30
|
+
_pytorch_yolo_engine_available = True
|
|
31
|
+
except ImportError:
|
|
32
|
+
_pytorch_yolo_engine_available = False
|
|
33
|
+
|
|
34
|
+
try:
|
|
35
|
+
from .gst_tflite_engine import GstTFLiteEngine
|
|
36
|
+
|
|
37
|
+
_tflite_engine_available = True
|
|
38
|
+
except ImportError:
|
|
39
|
+
_tflite_engine_available = False
|
|
40
|
+
|
|
41
|
+
try:
|
|
42
|
+
from .gst_tensorflow_engine import GstTensorFlowEngine
|
|
43
|
+
|
|
44
|
+
_tensorflow_engine_available = True
|
|
45
|
+
except ImportError:
|
|
46
|
+
_tensorflow_engine_available = False
|
|
47
|
+
|
|
48
|
+
try:
|
|
49
|
+
from .gst_onnx_engine import GstONNXEngine
|
|
50
|
+
|
|
51
|
+
_onnx_engine_available = True
|
|
52
|
+
except ImportError:
|
|
53
|
+
_onnx_engine_available = False
|
|
54
|
+
|
|
55
|
+
try:
|
|
56
|
+
from .gst_openvino_engine import GstOpenVinoEngine
|
|
57
|
+
|
|
58
|
+
_openvino_engine_available = True
|
|
59
|
+
except ImportError:
|
|
60
|
+
_openvino_engine_available = False
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
class GstEngineFactory:
|
|
64
|
+
# Define the constant strings for each engine
|
|
65
|
+
PYTORCH_ENGINE = "pytorch"
|
|
66
|
+
PYTORCH_YOLO_ENGINE = "pytorch-yolo"
|
|
67
|
+
TFLITE_ENGINE = "tflite"
|
|
68
|
+
TENSORFLOW_ENGINE = "tensorflow"
|
|
69
|
+
ONNX_ENGINE = "onnx"
|
|
70
|
+
OPENVINO_ENGINE = "openvino"
|
|
71
|
+
|
|
72
|
+
@staticmethod
|
|
73
|
+
def create_engine(engine_type, device="cpu"):
|
|
74
|
+
"""
|
|
75
|
+
Factory method to create the appropriate engine based on the engine type.
|
|
76
|
+
|
|
77
|
+
:param engine_type: The type of the ML engine, e.g., "pytorch" or "tflite".
|
|
78
|
+
:param device: The device to run the engine on (default is "cpu").
|
|
79
|
+
:return: An instance of the appropriate ML engine class.
|
|
80
|
+
"""
|
|
81
|
+
if engine_type == GstEngineFactory.PYTORCH_ENGINE:
|
|
82
|
+
if _pytorch_engine_available:
|
|
83
|
+
return GstPyTorchEngine(device)
|
|
84
|
+
else:
|
|
85
|
+
raise ImportError(
|
|
86
|
+
f"{GstEngineFactory.PYTORCH_ENGINE} engine is not available."
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
if engine_type == GstEngineFactory.PYTORCH_YOLO_ENGINE:
|
|
90
|
+
if _pytorch_yolo_engine_available:
|
|
91
|
+
return GstPyTorchYoloEngine(device)
|
|
92
|
+
else:
|
|
93
|
+
raise ImportError(
|
|
94
|
+
f"{GstEngineFactory.PYTORCH_YOLO_ENGINE} engine is not available."
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
elif engine_type == GstEngineFactory.TFLITE_ENGINE:
|
|
98
|
+
if _tflite_engine_available:
|
|
99
|
+
return GstTFLiteEngine(device)
|
|
100
|
+
else:
|
|
101
|
+
raise ImportError(
|
|
102
|
+
f"{GstEngineFactory.TFLITE_ENGINE} engine is not available."
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
elif engine_type == GstEngineFactory.TENSORFLOW_ENGINE:
|
|
106
|
+
if _tensorflow_engine_available:
|
|
107
|
+
return GstTensorFlowEngine(device)
|
|
108
|
+
else:
|
|
109
|
+
raise ImportError(
|
|
110
|
+
f"{GstEngineFactory.TENSORFLOW_ENGINE} engine is not available."
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
elif engine_type == GstEngineFactory.ONNX_ENGINE:
|
|
114
|
+
if _onnx_engine_available:
|
|
115
|
+
return GstONNXEngine(device)
|
|
116
|
+
else:
|
|
117
|
+
raise ImportError(
|
|
118
|
+
f"{GstEngineFactory.ONNX_ENGINE} engine is not available."
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
elif engine_type == GstEngineFactory.OPENVINO_ENGINE:
|
|
122
|
+
if _openvino_engine_available:
|
|
123
|
+
return GstOpenVinoEngine(device)
|
|
124
|
+
else:
|
|
125
|
+
raise ImportError(
|
|
126
|
+
f"{GstEngineFactory.OPENVINO_ENGINE} engine is not available."
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
else:
|
|
130
|
+
raise ValueError(f"Unsupported engine type: {engine_type}")
|
engine/gst_ml_engine.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
1
|
+
# GstMLEngine
|
|
2
|
+
# Copyright (C) 2024-2025 Collabora Ltd.
|
|
3
|
+
#
|
|
4
|
+
# This library is free software; you can redistribute it and/or
|
|
5
|
+
# modify it under the terms of the GNU Library General Public
|
|
6
|
+
# License as published by the Free Software Foundation; either
|
|
7
|
+
# version 2 of the License, or (at your option) any later version.
|
|
8
|
+
#
|
|
9
|
+
# This library is distributed in the hope that it will be useful,
|
|
10
|
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
11
|
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
|
|
12
|
+
# Library General Public License for more details.
|
|
13
|
+
#
|
|
14
|
+
# You should have received a copy of the GNU Library General Public
|
|
15
|
+
# License along with this library; if not, write to the
|
|
16
|
+
# Free Software Foundation, Inc., 51 Franklin Street, Fifth Floor,
|
|
17
|
+
# Boston, MA 02110-1301, USA.
|
|
18
|
+
|
|
19
|
+
from abc import ABC, abstractmethod
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class GstMLEngine(ABC):
|
|
23
|
+
def __init__(self, device="cpu"):
|
|
24
|
+
self.device = device
|
|
25
|
+
self.device_index = 0
|
|
26
|
+
self.model = None
|
|
27
|
+
self.vision_language_model = False
|
|
28
|
+
self.tokenizer = None
|
|
29
|
+
self.image_processor = None
|
|
30
|
+
self.batch_size = 1
|
|
31
|
+
self.frame_buffer = []
|
|
32
|
+
self.counter = 0
|
|
33
|
+
self.device_queue_id = None
|
|
34
|
+
self.track = False
|
|
35
|
+
self.prompt = "What is shown in this image?" # Default prompt
|
|
36
|
+
|
|
37
|
+
@abstractmethod
|
|
38
|
+
def load_model(self, model_name, **kwargs):
|
|
39
|
+
"""Load a model by name or path, with additional options."""
|
|
40
|
+
pass
|
|
41
|
+
|
|
42
|
+
def set_prompt(self, prompt):
|
|
43
|
+
"""Set the custom prompt for generating responses."""
|
|
44
|
+
self.prompt = prompt
|
|
45
|
+
|
|
46
|
+
def get_prompt(self):
|
|
47
|
+
"""Return the custom prompt."""
|
|
48
|
+
return self.prompt
|
|
49
|
+
|
|
50
|
+
def get_device(self):
|
|
51
|
+
"""Return the loaded model for use in inference."""
|
|
52
|
+
return self.device
|
|
53
|
+
|
|
54
|
+
@abstractmethod
|
|
55
|
+
def set_device(self, device):
|
|
56
|
+
"""Set the device (e.g., cpu, cuda)."""
|
|
57
|
+
pass
|
|
58
|
+
|
|
59
|
+
def get_model(self):
|
|
60
|
+
"""Return the loaded model for use in inference."""
|
|
61
|
+
return self.model
|
|
62
|
+
|
|
63
|
+
def set_model(self, model):
|
|
64
|
+
"""Set the model directly (useful for loading pre-built models)."""
|
|
65
|
+
self.model = model
|
|
66
|
+
|
|
67
|
+
@abstractmethod
|
|
68
|
+
def forward(self, frame):
|
|
69
|
+
"""Execute inference (usually object detection)"""
|
|
70
|
+
pass
|
|
71
|
+
|
|
72
|
+
@abstractmethod
|
|
73
|
+
def generate(self, input_text, max_length=100):
|
|
74
|
+
"""Generate LLM text"""
|
|
75
|
+
pass
|
|
@@ -0,0 +1,99 @@
|
|
|
1
|
+
# GstONNXEngine
|
|
2
|
+
# Copyright (C) 2024-2025 Collabora Ltd.
|
|
3
|
+
#
|
|
4
|
+
# This library is free software; you can redistribute it and/or
|
|
5
|
+
# modify it under the terms of the GNU Library General Public
|
|
6
|
+
# License as published by the Free Software Foundation; either
|
|
7
|
+
# version 2 of the License, or (at your option) any later version.
|
|
8
|
+
#
|
|
9
|
+
# This library is distributed in the hope that it will be useful,
|
|
10
|
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
11
|
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
|
|
12
|
+
# Library General Public License for more details.
|
|
13
|
+
#
|
|
14
|
+
# You should have received a copy of the GNU Library General Public
|
|
15
|
+
# License along with this library; if not, write to the
|
|
16
|
+
# Free Software Foundation, Inc., 51 Franklin Street, Fifth Floor,
|
|
17
|
+
# Boston, MA 02110-1301, USA.
|
|
18
|
+
|
|
19
|
+
import gi
|
|
20
|
+
import numpy as np
|
|
21
|
+
import onnxruntime as ort # ONNX Runtime for executing ONNX models
|
|
22
|
+
from .gst_ml_engine import GstMLEngine
|
|
23
|
+
|
|
24
|
+
gi.require_version("Gst", "1.0")
|
|
25
|
+
gi.require_version("GstBase", "1.0")
|
|
26
|
+
gi.require_version("GLib", "2.0")
|
|
27
|
+
from gi.repository import Gst # noqa: E402
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class GstONNXEngine(GstMLEngine):
|
|
31
|
+
def __init__(self, device="cpu"):
|
|
32
|
+
"""
|
|
33
|
+
Initialize the ONNX engine with the specified device.
|
|
34
|
+
"""
|
|
35
|
+
super().__init__(device)
|
|
36
|
+
self.session = None
|
|
37
|
+
self.input_names = None
|
|
38
|
+
self.output_names = None
|
|
39
|
+
|
|
40
|
+
def load_model(self, model_name, **kwargs):
|
|
41
|
+
"""
|
|
42
|
+
Load the ONNX model from the specified file path.
|
|
43
|
+
"""
|
|
44
|
+
try:
|
|
45
|
+
# Create the ONNX runtime session with device (CPU, CUDA, etc.)
|
|
46
|
+
providers = (
|
|
47
|
+
["CPUExecutionProvider"]
|
|
48
|
+
if self.device == "cpu"
|
|
49
|
+
else ["CUDAExecutionProvider"]
|
|
50
|
+
)
|
|
51
|
+
self.session = ort.InferenceSession(model_name, providers=providers)
|
|
52
|
+
|
|
53
|
+
# Extract input and output names for reference
|
|
54
|
+
self.input_names = [inp.name for inp in self.session.get_inputs()]
|
|
55
|
+
self.output_names = [out.name for out in self.session.get_outputs()]
|
|
56
|
+
|
|
57
|
+
Gst.info(f"ONNX model '{model_name}' loaded successfully on {self.device}.")
|
|
58
|
+
except Exception as e:
|
|
59
|
+
Gst.error(f"Failed to load ONNX model '{model_name}'. Error: {e}")
|
|
60
|
+
|
|
61
|
+
def set_device(self, device):
|
|
62
|
+
"""
|
|
63
|
+
Set the device for inference.
|
|
64
|
+
"""
|
|
65
|
+
self.device = device
|
|
66
|
+
# ONNX Runtime does not allow changing the device after session creation,
|
|
67
|
+
# so we need to reload the model if device changes.
|
|
68
|
+
if self.session:
|
|
69
|
+
model_path = (
|
|
70
|
+
self.session.get_modelmeta().producer_name
|
|
71
|
+
) # Assuming model path is stored
|
|
72
|
+
self.load_model(model_path)
|
|
73
|
+
|
|
74
|
+
def forward(self, frame):
|
|
75
|
+
"""
|
|
76
|
+
Perform inference on the given frame using the ONNX model.
|
|
77
|
+
"""
|
|
78
|
+
try:
|
|
79
|
+
# Preprocess the frame (resize, normalize, etc.) as required by the model
|
|
80
|
+
input_tensor = np.expand_dims(
|
|
81
|
+
frame.astype(np.float32), axis=0
|
|
82
|
+
) # Add batch dimension
|
|
83
|
+
|
|
84
|
+
# Prepare the input dictionary
|
|
85
|
+
input_dict = {self.input_names[0]: input_tensor}
|
|
86
|
+
|
|
87
|
+
# Run inference
|
|
88
|
+
output_data = self.session.run(self.output_names, input_dict)
|
|
89
|
+
|
|
90
|
+
# Convert output to NumPy arrays if needed
|
|
91
|
+
results = [np.array(out) for out in output_data]
|
|
92
|
+
|
|
93
|
+
return (
|
|
94
|
+
results if len(results) > 1 else results[0]
|
|
95
|
+
) # Return single result or list of results
|
|
96
|
+
|
|
97
|
+
except Exception as e:
|
|
98
|
+
Gst.error(f"Error during ONNX inference: {e}")
|
|
99
|
+
return None
|
|
@@ -0,0 +1,150 @@
|
|
|
1
|
+
# GstOpenVinoEngine
|
|
2
|
+
# Copyright (C) 2024-2025 Collabora Ltd.
|
|
3
|
+
#
|
|
4
|
+
# This library is free software; you can redistribute it and/or
|
|
5
|
+
# modify it under the terms of the GNU Library General Public
|
|
6
|
+
# License as published by the Free Software Foundation; either
|
|
7
|
+
# version 2 of the License, or (at your option) any later version.
|
|
8
|
+
#
|
|
9
|
+
# This library is distributed in the hope that it will be useful,
|
|
10
|
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
11
|
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
|
|
12
|
+
# Library General Public License for more details.
|
|
13
|
+
#
|
|
14
|
+
# You should have received a copy of the GNU Library General Public
|
|
15
|
+
# License along with this library; if not, write to the
|
|
16
|
+
# Free Software Foundation, Inc., 51 Franklin Street, Fifth Floor,
|
|
17
|
+
# Boston, MA 02110-1301, USA.
|
|
18
|
+
|
|
19
|
+
import gi
|
|
20
|
+
import numpy as np
|
|
21
|
+
from openvino.runtime import Core
|
|
22
|
+
from .gst_ml_engine import GstMLEngine
|
|
23
|
+
|
|
24
|
+
gi.require_version("Gst", "1.0")
|
|
25
|
+
from gi.repository import Gst # noqa: E402
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class GstOpenVinoEngine(GstMLEngine):
|
|
29
|
+
def __init__(self, device="CPU"):
|
|
30
|
+
self.device = device
|
|
31
|
+
self.core = Core()
|
|
32
|
+
self.compiled_model = None
|
|
33
|
+
self.tokenizer = None
|
|
34
|
+
self.is_vision_model = False
|
|
35
|
+
self.is_llm = False
|
|
36
|
+
|
|
37
|
+
def load_model(self, model_name, **kwargs):
|
|
38
|
+
"""
|
|
39
|
+
Load the model using OpenVINO and differentiate between vision and LLM models.
|
|
40
|
+
"""
|
|
41
|
+
try:
|
|
42
|
+
# Load model
|
|
43
|
+
model_path = f"{model_name}.xml"
|
|
44
|
+
self.model = self.core.read_model(model=model_path)
|
|
45
|
+
self.compiled_model = self.core.compile_model(self.model, self.device)
|
|
46
|
+
Gst.info(f"Model '{model_name}' loaded successfully on {self.device}")
|
|
47
|
+
|
|
48
|
+
# Inspect input shape to determine the type of model
|
|
49
|
+
input_shape = self.compiled_model.input(0).shape
|
|
50
|
+
|
|
51
|
+
if (
|
|
52
|
+
len(input_shape) == 4 and input_shape[1] == 3
|
|
53
|
+
): # Expecting image input (batch, 3 channels, height, width)
|
|
54
|
+
self.is_vision_model = True
|
|
55
|
+
Gst.info("Model identified as a vision model.")
|
|
56
|
+
elif len(input_shape) == 2: # Expecting LLM input (batch, sequence_length)
|
|
57
|
+
self.is_llm = True
|
|
58
|
+
Gst.info("Model identified as a large language model (LLM).")
|
|
59
|
+
|
|
60
|
+
if self.is_llm:
|
|
61
|
+
# Load a tokenizer for the LLM (Hugging Face for instance)
|
|
62
|
+
from transformers import AutoTokenizer
|
|
63
|
+
|
|
64
|
+
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
65
|
+
|
|
66
|
+
except Exception as e:
|
|
67
|
+
raise ValueError(f"Failed to load model '{model_name}'. Error: {e}")
|
|
68
|
+
|
|
69
|
+
def generate(self, input_data, max_length=100):
|
|
70
|
+
"""
|
|
71
|
+
Process input for the respective model type.
|
|
72
|
+
"""
|
|
73
|
+
if self.is_vision_model:
|
|
74
|
+
# Preprocess input for vision models (e.g., resize, normalize)
|
|
75
|
+
return self.preprocess_vision_input(input_data)
|
|
76
|
+
elif self.is_llm:
|
|
77
|
+
# Preprocess input for LLMs (tokenization)
|
|
78
|
+
return self.preprocess_llm_input(input_data, max_length)
|
|
79
|
+
else:
|
|
80
|
+
raise ValueError("Unknown model type. Please load a model first.")
|
|
81
|
+
|
|
82
|
+
def preprocess_vision_input(self, image):
|
|
83
|
+
"""
|
|
84
|
+
Preprocess input image for vision models.
|
|
85
|
+
"""
|
|
86
|
+
# Example preprocessing: Resize and normalize the image
|
|
87
|
+
input_shape = self.compiled_model.input(0).shape
|
|
88
|
+
resized_image = np.resize(
|
|
89
|
+
image, input_shape[2:]
|
|
90
|
+
) # Resize image to (height, width)
|
|
91
|
+
resized_image = resized_image.astype(np.float32) / 255.0 # Normalize
|
|
92
|
+
resized_image = np.expand_dims(resized_image, axis=0) # Add batch dimension
|
|
93
|
+
return resized_image
|
|
94
|
+
|
|
95
|
+
def preprocess_llm_input(self, input_text, max_length):
|
|
96
|
+
"""
|
|
97
|
+
Tokenize input text for LLM models.
|
|
98
|
+
"""
|
|
99
|
+
if self.tokenizer is None:
|
|
100
|
+
raise RuntimeError("Tokenizer not initialized. Load an LLM model first.")
|
|
101
|
+
|
|
102
|
+
tokens = self.tokenizer(
|
|
103
|
+
input_text,
|
|
104
|
+
max_length=max_length,
|
|
105
|
+
truncation=True,
|
|
106
|
+
padding="max_length",
|
|
107
|
+
return_tensors="np", # Return as NumPy array
|
|
108
|
+
)
|
|
109
|
+
input_ids = tokens["input_ids"]
|
|
110
|
+
attention_mask = tokens["attention_mask"]
|
|
111
|
+
|
|
112
|
+
return input_ids, attention_mask
|
|
113
|
+
|
|
114
|
+
def forward(self, input_data):
|
|
115
|
+
"""
|
|
116
|
+
Perform inference using the loaded OpenVINO model.
|
|
117
|
+
"""
|
|
118
|
+
if self.compiled_model is None:
|
|
119
|
+
raise RuntimeError("Model not loaded. Please load the model first.")
|
|
120
|
+
|
|
121
|
+
if self.is_vision_model:
|
|
122
|
+
# Perform inference for vision models
|
|
123
|
+
return self.perform_vision_inference(input_data)
|
|
124
|
+
elif self.is_llm:
|
|
125
|
+
# Perform inference for LLM models
|
|
126
|
+
return self.perform_llm_inference(input_data)
|
|
127
|
+
|
|
128
|
+
def perform_vision_inference(self, input_data):
|
|
129
|
+
"""
|
|
130
|
+
Perform inference on vision models.
|
|
131
|
+
"""
|
|
132
|
+
infer_request = self.compiled_model.create_infer_request()
|
|
133
|
+
infer_request.infer({self.compiled_model.input(0): input_data})
|
|
134
|
+
return infer_request.get_output_tensor(self.compiled_model.output(0)).data
|
|
135
|
+
|
|
136
|
+
def perform_llm_inference(self, input_data):
|
|
137
|
+
"""
|
|
138
|
+
Perform inference on LLM models.
|
|
139
|
+
"""
|
|
140
|
+
input_ids, attention_mask = (
|
|
141
|
+
input_data # Assuming input_data is (input_ids, attention_mask)
|
|
142
|
+
)
|
|
143
|
+
infer_request = self.compiled_model.create_infer_request()
|
|
144
|
+
infer_request.infer(
|
|
145
|
+
{
|
|
146
|
+
self.compiled_model.input(0): input_ids,
|
|
147
|
+
self.compiled_model.input(1): attention_mask,
|
|
148
|
+
}
|
|
149
|
+
)
|
|
150
|
+
return infer_request.get_output_tensor(self.compiled_model.output(0)).data
|