nucliadb 4.0.0.post542__py3-none-any.whl → 6.2.1.post2798__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- migrations/0003_allfields_key.py +1 -35
- migrations/0009_upgrade_relations_and_texts_to_v2.py +4 -2
- migrations/0010_fix_corrupt_indexes.py +10 -10
- migrations/0011_materialize_labelset_ids.py +1 -16
- migrations/0012_rollover_shards.py +5 -10
- migrations/0014_rollover_shards.py +4 -5
- migrations/0015_targeted_rollover.py +5 -10
- migrations/0016_upgrade_to_paragraphs_v2.py +25 -28
- migrations/0017_multiple_writable_shards.py +2 -4
- migrations/0018_purge_orphan_kbslugs.py +5 -7
- migrations/0019_upgrade_to_paragraphs_v3.py +25 -28
- migrations/0020_drain_nodes_from_cluster.py +3 -3
- nucliadb/standalone/tests/unit/test_run.py → migrations/0021_overwrite_vectorsets_key.py +16 -19
- nucliadb/tests/unit/test_openapi.py → migrations/0022_fix_paragraph_deletion_bug.py +16 -11
- migrations/0023_backfill_pg_catalog.py +80 -0
- migrations/0025_assign_models_to_kbs_v2.py +113 -0
- migrations/0026_fix_high_cardinality_content_types.py +61 -0
- migrations/0027_rollover_texts3.py +73 -0
- nucliadb/ingest/fields/date.py → migrations/pg/0001_bootstrap.py +10 -12
- migrations/pg/0002_catalog.py +42 -0
- nucliadb/ingest/tests/unit/test_settings.py → migrations/pg/0003_catalog_kbid_index.py +5 -3
- nucliadb/common/cluster/base.py +30 -16
- nucliadb/common/cluster/discovery/base.py +6 -14
- nucliadb/common/cluster/discovery/k8s.py +9 -19
- nucliadb/common/cluster/discovery/manual.py +1 -3
- nucliadb/common/cluster/discovery/utils.py +1 -3
- nucliadb/common/cluster/grpc_node_dummy.py +3 -11
- nucliadb/common/cluster/index_node.py +10 -19
- nucliadb/common/cluster/manager.py +174 -59
- nucliadb/common/cluster/rebalance.py +27 -29
- nucliadb/common/cluster/rollover.py +353 -194
- nucliadb/common/cluster/settings.py +6 -0
- nucliadb/common/cluster/standalone/grpc_node_binding.py +13 -64
- nucliadb/common/cluster/standalone/index_node.py +4 -11
- nucliadb/common/cluster/standalone/service.py +2 -6
- nucliadb/common/cluster/standalone/utils.py +2 -6
- nucliadb/common/cluster/utils.py +29 -22
- nucliadb/common/constants.py +20 -0
- nucliadb/common/context/__init__.py +3 -0
- nucliadb/common/context/fastapi.py +8 -5
- nucliadb/{tests/knowledgeboxes/__init__.py → common/counters.py} +8 -2
- nucliadb/common/datamanagers/__init__.py +7 -1
- nucliadb/common/datamanagers/atomic.py +22 -4
- nucliadb/common/datamanagers/cluster.py +5 -5
- nucliadb/common/datamanagers/entities.py +6 -16
- nucliadb/common/datamanagers/fields.py +84 -0
- nucliadb/common/datamanagers/kb.py +83 -37
- nucliadb/common/datamanagers/labels.py +26 -56
- nucliadb/common/datamanagers/processing.py +2 -6
- nucliadb/common/datamanagers/resources.py +41 -103
- nucliadb/common/datamanagers/rollover.py +76 -15
- nucliadb/common/datamanagers/synonyms.py +1 -1
- nucliadb/common/datamanagers/utils.py +15 -6
- nucliadb/common/datamanagers/vectorsets.py +110 -0
- nucliadb/common/external_index_providers/base.py +257 -0
- nucliadb/{ingest/tests/unit/orm/test_orm_utils.py → common/external_index_providers/exceptions.py} +9 -8
- nucliadb/common/external_index_providers/manager.py +101 -0
- nucliadb/common/external_index_providers/pinecone.py +933 -0
- nucliadb/common/external_index_providers/settings.py +52 -0
- nucliadb/common/http_clients/auth.py +3 -6
- nucliadb/common/http_clients/processing.py +6 -11
- nucliadb/common/http_clients/utils.py +1 -3
- nucliadb/common/ids.py +240 -0
- nucliadb/common/locking.py +29 -7
- nucliadb/common/maindb/driver.py +11 -35
- nucliadb/common/maindb/exceptions.py +3 -0
- nucliadb/common/maindb/local.py +22 -9
- nucliadb/common/maindb/pg.py +206 -111
- nucliadb/common/maindb/utils.py +11 -42
- nucliadb/common/models_utils/from_proto.py +479 -0
- nucliadb/common/models_utils/to_proto.py +60 -0
- nucliadb/common/nidx.py +260 -0
- nucliadb/export_import/datamanager.py +25 -19
- nucliadb/export_import/exporter.py +5 -11
- nucliadb/export_import/importer.py +5 -7
- nucliadb/export_import/models.py +3 -3
- nucliadb/export_import/tasks.py +4 -4
- nucliadb/export_import/utils.py +25 -37
- nucliadb/health.py +1 -3
- nucliadb/ingest/app.py +15 -11
- nucliadb/ingest/consumer/auditing.py +21 -19
- nucliadb/ingest/consumer/consumer.py +82 -47
- nucliadb/ingest/consumer/materializer.py +5 -12
- nucliadb/ingest/consumer/pull.py +12 -27
- nucliadb/ingest/consumer/service.py +19 -17
- nucliadb/ingest/consumer/shard_creator.py +2 -4
- nucliadb/ingest/consumer/utils.py +1 -3
- nucliadb/ingest/fields/base.py +137 -105
- nucliadb/ingest/fields/conversation.py +18 -5
- nucliadb/ingest/fields/exceptions.py +1 -4
- nucliadb/ingest/fields/file.py +7 -16
- nucliadb/ingest/fields/link.py +5 -10
- nucliadb/ingest/fields/text.py +9 -4
- nucliadb/ingest/orm/brain.py +200 -213
- nucliadb/ingest/orm/broker_message.py +181 -0
- nucliadb/ingest/orm/entities.py +36 -51
- nucliadb/ingest/orm/exceptions.py +12 -0
- nucliadb/ingest/orm/knowledgebox.py +322 -197
- nucliadb/ingest/orm/processor/__init__.py +2 -700
- nucliadb/ingest/orm/processor/auditing.py +4 -23
- nucliadb/ingest/orm/processor/data_augmentation.py +164 -0
- nucliadb/ingest/orm/processor/pgcatalog.py +84 -0
- nucliadb/ingest/orm/processor/processor.py +752 -0
- nucliadb/ingest/orm/processor/sequence_manager.py +1 -1
- nucliadb/ingest/orm/resource.py +249 -403
- nucliadb/ingest/orm/utils.py +4 -4
- nucliadb/ingest/partitions.py +3 -9
- nucliadb/ingest/processing.py +70 -73
- nucliadb/ingest/py.typed +0 -0
- nucliadb/ingest/serialize.py +37 -167
- nucliadb/ingest/service/__init__.py +1 -3
- nucliadb/ingest/service/writer.py +185 -412
- nucliadb/ingest/settings.py +10 -20
- nucliadb/ingest/utils.py +3 -6
- nucliadb/learning_proxy.py +242 -55
- nucliadb/metrics_exporter.py +30 -19
- nucliadb/middleware/__init__.py +1 -3
- nucliadb/migrator/command.py +1 -3
- nucliadb/migrator/datamanager.py +13 -13
- nucliadb/migrator/migrator.py +47 -30
- nucliadb/migrator/utils.py +18 -10
- nucliadb/purge/__init__.py +139 -33
- nucliadb/purge/orphan_shards.py +7 -13
- nucliadb/reader/__init__.py +1 -3
- nucliadb/reader/api/models.py +1 -12
- nucliadb/reader/api/v1/__init__.py +0 -1
- nucliadb/reader/api/v1/download.py +21 -88
- nucliadb/reader/api/v1/export_import.py +1 -1
- nucliadb/reader/api/v1/knowledgebox.py +10 -10
- nucliadb/reader/api/v1/learning_config.py +2 -6
- nucliadb/reader/api/v1/resource.py +62 -88
- nucliadb/reader/api/v1/services.py +64 -83
- nucliadb/reader/app.py +12 -29
- nucliadb/reader/lifecycle.py +18 -4
- nucliadb/reader/py.typed +0 -0
- nucliadb/reader/reader/notifications.py +10 -28
- nucliadb/search/__init__.py +1 -3
- nucliadb/search/api/v1/__init__.py +1 -2
- nucliadb/search/api/v1/ask.py +17 -10
- nucliadb/search/api/v1/catalog.py +184 -0
- nucliadb/search/api/v1/feedback.py +16 -24
- nucliadb/search/api/v1/find.py +36 -36
- nucliadb/search/api/v1/knowledgebox.py +89 -60
- nucliadb/search/api/v1/resource/ask.py +2 -8
- nucliadb/search/api/v1/resource/search.py +49 -70
- nucliadb/search/api/v1/search.py +44 -210
- nucliadb/search/api/v1/suggest.py +39 -54
- nucliadb/search/app.py +12 -32
- nucliadb/search/lifecycle.py +10 -3
- nucliadb/search/predict.py +136 -187
- nucliadb/search/py.typed +0 -0
- nucliadb/search/requesters/utils.py +25 -58
- nucliadb/search/search/cache.py +149 -20
- nucliadb/search/search/chat/ask.py +571 -123
- nucliadb/search/{tests/unit/test_run.py → search/chat/exceptions.py} +14 -14
- nucliadb/search/search/chat/images.py +41 -17
- nucliadb/search/search/chat/prompt.py +817 -266
- nucliadb/search/search/chat/query.py +213 -309
- nucliadb/{tests/migrations/__init__.py → search/search/cut.py} +8 -8
- nucliadb/search/search/fetch.py +43 -36
- nucliadb/search/search/filters.py +9 -15
- nucliadb/search/search/find.py +214 -53
- nucliadb/search/search/find_merge.py +408 -391
- nucliadb/search/search/hydrator.py +191 -0
- nucliadb/search/search/merge.py +187 -223
- nucliadb/search/search/metrics.py +73 -2
- nucliadb/search/search/paragraphs.py +64 -106
- nucliadb/search/search/pgcatalog.py +233 -0
- nucliadb/search/search/predict_proxy.py +1 -1
- nucliadb/search/search/query.py +305 -150
- nucliadb/search/search/query_parser/exceptions.py +22 -0
- nucliadb/search/search/query_parser/models.py +101 -0
- nucliadb/search/search/query_parser/parser.py +183 -0
- nucliadb/search/search/rank_fusion.py +204 -0
- nucliadb/search/search/rerankers.py +270 -0
- nucliadb/search/search/shards.py +3 -32
- nucliadb/search/search/summarize.py +7 -18
- nucliadb/search/search/utils.py +27 -4
- nucliadb/search/settings.py +15 -1
- nucliadb/standalone/api_router.py +4 -10
- nucliadb/standalone/app.py +8 -14
- nucliadb/standalone/auth.py +7 -21
- nucliadb/standalone/config.py +7 -10
- nucliadb/standalone/lifecycle.py +26 -25
- nucliadb/standalone/migrations.py +1 -3
- nucliadb/standalone/purge.py +1 -1
- nucliadb/standalone/py.typed +0 -0
- nucliadb/standalone/run.py +3 -6
- nucliadb/standalone/settings.py +9 -16
- nucliadb/standalone/versions.py +15 -5
- nucliadb/tasks/consumer.py +8 -12
- nucliadb/tasks/producer.py +7 -6
- nucliadb/tests/config.py +53 -0
- nucliadb/train/__init__.py +1 -3
- nucliadb/train/api/utils.py +1 -2
- nucliadb/train/api/v1/shards.py +1 -1
- nucliadb/train/api/v1/trainset.py +2 -4
- nucliadb/train/app.py +10 -31
- nucliadb/train/generator.py +10 -19
- nucliadb/train/generators/field_classifier.py +7 -19
- nucliadb/train/generators/field_streaming.py +156 -0
- nucliadb/train/generators/image_classifier.py +12 -18
- nucliadb/train/generators/paragraph_classifier.py +5 -9
- nucliadb/train/generators/paragraph_streaming.py +6 -9
- nucliadb/train/generators/question_answer_streaming.py +19 -20
- nucliadb/train/generators/sentence_classifier.py +9 -15
- nucliadb/train/generators/token_classifier.py +48 -39
- nucliadb/train/generators/utils.py +14 -18
- nucliadb/train/lifecycle.py +7 -3
- nucliadb/train/nodes.py +23 -32
- nucliadb/train/py.typed +0 -0
- nucliadb/train/servicer.py +13 -21
- nucliadb/train/settings.py +2 -6
- nucliadb/train/types.py +13 -10
- nucliadb/train/upload.py +3 -6
- nucliadb/train/uploader.py +19 -23
- nucliadb/train/utils.py +1 -1
- nucliadb/writer/__init__.py +1 -3
- nucliadb/{ingest/fields/keywordset.py → writer/api/utils.py} +13 -10
- nucliadb/writer/api/v1/export_import.py +67 -14
- nucliadb/writer/api/v1/field.py +16 -269
- nucliadb/writer/api/v1/knowledgebox.py +218 -68
- nucliadb/writer/api/v1/resource.py +68 -88
- nucliadb/writer/api/v1/services.py +51 -70
- nucliadb/writer/api/v1/slug.py +61 -0
- nucliadb/writer/api/v1/transaction.py +67 -0
- nucliadb/writer/api/v1/upload.py +143 -117
- nucliadb/writer/app.py +6 -43
- nucliadb/writer/back_pressure.py +16 -38
- nucliadb/writer/exceptions.py +0 -4
- nucliadb/writer/lifecycle.py +21 -15
- nucliadb/writer/py.typed +0 -0
- nucliadb/writer/resource/audit.py +2 -1
- nucliadb/writer/resource/basic.py +48 -46
- nucliadb/writer/resource/field.py +37 -128
- nucliadb/writer/resource/origin.py +1 -2
- nucliadb/writer/settings.py +6 -2
- nucliadb/writer/tus/__init__.py +17 -15
- nucliadb/writer/tus/azure.py +111 -0
- nucliadb/writer/tus/dm.py +17 -5
- nucliadb/writer/tus/exceptions.py +1 -3
- nucliadb/writer/tus/gcs.py +49 -84
- nucliadb/writer/tus/local.py +21 -37
- nucliadb/writer/tus/s3.py +28 -68
- nucliadb/writer/tus/storage.py +5 -56
- nucliadb/writer/vectorsets.py +125 -0
- nucliadb-6.2.1.post2798.dist-info/METADATA +148 -0
- nucliadb-6.2.1.post2798.dist-info/RECORD +343 -0
- {nucliadb-4.0.0.post542.dist-info → nucliadb-6.2.1.post2798.dist-info}/WHEEL +1 -1
- nucliadb/common/maindb/redis.py +0 -194
- nucliadb/common/maindb/tikv.py +0 -433
- nucliadb/ingest/fields/layout.py +0 -58
- nucliadb/ingest/tests/conftest.py +0 -30
- nucliadb/ingest/tests/fixtures.py +0 -764
- nucliadb/ingest/tests/integration/consumer/__init__.py +0 -18
- nucliadb/ingest/tests/integration/consumer/test_auditing.py +0 -78
- nucliadb/ingest/tests/integration/consumer/test_materializer.py +0 -126
- nucliadb/ingest/tests/integration/consumer/test_pull.py +0 -144
- nucliadb/ingest/tests/integration/consumer/test_service.py +0 -81
- nucliadb/ingest/tests/integration/consumer/test_shard_creator.py +0 -68
- nucliadb/ingest/tests/integration/ingest/test_ingest.py +0 -684
- nucliadb/ingest/tests/integration/ingest/test_processing_engine.py +0 -95
- nucliadb/ingest/tests/integration/ingest/test_relations.py +0 -272
- nucliadb/ingest/tests/unit/consumer/__init__.py +0 -18
- nucliadb/ingest/tests/unit/consumer/test_auditing.py +0 -139
- nucliadb/ingest/tests/unit/consumer/test_consumer.py +0 -69
- nucliadb/ingest/tests/unit/consumer/test_pull.py +0 -60
- nucliadb/ingest/tests/unit/consumer/test_shard_creator.py +0 -140
- nucliadb/ingest/tests/unit/consumer/test_utils.py +0 -67
- nucliadb/ingest/tests/unit/orm/__init__.py +0 -19
- nucliadb/ingest/tests/unit/orm/test_brain.py +0 -247
- nucliadb/ingest/tests/unit/orm/test_brain_vectors.py +0 -74
- nucliadb/ingest/tests/unit/orm/test_processor.py +0 -131
- nucliadb/ingest/tests/unit/orm/test_resource.py +0 -331
- nucliadb/ingest/tests/unit/test_cache.py +0 -31
- nucliadb/ingest/tests/unit/test_partitions.py +0 -40
- nucliadb/ingest/tests/unit/test_processing.py +0 -171
- nucliadb/middleware/transaction.py +0 -117
- nucliadb/reader/api/v1/learning_collector.py +0 -63
- nucliadb/reader/tests/__init__.py +0 -19
- nucliadb/reader/tests/conftest.py +0 -31
- nucliadb/reader/tests/fixtures.py +0 -136
- nucliadb/reader/tests/test_list_resources.py +0 -75
- nucliadb/reader/tests/test_reader_file_download.py +0 -273
- nucliadb/reader/tests/test_reader_resource.py +0 -353
- nucliadb/reader/tests/test_reader_resource_field.py +0 -219
- nucliadb/search/api/v1/chat.py +0 -263
- nucliadb/search/api/v1/resource/chat.py +0 -174
- nucliadb/search/tests/__init__.py +0 -19
- nucliadb/search/tests/conftest.py +0 -33
- nucliadb/search/tests/fixtures.py +0 -199
- nucliadb/search/tests/node.py +0 -466
- nucliadb/search/tests/unit/__init__.py +0 -18
- nucliadb/search/tests/unit/api/__init__.py +0 -19
- nucliadb/search/tests/unit/api/v1/__init__.py +0 -19
- nucliadb/search/tests/unit/api/v1/resource/__init__.py +0 -19
- nucliadb/search/tests/unit/api/v1/resource/test_chat.py +0 -98
- nucliadb/search/tests/unit/api/v1/test_ask.py +0 -120
- nucliadb/search/tests/unit/api/v1/test_chat.py +0 -96
- nucliadb/search/tests/unit/api/v1/test_predict_proxy.py +0 -98
- nucliadb/search/tests/unit/api/v1/test_summarize.py +0 -99
- nucliadb/search/tests/unit/search/__init__.py +0 -18
- nucliadb/search/tests/unit/search/requesters/__init__.py +0 -18
- nucliadb/search/tests/unit/search/requesters/test_utils.py +0 -211
- nucliadb/search/tests/unit/search/search/__init__.py +0 -19
- nucliadb/search/tests/unit/search/search/test_shards.py +0 -45
- nucliadb/search/tests/unit/search/search/test_utils.py +0 -82
- nucliadb/search/tests/unit/search/test_chat_prompt.py +0 -270
- nucliadb/search/tests/unit/search/test_fetch.py +0 -108
- nucliadb/search/tests/unit/search/test_filters.py +0 -125
- nucliadb/search/tests/unit/search/test_paragraphs.py +0 -157
- nucliadb/search/tests/unit/search/test_predict_proxy.py +0 -106
- nucliadb/search/tests/unit/search/test_query.py +0 -153
- nucliadb/search/tests/unit/test_app.py +0 -79
- nucliadb/search/tests/unit/test_find_merge.py +0 -112
- nucliadb/search/tests/unit/test_merge.py +0 -34
- nucliadb/search/tests/unit/test_predict.py +0 -525
- nucliadb/standalone/tests/__init__.py +0 -19
- nucliadb/standalone/tests/conftest.py +0 -33
- nucliadb/standalone/tests/fixtures.py +0 -38
- nucliadb/standalone/tests/unit/__init__.py +0 -18
- nucliadb/standalone/tests/unit/test_api_router.py +0 -61
- nucliadb/standalone/tests/unit/test_auth.py +0 -169
- nucliadb/standalone/tests/unit/test_introspect.py +0 -35
- nucliadb/standalone/tests/unit/test_migrations.py +0 -63
- nucliadb/standalone/tests/unit/test_versions.py +0 -68
- nucliadb/tests/benchmarks/__init__.py +0 -19
- nucliadb/tests/benchmarks/test_search.py +0 -99
- nucliadb/tests/conftest.py +0 -32
- nucliadb/tests/fixtures.py +0 -735
- nucliadb/tests/knowledgeboxes/philosophy_books.py +0 -202
- nucliadb/tests/knowledgeboxes/ten_dummy_resources.py +0 -107
- nucliadb/tests/migrations/test_migration_0017.py +0 -76
- nucliadb/tests/migrations/test_migration_0018.py +0 -95
- nucliadb/tests/tikv.py +0 -240
- nucliadb/tests/unit/__init__.py +0 -19
- nucliadb/tests/unit/common/__init__.py +0 -19
- nucliadb/tests/unit/common/cluster/__init__.py +0 -19
- nucliadb/tests/unit/common/cluster/discovery/__init__.py +0 -19
- nucliadb/tests/unit/common/cluster/discovery/test_k8s.py +0 -172
- nucliadb/tests/unit/common/cluster/standalone/__init__.py +0 -18
- nucliadb/tests/unit/common/cluster/standalone/test_service.py +0 -114
- nucliadb/tests/unit/common/cluster/standalone/test_utils.py +0 -61
- nucliadb/tests/unit/common/cluster/test_cluster.py +0 -408
- nucliadb/tests/unit/common/cluster/test_kb_shard_manager.py +0 -173
- nucliadb/tests/unit/common/cluster/test_rebalance.py +0 -38
- nucliadb/tests/unit/common/cluster/test_rollover.py +0 -282
- nucliadb/tests/unit/common/maindb/__init__.py +0 -18
- nucliadb/tests/unit/common/maindb/test_driver.py +0 -127
- nucliadb/tests/unit/common/maindb/test_tikv.py +0 -53
- nucliadb/tests/unit/common/maindb/test_utils.py +0 -92
- nucliadb/tests/unit/common/test_context.py +0 -36
- nucliadb/tests/unit/export_import/__init__.py +0 -19
- nucliadb/tests/unit/export_import/test_datamanager.py +0 -37
- nucliadb/tests/unit/export_import/test_utils.py +0 -301
- nucliadb/tests/unit/migrator/__init__.py +0 -19
- nucliadb/tests/unit/migrator/test_migrator.py +0 -87
- nucliadb/tests/unit/tasks/__init__.py +0 -19
- nucliadb/tests/unit/tasks/conftest.py +0 -42
- nucliadb/tests/unit/tasks/test_consumer.py +0 -92
- nucliadb/tests/unit/tasks/test_producer.py +0 -95
- nucliadb/tests/unit/tasks/test_tasks.py +0 -58
- nucliadb/tests/unit/test_field_ids.py +0 -49
- nucliadb/tests/unit/test_health.py +0 -86
- nucliadb/tests/unit/test_kb_slugs.py +0 -54
- nucliadb/tests/unit/test_learning_proxy.py +0 -252
- nucliadb/tests/unit/test_metrics_exporter.py +0 -77
- nucliadb/tests/unit/test_purge.py +0 -136
- nucliadb/tests/utils/__init__.py +0 -74
- nucliadb/tests/utils/aiohttp_session.py +0 -44
- nucliadb/tests/utils/broker_messages/__init__.py +0 -171
- nucliadb/tests/utils/broker_messages/fields.py +0 -197
- nucliadb/tests/utils/broker_messages/helpers.py +0 -33
- nucliadb/tests/utils/entities.py +0 -78
- nucliadb/train/api/v1/check.py +0 -60
- nucliadb/train/tests/__init__.py +0 -19
- nucliadb/train/tests/conftest.py +0 -29
- nucliadb/train/tests/fixtures.py +0 -342
- nucliadb/train/tests/test_field_classification.py +0 -122
- nucliadb/train/tests/test_get_entities.py +0 -80
- nucliadb/train/tests/test_get_info.py +0 -51
- nucliadb/train/tests/test_get_ontology.py +0 -34
- nucliadb/train/tests/test_get_ontology_count.py +0 -63
- nucliadb/train/tests/test_image_classification.py +0 -221
- nucliadb/train/tests/test_list_fields.py +0 -39
- nucliadb/train/tests/test_list_paragraphs.py +0 -73
- nucliadb/train/tests/test_list_resources.py +0 -39
- nucliadb/train/tests/test_list_sentences.py +0 -71
- nucliadb/train/tests/test_paragraph_classification.py +0 -123
- nucliadb/train/tests/test_paragraph_streaming.py +0 -118
- nucliadb/train/tests/test_question_answer_streaming.py +0 -239
- nucliadb/train/tests/test_sentence_classification.py +0 -143
- nucliadb/train/tests/test_token_classification.py +0 -136
- nucliadb/train/tests/utils.py +0 -101
- nucliadb/writer/layouts/__init__.py +0 -51
- nucliadb/writer/layouts/v1.py +0 -59
- nucliadb/writer/tests/__init__.py +0 -19
- nucliadb/writer/tests/conftest.py +0 -31
- nucliadb/writer/tests/fixtures.py +0 -191
- nucliadb/writer/tests/test_fields.py +0 -475
- nucliadb/writer/tests/test_files.py +0 -740
- nucliadb/writer/tests/test_knowledgebox.py +0 -49
- nucliadb/writer/tests/test_reprocess_file_field.py +0 -133
- nucliadb/writer/tests/test_resources.py +0 -476
- nucliadb/writer/tests/test_service.py +0 -137
- nucliadb/writer/tests/test_tus.py +0 -203
- nucliadb/writer/tests/utils.py +0 -35
- nucliadb/writer/tus/pg.py +0 -125
- nucliadb-4.0.0.post542.dist-info/METADATA +0 -135
- nucliadb-4.0.0.post542.dist-info/RECORD +0 -462
- {nucliadb/ingest/tests → migrations/pg}/__init__.py +0 -0
- /nucliadb/{ingest/tests/integration → common/external_index_providers}/__init__.py +0 -0
- /nucliadb/{ingest/tests/integration/ingest → common/models_utils}/__init__.py +0 -0
- /nucliadb/{ingest/tests/unit → search/search/query_parser}/__init__.py +0 -0
- /nucliadb/{ingest/tests → tests}/vectors.py +0 -0
- {nucliadb-4.0.0.post542.dist-info → nucliadb-6.2.1.post2798.dist-info}/entry_points.txt +0 -0
- {nucliadb-4.0.0.post542.dist-info → nucliadb-6.2.1.post2798.dist-info}/top_level.txt +0 -0
- {nucliadb-4.0.0.post542.dist-info → nucliadb-6.2.1.post2798.dist-info}/zip-safe +0 -0
@@ -17,32 +17,55 @@
|
|
17
17
|
# You should have received a copy of the GNU Affero General Public License
|
18
18
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
19
19
|
#
|
20
|
+
import asyncio
|
21
|
+
import copy
|
22
|
+
from collections import deque
|
20
23
|
from dataclasses import dataclass
|
21
|
-
from typing import Dict, List, Optional, Sequence, Tuple
|
24
|
+
from typing import Deque, Dict, List, Optional, Sequence, Tuple, Union, cast
|
22
25
|
|
26
|
+
import yaml
|
27
|
+
from pydantic import BaseModel
|
28
|
+
|
29
|
+
from nucliadb.common.ids import FIELD_TYPE_STR_TO_PB, FieldId, ParagraphId
|
30
|
+
from nucliadb.common.maindb.utils import get_driver
|
31
|
+
from nucliadb.common.models_utils import from_proto
|
23
32
|
from nucliadb.ingest.fields.base import Field
|
24
33
|
from nucliadb.ingest.fields.conversation import Conversation
|
34
|
+
from nucliadb.ingest.fields.file import File
|
25
35
|
from nucliadb.ingest.orm.knowledgebox import KnowledgeBox as KnowledgeBoxORM
|
26
|
-
from nucliadb.ingest.orm.resource import KB_REVERSE
|
27
|
-
from nucliadb.ingest.orm.resource import Resource as ResourceORM
|
28
|
-
from nucliadb.middleware.transaction import get_read_only_transaction
|
29
36
|
from nucliadb.search import logger
|
30
|
-
from nucliadb.search.search import
|
31
|
-
from nucliadb.search.search.chat.images import
|
37
|
+
from nucliadb.search.search import cache
|
38
|
+
from nucliadb.search.search.chat.images import (
|
39
|
+
get_file_thumbnail_image,
|
40
|
+
get_page_image,
|
41
|
+
get_paragraph_image,
|
42
|
+
)
|
43
|
+
from nucliadb.search.search.hydrator import hydrate_field_text, hydrate_resource_text
|
44
|
+
from nucliadb.search.search.paragraphs import get_paragraph_text
|
45
|
+
from nucliadb_models.metadata import Extra, Origin
|
32
46
|
from nucliadb_models.search import (
|
33
47
|
SCORE_TYPE,
|
48
|
+
ConversationalStrategy,
|
49
|
+
FieldExtensionStrategy,
|
34
50
|
FindParagraph,
|
51
|
+
FullResourceStrategy,
|
52
|
+
HierarchyResourceStrategy,
|
35
53
|
ImageRagStrategy,
|
36
54
|
ImageRagStrategyName,
|
37
|
-
|
55
|
+
MetadataExtensionStrategy,
|
56
|
+
MetadataExtensionType,
|
57
|
+
NeighbouringParagraphsStrategy,
|
58
|
+
PageImageStrategy,
|
59
|
+
ParagraphImageStrategy,
|
38
60
|
PromptContext,
|
39
61
|
PromptContextImages,
|
40
62
|
PromptContextOrder,
|
41
63
|
RagStrategy,
|
42
64
|
RagStrategyName,
|
65
|
+
TableImageStrategy,
|
43
66
|
)
|
44
67
|
from nucliadb_protos import resources_pb2
|
45
|
-
from nucliadb_utils.asyncio_utils import
|
68
|
+
from nucliadb_utils.asyncio_utils import run_concurrently
|
46
69
|
from nucliadb_utils.utilities import get_storage
|
47
70
|
|
48
71
|
MAX_RESOURCE_TASKS = 5
|
@@ -53,6 +76,12 @@ MAX_RESOURCE_FIELD_TASKS = 4
|
|
53
76
|
# The hope here is it will be enough to get the answer to the question.
|
54
77
|
CONVERSATION_MESSAGE_CONTEXT_EXPANSION = 15
|
55
78
|
|
79
|
+
TextBlockId = Union[ParagraphId, FieldId]
|
80
|
+
|
81
|
+
|
82
|
+
class ParagraphIdNotFoundInExtractedMetadata(Exception):
|
83
|
+
pass
|
84
|
+
|
56
85
|
|
57
86
|
class CappedPromptContext:
|
58
87
|
"""
|
@@ -70,16 +99,26 @@ class CappedPromptContext:
|
|
70
99
|
self._size = 0
|
71
100
|
|
72
101
|
def __setitem__(self, key: str, value: str) -> None:
|
102
|
+
prev_value_len = len(self.output.get(key, ""))
|
73
103
|
if self.max_size is None:
|
74
|
-
# Unbounded size
|
75
|
-
|
104
|
+
# Unbounded size context
|
105
|
+
to_add = value
|
76
106
|
else:
|
77
|
-
|
78
|
-
self._size
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
107
|
+
# Make sure we don't exceed the max size
|
108
|
+
size_available = max(self.max_size - self._size + prev_value_len, 0)
|
109
|
+
to_add = value[:size_available]
|
110
|
+
self.output[key] = to_add
|
111
|
+
self._size = self._size - prev_value_len + len(to_add)
|
112
|
+
|
113
|
+
def __getitem__(self, key: str) -> str:
|
114
|
+
return self.output.__getitem__(key)
|
115
|
+
|
116
|
+
def __delitem__(self, key: str) -> None:
|
117
|
+
value = self.output.pop(key, "")
|
118
|
+
self._size -= len(value)
|
119
|
+
|
120
|
+
def text_block_ids(self) -> list[str]:
|
121
|
+
return list(self.output.keys())
|
83
122
|
|
84
123
|
@property
|
85
124
|
def size(self) -> int:
|
@@ -94,15 +133,15 @@ async def get_next_conversation_messages(
|
|
94
133
|
num_messages: int,
|
95
134
|
message_type: Optional[resources_pb2.Message.MessageType.ValueType] = None,
|
96
135
|
msg_to: Optional[str] = None,
|
97
|
-
):
|
136
|
+
) -> List[resources_pb2.Message]:
|
98
137
|
output = []
|
99
138
|
cmetadata = await field_obj.get_metadata()
|
100
139
|
for current_page in range(page, cmetadata.pages + 1):
|
101
140
|
conv = await field_obj.db_get_value(current_page)
|
102
141
|
for message in conv.messages[start_idx:]:
|
103
|
-
if message_type is not None and message.type != message_type:
|
142
|
+
if message_type is not None and message.type != message_type: # pragma: no cover
|
104
143
|
continue
|
105
|
-
if msg_to is not None and msg_to not in message.to:
|
144
|
+
if msg_to is not None and msg_to not in message.to: # pragma: no cover
|
106
145
|
continue
|
107
146
|
output.append(message)
|
108
147
|
if len(output) >= num_messages:
|
@@ -125,16 +164,21 @@ async def find_conversation_message(
|
|
125
164
|
|
126
165
|
|
127
166
|
async def get_expanded_conversation_messages(
|
128
|
-
*,
|
167
|
+
*,
|
168
|
+
kb: KnowledgeBoxORM,
|
169
|
+
rid: str,
|
170
|
+
field_id: str,
|
171
|
+
mident: str,
|
172
|
+
max_messages: int = CONVERSATION_MESSAGE_CONTEXT_EXPANSION,
|
129
173
|
) -> list[resources_pb2.Message]:
|
130
174
|
resource = await kb.get(rid)
|
131
|
-
if resource is None:
|
175
|
+
if resource is None: # pragma: no cover
|
132
176
|
return []
|
133
|
-
field_obj = await resource.get_field(field_id,
|
177
|
+
field_obj: Conversation = await resource.get_field(field_id, FIELD_TYPE_STR_TO_PB["c"], load=True) # type: ignore
|
134
178
|
found_message, found_page, found_idx = await find_conversation_message(
|
135
179
|
field_obj=field_obj, mident=mident
|
136
180
|
)
|
137
|
-
if found_message is None:
|
181
|
+
if found_message is None: # pragma: no cover
|
138
182
|
return []
|
139
183
|
elif found_message.type == resources_pb2.Message.MessageType.QUESTION:
|
140
184
|
# only try to get answer if it was a question
|
@@ -150,7 +194,7 @@ async def get_expanded_conversation_messages(
|
|
150
194
|
field_obj=field_obj,
|
151
195
|
page=found_page,
|
152
196
|
start_idx=found_idx + 1,
|
153
|
-
num_messages=
|
197
|
+
num_messages=max_messages,
|
154
198
|
)
|
155
199
|
|
156
200
|
|
@@ -169,83 +213,27 @@ async def default_prompt_context(
|
|
169
213
|
- Using an dict prevents from duplicates pulled in through conversation expansion.
|
170
214
|
"""
|
171
215
|
# Sort retrieved paragraphs by decreasing order (most relevant first)
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
async def get_field_extracted_text(field: Field) -> Optional[tuple[Field, str]]:
|
196
|
-
extracted_text_pb = await field.get_extracted_text(force=True)
|
197
|
-
if extracted_text_pb is None:
|
198
|
-
return None
|
199
|
-
return field, extracted_text_pb.text
|
200
|
-
|
201
|
-
|
202
|
-
async def get_resource_field_extracted_text(
|
203
|
-
kb_obj: KnowledgeBoxORM,
|
204
|
-
resource_uuid,
|
205
|
-
field_id: str,
|
206
|
-
) -> Optional[tuple[Field, str]]:
|
207
|
-
resource = await kb_obj.get(resource_uuid)
|
208
|
-
if resource is None:
|
209
|
-
return None
|
210
|
-
|
211
|
-
try:
|
212
|
-
field_type, field_key = field_id.strip("/").split("/")
|
213
|
-
except ValueError:
|
214
|
-
logger.error(f"Invalid field id: {field_id}. Skipping getting extracted text.")
|
215
|
-
return None
|
216
|
-
field = await resource.get_field(field_key, KB_REVERSE[field_type], load=False)
|
217
|
-
if field is None:
|
218
|
-
return None
|
219
|
-
result = await get_field_extracted_text(field)
|
220
|
-
if result is None:
|
221
|
-
return None
|
222
|
-
_, extracted_text = result
|
223
|
-
return field, extracted_text
|
224
|
-
|
225
|
-
|
226
|
-
async def get_resource_extracted_texts(
|
227
|
-
kbid: str,
|
228
|
-
resource_uuid: str,
|
229
|
-
) -> list[tuple[Field, str]]:
|
230
|
-
txn = await get_read_only_transaction()
|
231
|
-
storage = await get_storage()
|
232
|
-
kb = KnowledgeBoxORM(txn, storage, kbid)
|
233
|
-
resource = ResourceORM(
|
234
|
-
txn=txn,
|
235
|
-
storage=storage,
|
236
|
-
kb=kb,
|
237
|
-
uuid=resource_uuid,
|
238
|
-
)
|
239
|
-
|
240
|
-
# Schedule the extraction of the text of each field in the resource
|
241
|
-
runner = ConcurrentRunner(max_tasks=MAX_RESOURCE_FIELD_TASKS)
|
242
|
-
for field_type, field_key in await resource.get_fields(force=True):
|
243
|
-
field = await resource.get_field(field_key, field_type, load=False)
|
244
|
-
runner.schedule(get_field_extracted_text(field))
|
245
|
-
|
246
|
-
# Wait for the results
|
247
|
-
results = await runner.wait()
|
248
|
-
return [result for result in results if result is not None]
|
216
|
+
async with get_driver().transaction(read_only=True) as txn:
|
217
|
+
storage = await get_storage()
|
218
|
+
kb = KnowledgeBoxORM(txn, storage, kbid)
|
219
|
+
for paragraph in ordered_paragraphs:
|
220
|
+
context[paragraph.id] = _clean_paragraph_text(paragraph)
|
221
|
+
|
222
|
+
# If the paragraph is a conversation and it matches semantically, we assume we
|
223
|
+
# have matched with the question, therefore try to include the answer to the
|
224
|
+
# context by pulling the next few messages of the conversation field
|
225
|
+
rid, field_type, field_id, mident = paragraph.id.split("/")[:4]
|
226
|
+
if field_type == "c" and paragraph.score_type in (
|
227
|
+
SCORE_TYPE.VECTOR,
|
228
|
+
SCORE_TYPE.BOTH,
|
229
|
+
):
|
230
|
+
expanded_msgs = await get_expanded_conversation_messages(
|
231
|
+
kb=kb, rid=rid, field_id=field_id, mident=mident
|
232
|
+
)
|
233
|
+
for msg in expanded_msgs:
|
234
|
+
text = msg.content.text.strip()
|
235
|
+
pid = f"{rid}/{field_type}/{field_id}/{msg.ident}/0-{len(msg.content.text) + 1}"
|
236
|
+
context[pid] = text
|
249
237
|
|
250
238
|
|
251
239
|
async def full_resource_prompt_context(
|
@@ -253,19 +241,18 @@ async def full_resource_prompt_context(
|
|
253
241
|
kbid: str,
|
254
242
|
ordered_paragraphs: list[FindParagraph],
|
255
243
|
resource: Optional[str],
|
256
|
-
|
244
|
+
strategy: FullResourceStrategy,
|
257
245
|
) -> None:
|
258
246
|
"""
|
259
247
|
Algorithm steps:
|
260
248
|
- Collect the list of resources in the results (in order of relevance).
|
261
249
|
- For each resource, collect the extracted text from all its fields and craft the context.
|
262
|
-
|
263
250
|
Arguments:
|
264
251
|
context: The context to be updated.
|
265
252
|
kbid: The knowledge box id.
|
266
|
-
|
253
|
+
ordered_paragraphs: The results of the retrieval (find) operation.
|
267
254
|
resource: The resource to be included in the context. This is used only when chatting with a specific resource with no retrieval.
|
268
|
-
|
255
|
+
strategy: strategy instance containing, for example, the number of full resources to include in the context.
|
269
256
|
""" # noqa: E501
|
270
257
|
if resource is not None:
|
271
258
|
# The user has specified a resource to be included in the context.
|
@@ -274,32 +261,205 @@ async def full_resource_prompt_context(
|
|
274
261
|
# Collect the list of resources in the results (in order of relevance).
|
275
262
|
ordered_resources = []
|
276
263
|
for paragraph in ordered_paragraphs:
|
277
|
-
resource_uuid = paragraph.id.
|
264
|
+
resource_uuid = parse_text_block_id(paragraph.id).rid
|
278
265
|
if resource_uuid not in ordered_resources:
|
279
|
-
|
266
|
+
skip = False
|
267
|
+
if strategy.apply_to is not None:
|
268
|
+
# decide whether the resource should be extended or not
|
269
|
+
for label in strategy.apply_to.exclude:
|
270
|
+
skip = skip or (label in (paragraph.labels or []))
|
271
|
+
|
272
|
+
if not skip:
|
273
|
+
ordered_resources.append(resource_uuid)
|
280
274
|
|
281
275
|
# For each resource, collect the extracted text from all its fields.
|
282
|
-
|
276
|
+
resources_extracted_texts = await run_concurrently(
|
283
277
|
[
|
284
|
-
|
285
|
-
for resource_uuid in ordered_resources[:
|
278
|
+
hydrate_resource_text(kbid, resource_uuid, max_concurrent_tasks=MAX_RESOURCE_FIELD_TASKS)
|
279
|
+
for resource_uuid in ordered_resources[: strategy.count]
|
286
280
|
],
|
287
281
|
max_concurrent=MAX_RESOURCE_TASKS,
|
288
282
|
)
|
289
|
-
|
290
|
-
for
|
291
|
-
if
|
283
|
+
added_fields = set()
|
284
|
+
for resource_extracted_texts in resources_extracted_texts:
|
285
|
+
if resource_extracted_texts is None:
|
292
286
|
continue
|
293
|
-
for field, extracted_text in
|
287
|
+
for field, extracted_text in resource_extracted_texts:
|
288
|
+
# First off, remove the text block ids from paragraphs that belong to
|
289
|
+
# the same field, as otherwise the context will be duplicated.
|
290
|
+
for tb_id in context.text_block_ids():
|
291
|
+
if tb_id.startswith(field.full()):
|
292
|
+
del context[tb_id]
|
294
293
|
# Add the extracted text of each field to the context.
|
295
|
-
context[field.
|
294
|
+
context[field.full()] = extracted_text
|
295
|
+
added_fields.add(field.full())
|
296
|
+
|
297
|
+
if strategy.include_remaining_text_blocks:
|
298
|
+
for paragraph in ordered_paragraphs:
|
299
|
+
pid = cast(ParagraphId, parse_text_block_id(paragraph.id))
|
300
|
+
if pid.field_id.full() not in added_fields:
|
301
|
+
context[paragraph.id] = _clean_paragraph_text(paragraph)
|
296
302
|
|
297
303
|
|
298
|
-
async def
|
304
|
+
async def extend_prompt_context_with_metadata(
|
305
|
+
context: CappedPromptContext,
|
306
|
+
kbid: str,
|
307
|
+
strategy: MetadataExtensionStrategy,
|
308
|
+
) -> None:
|
309
|
+
text_block_ids: list[TextBlockId] = []
|
310
|
+
for text_block_id in context.text_block_ids():
|
311
|
+
try:
|
312
|
+
text_block_ids.append(parse_text_block_id(text_block_id))
|
313
|
+
except ValueError: # pragma: no cover
|
314
|
+
# Some text block ids are not paragraphs nor fields, so they are skipped
|
315
|
+
# (e.g. USER_CONTEXT_0, when the user provides extra context)
|
316
|
+
continue
|
317
|
+
if len(text_block_ids) == 0: # pragma: no cover
|
318
|
+
return
|
319
|
+
|
320
|
+
if MetadataExtensionType.ORIGIN in strategy.types:
|
321
|
+
await extend_prompt_context_with_origin_metadata(context, kbid, text_block_ids)
|
322
|
+
|
323
|
+
if MetadataExtensionType.CLASSIFICATION_LABELS in strategy.types:
|
324
|
+
await extend_prompt_context_with_classification_labels(context, kbid, text_block_ids)
|
325
|
+
|
326
|
+
if MetadataExtensionType.NERS in strategy.types:
|
327
|
+
await extend_prompt_context_with_ner(context, kbid, text_block_ids)
|
328
|
+
|
329
|
+
if MetadataExtensionType.EXTRA_METADATA in strategy.types:
|
330
|
+
await extend_prompt_context_with_extra_metadata(context, kbid, text_block_ids)
|
331
|
+
|
332
|
+
|
333
|
+
def parse_text_block_id(text_block_id: str) -> TextBlockId:
|
334
|
+
try:
|
335
|
+
# Typically, the text block id is a paragraph id
|
336
|
+
return ParagraphId.from_string(text_block_id)
|
337
|
+
except ValueError:
|
338
|
+
# When we're doing `full_resource` or `hierarchy` strategies,the text block id
|
339
|
+
# is a field id
|
340
|
+
return FieldId.from_string(text_block_id)
|
341
|
+
|
342
|
+
|
343
|
+
async def extend_prompt_context_with_origin_metadata(context, kbid, text_block_ids: list[TextBlockId]):
|
344
|
+
async def _get_origin(kbid: str, rid: str) -> tuple[str, Optional[Origin]]:
|
345
|
+
origin = None
|
346
|
+
resource = await cache.get_resource(kbid, rid)
|
347
|
+
if resource is not None:
|
348
|
+
pb_origin = await resource.get_origin()
|
349
|
+
if pb_origin is not None:
|
350
|
+
origin = from_proto.origin(pb_origin)
|
351
|
+
return rid, origin
|
352
|
+
|
353
|
+
rids = {tb_id.rid for tb_id in text_block_ids}
|
354
|
+
origins = await run_concurrently([_get_origin(kbid, rid) for rid in rids])
|
355
|
+
rid_to_origin = {rid: origin for rid, origin in origins if origin is not None}
|
356
|
+
for tb_id in text_block_ids:
|
357
|
+
origin = rid_to_origin.get(tb_id.rid)
|
358
|
+
if origin is not None and tb_id.full() in context.output:
|
359
|
+
context[tb_id.full()] += f"\n\nDOCUMENT METADATA AT ORIGIN:\n{to_yaml(origin)}"
|
360
|
+
|
361
|
+
|
362
|
+
async def extend_prompt_context_with_classification_labels(
|
363
|
+
context, kbid, text_block_ids: list[TextBlockId]
|
364
|
+
):
|
365
|
+
async def _get_labels(kbid: str, _id: TextBlockId) -> tuple[TextBlockId, list[tuple[str, str]]]:
|
366
|
+
fid = _id if isinstance(_id, FieldId) else _id.field_id
|
367
|
+
labels = set()
|
368
|
+
resource = await cache.get_resource(kbid, fid.rid)
|
369
|
+
if resource is not None:
|
370
|
+
pb_basic = await resource.get_basic()
|
371
|
+
if pb_basic is not None:
|
372
|
+
# Add the classification labels of the resource
|
373
|
+
for classif in pb_basic.usermetadata.classifications:
|
374
|
+
labels.add((classif.labelset, classif.label))
|
375
|
+
# Add the classifications labels of the field
|
376
|
+
for fc in pb_basic.computedmetadata.field_classifications:
|
377
|
+
if fc.field.field == fid.key and fc.field.field_type == fid.pb_type:
|
378
|
+
for classif in fc.classifications:
|
379
|
+
if classif.cancelled_by_user: # pragma: no cover
|
380
|
+
continue
|
381
|
+
labels.add((classif.labelset, classif.label))
|
382
|
+
return _id, list(labels)
|
383
|
+
|
384
|
+
classif_labels = await run_concurrently([_get_labels(kbid, tb_id) for tb_id in text_block_ids])
|
385
|
+
tb_id_to_labels = {tb_id: labels for tb_id, labels in classif_labels if len(labels) > 0}
|
386
|
+
for tb_id in text_block_ids:
|
387
|
+
labels = tb_id_to_labels.get(tb_id)
|
388
|
+
if labels is not None and tb_id.full() in context.output:
|
389
|
+
labels_text = "DOCUMENT CLASSIFICATION LABELS:"
|
390
|
+
for labelset, label in labels:
|
391
|
+
labels_text += f"\n - {label} ({labelset})"
|
392
|
+
context[tb_id.full()] += "\n\n" + labels_text
|
393
|
+
|
394
|
+
|
395
|
+
async def extend_prompt_context_with_ner(context, kbid, text_block_ids: list[TextBlockId]):
|
396
|
+
async def _get_ners(kbid: str, _id: TextBlockId) -> tuple[TextBlockId, dict[str, set[str]]]:
|
397
|
+
fid = _id if isinstance(_id, FieldId) else _id.field_id
|
398
|
+
ners: dict[str, set[str]] = {}
|
399
|
+
resource = await cache.get_resource(kbid, fid.rid)
|
400
|
+
if resource is not None:
|
401
|
+
field = await resource.get_field(fid.key, fid.pb_type, load=False)
|
402
|
+
fcm = await field.get_field_metadata()
|
403
|
+
if fcm is not None:
|
404
|
+
# Data Augmentation + Processor entities
|
405
|
+
for (
|
406
|
+
data_aumgentation_task_id,
|
407
|
+
entities_wrapper,
|
408
|
+
) in fcm.metadata.entities.items():
|
409
|
+
for entity in entities_wrapper.entities:
|
410
|
+
ners.setdefault(entity.label, set()).add(entity.text)
|
411
|
+
# Legacy processor entities
|
412
|
+
# TODO: Remove once processor doesn't use this anymore and remove the positions and ner fields from the message
|
413
|
+
for token, family in fcm.metadata.ner.items():
|
414
|
+
ners.setdefault(family, set()).add(token)
|
415
|
+
return _id, ners
|
416
|
+
|
417
|
+
nerss = await run_concurrently([_get_ners(kbid, tb_id) for tb_id in text_block_ids])
|
418
|
+
tb_id_to_ners = {tb_id: ners for tb_id, ners in nerss if len(ners) > 0}
|
419
|
+
for tb_id in text_block_ids:
|
420
|
+
ners = tb_id_to_ners.get(tb_id)
|
421
|
+
if ners is not None and tb_id.full() in context.output:
|
422
|
+
ners_text = "DOCUMENT NAMED ENTITIES (NERs):"
|
423
|
+
for family, tokens in ners.items():
|
424
|
+
ners_text += f"\n - {family}:"
|
425
|
+
for token in sorted(list(tokens)):
|
426
|
+
ners_text += f"\n - {token}"
|
427
|
+
context[tb_id.full()] += "\n\n" + ners_text
|
428
|
+
|
429
|
+
|
430
|
+
async def extend_prompt_context_with_extra_metadata(context, kbid, text_block_ids: list[TextBlockId]):
|
431
|
+
async def _get_extra(kbid: str, rid: str) -> tuple[str, Optional[Extra]]:
|
432
|
+
extra = None
|
433
|
+
resource = await cache.get_resource(kbid, rid)
|
434
|
+
if resource is not None:
|
435
|
+
pb_extra = await resource.get_extra()
|
436
|
+
if pb_extra is not None:
|
437
|
+
extra = from_proto.extra(pb_extra)
|
438
|
+
return rid, extra
|
439
|
+
|
440
|
+
rids = {tb_id.rid for tb_id in text_block_ids}
|
441
|
+
extras = await run_concurrently([_get_extra(kbid, rid) for rid in rids])
|
442
|
+
rid_to_extra = {rid: extra for rid, extra in extras if extra is not None}
|
443
|
+
for tb_id in text_block_ids:
|
444
|
+
extra = rid_to_extra.get(tb_id.rid)
|
445
|
+
if extra is not None and tb_id.full() in context.output:
|
446
|
+
context[tb_id.full()] += f"\n\nDOCUMENT EXTRA METADATA:\n{to_yaml(extra)}"
|
447
|
+
|
448
|
+
|
449
|
+
def to_yaml(obj: BaseModel) -> str:
|
450
|
+
return yaml.dump(
|
451
|
+
obj.model_dump(exclude_none=True, exclude_defaults=True, exclude_unset=True),
|
452
|
+
default_flow_style=False,
|
453
|
+
indent=2,
|
454
|
+
sort_keys=True,
|
455
|
+
)
|
456
|
+
|
457
|
+
|
458
|
+
async def field_extension_prompt_context(
|
299
459
|
context: CappedPromptContext,
|
300
460
|
kbid: str,
|
301
461
|
ordered_paragraphs: list[FindParagraph],
|
302
|
-
|
462
|
+
strategy: FieldExtensionStrategy,
|
303
463
|
) -> None:
|
304
464
|
"""
|
305
465
|
Algorithm steps:
|
@@ -310,33 +470,402 @@ async def composed_prompt_context(
|
|
310
470
|
"""
|
311
471
|
ordered_resources = []
|
312
472
|
for paragraph in ordered_paragraphs:
|
313
|
-
resource_uuid = paragraph.id.
|
473
|
+
resource_uuid = ParagraphId.from_string(paragraph.id).rid
|
314
474
|
if resource_uuid not in ordered_resources:
|
315
475
|
ordered_resources.append(resource_uuid)
|
316
476
|
|
317
477
|
# Fetch the extracted texts of the specified fields for each resource
|
318
|
-
|
319
|
-
|
320
|
-
|
321
|
-
|
322
|
-
|
323
|
-
|
324
|
-
|
325
|
-
|
478
|
+
extend_fields = strategy.fields
|
479
|
+
extend_field_ids = []
|
480
|
+
for resource_uuid in ordered_resources:
|
481
|
+
for field_id in extend_fields:
|
482
|
+
try:
|
483
|
+
fid = FieldId.from_string(f"{resource_uuid}/{field_id.strip('/')}")
|
484
|
+
extend_field_ids.append(fid)
|
485
|
+
except ValueError: # pragma: no cover
|
486
|
+
# Invalid field id, skiping
|
487
|
+
continue
|
488
|
+
|
489
|
+
tasks = [hydrate_field_text(kbid, fid) for fid in extend_field_ids]
|
326
490
|
field_extracted_texts = await run_concurrently(tasks)
|
327
491
|
|
328
492
|
for result in field_extracted_texts:
|
329
|
-
if result is None:
|
493
|
+
if result is None: # pragma: no cover
|
330
494
|
continue
|
331
|
-
# Add the extracted text of each field to the beginning of the context.
|
332
495
|
field, extracted_text = result
|
333
|
-
|
496
|
+
# First off, remove the text block ids from paragraphs that belong to
|
497
|
+
# the same field, as otherwise the context will be duplicated.
|
498
|
+
for tb_id in context.text_block_ids():
|
499
|
+
if tb_id.startswith(field.full()):
|
500
|
+
del context[tb_id]
|
501
|
+
# Add the extracted text of each field to the beginning of the context.
|
502
|
+
context[field.full()] = extracted_text
|
334
503
|
|
335
504
|
# Add the extracted text of each paragraph to the end of the context.
|
336
505
|
for paragraph in ordered_paragraphs:
|
337
506
|
context[paragraph.id] = _clean_paragraph_text(paragraph)
|
338
507
|
|
339
508
|
|
509
|
+
async def get_paragraph_text_with_neighbours(
|
510
|
+
kbid: str,
|
511
|
+
pid: ParagraphId,
|
512
|
+
field_paragraphs: list[ParagraphId],
|
513
|
+
before: int = 0,
|
514
|
+
after: int = 0,
|
515
|
+
) -> tuple[ParagraphId, str]:
|
516
|
+
"""
|
517
|
+
This function will get the paragraph text of the paragraph with the neighbouring paragraphs included.
|
518
|
+
Parameters:
|
519
|
+
kbid: The knowledge box id.
|
520
|
+
pid: The matching paragraph id.
|
521
|
+
field_paragraphs: The list of paragraph ids of the field.
|
522
|
+
before: The number of paragraphs to include before the matching paragraph.
|
523
|
+
after: The number of paragraphs to include after the matching paragraph.
|
524
|
+
"""
|
525
|
+
|
526
|
+
async def _get_paragraph_text(
|
527
|
+
kbid: str,
|
528
|
+
pid: ParagraphId,
|
529
|
+
) -> tuple[ParagraphId, str]:
|
530
|
+
return pid, await get_paragraph_text(
|
531
|
+
kbid=kbid,
|
532
|
+
paragraph_id=pid,
|
533
|
+
log_on_missing_field=True,
|
534
|
+
)
|
535
|
+
|
536
|
+
ops = []
|
537
|
+
try:
|
538
|
+
for paragraph_index in get_neighbouring_paragraph_indexes(
|
539
|
+
field_paragraphs=field_paragraphs,
|
540
|
+
matching_paragraph=pid,
|
541
|
+
before=before,
|
542
|
+
after=after,
|
543
|
+
):
|
544
|
+
neighbour_pid = field_paragraphs[paragraph_index]
|
545
|
+
ops.append(
|
546
|
+
asyncio.create_task(
|
547
|
+
_get_paragraph_text(
|
548
|
+
kbid=kbid,
|
549
|
+
pid=neighbour_pid,
|
550
|
+
)
|
551
|
+
)
|
552
|
+
)
|
553
|
+
except ParagraphIdNotFoundInExtractedMetadata:
|
554
|
+
logger.warning(
|
555
|
+
"Could not find matching paragraph in extracted metadata. This is odd and needs to be investigated.",
|
556
|
+
extra={
|
557
|
+
"kbid": kbid,
|
558
|
+
"matching_paragraph": pid.full(),
|
559
|
+
"field_paragraphs": [p.full() for p in field_paragraphs],
|
560
|
+
},
|
561
|
+
)
|
562
|
+
# If we could not find the matching paragraph in the extracted metadata, we can't retrieve
|
563
|
+
# the neighbouring paragraphs and we simply fetch the text of the matching paragraph.
|
564
|
+
ops.append(
|
565
|
+
asyncio.create_task(
|
566
|
+
_get_paragraph_text(
|
567
|
+
kbid=kbid,
|
568
|
+
pid=pid,
|
569
|
+
)
|
570
|
+
)
|
571
|
+
)
|
572
|
+
|
573
|
+
results = []
|
574
|
+
if len(ops) > 0:
|
575
|
+
results = await asyncio.gather(*ops)
|
576
|
+
|
577
|
+
# Sort the results by the paragraph start
|
578
|
+
results.sort(key=lambda x: x[0].paragraph_start)
|
579
|
+
paragraph_texts = []
|
580
|
+
for _, text in results:
|
581
|
+
if text != "":
|
582
|
+
paragraph_texts.append(text)
|
583
|
+
return pid, "\n\n".join(paragraph_texts)
|
584
|
+
|
585
|
+
|
586
|
+
async def get_field_paragraphs_list(
|
587
|
+
kbid: str,
|
588
|
+
field: FieldId,
|
589
|
+
paragraphs: list[ParagraphId],
|
590
|
+
) -> None:
|
591
|
+
"""
|
592
|
+
Modifies the paragraphs list by adding the paragraph ids of the field, sorted by position.
|
593
|
+
"""
|
594
|
+
resource = await cache.get_resource(kbid, field.rid)
|
595
|
+
if resource is None: # pragma: no cover
|
596
|
+
return
|
597
|
+
field_obj: Field = await resource.get_field(key=field.key, type=field.pb_type, load=False)
|
598
|
+
field_metadata: Optional[resources_pb2.FieldComputedMetadata] = await field_obj.get_field_metadata(
|
599
|
+
force=True
|
600
|
+
)
|
601
|
+
if field_metadata is None: # pragma: no cover
|
602
|
+
return
|
603
|
+
for paragraph in field_metadata.metadata.paragraphs:
|
604
|
+
paragraphs.append(
|
605
|
+
ParagraphId(
|
606
|
+
field_id=field,
|
607
|
+
paragraph_start=paragraph.start,
|
608
|
+
paragraph_end=paragraph.end,
|
609
|
+
)
|
610
|
+
)
|
611
|
+
|
612
|
+
|
613
|
+
async def neighbouring_paragraphs_prompt_context(
|
614
|
+
context: CappedPromptContext,
|
615
|
+
kbid: str,
|
616
|
+
ordered_text_blocks: list[FindParagraph],
|
617
|
+
strategy: NeighbouringParagraphsStrategy,
|
618
|
+
) -> None:
|
619
|
+
"""
|
620
|
+
This function will get the paragraph texts and then craft a context with the neighbouring paragraphs of the
|
621
|
+
paragraphs in the ordered_paragraphs list. The number of paragraphs to include before and after each paragraph
|
622
|
+
"""
|
623
|
+
# First, get the sorted list of paragraphs for each matching field
|
624
|
+
# so we can know the indexes of the neighbouring paragraphs
|
625
|
+
unique_fields = {
|
626
|
+
ParagraphId.from_string(text_block.id).field_id for text_block in ordered_text_blocks
|
627
|
+
}
|
628
|
+
paragraphs_by_field: dict[FieldId, list[ParagraphId]] = {}
|
629
|
+
field_ops = []
|
630
|
+
for field_id in unique_fields:
|
631
|
+
plist = paragraphs_by_field.setdefault(field_id, [])
|
632
|
+
field_ops.append(
|
633
|
+
asyncio.create_task(get_field_paragraphs_list(kbid=kbid, field=field_id, paragraphs=plist))
|
634
|
+
)
|
635
|
+
if field_ops:
|
636
|
+
await asyncio.gather(*field_ops)
|
637
|
+
|
638
|
+
# Now, get the paragraph texts with the neighbouring paragraphs
|
639
|
+
paragraph_ops = []
|
640
|
+
for text_block in ordered_text_blocks:
|
641
|
+
pid = ParagraphId.from_string(text_block.id)
|
642
|
+
paragraph_ops.append(
|
643
|
+
asyncio.create_task(
|
644
|
+
get_paragraph_text_with_neighbours(
|
645
|
+
kbid=kbid,
|
646
|
+
pid=pid,
|
647
|
+
before=strategy.before,
|
648
|
+
after=strategy.after,
|
649
|
+
field_paragraphs=paragraphs_by_field.get(pid.field_id, []),
|
650
|
+
)
|
651
|
+
)
|
652
|
+
)
|
653
|
+
if not paragraph_ops: # pragma: no cover
|
654
|
+
return
|
655
|
+
|
656
|
+
results: list[tuple[ParagraphId, str]] = await asyncio.gather(*paragraph_ops)
|
657
|
+
# Add the paragraph texts to the context
|
658
|
+
for pid, text in results:
|
659
|
+
if text != "":
|
660
|
+
context[pid.full()] = text
|
661
|
+
|
662
|
+
|
663
|
+
async def conversation_prompt_context(
|
664
|
+
context: CappedPromptContext,
|
665
|
+
kbid: str,
|
666
|
+
ordered_paragraphs: list[FindParagraph],
|
667
|
+
conversational_strategy: ConversationalStrategy,
|
668
|
+
visual_llm: bool,
|
669
|
+
):
|
670
|
+
analyzed_fields: List[str] = []
|
671
|
+
async with get_driver().transaction(read_only=True) as txn:
|
672
|
+
storage = await get_storage()
|
673
|
+
kb = KnowledgeBoxORM(txn, storage, kbid)
|
674
|
+
for paragraph in ordered_paragraphs:
|
675
|
+
context[paragraph.id] = _clean_paragraph_text(paragraph)
|
676
|
+
|
677
|
+
# If the paragraph is a conversation and it matches semantically, we assume we
|
678
|
+
# have matched with the question, therefore try to include the answer to the
|
679
|
+
# context by pulling the next few messages of the conversation field
|
680
|
+
rid, field_type, field_id, mident = paragraph.id.split("/")[:4]
|
681
|
+
if field_type == "c" and paragraph.score_type in (
|
682
|
+
SCORE_TYPE.VECTOR,
|
683
|
+
SCORE_TYPE.BOTH,
|
684
|
+
SCORE_TYPE.BM25,
|
685
|
+
):
|
686
|
+
field_unique_id = "-".join([rid, field_type, field_id])
|
687
|
+
if field_unique_id in analyzed_fields:
|
688
|
+
continue
|
689
|
+
resource = await kb.get(rid)
|
690
|
+
if resource is None: # pragma: no cover
|
691
|
+
continue
|
692
|
+
|
693
|
+
field_obj: Conversation = await resource.get_field(
|
694
|
+
field_id, FIELD_TYPE_STR_TO_PB["c"], load=True
|
695
|
+
) # type: ignore
|
696
|
+
cmetadata = await field_obj.get_metadata()
|
697
|
+
|
698
|
+
attachments: List[resources_pb2.FieldRef] = []
|
699
|
+
if conversational_strategy.full:
|
700
|
+
extracted_text = await field_obj.get_extracted_text()
|
701
|
+
for current_page in range(1, cmetadata.pages + 1):
|
702
|
+
conv = await field_obj.db_get_value(current_page)
|
703
|
+
|
704
|
+
for message in conv.messages:
|
705
|
+
ident = message.ident
|
706
|
+
if extracted_text is not None:
|
707
|
+
text = extracted_text.split_text.get(ident, message.content.text.strip())
|
708
|
+
else:
|
709
|
+
text = message.content.text.strip()
|
710
|
+
pid = f"{rid}/{field_type}/{field_id}/{ident}/0-{len(text) + 1}"
|
711
|
+
context[pid] = text
|
712
|
+
attachments.extend(message.content.attachments_fields)
|
713
|
+
else:
|
714
|
+
# Add first message
|
715
|
+
extracted_text = await field_obj.get_extracted_text()
|
716
|
+
first_page = await field_obj.db_get_value()
|
717
|
+
if len(first_page.messages) > 0:
|
718
|
+
message = first_page.messages[0]
|
719
|
+
ident = message.ident
|
720
|
+
if extracted_text is not None:
|
721
|
+
text = extracted_text.split_text.get(ident, message.content.text.strip())
|
722
|
+
else:
|
723
|
+
text = message.content.text.strip()
|
724
|
+
pid = f"{rid}/{field_type}/{field_id}/{ident}/0-{len(text) + 1}"
|
725
|
+
context[pid] = text
|
726
|
+
attachments.extend(message.content.attachments_fields)
|
727
|
+
|
728
|
+
messages: Deque[resources_pb2.Message] = deque(
|
729
|
+
maxlen=conversational_strategy.max_messages
|
730
|
+
)
|
731
|
+
|
732
|
+
pending = -1
|
733
|
+
for page in range(1, cmetadata.pages + 1):
|
734
|
+
# Collect the messages with the window asked by the user arround the match paragraph
|
735
|
+
conv = await field_obj.db_get_value(page)
|
736
|
+
for message in conv.messages:
|
737
|
+
messages.append(message)
|
738
|
+
if pending > 0:
|
739
|
+
pending -= 1
|
740
|
+
if message.ident == mident:
|
741
|
+
pending = (conversational_strategy.max_messages - 1) // 2
|
742
|
+
if pending == 0:
|
743
|
+
break
|
744
|
+
if pending == 0:
|
745
|
+
break
|
746
|
+
|
747
|
+
for message in messages:
|
748
|
+
text = message.content.text.strip()
|
749
|
+
pid = f"{rid}/{field_type}/{field_id}/{message.ident}/0-{len(message.content.text) + 1}"
|
750
|
+
context[pid] = text
|
751
|
+
attachments.extend(message.content.attachments_fields)
|
752
|
+
|
753
|
+
if conversational_strategy.attachments_text:
|
754
|
+
# add on the context the images if vlm enabled
|
755
|
+
for attachment in attachments:
|
756
|
+
field: File = await resource.get_field(
|
757
|
+
attachment.field_id, attachment.field_type, load=True
|
758
|
+
) # type: ignore
|
759
|
+
extracted_text = await field.get_extracted_text()
|
760
|
+
if extracted_text is not None:
|
761
|
+
pid = f"{rid}/{field_type}/{attachment.field_id}/0-{len(extracted_text.text) + 1}"
|
762
|
+
context[pid] = f"Attachment {attachment.field_id}: {extracted_text.text}\n\n"
|
763
|
+
|
764
|
+
if conversational_strategy.attachments_images and visual_llm:
|
765
|
+
for attachment in attachments:
|
766
|
+
file_field: File = await resource.get_field(
|
767
|
+
attachment.field_id, attachment.field_type, load=True
|
768
|
+
) # type: ignore
|
769
|
+
image = await get_file_thumbnail_image(file_field)
|
770
|
+
if image is not None:
|
771
|
+
pid = f"{rid}/f/{attachment.field_id}/0-0"
|
772
|
+
context.images[pid] = image
|
773
|
+
|
774
|
+
analyzed_fields.append(field_unique_id)
|
775
|
+
|
776
|
+
|
777
|
+
async def hierarchy_prompt_context(
|
778
|
+
context: CappedPromptContext,
|
779
|
+
kbid: str,
|
780
|
+
ordered_paragraphs: list[FindParagraph],
|
781
|
+
strategy: HierarchyResourceStrategy,
|
782
|
+
) -> None:
|
783
|
+
"""
|
784
|
+
This function will get the paragraph texts (possibly with extra characters, if extra_characters > 0) and then
|
785
|
+
craft a context with all paragraphs of the same resource grouped together. Moreover, on each group of paragraphs,
|
786
|
+
it includes the resource title and summary so that the LLM can have a better understanding of the context.
|
787
|
+
"""
|
788
|
+
paragraphs_extra_characters = max(strategy.count, 0)
|
789
|
+
# Make a copy of the ordered paragraphs to avoid modifying the original list, which is returned
|
790
|
+
# in the response to the user
|
791
|
+
ordered_paragraphs_copy = copy.deepcopy(ordered_paragraphs)
|
792
|
+
resources: Dict[str, ExtraCharsParagraph] = {}
|
793
|
+
|
794
|
+
# Iterate paragraphs to get extended text
|
795
|
+
for paragraph in ordered_paragraphs_copy:
|
796
|
+
paragraph_id = ParagraphId.from_string(paragraph.id)
|
797
|
+
extended_paragraph_text = paragraph.text
|
798
|
+
if paragraphs_extra_characters > 0:
|
799
|
+
extended_paragraph_text = await get_paragraph_text(
|
800
|
+
kbid=kbid,
|
801
|
+
paragraph_id=paragraph_id,
|
802
|
+
log_on_missing_field=True,
|
803
|
+
)
|
804
|
+
rid = paragraph_id.rid
|
805
|
+
if rid not in resources:
|
806
|
+
# Get the title and the summary of the resource
|
807
|
+
title_text = await get_paragraph_text(
|
808
|
+
kbid=kbid,
|
809
|
+
paragraph_id=ParagraphId(
|
810
|
+
field_id=FieldId(
|
811
|
+
rid=rid,
|
812
|
+
type="a",
|
813
|
+
key="title",
|
814
|
+
),
|
815
|
+
paragraph_start=0,
|
816
|
+
paragraph_end=500,
|
817
|
+
),
|
818
|
+
log_on_missing_field=False,
|
819
|
+
)
|
820
|
+
summary_text = await get_paragraph_text(
|
821
|
+
kbid=kbid,
|
822
|
+
paragraph_id=ParagraphId(
|
823
|
+
field_id=FieldId(
|
824
|
+
rid=rid,
|
825
|
+
type="a",
|
826
|
+
key="summary",
|
827
|
+
),
|
828
|
+
paragraph_start=0,
|
829
|
+
paragraph_end=1000,
|
830
|
+
),
|
831
|
+
log_on_missing_field=False,
|
832
|
+
)
|
833
|
+
resources[rid] = ExtraCharsParagraph(
|
834
|
+
title=title_text,
|
835
|
+
summary=summary_text,
|
836
|
+
paragraphs=[(paragraph, extended_paragraph_text)],
|
837
|
+
)
|
838
|
+
else:
|
839
|
+
resources[rid].paragraphs.append((paragraph, extended_paragraph_text))
|
840
|
+
|
841
|
+
# Modify the first paragraph of each resource to include the title and summary of the resource, as well as the
|
842
|
+
# extended paragraph text of all the paragraphs in the resource.
|
843
|
+
for values in resources.values():
|
844
|
+
title_text = values.title
|
845
|
+
summary_text = values.summary
|
846
|
+
first_paragraph = None
|
847
|
+
text_with_hierarchy = ""
|
848
|
+
for paragraph, extended_paragraph_text in values.paragraphs:
|
849
|
+
if first_paragraph is None:
|
850
|
+
first_paragraph = paragraph
|
851
|
+
text_with_hierarchy += "\n EXTRACTED BLOCK: \n " + extended_paragraph_text + " \n\n "
|
852
|
+
# All paragraphs of the resource are cleared except the first one, which will be the
|
853
|
+
# one containing the whole hierarchy information
|
854
|
+
paragraph.text = ""
|
855
|
+
|
856
|
+
if first_paragraph is not None:
|
857
|
+
# The first paragraph is the only one holding the hierarchy information
|
858
|
+
first_paragraph.text = f"DOCUMENT: {title_text} \n SUMMARY: {summary_text} \n RESOURCE CONTENT: {text_with_hierarchy}"
|
859
|
+
|
860
|
+
# Now that the paragraphs have been modified, we can add them to the context
|
861
|
+
for paragraph in ordered_paragraphs_copy:
|
862
|
+
if paragraph.text == "":
|
863
|
+
# Skip paragraphs that were cleared in the hierarchy expansion
|
864
|
+
continue
|
865
|
+
context[paragraph.id] = _clean_paragraph_text(paragraph)
|
866
|
+
return
|
867
|
+
|
868
|
+
|
340
869
|
class PromptContextBuilder:
|
341
870
|
"""
|
342
871
|
Builds the context for the LLM prompt.
|
@@ -345,7 +874,7 @@ class PromptContextBuilder:
|
|
345
874
|
def __init__(
|
346
875
|
self,
|
347
876
|
kbid: str,
|
348
|
-
|
877
|
+
ordered_paragraphs: list[FindParagraph],
|
349
878
|
resource: Optional[str] = None,
|
350
879
|
user_context: Optional[list[str]] = None,
|
351
880
|
strategies: Optional[Sequence[RagStrategy]] = None,
|
@@ -354,7 +883,7 @@ class PromptContextBuilder:
|
|
354
883
|
visual_llm: bool = False,
|
355
884
|
):
|
356
885
|
self.kbid = kbid
|
357
|
-
self.ordered_paragraphs =
|
886
|
+
self.ordered_paragraphs = ordered_paragraphs
|
358
887
|
self.resource = resource
|
359
888
|
self.user_context = user_context
|
360
889
|
self.strategies = strategies
|
@@ -374,98 +903,175 @@ class PromptContextBuilder:
|
|
374
903
|
ccontext = CappedPromptContext(max_size=self.max_context_characters)
|
375
904
|
self.prepend_user_context(ccontext)
|
376
905
|
await self._build_context(ccontext)
|
377
|
-
|
378
906
|
if self.visual_llm:
|
379
907
|
await self._build_context_images(ccontext)
|
380
908
|
|
381
909
|
context = ccontext.output
|
382
910
|
context_images = ccontext.images
|
383
|
-
context_order = {
|
384
|
-
text_block_id: order for order, text_block_id in enumerate(context.keys())
|
385
|
-
}
|
911
|
+
context_order = {text_block_id: order for order, text_block_id in enumerate(context.keys())}
|
386
912
|
return context, context_order, context_images
|
387
913
|
|
388
914
|
async def _build_context_images(self, context: CappedPromptContext) -> None:
|
389
|
-
|
390
|
-
|
391
|
-
|
392
|
-
|
393
|
-
|
394
|
-
|
395
|
-
|
396
|
-
|
397
|
-
|
398
|
-
|
399
|
-
|
400
|
-
|
401
|
-
|
402
|
-
|
915
|
+
if self.image_strategies is None or len(self.image_strategies) == 0:
|
916
|
+
# Nothing to do
|
917
|
+
return
|
918
|
+
page_image_strategy: Optional[PageImageStrategy] = None
|
919
|
+
max_page_images = 5
|
920
|
+
table_image_strategy: Optional[TableImageStrategy] = None
|
921
|
+
paragraph_image_strategy: Optional[ParagraphImageStrategy] = None
|
922
|
+
for strategy in self.image_strategies:
|
923
|
+
if strategy.name == ImageRagStrategyName.PAGE_IMAGE:
|
924
|
+
if page_image_strategy is None:
|
925
|
+
page_image_strategy = cast(PageImageStrategy, strategy)
|
926
|
+
if page_image_strategy.count is not None:
|
927
|
+
max_page_images = page_image_strategy.count
|
928
|
+
elif strategy.name == ImageRagStrategyName.TABLES:
|
929
|
+
if table_image_strategy is None:
|
930
|
+
table_image_strategy = cast(TableImageStrategy, strategy)
|
931
|
+
elif strategy.name == ImageRagStrategyName.PARAGRAPH_IMAGE:
|
932
|
+
if paragraph_image_strategy is None:
|
933
|
+
paragraph_image_strategy = cast(ParagraphImageStrategy, strategy)
|
934
|
+
else: # pragma: no cover
|
935
|
+
logger.warning(
|
936
|
+
"Unknown image strategy",
|
937
|
+
extra={"strategy": strategy.name, "kbid": self.kbid},
|
938
|
+
)
|
939
|
+
page_images_added = 0
|
403
940
|
for paragraph in self.ordered_paragraphs:
|
404
|
-
|
405
|
-
|
406
|
-
gather_pages
|
407
|
-
and paragraph.position.page_number
|
408
|
-
and len(context.images) < page_count
|
409
|
-
):
|
410
|
-
field = "/".join(paragraph.id.split("/")[:3])
|
411
|
-
page = paragraph.position.page_number
|
412
|
-
page_id = f"{field}/{page}"
|
413
|
-
if page_id not in context.images:
|
414
|
-
context.images[page_id] = await get_page_image(
|
415
|
-
self.kbid, paragraph.id, page
|
416
|
-
)
|
941
|
+
pid = ParagraphId.from_string(paragraph.id)
|
942
|
+
paragraph_page_number = get_paragraph_page_number(paragraph)
|
417
943
|
if (
|
418
|
-
|
419
|
-
and
|
420
|
-
and
|
421
|
-
and paragraph.reference != ""
|
944
|
+
page_image_strategy is not None
|
945
|
+
and page_images_added < max_page_images
|
946
|
+
and paragraph_page_number is not None
|
422
947
|
):
|
423
|
-
|
424
|
-
|
425
|
-
|
426
|
-
|
948
|
+
# page_image_id: rid/f/myfield/0
|
949
|
+
page_image_id = "/".join([pid.field_id.full(), str(paragraph_page_number)])
|
950
|
+
if page_image_id not in context.images:
|
951
|
+
image = await get_page_image(self.kbid, pid, paragraph_page_number)
|
952
|
+
if image is not None:
|
953
|
+
context.images[page_image_id] = image
|
954
|
+
page_images_added += 1
|
955
|
+
else:
|
956
|
+
logger.warning(
|
957
|
+
f"Could not retrieve image for paragraph from storage",
|
958
|
+
extra={
|
959
|
+
"kbid": self.kbid,
|
960
|
+
"paragraph": pid.full(),
|
961
|
+
"page_number": paragraph_page_number,
|
962
|
+
},
|
963
|
+
)
|
964
|
+
|
965
|
+
add_table = table_image_strategy is not None and paragraph.is_a_table
|
966
|
+
add_paragraph = paragraph_image_strategy is not None and not paragraph.is_a_table
|
967
|
+
if (add_table or add_paragraph) and (
|
968
|
+
paragraph.reference is not None and paragraph.reference != ""
|
969
|
+
):
|
970
|
+
pimage = await get_paragraph_image(self.kbid, pid, paragraph.reference)
|
971
|
+
if pimage is not None:
|
972
|
+
context.images[paragraph.id] = pimage
|
973
|
+
else:
|
974
|
+
logger.warning(
|
975
|
+
f"Could not retrieve image for paragraph from storage",
|
976
|
+
extra={
|
977
|
+
"kbid": self.kbid,
|
978
|
+
"paragraph": pid.full(),
|
979
|
+
"reference": paragraph.reference,
|
980
|
+
},
|
981
|
+
)
|
427
982
|
|
428
983
|
async def _build_context(self, context: CappedPromptContext) -> None:
|
429
984
|
if self.strategies is None or len(self.strategies) == 0:
|
985
|
+
# When no strategy is specified, use the default one
|
430
986
|
await default_prompt_context(context, self.kbid, self.ordered_paragraphs)
|
431
987
|
return
|
432
|
-
|
433
|
-
|
434
|
-
|
435
|
-
|
988
|
+
else:
|
989
|
+
# Add the paragraphs to the context and then apply the strategies
|
990
|
+
for paragraph in self.ordered_paragraphs:
|
991
|
+
context[paragraph.id] = _clean_paragraph_text(paragraph)
|
992
|
+
|
993
|
+
full_resource: Optional[FullResourceStrategy] = None
|
994
|
+
hierarchy: Optional[HierarchyResourceStrategy] = None
|
995
|
+
neighbouring_paragraphs: Optional[NeighbouringParagraphsStrategy] = None
|
996
|
+
field_extension: Optional[FieldExtensionStrategy] = None
|
997
|
+
metadata_extension: Optional[MetadataExtensionStrategy] = None
|
998
|
+
conversational_strategy: Optional[ConversationalStrategy] = None
|
436
999
|
for strategy in self.strategies:
|
437
1000
|
if strategy.name == RagStrategyName.FIELD_EXTENSION:
|
438
|
-
|
1001
|
+
field_extension = cast(FieldExtensionStrategy, strategy)
|
1002
|
+
elif strategy.name == RagStrategyName.CONVERSATION:
|
1003
|
+
conversational_strategy = cast(ConversationalStrategy, strategy)
|
439
1004
|
elif strategy.name == RagStrategyName.FULL_RESOURCE:
|
440
|
-
|
441
|
-
|
442
|
-
|
443
|
-
|
1005
|
+
full_resource = cast(FullResourceStrategy, strategy)
|
1006
|
+
if self.resource: # pragma: no cover
|
1007
|
+
# When the retrieval is scoped to a specific resource
|
1008
|
+
# the full resource strategy only includes that resource
|
1009
|
+
full_resource.count = 1
|
444
1010
|
elif strategy.name == RagStrategyName.HIERARCHY:
|
445
|
-
|
1011
|
+
hierarchy = cast(HierarchyResourceStrategy, strategy)
|
1012
|
+
elif strategy.name == RagStrategyName.NEIGHBOURING_PARAGRAPHS:
|
1013
|
+
neighbouring_paragraphs = cast(NeighbouringParagraphsStrategy, strategy)
|
1014
|
+
elif strategy.name == RagStrategyName.METADATA_EXTENSION:
|
1015
|
+
metadata_extension = cast(MetadataExtensionStrategy, strategy)
|
1016
|
+
elif strategy.name != RagStrategyName.PREQUERIES: # pragma: no cover
|
1017
|
+
# Prequeries are not handled here
|
1018
|
+
logger.warning(
|
1019
|
+
"Unknown rag strategy",
|
1020
|
+
extra={"strategy": strategy.name, "kbid": self.kbid},
|
1021
|
+
)
|
446
1022
|
|
447
|
-
if
|
1023
|
+
if full_resource:
|
1024
|
+
# When full resoure is enabled, only metadata extension is allowed.
|
448
1025
|
await full_resource_prompt_context(
|
449
1026
|
context,
|
450
1027
|
self.kbid,
|
451
1028
|
self.ordered_paragraphs,
|
452
1029
|
self.resource,
|
453
|
-
|
1030
|
+
full_resource,
|
454
1031
|
)
|
1032
|
+
if metadata_extension:
|
1033
|
+
await extend_prompt_context_with_metadata(context, self.kbid, metadata_extension)
|
455
1034
|
return
|
456
1035
|
|
457
|
-
if
|
458
|
-
await
|
459
|
-
|
460
|
-
|
1036
|
+
if hierarchy:
|
1037
|
+
await hierarchy_prompt_context(
|
1038
|
+
context,
|
1039
|
+
self.kbid,
|
1040
|
+
self.ordered_paragraphs,
|
1041
|
+
hierarchy,
|
1042
|
+
)
|
1043
|
+
if neighbouring_paragraphs:
|
1044
|
+
await neighbouring_paragraphs_prompt_context(
|
1045
|
+
context,
|
1046
|
+
self.kbid,
|
1047
|
+
self.ordered_paragraphs,
|
1048
|
+
neighbouring_paragraphs,
|
1049
|
+
)
|
1050
|
+
if field_extension:
|
1051
|
+
await field_extension_prompt_context(
|
1052
|
+
context,
|
1053
|
+
self.kbid,
|
1054
|
+
self.ordered_paragraphs,
|
1055
|
+
field_extension,
|
1056
|
+
)
|
1057
|
+
if conversational_strategy:
|
1058
|
+
await conversation_prompt_context(
|
1059
|
+
context,
|
1060
|
+
self.kbid,
|
1061
|
+
self.ordered_paragraphs,
|
1062
|
+
conversational_strategy,
|
1063
|
+
self.visual_llm,
|
1064
|
+
)
|
1065
|
+
if metadata_extension:
|
1066
|
+
await extend_prompt_context_with_metadata(context, self.kbid, metadata_extension)
|
461
1067
|
|
462
|
-
|
463
|
-
|
464
|
-
|
465
|
-
|
466
|
-
|
467
|
-
|
468
|
-
|
1068
|
+
|
1069
|
+
def get_paragraph_page_number(paragraph: FindParagraph) -> Optional[int]:
|
1070
|
+
if not paragraph.page_with_visual:
|
1071
|
+
return None
|
1072
|
+
if paragraph.position is None:
|
1073
|
+
return None
|
1074
|
+
return paragraph.position.page_number
|
469
1075
|
|
470
1076
|
|
471
1077
|
@dataclass
|
@@ -475,67 +1081,6 @@ class ExtraCharsParagraph:
|
|
475
1081
|
paragraphs: List[Tuple[FindParagraph, str]]
|
476
1082
|
|
477
1083
|
|
478
|
-
async def get_extra_chars(
|
479
|
-
kbid: str, ordered_paragraphs: list[FindParagraph], distance: int
|
480
|
-
):
|
481
|
-
etcache = paragraphs.ExtractedTextCache()
|
482
|
-
resources: Dict[str, ExtraCharsParagraph] = {}
|
483
|
-
for paragraph in ordered_paragraphs:
|
484
|
-
rid, field_type, field = paragraph.id.split("/")[:3]
|
485
|
-
field_path = "/".join([rid, field_type, field])
|
486
|
-
position = paragraph.id.split("/")[-1]
|
487
|
-
start, end = position.split("-")
|
488
|
-
int_start = int(start)
|
489
|
-
int_end = int(end) + distance
|
490
|
-
|
491
|
-
new_text = await paragraphs.get_paragraph_text(
|
492
|
-
kbid=kbid,
|
493
|
-
rid=rid,
|
494
|
-
field=field_path,
|
495
|
-
start=int_start,
|
496
|
-
end=int_end,
|
497
|
-
extracted_text_cache=etcache,
|
498
|
-
)
|
499
|
-
if rid not in resources:
|
500
|
-
title_text = await paragraphs.get_paragraph_text(
|
501
|
-
kbid=kbid,
|
502
|
-
rid=rid,
|
503
|
-
field="/a/title",
|
504
|
-
start=0,
|
505
|
-
end=500,
|
506
|
-
extracted_text_cache=etcache,
|
507
|
-
)
|
508
|
-
summary_text = await paragraphs.get_paragraph_text(
|
509
|
-
kbid=kbid,
|
510
|
-
rid=rid,
|
511
|
-
field="/a/summary",
|
512
|
-
start=0,
|
513
|
-
end=1000,
|
514
|
-
extracted_text_cache=etcache,
|
515
|
-
)
|
516
|
-
resources[rid] = ExtraCharsParagraph(
|
517
|
-
title=title_text,
|
518
|
-
summary=summary_text,
|
519
|
-
paragraphs=[(paragraph, new_text)],
|
520
|
-
)
|
521
|
-
else:
|
522
|
-
resources[rid].paragraphs.append((paragraph, new_text)) # type: ignore
|
523
|
-
|
524
|
-
for values in resources.values():
|
525
|
-
title_text = values.title
|
526
|
-
summary_text = values.summary
|
527
|
-
first_paragraph = None
|
528
|
-
text = ""
|
529
|
-
for paragraph, text in values.paragraphs:
|
530
|
-
if first_paragraph is None:
|
531
|
-
first_paragraph = paragraph
|
532
|
-
text += "EXTRACTED BLOCK: \n " + text + " \n\n "
|
533
|
-
paragraph.text = ""
|
534
|
-
|
535
|
-
if first_paragraph is not None:
|
536
|
-
first_paragraph.text = f"DOCUMENT: {title_text} \n SUMMARY: {summary_text} \n RESOURCE CONTENT: {text}"
|
537
|
-
|
538
|
-
|
539
1084
|
def _clean_paragraph_text(paragraph: FindParagraph) -> str:
|
540
1085
|
text = paragraph.text.strip()
|
541
1086
|
# Do not send highlight marks on prompt context
|
@@ -543,17 +1088,23 @@ def _clean_paragraph_text(paragraph: FindParagraph) -> str:
|
|
543
1088
|
return text
|
544
1089
|
|
545
1090
|
|
546
|
-
def
|
1091
|
+
def get_neighbouring_paragraph_indexes(
|
1092
|
+
field_paragraphs: list[ParagraphId],
|
1093
|
+
matching_paragraph: ParagraphId,
|
1094
|
+
before: int,
|
1095
|
+
after: int,
|
1096
|
+
) -> list[int]:
|
547
1097
|
"""
|
548
|
-
Returns the
|
1098
|
+
Returns the indexes of the neighbouring paragraphs to fetch (including the matching paragraph).
|
549
1099
|
"""
|
550
|
-
|
551
|
-
|
552
|
-
|
553
|
-
|
554
|
-
|
555
|
-
|
556
|
-
|
557
|
-
|
558
|
-
|
559
|
-
)
|
1100
|
+
assert before >= 0
|
1101
|
+
assert after >= 0
|
1102
|
+
try:
|
1103
|
+
matching_index = field_paragraphs.index(matching_paragraph)
|
1104
|
+
except ValueError:
|
1105
|
+
raise ParagraphIdNotFoundInExtractedMetadata(
|
1106
|
+
f"Matching paragraph {matching_paragraph.full()} not found in extracted metadata"
|
1107
|
+
)
|
1108
|
+
start_index = max(0, matching_index - before)
|
1109
|
+
end_index = min(len(field_paragraphs), matching_index + after + 1)
|
1110
|
+
return list(range(start_index, end_index))
|