nucliadb 2.46.1.post382__py3-none-any.whl → 6.2.1.post2777__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (431) hide show
  1. migrations/0002_rollover_shards.py +1 -2
  2. migrations/0003_allfields_key.py +2 -37
  3. migrations/0004_rollover_shards.py +1 -2
  4. migrations/0005_rollover_shards.py +1 -2
  5. migrations/0006_rollover_shards.py +2 -4
  6. migrations/0008_cleanup_leftover_rollover_metadata.py +1 -2
  7. migrations/0009_upgrade_relations_and_texts_to_v2.py +5 -4
  8. migrations/0010_fix_corrupt_indexes.py +11 -12
  9. migrations/0011_materialize_labelset_ids.py +2 -18
  10. migrations/0012_rollover_shards.py +6 -12
  11. migrations/0013_rollover_shards.py +2 -4
  12. migrations/0014_rollover_shards.py +5 -7
  13. migrations/0015_targeted_rollover.py +6 -12
  14. migrations/0016_upgrade_to_paragraphs_v2.py +27 -32
  15. migrations/0017_multiple_writable_shards.py +3 -6
  16. migrations/0018_purge_orphan_kbslugs.py +59 -0
  17. migrations/0019_upgrade_to_paragraphs_v3.py +66 -0
  18. migrations/0020_drain_nodes_from_cluster.py +83 -0
  19. nucliadb/standalone/tests/unit/test_run.py → migrations/0021_overwrite_vectorsets_key.py +17 -18
  20. nucliadb/tests/unit/test_openapi.py → migrations/0022_fix_paragraph_deletion_bug.py +16 -11
  21. migrations/0023_backfill_pg_catalog.py +80 -0
  22. migrations/0025_assign_models_to_kbs_v2.py +113 -0
  23. migrations/0026_fix_high_cardinality_content_types.py +61 -0
  24. migrations/0027_rollover_texts3.py +73 -0
  25. nucliadb/ingest/fields/date.py → migrations/pg/0001_bootstrap.py +10 -12
  26. migrations/pg/0002_catalog.py +42 -0
  27. nucliadb/ingest/tests/unit/test_settings.py → migrations/pg/0003_catalog_kbid_index.py +5 -3
  28. nucliadb/common/cluster/base.py +41 -24
  29. nucliadb/common/cluster/discovery/base.py +6 -14
  30. nucliadb/common/cluster/discovery/k8s.py +9 -19
  31. nucliadb/common/cluster/discovery/manual.py +1 -3
  32. nucliadb/common/cluster/discovery/single.py +1 -2
  33. nucliadb/common/cluster/discovery/utils.py +1 -3
  34. nucliadb/common/cluster/grpc_node_dummy.py +11 -16
  35. nucliadb/common/cluster/index_node.py +10 -19
  36. nucliadb/common/cluster/manager.py +223 -102
  37. nucliadb/common/cluster/rebalance.py +42 -37
  38. nucliadb/common/cluster/rollover.py +377 -204
  39. nucliadb/common/cluster/settings.py +16 -9
  40. nucliadb/common/cluster/standalone/grpc_node_binding.py +24 -76
  41. nucliadb/common/cluster/standalone/index_node.py +4 -11
  42. nucliadb/common/cluster/standalone/service.py +2 -6
  43. nucliadb/common/cluster/standalone/utils.py +9 -6
  44. nucliadb/common/cluster/utils.py +43 -29
  45. nucliadb/common/constants.py +20 -0
  46. nucliadb/common/context/__init__.py +6 -4
  47. nucliadb/common/context/fastapi.py +8 -5
  48. nucliadb/{tests/knowledgeboxes/__init__.py → common/counters.py} +8 -2
  49. nucliadb/common/datamanagers/__init__.py +24 -5
  50. nucliadb/common/datamanagers/atomic.py +102 -0
  51. nucliadb/common/datamanagers/cluster.py +5 -5
  52. nucliadb/common/datamanagers/entities.py +6 -16
  53. nucliadb/common/datamanagers/fields.py +84 -0
  54. nucliadb/common/datamanagers/kb.py +101 -24
  55. nucliadb/common/datamanagers/labels.py +26 -56
  56. nucliadb/common/datamanagers/processing.py +2 -6
  57. nucliadb/common/datamanagers/resources.py +214 -117
  58. nucliadb/common/datamanagers/rollover.py +77 -16
  59. nucliadb/{ingest/orm → common/datamanagers}/synonyms.py +16 -28
  60. nucliadb/common/datamanagers/utils.py +19 -11
  61. nucliadb/common/datamanagers/vectorsets.py +110 -0
  62. nucliadb/common/external_index_providers/base.py +257 -0
  63. nucliadb/{ingest/tests/unit/test_cache.py → common/external_index_providers/exceptions.py} +9 -8
  64. nucliadb/common/external_index_providers/manager.py +101 -0
  65. nucliadb/common/external_index_providers/pinecone.py +933 -0
  66. nucliadb/common/external_index_providers/settings.py +52 -0
  67. nucliadb/common/http_clients/auth.py +3 -6
  68. nucliadb/common/http_clients/processing.py +6 -11
  69. nucliadb/common/http_clients/utils.py +1 -3
  70. nucliadb/common/ids.py +240 -0
  71. nucliadb/common/locking.py +43 -13
  72. nucliadb/common/maindb/driver.py +11 -35
  73. nucliadb/common/maindb/exceptions.py +6 -6
  74. nucliadb/common/maindb/local.py +22 -9
  75. nucliadb/common/maindb/pg.py +206 -111
  76. nucliadb/common/maindb/utils.py +13 -44
  77. nucliadb/common/models_utils/from_proto.py +479 -0
  78. nucliadb/common/models_utils/to_proto.py +60 -0
  79. nucliadb/common/nidx.py +260 -0
  80. nucliadb/export_import/datamanager.py +25 -19
  81. nucliadb/export_import/exceptions.py +8 -0
  82. nucliadb/export_import/exporter.py +20 -7
  83. nucliadb/export_import/importer.py +6 -11
  84. nucliadb/export_import/models.py +5 -5
  85. nucliadb/export_import/tasks.py +4 -4
  86. nucliadb/export_import/utils.py +94 -54
  87. nucliadb/health.py +1 -3
  88. nucliadb/ingest/app.py +15 -11
  89. nucliadb/ingest/consumer/auditing.py +30 -147
  90. nucliadb/ingest/consumer/consumer.py +96 -52
  91. nucliadb/ingest/consumer/materializer.py +10 -12
  92. nucliadb/ingest/consumer/pull.py +12 -27
  93. nucliadb/ingest/consumer/service.py +20 -19
  94. nucliadb/ingest/consumer/shard_creator.py +7 -14
  95. nucliadb/ingest/consumer/utils.py +1 -3
  96. nucliadb/ingest/fields/base.py +139 -188
  97. nucliadb/ingest/fields/conversation.py +18 -5
  98. nucliadb/ingest/fields/exceptions.py +1 -4
  99. nucliadb/ingest/fields/file.py +7 -25
  100. nucliadb/ingest/fields/link.py +11 -16
  101. nucliadb/ingest/fields/text.py +9 -4
  102. nucliadb/ingest/orm/brain.py +255 -262
  103. nucliadb/ingest/orm/broker_message.py +181 -0
  104. nucliadb/ingest/orm/entities.py +36 -51
  105. nucliadb/ingest/orm/exceptions.py +12 -0
  106. nucliadb/ingest/orm/knowledgebox.py +334 -278
  107. nucliadb/ingest/orm/processor/__init__.py +2 -697
  108. nucliadb/ingest/orm/processor/auditing.py +117 -0
  109. nucliadb/ingest/orm/processor/data_augmentation.py +164 -0
  110. nucliadb/ingest/orm/processor/pgcatalog.py +84 -0
  111. nucliadb/ingest/orm/processor/processor.py +752 -0
  112. nucliadb/ingest/orm/processor/sequence_manager.py +1 -1
  113. nucliadb/ingest/orm/resource.py +280 -520
  114. nucliadb/ingest/orm/utils.py +25 -31
  115. nucliadb/ingest/partitions.py +3 -9
  116. nucliadb/ingest/processing.py +76 -81
  117. nucliadb/ingest/py.typed +0 -0
  118. nucliadb/ingest/serialize.py +37 -173
  119. nucliadb/ingest/service/__init__.py +1 -3
  120. nucliadb/ingest/service/writer.py +186 -577
  121. nucliadb/ingest/settings.py +13 -22
  122. nucliadb/ingest/utils.py +3 -6
  123. nucliadb/learning_proxy.py +264 -51
  124. nucliadb/metrics_exporter.py +30 -19
  125. nucliadb/middleware/__init__.py +1 -3
  126. nucliadb/migrator/command.py +1 -3
  127. nucliadb/migrator/datamanager.py +13 -13
  128. nucliadb/migrator/migrator.py +57 -37
  129. nucliadb/migrator/settings.py +2 -1
  130. nucliadb/migrator/utils.py +18 -10
  131. nucliadb/purge/__init__.py +139 -33
  132. nucliadb/purge/orphan_shards.py +7 -13
  133. nucliadb/reader/__init__.py +1 -3
  134. nucliadb/reader/api/models.py +3 -14
  135. nucliadb/reader/api/v1/__init__.py +0 -1
  136. nucliadb/reader/api/v1/download.py +27 -94
  137. nucliadb/reader/api/v1/export_import.py +4 -4
  138. nucliadb/reader/api/v1/knowledgebox.py +13 -13
  139. nucliadb/reader/api/v1/learning_config.py +8 -12
  140. nucliadb/reader/api/v1/resource.py +67 -93
  141. nucliadb/reader/api/v1/services.py +70 -125
  142. nucliadb/reader/app.py +16 -46
  143. nucliadb/reader/lifecycle.py +18 -4
  144. nucliadb/reader/py.typed +0 -0
  145. nucliadb/reader/reader/notifications.py +10 -31
  146. nucliadb/search/__init__.py +1 -3
  147. nucliadb/search/api/v1/__init__.py +2 -2
  148. nucliadb/search/api/v1/ask.py +112 -0
  149. nucliadb/search/api/v1/catalog.py +184 -0
  150. nucliadb/search/api/v1/feedback.py +17 -25
  151. nucliadb/search/api/v1/find.py +41 -41
  152. nucliadb/search/api/v1/knowledgebox.py +90 -62
  153. nucliadb/search/api/v1/predict_proxy.py +2 -2
  154. nucliadb/search/api/v1/resource/ask.py +66 -117
  155. nucliadb/search/api/v1/resource/search.py +51 -72
  156. nucliadb/search/api/v1/router.py +1 -0
  157. nucliadb/search/api/v1/search.py +50 -197
  158. nucliadb/search/api/v1/suggest.py +40 -54
  159. nucliadb/search/api/v1/summarize.py +9 -5
  160. nucliadb/search/api/v1/utils.py +2 -1
  161. nucliadb/search/app.py +16 -48
  162. nucliadb/search/lifecycle.py +10 -3
  163. nucliadb/search/predict.py +176 -188
  164. nucliadb/search/py.typed +0 -0
  165. nucliadb/search/requesters/utils.py +41 -63
  166. nucliadb/search/search/cache.py +149 -20
  167. nucliadb/search/search/chat/ask.py +918 -0
  168. nucliadb/search/{tests/unit/test_run.py → search/chat/exceptions.py} +14 -13
  169. nucliadb/search/search/chat/images.py +41 -17
  170. nucliadb/search/search/chat/prompt.py +851 -282
  171. nucliadb/search/search/chat/query.py +274 -267
  172. nucliadb/{writer/resource/slug.py → search/search/cut.py} +8 -6
  173. nucliadb/search/search/fetch.py +43 -36
  174. nucliadb/search/search/filters.py +9 -15
  175. nucliadb/search/search/find.py +214 -54
  176. nucliadb/search/search/find_merge.py +408 -391
  177. nucliadb/search/search/hydrator.py +191 -0
  178. nucliadb/search/search/merge.py +198 -234
  179. nucliadb/search/search/metrics.py +73 -2
  180. nucliadb/search/search/paragraphs.py +64 -106
  181. nucliadb/search/search/pgcatalog.py +233 -0
  182. nucliadb/search/search/predict_proxy.py +1 -1
  183. nucliadb/search/search/query.py +386 -257
  184. nucliadb/search/search/query_parser/exceptions.py +22 -0
  185. nucliadb/search/search/query_parser/models.py +101 -0
  186. nucliadb/search/search/query_parser/parser.py +183 -0
  187. nucliadb/search/search/rank_fusion.py +204 -0
  188. nucliadb/search/search/rerankers.py +270 -0
  189. nucliadb/search/search/shards.py +4 -38
  190. nucliadb/search/search/summarize.py +14 -18
  191. nucliadb/search/search/utils.py +27 -4
  192. nucliadb/search/settings.py +15 -1
  193. nucliadb/standalone/api_router.py +4 -10
  194. nucliadb/standalone/app.py +17 -14
  195. nucliadb/standalone/auth.py +7 -21
  196. nucliadb/standalone/config.py +9 -12
  197. nucliadb/standalone/introspect.py +5 -5
  198. nucliadb/standalone/lifecycle.py +26 -25
  199. nucliadb/standalone/migrations.py +58 -0
  200. nucliadb/standalone/purge.py +9 -8
  201. nucliadb/standalone/py.typed +0 -0
  202. nucliadb/standalone/run.py +25 -18
  203. nucliadb/standalone/settings.py +10 -14
  204. nucliadb/standalone/versions.py +15 -5
  205. nucliadb/tasks/consumer.py +8 -12
  206. nucliadb/tasks/producer.py +7 -6
  207. nucliadb/tests/config.py +53 -0
  208. nucliadb/train/__init__.py +1 -3
  209. nucliadb/train/api/utils.py +1 -2
  210. nucliadb/train/api/v1/shards.py +2 -2
  211. nucliadb/train/api/v1/trainset.py +4 -6
  212. nucliadb/train/app.py +14 -47
  213. nucliadb/train/generator.py +10 -19
  214. nucliadb/train/generators/field_classifier.py +7 -19
  215. nucliadb/train/generators/field_streaming.py +156 -0
  216. nucliadb/train/generators/image_classifier.py +12 -18
  217. nucliadb/train/generators/paragraph_classifier.py +5 -9
  218. nucliadb/train/generators/paragraph_streaming.py +6 -9
  219. nucliadb/train/generators/question_answer_streaming.py +19 -20
  220. nucliadb/train/generators/sentence_classifier.py +9 -15
  221. nucliadb/train/generators/token_classifier.py +45 -36
  222. nucliadb/train/generators/utils.py +14 -18
  223. nucliadb/train/lifecycle.py +7 -3
  224. nucliadb/train/nodes.py +23 -32
  225. nucliadb/train/py.typed +0 -0
  226. nucliadb/train/servicer.py +13 -21
  227. nucliadb/train/settings.py +2 -6
  228. nucliadb/train/types.py +13 -10
  229. nucliadb/train/upload.py +3 -6
  230. nucliadb/train/uploader.py +20 -25
  231. nucliadb/train/utils.py +1 -1
  232. nucliadb/writer/__init__.py +1 -3
  233. nucliadb/writer/api/constants.py +0 -5
  234. nucliadb/{ingest/fields/keywordset.py → writer/api/utils.py} +13 -10
  235. nucliadb/writer/api/v1/export_import.py +102 -49
  236. nucliadb/writer/api/v1/field.py +196 -620
  237. nucliadb/writer/api/v1/knowledgebox.py +221 -71
  238. nucliadb/writer/api/v1/learning_config.py +2 -2
  239. nucliadb/writer/api/v1/resource.py +114 -216
  240. nucliadb/writer/api/v1/services.py +64 -132
  241. nucliadb/writer/api/v1/slug.py +61 -0
  242. nucliadb/writer/api/v1/transaction.py +67 -0
  243. nucliadb/writer/api/v1/upload.py +184 -215
  244. nucliadb/writer/app.py +11 -61
  245. nucliadb/writer/back_pressure.py +62 -43
  246. nucliadb/writer/exceptions.py +0 -4
  247. nucliadb/writer/lifecycle.py +21 -15
  248. nucliadb/writer/py.typed +0 -0
  249. nucliadb/writer/resource/audit.py +2 -1
  250. nucliadb/writer/resource/basic.py +48 -62
  251. nucliadb/writer/resource/field.py +45 -135
  252. nucliadb/writer/resource/origin.py +1 -2
  253. nucliadb/writer/settings.py +14 -5
  254. nucliadb/writer/tus/__init__.py +17 -15
  255. nucliadb/writer/tus/azure.py +111 -0
  256. nucliadb/writer/tus/dm.py +17 -5
  257. nucliadb/writer/tus/exceptions.py +1 -3
  258. nucliadb/writer/tus/gcs.py +56 -84
  259. nucliadb/writer/tus/local.py +21 -37
  260. nucliadb/writer/tus/s3.py +28 -68
  261. nucliadb/writer/tus/storage.py +5 -56
  262. nucliadb/writer/vectorsets.py +125 -0
  263. nucliadb-6.2.1.post2777.dist-info/METADATA +148 -0
  264. nucliadb-6.2.1.post2777.dist-info/RECORD +343 -0
  265. {nucliadb-2.46.1.post382.dist-info → nucliadb-6.2.1.post2777.dist-info}/WHEEL +1 -1
  266. nucliadb/common/maindb/redis.py +0 -194
  267. nucliadb/common/maindb/tikv.py +0 -412
  268. nucliadb/ingest/fields/layout.py +0 -58
  269. nucliadb/ingest/tests/conftest.py +0 -30
  270. nucliadb/ingest/tests/fixtures.py +0 -771
  271. nucliadb/ingest/tests/integration/consumer/__init__.py +0 -18
  272. nucliadb/ingest/tests/integration/consumer/test_auditing.py +0 -80
  273. nucliadb/ingest/tests/integration/consumer/test_materializer.py +0 -89
  274. nucliadb/ingest/tests/integration/consumer/test_pull.py +0 -144
  275. nucliadb/ingest/tests/integration/consumer/test_service.py +0 -81
  276. nucliadb/ingest/tests/integration/consumer/test_shard_creator.py +0 -68
  277. nucliadb/ingest/tests/integration/ingest/test_ingest.py +0 -691
  278. nucliadb/ingest/tests/integration/ingest/test_processing_engine.py +0 -95
  279. nucliadb/ingest/tests/integration/ingest/test_relations.py +0 -272
  280. nucliadb/ingest/tests/unit/consumer/__init__.py +0 -18
  281. nucliadb/ingest/tests/unit/consumer/test_auditing.py +0 -140
  282. nucliadb/ingest/tests/unit/consumer/test_consumer.py +0 -69
  283. nucliadb/ingest/tests/unit/consumer/test_pull.py +0 -60
  284. nucliadb/ingest/tests/unit/consumer/test_shard_creator.py +0 -139
  285. nucliadb/ingest/tests/unit/consumer/test_utils.py +0 -67
  286. nucliadb/ingest/tests/unit/orm/__init__.py +0 -19
  287. nucliadb/ingest/tests/unit/orm/test_brain.py +0 -247
  288. nucliadb/ingest/tests/unit/orm/test_processor.py +0 -131
  289. nucliadb/ingest/tests/unit/orm/test_resource.py +0 -275
  290. nucliadb/ingest/tests/unit/test_partitions.py +0 -40
  291. nucliadb/ingest/tests/unit/test_processing.py +0 -171
  292. nucliadb/middleware/transaction.py +0 -117
  293. nucliadb/reader/api/v1/learning_collector.py +0 -63
  294. nucliadb/reader/tests/__init__.py +0 -19
  295. nucliadb/reader/tests/conftest.py +0 -31
  296. nucliadb/reader/tests/fixtures.py +0 -136
  297. nucliadb/reader/tests/test_list_resources.py +0 -75
  298. nucliadb/reader/tests/test_reader_file_download.py +0 -273
  299. nucliadb/reader/tests/test_reader_resource.py +0 -379
  300. nucliadb/reader/tests/test_reader_resource_field.py +0 -219
  301. nucliadb/search/api/v1/chat.py +0 -258
  302. nucliadb/search/api/v1/resource/chat.py +0 -94
  303. nucliadb/search/tests/__init__.py +0 -19
  304. nucliadb/search/tests/conftest.py +0 -33
  305. nucliadb/search/tests/fixtures.py +0 -199
  306. nucliadb/search/tests/node.py +0 -465
  307. nucliadb/search/tests/unit/__init__.py +0 -18
  308. nucliadb/search/tests/unit/api/__init__.py +0 -19
  309. nucliadb/search/tests/unit/api/v1/__init__.py +0 -19
  310. nucliadb/search/tests/unit/api/v1/resource/__init__.py +0 -19
  311. nucliadb/search/tests/unit/api/v1/resource/test_ask.py +0 -67
  312. nucliadb/search/tests/unit/api/v1/resource/test_chat.py +0 -97
  313. nucliadb/search/tests/unit/api/v1/test_chat.py +0 -96
  314. nucliadb/search/tests/unit/api/v1/test_predict_proxy.py +0 -98
  315. nucliadb/search/tests/unit/api/v1/test_summarize.py +0 -93
  316. nucliadb/search/tests/unit/search/__init__.py +0 -18
  317. nucliadb/search/tests/unit/search/requesters/__init__.py +0 -18
  318. nucliadb/search/tests/unit/search/requesters/test_utils.py +0 -210
  319. nucliadb/search/tests/unit/search/search/__init__.py +0 -19
  320. nucliadb/search/tests/unit/search/search/test_shards.py +0 -45
  321. nucliadb/search/tests/unit/search/search/test_utils.py +0 -82
  322. nucliadb/search/tests/unit/search/test_chat_prompt.py +0 -266
  323. nucliadb/search/tests/unit/search/test_fetch.py +0 -108
  324. nucliadb/search/tests/unit/search/test_filters.py +0 -125
  325. nucliadb/search/tests/unit/search/test_paragraphs.py +0 -157
  326. nucliadb/search/tests/unit/search/test_predict_proxy.py +0 -106
  327. nucliadb/search/tests/unit/search/test_query.py +0 -201
  328. nucliadb/search/tests/unit/test_app.py +0 -79
  329. nucliadb/search/tests/unit/test_find_merge.py +0 -112
  330. nucliadb/search/tests/unit/test_merge.py +0 -34
  331. nucliadb/search/tests/unit/test_predict.py +0 -584
  332. nucliadb/standalone/tests/__init__.py +0 -19
  333. nucliadb/standalone/tests/conftest.py +0 -33
  334. nucliadb/standalone/tests/fixtures.py +0 -38
  335. nucliadb/standalone/tests/unit/__init__.py +0 -18
  336. nucliadb/standalone/tests/unit/test_api_router.py +0 -61
  337. nucliadb/standalone/tests/unit/test_auth.py +0 -169
  338. nucliadb/standalone/tests/unit/test_introspect.py +0 -35
  339. nucliadb/standalone/tests/unit/test_versions.py +0 -68
  340. nucliadb/tests/benchmarks/__init__.py +0 -19
  341. nucliadb/tests/benchmarks/test_search.py +0 -99
  342. nucliadb/tests/conftest.py +0 -32
  343. nucliadb/tests/fixtures.py +0 -736
  344. nucliadb/tests/knowledgeboxes/philosophy_books.py +0 -203
  345. nucliadb/tests/knowledgeboxes/ten_dummy_resources.py +0 -109
  346. nucliadb/tests/migrations/__init__.py +0 -19
  347. nucliadb/tests/migrations/test_migration_0017.py +0 -80
  348. nucliadb/tests/tikv.py +0 -240
  349. nucliadb/tests/unit/__init__.py +0 -19
  350. nucliadb/tests/unit/common/__init__.py +0 -19
  351. nucliadb/tests/unit/common/cluster/__init__.py +0 -19
  352. nucliadb/tests/unit/common/cluster/discovery/__init__.py +0 -19
  353. nucliadb/tests/unit/common/cluster/discovery/test_k8s.py +0 -170
  354. nucliadb/tests/unit/common/cluster/standalone/__init__.py +0 -18
  355. nucliadb/tests/unit/common/cluster/standalone/test_service.py +0 -113
  356. nucliadb/tests/unit/common/cluster/standalone/test_utils.py +0 -59
  357. nucliadb/tests/unit/common/cluster/test_cluster.py +0 -399
  358. nucliadb/tests/unit/common/cluster/test_kb_shard_manager.py +0 -178
  359. nucliadb/tests/unit/common/cluster/test_rollover.py +0 -279
  360. nucliadb/tests/unit/common/maindb/__init__.py +0 -18
  361. nucliadb/tests/unit/common/maindb/test_driver.py +0 -127
  362. nucliadb/tests/unit/common/maindb/test_tikv.py +0 -53
  363. nucliadb/tests/unit/common/maindb/test_utils.py +0 -81
  364. nucliadb/tests/unit/common/test_context.py +0 -36
  365. nucliadb/tests/unit/export_import/__init__.py +0 -19
  366. nucliadb/tests/unit/export_import/test_datamanager.py +0 -37
  367. nucliadb/tests/unit/export_import/test_utils.py +0 -294
  368. nucliadb/tests/unit/migrator/__init__.py +0 -19
  369. nucliadb/tests/unit/migrator/test_migrator.py +0 -87
  370. nucliadb/tests/unit/tasks/__init__.py +0 -19
  371. nucliadb/tests/unit/tasks/conftest.py +0 -42
  372. nucliadb/tests/unit/tasks/test_consumer.py +0 -93
  373. nucliadb/tests/unit/tasks/test_producer.py +0 -95
  374. nucliadb/tests/unit/tasks/test_tasks.py +0 -60
  375. nucliadb/tests/unit/test_field_ids.py +0 -49
  376. nucliadb/tests/unit/test_health.py +0 -84
  377. nucliadb/tests/unit/test_kb_slugs.py +0 -54
  378. nucliadb/tests/unit/test_learning_proxy.py +0 -252
  379. nucliadb/tests/unit/test_metrics_exporter.py +0 -77
  380. nucliadb/tests/unit/test_purge.py +0 -138
  381. nucliadb/tests/utils/__init__.py +0 -74
  382. nucliadb/tests/utils/aiohttp_session.py +0 -44
  383. nucliadb/tests/utils/broker_messages/__init__.py +0 -167
  384. nucliadb/tests/utils/broker_messages/fields.py +0 -181
  385. nucliadb/tests/utils/broker_messages/helpers.py +0 -33
  386. nucliadb/tests/utils/entities.py +0 -78
  387. nucliadb/train/api/v1/check.py +0 -60
  388. nucliadb/train/tests/__init__.py +0 -19
  389. nucliadb/train/tests/conftest.py +0 -29
  390. nucliadb/train/tests/fixtures.py +0 -342
  391. nucliadb/train/tests/test_field_classification.py +0 -122
  392. nucliadb/train/tests/test_get_entities.py +0 -80
  393. nucliadb/train/tests/test_get_info.py +0 -51
  394. nucliadb/train/tests/test_get_ontology.py +0 -34
  395. nucliadb/train/tests/test_get_ontology_count.py +0 -63
  396. nucliadb/train/tests/test_image_classification.py +0 -222
  397. nucliadb/train/tests/test_list_fields.py +0 -39
  398. nucliadb/train/tests/test_list_paragraphs.py +0 -73
  399. nucliadb/train/tests/test_list_resources.py +0 -39
  400. nucliadb/train/tests/test_list_sentences.py +0 -71
  401. nucliadb/train/tests/test_paragraph_classification.py +0 -123
  402. nucliadb/train/tests/test_paragraph_streaming.py +0 -118
  403. nucliadb/train/tests/test_question_answer_streaming.py +0 -239
  404. nucliadb/train/tests/test_sentence_classification.py +0 -143
  405. nucliadb/train/tests/test_token_classification.py +0 -136
  406. nucliadb/train/tests/utils.py +0 -108
  407. nucliadb/writer/layouts/__init__.py +0 -51
  408. nucliadb/writer/layouts/v1.py +0 -59
  409. nucliadb/writer/resource/vectors.py +0 -120
  410. nucliadb/writer/tests/__init__.py +0 -19
  411. nucliadb/writer/tests/conftest.py +0 -31
  412. nucliadb/writer/tests/fixtures.py +0 -192
  413. nucliadb/writer/tests/test_fields.py +0 -486
  414. nucliadb/writer/tests/test_files.py +0 -743
  415. nucliadb/writer/tests/test_knowledgebox.py +0 -49
  416. nucliadb/writer/tests/test_reprocess_file_field.py +0 -139
  417. nucliadb/writer/tests/test_resources.py +0 -546
  418. nucliadb/writer/tests/test_service.py +0 -137
  419. nucliadb/writer/tests/test_tus.py +0 -203
  420. nucliadb/writer/tests/utils.py +0 -35
  421. nucliadb/writer/tus/pg.py +0 -125
  422. nucliadb-2.46.1.post382.dist-info/METADATA +0 -134
  423. nucliadb-2.46.1.post382.dist-info/RECORD +0 -451
  424. {nucliadb/ingest/tests → migrations/pg}/__init__.py +0 -0
  425. /nucliadb/{ingest/tests/integration → common/external_index_providers}/__init__.py +0 -0
  426. /nucliadb/{ingest/tests/integration/ingest → common/models_utils}/__init__.py +0 -0
  427. /nucliadb/{ingest/tests/unit → search/search/query_parser}/__init__.py +0 -0
  428. /nucliadb/{ingest/tests → tests}/vectors.py +0 -0
  429. {nucliadb-2.46.1.post382.dist-info → nucliadb-6.2.1.post2777.dist-info}/entry_points.txt +0 -0
  430. {nucliadb-2.46.1.post382.dist-info → nucliadb-6.2.1.post2777.dist-info}/top_level.txt +0 -0
  431. {nucliadb-2.46.1.post382.dist-info → nucliadb-6.2.1.post2777.dist-info}/zip-safe +0 -0
@@ -18,67 +18,43 @@
18
18
  # along with this program. If not, see <http://www.gnu.org/licenses/>.
19
19
  #
20
20
  import asyncio
21
- from dataclasses import dataclass
22
- from time import monotonic as time
23
- from typing import AsyncGenerator, AsyncIterator, Optional
24
-
25
- from nucliadb_protos.nodereader_pb2 import RelationSearchRequest, RelationSearchResponse
21
+ from typing import Optional
26
22
 
23
+ from nucliadb.common.models_utils import to_proto
27
24
  from nucliadb.search import logger
28
25
  from nucliadb.search.predict import AnswerStatusCode
29
26
  from nucliadb.search.requesters.utils import Method, node_query
30
- from nucliadb.search.search.chat.prompt import PromptContextBuilder
27
+ from nucliadb.search.search.chat.exceptions import NoRetrievalResultsError
31
28
  from nucliadb.search.search.exceptions import IncompleteFindResultsError
32
29
  from nucliadb.search.search.find import find
33
30
  from nucliadb.search.search.merge import merge_relations_results
31
+ from nucliadb.search.search.metrics import RAGMetrics
34
32
  from nucliadb.search.search.query import QueryParser
33
+ from nucliadb.search.settings import settings
35
34
  from nucliadb.search.utilities import get_predict
36
35
  from nucliadb_models.search import (
37
- Author,
36
+ AskRequest,
38
37
  ChatContextMessage,
39
- ChatModel,
40
38
  ChatOptions,
41
- ChatRequest,
42
39
  FindRequest,
43
40
  KnowledgeboxFindResults,
44
41
  NucliaDBClientType,
42
+ PreQueriesStrategy,
43
+ PreQuery,
44
+ PreQueryResult,
45
45
  PromptContext,
46
46
  PromptContextOrder,
47
47
  Relations,
48
48
  RephraseModel,
49
49
  SearchOptions,
50
- UserPrompt,
50
+ parse_rephrase_prompt,
51
51
  )
52
52
  from nucliadb_protos import audit_pb2
53
+ from nucliadb_protos.nodereader_pb2 import RelationSearchResponse, SearchRequest, SearchResponse
53
54
  from nucliadb_telemetry.errors import capture_exception
54
- from nucliadb_utils.helpers import async_gen_lookahead
55
55
  from nucliadb_utils.utilities import get_audit
56
56
 
57
57
  NOT_ENOUGH_CONTEXT_ANSWER = "Not enough data to answer this."
58
- AUDIT_TEXT_RESULT_SEP = " \n\n "
59
- START_OF_CITATIONS = b"_CIT_"
60
-
61
-
62
- class FoundStatusCode:
63
- def __init__(self, default: AnswerStatusCode = AnswerStatusCode.SUCCESS):
64
- self._value = AnswerStatusCode.SUCCESS
65
-
66
- def set(self, value: AnswerStatusCode) -> None:
67
- self._value = value
68
-
69
- @property
70
- def value(self) -> AnswerStatusCode:
71
- return self._value
72
-
73
-
74
- @dataclass
75
- class ChatResult:
76
- nuclia_learning_id: Optional[str]
77
- answer_stream: AsyncIterator[bytes]
78
- status_code: FoundStatusCode
79
- find_results: KnowledgeboxFindResults
80
- prompt_context: PromptContext
81
- prompt_context_order: PromptContextOrder
82
58
 
83
59
 
84
60
  async def rephrase_query(
@@ -100,70 +76,120 @@ async def rephrase_query(
100
76
  return await predict.rephrase_query(kbid, req)
101
77
 
102
78
 
103
- async def format_generated_answer(
104
- answer_generator: AsyncGenerator[bytes, None], output_status_code: FoundStatusCode
105
- ):
106
- status_code: Optional[AnswerStatusCode] = None
107
- is_last_chunk = False
108
- async for answer_chunk, is_last_chunk in async_gen_lookahead(answer_generator):
109
- if is_last_chunk:
110
- try:
111
- status_code = _parse_answer_status_code(answer_chunk)
112
- except ValueError:
113
- # TODO: remove this in the future, it's
114
- # just for bw compatibility until predict
115
- # is updated to the new protocol
116
- status_code = AnswerStatusCode.SUCCESS
117
- yield answer_chunk
118
- else:
119
- # TODO: this should be needed but, in case we receive the status
120
- # code mixed with text, we strip it and return the text
121
- if len(answer_chunk) != len(status_code.encode()):
122
- answer_chunk = answer_chunk.rstrip(status_code.encode())
123
- yield answer_chunk
124
- break
125
- yield answer_chunk
126
- if not is_last_chunk:
127
- logger.warning("BUG: /chat endpoint without last chunk")
128
-
129
- output_status_code.set(status_code or AnswerStatusCode.SUCCESS)
130
-
131
-
132
79
  async def get_find_results(
133
80
  *,
134
81
  kbid: str,
135
82
  query: str,
136
- chat_request: ChatRequest,
83
+ item: AskRequest,
137
84
  ndb_client: NucliaDBClientType,
138
85
  user: str,
139
86
  origin: str,
87
+ metrics: RAGMetrics = RAGMetrics(),
88
+ prequeries_strategy: Optional[PreQueriesStrategy] = None,
89
+ ) -> tuple[KnowledgeboxFindResults, Optional[list[PreQueryResult]], QueryParser]:
90
+ prequeries_results = None
91
+ prefilter_queries_results = None
92
+ queries_results = None
93
+ if prequeries_strategy is not None:
94
+ prefilters = [prequery for prequery in prequeries_strategy.queries if prequery.prefilter]
95
+ prequeries = [prequery for prequery in prequeries_strategy.queries if not prequery.prefilter]
96
+ if len(prefilters) > 0:
97
+ with metrics.time("prefilters"):
98
+ prefilter_queries_results = await run_prequeries(
99
+ kbid,
100
+ prefilters,
101
+ x_ndb_client=ndb_client,
102
+ x_nucliadb_user=user,
103
+ x_forwarded_for=origin,
104
+ generative_model=item.generative_model,
105
+ metrics=metrics,
106
+ )
107
+ prefilter_matching_resources = {
108
+ resource
109
+ for _, find_results in prefilter_queries_results
110
+ for resource in find_results.resources.keys()
111
+ }
112
+ if len(prefilter_matching_resources) == 0:
113
+ raise NoRetrievalResultsError()
114
+ # Make sure the main query and prequeries use the same resource filters.
115
+ # This is important to avoid returning results that don't match the prefilter.
116
+ item.resource_filters = list(prefilter_matching_resources)
117
+ for prequery in prequeries:
118
+ prequery.request.resource_filters = list(prefilter_matching_resources)
119
+ prequery.request.show_hidden = item.show_hidden
120
+
121
+ if prequeries:
122
+ with metrics.time("prequeries"):
123
+ queries_results = await run_prequeries(
124
+ kbid,
125
+ prequeries,
126
+ x_ndb_client=ndb_client,
127
+ x_nucliadb_user=user,
128
+ x_forwarded_for=origin,
129
+ generative_model=item.generative_model,
130
+ metrics=metrics,
131
+ )
132
+
133
+ prequeries_results = (prefilter_queries_results or []) + (queries_results or [])
134
+
135
+ with metrics.time("main_query"):
136
+ main_results, query_parser = await run_main_query(
137
+ kbid,
138
+ query,
139
+ item,
140
+ ndb_client,
141
+ user,
142
+ origin,
143
+ metrics=metrics,
144
+ )
145
+ return main_results, prequeries_results, query_parser
146
+
147
+
148
+ async def run_main_query(
149
+ kbid: str,
150
+ query: str,
151
+ item: AskRequest,
152
+ ndb_client: NucliaDBClientType,
153
+ user: str,
154
+ origin: str,
155
+ metrics: RAGMetrics = RAGMetrics(),
140
156
  ) -> tuple[KnowledgeboxFindResults, QueryParser]:
141
157
  find_request = FindRequest()
142
- find_request.resource_filters = chat_request.resource_filters
158
+ find_request.resource_filters = item.resource_filters
143
159
  find_request.features = []
144
- if ChatOptions.VECTORS in chat_request.features:
145
- find_request.features.append(SearchOptions.VECTOR)
146
- if ChatOptions.PARAGRAPHS in chat_request.features:
147
- find_request.features.append(SearchOptions.PARAGRAPH)
148
- if ChatOptions.RELATIONS in chat_request.features:
160
+ if ChatOptions.SEMANTIC in item.features:
161
+ find_request.features.append(SearchOptions.SEMANTIC)
162
+ if ChatOptions.KEYWORD in item.features:
163
+ find_request.features.append(SearchOptions.KEYWORD)
164
+ if ChatOptions.RELATIONS in item.features:
149
165
  find_request.features.append(SearchOptions.RELATIONS)
150
166
  find_request.query = query
151
- find_request.fields = chat_request.fields
152
- find_request.filters = chat_request.filters
153
- find_request.field_type_filter = chat_request.field_type_filter
154
- find_request.min_score = chat_request.min_score
155
- find_request.range_creation_start = chat_request.range_creation_start
156
- find_request.range_creation_end = chat_request.range_creation_end
157
- find_request.range_modification_start = chat_request.range_modification_start
158
- find_request.range_modification_end = chat_request.range_modification_end
159
- find_request.show = chat_request.show
160
- find_request.extracted = chat_request.extracted
161
- find_request.shards = chat_request.shards
162
- find_request.autofilter = chat_request.autofilter
163
- find_request.highlight = chat_request.highlight
164
- find_request.security = chat_request.security
165
- find_request.debug = chat_request.debug
166
- find_request.rephrase = chat_request.rephrase
167
+ find_request.fields = item.fields
168
+ find_request.filters = item.filters
169
+ find_request.field_type_filter = item.field_type_filter
170
+ find_request.min_score = item.min_score
171
+ find_request.vectorset = item.vectorset
172
+ find_request.range_creation_start = item.range_creation_start
173
+ find_request.range_creation_end = item.range_creation_end
174
+ find_request.range_modification_start = item.range_modification_start
175
+ find_request.range_modification_end = item.range_modification_end
176
+ find_request.show = item.show
177
+ find_request.extracted = item.extracted
178
+ find_request.shards = item.shards
179
+ find_request.autofilter = item.autofilter
180
+ find_request.highlight = item.highlight
181
+ find_request.security = item.security
182
+ find_request.debug = item.debug
183
+ find_request.rephrase = item.rephrase
184
+ find_request.rephrase_prompt = parse_rephrase_prompt(item)
185
+ find_request.rank_fusion = item.rank_fusion
186
+ find_request.reranker = item.reranker
187
+ # We don't support pagination, we always get the top_k results.
188
+ find_request.top_k = item.top_k
189
+ find_request.show_hidden = item.show_hidden
190
+
191
+ # this executes the model validators, that can tweak some fields
192
+ FindRequest.model_validate(find_request)
167
193
 
168
194
  find_results, incomplete, query_parser = await find(
169
195
  kbid,
@@ -171,7 +197,8 @@ async def get_find_results(
171
197
  ndb_client,
172
198
  user,
173
199
  origin,
174
- generative_model=chat_request.generative_model,
200
+ generative_model=item.generative_model,
201
+ metrics=metrics,
175
202
  )
176
203
  if incomplete:
177
204
  raise IncompleteFindResultsError()
@@ -179,230 +206,210 @@ async def get_find_results(
179
206
 
180
207
 
181
208
  async def get_relations_results(
182
- *, kbid: str, chat_request: ChatRequest, text_answer: str
209
+ *,
210
+ kbid: str,
211
+ text_answer: str,
212
+ target_shard_replicas: Optional[list[str]],
213
+ timeout: Optional[float] = None,
183
214
  ) -> Relations:
184
215
  try:
185
216
  predict = get_predict()
186
217
  detected_entities = await predict.detect_entities(kbid, text_answer)
187
- relation_request = RelationSearchRequest()
188
- relation_request.subgraph.entry_points.extend(detected_entities)
189
- relation_request.subgraph.depth = 1
218
+ request = SearchRequest()
219
+ request.relation_subgraph.entry_points.extend(detected_entities)
220
+ request.relation_subgraph.depth = 1
190
221
 
191
- relations_results: list[RelationSearchResponse]
222
+ results: list[SearchResponse]
192
223
  (
193
- relations_results,
224
+ results,
194
225
  _,
195
226
  _,
196
227
  ) = await node_query(
197
228
  kbid,
198
- Method.RELATIONS,
199
- relation_request,
200
- target_shard_replicas=chat_request.shards,
201
- )
202
- return await merge_relations_results(
203
- relations_results, relation_request.subgraph
229
+ Method.SEARCH,
230
+ request,
231
+ target_shard_replicas=target_shard_replicas,
232
+ timeout=timeout,
233
+ use_read_replica_nodes=True,
234
+ retry_on_primary=False,
204
235
  )
236
+ relations_results: list[RelationSearchResponse] = [result.relation for result in results]
237
+ return await merge_relations_results(relations_results, request.relation_subgraph)
205
238
  except Exception as exc:
206
239
  capture_exception(exc)
207
240
  logger.exception("Error getting relations results")
208
241
  return Relations(entities={})
209
242
 
210
243
 
211
- async def not_enough_context_generator():
212
- await asyncio.sleep(0)
213
- yield NOT_ENOUGH_CONTEXT_ANSWER.encode()
214
- yield AnswerStatusCode.NO_CONTEXT.encode()
215
-
216
-
217
- async def chat(
218
- kbid: str,
219
- chat_request: ChatRequest,
220
- user_id: str,
221
- client_type: NucliaDBClientType,
222
- origin: str,
223
- ) -> ChatResult:
224
- start_time = time()
225
- nuclia_learning_id: Optional[str] = None
226
- chat_history = chat_request.context or []
227
- user_context = chat_request.extra_context or []
228
- user_query = chat_request.query
229
- rephrased_query = None
230
- prompt_context: PromptContext = {}
231
- prompt_context_order: PromptContextOrder = {}
232
-
233
- if len(chat_history) > 0 or len(user_context) > 0:
234
- rephrased_query = await rephrase_query(
235
- kbid,
236
- chat_history=chat_history,
237
- query=user_query,
238
- user_id=user_id,
239
- user_context=user_context,
240
- generative_model=chat_request.generative_model,
241
- )
242
-
243
- find_results, query_parser = await get_find_results(
244
- kbid=kbid,
245
- query=rephrased_query or user_query,
246
- chat_request=chat_request,
247
- ndb_client=client_type,
248
- user=user_id,
249
- origin=origin,
250
- )
251
-
252
- status_code = FoundStatusCode()
253
- if len(find_results.resources) == 0:
254
- answer_stream = format_generated_answer(
255
- not_enough_context_generator(), status_code
256
- )
257
- else:
258
- prompt_context_builder = PromptContextBuilder(
259
- kbid=kbid,
260
- find_results=find_results,
261
- user_context=user_context,
262
- strategies=chat_request.rag_strategies,
263
- image_strategies=chat_request.rag_images_strategies,
264
- max_context_size=await query_parser.get_max_context(),
265
- visual_llm=await query_parser.get_visual_llm_enabled(),
266
- )
267
- (
268
- prompt_context,
269
- prompt_context_order,
270
- prompt_context_images,
271
- ) = await prompt_context_builder.build()
272
- user_prompt = None
273
- if chat_request.prompt is not None:
274
- user_prompt = UserPrompt(prompt=chat_request.prompt)
275
-
276
- chat_model = ChatModel(
277
- user_id=user_id,
278
- query_context=prompt_context,
279
- query_context_order=prompt_context_order,
280
- chat_history=chat_history,
281
- question=user_query,
282
- truncate=True,
283
- user_prompt=user_prompt,
284
- citations=chat_request.citations,
285
- generative_model=chat_request.generative_model,
286
- max_tokens=chat_request.max_tokens,
287
- query_context_images=prompt_context_images,
288
- )
289
- predict = get_predict()
290
- nuclia_learning_id, predict_generator = await predict.chat_query(
291
- kbid, chat_model
292
- )
293
-
294
- async def _wrapped_stream():
295
- # so we can audit after streamed out answer
296
- text_answer = b""
297
- async for chunk in format_generated_answer(predict_generator, status_code):
298
- text_answer += chunk
299
- yield chunk
300
-
301
- await maybe_audit_chat(
302
- kbid=kbid,
303
- user_id=user_id,
304
- client_type=client_type,
305
- origin=origin,
306
- duration=time() - start_time,
307
- user_query=user_query,
308
- rephrased_query=rephrased_query,
309
- text_answer=text_answer,
310
- status_code=status_code.value,
311
- chat_history=chat_history,
312
- query_context=prompt_context,
313
- learning_id=nuclia_learning_id,
314
- )
315
-
316
- answer_stream = _wrapped_stream()
317
-
318
- return ChatResult(
319
- nuclia_learning_id=nuclia_learning_id,
320
- answer_stream=answer_stream,
321
- status_code=status_code,
322
- find_results=find_results,
323
- prompt_context=prompt_context,
324
- prompt_context_order=prompt_context_order,
325
- )
326
-
327
-
328
- def _parse_answer_status_code(chunk: bytes) -> AnswerStatusCode:
329
- """
330
- Parses the status code from the last chunk of the answer.
331
- """
332
- try:
333
- return AnswerStatusCode(chunk.decode())
334
- except ValueError:
335
- # In some cases, even if the status code was yield separately
336
- # at the server side, the status code is appended to the previous chunk...
337
- # It may be a bug in the aiohttp.StreamResponse implementation,
338
- # but we haven't spotted it yet. For now, we just try to parse the status code
339
- # from the tail of the chunk.
340
- logger.debug(
341
- f"Error decoding status code from /chat's last chunk. Chunk: {chunk!r}"
342
- )
343
- if chunk == b"":
344
- raise
345
- if chunk.endswith(b"0"):
346
- return AnswerStatusCode.SUCCESS
347
- return AnswerStatusCode(chunk[-2:].decode())
348
-
349
-
350
- async def maybe_audit_chat(
244
+ def maybe_audit_chat(
351
245
  *,
352
246
  kbid: str,
353
247
  user_id: str,
354
248
  client_type: NucliaDBClientType,
355
249
  origin: str,
356
- duration: float,
250
+ generative_answer_time: float,
251
+ generative_answer_first_chunk_time: float,
252
+ rephrase_time: Optional[float],
357
253
  user_query: str,
358
254
  rephrased_query: Optional[str],
359
255
  text_answer: bytes,
360
- status_code: Optional[AnswerStatusCode],
256
+ status_code: AnswerStatusCode,
361
257
  chat_history: list[ChatContextMessage],
362
- query_context: list[str],
258
+ query_context: PromptContext,
259
+ query_context_order: PromptContextOrder,
363
260
  learning_id: str,
261
+ model: str,
364
262
  ):
365
263
  audit = get_audit()
366
264
  if audit is None:
367
265
  return
368
266
 
369
267
  audit_answer = parse_audit_answer(text_answer, status_code)
268
+ # Append chat history
269
+ chat_history_context = [
270
+ audit_pb2.ChatContext(author=message.author, text=message.text) for message in chat_history
271
+ ]
370
272
 
371
- # Append chat history and query context
372
- audit_context = [
373
- audit_pb2.ChatContext(author=message.author, text=message.text)
374
- for message in chat_history
273
+ # Append paragraphs retrieved on this chat
274
+ chat_retrieved_context = [
275
+ audit_pb2.RetrievedContext(text_block_id=paragraph_id, text=text)
276
+ for paragraph_id, text in query_context.items()
375
277
  ]
376
- audit_context.append(
377
- audit_pb2.ChatContext(
378
- author=Author.NUCLIA,
379
- text=AUDIT_TEXT_RESULT_SEP.join(query_context),
380
- )
381
- )
382
- await audit.chat(
278
+
279
+ audit.chat(
383
280
  kbid,
384
281
  user_id,
385
- client_type.to_proto(),
282
+ to_proto.client_type(client_type),
386
283
  origin,
387
- duration,
388
284
  question=user_query,
285
+ generative_answer_time=generative_answer_time,
286
+ generative_answer_first_chunk_time=generative_answer_first_chunk_time,
287
+ rephrase_time=rephrase_time,
389
288
  rephrased_question=rephrased_query,
390
- context=audit_context,
289
+ chat_context=chat_history_context,
290
+ retrieved_context=chat_retrieved_context,
391
291
  answer=audit_answer,
392
292
  learning_id=learning_id,
293
+ status_code=int(status_code.value),
294
+ model=model,
393
295
  )
394
296
 
395
297
 
396
- def parse_audit_answer(
397
- raw_text_answer: bytes, status_code: Optional[AnswerStatusCode]
398
- ) -> Optional[str]:
298
+ def parse_audit_answer(raw_text_answer: bytes, status_code: AnswerStatusCode) -> Optional[str]:
399
299
  if status_code == AnswerStatusCode.NO_CONTEXT:
400
300
  # We don't want to audit "Not enough context to answer this." and instead set a None.
401
301
  return None
402
- # Split citations part from answer
403
- try:
404
- raw_audit_answer, _ = raw_text_answer.split(START_OF_CITATIONS)
405
- except ValueError:
406
- raw_audit_answer = raw_text_answer
407
- audit_answer = raw_audit_answer.decode()
408
- return audit_answer
302
+ return raw_text_answer.decode()
303
+
304
+
305
+ def tokens_to_chars(n_tokens: int) -> int:
306
+ # Multiply by 3 to have a good margin and guess between characters and tokens.
307
+ # This will be properly cut at the NUA predict API.
308
+ return n_tokens * 3
309
+
310
+
311
+ class ChatAuditor:
312
+ def __init__(
313
+ self,
314
+ kbid: str,
315
+ user_id: str,
316
+ client_type: NucliaDBClientType,
317
+ origin: str,
318
+ user_query: str,
319
+ rephrased_query: Optional[str],
320
+ chat_history: list[ChatContextMessage],
321
+ learning_id: Optional[str],
322
+ query_context: PromptContext,
323
+ query_context_order: PromptContextOrder,
324
+ model: str,
325
+ ):
326
+ self.kbid = kbid
327
+ self.user_id = user_id
328
+ self.client_type = client_type
329
+ self.origin = origin
330
+ self.user_query = user_query
331
+ self.rephrased_query = rephrased_query
332
+ self.chat_history = chat_history
333
+ self.learning_id = learning_id
334
+ self.query_context = query_context
335
+ self.query_context_order = query_context_order
336
+ self.model = model
337
+
338
+ def audit(
339
+ self,
340
+ text_answer: bytes,
341
+ generative_answer_time: float,
342
+ generative_answer_first_chunk_time: float,
343
+ rephrase_time: Optional[float],
344
+ status_code: AnswerStatusCode,
345
+ ):
346
+ maybe_audit_chat(
347
+ kbid=self.kbid,
348
+ user_id=self.user_id,
349
+ client_type=self.client_type,
350
+ origin=self.origin,
351
+ user_query=self.user_query,
352
+ rephrased_query=self.rephrased_query,
353
+ generative_answer_time=generative_answer_time,
354
+ generative_answer_first_chunk_time=generative_answer_first_chunk_time,
355
+ rephrase_time=rephrase_time,
356
+ text_answer=text_answer,
357
+ status_code=status_code,
358
+ chat_history=self.chat_history,
359
+ query_context=self.query_context,
360
+ query_context_order=self.query_context_order,
361
+ learning_id=self.learning_id or "unknown",
362
+ model=self.model,
363
+ )
364
+
365
+
366
+ def sorted_prompt_context_list(context: PromptContext, order: PromptContextOrder) -> list[str]:
367
+ """
368
+ context = {"x": "foo", "y": "bar"}
369
+ order = {"y": 1, "x": 0}
370
+ sorted_prompt_context_list(context, order) == ["foo", "bar"]
371
+ """
372
+ sorted_items = sorted(
373
+ context.items(),
374
+ key=lambda item: order.get(item[0], float("inf")),
375
+ )
376
+ return list(map(lambda item: item[1], sorted_items))
377
+
378
+
379
+ async def run_prequeries(
380
+ kbid: str,
381
+ prequeries: list[PreQuery],
382
+ x_ndb_client: NucliaDBClientType,
383
+ x_nucliadb_user: str,
384
+ x_forwarded_for: str,
385
+ generative_model: Optional[str] = None,
386
+ metrics: RAGMetrics = RAGMetrics(),
387
+ ) -> list[PreQueryResult]:
388
+ """
389
+ Runs simultaneous find requests for each prequery and returns the merged results according to the normalized weights.
390
+ """
391
+ results: list[PreQueryResult] = []
392
+ max_parallel_prequeries = asyncio.Semaphore(settings.prequeries_max_parallel)
393
+
394
+ async def _prequery_find(
395
+ prequery: PreQuery,
396
+ ):
397
+ async with max_parallel_prequeries:
398
+ find_results, _, _ = await find(
399
+ kbid,
400
+ prequery.request,
401
+ x_ndb_client,
402
+ x_nucliadb_user,
403
+ x_forwarded_for,
404
+ generative_model=generative_model,
405
+ metrics=metrics,
406
+ )
407
+ return prequery, find_results
408
+
409
+ ops = []
410
+ for prequery in prequeries:
411
+ ops.append(asyncio.create_task(_prequery_find(prequery)))
412
+ ops_results = await asyncio.gather(*ops)
413
+ for prequery, find_results in ops_results:
414
+ results.append((prequery, find_results))
415
+ return results
@@ -18,11 +18,13 @@
18
18
  # along with this program. If not, see <http://www.gnu.org/licenses/>.
19
19
  #
20
20
 
21
- from nucliadb.common.maindb.utils import get_driver
22
- from nucliadb.ingest.orm.knowledgebox import KnowledgeBox
21
+ from typing import TypeVar
23
22
 
23
+ T = TypeVar("T")
24
24
 
25
- async def resource_slug_exists(kbid: str, slug: str) -> bool:
26
- driver = get_driver()
27
- async with driver.transaction() as txn:
28
- return await KnowledgeBox.resource_slug_exists(txn, kbid, slug)
25
+
26
+ def cut_page(items: list[T], top_k: int) -> tuple[list[T], bool]:
27
+ """Return a slice of `items` representing the specified page and a boolean
28
+ indicating whether there is a next page or not"""
29
+ next_page = len(items) > top_k
30
+ return items[:top_k], next_page