nucliadb 2.46.1.post382__py3-none-any.whl → 6.2.1.post2777__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- migrations/0002_rollover_shards.py +1 -2
- migrations/0003_allfields_key.py +2 -37
- migrations/0004_rollover_shards.py +1 -2
- migrations/0005_rollover_shards.py +1 -2
- migrations/0006_rollover_shards.py +2 -4
- migrations/0008_cleanup_leftover_rollover_metadata.py +1 -2
- migrations/0009_upgrade_relations_and_texts_to_v2.py +5 -4
- migrations/0010_fix_corrupt_indexes.py +11 -12
- migrations/0011_materialize_labelset_ids.py +2 -18
- migrations/0012_rollover_shards.py +6 -12
- migrations/0013_rollover_shards.py +2 -4
- migrations/0014_rollover_shards.py +5 -7
- migrations/0015_targeted_rollover.py +6 -12
- migrations/0016_upgrade_to_paragraphs_v2.py +27 -32
- migrations/0017_multiple_writable_shards.py +3 -6
- migrations/0018_purge_orphan_kbslugs.py +59 -0
- migrations/0019_upgrade_to_paragraphs_v3.py +66 -0
- migrations/0020_drain_nodes_from_cluster.py +83 -0
- nucliadb/standalone/tests/unit/test_run.py → migrations/0021_overwrite_vectorsets_key.py +17 -18
- nucliadb/tests/unit/test_openapi.py → migrations/0022_fix_paragraph_deletion_bug.py +16 -11
- migrations/0023_backfill_pg_catalog.py +80 -0
- migrations/0025_assign_models_to_kbs_v2.py +113 -0
- migrations/0026_fix_high_cardinality_content_types.py +61 -0
- migrations/0027_rollover_texts3.py +73 -0
- nucliadb/ingest/fields/date.py → migrations/pg/0001_bootstrap.py +10 -12
- migrations/pg/0002_catalog.py +42 -0
- nucliadb/ingest/tests/unit/test_settings.py → migrations/pg/0003_catalog_kbid_index.py +5 -3
- nucliadb/common/cluster/base.py +41 -24
- nucliadb/common/cluster/discovery/base.py +6 -14
- nucliadb/common/cluster/discovery/k8s.py +9 -19
- nucliadb/common/cluster/discovery/manual.py +1 -3
- nucliadb/common/cluster/discovery/single.py +1 -2
- nucliadb/common/cluster/discovery/utils.py +1 -3
- nucliadb/common/cluster/grpc_node_dummy.py +11 -16
- nucliadb/common/cluster/index_node.py +10 -19
- nucliadb/common/cluster/manager.py +223 -102
- nucliadb/common/cluster/rebalance.py +42 -37
- nucliadb/common/cluster/rollover.py +377 -204
- nucliadb/common/cluster/settings.py +16 -9
- nucliadb/common/cluster/standalone/grpc_node_binding.py +24 -76
- nucliadb/common/cluster/standalone/index_node.py +4 -11
- nucliadb/common/cluster/standalone/service.py +2 -6
- nucliadb/common/cluster/standalone/utils.py +9 -6
- nucliadb/common/cluster/utils.py +43 -29
- nucliadb/common/constants.py +20 -0
- nucliadb/common/context/__init__.py +6 -4
- nucliadb/common/context/fastapi.py +8 -5
- nucliadb/{tests/knowledgeboxes/__init__.py → common/counters.py} +8 -2
- nucliadb/common/datamanagers/__init__.py +24 -5
- nucliadb/common/datamanagers/atomic.py +102 -0
- nucliadb/common/datamanagers/cluster.py +5 -5
- nucliadb/common/datamanagers/entities.py +6 -16
- nucliadb/common/datamanagers/fields.py +84 -0
- nucliadb/common/datamanagers/kb.py +101 -24
- nucliadb/common/datamanagers/labels.py +26 -56
- nucliadb/common/datamanagers/processing.py +2 -6
- nucliadb/common/datamanagers/resources.py +214 -117
- nucliadb/common/datamanagers/rollover.py +77 -16
- nucliadb/{ingest/orm → common/datamanagers}/synonyms.py +16 -28
- nucliadb/common/datamanagers/utils.py +19 -11
- nucliadb/common/datamanagers/vectorsets.py +110 -0
- nucliadb/common/external_index_providers/base.py +257 -0
- nucliadb/{ingest/tests/unit/test_cache.py → common/external_index_providers/exceptions.py} +9 -8
- nucliadb/common/external_index_providers/manager.py +101 -0
- nucliadb/common/external_index_providers/pinecone.py +933 -0
- nucliadb/common/external_index_providers/settings.py +52 -0
- nucliadb/common/http_clients/auth.py +3 -6
- nucliadb/common/http_clients/processing.py +6 -11
- nucliadb/common/http_clients/utils.py +1 -3
- nucliadb/common/ids.py +240 -0
- nucliadb/common/locking.py +43 -13
- nucliadb/common/maindb/driver.py +11 -35
- nucliadb/common/maindb/exceptions.py +6 -6
- nucliadb/common/maindb/local.py +22 -9
- nucliadb/common/maindb/pg.py +206 -111
- nucliadb/common/maindb/utils.py +13 -44
- nucliadb/common/models_utils/from_proto.py +479 -0
- nucliadb/common/models_utils/to_proto.py +60 -0
- nucliadb/common/nidx.py +260 -0
- nucliadb/export_import/datamanager.py +25 -19
- nucliadb/export_import/exceptions.py +8 -0
- nucliadb/export_import/exporter.py +20 -7
- nucliadb/export_import/importer.py +6 -11
- nucliadb/export_import/models.py +5 -5
- nucliadb/export_import/tasks.py +4 -4
- nucliadb/export_import/utils.py +94 -54
- nucliadb/health.py +1 -3
- nucliadb/ingest/app.py +15 -11
- nucliadb/ingest/consumer/auditing.py +30 -147
- nucliadb/ingest/consumer/consumer.py +96 -52
- nucliadb/ingest/consumer/materializer.py +10 -12
- nucliadb/ingest/consumer/pull.py +12 -27
- nucliadb/ingest/consumer/service.py +20 -19
- nucliadb/ingest/consumer/shard_creator.py +7 -14
- nucliadb/ingest/consumer/utils.py +1 -3
- nucliadb/ingest/fields/base.py +139 -188
- nucliadb/ingest/fields/conversation.py +18 -5
- nucliadb/ingest/fields/exceptions.py +1 -4
- nucliadb/ingest/fields/file.py +7 -25
- nucliadb/ingest/fields/link.py +11 -16
- nucliadb/ingest/fields/text.py +9 -4
- nucliadb/ingest/orm/brain.py +255 -262
- nucliadb/ingest/orm/broker_message.py +181 -0
- nucliadb/ingest/orm/entities.py +36 -51
- nucliadb/ingest/orm/exceptions.py +12 -0
- nucliadb/ingest/orm/knowledgebox.py +334 -278
- nucliadb/ingest/orm/processor/__init__.py +2 -697
- nucliadb/ingest/orm/processor/auditing.py +117 -0
- nucliadb/ingest/orm/processor/data_augmentation.py +164 -0
- nucliadb/ingest/orm/processor/pgcatalog.py +84 -0
- nucliadb/ingest/orm/processor/processor.py +752 -0
- nucliadb/ingest/orm/processor/sequence_manager.py +1 -1
- nucliadb/ingest/orm/resource.py +280 -520
- nucliadb/ingest/orm/utils.py +25 -31
- nucliadb/ingest/partitions.py +3 -9
- nucliadb/ingest/processing.py +76 -81
- nucliadb/ingest/py.typed +0 -0
- nucliadb/ingest/serialize.py +37 -173
- nucliadb/ingest/service/__init__.py +1 -3
- nucliadb/ingest/service/writer.py +186 -577
- nucliadb/ingest/settings.py +13 -22
- nucliadb/ingest/utils.py +3 -6
- nucliadb/learning_proxy.py +264 -51
- nucliadb/metrics_exporter.py +30 -19
- nucliadb/middleware/__init__.py +1 -3
- nucliadb/migrator/command.py +1 -3
- nucliadb/migrator/datamanager.py +13 -13
- nucliadb/migrator/migrator.py +57 -37
- nucliadb/migrator/settings.py +2 -1
- nucliadb/migrator/utils.py +18 -10
- nucliadb/purge/__init__.py +139 -33
- nucliadb/purge/orphan_shards.py +7 -13
- nucliadb/reader/__init__.py +1 -3
- nucliadb/reader/api/models.py +3 -14
- nucliadb/reader/api/v1/__init__.py +0 -1
- nucliadb/reader/api/v1/download.py +27 -94
- nucliadb/reader/api/v1/export_import.py +4 -4
- nucliadb/reader/api/v1/knowledgebox.py +13 -13
- nucliadb/reader/api/v1/learning_config.py +8 -12
- nucliadb/reader/api/v1/resource.py +67 -93
- nucliadb/reader/api/v1/services.py +70 -125
- nucliadb/reader/app.py +16 -46
- nucliadb/reader/lifecycle.py +18 -4
- nucliadb/reader/py.typed +0 -0
- nucliadb/reader/reader/notifications.py +10 -31
- nucliadb/search/__init__.py +1 -3
- nucliadb/search/api/v1/__init__.py +2 -2
- nucliadb/search/api/v1/ask.py +112 -0
- nucliadb/search/api/v1/catalog.py +184 -0
- nucliadb/search/api/v1/feedback.py +17 -25
- nucliadb/search/api/v1/find.py +41 -41
- nucliadb/search/api/v1/knowledgebox.py +90 -62
- nucliadb/search/api/v1/predict_proxy.py +2 -2
- nucliadb/search/api/v1/resource/ask.py +66 -117
- nucliadb/search/api/v1/resource/search.py +51 -72
- nucliadb/search/api/v1/router.py +1 -0
- nucliadb/search/api/v1/search.py +50 -197
- nucliadb/search/api/v1/suggest.py +40 -54
- nucliadb/search/api/v1/summarize.py +9 -5
- nucliadb/search/api/v1/utils.py +2 -1
- nucliadb/search/app.py +16 -48
- nucliadb/search/lifecycle.py +10 -3
- nucliadb/search/predict.py +176 -188
- nucliadb/search/py.typed +0 -0
- nucliadb/search/requesters/utils.py +41 -63
- nucliadb/search/search/cache.py +149 -20
- nucliadb/search/search/chat/ask.py +918 -0
- nucliadb/search/{tests/unit/test_run.py → search/chat/exceptions.py} +14 -13
- nucliadb/search/search/chat/images.py +41 -17
- nucliadb/search/search/chat/prompt.py +851 -282
- nucliadb/search/search/chat/query.py +274 -267
- nucliadb/{writer/resource/slug.py → search/search/cut.py} +8 -6
- nucliadb/search/search/fetch.py +43 -36
- nucliadb/search/search/filters.py +9 -15
- nucliadb/search/search/find.py +214 -54
- nucliadb/search/search/find_merge.py +408 -391
- nucliadb/search/search/hydrator.py +191 -0
- nucliadb/search/search/merge.py +198 -234
- nucliadb/search/search/metrics.py +73 -2
- nucliadb/search/search/paragraphs.py +64 -106
- nucliadb/search/search/pgcatalog.py +233 -0
- nucliadb/search/search/predict_proxy.py +1 -1
- nucliadb/search/search/query.py +386 -257
- nucliadb/search/search/query_parser/exceptions.py +22 -0
- nucliadb/search/search/query_parser/models.py +101 -0
- nucliadb/search/search/query_parser/parser.py +183 -0
- nucliadb/search/search/rank_fusion.py +204 -0
- nucliadb/search/search/rerankers.py +270 -0
- nucliadb/search/search/shards.py +4 -38
- nucliadb/search/search/summarize.py +14 -18
- nucliadb/search/search/utils.py +27 -4
- nucliadb/search/settings.py +15 -1
- nucliadb/standalone/api_router.py +4 -10
- nucliadb/standalone/app.py +17 -14
- nucliadb/standalone/auth.py +7 -21
- nucliadb/standalone/config.py +9 -12
- nucliadb/standalone/introspect.py +5 -5
- nucliadb/standalone/lifecycle.py +26 -25
- nucliadb/standalone/migrations.py +58 -0
- nucliadb/standalone/purge.py +9 -8
- nucliadb/standalone/py.typed +0 -0
- nucliadb/standalone/run.py +25 -18
- nucliadb/standalone/settings.py +10 -14
- nucliadb/standalone/versions.py +15 -5
- nucliadb/tasks/consumer.py +8 -12
- nucliadb/tasks/producer.py +7 -6
- nucliadb/tests/config.py +53 -0
- nucliadb/train/__init__.py +1 -3
- nucliadb/train/api/utils.py +1 -2
- nucliadb/train/api/v1/shards.py +2 -2
- nucliadb/train/api/v1/trainset.py +4 -6
- nucliadb/train/app.py +14 -47
- nucliadb/train/generator.py +10 -19
- nucliadb/train/generators/field_classifier.py +7 -19
- nucliadb/train/generators/field_streaming.py +156 -0
- nucliadb/train/generators/image_classifier.py +12 -18
- nucliadb/train/generators/paragraph_classifier.py +5 -9
- nucliadb/train/generators/paragraph_streaming.py +6 -9
- nucliadb/train/generators/question_answer_streaming.py +19 -20
- nucliadb/train/generators/sentence_classifier.py +9 -15
- nucliadb/train/generators/token_classifier.py +45 -36
- nucliadb/train/generators/utils.py +14 -18
- nucliadb/train/lifecycle.py +7 -3
- nucliadb/train/nodes.py +23 -32
- nucliadb/train/py.typed +0 -0
- nucliadb/train/servicer.py +13 -21
- nucliadb/train/settings.py +2 -6
- nucliadb/train/types.py +13 -10
- nucliadb/train/upload.py +3 -6
- nucliadb/train/uploader.py +20 -25
- nucliadb/train/utils.py +1 -1
- nucliadb/writer/__init__.py +1 -3
- nucliadb/writer/api/constants.py +0 -5
- nucliadb/{ingest/fields/keywordset.py → writer/api/utils.py} +13 -10
- nucliadb/writer/api/v1/export_import.py +102 -49
- nucliadb/writer/api/v1/field.py +196 -620
- nucliadb/writer/api/v1/knowledgebox.py +221 -71
- nucliadb/writer/api/v1/learning_config.py +2 -2
- nucliadb/writer/api/v1/resource.py +114 -216
- nucliadb/writer/api/v1/services.py +64 -132
- nucliadb/writer/api/v1/slug.py +61 -0
- nucliadb/writer/api/v1/transaction.py +67 -0
- nucliadb/writer/api/v1/upload.py +184 -215
- nucliadb/writer/app.py +11 -61
- nucliadb/writer/back_pressure.py +62 -43
- nucliadb/writer/exceptions.py +0 -4
- nucliadb/writer/lifecycle.py +21 -15
- nucliadb/writer/py.typed +0 -0
- nucliadb/writer/resource/audit.py +2 -1
- nucliadb/writer/resource/basic.py +48 -62
- nucliadb/writer/resource/field.py +45 -135
- nucliadb/writer/resource/origin.py +1 -2
- nucliadb/writer/settings.py +14 -5
- nucliadb/writer/tus/__init__.py +17 -15
- nucliadb/writer/tus/azure.py +111 -0
- nucliadb/writer/tus/dm.py +17 -5
- nucliadb/writer/tus/exceptions.py +1 -3
- nucliadb/writer/tus/gcs.py +56 -84
- nucliadb/writer/tus/local.py +21 -37
- nucliadb/writer/tus/s3.py +28 -68
- nucliadb/writer/tus/storage.py +5 -56
- nucliadb/writer/vectorsets.py +125 -0
- nucliadb-6.2.1.post2777.dist-info/METADATA +148 -0
- nucliadb-6.2.1.post2777.dist-info/RECORD +343 -0
- {nucliadb-2.46.1.post382.dist-info → nucliadb-6.2.1.post2777.dist-info}/WHEEL +1 -1
- nucliadb/common/maindb/redis.py +0 -194
- nucliadb/common/maindb/tikv.py +0 -412
- nucliadb/ingest/fields/layout.py +0 -58
- nucliadb/ingest/tests/conftest.py +0 -30
- nucliadb/ingest/tests/fixtures.py +0 -771
- nucliadb/ingest/tests/integration/consumer/__init__.py +0 -18
- nucliadb/ingest/tests/integration/consumer/test_auditing.py +0 -80
- nucliadb/ingest/tests/integration/consumer/test_materializer.py +0 -89
- nucliadb/ingest/tests/integration/consumer/test_pull.py +0 -144
- nucliadb/ingest/tests/integration/consumer/test_service.py +0 -81
- nucliadb/ingest/tests/integration/consumer/test_shard_creator.py +0 -68
- nucliadb/ingest/tests/integration/ingest/test_ingest.py +0 -691
- nucliadb/ingest/tests/integration/ingest/test_processing_engine.py +0 -95
- nucliadb/ingest/tests/integration/ingest/test_relations.py +0 -272
- nucliadb/ingest/tests/unit/consumer/__init__.py +0 -18
- nucliadb/ingest/tests/unit/consumer/test_auditing.py +0 -140
- nucliadb/ingest/tests/unit/consumer/test_consumer.py +0 -69
- nucliadb/ingest/tests/unit/consumer/test_pull.py +0 -60
- nucliadb/ingest/tests/unit/consumer/test_shard_creator.py +0 -139
- nucliadb/ingest/tests/unit/consumer/test_utils.py +0 -67
- nucliadb/ingest/tests/unit/orm/__init__.py +0 -19
- nucliadb/ingest/tests/unit/orm/test_brain.py +0 -247
- nucliadb/ingest/tests/unit/orm/test_processor.py +0 -131
- nucliadb/ingest/tests/unit/orm/test_resource.py +0 -275
- nucliadb/ingest/tests/unit/test_partitions.py +0 -40
- nucliadb/ingest/tests/unit/test_processing.py +0 -171
- nucliadb/middleware/transaction.py +0 -117
- nucliadb/reader/api/v1/learning_collector.py +0 -63
- nucliadb/reader/tests/__init__.py +0 -19
- nucliadb/reader/tests/conftest.py +0 -31
- nucliadb/reader/tests/fixtures.py +0 -136
- nucliadb/reader/tests/test_list_resources.py +0 -75
- nucliadb/reader/tests/test_reader_file_download.py +0 -273
- nucliadb/reader/tests/test_reader_resource.py +0 -379
- nucliadb/reader/tests/test_reader_resource_field.py +0 -219
- nucliadb/search/api/v1/chat.py +0 -258
- nucliadb/search/api/v1/resource/chat.py +0 -94
- nucliadb/search/tests/__init__.py +0 -19
- nucliadb/search/tests/conftest.py +0 -33
- nucliadb/search/tests/fixtures.py +0 -199
- nucliadb/search/tests/node.py +0 -465
- nucliadb/search/tests/unit/__init__.py +0 -18
- nucliadb/search/tests/unit/api/__init__.py +0 -19
- nucliadb/search/tests/unit/api/v1/__init__.py +0 -19
- nucliadb/search/tests/unit/api/v1/resource/__init__.py +0 -19
- nucliadb/search/tests/unit/api/v1/resource/test_ask.py +0 -67
- nucliadb/search/tests/unit/api/v1/resource/test_chat.py +0 -97
- nucliadb/search/tests/unit/api/v1/test_chat.py +0 -96
- nucliadb/search/tests/unit/api/v1/test_predict_proxy.py +0 -98
- nucliadb/search/tests/unit/api/v1/test_summarize.py +0 -93
- nucliadb/search/tests/unit/search/__init__.py +0 -18
- nucliadb/search/tests/unit/search/requesters/__init__.py +0 -18
- nucliadb/search/tests/unit/search/requesters/test_utils.py +0 -210
- nucliadb/search/tests/unit/search/search/__init__.py +0 -19
- nucliadb/search/tests/unit/search/search/test_shards.py +0 -45
- nucliadb/search/tests/unit/search/search/test_utils.py +0 -82
- nucliadb/search/tests/unit/search/test_chat_prompt.py +0 -266
- nucliadb/search/tests/unit/search/test_fetch.py +0 -108
- nucliadb/search/tests/unit/search/test_filters.py +0 -125
- nucliadb/search/tests/unit/search/test_paragraphs.py +0 -157
- nucliadb/search/tests/unit/search/test_predict_proxy.py +0 -106
- nucliadb/search/tests/unit/search/test_query.py +0 -201
- nucliadb/search/tests/unit/test_app.py +0 -79
- nucliadb/search/tests/unit/test_find_merge.py +0 -112
- nucliadb/search/tests/unit/test_merge.py +0 -34
- nucliadb/search/tests/unit/test_predict.py +0 -584
- nucliadb/standalone/tests/__init__.py +0 -19
- nucliadb/standalone/tests/conftest.py +0 -33
- nucliadb/standalone/tests/fixtures.py +0 -38
- nucliadb/standalone/tests/unit/__init__.py +0 -18
- nucliadb/standalone/tests/unit/test_api_router.py +0 -61
- nucliadb/standalone/tests/unit/test_auth.py +0 -169
- nucliadb/standalone/tests/unit/test_introspect.py +0 -35
- nucliadb/standalone/tests/unit/test_versions.py +0 -68
- nucliadb/tests/benchmarks/__init__.py +0 -19
- nucliadb/tests/benchmarks/test_search.py +0 -99
- nucliadb/tests/conftest.py +0 -32
- nucliadb/tests/fixtures.py +0 -736
- nucliadb/tests/knowledgeboxes/philosophy_books.py +0 -203
- nucliadb/tests/knowledgeboxes/ten_dummy_resources.py +0 -109
- nucliadb/tests/migrations/__init__.py +0 -19
- nucliadb/tests/migrations/test_migration_0017.py +0 -80
- nucliadb/tests/tikv.py +0 -240
- nucliadb/tests/unit/__init__.py +0 -19
- nucliadb/tests/unit/common/__init__.py +0 -19
- nucliadb/tests/unit/common/cluster/__init__.py +0 -19
- nucliadb/tests/unit/common/cluster/discovery/__init__.py +0 -19
- nucliadb/tests/unit/common/cluster/discovery/test_k8s.py +0 -170
- nucliadb/tests/unit/common/cluster/standalone/__init__.py +0 -18
- nucliadb/tests/unit/common/cluster/standalone/test_service.py +0 -113
- nucliadb/tests/unit/common/cluster/standalone/test_utils.py +0 -59
- nucliadb/tests/unit/common/cluster/test_cluster.py +0 -399
- nucliadb/tests/unit/common/cluster/test_kb_shard_manager.py +0 -178
- nucliadb/tests/unit/common/cluster/test_rollover.py +0 -279
- nucliadb/tests/unit/common/maindb/__init__.py +0 -18
- nucliadb/tests/unit/common/maindb/test_driver.py +0 -127
- nucliadb/tests/unit/common/maindb/test_tikv.py +0 -53
- nucliadb/tests/unit/common/maindb/test_utils.py +0 -81
- nucliadb/tests/unit/common/test_context.py +0 -36
- nucliadb/tests/unit/export_import/__init__.py +0 -19
- nucliadb/tests/unit/export_import/test_datamanager.py +0 -37
- nucliadb/tests/unit/export_import/test_utils.py +0 -294
- nucliadb/tests/unit/migrator/__init__.py +0 -19
- nucliadb/tests/unit/migrator/test_migrator.py +0 -87
- nucliadb/tests/unit/tasks/__init__.py +0 -19
- nucliadb/tests/unit/tasks/conftest.py +0 -42
- nucliadb/tests/unit/tasks/test_consumer.py +0 -93
- nucliadb/tests/unit/tasks/test_producer.py +0 -95
- nucliadb/tests/unit/tasks/test_tasks.py +0 -60
- nucliadb/tests/unit/test_field_ids.py +0 -49
- nucliadb/tests/unit/test_health.py +0 -84
- nucliadb/tests/unit/test_kb_slugs.py +0 -54
- nucliadb/tests/unit/test_learning_proxy.py +0 -252
- nucliadb/tests/unit/test_metrics_exporter.py +0 -77
- nucliadb/tests/unit/test_purge.py +0 -138
- nucliadb/tests/utils/__init__.py +0 -74
- nucliadb/tests/utils/aiohttp_session.py +0 -44
- nucliadb/tests/utils/broker_messages/__init__.py +0 -167
- nucliadb/tests/utils/broker_messages/fields.py +0 -181
- nucliadb/tests/utils/broker_messages/helpers.py +0 -33
- nucliadb/tests/utils/entities.py +0 -78
- nucliadb/train/api/v1/check.py +0 -60
- nucliadb/train/tests/__init__.py +0 -19
- nucliadb/train/tests/conftest.py +0 -29
- nucliadb/train/tests/fixtures.py +0 -342
- nucliadb/train/tests/test_field_classification.py +0 -122
- nucliadb/train/tests/test_get_entities.py +0 -80
- nucliadb/train/tests/test_get_info.py +0 -51
- nucliadb/train/tests/test_get_ontology.py +0 -34
- nucliadb/train/tests/test_get_ontology_count.py +0 -63
- nucliadb/train/tests/test_image_classification.py +0 -222
- nucliadb/train/tests/test_list_fields.py +0 -39
- nucliadb/train/tests/test_list_paragraphs.py +0 -73
- nucliadb/train/tests/test_list_resources.py +0 -39
- nucliadb/train/tests/test_list_sentences.py +0 -71
- nucliadb/train/tests/test_paragraph_classification.py +0 -123
- nucliadb/train/tests/test_paragraph_streaming.py +0 -118
- nucliadb/train/tests/test_question_answer_streaming.py +0 -239
- nucliadb/train/tests/test_sentence_classification.py +0 -143
- nucliadb/train/tests/test_token_classification.py +0 -136
- nucliadb/train/tests/utils.py +0 -108
- nucliadb/writer/layouts/__init__.py +0 -51
- nucliadb/writer/layouts/v1.py +0 -59
- nucliadb/writer/resource/vectors.py +0 -120
- nucliadb/writer/tests/__init__.py +0 -19
- nucliadb/writer/tests/conftest.py +0 -31
- nucliadb/writer/tests/fixtures.py +0 -192
- nucliadb/writer/tests/test_fields.py +0 -486
- nucliadb/writer/tests/test_files.py +0 -743
- nucliadb/writer/tests/test_knowledgebox.py +0 -49
- nucliadb/writer/tests/test_reprocess_file_field.py +0 -139
- nucliadb/writer/tests/test_resources.py +0 -546
- nucliadb/writer/tests/test_service.py +0 -137
- nucliadb/writer/tests/test_tus.py +0 -203
- nucliadb/writer/tests/utils.py +0 -35
- nucliadb/writer/tus/pg.py +0 -125
- nucliadb-2.46.1.post382.dist-info/METADATA +0 -134
- nucliadb-2.46.1.post382.dist-info/RECORD +0 -451
- {nucliadb/ingest/tests → migrations/pg}/__init__.py +0 -0
- /nucliadb/{ingest/tests/integration → common/external_index_providers}/__init__.py +0 -0
- /nucliadb/{ingest/tests/integration/ingest → common/models_utils}/__init__.py +0 -0
- /nucliadb/{ingest/tests/unit → search/search/query_parser}/__init__.py +0 -0
- /nucliadb/{ingest/tests → tests}/vectors.py +0 -0
- {nucliadb-2.46.1.post382.dist-info → nucliadb-6.2.1.post2777.dist-info}/entry_points.txt +0 -0
- {nucliadb-2.46.1.post382.dist-info → nucliadb-6.2.1.post2777.dist-info}/top_level.txt +0 -0
- {nucliadb-2.46.1.post382.dist-info → nucliadb-6.2.1.post2777.dist-info}/zip-safe +0 -0
@@ -20,6 +20,14 @@
|
|
20
20
|
|
21
21
|
from typing import AsyncGenerator
|
22
22
|
|
23
|
+
from nucliadb.common.cluster.base import AbstractIndexNode
|
24
|
+
from nucliadb.common.ids import FIELD_TYPE_PB_TO_STR, FIELD_TYPE_STR_TO_PB
|
25
|
+
from nucliadb.train import logger
|
26
|
+
from nucliadb.train.generators.utils import (
|
27
|
+
batchify,
|
28
|
+
get_paragraph,
|
29
|
+
get_resource_from_cache_or_db,
|
30
|
+
)
|
23
31
|
from nucliadb_protos.dataset_pb2 import (
|
24
32
|
QuestionAnswerStreamingBatch,
|
25
33
|
QuestionAnswerStreamItem,
|
@@ -32,15 +40,6 @@ from nucliadb_protos.resources_pb2 import (
|
|
32
40
|
QuestionAnswerAnnotation,
|
33
41
|
)
|
34
42
|
|
35
|
-
from nucliadb.common.cluster.base import AbstractIndexNode
|
36
|
-
from nucliadb.ingest.orm.resource import FIELD_TYPE_TO_ID, KB_REVERSE
|
37
|
-
from nucliadb.train import logger
|
38
|
-
from nucliadb.train.generators.utils import (
|
39
|
-
batchify,
|
40
|
-
get_paragraph,
|
41
|
-
get_resource_from_cache_or_db,
|
42
|
-
)
|
43
|
-
|
44
43
|
|
45
44
|
def question_answer_batch_generator(
|
46
45
|
kbid: str,
|
@@ -48,12 +47,8 @@ def question_answer_batch_generator(
|
|
48
47
|
node: AbstractIndexNode,
|
49
48
|
shard_replica_id: str,
|
50
49
|
) -> AsyncGenerator[QuestionAnswerStreamingBatch, None]:
|
51
|
-
generator = generate_question_answer_streaming_payloads(
|
52
|
-
|
53
|
-
)
|
54
|
-
batch_generator = batchify(
|
55
|
-
generator, trainset.batch_size, QuestionAnswerStreamingBatch
|
56
|
-
)
|
50
|
+
generator = generate_question_answer_streaming_payloads(kbid, trainset, node, shard_replica_id)
|
51
|
+
batch_generator = batchify(generator, trainset.batch_size, QuestionAnswerStreamingBatch)
|
57
52
|
return batch_generator
|
58
53
|
|
59
54
|
|
@@ -90,14 +85,18 @@ async def generate_question_answer_streaming_payloads(
|
|
90
85
|
item.cancelled_by_user = qa_annotation_pb.cancelled_by_user
|
91
86
|
yield item
|
92
87
|
|
93
|
-
field_type_int =
|
88
|
+
field_type_int = FIELD_TYPE_STR_TO_PB[field_type]
|
94
89
|
field_obj = await orm_resource.get_field(field, field_type_int, load=False)
|
95
90
|
|
96
91
|
question_answers_pb = await field_obj.get_question_answers()
|
97
92
|
if question_answers_pb is not None:
|
98
|
-
for question_answer_pb in question_answers_pb.question_answer:
|
93
|
+
for question_answer_pb in question_answers_pb.question_answers.question_answer:
|
99
94
|
async for item in iter_stream_items(kbid, question_answer_pb):
|
100
95
|
yield item
|
96
|
+
for question_answer_pb in question_answers_pb.split_question_answers.values():
|
97
|
+
for split_question_answer_pb in question_answer_pb.question_answer:
|
98
|
+
async for item in iter_stream_items(kbid, split_question_answer_pb):
|
99
|
+
yield item
|
101
100
|
|
102
101
|
|
103
102
|
async def iter_stream_items(
|
@@ -109,7 +108,7 @@ async def iter_stream_items(
|
|
109
108
|
for paragraph_id in question_pb.ids_paragraphs:
|
110
109
|
try:
|
111
110
|
text = await get_paragraph(kbid, paragraph_id)
|
112
|
-
except Exception as exc: # pragma:
|
111
|
+
except Exception as exc: # pragma: no cover
|
113
112
|
logger.warning(
|
114
113
|
"Question paragraph couldn't be fetched while streaming Q&A",
|
115
114
|
extra={"kbid": kbid, "paragraph_id": paragraph_id},
|
@@ -128,7 +127,7 @@ async def iter_stream_items(
|
|
128
127
|
for paragraph_id in answer_pb.ids_paragraphs:
|
129
128
|
try:
|
130
129
|
text = await get_paragraph(kbid, paragraph_id)
|
131
|
-
except Exception as exc: # pragma:
|
130
|
+
except Exception as exc: # pragma: no cover
|
132
131
|
logger.warning(
|
133
132
|
"Answer paragraph couldn't be fetched while streaming Q&A",
|
134
133
|
extra={"kbid": kbid, "paragraph_id": paragraph_id},
|
@@ -141,4 +140,4 @@ async def iter_stream_items(
|
|
141
140
|
|
142
141
|
|
143
142
|
def is_same_field(field: FieldID, field_id: str, field_type: str) -> bool:
|
144
|
-
return field.field == field_id and
|
143
|
+
return field.field == field_id and FIELD_TYPE_PB_TO_STR[field.field_type] == field_type
|
@@ -21,6 +21,11 @@
|
|
21
21
|
from typing import AsyncGenerator
|
22
22
|
|
23
23
|
from fastapi import HTTPException
|
24
|
+
|
25
|
+
from nucliadb.common.cluster.base import AbstractIndexNode
|
26
|
+
from nucliadb.common.ids import FIELD_TYPE_STR_TO_PB
|
27
|
+
from nucliadb.train import logger
|
28
|
+
from nucliadb.train.generators.utils import batchify, get_resource_from_cache_or_db
|
24
29
|
from nucliadb_protos.dataset_pb2 import (
|
25
30
|
Label,
|
26
31
|
MultipleTextSameLabels,
|
@@ -29,11 +34,6 @@ from nucliadb_protos.dataset_pb2 import (
|
|
29
34
|
)
|
30
35
|
from nucliadb_protos.nodereader_pb2 import StreamRequest
|
31
36
|
|
32
|
-
from nucliadb.common.cluster.base import AbstractIndexNode
|
33
|
-
from nucliadb.ingest.orm.resource import KB_REVERSE
|
34
|
-
from nucliadb.train import logger
|
35
|
-
from nucliadb.train.generators.utils import batchify, get_resource_from_cache_or_db
|
36
|
-
|
37
37
|
|
38
38
|
def sentence_classification_batch_generator(
|
39
39
|
kbid: str,
|
@@ -47,12 +47,8 @@ def sentence_classification_batch_generator(
|
|
47
47
|
detail="Sentence Classification should be at least of 1 labelset",
|
48
48
|
)
|
49
49
|
|
50
|
-
generator = generate_sentence_classification_payloads(
|
51
|
-
|
52
|
-
)
|
53
|
-
batch_generator = batchify(
|
54
|
-
generator, trainset.batch_size, SentenceClassificationBatch
|
55
|
-
)
|
50
|
+
generator = generate_sentence_classification_payloads(kbid, trainset, node, shard_replica_id)
|
51
|
+
batch_generator = batchify(generator, trainset.batch_size, SentenceClassificationBatch)
|
56
52
|
return batch_generator
|
57
53
|
|
58
54
|
|
@@ -107,14 +103,12 @@ async def get_sentences(kbid: str, result: str) -> list[str]:
|
|
107
103
|
logger.error(f"{rid} does not exist on DB")
|
108
104
|
return []
|
109
105
|
|
110
|
-
field_type_int =
|
106
|
+
field_type_int = FIELD_TYPE_STR_TO_PB[field_type]
|
111
107
|
field_obj = await orm_resource.get_field(field, field_type_int, load=False)
|
112
108
|
extracted_text = await field_obj.get_extracted_text()
|
113
109
|
field_metadata = await field_obj.get_field_metadata()
|
114
110
|
if extracted_text is None:
|
115
|
-
logger.warning(
|
116
|
-
f"{rid} {field} {field_type_int} extracted_text does not exist on DB"
|
117
|
-
)
|
111
|
+
logger.warning(f"{rid} {field} {field_type_int} extracted_text does not exist on DB")
|
118
112
|
return []
|
119
113
|
|
120
114
|
splitted_texts = []
|
@@ -21,6 +21,10 @@
|
|
21
21
|
from collections import OrderedDict
|
22
22
|
from typing import AsyncGenerator, cast
|
23
23
|
|
24
|
+
from nucliadb.common.cluster.base import AbstractIndexNode
|
25
|
+
from nucliadb.common.ids import FIELD_TYPE_STR_TO_PB
|
26
|
+
from nucliadb.train import logger
|
27
|
+
from nucliadb.train.generators.utils import batchify, get_resource_from_cache_or_db
|
24
28
|
from nucliadb_protos.dataset_pb2 import (
|
25
29
|
TokenClassificationBatch,
|
26
30
|
TokensClassification,
|
@@ -28,11 +32,6 @@ from nucliadb_protos.dataset_pb2 import (
|
|
28
32
|
)
|
29
33
|
from nucliadb_protos.nodereader_pb2 import StreamFilter, StreamRequest
|
30
34
|
|
31
|
-
from nucliadb.common.cluster.base import AbstractIndexNode
|
32
|
-
from nucliadb.ingest.orm.resource import KB_REVERSE
|
33
|
-
from nucliadb.train import logger
|
34
|
-
from nucliadb.train.generators.utils import batchify, get_resource_from_cache_or_db
|
35
|
-
|
36
35
|
NERS_DICT = dict[str, dict[str, list[tuple[int, int]]]]
|
37
36
|
POSITION_DICT = OrderedDict[tuple[int, int], tuple[str, str]]
|
38
37
|
MAIN = "__main__"
|
@@ -44,9 +43,7 @@ def token_classification_batch_generator(
|
|
44
43
|
node: AbstractIndexNode,
|
45
44
|
shard_replica_id: str,
|
46
45
|
) -> AsyncGenerator[TokenClassificationBatch, None]:
|
47
|
-
generator = generate_token_classification_payloads(
|
48
|
-
kbid, trainset, node, shard_replica_id
|
49
|
-
)
|
46
|
+
generator = generate_token_classification_payloads(kbid, trainset, node, shard_replica_id)
|
50
47
|
batch_generator = batchify(generator, trainset.batch_size, TokenClassificationBatch)
|
51
48
|
return batch_generator
|
52
49
|
|
@@ -97,13 +94,11 @@ async def get_field_text(
|
|
97
94
|
logger.error(f"{rid} does not exist on DB")
|
98
95
|
return {}, {}, {}
|
99
96
|
|
100
|
-
field_type_int =
|
97
|
+
field_type_int = FIELD_TYPE_STR_TO_PB[field_type]
|
101
98
|
field_obj = await orm_resource.get_field(field, field_type_int, load=False)
|
102
99
|
extracted_text = await field_obj.get_extracted_text()
|
103
100
|
if extracted_text is None:
|
104
|
-
logger.warning(
|
105
|
-
f"{rid} {field} {field_type_int} extracted_text does not exist on DB"
|
106
|
-
)
|
101
|
+
logger.warning(f"{rid} {field} {field_type_int} extracted_text does not exist on DB")
|
107
102
|
return {}, {}, {}
|
108
103
|
|
109
104
|
split_text: dict[str, str] = extracted_text.split_text
|
@@ -138,16 +133,42 @@ async def get_field_text(
|
|
138
133
|
split = MAIN
|
139
134
|
else:
|
140
135
|
split = token.split
|
141
|
-
split_ners[split].setdefault(token.klass, {}).setdefault(
|
142
|
-
|
143
|
-
)
|
144
|
-
split_ners[split][token.klass][token.token].append(
|
145
|
-
(token.start, token.end)
|
146
|
-
)
|
136
|
+
split_ners[split].setdefault(token.klass, {}).setdefault(token.token, [])
|
137
|
+
split_ners[split][token.klass][token.token].append((token.start, token.end))
|
147
138
|
|
148
139
|
field_metadata = await field_obj.get_field_metadata()
|
149
140
|
# Check computed definition of entities
|
150
141
|
if field_metadata is not None:
|
142
|
+
# Data Augmentation + Processor entities
|
143
|
+
for data_augmentation_task_id, entities in field_metadata.metadata.entities.items():
|
144
|
+
for entity in entities.entities:
|
145
|
+
entity_text = entity.text
|
146
|
+
entity_label = entity.label
|
147
|
+
entity_positions = entity.positions
|
148
|
+
if entity_label in valid_entity_groups:
|
149
|
+
split_ners[MAIN].setdefault(entity_label, {}).setdefault(entity_text, [])
|
150
|
+
for position in entity_positions:
|
151
|
+
split_ners[MAIN][entity_label][entity_text].append(
|
152
|
+
(position.start, position.end)
|
153
|
+
)
|
154
|
+
|
155
|
+
for split, split_metadata in field_metadata.split_metadata.items():
|
156
|
+
for data_augmentation_task_id, entities in split_metadata.entities.items():
|
157
|
+
for entity in entities.entities:
|
158
|
+
entity_text = entity.text
|
159
|
+
entity_label = entity.label
|
160
|
+
entity_positions = entity.positions
|
161
|
+
if entity_label in valid_entity_groups:
|
162
|
+
split_ners.setdefault(split, {}).setdefault(entity_label, {}).setdefault(
|
163
|
+
entity_text, []
|
164
|
+
)
|
165
|
+
for position in entity_positions:
|
166
|
+
split_ners[split][entity_label][entity_text].append(
|
167
|
+
(position.start, position.end)
|
168
|
+
)
|
169
|
+
|
170
|
+
# Legacy processor entities
|
171
|
+
# TODO: Remove once processor doesn't use this anymore and remove the positions and ner fields from the message
|
151
172
|
for entity_key, positions in field_metadata.metadata.positions.items():
|
152
173
|
entities = entity_key.split("/")
|
153
174
|
entity_group = entities[0]
|
@@ -156,9 +177,7 @@ async def get_field_text(
|
|
156
177
|
if entity_group in valid_entity_groups:
|
157
178
|
split_ners[MAIN].setdefault(entity_group, {}).setdefault(entity, [])
|
158
179
|
for position in positions.position:
|
159
|
-
split_ners[MAIN][entity_group][entity].append(
|
160
|
-
(position.start, position.end)
|
161
|
-
)
|
180
|
+
split_ners[MAIN][entity_group][entity].append((position.start, position.end))
|
162
181
|
|
163
182
|
for split, split_metadata in field_metadata.split_metadata.items():
|
164
183
|
for entity_key, positions in split_metadata.positions.items():
|
@@ -166,24 +185,16 @@ async def get_field_text(
|
|
166
185
|
entity_group = entities[0]
|
167
186
|
entity = "/".join(entities[1:])
|
168
187
|
if entity_group in valid_entity_groups:
|
169
|
-
split_ners.setdefault(split, {}).setdefault(
|
170
|
-
entity_group, {}
|
171
|
-
).setdefault(entity, [])
|
188
|
+
split_ners.setdefault(split, {}).setdefault(entity_group, {}).setdefault(entity, [])
|
172
189
|
for position in positions.position:
|
173
|
-
split_ners[split][entity_group][entity].append(
|
174
|
-
(position.start, position.end)
|
175
|
-
)
|
190
|
+
split_ners[split][entity_group][entity].append((position.start, position.end))
|
176
191
|
|
177
192
|
for split, invalid_tokens in invalid_tokens_split.items():
|
178
193
|
for token.klass, token.token, token.start, token.end in invalid_tokens:
|
179
194
|
if token.klass in split_ners.get(split, {}):
|
180
195
|
if token.token in split_ners.get(split, {}).get(token.klass, {}):
|
181
|
-
if (token.start, token.end) in split_ners[split][token.klass][
|
182
|
-
token.token
|
183
|
-
]:
|
184
|
-
split_ners[split][token.klass][token.token].remove(
|
185
|
-
(token.start, token.end)
|
186
|
-
)
|
196
|
+
if (token.start, token.end) in split_ners[split][token.klass][token.token]:
|
197
|
+
split_ners[split][token.klass][token.token].remove((token.start, token.end))
|
187
198
|
if len(split_ners[split][token.klass][token.token]) == 0:
|
188
199
|
del split_ners[split][token.klass][token.token]
|
189
200
|
if len(split_ners[split][token.klass]) == 0:
|
@@ -197,9 +208,7 @@ async def get_field_text(
|
|
197
208
|
for position in positions:
|
198
209
|
split_positions[position] = (entity_group, entity)
|
199
210
|
|
200
|
-
ordered_positions[split] = OrderedDict(
|
201
|
-
sorted(split_positions.items(), key=lambda x: x[0])
|
202
|
-
)
|
211
|
+
ordered_positions[split] = OrderedDict(sorted(split_positions.items(), key=lambda x: x[0]))
|
203
212
|
|
204
213
|
split_paragraphs: dict[str, list[tuple[int, int]]] = {}
|
205
214
|
if field_metadata is not None:
|
@@ -19,19 +19,17 @@
|
|
19
19
|
#
|
20
20
|
|
21
21
|
from contextvars import ContextVar
|
22
|
-
from typing import Any, AsyncIterator, Optional
|
22
|
+
from typing import Any, AsyncGenerator, AsyncIterator, Optional, Type
|
23
23
|
|
24
|
+
from nucliadb.common.ids import FIELD_TYPE_STR_TO_PB
|
25
|
+
from nucliadb.common.maindb.utils import get_driver
|
24
26
|
from nucliadb.ingest.orm.knowledgebox import KnowledgeBox as KnowledgeBoxORM
|
25
|
-
from nucliadb.ingest.orm.resource import KB_REVERSE
|
26
27
|
from nucliadb.ingest.orm.resource import Resource as ResourceORM
|
27
|
-
from nucliadb.middleware.transaction import get_read_only_transaction
|
28
28
|
from nucliadb.train import SERVICE_NAME, logger
|
29
|
-
from nucliadb.train.types import
|
29
|
+
from nucliadb.train.types import T
|
30
30
|
from nucliadb_utils.utilities import get_storage
|
31
31
|
|
32
|
-
rcache: ContextVar[Optional[dict[str, ResourceORM]]] = ContextVar(
|
33
|
-
"rcache", default=None
|
34
|
-
)
|
32
|
+
rcache: ContextVar[Optional[dict[str, ResourceORM]]] = ContextVar("rcache", default=None)
|
35
33
|
|
36
34
|
|
37
35
|
def get_resource_cache(clear: bool = False) -> dict[str, ResourceORM]:
|
@@ -46,12 +44,12 @@ async def get_resource_from_cache_or_db(kbid: str, uuid: str) -> Optional[Resour
|
|
46
44
|
resouce_cache = get_resource_cache()
|
47
45
|
orm_resource: Optional[ResourceORM] = None
|
48
46
|
if uuid not in resouce_cache:
|
49
|
-
transaction = await get_read_only_transaction()
|
50
47
|
storage = await get_storage(service_name=SERVICE_NAME)
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
48
|
+
async with get_driver().transaction(read_only=True) as transaction:
|
49
|
+
kb = KnowledgeBoxORM(transaction, storage, kbid)
|
50
|
+
orm_resource = await kb.get(uuid)
|
51
|
+
if orm_resource is not None:
|
52
|
+
resouce_cache[uuid] = orm_resource
|
55
53
|
else:
|
56
54
|
orm_resource = resouce_cache.get(uuid)
|
57
55
|
return orm_resource
|
@@ -75,13 +73,11 @@ async def get_paragraph(kbid: str, paragraph_id: str) -> str:
|
|
75
73
|
logger.error(f"{rid} does not exist on DB")
|
76
74
|
return ""
|
77
75
|
|
78
|
-
field_type_int =
|
76
|
+
field_type_int = FIELD_TYPE_STR_TO_PB[field_type]
|
79
77
|
field_obj = await orm_resource.get_field(field, field_type_int, load=False)
|
80
78
|
extracted_text = await field_obj.get_extracted_text()
|
81
79
|
if extracted_text is None:
|
82
|
-
logger.warning(
|
83
|
-
f"{rid} {field} {field_type_int} extracted_text does not exist on DB"
|
84
|
-
)
|
80
|
+
logger.warning(f"{rid} {field} {field_type_int} extracted_text does not exist on DB")
|
85
81
|
return ""
|
86
82
|
|
87
83
|
if split is not None:
|
@@ -94,8 +90,8 @@ async def get_paragraph(kbid: str, paragraph_id: str) -> str:
|
|
94
90
|
|
95
91
|
|
96
92
|
async def batchify(
|
97
|
-
producer: AsyncIterator[Any], size: int, batch_klass:
|
98
|
-
):
|
93
|
+
producer: AsyncIterator[Any], size: int, batch_klass: Type[T]
|
94
|
+
) -> AsyncGenerator[T, None]:
|
99
95
|
# NOTE: we are supposing all protobuffers have a data field
|
100
96
|
batch = []
|
101
97
|
async for item in producer:
|
nucliadb/train/lifecycle.py
CHANGED
@@ -18,6 +18,10 @@
|
|
18
18
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
19
19
|
#
|
20
20
|
|
21
|
+
from contextlib import asynccontextmanager
|
22
|
+
|
23
|
+
from fastapi import FastAPI
|
24
|
+
|
21
25
|
from nucliadb.common.cluster.discovery.utils import (
|
22
26
|
setup_cluster_discovery,
|
23
27
|
teardown_cluster_discovery,
|
@@ -33,16 +37,16 @@ from nucliadb_telemetry.utils import clean_telemetry, setup_telemetry
|
|
33
37
|
from nucliadb_utils.utilities import start_audit_utility, stop_audit_utility
|
34
38
|
|
35
39
|
|
36
|
-
|
40
|
+
@asynccontextmanager
|
41
|
+
async def lifespan(app: FastAPI):
|
37
42
|
await setup_telemetry(SERVICE_NAME)
|
38
|
-
|
39
43
|
await setup_cluster_discovery()
|
40
44
|
await start_shard_manager()
|
41
45
|
await start_train_grpc(SERVICE_NAME)
|
42
46
|
await start_audit_utility(SERVICE_NAME)
|
43
47
|
|
48
|
+
yield
|
44
49
|
|
45
|
-
async def finalize() -> None:
|
46
50
|
await stop_audit_utility()
|
47
51
|
await stop_train_grpc()
|
48
52
|
await stop_shard_manager()
|
nucliadb/train/nodes.py
CHANGED
@@ -19,6 +19,15 @@
|
|
19
19
|
#
|
20
20
|
from typing import AsyncIterator, Optional
|
21
21
|
|
22
|
+
from nucliadb.common import datamanagers
|
23
|
+
from nucliadb.common.cluster import manager
|
24
|
+
from nucliadb.common.cluster.base import AbstractIndexNode
|
25
|
+
|
26
|
+
# XXX: this keys shouldn't be exposed outside datamanagers
|
27
|
+
from nucliadb.common.datamanagers.resources import KB_RESOURCE_SLUG_BASE
|
28
|
+
from nucliadb.common.maindb.driver import Driver, Transaction
|
29
|
+
from nucliadb.ingest.orm.entities import EntitiesManager
|
30
|
+
from nucliadb.ingest.orm.knowledgebox import KnowledgeBox
|
22
31
|
from nucliadb_protos.train_pb2 import (
|
23
32
|
GetFieldsRequest,
|
24
33
|
GetParagraphsRequest,
|
@@ -30,15 +39,9 @@ from nucliadb_protos.train_pb2 import (
|
|
30
39
|
TrainSentence,
|
31
40
|
)
|
32
41
|
from nucliadb_protos.writer_pb2 import ShardObject
|
33
|
-
|
34
|
-
from nucliadb.common import datamanagers
|
35
|
-
from nucliadb.common.cluster import manager
|
36
|
-
from nucliadb.common.cluster.base import AbstractIndexNode
|
37
|
-
from nucliadb.common.maindb.driver import Driver, Transaction
|
38
|
-
from nucliadb.ingest.orm.entities import EntitiesManager
|
39
|
-
from nucliadb.ingest.orm.knowledgebox import KnowledgeBox
|
40
|
-
from nucliadb.ingest.orm.resource import KB_RESOURCE_SLUG_BASE
|
42
|
+
from nucliadb_utils import const
|
41
43
|
from nucliadb_utils.storages.storage import Storage
|
44
|
+
from nucliadb_utils.utilities import has_feature
|
42
45
|
|
43
46
|
|
44
47
|
class TrainShardManager(manager.KBShardManager):
|
@@ -50,13 +53,13 @@ class TrainShardManager(manager.KBShardManager):
|
|
50
53
|
async def get_reader(self, kbid: str, shard: str) -> tuple[AbstractIndexNode, str]:
|
51
54
|
shards = await self.get_shards_by_kbid_inner(kbid)
|
52
55
|
try:
|
53
|
-
shard_object: ShardObject = next(
|
54
|
-
filter(lambda x: x.shard == shard, shards.shards)
|
55
|
-
)
|
56
|
+
shard_object: ShardObject = next(filter(lambda x: x.shard == shard, shards.shards))
|
56
57
|
except StopIteration:
|
57
58
|
raise KeyError("Shard not found")
|
58
59
|
|
59
|
-
node_obj, shard_id = manager.choose_node(
|
60
|
+
node_obj, shard_id = manager.choose_node(
|
61
|
+
shard_object, use_nidx=has_feature(const.Features.NIDX_READS, context={"kbid": kbid})
|
62
|
+
)
|
60
63
|
return node_obj, shard_id
|
61
64
|
|
62
65
|
async def get_kb_obj(self, txn: Transaction, kbid: str) -> Optional[KnowledgeBox]:
|
@@ -69,9 +72,7 @@ class TrainShardManager(manager.KBShardManager):
|
|
69
72
|
kbobj = KnowledgeBox(txn, self.storage, kbid)
|
70
73
|
return kbobj
|
71
74
|
|
72
|
-
async def get_kb_entities_manager(
|
73
|
-
self, txn: Transaction, kbid: str
|
74
|
-
) -> Optional[EntitiesManager]:
|
75
|
+
async def get_kb_entities_manager(self, txn: Transaction, kbid: str) -> Optional[EntitiesManager]:
|
75
76
|
kbobj = await self.get_kb_obj(txn, kbid)
|
76
77
|
if kbobj is None:
|
77
78
|
return None
|
@@ -79,9 +80,7 @@ class TrainShardManager(manager.KBShardManager):
|
|
79
80
|
manager = EntitiesManager(kbobj, txn)
|
80
81
|
return manager
|
81
82
|
|
82
|
-
async def kb_sentences(
|
83
|
-
self, request: GetSentencesRequest
|
84
|
-
) -> AsyncIterator[TrainSentence]:
|
83
|
+
async def kb_sentences(self, request: GetSentencesRequest) -> AsyncIterator[TrainSentence]:
|
85
84
|
async with self.driver.transaction() as txn:
|
86
85
|
kb = KnowledgeBox(txn, self.storage, request.kb.uuid)
|
87
86
|
if request.uuid != "":
|
@@ -95,24 +94,18 @@ class TrainShardManager(manager.KBShardManager):
|
|
95
94
|
async for sentence in resource.iterate_sentences(request.metadata):
|
96
95
|
yield sentence
|
97
96
|
|
98
|
-
async def kb_paragraphs(
|
99
|
-
self, request: GetParagraphsRequest
|
100
|
-
) -> AsyncIterator[TrainParagraph]:
|
97
|
+
async def kb_paragraphs(self, request: GetParagraphsRequest) -> AsyncIterator[TrainParagraph]:
|
101
98
|
async with self.driver.transaction() as txn:
|
102
99
|
kb = KnowledgeBox(txn, self.storage, request.kb.uuid)
|
103
100
|
if request.uuid != "":
|
104
101
|
# Filter by uuid
|
105
102
|
resource = await kb.get(request.uuid)
|
106
103
|
if resource:
|
107
|
-
async for paragraph in resource.iterate_paragraphs(
|
108
|
-
request.metadata
|
109
|
-
):
|
104
|
+
async for paragraph in resource.iterate_paragraphs(request.metadata):
|
110
105
|
yield paragraph
|
111
106
|
else:
|
112
107
|
async for resource in kb.iterate_resources():
|
113
|
-
async for paragraph in resource.iterate_paragraphs(
|
114
|
-
request.metadata
|
115
|
-
):
|
108
|
+
async for paragraph in resource.iterate_paragraphs(request.metadata):
|
116
109
|
yield paragraph
|
117
110
|
|
118
111
|
async def kb_fields(self, request: GetFieldsRequest) -> AsyncIterator[TrainField]:
|
@@ -129,15 +122,13 @@ class TrainShardManager(manager.KBShardManager):
|
|
129
122
|
async for field in resource.iterate_fields(request.metadata):
|
130
123
|
yield field
|
131
124
|
|
132
|
-
async def kb_resources(
|
133
|
-
self, request: GetResourcesRequest
|
134
|
-
) -> AsyncIterator[TrainResource]:
|
125
|
+
async def kb_resources(self, request: GetResourcesRequest) -> AsyncIterator[TrainResource]:
|
135
126
|
async with self.driver.transaction() as txn:
|
136
127
|
kb = KnowledgeBox(txn, self.storage, request.kb.uuid)
|
137
128
|
base = KB_RESOURCE_SLUG_BASE.format(kbid=request.kb.uuid)
|
138
|
-
async for key in txn.keys(match=base
|
129
|
+
async for key in txn.keys(match=base):
|
139
130
|
# Fetch and Add wanted item
|
140
|
-
rid = await txn.get(key)
|
131
|
+
rid = await txn.get(key, for_update=False)
|
141
132
|
if rid is not None:
|
142
133
|
resource = await kb.get(rid.decode())
|
143
134
|
if resource is not None:
|
nucliadb/train/py.typed
ADDED
File without changes
|
nucliadb/train/servicer.py
CHANGED
@@ -18,10 +18,13 @@
|
|
18
18
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
19
19
|
#
|
20
20
|
import traceback
|
21
|
-
from typing import Optional
|
22
21
|
|
23
22
|
import aiohttp
|
24
|
-
|
23
|
+
|
24
|
+
from nucliadb.common import datamanagers
|
25
|
+
from nucliadb.train.settings import settings
|
26
|
+
from nucliadb.train.utils import get_shard_manager
|
27
|
+
from nucliadb_protos import train_pb2_grpc
|
25
28
|
from nucliadb_protos.train_pb2 import (
|
26
29
|
GetFieldsRequest,
|
27
30
|
GetInfoRequest,
|
@@ -38,10 +41,6 @@ from nucliadb_protos.writer_pb2 import (
|
|
38
41
|
GetLabelsRequest,
|
39
42
|
GetLabelsResponse,
|
40
43
|
)
|
41
|
-
|
42
|
-
from nucliadb.train.settings import settings
|
43
|
-
from nucliadb.train.utils import get_shard_manager
|
44
|
-
from nucliadb_protos import train_pb2_grpc
|
45
44
|
from nucliadb_telemetry import errors
|
46
45
|
|
47
46
|
|
@@ -111,20 +110,15 @@ class TrainServicer(train_pb2_grpc.TrainServicer):
|
|
111
110
|
async def GetOntology( # type: ignore
|
112
111
|
self, request: GetLabelsRequest, context=None
|
113
112
|
) -> GetLabelsResponse:
|
114
|
-
async with self.proc.driver.transaction() as txn:
|
115
|
-
kbobj = await self.proc.get_kb_obj(txn, request.kb.uuid)
|
116
|
-
labels: Optional[Labels] = None
|
117
|
-
if kbobj is not None:
|
118
|
-
labels = await kbobj.get_labels()
|
119
|
-
|
120
113
|
response = GetLabelsResponse()
|
121
|
-
|
122
|
-
|
114
|
+
kbid = request.kb.uuid
|
115
|
+
labels = await datamanagers.atomic.labelset.get_all(kbid=kbid)
|
116
|
+
if labels is not None:
|
117
|
+
response.kb.uuid = kbid
|
118
|
+
response.status = GetLabelsResponse.Status.OK
|
119
|
+
response.labels.CopyFrom(labels)
|
123
120
|
else:
|
124
|
-
response.
|
125
|
-
if labels is not None:
|
126
|
-
response.labels.CopyFrom(labels)
|
127
|
-
|
121
|
+
response.status = GetLabelsResponse.Status.NOTFOUND
|
128
122
|
return response
|
129
123
|
|
130
124
|
async def GetOntologyCount( # type: ignore
|
@@ -132,9 +126,7 @@ class TrainServicer(train_pb2_grpc.TrainServicer):
|
|
132
126
|
) -> LabelsetsCount:
|
133
127
|
url = settings.internal_search_api.format(kbid=request.kb.uuid)
|
134
128
|
facets = [f"faceted=/p/{labelset}" for labelset in request.paragraph_labelsets]
|
135
|
-
facets.extend(
|
136
|
-
[f"faceted=/l/{labelset}" for labelset in request.resource_labelsets]
|
137
|
-
)
|
129
|
+
facets.extend([f"faceted=/l/{labelset}" for labelset in request.resource_labelsets])
|
138
130
|
query = "&".join(facets)
|
139
131
|
headers = {"X-NUCLIADB-ROLES": "READER"}
|
140
132
|
async with aiohttp.ClientSession() as sess:
|
nucliadb/train/settings.py
CHANGED
@@ -29,13 +29,9 @@ class Settings(DriverSettings):
|
|
29
29
|
nuclia_learning_url: Optional[str] = "https://nuclia.cloud/api/v1/learning/"
|
30
30
|
nuclia_learning_apikey: Optional[str] = None
|
31
31
|
|
32
|
-
internal_counter_api: str =
|
33
|
-
"http://search.nuclia.svc.cluster.local:8030/api/v1/kb/{kbid}/counters"
|
34
|
-
)
|
32
|
+
internal_counter_api: str = "http://search.nuclia.svc.cluster.local:8030/api/v1/kb/{kbid}/counters"
|
35
33
|
|
36
|
-
internal_search_api: str =
|
37
|
-
"http://search.nuclia.svc.cluster.local:8030/api/v1/kb/{kbid}/search"
|
38
|
-
)
|
34
|
+
internal_search_api: str = "http://search.nuclia.svc.cluster.local:8030/api/v1/kb/{kbid}/search"
|
39
35
|
|
40
36
|
|
41
37
|
settings = Settings()
|
nucliadb/train/types.py
CHANGED
@@ -17,7 +17,7 @@
|
|
17
17
|
# You should have received a copy of the GNU Affero General Public License
|
18
18
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
19
19
|
#
|
20
|
-
from typing import Union
|
20
|
+
from typing import TypeVar, Union
|
21
21
|
|
22
22
|
from nucliadb_protos import dataset_pb2 as dpb
|
23
23
|
|
@@ -29,14 +29,17 @@ TrainBatch = Union[
|
|
29
29
|
dpb.QuestionAnswerStreamingBatch,
|
30
30
|
dpb.SentenceClassificationBatch,
|
31
31
|
dpb.TokenClassificationBatch,
|
32
|
+
dpb.FieldStreamingBatch,
|
32
33
|
]
|
33
34
|
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
35
|
+
T = TypeVar(
|
36
|
+
"T",
|
37
|
+
dpb.FieldClassificationBatch,
|
38
|
+
dpb.ImageClassificationBatch,
|
39
|
+
dpb.ParagraphClassificationBatch,
|
40
|
+
dpb.ParagraphStreamingBatch,
|
41
|
+
dpb.QuestionAnswerStreamingBatch,
|
42
|
+
dpb.SentenceClassificationBatch,
|
43
|
+
dpb.TokenClassificationBatch,
|
44
|
+
dpb.FieldStreamingBatch,
|
45
|
+
)
|
nucliadb/train/upload.py
CHANGED
@@ -19,11 +19,10 @@
|
|
19
19
|
#
|
20
20
|
import argparse
|
21
21
|
import asyncio
|
22
|
+
import importlib.metadata
|
22
23
|
from asyncio import tasks
|
23
24
|
from typing import Callable
|
24
25
|
|
25
|
-
import pkg_resources
|
26
|
-
|
27
26
|
from nucliadb.train.uploader import start_upload
|
28
27
|
from nucliadb_telemetry import errors
|
29
28
|
from nucliadb_telemetry.logs import setup_logging
|
@@ -33,9 +32,7 @@ from nucliadb_utils.settings import running_settings
|
|
33
32
|
def arg_parse():
|
34
33
|
parser = argparse.ArgumentParser(description="Upload data to Nuclia Learning API.")
|
35
34
|
|
36
|
-
parser.add_argument(
|
37
|
-
"-r", "--request", dest="request", help="Request UUID", required=True
|
38
|
-
)
|
35
|
+
parser.add_argument("-r", "--request", dest="request", help="Request UUID", required=True)
|
39
36
|
|
40
37
|
parser.add_argument("-k", "--kb", dest="kb", help="Knowledge Box", required=True)
|
41
38
|
|
@@ -75,7 +72,7 @@ def _cancel_all_tasks(loop):
|
|
75
72
|
def run() -> None:
|
76
73
|
setup_logging()
|
77
74
|
|
78
|
-
errors.setup_error_handling(
|
75
|
+
errors.setup_error_handling(importlib.metadata.distribution("nucliadb").version)
|
79
76
|
|
80
77
|
if asyncio._get_running_loop() is not None:
|
81
78
|
raise RuntimeError("cannot be called from a running event loop")
|