nucliadb 2.46.1.post382__py3-none-any.whl → 6.2.1.post2777__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- migrations/0002_rollover_shards.py +1 -2
- migrations/0003_allfields_key.py +2 -37
- migrations/0004_rollover_shards.py +1 -2
- migrations/0005_rollover_shards.py +1 -2
- migrations/0006_rollover_shards.py +2 -4
- migrations/0008_cleanup_leftover_rollover_metadata.py +1 -2
- migrations/0009_upgrade_relations_and_texts_to_v2.py +5 -4
- migrations/0010_fix_corrupt_indexes.py +11 -12
- migrations/0011_materialize_labelset_ids.py +2 -18
- migrations/0012_rollover_shards.py +6 -12
- migrations/0013_rollover_shards.py +2 -4
- migrations/0014_rollover_shards.py +5 -7
- migrations/0015_targeted_rollover.py +6 -12
- migrations/0016_upgrade_to_paragraphs_v2.py +27 -32
- migrations/0017_multiple_writable_shards.py +3 -6
- migrations/0018_purge_orphan_kbslugs.py +59 -0
- migrations/0019_upgrade_to_paragraphs_v3.py +66 -0
- migrations/0020_drain_nodes_from_cluster.py +83 -0
- nucliadb/standalone/tests/unit/test_run.py → migrations/0021_overwrite_vectorsets_key.py +17 -18
- nucliadb/tests/unit/test_openapi.py → migrations/0022_fix_paragraph_deletion_bug.py +16 -11
- migrations/0023_backfill_pg_catalog.py +80 -0
- migrations/0025_assign_models_to_kbs_v2.py +113 -0
- migrations/0026_fix_high_cardinality_content_types.py +61 -0
- migrations/0027_rollover_texts3.py +73 -0
- nucliadb/ingest/fields/date.py → migrations/pg/0001_bootstrap.py +10 -12
- migrations/pg/0002_catalog.py +42 -0
- nucliadb/ingest/tests/unit/test_settings.py → migrations/pg/0003_catalog_kbid_index.py +5 -3
- nucliadb/common/cluster/base.py +41 -24
- nucliadb/common/cluster/discovery/base.py +6 -14
- nucliadb/common/cluster/discovery/k8s.py +9 -19
- nucliadb/common/cluster/discovery/manual.py +1 -3
- nucliadb/common/cluster/discovery/single.py +1 -2
- nucliadb/common/cluster/discovery/utils.py +1 -3
- nucliadb/common/cluster/grpc_node_dummy.py +11 -16
- nucliadb/common/cluster/index_node.py +10 -19
- nucliadb/common/cluster/manager.py +223 -102
- nucliadb/common/cluster/rebalance.py +42 -37
- nucliadb/common/cluster/rollover.py +377 -204
- nucliadb/common/cluster/settings.py +16 -9
- nucliadb/common/cluster/standalone/grpc_node_binding.py +24 -76
- nucliadb/common/cluster/standalone/index_node.py +4 -11
- nucliadb/common/cluster/standalone/service.py +2 -6
- nucliadb/common/cluster/standalone/utils.py +9 -6
- nucliadb/common/cluster/utils.py +43 -29
- nucliadb/common/constants.py +20 -0
- nucliadb/common/context/__init__.py +6 -4
- nucliadb/common/context/fastapi.py +8 -5
- nucliadb/{tests/knowledgeboxes/__init__.py → common/counters.py} +8 -2
- nucliadb/common/datamanagers/__init__.py +24 -5
- nucliadb/common/datamanagers/atomic.py +102 -0
- nucliadb/common/datamanagers/cluster.py +5 -5
- nucliadb/common/datamanagers/entities.py +6 -16
- nucliadb/common/datamanagers/fields.py +84 -0
- nucliadb/common/datamanagers/kb.py +101 -24
- nucliadb/common/datamanagers/labels.py +26 -56
- nucliadb/common/datamanagers/processing.py +2 -6
- nucliadb/common/datamanagers/resources.py +214 -117
- nucliadb/common/datamanagers/rollover.py +77 -16
- nucliadb/{ingest/orm → common/datamanagers}/synonyms.py +16 -28
- nucliadb/common/datamanagers/utils.py +19 -11
- nucliadb/common/datamanagers/vectorsets.py +110 -0
- nucliadb/common/external_index_providers/base.py +257 -0
- nucliadb/{ingest/tests/unit/test_cache.py → common/external_index_providers/exceptions.py} +9 -8
- nucliadb/common/external_index_providers/manager.py +101 -0
- nucliadb/common/external_index_providers/pinecone.py +933 -0
- nucliadb/common/external_index_providers/settings.py +52 -0
- nucliadb/common/http_clients/auth.py +3 -6
- nucliadb/common/http_clients/processing.py +6 -11
- nucliadb/common/http_clients/utils.py +1 -3
- nucliadb/common/ids.py +240 -0
- nucliadb/common/locking.py +43 -13
- nucliadb/common/maindb/driver.py +11 -35
- nucliadb/common/maindb/exceptions.py +6 -6
- nucliadb/common/maindb/local.py +22 -9
- nucliadb/common/maindb/pg.py +206 -111
- nucliadb/common/maindb/utils.py +13 -44
- nucliadb/common/models_utils/from_proto.py +479 -0
- nucliadb/common/models_utils/to_proto.py +60 -0
- nucliadb/common/nidx.py +260 -0
- nucliadb/export_import/datamanager.py +25 -19
- nucliadb/export_import/exceptions.py +8 -0
- nucliadb/export_import/exporter.py +20 -7
- nucliadb/export_import/importer.py +6 -11
- nucliadb/export_import/models.py +5 -5
- nucliadb/export_import/tasks.py +4 -4
- nucliadb/export_import/utils.py +94 -54
- nucliadb/health.py +1 -3
- nucliadb/ingest/app.py +15 -11
- nucliadb/ingest/consumer/auditing.py +30 -147
- nucliadb/ingest/consumer/consumer.py +96 -52
- nucliadb/ingest/consumer/materializer.py +10 -12
- nucliadb/ingest/consumer/pull.py +12 -27
- nucliadb/ingest/consumer/service.py +20 -19
- nucliadb/ingest/consumer/shard_creator.py +7 -14
- nucliadb/ingest/consumer/utils.py +1 -3
- nucliadb/ingest/fields/base.py +139 -188
- nucliadb/ingest/fields/conversation.py +18 -5
- nucliadb/ingest/fields/exceptions.py +1 -4
- nucliadb/ingest/fields/file.py +7 -25
- nucliadb/ingest/fields/link.py +11 -16
- nucliadb/ingest/fields/text.py +9 -4
- nucliadb/ingest/orm/brain.py +255 -262
- nucliadb/ingest/orm/broker_message.py +181 -0
- nucliadb/ingest/orm/entities.py +36 -51
- nucliadb/ingest/orm/exceptions.py +12 -0
- nucliadb/ingest/orm/knowledgebox.py +334 -278
- nucliadb/ingest/orm/processor/__init__.py +2 -697
- nucliadb/ingest/orm/processor/auditing.py +117 -0
- nucliadb/ingest/orm/processor/data_augmentation.py +164 -0
- nucliadb/ingest/orm/processor/pgcatalog.py +84 -0
- nucliadb/ingest/orm/processor/processor.py +752 -0
- nucliadb/ingest/orm/processor/sequence_manager.py +1 -1
- nucliadb/ingest/orm/resource.py +280 -520
- nucliadb/ingest/orm/utils.py +25 -31
- nucliadb/ingest/partitions.py +3 -9
- nucliadb/ingest/processing.py +76 -81
- nucliadb/ingest/py.typed +0 -0
- nucliadb/ingest/serialize.py +37 -173
- nucliadb/ingest/service/__init__.py +1 -3
- nucliadb/ingest/service/writer.py +186 -577
- nucliadb/ingest/settings.py +13 -22
- nucliadb/ingest/utils.py +3 -6
- nucliadb/learning_proxy.py +264 -51
- nucliadb/metrics_exporter.py +30 -19
- nucliadb/middleware/__init__.py +1 -3
- nucliadb/migrator/command.py +1 -3
- nucliadb/migrator/datamanager.py +13 -13
- nucliadb/migrator/migrator.py +57 -37
- nucliadb/migrator/settings.py +2 -1
- nucliadb/migrator/utils.py +18 -10
- nucliadb/purge/__init__.py +139 -33
- nucliadb/purge/orphan_shards.py +7 -13
- nucliadb/reader/__init__.py +1 -3
- nucliadb/reader/api/models.py +3 -14
- nucliadb/reader/api/v1/__init__.py +0 -1
- nucliadb/reader/api/v1/download.py +27 -94
- nucliadb/reader/api/v1/export_import.py +4 -4
- nucliadb/reader/api/v1/knowledgebox.py +13 -13
- nucliadb/reader/api/v1/learning_config.py +8 -12
- nucliadb/reader/api/v1/resource.py +67 -93
- nucliadb/reader/api/v1/services.py +70 -125
- nucliadb/reader/app.py +16 -46
- nucliadb/reader/lifecycle.py +18 -4
- nucliadb/reader/py.typed +0 -0
- nucliadb/reader/reader/notifications.py +10 -31
- nucliadb/search/__init__.py +1 -3
- nucliadb/search/api/v1/__init__.py +2 -2
- nucliadb/search/api/v1/ask.py +112 -0
- nucliadb/search/api/v1/catalog.py +184 -0
- nucliadb/search/api/v1/feedback.py +17 -25
- nucliadb/search/api/v1/find.py +41 -41
- nucliadb/search/api/v1/knowledgebox.py +90 -62
- nucliadb/search/api/v1/predict_proxy.py +2 -2
- nucliadb/search/api/v1/resource/ask.py +66 -117
- nucliadb/search/api/v1/resource/search.py +51 -72
- nucliadb/search/api/v1/router.py +1 -0
- nucliadb/search/api/v1/search.py +50 -197
- nucliadb/search/api/v1/suggest.py +40 -54
- nucliadb/search/api/v1/summarize.py +9 -5
- nucliadb/search/api/v1/utils.py +2 -1
- nucliadb/search/app.py +16 -48
- nucliadb/search/lifecycle.py +10 -3
- nucliadb/search/predict.py +176 -188
- nucliadb/search/py.typed +0 -0
- nucliadb/search/requesters/utils.py +41 -63
- nucliadb/search/search/cache.py +149 -20
- nucliadb/search/search/chat/ask.py +918 -0
- nucliadb/search/{tests/unit/test_run.py → search/chat/exceptions.py} +14 -13
- nucliadb/search/search/chat/images.py +41 -17
- nucliadb/search/search/chat/prompt.py +851 -282
- nucliadb/search/search/chat/query.py +274 -267
- nucliadb/{writer/resource/slug.py → search/search/cut.py} +8 -6
- nucliadb/search/search/fetch.py +43 -36
- nucliadb/search/search/filters.py +9 -15
- nucliadb/search/search/find.py +214 -54
- nucliadb/search/search/find_merge.py +408 -391
- nucliadb/search/search/hydrator.py +191 -0
- nucliadb/search/search/merge.py +198 -234
- nucliadb/search/search/metrics.py +73 -2
- nucliadb/search/search/paragraphs.py +64 -106
- nucliadb/search/search/pgcatalog.py +233 -0
- nucliadb/search/search/predict_proxy.py +1 -1
- nucliadb/search/search/query.py +386 -257
- nucliadb/search/search/query_parser/exceptions.py +22 -0
- nucliadb/search/search/query_parser/models.py +101 -0
- nucliadb/search/search/query_parser/parser.py +183 -0
- nucliadb/search/search/rank_fusion.py +204 -0
- nucliadb/search/search/rerankers.py +270 -0
- nucliadb/search/search/shards.py +4 -38
- nucliadb/search/search/summarize.py +14 -18
- nucliadb/search/search/utils.py +27 -4
- nucliadb/search/settings.py +15 -1
- nucliadb/standalone/api_router.py +4 -10
- nucliadb/standalone/app.py +17 -14
- nucliadb/standalone/auth.py +7 -21
- nucliadb/standalone/config.py +9 -12
- nucliadb/standalone/introspect.py +5 -5
- nucliadb/standalone/lifecycle.py +26 -25
- nucliadb/standalone/migrations.py +58 -0
- nucliadb/standalone/purge.py +9 -8
- nucliadb/standalone/py.typed +0 -0
- nucliadb/standalone/run.py +25 -18
- nucliadb/standalone/settings.py +10 -14
- nucliadb/standalone/versions.py +15 -5
- nucliadb/tasks/consumer.py +8 -12
- nucliadb/tasks/producer.py +7 -6
- nucliadb/tests/config.py +53 -0
- nucliadb/train/__init__.py +1 -3
- nucliadb/train/api/utils.py +1 -2
- nucliadb/train/api/v1/shards.py +2 -2
- nucliadb/train/api/v1/trainset.py +4 -6
- nucliadb/train/app.py +14 -47
- nucliadb/train/generator.py +10 -19
- nucliadb/train/generators/field_classifier.py +7 -19
- nucliadb/train/generators/field_streaming.py +156 -0
- nucliadb/train/generators/image_classifier.py +12 -18
- nucliadb/train/generators/paragraph_classifier.py +5 -9
- nucliadb/train/generators/paragraph_streaming.py +6 -9
- nucliadb/train/generators/question_answer_streaming.py +19 -20
- nucliadb/train/generators/sentence_classifier.py +9 -15
- nucliadb/train/generators/token_classifier.py +45 -36
- nucliadb/train/generators/utils.py +14 -18
- nucliadb/train/lifecycle.py +7 -3
- nucliadb/train/nodes.py +23 -32
- nucliadb/train/py.typed +0 -0
- nucliadb/train/servicer.py +13 -21
- nucliadb/train/settings.py +2 -6
- nucliadb/train/types.py +13 -10
- nucliadb/train/upload.py +3 -6
- nucliadb/train/uploader.py +20 -25
- nucliadb/train/utils.py +1 -1
- nucliadb/writer/__init__.py +1 -3
- nucliadb/writer/api/constants.py +0 -5
- nucliadb/{ingest/fields/keywordset.py → writer/api/utils.py} +13 -10
- nucliadb/writer/api/v1/export_import.py +102 -49
- nucliadb/writer/api/v1/field.py +196 -620
- nucliadb/writer/api/v1/knowledgebox.py +221 -71
- nucliadb/writer/api/v1/learning_config.py +2 -2
- nucliadb/writer/api/v1/resource.py +114 -216
- nucliadb/writer/api/v1/services.py +64 -132
- nucliadb/writer/api/v1/slug.py +61 -0
- nucliadb/writer/api/v1/transaction.py +67 -0
- nucliadb/writer/api/v1/upload.py +184 -215
- nucliadb/writer/app.py +11 -61
- nucliadb/writer/back_pressure.py +62 -43
- nucliadb/writer/exceptions.py +0 -4
- nucliadb/writer/lifecycle.py +21 -15
- nucliadb/writer/py.typed +0 -0
- nucliadb/writer/resource/audit.py +2 -1
- nucliadb/writer/resource/basic.py +48 -62
- nucliadb/writer/resource/field.py +45 -135
- nucliadb/writer/resource/origin.py +1 -2
- nucliadb/writer/settings.py +14 -5
- nucliadb/writer/tus/__init__.py +17 -15
- nucliadb/writer/tus/azure.py +111 -0
- nucliadb/writer/tus/dm.py +17 -5
- nucliadb/writer/tus/exceptions.py +1 -3
- nucliadb/writer/tus/gcs.py +56 -84
- nucliadb/writer/tus/local.py +21 -37
- nucliadb/writer/tus/s3.py +28 -68
- nucliadb/writer/tus/storage.py +5 -56
- nucliadb/writer/vectorsets.py +125 -0
- nucliadb-6.2.1.post2777.dist-info/METADATA +148 -0
- nucliadb-6.2.1.post2777.dist-info/RECORD +343 -0
- {nucliadb-2.46.1.post382.dist-info → nucliadb-6.2.1.post2777.dist-info}/WHEEL +1 -1
- nucliadb/common/maindb/redis.py +0 -194
- nucliadb/common/maindb/tikv.py +0 -412
- nucliadb/ingest/fields/layout.py +0 -58
- nucliadb/ingest/tests/conftest.py +0 -30
- nucliadb/ingest/tests/fixtures.py +0 -771
- nucliadb/ingest/tests/integration/consumer/__init__.py +0 -18
- nucliadb/ingest/tests/integration/consumer/test_auditing.py +0 -80
- nucliadb/ingest/tests/integration/consumer/test_materializer.py +0 -89
- nucliadb/ingest/tests/integration/consumer/test_pull.py +0 -144
- nucliadb/ingest/tests/integration/consumer/test_service.py +0 -81
- nucliadb/ingest/tests/integration/consumer/test_shard_creator.py +0 -68
- nucliadb/ingest/tests/integration/ingest/test_ingest.py +0 -691
- nucliadb/ingest/tests/integration/ingest/test_processing_engine.py +0 -95
- nucliadb/ingest/tests/integration/ingest/test_relations.py +0 -272
- nucliadb/ingest/tests/unit/consumer/__init__.py +0 -18
- nucliadb/ingest/tests/unit/consumer/test_auditing.py +0 -140
- nucliadb/ingest/tests/unit/consumer/test_consumer.py +0 -69
- nucliadb/ingest/tests/unit/consumer/test_pull.py +0 -60
- nucliadb/ingest/tests/unit/consumer/test_shard_creator.py +0 -139
- nucliadb/ingest/tests/unit/consumer/test_utils.py +0 -67
- nucliadb/ingest/tests/unit/orm/__init__.py +0 -19
- nucliadb/ingest/tests/unit/orm/test_brain.py +0 -247
- nucliadb/ingest/tests/unit/orm/test_processor.py +0 -131
- nucliadb/ingest/tests/unit/orm/test_resource.py +0 -275
- nucliadb/ingest/tests/unit/test_partitions.py +0 -40
- nucliadb/ingest/tests/unit/test_processing.py +0 -171
- nucliadb/middleware/transaction.py +0 -117
- nucliadb/reader/api/v1/learning_collector.py +0 -63
- nucliadb/reader/tests/__init__.py +0 -19
- nucliadb/reader/tests/conftest.py +0 -31
- nucliadb/reader/tests/fixtures.py +0 -136
- nucliadb/reader/tests/test_list_resources.py +0 -75
- nucliadb/reader/tests/test_reader_file_download.py +0 -273
- nucliadb/reader/tests/test_reader_resource.py +0 -379
- nucliadb/reader/tests/test_reader_resource_field.py +0 -219
- nucliadb/search/api/v1/chat.py +0 -258
- nucliadb/search/api/v1/resource/chat.py +0 -94
- nucliadb/search/tests/__init__.py +0 -19
- nucliadb/search/tests/conftest.py +0 -33
- nucliadb/search/tests/fixtures.py +0 -199
- nucliadb/search/tests/node.py +0 -465
- nucliadb/search/tests/unit/__init__.py +0 -18
- nucliadb/search/tests/unit/api/__init__.py +0 -19
- nucliadb/search/tests/unit/api/v1/__init__.py +0 -19
- nucliadb/search/tests/unit/api/v1/resource/__init__.py +0 -19
- nucliadb/search/tests/unit/api/v1/resource/test_ask.py +0 -67
- nucliadb/search/tests/unit/api/v1/resource/test_chat.py +0 -97
- nucliadb/search/tests/unit/api/v1/test_chat.py +0 -96
- nucliadb/search/tests/unit/api/v1/test_predict_proxy.py +0 -98
- nucliadb/search/tests/unit/api/v1/test_summarize.py +0 -93
- nucliadb/search/tests/unit/search/__init__.py +0 -18
- nucliadb/search/tests/unit/search/requesters/__init__.py +0 -18
- nucliadb/search/tests/unit/search/requesters/test_utils.py +0 -210
- nucliadb/search/tests/unit/search/search/__init__.py +0 -19
- nucliadb/search/tests/unit/search/search/test_shards.py +0 -45
- nucliadb/search/tests/unit/search/search/test_utils.py +0 -82
- nucliadb/search/tests/unit/search/test_chat_prompt.py +0 -266
- nucliadb/search/tests/unit/search/test_fetch.py +0 -108
- nucliadb/search/tests/unit/search/test_filters.py +0 -125
- nucliadb/search/tests/unit/search/test_paragraphs.py +0 -157
- nucliadb/search/tests/unit/search/test_predict_proxy.py +0 -106
- nucliadb/search/tests/unit/search/test_query.py +0 -201
- nucliadb/search/tests/unit/test_app.py +0 -79
- nucliadb/search/tests/unit/test_find_merge.py +0 -112
- nucliadb/search/tests/unit/test_merge.py +0 -34
- nucliadb/search/tests/unit/test_predict.py +0 -584
- nucliadb/standalone/tests/__init__.py +0 -19
- nucliadb/standalone/tests/conftest.py +0 -33
- nucliadb/standalone/tests/fixtures.py +0 -38
- nucliadb/standalone/tests/unit/__init__.py +0 -18
- nucliadb/standalone/tests/unit/test_api_router.py +0 -61
- nucliadb/standalone/tests/unit/test_auth.py +0 -169
- nucliadb/standalone/tests/unit/test_introspect.py +0 -35
- nucliadb/standalone/tests/unit/test_versions.py +0 -68
- nucliadb/tests/benchmarks/__init__.py +0 -19
- nucliadb/tests/benchmarks/test_search.py +0 -99
- nucliadb/tests/conftest.py +0 -32
- nucliadb/tests/fixtures.py +0 -736
- nucliadb/tests/knowledgeboxes/philosophy_books.py +0 -203
- nucliadb/tests/knowledgeboxes/ten_dummy_resources.py +0 -109
- nucliadb/tests/migrations/__init__.py +0 -19
- nucliadb/tests/migrations/test_migration_0017.py +0 -80
- nucliadb/tests/tikv.py +0 -240
- nucliadb/tests/unit/__init__.py +0 -19
- nucliadb/tests/unit/common/__init__.py +0 -19
- nucliadb/tests/unit/common/cluster/__init__.py +0 -19
- nucliadb/tests/unit/common/cluster/discovery/__init__.py +0 -19
- nucliadb/tests/unit/common/cluster/discovery/test_k8s.py +0 -170
- nucliadb/tests/unit/common/cluster/standalone/__init__.py +0 -18
- nucliadb/tests/unit/common/cluster/standalone/test_service.py +0 -113
- nucliadb/tests/unit/common/cluster/standalone/test_utils.py +0 -59
- nucliadb/tests/unit/common/cluster/test_cluster.py +0 -399
- nucliadb/tests/unit/common/cluster/test_kb_shard_manager.py +0 -178
- nucliadb/tests/unit/common/cluster/test_rollover.py +0 -279
- nucliadb/tests/unit/common/maindb/__init__.py +0 -18
- nucliadb/tests/unit/common/maindb/test_driver.py +0 -127
- nucliadb/tests/unit/common/maindb/test_tikv.py +0 -53
- nucliadb/tests/unit/common/maindb/test_utils.py +0 -81
- nucliadb/tests/unit/common/test_context.py +0 -36
- nucliadb/tests/unit/export_import/__init__.py +0 -19
- nucliadb/tests/unit/export_import/test_datamanager.py +0 -37
- nucliadb/tests/unit/export_import/test_utils.py +0 -294
- nucliadb/tests/unit/migrator/__init__.py +0 -19
- nucliadb/tests/unit/migrator/test_migrator.py +0 -87
- nucliadb/tests/unit/tasks/__init__.py +0 -19
- nucliadb/tests/unit/tasks/conftest.py +0 -42
- nucliadb/tests/unit/tasks/test_consumer.py +0 -93
- nucliadb/tests/unit/tasks/test_producer.py +0 -95
- nucliadb/tests/unit/tasks/test_tasks.py +0 -60
- nucliadb/tests/unit/test_field_ids.py +0 -49
- nucliadb/tests/unit/test_health.py +0 -84
- nucliadb/tests/unit/test_kb_slugs.py +0 -54
- nucliadb/tests/unit/test_learning_proxy.py +0 -252
- nucliadb/tests/unit/test_metrics_exporter.py +0 -77
- nucliadb/tests/unit/test_purge.py +0 -138
- nucliadb/tests/utils/__init__.py +0 -74
- nucliadb/tests/utils/aiohttp_session.py +0 -44
- nucliadb/tests/utils/broker_messages/__init__.py +0 -167
- nucliadb/tests/utils/broker_messages/fields.py +0 -181
- nucliadb/tests/utils/broker_messages/helpers.py +0 -33
- nucliadb/tests/utils/entities.py +0 -78
- nucliadb/train/api/v1/check.py +0 -60
- nucliadb/train/tests/__init__.py +0 -19
- nucliadb/train/tests/conftest.py +0 -29
- nucliadb/train/tests/fixtures.py +0 -342
- nucliadb/train/tests/test_field_classification.py +0 -122
- nucliadb/train/tests/test_get_entities.py +0 -80
- nucliadb/train/tests/test_get_info.py +0 -51
- nucliadb/train/tests/test_get_ontology.py +0 -34
- nucliadb/train/tests/test_get_ontology_count.py +0 -63
- nucliadb/train/tests/test_image_classification.py +0 -222
- nucliadb/train/tests/test_list_fields.py +0 -39
- nucliadb/train/tests/test_list_paragraphs.py +0 -73
- nucliadb/train/tests/test_list_resources.py +0 -39
- nucliadb/train/tests/test_list_sentences.py +0 -71
- nucliadb/train/tests/test_paragraph_classification.py +0 -123
- nucliadb/train/tests/test_paragraph_streaming.py +0 -118
- nucliadb/train/tests/test_question_answer_streaming.py +0 -239
- nucliadb/train/tests/test_sentence_classification.py +0 -143
- nucliadb/train/tests/test_token_classification.py +0 -136
- nucliadb/train/tests/utils.py +0 -108
- nucliadb/writer/layouts/__init__.py +0 -51
- nucliadb/writer/layouts/v1.py +0 -59
- nucliadb/writer/resource/vectors.py +0 -120
- nucliadb/writer/tests/__init__.py +0 -19
- nucliadb/writer/tests/conftest.py +0 -31
- nucliadb/writer/tests/fixtures.py +0 -192
- nucliadb/writer/tests/test_fields.py +0 -486
- nucliadb/writer/tests/test_files.py +0 -743
- nucliadb/writer/tests/test_knowledgebox.py +0 -49
- nucliadb/writer/tests/test_reprocess_file_field.py +0 -139
- nucliadb/writer/tests/test_resources.py +0 -546
- nucliadb/writer/tests/test_service.py +0 -137
- nucliadb/writer/tests/test_tus.py +0 -203
- nucliadb/writer/tests/utils.py +0 -35
- nucliadb/writer/tus/pg.py +0 -125
- nucliadb-2.46.1.post382.dist-info/METADATA +0 -134
- nucliadb-2.46.1.post382.dist-info/RECORD +0 -451
- {nucliadb/ingest/tests → migrations/pg}/__init__.py +0 -0
- /nucliadb/{ingest/tests/integration → common/external_index_providers}/__init__.py +0 -0
- /nucliadb/{ingest/tests/integration/ingest → common/models_utils}/__init__.py +0 -0
- /nucliadb/{ingest/tests/unit → search/search/query_parser}/__init__.py +0 -0
- /nucliadb/{ingest/tests → tests}/vectors.py +0 -0
- {nucliadb-2.46.1.post382.dist-info → nucliadb-6.2.1.post2777.dist-info}/entry_points.txt +0 -0
- {nucliadb-2.46.1.post382.dist-info → nucliadb-6.2.1.post2777.dist-info}/top_level.txt +0 -0
- {nucliadb-2.46.1.post382.dist-info → nucliadb-6.2.1.post2777.dist-info}/zip-safe +0 -0
nucliadb/search/predict.py
CHANGED
@@ -19,31 +19,37 @@
|
|
19
19
|
#
|
20
20
|
import json
|
21
21
|
import os
|
22
|
+
import random
|
22
23
|
from enum import Enum
|
23
|
-
from typing import AsyncIterator, Optional
|
24
|
+
from typing import Any, AsyncIterator, Optional
|
24
25
|
from unittest.mock import AsyncMock, Mock
|
25
26
|
|
26
27
|
import aiohttp
|
27
28
|
import backoff
|
28
|
-
from
|
29
|
+
from nuclia_models.predict.generative_responses import GenerativeChunk
|
30
|
+
from pydantic import ValidationError
|
29
31
|
|
30
|
-
from nucliadb.
|
32
|
+
from nucliadb.common import datamanagers
|
31
33
|
from nucliadb.search import logger
|
32
|
-
from
|
33
|
-
|
34
|
-
ChatModel,
|
35
|
-
FeedbackRequest,
|
34
|
+
from nucliadb.tests.vectors import Q, Qm2023
|
35
|
+
from nucliadb_models.internal.predict import (
|
36
36
|
Ner,
|
37
37
|
QueryInfo,
|
38
|
-
|
38
|
+
RerankModel,
|
39
|
+
RerankResponse,
|
39
40
|
SentenceSearch,
|
41
|
+
TokenSearch,
|
42
|
+
)
|
43
|
+
from nucliadb_models.search import (
|
44
|
+
ChatModel,
|
45
|
+
RephraseModel,
|
40
46
|
SummarizedResource,
|
41
47
|
SummarizedResponse,
|
42
48
|
SummarizeModel,
|
43
|
-
TokenSearch,
|
44
49
|
)
|
45
|
-
from
|
46
|
-
from
|
50
|
+
from nucliadb_protos.utils_pb2 import RelationNode
|
51
|
+
from nucliadb_telemetry import errors, metrics
|
52
|
+
from nucliadb_utils.const import Features
|
47
53
|
from nucliadb_utils.exceptions import LimitsExceededError
|
48
54
|
from nucliadb_utils.settings import nuclia_settings
|
49
55
|
from nucliadb_utils.utilities import Utility, has_feature, set_utility
|
@@ -59,10 +65,6 @@ class ProxiedPredictAPIError(Exception):
|
|
59
65
|
self.detail = detail
|
60
66
|
|
61
67
|
|
62
|
-
class PredictVectorMissing(Exception):
|
63
|
-
pass
|
64
|
-
|
65
|
-
|
66
68
|
class NUAKeyMissingError(Exception):
|
67
69
|
pass
|
68
70
|
|
@@ -77,13 +79,12 @@ class RephraseMissingContextError(Exception):
|
|
77
79
|
|
78
80
|
DUMMY_RELATION_NODE = [
|
79
81
|
RelationNode(value="Ferran", ntype=RelationNode.NodeType.ENTITY, subtype="PERSON"),
|
80
|
-
RelationNode(
|
81
|
-
value="Joan Antoni", ntype=RelationNode.NodeType.ENTITY, subtype="PERSON"
|
82
|
-
),
|
82
|
+
RelationNode(value="Joan Antoni", ntype=RelationNode.NodeType.ENTITY, subtype="PERSON"),
|
83
83
|
]
|
84
84
|
|
85
85
|
DUMMY_REPHRASE_QUERY = "This is a rephrased query"
|
86
86
|
DUMMY_LEARNING_ID = "00"
|
87
|
+
DUMMY_LEARNING_MODEL = "chatgpt"
|
87
88
|
|
88
89
|
|
89
90
|
PUBLIC_PREDICT = "/api/v1/predict"
|
@@ -94,11 +95,12 @@ TOKENS = "/tokens"
|
|
94
95
|
QUERY = "/query"
|
95
96
|
SUMMARIZE = "/summarize"
|
96
97
|
CHAT = "/chat"
|
97
|
-
ASK_DOCUMENT = "/ask_document"
|
98
98
|
REPHRASE = "/rephrase"
|
99
99
|
FEEDBACK = "/feedback"
|
100
|
+
RERANK = "/rerank"
|
100
101
|
|
101
102
|
NUCLIA_LEARNING_ID_HEADER = "NUCLIA-LEARNING-ID"
|
103
|
+
NUCLIA_LEARNING_MODEL_HEADER = "NUCLIA-LEARNING-MODEL"
|
102
104
|
|
103
105
|
|
104
106
|
predict_observer = metrics.Observer(
|
@@ -107,7 +109,6 @@ predict_observer = metrics.Observer(
|
|
107
109
|
error_mappings={
|
108
110
|
"over_limits": LimitsExceededError,
|
109
111
|
"predict_api_error": SendToPredictError,
|
110
|
-
"empty_vectors": PredictVectorMissing,
|
111
112
|
},
|
112
113
|
)
|
113
114
|
|
@@ -121,6 +122,13 @@ class AnswerStatusCode(str, Enum):
|
|
121
122
|
ERROR = "-1"
|
122
123
|
NO_CONTEXT = "-2"
|
123
124
|
|
125
|
+
def prettify(self) -> str:
|
126
|
+
return {
|
127
|
+
AnswerStatusCode.SUCCESS: "success",
|
128
|
+
AnswerStatusCode.ERROR: "error",
|
129
|
+
AnswerStatusCode.NO_CONTEXT: "no_context",
|
130
|
+
}[self]
|
131
|
+
|
124
132
|
|
125
133
|
async def start_predict_engine():
|
126
134
|
if nuclia_settings.dummy_predict:
|
@@ -144,9 +152,7 @@ def convert_relations(data: dict[str, list[dict[str, str]]]) -> list[RelationNod
|
|
144
152
|
for token in data["tokens"]:
|
145
153
|
text = token["text"]
|
146
154
|
klass = token["ner"]
|
147
|
-
result.append(
|
148
|
-
RelationNode(value=text, ntype=RelationNode.NodeType.ENTITY, subtype=klass)
|
149
|
-
)
|
155
|
+
result.append(RelationNode(value=text, ntype=RelationNode.NodeType.ENTITY, subtype=klass))
|
150
156
|
return result
|
151
157
|
|
152
158
|
|
@@ -179,9 +185,7 @@ class PredictEngine:
|
|
179
185
|
await self.session.close()
|
180
186
|
|
181
187
|
def check_nua_key_is_configured_for_onprem(self):
|
182
|
-
if self.onprem and (
|
183
|
-
self.nuclia_service_account is None and self.local_predict is False
|
184
|
-
):
|
188
|
+
if self.onprem and (self.nuclia_service_account is None and self.local_predict is False):
|
185
189
|
raise NUAKeyMissingError()
|
186
190
|
|
187
191
|
def get_predict_url(self, endpoint: str, kbid: str) -> str:
|
@@ -193,7 +197,7 @@ class PredictEngine:
|
|
193
197
|
# /api/v1/predict/rephrase/{kbid}
|
194
198
|
return f"{self.public_url}{PUBLIC_PREDICT}{endpoint}/{kbid}"
|
195
199
|
else:
|
196
|
-
if has_feature(
|
200
|
+
if has_feature(Features.VERSIONED_PRIVATE_PREDICT):
|
197
201
|
return f"{self.cluster_url}{VERSIONED_PRIVATE_PREDICT}{endpoint}"
|
198
202
|
else:
|
199
203
|
return f"{self.cluster_url}{PRIVATE_PREDICT}{endpoint}"
|
@@ -207,16 +211,13 @@ class PredictEngine:
|
|
207
211
|
else:
|
208
212
|
return {"X-STF-KBID": kbid}
|
209
213
|
|
210
|
-
async def check_response(
|
211
|
-
self, resp: aiohttp.ClientResponse, expected_status: int = 200
|
212
|
-
) -> None:
|
214
|
+
async def check_response(self, resp: aiohttp.ClientResponse, expected_status: int = 200) -> None:
|
213
215
|
if resp.status == expected_status:
|
214
216
|
return
|
215
217
|
|
216
218
|
if resp.status == 402:
|
217
219
|
data = await resp.json()
|
218
220
|
raise LimitsExceededError(402, data["detail"])
|
219
|
-
|
220
221
|
try:
|
221
222
|
data = await resp.json()
|
222
223
|
try:
|
@@ -228,7 +229,10 @@ class PredictEngine:
|
|
228
229
|
aiohttp.client_exceptions.ContentTypeError,
|
229
230
|
):
|
230
231
|
detail = await resp.text()
|
231
|
-
|
232
|
+
if str(resp.status).startswith("5"):
|
233
|
+
logger.error(f"Predict API error at {resp.url}: {detail}")
|
234
|
+
else:
|
235
|
+
logger.info(f"Predict API error at {resp.url}: {detail}")
|
232
236
|
raise ProxiedPredictAPIError(status=resp.status, detail=detail)
|
233
237
|
|
234
238
|
@backoff.on_exception(
|
@@ -241,36 +245,6 @@ class PredictEngine:
|
|
241
245
|
func = getattr(self.session, method.lower())
|
242
246
|
return await func(**request_args)
|
243
247
|
|
244
|
-
@predict_observer.wrap({"type": "feedback"})
|
245
|
-
async def send_feedback(
|
246
|
-
self,
|
247
|
-
kbid: str,
|
248
|
-
item: FeedbackRequest,
|
249
|
-
x_nucliadb_user: str,
|
250
|
-
x_ndb_client: str,
|
251
|
-
x_forwarded_for: str,
|
252
|
-
):
|
253
|
-
try:
|
254
|
-
self.check_nua_key_is_configured_for_onprem()
|
255
|
-
except NUAKeyMissingError:
|
256
|
-
logger.warning(
|
257
|
-
"Nuclia Service account is not defined so could not send the feedback"
|
258
|
-
)
|
259
|
-
return
|
260
|
-
|
261
|
-
data = item.dict()
|
262
|
-
data["user_id"] = x_nucliadb_user
|
263
|
-
data["client"] = x_ndb_client
|
264
|
-
data["forwarded"] = x_forwarded_for
|
265
|
-
|
266
|
-
resp = await self.make_request(
|
267
|
-
"POST",
|
268
|
-
url=self.get_predict_url(FEEDBACK, kbid),
|
269
|
-
json=data,
|
270
|
-
headers=self.get_predict_headers(kbid),
|
271
|
-
)
|
272
|
-
await self.check_response(resp, expected_status=204)
|
273
|
-
|
274
248
|
@predict_observer.wrap({"type": "rephrase"})
|
275
249
|
async def rephrase_query(self, kbid: str, item: RephraseModel) -> str:
|
276
250
|
try:
|
@@ -283,16 +257,20 @@ class PredictEngine:
|
|
283
257
|
resp = await self.make_request(
|
284
258
|
"POST",
|
285
259
|
url=self.get_predict_url(REPHRASE, kbid),
|
286
|
-
json=item.
|
260
|
+
json=item.model_dump(),
|
287
261
|
headers=self.get_predict_headers(kbid),
|
288
262
|
)
|
289
263
|
await self.check_response(resp, expected_status=200)
|
290
264
|
return await _parse_rephrase_response(resp)
|
291
265
|
|
292
|
-
@predict_observer.wrap({"type": "
|
293
|
-
async def
|
266
|
+
@predict_observer.wrap({"type": "chat_ndjson"})
|
267
|
+
async def chat_query_ndjson(
|
294
268
|
self, kbid: str, item: ChatModel
|
295
|
-
) -> tuple[str, AsyncIterator[
|
269
|
+
) -> tuple[str, str, AsyncIterator[GenerativeChunk]]:
|
270
|
+
"""
|
271
|
+
Chat query using the new stream format
|
272
|
+
Format specs: https://github.com/ndjson/ndjson-spec
|
273
|
+
"""
|
296
274
|
try:
|
297
275
|
self.check_nua_key_is_configured_for_onprem()
|
298
276
|
except NUAKeyMissingError:
|
@@ -300,60 +278,62 @@ class PredictEngine:
|
|
300
278
|
logger.warning(error)
|
301
279
|
raise SendToPredictError(error)
|
302
280
|
|
281
|
+
# The ndjson format is triggered by the Accept header
|
282
|
+
headers = self.get_predict_headers(kbid)
|
283
|
+
headers["Accept"] = "application/x-ndjson"
|
284
|
+
|
303
285
|
resp = await self.make_request(
|
304
286
|
"POST",
|
305
287
|
url=self.get_predict_url(CHAT, kbid),
|
306
|
-
json=item.
|
307
|
-
headers=
|
288
|
+
json=item.model_dump(),
|
289
|
+
headers=headers,
|
308
290
|
timeout=None,
|
309
291
|
)
|
310
292
|
await self.check_response(resp, expected_status=200)
|
311
293
|
ident = resp.headers.get(NUCLIA_LEARNING_ID_HEADER)
|
312
|
-
|
313
|
-
|
314
|
-
@predict_observer.wrap({"type": "ask_document"})
|
315
|
-
async def ask_document(
|
316
|
-
self, kbid: str, question: str, blocks: list[list[str]], user_id: str
|
317
|
-
) -> str:
|
318
|
-
try:
|
319
|
-
self.check_nua_key_is_configured_for_onprem()
|
320
|
-
except NUAKeyMissingError:
|
321
|
-
error = "Nuclia Service account is not defined so could not ask document"
|
322
|
-
logger.warning(error)
|
323
|
-
raise SendToPredictError(error)
|
324
|
-
|
325
|
-
item = AskDocumentModel(question=question, blocks=blocks, user_id=user_id)
|
326
|
-
resp = await self.make_request(
|
327
|
-
"POST",
|
328
|
-
url=self.get_predict_url(ASK_DOCUMENT, kbid),
|
329
|
-
json=item.dict(),
|
330
|
-
headers=self.get_predict_headers(kbid),
|
331
|
-
timeout=None,
|
332
|
-
)
|
333
|
-
await self.check_response(resp, expected_status=200)
|
334
|
-
return await resp.text()
|
294
|
+
model = resp.headers.get(NUCLIA_LEARNING_MODEL_HEADER)
|
295
|
+
return ident, model, get_chat_ndjson_generator(resp)
|
335
296
|
|
336
297
|
@predict_observer.wrap({"type": "query"})
|
337
298
|
async def query(
|
338
299
|
self,
|
339
300
|
kbid: str,
|
340
301
|
sentence: str,
|
302
|
+
semantic_model: Optional[str] = None,
|
341
303
|
generative_model: Optional[str] = None,
|
342
|
-
rephrase:
|
304
|
+
rephrase: bool = False,
|
305
|
+
rephrase_prompt: Optional[str] = None,
|
343
306
|
) -> QueryInfo:
|
307
|
+
"""
|
308
|
+
Query endpoint: returns information to be used by NucliaDB at retrieval time, for instance:
|
309
|
+
- The embeddings
|
310
|
+
- The entities
|
311
|
+
- The stop words
|
312
|
+
- The semantic threshold
|
313
|
+
- etc.
|
314
|
+
|
315
|
+
:param kbid: KnowledgeBox ID
|
316
|
+
:param sentence: The query sentence
|
317
|
+
:param semantic_model: The semantic model to use to generate the embeddings
|
318
|
+
:param generative_model: The generative model that will be used to generate the answer
|
319
|
+
:param rephrase: If the query should be rephrased before calculating the embeddings for a better retrieval
|
320
|
+
:param rephrase_prompt: Custom prompt to use for rephrasing
|
321
|
+
"""
|
344
322
|
try:
|
345
323
|
self.check_nua_key_is_configured_for_onprem()
|
346
324
|
except NUAKeyMissingError:
|
347
|
-
error =
|
348
|
-
"Nuclia Service account is not defined so could not ask query endpoint"
|
349
|
-
)
|
325
|
+
error = "Nuclia Service account is not defined so could not ask query endpoint"
|
350
326
|
logger.warning(error)
|
351
327
|
raise SendToPredictError(error)
|
352
328
|
|
353
|
-
params = {
|
329
|
+
params: dict[str, Any] = {
|
354
330
|
"text": sentence,
|
355
331
|
"rephrase": str(rephrase),
|
356
332
|
}
|
333
|
+
if rephrase_prompt is not None:
|
334
|
+
params["rephrase_prompt"] = rephrase_prompt
|
335
|
+
if semantic_model is not None:
|
336
|
+
params["semantic_models"] = [semantic_model]
|
357
337
|
if generative_model is not None:
|
358
338
|
params["generative_model"] = generative_model
|
359
339
|
|
@@ -367,28 +347,6 @@ class PredictEngine:
|
|
367
347
|
data = await resp.json()
|
368
348
|
return QueryInfo(**data)
|
369
349
|
|
370
|
-
@predict_observer.wrap({"type": "sentence"})
|
371
|
-
async def convert_sentence_to_vector(self, kbid: str, sentence: str) -> list[float]:
|
372
|
-
try:
|
373
|
-
self.check_nua_key_is_configured_for_onprem()
|
374
|
-
except NUAKeyMissingError:
|
375
|
-
logger.warning(
|
376
|
-
"Nuclia Service account is not defined so could not retrieve vectors for the query"
|
377
|
-
)
|
378
|
-
return []
|
379
|
-
|
380
|
-
resp = await self.make_request(
|
381
|
-
"GET",
|
382
|
-
url=self.get_predict_url(SENTENCE, kbid),
|
383
|
-
params={"text": sentence},
|
384
|
-
headers=self.get_predict_headers(kbid),
|
385
|
-
)
|
386
|
-
await self.check_response(resp, expected_status=200)
|
387
|
-
data = await resp.json()
|
388
|
-
if len(data["data"]) == 0:
|
389
|
-
raise PredictVectorMissing()
|
390
|
-
return data["data"]
|
391
|
-
|
392
350
|
@predict_observer.wrap({"type": "entities"})
|
393
351
|
async def detect_entities(self, kbid: str, sentence: str) -> list[RelationNode]:
|
394
352
|
try:
|
@@ -420,26 +378,46 @@ class PredictEngine:
|
|
420
378
|
resp = await self.make_request(
|
421
379
|
"POST",
|
422
380
|
url=self.get_predict_url(SUMMARIZE, kbid),
|
423
|
-
json=item.
|
381
|
+
json=item.model_dump(),
|
424
382
|
headers=self.get_predict_headers(kbid),
|
425
383
|
timeout=None,
|
426
384
|
)
|
427
385
|
await self.check_response(resp, expected_status=200)
|
428
386
|
data = await resp.json()
|
429
|
-
return SummarizedResponse.
|
387
|
+
return SummarizedResponse.model_validate(data)
|
388
|
+
|
389
|
+
@predict_observer.wrap({"type": "rerank"})
|
390
|
+
async def rerank(self, kbid: str, item: RerankModel) -> RerankResponse:
|
391
|
+
try:
|
392
|
+
self.check_nua_key_is_configured_for_onprem()
|
393
|
+
except NUAKeyMissingError:
|
394
|
+
error = "Nuclia Service account is not defined. Rerank operation could not be performed"
|
395
|
+
logger.warning(error)
|
396
|
+
raise SendToPredictError(error)
|
397
|
+
resp = await self.make_request(
|
398
|
+
"POST",
|
399
|
+
url=self.get_predict_url(RERANK, kbid),
|
400
|
+
json=item.model_dump(),
|
401
|
+
headers=self.get_predict_headers(kbid),
|
402
|
+
)
|
403
|
+
await self.check_response(resp, expected_status=200)
|
404
|
+
data = await resp.json()
|
405
|
+
return RerankResponse.model_validate(data)
|
430
406
|
|
431
407
|
|
432
408
|
class DummyPredictEngine(PredictEngine):
|
409
|
+
default_semantic_threshold = 0.7
|
410
|
+
|
433
411
|
def __init__(self):
|
434
412
|
self.onprem = True
|
435
413
|
self.cluster_url = "http://localhost:8000"
|
436
414
|
self.public_url = "http://localhost:8000"
|
437
415
|
self.calls = []
|
438
|
-
self.
|
439
|
-
b"valid ",
|
440
|
-
b"answer ",
|
441
|
-
b" to",
|
442
|
-
|
416
|
+
self.ndjson_answer = [
|
417
|
+
b'{"chunk": {"type": "text", "text": "valid "}}\n',
|
418
|
+
b'{"chunk": {"type": "text", "text": "answer "}}\n',
|
419
|
+
b'{"chunk": {"type": "text", "text": "to"}}\n',
|
420
|
+
b'{"chunk": {"type": "status", "code": "0"}}\n',
|
443
421
|
]
|
444
422
|
self.max_context = 1000
|
445
423
|
|
@@ -458,84 +436,72 @@ class DummyPredictEngine(PredictEngine):
|
|
458
436
|
response.headers = {NUCLIA_LEARNING_ID_HEADER: DUMMY_LEARNING_ID}
|
459
437
|
return response
|
460
438
|
|
461
|
-
async def send_feedback(
|
462
|
-
self,
|
463
|
-
kbid: str,
|
464
|
-
item: FeedbackRequest,
|
465
|
-
x_nucliadb_user: str,
|
466
|
-
x_ndb_client: str,
|
467
|
-
x_forwarded_for: str,
|
468
|
-
):
|
469
|
-
self.calls.append(("send_feedback", item))
|
470
|
-
return
|
471
|
-
|
472
439
|
async def rephrase_query(self, kbid: str, item: RephraseModel) -> str:
|
473
440
|
self.calls.append(("rephrase_query", item))
|
474
441
|
return DUMMY_REPHRASE_QUERY
|
475
442
|
|
476
|
-
async def
|
443
|
+
async def chat_query_ndjson(
|
477
444
|
self, kbid: str, item: ChatModel
|
478
|
-
) -> tuple[str, AsyncIterator[
|
479
|
-
self.calls.append(("
|
445
|
+
) -> tuple[str, str, AsyncIterator[GenerativeChunk]]:
|
446
|
+
self.calls.append(("chat_query_ndjson", item))
|
480
447
|
|
481
448
|
async def generate():
|
482
|
-
for
|
483
|
-
yield
|
484
|
-
|
485
|
-
return (DUMMY_LEARNING_ID, generate())
|
449
|
+
for item in self.ndjson_answer:
|
450
|
+
yield GenerativeChunk.model_validate_json(item)
|
486
451
|
|
487
|
-
|
488
|
-
self, kbid: str, query: str, blocks: list[list[str]], user_id: str
|
489
|
-
) -> str:
|
490
|
-
self.calls.append(("ask_document", (query, blocks, user_id)))
|
491
|
-
answer = os.environ.get("TEST_ASK_DOCUMENT") or "Answer to your question"
|
492
|
-
return answer
|
452
|
+
return (DUMMY_LEARNING_ID, DUMMY_LEARNING_MODEL, generate())
|
493
453
|
|
494
454
|
async def query(
|
495
455
|
self,
|
496
456
|
kbid: str,
|
497
457
|
sentence: str,
|
458
|
+
semantic_model: Optional[str] = None,
|
498
459
|
generative_model: Optional[str] = None,
|
499
|
-
rephrase:
|
460
|
+
rephrase: bool = False,
|
461
|
+
rephrase_prompt: Optional[str] = None,
|
500
462
|
) -> QueryInfo:
|
501
463
|
self.calls.append(("query", sentence))
|
502
|
-
if (
|
503
|
-
os.environ.get("TEST_SENTENCE_ENCODER") == "multilingual-2023-02-21"
|
504
|
-
): # pragma: no cover
|
505
|
-
return QueryInfo(
|
506
|
-
language="en",
|
507
|
-
stop_words=[],
|
508
|
-
semantic_threshold=0.7,
|
509
|
-
visual_llm=True,
|
510
|
-
max_context=self.max_context,
|
511
|
-
entities=TokenSearch(
|
512
|
-
tokens=[Ner(text="text", ner="PERSON", start=0, end=2)], time=0.0
|
513
|
-
),
|
514
|
-
sentence=SentenceSearch(data=Qm2023, time=0.0),
|
515
|
-
query=sentence,
|
516
|
-
)
|
517
|
-
else:
|
518
|
-
return QueryInfo(
|
519
|
-
language="en",
|
520
|
-
stop_words=[],
|
521
|
-
semantic_threshold=0.7,
|
522
|
-
visual_llm=True,
|
523
|
-
max_context=self.max_context,
|
524
|
-
entities=TokenSearch(
|
525
|
-
tokens=[Ner(text="text", ner="PERSON", start=0, end=2)], time=0.0
|
526
|
-
),
|
527
|
-
sentence=SentenceSearch(data=Q, time=0.0),
|
528
|
-
query=sentence,
|
529
|
-
)
|
530
464
|
|
531
|
-
|
532
|
-
|
533
|
-
if (
|
534
|
-
os.environ.get("TEST_SENTENCE_ENCODER") == "multilingual-2023-02-21"
|
535
|
-
): # pragma: no cover
|
536
|
-
return Qm2023
|
465
|
+
if os.environ.get("TEST_SENTENCE_ENCODER") == "multilingual-2023-02-21": # pragma: no cover
|
466
|
+
base_vector = Qm2023
|
537
467
|
else:
|
538
|
-
|
468
|
+
base_vector = Q
|
469
|
+
|
470
|
+
# populate data with existing vectorsets
|
471
|
+
async with datamanagers.with_ro_transaction() as txn:
|
472
|
+
semantic_thresholds = {}
|
473
|
+
vectors = {}
|
474
|
+
timings = {}
|
475
|
+
async for vectorset_id, config in datamanagers.vectorsets.iter(txn, kbid=kbid):
|
476
|
+
semantic_thresholds[vectorset_id] = self.default_semantic_threshold
|
477
|
+
vectorset_dimension = config.vectorset_index_config.vector_dimension
|
478
|
+
if vectorset_dimension > len(base_vector):
|
479
|
+
padding = vectorset_dimension - len(base_vector)
|
480
|
+
vectors[vectorset_id] = base_vector + [random.random()] * padding
|
481
|
+
else:
|
482
|
+
vectors[vectorset_id] = base_vector[:vectorset_dimension]
|
483
|
+
|
484
|
+
timings[vectorset_id] = 0.010
|
485
|
+
|
486
|
+
# and fake data with the passed one too
|
487
|
+
model = semantic_model or "<PREDICT-DEFAULT-SEMANTIC-MODEL>"
|
488
|
+
semantic_thresholds[model] = self.default_semantic_threshold
|
489
|
+
vectors[model] = base_vector
|
490
|
+
timings[model] = 0.0
|
491
|
+
|
492
|
+
return QueryInfo(
|
493
|
+
language="en",
|
494
|
+
stop_words=[],
|
495
|
+
semantic_thresholds=semantic_thresholds,
|
496
|
+
visual_llm=True,
|
497
|
+
max_context=self.max_context,
|
498
|
+
entities=TokenSearch(tokens=[Ner(text="text", ner="PERSON", start=0, end=2)], time=0.0),
|
499
|
+
sentence=SentenceSearch(
|
500
|
+
vectors=vectors,
|
501
|
+
timings=timings,
|
502
|
+
),
|
503
|
+
query=sentence,
|
504
|
+
)
|
539
505
|
|
540
506
|
async def detect_entities(self, kbid: str, sentence: str) -> list[RelationNode]:
|
541
507
|
self.calls.append(("detect_entities", sentence))
|
@@ -554,9 +520,16 @@ class DummyPredictEngine(PredictEngine):
|
|
554
520
|
rsummary = []
|
555
521
|
for field_id, field_text in item.resources[rid].fields.items():
|
556
522
|
rsummary.append(f"{field_id}: {field_text}")
|
557
|
-
response.resources[rid] = SummarizedResource(
|
558
|
-
|
559
|
-
|
523
|
+
response.resources[rid] = SummarizedResource(summary="\n\n".join(rsummary), tokens=10)
|
524
|
+
return response
|
525
|
+
|
526
|
+
async def rerank(self, kbid: str, item: RerankModel) -> RerankResponse:
|
527
|
+
self.calls.append(("rerank", (kbid, item)))
|
528
|
+
# as we don't have information about the retrieval scores, return a
|
529
|
+
# random score given by the dict iteration
|
530
|
+
response = RerankResponse(
|
531
|
+
context_scores={paragraph_id: i for i, paragraph_id in enumerate(item.context.keys())}
|
532
|
+
)
|
560
533
|
return response
|
561
534
|
|
562
535
|
|
@@ -578,6 +551,21 @@ def get_answer_generator(response: aiohttp.ClientResponse):
|
|
578
551
|
return _iter_answer_chunks(response.content.iter_chunks())
|
579
552
|
|
580
553
|
|
554
|
+
def get_chat_ndjson_generator(
|
555
|
+
response: aiohttp.ClientResponse,
|
556
|
+
) -> AsyncIterator[GenerativeChunk]:
|
557
|
+
async def _parse_generative_chunks(gen):
|
558
|
+
async for chunk in gen:
|
559
|
+
try:
|
560
|
+
yield GenerativeChunk.model_validate_json(chunk.strip())
|
561
|
+
except ValidationError as ex:
|
562
|
+
errors.capture_exception(ex)
|
563
|
+
logger.error(f"Invalid chunk received: {chunk}")
|
564
|
+
continue
|
565
|
+
|
566
|
+
return _parse_generative_chunks(response.content)
|
567
|
+
|
568
|
+
|
581
569
|
async def _parse_rephrase_response(
|
582
570
|
resp: aiohttp.ClientResponse,
|
583
571
|
) -> str:
|
nucliadb/search/py.typed
ADDED
File without changes
|