nucliadb 4.0.0.post542__py3-none-any.whl → 6.2.1.post2798__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (418) hide show
  1. migrations/0003_allfields_key.py +1 -35
  2. migrations/0009_upgrade_relations_and_texts_to_v2.py +4 -2
  3. migrations/0010_fix_corrupt_indexes.py +10 -10
  4. migrations/0011_materialize_labelset_ids.py +1 -16
  5. migrations/0012_rollover_shards.py +5 -10
  6. migrations/0014_rollover_shards.py +4 -5
  7. migrations/0015_targeted_rollover.py +5 -10
  8. migrations/0016_upgrade_to_paragraphs_v2.py +25 -28
  9. migrations/0017_multiple_writable_shards.py +2 -4
  10. migrations/0018_purge_orphan_kbslugs.py +5 -7
  11. migrations/0019_upgrade_to_paragraphs_v3.py +25 -28
  12. migrations/0020_drain_nodes_from_cluster.py +3 -3
  13. nucliadb/standalone/tests/unit/test_run.py → migrations/0021_overwrite_vectorsets_key.py +16 -19
  14. nucliadb/tests/unit/test_openapi.py → migrations/0022_fix_paragraph_deletion_bug.py +16 -11
  15. migrations/0023_backfill_pg_catalog.py +80 -0
  16. migrations/0025_assign_models_to_kbs_v2.py +113 -0
  17. migrations/0026_fix_high_cardinality_content_types.py +61 -0
  18. migrations/0027_rollover_texts3.py +73 -0
  19. nucliadb/ingest/fields/date.py → migrations/pg/0001_bootstrap.py +10 -12
  20. migrations/pg/0002_catalog.py +42 -0
  21. nucliadb/ingest/tests/unit/test_settings.py → migrations/pg/0003_catalog_kbid_index.py +5 -3
  22. nucliadb/common/cluster/base.py +30 -16
  23. nucliadb/common/cluster/discovery/base.py +6 -14
  24. nucliadb/common/cluster/discovery/k8s.py +9 -19
  25. nucliadb/common/cluster/discovery/manual.py +1 -3
  26. nucliadb/common/cluster/discovery/utils.py +1 -3
  27. nucliadb/common/cluster/grpc_node_dummy.py +3 -11
  28. nucliadb/common/cluster/index_node.py +10 -19
  29. nucliadb/common/cluster/manager.py +174 -59
  30. nucliadb/common/cluster/rebalance.py +27 -29
  31. nucliadb/common/cluster/rollover.py +353 -194
  32. nucliadb/common/cluster/settings.py +6 -0
  33. nucliadb/common/cluster/standalone/grpc_node_binding.py +13 -64
  34. nucliadb/common/cluster/standalone/index_node.py +4 -11
  35. nucliadb/common/cluster/standalone/service.py +2 -6
  36. nucliadb/common/cluster/standalone/utils.py +2 -6
  37. nucliadb/common/cluster/utils.py +29 -22
  38. nucliadb/common/constants.py +20 -0
  39. nucliadb/common/context/__init__.py +3 -0
  40. nucliadb/common/context/fastapi.py +8 -5
  41. nucliadb/{tests/knowledgeboxes/__init__.py → common/counters.py} +8 -2
  42. nucliadb/common/datamanagers/__init__.py +7 -1
  43. nucliadb/common/datamanagers/atomic.py +22 -4
  44. nucliadb/common/datamanagers/cluster.py +5 -5
  45. nucliadb/common/datamanagers/entities.py +6 -16
  46. nucliadb/common/datamanagers/fields.py +84 -0
  47. nucliadb/common/datamanagers/kb.py +83 -37
  48. nucliadb/common/datamanagers/labels.py +26 -56
  49. nucliadb/common/datamanagers/processing.py +2 -6
  50. nucliadb/common/datamanagers/resources.py +41 -103
  51. nucliadb/common/datamanagers/rollover.py +76 -15
  52. nucliadb/common/datamanagers/synonyms.py +1 -1
  53. nucliadb/common/datamanagers/utils.py +15 -6
  54. nucliadb/common/datamanagers/vectorsets.py +110 -0
  55. nucliadb/common/external_index_providers/base.py +257 -0
  56. nucliadb/{ingest/tests/unit/orm/test_orm_utils.py → common/external_index_providers/exceptions.py} +9 -8
  57. nucliadb/common/external_index_providers/manager.py +101 -0
  58. nucliadb/common/external_index_providers/pinecone.py +933 -0
  59. nucliadb/common/external_index_providers/settings.py +52 -0
  60. nucliadb/common/http_clients/auth.py +3 -6
  61. nucliadb/common/http_clients/processing.py +6 -11
  62. nucliadb/common/http_clients/utils.py +1 -3
  63. nucliadb/common/ids.py +240 -0
  64. nucliadb/common/locking.py +29 -7
  65. nucliadb/common/maindb/driver.py +11 -35
  66. nucliadb/common/maindb/exceptions.py +3 -0
  67. nucliadb/common/maindb/local.py +22 -9
  68. nucliadb/common/maindb/pg.py +206 -111
  69. nucliadb/common/maindb/utils.py +11 -42
  70. nucliadb/common/models_utils/from_proto.py +479 -0
  71. nucliadb/common/models_utils/to_proto.py +60 -0
  72. nucliadb/common/nidx.py +260 -0
  73. nucliadb/export_import/datamanager.py +25 -19
  74. nucliadb/export_import/exporter.py +5 -11
  75. nucliadb/export_import/importer.py +5 -7
  76. nucliadb/export_import/models.py +3 -3
  77. nucliadb/export_import/tasks.py +4 -4
  78. nucliadb/export_import/utils.py +25 -37
  79. nucliadb/health.py +1 -3
  80. nucliadb/ingest/app.py +15 -11
  81. nucliadb/ingest/consumer/auditing.py +21 -19
  82. nucliadb/ingest/consumer/consumer.py +82 -47
  83. nucliadb/ingest/consumer/materializer.py +5 -12
  84. nucliadb/ingest/consumer/pull.py +12 -27
  85. nucliadb/ingest/consumer/service.py +19 -17
  86. nucliadb/ingest/consumer/shard_creator.py +2 -4
  87. nucliadb/ingest/consumer/utils.py +1 -3
  88. nucliadb/ingest/fields/base.py +137 -105
  89. nucliadb/ingest/fields/conversation.py +18 -5
  90. nucliadb/ingest/fields/exceptions.py +1 -4
  91. nucliadb/ingest/fields/file.py +7 -16
  92. nucliadb/ingest/fields/link.py +5 -10
  93. nucliadb/ingest/fields/text.py +9 -4
  94. nucliadb/ingest/orm/brain.py +200 -213
  95. nucliadb/ingest/orm/broker_message.py +181 -0
  96. nucliadb/ingest/orm/entities.py +36 -51
  97. nucliadb/ingest/orm/exceptions.py +12 -0
  98. nucliadb/ingest/orm/knowledgebox.py +322 -197
  99. nucliadb/ingest/orm/processor/__init__.py +2 -700
  100. nucliadb/ingest/orm/processor/auditing.py +4 -23
  101. nucliadb/ingest/orm/processor/data_augmentation.py +164 -0
  102. nucliadb/ingest/orm/processor/pgcatalog.py +84 -0
  103. nucliadb/ingest/orm/processor/processor.py +752 -0
  104. nucliadb/ingest/orm/processor/sequence_manager.py +1 -1
  105. nucliadb/ingest/orm/resource.py +249 -403
  106. nucliadb/ingest/orm/utils.py +4 -4
  107. nucliadb/ingest/partitions.py +3 -9
  108. nucliadb/ingest/processing.py +70 -73
  109. nucliadb/ingest/py.typed +0 -0
  110. nucliadb/ingest/serialize.py +37 -167
  111. nucliadb/ingest/service/__init__.py +1 -3
  112. nucliadb/ingest/service/writer.py +185 -412
  113. nucliadb/ingest/settings.py +10 -20
  114. nucliadb/ingest/utils.py +3 -6
  115. nucliadb/learning_proxy.py +242 -55
  116. nucliadb/metrics_exporter.py +30 -19
  117. nucliadb/middleware/__init__.py +1 -3
  118. nucliadb/migrator/command.py +1 -3
  119. nucliadb/migrator/datamanager.py +13 -13
  120. nucliadb/migrator/migrator.py +47 -30
  121. nucliadb/migrator/utils.py +18 -10
  122. nucliadb/purge/__init__.py +139 -33
  123. nucliadb/purge/orphan_shards.py +7 -13
  124. nucliadb/reader/__init__.py +1 -3
  125. nucliadb/reader/api/models.py +1 -12
  126. nucliadb/reader/api/v1/__init__.py +0 -1
  127. nucliadb/reader/api/v1/download.py +21 -88
  128. nucliadb/reader/api/v1/export_import.py +1 -1
  129. nucliadb/reader/api/v1/knowledgebox.py +10 -10
  130. nucliadb/reader/api/v1/learning_config.py +2 -6
  131. nucliadb/reader/api/v1/resource.py +62 -88
  132. nucliadb/reader/api/v1/services.py +64 -83
  133. nucliadb/reader/app.py +12 -29
  134. nucliadb/reader/lifecycle.py +18 -4
  135. nucliadb/reader/py.typed +0 -0
  136. nucliadb/reader/reader/notifications.py +10 -28
  137. nucliadb/search/__init__.py +1 -3
  138. nucliadb/search/api/v1/__init__.py +1 -2
  139. nucliadb/search/api/v1/ask.py +17 -10
  140. nucliadb/search/api/v1/catalog.py +184 -0
  141. nucliadb/search/api/v1/feedback.py +16 -24
  142. nucliadb/search/api/v1/find.py +36 -36
  143. nucliadb/search/api/v1/knowledgebox.py +89 -60
  144. nucliadb/search/api/v1/resource/ask.py +2 -8
  145. nucliadb/search/api/v1/resource/search.py +49 -70
  146. nucliadb/search/api/v1/search.py +44 -210
  147. nucliadb/search/api/v1/suggest.py +39 -54
  148. nucliadb/search/app.py +12 -32
  149. nucliadb/search/lifecycle.py +10 -3
  150. nucliadb/search/predict.py +136 -187
  151. nucliadb/search/py.typed +0 -0
  152. nucliadb/search/requesters/utils.py +25 -58
  153. nucliadb/search/search/cache.py +149 -20
  154. nucliadb/search/search/chat/ask.py +571 -123
  155. nucliadb/search/{tests/unit/test_run.py → search/chat/exceptions.py} +14 -14
  156. nucliadb/search/search/chat/images.py +41 -17
  157. nucliadb/search/search/chat/prompt.py +817 -266
  158. nucliadb/search/search/chat/query.py +213 -309
  159. nucliadb/{tests/migrations/__init__.py → search/search/cut.py} +8 -8
  160. nucliadb/search/search/fetch.py +43 -36
  161. nucliadb/search/search/filters.py +9 -15
  162. nucliadb/search/search/find.py +214 -53
  163. nucliadb/search/search/find_merge.py +408 -391
  164. nucliadb/search/search/hydrator.py +191 -0
  165. nucliadb/search/search/merge.py +187 -223
  166. nucliadb/search/search/metrics.py +73 -2
  167. nucliadb/search/search/paragraphs.py +64 -106
  168. nucliadb/search/search/pgcatalog.py +233 -0
  169. nucliadb/search/search/predict_proxy.py +1 -1
  170. nucliadb/search/search/query.py +305 -150
  171. nucliadb/search/search/query_parser/exceptions.py +22 -0
  172. nucliadb/search/search/query_parser/models.py +101 -0
  173. nucliadb/search/search/query_parser/parser.py +183 -0
  174. nucliadb/search/search/rank_fusion.py +204 -0
  175. nucliadb/search/search/rerankers.py +270 -0
  176. nucliadb/search/search/shards.py +3 -32
  177. nucliadb/search/search/summarize.py +7 -18
  178. nucliadb/search/search/utils.py +27 -4
  179. nucliadb/search/settings.py +15 -1
  180. nucliadb/standalone/api_router.py +4 -10
  181. nucliadb/standalone/app.py +8 -14
  182. nucliadb/standalone/auth.py +7 -21
  183. nucliadb/standalone/config.py +7 -10
  184. nucliadb/standalone/lifecycle.py +26 -25
  185. nucliadb/standalone/migrations.py +1 -3
  186. nucliadb/standalone/purge.py +1 -1
  187. nucliadb/standalone/py.typed +0 -0
  188. nucliadb/standalone/run.py +3 -6
  189. nucliadb/standalone/settings.py +9 -16
  190. nucliadb/standalone/versions.py +15 -5
  191. nucliadb/tasks/consumer.py +8 -12
  192. nucliadb/tasks/producer.py +7 -6
  193. nucliadb/tests/config.py +53 -0
  194. nucliadb/train/__init__.py +1 -3
  195. nucliadb/train/api/utils.py +1 -2
  196. nucliadb/train/api/v1/shards.py +1 -1
  197. nucliadb/train/api/v1/trainset.py +2 -4
  198. nucliadb/train/app.py +10 -31
  199. nucliadb/train/generator.py +10 -19
  200. nucliadb/train/generators/field_classifier.py +7 -19
  201. nucliadb/train/generators/field_streaming.py +156 -0
  202. nucliadb/train/generators/image_classifier.py +12 -18
  203. nucliadb/train/generators/paragraph_classifier.py +5 -9
  204. nucliadb/train/generators/paragraph_streaming.py +6 -9
  205. nucliadb/train/generators/question_answer_streaming.py +19 -20
  206. nucliadb/train/generators/sentence_classifier.py +9 -15
  207. nucliadb/train/generators/token_classifier.py +48 -39
  208. nucliadb/train/generators/utils.py +14 -18
  209. nucliadb/train/lifecycle.py +7 -3
  210. nucliadb/train/nodes.py +23 -32
  211. nucliadb/train/py.typed +0 -0
  212. nucliadb/train/servicer.py +13 -21
  213. nucliadb/train/settings.py +2 -6
  214. nucliadb/train/types.py +13 -10
  215. nucliadb/train/upload.py +3 -6
  216. nucliadb/train/uploader.py +19 -23
  217. nucliadb/train/utils.py +1 -1
  218. nucliadb/writer/__init__.py +1 -3
  219. nucliadb/{ingest/fields/keywordset.py → writer/api/utils.py} +13 -10
  220. nucliadb/writer/api/v1/export_import.py +67 -14
  221. nucliadb/writer/api/v1/field.py +16 -269
  222. nucliadb/writer/api/v1/knowledgebox.py +218 -68
  223. nucliadb/writer/api/v1/resource.py +68 -88
  224. nucliadb/writer/api/v1/services.py +51 -70
  225. nucliadb/writer/api/v1/slug.py +61 -0
  226. nucliadb/writer/api/v1/transaction.py +67 -0
  227. nucliadb/writer/api/v1/upload.py +143 -117
  228. nucliadb/writer/app.py +6 -43
  229. nucliadb/writer/back_pressure.py +16 -38
  230. nucliadb/writer/exceptions.py +0 -4
  231. nucliadb/writer/lifecycle.py +21 -15
  232. nucliadb/writer/py.typed +0 -0
  233. nucliadb/writer/resource/audit.py +2 -1
  234. nucliadb/writer/resource/basic.py +48 -46
  235. nucliadb/writer/resource/field.py +37 -128
  236. nucliadb/writer/resource/origin.py +1 -2
  237. nucliadb/writer/settings.py +6 -2
  238. nucliadb/writer/tus/__init__.py +17 -15
  239. nucliadb/writer/tus/azure.py +111 -0
  240. nucliadb/writer/tus/dm.py +17 -5
  241. nucliadb/writer/tus/exceptions.py +1 -3
  242. nucliadb/writer/tus/gcs.py +49 -84
  243. nucliadb/writer/tus/local.py +21 -37
  244. nucliadb/writer/tus/s3.py +28 -68
  245. nucliadb/writer/tus/storage.py +5 -56
  246. nucliadb/writer/vectorsets.py +125 -0
  247. nucliadb-6.2.1.post2798.dist-info/METADATA +148 -0
  248. nucliadb-6.2.1.post2798.dist-info/RECORD +343 -0
  249. {nucliadb-4.0.0.post542.dist-info → nucliadb-6.2.1.post2798.dist-info}/WHEEL +1 -1
  250. nucliadb/common/maindb/redis.py +0 -194
  251. nucliadb/common/maindb/tikv.py +0 -433
  252. nucliadb/ingest/fields/layout.py +0 -58
  253. nucliadb/ingest/tests/conftest.py +0 -30
  254. nucliadb/ingest/tests/fixtures.py +0 -764
  255. nucliadb/ingest/tests/integration/consumer/__init__.py +0 -18
  256. nucliadb/ingest/tests/integration/consumer/test_auditing.py +0 -78
  257. nucliadb/ingest/tests/integration/consumer/test_materializer.py +0 -126
  258. nucliadb/ingest/tests/integration/consumer/test_pull.py +0 -144
  259. nucliadb/ingest/tests/integration/consumer/test_service.py +0 -81
  260. nucliadb/ingest/tests/integration/consumer/test_shard_creator.py +0 -68
  261. nucliadb/ingest/tests/integration/ingest/test_ingest.py +0 -684
  262. nucliadb/ingest/tests/integration/ingest/test_processing_engine.py +0 -95
  263. nucliadb/ingest/tests/integration/ingest/test_relations.py +0 -272
  264. nucliadb/ingest/tests/unit/consumer/__init__.py +0 -18
  265. nucliadb/ingest/tests/unit/consumer/test_auditing.py +0 -139
  266. nucliadb/ingest/tests/unit/consumer/test_consumer.py +0 -69
  267. nucliadb/ingest/tests/unit/consumer/test_pull.py +0 -60
  268. nucliadb/ingest/tests/unit/consumer/test_shard_creator.py +0 -140
  269. nucliadb/ingest/tests/unit/consumer/test_utils.py +0 -67
  270. nucliadb/ingest/tests/unit/orm/__init__.py +0 -19
  271. nucliadb/ingest/tests/unit/orm/test_brain.py +0 -247
  272. nucliadb/ingest/tests/unit/orm/test_brain_vectors.py +0 -74
  273. nucliadb/ingest/tests/unit/orm/test_processor.py +0 -131
  274. nucliadb/ingest/tests/unit/orm/test_resource.py +0 -331
  275. nucliadb/ingest/tests/unit/test_cache.py +0 -31
  276. nucliadb/ingest/tests/unit/test_partitions.py +0 -40
  277. nucliadb/ingest/tests/unit/test_processing.py +0 -171
  278. nucliadb/middleware/transaction.py +0 -117
  279. nucliadb/reader/api/v1/learning_collector.py +0 -63
  280. nucliadb/reader/tests/__init__.py +0 -19
  281. nucliadb/reader/tests/conftest.py +0 -31
  282. nucliadb/reader/tests/fixtures.py +0 -136
  283. nucliadb/reader/tests/test_list_resources.py +0 -75
  284. nucliadb/reader/tests/test_reader_file_download.py +0 -273
  285. nucliadb/reader/tests/test_reader_resource.py +0 -353
  286. nucliadb/reader/tests/test_reader_resource_field.py +0 -219
  287. nucliadb/search/api/v1/chat.py +0 -263
  288. nucliadb/search/api/v1/resource/chat.py +0 -174
  289. nucliadb/search/tests/__init__.py +0 -19
  290. nucliadb/search/tests/conftest.py +0 -33
  291. nucliadb/search/tests/fixtures.py +0 -199
  292. nucliadb/search/tests/node.py +0 -466
  293. nucliadb/search/tests/unit/__init__.py +0 -18
  294. nucliadb/search/tests/unit/api/__init__.py +0 -19
  295. nucliadb/search/tests/unit/api/v1/__init__.py +0 -19
  296. nucliadb/search/tests/unit/api/v1/resource/__init__.py +0 -19
  297. nucliadb/search/tests/unit/api/v1/resource/test_chat.py +0 -98
  298. nucliadb/search/tests/unit/api/v1/test_ask.py +0 -120
  299. nucliadb/search/tests/unit/api/v1/test_chat.py +0 -96
  300. nucliadb/search/tests/unit/api/v1/test_predict_proxy.py +0 -98
  301. nucliadb/search/tests/unit/api/v1/test_summarize.py +0 -99
  302. nucliadb/search/tests/unit/search/__init__.py +0 -18
  303. nucliadb/search/tests/unit/search/requesters/__init__.py +0 -18
  304. nucliadb/search/tests/unit/search/requesters/test_utils.py +0 -211
  305. nucliadb/search/tests/unit/search/search/__init__.py +0 -19
  306. nucliadb/search/tests/unit/search/search/test_shards.py +0 -45
  307. nucliadb/search/tests/unit/search/search/test_utils.py +0 -82
  308. nucliadb/search/tests/unit/search/test_chat_prompt.py +0 -270
  309. nucliadb/search/tests/unit/search/test_fetch.py +0 -108
  310. nucliadb/search/tests/unit/search/test_filters.py +0 -125
  311. nucliadb/search/tests/unit/search/test_paragraphs.py +0 -157
  312. nucliadb/search/tests/unit/search/test_predict_proxy.py +0 -106
  313. nucliadb/search/tests/unit/search/test_query.py +0 -153
  314. nucliadb/search/tests/unit/test_app.py +0 -79
  315. nucliadb/search/tests/unit/test_find_merge.py +0 -112
  316. nucliadb/search/tests/unit/test_merge.py +0 -34
  317. nucliadb/search/tests/unit/test_predict.py +0 -525
  318. nucliadb/standalone/tests/__init__.py +0 -19
  319. nucliadb/standalone/tests/conftest.py +0 -33
  320. nucliadb/standalone/tests/fixtures.py +0 -38
  321. nucliadb/standalone/tests/unit/__init__.py +0 -18
  322. nucliadb/standalone/tests/unit/test_api_router.py +0 -61
  323. nucliadb/standalone/tests/unit/test_auth.py +0 -169
  324. nucliadb/standalone/tests/unit/test_introspect.py +0 -35
  325. nucliadb/standalone/tests/unit/test_migrations.py +0 -63
  326. nucliadb/standalone/tests/unit/test_versions.py +0 -68
  327. nucliadb/tests/benchmarks/__init__.py +0 -19
  328. nucliadb/tests/benchmarks/test_search.py +0 -99
  329. nucliadb/tests/conftest.py +0 -32
  330. nucliadb/tests/fixtures.py +0 -735
  331. nucliadb/tests/knowledgeboxes/philosophy_books.py +0 -202
  332. nucliadb/tests/knowledgeboxes/ten_dummy_resources.py +0 -107
  333. nucliadb/tests/migrations/test_migration_0017.py +0 -76
  334. nucliadb/tests/migrations/test_migration_0018.py +0 -95
  335. nucliadb/tests/tikv.py +0 -240
  336. nucliadb/tests/unit/__init__.py +0 -19
  337. nucliadb/tests/unit/common/__init__.py +0 -19
  338. nucliadb/tests/unit/common/cluster/__init__.py +0 -19
  339. nucliadb/tests/unit/common/cluster/discovery/__init__.py +0 -19
  340. nucliadb/tests/unit/common/cluster/discovery/test_k8s.py +0 -172
  341. nucliadb/tests/unit/common/cluster/standalone/__init__.py +0 -18
  342. nucliadb/tests/unit/common/cluster/standalone/test_service.py +0 -114
  343. nucliadb/tests/unit/common/cluster/standalone/test_utils.py +0 -61
  344. nucliadb/tests/unit/common/cluster/test_cluster.py +0 -408
  345. nucliadb/tests/unit/common/cluster/test_kb_shard_manager.py +0 -173
  346. nucliadb/tests/unit/common/cluster/test_rebalance.py +0 -38
  347. nucliadb/tests/unit/common/cluster/test_rollover.py +0 -282
  348. nucliadb/tests/unit/common/maindb/__init__.py +0 -18
  349. nucliadb/tests/unit/common/maindb/test_driver.py +0 -127
  350. nucliadb/tests/unit/common/maindb/test_tikv.py +0 -53
  351. nucliadb/tests/unit/common/maindb/test_utils.py +0 -92
  352. nucliadb/tests/unit/common/test_context.py +0 -36
  353. nucliadb/tests/unit/export_import/__init__.py +0 -19
  354. nucliadb/tests/unit/export_import/test_datamanager.py +0 -37
  355. nucliadb/tests/unit/export_import/test_utils.py +0 -301
  356. nucliadb/tests/unit/migrator/__init__.py +0 -19
  357. nucliadb/tests/unit/migrator/test_migrator.py +0 -87
  358. nucliadb/tests/unit/tasks/__init__.py +0 -19
  359. nucliadb/tests/unit/tasks/conftest.py +0 -42
  360. nucliadb/tests/unit/tasks/test_consumer.py +0 -92
  361. nucliadb/tests/unit/tasks/test_producer.py +0 -95
  362. nucliadb/tests/unit/tasks/test_tasks.py +0 -58
  363. nucliadb/tests/unit/test_field_ids.py +0 -49
  364. nucliadb/tests/unit/test_health.py +0 -86
  365. nucliadb/tests/unit/test_kb_slugs.py +0 -54
  366. nucliadb/tests/unit/test_learning_proxy.py +0 -252
  367. nucliadb/tests/unit/test_metrics_exporter.py +0 -77
  368. nucliadb/tests/unit/test_purge.py +0 -136
  369. nucliadb/tests/utils/__init__.py +0 -74
  370. nucliadb/tests/utils/aiohttp_session.py +0 -44
  371. nucliadb/tests/utils/broker_messages/__init__.py +0 -171
  372. nucliadb/tests/utils/broker_messages/fields.py +0 -197
  373. nucliadb/tests/utils/broker_messages/helpers.py +0 -33
  374. nucliadb/tests/utils/entities.py +0 -78
  375. nucliadb/train/api/v1/check.py +0 -60
  376. nucliadb/train/tests/__init__.py +0 -19
  377. nucliadb/train/tests/conftest.py +0 -29
  378. nucliadb/train/tests/fixtures.py +0 -342
  379. nucliadb/train/tests/test_field_classification.py +0 -122
  380. nucliadb/train/tests/test_get_entities.py +0 -80
  381. nucliadb/train/tests/test_get_info.py +0 -51
  382. nucliadb/train/tests/test_get_ontology.py +0 -34
  383. nucliadb/train/tests/test_get_ontology_count.py +0 -63
  384. nucliadb/train/tests/test_image_classification.py +0 -221
  385. nucliadb/train/tests/test_list_fields.py +0 -39
  386. nucliadb/train/tests/test_list_paragraphs.py +0 -73
  387. nucliadb/train/tests/test_list_resources.py +0 -39
  388. nucliadb/train/tests/test_list_sentences.py +0 -71
  389. nucliadb/train/tests/test_paragraph_classification.py +0 -123
  390. nucliadb/train/tests/test_paragraph_streaming.py +0 -118
  391. nucliadb/train/tests/test_question_answer_streaming.py +0 -239
  392. nucliadb/train/tests/test_sentence_classification.py +0 -143
  393. nucliadb/train/tests/test_token_classification.py +0 -136
  394. nucliadb/train/tests/utils.py +0 -101
  395. nucliadb/writer/layouts/__init__.py +0 -51
  396. nucliadb/writer/layouts/v1.py +0 -59
  397. nucliadb/writer/tests/__init__.py +0 -19
  398. nucliadb/writer/tests/conftest.py +0 -31
  399. nucliadb/writer/tests/fixtures.py +0 -191
  400. nucliadb/writer/tests/test_fields.py +0 -475
  401. nucliadb/writer/tests/test_files.py +0 -740
  402. nucliadb/writer/tests/test_knowledgebox.py +0 -49
  403. nucliadb/writer/tests/test_reprocess_file_field.py +0 -133
  404. nucliadb/writer/tests/test_resources.py +0 -476
  405. nucliadb/writer/tests/test_service.py +0 -137
  406. nucliadb/writer/tests/test_tus.py +0 -203
  407. nucliadb/writer/tests/utils.py +0 -35
  408. nucliadb/writer/tus/pg.py +0 -125
  409. nucliadb-4.0.0.post542.dist-info/METADATA +0 -135
  410. nucliadb-4.0.0.post542.dist-info/RECORD +0 -462
  411. {nucliadb/ingest/tests → migrations/pg}/__init__.py +0 -0
  412. /nucliadb/{ingest/tests/integration → common/external_index_providers}/__init__.py +0 -0
  413. /nucliadb/{ingest/tests/integration/ingest → common/models_utils}/__init__.py +0 -0
  414. /nucliadb/{ingest/tests/unit → search/search/query_parser}/__init__.py +0 -0
  415. /nucliadb/{ingest/tests → tests}/vectors.py +0 -0
  416. {nucliadb-4.0.0.post542.dist-info → nucliadb-6.2.1.post2798.dist-info}/entry_points.txt +0 -0
  417. {nucliadb-4.0.0.post542.dist-info → nucliadb-6.2.1.post2798.dist-info}/top_level.txt +0 -0
  418. {nucliadb-4.0.0.post542.dist-info → nucliadb-6.2.1.post2798.dist-info}/zip-safe +0 -0
@@ -18,68 +18,43 @@
18
18
  # along with this program. If not, see <http://www.gnu.org/licenses/>.
19
19
  #
20
20
  import asyncio
21
- from dataclasses import dataclass
22
- from time import monotonic as time
23
- from typing import AsyncGenerator, AsyncIterator, Optional
24
-
25
- from nucliadb_protos.nodereader_pb2 import RelationSearchRequest, RelationSearchResponse
21
+ from typing import Optional
26
22
 
23
+ from nucliadb.common.models_utils import to_proto
27
24
  from nucliadb.search import logger
28
25
  from nucliadb.search.predict import AnswerStatusCode
29
26
  from nucliadb.search.requesters.utils import Method, node_query
30
- from nucliadb.search.search.chat.prompt import PromptContextBuilder
27
+ from nucliadb.search.search.chat.exceptions import NoRetrievalResultsError
31
28
  from nucliadb.search.search.exceptions import IncompleteFindResultsError
32
29
  from nucliadb.search.search.find import find
33
30
  from nucliadb.search.search.merge import merge_relations_results
31
+ from nucliadb.search.search.metrics import RAGMetrics
34
32
  from nucliadb.search.search.query import QueryParser
33
+ from nucliadb.search.settings import settings
35
34
  from nucliadb.search.utilities import get_predict
36
35
  from nucliadb_models.search import (
37
- Author,
36
+ AskRequest,
38
37
  ChatContextMessage,
39
- ChatModel,
40
38
  ChatOptions,
41
- ChatRequest,
42
39
  FindRequest,
43
40
  KnowledgeboxFindResults,
44
- MinScore,
45
41
  NucliaDBClientType,
42
+ PreQueriesStrategy,
43
+ PreQuery,
44
+ PreQueryResult,
46
45
  PromptContext,
47
46
  PromptContextOrder,
48
47
  Relations,
49
48
  RephraseModel,
50
49
  SearchOptions,
51
- UserPrompt,
50
+ parse_rephrase_prompt,
52
51
  )
53
52
  from nucliadb_protos import audit_pb2
53
+ from nucliadb_protos.nodereader_pb2 import RelationSearchResponse, SearchRequest, SearchResponse
54
54
  from nucliadb_telemetry.errors import capture_exception
55
- from nucliadb_utils.helpers import async_gen_lookahead
56
55
  from nucliadb_utils.utilities import get_audit
57
56
 
58
57
  NOT_ENOUGH_CONTEXT_ANSWER = "Not enough data to answer this."
59
- AUDIT_TEXT_RESULT_SEP = " \n\n "
60
- START_OF_CITATIONS = b"_CIT_"
61
-
62
-
63
- class FoundStatusCode:
64
- def __init__(self, default: AnswerStatusCode = AnswerStatusCode.SUCCESS):
65
- self._value = AnswerStatusCode.SUCCESS
66
-
67
- def set(self, value: AnswerStatusCode) -> None:
68
- self._value = value
69
-
70
- @property
71
- def value(self) -> AnswerStatusCode:
72
- return self._value
73
-
74
-
75
- @dataclass
76
- class ChatResult:
77
- nuclia_learning_id: Optional[str]
78
- answer_stream: AsyncIterator[bytes]
79
- status_code: FoundStatusCode
80
- find_results: KnowledgeboxFindResults
81
- prompt_context: PromptContext
82
- prompt_context_order: PromptContextOrder
83
58
 
84
59
 
85
60
  async def rephrase_query(
@@ -101,70 +76,120 @@ async def rephrase_query(
101
76
  return await predict.rephrase_query(kbid, req)
102
77
 
103
78
 
104
- async def format_generated_answer(
105
- answer_generator: AsyncGenerator[bytes, None], output_status_code: FoundStatusCode
106
- ):
107
- status_code: Optional[AnswerStatusCode] = None
108
- is_last_chunk = False
109
- async for answer_chunk, is_last_chunk in async_gen_lookahead(answer_generator):
110
- if is_last_chunk:
111
- try:
112
- status_code = _parse_answer_status_code(answer_chunk)
113
- except ValueError:
114
- # TODO: remove this in the future, it's
115
- # just for bw compatibility until predict
116
- # is updated to the new protocol
117
- status_code = AnswerStatusCode.SUCCESS
118
- yield answer_chunk
119
- else:
120
- # TODO: this should be needed but, in case we receive the status
121
- # code mixed with text, we strip it and return the text
122
- if len(answer_chunk) != len(status_code.encode()):
123
- answer_chunk = answer_chunk.rstrip(status_code.encode())
124
- yield answer_chunk
125
- break
126
- yield answer_chunk
127
- if not is_last_chunk:
128
- logger.warning("BUG: /chat endpoint without last chunk")
129
-
130
- output_status_code.set(status_code or AnswerStatusCode.SUCCESS)
131
-
132
-
133
79
  async def get_find_results(
134
80
  *,
135
81
  kbid: str,
136
82
  query: str,
137
- chat_request: ChatRequest,
83
+ item: AskRequest,
84
+ ndb_client: NucliaDBClientType,
85
+ user: str,
86
+ origin: str,
87
+ metrics: RAGMetrics = RAGMetrics(),
88
+ prequeries_strategy: Optional[PreQueriesStrategy] = None,
89
+ ) -> tuple[KnowledgeboxFindResults, Optional[list[PreQueryResult]], QueryParser]:
90
+ prequeries_results = None
91
+ prefilter_queries_results = None
92
+ queries_results = None
93
+ if prequeries_strategy is not None:
94
+ prefilters = [prequery for prequery in prequeries_strategy.queries if prequery.prefilter]
95
+ prequeries = [prequery for prequery in prequeries_strategy.queries if not prequery.prefilter]
96
+ if len(prefilters) > 0:
97
+ with metrics.time("prefilters"):
98
+ prefilter_queries_results = await run_prequeries(
99
+ kbid,
100
+ prefilters,
101
+ x_ndb_client=ndb_client,
102
+ x_nucliadb_user=user,
103
+ x_forwarded_for=origin,
104
+ generative_model=item.generative_model,
105
+ metrics=metrics,
106
+ )
107
+ prefilter_matching_resources = {
108
+ resource
109
+ for _, find_results in prefilter_queries_results
110
+ for resource in find_results.resources.keys()
111
+ }
112
+ if len(prefilter_matching_resources) == 0:
113
+ raise NoRetrievalResultsError()
114
+ # Make sure the main query and prequeries use the same resource filters.
115
+ # This is important to avoid returning results that don't match the prefilter.
116
+ item.resource_filters = list(prefilter_matching_resources)
117
+ for prequery in prequeries:
118
+ prequery.request.resource_filters = list(prefilter_matching_resources)
119
+ prequery.request.show_hidden = item.show_hidden
120
+
121
+ if prequeries:
122
+ with metrics.time("prequeries"):
123
+ queries_results = await run_prequeries(
124
+ kbid,
125
+ prequeries,
126
+ x_ndb_client=ndb_client,
127
+ x_nucliadb_user=user,
128
+ x_forwarded_for=origin,
129
+ generative_model=item.generative_model,
130
+ metrics=metrics,
131
+ )
132
+
133
+ prequeries_results = (prefilter_queries_results or []) + (queries_results or [])
134
+
135
+ with metrics.time("main_query"):
136
+ main_results, query_parser = await run_main_query(
137
+ kbid,
138
+ query,
139
+ item,
140
+ ndb_client,
141
+ user,
142
+ origin,
143
+ metrics=metrics,
144
+ )
145
+ return main_results, prequeries_results, query_parser
146
+
147
+
148
+ async def run_main_query(
149
+ kbid: str,
150
+ query: str,
151
+ item: AskRequest,
138
152
  ndb_client: NucliaDBClientType,
139
153
  user: str,
140
154
  origin: str,
155
+ metrics: RAGMetrics = RAGMetrics(),
141
156
  ) -> tuple[KnowledgeboxFindResults, QueryParser]:
142
157
  find_request = FindRequest()
143
- find_request.resource_filters = chat_request.resource_filters
158
+ find_request.resource_filters = item.resource_filters
144
159
  find_request.features = []
145
- if ChatOptions.VECTORS in chat_request.features:
146
- find_request.features.append(SearchOptions.VECTOR)
147
- if ChatOptions.PARAGRAPHS in chat_request.features:
148
- find_request.features.append(SearchOptions.PARAGRAPH)
149
- if ChatOptions.RELATIONS in chat_request.features:
160
+ if ChatOptions.SEMANTIC in item.features:
161
+ find_request.features.append(SearchOptions.SEMANTIC)
162
+ if ChatOptions.KEYWORD in item.features:
163
+ find_request.features.append(SearchOptions.KEYWORD)
164
+ if ChatOptions.RELATIONS in item.features:
150
165
  find_request.features.append(SearchOptions.RELATIONS)
151
166
  find_request.query = query
152
- find_request.fields = chat_request.fields
153
- find_request.filters = chat_request.filters
154
- find_request.field_type_filter = chat_request.field_type_filter
155
- find_request.min_score = chat_request.min_score
156
- find_request.range_creation_start = chat_request.range_creation_start
157
- find_request.range_creation_end = chat_request.range_creation_end
158
- find_request.range_modification_start = chat_request.range_modification_start
159
- find_request.range_modification_end = chat_request.range_modification_end
160
- find_request.show = chat_request.show
161
- find_request.extracted = chat_request.extracted
162
- find_request.shards = chat_request.shards
163
- find_request.autofilter = chat_request.autofilter
164
- find_request.highlight = chat_request.highlight
165
- find_request.security = chat_request.security
166
- find_request.debug = chat_request.debug
167
- find_request.rephrase = chat_request.rephrase
167
+ find_request.fields = item.fields
168
+ find_request.filters = item.filters
169
+ find_request.field_type_filter = item.field_type_filter
170
+ find_request.min_score = item.min_score
171
+ find_request.vectorset = item.vectorset
172
+ find_request.range_creation_start = item.range_creation_start
173
+ find_request.range_creation_end = item.range_creation_end
174
+ find_request.range_modification_start = item.range_modification_start
175
+ find_request.range_modification_end = item.range_modification_end
176
+ find_request.show = item.show
177
+ find_request.extracted = item.extracted
178
+ find_request.shards = item.shards
179
+ find_request.autofilter = item.autofilter
180
+ find_request.highlight = item.highlight
181
+ find_request.security = item.security
182
+ find_request.debug = item.debug
183
+ find_request.rephrase = item.rephrase
184
+ find_request.rephrase_prompt = parse_rephrase_prompt(item)
185
+ find_request.rank_fusion = item.rank_fusion
186
+ find_request.reranker = item.reranker
187
+ # We don't support pagination, we always get the top_k results.
188
+ find_request.top_k = item.top_k
189
+ find_request.show_hidden = item.show_hidden
190
+
191
+ # this executes the model validators, that can tweak some fields
192
+ FindRequest.model_validate(find_request)
168
193
 
169
194
  find_results, incomplete, query_parser = await find(
170
195
  kbid,
@@ -172,7 +197,8 @@ async def get_find_results(
172
197
  ndb_client,
173
198
  user,
174
199
  origin,
175
- generative_model=chat_request.generative_model,
200
+ generative_model=item.generative_model,
201
+ metrics=metrics,
176
202
  )
177
203
  if incomplete:
178
204
  raise IncompleteFindResultsError()
@@ -180,269 +206,100 @@ async def get_find_results(
180
206
 
181
207
 
182
208
  async def get_relations_results(
183
- *, kbid: str, text_answer: str, target_shard_replicas: Optional[list[str]]
209
+ *,
210
+ kbid: str,
211
+ text_answer: str,
212
+ target_shard_replicas: Optional[list[str]],
213
+ timeout: Optional[float] = None,
184
214
  ) -> Relations:
185
215
  try:
186
216
  predict = get_predict()
187
217
  detected_entities = await predict.detect_entities(kbid, text_answer)
188
- relation_request = RelationSearchRequest()
189
- relation_request.subgraph.entry_points.extend(detected_entities)
190
- relation_request.subgraph.depth = 1
218
+ request = SearchRequest()
219
+ request.relation_subgraph.entry_points.extend(detected_entities)
220
+ request.relation_subgraph.depth = 1
191
221
 
192
- relations_results: list[RelationSearchResponse]
222
+ results: list[SearchResponse]
193
223
  (
194
- relations_results,
224
+ results,
195
225
  _,
196
226
  _,
197
227
  ) = await node_query(
198
228
  kbid,
199
- Method.RELATIONS,
200
- relation_request,
229
+ Method.SEARCH,
230
+ request,
201
231
  target_shard_replicas=target_shard_replicas,
232
+ timeout=timeout,
233
+ use_read_replica_nodes=True,
234
+ retry_on_primary=False,
202
235
  )
203
- return await merge_relations_results(
204
- relations_results, relation_request.subgraph
205
- )
236
+ relations_results: list[RelationSearchResponse] = [result.relation for result in results]
237
+ return await merge_relations_results(relations_results, request.relation_subgraph)
206
238
  except Exception as exc:
207
239
  capture_exception(exc)
208
240
  logger.exception("Error getting relations results")
209
241
  return Relations(entities={})
210
242
 
211
243
 
212
- async def not_enough_context_generator():
213
- await asyncio.sleep(0)
214
- yield NOT_ENOUGH_CONTEXT_ANSWER.encode()
215
- yield AnswerStatusCode.NO_CONTEXT.encode()
216
-
217
-
218
- async def chat(
219
- kbid: str,
220
- chat_request: ChatRequest,
221
- user_id: str,
222
- client_type: NucliaDBClientType,
223
- origin: str,
224
- resource: Optional[str] = None,
225
- ) -> ChatResult:
226
- start_time = time()
227
- nuclia_learning_id: Optional[str] = None
228
- chat_history = chat_request.context or []
229
- user_context = chat_request.extra_context or []
230
- user_query = chat_request.query
231
- rephrased_query = None
232
- prompt_context: PromptContext = {}
233
- prompt_context_order: PromptContextOrder = {}
234
-
235
- if len(chat_history) > 0 or len(user_context) > 0:
236
- rephrased_query = await rephrase_query(
237
- kbid,
238
- chat_history=chat_history,
239
- query=user_query,
240
- user_id=user_id,
241
- user_context=user_context,
242
- generative_model=chat_request.generative_model,
243
- )
244
-
245
- # Retrieval is not needed if we are chatting on a specific
246
- # resource and the full_resource strategy is enabled
247
- needs_retrieval = True
248
- if resource is not None:
249
- chat_request.resource_filters = [resource]
250
- if any(
251
- strategy.name == "full_resource" for strategy in chat_request.rag_strategies
252
- ):
253
- needs_retrieval = False
254
-
255
- if needs_retrieval:
256
- find_results, query_parser = await get_find_results(
257
- kbid=kbid,
258
- query=rephrased_query or user_query,
259
- chat_request=chat_request,
260
- ndb_client=client_type,
261
- user=user_id,
262
- origin=origin,
263
- )
264
- status_code = FoundStatusCode()
265
- if len(find_results.resources) == 0:
266
- # If no resources were found on the retrieval, we return
267
- # a "Not enough context" answer and skip the llm query
268
- answer_stream = format_generated_answer(
269
- not_enough_context_generator(), status_code
270
- )
271
- return ChatResult(
272
- nuclia_learning_id=nuclia_learning_id,
273
- answer_stream=answer_stream,
274
- status_code=status_code,
275
- find_results=find_results,
276
- prompt_context=prompt_context,
277
- prompt_context_order=prompt_context_order,
278
- )
279
- else:
280
- status_code = FoundStatusCode()
281
- find_results = KnowledgeboxFindResults(resources={}, min_score=None)
282
- query_parser = QueryParser(
283
- kbid=kbid,
284
- features=[],
285
- query="",
286
- filters=chat_request.filters,
287
- page_number=0,
288
- page_size=0,
289
- min_score=MinScore(),
290
- )
291
-
292
- query_parser.max_tokens = chat_request.max_tokens # type: ignore
293
- max_tokens_context = await query_parser.get_max_tokens_context()
294
- prompt_context_builder = PromptContextBuilder(
295
- kbid=kbid,
296
- find_results=find_results,
297
- resource=resource,
298
- user_context=user_context,
299
- strategies=chat_request.rag_strategies,
300
- image_strategies=chat_request.rag_images_strategies,
301
- max_context_characters=tokens_to_chars(max_tokens_context),
302
- visual_llm=await query_parser.get_visual_llm_enabled(),
303
- )
304
- (
305
- prompt_context,
306
- prompt_context_order,
307
- prompt_context_images,
308
- ) = await prompt_context_builder.build()
309
- user_prompt = None
310
- if chat_request.prompt is not None:
311
- user_prompt = UserPrompt(prompt=chat_request.prompt)
312
- chat_model = ChatModel(
313
- user_id=user_id,
314
- query_context=prompt_context,
315
- query_context_order=prompt_context_order,
316
- chat_history=chat_history,
317
- question=user_query,
318
- truncate=True,
319
- user_prompt=user_prompt,
320
- citations=chat_request.citations,
321
- generative_model=chat_request.generative_model,
322
- max_tokens=query_parser.get_max_tokens_answer(),
323
- query_context_images=prompt_context_images,
324
- prefer_markdown=chat_request.prefer_markdown,
325
- )
326
- predict = get_predict()
327
- nuclia_learning_id, predict_generator = await predict.chat_query(kbid, chat_model)
328
-
329
- async def _wrapped_stream():
330
- # so we can audit after streamed out answer
331
- text_answer = b""
332
- async for chunk in format_generated_answer(predict_generator, status_code):
333
- text_answer += chunk
334
- yield chunk
335
-
336
- await maybe_audit_chat(
337
- kbid=kbid,
338
- user_id=user_id,
339
- client_type=client_type,
340
- origin=origin,
341
- duration=time() - start_time,
342
- user_query=user_query,
343
- rephrased_query=rephrased_query,
344
- text_answer=text_answer,
345
- status_code=status_code.value,
346
- chat_history=chat_history,
347
- query_context=prompt_context,
348
- query_context_order=prompt_context_order,
349
- learning_id=nuclia_learning_id,
350
- )
351
-
352
- answer_stream = _wrapped_stream()
353
- return ChatResult(
354
- nuclia_learning_id=nuclia_learning_id,
355
- answer_stream=answer_stream,
356
- status_code=status_code,
357
- find_results=find_results,
358
- prompt_context=prompt_context,
359
- prompt_context_order=prompt_context_order,
360
- )
361
-
362
-
363
- def _parse_answer_status_code(chunk: bytes) -> AnswerStatusCode:
364
- """
365
- Parses the status code from the last chunk of the answer.
366
- """
367
- try:
368
- return AnswerStatusCode(chunk.decode())
369
- except ValueError:
370
- # In some cases, even if the status code was yield separately
371
- # at the server side, the status code is appended to the previous chunk...
372
- # It may be a bug in the aiohttp.StreamResponse implementation,
373
- # but we haven't spotted it yet. For now, we just try to parse the status code
374
- # from the tail of the chunk.
375
- logger.debug(
376
- f"Error decoding status code from /chat's last chunk. Chunk: {chunk!r}"
377
- )
378
- if chunk == b"":
379
- raise
380
- if chunk.endswith(b"0"):
381
- return AnswerStatusCode.SUCCESS
382
- return AnswerStatusCode(chunk[-2:].decode())
383
-
384
-
385
- async def maybe_audit_chat(
244
+ def maybe_audit_chat(
386
245
  *,
387
246
  kbid: str,
388
247
  user_id: str,
389
248
  client_type: NucliaDBClientType,
390
249
  origin: str,
391
- duration: float,
250
+ generative_answer_time: float,
251
+ generative_answer_first_chunk_time: float,
252
+ rephrase_time: Optional[float],
392
253
  user_query: str,
393
254
  rephrased_query: Optional[str],
394
255
  text_answer: bytes,
395
- status_code: Optional[AnswerStatusCode],
256
+ status_code: AnswerStatusCode,
396
257
  chat_history: list[ChatContextMessage],
397
258
  query_context: PromptContext,
398
259
  query_context_order: PromptContextOrder,
399
260
  learning_id: str,
261
+ model: str,
400
262
  ):
401
263
  audit = get_audit()
402
264
  if audit is None:
403
265
  return
404
266
 
405
267
  audit_answer = parse_audit_answer(text_answer, status_code)
268
+ # Append chat history
269
+ chat_history_context = [
270
+ audit_pb2.ChatContext(author=message.author, text=message.text) for message in chat_history
271
+ ]
406
272
 
407
- # Append chat history and query context
408
- audit_context = [
409
- audit_pb2.ChatContext(author=message.author, text=message.text)
410
- for message in chat_history
273
+ # Append paragraphs retrieved on this chat
274
+ chat_retrieved_context = [
275
+ audit_pb2.RetrievedContext(text_block_id=paragraph_id, text=text)
276
+ for paragraph_id, text in query_context.items()
411
277
  ]
412
- query_context_paragaph_ids = list(query_context.keys())
413
- audit_context.append(
414
- audit_pb2.ChatContext(
415
- author=Author.NUCLIA,
416
- text=AUDIT_TEXT_RESULT_SEP.join(query_context_paragaph_ids),
417
- )
418
- )
419
- await audit.chat(
278
+
279
+ audit.chat(
420
280
  kbid,
421
281
  user_id,
422
- client_type.to_proto(),
282
+ to_proto.client_type(client_type),
423
283
  origin,
424
- duration,
425
284
  question=user_query,
285
+ generative_answer_time=generative_answer_time,
286
+ generative_answer_first_chunk_time=generative_answer_first_chunk_time,
287
+ rephrase_time=rephrase_time,
426
288
  rephrased_question=rephrased_query,
427
- context=audit_context,
289
+ chat_context=chat_history_context,
290
+ retrieved_context=chat_retrieved_context,
428
291
  answer=audit_answer,
429
292
  learning_id=learning_id,
293
+ status_code=int(status_code.value),
294
+ model=model,
430
295
  )
431
296
 
432
297
 
433
- def parse_audit_answer(
434
- raw_text_answer: bytes, status_code: Optional[AnswerStatusCode]
435
- ) -> Optional[str]:
298
+ def parse_audit_answer(raw_text_answer: bytes, status_code: AnswerStatusCode) -> Optional[str]:
436
299
  if status_code == AnswerStatusCode.NO_CONTEXT:
437
300
  # We don't want to audit "Not enough context to answer this." and instead set a None.
438
301
  return None
439
- # Split citations part from answer
440
- try:
441
- raw_audit_answer, _ = raw_text_answer.split(START_OF_CITATIONS)
442
- except ValueError:
443
- raw_audit_answer = raw_text_answer
444
- audit_answer = raw_audit_answer.decode()
445
- return audit_answer
302
+ return raw_text_answer.decode()
446
303
 
447
304
 
448
305
  def tokens_to_chars(n_tokens: int) -> int:
@@ -458,47 +315,55 @@ class ChatAuditor:
458
315
  user_id: str,
459
316
  client_type: NucliaDBClientType,
460
317
  origin: str,
461
- start_time: float,
462
318
  user_query: str,
463
319
  rephrased_query: Optional[str],
464
320
  chat_history: list[ChatContextMessage],
465
321
  learning_id: Optional[str],
466
322
  query_context: PromptContext,
467
323
  query_context_order: PromptContextOrder,
324
+ model: str,
468
325
  ):
469
326
  self.kbid = kbid
470
327
  self.user_id = user_id
471
328
  self.client_type = client_type
472
329
  self.origin = origin
473
- self.start_time = start_time
474
330
  self.user_query = user_query
475
331
  self.rephrased_query = rephrased_query
476
332
  self.chat_history = chat_history
477
333
  self.learning_id = learning_id
478
334
  self.query_context = query_context
479
335
  self.query_context_order = query_context_order
336
+ self.model = model
480
337
 
481
- async def audit(self, text_answer: bytes, status_code: AnswerStatusCode):
482
- await maybe_audit_chat(
338
+ def audit(
339
+ self,
340
+ text_answer: bytes,
341
+ generative_answer_time: float,
342
+ generative_answer_first_chunk_time: float,
343
+ rephrase_time: Optional[float],
344
+ status_code: AnswerStatusCode,
345
+ ):
346
+ maybe_audit_chat(
483
347
  kbid=self.kbid,
484
348
  user_id=self.user_id,
485
349
  client_type=self.client_type,
486
350
  origin=self.origin,
487
- duration=time() - self.start_time,
488
351
  user_query=self.user_query,
489
352
  rephrased_query=self.rephrased_query,
353
+ generative_answer_time=generative_answer_time,
354
+ generative_answer_first_chunk_time=generative_answer_first_chunk_time,
355
+ rephrase_time=rephrase_time,
490
356
  text_answer=text_answer,
491
357
  status_code=status_code,
492
358
  chat_history=self.chat_history,
493
359
  query_context=self.query_context,
494
360
  query_context_order=self.query_context_order,
495
361
  learning_id=self.learning_id or "unknown",
362
+ model=self.model,
496
363
  )
497
364
 
498
365
 
499
- def sorted_prompt_context_list(
500
- context: PromptContext, order: PromptContextOrder
501
- ) -> list[str]:
366
+ def sorted_prompt_context_list(context: PromptContext, order: PromptContextOrder) -> list[str]:
502
367
  """
503
368
  context = {"x": "foo", "y": "bar"}
504
369
  order = {"y": 1, "x": 0}
@@ -509,3 +374,42 @@ def sorted_prompt_context_list(
509
374
  key=lambda item: order.get(item[0], float("inf")),
510
375
  )
511
376
  return list(map(lambda item: item[1], sorted_items))
377
+
378
+
379
+ async def run_prequeries(
380
+ kbid: str,
381
+ prequeries: list[PreQuery],
382
+ x_ndb_client: NucliaDBClientType,
383
+ x_nucliadb_user: str,
384
+ x_forwarded_for: str,
385
+ generative_model: Optional[str] = None,
386
+ metrics: RAGMetrics = RAGMetrics(),
387
+ ) -> list[PreQueryResult]:
388
+ """
389
+ Runs simultaneous find requests for each prequery and returns the merged results according to the normalized weights.
390
+ """
391
+ results: list[PreQueryResult] = []
392
+ max_parallel_prequeries = asyncio.Semaphore(settings.prequeries_max_parallel)
393
+
394
+ async def _prequery_find(
395
+ prequery: PreQuery,
396
+ ):
397
+ async with max_parallel_prequeries:
398
+ find_results, _, _ = await find(
399
+ kbid,
400
+ prequery.request,
401
+ x_ndb_client,
402
+ x_nucliadb_user,
403
+ x_forwarded_for,
404
+ generative_model=generative_model,
405
+ metrics=metrics,
406
+ )
407
+ return prequery, find_results
408
+
409
+ ops = []
410
+ for prequery in prequeries:
411
+ ops.append(asyncio.create_task(_prequery_find(prequery)))
412
+ ops_results = await asyncio.gather(*ops)
413
+ for prequery, find_results in ops_results:
414
+ results.append((prequery, find_results))
415
+ return results