llama-stack 0.4.4__py3-none-any.whl → 0.5.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llama_stack/cli/stack/_list_deps.py +11 -7
- llama_stack/cli/stack/run.py +3 -25
- llama_stack/core/access_control/datatypes.py +78 -0
- llama_stack/core/configure.py +2 -2
- llama_stack/{distributions/meta-reference-gpu → core/connectors}/__init__.py +3 -1
- llama_stack/core/connectors/connectors.py +162 -0
- llama_stack/core/conversations/conversations.py +61 -58
- llama_stack/core/datatypes.py +54 -8
- llama_stack/core/library_client.py +60 -13
- llama_stack/core/prompts/prompts.py +43 -42
- llama_stack/core/routers/datasets.py +20 -17
- llama_stack/core/routers/eval_scoring.py +143 -53
- llama_stack/core/routers/inference.py +20 -9
- llama_stack/core/routers/safety.py +30 -42
- llama_stack/core/routers/vector_io.py +15 -7
- llama_stack/core/routing_tables/models.py +42 -3
- llama_stack/core/routing_tables/scoring_functions.py +19 -19
- llama_stack/core/routing_tables/shields.py +20 -17
- llama_stack/core/routing_tables/vector_stores.py +8 -5
- llama_stack/core/server/auth.py +192 -17
- llama_stack/core/server/fastapi_router_registry.py +40 -5
- llama_stack/core/server/server.py +24 -5
- llama_stack/core/stack.py +54 -10
- llama_stack/core/storage/datatypes.py +9 -0
- llama_stack/core/store/registry.py +1 -1
- llama_stack/core/utils/exec.py +2 -2
- llama_stack/core/utils/type_inspection.py +16 -2
- llama_stack/distributions/dell/config.yaml +4 -1
- llama_stack/distributions/dell/run-with-safety.yaml +4 -1
- llama_stack/distributions/nvidia/config.yaml +4 -1
- llama_stack/distributions/nvidia/run-with-safety.yaml +4 -1
- llama_stack/distributions/oci/config.yaml +4 -1
- llama_stack/distributions/open-benchmark/config.yaml +9 -1
- llama_stack/distributions/postgres-demo/config.yaml +1 -1
- llama_stack/distributions/starter/build.yaml +62 -0
- llama_stack/distributions/starter/config.yaml +22 -3
- llama_stack/distributions/starter/run-with-postgres-store.yaml +22 -3
- llama_stack/distributions/starter/starter.py +13 -1
- llama_stack/distributions/starter-gpu/build.yaml +62 -0
- llama_stack/distributions/starter-gpu/config.yaml +22 -3
- llama_stack/distributions/starter-gpu/run-with-postgres-store.yaml +22 -3
- llama_stack/distributions/template.py +10 -2
- llama_stack/distributions/watsonx/config.yaml +4 -1
- llama_stack/log.py +1 -0
- llama_stack/providers/inline/agents/meta_reference/__init__.py +1 -0
- llama_stack/providers/inline/agents/meta_reference/agents.py +57 -61
- llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py +49 -51
- llama_stack/providers/inline/agents/meta_reference/responses/streaming.py +94 -22
- llama_stack/providers/inline/agents/meta_reference/responses/types.py +2 -1
- llama_stack/providers/inline/agents/meta_reference/responses/utils.py +4 -1
- llama_stack/providers/inline/agents/meta_reference/safety.py +2 -2
- llama_stack/providers/inline/batches/reference/batches.py +2 -1
- llama_stack/providers/inline/eval/meta_reference/eval.py +40 -32
- llama_stack/providers/inline/post_training/huggingface/post_training.py +33 -38
- llama_stack/providers/inline/post_training/huggingface/utils.py +2 -5
- llama_stack/providers/inline/post_training/torchtune/post_training.py +28 -33
- llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py +2 -4
- llama_stack/providers/inline/safety/code_scanner/code_scanner.py +12 -15
- llama_stack/providers/inline/safety/llama_guard/llama_guard.py +15 -18
- llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py +11 -17
- llama_stack/providers/inline/scoring/basic/scoring.py +13 -17
- llama_stack/providers/inline/scoring/braintrust/braintrust.py +15 -15
- llama_stack/providers/inline/scoring/llm_as_judge/scoring.py +13 -17
- llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py +1 -1
- llama_stack/providers/registry/agents.py +1 -0
- llama_stack/providers/registry/inference.py +1 -9
- llama_stack/providers/registry/vector_io.py +136 -16
- llama_stack/providers/remote/eval/nvidia/eval.py +22 -21
- llama_stack/providers/remote/files/s3/config.py +5 -3
- llama_stack/providers/remote/files/s3/files.py +2 -2
- llama_stack/providers/remote/inference/gemini/gemini.py +4 -0
- llama_stack/providers/remote/inference/openai/openai.py +2 -0
- llama_stack/providers/remote/inference/together/together.py +4 -0
- llama_stack/providers/remote/inference/vertexai/config.py +3 -3
- llama_stack/providers/remote/inference/vertexai/vertexai.py +5 -2
- llama_stack/providers/remote/inference/vllm/config.py +37 -18
- llama_stack/providers/remote/inference/vllm/vllm.py +0 -3
- llama_stack/providers/remote/inference/watsonx/watsonx.py +4 -0
- llama_stack/providers/remote/post_training/nvidia/post_training.py +31 -33
- llama_stack/providers/remote/safety/bedrock/bedrock.py +10 -27
- llama_stack/providers/remote/safety/nvidia/nvidia.py +9 -25
- llama_stack/providers/remote/safety/sambanova/sambanova.py +13 -11
- llama_stack/providers/remote/vector_io/elasticsearch/__init__.py +17 -0
- llama_stack/providers/remote/vector_io/elasticsearch/config.py +32 -0
- llama_stack/providers/remote/vector_io/elasticsearch/elasticsearch.py +463 -0
- llama_stack/providers/remote/vector_io/oci/__init__.py +22 -0
- llama_stack/providers/remote/vector_io/oci/config.py +41 -0
- llama_stack/providers/remote/vector_io/oci/oci26ai.py +595 -0
- llama_stack/providers/remote/vector_io/pgvector/config.py +69 -2
- llama_stack/providers/remote/vector_io/pgvector/pgvector.py +255 -6
- llama_stack/providers/remote/vector_io/qdrant/qdrant.py +62 -38
- llama_stack/providers/utils/bedrock/client.py +3 -3
- llama_stack/providers/utils/bedrock/config.py +7 -7
- llama_stack/providers/utils/inference/embedding_mixin.py +4 -0
- llama_stack/providers/utils/inference/http_client.py +239 -0
- llama_stack/providers/utils/inference/litellm_openai_mixin.py +5 -0
- llama_stack/providers/utils/inference/model_registry.py +148 -2
- llama_stack/providers/utils/inference/openai_compat.py +2 -1
- llama_stack/providers/utils/inference/openai_mixin.py +41 -2
- llama_stack/providers/utils/memory/openai_vector_store_mixin.py +92 -5
- llama_stack/providers/utils/memory/vector_store.py +46 -19
- llama_stack/providers/utils/responses/responses_store.py +7 -7
- llama_stack/providers/utils/safety.py +114 -0
- llama_stack/providers/utils/tools/mcp.py +44 -3
- llama_stack/testing/api_recorder.py +9 -3
- {llama_stack-0.4.4.dist-info → llama_stack-0.5.0rc1.dist-info}/METADATA +14 -2
- {llama_stack-0.4.4.dist-info → llama_stack-0.5.0rc1.dist-info}/RECORD +111 -144
- llama_stack/distributions/meta-reference-gpu/config.yaml +0 -140
- llama_stack/distributions/meta-reference-gpu/doc_template.md +0 -119
- llama_stack/distributions/meta-reference-gpu/meta_reference.py +0 -163
- llama_stack/distributions/meta-reference-gpu/run-with-safety.yaml +0 -155
- llama_stack/models/llama/hadamard_utils.py +0 -88
- llama_stack/models/llama/llama3/args.py +0 -74
- llama_stack/models/llama/llama3/dog.jpg +0 -0
- llama_stack/models/llama/llama3/generation.py +0 -378
- llama_stack/models/llama/llama3/model.py +0 -304
- llama_stack/models/llama/llama3/multimodal/__init__.py +0 -12
- llama_stack/models/llama/llama3/multimodal/encoder_utils.py +0 -180
- llama_stack/models/llama/llama3/multimodal/image_transform.py +0 -409
- llama_stack/models/llama/llama3/multimodal/model.py +0 -1430
- llama_stack/models/llama/llama3/multimodal/utils.py +0 -26
- llama_stack/models/llama/llama3/pasta.jpeg +0 -0
- llama_stack/models/llama/llama3/quantization/__init__.py +0 -5
- llama_stack/models/llama/llama3/quantization/loader.py +0 -316
- llama_stack/models/llama/llama3_1/__init__.py +0 -12
- llama_stack/models/llama/llama3_1/prompt_format.md +0 -358
- llama_stack/models/llama/llama3_1/prompts.py +0 -258
- llama_stack/models/llama/llama3_2/__init__.py +0 -5
- llama_stack/models/llama/llama3_2/prompts_text.py +0 -229
- llama_stack/models/llama/llama3_2/prompts_vision.py +0 -126
- llama_stack/models/llama/llama3_2/text_prompt_format.md +0 -286
- llama_stack/models/llama/llama3_2/vision_prompt_format.md +0 -141
- llama_stack/models/llama/llama3_3/__init__.py +0 -5
- llama_stack/models/llama/llama3_3/prompts.py +0 -259
- llama_stack/models/llama/llama4/args.py +0 -107
- llama_stack/models/llama/llama4/ffn.py +0 -58
- llama_stack/models/llama/llama4/moe.py +0 -214
- llama_stack/models/llama/llama4/preprocess.py +0 -435
- llama_stack/models/llama/llama4/quantization/__init__.py +0 -5
- llama_stack/models/llama/llama4/quantization/loader.py +0 -226
- llama_stack/models/llama/llama4/vision/__init__.py +0 -5
- llama_stack/models/llama/llama4/vision/embedding.py +0 -210
- llama_stack/models/llama/llama4/vision/encoder.py +0 -412
- llama_stack/models/llama/quantize_impls.py +0 -316
- llama_stack/providers/inline/inference/meta_reference/__init__.py +0 -20
- llama_stack/providers/inline/inference/meta_reference/common.py +0 -24
- llama_stack/providers/inline/inference/meta_reference/config.py +0 -68
- llama_stack/providers/inline/inference/meta_reference/generators.py +0 -201
- llama_stack/providers/inline/inference/meta_reference/inference.py +0 -542
- llama_stack/providers/inline/inference/meta_reference/model_parallel.py +0 -77
- llama_stack/providers/inline/inference/meta_reference/parallel_utils.py +0 -353
- {llama_stack-0.4.4.dist-info → llama_stack-0.5.0rc1.dist-info}/WHEEL +0 -0
- {llama_stack-0.4.4.dist-info → llama_stack-0.5.0rc1.dist-info}/entry_points.txt +0 -0
- {llama_stack-0.4.4.dist-info → llama_stack-0.5.0rc1.dist-info}/licenses/LICENSE +0 -0
- {llama_stack-0.4.4.dist-info → llama_stack-0.5.0rc1.dist-info}/top_level.txt +0 -0
|
@@ -1,353 +0,0 @@
|
|
|
1
|
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
-
# All rights reserved.
|
|
3
|
-
#
|
|
4
|
-
# This source code is licensed under the terms described in the LICENSE file in
|
|
5
|
-
# the root directory of this source tree.
|
|
6
|
-
|
|
7
|
-
# Copyright (c) Meta Platforms, IAny, nc. and affiliates.
|
|
8
|
-
# All rights reserved.
|
|
9
|
-
#
|
|
10
|
-
# This source code is licensed under the terms described in the LICENSE file in
|
|
11
|
-
# the root directory of this source tree.
|
|
12
|
-
|
|
13
|
-
import copy
|
|
14
|
-
import json
|
|
15
|
-
import multiprocessing
|
|
16
|
-
import os
|
|
17
|
-
import tempfile
|
|
18
|
-
import time
|
|
19
|
-
import uuid
|
|
20
|
-
from collections.abc import Callable, Generator
|
|
21
|
-
from enum import Enum
|
|
22
|
-
from typing import Annotated, Literal
|
|
23
|
-
|
|
24
|
-
import torch
|
|
25
|
-
import zmq
|
|
26
|
-
from fairscale.nn.model_parallel.initialize import (
|
|
27
|
-
get_model_parallel_group,
|
|
28
|
-
get_model_parallel_rank,
|
|
29
|
-
get_model_parallel_src_rank,
|
|
30
|
-
)
|
|
31
|
-
from pydantic import BaseModel, Field
|
|
32
|
-
from torch.distributed.launcher.api import LaunchConfig, elastic_launch
|
|
33
|
-
|
|
34
|
-
from llama_stack.log import get_logger
|
|
35
|
-
from llama_stack.models.llama.datatypes import GenerationResult
|
|
36
|
-
|
|
37
|
-
log = get_logger(name=__name__, category="inference")
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
class ProcessingMessageName(str, Enum):
|
|
41
|
-
ready_request = "ready_request"
|
|
42
|
-
ready_response = "ready_response"
|
|
43
|
-
end_sentinel = "end_sentinel"
|
|
44
|
-
cancel_sentinel = "cancel_sentinel"
|
|
45
|
-
task_request = "task_request"
|
|
46
|
-
task_response = "task_response"
|
|
47
|
-
exception_response = "exception_response"
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
class ReadyRequest(BaseModel):
|
|
51
|
-
type: Literal[ProcessingMessageName.ready_request] = ProcessingMessageName.ready_request
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
class ReadyResponse(BaseModel):
|
|
55
|
-
type: Literal[ProcessingMessageName.ready_response] = ProcessingMessageName.ready_response
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
class EndSentinel(BaseModel):
|
|
59
|
-
type: Literal[ProcessingMessageName.end_sentinel] = ProcessingMessageName.end_sentinel
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
class CancelSentinel(BaseModel):
|
|
63
|
-
type: Literal[ProcessingMessageName.cancel_sentinel] = ProcessingMessageName.cancel_sentinel
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
class TaskRequest(BaseModel):
|
|
67
|
-
type: Literal[ProcessingMessageName.task_request] = ProcessingMessageName.task_request
|
|
68
|
-
task: tuple[str, list]
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
class TaskResponse(BaseModel):
|
|
72
|
-
type: Literal[ProcessingMessageName.task_response] = ProcessingMessageName.task_response
|
|
73
|
-
result: list[GenerationResult]
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
class ExceptionResponse(BaseModel):
|
|
77
|
-
type: Literal[ProcessingMessageName.exception_response] = ProcessingMessageName.exception_response
|
|
78
|
-
error: str
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
ProcessingMessage = (
|
|
82
|
-
ReadyRequest | ReadyResponse | EndSentinel | CancelSentinel | TaskRequest | TaskResponse | ExceptionResponse
|
|
83
|
-
)
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
class ProcessingMessageWrapper(BaseModel):
|
|
87
|
-
payload: Annotated[
|
|
88
|
-
ProcessingMessage,
|
|
89
|
-
Field(discriminator="type"),
|
|
90
|
-
]
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
def mp_rank_0() -> bool:
|
|
94
|
-
return bool(get_model_parallel_rank() == 0)
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
def encode_msg(msg: ProcessingMessage) -> bytes:
|
|
98
|
-
return ProcessingMessageWrapper(payload=msg).model_dump_json().encode("utf-8")
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
def retrieve_requests(reply_socket_url: str):
|
|
102
|
-
if mp_rank_0():
|
|
103
|
-
context = zmq.Context()
|
|
104
|
-
reply_socket = context.socket(zmq.ROUTER)
|
|
105
|
-
reply_socket.connect(reply_socket_url)
|
|
106
|
-
|
|
107
|
-
while True:
|
|
108
|
-
client_id, obj = maybe_get_work(reply_socket)
|
|
109
|
-
if obj is None:
|
|
110
|
-
time.sleep(0.01)
|
|
111
|
-
continue
|
|
112
|
-
|
|
113
|
-
ready_response = ReadyResponse()
|
|
114
|
-
reply_socket.send_multipart([client_id, encode_msg(ready_response)])
|
|
115
|
-
break
|
|
116
|
-
|
|
117
|
-
def send_obj(obj: ProcessingMessage):
|
|
118
|
-
reply_socket.send_multipart([client_id, encode_msg(obj)])
|
|
119
|
-
|
|
120
|
-
while True:
|
|
121
|
-
tasks: list[ProcessingMessage | None] = [None]
|
|
122
|
-
if mp_rank_0():
|
|
123
|
-
client_id, maybe_task_json = maybe_get_work(reply_socket)
|
|
124
|
-
if maybe_task_json is not None:
|
|
125
|
-
task = maybe_parse_message(maybe_task_json)
|
|
126
|
-
# there is still an unknown unclean GeneratorExit happening resulting in a
|
|
127
|
-
# cancel sentinel getting queued _after_ we have finished sending everything :/
|
|
128
|
-
# kind of a hack this is :/
|
|
129
|
-
if task is not None and not isinstance(task, CancelSentinel):
|
|
130
|
-
tasks = [task]
|
|
131
|
-
|
|
132
|
-
torch.distributed.broadcast_object_list(
|
|
133
|
-
tasks,
|
|
134
|
-
src=get_model_parallel_src_rank(),
|
|
135
|
-
group=get_model_parallel_group(),
|
|
136
|
-
)
|
|
137
|
-
|
|
138
|
-
task = tasks[0]
|
|
139
|
-
if task is None:
|
|
140
|
-
time.sleep(0.1)
|
|
141
|
-
else:
|
|
142
|
-
try:
|
|
143
|
-
out = yield task
|
|
144
|
-
if out is None:
|
|
145
|
-
break
|
|
146
|
-
|
|
147
|
-
for obj in out:
|
|
148
|
-
updates: list[ProcessingMessage | None] = [None]
|
|
149
|
-
if mp_rank_0():
|
|
150
|
-
_, update_json = maybe_get_work(reply_socket)
|
|
151
|
-
update = maybe_parse_message(update_json)
|
|
152
|
-
if isinstance(update, CancelSentinel):
|
|
153
|
-
updates = [update]
|
|
154
|
-
else:
|
|
155
|
-
# only send the update if it's not cancelled otherwise the object sits in the socket
|
|
156
|
-
# and gets pulled in the next request lol
|
|
157
|
-
send_obj(TaskResponse(result=obj))
|
|
158
|
-
|
|
159
|
-
torch.distributed.broadcast_object_list(
|
|
160
|
-
updates,
|
|
161
|
-
src=get_model_parallel_src_rank(),
|
|
162
|
-
group=get_model_parallel_group(),
|
|
163
|
-
)
|
|
164
|
-
if isinstance(updates[0], CancelSentinel):
|
|
165
|
-
log.info("quitting generation loop because request was cancelled")
|
|
166
|
-
break
|
|
167
|
-
|
|
168
|
-
if mp_rank_0():
|
|
169
|
-
send_obj(EndSentinel())
|
|
170
|
-
except Exception as e:
|
|
171
|
-
log.exception("exception in generation loop")
|
|
172
|
-
|
|
173
|
-
if mp_rank_0():
|
|
174
|
-
send_obj(ExceptionResponse(error=str(e)))
|
|
175
|
-
|
|
176
|
-
if mp_rank_0():
|
|
177
|
-
send_obj(EndSentinel())
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
def maybe_get_work(sock: zmq.Socket):
|
|
181
|
-
message = None
|
|
182
|
-
client_id = None
|
|
183
|
-
try:
|
|
184
|
-
client_id, obj = sock.recv_multipart(zmq.NOBLOCK)
|
|
185
|
-
message = obj.decode("utf-8")
|
|
186
|
-
except zmq.ZMQError as e:
|
|
187
|
-
if e.errno != zmq.EAGAIN:
|
|
188
|
-
raise e
|
|
189
|
-
|
|
190
|
-
return client_id, message
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
def maybe_parse_message(maybe_json: str | None) -> ProcessingMessage | None:
|
|
194
|
-
if maybe_json is None:
|
|
195
|
-
return None
|
|
196
|
-
try:
|
|
197
|
-
return parse_message(maybe_json)
|
|
198
|
-
except json.JSONDecodeError:
|
|
199
|
-
return None
|
|
200
|
-
except ValueError:
|
|
201
|
-
return None
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
def parse_message(json_str: str) -> ProcessingMessage:
|
|
205
|
-
data = json.loads(json_str)
|
|
206
|
-
return copy.deepcopy(ProcessingMessageWrapper(**data).payload)
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
def worker_process_entrypoint(
|
|
210
|
-
reply_socket_url: str,
|
|
211
|
-
init_model_cb: Callable,
|
|
212
|
-
) -> None:
|
|
213
|
-
model = init_model_cb()
|
|
214
|
-
torch.distributed.barrier()
|
|
215
|
-
time.sleep(1)
|
|
216
|
-
|
|
217
|
-
# run the requests co-routine which retrieves requests from the socket
|
|
218
|
-
# and sends responses (we provide) back to the caller
|
|
219
|
-
req_gen = retrieve_requests(reply_socket_url)
|
|
220
|
-
result = None
|
|
221
|
-
while True:
|
|
222
|
-
try:
|
|
223
|
-
task = req_gen.send(result)
|
|
224
|
-
if isinstance(task, EndSentinel):
|
|
225
|
-
break
|
|
226
|
-
|
|
227
|
-
assert isinstance(task, TaskRequest), task
|
|
228
|
-
result = model(task.task)
|
|
229
|
-
except StopIteration:
|
|
230
|
-
break
|
|
231
|
-
|
|
232
|
-
log.info("[debug] worker process done")
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
def launch_dist_group(
|
|
236
|
-
reply_socket_url: str,
|
|
237
|
-
model_parallel_size: int,
|
|
238
|
-
init_model_cb: Callable,
|
|
239
|
-
**kwargs,
|
|
240
|
-
) -> None:
|
|
241
|
-
with tempfile.TemporaryDirectory() as tmpdir:
|
|
242
|
-
# TODO: track workers and if they terminate, tell parent process about it so cleanup can happen
|
|
243
|
-
launch_config = LaunchConfig(
|
|
244
|
-
max_nodes=1,
|
|
245
|
-
min_nodes=1,
|
|
246
|
-
nproc_per_node=model_parallel_size,
|
|
247
|
-
start_method="fork",
|
|
248
|
-
rdzv_backend="c10d",
|
|
249
|
-
rdzv_endpoint=os.path.join(tmpdir, "rdzv"),
|
|
250
|
-
rdzv_configs={"store_type": "file", "timeout": 90},
|
|
251
|
-
max_restarts=0,
|
|
252
|
-
monitor_interval=1,
|
|
253
|
-
run_id=str(uuid.uuid4()),
|
|
254
|
-
)
|
|
255
|
-
elastic_launch(launch_config, entrypoint=worker_process_entrypoint)(
|
|
256
|
-
reply_socket_url,
|
|
257
|
-
init_model_cb,
|
|
258
|
-
)
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
def start_model_parallel_process(
|
|
262
|
-
model_parallel_size: int,
|
|
263
|
-
init_model_cb: Callable,
|
|
264
|
-
**kwargs,
|
|
265
|
-
):
|
|
266
|
-
context = zmq.Context()
|
|
267
|
-
request_socket = context.socket(zmq.DEALER)
|
|
268
|
-
|
|
269
|
-
# Binding the request socket to a random port
|
|
270
|
-
request_socket.bind("tcp://127.0.0.1:0")
|
|
271
|
-
|
|
272
|
-
main_process_url = request_socket.getsockopt_string(zmq.LAST_ENDPOINT)
|
|
273
|
-
|
|
274
|
-
ctx = multiprocessing.get_context("spawn")
|
|
275
|
-
process = ctx.Process(
|
|
276
|
-
target=launch_dist_group,
|
|
277
|
-
args=(
|
|
278
|
-
main_process_url,
|
|
279
|
-
model_parallel_size,
|
|
280
|
-
init_model_cb,
|
|
281
|
-
),
|
|
282
|
-
kwargs=kwargs,
|
|
283
|
-
)
|
|
284
|
-
process.start()
|
|
285
|
-
|
|
286
|
-
# wait until the model is loaded; rank 0 will send a message to indicate it's ready
|
|
287
|
-
|
|
288
|
-
request_socket.send(encode_msg(ReadyRequest()))
|
|
289
|
-
_response = request_socket.recv()
|
|
290
|
-
log.info("Loaded model...")
|
|
291
|
-
|
|
292
|
-
return request_socket, process
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
class ModelParallelProcessGroup:
|
|
296
|
-
def __init__(
|
|
297
|
-
self,
|
|
298
|
-
model_parallel_size: int,
|
|
299
|
-
init_model_cb: Callable,
|
|
300
|
-
**kwargs,
|
|
301
|
-
):
|
|
302
|
-
self.model_parallel_size = model_parallel_size
|
|
303
|
-
self.init_model_cb = init_model_cb
|
|
304
|
-
self.started = False
|
|
305
|
-
self.running = False
|
|
306
|
-
|
|
307
|
-
def start(self):
|
|
308
|
-
assert not self.started, "process group already started"
|
|
309
|
-
self.request_socket, self.process = start_model_parallel_process(
|
|
310
|
-
self.model_parallel_size,
|
|
311
|
-
self.init_model_cb,
|
|
312
|
-
)
|
|
313
|
-
self.started = True
|
|
314
|
-
|
|
315
|
-
def stop(self):
|
|
316
|
-
assert self.started, "process group not started"
|
|
317
|
-
if self.process.is_alive():
|
|
318
|
-
self.request_socket.send(encode_msg(EndSentinel()), zmq.NOBLOCK)
|
|
319
|
-
self.process.join()
|
|
320
|
-
self.started = False
|
|
321
|
-
|
|
322
|
-
def run_inference(
|
|
323
|
-
self,
|
|
324
|
-
req: tuple[str, list],
|
|
325
|
-
) -> Generator:
|
|
326
|
-
assert not self.running, "inference already running"
|
|
327
|
-
|
|
328
|
-
self.running = True
|
|
329
|
-
try:
|
|
330
|
-
self.request_socket.send(encode_msg(TaskRequest(task=req)))
|
|
331
|
-
while True:
|
|
332
|
-
obj_json = self.request_socket.recv()
|
|
333
|
-
obj = parse_message(obj_json)
|
|
334
|
-
|
|
335
|
-
if isinstance(obj, EndSentinel):
|
|
336
|
-
break
|
|
337
|
-
|
|
338
|
-
if isinstance(obj, ExceptionResponse):
|
|
339
|
-
log.error(f"[debug] got exception {obj.error}")
|
|
340
|
-
raise Exception(obj.error)
|
|
341
|
-
|
|
342
|
-
if isinstance(obj, TaskResponse):
|
|
343
|
-
yield obj.result
|
|
344
|
-
|
|
345
|
-
except GeneratorExit:
|
|
346
|
-
self.request_socket.send(encode_msg(CancelSentinel()))
|
|
347
|
-
while True:
|
|
348
|
-
obj_json = self.request_socket.send()
|
|
349
|
-
obj = parse_message(obj_json)
|
|
350
|
-
if isinstance(obj, EndSentinel):
|
|
351
|
-
break
|
|
352
|
-
finally:
|
|
353
|
-
self.running = False
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|