nucliadb 6.9.1.post5192__py3-none-any.whl → 6.10.0.post5705__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (231) hide show
  1. migrations/0023_backfill_pg_catalog.py +2 -2
  2. migrations/0029_backfill_field_status.py +3 -4
  3. migrations/0032_remove_old_relations.py +2 -3
  4. migrations/0038_backfill_catalog_field_labels.py +2 -2
  5. migrations/0039_backfill_converation_splits_metadata.py +2 -2
  6. migrations/0041_reindex_conversations.py +137 -0
  7. migrations/pg/0010_shards_index.py +34 -0
  8. nucliadb/search/api/v1/resource/utils.py → migrations/pg/0011_catalog_statistics.py +5 -6
  9. migrations/pg/0012_catalog_statistics_undo.py +26 -0
  10. nucliadb/backups/create.py +2 -15
  11. nucliadb/backups/restore.py +4 -15
  12. nucliadb/backups/tasks.py +4 -1
  13. nucliadb/common/back_pressure/cache.py +2 -3
  14. nucliadb/common/back_pressure/materializer.py +7 -13
  15. nucliadb/common/back_pressure/settings.py +6 -6
  16. nucliadb/common/back_pressure/utils.py +1 -0
  17. nucliadb/common/cache.py +9 -9
  18. nucliadb/common/catalog/interface.py +12 -12
  19. nucliadb/common/catalog/pg.py +41 -29
  20. nucliadb/common/catalog/utils.py +3 -3
  21. nucliadb/common/cluster/manager.py +5 -4
  22. nucliadb/common/cluster/rebalance.py +483 -114
  23. nucliadb/common/cluster/rollover.py +25 -9
  24. nucliadb/common/cluster/settings.py +3 -8
  25. nucliadb/common/cluster/utils.py +34 -8
  26. nucliadb/common/context/__init__.py +7 -8
  27. nucliadb/common/context/fastapi.py +1 -2
  28. nucliadb/common/datamanagers/__init__.py +2 -4
  29. nucliadb/common/datamanagers/atomic.py +4 -2
  30. nucliadb/common/datamanagers/cluster.py +1 -2
  31. nucliadb/common/datamanagers/fields.py +3 -4
  32. nucliadb/common/datamanagers/kb.py +6 -6
  33. nucliadb/common/datamanagers/labels.py +2 -3
  34. nucliadb/common/datamanagers/resources.py +10 -33
  35. nucliadb/common/datamanagers/rollover.py +5 -7
  36. nucliadb/common/datamanagers/search_configurations.py +1 -2
  37. nucliadb/common/datamanagers/synonyms.py +1 -2
  38. nucliadb/common/datamanagers/utils.py +4 -4
  39. nucliadb/common/datamanagers/vectorsets.py +4 -4
  40. nucliadb/common/external_index_providers/base.py +32 -5
  41. nucliadb/common/external_index_providers/manager.py +4 -5
  42. nucliadb/common/filter_expression.py +128 -40
  43. nucliadb/common/http_clients/processing.py +12 -23
  44. nucliadb/common/ids.py +6 -4
  45. nucliadb/common/locking.py +1 -2
  46. nucliadb/common/maindb/driver.py +9 -8
  47. nucliadb/common/maindb/local.py +5 -5
  48. nucliadb/common/maindb/pg.py +9 -8
  49. nucliadb/common/nidx.py +3 -4
  50. nucliadb/export_import/datamanager.py +4 -3
  51. nucliadb/export_import/exporter.py +11 -19
  52. nucliadb/export_import/importer.py +13 -6
  53. nucliadb/export_import/tasks.py +2 -0
  54. nucliadb/export_import/utils.py +6 -18
  55. nucliadb/health.py +2 -2
  56. nucliadb/ingest/app.py +8 -8
  57. nucliadb/ingest/consumer/consumer.py +8 -10
  58. nucliadb/ingest/consumer/pull.py +3 -8
  59. nucliadb/ingest/consumer/service.py +3 -3
  60. nucliadb/ingest/consumer/utils.py +1 -1
  61. nucliadb/ingest/fields/base.py +28 -49
  62. nucliadb/ingest/fields/conversation.py +12 -12
  63. nucliadb/ingest/fields/exceptions.py +1 -2
  64. nucliadb/ingest/fields/file.py +22 -8
  65. nucliadb/ingest/fields/link.py +7 -7
  66. nucliadb/ingest/fields/text.py +2 -3
  67. nucliadb/ingest/orm/brain_v2.py +78 -64
  68. nucliadb/ingest/orm/broker_message.py +2 -4
  69. nucliadb/ingest/orm/entities.py +10 -209
  70. nucliadb/ingest/orm/index_message.py +4 -4
  71. nucliadb/ingest/orm/knowledgebox.py +18 -27
  72. nucliadb/ingest/orm/processor/auditing.py +1 -3
  73. nucliadb/ingest/orm/processor/data_augmentation.py +1 -2
  74. nucliadb/ingest/orm/processor/processor.py +27 -27
  75. nucliadb/ingest/orm/processor/sequence_manager.py +1 -2
  76. nucliadb/ingest/orm/resource.py +72 -70
  77. nucliadb/ingest/orm/utils.py +1 -1
  78. nucliadb/ingest/processing.py +17 -17
  79. nucliadb/ingest/serialize.py +202 -145
  80. nucliadb/ingest/service/writer.py +3 -109
  81. nucliadb/ingest/settings.py +3 -4
  82. nucliadb/ingest/utils.py +1 -2
  83. nucliadb/learning_proxy.py +11 -11
  84. nucliadb/metrics_exporter.py +5 -4
  85. nucliadb/middleware/__init__.py +82 -1
  86. nucliadb/migrator/datamanager.py +3 -4
  87. nucliadb/migrator/migrator.py +1 -2
  88. nucliadb/migrator/models.py +1 -2
  89. nucliadb/migrator/settings.py +1 -2
  90. nucliadb/models/internal/augment.py +614 -0
  91. nucliadb/models/internal/processing.py +19 -19
  92. nucliadb/openapi.py +2 -2
  93. nucliadb/purge/__init__.py +3 -8
  94. nucliadb/purge/orphan_shards.py +1 -2
  95. nucliadb/reader/__init__.py +5 -0
  96. nucliadb/reader/api/models.py +6 -13
  97. nucliadb/reader/api/v1/download.py +59 -38
  98. nucliadb/reader/api/v1/export_import.py +4 -4
  99. nucliadb/reader/api/v1/learning_config.py +24 -4
  100. nucliadb/reader/api/v1/resource.py +61 -9
  101. nucliadb/reader/api/v1/services.py +18 -14
  102. nucliadb/reader/app.py +3 -1
  103. nucliadb/reader/reader/notifications.py +1 -2
  104. nucliadb/search/api/v1/__init__.py +2 -0
  105. nucliadb/search/api/v1/ask.py +3 -4
  106. nucliadb/search/api/v1/augment.py +585 -0
  107. nucliadb/search/api/v1/catalog.py +11 -15
  108. nucliadb/search/api/v1/find.py +16 -22
  109. nucliadb/search/api/v1/hydrate.py +25 -25
  110. nucliadb/search/api/v1/knowledgebox.py +1 -2
  111. nucliadb/search/api/v1/predict_proxy.py +1 -2
  112. nucliadb/search/api/v1/resource/ask.py +7 -7
  113. nucliadb/search/api/v1/resource/ingestion_agents.py +5 -6
  114. nucliadb/search/api/v1/resource/search.py +9 -11
  115. nucliadb/search/api/v1/retrieve.py +130 -0
  116. nucliadb/search/api/v1/search.py +28 -32
  117. nucliadb/search/api/v1/suggest.py +11 -14
  118. nucliadb/search/api/v1/summarize.py +1 -2
  119. nucliadb/search/api/v1/utils.py +2 -2
  120. nucliadb/search/app.py +3 -2
  121. nucliadb/search/augmentor/__init__.py +21 -0
  122. nucliadb/search/augmentor/augmentor.py +232 -0
  123. nucliadb/search/augmentor/fields.py +704 -0
  124. nucliadb/search/augmentor/metrics.py +24 -0
  125. nucliadb/search/augmentor/paragraphs.py +334 -0
  126. nucliadb/search/augmentor/resources.py +238 -0
  127. nucliadb/search/augmentor/utils.py +33 -0
  128. nucliadb/search/lifecycle.py +3 -1
  129. nucliadb/search/predict.py +24 -17
  130. nucliadb/search/predict_models.py +8 -9
  131. nucliadb/search/requesters/utils.py +11 -10
  132. nucliadb/search/search/cache.py +19 -23
  133. nucliadb/search/search/chat/ask.py +88 -59
  134. nucliadb/search/search/chat/exceptions.py +3 -5
  135. nucliadb/search/search/chat/fetcher.py +201 -0
  136. nucliadb/search/search/chat/images.py +6 -4
  137. nucliadb/search/search/chat/old_prompt.py +1375 -0
  138. nucliadb/search/search/chat/parser.py +510 -0
  139. nucliadb/search/search/chat/prompt.py +563 -615
  140. nucliadb/search/search/chat/query.py +449 -36
  141. nucliadb/search/search/chat/rpc.py +85 -0
  142. nucliadb/search/search/fetch.py +3 -4
  143. nucliadb/search/search/filters.py +8 -11
  144. nucliadb/search/search/find.py +33 -31
  145. nucliadb/search/search/find_merge.py +124 -331
  146. nucliadb/search/search/graph_strategy.py +14 -12
  147. nucliadb/search/search/hydrator/__init__.py +3 -152
  148. nucliadb/search/search/hydrator/fields.py +92 -50
  149. nucliadb/search/search/hydrator/images.py +7 -7
  150. nucliadb/search/search/hydrator/paragraphs.py +42 -26
  151. nucliadb/search/search/hydrator/resources.py +20 -16
  152. nucliadb/search/search/ingestion_agents.py +5 -5
  153. nucliadb/search/search/merge.py +90 -94
  154. nucliadb/search/search/metrics.py +10 -9
  155. nucliadb/search/search/paragraphs.py +7 -9
  156. nucliadb/search/search/predict_proxy.py +13 -9
  157. nucliadb/search/search/query.py +14 -86
  158. nucliadb/search/search/query_parser/fetcher.py +51 -82
  159. nucliadb/search/search/query_parser/models.py +19 -20
  160. nucliadb/search/search/query_parser/old_filters.py +20 -19
  161. nucliadb/search/search/query_parser/parsers/ask.py +4 -5
  162. nucliadb/search/search/query_parser/parsers/catalog.py +5 -6
  163. nucliadb/search/search/query_parser/parsers/common.py +5 -6
  164. nucliadb/search/search/query_parser/parsers/find.py +6 -26
  165. nucliadb/search/search/query_parser/parsers/graph.py +13 -23
  166. nucliadb/search/search/query_parser/parsers/retrieve.py +207 -0
  167. nucliadb/search/search/query_parser/parsers/search.py +15 -53
  168. nucliadb/search/search/query_parser/parsers/unit_retrieval.py +8 -29
  169. nucliadb/search/search/rank_fusion.py +18 -13
  170. nucliadb/search/search/rerankers.py +5 -6
  171. nucliadb/search/search/retrieval.py +300 -0
  172. nucliadb/search/search/summarize.py +5 -6
  173. nucliadb/search/search/utils.py +3 -4
  174. nucliadb/search/settings.py +1 -2
  175. nucliadb/standalone/api_router.py +1 -1
  176. nucliadb/standalone/app.py +4 -3
  177. nucliadb/standalone/auth.py +5 -6
  178. nucliadb/standalone/lifecycle.py +2 -2
  179. nucliadb/standalone/run.py +2 -4
  180. nucliadb/standalone/settings.py +5 -6
  181. nucliadb/standalone/versions.py +3 -4
  182. nucliadb/tasks/consumer.py +13 -8
  183. nucliadb/tasks/models.py +2 -1
  184. nucliadb/tasks/producer.py +3 -3
  185. nucliadb/tasks/retries.py +8 -7
  186. nucliadb/train/api/utils.py +1 -3
  187. nucliadb/train/api/v1/shards.py +1 -2
  188. nucliadb/train/api/v1/trainset.py +1 -2
  189. nucliadb/train/app.py +1 -1
  190. nucliadb/train/generator.py +4 -4
  191. nucliadb/train/generators/field_classifier.py +2 -2
  192. nucliadb/train/generators/field_streaming.py +6 -6
  193. nucliadb/train/generators/image_classifier.py +2 -2
  194. nucliadb/train/generators/paragraph_classifier.py +2 -2
  195. nucliadb/train/generators/paragraph_streaming.py +2 -2
  196. nucliadb/train/generators/question_answer_streaming.py +2 -2
  197. nucliadb/train/generators/sentence_classifier.py +2 -2
  198. nucliadb/train/generators/token_classifier.py +3 -2
  199. nucliadb/train/generators/utils.py +6 -5
  200. nucliadb/train/nodes.py +3 -3
  201. nucliadb/train/resource.py +6 -8
  202. nucliadb/train/settings.py +3 -4
  203. nucliadb/train/types.py +11 -11
  204. nucliadb/train/upload.py +3 -2
  205. nucliadb/train/uploader.py +1 -2
  206. nucliadb/train/utils.py +1 -2
  207. nucliadb/writer/api/v1/export_import.py +4 -1
  208. nucliadb/writer/api/v1/field.py +7 -11
  209. nucliadb/writer/api/v1/knowledgebox.py +3 -4
  210. nucliadb/writer/api/v1/resource.py +9 -20
  211. nucliadb/writer/api/v1/services.py +10 -132
  212. nucliadb/writer/api/v1/upload.py +73 -72
  213. nucliadb/writer/app.py +8 -2
  214. nucliadb/writer/resource/basic.py +12 -15
  215. nucliadb/writer/resource/field.py +7 -5
  216. nucliadb/writer/resource/origin.py +7 -0
  217. nucliadb/writer/settings.py +2 -3
  218. nucliadb/writer/tus/__init__.py +2 -3
  219. nucliadb/writer/tus/azure.py +1 -3
  220. nucliadb/writer/tus/dm.py +3 -3
  221. nucliadb/writer/tus/exceptions.py +3 -4
  222. nucliadb/writer/tus/gcs.py +5 -6
  223. nucliadb/writer/tus/s3.py +2 -3
  224. nucliadb/writer/tus/storage.py +3 -3
  225. {nucliadb-6.9.1.post5192.dist-info → nucliadb-6.10.0.post5705.dist-info}/METADATA +9 -10
  226. nucliadb-6.10.0.post5705.dist-info/RECORD +410 -0
  227. nucliadb/common/datamanagers/entities.py +0 -139
  228. nucliadb-6.9.1.post5192.dist-info/RECORD +0 -392
  229. {nucliadb-6.9.1.post5192.dist-info → nucliadb-6.10.0.post5705.dist-info}/WHEEL +0 -0
  230. {nucliadb-6.9.1.post5192.dist-info → nucliadb-6.10.0.post5705.dist-info}/entry_points.txt +0 -0
  231. {nucliadb-6.9.1.post5192.dist-info → nucliadb-6.10.0.post5705.dist-info}/top_level.txt +0 -0
@@ -19,7 +19,6 @@
19
19
  import base64
20
20
  import logging
21
21
  import time
22
- from typing import Optional
23
22
 
24
23
  import orjson
25
24
  from jwcrypto import jwe, jwk # type: ignore
@@ -51,7 +50,7 @@ def get_mapped_roles(*, settings: Settings, data: dict[str, str]) -> list[str]:
51
50
 
52
51
  async def authenticate_auth_token(
53
52
  settings: Settings, request: HTTPConnection
54
- ) -> Optional[tuple[AuthCredentials, BaseUser]]:
53
+ ) -> tuple[AuthCredentials, BaseUser] | None:
55
54
  if "eph-token" not in request.query_params or settings.jwk_key is None:
56
55
  return None
57
56
 
@@ -81,7 +80,7 @@ class AuthHeaderAuthenticationBackend(NucliaCloudAuthenticationBackend):
81
80
  def __init__(self, settings: Settings) -> None:
82
81
  self.settings = settings
83
82
 
84
- async def authenticate(self, request: HTTPConnection) -> Optional[tuple[AuthCredentials, BaseUser]]:
83
+ async def authenticate(self, request: HTTPConnection) -> tuple[AuthCredentials, BaseUser] | None:
85
84
  token_resp = await authenticate_auth_token(self.settings, request)
86
85
  if token_resp is not None:
87
86
  return token_resp
@@ -109,7 +108,7 @@ class OAuth2AuthenticationBackend(NucliaCloudAuthenticationBackend):
109
108
  def __init__(self, settings: Settings) -> None:
110
109
  self.settings = settings
111
110
 
112
- async def authenticate(self, request: HTTPConnection) -> Optional[tuple[AuthCredentials, BaseUser]]:
111
+ async def authenticate(self, request: HTTPConnection) -> tuple[AuthCredentials, BaseUser] | None:
113
112
  token_resp = await authenticate_auth_token(self.settings, request)
114
113
  if token_resp is not None:
115
114
  return token_resp
@@ -160,7 +159,7 @@ class BasicAuthAuthenticationBackend(NucliaCloudAuthenticationBackend):
160
159
  def __init__(self, settings: Settings) -> None:
161
160
  self.settings = settings
162
161
 
163
- async def authenticate(self, request: HTTPConnection) -> Optional[tuple[AuthCredentials, BaseUser]]:
162
+ async def authenticate(self, request: HTTPConnection) -> tuple[AuthCredentials, BaseUser] | None:
164
163
  token_resp = await authenticate_auth_token(self.settings, request)
165
164
  if token_resp is not None:
166
165
  return token_resp
@@ -189,7 +188,7 @@ class UpstreamNaiveAuthenticationBackend(NucliaCloudAuthenticationBackend):
189
188
  user_header=settings.auth_policy_user_header,
190
189
  )
191
190
 
192
- async def authenticate(self, request: HTTPConnection) -> Optional[tuple[AuthCredentials, BaseUser]]:
191
+ async def authenticate(self, request: HTTPConnection) -> tuple[AuthCredentials, BaseUser] | None:
193
192
  token_resp = await authenticate_auth_token(self.settings, request)
194
193
  if token_resp is not None:
195
194
  return token_resp
@@ -17,7 +17,7 @@
17
17
  # You should have received a copy of the GNU Affero General Public License
18
18
  # along with this program. If not, see <http://www.gnu.org/licenses/>.
19
19
  #
20
- import asyncio
20
+ import inspect
21
21
  from contextlib import asynccontextmanager
22
22
 
23
23
  from fastapi import FastAPI
@@ -56,7 +56,7 @@ async def lifespan(app: FastAPI):
56
56
  yield
57
57
 
58
58
  for finalizer in SYNC_FINALIZERS:
59
- if asyncio.iscoroutinefunction(finalizer):
59
+ if inspect.iscoroutinefunction(finalizer):
60
60
  await finalizer()
61
61
  else:
62
62
  finalizer()
@@ -21,7 +21,6 @@ import asyncio
21
21
  import logging
22
22
  import os
23
23
  import sys
24
- from typing import Optional
25
24
 
26
25
  import argdantic
27
26
  import uvicorn # type: ignore
@@ -148,9 +147,8 @@ def run():
148
147
  server.run()
149
148
 
150
149
 
151
- def get_latest_nucliadb() -> Optional[str]:
152
- loop = asyncio.get_event_loop()
153
- return loop.run_until_complete(versions.latest_nucliadb())
150
+ def get_latest_nucliadb() -> str | None:
151
+ return asyncio.run(versions.latest_nucliadb())
154
152
 
155
153
 
156
154
  async def run_async_nucliadb(settings: Settings) -> uvicorn.Server:
@@ -18,7 +18,6 @@
18
18
  # along with this program. If not, see <http://www.gnu.org/licenses/>.
19
19
  #
20
20
  from enum import Enum
21
- from typing import Optional
22
21
 
23
22
  import pydantic
24
23
 
@@ -44,11 +43,11 @@ class Settings(DriverSettings, StorageSettings, ExtendedStorageSettings):
44
43
  # all settings here are mapped in to other env var settings used
45
44
  # in the app. These are helper settings to make things easier to
46
45
  # use with standalone app vs cluster app.
47
- nua_api_key: Optional[str] = pydantic.Field(
46
+ nua_api_key: str | None = pydantic.Field(
48
47
  default=None,
49
- description="Nuclia Understanding API Key. Read how to generate a NUA Key here: https://docs.nuclia.dev/docs/rag/advanced/understanding/intro#get-a-nua-key", # noqa
48
+ description="Nuclia Understanding API Key. Read how to generate a NUA Key here: https://docs.nuclia.dev/docs/rag/advanced/understanding/intro#get-a-nua-key",
50
49
  )
51
- zone: Optional[str] = pydantic.Field(default=None, description="Nuclia Understanding API Zone ID")
50
+ zone: str | None = pydantic.Field(default=None, description="Nuclia Understanding API Zone ID")
52
51
  http_host: str = pydantic.Field(default="0.0.0.0", description="HTTP Port")
53
52
  http_port: int = pydantic.Field(default=8080, description="HTTP Port")
54
53
  ingest_grpc_port: int = pydantic.Field(default=8030, description="Ingest GRPC Port")
@@ -83,7 +82,7 @@ class Settings(DriverSettings, StorageSettings, ExtendedStorageSettings):
83
82
  description="Default role to assign to user that is authenticated \
84
83
  upstream. Not used with `upstream_naive` auth policy.",
85
84
  )
86
- auth_policy_role_mapping: Optional[dict[str, dict[str, list[NucliaDBRoles]]]] = pydantic.Field(
85
+ auth_policy_role_mapping: dict[str, dict[str, list[NucliaDBRoles]]] | None = pydantic.Field(
87
86
  default=None,
88
87
  description="""
89
88
  Role mapping for `upstream_auth_header`, `upstream_oauth2` and `upstream_basicauth` auth policies.
@@ -97,7 +96,7 @@ Examples:
97
96
  """,
98
97
  )
99
98
 
100
- jwk_key: Optional[str] = pydantic.Field(
99
+ jwk_key: str | None = pydantic.Field(
101
100
  default=None,
102
101
  description="JWK key used for temporary token generation and validation.",
103
102
  )
@@ -20,7 +20,6 @@
20
20
  import enum
21
21
  import importlib.metadata
22
22
  import logging
23
- from typing import Optional
24
23
 
25
24
  from cachetools import TTLCache
26
25
 
@@ -45,11 +44,11 @@ def installed_nucliadb() -> str:
45
44
  return get_installed_version(StandalonePackages.NUCLIADB.value)
46
45
 
47
46
 
48
- async def latest_nucliadb() -> Optional[str]:
47
+ async def latest_nucliadb() -> str | None:
49
48
  return await get_latest_version(StandalonePackages.NUCLIADB.value)
50
49
 
51
50
 
52
- def nucliadb_updates_available(installed: str, latest: Optional[str]) -> bool:
51
+ def nucliadb_updates_available(installed: str, latest: str | None) -> bool:
53
52
  if latest is None:
54
53
  return False
55
54
  return is_newer_release(installed, latest)
@@ -96,7 +95,7 @@ def get_installed_version(package_name: str) -> str:
96
95
  return importlib.metadata.distribution(package_name).version
97
96
 
98
97
 
99
- async def get_latest_version(package: str) -> Optional[str]:
98
+ async def get_latest_version(package: str) -> str | None:
100
99
  result = CACHE.get(package, None)
101
100
  if result is None:
102
101
  try:
@@ -19,9 +19,10 @@
19
19
  #
20
20
 
21
21
  import asyncio
22
- from typing import Generic, Optional, Type
22
+ from typing import Generic
23
23
 
24
24
  import nats
25
+ import nats.js.api
25
26
  import pydantic
26
27
  from nats.aio.client import Msg
27
28
 
@@ -43,8 +44,9 @@ class NatsTaskConsumer(Generic[MsgType]):
43
44
  stream: NatsStream,
44
45
  consumer: NatsConsumer,
45
46
  callback: Callback,
46
- msg_type: Type[MsgType],
47
- max_concurrent_messages: Optional[int] = None,
47
+ msg_type: type[MsgType],
48
+ max_concurrent_messages: int | None = None,
49
+ max_deliver: int | None = None,
48
50
  ):
49
51
  self.name = name
50
52
  self.stream = stream
@@ -52,6 +54,7 @@ class NatsTaskConsumer(Generic[MsgType]):
52
54
  self.callback = callback
53
55
  self.msg_type = msg_type
54
56
  self.max_concurrent_messages = max_concurrent_messages
57
+ self.max_deliver = max_deliver
55
58
  self.initialized = False
56
59
  self.running_tasks: list[asyncio.Task] = []
57
60
  self.subscription = None
@@ -71,7 +74,8 @@ class NatsTaskConsumer(Generic[MsgType]):
71
74
  for task in self.running_tasks:
72
75
  task.cancel()
73
76
  try:
74
- await asyncio.wait(self.running_tasks, timeout=5)
77
+ if len(self.running_tasks) > 0:
78
+ await asyncio.wait(self.running_tasks, timeout=5)
75
79
  self.running_tasks.clear()
76
80
  except asyncio.TimeoutError:
77
81
  pass
@@ -96,6 +100,7 @@ class NatsTaskConsumer(Generic[MsgType]):
96
100
  ack_wait=nats_consumer_settings.nats_ack_wait,
97
101
  idle_heartbeat=nats_consumer_settings.nats_idle_heartbeat,
98
102
  max_ack_pending=max_ack_pending,
103
+ max_deliver=self.max_deliver,
99
104
  ),
100
105
  )
101
106
  logger.info(
@@ -168,8 +173,6 @@ class NatsTaskConsumer(Generic[MsgType]):
168
173
  },
169
174
  )
170
175
  await msg.ack()
171
- finally:
172
- return
173
176
 
174
177
 
175
178
  def create_consumer(
@@ -177,8 +180,9 @@ def create_consumer(
177
180
  stream: NatsStream,
178
181
  consumer: NatsConsumer,
179
182
  callback: Callback,
180
- msg_type: Type[MsgType],
181
- max_concurrent_messages: Optional[int] = None,
183
+ msg_type: type[MsgType],
184
+ max_concurrent_messages: int | None = None,
185
+ max_retries: int = 100,
182
186
  ) -> NatsTaskConsumer[MsgType]:
183
187
  """
184
188
  Returns a non-initialized consumer
@@ -190,4 +194,5 @@ def create_consumer(
190
194
  callback=callback,
191
195
  msg_type=msg_type,
192
196
  max_concurrent_messages=max_concurrent_messages,
197
+ max_deliver=max_retries,
193
198
  )
nucliadb/tasks/models.py CHANGED
@@ -17,7 +17,8 @@
17
17
  # You should have received a copy of the GNU Affero General Public License
18
18
  # along with this program. If not, see <http://www.gnu.org/licenses/>.
19
19
  #
20
- from typing import Any, Callable, Coroutine, TypeVar
20
+ from collections.abc import Callable, Coroutine
21
+ from typing import Any, TypeVar
21
22
 
22
23
  import pydantic
23
24
 
@@ -17,7 +17,7 @@
17
17
  # You should have received a copy of the GNU Affero General Public License
18
18
  # along with this program. If not, see <http://www.gnu.org/licenses/>.
19
19
  #
20
- from typing import Generic, Type
20
+ from typing import Generic
21
21
 
22
22
  from nucliadb.tasks.logger import logger
23
23
  from nucliadb.tasks.models import MsgType
@@ -32,7 +32,7 @@ class NatsTaskProducer(Generic[MsgType]):
32
32
  name: str,
33
33
  stream: NatsStream,
34
34
  producer_subject: str,
35
- msg_type: Type[MsgType],
35
+ msg_type: type[MsgType],
36
36
  ):
37
37
  self.name = name
38
38
  self.stream = stream
@@ -69,7 +69,7 @@ def create_producer(
69
69
  name: str,
70
70
  stream: NatsStream,
71
71
  producer_subject: str,
72
- msg_type: Type[MsgType],
72
+ msg_type: type[MsgType],
73
73
  ) -> NatsTaskProducer[MsgType]:
74
74
  """
75
75
  Returns a non-initialized producer.
nucliadb/tasks/retries.py CHANGED
@@ -19,9 +19,10 @@
19
19
  #
20
20
  import functools
21
21
  import logging
22
+ from collections.abc import Callable
22
23
  from datetime import datetime, timezone
23
24
  from enum import Enum
24
- from typing import Callable, Optional, cast
25
+ from typing import cast
25
26
 
26
27
  from pydantic import BaseModel
27
28
 
@@ -44,7 +45,7 @@ class TaskMetadata(BaseModel):
44
45
  status: Status
45
46
  retries: int = 0
46
47
  error_messages: list[str] = []
47
- last_modified: Optional[datetime] = None
48
+ last_modified: datetime | None = None
48
49
 
49
50
 
50
51
  class TaskRetryHandler:
@@ -87,7 +88,7 @@ class TaskRetryHandler:
87
88
  kbid=self.kbid, task_type=self.task_type, task_id=self.task_id
88
89
  )
89
90
 
90
- async def get_metadata(self) -> Optional[TaskMetadata]:
91
+ async def get_metadata(self) -> TaskMetadata | None:
91
92
  return await _get_metadata(self.context.kv_driver, self.metadata_key)
92
93
 
93
94
  async def set_metadata(self, metadata: TaskMetadata) -> None:
@@ -150,7 +151,7 @@ class TaskRetryHandler:
150
151
  return wrapper
151
152
 
152
153
 
153
- async def _get_metadata(kv_driver: Driver, metadata_key: str) -> Optional[TaskMetadata]:
154
+ async def _get_metadata(kv_driver: Driver, metadata_key: str) -> TaskMetadata | None:
154
155
  async with kv_driver.ro_transaction() as txn:
155
156
  metadata = await txn.get(metadata_key)
156
157
  if metadata is None:
@@ -173,7 +174,7 @@ async def purge_metadata(kv_driver: Driver) -> int:
173
174
  return 0
174
175
 
175
176
  total_purged = 0
176
- start: Optional[str] = ""
177
+ start: str | None = ""
177
178
  while True:
178
179
  start, purged = await purge_batch(kv_driver, start)
179
180
  total_purged += purged
@@ -183,8 +184,8 @@ async def purge_metadata(kv_driver: Driver) -> int:
183
184
 
184
185
 
185
186
  async def purge_batch(
186
- kv_driver: PGDriver, start: Optional[str] = None, batch_size: int = 200
187
- ) -> tuple[Optional[str], int]:
187
+ kv_driver: PGDriver, start: str | None = None, batch_size: int = 200
188
+ ) -> tuple[str | None, int]:
188
189
  """
189
190
  Returns the next start key and the number of purged records. If start is None, it means there are no more records to purge.
190
191
  """
@@ -19,12 +19,10 @@
19
19
  #
20
20
 
21
21
 
22
- from typing import Optional
23
-
24
22
  from nucliadb.train.utils import get_shard_manager
25
23
 
26
24
 
27
- async def get_kb_partitions(kbid: str, prefix: Optional[str] = None) -> list[str]:
25
+ async def get_kb_partitions(kbid: str, prefix: str | None = None) -> list[str]:
28
26
  shard_manager = get_shard_manager()
29
27
  shards = await shard_manager.get_shards_by_kbid_inner(kbid=kbid)
30
28
  valid_shards = []
@@ -18,7 +18,6 @@
18
18
  # along with this program. If not, see <http://www.gnu.org/licenses/>.
19
19
  #
20
20
  import json
21
- from typing import Optional
22
21
 
23
22
  import google.protobuf.message
24
23
  import pydantic
@@ -63,7 +62,7 @@ async def object_get_response(
63
62
  )
64
63
 
65
64
 
66
- async def get_trainset(request: Request) -> tuple[TrainSet, Optional[FilterExpression]]:
65
+ async def get_trainset(request: Request) -> tuple[TrainSet, FilterExpression | None]:
67
66
  if request.headers.get("Content-Type") == "application/json":
68
67
  try:
69
68
  trainset_model = TrainSetModel.model_validate(await request.json())
@@ -18,7 +18,6 @@
18
18
  # along with this program. If not, see <http://www.gnu.org/licenses/>.
19
19
  #
20
20
 
21
- from typing import Optional
22
21
 
23
22
  from fastapi import HTTPException, Request
24
23
  from fastapi_versioning import version
@@ -57,7 +56,7 @@ async def get_partitions_prefix(request: Request, kbid: str, prefix: str) -> Tra
57
56
  return await get_partitions(kbid, prefix=prefix)
58
57
 
59
58
 
60
- async def get_partitions(kbid: str, prefix: Optional[str] = None) -> TrainSetPartitions:
59
+ async def get_partitions(kbid: str, prefix: str | None = None) -> TrainSetPartitions:
61
60
  try:
62
61
  all_keys = await get_kb_partitions(kbid, prefix)
63
62
  except ShardNotFound:
nucliadb/train/app.py CHANGED
@@ -50,7 +50,6 @@ errors.setup_error_handling(importlib.metadata.distribution("nucliadb").version)
50
50
 
51
51
  fastapi_settings = dict(
52
52
  debug=running_settings.debug,
53
- middleware=middleware,
54
53
  lifespan=lifespan,
55
54
  exception_handlers={
56
55
  Exception: global_exception_handler,
@@ -71,6 +70,7 @@ application = VersionedFastAPI(
71
70
  prefix_format=f"/{API_PREFIX}/v{{major}}",
72
71
  default_version=(1, 0),
73
72
  enable_latest=False,
73
+ middleware=middleware,
74
74
  kwargs=fastapi_settings,
75
75
  )
76
76
 
@@ -17,7 +17,7 @@
17
17
  # You should have received a copy of the GNU Affero General Public License
18
18
  # along with this program. If not, see <http://www.gnu.org/licenses/>.
19
19
  #
20
- from typing import AsyncIterator, Callable, Optional
20
+ from collections.abc import AsyncIterator, Callable
21
21
 
22
22
  from fastapi import HTTPException
23
23
  from grpc import StatusCode
@@ -53,11 +53,11 @@ from nucliadb.train.utils import get_shard_manager
53
53
  from nucliadb_models.filters import FilterExpression
54
54
  from nucliadb_protos.dataset_pb2 import TaskType, TrainSet
55
55
 
56
- BatchGenerator = Callable[[str, TrainSet, str, Optional[FilterExpression]], AsyncIterator[TrainBatch]]
56
+ BatchGenerator = Callable[[str, TrainSet, str, FilterExpression | None], AsyncIterator[TrainBatch]]
57
57
 
58
58
 
59
59
  async def generate_train_data(
60
- kbid: str, shard: str, trainset: TrainSet, filter_expression: Optional[FilterExpression] = None
60
+ kbid: str, shard: str, trainset: TrainSet, filter_expression: FilterExpression | None = None
61
61
  ):
62
62
  # Get the data structure to generate data
63
63
  shard_manager = get_shard_manager()
@@ -66,7 +66,7 @@ async def generate_train_data(
66
66
  if trainset.batch_size == 0:
67
67
  trainset.batch_size = 50
68
68
 
69
- batch_generator: Optional[BatchGenerator] = None
69
+ batch_generator: BatchGenerator | None = None
70
70
 
71
71
  if trainset.type == TaskType.FIELD_CLASSIFICATION:
72
72
  batch_generator = field_classification_batch_generator
@@ -18,7 +18,7 @@
18
18
  # along with this program. If not, see <http://www.gnu.org/licenses/>.
19
19
  #
20
20
 
21
- from typing import AsyncGenerator, Optional
21
+ from collections.abc import AsyncGenerator
22
22
 
23
23
  from nidx_protos.nodereader_pb2 import StreamRequest
24
24
 
@@ -39,7 +39,7 @@ def field_classification_batch_generator(
39
39
  kbid: str,
40
40
  trainset: TrainSet,
41
41
  shard_replica_id: str,
42
- filter_expression: Optional[FilterExpression],
42
+ filter_expression: FilterExpression | None,
43
43
  ) -> AsyncGenerator[FieldClassificationBatch, None]:
44
44
  generator = generate_field_classification_payloads(kbid, trainset, shard_replica_id)
45
45
  batch_generator = batchify(generator, trainset.batch_size, FieldClassificationBatch)
@@ -19,7 +19,7 @@
19
19
  #
20
20
 
21
21
  import asyncio
22
- from typing import AsyncGenerator, AsyncIterable, Optional
22
+ from collections.abc import AsyncGenerator, AsyncIterable
23
23
 
24
24
  from nidx_protos.nodereader_pb2 import DocumentItem, StreamRequest
25
25
 
@@ -45,7 +45,7 @@ def field_streaming_batch_generator(
45
45
  kbid: str,
46
46
  trainset: TrainSet,
47
47
  shard_replica_id: str,
48
- filter_expression: Optional[FilterExpression],
48
+ filter_expression: FilterExpression | None,
49
49
  ) -> AsyncGenerator[FieldStreamingBatch, None]:
50
50
  generator = generate_field_streaming_payloads(kbid, trainset, shard_replica_id, filter_expression)
51
51
  batch_generator = batchify(generator, trainset.batch_size, FieldStreamingBatch)
@@ -53,7 +53,7 @@ def field_streaming_batch_generator(
53
53
 
54
54
 
55
55
  async def generate_field_streaming_payloads(
56
- kbid: str, trainset: TrainSet, shard_replica_id: str, filter_expression: Optional[FilterExpression]
56
+ kbid: str, trainset: TrainSet, shard_replica_id: str, filter_expression: FilterExpression | None
57
57
  ) -> AsyncGenerator[FieldSplitData, None]:
58
58
  request = StreamRequest()
59
59
  request.shard_id.id = shard_replica_id
@@ -192,7 +192,7 @@ async def _fetch_basic(kbid: str, fsd: FieldSplitData):
192
192
  fsd.basic.CopyFrom(basic)
193
193
 
194
194
 
195
- async def get_field_text(kbid: str, rid: str, field: str, field_type: str) -> Optional[ExtractedText]:
195
+ async def get_field_text(kbid: str, rid: str, field: str, field_type: str) -> ExtractedText | None:
196
196
  orm_resource = await get_resource_from_cache_or_db(kbid, rid)
197
197
 
198
198
  if orm_resource is None:
@@ -208,7 +208,7 @@ async def get_field_text(kbid: str, rid: str, field: str, field_type: str) -> Op
208
208
 
209
209
  async def get_field_metadata(
210
210
  kbid: str, rid: str, field: str, field_type: str
211
- ) -> Optional[FieldComputedMetadata]:
211
+ ) -> FieldComputedMetadata | None:
212
212
  orm_resource = await get_resource_from_cache_or_db(kbid, rid)
213
213
 
214
214
  if orm_resource is None:
@@ -222,7 +222,7 @@ async def get_field_metadata(
222
222
  return field_metadata
223
223
 
224
224
 
225
- async def get_field_basic(kbid: str, rid: str, field: str, field_type: str) -> Optional[Basic]:
225
+ async def get_field_basic(kbid: str, rid: str, field: str, field_type: str) -> Basic | None:
226
226
  orm_resource = await get_resource_from_cache_or_db(kbid, rid)
227
227
 
228
228
  if orm_resource is None:
@@ -18,7 +18,7 @@
18
18
  # along with this program. If not, see <http://www.gnu.org/licenses/>.
19
19
  #
20
20
 
21
- from typing import AsyncGenerator, Optional
21
+ from collections.abc import AsyncGenerator
22
22
 
23
23
  from nucliadb.train.generators.utils import batchify
24
24
  from nucliadb_models.filters import FilterExpression
@@ -33,7 +33,7 @@ def image_classification_batch_generator(
33
33
  kbid: str,
34
34
  trainset: TrainSet,
35
35
  shard_replica_id: str,
36
- filter_expression: Optional[FilterExpression],
36
+ filter_expression: FilterExpression | None,
37
37
  ) -> AsyncGenerator[ImageClassificationBatch, None]:
38
38
  generator = generate_image_classification_payloads(kbid, trainset, shard_replica_id)
39
39
  batch_generator = batchify(generator, trainset.batch_size, ImageClassificationBatch)
@@ -18,7 +18,7 @@
18
18
  # along with this program. If not, see <http://www.gnu.org/licenses/>.
19
19
  #
20
20
 
21
- from typing import AsyncGenerator, Optional
21
+ from collections.abc import AsyncGenerator
22
22
 
23
23
  from fastapi import HTTPException
24
24
  from nidx_protos.nodereader_pb2 import StreamRequest
@@ -38,7 +38,7 @@ def paragraph_classification_batch_generator(
38
38
  kbid: str,
39
39
  trainset: TrainSet,
40
40
  shard_replica_id: str,
41
- filter_expression: Optional[FilterExpression],
41
+ filter_expression: FilterExpression | None,
42
42
  ) -> AsyncGenerator[ParagraphClassificationBatch, None]:
43
43
  if len(trainset.filter.labels) != 1:
44
44
  raise HTTPException(
@@ -18,7 +18,7 @@
18
18
  # along with this program. If not, see <http://www.gnu.org/licenses/>.
19
19
  #
20
20
 
21
- from typing import AsyncGenerator, Optional
21
+ from collections.abc import AsyncGenerator
22
22
 
23
23
  from nidx_protos.nodereader_pb2 import StreamRequest
24
24
 
@@ -38,7 +38,7 @@ def paragraph_streaming_batch_generator(
38
38
  kbid: str,
39
39
  trainset: TrainSet,
40
40
  shard_replica_id: str,
41
- filter_expression: Optional[FilterExpression],
41
+ filter_expression: FilterExpression | None,
42
42
  ) -> AsyncGenerator[ParagraphStreamingBatch, None]:
43
43
  generator = generate_paragraph_streaming_payloads(kbid, trainset, shard_replica_id)
44
44
  batch_generator = batchify(generator, trainset.batch_size, ParagraphStreamingBatch)
@@ -18,7 +18,7 @@
18
18
  # along with this program. If not, see <http://www.gnu.org/licenses/>.
19
19
  #
20
20
 
21
- from typing import AsyncGenerator, Optional
21
+ from collections.abc import AsyncGenerator
22
22
 
23
23
  from nidx_protos.nodereader_pb2 import StreamRequest
24
24
 
@@ -47,7 +47,7 @@ def question_answer_batch_generator(
47
47
  kbid: str,
48
48
  trainset: TrainSet,
49
49
  shard_replica_id: str,
50
- filter_expression: Optional[FilterExpression],
50
+ filter_expression: FilterExpression | None,
51
51
  ) -> AsyncGenerator[QuestionAnswerStreamingBatch, None]:
52
52
  generator = generate_question_answer_streaming_payloads(kbid, trainset, shard_replica_id)
53
53
  batch_generator = batchify(generator, trainset.batch_size, QuestionAnswerStreamingBatch)
@@ -18,7 +18,7 @@
18
18
  # along with this program. If not, see <http://www.gnu.org/licenses/>.
19
19
  #
20
20
 
21
- from typing import AsyncGenerator, Optional
21
+ from collections.abc import AsyncGenerator
22
22
 
23
23
  from fastapi import HTTPException
24
24
  from nidx_protos.nodereader_pb2 import StreamRequest
@@ -40,7 +40,7 @@ def sentence_classification_batch_generator(
40
40
  kbid: str,
41
41
  trainset: TrainSet,
42
42
  shard_replica_id: str,
43
- filter_expression: Optional[FilterExpression],
43
+ filter_expression: FilterExpression | None,
44
44
  ) -> AsyncGenerator[SentenceClassificationBatch, None]:
45
45
  if len(trainset.filter.labels) == 0:
46
46
  raise HTTPException(
@@ -19,7 +19,8 @@
19
19
  #
20
20
 
21
21
  from collections import OrderedDict
22
- from typing import AsyncGenerator, Optional, cast
22
+ from collections.abc import AsyncGenerator
23
+ from typing import cast
23
24
 
24
25
  from nidx_protos.nodereader_pb2 import StreamFilter, StreamRequest
25
26
 
@@ -43,7 +44,7 @@ def token_classification_batch_generator(
43
44
  kbid: str,
44
45
  trainset: TrainSet,
45
46
  shard_replica_id: str,
46
- filter_expression: Optional[FilterExpression],
47
+ filter_expression: FilterExpression | None,
47
48
  ) -> AsyncGenerator[TokenClassificationBatch, None]:
48
49
  generator = generate_token_classification_payloads(kbid, trainset, shard_replica_id)
49
50
  batch_generator = batchify(generator, trainset.batch_size, TokenClassificationBatch)
@@ -18,7 +18,8 @@
18
18
  # along with this program. If not, see <http://www.gnu.org/licenses/>.
19
19
  #
20
20
 
21
- from typing import Any, AsyncGenerator, AsyncIterator, Optional, Type
21
+ from collections.abc import AsyncGenerator, AsyncIterator
22
+ from typing import Any
22
23
 
23
24
  from nucliadb.common.cache import get_resource_cache
24
25
  from nucliadb.common.ids import FIELD_TYPE_STR_TO_PB
@@ -30,16 +31,16 @@ from nucliadb.train.types import T
30
31
  from nucliadb_utils.utilities import get_storage
31
32
 
32
33
 
33
- async def get_resource_from_cache_or_db(kbid: str, uuid: str) -> Optional[ResourceORM]:
34
+ async def get_resource_from_cache_or_db(kbid: str, uuid: str) -> ResourceORM | None:
34
35
  resource_cache = get_resource_cache()
35
36
  if resource_cache is None:
36
- return await _get_resource_from_db(kbid, uuid)
37
37
  logger.warning("Resource cache is not set")
38
+ return await _get_resource_from_db(kbid, uuid)
38
39
 
39
40
  return await resource_cache.get(kbid, uuid)
40
41
 
41
42
 
42
- async def _get_resource_from_db(kbid: str, uuid: str) -> Optional[ResourceORM]:
43
+ async def _get_resource_from_db(kbid: str, uuid: str) -> ResourceORM | None:
43
44
  storage = await get_storage(service_name=SERVICE_NAME)
44
45
  async with get_driver().ro_transaction() as transaction:
45
46
  kb = KnowledgeBoxORM(transaction, storage, kbid)
@@ -81,7 +82,7 @@ async def get_paragraph(kbid: str, paragraph_id: str) -> str:
81
82
 
82
83
 
83
84
  async def batchify(
84
- producer: AsyncIterator[Any], size: int, batch_klass: Type[T]
85
+ producer: AsyncIterator[Any], size: int, batch_klass: type[T]
85
86
  ) -> AsyncGenerator[T, None]:
86
87
  # NOTE: we are supposing all protobuffers have a data field
87
88
  batch = []