nucliadb 4.0.0.post542__py3-none-any.whl → 6.2.1.post2798__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (418) hide show
  1. migrations/0003_allfields_key.py +1 -35
  2. migrations/0009_upgrade_relations_and_texts_to_v2.py +4 -2
  3. migrations/0010_fix_corrupt_indexes.py +10 -10
  4. migrations/0011_materialize_labelset_ids.py +1 -16
  5. migrations/0012_rollover_shards.py +5 -10
  6. migrations/0014_rollover_shards.py +4 -5
  7. migrations/0015_targeted_rollover.py +5 -10
  8. migrations/0016_upgrade_to_paragraphs_v2.py +25 -28
  9. migrations/0017_multiple_writable_shards.py +2 -4
  10. migrations/0018_purge_orphan_kbslugs.py +5 -7
  11. migrations/0019_upgrade_to_paragraphs_v3.py +25 -28
  12. migrations/0020_drain_nodes_from_cluster.py +3 -3
  13. nucliadb/standalone/tests/unit/test_run.py → migrations/0021_overwrite_vectorsets_key.py +16 -19
  14. nucliadb/tests/unit/test_openapi.py → migrations/0022_fix_paragraph_deletion_bug.py +16 -11
  15. migrations/0023_backfill_pg_catalog.py +80 -0
  16. migrations/0025_assign_models_to_kbs_v2.py +113 -0
  17. migrations/0026_fix_high_cardinality_content_types.py +61 -0
  18. migrations/0027_rollover_texts3.py +73 -0
  19. nucliadb/ingest/fields/date.py → migrations/pg/0001_bootstrap.py +10 -12
  20. migrations/pg/0002_catalog.py +42 -0
  21. nucliadb/ingest/tests/unit/test_settings.py → migrations/pg/0003_catalog_kbid_index.py +5 -3
  22. nucliadb/common/cluster/base.py +30 -16
  23. nucliadb/common/cluster/discovery/base.py +6 -14
  24. nucliadb/common/cluster/discovery/k8s.py +9 -19
  25. nucliadb/common/cluster/discovery/manual.py +1 -3
  26. nucliadb/common/cluster/discovery/utils.py +1 -3
  27. nucliadb/common/cluster/grpc_node_dummy.py +3 -11
  28. nucliadb/common/cluster/index_node.py +10 -19
  29. nucliadb/common/cluster/manager.py +174 -59
  30. nucliadb/common/cluster/rebalance.py +27 -29
  31. nucliadb/common/cluster/rollover.py +353 -194
  32. nucliadb/common/cluster/settings.py +6 -0
  33. nucliadb/common/cluster/standalone/grpc_node_binding.py +13 -64
  34. nucliadb/common/cluster/standalone/index_node.py +4 -11
  35. nucliadb/common/cluster/standalone/service.py +2 -6
  36. nucliadb/common/cluster/standalone/utils.py +2 -6
  37. nucliadb/common/cluster/utils.py +29 -22
  38. nucliadb/common/constants.py +20 -0
  39. nucliadb/common/context/__init__.py +3 -0
  40. nucliadb/common/context/fastapi.py +8 -5
  41. nucliadb/{tests/knowledgeboxes/__init__.py → common/counters.py} +8 -2
  42. nucliadb/common/datamanagers/__init__.py +7 -1
  43. nucliadb/common/datamanagers/atomic.py +22 -4
  44. nucliadb/common/datamanagers/cluster.py +5 -5
  45. nucliadb/common/datamanagers/entities.py +6 -16
  46. nucliadb/common/datamanagers/fields.py +84 -0
  47. nucliadb/common/datamanagers/kb.py +83 -37
  48. nucliadb/common/datamanagers/labels.py +26 -56
  49. nucliadb/common/datamanagers/processing.py +2 -6
  50. nucliadb/common/datamanagers/resources.py +41 -103
  51. nucliadb/common/datamanagers/rollover.py +76 -15
  52. nucliadb/common/datamanagers/synonyms.py +1 -1
  53. nucliadb/common/datamanagers/utils.py +15 -6
  54. nucliadb/common/datamanagers/vectorsets.py +110 -0
  55. nucliadb/common/external_index_providers/base.py +257 -0
  56. nucliadb/{ingest/tests/unit/orm/test_orm_utils.py → common/external_index_providers/exceptions.py} +9 -8
  57. nucliadb/common/external_index_providers/manager.py +101 -0
  58. nucliadb/common/external_index_providers/pinecone.py +933 -0
  59. nucliadb/common/external_index_providers/settings.py +52 -0
  60. nucliadb/common/http_clients/auth.py +3 -6
  61. nucliadb/common/http_clients/processing.py +6 -11
  62. nucliadb/common/http_clients/utils.py +1 -3
  63. nucliadb/common/ids.py +240 -0
  64. nucliadb/common/locking.py +29 -7
  65. nucliadb/common/maindb/driver.py +11 -35
  66. nucliadb/common/maindb/exceptions.py +3 -0
  67. nucliadb/common/maindb/local.py +22 -9
  68. nucliadb/common/maindb/pg.py +206 -111
  69. nucliadb/common/maindb/utils.py +11 -42
  70. nucliadb/common/models_utils/from_proto.py +479 -0
  71. nucliadb/common/models_utils/to_proto.py +60 -0
  72. nucliadb/common/nidx.py +260 -0
  73. nucliadb/export_import/datamanager.py +25 -19
  74. nucliadb/export_import/exporter.py +5 -11
  75. nucliadb/export_import/importer.py +5 -7
  76. nucliadb/export_import/models.py +3 -3
  77. nucliadb/export_import/tasks.py +4 -4
  78. nucliadb/export_import/utils.py +25 -37
  79. nucliadb/health.py +1 -3
  80. nucliadb/ingest/app.py +15 -11
  81. nucliadb/ingest/consumer/auditing.py +21 -19
  82. nucliadb/ingest/consumer/consumer.py +82 -47
  83. nucliadb/ingest/consumer/materializer.py +5 -12
  84. nucliadb/ingest/consumer/pull.py +12 -27
  85. nucliadb/ingest/consumer/service.py +19 -17
  86. nucliadb/ingest/consumer/shard_creator.py +2 -4
  87. nucliadb/ingest/consumer/utils.py +1 -3
  88. nucliadb/ingest/fields/base.py +137 -105
  89. nucliadb/ingest/fields/conversation.py +18 -5
  90. nucliadb/ingest/fields/exceptions.py +1 -4
  91. nucliadb/ingest/fields/file.py +7 -16
  92. nucliadb/ingest/fields/link.py +5 -10
  93. nucliadb/ingest/fields/text.py +9 -4
  94. nucliadb/ingest/orm/brain.py +200 -213
  95. nucliadb/ingest/orm/broker_message.py +181 -0
  96. nucliadb/ingest/orm/entities.py +36 -51
  97. nucliadb/ingest/orm/exceptions.py +12 -0
  98. nucliadb/ingest/orm/knowledgebox.py +322 -197
  99. nucliadb/ingest/orm/processor/__init__.py +2 -700
  100. nucliadb/ingest/orm/processor/auditing.py +4 -23
  101. nucliadb/ingest/orm/processor/data_augmentation.py +164 -0
  102. nucliadb/ingest/orm/processor/pgcatalog.py +84 -0
  103. nucliadb/ingest/orm/processor/processor.py +752 -0
  104. nucliadb/ingest/orm/processor/sequence_manager.py +1 -1
  105. nucliadb/ingest/orm/resource.py +249 -403
  106. nucliadb/ingest/orm/utils.py +4 -4
  107. nucliadb/ingest/partitions.py +3 -9
  108. nucliadb/ingest/processing.py +70 -73
  109. nucliadb/ingest/py.typed +0 -0
  110. nucliadb/ingest/serialize.py +37 -167
  111. nucliadb/ingest/service/__init__.py +1 -3
  112. nucliadb/ingest/service/writer.py +185 -412
  113. nucliadb/ingest/settings.py +10 -20
  114. nucliadb/ingest/utils.py +3 -6
  115. nucliadb/learning_proxy.py +242 -55
  116. nucliadb/metrics_exporter.py +30 -19
  117. nucliadb/middleware/__init__.py +1 -3
  118. nucliadb/migrator/command.py +1 -3
  119. nucliadb/migrator/datamanager.py +13 -13
  120. nucliadb/migrator/migrator.py +47 -30
  121. nucliadb/migrator/utils.py +18 -10
  122. nucliadb/purge/__init__.py +139 -33
  123. nucliadb/purge/orphan_shards.py +7 -13
  124. nucliadb/reader/__init__.py +1 -3
  125. nucliadb/reader/api/models.py +1 -12
  126. nucliadb/reader/api/v1/__init__.py +0 -1
  127. nucliadb/reader/api/v1/download.py +21 -88
  128. nucliadb/reader/api/v1/export_import.py +1 -1
  129. nucliadb/reader/api/v1/knowledgebox.py +10 -10
  130. nucliadb/reader/api/v1/learning_config.py +2 -6
  131. nucliadb/reader/api/v1/resource.py +62 -88
  132. nucliadb/reader/api/v1/services.py +64 -83
  133. nucliadb/reader/app.py +12 -29
  134. nucliadb/reader/lifecycle.py +18 -4
  135. nucliadb/reader/py.typed +0 -0
  136. nucliadb/reader/reader/notifications.py +10 -28
  137. nucliadb/search/__init__.py +1 -3
  138. nucliadb/search/api/v1/__init__.py +1 -2
  139. nucliadb/search/api/v1/ask.py +17 -10
  140. nucliadb/search/api/v1/catalog.py +184 -0
  141. nucliadb/search/api/v1/feedback.py +16 -24
  142. nucliadb/search/api/v1/find.py +36 -36
  143. nucliadb/search/api/v1/knowledgebox.py +89 -60
  144. nucliadb/search/api/v1/resource/ask.py +2 -8
  145. nucliadb/search/api/v1/resource/search.py +49 -70
  146. nucliadb/search/api/v1/search.py +44 -210
  147. nucliadb/search/api/v1/suggest.py +39 -54
  148. nucliadb/search/app.py +12 -32
  149. nucliadb/search/lifecycle.py +10 -3
  150. nucliadb/search/predict.py +136 -187
  151. nucliadb/search/py.typed +0 -0
  152. nucliadb/search/requesters/utils.py +25 -58
  153. nucliadb/search/search/cache.py +149 -20
  154. nucliadb/search/search/chat/ask.py +571 -123
  155. nucliadb/search/{tests/unit/test_run.py → search/chat/exceptions.py} +14 -14
  156. nucliadb/search/search/chat/images.py +41 -17
  157. nucliadb/search/search/chat/prompt.py +817 -266
  158. nucliadb/search/search/chat/query.py +213 -309
  159. nucliadb/{tests/migrations/__init__.py → search/search/cut.py} +8 -8
  160. nucliadb/search/search/fetch.py +43 -36
  161. nucliadb/search/search/filters.py +9 -15
  162. nucliadb/search/search/find.py +214 -53
  163. nucliadb/search/search/find_merge.py +408 -391
  164. nucliadb/search/search/hydrator.py +191 -0
  165. nucliadb/search/search/merge.py +187 -223
  166. nucliadb/search/search/metrics.py +73 -2
  167. nucliadb/search/search/paragraphs.py +64 -106
  168. nucliadb/search/search/pgcatalog.py +233 -0
  169. nucliadb/search/search/predict_proxy.py +1 -1
  170. nucliadb/search/search/query.py +305 -150
  171. nucliadb/search/search/query_parser/exceptions.py +22 -0
  172. nucliadb/search/search/query_parser/models.py +101 -0
  173. nucliadb/search/search/query_parser/parser.py +183 -0
  174. nucliadb/search/search/rank_fusion.py +204 -0
  175. nucliadb/search/search/rerankers.py +270 -0
  176. nucliadb/search/search/shards.py +3 -32
  177. nucliadb/search/search/summarize.py +7 -18
  178. nucliadb/search/search/utils.py +27 -4
  179. nucliadb/search/settings.py +15 -1
  180. nucliadb/standalone/api_router.py +4 -10
  181. nucliadb/standalone/app.py +8 -14
  182. nucliadb/standalone/auth.py +7 -21
  183. nucliadb/standalone/config.py +7 -10
  184. nucliadb/standalone/lifecycle.py +26 -25
  185. nucliadb/standalone/migrations.py +1 -3
  186. nucliadb/standalone/purge.py +1 -1
  187. nucliadb/standalone/py.typed +0 -0
  188. nucliadb/standalone/run.py +3 -6
  189. nucliadb/standalone/settings.py +9 -16
  190. nucliadb/standalone/versions.py +15 -5
  191. nucliadb/tasks/consumer.py +8 -12
  192. nucliadb/tasks/producer.py +7 -6
  193. nucliadb/tests/config.py +53 -0
  194. nucliadb/train/__init__.py +1 -3
  195. nucliadb/train/api/utils.py +1 -2
  196. nucliadb/train/api/v1/shards.py +1 -1
  197. nucliadb/train/api/v1/trainset.py +2 -4
  198. nucliadb/train/app.py +10 -31
  199. nucliadb/train/generator.py +10 -19
  200. nucliadb/train/generators/field_classifier.py +7 -19
  201. nucliadb/train/generators/field_streaming.py +156 -0
  202. nucliadb/train/generators/image_classifier.py +12 -18
  203. nucliadb/train/generators/paragraph_classifier.py +5 -9
  204. nucliadb/train/generators/paragraph_streaming.py +6 -9
  205. nucliadb/train/generators/question_answer_streaming.py +19 -20
  206. nucliadb/train/generators/sentence_classifier.py +9 -15
  207. nucliadb/train/generators/token_classifier.py +48 -39
  208. nucliadb/train/generators/utils.py +14 -18
  209. nucliadb/train/lifecycle.py +7 -3
  210. nucliadb/train/nodes.py +23 -32
  211. nucliadb/train/py.typed +0 -0
  212. nucliadb/train/servicer.py +13 -21
  213. nucliadb/train/settings.py +2 -6
  214. nucliadb/train/types.py +13 -10
  215. nucliadb/train/upload.py +3 -6
  216. nucliadb/train/uploader.py +19 -23
  217. nucliadb/train/utils.py +1 -1
  218. nucliadb/writer/__init__.py +1 -3
  219. nucliadb/{ingest/fields/keywordset.py → writer/api/utils.py} +13 -10
  220. nucliadb/writer/api/v1/export_import.py +67 -14
  221. nucliadb/writer/api/v1/field.py +16 -269
  222. nucliadb/writer/api/v1/knowledgebox.py +218 -68
  223. nucliadb/writer/api/v1/resource.py +68 -88
  224. nucliadb/writer/api/v1/services.py +51 -70
  225. nucliadb/writer/api/v1/slug.py +61 -0
  226. nucliadb/writer/api/v1/transaction.py +67 -0
  227. nucliadb/writer/api/v1/upload.py +143 -117
  228. nucliadb/writer/app.py +6 -43
  229. nucliadb/writer/back_pressure.py +16 -38
  230. nucliadb/writer/exceptions.py +0 -4
  231. nucliadb/writer/lifecycle.py +21 -15
  232. nucliadb/writer/py.typed +0 -0
  233. nucliadb/writer/resource/audit.py +2 -1
  234. nucliadb/writer/resource/basic.py +48 -46
  235. nucliadb/writer/resource/field.py +37 -128
  236. nucliadb/writer/resource/origin.py +1 -2
  237. nucliadb/writer/settings.py +6 -2
  238. nucliadb/writer/tus/__init__.py +17 -15
  239. nucliadb/writer/tus/azure.py +111 -0
  240. nucliadb/writer/tus/dm.py +17 -5
  241. nucliadb/writer/tus/exceptions.py +1 -3
  242. nucliadb/writer/tus/gcs.py +49 -84
  243. nucliadb/writer/tus/local.py +21 -37
  244. nucliadb/writer/tus/s3.py +28 -68
  245. nucliadb/writer/tus/storage.py +5 -56
  246. nucliadb/writer/vectorsets.py +125 -0
  247. nucliadb-6.2.1.post2798.dist-info/METADATA +148 -0
  248. nucliadb-6.2.1.post2798.dist-info/RECORD +343 -0
  249. {nucliadb-4.0.0.post542.dist-info → nucliadb-6.2.1.post2798.dist-info}/WHEEL +1 -1
  250. nucliadb/common/maindb/redis.py +0 -194
  251. nucliadb/common/maindb/tikv.py +0 -433
  252. nucliadb/ingest/fields/layout.py +0 -58
  253. nucliadb/ingest/tests/conftest.py +0 -30
  254. nucliadb/ingest/tests/fixtures.py +0 -764
  255. nucliadb/ingest/tests/integration/consumer/__init__.py +0 -18
  256. nucliadb/ingest/tests/integration/consumer/test_auditing.py +0 -78
  257. nucliadb/ingest/tests/integration/consumer/test_materializer.py +0 -126
  258. nucliadb/ingest/tests/integration/consumer/test_pull.py +0 -144
  259. nucliadb/ingest/tests/integration/consumer/test_service.py +0 -81
  260. nucliadb/ingest/tests/integration/consumer/test_shard_creator.py +0 -68
  261. nucliadb/ingest/tests/integration/ingest/test_ingest.py +0 -684
  262. nucliadb/ingest/tests/integration/ingest/test_processing_engine.py +0 -95
  263. nucliadb/ingest/tests/integration/ingest/test_relations.py +0 -272
  264. nucliadb/ingest/tests/unit/consumer/__init__.py +0 -18
  265. nucliadb/ingest/tests/unit/consumer/test_auditing.py +0 -139
  266. nucliadb/ingest/tests/unit/consumer/test_consumer.py +0 -69
  267. nucliadb/ingest/tests/unit/consumer/test_pull.py +0 -60
  268. nucliadb/ingest/tests/unit/consumer/test_shard_creator.py +0 -140
  269. nucliadb/ingest/tests/unit/consumer/test_utils.py +0 -67
  270. nucliadb/ingest/tests/unit/orm/__init__.py +0 -19
  271. nucliadb/ingest/tests/unit/orm/test_brain.py +0 -247
  272. nucliadb/ingest/tests/unit/orm/test_brain_vectors.py +0 -74
  273. nucliadb/ingest/tests/unit/orm/test_processor.py +0 -131
  274. nucliadb/ingest/tests/unit/orm/test_resource.py +0 -331
  275. nucliadb/ingest/tests/unit/test_cache.py +0 -31
  276. nucliadb/ingest/tests/unit/test_partitions.py +0 -40
  277. nucliadb/ingest/tests/unit/test_processing.py +0 -171
  278. nucliadb/middleware/transaction.py +0 -117
  279. nucliadb/reader/api/v1/learning_collector.py +0 -63
  280. nucliadb/reader/tests/__init__.py +0 -19
  281. nucliadb/reader/tests/conftest.py +0 -31
  282. nucliadb/reader/tests/fixtures.py +0 -136
  283. nucliadb/reader/tests/test_list_resources.py +0 -75
  284. nucliadb/reader/tests/test_reader_file_download.py +0 -273
  285. nucliadb/reader/tests/test_reader_resource.py +0 -353
  286. nucliadb/reader/tests/test_reader_resource_field.py +0 -219
  287. nucliadb/search/api/v1/chat.py +0 -263
  288. nucliadb/search/api/v1/resource/chat.py +0 -174
  289. nucliadb/search/tests/__init__.py +0 -19
  290. nucliadb/search/tests/conftest.py +0 -33
  291. nucliadb/search/tests/fixtures.py +0 -199
  292. nucliadb/search/tests/node.py +0 -466
  293. nucliadb/search/tests/unit/__init__.py +0 -18
  294. nucliadb/search/tests/unit/api/__init__.py +0 -19
  295. nucliadb/search/tests/unit/api/v1/__init__.py +0 -19
  296. nucliadb/search/tests/unit/api/v1/resource/__init__.py +0 -19
  297. nucliadb/search/tests/unit/api/v1/resource/test_chat.py +0 -98
  298. nucliadb/search/tests/unit/api/v1/test_ask.py +0 -120
  299. nucliadb/search/tests/unit/api/v1/test_chat.py +0 -96
  300. nucliadb/search/tests/unit/api/v1/test_predict_proxy.py +0 -98
  301. nucliadb/search/tests/unit/api/v1/test_summarize.py +0 -99
  302. nucliadb/search/tests/unit/search/__init__.py +0 -18
  303. nucliadb/search/tests/unit/search/requesters/__init__.py +0 -18
  304. nucliadb/search/tests/unit/search/requesters/test_utils.py +0 -211
  305. nucliadb/search/tests/unit/search/search/__init__.py +0 -19
  306. nucliadb/search/tests/unit/search/search/test_shards.py +0 -45
  307. nucliadb/search/tests/unit/search/search/test_utils.py +0 -82
  308. nucliadb/search/tests/unit/search/test_chat_prompt.py +0 -270
  309. nucliadb/search/tests/unit/search/test_fetch.py +0 -108
  310. nucliadb/search/tests/unit/search/test_filters.py +0 -125
  311. nucliadb/search/tests/unit/search/test_paragraphs.py +0 -157
  312. nucliadb/search/tests/unit/search/test_predict_proxy.py +0 -106
  313. nucliadb/search/tests/unit/search/test_query.py +0 -153
  314. nucliadb/search/tests/unit/test_app.py +0 -79
  315. nucliadb/search/tests/unit/test_find_merge.py +0 -112
  316. nucliadb/search/tests/unit/test_merge.py +0 -34
  317. nucliadb/search/tests/unit/test_predict.py +0 -525
  318. nucliadb/standalone/tests/__init__.py +0 -19
  319. nucliadb/standalone/tests/conftest.py +0 -33
  320. nucliadb/standalone/tests/fixtures.py +0 -38
  321. nucliadb/standalone/tests/unit/__init__.py +0 -18
  322. nucliadb/standalone/tests/unit/test_api_router.py +0 -61
  323. nucliadb/standalone/tests/unit/test_auth.py +0 -169
  324. nucliadb/standalone/tests/unit/test_introspect.py +0 -35
  325. nucliadb/standalone/tests/unit/test_migrations.py +0 -63
  326. nucliadb/standalone/tests/unit/test_versions.py +0 -68
  327. nucliadb/tests/benchmarks/__init__.py +0 -19
  328. nucliadb/tests/benchmarks/test_search.py +0 -99
  329. nucliadb/tests/conftest.py +0 -32
  330. nucliadb/tests/fixtures.py +0 -735
  331. nucliadb/tests/knowledgeboxes/philosophy_books.py +0 -202
  332. nucliadb/tests/knowledgeboxes/ten_dummy_resources.py +0 -107
  333. nucliadb/tests/migrations/test_migration_0017.py +0 -76
  334. nucliadb/tests/migrations/test_migration_0018.py +0 -95
  335. nucliadb/tests/tikv.py +0 -240
  336. nucliadb/tests/unit/__init__.py +0 -19
  337. nucliadb/tests/unit/common/__init__.py +0 -19
  338. nucliadb/tests/unit/common/cluster/__init__.py +0 -19
  339. nucliadb/tests/unit/common/cluster/discovery/__init__.py +0 -19
  340. nucliadb/tests/unit/common/cluster/discovery/test_k8s.py +0 -172
  341. nucliadb/tests/unit/common/cluster/standalone/__init__.py +0 -18
  342. nucliadb/tests/unit/common/cluster/standalone/test_service.py +0 -114
  343. nucliadb/tests/unit/common/cluster/standalone/test_utils.py +0 -61
  344. nucliadb/tests/unit/common/cluster/test_cluster.py +0 -408
  345. nucliadb/tests/unit/common/cluster/test_kb_shard_manager.py +0 -173
  346. nucliadb/tests/unit/common/cluster/test_rebalance.py +0 -38
  347. nucliadb/tests/unit/common/cluster/test_rollover.py +0 -282
  348. nucliadb/tests/unit/common/maindb/__init__.py +0 -18
  349. nucliadb/tests/unit/common/maindb/test_driver.py +0 -127
  350. nucliadb/tests/unit/common/maindb/test_tikv.py +0 -53
  351. nucliadb/tests/unit/common/maindb/test_utils.py +0 -92
  352. nucliadb/tests/unit/common/test_context.py +0 -36
  353. nucliadb/tests/unit/export_import/__init__.py +0 -19
  354. nucliadb/tests/unit/export_import/test_datamanager.py +0 -37
  355. nucliadb/tests/unit/export_import/test_utils.py +0 -301
  356. nucliadb/tests/unit/migrator/__init__.py +0 -19
  357. nucliadb/tests/unit/migrator/test_migrator.py +0 -87
  358. nucliadb/tests/unit/tasks/__init__.py +0 -19
  359. nucliadb/tests/unit/tasks/conftest.py +0 -42
  360. nucliadb/tests/unit/tasks/test_consumer.py +0 -92
  361. nucliadb/tests/unit/tasks/test_producer.py +0 -95
  362. nucliadb/tests/unit/tasks/test_tasks.py +0 -58
  363. nucliadb/tests/unit/test_field_ids.py +0 -49
  364. nucliadb/tests/unit/test_health.py +0 -86
  365. nucliadb/tests/unit/test_kb_slugs.py +0 -54
  366. nucliadb/tests/unit/test_learning_proxy.py +0 -252
  367. nucliadb/tests/unit/test_metrics_exporter.py +0 -77
  368. nucliadb/tests/unit/test_purge.py +0 -136
  369. nucliadb/tests/utils/__init__.py +0 -74
  370. nucliadb/tests/utils/aiohttp_session.py +0 -44
  371. nucliadb/tests/utils/broker_messages/__init__.py +0 -171
  372. nucliadb/tests/utils/broker_messages/fields.py +0 -197
  373. nucliadb/tests/utils/broker_messages/helpers.py +0 -33
  374. nucliadb/tests/utils/entities.py +0 -78
  375. nucliadb/train/api/v1/check.py +0 -60
  376. nucliadb/train/tests/__init__.py +0 -19
  377. nucliadb/train/tests/conftest.py +0 -29
  378. nucliadb/train/tests/fixtures.py +0 -342
  379. nucliadb/train/tests/test_field_classification.py +0 -122
  380. nucliadb/train/tests/test_get_entities.py +0 -80
  381. nucliadb/train/tests/test_get_info.py +0 -51
  382. nucliadb/train/tests/test_get_ontology.py +0 -34
  383. nucliadb/train/tests/test_get_ontology_count.py +0 -63
  384. nucliadb/train/tests/test_image_classification.py +0 -221
  385. nucliadb/train/tests/test_list_fields.py +0 -39
  386. nucliadb/train/tests/test_list_paragraphs.py +0 -73
  387. nucliadb/train/tests/test_list_resources.py +0 -39
  388. nucliadb/train/tests/test_list_sentences.py +0 -71
  389. nucliadb/train/tests/test_paragraph_classification.py +0 -123
  390. nucliadb/train/tests/test_paragraph_streaming.py +0 -118
  391. nucliadb/train/tests/test_question_answer_streaming.py +0 -239
  392. nucliadb/train/tests/test_sentence_classification.py +0 -143
  393. nucliadb/train/tests/test_token_classification.py +0 -136
  394. nucliadb/train/tests/utils.py +0 -101
  395. nucliadb/writer/layouts/__init__.py +0 -51
  396. nucliadb/writer/layouts/v1.py +0 -59
  397. nucliadb/writer/tests/__init__.py +0 -19
  398. nucliadb/writer/tests/conftest.py +0 -31
  399. nucliadb/writer/tests/fixtures.py +0 -191
  400. nucliadb/writer/tests/test_fields.py +0 -475
  401. nucliadb/writer/tests/test_files.py +0 -740
  402. nucliadb/writer/tests/test_knowledgebox.py +0 -49
  403. nucliadb/writer/tests/test_reprocess_file_field.py +0 -133
  404. nucliadb/writer/tests/test_resources.py +0 -476
  405. nucliadb/writer/tests/test_service.py +0 -137
  406. nucliadb/writer/tests/test_tus.py +0 -203
  407. nucliadb/writer/tests/utils.py +0 -35
  408. nucliadb/writer/tus/pg.py +0 -125
  409. nucliadb-4.0.0.post542.dist-info/METADATA +0 -135
  410. nucliadb-4.0.0.post542.dist-info/RECORD +0 -462
  411. {nucliadb/ingest/tests → migrations/pg}/__init__.py +0 -0
  412. /nucliadb/{ingest/tests/integration → common/external_index_providers}/__init__.py +0 -0
  413. /nucliadb/{ingest/tests/integration/ingest → common/models_utils}/__init__.py +0 -0
  414. /nucliadb/{ingest/tests/unit → search/search/query_parser}/__init__.py +0 -0
  415. /nucliadb/{ingest/tests → tests}/vectors.py +0 -0
  416. {nucliadb-4.0.0.post542.dist-info → nucliadb-6.2.1.post2798.dist-info}/entry_points.txt +0 -0
  417. {nucliadb-4.0.0.post542.dist-info → nucliadb-6.2.1.post2798.dist-info}/top_level.txt +0 -0
  418. {nucliadb-4.0.0.post542.dist-info → nucliadb-6.2.1.post2798.dist-info}/zip-safe +0 -0
@@ -19,31 +19,37 @@
19
19
  #
20
20
  import json
21
21
  import os
22
+ import random
22
23
  from enum import Enum
23
- from typing import Any, AsyncIterator, Literal, Optional, Union
24
+ from typing import Any, AsyncIterator, Optional
24
25
  from unittest.mock import AsyncMock, Mock
25
26
 
26
27
  import aiohttp
27
28
  import backoff
28
- from nucliadb_protos.utils_pb2 import RelationNode
29
- from pydantic import BaseModel, Field, ValidationError
29
+ from nuclia_models.predict.generative_responses import GenerativeChunk
30
+ from pydantic import ValidationError
30
31
 
31
- from nucliadb.ingest.tests.vectors import Q, Qm2023
32
+ from nucliadb.common import datamanagers
32
33
  from nucliadb.search import logger
33
- from nucliadb_models.search import (
34
- ChatModel,
35
- FeedbackRequest,
34
+ from nucliadb.tests.vectors import Q, Qm2023
35
+ from nucliadb_models.internal.predict import (
36
36
  Ner,
37
37
  QueryInfo,
38
- RephraseModel,
38
+ RerankModel,
39
+ RerankResponse,
39
40
  SentenceSearch,
41
+ TokenSearch,
42
+ )
43
+ from nucliadb_models.search import (
44
+ ChatModel,
45
+ RephraseModel,
40
46
  SummarizedResource,
41
47
  SummarizedResponse,
42
48
  SummarizeModel,
43
- TokenSearch,
44
49
  )
50
+ from nucliadb_protos.utils_pb2 import RelationNode
45
51
  from nucliadb_telemetry import errors, metrics
46
- from nucliadb_utils import const
52
+ from nucliadb_utils.const import Features
47
53
  from nucliadb_utils.exceptions import LimitsExceededError
48
54
  from nucliadb_utils.settings import nuclia_settings
49
55
  from nucliadb_utils.utilities import Utility, has_feature, set_utility
@@ -73,13 +79,12 @@ class RephraseMissingContextError(Exception):
73
79
 
74
80
  DUMMY_RELATION_NODE = [
75
81
  RelationNode(value="Ferran", ntype=RelationNode.NodeType.ENTITY, subtype="PERSON"),
76
- RelationNode(
77
- value="Joan Antoni", ntype=RelationNode.NodeType.ENTITY, subtype="PERSON"
78
- ),
82
+ RelationNode(value="Joan Antoni", ntype=RelationNode.NodeType.ENTITY, subtype="PERSON"),
79
83
  ]
80
84
 
81
85
  DUMMY_REPHRASE_QUERY = "This is a rephrased query"
82
86
  DUMMY_LEARNING_ID = "00"
87
+ DUMMY_LEARNING_MODEL = "chatgpt"
83
88
 
84
89
 
85
90
  PUBLIC_PREDICT = "/api/v1/predict"
@@ -92,8 +97,10 @@ SUMMARIZE = "/summarize"
92
97
  CHAT = "/chat"
93
98
  REPHRASE = "/rephrase"
94
99
  FEEDBACK = "/feedback"
100
+ RERANK = "/rerank"
95
101
 
96
102
  NUCLIA_LEARNING_ID_HEADER = "NUCLIA-LEARNING-ID"
103
+ NUCLIA_LEARNING_MODEL_HEADER = "NUCLIA-LEARNING-MODEL"
97
104
 
98
105
 
99
106
  predict_observer = metrics.Observer(
@@ -123,41 +130,6 @@ class AnswerStatusCode(str, Enum):
123
130
  }[self]
124
131
 
125
132
 
126
- class TextGenerativeResponse(BaseModel):
127
- type: Literal["text"] = "text"
128
- text: str
129
-
130
-
131
- class MetaGenerativeResponse(BaseModel):
132
- type: Literal["meta"] = "meta"
133
- input_tokens: int
134
- output_tokens: int
135
- timings: dict[str, float]
136
-
137
-
138
- class CitationsGenerativeResponse(BaseModel):
139
- type: Literal["citations"] = "citations"
140
- citations: dict[str, Any]
141
-
142
-
143
- class StatusGenerativeResponse(BaseModel):
144
- type: Literal["status"] = "status"
145
- code: str
146
- details: Optional[str] = None
147
-
148
-
149
- GenerativeResponse = Union[
150
- TextGenerativeResponse,
151
- MetaGenerativeResponse,
152
- CitationsGenerativeResponse,
153
- StatusGenerativeResponse,
154
- ]
155
-
156
-
157
- class GenerativeChunk(BaseModel):
158
- chunk: GenerativeResponse = Field(..., discriminator="type")
159
-
160
-
161
133
  async def start_predict_engine():
162
134
  if nuclia_settings.dummy_predict:
163
135
  predict_util = DummyPredictEngine()
@@ -180,9 +152,7 @@ def convert_relations(data: dict[str, list[dict[str, str]]]) -> list[RelationNod
180
152
  for token in data["tokens"]:
181
153
  text = token["text"]
182
154
  klass = token["ner"]
183
- result.append(
184
- RelationNode(value=text, ntype=RelationNode.NodeType.ENTITY, subtype=klass)
185
- )
155
+ result.append(RelationNode(value=text, ntype=RelationNode.NodeType.ENTITY, subtype=klass))
186
156
  return result
187
157
 
188
158
 
@@ -215,9 +185,7 @@ class PredictEngine:
215
185
  await self.session.close()
216
186
 
217
187
  def check_nua_key_is_configured_for_onprem(self):
218
- if self.onprem and (
219
- self.nuclia_service_account is None and self.local_predict is False
220
- ):
188
+ if self.onprem and (self.nuclia_service_account is None and self.local_predict is False):
221
189
  raise NUAKeyMissingError()
222
190
 
223
191
  def get_predict_url(self, endpoint: str, kbid: str) -> str:
@@ -229,7 +197,7 @@ class PredictEngine:
229
197
  # /api/v1/predict/rephrase/{kbid}
230
198
  return f"{self.public_url}{PUBLIC_PREDICT}{endpoint}/{kbid}"
231
199
  else:
232
- if has_feature(const.Features.VERSIONED_PRIVATE_PREDICT):
200
+ if has_feature(Features.VERSIONED_PRIVATE_PREDICT):
233
201
  return f"{self.cluster_url}{VERSIONED_PRIVATE_PREDICT}{endpoint}"
234
202
  else:
235
203
  return f"{self.cluster_url}{PRIVATE_PREDICT}{endpoint}"
@@ -243,16 +211,13 @@ class PredictEngine:
243
211
  else:
244
212
  return {"X-STF-KBID": kbid}
245
213
 
246
- async def check_response(
247
- self, resp: aiohttp.ClientResponse, expected_status: int = 200
248
- ) -> None:
214
+ async def check_response(self, resp: aiohttp.ClientResponse, expected_status: int = 200) -> None:
249
215
  if resp.status == expected_status:
250
216
  return
251
217
 
252
218
  if resp.status == 402:
253
219
  data = await resp.json()
254
220
  raise LimitsExceededError(402, data["detail"])
255
-
256
221
  try:
257
222
  data = await resp.json()
258
223
  try:
@@ -264,7 +229,10 @@ class PredictEngine:
264
229
  aiohttp.client_exceptions.ContentTypeError,
265
230
  ):
266
231
  detail = await resp.text()
267
- logger.error(f"Predict API error at {resp.url}: {detail}")
232
+ if str(resp.status).startswith("5"):
233
+ logger.error(f"Predict API error at {resp.url}: {detail}")
234
+ else:
235
+ logger.info(f"Predict API error at {resp.url}: {detail}")
268
236
  raise ProxiedPredictAPIError(status=resp.status, detail=detail)
269
237
 
270
238
  @backoff.on_exception(
@@ -277,36 +245,6 @@ class PredictEngine:
277
245
  func = getattr(self.session, method.lower())
278
246
  return await func(**request_args)
279
247
 
280
- @predict_observer.wrap({"type": "feedback"})
281
- async def send_feedback(
282
- self,
283
- kbid: str,
284
- item: FeedbackRequest,
285
- x_nucliadb_user: str,
286
- x_ndb_client: str,
287
- x_forwarded_for: str,
288
- ):
289
- try:
290
- self.check_nua_key_is_configured_for_onprem()
291
- except NUAKeyMissingError:
292
- logger.warning(
293
- "Nuclia Service account is not defined so could not send the feedback"
294
- )
295
- return
296
-
297
- data = item.dict()
298
- data["user_id"] = x_nucliadb_user
299
- data["client"] = x_ndb_client
300
- data["forwarded"] = x_forwarded_for
301
-
302
- resp = await self.make_request(
303
- "POST",
304
- url=self.get_predict_url(FEEDBACK, kbid),
305
- json=data,
306
- headers=self.get_predict_headers(kbid),
307
- )
308
- await self.check_response(resp, expected_status=204)
309
-
310
248
  @predict_observer.wrap({"type": "rephrase"})
311
249
  async def rephrase_query(self, kbid: str, item: RephraseModel) -> str:
312
250
  try:
@@ -319,38 +257,16 @@ class PredictEngine:
319
257
  resp = await self.make_request(
320
258
  "POST",
321
259
  url=self.get_predict_url(REPHRASE, kbid),
322
- json=item.dict(),
260
+ json=item.model_dump(),
323
261
  headers=self.get_predict_headers(kbid),
324
262
  )
325
263
  await self.check_response(resp, expected_status=200)
326
264
  return await _parse_rephrase_response(resp)
327
265
 
328
- @predict_observer.wrap({"type": "chat"})
329
- async def chat_query(
330
- self, kbid: str, item: ChatModel
331
- ) -> tuple[str, AsyncIterator[bytes]]:
332
- try:
333
- self.check_nua_key_is_configured_for_onprem()
334
- except NUAKeyMissingError:
335
- error = "Nuclia Service account is not defined so the chat operation could not be performed"
336
- logger.warning(error)
337
- raise SendToPredictError(error)
338
-
339
- resp = await self.make_request(
340
- "POST",
341
- url=self.get_predict_url(CHAT, kbid),
342
- json=item.dict(),
343
- headers=self.get_predict_headers(kbid),
344
- timeout=None,
345
- )
346
- await self.check_response(resp, expected_status=200)
347
- ident = resp.headers.get(NUCLIA_LEARNING_ID_HEADER)
348
- return ident, get_answer_generator(resp)
349
-
350
266
  @predict_observer.wrap({"type": "chat_ndjson"})
351
267
  async def chat_query_ndjson(
352
268
  self, kbid: str, item: ChatModel
353
- ) -> tuple[str, AsyncIterator[GenerativeChunk]]:
269
+ ) -> tuple[str, str, AsyncIterator[GenerativeChunk]]:
354
270
  """
355
271
  Chat query using the new stream format
356
272
  Format specs: https://github.com/ndjson/ndjson-spec
@@ -369,35 +285,55 @@ class PredictEngine:
369
285
  resp = await self.make_request(
370
286
  "POST",
371
287
  url=self.get_predict_url(CHAT, kbid),
372
- json=item.dict(),
288
+ json=item.model_dump(),
373
289
  headers=headers,
374
290
  timeout=None,
375
291
  )
376
292
  await self.check_response(resp, expected_status=200)
377
293
  ident = resp.headers.get(NUCLIA_LEARNING_ID_HEADER)
378
- return ident, get_chat_ndjson_generator(resp)
294
+ model = resp.headers.get(NUCLIA_LEARNING_MODEL_HEADER)
295
+ return ident, model, get_chat_ndjson_generator(resp)
379
296
 
380
297
  @predict_observer.wrap({"type": "query"})
381
298
  async def query(
382
299
  self,
383
300
  kbid: str,
384
301
  sentence: str,
302
+ semantic_model: Optional[str] = None,
385
303
  generative_model: Optional[str] = None,
386
- rephrase: Optional[bool] = False,
304
+ rephrase: bool = False,
305
+ rephrase_prompt: Optional[str] = None,
387
306
  ) -> QueryInfo:
307
+ """
308
+ Query endpoint: returns information to be used by NucliaDB at retrieval time, for instance:
309
+ - The embeddings
310
+ - The entities
311
+ - The stop words
312
+ - The semantic threshold
313
+ - etc.
314
+
315
+ :param kbid: KnowledgeBox ID
316
+ :param sentence: The query sentence
317
+ :param semantic_model: The semantic model to use to generate the embeddings
318
+ :param generative_model: The generative model that will be used to generate the answer
319
+ :param rephrase: If the query should be rephrased before calculating the embeddings for a better retrieval
320
+ :param rephrase_prompt: Custom prompt to use for rephrasing
321
+ """
388
322
  try:
389
323
  self.check_nua_key_is_configured_for_onprem()
390
324
  except NUAKeyMissingError:
391
- error = (
392
- "Nuclia Service account is not defined so could not ask query endpoint"
393
- )
325
+ error = "Nuclia Service account is not defined so could not ask query endpoint"
394
326
  logger.warning(error)
395
327
  raise SendToPredictError(error)
396
328
 
397
- params = {
329
+ params: dict[str, Any] = {
398
330
  "text": sentence,
399
331
  "rephrase": str(rephrase),
400
332
  }
333
+ if rephrase_prompt is not None:
334
+ params["rephrase_prompt"] = rephrase_prompt
335
+ if semantic_model is not None:
336
+ params["semantic_models"] = [semantic_model]
401
337
  if generative_model is not None:
402
338
  params["generative_model"] = generative_model
403
339
 
@@ -442,27 +378,41 @@ class PredictEngine:
442
378
  resp = await self.make_request(
443
379
  "POST",
444
380
  url=self.get_predict_url(SUMMARIZE, kbid),
445
- json=item.dict(),
381
+ json=item.model_dump(),
446
382
  headers=self.get_predict_headers(kbid),
447
383
  timeout=None,
448
384
  )
449
385
  await self.check_response(resp, expected_status=200)
450
386
  data = await resp.json()
451
- return SummarizedResponse.parse_obj(data)
387
+ return SummarizedResponse.model_validate(data)
388
+
389
+ @predict_observer.wrap({"type": "rerank"})
390
+ async def rerank(self, kbid: str, item: RerankModel) -> RerankResponse:
391
+ try:
392
+ self.check_nua_key_is_configured_for_onprem()
393
+ except NUAKeyMissingError:
394
+ error = "Nuclia Service account is not defined. Rerank operation could not be performed"
395
+ logger.warning(error)
396
+ raise SendToPredictError(error)
397
+ resp = await self.make_request(
398
+ "POST",
399
+ url=self.get_predict_url(RERANK, kbid),
400
+ json=item.model_dump(),
401
+ headers=self.get_predict_headers(kbid),
402
+ )
403
+ await self.check_response(resp, expected_status=200)
404
+ data = await resp.json()
405
+ return RerankResponse.model_validate(data)
452
406
 
453
407
 
454
408
  class DummyPredictEngine(PredictEngine):
409
+ default_semantic_threshold = 0.7
410
+
455
411
  def __init__(self):
456
412
  self.onprem = True
457
413
  self.cluster_url = "http://localhost:8000"
458
414
  self.public_url = "http://localhost:8000"
459
415
  self.calls = []
460
- self.generated_answer = [
461
- b"valid ",
462
- b"answer ",
463
- b" to",
464
- AnswerStatusCode.SUCCESS.encode(),
465
- ]
466
416
  self.ndjson_answer = [
467
417
  b'{"chunk": {"type": "text", "text": "valid "}}\n',
468
418
  b'{"chunk": {"type": "text", "text": "answer "}}\n',
@@ -486,79 +436,72 @@ class DummyPredictEngine(PredictEngine):
486
436
  response.headers = {NUCLIA_LEARNING_ID_HEADER: DUMMY_LEARNING_ID}
487
437
  return response
488
438
 
489
- async def send_feedback(
490
- self,
491
- kbid: str,
492
- item: FeedbackRequest,
493
- x_nucliadb_user: str,
494
- x_ndb_client: str,
495
- x_forwarded_for: str,
496
- ):
497
- self.calls.append(("send_feedback", item))
498
- return
499
-
500
439
  async def rephrase_query(self, kbid: str, item: RephraseModel) -> str:
501
440
  self.calls.append(("rephrase_query", item))
502
441
  return DUMMY_REPHRASE_QUERY
503
442
 
504
- async def chat_query(
505
- self, kbid: str, item: ChatModel
506
- ) -> tuple[str, AsyncIterator[bytes]]:
507
- self.calls.append(("chat_query", item))
508
-
509
- async def generate():
510
- for i in self.generated_answer:
511
- yield i
512
-
513
- return (DUMMY_LEARNING_ID, generate())
514
-
515
443
  async def chat_query_ndjson(
516
444
  self, kbid: str, item: ChatModel
517
- ) -> tuple[str, AsyncIterator[bytes]]:
445
+ ) -> tuple[str, str, AsyncIterator[GenerativeChunk]]:
518
446
  self.calls.append(("chat_query_ndjson", item))
519
447
 
520
448
  async def generate():
521
449
  for item in self.ndjson_answer:
522
- yield GenerativeChunk.parse_raw(item)
450
+ yield GenerativeChunk.model_validate_json(item)
523
451
 
524
- return (DUMMY_LEARNING_ID, generate())
452
+ return (DUMMY_LEARNING_ID, DUMMY_LEARNING_MODEL, generate())
525
453
 
526
454
  async def query(
527
455
  self,
528
456
  kbid: str,
529
457
  sentence: str,
458
+ semantic_model: Optional[str] = None,
530
459
  generative_model: Optional[str] = None,
531
- rephrase: Optional[bool] = False,
460
+ rephrase: bool = False,
461
+ rephrase_prompt: Optional[str] = None,
532
462
  ) -> QueryInfo:
533
463
  self.calls.append(("query", sentence))
534
- if (
535
- os.environ.get("TEST_SENTENCE_ENCODER") == "multilingual-2023-02-21"
536
- ): # pragma: no cover
537
- return QueryInfo(
538
- language="en",
539
- stop_words=[],
540
- semantic_threshold=0.7,
541
- visual_llm=True,
542
- max_context=self.max_context,
543
- entities=TokenSearch(
544
- tokens=[Ner(text="text", ner="PERSON", start=0, end=2)], time=0.0
545
- ),
546
- sentence=SentenceSearch(data=Qm2023, time=0.0),
547
- query=sentence,
548
- )
464
+
465
+ if os.environ.get("TEST_SENTENCE_ENCODER") == "multilingual-2023-02-21": # pragma: no cover
466
+ base_vector = Qm2023
549
467
  else:
550
- return QueryInfo(
551
- language="en",
552
- stop_words=[],
553
- semantic_threshold=0.7,
554
- visual_llm=True,
555
- max_context=self.max_context,
556
- entities=TokenSearch(
557
- tokens=[Ner(text="text", ner="PERSON", start=0, end=2)], time=0.0
558
- ),
559
- sentence=SentenceSearch(data=Q, time=0.0),
560
- query=sentence,
561
- )
468
+ base_vector = Q
469
+
470
+ # populate data with existing vectorsets
471
+ async with datamanagers.with_ro_transaction() as txn:
472
+ semantic_thresholds = {}
473
+ vectors = {}
474
+ timings = {}
475
+ async for vectorset_id, config in datamanagers.vectorsets.iter(txn, kbid=kbid):
476
+ semantic_thresholds[vectorset_id] = self.default_semantic_threshold
477
+ vectorset_dimension = config.vectorset_index_config.vector_dimension
478
+ if vectorset_dimension > len(base_vector):
479
+ padding = vectorset_dimension - len(base_vector)
480
+ vectors[vectorset_id] = base_vector + [random.random()] * padding
481
+ else:
482
+ vectors[vectorset_id] = base_vector[:vectorset_dimension]
483
+
484
+ timings[vectorset_id] = 0.010
485
+
486
+ # and fake data with the passed one too
487
+ model = semantic_model or "<PREDICT-DEFAULT-SEMANTIC-MODEL>"
488
+ semantic_thresholds[model] = self.default_semantic_threshold
489
+ vectors[model] = base_vector
490
+ timings[model] = 0.0
491
+
492
+ return QueryInfo(
493
+ language="en",
494
+ stop_words=[],
495
+ semantic_thresholds=semantic_thresholds,
496
+ visual_llm=True,
497
+ max_context=self.max_context,
498
+ entities=TokenSearch(tokens=[Ner(text="text", ner="PERSON", start=0, end=2)], time=0.0),
499
+ sentence=SentenceSearch(
500
+ vectors=vectors,
501
+ timings=timings,
502
+ ),
503
+ query=sentence,
504
+ )
562
505
 
563
506
  async def detect_entities(self, kbid: str, sentence: str) -> list[RelationNode]:
564
507
  self.calls.append(("detect_entities", sentence))
@@ -577,9 +520,16 @@ class DummyPredictEngine(PredictEngine):
577
520
  rsummary = []
578
521
  for field_id, field_text in item.resources[rid].fields.items():
579
522
  rsummary.append(f"{field_id}: {field_text}")
580
- response.resources[rid] = SummarizedResource(
581
- summary="\n\n".join(rsummary), tokens=10
582
- )
523
+ response.resources[rid] = SummarizedResource(summary="\n\n".join(rsummary), tokens=10)
524
+ return response
525
+
526
+ async def rerank(self, kbid: str, item: RerankModel) -> RerankResponse:
527
+ self.calls.append(("rerank", (kbid, item)))
528
+ # as we don't have information about the retrieval scores, return a
529
+ # random score given by the dict iteration
530
+ response = RerankResponse(
531
+ context_scores={paragraph_id: i for i, paragraph_id in enumerate(item.context.keys())}
532
+ )
583
533
  return response
584
534
 
585
535
 
@@ -604,11 +554,10 @@ def get_answer_generator(response: aiohttp.ClientResponse):
604
554
  def get_chat_ndjson_generator(
605
555
  response: aiohttp.ClientResponse,
606
556
  ) -> AsyncIterator[GenerativeChunk]:
607
-
608
557
  async def _parse_generative_chunks(gen):
609
558
  async for chunk in gen:
610
559
  try:
611
- yield GenerativeChunk.parse_raw(chunk.strip())
560
+ yield GenerativeChunk.model_validate_json(chunk.strip())
612
561
  except ValidationError as ex:
613
562
  errors.capture_exception(ex)
614
563
  logger.error(f"Invalid chunk received: {chunk}")
File without changes