indexify 0.0.39__py3-none-any.whl → 0.0.40__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- indexify/__init__.py +2 -0
- indexify/data_loaders/url_loader.py +0 -1
- indexify/extractor_sdk/data.py +1 -1
- indexify/extractor_sdk/extractor.py +7 -27
- indexify/extractors/__init__.py +0 -0
- indexify/extractors/embedding.py +53 -0
- indexify/graph.py +122 -16
- indexify/local_runner.py +2 -2
- {indexify-0.0.39.dist-info → indexify-0.0.40.dist-info}/METADATA +2 -1
- {indexify-0.0.39.dist-info → indexify-0.0.40.dist-info}/RECORD +12 -11
- indexify/run_graph.py +0 -122
- {indexify-0.0.39.dist-info → indexify-0.0.40.dist-info}/LICENSE.txt +0 -0
- {indexify-0.0.39.dist-info → indexify-0.0.40.dist-info}/WHEEL +0 -0
indexify/__init__.py
CHANGED
@@ -9,9 +9,11 @@ from .client import (
|
|
9
9
|
from . import extractor_sdk
|
10
10
|
from .settings import DEFAULT_SERVICE_URL
|
11
11
|
from . import data_loaders
|
12
|
+
from .graph import Graph
|
12
13
|
|
13
14
|
__all__ = [
|
14
15
|
"data_loaders",
|
16
|
+
"Graph",
|
15
17
|
"Document",
|
16
18
|
"extractor_sdk",
|
17
19
|
"IndexifyClient",
|
indexify/extractor_sdk/data.py
CHANGED
@@ -10,7 +10,8 @@ import requests
|
|
10
10
|
|
11
11
|
class EmbeddingSchema(BaseModel):
|
12
12
|
dim: int
|
13
|
-
distance: str = "cosine"
|
13
|
+
distance: Optional[str] = "cosine"
|
14
|
+
database_url: Optional[str] = None
|
14
15
|
|
15
16
|
class ExtractorMetadata(BaseModel):
|
16
17
|
name: str
|
@@ -40,8 +41,8 @@ class Extractor(ABC):
|
|
40
41
|
|
41
42
|
input_mime_types = ["text/plain"]
|
42
43
|
|
43
|
-
|
44
|
-
|
44
|
+
embedding_indexes: Dict[str, EmbeddingSchema] = {}
|
45
|
+
|
45
46
|
@abstractmethod
|
46
47
|
def extract(
|
47
48
|
self, input: Type[BaseModel], params: Type[BaseModel] = None
|
@@ -55,31 +56,9 @@ class Extractor(ABC):
|
|
55
56
|
pass
|
56
57
|
|
57
58
|
@classmethod
|
58
|
-
@abstractmethod
|
59
59
|
def sample_input(cls) -> Tuple[Content, Type[BaseModel]]:
|
60
60
|
pass
|
61
61
|
|
62
|
-
def describe(self) -> ExtractorMetadata:
|
63
|
-
embedding_schemas = {}
|
64
|
-
try:
|
65
|
-
embedding_schemas = self.embedding_schemas
|
66
|
-
except NotImplementedError:
|
67
|
-
pass
|
68
|
-
|
69
|
-
json_schema = (
|
70
|
-
self._param_cls.model_json_schema() if self._param_cls is not None else None
|
71
|
-
)
|
72
|
-
return ExtractorMetadata(
|
73
|
-
name=self.name,
|
74
|
-
version=self.version,
|
75
|
-
description=self.description,
|
76
|
-
system_dependencies=self.system_dependencies,
|
77
|
-
python_dependencies=self.python_dependencies,
|
78
|
-
input_mime_types=self.input_mime_types,
|
79
|
-
embedding_schemas=embedding_schemas,
|
80
|
-
input_params=json.dumps(json_schema),
|
81
|
-
)
|
82
|
-
|
83
62
|
def _download_file(self, url, filename):
|
84
63
|
if os.path.exists(filename):
|
85
64
|
# file exists skip
|
@@ -190,7 +169,7 @@ def extractor(
|
|
190
169
|
python_dependencies: Optional[List[str]] = None,
|
191
170
|
system_dependencies: Optional[List[str]] = None,
|
192
171
|
input_mime_types: Optional[List[str]] = None,
|
193
|
-
|
172
|
+
embedding_indexes: Optional[Dict[str, EmbeddingSchema]] = None,
|
194
173
|
sample_content: Optional[Callable] = None,
|
195
174
|
):
|
196
175
|
args = locals()
|
@@ -198,7 +177,7 @@ def extractor(
|
|
198
177
|
|
199
178
|
def construct(fn):
|
200
179
|
def wrapper():
|
201
|
-
|
180
|
+
description = fn.__doc__ or args.get("description", "")
|
202
181
|
|
203
182
|
if not args.get("name"):
|
204
183
|
args[
|
@@ -220,6 +199,7 @@ def extractor(
|
|
220
199
|
|
221
200
|
for key, val in args.items():
|
222
201
|
setattr(DecoratedFn, key, val)
|
202
|
+
DecoratedFn.description = description
|
223
203
|
|
224
204
|
return DecoratedFn
|
225
205
|
|
File without changes
|
@@ -0,0 +1,53 @@
|
|
1
|
+
from typing import List
|
2
|
+
|
3
|
+
from indexify.extractor_sdk.data import Feature
|
4
|
+
import torch
|
5
|
+
import torch.nn.functional as F
|
6
|
+
from transformers import AutoModel, AutoTokenizer
|
7
|
+
from indexify.extractor_sdk.extractor import Extractor , Feature
|
8
|
+
|
9
|
+
class SentenceTransformersEmbedding:
|
10
|
+
def __init__(self, model_name) -> None:
|
11
|
+
self._model_name = model_name
|
12
|
+
self._tokenizer = AutoTokenizer.from_pretrained(
|
13
|
+
f"sentence-transformers/{model_name}"
|
14
|
+
)
|
15
|
+
self._model = AutoModel.from_pretrained(
|
16
|
+
f"sentence-transformers/{model_name}", torchscript=True
|
17
|
+
)
|
18
|
+
self._model.eval()
|
19
|
+
|
20
|
+
def embed_batch(self, inputs: List[str]) -> List[List[float]]:
|
21
|
+
result = self._embed(inputs)
|
22
|
+
return result.tolist()
|
23
|
+
|
24
|
+
def embed(self, query: str) -> List[float]:
|
25
|
+
result = self._embed([query])
|
26
|
+
return result[0].tolist()
|
27
|
+
|
28
|
+
def _embed(self, inputs: List[str]) -> torch.Tensor:
|
29
|
+
encoded_input = self._tokenizer(
|
30
|
+
inputs, padding=True, truncation=True, return_tensors="pt"
|
31
|
+
)
|
32
|
+
sentence_embeddings = self._model(**encoded_input)
|
33
|
+
return F.normalize(sentence_embeddings, p=2, dim=1)
|
34
|
+
|
35
|
+
class BasicSentenceTransformerModels(Extractor):
|
36
|
+
|
37
|
+
def __init__(self, model: str):
|
38
|
+
super().__init__()
|
39
|
+
self.model = SentenceTransformersEmbedding(model)
|
40
|
+
|
41
|
+
def extract(self, input: str) -> List[Feature]:
|
42
|
+
embeddings = self.model.embed(input)
|
43
|
+
return [Feature.embedding(values=embeddings)]
|
44
|
+
|
45
|
+
class BasicHFTransformerEmbeddingModels(Extractor):
|
46
|
+
|
47
|
+
def __init__(self, model: str):
|
48
|
+
super().__init__()
|
49
|
+
self._model = AutoModel.from_pretrained(model, trust_remote_code=True)
|
50
|
+
|
51
|
+
def extract(self, input: str) -> List[Feature]:
|
52
|
+
embeddings = self.model.embed_query(input)
|
53
|
+
return [Feature.embedding(values=embeddings)]
|
indexify/graph.py
CHANGED
@@ -1,23 +1,129 @@
|
|
1
|
-
|
1
|
+
import json
|
2
|
+
import itertools
|
3
|
+
from collections import defaultdict
|
4
|
+
from typing import Any, Dict, List, Optional, Type, Union
|
2
5
|
|
3
|
-
|
6
|
+
import cloudpickle
|
4
7
|
from pydantic import BaseModel
|
5
8
|
|
6
|
-
from .
|
7
|
-
from .
|
9
|
+
from .extractor_sdk import Content, extractor, Extractor
|
10
|
+
from .runner import Runner
|
8
11
|
|
12
|
+
@extractor(description="id function")
|
13
|
+
def _id(content: Content) -> List[Content]:
|
14
|
+
return [content]
|
9
15
|
|
10
|
-
def Graph(
|
11
|
-
name: str,
|
12
|
-
input: Type[BaseModel],
|
13
|
-
start_node: Union[extractor, Extractor],
|
14
|
-
run_local: bool,
|
15
|
-
) -> RunGraph:
|
16
16
|
|
17
|
-
|
18
|
-
|
19
|
-
else:
|
20
|
-
raise NotImplementedError("Remote runner not supported yet")
|
17
|
+
def load_graph(graph: bytes) -> 'Graph':
|
18
|
+
return cloudpickle.loads(graph)
|
21
19
|
|
22
|
-
|
23
|
-
|
20
|
+
class Graph:
|
21
|
+
def __init__(self, name: str, input: Type[BaseModel], start_node: extractor, runner: Runner):
|
22
|
+
# TODO check for cycles
|
23
|
+
self.name = name
|
24
|
+
|
25
|
+
self.nodes: Dict[str, Union[extractor, Extractor]] = {}
|
26
|
+
self.params: Dict[str, Any] = {}
|
27
|
+
|
28
|
+
self.edges: Dict[str, List[(str, str)]] = defaultdict(list)
|
29
|
+
|
30
|
+
self.nodes["start"] = _id
|
31
|
+
self.nodes["end"] = _id
|
32
|
+
|
33
|
+
self._topo_counter = defaultdict(int)
|
34
|
+
|
35
|
+
self._start_node = None
|
36
|
+
self._input = input
|
37
|
+
|
38
|
+
self.runner = runner
|
39
|
+
|
40
|
+
def get_extractor(self, name: str) -> Extractor:
|
41
|
+
return self.nodes[name]
|
42
|
+
|
43
|
+
def _node(self, extractor: Extractor, params: Any = None) -> 'Graph':
|
44
|
+
name = extractor.name
|
45
|
+
|
46
|
+
# if you've already inserted a node just ignore the new insertion.
|
47
|
+
if name in self.nodes:
|
48
|
+
return
|
49
|
+
|
50
|
+
self.nodes[name] = extractor
|
51
|
+
self.params[name] = extractor.__dict__.get("params", None)
|
52
|
+
|
53
|
+
# assign each node a rank of 1 to init the graph
|
54
|
+
self._topo_counter[name] = 1
|
55
|
+
|
56
|
+
return self
|
57
|
+
|
58
|
+
def serialize(self):
|
59
|
+
return cloudpickle.dumps(self)
|
60
|
+
|
61
|
+
def add_edge(
|
62
|
+
self,
|
63
|
+
from_node: Type[Extractor],
|
64
|
+
to_node: Type[Extractor],
|
65
|
+
prefilter_predicates: Optional[str] = None,
|
66
|
+
) -> 'Graph':
|
67
|
+
|
68
|
+
self._node(from_node)
|
69
|
+
self._node(to_node)
|
70
|
+
|
71
|
+
from_node_name = from_node.name
|
72
|
+
to_node_name = to_node.name
|
73
|
+
|
74
|
+
self.edges[from_node_name].append((to_node_name, prefilter_predicates))
|
75
|
+
|
76
|
+
self._topo_counter[to_node_name] += 1
|
77
|
+
|
78
|
+
return self
|
79
|
+
|
80
|
+
"""
|
81
|
+
Connect nodes as a fan out from one `from_node` to multiple `to_nodes` and respective `prefilter_predicates`.
|
82
|
+
Note: The user has to match the sizes of the lists to make sure they line up otherwise a None is used as a default.
|
83
|
+
"""
|
84
|
+
|
85
|
+
def steps(
|
86
|
+
self,
|
87
|
+
from_node: extractor,
|
88
|
+
to_nodes: List[extractor],
|
89
|
+
prefilter_predicates: List[str] = [],
|
90
|
+
) -> 'Graph':
|
91
|
+
print(f"{to_nodes}, {prefilter_predicates}, {prefilter_predicates}")
|
92
|
+
for t_n, p in itertools.zip_longest(
|
93
|
+
to_nodes, prefilter_predicates, fillvalue=None
|
94
|
+
):
|
95
|
+
self.step(from_node=from_node, to_node=t_n, prefilter_predicates=p)
|
96
|
+
|
97
|
+
return self
|
98
|
+
|
99
|
+
def add_param(self, node: extractor, params: Dict[str, Any]):
|
100
|
+
try:
|
101
|
+
# check if the params can be serialized since the server needs this
|
102
|
+
json.dumps(params)
|
103
|
+
except Exception:
|
104
|
+
raise Exception(f"For node {node.name}, cannot serialize params as json.")
|
105
|
+
|
106
|
+
self.params[node.name] = params
|
107
|
+
|
108
|
+
def run(self, wf_input, local):
|
109
|
+
self._assign_start_node()
|
110
|
+
self.runner.run(self, wf_input=wf_input)
|
111
|
+
pass
|
112
|
+
|
113
|
+
def clear_cache_for_node(self, node: Union[extractor, Extractor]):
|
114
|
+
if node.name not in self.nodes.keys():
|
115
|
+
raise Exception(f"Node with name {node.name} not found in graph")
|
116
|
+
|
117
|
+
self.runner.deleted_from_memo(node.name)
|
118
|
+
|
119
|
+
def clear_cache_for_all_nodes(self):
|
120
|
+
for node_name in self.nodes:
|
121
|
+
self.runner.deleted_from_memo(node_name=node_name)
|
122
|
+
|
123
|
+
def get_result(self, node: Union[extractor, Extractor]) -> Any:
|
124
|
+
return self.runner.results[node.name]
|
125
|
+
|
126
|
+
def _assign_start_node(self):
|
127
|
+
# this method should be called before a graph can be run
|
128
|
+
nodes = sorted(self._topo_counter.items(), key=lambda x: x[1])
|
129
|
+
self._start_node = nodes[0][0]
|
indexify/local_runner.py
CHANGED
@@ -10,7 +10,7 @@ from indexify.extractor_sdk.extractor import extractor, Extractor
|
|
10
10
|
from collections import defaultdict
|
11
11
|
from typing import Any, Callable, Dict, Optional, Union
|
12
12
|
|
13
|
-
from indexify.
|
13
|
+
from indexify.graph import Graph
|
14
14
|
from indexify.runner import Runner
|
15
15
|
|
16
16
|
|
@@ -27,7 +27,7 @@ class LocalRunner(Runner):
|
|
27
27
|
# those bytes have to be a python type
|
28
28
|
|
29
29
|
# _input needs to be serializable into python object (ie json for ex) and Feature
|
30
|
-
def _run(self, g:
|
30
|
+
def _run(self, g: Graph, _input: BaseData, node_name: str):
|
31
31
|
print(f"---- Starting node {node_name}")
|
32
32
|
print(f'node_name {node_name}')
|
33
33
|
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: indexify
|
3
|
-
Version: 0.0.
|
3
|
+
Version: 0.0.40
|
4
4
|
Summary: Python Client for Indexify
|
5
5
|
Home-page: https://github.com/tensorlakeai/indexify
|
6
6
|
License: Apache 2.0
|
@@ -13,6 +13,7 @@ Classifier: Programming Language :: Python :: 3.9
|
|
13
13
|
Classifier: Programming Language :: Python :: 3.10
|
14
14
|
Classifier: Programming Language :: Python :: 3.11
|
15
15
|
Classifier: Programming Language :: Python :: 3.12
|
16
|
+
Requires-Dist: cloudpickle (>=3,<4)
|
16
17
|
Requires-Dist: httpx[http2] (>=0,<1)
|
17
18
|
Requires-Dist: pydantic (>=2.8,<3.0)
|
18
19
|
Requires-Dist: pyyaml (>=6,<7)
|
@@ -1,23 +1,24 @@
|
|
1
|
-
indexify/__init__.py,sha256=
|
1
|
+
indexify/__init__.py,sha256=0kUYM2FAVfYp0ZCx_3uQD5HbmDLDdBvEBtwHZrGQKaA,541
|
2
2
|
indexify/base_client.py,sha256=Db-BNYQ6yNmOIXPaQN8W5qjTYvfFvPzoxC9206YRc-U,2755
|
3
3
|
indexify/client.py,sha256=FPCO2DN6RstKLasmNrPxRhzBXDgM14tbc3eDDxl8J_A,25998
|
4
4
|
indexify/data_loaders/__init__.py,sha256=TmOJLgKC5gM7_1n7zxYiuza3fOilIiYYupxBGd31PfA,1339
|
5
5
|
indexify/data_loaders/local_directory_loader.py,sha256=0X_FgLS5unisJSij8LICv1htp8IdW09LbTIJ2wvVJg4,1246
|
6
|
-
indexify/data_loaders/url_loader.py,sha256=
|
6
|
+
indexify/data_loaders/url_loader.py,sha256=1q-uxFHsf5g5u49omzXHfP_zrzMwj-eFs7_1ugdr58g,1531
|
7
7
|
indexify/error.py,sha256=3umTeYb0ugtUyehV1ibfvaeACxAONPyWPc-1HRN4d1M,856
|
8
8
|
indexify/exceptions.py,sha256=vjd5SPPNFIEW35GorSIodsqvm9RKHQm9kdp8t9gv-WM,111
|
9
9
|
indexify/extraction_policy.py,sha256=awNDqwCz0tr4jTQmGf7s8_s6vcEuxMb0xynEl7b7iPI,2076
|
10
10
|
indexify/extractor_sdk/__init__.py,sha256=T512UtvFPUXEXlnT9HHHLHPcEau1Acoac_ksByuo7jA,348
|
11
|
-
indexify/extractor_sdk/data.py,sha256=
|
12
|
-
indexify/extractor_sdk/extractor.py,sha256=
|
11
|
+
indexify/extractor_sdk/data.py,sha256=DvNdq8w5XT4cyOR_wjWwyr32FdAfJ5297Hy89TqZcBI,2778
|
12
|
+
indexify/extractor_sdk/extractor.py,sha256=D7QshIoYzZaeAJKQlYilSzUeLNpp2innE5RVtEoa06s,9820
|
13
13
|
indexify/extractor_sdk/utils.py,sha256=_j8WflgOM0Qkf2NjhK2p1xXuwq4drLxO0mgKVPEHhlw,6594
|
14
|
-
indexify/
|
15
|
-
indexify/
|
16
|
-
indexify/
|
14
|
+
indexify/extractors/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
15
|
+
indexify/extractors/embedding.py,sha256=LlE2Ti0AJULHqar9a7VbvEnMX8VJ1m88-vFE2n_55M0,1898
|
16
|
+
indexify/graph.py,sha256=0pIGOBltNIk9HMfPf3iSwk_kTAzKJSbEGZCcr5PJBgg,3951
|
17
|
+
indexify/local_runner.py,sha256=04nYTuPzco0yqzrqAqjbrxNDC_AxLuxJmD7-76VLdUQ,4336
|
17
18
|
indexify/runner.py,sha256=M_3_GWYyPpb4lR5KFTpW8OAgp-fm9kYd_5xEqmiCBU4,637
|
18
19
|
indexify/settings.py,sha256=LSaWZ0ADIVmUv6o6dHWRC3-Ry5uLbCw2sBSg1e_U7UM,99
|
19
20
|
indexify/utils.py,sha256=rDN2lrsAs9noJEIjfx6ukmC2SAIyrlUt7QU-kaBjujM,125
|
20
|
-
indexify-0.0.
|
21
|
-
indexify-0.0.
|
22
|
-
indexify-0.0.
|
23
|
-
indexify-0.0.
|
21
|
+
indexify-0.0.40.dist-info/LICENSE.txt,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
22
|
+
indexify-0.0.40.dist-info/METADATA,sha256=YxPEZNNIPhedRKwTmOT555lEVAbokNgO37qbEu_OYXE,1913
|
23
|
+
indexify-0.0.40.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
|
24
|
+
indexify-0.0.40.dist-info/RECORD,,
|
indexify/run_graph.py
DELETED
@@ -1,122 +0,0 @@
|
|
1
|
-
import json
|
2
|
-
|
3
|
-
from .extractor_sdk import Content, extractor, Extractor
|
4
|
-
|
5
|
-
from collections import defaultdict
|
6
|
-
from typing import Any, Dict, List, Optional, Type, Union
|
7
|
-
from pydantic import BaseModel
|
8
|
-
|
9
|
-
import itertools
|
10
|
-
|
11
|
-
from .runner import Runner
|
12
|
-
|
13
|
-
@extractor(description="id function")
|
14
|
-
def _id(content: Content) -> List[Content]:
|
15
|
-
return [content]
|
16
|
-
|
17
|
-
|
18
|
-
class RunGraph:
|
19
|
-
def __init__(self, name: str, input: Type[BaseModel], start_node: extractor, runner: Runner):
|
20
|
-
# TODO check for cycles
|
21
|
-
self.name = name
|
22
|
-
|
23
|
-
self.nodes: Dict[str, Union[extractor, Extractor]] = {}
|
24
|
-
self.params: Dict[str, Any] = {}
|
25
|
-
|
26
|
-
self.edges: Dict[str, List[(str, str)]] = defaultdict(list)
|
27
|
-
|
28
|
-
self.nodes["start"] = _id
|
29
|
-
self.nodes["end"] = _id
|
30
|
-
|
31
|
-
self._topo_counter = defaultdict(int)
|
32
|
-
|
33
|
-
self._start_node = None
|
34
|
-
self._input = input
|
35
|
-
|
36
|
-
self.runner = runner
|
37
|
-
|
38
|
-
def _node(self, extractor: Union[extractor, Extractor], params: Any = None) -> 'RunGraph':
|
39
|
-
name = extractor.name
|
40
|
-
|
41
|
-
# if you've already inserted a node just ignore the new insertion.
|
42
|
-
if name in self.nodes:
|
43
|
-
return
|
44
|
-
|
45
|
-
self.nodes[name] = extractor
|
46
|
-
self.params[name] = extractor.__dict__.get("params", None)
|
47
|
-
|
48
|
-
# assign each node a rank of 1 to init the graph
|
49
|
-
self._topo_counter[name] = 1
|
50
|
-
|
51
|
-
return self
|
52
|
-
|
53
|
-
def add_edge(
|
54
|
-
self,
|
55
|
-
from_node: extractor,
|
56
|
-
to_node: extractor,
|
57
|
-
prefilter_predicates: Optional[str] = None,
|
58
|
-
) -> 'RunGraph':
|
59
|
-
|
60
|
-
self._node(from_node)
|
61
|
-
self._node(to_node)
|
62
|
-
|
63
|
-
from_node_name = from_node.name
|
64
|
-
to_node_name = to_node.name
|
65
|
-
|
66
|
-
self.edges[from_node_name].append((to_node_name, prefilter_predicates))
|
67
|
-
|
68
|
-
self._topo_counter[to_node_name] += 1
|
69
|
-
|
70
|
-
return self
|
71
|
-
|
72
|
-
"""
|
73
|
-
Connect nodes as a fan out from one `from_node` to multiple `to_nodes` and respective `prefilter_predicates`.
|
74
|
-
Note: The user has to match the sizes of the lists to make sure they line up otherwise a None is used as a default.
|
75
|
-
"""
|
76
|
-
|
77
|
-
def steps(
|
78
|
-
self,
|
79
|
-
from_node: extractor,
|
80
|
-
to_nodes: List[extractor],
|
81
|
-
prefilter_predicates: List[str] = [],
|
82
|
-
) -> 'RunGraph':
|
83
|
-
print(f"{to_nodes}, {prefilter_predicates}, {prefilter_predicates}")
|
84
|
-
for t_n, p in itertools.zip_longest(
|
85
|
-
to_nodes, prefilter_predicates, fillvalue=None
|
86
|
-
):
|
87
|
-
self.step(from_node=from_node, to_node=t_n, prefilter_predicates=p)
|
88
|
-
|
89
|
-
return self
|
90
|
-
|
91
|
-
def add_param(self, node: extractor, params: Dict[str, Any]):
|
92
|
-
try:
|
93
|
-
# check if the params can be serialized since the server needs this
|
94
|
-
json.dumps(params)
|
95
|
-
except Exception:
|
96
|
-
raise Exception(f"For node {node.name}, cannot serialize params as json.")
|
97
|
-
|
98
|
-
self.params[node.name] = params
|
99
|
-
|
100
|
-
def run(self, wf_input, local):
|
101
|
-
self._assign_start_node()
|
102
|
-
# self.runner = LocalRunner()
|
103
|
-
self.runner.run(self, wf_input=wf_input)
|
104
|
-
pass
|
105
|
-
|
106
|
-
def clear_cache_for_node(self, node: Union[extractor, Extractor]):
|
107
|
-
if node.name not in self.nodes.keys():
|
108
|
-
raise Exception(f"Node with name {node.name} not found in graph")
|
109
|
-
|
110
|
-
self.runner.deleted_from_memo(node.name)
|
111
|
-
|
112
|
-
def clear_cache_for_all_nodes(self):
|
113
|
-
for node_name in self.nodes:
|
114
|
-
self.runner.deleted_from_memo(node_name=node_name)
|
115
|
-
|
116
|
-
def get_result(self, node: Union[extractor, Extractor]) -> Any:
|
117
|
-
return self.runner.results[node.name]
|
118
|
-
|
119
|
-
def _assign_start_node(self):
|
120
|
-
# this method should be called before a graph can be run
|
121
|
-
nodes = sorted(self._topo_counter.items(), key=lambda x: x[1])
|
122
|
-
self._start_node = nodes[0][0]
|
File without changes
|
File without changes
|