nucliadb 6.4.0.post4127__py3-none-any.whl → 6.4.0.post4132__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nucliadb/common/cluster/grpc_node_dummy.py +1 -18
- nucliadb/common/cluster/manager.py +26 -21
- nucliadb/common/cluster/rebalance.py +7 -7
- nucliadb/common/cluster/rollover.py +12 -5
- nucliadb/common/nidx.py +0 -44
- nucliadb/ingest/consumer/auditing.py +5 -5
- nucliadb/ingest/consumer/shard_creator.py +5 -4
- nucliadb/ingest/orm/entities.py +4 -5
- nucliadb/metrics_exporter.py +0 -19
- nucliadb/purge/orphan_shards.py +17 -14
- nucliadb/search/api/v1/knowledgebox.py +6 -14
- nucliadb/search/api/v1/resource/search.py +2 -5
- nucliadb/search/api/v1/search.py +2 -6
- nucliadb/search/api/v1/suggest.py +1 -2
- nucliadb/search/requesters/utils.py +14 -33
- nucliadb/search/search/find.py +2 -8
- nucliadb/search/search/shards.py +9 -25
- nucliadb/train/generator.py +9 -11
- nucliadb/train/generators/field_classifier.py +3 -5
- nucliadb/train/generators/field_streaming.py +3 -5
- nucliadb/train/generators/image_classifier.py +1 -4
- nucliadb/train/generators/paragraph_classifier.py +3 -5
- nucliadb/train/generators/paragraph_streaming.py +3 -5
- nucliadb/train/generators/question_answer_streaming.py +3 -5
- nucliadb/train/generators/sentence_classifier.py +3 -5
- nucliadb/train/generators/token_classifier.py +3 -5
- nucliadb/train/nodes.py +2 -4
- {nucliadb-6.4.0.post4127.dist-info → nucliadb-6.4.0.post4132.dist-info}/METADATA +6 -6
- {nucliadb-6.4.0.post4127.dist-info → nucliadb-6.4.0.post4132.dist-info}/RECORD +32 -33
- nucliadb/common/cluster/base.py +0 -146
- {nucliadb-6.4.0.post4127.dist-info → nucliadb-6.4.0.post4132.dist-info}/WHEEL +0 -0
- {nucliadb-6.4.0.post4127.dist-info → nucliadb-6.4.0.post4132.dist-info}/entry_points.txt +0 -0
- {nucliadb-6.4.0.post4127.dist-info → nucliadb-6.4.0.post4132.dist-info}/top_level.txt +0 -0
@@ -35,8 +35,6 @@ from nidx_protos.nodereader_pb2 import (
|
|
35
35
|
SuggestResponse,
|
36
36
|
)
|
37
37
|
|
38
|
-
from nucliadb.common.cluster import manager as cluster_manager
|
39
|
-
from nucliadb.common.cluster.base import AbstractIndexNode
|
40
38
|
from nucliadb.common.cluster.exceptions import ShardsNotFound
|
41
39
|
from nucliadb.common.cluster.utils import get_shard_manager
|
42
40
|
from nucliadb.search import logger
|
@@ -78,7 +76,7 @@ async def node_query(
|
|
78
76
|
method: Method,
|
79
77
|
pb_query: SuggestRequest,
|
80
78
|
timeout: Optional[float] = None,
|
81
|
-
) -> tuple[list[SuggestResponse], bool, list[
|
79
|
+
) -> tuple[list[SuggestResponse], bool, list[str]]: ...
|
82
80
|
|
83
81
|
|
84
82
|
@overload
|
@@ -87,7 +85,7 @@ async def node_query(
|
|
87
85
|
method: Method,
|
88
86
|
pb_query: SearchRequest,
|
89
87
|
timeout: Optional[float] = None,
|
90
|
-
) -> tuple[list[SearchResponse], bool, list[
|
88
|
+
) -> tuple[list[SearchResponse], bool, list[str]]: ...
|
91
89
|
|
92
90
|
|
93
91
|
@overload
|
@@ -96,7 +94,7 @@ async def node_query(
|
|
96
94
|
method: Method,
|
97
95
|
pb_query: GraphSearchRequest,
|
98
96
|
timeout: Optional[float] = None,
|
99
|
-
) -> tuple[list[GraphSearchResponse], bool, list[
|
97
|
+
) -> tuple[list[GraphSearchResponse], bool, list[str]]: ...
|
100
98
|
|
101
99
|
|
102
100
|
async def node_query(
|
@@ -104,7 +102,7 @@ async def node_query(
|
|
104
102
|
method: Method,
|
105
103
|
pb_query: REQUEST_TYPE,
|
106
104
|
timeout: Optional[float] = None,
|
107
|
-
) -> tuple[Sequence[Union[T, BaseException]], bool, list[
|
105
|
+
) -> tuple[Sequence[Union[T, BaseException]], bool, list[str]]:
|
108
106
|
timeout = timeout or settings.search_timeout
|
109
107
|
shard_manager = get_shard_manager()
|
110
108
|
try:
|
@@ -116,21 +114,17 @@ async def node_query(
|
|
116
114
|
)
|
117
115
|
|
118
116
|
ops = []
|
119
|
-
|
117
|
+
queried_shards = []
|
120
118
|
incomplete_results = False
|
121
119
|
|
122
120
|
for shard_obj in shard_groups:
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
# let's add it ot the query list if has a valid value
|
131
|
-
func = METHODS[method]
|
132
|
-
ops.append(func(node, shard_id, pb_query)) # type: ignore
|
133
|
-
queried_nodes.append((node, shard_id))
|
121
|
+
shard_id = shard_obj.nidx_shard_id
|
122
|
+
if shard_id is not None:
|
123
|
+
# At least one node is alive for this shard group
|
124
|
+
# let's add it ot the query list if has a valid value
|
125
|
+
func = METHODS[method]
|
126
|
+
ops.append(func(shard_id, pb_query)) # type: ignore
|
127
|
+
queried_shards.append(shard_id)
|
134
128
|
|
135
129
|
if not ops:
|
136
130
|
logger.warning(f"No node found for any of this resources shards {kbid}")
|
@@ -146,8 +140,7 @@ async def node_query(
|
|
146
140
|
)
|
147
141
|
except asyncio.TimeoutError as exc: # pragma: no cover
|
148
142
|
logger.warning(
|
149
|
-
"Timeout while querying
|
150
|
-
extra={"nodes": debug_nodes_info(queried_nodes)},
|
143
|
+
"Timeout while querying nidx",
|
151
144
|
)
|
152
145
|
results = [exc]
|
153
146
|
|
@@ -164,7 +157,7 @@ async def node_query(
|
|
164
157
|
)
|
165
158
|
raise error
|
166
159
|
|
167
|
-
return results, incomplete_results,
|
160
|
+
return results, incomplete_results, queried_shards
|
168
161
|
|
169
162
|
|
170
163
|
def validate_node_query_results(results: list[Any]) -> Optional[HTTPException]:
|
@@ -201,15 +194,3 @@ def validate_node_query_results(results: list[Any]) -> Optional[HTTPException]:
|
|
201
194
|
return HTTPException(status_code=status_code, detail=reason)
|
202
195
|
|
203
196
|
return None
|
204
|
-
|
205
|
-
|
206
|
-
def debug_nodes_info(nodes: list[tuple[AbstractIndexNode, str]]) -> list[dict[str, str]]:
|
207
|
-
details: list[dict[str, str]] = []
|
208
|
-
for node, shard_id in nodes:
|
209
|
-
info = {
|
210
|
-
"id": node.id,
|
211
|
-
"shard_id": shard_id,
|
212
|
-
"address": "nidx",
|
213
|
-
}
|
214
|
-
details.append(info)
|
215
|
-
return details
|
nucliadb/search/search/find.py
CHANGED
@@ -23,7 +23,7 @@ from time import time
|
|
23
23
|
from nucliadb.common.external_index_providers.base import ExternalIndexManager
|
24
24
|
from nucliadb.common.external_index_providers.manager import get_external_index_manager
|
25
25
|
from nucliadb.common.models_utils import to_proto
|
26
|
-
from nucliadb.search.requesters.utils import Method,
|
26
|
+
from nucliadb.search.requesters.utils import Method, node_query
|
27
27
|
from nucliadb.search.search.find_merge import (
|
28
28
|
build_find_response,
|
29
29
|
compose_find_resources,
|
@@ -105,7 +105,7 @@ async def _index_node_retrieval(
|
|
105
105
|
) = await legacy_convert_retrieval_to_proto(parsed)
|
106
106
|
|
107
107
|
with metrics.time("node_query"):
|
108
|
-
results, query_incomplete_results,
|
108
|
+
results, query_incomplete_results, queried_shards = await node_query(
|
109
109
|
kbid, Method.SEARCH, pb_query
|
110
110
|
)
|
111
111
|
incomplete_results = incomplete_results or query_incomplete_results
|
@@ -139,10 +139,6 @@ async def _index_node_retrieval(
|
|
139
139
|
retrieval_rephrased_question=rephrased_query,
|
140
140
|
)
|
141
141
|
|
142
|
-
if item.debug:
|
143
|
-
search_results.nodes = debug_nodes_info(queried_nodes)
|
144
|
-
|
145
|
-
queried_shards = [shard_id for _, shard_id in queried_nodes]
|
146
142
|
search_results.shards = queried_shards
|
147
143
|
search_results.autofilters = autofilters
|
148
144
|
|
@@ -156,7 +152,6 @@ async def _index_node_retrieval(
|
|
156
152
|
"client": x_ndb_client,
|
157
153
|
"query": item.model_dump_json(),
|
158
154
|
"time": search_time,
|
159
|
-
"nodes": debug_nodes_info(queried_nodes),
|
160
155
|
"durations": metrics.steps(),
|
161
156
|
},
|
162
157
|
)
|
@@ -169,7 +164,6 @@ async def _index_node_retrieval(
|
|
169
164
|
"client": x_ndb_client,
|
170
165
|
"query": item.model_dump_json(),
|
171
166
|
"time": search_time,
|
172
|
-
"nodes": debug_nodes_info(queried_nodes),
|
173
167
|
"durations": metrics.steps(),
|
174
168
|
},
|
175
169
|
)
|
nucliadb/search/search/shards.py
CHANGED
@@ -17,7 +17,6 @@
|
|
17
17
|
# You should have received a copy of the GNU Affero General Public License
|
18
18
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
19
19
|
#
|
20
|
-
import asyncio
|
21
20
|
|
22
21
|
import backoff
|
23
22
|
from grpc import StatusCode
|
@@ -33,16 +32,7 @@ from nidx_protos.nodereader_pb2 import (
|
|
33
32
|
)
|
34
33
|
from nidx_protos.noderesources_pb2 import Shard
|
35
34
|
|
36
|
-
from nucliadb.common.
|
37
|
-
from nucliadb_telemetry import metrics
|
38
|
-
|
39
|
-
node_observer = metrics.Observer(
|
40
|
-
"node_client",
|
41
|
-
labels={"type": "", "node_id": ""},
|
42
|
-
error_mappings={
|
43
|
-
"timeout": asyncio.CancelledError,
|
44
|
-
},
|
45
|
-
)
|
35
|
+
from nucliadb.common.nidx import get_nidx_api_client, get_nidx_searcher_client
|
46
36
|
|
47
37
|
|
48
38
|
def should_giveup(e: Exception):
|
@@ -54,43 +44,37 @@ def should_giveup(e: Exception):
|
|
54
44
|
@backoff.on_exception(
|
55
45
|
backoff.expo, Exception, jitter=None, factor=0.1, max_tries=3, giveup=should_giveup
|
56
46
|
)
|
57
|
-
async def query_shard(
|
47
|
+
async def query_shard(shard: str, query: SearchRequest) -> SearchResponse:
|
58
48
|
req = SearchRequest()
|
59
49
|
req.CopyFrom(query)
|
60
50
|
req.shard = shard
|
61
|
-
|
62
|
-
return await node.reader.Search(req) # type: ignore
|
51
|
+
return await get_nidx_searcher_client().Search(req)
|
63
52
|
|
64
53
|
|
65
54
|
@backoff.on_exception(
|
66
55
|
backoff.expo, Exception, jitter=None, factor=0.1, max_tries=3, giveup=should_giveup
|
67
56
|
)
|
68
|
-
async def get_shard(
|
57
|
+
async def get_shard(shard_id: str) -> Shard:
|
69
58
|
req = GetShardRequest()
|
70
59
|
req.shard_id.id = shard_id
|
71
|
-
|
72
|
-
return await node.reader.GetShard(req) # type: ignore
|
60
|
+
return await get_nidx_api_client().GetShard(req)
|
73
61
|
|
74
62
|
|
75
63
|
@backoff.on_exception(
|
76
64
|
backoff.expo, Exception, jitter=None, factor=0.1, max_tries=3, giveup=should_giveup
|
77
65
|
)
|
78
|
-
async def suggest_shard(
|
66
|
+
async def suggest_shard(shard: str, query: SuggestRequest) -> SuggestResponse:
|
79
67
|
req = SuggestRequest()
|
80
68
|
req.CopyFrom(query)
|
81
69
|
req.shard = shard
|
82
|
-
|
83
|
-
return await node.reader.Suggest(req) # type: ignore
|
70
|
+
return await get_nidx_searcher_client().Suggest(req)
|
84
71
|
|
85
72
|
|
86
73
|
@backoff.on_exception(
|
87
74
|
backoff.expo, Exception, jitter=None, factor=0.1, max_tries=3, giveup=should_giveup
|
88
75
|
)
|
89
|
-
async def graph_search_shard(
|
90
|
-
node: AbstractIndexNode, shard: str, query: GraphSearchRequest
|
91
|
-
) -> GraphSearchResponse:
|
76
|
+
async def graph_search_shard(shard: str, query: GraphSearchRequest) -> GraphSearchResponse:
|
92
77
|
req = GraphSearchRequest()
|
93
78
|
req.CopyFrom(query)
|
94
79
|
req.shard = shard
|
95
|
-
|
96
|
-
return await node.reader.GraphSearch(req) # type: ignore
|
80
|
+
return await get_nidx_searcher_client().GraphSearch(req)
|
nucliadb/train/generator.py
CHANGED
@@ -54,7 +54,7 @@ from nucliadb_protos.dataset_pb2 import TaskType, TrainSet
|
|
54
54
|
async def generate_train_data(kbid: str, shard: str, trainset: TrainSet):
|
55
55
|
# Get the data structure to generate data
|
56
56
|
shard_manager = get_shard_manager()
|
57
|
-
|
57
|
+
shard_replica_id = await shard_manager.get_shard_id(kbid, shard)
|
58
58
|
|
59
59
|
if trainset.batch_size == 0:
|
60
60
|
trainset.batch_size = 50
|
@@ -62,24 +62,22 @@ async def generate_train_data(kbid: str, shard: str, trainset: TrainSet):
|
|
62
62
|
batch_generator: Optional[AsyncIterator[TrainBatch]] = None
|
63
63
|
|
64
64
|
if trainset.type == TaskType.FIELD_CLASSIFICATION:
|
65
|
-
batch_generator = field_classification_batch_generator(kbid, trainset,
|
65
|
+
batch_generator = field_classification_batch_generator(kbid, trainset, shard_replica_id)
|
66
66
|
elif trainset.type == TaskType.IMAGE_CLASSIFICATION:
|
67
|
-
batch_generator = image_classification_batch_generator(kbid, trainset,
|
67
|
+
batch_generator = image_classification_batch_generator(kbid, trainset, shard_replica_id)
|
68
68
|
elif trainset.type == TaskType.PARAGRAPH_CLASSIFICATION:
|
69
|
-
batch_generator = paragraph_classification_batch_generator(
|
70
|
-
kbid, trainset, node, shard_replica_id
|
71
|
-
)
|
69
|
+
batch_generator = paragraph_classification_batch_generator(kbid, trainset, shard_replica_id)
|
72
70
|
elif trainset.type == TaskType.TOKEN_CLASSIFICATION:
|
73
|
-
batch_generator = token_classification_batch_generator(kbid, trainset,
|
71
|
+
batch_generator = token_classification_batch_generator(kbid, trainset, shard_replica_id)
|
74
72
|
elif trainset.type == TaskType.SENTENCE_CLASSIFICATION:
|
75
|
-
batch_generator = sentence_classification_batch_generator(kbid, trainset,
|
73
|
+
batch_generator = sentence_classification_batch_generator(kbid, trainset, shard_replica_id)
|
76
74
|
elif trainset.type == TaskType.PARAGRAPH_STREAMING:
|
77
|
-
batch_generator = paragraph_streaming_batch_generator(kbid, trainset,
|
75
|
+
batch_generator = paragraph_streaming_batch_generator(kbid, trainset, shard_replica_id)
|
78
76
|
|
79
77
|
elif trainset.type == TaskType.QUESTION_ANSWER_STREAMING:
|
80
|
-
batch_generator = question_answer_batch_generator(kbid, trainset,
|
78
|
+
batch_generator = question_answer_batch_generator(kbid, trainset, shard_replica_id)
|
81
79
|
elif trainset.type == TaskType.FIELD_STREAMING:
|
82
|
-
batch_generator = field_streaming_batch_generator(kbid, trainset,
|
80
|
+
batch_generator = field_streaming_batch_generator(kbid, trainset, shard_replica_id)
|
83
81
|
|
84
82
|
if batch_generator is None:
|
85
83
|
raise HTTPException(
|
@@ -22,8 +22,8 @@ from typing import AsyncGenerator
|
|
22
22
|
|
23
23
|
from nidx_protos.nodereader_pb2 import StreamRequest
|
24
24
|
|
25
|
-
from nucliadb.common.cluster.base import AbstractIndexNode
|
26
25
|
from nucliadb.common.ids import FIELD_TYPE_STR_TO_PB
|
26
|
+
from nucliadb.common.nidx import get_nidx_searcher_client
|
27
27
|
from nucliadb.train import logger
|
28
28
|
from nucliadb.train.generators.utils import batchify, get_resource_from_cache_or_db
|
29
29
|
from nucliadb_protos.dataset_pb2 import (
|
@@ -37,10 +37,9 @@ from nucliadb_protos.dataset_pb2 import (
|
|
37
37
|
def field_classification_batch_generator(
|
38
38
|
kbid: str,
|
39
39
|
trainset: TrainSet,
|
40
|
-
node: AbstractIndexNode,
|
41
40
|
shard_replica_id: str,
|
42
41
|
) -> AsyncGenerator[FieldClassificationBatch, None]:
|
43
|
-
generator = generate_field_classification_payloads(kbid, trainset,
|
42
|
+
generator = generate_field_classification_payloads(kbid, trainset, shard_replica_id)
|
44
43
|
batch_generator = batchify(generator, trainset.batch_size, FieldClassificationBatch)
|
45
44
|
return batch_generator
|
46
45
|
|
@@ -48,7 +47,6 @@ def field_classification_batch_generator(
|
|
48
47
|
async def generate_field_classification_payloads(
|
49
48
|
kbid: str,
|
50
49
|
trainset: TrainSet,
|
51
|
-
node: AbstractIndexNode,
|
52
50
|
shard_replica_id: str,
|
53
51
|
) -> AsyncGenerator[TextLabel, None]:
|
54
52
|
labelset = f"/l/{trainset.filter.labels[0]}"
|
@@ -59,7 +57,7 @@ async def generate_field_classification_payloads(
|
|
59
57
|
request.filter.labels.append(labelset)
|
60
58
|
total = 0
|
61
59
|
|
62
|
-
async for document_item in
|
60
|
+
async for document_item in get_nidx_searcher_client().Documents(request):
|
63
61
|
text_labels = []
|
64
62
|
for label in document_item.labels:
|
65
63
|
if label.startswith(labelset):
|
@@ -22,8 +22,8 @@ from typing import AsyncGenerator, Optional
|
|
22
22
|
|
23
23
|
from nidx_protos.nodereader_pb2 import StreamRequest
|
24
24
|
|
25
|
-
from nucliadb.common.cluster.base import AbstractIndexNode
|
26
25
|
from nucliadb.common.ids import FIELD_TYPE_STR_TO_PB
|
26
|
+
from nucliadb.common.nidx import get_nidx_searcher_client
|
27
27
|
from nucliadb.train import logger
|
28
28
|
from nucliadb.train.generators.utils import batchify, get_resource_from_cache_or_db
|
29
29
|
from nucliadb_protos.dataset_pb2 import (
|
@@ -38,10 +38,9 @@ from nucliadb_protos.utils_pb2 import ExtractedText
|
|
38
38
|
def field_streaming_batch_generator(
|
39
39
|
kbid: str,
|
40
40
|
trainset: TrainSet,
|
41
|
-
node: AbstractIndexNode,
|
42
41
|
shard_replica_id: str,
|
43
42
|
) -> AsyncGenerator[FieldStreamingBatch, None]:
|
44
|
-
generator = generate_field_streaming_payloads(kbid, trainset,
|
43
|
+
generator = generate_field_streaming_payloads(kbid, trainset, shard_replica_id)
|
45
44
|
batch_generator = batchify(generator, trainset.batch_size, FieldStreamingBatch)
|
46
45
|
return batch_generator
|
47
46
|
|
@@ -49,7 +48,6 @@ def field_streaming_batch_generator(
|
|
49
48
|
async def generate_field_streaming_payloads(
|
50
49
|
kbid: str,
|
51
50
|
trainset: TrainSet,
|
52
|
-
node: AbstractIndexNode,
|
53
51
|
shard_replica_id: str,
|
54
52
|
) -> AsyncGenerator[FieldSplitData, None]:
|
55
53
|
# Query how many resources has each label
|
@@ -77,7 +75,7 @@ async def generate_field_streaming_payloads(
|
|
77
75
|
total = 0
|
78
76
|
resources = set()
|
79
77
|
|
80
|
-
async for document_item in
|
78
|
+
async for document_item in get_nidx_searcher_client().Documents(request):
|
81
79
|
text_labels = []
|
82
80
|
for label in document_item.labels:
|
83
81
|
text_labels.append(label)
|
@@ -20,7 +20,6 @@
|
|
20
20
|
|
21
21
|
from typing import AsyncGenerator
|
22
22
|
|
23
|
-
from nucliadb.common.cluster.base import AbstractIndexNode
|
24
23
|
from nucliadb.train.generators.utils import batchify
|
25
24
|
from nucliadb_protos.dataset_pb2 import (
|
26
25
|
ImageClassification,
|
@@ -32,10 +31,9 @@ from nucliadb_protos.dataset_pb2 import (
|
|
32
31
|
def image_classification_batch_generator(
|
33
32
|
kbid: str,
|
34
33
|
trainset: TrainSet,
|
35
|
-
node: AbstractIndexNode,
|
36
34
|
shard_replica_id: str,
|
37
35
|
) -> AsyncGenerator[ImageClassificationBatch, None]:
|
38
|
-
generator = generate_image_classification_payloads(kbid, trainset,
|
36
|
+
generator = generate_image_classification_payloads(kbid, trainset, shard_replica_id)
|
39
37
|
batch_generator = batchify(generator, trainset.batch_size, ImageClassificationBatch)
|
40
38
|
return batch_generator
|
41
39
|
|
@@ -43,7 +41,6 @@ def image_classification_batch_generator(
|
|
43
41
|
async def generate_image_classification_payloads(
|
44
42
|
kbid: str,
|
45
43
|
trainset: TrainSet,
|
46
|
-
node: AbstractIndexNode,
|
47
44
|
shard_replica_id: str,
|
48
45
|
) -> AsyncGenerator[ImageClassification, None]:
|
49
46
|
# NOTE: image classifications are no longer supported, as the page selection annotations were removed
|
@@ -23,7 +23,7 @@ from typing import AsyncGenerator
|
|
23
23
|
from fastapi import HTTPException
|
24
24
|
from nidx_protos.nodereader_pb2 import StreamRequest
|
25
25
|
|
26
|
-
from nucliadb.common.
|
26
|
+
from nucliadb.common.nidx import get_nidx_searcher_client
|
27
27
|
from nucliadb.train.generators.utils import batchify, get_paragraph
|
28
28
|
from nucliadb_protos.dataset_pb2 import (
|
29
29
|
Label,
|
@@ -36,7 +36,6 @@ from nucliadb_protos.dataset_pb2 import (
|
|
36
36
|
def paragraph_classification_batch_generator(
|
37
37
|
kbid: str,
|
38
38
|
trainset: TrainSet,
|
39
|
-
node: AbstractIndexNode,
|
40
39
|
shard_replica_id: str,
|
41
40
|
) -> AsyncGenerator[ParagraphClassificationBatch, None]:
|
42
41
|
if len(trainset.filter.labels) != 1:
|
@@ -45,7 +44,7 @@ def paragraph_classification_batch_generator(
|
|
45
44
|
detail="Paragraph Classification should be of 1 labelset",
|
46
45
|
)
|
47
46
|
|
48
|
-
generator = generate_paragraph_classification_payloads(kbid, trainset,
|
47
|
+
generator = generate_paragraph_classification_payloads(kbid, trainset, shard_replica_id)
|
49
48
|
batch_generator = batchify(generator, trainset.batch_size, ParagraphClassificationBatch)
|
50
49
|
return batch_generator
|
51
50
|
|
@@ -53,7 +52,6 @@ def paragraph_classification_batch_generator(
|
|
53
52
|
async def generate_paragraph_classification_payloads(
|
54
53
|
kbid: str,
|
55
54
|
trainset: TrainSet,
|
56
|
-
node: AbstractIndexNode,
|
57
55
|
shard_replica_id: str,
|
58
56
|
) -> AsyncGenerator[TextLabel, None]:
|
59
57
|
labelset = f"/l/{trainset.filter.labels[0]}"
|
@@ -63,7 +61,7 @@ async def generate_paragraph_classification_payloads(
|
|
63
61
|
request.shard_id.id = shard_replica_id
|
64
62
|
request.filter.labels.append(labelset)
|
65
63
|
|
66
|
-
async for paragraph_item in
|
64
|
+
async for paragraph_item in get_nidx_searcher_client().Paragraphs(request):
|
67
65
|
text_labels = []
|
68
66
|
for label in paragraph_item.labels:
|
69
67
|
if label.startswith(labelset):
|
@@ -22,8 +22,8 @@ from typing import AsyncGenerator
|
|
22
22
|
|
23
23
|
from nidx_protos.nodereader_pb2 import StreamRequest
|
24
24
|
|
25
|
-
from nucliadb.common.cluster.base import AbstractIndexNode
|
26
25
|
from nucliadb.common.ids import FIELD_TYPE_STR_TO_PB
|
26
|
+
from nucliadb.common.nidx import get_nidx_searcher_client
|
27
27
|
from nucliadb.train import logger
|
28
28
|
from nucliadb.train.generators.utils import batchify, get_resource_from_cache_or_db
|
29
29
|
from nucliadb_protos.dataset_pb2 import (
|
@@ -36,10 +36,9 @@ from nucliadb_protos.dataset_pb2 import (
|
|
36
36
|
def paragraph_streaming_batch_generator(
|
37
37
|
kbid: str,
|
38
38
|
trainset: TrainSet,
|
39
|
-
node: AbstractIndexNode,
|
40
39
|
shard_replica_id: str,
|
41
40
|
) -> AsyncGenerator[ParagraphStreamingBatch, None]:
|
42
|
-
generator = generate_paragraph_streaming_payloads(kbid, trainset,
|
41
|
+
generator = generate_paragraph_streaming_payloads(kbid, trainset, shard_replica_id)
|
43
42
|
batch_generator = batchify(generator, trainset.batch_size, ParagraphStreamingBatch)
|
44
43
|
return batch_generator
|
45
44
|
|
@@ -47,7 +46,6 @@ def paragraph_streaming_batch_generator(
|
|
47
46
|
async def generate_paragraph_streaming_payloads(
|
48
47
|
kbid: str,
|
49
48
|
trainset: TrainSet,
|
50
|
-
node: AbstractIndexNode,
|
51
49
|
shard_replica_id: str,
|
52
50
|
) -> AsyncGenerator[ParagraphStreamItem, None]:
|
53
51
|
"""Streams paragraphs ordered as if they were read sequentially from each
|
@@ -57,7 +55,7 @@ async def generate_paragraph_streaming_payloads(
|
|
57
55
|
request = StreamRequest()
|
58
56
|
request.shard_id.id = shard_replica_id
|
59
57
|
|
60
|
-
async for document_item in
|
58
|
+
async for document_item in get_nidx_searcher_client().Documents(request):
|
61
59
|
field_id = f"{document_item.uuid}{document_item.field}"
|
62
60
|
rid, field_type, field = field_id.split("/")
|
63
61
|
|
@@ -22,8 +22,8 @@ from typing import AsyncGenerator
|
|
22
22
|
|
23
23
|
from nidx_protos.nodereader_pb2 import StreamRequest
|
24
24
|
|
25
|
-
from nucliadb.common.cluster.base import AbstractIndexNode
|
26
25
|
from nucliadb.common.ids import FIELD_TYPE_PB_TO_STR, FIELD_TYPE_STR_TO_PB
|
26
|
+
from nucliadb.common.nidx import get_nidx_searcher_client
|
27
27
|
from nucliadb.train import logger
|
28
28
|
from nucliadb.train.generators.utils import (
|
29
29
|
batchify,
|
@@ -45,10 +45,9 @@ from nucliadb_protos.resources_pb2 import (
|
|
45
45
|
def question_answer_batch_generator(
|
46
46
|
kbid: str,
|
47
47
|
trainset: TrainSet,
|
48
|
-
node: AbstractIndexNode,
|
49
48
|
shard_replica_id: str,
|
50
49
|
) -> AsyncGenerator[QuestionAnswerStreamingBatch, None]:
|
51
|
-
generator = generate_question_answer_streaming_payloads(kbid, trainset,
|
50
|
+
generator = generate_question_answer_streaming_payloads(kbid, trainset, shard_replica_id)
|
52
51
|
batch_generator = batchify(generator, trainset.batch_size, QuestionAnswerStreamingBatch)
|
53
52
|
return batch_generator
|
54
53
|
|
@@ -56,13 +55,12 @@ def question_answer_batch_generator(
|
|
56
55
|
async def generate_question_answer_streaming_payloads(
|
57
56
|
kbid: str,
|
58
57
|
trainset: TrainSet,
|
59
|
-
node: AbstractIndexNode,
|
60
58
|
shard_replica_id: str,
|
61
59
|
):
|
62
60
|
request = StreamRequest()
|
63
61
|
request.shard_id.id = shard_replica_id
|
64
62
|
|
65
|
-
async for document_item in
|
63
|
+
async for document_item in get_nidx_searcher_client().Documents(request):
|
66
64
|
field_id = f"{document_item.uuid}{document_item.field}"
|
67
65
|
rid, field_type, field = field_id.split("/")
|
68
66
|
|
@@ -23,8 +23,8 @@ from typing import AsyncGenerator
|
|
23
23
|
from fastapi import HTTPException
|
24
24
|
from nidx_protos.nodereader_pb2 import StreamRequest
|
25
25
|
|
26
|
-
from nucliadb.common.cluster.base import AbstractIndexNode
|
27
26
|
from nucliadb.common.ids import FIELD_TYPE_STR_TO_PB
|
27
|
+
from nucliadb.common.nidx import get_nidx_searcher_client
|
28
28
|
from nucliadb.train import logger
|
29
29
|
from nucliadb.train.generators.utils import batchify, get_resource_from_cache_or_db
|
30
30
|
from nucliadb_protos.dataset_pb2 import (
|
@@ -38,7 +38,6 @@ from nucliadb_protos.dataset_pb2 import (
|
|
38
38
|
def sentence_classification_batch_generator(
|
39
39
|
kbid: str,
|
40
40
|
trainset: TrainSet,
|
41
|
-
node: AbstractIndexNode,
|
42
41
|
shard_replica_id: str,
|
43
42
|
) -> AsyncGenerator[SentenceClassificationBatch, None]:
|
44
43
|
if len(trainset.filter.labels) == 0:
|
@@ -47,7 +46,7 @@ def sentence_classification_batch_generator(
|
|
47
46
|
detail="Sentence Classification should be at least of 1 labelset",
|
48
47
|
)
|
49
48
|
|
50
|
-
generator = generate_sentence_classification_payloads(kbid, trainset,
|
49
|
+
generator = generate_sentence_classification_payloads(kbid, trainset, shard_replica_id)
|
51
50
|
batch_generator = batchify(generator, trainset.batch_size, SentenceClassificationBatch)
|
52
51
|
return batch_generator
|
53
52
|
|
@@ -55,7 +54,6 @@ def sentence_classification_batch_generator(
|
|
55
54
|
async def generate_sentence_classification_payloads(
|
56
55
|
kbid: str,
|
57
56
|
trainset: TrainSet,
|
58
|
-
node: AbstractIndexNode,
|
59
57
|
shard_replica_id: str,
|
60
58
|
) -> AsyncGenerator[MultipleTextSameLabels, None]:
|
61
59
|
labelsets = []
|
@@ -67,7 +65,7 @@ async def generate_sentence_classification_payloads(
|
|
67
65
|
labelsets.append(labelset)
|
68
66
|
request.filter.labels.append(labelset)
|
69
67
|
|
70
|
-
async for paragraph_item in
|
68
|
+
async for paragraph_item in get_nidx_searcher_client().Paragraphs(request):
|
71
69
|
text_labels: list[str] = []
|
72
70
|
for label in paragraph_item.labels:
|
73
71
|
for labelset in labelsets:
|
@@ -23,8 +23,8 @@ from typing import AsyncGenerator, cast
|
|
23
23
|
|
24
24
|
from nidx_protos.nodereader_pb2 import StreamFilter, StreamRequest
|
25
25
|
|
26
|
-
from nucliadb.common.cluster.base import AbstractIndexNode
|
27
26
|
from nucliadb.common.ids import FIELD_TYPE_STR_TO_PB
|
27
|
+
from nucliadb.common.nidx import get_nidx_searcher_client
|
28
28
|
from nucliadb.train import logger
|
29
29
|
from nucliadb.train.generators.utils import batchify, get_resource_from_cache_or_db
|
30
30
|
from nucliadb_protos.dataset_pb2 import (
|
@@ -41,10 +41,9 @@ MAIN = "__main__"
|
|
41
41
|
def token_classification_batch_generator(
|
42
42
|
kbid: str,
|
43
43
|
trainset: TrainSet,
|
44
|
-
node: AbstractIndexNode,
|
45
44
|
shard_replica_id: str,
|
46
45
|
) -> AsyncGenerator[TokenClassificationBatch, None]:
|
47
|
-
generator = generate_token_classification_payloads(kbid, trainset,
|
46
|
+
generator = generate_token_classification_payloads(kbid, trainset, shard_replica_id)
|
48
47
|
batch_generator = batchify(generator, trainset.batch_size, TokenClassificationBatch)
|
49
48
|
return batch_generator
|
50
49
|
|
@@ -52,7 +51,6 @@ def token_classification_batch_generator(
|
|
52
51
|
async def generate_token_classification_payloads(
|
53
52
|
kbid: str,
|
54
53
|
trainset: TrainSet,
|
55
|
-
node: AbstractIndexNode,
|
56
54
|
shard_replica_id: str,
|
57
55
|
) -> AsyncGenerator[TokensClassification, None]:
|
58
56
|
request = StreamRequest()
|
@@ -60,7 +58,7 @@ async def generate_token_classification_payloads(
|
|
60
58
|
for entitygroup in trainset.filter.labels:
|
61
59
|
request.filter.labels.append(f"/e/{entitygroup}")
|
62
60
|
request.filter.conjunction = StreamFilter.Conjunction.OR
|
63
|
-
async for field_item in
|
61
|
+
async for field_item in get_nidx_searcher_client().Documents(request):
|
64
62
|
_, field_type, field = field_item.field.split("/")
|
65
63
|
(
|
66
64
|
split_text,
|
nucliadb/train/nodes.py
CHANGED
@@ -21,7 +21,6 @@ from typing import AsyncIterator, Optional
|
|
21
21
|
|
22
22
|
from nucliadb.common import datamanagers
|
23
23
|
from nucliadb.common.cluster import manager
|
24
|
-
from nucliadb.common.cluster.base import AbstractIndexNode
|
25
24
|
|
26
25
|
# XXX: this keys shouldn't be exposed outside datamanagers
|
27
26
|
from nucliadb.common.datamanagers.resources import KB_RESOURCE_SLUG_BASE
|
@@ -54,15 +53,14 @@ class TrainShardManager(manager.KBShardManager):
|
|
54
53
|
self.driver = driver
|
55
54
|
self.storage = storage
|
56
55
|
|
57
|
-
async def
|
56
|
+
async def get_shard_id(self, kbid: str, shard: str) -> str:
|
58
57
|
shards = await self.get_shards_by_kbid_inner(kbid)
|
59
58
|
try:
|
60
59
|
shard_object: ShardObject = next(filter(lambda x: x.shard == shard, shards.shards))
|
61
60
|
except StopIteration:
|
62
61
|
raise KeyError("Shard not found")
|
63
62
|
|
64
|
-
|
65
|
-
return node_obj, shard_id
|
63
|
+
return shard_object.nidx_shard_id
|
66
64
|
|
67
65
|
async def get_kb_obj(self, txn: Transaction, kbid: str) -> Optional[KnowledgeBox]:
|
68
66
|
if kbid is None:
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: nucliadb
|
3
|
-
Version: 6.4.0.
|
3
|
+
Version: 6.4.0.post4132
|
4
4
|
Summary: NucliaDB
|
5
5
|
Author-email: Nuclia <nucliadb@nuclia.com>
|
6
6
|
License: AGPL
|
@@ -20,11 +20,11 @@ Classifier: Programming Language :: Python :: 3.12
|
|
20
20
|
Classifier: Programming Language :: Python :: 3 :: Only
|
21
21
|
Requires-Python: <4,>=3.9
|
22
22
|
Description-Content-Type: text/markdown
|
23
|
-
Requires-Dist: nucliadb-telemetry[all]>=6.4.0.
|
24
|
-
Requires-Dist: nucliadb-utils[cache,fastapi,storages]>=6.4.0.
|
25
|
-
Requires-Dist: nucliadb-protos>=6.4.0.
|
26
|
-
Requires-Dist: nucliadb-models>=6.4.0.
|
27
|
-
Requires-Dist: nidx-protos>=6.4.0.
|
23
|
+
Requires-Dist: nucliadb-telemetry[all]>=6.4.0.post4132
|
24
|
+
Requires-Dist: nucliadb-utils[cache,fastapi,storages]>=6.4.0.post4132
|
25
|
+
Requires-Dist: nucliadb-protos>=6.4.0.post4132
|
26
|
+
Requires-Dist: nucliadb-models>=6.4.0.post4132
|
27
|
+
Requires-Dist: nidx-protos>=6.4.0.post4132
|
28
28
|
Requires-Dist: nucliadb-admin-assets>=1.0.0.post1224
|
29
29
|
Requires-Dist: nuclia-models>=0.24.2
|
30
30
|
Requires-Dist: uvicorn[standard]
|