nucliadb 2.46.1.post382__py3-none-any.whl → 6.2.1.post2777__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- migrations/0002_rollover_shards.py +1 -2
- migrations/0003_allfields_key.py +2 -37
- migrations/0004_rollover_shards.py +1 -2
- migrations/0005_rollover_shards.py +1 -2
- migrations/0006_rollover_shards.py +2 -4
- migrations/0008_cleanup_leftover_rollover_metadata.py +1 -2
- migrations/0009_upgrade_relations_and_texts_to_v2.py +5 -4
- migrations/0010_fix_corrupt_indexes.py +11 -12
- migrations/0011_materialize_labelset_ids.py +2 -18
- migrations/0012_rollover_shards.py +6 -12
- migrations/0013_rollover_shards.py +2 -4
- migrations/0014_rollover_shards.py +5 -7
- migrations/0015_targeted_rollover.py +6 -12
- migrations/0016_upgrade_to_paragraphs_v2.py +27 -32
- migrations/0017_multiple_writable_shards.py +3 -6
- migrations/0018_purge_orphan_kbslugs.py +59 -0
- migrations/0019_upgrade_to_paragraphs_v3.py +66 -0
- migrations/0020_drain_nodes_from_cluster.py +83 -0
- nucliadb/standalone/tests/unit/test_run.py → migrations/0021_overwrite_vectorsets_key.py +17 -18
- nucliadb/tests/unit/test_openapi.py → migrations/0022_fix_paragraph_deletion_bug.py +16 -11
- migrations/0023_backfill_pg_catalog.py +80 -0
- migrations/0025_assign_models_to_kbs_v2.py +113 -0
- migrations/0026_fix_high_cardinality_content_types.py +61 -0
- migrations/0027_rollover_texts3.py +73 -0
- nucliadb/ingest/fields/date.py → migrations/pg/0001_bootstrap.py +10 -12
- migrations/pg/0002_catalog.py +42 -0
- nucliadb/ingest/tests/unit/test_settings.py → migrations/pg/0003_catalog_kbid_index.py +5 -3
- nucliadb/common/cluster/base.py +41 -24
- nucliadb/common/cluster/discovery/base.py +6 -14
- nucliadb/common/cluster/discovery/k8s.py +9 -19
- nucliadb/common/cluster/discovery/manual.py +1 -3
- nucliadb/common/cluster/discovery/single.py +1 -2
- nucliadb/common/cluster/discovery/utils.py +1 -3
- nucliadb/common/cluster/grpc_node_dummy.py +11 -16
- nucliadb/common/cluster/index_node.py +10 -19
- nucliadb/common/cluster/manager.py +223 -102
- nucliadb/common/cluster/rebalance.py +42 -37
- nucliadb/common/cluster/rollover.py +377 -204
- nucliadb/common/cluster/settings.py +16 -9
- nucliadb/common/cluster/standalone/grpc_node_binding.py +24 -76
- nucliadb/common/cluster/standalone/index_node.py +4 -11
- nucliadb/common/cluster/standalone/service.py +2 -6
- nucliadb/common/cluster/standalone/utils.py +9 -6
- nucliadb/common/cluster/utils.py +43 -29
- nucliadb/common/constants.py +20 -0
- nucliadb/common/context/__init__.py +6 -4
- nucliadb/common/context/fastapi.py +8 -5
- nucliadb/{tests/knowledgeboxes/__init__.py → common/counters.py} +8 -2
- nucliadb/common/datamanagers/__init__.py +24 -5
- nucliadb/common/datamanagers/atomic.py +102 -0
- nucliadb/common/datamanagers/cluster.py +5 -5
- nucliadb/common/datamanagers/entities.py +6 -16
- nucliadb/common/datamanagers/fields.py +84 -0
- nucliadb/common/datamanagers/kb.py +101 -24
- nucliadb/common/datamanagers/labels.py +26 -56
- nucliadb/common/datamanagers/processing.py +2 -6
- nucliadb/common/datamanagers/resources.py +214 -117
- nucliadb/common/datamanagers/rollover.py +77 -16
- nucliadb/{ingest/orm → common/datamanagers}/synonyms.py +16 -28
- nucliadb/common/datamanagers/utils.py +19 -11
- nucliadb/common/datamanagers/vectorsets.py +110 -0
- nucliadb/common/external_index_providers/base.py +257 -0
- nucliadb/{ingest/tests/unit/test_cache.py → common/external_index_providers/exceptions.py} +9 -8
- nucliadb/common/external_index_providers/manager.py +101 -0
- nucliadb/common/external_index_providers/pinecone.py +933 -0
- nucliadb/common/external_index_providers/settings.py +52 -0
- nucliadb/common/http_clients/auth.py +3 -6
- nucliadb/common/http_clients/processing.py +6 -11
- nucliadb/common/http_clients/utils.py +1 -3
- nucliadb/common/ids.py +240 -0
- nucliadb/common/locking.py +43 -13
- nucliadb/common/maindb/driver.py +11 -35
- nucliadb/common/maindb/exceptions.py +6 -6
- nucliadb/common/maindb/local.py +22 -9
- nucliadb/common/maindb/pg.py +206 -111
- nucliadb/common/maindb/utils.py +13 -44
- nucliadb/common/models_utils/from_proto.py +479 -0
- nucliadb/common/models_utils/to_proto.py +60 -0
- nucliadb/common/nidx.py +260 -0
- nucliadb/export_import/datamanager.py +25 -19
- nucliadb/export_import/exceptions.py +8 -0
- nucliadb/export_import/exporter.py +20 -7
- nucliadb/export_import/importer.py +6 -11
- nucliadb/export_import/models.py +5 -5
- nucliadb/export_import/tasks.py +4 -4
- nucliadb/export_import/utils.py +94 -54
- nucliadb/health.py +1 -3
- nucliadb/ingest/app.py +15 -11
- nucliadb/ingest/consumer/auditing.py +30 -147
- nucliadb/ingest/consumer/consumer.py +96 -52
- nucliadb/ingest/consumer/materializer.py +10 -12
- nucliadb/ingest/consumer/pull.py +12 -27
- nucliadb/ingest/consumer/service.py +20 -19
- nucliadb/ingest/consumer/shard_creator.py +7 -14
- nucliadb/ingest/consumer/utils.py +1 -3
- nucliadb/ingest/fields/base.py +139 -188
- nucliadb/ingest/fields/conversation.py +18 -5
- nucliadb/ingest/fields/exceptions.py +1 -4
- nucliadb/ingest/fields/file.py +7 -25
- nucliadb/ingest/fields/link.py +11 -16
- nucliadb/ingest/fields/text.py +9 -4
- nucliadb/ingest/orm/brain.py +255 -262
- nucliadb/ingest/orm/broker_message.py +181 -0
- nucliadb/ingest/orm/entities.py +36 -51
- nucliadb/ingest/orm/exceptions.py +12 -0
- nucliadb/ingest/orm/knowledgebox.py +334 -278
- nucliadb/ingest/orm/processor/__init__.py +2 -697
- nucliadb/ingest/orm/processor/auditing.py +117 -0
- nucliadb/ingest/orm/processor/data_augmentation.py +164 -0
- nucliadb/ingest/orm/processor/pgcatalog.py +84 -0
- nucliadb/ingest/orm/processor/processor.py +752 -0
- nucliadb/ingest/orm/processor/sequence_manager.py +1 -1
- nucliadb/ingest/orm/resource.py +280 -520
- nucliadb/ingest/orm/utils.py +25 -31
- nucliadb/ingest/partitions.py +3 -9
- nucliadb/ingest/processing.py +76 -81
- nucliadb/ingest/py.typed +0 -0
- nucliadb/ingest/serialize.py +37 -173
- nucliadb/ingest/service/__init__.py +1 -3
- nucliadb/ingest/service/writer.py +186 -577
- nucliadb/ingest/settings.py +13 -22
- nucliadb/ingest/utils.py +3 -6
- nucliadb/learning_proxy.py +264 -51
- nucliadb/metrics_exporter.py +30 -19
- nucliadb/middleware/__init__.py +1 -3
- nucliadb/migrator/command.py +1 -3
- nucliadb/migrator/datamanager.py +13 -13
- nucliadb/migrator/migrator.py +57 -37
- nucliadb/migrator/settings.py +2 -1
- nucliadb/migrator/utils.py +18 -10
- nucliadb/purge/__init__.py +139 -33
- nucliadb/purge/orphan_shards.py +7 -13
- nucliadb/reader/__init__.py +1 -3
- nucliadb/reader/api/models.py +3 -14
- nucliadb/reader/api/v1/__init__.py +0 -1
- nucliadb/reader/api/v1/download.py +27 -94
- nucliadb/reader/api/v1/export_import.py +4 -4
- nucliadb/reader/api/v1/knowledgebox.py +13 -13
- nucliadb/reader/api/v1/learning_config.py +8 -12
- nucliadb/reader/api/v1/resource.py +67 -93
- nucliadb/reader/api/v1/services.py +70 -125
- nucliadb/reader/app.py +16 -46
- nucliadb/reader/lifecycle.py +18 -4
- nucliadb/reader/py.typed +0 -0
- nucliadb/reader/reader/notifications.py +10 -31
- nucliadb/search/__init__.py +1 -3
- nucliadb/search/api/v1/__init__.py +2 -2
- nucliadb/search/api/v1/ask.py +112 -0
- nucliadb/search/api/v1/catalog.py +184 -0
- nucliadb/search/api/v1/feedback.py +17 -25
- nucliadb/search/api/v1/find.py +41 -41
- nucliadb/search/api/v1/knowledgebox.py +90 -62
- nucliadb/search/api/v1/predict_proxy.py +2 -2
- nucliadb/search/api/v1/resource/ask.py +66 -117
- nucliadb/search/api/v1/resource/search.py +51 -72
- nucliadb/search/api/v1/router.py +1 -0
- nucliadb/search/api/v1/search.py +50 -197
- nucliadb/search/api/v1/suggest.py +40 -54
- nucliadb/search/api/v1/summarize.py +9 -5
- nucliadb/search/api/v1/utils.py +2 -1
- nucliadb/search/app.py +16 -48
- nucliadb/search/lifecycle.py +10 -3
- nucliadb/search/predict.py +176 -188
- nucliadb/search/py.typed +0 -0
- nucliadb/search/requesters/utils.py +41 -63
- nucliadb/search/search/cache.py +149 -20
- nucliadb/search/search/chat/ask.py +918 -0
- nucliadb/search/{tests/unit/test_run.py → search/chat/exceptions.py} +14 -13
- nucliadb/search/search/chat/images.py +41 -17
- nucliadb/search/search/chat/prompt.py +851 -282
- nucliadb/search/search/chat/query.py +274 -267
- nucliadb/{writer/resource/slug.py → search/search/cut.py} +8 -6
- nucliadb/search/search/fetch.py +43 -36
- nucliadb/search/search/filters.py +9 -15
- nucliadb/search/search/find.py +214 -54
- nucliadb/search/search/find_merge.py +408 -391
- nucliadb/search/search/hydrator.py +191 -0
- nucliadb/search/search/merge.py +198 -234
- nucliadb/search/search/metrics.py +73 -2
- nucliadb/search/search/paragraphs.py +64 -106
- nucliadb/search/search/pgcatalog.py +233 -0
- nucliadb/search/search/predict_proxy.py +1 -1
- nucliadb/search/search/query.py +386 -257
- nucliadb/search/search/query_parser/exceptions.py +22 -0
- nucliadb/search/search/query_parser/models.py +101 -0
- nucliadb/search/search/query_parser/parser.py +183 -0
- nucliadb/search/search/rank_fusion.py +204 -0
- nucliadb/search/search/rerankers.py +270 -0
- nucliadb/search/search/shards.py +4 -38
- nucliadb/search/search/summarize.py +14 -18
- nucliadb/search/search/utils.py +27 -4
- nucliadb/search/settings.py +15 -1
- nucliadb/standalone/api_router.py +4 -10
- nucliadb/standalone/app.py +17 -14
- nucliadb/standalone/auth.py +7 -21
- nucliadb/standalone/config.py +9 -12
- nucliadb/standalone/introspect.py +5 -5
- nucliadb/standalone/lifecycle.py +26 -25
- nucliadb/standalone/migrations.py +58 -0
- nucliadb/standalone/purge.py +9 -8
- nucliadb/standalone/py.typed +0 -0
- nucliadb/standalone/run.py +25 -18
- nucliadb/standalone/settings.py +10 -14
- nucliadb/standalone/versions.py +15 -5
- nucliadb/tasks/consumer.py +8 -12
- nucliadb/tasks/producer.py +7 -6
- nucliadb/tests/config.py +53 -0
- nucliadb/train/__init__.py +1 -3
- nucliadb/train/api/utils.py +1 -2
- nucliadb/train/api/v1/shards.py +2 -2
- nucliadb/train/api/v1/trainset.py +4 -6
- nucliadb/train/app.py +14 -47
- nucliadb/train/generator.py +10 -19
- nucliadb/train/generators/field_classifier.py +7 -19
- nucliadb/train/generators/field_streaming.py +156 -0
- nucliadb/train/generators/image_classifier.py +12 -18
- nucliadb/train/generators/paragraph_classifier.py +5 -9
- nucliadb/train/generators/paragraph_streaming.py +6 -9
- nucliadb/train/generators/question_answer_streaming.py +19 -20
- nucliadb/train/generators/sentence_classifier.py +9 -15
- nucliadb/train/generators/token_classifier.py +45 -36
- nucliadb/train/generators/utils.py +14 -18
- nucliadb/train/lifecycle.py +7 -3
- nucliadb/train/nodes.py +23 -32
- nucliadb/train/py.typed +0 -0
- nucliadb/train/servicer.py +13 -21
- nucliadb/train/settings.py +2 -6
- nucliadb/train/types.py +13 -10
- nucliadb/train/upload.py +3 -6
- nucliadb/train/uploader.py +20 -25
- nucliadb/train/utils.py +1 -1
- nucliadb/writer/__init__.py +1 -3
- nucliadb/writer/api/constants.py +0 -5
- nucliadb/{ingest/fields/keywordset.py → writer/api/utils.py} +13 -10
- nucliadb/writer/api/v1/export_import.py +102 -49
- nucliadb/writer/api/v1/field.py +196 -620
- nucliadb/writer/api/v1/knowledgebox.py +221 -71
- nucliadb/writer/api/v1/learning_config.py +2 -2
- nucliadb/writer/api/v1/resource.py +114 -216
- nucliadb/writer/api/v1/services.py +64 -132
- nucliadb/writer/api/v1/slug.py +61 -0
- nucliadb/writer/api/v1/transaction.py +67 -0
- nucliadb/writer/api/v1/upload.py +184 -215
- nucliadb/writer/app.py +11 -61
- nucliadb/writer/back_pressure.py +62 -43
- nucliadb/writer/exceptions.py +0 -4
- nucliadb/writer/lifecycle.py +21 -15
- nucliadb/writer/py.typed +0 -0
- nucliadb/writer/resource/audit.py +2 -1
- nucliadb/writer/resource/basic.py +48 -62
- nucliadb/writer/resource/field.py +45 -135
- nucliadb/writer/resource/origin.py +1 -2
- nucliadb/writer/settings.py +14 -5
- nucliadb/writer/tus/__init__.py +17 -15
- nucliadb/writer/tus/azure.py +111 -0
- nucliadb/writer/tus/dm.py +17 -5
- nucliadb/writer/tus/exceptions.py +1 -3
- nucliadb/writer/tus/gcs.py +56 -84
- nucliadb/writer/tus/local.py +21 -37
- nucliadb/writer/tus/s3.py +28 -68
- nucliadb/writer/tus/storage.py +5 -56
- nucliadb/writer/vectorsets.py +125 -0
- nucliadb-6.2.1.post2777.dist-info/METADATA +148 -0
- nucliadb-6.2.1.post2777.dist-info/RECORD +343 -0
- {nucliadb-2.46.1.post382.dist-info → nucliadb-6.2.1.post2777.dist-info}/WHEEL +1 -1
- nucliadb/common/maindb/redis.py +0 -194
- nucliadb/common/maindb/tikv.py +0 -412
- nucliadb/ingest/fields/layout.py +0 -58
- nucliadb/ingest/tests/conftest.py +0 -30
- nucliadb/ingest/tests/fixtures.py +0 -771
- nucliadb/ingest/tests/integration/consumer/__init__.py +0 -18
- nucliadb/ingest/tests/integration/consumer/test_auditing.py +0 -80
- nucliadb/ingest/tests/integration/consumer/test_materializer.py +0 -89
- nucliadb/ingest/tests/integration/consumer/test_pull.py +0 -144
- nucliadb/ingest/tests/integration/consumer/test_service.py +0 -81
- nucliadb/ingest/tests/integration/consumer/test_shard_creator.py +0 -68
- nucliadb/ingest/tests/integration/ingest/test_ingest.py +0 -691
- nucliadb/ingest/tests/integration/ingest/test_processing_engine.py +0 -95
- nucliadb/ingest/tests/integration/ingest/test_relations.py +0 -272
- nucliadb/ingest/tests/unit/consumer/__init__.py +0 -18
- nucliadb/ingest/tests/unit/consumer/test_auditing.py +0 -140
- nucliadb/ingest/tests/unit/consumer/test_consumer.py +0 -69
- nucliadb/ingest/tests/unit/consumer/test_pull.py +0 -60
- nucliadb/ingest/tests/unit/consumer/test_shard_creator.py +0 -139
- nucliadb/ingest/tests/unit/consumer/test_utils.py +0 -67
- nucliadb/ingest/tests/unit/orm/__init__.py +0 -19
- nucliadb/ingest/tests/unit/orm/test_brain.py +0 -247
- nucliadb/ingest/tests/unit/orm/test_processor.py +0 -131
- nucliadb/ingest/tests/unit/orm/test_resource.py +0 -275
- nucliadb/ingest/tests/unit/test_partitions.py +0 -40
- nucliadb/ingest/tests/unit/test_processing.py +0 -171
- nucliadb/middleware/transaction.py +0 -117
- nucliadb/reader/api/v1/learning_collector.py +0 -63
- nucliadb/reader/tests/__init__.py +0 -19
- nucliadb/reader/tests/conftest.py +0 -31
- nucliadb/reader/tests/fixtures.py +0 -136
- nucliadb/reader/tests/test_list_resources.py +0 -75
- nucliadb/reader/tests/test_reader_file_download.py +0 -273
- nucliadb/reader/tests/test_reader_resource.py +0 -379
- nucliadb/reader/tests/test_reader_resource_field.py +0 -219
- nucliadb/search/api/v1/chat.py +0 -258
- nucliadb/search/api/v1/resource/chat.py +0 -94
- nucliadb/search/tests/__init__.py +0 -19
- nucliadb/search/tests/conftest.py +0 -33
- nucliadb/search/tests/fixtures.py +0 -199
- nucliadb/search/tests/node.py +0 -465
- nucliadb/search/tests/unit/__init__.py +0 -18
- nucliadb/search/tests/unit/api/__init__.py +0 -19
- nucliadb/search/tests/unit/api/v1/__init__.py +0 -19
- nucliadb/search/tests/unit/api/v1/resource/__init__.py +0 -19
- nucliadb/search/tests/unit/api/v1/resource/test_ask.py +0 -67
- nucliadb/search/tests/unit/api/v1/resource/test_chat.py +0 -97
- nucliadb/search/tests/unit/api/v1/test_chat.py +0 -96
- nucliadb/search/tests/unit/api/v1/test_predict_proxy.py +0 -98
- nucliadb/search/tests/unit/api/v1/test_summarize.py +0 -93
- nucliadb/search/tests/unit/search/__init__.py +0 -18
- nucliadb/search/tests/unit/search/requesters/__init__.py +0 -18
- nucliadb/search/tests/unit/search/requesters/test_utils.py +0 -210
- nucliadb/search/tests/unit/search/search/__init__.py +0 -19
- nucliadb/search/tests/unit/search/search/test_shards.py +0 -45
- nucliadb/search/tests/unit/search/search/test_utils.py +0 -82
- nucliadb/search/tests/unit/search/test_chat_prompt.py +0 -266
- nucliadb/search/tests/unit/search/test_fetch.py +0 -108
- nucliadb/search/tests/unit/search/test_filters.py +0 -125
- nucliadb/search/tests/unit/search/test_paragraphs.py +0 -157
- nucliadb/search/tests/unit/search/test_predict_proxy.py +0 -106
- nucliadb/search/tests/unit/search/test_query.py +0 -201
- nucliadb/search/tests/unit/test_app.py +0 -79
- nucliadb/search/tests/unit/test_find_merge.py +0 -112
- nucliadb/search/tests/unit/test_merge.py +0 -34
- nucliadb/search/tests/unit/test_predict.py +0 -584
- nucliadb/standalone/tests/__init__.py +0 -19
- nucliadb/standalone/tests/conftest.py +0 -33
- nucliadb/standalone/tests/fixtures.py +0 -38
- nucliadb/standalone/tests/unit/__init__.py +0 -18
- nucliadb/standalone/tests/unit/test_api_router.py +0 -61
- nucliadb/standalone/tests/unit/test_auth.py +0 -169
- nucliadb/standalone/tests/unit/test_introspect.py +0 -35
- nucliadb/standalone/tests/unit/test_versions.py +0 -68
- nucliadb/tests/benchmarks/__init__.py +0 -19
- nucliadb/tests/benchmarks/test_search.py +0 -99
- nucliadb/tests/conftest.py +0 -32
- nucliadb/tests/fixtures.py +0 -736
- nucliadb/tests/knowledgeboxes/philosophy_books.py +0 -203
- nucliadb/tests/knowledgeboxes/ten_dummy_resources.py +0 -109
- nucliadb/tests/migrations/__init__.py +0 -19
- nucliadb/tests/migrations/test_migration_0017.py +0 -80
- nucliadb/tests/tikv.py +0 -240
- nucliadb/tests/unit/__init__.py +0 -19
- nucliadb/tests/unit/common/__init__.py +0 -19
- nucliadb/tests/unit/common/cluster/__init__.py +0 -19
- nucliadb/tests/unit/common/cluster/discovery/__init__.py +0 -19
- nucliadb/tests/unit/common/cluster/discovery/test_k8s.py +0 -170
- nucliadb/tests/unit/common/cluster/standalone/__init__.py +0 -18
- nucliadb/tests/unit/common/cluster/standalone/test_service.py +0 -113
- nucliadb/tests/unit/common/cluster/standalone/test_utils.py +0 -59
- nucliadb/tests/unit/common/cluster/test_cluster.py +0 -399
- nucliadb/tests/unit/common/cluster/test_kb_shard_manager.py +0 -178
- nucliadb/tests/unit/common/cluster/test_rollover.py +0 -279
- nucliadb/tests/unit/common/maindb/__init__.py +0 -18
- nucliadb/tests/unit/common/maindb/test_driver.py +0 -127
- nucliadb/tests/unit/common/maindb/test_tikv.py +0 -53
- nucliadb/tests/unit/common/maindb/test_utils.py +0 -81
- nucliadb/tests/unit/common/test_context.py +0 -36
- nucliadb/tests/unit/export_import/__init__.py +0 -19
- nucliadb/tests/unit/export_import/test_datamanager.py +0 -37
- nucliadb/tests/unit/export_import/test_utils.py +0 -294
- nucliadb/tests/unit/migrator/__init__.py +0 -19
- nucliadb/tests/unit/migrator/test_migrator.py +0 -87
- nucliadb/tests/unit/tasks/__init__.py +0 -19
- nucliadb/tests/unit/tasks/conftest.py +0 -42
- nucliadb/tests/unit/tasks/test_consumer.py +0 -93
- nucliadb/tests/unit/tasks/test_producer.py +0 -95
- nucliadb/tests/unit/tasks/test_tasks.py +0 -60
- nucliadb/tests/unit/test_field_ids.py +0 -49
- nucliadb/tests/unit/test_health.py +0 -84
- nucliadb/tests/unit/test_kb_slugs.py +0 -54
- nucliadb/tests/unit/test_learning_proxy.py +0 -252
- nucliadb/tests/unit/test_metrics_exporter.py +0 -77
- nucliadb/tests/unit/test_purge.py +0 -138
- nucliadb/tests/utils/__init__.py +0 -74
- nucliadb/tests/utils/aiohttp_session.py +0 -44
- nucliadb/tests/utils/broker_messages/__init__.py +0 -167
- nucliadb/tests/utils/broker_messages/fields.py +0 -181
- nucliadb/tests/utils/broker_messages/helpers.py +0 -33
- nucliadb/tests/utils/entities.py +0 -78
- nucliadb/train/api/v1/check.py +0 -60
- nucliadb/train/tests/__init__.py +0 -19
- nucliadb/train/tests/conftest.py +0 -29
- nucliadb/train/tests/fixtures.py +0 -342
- nucliadb/train/tests/test_field_classification.py +0 -122
- nucliadb/train/tests/test_get_entities.py +0 -80
- nucliadb/train/tests/test_get_info.py +0 -51
- nucliadb/train/tests/test_get_ontology.py +0 -34
- nucliadb/train/tests/test_get_ontology_count.py +0 -63
- nucliadb/train/tests/test_image_classification.py +0 -222
- nucliadb/train/tests/test_list_fields.py +0 -39
- nucliadb/train/tests/test_list_paragraphs.py +0 -73
- nucliadb/train/tests/test_list_resources.py +0 -39
- nucliadb/train/tests/test_list_sentences.py +0 -71
- nucliadb/train/tests/test_paragraph_classification.py +0 -123
- nucliadb/train/tests/test_paragraph_streaming.py +0 -118
- nucliadb/train/tests/test_question_answer_streaming.py +0 -239
- nucliadb/train/tests/test_sentence_classification.py +0 -143
- nucliadb/train/tests/test_token_classification.py +0 -136
- nucliadb/train/tests/utils.py +0 -108
- nucliadb/writer/layouts/__init__.py +0 -51
- nucliadb/writer/layouts/v1.py +0 -59
- nucliadb/writer/resource/vectors.py +0 -120
- nucliadb/writer/tests/__init__.py +0 -19
- nucliadb/writer/tests/conftest.py +0 -31
- nucliadb/writer/tests/fixtures.py +0 -192
- nucliadb/writer/tests/test_fields.py +0 -486
- nucliadb/writer/tests/test_files.py +0 -743
- nucliadb/writer/tests/test_knowledgebox.py +0 -49
- nucliadb/writer/tests/test_reprocess_file_field.py +0 -139
- nucliadb/writer/tests/test_resources.py +0 -546
- nucliadb/writer/tests/test_service.py +0 -137
- nucliadb/writer/tests/test_tus.py +0 -203
- nucliadb/writer/tests/utils.py +0 -35
- nucliadb/writer/tus/pg.py +0 -125
- nucliadb-2.46.1.post382.dist-info/METADATA +0 -134
- nucliadb-2.46.1.post382.dist-info/RECORD +0 -451
- {nucliadb/ingest/tests → migrations/pg}/__init__.py +0 -0
- /nucliadb/{ingest/tests/integration → common/external_index_providers}/__init__.py +0 -0
- /nucliadb/{ingest/tests/integration/ingest → common/models_utils}/__init__.py +0 -0
- /nucliadb/{ingest/tests/unit → search/search/query_parser}/__init__.py +0 -0
- /nucliadb/{ingest/tests → tests}/vectors.py +0 -0
- {nucliadb-2.46.1.post382.dist-info → nucliadb-6.2.1.post2777.dist-info}/entry_points.txt +0 -0
- {nucliadb-2.46.1.post382.dist-info → nucliadb-6.2.1.post2777.dist-info}/top_level.txt +0 -0
- {nucliadb-2.46.1.post382.dist-info → nucliadb-6.2.1.post2777.dist-info}/zip-safe +0 -0
@@ -18,67 +18,43 @@
|
|
18
18
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
19
19
|
#
|
20
20
|
import asyncio
|
21
|
-
from
|
22
|
-
from time import monotonic as time
|
23
|
-
from typing import AsyncGenerator, AsyncIterator, Optional
|
24
|
-
|
25
|
-
from nucliadb_protos.nodereader_pb2 import RelationSearchRequest, RelationSearchResponse
|
21
|
+
from typing import Optional
|
26
22
|
|
23
|
+
from nucliadb.common.models_utils import to_proto
|
27
24
|
from nucliadb.search import logger
|
28
25
|
from nucliadb.search.predict import AnswerStatusCode
|
29
26
|
from nucliadb.search.requesters.utils import Method, node_query
|
30
|
-
from nucliadb.search.search.chat.
|
27
|
+
from nucliadb.search.search.chat.exceptions import NoRetrievalResultsError
|
31
28
|
from nucliadb.search.search.exceptions import IncompleteFindResultsError
|
32
29
|
from nucliadb.search.search.find import find
|
33
30
|
from nucliadb.search.search.merge import merge_relations_results
|
31
|
+
from nucliadb.search.search.metrics import RAGMetrics
|
34
32
|
from nucliadb.search.search.query import QueryParser
|
33
|
+
from nucliadb.search.settings import settings
|
35
34
|
from nucliadb.search.utilities import get_predict
|
36
35
|
from nucliadb_models.search import (
|
37
|
-
|
36
|
+
AskRequest,
|
38
37
|
ChatContextMessage,
|
39
|
-
ChatModel,
|
40
38
|
ChatOptions,
|
41
|
-
ChatRequest,
|
42
39
|
FindRequest,
|
43
40
|
KnowledgeboxFindResults,
|
44
41
|
NucliaDBClientType,
|
42
|
+
PreQueriesStrategy,
|
43
|
+
PreQuery,
|
44
|
+
PreQueryResult,
|
45
45
|
PromptContext,
|
46
46
|
PromptContextOrder,
|
47
47
|
Relations,
|
48
48
|
RephraseModel,
|
49
49
|
SearchOptions,
|
50
|
-
|
50
|
+
parse_rephrase_prompt,
|
51
51
|
)
|
52
52
|
from nucliadb_protos import audit_pb2
|
53
|
+
from nucliadb_protos.nodereader_pb2 import RelationSearchResponse, SearchRequest, SearchResponse
|
53
54
|
from nucliadb_telemetry.errors import capture_exception
|
54
|
-
from nucliadb_utils.helpers import async_gen_lookahead
|
55
55
|
from nucliadb_utils.utilities import get_audit
|
56
56
|
|
57
57
|
NOT_ENOUGH_CONTEXT_ANSWER = "Not enough data to answer this."
|
58
|
-
AUDIT_TEXT_RESULT_SEP = " \n\n "
|
59
|
-
START_OF_CITATIONS = b"_CIT_"
|
60
|
-
|
61
|
-
|
62
|
-
class FoundStatusCode:
|
63
|
-
def __init__(self, default: AnswerStatusCode = AnswerStatusCode.SUCCESS):
|
64
|
-
self._value = AnswerStatusCode.SUCCESS
|
65
|
-
|
66
|
-
def set(self, value: AnswerStatusCode) -> None:
|
67
|
-
self._value = value
|
68
|
-
|
69
|
-
@property
|
70
|
-
def value(self) -> AnswerStatusCode:
|
71
|
-
return self._value
|
72
|
-
|
73
|
-
|
74
|
-
@dataclass
|
75
|
-
class ChatResult:
|
76
|
-
nuclia_learning_id: Optional[str]
|
77
|
-
answer_stream: AsyncIterator[bytes]
|
78
|
-
status_code: FoundStatusCode
|
79
|
-
find_results: KnowledgeboxFindResults
|
80
|
-
prompt_context: PromptContext
|
81
|
-
prompt_context_order: PromptContextOrder
|
82
58
|
|
83
59
|
|
84
60
|
async def rephrase_query(
|
@@ -100,70 +76,120 @@ async def rephrase_query(
|
|
100
76
|
return await predict.rephrase_query(kbid, req)
|
101
77
|
|
102
78
|
|
103
|
-
async def format_generated_answer(
|
104
|
-
answer_generator: AsyncGenerator[bytes, None], output_status_code: FoundStatusCode
|
105
|
-
):
|
106
|
-
status_code: Optional[AnswerStatusCode] = None
|
107
|
-
is_last_chunk = False
|
108
|
-
async for answer_chunk, is_last_chunk in async_gen_lookahead(answer_generator):
|
109
|
-
if is_last_chunk:
|
110
|
-
try:
|
111
|
-
status_code = _parse_answer_status_code(answer_chunk)
|
112
|
-
except ValueError:
|
113
|
-
# TODO: remove this in the future, it's
|
114
|
-
# just for bw compatibility until predict
|
115
|
-
# is updated to the new protocol
|
116
|
-
status_code = AnswerStatusCode.SUCCESS
|
117
|
-
yield answer_chunk
|
118
|
-
else:
|
119
|
-
# TODO: this should be needed but, in case we receive the status
|
120
|
-
# code mixed with text, we strip it and return the text
|
121
|
-
if len(answer_chunk) != len(status_code.encode()):
|
122
|
-
answer_chunk = answer_chunk.rstrip(status_code.encode())
|
123
|
-
yield answer_chunk
|
124
|
-
break
|
125
|
-
yield answer_chunk
|
126
|
-
if not is_last_chunk:
|
127
|
-
logger.warning("BUG: /chat endpoint without last chunk")
|
128
|
-
|
129
|
-
output_status_code.set(status_code or AnswerStatusCode.SUCCESS)
|
130
|
-
|
131
|
-
|
132
79
|
async def get_find_results(
|
133
80
|
*,
|
134
81
|
kbid: str,
|
135
82
|
query: str,
|
136
|
-
|
83
|
+
item: AskRequest,
|
137
84
|
ndb_client: NucliaDBClientType,
|
138
85
|
user: str,
|
139
86
|
origin: str,
|
87
|
+
metrics: RAGMetrics = RAGMetrics(),
|
88
|
+
prequeries_strategy: Optional[PreQueriesStrategy] = None,
|
89
|
+
) -> tuple[KnowledgeboxFindResults, Optional[list[PreQueryResult]], QueryParser]:
|
90
|
+
prequeries_results = None
|
91
|
+
prefilter_queries_results = None
|
92
|
+
queries_results = None
|
93
|
+
if prequeries_strategy is not None:
|
94
|
+
prefilters = [prequery for prequery in prequeries_strategy.queries if prequery.prefilter]
|
95
|
+
prequeries = [prequery for prequery in prequeries_strategy.queries if not prequery.prefilter]
|
96
|
+
if len(prefilters) > 0:
|
97
|
+
with metrics.time("prefilters"):
|
98
|
+
prefilter_queries_results = await run_prequeries(
|
99
|
+
kbid,
|
100
|
+
prefilters,
|
101
|
+
x_ndb_client=ndb_client,
|
102
|
+
x_nucliadb_user=user,
|
103
|
+
x_forwarded_for=origin,
|
104
|
+
generative_model=item.generative_model,
|
105
|
+
metrics=metrics,
|
106
|
+
)
|
107
|
+
prefilter_matching_resources = {
|
108
|
+
resource
|
109
|
+
for _, find_results in prefilter_queries_results
|
110
|
+
for resource in find_results.resources.keys()
|
111
|
+
}
|
112
|
+
if len(prefilter_matching_resources) == 0:
|
113
|
+
raise NoRetrievalResultsError()
|
114
|
+
# Make sure the main query and prequeries use the same resource filters.
|
115
|
+
# This is important to avoid returning results that don't match the prefilter.
|
116
|
+
item.resource_filters = list(prefilter_matching_resources)
|
117
|
+
for prequery in prequeries:
|
118
|
+
prequery.request.resource_filters = list(prefilter_matching_resources)
|
119
|
+
prequery.request.show_hidden = item.show_hidden
|
120
|
+
|
121
|
+
if prequeries:
|
122
|
+
with metrics.time("prequeries"):
|
123
|
+
queries_results = await run_prequeries(
|
124
|
+
kbid,
|
125
|
+
prequeries,
|
126
|
+
x_ndb_client=ndb_client,
|
127
|
+
x_nucliadb_user=user,
|
128
|
+
x_forwarded_for=origin,
|
129
|
+
generative_model=item.generative_model,
|
130
|
+
metrics=metrics,
|
131
|
+
)
|
132
|
+
|
133
|
+
prequeries_results = (prefilter_queries_results or []) + (queries_results or [])
|
134
|
+
|
135
|
+
with metrics.time("main_query"):
|
136
|
+
main_results, query_parser = await run_main_query(
|
137
|
+
kbid,
|
138
|
+
query,
|
139
|
+
item,
|
140
|
+
ndb_client,
|
141
|
+
user,
|
142
|
+
origin,
|
143
|
+
metrics=metrics,
|
144
|
+
)
|
145
|
+
return main_results, prequeries_results, query_parser
|
146
|
+
|
147
|
+
|
148
|
+
async def run_main_query(
|
149
|
+
kbid: str,
|
150
|
+
query: str,
|
151
|
+
item: AskRequest,
|
152
|
+
ndb_client: NucliaDBClientType,
|
153
|
+
user: str,
|
154
|
+
origin: str,
|
155
|
+
metrics: RAGMetrics = RAGMetrics(),
|
140
156
|
) -> tuple[KnowledgeboxFindResults, QueryParser]:
|
141
157
|
find_request = FindRequest()
|
142
|
-
find_request.resource_filters =
|
158
|
+
find_request.resource_filters = item.resource_filters
|
143
159
|
find_request.features = []
|
144
|
-
if ChatOptions.
|
145
|
-
find_request.features.append(SearchOptions.
|
146
|
-
if ChatOptions.
|
147
|
-
find_request.features.append(SearchOptions.
|
148
|
-
if ChatOptions.RELATIONS in
|
160
|
+
if ChatOptions.SEMANTIC in item.features:
|
161
|
+
find_request.features.append(SearchOptions.SEMANTIC)
|
162
|
+
if ChatOptions.KEYWORD in item.features:
|
163
|
+
find_request.features.append(SearchOptions.KEYWORD)
|
164
|
+
if ChatOptions.RELATIONS in item.features:
|
149
165
|
find_request.features.append(SearchOptions.RELATIONS)
|
150
166
|
find_request.query = query
|
151
|
-
find_request.fields =
|
152
|
-
find_request.filters =
|
153
|
-
find_request.field_type_filter =
|
154
|
-
find_request.min_score =
|
155
|
-
find_request.
|
156
|
-
find_request.
|
157
|
-
find_request.
|
158
|
-
find_request.
|
159
|
-
find_request.
|
160
|
-
find_request.
|
161
|
-
find_request.
|
162
|
-
find_request.
|
163
|
-
find_request.
|
164
|
-
find_request.
|
165
|
-
find_request.
|
166
|
-
find_request.
|
167
|
+
find_request.fields = item.fields
|
168
|
+
find_request.filters = item.filters
|
169
|
+
find_request.field_type_filter = item.field_type_filter
|
170
|
+
find_request.min_score = item.min_score
|
171
|
+
find_request.vectorset = item.vectorset
|
172
|
+
find_request.range_creation_start = item.range_creation_start
|
173
|
+
find_request.range_creation_end = item.range_creation_end
|
174
|
+
find_request.range_modification_start = item.range_modification_start
|
175
|
+
find_request.range_modification_end = item.range_modification_end
|
176
|
+
find_request.show = item.show
|
177
|
+
find_request.extracted = item.extracted
|
178
|
+
find_request.shards = item.shards
|
179
|
+
find_request.autofilter = item.autofilter
|
180
|
+
find_request.highlight = item.highlight
|
181
|
+
find_request.security = item.security
|
182
|
+
find_request.debug = item.debug
|
183
|
+
find_request.rephrase = item.rephrase
|
184
|
+
find_request.rephrase_prompt = parse_rephrase_prompt(item)
|
185
|
+
find_request.rank_fusion = item.rank_fusion
|
186
|
+
find_request.reranker = item.reranker
|
187
|
+
# We don't support pagination, we always get the top_k results.
|
188
|
+
find_request.top_k = item.top_k
|
189
|
+
find_request.show_hidden = item.show_hidden
|
190
|
+
|
191
|
+
# this executes the model validators, that can tweak some fields
|
192
|
+
FindRequest.model_validate(find_request)
|
167
193
|
|
168
194
|
find_results, incomplete, query_parser = await find(
|
169
195
|
kbid,
|
@@ -171,7 +197,8 @@ async def get_find_results(
|
|
171
197
|
ndb_client,
|
172
198
|
user,
|
173
199
|
origin,
|
174
|
-
generative_model=
|
200
|
+
generative_model=item.generative_model,
|
201
|
+
metrics=metrics,
|
175
202
|
)
|
176
203
|
if incomplete:
|
177
204
|
raise IncompleteFindResultsError()
|
@@ -179,230 +206,210 @@ async def get_find_results(
|
|
179
206
|
|
180
207
|
|
181
208
|
async def get_relations_results(
|
182
|
-
*,
|
209
|
+
*,
|
210
|
+
kbid: str,
|
211
|
+
text_answer: str,
|
212
|
+
target_shard_replicas: Optional[list[str]],
|
213
|
+
timeout: Optional[float] = None,
|
183
214
|
) -> Relations:
|
184
215
|
try:
|
185
216
|
predict = get_predict()
|
186
217
|
detected_entities = await predict.detect_entities(kbid, text_answer)
|
187
|
-
|
188
|
-
|
189
|
-
|
218
|
+
request = SearchRequest()
|
219
|
+
request.relation_subgraph.entry_points.extend(detected_entities)
|
220
|
+
request.relation_subgraph.depth = 1
|
190
221
|
|
191
|
-
|
222
|
+
results: list[SearchResponse]
|
192
223
|
(
|
193
|
-
|
224
|
+
results,
|
194
225
|
_,
|
195
226
|
_,
|
196
227
|
) = await node_query(
|
197
228
|
kbid,
|
198
|
-
Method.
|
199
|
-
|
200
|
-
target_shard_replicas=
|
201
|
-
|
202
|
-
|
203
|
-
|
229
|
+
Method.SEARCH,
|
230
|
+
request,
|
231
|
+
target_shard_replicas=target_shard_replicas,
|
232
|
+
timeout=timeout,
|
233
|
+
use_read_replica_nodes=True,
|
234
|
+
retry_on_primary=False,
|
204
235
|
)
|
236
|
+
relations_results: list[RelationSearchResponse] = [result.relation for result in results]
|
237
|
+
return await merge_relations_results(relations_results, request.relation_subgraph)
|
205
238
|
except Exception as exc:
|
206
239
|
capture_exception(exc)
|
207
240
|
logger.exception("Error getting relations results")
|
208
241
|
return Relations(entities={})
|
209
242
|
|
210
243
|
|
211
|
-
|
212
|
-
await asyncio.sleep(0)
|
213
|
-
yield NOT_ENOUGH_CONTEXT_ANSWER.encode()
|
214
|
-
yield AnswerStatusCode.NO_CONTEXT.encode()
|
215
|
-
|
216
|
-
|
217
|
-
async def chat(
|
218
|
-
kbid: str,
|
219
|
-
chat_request: ChatRequest,
|
220
|
-
user_id: str,
|
221
|
-
client_type: NucliaDBClientType,
|
222
|
-
origin: str,
|
223
|
-
) -> ChatResult:
|
224
|
-
start_time = time()
|
225
|
-
nuclia_learning_id: Optional[str] = None
|
226
|
-
chat_history = chat_request.context or []
|
227
|
-
user_context = chat_request.extra_context or []
|
228
|
-
user_query = chat_request.query
|
229
|
-
rephrased_query = None
|
230
|
-
prompt_context: PromptContext = {}
|
231
|
-
prompt_context_order: PromptContextOrder = {}
|
232
|
-
|
233
|
-
if len(chat_history) > 0 or len(user_context) > 0:
|
234
|
-
rephrased_query = await rephrase_query(
|
235
|
-
kbid,
|
236
|
-
chat_history=chat_history,
|
237
|
-
query=user_query,
|
238
|
-
user_id=user_id,
|
239
|
-
user_context=user_context,
|
240
|
-
generative_model=chat_request.generative_model,
|
241
|
-
)
|
242
|
-
|
243
|
-
find_results, query_parser = await get_find_results(
|
244
|
-
kbid=kbid,
|
245
|
-
query=rephrased_query or user_query,
|
246
|
-
chat_request=chat_request,
|
247
|
-
ndb_client=client_type,
|
248
|
-
user=user_id,
|
249
|
-
origin=origin,
|
250
|
-
)
|
251
|
-
|
252
|
-
status_code = FoundStatusCode()
|
253
|
-
if len(find_results.resources) == 0:
|
254
|
-
answer_stream = format_generated_answer(
|
255
|
-
not_enough_context_generator(), status_code
|
256
|
-
)
|
257
|
-
else:
|
258
|
-
prompt_context_builder = PromptContextBuilder(
|
259
|
-
kbid=kbid,
|
260
|
-
find_results=find_results,
|
261
|
-
user_context=user_context,
|
262
|
-
strategies=chat_request.rag_strategies,
|
263
|
-
image_strategies=chat_request.rag_images_strategies,
|
264
|
-
max_context_size=await query_parser.get_max_context(),
|
265
|
-
visual_llm=await query_parser.get_visual_llm_enabled(),
|
266
|
-
)
|
267
|
-
(
|
268
|
-
prompt_context,
|
269
|
-
prompt_context_order,
|
270
|
-
prompt_context_images,
|
271
|
-
) = await prompt_context_builder.build()
|
272
|
-
user_prompt = None
|
273
|
-
if chat_request.prompt is not None:
|
274
|
-
user_prompt = UserPrompt(prompt=chat_request.prompt)
|
275
|
-
|
276
|
-
chat_model = ChatModel(
|
277
|
-
user_id=user_id,
|
278
|
-
query_context=prompt_context,
|
279
|
-
query_context_order=prompt_context_order,
|
280
|
-
chat_history=chat_history,
|
281
|
-
question=user_query,
|
282
|
-
truncate=True,
|
283
|
-
user_prompt=user_prompt,
|
284
|
-
citations=chat_request.citations,
|
285
|
-
generative_model=chat_request.generative_model,
|
286
|
-
max_tokens=chat_request.max_tokens,
|
287
|
-
query_context_images=prompt_context_images,
|
288
|
-
)
|
289
|
-
predict = get_predict()
|
290
|
-
nuclia_learning_id, predict_generator = await predict.chat_query(
|
291
|
-
kbid, chat_model
|
292
|
-
)
|
293
|
-
|
294
|
-
async def _wrapped_stream():
|
295
|
-
# so we can audit after streamed out answer
|
296
|
-
text_answer = b""
|
297
|
-
async for chunk in format_generated_answer(predict_generator, status_code):
|
298
|
-
text_answer += chunk
|
299
|
-
yield chunk
|
300
|
-
|
301
|
-
await maybe_audit_chat(
|
302
|
-
kbid=kbid,
|
303
|
-
user_id=user_id,
|
304
|
-
client_type=client_type,
|
305
|
-
origin=origin,
|
306
|
-
duration=time() - start_time,
|
307
|
-
user_query=user_query,
|
308
|
-
rephrased_query=rephrased_query,
|
309
|
-
text_answer=text_answer,
|
310
|
-
status_code=status_code.value,
|
311
|
-
chat_history=chat_history,
|
312
|
-
query_context=prompt_context,
|
313
|
-
learning_id=nuclia_learning_id,
|
314
|
-
)
|
315
|
-
|
316
|
-
answer_stream = _wrapped_stream()
|
317
|
-
|
318
|
-
return ChatResult(
|
319
|
-
nuclia_learning_id=nuclia_learning_id,
|
320
|
-
answer_stream=answer_stream,
|
321
|
-
status_code=status_code,
|
322
|
-
find_results=find_results,
|
323
|
-
prompt_context=prompt_context,
|
324
|
-
prompt_context_order=prompt_context_order,
|
325
|
-
)
|
326
|
-
|
327
|
-
|
328
|
-
def _parse_answer_status_code(chunk: bytes) -> AnswerStatusCode:
|
329
|
-
"""
|
330
|
-
Parses the status code from the last chunk of the answer.
|
331
|
-
"""
|
332
|
-
try:
|
333
|
-
return AnswerStatusCode(chunk.decode())
|
334
|
-
except ValueError:
|
335
|
-
# In some cases, even if the status code was yield separately
|
336
|
-
# at the server side, the status code is appended to the previous chunk...
|
337
|
-
# It may be a bug in the aiohttp.StreamResponse implementation,
|
338
|
-
# but we haven't spotted it yet. For now, we just try to parse the status code
|
339
|
-
# from the tail of the chunk.
|
340
|
-
logger.debug(
|
341
|
-
f"Error decoding status code from /chat's last chunk. Chunk: {chunk!r}"
|
342
|
-
)
|
343
|
-
if chunk == b"":
|
344
|
-
raise
|
345
|
-
if chunk.endswith(b"0"):
|
346
|
-
return AnswerStatusCode.SUCCESS
|
347
|
-
return AnswerStatusCode(chunk[-2:].decode())
|
348
|
-
|
349
|
-
|
350
|
-
async def maybe_audit_chat(
|
244
|
+
def maybe_audit_chat(
|
351
245
|
*,
|
352
246
|
kbid: str,
|
353
247
|
user_id: str,
|
354
248
|
client_type: NucliaDBClientType,
|
355
249
|
origin: str,
|
356
|
-
|
250
|
+
generative_answer_time: float,
|
251
|
+
generative_answer_first_chunk_time: float,
|
252
|
+
rephrase_time: Optional[float],
|
357
253
|
user_query: str,
|
358
254
|
rephrased_query: Optional[str],
|
359
255
|
text_answer: bytes,
|
360
|
-
status_code:
|
256
|
+
status_code: AnswerStatusCode,
|
361
257
|
chat_history: list[ChatContextMessage],
|
362
|
-
query_context:
|
258
|
+
query_context: PromptContext,
|
259
|
+
query_context_order: PromptContextOrder,
|
363
260
|
learning_id: str,
|
261
|
+
model: str,
|
364
262
|
):
|
365
263
|
audit = get_audit()
|
366
264
|
if audit is None:
|
367
265
|
return
|
368
266
|
|
369
267
|
audit_answer = parse_audit_answer(text_answer, status_code)
|
268
|
+
# Append chat history
|
269
|
+
chat_history_context = [
|
270
|
+
audit_pb2.ChatContext(author=message.author, text=message.text) for message in chat_history
|
271
|
+
]
|
370
272
|
|
371
|
-
# Append
|
372
|
-
|
373
|
-
audit_pb2.
|
374
|
-
for
|
273
|
+
# Append paragraphs retrieved on this chat
|
274
|
+
chat_retrieved_context = [
|
275
|
+
audit_pb2.RetrievedContext(text_block_id=paragraph_id, text=text)
|
276
|
+
for paragraph_id, text in query_context.items()
|
375
277
|
]
|
376
|
-
|
377
|
-
|
378
|
-
author=Author.NUCLIA,
|
379
|
-
text=AUDIT_TEXT_RESULT_SEP.join(query_context),
|
380
|
-
)
|
381
|
-
)
|
382
|
-
await audit.chat(
|
278
|
+
|
279
|
+
audit.chat(
|
383
280
|
kbid,
|
384
281
|
user_id,
|
385
|
-
client_type
|
282
|
+
to_proto.client_type(client_type),
|
386
283
|
origin,
|
387
|
-
duration,
|
388
284
|
question=user_query,
|
285
|
+
generative_answer_time=generative_answer_time,
|
286
|
+
generative_answer_first_chunk_time=generative_answer_first_chunk_time,
|
287
|
+
rephrase_time=rephrase_time,
|
389
288
|
rephrased_question=rephrased_query,
|
390
|
-
|
289
|
+
chat_context=chat_history_context,
|
290
|
+
retrieved_context=chat_retrieved_context,
|
391
291
|
answer=audit_answer,
|
392
292
|
learning_id=learning_id,
|
293
|
+
status_code=int(status_code.value),
|
294
|
+
model=model,
|
393
295
|
)
|
394
296
|
|
395
297
|
|
396
|
-
def parse_audit_answer(
|
397
|
-
raw_text_answer: bytes, status_code: Optional[AnswerStatusCode]
|
398
|
-
) -> Optional[str]:
|
298
|
+
def parse_audit_answer(raw_text_answer: bytes, status_code: AnswerStatusCode) -> Optional[str]:
|
399
299
|
if status_code == AnswerStatusCode.NO_CONTEXT:
|
400
300
|
# We don't want to audit "Not enough context to answer this." and instead set a None.
|
401
301
|
return None
|
402
|
-
|
403
|
-
|
404
|
-
|
405
|
-
|
406
|
-
|
407
|
-
|
408
|
-
return
|
302
|
+
return raw_text_answer.decode()
|
303
|
+
|
304
|
+
|
305
|
+
def tokens_to_chars(n_tokens: int) -> int:
|
306
|
+
# Multiply by 3 to have a good margin and guess between characters and tokens.
|
307
|
+
# This will be properly cut at the NUA predict API.
|
308
|
+
return n_tokens * 3
|
309
|
+
|
310
|
+
|
311
|
+
class ChatAuditor:
|
312
|
+
def __init__(
|
313
|
+
self,
|
314
|
+
kbid: str,
|
315
|
+
user_id: str,
|
316
|
+
client_type: NucliaDBClientType,
|
317
|
+
origin: str,
|
318
|
+
user_query: str,
|
319
|
+
rephrased_query: Optional[str],
|
320
|
+
chat_history: list[ChatContextMessage],
|
321
|
+
learning_id: Optional[str],
|
322
|
+
query_context: PromptContext,
|
323
|
+
query_context_order: PromptContextOrder,
|
324
|
+
model: str,
|
325
|
+
):
|
326
|
+
self.kbid = kbid
|
327
|
+
self.user_id = user_id
|
328
|
+
self.client_type = client_type
|
329
|
+
self.origin = origin
|
330
|
+
self.user_query = user_query
|
331
|
+
self.rephrased_query = rephrased_query
|
332
|
+
self.chat_history = chat_history
|
333
|
+
self.learning_id = learning_id
|
334
|
+
self.query_context = query_context
|
335
|
+
self.query_context_order = query_context_order
|
336
|
+
self.model = model
|
337
|
+
|
338
|
+
def audit(
|
339
|
+
self,
|
340
|
+
text_answer: bytes,
|
341
|
+
generative_answer_time: float,
|
342
|
+
generative_answer_first_chunk_time: float,
|
343
|
+
rephrase_time: Optional[float],
|
344
|
+
status_code: AnswerStatusCode,
|
345
|
+
):
|
346
|
+
maybe_audit_chat(
|
347
|
+
kbid=self.kbid,
|
348
|
+
user_id=self.user_id,
|
349
|
+
client_type=self.client_type,
|
350
|
+
origin=self.origin,
|
351
|
+
user_query=self.user_query,
|
352
|
+
rephrased_query=self.rephrased_query,
|
353
|
+
generative_answer_time=generative_answer_time,
|
354
|
+
generative_answer_first_chunk_time=generative_answer_first_chunk_time,
|
355
|
+
rephrase_time=rephrase_time,
|
356
|
+
text_answer=text_answer,
|
357
|
+
status_code=status_code,
|
358
|
+
chat_history=self.chat_history,
|
359
|
+
query_context=self.query_context,
|
360
|
+
query_context_order=self.query_context_order,
|
361
|
+
learning_id=self.learning_id or "unknown",
|
362
|
+
model=self.model,
|
363
|
+
)
|
364
|
+
|
365
|
+
|
366
|
+
def sorted_prompt_context_list(context: PromptContext, order: PromptContextOrder) -> list[str]:
|
367
|
+
"""
|
368
|
+
context = {"x": "foo", "y": "bar"}
|
369
|
+
order = {"y": 1, "x": 0}
|
370
|
+
sorted_prompt_context_list(context, order) == ["foo", "bar"]
|
371
|
+
"""
|
372
|
+
sorted_items = sorted(
|
373
|
+
context.items(),
|
374
|
+
key=lambda item: order.get(item[0], float("inf")),
|
375
|
+
)
|
376
|
+
return list(map(lambda item: item[1], sorted_items))
|
377
|
+
|
378
|
+
|
379
|
+
async def run_prequeries(
|
380
|
+
kbid: str,
|
381
|
+
prequeries: list[PreQuery],
|
382
|
+
x_ndb_client: NucliaDBClientType,
|
383
|
+
x_nucliadb_user: str,
|
384
|
+
x_forwarded_for: str,
|
385
|
+
generative_model: Optional[str] = None,
|
386
|
+
metrics: RAGMetrics = RAGMetrics(),
|
387
|
+
) -> list[PreQueryResult]:
|
388
|
+
"""
|
389
|
+
Runs simultaneous find requests for each prequery and returns the merged results according to the normalized weights.
|
390
|
+
"""
|
391
|
+
results: list[PreQueryResult] = []
|
392
|
+
max_parallel_prequeries = asyncio.Semaphore(settings.prequeries_max_parallel)
|
393
|
+
|
394
|
+
async def _prequery_find(
|
395
|
+
prequery: PreQuery,
|
396
|
+
):
|
397
|
+
async with max_parallel_prequeries:
|
398
|
+
find_results, _, _ = await find(
|
399
|
+
kbid,
|
400
|
+
prequery.request,
|
401
|
+
x_ndb_client,
|
402
|
+
x_nucliadb_user,
|
403
|
+
x_forwarded_for,
|
404
|
+
generative_model=generative_model,
|
405
|
+
metrics=metrics,
|
406
|
+
)
|
407
|
+
return prequery, find_results
|
408
|
+
|
409
|
+
ops = []
|
410
|
+
for prequery in prequeries:
|
411
|
+
ops.append(asyncio.create_task(_prequery_find(prequery)))
|
412
|
+
ops_results = await asyncio.gather(*ops)
|
413
|
+
for prequery, find_results in ops_results:
|
414
|
+
results.append((prequery, find_results))
|
415
|
+
return results
|
@@ -18,11 +18,13 @@
|
|
18
18
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
19
19
|
#
|
20
20
|
|
21
|
-
from
|
22
|
-
from nucliadb.ingest.orm.knowledgebox import KnowledgeBox
|
21
|
+
from typing import TypeVar
|
23
22
|
|
23
|
+
T = TypeVar("T")
|
24
24
|
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
25
|
+
|
26
|
+
def cut_page(items: list[T], top_k: int) -> tuple[list[T], bool]:
|
27
|
+
"""Return a slice of `items` representing the specified page and a boolean
|
28
|
+
indicating whether there is a next page or not"""
|
29
|
+
next_page = len(items) > top_k
|
30
|
+
return items[:top_k], next_page
|