indexify 0.0.43__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. indexify/__init__.py +15 -14
  2. indexify/base_client.py +48 -21
  3. indexify/cli.py +247 -0
  4. indexify/client.py +18 -790
  5. indexify/error.py +3 -30
  6. indexify/executor/agent.py +364 -0
  7. indexify/executor/api_objects.py +43 -0
  8. indexify/executor/downloader.py +124 -0
  9. indexify/executor/executor_tasks.py +72 -0
  10. indexify/executor/function_worker.py +177 -0
  11. indexify/executor/indexify_executor.py +32 -0
  12. indexify/executor/runtime_probes.py +48 -0
  13. indexify/executor/task_reporter.py +110 -0
  14. indexify/executor/task_store.py +113 -0
  15. indexify/foo +72 -0
  16. indexify/functions_sdk/data_objects.py +37 -0
  17. indexify/functions_sdk/graph.py +281 -0
  18. indexify/functions_sdk/graph_validation.py +66 -0
  19. indexify/functions_sdk/image.py +34 -0
  20. indexify/functions_sdk/indexify_functions.py +188 -0
  21. indexify/functions_sdk/local_cache.py +46 -0
  22. indexify/functions_sdk/object_serializer.py +60 -0
  23. indexify/local_client.py +183 -0
  24. indexify/remote_client.py +319 -0
  25. indexify-0.2.1.dist-info/METADATA +151 -0
  26. indexify-0.2.1.dist-info/RECORD +33 -0
  27. indexify-0.2.1.dist-info/entry_points.txt +3 -0
  28. indexify/exceptions.py +0 -3
  29. indexify/extraction_policy.py +0 -75
  30. indexify/extractor_sdk/__init__.py +0 -14
  31. indexify/extractor_sdk/data.py +0 -100
  32. indexify/extractor_sdk/extractor.py +0 -225
  33. indexify/extractor_sdk/utils.py +0 -102
  34. indexify/extractors/__init__.py +0 -0
  35. indexify/extractors/embedding.py +0 -55
  36. indexify/extractors/pdf_parser.py +0 -93
  37. indexify/graph.py +0 -133
  38. indexify/local_runner.py +0 -128
  39. indexify/runner.py +0 -22
  40. indexify/utils.py +0 -7
  41. indexify-0.0.43.dist-info/METADATA +0 -66
  42. indexify-0.0.43.dist-info/RECORD +0 -25
  43. {indexify-0.0.43.dist-info → indexify-0.2.1.dist-info}/LICENSE.txt +0 -0
  44. {indexify-0.0.43.dist-info → indexify-0.2.1.dist-info}/WHEEL +0 -0
@@ -0,0 +1,281 @@
1
+ import inspect
2
+ from collections import defaultdict
3
+ from typing import (
4
+ Annotated,
5
+ Any,
6
+ Callable,
7
+ Dict,
8
+ List,
9
+ Optional,
10
+ Set,
11
+ Type,
12
+ Union,
13
+ get_args,
14
+ get_origin,
15
+ )
16
+
17
+ import cloudpickle
18
+ import msgpack
19
+ from pydantic import BaseModel
20
+ from typing_extensions import get_args, get_origin
21
+
22
+ from .data_objects import IndexifyData, RouterOutput
23
+ from .graph_validation import validate_node, validate_route
24
+ from .indexify_functions import (
25
+ IndexifyFunction,
26
+ IndexifyFunctionWrapper,
27
+ IndexifyRouter,
28
+ )
29
+ from .object_serializer import CloudPickleSerializer, get_serializer
30
+
31
+ RouterFn = Annotated[
32
+ Callable[[IndexifyData], Optional[List[IndexifyFunction]]], "RouterFn"
33
+ ]
34
+ GraphNode = Annotated[Union[IndexifyFunctionWrapper, RouterFn], "GraphNode"]
35
+
36
+
37
+ def is_pydantic_model_from_annotation(type_annotation):
38
+ # If it's a string representation
39
+ if isinstance(type_annotation, str):
40
+ # Extract the class name from the string
41
+ class_name = type_annotation.split("'")[-2].split(".")[-1]
42
+ # This part is tricky and might require additional context or imports
43
+ # You might need to import the actual class or module where it's defined
44
+ # For example:
45
+ # from indexify.functions_sdk.data_objects import File
46
+ # return issubclass(eval(class_name), BaseModel)
47
+ return False # Default to False if we can't evaluate
48
+
49
+ # If it's a Type object
50
+ origin = get_origin(type_annotation)
51
+ if origin is not None:
52
+ # Handle generic types like List[File], Optional[File], etc.
53
+ args = get_args(type_annotation)
54
+ if args:
55
+ return is_pydantic_model_from_annotation(args[0])
56
+
57
+ # If it's a direct class reference
58
+ if isinstance(type_annotation, type):
59
+ return issubclass(type_annotation, BaseModel)
60
+
61
+ return False
62
+
63
+
64
+ class FunctionMetadata(BaseModel):
65
+ name: str
66
+ fn_name: str
67
+ description: str
68
+ reducer: bool = False
69
+ image_name: str
70
+ payload_encoder: str = "cloudpickle"
71
+
72
+
73
+ class RouterMetadata(BaseModel):
74
+ name: str
75
+ description: str
76
+ source_fn: str
77
+ target_fns: List[str]
78
+ image_name: str
79
+ payload_encoder: str = "cloudpickle"
80
+
81
+
82
+ class NodeMetadata(BaseModel):
83
+ dynamic_router: Optional[RouterMetadata] = None
84
+ compute_fn: Optional[FunctionMetadata] = None
85
+
86
+
87
+ class ComputeGraphMetadata(BaseModel):
88
+ name: str
89
+ description: str
90
+ start_node: NodeMetadata
91
+ nodes: Dict[str, NodeMetadata]
92
+ edges: Dict[str, List[str]]
93
+ accumulator_zero_values: Dict[str, bytes] = {}
94
+
95
+ def get_input_payload_serializer(self):
96
+ return get_serializer(self.start_node.compute_fn.payload_encoder)
97
+
98
+
99
+ class Graph:
100
+ def __init__(
101
+ self, name: str, start_node: IndexifyFunction, description: Optional[str] = None
102
+ ):
103
+ self.name = name
104
+ self.description = description
105
+ self.nodes: Dict[str, Union[IndexifyFunction, IndexifyRouter]] = {}
106
+ self.routers: Dict[str, List[str]] = defaultdict(list)
107
+ self.edges: Dict[str, List[str]] = defaultdict(list)
108
+ self.accumulator_zero_values: Dict[str, Any] = {}
109
+
110
+ self.add_node(start_node)
111
+ self._start_node: str = start_node.name
112
+
113
+ def get_function(self, name: str) -> IndexifyFunctionWrapper:
114
+ if name not in self.nodes:
115
+ raise ValueError(f"Function {name} not found in graph")
116
+ return IndexifyFunctionWrapper(self.nodes[name])
117
+
118
+ def get_accumulators(self) -> Dict[str, Any]:
119
+ return self.accumulator_zero_values
120
+
121
+ def deserialize_fn_output(self, name: str, output: IndexifyData) -> Any:
122
+ serializer = get_serializer(self.nodes[name].payload_encoder)
123
+ return serializer.deserialize(output.payload)
124
+
125
+ def add_node(
126
+ self, indexify_fn: Union[Type[IndexifyFunction], Type[IndexifyRouter]]
127
+ ) -> "Graph":
128
+ validate_node(indexify_fn=indexify_fn)
129
+
130
+ if indexify_fn.name in self.nodes:
131
+ return self
132
+
133
+ if issubclass(indexify_fn, IndexifyFunction) and indexify_fn.accumulate:
134
+ self.accumulator_zero_values[
135
+ indexify_fn.name
136
+ ] = indexify_fn.accumulate().model_dump()
137
+
138
+ self.nodes[indexify_fn.name] = indexify_fn
139
+ return self
140
+
141
+ def route(
142
+ self, from_node: Type[IndexifyRouter], to_nodes: List[Type[IndexifyFunction]]
143
+ ) -> "Graph":
144
+
145
+ validate_route(from_node=from_node, to_nodes=to_nodes)
146
+
147
+ print(
148
+ f"Adding router {from_node.name} to nodes {[node.name for node in to_nodes]}"
149
+ )
150
+ self.add_node(from_node)
151
+ for node in to_nodes:
152
+ self.add_node(node)
153
+ self.routers[from_node.name].append(node.name)
154
+ return self
155
+
156
+ def serialize(self):
157
+ return cloudpickle.dumps(self)
158
+
159
+ @staticmethod
160
+ def deserialize(graph: bytes) -> "Graph":
161
+ return cloudpickle.loads(graph)
162
+
163
+ @staticmethod
164
+ def from_path(path: str) -> "Graph":
165
+ with open(path, "rb") as f:
166
+ return cloudpickle.load(f)
167
+
168
+ def add_edge(
169
+ self,
170
+ from_node: Type[IndexifyFunction],
171
+ to_node: Union[Type[IndexifyFunction], RouterFn],
172
+ ) -> "Graph":
173
+ self.add_edges(from_node, [to_node])
174
+ return self
175
+
176
+ def invoke_fn_ser(
177
+ self, name: str, input: IndexifyData, acc: Optional[Any] = None
178
+ ) -> List[IndexifyData]:
179
+ fn_wrapper = self.get_function(name)
180
+ input = self.deserialize_input(name, input)
181
+ serializer = get_serializer(fn_wrapper.indexify_function.payload_encoder)
182
+ if acc is not None:
183
+ acc = fn_wrapper.indexify_function.accumulate.model_validate(
184
+ serializer.deserialize(acc.payload)
185
+ )
186
+ if acc is None and fn_wrapper.indexify_function.accumulate is not None:
187
+ acc = fn_wrapper.indexify_function.accumulate.model_validate(
188
+ self.accumulator_zero_values[name]
189
+ )
190
+ outputs: List[Any] = fn_wrapper.run_fn(input, acc=acc)
191
+ return [
192
+ IndexifyData(payload=serializer.serialize(output)) for output in outputs
193
+ ]
194
+
195
+ def invoke_router(self, name: str, input: IndexifyData) -> Optional[RouterOutput]:
196
+ fn_wrapper = self.get_function(name)
197
+ input = self.deserialize_input(name, input)
198
+ return RouterOutput(edges=fn_wrapper.run_router(input))
199
+
200
+ def deserialize_input(self, compute_fn: str, indexify_data: IndexifyData) -> Any:
201
+ compute_fn = self.nodes[compute_fn]
202
+ if not compute_fn:
203
+ raise ValueError(f"Compute function {compute_fn} not found in graph")
204
+ if compute_fn.payload_encoder == "cloudpickle":
205
+ return CloudPickleSerializer.deserialize(indexify_data.payload)
206
+ payload = msgpack.unpackb(indexify_data.payload)
207
+ signature = inspect.signature(compute_fn.run)
208
+ arg_types = {}
209
+ for name, param in signature.parameters.items():
210
+ if (
211
+ param.annotation != inspect.Parameter.empty
212
+ and param.annotation != getattr(compute_fn, "accumulate", None)
213
+ ):
214
+ arg_types[name] = param.annotation
215
+ if len(arg_types) > 1:
216
+ raise ValueError(
217
+ f"Compute function {compute_fn} has multiple arguments, but only one is supported"
218
+ )
219
+ elif len(arg_types) == 0:
220
+ raise ValueError(f"Compute function {compute_fn} has no arguments")
221
+ arg_name, arg_type = next(iter(arg_types.items()))
222
+ if arg_type is None:
223
+ raise ValueError(f"Argument {arg_name} has no type annotation")
224
+ if is_pydantic_model_from_annotation(arg_type):
225
+ if len(payload.keys()) == 1 and isinstance(list(payload.values())[0], dict):
226
+ payload = list(payload.values())[0]
227
+ return arg_type.model_validate(payload)
228
+ return payload
229
+
230
+ def add_edges(
231
+ self,
232
+ from_node: Union[Type[IndexifyFunction], Type[IndexifyRouter]],
233
+ to_node: List[Union[Type[IndexifyFunction], Type[IndexifyRouter]]],
234
+ ) -> "Graph":
235
+ self.add_node(from_node)
236
+ from_node_name = from_node.name
237
+ for node in to_node:
238
+ self.add_node(node)
239
+ self.edges[from_node_name].append(node.name)
240
+ return self
241
+
242
+ def definition(self) -> ComputeGraphMetadata:
243
+ start_node = self.nodes[self._start_node]
244
+ start_node = FunctionMetadata(
245
+ name=start_node.name,
246
+ fn_name=start_node.fn_name,
247
+ description=start_node.description,
248
+ reducer=start_node.accumulate is not None,
249
+ image_name=start_node.image._image_name,
250
+ )
251
+ metadata_edges = self.edges.copy()
252
+ metadata_nodes = {}
253
+ for node_name, node in self.nodes.items():
254
+ if node_name in self.routers:
255
+ metadata_nodes[node_name] = NodeMetadata(
256
+ dynamic_router=RouterMetadata(
257
+ name=node_name,
258
+ description=node.description or "",
259
+ source_fn=node_name,
260
+ target_fns=self.routers[node_name],
261
+ payload_encoder=node.payload_encoder,
262
+ image_name=node.image._image_name,
263
+ )
264
+ )
265
+ else:
266
+ metadata_nodes[node_name] = NodeMetadata(
267
+ compute_fn=FunctionMetadata(
268
+ name=node_name,
269
+ fn_name=node.fn_name,
270
+ description=node.description,
271
+ reducer=node.accumulate is not None,
272
+ image_name=node.image._image_name,
273
+ )
274
+ )
275
+ return ComputeGraphMetadata(
276
+ name=self.name,
277
+ description=self.description or "",
278
+ start_node=NodeMetadata(compute_fn=start_node),
279
+ nodes=metadata_nodes,
280
+ edges=metadata_edges,
281
+ )
@@ -0,0 +1,66 @@
1
+ import inspect
2
+ import re
3
+ from typing import List, Type, Union
4
+
5
+ from .indexify_functions import IndexifyFunction, IndexifyRouter
6
+
7
+
8
+ def validate_node(indexify_fn: Union[Type[IndexifyFunction], Type[IndexifyRouter]]):
9
+ if inspect.isfunction(indexify_fn):
10
+ raise Exception(
11
+ f"Unable to add node of type `{type(indexify_fn)}`. "
12
+ f"Required, `IndexifyFunction` or `IndexifyRouter`"
13
+ )
14
+ if not (
15
+ issubclass(indexify_fn, IndexifyFunction)
16
+ or issubclass(indexify_fn, IndexifyRouter)
17
+ ):
18
+ raise Exception(
19
+ f"Unable to add node of type `{indexify_fn.__name__}`. "
20
+ f"Required, `IndexifyFunction` or `IndexifyRouter`"
21
+ )
22
+
23
+ signature = inspect.signature(indexify_fn.run)
24
+
25
+ for param in signature.parameters.values():
26
+ if param.name == "self":
27
+ continue
28
+ if param.annotation == inspect.Parameter.empty:
29
+ raise Exception(
30
+ f"Input param {param.name} in {indexify_fn.name} has empty"
31
+ f" type annotation"
32
+ )
33
+
34
+ if signature.return_annotation == inspect.Signature.empty:
35
+ raise Exception(f"Function {indexify_fn.name} has empty return type annotation")
36
+
37
+
38
+ def validate_route(
39
+ from_node: Type[IndexifyRouter], to_nodes: List[Type[IndexifyFunction]]
40
+ ):
41
+ signature = inspect.signature(from_node.run)
42
+
43
+ if signature.return_annotation == inspect.Signature.empty:
44
+ raise Exception(f"Function {from_node.name} has empty return type annotation")
45
+
46
+ return_annotation = signature.return_annotation
47
+
48
+ if hasattr(return_annotation, '__origin__') and return_annotation.__origin__ is Union:
49
+ for arg in return_annotation.__args__:
50
+ if hasattr(arg, 'name'):
51
+ if arg not in to_nodes:
52
+ raise Exception(
53
+ f"Unable to find {arg.name} in to_nodes {[node.name for node in to_nodes]}"
54
+ )
55
+
56
+ if hasattr(return_annotation, '__origin__') and return_annotation.__origin__ is list:
57
+ union_args = return_annotation.__args__[0].__args__
58
+ for arg in union_args:
59
+ if hasattr(arg, 'name'):
60
+ if arg not in to_nodes:
61
+ raise Exception(
62
+ f"Unable to find {arg.name} in to_nodes {[node.name for node in to_nodes]}"
63
+ )
64
+ else:
65
+ raise Exception(f"Return type of {from_node.name} is not a Union")
66
+
@@ -0,0 +1,34 @@
1
+ class Image:
2
+ def __init__(self):
3
+ self._image_name = None
4
+
5
+ self._tag = "latest"
6
+
7
+ self._base_image = None
8
+
9
+ self._run_strs = []
10
+
11
+ def name(self, image_name):
12
+ self._image_name = image_name
13
+ return self
14
+
15
+ def tag(self, tag):
16
+ self._tag = tag
17
+ return self
18
+
19
+ def base_image(self, base_image):
20
+ self._base_image = base_image
21
+ return self
22
+
23
+ def run(self, run_str):
24
+ self._run_strs.append(run_str)
25
+ return self
26
+
27
+
28
+ DEFAULT_IMAGE = (
29
+ Image()
30
+ .name("indexify-executor-default")
31
+ .base_image("python:3.10.15-slim-bookworm")
32
+ .tag("latest")
33
+ .run("pip install indexify")
34
+ )
@@ -0,0 +1,188 @@
1
+ from abc import ABC, abstractmethod
2
+ from functools import update_wrapper
3
+ from typing import (
4
+ Any,
5
+ Callable,
6
+ Dict,
7
+ List,
8
+ Optional,
9
+ Type,
10
+ Union,
11
+ get_args,
12
+ get_origin,
13
+ )
14
+
15
+ from pydantic import BaseModel
16
+ from typing_extensions import get_type_hints
17
+
18
+ from .data_objects import IndexifyData, RouterOutput
19
+ from .image import DEFAULT_IMAGE, Image
20
+
21
+
22
+ class EmbeddingIndexes(BaseModel):
23
+ dim: int
24
+ distance: Optional[str] = "cosine"
25
+ database_url: Optional[str] = None
26
+
27
+
28
+ class PlacementConstraints(BaseModel):
29
+ min_python_version: Optional[str] = "3.9"
30
+ max_python_version: Optional[str] = None
31
+ platform: Optional[str] = None
32
+ image_name: Optional[str] = None
33
+
34
+
35
+ class IndexifyFunction(ABC):
36
+ name: str = ""
37
+ description: str = ""
38
+ image: Optional[Image] = DEFAULT_IMAGE
39
+ placement_constraints: List[PlacementConstraints] = []
40
+ accumulate: Optional[Type[Any]] = None
41
+ payload_encoder: Optional[str] = "cloudpickle"
42
+
43
+ @abstractmethod
44
+ def run(self, *args, **kwargs) -> Union[List[Any], Any]:
45
+ pass
46
+
47
+ def partial(self, **kwargs) -> Callable:
48
+ from functools import partial
49
+
50
+ return partial(self.run, **kwargs)
51
+
52
+
53
+ class IndexifyRouter(ABC):
54
+ name: str = ""
55
+ description: str = ""
56
+ image: Optional[Image] = DEFAULT_IMAGE
57
+ placement_constraints: List[PlacementConstraints] = []
58
+ payload_encoder: Optional[str] = "cloudpickle"
59
+
60
+ @abstractmethod
61
+ def run(self, *args, **kwargs) -> Optional[List[IndexifyFunction]]:
62
+ pass
63
+
64
+
65
+ def indexify_router(
66
+ name: Optional[str] = None,
67
+ description: Optional[str] = "",
68
+ image: Optional[Image] = DEFAULT_IMAGE,
69
+ placement_constraints: List[PlacementConstraints] = [],
70
+ payload_encoder: Optional[str] = "cloudpickle",
71
+ ):
72
+ def construct(fn):
73
+ args = locals().copy()
74
+ args["name"] = args["name"] if args.get("name", None) else fn.__name__
75
+ args["fn_name"] = fn.__name__
76
+ args["description"] = (
77
+ args["description"]
78
+ if args.get("description", None)
79
+ else (fn.__doc__ or "").strip().replace("\n", "")
80
+ )
81
+
82
+ class IndexifyRo(IndexifyRouter):
83
+ def run(self, *args, **kwargs) -> Optional[List[IndexifyFunction]]:
84
+ return fn(*args, **kwargs)
85
+
86
+ update_wrapper(run, fn)
87
+
88
+ for key, value in args.items():
89
+ if key != "fn" and key != "self":
90
+ setattr(IndexifyRo, key, value)
91
+
92
+ IndexifyRo.image = image
93
+ IndexifyRo.payload_encoder = payload_encoder
94
+ return IndexifyRo
95
+
96
+ return construct
97
+
98
+
99
+ def indexify_function(
100
+ name: Optional[str] = None,
101
+ description: Optional[str] = "",
102
+ image: Optional[Image] = DEFAULT_IMAGE,
103
+ accumulate: Optional[Type[BaseModel]] = None,
104
+ payload_encoder: Optional[str] = "cloudpickle",
105
+ placement_constraints: List[PlacementConstraints] = [],
106
+ ):
107
+ def construct(fn):
108
+ args = locals().copy()
109
+ args["name"] = args["name"] if args.get("name", None) else fn.__name__
110
+ args["fn_name"] = fn.__name__
111
+ args["description"] = (
112
+ args["description"]
113
+ if args.get("description", None)
114
+ else (fn.__doc__ or "").strip().replace("\n", "")
115
+ )
116
+
117
+ class IndexifyFn(IndexifyFunction):
118
+ def run(self, *args, **kwargs) -> Union[List[Any], Any]:
119
+ return fn(*args, **kwargs)
120
+
121
+ update_wrapper(run, fn)
122
+
123
+ for key, value in args.items():
124
+ if key != "fn" and key != "self":
125
+ setattr(IndexifyFn, key, value)
126
+
127
+ IndexifyFn.image = image
128
+ IndexifyFn.accumulate = accumulate
129
+ IndexifyFn.payload_encoder = payload_encoder
130
+ return IndexifyFn
131
+
132
+ return construct
133
+
134
+
135
+ class IndexifyFunctionWrapper:
136
+ def __init__(self, indexify_function: Union[IndexifyFunction, IndexifyRouter]):
137
+ self.indexify_function: Union[
138
+ IndexifyFunction, IndexifyRouter
139
+ ] = indexify_function()
140
+
141
+ def get_output_model(self) -> Any:
142
+ if not isinstance(self.indexify_function, IndexifyFunction):
143
+ raise TypeError("Input must be an instance of IndexifyFunction")
144
+
145
+ extract_method = self.indexify_function.run
146
+ type_hints = get_type_hints(extract_method)
147
+ return_type = type_hints.get("return", Any)
148
+ if get_origin(return_type) is list:
149
+ return_type = get_args(return_type)[0]
150
+ elif get_origin(return_type) is Union:
151
+ inner_types = get_args(return_type)
152
+ if len(inner_types) == 2 and type(None) in inner_types:
153
+ return_type = (
154
+ inner_types[0] if inner_types[1] is type(None) else inner_types[1]
155
+ )
156
+ return return_type
157
+
158
+ def run_router(self, input: Union[Dict, Type[BaseModel]]) -> List[str]:
159
+ kwargs = input if isinstance(input, dict) else {"input": input}
160
+ args = []
161
+ kwargs = {}
162
+ if isinstance(input, dict):
163
+ kwargs = input
164
+ else:
165
+ args.append(input)
166
+ extracted_data = self.indexify_function.run(*args, **kwargs)
167
+ if not isinstance(extracted_data, list) and extracted_data is not None:
168
+ return [extracted_data.name]
169
+ edges = []
170
+ for fn in extracted_data or []:
171
+ edges.append(fn.name)
172
+ return edges
173
+
174
+ def run_fn(
175
+ self, input: Union[Dict, Type[BaseModel]], acc: Type[Any] = None
176
+ ) -> List[IndexifyData]:
177
+ args = []
178
+ kwargs = {}
179
+ if acc is not None:
180
+ args.append(acc)
181
+ if isinstance(input, dict):
182
+ kwargs = input
183
+ else:
184
+ args.append(input)
185
+
186
+ extracted_data = self.indexify_function.run(*args, **kwargs)
187
+
188
+ return extracted_data if isinstance(extracted_data, list) else [extracted_data]
@@ -0,0 +1,46 @@
1
+ import os
2
+ from hashlib import sha256
3
+ from typing import List, Optional
4
+
5
+
6
+ class CacheAwareFunctionWrapper:
7
+ def __init__(self, cache_dir: str):
8
+ self._cache_dir = cache_dir
9
+ if not os.path.exists(cache_dir):
10
+ os.makedirs(cache_dir)
11
+
12
+ def _get_key(self, input: bytes) -> str:
13
+ h = sha256()
14
+ h.update(input)
15
+ return h.hexdigest()
16
+
17
+ def get(self, graph: str, node_name: str, input: bytes) -> Optional[List[bytes]]:
18
+ key = self._get_key(input)
19
+ dir_path = os.path.join(self._cache_dir, graph, node_name, key)
20
+ if not os.path.exists(dir_path):
21
+ return None
22
+
23
+ files = os.listdir(dir_path)
24
+ outputs = []
25
+ for file in files:
26
+ with open(os.path.join(dir_path, file), "rb") as f:
27
+ return f.read()
28
+
29
+ return outputs
30
+
31
+ def set(
32
+ self,
33
+ graph: str,
34
+ node_name: str,
35
+ input: bytes,
36
+ output: List[bytes],
37
+ ):
38
+ key = self._get_key(input)
39
+ dir_path = os.path.join(self._cache_dir, graph, node_name, key)
40
+ if not os.path.exists(dir_path):
41
+ os.makedirs(dir_path)
42
+
43
+ for i, output_item in enumerate(output):
44
+ file_path = os.path.join(dir_path, f"{i}.cbor")
45
+ with open(file_path, "wb") as f:
46
+ f.write(output_item)
@@ -0,0 +1,60 @@
1
+ from typing import Any, List
2
+
3
+ import cloudpickle
4
+ import msgpack
5
+ from pydantic import BaseModel
6
+
7
+ from .data_objects import IndexifyData
8
+
9
+
10
+ def get_serializer(serializer_type: str) -> Any:
11
+ if serializer_type == "cloudpickle":
12
+ return CloudPickleSerializer()
13
+ elif serializer_type == "msgpack":
14
+ return MsgPackSerializer()
15
+ else:
16
+ raise ValueError(f"Unknown serializer type: {serializer_type}")
17
+
18
+
19
+ class CloudPickleSerializer:
20
+ @staticmethod
21
+ def serialize(data: Any) -> bytes:
22
+ return cloudpickle.dumps(data)
23
+
24
+ @staticmethod
25
+ def deserialize(data: bytes) -> Any:
26
+ return cloudpickle.loads(data)
27
+
28
+ @staticmethod
29
+ def serialize_list(data: List[Any]) -> bytes:
30
+ return cloudpickle.dumps(data)
31
+
32
+ @staticmethod
33
+ def deserialize_list(data: bytes) -> List[Any]:
34
+ return cloudpickle.loads(data)
35
+
36
+
37
+ class MsgPackSerializer:
38
+ @staticmethod
39
+ def serialize(data: Any) -> bytes:
40
+ if (
41
+ isinstance(data, type)
42
+ and issubclass(data, BaseModel)
43
+ or isinstance(data, BaseModel)
44
+ ):
45
+ return msgpack.packb(data.model_dump())
46
+ return msgpack.packb(data)
47
+
48
+ @staticmethod
49
+ def deserialize(data: bytes) -> IndexifyData:
50
+ cached_output = msgpack.unpackb(data)
51
+ return IndexifyData(**cached_output)
52
+
53
+ @staticmethod
54
+ def serialize_list(data: List[IndexifyData]) -> bytes:
55
+ data = [item.model_dump() for item in data]
56
+ return msgpack.packb(data)
57
+
58
+ @staticmethod
59
+ def deserialize_list(data: bytes) -> List[IndexifyData]:
60
+ return [IndexifyData(**item) for item in msgpack.unpackb(data)]