nucliadb 6.7.2.post4874__py3-none-any.whl → 6.10.0.post5705__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (246) hide show
  1. migrations/0023_backfill_pg_catalog.py +8 -4
  2. migrations/0028_extracted_vectors_reference.py +1 -1
  3. migrations/0029_backfill_field_status.py +3 -4
  4. migrations/0032_remove_old_relations.py +2 -3
  5. migrations/0038_backfill_catalog_field_labels.py +8 -4
  6. migrations/0039_backfill_converation_splits_metadata.py +106 -0
  7. migrations/0040_migrate_search_configurations.py +79 -0
  8. migrations/0041_reindex_conversations.py +137 -0
  9. migrations/pg/0010_shards_index.py +34 -0
  10. nucliadb/search/api/v1/resource/utils.py → migrations/pg/0011_catalog_statistics.py +5 -6
  11. migrations/pg/0012_catalog_statistics_undo.py +26 -0
  12. nucliadb/backups/create.py +2 -15
  13. nucliadb/backups/restore.py +4 -15
  14. nucliadb/backups/tasks.py +4 -1
  15. nucliadb/common/back_pressure/cache.py +2 -3
  16. nucliadb/common/back_pressure/materializer.py +7 -13
  17. nucliadb/common/back_pressure/settings.py +6 -6
  18. nucliadb/common/back_pressure/utils.py +1 -0
  19. nucliadb/common/cache.py +9 -9
  20. nucliadb/common/catalog/__init__.py +79 -0
  21. nucliadb/common/catalog/dummy.py +36 -0
  22. nucliadb/common/catalog/interface.py +85 -0
  23. nucliadb/{search/search/pgcatalog.py → common/catalog/pg.py} +330 -232
  24. nucliadb/common/catalog/utils.py +56 -0
  25. nucliadb/common/cluster/manager.py +8 -23
  26. nucliadb/common/cluster/rebalance.py +484 -112
  27. nucliadb/common/cluster/rollover.py +36 -9
  28. nucliadb/common/cluster/settings.py +4 -9
  29. nucliadb/common/cluster/utils.py +34 -8
  30. nucliadb/common/context/__init__.py +7 -8
  31. nucliadb/common/context/fastapi.py +1 -2
  32. nucliadb/common/datamanagers/__init__.py +2 -4
  33. nucliadb/common/datamanagers/atomic.py +9 -2
  34. nucliadb/common/datamanagers/cluster.py +1 -2
  35. nucliadb/common/datamanagers/fields.py +3 -4
  36. nucliadb/common/datamanagers/kb.py +6 -6
  37. nucliadb/common/datamanagers/labels.py +2 -3
  38. nucliadb/common/datamanagers/resources.py +10 -33
  39. nucliadb/common/datamanagers/rollover.py +5 -7
  40. nucliadb/common/datamanagers/search_configurations.py +1 -2
  41. nucliadb/common/datamanagers/synonyms.py +1 -2
  42. nucliadb/common/datamanagers/utils.py +4 -4
  43. nucliadb/common/datamanagers/vectorsets.py +4 -4
  44. nucliadb/common/external_index_providers/base.py +32 -5
  45. nucliadb/common/external_index_providers/manager.py +5 -34
  46. nucliadb/common/external_index_providers/settings.py +1 -27
  47. nucliadb/common/filter_expression.py +129 -41
  48. nucliadb/common/http_clients/exceptions.py +8 -0
  49. nucliadb/common/http_clients/processing.py +16 -23
  50. nucliadb/common/http_clients/utils.py +3 -0
  51. nucliadb/common/ids.py +82 -58
  52. nucliadb/common/locking.py +1 -2
  53. nucliadb/common/maindb/driver.py +9 -8
  54. nucliadb/common/maindb/local.py +5 -5
  55. nucliadb/common/maindb/pg.py +9 -8
  56. nucliadb/common/nidx.py +22 -5
  57. nucliadb/common/vector_index_config.py +1 -1
  58. nucliadb/export_import/datamanager.py +4 -3
  59. nucliadb/export_import/exporter.py +11 -19
  60. nucliadb/export_import/importer.py +13 -6
  61. nucliadb/export_import/tasks.py +2 -0
  62. nucliadb/export_import/utils.py +6 -18
  63. nucliadb/health.py +2 -2
  64. nucliadb/ingest/app.py +8 -8
  65. nucliadb/ingest/consumer/consumer.py +8 -10
  66. nucliadb/ingest/consumer/pull.py +10 -8
  67. nucliadb/ingest/consumer/service.py +5 -30
  68. nucliadb/ingest/consumer/shard_creator.py +16 -5
  69. nucliadb/ingest/consumer/utils.py +1 -1
  70. nucliadb/ingest/fields/base.py +37 -49
  71. nucliadb/ingest/fields/conversation.py +55 -9
  72. nucliadb/ingest/fields/exceptions.py +1 -2
  73. nucliadb/ingest/fields/file.py +22 -8
  74. nucliadb/ingest/fields/link.py +7 -7
  75. nucliadb/ingest/fields/text.py +2 -3
  76. nucliadb/ingest/orm/brain_v2.py +89 -57
  77. nucliadb/ingest/orm/broker_message.py +2 -4
  78. nucliadb/ingest/orm/entities.py +10 -209
  79. nucliadb/ingest/orm/index_message.py +128 -113
  80. nucliadb/ingest/orm/knowledgebox.py +91 -59
  81. nucliadb/ingest/orm/processor/auditing.py +1 -3
  82. nucliadb/ingest/orm/processor/data_augmentation.py +1 -2
  83. nucliadb/ingest/orm/processor/processor.py +98 -153
  84. nucliadb/ingest/orm/processor/sequence_manager.py +1 -2
  85. nucliadb/ingest/orm/resource.py +82 -71
  86. nucliadb/ingest/orm/utils.py +1 -1
  87. nucliadb/ingest/partitions.py +12 -1
  88. nucliadb/ingest/processing.py +17 -17
  89. nucliadb/ingest/serialize.py +202 -145
  90. nucliadb/ingest/service/writer.py +15 -114
  91. nucliadb/ingest/settings.py +36 -15
  92. nucliadb/ingest/utils.py +1 -2
  93. nucliadb/learning_proxy.py +23 -26
  94. nucliadb/metrics_exporter.py +20 -6
  95. nucliadb/middleware/__init__.py +82 -1
  96. nucliadb/migrator/datamanager.py +4 -11
  97. nucliadb/migrator/migrator.py +1 -2
  98. nucliadb/migrator/models.py +1 -2
  99. nucliadb/migrator/settings.py +1 -2
  100. nucliadb/models/internal/augment.py +614 -0
  101. nucliadb/models/internal/processing.py +19 -19
  102. nucliadb/openapi.py +2 -2
  103. nucliadb/purge/__init__.py +3 -8
  104. nucliadb/purge/orphan_shards.py +1 -2
  105. nucliadb/reader/__init__.py +5 -0
  106. nucliadb/reader/api/models.py +6 -13
  107. nucliadb/reader/api/v1/download.py +59 -38
  108. nucliadb/reader/api/v1/export_import.py +4 -4
  109. nucliadb/reader/api/v1/knowledgebox.py +37 -9
  110. nucliadb/reader/api/v1/learning_config.py +33 -14
  111. nucliadb/reader/api/v1/resource.py +61 -9
  112. nucliadb/reader/api/v1/services.py +18 -14
  113. nucliadb/reader/app.py +3 -1
  114. nucliadb/reader/reader/notifications.py +1 -2
  115. nucliadb/search/api/v1/__init__.py +3 -0
  116. nucliadb/search/api/v1/ask.py +3 -4
  117. nucliadb/search/api/v1/augment.py +585 -0
  118. nucliadb/search/api/v1/catalog.py +15 -19
  119. nucliadb/search/api/v1/find.py +16 -22
  120. nucliadb/search/api/v1/hydrate.py +328 -0
  121. nucliadb/search/api/v1/knowledgebox.py +1 -2
  122. nucliadb/search/api/v1/predict_proxy.py +1 -2
  123. nucliadb/search/api/v1/resource/ask.py +28 -8
  124. nucliadb/search/api/v1/resource/ingestion_agents.py +5 -6
  125. nucliadb/search/api/v1/resource/search.py +9 -11
  126. nucliadb/search/api/v1/retrieve.py +130 -0
  127. nucliadb/search/api/v1/search.py +28 -32
  128. nucliadb/search/api/v1/suggest.py +11 -14
  129. nucliadb/search/api/v1/summarize.py +1 -2
  130. nucliadb/search/api/v1/utils.py +2 -2
  131. nucliadb/search/app.py +3 -2
  132. nucliadb/search/augmentor/__init__.py +21 -0
  133. nucliadb/search/augmentor/augmentor.py +232 -0
  134. nucliadb/search/augmentor/fields.py +704 -0
  135. nucliadb/search/augmentor/metrics.py +24 -0
  136. nucliadb/search/augmentor/paragraphs.py +334 -0
  137. nucliadb/search/augmentor/resources.py +238 -0
  138. nucliadb/search/augmentor/utils.py +33 -0
  139. nucliadb/search/lifecycle.py +3 -1
  140. nucliadb/search/predict.py +33 -19
  141. nucliadb/search/predict_models.py +8 -9
  142. nucliadb/search/requesters/utils.py +11 -10
  143. nucliadb/search/search/cache.py +19 -42
  144. nucliadb/search/search/chat/ask.py +131 -59
  145. nucliadb/search/search/chat/exceptions.py +3 -5
  146. nucliadb/search/search/chat/fetcher.py +201 -0
  147. nucliadb/search/search/chat/images.py +6 -4
  148. nucliadb/search/search/chat/old_prompt.py +1375 -0
  149. nucliadb/search/search/chat/parser.py +510 -0
  150. nucliadb/search/search/chat/prompt.py +563 -615
  151. nucliadb/search/search/chat/query.py +453 -32
  152. nucliadb/search/search/chat/rpc.py +85 -0
  153. nucliadb/search/search/fetch.py +3 -4
  154. nucliadb/search/search/filters.py +8 -11
  155. nucliadb/search/search/find.py +33 -31
  156. nucliadb/search/search/find_merge.py +124 -331
  157. nucliadb/search/search/graph_strategy.py +14 -12
  158. nucliadb/search/search/hydrator/__init__.py +49 -0
  159. nucliadb/search/search/hydrator/fields.py +217 -0
  160. nucliadb/search/search/hydrator/images.py +130 -0
  161. nucliadb/search/search/hydrator/paragraphs.py +323 -0
  162. nucliadb/search/search/hydrator/resources.py +60 -0
  163. nucliadb/search/search/ingestion_agents.py +5 -5
  164. nucliadb/search/search/merge.py +90 -94
  165. nucliadb/search/search/metrics.py +24 -7
  166. nucliadb/search/search/paragraphs.py +7 -9
  167. nucliadb/search/search/predict_proxy.py +44 -18
  168. nucliadb/search/search/query.py +14 -86
  169. nucliadb/search/search/query_parser/fetcher.py +51 -82
  170. nucliadb/search/search/query_parser/models.py +19 -48
  171. nucliadb/search/search/query_parser/old_filters.py +20 -19
  172. nucliadb/search/search/query_parser/parsers/ask.py +5 -6
  173. nucliadb/search/search/query_parser/parsers/catalog.py +7 -11
  174. nucliadb/search/search/query_parser/parsers/common.py +21 -13
  175. nucliadb/search/search/query_parser/parsers/find.py +6 -29
  176. nucliadb/search/search/query_parser/parsers/graph.py +18 -28
  177. nucliadb/search/search/query_parser/parsers/retrieve.py +207 -0
  178. nucliadb/search/search/query_parser/parsers/search.py +15 -56
  179. nucliadb/search/search/query_parser/parsers/unit_retrieval.py +8 -29
  180. nucliadb/search/search/rank_fusion.py +18 -13
  181. nucliadb/search/search/rerankers.py +6 -7
  182. nucliadb/search/search/retrieval.py +300 -0
  183. nucliadb/search/search/summarize.py +5 -6
  184. nucliadb/search/search/utils.py +3 -4
  185. nucliadb/search/settings.py +1 -2
  186. nucliadb/standalone/api_router.py +1 -1
  187. nucliadb/standalone/app.py +4 -3
  188. nucliadb/standalone/auth.py +5 -6
  189. nucliadb/standalone/lifecycle.py +2 -2
  190. nucliadb/standalone/run.py +5 -4
  191. nucliadb/standalone/settings.py +5 -6
  192. nucliadb/standalone/versions.py +3 -4
  193. nucliadb/tasks/consumer.py +13 -8
  194. nucliadb/tasks/models.py +2 -1
  195. nucliadb/tasks/producer.py +3 -3
  196. nucliadb/tasks/retries.py +8 -7
  197. nucliadb/train/api/utils.py +1 -3
  198. nucliadb/train/api/v1/shards.py +1 -2
  199. nucliadb/train/api/v1/trainset.py +1 -2
  200. nucliadb/train/app.py +1 -1
  201. nucliadb/train/generator.py +4 -4
  202. nucliadb/train/generators/field_classifier.py +2 -2
  203. nucliadb/train/generators/field_streaming.py +6 -6
  204. nucliadb/train/generators/image_classifier.py +2 -2
  205. nucliadb/train/generators/paragraph_classifier.py +2 -2
  206. nucliadb/train/generators/paragraph_streaming.py +2 -2
  207. nucliadb/train/generators/question_answer_streaming.py +2 -2
  208. nucliadb/train/generators/sentence_classifier.py +4 -10
  209. nucliadb/train/generators/token_classifier.py +3 -2
  210. nucliadb/train/generators/utils.py +6 -5
  211. nucliadb/train/nodes.py +3 -3
  212. nucliadb/train/resource.py +6 -8
  213. nucliadb/train/settings.py +3 -4
  214. nucliadb/train/types.py +11 -11
  215. nucliadb/train/upload.py +3 -2
  216. nucliadb/train/uploader.py +1 -2
  217. nucliadb/train/utils.py +1 -2
  218. nucliadb/writer/api/v1/export_import.py +4 -1
  219. nucliadb/writer/api/v1/field.py +15 -14
  220. nucliadb/writer/api/v1/knowledgebox.py +18 -56
  221. nucliadb/writer/api/v1/learning_config.py +5 -4
  222. nucliadb/writer/api/v1/resource.py +9 -20
  223. nucliadb/writer/api/v1/services.py +10 -132
  224. nucliadb/writer/api/v1/upload.py +73 -72
  225. nucliadb/writer/app.py +8 -2
  226. nucliadb/writer/resource/basic.py +12 -15
  227. nucliadb/writer/resource/field.py +43 -5
  228. nucliadb/writer/resource/origin.py +7 -0
  229. nucliadb/writer/settings.py +2 -3
  230. nucliadb/writer/tus/__init__.py +2 -3
  231. nucliadb/writer/tus/azure.py +5 -7
  232. nucliadb/writer/tus/dm.py +3 -3
  233. nucliadb/writer/tus/exceptions.py +3 -4
  234. nucliadb/writer/tus/gcs.py +15 -22
  235. nucliadb/writer/tus/s3.py +2 -3
  236. nucliadb/writer/tus/storage.py +3 -3
  237. {nucliadb-6.7.2.post4874.dist-info → nucliadb-6.10.0.post5705.dist-info}/METADATA +10 -11
  238. nucliadb-6.10.0.post5705.dist-info/RECORD +410 -0
  239. nucliadb/common/datamanagers/entities.py +0 -139
  240. nucliadb/common/external_index_providers/pinecone.py +0 -894
  241. nucliadb/ingest/orm/processor/pgcatalog.py +0 -129
  242. nucliadb/search/search/hydrator.py +0 -197
  243. nucliadb-6.7.2.post4874.dist-info/RECORD +0 -383
  244. {nucliadb-6.7.2.post4874.dist-info → nucliadb-6.10.0.post5705.dist-info}/WHEEL +0 -0
  245. {nucliadb-6.7.2.post4874.dist-info → nucliadb-6.10.0.post5705.dist-info}/entry_points.txt +0 -0
  246. {nucliadb-6.7.2.post4874.dist-info → nucliadb-6.10.0.post5705.dist-info}/top_level.txt +0 -0
@@ -18,162 +18,534 @@
18
18
  # along with this program. If not, see <http://www.gnu.org/licenses/>.
19
19
  #
20
20
  import asyncio
21
+ import dataclasses
21
22
  import logging
23
+ import math
24
+ import random
25
+ from typing import cast
22
26
 
27
+ from grpc import StatusCode
28
+ from grpc.aio import AioRpcError
23
29
  from nidx_protos import nodereader_pb2, noderesources_pb2
24
30
 
25
31
  from nucliadb.common import datamanagers, locking
26
32
  from nucliadb.common.cluster.utils import get_shard_manager
27
33
  from nucliadb.common.context import ApplicationContext
34
+ from nucliadb.common.maindb.driver import Driver
35
+ from nucliadb.common.maindb.pg import PGDriver
28
36
  from nucliadb.common.nidx import get_nidx_api_client, get_nidx_searcher_client
37
+ from nucliadb_protos import writer_pb2
29
38
  from nucliadb_telemetry import errors
30
39
  from nucliadb_telemetry.logs import setup_logging
31
40
  from nucliadb_telemetry.utils import setup_telemetry
32
41
  from nucliadb_utils.fastapi.run import serve_metrics
33
42
 
34
43
  from .settings import settings
35
- from .utils import delete_resource_from_shard, index_resource_to_shard
44
+ from .utils import delete_resource_from_shard, index_resource_to_shard, wait_for_nidx
36
45
 
37
46
  logger = logging.getLogger(__name__)
38
47
 
39
48
  REBALANCE_LOCK = "rebalance"
40
49
 
50
+ MAX_MOVES_PER_SHARD = 100
51
+
52
+
53
+ @dataclasses.dataclass
54
+ class RebalanceShard:
55
+ id: str
56
+ nidx_id: str
57
+ paragraphs: int
58
+ active: bool
59
+
60
+ def to_dict(self):
61
+ return self.__dict__
62
+
63
+
64
+ class Rebalancer:
65
+ def __init__(self, context: ApplicationContext, kbid: str):
66
+ self.context = context
67
+ self.kbid = kbid
68
+ self.kb_shards: writer_pb2.Shards | None = None
69
+
70
+ async def get_rebalance_shards(self, estimate: bool = False) -> list[RebalanceShard]:
71
+ """
72
+ Return the sorted list of shards by increasing paragraph count.
73
+
74
+ If estimate is True, it will fetch the paragraph count from nidx shard metadata, which is lighter
75
+ but deletions are not guaranteed to be reflected. Otherwise, it will get the paragraph counts
76
+ by querying nidx paragraph index for each shard.
77
+ """
78
+ result = []
79
+ self.kb_shards = await datamanagers.atomic.cluster.get_kb_shards(kbid=self.kbid)
80
+ if self.kb_shards is not None:
81
+ for idx, shard in enumerate(self.kb_shards.shards):
82
+ if estimate:
83
+ shard_metadata = await get_shard_metadata(shard.nidx_shard_id)
84
+ paragraphs = shard_metadata.paragraphs
85
+ else:
86
+ paragraphs = await get_shard_paragraph_count(shard.nidx_shard_id)
87
+ result.append(
88
+ RebalanceShard(
89
+ id=shard.shard,
90
+ nidx_id=shard.nidx_shard_id,
91
+ paragraphs=paragraphs,
92
+ active=(idx == self.kb_shards.actual),
93
+ )
94
+ )
95
+ return list(sorted(result, key=lambda x: x.paragraphs))
96
+
97
+ async def move_paragraphs(
98
+ self, from_shard: RebalanceShard, to_shard: RebalanceShard, max_paragraphs: int
99
+ ) -> int:
100
+ """
101
+ Takes random resources from the source shard and tries to move at most max_paragraphs.
102
+ It stops moving paragraphs until the are no more resources to move.
103
+ """
104
+ moved_paragraphs = 0
105
+
106
+ resources_batch: list[str] = []
107
+
108
+ while moved_paragraphs < max_paragraphs:
109
+ if len(resources_batch) == 0:
110
+ resources_batch = await get_resources_from_shard(
111
+ self.context.kv_driver, self.kbid, from_shard.id, n=100
112
+ )
113
+ if len(resources_batch) == 0:
114
+ # No more resources to move or shard not found
115
+ break
116
+
117
+ # Take a random resource to move
118
+ resource_id = random.choice(resources_batch)
119
+
120
+ assert self.kb_shards is not None
121
+ from_shard_obj = next(s for s in self.kb_shards.shards if s.shard == from_shard.id)
122
+ to_shard_obj = next(s for s in self.kb_shards.shards if s.shard == to_shard.id)
123
+ paragraphs_count = await get_resource_paragraphs_count(resource_id, from_shard.nidx_id)
124
+ moved = await move_resource_to_shard(
125
+ self.context, self.kbid, resource_id, from_shard_obj, to_shard_obj
126
+ )
127
+ if moved:
128
+ resources_batch.remove(resource_id)
129
+ moved_paragraphs += paragraphs_count
41
130
 
42
- async def get_shards_paragraphs(kbid: str) -> list[tuple[str, int]]:
43
- """
44
- Ordered shard -> num paragraph by number of paragraphs
45
- """
46
- async with datamanagers.with_ro_transaction() as txn:
47
- kb_shards = await datamanagers.cluster.get_kb_shards(txn, kbid=kbid)
48
- if kb_shards is None:
49
- return []
50
-
51
- results = {}
52
- for shard_meta in kb_shards.shards:
53
- # Rebalance using node as source of truth. But it will rebalance nidx
54
- shard_data: nodereader_pb2.Shard = await get_nidx_api_client().GetShard(
55
- nodereader_pb2.GetShardRequest(
56
- shard_id=noderesources_pb2.ShardId(id=shard_meta.nidx_shard_id)
57
- ) # type: ignore
131
+ return moved_paragraphs
132
+
133
+ async def wait_for_indexing(self):
134
+ try:
135
+ self.context.nats_manager
136
+ except AssertionError: # pragma: no cover
137
+ logger.warning(f"Nats manager not initialized. Cannot wait for indexing")
138
+ return
139
+ while True:
140
+ try:
141
+ await wait_for_nidx(self.context.nats_manager, max_wait_seconds=60, max_pending=1000)
142
+ return
143
+ except asyncio.TimeoutError:
144
+ logger.warning("Nidx is behind. Backing off rebalancing.", extra={"kbid": self.kbid})
145
+ await asyncio.sleep(30)
146
+
147
+ async def required(self) -> bool:
148
+ """
149
+ Return true if any shard needs rebalancing.
150
+ """
151
+ shards = await self.get_rebalance_shards(estimate=True)
152
+ return any(needs_split(shard) or needs_merge(shard, shards) for shard in shards)
153
+
154
+ async def rebalance_shards(self):
155
+ """
156
+ Iterate over shards until none of them need more rebalancing.
157
+
158
+ Will move excess of paragraphs to other shards (potentially creating new ones), and
159
+ merge small shards together when possible (potentially deleting empty ones.)
160
+
161
+
162
+ Merge chooses a <90% filled shard and fills it to almost 100%
163
+ Split chooses a >110% filled shard and reduces it to 100%
164
+ If the shard is between 90% and 110% full, nobody touches it
165
+ """
166
+ while True:
167
+ await self.wait_for_indexing()
168
+ shards = await self.get_rebalance_shards()
169
+
170
+ # Any shards to split?
171
+ shard_to_split = next((s for s in shards[::-1] if needs_split(s)), None)
172
+ if shard_to_split is not None:
173
+ await self.split_shard(shard_to_split, shards)
174
+ continue
175
+
176
+ # Any shards to merge?
177
+ shard_to_merge = next((s for s in shards if needs_merge(s, shards)), None)
178
+ if shard_to_merge is not None:
179
+ await self.merge_shard(shard_to_merge, shards)
180
+ else:
181
+ break
182
+
183
+ async def split_shard(self, shard_to_split: RebalanceShard, shards: list[RebalanceShard]):
184
+ logger.info(
185
+ "Splitting excess of paragraphs to other shards",
186
+ extra={
187
+ "kbid": self.kbid,
188
+ "shard": shard_to_split.to_dict(),
189
+ },
58
190
  )
59
- results[shard_meta.shard] = shard_data.paragraphs
60
191
 
61
- return [(shard, paragraphs) for shard, paragraphs in sorted(results.items(), key=lambda x: x[1])]
192
+ # First off, calculate if the excess fits in the other shards or we need to add a new shard.
193
+ # Note that we don't filter out the active shard on purpose.
194
+ excess = shard_to_split.paragraphs - settings.max_shard_paragraphs
195
+ other_shards = [s for s in shards if s.id != shard_to_split.id]
196
+ other_shards_capacity = sum(
197
+ [max(0, (settings.max_shard_paragraphs - s.paragraphs)) for s in other_shards]
198
+ )
199
+ if excess > other_shards_capacity:
200
+ shards_to_add = math.ceil((excess - other_shards_capacity) / settings.max_shard_paragraphs)
201
+ logger.info(
202
+ "More shards needed",
203
+ extra={
204
+ "kbid": self.kbid,
205
+ "shards_to_add": shards_to_add,
206
+ "all_shards": [s.to_dict() for s in shards],
207
+ },
208
+ )
209
+ # Add new shards where to rebalance the excess of paragraphs
210
+ async with (
211
+ locking.distributed_lock(locking.NEW_SHARD_LOCK.format(kbid=self.kbid)),
212
+ datamanagers.with_rw_transaction() as txn,
213
+ ):
214
+ kb_config = await datamanagers.kb.get_config(txn, kbid=self.kbid)
215
+ prewarm = kb_config is not None and kb_config.prewarm_enabled
216
+ sm = get_shard_manager()
217
+ for _ in range(shards_to_add):
218
+ await sm.create_shard_by_kbid(txn, self.kbid, prewarm_enabled=prewarm)
219
+ await txn.commit()
62
220
 
221
+ # Recalculate after having created shards, the active shard is a different one
222
+ shards = await self.get_rebalance_shards()
223
+
224
+ # Now, move resources to other shards as long as we are still over the max
225
+ for _ in range(MAX_MOVES_PER_SHARD):
226
+ shard_paragraphs = next(s.paragraphs for s in shards if s.id == shard_to_split.id)
227
+ excess = shard_paragraphs - settings.max_shard_paragraphs
228
+ if excess <= 0:
229
+ logger.info(
230
+ "Shard rebalanced successfuly",
231
+ extra={"kbid": self.kbid, "shard": shard_to_split.to_dict()},
232
+ )
233
+ break
63
234
 
64
- async def maybe_add_shard(kbid: str) -> None:
65
- async with locking.distributed_lock(locking.NEW_SHARD_LOCK.format(kbid=kbid)):
66
- async with datamanagers.with_ro_transaction() as txn:
67
- kb_shards = await datamanagers.cluster.get_kb_shards(txn, kbid=kbid)
68
- if kb_shards is None:
69
- return
235
+ target_shard, target_capacity = get_target_shard(shards, shard_to_split, skip_active=False)
236
+ if target_shard is None:
237
+ logger.warning("No target shard found for splitting", extra={"kbid": self.kbid})
238
+ break
70
239
 
71
- shard_paragraphs = await get_shards_paragraphs(kbid)
72
- total_paragraphs = sum([c for _, c in shard_paragraphs])
240
+ moved_paragraphs = await self.move_paragraphs(
241
+ from_shard=shard_to_split,
242
+ to_shard=target_shard,
243
+ max_paragraphs=min(excess, target_capacity),
244
+ )
73
245
 
74
- if (total_paragraphs / len(kb_shards.shards)) > (
75
- settings.max_shard_paragraphs * 0.9 # 90% of the max
76
- ):
77
- # create new shard
78
- async with datamanagers.with_transaction() as txn:
79
- sm = get_shard_manager()
80
- await sm.create_shard_by_kbid(txn, kbid)
81
- await txn.commit()
246
+ # Update shard paragraph counts
247
+ shard_to_split.paragraphs -= moved_paragraphs
248
+ target_shard.paragraphs += moved_paragraphs
249
+ shards.sort(key=lambda x: x.paragraphs)
82
250
 
251
+ await self.wait_for_indexing()
83
252
 
84
- async def move_set_of_kb_resources(
85
- context: ApplicationContext,
86
- kbid: str,
87
- from_shard_id: str,
88
- to_shard_id: str,
89
- count: int = 20,
90
- ) -> None:
91
- async with datamanagers.with_ro_transaction() as txn:
92
- kb_shards = await datamanagers.cluster.get_kb_shards(txn, kbid=kbid)
93
- if kb_shards is None: # pragma: no cover
94
- logger.warning("No shards found for kb. This should not happen.", extra={"kbid": kbid})
95
- return
253
+ async def merge_shard(self, shard_to_merge: RebalanceShard, shards: list[RebalanceShard]):
254
+ logger.info(
255
+ "Merging shard",
256
+ extra={
257
+ "kbid": self.kbid,
258
+ "shard": shard_to_merge.to_dict(),
259
+ },
260
+ )
261
+ empty_shard = False
96
262
 
97
- logger.info(
98
- "Rebalancing kb shards",
99
- extra={"kbid": kbid, "from": from_shard_id, "to": to_shard_id, "count": count},
100
- )
263
+ for _ in range(MAX_MOVES_PER_SHARD):
264
+ resources_count = await count_resources_in_shard(
265
+ self.context.kv_driver, self.kbid, shard_to_merge.id
266
+ )
267
+ if resources_count == 0:
268
+ logger.info(
269
+ "Shard is now empty",
270
+ extra={
271
+ "kbid": self.kbid,
272
+ "shard": shard_to_merge.to_dict(),
273
+ },
274
+ )
275
+ empty_shard = True
276
+ break
277
+
278
+ logger.info(
279
+ "Shard not yet empty",
280
+ extra={
281
+ "kbid": self.kbid,
282
+ "shard": shard_to_merge.to_dict(),
283
+ "remaining": resources_count,
284
+ },
285
+ )
286
+
287
+ target_shard, target_capacity = get_target_shard(shards, shard_to_merge, skip_active=True)
288
+ if target_shard is None:
289
+ logger.warning(
290
+ "No target shard could be found for merging. Moving on",
291
+ extra={"kbid": self.kbid, "shard": shard_to_merge.to_dict()},
292
+ )
293
+ break
101
294
 
102
- from_shard = [s for s in kb_shards.shards if s.shard == from_shard_id][0]
103
- to_shard = [s for s in kb_shards.shards if s.shard == to_shard_id][0]
295
+ moved_paragraphs = await self.move_paragraphs(
296
+ from_shard=shard_to_merge,
297
+ to_shard=target_shard,
298
+ max_paragraphs=target_capacity,
299
+ )
300
+
301
+ # Update shard paragraph counts
302
+ shard_to_merge.paragraphs -= moved_paragraphs
303
+ target_shard.paragraphs += moved_paragraphs
304
+ shards.sort(key=lambda x: x.paragraphs)
305
+
306
+ await self.wait_for_indexing()
307
+
308
+ if empty_shard:
309
+ # If shard was emptied, delete it
310
+ async with locking.distributed_lock(locking.NEW_SHARD_LOCK.format(kbid=self.kbid)):
311
+ async with datamanagers.with_rw_transaction() as txn:
312
+ kb_shards = await datamanagers.cluster.get_kb_shards(
313
+ txn, kbid=self.kbid, for_update=True
314
+ )
315
+ if kb_shards is not None:
316
+ logger.info(
317
+ "Deleting empty shard",
318
+ extra={
319
+ "kbid": self.kbid,
320
+ "shard_id": shard_to_merge.id,
321
+ "nidx_shard_id": shard_to_merge.nidx_id,
322
+ },
323
+ )
324
+
325
+ # Delete shards from kb shards in maindb
326
+ to_delete, to_delete_idx = next(
327
+ (s, idx)
328
+ for idx, s in enumerate(kb_shards.shards)
329
+ if s.shard == shard_to_merge.id
330
+ )
331
+ kb_shards.shards.remove(to_delete)
332
+ if to_delete_idx <= kb_shards.actual:
333
+ # Only decrement the actual pointer if we remove before the pointer.
334
+ kb_shards.actual -= 1
335
+ assert kb_shards.actual >= 0
336
+ await datamanagers.cluster.update_kb_shards(
337
+ txn, kbid=self.kbid, shards=kb_shards
338
+ )
339
+ await txn.commit()
340
+
341
+ # Delete shard from nidx
342
+ if to_delete:
343
+ await get_nidx_api_client().DeleteShard(
344
+ noderesources_pb2.ShardId(id=to_delete.nidx_shard_id)
345
+ )
346
+
347
+
348
+ async def get_resources_from_shard(driver: Driver, kbid: str, shard_id: str, n: int) -> list[str]:
349
+ driver = cast(PGDriver, driver)
350
+ async with driver._get_connection() as conn:
351
+ cur = conn.cursor("")
352
+ await cur.execute(
353
+ """
354
+ SELECT split_part(key, '/', 5) FROM resources WHERE key ~ '/kbs/[^/]*/r/[^/]*/shard$' AND key ~ %s AND value = %s LIMIT %s;
355
+ """,
356
+ (f"/kbs/{kbid}/r/[^/]*/shard$", shard_id, n),
357
+ )
358
+ records = await cur.fetchall()
359
+ rids: list[str] = [r[0] for r in records]
360
+ return rids
361
+
362
+
363
+ async def get_resource_paragraphs_count(resource_id: str, nidx_shard_id: str) -> int:
364
+ # Do a search on the fields (paragraph) index and return the number of paragraphs this resource has
365
+ try:
366
+ request = nodereader_pb2.SearchRequest(
367
+ shard=nidx_shard_id,
368
+ paragraph=True,
369
+ document=False,
370
+ result_per_page=0,
371
+ field_filter=nodereader_pb2.FilterExpression(
372
+ resource=nodereader_pb2.FilterExpression.ResourceFilter(resource_id=resource_id)
373
+ ),
374
+ )
375
+ search_response: nodereader_pb2.SearchResponse = await get_nidx_searcher_client().Search(request)
376
+ return search_response.paragraph.total
377
+ except AioRpcError as exc: # pragma: no cover
378
+ if exc.code() == StatusCode.NOT_FOUND:
379
+ logger.warning(f"Shard not found in nidx", extra={"nidx_shard_id": nidx_shard_id})
380
+ return 0
381
+ raise
104
382
 
105
- request = nodereader_pb2.SearchRequest(
106
- shard=from_shard.nidx_shard_id,
107
- paragraph=False,
108
- document=True,
109
- result_per_page=count,
383
+
384
+ def get_target_shard(
385
+ shards: list[RebalanceShard], rebalanced_shard: RebalanceShard, skip_active: bool = True
386
+ ) -> tuple[RebalanceShard | None, int]:
387
+ """
388
+ Return the biggest shard with capacity (< 90% of the max paragraphs per shard).
389
+ """
390
+ target_shard = next(
391
+ reversed(
392
+ [
393
+ s
394
+ for s in shards
395
+ if s.id != rebalanced_shard.id
396
+ and s.paragraphs < settings.max_shard_paragraphs * 0.9
397
+ and (not skip_active or (skip_active and not s.active))
398
+ ]
399
+ ),
400
+ None,
110
401
  )
111
- request.field_filter.field.field_type = "a"
112
- request.field_filter.field.field_id = "title"
113
- search_response: nodereader_pb2.SearchResponse = await get_nidx_searcher_client().Search(request)
402
+ if target_shard is None: # pragma: no cover
403
+ return None, 0
404
+
405
+ # Aim to fill target shards up to 100% of max
406
+ capacity = int(max(0, settings.max_shard_paragraphs - target_shard.paragraphs))
407
+ return target_shard, capacity
408
+
409
+
410
+ async def count_resources_in_shard(driver: Driver, kbid: str, shard_id: str) -> int:
411
+ driver = cast(PGDriver, driver)
412
+ async with driver._get_connection() as conn:
413
+ cur = conn.cursor("")
414
+ await cur.execute(
415
+ """
416
+ SELECT COUNT(*) FROM resources WHERE key ~ '/kbs/[^/]*/r/[^/]*/shard$' AND key ~ %s AND value = %s;
417
+ """,
418
+ (f"/kbs/{kbid}/r/[^/]*/shard$", shard_id),
419
+ )
420
+ record = await cur.fetchone()
421
+ if record is None: # pragma: no cover
422
+ return 0
423
+ return record[0]
424
+
425
+
426
+ async def get_shard_paragraph_count(nidx_shard_id: str) -> int:
427
+ # Do a search on the fields (paragraph) index
428
+ try:
429
+ request = nodereader_pb2.SearchRequest(
430
+ shard=nidx_shard_id,
431
+ paragraph=True,
432
+ document=False,
433
+ result_per_page=0,
434
+ )
435
+ search_response: nodereader_pb2.SearchResponse = await get_nidx_searcher_client().Search(request)
436
+ return search_response.paragraph.total
437
+ except AioRpcError as exc: # pragma: no cover
438
+ if exc.code() == StatusCode.NOT_FOUND:
439
+ logger.warning(f"Shard not found in nidx", extra={"nidx_shard_id": nidx_shard_id})
440
+ return 0
441
+ raise
442
+
443
+
444
+ async def get_shard_metadata(nidx_shard_id: str) -> nodereader_pb2.Shard:
445
+ try:
446
+ shard_metadata: nodereader_pb2.Shard = await get_nidx_api_client().GetShard(
447
+ nodereader_pb2.GetShardRequest(shard_id=noderesources_pb2.ShardId(id=nidx_shard_id))
448
+ )
449
+ return shard_metadata
450
+ except AioRpcError as exc: # pragma: no cover
451
+ if exc.code() == StatusCode.NOT_FOUND:
452
+ logger.warning(f"Shard not found in nidx", extra={"nidx_shard_id": nidx_shard_id})
453
+ return nodereader_pb2.Shard()
454
+ raise
455
+
114
456
 
115
- for result in search_response.document.results:
116
- resource_id = result.uuid
457
+ async def move_resource_to_shard(
458
+ context: ApplicationContext,
459
+ kbid: str,
460
+ resource_id: str,
461
+ from_shard: writer_pb2.ShardObject,
462
+ to_shard: writer_pb2.ShardObject,
463
+ ) -> bool:
464
+ indexed_to_new = False
465
+ deleted_from_old = False
466
+ try:
467
+ async with (
468
+ datamanagers.with_transaction() as txn,
469
+ locking.distributed_lock(
470
+ locking.RESOURCE_INDEX_LOCK.format(kbid=kbid, resource_id=resource_id)
471
+ ),
472
+ ):
473
+ found_shard_id = await datamanagers.resources.get_resource_shard_id(
474
+ txn, kbid=kbid, rid=resource_id, for_update=True
475
+ )
476
+ if found_shard_id is None: # pragma: no cover
477
+ # resource deleted
478
+ return False
479
+ if found_shard_id != from_shard.shard: # pragma: no cover
480
+ # resource could have already been moved
481
+ return False
482
+
483
+ await datamanagers.resources.set_resource_shard_id(
484
+ txn, kbid=kbid, rid=resource_id, shard=to_shard.shard
485
+ )
486
+ await index_resource_to_shard(context, kbid, resource_id, to_shard)
487
+ indexed_to_new = True
488
+ await delete_resource_from_shard(context, kbid, resource_id, from_shard)
489
+ deleted_from_old = True
490
+ await txn.commit()
491
+ return True
492
+ except Exception:
493
+ logger.exception(
494
+ "Failed to move resource",
495
+ extra={"kbid": kbid, "resource_id": resource_id},
496
+ )
497
+ # XXX Not ideal failure situation here. Try reverting the whole move even though it could be redundant
117
498
  try:
118
- async with (
119
- datamanagers.with_transaction() as txn,
120
- locking.distributed_lock(
121
- locking.RESOURCE_INDEX_LOCK.format(kbid=kbid, resource_id=resource_id)
122
- ),
123
- ):
124
- found_shard_id = await datamanagers.resources.get_resource_shard_id(
125
- txn, kbid=kbid, rid=resource_id, for_update=True
126
- )
127
- if found_shard_id is None:
128
- # resource deleted
129
- continue
130
- if found_shard_id != from_shard_id:
131
- # resource could have already been moved
132
- continue
133
-
134
- await datamanagers.resources.set_resource_shard_id(
135
- txn, kbid=kbid, rid=resource_id, shard=to_shard_id
136
- )
137
- await index_resource_to_shard(context, kbid, resource_id, to_shard)
138
- await delete_resource_from_shard(context, kbid, resource_id, from_shard)
139
- await txn.commit()
499
+ if indexed_to_new:
500
+ await delete_resource_from_shard(context, kbid, resource_id, to_shard)
501
+ if deleted_from_old:
502
+ await index_resource_to_shard(context, kbid, resource_id, from_shard)
140
503
  except Exception:
141
504
  logger.exception(
142
- "Failed to move resource",
505
+ "Failed to revert move resource. Hopefully you never see this message.",
143
506
  extra={"kbid": kbid, "resource_id": resource_id},
144
507
  )
145
- # XXX Not ideal failure situation here. Try reverting the whole move even though it could be redundant
146
- try:
147
- await index_resource_to_shard(context, kbid, resource_id, from_shard)
148
- await delete_resource_from_shard(context, kbid, resource_id, to_shard)
149
- except Exception:
150
- logger.exception(
151
- "Failed to revert move resource. Hopefully you never see this message.",
152
- extra={"kbid": kbid, "resource_id": resource_id},
153
- )
508
+ return False
154
509
 
155
510
 
156
- async def rebalance_kb(context: ApplicationContext, kbid: str) -> None:
157
- await maybe_add_shard(kbid)
511
+ def needs_split(shard: RebalanceShard) -> bool:
512
+ """
513
+ Return true if the shard is more than 110% of the max.
514
+
515
+ Active shards are not considered for splitting: the shard creator subscriber will
516
+ eventually create a new shard, make it the active one and the previous one, if
517
+ too full, will be split.
518
+ """
519
+ return not shard.active and (shard.paragraphs > (settings.max_shard_paragraphs * 1.1))
158
520
 
159
- shard_paragraphs = await get_shards_paragraphs(kbid)
160
- rebalanced_shards = set()
161
- while any(paragraphs > settings.max_shard_paragraphs for _, paragraphs in shard_paragraphs):
162
- # find the shard with the least/most paragraphs
163
- smallest_shard = shard_paragraphs[0][0]
164
- largest_shard = shard_paragraphs[-1][0]
165
- assert smallest_shard != largest_shard
166
521
 
167
- if smallest_shard in rebalanced_shards:
168
- # XXX This is to prevent flapping data between shards on a single pass
169
- # if we already rebalanced this shard, then we can't do anything else
170
- break
522
+ def needs_merge(shard: RebalanceShard, all_shards: list[RebalanceShard]) -> bool:
523
+ """
524
+ Returns true if a shard is less 75% full and there is enough capacity on the other shards to fit it.
171
525
 
172
- await move_set_of_kb_resources(context, kbid, largest_shard, smallest_shard)
526
+ Active shards are not considered for merging. Shards that are more than 75% full are also skipped.
527
+ """
528
+ if shard.active:
529
+ return False
530
+ if shard.paragraphs > (settings.max_shard_paragraphs * 0.75):
531
+ return False
532
+ other_shards = [s for s in all_shards if s.id != shard.id and not s.active]
533
+ other_shards_capacity = sum(
534
+ [max(0, ((settings.max_shard_paragraphs * 0.9) - s.paragraphs)) for s in other_shards]
535
+ )
536
+ return shard.paragraphs < other_shards_capacity
173
537
 
174
- rebalanced_shards.add(largest_shard)
175
538
 
176
- shard_paragraphs = await get_shards_paragraphs(kbid)
539
+ async def rebalance_kb(context: ApplicationContext, kbid: str) -> None:
540
+ rebalancer = Rebalancer(context, kbid)
541
+ try:
542
+ logger.info("Starting rebalance for kb", extra={"kbid": kbid})
543
+ if await rebalancer.required():
544
+ await rebalancer.rebalance_shards()
545
+ logger.info("Finished rebalance for kb", extra={"kbid": kbid})
546
+ except Exception as err:
547
+ logger.exception("Rebalance finished with error", extra={"kbid": kbid})
548
+ errors.capture_exception(err)
177
549
 
178
550
 
179
551
  async def run(context: ApplicationContext) -> None:
@@ -182,7 +554,7 @@ async def run(context: ApplicationContext) -> None:
182
554
  # get all kb ids
183
555
  async with datamanagers.with_ro_transaction() as txn:
184
556
  kbids = [kbid async for kbid, _ in datamanagers.kb.get_kbs(txn)]
185
- # go through each kb and see if shards need to be reduced in size
557
+ # go through each kb and see if shards need to be rebalanced
186
558
  for kbid in kbids:
187
559
  async with locking.distributed_lock(locking.KB_SHARDS_LOCK.format(kbid=kbid)):
188
560
  await rebalance_kb(context, kbid)