nucliadb 4.0.0.post542__py3-none-any.whl → 6.2.1.post2777__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- migrations/0003_allfields_key.py +1 -35
- migrations/0009_upgrade_relations_and_texts_to_v2.py +4 -2
- migrations/0010_fix_corrupt_indexes.py +10 -10
- migrations/0011_materialize_labelset_ids.py +1 -16
- migrations/0012_rollover_shards.py +5 -10
- migrations/0014_rollover_shards.py +4 -5
- migrations/0015_targeted_rollover.py +5 -10
- migrations/0016_upgrade_to_paragraphs_v2.py +25 -28
- migrations/0017_multiple_writable_shards.py +2 -4
- migrations/0018_purge_orphan_kbslugs.py +5 -7
- migrations/0019_upgrade_to_paragraphs_v3.py +25 -28
- migrations/0020_drain_nodes_from_cluster.py +3 -3
- nucliadb/standalone/tests/unit/test_run.py → migrations/0021_overwrite_vectorsets_key.py +16 -19
- nucliadb/tests/unit/test_openapi.py → migrations/0022_fix_paragraph_deletion_bug.py +16 -11
- migrations/0023_backfill_pg_catalog.py +80 -0
- migrations/0025_assign_models_to_kbs_v2.py +113 -0
- migrations/0026_fix_high_cardinality_content_types.py +61 -0
- migrations/0027_rollover_texts3.py +73 -0
- nucliadb/ingest/fields/date.py → migrations/pg/0001_bootstrap.py +10 -12
- migrations/pg/0002_catalog.py +42 -0
- nucliadb/ingest/tests/unit/test_settings.py → migrations/pg/0003_catalog_kbid_index.py +5 -3
- nucliadb/common/cluster/base.py +30 -16
- nucliadb/common/cluster/discovery/base.py +6 -14
- nucliadb/common/cluster/discovery/k8s.py +9 -19
- nucliadb/common/cluster/discovery/manual.py +1 -3
- nucliadb/common/cluster/discovery/utils.py +1 -3
- nucliadb/common/cluster/grpc_node_dummy.py +3 -11
- nucliadb/common/cluster/index_node.py +10 -19
- nucliadb/common/cluster/manager.py +174 -59
- nucliadb/common/cluster/rebalance.py +27 -29
- nucliadb/common/cluster/rollover.py +353 -194
- nucliadb/common/cluster/settings.py +6 -0
- nucliadb/common/cluster/standalone/grpc_node_binding.py +13 -64
- nucliadb/common/cluster/standalone/index_node.py +4 -11
- nucliadb/common/cluster/standalone/service.py +2 -6
- nucliadb/common/cluster/standalone/utils.py +2 -6
- nucliadb/common/cluster/utils.py +29 -22
- nucliadb/common/constants.py +20 -0
- nucliadb/common/context/__init__.py +3 -0
- nucliadb/common/context/fastapi.py +8 -5
- nucliadb/{tests/knowledgeboxes/__init__.py → common/counters.py} +8 -2
- nucliadb/common/datamanagers/__init__.py +7 -1
- nucliadb/common/datamanagers/atomic.py +22 -4
- nucliadb/common/datamanagers/cluster.py +5 -5
- nucliadb/common/datamanagers/entities.py +6 -16
- nucliadb/common/datamanagers/fields.py +84 -0
- nucliadb/common/datamanagers/kb.py +83 -37
- nucliadb/common/datamanagers/labels.py +26 -56
- nucliadb/common/datamanagers/processing.py +2 -6
- nucliadb/common/datamanagers/resources.py +41 -103
- nucliadb/common/datamanagers/rollover.py +76 -15
- nucliadb/common/datamanagers/synonyms.py +1 -1
- nucliadb/common/datamanagers/utils.py +15 -6
- nucliadb/common/datamanagers/vectorsets.py +110 -0
- nucliadb/common/external_index_providers/base.py +257 -0
- nucliadb/{ingest/tests/unit/orm/test_orm_utils.py → common/external_index_providers/exceptions.py} +9 -8
- nucliadb/common/external_index_providers/manager.py +101 -0
- nucliadb/common/external_index_providers/pinecone.py +933 -0
- nucliadb/common/external_index_providers/settings.py +52 -0
- nucliadb/common/http_clients/auth.py +3 -6
- nucliadb/common/http_clients/processing.py +6 -11
- nucliadb/common/http_clients/utils.py +1 -3
- nucliadb/common/ids.py +240 -0
- nucliadb/common/locking.py +29 -7
- nucliadb/common/maindb/driver.py +11 -35
- nucliadb/common/maindb/exceptions.py +3 -0
- nucliadb/common/maindb/local.py +22 -9
- nucliadb/common/maindb/pg.py +206 -111
- nucliadb/common/maindb/utils.py +11 -42
- nucliadb/common/models_utils/from_proto.py +479 -0
- nucliadb/common/models_utils/to_proto.py +60 -0
- nucliadb/common/nidx.py +260 -0
- nucliadb/export_import/datamanager.py +25 -19
- nucliadb/export_import/exporter.py +5 -11
- nucliadb/export_import/importer.py +5 -7
- nucliadb/export_import/models.py +3 -3
- nucliadb/export_import/tasks.py +4 -4
- nucliadb/export_import/utils.py +25 -37
- nucliadb/health.py +1 -3
- nucliadb/ingest/app.py +15 -11
- nucliadb/ingest/consumer/auditing.py +21 -19
- nucliadb/ingest/consumer/consumer.py +82 -47
- nucliadb/ingest/consumer/materializer.py +5 -12
- nucliadb/ingest/consumer/pull.py +12 -27
- nucliadb/ingest/consumer/service.py +19 -17
- nucliadb/ingest/consumer/shard_creator.py +2 -4
- nucliadb/ingest/consumer/utils.py +1 -3
- nucliadb/ingest/fields/base.py +137 -105
- nucliadb/ingest/fields/conversation.py +18 -5
- nucliadb/ingest/fields/exceptions.py +1 -4
- nucliadb/ingest/fields/file.py +7 -16
- nucliadb/ingest/fields/link.py +5 -10
- nucliadb/ingest/fields/text.py +9 -4
- nucliadb/ingest/orm/brain.py +200 -213
- nucliadb/ingest/orm/broker_message.py +181 -0
- nucliadb/ingest/orm/entities.py +36 -51
- nucliadb/ingest/orm/exceptions.py +12 -0
- nucliadb/ingest/orm/knowledgebox.py +322 -197
- nucliadb/ingest/orm/processor/__init__.py +2 -700
- nucliadb/ingest/orm/processor/auditing.py +4 -23
- nucliadb/ingest/orm/processor/data_augmentation.py +164 -0
- nucliadb/ingest/orm/processor/pgcatalog.py +84 -0
- nucliadb/ingest/orm/processor/processor.py +752 -0
- nucliadb/ingest/orm/processor/sequence_manager.py +1 -1
- nucliadb/ingest/orm/resource.py +249 -402
- nucliadb/ingest/orm/utils.py +4 -4
- nucliadb/ingest/partitions.py +3 -9
- nucliadb/ingest/processing.py +64 -73
- nucliadb/ingest/py.typed +0 -0
- nucliadb/ingest/serialize.py +37 -167
- nucliadb/ingest/service/__init__.py +1 -3
- nucliadb/ingest/service/writer.py +185 -412
- nucliadb/ingest/settings.py +10 -20
- nucliadb/ingest/utils.py +3 -6
- nucliadb/learning_proxy.py +242 -55
- nucliadb/metrics_exporter.py +30 -19
- nucliadb/middleware/__init__.py +1 -3
- nucliadb/migrator/command.py +1 -3
- nucliadb/migrator/datamanager.py +13 -13
- nucliadb/migrator/migrator.py +47 -30
- nucliadb/migrator/utils.py +18 -10
- nucliadb/purge/__init__.py +139 -33
- nucliadb/purge/orphan_shards.py +7 -13
- nucliadb/reader/__init__.py +1 -3
- nucliadb/reader/api/models.py +1 -12
- nucliadb/reader/api/v1/__init__.py +0 -1
- nucliadb/reader/api/v1/download.py +21 -88
- nucliadb/reader/api/v1/export_import.py +1 -1
- nucliadb/reader/api/v1/knowledgebox.py +10 -10
- nucliadb/reader/api/v1/learning_config.py +2 -6
- nucliadb/reader/api/v1/resource.py +62 -88
- nucliadb/reader/api/v1/services.py +64 -83
- nucliadb/reader/app.py +12 -29
- nucliadb/reader/lifecycle.py +18 -4
- nucliadb/reader/py.typed +0 -0
- nucliadb/reader/reader/notifications.py +10 -28
- nucliadb/search/__init__.py +1 -3
- nucliadb/search/api/v1/__init__.py +1 -2
- nucliadb/search/api/v1/ask.py +17 -10
- nucliadb/search/api/v1/catalog.py +184 -0
- nucliadb/search/api/v1/feedback.py +16 -24
- nucliadb/search/api/v1/find.py +36 -36
- nucliadb/search/api/v1/knowledgebox.py +89 -60
- nucliadb/search/api/v1/resource/ask.py +2 -8
- nucliadb/search/api/v1/resource/search.py +49 -70
- nucliadb/search/api/v1/search.py +44 -210
- nucliadb/search/api/v1/suggest.py +39 -54
- nucliadb/search/app.py +12 -32
- nucliadb/search/lifecycle.py +10 -3
- nucliadb/search/predict.py +136 -187
- nucliadb/search/py.typed +0 -0
- nucliadb/search/requesters/utils.py +25 -58
- nucliadb/search/search/cache.py +149 -20
- nucliadb/search/search/chat/ask.py +571 -123
- nucliadb/search/{tests/unit/test_run.py → search/chat/exceptions.py} +14 -14
- nucliadb/search/search/chat/images.py +41 -17
- nucliadb/search/search/chat/prompt.py +817 -266
- nucliadb/search/search/chat/query.py +213 -309
- nucliadb/{tests/migrations/__init__.py → search/search/cut.py} +8 -8
- nucliadb/search/search/fetch.py +43 -36
- nucliadb/search/search/filters.py +9 -15
- nucliadb/search/search/find.py +214 -53
- nucliadb/search/search/find_merge.py +408 -391
- nucliadb/search/search/hydrator.py +191 -0
- nucliadb/search/search/merge.py +187 -223
- nucliadb/search/search/metrics.py +73 -2
- nucliadb/search/search/paragraphs.py +64 -106
- nucliadb/search/search/pgcatalog.py +233 -0
- nucliadb/search/search/predict_proxy.py +1 -1
- nucliadb/search/search/query.py +305 -150
- nucliadb/search/search/query_parser/exceptions.py +22 -0
- nucliadb/search/search/query_parser/models.py +101 -0
- nucliadb/search/search/query_parser/parser.py +183 -0
- nucliadb/search/search/rank_fusion.py +204 -0
- nucliadb/search/search/rerankers.py +270 -0
- nucliadb/search/search/shards.py +3 -32
- nucliadb/search/search/summarize.py +7 -18
- nucliadb/search/search/utils.py +27 -4
- nucliadb/search/settings.py +15 -1
- nucliadb/standalone/api_router.py +4 -10
- nucliadb/standalone/app.py +8 -14
- nucliadb/standalone/auth.py +7 -21
- nucliadb/standalone/config.py +7 -10
- nucliadb/standalone/lifecycle.py +26 -25
- nucliadb/standalone/migrations.py +1 -3
- nucliadb/standalone/purge.py +1 -1
- nucliadb/standalone/py.typed +0 -0
- nucliadb/standalone/run.py +3 -6
- nucliadb/standalone/settings.py +9 -16
- nucliadb/standalone/versions.py +15 -5
- nucliadb/tasks/consumer.py +8 -12
- nucliadb/tasks/producer.py +7 -6
- nucliadb/tests/config.py +53 -0
- nucliadb/train/__init__.py +1 -3
- nucliadb/train/api/utils.py +1 -2
- nucliadb/train/api/v1/shards.py +1 -1
- nucliadb/train/api/v1/trainset.py +2 -4
- nucliadb/train/app.py +10 -31
- nucliadb/train/generator.py +10 -19
- nucliadb/train/generators/field_classifier.py +7 -19
- nucliadb/train/generators/field_streaming.py +156 -0
- nucliadb/train/generators/image_classifier.py +12 -18
- nucliadb/train/generators/paragraph_classifier.py +5 -9
- nucliadb/train/generators/paragraph_streaming.py +6 -9
- nucliadb/train/generators/question_answer_streaming.py +19 -20
- nucliadb/train/generators/sentence_classifier.py +9 -15
- nucliadb/train/generators/token_classifier.py +48 -39
- nucliadb/train/generators/utils.py +14 -18
- nucliadb/train/lifecycle.py +7 -3
- nucliadb/train/nodes.py +23 -32
- nucliadb/train/py.typed +0 -0
- nucliadb/train/servicer.py +13 -21
- nucliadb/train/settings.py +2 -6
- nucliadb/train/types.py +13 -10
- nucliadb/train/upload.py +3 -6
- nucliadb/train/uploader.py +19 -23
- nucliadb/train/utils.py +1 -1
- nucliadb/writer/__init__.py +1 -3
- nucliadb/{ingest/fields/keywordset.py → writer/api/utils.py} +13 -10
- nucliadb/writer/api/v1/export_import.py +67 -14
- nucliadb/writer/api/v1/field.py +16 -269
- nucliadb/writer/api/v1/knowledgebox.py +218 -68
- nucliadb/writer/api/v1/resource.py +68 -88
- nucliadb/writer/api/v1/services.py +51 -70
- nucliadb/writer/api/v1/slug.py +61 -0
- nucliadb/writer/api/v1/transaction.py +67 -0
- nucliadb/writer/api/v1/upload.py +114 -113
- nucliadb/writer/app.py +6 -43
- nucliadb/writer/back_pressure.py +16 -38
- nucliadb/writer/exceptions.py +0 -4
- nucliadb/writer/lifecycle.py +21 -15
- nucliadb/writer/py.typed +0 -0
- nucliadb/writer/resource/audit.py +2 -1
- nucliadb/writer/resource/basic.py +48 -46
- nucliadb/writer/resource/field.py +25 -127
- nucliadb/writer/resource/origin.py +1 -2
- nucliadb/writer/settings.py +6 -2
- nucliadb/writer/tus/__init__.py +17 -15
- nucliadb/writer/tus/azure.py +111 -0
- nucliadb/writer/tus/dm.py +17 -5
- nucliadb/writer/tus/exceptions.py +1 -3
- nucliadb/writer/tus/gcs.py +49 -84
- nucliadb/writer/tus/local.py +21 -37
- nucliadb/writer/tus/s3.py +28 -68
- nucliadb/writer/tus/storage.py +5 -56
- nucliadb/writer/vectorsets.py +125 -0
- nucliadb-6.2.1.post2777.dist-info/METADATA +148 -0
- nucliadb-6.2.1.post2777.dist-info/RECORD +343 -0
- {nucliadb-4.0.0.post542.dist-info → nucliadb-6.2.1.post2777.dist-info}/WHEEL +1 -1
- nucliadb/common/maindb/redis.py +0 -194
- nucliadb/common/maindb/tikv.py +0 -433
- nucliadb/ingest/fields/layout.py +0 -58
- nucliadb/ingest/tests/conftest.py +0 -30
- nucliadb/ingest/tests/fixtures.py +0 -764
- nucliadb/ingest/tests/integration/consumer/__init__.py +0 -18
- nucliadb/ingest/tests/integration/consumer/test_auditing.py +0 -78
- nucliadb/ingest/tests/integration/consumer/test_materializer.py +0 -126
- nucliadb/ingest/tests/integration/consumer/test_pull.py +0 -144
- nucliadb/ingest/tests/integration/consumer/test_service.py +0 -81
- nucliadb/ingest/tests/integration/consumer/test_shard_creator.py +0 -68
- nucliadb/ingest/tests/integration/ingest/test_ingest.py +0 -684
- nucliadb/ingest/tests/integration/ingest/test_processing_engine.py +0 -95
- nucliadb/ingest/tests/integration/ingest/test_relations.py +0 -272
- nucliadb/ingest/tests/unit/consumer/__init__.py +0 -18
- nucliadb/ingest/tests/unit/consumer/test_auditing.py +0 -139
- nucliadb/ingest/tests/unit/consumer/test_consumer.py +0 -69
- nucliadb/ingest/tests/unit/consumer/test_pull.py +0 -60
- nucliadb/ingest/tests/unit/consumer/test_shard_creator.py +0 -140
- nucliadb/ingest/tests/unit/consumer/test_utils.py +0 -67
- nucliadb/ingest/tests/unit/orm/__init__.py +0 -19
- nucliadb/ingest/tests/unit/orm/test_brain.py +0 -247
- nucliadb/ingest/tests/unit/orm/test_brain_vectors.py +0 -74
- nucliadb/ingest/tests/unit/orm/test_processor.py +0 -131
- nucliadb/ingest/tests/unit/orm/test_resource.py +0 -331
- nucliadb/ingest/tests/unit/test_cache.py +0 -31
- nucliadb/ingest/tests/unit/test_partitions.py +0 -40
- nucliadb/ingest/tests/unit/test_processing.py +0 -171
- nucliadb/middleware/transaction.py +0 -117
- nucliadb/reader/api/v1/learning_collector.py +0 -63
- nucliadb/reader/tests/__init__.py +0 -19
- nucliadb/reader/tests/conftest.py +0 -31
- nucliadb/reader/tests/fixtures.py +0 -136
- nucliadb/reader/tests/test_list_resources.py +0 -75
- nucliadb/reader/tests/test_reader_file_download.py +0 -273
- nucliadb/reader/tests/test_reader_resource.py +0 -353
- nucliadb/reader/tests/test_reader_resource_field.py +0 -219
- nucliadb/search/api/v1/chat.py +0 -263
- nucliadb/search/api/v1/resource/chat.py +0 -174
- nucliadb/search/tests/__init__.py +0 -19
- nucliadb/search/tests/conftest.py +0 -33
- nucliadb/search/tests/fixtures.py +0 -199
- nucliadb/search/tests/node.py +0 -466
- nucliadb/search/tests/unit/__init__.py +0 -18
- nucliadb/search/tests/unit/api/__init__.py +0 -19
- nucliadb/search/tests/unit/api/v1/__init__.py +0 -19
- nucliadb/search/tests/unit/api/v1/resource/__init__.py +0 -19
- nucliadb/search/tests/unit/api/v1/resource/test_chat.py +0 -98
- nucliadb/search/tests/unit/api/v1/test_ask.py +0 -120
- nucliadb/search/tests/unit/api/v1/test_chat.py +0 -96
- nucliadb/search/tests/unit/api/v1/test_predict_proxy.py +0 -98
- nucliadb/search/tests/unit/api/v1/test_summarize.py +0 -99
- nucliadb/search/tests/unit/search/__init__.py +0 -18
- nucliadb/search/tests/unit/search/requesters/__init__.py +0 -18
- nucliadb/search/tests/unit/search/requesters/test_utils.py +0 -211
- nucliadb/search/tests/unit/search/search/__init__.py +0 -19
- nucliadb/search/tests/unit/search/search/test_shards.py +0 -45
- nucliadb/search/tests/unit/search/search/test_utils.py +0 -82
- nucliadb/search/tests/unit/search/test_chat_prompt.py +0 -270
- nucliadb/search/tests/unit/search/test_fetch.py +0 -108
- nucliadb/search/tests/unit/search/test_filters.py +0 -125
- nucliadb/search/tests/unit/search/test_paragraphs.py +0 -157
- nucliadb/search/tests/unit/search/test_predict_proxy.py +0 -106
- nucliadb/search/tests/unit/search/test_query.py +0 -153
- nucliadb/search/tests/unit/test_app.py +0 -79
- nucliadb/search/tests/unit/test_find_merge.py +0 -112
- nucliadb/search/tests/unit/test_merge.py +0 -34
- nucliadb/search/tests/unit/test_predict.py +0 -525
- nucliadb/standalone/tests/__init__.py +0 -19
- nucliadb/standalone/tests/conftest.py +0 -33
- nucliadb/standalone/tests/fixtures.py +0 -38
- nucliadb/standalone/tests/unit/__init__.py +0 -18
- nucliadb/standalone/tests/unit/test_api_router.py +0 -61
- nucliadb/standalone/tests/unit/test_auth.py +0 -169
- nucliadb/standalone/tests/unit/test_introspect.py +0 -35
- nucliadb/standalone/tests/unit/test_migrations.py +0 -63
- nucliadb/standalone/tests/unit/test_versions.py +0 -68
- nucliadb/tests/benchmarks/__init__.py +0 -19
- nucliadb/tests/benchmarks/test_search.py +0 -99
- nucliadb/tests/conftest.py +0 -32
- nucliadb/tests/fixtures.py +0 -735
- nucliadb/tests/knowledgeboxes/philosophy_books.py +0 -202
- nucliadb/tests/knowledgeboxes/ten_dummy_resources.py +0 -107
- nucliadb/tests/migrations/test_migration_0017.py +0 -76
- nucliadb/tests/migrations/test_migration_0018.py +0 -95
- nucliadb/tests/tikv.py +0 -240
- nucliadb/tests/unit/__init__.py +0 -19
- nucliadb/tests/unit/common/__init__.py +0 -19
- nucliadb/tests/unit/common/cluster/__init__.py +0 -19
- nucliadb/tests/unit/common/cluster/discovery/__init__.py +0 -19
- nucliadb/tests/unit/common/cluster/discovery/test_k8s.py +0 -172
- nucliadb/tests/unit/common/cluster/standalone/__init__.py +0 -18
- nucliadb/tests/unit/common/cluster/standalone/test_service.py +0 -114
- nucliadb/tests/unit/common/cluster/standalone/test_utils.py +0 -61
- nucliadb/tests/unit/common/cluster/test_cluster.py +0 -408
- nucliadb/tests/unit/common/cluster/test_kb_shard_manager.py +0 -173
- nucliadb/tests/unit/common/cluster/test_rebalance.py +0 -38
- nucliadb/tests/unit/common/cluster/test_rollover.py +0 -282
- nucliadb/tests/unit/common/maindb/__init__.py +0 -18
- nucliadb/tests/unit/common/maindb/test_driver.py +0 -127
- nucliadb/tests/unit/common/maindb/test_tikv.py +0 -53
- nucliadb/tests/unit/common/maindb/test_utils.py +0 -92
- nucliadb/tests/unit/common/test_context.py +0 -36
- nucliadb/tests/unit/export_import/__init__.py +0 -19
- nucliadb/tests/unit/export_import/test_datamanager.py +0 -37
- nucliadb/tests/unit/export_import/test_utils.py +0 -301
- nucliadb/tests/unit/migrator/__init__.py +0 -19
- nucliadb/tests/unit/migrator/test_migrator.py +0 -87
- nucliadb/tests/unit/tasks/__init__.py +0 -19
- nucliadb/tests/unit/tasks/conftest.py +0 -42
- nucliadb/tests/unit/tasks/test_consumer.py +0 -92
- nucliadb/tests/unit/tasks/test_producer.py +0 -95
- nucliadb/tests/unit/tasks/test_tasks.py +0 -58
- nucliadb/tests/unit/test_field_ids.py +0 -49
- nucliadb/tests/unit/test_health.py +0 -86
- nucliadb/tests/unit/test_kb_slugs.py +0 -54
- nucliadb/tests/unit/test_learning_proxy.py +0 -252
- nucliadb/tests/unit/test_metrics_exporter.py +0 -77
- nucliadb/tests/unit/test_purge.py +0 -136
- nucliadb/tests/utils/__init__.py +0 -74
- nucliadb/tests/utils/aiohttp_session.py +0 -44
- nucliadb/tests/utils/broker_messages/__init__.py +0 -171
- nucliadb/tests/utils/broker_messages/fields.py +0 -197
- nucliadb/tests/utils/broker_messages/helpers.py +0 -33
- nucliadb/tests/utils/entities.py +0 -78
- nucliadb/train/api/v1/check.py +0 -60
- nucliadb/train/tests/__init__.py +0 -19
- nucliadb/train/tests/conftest.py +0 -29
- nucliadb/train/tests/fixtures.py +0 -342
- nucliadb/train/tests/test_field_classification.py +0 -122
- nucliadb/train/tests/test_get_entities.py +0 -80
- nucliadb/train/tests/test_get_info.py +0 -51
- nucliadb/train/tests/test_get_ontology.py +0 -34
- nucliadb/train/tests/test_get_ontology_count.py +0 -63
- nucliadb/train/tests/test_image_classification.py +0 -221
- nucliadb/train/tests/test_list_fields.py +0 -39
- nucliadb/train/tests/test_list_paragraphs.py +0 -73
- nucliadb/train/tests/test_list_resources.py +0 -39
- nucliadb/train/tests/test_list_sentences.py +0 -71
- nucliadb/train/tests/test_paragraph_classification.py +0 -123
- nucliadb/train/tests/test_paragraph_streaming.py +0 -118
- nucliadb/train/tests/test_question_answer_streaming.py +0 -239
- nucliadb/train/tests/test_sentence_classification.py +0 -143
- nucliadb/train/tests/test_token_classification.py +0 -136
- nucliadb/train/tests/utils.py +0 -101
- nucliadb/writer/layouts/__init__.py +0 -51
- nucliadb/writer/layouts/v1.py +0 -59
- nucliadb/writer/tests/__init__.py +0 -19
- nucliadb/writer/tests/conftest.py +0 -31
- nucliadb/writer/tests/fixtures.py +0 -191
- nucliadb/writer/tests/test_fields.py +0 -475
- nucliadb/writer/tests/test_files.py +0 -740
- nucliadb/writer/tests/test_knowledgebox.py +0 -49
- nucliadb/writer/tests/test_reprocess_file_field.py +0 -133
- nucliadb/writer/tests/test_resources.py +0 -476
- nucliadb/writer/tests/test_service.py +0 -137
- nucliadb/writer/tests/test_tus.py +0 -203
- nucliadb/writer/tests/utils.py +0 -35
- nucliadb/writer/tus/pg.py +0 -125
- nucliadb-4.0.0.post542.dist-info/METADATA +0 -135
- nucliadb-4.0.0.post542.dist-info/RECORD +0 -462
- {nucliadb/ingest/tests → migrations/pg}/__init__.py +0 -0
- /nucliadb/{ingest/tests/integration → common/external_index_providers}/__init__.py +0 -0
- /nucliadb/{ingest/tests/integration/ingest → common/models_utils}/__init__.py +0 -0
- /nucliadb/{ingest/tests/unit → search/search/query_parser}/__init__.py +0 -0
- /nucliadb/{ingest/tests → tests}/vectors.py +0 -0
- {nucliadb-4.0.0.post542.dist-info → nucliadb-6.2.1.post2777.dist-info}/entry_points.txt +0 -0
- {nucliadb-4.0.0.post542.dist-info → nucliadb-6.2.1.post2777.dist-info}/top_level.txt +0 -0
- {nucliadb-4.0.0.post542.dist-info → nucliadb-6.2.1.post2777.dist-info}/zip-safe +0 -0
@@ -17,21 +17,32 @@
|
|
17
17
|
# You should have received a copy of the GNU Affero General Public License
|
18
18
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
19
19
|
#
|
20
|
+
import dataclasses
|
20
21
|
import functools
|
21
|
-
|
22
|
-
from typing import AsyncGenerator, Optional
|
22
|
+
import json
|
23
|
+
from typing import AsyncGenerator, Optional, cast
|
23
24
|
|
24
|
-
from
|
25
|
-
from nucliadb.models.responses import HTTPClientError
|
26
|
-
from nucliadb.search import logger, predict
|
27
|
-
from nucliadb.search.predict import (
|
28
|
-
AnswerStatusCode,
|
25
|
+
from nuclia_models.predict.generative_responses import (
|
29
26
|
CitationsGenerativeResponse,
|
30
27
|
GenerativeChunk,
|
28
|
+
JSONGenerativeResponse,
|
31
29
|
MetaGenerativeResponse,
|
32
30
|
StatusGenerativeResponse,
|
33
31
|
TextGenerativeResponse,
|
34
32
|
)
|
33
|
+
from pydantic_core import ValidationError
|
34
|
+
|
35
|
+
from nucliadb.common.datamanagers.exceptions import KnowledgeBoxNotFound
|
36
|
+
from nucliadb.models.responses import HTTPClientError
|
37
|
+
from nucliadb.search import logger, predict
|
38
|
+
from nucliadb.search.predict import (
|
39
|
+
AnswerStatusCode,
|
40
|
+
RephraseMissingContextError,
|
41
|
+
)
|
42
|
+
from nucliadb.search.search.chat.exceptions import (
|
43
|
+
AnswerJsonSchemaTooLong,
|
44
|
+
NoRetrievalResultsError,
|
45
|
+
)
|
35
46
|
from nucliadb.search.search.chat.prompt import PromptContextBuilder
|
36
47
|
from nucliadb.search.search.chat.query import (
|
37
48
|
NOT_ENOUGH_CONTEXT_ANSWER,
|
@@ -46,6 +57,7 @@ from nucliadb.search.search.exceptions import (
|
|
46
57
|
IncompleteFindResultsError,
|
47
58
|
InvalidQueryError,
|
48
59
|
)
|
60
|
+
from nucliadb.search.search.metrics import RAGMetrics
|
49
61
|
from nucliadb.search.search.query import QueryParser
|
50
62
|
from nucliadb.search.utilities import get_predict
|
51
63
|
from nucliadb_models.search import (
|
@@ -53,6 +65,7 @@ from nucliadb_models.search import (
|
|
53
65
|
AskRequest,
|
54
66
|
AskResponseItem,
|
55
67
|
AskResponseItemType,
|
68
|
+
AskRetrievalMatch,
|
56
69
|
AskTimings,
|
57
70
|
AskTokens,
|
58
71
|
ChatModel,
|
@@ -60,48 +73,84 @@ from nucliadb_models.search import (
|
|
60
73
|
CitationsAskResponseItem,
|
61
74
|
DebugAskResponseItem,
|
62
75
|
ErrorAskResponseItem,
|
76
|
+
FindParagraph,
|
77
|
+
FindRequest,
|
78
|
+
JSONAskResponseItem,
|
63
79
|
KnowledgeboxFindResults,
|
64
80
|
MetadataAskResponseItem,
|
65
81
|
MinScore,
|
66
82
|
NucliaDBClientType,
|
83
|
+
PrequeriesAskResponseItem,
|
84
|
+
PreQueriesStrategy,
|
85
|
+
PreQuery,
|
86
|
+
PreQueryResult,
|
67
87
|
PromptContext,
|
68
88
|
PromptContextOrder,
|
89
|
+
RagStrategyName,
|
69
90
|
Relations,
|
70
91
|
RelationsAskResponseItem,
|
71
92
|
RetrievalAskResponseItem,
|
93
|
+
SearchOptions,
|
72
94
|
StatusAskResponseItem,
|
73
95
|
SyncAskMetadata,
|
74
96
|
SyncAskResponse,
|
75
97
|
UserPrompt,
|
98
|
+
parse_custom_prompt,
|
99
|
+
parse_rephrase_prompt,
|
76
100
|
)
|
101
|
+
from nucliadb_telemetry import errors
|
77
102
|
from nucliadb_utils.exceptions import LimitsExceededError
|
78
103
|
|
79
104
|
|
105
|
+
@dataclasses.dataclass
|
106
|
+
class RetrievalMatch:
|
107
|
+
paragraph: FindParagraph
|
108
|
+
weighted_score: float
|
109
|
+
|
110
|
+
|
111
|
+
@dataclasses.dataclass
|
112
|
+
class RetrievalResults:
|
113
|
+
main_query: KnowledgeboxFindResults
|
114
|
+
query_parser: QueryParser
|
115
|
+
main_query_weight: float
|
116
|
+
prequeries: Optional[list[PreQueryResult]] = None
|
117
|
+
best_matches: list[RetrievalMatch] = dataclasses.field(default_factory=list)
|
118
|
+
|
119
|
+
|
80
120
|
class AskResult:
|
81
121
|
def __init__(
|
82
122
|
self,
|
83
123
|
*,
|
84
124
|
kbid: str,
|
85
125
|
ask_request: AskRequest,
|
86
|
-
|
126
|
+
main_results: KnowledgeboxFindResults,
|
127
|
+
prequeries_results: Optional[list[PreQueryResult]],
|
87
128
|
nuclia_learning_id: Optional[str],
|
88
129
|
predict_answer_stream: AsyncGenerator[GenerativeChunk, None],
|
89
130
|
prompt_context: PromptContext,
|
90
131
|
prompt_context_order: PromptContextOrder,
|
91
132
|
auditor: ChatAuditor,
|
133
|
+
metrics: RAGMetrics,
|
134
|
+
best_matches: list[RetrievalMatch],
|
135
|
+
debug_chat_model: Optional[ChatModel],
|
92
136
|
):
|
93
137
|
# Initial attributes
|
94
138
|
self.kbid = kbid
|
95
139
|
self.ask_request = ask_request
|
96
|
-
self.
|
140
|
+
self.main_results = main_results
|
141
|
+
self.prequeries_results = prequeries_results or []
|
97
142
|
self.nuclia_learning_id = nuclia_learning_id
|
98
143
|
self.predict_answer_stream = predict_answer_stream
|
99
144
|
self.prompt_context = prompt_context
|
145
|
+
self.debug_chat_model = debug_chat_model
|
100
146
|
self.prompt_context_order = prompt_context_order
|
101
|
-
self.auditor = auditor
|
147
|
+
self.auditor: ChatAuditor = auditor
|
148
|
+
self.metrics: RAGMetrics = metrics
|
149
|
+
self.best_matches: list[RetrievalMatch] = best_matches
|
102
150
|
|
103
151
|
# Computed from the predict chat answer stream
|
104
152
|
self._answer_text = ""
|
153
|
+
self._object: Optional[JSONGenerativeResponse] = None
|
105
154
|
self._status: Optional[StatusGenerativeResponse] = None
|
106
155
|
self._citations: Optional[CitationsGenerativeResponse] = None
|
107
156
|
self._metadata: Optional[MetaGenerativeResponse] = None
|
@@ -113,6 +162,12 @@ class AskResult:
|
|
113
162
|
return AnswerStatusCode.SUCCESS
|
114
163
|
return AnswerStatusCode(self._status.code)
|
115
164
|
|
165
|
+
@property
|
166
|
+
def status_error_details(self) -> Optional[str]:
|
167
|
+
if self._status is None: # pragma: no cover
|
168
|
+
return None
|
169
|
+
return self._status.details
|
170
|
+
|
116
171
|
@property
|
117
172
|
def ask_request_with_relations(self) -> bool:
|
118
173
|
return ChatOptions.RELATIONS in self.ask_request.features
|
@@ -128,34 +183,89 @@ class AskResult:
|
|
128
183
|
except Exception as exc:
|
129
184
|
# Handle any unexpected error that might happen
|
130
185
|
# during the streaming and halt the stream
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
186
|
+
errors.capture_exception(exc)
|
187
|
+
logger.error(
|
188
|
+
f"Unexpected error while generating the answer: {exc}",
|
189
|
+
extra={"kbid": self.kbid},
|
190
|
+
)
|
191
|
+
error_message = "Unexpected error while generating the answer. Please try again later."
|
192
|
+
if self.ask_request_with_debug_flag:
|
193
|
+
error_message += f" Error: {exc}"
|
194
|
+
item = ErrorAskResponseItem(error=error_message)
|
136
195
|
yield self._ndjson_encode(item)
|
137
196
|
return
|
138
197
|
|
139
198
|
def _ndjson_encode(self, item: AskResponseItemType) -> str:
|
140
199
|
result_item = AskResponseItem(item=item)
|
141
|
-
return result_item.
|
200
|
+
return result_item.model_dump_json(exclude_none=True, by_alias=True) + "\n"
|
142
201
|
|
143
202
|
async def _stream(self) -> AsyncGenerator[AskResponseItemType, None]:
|
144
|
-
# First stream out the
|
145
|
-
|
203
|
+
# First, stream out the predict answer
|
204
|
+
first_chunk_yielded = False
|
205
|
+
with self.metrics.time("stream_predict_answer"):
|
206
|
+
async for answer_chunk in self._stream_predict_answer_text():
|
207
|
+
yield AnswerAskResponseItem(text=answer_chunk)
|
208
|
+
if not first_chunk_yielded:
|
209
|
+
self.metrics.record_first_chunk_yielded()
|
210
|
+
first_chunk_yielded = True
|
211
|
+
|
212
|
+
if self._object is not None:
|
213
|
+
yield JSONAskResponseItem(object=self._object.object)
|
214
|
+
if not first_chunk_yielded:
|
215
|
+
# When there is a JSON generative response, we consider the first chunk yielded
|
216
|
+
# to be the moment when the JSON object is yielded, not the text
|
217
|
+
self.metrics.record_first_chunk_yielded()
|
218
|
+
first_chunk_yielded = True
|
219
|
+
|
220
|
+
yield RetrievalAskResponseItem(
|
221
|
+
results=self.main_results,
|
222
|
+
best_matches=[
|
223
|
+
AskRetrievalMatch(
|
224
|
+
id=match.paragraph.id,
|
225
|
+
)
|
226
|
+
for match in self.best_matches
|
227
|
+
],
|
228
|
+
)
|
146
229
|
|
147
|
-
|
148
|
-
|
149
|
-
|
230
|
+
if len(self.prequeries_results) > 0:
|
231
|
+
item = PrequeriesAskResponseItem()
|
232
|
+
for index, (prequery, result) in enumerate(self.prequeries_results):
|
233
|
+
prequery_id = prequery.id or f"prequery_{index}"
|
234
|
+
item.results[prequery_id] = result
|
235
|
+
yield item
|
236
|
+
|
237
|
+
# Then the status
|
238
|
+
if self.status_code == AnswerStatusCode.ERROR:
|
239
|
+
# If predict yielded an error status, we yield it too and halt the stream immediately
|
240
|
+
yield StatusAskResponseItem(
|
241
|
+
code=self.status_code.value,
|
242
|
+
status=self.status_code.prettify(),
|
243
|
+
details=self.status_error_details or "Unknown error",
|
244
|
+
)
|
245
|
+
return
|
150
246
|
|
151
|
-
# Then the status code
|
152
247
|
yield StatusAskResponseItem(
|
153
|
-
code=self.status_code.value,
|
248
|
+
code=self.status_code.value,
|
249
|
+
status=self.status_code.prettify(),
|
154
250
|
)
|
155
251
|
|
156
252
|
# Audit the answer
|
157
|
-
|
158
|
-
|
253
|
+
if self._object is None:
|
254
|
+
audit_answer = self._answer_text.encode("utf-8")
|
255
|
+
else:
|
256
|
+
audit_answer = json.dumps(self._object.object).encode("utf-8")
|
257
|
+
|
258
|
+
try:
|
259
|
+
rephrase_time = self.metrics.elapsed("rephrase")
|
260
|
+
except KeyError:
|
261
|
+
# Not all ask requests have a rephrase step
|
262
|
+
rephrase_time = None
|
263
|
+
|
264
|
+
self.auditor.audit(
|
265
|
+
text_answer=audit_answer,
|
266
|
+
generative_answer_time=self.metrics.elapsed("stream_predict_answer"),
|
267
|
+
generative_answer_first_chunk_time=self.metrics.get_first_chunk_time() or 0,
|
268
|
+
rephrase_time=rephrase_time,
|
159
269
|
status_code=self.status_code,
|
160
270
|
)
|
161
271
|
|
@@ -163,25 +273,24 @@ class AskResult:
|
|
163
273
|
if self._citations is not None:
|
164
274
|
yield CitationsAskResponseItem(citations=self._citations.citations)
|
165
275
|
|
166
|
-
# Stream out
|
276
|
+
# Stream out generic metadata about the answer
|
167
277
|
if self._metadata is not None:
|
168
278
|
yield MetadataAskResponseItem(
|
169
279
|
tokens=AskTokens(
|
170
280
|
input=self._metadata.input_tokens,
|
171
281
|
output=self._metadata.output_tokens,
|
282
|
+
input_nuclia=self._metadata.input_nuclia_tokens,
|
283
|
+
output_nuclia=self._metadata.output_nuclia_tokens,
|
172
284
|
),
|
173
285
|
timings=AskTimings(
|
174
|
-
generative_first_chunk=self._metadata.timings.get(
|
175
|
-
"generative_first_chunk"
|
176
|
-
),
|
286
|
+
generative_first_chunk=self._metadata.timings.get("generative_first_chunk"),
|
177
287
|
generative_total=self._metadata.timings.get("generative"),
|
178
288
|
),
|
179
289
|
)
|
180
290
|
|
181
291
|
# Stream out the relations results
|
182
292
|
should_query_relations = (
|
183
|
-
self.ask_request_with_relations
|
184
|
-
and self.status_code != AnswerStatusCode.NO_CONTEXT
|
293
|
+
self.ask_request_with_relations and self.status_code == AnswerStatusCode.SUCCESS
|
185
294
|
)
|
186
295
|
if should_query_relations:
|
187
296
|
relations = await self.get_relations_results()
|
@@ -189,11 +298,15 @@ class AskResult:
|
|
189
298
|
|
190
299
|
# Stream out debug information
|
191
300
|
if self.ask_request_with_debug_flag:
|
301
|
+
predict_request = None
|
302
|
+
if self.debug_chat_model:
|
303
|
+
predict_request = self.debug_chat_model.model_dump(mode="json")
|
192
304
|
yield DebugAskResponseItem(
|
193
305
|
metadata={
|
194
306
|
"prompt_context": sorted_prompt_context_list(
|
195
307
|
self.prompt_context, self.prompt_context_order
|
196
|
-
)
|
308
|
+
),
|
309
|
+
"predict_request": predict_request,
|
197
310
|
}
|
198
311
|
)
|
199
312
|
|
@@ -208,40 +321,68 @@ class AskResult:
|
|
208
321
|
tokens=AskTokens(
|
209
322
|
input=self._metadata.input_tokens,
|
210
323
|
output=self._metadata.output_tokens,
|
324
|
+
input_nuclia=self._metadata.input_nuclia_tokens,
|
325
|
+
output_nuclia=self._metadata.output_nuclia_tokens,
|
211
326
|
),
|
212
327
|
timings=AskTimings(
|
213
|
-
generative_first_chunk=self._metadata.timings.get(
|
214
|
-
"generative_first_chunk"
|
215
|
-
),
|
328
|
+
generative_first_chunk=self._metadata.timings.get("generative_first_chunk"),
|
216
329
|
generative_total=self._metadata.timings.get("generative"),
|
217
330
|
),
|
218
331
|
)
|
219
332
|
citations = {}
|
220
333
|
if self._citations is not None:
|
221
334
|
citations = self._citations.citations
|
335
|
+
|
336
|
+
answer_json = None
|
337
|
+
if self._object is not None:
|
338
|
+
answer_json = self._object.object
|
339
|
+
|
340
|
+
prequeries_results: Optional[dict[str, KnowledgeboxFindResults]] = None
|
341
|
+
if self.prequeries_results:
|
342
|
+
prequeries_results = {}
|
343
|
+
for index, (prequery, result) in enumerate(self.prequeries_results):
|
344
|
+
prequery_id = prequery.id or f"prequery_{index}"
|
345
|
+
prequeries_results[prequery_id] = result
|
346
|
+
|
347
|
+
best_matches = [
|
348
|
+
AskRetrievalMatch(
|
349
|
+
id=match.paragraph.id,
|
350
|
+
)
|
351
|
+
for match in self.best_matches
|
352
|
+
]
|
353
|
+
|
222
354
|
response = SyncAskResponse(
|
223
355
|
answer=self._answer_text,
|
356
|
+
answer_json=answer_json,
|
224
357
|
status=self.status_code.prettify(),
|
225
358
|
relations=self._relations,
|
226
|
-
retrieval_results=self.
|
359
|
+
retrieval_results=self.main_results,
|
360
|
+
retrieval_best_matches=best_matches,
|
361
|
+
prequeries=prequeries_results,
|
227
362
|
citations=citations,
|
228
363
|
metadata=metadata,
|
229
364
|
learning_id=self.nuclia_learning_id or "",
|
230
365
|
)
|
366
|
+
if self.status_code == AnswerStatusCode.ERROR and self.status_error_details:
|
367
|
+
response.error_details = self.status_error_details
|
231
368
|
if self.ask_request_with_debug_flag:
|
232
369
|
sorted_prompt_context = sorted_prompt_context_list(
|
233
370
|
self.prompt_context, self.prompt_context_order
|
234
371
|
)
|
235
372
|
response.prompt_context = sorted_prompt_context
|
236
|
-
|
373
|
+
if self.debug_chat_model:
|
374
|
+
response.predict_request = self.debug_chat_model.model_dump(mode="json")
|
375
|
+
return response.model_dump_json(exclude_none=True, by_alias=True)
|
237
376
|
|
238
377
|
async def get_relations_results(self) -> Relations:
|
239
378
|
if self._relations is None:
|
240
|
-
self.
|
241
|
-
|
242
|
-
|
243
|
-
|
244
|
-
|
379
|
+
with self.metrics.time("relations"):
|
380
|
+
self._relations = await get_relations_results(
|
381
|
+
kbid=self.kbid,
|
382
|
+
text_answer=self._answer_text,
|
383
|
+
target_shard_replicas=self.ask_request.shards,
|
384
|
+
timeout=5.0,
|
385
|
+
)
|
245
386
|
return self._relations
|
246
387
|
|
247
388
|
async def _stream_predict_answer_text(self) -> AsyncGenerator[str, None]:
|
@@ -257,12 +398,12 @@ class AskResult:
|
|
257
398
|
if isinstance(item, TextGenerativeResponse):
|
258
399
|
self._answer_text += item.text
|
259
400
|
yield item.text
|
401
|
+
elif isinstance(item, JSONGenerativeResponse):
|
402
|
+
self._object = item
|
260
403
|
elif isinstance(item, StatusGenerativeResponse):
|
261
404
|
self._status = item
|
262
|
-
continue
|
263
405
|
elif isinstance(item, CitationsGenerativeResponse):
|
264
406
|
self._citations = item
|
265
|
-
continue
|
266
407
|
elif isinstance(item, MetaGenerativeResponse):
|
267
408
|
self._metadata = item
|
268
409
|
else:
|
@@ -275,9 +416,11 @@ class AskResult:
|
|
275
416
|
class NotEnoughContextAskResult(AskResult):
|
276
417
|
def __init__(
|
277
418
|
self,
|
278
|
-
|
419
|
+
main_results: Optional[KnowledgeboxFindResults] = None,
|
420
|
+
prequeries_results: Optional[list[PreQueryResult]] = None,
|
279
421
|
):
|
280
|
-
self.
|
422
|
+
self.main_results = main_results or KnowledgeboxFindResults(resources={}, min_score=None)
|
423
|
+
self.prequeries_results = prequeries_results or []
|
281
424
|
self.nuclia_learning_id = None
|
282
425
|
|
283
426
|
async def ndjson_stream(self) -> AsyncGenerator[str, None]:
|
@@ -286,19 +429,17 @@ class NotEnoughContextAskResult(AskResult):
|
|
286
429
|
return the find results and the messages indicating that there is not enough
|
287
430
|
context in the corpus to answer.
|
288
431
|
"""
|
289
|
-
yield self._ndjson_encode(RetrievalAskResponseItem(results=self.
|
432
|
+
yield self._ndjson_encode(RetrievalAskResponseItem(results=self.main_results))
|
290
433
|
yield self._ndjson_encode(AnswerAskResponseItem(text=NOT_ENOUGH_CONTEXT_ANSWER))
|
291
434
|
status = AnswerStatusCode.NO_CONTEXT
|
292
|
-
yield self._ndjson_encode(
|
293
|
-
StatusAskResponseItem(code=status.value, status=status.prettify())
|
294
|
-
)
|
435
|
+
yield self._ndjson_encode(StatusAskResponseItem(code=status.value, status=status.prettify()))
|
295
436
|
|
296
437
|
async def json(self) -> str:
|
297
438
|
return SyncAskResponse(
|
298
439
|
answer=NOT_ENOUGH_CONTEXT_ANSWER,
|
299
|
-
retrieval_results=self.
|
440
|
+
retrieval_results=self.main_results,
|
300
441
|
status=AnswerStatusCode.NO_CONTEXT,
|
301
|
-
).
|
442
|
+
).model_dump_json()
|
302
443
|
|
303
444
|
|
304
445
|
async def ask(
|
@@ -310,7 +451,7 @@ async def ask(
|
|
310
451
|
origin: str,
|
311
452
|
resource: Optional[str] = None,
|
312
453
|
) -> AskResult:
|
313
|
-
|
454
|
+
metrics = RAGMetrics()
|
314
455
|
chat_history = ask_request.context or []
|
315
456
|
user_context = ask_request.extra_context or []
|
316
457
|
user_query = ask_request.query
|
@@ -318,117 +459,116 @@ async def ask(
|
|
318
459
|
# Maybe rephrase the query
|
319
460
|
rephrased_query = None
|
320
461
|
if len(chat_history) > 0 or len(user_context) > 0:
|
321
|
-
|
322
|
-
|
323
|
-
|
324
|
-
|
325
|
-
|
326
|
-
|
327
|
-
|
328
|
-
|
329
|
-
|
330
|
-
|
331
|
-
|
332
|
-
|
333
|
-
if resource is not None:
|
334
|
-
ask_request.resource_filters = [resource]
|
335
|
-
if any(
|
336
|
-
strategy.name == "full_resource" for strategy in ask_request.rag_strategies
|
337
|
-
):
|
338
|
-
needs_retrieval = False
|
462
|
+
try:
|
463
|
+
with metrics.time("rephrase"):
|
464
|
+
rephrased_query = await rephrase_query(
|
465
|
+
kbid,
|
466
|
+
chat_history=chat_history,
|
467
|
+
query=user_query,
|
468
|
+
user_id=user_id,
|
469
|
+
user_context=user_context,
|
470
|
+
generative_model=ask_request.generative_model,
|
471
|
+
)
|
472
|
+
except RephraseMissingContextError:
|
473
|
+
logger.info("Failed to rephrase ask query, using original")
|
339
474
|
|
340
|
-
|
341
|
-
|
342
|
-
find_results, query_parser = await get_find_results(
|
475
|
+
try:
|
476
|
+
retrieval_results = await retrieval_step(
|
343
477
|
kbid=kbid,
|
344
|
-
# Prefer the rephrased query if available
|
345
|
-
|
346
|
-
|
347
|
-
|
348
|
-
|
478
|
+
# Prefer the rephrased query for retrieval if available
|
479
|
+
main_query=rephrased_query or user_query,
|
480
|
+
ask_request=ask_request,
|
481
|
+
client_type=client_type,
|
482
|
+
user_id=user_id,
|
349
483
|
origin=origin,
|
484
|
+
metrics=metrics,
|
485
|
+
resource=resource,
|
350
486
|
)
|
351
|
-
|
352
|
-
|
353
|
-
|
354
|
-
|
355
|
-
|
356
|
-
|
357
|
-
kbid=kbid,
|
358
|
-
features=[],
|
359
|
-
query="",
|
360
|
-
filters=ask_request.filters,
|
361
|
-
page_number=0,
|
362
|
-
page_size=0,
|
363
|
-
min_score=MinScore(),
|
487
|
+
except NoRetrievalResultsError as err:
|
488
|
+
# If a retrieval was attempted but no results were found,
|
489
|
+
# early return the ask endpoint without querying the generative model
|
490
|
+
return NotEnoughContextAskResult(
|
491
|
+
main_results=err.main_query,
|
492
|
+
prequeries_results=err.prequeries,
|
364
493
|
)
|
365
494
|
|
366
|
-
|
367
|
-
query_parser.max_tokens = ask_request.max_tokens # type: ignore
|
368
|
-
max_tokens_context = await query_parser.get_max_tokens_context()
|
369
|
-
prompt_context_builder = PromptContextBuilder(
|
370
|
-
kbid=kbid,
|
371
|
-
find_results=find_results,
|
372
|
-
resource=resource,
|
373
|
-
user_context=user_context,
|
374
|
-
strategies=ask_request.rag_strategies,
|
375
|
-
image_strategies=ask_request.rag_images_strategies,
|
376
|
-
max_context_characters=tokens_to_chars(max_tokens_context),
|
377
|
-
visual_llm=await query_parser.get_visual_llm_enabled(),
|
378
|
-
)
|
379
|
-
(
|
380
|
-
prompt_context,
|
381
|
-
prompt_context_order,
|
382
|
-
prompt_context_images,
|
383
|
-
) = await prompt_context_builder.build()
|
495
|
+
query_parser = retrieval_results.query_parser
|
384
496
|
|
385
|
-
#
|
386
|
-
|
387
|
-
|
388
|
-
|
497
|
+
# Now we build the prompt context
|
498
|
+
with metrics.time("context_building"):
|
499
|
+
query_parser.max_tokens = ask_request.max_tokens # type: ignore
|
500
|
+
max_tokens_context = await query_parser.get_max_tokens_context()
|
501
|
+
prompt_context_builder = PromptContextBuilder(
|
502
|
+
kbid=kbid,
|
503
|
+
ordered_paragraphs=[match.paragraph for match in retrieval_results.best_matches],
|
504
|
+
resource=resource,
|
505
|
+
user_context=user_context,
|
506
|
+
strategies=ask_request.rag_strategies,
|
507
|
+
image_strategies=ask_request.rag_images_strategies,
|
508
|
+
max_context_characters=tokens_to_chars(max_tokens_context),
|
509
|
+
visual_llm=await query_parser.get_visual_llm_enabled(),
|
510
|
+
)
|
511
|
+
(
|
512
|
+
prompt_context,
|
513
|
+
prompt_context_order,
|
514
|
+
prompt_context_images,
|
515
|
+
) = await prompt_context_builder.build()
|
389
516
|
|
390
517
|
# Make the chat request to the predict API
|
518
|
+
custom_prompt = parse_custom_prompt(ask_request)
|
391
519
|
chat_model = ChatModel(
|
392
520
|
user_id=user_id,
|
521
|
+
system=custom_prompt.system,
|
522
|
+
user_prompt=UserPrompt(prompt=custom_prompt.user) if custom_prompt.user else None,
|
393
523
|
query_context=prompt_context,
|
394
524
|
query_context_order=prompt_context_order,
|
395
525
|
chat_history=chat_history,
|
396
526
|
question=user_query,
|
397
527
|
truncate=True,
|
398
|
-
user_prompt=user_prompt,
|
399
528
|
citations=ask_request.citations,
|
529
|
+
citation_threshold=ask_request.citation_threshold,
|
400
530
|
generative_model=ask_request.generative_model,
|
401
531
|
max_tokens=query_parser.get_max_tokens_answer(),
|
402
532
|
query_context_images=prompt_context_images,
|
533
|
+
json_schema=ask_request.answer_json_schema,
|
534
|
+
rerank_context=False,
|
535
|
+
top_k=ask_request.top_k,
|
403
536
|
)
|
404
|
-
|
405
|
-
|
406
|
-
|
407
|
-
|
537
|
+
with metrics.time("stream_start"):
|
538
|
+
predict = get_predict()
|
539
|
+
(
|
540
|
+
nuclia_learning_id,
|
541
|
+
nuclia_learning_model,
|
542
|
+
predict_answer_stream,
|
543
|
+
) = await predict.chat_query_ndjson(kbid, chat_model)
|
544
|
+
debug_chat_model = chat_model
|
408
545
|
|
409
546
|
auditor = ChatAuditor(
|
410
547
|
kbid=kbid,
|
411
548
|
user_id=user_id,
|
412
549
|
client_type=client_type,
|
413
550
|
origin=origin,
|
414
|
-
start_time=start_time,
|
415
551
|
user_query=user_query,
|
416
552
|
rephrased_query=rephrased_query,
|
417
553
|
chat_history=chat_history,
|
418
554
|
learning_id=nuclia_learning_id,
|
419
555
|
query_context=prompt_context,
|
420
556
|
query_context_order=prompt_context_order,
|
557
|
+
model=nuclia_learning_model,
|
421
558
|
)
|
422
|
-
|
423
559
|
return AskResult(
|
424
560
|
kbid=kbid,
|
425
561
|
ask_request=ask_request,
|
426
|
-
|
562
|
+
main_results=retrieval_results.main_query,
|
563
|
+
prequeries_results=retrieval_results.prequeries,
|
427
564
|
nuclia_learning_id=nuclia_learning_id,
|
428
|
-
predict_answer_stream=predict_answer_stream,
|
565
|
+
predict_answer_stream=predict_answer_stream, # type: ignore
|
429
566
|
prompt_context=prompt_context,
|
430
567
|
prompt_context_order=prompt_context_order,
|
431
568
|
auditor=auditor,
|
569
|
+
metrics=metrics,
|
570
|
+
best_matches=retrieval_results.best_matches,
|
571
|
+
debug_chat_model=debug_chat_model,
|
432
572
|
)
|
433
573
|
|
434
574
|
|
@@ -468,3 +608,311 @@ def handled_ask_exceptions(func):
|
|
468
608
|
return HTTPClientError(status_code=412, detail=str(exc))
|
469
609
|
|
470
610
|
return wrapper
|
611
|
+
|
612
|
+
|
613
|
+
def parse_prequeries(ask_request: AskRequest) -> Optional[PreQueriesStrategy]:
|
614
|
+
query_ids = []
|
615
|
+
for rag_strategy in ask_request.rag_strategies:
|
616
|
+
if rag_strategy.name == RagStrategyName.PREQUERIES:
|
617
|
+
prequeries = cast(PreQueriesStrategy, rag_strategy)
|
618
|
+
# Give each query a unique id if they don't have one
|
619
|
+
for index, query in enumerate(prequeries.queries):
|
620
|
+
if query.id is None:
|
621
|
+
query.id = f"prequery_{index}"
|
622
|
+
if query.id in query_ids:
|
623
|
+
raise InvalidQueryError(
|
624
|
+
"rag_strategies",
|
625
|
+
"Prequeries must have unique ids",
|
626
|
+
)
|
627
|
+
query_ids.append(query.id)
|
628
|
+
return prequeries
|
629
|
+
return None
|
630
|
+
|
631
|
+
|
632
|
+
async def retrieval_step(
|
633
|
+
kbid: str,
|
634
|
+
main_query: str,
|
635
|
+
ask_request: AskRequest,
|
636
|
+
client_type: NucliaDBClientType,
|
637
|
+
user_id: str,
|
638
|
+
origin: str,
|
639
|
+
metrics: RAGMetrics,
|
640
|
+
resource: Optional[str] = None,
|
641
|
+
) -> RetrievalResults:
|
642
|
+
"""
|
643
|
+
This function encapsulates all the logic related to retrieval in the ask endpoint.
|
644
|
+
"""
|
645
|
+
if resource is None:
|
646
|
+
return await retrieval_in_kb(
|
647
|
+
kbid,
|
648
|
+
main_query,
|
649
|
+
ask_request,
|
650
|
+
client_type,
|
651
|
+
user_id,
|
652
|
+
origin,
|
653
|
+
metrics,
|
654
|
+
)
|
655
|
+
else:
|
656
|
+
return await retrieval_in_resource(
|
657
|
+
kbid,
|
658
|
+
resource,
|
659
|
+
main_query,
|
660
|
+
ask_request,
|
661
|
+
client_type,
|
662
|
+
user_id,
|
663
|
+
origin,
|
664
|
+
metrics,
|
665
|
+
)
|
666
|
+
|
667
|
+
|
668
|
+
async def retrieval_in_kb(
|
669
|
+
kbid: str,
|
670
|
+
main_query: str,
|
671
|
+
ask_request: AskRequest,
|
672
|
+
client_type: NucliaDBClientType,
|
673
|
+
user_id: str,
|
674
|
+
origin: str,
|
675
|
+
metrics: RAGMetrics,
|
676
|
+
) -> RetrievalResults:
|
677
|
+
prequeries = parse_prequeries(ask_request)
|
678
|
+
with metrics.time("retrieval"):
|
679
|
+
main_results, prequeries_results, query_parser = await get_find_results(
|
680
|
+
kbid=kbid,
|
681
|
+
query=main_query,
|
682
|
+
item=ask_request,
|
683
|
+
ndb_client=client_type,
|
684
|
+
user=user_id,
|
685
|
+
origin=origin,
|
686
|
+
metrics=metrics,
|
687
|
+
prequeries_strategy=prequeries,
|
688
|
+
)
|
689
|
+
if len(main_results.resources) == 0 and all(
|
690
|
+
len(prequery_result.resources) == 0 for (_, prequery_result) in prequeries_results or []
|
691
|
+
):
|
692
|
+
raise NoRetrievalResultsError(main_results, prequeries_results)
|
693
|
+
|
694
|
+
main_query_weight = prequeries.main_query_weight if prequeries is not None else 1.0
|
695
|
+
best_matches = compute_best_matches(
|
696
|
+
main_results=main_results,
|
697
|
+
prequeries_results=prequeries_results,
|
698
|
+
main_query_weight=main_query_weight,
|
699
|
+
)
|
700
|
+
return RetrievalResults(
|
701
|
+
main_query=main_results,
|
702
|
+
prequeries=prequeries_results,
|
703
|
+
query_parser=query_parser,
|
704
|
+
main_query_weight=main_query_weight,
|
705
|
+
best_matches=best_matches,
|
706
|
+
)
|
707
|
+
|
708
|
+
|
709
|
+
async def retrieval_in_resource(
|
710
|
+
kbid: str,
|
711
|
+
resource: str,
|
712
|
+
main_query: str,
|
713
|
+
ask_request: AskRequest,
|
714
|
+
client_type: NucliaDBClientType,
|
715
|
+
user_id: str,
|
716
|
+
origin: str,
|
717
|
+
metrics: RAGMetrics,
|
718
|
+
) -> RetrievalResults:
|
719
|
+
if any(strategy.name == "full_resource" for strategy in ask_request.rag_strategies):
|
720
|
+
# Retrieval is not needed if we are chatting on a specific resource and the full_resource strategy is enabled
|
721
|
+
return RetrievalResults(
|
722
|
+
main_query=KnowledgeboxFindResults(resources={}, min_score=None),
|
723
|
+
prequeries=None,
|
724
|
+
query_parser=QueryParser(
|
725
|
+
kbid=kbid,
|
726
|
+
features=[],
|
727
|
+
query="",
|
728
|
+
label_filters=ask_request.filters,
|
729
|
+
keyword_filters=ask_request.keyword_filters,
|
730
|
+
top_k=0,
|
731
|
+
min_score=MinScore(),
|
732
|
+
),
|
733
|
+
main_query_weight=1.0,
|
734
|
+
)
|
735
|
+
|
736
|
+
prequeries = parse_prequeries(ask_request)
|
737
|
+
if prequeries is None and ask_request.answer_json_schema is not None and main_query == "":
|
738
|
+
prequeries = calculate_prequeries_for_json_schema(ask_request)
|
739
|
+
|
740
|
+
# Make sure the retrieval is scoped to the resource if provided
|
741
|
+
ask_request.resource_filters = [resource]
|
742
|
+
if prequeries is not None:
|
743
|
+
for prequery in prequeries.queries:
|
744
|
+
if prequery.prefilter is True:
|
745
|
+
raise InvalidQueryError(
|
746
|
+
"rag_strategies",
|
747
|
+
"Prequeries with prefilter are not supported when asking on a resource",
|
748
|
+
)
|
749
|
+
prequery.request.resource_filters = [resource]
|
750
|
+
|
751
|
+
with metrics.time("retrieval"):
|
752
|
+
main_results, prequeries_results, query_parser = await get_find_results(
|
753
|
+
kbid=kbid,
|
754
|
+
query=main_query,
|
755
|
+
item=ask_request,
|
756
|
+
ndb_client=client_type,
|
757
|
+
user=user_id,
|
758
|
+
origin=origin,
|
759
|
+
metrics=metrics,
|
760
|
+
prequeries_strategy=prequeries,
|
761
|
+
)
|
762
|
+
if len(main_results.resources) == 0 and all(
|
763
|
+
len(prequery_result.resources) == 0 for (_, prequery_result) in prequeries_results or []
|
764
|
+
):
|
765
|
+
raise NoRetrievalResultsError(main_results, prequeries_results)
|
766
|
+
main_query_weight = prequeries.main_query_weight if prequeries is not None else 1.0
|
767
|
+
best_matches = compute_best_matches(
|
768
|
+
main_results=main_results,
|
769
|
+
prequeries_results=prequeries_results,
|
770
|
+
main_query_weight=main_query_weight,
|
771
|
+
)
|
772
|
+
return RetrievalResults(
|
773
|
+
main_query=main_results,
|
774
|
+
prequeries=prequeries_results,
|
775
|
+
query_parser=query_parser,
|
776
|
+
main_query_weight=main_query_weight,
|
777
|
+
best_matches=best_matches,
|
778
|
+
)
|
779
|
+
|
780
|
+
|
781
|
+
def compute_best_matches(
|
782
|
+
main_results: KnowledgeboxFindResults,
|
783
|
+
prequeries_results: Optional[list[PreQueryResult]] = None,
|
784
|
+
main_query_weight: float = 1.0,
|
785
|
+
) -> list[RetrievalMatch]:
|
786
|
+
"""
|
787
|
+
Returns the list of matches of the retrieval results, ordered by relevance (descending weighted score).
|
788
|
+
|
789
|
+
If prequeries_results is provided, the paragraphs of the prequeries are weighted according to the
|
790
|
+
normalized weight of the prequery. The paragraph score is not modified, but it is used to determine the order in which they
|
791
|
+
are presented in the LLM prompt context.
|
792
|
+
|
793
|
+
If a paragraph is matched in various prequeries, the final weighted score is the sum of the weighted scores for each prequery.
|
794
|
+
|
795
|
+
`main_query_weight` is the weight given to the paragraphs matching the main query when calculating the final score.
|
796
|
+
"""
|
797
|
+
|
798
|
+
def iter_paragraphs(results: KnowledgeboxFindResults):
|
799
|
+
for resource in results.resources.values():
|
800
|
+
for field in resource.fields.values():
|
801
|
+
for paragraph in field.paragraphs.values():
|
802
|
+
yield paragraph
|
803
|
+
|
804
|
+
total_weights = main_query_weight + sum(prequery.weight for prequery, _ in prequeries_results or [])
|
805
|
+
paragraph_id_to_match: dict[str, RetrievalMatch] = {}
|
806
|
+
for paragraph in iter_paragraphs(main_results):
|
807
|
+
normalized_weight = main_query_weight / total_weights
|
808
|
+
rmatch = RetrievalMatch(
|
809
|
+
paragraph=paragraph,
|
810
|
+
weighted_score=paragraph.score * normalized_weight,
|
811
|
+
)
|
812
|
+
paragraph_id_to_match[paragraph.id] = rmatch
|
813
|
+
|
814
|
+
for prequery, prequery_results in prequeries_results or []:
|
815
|
+
for paragraph in iter_paragraphs(prequery_results):
|
816
|
+
normalized_weight = prequery.weight / total_weights
|
817
|
+
weighted_score = paragraph.score * normalized_weight
|
818
|
+
if paragraph.id in paragraph_id_to_match:
|
819
|
+
rmatch = paragraph_id_to_match[paragraph.id]
|
820
|
+
# If a paragraph is matched in various prequeries, the final score is the
|
821
|
+
# sum of the weighted scores
|
822
|
+
rmatch.weighted_score += weighted_score
|
823
|
+
else:
|
824
|
+
paragraph_id_to_match[paragraph.id] = RetrievalMatch(
|
825
|
+
paragraph=paragraph,
|
826
|
+
weighted_score=weighted_score,
|
827
|
+
)
|
828
|
+
|
829
|
+
return sorted(
|
830
|
+
paragraph_id_to_match.values(),
|
831
|
+
key=lambda match: match.weighted_score,
|
832
|
+
reverse=True,
|
833
|
+
)
|
834
|
+
|
835
|
+
|
836
|
+
def calculate_prequeries_for_json_schema(
|
837
|
+
ask_request: AskRequest,
|
838
|
+
) -> Optional[PreQueriesStrategy]:
|
839
|
+
"""
|
840
|
+
This function generates a PreQueriesStrategy with a query for each property in the JSON schema
|
841
|
+
found in ask_request.answer_json_schema.
|
842
|
+
|
843
|
+
This is useful for the use-case where the user is asking for a structured answer on a corpus
|
844
|
+
that is too big to send to the generative model.
|
845
|
+
|
846
|
+
For instance, a JSON schema like this:
|
847
|
+
{
|
848
|
+
"name": "book_ordering",
|
849
|
+
"description": "Structured answer for a book to order",
|
850
|
+
"parameters": {
|
851
|
+
"type": "object",
|
852
|
+
"properties": {
|
853
|
+
"title": {
|
854
|
+
"type": "string",
|
855
|
+
"description": "The title of the book"
|
856
|
+
},
|
857
|
+
"author": {
|
858
|
+
"type": "string",
|
859
|
+
"description": "The author of the book"
|
860
|
+
},
|
861
|
+
},
|
862
|
+
"required": ["title", "author"]
|
863
|
+
}
|
864
|
+
}
|
865
|
+
Will generate a PreQueriesStrategy with 2 queries, one for each property in the JSON schema, with equal weights
|
866
|
+
[
|
867
|
+
PreQuery(request=FindRequest(query="The title of the book", ...), weight=1.0),
|
868
|
+
PreQuery(request=FindRequest(query="The author of the book", ...), weight=1.0),
|
869
|
+
]
|
870
|
+
"""
|
871
|
+
prequeries: list[PreQuery] = []
|
872
|
+
json_schema = ask_request.answer_json_schema or {}
|
873
|
+
features = []
|
874
|
+
if ChatOptions.SEMANTIC in ask_request.features:
|
875
|
+
features.append(SearchOptions.SEMANTIC)
|
876
|
+
if ChatOptions.KEYWORD in ask_request.features:
|
877
|
+
features.append(SearchOptions.KEYWORD)
|
878
|
+
|
879
|
+
properties = json_schema.get("parameters", {}).get("properties", {})
|
880
|
+
if len(properties) == 0: # pragma: no cover
|
881
|
+
return None
|
882
|
+
for prop_name, prop_def in properties.items():
|
883
|
+
query = prop_name
|
884
|
+
if prop_def.get("description"):
|
885
|
+
query += f": {prop_def['description']}"
|
886
|
+
req = FindRequest(
|
887
|
+
query=query,
|
888
|
+
features=features,
|
889
|
+
filters=[],
|
890
|
+
keyword_filters=[],
|
891
|
+
top_k=10,
|
892
|
+
min_score=ask_request.min_score,
|
893
|
+
vectorset=ask_request.vectorset,
|
894
|
+
highlight=False,
|
895
|
+
debug=False,
|
896
|
+
show=[],
|
897
|
+
with_duplicates=False,
|
898
|
+
with_synonyms=False,
|
899
|
+
resource_filters=[], # to be filled with the resource filter
|
900
|
+
rephrase=ask_request.rephrase,
|
901
|
+
rephrase_prompt=parse_rephrase_prompt(ask_request),
|
902
|
+
security=ask_request.security,
|
903
|
+
autofilter=False,
|
904
|
+
)
|
905
|
+
prequery = PreQuery(
|
906
|
+
request=req,
|
907
|
+
weight=1.0,
|
908
|
+
)
|
909
|
+
prequeries.append(prequery)
|
910
|
+
try:
|
911
|
+
strategy = PreQueriesStrategy(queries=prequeries)
|
912
|
+
except ValidationError:
|
913
|
+
raise AnswerJsonSchemaTooLong(
|
914
|
+
"Answer JSON schema with too many properties generated too many prequeries"
|
915
|
+
)
|
916
|
+
|
917
|
+
ask_request.rag_strategies = [strategy]
|
918
|
+
return strategy
|