nucliadb 4.0.0.post542__py3-none-any.whl → 6.2.1.post2798__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- migrations/0003_allfields_key.py +1 -35
- migrations/0009_upgrade_relations_and_texts_to_v2.py +4 -2
- migrations/0010_fix_corrupt_indexes.py +10 -10
- migrations/0011_materialize_labelset_ids.py +1 -16
- migrations/0012_rollover_shards.py +5 -10
- migrations/0014_rollover_shards.py +4 -5
- migrations/0015_targeted_rollover.py +5 -10
- migrations/0016_upgrade_to_paragraphs_v2.py +25 -28
- migrations/0017_multiple_writable_shards.py +2 -4
- migrations/0018_purge_orphan_kbslugs.py +5 -7
- migrations/0019_upgrade_to_paragraphs_v3.py +25 -28
- migrations/0020_drain_nodes_from_cluster.py +3 -3
- nucliadb/standalone/tests/unit/test_run.py → migrations/0021_overwrite_vectorsets_key.py +16 -19
- nucliadb/tests/unit/test_openapi.py → migrations/0022_fix_paragraph_deletion_bug.py +16 -11
- migrations/0023_backfill_pg_catalog.py +80 -0
- migrations/0025_assign_models_to_kbs_v2.py +113 -0
- migrations/0026_fix_high_cardinality_content_types.py +61 -0
- migrations/0027_rollover_texts3.py +73 -0
- nucliadb/ingest/fields/date.py → migrations/pg/0001_bootstrap.py +10 -12
- migrations/pg/0002_catalog.py +42 -0
- nucliadb/ingest/tests/unit/test_settings.py → migrations/pg/0003_catalog_kbid_index.py +5 -3
- nucliadb/common/cluster/base.py +30 -16
- nucliadb/common/cluster/discovery/base.py +6 -14
- nucliadb/common/cluster/discovery/k8s.py +9 -19
- nucliadb/common/cluster/discovery/manual.py +1 -3
- nucliadb/common/cluster/discovery/utils.py +1 -3
- nucliadb/common/cluster/grpc_node_dummy.py +3 -11
- nucliadb/common/cluster/index_node.py +10 -19
- nucliadb/common/cluster/manager.py +174 -59
- nucliadb/common/cluster/rebalance.py +27 -29
- nucliadb/common/cluster/rollover.py +353 -194
- nucliadb/common/cluster/settings.py +6 -0
- nucliadb/common/cluster/standalone/grpc_node_binding.py +13 -64
- nucliadb/common/cluster/standalone/index_node.py +4 -11
- nucliadb/common/cluster/standalone/service.py +2 -6
- nucliadb/common/cluster/standalone/utils.py +2 -6
- nucliadb/common/cluster/utils.py +29 -22
- nucliadb/common/constants.py +20 -0
- nucliadb/common/context/__init__.py +3 -0
- nucliadb/common/context/fastapi.py +8 -5
- nucliadb/{tests/knowledgeboxes/__init__.py → common/counters.py} +8 -2
- nucliadb/common/datamanagers/__init__.py +7 -1
- nucliadb/common/datamanagers/atomic.py +22 -4
- nucliadb/common/datamanagers/cluster.py +5 -5
- nucliadb/common/datamanagers/entities.py +6 -16
- nucliadb/common/datamanagers/fields.py +84 -0
- nucliadb/common/datamanagers/kb.py +83 -37
- nucliadb/common/datamanagers/labels.py +26 -56
- nucliadb/common/datamanagers/processing.py +2 -6
- nucliadb/common/datamanagers/resources.py +41 -103
- nucliadb/common/datamanagers/rollover.py +76 -15
- nucliadb/common/datamanagers/synonyms.py +1 -1
- nucliadb/common/datamanagers/utils.py +15 -6
- nucliadb/common/datamanagers/vectorsets.py +110 -0
- nucliadb/common/external_index_providers/base.py +257 -0
- nucliadb/{ingest/tests/unit/orm/test_orm_utils.py → common/external_index_providers/exceptions.py} +9 -8
- nucliadb/common/external_index_providers/manager.py +101 -0
- nucliadb/common/external_index_providers/pinecone.py +933 -0
- nucliadb/common/external_index_providers/settings.py +52 -0
- nucliadb/common/http_clients/auth.py +3 -6
- nucliadb/common/http_clients/processing.py +6 -11
- nucliadb/common/http_clients/utils.py +1 -3
- nucliadb/common/ids.py +240 -0
- nucliadb/common/locking.py +29 -7
- nucliadb/common/maindb/driver.py +11 -35
- nucliadb/common/maindb/exceptions.py +3 -0
- nucliadb/common/maindb/local.py +22 -9
- nucliadb/common/maindb/pg.py +206 -111
- nucliadb/common/maindb/utils.py +11 -42
- nucliadb/common/models_utils/from_proto.py +479 -0
- nucliadb/common/models_utils/to_proto.py +60 -0
- nucliadb/common/nidx.py +260 -0
- nucliadb/export_import/datamanager.py +25 -19
- nucliadb/export_import/exporter.py +5 -11
- nucliadb/export_import/importer.py +5 -7
- nucliadb/export_import/models.py +3 -3
- nucliadb/export_import/tasks.py +4 -4
- nucliadb/export_import/utils.py +25 -37
- nucliadb/health.py +1 -3
- nucliadb/ingest/app.py +15 -11
- nucliadb/ingest/consumer/auditing.py +21 -19
- nucliadb/ingest/consumer/consumer.py +82 -47
- nucliadb/ingest/consumer/materializer.py +5 -12
- nucliadb/ingest/consumer/pull.py +12 -27
- nucliadb/ingest/consumer/service.py +19 -17
- nucliadb/ingest/consumer/shard_creator.py +2 -4
- nucliadb/ingest/consumer/utils.py +1 -3
- nucliadb/ingest/fields/base.py +137 -105
- nucliadb/ingest/fields/conversation.py +18 -5
- nucliadb/ingest/fields/exceptions.py +1 -4
- nucliadb/ingest/fields/file.py +7 -16
- nucliadb/ingest/fields/link.py +5 -10
- nucliadb/ingest/fields/text.py +9 -4
- nucliadb/ingest/orm/brain.py +200 -213
- nucliadb/ingest/orm/broker_message.py +181 -0
- nucliadb/ingest/orm/entities.py +36 -51
- nucliadb/ingest/orm/exceptions.py +12 -0
- nucliadb/ingest/orm/knowledgebox.py +322 -197
- nucliadb/ingest/orm/processor/__init__.py +2 -700
- nucliadb/ingest/orm/processor/auditing.py +4 -23
- nucliadb/ingest/orm/processor/data_augmentation.py +164 -0
- nucliadb/ingest/orm/processor/pgcatalog.py +84 -0
- nucliadb/ingest/orm/processor/processor.py +752 -0
- nucliadb/ingest/orm/processor/sequence_manager.py +1 -1
- nucliadb/ingest/orm/resource.py +249 -403
- nucliadb/ingest/orm/utils.py +4 -4
- nucliadb/ingest/partitions.py +3 -9
- nucliadb/ingest/processing.py +70 -73
- nucliadb/ingest/py.typed +0 -0
- nucliadb/ingest/serialize.py +37 -167
- nucliadb/ingest/service/__init__.py +1 -3
- nucliadb/ingest/service/writer.py +185 -412
- nucliadb/ingest/settings.py +10 -20
- nucliadb/ingest/utils.py +3 -6
- nucliadb/learning_proxy.py +242 -55
- nucliadb/metrics_exporter.py +30 -19
- nucliadb/middleware/__init__.py +1 -3
- nucliadb/migrator/command.py +1 -3
- nucliadb/migrator/datamanager.py +13 -13
- nucliadb/migrator/migrator.py +47 -30
- nucliadb/migrator/utils.py +18 -10
- nucliadb/purge/__init__.py +139 -33
- nucliadb/purge/orphan_shards.py +7 -13
- nucliadb/reader/__init__.py +1 -3
- nucliadb/reader/api/models.py +1 -12
- nucliadb/reader/api/v1/__init__.py +0 -1
- nucliadb/reader/api/v1/download.py +21 -88
- nucliadb/reader/api/v1/export_import.py +1 -1
- nucliadb/reader/api/v1/knowledgebox.py +10 -10
- nucliadb/reader/api/v1/learning_config.py +2 -6
- nucliadb/reader/api/v1/resource.py +62 -88
- nucliadb/reader/api/v1/services.py +64 -83
- nucliadb/reader/app.py +12 -29
- nucliadb/reader/lifecycle.py +18 -4
- nucliadb/reader/py.typed +0 -0
- nucliadb/reader/reader/notifications.py +10 -28
- nucliadb/search/__init__.py +1 -3
- nucliadb/search/api/v1/__init__.py +1 -2
- nucliadb/search/api/v1/ask.py +17 -10
- nucliadb/search/api/v1/catalog.py +184 -0
- nucliadb/search/api/v1/feedback.py +16 -24
- nucliadb/search/api/v1/find.py +36 -36
- nucliadb/search/api/v1/knowledgebox.py +89 -60
- nucliadb/search/api/v1/resource/ask.py +2 -8
- nucliadb/search/api/v1/resource/search.py +49 -70
- nucliadb/search/api/v1/search.py +44 -210
- nucliadb/search/api/v1/suggest.py +39 -54
- nucliadb/search/app.py +12 -32
- nucliadb/search/lifecycle.py +10 -3
- nucliadb/search/predict.py +136 -187
- nucliadb/search/py.typed +0 -0
- nucliadb/search/requesters/utils.py +25 -58
- nucliadb/search/search/cache.py +149 -20
- nucliadb/search/search/chat/ask.py +571 -123
- nucliadb/search/{tests/unit/test_run.py → search/chat/exceptions.py} +14 -14
- nucliadb/search/search/chat/images.py +41 -17
- nucliadb/search/search/chat/prompt.py +817 -266
- nucliadb/search/search/chat/query.py +213 -309
- nucliadb/{tests/migrations/__init__.py → search/search/cut.py} +8 -8
- nucliadb/search/search/fetch.py +43 -36
- nucliadb/search/search/filters.py +9 -15
- nucliadb/search/search/find.py +214 -53
- nucliadb/search/search/find_merge.py +408 -391
- nucliadb/search/search/hydrator.py +191 -0
- nucliadb/search/search/merge.py +187 -223
- nucliadb/search/search/metrics.py +73 -2
- nucliadb/search/search/paragraphs.py +64 -106
- nucliadb/search/search/pgcatalog.py +233 -0
- nucliadb/search/search/predict_proxy.py +1 -1
- nucliadb/search/search/query.py +305 -150
- nucliadb/search/search/query_parser/exceptions.py +22 -0
- nucliadb/search/search/query_parser/models.py +101 -0
- nucliadb/search/search/query_parser/parser.py +183 -0
- nucliadb/search/search/rank_fusion.py +204 -0
- nucliadb/search/search/rerankers.py +270 -0
- nucliadb/search/search/shards.py +3 -32
- nucliadb/search/search/summarize.py +7 -18
- nucliadb/search/search/utils.py +27 -4
- nucliadb/search/settings.py +15 -1
- nucliadb/standalone/api_router.py +4 -10
- nucliadb/standalone/app.py +8 -14
- nucliadb/standalone/auth.py +7 -21
- nucliadb/standalone/config.py +7 -10
- nucliadb/standalone/lifecycle.py +26 -25
- nucliadb/standalone/migrations.py +1 -3
- nucliadb/standalone/purge.py +1 -1
- nucliadb/standalone/py.typed +0 -0
- nucliadb/standalone/run.py +3 -6
- nucliadb/standalone/settings.py +9 -16
- nucliadb/standalone/versions.py +15 -5
- nucliadb/tasks/consumer.py +8 -12
- nucliadb/tasks/producer.py +7 -6
- nucliadb/tests/config.py +53 -0
- nucliadb/train/__init__.py +1 -3
- nucliadb/train/api/utils.py +1 -2
- nucliadb/train/api/v1/shards.py +1 -1
- nucliadb/train/api/v1/trainset.py +2 -4
- nucliadb/train/app.py +10 -31
- nucliadb/train/generator.py +10 -19
- nucliadb/train/generators/field_classifier.py +7 -19
- nucliadb/train/generators/field_streaming.py +156 -0
- nucliadb/train/generators/image_classifier.py +12 -18
- nucliadb/train/generators/paragraph_classifier.py +5 -9
- nucliadb/train/generators/paragraph_streaming.py +6 -9
- nucliadb/train/generators/question_answer_streaming.py +19 -20
- nucliadb/train/generators/sentence_classifier.py +9 -15
- nucliadb/train/generators/token_classifier.py +48 -39
- nucliadb/train/generators/utils.py +14 -18
- nucliadb/train/lifecycle.py +7 -3
- nucliadb/train/nodes.py +23 -32
- nucliadb/train/py.typed +0 -0
- nucliadb/train/servicer.py +13 -21
- nucliadb/train/settings.py +2 -6
- nucliadb/train/types.py +13 -10
- nucliadb/train/upload.py +3 -6
- nucliadb/train/uploader.py +19 -23
- nucliadb/train/utils.py +1 -1
- nucliadb/writer/__init__.py +1 -3
- nucliadb/{ingest/fields/keywordset.py → writer/api/utils.py} +13 -10
- nucliadb/writer/api/v1/export_import.py +67 -14
- nucliadb/writer/api/v1/field.py +16 -269
- nucliadb/writer/api/v1/knowledgebox.py +218 -68
- nucliadb/writer/api/v1/resource.py +68 -88
- nucliadb/writer/api/v1/services.py +51 -70
- nucliadb/writer/api/v1/slug.py +61 -0
- nucliadb/writer/api/v1/transaction.py +67 -0
- nucliadb/writer/api/v1/upload.py +143 -117
- nucliadb/writer/app.py +6 -43
- nucliadb/writer/back_pressure.py +16 -38
- nucliadb/writer/exceptions.py +0 -4
- nucliadb/writer/lifecycle.py +21 -15
- nucliadb/writer/py.typed +0 -0
- nucliadb/writer/resource/audit.py +2 -1
- nucliadb/writer/resource/basic.py +48 -46
- nucliadb/writer/resource/field.py +37 -128
- nucliadb/writer/resource/origin.py +1 -2
- nucliadb/writer/settings.py +6 -2
- nucliadb/writer/tus/__init__.py +17 -15
- nucliadb/writer/tus/azure.py +111 -0
- nucliadb/writer/tus/dm.py +17 -5
- nucliadb/writer/tus/exceptions.py +1 -3
- nucliadb/writer/tus/gcs.py +49 -84
- nucliadb/writer/tus/local.py +21 -37
- nucliadb/writer/tus/s3.py +28 -68
- nucliadb/writer/tus/storage.py +5 -56
- nucliadb/writer/vectorsets.py +125 -0
- nucliadb-6.2.1.post2798.dist-info/METADATA +148 -0
- nucliadb-6.2.1.post2798.dist-info/RECORD +343 -0
- {nucliadb-4.0.0.post542.dist-info → nucliadb-6.2.1.post2798.dist-info}/WHEEL +1 -1
- nucliadb/common/maindb/redis.py +0 -194
- nucliadb/common/maindb/tikv.py +0 -433
- nucliadb/ingest/fields/layout.py +0 -58
- nucliadb/ingest/tests/conftest.py +0 -30
- nucliadb/ingest/tests/fixtures.py +0 -764
- nucliadb/ingest/tests/integration/consumer/__init__.py +0 -18
- nucliadb/ingest/tests/integration/consumer/test_auditing.py +0 -78
- nucliadb/ingest/tests/integration/consumer/test_materializer.py +0 -126
- nucliadb/ingest/tests/integration/consumer/test_pull.py +0 -144
- nucliadb/ingest/tests/integration/consumer/test_service.py +0 -81
- nucliadb/ingest/tests/integration/consumer/test_shard_creator.py +0 -68
- nucliadb/ingest/tests/integration/ingest/test_ingest.py +0 -684
- nucliadb/ingest/tests/integration/ingest/test_processing_engine.py +0 -95
- nucliadb/ingest/tests/integration/ingest/test_relations.py +0 -272
- nucliadb/ingest/tests/unit/consumer/__init__.py +0 -18
- nucliadb/ingest/tests/unit/consumer/test_auditing.py +0 -139
- nucliadb/ingest/tests/unit/consumer/test_consumer.py +0 -69
- nucliadb/ingest/tests/unit/consumer/test_pull.py +0 -60
- nucliadb/ingest/tests/unit/consumer/test_shard_creator.py +0 -140
- nucliadb/ingest/tests/unit/consumer/test_utils.py +0 -67
- nucliadb/ingest/tests/unit/orm/__init__.py +0 -19
- nucliadb/ingest/tests/unit/orm/test_brain.py +0 -247
- nucliadb/ingest/tests/unit/orm/test_brain_vectors.py +0 -74
- nucliadb/ingest/tests/unit/orm/test_processor.py +0 -131
- nucliadb/ingest/tests/unit/orm/test_resource.py +0 -331
- nucliadb/ingest/tests/unit/test_cache.py +0 -31
- nucliadb/ingest/tests/unit/test_partitions.py +0 -40
- nucliadb/ingest/tests/unit/test_processing.py +0 -171
- nucliadb/middleware/transaction.py +0 -117
- nucliadb/reader/api/v1/learning_collector.py +0 -63
- nucliadb/reader/tests/__init__.py +0 -19
- nucliadb/reader/tests/conftest.py +0 -31
- nucliadb/reader/tests/fixtures.py +0 -136
- nucliadb/reader/tests/test_list_resources.py +0 -75
- nucliadb/reader/tests/test_reader_file_download.py +0 -273
- nucliadb/reader/tests/test_reader_resource.py +0 -353
- nucliadb/reader/tests/test_reader_resource_field.py +0 -219
- nucliadb/search/api/v1/chat.py +0 -263
- nucliadb/search/api/v1/resource/chat.py +0 -174
- nucliadb/search/tests/__init__.py +0 -19
- nucliadb/search/tests/conftest.py +0 -33
- nucliadb/search/tests/fixtures.py +0 -199
- nucliadb/search/tests/node.py +0 -466
- nucliadb/search/tests/unit/__init__.py +0 -18
- nucliadb/search/tests/unit/api/__init__.py +0 -19
- nucliadb/search/tests/unit/api/v1/__init__.py +0 -19
- nucliadb/search/tests/unit/api/v1/resource/__init__.py +0 -19
- nucliadb/search/tests/unit/api/v1/resource/test_chat.py +0 -98
- nucliadb/search/tests/unit/api/v1/test_ask.py +0 -120
- nucliadb/search/tests/unit/api/v1/test_chat.py +0 -96
- nucliadb/search/tests/unit/api/v1/test_predict_proxy.py +0 -98
- nucliadb/search/tests/unit/api/v1/test_summarize.py +0 -99
- nucliadb/search/tests/unit/search/__init__.py +0 -18
- nucliadb/search/tests/unit/search/requesters/__init__.py +0 -18
- nucliadb/search/tests/unit/search/requesters/test_utils.py +0 -211
- nucliadb/search/tests/unit/search/search/__init__.py +0 -19
- nucliadb/search/tests/unit/search/search/test_shards.py +0 -45
- nucliadb/search/tests/unit/search/search/test_utils.py +0 -82
- nucliadb/search/tests/unit/search/test_chat_prompt.py +0 -270
- nucliadb/search/tests/unit/search/test_fetch.py +0 -108
- nucliadb/search/tests/unit/search/test_filters.py +0 -125
- nucliadb/search/tests/unit/search/test_paragraphs.py +0 -157
- nucliadb/search/tests/unit/search/test_predict_proxy.py +0 -106
- nucliadb/search/tests/unit/search/test_query.py +0 -153
- nucliadb/search/tests/unit/test_app.py +0 -79
- nucliadb/search/tests/unit/test_find_merge.py +0 -112
- nucliadb/search/tests/unit/test_merge.py +0 -34
- nucliadb/search/tests/unit/test_predict.py +0 -525
- nucliadb/standalone/tests/__init__.py +0 -19
- nucliadb/standalone/tests/conftest.py +0 -33
- nucliadb/standalone/tests/fixtures.py +0 -38
- nucliadb/standalone/tests/unit/__init__.py +0 -18
- nucliadb/standalone/tests/unit/test_api_router.py +0 -61
- nucliadb/standalone/tests/unit/test_auth.py +0 -169
- nucliadb/standalone/tests/unit/test_introspect.py +0 -35
- nucliadb/standalone/tests/unit/test_migrations.py +0 -63
- nucliadb/standalone/tests/unit/test_versions.py +0 -68
- nucliadb/tests/benchmarks/__init__.py +0 -19
- nucliadb/tests/benchmarks/test_search.py +0 -99
- nucliadb/tests/conftest.py +0 -32
- nucliadb/tests/fixtures.py +0 -735
- nucliadb/tests/knowledgeboxes/philosophy_books.py +0 -202
- nucliadb/tests/knowledgeboxes/ten_dummy_resources.py +0 -107
- nucliadb/tests/migrations/test_migration_0017.py +0 -76
- nucliadb/tests/migrations/test_migration_0018.py +0 -95
- nucliadb/tests/tikv.py +0 -240
- nucliadb/tests/unit/__init__.py +0 -19
- nucliadb/tests/unit/common/__init__.py +0 -19
- nucliadb/tests/unit/common/cluster/__init__.py +0 -19
- nucliadb/tests/unit/common/cluster/discovery/__init__.py +0 -19
- nucliadb/tests/unit/common/cluster/discovery/test_k8s.py +0 -172
- nucliadb/tests/unit/common/cluster/standalone/__init__.py +0 -18
- nucliadb/tests/unit/common/cluster/standalone/test_service.py +0 -114
- nucliadb/tests/unit/common/cluster/standalone/test_utils.py +0 -61
- nucliadb/tests/unit/common/cluster/test_cluster.py +0 -408
- nucliadb/tests/unit/common/cluster/test_kb_shard_manager.py +0 -173
- nucliadb/tests/unit/common/cluster/test_rebalance.py +0 -38
- nucliadb/tests/unit/common/cluster/test_rollover.py +0 -282
- nucliadb/tests/unit/common/maindb/__init__.py +0 -18
- nucliadb/tests/unit/common/maindb/test_driver.py +0 -127
- nucliadb/tests/unit/common/maindb/test_tikv.py +0 -53
- nucliadb/tests/unit/common/maindb/test_utils.py +0 -92
- nucliadb/tests/unit/common/test_context.py +0 -36
- nucliadb/tests/unit/export_import/__init__.py +0 -19
- nucliadb/tests/unit/export_import/test_datamanager.py +0 -37
- nucliadb/tests/unit/export_import/test_utils.py +0 -301
- nucliadb/tests/unit/migrator/__init__.py +0 -19
- nucliadb/tests/unit/migrator/test_migrator.py +0 -87
- nucliadb/tests/unit/tasks/__init__.py +0 -19
- nucliadb/tests/unit/tasks/conftest.py +0 -42
- nucliadb/tests/unit/tasks/test_consumer.py +0 -92
- nucliadb/tests/unit/tasks/test_producer.py +0 -95
- nucliadb/tests/unit/tasks/test_tasks.py +0 -58
- nucliadb/tests/unit/test_field_ids.py +0 -49
- nucliadb/tests/unit/test_health.py +0 -86
- nucliadb/tests/unit/test_kb_slugs.py +0 -54
- nucliadb/tests/unit/test_learning_proxy.py +0 -252
- nucliadb/tests/unit/test_metrics_exporter.py +0 -77
- nucliadb/tests/unit/test_purge.py +0 -136
- nucliadb/tests/utils/__init__.py +0 -74
- nucliadb/tests/utils/aiohttp_session.py +0 -44
- nucliadb/tests/utils/broker_messages/__init__.py +0 -171
- nucliadb/tests/utils/broker_messages/fields.py +0 -197
- nucliadb/tests/utils/broker_messages/helpers.py +0 -33
- nucliadb/tests/utils/entities.py +0 -78
- nucliadb/train/api/v1/check.py +0 -60
- nucliadb/train/tests/__init__.py +0 -19
- nucliadb/train/tests/conftest.py +0 -29
- nucliadb/train/tests/fixtures.py +0 -342
- nucliadb/train/tests/test_field_classification.py +0 -122
- nucliadb/train/tests/test_get_entities.py +0 -80
- nucliadb/train/tests/test_get_info.py +0 -51
- nucliadb/train/tests/test_get_ontology.py +0 -34
- nucliadb/train/tests/test_get_ontology_count.py +0 -63
- nucliadb/train/tests/test_image_classification.py +0 -221
- nucliadb/train/tests/test_list_fields.py +0 -39
- nucliadb/train/tests/test_list_paragraphs.py +0 -73
- nucliadb/train/tests/test_list_resources.py +0 -39
- nucliadb/train/tests/test_list_sentences.py +0 -71
- nucliadb/train/tests/test_paragraph_classification.py +0 -123
- nucliadb/train/tests/test_paragraph_streaming.py +0 -118
- nucliadb/train/tests/test_question_answer_streaming.py +0 -239
- nucliadb/train/tests/test_sentence_classification.py +0 -143
- nucliadb/train/tests/test_token_classification.py +0 -136
- nucliadb/train/tests/utils.py +0 -101
- nucliadb/writer/layouts/__init__.py +0 -51
- nucliadb/writer/layouts/v1.py +0 -59
- nucliadb/writer/tests/__init__.py +0 -19
- nucliadb/writer/tests/conftest.py +0 -31
- nucliadb/writer/tests/fixtures.py +0 -191
- nucliadb/writer/tests/test_fields.py +0 -475
- nucliadb/writer/tests/test_files.py +0 -740
- nucliadb/writer/tests/test_knowledgebox.py +0 -49
- nucliadb/writer/tests/test_reprocess_file_field.py +0 -133
- nucliadb/writer/tests/test_resources.py +0 -476
- nucliadb/writer/tests/test_service.py +0 -137
- nucliadb/writer/tests/test_tus.py +0 -203
- nucliadb/writer/tests/utils.py +0 -35
- nucliadb/writer/tus/pg.py +0 -125
- nucliadb-4.0.0.post542.dist-info/METADATA +0 -135
- nucliadb-4.0.0.post542.dist-info/RECORD +0 -462
- {nucliadb/ingest/tests → migrations/pg}/__init__.py +0 -0
- /nucliadb/{ingest/tests/integration → common/external_index_providers}/__init__.py +0 -0
- /nucliadb/{ingest/tests/integration/ingest → common/models_utils}/__init__.py +0 -0
- /nucliadb/{ingest/tests/unit → search/search/query_parser}/__init__.py +0 -0
- /nucliadb/{ingest/tests → tests}/vectors.py +0 -0
- {nucliadb-4.0.0.post542.dist-info → nucliadb-6.2.1.post2798.dist-info}/entry_points.txt +0 -0
- {nucliadb-4.0.0.post542.dist-info → nucliadb-6.2.1.post2798.dist-info}/top_level.txt +0 -0
- {nucliadb-4.0.0.post542.dist-info → nucliadb-6.2.1.post2798.dist-info}/zip-safe +0 -0
@@ -18,68 +18,43 @@
|
|
18
18
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
19
19
|
#
|
20
20
|
import asyncio
|
21
|
-
from
|
22
|
-
from time import monotonic as time
|
23
|
-
from typing import AsyncGenerator, AsyncIterator, Optional
|
24
|
-
|
25
|
-
from nucliadb_protos.nodereader_pb2 import RelationSearchRequest, RelationSearchResponse
|
21
|
+
from typing import Optional
|
26
22
|
|
23
|
+
from nucliadb.common.models_utils import to_proto
|
27
24
|
from nucliadb.search import logger
|
28
25
|
from nucliadb.search.predict import AnswerStatusCode
|
29
26
|
from nucliadb.search.requesters.utils import Method, node_query
|
30
|
-
from nucliadb.search.search.chat.
|
27
|
+
from nucliadb.search.search.chat.exceptions import NoRetrievalResultsError
|
31
28
|
from nucliadb.search.search.exceptions import IncompleteFindResultsError
|
32
29
|
from nucliadb.search.search.find import find
|
33
30
|
from nucliadb.search.search.merge import merge_relations_results
|
31
|
+
from nucliadb.search.search.metrics import RAGMetrics
|
34
32
|
from nucliadb.search.search.query import QueryParser
|
33
|
+
from nucliadb.search.settings import settings
|
35
34
|
from nucliadb.search.utilities import get_predict
|
36
35
|
from nucliadb_models.search import (
|
37
|
-
|
36
|
+
AskRequest,
|
38
37
|
ChatContextMessage,
|
39
|
-
ChatModel,
|
40
38
|
ChatOptions,
|
41
|
-
ChatRequest,
|
42
39
|
FindRequest,
|
43
40
|
KnowledgeboxFindResults,
|
44
|
-
MinScore,
|
45
41
|
NucliaDBClientType,
|
42
|
+
PreQueriesStrategy,
|
43
|
+
PreQuery,
|
44
|
+
PreQueryResult,
|
46
45
|
PromptContext,
|
47
46
|
PromptContextOrder,
|
48
47
|
Relations,
|
49
48
|
RephraseModel,
|
50
49
|
SearchOptions,
|
51
|
-
|
50
|
+
parse_rephrase_prompt,
|
52
51
|
)
|
53
52
|
from nucliadb_protos import audit_pb2
|
53
|
+
from nucliadb_protos.nodereader_pb2 import RelationSearchResponse, SearchRequest, SearchResponse
|
54
54
|
from nucliadb_telemetry.errors import capture_exception
|
55
|
-
from nucliadb_utils.helpers import async_gen_lookahead
|
56
55
|
from nucliadb_utils.utilities import get_audit
|
57
56
|
|
58
57
|
NOT_ENOUGH_CONTEXT_ANSWER = "Not enough data to answer this."
|
59
|
-
AUDIT_TEXT_RESULT_SEP = " \n\n "
|
60
|
-
START_OF_CITATIONS = b"_CIT_"
|
61
|
-
|
62
|
-
|
63
|
-
class FoundStatusCode:
|
64
|
-
def __init__(self, default: AnswerStatusCode = AnswerStatusCode.SUCCESS):
|
65
|
-
self._value = AnswerStatusCode.SUCCESS
|
66
|
-
|
67
|
-
def set(self, value: AnswerStatusCode) -> None:
|
68
|
-
self._value = value
|
69
|
-
|
70
|
-
@property
|
71
|
-
def value(self) -> AnswerStatusCode:
|
72
|
-
return self._value
|
73
|
-
|
74
|
-
|
75
|
-
@dataclass
|
76
|
-
class ChatResult:
|
77
|
-
nuclia_learning_id: Optional[str]
|
78
|
-
answer_stream: AsyncIterator[bytes]
|
79
|
-
status_code: FoundStatusCode
|
80
|
-
find_results: KnowledgeboxFindResults
|
81
|
-
prompt_context: PromptContext
|
82
|
-
prompt_context_order: PromptContextOrder
|
83
58
|
|
84
59
|
|
85
60
|
async def rephrase_query(
|
@@ -101,70 +76,120 @@ async def rephrase_query(
|
|
101
76
|
return await predict.rephrase_query(kbid, req)
|
102
77
|
|
103
78
|
|
104
|
-
async def format_generated_answer(
|
105
|
-
answer_generator: AsyncGenerator[bytes, None], output_status_code: FoundStatusCode
|
106
|
-
):
|
107
|
-
status_code: Optional[AnswerStatusCode] = None
|
108
|
-
is_last_chunk = False
|
109
|
-
async for answer_chunk, is_last_chunk in async_gen_lookahead(answer_generator):
|
110
|
-
if is_last_chunk:
|
111
|
-
try:
|
112
|
-
status_code = _parse_answer_status_code(answer_chunk)
|
113
|
-
except ValueError:
|
114
|
-
# TODO: remove this in the future, it's
|
115
|
-
# just for bw compatibility until predict
|
116
|
-
# is updated to the new protocol
|
117
|
-
status_code = AnswerStatusCode.SUCCESS
|
118
|
-
yield answer_chunk
|
119
|
-
else:
|
120
|
-
# TODO: this should be needed but, in case we receive the status
|
121
|
-
# code mixed with text, we strip it and return the text
|
122
|
-
if len(answer_chunk) != len(status_code.encode()):
|
123
|
-
answer_chunk = answer_chunk.rstrip(status_code.encode())
|
124
|
-
yield answer_chunk
|
125
|
-
break
|
126
|
-
yield answer_chunk
|
127
|
-
if not is_last_chunk:
|
128
|
-
logger.warning("BUG: /chat endpoint without last chunk")
|
129
|
-
|
130
|
-
output_status_code.set(status_code or AnswerStatusCode.SUCCESS)
|
131
|
-
|
132
|
-
|
133
79
|
async def get_find_results(
|
134
80
|
*,
|
135
81
|
kbid: str,
|
136
82
|
query: str,
|
137
|
-
|
83
|
+
item: AskRequest,
|
84
|
+
ndb_client: NucliaDBClientType,
|
85
|
+
user: str,
|
86
|
+
origin: str,
|
87
|
+
metrics: RAGMetrics = RAGMetrics(),
|
88
|
+
prequeries_strategy: Optional[PreQueriesStrategy] = None,
|
89
|
+
) -> tuple[KnowledgeboxFindResults, Optional[list[PreQueryResult]], QueryParser]:
|
90
|
+
prequeries_results = None
|
91
|
+
prefilter_queries_results = None
|
92
|
+
queries_results = None
|
93
|
+
if prequeries_strategy is not None:
|
94
|
+
prefilters = [prequery for prequery in prequeries_strategy.queries if prequery.prefilter]
|
95
|
+
prequeries = [prequery for prequery in prequeries_strategy.queries if not prequery.prefilter]
|
96
|
+
if len(prefilters) > 0:
|
97
|
+
with metrics.time("prefilters"):
|
98
|
+
prefilter_queries_results = await run_prequeries(
|
99
|
+
kbid,
|
100
|
+
prefilters,
|
101
|
+
x_ndb_client=ndb_client,
|
102
|
+
x_nucliadb_user=user,
|
103
|
+
x_forwarded_for=origin,
|
104
|
+
generative_model=item.generative_model,
|
105
|
+
metrics=metrics,
|
106
|
+
)
|
107
|
+
prefilter_matching_resources = {
|
108
|
+
resource
|
109
|
+
for _, find_results in prefilter_queries_results
|
110
|
+
for resource in find_results.resources.keys()
|
111
|
+
}
|
112
|
+
if len(prefilter_matching_resources) == 0:
|
113
|
+
raise NoRetrievalResultsError()
|
114
|
+
# Make sure the main query and prequeries use the same resource filters.
|
115
|
+
# This is important to avoid returning results that don't match the prefilter.
|
116
|
+
item.resource_filters = list(prefilter_matching_resources)
|
117
|
+
for prequery in prequeries:
|
118
|
+
prequery.request.resource_filters = list(prefilter_matching_resources)
|
119
|
+
prequery.request.show_hidden = item.show_hidden
|
120
|
+
|
121
|
+
if prequeries:
|
122
|
+
with metrics.time("prequeries"):
|
123
|
+
queries_results = await run_prequeries(
|
124
|
+
kbid,
|
125
|
+
prequeries,
|
126
|
+
x_ndb_client=ndb_client,
|
127
|
+
x_nucliadb_user=user,
|
128
|
+
x_forwarded_for=origin,
|
129
|
+
generative_model=item.generative_model,
|
130
|
+
metrics=metrics,
|
131
|
+
)
|
132
|
+
|
133
|
+
prequeries_results = (prefilter_queries_results or []) + (queries_results or [])
|
134
|
+
|
135
|
+
with metrics.time("main_query"):
|
136
|
+
main_results, query_parser = await run_main_query(
|
137
|
+
kbid,
|
138
|
+
query,
|
139
|
+
item,
|
140
|
+
ndb_client,
|
141
|
+
user,
|
142
|
+
origin,
|
143
|
+
metrics=metrics,
|
144
|
+
)
|
145
|
+
return main_results, prequeries_results, query_parser
|
146
|
+
|
147
|
+
|
148
|
+
async def run_main_query(
|
149
|
+
kbid: str,
|
150
|
+
query: str,
|
151
|
+
item: AskRequest,
|
138
152
|
ndb_client: NucliaDBClientType,
|
139
153
|
user: str,
|
140
154
|
origin: str,
|
155
|
+
metrics: RAGMetrics = RAGMetrics(),
|
141
156
|
) -> tuple[KnowledgeboxFindResults, QueryParser]:
|
142
157
|
find_request = FindRequest()
|
143
|
-
find_request.resource_filters =
|
158
|
+
find_request.resource_filters = item.resource_filters
|
144
159
|
find_request.features = []
|
145
|
-
if ChatOptions.
|
146
|
-
find_request.features.append(SearchOptions.
|
147
|
-
if ChatOptions.
|
148
|
-
find_request.features.append(SearchOptions.
|
149
|
-
if ChatOptions.RELATIONS in
|
160
|
+
if ChatOptions.SEMANTIC in item.features:
|
161
|
+
find_request.features.append(SearchOptions.SEMANTIC)
|
162
|
+
if ChatOptions.KEYWORD in item.features:
|
163
|
+
find_request.features.append(SearchOptions.KEYWORD)
|
164
|
+
if ChatOptions.RELATIONS in item.features:
|
150
165
|
find_request.features.append(SearchOptions.RELATIONS)
|
151
166
|
find_request.query = query
|
152
|
-
find_request.fields =
|
153
|
-
find_request.filters =
|
154
|
-
find_request.field_type_filter =
|
155
|
-
find_request.min_score =
|
156
|
-
find_request.
|
157
|
-
find_request.
|
158
|
-
find_request.
|
159
|
-
find_request.
|
160
|
-
find_request.
|
161
|
-
find_request.
|
162
|
-
find_request.
|
163
|
-
find_request.
|
164
|
-
find_request.
|
165
|
-
find_request.
|
166
|
-
find_request.
|
167
|
-
find_request.
|
167
|
+
find_request.fields = item.fields
|
168
|
+
find_request.filters = item.filters
|
169
|
+
find_request.field_type_filter = item.field_type_filter
|
170
|
+
find_request.min_score = item.min_score
|
171
|
+
find_request.vectorset = item.vectorset
|
172
|
+
find_request.range_creation_start = item.range_creation_start
|
173
|
+
find_request.range_creation_end = item.range_creation_end
|
174
|
+
find_request.range_modification_start = item.range_modification_start
|
175
|
+
find_request.range_modification_end = item.range_modification_end
|
176
|
+
find_request.show = item.show
|
177
|
+
find_request.extracted = item.extracted
|
178
|
+
find_request.shards = item.shards
|
179
|
+
find_request.autofilter = item.autofilter
|
180
|
+
find_request.highlight = item.highlight
|
181
|
+
find_request.security = item.security
|
182
|
+
find_request.debug = item.debug
|
183
|
+
find_request.rephrase = item.rephrase
|
184
|
+
find_request.rephrase_prompt = parse_rephrase_prompt(item)
|
185
|
+
find_request.rank_fusion = item.rank_fusion
|
186
|
+
find_request.reranker = item.reranker
|
187
|
+
# We don't support pagination, we always get the top_k results.
|
188
|
+
find_request.top_k = item.top_k
|
189
|
+
find_request.show_hidden = item.show_hidden
|
190
|
+
|
191
|
+
# this executes the model validators, that can tweak some fields
|
192
|
+
FindRequest.model_validate(find_request)
|
168
193
|
|
169
194
|
find_results, incomplete, query_parser = await find(
|
170
195
|
kbid,
|
@@ -172,7 +197,8 @@ async def get_find_results(
|
|
172
197
|
ndb_client,
|
173
198
|
user,
|
174
199
|
origin,
|
175
|
-
generative_model=
|
200
|
+
generative_model=item.generative_model,
|
201
|
+
metrics=metrics,
|
176
202
|
)
|
177
203
|
if incomplete:
|
178
204
|
raise IncompleteFindResultsError()
|
@@ -180,269 +206,100 @@ async def get_find_results(
|
|
180
206
|
|
181
207
|
|
182
208
|
async def get_relations_results(
|
183
|
-
*,
|
209
|
+
*,
|
210
|
+
kbid: str,
|
211
|
+
text_answer: str,
|
212
|
+
target_shard_replicas: Optional[list[str]],
|
213
|
+
timeout: Optional[float] = None,
|
184
214
|
) -> Relations:
|
185
215
|
try:
|
186
216
|
predict = get_predict()
|
187
217
|
detected_entities = await predict.detect_entities(kbid, text_answer)
|
188
|
-
|
189
|
-
|
190
|
-
|
218
|
+
request = SearchRequest()
|
219
|
+
request.relation_subgraph.entry_points.extend(detected_entities)
|
220
|
+
request.relation_subgraph.depth = 1
|
191
221
|
|
192
|
-
|
222
|
+
results: list[SearchResponse]
|
193
223
|
(
|
194
|
-
|
224
|
+
results,
|
195
225
|
_,
|
196
226
|
_,
|
197
227
|
) = await node_query(
|
198
228
|
kbid,
|
199
|
-
Method.
|
200
|
-
|
229
|
+
Method.SEARCH,
|
230
|
+
request,
|
201
231
|
target_shard_replicas=target_shard_replicas,
|
232
|
+
timeout=timeout,
|
233
|
+
use_read_replica_nodes=True,
|
234
|
+
retry_on_primary=False,
|
202
235
|
)
|
203
|
-
|
204
|
-
|
205
|
-
)
|
236
|
+
relations_results: list[RelationSearchResponse] = [result.relation for result in results]
|
237
|
+
return await merge_relations_results(relations_results, request.relation_subgraph)
|
206
238
|
except Exception as exc:
|
207
239
|
capture_exception(exc)
|
208
240
|
logger.exception("Error getting relations results")
|
209
241
|
return Relations(entities={})
|
210
242
|
|
211
243
|
|
212
|
-
|
213
|
-
await asyncio.sleep(0)
|
214
|
-
yield NOT_ENOUGH_CONTEXT_ANSWER.encode()
|
215
|
-
yield AnswerStatusCode.NO_CONTEXT.encode()
|
216
|
-
|
217
|
-
|
218
|
-
async def chat(
|
219
|
-
kbid: str,
|
220
|
-
chat_request: ChatRequest,
|
221
|
-
user_id: str,
|
222
|
-
client_type: NucliaDBClientType,
|
223
|
-
origin: str,
|
224
|
-
resource: Optional[str] = None,
|
225
|
-
) -> ChatResult:
|
226
|
-
start_time = time()
|
227
|
-
nuclia_learning_id: Optional[str] = None
|
228
|
-
chat_history = chat_request.context or []
|
229
|
-
user_context = chat_request.extra_context or []
|
230
|
-
user_query = chat_request.query
|
231
|
-
rephrased_query = None
|
232
|
-
prompt_context: PromptContext = {}
|
233
|
-
prompt_context_order: PromptContextOrder = {}
|
234
|
-
|
235
|
-
if len(chat_history) > 0 or len(user_context) > 0:
|
236
|
-
rephrased_query = await rephrase_query(
|
237
|
-
kbid,
|
238
|
-
chat_history=chat_history,
|
239
|
-
query=user_query,
|
240
|
-
user_id=user_id,
|
241
|
-
user_context=user_context,
|
242
|
-
generative_model=chat_request.generative_model,
|
243
|
-
)
|
244
|
-
|
245
|
-
# Retrieval is not needed if we are chatting on a specific
|
246
|
-
# resource and the full_resource strategy is enabled
|
247
|
-
needs_retrieval = True
|
248
|
-
if resource is not None:
|
249
|
-
chat_request.resource_filters = [resource]
|
250
|
-
if any(
|
251
|
-
strategy.name == "full_resource" for strategy in chat_request.rag_strategies
|
252
|
-
):
|
253
|
-
needs_retrieval = False
|
254
|
-
|
255
|
-
if needs_retrieval:
|
256
|
-
find_results, query_parser = await get_find_results(
|
257
|
-
kbid=kbid,
|
258
|
-
query=rephrased_query or user_query,
|
259
|
-
chat_request=chat_request,
|
260
|
-
ndb_client=client_type,
|
261
|
-
user=user_id,
|
262
|
-
origin=origin,
|
263
|
-
)
|
264
|
-
status_code = FoundStatusCode()
|
265
|
-
if len(find_results.resources) == 0:
|
266
|
-
# If no resources were found on the retrieval, we return
|
267
|
-
# a "Not enough context" answer and skip the llm query
|
268
|
-
answer_stream = format_generated_answer(
|
269
|
-
not_enough_context_generator(), status_code
|
270
|
-
)
|
271
|
-
return ChatResult(
|
272
|
-
nuclia_learning_id=nuclia_learning_id,
|
273
|
-
answer_stream=answer_stream,
|
274
|
-
status_code=status_code,
|
275
|
-
find_results=find_results,
|
276
|
-
prompt_context=prompt_context,
|
277
|
-
prompt_context_order=prompt_context_order,
|
278
|
-
)
|
279
|
-
else:
|
280
|
-
status_code = FoundStatusCode()
|
281
|
-
find_results = KnowledgeboxFindResults(resources={}, min_score=None)
|
282
|
-
query_parser = QueryParser(
|
283
|
-
kbid=kbid,
|
284
|
-
features=[],
|
285
|
-
query="",
|
286
|
-
filters=chat_request.filters,
|
287
|
-
page_number=0,
|
288
|
-
page_size=0,
|
289
|
-
min_score=MinScore(),
|
290
|
-
)
|
291
|
-
|
292
|
-
query_parser.max_tokens = chat_request.max_tokens # type: ignore
|
293
|
-
max_tokens_context = await query_parser.get_max_tokens_context()
|
294
|
-
prompt_context_builder = PromptContextBuilder(
|
295
|
-
kbid=kbid,
|
296
|
-
find_results=find_results,
|
297
|
-
resource=resource,
|
298
|
-
user_context=user_context,
|
299
|
-
strategies=chat_request.rag_strategies,
|
300
|
-
image_strategies=chat_request.rag_images_strategies,
|
301
|
-
max_context_characters=tokens_to_chars(max_tokens_context),
|
302
|
-
visual_llm=await query_parser.get_visual_llm_enabled(),
|
303
|
-
)
|
304
|
-
(
|
305
|
-
prompt_context,
|
306
|
-
prompt_context_order,
|
307
|
-
prompt_context_images,
|
308
|
-
) = await prompt_context_builder.build()
|
309
|
-
user_prompt = None
|
310
|
-
if chat_request.prompt is not None:
|
311
|
-
user_prompt = UserPrompt(prompt=chat_request.prompt)
|
312
|
-
chat_model = ChatModel(
|
313
|
-
user_id=user_id,
|
314
|
-
query_context=prompt_context,
|
315
|
-
query_context_order=prompt_context_order,
|
316
|
-
chat_history=chat_history,
|
317
|
-
question=user_query,
|
318
|
-
truncate=True,
|
319
|
-
user_prompt=user_prompt,
|
320
|
-
citations=chat_request.citations,
|
321
|
-
generative_model=chat_request.generative_model,
|
322
|
-
max_tokens=query_parser.get_max_tokens_answer(),
|
323
|
-
query_context_images=prompt_context_images,
|
324
|
-
prefer_markdown=chat_request.prefer_markdown,
|
325
|
-
)
|
326
|
-
predict = get_predict()
|
327
|
-
nuclia_learning_id, predict_generator = await predict.chat_query(kbid, chat_model)
|
328
|
-
|
329
|
-
async def _wrapped_stream():
|
330
|
-
# so we can audit after streamed out answer
|
331
|
-
text_answer = b""
|
332
|
-
async for chunk in format_generated_answer(predict_generator, status_code):
|
333
|
-
text_answer += chunk
|
334
|
-
yield chunk
|
335
|
-
|
336
|
-
await maybe_audit_chat(
|
337
|
-
kbid=kbid,
|
338
|
-
user_id=user_id,
|
339
|
-
client_type=client_type,
|
340
|
-
origin=origin,
|
341
|
-
duration=time() - start_time,
|
342
|
-
user_query=user_query,
|
343
|
-
rephrased_query=rephrased_query,
|
344
|
-
text_answer=text_answer,
|
345
|
-
status_code=status_code.value,
|
346
|
-
chat_history=chat_history,
|
347
|
-
query_context=prompt_context,
|
348
|
-
query_context_order=prompt_context_order,
|
349
|
-
learning_id=nuclia_learning_id,
|
350
|
-
)
|
351
|
-
|
352
|
-
answer_stream = _wrapped_stream()
|
353
|
-
return ChatResult(
|
354
|
-
nuclia_learning_id=nuclia_learning_id,
|
355
|
-
answer_stream=answer_stream,
|
356
|
-
status_code=status_code,
|
357
|
-
find_results=find_results,
|
358
|
-
prompt_context=prompt_context,
|
359
|
-
prompt_context_order=prompt_context_order,
|
360
|
-
)
|
361
|
-
|
362
|
-
|
363
|
-
def _parse_answer_status_code(chunk: bytes) -> AnswerStatusCode:
|
364
|
-
"""
|
365
|
-
Parses the status code from the last chunk of the answer.
|
366
|
-
"""
|
367
|
-
try:
|
368
|
-
return AnswerStatusCode(chunk.decode())
|
369
|
-
except ValueError:
|
370
|
-
# In some cases, even if the status code was yield separately
|
371
|
-
# at the server side, the status code is appended to the previous chunk...
|
372
|
-
# It may be a bug in the aiohttp.StreamResponse implementation,
|
373
|
-
# but we haven't spotted it yet. For now, we just try to parse the status code
|
374
|
-
# from the tail of the chunk.
|
375
|
-
logger.debug(
|
376
|
-
f"Error decoding status code from /chat's last chunk. Chunk: {chunk!r}"
|
377
|
-
)
|
378
|
-
if chunk == b"":
|
379
|
-
raise
|
380
|
-
if chunk.endswith(b"0"):
|
381
|
-
return AnswerStatusCode.SUCCESS
|
382
|
-
return AnswerStatusCode(chunk[-2:].decode())
|
383
|
-
|
384
|
-
|
385
|
-
async def maybe_audit_chat(
|
244
|
+
def maybe_audit_chat(
|
386
245
|
*,
|
387
246
|
kbid: str,
|
388
247
|
user_id: str,
|
389
248
|
client_type: NucliaDBClientType,
|
390
249
|
origin: str,
|
391
|
-
|
250
|
+
generative_answer_time: float,
|
251
|
+
generative_answer_first_chunk_time: float,
|
252
|
+
rephrase_time: Optional[float],
|
392
253
|
user_query: str,
|
393
254
|
rephrased_query: Optional[str],
|
394
255
|
text_answer: bytes,
|
395
|
-
status_code:
|
256
|
+
status_code: AnswerStatusCode,
|
396
257
|
chat_history: list[ChatContextMessage],
|
397
258
|
query_context: PromptContext,
|
398
259
|
query_context_order: PromptContextOrder,
|
399
260
|
learning_id: str,
|
261
|
+
model: str,
|
400
262
|
):
|
401
263
|
audit = get_audit()
|
402
264
|
if audit is None:
|
403
265
|
return
|
404
266
|
|
405
267
|
audit_answer = parse_audit_answer(text_answer, status_code)
|
268
|
+
# Append chat history
|
269
|
+
chat_history_context = [
|
270
|
+
audit_pb2.ChatContext(author=message.author, text=message.text) for message in chat_history
|
271
|
+
]
|
406
272
|
|
407
|
-
# Append
|
408
|
-
|
409
|
-
audit_pb2.
|
410
|
-
for
|
273
|
+
# Append paragraphs retrieved on this chat
|
274
|
+
chat_retrieved_context = [
|
275
|
+
audit_pb2.RetrievedContext(text_block_id=paragraph_id, text=text)
|
276
|
+
for paragraph_id, text in query_context.items()
|
411
277
|
]
|
412
|
-
|
413
|
-
|
414
|
-
audit_pb2.ChatContext(
|
415
|
-
author=Author.NUCLIA,
|
416
|
-
text=AUDIT_TEXT_RESULT_SEP.join(query_context_paragaph_ids),
|
417
|
-
)
|
418
|
-
)
|
419
|
-
await audit.chat(
|
278
|
+
|
279
|
+
audit.chat(
|
420
280
|
kbid,
|
421
281
|
user_id,
|
422
|
-
client_type
|
282
|
+
to_proto.client_type(client_type),
|
423
283
|
origin,
|
424
|
-
duration,
|
425
284
|
question=user_query,
|
285
|
+
generative_answer_time=generative_answer_time,
|
286
|
+
generative_answer_first_chunk_time=generative_answer_first_chunk_time,
|
287
|
+
rephrase_time=rephrase_time,
|
426
288
|
rephrased_question=rephrased_query,
|
427
|
-
|
289
|
+
chat_context=chat_history_context,
|
290
|
+
retrieved_context=chat_retrieved_context,
|
428
291
|
answer=audit_answer,
|
429
292
|
learning_id=learning_id,
|
293
|
+
status_code=int(status_code.value),
|
294
|
+
model=model,
|
430
295
|
)
|
431
296
|
|
432
297
|
|
433
|
-
def parse_audit_answer(
|
434
|
-
raw_text_answer: bytes, status_code: Optional[AnswerStatusCode]
|
435
|
-
) -> Optional[str]:
|
298
|
+
def parse_audit_answer(raw_text_answer: bytes, status_code: AnswerStatusCode) -> Optional[str]:
|
436
299
|
if status_code == AnswerStatusCode.NO_CONTEXT:
|
437
300
|
# We don't want to audit "Not enough context to answer this." and instead set a None.
|
438
301
|
return None
|
439
|
-
|
440
|
-
try:
|
441
|
-
raw_audit_answer, _ = raw_text_answer.split(START_OF_CITATIONS)
|
442
|
-
except ValueError:
|
443
|
-
raw_audit_answer = raw_text_answer
|
444
|
-
audit_answer = raw_audit_answer.decode()
|
445
|
-
return audit_answer
|
302
|
+
return raw_text_answer.decode()
|
446
303
|
|
447
304
|
|
448
305
|
def tokens_to_chars(n_tokens: int) -> int:
|
@@ -458,47 +315,55 @@ class ChatAuditor:
|
|
458
315
|
user_id: str,
|
459
316
|
client_type: NucliaDBClientType,
|
460
317
|
origin: str,
|
461
|
-
start_time: float,
|
462
318
|
user_query: str,
|
463
319
|
rephrased_query: Optional[str],
|
464
320
|
chat_history: list[ChatContextMessage],
|
465
321
|
learning_id: Optional[str],
|
466
322
|
query_context: PromptContext,
|
467
323
|
query_context_order: PromptContextOrder,
|
324
|
+
model: str,
|
468
325
|
):
|
469
326
|
self.kbid = kbid
|
470
327
|
self.user_id = user_id
|
471
328
|
self.client_type = client_type
|
472
329
|
self.origin = origin
|
473
|
-
self.start_time = start_time
|
474
330
|
self.user_query = user_query
|
475
331
|
self.rephrased_query = rephrased_query
|
476
332
|
self.chat_history = chat_history
|
477
333
|
self.learning_id = learning_id
|
478
334
|
self.query_context = query_context
|
479
335
|
self.query_context_order = query_context_order
|
336
|
+
self.model = model
|
480
337
|
|
481
|
-
|
482
|
-
|
338
|
+
def audit(
|
339
|
+
self,
|
340
|
+
text_answer: bytes,
|
341
|
+
generative_answer_time: float,
|
342
|
+
generative_answer_first_chunk_time: float,
|
343
|
+
rephrase_time: Optional[float],
|
344
|
+
status_code: AnswerStatusCode,
|
345
|
+
):
|
346
|
+
maybe_audit_chat(
|
483
347
|
kbid=self.kbid,
|
484
348
|
user_id=self.user_id,
|
485
349
|
client_type=self.client_type,
|
486
350
|
origin=self.origin,
|
487
|
-
duration=time() - self.start_time,
|
488
351
|
user_query=self.user_query,
|
489
352
|
rephrased_query=self.rephrased_query,
|
353
|
+
generative_answer_time=generative_answer_time,
|
354
|
+
generative_answer_first_chunk_time=generative_answer_first_chunk_time,
|
355
|
+
rephrase_time=rephrase_time,
|
490
356
|
text_answer=text_answer,
|
491
357
|
status_code=status_code,
|
492
358
|
chat_history=self.chat_history,
|
493
359
|
query_context=self.query_context,
|
494
360
|
query_context_order=self.query_context_order,
|
495
361
|
learning_id=self.learning_id or "unknown",
|
362
|
+
model=self.model,
|
496
363
|
)
|
497
364
|
|
498
365
|
|
499
|
-
def sorted_prompt_context_list(
|
500
|
-
context: PromptContext, order: PromptContextOrder
|
501
|
-
) -> list[str]:
|
366
|
+
def sorted_prompt_context_list(context: PromptContext, order: PromptContextOrder) -> list[str]:
|
502
367
|
"""
|
503
368
|
context = {"x": "foo", "y": "bar"}
|
504
369
|
order = {"y": 1, "x": 0}
|
@@ -509,3 +374,42 @@ def sorted_prompt_context_list(
|
|
509
374
|
key=lambda item: order.get(item[0], float("inf")),
|
510
375
|
)
|
511
376
|
return list(map(lambda item: item[1], sorted_items))
|
377
|
+
|
378
|
+
|
379
|
+
async def run_prequeries(
|
380
|
+
kbid: str,
|
381
|
+
prequeries: list[PreQuery],
|
382
|
+
x_ndb_client: NucliaDBClientType,
|
383
|
+
x_nucliadb_user: str,
|
384
|
+
x_forwarded_for: str,
|
385
|
+
generative_model: Optional[str] = None,
|
386
|
+
metrics: RAGMetrics = RAGMetrics(),
|
387
|
+
) -> list[PreQueryResult]:
|
388
|
+
"""
|
389
|
+
Runs simultaneous find requests for each prequery and returns the merged results according to the normalized weights.
|
390
|
+
"""
|
391
|
+
results: list[PreQueryResult] = []
|
392
|
+
max_parallel_prequeries = asyncio.Semaphore(settings.prequeries_max_parallel)
|
393
|
+
|
394
|
+
async def _prequery_find(
|
395
|
+
prequery: PreQuery,
|
396
|
+
):
|
397
|
+
async with max_parallel_prequeries:
|
398
|
+
find_results, _, _ = await find(
|
399
|
+
kbid,
|
400
|
+
prequery.request,
|
401
|
+
x_ndb_client,
|
402
|
+
x_nucliadb_user,
|
403
|
+
x_forwarded_for,
|
404
|
+
generative_model=generative_model,
|
405
|
+
metrics=metrics,
|
406
|
+
)
|
407
|
+
return prequery, find_results
|
408
|
+
|
409
|
+
ops = []
|
410
|
+
for prequery in prequeries:
|
411
|
+
ops.append(asyncio.create_task(_prequery_find(prequery)))
|
412
|
+
ops_results = await asyncio.gather(*ops)
|
413
|
+
for prequery, find_results in ops_results:
|
414
|
+
results.append((prequery, find_results))
|
415
|
+
return results
|