nucliadb 6.4.0.post4127__py3-none-any.whl → 6.4.0.post4132__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. nucliadb/common/cluster/grpc_node_dummy.py +1 -18
  2. nucliadb/common/cluster/manager.py +26 -21
  3. nucliadb/common/cluster/rebalance.py +7 -7
  4. nucliadb/common/cluster/rollover.py +12 -5
  5. nucliadb/common/nidx.py +0 -44
  6. nucliadb/ingest/consumer/auditing.py +5 -5
  7. nucliadb/ingest/consumer/shard_creator.py +5 -4
  8. nucliadb/ingest/orm/entities.py +4 -5
  9. nucliadb/metrics_exporter.py +0 -19
  10. nucliadb/purge/orphan_shards.py +17 -14
  11. nucliadb/search/api/v1/knowledgebox.py +6 -14
  12. nucliadb/search/api/v1/resource/search.py +2 -5
  13. nucliadb/search/api/v1/search.py +2 -6
  14. nucliadb/search/api/v1/suggest.py +1 -2
  15. nucliadb/search/requesters/utils.py +14 -33
  16. nucliadb/search/search/find.py +2 -8
  17. nucliadb/search/search/shards.py +9 -25
  18. nucliadb/train/generator.py +9 -11
  19. nucliadb/train/generators/field_classifier.py +3 -5
  20. nucliadb/train/generators/field_streaming.py +3 -5
  21. nucliadb/train/generators/image_classifier.py +1 -4
  22. nucliadb/train/generators/paragraph_classifier.py +3 -5
  23. nucliadb/train/generators/paragraph_streaming.py +3 -5
  24. nucliadb/train/generators/question_answer_streaming.py +3 -5
  25. nucliadb/train/generators/sentence_classifier.py +3 -5
  26. nucliadb/train/generators/token_classifier.py +3 -5
  27. nucliadb/train/nodes.py +2 -4
  28. {nucliadb-6.4.0.post4127.dist-info → nucliadb-6.4.0.post4132.dist-info}/METADATA +6 -6
  29. {nucliadb-6.4.0.post4127.dist-info → nucliadb-6.4.0.post4132.dist-info}/RECORD +32 -33
  30. nucliadb/common/cluster/base.py +0 -146
  31. {nucliadb-6.4.0.post4127.dist-info → nucliadb-6.4.0.post4132.dist-info}/WHEEL +0 -0
  32. {nucliadb-6.4.0.post4127.dist-info → nucliadb-6.4.0.post4132.dist-info}/entry_points.txt +0 -0
  33. {nucliadb-6.4.0.post4127.dist-info → nucliadb-6.4.0.post4132.dist-info}/top_level.txt +0 -0
@@ -35,8 +35,6 @@ from nidx_protos.nodereader_pb2 import (
35
35
  SuggestResponse,
36
36
  )
37
37
 
38
- from nucliadb.common.cluster import manager as cluster_manager
39
- from nucliadb.common.cluster.base import AbstractIndexNode
40
38
  from nucliadb.common.cluster.exceptions import ShardsNotFound
41
39
  from nucliadb.common.cluster.utils import get_shard_manager
42
40
  from nucliadb.search import logger
@@ -78,7 +76,7 @@ async def node_query(
78
76
  method: Method,
79
77
  pb_query: SuggestRequest,
80
78
  timeout: Optional[float] = None,
81
- ) -> tuple[list[SuggestResponse], bool, list[tuple[AbstractIndexNode, str]]]: ...
79
+ ) -> tuple[list[SuggestResponse], bool, list[str]]: ...
82
80
 
83
81
 
84
82
  @overload
@@ -87,7 +85,7 @@ async def node_query(
87
85
  method: Method,
88
86
  pb_query: SearchRequest,
89
87
  timeout: Optional[float] = None,
90
- ) -> tuple[list[SearchResponse], bool, list[tuple[AbstractIndexNode, str]]]: ...
88
+ ) -> tuple[list[SearchResponse], bool, list[str]]: ...
91
89
 
92
90
 
93
91
  @overload
@@ -96,7 +94,7 @@ async def node_query(
96
94
  method: Method,
97
95
  pb_query: GraphSearchRequest,
98
96
  timeout: Optional[float] = None,
99
- ) -> tuple[list[GraphSearchResponse], bool, list[tuple[AbstractIndexNode, str]]]: ...
97
+ ) -> tuple[list[GraphSearchResponse], bool, list[str]]: ...
100
98
 
101
99
 
102
100
  async def node_query(
@@ -104,7 +102,7 @@ async def node_query(
104
102
  method: Method,
105
103
  pb_query: REQUEST_TYPE,
106
104
  timeout: Optional[float] = None,
107
- ) -> tuple[Sequence[Union[T, BaseException]], bool, list[tuple[AbstractIndexNode, str]]]:
105
+ ) -> tuple[Sequence[Union[T, BaseException]], bool, list[str]]:
108
106
  timeout = timeout or settings.search_timeout
109
107
  shard_manager = get_shard_manager()
110
108
  try:
@@ -116,21 +114,17 @@ async def node_query(
116
114
  )
117
115
 
118
116
  ops = []
119
- queried_nodes = []
117
+ queried_shards = []
120
118
  incomplete_results = False
121
119
 
122
120
  for shard_obj in shard_groups:
123
- try:
124
- node, shard_id = cluster_manager.choose_node(shard_obj)
125
- except KeyError:
126
- incomplete_results = True
127
- else:
128
- if shard_id is not None:
129
- # At least one node is alive for this shard group
130
- # let's add it ot the query list if has a valid value
131
- func = METHODS[method]
132
- ops.append(func(node, shard_id, pb_query)) # type: ignore
133
- queried_nodes.append((node, shard_id))
121
+ shard_id = shard_obj.nidx_shard_id
122
+ if shard_id is not None:
123
+ # At least one node is alive for this shard group
124
+ # let's add it ot the query list if has a valid value
125
+ func = METHODS[method]
126
+ ops.append(func(shard_id, pb_query)) # type: ignore
127
+ queried_shards.append(shard_id)
134
128
 
135
129
  if not ops:
136
130
  logger.warning(f"No node found for any of this resources shards {kbid}")
@@ -146,8 +140,7 @@ async def node_query(
146
140
  )
147
141
  except asyncio.TimeoutError as exc: # pragma: no cover
148
142
  logger.warning(
149
- "Timeout while querying nodes",
150
- extra={"nodes": debug_nodes_info(queried_nodes)},
143
+ "Timeout while querying nidx",
151
144
  )
152
145
  results = [exc]
153
146
 
@@ -164,7 +157,7 @@ async def node_query(
164
157
  )
165
158
  raise error
166
159
 
167
- return results, incomplete_results, queried_nodes
160
+ return results, incomplete_results, queried_shards
168
161
 
169
162
 
170
163
  def validate_node_query_results(results: list[Any]) -> Optional[HTTPException]:
@@ -201,15 +194,3 @@ def validate_node_query_results(results: list[Any]) -> Optional[HTTPException]:
201
194
  return HTTPException(status_code=status_code, detail=reason)
202
195
 
203
196
  return None
204
-
205
-
206
- def debug_nodes_info(nodes: list[tuple[AbstractIndexNode, str]]) -> list[dict[str, str]]:
207
- details: list[dict[str, str]] = []
208
- for node, shard_id in nodes:
209
- info = {
210
- "id": node.id,
211
- "shard_id": shard_id,
212
- "address": "nidx",
213
- }
214
- details.append(info)
215
- return details
@@ -23,7 +23,7 @@ from time import time
23
23
  from nucliadb.common.external_index_providers.base import ExternalIndexManager
24
24
  from nucliadb.common.external_index_providers.manager import get_external_index_manager
25
25
  from nucliadb.common.models_utils import to_proto
26
- from nucliadb.search.requesters.utils import Method, debug_nodes_info, node_query
26
+ from nucliadb.search.requesters.utils import Method, node_query
27
27
  from nucliadb.search.search.find_merge import (
28
28
  build_find_response,
29
29
  compose_find_resources,
@@ -105,7 +105,7 @@ async def _index_node_retrieval(
105
105
  ) = await legacy_convert_retrieval_to_proto(parsed)
106
106
 
107
107
  with metrics.time("node_query"):
108
- results, query_incomplete_results, queried_nodes = await node_query(
108
+ results, query_incomplete_results, queried_shards = await node_query(
109
109
  kbid, Method.SEARCH, pb_query
110
110
  )
111
111
  incomplete_results = incomplete_results or query_incomplete_results
@@ -139,10 +139,6 @@ async def _index_node_retrieval(
139
139
  retrieval_rephrased_question=rephrased_query,
140
140
  )
141
141
 
142
- if item.debug:
143
- search_results.nodes = debug_nodes_info(queried_nodes)
144
-
145
- queried_shards = [shard_id for _, shard_id in queried_nodes]
146
142
  search_results.shards = queried_shards
147
143
  search_results.autofilters = autofilters
148
144
 
@@ -156,7 +152,6 @@ async def _index_node_retrieval(
156
152
  "client": x_ndb_client,
157
153
  "query": item.model_dump_json(),
158
154
  "time": search_time,
159
- "nodes": debug_nodes_info(queried_nodes),
160
155
  "durations": metrics.steps(),
161
156
  },
162
157
  )
@@ -169,7 +164,6 @@ async def _index_node_retrieval(
169
164
  "client": x_ndb_client,
170
165
  "query": item.model_dump_json(),
171
166
  "time": search_time,
172
- "nodes": debug_nodes_info(queried_nodes),
173
167
  "durations": metrics.steps(),
174
168
  },
175
169
  )
@@ -17,7 +17,6 @@
17
17
  # You should have received a copy of the GNU Affero General Public License
18
18
  # along with this program. If not, see <http://www.gnu.org/licenses/>.
19
19
  #
20
- import asyncio
21
20
 
22
21
  import backoff
23
22
  from grpc import StatusCode
@@ -33,16 +32,7 @@ from nidx_protos.nodereader_pb2 import (
33
32
  )
34
33
  from nidx_protos.noderesources_pb2 import Shard
35
34
 
36
- from nucliadb.common.cluster.base import AbstractIndexNode
37
- from nucliadb_telemetry import metrics
38
-
39
- node_observer = metrics.Observer(
40
- "node_client",
41
- labels={"type": "", "node_id": ""},
42
- error_mappings={
43
- "timeout": asyncio.CancelledError,
44
- },
45
- )
35
+ from nucliadb.common.nidx import get_nidx_api_client, get_nidx_searcher_client
46
36
 
47
37
 
48
38
  def should_giveup(e: Exception):
@@ -54,43 +44,37 @@ def should_giveup(e: Exception):
54
44
  @backoff.on_exception(
55
45
  backoff.expo, Exception, jitter=None, factor=0.1, max_tries=3, giveup=should_giveup
56
46
  )
57
- async def query_shard(node: AbstractIndexNode, shard: str, query: SearchRequest) -> SearchResponse:
47
+ async def query_shard(shard: str, query: SearchRequest) -> SearchResponse:
58
48
  req = SearchRequest()
59
49
  req.CopyFrom(query)
60
50
  req.shard = shard
61
- with node_observer({"type": "search", "node_id": node.id}):
62
- return await node.reader.Search(req) # type: ignore
51
+ return await get_nidx_searcher_client().Search(req)
63
52
 
64
53
 
65
54
  @backoff.on_exception(
66
55
  backoff.expo, Exception, jitter=None, factor=0.1, max_tries=3, giveup=should_giveup
67
56
  )
68
- async def get_shard(node: AbstractIndexNode, shard_id: str) -> Shard:
57
+ async def get_shard(shard_id: str) -> Shard:
69
58
  req = GetShardRequest()
70
59
  req.shard_id.id = shard_id
71
- with node_observer({"type": "get_shard", "node_id": node.id}):
72
- return await node.reader.GetShard(req) # type: ignore
60
+ return await get_nidx_api_client().GetShard(req)
73
61
 
74
62
 
75
63
  @backoff.on_exception(
76
64
  backoff.expo, Exception, jitter=None, factor=0.1, max_tries=3, giveup=should_giveup
77
65
  )
78
- async def suggest_shard(node: AbstractIndexNode, shard: str, query: SuggestRequest) -> SuggestResponse:
66
+ async def suggest_shard(shard: str, query: SuggestRequest) -> SuggestResponse:
79
67
  req = SuggestRequest()
80
68
  req.CopyFrom(query)
81
69
  req.shard = shard
82
- with node_observer({"type": "suggest", "node_id": node.id}):
83
- return await node.reader.Suggest(req) # type: ignore
70
+ return await get_nidx_searcher_client().Suggest(req)
84
71
 
85
72
 
86
73
  @backoff.on_exception(
87
74
  backoff.expo, Exception, jitter=None, factor=0.1, max_tries=3, giveup=should_giveup
88
75
  )
89
- async def graph_search_shard(
90
- node: AbstractIndexNode, shard: str, query: GraphSearchRequest
91
- ) -> GraphSearchResponse:
76
+ async def graph_search_shard(shard: str, query: GraphSearchRequest) -> GraphSearchResponse:
92
77
  req = GraphSearchRequest()
93
78
  req.CopyFrom(query)
94
79
  req.shard = shard
95
- with node_observer({"type": "graph_search", "node_id": node.id}):
96
- return await node.reader.GraphSearch(req) # type: ignore
80
+ return await get_nidx_searcher_client().GraphSearch(req)
@@ -54,7 +54,7 @@ from nucliadb_protos.dataset_pb2 import TaskType, TrainSet
54
54
  async def generate_train_data(kbid: str, shard: str, trainset: TrainSet):
55
55
  # Get the data structure to generate data
56
56
  shard_manager = get_shard_manager()
57
- node, shard_replica_id = await shard_manager.get_reader(kbid, shard)
57
+ shard_replica_id = await shard_manager.get_shard_id(kbid, shard)
58
58
 
59
59
  if trainset.batch_size == 0:
60
60
  trainset.batch_size = 50
@@ -62,24 +62,22 @@ async def generate_train_data(kbid: str, shard: str, trainset: TrainSet):
62
62
  batch_generator: Optional[AsyncIterator[TrainBatch]] = None
63
63
 
64
64
  if trainset.type == TaskType.FIELD_CLASSIFICATION:
65
- batch_generator = field_classification_batch_generator(kbid, trainset, node, shard_replica_id)
65
+ batch_generator = field_classification_batch_generator(kbid, trainset, shard_replica_id)
66
66
  elif trainset.type == TaskType.IMAGE_CLASSIFICATION:
67
- batch_generator = image_classification_batch_generator(kbid, trainset, node, shard_replica_id)
67
+ batch_generator = image_classification_batch_generator(kbid, trainset, shard_replica_id)
68
68
  elif trainset.type == TaskType.PARAGRAPH_CLASSIFICATION:
69
- batch_generator = paragraph_classification_batch_generator(
70
- kbid, trainset, node, shard_replica_id
71
- )
69
+ batch_generator = paragraph_classification_batch_generator(kbid, trainset, shard_replica_id)
72
70
  elif trainset.type == TaskType.TOKEN_CLASSIFICATION:
73
- batch_generator = token_classification_batch_generator(kbid, trainset, node, shard_replica_id)
71
+ batch_generator = token_classification_batch_generator(kbid, trainset, shard_replica_id)
74
72
  elif trainset.type == TaskType.SENTENCE_CLASSIFICATION:
75
- batch_generator = sentence_classification_batch_generator(kbid, trainset, node, shard_replica_id)
73
+ batch_generator = sentence_classification_batch_generator(kbid, trainset, shard_replica_id)
76
74
  elif trainset.type == TaskType.PARAGRAPH_STREAMING:
77
- batch_generator = paragraph_streaming_batch_generator(kbid, trainset, node, shard_replica_id)
75
+ batch_generator = paragraph_streaming_batch_generator(kbid, trainset, shard_replica_id)
78
76
 
79
77
  elif trainset.type == TaskType.QUESTION_ANSWER_STREAMING:
80
- batch_generator = question_answer_batch_generator(kbid, trainset, node, shard_replica_id)
78
+ batch_generator = question_answer_batch_generator(kbid, trainset, shard_replica_id)
81
79
  elif trainset.type == TaskType.FIELD_STREAMING:
82
- batch_generator = field_streaming_batch_generator(kbid, trainset, node, shard_replica_id)
80
+ batch_generator = field_streaming_batch_generator(kbid, trainset, shard_replica_id)
83
81
 
84
82
  if batch_generator is None:
85
83
  raise HTTPException(
@@ -22,8 +22,8 @@ from typing import AsyncGenerator
22
22
 
23
23
  from nidx_protos.nodereader_pb2 import StreamRequest
24
24
 
25
- from nucliadb.common.cluster.base import AbstractIndexNode
26
25
  from nucliadb.common.ids import FIELD_TYPE_STR_TO_PB
26
+ from nucliadb.common.nidx import get_nidx_searcher_client
27
27
  from nucliadb.train import logger
28
28
  from nucliadb.train.generators.utils import batchify, get_resource_from_cache_or_db
29
29
  from nucliadb_protos.dataset_pb2 import (
@@ -37,10 +37,9 @@ from nucliadb_protos.dataset_pb2 import (
37
37
  def field_classification_batch_generator(
38
38
  kbid: str,
39
39
  trainset: TrainSet,
40
- node: AbstractIndexNode,
41
40
  shard_replica_id: str,
42
41
  ) -> AsyncGenerator[FieldClassificationBatch, None]:
43
- generator = generate_field_classification_payloads(kbid, trainset, node, shard_replica_id)
42
+ generator = generate_field_classification_payloads(kbid, trainset, shard_replica_id)
44
43
  batch_generator = batchify(generator, trainset.batch_size, FieldClassificationBatch)
45
44
  return batch_generator
46
45
 
@@ -48,7 +47,6 @@ def field_classification_batch_generator(
48
47
  async def generate_field_classification_payloads(
49
48
  kbid: str,
50
49
  trainset: TrainSet,
51
- node: AbstractIndexNode,
52
50
  shard_replica_id: str,
53
51
  ) -> AsyncGenerator[TextLabel, None]:
54
52
  labelset = f"/l/{trainset.filter.labels[0]}"
@@ -59,7 +57,7 @@ async def generate_field_classification_payloads(
59
57
  request.filter.labels.append(labelset)
60
58
  total = 0
61
59
 
62
- async for document_item in node.stream_get_fields(request):
60
+ async for document_item in get_nidx_searcher_client().Documents(request):
63
61
  text_labels = []
64
62
  for label in document_item.labels:
65
63
  if label.startswith(labelset):
@@ -22,8 +22,8 @@ from typing import AsyncGenerator, Optional
22
22
 
23
23
  from nidx_protos.nodereader_pb2 import StreamRequest
24
24
 
25
- from nucliadb.common.cluster.base import AbstractIndexNode
26
25
  from nucliadb.common.ids import FIELD_TYPE_STR_TO_PB
26
+ from nucliadb.common.nidx import get_nidx_searcher_client
27
27
  from nucliadb.train import logger
28
28
  from nucliadb.train.generators.utils import batchify, get_resource_from_cache_or_db
29
29
  from nucliadb_protos.dataset_pb2 import (
@@ -38,10 +38,9 @@ from nucliadb_protos.utils_pb2 import ExtractedText
38
38
  def field_streaming_batch_generator(
39
39
  kbid: str,
40
40
  trainset: TrainSet,
41
- node: AbstractIndexNode,
42
41
  shard_replica_id: str,
43
42
  ) -> AsyncGenerator[FieldStreamingBatch, None]:
44
- generator = generate_field_streaming_payloads(kbid, trainset, node, shard_replica_id)
43
+ generator = generate_field_streaming_payloads(kbid, trainset, shard_replica_id)
45
44
  batch_generator = batchify(generator, trainset.batch_size, FieldStreamingBatch)
46
45
  return batch_generator
47
46
 
@@ -49,7 +48,6 @@ def field_streaming_batch_generator(
49
48
  async def generate_field_streaming_payloads(
50
49
  kbid: str,
51
50
  trainset: TrainSet,
52
- node: AbstractIndexNode,
53
51
  shard_replica_id: str,
54
52
  ) -> AsyncGenerator[FieldSplitData, None]:
55
53
  # Query how many resources has each label
@@ -77,7 +75,7 @@ async def generate_field_streaming_payloads(
77
75
  total = 0
78
76
  resources = set()
79
77
 
80
- async for document_item in node.stream_get_fields(request):
78
+ async for document_item in get_nidx_searcher_client().Documents(request):
81
79
  text_labels = []
82
80
  for label in document_item.labels:
83
81
  text_labels.append(label)
@@ -20,7 +20,6 @@
20
20
 
21
21
  from typing import AsyncGenerator
22
22
 
23
- from nucliadb.common.cluster.base import AbstractIndexNode
24
23
  from nucliadb.train.generators.utils import batchify
25
24
  from nucliadb_protos.dataset_pb2 import (
26
25
  ImageClassification,
@@ -32,10 +31,9 @@ from nucliadb_protos.dataset_pb2 import (
32
31
  def image_classification_batch_generator(
33
32
  kbid: str,
34
33
  trainset: TrainSet,
35
- node: AbstractIndexNode,
36
34
  shard_replica_id: str,
37
35
  ) -> AsyncGenerator[ImageClassificationBatch, None]:
38
- generator = generate_image_classification_payloads(kbid, trainset, node, shard_replica_id)
36
+ generator = generate_image_classification_payloads(kbid, trainset, shard_replica_id)
39
37
  batch_generator = batchify(generator, trainset.batch_size, ImageClassificationBatch)
40
38
  return batch_generator
41
39
 
@@ -43,7 +41,6 @@ def image_classification_batch_generator(
43
41
  async def generate_image_classification_payloads(
44
42
  kbid: str,
45
43
  trainset: TrainSet,
46
- node: AbstractIndexNode,
47
44
  shard_replica_id: str,
48
45
  ) -> AsyncGenerator[ImageClassification, None]:
49
46
  # NOTE: image classifications are no longer supported, as the page selection annotations were removed
@@ -23,7 +23,7 @@ from typing import AsyncGenerator
23
23
  from fastapi import HTTPException
24
24
  from nidx_protos.nodereader_pb2 import StreamRequest
25
25
 
26
- from nucliadb.common.cluster.base import AbstractIndexNode
26
+ from nucliadb.common.nidx import get_nidx_searcher_client
27
27
  from nucliadb.train.generators.utils import batchify, get_paragraph
28
28
  from nucliadb_protos.dataset_pb2 import (
29
29
  Label,
@@ -36,7 +36,6 @@ from nucliadb_protos.dataset_pb2 import (
36
36
  def paragraph_classification_batch_generator(
37
37
  kbid: str,
38
38
  trainset: TrainSet,
39
- node: AbstractIndexNode,
40
39
  shard_replica_id: str,
41
40
  ) -> AsyncGenerator[ParagraphClassificationBatch, None]:
42
41
  if len(trainset.filter.labels) != 1:
@@ -45,7 +44,7 @@ def paragraph_classification_batch_generator(
45
44
  detail="Paragraph Classification should be of 1 labelset",
46
45
  )
47
46
 
48
- generator = generate_paragraph_classification_payloads(kbid, trainset, node, shard_replica_id)
47
+ generator = generate_paragraph_classification_payloads(kbid, trainset, shard_replica_id)
49
48
  batch_generator = batchify(generator, trainset.batch_size, ParagraphClassificationBatch)
50
49
  return batch_generator
51
50
 
@@ -53,7 +52,6 @@ def paragraph_classification_batch_generator(
53
52
  async def generate_paragraph_classification_payloads(
54
53
  kbid: str,
55
54
  trainset: TrainSet,
56
- node: AbstractIndexNode,
57
55
  shard_replica_id: str,
58
56
  ) -> AsyncGenerator[TextLabel, None]:
59
57
  labelset = f"/l/{trainset.filter.labels[0]}"
@@ -63,7 +61,7 @@ async def generate_paragraph_classification_payloads(
63
61
  request.shard_id.id = shard_replica_id
64
62
  request.filter.labels.append(labelset)
65
63
 
66
- async for paragraph_item in node.stream_get_paragraphs(request):
64
+ async for paragraph_item in get_nidx_searcher_client().Paragraphs(request):
67
65
  text_labels = []
68
66
  for label in paragraph_item.labels:
69
67
  if label.startswith(labelset):
@@ -22,8 +22,8 @@ from typing import AsyncGenerator
22
22
 
23
23
  from nidx_protos.nodereader_pb2 import StreamRequest
24
24
 
25
- from nucliadb.common.cluster.base import AbstractIndexNode
26
25
  from nucliadb.common.ids import FIELD_TYPE_STR_TO_PB
26
+ from nucliadb.common.nidx import get_nidx_searcher_client
27
27
  from nucliadb.train import logger
28
28
  from nucliadb.train.generators.utils import batchify, get_resource_from_cache_or_db
29
29
  from nucliadb_protos.dataset_pb2 import (
@@ -36,10 +36,9 @@ from nucliadb_protos.dataset_pb2 import (
36
36
  def paragraph_streaming_batch_generator(
37
37
  kbid: str,
38
38
  trainset: TrainSet,
39
- node: AbstractIndexNode,
40
39
  shard_replica_id: str,
41
40
  ) -> AsyncGenerator[ParagraphStreamingBatch, None]:
42
- generator = generate_paragraph_streaming_payloads(kbid, trainset, node, shard_replica_id)
41
+ generator = generate_paragraph_streaming_payloads(kbid, trainset, shard_replica_id)
43
42
  batch_generator = batchify(generator, trainset.batch_size, ParagraphStreamingBatch)
44
43
  return batch_generator
45
44
 
@@ -47,7 +46,6 @@ def paragraph_streaming_batch_generator(
47
46
  async def generate_paragraph_streaming_payloads(
48
47
  kbid: str,
49
48
  trainset: TrainSet,
50
- node: AbstractIndexNode,
51
49
  shard_replica_id: str,
52
50
  ) -> AsyncGenerator[ParagraphStreamItem, None]:
53
51
  """Streams paragraphs ordered as if they were read sequentially from each
@@ -57,7 +55,7 @@ async def generate_paragraph_streaming_payloads(
57
55
  request = StreamRequest()
58
56
  request.shard_id.id = shard_replica_id
59
57
 
60
- async for document_item in node.stream_get_fields(request):
58
+ async for document_item in get_nidx_searcher_client().Documents(request):
61
59
  field_id = f"{document_item.uuid}{document_item.field}"
62
60
  rid, field_type, field = field_id.split("/")
63
61
 
@@ -22,8 +22,8 @@ from typing import AsyncGenerator
22
22
 
23
23
  from nidx_protos.nodereader_pb2 import StreamRequest
24
24
 
25
- from nucliadb.common.cluster.base import AbstractIndexNode
26
25
  from nucliadb.common.ids import FIELD_TYPE_PB_TO_STR, FIELD_TYPE_STR_TO_PB
26
+ from nucliadb.common.nidx import get_nidx_searcher_client
27
27
  from nucliadb.train import logger
28
28
  from nucliadb.train.generators.utils import (
29
29
  batchify,
@@ -45,10 +45,9 @@ from nucliadb_protos.resources_pb2 import (
45
45
  def question_answer_batch_generator(
46
46
  kbid: str,
47
47
  trainset: TrainSet,
48
- node: AbstractIndexNode,
49
48
  shard_replica_id: str,
50
49
  ) -> AsyncGenerator[QuestionAnswerStreamingBatch, None]:
51
- generator = generate_question_answer_streaming_payloads(kbid, trainset, node, shard_replica_id)
50
+ generator = generate_question_answer_streaming_payloads(kbid, trainset, shard_replica_id)
52
51
  batch_generator = batchify(generator, trainset.batch_size, QuestionAnswerStreamingBatch)
53
52
  return batch_generator
54
53
 
@@ -56,13 +55,12 @@ def question_answer_batch_generator(
56
55
  async def generate_question_answer_streaming_payloads(
57
56
  kbid: str,
58
57
  trainset: TrainSet,
59
- node: AbstractIndexNode,
60
58
  shard_replica_id: str,
61
59
  ):
62
60
  request = StreamRequest()
63
61
  request.shard_id.id = shard_replica_id
64
62
 
65
- async for document_item in node.stream_get_fields(request):
63
+ async for document_item in get_nidx_searcher_client().Documents(request):
66
64
  field_id = f"{document_item.uuid}{document_item.field}"
67
65
  rid, field_type, field = field_id.split("/")
68
66
 
@@ -23,8 +23,8 @@ from typing import AsyncGenerator
23
23
  from fastapi import HTTPException
24
24
  from nidx_protos.nodereader_pb2 import StreamRequest
25
25
 
26
- from nucliadb.common.cluster.base import AbstractIndexNode
27
26
  from nucliadb.common.ids import FIELD_TYPE_STR_TO_PB
27
+ from nucliadb.common.nidx import get_nidx_searcher_client
28
28
  from nucliadb.train import logger
29
29
  from nucliadb.train.generators.utils import batchify, get_resource_from_cache_or_db
30
30
  from nucliadb_protos.dataset_pb2 import (
@@ -38,7 +38,6 @@ from nucliadb_protos.dataset_pb2 import (
38
38
  def sentence_classification_batch_generator(
39
39
  kbid: str,
40
40
  trainset: TrainSet,
41
- node: AbstractIndexNode,
42
41
  shard_replica_id: str,
43
42
  ) -> AsyncGenerator[SentenceClassificationBatch, None]:
44
43
  if len(trainset.filter.labels) == 0:
@@ -47,7 +46,7 @@ def sentence_classification_batch_generator(
47
46
  detail="Sentence Classification should be at least of 1 labelset",
48
47
  )
49
48
 
50
- generator = generate_sentence_classification_payloads(kbid, trainset, node, shard_replica_id)
49
+ generator = generate_sentence_classification_payloads(kbid, trainset, shard_replica_id)
51
50
  batch_generator = batchify(generator, trainset.batch_size, SentenceClassificationBatch)
52
51
  return batch_generator
53
52
 
@@ -55,7 +54,6 @@ def sentence_classification_batch_generator(
55
54
  async def generate_sentence_classification_payloads(
56
55
  kbid: str,
57
56
  trainset: TrainSet,
58
- node: AbstractIndexNode,
59
57
  shard_replica_id: str,
60
58
  ) -> AsyncGenerator[MultipleTextSameLabels, None]:
61
59
  labelsets = []
@@ -67,7 +65,7 @@ async def generate_sentence_classification_payloads(
67
65
  labelsets.append(labelset)
68
66
  request.filter.labels.append(labelset)
69
67
 
70
- async for paragraph_item in node.stream_get_paragraphs(request):
68
+ async for paragraph_item in get_nidx_searcher_client().Paragraphs(request):
71
69
  text_labels: list[str] = []
72
70
  for label in paragraph_item.labels:
73
71
  for labelset in labelsets:
@@ -23,8 +23,8 @@ from typing import AsyncGenerator, cast
23
23
 
24
24
  from nidx_protos.nodereader_pb2 import StreamFilter, StreamRequest
25
25
 
26
- from nucliadb.common.cluster.base import AbstractIndexNode
27
26
  from nucliadb.common.ids import FIELD_TYPE_STR_TO_PB
27
+ from nucliadb.common.nidx import get_nidx_searcher_client
28
28
  from nucliadb.train import logger
29
29
  from nucliadb.train.generators.utils import batchify, get_resource_from_cache_or_db
30
30
  from nucliadb_protos.dataset_pb2 import (
@@ -41,10 +41,9 @@ MAIN = "__main__"
41
41
  def token_classification_batch_generator(
42
42
  kbid: str,
43
43
  trainset: TrainSet,
44
- node: AbstractIndexNode,
45
44
  shard_replica_id: str,
46
45
  ) -> AsyncGenerator[TokenClassificationBatch, None]:
47
- generator = generate_token_classification_payloads(kbid, trainset, node, shard_replica_id)
46
+ generator = generate_token_classification_payloads(kbid, trainset, shard_replica_id)
48
47
  batch_generator = batchify(generator, trainset.batch_size, TokenClassificationBatch)
49
48
  return batch_generator
50
49
 
@@ -52,7 +51,6 @@ def token_classification_batch_generator(
52
51
  async def generate_token_classification_payloads(
53
52
  kbid: str,
54
53
  trainset: TrainSet,
55
- node: AbstractIndexNode,
56
54
  shard_replica_id: str,
57
55
  ) -> AsyncGenerator[TokensClassification, None]:
58
56
  request = StreamRequest()
@@ -60,7 +58,7 @@ async def generate_token_classification_payloads(
60
58
  for entitygroup in trainset.filter.labels:
61
59
  request.filter.labels.append(f"/e/{entitygroup}")
62
60
  request.filter.conjunction = StreamFilter.Conjunction.OR
63
- async for field_item in node.stream_get_fields(request):
61
+ async for field_item in get_nidx_searcher_client().Documents(request):
64
62
  _, field_type, field = field_item.field.split("/")
65
63
  (
66
64
  split_text,
nucliadb/train/nodes.py CHANGED
@@ -21,7 +21,6 @@ from typing import AsyncIterator, Optional
21
21
 
22
22
  from nucliadb.common import datamanagers
23
23
  from nucliadb.common.cluster import manager
24
- from nucliadb.common.cluster.base import AbstractIndexNode
25
24
 
26
25
  # XXX: this keys shouldn't be exposed outside datamanagers
27
26
  from nucliadb.common.datamanagers.resources import KB_RESOURCE_SLUG_BASE
@@ -54,15 +53,14 @@ class TrainShardManager(manager.KBShardManager):
54
53
  self.driver = driver
55
54
  self.storage = storage
56
55
 
57
- async def get_reader(self, kbid: str, shard: str) -> tuple[AbstractIndexNode, str]:
56
+ async def get_shard_id(self, kbid: str, shard: str) -> str:
58
57
  shards = await self.get_shards_by_kbid_inner(kbid)
59
58
  try:
60
59
  shard_object: ShardObject = next(filter(lambda x: x.shard == shard, shards.shards))
61
60
  except StopIteration:
62
61
  raise KeyError("Shard not found")
63
62
 
64
- node_obj, shard_id = manager.choose_node(shard_object)
65
- return node_obj, shard_id
63
+ return shard_object.nidx_shard_id
66
64
 
67
65
  async def get_kb_obj(self, txn: Transaction, kbid: str) -> Optional[KnowledgeBox]:
68
66
  if kbid is None:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: nucliadb
3
- Version: 6.4.0.post4127
3
+ Version: 6.4.0.post4132
4
4
  Summary: NucliaDB
5
5
  Author-email: Nuclia <nucliadb@nuclia.com>
6
6
  License: AGPL
@@ -20,11 +20,11 @@ Classifier: Programming Language :: Python :: 3.12
20
20
  Classifier: Programming Language :: Python :: 3 :: Only
21
21
  Requires-Python: <4,>=3.9
22
22
  Description-Content-Type: text/markdown
23
- Requires-Dist: nucliadb-telemetry[all]>=6.4.0.post4127
24
- Requires-Dist: nucliadb-utils[cache,fastapi,storages]>=6.4.0.post4127
25
- Requires-Dist: nucliadb-protos>=6.4.0.post4127
26
- Requires-Dist: nucliadb-models>=6.4.0.post4127
27
- Requires-Dist: nidx-protos>=6.4.0.post4127
23
+ Requires-Dist: nucliadb-telemetry[all]>=6.4.0.post4132
24
+ Requires-Dist: nucliadb-utils[cache,fastapi,storages]>=6.4.0.post4132
25
+ Requires-Dist: nucliadb-protos>=6.4.0.post4132
26
+ Requires-Dist: nucliadb-models>=6.4.0.post4132
27
+ Requires-Dist: nidx-protos>=6.4.0.post4132
28
28
  Requires-Dist: nucliadb-admin-assets>=1.0.0.post1224
29
29
  Requires-Dist: nuclia-models>=0.24.2
30
30
  Requires-Dist: uvicorn[standard]