nucliadb 4.0.0.post542__py3-none-any.whl → 6.2.1.post2777__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- migrations/0003_allfields_key.py +1 -35
- migrations/0009_upgrade_relations_and_texts_to_v2.py +4 -2
- migrations/0010_fix_corrupt_indexes.py +10 -10
- migrations/0011_materialize_labelset_ids.py +1 -16
- migrations/0012_rollover_shards.py +5 -10
- migrations/0014_rollover_shards.py +4 -5
- migrations/0015_targeted_rollover.py +5 -10
- migrations/0016_upgrade_to_paragraphs_v2.py +25 -28
- migrations/0017_multiple_writable_shards.py +2 -4
- migrations/0018_purge_orphan_kbslugs.py +5 -7
- migrations/0019_upgrade_to_paragraphs_v3.py +25 -28
- migrations/0020_drain_nodes_from_cluster.py +3 -3
- nucliadb/standalone/tests/unit/test_run.py → migrations/0021_overwrite_vectorsets_key.py +16 -19
- nucliadb/tests/unit/test_openapi.py → migrations/0022_fix_paragraph_deletion_bug.py +16 -11
- migrations/0023_backfill_pg_catalog.py +80 -0
- migrations/0025_assign_models_to_kbs_v2.py +113 -0
- migrations/0026_fix_high_cardinality_content_types.py +61 -0
- migrations/0027_rollover_texts3.py +73 -0
- nucliadb/ingest/fields/date.py → migrations/pg/0001_bootstrap.py +10 -12
- migrations/pg/0002_catalog.py +42 -0
- nucliadb/ingest/tests/unit/test_settings.py → migrations/pg/0003_catalog_kbid_index.py +5 -3
- nucliadb/common/cluster/base.py +30 -16
- nucliadb/common/cluster/discovery/base.py +6 -14
- nucliadb/common/cluster/discovery/k8s.py +9 -19
- nucliadb/common/cluster/discovery/manual.py +1 -3
- nucliadb/common/cluster/discovery/utils.py +1 -3
- nucliadb/common/cluster/grpc_node_dummy.py +3 -11
- nucliadb/common/cluster/index_node.py +10 -19
- nucliadb/common/cluster/manager.py +174 -59
- nucliadb/common/cluster/rebalance.py +27 -29
- nucliadb/common/cluster/rollover.py +353 -194
- nucliadb/common/cluster/settings.py +6 -0
- nucliadb/common/cluster/standalone/grpc_node_binding.py +13 -64
- nucliadb/common/cluster/standalone/index_node.py +4 -11
- nucliadb/common/cluster/standalone/service.py +2 -6
- nucliadb/common/cluster/standalone/utils.py +2 -6
- nucliadb/common/cluster/utils.py +29 -22
- nucliadb/common/constants.py +20 -0
- nucliadb/common/context/__init__.py +3 -0
- nucliadb/common/context/fastapi.py +8 -5
- nucliadb/{tests/knowledgeboxes/__init__.py → common/counters.py} +8 -2
- nucliadb/common/datamanagers/__init__.py +7 -1
- nucliadb/common/datamanagers/atomic.py +22 -4
- nucliadb/common/datamanagers/cluster.py +5 -5
- nucliadb/common/datamanagers/entities.py +6 -16
- nucliadb/common/datamanagers/fields.py +84 -0
- nucliadb/common/datamanagers/kb.py +83 -37
- nucliadb/common/datamanagers/labels.py +26 -56
- nucliadb/common/datamanagers/processing.py +2 -6
- nucliadb/common/datamanagers/resources.py +41 -103
- nucliadb/common/datamanagers/rollover.py +76 -15
- nucliadb/common/datamanagers/synonyms.py +1 -1
- nucliadb/common/datamanagers/utils.py +15 -6
- nucliadb/common/datamanagers/vectorsets.py +110 -0
- nucliadb/common/external_index_providers/base.py +257 -0
- nucliadb/{ingest/tests/unit/orm/test_orm_utils.py → common/external_index_providers/exceptions.py} +9 -8
- nucliadb/common/external_index_providers/manager.py +101 -0
- nucliadb/common/external_index_providers/pinecone.py +933 -0
- nucliadb/common/external_index_providers/settings.py +52 -0
- nucliadb/common/http_clients/auth.py +3 -6
- nucliadb/common/http_clients/processing.py +6 -11
- nucliadb/common/http_clients/utils.py +1 -3
- nucliadb/common/ids.py +240 -0
- nucliadb/common/locking.py +29 -7
- nucliadb/common/maindb/driver.py +11 -35
- nucliadb/common/maindb/exceptions.py +3 -0
- nucliadb/common/maindb/local.py +22 -9
- nucliadb/common/maindb/pg.py +206 -111
- nucliadb/common/maindb/utils.py +11 -42
- nucliadb/common/models_utils/from_proto.py +479 -0
- nucliadb/common/models_utils/to_proto.py +60 -0
- nucliadb/common/nidx.py +260 -0
- nucliadb/export_import/datamanager.py +25 -19
- nucliadb/export_import/exporter.py +5 -11
- nucliadb/export_import/importer.py +5 -7
- nucliadb/export_import/models.py +3 -3
- nucliadb/export_import/tasks.py +4 -4
- nucliadb/export_import/utils.py +25 -37
- nucliadb/health.py +1 -3
- nucliadb/ingest/app.py +15 -11
- nucliadb/ingest/consumer/auditing.py +21 -19
- nucliadb/ingest/consumer/consumer.py +82 -47
- nucliadb/ingest/consumer/materializer.py +5 -12
- nucliadb/ingest/consumer/pull.py +12 -27
- nucliadb/ingest/consumer/service.py +19 -17
- nucliadb/ingest/consumer/shard_creator.py +2 -4
- nucliadb/ingest/consumer/utils.py +1 -3
- nucliadb/ingest/fields/base.py +137 -105
- nucliadb/ingest/fields/conversation.py +18 -5
- nucliadb/ingest/fields/exceptions.py +1 -4
- nucliadb/ingest/fields/file.py +7 -16
- nucliadb/ingest/fields/link.py +5 -10
- nucliadb/ingest/fields/text.py +9 -4
- nucliadb/ingest/orm/brain.py +200 -213
- nucliadb/ingest/orm/broker_message.py +181 -0
- nucliadb/ingest/orm/entities.py +36 -51
- nucliadb/ingest/orm/exceptions.py +12 -0
- nucliadb/ingest/orm/knowledgebox.py +322 -197
- nucliadb/ingest/orm/processor/__init__.py +2 -700
- nucliadb/ingest/orm/processor/auditing.py +4 -23
- nucliadb/ingest/orm/processor/data_augmentation.py +164 -0
- nucliadb/ingest/orm/processor/pgcatalog.py +84 -0
- nucliadb/ingest/orm/processor/processor.py +752 -0
- nucliadb/ingest/orm/processor/sequence_manager.py +1 -1
- nucliadb/ingest/orm/resource.py +249 -402
- nucliadb/ingest/orm/utils.py +4 -4
- nucliadb/ingest/partitions.py +3 -9
- nucliadb/ingest/processing.py +64 -73
- nucliadb/ingest/py.typed +0 -0
- nucliadb/ingest/serialize.py +37 -167
- nucliadb/ingest/service/__init__.py +1 -3
- nucliadb/ingest/service/writer.py +185 -412
- nucliadb/ingest/settings.py +10 -20
- nucliadb/ingest/utils.py +3 -6
- nucliadb/learning_proxy.py +242 -55
- nucliadb/metrics_exporter.py +30 -19
- nucliadb/middleware/__init__.py +1 -3
- nucliadb/migrator/command.py +1 -3
- nucliadb/migrator/datamanager.py +13 -13
- nucliadb/migrator/migrator.py +47 -30
- nucliadb/migrator/utils.py +18 -10
- nucliadb/purge/__init__.py +139 -33
- nucliadb/purge/orphan_shards.py +7 -13
- nucliadb/reader/__init__.py +1 -3
- nucliadb/reader/api/models.py +1 -12
- nucliadb/reader/api/v1/__init__.py +0 -1
- nucliadb/reader/api/v1/download.py +21 -88
- nucliadb/reader/api/v1/export_import.py +1 -1
- nucliadb/reader/api/v1/knowledgebox.py +10 -10
- nucliadb/reader/api/v1/learning_config.py +2 -6
- nucliadb/reader/api/v1/resource.py +62 -88
- nucliadb/reader/api/v1/services.py +64 -83
- nucliadb/reader/app.py +12 -29
- nucliadb/reader/lifecycle.py +18 -4
- nucliadb/reader/py.typed +0 -0
- nucliadb/reader/reader/notifications.py +10 -28
- nucliadb/search/__init__.py +1 -3
- nucliadb/search/api/v1/__init__.py +1 -2
- nucliadb/search/api/v1/ask.py +17 -10
- nucliadb/search/api/v1/catalog.py +184 -0
- nucliadb/search/api/v1/feedback.py +16 -24
- nucliadb/search/api/v1/find.py +36 -36
- nucliadb/search/api/v1/knowledgebox.py +89 -60
- nucliadb/search/api/v1/resource/ask.py +2 -8
- nucliadb/search/api/v1/resource/search.py +49 -70
- nucliadb/search/api/v1/search.py +44 -210
- nucliadb/search/api/v1/suggest.py +39 -54
- nucliadb/search/app.py +12 -32
- nucliadb/search/lifecycle.py +10 -3
- nucliadb/search/predict.py +136 -187
- nucliadb/search/py.typed +0 -0
- nucliadb/search/requesters/utils.py +25 -58
- nucliadb/search/search/cache.py +149 -20
- nucliadb/search/search/chat/ask.py +571 -123
- nucliadb/search/{tests/unit/test_run.py → search/chat/exceptions.py} +14 -14
- nucliadb/search/search/chat/images.py +41 -17
- nucliadb/search/search/chat/prompt.py +817 -266
- nucliadb/search/search/chat/query.py +213 -309
- nucliadb/{tests/migrations/__init__.py → search/search/cut.py} +8 -8
- nucliadb/search/search/fetch.py +43 -36
- nucliadb/search/search/filters.py +9 -15
- nucliadb/search/search/find.py +214 -53
- nucliadb/search/search/find_merge.py +408 -391
- nucliadb/search/search/hydrator.py +191 -0
- nucliadb/search/search/merge.py +187 -223
- nucliadb/search/search/metrics.py +73 -2
- nucliadb/search/search/paragraphs.py +64 -106
- nucliadb/search/search/pgcatalog.py +233 -0
- nucliadb/search/search/predict_proxy.py +1 -1
- nucliadb/search/search/query.py +305 -150
- nucliadb/search/search/query_parser/exceptions.py +22 -0
- nucliadb/search/search/query_parser/models.py +101 -0
- nucliadb/search/search/query_parser/parser.py +183 -0
- nucliadb/search/search/rank_fusion.py +204 -0
- nucliadb/search/search/rerankers.py +270 -0
- nucliadb/search/search/shards.py +3 -32
- nucliadb/search/search/summarize.py +7 -18
- nucliadb/search/search/utils.py +27 -4
- nucliadb/search/settings.py +15 -1
- nucliadb/standalone/api_router.py +4 -10
- nucliadb/standalone/app.py +8 -14
- nucliadb/standalone/auth.py +7 -21
- nucliadb/standalone/config.py +7 -10
- nucliadb/standalone/lifecycle.py +26 -25
- nucliadb/standalone/migrations.py +1 -3
- nucliadb/standalone/purge.py +1 -1
- nucliadb/standalone/py.typed +0 -0
- nucliadb/standalone/run.py +3 -6
- nucliadb/standalone/settings.py +9 -16
- nucliadb/standalone/versions.py +15 -5
- nucliadb/tasks/consumer.py +8 -12
- nucliadb/tasks/producer.py +7 -6
- nucliadb/tests/config.py +53 -0
- nucliadb/train/__init__.py +1 -3
- nucliadb/train/api/utils.py +1 -2
- nucliadb/train/api/v1/shards.py +1 -1
- nucliadb/train/api/v1/trainset.py +2 -4
- nucliadb/train/app.py +10 -31
- nucliadb/train/generator.py +10 -19
- nucliadb/train/generators/field_classifier.py +7 -19
- nucliadb/train/generators/field_streaming.py +156 -0
- nucliadb/train/generators/image_classifier.py +12 -18
- nucliadb/train/generators/paragraph_classifier.py +5 -9
- nucliadb/train/generators/paragraph_streaming.py +6 -9
- nucliadb/train/generators/question_answer_streaming.py +19 -20
- nucliadb/train/generators/sentence_classifier.py +9 -15
- nucliadb/train/generators/token_classifier.py +48 -39
- nucliadb/train/generators/utils.py +14 -18
- nucliadb/train/lifecycle.py +7 -3
- nucliadb/train/nodes.py +23 -32
- nucliadb/train/py.typed +0 -0
- nucliadb/train/servicer.py +13 -21
- nucliadb/train/settings.py +2 -6
- nucliadb/train/types.py +13 -10
- nucliadb/train/upload.py +3 -6
- nucliadb/train/uploader.py +19 -23
- nucliadb/train/utils.py +1 -1
- nucliadb/writer/__init__.py +1 -3
- nucliadb/{ingest/fields/keywordset.py → writer/api/utils.py} +13 -10
- nucliadb/writer/api/v1/export_import.py +67 -14
- nucliadb/writer/api/v1/field.py +16 -269
- nucliadb/writer/api/v1/knowledgebox.py +218 -68
- nucliadb/writer/api/v1/resource.py +68 -88
- nucliadb/writer/api/v1/services.py +51 -70
- nucliadb/writer/api/v1/slug.py +61 -0
- nucliadb/writer/api/v1/transaction.py +67 -0
- nucliadb/writer/api/v1/upload.py +114 -113
- nucliadb/writer/app.py +6 -43
- nucliadb/writer/back_pressure.py +16 -38
- nucliadb/writer/exceptions.py +0 -4
- nucliadb/writer/lifecycle.py +21 -15
- nucliadb/writer/py.typed +0 -0
- nucliadb/writer/resource/audit.py +2 -1
- nucliadb/writer/resource/basic.py +48 -46
- nucliadb/writer/resource/field.py +25 -127
- nucliadb/writer/resource/origin.py +1 -2
- nucliadb/writer/settings.py +6 -2
- nucliadb/writer/tus/__init__.py +17 -15
- nucliadb/writer/tus/azure.py +111 -0
- nucliadb/writer/tus/dm.py +17 -5
- nucliadb/writer/tus/exceptions.py +1 -3
- nucliadb/writer/tus/gcs.py +49 -84
- nucliadb/writer/tus/local.py +21 -37
- nucliadb/writer/tus/s3.py +28 -68
- nucliadb/writer/tus/storage.py +5 -56
- nucliadb/writer/vectorsets.py +125 -0
- nucliadb-6.2.1.post2777.dist-info/METADATA +148 -0
- nucliadb-6.2.1.post2777.dist-info/RECORD +343 -0
- {nucliadb-4.0.0.post542.dist-info → nucliadb-6.2.1.post2777.dist-info}/WHEEL +1 -1
- nucliadb/common/maindb/redis.py +0 -194
- nucliadb/common/maindb/tikv.py +0 -433
- nucliadb/ingest/fields/layout.py +0 -58
- nucliadb/ingest/tests/conftest.py +0 -30
- nucliadb/ingest/tests/fixtures.py +0 -764
- nucliadb/ingest/tests/integration/consumer/__init__.py +0 -18
- nucliadb/ingest/tests/integration/consumer/test_auditing.py +0 -78
- nucliadb/ingest/tests/integration/consumer/test_materializer.py +0 -126
- nucliadb/ingest/tests/integration/consumer/test_pull.py +0 -144
- nucliadb/ingest/tests/integration/consumer/test_service.py +0 -81
- nucliadb/ingest/tests/integration/consumer/test_shard_creator.py +0 -68
- nucliadb/ingest/tests/integration/ingest/test_ingest.py +0 -684
- nucliadb/ingest/tests/integration/ingest/test_processing_engine.py +0 -95
- nucliadb/ingest/tests/integration/ingest/test_relations.py +0 -272
- nucliadb/ingest/tests/unit/consumer/__init__.py +0 -18
- nucliadb/ingest/tests/unit/consumer/test_auditing.py +0 -139
- nucliadb/ingest/tests/unit/consumer/test_consumer.py +0 -69
- nucliadb/ingest/tests/unit/consumer/test_pull.py +0 -60
- nucliadb/ingest/tests/unit/consumer/test_shard_creator.py +0 -140
- nucliadb/ingest/tests/unit/consumer/test_utils.py +0 -67
- nucliadb/ingest/tests/unit/orm/__init__.py +0 -19
- nucliadb/ingest/tests/unit/orm/test_brain.py +0 -247
- nucliadb/ingest/tests/unit/orm/test_brain_vectors.py +0 -74
- nucliadb/ingest/tests/unit/orm/test_processor.py +0 -131
- nucliadb/ingest/tests/unit/orm/test_resource.py +0 -331
- nucliadb/ingest/tests/unit/test_cache.py +0 -31
- nucliadb/ingest/tests/unit/test_partitions.py +0 -40
- nucliadb/ingest/tests/unit/test_processing.py +0 -171
- nucliadb/middleware/transaction.py +0 -117
- nucliadb/reader/api/v1/learning_collector.py +0 -63
- nucliadb/reader/tests/__init__.py +0 -19
- nucliadb/reader/tests/conftest.py +0 -31
- nucliadb/reader/tests/fixtures.py +0 -136
- nucliadb/reader/tests/test_list_resources.py +0 -75
- nucliadb/reader/tests/test_reader_file_download.py +0 -273
- nucliadb/reader/tests/test_reader_resource.py +0 -353
- nucliadb/reader/tests/test_reader_resource_field.py +0 -219
- nucliadb/search/api/v1/chat.py +0 -263
- nucliadb/search/api/v1/resource/chat.py +0 -174
- nucliadb/search/tests/__init__.py +0 -19
- nucliadb/search/tests/conftest.py +0 -33
- nucliadb/search/tests/fixtures.py +0 -199
- nucliadb/search/tests/node.py +0 -466
- nucliadb/search/tests/unit/__init__.py +0 -18
- nucliadb/search/tests/unit/api/__init__.py +0 -19
- nucliadb/search/tests/unit/api/v1/__init__.py +0 -19
- nucliadb/search/tests/unit/api/v1/resource/__init__.py +0 -19
- nucliadb/search/tests/unit/api/v1/resource/test_chat.py +0 -98
- nucliadb/search/tests/unit/api/v1/test_ask.py +0 -120
- nucliadb/search/tests/unit/api/v1/test_chat.py +0 -96
- nucliadb/search/tests/unit/api/v1/test_predict_proxy.py +0 -98
- nucliadb/search/tests/unit/api/v1/test_summarize.py +0 -99
- nucliadb/search/tests/unit/search/__init__.py +0 -18
- nucliadb/search/tests/unit/search/requesters/__init__.py +0 -18
- nucliadb/search/tests/unit/search/requesters/test_utils.py +0 -211
- nucliadb/search/tests/unit/search/search/__init__.py +0 -19
- nucliadb/search/tests/unit/search/search/test_shards.py +0 -45
- nucliadb/search/tests/unit/search/search/test_utils.py +0 -82
- nucliadb/search/tests/unit/search/test_chat_prompt.py +0 -270
- nucliadb/search/tests/unit/search/test_fetch.py +0 -108
- nucliadb/search/tests/unit/search/test_filters.py +0 -125
- nucliadb/search/tests/unit/search/test_paragraphs.py +0 -157
- nucliadb/search/tests/unit/search/test_predict_proxy.py +0 -106
- nucliadb/search/tests/unit/search/test_query.py +0 -153
- nucliadb/search/tests/unit/test_app.py +0 -79
- nucliadb/search/tests/unit/test_find_merge.py +0 -112
- nucliadb/search/tests/unit/test_merge.py +0 -34
- nucliadb/search/tests/unit/test_predict.py +0 -525
- nucliadb/standalone/tests/__init__.py +0 -19
- nucliadb/standalone/tests/conftest.py +0 -33
- nucliadb/standalone/tests/fixtures.py +0 -38
- nucliadb/standalone/tests/unit/__init__.py +0 -18
- nucliadb/standalone/tests/unit/test_api_router.py +0 -61
- nucliadb/standalone/tests/unit/test_auth.py +0 -169
- nucliadb/standalone/tests/unit/test_introspect.py +0 -35
- nucliadb/standalone/tests/unit/test_migrations.py +0 -63
- nucliadb/standalone/tests/unit/test_versions.py +0 -68
- nucliadb/tests/benchmarks/__init__.py +0 -19
- nucliadb/tests/benchmarks/test_search.py +0 -99
- nucliadb/tests/conftest.py +0 -32
- nucliadb/tests/fixtures.py +0 -735
- nucliadb/tests/knowledgeboxes/philosophy_books.py +0 -202
- nucliadb/tests/knowledgeboxes/ten_dummy_resources.py +0 -107
- nucliadb/tests/migrations/test_migration_0017.py +0 -76
- nucliadb/tests/migrations/test_migration_0018.py +0 -95
- nucliadb/tests/tikv.py +0 -240
- nucliadb/tests/unit/__init__.py +0 -19
- nucliadb/tests/unit/common/__init__.py +0 -19
- nucliadb/tests/unit/common/cluster/__init__.py +0 -19
- nucliadb/tests/unit/common/cluster/discovery/__init__.py +0 -19
- nucliadb/tests/unit/common/cluster/discovery/test_k8s.py +0 -172
- nucliadb/tests/unit/common/cluster/standalone/__init__.py +0 -18
- nucliadb/tests/unit/common/cluster/standalone/test_service.py +0 -114
- nucliadb/tests/unit/common/cluster/standalone/test_utils.py +0 -61
- nucliadb/tests/unit/common/cluster/test_cluster.py +0 -408
- nucliadb/tests/unit/common/cluster/test_kb_shard_manager.py +0 -173
- nucliadb/tests/unit/common/cluster/test_rebalance.py +0 -38
- nucliadb/tests/unit/common/cluster/test_rollover.py +0 -282
- nucliadb/tests/unit/common/maindb/__init__.py +0 -18
- nucliadb/tests/unit/common/maindb/test_driver.py +0 -127
- nucliadb/tests/unit/common/maindb/test_tikv.py +0 -53
- nucliadb/tests/unit/common/maindb/test_utils.py +0 -92
- nucliadb/tests/unit/common/test_context.py +0 -36
- nucliadb/tests/unit/export_import/__init__.py +0 -19
- nucliadb/tests/unit/export_import/test_datamanager.py +0 -37
- nucliadb/tests/unit/export_import/test_utils.py +0 -301
- nucliadb/tests/unit/migrator/__init__.py +0 -19
- nucliadb/tests/unit/migrator/test_migrator.py +0 -87
- nucliadb/tests/unit/tasks/__init__.py +0 -19
- nucliadb/tests/unit/tasks/conftest.py +0 -42
- nucliadb/tests/unit/tasks/test_consumer.py +0 -92
- nucliadb/tests/unit/tasks/test_producer.py +0 -95
- nucliadb/tests/unit/tasks/test_tasks.py +0 -58
- nucliadb/tests/unit/test_field_ids.py +0 -49
- nucliadb/tests/unit/test_health.py +0 -86
- nucliadb/tests/unit/test_kb_slugs.py +0 -54
- nucliadb/tests/unit/test_learning_proxy.py +0 -252
- nucliadb/tests/unit/test_metrics_exporter.py +0 -77
- nucliadb/tests/unit/test_purge.py +0 -136
- nucliadb/tests/utils/__init__.py +0 -74
- nucliadb/tests/utils/aiohttp_session.py +0 -44
- nucliadb/tests/utils/broker_messages/__init__.py +0 -171
- nucliadb/tests/utils/broker_messages/fields.py +0 -197
- nucliadb/tests/utils/broker_messages/helpers.py +0 -33
- nucliadb/tests/utils/entities.py +0 -78
- nucliadb/train/api/v1/check.py +0 -60
- nucliadb/train/tests/__init__.py +0 -19
- nucliadb/train/tests/conftest.py +0 -29
- nucliadb/train/tests/fixtures.py +0 -342
- nucliadb/train/tests/test_field_classification.py +0 -122
- nucliadb/train/tests/test_get_entities.py +0 -80
- nucliadb/train/tests/test_get_info.py +0 -51
- nucliadb/train/tests/test_get_ontology.py +0 -34
- nucliadb/train/tests/test_get_ontology_count.py +0 -63
- nucliadb/train/tests/test_image_classification.py +0 -221
- nucliadb/train/tests/test_list_fields.py +0 -39
- nucliadb/train/tests/test_list_paragraphs.py +0 -73
- nucliadb/train/tests/test_list_resources.py +0 -39
- nucliadb/train/tests/test_list_sentences.py +0 -71
- nucliadb/train/tests/test_paragraph_classification.py +0 -123
- nucliadb/train/tests/test_paragraph_streaming.py +0 -118
- nucliadb/train/tests/test_question_answer_streaming.py +0 -239
- nucliadb/train/tests/test_sentence_classification.py +0 -143
- nucliadb/train/tests/test_token_classification.py +0 -136
- nucliadb/train/tests/utils.py +0 -101
- nucliadb/writer/layouts/__init__.py +0 -51
- nucliadb/writer/layouts/v1.py +0 -59
- nucliadb/writer/tests/__init__.py +0 -19
- nucliadb/writer/tests/conftest.py +0 -31
- nucliadb/writer/tests/fixtures.py +0 -191
- nucliadb/writer/tests/test_fields.py +0 -475
- nucliadb/writer/tests/test_files.py +0 -740
- nucliadb/writer/tests/test_knowledgebox.py +0 -49
- nucliadb/writer/tests/test_reprocess_file_field.py +0 -133
- nucliadb/writer/tests/test_resources.py +0 -476
- nucliadb/writer/tests/test_service.py +0 -137
- nucliadb/writer/tests/test_tus.py +0 -203
- nucliadb/writer/tests/utils.py +0 -35
- nucliadb/writer/tus/pg.py +0 -125
- nucliadb-4.0.0.post542.dist-info/METADATA +0 -135
- nucliadb-4.0.0.post542.dist-info/RECORD +0 -462
- {nucliadb/ingest/tests → migrations/pg}/__init__.py +0 -0
- /nucliadb/{ingest/tests/integration → common/external_index_providers}/__init__.py +0 -0
- /nucliadb/{ingest/tests/integration/ingest → common/models_utils}/__init__.py +0 -0
- /nucliadb/{ingest/tests/unit → search/search/query_parser}/__init__.py +0 -0
- /nucliadb/{ingest/tests → tests}/vectors.py +0 -0
- {nucliadb-4.0.0.post542.dist-info → nucliadb-6.2.1.post2777.dist-info}/entry_points.txt +0 -0
- {nucliadb-4.0.0.post542.dist-info → nucliadb-6.2.1.post2777.dist-info}/top_level.txt +0 -0
- {nucliadb-4.0.0.post542.dist-info → nucliadb-6.2.1.post2777.dist-info}/zip-safe +0 -0
nucliadb/train/generator.py
CHANGED
@@ -21,11 +21,11 @@
|
|
21
21
|
from typing import AsyncIterator, Optional
|
22
22
|
|
23
23
|
from fastapi import HTTPException
|
24
|
-
from nucliadb_protos.dataset_pb2 import TaskType, TrainSet
|
25
24
|
|
26
25
|
from nucliadb.train.generators.field_classifier import (
|
27
26
|
field_classification_batch_generator,
|
28
27
|
)
|
28
|
+
from nucliadb.train.generators.field_streaming import field_streaming_batch_generator
|
29
29
|
from nucliadb.train.generators.image_classifier import (
|
30
30
|
image_classification_batch_generator,
|
31
31
|
)
|
@@ -46,6 +46,7 @@ from nucliadb.train.generators.token_classifier import (
|
|
46
46
|
)
|
47
47
|
from nucliadb.train.types import TrainBatch
|
48
48
|
from nucliadb.train.utils import get_shard_manager
|
49
|
+
from nucliadb_protos.dataset_pb2 import TaskType, TrainSet
|
49
50
|
|
50
51
|
|
51
52
|
async def generate_train_data(kbid: str, shard: str, trainset: TrainSet):
|
@@ -59,34 +60,24 @@ async def generate_train_data(kbid: str, shard: str, trainset: TrainSet):
|
|
59
60
|
batch_generator: Optional[AsyncIterator[TrainBatch]] = None
|
60
61
|
|
61
62
|
if trainset.type == TaskType.FIELD_CLASSIFICATION:
|
62
|
-
batch_generator = field_classification_batch_generator(
|
63
|
-
kbid, trainset, node, shard_replica_id
|
64
|
-
)
|
63
|
+
batch_generator = field_classification_batch_generator(kbid, trainset, node, shard_replica_id)
|
65
64
|
elif trainset.type == TaskType.IMAGE_CLASSIFICATION:
|
66
|
-
batch_generator = image_classification_batch_generator(
|
67
|
-
kbid, trainset, node, shard_replica_id
|
68
|
-
)
|
65
|
+
batch_generator = image_classification_batch_generator(kbid, trainset, node, shard_replica_id)
|
69
66
|
elif trainset.type == TaskType.PARAGRAPH_CLASSIFICATION:
|
70
67
|
batch_generator = paragraph_classification_batch_generator(
|
71
68
|
kbid, trainset, node, shard_replica_id
|
72
69
|
)
|
73
70
|
elif trainset.type == TaskType.TOKEN_CLASSIFICATION:
|
74
|
-
batch_generator = token_classification_batch_generator(
|
75
|
-
kbid, trainset, node, shard_replica_id
|
76
|
-
)
|
71
|
+
batch_generator = token_classification_batch_generator(kbid, trainset, node, shard_replica_id)
|
77
72
|
elif trainset.type == TaskType.SENTENCE_CLASSIFICATION:
|
78
|
-
batch_generator = sentence_classification_batch_generator(
|
79
|
-
kbid, trainset, node, shard_replica_id
|
80
|
-
)
|
73
|
+
batch_generator = sentence_classification_batch_generator(kbid, trainset, node, shard_replica_id)
|
81
74
|
elif trainset.type == TaskType.PARAGRAPH_STREAMING:
|
82
|
-
batch_generator = paragraph_streaming_batch_generator(
|
83
|
-
kbid, trainset, node, shard_replica_id
|
84
|
-
)
|
75
|
+
batch_generator = paragraph_streaming_batch_generator(kbid, trainset, node, shard_replica_id)
|
85
76
|
|
86
77
|
elif trainset.type == TaskType.QUESTION_ANSWER_STREAMING:
|
87
|
-
batch_generator = question_answer_batch_generator(
|
88
|
-
|
89
|
-
)
|
78
|
+
batch_generator = question_answer_batch_generator(kbid, trainset, node, shard_replica_id)
|
79
|
+
elif trainset.type == TaskType.FIELD_STREAMING:
|
80
|
+
batch_generator = field_streaming_batch_generator(kbid, trainset, node, shard_replica_id)
|
90
81
|
|
91
82
|
if batch_generator is None:
|
92
83
|
raise HTTPException(
|
@@ -20,7 +20,10 @@
|
|
20
20
|
|
21
21
|
from typing import AsyncGenerator
|
22
22
|
|
23
|
-
from
|
23
|
+
from nucliadb.common.cluster.base import AbstractIndexNode
|
24
|
+
from nucliadb.common.ids import FIELD_TYPE_STR_TO_PB
|
25
|
+
from nucliadb.train import logger
|
26
|
+
from nucliadb.train.generators.utils import batchify, get_resource_from_cache_or_db
|
24
27
|
from nucliadb_protos.dataset_pb2 import (
|
25
28
|
FieldClassificationBatch,
|
26
29
|
Label,
|
@@ -29,11 +32,6 @@ from nucliadb_protos.dataset_pb2 import (
|
|
29
32
|
)
|
30
33
|
from nucliadb_protos.nodereader_pb2 import StreamRequest
|
31
34
|
|
32
|
-
from nucliadb.common.cluster.base import AbstractIndexNode
|
33
|
-
from nucliadb.ingest.orm.resource import KB_REVERSE
|
34
|
-
from nucliadb.train import logger
|
35
|
-
from nucliadb.train.generators.utils import batchify, get_resource_from_cache_or_db
|
36
|
-
|
37
35
|
|
38
36
|
def field_classification_batch_generator(
|
39
37
|
kbid: str,
|
@@ -41,15 +39,7 @@ def field_classification_batch_generator(
|
|
41
39
|
node: AbstractIndexNode,
|
42
40
|
shard_replica_id: str,
|
43
41
|
) -> AsyncGenerator[FieldClassificationBatch, None]:
|
44
|
-
|
45
|
-
raise HTTPException(
|
46
|
-
status_code=422,
|
47
|
-
detail="Paragraph Classification should be of 1 labelset",
|
48
|
-
)
|
49
|
-
|
50
|
-
generator = generate_field_classification_payloads(
|
51
|
-
kbid, trainset, node, shard_replica_id
|
52
|
-
)
|
42
|
+
generator = generate_field_classification_payloads(kbid, trainset, node, shard_replica_id)
|
53
43
|
batch_generator = batchify(generator, trainset.batch_size, FieldClassificationBatch)
|
54
44
|
return batch_generator
|
55
45
|
|
@@ -95,13 +85,11 @@ async def get_field_text(kbid: str, rid: str, field: str, field_type: str) -> st
|
|
95
85
|
logger.error(f"{rid} does not exist on DB")
|
96
86
|
return ""
|
97
87
|
|
98
|
-
field_type_int =
|
88
|
+
field_type_int = FIELD_TYPE_STR_TO_PB[field_type]
|
99
89
|
field_obj = await orm_resource.get_field(field, field_type_int, load=False)
|
100
90
|
extracted_text = await field_obj.get_extracted_text()
|
101
91
|
if extracted_text is None:
|
102
|
-
logger.warning(
|
103
|
-
f"{rid} {field} {field_type_int} extracted_text does not exist on DB"
|
104
|
-
)
|
92
|
+
logger.warning(f"{rid} {field} {field_type_int} extracted_text does not exist on DB")
|
105
93
|
return ""
|
106
94
|
|
107
95
|
text = ""
|
@@ -0,0 +1,156 @@
|
|
1
|
+
# Copyright (C) 2021 Bosutech XXI S.L.
|
2
|
+
#
|
3
|
+
# nucliadb is offered under the AGPL v3.0 and as commercial software.
|
4
|
+
# For commercial licensing, contact us at info@nuclia.com.
|
5
|
+
#
|
6
|
+
# AGPL:
|
7
|
+
# This program is free software: you can redistribute it and/or modify
|
8
|
+
# it under the terms of the GNU Affero General Public License as
|
9
|
+
# published by the Free Software Foundation, either version 3 of the
|
10
|
+
# License, or (at your option) any later version.
|
11
|
+
#
|
12
|
+
# This program is distributed in the hope that it will be useful,
|
13
|
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
14
|
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
15
|
+
# GNU Affero General Public License for more details.
|
16
|
+
#
|
17
|
+
# You should have received a copy of the GNU Affero General Public License
|
18
|
+
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
19
|
+
#
|
20
|
+
|
21
|
+
from typing import AsyncGenerator, Optional
|
22
|
+
|
23
|
+
from nucliadb.common.cluster.base import AbstractIndexNode
|
24
|
+
from nucliadb.common.ids import FIELD_TYPE_STR_TO_PB
|
25
|
+
from nucliadb.train import logger
|
26
|
+
from nucliadb.train.generators.utils import batchify, get_resource_from_cache_or_db
|
27
|
+
from nucliadb_protos.dataset_pb2 import (
|
28
|
+
FieldSplitData,
|
29
|
+
FieldStreamingBatch,
|
30
|
+
TrainSet,
|
31
|
+
)
|
32
|
+
from nucliadb_protos.nodereader_pb2 import StreamRequest
|
33
|
+
from nucliadb_protos.resources_pb2 import Basic, FieldComputedMetadata
|
34
|
+
from nucliadb_protos.utils_pb2 import ExtractedText
|
35
|
+
|
36
|
+
|
37
|
+
def field_streaming_batch_generator(
|
38
|
+
kbid: str,
|
39
|
+
trainset: TrainSet,
|
40
|
+
node: AbstractIndexNode,
|
41
|
+
shard_replica_id: str,
|
42
|
+
) -> AsyncGenerator[FieldStreamingBatch, None]:
|
43
|
+
generator = generate_field_streaming_payloads(kbid, trainset, node, shard_replica_id)
|
44
|
+
batch_generator = batchify(generator, trainset.batch_size, FieldStreamingBatch)
|
45
|
+
return batch_generator
|
46
|
+
|
47
|
+
|
48
|
+
async def generate_field_streaming_payloads(
|
49
|
+
kbid: str,
|
50
|
+
trainset: TrainSet,
|
51
|
+
node: AbstractIndexNode,
|
52
|
+
shard_replica_id: str,
|
53
|
+
) -> AsyncGenerator[FieldSplitData, None]:
|
54
|
+
# Query how many resources has each label
|
55
|
+
request = StreamRequest()
|
56
|
+
request.shard_id.id = shard_replica_id
|
57
|
+
|
58
|
+
for label in trainset.filter.labels:
|
59
|
+
request.filter.labels.append(f"/l/{label}")
|
60
|
+
|
61
|
+
for path in trainset.filter.paths:
|
62
|
+
request.filter.labels.append(f"/p/{path}")
|
63
|
+
|
64
|
+
for metadata in trainset.filter.metadata:
|
65
|
+
request.filter.labels.append(f"/m/{metadata}")
|
66
|
+
|
67
|
+
for entity in trainset.filter.entities:
|
68
|
+
request.filter.labels.append(f"/e/{entity}")
|
69
|
+
|
70
|
+
for field in trainset.filter.fields:
|
71
|
+
request.filter.labels.append(f"/f/{field}")
|
72
|
+
|
73
|
+
for status in trainset.filter.status:
|
74
|
+
request.filter.labels.append(f"/n/s/{status}")
|
75
|
+
total = 0
|
76
|
+
|
77
|
+
async for document_item in node.stream_get_fields(request):
|
78
|
+
text_labels = []
|
79
|
+
for label in document_item.labels:
|
80
|
+
text_labels.append(label)
|
81
|
+
|
82
|
+
field_id = f"{document_item.uuid}{document_item.field}"
|
83
|
+
total += 1
|
84
|
+
|
85
|
+
field_parts = document_item.field.split("/")
|
86
|
+
if len(field_parts) == 3:
|
87
|
+
_, field_type, field = field_parts
|
88
|
+
split = "0"
|
89
|
+
elif len(field_parts) == 4:
|
90
|
+
_, field_type, field, split = field_parts
|
91
|
+
else:
|
92
|
+
raise Exception(f"Invalid field definition {document_item.field}")
|
93
|
+
|
94
|
+
tl = FieldSplitData()
|
95
|
+
rid, field_type, field = field_id.split("/")
|
96
|
+
tl.rid = document_item.uuid
|
97
|
+
tl.field = field
|
98
|
+
tl.field_type = field_type
|
99
|
+
tl.split = split
|
100
|
+
extracted = await get_field_text(kbid, rid, field, field_type)
|
101
|
+
if extracted is not None:
|
102
|
+
tl.text.CopyFrom(extracted)
|
103
|
+
|
104
|
+
metadata_obj = await get_field_metadata(kbid, rid, field, field_type)
|
105
|
+
if metadata_obj is not None:
|
106
|
+
tl.metadata.CopyFrom(metadata_obj)
|
107
|
+
|
108
|
+
basic = await get_field_basic(kbid, rid, field, field_type)
|
109
|
+
if basic is not None:
|
110
|
+
tl.basic.CopyFrom(basic)
|
111
|
+
|
112
|
+
tl.labels.extend(text_labels)
|
113
|
+
|
114
|
+
yield tl
|
115
|
+
|
116
|
+
|
117
|
+
async def get_field_text(kbid: str, rid: str, field: str, field_type: str) -> Optional[ExtractedText]:
|
118
|
+
orm_resource = await get_resource_from_cache_or_db(kbid, rid)
|
119
|
+
|
120
|
+
if orm_resource is None:
|
121
|
+
logger.error(f"{rid} does not exist on DB")
|
122
|
+
return None
|
123
|
+
|
124
|
+
field_type_int = FIELD_TYPE_STR_TO_PB[field_type]
|
125
|
+
field_obj = await orm_resource.get_field(field, field_type_int, load=False)
|
126
|
+
extracted_text = await field_obj.get_extracted_text()
|
127
|
+
|
128
|
+
return extracted_text
|
129
|
+
|
130
|
+
|
131
|
+
async def get_field_metadata(
|
132
|
+
kbid: str, rid: str, field: str, field_type: str
|
133
|
+
) -> Optional[FieldComputedMetadata]:
|
134
|
+
orm_resource = await get_resource_from_cache_or_db(kbid, rid)
|
135
|
+
|
136
|
+
if orm_resource is None:
|
137
|
+
logger.error(f"{rid} does not exist on DB")
|
138
|
+
return None
|
139
|
+
|
140
|
+
field_type_int = FIELD_TYPE_STR_TO_PB[field_type]
|
141
|
+
field_obj = await orm_resource.get_field(field, field_type_int, load=False)
|
142
|
+
field_metadata = await field_obj.get_field_metadata()
|
143
|
+
|
144
|
+
return field_metadata
|
145
|
+
|
146
|
+
|
147
|
+
async def get_field_basic(kbid: str, rid: str, field: str, field_type: str) -> Optional[Basic]:
|
148
|
+
orm_resource = await get_resource_from_cache_or_db(kbid, rid)
|
149
|
+
|
150
|
+
if orm_resource is None:
|
151
|
+
logger.error(f"{rid} does not exist on DB")
|
152
|
+
return None
|
153
|
+
|
154
|
+
basic = await orm_resource.get_basic()
|
155
|
+
|
156
|
+
return basic
|
@@ -21,6 +21,12 @@
|
|
21
21
|
import json
|
22
22
|
from typing import Any, AsyncGenerator
|
23
23
|
|
24
|
+
from nucliadb.common.cluster.base import AbstractIndexNode
|
25
|
+
from nucliadb.common.ids import FIELD_TYPE_STR_TO_PB
|
26
|
+
from nucliadb.ingest.fields.base import Field
|
27
|
+
from nucliadb.ingest.orm.resource import Resource
|
28
|
+
from nucliadb.train import logger
|
29
|
+
from nucliadb.train.generators.utils import batchify, get_resource_from_cache_or_db
|
24
30
|
from nucliadb_protos.dataset_pb2 import (
|
25
31
|
ImageClassification,
|
26
32
|
ImageClassificationBatch,
|
@@ -29,12 +35,6 @@ from nucliadb_protos.dataset_pb2 import (
|
|
29
35
|
from nucliadb_protos.nodereader_pb2 import StreamRequest
|
30
36
|
from nucliadb_protos.resources_pb2 import FieldType, PageStructure, VisualSelection
|
31
37
|
|
32
|
-
from nucliadb.common.cluster.base import AbstractIndexNode
|
33
|
-
from nucliadb.ingest.fields.base import Field
|
34
|
-
from nucliadb.ingest.orm.resource import KB_REVERSE, Resource
|
35
|
-
from nucliadb.train import logger
|
36
|
-
from nucliadb.train.generators.utils import batchify, get_resource_from_cache_or_db
|
37
|
-
|
38
38
|
VISUALLY_ANNOTABLE_FIELDS = {FieldType.FILE, FieldType.LINK}
|
39
39
|
|
40
40
|
# PAWLS JSON format
|
@@ -47,9 +47,7 @@ def image_classification_batch_generator(
|
|
47
47
|
node: AbstractIndexNode,
|
48
48
|
shard_replica_id: str,
|
49
49
|
) -> AsyncGenerator[ImageClassificationBatch, None]:
|
50
|
-
generator = generate_image_classification_payloads(
|
51
|
-
kbid, trainset, node, shard_replica_id
|
52
|
-
)
|
50
|
+
generator = generate_image_classification_payloads(kbid, trainset, node, shard_replica_id)
|
53
51
|
batch_generator = batchify(generator, trainset.batch_size, ImageClassificationBatch)
|
54
52
|
return batch_generator
|
55
53
|
|
@@ -71,7 +69,7 @@ async def generate_image_classification_payloads(
|
|
71
69
|
return
|
72
70
|
|
73
71
|
_, field_type_key, field_key = item.field.split("/")
|
74
|
-
field_type =
|
72
|
+
field_type = FIELD_TYPE_STR_TO_PB[field_type_key]
|
75
73
|
|
76
74
|
if field_type not in VISUALLY_ANNOTABLE_FIELDS:
|
77
75
|
continue
|
@@ -131,9 +129,7 @@ async def generate_image_classification_payloads(
|
|
131
129
|
yield ic
|
132
130
|
|
133
131
|
|
134
|
-
async def get_page_selections(
|
135
|
-
resource: Resource, field: Field
|
136
|
-
) -> dict[int, list[VisualSelection]]:
|
132
|
+
async def get_page_selections(resource: Resource, field: Field) -> dict[int, list[VisualSelection]]:
|
137
133
|
page_selections: dict[int, list[VisualSelection]] = {}
|
138
134
|
basic = await resource.get_basic()
|
139
135
|
if basic is None or basic.fieldmetadata is None:
|
@@ -144,7 +140,7 @@ async def get_page_selections(
|
|
144
140
|
for fieldmetadata in basic.fieldmetadata:
|
145
141
|
if (
|
146
142
|
fieldmetadata.field.field == field.id
|
147
|
-
and fieldmetadata.field.field_type ==
|
143
|
+
and fieldmetadata.field.field_type == FIELD_TYPE_STR_TO_PB[field.type]
|
148
144
|
):
|
149
145
|
for selection in fieldmetadata.page_selections:
|
150
146
|
page_selections[selection.page] = selection.visual # type: ignore
|
@@ -155,7 +151,7 @@ async def get_page_selections(
|
|
155
151
|
|
156
152
|
async def get_page_structure(field: Field) -> list[tuple[str, PageStructure]]:
|
157
153
|
page_structures: list[tuple[str, PageStructure]] = []
|
158
|
-
field_type =
|
154
|
+
field_type = FIELD_TYPE_STR_TO_PB[field.type]
|
159
155
|
if field_type == FieldType.FILE:
|
160
156
|
fed = await field.get_file_extracted_data() # type: ignore
|
161
157
|
if fed is None:
|
@@ -163,9 +159,7 @@ async def get_page_structure(field: Field) -> list[tuple[str, PageStructure]]:
|
|
163
159
|
|
164
160
|
fp = fed.file_pages_previews
|
165
161
|
if len(fp.pages) != len(fp.structures):
|
166
|
-
field_path =
|
167
|
-
f"/kb/{field.kbid}/resource/{field.resource.uuid}/file/{field.id}"
|
168
|
-
)
|
162
|
+
field_path = f"/kb/{field.kbid}/resource/{field.resource.uuid}/file/{field.id}"
|
169
163
|
logger.warning(
|
170
164
|
f"File extracted data has a different number of pages and structures! ({field_path})"
|
171
165
|
)
|
@@ -21,6 +21,9 @@
|
|
21
21
|
from typing import AsyncGenerator
|
22
22
|
|
23
23
|
from fastapi import HTTPException
|
24
|
+
|
25
|
+
from nucliadb.common.cluster.base import AbstractIndexNode
|
26
|
+
from nucliadb.train.generators.utils import batchify, get_paragraph
|
24
27
|
from nucliadb_protos.dataset_pb2 import (
|
25
28
|
Label,
|
26
29
|
ParagraphClassificationBatch,
|
@@ -29,9 +32,6 @@ from nucliadb_protos.dataset_pb2 import (
|
|
29
32
|
)
|
30
33
|
from nucliadb_protos.nodereader_pb2 import StreamRequest
|
31
34
|
|
32
|
-
from nucliadb.common.cluster.base import AbstractIndexNode
|
33
|
-
from nucliadb.train.generators.utils import batchify, get_paragraph
|
34
|
-
|
35
35
|
|
36
36
|
def paragraph_classification_batch_generator(
|
37
37
|
kbid: str,
|
@@ -45,12 +45,8 @@ def paragraph_classification_batch_generator(
|
|
45
45
|
detail="Paragraph Classification should be of 1 labelset",
|
46
46
|
)
|
47
47
|
|
48
|
-
generator = generate_paragraph_classification_payloads(
|
49
|
-
|
50
|
-
)
|
51
|
-
batch_generator = batchify(
|
52
|
-
generator, trainset.batch_size, ParagraphClassificationBatch
|
53
|
-
)
|
48
|
+
generator = generate_paragraph_classification_payloads(kbid, trainset, node, shard_replica_id)
|
49
|
+
batch_generator = batchify(generator, trainset.batch_size, ParagraphClassificationBatch)
|
54
50
|
return batch_generator
|
55
51
|
|
56
52
|
|
@@ -20,6 +20,10 @@
|
|
20
20
|
|
21
21
|
from typing import AsyncGenerator
|
22
22
|
|
23
|
+
from nucliadb.common.cluster.base import AbstractIndexNode
|
24
|
+
from nucliadb.common.ids import FIELD_TYPE_STR_TO_PB
|
25
|
+
from nucliadb.train import logger
|
26
|
+
from nucliadb.train.generators.utils import batchify, get_resource_from_cache_or_db
|
23
27
|
from nucliadb_protos.dataset_pb2 import (
|
24
28
|
ParagraphStreamingBatch,
|
25
29
|
ParagraphStreamItem,
|
@@ -27,11 +31,6 @@ from nucliadb_protos.dataset_pb2 import (
|
|
27
31
|
)
|
28
32
|
from nucliadb_protos.nodereader_pb2 import StreamRequest
|
29
33
|
|
30
|
-
from nucliadb.common.cluster.base import AbstractIndexNode
|
31
|
-
from nucliadb.ingest.orm.resource import KB_REVERSE
|
32
|
-
from nucliadb.train import logger
|
33
|
-
from nucliadb.train.generators.utils import batchify, get_resource_from_cache_or_db
|
34
|
-
|
35
34
|
|
36
35
|
def paragraph_streaming_batch_generator(
|
37
36
|
kbid: str,
|
@@ -39,9 +38,7 @@ def paragraph_streaming_batch_generator(
|
|
39
38
|
node: AbstractIndexNode,
|
40
39
|
shard_replica_id: str,
|
41
40
|
) -> AsyncGenerator[ParagraphStreamingBatch, None]:
|
42
|
-
generator = generate_paragraph_streaming_payloads(
|
43
|
-
kbid, trainset, node, shard_replica_id
|
44
|
-
)
|
41
|
+
generator = generate_paragraph_streaming_payloads(kbid, trainset, node, shard_replica_id)
|
45
42
|
batch_generator = batchify(generator, trainset.batch_size, ParagraphStreamingBatch)
|
46
43
|
return batch_generator
|
47
44
|
|
@@ -68,7 +65,7 @@ async def generate_paragraph_streaming_payloads(
|
|
68
65
|
logger.error(f"{rid} does not exist on DB")
|
69
66
|
continue
|
70
67
|
|
71
|
-
field_type_int =
|
68
|
+
field_type_int = FIELD_TYPE_STR_TO_PB[field_type]
|
72
69
|
field_obj = await orm_resource.get_field(field, field_type_int, load=False)
|
73
70
|
|
74
71
|
extracted_text = await field_obj.get_extracted_text()
|
@@ -20,6 +20,14 @@
|
|
20
20
|
|
21
21
|
from typing import AsyncGenerator
|
22
22
|
|
23
|
+
from nucliadb.common.cluster.base import AbstractIndexNode
|
24
|
+
from nucliadb.common.ids import FIELD_TYPE_PB_TO_STR, FIELD_TYPE_STR_TO_PB
|
25
|
+
from nucliadb.train import logger
|
26
|
+
from nucliadb.train.generators.utils import (
|
27
|
+
batchify,
|
28
|
+
get_paragraph,
|
29
|
+
get_resource_from_cache_or_db,
|
30
|
+
)
|
23
31
|
from nucliadb_protos.dataset_pb2 import (
|
24
32
|
QuestionAnswerStreamingBatch,
|
25
33
|
QuestionAnswerStreamItem,
|
@@ -32,15 +40,6 @@ from nucliadb_protos.resources_pb2 import (
|
|
32
40
|
QuestionAnswerAnnotation,
|
33
41
|
)
|
34
42
|
|
35
|
-
from nucliadb.common.cluster.base import AbstractIndexNode
|
36
|
-
from nucliadb.ingest.orm.resource import FIELD_TYPE_TO_ID, KB_REVERSE
|
37
|
-
from nucliadb.train import logger
|
38
|
-
from nucliadb.train.generators.utils import (
|
39
|
-
batchify,
|
40
|
-
get_paragraph,
|
41
|
-
get_resource_from_cache_or_db,
|
42
|
-
)
|
43
|
-
|
44
43
|
|
45
44
|
def question_answer_batch_generator(
|
46
45
|
kbid: str,
|
@@ -48,12 +47,8 @@ def question_answer_batch_generator(
|
|
48
47
|
node: AbstractIndexNode,
|
49
48
|
shard_replica_id: str,
|
50
49
|
) -> AsyncGenerator[QuestionAnswerStreamingBatch, None]:
|
51
|
-
generator = generate_question_answer_streaming_payloads(
|
52
|
-
|
53
|
-
)
|
54
|
-
batch_generator = batchify(
|
55
|
-
generator, trainset.batch_size, QuestionAnswerStreamingBatch
|
56
|
-
)
|
50
|
+
generator = generate_question_answer_streaming_payloads(kbid, trainset, node, shard_replica_id)
|
51
|
+
batch_generator = batchify(generator, trainset.batch_size, QuestionAnswerStreamingBatch)
|
57
52
|
return batch_generator
|
58
53
|
|
59
54
|
|
@@ -90,14 +85,18 @@ async def generate_question_answer_streaming_payloads(
|
|
90
85
|
item.cancelled_by_user = qa_annotation_pb.cancelled_by_user
|
91
86
|
yield item
|
92
87
|
|
93
|
-
field_type_int =
|
88
|
+
field_type_int = FIELD_TYPE_STR_TO_PB[field_type]
|
94
89
|
field_obj = await orm_resource.get_field(field, field_type_int, load=False)
|
95
90
|
|
96
91
|
question_answers_pb = await field_obj.get_question_answers()
|
97
92
|
if question_answers_pb is not None:
|
98
|
-
for question_answer_pb in question_answers_pb.question_answer:
|
93
|
+
for question_answer_pb in question_answers_pb.question_answers.question_answer:
|
99
94
|
async for item in iter_stream_items(kbid, question_answer_pb):
|
100
95
|
yield item
|
96
|
+
for question_answer_pb in question_answers_pb.split_question_answers.values():
|
97
|
+
for split_question_answer_pb in question_answer_pb.question_answer:
|
98
|
+
async for item in iter_stream_items(kbid, split_question_answer_pb):
|
99
|
+
yield item
|
101
100
|
|
102
101
|
|
103
102
|
async def iter_stream_items(
|
@@ -109,7 +108,7 @@ async def iter_stream_items(
|
|
109
108
|
for paragraph_id in question_pb.ids_paragraphs:
|
110
109
|
try:
|
111
110
|
text = await get_paragraph(kbid, paragraph_id)
|
112
|
-
except Exception as exc: # pragma:
|
111
|
+
except Exception as exc: # pragma: no cover
|
113
112
|
logger.warning(
|
114
113
|
"Question paragraph couldn't be fetched while streaming Q&A",
|
115
114
|
extra={"kbid": kbid, "paragraph_id": paragraph_id},
|
@@ -128,7 +127,7 @@ async def iter_stream_items(
|
|
128
127
|
for paragraph_id in answer_pb.ids_paragraphs:
|
129
128
|
try:
|
130
129
|
text = await get_paragraph(kbid, paragraph_id)
|
131
|
-
except Exception as exc: # pragma:
|
130
|
+
except Exception as exc: # pragma: no cover
|
132
131
|
logger.warning(
|
133
132
|
"Answer paragraph couldn't be fetched while streaming Q&A",
|
134
133
|
extra={"kbid": kbid, "paragraph_id": paragraph_id},
|
@@ -141,4 +140,4 @@ async def iter_stream_items(
|
|
141
140
|
|
142
141
|
|
143
142
|
def is_same_field(field: FieldID, field_id: str, field_type: str) -> bool:
|
144
|
-
return field.field == field_id and
|
143
|
+
return field.field == field_id and FIELD_TYPE_PB_TO_STR[field.field_type] == field_type
|
@@ -21,6 +21,11 @@
|
|
21
21
|
from typing import AsyncGenerator
|
22
22
|
|
23
23
|
from fastapi import HTTPException
|
24
|
+
|
25
|
+
from nucliadb.common.cluster.base import AbstractIndexNode
|
26
|
+
from nucliadb.common.ids import FIELD_TYPE_STR_TO_PB
|
27
|
+
from nucliadb.train import logger
|
28
|
+
from nucliadb.train.generators.utils import batchify, get_resource_from_cache_or_db
|
24
29
|
from nucliadb_protos.dataset_pb2 import (
|
25
30
|
Label,
|
26
31
|
MultipleTextSameLabels,
|
@@ -29,11 +34,6 @@ from nucliadb_protos.dataset_pb2 import (
|
|
29
34
|
)
|
30
35
|
from nucliadb_protos.nodereader_pb2 import StreamRequest
|
31
36
|
|
32
|
-
from nucliadb.common.cluster.base import AbstractIndexNode
|
33
|
-
from nucliadb.ingest.orm.resource import KB_REVERSE
|
34
|
-
from nucliadb.train import logger
|
35
|
-
from nucliadb.train.generators.utils import batchify, get_resource_from_cache_or_db
|
36
|
-
|
37
37
|
|
38
38
|
def sentence_classification_batch_generator(
|
39
39
|
kbid: str,
|
@@ -47,12 +47,8 @@ def sentence_classification_batch_generator(
|
|
47
47
|
detail="Sentence Classification should be at least of 1 labelset",
|
48
48
|
)
|
49
49
|
|
50
|
-
generator = generate_sentence_classification_payloads(
|
51
|
-
|
52
|
-
)
|
53
|
-
batch_generator = batchify(
|
54
|
-
generator, trainset.batch_size, SentenceClassificationBatch
|
55
|
-
)
|
50
|
+
generator = generate_sentence_classification_payloads(kbid, trainset, node, shard_replica_id)
|
51
|
+
batch_generator = batchify(generator, trainset.batch_size, SentenceClassificationBatch)
|
56
52
|
return batch_generator
|
57
53
|
|
58
54
|
|
@@ -107,14 +103,12 @@ async def get_sentences(kbid: str, result: str) -> list[str]:
|
|
107
103
|
logger.error(f"{rid} does not exist on DB")
|
108
104
|
return []
|
109
105
|
|
110
|
-
field_type_int =
|
106
|
+
field_type_int = FIELD_TYPE_STR_TO_PB[field_type]
|
111
107
|
field_obj = await orm_resource.get_field(field, field_type_int, load=False)
|
112
108
|
extracted_text = await field_obj.get_extracted_text()
|
113
109
|
field_metadata = await field_obj.get_field_metadata()
|
114
110
|
if extracted_text is None:
|
115
|
-
logger.warning(
|
116
|
-
f"{rid} {field} {field_type_int} extracted_text does not exist on DB"
|
117
|
-
)
|
111
|
+
logger.warning(f"{rid} {field} {field_type_int} extracted_text does not exist on DB")
|
118
112
|
return []
|
119
113
|
|
120
114
|
splitted_texts = []
|