nucliadb 4.0.0.post542__py3-none-any.whl → 6.2.1.post2777__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (418) hide show
  1. migrations/0003_allfields_key.py +1 -35
  2. migrations/0009_upgrade_relations_and_texts_to_v2.py +4 -2
  3. migrations/0010_fix_corrupt_indexes.py +10 -10
  4. migrations/0011_materialize_labelset_ids.py +1 -16
  5. migrations/0012_rollover_shards.py +5 -10
  6. migrations/0014_rollover_shards.py +4 -5
  7. migrations/0015_targeted_rollover.py +5 -10
  8. migrations/0016_upgrade_to_paragraphs_v2.py +25 -28
  9. migrations/0017_multiple_writable_shards.py +2 -4
  10. migrations/0018_purge_orphan_kbslugs.py +5 -7
  11. migrations/0019_upgrade_to_paragraphs_v3.py +25 -28
  12. migrations/0020_drain_nodes_from_cluster.py +3 -3
  13. nucliadb/standalone/tests/unit/test_run.py → migrations/0021_overwrite_vectorsets_key.py +16 -19
  14. nucliadb/tests/unit/test_openapi.py → migrations/0022_fix_paragraph_deletion_bug.py +16 -11
  15. migrations/0023_backfill_pg_catalog.py +80 -0
  16. migrations/0025_assign_models_to_kbs_v2.py +113 -0
  17. migrations/0026_fix_high_cardinality_content_types.py +61 -0
  18. migrations/0027_rollover_texts3.py +73 -0
  19. nucliadb/ingest/fields/date.py → migrations/pg/0001_bootstrap.py +10 -12
  20. migrations/pg/0002_catalog.py +42 -0
  21. nucliadb/ingest/tests/unit/test_settings.py → migrations/pg/0003_catalog_kbid_index.py +5 -3
  22. nucliadb/common/cluster/base.py +30 -16
  23. nucliadb/common/cluster/discovery/base.py +6 -14
  24. nucliadb/common/cluster/discovery/k8s.py +9 -19
  25. nucliadb/common/cluster/discovery/manual.py +1 -3
  26. nucliadb/common/cluster/discovery/utils.py +1 -3
  27. nucliadb/common/cluster/grpc_node_dummy.py +3 -11
  28. nucliadb/common/cluster/index_node.py +10 -19
  29. nucliadb/common/cluster/manager.py +174 -59
  30. nucliadb/common/cluster/rebalance.py +27 -29
  31. nucliadb/common/cluster/rollover.py +353 -194
  32. nucliadb/common/cluster/settings.py +6 -0
  33. nucliadb/common/cluster/standalone/grpc_node_binding.py +13 -64
  34. nucliadb/common/cluster/standalone/index_node.py +4 -11
  35. nucliadb/common/cluster/standalone/service.py +2 -6
  36. nucliadb/common/cluster/standalone/utils.py +2 -6
  37. nucliadb/common/cluster/utils.py +29 -22
  38. nucliadb/common/constants.py +20 -0
  39. nucliadb/common/context/__init__.py +3 -0
  40. nucliadb/common/context/fastapi.py +8 -5
  41. nucliadb/{tests/knowledgeboxes/__init__.py → common/counters.py} +8 -2
  42. nucliadb/common/datamanagers/__init__.py +7 -1
  43. nucliadb/common/datamanagers/atomic.py +22 -4
  44. nucliadb/common/datamanagers/cluster.py +5 -5
  45. nucliadb/common/datamanagers/entities.py +6 -16
  46. nucliadb/common/datamanagers/fields.py +84 -0
  47. nucliadb/common/datamanagers/kb.py +83 -37
  48. nucliadb/common/datamanagers/labels.py +26 -56
  49. nucliadb/common/datamanagers/processing.py +2 -6
  50. nucliadb/common/datamanagers/resources.py +41 -103
  51. nucliadb/common/datamanagers/rollover.py +76 -15
  52. nucliadb/common/datamanagers/synonyms.py +1 -1
  53. nucliadb/common/datamanagers/utils.py +15 -6
  54. nucliadb/common/datamanagers/vectorsets.py +110 -0
  55. nucliadb/common/external_index_providers/base.py +257 -0
  56. nucliadb/{ingest/tests/unit/orm/test_orm_utils.py → common/external_index_providers/exceptions.py} +9 -8
  57. nucliadb/common/external_index_providers/manager.py +101 -0
  58. nucliadb/common/external_index_providers/pinecone.py +933 -0
  59. nucliadb/common/external_index_providers/settings.py +52 -0
  60. nucliadb/common/http_clients/auth.py +3 -6
  61. nucliadb/common/http_clients/processing.py +6 -11
  62. nucliadb/common/http_clients/utils.py +1 -3
  63. nucliadb/common/ids.py +240 -0
  64. nucliadb/common/locking.py +29 -7
  65. nucliadb/common/maindb/driver.py +11 -35
  66. nucliadb/common/maindb/exceptions.py +3 -0
  67. nucliadb/common/maindb/local.py +22 -9
  68. nucliadb/common/maindb/pg.py +206 -111
  69. nucliadb/common/maindb/utils.py +11 -42
  70. nucliadb/common/models_utils/from_proto.py +479 -0
  71. nucliadb/common/models_utils/to_proto.py +60 -0
  72. nucliadb/common/nidx.py +260 -0
  73. nucliadb/export_import/datamanager.py +25 -19
  74. nucliadb/export_import/exporter.py +5 -11
  75. nucliadb/export_import/importer.py +5 -7
  76. nucliadb/export_import/models.py +3 -3
  77. nucliadb/export_import/tasks.py +4 -4
  78. nucliadb/export_import/utils.py +25 -37
  79. nucliadb/health.py +1 -3
  80. nucliadb/ingest/app.py +15 -11
  81. nucliadb/ingest/consumer/auditing.py +21 -19
  82. nucliadb/ingest/consumer/consumer.py +82 -47
  83. nucliadb/ingest/consumer/materializer.py +5 -12
  84. nucliadb/ingest/consumer/pull.py +12 -27
  85. nucliadb/ingest/consumer/service.py +19 -17
  86. nucliadb/ingest/consumer/shard_creator.py +2 -4
  87. nucliadb/ingest/consumer/utils.py +1 -3
  88. nucliadb/ingest/fields/base.py +137 -105
  89. nucliadb/ingest/fields/conversation.py +18 -5
  90. nucliadb/ingest/fields/exceptions.py +1 -4
  91. nucliadb/ingest/fields/file.py +7 -16
  92. nucliadb/ingest/fields/link.py +5 -10
  93. nucliadb/ingest/fields/text.py +9 -4
  94. nucliadb/ingest/orm/brain.py +200 -213
  95. nucliadb/ingest/orm/broker_message.py +181 -0
  96. nucliadb/ingest/orm/entities.py +36 -51
  97. nucliadb/ingest/orm/exceptions.py +12 -0
  98. nucliadb/ingest/orm/knowledgebox.py +322 -197
  99. nucliadb/ingest/orm/processor/__init__.py +2 -700
  100. nucliadb/ingest/orm/processor/auditing.py +4 -23
  101. nucliadb/ingest/orm/processor/data_augmentation.py +164 -0
  102. nucliadb/ingest/orm/processor/pgcatalog.py +84 -0
  103. nucliadb/ingest/orm/processor/processor.py +752 -0
  104. nucliadb/ingest/orm/processor/sequence_manager.py +1 -1
  105. nucliadb/ingest/orm/resource.py +249 -402
  106. nucliadb/ingest/orm/utils.py +4 -4
  107. nucliadb/ingest/partitions.py +3 -9
  108. nucliadb/ingest/processing.py +64 -73
  109. nucliadb/ingest/py.typed +0 -0
  110. nucliadb/ingest/serialize.py +37 -167
  111. nucliadb/ingest/service/__init__.py +1 -3
  112. nucliadb/ingest/service/writer.py +185 -412
  113. nucliadb/ingest/settings.py +10 -20
  114. nucliadb/ingest/utils.py +3 -6
  115. nucliadb/learning_proxy.py +242 -55
  116. nucliadb/metrics_exporter.py +30 -19
  117. nucliadb/middleware/__init__.py +1 -3
  118. nucliadb/migrator/command.py +1 -3
  119. nucliadb/migrator/datamanager.py +13 -13
  120. nucliadb/migrator/migrator.py +47 -30
  121. nucliadb/migrator/utils.py +18 -10
  122. nucliadb/purge/__init__.py +139 -33
  123. nucliadb/purge/orphan_shards.py +7 -13
  124. nucliadb/reader/__init__.py +1 -3
  125. nucliadb/reader/api/models.py +1 -12
  126. nucliadb/reader/api/v1/__init__.py +0 -1
  127. nucliadb/reader/api/v1/download.py +21 -88
  128. nucliadb/reader/api/v1/export_import.py +1 -1
  129. nucliadb/reader/api/v1/knowledgebox.py +10 -10
  130. nucliadb/reader/api/v1/learning_config.py +2 -6
  131. nucliadb/reader/api/v1/resource.py +62 -88
  132. nucliadb/reader/api/v1/services.py +64 -83
  133. nucliadb/reader/app.py +12 -29
  134. nucliadb/reader/lifecycle.py +18 -4
  135. nucliadb/reader/py.typed +0 -0
  136. nucliadb/reader/reader/notifications.py +10 -28
  137. nucliadb/search/__init__.py +1 -3
  138. nucliadb/search/api/v1/__init__.py +1 -2
  139. nucliadb/search/api/v1/ask.py +17 -10
  140. nucliadb/search/api/v1/catalog.py +184 -0
  141. nucliadb/search/api/v1/feedback.py +16 -24
  142. nucliadb/search/api/v1/find.py +36 -36
  143. nucliadb/search/api/v1/knowledgebox.py +89 -60
  144. nucliadb/search/api/v1/resource/ask.py +2 -8
  145. nucliadb/search/api/v1/resource/search.py +49 -70
  146. nucliadb/search/api/v1/search.py +44 -210
  147. nucliadb/search/api/v1/suggest.py +39 -54
  148. nucliadb/search/app.py +12 -32
  149. nucliadb/search/lifecycle.py +10 -3
  150. nucliadb/search/predict.py +136 -187
  151. nucliadb/search/py.typed +0 -0
  152. nucliadb/search/requesters/utils.py +25 -58
  153. nucliadb/search/search/cache.py +149 -20
  154. nucliadb/search/search/chat/ask.py +571 -123
  155. nucliadb/search/{tests/unit/test_run.py → search/chat/exceptions.py} +14 -14
  156. nucliadb/search/search/chat/images.py +41 -17
  157. nucliadb/search/search/chat/prompt.py +817 -266
  158. nucliadb/search/search/chat/query.py +213 -309
  159. nucliadb/{tests/migrations/__init__.py → search/search/cut.py} +8 -8
  160. nucliadb/search/search/fetch.py +43 -36
  161. nucliadb/search/search/filters.py +9 -15
  162. nucliadb/search/search/find.py +214 -53
  163. nucliadb/search/search/find_merge.py +408 -391
  164. nucliadb/search/search/hydrator.py +191 -0
  165. nucliadb/search/search/merge.py +187 -223
  166. nucliadb/search/search/metrics.py +73 -2
  167. nucliadb/search/search/paragraphs.py +64 -106
  168. nucliadb/search/search/pgcatalog.py +233 -0
  169. nucliadb/search/search/predict_proxy.py +1 -1
  170. nucliadb/search/search/query.py +305 -150
  171. nucliadb/search/search/query_parser/exceptions.py +22 -0
  172. nucliadb/search/search/query_parser/models.py +101 -0
  173. nucliadb/search/search/query_parser/parser.py +183 -0
  174. nucliadb/search/search/rank_fusion.py +204 -0
  175. nucliadb/search/search/rerankers.py +270 -0
  176. nucliadb/search/search/shards.py +3 -32
  177. nucliadb/search/search/summarize.py +7 -18
  178. nucliadb/search/search/utils.py +27 -4
  179. nucliadb/search/settings.py +15 -1
  180. nucliadb/standalone/api_router.py +4 -10
  181. nucliadb/standalone/app.py +8 -14
  182. nucliadb/standalone/auth.py +7 -21
  183. nucliadb/standalone/config.py +7 -10
  184. nucliadb/standalone/lifecycle.py +26 -25
  185. nucliadb/standalone/migrations.py +1 -3
  186. nucliadb/standalone/purge.py +1 -1
  187. nucliadb/standalone/py.typed +0 -0
  188. nucliadb/standalone/run.py +3 -6
  189. nucliadb/standalone/settings.py +9 -16
  190. nucliadb/standalone/versions.py +15 -5
  191. nucliadb/tasks/consumer.py +8 -12
  192. nucliadb/tasks/producer.py +7 -6
  193. nucliadb/tests/config.py +53 -0
  194. nucliadb/train/__init__.py +1 -3
  195. nucliadb/train/api/utils.py +1 -2
  196. nucliadb/train/api/v1/shards.py +1 -1
  197. nucliadb/train/api/v1/trainset.py +2 -4
  198. nucliadb/train/app.py +10 -31
  199. nucliadb/train/generator.py +10 -19
  200. nucliadb/train/generators/field_classifier.py +7 -19
  201. nucliadb/train/generators/field_streaming.py +156 -0
  202. nucliadb/train/generators/image_classifier.py +12 -18
  203. nucliadb/train/generators/paragraph_classifier.py +5 -9
  204. nucliadb/train/generators/paragraph_streaming.py +6 -9
  205. nucliadb/train/generators/question_answer_streaming.py +19 -20
  206. nucliadb/train/generators/sentence_classifier.py +9 -15
  207. nucliadb/train/generators/token_classifier.py +48 -39
  208. nucliadb/train/generators/utils.py +14 -18
  209. nucliadb/train/lifecycle.py +7 -3
  210. nucliadb/train/nodes.py +23 -32
  211. nucliadb/train/py.typed +0 -0
  212. nucliadb/train/servicer.py +13 -21
  213. nucliadb/train/settings.py +2 -6
  214. nucliadb/train/types.py +13 -10
  215. nucliadb/train/upload.py +3 -6
  216. nucliadb/train/uploader.py +19 -23
  217. nucliadb/train/utils.py +1 -1
  218. nucliadb/writer/__init__.py +1 -3
  219. nucliadb/{ingest/fields/keywordset.py → writer/api/utils.py} +13 -10
  220. nucliadb/writer/api/v1/export_import.py +67 -14
  221. nucliadb/writer/api/v1/field.py +16 -269
  222. nucliadb/writer/api/v1/knowledgebox.py +218 -68
  223. nucliadb/writer/api/v1/resource.py +68 -88
  224. nucliadb/writer/api/v1/services.py +51 -70
  225. nucliadb/writer/api/v1/slug.py +61 -0
  226. nucliadb/writer/api/v1/transaction.py +67 -0
  227. nucliadb/writer/api/v1/upload.py +114 -113
  228. nucliadb/writer/app.py +6 -43
  229. nucliadb/writer/back_pressure.py +16 -38
  230. nucliadb/writer/exceptions.py +0 -4
  231. nucliadb/writer/lifecycle.py +21 -15
  232. nucliadb/writer/py.typed +0 -0
  233. nucliadb/writer/resource/audit.py +2 -1
  234. nucliadb/writer/resource/basic.py +48 -46
  235. nucliadb/writer/resource/field.py +25 -127
  236. nucliadb/writer/resource/origin.py +1 -2
  237. nucliadb/writer/settings.py +6 -2
  238. nucliadb/writer/tus/__init__.py +17 -15
  239. nucliadb/writer/tus/azure.py +111 -0
  240. nucliadb/writer/tus/dm.py +17 -5
  241. nucliadb/writer/tus/exceptions.py +1 -3
  242. nucliadb/writer/tus/gcs.py +49 -84
  243. nucliadb/writer/tus/local.py +21 -37
  244. nucliadb/writer/tus/s3.py +28 -68
  245. nucliadb/writer/tus/storage.py +5 -56
  246. nucliadb/writer/vectorsets.py +125 -0
  247. nucliadb-6.2.1.post2777.dist-info/METADATA +148 -0
  248. nucliadb-6.2.1.post2777.dist-info/RECORD +343 -0
  249. {nucliadb-4.0.0.post542.dist-info → nucliadb-6.2.1.post2777.dist-info}/WHEEL +1 -1
  250. nucliadb/common/maindb/redis.py +0 -194
  251. nucliadb/common/maindb/tikv.py +0 -433
  252. nucliadb/ingest/fields/layout.py +0 -58
  253. nucliadb/ingest/tests/conftest.py +0 -30
  254. nucliadb/ingest/tests/fixtures.py +0 -764
  255. nucliadb/ingest/tests/integration/consumer/__init__.py +0 -18
  256. nucliadb/ingest/tests/integration/consumer/test_auditing.py +0 -78
  257. nucliadb/ingest/tests/integration/consumer/test_materializer.py +0 -126
  258. nucliadb/ingest/tests/integration/consumer/test_pull.py +0 -144
  259. nucliadb/ingest/tests/integration/consumer/test_service.py +0 -81
  260. nucliadb/ingest/tests/integration/consumer/test_shard_creator.py +0 -68
  261. nucliadb/ingest/tests/integration/ingest/test_ingest.py +0 -684
  262. nucliadb/ingest/tests/integration/ingest/test_processing_engine.py +0 -95
  263. nucliadb/ingest/tests/integration/ingest/test_relations.py +0 -272
  264. nucliadb/ingest/tests/unit/consumer/__init__.py +0 -18
  265. nucliadb/ingest/tests/unit/consumer/test_auditing.py +0 -139
  266. nucliadb/ingest/tests/unit/consumer/test_consumer.py +0 -69
  267. nucliadb/ingest/tests/unit/consumer/test_pull.py +0 -60
  268. nucliadb/ingest/tests/unit/consumer/test_shard_creator.py +0 -140
  269. nucliadb/ingest/tests/unit/consumer/test_utils.py +0 -67
  270. nucliadb/ingest/tests/unit/orm/__init__.py +0 -19
  271. nucliadb/ingest/tests/unit/orm/test_brain.py +0 -247
  272. nucliadb/ingest/tests/unit/orm/test_brain_vectors.py +0 -74
  273. nucliadb/ingest/tests/unit/orm/test_processor.py +0 -131
  274. nucliadb/ingest/tests/unit/orm/test_resource.py +0 -331
  275. nucliadb/ingest/tests/unit/test_cache.py +0 -31
  276. nucliadb/ingest/tests/unit/test_partitions.py +0 -40
  277. nucliadb/ingest/tests/unit/test_processing.py +0 -171
  278. nucliadb/middleware/transaction.py +0 -117
  279. nucliadb/reader/api/v1/learning_collector.py +0 -63
  280. nucliadb/reader/tests/__init__.py +0 -19
  281. nucliadb/reader/tests/conftest.py +0 -31
  282. nucliadb/reader/tests/fixtures.py +0 -136
  283. nucliadb/reader/tests/test_list_resources.py +0 -75
  284. nucliadb/reader/tests/test_reader_file_download.py +0 -273
  285. nucliadb/reader/tests/test_reader_resource.py +0 -353
  286. nucliadb/reader/tests/test_reader_resource_field.py +0 -219
  287. nucliadb/search/api/v1/chat.py +0 -263
  288. nucliadb/search/api/v1/resource/chat.py +0 -174
  289. nucliadb/search/tests/__init__.py +0 -19
  290. nucliadb/search/tests/conftest.py +0 -33
  291. nucliadb/search/tests/fixtures.py +0 -199
  292. nucliadb/search/tests/node.py +0 -466
  293. nucliadb/search/tests/unit/__init__.py +0 -18
  294. nucliadb/search/tests/unit/api/__init__.py +0 -19
  295. nucliadb/search/tests/unit/api/v1/__init__.py +0 -19
  296. nucliadb/search/tests/unit/api/v1/resource/__init__.py +0 -19
  297. nucliadb/search/tests/unit/api/v1/resource/test_chat.py +0 -98
  298. nucliadb/search/tests/unit/api/v1/test_ask.py +0 -120
  299. nucliadb/search/tests/unit/api/v1/test_chat.py +0 -96
  300. nucliadb/search/tests/unit/api/v1/test_predict_proxy.py +0 -98
  301. nucliadb/search/tests/unit/api/v1/test_summarize.py +0 -99
  302. nucliadb/search/tests/unit/search/__init__.py +0 -18
  303. nucliadb/search/tests/unit/search/requesters/__init__.py +0 -18
  304. nucliadb/search/tests/unit/search/requesters/test_utils.py +0 -211
  305. nucliadb/search/tests/unit/search/search/__init__.py +0 -19
  306. nucliadb/search/tests/unit/search/search/test_shards.py +0 -45
  307. nucliadb/search/tests/unit/search/search/test_utils.py +0 -82
  308. nucliadb/search/tests/unit/search/test_chat_prompt.py +0 -270
  309. nucliadb/search/tests/unit/search/test_fetch.py +0 -108
  310. nucliadb/search/tests/unit/search/test_filters.py +0 -125
  311. nucliadb/search/tests/unit/search/test_paragraphs.py +0 -157
  312. nucliadb/search/tests/unit/search/test_predict_proxy.py +0 -106
  313. nucliadb/search/tests/unit/search/test_query.py +0 -153
  314. nucliadb/search/tests/unit/test_app.py +0 -79
  315. nucliadb/search/tests/unit/test_find_merge.py +0 -112
  316. nucliadb/search/tests/unit/test_merge.py +0 -34
  317. nucliadb/search/tests/unit/test_predict.py +0 -525
  318. nucliadb/standalone/tests/__init__.py +0 -19
  319. nucliadb/standalone/tests/conftest.py +0 -33
  320. nucliadb/standalone/tests/fixtures.py +0 -38
  321. nucliadb/standalone/tests/unit/__init__.py +0 -18
  322. nucliadb/standalone/tests/unit/test_api_router.py +0 -61
  323. nucliadb/standalone/tests/unit/test_auth.py +0 -169
  324. nucliadb/standalone/tests/unit/test_introspect.py +0 -35
  325. nucliadb/standalone/tests/unit/test_migrations.py +0 -63
  326. nucliadb/standalone/tests/unit/test_versions.py +0 -68
  327. nucliadb/tests/benchmarks/__init__.py +0 -19
  328. nucliadb/tests/benchmarks/test_search.py +0 -99
  329. nucliadb/tests/conftest.py +0 -32
  330. nucliadb/tests/fixtures.py +0 -735
  331. nucliadb/tests/knowledgeboxes/philosophy_books.py +0 -202
  332. nucliadb/tests/knowledgeboxes/ten_dummy_resources.py +0 -107
  333. nucliadb/tests/migrations/test_migration_0017.py +0 -76
  334. nucliadb/tests/migrations/test_migration_0018.py +0 -95
  335. nucliadb/tests/tikv.py +0 -240
  336. nucliadb/tests/unit/__init__.py +0 -19
  337. nucliadb/tests/unit/common/__init__.py +0 -19
  338. nucliadb/tests/unit/common/cluster/__init__.py +0 -19
  339. nucliadb/tests/unit/common/cluster/discovery/__init__.py +0 -19
  340. nucliadb/tests/unit/common/cluster/discovery/test_k8s.py +0 -172
  341. nucliadb/tests/unit/common/cluster/standalone/__init__.py +0 -18
  342. nucliadb/tests/unit/common/cluster/standalone/test_service.py +0 -114
  343. nucliadb/tests/unit/common/cluster/standalone/test_utils.py +0 -61
  344. nucliadb/tests/unit/common/cluster/test_cluster.py +0 -408
  345. nucliadb/tests/unit/common/cluster/test_kb_shard_manager.py +0 -173
  346. nucliadb/tests/unit/common/cluster/test_rebalance.py +0 -38
  347. nucliadb/tests/unit/common/cluster/test_rollover.py +0 -282
  348. nucliadb/tests/unit/common/maindb/__init__.py +0 -18
  349. nucliadb/tests/unit/common/maindb/test_driver.py +0 -127
  350. nucliadb/tests/unit/common/maindb/test_tikv.py +0 -53
  351. nucliadb/tests/unit/common/maindb/test_utils.py +0 -92
  352. nucliadb/tests/unit/common/test_context.py +0 -36
  353. nucliadb/tests/unit/export_import/__init__.py +0 -19
  354. nucliadb/tests/unit/export_import/test_datamanager.py +0 -37
  355. nucliadb/tests/unit/export_import/test_utils.py +0 -301
  356. nucliadb/tests/unit/migrator/__init__.py +0 -19
  357. nucliadb/tests/unit/migrator/test_migrator.py +0 -87
  358. nucliadb/tests/unit/tasks/__init__.py +0 -19
  359. nucliadb/tests/unit/tasks/conftest.py +0 -42
  360. nucliadb/tests/unit/tasks/test_consumer.py +0 -92
  361. nucliadb/tests/unit/tasks/test_producer.py +0 -95
  362. nucliadb/tests/unit/tasks/test_tasks.py +0 -58
  363. nucliadb/tests/unit/test_field_ids.py +0 -49
  364. nucliadb/tests/unit/test_health.py +0 -86
  365. nucliadb/tests/unit/test_kb_slugs.py +0 -54
  366. nucliadb/tests/unit/test_learning_proxy.py +0 -252
  367. nucliadb/tests/unit/test_metrics_exporter.py +0 -77
  368. nucliadb/tests/unit/test_purge.py +0 -136
  369. nucliadb/tests/utils/__init__.py +0 -74
  370. nucliadb/tests/utils/aiohttp_session.py +0 -44
  371. nucliadb/tests/utils/broker_messages/__init__.py +0 -171
  372. nucliadb/tests/utils/broker_messages/fields.py +0 -197
  373. nucliadb/tests/utils/broker_messages/helpers.py +0 -33
  374. nucliadb/tests/utils/entities.py +0 -78
  375. nucliadb/train/api/v1/check.py +0 -60
  376. nucliadb/train/tests/__init__.py +0 -19
  377. nucliadb/train/tests/conftest.py +0 -29
  378. nucliadb/train/tests/fixtures.py +0 -342
  379. nucliadb/train/tests/test_field_classification.py +0 -122
  380. nucliadb/train/tests/test_get_entities.py +0 -80
  381. nucliadb/train/tests/test_get_info.py +0 -51
  382. nucliadb/train/tests/test_get_ontology.py +0 -34
  383. nucliadb/train/tests/test_get_ontology_count.py +0 -63
  384. nucliadb/train/tests/test_image_classification.py +0 -221
  385. nucliadb/train/tests/test_list_fields.py +0 -39
  386. nucliadb/train/tests/test_list_paragraphs.py +0 -73
  387. nucliadb/train/tests/test_list_resources.py +0 -39
  388. nucliadb/train/tests/test_list_sentences.py +0 -71
  389. nucliadb/train/tests/test_paragraph_classification.py +0 -123
  390. nucliadb/train/tests/test_paragraph_streaming.py +0 -118
  391. nucliadb/train/tests/test_question_answer_streaming.py +0 -239
  392. nucliadb/train/tests/test_sentence_classification.py +0 -143
  393. nucliadb/train/tests/test_token_classification.py +0 -136
  394. nucliadb/train/tests/utils.py +0 -101
  395. nucliadb/writer/layouts/__init__.py +0 -51
  396. nucliadb/writer/layouts/v1.py +0 -59
  397. nucliadb/writer/tests/__init__.py +0 -19
  398. nucliadb/writer/tests/conftest.py +0 -31
  399. nucliadb/writer/tests/fixtures.py +0 -191
  400. nucliadb/writer/tests/test_fields.py +0 -475
  401. nucliadb/writer/tests/test_files.py +0 -740
  402. nucliadb/writer/tests/test_knowledgebox.py +0 -49
  403. nucliadb/writer/tests/test_reprocess_file_field.py +0 -133
  404. nucliadb/writer/tests/test_resources.py +0 -476
  405. nucliadb/writer/tests/test_service.py +0 -137
  406. nucliadb/writer/tests/test_tus.py +0 -203
  407. nucliadb/writer/tests/utils.py +0 -35
  408. nucliadb/writer/tus/pg.py +0 -125
  409. nucliadb-4.0.0.post542.dist-info/METADATA +0 -135
  410. nucliadb-4.0.0.post542.dist-info/RECORD +0 -462
  411. {nucliadb/ingest/tests → migrations/pg}/__init__.py +0 -0
  412. /nucliadb/{ingest/tests/integration → common/external_index_providers}/__init__.py +0 -0
  413. /nucliadb/{ingest/tests/integration/ingest → common/models_utils}/__init__.py +0 -0
  414. /nucliadb/{ingest/tests/unit → search/search/query_parser}/__init__.py +0 -0
  415. /nucliadb/{ingest/tests → tests}/vectors.py +0 -0
  416. {nucliadb-4.0.0.post542.dist-info → nucliadb-6.2.1.post2777.dist-info}/entry_points.txt +0 -0
  417. {nucliadb-4.0.0.post542.dist-info → nucliadb-6.2.1.post2777.dist-info}/top_level.txt +0 -0
  418. {nucliadb-4.0.0.post542.dist-info → nucliadb-6.2.1.post2777.dist-info}/zip-safe +0 -0
@@ -17,32 +17,55 @@
17
17
  # You should have received a copy of the GNU Affero General Public License
18
18
  # along with this program. If not, see <http://www.gnu.org/licenses/>.
19
19
  #
20
+ import asyncio
21
+ import copy
22
+ from collections import deque
20
23
  from dataclasses import dataclass
21
- from typing import Dict, List, Optional, Sequence, Tuple
24
+ from typing import Deque, Dict, List, Optional, Sequence, Tuple, Union, cast
22
25
 
26
+ import yaml
27
+ from pydantic import BaseModel
28
+
29
+ from nucliadb.common.ids import FIELD_TYPE_STR_TO_PB, FieldId, ParagraphId
30
+ from nucliadb.common.maindb.utils import get_driver
31
+ from nucliadb.common.models_utils import from_proto
23
32
  from nucliadb.ingest.fields.base import Field
24
33
  from nucliadb.ingest.fields.conversation import Conversation
34
+ from nucliadb.ingest.fields.file import File
25
35
  from nucliadb.ingest.orm.knowledgebox import KnowledgeBox as KnowledgeBoxORM
26
- from nucliadb.ingest.orm.resource import KB_REVERSE
27
- from nucliadb.ingest.orm.resource import Resource as ResourceORM
28
- from nucliadb.middleware.transaction import get_read_only_transaction
29
36
  from nucliadb.search import logger
30
- from nucliadb.search.search import paragraphs
31
- from nucliadb.search.search.chat.images import get_page_image, get_paragraph_image
37
+ from nucliadb.search.search import cache
38
+ from nucliadb.search.search.chat.images import (
39
+ get_file_thumbnail_image,
40
+ get_page_image,
41
+ get_paragraph_image,
42
+ )
43
+ from nucliadb.search.search.hydrator import hydrate_field_text, hydrate_resource_text
44
+ from nucliadb.search.search.paragraphs import get_paragraph_text
45
+ from nucliadb_models.metadata import Extra, Origin
32
46
  from nucliadb_models.search import (
33
47
  SCORE_TYPE,
48
+ ConversationalStrategy,
49
+ FieldExtensionStrategy,
34
50
  FindParagraph,
51
+ FullResourceStrategy,
52
+ HierarchyResourceStrategy,
35
53
  ImageRagStrategy,
36
54
  ImageRagStrategyName,
37
- KnowledgeboxFindResults,
55
+ MetadataExtensionStrategy,
56
+ MetadataExtensionType,
57
+ NeighbouringParagraphsStrategy,
58
+ PageImageStrategy,
59
+ ParagraphImageStrategy,
38
60
  PromptContext,
39
61
  PromptContextImages,
40
62
  PromptContextOrder,
41
63
  RagStrategy,
42
64
  RagStrategyName,
65
+ TableImageStrategy,
43
66
  )
44
67
  from nucliadb_protos import resources_pb2
45
- from nucliadb_utils.asyncio_utils import ConcurrentRunner, run_concurrently
68
+ from nucliadb_utils.asyncio_utils import run_concurrently
46
69
  from nucliadb_utils.utilities import get_storage
47
70
 
48
71
  MAX_RESOURCE_TASKS = 5
@@ -53,6 +76,12 @@ MAX_RESOURCE_FIELD_TASKS = 4
53
76
  # The hope here is it will be enough to get the answer to the question.
54
77
  CONVERSATION_MESSAGE_CONTEXT_EXPANSION = 15
55
78
 
79
+ TextBlockId = Union[ParagraphId, FieldId]
80
+
81
+
82
+ class ParagraphIdNotFoundInExtractedMetadata(Exception):
83
+ pass
84
+
56
85
 
57
86
  class CappedPromptContext:
58
87
  """
@@ -70,16 +99,26 @@ class CappedPromptContext:
70
99
  self._size = 0
71
100
 
72
101
  def __setitem__(self, key: str, value: str) -> None:
102
+ prev_value_len = len(self.output.get(key, ""))
73
103
  if self.max_size is None:
74
- # Unbounded size
75
- self.output[key] = value
104
+ # Unbounded size context
105
+ to_add = value
76
106
  else:
77
- existing_len = len(self.output.get(key, ""))
78
- self._size -= existing_len
79
- size_available = self.max_size - self._size
80
- if size_available > 0:
81
- self.output[key] = value[:size_available]
82
- self._size += len(self.output[key])
107
+ # Make sure we don't exceed the max size
108
+ size_available = max(self.max_size - self._size + prev_value_len, 0)
109
+ to_add = value[:size_available]
110
+ self.output[key] = to_add
111
+ self._size = self._size - prev_value_len + len(to_add)
112
+
113
+ def __getitem__(self, key: str) -> str:
114
+ return self.output.__getitem__(key)
115
+
116
+ def __delitem__(self, key: str) -> None:
117
+ value = self.output.pop(key, "")
118
+ self._size -= len(value)
119
+
120
+ def text_block_ids(self) -> list[str]:
121
+ return list(self.output.keys())
83
122
 
84
123
  @property
85
124
  def size(self) -> int:
@@ -94,15 +133,15 @@ async def get_next_conversation_messages(
94
133
  num_messages: int,
95
134
  message_type: Optional[resources_pb2.Message.MessageType.ValueType] = None,
96
135
  msg_to: Optional[str] = None,
97
- ):
136
+ ) -> List[resources_pb2.Message]:
98
137
  output = []
99
138
  cmetadata = await field_obj.get_metadata()
100
139
  for current_page in range(page, cmetadata.pages + 1):
101
140
  conv = await field_obj.db_get_value(current_page)
102
141
  for message in conv.messages[start_idx:]:
103
- if message_type is not None and message.type != message_type:
142
+ if message_type is not None and message.type != message_type: # pragma: no cover
104
143
  continue
105
- if msg_to is not None and msg_to not in message.to:
144
+ if msg_to is not None and msg_to not in message.to: # pragma: no cover
106
145
  continue
107
146
  output.append(message)
108
147
  if len(output) >= num_messages:
@@ -125,16 +164,21 @@ async def find_conversation_message(
125
164
 
126
165
 
127
166
  async def get_expanded_conversation_messages(
128
- *, kb: KnowledgeBoxORM, rid: str, field_id: str, mident: str
167
+ *,
168
+ kb: KnowledgeBoxORM,
169
+ rid: str,
170
+ field_id: str,
171
+ mident: str,
172
+ max_messages: int = CONVERSATION_MESSAGE_CONTEXT_EXPANSION,
129
173
  ) -> list[resources_pb2.Message]:
130
174
  resource = await kb.get(rid)
131
- if resource is None:
175
+ if resource is None: # pragma: no cover
132
176
  return []
133
- field_obj = await resource.get_field(field_id, KB_REVERSE["c"], load=True)
177
+ field_obj: Conversation = await resource.get_field(field_id, FIELD_TYPE_STR_TO_PB["c"], load=True) # type: ignore
134
178
  found_message, found_page, found_idx = await find_conversation_message(
135
179
  field_obj=field_obj, mident=mident
136
180
  )
137
- if found_message is None:
181
+ if found_message is None: # pragma: no cover
138
182
  return []
139
183
  elif found_message.type == resources_pb2.Message.MessageType.QUESTION:
140
184
  # only try to get answer if it was a question
@@ -150,7 +194,7 @@ async def get_expanded_conversation_messages(
150
194
  field_obj=field_obj,
151
195
  page=found_page,
152
196
  start_idx=found_idx + 1,
153
- num_messages=CONVERSATION_MESSAGE_CONTEXT_EXPANSION,
197
+ num_messages=max_messages,
154
198
  )
155
199
 
156
200
 
@@ -169,83 +213,27 @@ async def default_prompt_context(
169
213
  - Using an dict prevents from duplicates pulled in through conversation expansion.
170
214
  """
171
215
  # Sort retrieved paragraphs by decreasing order (most relevant first)
172
- txn = await get_read_only_transaction()
173
- storage = await get_storage()
174
- kb = KnowledgeBoxORM(txn, storage, kbid)
175
- for paragraph in ordered_paragraphs:
176
- context[paragraph.id] = _clean_paragraph_text(paragraph)
177
-
178
- # If the paragraph is a conversation and it matches semantically, we assume we
179
- # have matched with the question, therefore try to include the answer to the
180
- # context by pulling the next few messages of the conversation field
181
- rid, field_type, field_id, mident = paragraph.id.split("/")[:4]
182
- if field_type == "c" and paragraph.score_type in (
183
- SCORE_TYPE.VECTOR,
184
- SCORE_TYPE.BOTH,
185
- ):
186
- expanded_msgs = await get_expanded_conversation_messages(
187
- kb=kb, rid=rid, field_id=field_id, mident=mident
188
- )
189
- for msg in expanded_msgs:
190
- text = msg.content.text.strip()
191
- pid = f"{rid}/{field_type}/{field_id}/{msg.ident}/0-{len(msg.content.text) + 1}"
192
- context[pid] = text
193
-
194
-
195
- async def get_field_extracted_text(field: Field) -> Optional[tuple[Field, str]]:
196
- extracted_text_pb = await field.get_extracted_text(force=True)
197
- if extracted_text_pb is None:
198
- return None
199
- return field, extracted_text_pb.text
200
-
201
-
202
- async def get_resource_field_extracted_text(
203
- kb_obj: KnowledgeBoxORM,
204
- resource_uuid,
205
- field_id: str,
206
- ) -> Optional[tuple[Field, str]]:
207
- resource = await kb_obj.get(resource_uuid)
208
- if resource is None:
209
- return None
210
-
211
- try:
212
- field_type, field_key = field_id.strip("/").split("/")
213
- except ValueError:
214
- logger.error(f"Invalid field id: {field_id}. Skipping getting extracted text.")
215
- return None
216
- field = await resource.get_field(field_key, KB_REVERSE[field_type], load=False)
217
- if field is None:
218
- return None
219
- result = await get_field_extracted_text(field)
220
- if result is None:
221
- return None
222
- _, extracted_text = result
223
- return field, extracted_text
224
-
225
-
226
- async def get_resource_extracted_texts(
227
- kbid: str,
228
- resource_uuid: str,
229
- ) -> list[tuple[Field, str]]:
230
- txn = await get_read_only_transaction()
231
- storage = await get_storage()
232
- kb = KnowledgeBoxORM(txn, storage, kbid)
233
- resource = ResourceORM(
234
- txn=txn,
235
- storage=storage,
236
- kb=kb,
237
- uuid=resource_uuid,
238
- )
239
-
240
- # Schedule the extraction of the text of each field in the resource
241
- runner = ConcurrentRunner(max_tasks=MAX_RESOURCE_FIELD_TASKS)
242
- for field_type, field_key in await resource.get_fields(force=True):
243
- field = await resource.get_field(field_key, field_type, load=False)
244
- runner.schedule(get_field_extracted_text(field))
245
-
246
- # Wait for the results
247
- results = await runner.wait()
248
- return [result for result in results if result is not None]
216
+ async with get_driver().transaction(read_only=True) as txn:
217
+ storage = await get_storage()
218
+ kb = KnowledgeBoxORM(txn, storage, kbid)
219
+ for paragraph in ordered_paragraphs:
220
+ context[paragraph.id] = _clean_paragraph_text(paragraph)
221
+
222
+ # If the paragraph is a conversation and it matches semantically, we assume we
223
+ # have matched with the question, therefore try to include the answer to the
224
+ # context by pulling the next few messages of the conversation field
225
+ rid, field_type, field_id, mident = paragraph.id.split("/")[:4]
226
+ if field_type == "c" and paragraph.score_type in (
227
+ SCORE_TYPE.VECTOR,
228
+ SCORE_TYPE.BOTH,
229
+ ):
230
+ expanded_msgs = await get_expanded_conversation_messages(
231
+ kb=kb, rid=rid, field_id=field_id, mident=mident
232
+ )
233
+ for msg in expanded_msgs:
234
+ text = msg.content.text.strip()
235
+ pid = f"{rid}/{field_type}/{field_id}/{msg.ident}/0-{len(msg.content.text) + 1}"
236
+ context[pid] = text
249
237
 
250
238
 
251
239
  async def full_resource_prompt_context(
@@ -253,19 +241,18 @@ async def full_resource_prompt_context(
253
241
  kbid: str,
254
242
  ordered_paragraphs: list[FindParagraph],
255
243
  resource: Optional[str],
256
- number_of_full_resources: Optional[int] = None,
244
+ strategy: FullResourceStrategy,
257
245
  ) -> None:
258
246
  """
259
247
  Algorithm steps:
260
248
  - Collect the list of resources in the results (in order of relevance).
261
249
  - For each resource, collect the extracted text from all its fields and craft the context.
262
-
263
250
  Arguments:
264
251
  context: The context to be updated.
265
252
  kbid: The knowledge box id.
266
- results: The results of the retrieval (find) operation.
253
+ ordered_paragraphs: The results of the retrieval (find) operation.
267
254
  resource: The resource to be included in the context. This is used only when chatting with a specific resource with no retrieval.
268
- number_of_full_resources: The number of full resources to include in the context.
255
+ strategy: strategy instance containing, for example, the number of full resources to include in the context.
269
256
  """ # noqa: E501
270
257
  if resource is not None:
271
258
  # The user has specified a resource to be included in the context.
@@ -274,32 +261,205 @@ async def full_resource_prompt_context(
274
261
  # Collect the list of resources in the results (in order of relevance).
275
262
  ordered_resources = []
276
263
  for paragraph in ordered_paragraphs:
277
- resource_uuid = paragraph.id.split("/")[0]
264
+ resource_uuid = parse_text_block_id(paragraph.id).rid
278
265
  if resource_uuid not in ordered_resources:
279
- ordered_resources.append(resource_uuid)
266
+ skip = False
267
+ if strategy.apply_to is not None:
268
+ # decide whether the resource should be extended or not
269
+ for label in strategy.apply_to.exclude:
270
+ skip = skip or (label in (paragraph.labels or []))
271
+
272
+ if not skip:
273
+ ordered_resources.append(resource_uuid)
280
274
 
281
275
  # For each resource, collect the extracted text from all its fields.
282
- resource_extracted_texts = await run_concurrently(
276
+ resources_extracted_texts = await run_concurrently(
283
277
  [
284
- get_resource_extracted_texts(kbid, resource_uuid)
285
- for resource_uuid in ordered_resources[:number_of_full_resources]
278
+ hydrate_resource_text(kbid, resource_uuid, max_concurrent_tasks=MAX_RESOURCE_FIELD_TASKS)
279
+ for resource_uuid in ordered_resources[: strategy.count]
286
280
  ],
287
281
  max_concurrent=MAX_RESOURCE_TASKS,
288
282
  )
289
-
290
- for extracted_texts in resource_extracted_texts:
291
- if extracted_texts is None:
283
+ added_fields = set()
284
+ for resource_extracted_texts in resources_extracted_texts:
285
+ if resource_extracted_texts is None:
292
286
  continue
293
- for field, extracted_text in extracted_texts:
287
+ for field, extracted_text in resource_extracted_texts:
288
+ # First off, remove the text block ids from paragraphs that belong to
289
+ # the same field, as otherwise the context will be duplicated.
290
+ for tb_id in context.text_block_ids():
291
+ if tb_id.startswith(field.full()):
292
+ del context[tb_id]
294
293
  # Add the extracted text of each field to the context.
295
- context[field.resource_unique_id] = extracted_text
294
+ context[field.full()] = extracted_text
295
+ added_fields.add(field.full())
296
+
297
+ if strategy.include_remaining_text_blocks:
298
+ for paragraph in ordered_paragraphs:
299
+ pid = cast(ParagraphId, parse_text_block_id(paragraph.id))
300
+ if pid.field_id.full() not in added_fields:
301
+ context[paragraph.id] = _clean_paragraph_text(paragraph)
296
302
 
297
303
 
298
- async def composed_prompt_context(
304
+ async def extend_prompt_context_with_metadata(
305
+ context: CappedPromptContext,
306
+ kbid: str,
307
+ strategy: MetadataExtensionStrategy,
308
+ ) -> None:
309
+ text_block_ids: list[TextBlockId] = []
310
+ for text_block_id in context.text_block_ids():
311
+ try:
312
+ text_block_ids.append(parse_text_block_id(text_block_id))
313
+ except ValueError: # pragma: no cover
314
+ # Some text block ids are not paragraphs nor fields, so they are skipped
315
+ # (e.g. USER_CONTEXT_0, when the user provides extra context)
316
+ continue
317
+ if len(text_block_ids) == 0: # pragma: no cover
318
+ return
319
+
320
+ if MetadataExtensionType.ORIGIN in strategy.types:
321
+ await extend_prompt_context_with_origin_metadata(context, kbid, text_block_ids)
322
+
323
+ if MetadataExtensionType.CLASSIFICATION_LABELS in strategy.types:
324
+ await extend_prompt_context_with_classification_labels(context, kbid, text_block_ids)
325
+
326
+ if MetadataExtensionType.NERS in strategy.types:
327
+ await extend_prompt_context_with_ner(context, kbid, text_block_ids)
328
+
329
+ if MetadataExtensionType.EXTRA_METADATA in strategy.types:
330
+ await extend_prompt_context_with_extra_metadata(context, kbid, text_block_ids)
331
+
332
+
333
+ def parse_text_block_id(text_block_id: str) -> TextBlockId:
334
+ try:
335
+ # Typically, the text block id is a paragraph id
336
+ return ParagraphId.from_string(text_block_id)
337
+ except ValueError:
338
+ # When we're doing `full_resource` or `hierarchy` strategies,the text block id
339
+ # is a field id
340
+ return FieldId.from_string(text_block_id)
341
+
342
+
343
+ async def extend_prompt_context_with_origin_metadata(context, kbid, text_block_ids: list[TextBlockId]):
344
+ async def _get_origin(kbid: str, rid: str) -> tuple[str, Optional[Origin]]:
345
+ origin = None
346
+ resource = await cache.get_resource(kbid, rid)
347
+ if resource is not None:
348
+ pb_origin = await resource.get_origin()
349
+ if pb_origin is not None:
350
+ origin = from_proto.origin(pb_origin)
351
+ return rid, origin
352
+
353
+ rids = {tb_id.rid for tb_id in text_block_ids}
354
+ origins = await run_concurrently([_get_origin(kbid, rid) for rid in rids])
355
+ rid_to_origin = {rid: origin for rid, origin in origins if origin is not None}
356
+ for tb_id in text_block_ids:
357
+ origin = rid_to_origin.get(tb_id.rid)
358
+ if origin is not None and tb_id.full() in context.output:
359
+ context[tb_id.full()] += f"\n\nDOCUMENT METADATA AT ORIGIN:\n{to_yaml(origin)}"
360
+
361
+
362
+ async def extend_prompt_context_with_classification_labels(
363
+ context, kbid, text_block_ids: list[TextBlockId]
364
+ ):
365
+ async def _get_labels(kbid: str, _id: TextBlockId) -> tuple[TextBlockId, list[tuple[str, str]]]:
366
+ fid = _id if isinstance(_id, FieldId) else _id.field_id
367
+ labels = set()
368
+ resource = await cache.get_resource(kbid, fid.rid)
369
+ if resource is not None:
370
+ pb_basic = await resource.get_basic()
371
+ if pb_basic is not None:
372
+ # Add the classification labels of the resource
373
+ for classif in pb_basic.usermetadata.classifications:
374
+ labels.add((classif.labelset, classif.label))
375
+ # Add the classifications labels of the field
376
+ for fc in pb_basic.computedmetadata.field_classifications:
377
+ if fc.field.field == fid.key and fc.field.field_type == fid.pb_type:
378
+ for classif in fc.classifications:
379
+ if classif.cancelled_by_user: # pragma: no cover
380
+ continue
381
+ labels.add((classif.labelset, classif.label))
382
+ return _id, list(labels)
383
+
384
+ classif_labels = await run_concurrently([_get_labels(kbid, tb_id) for tb_id in text_block_ids])
385
+ tb_id_to_labels = {tb_id: labels for tb_id, labels in classif_labels if len(labels) > 0}
386
+ for tb_id in text_block_ids:
387
+ labels = tb_id_to_labels.get(tb_id)
388
+ if labels is not None and tb_id.full() in context.output:
389
+ labels_text = "DOCUMENT CLASSIFICATION LABELS:"
390
+ for labelset, label in labels:
391
+ labels_text += f"\n - {label} ({labelset})"
392
+ context[tb_id.full()] += "\n\n" + labels_text
393
+
394
+
395
+ async def extend_prompt_context_with_ner(context, kbid, text_block_ids: list[TextBlockId]):
396
+ async def _get_ners(kbid: str, _id: TextBlockId) -> tuple[TextBlockId, dict[str, set[str]]]:
397
+ fid = _id if isinstance(_id, FieldId) else _id.field_id
398
+ ners: dict[str, set[str]] = {}
399
+ resource = await cache.get_resource(kbid, fid.rid)
400
+ if resource is not None:
401
+ field = await resource.get_field(fid.key, fid.pb_type, load=False)
402
+ fcm = await field.get_field_metadata()
403
+ if fcm is not None:
404
+ # Data Augmentation + Processor entities
405
+ for (
406
+ data_aumgentation_task_id,
407
+ entities_wrapper,
408
+ ) in fcm.metadata.entities.items():
409
+ for entity in entities_wrapper.entities:
410
+ ners.setdefault(entity.label, set()).add(entity.text)
411
+ # Legacy processor entities
412
+ # TODO: Remove once processor doesn't use this anymore and remove the positions and ner fields from the message
413
+ for token, family in fcm.metadata.ner.items():
414
+ ners.setdefault(family, set()).add(token)
415
+ return _id, ners
416
+
417
+ nerss = await run_concurrently([_get_ners(kbid, tb_id) for tb_id in text_block_ids])
418
+ tb_id_to_ners = {tb_id: ners for tb_id, ners in nerss if len(ners) > 0}
419
+ for tb_id in text_block_ids:
420
+ ners = tb_id_to_ners.get(tb_id)
421
+ if ners is not None and tb_id.full() in context.output:
422
+ ners_text = "DOCUMENT NAMED ENTITIES (NERs):"
423
+ for family, tokens in ners.items():
424
+ ners_text += f"\n - {family}:"
425
+ for token in sorted(list(tokens)):
426
+ ners_text += f"\n - {token}"
427
+ context[tb_id.full()] += "\n\n" + ners_text
428
+
429
+
430
+ async def extend_prompt_context_with_extra_metadata(context, kbid, text_block_ids: list[TextBlockId]):
431
+ async def _get_extra(kbid: str, rid: str) -> tuple[str, Optional[Extra]]:
432
+ extra = None
433
+ resource = await cache.get_resource(kbid, rid)
434
+ if resource is not None:
435
+ pb_extra = await resource.get_extra()
436
+ if pb_extra is not None:
437
+ extra = from_proto.extra(pb_extra)
438
+ return rid, extra
439
+
440
+ rids = {tb_id.rid for tb_id in text_block_ids}
441
+ extras = await run_concurrently([_get_extra(kbid, rid) for rid in rids])
442
+ rid_to_extra = {rid: extra for rid, extra in extras if extra is not None}
443
+ for tb_id in text_block_ids:
444
+ extra = rid_to_extra.get(tb_id.rid)
445
+ if extra is not None and tb_id.full() in context.output:
446
+ context[tb_id.full()] += f"\n\nDOCUMENT EXTRA METADATA:\n{to_yaml(extra)}"
447
+
448
+
449
+ def to_yaml(obj: BaseModel) -> str:
450
+ return yaml.dump(
451
+ obj.model_dump(exclude_none=True, exclude_defaults=True, exclude_unset=True),
452
+ default_flow_style=False,
453
+ indent=2,
454
+ sort_keys=True,
455
+ )
456
+
457
+
458
+ async def field_extension_prompt_context(
299
459
  context: CappedPromptContext,
300
460
  kbid: str,
301
461
  ordered_paragraphs: list[FindParagraph],
302
- extend_with_fields: list[str],
462
+ strategy: FieldExtensionStrategy,
303
463
  ) -> None:
304
464
  """
305
465
  Algorithm steps:
@@ -310,33 +470,402 @@ async def composed_prompt_context(
310
470
  """
311
471
  ordered_resources = []
312
472
  for paragraph in ordered_paragraphs:
313
- resource_uuid = paragraph.id.split("/")[0]
473
+ resource_uuid = ParagraphId.from_string(paragraph.id).rid
314
474
  if resource_uuid not in ordered_resources:
315
475
  ordered_resources.append(resource_uuid)
316
476
 
317
477
  # Fetch the extracted texts of the specified fields for each resource
318
- txn = await get_read_only_transaction()
319
- kb_obj = KnowledgeBoxORM(txn, await get_storage(), kbid)
320
-
321
- tasks = [
322
- get_resource_field_extracted_text(kb_obj, resource_uuid, field_id)
323
- for resource_uuid in ordered_resources
324
- for field_id in extend_with_fields
325
- ]
478
+ extend_fields = strategy.fields
479
+ extend_field_ids = []
480
+ for resource_uuid in ordered_resources:
481
+ for field_id in extend_fields:
482
+ try:
483
+ fid = FieldId.from_string(f"{resource_uuid}/{field_id.strip('/')}")
484
+ extend_field_ids.append(fid)
485
+ except ValueError: # pragma: no cover
486
+ # Invalid field id, skiping
487
+ continue
488
+
489
+ tasks = [hydrate_field_text(kbid, fid) for fid in extend_field_ids]
326
490
  field_extracted_texts = await run_concurrently(tasks)
327
491
 
328
492
  for result in field_extracted_texts:
329
- if result is None:
493
+ if result is None: # pragma: no cover
330
494
  continue
331
- # Add the extracted text of each field to the beginning of the context.
332
495
  field, extracted_text = result
333
- context[field.resource_unique_id] = extracted_text
496
+ # First off, remove the text block ids from paragraphs that belong to
497
+ # the same field, as otherwise the context will be duplicated.
498
+ for tb_id in context.text_block_ids():
499
+ if tb_id.startswith(field.full()):
500
+ del context[tb_id]
501
+ # Add the extracted text of each field to the beginning of the context.
502
+ context[field.full()] = extracted_text
334
503
 
335
504
  # Add the extracted text of each paragraph to the end of the context.
336
505
  for paragraph in ordered_paragraphs:
337
506
  context[paragraph.id] = _clean_paragraph_text(paragraph)
338
507
 
339
508
 
509
+ async def get_paragraph_text_with_neighbours(
510
+ kbid: str,
511
+ pid: ParagraphId,
512
+ field_paragraphs: list[ParagraphId],
513
+ before: int = 0,
514
+ after: int = 0,
515
+ ) -> tuple[ParagraphId, str]:
516
+ """
517
+ This function will get the paragraph text of the paragraph with the neighbouring paragraphs included.
518
+ Parameters:
519
+ kbid: The knowledge box id.
520
+ pid: The matching paragraph id.
521
+ field_paragraphs: The list of paragraph ids of the field.
522
+ before: The number of paragraphs to include before the matching paragraph.
523
+ after: The number of paragraphs to include after the matching paragraph.
524
+ """
525
+
526
+ async def _get_paragraph_text(
527
+ kbid: str,
528
+ pid: ParagraphId,
529
+ ) -> tuple[ParagraphId, str]:
530
+ return pid, await get_paragraph_text(
531
+ kbid=kbid,
532
+ paragraph_id=pid,
533
+ log_on_missing_field=True,
534
+ )
535
+
536
+ ops = []
537
+ try:
538
+ for paragraph_index in get_neighbouring_paragraph_indexes(
539
+ field_paragraphs=field_paragraphs,
540
+ matching_paragraph=pid,
541
+ before=before,
542
+ after=after,
543
+ ):
544
+ neighbour_pid = field_paragraphs[paragraph_index]
545
+ ops.append(
546
+ asyncio.create_task(
547
+ _get_paragraph_text(
548
+ kbid=kbid,
549
+ pid=neighbour_pid,
550
+ )
551
+ )
552
+ )
553
+ except ParagraphIdNotFoundInExtractedMetadata:
554
+ logger.warning(
555
+ "Could not find matching paragraph in extracted metadata. This is odd and needs to be investigated.",
556
+ extra={
557
+ "kbid": kbid,
558
+ "matching_paragraph": pid.full(),
559
+ "field_paragraphs": [p.full() for p in field_paragraphs],
560
+ },
561
+ )
562
+ # If we could not find the matching paragraph in the extracted metadata, we can't retrieve
563
+ # the neighbouring paragraphs and we simply fetch the text of the matching paragraph.
564
+ ops.append(
565
+ asyncio.create_task(
566
+ _get_paragraph_text(
567
+ kbid=kbid,
568
+ pid=pid,
569
+ )
570
+ )
571
+ )
572
+
573
+ results = []
574
+ if len(ops) > 0:
575
+ results = await asyncio.gather(*ops)
576
+
577
+ # Sort the results by the paragraph start
578
+ results.sort(key=lambda x: x[0].paragraph_start)
579
+ paragraph_texts = []
580
+ for _, text in results:
581
+ if text != "":
582
+ paragraph_texts.append(text)
583
+ return pid, "\n\n".join(paragraph_texts)
584
+
585
+
586
+ async def get_field_paragraphs_list(
587
+ kbid: str,
588
+ field: FieldId,
589
+ paragraphs: list[ParagraphId],
590
+ ) -> None:
591
+ """
592
+ Modifies the paragraphs list by adding the paragraph ids of the field, sorted by position.
593
+ """
594
+ resource = await cache.get_resource(kbid, field.rid)
595
+ if resource is None: # pragma: no cover
596
+ return
597
+ field_obj: Field = await resource.get_field(key=field.key, type=field.pb_type, load=False)
598
+ field_metadata: Optional[resources_pb2.FieldComputedMetadata] = await field_obj.get_field_metadata(
599
+ force=True
600
+ )
601
+ if field_metadata is None: # pragma: no cover
602
+ return
603
+ for paragraph in field_metadata.metadata.paragraphs:
604
+ paragraphs.append(
605
+ ParagraphId(
606
+ field_id=field,
607
+ paragraph_start=paragraph.start,
608
+ paragraph_end=paragraph.end,
609
+ )
610
+ )
611
+
612
+
613
+ async def neighbouring_paragraphs_prompt_context(
614
+ context: CappedPromptContext,
615
+ kbid: str,
616
+ ordered_text_blocks: list[FindParagraph],
617
+ strategy: NeighbouringParagraphsStrategy,
618
+ ) -> None:
619
+ """
620
+ This function will get the paragraph texts and then craft a context with the neighbouring paragraphs of the
621
+ paragraphs in the ordered_paragraphs list. The number of paragraphs to include before and after each paragraph
622
+ """
623
+ # First, get the sorted list of paragraphs for each matching field
624
+ # so we can know the indexes of the neighbouring paragraphs
625
+ unique_fields = {
626
+ ParagraphId.from_string(text_block.id).field_id for text_block in ordered_text_blocks
627
+ }
628
+ paragraphs_by_field: dict[FieldId, list[ParagraphId]] = {}
629
+ field_ops = []
630
+ for field_id in unique_fields:
631
+ plist = paragraphs_by_field.setdefault(field_id, [])
632
+ field_ops.append(
633
+ asyncio.create_task(get_field_paragraphs_list(kbid=kbid, field=field_id, paragraphs=plist))
634
+ )
635
+ if field_ops:
636
+ await asyncio.gather(*field_ops)
637
+
638
+ # Now, get the paragraph texts with the neighbouring paragraphs
639
+ paragraph_ops = []
640
+ for text_block in ordered_text_blocks:
641
+ pid = ParagraphId.from_string(text_block.id)
642
+ paragraph_ops.append(
643
+ asyncio.create_task(
644
+ get_paragraph_text_with_neighbours(
645
+ kbid=kbid,
646
+ pid=pid,
647
+ before=strategy.before,
648
+ after=strategy.after,
649
+ field_paragraphs=paragraphs_by_field.get(pid.field_id, []),
650
+ )
651
+ )
652
+ )
653
+ if not paragraph_ops: # pragma: no cover
654
+ return
655
+
656
+ results: list[tuple[ParagraphId, str]] = await asyncio.gather(*paragraph_ops)
657
+ # Add the paragraph texts to the context
658
+ for pid, text in results:
659
+ if text != "":
660
+ context[pid.full()] = text
661
+
662
+
663
+ async def conversation_prompt_context(
664
+ context: CappedPromptContext,
665
+ kbid: str,
666
+ ordered_paragraphs: list[FindParagraph],
667
+ conversational_strategy: ConversationalStrategy,
668
+ visual_llm: bool,
669
+ ):
670
+ analyzed_fields: List[str] = []
671
+ async with get_driver().transaction(read_only=True) as txn:
672
+ storage = await get_storage()
673
+ kb = KnowledgeBoxORM(txn, storage, kbid)
674
+ for paragraph in ordered_paragraphs:
675
+ context[paragraph.id] = _clean_paragraph_text(paragraph)
676
+
677
+ # If the paragraph is a conversation and it matches semantically, we assume we
678
+ # have matched with the question, therefore try to include the answer to the
679
+ # context by pulling the next few messages of the conversation field
680
+ rid, field_type, field_id, mident = paragraph.id.split("/")[:4]
681
+ if field_type == "c" and paragraph.score_type in (
682
+ SCORE_TYPE.VECTOR,
683
+ SCORE_TYPE.BOTH,
684
+ SCORE_TYPE.BM25,
685
+ ):
686
+ field_unique_id = "-".join([rid, field_type, field_id])
687
+ if field_unique_id in analyzed_fields:
688
+ continue
689
+ resource = await kb.get(rid)
690
+ if resource is None: # pragma: no cover
691
+ continue
692
+
693
+ field_obj: Conversation = await resource.get_field(
694
+ field_id, FIELD_TYPE_STR_TO_PB["c"], load=True
695
+ ) # type: ignore
696
+ cmetadata = await field_obj.get_metadata()
697
+
698
+ attachments: List[resources_pb2.FieldRef] = []
699
+ if conversational_strategy.full:
700
+ extracted_text = await field_obj.get_extracted_text()
701
+ for current_page in range(1, cmetadata.pages + 1):
702
+ conv = await field_obj.db_get_value(current_page)
703
+
704
+ for message in conv.messages:
705
+ ident = message.ident
706
+ if extracted_text is not None:
707
+ text = extracted_text.split_text.get(ident, message.content.text.strip())
708
+ else:
709
+ text = message.content.text.strip()
710
+ pid = f"{rid}/{field_type}/{field_id}/{ident}/0-{len(text) + 1}"
711
+ context[pid] = text
712
+ attachments.extend(message.content.attachments_fields)
713
+ else:
714
+ # Add first message
715
+ extracted_text = await field_obj.get_extracted_text()
716
+ first_page = await field_obj.db_get_value()
717
+ if len(first_page.messages) > 0:
718
+ message = first_page.messages[0]
719
+ ident = message.ident
720
+ if extracted_text is not None:
721
+ text = extracted_text.split_text.get(ident, message.content.text.strip())
722
+ else:
723
+ text = message.content.text.strip()
724
+ pid = f"{rid}/{field_type}/{field_id}/{ident}/0-{len(text) + 1}"
725
+ context[pid] = text
726
+ attachments.extend(message.content.attachments_fields)
727
+
728
+ messages: Deque[resources_pb2.Message] = deque(
729
+ maxlen=conversational_strategy.max_messages
730
+ )
731
+
732
+ pending = -1
733
+ for page in range(1, cmetadata.pages + 1):
734
+ # Collect the messages with the window asked by the user arround the match paragraph
735
+ conv = await field_obj.db_get_value(page)
736
+ for message in conv.messages:
737
+ messages.append(message)
738
+ if pending > 0:
739
+ pending -= 1
740
+ if message.ident == mident:
741
+ pending = (conversational_strategy.max_messages - 1) // 2
742
+ if pending == 0:
743
+ break
744
+ if pending == 0:
745
+ break
746
+
747
+ for message in messages:
748
+ text = message.content.text.strip()
749
+ pid = f"{rid}/{field_type}/{field_id}/{message.ident}/0-{len(message.content.text) + 1}"
750
+ context[pid] = text
751
+ attachments.extend(message.content.attachments_fields)
752
+
753
+ if conversational_strategy.attachments_text:
754
+ # add on the context the images if vlm enabled
755
+ for attachment in attachments:
756
+ field: File = await resource.get_field(
757
+ attachment.field_id, attachment.field_type, load=True
758
+ ) # type: ignore
759
+ extracted_text = await field.get_extracted_text()
760
+ if extracted_text is not None:
761
+ pid = f"{rid}/{field_type}/{attachment.field_id}/0-{len(extracted_text.text) + 1}"
762
+ context[pid] = f"Attachment {attachment.field_id}: {extracted_text.text}\n\n"
763
+
764
+ if conversational_strategy.attachments_images and visual_llm:
765
+ for attachment in attachments:
766
+ file_field: File = await resource.get_field(
767
+ attachment.field_id, attachment.field_type, load=True
768
+ ) # type: ignore
769
+ image = await get_file_thumbnail_image(file_field)
770
+ if image is not None:
771
+ pid = f"{rid}/f/{attachment.field_id}/0-0"
772
+ context.images[pid] = image
773
+
774
+ analyzed_fields.append(field_unique_id)
775
+
776
+
777
+ async def hierarchy_prompt_context(
778
+ context: CappedPromptContext,
779
+ kbid: str,
780
+ ordered_paragraphs: list[FindParagraph],
781
+ strategy: HierarchyResourceStrategy,
782
+ ) -> None:
783
+ """
784
+ This function will get the paragraph texts (possibly with extra characters, if extra_characters > 0) and then
785
+ craft a context with all paragraphs of the same resource grouped together. Moreover, on each group of paragraphs,
786
+ it includes the resource title and summary so that the LLM can have a better understanding of the context.
787
+ """
788
+ paragraphs_extra_characters = max(strategy.count, 0)
789
+ # Make a copy of the ordered paragraphs to avoid modifying the original list, which is returned
790
+ # in the response to the user
791
+ ordered_paragraphs_copy = copy.deepcopy(ordered_paragraphs)
792
+ resources: Dict[str, ExtraCharsParagraph] = {}
793
+
794
+ # Iterate paragraphs to get extended text
795
+ for paragraph in ordered_paragraphs_copy:
796
+ paragraph_id = ParagraphId.from_string(paragraph.id)
797
+ extended_paragraph_text = paragraph.text
798
+ if paragraphs_extra_characters > 0:
799
+ extended_paragraph_text = await get_paragraph_text(
800
+ kbid=kbid,
801
+ paragraph_id=paragraph_id,
802
+ log_on_missing_field=True,
803
+ )
804
+ rid = paragraph_id.rid
805
+ if rid not in resources:
806
+ # Get the title and the summary of the resource
807
+ title_text = await get_paragraph_text(
808
+ kbid=kbid,
809
+ paragraph_id=ParagraphId(
810
+ field_id=FieldId(
811
+ rid=rid,
812
+ type="a",
813
+ key="title",
814
+ ),
815
+ paragraph_start=0,
816
+ paragraph_end=500,
817
+ ),
818
+ log_on_missing_field=False,
819
+ )
820
+ summary_text = await get_paragraph_text(
821
+ kbid=kbid,
822
+ paragraph_id=ParagraphId(
823
+ field_id=FieldId(
824
+ rid=rid,
825
+ type="a",
826
+ key="summary",
827
+ ),
828
+ paragraph_start=0,
829
+ paragraph_end=1000,
830
+ ),
831
+ log_on_missing_field=False,
832
+ )
833
+ resources[rid] = ExtraCharsParagraph(
834
+ title=title_text,
835
+ summary=summary_text,
836
+ paragraphs=[(paragraph, extended_paragraph_text)],
837
+ )
838
+ else:
839
+ resources[rid].paragraphs.append((paragraph, extended_paragraph_text))
840
+
841
+ # Modify the first paragraph of each resource to include the title and summary of the resource, as well as the
842
+ # extended paragraph text of all the paragraphs in the resource.
843
+ for values in resources.values():
844
+ title_text = values.title
845
+ summary_text = values.summary
846
+ first_paragraph = None
847
+ text_with_hierarchy = ""
848
+ for paragraph, extended_paragraph_text in values.paragraphs:
849
+ if first_paragraph is None:
850
+ first_paragraph = paragraph
851
+ text_with_hierarchy += "\n EXTRACTED BLOCK: \n " + extended_paragraph_text + " \n\n "
852
+ # All paragraphs of the resource are cleared except the first one, which will be the
853
+ # one containing the whole hierarchy information
854
+ paragraph.text = ""
855
+
856
+ if first_paragraph is not None:
857
+ # The first paragraph is the only one holding the hierarchy information
858
+ first_paragraph.text = f"DOCUMENT: {title_text} \n SUMMARY: {summary_text} \n RESOURCE CONTENT: {text_with_hierarchy}"
859
+
860
+ # Now that the paragraphs have been modified, we can add them to the context
861
+ for paragraph in ordered_paragraphs_copy:
862
+ if paragraph.text == "":
863
+ # Skip paragraphs that were cleared in the hierarchy expansion
864
+ continue
865
+ context[paragraph.id] = _clean_paragraph_text(paragraph)
866
+ return
867
+
868
+
340
869
  class PromptContextBuilder:
341
870
  """
342
871
  Builds the context for the LLM prompt.
@@ -345,7 +874,7 @@ class PromptContextBuilder:
345
874
  def __init__(
346
875
  self,
347
876
  kbid: str,
348
- find_results: KnowledgeboxFindResults,
877
+ ordered_paragraphs: list[FindParagraph],
349
878
  resource: Optional[str] = None,
350
879
  user_context: Optional[list[str]] = None,
351
880
  strategies: Optional[Sequence[RagStrategy]] = None,
@@ -354,7 +883,7 @@ class PromptContextBuilder:
354
883
  visual_llm: bool = False,
355
884
  ):
356
885
  self.kbid = kbid
357
- self.ordered_paragraphs = get_ordered_paragraphs(find_results)
886
+ self.ordered_paragraphs = ordered_paragraphs
358
887
  self.resource = resource
359
888
  self.user_context = user_context
360
889
  self.strategies = strategies
@@ -374,98 +903,175 @@ class PromptContextBuilder:
374
903
  ccontext = CappedPromptContext(max_size=self.max_context_characters)
375
904
  self.prepend_user_context(ccontext)
376
905
  await self._build_context(ccontext)
377
-
378
906
  if self.visual_llm:
379
907
  await self._build_context_images(ccontext)
380
908
 
381
909
  context = ccontext.output
382
910
  context_images = ccontext.images
383
- context_order = {
384
- text_block_id: order for order, text_block_id in enumerate(context.keys())
385
- }
911
+ context_order = {text_block_id: order for order, text_block_id in enumerate(context.keys())}
386
912
  return context, context_order, context_images
387
913
 
388
914
  async def _build_context_images(self, context: CappedPromptContext) -> None:
389
- flatten_strategies = []
390
- page_count = 5
391
- gather_pages = False
392
- gather_tables = False
393
- if self.image_strategies is not None:
394
- for strategy in self.image_strategies:
395
- flatten_strategies.append(strategy.name)
396
- if strategy.name == ImageRagStrategyName.PAGE_IMAGE:
397
- gather_pages = True
398
- if strategy.count is not None: # type: ignore
399
- page_count = strategy.count # type: ignore
400
- if strategy.name == ImageRagStrategyName.TABLES:
401
- gather_tables = True
402
-
915
+ if self.image_strategies is None or len(self.image_strategies) == 0:
916
+ # Nothing to do
917
+ return
918
+ page_image_strategy: Optional[PageImageStrategy] = None
919
+ max_page_images = 5
920
+ table_image_strategy: Optional[TableImageStrategy] = None
921
+ paragraph_image_strategy: Optional[ParagraphImageStrategy] = None
922
+ for strategy in self.image_strategies:
923
+ if strategy.name == ImageRagStrategyName.PAGE_IMAGE:
924
+ if page_image_strategy is None:
925
+ page_image_strategy = cast(PageImageStrategy, strategy)
926
+ if page_image_strategy.count is not None:
927
+ max_page_images = page_image_strategy.count
928
+ elif strategy.name == ImageRagStrategyName.TABLES:
929
+ if table_image_strategy is None:
930
+ table_image_strategy = cast(TableImageStrategy, strategy)
931
+ elif strategy.name == ImageRagStrategyName.PARAGRAPH_IMAGE:
932
+ if paragraph_image_strategy is None:
933
+ paragraph_image_strategy = cast(ParagraphImageStrategy, strategy)
934
+ else: # pragma: no cover
935
+ logger.warning(
936
+ "Unknown image strategy",
937
+ extra={"strategy": strategy.name, "kbid": self.kbid},
938
+ )
939
+ page_images_added = 0
403
940
  for paragraph in self.ordered_paragraphs:
404
- if paragraph.page_with_visual and paragraph.position:
405
- if (
406
- gather_pages
407
- and paragraph.position.page_number
408
- and len(context.images) < page_count
409
- ):
410
- field = "/".join(paragraph.id.split("/")[:3])
411
- page = paragraph.position.page_number
412
- page_id = f"{field}/{page}"
413
- if page_id not in context.images:
414
- context.images[page_id] = await get_page_image(
415
- self.kbid, paragraph.id, page
416
- )
941
+ pid = ParagraphId.from_string(paragraph.id)
942
+ paragraph_page_number = get_paragraph_page_number(paragraph)
417
943
  if (
418
- gather_tables
419
- and paragraph.is_a_table
420
- and paragraph.reference
421
- and paragraph.reference != ""
944
+ page_image_strategy is not None
945
+ and page_images_added < max_page_images
946
+ and paragraph_page_number is not None
422
947
  ):
423
- image = paragraph.reference
424
- context.images[paragraph.id] = await get_paragraph_image(
425
- self.kbid, paragraph.id, image
426
- )
948
+ # page_image_id: rid/f/myfield/0
949
+ page_image_id = "/".join([pid.field_id.full(), str(paragraph_page_number)])
950
+ if page_image_id not in context.images:
951
+ image = await get_page_image(self.kbid, pid, paragraph_page_number)
952
+ if image is not None:
953
+ context.images[page_image_id] = image
954
+ page_images_added += 1
955
+ else:
956
+ logger.warning(
957
+ f"Could not retrieve image for paragraph from storage",
958
+ extra={
959
+ "kbid": self.kbid,
960
+ "paragraph": pid.full(),
961
+ "page_number": paragraph_page_number,
962
+ },
963
+ )
964
+
965
+ add_table = table_image_strategy is not None and paragraph.is_a_table
966
+ add_paragraph = paragraph_image_strategy is not None and not paragraph.is_a_table
967
+ if (add_table or add_paragraph) and (
968
+ paragraph.reference is not None and paragraph.reference != ""
969
+ ):
970
+ pimage = await get_paragraph_image(self.kbid, pid, paragraph.reference)
971
+ if pimage is not None:
972
+ context.images[paragraph.id] = pimage
973
+ else:
974
+ logger.warning(
975
+ f"Could not retrieve image for paragraph from storage",
976
+ extra={
977
+ "kbid": self.kbid,
978
+ "paragraph": pid.full(),
979
+ "reference": paragraph.reference,
980
+ },
981
+ )
427
982
 
428
983
  async def _build_context(self, context: CappedPromptContext) -> None:
429
984
  if self.strategies is None or len(self.strategies) == 0:
985
+ # When no strategy is specified, use the default one
430
986
  await default_prompt_context(context, self.kbid, self.ordered_paragraphs)
431
987
  return
432
-
433
- number_of_full_resources = 0
434
- distance = 0
435
- extend_with_fields = []
988
+ else:
989
+ # Add the paragraphs to the context and then apply the strategies
990
+ for paragraph in self.ordered_paragraphs:
991
+ context[paragraph.id] = _clean_paragraph_text(paragraph)
992
+
993
+ full_resource: Optional[FullResourceStrategy] = None
994
+ hierarchy: Optional[HierarchyResourceStrategy] = None
995
+ neighbouring_paragraphs: Optional[NeighbouringParagraphsStrategy] = None
996
+ field_extension: Optional[FieldExtensionStrategy] = None
997
+ metadata_extension: Optional[MetadataExtensionStrategy] = None
998
+ conversational_strategy: Optional[ConversationalStrategy] = None
436
999
  for strategy in self.strategies:
437
1000
  if strategy.name == RagStrategyName.FIELD_EXTENSION:
438
- extend_with_fields.extend(strategy.fields) # type: ignore
1001
+ field_extension = cast(FieldExtensionStrategy, strategy)
1002
+ elif strategy.name == RagStrategyName.CONVERSATION:
1003
+ conversational_strategy = cast(ConversationalStrategy, strategy)
439
1004
  elif strategy.name == RagStrategyName.FULL_RESOURCE:
440
- if self.resource:
441
- number_of_full_resources = 1
442
- else:
443
- number_of_full_resources = strategy.count or len(self.ordered_paragraphs) # type: ignore
1005
+ full_resource = cast(FullResourceStrategy, strategy)
1006
+ if self.resource: # pragma: no cover
1007
+ # When the retrieval is scoped to a specific resource
1008
+ # the full resource strategy only includes that resource
1009
+ full_resource.count = 1
444
1010
  elif strategy.name == RagStrategyName.HIERARCHY:
445
- distance = strategy.count # type: ignore
1011
+ hierarchy = cast(HierarchyResourceStrategy, strategy)
1012
+ elif strategy.name == RagStrategyName.NEIGHBOURING_PARAGRAPHS:
1013
+ neighbouring_paragraphs = cast(NeighbouringParagraphsStrategy, strategy)
1014
+ elif strategy.name == RagStrategyName.METADATA_EXTENSION:
1015
+ metadata_extension = cast(MetadataExtensionStrategy, strategy)
1016
+ elif strategy.name != RagStrategyName.PREQUERIES: # pragma: no cover
1017
+ # Prequeries are not handled here
1018
+ logger.warning(
1019
+ "Unknown rag strategy",
1020
+ extra={"strategy": strategy.name, "kbid": self.kbid},
1021
+ )
446
1022
 
447
- if number_of_full_resources:
1023
+ if full_resource:
1024
+ # When full resoure is enabled, only metadata extension is allowed.
448
1025
  await full_resource_prompt_context(
449
1026
  context,
450
1027
  self.kbid,
451
1028
  self.ordered_paragraphs,
452
1029
  self.resource,
453
- number_of_full_resources,
1030
+ full_resource,
454
1031
  )
1032
+ if metadata_extension:
1033
+ await extend_prompt_context_with_metadata(context, self.kbid, metadata_extension)
455
1034
  return
456
1035
 
457
- if distance > 0:
458
- await get_extra_chars(self.kbid, self.ordered_paragraphs, distance)
459
- await default_prompt_context(context, self.kbid, self.ordered_paragraphs)
460
- return
1036
+ if hierarchy:
1037
+ await hierarchy_prompt_context(
1038
+ context,
1039
+ self.kbid,
1040
+ self.ordered_paragraphs,
1041
+ hierarchy,
1042
+ )
1043
+ if neighbouring_paragraphs:
1044
+ await neighbouring_paragraphs_prompt_context(
1045
+ context,
1046
+ self.kbid,
1047
+ self.ordered_paragraphs,
1048
+ neighbouring_paragraphs,
1049
+ )
1050
+ if field_extension:
1051
+ await field_extension_prompt_context(
1052
+ context,
1053
+ self.kbid,
1054
+ self.ordered_paragraphs,
1055
+ field_extension,
1056
+ )
1057
+ if conversational_strategy:
1058
+ await conversation_prompt_context(
1059
+ context,
1060
+ self.kbid,
1061
+ self.ordered_paragraphs,
1062
+ conversational_strategy,
1063
+ self.visual_llm,
1064
+ )
1065
+ if metadata_extension:
1066
+ await extend_prompt_context_with_metadata(context, self.kbid, metadata_extension)
461
1067
 
462
- await composed_prompt_context(
463
- context,
464
- self.kbid,
465
- self.ordered_paragraphs,
466
- extend_with_fields=extend_with_fields,
467
- )
468
- return
1068
+
1069
+ def get_paragraph_page_number(paragraph: FindParagraph) -> Optional[int]:
1070
+ if not paragraph.page_with_visual:
1071
+ return None
1072
+ if paragraph.position is None:
1073
+ return None
1074
+ return paragraph.position.page_number
469
1075
 
470
1076
 
471
1077
  @dataclass
@@ -475,67 +1081,6 @@ class ExtraCharsParagraph:
475
1081
  paragraphs: List[Tuple[FindParagraph, str]]
476
1082
 
477
1083
 
478
- async def get_extra_chars(
479
- kbid: str, ordered_paragraphs: list[FindParagraph], distance: int
480
- ):
481
- etcache = paragraphs.ExtractedTextCache()
482
- resources: Dict[str, ExtraCharsParagraph] = {}
483
- for paragraph in ordered_paragraphs:
484
- rid, field_type, field = paragraph.id.split("/")[:3]
485
- field_path = "/".join([rid, field_type, field])
486
- position = paragraph.id.split("/")[-1]
487
- start, end = position.split("-")
488
- int_start = int(start)
489
- int_end = int(end) + distance
490
-
491
- new_text = await paragraphs.get_paragraph_text(
492
- kbid=kbid,
493
- rid=rid,
494
- field=field_path,
495
- start=int_start,
496
- end=int_end,
497
- extracted_text_cache=etcache,
498
- )
499
- if rid not in resources:
500
- title_text = await paragraphs.get_paragraph_text(
501
- kbid=kbid,
502
- rid=rid,
503
- field="/a/title",
504
- start=0,
505
- end=500,
506
- extracted_text_cache=etcache,
507
- )
508
- summary_text = await paragraphs.get_paragraph_text(
509
- kbid=kbid,
510
- rid=rid,
511
- field="/a/summary",
512
- start=0,
513
- end=1000,
514
- extracted_text_cache=etcache,
515
- )
516
- resources[rid] = ExtraCharsParagraph(
517
- title=title_text,
518
- summary=summary_text,
519
- paragraphs=[(paragraph, new_text)],
520
- )
521
- else:
522
- resources[rid].paragraphs.append((paragraph, new_text)) # type: ignore
523
-
524
- for values in resources.values():
525
- title_text = values.title
526
- summary_text = values.summary
527
- first_paragraph = None
528
- text = ""
529
- for paragraph, text in values.paragraphs:
530
- if first_paragraph is None:
531
- first_paragraph = paragraph
532
- text += "EXTRACTED BLOCK: \n " + text + " \n\n "
533
- paragraph.text = ""
534
-
535
- if first_paragraph is not None:
536
- first_paragraph.text = f"DOCUMENT: {title_text} \n SUMMARY: {summary_text} \n RESOURCE CONTENT: {text}"
537
-
538
-
539
1084
  def _clean_paragraph_text(paragraph: FindParagraph) -> str:
540
1085
  text = paragraph.text.strip()
541
1086
  # Do not send highlight marks on prompt context
@@ -543,17 +1088,23 @@ def _clean_paragraph_text(paragraph: FindParagraph) -> str:
543
1088
  return text
544
1089
 
545
1090
 
546
- def get_ordered_paragraphs(results: KnowledgeboxFindResults) -> list[FindParagraph]:
1091
+ def get_neighbouring_paragraph_indexes(
1092
+ field_paragraphs: list[ParagraphId],
1093
+ matching_paragraph: ParagraphId,
1094
+ before: int,
1095
+ after: int,
1096
+ ) -> list[int]:
547
1097
  """
548
- Returns the list of paragraphs in the results, ordered by relevance.
1098
+ Returns the indexes of the neighbouring paragraphs to fetch (including the matching paragraph).
549
1099
  """
550
- return sorted(
551
- [
552
- paragraph
553
- for resource in results.resources.values()
554
- for field in resource.fields.values()
555
- for paragraph in field.paragraphs.values()
556
- ],
557
- key=lambda paragraph: paragraph.order,
558
- reverse=False,
559
- )
1100
+ assert before >= 0
1101
+ assert after >= 0
1102
+ try:
1103
+ matching_index = field_paragraphs.index(matching_paragraph)
1104
+ except ValueError:
1105
+ raise ParagraphIdNotFoundInExtractedMetadata(
1106
+ f"Matching paragraph {matching_paragraph.full()} not found in extracted metadata"
1107
+ )
1108
+ start_index = max(0, matching_index - before)
1109
+ end_index = min(len(field_paragraphs), matching_index + after + 1)
1110
+ return list(range(start_index, end_index))