model-runner-client 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- model_runner_client-0.1.0/PKG-INFO +107 -0
- model_runner_client-0.1.0/README.md +89 -0
- model_runner_client-0.1.0/model_runner_client/__init__.py +0 -0
- model_runner_client-0.1.0/model_runner_client/model_cluster.py +127 -0
- model_runner_client-0.1.0/model_runner_client/model_concurrent_runner.py +65 -0
- model_runner_client-0.1.0/model_runner_client/model_runner.py +76 -0
- model_runner_client-0.1.0/model_runner_client/protos/__init__.py +0 -0
- model_runner_client-0.1.0/model_runner_client/protos/model_runner.proto +29 -0
- model_runner_client-0.1.0/model_runner_client/protos/model_runner_pb2.py +43 -0
- model_runner_client-0.1.0/model_runner_client/protos/model_runner_pb2.pyi +38 -0
- model_runner_client-0.1.0/model_runner_client/protos/model_runner_pb2_grpc.py +227 -0
- model_runner_client-0.1.0/model_runner_client/utils/__init__.py +0 -0
- model_runner_client-0.1.0/model_runner_client/utils/datatype_transformer.py +71 -0
- model_runner_client-0.1.0/model_runner_client/websocket_client.py +134 -0
- model_runner_client-0.1.0/pyproject.toml +27 -0
|
@@ -0,0 +1,107 @@
|
|
|
1
|
+
Metadata-Version: 2.1
|
|
2
|
+
Name: model-runner-client
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Model Runner Client is a Python library that allows Coordinators to manage real-time model synchronization and perform concurrent predictions on distributed model nodes within a Crunch.
|
|
5
|
+
Author: boutrig abdennour
|
|
6
|
+
Author-email: abdennour.boutrig@crunchdao.com
|
|
7
|
+
Requires-Python: >=3.11,<4.0
|
|
8
|
+
Classifier: Programming Language :: Python :: 3
|
|
9
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
10
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
11
|
+
Classifier: Programming Language :: Python :: 3.13
|
|
12
|
+
Requires-Dist: grpcio (>=1.70.0,<2.0.0)
|
|
13
|
+
Requires-Dist: pandas (>=2.2.3,<3.0.0)
|
|
14
|
+
Requires-Dist: protobuf (>=5.29.3,<6.0.0)
|
|
15
|
+
Requires-Dist: pyarrow (>=19.0.0,<20.0.0)
|
|
16
|
+
Requires-Dist: websockets (>=14.2,<15.0)
|
|
17
|
+
Description-Content-Type: text/markdown
|
|
18
|
+
|
|
19
|
+
# Model Runner Client
|
|
20
|
+
|
|
21
|
+
**Model Runner Client** is a Python library that allows you, as a Coordinator, to interact with models participating in your crunch. It tracks which models join or leave through a WebSocket connection to the model nodes.
|
|
22
|
+
|
|
23
|
+
- **Real-Time Model Sync**: Each model participating in your crunch is an instance of `ModelRunner`, maintained via WebSocket in the `ModelCluster`.
|
|
24
|
+
- **Concurrent Predictions (with Timeout Handling)**: Use the `ModelConcurrentRunner` to request predictions from all models simultaneously. Define a timeout to avoid blocking if a model takes too long to predict.
|
|
25
|
+
|
|
26
|
+
## Installation
|
|
27
|
+
|
|
28
|
+
```bash
|
|
29
|
+
pip install model-runner-client
|
|
30
|
+
```
|
|
31
|
+
|
|
32
|
+
> **Note**: Adjust this command (e.g., `pip3` or virtual environments) depending on your setup.
|
|
33
|
+
|
|
34
|
+
## Usage
|
|
35
|
+
|
|
36
|
+
Below is a quick example focusing on the ModelConcurrentRunner. It handles concurrent predictions for you and returns all results in one go.
|
|
37
|
+
|
|
38
|
+
```python
|
|
39
|
+
import asyncio
|
|
40
|
+
from model_runner_client.model_concurrent_runner import ModelConcurrentRunner
|
|
41
|
+
from model_runner_client.protos.model_runner_pb2 import DataType
|
|
42
|
+
from model_runner_client.utils.datatype_transformer import encode_data
|
|
43
|
+
|
|
44
|
+
async def main():
|
|
45
|
+
# crunch_id, host, and port are values provided by crunchdao
|
|
46
|
+
concurrent_runner = ModelConcurrentRunner(
|
|
47
|
+
timeout=10,
|
|
48
|
+
crunch_id="bird-game",
|
|
49
|
+
host="localhost",
|
|
50
|
+
port=8000
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
# Initialize communication with the model nodes to fetch
|
|
54
|
+
# models that want to predict and set up the model cluster
|
|
55
|
+
await concurrent_runner.init()
|
|
56
|
+
|
|
57
|
+
async def prediction_call():
|
|
58
|
+
while True:
|
|
59
|
+
# Your data to be predicted (X)
|
|
60
|
+
value = {
|
|
61
|
+
'falcon_location': 21.179864629354732,
|
|
62
|
+
'time': 230.96231205799998,
|
|
63
|
+
'dove_location': 19.164986723324326,
|
|
64
|
+
'falcon_id': 1
|
|
65
|
+
}
|
|
66
|
+
|
|
67
|
+
# Encode data as binary and request a prediction
|
|
68
|
+
result = await concurrent_runner.predict(
|
|
69
|
+
DataType.JSON,
|
|
70
|
+
encode_data(DataType.JSON, value)
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
# You receive a dictionary of predictions
|
|
74
|
+
for model_runner, model_predict_result in result.items():
|
|
75
|
+
print(f"{model_runner.model_id}: {model_predict_result}")
|
|
76
|
+
|
|
77
|
+
# This pause (30s) simulates other work
|
|
78
|
+
# the Coordinator might perform between predictions
|
|
79
|
+
await asyncio.sleep(30)
|
|
80
|
+
|
|
81
|
+
# Keep the cluster updated with `concurrent_runner.sync()`,
|
|
82
|
+
# which maintains a permanent WebSocket connection.
|
|
83
|
+
# Then run our prediction process.
|
|
84
|
+
await asyncio.gather(
|
|
85
|
+
asyncio.create_task(concurrent_runner.sync()),
|
|
86
|
+
asyncio.create_task(prediction_call())
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
if __name__ == "__main__":
|
|
90
|
+
try:
|
|
91
|
+
asyncio.run(main())
|
|
92
|
+
except KeyboardInterrupt:
|
|
93
|
+
print("\nReceived exit signal, shutting down gracefully.")
|
|
94
|
+
```
|
|
95
|
+
|
|
96
|
+
### Important Notes
|
|
97
|
+
|
|
98
|
+
- **Prediction Failures & Timeouts**: A prediction may fail or exceed the defined timeout, so be sure to handle these cases appropriately. Refer to `ModelPredictResult.Status` for details.
|
|
99
|
+
- **Custom Implementations**: If you need more control over your workflow, you can manage each model individually. Instead of using `ModelConcurrentRunner`, you can directly leverage `ModelRunner` instances from the `ModelCluster`, customizing how you schedule predictions and handle results.
|
|
100
|
+
|
|
101
|
+
## Contributing
|
|
102
|
+
|
|
103
|
+
Contributions are welcome! Feel free to open issues or submit pull requests if you encounter any bugs or want to suggest improvements.
|
|
104
|
+
|
|
105
|
+
## License
|
|
106
|
+
|
|
107
|
+
This project is distributed under the [MIT License](https://choosealicense.com/licenses/mit/). See the LICENSE file for details.
|
|
@@ -0,0 +1,89 @@
|
|
|
1
|
+
# Model Runner Client
|
|
2
|
+
|
|
3
|
+
**Model Runner Client** is a Python library that allows you, as a Coordinator, to interact with models participating in your crunch. It tracks which models join or leave through a WebSocket connection to the model nodes.
|
|
4
|
+
|
|
5
|
+
- **Real-Time Model Sync**: Each model participating in your crunch is an instance of `ModelRunner`, maintained via WebSocket in the `ModelCluster`.
|
|
6
|
+
- **Concurrent Predictions (with Timeout Handling)**: Use the `ModelConcurrentRunner` to request predictions from all models simultaneously. Define a timeout to avoid blocking if a model takes too long to predict.
|
|
7
|
+
|
|
8
|
+
## Installation
|
|
9
|
+
|
|
10
|
+
```bash
|
|
11
|
+
pip install model-runner-client
|
|
12
|
+
```
|
|
13
|
+
|
|
14
|
+
> **Note**: Adjust this command (e.g., `pip3` or virtual environments) depending on your setup.
|
|
15
|
+
|
|
16
|
+
## Usage
|
|
17
|
+
|
|
18
|
+
Below is a quick example focusing on the ModelConcurrentRunner. It handles concurrent predictions for you and returns all results in one go.
|
|
19
|
+
|
|
20
|
+
```python
|
|
21
|
+
import asyncio
|
|
22
|
+
from model_runner_client.model_concurrent_runner import ModelConcurrentRunner
|
|
23
|
+
from model_runner_client.protos.model_runner_pb2 import DataType
|
|
24
|
+
from model_runner_client.utils.datatype_transformer import encode_data
|
|
25
|
+
|
|
26
|
+
async def main():
|
|
27
|
+
# crunch_id, host, and port are values provided by crunchdao
|
|
28
|
+
concurrent_runner = ModelConcurrentRunner(
|
|
29
|
+
timeout=10,
|
|
30
|
+
crunch_id="bird-game",
|
|
31
|
+
host="localhost",
|
|
32
|
+
port=8000
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
# Initialize communication with the model nodes to fetch
|
|
36
|
+
# models that want to predict and set up the model cluster
|
|
37
|
+
await concurrent_runner.init()
|
|
38
|
+
|
|
39
|
+
async def prediction_call():
|
|
40
|
+
while True:
|
|
41
|
+
# Your data to be predicted (X)
|
|
42
|
+
value = {
|
|
43
|
+
'falcon_location': 21.179864629354732,
|
|
44
|
+
'time': 230.96231205799998,
|
|
45
|
+
'dove_location': 19.164986723324326,
|
|
46
|
+
'falcon_id': 1
|
|
47
|
+
}
|
|
48
|
+
|
|
49
|
+
# Encode data as binary and request a prediction
|
|
50
|
+
result = await concurrent_runner.predict(
|
|
51
|
+
DataType.JSON,
|
|
52
|
+
encode_data(DataType.JSON, value)
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
# You receive a dictionary of predictions
|
|
56
|
+
for model_runner, model_predict_result in result.items():
|
|
57
|
+
print(f"{model_runner.model_id}: {model_predict_result}")
|
|
58
|
+
|
|
59
|
+
# This pause (30s) simulates other work
|
|
60
|
+
# the Coordinator might perform between predictions
|
|
61
|
+
await asyncio.sleep(30)
|
|
62
|
+
|
|
63
|
+
# Keep the cluster updated with `concurrent_runner.sync()`,
|
|
64
|
+
# which maintains a permanent WebSocket connection.
|
|
65
|
+
# Then run our prediction process.
|
|
66
|
+
await asyncio.gather(
|
|
67
|
+
asyncio.create_task(concurrent_runner.sync()),
|
|
68
|
+
asyncio.create_task(prediction_call())
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
if __name__ == "__main__":
|
|
72
|
+
try:
|
|
73
|
+
asyncio.run(main())
|
|
74
|
+
except KeyboardInterrupt:
|
|
75
|
+
print("\nReceived exit signal, shutting down gracefully.")
|
|
76
|
+
```
|
|
77
|
+
|
|
78
|
+
### Important Notes
|
|
79
|
+
|
|
80
|
+
- **Prediction Failures & Timeouts**: A prediction may fail or exceed the defined timeout, so be sure to handle these cases appropriately. Refer to `ModelPredictResult.Status` for details.
|
|
81
|
+
- **Custom Implementations**: If you need more control over your workflow, you can manage each model individually. Instead of using `ModelConcurrentRunner`, you can directly leverage `ModelRunner` instances from the `ModelCluster`, customizing how you schedule predictions and handle results.
|
|
82
|
+
|
|
83
|
+
## Contributing
|
|
84
|
+
|
|
85
|
+
Contributions are welcome! Feel free to open issues or submit pull requests if you encounter any bugs or want to suggest improvements.
|
|
86
|
+
|
|
87
|
+
## License
|
|
88
|
+
|
|
89
|
+
This project is distributed under the [MIT License](https://choosealicense.com/licenses/mit/). See the LICENSE file for details.
|
|
File without changes
|
|
@@ -0,0 +1,127 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
|
|
3
|
+
from model_runner_client.model_runner import ModelRunner
|
|
4
|
+
from model_runner_client.websocket_client import WebsocketClient
|
|
5
|
+
|
|
6
|
+
logger = logging.getLogger("model_runner_client")
|
|
7
|
+
import asyncio
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class ModelCluster:
|
|
11
|
+
def __init__(self, crunch_id: str, ws_host: str, ws_port: int):
|
|
12
|
+
"""
|
|
13
|
+
ModelCluster constructor.
|
|
14
|
+
|
|
15
|
+
:param crunch_id: The Crunch ID that this cluster is responsible for.
|
|
16
|
+
:param ws_host: The WebSocket server's host.
|
|
17
|
+
:param ws_port: The WebSocket server's port.
|
|
18
|
+
"""
|
|
19
|
+
self.crunch_id = crunch_id
|
|
20
|
+
self.models_run = {}
|
|
21
|
+
logger.debug(f"Initializing ModelCluster with Crunch ID: {crunch_id}")
|
|
22
|
+
self.ws_client = WebsocketClient(ws_host, ws_port, crunch_id, event_handler=self.handle_event)
|
|
23
|
+
|
|
24
|
+
async def init(self):
|
|
25
|
+
await self.ws_client.connect()
|
|
26
|
+
await self.ws_client.init()
|
|
27
|
+
|
|
28
|
+
logger.debug("WebSocket client initialized.")
|
|
29
|
+
|
|
30
|
+
async def sync(self):
|
|
31
|
+
await self.ws_client.listen()
|
|
32
|
+
|
|
33
|
+
async def handle_event(self, event_type: str, data: list[dict]):
|
|
34
|
+
"""
|
|
35
|
+
Handle WebSocket events (`init` and `update`) and update the cluster's state.
|
|
36
|
+
|
|
37
|
+
:param event_type: The type of the event (`init` or `update`).
|
|
38
|
+
:param data: The event data.
|
|
39
|
+
"""
|
|
40
|
+
try:
|
|
41
|
+
if event_type == "init":
|
|
42
|
+
logger.debug(f"**ModelCluster** Processing event type: {event_type}")
|
|
43
|
+
await self.handle_init_event(data)
|
|
44
|
+
elif event_type == "update":
|
|
45
|
+
logger.debug(f"**ModelCluster** Processing event type: {event_type}")
|
|
46
|
+
await self.handle_update_event(data)
|
|
47
|
+
else:
|
|
48
|
+
logger.warning(f"**ModelCluster** Unknown event type received: {event_type}")
|
|
49
|
+
except Exception as e:
|
|
50
|
+
logger.error(f"**ModelCluster** Error processing event {event_type}: {e}", exc_info=True)
|
|
51
|
+
raise e
|
|
52
|
+
|
|
53
|
+
async def handle_init_event(self, data: list[dict]):
|
|
54
|
+
"""
|
|
55
|
+
Process the `init` event to initialize models run.
|
|
56
|
+
|
|
57
|
+
:param data: List of models with their initial states.
|
|
58
|
+
"""
|
|
59
|
+
logger.debug("**ModelCluster** Handling 'init' event.")
|
|
60
|
+
await self.update_model_runs(data)
|
|
61
|
+
|
|
62
|
+
async def handle_update_event(self, data: list[dict]):
|
|
63
|
+
"""
|
|
64
|
+
Process the `update` event to update model states.
|
|
65
|
+
|
|
66
|
+
:param data: List of models with their updated states.
|
|
67
|
+
"""
|
|
68
|
+
logger.debug("**ModelCluster** Handling 'update' event.")
|
|
69
|
+
await self.update_model_runs(data)
|
|
70
|
+
|
|
71
|
+
async def update_model_runs(self, data):
|
|
72
|
+
tasks = []
|
|
73
|
+
for model_update in data:
|
|
74
|
+
model_id = model_update.get("model_id")
|
|
75
|
+
model_name = model_update.get("model_name")
|
|
76
|
+
state = model_update.get("state")
|
|
77
|
+
ip = model_update.get("ip")
|
|
78
|
+
port = model_update.get("port")
|
|
79
|
+
logger.debug(f"**ModelCluster** Updating model with ID: {model_id}")
|
|
80
|
+
|
|
81
|
+
# Find the model in the current state
|
|
82
|
+
model_runner = self.models_run.get(model_id)
|
|
83
|
+
|
|
84
|
+
if model_runner:
|
|
85
|
+
if state == "STOPPED":
|
|
86
|
+
# Remove model if state is "stopped"
|
|
87
|
+
tasks.append(self.remove_model_runner(model_runner))
|
|
88
|
+
logger.debug(f"**ModelCluster** Model with ID {model_id} removed due to 'stopped' state.")
|
|
89
|
+
elif state == "RUNNING":
|
|
90
|
+
logger.debug(f"**ModelCluster** Model with ID {model_id} is already running in the cluster. Skipping update for '{model_name}' with state: {state}.")
|
|
91
|
+
else:
|
|
92
|
+
logger.warning(f"**ModelCluster** Model updated: {model_id}, with state: {state} => This state is not handled...")
|
|
93
|
+
else:
|
|
94
|
+
if state == "STOPPED":
|
|
95
|
+
logger.debug(f"**ModelCluster** Model with ID {model_id} is not found in the cluster state, and its state is 'STOPPED'. No action is required.")
|
|
96
|
+
elif state == "RUNNING":
|
|
97
|
+
logger.debug(f"**ModelCluster** New model with ID {model_id} is running, we add it to the cluster state.")
|
|
98
|
+
model_runner = ModelRunner(model_id, model_name, ip, port)
|
|
99
|
+
tasks.append(self.add_model_runner(model_runner))
|
|
100
|
+
else:
|
|
101
|
+
logger.warning(f"**ModelCluster** Model updated: {model_id}, with state: {state} => This state is not handled...")
|
|
102
|
+
|
|
103
|
+
await asyncio.gather(*tasks)
|
|
104
|
+
|
|
105
|
+
async def add_model_runner(self, model_runner: ModelRunner):
|
|
106
|
+
"""
|
|
107
|
+
Asynchronously initialize a model_runner and add it to the cluster state.
|
|
108
|
+
"""
|
|
109
|
+
is_initialized = await model_runner.init()
|
|
110
|
+
if is_initialized:
|
|
111
|
+
self.models_run[model_runner.model_id] = model_runner
|
|
112
|
+
|
|
113
|
+
async def remove_model_runner(self, model_runner: ModelRunner):
|
|
114
|
+
"""
|
|
115
|
+
Asynchronously initialize a model_runner and add it to the cluster state.
|
|
116
|
+
"""
|
|
117
|
+
await model_runner.close()
|
|
118
|
+
del self.models_run[model_runner.model_id]
|
|
119
|
+
|
|
120
|
+
async def start(self):
|
|
121
|
+
"""
|
|
122
|
+
Start the WebSocket client and handle events.
|
|
123
|
+
"""
|
|
124
|
+
try:
|
|
125
|
+
await self.ws_client.connect()
|
|
126
|
+
except Exception as e:
|
|
127
|
+
logger.error(f"**ModelCluster** Failed to start WebSocket client: {e}", exc_info=True)
|
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import logging
|
|
3
|
+
from enum import Enum
|
|
4
|
+
|
|
5
|
+
from model_runner_client.model_cluster import ModelCluster
|
|
6
|
+
from model_runner_client.model_runner import ModelRunner
|
|
7
|
+
from model_runner_client.protos.model_runner_pb2 import DataType
|
|
8
|
+
|
|
9
|
+
logger = logging.getLogger("model_runner_client")
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class ModelPredictResult:
|
|
13
|
+
class Status(Enum):
|
|
14
|
+
SUCCESS = "SUCCESS"
|
|
15
|
+
FAILED = "FAILED"
|
|
16
|
+
TIMEOUT = "TIMEOUT"
|
|
17
|
+
|
|
18
|
+
def __init__(self, model_runner: ModelRunner, result: any, status: Status):
|
|
19
|
+
self.model_runner = model_runner
|
|
20
|
+
self.result = result
|
|
21
|
+
self.status = status
|
|
22
|
+
|
|
23
|
+
def __str__(self):
|
|
24
|
+
return f"ModelPredictResult(model_runner={self.model_runner}, result={self.result}, status={self.status.name})"
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class ModelConcurrentRunner:
|
|
28
|
+
def __init__(self, timeout: int, crunch_id: str, host: str, port: int):
|
|
29
|
+
self.timeout = timeout
|
|
30
|
+
self.host = host
|
|
31
|
+
self.port = port
|
|
32
|
+
self.model_cluster = ModelCluster(crunch_id, self.host, self.port)
|
|
33
|
+
|
|
34
|
+
# TODO: If the model returns failures exceeding max_consecutive_failures, exclude the model. Maybe also inform the orchestrator to STOP the model ?
|
|
35
|
+
#self.max_consecutive_failures
|
|
36
|
+
|
|
37
|
+
# TODO: Implement this. If the option is enabled, allow the model time to recover after a timeout.
|
|
38
|
+
#self.enable_recovery_mode
|
|
39
|
+
#self.recovery_time
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
async def init(self):
|
|
43
|
+
await self.model_cluster.init()
|
|
44
|
+
|
|
45
|
+
async def sync(self):
|
|
46
|
+
await self.model_cluster.sync()
|
|
47
|
+
|
|
48
|
+
async def predict(self, argument_type: DataType, argument_value: bytes) -> dict[ModelRunner, ModelPredictResult]:
|
|
49
|
+
tasks = [self._predict_with_timeout(model, argument_type, argument_value)
|
|
50
|
+
for model in self.model_cluster.models_run.values()]
|
|
51
|
+
logger.debug(f"**ModelConcurrentRunner** predict tasks: {tasks}")
|
|
52
|
+
results = await asyncio.gather(*tasks, return_exceptions=True)
|
|
53
|
+
|
|
54
|
+
return {result.model_runner: result for result in results}
|
|
55
|
+
|
|
56
|
+
async def _predict_with_timeout(self, model: ModelRunner, argument_type: DataType, argument_value: bytes):
|
|
57
|
+
try:
|
|
58
|
+
result = await asyncio.wait_for(
|
|
59
|
+
model.predict(argument_type, argument_value), timeout=self.timeout
|
|
60
|
+
)
|
|
61
|
+
return ModelPredictResult(model, result, ModelPredictResult.Status.SUCCESS)
|
|
62
|
+
except asyncio.TimeoutError:
|
|
63
|
+
return ModelPredictResult(model, None, ModelPredictResult.Status.TIMEOUT)
|
|
64
|
+
except Exception as e:
|
|
65
|
+
return ModelPredictResult(model, str(e), ModelPredictResult.Status.FAILED)
|
|
@@ -0,0 +1,76 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import atexit
|
|
3
|
+
import logging
|
|
4
|
+
|
|
5
|
+
import grpc
|
|
6
|
+
from google.protobuf import empty_pb2
|
|
7
|
+
from grpc.aio import AioRpcError
|
|
8
|
+
|
|
9
|
+
from model_runner_client.protos.model_runner_pb2 import DataType, InferRequest
|
|
10
|
+
from model_runner_client.protos.model_runner_pb2_grpc import ModelRunnerStub
|
|
11
|
+
from model_runner_client.utils.datatype_transformer import decode_data
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger("model_runner_client")
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class ModelRunner:
|
|
17
|
+
def __init__(self, model_id: str, model_name: str, ip: str, port: int):
|
|
18
|
+
|
|
19
|
+
self.model_id = model_id
|
|
20
|
+
self.model_name = model_name
|
|
21
|
+
self.ip = ip
|
|
22
|
+
self.port = port
|
|
23
|
+
logger.info(f"**ModelRunner** New model runner created: {self.model_id}, {self.model_name}, {self.ip}:{self.port}, let's connect it")
|
|
24
|
+
|
|
25
|
+
self.grpc_channel = None
|
|
26
|
+
self.grpc_stub = None
|
|
27
|
+
self.retry_attempts = 5 # args ?
|
|
28
|
+
self.min_retry_interval = 2 # 2 seconds
|
|
29
|
+
self.closed = False
|
|
30
|
+
|
|
31
|
+
def __del__(self):
|
|
32
|
+
logger.debug(f"**ModelRunner** Model runner {self.model_id} is destroyed")
|
|
33
|
+
atexit.register(self.close_sync)
|
|
34
|
+
|
|
35
|
+
async def init(self) -> bool:
|
|
36
|
+
for attempt in range(1, self.retry_attempts + 1):
|
|
37
|
+
if self.closed:
|
|
38
|
+
logger.debug(f"**ModelRunner** Model runner {self.model_id} closed, aborting initialization")
|
|
39
|
+
return False
|
|
40
|
+
try:
|
|
41
|
+
self.grpc_channel = grpc.aio.insecure_channel(f"{self.ip}:{self.port}")
|
|
42
|
+
self.grpc_stub = ModelRunnerStub(self.grpc_channel)
|
|
43
|
+
await self.grpc_stub.Setup(empty_pb2.Empty()) # maybe orchestrator has to do that ?
|
|
44
|
+
logger.info(f"**ModelRunner** model runner: {self.model_id}, {self.model_name}, is connected and ready")
|
|
45
|
+
return True
|
|
46
|
+
except (AioRpcError, asyncio.TimeoutError) as e:
|
|
47
|
+
logger.error(f"**ModelRunner** Model {self.model_id} initialization failed on attempt {attempt}/{self.retry_attempts}: {e}")
|
|
48
|
+
except Exception as e:
|
|
49
|
+
logger.error(f"**ModelRunner** Unexpected error during initialization of model {self.model_id}: {e}", exc_info=True)
|
|
50
|
+
|
|
51
|
+
if attempt < self.retry_attempts:
|
|
52
|
+
backoff_time = 2 ** attempt # Backoff with exponential delay
|
|
53
|
+
logger.warning(f"**ModelRunner** Retrying in {backoff_time} seconds...")
|
|
54
|
+
await asyncio.sleep(backoff_time)
|
|
55
|
+
else:
|
|
56
|
+
logger.error(f"**ModelRunner** Model {self.model_id} failed to initialize after {self.retry_attempts} attempts.")
|
|
57
|
+
# todo what is the behavior here ? remove it locally ?
|
|
58
|
+
return False
|
|
59
|
+
|
|
60
|
+
async def predict(self, argument_type: DataType, argument_value: bytes):
|
|
61
|
+
logger.debug(f"**ModelRunner** Doing prediction of model_id:{self.model_id}, name:{self.model_name}, argument_type:{argument_type}")
|
|
62
|
+
prediction_request = InferRequest(type=argument_type, argument=argument_value)
|
|
63
|
+
|
|
64
|
+
response = await self.grpc_stub.Infer(prediction_request)
|
|
65
|
+
|
|
66
|
+
return decode_data(response.prediction, response.type)
|
|
67
|
+
|
|
68
|
+
def close_sync(self):
|
|
69
|
+
loop = asyncio.get_event_loop()
|
|
70
|
+
loop.run_until_complete(self.close())
|
|
71
|
+
|
|
72
|
+
async def close(self):
|
|
73
|
+
self.closed = True
|
|
74
|
+
if self.grpc_channel:
|
|
75
|
+
await self.grpc_channel.close()
|
|
76
|
+
logger.debug(f"**ModelRunner** Model runner {self.model_id} grpc connection closed")
|
|
File without changes
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
syntax = "proto3";
|
|
2
|
+
|
|
3
|
+
import "google/protobuf/empty.proto";
|
|
4
|
+
|
|
5
|
+
service ModelRunner {
|
|
6
|
+
rpc InferStream(stream InferRequest) returns (stream InferResponse);
|
|
7
|
+
rpc Infer(InferRequest) returns (InferResponse);
|
|
8
|
+
rpc Setup(google.protobuf.Empty) returns (google.protobuf.Empty);
|
|
9
|
+
rpc Reinitialize(google.protobuf.Empty) returns (google.protobuf.Empty);
|
|
10
|
+
}
|
|
11
|
+
|
|
12
|
+
enum DataType {
|
|
13
|
+
DOUBLE = 0;
|
|
14
|
+
INT = 1;
|
|
15
|
+
STRING = 2;
|
|
16
|
+
PARQUET = 3;
|
|
17
|
+
ARROW = 4;
|
|
18
|
+
JSON = 5;
|
|
19
|
+
}
|
|
20
|
+
|
|
21
|
+
message InferRequest {
|
|
22
|
+
DataType type = 1;
|
|
23
|
+
bytes argument = 2;
|
|
24
|
+
}
|
|
25
|
+
|
|
26
|
+
message InferResponse {
|
|
27
|
+
DataType type = 1;
|
|
28
|
+
bytes prediction = 2;
|
|
29
|
+
}
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
|
3
|
+
# NO CHECKED-IN PROTOBUF GENCODE
|
|
4
|
+
# source: model_runner.proto
|
|
5
|
+
# Protobuf Python Version: 5.29.0
|
|
6
|
+
"""Generated protocol buffer code."""
|
|
7
|
+
from google.protobuf import descriptor as _descriptor
|
|
8
|
+
from google.protobuf import descriptor_pool as _descriptor_pool
|
|
9
|
+
from google.protobuf import runtime_version as _runtime_version
|
|
10
|
+
from google.protobuf import symbol_database as _symbol_database
|
|
11
|
+
from google.protobuf.internal import builder as _builder
|
|
12
|
+
_runtime_version.ValidateProtobufRuntimeVersion(
|
|
13
|
+
_runtime_version.Domain.PUBLIC,
|
|
14
|
+
5,
|
|
15
|
+
29,
|
|
16
|
+
0,
|
|
17
|
+
'',
|
|
18
|
+
'model_runner.proto'
|
|
19
|
+
)
|
|
20
|
+
# @@protoc_insertion_point(imports)
|
|
21
|
+
|
|
22
|
+
_sym_db = _symbol_database.Default()
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
from google.protobuf import empty_pb2 as google_dot_protobuf_dot_empty__pb2
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x12model_runner.proto\x1a\x1bgoogle/protobuf/empty.proto\"9\n\x0cInferRequest\x12\x17\n\x04type\x18\x01 \x01(\x0e\x32\t.DataType\x12\x10\n\x08\x61rgument\x18\x02 \x01(\x0c\"<\n\rInferResponse\x12\x17\n\x04type\x18\x01 \x01(\x0e\x32\t.DataType\x12\x12\n\nprediction\x18\x02 \x01(\x0c*M\n\x08\x44\x61taType\x12\n\n\x06\x44OUBLE\x10\x00\x12\x07\n\x03INT\x10\x01\x12\n\n\x06STRING\x10\x02\x12\x0b\n\x07PARQUET\x10\x03\x12\t\n\x05\x41RROW\x10\x04\x12\x08\n\x04JSON\x10\x05\x32\xe0\x01\n\x0bModelRunner\x12\x30\n\x0bInferStream\x12\r.InferRequest\x1a\x0e.InferResponse(\x01\x30\x01\x12&\n\x05Infer\x12\r.InferRequest\x1a\x0e.InferResponse\x12\x37\n\x05Setup\x12\x16.google.protobuf.Empty\x1a\x16.google.protobuf.Empty\x12>\n\x0cReinitialize\x12\x16.google.protobuf.Empty\x1a\x16.google.protobuf.Emptyb\x06proto3')
|
|
29
|
+
|
|
30
|
+
_globals = globals()
|
|
31
|
+
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
|
32
|
+
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'model_runner_pb2', _globals)
|
|
33
|
+
if not _descriptor._USE_C_DESCRIPTORS:
|
|
34
|
+
DESCRIPTOR._loaded_options = None
|
|
35
|
+
_globals['_DATATYPE']._serialized_start=172
|
|
36
|
+
_globals['_DATATYPE']._serialized_end=249
|
|
37
|
+
_globals['_INFERREQUEST']._serialized_start=51
|
|
38
|
+
_globals['_INFERREQUEST']._serialized_end=108
|
|
39
|
+
_globals['_INFERRESPONSE']._serialized_start=110
|
|
40
|
+
_globals['_INFERRESPONSE']._serialized_end=170
|
|
41
|
+
_globals['_MODELRUNNER']._serialized_start=252
|
|
42
|
+
_globals['_MODELRUNNER']._serialized_end=476
|
|
43
|
+
# @@protoc_insertion_point(module_scope)
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
from google.protobuf import empty_pb2 as _empty_pb2
|
|
2
|
+
from google.protobuf.internal import enum_type_wrapper as _enum_type_wrapper
|
|
3
|
+
from google.protobuf import descriptor as _descriptor
|
|
4
|
+
from google.protobuf import message as _message
|
|
5
|
+
from typing import ClassVar as _ClassVar, Optional as _Optional, Union as _Union
|
|
6
|
+
|
|
7
|
+
DESCRIPTOR: _descriptor.FileDescriptor
|
|
8
|
+
|
|
9
|
+
class DataType(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):
|
|
10
|
+
__slots__ = ()
|
|
11
|
+
DOUBLE: _ClassVar[DataType]
|
|
12
|
+
INT: _ClassVar[DataType]
|
|
13
|
+
STRING: _ClassVar[DataType]
|
|
14
|
+
PARQUET: _ClassVar[DataType]
|
|
15
|
+
ARROW: _ClassVar[DataType]
|
|
16
|
+
JSON: _ClassVar[DataType]
|
|
17
|
+
DOUBLE: DataType
|
|
18
|
+
INT: DataType
|
|
19
|
+
STRING: DataType
|
|
20
|
+
PARQUET: DataType
|
|
21
|
+
ARROW: DataType
|
|
22
|
+
JSON: DataType
|
|
23
|
+
|
|
24
|
+
class InferRequest(_message.Message):
|
|
25
|
+
__slots__ = ("type", "argument")
|
|
26
|
+
TYPE_FIELD_NUMBER: _ClassVar[int]
|
|
27
|
+
ARGUMENT_FIELD_NUMBER: _ClassVar[int]
|
|
28
|
+
type: DataType
|
|
29
|
+
argument: bytes
|
|
30
|
+
def __init__(self, type: _Optional[_Union[DataType, str]] = ..., argument: _Optional[bytes] = ...) -> None: ...
|
|
31
|
+
|
|
32
|
+
class InferResponse(_message.Message):
|
|
33
|
+
__slots__ = ("type", "prediction")
|
|
34
|
+
TYPE_FIELD_NUMBER: _ClassVar[int]
|
|
35
|
+
PREDICTION_FIELD_NUMBER: _ClassVar[int]
|
|
36
|
+
type: DataType
|
|
37
|
+
prediction: bytes
|
|
38
|
+
def __init__(self, type: _Optional[_Union[DataType, str]] = ..., prediction: _Optional[bytes] = ...) -> None: ...
|
|
@@ -0,0 +1,227 @@
|
|
|
1
|
+
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
|
|
2
|
+
"""Client and server classes corresponding to protobuf-defined services."""
|
|
3
|
+
import grpc
|
|
4
|
+
import warnings
|
|
5
|
+
|
|
6
|
+
from google.protobuf import empty_pb2 as google_dot_protobuf_dot_empty__pb2
|
|
7
|
+
from . import model_runner_pb2 as model__runner__pb2
|
|
8
|
+
|
|
9
|
+
GRPC_GENERATED_VERSION = '1.69.0'
|
|
10
|
+
GRPC_VERSION = grpc.__version__
|
|
11
|
+
_version_not_supported = False
|
|
12
|
+
|
|
13
|
+
try:
|
|
14
|
+
from grpc._utilities import first_version_is_lower
|
|
15
|
+
_version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION)
|
|
16
|
+
except ImportError:
|
|
17
|
+
_version_not_supported = True
|
|
18
|
+
|
|
19
|
+
if _version_not_supported:
|
|
20
|
+
raise RuntimeError(
|
|
21
|
+
f'The grpc package installed is at version {GRPC_VERSION},'
|
|
22
|
+
+ f' but the generated code in model_runner_pb2_grpc.py depends on'
|
|
23
|
+
+ f' grpcio>={GRPC_GENERATED_VERSION}.'
|
|
24
|
+
+ f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}'
|
|
25
|
+
+ f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.'
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class ModelRunnerStub(object):
|
|
30
|
+
"""Missing associated documentation comment in .proto file."""
|
|
31
|
+
|
|
32
|
+
def __init__(self, channel):
|
|
33
|
+
"""Constructor.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
channel: A grpc.Channel.
|
|
37
|
+
"""
|
|
38
|
+
self.InferStream = channel.stream_stream(
|
|
39
|
+
'/ModelRunner/InferStream',
|
|
40
|
+
request_serializer=model__runner__pb2.InferRequest.SerializeToString,
|
|
41
|
+
response_deserializer=model__runner__pb2.InferResponse.FromString,
|
|
42
|
+
_registered_method=True)
|
|
43
|
+
self.Infer = channel.unary_unary(
|
|
44
|
+
'/ModelRunner/Infer',
|
|
45
|
+
request_serializer=model__runner__pb2.InferRequest.SerializeToString,
|
|
46
|
+
response_deserializer=model__runner__pb2.InferResponse.FromString,
|
|
47
|
+
_registered_method=True)
|
|
48
|
+
self.Setup = channel.unary_unary(
|
|
49
|
+
'/ModelRunner/Setup',
|
|
50
|
+
request_serializer=google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString,
|
|
51
|
+
response_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString,
|
|
52
|
+
_registered_method=True)
|
|
53
|
+
self.Reinitialize = channel.unary_unary(
|
|
54
|
+
'/ModelRunner/Reinitialize',
|
|
55
|
+
request_serializer=google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString,
|
|
56
|
+
response_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString,
|
|
57
|
+
_registered_method=True)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class ModelRunnerServicer(object):
|
|
61
|
+
"""Missing associated documentation comment in .proto file."""
|
|
62
|
+
|
|
63
|
+
def InferStream(self, request_iterator, context):
|
|
64
|
+
"""Missing associated documentation comment in .proto file."""
|
|
65
|
+
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
|
66
|
+
context.set_details('Method not implemented!')
|
|
67
|
+
raise NotImplementedError('Method not implemented!')
|
|
68
|
+
|
|
69
|
+
def Infer(self, request, context):
|
|
70
|
+
"""Missing associated documentation comment in .proto file."""
|
|
71
|
+
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
|
72
|
+
context.set_details('Method not implemented!')
|
|
73
|
+
raise NotImplementedError('Method not implemented!')
|
|
74
|
+
|
|
75
|
+
def Setup(self, request, context):
|
|
76
|
+
"""Missing associated documentation comment in .proto file."""
|
|
77
|
+
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
|
78
|
+
context.set_details('Method not implemented!')
|
|
79
|
+
raise NotImplementedError('Method not implemented!')
|
|
80
|
+
|
|
81
|
+
def Reinitialize(self, request, context):
|
|
82
|
+
"""Missing associated documentation comment in .proto file."""
|
|
83
|
+
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
|
84
|
+
context.set_details('Method not implemented!')
|
|
85
|
+
raise NotImplementedError('Method not implemented!')
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def add_ModelRunnerServicer_to_server(servicer, server):
|
|
89
|
+
rpc_method_handlers = {
|
|
90
|
+
'InferStream': grpc.stream_stream_rpc_method_handler(
|
|
91
|
+
servicer.InferStream,
|
|
92
|
+
request_deserializer=model__runner__pb2.InferRequest.FromString,
|
|
93
|
+
response_serializer=model__runner__pb2.InferResponse.SerializeToString,
|
|
94
|
+
),
|
|
95
|
+
'Infer': grpc.unary_unary_rpc_method_handler(
|
|
96
|
+
servicer.Infer,
|
|
97
|
+
request_deserializer=model__runner__pb2.InferRequest.FromString,
|
|
98
|
+
response_serializer=model__runner__pb2.InferResponse.SerializeToString,
|
|
99
|
+
),
|
|
100
|
+
'Setup': grpc.unary_unary_rpc_method_handler(
|
|
101
|
+
servicer.Setup,
|
|
102
|
+
request_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString,
|
|
103
|
+
response_serializer=google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString,
|
|
104
|
+
),
|
|
105
|
+
'Reinitialize': grpc.unary_unary_rpc_method_handler(
|
|
106
|
+
servicer.Reinitialize,
|
|
107
|
+
request_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString,
|
|
108
|
+
response_serializer=google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString,
|
|
109
|
+
),
|
|
110
|
+
}
|
|
111
|
+
generic_handler = grpc.method_handlers_generic_handler(
|
|
112
|
+
'ModelRunner', rpc_method_handlers)
|
|
113
|
+
server.add_generic_rpc_handlers((generic_handler,))
|
|
114
|
+
server.add_registered_method_handlers('ModelRunner', rpc_method_handlers)
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
# This class is part of an EXPERIMENTAL API.
|
|
118
|
+
class ModelRunner(object):
|
|
119
|
+
"""Missing associated documentation comment in .proto file."""
|
|
120
|
+
|
|
121
|
+
@staticmethod
|
|
122
|
+
def InferStream(request_iterator,
|
|
123
|
+
target,
|
|
124
|
+
options=(),
|
|
125
|
+
channel_credentials=None,
|
|
126
|
+
call_credentials=None,
|
|
127
|
+
insecure=False,
|
|
128
|
+
compression=None,
|
|
129
|
+
wait_for_ready=None,
|
|
130
|
+
timeout=None,
|
|
131
|
+
metadata=None):
|
|
132
|
+
return grpc.experimental.stream_stream(
|
|
133
|
+
request_iterator,
|
|
134
|
+
target,
|
|
135
|
+
'/ModelRunner/InferStream',
|
|
136
|
+
model__runner__pb2.InferRequest.SerializeToString,
|
|
137
|
+
model__runner__pb2.InferResponse.FromString,
|
|
138
|
+
options,
|
|
139
|
+
channel_credentials,
|
|
140
|
+
insecure,
|
|
141
|
+
call_credentials,
|
|
142
|
+
compression,
|
|
143
|
+
wait_for_ready,
|
|
144
|
+
timeout,
|
|
145
|
+
metadata,
|
|
146
|
+
_registered_method=True)
|
|
147
|
+
|
|
148
|
+
@staticmethod
|
|
149
|
+
def Infer(request,
|
|
150
|
+
target,
|
|
151
|
+
options=(),
|
|
152
|
+
channel_credentials=None,
|
|
153
|
+
call_credentials=None,
|
|
154
|
+
insecure=False,
|
|
155
|
+
compression=None,
|
|
156
|
+
wait_for_ready=None,
|
|
157
|
+
timeout=None,
|
|
158
|
+
metadata=None):
|
|
159
|
+
return grpc.experimental.unary_unary(
|
|
160
|
+
request,
|
|
161
|
+
target,
|
|
162
|
+
'/ModelRunner/Infer',
|
|
163
|
+
model__runner__pb2.InferRequest.SerializeToString,
|
|
164
|
+
model__runner__pb2.InferResponse.FromString,
|
|
165
|
+
options,
|
|
166
|
+
channel_credentials,
|
|
167
|
+
insecure,
|
|
168
|
+
call_credentials,
|
|
169
|
+
compression,
|
|
170
|
+
wait_for_ready,
|
|
171
|
+
timeout,
|
|
172
|
+
metadata,
|
|
173
|
+
_registered_method=True)
|
|
174
|
+
|
|
175
|
+
@staticmethod
|
|
176
|
+
def Setup(request,
|
|
177
|
+
target,
|
|
178
|
+
options=(),
|
|
179
|
+
channel_credentials=None,
|
|
180
|
+
call_credentials=None,
|
|
181
|
+
insecure=False,
|
|
182
|
+
compression=None,
|
|
183
|
+
wait_for_ready=None,
|
|
184
|
+
timeout=None,
|
|
185
|
+
metadata=None):
|
|
186
|
+
return grpc.experimental.unary_unary(
|
|
187
|
+
request,
|
|
188
|
+
target,
|
|
189
|
+
'/ModelRunner/Setup',
|
|
190
|
+
google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString,
|
|
191
|
+
google_dot_protobuf_dot_empty__pb2.Empty.FromString,
|
|
192
|
+
options,
|
|
193
|
+
channel_credentials,
|
|
194
|
+
insecure,
|
|
195
|
+
call_credentials,
|
|
196
|
+
compression,
|
|
197
|
+
wait_for_ready,
|
|
198
|
+
timeout,
|
|
199
|
+
metadata,
|
|
200
|
+
_registered_method=True)
|
|
201
|
+
|
|
202
|
+
@staticmethod
|
|
203
|
+
def Reinitialize(request,
|
|
204
|
+
target,
|
|
205
|
+
options=(),
|
|
206
|
+
channel_credentials=None,
|
|
207
|
+
call_credentials=None,
|
|
208
|
+
insecure=False,
|
|
209
|
+
compression=None,
|
|
210
|
+
wait_for_ready=None,
|
|
211
|
+
timeout=None,
|
|
212
|
+
metadata=None):
|
|
213
|
+
return grpc.experimental.unary_unary(
|
|
214
|
+
request,
|
|
215
|
+
target,
|
|
216
|
+
'/ModelRunner/Reinitialize',
|
|
217
|
+
google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString,
|
|
218
|
+
google_dot_protobuf_dot_empty__pb2.Empty.FromString,
|
|
219
|
+
options,
|
|
220
|
+
channel_credentials,
|
|
221
|
+
insecure,
|
|
222
|
+
call_credentials,
|
|
223
|
+
compression,
|
|
224
|
+
wait_for_ready,
|
|
225
|
+
timeout,
|
|
226
|
+
metadata,
|
|
227
|
+
_registered_method=True)
|
|
File without changes
|
|
@@ -0,0 +1,71 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import struct
|
|
3
|
+
import io
|
|
4
|
+
import pandas
|
|
5
|
+
import pyarrow.parquet as pq
|
|
6
|
+
|
|
7
|
+
from model_runner_client.protos.model_runner_pb2 import DataType
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
# Encoder: Converts data to bytes
|
|
11
|
+
def encode_data(data_type: DataType, data) -> bytes:
|
|
12
|
+
if data_type == DataType.DOUBLE:
|
|
13
|
+
return struct.pack("d", data)
|
|
14
|
+
elif data_type == DataType.INT:
|
|
15
|
+
return data.to_bytes(8, byteorder="little", signed=True)
|
|
16
|
+
elif data_type == DataType.STRING:
|
|
17
|
+
return data.encode("utf-8")
|
|
18
|
+
elif data_type == DataType.PARQUET:
|
|
19
|
+
table = pandas.Table.from_pandas(data)
|
|
20
|
+
sink = io.BytesIO()
|
|
21
|
+
pandas.write_table(table, sink)
|
|
22
|
+
return sink.getvalue()
|
|
23
|
+
elif data_type == DataType.JSON:
|
|
24
|
+
try:
|
|
25
|
+
json_data = json.dumps(data) # Convert the object to a JSON string
|
|
26
|
+
return json_data.encode("utf-8") # Return the JSON string as bytes
|
|
27
|
+
except TypeError as e:
|
|
28
|
+
raise ValueError(f"Data cannot be serialized to JSON: {e}")
|
|
29
|
+
else:
|
|
30
|
+
raise ValueError(f"Unsupported data type: {data_type}")
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
# Decoder: Converts bytes to data
|
|
34
|
+
def decode_data(data_bytes: bytes, data_type: DataType):
|
|
35
|
+
if data_type == DataType.DOUBLE:
|
|
36
|
+
return struct.unpack("d", data_bytes)[0]
|
|
37
|
+
elif data_type == DataType.INT:
|
|
38
|
+
return int.from_bytes(data_bytes, byteorder="little", signed=True)
|
|
39
|
+
elif data_type == DataType.STRING:
|
|
40
|
+
return data_bytes.decode("utf-8")
|
|
41
|
+
elif data_type == DataType.PARQUET:
|
|
42
|
+
buffer = io.BytesIO(data_bytes)
|
|
43
|
+
return pq.read_table(buffer).to_pandas()
|
|
44
|
+
elif data_type == DataType.JSON:
|
|
45
|
+
try:
|
|
46
|
+
json_data = data_bytes.decode("utf-8")
|
|
47
|
+
return json.loads(json_data)
|
|
48
|
+
except json.JSONDecodeError as e:
|
|
49
|
+
raise ValueError(f"Failed to decode JSON data: {e}")
|
|
50
|
+
|
|
51
|
+
else:
|
|
52
|
+
raise ValueError(f"Unsupported data type: {data_type}")
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def detect_data_type(data) -> DataType:
|
|
56
|
+
"""
|
|
57
|
+
Detects the data type based on the Python object and returns
|
|
58
|
+
the corresponding DataType enum.
|
|
59
|
+
"""
|
|
60
|
+
if isinstance(data, float):
|
|
61
|
+
return DataType.DOUBLE
|
|
62
|
+
elif isinstance(data, int):
|
|
63
|
+
return DataType.INT
|
|
64
|
+
elif isinstance(data, str):
|
|
65
|
+
return DataType.STRING
|
|
66
|
+
elif isinstance(data, pandas.DataFrame):
|
|
67
|
+
return DataType.PARQUET
|
|
68
|
+
elif isinstance(data, dict) or isinstance(data, list):
|
|
69
|
+
return DataType.JSON
|
|
70
|
+
else:
|
|
71
|
+
raise ValueError(f"Unsupported data type: {type(data)}")
|
|
@@ -0,0 +1,134 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import atexit
|
|
3
|
+
import json
|
|
4
|
+
import logging
|
|
5
|
+
|
|
6
|
+
import websockets
|
|
7
|
+
from websockets import State
|
|
8
|
+
|
|
9
|
+
logger = logging.getLogger("model_runner_client")
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class WebsocketClient:
|
|
13
|
+
def __init__(self, host, port, crunch_id, event_handler=None):
|
|
14
|
+
"""
|
|
15
|
+
WebsocketClient constructor.
|
|
16
|
+
|
|
17
|
+
:param host: WebSocket server host.
|
|
18
|
+
:param port: WebSocket server port.
|
|
19
|
+
:param crunch_id: Crunch ID used to connect to the server.
|
|
20
|
+
:param event_handler: Optional handler for WebSocket events (delegated to ModelCluster).
|
|
21
|
+
"""
|
|
22
|
+
self.retry_interval = 10
|
|
23
|
+
self.max_retries = 5
|
|
24
|
+
self.listening = None
|
|
25
|
+
self.host = host
|
|
26
|
+
self.port = port
|
|
27
|
+
self.crunch_id = crunch_id
|
|
28
|
+
self.websocket = None
|
|
29
|
+
self.event_handler = event_handler # Delegate to ModelCluster
|
|
30
|
+
|
|
31
|
+
def __del__(self):
|
|
32
|
+
atexit.register(self.disconnect_sync)
|
|
33
|
+
|
|
34
|
+
async def connect(self):
|
|
35
|
+
"""
|
|
36
|
+
Establish a WebSocket connection.
|
|
37
|
+
"""
|
|
38
|
+
retry_count = 0
|
|
39
|
+
while self.max_retries > retry_count:
|
|
40
|
+
try:
|
|
41
|
+
uri = f"ws://{self.host}:{self.port}/{self.crunch_id}"
|
|
42
|
+
logger.info(f"Connecting to WebSocket server at {uri}")
|
|
43
|
+
self.websocket = await websockets.connect(uri)
|
|
44
|
+
logger.info(f"Connected to WebSocket server at {uri}")
|
|
45
|
+
break
|
|
46
|
+
except (websockets.exceptions.ConnectionClosed, ConnectionRefusedError, OSError, asyncio.TimeoutError) as e:
|
|
47
|
+
logger.warning(f"Connection error ({e.__class__.__name__}): Retrying in {self.retry_interval} seconds...")
|
|
48
|
+
except Exception as e:
|
|
49
|
+
logger.error(f"Unexpected error ({e.__class__.__name__}): {e}", exc_info=True)
|
|
50
|
+
logger.warning(f"Retrying in {self.retry_interval} seconds...")
|
|
51
|
+
|
|
52
|
+
await asyncio.sleep(self.retry_interval)
|
|
53
|
+
retry_count += 1
|
|
54
|
+
|
|
55
|
+
async def init(self):
|
|
56
|
+
"""
|
|
57
|
+
Listen first message who is init from the WebSocket server.
|
|
58
|
+
"""
|
|
59
|
+
# retry here doesn't make sens, it comme after connection and connection handle retries
|
|
60
|
+
try:
|
|
61
|
+
message = await self.websocket.recv()
|
|
62
|
+
await self.handle_event(message)
|
|
63
|
+
except websockets.exceptions.ConnectionClosed:
|
|
64
|
+
logger.warning("WebSocket connection closed by the server.")
|
|
65
|
+
except Exception as e:
|
|
66
|
+
logger.error(f"Error while listening to WebSocket messages: {e}", exc_info=True)
|
|
67
|
+
|
|
68
|
+
async def listen(self):
|
|
69
|
+
"""
|
|
70
|
+
Listen for messages from the WebSocket server.
|
|
71
|
+
"""
|
|
72
|
+
self.listening = True
|
|
73
|
+
while self.listening:
|
|
74
|
+
try:
|
|
75
|
+
if not await self.is_connected():
|
|
76
|
+
await self.connect()
|
|
77
|
+
|
|
78
|
+
logger.info("Listening for messages...")
|
|
79
|
+
async for message in self.websocket:
|
|
80
|
+
await self.handle_event(message)
|
|
81
|
+
except (websockets.exceptions.ConnectionClosed, asyncio.TimeoutError):
|
|
82
|
+
logger.warning(f"Connection lost. Retrying in {self.retry_interval} seconds...")
|
|
83
|
+
await asyncio.sleep(self.retry_interval)
|
|
84
|
+
except Exception as e:
|
|
85
|
+
logger.error(f"An unexpected error occurred: {e}", exc_info=True)
|
|
86
|
+
logger.warning(f"Retrying in {self.retry_interval} seconds...")
|
|
87
|
+
await asyncio.sleep(self.retry_interval)
|
|
88
|
+
|
|
89
|
+
finally:
|
|
90
|
+
# Cleanup or reset if needed after the loop iteration
|
|
91
|
+
# For example, clear websocket references if disconnected
|
|
92
|
+
if self.websocket and not self.websocket.state:
|
|
93
|
+
self.websocket = None
|
|
94
|
+
|
|
95
|
+
async def is_connected(self):
|
|
96
|
+
"""
|
|
97
|
+
Check if the WebSocket connection is open.
|
|
98
|
+
:return: True if the client is connected, False otherwise.
|
|
99
|
+
"""
|
|
100
|
+
return self.websocket is not None and self.websocket.state == State.OPEN
|
|
101
|
+
|
|
102
|
+
async def handle_event(self, message):
|
|
103
|
+
"""
|
|
104
|
+
Handle incoming WebSocket messages and forward to the event handler.
|
|
105
|
+
|
|
106
|
+
:param message: Message received from the server.
|
|
107
|
+
"""
|
|
108
|
+
try:
|
|
109
|
+
event = json.loads(message)
|
|
110
|
+
event_type = event.get("event")
|
|
111
|
+
data = event.get("data")
|
|
112
|
+
|
|
113
|
+
logger.debug(f"Received event: {event_type}, data: {data}")
|
|
114
|
+
|
|
115
|
+
# Delegate to the event handler, if available
|
|
116
|
+
if self.event_handler:
|
|
117
|
+
await self.event_handler(event_type, data)
|
|
118
|
+
else:
|
|
119
|
+
logger.warning("No event handler defined. Event will be ignored.")
|
|
120
|
+
except json.JSONDecodeError:
|
|
121
|
+
logger.error(f"Failed to decode WebSocket message: {message}")
|
|
122
|
+
|
|
123
|
+
def disconnect_sync(self):
|
|
124
|
+
loop = asyncio.get_event_loop()
|
|
125
|
+
loop.run_until_complete(self.disconnect())
|
|
126
|
+
|
|
127
|
+
async def disconnect(self):
|
|
128
|
+
"""
|
|
129
|
+
Close the WebSocket connection.
|
|
130
|
+
"""
|
|
131
|
+
self.listening = False
|
|
132
|
+
if self.websocket:
|
|
133
|
+
await self.websocket.close()
|
|
134
|
+
logger.info("WebSocket connection closed.")
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
[tool.poetry]
|
|
2
|
+
name = "model-runner-client"
|
|
3
|
+
version = "0.1.0"
|
|
4
|
+
description = "Model Runner Client is a Python library that allows Coordinators to manage real-time model synchronization and perform concurrent predictions on distributed model nodes within a Crunch."
|
|
5
|
+
authors = ["boutrig abdennour <abdennour.boutrig@crunchdao.com>"]
|
|
6
|
+
readme = "README.md"
|
|
7
|
+
|
|
8
|
+
[tool.poetry.dependencies]
|
|
9
|
+
python = "^3.11"
|
|
10
|
+
grpcio = "^1.70.0"
|
|
11
|
+
pyarrow = "^19.0.0"
|
|
12
|
+
protobuf = "^5.29.3"
|
|
13
|
+
pandas = "^2.2.3"
|
|
14
|
+
websockets = "^14.2"
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
[tool.poetry.group.dev.dependencies]
|
|
18
|
+
pytest = "^8.3.4"
|
|
19
|
+
pytest-asyncio = "^0.25.3"
|
|
20
|
+
|
|
21
|
+
[build-system]
|
|
22
|
+
requires = ["poetry-core"]
|
|
23
|
+
build-backend = "poetry.core.masonry.api"
|
|
24
|
+
|
|
25
|
+
[project.urls]
|
|
26
|
+
"Homepage" = "https://github.com/crunchdao/model-runner-client"
|
|
27
|
+
"Bug Tracker" = "https://github.com/crunchdao/model-runner-client/issues"
|