nucliadb 2.46.1.post382__py3-none-any.whl → 6.2.1.post2777__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- migrations/0002_rollover_shards.py +1 -2
- migrations/0003_allfields_key.py +2 -37
- migrations/0004_rollover_shards.py +1 -2
- migrations/0005_rollover_shards.py +1 -2
- migrations/0006_rollover_shards.py +2 -4
- migrations/0008_cleanup_leftover_rollover_metadata.py +1 -2
- migrations/0009_upgrade_relations_and_texts_to_v2.py +5 -4
- migrations/0010_fix_corrupt_indexes.py +11 -12
- migrations/0011_materialize_labelset_ids.py +2 -18
- migrations/0012_rollover_shards.py +6 -12
- migrations/0013_rollover_shards.py +2 -4
- migrations/0014_rollover_shards.py +5 -7
- migrations/0015_targeted_rollover.py +6 -12
- migrations/0016_upgrade_to_paragraphs_v2.py +27 -32
- migrations/0017_multiple_writable_shards.py +3 -6
- migrations/0018_purge_orphan_kbslugs.py +59 -0
- migrations/0019_upgrade_to_paragraphs_v3.py +66 -0
- migrations/0020_drain_nodes_from_cluster.py +83 -0
- nucliadb/standalone/tests/unit/test_run.py → migrations/0021_overwrite_vectorsets_key.py +17 -18
- nucliadb/tests/unit/test_openapi.py → migrations/0022_fix_paragraph_deletion_bug.py +16 -11
- migrations/0023_backfill_pg_catalog.py +80 -0
- migrations/0025_assign_models_to_kbs_v2.py +113 -0
- migrations/0026_fix_high_cardinality_content_types.py +61 -0
- migrations/0027_rollover_texts3.py +73 -0
- nucliadb/ingest/fields/date.py → migrations/pg/0001_bootstrap.py +10 -12
- migrations/pg/0002_catalog.py +42 -0
- nucliadb/ingest/tests/unit/test_settings.py → migrations/pg/0003_catalog_kbid_index.py +5 -3
- nucliadb/common/cluster/base.py +41 -24
- nucliadb/common/cluster/discovery/base.py +6 -14
- nucliadb/common/cluster/discovery/k8s.py +9 -19
- nucliadb/common/cluster/discovery/manual.py +1 -3
- nucliadb/common/cluster/discovery/single.py +1 -2
- nucliadb/common/cluster/discovery/utils.py +1 -3
- nucliadb/common/cluster/grpc_node_dummy.py +11 -16
- nucliadb/common/cluster/index_node.py +10 -19
- nucliadb/common/cluster/manager.py +223 -102
- nucliadb/common/cluster/rebalance.py +42 -37
- nucliadb/common/cluster/rollover.py +377 -204
- nucliadb/common/cluster/settings.py +16 -9
- nucliadb/common/cluster/standalone/grpc_node_binding.py +24 -76
- nucliadb/common/cluster/standalone/index_node.py +4 -11
- nucliadb/common/cluster/standalone/service.py +2 -6
- nucliadb/common/cluster/standalone/utils.py +9 -6
- nucliadb/common/cluster/utils.py +43 -29
- nucliadb/common/constants.py +20 -0
- nucliadb/common/context/__init__.py +6 -4
- nucliadb/common/context/fastapi.py +8 -5
- nucliadb/{tests/knowledgeboxes/__init__.py → common/counters.py} +8 -2
- nucliadb/common/datamanagers/__init__.py +24 -5
- nucliadb/common/datamanagers/atomic.py +102 -0
- nucliadb/common/datamanagers/cluster.py +5 -5
- nucliadb/common/datamanagers/entities.py +6 -16
- nucliadb/common/datamanagers/fields.py +84 -0
- nucliadb/common/datamanagers/kb.py +101 -24
- nucliadb/common/datamanagers/labels.py +26 -56
- nucliadb/common/datamanagers/processing.py +2 -6
- nucliadb/common/datamanagers/resources.py +214 -117
- nucliadb/common/datamanagers/rollover.py +77 -16
- nucliadb/{ingest/orm → common/datamanagers}/synonyms.py +16 -28
- nucliadb/common/datamanagers/utils.py +19 -11
- nucliadb/common/datamanagers/vectorsets.py +110 -0
- nucliadb/common/external_index_providers/base.py +257 -0
- nucliadb/{ingest/tests/unit/test_cache.py → common/external_index_providers/exceptions.py} +9 -8
- nucliadb/common/external_index_providers/manager.py +101 -0
- nucliadb/common/external_index_providers/pinecone.py +933 -0
- nucliadb/common/external_index_providers/settings.py +52 -0
- nucliadb/common/http_clients/auth.py +3 -6
- nucliadb/common/http_clients/processing.py +6 -11
- nucliadb/common/http_clients/utils.py +1 -3
- nucliadb/common/ids.py +240 -0
- nucliadb/common/locking.py +43 -13
- nucliadb/common/maindb/driver.py +11 -35
- nucliadb/common/maindb/exceptions.py +6 -6
- nucliadb/common/maindb/local.py +22 -9
- nucliadb/common/maindb/pg.py +206 -111
- nucliadb/common/maindb/utils.py +13 -44
- nucliadb/common/models_utils/from_proto.py +479 -0
- nucliadb/common/models_utils/to_proto.py +60 -0
- nucliadb/common/nidx.py +260 -0
- nucliadb/export_import/datamanager.py +25 -19
- nucliadb/export_import/exceptions.py +8 -0
- nucliadb/export_import/exporter.py +20 -7
- nucliadb/export_import/importer.py +6 -11
- nucliadb/export_import/models.py +5 -5
- nucliadb/export_import/tasks.py +4 -4
- nucliadb/export_import/utils.py +94 -54
- nucliadb/health.py +1 -3
- nucliadb/ingest/app.py +15 -11
- nucliadb/ingest/consumer/auditing.py +30 -147
- nucliadb/ingest/consumer/consumer.py +96 -52
- nucliadb/ingest/consumer/materializer.py +10 -12
- nucliadb/ingest/consumer/pull.py +12 -27
- nucliadb/ingest/consumer/service.py +20 -19
- nucliadb/ingest/consumer/shard_creator.py +7 -14
- nucliadb/ingest/consumer/utils.py +1 -3
- nucliadb/ingest/fields/base.py +139 -188
- nucliadb/ingest/fields/conversation.py +18 -5
- nucliadb/ingest/fields/exceptions.py +1 -4
- nucliadb/ingest/fields/file.py +7 -25
- nucliadb/ingest/fields/link.py +11 -16
- nucliadb/ingest/fields/text.py +9 -4
- nucliadb/ingest/orm/brain.py +255 -262
- nucliadb/ingest/orm/broker_message.py +181 -0
- nucliadb/ingest/orm/entities.py +36 -51
- nucliadb/ingest/orm/exceptions.py +12 -0
- nucliadb/ingest/orm/knowledgebox.py +334 -278
- nucliadb/ingest/orm/processor/__init__.py +2 -697
- nucliadb/ingest/orm/processor/auditing.py +117 -0
- nucliadb/ingest/orm/processor/data_augmentation.py +164 -0
- nucliadb/ingest/orm/processor/pgcatalog.py +84 -0
- nucliadb/ingest/orm/processor/processor.py +752 -0
- nucliadb/ingest/orm/processor/sequence_manager.py +1 -1
- nucliadb/ingest/orm/resource.py +280 -520
- nucliadb/ingest/orm/utils.py +25 -31
- nucliadb/ingest/partitions.py +3 -9
- nucliadb/ingest/processing.py +76 -81
- nucliadb/ingest/py.typed +0 -0
- nucliadb/ingest/serialize.py +37 -173
- nucliadb/ingest/service/__init__.py +1 -3
- nucliadb/ingest/service/writer.py +186 -577
- nucliadb/ingest/settings.py +13 -22
- nucliadb/ingest/utils.py +3 -6
- nucliadb/learning_proxy.py +264 -51
- nucliadb/metrics_exporter.py +30 -19
- nucliadb/middleware/__init__.py +1 -3
- nucliadb/migrator/command.py +1 -3
- nucliadb/migrator/datamanager.py +13 -13
- nucliadb/migrator/migrator.py +57 -37
- nucliadb/migrator/settings.py +2 -1
- nucliadb/migrator/utils.py +18 -10
- nucliadb/purge/__init__.py +139 -33
- nucliadb/purge/orphan_shards.py +7 -13
- nucliadb/reader/__init__.py +1 -3
- nucliadb/reader/api/models.py +3 -14
- nucliadb/reader/api/v1/__init__.py +0 -1
- nucliadb/reader/api/v1/download.py +27 -94
- nucliadb/reader/api/v1/export_import.py +4 -4
- nucliadb/reader/api/v1/knowledgebox.py +13 -13
- nucliadb/reader/api/v1/learning_config.py +8 -12
- nucliadb/reader/api/v1/resource.py +67 -93
- nucliadb/reader/api/v1/services.py +70 -125
- nucliadb/reader/app.py +16 -46
- nucliadb/reader/lifecycle.py +18 -4
- nucliadb/reader/py.typed +0 -0
- nucliadb/reader/reader/notifications.py +10 -31
- nucliadb/search/__init__.py +1 -3
- nucliadb/search/api/v1/__init__.py +2 -2
- nucliadb/search/api/v1/ask.py +112 -0
- nucliadb/search/api/v1/catalog.py +184 -0
- nucliadb/search/api/v1/feedback.py +17 -25
- nucliadb/search/api/v1/find.py +41 -41
- nucliadb/search/api/v1/knowledgebox.py +90 -62
- nucliadb/search/api/v1/predict_proxy.py +2 -2
- nucliadb/search/api/v1/resource/ask.py +66 -117
- nucliadb/search/api/v1/resource/search.py +51 -72
- nucliadb/search/api/v1/router.py +1 -0
- nucliadb/search/api/v1/search.py +50 -197
- nucliadb/search/api/v1/suggest.py +40 -54
- nucliadb/search/api/v1/summarize.py +9 -5
- nucliadb/search/api/v1/utils.py +2 -1
- nucliadb/search/app.py +16 -48
- nucliadb/search/lifecycle.py +10 -3
- nucliadb/search/predict.py +176 -188
- nucliadb/search/py.typed +0 -0
- nucliadb/search/requesters/utils.py +41 -63
- nucliadb/search/search/cache.py +149 -20
- nucliadb/search/search/chat/ask.py +918 -0
- nucliadb/search/{tests/unit/test_run.py → search/chat/exceptions.py} +14 -13
- nucliadb/search/search/chat/images.py +41 -17
- nucliadb/search/search/chat/prompt.py +851 -282
- nucliadb/search/search/chat/query.py +274 -267
- nucliadb/{writer/resource/slug.py → search/search/cut.py} +8 -6
- nucliadb/search/search/fetch.py +43 -36
- nucliadb/search/search/filters.py +9 -15
- nucliadb/search/search/find.py +214 -54
- nucliadb/search/search/find_merge.py +408 -391
- nucliadb/search/search/hydrator.py +191 -0
- nucliadb/search/search/merge.py +198 -234
- nucliadb/search/search/metrics.py +73 -2
- nucliadb/search/search/paragraphs.py +64 -106
- nucliadb/search/search/pgcatalog.py +233 -0
- nucliadb/search/search/predict_proxy.py +1 -1
- nucliadb/search/search/query.py +386 -257
- nucliadb/search/search/query_parser/exceptions.py +22 -0
- nucliadb/search/search/query_parser/models.py +101 -0
- nucliadb/search/search/query_parser/parser.py +183 -0
- nucliadb/search/search/rank_fusion.py +204 -0
- nucliadb/search/search/rerankers.py +270 -0
- nucliadb/search/search/shards.py +4 -38
- nucliadb/search/search/summarize.py +14 -18
- nucliadb/search/search/utils.py +27 -4
- nucliadb/search/settings.py +15 -1
- nucliadb/standalone/api_router.py +4 -10
- nucliadb/standalone/app.py +17 -14
- nucliadb/standalone/auth.py +7 -21
- nucliadb/standalone/config.py +9 -12
- nucliadb/standalone/introspect.py +5 -5
- nucliadb/standalone/lifecycle.py +26 -25
- nucliadb/standalone/migrations.py +58 -0
- nucliadb/standalone/purge.py +9 -8
- nucliadb/standalone/py.typed +0 -0
- nucliadb/standalone/run.py +25 -18
- nucliadb/standalone/settings.py +10 -14
- nucliadb/standalone/versions.py +15 -5
- nucliadb/tasks/consumer.py +8 -12
- nucliadb/tasks/producer.py +7 -6
- nucliadb/tests/config.py +53 -0
- nucliadb/train/__init__.py +1 -3
- nucliadb/train/api/utils.py +1 -2
- nucliadb/train/api/v1/shards.py +2 -2
- nucliadb/train/api/v1/trainset.py +4 -6
- nucliadb/train/app.py +14 -47
- nucliadb/train/generator.py +10 -19
- nucliadb/train/generators/field_classifier.py +7 -19
- nucliadb/train/generators/field_streaming.py +156 -0
- nucliadb/train/generators/image_classifier.py +12 -18
- nucliadb/train/generators/paragraph_classifier.py +5 -9
- nucliadb/train/generators/paragraph_streaming.py +6 -9
- nucliadb/train/generators/question_answer_streaming.py +19 -20
- nucliadb/train/generators/sentence_classifier.py +9 -15
- nucliadb/train/generators/token_classifier.py +45 -36
- nucliadb/train/generators/utils.py +14 -18
- nucliadb/train/lifecycle.py +7 -3
- nucliadb/train/nodes.py +23 -32
- nucliadb/train/py.typed +0 -0
- nucliadb/train/servicer.py +13 -21
- nucliadb/train/settings.py +2 -6
- nucliadb/train/types.py +13 -10
- nucliadb/train/upload.py +3 -6
- nucliadb/train/uploader.py +20 -25
- nucliadb/train/utils.py +1 -1
- nucliadb/writer/__init__.py +1 -3
- nucliadb/writer/api/constants.py +0 -5
- nucliadb/{ingest/fields/keywordset.py → writer/api/utils.py} +13 -10
- nucliadb/writer/api/v1/export_import.py +102 -49
- nucliadb/writer/api/v1/field.py +196 -620
- nucliadb/writer/api/v1/knowledgebox.py +221 -71
- nucliadb/writer/api/v1/learning_config.py +2 -2
- nucliadb/writer/api/v1/resource.py +114 -216
- nucliadb/writer/api/v1/services.py +64 -132
- nucliadb/writer/api/v1/slug.py +61 -0
- nucliadb/writer/api/v1/transaction.py +67 -0
- nucliadb/writer/api/v1/upload.py +184 -215
- nucliadb/writer/app.py +11 -61
- nucliadb/writer/back_pressure.py +62 -43
- nucliadb/writer/exceptions.py +0 -4
- nucliadb/writer/lifecycle.py +21 -15
- nucliadb/writer/py.typed +0 -0
- nucliadb/writer/resource/audit.py +2 -1
- nucliadb/writer/resource/basic.py +48 -62
- nucliadb/writer/resource/field.py +45 -135
- nucliadb/writer/resource/origin.py +1 -2
- nucliadb/writer/settings.py +14 -5
- nucliadb/writer/tus/__init__.py +17 -15
- nucliadb/writer/tus/azure.py +111 -0
- nucliadb/writer/tus/dm.py +17 -5
- nucliadb/writer/tus/exceptions.py +1 -3
- nucliadb/writer/tus/gcs.py +56 -84
- nucliadb/writer/tus/local.py +21 -37
- nucliadb/writer/tus/s3.py +28 -68
- nucliadb/writer/tus/storage.py +5 -56
- nucliadb/writer/vectorsets.py +125 -0
- nucliadb-6.2.1.post2777.dist-info/METADATA +148 -0
- nucliadb-6.2.1.post2777.dist-info/RECORD +343 -0
- {nucliadb-2.46.1.post382.dist-info → nucliadb-6.2.1.post2777.dist-info}/WHEEL +1 -1
- nucliadb/common/maindb/redis.py +0 -194
- nucliadb/common/maindb/tikv.py +0 -412
- nucliadb/ingest/fields/layout.py +0 -58
- nucliadb/ingest/tests/conftest.py +0 -30
- nucliadb/ingest/tests/fixtures.py +0 -771
- nucliadb/ingest/tests/integration/consumer/__init__.py +0 -18
- nucliadb/ingest/tests/integration/consumer/test_auditing.py +0 -80
- nucliadb/ingest/tests/integration/consumer/test_materializer.py +0 -89
- nucliadb/ingest/tests/integration/consumer/test_pull.py +0 -144
- nucliadb/ingest/tests/integration/consumer/test_service.py +0 -81
- nucliadb/ingest/tests/integration/consumer/test_shard_creator.py +0 -68
- nucliadb/ingest/tests/integration/ingest/test_ingest.py +0 -691
- nucliadb/ingest/tests/integration/ingest/test_processing_engine.py +0 -95
- nucliadb/ingest/tests/integration/ingest/test_relations.py +0 -272
- nucliadb/ingest/tests/unit/consumer/__init__.py +0 -18
- nucliadb/ingest/tests/unit/consumer/test_auditing.py +0 -140
- nucliadb/ingest/tests/unit/consumer/test_consumer.py +0 -69
- nucliadb/ingest/tests/unit/consumer/test_pull.py +0 -60
- nucliadb/ingest/tests/unit/consumer/test_shard_creator.py +0 -139
- nucliadb/ingest/tests/unit/consumer/test_utils.py +0 -67
- nucliadb/ingest/tests/unit/orm/__init__.py +0 -19
- nucliadb/ingest/tests/unit/orm/test_brain.py +0 -247
- nucliadb/ingest/tests/unit/orm/test_processor.py +0 -131
- nucliadb/ingest/tests/unit/orm/test_resource.py +0 -275
- nucliadb/ingest/tests/unit/test_partitions.py +0 -40
- nucliadb/ingest/tests/unit/test_processing.py +0 -171
- nucliadb/middleware/transaction.py +0 -117
- nucliadb/reader/api/v1/learning_collector.py +0 -63
- nucliadb/reader/tests/__init__.py +0 -19
- nucliadb/reader/tests/conftest.py +0 -31
- nucliadb/reader/tests/fixtures.py +0 -136
- nucliadb/reader/tests/test_list_resources.py +0 -75
- nucliadb/reader/tests/test_reader_file_download.py +0 -273
- nucliadb/reader/tests/test_reader_resource.py +0 -379
- nucliadb/reader/tests/test_reader_resource_field.py +0 -219
- nucliadb/search/api/v1/chat.py +0 -258
- nucliadb/search/api/v1/resource/chat.py +0 -94
- nucliadb/search/tests/__init__.py +0 -19
- nucliadb/search/tests/conftest.py +0 -33
- nucliadb/search/tests/fixtures.py +0 -199
- nucliadb/search/tests/node.py +0 -465
- nucliadb/search/tests/unit/__init__.py +0 -18
- nucliadb/search/tests/unit/api/__init__.py +0 -19
- nucliadb/search/tests/unit/api/v1/__init__.py +0 -19
- nucliadb/search/tests/unit/api/v1/resource/__init__.py +0 -19
- nucliadb/search/tests/unit/api/v1/resource/test_ask.py +0 -67
- nucliadb/search/tests/unit/api/v1/resource/test_chat.py +0 -97
- nucliadb/search/tests/unit/api/v1/test_chat.py +0 -96
- nucliadb/search/tests/unit/api/v1/test_predict_proxy.py +0 -98
- nucliadb/search/tests/unit/api/v1/test_summarize.py +0 -93
- nucliadb/search/tests/unit/search/__init__.py +0 -18
- nucliadb/search/tests/unit/search/requesters/__init__.py +0 -18
- nucliadb/search/tests/unit/search/requesters/test_utils.py +0 -210
- nucliadb/search/tests/unit/search/search/__init__.py +0 -19
- nucliadb/search/tests/unit/search/search/test_shards.py +0 -45
- nucliadb/search/tests/unit/search/search/test_utils.py +0 -82
- nucliadb/search/tests/unit/search/test_chat_prompt.py +0 -266
- nucliadb/search/tests/unit/search/test_fetch.py +0 -108
- nucliadb/search/tests/unit/search/test_filters.py +0 -125
- nucliadb/search/tests/unit/search/test_paragraphs.py +0 -157
- nucliadb/search/tests/unit/search/test_predict_proxy.py +0 -106
- nucliadb/search/tests/unit/search/test_query.py +0 -201
- nucliadb/search/tests/unit/test_app.py +0 -79
- nucliadb/search/tests/unit/test_find_merge.py +0 -112
- nucliadb/search/tests/unit/test_merge.py +0 -34
- nucliadb/search/tests/unit/test_predict.py +0 -584
- nucliadb/standalone/tests/__init__.py +0 -19
- nucliadb/standalone/tests/conftest.py +0 -33
- nucliadb/standalone/tests/fixtures.py +0 -38
- nucliadb/standalone/tests/unit/__init__.py +0 -18
- nucliadb/standalone/tests/unit/test_api_router.py +0 -61
- nucliadb/standalone/tests/unit/test_auth.py +0 -169
- nucliadb/standalone/tests/unit/test_introspect.py +0 -35
- nucliadb/standalone/tests/unit/test_versions.py +0 -68
- nucliadb/tests/benchmarks/__init__.py +0 -19
- nucliadb/tests/benchmarks/test_search.py +0 -99
- nucliadb/tests/conftest.py +0 -32
- nucliadb/tests/fixtures.py +0 -736
- nucliadb/tests/knowledgeboxes/philosophy_books.py +0 -203
- nucliadb/tests/knowledgeboxes/ten_dummy_resources.py +0 -109
- nucliadb/tests/migrations/__init__.py +0 -19
- nucliadb/tests/migrations/test_migration_0017.py +0 -80
- nucliadb/tests/tikv.py +0 -240
- nucliadb/tests/unit/__init__.py +0 -19
- nucliadb/tests/unit/common/__init__.py +0 -19
- nucliadb/tests/unit/common/cluster/__init__.py +0 -19
- nucliadb/tests/unit/common/cluster/discovery/__init__.py +0 -19
- nucliadb/tests/unit/common/cluster/discovery/test_k8s.py +0 -170
- nucliadb/tests/unit/common/cluster/standalone/__init__.py +0 -18
- nucliadb/tests/unit/common/cluster/standalone/test_service.py +0 -113
- nucliadb/tests/unit/common/cluster/standalone/test_utils.py +0 -59
- nucliadb/tests/unit/common/cluster/test_cluster.py +0 -399
- nucliadb/tests/unit/common/cluster/test_kb_shard_manager.py +0 -178
- nucliadb/tests/unit/common/cluster/test_rollover.py +0 -279
- nucliadb/tests/unit/common/maindb/__init__.py +0 -18
- nucliadb/tests/unit/common/maindb/test_driver.py +0 -127
- nucliadb/tests/unit/common/maindb/test_tikv.py +0 -53
- nucliadb/tests/unit/common/maindb/test_utils.py +0 -81
- nucliadb/tests/unit/common/test_context.py +0 -36
- nucliadb/tests/unit/export_import/__init__.py +0 -19
- nucliadb/tests/unit/export_import/test_datamanager.py +0 -37
- nucliadb/tests/unit/export_import/test_utils.py +0 -294
- nucliadb/tests/unit/migrator/__init__.py +0 -19
- nucliadb/tests/unit/migrator/test_migrator.py +0 -87
- nucliadb/tests/unit/tasks/__init__.py +0 -19
- nucliadb/tests/unit/tasks/conftest.py +0 -42
- nucliadb/tests/unit/tasks/test_consumer.py +0 -93
- nucliadb/tests/unit/tasks/test_producer.py +0 -95
- nucliadb/tests/unit/tasks/test_tasks.py +0 -60
- nucliadb/tests/unit/test_field_ids.py +0 -49
- nucliadb/tests/unit/test_health.py +0 -84
- nucliadb/tests/unit/test_kb_slugs.py +0 -54
- nucliadb/tests/unit/test_learning_proxy.py +0 -252
- nucliadb/tests/unit/test_metrics_exporter.py +0 -77
- nucliadb/tests/unit/test_purge.py +0 -138
- nucliadb/tests/utils/__init__.py +0 -74
- nucliadb/tests/utils/aiohttp_session.py +0 -44
- nucliadb/tests/utils/broker_messages/__init__.py +0 -167
- nucliadb/tests/utils/broker_messages/fields.py +0 -181
- nucliadb/tests/utils/broker_messages/helpers.py +0 -33
- nucliadb/tests/utils/entities.py +0 -78
- nucliadb/train/api/v1/check.py +0 -60
- nucliadb/train/tests/__init__.py +0 -19
- nucliadb/train/tests/conftest.py +0 -29
- nucliadb/train/tests/fixtures.py +0 -342
- nucliadb/train/tests/test_field_classification.py +0 -122
- nucliadb/train/tests/test_get_entities.py +0 -80
- nucliadb/train/tests/test_get_info.py +0 -51
- nucliadb/train/tests/test_get_ontology.py +0 -34
- nucliadb/train/tests/test_get_ontology_count.py +0 -63
- nucliadb/train/tests/test_image_classification.py +0 -222
- nucliadb/train/tests/test_list_fields.py +0 -39
- nucliadb/train/tests/test_list_paragraphs.py +0 -73
- nucliadb/train/tests/test_list_resources.py +0 -39
- nucliadb/train/tests/test_list_sentences.py +0 -71
- nucliadb/train/tests/test_paragraph_classification.py +0 -123
- nucliadb/train/tests/test_paragraph_streaming.py +0 -118
- nucliadb/train/tests/test_question_answer_streaming.py +0 -239
- nucliadb/train/tests/test_sentence_classification.py +0 -143
- nucliadb/train/tests/test_token_classification.py +0 -136
- nucliadb/train/tests/utils.py +0 -108
- nucliadb/writer/layouts/__init__.py +0 -51
- nucliadb/writer/layouts/v1.py +0 -59
- nucliadb/writer/resource/vectors.py +0 -120
- nucliadb/writer/tests/__init__.py +0 -19
- nucliadb/writer/tests/conftest.py +0 -31
- nucliadb/writer/tests/fixtures.py +0 -192
- nucliadb/writer/tests/test_fields.py +0 -486
- nucliadb/writer/tests/test_files.py +0 -743
- nucliadb/writer/tests/test_knowledgebox.py +0 -49
- nucliadb/writer/tests/test_reprocess_file_field.py +0 -139
- nucliadb/writer/tests/test_resources.py +0 -546
- nucliadb/writer/tests/test_service.py +0 -137
- nucliadb/writer/tests/test_tus.py +0 -203
- nucliadb/writer/tests/utils.py +0 -35
- nucliadb/writer/tus/pg.py +0 -125
- nucliadb-2.46.1.post382.dist-info/METADATA +0 -134
- nucliadb-2.46.1.post382.dist-info/RECORD +0 -451
- {nucliadb/ingest/tests → migrations/pg}/__init__.py +0 -0
- /nucliadb/{ingest/tests/integration → common/external_index_providers}/__init__.py +0 -0
- /nucliadb/{ingest/tests/integration/ingest → common/models_utils}/__init__.py +0 -0
- /nucliadb/{ingest/tests/unit → search/search/query_parser}/__init__.py +0 -0
- /nucliadb/{ingest/tests → tests}/vectors.py +0 -0
- {nucliadb-2.46.1.post382.dist-info → nucliadb-6.2.1.post2777.dist-info}/entry_points.txt +0 -0
- {nucliadb-2.46.1.post382.dist-info → nucliadb-6.2.1.post2777.dist-info}/top_level.txt +0 -0
- {nucliadb-2.46.1.post382.dist-info → nucliadb-6.2.1.post2777.dist-info}/zip-safe +0 -0
@@ -17,32 +17,55 @@
|
|
17
17
|
# You should have received a copy of the GNU Affero General Public License
|
18
18
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
19
19
|
#
|
20
|
+
import asyncio
|
21
|
+
import copy
|
22
|
+
from collections import deque
|
20
23
|
from dataclasses import dataclass
|
21
|
-
from typing import Dict, List, Optional, Sequence, Tuple
|
24
|
+
from typing import Deque, Dict, List, Optional, Sequence, Tuple, Union, cast
|
22
25
|
|
26
|
+
import yaml
|
27
|
+
from pydantic import BaseModel
|
28
|
+
|
29
|
+
from nucliadb.common.ids import FIELD_TYPE_STR_TO_PB, FieldId, ParagraphId
|
30
|
+
from nucliadb.common.maindb.utils import get_driver
|
31
|
+
from nucliadb.common.models_utils import from_proto
|
23
32
|
from nucliadb.ingest.fields.base import Field
|
24
33
|
from nucliadb.ingest.fields.conversation import Conversation
|
34
|
+
from nucliadb.ingest.fields.file import File
|
25
35
|
from nucliadb.ingest.orm.knowledgebox import KnowledgeBox as KnowledgeBoxORM
|
26
|
-
from nucliadb.ingest.orm.resource import KB_REVERSE
|
27
|
-
from nucliadb.ingest.orm.resource import Resource as ResourceORM
|
28
|
-
from nucliadb.middleware.transaction import get_read_only_transaction
|
29
36
|
from nucliadb.search import logger
|
30
|
-
from nucliadb.search.search import
|
31
|
-
from nucliadb.search.search.chat.images import
|
37
|
+
from nucliadb.search.search import cache
|
38
|
+
from nucliadb.search.search.chat.images import (
|
39
|
+
get_file_thumbnail_image,
|
40
|
+
get_page_image,
|
41
|
+
get_paragraph_image,
|
42
|
+
)
|
43
|
+
from nucliadb.search.search.hydrator import hydrate_field_text, hydrate_resource_text
|
44
|
+
from nucliadb.search.search.paragraphs import get_paragraph_text
|
45
|
+
from nucliadb_models.metadata import Extra, Origin
|
32
46
|
from nucliadb_models.search import (
|
33
47
|
SCORE_TYPE,
|
48
|
+
ConversationalStrategy,
|
49
|
+
FieldExtensionStrategy,
|
34
50
|
FindParagraph,
|
51
|
+
FullResourceStrategy,
|
52
|
+
HierarchyResourceStrategy,
|
35
53
|
ImageRagStrategy,
|
36
54
|
ImageRagStrategyName,
|
37
|
-
|
55
|
+
MetadataExtensionStrategy,
|
56
|
+
MetadataExtensionType,
|
57
|
+
NeighbouringParagraphsStrategy,
|
58
|
+
PageImageStrategy,
|
59
|
+
ParagraphImageStrategy,
|
38
60
|
PromptContext,
|
39
61
|
PromptContextImages,
|
40
62
|
PromptContextOrder,
|
41
63
|
RagStrategy,
|
42
64
|
RagStrategyName,
|
65
|
+
TableImageStrategy,
|
43
66
|
)
|
44
67
|
from nucliadb_protos import resources_pb2
|
45
|
-
from nucliadb_utils.asyncio_utils import
|
68
|
+
from nucliadb_utils.asyncio_utils import run_concurrently
|
46
69
|
from nucliadb_utils.utilities import get_storage
|
47
70
|
|
48
71
|
MAX_RESOURCE_TASKS = 5
|
@@ -53,12 +76,20 @@ MAX_RESOURCE_FIELD_TASKS = 4
|
|
53
76
|
# The hope here is it will be enough to get the answer to the question.
|
54
77
|
CONVERSATION_MESSAGE_CONTEXT_EXPANSION = 15
|
55
78
|
|
79
|
+
TextBlockId = Union[ParagraphId, FieldId]
|
80
|
+
|
81
|
+
|
82
|
+
class ParagraphIdNotFoundInExtractedMetadata(Exception):
|
83
|
+
pass
|
84
|
+
|
56
85
|
|
57
86
|
class CappedPromptContext:
|
58
87
|
"""
|
59
|
-
Class to keep track of the size
|
88
|
+
Class to keep track of the size (in number of characters) of the prompt context
|
89
|
+
and raise an exception if it exceeds the configured limit.
|
60
90
|
|
61
|
-
This class will automatically trim data that exceeds the limit when it's being
|
91
|
+
This class will automatically trim data that exceeds the limit when it's being
|
92
|
+
set on the dictionary.
|
62
93
|
"""
|
63
94
|
|
64
95
|
def __init__(self, max_size: Optional[int]):
|
@@ -68,15 +99,26 @@ class CappedPromptContext:
|
|
68
99
|
self._size = 0
|
69
100
|
|
70
101
|
def __setitem__(self, key: str, value: str) -> None:
|
102
|
+
prev_value_len = len(self.output.get(key, ""))
|
71
103
|
if self.max_size is None:
|
72
|
-
|
104
|
+
# Unbounded size context
|
105
|
+
to_add = value
|
73
106
|
else:
|
74
|
-
|
75
|
-
self._size
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
107
|
+
# Make sure we don't exceed the max size
|
108
|
+
size_available = max(self.max_size - self._size + prev_value_len, 0)
|
109
|
+
to_add = value[:size_available]
|
110
|
+
self.output[key] = to_add
|
111
|
+
self._size = self._size - prev_value_len + len(to_add)
|
112
|
+
|
113
|
+
def __getitem__(self, key: str) -> str:
|
114
|
+
return self.output.__getitem__(key)
|
115
|
+
|
116
|
+
def __delitem__(self, key: str) -> None:
|
117
|
+
value = self.output.pop(key, "")
|
118
|
+
self._size -= len(value)
|
119
|
+
|
120
|
+
def text_block_ids(self) -> list[str]:
|
121
|
+
return list(self.output.keys())
|
80
122
|
|
81
123
|
@property
|
82
124
|
def size(self) -> int:
|
@@ -91,15 +133,15 @@ async def get_next_conversation_messages(
|
|
91
133
|
num_messages: int,
|
92
134
|
message_type: Optional[resources_pb2.Message.MessageType.ValueType] = None,
|
93
135
|
msg_to: Optional[str] = None,
|
94
|
-
):
|
136
|
+
) -> List[resources_pb2.Message]:
|
95
137
|
output = []
|
96
138
|
cmetadata = await field_obj.get_metadata()
|
97
139
|
for current_page in range(page, cmetadata.pages + 1):
|
98
140
|
conv = await field_obj.db_get_value(current_page)
|
99
141
|
for message in conv.messages[start_idx:]:
|
100
|
-
if message_type is not None and message.type != message_type:
|
142
|
+
if message_type is not None and message.type != message_type: # pragma: no cover
|
101
143
|
continue
|
102
|
-
if msg_to is not None and msg_to not in message.to:
|
144
|
+
if msg_to is not None and msg_to not in message.to: # pragma: no cover
|
103
145
|
continue
|
104
146
|
output.append(message)
|
105
147
|
if len(output) >= num_messages:
|
@@ -122,16 +164,21 @@ async def find_conversation_message(
|
|
122
164
|
|
123
165
|
|
124
166
|
async def get_expanded_conversation_messages(
|
125
|
-
*,
|
167
|
+
*,
|
168
|
+
kb: KnowledgeBoxORM,
|
169
|
+
rid: str,
|
170
|
+
field_id: str,
|
171
|
+
mident: str,
|
172
|
+
max_messages: int = CONVERSATION_MESSAGE_CONTEXT_EXPANSION,
|
126
173
|
) -> list[resources_pb2.Message]:
|
127
174
|
resource = await kb.get(rid)
|
128
|
-
if resource is None:
|
175
|
+
if resource is None: # pragma: no cover
|
129
176
|
return []
|
130
|
-
field_obj = await resource.get_field(field_id,
|
177
|
+
field_obj: Conversation = await resource.get_field(field_id, FIELD_TYPE_STR_TO_PB["c"], load=True) # type: ignore
|
131
178
|
found_message, found_page, found_idx = await find_conversation_message(
|
132
179
|
field_obj=field_obj, mident=mident
|
133
180
|
)
|
134
|
-
if found_message is None:
|
181
|
+
if found_message is None: # pragma: no cover
|
135
182
|
return []
|
136
183
|
elif found_message.type == resources_pb2.Message.MessageType.QUESTION:
|
137
184
|
# only try to get answer if it was a question
|
@@ -147,14 +194,14 @@ async def get_expanded_conversation_messages(
|
|
147
194
|
field_obj=field_obj,
|
148
195
|
page=found_page,
|
149
196
|
start_idx=found_idx + 1,
|
150
|
-
num_messages=
|
197
|
+
num_messages=max_messages,
|
151
198
|
)
|
152
199
|
|
153
200
|
|
154
201
|
async def default_prompt_context(
|
155
202
|
context: CappedPromptContext,
|
156
203
|
kbid: str,
|
157
|
-
|
204
|
+
ordered_paragraphs: list[FindParagraph],
|
158
205
|
) -> None:
|
159
206
|
"""
|
160
207
|
- Updates context (which is an ordered dict of text_block_id -> context_text).
|
@@ -166,128 +213,253 @@ async def default_prompt_context(
|
|
166
213
|
- Using an dict prevents from duplicates pulled in through conversation expansion.
|
167
214
|
"""
|
168
215
|
# Sort retrieved paragraphs by decreasing order (most relevant first)
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
context[pid] = text
|
191
|
-
|
192
|
-
|
193
|
-
async def get_field_extracted_text(field: Field) -> Optional[tuple[Field, str]]:
|
194
|
-
extracted_text_pb = await field.get_extracted_text(force=True)
|
195
|
-
if extracted_text_pb is None:
|
196
|
-
return None
|
197
|
-
return field, extracted_text_pb.text
|
198
|
-
|
199
|
-
|
200
|
-
async def get_resource_field_extracted_text(
|
201
|
-
kb_obj: KnowledgeBoxORM,
|
202
|
-
resource_uuid,
|
203
|
-
field_id: str,
|
204
|
-
) -> Optional[tuple[Field, str]]:
|
205
|
-
resource = await kb_obj.get(resource_uuid)
|
206
|
-
if resource is None:
|
207
|
-
return None
|
208
|
-
|
209
|
-
try:
|
210
|
-
field_type, field_key = field_id.strip("/").split("/")
|
211
|
-
except ValueError:
|
212
|
-
logger.error(f"Invalid field id: {field_id}. Skipping getting extracted text.")
|
213
|
-
return None
|
214
|
-
field = await resource.get_field(field_key, KB_REVERSE[field_type], load=False)
|
215
|
-
if field is None:
|
216
|
-
return None
|
217
|
-
result = await get_field_extracted_text(field)
|
218
|
-
if result is None:
|
219
|
-
return None
|
220
|
-
_, extracted_text = result
|
221
|
-
return field, extracted_text
|
222
|
-
|
223
|
-
|
224
|
-
async def get_resource_extracted_texts(
|
225
|
-
kbid: str,
|
226
|
-
resource_uuid: str,
|
227
|
-
) -> list[tuple[Field, str]]:
|
228
|
-
txn = await get_read_only_transaction()
|
229
|
-
storage = await get_storage()
|
230
|
-
kb = KnowledgeBoxORM(txn, storage, kbid)
|
231
|
-
resource = ResourceORM(
|
232
|
-
txn=txn,
|
233
|
-
storage=storage,
|
234
|
-
kb=kb,
|
235
|
-
uuid=resource_uuid,
|
236
|
-
)
|
237
|
-
|
238
|
-
# Schedule the extraction of the text of each field in the resource
|
239
|
-
runner = ConcurrentRunner(max_tasks=MAX_RESOURCE_FIELD_TASKS)
|
240
|
-
for field_type, field_key in await resource.get_fields(force=True):
|
241
|
-
field = await resource.get_field(field_key, field_type, load=False)
|
242
|
-
runner.schedule(get_field_extracted_text(field))
|
243
|
-
|
244
|
-
# Wait for the results
|
245
|
-
results = await runner.wait()
|
246
|
-
return [result for result in results if result is not None]
|
216
|
+
async with get_driver().transaction(read_only=True) as txn:
|
217
|
+
storage = await get_storage()
|
218
|
+
kb = KnowledgeBoxORM(txn, storage, kbid)
|
219
|
+
for paragraph in ordered_paragraphs:
|
220
|
+
context[paragraph.id] = _clean_paragraph_text(paragraph)
|
221
|
+
|
222
|
+
# If the paragraph is a conversation and it matches semantically, we assume we
|
223
|
+
# have matched with the question, therefore try to include the answer to the
|
224
|
+
# context by pulling the next few messages of the conversation field
|
225
|
+
rid, field_type, field_id, mident = paragraph.id.split("/")[:4]
|
226
|
+
if field_type == "c" and paragraph.score_type in (
|
227
|
+
SCORE_TYPE.VECTOR,
|
228
|
+
SCORE_TYPE.BOTH,
|
229
|
+
):
|
230
|
+
expanded_msgs = await get_expanded_conversation_messages(
|
231
|
+
kb=kb, rid=rid, field_id=field_id, mident=mident
|
232
|
+
)
|
233
|
+
for msg in expanded_msgs:
|
234
|
+
text = msg.content.text.strip()
|
235
|
+
pid = f"{rid}/{field_type}/{field_id}/{msg.ident}/0-{len(msg.content.text) + 1}"
|
236
|
+
context[pid] = text
|
247
237
|
|
248
238
|
|
249
239
|
async def full_resource_prompt_context(
|
250
240
|
context: CappedPromptContext,
|
251
241
|
kbid: str,
|
252
|
-
|
253
|
-
|
242
|
+
ordered_paragraphs: list[FindParagraph],
|
243
|
+
resource: Optional[str],
|
244
|
+
strategy: FullResourceStrategy,
|
254
245
|
) -> None:
|
255
246
|
"""
|
256
247
|
Algorithm steps:
|
257
248
|
- Collect the list of resources in the results (in order of relevance).
|
258
249
|
- For each resource, collect the extracted text from all its fields and craft the context.
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
|
267
|
-
|
250
|
+
Arguments:
|
251
|
+
context: The context to be updated.
|
252
|
+
kbid: The knowledge box id.
|
253
|
+
ordered_paragraphs: The results of the retrieval (find) operation.
|
254
|
+
resource: The resource to be included in the context. This is used only when chatting with a specific resource with no retrieval.
|
255
|
+
strategy: strategy instance containing, for example, the number of full resources to include in the context.
|
256
|
+
""" # noqa: E501
|
257
|
+
if resource is not None:
|
258
|
+
# The user has specified a resource to be included in the context.
|
259
|
+
ordered_resources = [resource]
|
260
|
+
else:
|
261
|
+
# Collect the list of resources in the results (in order of relevance).
|
262
|
+
ordered_resources = []
|
263
|
+
for paragraph in ordered_paragraphs:
|
264
|
+
resource_uuid = parse_text_block_id(paragraph.id).rid
|
265
|
+
if resource_uuid not in ordered_resources:
|
266
|
+
skip = False
|
267
|
+
if strategy.apply_to is not None:
|
268
|
+
# decide whether the resource should be extended or not
|
269
|
+
for label in strategy.apply_to.exclude:
|
270
|
+
skip = skip or (label in (paragraph.labels or []))
|
271
|
+
|
272
|
+
if not skip:
|
273
|
+
ordered_resources.append(resource_uuid)
|
268
274
|
|
269
275
|
# For each resource, collect the extracted text from all its fields.
|
270
|
-
|
276
|
+
resources_extracted_texts = await run_concurrently(
|
271
277
|
[
|
272
|
-
|
273
|
-
for resource_uuid in ordered_resources[:
|
278
|
+
hydrate_resource_text(kbid, resource_uuid, max_concurrent_tasks=MAX_RESOURCE_FIELD_TASKS)
|
279
|
+
for resource_uuid in ordered_resources[: strategy.count]
|
274
280
|
],
|
275
281
|
max_concurrent=MAX_RESOURCE_TASKS,
|
276
282
|
)
|
277
|
-
|
278
|
-
for
|
279
|
-
if
|
283
|
+
added_fields = set()
|
284
|
+
for resource_extracted_texts in resources_extracted_texts:
|
285
|
+
if resource_extracted_texts is None:
|
280
286
|
continue
|
281
|
-
for field, extracted_text in
|
287
|
+
for field, extracted_text in resource_extracted_texts:
|
288
|
+
# First off, remove the text block ids from paragraphs that belong to
|
289
|
+
# the same field, as otherwise the context will be duplicated.
|
290
|
+
for tb_id in context.text_block_ids():
|
291
|
+
if tb_id.startswith(field.full()):
|
292
|
+
del context[tb_id]
|
282
293
|
# Add the extracted text of each field to the context.
|
283
|
-
context[field.
|
294
|
+
context[field.full()] = extracted_text
|
295
|
+
added_fields.add(field.full())
|
296
|
+
|
297
|
+
if strategy.include_remaining_text_blocks:
|
298
|
+
for paragraph in ordered_paragraphs:
|
299
|
+
pid = cast(ParagraphId, parse_text_block_id(paragraph.id))
|
300
|
+
if pid.field_id.full() not in added_fields:
|
301
|
+
context[paragraph.id] = _clean_paragraph_text(paragraph)
|
284
302
|
|
285
303
|
|
286
|
-
async def
|
304
|
+
async def extend_prompt_context_with_metadata(
|
287
305
|
context: CappedPromptContext,
|
288
306
|
kbid: str,
|
289
|
-
|
290
|
-
|
307
|
+
strategy: MetadataExtensionStrategy,
|
308
|
+
) -> None:
|
309
|
+
text_block_ids: list[TextBlockId] = []
|
310
|
+
for text_block_id in context.text_block_ids():
|
311
|
+
try:
|
312
|
+
text_block_ids.append(parse_text_block_id(text_block_id))
|
313
|
+
except ValueError: # pragma: no cover
|
314
|
+
# Some text block ids are not paragraphs nor fields, so they are skipped
|
315
|
+
# (e.g. USER_CONTEXT_0, when the user provides extra context)
|
316
|
+
continue
|
317
|
+
if len(text_block_ids) == 0: # pragma: no cover
|
318
|
+
return
|
319
|
+
|
320
|
+
if MetadataExtensionType.ORIGIN in strategy.types:
|
321
|
+
await extend_prompt_context_with_origin_metadata(context, kbid, text_block_ids)
|
322
|
+
|
323
|
+
if MetadataExtensionType.CLASSIFICATION_LABELS in strategy.types:
|
324
|
+
await extend_prompt_context_with_classification_labels(context, kbid, text_block_ids)
|
325
|
+
|
326
|
+
if MetadataExtensionType.NERS in strategy.types:
|
327
|
+
await extend_prompt_context_with_ner(context, kbid, text_block_ids)
|
328
|
+
|
329
|
+
if MetadataExtensionType.EXTRA_METADATA in strategy.types:
|
330
|
+
await extend_prompt_context_with_extra_metadata(context, kbid, text_block_ids)
|
331
|
+
|
332
|
+
|
333
|
+
def parse_text_block_id(text_block_id: str) -> TextBlockId:
|
334
|
+
try:
|
335
|
+
# Typically, the text block id is a paragraph id
|
336
|
+
return ParagraphId.from_string(text_block_id)
|
337
|
+
except ValueError:
|
338
|
+
# When we're doing `full_resource` or `hierarchy` strategies,the text block id
|
339
|
+
# is a field id
|
340
|
+
return FieldId.from_string(text_block_id)
|
341
|
+
|
342
|
+
|
343
|
+
async def extend_prompt_context_with_origin_metadata(context, kbid, text_block_ids: list[TextBlockId]):
|
344
|
+
async def _get_origin(kbid: str, rid: str) -> tuple[str, Optional[Origin]]:
|
345
|
+
origin = None
|
346
|
+
resource = await cache.get_resource(kbid, rid)
|
347
|
+
if resource is not None:
|
348
|
+
pb_origin = await resource.get_origin()
|
349
|
+
if pb_origin is not None:
|
350
|
+
origin = from_proto.origin(pb_origin)
|
351
|
+
return rid, origin
|
352
|
+
|
353
|
+
rids = {tb_id.rid for tb_id in text_block_ids}
|
354
|
+
origins = await run_concurrently([_get_origin(kbid, rid) for rid in rids])
|
355
|
+
rid_to_origin = {rid: origin for rid, origin in origins if origin is not None}
|
356
|
+
for tb_id in text_block_ids:
|
357
|
+
origin = rid_to_origin.get(tb_id.rid)
|
358
|
+
if origin is not None and tb_id.full() in context.output:
|
359
|
+
context[tb_id.full()] += f"\n\nDOCUMENT METADATA AT ORIGIN:\n{to_yaml(origin)}"
|
360
|
+
|
361
|
+
|
362
|
+
async def extend_prompt_context_with_classification_labels(
|
363
|
+
context, kbid, text_block_ids: list[TextBlockId]
|
364
|
+
):
|
365
|
+
async def _get_labels(kbid: str, _id: TextBlockId) -> tuple[TextBlockId, list[tuple[str, str]]]:
|
366
|
+
fid = _id if isinstance(_id, FieldId) else _id.field_id
|
367
|
+
labels = set()
|
368
|
+
resource = await cache.get_resource(kbid, fid.rid)
|
369
|
+
if resource is not None:
|
370
|
+
pb_basic = await resource.get_basic()
|
371
|
+
if pb_basic is not None:
|
372
|
+
# Add the classification labels of the resource
|
373
|
+
for classif in pb_basic.usermetadata.classifications:
|
374
|
+
labels.add((classif.labelset, classif.label))
|
375
|
+
# Add the classifications labels of the field
|
376
|
+
for fc in pb_basic.computedmetadata.field_classifications:
|
377
|
+
if fc.field.field == fid.key and fc.field.field_type == fid.pb_type:
|
378
|
+
for classif in fc.classifications:
|
379
|
+
if classif.cancelled_by_user: # pragma: no cover
|
380
|
+
continue
|
381
|
+
labels.add((classif.labelset, classif.label))
|
382
|
+
return _id, list(labels)
|
383
|
+
|
384
|
+
classif_labels = await run_concurrently([_get_labels(kbid, tb_id) for tb_id in text_block_ids])
|
385
|
+
tb_id_to_labels = {tb_id: labels for tb_id, labels in classif_labels if len(labels) > 0}
|
386
|
+
for tb_id in text_block_ids:
|
387
|
+
labels = tb_id_to_labels.get(tb_id)
|
388
|
+
if labels is not None and tb_id.full() in context.output:
|
389
|
+
labels_text = "DOCUMENT CLASSIFICATION LABELS:"
|
390
|
+
for labelset, label in labels:
|
391
|
+
labels_text += f"\n - {label} ({labelset})"
|
392
|
+
context[tb_id.full()] += "\n\n" + labels_text
|
393
|
+
|
394
|
+
|
395
|
+
async def extend_prompt_context_with_ner(context, kbid, text_block_ids: list[TextBlockId]):
|
396
|
+
async def _get_ners(kbid: str, _id: TextBlockId) -> tuple[TextBlockId, dict[str, set[str]]]:
|
397
|
+
fid = _id if isinstance(_id, FieldId) else _id.field_id
|
398
|
+
ners: dict[str, set[str]] = {}
|
399
|
+
resource = await cache.get_resource(kbid, fid.rid)
|
400
|
+
if resource is not None:
|
401
|
+
field = await resource.get_field(fid.key, fid.pb_type, load=False)
|
402
|
+
fcm = await field.get_field_metadata()
|
403
|
+
if fcm is not None:
|
404
|
+
# Data Augmentation + Processor entities
|
405
|
+
for (
|
406
|
+
data_aumgentation_task_id,
|
407
|
+
entities_wrapper,
|
408
|
+
) in fcm.metadata.entities.items():
|
409
|
+
for entity in entities_wrapper.entities:
|
410
|
+
ners.setdefault(entity.label, set()).add(entity.text)
|
411
|
+
# Legacy processor entities
|
412
|
+
# TODO: Remove once processor doesn't use this anymore and remove the positions and ner fields from the message
|
413
|
+
for token, family in fcm.metadata.ner.items():
|
414
|
+
ners.setdefault(family, set()).add(token)
|
415
|
+
return _id, ners
|
416
|
+
|
417
|
+
nerss = await run_concurrently([_get_ners(kbid, tb_id) for tb_id in text_block_ids])
|
418
|
+
tb_id_to_ners = {tb_id: ners for tb_id, ners in nerss if len(ners) > 0}
|
419
|
+
for tb_id in text_block_ids:
|
420
|
+
ners = tb_id_to_ners.get(tb_id)
|
421
|
+
if ners is not None and tb_id.full() in context.output:
|
422
|
+
ners_text = "DOCUMENT NAMED ENTITIES (NERs):"
|
423
|
+
for family, tokens in ners.items():
|
424
|
+
ners_text += f"\n - {family}:"
|
425
|
+
for token in sorted(list(tokens)):
|
426
|
+
ners_text += f"\n - {token}"
|
427
|
+
context[tb_id.full()] += "\n\n" + ners_text
|
428
|
+
|
429
|
+
|
430
|
+
async def extend_prompt_context_with_extra_metadata(context, kbid, text_block_ids: list[TextBlockId]):
|
431
|
+
async def _get_extra(kbid: str, rid: str) -> tuple[str, Optional[Extra]]:
|
432
|
+
extra = None
|
433
|
+
resource = await cache.get_resource(kbid, rid)
|
434
|
+
if resource is not None:
|
435
|
+
pb_extra = await resource.get_extra()
|
436
|
+
if pb_extra is not None:
|
437
|
+
extra = from_proto.extra(pb_extra)
|
438
|
+
return rid, extra
|
439
|
+
|
440
|
+
rids = {tb_id.rid for tb_id in text_block_ids}
|
441
|
+
extras = await run_concurrently([_get_extra(kbid, rid) for rid in rids])
|
442
|
+
rid_to_extra = {rid: extra for rid, extra in extras if extra is not None}
|
443
|
+
for tb_id in text_block_ids:
|
444
|
+
extra = rid_to_extra.get(tb_id.rid)
|
445
|
+
if extra is not None and tb_id.full() in context.output:
|
446
|
+
context[tb_id.full()] += f"\n\nDOCUMENT EXTRA METADATA:\n{to_yaml(extra)}"
|
447
|
+
|
448
|
+
|
449
|
+
def to_yaml(obj: BaseModel) -> str:
|
450
|
+
return yaml.dump(
|
451
|
+
obj.model_dump(exclude_none=True, exclude_defaults=True, exclude_unset=True),
|
452
|
+
default_flow_style=False,
|
453
|
+
indent=2,
|
454
|
+
sort_keys=True,
|
455
|
+
)
|
456
|
+
|
457
|
+
|
458
|
+
async def field_extension_prompt_context(
|
459
|
+
context: CappedPromptContext,
|
460
|
+
kbid: str,
|
461
|
+
ordered_paragraphs: list[FindParagraph],
|
462
|
+
strategy: FieldExtensionStrategy,
|
291
463
|
) -> None:
|
292
464
|
"""
|
293
465
|
Algorithm steps:
|
@@ -296,35 +468,402 @@ async def composed_prompt_context(
|
|
296
468
|
- Add the extracted text of each field to the beginning of the context.
|
297
469
|
- Add the extracted text of each paragraph to the end of the context.
|
298
470
|
"""
|
299
|
-
# Collect the list of resources in the results (in order of relevance).
|
300
|
-
ordered_paras = get_ordered_paragraphs(results)
|
301
471
|
ordered_resources = []
|
302
|
-
for paragraph in
|
303
|
-
resource_uuid = paragraph.id.
|
472
|
+
for paragraph in ordered_paragraphs:
|
473
|
+
resource_uuid = ParagraphId.from_string(paragraph.id).rid
|
304
474
|
if resource_uuid not in ordered_resources:
|
305
475
|
ordered_resources.append(resource_uuid)
|
306
476
|
|
307
477
|
# Fetch the extracted texts of the specified fields for each resource
|
308
|
-
|
309
|
-
|
310
|
-
|
311
|
-
|
312
|
-
|
313
|
-
|
314
|
-
|
315
|
-
|
478
|
+
extend_fields = strategy.fields
|
479
|
+
extend_field_ids = []
|
480
|
+
for resource_uuid in ordered_resources:
|
481
|
+
for field_id in extend_fields:
|
482
|
+
try:
|
483
|
+
fid = FieldId.from_string(f"{resource_uuid}/{field_id.strip('/')}")
|
484
|
+
extend_field_ids.append(fid)
|
485
|
+
except ValueError: # pragma: no cover
|
486
|
+
# Invalid field id, skiping
|
487
|
+
continue
|
488
|
+
|
489
|
+
tasks = [hydrate_field_text(kbid, fid) for fid in extend_field_ids]
|
316
490
|
field_extracted_texts = await run_concurrently(tasks)
|
317
491
|
|
318
492
|
for result in field_extracted_texts:
|
319
|
-
if result is None:
|
493
|
+
if result is None: # pragma: no cover
|
320
494
|
continue
|
321
|
-
# Add the extracted text of each field to the beginning of the context.
|
322
495
|
field, extracted_text = result
|
323
|
-
|
496
|
+
# First off, remove the text block ids from paragraphs that belong to
|
497
|
+
# the same field, as otherwise the context will be duplicated.
|
498
|
+
for tb_id in context.text_block_ids():
|
499
|
+
if tb_id.startswith(field.full()):
|
500
|
+
del context[tb_id]
|
501
|
+
# Add the extracted text of each field to the beginning of the context.
|
502
|
+
context[field.full()] = extracted_text
|
324
503
|
|
325
504
|
# Add the extracted text of each paragraph to the end of the context.
|
326
|
-
for paragraph in
|
505
|
+
for paragraph in ordered_paragraphs:
|
506
|
+
context[paragraph.id] = _clean_paragraph_text(paragraph)
|
507
|
+
|
508
|
+
|
509
|
+
async def get_paragraph_text_with_neighbours(
|
510
|
+
kbid: str,
|
511
|
+
pid: ParagraphId,
|
512
|
+
field_paragraphs: list[ParagraphId],
|
513
|
+
before: int = 0,
|
514
|
+
after: int = 0,
|
515
|
+
) -> tuple[ParagraphId, str]:
|
516
|
+
"""
|
517
|
+
This function will get the paragraph text of the paragraph with the neighbouring paragraphs included.
|
518
|
+
Parameters:
|
519
|
+
kbid: The knowledge box id.
|
520
|
+
pid: The matching paragraph id.
|
521
|
+
field_paragraphs: The list of paragraph ids of the field.
|
522
|
+
before: The number of paragraphs to include before the matching paragraph.
|
523
|
+
after: The number of paragraphs to include after the matching paragraph.
|
524
|
+
"""
|
525
|
+
|
526
|
+
async def _get_paragraph_text(
|
527
|
+
kbid: str,
|
528
|
+
pid: ParagraphId,
|
529
|
+
) -> tuple[ParagraphId, str]:
|
530
|
+
return pid, await get_paragraph_text(
|
531
|
+
kbid=kbid,
|
532
|
+
paragraph_id=pid,
|
533
|
+
log_on_missing_field=True,
|
534
|
+
)
|
535
|
+
|
536
|
+
ops = []
|
537
|
+
try:
|
538
|
+
for paragraph_index in get_neighbouring_paragraph_indexes(
|
539
|
+
field_paragraphs=field_paragraphs,
|
540
|
+
matching_paragraph=pid,
|
541
|
+
before=before,
|
542
|
+
after=after,
|
543
|
+
):
|
544
|
+
neighbour_pid = field_paragraphs[paragraph_index]
|
545
|
+
ops.append(
|
546
|
+
asyncio.create_task(
|
547
|
+
_get_paragraph_text(
|
548
|
+
kbid=kbid,
|
549
|
+
pid=neighbour_pid,
|
550
|
+
)
|
551
|
+
)
|
552
|
+
)
|
553
|
+
except ParagraphIdNotFoundInExtractedMetadata:
|
554
|
+
logger.warning(
|
555
|
+
"Could not find matching paragraph in extracted metadata. This is odd and needs to be investigated.",
|
556
|
+
extra={
|
557
|
+
"kbid": kbid,
|
558
|
+
"matching_paragraph": pid.full(),
|
559
|
+
"field_paragraphs": [p.full() for p in field_paragraphs],
|
560
|
+
},
|
561
|
+
)
|
562
|
+
# If we could not find the matching paragraph in the extracted metadata, we can't retrieve
|
563
|
+
# the neighbouring paragraphs and we simply fetch the text of the matching paragraph.
|
564
|
+
ops.append(
|
565
|
+
asyncio.create_task(
|
566
|
+
_get_paragraph_text(
|
567
|
+
kbid=kbid,
|
568
|
+
pid=pid,
|
569
|
+
)
|
570
|
+
)
|
571
|
+
)
|
572
|
+
|
573
|
+
results = []
|
574
|
+
if len(ops) > 0:
|
575
|
+
results = await asyncio.gather(*ops)
|
576
|
+
|
577
|
+
# Sort the results by the paragraph start
|
578
|
+
results.sort(key=lambda x: x[0].paragraph_start)
|
579
|
+
paragraph_texts = []
|
580
|
+
for _, text in results:
|
581
|
+
if text != "":
|
582
|
+
paragraph_texts.append(text)
|
583
|
+
return pid, "\n\n".join(paragraph_texts)
|
584
|
+
|
585
|
+
|
586
|
+
async def get_field_paragraphs_list(
|
587
|
+
kbid: str,
|
588
|
+
field: FieldId,
|
589
|
+
paragraphs: list[ParagraphId],
|
590
|
+
) -> None:
|
591
|
+
"""
|
592
|
+
Modifies the paragraphs list by adding the paragraph ids of the field, sorted by position.
|
593
|
+
"""
|
594
|
+
resource = await cache.get_resource(kbid, field.rid)
|
595
|
+
if resource is None: # pragma: no cover
|
596
|
+
return
|
597
|
+
field_obj: Field = await resource.get_field(key=field.key, type=field.pb_type, load=False)
|
598
|
+
field_metadata: Optional[resources_pb2.FieldComputedMetadata] = await field_obj.get_field_metadata(
|
599
|
+
force=True
|
600
|
+
)
|
601
|
+
if field_metadata is None: # pragma: no cover
|
602
|
+
return
|
603
|
+
for paragraph in field_metadata.metadata.paragraphs:
|
604
|
+
paragraphs.append(
|
605
|
+
ParagraphId(
|
606
|
+
field_id=field,
|
607
|
+
paragraph_start=paragraph.start,
|
608
|
+
paragraph_end=paragraph.end,
|
609
|
+
)
|
610
|
+
)
|
611
|
+
|
612
|
+
|
613
|
+
async def neighbouring_paragraphs_prompt_context(
|
614
|
+
context: CappedPromptContext,
|
615
|
+
kbid: str,
|
616
|
+
ordered_text_blocks: list[FindParagraph],
|
617
|
+
strategy: NeighbouringParagraphsStrategy,
|
618
|
+
) -> None:
|
619
|
+
"""
|
620
|
+
This function will get the paragraph texts and then craft a context with the neighbouring paragraphs of the
|
621
|
+
paragraphs in the ordered_paragraphs list. The number of paragraphs to include before and after each paragraph
|
622
|
+
"""
|
623
|
+
# First, get the sorted list of paragraphs for each matching field
|
624
|
+
# so we can know the indexes of the neighbouring paragraphs
|
625
|
+
unique_fields = {
|
626
|
+
ParagraphId.from_string(text_block.id).field_id for text_block in ordered_text_blocks
|
627
|
+
}
|
628
|
+
paragraphs_by_field: dict[FieldId, list[ParagraphId]] = {}
|
629
|
+
field_ops = []
|
630
|
+
for field_id in unique_fields:
|
631
|
+
plist = paragraphs_by_field.setdefault(field_id, [])
|
632
|
+
field_ops.append(
|
633
|
+
asyncio.create_task(get_field_paragraphs_list(kbid=kbid, field=field_id, paragraphs=plist))
|
634
|
+
)
|
635
|
+
if field_ops:
|
636
|
+
await asyncio.gather(*field_ops)
|
637
|
+
|
638
|
+
# Now, get the paragraph texts with the neighbouring paragraphs
|
639
|
+
paragraph_ops = []
|
640
|
+
for text_block in ordered_text_blocks:
|
641
|
+
pid = ParagraphId.from_string(text_block.id)
|
642
|
+
paragraph_ops.append(
|
643
|
+
asyncio.create_task(
|
644
|
+
get_paragraph_text_with_neighbours(
|
645
|
+
kbid=kbid,
|
646
|
+
pid=pid,
|
647
|
+
before=strategy.before,
|
648
|
+
after=strategy.after,
|
649
|
+
field_paragraphs=paragraphs_by_field.get(pid.field_id, []),
|
650
|
+
)
|
651
|
+
)
|
652
|
+
)
|
653
|
+
if not paragraph_ops: # pragma: no cover
|
654
|
+
return
|
655
|
+
|
656
|
+
results: list[tuple[ParagraphId, str]] = await asyncio.gather(*paragraph_ops)
|
657
|
+
# Add the paragraph texts to the context
|
658
|
+
for pid, text in results:
|
659
|
+
if text != "":
|
660
|
+
context[pid.full()] = text
|
661
|
+
|
662
|
+
|
663
|
+
async def conversation_prompt_context(
|
664
|
+
context: CappedPromptContext,
|
665
|
+
kbid: str,
|
666
|
+
ordered_paragraphs: list[FindParagraph],
|
667
|
+
conversational_strategy: ConversationalStrategy,
|
668
|
+
visual_llm: bool,
|
669
|
+
):
|
670
|
+
analyzed_fields: List[str] = []
|
671
|
+
async with get_driver().transaction(read_only=True) as txn:
|
672
|
+
storage = await get_storage()
|
673
|
+
kb = KnowledgeBoxORM(txn, storage, kbid)
|
674
|
+
for paragraph in ordered_paragraphs:
|
675
|
+
context[paragraph.id] = _clean_paragraph_text(paragraph)
|
676
|
+
|
677
|
+
# If the paragraph is a conversation and it matches semantically, we assume we
|
678
|
+
# have matched with the question, therefore try to include the answer to the
|
679
|
+
# context by pulling the next few messages of the conversation field
|
680
|
+
rid, field_type, field_id, mident = paragraph.id.split("/")[:4]
|
681
|
+
if field_type == "c" and paragraph.score_type in (
|
682
|
+
SCORE_TYPE.VECTOR,
|
683
|
+
SCORE_TYPE.BOTH,
|
684
|
+
SCORE_TYPE.BM25,
|
685
|
+
):
|
686
|
+
field_unique_id = "-".join([rid, field_type, field_id])
|
687
|
+
if field_unique_id in analyzed_fields:
|
688
|
+
continue
|
689
|
+
resource = await kb.get(rid)
|
690
|
+
if resource is None: # pragma: no cover
|
691
|
+
continue
|
692
|
+
|
693
|
+
field_obj: Conversation = await resource.get_field(
|
694
|
+
field_id, FIELD_TYPE_STR_TO_PB["c"], load=True
|
695
|
+
) # type: ignore
|
696
|
+
cmetadata = await field_obj.get_metadata()
|
697
|
+
|
698
|
+
attachments: List[resources_pb2.FieldRef] = []
|
699
|
+
if conversational_strategy.full:
|
700
|
+
extracted_text = await field_obj.get_extracted_text()
|
701
|
+
for current_page in range(1, cmetadata.pages + 1):
|
702
|
+
conv = await field_obj.db_get_value(current_page)
|
703
|
+
|
704
|
+
for message in conv.messages:
|
705
|
+
ident = message.ident
|
706
|
+
if extracted_text is not None:
|
707
|
+
text = extracted_text.split_text.get(ident, message.content.text.strip())
|
708
|
+
else:
|
709
|
+
text = message.content.text.strip()
|
710
|
+
pid = f"{rid}/{field_type}/{field_id}/{ident}/0-{len(text) + 1}"
|
711
|
+
context[pid] = text
|
712
|
+
attachments.extend(message.content.attachments_fields)
|
713
|
+
else:
|
714
|
+
# Add first message
|
715
|
+
extracted_text = await field_obj.get_extracted_text()
|
716
|
+
first_page = await field_obj.db_get_value()
|
717
|
+
if len(first_page.messages) > 0:
|
718
|
+
message = first_page.messages[0]
|
719
|
+
ident = message.ident
|
720
|
+
if extracted_text is not None:
|
721
|
+
text = extracted_text.split_text.get(ident, message.content.text.strip())
|
722
|
+
else:
|
723
|
+
text = message.content.text.strip()
|
724
|
+
pid = f"{rid}/{field_type}/{field_id}/{ident}/0-{len(text) + 1}"
|
725
|
+
context[pid] = text
|
726
|
+
attachments.extend(message.content.attachments_fields)
|
727
|
+
|
728
|
+
messages: Deque[resources_pb2.Message] = deque(
|
729
|
+
maxlen=conversational_strategy.max_messages
|
730
|
+
)
|
731
|
+
|
732
|
+
pending = -1
|
733
|
+
for page in range(1, cmetadata.pages + 1):
|
734
|
+
# Collect the messages with the window asked by the user arround the match paragraph
|
735
|
+
conv = await field_obj.db_get_value(page)
|
736
|
+
for message in conv.messages:
|
737
|
+
messages.append(message)
|
738
|
+
if pending > 0:
|
739
|
+
pending -= 1
|
740
|
+
if message.ident == mident:
|
741
|
+
pending = (conversational_strategy.max_messages - 1) // 2
|
742
|
+
if pending == 0:
|
743
|
+
break
|
744
|
+
if pending == 0:
|
745
|
+
break
|
746
|
+
|
747
|
+
for message in messages:
|
748
|
+
text = message.content.text.strip()
|
749
|
+
pid = f"{rid}/{field_type}/{field_id}/{message.ident}/0-{len(message.content.text) + 1}"
|
750
|
+
context[pid] = text
|
751
|
+
attachments.extend(message.content.attachments_fields)
|
752
|
+
|
753
|
+
if conversational_strategy.attachments_text:
|
754
|
+
# add on the context the images if vlm enabled
|
755
|
+
for attachment in attachments:
|
756
|
+
field: File = await resource.get_field(
|
757
|
+
attachment.field_id, attachment.field_type, load=True
|
758
|
+
) # type: ignore
|
759
|
+
extracted_text = await field.get_extracted_text()
|
760
|
+
if extracted_text is not None:
|
761
|
+
pid = f"{rid}/{field_type}/{attachment.field_id}/0-{len(extracted_text.text) + 1}"
|
762
|
+
context[pid] = f"Attachment {attachment.field_id}: {extracted_text.text}\n\n"
|
763
|
+
|
764
|
+
if conversational_strategy.attachments_images and visual_llm:
|
765
|
+
for attachment in attachments:
|
766
|
+
file_field: File = await resource.get_field(
|
767
|
+
attachment.field_id, attachment.field_type, load=True
|
768
|
+
) # type: ignore
|
769
|
+
image = await get_file_thumbnail_image(file_field)
|
770
|
+
if image is not None:
|
771
|
+
pid = f"{rid}/f/{attachment.field_id}/0-0"
|
772
|
+
context.images[pid] = image
|
773
|
+
|
774
|
+
analyzed_fields.append(field_unique_id)
|
775
|
+
|
776
|
+
|
777
|
+
async def hierarchy_prompt_context(
|
778
|
+
context: CappedPromptContext,
|
779
|
+
kbid: str,
|
780
|
+
ordered_paragraphs: list[FindParagraph],
|
781
|
+
strategy: HierarchyResourceStrategy,
|
782
|
+
) -> None:
|
783
|
+
"""
|
784
|
+
This function will get the paragraph texts (possibly with extra characters, if extra_characters > 0) and then
|
785
|
+
craft a context with all paragraphs of the same resource grouped together. Moreover, on each group of paragraphs,
|
786
|
+
it includes the resource title and summary so that the LLM can have a better understanding of the context.
|
787
|
+
"""
|
788
|
+
paragraphs_extra_characters = max(strategy.count, 0)
|
789
|
+
# Make a copy of the ordered paragraphs to avoid modifying the original list, which is returned
|
790
|
+
# in the response to the user
|
791
|
+
ordered_paragraphs_copy = copy.deepcopy(ordered_paragraphs)
|
792
|
+
resources: Dict[str, ExtraCharsParagraph] = {}
|
793
|
+
|
794
|
+
# Iterate paragraphs to get extended text
|
795
|
+
for paragraph in ordered_paragraphs_copy:
|
796
|
+
paragraph_id = ParagraphId.from_string(paragraph.id)
|
797
|
+
extended_paragraph_text = paragraph.text
|
798
|
+
if paragraphs_extra_characters > 0:
|
799
|
+
extended_paragraph_text = await get_paragraph_text(
|
800
|
+
kbid=kbid,
|
801
|
+
paragraph_id=paragraph_id,
|
802
|
+
log_on_missing_field=True,
|
803
|
+
)
|
804
|
+
rid = paragraph_id.rid
|
805
|
+
if rid not in resources:
|
806
|
+
# Get the title and the summary of the resource
|
807
|
+
title_text = await get_paragraph_text(
|
808
|
+
kbid=kbid,
|
809
|
+
paragraph_id=ParagraphId(
|
810
|
+
field_id=FieldId(
|
811
|
+
rid=rid,
|
812
|
+
type="a",
|
813
|
+
key="title",
|
814
|
+
),
|
815
|
+
paragraph_start=0,
|
816
|
+
paragraph_end=500,
|
817
|
+
),
|
818
|
+
log_on_missing_field=False,
|
819
|
+
)
|
820
|
+
summary_text = await get_paragraph_text(
|
821
|
+
kbid=kbid,
|
822
|
+
paragraph_id=ParagraphId(
|
823
|
+
field_id=FieldId(
|
824
|
+
rid=rid,
|
825
|
+
type="a",
|
826
|
+
key="summary",
|
827
|
+
),
|
828
|
+
paragraph_start=0,
|
829
|
+
paragraph_end=1000,
|
830
|
+
),
|
831
|
+
log_on_missing_field=False,
|
832
|
+
)
|
833
|
+
resources[rid] = ExtraCharsParagraph(
|
834
|
+
title=title_text,
|
835
|
+
summary=summary_text,
|
836
|
+
paragraphs=[(paragraph, extended_paragraph_text)],
|
837
|
+
)
|
838
|
+
else:
|
839
|
+
resources[rid].paragraphs.append((paragraph, extended_paragraph_text))
|
840
|
+
|
841
|
+
# Modify the first paragraph of each resource to include the title and summary of the resource, as well as the
|
842
|
+
# extended paragraph text of all the paragraphs in the resource.
|
843
|
+
for values in resources.values():
|
844
|
+
title_text = values.title
|
845
|
+
summary_text = values.summary
|
846
|
+
first_paragraph = None
|
847
|
+
text_with_hierarchy = ""
|
848
|
+
for paragraph, extended_paragraph_text in values.paragraphs:
|
849
|
+
if first_paragraph is None:
|
850
|
+
first_paragraph = paragraph
|
851
|
+
text_with_hierarchy += "\n EXTRACTED BLOCK: \n " + extended_paragraph_text + " \n\n "
|
852
|
+
# All paragraphs of the resource are cleared except the first one, which will be the
|
853
|
+
# one containing the whole hierarchy information
|
854
|
+
paragraph.text = ""
|
855
|
+
|
856
|
+
if first_paragraph is not None:
|
857
|
+
# The first paragraph is the only one holding the hierarchy information
|
858
|
+
first_paragraph.text = f"DOCUMENT: {title_text} \n SUMMARY: {summary_text} \n RESOURCE CONTENT: {text_with_hierarchy}"
|
859
|
+
|
860
|
+
# Now that the paragraphs have been modified, we can add them to the context
|
861
|
+
for paragraph in ordered_paragraphs_copy:
|
862
|
+
if paragraph.text == "":
|
863
|
+
# Skip paragraphs that were cleared in the hierarchy expansion
|
864
|
+
continue
|
327
865
|
context[paragraph.id] = _clean_paragraph_text(paragraph)
|
866
|
+
return
|
328
867
|
|
329
868
|
|
330
869
|
class PromptContextBuilder:
|
@@ -335,19 +874,21 @@ class PromptContextBuilder:
|
|
335
874
|
def __init__(
|
336
875
|
self,
|
337
876
|
kbid: str,
|
338
|
-
|
877
|
+
ordered_paragraphs: list[FindParagraph],
|
878
|
+
resource: Optional[str] = None,
|
339
879
|
user_context: Optional[list[str]] = None,
|
340
880
|
strategies: Optional[Sequence[RagStrategy]] = None,
|
341
881
|
image_strategies: Optional[Sequence[ImageRagStrategy]] = None,
|
342
|
-
|
882
|
+
max_context_characters: Optional[int] = None,
|
343
883
|
visual_llm: bool = False,
|
344
884
|
):
|
345
885
|
self.kbid = kbid
|
346
|
-
self.
|
886
|
+
self.ordered_paragraphs = ordered_paragraphs
|
887
|
+
self.resource = resource
|
347
888
|
self.user_context = user_context
|
348
889
|
self.strategies = strategies
|
349
890
|
self.image_strategies = image_strategies
|
350
|
-
self.
|
891
|
+
self.max_context_characters = max_context_characters
|
351
892
|
self.visual_llm = visual_llm
|
352
893
|
|
353
894
|
def prepend_user_context(self, context: CappedPromptContext):
|
@@ -359,95 +900,178 @@ class PromptContextBuilder:
|
|
359
900
|
async def build(
|
360
901
|
self,
|
361
902
|
) -> tuple[PromptContext, PromptContextOrder, PromptContextImages]:
|
362
|
-
ccontext = CappedPromptContext(max_size=self.
|
903
|
+
ccontext = CappedPromptContext(max_size=self.max_context_characters)
|
363
904
|
self.prepend_user_context(ccontext)
|
364
905
|
await self._build_context(ccontext)
|
365
|
-
|
366
906
|
if self.visual_llm:
|
367
907
|
await self._build_context_images(ccontext)
|
368
908
|
|
369
909
|
context = ccontext.output
|
370
910
|
context_images = ccontext.images
|
371
|
-
context_order = {
|
372
|
-
text_block_id: order for order, text_block_id in enumerate(context.keys())
|
373
|
-
}
|
911
|
+
context_order = {text_block_id: order for order, text_block_id in enumerate(context.keys())}
|
374
912
|
return context, context_order, context_images
|
375
913
|
|
376
914
|
async def _build_context_images(self, context: CappedPromptContext) -> None:
|
377
|
-
|
378
|
-
|
379
|
-
|
380
|
-
|
381
|
-
|
382
|
-
|
383
|
-
|
384
|
-
|
385
|
-
|
386
|
-
|
387
|
-
|
388
|
-
|
389
|
-
|
390
|
-
|
391
|
-
|
392
|
-
|
393
|
-
|
394
|
-
if
|
395
|
-
|
396
|
-
|
397
|
-
|
398
|
-
|
399
|
-
|
400
|
-
|
401
|
-
|
402
|
-
|
403
|
-
|
404
|
-
|
405
|
-
)
|
915
|
+
if self.image_strategies is None or len(self.image_strategies) == 0:
|
916
|
+
# Nothing to do
|
917
|
+
return
|
918
|
+
page_image_strategy: Optional[PageImageStrategy] = None
|
919
|
+
max_page_images = 5
|
920
|
+
table_image_strategy: Optional[TableImageStrategy] = None
|
921
|
+
paragraph_image_strategy: Optional[ParagraphImageStrategy] = None
|
922
|
+
for strategy in self.image_strategies:
|
923
|
+
if strategy.name == ImageRagStrategyName.PAGE_IMAGE:
|
924
|
+
if page_image_strategy is None:
|
925
|
+
page_image_strategy = cast(PageImageStrategy, strategy)
|
926
|
+
if page_image_strategy.count is not None:
|
927
|
+
max_page_images = page_image_strategy.count
|
928
|
+
elif strategy.name == ImageRagStrategyName.TABLES:
|
929
|
+
if table_image_strategy is None:
|
930
|
+
table_image_strategy = cast(TableImageStrategy, strategy)
|
931
|
+
elif strategy.name == ImageRagStrategyName.PARAGRAPH_IMAGE:
|
932
|
+
if paragraph_image_strategy is None:
|
933
|
+
paragraph_image_strategy = cast(ParagraphImageStrategy, strategy)
|
934
|
+
else: # pragma: no cover
|
935
|
+
logger.warning(
|
936
|
+
"Unknown image strategy",
|
937
|
+
extra={"strategy": strategy.name, "kbid": self.kbid},
|
938
|
+
)
|
939
|
+
page_images_added = 0
|
940
|
+
for paragraph in self.ordered_paragraphs:
|
941
|
+
pid = ParagraphId.from_string(paragraph.id)
|
942
|
+
paragraph_page_number = get_paragraph_page_number(paragraph)
|
406
943
|
if (
|
407
|
-
|
408
|
-
and
|
409
|
-
and
|
410
|
-
and paragraph.reference != ""
|
944
|
+
page_image_strategy is not None
|
945
|
+
and page_images_added < max_page_images
|
946
|
+
and paragraph_page_number is not None
|
411
947
|
):
|
412
|
-
|
413
|
-
|
414
|
-
|
415
|
-
|
948
|
+
# page_image_id: rid/f/myfield/0
|
949
|
+
page_image_id = "/".join([pid.field_id.full(), str(paragraph_page_number)])
|
950
|
+
if page_image_id not in context.images:
|
951
|
+
image = await get_page_image(self.kbid, pid, paragraph_page_number)
|
952
|
+
if image is not None:
|
953
|
+
context.images[page_image_id] = image
|
954
|
+
page_images_added += 1
|
955
|
+
else:
|
956
|
+
logger.warning(
|
957
|
+
f"Could not retrieve image for paragraph from storage",
|
958
|
+
extra={
|
959
|
+
"kbid": self.kbid,
|
960
|
+
"paragraph": pid.full(),
|
961
|
+
"page_number": paragraph_page_number,
|
962
|
+
},
|
963
|
+
)
|
964
|
+
|
965
|
+
add_table = table_image_strategy is not None and paragraph.is_a_table
|
966
|
+
add_paragraph = paragraph_image_strategy is not None and not paragraph.is_a_table
|
967
|
+
if (add_table or add_paragraph) and (
|
968
|
+
paragraph.reference is not None and paragraph.reference != ""
|
969
|
+
):
|
970
|
+
pimage = await get_paragraph_image(self.kbid, pid, paragraph.reference)
|
971
|
+
if pimage is not None:
|
972
|
+
context.images[paragraph.id] = pimage
|
973
|
+
else:
|
974
|
+
logger.warning(
|
975
|
+
f"Could not retrieve image for paragraph from storage",
|
976
|
+
extra={
|
977
|
+
"kbid": self.kbid,
|
978
|
+
"paragraph": pid.full(),
|
979
|
+
"reference": paragraph.reference,
|
980
|
+
},
|
981
|
+
)
|
416
982
|
|
417
983
|
async def _build_context(self, context: CappedPromptContext) -> None:
|
418
984
|
if self.strategies is None or len(self.strategies) == 0:
|
419
|
-
|
985
|
+
# When no strategy is specified, use the default one
|
986
|
+
await default_prompt_context(context, self.kbid, self.ordered_paragraphs)
|
420
987
|
return
|
421
|
-
|
422
|
-
|
423
|
-
|
424
|
-
|
988
|
+
else:
|
989
|
+
# Add the paragraphs to the context and then apply the strategies
|
990
|
+
for paragraph in self.ordered_paragraphs:
|
991
|
+
context[paragraph.id] = _clean_paragraph_text(paragraph)
|
992
|
+
|
993
|
+
full_resource: Optional[FullResourceStrategy] = None
|
994
|
+
hierarchy: Optional[HierarchyResourceStrategy] = None
|
995
|
+
neighbouring_paragraphs: Optional[NeighbouringParagraphsStrategy] = None
|
996
|
+
field_extension: Optional[FieldExtensionStrategy] = None
|
997
|
+
metadata_extension: Optional[MetadataExtensionStrategy] = None
|
998
|
+
conversational_strategy: Optional[ConversationalStrategy] = None
|
425
999
|
for strategy in self.strategies:
|
426
1000
|
if strategy.name == RagStrategyName.FIELD_EXTENSION:
|
427
|
-
|
1001
|
+
field_extension = cast(FieldExtensionStrategy, strategy)
|
1002
|
+
elif strategy.name == RagStrategyName.CONVERSATION:
|
1003
|
+
conversational_strategy = cast(ConversationalStrategy, strategy)
|
428
1004
|
elif strategy.name == RagStrategyName.FULL_RESOURCE:
|
429
|
-
|
1005
|
+
full_resource = cast(FullResourceStrategy, strategy)
|
1006
|
+
if self.resource: # pragma: no cover
|
1007
|
+
# When the retrieval is scoped to a specific resource
|
1008
|
+
# the full resource strategy only includes that resource
|
1009
|
+
full_resource.count = 1
|
430
1010
|
elif strategy.name == RagStrategyName.HIERARCHY:
|
431
|
-
|
1011
|
+
hierarchy = cast(HierarchyResourceStrategy, strategy)
|
1012
|
+
elif strategy.name == RagStrategyName.NEIGHBOURING_PARAGRAPHS:
|
1013
|
+
neighbouring_paragraphs = cast(NeighbouringParagraphsStrategy, strategy)
|
1014
|
+
elif strategy.name == RagStrategyName.METADATA_EXTENSION:
|
1015
|
+
metadata_extension = cast(MetadataExtensionStrategy, strategy)
|
1016
|
+
elif strategy.name != RagStrategyName.PREQUERIES: # pragma: no cover
|
1017
|
+
# Prequeries are not handled here
|
1018
|
+
logger.warning(
|
1019
|
+
"Unknown rag strategy",
|
1020
|
+
extra={"strategy": strategy.name, "kbid": self.kbid},
|
1021
|
+
)
|
432
1022
|
|
433
|
-
if
|
1023
|
+
if full_resource:
|
1024
|
+
# When full resoure is enabled, only metadata extension is allowed.
|
434
1025
|
await full_resource_prompt_context(
|
435
|
-
context,
|
1026
|
+
context,
|
1027
|
+
self.kbid,
|
1028
|
+
self.ordered_paragraphs,
|
1029
|
+
self.resource,
|
1030
|
+
full_resource,
|
436
1031
|
)
|
1032
|
+
if metadata_extension:
|
1033
|
+
await extend_prompt_context_with_metadata(context, self.kbid, metadata_extension)
|
437
1034
|
return
|
438
1035
|
|
439
|
-
if
|
440
|
-
await
|
441
|
-
|
442
|
-
|
1036
|
+
if hierarchy:
|
1037
|
+
await hierarchy_prompt_context(
|
1038
|
+
context,
|
1039
|
+
self.kbid,
|
1040
|
+
self.ordered_paragraphs,
|
1041
|
+
hierarchy,
|
1042
|
+
)
|
1043
|
+
if neighbouring_paragraphs:
|
1044
|
+
await neighbouring_paragraphs_prompt_context(
|
1045
|
+
context,
|
1046
|
+
self.kbid,
|
1047
|
+
self.ordered_paragraphs,
|
1048
|
+
neighbouring_paragraphs,
|
1049
|
+
)
|
1050
|
+
if field_extension:
|
1051
|
+
await field_extension_prompt_context(
|
1052
|
+
context,
|
1053
|
+
self.kbid,
|
1054
|
+
self.ordered_paragraphs,
|
1055
|
+
field_extension,
|
1056
|
+
)
|
1057
|
+
if conversational_strategy:
|
1058
|
+
await conversation_prompt_context(
|
1059
|
+
context,
|
1060
|
+
self.kbid,
|
1061
|
+
self.ordered_paragraphs,
|
1062
|
+
conversational_strategy,
|
1063
|
+
self.visual_llm,
|
1064
|
+
)
|
1065
|
+
if metadata_extension:
|
1066
|
+
await extend_prompt_context_with_metadata(context, self.kbid, metadata_extension)
|
443
1067
|
|
444
|
-
|
445
|
-
|
446
|
-
|
447
|
-
|
448
|
-
|
449
|
-
|
450
|
-
|
1068
|
+
|
1069
|
+
def get_paragraph_page_number(paragraph: FindParagraph) -> Optional[int]:
|
1070
|
+
if not paragraph.page_with_visual:
|
1071
|
+
return None
|
1072
|
+
if paragraph.position is None:
|
1073
|
+
return None
|
1074
|
+
return paragraph.position.page_number
|
451
1075
|
|
452
1076
|
|
453
1077
|
@dataclass
|
@@ -457,67 +1081,6 @@ class ExtraCharsParagraph:
|
|
457
1081
|
paragraphs: List[Tuple[FindParagraph, str]]
|
458
1082
|
|
459
1083
|
|
460
|
-
async def get_extra_chars(
|
461
|
-
kbid: str, find_results: KnowledgeboxFindResults, distance: int
|
462
|
-
):
|
463
|
-
etcache = paragraphs.ExtractedTextCache()
|
464
|
-
resources: Dict[str, ExtraCharsParagraph] = {}
|
465
|
-
for paragraph in get_ordered_paragraphs(find_results):
|
466
|
-
rid, field_type, field = paragraph.id.split("/")[:3]
|
467
|
-
field_path = "/".join([rid, field_type, field])
|
468
|
-
position = paragraph.id.split("/")[-1]
|
469
|
-
start, end = position.split("-")
|
470
|
-
int_start = int(start)
|
471
|
-
int_end = int(end) + distance
|
472
|
-
|
473
|
-
new_text = await paragraphs.get_paragraph_text(
|
474
|
-
kbid=kbid,
|
475
|
-
rid=rid,
|
476
|
-
field=field_path,
|
477
|
-
start=int_start,
|
478
|
-
end=int_end,
|
479
|
-
extracted_text_cache=etcache,
|
480
|
-
)
|
481
|
-
if rid not in resources:
|
482
|
-
title_text = await paragraphs.get_paragraph_text(
|
483
|
-
kbid=kbid,
|
484
|
-
rid=rid,
|
485
|
-
field="/a/title",
|
486
|
-
start=0,
|
487
|
-
end=500,
|
488
|
-
extracted_text_cache=etcache,
|
489
|
-
)
|
490
|
-
summary_text = await paragraphs.get_paragraph_text(
|
491
|
-
kbid=kbid,
|
492
|
-
rid=rid,
|
493
|
-
field="/a/summary",
|
494
|
-
start=0,
|
495
|
-
end=1000,
|
496
|
-
extracted_text_cache=etcache,
|
497
|
-
)
|
498
|
-
resources[rid] = ExtraCharsParagraph(
|
499
|
-
title=title_text,
|
500
|
-
summary=summary_text,
|
501
|
-
paragraphs=[(paragraph, new_text)],
|
502
|
-
)
|
503
|
-
else:
|
504
|
-
resources[rid].paragraphs.append((paragraph, new_text)) # type: ignore
|
505
|
-
|
506
|
-
for key, values in resources.items():
|
507
|
-
title_text = values.title
|
508
|
-
summary_text = values.summary
|
509
|
-
first_paragraph = None
|
510
|
-
text = ""
|
511
|
-
for paragraph, text in values.paragraphs:
|
512
|
-
if first_paragraph is None:
|
513
|
-
first_paragraph = paragraph
|
514
|
-
text += "EXTRACTED BLOCK: \n " + text + " \n\n "
|
515
|
-
paragraph.text = ""
|
516
|
-
|
517
|
-
if first_paragraph is not None:
|
518
|
-
first_paragraph.text = f"DOCUMENT: {title_text} \n SUMMARY: {summary_text} \n RESOURCE CONTENT: {text}"
|
519
|
-
|
520
|
-
|
521
1084
|
def _clean_paragraph_text(paragraph: FindParagraph) -> str:
|
522
1085
|
text = paragraph.text.strip()
|
523
1086
|
# Do not send highlight marks on prompt context
|
@@ -525,17 +1088,23 @@ def _clean_paragraph_text(paragraph: FindParagraph) -> str:
|
|
525
1088
|
return text
|
526
1089
|
|
527
1090
|
|
528
|
-
def
|
1091
|
+
def get_neighbouring_paragraph_indexes(
|
1092
|
+
field_paragraphs: list[ParagraphId],
|
1093
|
+
matching_paragraph: ParagraphId,
|
1094
|
+
before: int,
|
1095
|
+
after: int,
|
1096
|
+
) -> list[int]:
|
529
1097
|
"""
|
530
|
-
Returns the
|
1098
|
+
Returns the indexes of the neighbouring paragraphs to fetch (including the matching paragraph).
|
531
1099
|
"""
|
532
|
-
|
533
|
-
|
534
|
-
|
535
|
-
|
536
|
-
|
537
|
-
|
538
|
-
|
539
|
-
|
540
|
-
|
541
|
-
)
|
1100
|
+
assert before >= 0
|
1101
|
+
assert after >= 0
|
1102
|
+
try:
|
1103
|
+
matching_index = field_paragraphs.index(matching_paragraph)
|
1104
|
+
except ValueError:
|
1105
|
+
raise ParagraphIdNotFoundInExtractedMetadata(
|
1106
|
+
f"Matching paragraph {matching_paragraph.full()} not found in extracted metadata"
|
1107
|
+
)
|
1108
|
+
start_index = max(0, matching_index - before)
|
1109
|
+
end_index = min(len(field_paragraphs), matching_index + after + 1)
|
1110
|
+
return list(range(start_index, end_index))
|