nucliadb 4.0.0.post542__py3-none-any.whl → 6.2.1.post2798__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (418) hide show
  1. migrations/0003_allfields_key.py +1 -35
  2. migrations/0009_upgrade_relations_and_texts_to_v2.py +4 -2
  3. migrations/0010_fix_corrupt_indexes.py +10 -10
  4. migrations/0011_materialize_labelset_ids.py +1 -16
  5. migrations/0012_rollover_shards.py +5 -10
  6. migrations/0014_rollover_shards.py +4 -5
  7. migrations/0015_targeted_rollover.py +5 -10
  8. migrations/0016_upgrade_to_paragraphs_v2.py +25 -28
  9. migrations/0017_multiple_writable_shards.py +2 -4
  10. migrations/0018_purge_orphan_kbslugs.py +5 -7
  11. migrations/0019_upgrade_to_paragraphs_v3.py +25 -28
  12. migrations/0020_drain_nodes_from_cluster.py +3 -3
  13. nucliadb/standalone/tests/unit/test_run.py → migrations/0021_overwrite_vectorsets_key.py +16 -19
  14. nucliadb/tests/unit/test_openapi.py → migrations/0022_fix_paragraph_deletion_bug.py +16 -11
  15. migrations/0023_backfill_pg_catalog.py +80 -0
  16. migrations/0025_assign_models_to_kbs_v2.py +113 -0
  17. migrations/0026_fix_high_cardinality_content_types.py +61 -0
  18. migrations/0027_rollover_texts3.py +73 -0
  19. nucliadb/ingest/fields/date.py → migrations/pg/0001_bootstrap.py +10 -12
  20. migrations/pg/0002_catalog.py +42 -0
  21. nucliadb/ingest/tests/unit/test_settings.py → migrations/pg/0003_catalog_kbid_index.py +5 -3
  22. nucliadb/common/cluster/base.py +30 -16
  23. nucliadb/common/cluster/discovery/base.py +6 -14
  24. nucliadb/common/cluster/discovery/k8s.py +9 -19
  25. nucliadb/common/cluster/discovery/manual.py +1 -3
  26. nucliadb/common/cluster/discovery/utils.py +1 -3
  27. nucliadb/common/cluster/grpc_node_dummy.py +3 -11
  28. nucliadb/common/cluster/index_node.py +10 -19
  29. nucliadb/common/cluster/manager.py +174 -59
  30. nucliadb/common/cluster/rebalance.py +27 -29
  31. nucliadb/common/cluster/rollover.py +353 -194
  32. nucliadb/common/cluster/settings.py +6 -0
  33. nucliadb/common/cluster/standalone/grpc_node_binding.py +13 -64
  34. nucliadb/common/cluster/standalone/index_node.py +4 -11
  35. nucliadb/common/cluster/standalone/service.py +2 -6
  36. nucliadb/common/cluster/standalone/utils.py +2 -6
  37. nucliadb/common/cluster/utils.py +29 -22
  38. nucliadb/common/constants.py +20 -0
  39. nucliadb/common/context/__init__.py +3 -0
  40. nucliadb/common/context/fastapi.py +8 -5
  41. nucliadb/{tests/knowledgeboxes/__init__.py → common/counters.py} +8 -2
  42. nucliadb/common/datamanagers/__init__.py +7 -1
  43. nucliadb/common/datamanagers/atomic.py +22 -4
  44. nucliadb/common/datamanagers/cluster.py +5 -5
  45. nucliadb/common/datamanagers/entities.py +6 -16
  46. nucliadb/common/datamanagers/fields.py +84 -0
  47. nucliadb/common/datamanagers/kb.py +83 -37
  48. nucliadb/common/datamanagers/labels.py +26 -56
  49. nucliadb/common/datamanagers/processing.py +2 -6
  50. nucliadb/common/datamanagers/resources.py +41 -103
  51. nucliadb/common/datamanagers/rollover.py +76 -15
  52. nucliadb/common/datamanagers/synonyms.py +1 -1
  53. nucliadb/common/datamanagers/utils.py +15 -6
  54. nucliadb/common/datamanagers/vectorsets.py +110 -0
  55. nucliadb/common/external_index_providers/base.py +257 -0
  56. nucliadb/{ingest/tests/unit/orm/test_orm_utils.py → common/external_index_providers/exceptions.py} +9 -8
  57. nucliadb/common/external_index_providers/manager.py +101 -0
  58. nucliadb/common/external_index_providers/pinecone.py +933 -0
  59. nucliadb/common/external_index_providers/settings.py +52 -0
  60. nucliadb/common/http_clients/auth.py +3 -6
  61. nucliadb/common/http_clients/processing.py +6 -11
  62. nucliadb/common/http_clients/utils.py +1 -3
  63. nucliadb/common/ids.py +240 -0
  64. nucliadb/common/locking.py +29 -7
  65. nucliadb/common/maindb/driver.py +11 -35
  66. nucliadb/common/maindb/exceptions.py +3 -0
  67. nucliadb/common/maindb/local.py +22 -9
  68. nucliadb/common/maindb/pg.py +206 -111
  69. nucliadb/common/maindb/utils.py +11 -42
  70. nucliadb/common/models_utils/from_proto.py +479 -0
  71. nucliadb/common/models_utils/to_proto.py +60 -0
  72. nucliadb/common/nidx.py +260 -0
  73. nucliadb/export_import/datamanager.py +25 -19
  74. nucliadb/export_import/exporter.py +5 -11
  75. nucliadb/export_import/importer.py +5 -7
  76. nucliadb/export_import/models.py +3 -3
  77. nucliadb/export_import/tasks.py +4 -4
  78. nucliadb/export_import/utils.py +25 -37
  79. nucliadb/health.py +1 -3
  80. nucliadb/ingest/app.py +15 -11
  81. nucliadb/ingest/consumer/auditing.py +21 -19
  82. nucliadb/ingest/consumer/consumer.py +82 -47
  83. nucliadb/ingest/consumer/materializer.py +5 -12
  84. nucliadb/ingest/consumer/pull.py +12 -27
  85. nucliadb/ingest/consumer/service.py +19 -17
  86. nucliadb/ingest/consumer/shard_creator.py +2 -4
  87. nucliadb/ingest/consumer/utils.py +1 -3
  88. nucliadb/ingest/fields/base.py +137 -105
  89. nucliadb/ingest/fields/conversation.py +18 -5
  90. nucliadb/ingest/fields/exceptions.py +1 -4
  91. nucliadb/ingest/fields/file.py +7 -16
  92. nucliadb/ingest/fields/link.py +5 -10
  93. nucliadb/ingest/fields/text.py +9 -4
  94. nucliadb/ingest/orm/brain.py +200 -213
  95. nucliadb/ingest/orm/broker_message.py +181 -0
  96. nucliadb/ingest/orm/entities.py +36 -51
  97. nucliadb/ingest/orm/exceptions.py +12 -0
  98. nucliadb/ingest/orm/knowledgebox.py +322 -197
  99. nucliadb/ingest/orm/processor/__init__.py +2 -700
  100. nucliadb/ingest/orm/processor/auditing.py +4 -23
  101. nucliadb/ingest/orm/processor/data_augmentation.py +164 -0
  102. nucliadb/ingest/orm/processor/pgcatalog.py +84 -0
  103. nucliadb/ingest/orm/processor/processor.py +752 -0
  104. nucliadb/ingest/orm/processor/sequence_manager.py +1 -1
  105. nucliadb/ingest/orm/resource.py +249 -403
  106. nucliadb/ingest/orm/utils.py +4 -4
  107. nucliadb/ingest/partitions.py +3 -9
  108. nucliadb/ingest/processing.py +70 -73
  109. nucliadb/ingest/py.typed +0 -0
  110. nucliadb/ingest/serialize.py +37 -167
  111. nucliadb/ingest/service/__init__.py +1 -3
  112. nucliadb/ingest/service/writer.py +185 -412
  113. nucliadb/ingest/settings.py +10 -20
  114. nucliadb/ingest/utils.py +3 -6
  115. nucliadb/learning_proxy.py +242 -55
  116. nucliadb/metrics_exporter.py +30 -19
  117. nucliadb/middleware/__init__.py +1 -3
  118. nucliadb/migrator/command.py +1 -3
  119. nucliadb/migrator/datamanager.py +13 -13
  120. nucliadb/migrator/migrator.py +47 -30
  121. nucliadb/migrator/utils.py +18 -10
  122. nucliadb/purge/__init__.py +139 -33
  123. nucliadb/purge/orphan_shards.py +7 -13
  124. nucliadb/reader/__init__.py +1 -3
  125. nucliadb/reader/api/models.py +1 -12
  126. nucliadb/reader/api/v1/__init__.py +0 -1
  127. nucliadb/reader/api/v1/download.py +21 -88
  128. nucliadb/reader/api/v1/export_import.py +1 -1
  129. nucliadb/reader/api/v1/knowledgebox.py +10 -10
  130. nucliadb/reader/api/v1/learning_config.py +2 -6
  131. nucliadb/reader/api/v1/resource.py +62 -88
  132. nucliadb/reader/api/v1/services.py +64 -83
  133. nucliadb/reader/app.py +12 -29
  134. nucliadb/reader/lifecycle.py +18 -4
  135. nucliadb/reader/py.typed +0 -0
  136. nucliadb/reader/reader/notifications.py +10 -28
  137. nucliadb/search/__init__.py +1 -3
  138. nucliadb/search/api/v1/__init__.py +1 -2
  139. nucliadb/search/api/v1/ask.py +17 -10
  140. nucliadb/search/api/v1/catalog.py +184 -0
  141. nucliadb/search/api/v1/feedback.py +16 -24
  142. nucliadb/search/api/v1/find.py +36 -36
  143. nucliadb/search/api/v1/knowledgebox.py +89 -60
  144. nucliadb/search/api/v1/resource/ask.py +2 -8
  145. nucliadb/search/api/v1/resource/search.py +49 -70
  146. nucliadb/search/api/v1/search.py +44 -210
  147. nucliadb/search/api/v1/suggest.py +39 -54
  148. nucliadb/search/app.py +12 -32
  149. nucliadb/search/lifecycle.py +10 -3
  150. nucliadb/search/predict.py +136 -187
  151. nucliadb/search/py.typed +0 -0
  152. nucliadb/search/requesters/utils.py +25 -58
  153. nucliadb/search/search/cache.py +149 -20
  154. nucliadb/search/search/chat/ask.py +571 -123
  155. nucliadb/search/{tests/unit/test_run.py → search/chat/exceptions.py} +14 -14
  156. nucliadb/search/search/chat/images.py +41 -17
  157. nucliadb/search/search/chat/prompt.py +817 -266
  158. nucliadb/search/search/chat/query.py +213 -309
  159. nucliadb/{tests/migrations/__init__.py → search/search/cut.py} +8 -8
  160. nucliadb/search/search/fetch.py +43 -36
  161. nucliadb/search/search/filters.py +9 -15
  162. nucliadb/search/search/find.py +214 -53
  163. nucliadb/search/search/find_merge.py +408 -391
  164. nucliadb/search/search/hydrator.py +191 -0
  165. nucliadb/search/search/merge.py +187 -223
  166. nucliadb/search/search/metrics.py +73 -2
  167. nucliadb/search/search/paragraphs.py +64 -106
  168. nucliadb/search/search/pgcatalog.py +233 -0
  169. nucliadb/search/search/predict_proxy.py +1 -1
  170. nucliadb/search/search/query.py +305 -150
  171. nucliadb/search/search/query_parser/exceptions.py +22 -0
  172. nucliadb/search/search/query_parser/models.py +101 -0
  173. nucliadb/search/search/query_parser/parser.py +183 -0
  174. nucliadb/search/search/rank_fusion.py +204 -0
  175. nucliadb/search/search/rerankers.py +270 -0
  176. nucliadb/search/search/shards.py +3 -32
  177. nucliadb/search/search/summarize.py +7 -18
  178. nucliadb/search/search/utils.py +27 -4
  179. nucliadb/search/settings.py +15 -1
  180. nucliadb/standalone/api_router.py +4 -10
  181. nucliadb/standalone/app.py +8 -14
  182. nucliadb/standalone/auth.py +7 -21
  183. nucliadb/standalone/config.py +7 -10
  184. nucliadb/standalone/lifecycle.py +26 -25
  185. nucliadb/standalone/migrations.py +1 -3
  186. nucliadb/standalone/purge.py +1 -1
  187. nucliadb/standalone/py.typed +0 -0
  188. nucliadb/standalone/run.py +3 -6
  189. nucliadb/standalone/settings.py +9 -16
  190. nucliadb/standalone/versions.py +15 -5
  191. nucliadb/tasks/consumer.py +8 -12
  192. nucliadb/tasks/producer.py +7 -6
  193. nucliadb/tests/config.py +53 -0
  194. nucliadb/train/__init__.py +1 -3
  195. nucliadb/train/api/utils.py +1 -2
  196. nucliadb/train/api/v1/shards.py +1 -1
  197. nucliadb/train/api/v1/trainset.py +2 -4
  198. nucliadb/train/app.py +10 -31
  199. nucliadb/train/generator.py +10 -19
  200. nucliadb/train/generators/field_classifier.py +7 -19
  201. nucliadb/train/generators/field_streaming.py +156 -0
  202. nucliadb/train/generators/image_classifier.py +12 -18
  203. nucliadb/train/generators/paragraph_classifier.py +5 -9
  204. nucliadb/train/generators/paragraph_streaming.py +6 -9
  205. nucliadb/train/generators/question_answer_streaming.py +19 -20
  206. nucliadb/train/generators/sentence_classifier.py +9 -15
  207. nucliadb/train/generators/token_classifier.py +48 -39
  208. nucliadb/train/generators/utils.py +14 -18
  209. nucliadb/train/lifecycle.py +7 -3
  210. nucliadb/train/nodes.py +23 -32
  211. nucliadb/train/py.typed +0 -0
  212. nucliadb/train/servicer.py +13 -21
  213. nucliadb/train/settings.py +2 -6
  214. nucliadb/train/types.py +13 -10
  215. nucliadb/train/upload.py +3 -6
  216. nucliadb/train/uploader.py +19 -23
  217. nucliadb/train/utils.py +1 -1
  218. nucliadb/writer/__init__.py +1 -3
  219. nucliadb/{ingest/fields/keywordset.py → writer/api/utils.py} +13 -10
  220. nucliadb/writer/api/v1/export_import.py +67 -14
  221. nucliadb/writer/api/v1/field.py +16 -269
  222. nucliadb/writer/api/v1/knowledgebox.py +218 -68
  223. nucliadb/writer/api/v1/resource.py +68 -88
  224. nucliadb/writer/api/v1/services.py +51 -70
  225. nucliadb/writer/api/v1/slug.py +61 -0
  226. nucliadb/writer/api/v1/transaction.py +67 -0
  227. nucliadb/writer/api/v1/upload.py +143 -117
  228. nucliadb/writer/app.py +6 -43
  229. nucliadb/writer/back_pressure.py +16 -38
  230. nucliadb/writer/exceptions.py +0 -4
  231. nucliadb/writer/lifecycle.py +21 -15
  232. nucliadb/writer/py.typed +0 -0
  233. nucliadb/writer/resource/audit.py +2 -1
  234. nucliadb/writer/resource/basic.py +48 -46
  235. nucliadb/writer/resource/field.py +37 -128
  236. nucliadb/writer/resource/origin.py +1 -2
  237. nucliadb/writer/settings.py +6 -2
  238. nucliadb/writer/tus/__init__.py +17 -15
  239. nucliadb/writer/tus/azure.py +111 -0
  240. nucliadb/writer/tus/dm.py +17 -5
  241. nucliadb/writer/tus/exceptions.py +1 -3
  242. nucliadb/writer/tus/gcs.py +49 -84
  243. nucliadb/writer/tus/local.py +21 -37
  244. nucliadb/writer/tus/s3.py +28 -68
  245. nucliadb/writer/tus/storage.py +5 -56
  246. nucliadb/writer/vectorsets.py +125 -0
  247. nucliadb-6.2.1.post2798.dist-info/METADATA +148 -0
  248. nucliadb-6.2.1.post2798.dist-info/RECORD +343 -0
  249. {nucliadb-4.0.0.post542.dist-info → nucliadb-6.2.1.post2798.dist-info}/WHEEL +1 -1
  250. nucliadb/common/maindb/redis.py +0 -194
  251. nucliadb/common/maindb/tikv.py +0 -433
  252. nucliadb/ingest/fields/layout.py +0 -58
  253. nucliadb/ingest/tests/conftest.py +0 -30
  254. nucliadb/ingest/tests/fixtures.py +0 -764
  255. nucliadb/ingest/tests/integration/consumer/__init__.py +0 -18
  256. nucliadb/ingest/tests/integration/consumer/test_auditing.py +0 -78
  257. nucliadb/ingest/tests/integration/consumer/test_materializer.py +0 -126
  258. nucliadb/ingest/tests/integration/consumer/test_pull.py +0 -144
  259. nucliadb/ingest/tests/integration/consumer/test_service.py +0 -81
  260. nucliadb/ingest/tests/integration/consumer/test_shard_creator.py +0 -68
  261. nucliadb/ingest/tests/integration/ingest/test_ingest.py +0 -684
  262. nucliadb/ingest/tests/integration/ingest/test_processing_engine.py +0 -95
  263. nucliadb/ingest/tests/integration/ingest/test_relations.py +0 -272
  264. nucliadb/ingest/tests/unit/consumer/__init__.py +0 -18
  265. nucliadb/ingest/tests/unit/consumer/test_auditing.py +0 -139
  266. nucliadb/ingest/tests/unit/consumer/test_consumer.py +0 -69
  267. nucliadb/ingest/tests/unit/consumer/test_pull.py +0 -60
  268. nucliadb/ingest/tests/unit/consumer/test_shard_creator.py +0 -140
  269. nucliadb/ingest/tests/unit/consumer/test_utils.py +0 -67
  270. nucliadb/ingest/tests/unit/orm/__init__.py +0 -19
  271. nucliadb/ingest/tests/unit/orm/test_brain.py +0 -247
  272. nucliadb/ingest/tests/unit/orm/test_brain_vectors.py +0 -74
  273. nucliadb/ingest/tests/unit/orm/test_processor.py +0 -131
  274. nucliadb/ingest/tests/unit/orm/test_resource.py +0 -331
  275. nucliadb/ingest/tests/unit/test_cache.py +0 -31
  276. nucliadb/ingest/tests/unit/test_partitions.py +0 -40
  277. nucliadb/ingest/tests/unit/test_processing.py +0 -171
  278. nucliadb/middleware/transaction.py +0 -117
  279. nucliadb/reader/api/v1/learning_collector.py +0 -63
  280. nucliadb/reader/tests/__init__.py +0 -19
  281. nucliadb/reader/tests/conftest.py +0 -31
  282. nucliadb/reader/tests/fixtures.py +0 -136
  283. nucliadb/reader/tests/test_list_resources.py +0 -75
  284. nucliadb/reader/tests/test_reader_file_download.py +0 -273
  285. nucliadb/reader/tests/test_reader_resource.py +0 -353
  286. nucliadb/reader/tests/test_reader_resource_field.py +0 -219
  287. nucliadb/search/api/v1/chat.py +0 -263
  288. nucliadb/search/api/v1/resource/chat.py +0 -174
  289. nucliadb/search/tests/__init__.py +0 -19
  290. nucliadb/search/tests/conftest.py +0 -33
  291. nucliadb/search/tests/fixtures.py +0 -199
  292. nucliadb/search/tests/node.py +0 -466
  293. nucliadb/search/tests/unit/__init__.py +0 -18
  294. nucliadb/search/tests/unit/api/__init__.py +0 -19
  295. nucliadb/search/tests/unit/api/v1/__init__.py +0 -19
  296. nucliadb/search/tests/unit/api/v1/resource/__init__.py +0 -19
  297. nucliadb/search/tests/unit/api/v1/resource/test_chat.py +0 -98
  298. nucliadb/search/tests/unit/api/v1/test_ask.py +0 -120
  299. nucliadb/search/tests/unit/api/v1/test_chat.py +0 -96
  300. nucliadb/search/tests/unit/api/v1/test_predict_proxy.py +0 -98
  301. nucliadb/search/tests/unit/api/v1/test_summarize.py +0 -99
  302. nucliadb/search/tests/unit/search/__init__.py +0 -18
  303. nucliadb/search/tests/unit/search/requesters/__init__.py +0 -18
  304. nucliadb/search/tests/unit/search/requesters/test_utils.py +0 -211
  305. nucliadb/search/tests/unit/search/search/__init__.py +0 -19
  306. nucliadb/search/tests/unit/search/search/test_shards.py +0 -45
  307. nucliadb/search/tests/unit/search/search/test_utils.py +0 -82
  308. nucliadb/search/tests/unit/search/test_chat_prompt.py +0 -270
  309. nucliadb/search/tests/unit/search/test_fetch.py +0 -108
  310. nucliadb/search/tests/unit/search/test_filters.py +0 -125
  311. nucliadb/search/tests/unit/search/test_paragraphs.py +0 -157
  312. nucliadb/search/tests/unit/search/test_predict_proxy.py +0 -106
  313. nucliadb/search/tests/unit/search/test_query.py +0 -153
  314. nucliadb/search/tests/unit/test_app.py +0 -79
  315. nucliadb/search/tests/unit/test_find_merge.py +0 -112
  316. nucliadb/search/tests/unit/test_merge.py +0 -34
  317. nucliadb/search/tests/unit/test_predict.py +0 -525
  318. nucliadb/standalone/tests/__init__.py +0 -19
  319. nucliadb/standalone/tests/conftest.py +0 -33
  320. nucliadb/standalone/tests/fixtures.py +0 -38
  321. nucliadb/standalone/tests/unit/__init__.py +0 -18
  322. nucliadb/standalone/tests/unit/test_api_router.py +0 -61
  323. nucliadb/standalone/tests/unit/test_auth.py +0 -169
  324. nucliadb/standalone/tests/unit/test_introspect.py +0 -35
  325. nucliadb/standalone/tests/unit/test_migrations.py +0 -63
  326. nucliadb/standalone/tests/unit/test_versions.py +0 -68
  327. nucliadb/tests/benchmarks/__init__.py +0 -19
  328. nucliadb/tests/benchmarks/test_search.py +0 -99
  329. nucliadb/tests/conftest.py +0 -32
  330. nucliadb/tests/fixtures.py +0 -735
  331. nucliadb/tests/knowledgeboxes/philosophy_books.py +0 -202
  332. nucliadb/tests/knowledgeboxes/ten_dummy_resources.py +0 -107
  333. nucliadb/tests/migrations/test_migration_0017.py +0 -76
  334. nucliadb/tests/migrations/test_migration_0018.py +0 -95
  335. nucliadb/tests/tikv.py +0 -240
  336. nucliadb/tests/unit/__init__.py +0 -19
  337. nucliadb/tests/unit/common/__init__.py +0 -19
  338. nucliadb/tests/unit/common/cluster/__init__.py +0 -19
  339. nucliadb/tests/unit/common/cluster/discovery/__init__.py +0 -19
  340. nucliadb/tests/unit/common/cluster/discovery/test_k8s.py +0 -172
  341. nucliadb/tests/unit/common/cluster/standalone/__init__.py +0 -18
  342. nucliadb/tests/unit/common/cluster/standalone/test_service.py +0 -114
  343. nucliadb/tests/unit/common/cluster/standalone/test_utils.py +0 -61
  344. nucliadb/tests/unit/common/cluster/test_cluster.py +0 -408
  345. nucliadb/tests/unit/common/cluster/test_kb_shard_manager.py +0 -173
  346. nucliadb/tests/unit/common/cluster/test_rebalance.py +0 -38
  347. nucliadb/tests/unit/common/cluster/test_rollover.py +0 -282
  348. nucliadb/tests/unit/common/maindb/__init__.py +0 -18
  349. nucliadb/tests/unit/common/maindb/test_driver.py +0 -127
  350. nucliadb/tests/unit/common/maindb/test_tikv.py +0 -53
  351. nucliadb/tests/unit/common/maindb/test_utils.py +0 -92
  352. nucliadb/tests/unit/common/test_context.py +0 -36
  353. nucliadb/tests/unit/export_import/__init__.py +0 -19
  354. nucliadb/tests/unit/export_import/test_datamanager.py +0 -37
  355. nucliadb/tests/unit/export_import/test_utils.py +0 -301
  356. nucliadb/tests/unit/migrator/__init__.py +0 -19
  357. nucliadb/tests/unit/migrator/test_migrator.py +0 -87
  358. nucliadb/tests/unit/tasks/__init__.py +0 -19
  359. nucliadb/tests/unit/tasks/conftest.py +0 -42
  360. nucliadb/tests/unit/tasks/test_consumer.py +0 -92
  361. nucliadb/tests/unit/tasks/test_producer.py +0 -95
  362. nucliadb/tests/unit/tasks/test_tasks.py +0 -58
  363. nucliadb/tests/unit/test_field_ids.py +0 -49
  364. nucliadb/tests/unit/test_health.py +0 -86
  365. nucliadb/tests/unit/test_kb_slugs.py +0 -54
  366. nucliadb/tests/unit/test_learning_proxy.py +0 -252
  367. nucliadb/tests/unit/test_metrics_exporter.py +0 -77
  368. nucliadb/tests/unit/test_purge.py +0 -136
  369. nucliadb/tests/utils/__init__.py +0 -74
  370. nucliadb/tests/utils/aiohttp_session.py +0 -44
  371. nucliadb/tests/utils/broker_messages/__init__.py +0 -171
  372. nucliadb/tests/utils/broker_messages/fields.py +0 -197
  373. nucliadb/tests/utils/broker_messages/helpers.py +0 -33
  374. nucliadb/tests/utils/entities.py +0 -78
  375. nucliadb/train/api/v1/check.py +0 -60
  376. nucliadb/train/tests/__init__.py +0 -19
  377. nucliadb/train/tests/conftest.py +0 -29
  378. nucliadb/train/tests/fixtures.py +0 -342
  379. nucliadb/train/tests/test_field_classification.py +0 -122
  380. nucliadb/train/tests/test_get_entities.py +0 -80
  381. nucliadb/train/tests/test_get_info.py +0 -51
  382. nucliadb/train/tests/test_get_ontology.py +0 -34
  383. nucliadb/train/tests/test_get_ontology_count.py +0 -63
  384. nucliadb/train/tests/test_image_classification.py +0 -221
  385. nucliadb/train/tests/test_list_fields.py +0 -39
  386. nucliadb/train/tests/test_list_paragraphs.py +0 -73
  387. nucliadb/train/tests/test_list_resources.py +0 -39
  388. nucliadb/train/tests/test_list_sentences.py +0 -71
  389. nucliadb/train/tests/test_paragraph_classification.py +0 -123
  390. nucliadb/train/tests/test_paragraph_streaming.py +0 -118
  391. nucliadb/train/tests/test_question_answer_streaming.py +0 -239
  392. nucliadb/train/tests/test_sentence_classification.py +0 -143
  393. nucliadb/train/tests/test_token_classification.py +0 -136
  394. nucliadb/train/tests/utils.py +0 -101
  395. nucliadb/writer/layouts/__init__.py +0 -51
  396. nucliadb/writer/layouts/v1.py +0 -59
  397. nucliadb/writer/tests/__init__.py +0 -19
  398. nucliadb/writer/tests/conftest.py +0 -31
  399. nucliadb/writer/tests/fixtures.py +0 -191
  400. nucliadb/writer/tests/test_fields.py +0 -475
  401. nucliadb/writer/tests/test_files.py +0 -740
  402. nucliadb/writer/tests/test_knowledgebox.py +0 -49
  403. nucliadb/writer/tests/test_reprocess_file_field.py +0 -133
  404. nucliadb/writer/tests/test_resources.py +0 -476
  405. nucliadb/writer/tests/test_service.py +0 -137
  406. nucliadb/writer/tests/test_tus.py +0 -203
  407. nucliadb/writer/tests/utils.py +0 -35
  408. nucliadb/writer/tus/pg.py +0 -125
  409. nucliadb-4.0.0.post542.dist-info/METADATA +0 -135
  410. nucliadb-4.0.0.post542.dist-info/RECORD +0 -462
  411. {nucliadb/ingest/tests → migrations/pg}/__init__.py +0 -0
  412. /nucliadb/{ingest/tests/integration → common/external_index_providers}/__init__.py +0 -0
  413. /nucliadb/{ingest/tests/integration/ingest → common/models_utils}/__init__.py +0 -0
  414. /nucliadb/{ingest/tests/unit → search/search/query_parser}/__init__.py +0 -0
  415. /nucliadb/{ingest/tests → tests}/vectors.py +0 -0
  416. {nucliadb-4.0.0.post542.dist-info → nucliadb-6.2.1.post2798.dist-info}/entry_points.txt +0 -0
  417. {nucliadb-4.0.0.post542.dist-info → nucliadb-6.2.1.post2798.dist-info}/top_level.txt +0 -0
  418. {nucliadb-4.0.0.post542.dist-info → nucliadb-6.2.1.post2798.dist-info}/zip-safe +0 -0
@@ -17,21 +17,32 @@
17
17
  # You should have received a copy of the GNU Affero General Public License
18
18
  # along with this program. If not, see <http://www.gnu.org/licenses/>.
19
19
  #
20
+ import dataclasses
20
21
  import functools
21
- from time import monotonic as time
22
- from typing import AsyncGenerator, Optional
22
+ import json
23
+ from typing import AsyncGenerator, Optional, cast
23
24
 
24
- from nucliadb.common.datamanagers.exceptions import KnowledgeBoxNotFound
25
- from nucliadb.models.responses import HTTPClientError
26
- from nucliadb.search import logger, predict
27
- from nucliadb.search.predict import (
28
- AnswerStatusCode,
25
+ from nuclia_models.predict.generative_responses import (
29
26
  CitationsGenerativeResponse,
30
27
  GenerativeChunk,
28
+ JSONGenerativeResponse,
31
29
  MetaGenerativeResponse,
32
30
  StatusGenerativeResponse,
33
31
  TextGenerativeResponse,
34
32
  )
33
+ from pydantic_core import ValidationError
34
+
35
+ from nucliadb.common.datamanagers.exceptions import KnowledgeBoxNotFound
36
+ from nucliadb.models.responses import HTTPClientError
37
+ from nucliadb.search import logger, predict
38
+ from nucliadb.search.predict import (
39
+ AnswerStatusCode,
40
+ RephraseMissingContextError,
41
+ )
42
+ from nucliadb.search.search.chat.exceptions import (
43
+ AnswerJsonSchemaTooLong,
44
+ NoRetrievalResultsError,
45
+ )
35
46
  from nucliadb.search.search.chat.prompt import PromptContextBuilder
36
47
  from nucliadb.search.search.chat.query import (
37
48
  NOT_ENOUGH_CONTEXT_ANSWER,
@@ -46,6 +57,7 @@ from nucliadb.search.search.exceptions import (
46
57
  IncompleteFindResultsError,
47
58
  InvalidQueryError,
48
59
  )
60
+ from nucliadb.search.search.metrics import RAGMetrics
49
61
  from nucliadb.search.search.query import QueryParser
50
62
  from nucliadb.search.utilities import get_predict
51
63
  from nucliadb_models.search import (
@@ -53,6 +65,7 @@ from nucliadb_models.search import (
53
65
  AskRequest,
54
66
  AskResponseItem,
55
67
  AskResponseItemType,
68
+ AskRetrievalMatch,
56
69
  AskTimings,
57
70
  AskTokens,
58
71
  ChatModel,
@@ -60,48 +73,84 @@ from nucliadb_models.search import (
60
73
  CitationsAskResponseItem,
61
74
  DebugAskResponseItem,
62
75
  ErrorAskResponseItem,
76
+ FindParagraph,
77
+ FindRequest,
78
+ JSONAskResponseItem,
63
79
  KnowledgeboxFindResults,
64
80
  MetadataAskResponseItem,
65
81
  MinScore,
66
82
  NucliaDBClientType,
83
+ PrequeriesAskResponseItem,
84
+ PreQueriesStrategy,
85
+ PreQuery,
86
+ PreQueryResult,
67
87
  PromptContext,
68
88
  PromptContextOrder,
89
+ RagStrategyName,
69
90
  Relations,
70
91
  RelationsAskResponseItem,
71
92
  RetrievalAskResponseItem,
93
+ SearchOptions,
72
94
  StatusAskResponseItem,
73
95
  SyncAskMetadata,
74
96
  SyncAskResponse,
75
97
  UserPrompt,
98
+ parse_custom_prompt,
99
+ parse_rephrase_prompt,
76
100
  )
101
+ from nucliadb_telemetry import errors
77
102
  from nucliadb_utils.exceptions import LimitsExceededError
78
103
 
79
104
 
105
+ @dataclasses.dataclass
106
+ class RetrievalMatch:
107
+ paragraph: FindParagraph
108
+ weighted_score: float
109
+
110
+
111
+ @dataclasses.dataclass
112
+ class RetrievalResults:
113
+ main_query: KnowledgeboxFindResults
114
+ query_parser: QueryParser
115
+ main_query_weight: float
116
+ prequeries: Optional[list[PreQueryResult]] = None
117
+ best_matches: list[RetrievalMatch] = dataclasses.field(default_factory=list)
118
+
119
+
80
120
  class AskResult:
81
121
  def __init__(
82
122
  self,
83
123
  *,
84
124
  kbid: str,
85
125
  ask_request: AskRequest,
86
- find_results: KnowledgeboxFindResults,
126
+ main_results: KnowledgeboxFindResults,
127
+ prequeries_results: Optional[list[PreQueryResult]],
87
128
  nuclia_learning_id: Optional[str],
88
129
  predict_answer_stream: AsyncGenerator[GenerativeChunk, None],
89
130
  prompt_context: PromptContext,
90
131
  prompt_context_order: PromptContextOrder,
91
132
  auditor: ChatAuditor,
133
+ metrics: RAGMetrics,
134
+ best_matches: list[RetrievalMatch],
135
+ debug_chat_model: Optional[ChatModel],
92
136
  ):
93
137
  # Initial attributes
94
138
  self.kbid = kbid
95
139
  self.ask_request = ask_request
96
- self.find_results = find_results
140
+ self.main_results = main_results
141
+ self.prequeries_results = prequeries_results or []
97
142
  self.nuclia_learning_id = nuclia_learning_id
98
143
  self.predict_answer_stream = predict_answer_stream
99
144
  self.prompt_context = prompt_context
145
+ self.debug_chat_model = debug_chat_model
100
146
  self.prompt_context_order = prompt_context_order
101
- self.auditor = auditor
147
+ self.auditor: ChatAuditor = auditor
148
+ self.metrics: RAGMetrics = metrics
149
+ self.best_matches: list[RetrievalMatch] = best_matches
102
150
 
103
151
  # Computed from the predict chat answer stream
104
152
  self._answer_text = ""
153
+ self._object: Optional[JSONGenerativeResponse] = None
105
154
  self._status: Optional[StatusGenerativeResponse] = None
106
155
  self._citations: Optional[CitationsGenerativeResponse] = None
107
156
  self._metadata: Optional[MetaGenerativeResponse] = None
@@ -113,6 +162,12 @@ class AskResult:
113
162
  return AnswerStatusCode.SUCCESS
114
163
  return AnswerStatusCode(self._status.code)
115
164
 
165
+ @property
166
+ def status_error_details(self) -> Optional[str]:
167
+ if self._status is None: # pragma: no cover
168
+ return None
169
+ return self._status.details
170
+
116
171
  @property
117
172
  def ask_request_with_relations(self) -> bool:
118
173
  return ChatOptions.RELATIONS in self.ask_request.features
@@ -128,34 +183,89 @@ class AskResult:
128
183
  except Exception as exc:
129
184
  # Handle any unexpected error that might happen
130
185
  # during the streaming and halt the stream
131
- item = ErrorAskResponseItem(error=str(exc))
132
- yield self._ndjson_encode(item)
133
-
134
- staus = AnswerStatusCode.ERROR
135
- item = StatusAskResponseItem(code=staus.value, status=staus.prettify())
186
+ errors.capture_exception(exc)
187
+ logger.error(
188
+ f"Unexpected error while generating the answer: {exc}",
189
+ extra={"kbid": self.kbid},
190
+ )
191
+ error_message = "Unexpected error while generating the answer. Please try again later."
192
+ if self.ask_request_with_debug_flag:
193
+ error_message += f" Error: {exc}"
194
+ item = ErrorAskResponseItem(error=error_message)
136
195
  yield self._ndjson_encode(item)
137
196
  return
138
197
 
139
198
  def _ndjson_encode(self, item: AskResponseItemType) -> str:
140
199
  result_item = AskResponseItem(item=item)
141
- return result_item.json(exclude_unset=False, exclude_none=True) + "\n"
200
+ return result_item.model_dump_json(exclude_none=True, by_alias=True) + "\n"
142
201
 
143
202
  async def _stream(self) -> AsyncGenerator[AskResponseItemType, None]:
144
- # First stream out the find results
145
- yield RetrievalAskResponseItem(results=self.find_results)
203
+ # First, stream out the predict answer
204
+ first_chunk_yielded = False
205
+ with self.metrics.time("stream_predict_answer"):
206
+ async for answer_chunk in self._stream_predict_answer_text():
207
+ yield AnswerAskResponseItem(text=answer_chunk)
208
+ if not first_chunk_yielded:
209
+ self.metrics.record_first_chunk_yielded()
210
+ first_chunk_yielded = True
211
+
212
+ if self._object is not None:
213
+ yield JSONAskResponseItem(object=self._object.object)
214
+ if not first_chunk_yielded:
215
+ # When there is a JSON generative response, we consider the first chunk yielded
216
+ # to be the moment when the JSON object is yielded, not the text
217
+ self.metrics.record_first_chunk_yielded()
218
+ first_chunk_yielded = True
219
+
220
+ yield RetrievalAskResponseItem(
221
+ results=self.main_results,
222
+ best_matches=[
223
+ AskRetrievalMatch(
224
+ id=match.paragraph.id,
225
+ )
226
+ for match in self.best_matches
227
+ ],
228
+ )
146
229
 
147
- # Then stream out the predict answer
148
- async for answer_chunk in self._stream_predict_answer_text():
149
- yield AnswerAskResponseItem(text=answer_chunk)
230
+ if len(self.prequeries_results) > 0:
231
+ item = PrequeriesAskResponseItem()
232
+ for index, (prequery, result) in enumerate(self.prequeries_results):
233
+ prequery_id = prequery.id or f"prequery_{index}"
234
+ item.results[prequery_id] = result
235
+ yield item
236
+
237
+ # Then the status
238
+ if self.status_code == AnswerStatusCode.ERROR:
239
+ # If predict yielded an error status, we yield it too and halt the stream immediately
240
+ yield StatusAskResponseItem(
241
+ code=self.status_code.value,
242
+ status=self.status_code.prettify(),
243
+ details=self.status_error_details or "Unknown error",
244
+ )
245
+ return
150
246
 
151
- # Then the status code
152
247
  yield StatusAskResponseItem(
153
- code=self.status_code.value, status=self.status_code.prettify()
248
+ code=self.status_code.value,
249
+ status=self.status_code.prettify(),
154
250
  )
155
251
 
156
252
  # Audit the answer
157
- await self.auditor.audit(
158
- text_answer=self._answer_text.encode("utf-8"),
253
+ if self._object is None:
254
+ audit_answer = self._answer_text.encode("utf-8")
255
+ else:
256
+ audit_answer = json.dumps(self._object.object).encode("utf-8")
257
+
258
+ try:
259
+ rephrase_time = self.metrics.elapsed("rephrase")
260
+ except KeyError:
261
+ # Not all ask requests have a rephrase step
262
+ rephrase_time = None
263
+
264
+ self.auditor.audit(
265
+ text_answer=audit_answer,
266
+ generative_answer_time=self.metrics.elapsed("stream_predict_answer"),
267
+ generative_answer_first_chunk_time=self.metrics.get_first_chunk_time() or 0,
268
+ rephrase_time=rephrase_time,
159
269
  status_code=self.status_code,
160
270
  )
161
271
 
@@ -163,25 +273,24 @@ class AskResult:
163
273
  if self._citations is not None:
164
274
  yield CitationsAskResponseItem(citations=self._citations.citations)
165
275
 
166
- # Stream out other metadata about the answer if available
276
+ # Stream out generic metadata about the answer
167
277
  if self._metadata is not None:
168
278
  yield MetadataAskResponseItem(
169
279
  tokens=AskTokens(
170
280
  input=self._metadata.input_tokens,
171
281
  output=self._metadata.output_tokens,
282
+ input_nuclia=self._metadata.input_nuclia_tokens,
283
+ output_nuclia=self._metadata.output_nuclia_tokens,
172
284
  ),
173
285
  timings=AskTimings(
174
- generative_first_chunk=self._metadata.timings.get(
175
- "generative_first_chunk"
176
- ),
286
+ generative_first_chunk=self._metadata.timings.get("generative_first_chunk"),
177
287
  generative_total=self._metadata.timings.get("generative"),
178
288
  ),
179
289
  )
180
290
 
181
291
  # Stream out the relations results
182
292
  should_query_relations = (
183
- self.ask_request_with_relations
184
- and self.status_code != AnswerStatusCode.NO_CONTEXT
293
+ self.ask_request_with_relations and self.status_code == AnswerStatusCode.SUCCESS
185
294
  )
186
295
  if should_query_relations:
187
296
  relations = await self.get_relations_results()
@@ -189,11 +298,15 @@ class AskResult:
189
298
 
190
299
  # Stream out debug information
191
300
  if self.ask_request_with_debug_flag:
301
+ predict_request = None
302
+ if self.debug_chat_model:
303
+ predict_request = self.debug_chat_model.model_dump(mode="json")
192
304
  yield DebugAskResponseItem(
193
305
  metadata={
194
306
  "prompt_context": sorted_prompt_context_list(
195
307
  self.prompt_context, self.prompt_context_order
196
- )
308
+ ),
309
+ "predict_request": predict_request,
197
310
  }
198
311
  )
199
312
 
@@ -208,40 +321,68 @@ class AskResult:
208
321
  tokens=AskTokens(
209
322
  input=self._metadata.input_tokens,
210
323
  output=self._metadata.output_tokens,
324
+ input_nuclia=self._metadata.input_nuclia_tokens,
325
+ output_nuclia=self._metadata.output_nuclia_tokens,
211
326
  ),
212
327
  timings=AskTimings(
213
- generative_first_chunk=self._metadata.timings.get(
214
- "generative_first_chunk"
215
- ),
328
+ generative_first_chunk=self._metadata.timings.get("generative_first_chunk"),
216
329
  generative_total=self._metadata.timings.get("generative"),
217
330
  ),
218
331
  )
219
332
  citations = {}
220
333
  if self._citations is not None:
221
334
  citations = self._citations.citations
335
+
336
+ answer_json = None
337
+ if self._object is not None:
338
+ answer_json = self._object.object
339
+
340
+ prequeries_results: Optional[dict[str, KnowledgeboxFindResults]] = None
341
+ if self.prequeries_results:
342
+ prequeries_results = {}
343
+ for index, (prequery, result) in enumerate(self.prequeries_results):
344
+ prequery_id = prequery.id or f"prequery_{index}"
345
+ prequeries_results[prequery_id] = result
346
+
347
+ best_matches = [
348
+ AskRetrievalMatch(
349
+ id=match.paragraph.id,
350
+ )
351
+ for match in self.best_matches
352
+ ]
353
+
222
354
  response = SyncAskResponse(
223
355
  answer=self._answer_text,
356
+ answer_json=answer_json,
224
357
  status=self.status_code.prettify(),
225
358
  relations=self._relations,
226
- retrieval_results=self.find_results,
359
+ retrieval_results=self.main_results,
360
+ retrieval_best_matches=best_matches,
361
+ prequeries=prequeries_results,
227
362
  citations=citations,
228
363
  metadata=metadata,
229
364
  learning_id=self.nuclia_learning_id or "",
230
365
  )
366
+ if self.status_code == AnswerStatusCode.ERROR and self.status_error_details:
367
+ response.error_details = self.status_error_details
231
368
  if self.ask_request_with_debug_flag:
232
369
  sorted_prompt_context = sorted_prompt_context_list(
233
370
  self.prompt_context, self.prompt_context_order
234
371
  )
235
372
  response.prompt_context = sorted_prompt_context
236
- return response.json(exclude_unset=True)
373
+ if self.debug_chat_model:
374
+ response.predict_request = self.debug_chat_model.model_dump(mode="json")
375
+ return response.model_dump_json(exclude_none=True, by_alias=True)
237
376
 
238
377
  async def get_relations_results(self) -> Relations:
239
378
  if self._relations is None:
240
- self._relations = await get_relations_results(
241
- kbid=self.kbid,
242
- text_answer=self._answer_text,
243
- target_shard_replicas=self.ask_request.shards,
244
- )
379
+ with self.metrics.time("relations"):
380
+ self._relations = await get_relations_results(
381
+ kbid=self.kbid,
382
+ text_answer=self._answer_text,
383
+ target_shard_replicas=self.ask_request.shards,
384
+ timeout=5.0,
385
+ )
245
386
  return self._relations
246
387
 
247
388
  async def _stream_predict_answer_text(self) -> AsyncGenerator[str, None]:
@@ -257,12 +398,12 @@ class AskResult:
257
398
  if isinstance(item, TextGenerativeResponse):
258
399
  self._answer_text += item.text
259
400
  yield item.text
401
+ elif isinstance(item, JSONGenerativeResponse):
402
+ self._object = item
260
403
  elif isinstance(item, StatusGenerativeResponse):
261
404
  self._status = item
262
- continue
263
405
  elif isinstance(item, CitationsGenerativeResponse):
264
406
  self._citations = item
265
- continue
266
407
  elif isinstance(item, MetaGenerativeResponse):
267
408
  self._metadata = item
268
409
  else:
@@ -275,9 +416,11 @@ class AskResult:
275
416
  class NotEnoughContextAskResult(AskResult):
276
417
  def __init__(
277
418
  self,
278
- find_results: KnowledgeboxFindResults,
419
+ main_results: Optional[KnowledgeboxFindResults] = None,
420
+ prequeries_results: Optional[list[PreQueryResult]] = None,
279
421
  ):
280
- self.find_results = find_results
422
+ self.main_results = main_results or KnowledgeboxFindResults(resources={}, min_score=None)
423
+ self.prequeries_results = prequeries_results or []
281
424
  self.nuclia_learning_id = None
282
425
 
283
426
  async def ndjson_stream(self) -> AsyncGenerator[str, None]:
@@ -286,19 +429,17 @@ class NotEnoughContextAskResult(AskResult):
286
429
  return the find results and the messages indicating that there is not enough
287
430
  context in the corpus to answer.
288
431
  """
289
- yield self._ndjson_encode(RetrievalAskResponseItem(results=self.find_results))
432
+ yield self._ndjson_encode(RetrievalAskResponseItem(results=self.main_results))
290
433
  yield self._ndjson_encode(AnswerAskResponseItem(text=NOT_ENOUGH_CONTEXT_ANSWER))
291
434
  status = AnswerStatusCode.NO_CONTEXT
292
- yield self._ndjson_encode(
293
- StatusAskResponseItem(code=status.value, status=status.prettify())
294
- )
435
+ yield self._ndjson_encode(StatusAskResponseItem(code=status.value, status=status.prettify()))
295
436
 
296
437
  async def json(self) -> str:
297
438
  return SyncAskResponse(
298
439
  answer=NOT_ENOUGH_CONTEXT_ANSWER,
299
- retrieval_results=self.find_results,
440
+ retrieval_results=self.main_results,
300
441
  status=AnswerStatusCode.NO_CONTEXT,
301
- ).json(exclude_unset=True)
442
+ ).model_dump_json()
302
443
 
303
444
 
304
445
  async def ask(
@@ -310,7 +451,7 @@ async def ask(
310
451
  origin: str,
311
452
  resource: Optional[str] = None,
312
453
  ) -> AskResult:
313
- start_time = time()
454
+ metrics = RAGMetrics()
314
455
  chat_history = ask_request.context or []
315
456
  user_context = ask_request.extra_context or []
316
457
  user_query = ask_request.query
@@ -318,117 +459,116 @@ async def ask(
318
459
  # Maybe rephrase the query
319
460
  rephrased_query = None
320
461
  if len(chat_history) > 0 or len(user_context) > 0:
321
- rephrased_query = await rephrase_query(
322
- kbid,
323
- chat_history=chat_history,
324
- query=user_query,
325
- user_id=user_id,
326
- user_context=user_context,
327
- generative_model=ask_request.generative_model,
328
- )
329
-
330
- # Retrieval is not needed if we are chatting on a specific
331
- # resource and the full_resource strategy is enabled
332
- needs_retrieval = True
333
- if resource is not None:
334
- ask_request.resource_filters = [resource]
335
- if any(
336
- strategy.name == "full_resource" for strategy in ask_request.rag_strategies
337
- ):
338
- needs_retrieval = False
462
+ try:
463
+ with metrics.time("rephrase"):
464
+ rephrased_query = await rephrase_query(
465
+ kbid,
466
+ chat_history=chat_history,
467
+ query=user_query,
468
+ user_id=user_id,
469
+ user_context=user_context,
470
+ generative_model=ask_request.generative_model,
471
+ )
472
+ except RephraseMissingContextError:
473
+ logger.info("Failed to rephrase ask query, using original")
339
474
 
340
- # Maybe do a retrieval query
341
- if needs_retrieval:
342
- find_results, query_parser = await get_find_results(
475
+ try:
476
+ retrieval_results = await retrieval_step(
343
477
  kbid=kbid,
344
- # Prefer the rephrased query if available
345
- query=rephrased_query or user_query,
346
- chat_request=ask_request,
347
- ndb_client=client_type,
348
- user=user_id,
478
+ # Prefer the rephrased query for retrieval if available
479
+ main_query=rephrased_query or user_query,
480
+ ask_request=ask_request,
481
+ client_type=client_type,
482
+ user_id=user_id,
349
483
  origin=origin,
484
+ metrics=metrics,
485
+ resource=resource,
350
486
  )
351
- if len(find_results.resources) == 0:
352
- return NotEnoughContextAskResult(find_results=find_results)
353
-
354
- else:
355
- find_results = KnowledgeboxFindResults(resources={}, min_score=None)
356
- query_parser = QueryParser(
357
- kbid=kbid,
358
- features=[],
359
- query="",
360
- filters=ask_request.filters,
361
- page_number=0,
362
- page_size=0,
363
- min_score=MinScore(),
487
+ except NoRetrievalResultsError as err:
488
+ # If a retrieval was attempted but no results were found,
489
+ # early return the ask endpoint without querying the generative model
490
+ return NotEnoughContextAskResult(
491
+ main_results=err.main_query,
492
+ prequeries_results=err.prequeries,
364
493
  )
365
494
 
366
- # Now we build the prompt context
367
- query_parser.max_tokens = ask_request.max_tokens # type: ignore
368
- max_tokens_context = await query_parser.get_max_tokens_context()
369
- prompt_context_builder = PromptContextBuilder(
370
- kbid=kbid,
371
- find_results=find_results,
372
- resource=resource,
373
- user_context=user_context,
374
- strategies=ask_request.rag_strategies,
375
- image_strategies=ask_request.rag_images_strategies,
376
- max_context_characters=tokens_to_chars(max_tokens_context),
377
- visual_llm=await query_parser.get_visual_llm_enabled(),
378
- )
379
- (
380
- prompt_context,
381
- prompt_context_order,
382
- prompt_context_images,
383
- ) = await prompt_context_builder.build()
495
+ query_parser = retrieval_results.query_parser
384
496
 
385
- # Parse the user prompt (if any)
386
- user_prompt = None
387
- if ask_request.prompt is not None:
388
- user_prompt = UserPrompt(prompt=ask_request.prompt)
497
+ # Now we build the prompt context
498
+ with metrics.time("context_building"):
499
+ query_parser.max_tokens = ask_request.max_tokens # type: ignore
500
+ max_tokens_context = await query_parser.get_max_tokens_context()
501
+ prompt_context_builder = PromptContextBuilder(
502
+ kbid=kbid,
503
+ ordered_paragraphs=[match.paragraph for match in retrieval_results.best_matches],
504
+ resource=resource,
505
+ user_context=user_context,
506
+ strategies=ask_request.rag_strategies,
507
+ image_strategies=ask_request.rag_images_strategies,
508
+ max_context_characters=tokens_to_chars(max_tokens_context),
509
+ visual_llm=await query_parser.get_visual_llm_enabled(),
510
+ )
511
+ (
512
+ prompt_context,
513
+ prompt_context_order,
514
+ prompt_context_images,
515
+ ) = await prompt_context_builder.build()
389
516
 
390
517
  # Make the chat request to the predict API
518
+ custom_prompt = parse_custom_prompt(ask_request)
391
519
  chat_model = ChatModel(
392
520
  user_id=user_id,
521
+ system=custom_prompt.system,
522
+ user_prompt=UserPrompt(prompt=custom_prompt.user) if custom_prompt.user else None,
393
523
  query_context=prompt_context,
394
524
  query_context_order=prompt_context_order,
395
525
  chat_history=chat_history,
396
526
  question=user_query,
397
527
  truncate=True,
398
- user_prompt=user_prompt,
399
528
  citations=ask_request.citations,
529
+ citation_threshold=ask_request.citation_threshold,
400
530
  generative_model=ask_request.generative_model,
401
531
  max_tokens=query_parser.get_max_tokens_answer(),
402
532
  query_context_images=prompt_context_images,
533
+ json_schema=ask_request.answer_json_schema,
534
+ rerank_context=False,
535
+ top_k=ask_request.top_k,
403
536
  )
404
- predict = get_predict()
405
- nuclia_learning_id, predict_answer_stream = await predict.chat_query_ndjson(
406
- kbid, chat_model
407
- )
537
+ with metrics.time("stream_start"):
538
+ predict = get_predict()
539
+ (
540
+ nuclia_learning_id,
541
+ nuclia_learning_model,
542
+ predict_answer_stream,
543
+ ) = await predict.chat_query_ndjson(kbid, chat_model)
544
+ debug_chat_model = chat_model
408
545
 
409
546
  auditor = ChatAuditor(
410
547
  kbid=kbid,
411
548
  user_id=user_id,
412
549
  client_type=client_type,
413
550
  origin=origin,
414
- start_time=start_time,
415
551
  user_query=user_query,
416
552
  rephrased_query=rephrased_query,
417
553
  chat_history=chat_history,
418
554
  learning_id=nuclia_learning_id,
419
555
  query_context=prompt_context,
420
556
  query_context_order=prompt_context_order,
557
+ model=nuclia_learning_model,
421
558
  )
422
-
423
559
  return AskResult(
424
560
  kbid=kbid,
425
561
  ask_request=ask_request,
426
- find_results=find_results,
562
+ main_results=retrieval_results.main_query,
563
+ prequeries_results=retrieval_results.prequeries,
427
564
  nuclia_learning_id=nuclia_learning_id,
428
- predict_answer_stream=predict_answer_stream,
565
+ predict_answer_stream=predict_answer_stream, # type: ignore
429
566
  prompt_context=prompt_context,
430
567
  prompt_context_order=prompt_context_order,
431
568
  auditor=auditor,
569
+ metrics=metrics,
570
+ best_matches=retrieval_results.best_matches,
571
+ debug_chat_model=debug_chat_model,
432
572
  )
433
573
 
434
574
 
@@ -468,3 +608,311 @@ def handled_ask_exceptions(func):
468
608
  return HTTPClientError(status_code=412, detail=str(exc))
469
609
 
470
610
  return wrapper
611
+
612
+
613
+ def parse_prequeries(ask_request: AskRequest) -> Optional[PreQueriesStrategy]:
614
+ query_ids = []
615
+ for rag_strategy in ask_request.rag_strategies:
616
+ if rag_strategy.name == RagStrategyName.PREQUERIES:
617
+ prequeries = cast(PreQueriesStrategy, rag_strategy)
618
+ # Give each query a unique id if they don't have one
619
+ for index, query in enumerate(prequeries.queries):
620
+ if query.id is None:
621
+ query.id = f"prequery_{index}"
622
+ if query.id in query_ids:
623
+ raise InvalidQueryError(
624
+ "rag_strategies",
625
+ "Prequeries must have unique ids",
626
+ )
627
+ query_ids.append(query.id)
628
+ return prequeries
629
+ return None
630
+
631
+
632
+ async def retrieval_step(
633
+ kbid: str,
634
+ main_query: str,
635
+ ask_request: AskRequest,
636
+ client_type: NucliaDBClientType,
637
+ user_id: str,
638
+ origin: str,
639
+ metrics: RAGMetrics,
640
+ resource: Optional[str] = None,
641
+ ) -> RetrievalResults:
642
+ """
643
+ This function encapsulates all the logic related to retrieval in the ask endpoint.
644
+ """
645
+ if resource is None:
646
+ return await retrieval_in_kb(
647
+ kbid,
648
+ main_query,
649
+ ask_request,
650
+ client_type,
651
+ user_id,
652
+ origin,
653
+ metrics,
654
+ )
655
+ else:
656
+ return await retrieval_in_resource(
657
+ kbid,
658
+ resource,
659
+ main_query,
660
+ ask_request,
661
+ client_type,
662
+ user_id,
663
+ origin,
664
+ metrics,
665
+ )
666
+
667
+
668
+ async def retrieval_in_kb(
669
+ kbid: str,
670
+ main_query: str,
671
+ ask_request: AskRequest,
672
+ client_type: NucliaDBClientType,
673
+ user_id: str,
674
+ origin: str,
675
+ metrics: RAGMetrics,
676
+ ) -> RetrievalResults:
677
+ prequeries = parse_prequeries(ask_request)
678
+ with metrics.time("retrieval"):
679
+ main_results, prequeries_results, query_parser = await get_find_results(
680
+ kbid=kbid,
681
+ query=main_query,
682
+ item=ask_request,
683
+ ndb_client=client_type,
684
+ user=user_id,
685
+ origin=origin,
686
+ metrics=metrics,
687
+ prequeries_strategy=prequeries,
688
+ )
689
+ if len(main_results.resources) == 0 and all(
690
+ len(prequery_result.resources) == 0 for (_, prequery_result) in prequeries_results or []
691
+ ):
692
+ raise NoRetrievalResultsError(main_results, prequeries_results)
693
+
694
+ main_query_weight = prequeries.main_query_weight if prequeries is not None else 1.0
695
+ best_matches = compute_best_matches(
696
+ main_results=main_results,
697
+ prequeries_results=prequeries_results,
698
+ main_query_weight=main_query_weight,
699
+ )
700
+ return RetrievalResults(
701
+ main_query=main_results,
702
+ prequeries=prequeries_results,
703
+ query_parser=query_parser,
704
+ main_query_weight=main_query_weight,
705
+ best_matches=best_matches,
706
+ )
707
+
708
+
709
+ async def retrieval_in_resource(
710
+ kbid: str,
711
+ resource: str,
712
+ main_query: str,
713
+ ask_request: AskRequest,
714
+ client_type: NucliaDBClientType,
715
+ user_id: str,
716
+ origin: str,
717
+ metrics: RAGMetrics,
718
+ ) -> RetrievalResults:
719
+ if any(strategy.name == "full_resource" for strategy in ask_request.rag_strategies):
720
+ # Retrieval is not needed if we are chatting on a specific resource and the full_resource strategy is enabled
721
+ return RetrievalResults(
722
+ main_query=KnowledgeboxFindResults(resources={}, min_score=None),
723
+ prequeries=None,
724
+ query_parser=QueryParser(
725
+ kbid=kbid,
726
+ features=[],
727
+ query="",
728
+ label_filters=ask_request.filters,
729
+ keyword_filters=ask_request.keyword_filters,
730
+ top_k=0,
731
+ min_score=MinScore(),
732
+ ),
733
+ main_query_weight=1.0,
734
+ )
735
+
736
+ prequeries = parse_prequeries(ask_request)
737
+ if prequeries is None and ask_request.answer_json_schema is not None and main_query == "":
738
+ prequeries = calculate_prequeries_for_json_schema(ask_request)
739
+
740
+ # Make sure the retrieval is scoped to the resource if provided
741
+ ask_request.resource_filters = [resource]
742
+ if prequeries is not None:
743
+ for prequery in prequeries.queries:
744
+ if prequery.prefilter is True:
745
+ raise InvalidQueryError(
746
+ "rag_strategies",
747
+ "Prequeries with prefilter are not supported when asking on a resource",
748
+ )
749
+ prequery.request.resource_filters = [resource]
750
+
751
+ with metrics.time("retrieval"):
752
+ main_results, prequeries_results, query_parser = await get_find_results(
753
+ kbid=kbid,
754
+ query=main_query,
755
+ item=ask_request,
756
+ ndb_client=client_type,
757
+ user=user_id,
758
+ origin=origin,
759
+ metrics=metrics,
760
+ prequeries_strategy=prequeries,
761
+ )
762
+ if len(main_results.resources) == 0 and all(
763
+ len(prequery_result.resources) == 0 for (_, prequery_result) in prequeries_results or []
764
+ ):
765
+ raise NoRetrievalResultsError(main_results, prequeries_results)
766
+ main_query_weight = prequeries.main_query_weight if prequeries is not None else 1.0
767
+ best_matches = compute_best_matches(
768
+ main_results=main_results,
769
+ prequeries_results=prequeries_results,
770
+ main_query_weight=main_query_weight,
771
+ )
772
+ return RetrievalResults(
773
+ main_query=main_results,
774
+ prequeries=prequeries_results,
775
+ query_parser=query_parser,
776
+ main_query_weight=main_query_weight,
777
+ best_matches=best_matches,
778
+ )
779
+
780
+
781
+ def compute_best_matches(
782
+ main_results: KnowledgeboxFindResults,
783
+ prequeries_results: Optional[list[PreQueryResult]] = None,
784
+ main_query_weight: float = 1.0,
785
+ ) -> list[RetrievalMatch]:
786
+ """
787
+ Returns the list of matches of the retrieval results, ordered by relevance (descending weighted score).
788
+
789
+ If prequeries_results is provided, the paragraphs of the prequeries are weighted according to the
790
+ normalized weight of the prequery. The paragraph score is not modified, but it is used to determine the order in which they
791
+ are presented in the LLM prompt context.
792
+
793
+ If a paragraph is matched in various prequeries, the final weighted score is the sum of the weighted scores for each prequery.
794
+
795
+ `main_query_weight` is the weight given to the paragraphs matching the main query when calculating the final score.
796
+ """
797
+
798
+ def iter_paragraphs(results: KnowledgeboxFindResults):
799
+ for resource in results.resources.values():
800
+ for field in resource.fields.values():
801
+ for paragraph in field.paragraphs.values():
802
+ yield paragraph
803
+
804
+ total_weights = main_query_weight + sum(prequery.weight for prequery, _ in prequeries_results or [])
805
+ paragraph_id_to_match: dict[str, RetrievalMatch] = {}
806
+ for paragraph in iter_paragraphs(main_results):
807
+ normalized_weight = main_query_weight / total_weights
808
+ rmatch = RetrievalMatch(
809
+ paragraph=paragraph,
810
+ weighted_score=paragraph.score * normalized_weight,
811
+ )
812
+ paragraph_id_to_match[paragraph.id] = rmatch
813
+
814
+ for prequery, prequery_results in prequeries_results or []:
815
+ for paragraph in iter_paragraphs(prequery_results):
816
+ normalized_weight = prequery.weight / total_weights
817
+ weighted_score = paragraph.score * normalized_weight
818
+ if paragraph.id in paragraph_id_to_match:
819
+ rmatch = paragraph_id_to_match[paragraph.id]
820
+ # If a paragraph is matched in various prequeries, the final score is the
821
+ # sum of the weighted scores
822
+ rmatch.weighted_score += weighted_score
823
+ else:
824
+ paragraph_id_to_match[paragraph.id] = RetrievalMatch(
825
+ paragraph=paragraph,
826
+ weighted_score=weighted_score,
827
+ )
828
+
829
+ return sorted(
830
+ paragraph_id_to_match.values(),
831
+ key=lambda match: match.weighted_score,
832
+ reverse=True,
833
+ )
834
+
835
+
836
+ def calculate_prequeries_for_json_schema(
837
+ ask_request: AskRequest,
838
+ ) -> Optional[PreQueriesStrategy]:
839
+ """
840
+ This function generates a PreQueriesStrategy with a query for each property in the JSON schema
841
+ found in ask_request.answer_json_schema.
842
+
843
+ This is useful for the use-case where the user is asking for a structured answer on a corpus
844
+ that is too big to send to the generative model.
845
+
846
+ For instance, a JSON schema like this:
847
+ {
848
+ "name": "book_ordering",
849
+ "description": "Structured answer for a book to order",
850
+ "parameters": {
851
+ "type": "object",
852
+ "properties": {
853
+ "title": {
854
+ "type": "string",
855
+ "description": "The title of the book"
856
+ },
857
+ "author": {
858
+ "type": "string",
859
+ "description": "The author of the book"
860
+ },
861
+ },
862
+ "required": ["title", "author"]
863
+ }
864
+ }
865
+ Will generate a PreQueriesStrategy with 2 queries, one for each property in the JSON schema, with equal weights
866
+ [
867
+ PreQuery(request=FindRequest(query="The title of the book", ...), weight=1.0),
868
+ PreQuery(request=FindRequest(query="The author of the book", ...), weight=1.0),
869
+ ]
870
+ """
871
+ prequeries: list[PreQuery] = []
872
+ json_schema = ask_request.answer_json_schema or {}
873
+ features = []
874
+ if ChatOptions.SEMANTIC in ask_request.features:
875
+ features.append(SearchOptions.SEMANTIC)
876
+ if ChatOptions.KEYWORD in ask_request.features:
877
+ features.append(SearchOptions.KEYWORD)
878
+
879
+ properties = json_schema.get("parameters", {}).get("properties", {})
880
+ if len(properties) == 0: # pragma: no cover
881
+ return None
882
+ for prop_name, prop_def in properties.items():
883
+ query = prop_name
884
+ if prop_def.get("description"):
885
+ query += f": {prop_def['description']}"
886
+ req = FindRequest(
887
+ query=query,
888
+ features=features,
889
+ filters=[],
890
+ keyword_filters=[],
891
+ top_k=10,
892
+ min_score=ask_request.min_score,
893
+ vectorset=ask_request.vectorset,
894
+ highlight=False,
895
+ debug=False,
896
+ show=[],
897
+ with_duplicates=False,
898
+ with_synonyms=False,
899
+ resource_filters=[], # to be filled with the resource filter
900
+ rephrase=ask_request.rephrase,
901
+ rephrase_prompt=parse_rephrase_prompt(ask_request),
902
+ security=ask_request.security,
903
+ autofilter=False,
904
+ )
905
+ prequery = PreQuery(
906
+ request=req,
907
+ weight=1.0,
908
+ )
909
+ prequeries.append(prequery)
910
+ try:
911
+ strategy = PreQueriesStrategy(queries=prequeries)
912
+ except ValidationError:
913
+ raise AnswerJsonSchemaTooLong(
914
+ "Answer JSON schema with too many properties generated too many prequeries"
915
+ )
916
+
917
+ ask_request.rag_strategies = [strategy]
918
+ return strategy