gen-worker 0.1.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gen_worker/__init__.py +19 -0
- gen_worker/decorators.py +66 -0
- gen_worker/default_model_manager/__init__.py +5 -0
- gen_worker/downloader.py +84 -0
- gen_worker/entrypoint.py +135 -0
- gen_worker/errors.py +10 -0
- gen_worker/model_interface.py +48 -0
- gen_worker/pb/__init__.py +27 -0
- gen_worker/pb/frontend_pb2.py +53 -0
- gen_worker/pb/frontend_pb2_grpc.py +189 -0
- gen_worker/pb/worker_scheduler_pb2.py +69 -0
- gen_worker/pb/worker_scheduler_pb2_grpc.py +100 -0
- gen_worker/py.typed +0 -0
- gen_worker/testing/__init__.py +1 -0
- gen_worker/testing/stub_manager.py +69 -0
- gen_worker/torch_manager/__init__.py +4 -0
- gen_worker/torch_manager/manager.py +2059 -0
- gen_worker/torch_manager/utils/base_types/architecture.py +145 -0
- gen_worker/torch_manager/utils/base_types/common.py +52 -0
- gen_worker/torch_manager/utils/base_types/config.py +46 -0
- gen_worker/torch_manager/utils/config.py +321 -0
- gen_worker/torch_manager/utils/db/database.py +46 -0
- gen_worker/torch_manager/utils/device.py +26 -0
- gen_worker/torch_manager/utils/diffusers_fix.py +10 -0
- gen_worker/torch_manager/utils/flashpack_loader.py +262 -0
- gen_worker/torch_manager/utils/globals.py +59 -0
- gen_worker/torch_manager/utils/load_models.py +238 -0
- gen_worker/torch_manager/utils/local_cache.py +340 -0
- gen_worker/torch_manager/utils/model_downloader.py +763 -0
- gen_worker/torch_manager/utils/parse_cli.py +98 -0
- gen_worker/torch_manager/utils/paths.py +22 -0
- gen_worker/torch_manager/utils/repository.py +141 -0
- gen_worker/torch_manager/utils/utils.py +43 -0
- gen_worker/types.py +47 -0
- gen_worker/worker.py +1720 -0
- gen_worker-0.1.4.dist-info/METADATA +113 -0
- gen_worker-0.1.4.dist-info/RECORD +38 -0
- gen_worker-0.1.4.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
|
3
|
+
# NO CHECKED-IN PROTOBUF GENCODE
|
|
4
|
+
# source: worker_scheduler.proto
|
|
5
|
+
# Protobuf Python Version: 5.27.4
|
|
6
|
+
"""Generated protocol buffer code."""
|
|
7
|
+
from google.protobuf import descriptor as _descriptor
|
|
8
|
+
from google.protobuf import descriptor_pool as _descriptor_pool
|
|
9
|
+
from google.protobuf import runtime_version as _runtime_version
|
|
10
|
+
from google.protobuf import symbol_database as _symbol_database
|
|
11
|
+
from google.protobuf.internal import builder as _builder
|
|
12
|
+
_runtime_version.ValidateProtobufRuntimeVersion(
|
|
13
|
+
_runtime_version.Domain.PUBLIC,
|
|
14
|
+
5,
|
|
15
|
+
27,
|
|
16
|
+
4,
|
|
17
|
+
'',
|
|
18
|
+
'worker_scheduler.proto'
|
|
19
|
+
)
|
|
20
|
+
# @@protoc_insertion_point(imports)
|
|
21
|
+
|
|
22
|
+
_sym_db = _symbol_database.Default()
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x16worker_scheduler.proto\x12\tscheduler\"\x82\x05\n\x0fWorkerResources\x12\x11\n\tworker_id\x18\x01 \x01(\t\x12\x11\n\tcpu_cores\x18\x02 \x01(\x05\x12\x14\n\x0cmemory_bytes\x18\x03 \x01(\x03\x12\x11\n\tgpu_count\x18\x04 \x01(\x05\x12\x18\n\x10gpu_memory_bytes\x18\x05 \x01(\x03\x12\x1b\n\x13\x61vailable_functions\x18\x06 \x03(\t\x12\x18\n\x10\x61vailable_models\x18\x07 \x03(\t\x12\x1e\n\x16supports_model_loading\x18\x08 \x01(\x08\x12\x15\n\rdeployment_id\x18\t \x01(\t\x12\x15\n\rrunpod_pod_id\x18\n \x01(\t\x12\x13\n\x0bgpu_is_busy\x18\x0b \x01(\x08\x12\x1d\n\x15gpu_memory_used_bytes\x18\x0c \x01(\x03\x12\x10\n\x08gpu_name\x18\r \x01(\t\x12\x12\n\ngpu_driver\x18\x0e \x01(\t\x12\x1d\n\x15gpu_memory_free_bytes\x18\x0f \x01(\x03\x12\x17\n\x0fmax_concurrency\x18\x10 \x01(\x05\x12Q\n\x14\x66unction_concurrency\x18\x11 \x03(\x0b\x32\x33.scheduler.WorkerResources.FunctionConcurrencyEntry\x12\x14\n\x0c\x63uda_version\x18\x12 \x01(\t\x12\x15\n\rtorch_version\x18\x13 \x01(\t\x12\x33\n\x10\x66unction_schemas\x18\x14 \x03(\x0b\x32\x19.scheduler.FunctionSchema\x1a:\n\x18\x46unctionConcurrencyEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\x05:\x02\x38\x01\"U\n\x0e\x46unctionSchema\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x19\n\x11input_schema_json\x18\x02 \x01(\x0c\x12\x1a\n\x12output_schema_json\x18\x03 \x01(\x0c\"Y\n\x12WorkerRegistration\x12-\n\tresources\x18\x01 \x01(\x0b\x32\x1a.scheduler.WorkerResources\x12\x14\n\x0cis_heartbeat\x18\x02 \x01(\x08\"$\n\x10LoadModelCommand\x12\x10\n\x08model_id\x18\x01 \x01(\t\"&\n\x12UnloadModelCommand\x12\x10\n\x08model_id\x18\x01 \x01(\t\"&\n\x14InterruptTaskCommand\x12\x0e\n\x06run_id\x18\x01 \x01(\t\"\xa5\x01\n\x14TaskExecutionRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\t\x12\x15\n\rfunction_name\x18\x02 \x01(\t\x12\x15\n\rinput_payload\x18\x03 \x01(\x0c\x12\x17\n\x0frequired_models\x18\x04 \x03(\t\x12\x12\n\ntimeout_ms\x18\x05 \x01(\x03\x12\x11\n\ttenant_id\x18\x06 \x01(\t\x12\x0f\n\x07user_id\x18\x07 \x01(\t\"e\n\x13TaskExecutionResult\x12\x0e\n\x06run_id\x18\x01 \x01(\t\x12\x0f\n\x07success\x18\x02 \x01(\x08\x12\x16\n\x0eoutput_payload\x18\x03 \x01(\x0c\x12\x15\n\rerror_message\x18\x04 \x01(\t\"`\n\x11SpawnTasksRequest\x12\x15\n\rparent_run_id\x18\x01 \x01(\t\x12\x34\n\x0b\x63hild_tasks\x18\x02 \x03(\x0b\x32\x1f.scheduler.TaskExecutionRequest\"G\n\x0bWorkerEvent\x12\x0e\n\x06run_id\x18\x01 \x01(\t\x12\x12\n\nevent_type\x18\x02 \x01(\t\x12\x14\n\x0cpayload_json\x18\x03 \x01(\x0c\"K\n\x0fLoadModelResult\x12\x10\n\x08model_id\x18\x01 \x01(\t\x12\x0f\n\x07success\x18\x02 \x01(\x08\x12\x15\n\rerror_message\x18\x03 \x01(\t\"M\n\x11UnloadModelResult\x12\x10\n\x08model_id\x18\x01 \x01(\t\x12\x0f\n\x07success\x18\x02 \x01(\x08\x12\x15\n\rerror_message\x18\x03 \x01(\t\"4\n\x15\x44\x65ploymentModelConfig\x12\x1b\n\x13supported_model_ids\x18\x01 \x03(\t\"\x9b\x05\n\x16WorkerSchedulerMessage\x12<\n\x13worker_registration\x18\x01 \x01(\x0b\x32\x1d.scheduler.WorkerRegistrationH\x00\x12\x34\n\nrun_result\x18\x02 \x01(\x0b\x32\x1e.scheduler.TaskExecutionResultH\x00\x12\x33\n\x0bspawn_tasks\x18\x03 \x01(\x0b\x32\x1c.scheduler.SpawnTasksRequestH\x00\x12\x37\n\x11load_model_result\x18\x04 \x01(\x0b\x32\x1a.scheduler.LoadModelResultH\x00\x12;\n\x13unload_model_result\x18\x05 \x01(\x0b\x32\x1c.scheduler.UnloadModelResultH\x00\x12.\n\x0cworker_event\x18\x06 \x01(\x0b\x32\x16.scheduler.WorkerEventH\x00\x12\x36\n\x0brun_request\x18\n \x01(\x0b\x32\x1f.scheduler.TaskExecutionRequestH\x00\x12\x35\n\x0eload_model_cmd\x18\x0b \x01(\x0b\x32\x1b.scheduler.LoadModelCommandH\x00\x12\x39\n\x10unload_model_cmd\x18\x0c \x01(\x0b\x32\x1d.scheduler.UnloadModelCommandH\x00\x12<\n\x11interrupt_run_cmd\x18\r \x01(\x0b\x32\x1f.scheduler.InterruptTaskCommandH\x00\x12\x43\n\x17\x64\x65ployment_model_config\x18\x0e \x01(\x0b\x32 .scheduler.DeploymentModelConfigH\x00\x42\x05\n\x03msg2s\n\x16SchedulerWorkerService\x12Y\n\rConnectWorker\x12!.scheduler.WorkerSchedulerMessage\x1a!.scheduler.WorkerSchedulerMessage(\x01\x30\x01\x42\x41Z?github.com/cozy-creator/gen-orchestrator/pkg/pb/workerSchedulerb\x06proto3')
|
|
28
|
+
|
|
29
|
+
_globals = globals()
|
|
30
|
+
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
|
31
|
+
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'worker_scheduler_pb2', _globals)
|
|
32
|
+
if not _descriptor._USE_C_DESCRIPTORS:
|
|
33
|
+
_globals['DESCRIPTOR']._loaded_options = None
|
|
34
|
+
_globals['DESCRIPTOR']._serialized_options = b'Z?github.com/cozy-creator/gen-orchestrator/pkg/pb/workerScheduler'
|
|
35
|
+
_globals['_WORKERRESOURCES_FUNCTIONCONCURRENCYENTRY']._loaded_options = None
|
|
36
|
+
_globals['_WORKERRESOURCES_FUNCTIONCONCURRENCYENTRY']._serialized_options = b'8\001'
|
|
37
|
+
_globals['_WORKERRESOURCES']._serialized_start=38
|
|
38
|
+
_globals['_WORKERRESOURCES']._serialized_end=680
|
|
39
|
+
_globals['_WORKERRESOURCES_FUNCTIONCONCURRENCYENTRY']._serialized_start=622
|
|
40
|
+
_globals['_WORKERRESOURCES_FUNCTIONCONCURRENCYENTRY']._serialized_end=680
|
|
41
|
+
_globals['_FUNCTIONSCHEMA']._serialized_start=682
|
|
42
|
+
_globals['_FUNCTIONSCHEMA']._serialized_end=767
|
|
43
|
+
_globals['_WORKERREGISTRATION']._serialized_start=769
|
|
44
|
+
_globals['_WORKERREGISTRATION']._serialized_end=858
|
|
45
|
+
_globals['_LOADMODELCOMMAND']._serialized_start=860
|
|
46
|
+
_globals['_LOADMODELCOMMAND']._serialized_end=896
|
|
47
|
+
_globals['_UNLOADMODELCOMMAND']._serialized_start=898
|
|
48
|
+
_globals['_UNLOADMODELCOMMAND']._serialized_end=936
|
|
49
|
+
_globals['_INTERRUPTTASKCOMMAND']._serialized_start=938
|
|
50
|
+
_globals['_INTERRUPTTASKCOMMAND']._serialized_end=976
|
|
51
|
+
_globals['_TASKEXECUTIONREQUEST']._serialized_start=979
|
|
52
|
+
_globals['_TASKEXECUTIONREQUEST']._serialized_end=1144
|
|
53
|
+
_globals['_TASKEXECUTIONRESULT']._serialized_start=1146
|
|
54
|
+
_globals['_TASKEXECUTIONRESULT']._serialized_end=1247
|
|
55
|
+
_globals['_SPAWNTASKSREQUEST']._serialized_start=1249
|
|
56
|
+
_globals['_SPAWNTASKSREQUEST']._serialized_end=1345
|
|
57
|
+
_globals['_WORKEREVENT']._serialized_start=1347
|
|
58
|
+
_globals['_WORKEREVENT']._serialized_end=1418
|
|
59
|
+
_globals['_LOADMODELRESULT']._serialized_start=1420
|
|
60
|
+
_globals['_LOADMODELRESULT']._serialized_end=1495
|
|
61
|
+
_globals['_UNLOADMODELRESULT']._serialized_start=1497
|
|
62
|
+
_globals['_UNLOADMODELRESULT']._serialized_end=1574
|
|
63
|
+
_globals['_DEPLOYMENTMODELCONFIG']._serialized_start=1576
|
|
64
|
+
_globals['_DEPLOYMENTMODELCONFIG']._serialized_end=1628
|
|
65
|
+
_globals['_WORKERSCHEDULERMESSAGE']._serialized_start=1631
|
|
66
|
+
_globals['_WORKERSCHEDULERMESSAGE']._serialized_end=2298
|
|
67
|
+
_globals['_SCHEDULERWORKERSERVICE']._serialized_start=2300
|
|
68
|
+
_globals['_SCHEDULERWORKERSERVICE']._serialized_end=2415
|
|
69
|
+
# @@protoc_insertion_point(module_scope)
|
|
@@ -0,0 +1,100 @@
|
|
|
1
|
+
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
|
|
2
|
+
"""Client and server classes corresponding to protobuf-defined services."""
|
|
3
|
+
import grpc
|
|
4
|
+
import warnings
|
|
5
|
+
|
|
6
|
+
import worker_scheduler_pb2 as worker__scheduler__pb2
|
|
7
|
+
|
|
8
|
+
GRPC_GENERATED_VERSION = '1.76.0'
|
|
9
|
+
GRPC_VERSION = grpc.__version__
|
|
10
|
+
_version_not_supported = False
|
|
11
|
+
|
|
12
|
+
try:
|
|
13
|
+
from grpc._utilities import first_version_is_lower
|
|
14
|
+
_version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION)
|
|
15
|
+
except ImportError:
|
|
16
|
+
_version_not_supported = True
|
|
17
|
+
|
|
18
|
+
if _version_not_supported:
|
|
19
|
+
raise RuntimeError(
|
|
20
|
+
f'The grpc package installed is at version {GRPC_VERSION},'
|
|
21
|
+
+ ' but the generated code in worker_scheduler_pb2_grpc.py depends on'
|
|
22
|
+
+ f' grpcio>={GRPC_GENERATED_VERSION}.'
|
|
23
|
+
+ f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}'
|
|
24
|
+
+ f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.'
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class SchedulerWorkerServiceStub(object):
|
|
29
|
+
"""The single gRPC service with a bidirectional streaming method.
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
def __init__(self, channel):
|
|
33
|
+
"""Constructor.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
channel: A grpc.Channel.
|
|
37
|
+
"""
|
|
38
|
+
self.ConnectWorker = channel.stream_stream(
|
|
39
|
+
'/scheduler.SchedulerWorkerService/ConnectWorker',
|
|
40
|
+
request_serializer=worker__scheduler__pb2.WorkerSchedulerMessage.SerializeToString,
|
|
41
|
+
response_deserializer=worker__scheduler__pb2.WorkerSchedulerMessage.FromString,
|
|
42
|
+
_registered_method=True)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class SchedulerWorkerServiceServicer(object):
|
|
46
|
+
"""The single gRPC service with a bidirectional streaming method.
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
def ConnectWorker(self, request_iterator, context):
|
|
50
|
+
"""Missing associated documentation comment in .proto file."""
|
|
51
|
+
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
|
52
|
+
context.set_details('Method not implemented!')
|
|
53
|
+
raise NotImplementedError('Method not implemented!')
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def add_SchedulerWorkerServiceServicer_to_server(servicer, server):
|
|
57
|
+
rpc_method_handlers = {
|
|
58
|
+
'ConnectWorker': grpc.stream_stream_rpc_method_handler(
|
|
59
|
+
servicer.ConnectWorker,
|
|
60
|
+
request_deserializer=worker__scheduler__pb2.WorkerSchedulerMessage.FromString,
|
|
61
|
+
response_serializer=worker__scheduler__pb2.WorkerSchedulerMessage.SerializeToString,
|
|
62
|
+
),
|
|
63
|
+
}
|
|
64
|
+
generic_handler = grpc.method_handlers_generic_handler(
|
|
65
|
+
'scheduler.SchedulerWorkerService', rpc_method_handlers)
|
|
66
|
+
server.add_generic_rpc_handlers((generic_handler,))
|
|
67
|
+
server.add_registered_method_handlers('scheduler.SchedulerWorkerService', rpc_method_handlers)
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
# This class is part of an EXPERIMENTAL API.
|
|
71
|
+
class SchedulerWorkerService(object):
|
|
72
|
+
"""The single gRPC service with a bidirectional streaming method.
|
|
73
|
+
"""
|
|
74
|
+
|
|
75
|
+
@staticmethod
|
|
76
|
+
def ConnectWorker(request_iterator,
|
|
77
|
+
target,
|
|
78
|
+
options=(),
|
|
79
|
+
channel_credentials=None,
|
|
80
|
+
call_credentials=None,
|
|
81
|
+
insecure=False,
|
|
82
|
+
compression=None,
|
|
83
|
+
wait_for_ready=None,
|
|
84
|
+
timeout=None,
|
|
85
|
+
metadata=None):
|
|
86
|
+
return grpc.experimental.stream_stream(
|
|
87
|
+
request_iterator,
|
|
88
|
+
target,
|
|
89
|
+
'/scheduler.SchedulerWorkerService/ConnectWorker',
|
|
90
|
+
worker__scheduler__pb2.WorkerSchedulerMessage.SerializeToString,
|
|
91
|
+
worker__scheduler__pb2.WorkerSchedulerMessage.FromString,
|
|
92
|
+
options,
|
|
93
|
+
channel_credentials,
|
|
94
|
+
insecure,
|
|
95
|
+
call_credentials,
|
|
96
|
+
compression,
|
|
97
|
+
wait_for_ready,
|
|
98
|
+
timeout,
|
|
99
|
+
metadata,
|
|
100
|
+
_registered_method=True)
|
gen_worker/py.typed
ADDED
|
File without changes
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Testing helpers for gen_worker."""
|
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import os
|
|
5
|
+
from typing import Any, Dict, List, Optional
|
|
6
|
+
|
|
7
|
+
from gen_worker.model_interface import DownloaderType, ModelManagementInterface
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class StubModelManager(ModelManagementInterface):
|
|
11
|
+
"""
|
|
12
|
+
Minimal model manager for E2E testing.
|
|
13
|
+
|
|
14
|
+
Downloads model artifacts via the provided downloader and returns a
|
|
15
|
+
lightweight "pipeline" dict containing the local path.
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
def __init__(self) -> None:
|
|
19
|
+
self._cache_dir = os.getenv("MODEL_CACHE_DIR", "/models")
|
|
20
|
+
self._downloader: Optional[DownloaderType] = None
|
|
21
|
+
self._models: Dict[str, str] = {}
|
|
22
|
+
self._vram_loaded: List[str] = []
|
|
23
|
+
self._lock = asyncio.Lock()
|
|
24
|
+
|
|
25
|
+
async def process_supported_models_config(
|
|
26
|
+
self,
|
|
27
|
+
supported_model_ids: List[str],
|
|
28
|
+
downloader_instance: Optional[DownloaderType],
|
|
29
|
+
) -> None:
|
|
30
|
+
self._downloader = downloader_instance
|
|
31
|
+
if not supported_model_ids:
|
|
32
|
+
return
|
|
33
|
+
for model_id in supported_model_ids:
|
|
34
|
+
await self._ensure_download(model_id)
|
|
35
|
+
if model_id not in self._vram_loaded:
|
|
36
|
+
self._vram_loaded.append(model_id)
|
|
37
|
+
|
|
38
|
+
async def load_model_into_vram(self, model_id: str) -> bool:
|
|
39
|
+
if not model_id:
|
|
40
|
+
return False
|
|
41
|
+
await self._ensure_download(model_id)
|
|
42
|
+
if model_id not in self._vram_loaded:
|
|
43
|
+
self._vram_loaded.append(model_id)
|
|
44
|
+
return True
|
|
45
|
+
|
|
46
|
+
async def get_active_pipeline(self, model_id: str) -> Optional[Any]:
|
|
47
|
+
if not model_id:
|
|
48
|
+
return None
|
|
49
|
+
path = await self._ensure_download(model_id)
|
|
50
|
+
return {
|
|
51
|
+
"model_id": model_id,
|
|
52
|
+
"model_path": path,
|
|
53
|
+
}
|
|
54
|
+
|
|
55
|
+
def get_vram_loaded_models(self) -> List[str]:
|
|
56
|
+
return list(self._vram_loaded)
|
|
57
|
+
|
|
58
|
+
async def _ensure_download(self, model_id: str) -> str:
|
|
59
|
+
if model_id in self._models:
|
|
60
|
+
return self._models[model_id]
|
|
61
|
+
if not self._downloader:
|
|
62
|
+
raise RuntimeError("Model downloader is not configured")
|
|
63
|
+
os.makedirs(self._cache_dir, exist_ok=True)
|
|
64
|
+
async with self._lock:
|
|
65
|
+
if model_id in self._models:
|
|
66
|
+
return self._models[model_id]
|
|
67
|
+
local_path = await self._downloader.download(model_id, self._cache_dir)
|
|
68
|
+
self._models[model_id] = local_path
|
|
69
|
+
return local_path
|