nucliadb 2.46.1.post382__py3-none-any.whl → 6.2.1.post2777__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (431) hide show
  1. migrations/0002_rollover_shards.py +1 -2
  2. migrations/0003_allfields_key.py +2 -37
  3. migrations/0004_rollover_shards.py +1 -2
  4. migrations/0005_rollover_shards.py +1 -2
  5. migrations/0006_rollover_shards.py +2 -4
  6. migrations/0008_cleanup_leftover_rollover_metadata.py +1 -2
  7. migrations/0009_upgrade_relations_and_texts_to_v2.py +5 -4
  8. migrations/0010_fix_corrupt_indexes.py +11 -12
  9. migrations/0011_materialize_labelset_ids.py +2 -18
  10. migrations/0012_rollover_shards.py +6 -12
  11. migrations/0013_rollover_shards.py +2 -4
  12. migrations/0014_rollover_shards.py +5 -7
  13. migrations/0015_targeted_rollover.py +6 -12
  14. migrations/0016_upgrade_to_paragraphs_v2.py +27 -32
  15. migrations/0017_multiple_writable_shards.py +3 -6
  16. migrations/0018_purge_orphan_kbslugs.py +59 -0
  17. migrations/0019_upgrade_to_paragraphs_v3.py +66 -0
  18. migrations/0020_drain_nodes_from_cluster.py +83 -0
  19. nucliadb/standalone/tests/unit/test_run.py → migrations/0021_overwrite_vectorsets_key.py +17 -18
  20. nucliadb/tests/unit/test_openapi.py → migrations/0022_fix_paragraph_deletion_bug.py +16 -11
  21. migrations/0023_backfill_pg_catalog.py +80 -0
  22. migrations/0025_assign_models_to_kbs_v2.py +113 -0
  23. migrations/0026_fix_high_cardinality_content_types.py +61 -0
  24. migrations/0027_rollover_texts3.py +73 -0
  25. nucliadb/ingest/fields/date.py → migrations/pg/0001_bootstrap.py +10 -12
  26. migrations/pg/0002_catalog.py +42 -0
  27. nucliadb/ingest/tests/unit/test_settings.py → migrations/pg/0003_catalog_kbid_index.py +5 -3
  28. nucliadb/common/cluster/base.py +41 -24
  29. nucliadb/common/cluster/discovery/base.py +6 -14
  30. nucliadb/common/cluster/discovery/k8s.py +9 -19
  31. nucliadb/common/cluster/discovery/manual.py +1 -3
  32. nucliadb/common/cluster/discovery/single.py +1 -2
  33. nucliadb/common/cluster/discovery/utils.py +1 -3
  34. nucliadb/common/cluster/grpc_node_dummy.py +11 -16
  35. nucliadb/common/cluster/index_node.py +10 -19
  36. nucliadb/common/cluster/manager.py +223 -102
  37. nucliadb/common/cluster/rebalance.py +42 -37
  38. nucliadb/common/cluster/rollover.py +377 -204
  39. nucliadb/common/cluster/settings.py +16 -9
  40. nucliadb/common/cluster/standalone/grpc_node_binding.py +24 -76
  41. nucliadb/common/cluster/standalone/index_node.py +4 -11
  42. nucliadb/common/cluster/standalone/service.py +2 -6
  43. nucliadb/common/cluster/standalone/utils.py +9 -6
  44. nucliadb/common/cluster/utils.py +43 -29
  45. nucliadb/common/constants.py +20 -0
  46. nucliadb/common/context/__init__.py +6 -4
  47. nucliadb/common/context/fastapi.py +8 -5
  48. nucliadb/{tests/knowledgeboxes/__init__.py → common/counters.py} +8 -2
  49. nucliadb/common/datamanagers/__init__.py +24 -5
  50. nucliadb/common/datamanagers/atomic.py +102 -0
  51. nucliadb/common/datamanagers/cluster.py +5 -5
  52. nucliadb/common/datamanagers/entities.py +6 -16
  53. nucliadb/common/datamanagers/fields.py +84 -0
  54. nucliadb/common/datamanagers/kb.py +101 -24
  55. nucliadb/common/datamanagers/labels.py +26 -56
  56. nucliadb/common/datamanagers/processing.py +2 -6
  57. nucliadb/common/datamanagers/resources.py +214 -117
  58. nucliadb/common/datamanagers/rollover.py +77 -16
  59. nucliadb/{ingest/orm → common/datamanagers}/synonyms.py +16 -28
  60. nucliadb/common/datamanagers/utils.py +19 -11
  61. nucliadb/common/datamanagers/vectorsets.py +110 -0
  62. nucliadb/common/external_index_providers/base.py +257 -0
  63. nucliadb/{ingest/tests/unit/test_cache.py → common/external_index_providers/exceptions.py} +9 -8
  64. nucliadb/common/external_index_providers/manager.py +101 -0
  65. nucliadb/common/external_index_providers/pinecone.py +933 -0
  66. nucliadb/common/external_index_providers/settings.py +52 -0
  67. nucliadb/common/http_clients/auth.py +3 -6
  68. nucliadb/common/http_clients/processing.py +6 -11
  69. nucliadb/common/http_clients/utils.py +1 -3
  70. nucliadb/common/ids.py +240 -0
  71. nucliadb/common/locking.py +43 -13
  72. nucliadb/common/maindb/driver.py +11 -35
  73. nucliadb/common/maindb/exceptions.py +6 -6
  74. nucliadb/common/maindb/local.py +22 -9
  75. nucliadb/common/maindb/pg.py +206 -111
  76. nucliadb/common/maindb/utils.py +13 -44
  77. nucliadb/common/models_utils/from_proto.py +479 -0
  78. nucliadb/common/models_utils/to_proto.py +60 -0
  79. nucliadb/common/nidx.py +260 -0
  80. nucliadb/export_import/datamanager.py +25 -19
  81. nucliadb/export_import/exceptions.py +8 -0
  82. nucliadb/export_import/exporter.py +20 -7
  83. nucliadb/export_import/importer.py +6 -11
  84. nucliadb/export_import/models.py +5 -5
  85. nucliadb/export_import/tasks.py +4 -4
  86. nucliadb/export_import/utils.py +94 -54
  87. nucliadb/health.py +1 -3
  88. nucliadb/ingest/app.py +15 -11
  89. nucliadb/ingest/consumer/auditing.py +30 -147
  90. nucliadb/ingest/consumer/consumer.py +96 -52
  91. nucliadb/ingest/consumer/materializer.py +10 -12
  92. nucliadb/ingest/consumer/pull.py +12 -27
  93. nucliadb/ingest/consumer/service.py +20 -19
  94. nucliadb/ingest/consumer/shard_creator.py +7 -14
  95. nucliadb/ingest/consumer/utils.py +1 -3
  96. nucliadb/ingest/fields/base.py +139 -188
  97. nucliadb/ingest/fields/conversation.py +18 -5
  98. nucliadb/ingest/fields/exceptions.py +1 -4
  99. nucliadb/ingest/fields/file.py +7 -25
  100. nucliadb/ingest/fields/link.py +11 -16
  101. nucliadb/ingest/fields/text.py +9 -4
  102. nucliadb/ingest/orm/brain.py +255 -262
  103. nucliadb/ingest/orm/broker_message.py +181 -0
  104. nucliadb/ingest/orm/entities.py +36 -51
  105. nucliadb/ingest/orm/exceptions.py +12 -0
  106. nucliadb/ingest/orm/knowledgebox.py +334 -278
  107. nucliadb/ingest/orm/processor/__init__.py +2 -697
  108. nucliadb/ingest/orm/processor/auditing.py +117 -0
  109. nucliadb/ingest/orm/processor/data_augmentation.py +164 -0
  110. nucliadb/ingest/orm/processor/pgcatalog.py +84 -0
  111. nucliadb/ingest/orm/processor/processor.py +752 -0
  112. nucliadb/ingest/orm/processor/sequence_manager.py +1 -1
  113. nucliadb/ingest/orm/resource.py +280 -520
  114. nucliadb/ingest/orm/utils.py +25 -31
  115. nucliadb/ingest/partitions.py +3 -9
  116. nucliadb/ingest/processing.py +76 -81
  117. nucliadb/ingest/py.typed +0 -0
  118. nucliadb/ingest/serialize.py +37 -173
  119. nucliadb/ingest/service/__init__.py +1 -3
  120. nucliadb/ingest/service/writer.py +186 -577
  121. nucliadb/ingest/settings.py +13 -22
  122. nucliadb/ingest/utils.py +3 -6
  123. nucliadb/learning_proxy.py +264 -51
  124. nucliadb/metrics_exporter.py +30 -19
  125. nucliadb/middleware/__init__.py +1 -3
  126. nucliadb/migrator/command.py +1 -3
  127. nucliadb/migrator/datamanager.py +13 -13
  128. nucliadb/migrator/migrator.py +57 -37
  129. nucliadb/migrator/settings.py +2 -1
  130. nucliadb/migrator/utils.py +18 -10
  131. nucliadb/purge/__init__.py +139 -33
  132. nucliadb/purge/orphan_shards.py +7 -13
  133. nucliadb/reader/__init__.py +1 -3
  134. nucliadb/reader/api/models.py +3 -14
  135. nucliadb/reader/api/v1/__init__.py +0 -1
  136. nucliadb/reader/api/v1/download.py +27 -94
  137. nucliadb/reader/api/v1/export_import.py +4 -4
  138. nucliadb/reader/api/v1/knowledgebox.py +13 -13
  139. nucliadb/reader/api/v1/learning_config.py +8 -12
  140. nucliadb/reader/api/v1/resource.py +67 -93
  141. nucliadb/reader/api/v1/services.py +70 -125
  142. nucliadb/reader/app.py +16 -46
  143. nucliadb/reader/lifecycle.py +18 -4
  144. nucliadb/reader/py.typed +0 -0
  145. nucliadb/reader/reader/notifications.py +10 -31
  146. nucliadb/search/__init__.py +1 -3
  147. nucliadb/search/api/v1/__init__.py +2 -2
  148. nucliadb/search/api/v1/ask.py +112 -0
  149. nucliadb/search/api/v1/catalog.py +184 -0
  150. nucliadb/search/api/v1/feedback.py +17 -25
  151. nucliadb/search/api/v1/find.py +41 -41
  152. nucliadb/search/api/v1/knowledgebox.py +90 -62
  153. nucliadb/search/api/v1/predict_proxy.py +2 -2
  154. nucliadb/search/api/v1/resource/ask.py +66 -117
  155. nucliadb/search/api/v1/resource/search.py +51 -72
  156. nucliadb/search/api/v1/router.py +1 -0
  157. nucliadb/search/api/v1/search.py +50 -197
  158. nucliadb/search/api/v1/suggest.py +40 -54
  159. nucliadb/search/api/v1/summarize.py +9 -5
  160. nucliadb/search/api/v1/utils.py +2 -1
  161. nucliadb/search/app.py +16 -48
  162. nucliadb/search/lifecycle.py +10 -3
  163. nucliadb/search/predict.py +176 -188
  164. nucliadb/search/py.typed +0 -0
  165. nucliadb/search/requesters/utils.py +41 -63
  166. nucliadb/search/search/cache.py +149 -20
  167. nucliadb/search/search/chat/ask.py +918 -0
  168. nucliadb/search/{tests/unit/test_run.py → search/chat/exceptions.py} +14 -13
  169. nucliadb/search/search/chat/images.py +41 -17
  170. nucliadb/search/search/chat/prompt.py +851 -282
  171. nucliadb/search/search/chat/query.py +274 -267
  172. nucliadb/{writer/resource/slug.py → search/search/cut.py} +8 -6
  173. nucliadb/search/search/fetch.py +43 -36
  174. nucliadb/search/search/filters.py +9 -15
  175. nucliadb/search/search/find.py +214 -54
  176. nucliadb/search/search/find_merge.py +408 -391
  177. nucliadb/search/search/hydrator.py +191 -0
  178. nucliadb/search/search/merge.py +198 -234
  179. nucliadb/search/search/metrics.py +73 -2
  180. nucliadb/search/search/paragraphs.py +64 -106
  181. nucliadb/search/search/pgcatalog.py +233 -0
  182. nucliadb/search/search/predict_proxy.py +1 -1
  183. nucliadb/search/search/query.py +386 -257
  184. nucliadb/search/search/query_parser/exceptions.py +22 -0
  185. nucliadb/search/search/query_parser/models.py +101 -0
  186. nucliadb/search/search/query_parser/parser.py +183 -0
  187. nucliadb/search/search/rank_fusion.py +204 -0
  188. nucliadb/search/search/rerankers.py +270 -0
  189. nucliadb/search/search/shards.py +4 -38
  190. nucliadb/search/search/summarize.py +14 -18
  191. nucliadb/search/search/utils.py +27 -4
  192. nucliadb/search/settings.py +15 -1
  193. nucliadb/standalone/api_router.py +4 -10
  194. nucliadb/standalone/app.py +17 -14
  195. nucliadb/standalone/auth.py +7 -21
  196. nucliadb/standalone/config.py +9 -12
  197. nucliadb/standalone/introspect.py +5 -5
  198. nucliadb/standalone/lifecycle.py +26 -25
  199. nucliadb/standalone/migrations.py +58 -0
  200. nucliadb/standalone/purge.py +9 -8
  201. nucliadb/standalone/py.typed +0 -0
  202. nucliadb/standalone/run.py +25 -18
  203. nucliadb/standalone/settings.py +10 -14
  204. nucliadb/standalone/versions.py +15 -5
  205. nucliadb/tasks/consumer.py +8 -12
  206. nucliadb/tasks/producer.py +7 -6
  207. nucliadb/tests/config.py +53 -0
  208. nucliadb/train/__init__.py +1 -3
  209. nucliadb/train/api/utils.py +1 -2
  210. nucliadb/train/api/v1/shards.py +2 -2
  211. nucliadb/train/api/v1/trainset.py +4 -6
  212. nucliadb/train/app.py +14 -47
  213. nucliadb/train/generator.py +10 -19
  214. nucliadb/train/generators/field_classifier.py +7 -19
  215. nucliadb/train/generators/field_streaming.py +156 -0
  216. nucliadb/train/generators/image_classifier.py +12 -18
  217. nucliadb/train/generators/paragraph_classifier.py +5 -9
  218. nucliadb/train/generators/paragraph_streaming.py +6 -9
  219. nucliadb/train/generators/question_answer_streaming.py +19 -20
  220. nucliadb/train/generators/sentence_classifier.py +9 -15
  221. nucliadb/train/generators/token_classifier.py +45 -36
  222. nucliadb/train/generators/utils.py +14 -18
  223. nucliadb/train/lifecycle.py +7 -3
  224. nucliadb/train/nodes.py +23 -32
  225. nucliadb/train/py.typed +0 -0
  226. nucliadb/train/servicer.py +13 -21
  227. nucliadb/train/settings.py +2 -6
  228. nucliadb/train/types.py +13 -10
  229. nucliadb/train/upload.py +3 -6
  230. nucliadb/train/uploader.py +20 -25
  231. nucliadb/train/utils.py +1 -1
  232. nucliadb/writer/__init__.py +1 -3
  233. nucliadb/writer/api/constants.py +0 -5
  234. nucliadb/{ingest/fields/keywordset.py → writer/api/utils.py} +13 -10
  235. nucliadb/writer/api/v1/export_import.py +102 -49
  236. nucliadb/writer/api/v1/field.py +196 -620
  237. nucliadb/writer/api/v1/knowledgebox.py +221 -71
  238. nucliadb/writer/api/v1/learning_config.py +2 -2
  239. nucliadb/writer/api/v1/resource.py +114 -216
  240. nucliadb/writer/api/v1/services.py +64 -132
  241. nucliadb/writer/api/v1/slug.py +61 -0
  242. nucliadb/writer/api/v1/transaction.py +67 -0
  243. nucliadb/writer/api/v1/upload.py +184 -215
  244. nucliadb/writer/app.py +11 -61
  245. nucliadb/writer/back_pressure.py +62 -43
  246. nucliadb/writer/exceptions.py +0 -4
  247. nucliadb/writer/lifecycle.py +21 -15
  248. nucliadb/writer/py.typed +0 -0
  249. nucliadb/writer/resource/audit.py +2 -1
  250. nucliadb/writer/resource/basic.py +48 -62
  251. nucliadb/writer/resource/field.py +45 -135
  252. nucliadb/writer/resource/origin.py +1 -2
  253. nucliadb/writer/settings.py +14 -5
  254. nucliadb/writer/tus/__init__.py +17 -15
  255. nucliadb/writer/tus/azure.py +111 -0
  256. nucliadb/writer/tus/dm.py +17 -5
  257. nucliadb/writer/tus/exceptions.py +1 -3
  258. nucliadb/writer/tus/gcs.py +56 -84
  259. nucliadb/writer/tus/local.py +21 -37
  260. nucliadb/writer/tus/s3.py +28 -68
  261. nucliadb/writer/tus/storage.py +5 -56
  262. nucliadb/writer/vectorsets.py +125 -0
  263. nucliadb-6.2.1.post2777.dist-info/METADATA +148 -0
  264. nucliadb-6.2.1.post2777.dist-info/RECORD +343 -0
  265. {nucliadb-2.46.1.post382.dist-info → nucliadb-6.2.1.post2777.dist-info}/WHEEL +1 -1
  266. nucliadb/common/maindb/redis.py +0 -194
  267. nucliadb/common/maindb/tikv.py +0 -412
  268. nucliadb/ingest/fields/layout.py +0 -58
  269. nucliadb/ingest/tests/conftest.py +0 -30
  270. nucliadb/ingest/tests/fixtures.py +0 -771
  271. nucliadb/ingest/tests/integration/consumer/__init__.py +0 -18
  272. nucliadb/ingest/tests/integration/consumer/test_auditing.py +0 -80
  273. nucliadb/ingest/tests/integration/consumer/test_materializer.py +0 -89
  274. nucliadb/ingest/tests/integration/consumer/test_pull.py +0 -144
  275. nucliadb/ingest/tests/integration/consumer/test_service.py +0 -81
  276. nucliadb/ingest/tests/integration/consumer/test_shard_creator.py +0 -68
  277. nucliadb/ingest/tests/integration/ingest/test_ingest.py +0 -691
  278. nucliadb/ingest/tests/integration/ingest/test_processing_engine.py +0 -95
  279. nucliadb/ingest/tests/integration/ingest/test_relations.py +0 -272
  280. nucliadb/ingest/tests/unit/consumer/__init__.py +0 -18
  281. nucliadb/ingest/tests/unit/consumer/test_auditing.py +0 -140
  282. nucliadb/ingest/tests/unit/consumer/test_consumer.py +0 -69
  283. nucliadb/ingest/tests/unit/consumer/test_pull.py +0 -60
  284. nucliadb/ingest/tests/unit/consumer/test_shard_creator.py +0 -139
  285. nucliadb/ingest/tests/unit/consumer/test_utils.py +0 -67
  286. nucliadb/ingest/tests/unit/orm/__init__.py +0 -19
  287. nucliadb/ingest/tests/unit/orm/test_brain.py +0 -247
  288. nucliadb/ingest/tests/unit/orm/test_processor.py +0 -131
  289. nucliadb/ingest/tests/unit/orm/test_resource.py +0 -275
  290. nucliadb/ingest/tests/unit/test_partitions.py +0 -40
  291. nucliadb/ingest/tests/unit/test_processing.py +0 -171
  292. nucliadb/middleware/transaction.py +0 -117
  293. nucliadb/reader/api/v1/learning_collector.py +0 -63
  294. nucliadb/reader/tests/__init__.py +0 -19
  295. nucliadb/reader/tests/conftest.py +0 -31
  296. nucliadb/reader/tests/fixtures.py +0 -136
  297. nucliadb/reader/tests/test_list_resources.py +0 -75
  298. nucliadb/reader/tests/test_reader_file_download.py +0 -273
  299. nucliadb/reader/tests/test_reader_resource.py +0 -379
  300. nucliadb/reader/tests/test_reader_resource_field.py +0 -219
  301. nucliadb/search/api/v1/chat.py +0 -258
  302. nucliadb/search/api/v1/resource/chat.py +0 -94
  303. nucliadb/search/tests/__init__.py +0 -19
  304. nucliadb/search/tests/conftest.py +0 -33
  305. nucliadb/search/tests/fixtures.py +0 -199
  306. nucliadb/search/tests/node.py +0 -465
  307. nucliadb/search/tests/unit/__init__.py +0 -18
  308. nucliadb/search/tests/unit/api/__init__.py +0 -19
  309. nucliadb/search/tests/unit/api/v1/__init__.py +0 -19
  310. nucliadb/search/tests/unit/api/v1/resource/__init__.py +0 -19
  311. nucliadb/search/tests/unit/api/v1/resource/test_ask.py +0 -67
  312. nucliadb/search/tests/unit/api/v1/resource/test_chat.py +0 -97
  313. nucliadb/search/tests/unit/api/v1/test_chat.py +0 -96
  314. nucliadb/search/tests/unit/api/v1/test_predict_proxy.py +0 -98
  315. nucliadb/search/tests/unit/api/v1/test_summarize.py +0 -93
  316. nucliadb/search/tests/unit/search/__init__.py +0 -18
  317. nucliadb/search/tests/unit/search/requesters/__init__.py +0 -18
  318. nucliadb/search/tests/unit/search/requesters/test_utils.py +0 -210
  319. nucliadb/search/tests/unit/search/search/__init__.py +0 -19
  320. nucliadb/search/tests/unit/search/search/test_shards.py +0 -45
  321. nucliadb/search/tests/unit/search/search/test_utils.py +0 -82
  322. nucliadb/search/tests/unit/search/test_chat_prompt.py +0 -266
  323. nucliadb/search/tests/unit/search/test_fetch.py +0 -108
  324. nucliadb/search/tests/unit/search/test_filters.py +0 -125
  325. nucliadb/search/tests/unit/search/test_paragraphs.py +0 -157
  326. nucliadb/search/tests/unit/search/test_predict_proxy.py +0 -106
  327. nucliadb/search/tests/unit/search/test_query.py +0 -201
  328. nucliadb/search/tests/unit/test_app.py +0 -79
  329. nucliadb/search/tests/unit/test_find_merge.py +0 -112
  330. nucliadb/search/tests/unit/test_merge.py +0 -34
  331. nucliadb/search/tests/unit/test_predict.py +0 -584
  332. nucliadb/standalone/tests/__init__.py +0 -19
  333. nucliadb/standalone/tests/conftest.py +0 -33
  334. nucliadb/standalone/tests/fixtures.py +0 -38
  335. nucliadb/standalone/tests/unit/__init__.py +0 -18
  336. nucliadb/standalone/tests/unit/test_api_router.py +0 -61
  337. nucliadb/standalone/tests/unit/test_auth.py +0 -169
  338. nucliadb/standalone/tests/unit/test_introspect.py +0 -35
  339. nucliadb/standalone/tests/unit/test_versions.py +0 -68
  340. nucliadb/tests/benchmarks/__init__.py +0 -19
  341. nucliadb/tests/benchmarks/test_search.py +0 -99
  342. nucliadb/tests/conftest.py +0 -32
  343. nucliadb/tests/fixtures.py +0 -736
  344. nucliadb/tests/knowledgeboxes/philosophy_books.py +0 -203
  345. nucliadb/tests/knowledgeboxes/ten_dummy_resources.py +0 -109
  346. nucliadb/tests/migrations/__init__.py +0 -19
  347. nucliadb/tests/migrations/test_migration_0017.py +0 -80
  348. nucliadb/tests/tikv.py +0 -240
  349. nucliadb/tests/unit/__init__.py +0 -19
  350. nucliadb/tests/unit/common/__init__.py +0 -19
  351. nucliadb/tests/unit/common/cluster/__init__.py +0 -19
  352. nucliadb/tests/unit/common/cluster/discovery/__init__.py +0 -19
  353. nucliadb/tests/unit/common/cluster/discovery/test_k8s.py +0 -170
  354. nucliadb/tests/unit/common/cluster/standalone/__init__.py +0 -18
  355. nucliadb/tests/unit/common/cluster/standalone/test_service.py +0 -113
  356. nucliadb/tests/unit/common/cluster/standalone/test_utils.py +0 -59
  357. nucliadb/tests/unit/common/cluster/test_cluster.py +0 -399
  358. nucliadb/tests/unit/common/cluster/test_kb_shard_manager.py +0 -178
  359. nucliadb/tests/unit/common/cluster/test_rollover.py +0 -279
  360. nucliadb/tests/unit/common/maindb/__init__.py +0 -18
  361. nucliadb/tests/unit/common/maindb/test_driver.py +0 -127
  362. nucliadb/tests/unit/common/maindb/test_tikv.py +0 -53
  363. nucliadb/tests/unit/common/maindb/test_utils.py +0 -81
  364. nucliadb/tests/unit/common/test_context.py +0 -36
  365. nucliadb/tests/unit/export_import/__init__.py +0 -19
  366. nucliadb/tests/unit/export_import/test_datamanager.py +0 -37
  367. nucliadb/tests/unit/export_import/test_utils.py +0 -294
  368. nucliadb/tests/unit/migrator/__init__.py +0 -19
  369. nucliadb/tests/unit/migrator/test_migrator.py +0 -87
  370. nucliadb/tests/unit/tasks/__init__.py +0 -19
  371. nucliadb/tests/unit/tasks/conftest.py +0 -42
  372. nucliadb/tests/unit/tasks/test_consumer.py +0 -93
  373. nucliadb/tests/unit/tasks/test_producer.py +0 -95
  374. nucliadb/tests/unit/tasks/test_tasks.py +0 -60
  375. nucliadb/tests/unit/test_field_ids.py +0 -49
  376. nucliadb/tests/unit/test_health.py +0 -84
  377. nucliadb/tests/unit/test_kb_slugs.py +0 -54
  378. nucliadb/tests/unit/test_learning_proxy.py +0 -252
  379. nucliadb/tests/unit/test_metrics_exporter.py +0 -77
  380. nucliadb/tests/unit/test_purge.py +0 -138
  381. nucliadb/tests/utils/__init__.py +0 -74
  382. nucliadb/tests/utils/aiohttp_session.py +0 -44
  383. nucliadb/tests/utils/broker_messages/__init__.py +0 -167
  384. nucliadb/tests/utils/broker_messages/fields.py +0 -181
  385. nucliadb/tests/utils/broker_messages/helpers.py +0 -33
  386. nucliadb/tests/utils/entities.py +0 -78
  387. nucliadb/train/api/v1/check.py +0 -60
  388. nucliadb/train/tests/__init__.py +0 -19
  389. nucliadb/train/tests/conftest.py +0 -29
  390. nucliadb/train/tests/fixtures.py +0 -342
  391. nucliadb/train/tests/test_field_classification.py +0 -122
  392. nucliadb/train/tests/test_get_entities.py +0 -80
  393. nucliadb/train/tests/test_get_info.py +0 -51
  394. nucliadb/train/tests/test_get_ontology.py +0 -34
  395. nucliadb/train/tests/test_get_ontology_count.py +0 -63
  396. nucliadb/train/tests/test_image_classification.py +0 -222
  397. nucliadb/train/tests/test_list_fields.py +0 -39
  398. nucliadb/train/tests/test_list_paragraphs.py +0 -73
  399. nucliadb/train/tests/test_list_resources.py +0 -39
  400. nucliadb/train/tests/test_list_sentences.py +0 -71
  401. nucliadb/train/tests/test_paragraph_classification.py +0 -123
  402. nucliadb/train/tests/test_paragraph_streaming.py +0 -118
  403. nucliadb/train/tests/test_question_answer_streaming.py +0 -239
  404. nucliadb/train/tests/test_sentence_classification.py +0 -143
  405. nucliadb/train/tests/test_token_classification.py +0 -136
  406. nucliadb/train/tests/utils.py +0 -108
  407. nucliadb/writer/layouts/__init__.py +0 -51
  408. nucliadb/writer/layouts/v1.py +0 -59
  409. nucliadb/writer/resource/vectors.py +0 -120
  410. nucliadb/writer/tests/__init__.py +0 -19
  411. nucliadb/writer/tests/conftest.py +0 -31
  412. nucliadb/writer/tests/fixtures.py +0 -192
  413. nucliadb/writer/tests/test_fields.py +0 -486
  414. nucliadb/writer/tests/test_files.py +0 -743
  415. nucliadb/writer/tests/test_knowledgebox.py +0 -49
  416. nucliadb/writer/tests/test_reprocess_file_field.py +0 -139
  417. nucliadb/writer/tests/test_resources.py +0 -546
  418. nucliadb/writer/tests/test_service.py +0 -137
  419. nucliadb/writer/tests/test_tus.py +0 -203
  420. nucliadb/writer/tests/utils.py +0 -35
  421. nucliadb/writer/tus/pg.py +0 -125
  422. nucliadb-2.46.1.post382.dist-info/METADATA +0 -134
  423. nucliadb-2.46.1.post382.dist-info/RECORD +0 -451
  424. {nucliadb/ingest/tests → migrations/pg}/__init__.py +0 -0
  425. /nucliadb/{ingest/tests/integration → common/external_index_providers}/__init__.py +0 -0
  426. /nucliadb/{ingest/tests/integration/ingest → common/models_utils}/__init__.py +0 -0
  427. /nucliadb/{ingest/tests/unit → search/search/query_parser}/__init__.py +0 -0
  428. /nucliadb/{ingest/tests → tests}/vectors.py +0 -0
  429. {nucliadb-2.46.1.post382.dist-info → nucliadb-6.2.1.post2777.dist-info}/entry_points.txt +0 -0
  430. {nucliadb-2.46.1.post382.dist-info → nucliadb-6.2.1.post2777.dist-info}/top_level.txt +0 -0
  431. {nucliadb-2.46.1.post382.dist-info → nucliadb-6.2.1.post2777.dist-info}/zip-safe +0 -0
@@ -19,31 +19,37 @@
19
19
  #
20
20
  import json
21
21
  import os
22
+ import random
22
23
  from enum import Enum
23
- from typing import AsyncIterator, Optional
24
+ from typing import Any, AsyncIterator, Optional
24
25
  from unittest.mock import AsyncMock, Mock
25
26
 
26
27
  import aiohttp
27
28
  import backoff
28
- from nucliadb_protos.utils_pb2 import RelationNode
29
+ from nuclia_models.predict.generative_responses import GenerativeChunk
30
+ from pydantic import ValidationError
29
31
 
30
- from nucliadb.ingest.tests.vectors import Q, Qm2023
32
+ from nucliadb.common import datamanagers
31
33
  from nucliadb.search import logger
32
- from nucliadb_models.search import (
33
- AskDocumentModel,
34
- ChatModel,
35
- FeedbackRequest,
34
+ from nucliadb.tests.vectors import Q, Qm2023
35
+ from nucliadb_models.internal.predict import (
36
36
  Ner,
37
37
  QueryInfo,
38
- RephraseModel,
38
+ RerankModel,
39
+ RerankResponse,
39
40
  SentenceSearch,
41
+ TokenSearch,
42
+ )
43
+ from nucliadb_models.search import (
44
+ ChatModel,
45
+ RephraseModel,
40
46
  SummarizedResource,
41
47
  SummarizedResponse,
42
48
  SummarizeModel,
43
- TokenSearch,
44
49
  )
45
- from nucliadb_telemetry import metrics
46
- from nucliadb_utils import const
50
+ from nucliadb_protos.utils_pb2 import RelationNode
51
+ from nucliadb_telemetry import errors, metrics
52
+ from nucliadb_utils.const import Features
47
53
  from nucliadb_utils.exceptions import LimitsExceededError
48
54
  from nucliadb_utils.settings import nuclia_settings
49
55
  from nucliadb_utils.utilities import Utility, has_feature, set_utility
@@ -59,10 +65,6 @@ class ProxiedPredictAPIError(Exception):
59
65
  self.detail = detail
60
66
 
61
67
 
62
- class PredictVectorMissing(Exception):
63
- pass
64
-
65
-
66
68
  class NUAKeyMissingError(Exception):
67
69
  pass
68
70
 
@@ -77,13 +79,12 @@ class RephraseMissingContextError(Exception):
77
79
 
78
80
  DUMMY_RELATION_NODE = [
79
81
  RelationNode(value="Ferran", ntype=RelationNode.NodeType.ENTITY, subtype="PERSON"),
80
- RelationNode(
81
- value="Joan Antoni", ntype=RelationNode.NodeType.ENTITY, subtype="PERSON"
82
- ),
82
+ RelationNode(value="Joan Antoni", ntype=RelationNode.NodeType.ENTITY, subtype="PERSON"),
83
83
  ]
84
84
 
85
85
  DUMMY_REPHRASE_QUERY = "This is a rephrased query"
86
86
  DUMMY_LEARNING_ID = "00"
87
+ DUMMY_LEARNING_MODEL = "chatgpt"
87
88
 
88
89
 
89
90
  PUBLIC_PREDICT = "/api/v1/predict"
@@ -94,11 +95,12 @@ TOKENS = "/tokens"
94
95
  QUERY = "/query"
95
96
  SUMMARIZE = "/summarize"
96
97
  CHAT = "/chat"
97
- ASK_DOCUMENT = "/ask_document"
98
98
  REPHRASE = "/rephrase"
99
99
  FEEDBACK = "/feedback"
100
+ RERANK = "/rerank"
100
101
 
101
102
  NUCLIA_LEARNING_ID_HEADER = "NUCLIA-LEARNING-ID"
103
+ NUCLIA_LEARNING_MODEL_HEADER = "NUCLIA-LEARNING-MODEL"
102
104
 
103
105
 
104
106
  predict_observer = metrics.Observer(
@@ -107,7 +109,6 @@ predict_observer = metrics.Observer(
107
109
  error_mappings={
108
110
  "over_limits": LimitsExceededError,
109
111
  "predict_api_error": SendToPredictError,
110
- "empty_vectors": PredictVectorMissing,
111
112
  },
112
113
  )
113
114
 
@@ -121,6 +122,13 @@ class AnswerStatusCode(str, Enum):
121
122
  ERROR = "-1"
122
123
  NO_CONTEXT = "-2"
123
124
 
125
+ def prettify(self) -> str:
126
+ return {
127
+ AnswerStatusCode.SUCCESS: "success",
128
+ AnswerStatusCode.ERROR: "error",
129
+ AnswerStatusCode.NO_CONTEXT: "no_context",
130
+ }[self]
131
+
124
132
 
125
133
  async def start_predict_engine():
126
134
  if nuclia_settings.dummy_predict:
@@ -144,9 +152,7 @@ def convert_relations(data: dict[str, list[dict[str, str]]]) -> list[RelationNod
144
152
  for token in data["tokens"]:
145
153
  text = token["text"]
146
154
  klass = token["ner"]
147
- result.append(
148
- RelationNode(value=text, ntype=RelationNode.NodeType.ENTITY, subtype=klass)
149
- )
155
+ result.append(RelationNode(value=text, ntype=RelationNode.NodeType.ENTITY, subtype=klass))
150
156
  return result
151
157
 
152
158
 
@@ -179,9 +185,7 @@ class PredictEngine:
179
185
  await self.session.close()
180
186
 
181
187
  def check_nua_key_is_configured_for_onprem(self):
182
- if self.onprem and (
183
- self.nuclia_service_account is None and self.local_predict is False
184
- ):
188
+ if self.onprem and (self.nuclia_service_account is None and self.local_predict is False):
185
189
  raise NUAKeyMissingError()
186
190
 
187
191
  def get_predict_url(self, endpoint: str, kbid: str) -> str:
@@ -193,7 +197,7 @@ class PredictEngine:
193
197
  # /api/v1/predict/rephrase/{kbid}
194
198
  return f"{self.public_url}{PUBLIC_PREDICT}{endpoint}/{kbid}"
195
199
  else:
196
- if has_feature(const.Features.VERSIONED_PRIVATE_PREDICT):
200
+ if has_feature(Features.VERSIONED_PRIVATE_PREDICT):
197
201
  return f"{self.cluster_url}{VERSIONED_PRIVATE_PREDICT}{endpoint}"
198
202
  else:
199
203
  return f"{self.cluster_url}{PRIVATE_PREDICT}{endpoint}"
@@ -207,16 +211,13 @@ class PredictEngine:
207
211
  else:
208
212
  return {"X-STF-KBID": kbid}
209
213
 
210
- async def check_response(
211
- self, resp: aiohttp.ClientResponse, expected_status: int = 200
212
- ) -> None:
214
+ async def check_response(self, resp: aiohttp.ClientResponse, expected_status: int = 200) -> None:
213
215
  if resp.status == expected_status:
214
216
  return
215
217
 
216
218
  if resp.status == 402:
217
219
  data = await resp.json()
218
220
  raise LimitsExceededError(402, data["detail"])
219
-
220
221
  try:
221
222
  data = await resp.json()
222
223
  try:
@@ -228,7 +229,10 @@ class PredictEngine:
228
229
  aiohttp.client_exceptions.ContentTypeError,
229
230
  ):
230
231
  detail = await resp.text()
231
- logger.error(f"Predict API error at {resp.url}: {detail}")
232
+ if str(resp.status).startswith("5"):
233
+ logger.error(f"Predict API error at {resp.url}: {detail}")
234
+ else:
235
+ logger.info(f"Predict API error at {resp.url}: {detail}")
232
236
  raise ProxiedPredictAPIError(status=resp.status, detail=detail)
233
237
 
234
238
  @backoff.on_exception(
@@ -241,36 +245,6 @@ class PredictEngine:
241
245
  func = getattr(self.session, method.lower())
242
246
  return await func(**request_args)
243
247
 
244
- @predict_observer.wrap({"type": "feedback"})
245
- async def send_feedback(
246
- self,
247
- kbid: str,
248
- item: FeedbackRequest,
249
- x_nucliadb_user: str,
250
- x_ndb_client: str,
251
- x_forwarded_for: str,
252
- ):
253
- try:
254
- self.check_nua_key_is_configured_for_onprem()
255
- except NUAKeyMissingError:
256
- logger.warning(
257
- "Nuclia Service account is not defined so could not send the feedback"
258
- )
259
- return
260
-
261
- data = item.dict()
262
- data["user_id"] = x_nucliadb_user
263
- data["client"] = x_ndb_client
264
- data["forwarded"] = x_forwarded_for
265
-
266
- resp = await self.make_request(
267
- "POST",
268
- url=self.get_predict_url(FEEDBACK, kbid),
269
- json=data,
270
- headers=self.get_predict_headers(kbid),
271
- )
272
- await self.check_response(resp, expected_status=204)
273
-
274
248
  @predict_observer.wrap({"type": "rephrase"})
275
249
  async def rephrase_query(self, kbid: str, item: RephraseModel) -> str:
276
250
  try:
@@ -283,16 +257,20 @@ class PredictEngine:
283
257
  resp = await self.make_request(
284
258
  "POST",
285
259
  url=self.get_predict_url(REPHRASE, kbid),
286
- json=item.dict(),
260
+ json=item.model_dump(),
287
261
  headers=self.get_predict_headers(kbid),
288
262
  )
289
263
  await self.check_response(resp, expected_status=200)
290
264
  return await _parse_rephrase_response(resp)
291
265
 
292
- @predict_observer.wrap({"type": "chat"})
293
- async def chat_query(
266
+ @predict_observer.wrap({"type": "chat_ndjson"})
267
+ async def chat_query_ndjson(
294
268
  self, kbid: str, item: ChatModel
295
- ) -> tuple[str, AsyncIterator[bytes]]:
269
+ ) -> tuple[str, str, AsyncIterator[GenerativeChunk]]:
270
+ """
271
+ Chat query using the new stream format
272
+ Format specs: https://github.com/ndjson/ndjson-spec
273
+ """
296
274
  try:
297
275
  self.check_nua_key_is_configured_for_onprem()
298
276
  except NUAKeyMissingError:
@@ -300,60 +278,62 @@ class PredictEngine:
300
278
  logger.warning(error)
301
279
  raise SendToPredictError(error)
302
280
 
281
+ # The ndjson format is triggered by the Accept header
282
+ headers = self.get_predict_headers(kbid)
283
+ headers["Accept"] = "application/x-ndjson"
284
+
303
285
  resp = await self.make_request(
304
286
  "POST",
305
287
  url=self.get_predict_url(CHAT, kbid),
306
- json=item.dict(),
307
- headers=self.get_predict_headers(kbid),
288
+ json=item.model_dump(),
289
+ headers=headers,
308
290
  timeout=None,
309
291
  )
310
292
  await self.check_response(resp, expected_status=200)
311
293
  ident = resp.headers.get(NUCLIA_LEARNING_ID_HEADER)
312
- return ident, get_answer_generator(resp)
313
-
314
- @predict_observer.wrap({"type": "ask_document"})
315
- async def ask_document(
316
- self, kbid: str, question: str, blocks: list[list[str]], user_id: str
317
- ) -> str:
318
- try:
319
- self.check_nua_key_is_configured_for_onprem()
320
- except NUAKeyMissingError:
321
- error = "Nuclia Service account is not defined so could not ask document"
322
- logger.warning(error)
323
- raise SendToPredictError(error)
324
-
325
- item = AskDocumentModel(question=question, blocks=blocks, user_id=user_id)
326
- resp = await self.make_request(
327
- "POST",
328
- url=self.get_predict_url(ASK_DOCUMENT, kbid),
329
- json=item.dict(),
330
- headers=self.get_predict_headers(kbid),
331
- timeout=None,
332
- )
333
- await self.check_response(resp, expected_status=200)
334
- return await resp.text()
294
+ model = resp.headers.get(NUCLIA_LEARNING_MODEL_HEADER)
295
+ return ident, model, get_chat_ndjson_generator(resp)
335
296
 
336
297
  @predict_observer.wrap({"type": "query"})
337
298
  async def query(
338
299
  self,
339
300
  kbid: str,
340
301
  sentence: str,
302
+ semantic_model: Optional[str] = None,
341
303
  generative_model: Optional[str] = None,
342
- rephrase: Optional[bool] = False,
304
+ rephrase: bool = False,
305
+ rephrase_prompt: Optional[str] = None,
343
306
  ) -> QueryInfo:
307
+ """
308
+ Query endpoint: returns information to be used by NucliaDB at retrieval time, for instance:
309
+ - The embeddings
310
+ - The entities
311
+ - The stop words
312
+ - The semantic threshold
313
+ - etc.
314
+
315
+ :param kbid: KnowledgeBox ID
316
+ :param sentence: The query sentence
317
+ :param semantic_model: The semantic model to use to generate the embeddings
318
+ :param generative_model: The generative model that will be used to generate the answer
319
+ :param rephrase: If the query should be rephrased before calculating the embeddings for a better retrieval
320
+ :param rephrase_prompt: Custom prompt to use for rephrasing
321
+ """
344
322
  try:
345
323
  self.check_nua_key_is_configured_for_onprem()
346
324
  except NUAKeyMissingError:
347
- error = (
348
- "Nuclia Service account is not defined so could not ask query endpoint"
349
- )
325
+ error = "Nuclia Service account is not defined so could not ask query endpoint"
350
326
  logger.warning(error)
351
327
  raise SendToPredictError(error)
352
328
 
353
- params = {
329
+ params: dict[str, Any] = {
354
330
  "text": sentence,
355
331
  "rephrase": str(rephrase),
356
332
  }
333
+ if rephrase_prompt is not None:
334
+ params["rephrase_prompt"] = rephrase_prompt
335
+ if semantic_model is not None:
336
+ params["semantic_models"] = [semantic_model]
357
337
  if generative_model is not None:
358
338
  params["generative_model"] = generative_model
359
339
 
@@ -367,28 +347,6 @@ class PredictEngine:
367
347
  data = await resp.json()
368
348
  return QueryInfo(**data)
369
349
 
370
- @predict_observer.wrap({"type": "sentence"})
371
- async def convert_sentence_to_vector(self, kbid: str, sentence: str) -> list[float]:
372
- try:
373
- self.check_nua_key_is_configured_for_onprem()
374
- except NUAKeyMissingError:
375
- logger.warning(
376
- "Nuclia Service account is not defined so could not retrieve vectors for the query"
377
- )
378
- return []
379
-
380
- resp = await self.make_request(
381
- "GET",
382
- url=self.get_predict_url(SENTENCE, kbid),
383
- params={"text": sentence},
384
- headers=self.get_predict_headers(kbid),
385
- )
386
- await self.check_response(resp, expected_status=200)
387
- data = await resp.json()
388
- if len(data["data"]) == 0:
389
- raise PredictVectorMissing()
390
- return data["data"]
391
-
392
350
  @predict_observer.wrap({"type": "entities"})
393
351
  async def detect_entities(self, kbid: str, sentence: str) -> list[RelationNode]:
394
352
  try:
@@ -420,26 +378,46 @@ class PredictEngine:
420
378
  resp = await self.make_request(
421
379
  "POST",
422
380
  url=self.get_predict_url(SUMMARIZE, kbid),
423
- json=item.dict(),
381
+ json=item.model_dump(),
424
382
  headers=self.get_predict_headers(kbid),
425
383
  timeout=None,
426
384
  )
427
385
  await self.check_response(resp, expected_status=200)
428
386
  data = await resp.json()
429
- return SummarizedResponse.parse_obj(data)
387
+ return SummarizedResponse.model_validate(data)
388
+
389
+ @predict_observer.wrap({"type": "rerank"})
390
+ async def rerank(self, kbid: str, item: RerankModel) -> RerankResponse:
391
+ try:
392
+ self.check_nua_key_is_configured_for_onprem()
393
+ except NUAKeyMissingError:
394
+ error = "Nuclia Service account is not defined. Rerank operation could not be performed"
395
+ logger.warning(error)
396
+ raise SendToPredictError(error)
397
+ resp = await self.make_request(
398
+ "POST",
399
+ url=self.get_predict_url(RERANK, kbid),
400
+ json=item.model_dump(),
401
+ headers=self.get_predict_headers(kbid),
402
+ )
403
+ await self.check_response(resp, expected_status=200)
404
+ data = await resp.json()
405
+ return RerankResponse.model_validate(data)
430
406
 
431
407
 
432
408
  class DummyPredictEngine(PredictEngine):
409
+ default_semantic_threshold = 0.7
410
+
433
411
  def __init__(self):
434
412
  self.onprem = True
435
413
  self.cluster_url = "http://localhost:8000"
436
414
  self.public_url = "http://localhost:8000"
437
415
  self.calls = []
438
- self.generated_answer = [
439
- b"valid ",
440
- b"answer ",
441
- b" to",
442
- AnswerStatusCode.SUCCESS.encode(),
416
+ self.ndjson_answer = [
417
+ b'{"chunk": {"type": "text", "text": "valid "}}\n',
418
+ b'{"chunk": {"type": "text", "text": "answer "}}\n',
419
+ b'{"chunk": {"type": "text", "text": "to"}}\n',
420
+ b'{"chunk": {"type": "status", "code": "0"}}\n',
443
421
  ]
444
422
  self.max_context = 1000
445
423
 
@@ -458,84 +436,72 @@ class DummyPredictEngine(PredictEngine):
458
436
  response.headers = {NUCLIA_LEARNING_ID_HEADER: DUMMY_LEARNING_ID}
459
437
  return response
460
438
 
461
- async def send_feedback(
462
- self,
463
- kbid: str,
464
- item: FeedbackRequest,
465
- x_nucliadb_user: str,
466
- x_ndb_client: str,
467
- x_forwarded_for: str,
468
- ):
469
- self.calls.append(("send_feedback", item))
470
- return
471
-
472
439
  async def rephrase_query(self, kbid: str, item: RephraseModel) -> str:
473
440
  self.calls.append(("rephrase_query", item))
474
441
  return DUMMY_REPHRASE_QUERY
475
442
 
476
- async def chat_query(
443
+ async def chat_query_ndjson(
477
444
  self, kbid: str, item: ChatModel
478
- ) -> tuple[str, AsyncIterator[bytes]]:
479
- self.calls.append(("chat_query", item))
445
+ ) -> tuple[str, str, AsyncIterator[GenerativeChunk]]:
446
+ self.calls.append(("chat_query_ndjson", item))
480
447
 
481
448
  async def generate():
482
- for i in self.generated_answer:
483
- yield i
484
-
485
- return (DUMMY_LEARNING_ID, generate())
449
+ for item in self.ndjson_answer:
450
+ yield GenerativeChunk.model_validate_json(item)
486
451
 
487
- async def ask_document(
488
- self, kbid: str, query: str, blocks: list[list[str]], user_id: str
489
- ) -> str:
490
- self.calls.append(("ask_document", (query, blocks, user_id)))
491
- answer = os.environ.get("TEST_ASK_DOCUMENT") or "Answer to your question"
492
- return answer
452
+ return (DUMMY_LEARNING_ID, DUMMY_LEARNING_MODEL, generate())
493
453
 
494
454
  async def query(
495
455
  self,
496
456
  kbid: str,
497
457
  sentence: str,
458
+ semantic_model: Optional[str] = None,
498
459
  generative_model: Optional[str] = None,
499
- rephrase: Optional[bool] = False,
460
+ rephrase: bool = False,
461
+ rephrase_prompt: Optional[str] = None,
500
462
  ) -> QueryInfo:
501
463
  self.calls.append(("query", sentence))
502
- if (
503
- os.environ.get("TEST_SENTENCE_ENCODER") == "multilingual-2023-02-21"
504
- ): # pragma: no cover
505
- return QueryInfo(
506
- language="en",
507
- stop_words=[],
508
- semantic_threshold=0.7,
509
- visual_llm=True,
510
- max_context=self.max_context,
511
- entities=TokenSearch(
512
- tokens=[Ner(text="text", ner="PERSON", start=0, end=2)], time=0.0
513
- ),
514
- sentence=SentenceSearch(data=Qm2023, time=0.0),
515
- query=sentence,
516
- )
517
- else:
518
- return QueryInfo(
519
- language="en",
520
- stop_words=[],
521
- semantic_threshold=0.7,
522
- visual_llm=True,
523
- max_context=self.max_context,
524
- entities=TokenSearch(
525
- tokens=[Ner(text="text", ner="PERSON", start=0, end=2)], time=0.0
526
- ),
527
- sentence=SentenceSearch(data=Q, time=0.0),
528
- query=sentence,
529
- )
530
464
 
531
- async def convert_sentence_to_vector(self, kbid: str, sentence: str) -> list[float]:
532
- self.calls.append(("convert_sentence_to_vector", sentence))
533
- if (
534
- os.environ.get("TEST_SENTENCE_ENCODER") == "multilingual-2023-02-21"
535
- ): # pragma: no cover
536
- return Qm2023
465
+ if os.environ.get("TEST_SENTENCE_ENCODER") == "multilingual-2023-02-21": # pragma: no cover
466
+ base_vector = Qm2023
537
467
  else:
538
- return Q
468
+ base_vector = Q
469
+
470
+ # populate data with existing vectorsets
471
+ async with datamanagers.with_ro_transaction() as txn:
472
+ semantic_thresholds = {}
473
+ vectors = {}
474
+ timings = {}
475
+ async for vectorset_id, config in datamanagers.vectorsets.iter(txn, kbid=kbid):
476
+ semantic_thresholds[vectorset_id] = self.default_semantic_threshold
477
+ vectorset_dimension = config.vectorset_index_config.vector_dimension
478
+ if vectorset_dimension > len(base_vector):
479
+ padding = vectorset_dimension - len(base_vector)
480
+ vectors[vectorset_id] = base_vector + [random.random()] * padding
481
+ else:
482
+ vectors[vectorset_id] = base_vector[:vectorset_dimension]
483
+
484
+ timings[vectorset_id] = 0.010
485
+
486
+ # and fake data with the passed one too
487
+ model = semantic_model or "<PREDICT-DEFAULT-SEMANTIC-MODEL>"
488
+ semantic_thresholds[model] = self.default_semantic_threshold
489
+ vectors[model] = base_vector
490
+ timings[model] = 0.0
491
+
492
+ return QueryInfo(
493
+ language="en",
494
+ stop_words=[],
495
+ semantic_thresholds=semantic_thresholds,
496
+ visual_llm=True,
497
+ max_context=self.max_context,
498
+ entities=TokenSearch(tokens=[Ner(text="text", ner="PERSON", start=0, end=2)], time=0.0),
499
+ sentence=SentenceSearch(
500
+ vectors=vectors,
501
+ timings=timings,
502
+ ),
503
+ query=sentence,
504
+ )
539
505
 
540
506
  async def detect_entities(self, kbid: str, sentence: str) -> list[RelationNode]:
541
507
  self.calls.append(("detect_entities", sentence))
@@ -554,9 +520,16 @@ class DummyPredictEngine(PredictEngine):
554
520
  rsummary = []
555
521
  for field_id, field_text in item.resources[rid].fields.items():
556
522
  rsummary.append(f"{field_id}: {field_text}")
557
- response.resources[rid] = SummarizedResource(
558
- summary="\n\n".join(rsummary), tokens=10
559
- )
523
+ response.resources[rid] = SummarizedResource(summary="\n\n".join(rsummary), tokens=10)
524
+ return response
525
+
526
+ async def rerank(self, kbid: str, item: RerankModel) -> RerankResponse:
527
+ self.calls.append(("rerank", (kbid, item)))
528
+ # as we don't have information about the retrieval scores, return a
529
+ # random score given by the dict iteration
530
+ response = RerankResponse(
531
+ context_scores={paragraph_id: i for i, paragraph_id in enumerate(item.context.keys())}
532
+ )
560
533
  return response
561
534
 
562
535
 
@@ -578,6 +551,21 @@ def get_answer_generator(response: aiohttp.ClientResponse):
578
551
  return _iter_answer_chunks(response.content.iter_chunks())
579
552
 
580
553
 
554
+ def get_chat_ndjson_generator(
555
+ response: aiohttp.ClientResponse,
556
+ ) -> AsyncIterator[GenerativeChunk]:
557
+ async def _parse_generative_chunks(gen):
558
+ async for chunk in gen:
559
+ try:
560
+ yield GenerativeChunk.model_validate_json(chunk.strip())
561
+ except ValidationError as ex:
562
+ errors.capture_exception(ex)
563
+ logger.error(f"Invalid chunk received: {chunk}")
564
+ continue
565
+
566
+ return _parse_generative_chunks(response.content)
567
+
568
+
581
569
  async def _parse_rephrase_response(
582
570
  resp: aiohttp.ClientResponse,
583
571
  ) -> str:
File without changes