indexify 0.0.42__py3-none-any.whl → 0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- indexify/__init__.py +13 -14
- indexify/base_client.py +48 -21
- indexify/cli.py +235 -0
- indexify/client.py +18 -790
- indexify/error.py +3 -30
- indexify/executor/agent.py +362 -0
- indexify/executor/api_objects.py +43 -0
- indexify/executor/downloader.py +124 -0
- indexify/executor/executor_tasks.py +72 -0
- indexify/executor/function_worker.py +177 -0
- indexify/executor/indexify_executor.py +32 -0
- indexify/executor/task_reporter.py +110 -0
- indexify/executor/task_store.py +113 -0
- indexify/foo +72 -0
- indexify/functions_sdk/data_objects.py +37 -0
- indexify/functions_sdk/graph.py +276 -0
- indexify/functions_sdk/graph_validation.py +69 -0
- indexify/functions_sdk/image.py +26 -0
- indexify/functions_sdk/indexify_functions.py +192 -0
- indexify/functions_sdk/local_cache.py +46 -0
- indexify/functions_sdk/object_serializer.py +61 -0
- indexify/local_client.py +183 -0
- indexify/remote_client.py +319 -0
- indexify-0.2.dist-info/METADATA +151 -0
- indexify-0.2.dist-info/RECORD +32 -0
- indexify-0.2.dist-info/entry_points.txt +3 -0
- indexify/exceptions.py +0 -3
- indexify/extraction_policy.py +0 -75
- indexify/extractor_sdk/__init__.py +0 -14
- indexify/extractor_sdk/data.py +0 -100
- indexify/extractor_sdk/extractor.py +0 -223
- indexify/extractor_sdk/utils.py +0 -102
- indexify/extractors/__init__.py +0 -0
- indexify/extractors/embedding.py +0 -55
- indexify/extractors/pdf_parser.py +0 -93
- indexify/graph.py +0 -133
- indexify/local_runner.py +0 -128
- indexify/runner.py +0 -22
- indexify/utils.py +0 -7
- indexify-0.0.42.dist-info/METADATA +0 -66
- indexify-0.0.42.dist-info/RECORD +0 -25
- {indexify-0.0.42.dist-info → indexify-0.2.dist-info}/LICENSE.txt +0 -0
- {indexify-0.0.42.dist-info → indexify-0.2.dist-info}/WHEEL +0 -0
@@ -0,0 +1,276 @@
|
|
1
|
+
import inspect
|
2
|
+
from collections import defaultdict
|
3
|
+
from typing import (
|
4
|
+
Annotated,
|
5
|
+
Any,
|
6
|
+
Callable,
|
7
|
+
Dict,
|
8
|
+
List,
|
9
|
+
Optional,
|
10
|
+
Set,
|
11
|
+
Type,
|
12
|
+
Union,
|
13
|
+
get_args,
|
14
|
+
get_origin,
|
15
|
+
)
|
16
|
+
|
17
|
+
import cloudpickle
|
18
|
+
import msgpack
|
19
|
+
from pydantic import BaseModel
|
20
|
+
from typing_extensions import get_args, get_origin
|
21
|
+
|
22
|
+
from .data_objects import IndexifyData, RouterOutput
|
23
|
+
from .graph_validation import validate_node, validate_route
|
24
|
+
from .indexify_functions import (
|
25
|
+
IndexifyFunction,
|
26
|
+
IndexifyFunctionWrapper,
|
27
|
+
IndexifyRouter,
|
28
|
+
)
|
29
|
+
from .object_serializer import CloudPickleSerializer, get_serializer
|
30
|
+
|
31
|
+
RouterFn = Annotated[
|
32
|
+
Callable[[IndexifyData], Optional[List[IndexifyFunction]]], "RouterFn"
|
33
|
+
]
|
34
|
+
GraphNode = Annotated[Union[IndexifyFunctionWrapper, RouterFn], "GraphNode"]
|
35
|
+
|
36
|
+
|
37
|
+
def is_pydantic_model_from_annotation(type_annotation):
|
38
|
+
# If it's a string representation
|
39
|
+
if isinstance(type_annotation, str):
|
40
|
+
# Extract the class name from the string
|
41
|
+
class_name = type_annotation.split("'")[-2].split(".")[-1]
|
42
|
+
# This part is tricky and might require additional context or imports
|
43
|
+
# You might need to import the actual class or module where it's defined
|
44
|
+
# For example:
|
45
|
+
# from indexify.functions_sdk.data_objects import File
|
46
|
+
# return issubclass(eval(class_name), BaseModel)
|
47
|
+
return False # Default to False if we can't evaluate
|
48
|
+
|
49
|
+
# If it's a Type object
|
50
|
+
origin = get_origin(type_annotation)
|
51
|
+
if origin is not None:
|
52
|
+
# Handle generic types like List[File], Optional[File], etc.
|
53
|
+
args = get_args(type_annotation)
|
54
|
+
if args:
|
55
|
+
return is_pydantic_model_from_annotation(args[0])
|
56
|
+
|
57
|
+
# If it's a direct class reference
|
58
|
+
if isinstance(type_annotation, type):
|
59
|
+
return issubclass(type_annotation, BaseModel)
|
60
|
+
|
61
|
+
return False
|
62
|
+
|
63
|
+
|
64
|
+
class FunctionMetadata(BaseModel):
|
65
|
+
name: str
|
66
|
+
fn_name: str
|
67
|
+
description: str
|
68
|
+
reducer: bool = False
|
69
|
+
payload_encoder: str = "cloudpickle"
|
70
|
+
|
71
|
+
|
72
|
+
class RouterMetadata(BaseModel):
|
73
|
+
name: str
|
74
|
+
description: str
|
75
|
+
source_fn: str
|
76
|
+
target_fns: List[str]
|
77
|
+
payload_encoder: str = "cloudpickle"
|
78
|
+
|
79
|
+
|
80
|
+
class NodeMetadata(BaseModel):
|
81
|
+
dynamic_router: Optional[RouterMetadata] = None
|
82
|
+
compute_fn: Optional[FunctionMetadata] = None
|
83
|
+
|
84
|
+
|
85
|
+
class ComputeGraphMetadata(BaseModel):
|
86
|
+
name: str
|
87
|
+
description: str
|
88
|
+
start_node: NodeMetadata
|
89
|
+
nodes: Dict[str, NodeMetadata]
|
90
|
+
edges: Dict[str, List[str]]
|
91
|
+
accumulator_zero_values: Dict[str, bytes] = {}
|
92
|
+
|
93
|
+
def get_input_payload_serializer(self):
|
94
|
+
return get_serializer(self.start_node.compute_fn.payload_encoder)
|
95
|
+
|
96
|
+
|
97
|
+
class Graph:
|
98
|
+
def __init__(
|
99
|
+
self, name: str, start_node: IndexifyFunction, description: Optional[str] = None
|
100
|
+
):
|
101
|
+
self.name = name
|
102
|
+
self.description = description
|
103
|
+
self.nodes: Dict[str, Union[IndexifyFunction, IndexifyRouter]] = {}
|
104
|
+
self.routers: Dict[str, List[str]] = defaultdict(list)
|
105
|
+
self.edges: Dict[str, List[str]] = defaultdict(list)
|
106
|
+
self.accumulator_zero_values: Dict[str, Any] = {}
|
107
|
+
|
108
|
+
self.add_node(start_node)
|
109
|
+
self._start_node: str = start_node.name
|
110
|
+
|
111
|
+
def get_function(self, name: str) -> IndexifyFunctionWrapper:
|
112
|
+
if name not in self.nodes:
|
113
|
+
raise ValueError(f"Function {name} not found in graph")
|
114
|
+
return IndexifyFunctionWrapper(self.nodes[name])
|
115
|
+
|
116
|
+
def get_accumulators(self) -> Dict[str, Any]:
|
117
|
+
return self.accumulator_zero_values
|
118
|
+
|
119
|
+
def deserialize_fn_output(self, name: str, output: IndexifyData) -> Any:
|
120
|
+
serializer = get_serializer(self.nodes[name].payload_encoder)
|
121
|
+
return serializer.deserialize(output.payload)
|
122
|
+
|
123
|
+
def add_node(
|
124
|
+
self, indexify_fn: Union[Type[IndexifyFunction], Type[IndexifyRouter]]
|
125
|
+
) -> "Graph":
|
126
|
+
validate_node(indexify_fn=indexify_fn)
|
127
|
+
|
128
|
+
if indexify_fn.name in self.nodes:
|
129
|
+
return self
|
130
|
+
|
131
|
+
if issubclass(indexify_fn, IndexifyFunction) and indexify_fn.accumulate:
|
132
|
+
self.accumulator_zero_values[
|
133
|
+
indexify_fn.name
|
134
|
+
] = indexify_fn.accumulate().model_dump()
|
135
|
+
|
136
|
+
self.nodes[indexify_fn.name] = indexify_fn
|
137
|
+
return self
|
138
|
+
|
139
|
+
def route(
|
140
|
+
self, from_node: Type[IndexifyRouter], to_nodes: List[Type[IndexifyFunction]]
|
141
|
+
) -> "Graph":
|
142
|
+
|
143
|
+
validate_route(from_node=from_node, to_nodes=to_nodes)
|
144
|
+
|
145
|
+
print(
|
146
|
+
f"Adding router {from_node.name} to nodes {[node.name for node in to_nodes]}"
|
147
|
+
)
|
148
|
+
self.add_node(from_node)
|
149
|
+
for node in to_nodes:
|
150
|
+
self.add_node(node)
|
151
|
+
self.routers[from_node.name].append(node.name)
|
152
|
+
return self
|
153
|
+
|
154
|
+
def serialize(self):
|
155
|
+
return cloudpickle.dumps(self)
|
156
|
+
|
157
|
+
@staticmethod
|
158
|
+
def deserialize(graph: bytes) -> "Graph":
|
159
|
+
return cloudpickle.loads(graph)
|
160
|
+
|
161
|
+
@staticmethod
|
162
|
+
def from_path(path: str) -> "Graph":
|
163
|
+
with open(path, "rb") as f:
|
164
|
+
return cloudpickle.load(f)
|
165
|
+
|
166
|
+
def add_edge(
|
167
|
+
self,
|
168
|
+
from_node: Type[IndexifyFunction],
|
169
|
+
to_node: Union[Type[IndexifyFunction], RouterFn],
|
170
|
+
) -> "Graph":
|
171
|
+
self.add_edges(from_node, [to_node])
|
172
|
+
return self
|
173
|
+
|
174
|
+
def invoke_fn_ser(
|
175
|
+
self, name: str, input: IndexifyData, acc: Optional[Any] = None
|
176
|
+
) -> List[IndexifyData]:
|
177
|
+
fn_wrapper = self.get_function(name)
|
178
|
+
input = self.deserialize_input(name, input)
|
179
|
+
serializer = get_serializer(fn_wrapper.indexify_function.payload_encoder)
|
180
|
+
if acc is not None:
|
181
|
+
acc = fn_wrapper.indexify_function.accumulate.model_validate(
|
182
|
+
serializer.deserialize(acc.payload)
|
183
|
+
)
|
184
|
+
if acc is None and fn_wrapper.indexify_function.accumulate is not None:
|
185
|
+
acc = fn_wrapper.indexify_function.accumulate.model_validate(
|
186
|
+
self.accumulator_zero_values[name]
|
187
|
+
)
|
188
|
+
outputs: List[Any] = fn_wrapper.run_fn(input, acc=acc)
|
189
|
+
return [
|
190
|
+
IndexifyData(payload=serializer.serialize(output)) for output in outputs
|
191
|
+
]
|
192
|
+
|
193
|
+
def invoke_router(self, name: str, input: IndexifyData) -> Optional[RouterOutput]:
|
194
|
+
fn_wrapper = self.get_function(name)
|
195
|
+
input = self.deserialize_input(name, input)
|
196
|
+
return RouterOutput(edges=fn_wrapper.run_router(input))
|
197
|
+
|
198
|
+
def deserialize_input(self, compute_fn: str, indexify_data: IndexifyData) -> Any:
|
199
|
+
compute_fn = self.nodes[compute_fn]
|
200
|
+
if not compute_fn:
|
201
|
+
raise ValueError(f"Compute function {compute_fn} not found in graph")
|
202
|
+
if compute_fn.payload_encoder == "cloudpickle":
|
203
|
+
return CloudPickleSerializer.deserialize(indexify_data.payload)
|
204
|
+
payload = msgpack.unpackb(indexify_data.payload)
|
205
|
+
signature = inspect.signature(compute_fn.run)
|
206
|
+
arg_types = {}
|
207
|
+
for name, param in signature.parameters.items():
|
208
|
+
if (
|
209
|
+
param.annotation != inspect.Parameter.empty
|
210
|
+
and param.annotation != getattr(compute_fn, "accumulate", None)
|
211
|
+
):
|
212
|
+
arg_types[name] = param.annotation
|
213
|
+
if len(arg_types) > 1:
|
214
|
+
raise ValueError(
|
215
|
+
f"Compute function {compute_fn} has multiple arguments, but only one is supported"
|
216
|
+
)
|
217
|
+
elif len(arg_types) == 0:
|
218
|
+
raise ValueError(f"Compute function {compute_fn} has no arguments")
|
219
|
+
arg_name, arg_type = next(iter(arg_types.items()))
|
220
|
+
if arg_type is None:
|
221
|
+
raise ValueError(f"Argument {arg_name} has no type annotation")
|
222
|
+
if is_pydantic_model_from_annotation(arg_type):
|
223
|
+
if len(payload.keys()) == 1 and isinstance(list(payload.values())[0], dict):
|
224
|
+
payload = list(payload.values())[0]
|
225
|
+
return arg_type.model_validate(payload)
|
226
|
+
return payload
|
227
|
+
|
228
|
+
def add_edges(
|
229
|
+
self,
|
230
|
+
from_node: Union[Type[IndexifyFunction], Type[IndexifyRouter]],
|
231
|
+
to_node: List[Union[Type[IndexifyFunction], Type[IndexifyRouter]]],
|
232
|
+
) -> "Graph":
|
233
|
+
self.add_node(from_node)
|
234
|
+
from_node_name = from_node.name
|
235
|
+
for node in to_node:
|
236
|
+
self.add_node(node)
|
237
|
+
self.edges[from_node_name].append(node.name)
|
238
|
+
return self
|
239
|
+
|
240
|
+
def definition(self) -> ComputeGraphMetadata:
|
241
|
+
start_node = self.nodes[self._start_node]
|
242
|
+
start_node = FunctionMetadata(
|
243
|
+
name=start_node.name,
|
244
|
+
fn_name=start_node.fn_name,
|
245
|
+
description=start_node.description,
|
246
|
+
reducer=start_node.accumulate is not None,
|
247
|
+
)
|
248
|
+
metadata_edges = self.edges.copy()
|
249
|
+
metadata_nodes = {}
|
250
|
+
for node_name, node in self.nodes.items():
|
251
|
+
if node_name in self.routers:
|
252
|
+
metadata_nodes[node_name] = NodeMetadata(
|
253
|
+
dynamic_router=RouterMetadata(
|
254
|
+
name=node_name,
|
255
|
+
description=node.description or "",
|
256
|
+
source_fn=node_name,
|
257
|
+
target_fns=self.routers[node_name],
|
258
|
+
payload_encoder=node.payload_encoder,
|
259
|
+
)
|
260
|
+
)
|
261
|
+
else:
|
262
|
+
metadata_nodes[node_name] = NodeMetadata(
|
263
|
+
compute_fn=FunctionMetadata(
|
264
|
+
name=node_name,
|
265
|
+
fn_name=node.fn_name,
|
266
|
+
description=node.description,
|
267
|
+
reducer=node.accumulate is not None,
|
268
|
+
)
|
269
|
+
)
|
270
|
+
return ComputeGraphMetadata(
|
271
|
+
name=self.name,
|
272
|
+
description=self.description or "",
|
273
|
+
start_node=NodeMetadata(compute_fn=start_node),
|
274
|
+
nodes=metadata_nodes,
|
275
|
+
edges=metadata_edges,
|
276
|
+
)
|
@@ -0,0 +1,69 @@
|
|
1
|
+
import inspect
|
2
|
+
import re
|
3
|
+
from typing import List, Type, Union
|
4
|
+
|
5
|
+
from .indexify_functions import IndexifyFunction, IndexifyRouter
|
6
|
+
|
7
|
+
|
8
|
+
def validate_node(indexify_fn: Union[Type[IndexifyFunction], Type[IndexifyRouter]]):
|
9
|
+
if inspect.isfunction(indexify_fn):
|
10
|
+
raise Exception(
|
11
|
+
f"Unable to add node of type `{type(indexify_fn)}`. "
|
12
|
+
f"Required, `IndexifyFunction` or `IndexifyRouter`"
|
13
|
+
)
|
14
|
+
if not (
|
15
|
+
issubclass(indexify_fn, IndexifyFunction)
|
16
|
+
or issubclass(indexify_fn, IndexifyRouter)
|
17
|
+
):
|
18
|
+
raise Exception(
|
19
|
+
f"Unable to add node of type `{indexify_fn.__name__}`. "
|
20
|
+
f"Required, `IndexifyFunction` or `IndexifyRouter`"
|
21
|
+
)
|
22
|
+
|
23
|
+
signature = inspect.signature(indexify_fn.run)
|
24
|
+
|
25
|
+
for param in signature.parameters.values():
|
26
|
+
if param.name == "self":
|
27
|
+
continue
|
28
|
+
if param.annotation == inspect.Parameter.empty:
|
29
|
+
raise Exception(
|
30
|
+
f"Input param {param.name} in {indexify_fn.name} has empty"
|
31
|
+
f" type annotation"
|
32
|
+
)
|
33
|
+
|
34
|
+
if signature.return_annotation == inspect.Signature.empty:
|
35
|
+
raise Exception(f"Function {indexify_fn.name} has empty return type annotation")
|
36
|
+
|
37
|
+
|
38
|
+
def validate_route(
|
39
|
+
from_node: Type[IndexifyRouter], to_nodes: List[Type[IndexifyFunction]]
|
40
|
+
):
|
41
|
+
signature = inspect.signature(from_node.run)
|
42
|
+
|
43
|
+
if signature.return_annotation == inspect.Signature.empty:
|
44
|
+
raise Exception(f"Function {from_node.name} has empty return type annotation")
|
45
|
+
|
46
|
+
# We lose the exact type string when the object is created
|
47
|
+
source = inspect.getsource(from_node.run)
|
48
|
+
|
49
|
+
union_pattern = r"Union\[((?:\w+(?:,\s*)?)+)\]"
|
50
|
+
union_match = re.search(union_pattern, source)
|
51
|
+
|
52
|
+
src_route_nodes = None
|
53
|
+
if union_match:
|
54
|
+
# nodes = re.findall(r'\w+', match.group(1))
|
55
|
+
src_route_nodes = [node.strip() for node in union_match.group(1).split(",")]
|
56
|
+
if len(src_route_nodes) <= 1:
|
57
|
+
raise Exception(f"Invalid router for {from_node.name}, lte 1 route.")
|
58
|
+
else:
|
59
|
+
raise Exception(
|
60
|
+
f"Invalid router for {from_node.name}, cannot find output nodes"
|
61
|
+
)
|
62
|
+
|
63
|
+
to_node_names = [i.name for i in to_nodes]
|
64
|
+
|
65
|
+
for src_node in src_route_nodes:
|
66
|
+
if src_node not in to_node_names:
|
67
|
+
raise Exception(
|
68
|
+
f"Unable to find {src_node} in to_nodes " f"{to_node_names}"
|
69
|
+
)
|
@@ -0,0 +1,26 @@
|
|
1
|
+
class Image:
|
2
|
+
def __init__(self):
|
3
|
+
self._image_name = None
|
4
|
+
|
5
|
+
self._tag = "latest"
|
6
|
+
|
7
|
+
self._base_image = None
|
8
|
+
|
9
|
+
self._run_strs = []
|
10
|
+
pass
|
11
|
+
|
12
|
+
def image_name(self, image_name):
|
13
|
+
self._image_name = image_name
|
14
|
+
return self
|
15
|
+
|
16
|
+
def tag(self, tag):
|
17
|
+
self._tag = tag
|
18
|
+
return self
|
19
|
+
|
20
|
+
def base_image(self, base_image):
|
21
|
+
self._base_image = base_image
|
22
|
+
return self
|
23
|
+
|
24
|
+
def run(self, run_str):
|
25
|
+
self._run_strs.append(run_str)
|
26
|
+
return self
|
@@ -0,0 +1,192 @@
|
|
1
|
+
from abc import ABC, abstractmethod
|
2
|
+
from functools import update_wrapper
|
3
|
+
from typing import (
|
4
|
+
Any,
|
5
|
+
Callable,
|
6
|
+
Dict,
|
7
|
+
List,
|
8
|
+
Optional,
|
9
|
+
Type,
|
10
|
+
Union,
|
11
|
+
get_args,
|
12
|
+
get_origin,
|
13
|
+
)
|
14
|
+
|
15
|
+
from pydantic import BaseModel
|
16
|
+
from typing_extensions import get_type_hints
|
17
|
+
|
18
|
+
from .data_objects import IndexifyData, RouterOutput
|
19
|
+
from .image import Image
|
20
|
+
|
21
|
+
|
22
|
+
class EmbeddingIndexes(BaseModel):
|
23
|
+
dim: int
|
24
|
+
distance: Optional[str] = "cosine"
|
25
|
+
database_url: Optional[str] = None
|
26
|
+
|
27
|
+
|
28
|
+
class PlacementConstraints(BaseModel):
|
29
|
+
min_python_version: Optional[str] = "3.9"
|
30
|
+
max_python_version: Optional[str] = None
|
31
|
+
platform: Optional[str] = None
|
32
|
+
image_name: Optional[str] = None
|
33
|
+
|
34
|
+
|
35
|
+
class IndexifyFunction(ABC):
|
36
|
+
name: str = ""
|
37
|
+
base_image: Optional[str] = None
|
38
|
+
description: str = ""
|
39
|
+
placement_constraints: List[PlacementConstraints] = []
|
40
|
+
accumulate: Optional[Type[Any]] = None
|
41
|
+
payload_encoder: Optional[str] = "cloudpickle"
|
42
|
+
|
43
|
+
@abstractmethod
|
44
|
+
def run(self, *args, **kwargs) -> Union[List[Any], Any]:
|
45
|
+
pass
|
46
|
+
|
47
|
+
def partial(self, **kwargs) -> Callable:
|
48
|
+
from functools import partial
|
49
|
+
|
50
|
+
return partial(self.run, **kwargs)
|
51
|
+
|
52
|
+
|
53
|
+
class IndexifyRouter(ABC):
|
54
|
+
name: str = ""
|
55
|
+
description: str = ""
|
56
|
+
image: Image = None
|
57
|
+
placement_constraints: List[PlacementConstraints] = []
|
58
|
+
payload_encoder: Optional[str] = "cloudpickle"
|
59
|
+
|
60
|
+
@abstractmethod
|
61
|
+
def run(self, *args, **kwargs) -> Optional[List[IndexifyFunction]]:
|
62
|
+
pass
|
63
|
+
|
64
|
+
|
65
|
+
def indexify_router(
|
66
|
+
name: Optional[str] = None,
|
67
|
+
description: Optional[str] = "",
|
68
|
+
image: Image = None,
|
69
|
+
placement_constraints: List[PlacementConstraints] = [],
|
70
|
+
output_encoder: Optional[str] = "cloudpickle",
|
71
|
+
):
|
72
|
+
def construct(fn):
|
73
|
+
args = locals().copy()
|
74
|
+
args["name"] = args["name"] if args.get("name", None) else fn.__name__
|
75
|
+
args["fn_name"] = fn.__name__
|
76
|
+
args["description"] = (
|
77
|
+
args["description"]
|
78
|
+
if args.get("description", None)
|
79
|
+
else (fn.__doc__ or "").strip().replace("\n", "")
|
80
|
+
)
|
81
|
+
|
82
|
+
class IndexifyRo(IndexifyRouter):
|
83
|
+
def run(self, *args, **kwargs) -> Optional[List[IndexifyFunction]]:
|
84
|
+
return fn(*args, **kwargs)
|
85
|
+
|
86
|
+
update_wrapper(run, fn)
|
87
|
+
|
88
|
+
for key, value in args.items():
|
89
|
+
if key != "fn" and key != "self":
|
90
|
+
setattr(IndexifyRo, key, value)
|
91
|
+
|
92
|
+
IndexifyRo.image = image
|
93
|
+
IndexifyRo.payload_encoder = output_encoder
|
94
|
+
return IndexifyRo
|
95
|
+
|
96
|
+
return construct
|
97
|
+
|
98
|
+
|
99
|
+
def indexify_function(
|
100
|
+
name: Optional[str] = None,
|
101
|
+
description: Optional[str] = "",
|
102
|
+
image: Image = None,
|
103
|
+
accumulate: Optional[Type[BaseModel]] = None,
|
104
|
+
min_batch_size: Optional[int] = None,
|
105
|
+
max_batch_size: Optional[int] = None,
|
106
|
+
output_encoder: Optional[str] = "cloudpickle",
|
107
|
+
placement_constraints: List[PlacementConstraints] = [],
|
108
|
+
):
|
109
|
+
def construct(fn):
|
110
|
+
args = locals().copy()
|
111
|
+
args["name"] = args["name"] if args.get("name", None) else fn.__name__
|
112
|
+
args["fn_name"] = fn.__name__
|
113
|
+
args["description"] = (
|
114
|
+
args["description"]
|
115
|
+
if args.get("description", None)
|
116
|
+
else (fn.__doc__ or "").strip().replace("\n", "")
|
117
|
+
)
|
118
|
+
|
119
|
+
class IndexifyFn(IndexifyFunction):
|
120
|
+
def run(self, *args, **kwargs) -> Union[List[Any], Any]:
|
121
|
+
return fn(*args, **kwargs)
|
122
|
+
|
123
|
+
update_wrapper(run, fn)
|
124
|
+
|
125
|
+
for key, value in args.items():
|
126
|
+
if key != "fn" and key != "self":
|
127
|
+
setattr(IndexifyFn, key, value)
|
128
|
+
|
129
|
+
IndexifyFn.image = image
|
130
|
+
IndexifyFn.accumulate = accumulate
|
131
|
+
IndexifyFn.min_batch_size = min_batch_size
|
132
|
+
IndexifyFn.max_batch_size = max_batch_size
|
133
|
+
IndexifyFn.payload_encoder = output_encoder
|
134
|
+
return IndexifyFn
|
135
|
+
|
136
|
+
return construct
|
137
|
+
|
138
|
+
|
139
|
+
class IndexifyFunctionWrapper:
|
140
|
+
def __init__(self, indexify_function: Union[IndexifyFunction, IndexifyRouter]):
|
141
|
+
self.indexify_function: Union[
|
142
|
+
IndexifyFunction, IndexifyRouter
|
143
|
+
] = indexify_function()
|
144
|
+
|
145
|
+
def get_output_model(self) -> Any:
|
146
|
+
if not isinstance(self.indexify_function, IndexifyFunction):
|
147
|
+
raise TypeError("Input must be an instance of IndexifyFunction")
|
148
|
+
|
149
|
+
extract_method = self.indexify_function.run
|
150
|
+
type_hints = get_type_hints(extract_method)
|
151
|
+
return_type = type_hints.get("return", Any)
|
152
|
+
if get_origin(return_type) is list:
|
153
|
+
return_type = get_args(return_type)[0]
|
154
|
+
elif get_origin(return_type) is Union:
|
155
|
+
inner_types = get_args(return_type)
|
156
|
+
if len(inner_types) == 2 and type(None) in inner_types:
|
157
|
+
return_type = (
|
158
|
+
inner_types[0] if inner_types[1] is type(None) else inner_types[1]
|
159
|
+
)
|
160
|
+
return return_type
|
161
|
+
|
162
|
+
def run_router(self, input: Union[Dict, Type[BaseModel]]) -> List[str]:
|
163
|
+
kwargs = input if isinstance(input, dict) else {"input": input}
|
164
|
+
args = []
|
165
|
+
kwargs = {}
|
166
|
+
if isinstance(input, dict):
|
167
|
+
kwargs = input
|
168
|
+
else:
|
169
|
+
args.append(input)
|
170
|
+
extracted_data = self.indexify_function.run(*args, **kwargs)
|
171
|
+
if not isinstance(extracted_data, list) and extracted_data is not None:
|
172
|
+
return [extracted_data.name]
|
173
|
+
edges = []
|
174
|
+
for fn in extracted_data or []:
|
175
|
+
edges.append(fn.name)
|
176
|
+
return edges
|
177
|
+
|
178
|
+
def run_fn(
|
179
|
+
self, input: Union[Dict, Type[BaseModel]], acc: Type[Any] = None
|
180
|
+
) -> List[IndexifyData]:
|
181
|
+
args = []
|
182
|
+
kwargs = {}
|
183
|
+
if acc is not None:
|
184
|
+
args.append(acc)
|
185
|
+
if isinstance(input, dict):
|
186
|
+
kwargs = input
|
187
|
+
else:
|
188
|
+
args.append(input)
|
189
|
+
|
190
|
+
extracted_data = self.indexify_function.run(*args, **kwargs)
|
191
|
+
|
192
|
+
return extracted_data if isinstance(extracted_data, list) else [extracted_data]
|
@@ -0,0 +1,46 @@
|
|
1
|
+
import os
|
2
|
+
from hashlib import sha256
|
3
|
+
from typing import List, Optional
|
4
|
+
|
5
|
+
|
6
|
+
class CacheAwareFunctionWrapper:
|
7
|
+
def __init__(self, cache_dir: str):
|
8
|
+
self._cache_dir = cache_dir
|
9
|
+
if not os.path.exists(cache_dir):
|
10
|
+
os.makedirs(cache_dir)
|
11
|
+
|
12
|
+
def _get_key(self, input: bytes) -> str:
|
13
|
+
h = sha256()
|
14
|
+
h.update(input)
|
15
|
+
return h.hexdigest()
|
16
|
+
|
17
|
+
def get(self, graph: str, node_name: str, input: bytes) -> Optional[List[bytes]]:
|
18
|
+
key = self._get_key(input)
|
19
|
+
dir_path = os.path.join(self._cache_dir, graph, node_name, key)
|
20
|
+
if not os.path.exists(dir_path):
|
21
|
+
return None
|
22
|
+
|
23
|
+
files = os.listdir(dir_path)
|
24
|
+
outputs = []
|
25
|
+
for file in files:
|
26
|
+
with open(os.path.join(dir_path, file), "rb") as f:
|
27
|
+
return f.read()
|
28
|
+
|
29
|
+
return outputs
|
30
|
+
|
31
|
+
def set(
|
32
|
+
self,
|
33
|
+
graph: str,
|
34
|
+
node_name: str,
|
35
|
+
input: bytes,
|
36
|
+
output: List[bytes],
|
37
|
+
):
|
38
|
+
key = self._get_key(input)
|
39
|
+
dir_path = os.path.join(self._cache_dir, graph, node_name, key)
|
40
|
+
if not os.path.exists(dir_path):
|
41
|
+
os.makedirs(dir_path)
|
42
|
+
|
43
|
+
for i, output_item in enumerate(output):
|
44
|
+
file_path = os.path.join(dir_path, f"{i}.cbor")
|
45
|
+
with open(file_path, "wb") as f:
|
46
|
+
f.write(output_item)
|
@@ -0,0 +1,61 @@
|
|
1
|
+
from typing import Any, List
|
2
|
+
|
3
|
+
import cloudpickle
|
4
|
+
import msgpack
|
5
|
+
from pydantic import BaseModel
|
6
|
+
|
7
|
+
from .data_objects import IndexifyData
|
8
|
+
|
9
|
+
|
10
|
+
def get_serializer(serializer_type: str) -> Any:
|
11
|
+
if serializer_type == "cloudpickle":
|
12
|
+
return CloudPickleSerializer()
|
13
|
+
elif serializer_type == "msgpack":
|
14
|
+
return MsgPackSerializer()
|
15
|
+
else:
|
16
|
+
raise ValueError(f"Unknown serializer type: {serializer_type}")
|
17
|
+
|
18
|
+
|
19
|
+
class CloudPickleSerializer:
|
20
|
+
@staticmethod
|
21
|
+
def serialize(data: Any) -> bytes:
|
22
|
+
return cloudpickle.dumps(data)
|
23
|
+
|
24
|
+
@staticmethod
|
25
|
+
def deserialize(data: bytes) -> Any:
|
26
|
+
return cloudpickle.loads(data)
|
27
|
+
|
28
|
+
@staticmethod
|
29
|
+
def serialize_list(data: List[Any]) -> bytes:
|
30
|
+
return cloudpickle.dumps(data)
|
31
|
+
|
32
|
+
@staticmethod
|
33
|
+
def deserialize_list(data: bytes) -> List[Any]:
|
34
|
+
return cloudpickle.loads(data)
|
35
|
+
|
36
|
+
|
37
|
+
class MsgPackSerializer:
|
38
|
+
@staticmethod
|
39
|
+
def serialize(data: Any) -> bytes:
|
40
|
+
if (
|
41
|
+
isinstance(data, type)
|
42
|
+
and issubclass(data, BaseModel)
|
43
|
+
or isinstance(data, BaseModel)
|
44
|
+
):
|
45
|
+
return msgpack.packb(data.model_dump())
|
46
|
+
return msgpack.packb(data)
|
47
|
+
|
48
|
+
@staticmethod
|
49
|
+
def deserialize(data: bytes) -> IndexifyData:
|
50
|
+
cached_output = msgpack.unpackb(data)
|
51
|
+
print(cached_output)
|
52
|
+
return IndexifyData(**cached_output)
|
53
|
+
|
54
|
+
@staticmethod
|
55
|
+
def serialize_list(data: List[IndexifyData]) -> bytes:
|
56
|
+
data = [item.model_dump() for item in data]
|
57
|
+
return msgpack.packb(data)
|
58
|
+
|
59
|
+
@staticmethod
|
60
|
+
def deserialize_list(data: bytes) -> List[IndexifyData]:
|
61
|
+
return [IndexifyData(**item) for item in msgpack.unpackb(data)]
|