nucliadb 2.46.1.post382__py3-none-any.whl → 6.2.1.post2777__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (431) hide show
  1. migrations/0002_rollover_shards.py +1 -2
  2. migrations/0003_allfields_key.py +2 -37
  3. migrations/0004_rollover_shards.py +1 -2
  4. migrations/0005_rollover_shards.py +1 -2
  5. migrations/0006_rollover_shards.py +2 -4
  6. migrations/0008_cleanup_leftover_rollover_metadata.py +1 -2
  7. migrations/0009_upgrade_relations_and_texts_to_v2.py +5 -4
  8. migrations/0010_fix_corrupt_indexes.py +11 -12
  9. migrations/0011_materialize_labelset_ids.py +2 -18
  10. migrations/0012_rollover_shards.py +6 -12
  11. migrations/0013_rollover_shards.py +2 -4
  12. migrations/0014_rollover_shards.py +5 -7
  13. migrations/0015_targeted_rollover.py +6 -12
  14. migrations/0016_upgrade_to_paragraphs_v2.py +27 -32
  15. migrations/0017_multiple_writable_shards.py +3 -6
  16. migrations/0018_purge_orphan_kbslugs.py +59 -0
  17. migrations/0019_upgrade_to_paragraphs_v3.py +66 -0
  18. migrations/0020_drain_nodes_from_cluster.py +83 -0
  19. nucliadb/standalone/tests/unit/test_run.py → migrations/0021_overwrite_vectorsets_key.py +17 -18
  20. nucliadb/tests/unit/test_openapi.py → migrations/0022_fix_paragraph_deletion_bug.py +16 -11
  21. migrations/0023_backfill_pg_catalog.py +80 -0
  22. migrations/0025_assign_models_to_kbs_v2.py +113 -0
  23. migrations/0026_fix_high_cardinality_content_types.py +61 -0
  24. migrations/0027_rollover_texts3.py +73 -0
  25. nucliadb/ingest/fields/date.py → migrations/pg/0001_bootstrap.py +10 -12
  26. migrations/pg/0002_catalog.py +42 -0
  27. nucliadb/ingest/tests/unit/test_settings.py → migrations/pg/0003_catalog_kbid_index.py +5 -3
  28. nucliadb/common/cluster/base.py +41 -24
  29. nucliadb/common/cluster/discovery/base.py +6 -14
  30. nucliadb/common/cluster/discovery/k8s.py +9 -19
  31. nucliadb/common/cluster/discovery/manual.py +1 -3
  32. nucliadb/common/cluster/discovery/single.py +1 -2
  33. nucliadb/common/cluster/discovery/utils.py +1 -3
  34. nucliadb/common/cluster/grpc_node_dummy.py +11 -16
  35. nucliadb/common/cluster/index_node.py +10 -19
  36. nucliadb/common/cluster/manager.py +223 -102
  37. nucliadb/common/cluster/rebalance.py +42 -37
  38. nucliadb/common/cluster/rollover.py +377 -204
  39. nucliadb/common/cluster/settings.py +16 -9
  40. nucliadb/common/cluster/standalone/grpc_node_binding.py +24 -76
  41. nucliadb/common/cluster/standalone/index_node.py +4 -11
  42. nucliadb/common/cluster/standalone/service.py +2 -6
  43. nucliadb/common/cluster/standalone/utils.py +9 -6
  44. nucliadb/common/cluster/utils.py +43 -29
  45. nucliadb/common/constants.py +20 -0
  46. nucliadb/common/context/__init__.py +6 -4
  47. nucliadb/common/context/fastapi.py +8 -5
  48. nucliadb/{tests/knowledgeboxes/__init__.py → common/counters.py} +8 -2
  49. nucliadb/common/datamanagers/__init__.py +24 -5
  50. nucliadb/common/datamanagers/atomic.py +102 -0
  51. nucliadb/common/datamanagers/cluster.py +5 -5
  52. nucliadb/common/datamanagers/entities.py +6 -16
  53. nucliadb/common/datamanagers/fields.py +84 -0
  54. nucliadb/common/datamanagers/kb.py +101 -24
  55. nucliadb/common/datamanagers/labels.py +26 -56
  56. nucliadb/common/datamanagers/processing.py +2 -6
  57. nucliadb/common/datamanagers/resources.py +214 -117
  58. nucliadb/common/datamanagers/rollover.py +77 -16
  59. nucliadb/{ingest/orm → common/datamanagers}/synonyms.py +16 -28
  60. nucliadb/common/datamanagers/utils.py +19 -11
  61. nucliadb/common/datamanagers/vectorsets.py +110 -0
  62. nucliadb/common/external_index_providers/base.py +257 -0
  63. nucliadb/{ingest/tests/unit/test_cache.py → common/external_index_providers/exceptions.py} +9 -8
  64. nucliadb/common/external_index_providers/manager.py +101 -0
  65. nucliadb/common/external_index_providers/pinecone.py +933 -0
  66. nucliadb/common/external_index_providers/settings.py +52 -0
  67. nucliadb/common/http_clients/auth.py +3 -6
  68. nucliadb/common/http_clients/processing.py +6 -11
  69. nucliadb/common/http_clients/utils.py +1 -3
  70. nucliadb/common/ids.py +240 -0
  71. nucliadb/common/locking.py +43 -13
  72. nucliadb/common/maindb/driver.py +11 -35
  73. nucliadb/common/maindb/exceptions.py +6 -6
  74. nucliadb/common/maindb/local.py +22 -9
  75. nucliadb/common/maindb/pg.py +206 -111
  76. nucliadb/common/maindb/utils.py +13 -44
  77. nucliadb/common/models_utils/from_proto.py +479 -0
  78. nucliadb/common/models_utils/to_proto.py +60 -0
  79. nucliadb/common/nidx.py +260 -0
  80. nucliadb/export_import/datamanager.py +25 -19
  81. nucliadb/export_import/exceptions.py +8 -0
  82. nucliadb/export_import/exporter.py +20 -7
  83. nucliadb/export_import/importer.py +6 -11
  84. nucliadb/export_import/models.py +5 -5
  85. nucliadb/export_import/tasks.py +4 -4
  86. nucliadb/export_import/utils.py +94 -54
  87. nucliadb/health.py +1 -3
  88. nucliadb/ingest/app.py +15 -11
  89. nucliadb/ingest/consumer/auditing.py +30 -147
  90. nucliadb/ingest/consumer/consumer.py +96 -52
  91. nucliadb/ingest/consumer/materializer.py +10 -12
  92. nucliadb/ingest/consumer/pull.py +12 -27
  93. nucliadb/ingest/consumer/service.py +20 -19
  94. nucliadb/ingest/consumer/shard_creator.py +7 -14
  95. nucliadb/ingest/consumer/utils.py +1 -3
  96. nucliadb/ingest/fields/base.py +139 -188
  97. nucliadb/ingest/fields/conversation.py +18 -5
  98. nucliadb/ingest/fields/exceptions.py +1 -4
  99. nucliadb/ingest/fields/file.py +7 -25
  100. nucliadb/ingest/fields/link.py +11 -16
  101. nucliadb/ingest/fields/text.py +9 -4
  102. nucliadb/ingest/orm/brain.py +255 -262
  103. nucliadb/ingest/orm/broker_message.py +181 -0
  104. nucliadb/ingest/orm/entities.py +36 -51
  105. nucliadb/ingest/orm/exceptions.py +12 -0
  106. nucliadb/ingest/orm/knowledgebox.py +334 -278
  107. nucliadb/ingest/orm/processor/__init__.py +2 -697
  108. nucliadb/ingest/orm/processor/auditing.py +117 -0
  109. nucliadb/ingest/orm/processor/data_augmentation.py +164 -0
  110. nucliadb/ingest/orm/processor/pgcatalog.py +84 -0
  111. nucliadb/ingest/orm/processor/processor.py +752 -0
  112. nucliadb/ingest/orm/processor/sequence_manager.py +1 -1
  113. nucliadb/ingest/orm/resource.py +280 -520
  114. nucliadb/ingest/orm/utils.py +25 -31
  115. nucliadb/ingest/partitions.py +3 -9
  116. nucliadb/ingest/processing.py +76 -81
  117. nucliadb/ingest/py.typed +0 -0
  118. nucliadb/ingest/serialize.py +37 -173
  119. nucliadb/ingest/service/__init__.py +1 -3
  120. nucliadb/ingest/service/writer.py +186 -577
  121. nucliadb/ingest/settings.py +13 -22
  122. nucliadb/ingest/utils.py +3 -6
  123. nucliadb/learning_proxy.py +264 -51
  124. nucliadb/metrics_exporter.py +30 -19
  125. nucliadb/middleware/__init__.py +1 -3
  126. nucliadb/migrator/command.py +1 -3
  127. nucliadb/migrator/datamanager.py +13 -13
  128. nucliadb/migrator/migrator.py +57 -37
  129. nucliadb/migrator/settings.py +2 -1
  130. nucliadb/migrator/utils.py +18 -10
  131. nucliadb/purge/__init__.py +139 -33
  132. nucliadb/purge/orphan_shards.py +7 -13
  133. nucliadb/reader/__init__.py +1 -3
  134. nucliadb/reader/api/models.py +3 -14
  135. nucliadb/reader/api/v1/__init__.py +0 -1
  136. nucliadb/reader/api/v1/download.py +27 -94
  137. nucliadb/reader/api/v1/export_import.py +4 -4
  138. nucliadb/reader/api/v1/knowledgebox.py +13 -13
  139. nucliadb/reader/api/v1/learning_config.py +8 -12
  140. nucliadb/reader/api/v1/resource.py +67 -93
  141. nucliadb/reader/api/v1/services.py +70 -125
  142. nucliadb/reader/app.py +16 -46
  143. nucliadb/reader/lifecycle.py +18 -4
  144. nucliadb/reader/py.typed +0 -0
  145. nucliadb/reader/reader/notifications.py +10 -31
  146. nucliadb/search/__init__.py +1 -3
  147. nucliadb/search/api/v1/__init__.py +2 -2
  148. nucliadb/search/api/v1/ask.py +112 -0
  149. nucliadb/search/api/v1/catalog.py +184 -0
  150. nucliadb/search/api/v1/feedback.py +17 -25
  151. nucliadb/search/api/v1/find.py +41 -41
  152. nucliadb/search/api/v1/knowledgebox.py +90 -62
  153. nucliadb/search/api/v1/predict_proxy.py +2 -2
  154. nucliadb/search/api/v1/resource/ask.py +66 -117
  155. nucliadb/search/api/v1/resource/search.py +51 -72
  156. nucliadb/search/api/v1/router.py +1 -0
  157. nucliadb/search/api/v1/search.py +50 -197
  158. nucliadb/search/api/v1/suggest.py +40 -54
  159. nucliadb/search/api/v1/summarize.py +9 -5
  160. nucliadb/search/api/v1/utils.py +2 -1
  161. nucliadb/search/app.py +16 -48
  162. nucliadb/search/lifecycle.py +10 -3
  163. nucliadb/search/predict.py +176 -188
  164. nucliadb/search/py.typed +0 -0
  165. nucliadb/search/requesters/utils.py +41 -63
  166. nucliadb/search/search/cache.py +149 -20
  167. nucliadb/search/search/chat/ask.py +918 -0
  168. nucliadb/search/{tests/unit/test_run.py → search/chat/exceptions.py} +14 -13
  169. nucliadb/search/search/chat/images.py +41 -17
  170. nucliadb/search/search/chat/prompt.py +851 -282
  171. nucliadb/search/search/chat/query.py +274 -267
  172. nucliadb/{writer/resource/slug.py → search/search/cut.py} +8 -6
  173. nucliadb/search/search/fetch.py +43 -36
  174. nucliadb/search/search/filters.py +9 -15
  175. nucliadb/search/search/find.py +214 -54
  176. nucliadb/search/search/find_merge.py +408 -391
  177. nucliadb/search/search/hydrator.py +191 -0
  178. nucliadb/search/search/merge.py +198 -234
  179. nucliadb/search/search/metrics.py +73 -2
  180. nucliadb/search/search/paragraphs.py +64 -106
  181. nucliadb/search/search/pgcatalog.py +233 -0
  182. nucliadb/search/search/predict_proxy.py +1 -1
  183. nucliadb/search/search/query.py +386 -257
  184. nucliadb/search/search/query_parser/exceptions.py +22 -0
  185. nucliadb/search/search/query_parser/models.py +101 -0
  186. nucliadb/search/search/query_parser/parser.py +183 -0
  187. nucliadb/search/search/rank_fusion.py +204 -0
  188. nucliadb/search/search/rerankers.py +270 -0
  189. nucliadb/search/search/shards.py +4 -38
  190. nucliadb/search/search/summarize.py +14 -18
  191. nucliadb/search/search/utils.py +27 -4
  192. nucliadb/search/settings.py +15 -1
  193. nucliadb/standalone/api_router.py +4 -10
  194. nucliadb/standalone/app.py +17 -14
  195. nucliadb/standalone/auth.py +7 -21
  196. nucliadb/standalone/config.py +9 -12
  197. nucliadb/standalone/introspect.py +5 -5
  198. nucliadb/standalone/lifecycle.py +26 -25
  199. nucliadb/standalone/migrations.py +58 -0
  200. nucliadb/standalone/purge.py +9 -8
  201. nucliadb/standalone/py.typed +0 -0
  202. nucliadb/standalone/run.py +25 -18
  203. nucliadb/standalone/settings.py +10 -14
  204. nucliadb/standalone/versions.py +15 -5
  205. nucliadb/tasks/consumer.py +8 -12
  206. nucliadb/tasks/producer.py +7 -6
  207. nucliadb/tests/config.py +53 -0
  208. nucliadb/train/__init__.py +1 -3
  209. nucliadb/train/api/utils.py +1 -2
  210. nucliadb/train/api/v1/shards.py +2 -2
  211. nucliadb/train/api/v1/trainset.py +4 -6
  212. nucliadb/train/app.py +14 -47
  213. nucliadb/train/generator.py +10 -19
  214. nucliadb/train/generators/field_classifier.py +7 -19
  215. nucliadb/train/generators/field_streaming.py +156 -0
  216. nucliadb/train/generators/image_classifier.py +12 -18
  217. nucliadb/train/generators/paragraph_classifier.py +5 -9
  218. nucliadb/train/generators/paragraph_streaming.py +6 -9
  219. nucliadb/train/generators/question_answer_streaming.py +19 -20
  220. nucliadb/train/generators/sentence_classifier.py +9 -15
  221. nucliadb/train/generators/token_classifier.py +45 -36
  222. nucliadb/train/generators/utils.py +14 -18
  223. nucliadb/train/lifecycle.py +7 -3
  224. nucliadb/train/nodes.py +23 -32
  225. nucliadb/train/py.typed +0 -0
  226. nucliadb/train/servicer.py +13 -21
  227. nucliadb/train/settings.py +2 -6
  228. nucliadb/train/types.py +13 -10
  229. nucliadb/train/upload.py +3 -6
  230. nucliadb/train/uploader.py +20 -25
  231. nucliadb/train/utils.py +1 -1
  232. nucliadb/writer/__init__.py +1 -3
  233. nucliadb/writer/api/constants.py +0 -5
  234. nucliadb/{ingest/fields/keywordset.py → writer/api/utils.py} +13 -10
  235. nucliadb/writer/api/v1/export_import.py +102 -49
  236. nucliadb/writer/api/v1/field.py +196 -620
  237. nucliadb/writer/api/v1/knowledgebox.py +221 -71
  238. nucliadb/writer/api/v1/learning_config.py +2 -2
  239. nucliadb/writer/api/v1/resource.py +114 -216
  240. nucliadb/writer/api/v1/services.py +64 -132
  241. nucliadb/writer/api/v1/slug.py +61 -0
  242. nucliadb/writer/api/v1/transaction.py +67 -0
  243. nucliadb/writer/api/v1/upload.py +184 -215
  244. nucliadb/writer/app.py +11 -61
  245. nucliadb/writer/back_pressure.py +62 -43
  246. nucliadb/writer/exceptions.py +0 -4
  247. nucliadb/writer/lifecycle.py +21 -15
  248. nucliadb/writer/py.typed +0 -0
  249. nucliadb/writer/resource/audit.py +2 -1
  250. nucliadb/writer/resource/basic.py +48 -62
  251. nucliadb/writer/resource/field.py +45 -135
  252. nucliadb/writer/resource/origin.py +1 -2
  253. nucliadb/writer/settings.py +14 -5
  254. nucliadb/writer/tus/__init__.py +17 -15
  255. nucliadb/writer/tus/azure.py +111 -0
  256. nucliadb/writer/tus/dm.py +17 -5
  257. nucliadb/writer/tus/exceptions.py +1 -3
  258. nucliadb/writer/tus/gcs.py +56 -84
  259. nucliadb/writer/tus/local.py +21 -37
  260. nucliadb/writer/tus/s3.py +28 -68
  261. nucliadb/writer/tus/storage.py +5 -56
  262. nucliadb/writer/vectorsets.py +125 -0
  263. nucliadb-6.2.1.post2777.dist-info/METADATA +148 -0
  264. nucliadb-6.2.1.post2777.dist-info/RECORD +343 -0
  265. {nucliadb-2.46.1.post382.dist-info → nucliadb-6.2.1.post2777.dist-info}/WHEEL +1 -1
  266. nucliadb/common/maindb/redis.py +0 -194
  267. nucliadb/common/maindb/tikv.py +0 -412
  268. nucliadb/ingest/fields/layout.py +0 -58
  269. nucliadb/ingest/tests/conftest.py +0 -30
  270. nucliadb/ingest/tests/fixtures.py +0 -771
  271. nucliadb/ingest/tests/integration/consumer/__init__.py +0 -18
  272. nucliadb/ingest/tests/integration/consumer/test_auditing.py +0 -80
  273. nucliadb/ingest/tests/integration/consumer/test_materializer.py +0 -89
  274. nucliadb/ingest/tests/integration/consumer/test_pull.py +0 -144
  275. nucliadb/ingest/tests/integration/consumer/test_service.py +0 -81
  276. nucliadb/ingest/tests/integration/consumer/test_shard_creator.py +0 -68
  277. nucliadb/ingest/tests/integration/ingest/test_ingest.py +0 -691
  278. nucliadb/ingest/tests/integration/ingest/test_processing_engine.py +0 -95
  279. nucliadb/ingest/tests/integration/ingest/test_relations.py +0 -272
  280. nucliadb/ingest/tests/unit/consumer/__init__.py +0 -18
  281. nucliadb/ingest/tests/unit/consumer/test_auditing.py +0 -140
  282. nucliadb/ingest/tests/unit/consumer/test_consumer.py +0 -69
  283. nucliadb/ingest/tests/unit/consumer/test_pull.py +0 -60
  284. nucliadb/ingest/tests/unit/consumer/test_shard_creator.py +0 -139
  285. nucliadb/ingest/tests/unit/consumer/test_utils.py +0 -67
  286. nucliadb/ingest/tests/unit/orm/__init__.py +0 -19
  287. nucliadb/ingest/tests/unit/orm/test_brain.py +0 -247
  288. nucliadb/ingest/tests/unit/orm/test_processor.py +0 -131
  289. nucliadb/ingest/tests/unit/orm/test_resource.py +0 -275
  290. nucliadb/ingest/tests/unit/test_partitions.py +0 -40
  291. nucliadb/ingest/tests/unit/test_processing.py +0 -171
  292. nucliadb/middleware/transaction.py +0 -117
  293. nucliadb/reader/api/v1/learning_collector.py +0 -63
  294. nucliadb/reader/tests/__init__.py +0 -19
  295. nucliadb/reader/tests/conftest.py +0 -31
  296. nucliadb/reader/tests/fixtures.py +0 -136
  297. nucliadb/reader/tests/test_list_resources.py +0 -75
  298. nucliadb/reader/tests/test_reader_file_download.py +0 -273
  299. nucliadb/reader/tests/test_reader_resource.py +0 -379
  300. nucliadb/reader/tests/test_reader_resource_field.py +0 -219
  301. nucliadb/search/api/v1/chat.py +0 -258
  302. nucliadb/search/api/v1/resource/chat.py +0 -94
  303. nucliadb/search/tests/__init__.py +0 -19
  304. nucliadb/search/tests/conftest.py +0 -33
  305. nucliadb/search/tests/fixtures.py +0 -199
  306. nucliadb/search/tests/node.py +0 -465
  307. nucliadb/search/tests/unit/__init__.py +0 -18
  308. nucliadb/search/tests/unit/api/__init__.py +0 -19
  309. nucliadb/search/tests/unit/api/v1/__init__.py +0 -19
  310. nucliadb/search/tests/unit/api/v1/resource/__init__.py +0 -19
  311. nucliadb/search/tests/unit/api/v1/resource/test_ask.py +0 -67
  312. nucliadb/search/tests/unit/api/v1/resource/test_chat.py +0 -97
  313. nucliadb/search/tests/unit/api/v1/test_chat.py +0 -96
  314. nucliadb/search/tests/unit/api/v1/test_predict_proxy.py +0 -98
  315. nucliadb/search/tests/unit/api/v1/test_summarize.py +0 -93
  316. nucliadb/search/tests/unit/search/__init__.py +0 -18
  317. nucliadb/search/tests/unit/search/requesters/__init__.py +0 -18
  318. nucliadb/search/tests/unit/search/requesters/test_utils.py +0 -210
  319. nucliadb/search/tests/unit/search/search/__init__.py +0 -19
  320. nucliadb/search/tests/unit/search/search/test_shards.py +0 -45
  321. nucliadb/search/tests/unit/search/search/test_utils.py +0 -82
  322. nucliadb/search/tests/unit/search/test_chat_prompt.py +0 -266
  323. nucliadb/search/tests/unit/search/test_fetch.py +0 -108
  324. nucliadb/search/tests/unit/search/test_filters.py +0 -125
  325. nucliadb/search/tests/unit/search/test_paragraphs.py +0 -157
  326. nucliadb/search/tests/unit/search/test_predict_proxy.py +0 -106
  327. nucliadb/search/tests/unit/search/test_query.py +0 -201
  328. nucliadb/search/tests/unit/test_app.py +0 -79
  329. nucliadb/search/tests/unit/test_find_merge.py +0 -112
  330. nucliadb/search/tests/unit/test_merge.py +0 -34
  331. nucliadb/search/tests/unit/test_predict.py +0 -584
  332. nucliadb/standalone/tests/__init__.py +0 -19
  333. nucliadb/standalone/tests/conftest.py +0 -33
  334. nucliadb/standalone/tests/fixtures.py +0 -38
  335. nucliadb/standalone/tests/unit/__init__.py +0 -18
  336. nucliadb/standalone/tests/unit/test_api_router.py +0 -61
  337. nucliadb/standalone/tests/unit/test_auth.py +0 -169
  338. nucliadb/standalone/tests/unit/test_introspect.py +0 -35
  339. nucliadb/standalone/tests/unit/test_versions.py +0 -68
  340. nucliadb/tests/benchmarks/__init__.py +0 -19
  341. nucliadb/tests/benchmarks/test_search.py +0 -99
  342. nucliadb/tests/conftest.py +0 -32
  343. nucliadb/tests/fixtures.py +0 -736
  344. nucliadb/tests/knowledgeboxes/philosophy_books.py +0 -203
  345. nucliadb/tests/knowledgeboxes/ten_dummy_resources.py +0 -109
  346. nucliadb/tests/migrations/__init__.py +0 -19
  347. nucliadb/tests/migrations/test_migration_0017.py +0 -80
  348. nucliadb/tests/tikv.py +0 -240
  349. nucliadb/tests/unit/__init__.py +0 -19
  350. nucliadb/tests/unit/common/__init__.py +0 -19
  351. nucliadb/tests/unit/common/cluster/__init__.py +0 -19
  352. nucliadb/tests/unit/common/cluster/discovery/__init__.py +0 -19
  353. nucliadb/tests/unit/common/cluster/discovery/test_k8s.py +0 -170
  354. nucliadb/tests/unit/common/cluster/standalone/__init__.py +0 -18
  355. nucliadb/tests/unit/common/cluster/standalone/test_service.py +0 -113
  356. nucliadb/tests/unit/common/cluster/standalone/test_utils.py +0 -59
  357. nucliadb/tests/unit/common/cluster/test_cluster.py +0 -399
  358. nucliadb/tests/unit/common/cluster/test_kb_shard_manager.py +0 -178
  359. nucliadb/tests/unit/common/cluster/test_rollover.py +0 -279
  360. nucliadb/tests/unit/common/maindb/__init__.py +0 -18
  361. nucliadb/tests/unit/common/maindb/test_driver.py +0 -127
  362. nucliadb/tests/unit/common/maindb/test_tikv.py +0 -53
  363. nucliadb/tests/unit/common/maindb/test_utils.py +0 -81
  364. nucliadb/tests/unit/common/test_context.py +0 -36
  365. nucliadb/tests/unit/export_import/__init__.py +0 -19
  366. nucliadb/tests/unit/export_import/test_datamanager.py +0 -37
  367. nucliadb/tests/unit/export_import/test_utils.py +0 -294
  368. nucliadb/tests/unit/migrator/__init__.py +0 -19
  369. nucliadb/tests/unit/migrator/test_migrator.py +0 -87
  370. nucliadb/tests/unit/tasks/__init__.py +0 -19
  371. nucliadb/tests/unit/tasks/conftest.py +0 -42
  372. nucliadb/tests/unit/tasks/test_consumer.py +0 -93
  373. nucliadb/tests/unit/tasks/test_producer.py +0 -95
  374. nucliadb/tests/unit/tasks/test_tasks.py +0 -60
  375. nucliadb/tests/unit/test_field_ids.py +0 -49
  376. nucliadb/tests/unit/test_health.py +0 -84
  377. nucliadb/tests/unit/test_kb_slugs.py +0 -54
  378. nucliadb/tests/unit/test_learning_proxy.py +0 -252
  379. nucliadb/tests/unit/test_metrics_exporter.py +0 -77
  380. nucliadb/tests/unit/test_purge.py +0 -138
  381. nucliadb/tests/utils/__init__.py +0 -74
  382. nucliadb/tests/utils/aiohttp_session.py +0 -44
  383. nucliadb/tests/utils/broker_messages/__init__.py +0 -167
  384. nucliadb/tests/utils/broker_messages/fields.py +0 -181
  385. nucliadb/tests/utils/broker_messages/helpers.py +0 -33
  386. nucliadb/tests/utils/entities.py +0 -78
  387. nucliadb/train/api/v1/check.py +0 -60
  388. nucliadb/train/tests/__init__.py +0 -19
  389. nucliadb/train/tests/conftest.py +0 -29
  390. nucliadb/train/tests/fixtures.py +0 -342
  391. nucliadb/train/tests/test_field_classification.py +0 -122
  392. nucliadb/train/tests/test_get_entities.py +0 -80
  393. nucliadb/train/tests/test_get_info.py +0 -51
  394. nucliadb/train/tests/test_get_ontology.py +0 -34
  395. nucliadb/train/tests/test_get_ontology_count.py +0 -63
  396. nucliadb/train/tests/test_image_classification.py +0 -222
  397. nucliadb/train/tests/test_list_fields.py +0 -39
  398. nucliadb/train/tests/test_list_paragraphs.py +0 -73
  399. nucliadb/train/tests/test_list_resources.py +0 -39
  400. nucliadb/train/tests/test_list_sentences.py +0 -71
  401. nucliadb/train/tests/test_paragraph_classification.py +0 -123
  402. nucliadb/train/tests/test_paragraph_streaming.py +0 -118
  403. nucliadb/train/tests/test_question_answer_streaming.py +0 -239
  404. nucliadb/train/tests/test_sentence_classification.py +0 -143
  405. nucliadb/train/tests/test_token_classification.py +0 -136
  406. nucliadb/train/tests/utils.py +0 -108
  407. nucliadb/writer/layouts/__init__.py +0 -51
  408. nucliadb/writer/layouts/v1.py +0 -59
  409. nucliadb/writer/resource/vectors.py +0 -120
  410. nucliadb/writer/tests/__init__.py +0 -19
  411. nucliadb/writer/tests/conftest.py +0 -31
  412. nucliadb/writer/tests/fixtures.py +0 -192
  413. nucliadb/writer/tests/test_fields.py +0 -486
  414. nucliadb/writer/tests/test_files.py +0 -743
  415. nucliadb/writer/tests/test_knowledgebox.py +0 -49
  416. nucliadb/writer/tests/test_reprocess_file_field.py +0 -139
  417. nucliadb/writer/tests/test_resources.py +0 -546
  418. nucliadb/writer/tests/test_service.py +0 -137
  419. nucliadb/writer/tests/test_tus.py +0 -203
  420. nucliadb/writer/tests/utils.py +0 -35
  421. nucliadb/writer/tus/pg.py +0 -125
  422. nucliadb-2.46.1.post382.dist-info/METADATA +0 -134
  423. nucliadb-2.46.1.post382.dist-info/RECORD +0 -451
  424. {nucliadb/ingest/tests → migrations/pg}/__init__.py +0 -0
  425. /nucliadb/{ingest/tests/integration → common/external_index_providers}/__init__.py +0 -0
  426. /nucliadb/{ingest/tests/integration/ingest → common/models_utils}/__init__.py +0 -0
  427. /nucliadb/{ingest/tests/unit → search/search/query_parser}/__init__.py +0 -0
  428. /nucliadb/{ingest/tests → tests}/vectors.py +0 -0
  429. {nucliadb-2.46.1.post382.dist-info → nucliadb-6.2.1.post2777.dist-info}/entry_points.txt +0 -0
  430. {nucliadb-2.46.1.post382.dist-info → nucliadb-6.2.1.post2777.dist-info}/top_level.txt +0 -0
  431. {nucliadb-2.46.1.post382.dist-info → nucliadb-6.2.1.post2777.dist-info}/zip-safe +0 -0
@@ -0,0 +1,918 @@
1
+ # Copyright (C) 2021 Bosutech XXI S.L.
2
+ #
3
+ # nucliadb is offered under the AGPL v3.0 and as commercial software.
4
+ # For commercial licensing, contact us at info@nuclia.com.
5
+ #
6
+ # AGPL:
7
+ # This program is free software: you can redistribute it and/or modify
8
+ # it under the terms of the GNU Affero General Public License as
9
+ # published by the Free Software Foundation, either version 3 of the
10
+ # License, or (at your option) any later version.
11
+ #
12
+ # This program is distributed in the hope that it will be useful,
13
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
14
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15
+ # GNU Affero General Public License for more details.
16
+ #
17
+ # You should have received a copy of the GNU Affero General Public License
18
+ # along with this program. If not, see <http://www.gnu.org/licenses/>.
19
+ #
20
+ import dataclasses
21
+ import functools
22
+ import json
23
+ from typing import AsyncGenerator, Optional, cast
24
+
25
+ from nuclia_models.predict.generative_responses import (
26
+ CitationsGenerativeResponse,
27
+ GenerativeChunk,
28
+ JSONGenerativeResponse,
29
+ MetaGenerativeResponse,
30
+ StatusGenerativeResponse,
31
+ TextGenerativeResponse,
32
+ )
33
+ from pydantic_core import ValidationError
34
+
35
+ from nucliadb.common.datamanagers.exceptions import KnowledgeBoxNotFound
36
+ from nucliadb.models.responses import HTTPClientError
37
+ from nucliadb.search import logger, predict
38
+ from nucliadb.search.predict import (
39
+ AnswerStatusCode,
40
+ RephraseMissingContextError,
41
+ )
42
+ from nucliadb.search.search.chat.exceptions import (
43
+ AnswerJsonSchemaTooLong,
44
+ NoRetrievalResultsError,
45
+ )
46
+ from nucliadb.search.search.chat.prompt import PromptContextBuilder
47
+ from nucliadb.search.search.chat.query import (
48
+ NOT_ENOUGH_CONTEXT_ANSWER,
49
+ ChatAuditor,
50
+ get_find_results,
51
+ get_relations_results,
52
+ rephrase_query,
53
+ sorted_prompt_context_list,
54
+ tokens_to_chars,
55
+ )
56
+ from nucliadb.search.search.exceptions import (
57
+ IncompleteFindResultsError,
58
+ InvalidQueryError,
59
+ )
60
+ from nucliadb.search.search.metrics import RAGMetrics
61
+ from nucliadb.search.search.query import QueryParser
62
+ from nucliadb.search.utilities import get_predict
63
+ from nucliadb_models.search import (
64
+ AnswerAskResponseItem,
65
+ AskRequest,
66
+ AskResponseItem,
67
+ AskResponseItemType,
68
+ AskRetrievalMatch,
69
+ AskTimings,
70
+ AskTokens,
71
+ ChatModel,
72
+ ChatOptions,
73
+ CitationsAskResponseItem,
74
+ DebugAskResponseItem,
75
+ ErrorAskResponseItem,
76
+ FindParagraph,
77
+ FindRequest,
78
+ JSONAskResponseItem,
79
+ KnowledgeboxFindResults,
80
+ MetadataAskResponseItem,
81
+ MinScore,
82
+ NucliaDBClientType,
83
+ PrequeriesAskResponseItem,
84
+ PreQueriesStrategy,
85
+ PreQuery,
86
+ PreQueryResult,
87
+ PromptContext,
88
+ PromptContextOrder,
89
+ RagStrategyName,
90
+ Relations,
91
+ RelationsAskResponseItem,
92
+ RetrievalAskResponseItem,
93
+ SearchOptions,
94
+ StatusAskResponseItem,
95
+ SyncAskMetadata,
96
+ SyncAskResponse,
97
+ UserPrompt,
98
+ parse_custom_prompt,
99
+ parse_rephrase_prompt,
100
+ )
101
+ from nucliadb_telemetry import errors
102
+ from nucliadb_utils.exceptions import LimitsExceededError
103
+
104
+
105
+ @dataclasses.dataclass
106
+ class RetrievalMatch:
107
+ paragraph: FindParagraph
108
+ weighted_score: float
109
+
110
+
111
+ @dataclasses.dataclass
112
+ class RetrievalResults:
113
+ main_query: KnowledgeboxFindResults
114
+ query_parser: QueryParser
115
+ main_query_weight: float
116
+ prequeries: Optional[list[PreQueryResult]] = None
117
+ best_matches: list[RetrievalMatch] = dataclasses.field(default_factory=list)
118
+
119
+
120
+ class AskResult:
121
+ def __init__(
122
+ self,
123
+ *,
124
+ kbid: str,
125
+ ask_request: AskRequest,
126
+ main_results: KnowledgeboxFindResults,
127
+ prequeries_results: Optional[list[PreQueryResult]],
128
+ nuclia_learning_id: Optional[str],
129
+ predict_answer_stream: AsyncGenerator[GenerativeChunk, None],
130
+ prompt_context: PromptContext,
131
+ prompt_context_order: PromptContextOrder,
132
+ auditor: ChatAuditor,
133
+ metrics: RAGMetrics,
134
+ best_matches: list[RetrievalMatch],
135
+ debug_chat_model: Optional[ChatModel],
136
+ ):
137
+ # Initial attributes
138
+ self.kbid = kbid
139
+ self.ask_request = ask_request
140
+ self.main_results = main_results
141
+ self.prequeries_results = prequeries_results or []
142
+ self.nuclia_learning_id = nuclia_learning_id
143
+ self.predict_answer_stream = predict_answer_stream
144
+ self.prompt_context = prompt_context
145
+ self.debug_chat_model = debug_chat_model
146
+ self.prompt_context_order = prompt_context_order
147
+ self.auditor: ChatAuditor = auditor
148
+ self.metrics: RAGMetrics = metrics
149
+ self.best_matches: list[RetrievalMatch] = best_matches
150
+
151
+ # Computed from the predict chat answer stream
152
+ self._answer_text = ""
153
+ self._object: Optional[JSONGenerativeResponse] = None
154
+ self._status: Optional[StatusGenerativeResponse] = None
155
+ self._citations: Optional[CitationsGenerativeResponse] = None
156
+ self._metadata: Optional[MetaGenerativeResponse] = None
157
+ self._relations: Optional[Relations] = None
158
+
159
+ @property
160
+ def status_code(self) -> AnswerStatusCode:
161
+ if self._status is None:
162
+ return AnswerStatusCode.SUCCESS
163
+ return AnswerStatusCode(self._status.code)
164
+
165
+ @property
166
+ def status_error_details(self) -> Optional[str]:
167
+ if self._status is None: # pragma: no cover
168
+ return None
169
+ return self._status.details
170
+
171
+ @property
172
+ def ask_request_with_relations(self) -> bool:
173
+ return ChatOptions.RELATIONS in self.ask_request.features
174
+
175
+ @property
176
+ def ask_request_with_debug_flag(self) -> bool:
177
+ return self.ask_request.debug
178
+
179
+ async def ndjson_stream(self) -> AsyncGenerator[str, None]:
180
+ try:
181
+ async for item in self._stream():
182
+ yield self._ndjson_encode(item)
183
+ except Exception as exc:
184
+ # Handle any unexpected error that might happen
185
+ # during the streaming and halt the stream
186
+ errors.capture_exception(exc)
187
+ logger.error(
188
+ f"Unexpected error while generating the answer: {exc}",
189
+ extra={"kbid": self.kbid},
190
+ )
191
+ error_message = "Unexpected error while generating the answer. Please try again later."
192
+ if self.ask_request_with_debug_flag:
193
+ error_message += f" Error: {exc}"
194
+ item = ErrorAskResponseItem(error=error_message)
195
+ yield self._ndjson_encode(item)
196
+ return
197
+
198
+ def _ndjson_encode(self, item: AskResponseItemType) -> str:
199
+ result_item = AskResponseItem(item=item)
200
+ return result_item.model_dump_json(exclude_none=True, by_alias=True) + "\n"
201
+
202
+ async def _stream(self) -> AsyncGenerator[AskResponseItemType, None]:
203
+ # First, stream out the predict answer
204
+ first_chunk_yielded = False
205
+ with self.metrics.time("stream_predict_answer"):
206
+ async for answer_chunk in self._stream_predict_answer_text():
207
+ yield AnswerAskResponseItem(text=answer_chunk)
208
+ if not first_chunk_yielded:
209
+ self.metrics.record_first_chunk_yielded()
210
+ first_chunk_yielded = True
211
+
212
+ if self._object is not None:
213
+ yield JSONAskResponseItem(object=self._object.object)
214
+ if not first_chunk_yielded:
215
+ # When there is a JSON generative response, we consider the first chunk yielded
216
+ # to be the moment when the JSON object is yielded, not the text
217
+ self.metrics.record_first_chunk_yielded()
218
+ first_chunk_yielded = True
219
+
220
+ yield RetrievalAskResponseItem(
221
+ results=self.main_results,
222
+ best_matches=[
223
+ AskRetrievalMatch(
224
+ id=match.paragraph.id,
225
+ )
226
+ for match in self.best_matches
227
+ ],
228
+ )
229
+
230
+ if len(self.prequeries_results) > 0:
231
+ item = PrequeriesAskResponseItem()
232
+ for index, (prequery, result) in enumerate(self.prequeries_results):
233
+ prequery_id = prequery.id or f"prequery_{index}"
234
+ item.results[prequery_id] = result
235
+ yield item
236
+
237
+ # Then the status
238
+ if self.status_code == AnswerStatusCode.ERROR:
239
+ # If predict yielded an error status, we yield it too and halt the stream immediately
240
+ yield StatusAskResponseItem(
241
+ code=self.status_code.value,
242
+ status=self.status_code.prettify(),
243
+ details=self.status_error_details or "Unknown error",
244
+ )
245
+ return
246
+
247
+ yield StatusAskResponseItem(
248
+ code=self.status_code.value,
249
+ status=self.status_code.prettify(),
250
+ )
251
+
252
+ # Audit the answer
253
+ if self._object is None:
254
+ audit_answer = self._answer_text.encode("utf-8")
255
+ else:
256
+ audit_answer = json.dumps(self._object.object).encode("utf-8")
257
+
258
+ try:
259
+ rephrase_time = self.metrics.elapsed("rephrase")
260
+ except KeyError:
261
+ # Not all ask requests have a rephrase step
262
+ rephrase_time = None
263
+
264
+ self.auditor.audit(
265
+ text_answer=audit_answer,
266
+ generative_answer_time=self.metrics.elapsed("stream_predict_answer"),
267
+ generative_answer_first_chunk_time=self.metrics.get_first_chunk_time() or 0,
268
+ rephrase_time=rephrase_time,
269
+ status_code=self.status_code,
270
+ )
271
+
272
+ # Stream out the citations
273
+ if self._citations is not None:
274
+ yield CitationsAskResponseItem(citations=self._citations.citations)
275
+
276
+ # Stream out generic metadata about the answer
277
+ if self._metadata is not None:
278
+ yield MetadataAskResponseItem(
279
+ tokens=AskTokens(
280
+ input=self._metadata.input_tokens,
281
+ output=self._metadata.output_tokens,
282
+ input_nuclia=self._metadata.input_nuclia_tokens,
283
+ output_nuclia=self._metadata.output_nuclia_tokens,
284
+ ),
285
+ timings=AskTimings(
286
+ generative_first_chunk=self._metadata.timings.get("generative_first_chunk"),
287
+ generative_total=self._metadata.timings.get("generative"),
288
+ ),
289
+ )
290
+
291
+ # Stream out the relations results
292
+ should_query_relations = (
293
+ self.ask_request_with_relations and self.status_code == AnswerStatusCode.SUCCESS
294
+ )
295
+ if should_query_relations:
296
+ relations = await self.get_relations_results()
297
+ yield RelationsAskResponseItem(relations=relations)
298
+
299
+ # Stream out debug information
300
+ if self.ask_request_with_debug_flag:
301
+ predict_request = None
302
+ if self.debug_chat_model:
303
+ predict_request = self.debug_chat_model.model_dump(mode="json")
304
+ yield DebugAskResponseItem(
305
+ metadata={
306
+ "prompt_context": sorted_prompt_context_list(
307
+ self.prompt_context, self.prompt_context_order
308
+ ),
309
+ "predict_request": predict_request,
310
+ }
311
+ )
312
+
313
+ async def json(self) -> str:
314
+ # First, run the stream in memory to get all the data in memory
315
+ async for _ in self._stream():
316
+ ...
317
+
318
+ metadata = None
319
+ if self._metadata is not None:
320
+ metadata = SyncAskMetadata(
321
+ tokens=AskTokens(
322
+ input=self._metadata.input_tokens,
323
+ output=self._metadata.output_tokens,
324
+ input_nuclia=self._metadata.input_nuclia_tokens,
325
+ output_nuclia=self._metadata.output_nuclia_tokens,
326
+ ),
327
+ timings=AskTimings(
328
+ generative_first_chunk=self._metadata.timings.get("generative_first_chunk"),
329
+ generative_total=self._metadata.timings.get("generative"),
330
+ ),
331
+ )
332
+ citations = {}
333
+ if self._citations is not None:
334
+ citations = self._citations.citations
335
+
336
+ answer_json = None
337
+ if self._object is not None:
338
+ answer_json = self._object.object
339
+
340
+ prequeries_results: Optional[dict[str, KnowledgeboxFindResults]] = None
341
+ if self.prequeries_results:
342
+ prequeries_results = {}
343
+ for index, (prequery, result) in enumerate(self.prequeries_results):
344
+ prequery_id = prequery.id or f"prequery_{index}"
345
+ prequeries_results[prequery_id] = result
346
+
347
+ best_matches = [
348
+ AskRetrievalMatch(
349
+ id=match.paragraph.id,
350
+ )
351
+ for match in self.best_matches
352
+ ]
353
+
354
+ response = SyncAskResponse(
355
+ answer=self._answer_text,
356
+ answer_json=answer_json,
357
+ status=self.status_code.prettify(),
358
+ relations=self._relations,
359
+ retrieval_results=self.main_results,
360
+ retrieval_best_matches=best_matches,
361
+ prequeries=prequeries_results,
362
+ citations=citations,
363
+ metadata=metadata,
364
+ learning_id=self.nuclia_learning_id or "",
365
+ )
366
+ if self.status_code == AnswerStatusCode.ERROR and self.status_error_details:
367
+ response.error_details = self.status_error_details
368
+ if self.ask_request_with_debug_flag:
369
+ sorted_prompt_context = sorted_prompt_context_list(
370
+ self.prompt_context, self.prompt_context_order
371
+ )
372
+ response.prompt_context = sorted_prompt_context
373
+ if self.debug_chat_model:
374
+ response.predict_request = self.debug_chat_model.model_dump(mode="json")
375
+ return response.model_dump_json(exclude_none=True, by_alias=True)
376
+
377
+ async def get_relations_results(self) -> Relations:
378
+ if self._relations is None:
379
+ with self.metrics.time("relations"):
380
+ self._relations = await get_relations_results(
381
+ kbid=self.kbid,
382
+ text_answer=self._answer_text,
383
+ target_shard_replicas=self.ask_request.shards,
384
+ timeout=5.0,
385
+ )
386
+ return self._relations
387
+
388
+ async def _stream_predict_answer_text(self) -> AsyncGenerator[str, None]:
389
+ """
390
+ Reads the stream of the generative model, yielding the answer text but also parsing
391
+ other items like status codes, citations and miscellaneous metadata.
392
+
393
+ This method does not assume any order in the stream of items, but it assumes that at least
394
+ the answer text is streamed in order.
395
+ """
396
+ async for generative_chunk in self.predict_answer_stream:
397
+ item = generative_chunk.chunk
398
+ if isinstance(item, TextGenerativeResponse):
399
+ self._answer_text += item.text
400
+ yield item.text
401
+ elif isinstance(item, JSONGenerativeResponse):
402
+ self._object = item
403
+ elif isinstance(item, StatusGenerativeResponse):
404
+ self._status = item
405
+ elif isinstance(item, CitationsGenerativeResponse):
406
+ self._citations = item
407
+ elif isinstance(item, MetaGenerativeResponse):
408
+ self._metadata = item
409
+ else:
410
+ logger.warning(
411
+ f"Unexpected item in predict answer stream: {item}",
412
+ extra={"kbid": self.kbid},
413
+ )
414
+
415
+
416
+ class NotEnoughContextAskResult(AskResult):
417
+ def __init__(
418
+ self,
419
+ main_results: Optional[KnowledgeboxFindResults] = None,
420
+ prequeries_results: Optional[list[PreQueryResult]] = None,
421
+ ):
422
+ self.main_results = main_results or KnowledgeboxFindResults(resources={}, min_score=None)
423
+ self.prequeries_results = prequeries_results or []
424
+ self.nuclia_learning_id = None
425
+
426
+ async def ndjson_stream(self) -> AsyncGenerator[str, None]:
427
+ """
428
+ In the case where there are no results in the retrieval phase, we simply
429
+ return the find results and the messages indicating that there is not enough
430
+ context in the corpus to answer.
431
+ """
432
+ yield self._ndjson_encode(RetrievalAskResponseItem(results=self.main_results))
433
+ yield self._ndjson_encode(AnswerAskResponseItem(text=NOT_ENOUGH_CONTEXT_ANSWER))
434
+ status = AnswerStatusCode.NO_CONTEXT
435
+ yield self._ndjson_encode(StatusAskResponseItem(code=status.value, status=status.prettify()))
436
+
437
+ async def json(self) -> str:
438
+ return SyncAskResponse(
439
+ answer=NOT_ENOUGH_CONTEXT_ANSWER,
440
+ retrieval_results=self.main_results,
441
+ status=AnswerStatusCode.NO_CONTEXT,
442
+ ).model_dump_json()
443
+
444
+
445
+ async def ask(
446
+ *,
447
+ kbid: str,
448
+ ask_request: AskRequest,
449
+ user_id: str,
450
+ client_type: NucliaDBClientType,
451
+ origin: str,
452
+ resource: Optional[str] = None,
453
+ ) -> AskResult:
454
+ metrics = RAGMetrics()
455
+ chat_history = ask_request.context or []
456
+ user_context = ask_request.extra_context or []
457
+ user_query = ask_request.query
458
+
459
+ # Maybe rephrase the query
460
+ rephrased_query = None
461
+ if len(chat_history) > 0 or len(user_context) > 0:
462
+ try:
463
+ with metrics.time("rephrase"):
464
+ rephrased_query = await rephrase_query(
465
+ kbid,
466
+ chat_history=chat_history,
467
+ query=user_query,
468
+ user_id=user_id,
469
+ user_context=user_context,
470
+ generative_model=ask_request.generative_model,
471
+ )
472
+ except RephraseMissingContextError:
473
+ logger.info("Failed to rephrase ask query, using original")
474
+
475
+ try:
476
+ retrieval_results = await retrieval_step(
477
+ kbid=kbid,
478
+ # Prefer the rephrased query for retrieval if available
479
+ main_query=rephrased_query or user_query,
480
+ ask_request=ask_request,
481
+ client_type=client_type,
482
+ user_id=user_id,
483
+ origin=origin,
484
+ metrics=metrics,
485
+ resource=resource,
486
+ )
487
+ except NoRetrievalResultsError as err:
488
+ # If a retrieval was attempted but no results were found,
489
+ # early return the ask endpoint without querying the generative model
490
+ return NotEnoughContextAskResult(
491
+ main_results=err.main_query,
492
+ prequeries_results=err.prequeries,
493
+ )
494
+
495
+ query_parser = retrieval_results.query_parser
496
+
497
+ # Now we build the prompt context
498
+ with metrics.time("context_building"):
499
+ query_parser.max_tokens = ask_request.max_tokens # type: ignore
500
+ max_tokens_context = await query_parser.get_max_tokens_context()
501
+ prompt_context_builder = PromptContextBuilder(
502
+ kbid=kbid,
503
+ ordered_paragraphs=[match.paragraph for match in retrieval_results.best_matches],
504
+ resource=resource,
505
+ user_context=user_context,
506
+ strategies=ask_request.rag_strategies,
507
+ image_strategies=ask_request.rag_images_strategies,
508
+ max_context_characters=tokens_to_chars(max_tokens_context),
509
+ visual_llm=await query_parser.get_visual_llm_enabled(),
510
+ )
511
+ (
512
+ prompt_context,
513
+ prompt_context_order,
514
+ prompt_context_images,
515
+ ) = await prompt_context_builder.build()
516
+
517
+ # Make the chat request to the predict API
518
+ custom_prompt = parse_custom_prompt(ask_request)
519
+ chat_model = ChatModel(
520
+ user_id=user_id,
521
+ system=custom_prompt.system,
522
+ user_prompt=UserPrompt(prompt=custom_prompt.user) if custom_prompt.user else None,
523
+ query_context=prompt_context,
524
+ query_context_order=prompt_context_order,
525
+ chat_history=chat_history,
526
+ question=user_query,
527
+ truncate=True,
528
+ citations=ask_request.citations,
529
+ citation_threshold=ask_request.citation_threshold,
530
+ generative_model=ask_request.generative_model,
531
+ max_tokens=query_parser.get_max_tokens_answer(),
532
+ query_context_images=prompt_context_images,
533
+ json_schema=ask_request.answer_json_schema,
534
+ rerank_context=False,
535
+ top_k=ask_request.top_k,
536
+ )
537
+ with metrics.time("stream_start"):
538
+ predict = get_predict()
539
+ (
540
+ nuclia_learning_id,
541
+ nuclia_learning_model,
542
+ predict_answer_stream,
543
+ ) = await predict.chat_query_ndjson(kbid, chat_model)
544
+ debug_chat_model = chat_model
545
+
546
+ auditor = ChatAuditor(
547
+ kbid=kbid,
548
+ user_id=user_id,
549
+ client_type=client_type,
550
+ origin=origin,
551
+ user_query=user_query,
552
+ rephrased_query=rephrased_query,
553
+ chat_history=chat_history,
554
+ learning_id=nuclia_learning_id,
555
+ query_context=prompt_context,
556
+ query_context_order=prompt_context_order,
557
+ model=nuclia_learning_model,
558
+ )
559
+ return AskResult(
560
+ kbid=kbid,
561
+ ask_request=ask_request,
562
+ main_results=retrieval_results.main_query,
563
+ prequeries_results=retrieval_results.prequeries,
564
+ nuclia_learning_id=nuclia_learning_id,
565
+ predict_answer_stream=predict_answer_stream, # type: ignore
566
+ prompt_context=prompt_context,
567
+ prompt_context_order=prompt_context_order,
568
+ auditor=auditor,
569
+ metrics=metrics,
570
+ best_matches=retrieval_results.best_matches,
571
+ debug_chat_model=debug_chat_model,
572
+ )
573
+
574
+
575
+ def handled_ask_exceptions(func):
576
+ @functools.wraps(func)
577
+ async def wrapper(*args, **kwargs):
578
+ try:
579
+ return await func(*args, **kwargs)
580
+ except KnowledgeBoxNotFound:
581
+ return HTTPClientError(
582
+ status_code=404,
583
+ detail=f"Knowledge Box not found.",
584
+ )
585
+ except LimitsExceededError as exc:
586
+ return HTTPClientError(status_code=exc.status_code, detail=exc.detail)
587
+ except predict.ProxiedPredictAPIError as err:
588
+ return HTTPClientError(
589
+ status_code=err.status,
590
+ detail=err.detail,
591
+ )
592
+ except IncompleteFindResultsError:
593
+ return HTTPClientError(
594
+ status_code=529,
595
+ detail="Temporary error on information retrieval. Please try again.",
596
+ )
597
+ except predict.RephraseMissingContextError:
598
+ return HTTPClientError(
599
+ status_code=412,
600
+ detail="Unable to rephrase the query with the provided context.",
601
+ )
602
+ except predict.RephraseError as err:
603
+ return HTTPClientError(
604
+ status_code=529,
605
+ detail=f"Temporary error while rephrasing the query. Please try again later. Error: {err}",
606
+ )
607
+ except InvalidQueryError as exc:
608
+ return HTTPClientError(status_code=412, detail=str(exc))
609
+
610
+ return wrapper
611
+
612
+
613
+ def parse_prequeries(ask_request: AskRequest) -> Optional[PreQueriesStrategy]:
614
+ query_ids = []
615
+ for rag_strategy in ask_request.rag_strategies:
616
+ if rag_strategy.name == RagStrategyName.PREQUERIES:
617
+ prequeries = cast(PreQueriesStrategy, rag_strategy)
618
+ # Give each query a unique id if they don't have one
619
+ for index, query in enumerate(prequeries.queries):
620
+ if query.id is None:
621
+ query.id = f"prequery_{index}"
622
+ if query.id in query_ids:
623
+ raise InvalidQueryError(
624
+ "rag_strategies",
625
+ "Prequeries must have unique ids",
626
+ )
627
+ query_ids.append(query.id)
628
+ return prequeries
629
+ return None
630
+
631
+
632
+ async def retrieval_step(
633
+ kbid: str,
634
+ main_query: str,
635
+ ask_request: AskRequest,
636
+ client_type: NucliaDBClientType,
637
+ user_id: str,
638
+ origin: str,
639
+ metrics: RAGMetrics,
640
+ resource: Optional[str] = None,
641
+ ) -> RetrievalResults:
642
+ """
643
+ This function encapsulates all the logic related to retrieval in the ask endpoint.
644
+ """
645
+ if resource is None:
646
+ return await retrieval_in_kb(
647
+ kbid,
648
+ main_query,
649
+ ask_request,
650
+ client_type,
651
+ user_id,
652
+ origin,
653
+ metrics,
654
+ )
655
+ else:
656
+ return await retrieval_in_resource(
657
+ kbid,
658
+ resource,
659
+ main_query,
660
+ ask_request,
661
+ client_type,
662
+ user_id,
663
+ origin,
664
+ metrics,
665
+ )
666
+
667
+
668
+ async def retrieval_in_kb(
669
+ kbid: str,
670
+ main_query: str,
671
+ ask_request: AskRequest,
672
+ client_type: NucliaDBClientType,
673
+ user_id: str,
674
+ origin: str,
675
+ metrics: RAGMetrics,
676
+ ) -> RetrievalResults:
677
+ prequeries = parse_prequeries(ask_request)
678
+ with metrics.time("retrieval"):
679
+ main_results, prequeries_results, query_parser = await get_find_results(
680
+ kbid=kbid,
681
+ query=main_query,
682
+ item=ask_request,
683
+ ndb_client=client_type,
684
+ user=user_id,
685
+ origin=origin,
686
+ metrics=metrics,
687
+ prequeries_strategy=prequeries,
688
+ )
689
+ if len(main_results.resources) == 0 and all(
690
+ len(prequery_result.resources) == 0 for (_, prequery_result) in prequeries_results or []
691
+ ):
692
+ raise NoRetrievalResultsError(main_results, prequeries_results)
693
+
694
+ main_query_weight = prequeries.main_query_weight if prequeries is not None else 1.0
695
+ best_matches = compute_best_matches(
696
+ main_results=main_results,
697
+ prequeries_results=prequeries_results,
698
+ main_query_weight=main_query_weight,
699
+ )
700
+ return RetrievalResults(
701
+ main_query=main_results,
702
+ prequeries=prequeries_results,
703
+ query_parser=query_parser,
704
+ main_query_weight=main_query_weight,
705
+ best_matches=best_matches,
706
+ )
707
+
708
+
709
+ async def retrieval_in_resource(
710
+ kbid: str,
711
+ resource: str,
712
+ main_query: str,
713
+ ask_request: AskRequest,
714
+ client_type: NucliaDBClientType,
715
+ user_id: str,
716
+ origin: str,
717
+ metrics: RAGMetrics,
718
+ ) -> RetrievalResults:
719
+ if any(strategy.name == "full_resource" for strategy in ask_request.rag_strategies):
720
+ # Retrieval is not needed if we are chatting on a specific resource and the full_resource strategy is enabled
721
+ return RetrievalResults(
722
+ main_query=KnowledgeboxFindResults(resources={}, min_score=None),
723
+ prequeries=None,
724
+ query_parser=QueryParser(
725
+ kbid=kbid,
726
+ features=[],
727
+ query="",
728
+ label_filters=ask_request.filters,
729
+ keyword_filters=ask_request.keyword_filters,
730
+ top_k=0,
731
+ min_score=MinScore(),
732
+ ),
733
+ main_query_weight=1.0,
734
+ )
735
+
736
+ prequeries = parse_prequeries(ask_request)
737
+ if prequeries is None and ask_request.answer_json_schema is not None and main_query == "":
738
+ prequeries = calculate_prequeries_for_json_schema(ask_request)
739
+
740
+ # Make sure the retrieval is scoped to the resource if provided
741
+ ask_request.resource_filters = [resource]
742
+ if prequeries is not None:
743
+ for prequery in prequeries.queries:
744
+ if prequery.prefilter is True:
745
+ raise InvalidQueryError(
746
+ "rag_strategies",
747
+ "Prequeries with prefilter are not supported when asking on a resource",
748
+ )
749
+ prequery.request.resource_filters = [resource]
750
+
751
+ with metrics.time("retrieval"):
752
+ main_results, prequeries_results, query_parser = await get_find_results(
753
+ kbid=kbid,
754
+ query=main_query,
755
+ item=ask_request,
756
+ ndb_client=client_type,
757
+ user=user_id,
758
+ origin=origin,
759
+ metrics=metrics,
760
+ prequeries_strategy=prequeries,
761
+ )
762
+ if len(main_results.resources) == 0 and all(
763
+ len(prequery_result.resources) == 0 for (_, prequery_result) in prequeries_results or []
764
+ ):
765
+ raise NoRetrievalResultsError(main_results, prequeries_results)
766
+ main_query_weight = prequeries.main_query_weight if prequeries is not None else 1.0
767
+ best_matches = compute_best_matches(
768
+ main_results=main_results,
769
+ prequeries_results=prequeries_results,
770
+ main_query_weight=main_query_weight,
771
+ )
772
+ return RetrievalResults(
773
+ main_query=main_results,
774
+ prequeries=prequeries_results,
775
+ query_parser=query_parser,
776
+ main_query_weight=main_query_weight,
777
+ best_matches=best_matches,
778
+ )
779
+
780
+
781
+ def compute_best_matches(
782
+ main_results: KnowledgeboxFindResults,
783
+ prequeries_results: Optional[list[PreQueryResult]] = None,
784
+ main_query_weight: float = 1.0,
785
+ ) -> list[RetrievalMatch]:
786
+ """
787
+ Returns the list of matches of the retrieval results, ordered by relevance (descending weighted score).
788
+
789
+ If prequeries_results is provided, the paragraphs of the prequeries are weighted according to the
790
+ normalized weight of the prequery. The paragraph score is not modified, but it is used to determine the order in which they
791
+ are presented in the LLM prompt context.
792
+
793
+ If a paragraph is matched in various prequeries, the final weighted score is the sum of the weighted scores for each prequery.
794
+
795
+ `main_query_weight` is the weight given to the paragraphs matching the main query when calculating the final score.
796
+ """
797
+
798
+ def iter_paragraphs(results: KnowledgeboxFindResults):
799
+ for resource in results.resources.values():
800
+ for field in resource.fields.values():
801
+ for paragraph in field.paragraphs.values():
802
+ yield paragraph
803
+
804
+ total_weights = main_query_weight + sum(prequery.weight for prequery, _ in prequeries_results or [])
805
+ paragraph_id_to_match: dict[str, RetrievalMatch] = {}
806
+ for paragraph in iter_paragraphs(main_results):
807
+ normalized_weight = main_query_weight / total_weights
808
+ rmatch = RetrievalMatch(
809
+ paragraph=paragraph,
810
+ weighted_score=paragraph.score * normalized_weight,
811
+ )
812
+ paragraph_id_to_match[paragraph.id] = rmatch
813
+
814
+ for prequery, prequery_results in prequeries_results or []:
815
+ for paragraph in iter_paragraphs(prequery_results):
816
+ normalized_weight = prequery.weight / total_weights
817
+ weighted_score = paragraph.score * normalized_weight
818
+ if paragraph.id in paragraph_id_to_match:
819
+ rmatch = paragraph_id_to_match[paragraph.id]
820
+ # If a paragraph is matched in various prequeries, the final score is the
821
+ # sum of the weighted scores
822
+ rmatch.weighted_score += weighted_score
823
+ else:
824
+ paragraph_id_to_match[paragraph.id] = RetrievalMatch(
825
+ paragraph=paragraph,
826
+ weighted_score=weighted_score,
827
+ )
828
+
829
+ return sorted(
830
+ paragraph_id_to_match.values(),
831
+ key=lambda match: match.weighted_score,
832
+ reverse=True,
833
+ )
834
+
835
+
836
+ def calculate_prequeries_for_json_schema(
837
+ ask_request: AskRequest,
838
+ ) -> Optional[PreQueriesStrategy]:
839
+ """
840
+ This function generates a PreQueriesStrategy with a query for each property in the JSON schema
841
+ found in ask_request.answer_json_schema.
842
+
843
+ This is useful for the use-case where the user is asking for a structured answer on a corpus
844
+ that is too big to send to the generative model.
845
+
846
+ For instance, a JSON schema like this:
847
+ {
848
+ "name": "book_ordering",
849
+ "description": "Structured answer for a book to order",
850
+ "parameters": {
851
+ "type": "object",
852
+ "properties": {
853
+ "title": {
854
+ "type": "string",
855
+ "description": "The title of the book"
856
+ },
857
+ "author": {
858
+ "type": "string",
859
+ "description": "The author of the book"
860
+ },
861
+ },
862
+ "required": ["title", "author"]
863
+ }
864
+ }
865
+ Will generate a PreQueriesStrategy with 2 queries, one for each property in the JSON schema, with equal weights
866
+ [
867
+ PreQuery(request=FindRequest(query="The title of the book", ...), weight=1.0),
868
+ PreQuery(request=FindRequest(query="The author of the book", ...), weight=1.0),
869
+ ]
870
+ """
871
+ prequeries: list[PreQuery] = []
872
+ json_schema = ask_request.answer_json_schema or {}
873
+ features = []
874
+ if ChatOptions.SEMANTIC in ask_request.features:
875
+ features.append(SearchOptions.SEMANTIC)
876
+ if ChatOptions.KEYWORD in ask_request.features:
877
+ features.append(SearchOptions.KEYWORD)
878
+
879
+ properties = json_schema.get("parameters", {}).get("properties", {})
880
+ if len(properties) == 0: # pragma: no cover
881
+ return None
882
+ for prop_name, prop_def in properties.items():
883
+ query = prop_name
884
+ if prop_def.get("description"):
885
+ query += f": {prop_def['description']}"
886
+ req = FindRequest(
887
+ query=query,
888
+ features=features,
889
+ filters=[],
890
+ keyword_filters=[],
891
+ top_k=10,
892
+ min_score=ask_request.min_score,
893
+ vectorset=ask_request.vectorset,
894
+ highlight=False,
895
+ debug=False,
896
+ show=[],
897
+ with_duplicates=False,
898
+ with_synonyms=False,
899
+ resource_filters=[], # to be filled with the resource filter
900
+ rephrase=ask_request.rephrase,
901
+ rephrase_prompt=parse_rephrase_prompt(ask_request),
902
+ security=ask_request.security,
903
+ autofilter=False,
904
+ )
905
+ prequery = PreQuery(
906
+ request=req,
907
+ weight=1.0,
908
+ )
909
+ prequeries.append(prequery)
910
+ try:
911
+ strategy = PreQueriesStrategy(queries=prequeries)
912
+ except ValidationError:
913
+ raise AnswerJsonSchemaTooLong(
914
+ "Answer JSON schema with too many properties generated too many prequeries"
915
+ )
916
+
917
+ ask_request.rag_strategies = [strategy]
918
+ return strategy