gen-worker 0.1.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. gen_worker/__init__.py +19 -0
  2. gen_worker/decorators.py +66 -0
  3. gen_worker/default_model_manager/__init__.py +5 -0
  4. gen_worker/downloader.py +84 -0
  5. gen_worker/entrypoint.py +135 -0
  6. gen_worker/errors.py +10 -0
  7. gen_worker/model_interface.py +48 -0
  8. gen_worker/pb/__init__.py +27 -0
  9. gen_worker/pb/frontend_pb2.py +53 -0
  10. gen_worker/pb/frontend_pb2_grpc.py +189 -0
  11. gen_worker/pb/worker_scheduler_pb2.py +69 -0
  12. gen_worker/pb/worker_scheduler_pb2_grpc.py +100 -0
  13. gen_worker/py.typed +0 -0
  14. gen_worker/testing/__init__.py +1 -0
  15. gen_worker/testing/stub_manager.py +69 -0
  16. gen_worker/torch_manager/__init__.py +4 -0
  17. gen_worker/torch_manager/manager.py +2059 -0
  18. gen_worker/torch_manager/utils/base_types/architecture.py +145 -0
  19. gen_worker/torch_manager/utils/base_types/common.py +52 -0
  20. gen_worker/torch_manager/utils/base_types/config.py +46 -0
  21. gen_worker/torch_manager/utils/config.py +321 -0
  22. gen_worker/torch_manager/utils/db/database.py +46 -0
  23. gen_worker/torch_manager/utils/device.py +26 -0
  24. gen_worker/torch_manager/utils/diffusers_fix.py +10 -0
  25. gen_worker/torch_manager/utils/flashpack_loader.py +262 -0
  26. gen_worker/torch_manager/utils/globals.py +59 -0
  27. gen_worker/torch_manager/utils/load_models.py +238 -0
  28. gen_worker/torch_manager/utils/local_cache.py +340 -0
  29. gen_worker/torch_manager/utils/model_downloader.py +763 -0
  30. gen_worker/torch_manager/utils/parse_cli.py +98 -0
  31. gen_worker/torch_manager/utils/paths.py +22 -0
  32. gen_worker/torch_manager/utils/repository.py +141 -0
  33. gen_worker/torch_manager/utils/utils.py +43 -0
  34. gen_worker/types.py +47 -0
  35. gen_worker/worker.py +1720 -0
  36. gen_worker-0.1.4.dist-info/METADATA +113 -0
  37. gen_worker-0.1.4.dist-info/RECORD +38 -0
  38. gen_worker-0.1.4.dist-info/WHEEL +4 -0
@@ -0,0 +1,69 @@
1
+ # -*- coding: utf-8 -*-
2
+ # Generated by the protocol buffer compiler. DO NOT EDIT!
3
+ # NO CHECKED-IN PROTOBUF GENCODE
4
+ # source: worker_scheduler.proto
5
+ # Protobuf Python Version: 5.27.4
6
+ """Generated protocol buffer code."""
7
+ from google.protobuf import descriptor as _descriptor
8
+ from google.protobuf import descriptor_pool as _descriptor_pool
9
+ from google.protobuf import runtime_version as _runtime_version
10
+ from google.protobuf import symbol_database as _symbol_database
11
+ from google.protobuf.internal import builder as _builder
12
+ _runtime_version.ValidateProtobufRuntimeVersion(
13
+ _runtime_version.Domain.PUBLIC,
14
+ 5,
15
+ 27,
16
+ 4,
17
+ '',
18
+ 'worker_scheduler.proto'
19
+ )
20
+ # @@protoc_insertion_point(imports)
21
+
22
+ _sym_db = _symbol_database.Default()
23
+
24
+
25
+
26
+
27
+ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x16worker_scheduler.proto\x12\tscheduler\"\x82\x05\n\x0fWorkerResources\x12\x11\n\tworker_id\x18\x01 \x01(\t\x12\x11\n\tcpu_cores\x18\x02 \x01(\x05\x12\x14\n\x0cmemory_bytes\x18\x03 \x01(\x03\x12\x11\n\tgpu_count\x18\x04 \x01(\x05\x12\x18\n\x10gpu_memory_bytes\x18\x05 \x01(\x03\x12\x1b\n\x13\x61vailable_functions\x18\x06 \x03(\t\x12\x18\n\x10\x61vailable_models\x18\x07 \x03(\t\x12\x1e\n\x16supports_model_loading\x18\x08 \x01(\x08\x12\x15\n\rdeployment_id\x18\t \x01(\t\x12\x15\n\rrunpod_pod_id\x18\n \x01(\t\x12\x13\n\x0bgpu_is_busy\x18\x0b \x01(\x08\x12\x1d\n\x15gpu_memory_used_bytes\x18\x0c \x01(\x03\x12\x10\n\x08gpu_name\x18\r \x01(\t\x12\x12\n\ngpu_driver\x18\x0e \x01(\t\x12\x1d\n\x15gpu_memory_free_bytes\x18\x0f \x01(\x03\x12\x17\n\x0fmax_concurrency\x18\x10 \x01(\x05\x12Q\n\x14\x66unction_concurrency\x18\x11 \x03(\x0b\x32\x33.scheduler.WorkerResources.FunctionConcurrencyEntry\x12\x14\n\x0c\x63uda_version\x18\x12 \x01(\t\x12\x15\n\rtorch_version\x18\x13 \x01(\t\x12\x33\n\x10\x66unction_schemas\x18\x14 \x03(\x0b\x32\x19.scheduler.FunctionSchema\x1a:\n\x18\x46unctionConcurrencyEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\x05:\x02\x38\x01\"U\n\x0e\x46unctionSchema\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x19\n\x11input_schema_json\x18\x02 \x01(\x0c\x12\x1a\n\x12output_schema_json\x18\x03 \x01(\x0c\"Y\n\x12WorkerRegistration\x12-\n\tresources\x18\x01 \x01(\x0b\x32\x1a.scheduler.WorkerResources\x12\x14\n\x0cis_heartbeat\x18\x02 \x01(\x08\"$\n\x10LoadModelCommand\x12\x10\n\x08model_id\x18\x01 \x01(\t\"&\n\x12UnloadModelCommand\x12\x10\n\x08model_id\x18\x01 \x01(\t\"&\n\x14InterruptTaskCommand\x12\x0e\n\x06run_id\x18\x01 \x01(\t\"\xa5\x01\n\x14TaskExecutionRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\t\x12\x15\n\rfunction_name\x18\x02 \x01(\t\x12\x15\n\rinput_payload\x18\x03 \x01(\x0c\x12\x17\n\x0frequired_models\x18\x04 \x03(\t\x12\x12\n\ntimeout_ms\x18\x05 \x01(\x03\x12\x11\n\ttenant_id\x18\x06 \x01(\t\x12\x0f\n\x07user_id\x18\x07 \x01(\t\"e\n\x13TaskExecutionResult\x12\x0e\n\x06run_id\x18\x01 \x01(\t\x12\x0f\n\x07success\x18\x02 \x01(\x08\x12\x16\n\x0eoutput_payload\x18\x03 \x01(\x0c\x12\x15\n\rerror_message\x18\x04 \x01(\t\"`\n\x11SpawnTasksRequest\x12\x15\n\rparent_run_id\x18\x01 \x01(\t\x12\x34\n\x0b\x63hild_tasks\x18\x02 \x03(\x0b\x32\x1f.scheduler.TaskExecutionRequest\"G\n\x0bWorkerEvent\x12\x0e\n\x06run_id\x18\x01 \x01(\t\x12\x12\n\nevent_type\x18\x02 \x01(\t\x12\x14\n\x0cpayload_json\x18\x03 \x01(\x0c\"K\n\x0fLoadModelResult\x12\x10\n\x08model_id\x18\x01 \x01(\t\x12\x0f\n\x07success\x18\x02 \x01(\x08\x12\x15\n\rerror_message\x18\x03 \x01(\t\"M\n\x11UnloadModelResult\x12\x10\n\x08model_id\x18\x01 \x01(\t\x12\x0f\n\x07success\x18\x02 \x01(\x08\x12\x15\n\rerror_message\x18\x03 \x01(\t\"4\n\x15\x44\x65ploymentModelConfig\x12\x1b\n\x13supported_model_ids\x18\x01 \x03(\t\"\x9b\x05\n\x16WorkerSchedulerMessage\x12<\n\x13worker_registration\x18\x01 \x01(\x0b\x32\x1d.scheduler.WorkerRegistrationH\x00\x12\x34\n\nrun_result\x18\x02 \x01(\x0b\x32\x1e.scheduler.TaskExecutionResultH\x00\x12\x33\n\x0bspawn_tasks\x18\x03 \x01(\x0b\x32\x1c.scheduler.SpawnTasksRequestH\x00\x12\x37\n\x11load_model_result\x18\x04 \x01(\x0b\x32\x1a.scheduler.LoadModelResultH\x00\x12;\n\x13unload_model_result\x18\x05 \x01(\x0b\x32\x1c.scheduler.UnloadModelResultH\x00\x12.\n\x0cworker_event\x18\x06 \x01(\x0b\x32\x16.scheduler.WorkerEventH\x00\x12\x36\n\x0brun_request\x18\n \x01(\x0b\x32\x1f.scheduler.TaskExecutionRequestH\x00\x12\x35\n\x0eload_model_cmd\x18\x0b \x01(\x0b\x32\x1b.scheduler.LoadModelCommandH\x00\x12\x39\n\x10unload_model_cmd\x18\x0c \x01(\x0b\x32\x1d.scheduler.UnloadModelCommandH\x00\x12<\n\x11interrupt_run_cmd\x18\r \x01(\x0b\x32\x1f.scheduler.InterruptTaskCommandH\x00\x12\x43\n\x17\x64\x65ployment_model_config\x18\x0e \x01(\x0b\x32 .scheduler.DeploymentModelConfigH\x00\x42\x05\n\x03msg2s\n\x16SchedulerWorkerService\x12Y\n\rConnectWorker\x12!.scheduler.WorkerSchedulerMessage\x1a!.scheduler.WorkerSchedulerMessage(\x01\x30\x01\x42\x41Z?github.com/cozy-creator/gen-orchestrator/pkg/pb/workerSchedulerb\x06proto3')
28
+
29
+ _globals = globals()
30
+ _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
31
+ _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'worker_scheduler_pb2', _globals)
32
+ if not _descriptor._USE_C_DESCRIPTORS:
33
+ _globals['DESCRIPTOR']._loaded_options = None
34
+ _globals['DESCRIPTOR']._serialized_options = b'Z?github.com/cozy-creator/gen-orchestrator/pkg/pb/workerScheduler'
35
+ _globals['_WORKERRESOURCES_FUNCTIONCONCURRENCYENTRY']._loaded_options = None
36
+ _globals['_WORKERRESOURCES_FUNCTIONCONCURRENCYENTRY']._serialized_options = b'8\001'
37
+ _globals['_WORKERRESOURCES']._serialized_start=38
38
+ _globals['_WORKERRESOURCES']._serialized_end=680
39
+ _globals['_WORKERRESOURCES_FUNCTIONCONCURRENCYENTRY']._serialized_start=622
40
+ _globals['_WORKERRESOURCES_FUNCTIONCONCURRENCYENTRY']._serialized_end=680
41
+ _globals['_FUNCTIONSCHEMA']._serialized_start=682
42
+ _globals['_FUNCTIONSCHEMA']._serialized_end=767
43
+ _globals['_WORKERREGISTRATION']._serialized_start=769
44
+ _globals['_WORKERREGISTRATION']._serialized_end=858
45
+ _globals['_LOADMODELCOMMAND']._serialized_start=860
46
+ _globals['_LOADMODELCOMMAND']._serialized_end=896
47
+ _globals['_UNLOADMODELCOMMAND']._serialized_start=898
48
+ _globals['_UNLOADMODELCOMMAND']._serialized_end=936
49
+ _globals['_INTERRUPTTASKCOMMAND']._serialized_start=938
50
+ _globals['_INTERRUPTTASKCOMMAND']._serialized_end=976
51
+ _globals['_TASKEXECUTIONREQUEST']._serialized_start=979
52
+ _globals['_TASKEXECUTIONREQUEST']._serialized_end=1144
53
+ _globals['_TASKEXECUTIONRESULT']._serialized_start=1146
54
+ _globals['_TASKEXECUTIONRESULT']._serialized_end=1247
55
+ _globals['_SPAWNTASKSREQUEST']._serialized_start=1249
56
+ _globals['_SPAWNTASKSREQUEST']._serialized_end=1345
57
+ _globals['_WORKEREVENT']._serialized_start=1347
58
+ _globals['_WORKEREVENT']._serialized_end=1418
59
+ _globals['_LOADMODELRESULT']._serialized_start=1420
60
+ _globals['_LOADMODELRESULT']._serialized_end=1495
61
+ _globals['_UNLOADMODELRESULT']._serialized_start=1497
62
+ _globals['_UNLOADMODELRESULT']._serialized_end=1574
63
+ _globals['_DEPLOYMENTMODELCONFIG']._serialized_start=1576
64
+ _globals['_DEPLOYMENTMODELCONFIG']._serialized_end=1628
65
+ _globals['_WORKERSCHEDULERMESSAGE']._serialized_start=1631
66
+ _globals['_WORKERSCHEDULERMESSAGE']._serialized_end=2298
67
+ _globals['_SCHEDULERWORKERSERVICE']._serialized_start=2300
68
+ _globals['_SCHEDULERWORKERSERVICE']._serialized_end=2415
69
+ # @@protoc_insertion_point(module_scope)
@@ -0,0 +1,100 @@
1
+ # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
2
+ """Client and server classes corresponding to protobuf-defined services."""
3
+ import grpc
4
+ import warnings
5
+
6
+ import worker_scheduler_pb2 as worker__scheduler__pb2
7
+
8
+ GRPC_GENERATED_VERSION = '1.76.0'
9
+ GRPC_VERSION = grpc.__version__
10
+ _version_not_supported = False
11
+
12
+ try:
13
+ from grpc._utilities import first_version_is_lower
14
+ _version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION)
15
+ except ImportError:
16
+ _version_not_supported = True
17
+
18
+ if _version_not_supported:
19
+ raise RuntimeError(
20
+ f'The grpc package installed is at version {GRPC_VERSION},'
21
+ + ' but the generated code in worker_scheduler_pb2_grpc.py depends on'
22
+ + f' grpcio>={GRPC_GENERATED_VERSION}.'
23
+ + f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}'
24
+ + f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.'
25
+ )
26
+
27
+
28
+ class SchedulerWorkerServiceStub(object):
29
+ """The single gRPC service with a bidirectional streaming method.
30
+ """
31
+
32
+ def __init__(self, channel):
33
+ """Constructor.
34
+
35
+ Args:
36
+ channel: A grpc.Channel.
37
+ """
38
+ self.ConnectWorker = channel.stream_stream(
39
+ '/scheduler.SchedulerWorkerService/ConnectWorker',
40
+ request_serializer=worker__scheduler__pb2.WorkerSchedulerMessage.SerializeToString,
41
+ response_deserializer=worker__scheduler__pb2.WorkerSchedulerMessage.FromString,
42
+ _registered_method=True)
43
+
44
+
45
+ class SchedulerWorkerServiceServicer(object):
46
+ """The single gRPC service with a bidirectional streaming method.
47
+ """
48
+
49
+ def ConnectWorker(self, request_iterator, context):
50
+ """Missing associated documentation comment in .proto file."""
51
+ context.set_code(grpc.StatusCode.UNIMPLEMENTED)
52
+ context.set_details('Method not implemented!')
53
+ raise NotImplementedError('Method not implemented!')
54
+
55
+
56
+ def add_SchedulerWorkerServiceServicer_to_server(servicer, server):
57
+ rpc_method_handlers = {
58
+ 'ConnectWorker': grpc.stream_stream_rpc_method_handler(
59
+ servicer.ConnectWorker,
60
+ request_deserializer=worker__scheduler__pb2.WorkerSchedulerMessage.FromString,
61
+ response_serializer=worker__scheduler__pb2.WorkerSchedulerMessage.SerializeToString,
62
+ ),
63
+ }
64
+ generic_handler = grpc.method_handlers_generic_handler(
65
+ 'scheduler.SchedulerWorkerService', rpc_method_handlers)
66
+ server.add_generic_rpc_handlers((generic_handler,))
67
+ server.add_registered_method_handlers('scheduler.SchedulerWorkerService', rpc_method_handlers)
68
+
69
+
70
+ # This class is part of an EXPERIMENTAL API.
71
+ class SchedulerWorkerService(object):
72
+ """The single gRPC service with a bidirectional streaming method.
73
+ """
74
+
75
+ @staticmethod
76
+ def ConnectWorker(request_iterator,
77
+ target,
78
+ options=(),
79
+ channel_credentials=None,
80
+ call_credentials=None,
81
+ insecure=False,
82
+ compression=None,
83
+ wait_for_ready=None,
84
+ timeout=None,
85
+ metadata=None):
86
+ return grpc.experimental.stream_stream(
87
+ request_iterator,
88
+ target,
89
+ '/scheduler.SchedulerWorkerService/ConnectWorker',
90
+ worker__scheduler__pb2.WorkerSchedulerMessage.SerializeToString,
91
+ worker__scheduler__pb2.WorkerSchedulerMessage.FromString,
92
+ options,
93
+ channel_credentials,
94
+ insecure,
95
+ call_credentials,
96
+ compression,
97
+ wait_for_ready,
98
+ timeout,
99
+ metadata,
100
+ _registered_method=True)
gen_worker/py.typed ADDED
File without changes
@@ -0,0 +1 @@
1
+ """Testing helpers for gen_worker."""
@@ -0,0 +1,69 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import os
5
+ from typing import Any, Dict, List, Optional
6
+
7
+ from gen_worker.model_interface import DownloaderType, ModelManagementInterface
8
+
9
+
10
+ class StubModelManager(ModelManagementInterface):
11
+ """
12
+ Minimal model manager for E2E testing.
13
+
14
+ Downloads model artifacts via the provided downloader and returns a
15
+ lightweight "pipeline" dict containing the local path.
16
+ """
17
+
18
+ def __init__(self) -> None:
19
+ self._cache_dir = os.getenv("MODEL_CACHE_DIR", "/models")
20
+ self._downloader: Optional[DownloaderType] = None
21
+ self._models: Dict[str, str] = {}
22
+ self._vram_loaded: List[str] = []
23
+ self._lock = asyncio.Lock()
24
+
25
+ async def process_supported_models_config(
26
+ self,
27
+ supported_model_ids: List[str],
28
+ downloader_instance: Optional[DownloaderType],
29
+ ) -> None:
30
+ self._downloader = downloader_instance
31
+ if not supported_model_ids:
32
+ return
33
+ for model_id in supported_model_ids:
34
+ await self._ensure_download(model_id)
35
+ if model_id not in self._vram_loaded:
36
+ self._vram_loaded.append(model_id)
37
+
38
+ async def load_model_into_vram(self, model_id: str) -> bool:
39
+ if not model_id:
40
+ return False
41
+ await self._ensure_download(model_id)
42
+ if model_id not in self._vram_loaded:
43
+ self._vram_loaded.append(model_id)
44
+ return True
45
+
46
+ async def get_active_pipeline(self, model_id: str) -> Optional[Any]:
47
+ if not model_id:
48
+ return None
49
+ path = await self._ensure_download(model_id)
50
+ return {
51
+ "model_id": model_id,
52
+ "model_path": path,
53
+ }
54
+
55
+ def get_vram_loaded_models(self) -> List[str]:
56
+ return list(self._vram_loaded)
57
+
58
+ async def _ensure_download(self, model_id: str) -> str:
59
+ if model_id in self._models:
60
+ return self._models[model_id]
61
+ if not self._downloader:
62
+ raise RuntimeError("Model downloader is not configured")
63
+ os.makedirs(self._cache_dir, exist_ok=True)
64
+ async with self._lock:
65
+ if model_id in self._models:
66
+ return self._models[model_id]
67
+ local_path = await self._downloader.download(model_id, self._cache_dir)
68
+ self._models[model_id] = local_path
69
+ return local_path
@@ -0,0 +1,4 @@
1
+ from .manager import DefaultModelManager
2
+ from .utils.config import load_config, set_config
3
+
4
+ __all__ = ["DefaultModelManager", "load_config", "set_config"]