indexify 0.3.19__py3-none-any.whl → 0.3.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- indexify/cli/cli.py +12 -0
- indexify/executor/api_objects.py +11 -6
- indexify/executor/blob_store/blob_store.py +69 -0
- indexify/executor/blob_store/local_fs_blob_store.py +48 -0
- indexify/executor/blob_store/metrics/blob_store.py +33 -0
- indexify/executor/blob_store/s3_blob_store.py +88 -0
- indexify/executor/downloader.py +192 -27
- indexify/executor/executor.py +29 -13
- indexify/executor/function_executor/function_executor.py +1 -1
- indexify/executor/function_executor/function_executor_states_container.py +5 -0
- indexify/executor/function_executor/function_executor_status.py +2 -0
- indexify/executor/function_executor/health_checker.py +7 -2
- indexify/executor/function_executor/invocation_state_client.py +4 -2
- indexify/executor/function_executor/single_task_runner.py +2 -0
- indexify/executor/function_executor/task_output.py +8 -1
- indexify/executor/grpc/channel_manager.py +4 -3
- indexify/executor/grpc/function_executor_controller.py +163 -193
- indexify/executor/grpc/metrics/state_reconciler.py +17 -0
- indexify/executor/grpc/metrics/task_controller.py +8 -0
- indexify/executor/grpc/state_reconciler.py +305 -188
- indexify/executor/grpc/state_reporter.py +18 -10
- indexify/executor/grpc/task_controller.py +247 -189
- indexify/executor/metrics/task_reporter.py +17 -0
- indexify/executor/task_reporter.py +217 -94
- indexify/executor/task_runner.py +1 -0
- indexify/proto/executor_api.proto +37 -11
- indexify/proto/executor_api_pb2.py +49 -47
- indexify/proto/executor_api_pb2.pyi +55 -15
- {indexify-0.3.19.dist-info → indexify-0.3.21.dist-info}/METADATA +2 -1
- {indexify-0.3.19.dist-info → indexify-0.3.21.dist-info}/RECORD +32 -27
- indexify/executor/grpc/completed_tasks_container.py +0 -26
- {indexify-0.3.19.dist-info → indexify-0.3.21.dist-info}/WHEEL +0 -0
- {indexify-0.3.19.dist-info → indexify-0.3.21.dist-info}/entry_points.txt +0 -0
indexify/cli/cli.py
CHANGED
@@ -25,6 +25,9 @@ from rich.theme import Theme
|
|
25
25
|
from tensorlake.functions_sdk.image import Image
|
26
26
|
|
27
27
|
from indexify.executor.api_objects import FunctionURI
|
28
|
+
from indexify.executor.blob_store.blob_store import BLOBStore
|
29
|
+
from indexify.executor.blob_store.local_fs_blob_store import LocalFSBLOBStore
|
30
|
+
from indexify.executor.blob_store.s3_blob_store import S3BLOBStore
|
28
31
|
from indexify.executor.executor import Executor
|
29
32
|
from indexify.executor.executor_flavor import ExecutorFlavor
|
30
33
|
from indexify.executor.function_executor.server.subprocess_function_executor_server_factory import (
|
@@ -197,6 +200,14 @@ def executor(
|
|
197
200
|
)
|
198
201
|
exit(1)
|
199
202
|
|
203
|
+
# Enable all available blob stores in OSS because we don't know which one is going to be used.
|
204
|
+
blob_store: BLOBStore = BLOBStore(
|
205
|
+
# Local FS mode is used in tests and in cases when user wants to store data on NFS.
|
206
|
+
local=LocalFSBLOBStore(),
|
207
|
+
# S3 is initiliazed lazily so it's okay to create it even if the user is not going to use it.
|
208
|
+
s3=S3BLOBStore(),
|
209
|
+
)
|
210
|
+
|
200
211
|
prometheus_client.Info("cli", "CLI information").info(
|
201
212
|
{
|
202
213
|
"package": "indexify",
|
@@ -222,6 +233,7 @@ def executor(
|
|
222
233
|
monitoring_server_host=monitoring_server_host,
|
223
234
|
monitoring_server_port=monitoring_server_port,
|
224
235
|
enable_grpc_state_reconciler=enable_grpc_state_reconciler,
|
236
|
+
blob_store=blob_store,
|
225
237
|
).run()
|
226
238
|
|
227
239
|
|
indexify/executor/api_objects.py
CHANGED
@@ -3,6 +3,13 @@ from typing import Any, Dict, List, Optional
|
|
3
3
|
from pydantic import BaseModel
|
4
4
|
|
5
5
|
|
6
|
+
class DataPayload(BaseModel):
|
7
|
+
path: str
|
8
|
+
size: int
|
9
|
+
sha256_hash: str
|
10
|
+
content_type: Optional[str] = None
|
11
|
+
|
12
|
+
|
6
13
|
class Task(BaseModel):
|
7
14
|
id: str
|
8
15
|
namespace: str
|
@@ -16,6 +23,10 @@ class Task(BaseModel):
|
|
16
23
|
"image_uri defines the URI of the image of this task. Optional since some executors do not require it."
|
17
24
|
secret_names: Optional[List[str]] = None
|
18
25
|
"secret_names defines the names of the secrets to set on function executor. Optional for backward compatibility."
|
26
|
+
graph_payload: Optional[DataPayload] = None
|
27
|
+
input_payload: Optional[DataPayload] = None
|
28
|
+
reducer_input_payload: Optional[DataPayload] = None
|
29
|
+
output_payload_uri_prefix: Optional[str] = None
|
19
30
|
|
20
31
|
|
21
32
|
class FunctionURI(BaseModel):
|
@@ -49,12 +60,6 @@ class TaskResult(BaseModel):
|
|
49
60
|
reducer: bool = False
|
50
61
|
|
51
62
|
|
52
|
-
class DataPayload(BaseModel):
|
53
|
-
path: str
|
54
|
-
size: int
|
55
|
-
sha256_hash: str
|
56
|
-
|
57
|
-
|
58
63
|
class IngestFnOutputsResponse(BaseModel):
|
59
64
|
data_payloads: List[DataPayload]
|
60
65
|
stdout: Optional[DataPayload] = None
|
@@ -0,0 +1,69 @@
|
|
1
|
+
from typing import Any, Optional
|
2
|
+
|
3
|
+
from .local_fs_blob_store import LocalFSBLOBStore
|
4
|
+
from .metrics.blob_store import (
|
5
|
+
metric_get_blob_errors,
|
6
|
+
metric_get_blob_latency,
|
7
|
+
metric_get_blob_requests,
|
8
|
+
metric_put_blob_errors,
|
9
|
+
metric_put_blob_latency,
|
10
|
+
metric_put_blob_requests,
|
11
|
+
)
|
12
|
+
from .s3_blob_store import S3BLOBStore
|
13
|
+
|
14
|
+
|
15
|
+
class BLOBStore:
|
16
|
+
"""Dispatches generic BLOB store calls to their real backends."""
|
17
|
+
|
18
|
+
def __init__(
|
19
|
+
self, local: Optional[LocalFSBLOBStore] = None, s3: Optional[S3BLOBStore] = None
|
20
|
+
):
|
21
|
+
"""Creates a BLOB store that uses the supplied BLOB stores."""
|
22
|
+
self._local: Optional[LocalFSBLOBStore] = local
|
23
|
+
self._s3: Optional[S3BLOBStore] = s3
|
24
|
+
|
25
|
+
async def get(self, uri: str, logger: Any) -> bytes:
|
26
|
+
"""Returns binary value stored in BLOB with the supplied URI.
|
27
|
+
|
28
|
+
Raises Exception on error. Raises KeyError if the BLOB doesn't exist.
|
29
|
+
"""
|
30
|
+
with (
|
31
|
+
metric_get_blob_errors.count_exceptions(),
|
32
|
+
metric_get_blob_latency.time(),
|
33
|
+
):
|
34
|
+
metric_get_blob_requests.inc()
|
35
|
+
if _is_file_uri(uri):
|
36
|
+
self._check_local_is_available()
|
37
|
+
return await self._local.get(uri, logger)
|
38
|
+
else:
|
39
|
+
self._check_s3_is_available()
|
40
|
+
return await self._s3.get(uri, logger)
|
41
|
+
|
42
|
+
async def put(self, uri: str, value: bytes, logger: Any) -> None:
|
43
|
+
"""Stores the supplied binary value in a BLOB with the supplied URI.
|
44
|
+
|
45
|
+
Overwrites existing BLOB. Raises Exception on error.
|
46
|
+
"""
|
47
|
+
with (
|
48
|
+
metric_put_blob_errors.count_exceptions(),
|
49
|
+
metric_put_blob_latency.time(),
|
50
|
+
):
|
51
|
+
metric_put_blob_requests.inc()
|
52
|
+
if _is_file_uri(uri):
|
53
|
+
self._check_local_is_available()
|
54
|
+
await self._local.put(uri, value, logger)
|
55
|
+
else:
|
56
|
+
self._check_s3_is_available()
|
57
|
+
await self._s3.put(uri, value, logger)
|
58
|
+
|
59
|
+
def _check_local_is_available(self):
|
60
|
+
if self._local is None:
|
61
|
+
raise RuntimeError("Local file system BLOB store is not available")
|
62
|
+
|
63
|
+
def _check_s3_is_available(self):
|
64
|
+
if self._s3 is None:
|
65
|
+
raise RuntimeError("S3 BLOB store is not available")
|
66
|
+
|
67
|
+
|
68
|
+
def _is_file_uri(uri: str) -> bool:
|
69
|
+
return uri.startswith("file://")
|
@@ -0,0 +1,48 @@
|
|
1
|
+
import asyncio
|
2
|
+
import os
|
3
|
+
import os.path
|
4
|
+
from typing import Any
|
5
|
+
|
6
|
+
|
7
|
+
class LocalFSBLOBStore:
|
8
|
+
"""BLOB store that stores BLOBs in local file system."""
|
9
|
+
|
10
|
+
async def get(self, uri: str, logger: Any) -> bytes:
|
11
|
+
"""Returns binary value stored in file at the supplied URI.
|
12
|
+
|
13
|
+
The URI must be a file URI (starts with "file://"). The path must be absolute.
|
14
|
+
Raises Exception on error. Raises KeyError if the file doesn't exist.
|
15
|
+
"""
|
16
|
+
# Run synchronous code in a thread to not block the event loop.
|
17
|
+
return await asyncio.to_thread(self._sync_get, _path_from_file_uri(uri))
|
18
|
+
|
19
|
+
async def put(self, uri: str, value: bytes, logger: Any) -> None:
|
20
|
+
"""Stores the supplied binary value in a file at the supplied URI.
|
21
|
+
|
22
|
+
The URI must be a file URI (starts with "file://"). The path must be absolute.
|
23
|
+
Overwrites existing file. Raises Exception on error.
|
24
|
+
"""
|
25
|
+
# Run synchronous code in a thread to not block the event loop.
|
26
|
+
return await asyncio.to_thread(self._sync_put, _path_from_file_uri(uri), value)
|
27
|
+
|
28
|
+
def _sync_get(self, path: str) -> bytes:
|
29
|
+
if not os.path.isabs(path):
|
30
|
+
raise ValueError(f"Path {path} must be absolute")
|
31
|
+
|
32
|
+
if os.path.exists(path):
|
33
|
+
with open(path, mode="rb") as blob_file:
|
34
|
+
return blob_file.read()
|
35
|
+
else:
|
36
|
+
raise KeyError(f"File at {path} does not exist")
|
37
|
+
|
38
|
+
def _sync_put(self, path: str, value: bytes) -> None:
|
39
|
+
if not os.path.isabs(path):
|
40
|
+
raise ValueError(f"Path {path} must be absolute")
|
41
|
+
|
42
|
+
os.makedirs(os.path.dirname(path), exist_ok=True)
|
43
|
+
with open(path, mode="wb") as blob_file:
|
44
|
+
blob_file.write(value)
|
45
|
+
|
46
|
+
|
47
|
+
def _path_from_file_uri(uri: str) -> str:
|
48
|
+
return uri[7:] # strip "file://" prefix
|
@@ -0,0 +1,33 @@
|
|
1
|
+
import prometheus_client
|
2
|
+
|
3
|
+
from ...monitoring.metrics import latency_metric_for_fast_operation
|
4
|
+
|
5
|
+
metric_get_blob_requests: prometheus_client.Counter = prometheus_client.Counter(
|
6
|
+
"get_blob_requests",
|
7
|
+
"Number of get blob requests",
|
8
|
+
)
|
9
|
+
metric_get_blob_errors: prometheus_client.Counter = prometheus_client.Counter(
|
10
|
+
"get_blob_request_errors",
|
11
|
+
"Number of get blob request errors",
|
12
|
+
)
|
13
|
+
metric_get_blob_latency: prometheus_client.Histogram = (
|
14
|
+
latency_metric_for_fast_operation(
|
15
|
+
"get_blob_request",
|
16
|
+
"get blob request",
|
17
|
+
)
|
18
|
+
)
|
19
|
+
|
20
|
+
metric_put_blob_requests: prometheus_client.Counter = prometheus_client.Counter(
|
21
|
+
"put_blob_requests",
|
22
|
+
"Number of put blob requests",
|
23
|
+
)
|
24
|
+
metric_put_blob_errors: prometheus_client.Counter = prometheus_client.Counter(
|
25
|
+
"put_blob_request_errors",
|
26
|
+
"Number of put blob request errors",
|
27
|
+
)
|
28
|
+
metric_put_blob_latency: prometheus_client.Histogram = (
|
29
|
+
latency_metric_for_fast_operation(
|
30
|
+
"put_blob_request",
|
31
|
+
"put blob request",
|
32
|
+
)
|
33
|
+
)
|
@@ -0,0 +1,88 @@
|
|
1
|
+
import asyncio
|
2
|
+
from typing import Any, Optional
|
3
|
+
|
4
|
+
import boto3
|
5
|
+
from botocore.config import Config as BotoConfig
|
6
|
+
from botocore.exceptions import ClientError as BotoClientError
|
7
|
+
|
8
|
+
_MAX_RETRIES = 3
|
9
|
+
|
10
|
+
|
11
|
+
class S3BLOBStore:
|
12
|
+
def __init__(self):
|
13
|
+
self._s3_client: Optional[Any] = None
|
14
|
+
|
15
|
+
def _lazy_create_client(self):
|
16
|
+
"""Creates S3 client if it doesn't exist.
|
17
|
+
|
18
|
+
We create the client lazily only if S3 is used.
|
19
|
+
This is because S3 BLOB store is always created by Executor
|
20
|
+
and the creation will fail if user didn't configure S3 credentials and etc.
|
21
|
+
"""
|
22
|
+
if self._s3_client is not None:
|
23
|
+
return
|
24
|
+
|
25
|
+
# The credentials and etc are fetched by boto3 library automatically following
|
26
|
+
# https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html#configuring-credentials
|
27
|
+
# This provides a lot of flexibility for the user and follows a well-known and documented logic.
|
28
|
+
self._s3_client = boto3.client(
|
29
|
+
"s3",
|
30
|
+
config=BotoConfig(
|
31
|
+
# https://boto3.amazonaws.com/v1/documentation/api/latest/guide/retries.html#standard-retry-mode
|
32
|
+
retries={
|
33
|
+
"max_attempts": _MAX_RETRIES,
|
34
|
+
"mode": "standard",
|
35
|
+
}
|
36
|
+
),
|
37
|
+
)
|
38
|
+
|
39
|
+
async def get(self, uri: str, logger: Any) -> bytes:
|
40
|
+
"""Returns binary value stored in S3 object at the supplied URI.
|
41
|
+
|
42
|
+
The URI must be S3 URI (starts with "s3://").
|
43
|
+
Raises Exception on error. Raises KeyError if the object doesn't exist.
|
44
|
+
"""
|
45
|
+
try:
|
46
|
+
self._lazy_create_client()
|
47
|
+
bucket_name, key = _bucket_name_and_object_key_from_uri(uri)
|
48
|
+
response = await asyncio.to_thread(
|
49
|
+
self._s3_client.get_object, Bucket=bucket_name, Key=key
|
50
|
+
)
|
51
|
+
return response["Body"].read()
|
52
|
+
except BotoClientError as e:
|
53
|
+
logger.error("failed to get S3 object", uri=uri, exc_info=e)
|
54
|
+
|
55
|
+
if e.response["Error"]["Code"] == "NoSuchKey":
|
56
|
+
raise KeyError(f"Object {key} does not exist in bucket {bucket_name}")
|
57
|
+
raise
|
58
|
+
except Exception as e:
|
59
|
+
logger.error("failed to get S3 object", uri=uri, exc_info=e)
|
60
|
+
raise
|
61
|
+
|
62
|
+
async def put(self, uri: str, value: bytes, logger: Any) -> None:
|
63
|
+
"""Stores the supplied binary value in a S3 object at the supplied URI.
|
64
|
+
|
65
|
+
The URI must be S3 URI (starts with "s3://").
|
66
|
+
Overwrites existing object. Raises Exception on error.
|
67
|
+
"""
|
68
|
+
try:
|
69
|
+
self._lazy_create_client()
|
70
|
+
bucket_name, key = _bucket_name_and_object_key_from_uri(uri)
|
71
|
+
await asyncio.to_thread(
|
72
|
+
self._s3_client.put_object, Bucket=bucket_name, Key=key, Body=value
|
73
|
+
)
|
74
|
+
except Exception as e:
|
75
|
+
logger.error("failed to set S3 object", uri=uri, exc_info=e)
|
76
|
+
raise
|
77
|
+
|
78
|
+
|
79
|
+
def _bucket_name_and_object_key_from_uri(uri: str) -> tuple[str, str]:
|
80
|
+
# Example S3 object URI:
|
81
|
+
# s3://test-indexify-server-blob-store-eugene-20250411/225b83f4-2aed-40a7-adee-b7a681f817f2
|
82
|
+
if not uri.startswith("s3://"):
|
83
|
+
raise ValueError(f"S3 URI '{uri}' is missing 's3://' prefix")
|
84
|
+
|
85
|
+
parts = uri[5:].split("/", 1)
|
86
|
+
if len(parts) != 2:
|
87
|
+
raise ValueError(f"Failed parsing bucket name from S3 URI '{uri}'")
|
88
|
+
return parts[0], parts[1] # bucket_name, key
|
indexify/executor/downloader.py
CHANGED
@@ -1,13 +1,20 @@
|
|
1
1
|
import asyncio
|
2
2
|
import os
|
3
|
-
from typing import Any, Optional
|
3
|
+
from typing import Any, Optional, Union
|
4
4
|
|
5
5
|
import httpx
|
6
6
|
import nanoid
|
7
7
|
from tensorlake.function_executor.proto.function_executor_pb2 import SerializedObject
|
8
|
+
from tensorlake.function_executor.proto.message_validator import MessageValidator
|
8
9
|
from tensorlake.utils.http_client import get_httpx_client
|
9
10
|
|
10
|
-
from .
|
11
|
+
from indexify.proto.executor_api_pb2 import DataPayload as DataPayloadProto
|
12
|
+
from indexify.proto.executor_api_pb2 import (
|
13
|
+
DataPayloadEncoding,
|
14
|
+
)
|
15
|
+
|
16
|
+
from .api_objects import DataPayload
|
17
|
+
from .blob_store.blob_store import BLOBStore
|
11
18
|
from .metrics.downloader import (
|
12
19
|
metric_graph_download_errors,
|
13
20
|
metric_graph_download_latency,
|
@@ -27,14 +34,24 @@ from .metrics.downloader import (
|
|
27
34
|
|
28
35
|
class Downloader:
|
29
36
|
def __init__(
|
30
|
-
self,
|
37
|
+
self,
|
38
|
+
code_path: str,
|
39
|
+
base_url: str,
|
40
|
+
blob_store: BLOBStore,
|
41
|
+
config_path: Optional[str] = None,
|
31
42
|
):
|
32
|
-
self.
|
43
|
+
self._code_path = code_path
|
33
44
|
self._base_url = base_url
|
34
45
|
self._client = get_httpx_client(config_path, make_async=True)
|
46
|
+
self._blob_store: BLOBStore = blob_store
|
35
47
|
|
36
48
|
async def download_graph(
|
37
|
-
self,
|
49
|
+
self,
|
50
|
+
namespace: str,
|
51
|
+
graph_name: str,
|
52
|
+
graph_version: str,
|
53
|
+
data_payload: Optional[Union[DataPayload, DataPayloadProto]],
|
54
|
+
logger: Any,
|
38
55
|
) -> SerializedObject:
|
39
56
|
logger = logger.bind(module=__name__)
|
40
57
|
with (
|
@@ -47,6 +64,7 @@ class Downloader:
|
|
47
64
|
namespace=namespace,
|
48
65
|
graph_name=graph_name,
|
49
66
|
graph_version=graph_version,
|
67
|
+
data_payload=data_payload,
|
50
68
|
logger=logger,
|
51
69
|
)
|
52
70
|
|
@@ -56,6 +74,7 @@ class Downloader:
|
|
56
74
|
graph_name: str,
|
57
75
|
graph_invocation_id: str,
|
58
76
|
input_key: str,
|
77
|
+
data_payload: Optional[DataPayload],
|
59
78
|
logger: Any,
|
60
79
|
) -> SerializedObject:
|
61
80
|
logger = logger.bind(module=__name__)
|
@@ -70,6 +89,7 @@ class Downloader:
|
|
70
89
|
graph_name=graph_name,
|
71
90
|
graph_invocation_id=graph_invocation_id,
|
72
91
|
input_key=input_key,
|
92
|
+
data_payload=data_payload,
|
73
93
|
logger=logger,
|
74
94
|
)
|
75
95
|
|
@@ -80,6 +100,7 @@ class Downloader:
|
|
80
100
|
function_name: str,
|
81
101
|
graph_invocation_id: str,
|
82
102
|
reducer_output_key: str,
|
103
|
+
data_payload: Optional[Union[DataPayload, DataPayloadProto]],
|
83
104
|
logger: Any,
|
84
105
|
) -> SerializedObject:
|
85
106
|
logger = logger.bind(module=__name__)
|
@@ -89,21 +110,27 @@ class Downloader:
|
|
89
110
|
metric_reducer_init_value_download_latency.time(),
|
90
111
|
):
|
91
112
|
metric_reducer_init_value_downloads.inc()
|
92
|
-
return await self.
|
113
|
+
return await self._download_init_value(
|
93
114
|
namespace=namespace,
|
94
115
|
graph_name=graph_name,
|
95
116
|
function_name=function_name,
|
96
117
|
graph_invocation_id=graph_invocation_id,
|
97
118
|
reducer_output_key=reducer_output_key,
|
119
|
+
data_payload=data_payload,
|
98
120
|
logger=logger,
|
99
121
|
)
|
100
122
|
|
101
123
|
async def _download_graph(
|
102
|
-
self,
|
124
|
+
self,
|
125
|
+
namespace: str,
|
126
|
+
graph_name: str,
|
127
|
+
graph_version: str,
|
128
|
+
data_payload: Optional[Union[DataPayload, DataPayloadProto]],
|
129
|
+
logger: Any,
|
103
130
|
) -> SerializedObject:
|
104
131
|
# Cache graph to reduce load on the server.
|
105
132
|
graph_path = os.path.join(
|
106
|
-
self.
|
133
|
+
self._code_path,
|
107
134
|
"graph_cache",
|
108
135
|
namespace,
|
109
136
|
graph_name,
|
@@ -118,17 +145,41 @@ class Downloader:
|
|
118
145
|
metric_graphs_from_cache.inc()
|
119
146
|
return graph
|
120
147
|
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
148
|
+
if data_payload is None:
|
149
|
+
graph: SerializedObject = await self._fetch_graph_from_server(
|
150
|
+
namespace=namespace,
|
151
|
+
graph_name=graph_name,
|
152
|
+
graph_version=graph_version,
|
153
|
+
logger=logger,
|
154
|
+
)
|
155
|
+
elif isinstance(data_payload, DataPayloadProto):
|
156
|
+
(
|
157
|
+
MessageValidator(data_payload)
|
158
|
+
.required_field("uri")
|
159
|
+
.required_field("encoding")
|
160
|
+
)
|
161
|
+
data: bytes = await self._blob_store.get(
|
162
|
+
uri=data_payload.uri, logger=logger
|
163
|
+
)
|
164
|
+
return _serialized_object_from_data_payload_proto(
|
165
|
+
data_payload=data_payload,
|
166
|
+
data=data,
|
167
|
+
)
|
168
|
+
elif isinstance(data_payload, DataPayload):
|
169
|
+
data: bytes = await self._blob_store.get(
|
170
|
+
uri=data_payload.path, logger=logger
|
171
|
+
)
|
172
|
+
return _serialized_object_from_data_payload(
|
173
|
+
data_payload=data_payload,
|
174
|
+
data=data,
|
175
|
+
)
|
176
|
+
|
127
177
|
# Filesystem operations are synchronous.
|
128
178
|
# Run in a separate thread to not block the main event loop.
|
129
179
|
# We don't need to wait for the write completion so we use create_task.
|
130
180
|
asyncio.create_task(
|
131
|
-
asyncio.to_thread(self._write_cached_graph, graph_path, graph)
|
181
|
+
asyncio.to_thread(self._write_cached_graph, graph_path, graph),
|
182
|
+
name="graph cache write",
|
132
183
|
)
|
133
184
|
|
134
185
|
return graph
|
@@ -145,7 +196,7 @@ class Downloader:
|
|
145
196
|
# Another task already cached the graph.
|
146
197
|
return None
|
147
198
|
|
148
|
-
tmp_path = os.path.join(self.
|
199
|
+
tmp_path = os.path.join(self._code_path, "task_graph_cache", nanoid.generate())
|
149
200
|
os.makedirs(os.path.dirname(tmp_path), exist_ok=True)
|
150
201
|
with open(tmp_path, "wb") as f:
|
151
202
|
f.write(graph.SerializeToString())
|
@@ -162,21 +213,87 @@ class Downloader:
|
|
162
213
|
graph_name: str,
|
163
214
|
graph_invocation_id: str,
|
164
215
|
input_key: str,
|
216
|
+
data_payload: Optional[Union[DataPayload, DataPayloadProto]],
|
165
217
|
logger: Any,
|
166
218
|
) -> SerializedObject:
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
219
|
+
if data_payload is None:
|
220
|
+
first_function_in_graph = graph_invocation_id == input_key.split("|")[-1]
|
221
|
+
if first_function_in_graph:
|
222
|
+
# The first function in Graph gets its input from graph invocation payload.
|
223
|
+
return await self._fetch_graph_invocation_payload_from_server(
|
224
|
+
namespace=namespace,
|
225
|
+
graph_name=graph_name,
|
226
|
+
graph_invocation_id=graph_invocation_id,
|
227
|
+
logger=logger,
|
228
|
+
)
|
229
|
+
else:
|
230
|
+
return await self._fetch_function_input_from_server(
|
231
|
+
input_key=input_key, logger=logger
|
232
|
+
)
|
233
|
+
elif isinstance(data_payload, DataPayloadProto):
|
234
|
+
(
|
235
|
+
MessageValidator(data_payload)
|
236
|
+
.required_field("uri")
|
237
|
+
.required_field("encoding")
|
238
|
+
)
|
239
|
+
data: bytes = await self._blob_store.get(
|
240
|
+
uri=data_payload.uri, logger=logger
|
241
|
+
)
|
242
|
+
return _serialized_object_from_data_payload_proto(
|
243
|
+
data_payload=data_payload,
|
244
|
+
data=data,
|
245
|
+
)
|
246
|
+
elif isinstance(data_payload, DataPayload):
|
247
|
+
data: bytes = await self._blob_store.get(
|
248
|
+
uri=data_payload.path, logger=logger
|
249
|
+
)
|
250
|
+
return _serialized_object_from_data_payload(
|
251
|
+
data_payload=data_payload,
|
252
|
+
data=data,
|
253
|
+
)
|
254
|
+
|
255
|
+
async def _download_init_value(
|
256
|
+
self,
|
257
|
+
namespace: str,
|
258
|
+
graph_name: str,
|
259
|
+
function_name: str,
|
260
|
+
graph_invocation_id: str,
|
261
|
+
reducer_output_key: str,
|
262
|
+
data_payload: Optional[Union[DataPayload, DataPayloadProto]],
|
263
|
+
logger: Any,
|
264
|
+
) -> SerializedObject:
|
265
|
+
if data_payload is None:
|
266
|
+
return await self._fetch_function_init_value_from_server(
|
171
267
|
namespace=namespace,
|
172
268
|
graph_name=graph_name,
|
269
|
+
function_name=function_name,
|
173
270
|
graph_invocation_id=graph_invocation_id,
|
271
|
+
reducer_output_key=reducer_output_key,
|
174
272
|
logger=logger,
|
175
273
|
)
|
176
|
-
|
177
|
-
|
274
|
+
elif isinstance(data_payload, DataPayloadProto):
|
275
|
+
(
|
276
|
+
MessageValidator(data_payload)
|
277
|
+
.required_field("uri")
|
278
|
+
.required_field("encoding")
|
279
|
+
)
|
280
|
+
data: bytes = await self._blob_store.get(
|
281
|
+
uri=data_payload.uri, logger=logger
|
282
|
+
)
|
283
|
+
return _serialized_object_from_data_payload_proto(
|
284
|
+
data_payload=data_payload,
|
285
|
+
data=data,
|
286
|
+
)
|
287
|
+
elif isinstance(data_payload, DataPayload):
|
288
|
+
data: bytes = await self._blob_store.get(
|
289
|
+
uri=data_payload.path, logger=logger
|
290
|
+
)
|
291
|
+
return _serialized_object_from_data_payload(
|
292
|
+
data_payload=data_payload,
|
293
|
+
data=data,
|
294
|
+
)
|
178
295
|
|
179
|
-
async def
|
296
|
+
async def _fetch_graph_from_server(
|
180
297
|
self, namespace: str, graph_name: str, graph_version: str, logger: Any
|
181
298
|
) -> SerializedObject:
|
182
299
|
"""Downloads the compute graph for the task and returns it."""
|
@@ -186,7 +303,7 @@ class Downloader:
|
|
186
303
|
logger=logger,
|
187
304
|
)
|
188
305
|
|
189
|
-
async def
|
306
|
+
async def _fetch_graph_invocation_payload_from_server(
|
190
307
|
self, namespace: str, graph_name: str, graph_invocation_id: str, logger: Any
|
191
308
|
) -> SerializedObject:
|
192
309
|
return await self._fetch_url(
|
@@ -195,7 +312,7 @@ class Downloader:
|
|
195
312
|
logger=logger,
|
196
313
|
)
|
197
314
|
|
198
|
-
async def
|
315
|
+
async def _fetch_function_input_from_server(
|
199
316
|
self, input_key: str, logger: Any
|
200
317
|
) -> SerializedObject:
|
201
318
|
return await self._fetch_url(
|
@@ -204,7 +321,7 @@ class Downloader:
|
|
204
321
|
logger=logger,
|
205
322
|
)
|
206
323
|
|
207
|
-
async def
|
324
|
+
async def _fetch_function_init_value_from_server(
|
208
325
|
self,
|
209
326
|
namespace: str,
|
210
327
|
graph_name: str,
|
@@ -223,7 +340,11 @@ class Downloader:
|
|
223
340
|
async def _fetch_url(
|
224
341
|
self, url: str, resource_description: str, logger: Any
|
225
342
|
) -> SerializedObject:
|
226
|
-
logger.
|
343
|
+
logger.warning(
|
344
|
+
f"downloading resource from Server",
|
345
|
+
url=url,
|
346
|
+
resource_description=resource_description,
|
347
|
+
)
|
227
348
|
response: httpx.Response = await self._client.get(url)
|
228
349
|
try:
|
229
350
|
response.raise_for_status()
|
@@ -252,3 +373,47 @@ def serialized_object_from_http_response(response: httpx.Response) -> Serialized
|
|
252
373
|
return SerializedObject(
|
253
374
|
string=response.text, content_type=response.headers["content-type"]
|
254
375
|
)
|
376
|
+
|
377
|
+
|
378
|
+
def _serialized_object_from_data_payload(
|
379
|
+
data_payload: DataPayload, data: bytes
|
380
|
+
) -> SerializedObject:
|
381
|
+
"""Converts the given data payload and its data into SerializedObject accepted by Function Executor."""
|
382
|
+
if data_payload.content_type in [
|
383
|
+
"application/octet-stream",
|
384
|
+
"application/pickle",
|
385
|
+
]:
|
386
|
+
return SerializedObject(bytes=data, content_type=data_payload.content_type)
|
387
|
+
else:
|
388
|
+
return SerializedObject(
|
389
|
+
string=data.decode("utf-8"), content_type=data_payload.content_type
|
390
|
+
)
|
391
|
+
|
392
|
+
|
393
|
+
def _serialized_object_from_data_payload_proto(
|
394
|
+
data_payload: DataPayloadProto, data: bytes
|
395
|
+
) -> SerializedObject:
|
396
|
+
"""Converts the given data payload and its data into SerializedObject accepted by Function Executor.
|
397
|
+
|
398
|
+
Raises ValueError if the supplied data payload can't be converted into serialized object.
|
399
|
+
"""
|
400
|
+
if data_payload.encoding == DataPayloadEncoding.DATA_PAYLOAD_ENCODING_BINARY_PICKLE:
|
401
|
+
return SerializedObject(
|
402
|
+
bytes=data,
|
403
|
+
content_type="application/octet-stream",
|
404
|
+
)
|
405
|
+
elif data_payload.encoding == DataPayloadEncoding.DATA_PAYLOAD_ENCODING_UTF8_TEXT:
|
406
|
+
return SerializedObject(
|
407
|
+
content_type="text/plain",
|
408
|
+
string=data.decode("utf-8"),
|
409
|
+
)
|
410
|
+
elif data_payload.encoding == DataPayloadEncoding.DATA_PAYLOAD_ENCODING_UTF8_JSON:
|
411
|
+
result = SerializedObject(
|
412
|
+
content_type="application/json",
|
413
|
+
string=data.decode("utf-8"),
|
414
|
+
)
|
415
|
+
return result
|
416
|
+
|
417
|
+
raise ValueError(
|
418
|
+
f"Can't convert data payload {data_payload} into serialized object"
|
419
|
+
)
|