nucliadb 2.46.1.post382__py3-none-any.whl → 6.2.1.post2777__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (431) hide show
  1. migrations/0002_rollover_shards.py +1 -2
  2. migrations/0003_allfields_key.py +2 -37
  3. migrations/0004_rollover_shards.py +1 -2
  4. migrations/0005_rollover_shards.py +1 -2
  5. migrations/0006_rollover_shards.py +2 -4
  6. migrations/0008_cleanup_leftover_rollover_metadata.py +1 -2
  7. migrations/0009_upgrade_relations_and_texts_to_v2.py +5 -4
  8. migrations/0010_fix_corrupt_indexes.py +11 -12
  9. migrations/0011_materialize_labelset_ids.py +2 -18
  10. migrations/0012_rollover_shards.py +6 -12
  11. migrations/0013_rollover_shards.py +2 -4
  12. migrations/0014_rollover_shards.py +5 -7
  13. migrations/0015_targeted_rollover.py +6 -12
  14. migrations/0016_upgrade_to_paragraphs_v2.py +27 -32
  15. migrations/0017_multiple_writable_shards.py +3 -6
  16. migrations/0018_purge_orphan_kbslugs.py +59 -0
  17. migrations/0019_upgrade_to_paragraphs_v3.py +66 -0
  18. migrations/0020_drain_nodes_from_cluster.py +83 -0
  19. nucliadb/standalone/tests/unit/test_run.py → migrations/0021_overwrite_vectorsets_key.py +17 -18
  20. nucliadb/tests/unit/test_openapi.py → migrations/0022_fix_paragraph_deletion_bug.py +16 -11
  21. migrations/0023_backfill_pg_catalog.py +80 -0
  22. migrations/0025_assign_models_to_kbs_v2.py +113 -0
  23. migrations/0026_fix_high_cardinality_content_types.py +61 -0
  24. migrations/0027_rollover_texts3.py +73 -0
  25. nucliadb/ingest/fields/date.py → migrations/pg/0001_bootstrap.py +10 -12
  26. migrations/pg/0002_catalog.py +42 -0
  27. nucliadb/ingest/tests/unit/test_settings.py → migrations/pg/0003_catalog_kbid_index.py +5 -3
  28. nucliadb/common/cluster/base.py +41 -24
  29. nucliadb/common/cluster/discovery/base.py +6 -14
  30. nucliadb/common/cluster/discovery/k8s.py +9 -19
  31. nucliadb/common/cluster/discovery/manual.py +1 -3
  32. nucliadb/common/cluster/discovery/single.py +1 -2
  33. nucliadb/common/cluster/discovery/utils.py +1 -3
  34. nucliadb/common/cluster/grpc_node_dummy.py +11 -16
  35. nucliadb/common/cluster/index_node.py +10 -19
  36. nucliadb/common/cluster/manager.py +223 -102
  37. nucliadb/common/cluster/rebalance.py +42 -37
  38. nucliadb/common/cluster/rollover.py +377 -204
  39. nucliadb/common/cluster/settings.py +16 -9
  40. nucliadb/common/cluster/standalone/grpc_node_binding.py +24 -76
  41. nucliadb/common/cluster/standalone/index_node.py +4 -11
  42. nucliadb/common/cluster/standalone/service.py +2 -6
  43. nucliadb/common/cluster/standalone/utils.py +9 -6
  44. nucliadb/common/cluster/utils.py +43 -29
  45. nucliadb/common/constants.py +20 -0
  46. nucliadb/common/context/__init__.py +6 -4
  47. nucliadb/common/context/fastapi.py +8 -5
  48. nucliadb/{tests/knowledgeboxes/__init__.py → common/counters.py} +8 -2
  49. nucliadb/common/datamanagers/__init__.py +24 -5
  50. nucliadb/common/datamanagers/atomic.py +102 -0
  51. nucliadb/common/datamanagers/cluster.py +5 -5
  52. nucliadb/common/datamanagers/entities.py +6 -16
  53. nucliadb/common/datamanagers/fields.py +84 -0
  54. nucliadb/common/datamanagers/kb.py +101 -24
  55. nucliadb/common/datamanagers/labels.py +26 -56
  56. nucliadb/common/datamanagers/processing.py +2 -6
  57. nucliadb/common/datamanagers/resources.py +214 -117
  58. nucliadb/common/datamanagers/rollover.py +77 -16
  59. nucliadb/{ingest/orm → common/datamanagers}/synonyms.py +16 -28
  60. nucliadb/common/datamanagers/utils.py +19 -11
  61. nucliadb/common/datamanagers/vectorsets.py +110 -0
  62. nucliadb/common/external_index_providers/base.py +257 -0
  63. nucliadb/{ingest/tests/unit/test_cache.py → common/external_index_providers/exceptions.py} +9 -8
  64. nucliadb/common/external_index_providers/manager.py +101 -0
  65. nucliadb/common/external_index_providers/pinecone.py +933 -0
  66. nucliadb/common/external_index_providers/settings.py +52 -0
  67. nucliadb/common/http_clients/auth.py +3 -6
  68. nucliadb/common/http_clients/processing.py +6 -11
  69. nucliadb/common/http_clients/utils.py +1 -3
  70. nucliadb/common/ids.py +240 -0
  71. nucliadb/common/locking.py +43 -13
  72. nucliadb/common/maindb/driver.py +11 -35
  73. nucliadb/common/maindb/exceptions.py +6 -6
  74. nucliadb/common/maindb/local.py +22 -9
  75. nucliadb/common/maindb/pg.py +206 -111
  76. nucliadb/common/maindb/utils.py +13 -44
  77. nucliadb/common/models_utils/from_proto.py +479 -0
  78. nucliadb/common/models_utils/to_proto.py +60 -0
  79. nucliadb/common/nidx.py +260 -0
  80. nucliadb/export_import/datamanager.py +25 -19
  81. nucliadb/export_import/exceptions.py +8 -0
  82. nucliadb/export_import/exporter.py +20 -7
  83. nucliadb/export_import/importer.py +6 -11
  84. nucliadb/export_import/models.py +5 -5
  85. nucliadb/export_import/tasks.py +4 -4
  86. nucliadb/export_import/utils.py +94 -54
  87. nucliadb/health.py +1 -3
  88. nucliadb/ingest/app.py +15 -11
  89. nucliadb/ingest/consumer/auditing.py +30 -147
  90. nucliadb/ingest/consumer/consumer.py +96 -52
  91. nucliadb/ingest/consumer/materializer.py +10 -12
  92. nucliadb/ingest/consumer/pull.py +12 -27
  93. nucliadb/ingest/consumer/service.py +20 -19
  94. nucliadb/ingest/consumer/shard_creator.py +7 -14
  95. nucliadb/ingest/consumer/utils.py +1 -3
  96. nucliadb/ingest/fields/base.py +139 -188
  97. nucliadb/ingest/fields/conversation.py +18 -5
  98. nucliadb/ingest/fields/exceptions.py +1 -4
  99. nucliadb/ingest/fields/file.py +7 -25
  100. nucliadb/ingest/fields/link.py +11 -16
  101. nucliadb/ingest/fields/text.py +9 -4
  102. nucliadb/ingest/orm/brain.py +255 -262
  103. nucliadb/ingest/orm/broker_message.py +181 -0
  104. nucliadb/ingest/orm/entities.py +36 -51
  105. nucliadb/ingest/orm/exceptions.py +12 -0
  106. nucliadb/ingest/orm/knowledgebox.py +334 -278
  107. nucliadb/ingest/orm/processor/__init__.py +2 -697
  108. nucliadb/ingest/orm/processor/auditing.py +117 -0
  109. nucliadb/ingest/orm/processor/data_augmentation.py +164 -0
  110. nucliadb/ingest/orm/processor/pgcatalog.py +84 -0
  111. nucliadb/ingest/orm/processor/processor.py +752 -0
  112. nucliadb/ingest/orm/processor/sequence_manager.py +1 -1
  113. nucliadb/ingest/orm/resource.py +280 -520
  114. nucliadb/ingest/orm/utils.py +25 -31
  115. nucliadb/ingest/partitions.py +3 -9
  116. nucliadb/ingest/processing.py +76 -81
  117. nucliadb/ingest/py.typed +0 -0
  118. nucliadb/ingest/serialize.py +37 -173
  119. nucliadb/ingest/service/__init__.py +1 -3
  120. nucliadb/ingest/service/writer.py +186 -577
  121. nucliadb/ingest/settings.py +13 -22
  122. nucliadb/ingest/utils.py +3 -6
  123. nucliadb/learning_proxy.py +264 -51
  124. nucliadb/metrics_exporter.py +30 -19
  125. nucliadb/middleware/__init__.py +1 -3
  126. nucliadb/migrator/command.py +1 -3
  127. nucliadb/migrator/datamanager.py +13 -13
  128. nucliadb/migrator/migrator.py +57 -37
  129. nucliadb/migrator/settings.py +2 -1
  130. nucliadb/migrator/utils.py +18 -10
  131. nucliadb/purge/__init__.py +139 -33
  132. nucliadb/purge/orphan_shards.py +7 -13
  133. nucliadb/reader/__init__.py +1 -3
  134. nucliadb/reader/api/models.py +3 -14
  135. nucliadb/reader/api/v1/__init__.py +0 -1
  136. nucliadb/reader/api/v1/download.py +27 -94
  137. nucliadb/reader/api/v1/export_import.py +4 -4
  138. nucliadb/reader/api/v1/knowledgebox.py +13 -13
  139. nucliadb/reader/api/v1/learning_config.py +8 -12
  140. nucliadb/reader/api/v1/resource.py +67 -93
  141. nucliadb/reader/api/v1/services.py +70 -125
  142. nucliadb/reader/app.py +16 -46
  143. nucliadb/reader/lifecycle.py +18 -4
  144. nucliadb/reader/py.typed +0 -0
  145. nucliadb/reader/reader/notifications.py +10 -31
  146. nucliadb/search/__init__.py +1 -3
  147. nucliadb/search/api/v1/__init__.py +2 -2
  148. nucliadb/search/api/v1/ask.py +112 -0
  149. nucliadb/search/api/v1/catalog.py +184 -0
  150. nucliadb/search/api/v1/feedback.py +17 -25
  151. nucliadb/search/api/v1/find.py +41 -41
  152. nucliadb/search/api/v1/knowledgebox.py +90 -62
  153. nucliadb/search/api/v1/predict_proxy.py +2 -2
  154. nucliadb/search/api/v1/resource/ask.py +66 -117
  155. nucliadb/search/api/v1/resource/search.py +51 -72
  156. nucliadb/search/api/v1/router.py +1 -0
  157. nucliadb/search/api/v1/search.py +50 -197
  158. nucliadb/search/api/v1/suggest.py +40 -54
  159. nucliadb/search/api/v1/summarize.py +9 -5
  160. nucliadb/search/api/v1/utils.py +2 -1
  161. nucliadb/search/app.py +16 -48
  162. nucliadb/search/lifecycle.py +10 -3
  163. nucliadb/search/predict.py +176 -188
  164. nucliadb/search/py.typed +0 -0
  165. nucliadb/search/requesters/utils.py +41 -63
  166. nucliadb/search/search/cache.py +149 -20
  167. nucliadb/search/search/chat/ask.py +918 -0
  168. nucliadb/search/{tests/unit/test_run.py → search/chat/exceptions.py} +14 -13
  169. nucliadb/search/search/chat/images.py +41 -17
  170. nucliadb/search/search/chat/prompt.py +851 -282
  171. nucliadb/search/search/chat/query.py +274 -267
  172. nucliadb/{writer/resource/slug.py → search/search/cut.py} +8 -6
  173. nucliadb/search/search/fetch.py +43 -36
  174. nucliadb/search/search/filters.py +9 -15
  175. nucliadb/search/search/find.py +214 -54
  176. nucliadb/search/search/find_merge.py +408 -391
  177. nucliadb/search/search/hydrator.py +191 -0
  178. nucliadb/search/search/merge.py +198 -234
  179. nucliadb/search/search/metrics.py +73 -2
  180. nucliadb/search/search/paragraphs.py +64 -106
  181. nucliadb/search/search/pgcatalog.py +233 -0
  182. nucliadb/search/search/predict_proxy.py +1 -1
  183. nucliadb/search/search/query.py +386 -257
  184. nucliadb/search/search/query_parser/exceptions.py +22 -0
  185. nucliadb/search/search/query_parser/models.py +101 -0
  186. nucliadb/search/search/query_parser/parser.py +183 -0
  187. nucliadb/search/search/rank_fusion.py +204 -0
  188. nucliadb/search/search/rerankers.py +270 -0
  189. nucliadb/search/search/shards.py +4 -38
  190. nucliadb/search/search/summarize.py +14 -18
  191. nucliadb/search/search/utils.py +27 -4
  192. nucliadb/search/settings.py +15 -1
  193. nucliadb/standalone/api_router.py +4 -10
  194. nucliadb/standalone/app.py +17 -14
  195. nucliadb/standalone/auth.py +7 -21
  196. nucliadb/standalone/config.py +9 -12
  197. nucliadb/standalone/introspect.py +5 -5
  198. nucliadb/standalone/lifecycle.py +26 -25
  199. nucliadb/standalone/migrations.py +58 -0
  200. nucliadb/standalone/purge.py +9 -8
  201. nucliadb/standalone/py.typed +0 -0
  202. nucliadb/standalone/run.py +25 -18
  203. nucliadb/standalone/settings.py +10 -14
  204. nucliadb/standalone/versions.py +15 -5
  205. nucliadb/tasks/consumer.py +8 -12
  206. nucliadb/tasks/producer.py +7 -6
  207. nucliadb/tests/config.py +53 -0
  208. nucliadb/train/__init__.py +1 -3
  209. nucliadb/train/api/utils.py +1 -2
  210. nucliadb/train/api/v1/shards.py +2 -2
  211. nucliadb/train/api/v1/trainset.py +4 -6
  212. nucliadb/train/app.py +14 -47
  213. nucliadb/train/generator.py +10 -19
  214. nucliadb/train/generators/field_classifier.py +7 -19
  215. nucliadb/train/generators/field_streaming.py +156 -0
  216. nucliadb/train/generators/image_classifier.py +12 -18
  217. nucliadb/train/generators/paragraph_classifier.py +5 -9
  218. nucliadb/train/generators/paragraph_streaming.py +6 -9
  219. nucliadb/train/generators/question_answer_streaming.py +19 -20
  220. nucliadb/train/generators/sentence_classifier.py +9 -15
  221. nucliadb/train/generators/token_classifier.py +45 -36
  222. nucliadb/train/generators/utils.py +14 -18
  223. nucliadb/train/lifecycle.py +7 -3
  224. nucliadb/train/nodes.py +23 -32
  225. nucliadb/train/py.typed +0 -0
  226. nucliadb/train/servicer.py +13 -21
  227. nucliadb/train/settings.py +2 -6
  228. nucliadb/train/types.py +13 -10
  229. nucliadb/train/upload.py +3 -6
  230. nucliadb/train/uploader.py +20 -25
  231. nucliadb/train/utils.py +1 -1
  232. nucliadb/writer/__init__.py +1 -3
  233. nucliadb/writer/api/constants.py +0 -5
  234. nucliadb/{ingest/fields/keywordset.py → writer/api/utils.py} +13 -10
  235. nucliadb/writer/api/v1/export_import.py +102 -49
  236. nucliadb/writer/api/v1/field.py +196 -620
  237. nucliadb/writer/api/v1/knowledgebox.py +221 -71
  238. nucliadb/writer/api/v1/learning_config.py +2 -2
  239. nucliadb/writer/api/v1/resource.py +114 -216
  240. nucliadb/writer/api/v1/services.py +64 -132
  241. nucliadb/writer/api/v1/slug.py +61 -0
  242. nucliadb/writer/api/v1/transaction.py +67 -0
  243. nucliadb/writer/api/v1/upload.py +184 -215
  244. nucliadb/writer/app.py +11 -61
  245. nucliadb/writer/back_pressure.py +62 -43
  246. nucliadb/writer/exceptions.py +0 -4
  247. nucliadb/writer/lifecycle.py +21 -15
  248. nucliadb/writer/py.typed +0 -0
  249. nucliadb/writer/resource/audit.py +2 -1
  250. nucliadb/writer/resource/basic.py +48 -62
  251. nucliadb/writer/resource/field.py +45 -135
  252. nucliadb/writer/resource/origin.py +1 -2
  253. nucliadb/writer/settings.py +14 -5
  254. nucliadb/writer/tus/__init__.py +17 -15
  255. nucliadb/writer/tus/azure.py +111 -0
  256. nucliadb/writer/tus/dm.py +17 -5
  257. nucliadb/writer/tus/exceptions.py +1 -3
  258. nucliadb/writer/tus/gcs.py +56 -84
  259. nucliadb/writer/tus/local.py +21 -37
  260. nucliadb/writer/tus/s3.py +28 -68
  261. nucliadb/writer/tus/storage.py +5 -56
  262. nucliadb/writer/vectorsets.py +125 -0
  263. nucliadb-6.2.1.post2777.dist-info/METADATA +148 -0
  264. nucliadb-6.2.1.post2777.dist-info/RECORD +343 -0
  265. {nucliadb-2.46.1.post382.dist-info → nucliadb-6.2.1.post2777.dist-info}/WHEEL +1 -1
  266. nucliadb/common/maindb/redis.py +0 -194
  267. nucliadb/common/maindb/tikv.py +0 -412
  268. nucliadb/ingest/fields/layout.py +0 -58
  269. nucliadb/ingest/tests/conftest.py +0 -30
  270. nucliadb/ingest/tests/fixtures.py +0 -771
  271. nucliadb/ingest/tests/integration/consumer/__init__.py +0 -18
  272. nucliadb/ingest/tests/integration/consumer/test_auditing.py +0 -80
  273. nucliadb/ingest/tests/integration/consumer/test_materializer.py +0 -89
  274. nucliadb/ingest/tests/integration/consumer/test_pull.py +0 -144
  275. nucliadb/ingest/tests/integration/consumer/test_service.py +0 -81
  276. nucliadb/ingest/tests/integration/consumer/test_shard_creator.py +0 -68
  277. nucliadb/ingest/tests/integration/ingest/test_ingest.py +0 -691
  278. nucliadb/ingest/tests/integration/ingest/test_processing_engine.py +0 -95
  279. nucliadb/ingest/tests/integration/ingest/test_relations.py +0 -272
  280. nucliadb/ingest/tests/unit/consumer/__init__.py +0 -18
  281. nucliadb/ingest/tests/unit/consumer/test_auditing.py +0 -140
  282. nucliadb/ingest/tests/unit/consumer/test_consumer.py +0 -69
  283. nucliadb/ingest/tests/unit/consumer/test_pull.py +0 -60
  284. nucliadb/ingest/tests/unit/consumer/test_shard_creator.py +0 -139
  285. nucliadb/ingest/tests/unit/consumer/test_utils.py +0 -67
  286. nucliadb/ingest/tests/unit/orm/__init__.py +0 -19
  287. nucliadb/ingest/tests/unit/orm/test_brain.py +0 -247
  288. nucliadb/ingest/tests/unit/orm/test_processor.py +0 -131
  289. nucliadb/ingest/tests/unit/orm/test_resource.py +0 -275
  290. nucliadb/ingest/tests/unit/test_partitions.py +0 -40
  291. nucliadb/ingest/tests/unit/test_processing.py +0 -171
  292. nucliadb/middleware/transaction.py +0 -117
  293. nucliadb/reader/api/v1/learning_collector.py +0 -63
  294. nucliadb/reader/tests/__init__.py +0 -19
  295. nucliadb/reader/tests/conftest.py +0 -31
  296. nucliadb/reader/tests/fixtures.py +0 -136
  297. nucliadb/reader/tests/test_list_resources.py +0 -75
  298. nucliadb/reader/tests/test_reader_file_download.py +0 -273
  299. nucliadb/reader/tests/test_reader_resource.py +0 -379
  300. nucliadb/reader/tests/test_reader_resource_field.py +0 -219
  301. nucliadb/search/api/v1/chat.py +0 -258
  302. nucliadb/search/api/v1/resource/chat.py +0 -94
  303. nucliadb/search/tests/__init__.py +0 -19
  304. nucliadb/search/tests/conftest.py +0 -33
  305. nucliadb/search/tests/fixtures.py +0 -199
  306. nucliadb/search/tests/node.py +0 -465
  307. nucliadb/search/tests/unit/__init__.py +0 -18
  308. nucliadb/search/tests/unit/api/__init__.py +0 -19
  309. nucliadb/search/tests/unit/api/v1/__init__.py +0 -19
  310. nucliadb/search/tests/unit/api/v1/resource/__init__.py +0 -19
  311. nucliadb/search/tests/unit/api/v1/resource/test_ask.py +0 -67
  312. nucliadb/search/tests/unit/api/v1/resource/test_chat.py +0 -97
  313. nucliadb/search/tests/unit/api/v1/test_chat.py +0 -96
  314. nucliadb/search/tests/unit/api/v1/test_predict_proxy.py +0 -98
  315. nucliadb/search/tests/unit/api/v1/test_summarize.py +0 -93
  316. nucliadb/search/tests/unit/search/__init__.py +0 -18
  317. nucliadb/search/tests/unit/search/requesters/__init__.py +0 -18
  318. nucliadb/search/tests/unit/search/requesters/test_utils.py +0 -210
  319. nucliadb/search/tests/unit/search/search/__init__.py +0 -19
  320. nucliadb/search/tests/unit/search/search/test_shards.py +0 -45
  321. nucliadb/search/tests/unit/search/search/test_utils.py +0 -82
  322. nucliadb/search/tests/unit/search/test_chat_prompt.py +0 -266
  323. nucliadb/search/tests/unit/search/test_fetch.py +0 -108
  324. nucliadb/search/tests/unit/search/test_filters.py +0 -125
  325. nucliadb/search/tests/unit/search/test_paragraphs.py +0 -157
  326. nucliadb/search/tests/unit/search/test_predict_proxy.py +0 -106
  327. nucliadb/search/tests/unit/search/test_query.py +0 -201
  328. nucliadb/search/tests/unit/test_app.py +0 -79
  329. nucliadb/search/tests/unit/test_find_merge.py +0 -112
  330. nucliadb/search/tests/unit/test_merge.py +0 -34
  331. nucliadb/search/tests/unit/test_predict.py +0 -584
  332. nucliadb/standalone/tests/__init__.py +0 -19
  333. nucliadb/standalone/tests/conftest.py +0 -33
  334. nucliadb/standalone/tests/fixtures.py +0 -38
  335. nucliadb/standalone/tests/unit/__init__.py +0 -18
  336. nucliadb/standalone/tests/unit/test_api_router.py +0 -61
  337. nucliadb/standalone/tests/unit/test_auth.py +0 -169
  338. nucliadb/standalone/tests/unit/test_introspect.py +0 -35
  339. nucliadb/standalone/tests/unit/test_versions.py +0 -68
  340. nucliadb/tests/benchmarks/__init__.py +0 -19
  341. nucliadb/tests/benchmarks/test_search.py +0 -99
  342. nucliadb/tests/conftest.py +0 -32
  343. nucliadb/tests/fixtures.py +0 -736
  344. nucliadb/tests/knowledgeboxes/philosophy_books.py +0 -203
  345. nucliadb/tests/knowledgeboxes/ten_dummy_resources.py +0 -109
  346. nucliadb/tests/migrations/__init__.py +0 -19
  347. nucliadb/tests/migrations/test_migration_0017.py +0 -80
  348. nucliadb/tests/tikv.py +0 -240
  349. nucliadb/tests/unit/__init__.py +0 -19
  350. nucliadb/tests/unit/common/__init__.py +0 -19
  351. nucliadb/tests/unit/common/cluster/__init__.py +0 -19
  352. nucliadb/tests/unit/common/cluster/discovery/__init__.py +0 -19
  353. nucliadb/tests/unit/common/cluster/discovery/test_k8s.py +0 -170
  354. nucliadb/tests/unit/common/cluster/standalone/__init__.py +0 -18
  355. nucliadb/tests/unit/common/cluster/standalone/test_service.py +0 -113
  356. nucliadb/tests/unit/common/cluster/standalone/test_utils.py +0 -59
  357. nucliadb/tests/unit/common/cluster/test_cluster.py +0 -399
  358. nucliadb/tests/unit/common/cluster/test_kb_shard_manager.py +0 -178
  359. nucliadb/tests/unit/common/cluster/test_rollover.py +0 -279
  360. nucliadb/tests/unit/common/maindb/__init__.py +0 -18
  361. nucliadb/tests/unit/common/maindb/test_driver.py +0 -127
  362. nucliadb/tests/unit/common/maindb/test_tikv.py +0 -53
  363. nucliadb/tests/unit/common/maindb/test_utils.py +0 -81
  364. nucliadb/tests/unit/common/test_context.py +0 -36
  365. nucliadb/tests/unit/export_import/__init__.py +0 -19
  366. nucliadb/tests/unit/export_import/test_datamanager.py +0 -37
  367. nucliadb/tests/unit/export_import/test_utils.py +0 -294
  368. nucliadb/tests/unit/migrator/__init__.py +0 -19
  369. nucliadb/tests/unit/migrator/test_migrator.py +0 -87
  370. nucliadb/tests/unit/tasks/__init__.py +0 -19
  371. nucliadb/tests/unit/tasks/conftest.py +0 -42
  372. nucliadb/tests/unit/tasks/test_consumer.py +0 -93
  373. nucliadb/tests/unit/tasks/test_producer.py +0 -95
  374. nucliadb/tests/unit/tasks/test_tasks.py +0 -60
  375. nucliadb/tests/unit/test_field_ids.py +0 -49
  376. nucliadb/tests/unit/test_health.py +0 -84
  377. nucliadb/tests/unit/test_kb_slugs.py +0 -54
  378. nucliadb/tests/unit/test_learning_proxy.py +0 -252
  379. nucliadb/tests/unit/test_metrics_exporter.py +0 -77
  380. nucliadb/tests/unit/test_purge.py +0 -138
  381. nucliadb/tests/utils/__init__.py +0 -74
  382. nucliadb/tests/utils/aiohttp_session.py +0 -44
  383. nucliadb/tests/utils/broker_messages/__init__.py +0 -167
  384. nucliadb/tests/utils/broker_messages/fields.py +0 -181
  385. nucliadb/tests/utils/broker_messages/helpers.py +0 -33
  386. nucliadb/tests/utils/entities.py +0 -78
  387. nucliadb/train/api/v1/check.py +0 -60
  388. nucliadb/train/tests/__init__.py +0 -19
  389. nucliadb/train/tests/conftest.py +0 -29
  390. nucliadb/train/tests/fixtures.py +0 -342
  391. nucliadb/train/tests/test_field_classification.py +0 -122
  392. nucliadb/train/tests/test_get_entities.py +0 -80
  393. nucliadb/train/tests/test_get_info.py +0 -51
  394. nucliadb/train/tests/test_get_ontology.py +0 -34
  395. nucliadb/train/tests/test_get_ontology_count.py +0 -63
  396. nucliadb/train/tests/test_image_classification.py +0 -222
  397. nucliadb/train/tests/test_list_fields.py +0 -39
  398. nucliadb/train/tests/test_list_paragraphs.py +0 -73
  399. nucliadb/train/tests/test_list_resources.py +0 -39
  400. nucliadb/train/tests/test_list_sentences.py +0 -71
  401. nucliadb/train/tests/test_paragraph_classification.py +0 -123
  402. nucliadb/train/tests/test_paragraph_streaming.py +0 -118
  403. nucliadb/train/tests/test_question_answer_streaming.py +0 -239
  404. nucliadb/train/tests/test_sentence_classification.py +0 -143
  405. nucliadb/train/tests/test_token_classification.py +0 -136
  406. nucliadb/train/tests/utils.py +0 -108
  407. nucliadb/writer/layouts/__init__.py +0 -51
  408. nucliadb/writer/layouts/v1.py +0 -59
  409. nucliadb/writer/resource/vectors.py +0 -120
  410. nucliadb/writer/tests/__init__.py +0 -19
  411. nucliadb/writer/tests/conftest.py +0 -31
  412. nucliadb/writer/tests/fixtures.py +0 -192
  413. nucliadb/writer/tests/test_fields.py +0 -486
  414. nucliadb/writer/tests/test_files.py +0 -743
  415. nucliadb/writer/tests/test_knowledgebox.py +0 -49
  416. nucliadb/writer/tests/test_reprocess_file_field.py +0 -139
  417. nucliadb/writer/tests/test_resources.py +0 -546
  418. nucliadb/writer/tests/test_service.py +0 -137
  419. nucliadb/writer/tests/test_tus.py +0 -203
  420. nucliadb/writer/tests/utils.py +0 -35
  421. nucliadb/writer/tus/pg.py +0 -125
  422. nucliadb-2.46.1.post382.dist-info/METADATA +0 -134
  423. nucliadb-2.46.1.post382.dist-info/RECORD +0 -451
  424. {nucliadb/ingest/tests → migrations/pg}/__init__.py +0 -0
  425. /nucliadb/{ingest/tests/integration → common/external_index_providers}/__init__.py +0 -0
  426. /nucliadb/{ingest/tests/integration/ingest → common/models_utils}/__init__.py +0 -0
  427. /nucliadb/{ingest/tests/unit → search/search/query_parser}/__init__.py +0 -0
  428. /nucliadb/{ingest/tests → tests}/vectors.py +0 -0
  429. {nucliadb-2.46.1.post382.dist-info → nucliadb-6.2.1.post2777.dist-info}/entry_points.txt +0 -0
  430. {nucliadb-2.46.1.post382.dist-info → nucliadb-6.2.1.post2777.dist-info}/top_level.txt +0 -0
  431. {nucliadb-2.46.1.post382.dist-info → nucliadb-6.2.1.post2777.dist-info}/zip-safe +0 -0
@@ -17,32 +17,55 @@
17
17
  # You should have received a copy of the GNU Affero General Public License
18
18
  # along with this program. If not, see <http://www.gnu.org/licenses/>.
19
19
  #
20
+ import asyncio
21
+ import copy
22
+ from collections import deque
20
23
  from dataclasses import dataclass
21
- from typing import Dict, List, Optional, Sequence, Tuple
24
+ from typing import Deque, Dict, List, Optional, Sequence, Tuple, Union, cast
22
25
 
26
+ import yaml
27
+ from pydantic import BaseModel
28
+
29
+ from nucliadb.common.ids import FIELD_TYPE_STR_TO_PB, FieldId, ParagraphId
30
+ from nucliadb.common.maindb.utils import get_driver
31
+ from nucliadb.common.models_utils import from_proto
23
32
  from nucliadb.ingest.fields.base import Field
24
33
  from nucliadb.ingest.fields.conversation import Conversation
34
+ from nucliadb.ingest.fields.file import File
25
35
  from nucliadb.ingest.orm.knowledgebox import KnowledgeBox as KnowledgeBoxORM
26
- from nucliadb.ingest.orm.resource import KB_REVERSE
27
- from nucliadb.ingest.orm.resource import Resource as ResourceORM
28
- from nucliadb.middleware.transaction import get_read_only_transaction
29
36
  from nucliadb.search import logger
30
- from nucliadb.search.search import paragraphs
31
- from nucliadb.search.search.chat.images import get_page_image, get_paragraph_image
37
+ from nucliadb.search.search import cache
38
+ from nucliadb.search.search.chat.images import (
39
+ get_file_thumbnail_image,
40
+ get_page_image,
41
+ get_paragraph_image,
42
+ )
43
+ from nucliadb.search.search.hydrator import hydrate_field_text, hydrate_resource_text
44
+ from nucliadb.search.search.paragraphs import get_paragraph_text
45
+ from nucliadb_models.metadata import Extra, Origin
32
46
  from nucliadb_models.search import (
33
47
  SCORE_TYPE,
48
+ ConversationalStrategy,
49
+ FieldExtensionStrategy,
34
50
  FindParagraph,
51
+ FullResourceStrategy,
52
+ HierarchyResourceStrategy,
35
53
  ImageRagStrategy,
36
54
  ImageRagStrategyName,
37
- KnowledgeboxFindResults,
55
+ MetadataExtensionStrategy,
56
+ MetadataExtensionType,
57
+ NeighbouringParagraphsStrategy,
58
+ PageImageStrategy,
59
+ ParagraphImageStrategy,
38
60
  PromptContext,
39
61
  PromptContextImages,
40
62
  PromptContextOrder,
41
63
  RagStrategy,
42
64
  RagStrategyName,
65
+ TableImageStrategy,
43
66
  )
44
67
  from nucliadb_protos import resources_pb2
45
- from nucliadb_utils.asyncio_utils import ConcurrentRunner, run_concurrently
68
+ from nucliadb_utils.asyncio_utils import run_concurrently
46
69
  from nucliadb_utils.utilities import get_storage
47
70
 
48
71
  MAX_RESOURCE_TASKS = 5
@@ -53,12 +76,20 @@ MAX_RESOURCE_FIELD_TASKS = 4
53
76
  # The hope here is it will be enough to get the answer to the question.
54
77
  CONVERSATION_MESSAGE_CONTEXT_EXPANSION = 15
55
78
 
79
+ TextBlockId = Union[ParagraphId, FieldId]
80
+
81
+
82
+ class ParagraphIdNotFoundInExtractedMetadata(Exception):
83
+ pass
84
+
56
85
 
57
86
  class CappedPromptContext:
58
87
  """
59
- Class to keep track of the size of the prompt context and raise an exception if it exceeds the configured limit.
88
+ Class to keep track of the size (in number of characters) of the prompt context
89
+ and raise an exception if it exceeds the configured limit.
60
90
 
61
- This class will automatically trim data that exceeds the limit when it's being set on the dictionary.
91
+ This class will automatically trim data that exceeds the limit when it's being
92
+ set on the dictionary.
62
93
  """
63
94
 
64
95
  def __init__(self, max_size: Optional[int]):
@@ -68,15 +99,26 @@ class CappedPromptContext:
68
99
  self._size = 0
69
100
 
70
101
  def __setitem__(self, key: str, value: str) -> None:
102
+ prev_value_len = len(self.output.get(key, ""))
71
103
  if self.max_size is None:
72
- self.output[key] = value
104
+ # Unbounded size context
105
+ to_add = value
73
106
  else:
74
- existing_len = len(self.output.get(key, ""))
75
- self._size -= existing_len
76
- size_available = self.max_size - self._size
77
- if size_available > 0:
78
- self.output[key] = value[:size_available]
79
- self._size += len(self.output[key])
107
+ # Make sure we don't exceed the max size
108
+ size_available = max(self.max_size - self._size + prev_value_len, 0)
109
+ to_add = value[:size_available]
110
+ self.output[key] = to_add
111
+ self._size = self._size - prev_value_len + len(to_add)
112
+
113
+ def __getitem__(self, key: str) -> str:
114
+ return self.output.__getitem__(key)
115
+
116
+ def __delitem__(self, key: str) -> None:
117
+ value = self.output.pop(key, "")
118
+ self._size -= len(value)
119
+
120
+ def text_block_ids(self) -> list[str]:
121
+ return list(self.output.keys())
80
122
 
81
123
  @property
82
124
  def size(self) -> int:
@@ -91,15 +133,15 @@ async def get_next_conversation_messages(
91
133
  num_messages: int,
92
134
  message_type: Optional[resources_pb2.Message.MessageType.ValueType] = None,
93
135
  msg_to: Optional[str] = None,
94
- ):
136
+ ) -> List[resources_pb2.Message]:
95
137
  output = []
96
138
  cmetadata = await field_obj.get_metadata()
97
139
  for current_page in range(page, cmetadata.pages + 1):
98
140
  conv = await field_obj.db_get_value(current_page)
99
141
  for message in conv.messages[start_idx:]:
100
- if message_type is not None and message.type != message_type:
142
+ if message_type is not None and message.type != message_type: # pragma: no cover
101
143
  continue
102
- if msg_to is not None and msg_to not in message.to:
144
+ if msg_to is not None and msg_to not in message.to: # pragma: no cover
103
145
  continue
104
146
  output.append(message)
105
147
  if len(output) >= num_messages:
@@ -122,16 +164,21 @@ async def find_conversation_message(
122
164
 
123
165
 
124
166
  async def get_expanded_conversation_messages(
125
- *, kb: KnowledgeBoxORM, rid: str, field_id: str, mident: str
167
+ *,
168
+ kb: KnowledgeBoxORM,
169
+ rid: str,
170
+ field_id: str,
171
+ mident: str,
172
+ max_messages: int = CONVERSATION_MESSAGE_CONTEXT_EXPANSION,
126
173
  ) -> list[resources_pb2.Message]:
127
174
  resource = await kb.get(rid)
128
- if resource is None:
175
+ if resource is None: # pragma: no cover
129
176
  return []
130
- field_obj = await resource.get_field(field_id, KB_REVERSE["c"], load=True)
177
+ field_obj: Conversation = await resource.get_field(field_id, FIELD_TYPE_STR_TO_PB["c"], load=True) # type: ignore
131
178
  found_message, found_page, found_idx = await find_conversation_message(
132
179
  field_obj=field_obj, mident=mident
133
180
  )
134
- if found_message is None:
181
+ if found_message is None: # pragma: no cover
135
182
  return []
136
183
  elif found_message.type == resources_pb2.Message.MessageType.QUESTION:
137
184
  # only try to get answer if it was a question
@@ -147,14 +194,14 @@ async def get_expanded_conversation_messages(
147
194
  field_obj=field_obj,
148
195
  page=found_page,
149
196
  start_idx=found_idx + 1,
150
- num_messages=CONVERSATION_MESSAGE_CONTEXT_EXPANSION,
197
+ num_messages=max_messages,
151
198
  )
152
199
 
153
200
 
154
201
  async def default_prompt_context(
155
202
  context: CappedPromptContext,
156
203
  kbid: str,
157
- results: KnowledgeboxFindResults,
204
+ ordered_paragraphs: list[FindParagraph],
158
205
  ) -> None:
159
206
  """
160
207
  - Updates context (which is an ordered dict of text_block_id -> context_text).
@@ -166,128 +213,253 @@ async def default_prompt_context(
166
213
  - Using an dict prevents from duplicates pulled in through conversation expansion.
167
214
  """
168
215
  # Sort retrieved paragraphs by decreasing order (most relevant first)
169
- ordered_paras = get_ordered_paragraphs(results)
170
- txn = await get_read_only_transaction()
171
- storage = await get_storage()
172
- kb = KnowledgeBoxORM(txn, storage, kbid)
173
- for paragraph in ordered_paras:
174
- context[paragraph.id] = _clean_paragraph_text(paragraph)
175
-
176
- # If the paragraph is a conversation and it matches semantically, we assume we
177
- # have matched with the question, therefore try to include the answer to the
178
- # context by pulling the next few messages of the conversation field
179
- rid, field_type, field_id, mident = paragraph.id.split("/")[:4]
180
- if field_type == "c" and paragraph.score_type in (
181
- SCORE_TYPE.VECTOR,
182
- SCORE_TYPE.BOTH,
183
- ):
184
- expanded_msgs = await get_expanded_conversation_messages(
185
- kb=kb, rid=rid, field_id=field_id, mident=mident
186
- )
187
- for msg in expanded_msgs:
188
- text = msg.content.text.strip()
189
- pid = f"{rid}/{field_type}/{field_id}/{msg.ident}/0-{len(msg.content.text) + 1}"
190
- context[pid] = text
191
-
192
-
193
- async def get_field_extracted_text(field: Field) -> Optional[tuple[Field, str]]:
194
- extracted_text_pb = await field.get_extracted_text(force=True)
195
- if extracted_text_pb is None:
196
- return None
197
- return field, extracted_text_pb.text
198
-
199
-
200
- async def get_resource_field_extracted_text(
201
- kb_obj: KnowledgeBoxORM,
202
- resource_uuid,
203
- field_id: str,
204
- ) -> Optional[tuple[Field, str]]:
205
- resource = await kb_obj.get(resource_uuid)
206
- if resource is None:
207
- return None
208
-
209
- try:
210
- field_type, field_key = field_id.strip("/").split("/")
211
- except ValueError:
212
- logger.error(f"Invalid field id: {field_id}. Skipping getting extracted text.")
213
- return None
214
- field = await resource.get_field(field_key, KB_REVERSE[field_type], load=False)
215
- if field is None:
216
- return None
217
- result = await get_field_extracted_text(field)
218
- if result is None:
219
- return None
220
- _, extracted_text = result
221
- return field, extracted_text
222
-
223
-
224
- async def get_resource_extracted_texts(
225
- kbid: str,
226
- resource_uuid: str,
227
- ) -> list[tuple[Field, str]]:
228
- txn = await get_read_only_transaction()
229
- storage = await get_storage()
230
- kb = KnowledgeBoxORM(txn, storage, kbid)
231
- resource = ResourceORM(
232
- txn=txn,
233
- storage=storage,
234
- kb=kb,
235
- uuid=resource_uuid,
236
- )
237
-
238
- # Schedule the extraction of the text of each field in the resource
239
- runner = ConcurrentRunner(max_tasks=MAX_RESOURCE_FIELD_TASKS)
240
- for field_type, field_key in await resource.get_fields(force=True):
241
- field = await resource.get_field(field_key, field_type, load=False)
242
- runner.schedule(get_field_extracted_text(field))
243
-
244
- # Wait for the results
245
- results = await runner.wait()
246
- return [result for result in results if result is not None]
216
+ async with get_driver().transaction(read_only=True) as txn:
217
+ storage = await get_storage()
218
+ kb = KnowledgeBoxORM(txn, storage, kbid)
219
+ for paragraph in ordered_paragraphs:
220
+ context[paragraph.id] = _clean_paragraph_text(paragraph)
221
+
222
+ # If the paragraph is a conversation and it matches semantically, we assume we
223
+ # have matched with the question, therefore try to include the answer to the
224
+ # context by pulling the next few messages of the conversation field
225
+ rid, field_type, field_id, mident = paragraph.id.split("/")[:4]
226
+ if field_type == "c" and paragraph.score_type in (
227
+ SCORE_TYPE.VECTOR,
228
+ SCORE_TYPE.BOTH,
229
+ ):
230
+ expanded_msgs = await get_expanded_conversation_messages(
231
+ kb=kb, rid=rid, field_id=field_id, mident=mident
232
+ )
233
+ for msg in expanded_msgs:
234
+ text = msg.content.text.strip()
235
+ pid = f"{rid}/{field_type}/{field_id}/{msg.ident}/0-{len(msg.content.text) + 1}"
236
+ context[pid] = text
247
237
 
248
238
 
249
239
  async def full_resource_prompt_context(
250
240
  context: CappedPromptContext,
251
241
  kbid: str,
252
- results: KnowledgeboxFindResults,
253
- number_of_full_resources: Optional[int] = None,
242
+ ordered_paragraphs: list[FindParagraph],
243
+ resource: Optional[str],
244
+ strategy: FullResourceStrategy,
254
245
  ) -> None:
255
246
  """
256
247
  Algorithm steps:
257
248
  - Collect the list of resources in the results (in order of relevance).
258
249
  - For each resource, collect the extracted text from all its fields and craft the context.
259
- """
260
-
261
- # Collect the list of resources in the results (in order of relevance).
262
- ordered_paras = get_ordered_paragraphs(results)
263
- ordered_resources = []
264
- for paragraph in ordered_paras:
265
- resource_uuid = paragraph.id.split("/")[0]
266
- if resource_uuid not in ordered_resources:
267
- ordered_resources.append(resource_uuid)
250
+ Arguments:
251
+ context: The context to be updated.
252
+ kbid: The knowledge box id.
253
+ ordered_paragraphs: The results of the retrieval (find) operation.
254
+ resource: The resource to be included in the context. This is used only when chatting with a specific resource with no retrieval.
255
+ strategy: strategy instance containing, for example, the number of full resources to include in the context.
256
+ """ # noqa: E501
257
+ if resource is not None:
258
+ # The user has specified a resource to be included in the context.
259
+ ordered_resources = [resource]
260
+ else:
261
+ # Collect the list of resources in the results (in order of relevance).
262
+ ordered_resources = []
263
+ for paragraph in ordered_paragraphs:
264
+ resource_uuid = parse_text_block_id(paragraph.id).rid
265
+ if resource_uuid not in ordered_resources:
266
+ skip = False
267
+ if strategy.apply_to is not None:
268
+ # decide whether the resource should be extended or not
269
+ for label in strategy.apply_to.exclude:
270
+ skip = skip or (label in (paragraph.labels or []))
271
+
272
+ if not skip:
273
+ ordered_resources.append(resource_uuid)
268
274
 
269
275
  # For each resource, collect the extracted text from all its fields.
270
- resource_extracted_texts = await run_concurrently(
276
+ resources_extracted_texts = await run_concurrently(
271
277
  [
272
- get_resource_extracted_texts(kbid, resource_uuid)
273
- for resource_uuid in ordered_resources[:number_of_full_resources]
278
+ hydrate_resource_text(kbid, resource_uuid, max_concurrent_tasks=MAX_RESOURCE_FIELD_TASKS)
279
+ for resource_uuid in ordered_resources[: strategy.count]
274
280
  ],
275
281
  max_concurrent=MAX_RESOURCE_TASKS,
276
282
  )
277
-
278
- for extracted_texts in resource_extracted_texts:
279
- if extracted_texts is None:
283
+ added_fields = set()
284
+ for resource_extracted_texts in resources_extracted_texts:
285
+ if resource_extracted_texts is None:
280
286
  continue
281
- for field, extracted_text in extracted_texts:
287
+ for field, extracted_text in resource_extracted_texts:
288
+ # First off, remove the text block ids from paragraphs that belong to
289
+ # the same field, as otherwise the context will be duplicated.
290
+ for tb_id in context.text_block_ids():
291
+ if tb_id.startswith(field.full()):
292
+ del context[tb_id]
282
293
  # Add the extracted text of each field to the context.
283
- context[field.resource_unique_id] = extracted_text
294
+ context[field.full()] = extracted_text
295
+ added_fields.add(field.full())
296
+
297
+ if strategy.include_remaining_text_blocks:
298
+ for paragraph in ordered_paragraphs:
299
+ pid = cast(ParagraphId, parse_text_block_id(paragraph.id))
300
+ if pid.field_id.full() not in added_fields:
301
+ context[paragraph.id] = _clean_paragraph_text(paragraph)
284
302
 
285
303
 
286
- async def composed_prompt_context(
304
+ async def extend_prompt_context_with_metadata(
287
305
  context: CappedPromptContext,
288
306
  kbid: str,
289
- results: KnowledgeboxFindResults,
290
- extend_with_fields: list[str],
307
+ strategy: MetadataExtensionStrategy,
308
+ ) -> None:
309
+ text_block_ids: list[TextBlockId] = []
310
+ for text_block_id in context.text_block_ids():
311
+ try:
312
+ text_block_ids.append(parse_text_block_id(text_block_id))
313
+ except ValueError: # pragma: no cover
314
+ # Some text block ids are not paragraphs nor fields, so they are skipped
315
+ # (e.g. USER_CONTEXT_0, when the user provides extra context)
316
+ continue
317
+ if len(text_block_ids) == 0: # pragma: no cover
318
+ return
319
+
320
+ if MetadataExtensionType.ORIGIN in strategy.types:
321
+ await extend_prompt_context_with_origin_metadata(context, kbid, text_block_ids)
322
+
323
+ if MetadataExtensionType.CLASSIFICATION_LABELS in strategy.types:
324
+ await extend_prompt_context_with_classification_labels(context, kbid, text_block_ids)
325
+
326
+ if MetadataExtensionType.NERS in strategy.types:
327
+ await extend_prompt_context_with_ner(context, kbid, text_block_ids)
328
+
329
+ if MetadataExtensionType.EXTRA_METADATA in strategy.types:
330
+ await extend_prompt_context_with_extra_metadata(context, kbid, text_block_ids)
331
+
332
+
333
+ def parse_text_block_id(text_block_id: str) -> TextBlockId:
334
+ try:
335
+ # Typically, the text block id is a paragraph id
336
+ return ParagraphId.from_string(text_block_id)
337
+ except ValueError:
338
+ # When we're doing `full_resource` or `hierarchy` strategies,the text block id
339
+ # is a field id
340
+ return FieldId.from_string(text_block_id)
341
+
342
+
343
+ async def extend_prompt_context_with_origin_metadata(context, kbid, text_block_ids: list[TextBlockId]):
344
+ async def _get_origin(kbid: str, rid: str) -> tuple[str, Optional[Origin]]:
345
+ origin = None
346
+ resource = await cache.get_resource(kbid, rid)
347
+ if resource is not None:
348
+ pb_origin = await resource.get_origin()
349
+ if pb_origin is not None:
350
+ origin = from_proto.origin(pb_origin)
351
+ return rid, origin
352
+
353
+ rids = {tb_id.rid for tb_id in text_block_ids}
354
+ origins = await run_concurrently([_get_origin(kbid, rid) for rid in rids])
355
+ rid_to_origin = {rid: origin for rid, origin in origins if origin is not None}
356
+ for tb_id in text_block_ids:
357
+ origin = rid_to_origin.get(tb_id.rid)
358
+ if origin is not None and tb_id.full() in context.output:
359
+ context[tb_id.full()] += f"\n\nDOCUMENT METADATA AT ORIGIN:\n{to_yaml(origin)}"
360
+
361
+
362
+ async def extend_prompt_context_with_classification_labels(
363
+ context, kbid, text_block_ids: list[TextBlockId]
364
+ ):
365
+ async def _get_labels(kbid: str, _id: TextBlockId) -> tuple[TextBlockId, list[tuple[str, str]]]:
366
+ fid = _id if isinstance(_id, FieldId) else _id.field_id
367
+ labels = set()
368
+ resource = await cache.get_resource(kbid, fid.rid)
369
+ if resource is not None:
370
+ pb_basic = await resource.get_basic()
371
+ if pb_basic is not None:
372
+ # Add the classification labels of the resource
373
+ for classif in pb_basic.usermetadata.classifications:
374
+ labels.add((classif.labelset, classif.label))
375
+ # Add the classifications labels of the field
376
+ for fc in pb_basic.computedmetadata.field_classifications:
377
+ if fc.field.field == fid.key and fc.field.field_type == fid.pb_type:
378
+ for classif in fc.classifications:
379
+ if classif.cancelled_by_user: # pragma: no cover
380
+ continue
381
+ labels.add((classif.labelset, classif.label))
382
+ return _id, list(labels)
383
+
384
+ classif_labels = await run_concurrently([_get_labels(kbid, tb_id) for tb_id in text_block_ids])
385
+ tb_id_to_labels = {tb_id: labels for tb_id, labels in classif_labels if len(labels) > 0}
386
+ for tb_id in text_block_ids:
387
+ labels = tb_id_to_labels.get(tb_id)
388
+ if labels is not None and tb_id.full() in context.output:
389
+ labels_text = "DOCUMENT CLASSIFICATION LABELS:"
390
+ for labelset, label in labels:
391
+ labels_text += f"\n - {label} ({labelset})"
392
+ context[tb_id.full()] += "\n\n" + labels_text
393
+
394
+
395
+ async def extend_prompt_context_with_ner(context, kbid, text_block_ids: list[TextBlockId]):
396
+ async def _get_ners(kbid: str, _id: TextBlockId) -> tuple[TextBlockId, dict[str, set[str]]]:
397
+ fid = _id if isinstance(_id, FieldId) else _id.field_id
398
+ ners: dict[str, set[str]] = {}
399
+ resource = await cache.get_resource(kbid, fid.rid)
400
+ if resource is not None:
401
+ field = await resource.get_field(fid.key, fid.pb_type, load=False)
402
+ fcm = await field.get_field_metadata()
403
+ if fcm is not None:
404
+ # Data Augmentation + Processor entities
405
+ for (
406
+ data_aumgentation_task_id,
407
+ entities_wrapper,
408
+ ) in fcm.metadata.entities.items():
409
+ for entity in entities_wrapper.entities:
410
+ ners.setdefault(entity.label, set()).add(entity.text)
411
+ # Legacy processor entities
412
+ # TODO: Remove once processor doesn't use this anymore and remove the positions and ner fields from the message
413
+ for token, family in fcm.metadata.ner.items():
414
+ ners.setdefault(family, set()).add(token)
415
+ return _id, ners
416
+
417
+ nerss = await run_concurrently([_get_ners(kbid, tb_id) for tb_id in text_block_ids])
418
+ tb_id_to_ners = {tb_id: ners for tb_id, ners in nerss if len(ners) > 0}
419
+ for tb_id in text_block_ids:
420
+ ners = tb_id_to_ners.get(tb_id)
421
+ if ners is not None and tb_id.full() in context.output:
422
+ ners_text = "DOCUMENT NAMED ENTITIES (NERs):"
423
+ for family, tokens in ners.items():
424
+ ners_text += f"\n - {family}:"
425
+ for token in sorted(list(tokens)):
426
+ ners_text += f"\n - {token}"
427
+ context[tb_id.full()] += "\n\n" + ners_text
428
+
429
+
430
+ async def extend_prompt_context_with_extra_metadata(context, kbid, text_block_ids: list[TextBlockId]):
431
+ async def _get_extra(kbid: str, rid: str) -> tuple[str, Optional[Extra]]:
432
+ extra = None
433
+ resource = await cache.get_resource(kbid, rid)
434
+ if resource is not None:
435
+ pb_extra = await resource.get_extra()
436
+ if pb_extra is not None:
437
+ extra = from_proto.extra(pb_extra)
438
+ return rid, extra
439
+
440
+ rids = {tb_id.rid for tb_id in text_block_ids}
441
+ extras = await run_concurrently([_get_extra(kbid, rid) for rid in rids])
442
+ rid_to_extra = {rid: extra for rid, extra in extras if extra is not None}
443
+ for tb_id in text_block_ids:
444
+ extra = rid_to_extra.get(tb_id.rid)
445
+ if extra is not None and tb_id.full() in context.output:
446
+ context[tb_id.full()] += f"\n\nDOCUMENT EXTRA METADATA:\n{to_yaml(extra)}"
447
+
448
+
449
+ def to_yaml(obj: BaseModel) -> str:
450
+ return yaml.dump(
451
+ obj.model_dump(exclude_none=True, exclude_defaults=True, exclude_unset=True),
452
+ default_flow_style=False,
453
+ indent=2,
454
+ sort_keys=True,
455
+ )
456
+
457
+
458
+ async def field_extension_prompt_context(
459
+ context: CappedPromptContext,
460
+ kbid: str,
461
+ ordered_paragraphs: list[FindParagraph],
462
+ strategy: FieldExtensionStrategy,
291
463
  ) -> None:
292
464
  """
293
465
  Algorithm steps:
@@ -296,35 +468,402 @@ async def composed_prompt_context(
296
468
  - Add the extracted text of each field to the beginning of the context.
297
469
  - Add the extracted text of each paragraph to the end of the context.
298
470
  """
299
- # Collect the list of resources in the results (in order of relevance).
300
- ordered_paras = get_ordered_paragraphs(results)
301
471
  ordered_resources = []
302
- for paragraph in ordered_paras:
303
- resource_uuid = paragraph.id.split("/")[0]
472
+ for paragraph in ordered_paragraphs:
473
+ resource_uuid = ParagraphId.from_string(paragraph.id).rid
304
474
  if resource_uuid not in ordered_resources:
305
475
  ordered_resources.append(resource_uuid)
306
476
 
307
477
  # Fetch the extracted texts of the specified fields for each resource
308
- txn = await get_read_only_transaction()
309
- kb_obj = KnowledgeBoxORM(txn, await get_storage(), kbid)
310
-
311
- tasks = [
312
- get_resource_field_extracted_text(kb_obj, resource_uuid, field_id)
313
- for resource_uuid in ordered_resources
314
- for field_id in extend_with_fields
315
- ]
478
+ extend_fields = strategy.fields
479
+ extend_field_ids = []
480
+ for resource_uuid in ordered_resources:
481
+ for field_id in extend_fields:
482
+ try:
483
+ fid = FieldId.from_string(f"{resource_uuid}/{field_id.strip('/')}")
484
+ extend_field_ids.append(fid)
485
+ except ValueError: # pragma: no cover
486
+ # Invalid field id, skiping
487
+ continue
488
+
489
+ tasks = [hydrate_field_text(kbid, fid) for fid in extend_field_ids]
316
490
  field_extracted_texts = await run_concurrently(tasks)
317
491
 
318
492
  for result in field_extracted_texts:
319
- if result is None:
493
+ if result is None: # pragma: no cover
320
494
  continue
321
- # Add the extracted text of each field to the beginning of the context.
322
495
  field, extracted_text = result
323
- context[field.resource_unique_id] = extracted_text
496
+ # First off, remove the text block ids from paragraphs that belong to
497
+ # the same field, as otherwise the context will be duplicated.
498
+ for tb_id in context.text_block_ids():
499
+ if tb_id.startswith(field.full()):
500
+ del context[tb_id]
501
+ # Add the extracted text of each field to the beginning of the context.
502
+ context[field.full()] = extracted_text
324
503
 
325
504
  # Add the extracted text of each paragraph to the end of the context.
326
- for paragraph in ordered_paras:
505
+ for paragraph in ordered_paragraphs:
506
+ context[paragraph.id] = _clean_paragraph_text(paragraph)
507
+
508
+
509
+ async def get_paragraph_text_with_neighbours(
510
+ kbid: str,
511
+ pid: ParagraphId,
512
+ field_paragraphs: list[ParagraphId],
513
+ before: int = 0,
514
+ after: int = 0,
515
+ ) -> tuple[ParagraphId, str]:
516
+ """
517
+ This function will get the paragraph text of the paragraph with the neighbouring paragraphs included.
518
+ Parameters:
519
+ kbid: The knowledge box id.
520
+ pid: The matching paragraph id.
521
+ field_paragraphs: The list of paragraph ids of the field.
522
+ before: The number of paragraphs to include before the matching paragraph.
523
+ after: The number of paragraphs to include after the matching paragraph.
524
+ """
525
+
526
+ async def _get_paragraph_text(
527
+ kbid: str,
528
+ pid: ParagraphId,
529
+ ) -> tuple[ParagraphId, str]:
530
+ return pid, await get_paragraph_text(
531
+ kbid=kbid,
532
+ paragraph_id=pid,
533
+ log_on_missing_field=True,
534
+ )
535
+
536
+ ops = []
537
+ try:
538
+ for paragraph_index in get_neighbouring_paragraph_indexes(
539
+ field_paragraphs=field_paragraphs,
540
+ matching_paragraph=pid,
541
+ before=before,
542
+ after=after,
543
+ ):
544
+ neighbour_pid = field_paragraphs[paragraph_index]
545
+ ops.append(
546
+ asyncio.create_task(
547
+ _get_paragraph_text(
548
+ kbid=kbid,
549
+ pid=neighbour_pid,
550
+ )
551
+ )
552
+ )
553
+ except ParagraphIdNotFoundInExtractedMetadata:
554
+ logger.warning(
555
+ "Could not find matching paragraph in extracted metadata. This is odd and needs to be investigated.",
556
+ extra={
557
+ "kbid": kbid,
558
+ "matching_paragraph": pid.full(),
559
+ "field_paragraphs": [p.full() for p in field_paragraphs],
560
+ },
561
+ )
562
+ # If we could not find the matching paragraph in the extracted metadata, we can't retrieve
563
+ # the neighbouring paragraphs and we simply fetch the text of the matching paragraph.
564
+ ops.append(
565
+ asyncio.create_task(
566
+ _get_paragraph_text(
567
+ kbid=kbid,
568
+ pid=pid,
569
+ )
570
+ )
571
+ )
572
+
573
+ results = []
574
+ if len(ops) > 0:
575
+ results = await asyncio.gather(*ops)
576
+
577
+ # Sort the results by the paragraph start
578
+ results.sort(key=lambda x: x[0].paragraph_start)
579
+ paragraph_texts = []
580
+ for _, text in results:
581
+ if text != "":
582
+ paragraph_texts.append(text)
583
+ return pid, "\n\n".join(paragraph_texts)
584
+
585
+
586
+ async def get_field_paragraphs_list(
587
+ kbid: str,
588
+ field: FieldId,
589
+ paragraphs: list[ParagraphId],
590
+ ) -> None:
591
+ """
592
+ Modifies the paragraphs list by adding the paragraph ids of the field, sorted by position.
593
+ """
594
+ resource = await cache.get_resource(kbid, field.rid)
595
+ if resource is None: # pragma: no cover
596
+ return
597
+ field_obj: Field = await resource.get_field(key=field.key, type=field.pb_type, load=False)
598
+ field_metadata: Optional[resources_pb2.FieldComputedMetadata] = await field_obj.get_field_metadata(
599
+ force=True
600
+ )
601
+ if field_metadata is None: # pragma: no cover
602
+ return
603
+ for paragraph in field_metadata.metadata.paragraphs:
604
+ paragraphs.append(
605
+ ParagraphId(
606
+ field_id=field,
607
+ paragraph_start=paragraph.start,
608
+ paragraph_end=paragraph.end,
609
+ )
610
+ )
611
+
612
+
613
+ async def neighbouring_paragraphs_prompt_context(
614
+ context: CappedPromptContext,
615
+ kbid: str,
616
+ ordered_text_blocks: list[FindParagraph],
617
+ strategy: NeighbouringParagraphsStrategy,
618
+ ) -> None:
619
+ """
620
+ This function will get the paragraph texts and then craft a context with the neighbouring paragraphs of the
621
+ paragraphs in the ordered_paragraphs list. The number of paragraphs to include before and after each paragraph
622
+ """
623
+ # First, get the sorted list of paragraphs for each matching field
624
+ # so we can know the indexes of the neighbouring paragraphs
625
+ unique_fields = {
626
+ ParagraphId.from_string(text_block.id).field_id for text_block in ordered_text_blocks
627
+ }
628
+ paragraphs_by_field: dict[FieldId, list[ParagraphId]] = {}
629
+ field_ops = []
630
+ for field_id in unique_fields:
631
+ plist = paragraphs_by_field.setdefault(field_id, [])
632
+ field_ops.append(
633
+ asyncio.create_task(get_field_paragraphs_list(kbid=kbid, field=field_id, paragraphs=plist))
634
+ )
635
+ if field_ops:
636
+ await asyncio.gather(*field_ops)
637
+
638
+ # Now, get the paragraph texts with the neighbouring paragraphs
639
+ paragraph_ops = []
640
+ for text_block in ordered_text_blocks:
641
+ pid = ParagraphId.from_string(text_block.id)
642
+ paragraph_ops.append(
643
+ asyncio.create_task(
644
+ get_paragraph_text_with_neighbours(
645
+ kbid=kbid,
646
+ pid=pid,
647
+ before=strategy.before,
648
+ after=strategy.after,
649
+ field_paragraphs=paragraphs_by_field.get(pid.field_id, []),
650
+ )
651
+ )
652
+ )
653
+ if not paragraph_ops: # pragma: no cover
654
+ return
655
+
656
+ results: list[tuple[ParagraphId, str]] = await asyncio.gather(*paragraph_ops)
657
+ # Add the paragraph texts to the context
658
+ for pid, text in results:
659
+ if text != "":
660
+ context[pid.full()] = text
661
+
662
+
663
+ async def conversation_prompt_context(
664
+ context: CappedPromptContext,
665
+ kbid: str,
666
+ ordered_paragraphs: list[FindParagraph],
667
+ conversational_strategy: ConversationalStrategy,
668
+ visual_llm: bool,
669
+ ):
670
+ analyzed_fields: List[str] = []
671
+ async with get_driver().transaction(read_only=True) as txn:
672
+ storage = await get_storage()
673
+ kb = KnowledgeBoxORM(txn, storage, kbid)
674
+ for paragraph in ordered_paragraphs:
675
+ context[paragraph.id] = _clean_paragraph_text(paragraph)
676
+
677
+ # If the paragraph is a conversation and it matches semantically, we assume we
678
+ # have matched with the question, therefore try to include the answer to the
679
+ # context by pulling the next few messages of the conversation field
680
+ rid, field_type, field_id, mident = paragraph.id.split("/")[:4]
681
+ if field_type == "c" and paragraph.score_type in (
682
+ SCORE_TYPE.VECTOR,
683
+ SCORE_TYPE.BOTH,
684
+ SCORE_TYPE.BM25,
685
+ ):
686
+ field_unique_id = "-".join([rid, field_type, field_id])
687
+ if field_unique_id in analyzed_fields:
688
+ continue
689
+ resource = await kb.get(rid)
690
+ if resource is None: # pragma: no cover
691
+ continue
692
+
693
+ field_obj: Conversation = await resource.get_field(
694
+ field_id, FIELD_TYPE_STR_TO_PB["c"], load=True
695
+ ) # type: ignore
696
+ cmetadata = await field_obj.get_metadata()
697
+
698
+ attachments: List[resources_pb2.FieldRef] = []
699
+ if conversational_strategy.full:
700
+ extracted_text = await field_obj.get_extracted_text()
701
+ for current_page in range(1, cmetadata.pages + 1):
702
+ conv = await field_obj.db_get_value(current_page)
703
+
704
+ for message in conv.messages:
705
+ ident = message.ident
706
+ if extracted_text is not None:
707
+ text = extracted_text.split_text.get(ident, message.content.text.strip())
708
+ else:
709
+ text = message.content.text.strip()
710
+ pid = f"{rid}/{field_type}/{field_id}/{ident}/0-{len(text) + 1}"
711
+ context[pid] = text
712
+ attachments.extend(message.content.attachments_fields)
713
+ else:
714
+ # Add first message
715
+ extracted_text = await field_obj.get_extracted_text()
716
+ first_page = await field_obj.db_get_value()
717
+ if len(first_page.messages) > 0:
718
+ message = first_page.messages[0]
719
+ ident = message.ident
720
+ if extracted_text is not None:
721
+ text = extracted_text.split_text.get(ident, message.content.text.strip())
722
+ else:
723
+ text = message.content.text.strip()
724
+ pid = f"{rid}/{field_type}/{field_id}/{ident}/0-{len(text) + 1}"
725
+ context[pid] = text
726
+ attachments.extend(message.content.attachments_fields)
727
+
728
+ messages: Deque[resources_pb2.Message] = deque(
729
+ maxlen=conversational_strategy.max_messages
730
+ )
731
+
732
+ pending = -1
733
+ for page in range(1, cmetadata.pages + 1):
734
+ # Collect the messages with the window asked by the user arround the match paragraph
735
+ conv = await field_obj.db_get_value(page)
736
+ for message in conv.messages:
737
+ messages.append(message)
738
+ if pending > 0:
739
+ pending -= 1
740
+ if message.ident == mident:
741
+ pending = (conversational_strategy.max_messages - 1) // 2
742
+ if pending == 0:
743
+ break
744
+ if pending == 0:
745
+ break
746
+
747
+ for message in messages:
748
+ text = message.content.text.strip()
749
+ pid = f"{rid}/{field_type}/{field_id}/{message.ident}/0-{len(message.content.text) + 1}"
750
+ context[pid] = text
751
+ attachments.extend(message.content.attachments_fields)
752
+
753
+ if conversational_strategy.attachments_text:
754
+ # add on the context the images if vlm enabled
755
+ for attachment in attachments:
756
+ field: File = await resource.get_field(
757
+ attachment.field_id, attachment.field_type, load=True
758
+ ) # type: ignore
759
+ extracted_text = await field.get_extracted_text()
760
+ if extracted_text is not None:
761
+ pid = f"{rid}/{field_type}/{attachment.field_id}/0-{len(extracted_text.text) + 1}"
762
+ context[pid] = f"Attachment {attachment.field_id}: {extracted_text.text}\n\n"
763
+
764
+ if conversational_strategy.attachments_images and visual_llm:
765
+ for attachment in attachments:
766
+ file_field: File = await resource.get_field(
767
+ attachment.field_id, attachment.field_type, load=True
768
+ ) # type: ignore
769
+ image = await get_file_thumbnail_image(file_field)
770
+ if image is not None:
771
+ pid = f"{rid}/f/{attachment.field_id}/0-0"
772
+ context.images[pid] = image
773
+
774
+ analyzed_fields.append(field_unique_id)
775
+
776
+
777
+ async def hierarchy_prompt_context(
778
+ context: CappedPromptContext,
779
+ kbid: str,
780
+ ordered_paragraphs: list[FindParagraph],
781
+ strategy: HierarchyResourceStrategy,
782
+ ) -> None:
783
+ """
784
+ This function will get the paragraph texts (possibly with extra characters, if extra_characters > 0) and then
785
+ craft a context with all paragraphs of the same resource grouped together. Moreover, on each group of paragraphs,
786
+ it includes the resource title and summary so that the LLM can have a better understanding of the context.
787
+ """
788
+ paragraphs_extra_characters = max(strategy.count, 0)
789
+ # Make a copy of the ordered paragraphs to avoid modifying the original list, which is returned
790
+ # in the response to the user
791
+ ordered_paragraphs_copy = copy.deepcopy(ordered_paragraphs)
792
+ resources: Dict[str, ExtraCharsParagraph] = {}
793
+
794
+ # Iterate paragraphs to get extended text
795
+ for paragraph in ordered_paragraphs_copy:
796
+ paragraph_id = ParagraphId.from_string(paragraph.id)
797
+ extended_paragraph_text = paragraph.text
798
+ if paragraphs_extra_characters > 0:
799
+ extended_paragraph_text = await get_paragraph_text(
800
+ kbid=kbid,
801
+ paragraph_id=paragraph_id,
802
+ log_on_missing_field=True,
803
+ )
804
+ rid = paragraph_id.rid
805
+ if rid not in resources:
806
+ # Get the title and the summary of the resource
807
+ title_text = await get_paragraph_text(
808
+ kbid=kbid,
809
+ paragraph_id=ParagraphId(
810
+ field_id=FieldId(
811
+ rid=rid,
812
+ type="a",
813
+ key="title",
814
+ ),
815
+ paragraph_start=0,
816
+ paragraph_end=500,
817
+ ),
818
+ log_on_missing_field=False,
819
+ )
820
+ summary_text = await get_paragraph_text(
821
+ kbid=kbid,
822
+ paragraph_id=ParagraphId(
823
+ field_id=FieldId(
824
+ rid=rid,
825
+ type="a",
826
+ key="summary",
827
+ ),
828
+ paragraph_start=0,
829
+ paragraph_end=1000,
830
+ ),
831
+ log_on_missing_field=False,
832
+ )
833
+ resources[rid] = ExtraCharsParagraph(
834
+ title=title_text,
835
+ summary=summary_text,
836
+ paragraphs=[(paragraph, extended_paragraph_text)],
837
+ )
838
+ else:
839
+ resources[rid].paragraphs.append((paragraph, extended_paragraph_text))
840
+
841
+ # Modify the first paragraph of each resource to include the title and summary of the resource, as well as the
842
+ # extended paragraph text of all the paragraphs in the resource.
843
+ for values in resources.values():
844
+ title_text = values.title
845
+ summary_text = values.summary
846
+ first_paragraph = None
847
+ text_with_hierarchy = ""
848
+ for paragraph, extended_paragraph_text in values.paragraphs:
849
+ if first_paragraph is None:
850
+ first_paragraph = paragraph
851
+ text_with_hierarchy += "\n EXTRACTED BLOCK: \n " + extended_paragraph_text + " \n\n "
852
+ # All paragraphs of the resource are cleared except the first one, which will be the
853
+ # one containing the whole hierarchy information
854
+ paragraph.text = ""
855
+
856
+ if first_paragraph is not None:
857
+ # The first paragraph is the only one holding the hierarchy information
858
+ first_paragraph.text = f"DOCUMENT: {title_text} \n SUMMARY: {summary_text} \n RESOURCE CONTENT: {text_with_hierarchy}"
859
+
860
+ # Now that the paragraphs have been modified, we can add them to the context
861
+ for paragraph in ordered_paragraphs_copy:
862
+ if paragraph.text == "":
863
+ # Skip paragraphs that were cleared in the hierarchy expansion
864
+ continue
327
865
  context[paragraph.id] = _clean_paragraph_text(paragraph)
866
+ return
328
867
 
329
868
 
330
869
  class PromptContextBuilder:
@@ -335,19 +874,21 @@ class PromptContextBuilder:
335
874
  def __init__(
336
875
  self,
337
876
  kbid: str,
338
- find_results: KnowledgeboxFindResults,
877
+ ordered_paragraphs: list[FindParagraph],
878
+ resource: Optional[str] = None,
339
879
  user_context: Optional[list[str]] = None,
340
880
  strategies: Optional[Sequence[RagStrategy]] = None,
341
881
  image_strategies: Optional[Sequence[ImageRagStrategy]] = None,
342
- max_context_size: Optional[int] = None,
882
+ max_context_characters: Optional[int] = None,
343
883
  visual_llm: bool = False,
344
884
  ):
345
885
  self.kbid = kbid
346
- self.find_results = find_results
886
+ self.ordered_paragraphs = ordered_paragraphs
887
+ self.resource = resource
347
888
  self.user_context = user_context
348
889
  self.strategies = strategies
349
890
  self.image_strategies = image_strategies
350
- self.max_context_size = max_context_size
891
+ self.max_context_characters = max_context_characters
351
892
  self.visual_llm = visual_llm
352
893
 
353
894
  def prepend_user_context(self, context: CappedPromptContext):
@@ -359,95 +900,178 @@ class PromptContextBuilder:
359
900
  async def build(
360
901
  self,
361
902
  ) -> tuple[PromptContext, PromptContextOrder, PromptContextImages]:
362
- ccontext = CappedPromptContext(max_size=self.max_context_size)
903
+ ccontext = CappedPromptContext(max_size=self.max_context_characters)
363
904
  self.prepend_user_context(ccontext)
364
905
  await self._build_context(ccontext)
365
-
366
906
  if self.visual_llm:
367
907
  await self._build_context_images(ccontext)
368
908
 
369
909
  context = ccontext.output
370
910
  context_images = ccontext.images
371
- context_order = {
372
- text_block_id: order for order, text_block_id in enumerate(context.keys())
373
- }
911
+ context_order = {text_block_id: order for order, text_block_id in enumerate(context.keys())}
374
912
  return context, context_order, context_images
375
913
 
376
914
  async def _build_context_images(self, context: CappedPromptContext) -> None:
377
- ordered_paras = get_ordered_paragraphs(self.find_results)
378
- flatten_strategies = []
379
- page_count = 5
380
- gather_pages = False
381
- gather_tables = False
382
- if self.image_strategies is not None:
383
- for strategy in self.image_strategies:
384
- flatten_strategies.append(strategy.name)
385
- if strategy.name == ImageRagStrategyName.PAGE_IMAGE:
386
- gather_pages = True
387
- if strategy.count is not None: # type: ignore
388
- page_count = strategy.count # type: ignore
389
- if strategy.name == ImageRagStrategyName.TABLES:
390
- gather_tables = True
391
-
392
- for paragraph in ordered_paras:
393
- if paragraph.page_with_visual and paragraph.position:
394
- if (
395
- gather_pages
396
- and paragraph.position.page_number
397
- and len(context.images) < page_count
398
- ):
399
- field = "/".join(paragraph.id.split("/")[:3])
400
- page = paragraph.position.page_number
401
- page_id = f"{field}/{page}"
402
- if page_id not in context.images:
403
- context.images[page_id] = await get_page_image(
404
- self.kbid, paragraph.id, page
405
- )
915
+ if self.image_strategies is None or len(self.image_strategies) == 0:
916
+ # Nothing to do
917
+ return
918
+ page_image_strategy: Optional[PageImageStrategy] = None
919
+ max_page_images = 5
920
+ table_image_strategy: Optional[TableImageStrategy] = None
921
+ paragraph_image_strategy: Optional[ParagraphImageStrategy] = None
922
+ for strategy in self.image_strategies:
923
+ if strategy.name == ImageRagStrategyName.PAGE_IMAGE:
924
+ if page_image_strategy is None:
925
+ page_image_strategy = cast(PageImageStrategy, strategy)
926
+ if page_image_strategy.count is not None:
927
+ max_page_images = page_image_strategy.count
928
+ elif strategy.name == ImageRagStrategyName.TABLES:
929
+ if table_image_strategy is None:
930
+ table_image_strategy = cast(TableImageStrategy, strategy)
931
+ elif strategy.name == ImageRagStrategyName.PARAGRAPH_IMAGE:
932
+ if paragraph_image_strategy is None:
933
+ paragraph_image_strategy = cast(ParagraphImageStrategy, strategy)
934
+ else: # pragma: no cover
935
+ logger.warning(
936
+ "Unknown image strategy",
937
+ extra={"strategy": strategy.name, "kbid": self.kbid},
938
+ )
939
+ page_images_added = 0
940
+ for paragraph in self.ordered_paragraphs:
941
+ pid = ParagraphId.from_string(paragraph.id)
942
+ paragraph_page_number = get_paragraph_page_number(paragraph)
406
943
  if (
407
- gather_tables
408
- and paragraph.is_a_table
409
- and paragraph.reference
410
- and paragraph.reference != ""
944
+ page_image_strategy is not None
945
+ and page_images_added < max_page_images
946
+ and paragraph_page_number is not None
411
947
  ):
412
- image = paragraph.reference
413
- context.images[paragraph.id] = await get_paragraph_image(
414
- self.kbid, paragraph.id, image
415
- )
948
+ # page_image_id: rid/f/myfield/0
949
+ page_image_id = "/".join([pid.field_id.full(), str(paragraph_page_number)])
950
+ if page_image_id not in context.images:
951
+ image = await get_page_image(self.kbid, pid, paragraph_page_number)
952
+ if image is not None:
953
+ context.images[page_image_id] = image
954
+ page_images_added += 1
955
+ else:
956
+ logger.warning(
957
+ f"Could not retrieve image for paragraph from storage",
958
+ extra={
959
+ "kbid": self.kbid,
960
+ "paragraph": pid.full(),
961
+ "page_number": paragraph_page_number,
962
+ },
963
+ )
964
+
965
+ add_table = table_image_strategy is not None and paragraph.is_a_table
966
+ add_paragraph = paragraph_image_strategy is not None and not paragraph.is_a_table
967
+ if (add_table or add_paragraph) and (
968
+ paragraph.reference is not None and paragraph.reference != ""
969
+ ):
970
+ pimage = await get_paragraph_image(self.kbid, pid, paragraph.reference)
971
+ if pimage is not None:
972
+ context.images[paragraph.id] = pimage
973
+ else:
974
+ logger.warning(
975
+ f"Could not retrieve image for paragraph from storage",
976
+ extra={
977
+ "kbid": self.kbid,
978
+ "paragraph": pid.full(),
979
+ "reference": paragraph.reference,
980
+ },
981
+ )
416
982
 
417
983
  async def _build_context(self, context: CappedPromptContext) -> None:
418
984
  if self.strategies is None or len(self.strategies) == 0:
419
- await default_prompt_context(context, self.kbid, self.find_results)
985
+ # When no strategy is specified, use the default one
986
+ await default_prompt_context(context, self.kbid, self.ordered_paragraphs)
420
987
  return
421
-
422
- number_of_full_resources = 0
423
- distance = 0
424
- extend_with_fields = []
988
+ else:
989
+ # Add the paragraphs to the context and then apply the strategies
990
+ for paragraph in self.ordered_paragraphs:
991
+ context[paragraph.id] = _clean_paragraph_text(paragraph)
992
+
993
+ full_resource: Optional[FullResourceStrategy] = None
994
+ hierarchy: Optional[HierarchyResourceStrategy] = None
995
+ neighbouring_paragraphs: Optional[NeighbouringParagraphsStrategy] = None
996
+ field_extension: Optional[FieldExtensionStrategy] = None
997
+ metadata_extension: Optional[MetadataExtensionStrategy] = None
998
+ conversational_strategy: Optional[ConversationalStrategy] = None
425
999
  for strategy in self.strategies:
426
1000
  if strategy.name == RagStrategyName.FIELD_EXTENSION:
427
- extend_with_fields.extend(strategy.fields) # type: ignore
1001
+ field_extension = cast(FieldExtensionStrategy, strategy)
1002
+ elif strategy.name == RagStrategyName.CONVERSATION:
1003
+ conversational_strategy = cast(ConversationalStrategy, strategy)
428
1004
  elif strategy.name == RagStrategyName.FULL_RESOURCE:
429
- number_of_full_resources = strategy.count or self.find_results.total # type: ignore
1005
+ full_resource = cast(FullResourceStrategy, strategy)
1006
+ if self.resource: # pragma: no cover
1007
+ # When the retrieval is scoped to a specific resource
1008
+ # the full resource strategy only includes that resource
1009
+ full_resource.count = 1
430
1010
  elif strategy.name == RagStrategyName.HIERARCHY:
431
- distance = strategy.count # type: ignore
1011
+ hierarchy = cast(HierarchyResourceStrategy, strategy)
1012
+ elif strategy.name == RagStrategyName.NEIGHBOURING_PARAGRAPHS:
1013
+ neighbouring_paragraphs = cast(NeighbouringParagraphsStrategy, strategy)
1014
+ elif strategy.name == RagStrategyName.METADATA_EXTENSION:
1015
+ metadata_extension = cast(MetadataExtensionStrategy, strategy)
1016
+ elif strategy.name != RagStrategyName.PREQUERIES: # pragma: no cover
1017
+ # Prequeries are not handled here
1018
+ logger.warning(
1019
+ "Unknown rag strategy",
1020
+ extra={"strategy": strategy.name, "kbid": self.kbid},
1021
+ )
432
1022
 
433
- if number_of_full_resources:
1023
+ if full_resource:
1024
+ # When full resoure is enabled, only metadata extension is allowed.
434
1025
  await full_resource_prompt_context(
435
- context, self.kbid, self.find_results, number_of_full_resources
1026
+ context,
1027
+ self.kbid,
1028
+ self.ordered_paragraphs,
1029
+ self.resource,
1030
+ full_resource,
436
1031
  )
1032
+ if metadata_extension:
1033
+ await extend_prompt_context_with_metadata(context, self.kbid, metadata_extension)
437
1034
  return
438
1035
 
439
- if distance > 0:
440
- await get_extra_chars(self.kbid, self.find_results, distance)
441
- await default_prompt_context(context, self.kbid, self.find_results)
442
- return
1036
+ if hierarchy:
1037
+ await hierarchy_prompt_context(
1038
+ context,
1039
+ self.kbid,
1040
+ self.ordered_paragraphs,
1041
+ hierarchy,
1042
+ )
1043
+ if neighbouring_paragraphs:
1044
+ await neighbouring_paragraphs_prompt_context(
1045
+ context,
1046
+ self.kbid,
1047
+ self.ordered_paragraphs,
1048
+ neighbouring_paragraphs,
1049
+ )
1050
+ if field_extension:
1051
+ await field_extension_prompt_context(
1052
+ context,
1053
+ self.kbid,
1054
+ self.ordered_paragraphs,
1055
+ field_extension,
1056
+ )
1057
+ if conversational_strategy:
1058
+ await conversation_prompt_context(
1059
+ context,
1060
+ self.kbid,
1061
+ self.ordered_paragraphs,
1062
+ conversational_strategy,
1063
+ self.visual_llm,
1064
+ )
1065
+ if metadata_extension:
1066
+ await extend_prompt_context_with_metadata(context, self.kbid, metadata_extension)
443
1067
 
444
- await composed_prompt_context(
445
- context,
446
- self.kbid,
447
- self.find_results,
448
- extend_with_fields=extend_with_fields,
449
- )
450
- return
1068
+
1069
+ def get_paragraph_page_number(paragraph: FindParagraph) -> Optional[int]:
1070
+ if not paragraph.page_with_visual:
1071
+ return None
1072
+ if paragraph.position is None:
1073
+ return None
1074
+ return paragraph.position.page_number
451
1075
 
452
1076
 
453
1077
  @dataclass
@@ -457,67 +1081,6 @@ class ExtraCharsParagraph:
457
1081
  paragraphs: List[Tuple[FindParagraph, str]]
458
1082
 
459
1083
 
460
- async def get_extra_chars(
461
- kbid: str, find_results: KnowledgeboxFindResults, distance: int
462
- ):
463
- etcache = paragraphs.ExtractedTextCache()
464
- resources: Dict[str, ExtraCharsParagraph] = {}
465
- for paragraph in get_ordered_paragraphs(find_results):
466
- rid, field_type, field = paragraph.id.split("/")[:3]
467
- field_path = "/".join([rid, field_type, field])
468
- position = paragraph.id.split("/")[-1]
469
- start, end = position.split("-")
470
- int_start = int(start)
471
- int_end = int(end) + distance
472
-
473
- new_text = await paragraphs.get_paragraph_text(
474
- kbid=kbid,
475
- rid=rid,
476
- field=field_path,
477
- start=int_start,
478
- end=int_end,
479
- extracted_text_cache=etcache,
480
- )
481
- if rid not in resources:
482
- title_text = await paragraphs.get_paragraph_text(
483
- kbid=kbid,
484
- rid=rid,
485
- field="/a/title",
486
- start=0,
487
- end=500,
488
- extracted_text_cache=etcache,
489
- )
490
- summary_text = await paragraphs.get_paragraph_text(
491
- kbid=kbid,
492
- rid=rid,
493
- field="/a/summary",
494
- start=0,
495
- end=1000,
496
- extracted_text_cache=etcache,
497
- )
498
- resources[rid] = ExtraCharsParagraph(
499
- title=title_text,
500
- summary=summary_text,
501
- paragraphs=[(paragraph, new_text)],
502
- )
503
- else:
504
- resources[rid].paragraphs.append((paragraph, new_text)) # type: ignore
505
-
506
- for key, values in resources.items():
507
- title_text = values.title
508
- summary_text = values.summary
509
- first_paragraph = None
510
- text = ""
511
- for paragraph, text in values.paragraphs:
512
- if first_paragraph is None:
513
- first_paragraph = paragraph
514
- text += "EXTRACTED BLOCK: \n " + text + " \n\n "
515
- paragraph.text = ""
516
-
517
- if first_paragraph is not None:
518
- first_paragraph.text = f"DOCUMENT: {title_text} \n SUMMARY: {summary_text} \n RESOURCE CONTENT: {text}"
519
-
520
-
521
1084
  def _clean_paragraph_text(paragraph: FindParagraph) -> str:
522
1085
  text = paragraph.text.strip()
523
1086
  # Do not send highlight marks on prompt context
@@ -525,17 +1088,23 @@ def _clean_paragraph_text(paragraph: FindParagraph) -> str:
525
1088
  return text
526
1089
 
527
1090
 
528
- def get_ordered_paragraphs(results: KnowledgeboxFindResults) -> list[FindParagraph]:
1091
+ def get_neighbouring_paragraph_indexes(
1092
+ field_paragraphs: list[ParagraphId],
1093
+ matching_paragraph: ParagraphId,
1094
+ before: int,
1095
+ after: int,
1096
+ ) -> list[int]:
529
1097
  """
530
- Returns the list of paragraphs in the results, ordered by relevance.
1098
+ Returns the indexes of the neighbouring paragraphs to fetch (including the matching paragraph).
531
1099
  """
532
- return sorted(
533
- [
534
- paragraph
535
- for resource in results.resources.values()
536
- for field in resource.fields.values()
537
- for paragraph in field.paragraphs.values()
538
- ],
539
- key=lambda paragraph: paragraph.order,
540
- reverse=False,
541
- )
1100
+ assert before >= 0
1101
+ assert after >= 0
1102
+ try:
1103
+ matching_index = field_paragraphs.index(matching_paragraph)
1104
+ except ValueError:
1105
+ raise ParagraphIdNotFoundInExtractedMetadata(
1106
+ f"Matching paragraph {matching_paragraph.full()} not found in extracted metadata"
1107
+ )
1108
+ start_index = max(0, matching_index - before)
1109
+ end_index = min(len(field_paragraphs), matching_index + after + 1)
1110
+ return list(range(start_index, end_index))