nucliadb 2.46.1.post382__py3-none-any.whl → 6.2.1.post2777__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- migrations/0002_rollover_shards.py +1 -2
- migrations/0003_allfields_key.py +2 -37
- migrations/0004_rollover_shards.py +1 -2
- migrations/0005_rollover_shards.py +1 -2
- migrations/0006_rollover_shards.py +2 -4
- migrations/0008_cleanup_leftover_rollover_metadata.py +1 -2
- migrations/0009_upgrade_relations_and_texts_to_v2.py +5 -4
- migrations/0010_fix_corrupt_indexes.py +11 -12
- migrations/0011_materialize_labelset_ids.py +2 -18
- migrations/0012_rollover_shards.py +6 -12
- migrations/0013_rollover_shards.py +2 -4
- migrations/0014_rollover_shards.py +5 -7
- migrations/0015_targeted_rollover.py +6 -12
- migrations/0016_upgrade_to_paragraphs_v2.py +27 -32
- migrations/0017_multiple_writable_shards.py +3 -6
- migrations/0018_purge_orphan_kbslugs.py +59 -0
- migrations/0019_upgrade_to_paragraphs_v3.py +66 -0
- migrations/0020_drain_nodes_from_cluster.py +83 -0
- nucliadb/standalone/tests/unit/test_run.py → migrations/0021_overwrite_vectorsets_key.py +17 -18
- nucliadb/tests/unit/test_openapi.py → migrations/0022_fix_paragraph_deletion_bug.py +16 -11
- migrations/0023_backfill_pg_catalog.py +80 -0
- migrations/0025_assign_models_to_kbs_v2.py +113 -0
- migrations/0026_fix_high_cardinality_content_types.py +61 -0
- migrations/0027_rollover_texts3.py +73 -0
- nucliadb/ingest/fields/date.py → migrations/pg/0001_bootstrap.py +10 -12
- migrations/pg/0002_catalog.py +42 -0
- nucliadb/ingest/tests/unit/test_settings.py → migrations/pg/0003_catalog_kbid_index.py +5 -3
- nucliadb/common/cluster/base.py +41 -24
- nucliadb/common/cluster/discovery/base.py +6 -14
- nucliadb/common/cluster/discovery/k8s.py +9 -19
- nucliadb/common/cluster/discovery/manual.py +1 -3
- nucliadb/common/cluster/discovery/single.py +1 -2
- nucliadb/common/cluster/discovery/utils.py +1 -3
- nucliadb/common/cluster/grpc_node_dummy.py +11 -16
- nucliadb/common/cluster/index_node.py +10 -19
- nucliadb/common/cluster/manager.py +223 -102
- nucliadb/common/cluster/rebalance.py +42 -37
- nucliadb/common/cluster/rollover.py +377 -204
- nucliadb/common/cluster/settings.py +16 -9
- nucliadb/common/cluster/standalone/grpc_node_binding.py +24 -76
- nucliadb/common/cluster/standalone/index_node.py +4 -11
- nucliadb/common/cluster/standalone/service.py +2 -6
- nucliadb/common/cluster/standalone/utils.py +9 -6
- nucliadb/common/cluster/utils.py +43 -29
- nucliadb/common/constants.py +20 -0
- nucliadb/common/context/__init__.py +6 -4
- nucliadb/common/context/fastapi.py +8 -5
- nucliadb/{tests/knowledgeboxes/__init__.py → common/counters.py} +8 -2
- nucliadb/common/datamanagers/__init__.py +24 -5
- nucliadb/common/datamanagers/atomic.py +102 -0
- nucliadb/common/datamanagers/cluster.py +5 -5
- nucliadb/common/datamanagers/entities.py +6 -16
- nucliadb/common/datamanagers/fields.py +84 -0
- nucliadb/common/datamanagers/kb.py +101 -24
- nucliadb/common/datamanagers/labels.py +26 -56
- nucliadb/common/datamanagers/processing.py +2 -6
- nucliadb/common/datamanagers/resources.py +214 -117
- nucliadb/common/datamanagers/rollover.py +77 -16
- nucliadb/{ingest/orm → common/datamanagers}/synonyms.py +16 -28
- nucliadb/common/datamanagers/utils.py +19 -11
- nucliadb/common/datamanagers/vectorsets.py +110 -0
- nucliadb/common/external_index_providers/base.py +257 -0
- nucliadb/{ingest/tests/unit/test_cache.py → common/external_index_providers/exceptions.py} +9 -8
- nucliadb/common/external_index_providers/manager.py +101 -0
- nucliadb/common/external_index_providers/pinecone.py +933 -0
- nucliadb/common/external_index_providers/settings.py +52 -0
- nucliadb/common/http_clients/auth.py +3 -6
- nucliadb/common/http_clients/processing.py +6 -11
- nucliadb/common/http_clients/utils.py +1 -3
- nucliadb/common/ids.py +240 -0
- nucliadb/common/locking.py +43 -13
- nucliadb/common/maindb/driver.py +11 -35
- nucliadb/common/maindb/exceptions.py +6 -6
- nucliadb/common/maindb/local.py +22 -9
- nucliadb/common/maindb/pg.py +206 -111
- nucliadb/common/maindb/utils.py +13 -44
- nucliadb/common/models_utils/from_proto.py +479 -0
- nucliadb/common/models_utils/to_proto.py +60 -0
- nucliadb/common/nidx.py +260 -0
- nucliadb/export_import/datamanager.py +25 -19
- nucliadb/export_import/exceptions.py +8 -0
- nucliadb/export_import/exporter.py +20 -7
- nucliadb/export_import/importer.py +6 -11
- nucliadb/export_import/models.py +5 -5
- nucliadb/export_import/tasks.py +4 -4
- nucliadb/export_import/utils.py +94 -54
- nucliadb/health.py +1 -3
- nucliadb/ingest/app.py +15 -11
- nucliadb/ingest/consumer/auditing.py +30 -147
- nucliadb/ingest/consumer/consumer.py +96 -52
- nucliadb/ingest/consumer/materializer.py +10 -12
- nucliadb/ingest/consumer/pull.py +12 -27
- nucliadb/ingest/consumer/service.py +20 -19
- nucliadb/ingest/consumer/shard_creator.py +7 -14
- nucliadb/ingest/consumer/utils.py +1 -3
- nucliadb/ingest/fields/base.py +139 -188
- nucliadb/ingest/fields/conversation.py +18 -5
- nucliadb/ingest/fields/exceptions.py +1 -4
- nucliadb/ingest/fields/file.py +7 -25
- nucliadb/ingest/fields/link.py +11 -16
- nucliadb/ingest/fields/text.py +9 -4
- nucliadb/ingest/orm/brain.py +255 -262
- nucliadb/ingest/orm/broker_message.py +181 -0
- nucliadb/ingest/orm/entities.py +36 -51
- nucliadb/ingest/orm/exceptions.py +12 -0
- nucliadb/ingest/orm/knowledgebox.py +334 -278
- nucliadb/ingest/orm/processor/__init__.py +2 -697
- nucliadb/ingest/orm/processor/auditing.py +117 -0
- nucliadb/ingest/orm/processor/data_augmentation.py +164 -0
- nucliadb/ingest/orm/processor/pgcatalog.py +84 -0
- nucliadb/ingest/orm/processor/processor.py +752 -0
- nucliadb/ingest/orm/processor/sequence_manager.py +1 -1
- nucliadb/ingest/orm/resource.py +280 -520
- nucliadb/ingest/orm/utils.py +25 -31
- nucliadb/ingest/partitions.py +3 -9
- nucliadb/ingest/processing.py +76 -81
- nucliadb/ingest/py.typed +0 -0
- nucliadb/ingest/serialize.py +37 -173
- nucliadb/ingest/service/__init__.py +1 -3
- nucliadb/ingest/service/writer.py +186 -577
- nucliadb/ingest/settings.py +13 -22
- nucliadb/ingest/utils.py +3 -6
- nucliadb/learning_proxy.py +264 -51
- nucliadb/metrics_exporter.py +30 -19
- nucliadb/middleware/__init__.py +1 -3
- nucliadb/migrator/command.py +1 -3
- nucliadb/migrator/datamanager.py +13 -13
- nucliadb/migrator/migrator.py +57 -37
- nucliadb/migrator/settings.py +2 -1
- nucliadb/migrator/utils.py +18 -10
- nucliadb/purge/__init__.py +139 -33
- nucliadb/purge/orphan_shards.py +7 -13
- nucliadb/reader/__init__.py +1 -3
- nucliadb/reader/api/models.py +3 -14
- nucliadb/reader/api/v1/__init__.py +0 -1
- nucliadb/reader/api/v1/download.py +27 -94
- nucliadb/reader/api/v1/export_import.py +4 -4
- nucliadb/reader/api/v1/knowledgebox.py +13 -13
- nucliadb/reader/api/v1/learning_config.py +8 -12
- nucliadb/reader/api/v1/resource.py +67 -93
- nucliadb/reader/api/v1/services.py +70 -125
- nucliadb/reader/app.py +16 -46
- nucliadb/reader/lifecycle.py +18 -4
- nucliadb/reader/py.typed +0 -0
- nucliadb/reader/reader/notifications.py +10 -31
- nucliadb/search/__init__.py +1 -3
- nucliadb/search/api/v1/__init__.py +2 -2
- nucliadb/search/api/v1/ask.py +112 -0
- nucliadb/search/api/v1/catalog.py +184 -0
- nucliadb/search/api/v1/feedback.py +17 -25
- nucliadb/search/api/v1/find.py +41 -41
- nucliadb/search/api/v1/knowledgebox.py +90 -62
- nucliadb/search/api/v1/predict_proxy.py +2 -2
- nucliadb/search/api/v1/resource/ask.py +66 -117
- nucliadb/search/api/v1/resource/search.py +51 -72
- nucliadb/search/api/v1/router.py +1 -0
- nucliadb/search/api/v1/search.py +50 -197
- nucliadb/search/api/v1/suggest.py +40 -54
- nucliadb/search/api/v1/summarize.py +9 -5
- nucliadb/search/api/v1/utils.py +2 -1
- nucliadb/search/app.py +16 -48
- nucliadb/search/lifecycle.py +10 -3
- nucliadb/search/predict.py +176 -188
- nucliadb/search/py.typed +0 -0
- nucliadb/search/requesters/utils.py +41 -63
- nucliadb/search/search/cache.py +149 -20
- nucliadb/search/search/chat/ask.py +918 -0
- nucliadb/search/{tests/unit/test_run.py → search/chat/exceptions.py} +14 -13
- nucliadb/search/search/chat/images.py +41 -17
- nucliadb/search/search/chat/prompt.py +851 -282
- nucliadb/search/search/chat/query.py +274 -267
- nucliadb/{writer/resource/slug.py → search/search/cut.py} +8 -6
- nucliadb/search/search/fetch.py +43 -36
- nucliadb/search/search/filters.py +9 -15
- nucliadb/search/search/find.py +214 -54
- nucliadb/search/search/find_merge.py +408 -391
- nucliadb/search/search/hydrator.py +191 -0
- nucliadb/search/search/merge.py +198 -234
- nucliadb/search/search/metrics.py +73 -2
- nucliadb/search/search/paragraphs.py +64 -106
- nucliadb/search/search/pgcatalog.py +233 -0
- nucliadb/search/search/predict_proxy.py +1 -1
- nucliadb/search/search/query.py +386 -257
- nucliadb/search/search/query_parser/exceptions.py +22 -0
- nucliadb/search/search/query_parser/models.py +101 -0
- nucliadb/search/search/query_parser/parser.py +183 -0
- nucliadb/search/search/rank_fusion.py +204 -0
- nucliadb/search/search/rerankers.py +270 -0
- nucliadb/search/search/shards.py +4 -38
- nucliadb/search/search/summarize.py +14 -18
- nucliadb/search/search/utils.py +27 -4
- nucliadb/search/settings.py +15 -1
- nucliadb/standalone/api_router.py +4 -10
- nucliadb/standalone/app.py +17 -14
- nucliadb/standalone/auth.py +7 -21
- nucliadb/standalone/config.py +9 -12
- nucliadb/standalone/introspect.py +5 -5
- nucliadb/standalone/lifecycle.py +26 -25
- nucliadb/standalone/migrations.py +58 -0
- nucliadb/standalone/purge.py +9 -8
- nucliadb/standalone/py.typed +0 -0
- nucliadb/standalone/run.py +25 -18
- nucliadb/standalone/settings.py +10 -14
- nucliadb/standalone/versions.py +15 -5
- nucliadb/tasks/consumer.py +8 -12
- nucliadb/tasks/producer.py +7 -6
- nucliadb/tests/config.py +53 -0
- nucliadb/train/__init__.py +1 -3
- nucliadb/train/api/utils.py +1 -2
- nucliadb/train/api/v1/shards.py +2 -2
- nucliadb/train/api/v1/trainset.py +4 -6
- nucliadb/train/app.py +14 -47
- nucliadb/train/generator.py +10 -19
- nucliadb/train/generators/field_classifier.py +7 -19
- nucliadb/train/generators/field_streaming.py +156 -0
- nucliadb/train/generators/image_classifier.py +12 -18
- nucliadb/train/generators/paragraph_classifier.py +5 -9
- nucliadb/train/generators/paragraph_streaming.py +6 -9
- nucliadb/train/generators/question_answer_streaming.py +19 -20
- nucliadb/train/generators/sentence_classifier.py +9 -15
- nucliadb/train/generators/token_classifier.py +45 -36
- nucliadb/train/generators/utils.py +14 -18
- nucliadb/train/lifecycle.py +7 -3
- nucliadb/train/nodes.py +23 -32
- nucliadb/train/py.typed +0 -0
- nucliadb/train/servicer.py +13 -21
- nucliadb/train/settings.py +2 -6
- nucliadb/train/types.py +13 -10
- nucliadb/train/upload.py +3 -6
- nucliadb/train/uploader.py +20 -25
- nucliadb/train/utils.py +1 -1
- nucliadb/writer/__init__.py +1 -3
- nucliadb/writer/api/constants.py +0 -5
- nucliadb/{ingest/fields/keywordset.py → writer/api/utils.py} +13 -10
- nucliadb/writer/api/v1/export_import.py +102 -49
- nucliadb/writer/api/v1/field.py +196 -620
- nucliadb/writer/api/v1/knowledgebox.py +221 -71
- nucliadb/writer/api/v1/learning_config.py +2 -2
- nucliadb/writer/api/v1/resource.py +114 -216
- nucliadb/writer/api/v1/services.py +64 -132
- nucliadb/writer/api/v1/slug.py +61 -0
- nucliadb/writer/api/v1/transaction.py +67 -0
- nucliadb/writer/api/v1/upload.py +184 -215
- nucliadb/writer/app.py +11 -61
- nucliadb/writer/back_pressure.py +62 -43
- nucliadb/writer/exceptions.py +0 -4
- nucliadb/writer/lifecycle.py +21 -15
- nucliadb/writer/py.typed +0 -0
- nucliadb/writer/resource/audit.py +2 -1
- nucliadb/writer/resource/basic.py +48 -62
- nucliadb/writer/resource/field.py +45 -135
- nucliadb/writer/resource/origin.py +1 -2
- nucliadb/writer/settings.py +14 -5
- nucliadb/writer/tus/__init__.py +17 -15
- nucliadb/writer/tus/azure.py +111 -0
- nucliadb/writer/tus/dm.py +17 -5
- nucliadb/writer/tus/exceptions.py +1 -3
- nucliadb/writer/tus/gcs.py +56 -84
- nucliadb/writer/tus/local.py +21 -37
- nucliadb/writer/tus/s3.py +28 -68
- nucliadb/writer/tus/storage.py +5 -56
- nucliadb/writer/vectorsets.py +125 -0
- nucliadb-6.2.1.post2777.dist-info/METADATA +148 -0
- nucliadb-6.2.1.post2777.dist-info/RECORD +343 -0
- {nucliadb-2.46.1.post382.dist-info → nucliadb-6.2.1.post2777.dist-info}/WHEEL +1 -1
- nucliadb/common/maindb/redis.py +0 -194
- nucliadb/common/maindb/tikv.py +0 -412
- nucliadb/ingest/fields/layout.py +0 -58
- nucliadb/ingest/tests/conftest.py +0 -30
- nucliadb/ingest/tests/fixtures.py +0 -771
- nucliadb/ingest/tests/integration/consumer/__init__.py +0 -18
- nucliadb/ingest/tests/integration/consumer/test_auditing.py +0 -80
- nucliadb/ingest/tests/integration/consumer/test_materializer.py +0 -89
- nucliadb/ingest/tests/integration/consumer/test_pull.py +0 -144
- nucliadb/ingest/tests/integration/consumer/test_service.py +0 -81
- nucliadb/ingest/tests/integration/consumer/test_shard_creator.py +0 -68
- nucliadb/ingest/tests/integration/ingest/test_ingest.py +0 -691
- nucliadb/ingest/tests/integration/ingest/test_processing_engine.py +0 -95
- nucliadb/ingest/tests/integration/ingest/test_relations.py +0 -272
- nucliadb/ingest/tests/unit/consumer/__init__.py +0 -18
- nucliadb/ingest/tests/unit/consumer/test_auditing.py +0 -140
- nucliadb/ingest/tests/unit/consumer/test_consumer.py +0 -69
- nucliadb/ingest/tests/unit/consumer/test_pull.py +0 -60
- nucliadb/ingest/tests/unit/consumer/test_shard_creator.py +0 -139
- nucliadb/ingest/tests/unit/consumer/test_utils.py +0 -67
- nucliadb/ingest/tests/unit/orm/__init__.py +0 -19
- nucliadb/ingest/tests/unit/orm/test_brain.py +0 -247
- nucliadb/ingest/tests/unit/orm/test_processor.py +0 -131
- nucliadb/ingest/tests/unit/orm/test_resource.py +0 -275
- nucliadb/ingest/tests/unit/test_partitions.py +0 -40
- nucliadb/ingest/tests/unit/test_processing.py +0 -171
- nucliadb/middleware/transaction.py +0 -117
- nucliadb/reader/api/v1/learning_collector.py +0 -63
- nucliadb/reader/tests/__init__.py +0 -19
- nucliadb/reader/tests/conftest.py +0 -31
- nucliadb/reader/tests/fixtures.py +0 -136
- nucliadb/reader/tests/test_list_resources.py +0 -75
- nucliadb/reader/tests/test_reader_file_download.py +0 -273
- nucliadb/reader/tests/test_reader_resource.py +0 -379
- nucliadb/reader/tests/test_reader_resource_field.py +0 -219
- nucliadb/search/api/v1/chat.py +0 -258
- nucliadb/search/api/v1/resource/chat.py +0 -94
- nucliadb/search/tests/__init__.py +0 -19
- nucliadb/search/tests/conftest.py +0 -33
- nucliadb/search/tests/fixtures.py +0 -199
- nucliadb/search/tests/node.py +0 -465
- nucliadb/search/tests/unit/__init__.py +0 -18
- nucliadb/search/tests/unit/api/__init__.py +0 -19
- nucliadb/search/tests/unit/api/v1/__init__.py +0 -19
- nucliadb/search/tests/unit/api/v1/resource/__init__.py +0 -19
- nucliadb/search/tests/unit/api/v1/resource/test_ask.py +0 -67
- nucliadb/search/tests/unit/api/v1/resource/test_chat.py +0 -97
- nucliadb/search/tests/unit/api/v1/test_chat.py +0 -96
- nucliadb/search/tests/unit/api/v1/test_predict_proxy.py +0 -98
- nucliadb/search/tests/unit/api/v1/test_summarize.py +0 -93
- nucliadb/search/tests/unit/search/__init__.py +0 -18
- nucliadb/search/tests/unit/search/requesters/__init__.py +0 -18
- nucliadb/search/tests/unit/search/requesters/test_utils.py +0 -210
- nucliadb/search/tests/unit/search/search/__init__.py +0 -19
- nucliadb/search/tests/unit/search/search/test_shards.py +0 -45
- nucliadb/search/tests/unit/search/search/test_utils.py +0 -82
- nucliadb/search/tests/unit/search/test_chat_prompt.py +0 -266
- nucliadb/search/tests/unit/search/test_fetch.py +0 -108
- nucliadb/search/tests/unit/search/test_filters.py +0 -125
- nucliadb/search/tests/unit/search/test_paragraphs.py +0 -157
- nucliadb/search/tests/unit/search/test_predict_proxy.py +0 -106
- nucliadb/search/tests/unit/search/test_query.py +0 -201
- nucliadb/search/tests/unit/test_app.py +0 -79
- nucliadb/search/tests/unit/test_find_merge.py +0 -112
- nucliadb/search/tests/unit/test_merge.py +0 -34
- nucliadb/search/tests/unit/test_predict.py +0 -584
- nucliadb/standalone/tests/__init__.py +0 -19
- nucliadb/standalone/tests/conftest.py +0 -33
- nucliadb/standalone/tests/fixtures.py +0 -38
- nucliadb/standalone/tests/unit/__init__.py +0 -18
- nucliadb/standalone/tests/unit/test_api_router.py +0 -61
- nucliadb/standalone/tests/unit/test_auth.py +0 -169
- nucliadb/standalone/tests/unit/test_introspect.py +0 -35
- nucliadb/standalone/tests/unit/test_versions.py +0 -68
- nucliadb/tests/benchmarks/__init__.py +0 -19
- nucliadb/tests/benchmarks/test_search.py +0 -99
- nucliadb/tests/conftest.py +0 -32
- nucliadb/tests/fixtures.py +0 -736
- nucliadb/tests/knowledgeboxes/philosophy_books.py +0 -203
- nucliadb/tests/knowledgeboxes/ten_dummy_resources.py +0 -109
- nucliadb/tests/migrations/__init__.py +0 -19
- nucliadb/tests/migrations/test_migration_0017.py +0 -80
- nucliadb/tests/tikv.py +0 -240
- nucliadb/tests/unit/__init__.py +0 -19
- nucliadb/tests/unit/common/__init__.py +0 -19
- nucliadb/tests/unit/common/cluster/__init__.py +0 -19
- nucliadb/tests/unit/common/cluster/discovery/__init__.py +0 -19
- nucliadb/tests/unit/common/cluster/discovery/test_k8s.py +0 -170
- nucliadb/tests/unit/common/cluster/standalone/__init__.py +0 -18
- nucliadb/tests/unit/common/cluster/standalone/test_service.py +0 -113
- nucliadb/tests/unit/common/cluster/standalone/test_utils.py +0 -59
- nucliadb/tests/unit/common/cluster/test_cluster.py +0 -399
- nucliadb/tests/unit/common/cluster/test_kb_shard_manager.py +0 -178
- nucliadb/tests/unit/common/cluster/test_rollover.py +0 -279
- nucliadb/tests/unit/common/maindb/__init__.py +0 -18
- nucliadb/tests/unit/common/maindb/test_driver.py +0 -127
- nucliadb/tests/unit/common/maindb/test_tikv.py +0 -53
- nucliadb/tests/unit/common/maindb/test_utils.py +0 -81
- nucliadb/tests/unit/common/test_context.py +0 -36
- nucliadb/tests/unit/export_import/__init__.py +0 -19
- nucliadb/tests/unit/export_import/test_datamanager.py +0 -37
- nucliadb/tests/unit/export_import/test_utils.py +0 -294
- nucliadb/tests/unit/migrator/__init__.py +0 -19
- nucliadb/tests/unit/migrator/test_migrator.py +0 -87
- nucliadb/tests/unit/tasks/__init__.py +0 -19
- nucliadb/tests/unit/tasks/conftest.py +0 -42
- nucliadb/tests/unit/tasks/test_consumer.py +0 -93
- nucliadb/tests/unit/tasks/test_producer.py +0 -95
- nucliadb/tests/unit/tasks/test_tasks.py +0 -60
- nucliadb/tests/unit/test_field_ids.py +0 -49
- nucliadb/tests/unit/test_health.py +0 -84
- nucliadb/tests/unit/test_kb_slugs.py +0 -54
- nucliadb/tests/unit/test_learning_proxy.py +0 -252
- nucliadb/tests/unit/test_metrics_exporter.py +0 -77
- nucliadb/tests/unit/test_purge.py +0 -138
- nucliadb/tests/utils/__init__.py +0 -74
- nucliadb/tests/utils/aiohttp_session.py +0 -44
- nucliadb/tests/utils/broker_messages/__init__.py +0 -167
- nucliadb/tests/utils/broker_messages/fields.py +0 -181
- nucliadb/tests/utils/broker_messages/helpers.py +0 -33
- nucliadb/tests/utils/entities.py +0 -78
- nucliadb/train/api/v1/check.py +0 -60
- nucliadb/train/tests/__init__.py +0 -19
- nucliadb/train/tests/conftest.py +0 -29
- nucliadb/train/tests/fixtures.py +0 -342
- nucliadb/train/tests/test_field_classification.py +0 -122
- nucliadb/train/tests/test_get_entities.py +0 -80
- nucliadb/train/tests/test_get_info.py +0 -51
- nucliadb/train/tests/test_get_ontology.py +0 -34
- nucliadb/train/tests/test_get_ontology_count.py +0 -63
- nucliadb/train/tests/test_image_classification.py +0 -222
- nucliadb/train/tests/test_list_fields.py +0 -39
- nucliadb/train/tests/test_list_paragraphs.py +0 -73
- nucliadb/train/tests/test_list_resources.py +0 -39
- nucliadb/train/tests/test_list_sentences.py +0 -71
- nucliadb/train/tests/test_paragraph_classification.py +0 -123
- nucliadb/train/tests/test_paragraph_streaming.py +0 -118
- nucliadb/train/tests/test_question_answer_streaming.py +0 -239
- nucliadb/train/tests/test_sentence_classification.py +0 -143
- nucliadb/train/tests/test_token_classification.py +0 -136
- nucliadb/train/tests/utils.py +0 -108
- nucliadb/writer/layouts/__init__.py +0 -51
- nucliadb/writer/layouts/v1.py +0 -59
- nucliadb/writer/resource/vectors.py +0 -120
- nucliadb/writer/tests/__init__.py +0 -19
- nucliadb/writer/tests/conftest.py +0 -31
- nucliadb/writer/tests/fixtures.py +0 -192
- nucliadb/writer/tests/test_fields.py +0 -486
- nucliadb/writer/tests/test_files.py +0 -743
- nucliadb/writer/tests/test_knowledgebox.py +0 -49
- nucliadb/writer/tests/test_reprocess_file_field.py +0 -139
- nucliadb/writer/tests/test_resources.py +0 -546
- nucliadb/writer/tests/test_service.py +0 -137
- nucliadb/writer/tests/test_tus.py +0 -203
- nucliadb/writer/tests/utils.py +0 -35
- nucliadb/writer/tus/pg.py +0 -125
- nucliadb-2.46.1.post382.dist-info/METADATA +0 -134
- nucliadb-2.46.1.post382.dist-info/RECORD +0 -451
- {nucliadb/ingest/tests → migrations/pg}/__init__.py +0 -0
- /nucliadb/{ingest/tests/integration → common/external_index_providers}/__init__.py +0 -0
- /nucliadb/{ingest/tests/integration/ingest → common/models_utils}/__init__.py +0 -0
- /nucliadb/{ingest/tests/unit → search/search/query_parser}/__init__.py +0 -0
- /nucliadb/{ingest/tests → tests}/vectors.py +0 -0
- {nucliadb-2.46.1.post382.dist-info → nucliadb-6.2.1.post2777.dist-info}/entry_points.txt +0 -0
- {nucliadb-2.46.1.post382.dist-info → nucliadb-6.2.1.post2777.dist-info}/top_level.txt +0 -0
- {nucliadb-2.46.1.post382.dist-info → nucliadb-6.2.1.post2777.dist-info}/zip-safe +0 -0
@@ -0,0 +1,918 @@
|
|
1
|
+
# Copyright (C) 2021 Bosutech XXI S.L.
|
2
|
+
#
|
3
|
+
# nucliadb is offered under the AGPL v3.0 and as commercial software.
|
4
|
+
# For commercial licensing, contact us at info@nuclia.com.
|
5
|
+
#
|
6
|
+
# AGPL:
|
7
|
+
# This program is free software: you can redistribute it and/or modify
|
8
|
+
# it under the terms of the GNU Affero General Public License as
|
9
|
+
# published by the Free Software Foundation, either version 3 of the
|
10
|
+
# License, or (at your option) any later version.
|
11
|
+
#
|
12
|
+
# This program is distributed in the hope that it will be useful,
|
13
|
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
14
|
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
15
|
+
# GNU Affero General Public License for more details.
|
16
|
+
#
|
17
|
+
# You should have received a copy of the GNU Affero General Public License
|
18
|
+
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
19
|
+
#
|
20
|
+
import dataclasses
|
21
|
+
import functools
|
22
|
+
import json
|
23
|
+
from typing import AsyncGenerator, Optional, cast
|
24
|
+
|
25
|
+
from nuclia_models.predict.generative_responses import (
|
26
|
+
CitationsGenerativeResponse,
|
27
|
+
GenerativeChunk,
|
28
|
+
JSONGenerativeResponse,
|
29
|
+
MetaGenerativeResponse,
|
30
|
+
StatusGenerativeResponse,
|
31
|
+
TextGenerativeResponse,
|
32
|
+
)
|
33
|
+
from pydantic_core import ValidationError
|
34
|
+
|
35
|
+
from nucliadb.common.datamanagers.exceptions import KnowledgeBoxNotFound
|
36
|
+
from nucliadb.models.responses import HTTPClientError
|
37
|
+
from nucliadb.search import logger, predict
|
38
|
+
from nucliadb.search.predict import (
|
39
|
+
AnswerStatusCode,
|
40
|
+
RephraseMissingContextError,
|
41
|
+
)
|
42
|
+
from nucliadb.search.search.chat.exceptions import (
|
43
|
+
AnswerJsonSchemaTooLong,
|
44
|
+
NoRetrievalResultsError,
|
45
|
+
)
|
46
|
+
from nucliadb.search.search.chat.prompt import PromptContextBuilder
|
47
|
+
from nucliadb.search.search.chat.query import (
|
48
|
+
NOT_ENOUGH_CONTEXT_ANSWER,
|
49
|
+
ChatAuditor,
|
50
|
+
get_find_results,
|
51
|
+
get_relations_results,
|
52
|
+
rephrase_query,
|
53
|
+
sorted_prompt_context_list,
|
54
|
+
tokens_to_chars,
|
55
|
+
)
|
56
|
+
from nucliadb.search.search.exceptions import (
|
57
|
+
IncompleteFindResultsError,
|
58
|
+
InvalidQueryError,
|
59
|
+
)
|
60
|
+
from nucliadb.search.search.metrics import RAGMetrics
|
61
|
+
from nucliadb.search.search.query import QueryParser
|
62
|
+
from nucliadb.search.utilities import get_predict
|
63
|
+
from nucliadb_models.search import (
|
64
|
+
AnswerAskResponseItem,
|
65
|
+
AskRequest,
|
66
|
+
AskResponseItem,
|
67
|
+
AskResponseItemType,
|
68
|
+
AskRetrievalMatch,
|
69
|
+
AskTimings,
|
70
|
+
AskTokens,
|
71
|
+
ChatModel,
|
72
|
+
ChatOptions,
|
73
|
+
CitationsAskResponseItem,
|
74
|
+
DebugAskResponseItem,
|
75
|
+
ErrorAskResponseItem,
|
76
|
+
FindParagraph,
|
77
|
+
FindRequest,
|
78
|
+
JSONAskResponseItem,
|
79
|
+
KnowledgeboxFindResults,
|
80
|
+
MetadataAskResponseItem,
|
81
|
+
MinScore,
|
82
|
+
NucliaDBClientType,
|
83
|
+
PrequeriesAskResponseItem,
|
84
|
+
PreQueriesStrategy,
|
85
|
+
PreQuery,
|
86
|
+
PreQueryResult,
|
87
|
+
PromptContext,
|
88
|
+
PromptContextOrder,
|
89
|
+
RagStrategyName,
|
90
|
+
Relations,
|
91
|
+
RelationsAskResponseItem,
|
92
|
+
RetrievalAskResponseItem,
|
93
|
+
SearchOptions,
|
94
|
+
StatusAskResponseItem,
|
95
|
+
SyncAskMetadata,
|
96
|
+
SyncAskResponse,
|
97
|
+
UserPrompt,
|
98
|
+
parse_custom_prompt,
|
99
|
+
parse_rephrase_prompt,
|
100
|
+
)
|
101
|
+
from nucliadb_telemetry import errors
|
102
|
+
from nucliadb_utils.exceptions import LimitsExceededError
|
103
|
+
|
104
|
+
|
105
|
+
@dataclasses.dataclass
|
106
|
+
class RetrievalMatch:
|
107
|
+
paragraph: FindParagraph
|
108
|
+
weighted_score: float
|
109
|
+
|
110
|
+
|
111
|
+
@dataclasses.dataclass
|
112
|
+
class RetrievalResults:
|
113
|
+
main_query: KnowledgeboxFindResults
|
114
|
+
query_parser: QueryParser
|
115
|
+
main_query_weight: float
|
116
|
+
prequeries: Optional[list[PreQueryResult]] = None
|
117
|
+
best_matches: list[RetrievalMatch] = dataclasses.field(default_factory=list)
|
118
|
+
|
119
|
+
|
120
|
+
class AskResult:
|
121
|
+
def __init__(
|
122
|
+
self,
|
123
|
+
*,
|
124
|
+
kbid: str,
|
125
|
+
ask_request: AskRequest,
|
126
|
+
main_results: KnowledgeboxFindResults,
|
127
|
+
prequeries_results: Optional[list[PreQueryResult]],
|
128
|
+
nuclia_learning_id: Optional[str],
|
129
|
+
predict_answer_stream: AsyncGenerator[GenerativeChunk, None],
|
130
|
+
prompt_context: PromptContext,
|
131
|
+
prompt_context_order: PromptContextOrder,
|
132
|
+
auditor: ChatAuditor,
|
133
|
+
metrics: RAGMetrics,
|
134
|
+
best_matches: list[RetrievalMatch],
|
135
|
+
debug_chat_model: Optional[ChatModel],
|
136
|
+
):
|
137
|
+
# Initial attributes
|
138
|
+
self.kbid = kbid
|
139
|
+
self.ask_request = ask_request
|
140
|
+
self.main_results = main_results
|
141
|
+
self.prequeries_results = prequeries_results or []
|
142
|
+
self.nuclia_learning_id = nuclia_learning_id
|
143
|
+
self.predict_answer_stream = predict_answer_stream
|
144
|
+
self.prompt_context = prompt_context
|
145
|
+
self.debug_chat_model = debug_chat_model
|
146
|
+
self.prompt_context_order = prompt_context_order
|
147
|
+
self.auditor: ChatAuditor = auditor
|
148
|
+
self.metrics: RAGMetrics = metrics
|
149
|
+
self.best_matches: list[RetrievalMatch] = best_matches
|
150
|
+
|
151
|
+
# Computed from the predict chat answer stream
|
152
|
+
self._answer_text = ""
|
153
|
+
self._object: Optional[JSONGenerativeResponse] = None
|
154
|
+
self._status: Optional[StatusGenerativeResponse] = None
|
155
|
+
self._citations: Optional[CitationsGenerativeResponse] = None
|
156
|
+
self._metadata: Optional[MetaGenerativeResponse] = None
|
157
|
+
self._relations: Optional[Relations] = None
|
158
|
+
|
159
|
+
@property
|
160
|
+
def status_code(self) -> AnswerStatusCode:
|
161
|
+
if self._status is None:
|
162
|
+
return AnswerStatusCode.SUCCESS
|
163
|
+
return AnswerStatusCode(self._status.code)
|
164
|
+
|
165
|
+
@property
|
166
|
+
def status_error_details(self) -> Optional[str]:
|
167
|
+
if self._status is None: # pragma: no cover
|
168
|
+
return None
|
169
|
+
return self._status.details
|
170
|
+
|
171
|
+
@property
|
172
|
+
def ask_request_with_relations(self) -> bool:
|
173
|
+
return ChatOptions.RELATIONS in self.ask_request.features
|
174
|
+
|
175
|
+
@property
|
176
|
+
def ask_request_with_debug_flag(self) -> bool:
|
177
|
+
return self.ask_request.debug
|
178
|
+
|
179
|
+
async def ndjson_stream(self) -> AsyncGenerator[str, None]:
|
180
|
+
try:
|
181
|
+
async for item in self._stream():
|
182
|
+
yield self._ndjson_encode(item)
|
183
|
+
except Exception as exc:
|
184
|
+
# Handle any unexpected error that might happen
|
185
|
+
# during the streaming and halt the stream
|
186
|
+
errors.capture_exception(exc)
|
187
|
+
logger.error(
|
188
|
+
f"Unexpected error while generating the answer: {exc}",
|
189
|
+
extra={"kbid": self.kbid},
|
190
|
+
)
|
191
|
+
error_message = "Unexpected error while generating the answer. Please try again later."
|
192
|
+
if self.ask_request_with_debug_flag:
|
193
|
+
error_message += f" Error: {exc}"
|
194
|
+
item = ErrorAskResponseItem(error=error_message)
|
195
|
+
yield self._ndjson_encode(item)
|
196
|
+
return
|
197
|
+
|
198
|
+
def _ndjson_encode(self, item: AskResponseItemType) -> str:
|
199
|
+
result_item = AskResponseItem(item=item)
|
200
|
+
return result_item.model_dump_json(exclude_none=True, by_alias=True) + "\n"
|
201
|
+
|
202
|
+
async def _stream(self) -> AsyncGenerator[AskResponseItemType, None]:
|
203
|
+
# First, stream out the predict answer
|
204
|
+
first_chunk_yielded = False
|
205
|
+
with self.metrics.time("stream_predict_answer"):
|
206
|
+
async for answer_chunk in self._stream_predict_answer_text():
|
207
|
+
yield AnswerAskResponseItem(text=answer_chunk)
|
208
|
+
if not first_chunk_yielded:
|
209
|
+
self.metrics.record_first_chunk_yielded()
|
210
|
+
first_chunk_yielded = True
|
211
|
+
|
212
|
+
if self._object is not None:
|
213
|
+
yield JSONAskResponseItem(object=self._object.object)
|
214
|
+
if not first_chunk_yielded:
|
215
|
+
# When there is a JSON generative response, we consider the first chunk yielded
|
216
|
+
# to be the moment when the JSON object is yielded, not the text
|
217
|
+
self.metrics.record_first_chunk_yielded()
|
218
|
+
first_chunk_yielded = True
|
219
|
+
|
220
|
+
yield RetrievalAskResponseItem(
|
221
|
+
results=self.main_results,
|
222
|
+
best_matches=[
|
223
|
+
AskRetrievalMatch(
|
224
|
+
id=match.paragraph.id,
|
225
|
+
)
|
226
|
+
for match in self.best_matches
|
227
|
+
],
|
228
|
+
)
|
229
|
+
|
230
|
+
if len(self.prequeries_results) > 0:
|
231
|
+
item = PrequeriesAskResponseItem()
|
232
|
+
for index, (prequery, result) in enumerate(self.prequeries_results):
|
233
|
+
prequery_id = prequery.id or f"prequery_{index}"
|
234
|
+
item.results[prequery_id] = result
|
235
|
+
yield item
|
236
|
+
|
237
|
+
# Then the status
|
238
|
+
if self.status_code == AnswerStatusCode.ERROR:
|
239
|
+
# If predict yielded an error status, we yield it too and halt the stream immediately
|
240
|
+
yield StatusAskResponseItem(
|
241
|
+
code=self.status_code.value,
|
242
|
+
status=self.status_code.prettify(),
|
243
|
+
details=self.status_error_details or "Unknown error",
|
244
|
+
)
|
245
|
+
return
|
246
|
+
|
247
|
+
yield StatusAskResponseItem(
|
248
|
+
code=self.status_code.value,
|
249
|
+
status=self.status_code.prettify(),
|
250
|
+
)
|
251
|
+
|
252
|
+
# Audit the answer
|
253
|
+
if self._object is None:
|
254
|
+
audit_answer = self._answer_text.encode("utf-8")
|
255
|
+
else:
|
256
|
+
audit_answer = json.dumps(self._object.object).encode("utf-8")
|
257
|
+
|
258
|
+
try:
|
259
|
+
rephrase_time = self.metrics.elapsed("rephrase")
|
260
|
+
except KeyError:
|
261
|
+
# Not all ask requests have a rephrase step
|
262
|
+
rephrase_time = None
|
263
|
+
|
264
|
+
self.auditor.audit(
|
265
|
+
text_answer=audit_answer,
|
266
|
+
generative_answer_time=self.metrics.elapsed("stream_predict_answer"),
|
267
|
+
generative_answer_first_chunk_time=self.metrics.get_first_chunk_time() or 0,
|
268
|
+
rephrase_time=rephrase_time,
|
269
|
+
status_code=self.status_code,
|
270
|
+
)
|
271
|
+
|
272
|
+
# Stream out the citations
|
273
|
+
if self._citations is not None:
|
274
|
+
yield CitationsAskResponseItem(citations=self._citations.citations)
|
275
|
+
|
276
|
+
# Stream out generic metadata about the answer
|
277
|
+
if self._metadata is not None:
|
278
|
+
yield MetadataAskResponseItem(
|
279
|
+
tokens=AskTokens(
|
280
|
+
input=self._metadata.input_tokens,
|
281
|
+
output=self._metadata.output_tokens,
|
282
|
+
input_nuclia=self._metadata.input_nuclia_tokens,
|
283
|
+
output_nuclia=self._metadata.output_nuclia_tokens,
|
284
|
+
),
|
285
|
+
timings=AskTimings(
|
286
|
+
generative_first_chunk=self._metadata.timings.get("generative_first_chunk"),
|
287
|
+
generative_total=self._metadata.timings.get("generative"),
|
288
|
+
),
|
289
|
+
)
|
290
|
+
|
291
|
+
# Stream out the relations results
|
292
|
+
should_query_relations = (
|
293
|
+
self.ask_request_with_relations and self.status_code == AnswerStatusCode.SUCCESS
|
294
|
+
)
|
295
|
+
if should_query_relations:
|
296
|
+
relations = await self.get_relations_results()
|
297
|
+
yield RelationsAskResponseItem(relations=relations)
|
298
|
+
|
299
|
+
# Stream out debug information
|
300
|
+
if self.ask_request_with_debug_flag:
|
301
|
+
predict_request = None
|
302
|
+
if self.debug_chat_model:
|
303
|
+
predict_request = self.debug_chat_model.model_dump(mode="json")
|
304
|
+
yield DebugAskResponseItem(
|
305
|
+
metadata={
|
306
|
+
"prompt_context": sorted_prompt_context_list(
|
307
|
+
self.prompt_context, self.prompt_context_order
|
308
|
+
),
|
309
|
+
"predict_request": predict_request,
|
310
|
+
}
|
311
|
+
)
|
312
|
+
|
313
|
+
async def json(self) -> str:
|
314
|
+
# First, run the stream in memory to get all the data in memory
|
315
|
+
async for _ in self._stream():
|
316
|
+
...
|
317
|
+
|
318
|
+
metadata = None
|
319
|
+
if self._metadata is not None:
|
320
|
+
metadata = SyncAskMetadata(
|
321
|
+
tokens=AskTokens(
|
322
|
+
input=self._metadata.input_tokens,
|
323
|
+
output=self._metadata.output_tokens,
|
324
|
+
input_nuclia=self._metadata.input_nuclia_tokens,
|
325
|
+
output_nuclia=self._metadata.output_nuclia_tokens,
|
326
|
+
),
|
327
|
+
timings=AskTimings(
|
328
|
+
generative_first_chunk=self._metadata.timings.get("generative_first_chunk"),
|
329
|
+
generative_total=self._metadata.timings.get("generative"),
|
330
|
+
),
|
331
|
+
)
|
332
|
+
citations = {}
|
333
|
+
if self._citations is not None:
|
334
|
+
citations = self._citations.citations
|
335
|
+
|
336
|
+
answer_json = None
|
337
|
+
if self._object is not None:
|
338
|
+
answer_json = self._object.object
|
339
|
+
|
340
|
+
prequeries_results: Optional[dict[str, KnowledgeboxFindResults]] = None
|
341
|
+
if self.prequeries_results:
|
342
|
+
prequeries_results = {}
|
343
|
+
for index, (prequery, result) in enumerate(self.prequeries_results):
|
344
|
+
prequery_id = prequery.id or f"prequery_{index}"
|
345
|
+
prequeries_results[prequery_id] = result
|
346
|
+
|
347
|
+
best_matches = [
|
348
|
+
AskRetrievalMatch(
|
349
|
+
id=match.paragraph.id,
|
350
|
+
)
|
351
|
+
for match in self.best_matches
|
352
|
+
]
|
353
|
+
|
354
|
+
response = SyncAskResponse(
|
355
|
+
answer=self._answer_text,
|
356
|
+
answer_json=answer_json,
|
357
|
+
status=self.status_code.prettify(),
|
358
|
+
relations=self._relations,
|
359
|
+
retrieval_results=self.main_results,
|
360
|
+
retrieval_best_matches=best_matches,
|
361
|
+
prequeries=prequeries_results,
|
362
|
+
citations=citations,
|
363
|
+
metadata=metadata,
|
364
|
+
learning_id=self.nuclia_learning_id or "",
|
365
|
+
)
|
366
|
+
if self.status_code == AnswerStatusCode.ERROR and self.status_error_details:
|
367
|
+
response.error_details = self.status_error_details
|
368
|
+
if self.ask_request_with_debug_flag:
|
369
|
+
sorted_prompt_context = sorted_prompt_context_list(
|
370
|
+
self.prompt_context, self.prompt_context_order
|
371
|
+
)
|
372
|
+
response.prompt_context = sorted_prompt_context
|
373
|
+
if self.debug_chat_model:
|
374
|
+
response.predict_request = self.debug_chat_model.model_dump(mode="json")
|
375
|
+
return response.model_dump_json(exclude_none=True, by_alias=True)
|
376
|
+
|
377
|
+
async def get_relations_results(self) -> Relations:
|
378
|
+
if self._relations is None:
|
379
|
+
with self.metrics.time("relations"):
|
380
|
+
self._relations = await get_relations_results(
|
381
|
+
kbid=self.kbid,
|
382
|
+
text_answer=self._answer_text,
|
383
|
+
target_shard_replicas=self.ask_request.shards,
|
384
|
+
timeout=5.0,
|
385
|
+
)
|
386
|
+
return self._relations
|
387
|
+
|
388
|
+
async def _stream_predict_answer_text(self) -> AsyncGenerator[str, None]:
|
389
|
+
"""
|
390
|
+
Reads the stream of the generative model, yielding the answer text but also parsing
|
391
|
+
other items like status codes, citations and miscellaneous metadata.
|
392
|
+
|
393
|
+
This method does not assume any order in the stream of items, but it assumes that at least
|
394
|
+
the answer text is streamed in order.
|
395
|
+
"""
|
396
|
+
async for generative_chunk in self.predict_answer_stream:
|
397
|
+
item = generative_chunk.chunk
|
398
|
+
if isinstance(item, TextGenerativeResponse):
|
399
|
+
self._answer_text += item.text
|
400
|
+
yield item.text
|
401
|
+
elif isinstance(item, JSONGenerativeResponse):
|
402
|
+
self._object = item
|
403
|
+
elif isinstance(item, StatusGenerativeResponse):
|
404
|
+
self._status = item
|
405
|
+
elif isinstance(item, CitationsGenerativeResponse):
|
406
|
+
self._citations = item
|
407
|
+
elif isinstance(item, MetaGenerativeResponse):
|
408
|
+
self._metadata = item
|
409
|
+
else:
|
410
|
+
logger.warning(
|
411
|
+
f"Unexpected item in predict answer stream: {item}",
|
412
|
+
extra={"kbid": self.kbid},
|
413
|
+
)
|
414
|
+
|
415
|
+
|
416
|
+
class NotEnoughContextAskResult(AskResult):
|
417
|
+
def __init__(
|
418
|
+
self,
|
419
|
+
main_results: Optional[KnowledgeboxFindResults] = None,
|
420
|
+
prequeries_results: Optional[list[PreQueryResult]] = None,
|
421
|
+
):
|
422
|
+
self.main_results = main_results or KnowledgeboxFindResults(resources={}, min_score=None)
|
423
|
+
self.prequeries_results = prequeries_results or []
|
424
|
+
self.nuclia_learning_id = None
|
425
|
+
|
426
|
+
async def ndjson_stream(self) -> AsyncGenerator[str, None]:
|
427
|
+
"""
|
428
|
+
In the case where there are no results in the retrieval phase, we simply
|
429
|
+
return the find results and the messages indicating that there is not enough
|
430
|
+
context in the corpus to answer.
|
431
|
+
"""
|
432
|
+
yield self._ndjson_encode(RetrievalAskResponseItem(results=self.main_results))
|
433
|
+
yield self._ndjson_encode(AnswerAskResponseItem(text=NOT_ENOUGH_CONTEXT_ANSWER))
|
434
|
+
status = AnswerStatusCode.NO_CONTEXT
|
435
|
+
yield self._ndjson_encode(StatusAskResponseItem(code=status.value, status=status.prettify()))
|
436
|
+
|
437
|
+
async def json(self) -> str:
|
438
|
+
return SyncAskResponse(
|
439
|
+
answer=NOT_ENOUGH_CONTEXT_ANSWER,
|
440
|
+
retrieval_results=self.main_results,
|
441
|
+
status=AnswerStatusCode.NO_CONTEXT,
|
442
|
+
).model_dump_json()
|
443
|
+
|
444
|
+
|
445
|
+
async def ask(
|
446
|
+
*,
|
447
|
+
kbid: str,
|
448
|
+
ask_request: AskRequest,
|
449
|
+
user_id: str,
|
450
|
+
client_type: NucliaDBClientType,
|
451
|
+
origin: str,
|
452
|
+
resource: Optional[str] = None,
|
453
|
+
) -> AskResult:
|
454
|
+
metrics = RAGMetrics()
|
455
|
+
chat_history = ask_request.context or []
|
456
|
+
user_context = ask_request.extra_context or []
|
457
|
+
user_query = ask_request.query
|
458
|
+
|
459
|
+
# Maybe rephrase the query
|
460
|
+
rephrased_query = None
|
461
|
+
if len(chat_history) > 0 or len(user_context) > 0:
|
462
|
+
try:
|
463
|
+
with metrics.time("rephrase"):
|
464
|
+
rephrased_query = await rephrase_query(
|
465
|
+
kbid,
|
466
|
+
chat_history=chat_history,
|
467
|
+
query=user_query,
|
468
|
+
user_id=user_id,
|
469
|
+
user_context=user_context,
|
470
|
+
generative_model=ask_request.generative_model,
|
471
|
+
)
|
472
|
+
except RephraseMissingContextError:
|
473
|
+
logger.info("Failed to rephrase ask query, using original")
|
474
|
+
|
475
|
+
try:
|
476
|
+
retrieval_results = await retrieval_step(
|
477
|
+
kbid=kbid,
|
478
|
+
# Prefer the rephrased query for retrieval if available
|
479
|
+
main_query=rephrased_query or user_query,
|
480
|
+
ask_request=ask_request,
|
481
|
+
client_type=client_type,
|
482
|
+
user_id=user_id,
|
483
|
+
origin=origin,
|
484
|
+
metrics=metrics,
|
485
|
+
resource=resource,
|
486
|
+
)
|
487
|
+
except NoRetrievalResultsError as err:
|
488
|
+
# If a retrieval was attempted but no results were found,
|
489
|
+
# early return the ask endpoint without querying the generative model
|
490
|
+
return NotEnoughContextAskResult(
|
491
|
+
main_results=err.main_query,
|
492
|
+
prequeries_results=err.prequeries,
|
493
|
+
)
|
494
|
+
|
495
|
+
query_parser = retrieval_results.query_parser
|
496
|
+
|
497
|
+
# Now we build the prompt context
|
498
|
+
with metrics.time("context_building"):
|
499
|
+
query_parser.max_tokens = ask_request.max_tokens # type: ignore
|
500
|
+
max_tokens_context = await query_parser.get_max_tokens_context()
|
501
|
+
prompt_context_builder = PromptContextBuilder(
|
502
|
+
kbid=kbid,
|
503
|
+
ordered_paragraphs=[match.paragraph for match in retrieval_results.best_matches],
|
504
|
+
resource=resource,
|
505
|
+
user_context=user_context,
|
506
|
+
strategies=ask_request.rag_strategies,
|
507
|
+
image_strategies=ask_request.rag_images_strategies,
|
508
|
+
max_context_characters=tokens_to_chars(max_tokens_context),
|
509
|
+
visual_llm=await query_parser.get_visual_llm_enabled(),
|
510
|
+
)
|
511
|
+
(
|
512
|
+
prompt_context,
|
513
|
+
prompt_context_order,
|
514
|
+
prompt_context_images,
|
515
|
+
) = await prompt_context_builder.build()
|
516
|
+
|
517
|
+
# Make the chat request to the predict API
|
518
|
+
custom_prompt = parse_custom_prompt(ask_request)
|
519
|
+
chat_model = ChatModel(
|
520
|
+
user_id=user_id,
|
521
|
+
system=custom_prompt.system,
|
522
|
+
user_prompt=UserPrompt(prompt=custom_prompt.user) if custom_prompt.user else None,
|
523
|
+
query_context=prompt_context,
|
524
|
+
query_context_order=prompt_context_order,
|
525
|
+
chat_history=chat_history,
|
526
|
+
question=user_query,
|
527
|
+
truncate=True,
|
528
|
+
citations=ask_request.citations,
|
529
|
+
citation_threshold=ask_request.citation_threshold,
|
530
|
+
generative_model=ask_request.generative_model,
|
531
|
+
max_tokens=query_parser.get_max_tokens_answer(),
|
532
|
+
query_context_images=prompt_context_images,
|
533
|
+
json_schema=ask_request.answer_json_schema,
|
534
|
+
rerank_context=False,
|
535
|
+
top_k=ask_request.top_k,
|
536
|
+
)
|
537
|
+
with metrics.time("stream_start"):
|
538
|
+
predict = get_predict()
|
539
|
+
(
|
540
|
+
nuclia_learning_id,
|
541
|
+
nuclia_learning_model,
|
542
|
+
predict_answer_stream,
|
543
|
+
) = await predict.chat_query_ndjson(kbid, chat_model)
|
544
|
+
debug_chat_model = chat_model
|
545
|
+
|
546
|
+
auditor = ChatAuditor(
|
547
|
+
kbid=kbid,
|
548
|
+
user_id=user_id,
|
549
|
+
client_type=client_type,
|
550
|
+
origin=origin,
|
551
|
+
user_query=user_query,
|
552
|
+
rephrased_query=rephrased_query,
|
553
|
+
chat_history=chat_history,
|
554
|
+
learning_id=nuclia_learning_id,
|
555
|
+
query_context=prompt_context,
|
556
|
+
query_context_order=prompt_context_order,
|
557
|
+
model=nuclia_learning_model,
|
558
|
+
)
|
559
|
+
return AskResult(
|
560
|
+
kbid=kbid,
|
561
|
+
ask_request=ask_request,
|
562
|
+
main_results=retrieval_results.main_query,
|
563
|
+
prequeries_results=retrieval_results.prequeries,
|
564
|
+
nuclia_learning_id=nuclia_learning_id,
|
565
|
+
predict_answer_stream=predict_answer_stream, # type: ignore
|
566
|
+
prompt_context=prompt_context,
|
567
|
+
prompt_context_order=prompt_context_order,
|
568
|
+
auditor=auditor,
|
569
|
+
metrics=metrics,
|
570
|
+
best_matches=retrieval_results.best_matches,
|
571
|
+
debug_chat_model=debug_chat_model,
|
572
|
+
)
|
573
|
+
|
574
|
+
|
575
|
+
def handled_ask_exceptions(func):
|
576
|
+
@functools.wraps(func)
|
577
|
+
async def wrapper(*args, **kwargs):
|
578
|
+
try:
|
579
|
+
return await func(*args, **kwargs)
|
580
|
+
except KnowledgeBoxNotFound:
|
581
|
+
return HTTPClientError(
|
582
|
+
status_code=404,
|
583
|
+
detail=f"Knowledge Box not found.",
|
584
|
+
)
|
585
|
+
except LimitsExceededError as exc:
|
586
|
+
return HTTPClientError(status_code=exc.status_code, detail=exc.detail)
|
587
|
+
except predict.ProxiedPredictAPIError as err:
|
588
|
+
return HTTPClientError(
|
589
|
+
status_code=err.status,
|
590
|
+
detail=err.detail,
|
591
|
+
)
|
592
|
+
except IncompleteFindResultsError:
|
593
|
+
return HTTPClientError(
|
594
|
+
status_code=529,
|
595
|
+
detail="Temporary error on information retrieval. Please try again.",
|
596
|
+
)
|
597
|
+
except predict.RephraseMissingContextError:
|
598
|
+
return HTTPClientError(
|
599
|
+
status_code=412,
|
600
|
+
detail="Unable to rephrase the query with the provided context.",
|
601
|
+
)
|
602
|
+
except predict.RephraseError as err:
|
603
|
+
return HTTPClientError(
|
604
|
+
status_code=529,
|
605
|
+
detail=f"Temporary error while rephrasing the query. Please try again later. Error: {err}",
|
606
|
+
)
|
607
|
+
except InvalidQueryError as exc:
|
608
|
+
return HTTPClientError(status_code=412, detail=str(exc))
|
609
|
+
|
610
|
+
return wrapper
|
611
|
+
|
612
|
+
|
613
|
+
def parse_prequeries(ask_request: AskRequest) -> Optional[PreQueriesStrategy]:
|
614
|
+
query_ids = []
|
615
|
+
for rag_strategy in ask_request.rag_strategies:
|
616
|
+
if rag_strategy.name == RagStrategyName.PREQUERIES:
|
617
|
+
prequeries = cast(PreQueriesStrategy, rag_strategy)
|
618
|
+
# Give each query a unique id if they don't have one
|
619
|
+
for index, query in enumerate(prequeries.queries):
|
620
|
+
if query.id is None:
|
621
|
+
query.id = f"prequery_{index}"
|
622
|
+
if query.id in query_ids:
|
623
|
+
raise InvalidQueryError(
|
624
|
+
"rag_strategies",
|
625
|
+
"Prequeries must have unique ids",
|
626
|
+
)
|
627
|
+
query_ids.append(query.id)
|
628
|
+
return prequeries
|
629
|
+
return None
|
630
|
+
|
631
|
+
|
632
|
+
async def retrieval_step(
|
633
|
+
kbid: str,
|
634
|
+
main_query: str,
|
635
|
+
ask_request: AskRequest,
|
636
|
+
client_type: NucliaDBClientType,
|
637
|
+
user_id: str,
|
638
|
+
origin: str,
|
639
|
+
metrics: RAGMetrics,
|
640
|
+
resource: Optional[str] = None,
|
641
|
+
) -> RetrievalResults:
|
642
|
+
"""
|
643
|
+
This function encapsulates all the logic related to retrieval in the ask endpoint.
|
644
|
+
"""
|
645
|
+
if resource is None:
|
646
|
+
return await retrieval_in_kb(
|
647
|
+
kbid,
|
648
|
+
main_query,
|
649
|
+
ask_request,
|
650
|
+
client_type,
|
651
|
+
user_id,
|
652
|
+
origin,
|
653
|
+
metrics,
|
654
|
+
)
|
655
|
+
else:
|
656
|
+
return await retrieval_in_resource(
|
657
|
+
kbid,
|
658
|
+
resource,
|
659
|
+
main_query,
|
660
|
+
ask_request,
|
661
|
+
client_type,
|
662
|
+
user_id,
|
663
|
+
origin,
|
664
|
+
metrics,
|
665
|
+
)
|
666
|
+
|
667
|
+
|
668
|
+
async def retrieval_in_kb(
|
669
|
+
kbid: str,
|
670
|
+
main_query: str,
|
671
|
+
ask_request: AskRequest,
|
672
|
+
client_type: NucliaDBClientType,
|
673
|
+
user_id: str,
|
674
|
+
origin: str,
|
675
|
+
metrics: RAGMetrics,
|
676
|
+
) -> RetrievalResults:
|
677
|
+
prequeries = parse_prequeries(ask_request)
|
678
|
+
with metrics.time("retrieval"):
|
679
|
+
main_results, prequeries_results, query_parser = await get_find_results(
|
680
|
+
kbid=kbid,
|
681
|
+
query=main_query,
|
682
|
+
item=ask_request,
|
683
|
+
ndb_client=client_type,
|
684
|
+
user=user_id,
|
685
|
+
origin=origin,
|
686
|
+
metrics=metrics,
|
687
|
+
prequeries_strategy=prequeries,
|
688
|
+
)
|
689
|
+
if len(main_results.resources) == 0 and all(
|
690
|
+
len(prequery_result.resources) == 0 for (_, prequery_result) in prequeries_results or []
|
691
|
+
):
|
692
|
+
raise NoRetrievalResultsError(main_results, prequeries_results)
|
693
|
+
|
694
|
+
main_query_weight = prequeries.main_query_weight if prequeries is not None else 1.0
|
695
|
+
best_matches = compute_best_matches(
|
696
|
+
main_results=main_results,
|
697
|
+
prequeries_results=prequeries_results,
|
698
|
+
main_query_weight=main_query_weight,
|
699
|
+
)
|
700
|
+
return RetrievalResults(
|
701
|
+
main_query=main_results,
|
702
|
+
prequeries=prequeries_results,
|
703
|
+
query_parser=query_parser,
|
704
|
+
main_query_weight=main_query_weight,
|
705
|
+
best_matches=best_matches,
|
706
|
+
)
|
707
|
+
|
708
|
+
|
709
|
+
async def retrieval_in_resource(
|
710
|
+
kbid: str,
|
711
|
+
resource: str,
|
712
|
+
main_query: str,
|
713
|
+
ask_request: AskRequest,
|
714
|
+
client_type: NucliaDBClientType,
|
715
|
+
user_id: str,
|
716
|
+
origin: str,
|
717
|
+
metrics: RAGMetrics,
|
718
|
+
) -> RetrievalResults:
|
719
|
+
if any(strategy.name == "full_resource" for strategy in ask_request.rag_strategies):
|
720
|
+
# Retrieval is not needed if we are chatting on a specific resource and the full_resource strategy is enabled
|
721
|
+
return RetrievalResults(
|
722
|
+
main_query=KnowledgeboxFindResults(resources={}, min_score=None),
|
723
|
+
prequeries=None,
|
724
|
+
query_parser=QueryParser(
|
725
|
+
kbid=kbid,
|
726
|
+
features=[],
|
727
|
+
query="",
|
728
|
+
label_filters=ask_request.filters,
|
729
|
+
keyword_filters=ask_request.keyword_filters,
|
730
|
+
top_k=0,
|
731
|
+
min_score=MinScore(),
|
732
|
+
),
|
733
|
+
main_query_weight=1.0,
|
734
|
+
)
|
735
|
+
|
736
|
+
prequeries = parse_prequeries(ask_request)
|
737
|
+
if prequeries is None and ask_request.answer_json_schema is not None and main_query == "":
|
738
|
+
prequeries = calculate_prequeries_for_json_schema(ask_request)
|
739
|
+
|
740
|
+
# Make sure the retrieval is scoped to the resource if provided
|
741
|
+
ask_request.resource_filters = [resource]
|
742
|
+
if prequeries is not None:
|
743
|
+
for prequery in prequeries.queries:
|
744
|
+
if prequery.prefilter is True:
|
745
|
+
raise InvalidQueryError(
|
746
|
+
"rag_strategies",
|
747
|
+
"Prequeries with prefilter are not supported when asking on a resource",
|
748
|
+
)
|
749
|
+
prequery.request.resource_filters = [resource]
|
750
|
+
|
751
|
+
with metrics.time("retrieval"):
|
752
|
+
main_results, prequeries_results, query_parser = await get_find_results(
|
753
|
+
kbid=kbid,
|
754
|
+
query=main_query,
|
755
|
+
item=ask_request,
|
756
|
+
ndb_client=client_type,
|
757
|
+
user=user_id,
|
758
|
+
origin=origin,
|
759
|
+
metrics=metrics,
|
760
|
+
prequeries_strategy=prequeries,
|
761
|
+
)
|
762
|
+
if len(main_results.resources) == 0 and all(
|
763
|
+
len(prequery_result.resources) == 0 for (_, prequery_result) in prequeries_results or []
|
764
|
+
):
|
765
|
+
raise NoRetrievalResultsError(main_results, prequeries_results)
|
766
|
+
main_query_weight = prequeries.main_query_weight if prequeries is not None else 1.0
|
767
|
+
best_matches = compute_best_matches(
|
768
|
+
main_results=main_results,
|
769
|
+
prequeries_results=prequeries_results,
|
770
|
+
main_query_weight=main_query_weight,
|
771
|
+
)
|
772
|
+
return RetrievalResults(
|
773
|
+
main_query=main_results,
|
774
|
+
prequeries=prequeries_results,
|
775
|
+
query_parser=query_parser,
|
776
|
+
main_query_weight=main_query_weight,
|
777
|
+
best_matches=best_matches,
|
778
|
+
)
|
779
|
+
|
780
|
+
|
781
|
+
def compute_best_matches(
|
782
|
+
main_results: KnowledgeboxFindResults,
|
783
|
+
prequeries_results: Optional[list[PreQueryResult]] = None,
|
784
|
+
main_query_weight: float = 1.0,
|
785
|
+
) -> list[RetrievalMatch]:
|
786
|
+
"""
|
787
|
+
Returns the list of matches of the retrieval results, ordered by relevance (descending weighted score).
|
788
|
+
|
789
|
+
If prequeries_results is provided, the paragraphs of the prequeries are weighted according to the
|
790
|
+
normalized weight of the prequery. The paragraph score is not modified, but it is used to determine the order in which they
|
791
|
+
are presented in the LLM prompt context.
|
792
|
+
|
793
|
+
If a paragraph is matched in various prequeries, the final weighted score is the sum of the weighted scores for each prequery.
|
794
|
+
|
795
|
+
`main_query_weight` is the weight given to the paragraphs matching the main query when calculating the final score.
|
796
|
+
"""
|
797
|
+
|
798
|
+
def iter_paragraphs(results: KnowledgeboxFindResults):
|
799
|
+
for resource in results.resources.values():
|
800
|
+
for field in resource.fields.values():
|
801
|
+
for paragraph in field.paragraphs.values():
|
802
|
+
yield paragraph
|
803
|
+
|
804
|
+
total_weights = main_query_weight + sum(prequery.weight for prequery, _ in prequeries_results or [])
|
805
|
+
paragraph_id_to_match: dict[str, RetrievalMatch] = {}
|
806
|
+
for paragraph in iter_paragraphs(main_results):
|
807
|
+
normalized_weight = main_query_weight / total_weights
|
808
|
+
rmatch = RetrievalMatch(
|
809
|
+
paragraph=paragraph,
|
810
|
+
weighted_score=paragraph.score * normalized_weight,
|
811
|
+
)
|
812
|
+
paragraph_id_to_match[paragraph.id] = rmatch
|
813
|
+
|
814
|
+
for prequery, prequery_results in prequeries_results or []:
|
815
|
+
for paragraph in iter_paragraphs(prequery_results):
|
816
|
+
normalized_weight = prequery.weight / total_weights
|
817
|
+
weighted_score = paragraph.score * normalized_weight
|
818
|
+
if paragraph.id in paragraph_id_to_match:
|
819
|
+
rmatch = paragraph_id_to_match[paragraph.id]
|
820
|
+
# If a paragraph is matched in various prequeries, the final score is the
|
821
|
+
# sum of the weighted scores
|
822
|
+
rmatch.weighted_score += weighted_score
|
823
|
+
else:
|
824
|
+
paragraph_id_to_match[paragraph.id] = RetrievalMatch(
|
825
|
+
paragraph=paragraph,
|
826
|
+
weighted_score=weighted_score,
|
827
|
+
)
|
828
|
+
|
829
|
+
return sorted(
|
830
|
+
paragraph_id_to_match.values(),
|
831
|
+
key=lambda match: match.weighted_score,
|
832
|
+
reverse=True,
|
833
|
+
)
|
834
|
+
|
835
|
+
|
836
|
+
def calculate_prequeries_for_json_schema(
|
837
|
+
ask_request: AskRequest,
|
838
|
+
) -> Optional[PreQueriesStrategy]:
|
839
|
+
"""
|
840
|
+
This function generates a PreQueriesStrategy with a query for each property in the JSON schema
|
841
|
+
found in ask_request.answer_json_schema.
|
842
|
+
|
843
|
+
This is useful for the use-case where the user is asking for a structured answer on a corpus
|
844
|
+
that is too big to send to the generative model.
|
845
|
+
|
846
|
+
For instance, a JSON schema like this:
|
847
|
+
{
|
848
|
+
"name": "book_ordering",
|
849
|
+
"description": "Structured answer for a book to order",
|
850
|
+
"parameters": {
|
851
|
+
"type": "object",
|
852
|
+
"properties": {
|
853
|
+
"title": {
|
854
|
+
"type": "string",
|
855
|
+
"description": "The title of the book"
|
856
|
+
},
|
857
|
+
"author": {
|
858
|
+
"type": "string",
|
859
|
+
"description": "The author of the book"
|
860
|
+
},
|
861
|
+
},
|
862
|
+
"required": ["title", "author"]
|
863
|
+
}
|
864
|
+
}
|
865
|
+
Will generate a PreQueriesStrategy with 2 queries, one for each property in the JSON schema, with equal weights
|
866
|
+
[
|
867
|
+
PreQuery(request=FindRequest(query="The title of the book", ...), weight=1.0),
|
868
|
+
PreQuery(request=FindRequest(query="The author of the book", ...), weight=1.0),
|
869
|
+
]
|
870
|
+
"""
|
871
|
+
prequeries: list[PreQuery] = []
|
872
|
+
json_schema = ask_request.answer_json_schema or {}
|
873
|
+
features = []
|
874
|
+
if ChatOptions.SEMANTIC in ask_request.features:
|
875
|
+
features.append(SearchOptions.SEMANTIC)
|
876
|
+
if ChatOptions.KEYWORD in ask_request.features:
|
877
|
+
features.append(SearchOptions.KEYWORD)
|
878
|
+
|
879
|
+
properties = json_schema.get("parameters", {}).get("properties", {})
|
880
|
+
if len(properties) == 0: # pragma: no cover
|
881
|
+
return None
|
882
|
+
for prop_name, prop_def in properties.items():
|
883
|
+
query = prop_name
|
884
|
+
if prop_def.get("description"):
|
885
|
+
query += f": {prop_def['description']}"
|
886
|
+
req = FindRequest(
|
887
|
+
query=query,
|
888
|
+
features=features,
|
889
|
+
filters=[],
|
890
|
+
keyword_filters=[],
|
891
|
+
top_k=10,
|
892
|
+
min_score=ask_request.min_score,
|
893
|
+
vectorset=ask_request.vectorset,
|
894
|
+
highlight=False,
|
895
|
+
debug=False,
|
896
|
+
show=[],
|
897
|
+
with_duplicates=False,
|
898
|
+
with_synonyms=False,
|
899
|
+
resource_filters=[], # to be filled with the resource filter
|
900
|
+
rephrase=ask_request.rephrase,
|
901
|
+
rephrase_prompt=parse_rephrase_prompt(ask_request),
|
902
|
+
security=ask_request.security,
|
903
|
+
autofilter=False,
|
904
|
+
)
|
905
|
+
prequery = PreQuery(
|
906
|
+
request=req,
|
907
|
+
weight=1.0,
|
908
|
+
)
|
909
|
+
prequeries.append(prequery)
|
910
|
+
try:
|
911
|
+
strategy = PreQueriesStrategy(queries=prequeries)
|
912
|
+
except ValidationError:
|
913
|
+
raise AnswerJsonSchemaTooLong(
|
914
|
+
"Answer JSON schema with too many properties generated too many prequeries"
|
915
|
+
)
|
916
|
+
|
917
|
+
ask_request.rag_strategies = [strategy]
|
918
|
+
return strategy
|