nucliadb 6.7.2.post4874__py3-none-any.whl → 6.10.0.post5705__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (246) hide show
  1. migrations/0023_backfill_pg_catalog.py +8 -4
  2. migrations/0028_extracted_vectors_reference.py +1 -1
  3. migrations/0029_backfill_field_status.py +3 -4
  4. migrations/0032_remove_old_relations.py +2 -3
  5. migrations/0038_backfill_catalog_field_labels.py +8 -4
  6. migrations/0039_backfill_converation_splits_metadata.py +106 -0
  7. migrations/0040_migrate_search_configurations.py +79 -0
  8. migrations/0041_reindex_conversations.py +137 -0
  9. migrations/pg/0010_shards_index.py +34 -0
  10. nucliadb/search/api/v1/resource/utils.py → migrations/pg/0011_catalog_statistics.py +5 -6
  11. migrations/pg/0012_catalog_statistics_undo.py +26 -0
  12. nucliadb/backups/create.py +2 -15
  13. nucliadb/backups/restore.py +4 -15
  14. nucliadb/backups/tasks.py +4 -1
  15. nucliadb/common/back_pressure/cache.py +2 -3
  16. nucliadb/common/back_pressure/materializer.py +7 -13
  17. nucliadb/common/back_pressure/settings.py +6 -6
  18. nucliadb/common/back_pressure/utils.py +1 -0
  19. nucliadb/common/cache.py +9 -9
  20. nucliadb/common/catalog/__init__.py +79 -0
  21. nucliadb/common/catalog/dummy.py +36 -0
  22. nucliadb/common/catalog/interface.py +85 -0
  23. nucliadb/{search/search/pgcatalog.py → common/catalog/pg.py} +330 -232
  24. nucliadb/common/catalog/utils.py +56 -0
  25. nucliadb/common/cluster/manager.py +8 -23
  26. nucliadb/common/cluster/rebalance.py +484 -112
  27. nucliadb/common/cluster/rollover.py +36 -9
  28. nucliadb/common/cluster/settings.py +4 -9
  29. nucliadb/common/cluster/utils.py +34 -8
  30. nucliadb/common/context/__init__.py +7 -8
  31. nucliadb/common/context/fastapi.py +1 -2
  32. nucliadb/common/datamanagers/__init__.py +2 -4
  33. nucliadb/common/datamanagers/atomic.py +9 -2
  34. nucliadb/common/datamanagers/cluster.py +1 -2
  35. nucliadb/common/datamanagers/fields.py +3 -4
  36. nucliadb/common/datamanagers/kb.py +6 -6
  37. nucliadb/common/datamanagers/labels.py +2 -3
  38. nucliadb/common/datamanagers/resources.py +10 -33
  39. nucliadb/common/datamanagers/rollover.py +5 -7
  40. nucliadb/common/datamanagers/search_configurations.py +1 -2
  41. nucliadb/common/datamanagers/synonyms.py +1 -2
  42. nucliadb/common/datamanagers/utils.py +4 -4
  43. nucliadb/common/datamanagers/vectorsets.py +4 -4
  44. nucliadb/common/external_index_providers/base.py +32 -5
  45. nucliadb/common/external_index_providers/manager.py +5 -34
  46. nucliadb/common/external_index_providers/settings.py +1 -27
  47. nucliadb/common/filter_expression.py +129 -41
  48. nucliadb/common/http_clients/exceptions.py +8 -0
  49. nucliadb/common/http_clients/processing.py +16 -23
  50. nucliadb/common/http_clients/utils.py +3 -0
  51. nucliadb/common/ids.py +82 -58
  52. nucliadb/common/locking.py +1 -2
  53. nucliadb/common/maindb/driver.py +9 -8
  54. nucliadb/common/maindb/local.py +5 -5
  55. nucliadb/common/maindb/pg.py +9 -8
  56. nucliadb/common/nidx.py +22 -5
  57. nucliadb/common/vector_index_config.py +1 -1
  58. nucliadb/export_import/datamanager.py +4 -3
  59. nucliadb/export_import/exporter.py +11 -19
  60. nucliadb/export_import/importer.py +13 -6
  61. nucliadb/export_import/tasks.py +2 -0
  62. nucliadb/export_import/utils.py +6 -18
  63. nucliadb/health.py +2 -2
  64. nucliadb/ingest/app.py +8 -8
  65. nucliadb/ingest/consumer/consumer.py +8 -10
  66. nucliadb/ingest/consumer/pull.py +10 -8
  67. nucliadb/ingest/consumer/service.py +5 -30
  68. nucliadb/ingest/consumer/shard_creator.py +16 -5
  69. nucliadb/ingest/consumer/utils.py +1 -1
  70. nucliadb/ingest/fields/base.py +37 -49
  71. nucliadb/ingest/fields/conversation.py +55 -9
  72. nucliadb/ingest/fields/exceptions.py +1 -2
  73. nucliadb/ingest/fields/file.py +22 -8
  74. nucliadb/ingest/fields/link.py +7 -7
  75. nucliadb/ingest/fields/text.py +2 -3
  76. nucliadb/ingest/orm/brain_v2.py +89 -57
  77. nucliadb/ingest/orm/broker_message.py +2 -4
  78. nucliadb/ingest/orm/entities.py +10 -209
  79. nucliadb/ingest/orm/index_message.py +128 -113
  80. nucliadb/ingest/orm/knowledgebox.py +91 -59
  81. nucliadb/ingest/orm/processor/auditing.py +1 -3
  82. nucliadb/ingest/orm/processor/data_augmentation.py +1 -2
  83. nucliadb/ingest/orm/processor/processor.py +98 -153
  84. nucliadb/ingest/orm/processor/sequence_manager.py +1 -2
  85. nucliadb/ingest/orm/resource.py +82 -71
  86. nucliadb/ingest/orm/utils.py +1 -1
  87. nucliadb/ingest/partitions.py +12 -1
  88. nucliadb/ingest/processing.py +17 -17
  89. nucliadb/ingest/serialize.py +202 -145
  90. nucliadb/ingest/service/writer.py +15 -114
  91. nucliadb/ingest/settings.py +36 -15
  92. nucliadb/ingest/utils.py +1 -2
  93. nucliadb/learning_proxy.py +23 -26
  94. nucliadb/metrics_exporter.py +20 -6
  95. nucliadb/middleware/__init__.py +82 -1
  96. nucliadb/migrator/datamanager.py +4 -11
  97. nucliadb/migrator/migrator.py +1 -2
  98. nucliadb/migrator/models.py +1 -2
  99. nucliadb/migrator/settings.py +1 -2
  100. nucliadb/models/internal/augment.py +614 -0
  101. nucliadb/models/internal/processing.py +19 -19
  102. nucliadb/openapi.py +2 -2
  103. nucliadb/purge/__init__.py +3 -8
  104. nucliadb/purge/orphan_shards.py +1 -2
  105. nucliadb/reader/__init__.py +5 -0
  106. nucliadb/reader/api/models.py +6 -13
  107. nucliadb/reader/api/v1/download.py +59 -38
  108. nucliadb/reader/api/v1/export_import.py +4 -4
  109. nucliadb/reader/api/v1/knowledgebox.py +37 -9
  110. nucliadb/reader/api/v1/learning_config.py +33 -14
  111. nucliadb/reader/api/v1/resource.py +61 -9
  112. nucliadb/reader/api/v1/services.py +18 -14
  113. nucliadb/reader/app.py +3 -1
  114. nucliadb/reader/reader/notifications.py +1 -2
  115. nucliadb/search/api/v1/__init__.py +3 -0
  116. nucliadb/search/api/v1/ask.py +3 -4
  117. nucliadb/search/api/v1/augment.py +585 -0
  118. nucliadb/search/api/v1/catalog.py +15 -19
  119. nucliadb/search/api/v1/find.py +16 -22
  120. nucliadb/search/api/v1/hydrate.py +328 -0
  121. nucliadb/search/api/v1/knowledgebox.py +1 -2
  122. nucliadb/search/api/v1/predict_proxy.py +1 -2
  123. nucliadb/search/api/v1/resource/ask.py +28 -8
  124. nucliadb/search/api/v1/resource/ingestion_agents.py +5 -6
  125. nucliadb/search/api/v1/resource/search.py +9 -11
  126. nucliadb/search/api/v1/retrieve.py +130 -0
  127. nucliadb/search/api/v1/search.py +28 -32
  128. nucliadb/search/api/v1/suggest.py +11 -14
  129. nucliadb/search/api/v1/summarize.py +1 -2
  130. nucliadb/search/api/v1/utils.py +2 -2
  131. nucliadb/search/app.py +3 -2
  132. nucliadb/search/augmentor/__init__.py +21 -0
  133. nucliadb/search/augmentor/augmentor.py +232 -0
  134. nucliadb/search/augmentor/fields.py +704 -0
  135. nucliadb/search/augmentor/metrics.py +24 -0
  136. nucliadb/search/augmentor/paragraphs.py +334 -0
  137. nucliadb/search/augmentor/resources.py +238 -0
  138. nucliadb/search/augmentor/utils.py +33 -0
  139. nucliadb/search/lifecycle.py +3 -1
  140. nucliadb/search/predict.py +33 -19
  141. nucliadb/search/predict_models.py +8 -9
  142. nucliadb/search/requesters/utils.py +11 -10
  143. nucliadb/search/search/cache.py +19 -42
  144. nucliadb/search/search/chat/ask.py +131 -59
  145. nucliadb/search/search/chat/exceptions.py +3 -5
  146. nucliadb/search/search/chat/fetcher.py +201 -0
  147. nucliadb/search/search/chat/images.py +6 -4
  148. nucliadb/search/search/chat/old_prompt.py +1375 -0
  149. nucliadb/search/search/chat/parser.py +510 -0
  150. nucliadb/search/search/chat/prompt.py +563 -615
  151. nucliadb/search/search/chat/query.py +453 -32
  152. nucliadb/search/search/chat/rpc.py +85 -0
  153. nucliadb/search/search/fetch.py +3 -4
  154. nucliadb/search/search/filters.py +8 -11
  155. nucliadb/search/search/find.py +33 -31
  156. nucliadb/search/search/find_merge.py +124 -331
  157. nucliadb/search/search/graph_strategy.py +14 -12
  158. nucliadb/search/search/hydrator/__init__.py +49 -0
  159. nucliadb/search/search/hydrator/fields.py +217 -0
  160. nucliadb/search/search/hydrator/images.py +130 -0
  161. nucliadb/search/search/hydrator/paragraphs.py +323 -0
  162. nucliadb/search/search/hydrator/resources.py +60 -0
  163. nucliadb/search/search/ingestion_agents.py +5 -5
  164. nucliadb/search/search/merge.py +90 -94
  165. nucliadb/search/search/metrics.py +24 -7
  166. nucliadb/search/search/paragraphs.py +7 -9
  167. nucliadb/search/search/predict_proxy.py +44 -18
  168. nucliadb/search/search/query.py +14 -86
  169. nucliadb/search/search/query_parser/fetcher.py +51 -82
  170. nucliadb/search/search/query_parser/models.py +19 -48
  171. nucliadb/search/search/query_parser/old_filters.py +20 -19
  172. nucliadb/search/search/query_parser/parsers/ask.py +5 -6
  173. nucliadb/search/search/query_parser/parsers/catalog.py +7 -11
  174. nucliadb/search/search/query_parser/parsers/common.py +21 -13
  175. nucliadb/search/search/query_parser/parsers/find.py +6 -29
  176. nucliadb/search/search/query_parser/parsers/graph.py +18 -28
  177. nucliadb/search/search/query_parser/parsers/retrieve.py +207 -0
  178. nucliadb/search/search/query_parser/parsers/search.py +15 -56
  179. nucliadb/search/search/query_parser/parsers/unit_retrieval.py +8 -29
  180. nucliadb/search/search/rank_fusion.py +18 -13
  181. nucliadb/search/search/rerankers.py +6 -7
  182. nucliadb/search/search/retrieval.py +300 -0
  183. nucliadb/search/search/summarize.py +5 -6
  184. nucliadb/search/search/utils.py +3 -4
  185. nucliadb/search/settings.py +1 -2
  186. nucliadb/standalone/api_router.py +1 -1
  187. nucliadb/standalone/app.py +4 -3
  188. nucliadb/standalone/auth.py +5 -6
  189. nucliadb/standalone/lifecycle.py +2 -2
  190. nucliadb/standalone/run.py +5 -4
  191. nucliadb/standalone/settings.py +5 -6
  192. nucliadb/standalone/versions.py +3 -4
  193. nucliadb/tasks/consumer.py +13 -8
  194. nucliadb/tasks/models.py +2 -1
  195. nucliadb/tasks/producer.py +3 -3
  196. nucliadb/tasks/retries.py +8 -7
  197. nucliadb/train/api/utils.py +1 -3
  198. nucliadb/train/api/v1/shards.py +1 -2
  199. nucliadb/train/api/v1/trainset.py +1 -2
  200. nucliadb/train/app.py +1 -1
  201. nucliadb/train/generator.py +4 -4
  202. nucliadb/train/generators/field_classifier.py +2 -2
  203. nucliadb/train/generators/field_streaming.py +6 -6
  204. nucliadb/train/generators/image_classifier.py +2 -2
  205. nucliadb/train/generators/paragraph_classifier.py +2 -2
  206. nucliadb/train/generators/paragraph_streaming.py +2 -2
  207. nucliadb/train/generators/question_answer_streaming.py +2 -2
  208. nucliadb/train/generators/sentence_classifier.py +4 -10
  209. nucliadb/train/generators/token_classifier.py +3 -2
  210. nucliadb/train/generators/utils.py +6 -5
  211. nucliadb/train/nodes.py +3 -3
  212. nucliadb/train/resource.py +6 -8
  213. nucliadb/train/settings.py +3 -4
  214. nucliadb/train/types.py +11 -11
  215. nucliadb/train/upload.py +3 -2
  216. nucliadb/train/uploader.py +1 -2
  217. nucliadb/train/utils.py +1 -2
  218. nucliadb/writer/api/v1/export_import.py +4 -1
  219. nucliadb/writer/api/v1/field.py +15 -14
  220. nucliadb/writer/api/v1/knowledgebox.py +18 -56
  221. nucliadb/writer/api/v1/learning_config.py +5 -4
  222. nucliadb/writer/api/v1/resource.py +9 -20
  223. nucliadb/writer/api/v1/services.py +10 -132
  224. nucliadb/writer/api/v1/upload.py +73 -72
  225. nucliadb/writer/app.py +8 -2
  226. nucliadb/writer/resource/basic.py +12 -15
  227. nucliadb/writer/resource/field.py +43 -5
  228. nucliadb/writer/resource/origin.py +7 -0
  229. nucliadb/writer/settings.py +2 -3
  230. nucliadb/writer/tus/__init__.py +2 -3
  231. nucliadb/writer/tus/azure.py +5 -7
  232. nucliadb/writer/tus/dm.py +3 -3
  233. nucliadb/writer/tus/exceptions.py +3 -4
  234. nucliadb/writer/tus/gcs.py +15 -22
  235. nucliadb/writer/tus/s3.py +2 -3
  236. nucliadb/writer/tus/storage.py +3 -3
  237. {nucliadb-6.7.2.post4874.dist-info → nucliadb-6.10.0.post5705.dist-info}/METADATA +10 -11
  238. nucliadb-6.10.0.post5705.dist-info/RECORD +410 -0
  239. nucliadb/common/datamanagers/entities.py +0 -139
  240. nucliadb/common/external_index_providers/pinecone.py +0 -894
  241. nucliadb/ingest/orm/processor/pgcatalog.py +0 -129
  242. nucliadb/search/search/hydrator.py +0 -197
  243. nucliadb-6.7.2.post4874.dist-info/RECORD +0 -383
  244. {nucliadb-6.7.2.post4874.dist-info → nucliadb-6.10.0.post5705.dist-info}/WHEEL +0 -0
  245. {nucliadb-6.7.2.post4874.dist-info → nucliadb-6.10.0.post5705.dist-info}/entry_points.txt +0 -0
  246. {nucliadb-6.7.2.post4874.dist-info → nucliadb-6.10.0.post5705.dist-info}/top_level.txt +0 -0
@@ -18,7 +18,6 @@
18
18
  # along with this program. If not, see <http://www.gnu.org/licenses/>.
19
19
  #
20
20
  import asyncio
21
- from typing import Optional
22
21
 
23
22
  from nucliadb.common import datamanagers
24
23
  from nucliadb.common.maindb.utils import get_driver
@@ -36,7 +35,7 @@ from nucliadb_models.search import (
36
35
  from nucliadb_protos.utils_pb2 import ExtractedText
37
36
  from nucliadb_utils.utilities import get_storage
38
37
 
39
- ExtractedTexts = list[tuple[str, str, Optional[ExtractedText]]]
38
+ ExtractedTexts = list[tuple[str, str, ExtractedText | None]]
40
39
 
41
40
  MAX_GET_EXTRACTED_TEXT_OPS = 20
42
41
 
@@ -46,7 +45,7 @@ class NoResourcesToSummarize(Exception):
46
45
 
47
46
 
48
47
  async def summarize(
49
- kbid: str, request: SummarizeRequest, extra_predict_headers: Optional[dict[str, str]]
48
+ kbid: str, request: SummarizeRequest, extra_predict_headers: dict[str, str] | None
50
49
  ) -> SummarizedResponse:
51
50
  predict_request = SummarizeModel()
52
51
  predict_request.generative_model = request.generative_model
@@ -87,7 +86,7 @@ async def get_extracted_texts(kbid: str, resource_uuids_or_slugs: list[str]) ->
87
86
  if uuid is None:
88
87
  logger.warning(f"Resource {uuid_or_slug} not found in KB", extra={"kbid": kbid})
89
88
  continue
90
- resource_orm = Resource(txn=txn, storage=storage, kb=kb_orm, uuid=uuid)
89
+ resource_orm = Resource(txn=txn, storage=storage, kbid=kbid, uuid=uuid)
91
90
  fields = await resource_orm.get_fields(force=True)
92
91
  for _, field in fields.items():
93
92
  task = asyncio.create_task(get_extracted_text(uuid_or_slug, field, max_tasks))
@@ -115,14 +114,14 @@ async def get_extracted_texts(kbid: str, resource_uuids_or_slugs: list[str]) ->
115
114
 
116
115
  async def get_extracted_text(
117
116
  uuid_or_slug, field: Field, max_operations: asyncio.Semaphore
118
- ) -> tuple[str, str, Optional[ExtractedText]]:
117
+ ) -> tuple[str, str, ExtractedText | None]:
119
118
  async with max_operations:
120
119
  extracted_text = await field.get_extracted_text(force=True)
121
120
  field_key = f"{field.type}/{field.id}"
122
121
  return uuid_or_slug, field_key, extracted_text
123
122
 
124
123
 
125
- async def get_resource_uuid(kbobj: KnowledgeBox, uuid_or_slug: str) -> Optional[str]:
124
+ async def get_resource_uuid(kbobj: KnowledgeBox, uuid_or_slug: str) -> str | None:
126
125
  """
127
126
  Return the uuid of the resource with the given uuid_or_slug.
128
127
  """
@@ -18,7 +18,6 @@
18
18
  # along with this program. If not, see <http://www.gnu.org/licenses/>.
19
19
  #
20
20
  import logging
21
- from typing import Optional
22
21
 
23
22
  from pydantic import BaseModel
24
23
 
@@ -30,7 +29,7 @@ from nucliadb_utils.utilities import has_feature
30
29
  logger = logging.getLogger(__name__)
31
30
 
32
31
 
33
- async def filter_hidden_resources(kbid: str, show_hidden: bool) -> Optional[bool]:
32
+ async def filter_hidden_resources(kbid: str, show_hidden: bool) -> bool | None:
34
33
  kb_config = await kb.get_config(kbid=kbid)
35
34
  hidden_enabled = kb_config and kb_config.hidden_resources_enabled
36
35
  if hidden_enabled and not show_hidden:
@@ -41,8 +40,8 @@ async def filter_hidden_resources(kbid: str, show_hidden: bool) -> Optional[bool
41
40
 
42
41
  def min_score_from_query_params(
43
42
  min_score_bm25: float,
44
- min_score_semantic: Optional[float],
45
- deprecated_min_score: Optional[float],
43
+ min_score_semantic: float | None,
44
+ deprecated_min_score: float | None,
46
45
  ) -> MinScore:
47
46
  # Keep backward compatibility with the deprecated min_score parameter
48
47
  semantic = deprecated_min_score if min_score_semantic is None else min_score_semantic
@@ -18,7 +18,6 @@
18
18
  # along with this program. If not, see <http://www.gnu.org/licenses/>.
19
19
  #
20
20
 
21
- from typing import Optional
22
21
 
23
22
  from pydantic import Field
24
23
 
@@ -43,7 +42,7 @@ class Settings(DriverSettings):
43
42
  title="Prequeries max parallel",
44
43
  description="The maximum number of prequeries to run in parallel per /ask request",
45
44
  )
46
- nidx_address: Optional[str] = Field(default=None)
45
+ nidx_address: str | None = Field(default=None)
47
46
 
48
47
 
49
48
  settings = Settings()
@@ -57,7 +57,7 @@ async def api_config_check(request: Request):
57
57
  valid_nua_key = True
58
58
  except Exception as exc:
59
59
  logger.warning(f"Error validating nua key", exc_info=exc)
60
- nua_key_check_error = f"Error checking NUA key: {str(exc)}"
60
+ nua_key_check_error = f"Error checking NUA key: {exc!s}"
61
61
  return JSONResponse(
62
62
  {
63
63
  "nua_api_key": {
@@ -31,7 +31,7 @@ from starlette.responses import HTMLResponse
31
31
  from starlette.routing import Mount
32
32
 
33
33
  import nucliadb_admin_assets # type: ignore
34
- from nucliadb.middleware import ProcessTimeHeaderMiddleware
34
+ from nucliadb.middleware import ClientErrorPayloadLoggerMiddleware, ProcessTimeHeaderMiddleware
35
35
  from nucliadb.reader import API_PREFIX
36
36
  from nucliadb.reader.api.v1.router import api as api_reader_v1
37
37
  from nucliadb.search.api.v1.router import api as api_search_v1
@@ -79,7 +79,7 @@ HOMEPAGE_HTML = """
79
79
  </ul>
80
80
  </body>
81
81
  </html>
82
- """ # noqa: E501
82
+ """
83
83
 
84
84
 
85
85
  def application_factory(settings: Settings) -> FastAPI:
@@ -95,13 +95,13 @@ def application_factory(settings: Settings) -> FastAPI:
95
95
  backend=get_auth_backend(settings),
96
96
  ),
97
97
  Middleware(AuditMiddleware, audit_utility_getter=get_audit),
98
+ Middleware(ClientErrorPayloadLoggerMiddleware),
98
99
  ]
99
100
  if running_settings.debug:
100
101
  middleware.append(Middleware(ProcessTimeHeaderMiddleware))
101
102
 
102
103
  fastapi_settings = dict(
103
104
  debug=running_settings.debug,
104
- middleware=middleware,
105
105
  lifespan=lifespan,
106
106
  exception_handlers={
107
107
  Exception: global_exception_handler,
@@ -122,6 +122,7 @@ def application_factory(settings: Settings) -> FastAPI:
122
122
  prefix_format=f"/{API_PREFIX}/v{{major}}",
123
123
  default_version=(1, 0),
124
124
  enable_latest=False,
125
+ middleware=middleware,
125
126
  kwargs=fastapi_settings,
126
127
  )
127
128
 
@@ -19,7 +19,6 @@
19
19
  import base64
20
20
  import logging
21
21
  import time
22
- from typing import Optional
23
22
 
24
23
  import orjson
25
24
  from jwcrypto import jwe, jwk # type: ignore
@@ -51,7 +50,7 @@ def get_mapped_roles(*, settings: Settings, data: dict[str, str]) -> list[str]:
51
50
 
52
51
  async def authenticate_auth_token(
53
52
  settings: Settings, request: HTTPConnection
54
- ) -> Optional[tuple[AuthCredentials, BaseUser]]:
53
+ ) -> tuple[AuthCredentials, BaseUser] | None:
55
54
  if "eph-token" not in request.query_params or settings.jwk_key is None:
56
55
  return None
57
56
 
@@ -81,7 +80,7 @@ class AuthHeaderAuthenticationBackend(NucliaCloudAuthenticationBackend):
81
80
  def __init__(self, settings: Settings) -> None:
82
81
  self.settings = settings
83
82
 
84
- async def authenticate(self, request: HTTPConnection) -> Optional[tuple[AuthCredentials, BaseUser]]:
83
+ async def authenticate(self, request: HTTPConnection) -> tuple[AuthCredentials, BaseUser] | None:
85
84
  token_resp = await authenticate_auth_token(self.settings, request)
86
85
  if token_resp is not None:
87
86
  return token_resp
@@ -109,7 +108,7 @@ class OAuth2AuthenticationBackend(NucliaCloudAuthenticationBackend):
109
108
  def __init__(self, settings: Settings) -> None:
110
109
  self.settings = settings
111
110
 
112
- async def authenticate(self, request: HTTPConnection) -> Optional[tuple[AuthCredentials, BaseUser]]:
111
+ async def authenticate(self, request: HTTPConnection) -> tuple[AuthCredentials, BaseUser] | None:
113
112
  token_resp = await authenticate_auth_token(self.settings, request)
114
113
  if token_resp is not None:
115
114
  return token_resp
@@ -160,7 +159,7 @@ class BasicAuthAuthenticationBackend(NucliaCloudAuthenticationBackend):
160
159
  def __init__(self, settings: Settings) -> None:
161
160
  self.settings = settings
162
161
 
163
- async def authenticate(self, request: HTTPConnection) -> Optional[tuple[AuthCredentials, BaseUser]]:
162
+ async def authenticate(self, request: HTTPConnection) -> tuple[AuthCredentials, BaseUser] | None:
164
163
  token_resp = await authenticate_auth_token(self.settings, request)
165
164
  if token_resp is not None:
166
165
  return token_resp
@@ -189,7 +188,7 @@ class UpstreamNaiveAuthenticationBackend(NucliaCloudAuthenticationBackend):
189
188
  user_header=settings.auth_policy_user_header,
190
189
  )
191
190
 
192
- async def authenticate(self, request: HTTPConnection) -> Optional[tuple[AuthCredentials, BaseUser]]:
191
+ async def authenticate(self, request: HTTPConnection) -> tuple[AuthCredentials, BaseUser] | None:
193
192
  token_resp = await authenticate_auth_token(self.settings, request)
194
193
  if token_resp is not None:
195
194
  return token_resp
@@ -17,7 +17,7 @@
17
17
  # You should have received a copy of the GNU Affero General Public License
18
18
  # along with this program. If not, see <http://www.gnu.org/licenses/>.
19
19
  #
20
- import asyncio
20
+ import inspect
21
21
  from contextlib import asynccontextmanager
22
22
 
23
23
  from fastapi import FastAPI
@@ -56,7 +56,7 @@ async def lifespan(app: FastAPI):
56
56
  yield
57
57
 
58
58
  for finalizer in SYNC_FINALIZERS:
59
- if asyncio.iscoroutinefunction(finalizer):
59
+ if inspect.iscoroutinefunction(finalizer):
60
60
  await finalizer()
61
61
  else:
62
62
  finalizer()
@@ -21,7 +21,6 @@ import asyncio
21
21
  import logging
22
22
  import os
23
23
  import sys
24
- from typing import Optional
25
24
 
26
25
  import argdantic
27
26
  import uvicorn # type: ignore
@@ -116,6 +115,9 @@ def run():
116
115
  if nuclia_settings.nuclia_service_account:
117
116
  settings_to_output["NUA API key"] = "Configured ✔"
118
117
  settings_to_output["NUA API zone"] = nuclia_settings.nuclia_zone
118
+ settings_to_output["NUA API url"] = (
119
+ nuclia_settings.nuclia_public_url.format(zone=nuclia_settings.nuclia_zone) + "/api"
120
+ )
119
121
 
120
122
  settings_to_output_fmted = "\n".join(
121
123
  [f"|| - {k}:{' ' * (27 - len(k))}{v}" for k, v in settings_to_output.items()]
@@ -145,9 +147,8 @@ def run():
145
147
  server.run()
146
148
 
147
149
 
148
- def get_latest_nucliadb() -> Optional[str]:
149
- loop = asyncio.get_event_loop()
150
- return loop.run_until_complete(versions.latest_nucliadb())
150
+ def get_latest_nucliadb() -> str | None:
151
+ return asyncio.run(versions.latest_nucliadb())
151
152
 
152
153
 
153
154
  async def run_async_nucliadb(settings: Settings) -> uvicorn.Server:
@@ -18,7 +18,6 @@
18
18
  # along with this program. If not, see <http://www.gnu.org/licenses/>.
19
19
  #
20
20
  from enum import Enum
21
- from typing import Optional
22
21
 
23
22
  import pydantic
24
23
 
@@ -44,11 +43,11 @@ class Settings(DriverSettings, StorageSettings, ExtendedStorageSettings):
44
43
  # all settings here are mapped in to other env var settings used
45
44
  # in the app. These are helper settings to make things easier to
46
45
  # use with standalone app vs cluster app.
47
- nua_api_key: Optional[str] = pydantic.Field(
46
+ nua_api_key: str | None = pydantic.Field(
48
47
  default=None,
49
- description="Nuclia Understanding API Key. Read how to generate a NUA Key here: https://docs.nuclia.dev/docs/rag/advanced/understanding/intro#get-a-nua-key", # noqa
48
+ description="Nuclia Understanding API Key. Read how to generate a NUA Key here: https://docs.nuclia.dev/docs/rag/advanced/understanding/intro#get-a-nua-key",
50
49
  )
51
- zone: Optional[str] = pydantic.Field(default=None, description="Nuclia Understanding API Zone ID")
50
+ zone: str | None = pydantic.Field(default=None, description="Nuclia Understanding API Zone ID")
52
51
  http_host: str = pydantic.Field(default="0.0.0.0", description="HTTP Port")
53
52
  http_port: int = pydantic.Field(default=8080, description="HTTP Port")
54
53
  ingest_grpc_port: int = pydantic.Field(default=8030, description="Ingest GRPC Port")
@@ -83,7 +82,7 @@ class Settings(DriverSettings, StorageSettings, ExtendedStorageSettings):
83
82
  description="Default role to assign to user that is authenticated \
84
83
  upstream. Not used with `upstream_naive` auth policy.",
85
84
  )
86
- auth_policy_role_mapping: Optional[dict[str, dict[str, list[NucliaDBRoles]]]] = pydantic.Field(
85
+ auth_policy_role_mapping: dict[str, dict[str, list[NucliaDBRoles]]] | None = pydantic.Field(
87
86
  default=None,
88
87
  description="""
89
88
  Role mapping for `upstream_auth_header`, `upstream_oauth2` and `upstream_basicauth` auth policies.
@@ -97,7 +96,7 @@ Examples:
97
96
  """,
98
97
  )
99
98
 
100
- jwk_key: Optional[str] = pydantic.Field(
99
+ jwk_key: str | None = pydantic.Field(
101
100
  default=None,
102
101
  description="JWK key used for temporary token generation and validation.",
103
102
  )
@@ -20,7 +20,6 @@
20
20
  import enum
21
21
  import importlib.metadata
22
22
  import logging
23
- from typing import Optional
24
23
 
25
24
  from cachetools import TTLCache
26
25
 
@@ -45,11 +44,11 @@ def installed_nucliadb() -> str:
45
44
  return get_installed_version(StandalonePackages.NUCLIADB.value)
46
45
 
47
46
 
48
- async def latest_nucliadb() -> Optional[str]:
47
+ async def latest_nucliadb() -> str | None:
49
48
  return await get_latest_version(StandalonePackages.NUCLIADB.value)
50
49
 
51
50
 
52
- def nucliadb_updates_available(installed: str, latest: Optional[str]) -> bool:
51
+ def nucliadb_updates_available(installed: str, latest: str | None) -> bool:
53
52
  if latest is None:
54
53
  return False
55
54
  return is_newer_release(installed, latest)
@@ -96,7 +95,7 @@ def get_installed_version(package_name: str) -> str:
96
95
  return importlib.metadata.distribution(package_name).version
97
96
 
98
97
 
99
- async def get_latest_version(package: str) -> Optional[str]:
98
+ async def get_latest_version(package: str) -> str | None:
100
99
  result = CACHE.get(package, None)
101
100
  if result is None:
102
101
  try:
@@ -19,9 +19,10 @@
19
19
  #
20
20
 
21
21
  import asyncio
22
- from typing import Generic, Optional, Type
22
+ from typing import Generic
23
23
 
24
24
  import nats
25
+ import nats.js.api
25
26
  import pydantic
26
27
  from nats.aio.client import Msg
27
28
 
@@ -43,8 +44,9 @@ class NatsTaskConsumer(Generic[MsgType]):
43
44
  stream: NatsStream,
44
45
  consumer: NatsConsumer,
45
46
  callback: Callback,
46
- msg_type: Type[MsgType],
47
- max_concurrent_messages: Optional[int] = None,
47
+ msg_type: type[MsgType],
48
+ max_concurrent_messages: int | None = None,
49
+ max_deliver: int | None = None,
48
50
  ):
49
51
  self.name = name
50
52
  self.stream = stream
@@ -52,6 +54,7 @@ class NatsTaskConsumer(Generic[MsgType]):
52
54
  self.callback = callback
53
55
  self.msg_type = msg_type
54
56
  self.max_concurrent_messages = max_concurrent_messages
57
+ self.max_deliver = max_deliver
55
58
  self.initialized = False
56
59
  self.running_tasks: list[asyncio.Task] = []
57
60
  self.subscription = None
@@ -71,7 +74,8 @@ class NatsTaskConsumer(Generic[MsgType]):
71
74
  for task in self.running_tasks:
72
75
  task.cancel()
73
76
  try:
74
- await asyncio.wait(self.running_tasks, timeout=5)
77
+ if len(self.running_tasks) > 0:
78
+ await asyncio.wait(self.running_tasks, timeout=5)
75
79
  self.running_tasks.clear()
76
80
  except asyncio.TimeoutError:
77
81
  pass
@@ -96,6 +100,7 @@ class NatsTaskConsumer(Generic[MsgType]):
96
100
  ack_wait=nats_consumer_settings.nats_ack_wait,
97
101
  idle_heartbeat=nats_consumer_settings.nats_idle_heartbeat,
98
102
  max_ack_pending=max_ack_pending,
103
+ max_deliver=self.max_deliver,
99
104
  ),
100
105
  )
101
106
  logger.info(
@@ -168,8 +173,6 @@ class NatsTaskConsumer(Generic[MsgType]):
168
173
  },
169
174
  )
170
175
  await msg.ack()
171
- finally:
172
- return
173
176
 
174
177
 
175
178
  def create_consumer(
@@ -177,8 +180,9 @@ def create_consumer(
177
180
  stream: NatsStream,
178
181
  consumer: NatsConsumer,
179
182
  callback: Callback,
180
- msg_type: Type[MsgType],
181
- max_concurrent_messages: Optional[int] = None,
183
+ msg_type: type[MsgType],
184
+ max_concurrent_messages: int | None = None,
185
+ max_retries: int = 100,
182
186
  ) -> NatsTaskConsumer[MsgType]:
183
187
  """
184
188
  Returns a non-initialized consumer
@@ -190,4 +194,5 @@ def create_consumer(
190
194
  callback=callback,
191
195
  msg_type=msg_type,
192
196
  max_concurrent_messages=max_concurrent_messages,
197
+ max_deliver=max_retries,
193
198
  )
nucliadb/tasks/models.py CHANGED
@@ -17,7 +17,8 @@
17
17
  # You should have received a copy of the GNU Affero General Public License
18
18
  # along with this program. If not, see <http://www.gnu.org/licenses/>.
19
19
  #
20
- from typing import Any, Callable, Coroutine, TypeVar
20
+ from collections.abc import Callable, Coroutine
21
+ from typing import Any, TypeVar
21
22
 
22
23
  import pydantic
23
24
 
@@ -17,7 +17,7 @@
17
17
  # You should have received a copy of the GNU Affero General Public License
18
18
  # along with this program. If not, see <http://www.gnu.org/licenses/>.
19
19
  #
20
- from typing import Generic, Type
20
+ from typing import Generic
21
21
 
22
22
  from nucliadb.tasks.logger import logger
23
23
  from nucliadb.tasks.models import MsgType
@@ -32,7 +32,7 @@ class NatsTaskProducer(Generic[MsgType]):
32
32
  name: str,
33
33
  stream: NatsStream,
34
34
  producer_subject: str,
35
- msg_type: Type[MsgType],
35
+ msg_type: type[MsgType],
36
36
  ):
37
37
  self.name = name
38
38
  self.stream = stream
@@ -69,7 +69,7 @@ def create_producer(
69
69
  name: str,
70
70
  stream: NatsStream,
71
71
  producer_subject: str,
72
- msg_type: Type[MsgType],
72
+ msg_type: type[MsgType],
73
73
  ) -> NatsTaskProducer[MsgType]:
74
74
  """
75
75
  Returns a non-initialized producer.
nucliadb/tasks/retries.py CHANGED
@@ -19,9 +19,10 @@
19
19
  #
20
20
  import functools
21
21
  import logging
22
+ from collections.abc import Callable
22
23
  from datetime import datetime, timezone
23
24
  from enum import Enum
24
- from typing import Callable, Optional, cast
25
+ from typing import cast
25
26
 
26
27
  from pydantic import BaseModel
27
28
 
@@ -44,7 +45,7 @@ class TaskMetadata(BaseModel):
44
45
  status: Status
45
46
  retries: int = 0
46
47
  error_messages: list[str] = []
47
- last_modified: Optional[datetime] = None
48
+ last_modified: datetime | None = None
48
49
 
49
50
 
50
51
  class TaskRetryHandler:
@@ -87,7 +88,7 @@ class TaskRetryHandler:
87
88
  kbid=self.kbid, task_type=self.task_type, task_id=self.task_id
88
89
  )
89
90
 
90
- async def get_metadata(self) -> Optional[TaskMetadata]:
91
+ async def get_metadata(self) -> TaskMetadata | None:
91
92
  return await _get_metadata(self.context.kv_driver, self.metadata_key)
92
93
 
93
94
  async def set_metadata(self, metadata: TaskMetadata) -> None:
@@ -150,7 +151,7 @@ class TaskRetryHandler:
150
151
  return wrapper
151
152
 
152
153
 
153
- async def _get_metadata(kv_driver: Driver, metadata_key: str) -> Optional[TaskMetadata]:
154
+ async def _get_metadata(kv_driver: Driver, metadata_key: str) -> TaskMetadata | None:
154
155
  async with kv_driver.ro_transaction() as txn:
155
156
  metadata = await txn.get(metadata_key)
156
157
  if metadata is None:
@@ -173,7 +174,7 @@ async def purge_metadata(kv_driver: Driver) -> int:
173
174
  return 0
174
175
 
175
176
  total_purged = 0
176
- start: Optional[str] = ""
177
+ start: str | None = ""
177
178
  while True:
178
179
  start, purged = await purge_batch(kv_driver, start)
179
180
  total_purged += purged
@@ -183,8 +184,8 @@ async def purge_metadata(kv_driver: Driver) -> int:
183
184
 
184
185
 
185
186
  async def purge_batch(
186
- kv_driver: PGDriver, start: Optional[str] = None, batch_size: int = 200
187
- ) -> tuple[Optional[str], int]:
187
+ kv_driver: PGDriver, start: str | None = None, batch_size: int = 200
188
+ ) -> tuple[str | None, int]:
188
189
  """
189
190
  Returns the next start key and the number of purged records. If start is None, it means there are no more records to purge.
190
191
  """
@@ -19,12 +19,10 @@
19
19
  #
20
20
 
21
21
 
22
- from typing import Optional
23
-
24
22
  from nucliadb.train.utils import get_shard_manager
25
23
 
26
24
 
27
- async def get_kb_partitions(kbid: str, prefix: Optional[str] = None) -> list[str]:
25
+ async def get_kb_partitions(kbid: str, prefix: str | None = None) -> list[str]:
28
26
  shard_manager = get_shard_manager()
29
27
  shards = await shard_manager.get_shards_by_kbid_inner(kbid=kbid)
30
28
  valid_shards = []
@@ -18,7 +18,6 @@
18
18
  # along with this program. If not, see <http://www.gnu.org/licenses/>.
19
19
  #
20
20
  import json
21
- from typing import Optional
22
21
 
23
22
  import google.protobuf.message
24
23
  import pydantic
@@ -63,7 +62,7 @@ async def object_get_response(
63
62
  )
64
63
 
65
64
 
66
- async def get_trainset(request: Request) -> tuple[TrainSet, Optional[FilterExpression]]:
65
+ async def get_trainset(request: Request) -> tuple[TrainSet, FilterExpression | None]:
67
66
  if request.headers.get("Content-Type") == "application/json":
68
67
  try:
69
68
  trainset_model = TrainSetModel.model_validate(await request.json())
@@ -18,7 +18,6 @@
18
18
  # along with this program. If not, see <http://www.gnu.org/licenses/>.
19
19
  #
20
20
 
21
- from typing import Optional
22
21
 
23
22
  from fastapi import HTTPException, Request
24
23
  from fastapi_versioning import version
@@ -57,7 +56,7 @@ async def get_partitions_prefix(request: Request, kbid: str, prefix: str) -> Tra
57
56
  return await get_partitions(kbid, prefix=prefix)
58
57
 
59
58
 
60
- async def get_partitions(kbid: str, prefix: Optional[str] = None) -> TrainSetPartitions:
59
+ async def get_partitions(kbid: str, prefix: str | None = None) -> TrainSetPartitions:
61
60
  try:
62
61
  all_keys = await get_kb_partitions(kbid, prefix)
63
62
  except ShardNotFound:
nucliadb/train/app.py CHANGED
@@ -50,7 +50,6 @@ errors.setup_error_handling(importlib.metadata.distribution("nucliadb").version)
50
50
 
51
51
  fastapi_settings = dict(
52
52
  debug=running_settings.debug,
53
- middleware=middleware,
54
53
  lifespan=lifespan,
55
54
  exception_handlers={
56
55
  Exception: global_exception_handler,
@@ -71,6 +70,7 @@ application = VersionedFastAPI(
71
70
  prefix_format=f"/{API_PREFIX}/v{{major}}",
72
71
  default_version=(1, 0),
73
72
  enable_latest=False,
73
+ middleware=middleware,
74
74
  kwargs=fastapi_settings,
75
75
  )
76
76
 
@@ -17,7 +17,7 @@
17
17
  # You should have received a copy of the GNU Affero General Public License
18
18
  # along with this program. If not, see <http://www.gnu.org/licenses/>.
19
19
  #
20
- from typing import AsyncIterator, Callable, Optional
20
+ from collections.abc import AsyncIterator, Callable
21
21
 
22
22
  from fastapi import HTTPException
23
23
  from grpc import StatusCode
@@ -53,11 +53,11 @@ from nucliadb.train.utils import get_shard_manager
53
53
  from nucliadb_models.filters import FilterExpression
54
54
  from nucliadb_protos.dataset_pb2 import TaskType, TrainSet
55
55
 
56
- BatchGenerator = Callable[[str, TrainSet, str, Optional[FilterExpression]], AsyncIterator[TrainBatch]]
56
+ BatchGenerator = Callable[[str, TrainSet, str, FilterExpression | None], AsyncIterator[TrainBatch]]
57
57
 
58
58
 
59
59
  async def generate_train_data(
60
- kbid: str, shard: str, trainset: TrainSet, filter_expression: Optional[FilterExpression] = None
60
+ kbid: str, shard: str, trainset: TrainSet, filter_expression: FilterExpression | None = None
61
61
  ):
62
62
  # Get the data structure to generate data
63
63
  shard_manager = get_shard_manager()
@@ -66,7 +66,7 @@ async def generate_train_data(
66
66
  if trainset.batch_size == 0:
67
67
  trainset.batch_size = 50
68
68
 
69
- batch_generator: Optional[BatchGenerator] = None
69
+ batch_generator: BatchGenerator | None = None
70
70
 
71
71
  if trainset.type == TaskType.FIELD_CLASSIFICATION:
72
72
  batch_generator = field_classification_batch_generator
@@ -18,7 +18,7 @@
18
18
  # along with this program. If not, see <http://www.gnu.org/licenses/>.
19
19
  #
20
20
 
21
- from typing import AsyncGenerator, Optional
21
+ from collections.abc import AsyncGenerator
22
22
 
23
23
  from nidx_protos.nodereader_pb2 import StreamRequest
24
24
 
@@ -39,7 +39,7 @@ def field_classification_batch_generator(
39
39
  kbid: str,
40
40
  trainset: TrainSet,
41
41
  shard_replica_id: str,
42
- filter_expression: Optional[FilterExpression],
42
+ filter_expression: FilterExpression | None,
43
43
  ) -> AsyncGenerator[FieldClassificationBatch, None]:
44
44
  generator = generate_field_classification_payloads(kbid, trainset, shard_replica_id)
45
45
  batch_generator = batchify(generator, trainset.batch_size, FieldClassificationBatch)
@@ -19,7 +19,7 @@
19
19
  #
20
20
 
21
21
  import asyncio
22
- from typing import AsyncGenerator, AsyncIterable, Optional
22
+ from collections.abc import AsyncGenerator, AsyncIterable
23
23
 
24
24
  from nidx_protos.nodereader_pb2 import DocumentItem, StreamRequest
25
25
 
@@ -45,7 +45,7 @@ def field_streaming_batch_generator(
45
45
  kbid: str,
46
46
  trainset: TrainSet,
47
47
  shard_replica_id: str,
48
- filter_expression: Optional[FilterExpression],
48
+ filter_expression: FilterExpression | None,
49
49
  ) -> AsyncGenerator[FieldStreamingBatch, None]:
50
50
  generator = generate_field_streaming_payloads(kbid, trainset, shard_replica_id, filter_expression)
51
51
  batch_generator = batchify(generator, trainset.batch_size, FieldStreamingBatch)
@@ -53,7 +53,7 @@ def field_streaming_batch_generator(
53
53
 
54
54
 
55
55
  async def generate_field_streaming_payloads(
56
- kbid: str, trainset: TrainSet, shard_replica_id: str, filter_expression: Optional[FilterExpression]
56
+ kbid: str, trainset: TrainSet, shard_replica_id: str, filter_expression: FilterExpression | None
57
57
  ) -> AsyncGenerator[FieldSplitData, None]:
58
58
  request = StreamRequest()
59
59
  request.shard_id.id = shard_replica_id
@@ -192,7 +192,7 @@ async def _fetch_basic(kbid: str, fsd: FieldSplitData):
192
192
  fsd.basic.CopyFrom(basic)
193
193
 
194
194
 
195
- async def get_field_text(kbid: str, rid: str, field: str, field_type: str) -> Optional[ExtractedText]:
195
+ async def get_field_text(kbid: str, rid: str, field: str, field_type: str) -> ExtractedText | None:
196
196
  orm_resource = await get_resource_from_cache_or_db(kbid, rid)
197
197
 
198
198
  if orm_resource is None:
@@ -208,7 +208,7 @@ async def get_field_text(kbid: str, rid: str, field: str, field_type: str) -> Op
208
208
 
209
209
  async def get_field_metadata(
210
210
  kbid: str, rid: str, field: str, field_type: str
211
- ) -> Optional[FieldComputedMetadata]:
211
+ ) -> FieldComputedMetadata | None:
212
212
  orm_resource = await get_resource_from_cache_or_db(kbid, rid)
213
213
 
214
214
  if orm_resource is None:
@@ -222,7 +222,7 @@ async def get_field_metadata(
222
222
  return field_metadata
223
223
 
224
224
 
225
- async def get_field_basic(kbid: str, rid: str, field: str, field_type: str) -> Optional[Basic]:
225
+ async def get_field_basic(kbid: str, rid: str, field: str, field_type: str) -> Basic | None:
226
226
  orm_resource = await get_resource_from_cache_or_db(kbid, rid)
227
227
 
228
228
  if orm_resource is None: