nucliadb 6.2.1.post2954__py3-none-any.whl → 6.2.1.post2972__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. nucliadb/common/cluster/manager.py +33 -331
  2. nucliadb/common/cluster/rebalance.py +2 -2
  3. nucliadb/common/cluster/rollover.py +12 -71
  4. nucliadb/common/cluster/standalone/utils.py +0 -43
  5. nucliadb/common/cluster/utils.py +0 -16
  6. nucliadb/common/nidx.py +21 -23
  7. nucliadb/health.py +0 -7
  8. nucliadb/ingest/app.py +0 -8
  9. nucliadb/ingest/consumer/auditing.py +1 -1
  10. nucliadb/ingest/consumer/shard_creator.py +1 -1
  11. nucliadb/ingest/orm/entities.py +3 -6
  12. nucliadb/purge/orphan_shards.py +6 -4
  13. nucliadb/search/api/v1/knowledgebox.py +1 -5
  14. nucliadb/search/predict.py +4 -4
  15. nucliadb/search/requesters/utils.py +1 -2
  16. nucliadb/search/search/chat/ask.py +18 -11
  17. nucliadb/search/search/chat/query.py +1 -1
  18. nucliadb/search/search/shards.py +19 -0
  19. nucliadb/standalone/introspect.py +0 -25
  20. nucliadb/train/lifecycle.py +0 -6
  21. nucliadb/train/nodes.py +1 -5
  22. nucliadb/writer/back_pressure.py +17 -46
  23. nucliadb/writer/settings.py +2 -2
  24. {nucliadb-6.2.1.post2954.dist-info → nucliadb-6.2.1.post2972.dist-info}/METADATA +5 -7
  25. {nucliadb-6.2.1.post2954.dist-info → nucliadb-6.2.1.post2972.dist-info}/RECORD +29 -39
  26. nucliadb/common/cluster/discovery/__init__.py +0 -19
  27. nucliadb/common/cluster/discovery/base.py +0 -178
  28. nucliadb/common/cluster/discovery/k8s.py +0 -301
  29. nucliadb/common/cluster/discovery/manual.py +0 -57
  30. nucliadb/common/cluster/discovery/single.py +0 -51
  31. nucliadb/common/cluster/discovery/types.py +0 -32
  32. nucliadb/common/cluster/discovery/utils.py +0 -67
  33. nucliadb/common/cluster/standalone/grpc_node_binding.py +0 -349
  34. nucliadb/common/cluster/standalone/index_node.py +0 -123
  35. nucliadb/common/cluster/standalone/service.py +0 -84
  36. {nucliadb-6.2.1.post2954.dist-info → nucliadb-6.2.1.post2972.dist-info}/WHEEL +0 -0
  37. {nucliadb-6.2.1.post2954.dist-info → nucliadb-6.2.1.post2972.dist-info}/entry_points.txt +0 -0
  38. {nucliadb-6.2.1.post2954.dist-info → nucliadb-6.2.1.post2972.dist-info}/top_level.txt +0 -0
  39. {nucliadb-6.2.1.post2954.dist-info → nucliadb-6.2.1.post2972.dist-info}/zip-safe +0 -0
@@ -19,13 +19,10 @@
19
19
 
20
20
  import logging
21
21
  import os
22
- import shutil
23
22
  import uuid
24
- from socket import gethostname
25
23
 
26
24
  from nucliadb.common.cluster.settings import StandaloneNodeRole
27
25
  from nucliadb.common.cluster.settings import settings as cluster_settings
28
- from nucliadb.common.cluster.standalone.index_node import StandaloneIndexNode
29
26
 
30
27
  logger = logging.getLogger(__name__)
31
28
 
@@ -46,46 +43,6 @@ def get_standalone_node_id() -> str:
46
43
  return str(uuid.UUID(bytes=f.read()))
47
44
 
48
45
 
49
- _SELF_INDEX_NODE = None
50
-
51
-
52
- def get_self() -> StandaloneIndexNode:
53
- """
54
- This returns an instance of the standalone index node
55
- so when API requests come into this mode, we don't
56
- make another grpc request since this node can service it directly.
57
- """
58
- if not is_index_node():
59
- raise Exception("This node is not an Index Node. You should not reach this code path.")
60
- global _SELF_INDEX_NODE
61
- node_id = get_standalone_node_id()
62
- if _SELF_INDEX_NODE is None or node_id != _SELF_INDEX_NODE.id:
63
- if "NUCLIADB_SERVICE_HOST" in os.environ:
64
- hn = os.environ["HOSTNAME"]
65
- ns = os.environ.get("NAMESPACE", "nucliadb")
66
- host = f"{hn}.{ns}"
67
- else:
68
- host = gethostname()
69
- _SELF_INDEX_NODE = StandaloneIndexNode(id=node_id, address=host, shard_count=0, available_disk=0)
70
- try:
71
- _, _, available_disk = shutil.disk_usage(cluster_settings.data_path)
72
- _SELF_INDEX_NODE.available_disk = available_disk
73
- except FileNotFoundError: # pragma: no cover
74
- ...
75
- try:
76
- _shards_dir = os.path.join(cluster_settings.data_path, "shards")
77
- _SELF_INDEX_NODE.shard_count = len(
78
- [
79
- shard_dir
80
- for shard_dir in os.listdir(_shards_dir)
81
- if os.path.isdir(os.path.join(_shards_dir, shard_dir))
82
- ]
83
- )
84
- except FileNotFoundError: # pragma: no cover
85
- ...
86
- return _SELF_INDEX_NODE
87
-
88
-
89
46
  def is_index_node() -> bool:
90
47
  return cluster_settings.standalone_node_role in (
91
48
  StandaloneNodeRole.ALL,
@@ -23,20 +23,11 @@ from typing import TYPE_CHECKING, Optional, Union
23
23
  import backoff
24
24
 
25
25
  from nucliadb.common import datamanagers
26
- from nucliadb.common.cluster.discovery.utils import (
27
- setup_cluster_discovery,
28
- teardown_cluster_discovery,
29
- )
30
26
  from nucliadb.common.cluster.manager import (
31
27
  KBShardManager,
32
28
  StandaloneKBShardManager,
33
- clear_index_nodes,
34
29
  )
35
30
  from nucliadb.common.cluster.settings import settings
36
- from nucliadb.common.cluster.standalone.service import (
37
- start_grpc as start_standalone_grpc,
38
- )
39
- from nucliadb.common.cluster.standalone.utils import is_index_node
40
31
  from nucliadb.ingest.orm.resource import Resource
41
32
  from nucliadb_protos import nodereader_pb2, writer_pb2
42
33
  from nucliadb_utils import const
@@ -62,12 +53,8 @@ async def setup_cluster() -> Union[KBShardManager, StandaloneKBShardManager]:
62
53
  # already setup
63
54
  return get_utility(Utility.SHARD_MANAGER)
64
55
 
65
- await setup_cluster_discovery()
66
56
  mng: Union[KBShardManager, StandaloneKBShardManager]
67
57
  if settings.standalone_mode:
68
- if is_index_node():
69
- server = await start_standalone_grpc()
70
- set_utility(_STANDALONE_SERVER, server)
71
58
  mng = StandaloneKBShardManager()
72
59
  else:
73
60
  mng = KBShardManager()
@@ -76,7 +63,6 @@ async def setup_cluster() -> Union[KBShardManager, StandaloneKBShardManager]:
76
63
 
77
64
 
78
65
  async def teardown_cluster():
79
- await teardown_cluster_discovery()
80
66
  if get_utility(Utility.SHARD_MANAGER):
81
67
  clean_utility(Utility.SHARD_MANAGER)
82
68
 
@@ -85,8 +71,6 @@ async def teardown_cluster():
85
71
  await std_server.stop(None)
86
72
  clean_utility(_STANDALONE_SERVER)
87
73
 
88
- clear_index_nodes()
89
-
90
74
 
91
75
  def get_shard_manager() -> KBShardManager:
92
76
  return get_utility(Utility.SHARD_MANAGER) # type: ignore
nucliadb/common/nidx.py CHANGED
@@ -37,12 +37,10 @@ from nucliadb_utils.settings import FileBackendConfig, indexing_settings, storag
37
37
  from nucliadb_utils.storages.settings import settings as extended_storage_settings
38
38
  from nucliadb_utils.utilities import Utility, clean_utility, get_utility, set_utility
39
39
 
40
- NIDX_ENABLED = bool(os.environ.get("NIDX_ENABLED"))
41
-
42
40
 
43
41
  class NidxUtility:
44
- api_client = None
45
- searcher_client = None
42
+ api_client: NidxApiStub
43
+ searcher_client: NidxSearcherStub
46
44
 
47
45
  async def initialize(self):
48
46
  raise NotImplementedError()
@@ -98,6 +96,9 @@ class NidxBindingUtility(NidxUtility):
98
96
 
99
97
  self.config = {
100
98
  "METADATA__DATABASE_URL": ingest_settings.driver_pg_url,
99
+ "SEARCHER__METADATA_REFRESH_INTERVAL": str(
100
+ indexing_settings.index_searcher_refresh_interval
101
+ ),
101
102
  **_storage_config("INDEXER", None),
102
103
  **_storage_config("STORAGE", "nidx"),
103
104
  }
@@ -158,11 +159,8 @@ class NidxServiceUtility(NidxUtility):
158
159
  return res.seq
159
160
 
160
161
 
161
- async def start_nidx_utility() -> Optional[NidxUtility]:
162
- if not NIDX_ENABLED:
163
- return None
164
-
165
- nidx = get_nidx()
162
+ async def start_nidx_utility() -> NidxUtility:
163
+ nidx = get_utility(Utility.NIDX)
166
164
  if nidx:
167
165
  return nidx
168
166
 
@@ -178,30 +176,33 @@ async def start_nidx_utility() -> Optional[NidxUtility]:
178
176
 
179
177
 
180
178
  async def stop_nidx_utility():
181
- nidx_utility = get_nidx()
179
+ nidx_utility = get_utility(Utility.NIDX)
182
180
  if nidx_utility:
183
181
  clean_utility(Utility.NIDX)
184
182
  await nidx_utility.finalize()
185
183
 
186
184
 
187
- def get_nidx() -> Optional[NidxUtility]:
188
- return get_utility(Utility.NIDX)
185
+ def get_nidx() -> NidxUtility:
186
+ nidx = get_utility(Utility.NIDX)
187
+ if nidx is None:
188
+ raise Exception("nidx not initialized")
189
+ return nidx
189
190
 
190
191
 
191
- def get_nidx_api_client() -> Optional["NidxApiStub"]:
192
+ def get_nidx_api_client() -> "NidxApiStub":
192
193
  nidx = get_nidx()
193
- if nidx:
194
+ if nidx.api_client:
194
195
  return nidx.api_client
195
196
  else:
196
- return None
197
+ raise Exception("nidx not initialized")
197
198
 
198
199
 
199
- def get_nidx_searcher_client() -> Optional["NidxSearcherStub"]:
200
+ def get_nidx_searcher_client() -> "NidxSearcherStub":
200
201
  nidx = get_nidx()
201
- if nidx:
202
+ if nidx.searcher_client:
202
203
  return nidx.searcher_client
203
204
  else:
204
- return None
205
+ raise Exception("nidx not initialized")
205
206
 
206
207
 
207
208
  # TODO: Remove the index node abstraction
@@ -252,9 +253,6 @@ class FakeNode(AbstractIndexNode):
252
253
  return "nidx"
253
254
 
254
255
 
255
- def get_nidx_fake_node() -> Optional[FakeNode]:
256
+ def get_nidx_fake_node() -> FakeNode:
256
257
  nidx = get_nidx()
257
- if nidx:
258
- return FakeNode(nidx.api_client, nidx.searcher_client)
259
- else:
260
- return None
258
+ return FakeNode(nidx.api_client, nidx.searcher_client)
nucliadb/health.py CHANGED
@@ -40,13 +40,6 @@ def nats_manager_healthy() -> bool:
40
40
  return nats_manager.healthy()
41
41
 
42
42
 
43
- def nodes_health_check() -> bool:
44
- from nucliadb.common.cluster import manager
45
- from nucliadb.ingest.settings import DriverConfig, settings
46
-
47
- return len(manager.INDEX_NODES) > 0 or settings.driver == DriverConfig.LOCAL
48
-
49
-
50
43
  def pubsub_check() -> bool:
51
44
  driver: Optional[PubSubDriver] = get_utility(Utility.PUBSUB)
52
45
  if driver is None:
nucliadb/ingest/app.py CHANGED
@@ -22,10 +22,6 @@ import importlib.metadata
22
22
  from typing import Awaitable, Callable
23
23
 
24
24
  from nucliadb import health
25
- from nucliadb.common.cluster.discovery.utils import (
26
- setup_cluster_discovery,
27
- teardown_cluster_discovery,
28
- )
29
25
  from nucliadb.common.cluster.settings import settings as cluster_settings
30
26
  from nucliadb.common.cluster.utils import setup_cluster, teardown_cluster
31
27
  from nucliadb.common.context import ApplicationContext
@@ -89,13 +85,9 @@ async def initialize() -> list[Callable[[], Awaitable[None]]]:
89
85
  )
90
86
  finalizers.append(stop_nats_manager)
91
87
 
92
- await setup_cluster_discovery()
93
- finalizers.append(teardown_cluster_discovery)
94
-
95
88
  health.register_health_checks(
96
89
  [
97
90
  health.nats_manager_healthy,
98
- health.nodes_health_check,
99
91
  health.pubsub_check,
100
92
  ]
101
93
  )
@@ -113,7 +113,7 @@ class IndexAuditHandler:
113
113
 
114
114
  for shard_obj in shard_groups:
115
115
  # TODO: Uses node for auditing, don't want to suddenly change metrics
116
- node, shard_id = choose_node(shard_obj, use_nidx=False)
116
+ node, shard_id = choose_node(shard_obj)
117
117
  shard: nodereader_pb2.Shard = await node.reader.GetShard(
118
118
  nodereader_pb2.GetShardRequest(shard_id=noderesources_pb2.ShardId(id=shard_id)) # type: ignore
119
119
  )
@@ -103,7 +103,7 @@ class ShardCreatorHandler:
103
103
  async with locking.distributed_lock(locking.NEW_SHARD_LOCK.format(kbid=kbid)):
104
104
  # remember, a lock will do at least 1+ reads and 1 write.
105
105
  # with heavy writes, this adds some simple k/v pressure
106
- node, shard_id = choose_node(current_shard, use_nidx=True)
106
+ node, shard_id = choose_node(current_shard)
107
107
  shard: nodereader_pb2.Shard = await node.reader.GetShard(
108
108
  nodereader_pb2.GetShardRequest(shard_id=noderesources_pb2.ShardId(id=shard_id)) # type: ignore
109
109
  )
@@ -37,6 +37,7 @@ from nucliadb.common.datamanagers.entities import (
37
37
  from nucliadb.common.maindb.driver import Transaction
38
38
  from nucliadb.ingest.orm.knowledgebox import KnowledgeBox
39
39
  from nucliadb.ingest.settings import settings
40
+ from nucliadb.search.search.shards import query_shard
40
41
  from nucliadb_protos.knowledgebox_pb2 import (
41
42
  DeletedEntitiesGroups,
42
43
  EntitiesGroup,
@@ -54,8 +55,6 @@ from nucliadb_protos.nodereader_pb2 import (
54
55
  from nucliadb_protos.utils_pb2 import RelationNode
55
56
  from nucliadb_protos.writer_pb2 import GetEntitiesResponse
56
57
  from nucliadb_telemetry import errors
57
- from nucliadb_utils import const
58
- from nucliadb_utils.utilities import has_feature
59
58
 
60
59
  from .exceptions import EntityManagementException
61
60
 
@@ -218,14 +217,13 @@ class EntitiesManager:
218
217
  ],
219
218
  ),
220
219
  )
221
- response = await node.reader.Search(request) # type: ignore
220
+ response = await query_shard(node, shard_id, request)
222
221
  return response.relation
223
222
 
224
223
  results = await shard_manager.apply_for_all_shards(
225
224
  self.kbid,
226
225
  do_entities_search,
227
226
  settings.relation_search_timeout,
228
- use_nidx=has_feature(const.Features.NIDX_READS, context={"kbid": self.kbid}),
229
227
  use_read_replica_nodes=self.use_read_replica_nodes,
230
228
  )
231
229
  for result in results:
@@ -315,7 +313,7 @@ class EntitiesManager:
315
313
  paragraph=False,
316
314
  faceted=Faceted(labels=["/e"]),
317
315
  )
318
- response: SearchResponse = await node.reader.Search(request) # type: ignore
316
+ response: SearchResponse = await query_shard(node, shard_id, request)
319
317
  try:
320
318
  facetresults = response.document.facets["/e"].facetresults
321
319
  return {facet.tag.split("/")[-1] for facet in facetresults}
@@ -327,7 +325,6 @@ class EntitiesManager:
327
325
  self.kbid,
328
326
  query_indexed_entities_group_names,
329
327
  settings.relation_types_timeout,
330
- use_nidx=has_feature(const.Features.NIDX_READS, context={"kbid": self.kbid}),
331
328
  use_read_replica_nodes=self.use_read_replica_nodes,
332
329
  )
333
330
  for result in results:
@@ -33,6 +33,7 @@ from nucliadb.common.cluster.manager import KBShardManager
33
33
  from nucliadb.common.cluster.utils import setup_cluster, teardown_cluster
34
34
  from nucliadb.common.maindb.driver import Driver
35
35
  from nucliadb.common.maindb.utils import setup_driver, teardown_driver
36
+ from nucliadb.common.nidx import start_nidx_utility, stop_nidx_utility
36
37
  from nucliadb.ingest import logger
37
38
  from nucliadb_telemetry import errors
38
39
  from nucliadb_telemetry.logs import setup_logging
@@ -135,10 +136,9 @@ async def _get_stored_shards(driver: Driver) -> dict[str, ShardLocation]:
135
136
  continue
136
137
  else:
137
138
  for shard_object_pb in kb_shards:
138
- for shard_replica_pb in shard_object_pb.replicas:
139
- shard_replica_id = shard_replica_pb.shard.id
140
- node_id = shard_replica_pb.node
141
- stored_shards[shard_replica_id] = ShardLocation(kbid=kbid, node_id=node_id)
139
+ stored_shards[shard_object_pb.nidx_shard_id] = ShardLocation(
140
+ kbid=kbid, node_id="nidx"
141
+ )
142
142
  return stored_shards
143
143
 
144
144
 
@@ -241,6 +241,7 @@ async def main():
241
241
  """
242
242
  args = parse_arguments()
243
243
 
244
+ await start_nidx_utility()
244
245
  await setup_cluster()
245
246
  driver = await setup_driver()
246
247
 
@@ -253,6 +254,7 @@ async def main():
253
254
  finally:
254
255
  await teardown_driver()
255
256
  await teardown_cluster()
257
+ await stop_nidx_utility()
256
258
 
257
259
 
258
260
  def run() -> int: # pragma: no cover
@@ -48,9 +48,7 @@ from nucliadb_protos.noderesources_pb2 import Shard
48
48
  from nucliadb_protos.writer_pb2 import ShardObject as PBShardObject
49
49
  from nucliadb_protos.writer_pb2 import Shards
50
50
  from nucliadb_telemetry import errors
51
- from nucliadb_utils import const
52
51
  from nucliadb_utils.authentication import requires, requires_one
53
- from nucliadb_utils.utilities import has_feature
54
52
 
55
53
  MAX_PARAGRAPHS_FOR_SMALL_KB = 250_000
56
54
 
@@ -166,9 +164,7 @@ async def get_node_index_counts(kbid: str) -> tuple[IndexCounts, list[str]]:
166
164
  queried_shards = []
167
165
  for shard_object in shard_groups:
168
166
  try:
169
- node, shard_id = choose_node(
170
- shard_object, use_nidx=has_feature(const.Features.NIDX_READS, context={"kbid": kbid})
171
- )
167
+ node, shard_id = choose_node(shard_object)
172
168
  except KeyError:
173
169
  raise HTTPException(
174
170
  status_code=500,
@@ -21,7 +21,7 @@ import json
21
21
  import os
22
22
  import random
23
23
  from enum import Enum
24
- from typing import Any, AsyncIterator, Optional
24
+ from typing import Any, AsyncGenerator, Optional
25
25
  from unittest.mock import AsyncMock, Mock
26
26
 
27
27
  import aiohttp
@@ -268,7 +268,7 @@ class PredictEngine:
268
268
  @predict_observer.wrap({"type": "chat_ndjson"})
269
269
  async def chat_query_ndjson(
270
270
  self, kbid: str, item: ChatModel
271
- ) -> tuple[str, str, AsyncIterator[GenerativeChunk]]:
271
+ ) -> tuple[str, str, AsyncGenerator[GenerativeChunk, None]]:
272
272
  """
273
273
  Chat query using the new stream format
274
274
  Format specs: https://github.com/ndjson/ndjson-spec
@@ -444,7 +444,7 @@ class DummyPredictEngine(PredictEngine):
444
444
 
445
445
  async def chat_query_ndjson(
446
446
  self, kbid: str, item: ChatModel
447
- ) -> tuple[str, str, AsyncIterator[GenerativeChunk]]:
447
+ ) -> tuple[str, str, AsyncGenerator[GenerativeChunk, None]]:
448
448
  self.calls.append(("chat_query_ndjson", item))
449
449
 
450
450
  async def generate():
@@ -555,7 +555,7 @@ def get_answer_generator(response: aiohttp.ClientResponse):
555
555
 
556
556
  def get_chat_ndjson_generator(
557
557
  response: aiohttp.ClientResponse,
558
- ) -> AsyncIterator[GenerativeChunk]:
558
+ ) -> AsyncGenerator[GenerativeChunk, None]:
559
559
  async def _parse_generative_chunks(gen):
560
560
  async for chunk in gen:
561
561
  try:
@@ -123,7 +123,6 @@ async def node_query(
123
123
  try:
124
124
  node, shard_id = cluster_manager.choose_node(
125
125
  shard_obj,
126
- use_nidx=has_feature(const.Features.NIDX_READS, context={"kbid": kbid}),
127
126
  use_read_replica_nodes=use_read_replica_nodes,
128
127
  target_shard_replicas=target_shard_replicas,
129
128
  )
@@ -224,7 +223,7 @@ def validate_node_query_results(results: list[Any]) -> Optional[HTTPException]:
224
223
  )
225
224
  else:
226
225
  errors.capture_exception(result)
227
- logger.exception("Error while querying shard data", exc_info=result)
226
+ logger.exception(f"Error while querying shard data {result}", exc_info=result)
228
227
 
229
228
  return HTTPException(status_code=status_code, detail=reason)
230
229
 
@@ -129,7 +129,7 @@ class AskResult:
129
129
  main_results: KnowledgeboxFindResults,
130
130
  prequeries_results: Optional[list[PreQueryResult]],
131
131
  nuclia_learning_id: Optional[str],
132
- predict_answer_stream: AsyncGenerator[GenerativeChunk, None],
132
+ predict_answer_stream: Optional[AsyncGenerator[GenerativeChunk, None]],
133
133
  prompt_context: PromptContext,
134
134
  prompt_context_order: PromptContextOrder,
135
135
  auditor: ChatAuditor,
@@ -396,6 +396,9 @@ class AskResult:
396
396
  This method does not assume any order in the stream of items, but it assumes that at least
397
397
  the answer text is streamed in order.
398
398
  """
399
+ if self.predict_answer_stream is None:
400
+ # In some cases, clients may want to skip the answer generation step
401
+ return
399
402
  async for generative_chunk in self.predict_answer_stream:
400
403
  item = generative_chunk.chunk
401
404
  if isinstance(item, TextGenerativeResponse):
@@ -562,14 +565,18 @@ async def ask(
562
565
  rerank_context=False,
563
566
  top_k=ask_request.top_k,
564
567
  )
565
- with metrics.time("stream_start"):
566
- predict = get_predict()
567
- (
568
- nuclia_learning_id,
569
- nuclia_learning_model,
570
- predict_answer_stream,
571
- ) = await predict.chat_query_ndjson(kbid, chat_model)
572
- debug_chat_model = chat_model
568
+
569
+ nuclia_learning_id = None
570
+ nuclia_learning_model = None
571
+ predict_answer_stream = None
572
+ if ask_request.generate_answer:
573
+ with metrics.time("stream_start"):
574
+ predict = get_predict()
575
+ (
576
+ nuclia_learning_id,
577
+ nuclia_learning_model,
578
+ predict_answer_stream,
579
+ ) = await predict.chat_query_ndjson(kbid, chat_model)
573
580
 
574
581
  auditor = ChatAuditor(
575
582
  kbid=kbid,
@@ -590,13 +597,13 @@ async def ask(
590
597
  main_results=retrieval_results.main_query,
591
598
  prequeries_results=retrieval_results.prequeries,
592
599
  nuclia_learning_id=nuclia_learning_id,
593
- predict_answer_stream=predict_answer_stream, # type: ignore
600
+ predict_answer_stream=predict_answer_stream,
594
601
  prompt_context=prompt_context,
595
602
  prompt_context_order=prompt_context_order,
596
603
  auditor=auditor,
597
604
  metrics=metrics,
598
605
  best_matches=retrieval_results.best_matches,
599
- debug_chat_model=debug_chat_model,
606
+ debug_chat_model=chat_model,
600
607
  )
601
608
 
602
609
 
@@ -349,7 +349,7 @@ class ChatAuditor:
349
349
  learning_id: Optional[str],
350
350
  query_context: PromptContext,
351
351
  query_context_order: PromptContextOrder,
352
- model: str,
352
+ model: Optional[str],
353
353
  ):
354
354
  self.kbid = kbid
355
355
  self.user_id = user_id
@@ -19,6 +19,10 @@
19
19
  #
20
20
  import asyncio
21
21
 
22
+ import backoff
23
+ from grpc import StatusCode
24
+ from grpc.aio import AioRpcError
25
+
22
26
  from nucliadb.common.cluster.base import AbstractIndexNode
23
27
  from nucliadb_protos.nodereader_pb2 import (
24
28
  GetShardRequest,
@@ -39,6 +43,15 @@ node_observer = metrics.Observer(
39
43
  )
40
44
 
41
45
 
46
+ def should_giveup(e: Exception):
47
+ if isinstance(e, AioRpcError) and e.code() != StatusCode.NOT_FOUND:
48
+ return True
49
+ return False
50
+
51
+
52
+ @backoff.on_exception(
53
+ backoff.expo, Exception, jitter=None, factor=0.1, max_tries=3, giveup=should_giveup
54
+ )
42
55
  async def query_shard(node: AbstractIndexNode, shard: str, query: SearchRequest) -> SearchResponse:
43
56
  req = SearchRequest()
44
57
  req.CopyFrom(query)
@@ -47,6 +60,9 @@ async def query_shard(node: AbstractIndexNode, shard: str, query: SearchRequest)
47
60
  return await node.reader.Search(req) # type: ignore
48
61
 
49
62
 
63
+ @backoff.on_exception(
64
+ backoff.expo, Exception, jitter=None, factor=0.1, max_tries=3, giveup=should_giveup
65
+ )
50
66
  async def get_shard(node: AbstractIndexNode, shard_id: str) -> Shard:
51
67
  req = GetShardRequest()
52
68
  req.shard_id.id = shard_id
@@ -54,6 +70,9 @@ async def get_shard(node: AbstractIndexNode, shard_id: str) -> Shard:
54
70
  return await node.reader.GetShard(req) # type: ignore
55
71
 
56
72
 
73
+ @backoff.on_exception(
74
+ backoff.expo, Exception, jitter=None, factor=0.1, max_tries=3, giveup=should_giveup
75
+ )
57
76
  async def suggest_shard(node: AbstractIndexNode, shard: str, query: SuggestRequest) -> SuggestResponse:
58
77
  req = SuggestRequest()
59
78
  req.CopyFrom(query)
@@ -32,7 +32,6 @@ import psutil
32
32
  from fastapi import FastAPI
33
33
  from pydantic import BaseModel
34
34
 
35
- from nucliadb.common.cluster import manager as cluster_manager
36
35
  from nucliadb.standalone.settings import Settings
37
36
  from nucliadb_telemetry.settings import LogOutputType, LogSettings
38
37
 
@@ -83,7 +82,6 @@ async def stream_tar(app: FastAPI) -> AsyncGenerator[bytes, None]:
83
82
  with tarfile.open(tar_file, mode="w:gz") as tar:
84
83
  await add_system_info(temp_dir, tar)
85
84
  await add_dependencies(temp_dir, tar)
86
- await add_cluster_info(temp_dir, tar)
87
85
  settings: Settings = app.settings.copy() # type: ignore
88
86
  await add_settings(temp_dir, tar, settings)
89
87
  if settings.log_output_type == LogOutputType.FILE:
@@ -145,29 +143,6 @@ def _add_dependencies_to_tar(temp_dir: str, tar: tarfile.TarFile):
145
143
  tar.add(dependendies_file, arcname="dependencies.txt")
146
144
 
147
145
 
148
- async def add_cluster_info(temp_dir: str, tar: tarfile.TarFile):
149
- loop = asyncio.get_event_loop()
150
- await loop.run_in_executor(None, _add_cluster_info_to_tar, temp_dir, tar)
151
-
152
-
153
- def _add_cluster_info_to_tar(temp_dir: str, tar: tarfile.TarFile):
154
- cluster_info = ClusterInfo(
155
- nodes=[
156
- NodeInfo(
157
- id=node.id,
158
- address=node.address,
159
- shard_count=node.shard_count,
160
- primary_id=node.primary_id,
161
- )
162
- for node in cluster_manager.get_index_nodes()
163
- ]
164
- )
165
- cluster_info_file = os.path.join(temp_dir, "cluster_info.txt")
166
- with open(cluster_info_file, "w") as f:
167
- f.write(cluster_info.model_dump_json(indent=4))
168
- tar.add(cluster_info_file, arcname="cluster_info.txt")
169
-
170
-
171
146
  async def add_settings(temp_dir: str, tar: tarfile.TarFile, settings: Settings):
172
147
  loop = asyncio.get_event_loop()
173
148
  await loop.run_in_executor(None, _add_settings_to_tar, temp_dir, tar, settings)
@@ -22,10 +22,6 @@ from contextlib import asynccontextmanager
22
22
 
23
23
  from fastapi import FastAPI
24
24
 
25
- from nucliadb.common.cluster.discovery.utils import (
26
- setup_cluster_discovery,
27
- teardown_cluster_discovery,
28
- )
29
25
  from nucliadb.train import SERVICE_NAME
30
26
  from nucliadb.train.utils import (
31
27
  start_shard_manager,
@@ -40,7 +36,6 @@ from nucliadb_utils.utilities import start_audit_utility, stop_audit_utility
40
36
  @asynccontextmanager
41
37
  async def lifespan(app: FastAPI):
42
38
  await setup_telemetry(SERVICE_NAME)
43
- await setup_cluster_discovery()
44
39
  await start_shard_manager()
45
40
  await start_train_grpc(SERVICE_NAME)
46
41
  await start_audit_utility(SERVICE_NAME)
@@ -50,5 +45,4 @@ async def lifespan(app: FastAPI):
50
45
  await stop_audit_utility()
51
46
  await stop_train_grpc()
52
47
  await stop_shard_manager()
53
- await teardown_cluster_discovery()
54
48
  await clean_telemetry(SERVICE_NAME)
nucliadb/train/nodes.py CHANGED
@@ -45,9 +45,7 @@ from nucliadb_protos.train_pb2 import (
45
45
  TrainSentence,
46
46
  )
47
47
  from nucliadb_protos.writer_pb2 import ShardObject
48
- from nucliadb_utils import const
49
48
  from nucliadb_utils.storages.storage import Storage
50
- from nucliadb_utils.utilities import has_feature
51
49
 
52
50
 
53
51
  class TrainShardManager(manager.KBShardManager):
@@ -63,9 +61,7 @@ class TrainShardManager(manager.KBShardManager):
63
61
  except StopIteration:
64
62
  raise KeyError("Shard not found")
65
63
 
66
- node_obj, shard_id = manager.choose_node(
67
- shard_object, use_nidx=has_feature(const.Features.NIDX_READS, context={"kbid": kbid})
68
- )
64
+ node_obj, shard_id = manager.choose_node(shard_object)
69
65
  return node_obj, shard_id
70
66
 
71
67
  async def get_kb_obj(self, txn: Transaction, kbid: str) -> Optional[KnowledgeBox]: