model-runner-client 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,107 @@
1
+ Metadata-Version: 2.1
2
+ Name: model-runner-client
3
+ Version: 0.1.0
4
+ Summary: Model Runner Client is a Python library that allows Coordinators to manage real-time model synchronization and perform concurrent predictions on distributed model nodes within a Crunch.
5
+ Author: boutrig abdennour
6
+ Author-email: abdennour.boutrig@crunchdao.com
7
+ Requires-Python: >=3.11,<4.0
8
+ Classifier: Programming Language :: Python :: 3
9
+ Classifier: Programming Language :: Python :: 3.11
10
+ Classifier: Programming Language :: Python :: 3.12
11
+ Classifier: Programming Language :: Python :: 3.13
12
+ Requires-Dist: grpcio (>=1.70.0,<2.0.0)
13
+ Requires-Dist: pandas (>=2.2.3,<3.0.0)
14
+ Requires-Dist: protobuf (>=5.29.3,<6.0.0)
15
+ Requires-Dist: pyarrow (>=19.0.0,<20.0.0)
16
+ Requires-Dist: websockets (>=14.2,<15.0)
17
+ Description-Content-Type: text/markdown
18
+
19
+ # Model Runner Client
20
+
21
+ **Model Runner Client** is a Python library that allows you, as a Coordinator, to interact with models participating in your crunch. It tracks which models join or leave through a WebSocket connection to the model nodes.
22
+
23
+ - **Real-Time Model Sync**: Each model participating in your crunch is an instance of `ModelRunner`, maintained via WebSocket in the `ModelCluster`.
24
+ - **Concurrent Predictions (with Timeout Handling)**: Use the `ModelConcurrentRunner` to request predictions from all models simultaneously. Define a timeout to avoid blocking if a model takes too long to predict.
25
+
26
+ ## Installation
27
+
28
+ ```bash
29
+ pip install model-runner-client
30
+ ```
31
+
32
+ > **Note**: Adjust this command (e.g., `pip3` or virtual environments) depending on your setup.
33
+
34
+ ## Usage
35
+
36
+ Below is a quick example focusing on the ModelConcurrentRunner. It handles concurrent predictions for you and returns all results in one go.
37
+
38
+ ```python
39
+ import asyncio
40
+ from model_runner_client.model_concurrent_runner import ModelConcurrentRunner
41
+ from model_runner_client.protos.model_runner_pb2 import DataType
42
+ from model_runner_client.utils.datatype_transformer import encode_data
43
+
44
+ async def main():
45
+ # crunch_id, host, and port are values provided by crunchdao
46
+ concurrent_runner = ModelConcurrentRunner(
47
+ timeout=10,
48
+ crunch_id="bird-game",
49
+ host="localhost",
50
+ port=8000
51
+ )
52
+
53
+ # Initialize communication with the model nodes to fetch
54
+ # models that want to predict and set up the model cluster
55
+ await concurrent_runner.init()
56
+
57
+ async def prediction_call():
58
+ while True:
59
+ # Your data to be predicted (X)
60
+ value = {
61
+ 'falcon_location': 21.179864629354732,
62
+ 'time': 230.96231205799998,
63
+ 'dove_location': 19.164986723324326,
64
+ 'falcon_id': 1
65
+ }
66
+
67
+ # Encode data as binary and request a prediction
68
+ result = await concurrent_runner.predict(
69
+ DataType.JSON,
70
+ encode_data(DataType.JSON, value)
71
+ )
72
+
73
+ # You receive a dictionary of predictions
74
+ for model_runner, model_predict_result in result.items():
75
+ print(f"{model_runner.model_id}: {model_predict_result}")
76
+
77
+ # This pause (30s) simulates other work
78
+ # the Coordinator might perform between predictions
79
+ await asyncio.sleep(30)
80
+
81
+ # Keep the cluster updated with `concurrent_runner.sync()`,
82
+ # which maintains a permanent WebSocket connection.
83
+ # Then run our prediction process.
84
+ await asyncio.gather(
85
+ asyncio.create_task(concurrent_runner.sync()),
86
+ asyncio.create_task(prediction_call())
87
+ )
88
+
89
+ if __name__ == "__main__":
90
+ try:
91
+ asyncio.run(main())
92
+ except KeyboardInterrupt:
93
+ print("\nReceived exit signal, shutting down gracefully.")
94
+ ```
95
+
96
+ ### Important Notes
97
+
98
+ - **Prediction Failures & Timeouts**: A prediction may fail or exceed the defined timeout, so be sure to handle these cases appropriately. Refer to `ModelPredictResult.Status` for details.
99
+ - **Custom Implementations**: If you need more control over your workflow, you can manage each model individually. Instead of using `ModelConcurrentRunner`, you can directly leverage `ModelRunner` instances from the `ModelCluster`, customizing how you schedule predictions and handle results.
100
+
101
+ ## Contributing
102
+
103
+ Contributions are welcome! Feel free to open issues or submit pull requests if you encounter any bugs or want to suggest improvements.
104
+
105
+ ## License
106
+
107
+ This project is distributed under the [MIT License](https://choosealicense.com/licenses/mit/). See the LICENSE file for details.
@@ -0,0 +1,89 @@
1
+ # Model Runner Client
2
+
3
+ **Model Runner Client** is a Python library that allows you, as a Coordinator, to interact with models participating in your crunch. It tracks which models join or leave through a WebSocket connection to the model nodes.
4
+
5
+ - **Real-Time Model Sync**: Each model participating in your crunch is an instance of `ModelRunner`, maintained via WebSocket in the `ModelCluster`.
6
+ - **Concurrent Predictions (with Timeout Handling)**: Use the `ModelConcurrentRunner` to request predictions from all models simultaneously. Define a timeout to avoid blocking if a model takes too long to predict.
7
+
8
+ ## Installation
9
+
10
+ ```bash
11
+ pip install model-runner-client
12
+ ```
13
+
14
+ > **Note**: Adjust this command (e.g., `pip3` or virtual environments) depending on your setup.
15
+
16
+ ## Usage
17
+
18
+ Below is a quick example focusing on the ModelConcurrentRunner. It handles concurrent predictions for you and returns all results in one go.
19
+
20
+ ```python
21
+ import asyncio
22
+ from model_runner_client.model_concurrent_runner import ModelConcurrentRunner
23
+ from model_runner_client.protos.model_runner_pb2 import DataType
24
+ from model_runner_client.utils.datatype_transformer import encode_data
25
+
26
+ async def main():
27
+ # crunch_id, host, and port are values provided by crunchdao
28
+ concurrent_runner = ModelConcurrentRunner(
29
+ timeout=10,
30
+ crunch_id="bird-game",
31
+ host="localhost",
32
+ port=8000
33
+ )
34
+
35
+ # Initialize communication with the model nodes to fetch
36
+ # models that want to predict and set up the model cluster
37
+ await concurrent_runner.init()
38
+
39
+ async def prediction_call():
40
+ while True:
41
+ # Your data to be predicted (X)
42
+ value = {
43
+ 'falcon_location': 21.179864629354732,
44
+ 'time': 230.96231205799998,
45
+ 'dove_location': 19.164986723324326,
46
+ 'falcon_id': 1
47
+ }
48
+
49
+ # Encode data as binary and request a prediction
50
+ result = await concurrent_runner.predict(
51
+ DataType.JSON,
52
+ encode_data(DataType.JSON, value)
53
+ )
54
+
55
+ # You receive a dictionary of predictions
56
+ for model_runner, model_predict_result in result.items():
57
+ print(f"{model_runner.model_id}: {model_predict_result}")
58
+
59
+ # This pause (30s) simulates other work
60
+ # the Coordinator might perform between predictions
61
+ await asyncio.sleep(30)
62
+
63
+ # Keep the cluster updated with `concurrent_runner.sync()`,
64
+ # which maintains a permanent WebSocket connection.
65
+ # Then run our prediction process.
66
+ await asyncio.gather(
67
+ asyncio.create_task(concurrent_runner.sync()),
68
+ asyncio.create_task(prediction_call())
69
+ )
70
+
71
+ if __name__ == "__main__":
72
+ try:
73
+ asyncio.run(main())
74
+ except KeyboardInterrupt:
75
+ print("\nReceived exit signal, shutting down gracefully.")
76
+ ```
77
+
78
+ ### Important Notes
79
+
80
+ - **Prediction Failures & Timeouts**: A prediction may fail or exceed the defined timeout, so be sure to handle these cases appropriately. Refer to `ModelPredictResult.Status` for details.
81
+ - **Custom Implementations**: If you need more control over your workflow, you can manage each model individually. Instead of using `ModelConcurrentRunner`, you can directly leverage `ModelRunner` instances from the `ModelCluster`, customizing how you schedule predictions and handle results.
82
+
83
+ ## Contributing
84
+
85
+ Contributions are welcome! Feel free to open issues or submit pull requests if you encounter any bugs or want to suggest improvements.
86
+
87
+ ## License
88
+
89
+ This project is distributed under the [MIT License](https://choosealicense.com/licenses/mit/). See the LICENSE file for details.
@@ -0,0 +1,127 @@
1
+ import logging
2
+
3
+ from model_runner_client.model_runner import ModelRunner
4
+ from model_runner_client.websocket_client import WebsocketClient
5
+
6
+ logger = logging.getLogger("model_runner_client")
7
+ import asyncio
8
+
9
+
10
+ class ModelCluster:
11
+ def __init__(self, crunch_id: str, ws_host: str, ws_port: int):
12
+ """
13
+ ModelCluster constructor.
14
+
15
+ :param crunch_id: The Crunch ID that this cluster is responsible for.
16
+ :param ws_host: The WebSocket server's host.
17
+ :param ws_port: The WebSocket server's port.
18
+ """
19
+ self.crunch_id = crunch_id
20
+ self.models_run = {}
21
+ logger.debug(f"Initializing ModelCluster with Crunch ID: {crunch_id}")
22
+ self.ws_client = WebsocketClient(ws_host, ws_port, crunch_id, event_handler=self.handle_event)
23
+
24
+ async def init(self):
25
+ await self.ws_client.connect()
26
+ await self.ws_client.init()
27
+
28
+ logger.debug("WebSocket client initialized.")
29
+
30
+ async def sync(self):
31
+ await self.ws_client.listen()
32
+
33
+ async def handle_event(self, event_type: str, data: list[dict]):
34
+ """
35
+ Handle WebSocket events (`init` and `update`) and update the cluster's state.
36
+
37
+ :param event_type: The type of the event (`init` or `update`).
38
+ :param data: The event data.
39
+ """
40
+ try:
41
+ if event_type == "init":
42
+ logger.debug(f"**ModelCluster** Processing event type: {event_type}")
43
+ await self.handle_init_event(data)
44
+ elif event_type == "update":
45
+ logger.debug(f"**ModelCluster** Processing event type: {event_type}")
46
+ await self.handle_update_event(data)
47
+ else:
48
+ logger.warning(f"**ModelCluster** Unknown event type received: {event_type}")
49
+ except Exception as e:
50
+ logger.error(f"**ModelCluster** Error processing event {event_type}: {e}", exc_info=True)
51
+ raise e
52
+
53
+ async def handle_init_event(self, data: list[dict]):
54
+ """
55
+ Process the `init` event to initialize models run.
56
+
57
+ :param data: List of models with their initial states.
58
+ """
59
+ logger.debug("**ModelCluster** Handling 'init' event.")
60
+ await self.update_model_runs(data)
61
+
62
+ async def handle_update_event(self, data: list[dict]):
63
+ """
64
+ Process the `update` event to update model states.
65
+
66
+ :param data: List of models with their updated states.
67
+ """
68
+ logger.debug("**ModelCluster** Handling 'update' event.")
69
+ await self.update_model_runs(data)
70
+
71
+ async def update_model_runs(self, data):
72
+ tasks = []
73
+ for model_update in data:
74
+ model_id = model_update.get("model_id")
75
+ model_name = model_update.get("model_name")
76
+ state = model_update.get("state")
77
+ ip = model_update.get("ip")
78
+ port = model_update.get("port")
79
+ logger.debug(f"**ModelCluster** Updating model with ID: {model_id}")
80
+
81
+ # Find the model in the current state
82
+ model_runner = self.models_run.get(model_id)
83
+
84
+ if model_runner:
85
+ if state == "STOPPED":
86
+ # Remove model if state is "stopped"
87
+ tasks.append(self.remove_model_runner(model_runner))
88
+ logger.debug(f"**ModelCluster** Model with ID {model_id} removed due to 'stopped' state.")
89
+ elif state == "RUNNING":
90
+ logger.debug(f"**ModelCluster** Model with ID {model_id} is already running in the cluster. Skipping update for '{model_name}' with state: {state}.")
91
+ else:
92
+ logger.warning(f"**ModelCluster** Model updated: {model_id}, with state: {state} => This state is not handled...")
93
+ else:
94
+ if state == "STOPPED":
95
+ logger.debug(f"**ModelCluster** Model with ID {model_id} is not found in the cluster state, and its state is 'STOPPED'. No action is required.")
96
+ elif state == "RUNNING":
97
+ logger.debug(f"**ModelCluster** New model with ID {model_id} is running, we add it to the cluster state.")
98
+ model_runner = ModelRunner(model_id, model_name, ip, port)
99
+ tasks.append(self.add_model_runner(model_runner))
100
+ else:
101
+ logger.warning(f"**ModelCluster** Model updated: {model_id}, with state: {state} => This state is not handled...")
102
+
103
+ await asyncio.gather(*tasks)
104
+
105
+ async def add_model_runner(self, model_runner: ModelRunner):
106
+ """
107
+ Asynchronously initialize a model_runner and add it to the cluster state.
108
+ """
109
+ is_initialized = await model_runner.init()
110
+ if is_initialized:
111
+ self.models_run[model_runner.model_id] = model_runner
112
+
113
+ async def remove_model_runner(self, model_runner: ModelRunner):
114
+ """
115
+ Asynchronously initialize a model_runner and add it to the cluster state.
116
+ """
117
+ await model_runner.close()
118
+ del self.models_run[model_runner.model_id]
119
+
120
+ async def start(self):
121
+ """
122
+ Start the WebSocket client and handle events.
123
+ """
124
+ try:
125
+ await self.ws_client.connect()
126
+ except Exception as e:
127
+ logger.error(f"**ModelCluster** Failed to start WebSocket client: {e}", exc_info=True)
@@ -0,0 +1,65 @@
1
+ import asyncio
2
+ import logging
3
+ from enum import Enum
4
+
5
+ from model_runner_client.model_cluster import ModelCluster
6
+ from model_runner_client.model_runner import ModelRunner
7
+ from model_runner_client.protos.model_runner_pb2 import DataType
8
+
9
+ logger = logging.getLogger("model_runner_client")
10
+
11
+
12
+ class ModelPredictResult:
13
+ class Status(Enum):
14
+ SUCCESS = "SUCCESS"
15
+ FAILED = "FAILED"
16
+ TIMEOUT = "TIMEOUT"
17
+
18
+ def __init__(self, model_runner: ModelRunner, result: any, status: Status):
19
+ self.model_runner = model_runner
20
+ self.result = result
21
+ self.status = status
22
+
23
+ def __str__(self):
24
+ return f"ModelPredictResult(model_runner={self.model_runner}, result={self.result}, status={self.status.name})"
25
+
26
+
27
+ class ModelConcurrentRunner:
28
+ def __init__(self, timeout: int, crunch_id: str, host: str, port: int):
29
+ self.timeout = timeout
30
+ self.host = host
31
+ self.port = port
32
+ self.model_cluster = ModelCluster(crunch_id, self.host, self.port)
33
+
34
+ # TODO: If the model returns failures exceeding max_consecutive_failures, exclude the model. Maybe also inform the orchestrator to STOP the model ?
35
+ #self.max_consecutive_failures
36
+
37
+ # TODO: Implement this. If the option is enabled, allow the model time to recover after a timeout.
38
+ #self.enable_recovery_mode
39
+ #self.recovery_time
40
+
41
+
42
+ async def init(self):
43
+ await self.model_cluster.init()
44
+
45
+ async def sync(self):
46
+ await self.model_cluster.sync()
47
+
48
+ async def predict(self, argument_type: DataType, argument_value: bytes) -> dict[ModelRunner, ModelPredictResult]:
49
+ tasks = [self._predict_with_timeout(model, argument_type, argument_value)
50
+ for model in self.model_cluster.models_run.values()]
51
+ logger.debug(f"**ModelConcurrentRunner** predict tasks: {tasks}")
52
+ results = await asyncio.gather(*tasks, return_exceptions=True)
53
+
54
+ return {result.model_runner: result for result in results}
55
+
56
+ async def _predict_with_timeout(self, model: ModelRunner, argument_type: DataType, argument_value: bytes):
57
+ try:
58
+ result = await asyncio.wait_for(
59
+ model.predict(argument_type, argument_value), timeout=self.timeout
60
+ )
61
+ return ModelPredictResult(model, result, ModelPredictResult.Status.SUCCESS)
62
+ except asyncio.TimeoutError:
63
+ return ModelPredictResult(model, None, ModelPredictResult.Status.TIMEOUT)
64
+ except Exception as e:
65
+ return ModelPredictResult(model, str(e), ModelPredictResult.Status.FAILED)
@@ -0,0 +1,76 @@
1
+ import asyncio
2
+ import atexit
3
+ import logging
4
+
5
+ import grpc
6
+ from google.protobuf import empty_pb2
7
+ from grpc.aio import AioRpcError
8
+
9
+ from model_runner_client.protos.model_runner_pb2 import DataType, InferRequest
10
+ from model_runner_client.protos.model_runner_pb2_grpc import ModelRunnerStub
11
+ from model_runner_client.utils.datatype_transformer import decode_data
12
+
13
+ logger = logging.getLogger("model_runner_client")
14
+
15
+
16
+ class ModelRunner:
17
+ def __init__(self, model_id: str, model_name: str, ip: str, port: int):
18
+
19
+ self.model_id = model_id
20
+ self.model_name = model_name
21
+ self.ip = ip
22
+ self.port = port
23
+ logger.info(f"**ModelRunner** New model runner created: {self.model_id}, {self.model_name}, {self.ip}:{self.port}, let's connect it")
24
+
25
+ self.grpc_channel = None
26
+ self.grpc_stub = None
27
+ self.retry_attempts = 5 # args ?
28
+ self.min_retry_interval = 2 # 2 seconds
29
+ self.closed = False
30
+
31
+ def __del__(self):
32
+ logger.debug(f"**ModelRunner** Model runner {self.model_id} is destroyed")
33
+ atexit.register(self.close_sync)
34
+
35
+ async def init(self) -> bool:
36
+ for attempt in range(1, self.retry_attempts + 1):
37
+ if self.closed:
38
+ logger.debug(f"**ModelRunner** Model runner {self.model_id} closed, aborting initialization")
39
+ return False
40
+ try:
41
+ self.grpc_channel = grpc.aio.insecure_channel(f"{self.ip}:{self.port}")
42
+ self.grpc_stub = ModelRunnerStub(self.grpc_channel)
43
+ await self.grpc_stub.Setup(empty_pb2.Empty()) # maybe orchestrator has to do that ?
44
+ logger.info(f"**ModelRunner** model runner: {self.model_id}, {self.model_name}, is connected and ready")
45
+ return True
46
+ except (AioRpcError, asyncio.TimeoutError) as e:
47
+ logger.error(f"**ModelRunner** Model {self.model_id} initialization failed on attempt {attempt}/{self.retry_attempts}: {e}")
48
+ except Exception as e:
49
+ logger.error(f"**ModelRunner** Unexpected error during initialization of model {self.model_id}: {e}", exc_info=True)
50
+
51
+ if attempt < self.retry_attempts:
52
+ backoff_time = 2 ** attempt # Backoff with exponential delay
53
+ logger.warning(f"**ModelRunner** Retrying in {backoff_time} seconds...")
54
+ await asyncio.sleep(backoff_time)
55
+ else:
56
+ logger.error(f"**ModelRunner** Model {self.model_id} failed to initialize after {self.retry_attempts} attempts.")
57
+ # todo what is the behavior here ? remove it locally ?
58
+ return False
59
+
60
+ async def predict(self, argument_type: DataType, argument_value: bytes):
61
+ logger.debug(f"**ModelRunner** Doing prediction of model_id:{self.model_id}, name:{self.model_name}, argument_type:{argument_type}")
62
+ prediction_request = InferRequest(type=argument_type, argument=argument_value)
63
+
64
+ response = await self.grpc_stub.Infer(prediction_request)
65
+
66
+ return decode_data(response.prediction, response.type)
67
+
68
+ def close_sync(self):
69
+ loop = asyncio.get_event_loop()
70
+ loop.run_until_complete(self.close())
71
+
72
+ async def close(self):
73
+ self.closed = True
74
+ if self.grpc_channel:
75
+ await self.grpc_channel.close()
76
+ logger.debug(f"**ModelRunner** Model runner {self.model_id} grpc connection closed")
@@ -0,0 +1,29 @@
1
+ syntax = "proto3";
2
+
3
+ import "google/protobuf/empty.proto";
4
+
5
+ service ModelRunner {
6
+ rpc InferStream(stream InferRequest) returns (stream InferResponse);
7
+ rpc Infer(InferRequest) returns (InferResponse);
8
+ rpc Setup(google.protobuf.Empty) returns (google.protobuf.Empty);
9
+ rpc Reinitialize(google.protobuf.Empty) returns (google.protobuf.Empty);
10
+ }
11
+
12
+ enum DataType {
13
+ DOUBLE = 0;
14
+ INT = 1;
15
+ STRING = 2;
16
+ PARQUET = 3;
17
+ ARROW = 4;
18
+ JSON = 5;
19
+ }
20
+
21
+ message InferRequest {
22
+ DataType type = 1;
23
+ bytes argument = 2;
24
+ }
25
+
26
+ message InferResponse {
27
+ DataType type = 1;
28
+ bytes prediction = 2;
29
+ }
@@ -0,0 +1,43 @@
1
+ # -*- coding: utf-8 -*-
2
+ # Generated by the protocol buffer compiler. DO NOT EDIT!
3
+ # NO CHECKED-IN PROTOBUF GENCODE
4
+ # source: model_runner.proto
5
+ # Protobuf Python Version: 5.29.0
6
+ """Generated protocol buffer code."""
7
+ from google.protobuf import descriptor as _descriptor
8
+ from google.protobuf import descriptor_pool as _descriptor_pool
9
+ from google.protobuf import runtime_version as _runtime_version
10
+ from google.protobuf import symbol_database as _symbol_database
11
+ from google.protobuf.internal import builder as _builder
12
+ _runtime_version.ValidateProtobufRuntimeVersion(
13
+ _runtime_version.Domain.PUBLIC,
14
+ 5,
15
+ 29,
16
+ 0,
17
+ '',
18
+ 'model_runner.proto'
19
+ )
20
+ # @@protoc_insertion_point(imports)
21
+
22
+ _sym_db = _symbol_database.Default()
23
+
24
+
25
+ from google.protobuf import empty_pb2 as google_dot_protobuf_dot_empty__pb2
26
+
27
+
28
+ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x12model_runner.proto\x1a\x1bgoogle/protobuf/empty.proto\"9\n\x0cInferRequest\x12\x17\n\x04type\x18\x01 \x01(\x0e\x32\t.DataType\x12\x10\n\x08\x61rgument\x18\x02 \x01(\x0c\"<\n\rInferResponse\x12\x17\n\x04type\x18\x01 \x01(\x0e\x32\t.DataType\x12\x12\n\nprediction\x18\x02 \x01(\x0c*M\n\x08\x44\x61taType\x12\n\n\x06\x44OUBLE\x10\x00\x12\x07\n\x03INT\x10\x01\x12\n\n\x06STRING\x10\x02\x12\x0b\n\x07PARQUET\x10\x03\x12\t\n\x05\x41RROW\x10\x04\x12\x08\n\x04JSON\x10\x05\x32\xe0\x01\n\x0bModelRunner\x12\x30\n\x0bInferStream\x12\r.InferRequest\x1a\x0e.InferResponse(\x01\x30\x01\x12&\n\x05Infer\x12\r.InferRequest\x1a\x0e.InferResponse\x12\x37\n\x05Setup\x12\x16.google.protobuf.Empty\x1a\x16.google.protobuf.Empty\x12>\n\x0cReinitialize\x12\x16.google.protobuf.Empty\x1a\x16.google.protobuf.Emptyb\x06proto3')
29
+
30
+ _globals = globals()
31
+ _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
32
+ _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'model_runner_pb2', _globals)
33
+ if not _descriptor._USE_C_DESCRIPTORS:
34
+ DESCRIPTOR._loaded_options = None
35
+ _globals['_DATATYPE']._serialized_start=172
36
+ _globals['_DATATYPE']._serialized_end=249
37
+ _globals['_INFERREQUEST']._serialized_start=51
38
+ _globals['_INFERREQUEST']._serialized_end=108
39
+ _globals['_INFERRESPONSE']._serialized_start=110
40
+ _globals['_INFERRESPONSE']._serialized_end=170
41
+ _globals['_MODELRUNNER']._serialized_start=252
42
+ _globals['_MODELRUNNER']._serialized_end=476
43
+ # @@protoc_insertion_point(module_scope)
@@ -0,0 +1,38 @@
1
+ from google.protobuf import empty_pb2 as _empty_pb2
2
+ from google.protobuf.internal import enum_type_wrapper as _enum_type_wrapper
3
+ from google.protobuf import descriptor as _descriptor
4
+ from google.protobuf import message as _message
5
+ from typing import ClassVar as _ClassVar, Optional as _Optional, Union as _Union
6
+
7
+ DESCRIPTOR: _descriptor.FileDescriptor
8
+
9
+ class DataType(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):
10
+ __slots__ = ()
11
+ DOUBLE: _ClassVar[DataType]
12
+ INT: _ClassVar[DataType]
13
+ STRING: _ClassVar[DataType]
14
+ PARQUET: _ClassVar[DataType]
15
+ ARROW: _ClassVar[DataType]
16
+ JSON: _ClassVar[DataType]
17
+ DOUBLE: DataType
18
+ INT: DataType
19
+ STRING: DataType
20
+ PARQUET: DataType
21
+ ARROW: DataType
22
+ JSON: DataType
23
+
24
+ class InferRequest(_message.Message):
25
+ __slots__ = ("type", "argument")
26
+ TYPE_FIELD_NUMBER: _ClassVar[int]
27
+ ARGUMENT_FIELD_NUMBER: _ClassVar[int]
28
+ type: DataType
29
+ argument: bytes
30
+ def __init__(self, type: _Optional[_Union[DataType, str]] = ..., argument: _Optional[bytes] = ...) -> None: ...
31
+
32
+ class InferResponse(_message.Message):
33
+ __slots__ = ("type", "prediction")
34
+ TYPE_FIELD_NUMBER: _ClassVar[int]
35
+ PREDICTION_FIELD_NUMBER: _ClassVar[int]
36
+ type: DataType
37
+ prediction: bytes
38
+ def __init__(self, type: _Optional[_Union[DataType, str]] = ..., prediction: _Optional[bytes] = ...) -> None: ...
@@ -0,0 +1,227 @@
1
+ # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
2
+ """Client and server classes corresponding to protobuf-defined services."""
3
+ import grpc
4
+ import warnings
5
+
6
+ from google.protobuf import empty_pb2 as google_dot_protobuf_dot_empty__pb2
7
+ from . import model_runner_pb2 as model__runner__pb2
8
+
9
+ GRPC_GENERATED_VERSION = '1.69.0'
10
+ GRPC_VERSION = grpc.__version__
11
+ _version_not_supported = False
12
+
13
+ try:
14
+ from grpc._utilities import first_version_is_lower
15
+ _version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION)
16
+ except ImportError:
17
+ _version_not_supported = True
18
+
19
+ if _version_not_supported:
20
+ raise RuntimeError(
21
+ f'The grpc package installed is at version {GRPC_VERSION},'
22
+ + f' but the generated code in model_runner_pb2_grpc.py depends on'
23
+ + f' grpcio>={GRPC_GENERATED_VERSION}.'
24
+ + f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}'
25
+ + f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.'
26
+ )
27
+
28
+
29
+ class ModelRunnerStub(object):
30
+ """Missing associated documentation comment in .proto file."""
31
+
32
+ def __init__(self, channel):
33
+ """Constructor.
34
+
35
+ Args:
36
+ channel: A grpc.Channel.
37
+ """
38
+ self.InferStream = channel.stream_stream(
39
+ '/ModelRunner/InferStream',
40
+ request_serializer=model__runner__pb2.InferRequest.SerializeToString,
41
+ response_deserializer=model__runner__pb2.InferResponse.FromString,
42
+ _registered_method=True)
43
+ self.Infer = channel.unary_unary(
44
+ '/ModelRunner/Infer',
45
+ request_serializer=model__runner__pb2.InferRequest.SerializeToString,
46
+ response_deserializer=model__runner__pb2.InferResponse.FromString,
47
+ _registered_method=True)
48
+ self.Setup = channel.unary_unary(
49
+ '/ModelRunner/Setup',
50
+ request_serializer=google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString,
51
+ response_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString,
52
+ _registered_method=True)
53
+ self.Reinitialize = channel.unary_unary(
54
+ '/ModelRunner/Reinitialize',
55
+ request_serializer=google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString,
56
+ response_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString,
57
+ _registered_method=True)
58
+
59
+
60
+ class ModelRunnerServicer(object):
61
+ """Missing associated documentation comment in .proto file."""
62
+
63
+ def InferStream(self, request_iterator, context):
64
+ """Missing associated documentation comment in .proto file."""
65
+ context.set_code(grpc.StatusCode.UNIMPLEMENTED)
66
+ context.set_details('Method not implemented!')
67
+ raise NotImplementedError('Method not implemented!')
68
+
69
+ def Infer(self, request, context):
70
+ """Missing associated documentation comment in .proto file."""
71
+ context.set_code(grpc.StatusCode.UNIMPLEMENTED)
72
+ context.set_details('Method not implemented!')
73
+ raise NotImplementedError('Method not implemented!')
74
+
75
+ def Setup(self, request, context):
76
+ """Missing associated documentation comment in .proto file."""
77
+ context.set_code(grpc.StatusCode.UNIMPLEMENTED)
78
+ context.set_details('Method not implemented!')
79
+ raise NotImplementedError('Method not implemented!')
80
+
81
+ def Reinitialize(self, request, context):
82
+ """Missing associated documentation comment in .proto file."""
83
+ context.set_code(grpc.StatusCode.UNIMPLEMENTED)
84
+ context.set_details('Method not implemented!')
85
+ raise NotImplementedError('Method not implemented!')
86
+
87
+
88
+ def add_ModelRunnerServicer_to_server(servicer, server):
89
+ rpc_method_handlers = {
90
+ 'InferStream': grpc.stream_stream_rpc_method_handler(
91
+ servicer.InferStream,
92
+ request_deserializer=model__runner__pb2.InferRequest.FromString,
93
+ response_serializer=model__runner__pb2.InferResponse.SerializeToString,
94
+ ),
95
+ 'Infer': grpc.unary_unary_rpc_method_handler(
96
+ servicer.Infer,
97
+ request_deserializer=model__runner__pb2.InferRequest.FromString,
98
+ response_serializer=model__runner__pb2.InferResponse.SerializeToString,
99
+ ),
100
+ 'Setup': grpc.unary_unary_rpc_method_handler(
101
+ servicer.Setup,
102
+ request_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString,
103
+ response_serializer=google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString,
104
+ ),
105
+ 'Reinitialize': grpc.unary_unary_rpc_method_handler(
106
+ servicer.Reinitialize,
107
+ request_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString,
108
+ response_serializer=google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString,
109
+ ),
110
+ }
111
+ generic_handler = grpc.method_handlers_generic_handler(
112
+ 'ModelRunner', rpc_method_handlers)
113
+ server.add_generic_rpc_handlers((generic_handler,))
114
+ server.add_registered_method_handlers('ModelRunner', rpc_method_handlers)
115
+
116
+
117
+ # This class is part of an EXPERIMENTAL API.
118
+ class ModelRunner(object):
119
+ """Missing associated documentation comment in .proto file."""
120
+
121
+ @staticmethod
122
+ def InferStream(request_iterator,
123
+ target,
124
+ options=(),
125
+ channel_credentials=None,
126
+ call_credentials=None,
127
+ insecure=False,
128
+ compression=None,
129
+ wait_for_ready=None,
130
+ timeout=None,
131
+ metadata=None):
132
+ return grpc.experimental.stream_stream(
133
+ request_iterator,
134
+ target,
135
+ '/ModelRunner/InferStream',
136
+ model__runner__pb2.InferRequest.SerializeToString,
137
+ model__runner__pb2.InferResponse.FromString,
138
+ options,
139
+ channel_credentials,
140
+ insecure,
141
+ call_credentials,
142
+ compression,
143
+ wait_for_ready,
144
+ timeout,
145
+ metadata,
146
+ _registered_method=True)
147
+
148
+ @staticmethod
149
+ def Infer(request,
150
+ target,
151
+ options=(),
152
+ channel_credentials=None,
153
+ call_credentials=None,
154
+ insecure=False,
155
+ compression=None,
156
+ wait_for_ready=None,
157
+ timeout=None,
158
+ metadata=None):
159
+ return grpc.experimental.unary_unary(
160
+ request,
161
+ target,
162
+ '/ModelRunner/Infer',
163
+ model__runner__pb2.InferRequest.SerializeToString,
164
+ model__runner__pb2.InferResponse.FromString,
165
+ options,
166
+ channel_credentials,
167
+ insecure,
168
+ call_credentials,
169
+ compression,
170
+ wait_for_ready,
171
+ timeout,
172
+ metadata,
173
+ _registered_method=True)
174
+
175
+ @staticmethod
176
+ def Setup(request,
177
+ target,
178
+ options=(),
179
+ channel_credentials=None,
180
+ call_credentials=None,
181
+ insecure=False,
182
+ compression=None,
183
+ wait_for_ready=None,
184
+ timeout=None,
185
+ metadata=None):
186
+ return grpc.experimental.unary_unary(
187
+ request,
188
+ target,
189
+ '/ModelRunner/Setup',
190
+ google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString,
191
+ google_dot_protobuf_dot_empty__pb2.Empty.FromString,
192
+ options,
193
+ channel_credentials,
194
+ insecure,
195
+ call_credentials,
196
+ compression,
197
+ wait_for_ready,
198
+ timeout,
199
+ metadata,
200
+ _registered_method=True)
201
+
202
+ @staticmethod
203
+ def Reinitialize(request,
204
+ target,
205
+ options=(),
206
+ channel_credentials=None,
207
+ call_credentials=None,
208
+ insecure=False,
209
+ compression=None,
210
+ wait_for_ready=None,
211
+ timeout=None,
212
+ metadata=None):
213
+ return grpc.experimental.unary_unary(
214
+ request,
215
+ target,
216
+ '/ModelRunner/Reinitialize',
217
+ google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString,
218
+ google_dot_protobuf_dot_empty__pb2.Empty.FromString,
219
+ options,
220
+ channel_credentials,
221
+ insecure,
222
+ call_credentials,
223
+ compression,
224
+ wait_for_ready,
225
+ timeout,
226
+ metadata,
227
+ _registered_method=True)
@@ -0,0 +1,71 @@
1
+ import json
2
+ import struct
3
+ import io
4
+ import pandas
5
+ import pyarrow.parquet as pq
6
+
7
+ from model_runner_client.protos.model_runner_pb2 import DataType
8
+
9
+
10
+ # Encoder: Converts data to bytes
11
+ def encode_data(data_type: DataType, data) -> bytes:
12
+ if data_type == DataType.DOUBLE:
13
+ return struct.pack("d", data)
14
+ elif data_type == DataType.INT:
15
+ return data.to_bytes(8, byteorder="little", signed=True)
16
+ elif data_type == DataType.STRING:
17
+ return data.encode("utf-8")
18
+ elif data_type == DataType.PARQUET:
19
+ table = pandas.Table.from_pandas(data)
20
+ sink = io.BytesIO()
21
+ pandas.write_table(table, sink)
22
+ return sink.getvalue()
23
+ elif data_type == DataType.JSON:
24
+ try:
25
+ json_data = json.dumps(data) # Convert the object to a JSON string
26
+ return json_data.encode("utf-8") # Return the JSON string as bytes
27
+ except TypeError as e:
28
+ raise ValueError(f"Data cannot be serialized to JSON: {e}")
29
+ else:
30
+ raise ValueError(f"Unsupported data type: {data_type}")
31
+
32
+
33
+ # Decoder: Converts bytes to data
34
+ def decode_data(data_bytes: bytes, data_type: DataType):
35
+ if data_type == DataType.DOUBLE:
36
+ return struct.unpack("d", data_bytes)[0]
37
+ elif data_type == DataType.INT:
38
+ return int.from_bytes(data_bytes, byteorder="little", signed=True)
39
+ elif data_type == DataType.STRING:
40
+ return data_bytes.decode("utf-8")
41
+ elif data_type == DataType.PARQUET:
42
+ buffer = io.BytesIO(data_bytes)
43
+ return pq.read_table(buffer).to_pandas()
44
+ elif data_type == DataType.JSON:
45
+ try:
46
+ json_data = data_bytes.decode("utf-8")
47
+ return json.loads(json_data)
48
+ except json.JSONDecodeError as e:
49
+ raise ValueError(f"Failed to decode JSON data: {e}")
50
+
51
+ else:
52
+ raise ValueError(f"Unsupported data type: {data_type}")
53
+
54
+
55
+ def detect_data_type(data) -> DataType:
56
+ """
57
+ Detects the data type based on the Python object and returns
58
+ the corresponding DataType enum.
59
+ """
60
+ if isinstance(data, float):
61
+ return DataType.DOUBLE
62
+ elif isinstance(data, int):
63
+ return DataType.INT
64
+ elif isinstance(data, str):
65
+ return DataType.STRING
66
+ elif isinstance(data, pandas.DataFrame):
67
+ return DataType.PARQUET
68
+ elif isinstance(data, dict) or isinstance(data, list):
69
+ return DataType.JSON
70
+ else:
71
+ raise ValueError(f"Unsupported data type: {type(data)}")
@@ -0,0 +1,134 @@
1
+ import asyncio
2
+ import atexit
3
+ import json
4
+ import logging
5
+
6
+ import websockets
7
+ from websockets import State
8
+
9
+ logger = logging.getLogger("model_runner_client")
10
+
11
+
12
+ class WebsocketClient:
13
+ def __init__(self, host, port, crunch_id, event_handler=None):
14
+ """
15
+ WebsocketClient constructor.
16
+
17
+ :param host: WebSocket server host.
18
+ :param port: WebSocket server port.
19
+ :param crunch_id: Crunch ID used to connect to the server.
20
+ :param event_handler: Optional handler for WebSocket events (delegated to ModelCluster).
21
+ """
22
+ self.retry_interval = 10
23
+ self.max_retries = 5
24
+ self.listening = None
25
+ self.host = host
26
+ self.port = port
27
+ self.crunch_id = crunch_id
28
+ self.websocket = None
29
+ self.event_handler = event_handler # Delegate to ModelCluster
30
+
31
+ def __del__(self):
32
+ atexit.register(self.disconnect_sync)
33
+
34
+ async def connect(self):
35
+ """
36
+ Establish a WebSocket connection.
37
+ """
38
+ retry_count = 0
39
+ while self.max_retries > retry_count:
40
+ try:
41
+ uri = f"ws://{self.host}:{self.port}/{self.crunch_id}"
42
+ logger.info(f"Connecting to WebSocket server at {uri}")
43
+ self.websocket = await websockets.connect(uri)
44
+ logger.info(f"Connected to WebSocket server at {uri}")
45
+ break
46
+ except (websockets.exceptions.ConnectionClosed, ConnectionRefusedError, OSError, asyncio.TimeoutError) as e:
47
+ logger.warning(f"Connection error ({e.__class__.__name__}): Retrying in {self.retry_interval} seconds...")
48
+ except Exception as e:
49
+ logger.error(f"Unexpected error ({e.__class__.__name__}): {e}", exc_info=True)
50
+ logger.warning(f"Retrying in {self.retry_interval} seconds...")
51
+
52
+ await asyncio.sleep(self.retry_interval)
53
+ retry_count += 1
54
+
55
+ async def init(self):
56
+ """
57
+ Listen first message who is init from the WebSocket server.
58
+ """
59
+ # retry here doesn't make sens, it comme after connection and connection handle retries
60
+ try:
61
+ message = await self.websocket.recv()
62
+ await self.handle_event(message)
63
+ except websockets.exceptions.ConnectionClosed:
64
+ logger.warning("WebSocket connection closed by the server.")
65
+ except Exception as e:
66
+ logger.error(f"Error while listening to WebSocket messages: {e}", exc_info=True)
67
+
68
+ async def listen(self):
69
+ """
70
+ Listen for messages from the WebSocket server.
71
+ """
72
+ self.listening = True
73
+ while self.listening:
74
+ try:
75
+ if not await self.is_connected():
76
+ await self.connect()
77
+
78
+ logger.info("Listening for messages...")
79
+ async for message in self.websocket:
80
+ await self.handle_event(message)
81
+ except (websockets.exceptions.ConnectionClosed, asyncio.TimeoutError):
82
+ logger.warning(f"Connection lost. Retrying in {self.retry_interval} seconds...")
83
+ await asyncio.sleep(self.retry_interval)
84
+ except Exception as e:
85
+ logger.error(f"An unexpected error occurred: {e}", exc_info=True)
86
+ logger.warning(f"Retrying in {self.retry_interval} seconds...")
87
+ await asyncio.sleep(self.retry_interval)
88
+
89
+ finally:
90
+ # Cleanup or reset if needed after the loop iteration
91
+ # For example, clear websocket references if disconnected
92
+ if self.websocket and not self.websocket.state:
93
+ self.websocket = None
94
+
95
+ async def is_connected(self):
96
+ """
97
+ Check if the WebSocket connection is open.
98
+ :return: True if the client is connected, False otherwise.
99
+ """
100
+ return self.websocket is not None and self.websocket.state == State.OPEN
101
+
102
+ async def handle_event(self, message):
103
+ """
104
+ Handle incoming WebSocket messages and forward to the event handler.
105
+
106
+ :param message: Message received from the server.
107
+ """
108
+ try:
109
+ event = json.loads(message)
110
+ event_type = event.get("event")
111
+ data = event.get("data")
112
+
113
+ logger.debug(f"Received event: {event_type}, data: {data}")
114
+
115
+ # Delegate to the event handler, if available
116
+ if self.event_handler:
117
+ await self.event_handler(event_type, data)
118
+ else:
119
+ logger.warning("No event handler defined. Event will be ignored.")
120
+ except json.JSONDecodeError:
121
+ logger.error(f"Failed to decode WebSocket message: {message}")
122
+
123
+ def disconnect_sync(self):
124
+ loop = asyncio.get_event_loop()
125
+ loop.run_until_complete(self.disconnect())
126
+
127
+ async def disconnect(self):
128
+ """
129
+ Close the WebSocket connection.
130
+ """
131
+ self.listening = False
132
+ if self.websocket:
133
+ await self.websocket.close()
134
+ logger.info("WebSocket connection closed.")
@@ -0,0 +1,27 @@
1
+ [tool.poetry]
2
+ name = "model-runner-client"
3
+ version = "0.1.0"
4
+ description = "Model Runner Client is a Python library that allows Coordinators to manage real-time model synchronization and perform concurrent predictions on distributed model nodes within a Crunch."
5
+ authors = ["boutrig abdennour <abdennour.boutrig@crunchdao.com>"]
6
+ readme = "README.md"
7
+
8
+ [tool.poetry.dependencies]
9
+ python = "^3.11"
10
+ grpcio = "^1.70.0"
11
+ pyarrow = "^19.0.0"
12
+ protobuf = "^5.29.3"
13
+ pandas = "^2.2.3"
14
+ websockets = "^14.2"
15
+
16
+
17
+ [tool.poetry.group.dev.dependencies]
18
+ pytest = "^8.3.4"
19
+ pytest-asyncio = "^0.25.3"
20
+
21
+ [build-system]
22
+ requires = ["poetry-core"]
23
+ build-backend = "poetry.core.masonry.api"
24
+
25
+ [project.urls]
26
+ "Homepage" = "https://github.com/crunchdao/model-runner-client"
27
+ "Bug Tracker" = "https://github.com/crunchdao/model-runner-client/issues"