nucliadb 2.46.1.post382__py3-none-any.whl → 6.2.1.post2777__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- migrations/0002_rollover_shards.py +1 -2
- migrations/0003_allfields_key.py +2 -37
- migrations/0004_rollover_shards.py +1 -2
- migrations/0005_rollover_shards.py +1 -2
- migrations/0006_rollover_shards.py +2 -4
- migrations/0008_cleanup_leftover_rollover_metadata.py +1 -2
- migrations/0009_upgrade_relations_and_texts_to_v2.py +5 -4
- migrations/0010_fix_corrupt_indexes.py +11 -12
- migrations/0011_materialize_labelset_ids.py +2 -18
- migrations/0012_rollover_shards.py +6 -12
- migrations/0013_rollover_shards.py +2 -4
- migrations/0014_rollover_shards.py +5 -7
- migrations/0015_targeted_rollover.py +6 -12
- migrations/0016_upgrade_to_paragraphs_v2.py +27 -32
- migrations/0017_multiple_writable_shards.py +3 -6
- migrations/0018_purge_orphan_kbslugs.py +59 -0
- migrations/0019_upgrade_to_paragraphs_v3.py +66 -0
- migrations/0020_drain_nodes_from_cluster.py +83 -0
- nucliadb/standalone/tests/unit/test_run.py → migrations/0021_overwrite_vectorsets_key.py +17 -18
- nucliadb/tests/unit/test_openapi.py → migrations/0022_fix_paragraph_deletion_bug.py +16 -11
- migrations/0023_backfill_pg_catalog.py +80 -0
- migrations/0025_assign_models_to_kbs_v2.py +113 -0
- migrations/0026_fix_high_cardinality_content_types.py +61 -0
- migrations/0027_rollover_texts3.py +73 -0
- nucliadb/ingest/fields/date.py → migrations/pg/0001_bootstrap.py +10 -12
- migrations/pg/0002_catalog.py +42 -0
- nucliadb/ingest/tests/unit/test_settings.py → migrations/pg/0003_catalog_kbid_index.py +5 -3
- nucliadb/common/cluster/base.py +41 -24
- nucliadb/common/cluster/discovery/base.py +6 -14
- nucliadb/common/cluster/discovery/k8s.py +9 -19
- nucliadb/common/cluster/discovery/manual.py +1 -3
- nucliadb/common/cluster/discovery/single.py +1 -2
- nucliadb/common/cluster/discovery/utils.py +1 -3
- nucliadb/common/cluster/grpc_node_dummy.py +11 -16
- nucliadb/common/cluster/index_node.py +10 -19
- nucliadb/common/cluster/manager.py +223 -102
- nucliadb/common/cluster/rebalance.py +42 -37
- nucliadb/common/cluster/rollover.py +377 -204
- nucliadb/common/cluster/settings.py +16 -9
- nucliadb/common/cluster/standalone/grpc_node_binding.py +24 -76
- nucliadb/common/cluster/standalone/index_node.py +4 -11
- nucliadb/common/cluster/standalone/service.py +2 -6
- nucliadb/common/cluster/standalone/utils.py +9 -6
- nucliadb/common/cluster/utils.py +43 -29
- nucliadb/common/constants.py +20 -0
- nucliadb/common/context/__init__.py +6 -4
- nucliadb/common/context/fastapi.py +8 -5
- nucliadb/{tests/knowledgeboxes/__init__.py → common/counters.py} +8 -2
- nucliadb/common/datamanagers/__init__.py +24 -5
- nucliadb/common/datamanagers/atomic.py +102 -0
- nucliadb/common/datamanagers/cluster.py +5 -5
- nucliadb/common/datamanagers/entities.py +6 -16
- nucliadb/common/datamanagers/fields.py +84 -0
- nucliadb/common/datamanagers/kb.py +101 -24
- nucliadb/common/datamanagers/labels.py +26 -56
- nucliadb/common/datamanagers/processing.py +2 -6
- nucliadb/common/datamanagers/resources.py +214 -117
- nucliadb/common/datamanagers/rollover.py +77 -16
- nucliadb/{ingest/orm → common/datamanagers}/synonyms.py +16 -28
- nucliadb/common/datamanagers/utils.py +19 -11
- nucliadb/common/datamanagers/vectorsets.py +110 -0
- nucliadb/common/external_index_providers/base.py +257 -0
- nucliadb/{ingest/tests/unit/test_cache.py → common/external_index_providers/exceptions.py} +9 -8
- nucliadb/common/external_index_providers/manager.py +101 -0
- nucliadb/common/external_index_providers/pinecone.py +933 -0
- nucliadb/common/external_index_providers/settings.py +52 -0
- nucliadb/common/http_clients/auth.py +3 -6
- nucliadb/common/http_clients/processing.py +6 -11
- nucliadb/common/http_clients/utils.py +1 -3
- nucliadb/common/ids.py +240 -0
- nucliadb/common/locking.py +43 -13
- nucliadb/common/maindb/driver.py +11 -35
- nucliadb/common/maindb/exceptions.py +6 -6
- nucliadb/common/maindb/local.py +22 -9
- nucliadb/common/maindb/pg.py +206 -111
- nucliadb/common/maindb/utils.py +13 -44
- nucliadb/common/models_utils/from_proto.py +479 -0
- nucliadb/common/models_utils/to_proto.py +60 -0
- nucliadb/common/nidx.py +260 -0
- nucliadb/export_import/datamanager.py +25 -19
- nucliadb/export_import/exceptions.py +8 -0
- nucliadb/export_import/exporter.py +20 -7
- nucliadb/export_import/importer.py +6 -11
- nucliadb/export_import/models.py +5 -5
- nucliadb/export_import/tasks.py +4 -4
- nucliadb/export_import/utils.py +94 -54
- nucliadb/health.py +1 -3
- nucliadb/ingest/app.py +15 -11
- nucliadb/ingest/consumer/auditing.py +30 -147
- nucliadb/ingest/consumer/consumer.py +96 -52
- nucliadb/ingest/consumer/materializer.py +10 -12
- nucliadb/ingest/consumer/pull.py +12 -27
- nucliadb/ingest/consumer/service.py +20 -19
- nucliadb/ingest/consumer/shard_creator.py +7 -14
- nucliadb/ingest/consumer/utils.py +1 -3
- nucliadb/ingest/fields/base.py +139 -188
- nucliadb/ingest/fields/conversation.py +18 -5
- nucliadb/ingest/fields/exceptions.py +1 -4
- nucliadb/ingest/fields/file.py +7 -25
- nucliadb/ingest/fields/link.py +11 -16
- nucliadb/ingest/fields/text.py +9 -4
- nucliadb/ingest/orm/brain.py +255 -262
- nucliadb/ingest/orm/broker_message.py +181 -0
- nucliadb/ingest/orm/entities.py +36 -51
- nucliadb/ingest/orm/exceptions.py +12 -0
- nucliadb/ingest/orm/knowledgebox.py +334 -278
- nucliadb/ingest/orm/processor/__init__.py +2 -697
- nucliadb/ingest/orm/processor/auditing.py +117 -0
- nucliadb/ingest/orm/processor/data_augmentation.py +164 -0
- nucliadb/ingest/orm/processor/pgcatalog.py +84 -0
- nucliadb/ingest/orm/processor/processor.py +752 -0
- nucliadb/ingest/orm/processor/sequence_manager.py +1 -1
- nucliadb/ingest/orm/resource.py +280 -520
- nucliadb/ingest/orm/utils.py +25 -31
- nucliadb/ingest/partitions.py +3 -9
- nucliadb/ingest/processing.py +76 -81
- nucliadb/ingest/py.typed +0 -0
- nucliadb/ingest/serialize.py +37 -173
- nucliadb/ingest/service/__init__.py +1 -3
- nucliadb/ingest/service/writer.py +186 -577
- nucliadb/ingest/settings.py +13 -22
- nucliadb/ingest/utils.py +3 -6
- nucliadb/learning_proxy.py +264 -51
- nucliadb/metrics_exporter.py +30 -19
- nucliadb/middleware/__init__.py +1 -3
- nucliadb/migrator/command.py +1 -3
- nucliadb/migrator/datamanager.py +13 -13
- nucliadb/migrator/migrator.py +57 -37
- nucliadb/migrator/settings.py +2 -1
- nucliadb/migrator/utils.py +18 -10
- nucliadb/purge/__init__.py +139 -33
- nucliadb/purge/orphan_shards.py +7 -13
- nucliadb/reader/__init__.py +1 -3
- nucliadb/reader/api/models.py +3 -14
- nucliadb/reader/api/v1/__init__.py +0 -1
- nucliadb/reader/api/v1/download.py +27 -94
- nucliadb/reader/api/v1/export_import.py +4 -4
- nucliadb/reader/api/v1/knowledgebox.py +13 -13
- nucliadb/reader/api/v1/learning_config.py +8 -12
- nucliadb/reader/api/v1/resource.py +67 -93
- nucliadb/reader/api/v1/services.py +70 -125
- nucliadb/reader/app.py +16 -46
- nucliadb/reader/lifecycle.py +18 -4
- nucliadb/reader/py.typed +0 -0
- nucliadb/reader/reader/notifications.py +10 -31
- nucliadb/search/__init__.py +1 -3
- nucliadb/search/api/v1/__init__.py +2 -2
- nucliadb/search/api/v1/ask.py +112 -0
- nucliadb/search/api/v1/catalog.py +184 -0
- nucliadb/search/api/v1/feedback.py +17 -25
- nucliadb/search/api/v1/find.py +41 -41
- nucliadb/search/api/v1/knowledgebox.py +90 -62
- nucliadb/search/api/v1/predict_proxy.py +2 -2
- nucliadb/search/api/v1/resource/ask.py +66 -117
- nucliadb/search/api/v1/resource/search.py +51 -72
- nucliadb/search/api/v1/router.py +1 -0
- nucliadb/search/api/v1/search.py +50 -197
- nucliadb/search/api/v1/suggest.py +40 -54
- nucliadb/search/api/v1/summarize.py +9 -5
- nucliadb/search/api/v1/utils.py +2 -1
- nucliadb/search/app.py +16 -48
- nucliadb/search/lifecycle.py +10 -3
- nucliadb/search/predict.py +176 -188
- nucliadb/search/py.typed +0 -0
- nucliadb/search/requesters/utils.py +41 -63
- nucliadb/search/search/cache.py +149 -20
- nucliadb/search/search/chat/ask.py +918 -0
- nucliadb/search/{tests/unit/test_run.py → search/chat/exceptions.py} +14 -13
- nucliadb/search/search/chat/images.py +41 -17
- nucliadb/search/search/chat/prompt.py +851 -282
- nucliadb/search/search/chat/query.py +274 -267
- nucliadb/{writer/resource/slug.py → search/search/cut.py} +8 -6
- nucliadb/search/search/fetch.py +43 -36
- nucliadb/search/search/filters.py +9 -15
- nucliadb/search/search/find.py +214 -54
- nucliadb/search/search/find_merge.py +408 -391
- nucliadb/search/search/hydrator.py +191 -0
- nucliadb/search/search/merge.py +198 -234
- nucliadb/search/search/metrics.py +73 -2
- nucliadb/search/search/paragraphs.py +64 -106
- nucliadb/search/search/pgcatalog.py +233 -0
- nucliadb/search/search/predict_proxy.py +1 -1
- nucliadb/search/search/query.py +386 -257
- nucliadb/search/search/query_parser/exceptions.py +22 -0
- nucliadb/search/search/query_parser/models.py +101 -0
- nucliadb/search/search/query_parser/parser.py +183 -0
- nucliadb/search/search/rank_fusion.py +204 -0
- nucliadb/search/search/rerankers.py +270 -0
- nucliadb/search/search/shards.py +4 -38
- nucliadb/search/search/summarize.py +14 -18
- nucliadb/search/search/utils.py +27 -4
- nucliadb/search/settings.py +15 -1
- nucliadb/standalone/api_router.py +4 -10
- nucliadb/standalone/app.py +17 -14
- nucliadb/standalone/auth.py +7 -21
- nucliadb/standalone/config.py +9 -12
- nucliadb/standalone/introspect.py +5 -5
- nucliadb/standalone/lifecycle.py +26 -25
- nucliadb/standalone/migrations.py +58 -0
- nucliadb/standalone/purge.py +9 -8
- nucliadb/standalone/py.typed +0 -0
- nucliadb/standalone/run.py +25 -18
- nucliadb/standalone/settings.py +10 -14
- nucliadb/standalone/versions.py +15 -5
- nucliadb/tasks/consumer.py +8 -12
- nucliadb/tasks/producer.py +7 -6
- nucliadb/tests/config.py +53 -0
- nucliadb/train/__init__.py +1 -3
- nucliadb/train/api/utils.py +1 -2
- nucliadb/train/api/v1/shards.py +2 -2
- nucliadb/train/api/v1/trainset.py +4 -6
- nucliadb/train/app.py +14 -47
- nucliadb/train/generator.py +10 -19
- nucliadb/train/generators/field_classifier.py +7 -19
- nucliadb/train/generators/field_streaming.py +156 -0
- nucliadb/train/generators/image_classifier.py +12 -18
- nucliadb/train/generators/paragraph_classifier.py +5 -9
- nucliadb/train/generators/paragraph_streaming.py +6 -9
- nucliadb/train/generators/question_answer_streaming.py +19 -20
- nucliadb/train/generators/sentence_classifier.py +9 -15
- nucliadb/train/generators/token_classifier.py +45 -36
- nucliadb/train/generators/utils.py +14 -18
- nucliadb/train/lifecycle.py +7 -3
- nucliadb/train/nodes.py +23 -32
- nucliadb/train/py.typed +0 -0
- nucliadb/train/servicer.py +13 -21
- nucliadb/train/settings.py +2 -6
- nucliadb/train/types.py +13 -10
- nucliadb/train/upload.py +3 -6
- nucliadb/train/uploader.py +20 -25
- nucliadb/train/utils.py +1 -1
- nucliadb/writer/__init__.py +1 -3
- nucliadb/writer/api/constants.py +0 -5
- nucliadb/{ingest/fields/keywordset.py → writer/api/utils.py} +13 -10
- nucliadb/writer/api/v1/export_import.py +102 -49
- nucliadb/writer/api/v1/field.py +196 -620
- nucliadb/writer/api/v1/knowledgebox.py +221 -71
- nucliadb/writer/api/v1/learning_config.py +2 -2
- nucliadb/writer/api/v1/resource.py +114 -216
- nucliadb/writer/api/v1/services.py +64 -132
- nucliadb/writer/api/v1/slug.py +61 -0
- nucliadb/writer/api/v1/transaction.py +67 -0
- nucliadb/writer/api/v1/upload.py +184 -215
- nucliadb/writer/app.py +11 -61
- nucliadb/writer/back_pressure.py +62 -43
- nucliadb/writer/exceptions.py +0 -4
- nucliadb/writer/lifecycle.py +21 -15
- nucliadb/writer/py.typed +0 -0
- nucliadb/writer/resource/audit.py +2 -1
- nucliadb/writer/resource/basic.py +48 -62
- nucliadb/writer/resource/field.py +45 -135
- nucliadb/writer/resource/origin.py +1 -2
- nucliadb/writer/settings.py +14 -5
- nucliadb/writer/tus/__init__.py +17 -15
- nucliadb/writer/tus/azure.py +111 -0
- nucliadb/writer/tus/dm.py +17 -5
- nucliadb/writer/tus/exceptions.py +1 -3
- nucliadb/writer/tus/gcs.py +56 -84
- nucliadb/writer/tus/local.py +21 -37
- nucliadb/writer/tus/s3.py +28 -68
- nucliadb/writer/tus/storage.py +5 -56
- nucliadb/writer/vectorsets.py +125 -0
- nucliadb-6.2.1.post2777.dist-info/METADATA +148 -0
- nucliadb-6.2.1.post2777.dist-info/RECORD +343 -0
- {nucliadb-2.46.1.post382.dist-info → nucliadb-6.2.1.post2777.dist-info}/WHEEL +1 -1
- nucliadb/common/maindb/redis.py +0 -194
- nucliadb/common/maindb/tikv.py +0 -412
- nucliadb/ingest/fields/layout.py +0 -58
- nucliadb/ingest/tests/conftest.py +0 -30
- nucliadb/ingest/tests/fixtures.py +0 -771
- nucliadb/ingest/tests/integration/consumer/__init__.py +0 -18
- nucliadb/ingest/tests/integration/consumer/test_auditing.py +0 -80
- nucliadb/ingest/tests/integration/consumer/test_materializer.py +0 -89
- nucliadb/ingest/tests/integration/consumer/test_pull.py +0 -144
- nucliadb/ingest/tests/integration/consumer/test_service.py +0 -81
- nucliadb/ingest/tests/integration/consumer/test_shard_creator.py +0 -68
- nucliadb/ingest/tests/integration/ingest/test_ingest.py +0 -691
- nucliadb/ingest/tests/integration/ingest/test_processing_engine.py +0 -95
- nucliadb/ingest/tests/integration/ingest/test_relations.py +0 -272
- nucliadb/ingest/tests/unit/consumer/__init__.py +0 -18
- nucliadb/ingest/tests/unit/consumer/test_auditing.py +0 -140
- nucliadb/ingest/tests/unit/consumer/test_consumer.py +0 -69
- nucliadb/ingest/tests/unit/consumer/test_pull.py +0 -60
- nucliadb/ingest/tests/unit/consumer/test_shard_creator.py +0 -139
- nucliadb/ingest/tests/unit/consumer/test_utils.py +0 -67
- nucliadb/ingest/tests/unit/orm/__init__.py +0 -19
- nucliadb/ingest/tests/unit/orm/test_brain.py +0 -247
- nucliadb/ingest/tests/unit/orm/test_processor.py +0 -131
- nucliadb/ingest/tests/unit/orm/test_resource.py +0 -275
- nucliadb/ingest/tests/unit/test_partitions.py +0 -40
- nucliadb/ingest/tests/unit/test_processing.py +0 -171
- nucliadb/middleware/transaction.py +0 -117
- nucliadb/reader/api/v1/learning_collector.py +0 -63
- nucliadb/reader/tests/__init__.py +0 -19
- nucliadb/reader/tests/conftest.py +0 -31
- nucliadb/reader/tests/fixtures.py +0 -136
- nucliadb/reader/tests/test_list_resources.py +0 -75
- nucliadb/reader/tests/test_reader_file_download.py +0 -273
- nucliadb/reader/tests/test_reader_resource.py +0 -379
- nucliadb/reader/tests/test_reader_resource_field.py +0 -219
- nucliadb/search/api/v1/chat.py +0 -258
- nucliadb/search/api/v1/resource/chat.py +0 -94
- nucliadb/search/tests/__init__.py +0 -19
- nucliadb/search/tests/conftest.py +0 -33
- nucliadb/search/tests/fixtures.py +0 -199
- nucliadb/search/tests/node.py +0 -465
- nucliadb/search/tests/unit/__init__.py +0 -18
- nucliadb/search/tests/unit/api/__init__.py +0 -19
- nucliadb/search/tests/unit/api/v1/__init__.py +0 -19
- nucliadb/search/tests/unit/api/v1/resource/__init__.py +0 -19
- nucliadb/search/tests/unit/api/v1/resource/test_ask.py +0 -67
- nucliadb/search/tests/unit/api/v1/resource/test_chat.py +0 -97
- nucliadb/search/tests/unit/api/v1/test_chat.py +0 -96
- nucliadb/search/tests/unit/api/v1/test_predict_proxy.py +0 -98
- nucliadb/search/tests/unit/api/v1/test_summarize.py +0 -93
- nucliadb/search/tests/unit/search/__init__.py +0 -18
- nucliadb/search/tests/unit/search/requesters/__init__.py +0 -18
- nucliadb/search/tests/unit/search/requesters/test_utils.py +0 -210
- nucliadb/search/tests/unit/search/search/__init__.py +0 -19
- nucliadb/search/tests/unit/search/search/test_shards.py +0 -45
- nucliadb/search/tests/unit/search/search/test_utils.py +0 -82
- nucliadb/search/tests/unit/search/test_chat_prompt.py +0 -266
- nucliadb/search/tests/unit/search/test_fetch.py +0 -108
- nucliadb/search/tests/unit/search/test_filters.py +0 -125
- nucliadb/search/tests/unit/search/test_paragraphs.py +0 -157
- nucliadb/search/tests/unit/search/test_predict_proxy.py +0 -106
- nucliadb/search/tests/unit/search/test_query.py +0 -201
- nucliadb/search/tests/unit/test_app.py +0 -79
- nucliadb/search/tests/unit/test_find_merge.py +0 -112
- nucliadb/search/tests/unit/test_merge.py +0 -34
- nucliadb/search/tests/unit/test_predict.py +0 -584
- nucliadb/standalone/tests/__init__.py +0 -19
- nucliadb/standalone/tests/conftest.py +0 -33
- nucliadb/standalone/tests/fixtures.py +0 -38
- nucliadb/standalone/tests/unit/__init__.py +0 -18
- nucliadb/standalone/tests/unit/test_api_router.py +0 -61
- nucliadb/standalone/tests/unit/test_auth.py +0 -169
- nucliadb/standalone/tests/unit/test_introspect.py +0 -35
- nucliadb/standalone/tests/unit/test_versions.py +0 -68
- nucliadb/tests/benchmarks/__init__.py +0 -19
- nucliadb/tests/benchmarks/test_search.py +0 -99
- nucliadb/tests/conftest.py +0 -32
- nucliadb/tests/fixtures.py +0 -736
- nucliadb/tests/knowledgeboxes/philosophy_books.py +0 -203
- nucliadb/tests/knowledgeboxes/ten_dummy_resources.py +0 -109
- nucliadb/tests/migrations/__init__.py +0 -19
- nucliadb/tests/migrations/test_migration_0017.py +0 -80
- nucliadb/tests/tikv.py +0 -240
- nucliadb/tests/unit/__init__.py +0 -19
- nucliadb/tests/unit/common/__init__.py +0 -19
- nucliadb/tests/unit/common/cluster/__init__.py +0 -19
- nucliadb/tests/unit/common/cluster/discovery/__init__.py +0 -19
- nucliadb/tests/unit/common/cluster/discovery/test_k8s.py +0 -170
- nucliadb/tests/unit/common/cluster/standalone/__init__.py +0 -18
- nucliadb/tests/unit/common/cluster/standalone/test_service.py +0 -113
- nucliadb/tests/unit/common/cluster/standalone/test_utils.py +0 -59
- nucliadb/tests/unit/common/cluster/test_cluster.py +0 -399
- nucliadb/tests/unit/common/cluster/test_kb_shard_manager.py +0 -178
- nucliadb/tests/unit/common/cluster/test_rollover.py +0 -279
- nucliadb/tests/unit/common/maindb/__init__.py +0 -18
- nucliadb/tests/unit/common/maindb/test_driver.py +0 -127
- nucliadb/tests/unit/common/maindb/test_tikv.py +0 -53
- nucliadb/tests/unit/common/maindb/test_utils.py +0 -81
- nucliadb/tests/unit/common/test_context.py +0 -36
- nucliadb/tests/unit/export_import/__init__.py +0 -19
- nucliadb/tests/unit/export_import/test_datamanager.py +0 -37
- nucliadb/tests/unit/export_import/test_utils.py +0 -294
- nucliadb/tests/unit/migrator/__init__.py +0 -19
- nucliadb/tests/unit/migrator/test_migrator.py +0 -87
- nucliadb/tests/unit/tasks/__init__.py +0 -19
- nucliadb/tests/unit/tasks/conftest.py +0 -42
- nucliadb/tests/unit/tasks/test_consumer.py +0 -93
- nucliadb/tests/unit/tasks/test_producer.py +0 -95
- nucliadb/tests/unit/tasks/test_tasks.py +0 -60
- nucliadb/tests/unit/test_field_ids.py +0 -49
- nucliadb/tests/unit/test_health.py +0 -84
- nucliadb/tests/unit/test_kb_slugs.py +0 -54
- nucliadb/tests/unit/test_learning_proxy.py +0 -252
- nucliadb/tests/unit/test_metrics_exporter.py +0 -77
- nucliadb/tests/unit/test_purge.py +0 -138
- nucliadb/tests/utils/__init__.py +0 -74
- nucliadb/tests/utils/aiohttp_session.py +0 -44
- nucliadb/tests/utils/broker_messages/__init__.py +0 -167
- nucliadb/tests/utils/broker_messages/fields.py +0 -181
- nucliadb/tests/utils/broker_messages/helpers.py +0 -33
- nucliadb/tests/utils/entities.py +0 -78
- nucliadb/train/api/v1/check.py +0 -60
- nucliadb/train/tests/__init__.py +0 -19
- nucliadb/train/tests/conftest.py +0 -29
- nucliadb/train/tests/fixtures.py +0 -342
- nucliadb/train/tests/test_field_classification.py +0 -122
- nucliadb/train/tests/test_get_entities.py +0 -80
- nucliadb/train/tests/test_get_info.py +0 -51
- nucliadb/train/tests/test_get_ontology.py +0 -34
- nucliadb/train/tests/test_get_ontology_count.py +0 -63
- nucliadb/train/tests/test_image_classification.py +0 -222
- nucliadb/train/tests/test_list_fields.py +0 -39
- nucliadb/train/tests/test_list_paragraphs.py +0 -73
- nucliadb/train/tests/test_list_resources.py +0 -39
- nucliadb/train/tests/test_list_sentences.py +0 -71
- nucliadb/train/tests/test_paragraph_classification.py +0 -123
- nucliadb/train/tests/test_paragraph_streaming.py +0 -118
- nucliadb/train/tests/test_question_answer_streaming.py +0 -239
- nucliadb/train/tests/test_sentence_classification.py +0 -143
- nucliadb/train/tests/test_token_classification.py +0 -136
- nucliadb/train/tests/utils.py +0 -108
- nucliadb/writer/layouts/__init__.py +0 -51
- nucliadb/writer/layouts/v1.py +0 -59
- nucliadb/writer/resource/vectors.py +0 -120
- nucliadb/writer/tests/__init__.py +0 -19
- nucliadb/writer/tests/conftest.py +0 -31
- nucliadb/writer/tests/fixtures.py +0 -192
- nucliadb/writer/tests/test_fields.py +0 -486
- nucliadb/writer/tests/test_files.py +0 -743
- nucliadb/writer/tests/test_knowledgebox.py +0 -49
- nucliadb/writer/tests/test_reprocess_file_field.py +0 -139
- nucliadb/writer/tests/test_resources.py +0 -546
- nucliadb/writer/tests/test_service.py +0 -137
- nucliadb/writer/tests/test_tus.py +0 -203
- nucliadb/writer/tests/utils.py +0 -35
- nucliadb/writer/tus/pg.py +0 -125
- nucliadb-2.46.1.post382.dist-info/METADATA +0 -134
- nucliadb-2.46.1.post382.dist-info/RECORD +0 -451
- {nucliadb/ingest/tests → migrations/pg}/__init__.py +0 -0
- /nucliadb/{ingest/tests/integration → common/external_index_providers}/__init__.py +0 -0
- /nucliadb/{ingest/tests/integration/ingest → common/models_utils}/__init__.py +0 -0
- /nucliadb/{ingest/tests/unit → search/search/query_parser}/__init__.py +0 -0
- /nucliadb/{ingest/tests → tests}/vectors.py +0 -0
- {nucliadb-2.46.1.post382.dist-info → nucliadb-6.2.1.post2777.dist-info}/entry_points.txt +0 -0
- {nucliadb-2.46.1.post382.dist-info → nucliadb-6.2.1.post2777.dist-info}/top_level.txt +0 -0
- {nucliadb-2.46.1.post382.dist-info → nucliadb-6.2.1.post2777.dist-info}/zip-safe +0 -0
nucliadb/tests/config.py
ADDED
@@ -0,0 +1,53 @@
|
|
1
|
+
# Copyright (C) 2021 Bosutech XXI S.L.
|
2
|
+
#
|
3
|
+
# nucliadb is offered under the AGPL v3.0 and as commercial software.
|
4
|
+
# For commercial licensing, contact us at info@nuclia.com.
|
5
|
+
#
|
6
|
+
# AGPL:
|
7
|
+
# This program is free software: you can redistribute it and/or modify
|
8
|
+
# it under the terms of the GNU Affero General Public License as
|
9
|
+
# published by the Free Software Foundation, either version 3 of the
|
10
|
+
# License, or (at your option) any later version.
|
11
|
+
#
|
12
|
+
# This program is distributed in the hope that it will be useful,
|
13
|
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
14
|
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
15
|
+
# GNU Affero General Public License for more details.
|
16
|
+
#
|
17
|
+
# You should have received a copy of the GNU Affero General Public License
|
18
|
+
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
19
|
+
#
|
20
|
+
|
21
|
+
# This is a test fixture which is useful outside nucliadb tests. In particular
|
22
|
+
# it is used for the testbed. Keeping it under src so it can be imported outside
|
23
|
+
def reset_config():
|
24
|
+
from nucliadb.common.cluster import settings as cluster_settings
|
25
|
+
from nucliadb.ingest import settings as ingest_settings
|
26
|
+
from nucliadb.train import settings as train_settings
|
27
|
+
from nucliadb.writer import settings as writer_settings
|
28
|
+
from nucliadb_utils import settings as utils_settings
|
29
|
+
from nucliadb_utils.cache import settings as cache_settings
|
30
|
+
|
31
|
+
all_settings = [
|
32
|
+
cluster_settings.settings,
|
33
|
+
ingest_settings.settings,
|
34
|
+
train_settings.settings,
|
35
|
+
writer_settings.settings,
|
36
|
+
cache_settings.settings,
|
37
|
+
utils_settings.audit_settings,
|
38
|
+
utils_settings.http_settings,
|
39
|
+
utils_settings.indexing_settings,
|
40
|
+
utils_settings.nuclia_settings,
|
41
|
+
utils_settings.nucliadb_settings,
|
42
|
+
utils_settings.storage_settings,
|
43
|
+
utils_settings.transaction_settings,
|
44
|
+
]
|
45
|
+
for settings in all_settings:
|
46
|
+
defaults = type(settings)()
|
47
|
+
for attr, _value in settings:
|
48
|
+
default_value = getattr(defaults, attr)
|
49
|
+
setattr(settings, attr, default_value)
|
50
|
+
|
51
|
+
from nucliadb.common.cluster import manager
|
52
|
+
|
53
|
+
manager.INDEX_NODES.clear()
|
nucliadb/train/__init__.py
CHANGED
@@ -28,9 +28,7 @@ logger = logging.getLogger(SERVICE_NAME)
|
|
28
28
|
class EndpointFilter(logging.Filter):
|
29
29
|
def filter(self, record: logging.LogRecord) -> bool:
|
30
30
|
return (
|
31
|
-
record.args is not None
|
32
|
-
and len(record.args) >= 3
|
33
|
-
and record.args[2] not in ("/", "/metrics") # type: ignore
|
31
|
+
record.args is not None and len(record.args) >= 3 and record.args[2] not in ("/", "/metrics") # type: ignore
|
34
32
|
)
|
35
33
|
|
36
34
|
|
nucliadb/train/api/utils.py
CHANGED
@@ -21,9 +21,8 @@
|
|
21
21
|
|
22
22
|
from typing import Optional
|
23
23
|
|
24
|
-
from nucliadb_protos.dataset_pb2 import TrainSet
|
25
|
-
|
26
24
|
from nucliadb.train.utils import get_shard_manager
|
25
|
+
from nucliadb_protos.dataset_pb2 import TrainSet
|
27
26
|
|
28
27
|
|
29
28
|
async def get_kb_partitions(kbid: str, prefix: Optional[str] = None):
|
nucliadb/train/api/v1/shards.py
CHANGED
@@ -21,7 +21,7 @@
|
|
21
21
|
|
22
22
|
from fastapi import HTTPException, Request
|
23
23
|
from fastapi.responses import StreamingResponse
|
24
|
-
from fastapi_versioning import version
|
24
|
+
from fastapi_versioning import version
|
25
25
|
|
26
26
|
from nucliadb.train.api.utils import get_kb_partitions, get_train
|
27
27
|
from nucliadb.train.api.v1.router import KB_PREFIX, api
|
@@ -34,7 +34,7 @@ from nucliadb_utils.authentication import requires_one
|
|
34
34
|
f"/{KB_PREFIX}/{{kbid}}/trainset/{{shard}}",
|
35
35
|
tags=["Object Response"],
|
36
36
|
status_code=200,
|
37
|
-
|
37
|
+
summary="Return Train Stream",
|
38
38
|
)
|
39
39
|
@requires_one([NucliaDBRoles.READER])
|
40
40
|
@version(1)
|
@@ -21,7 +21,7 @@
|
|
21
21
|
from typing import Optional
|
22
22
|
|
23
23
|
from fastapi import Request
|
24
|
-
from fastapi_versioning import version
|
24
|
+
from fastapi_versioning import version
|
25
25
|
|
26
26
|
from nucliadb.train.api.utils import get_kb_partitions
|
27
27
|
from nucliadb.train.api.v1.router import KB_PREFIX, api
|
@@ -34,7 +34,7 @@ from nucliadb_utils.authentication import requires_one
|
|
34
34
|
f"/{KB_PREFIX}/{{kbid}}/trainset",
|
35
35
|
tags=["Train"],
|
36
36
|
status_code=200,
|
37
|
-
|
37
|
+
summary="Return Train call",
|
38
38
|
response_model=TrainSetPartitions,
|
39
39
|
)
|
40
40
|
@requires_one([NucliaDBRoles.READER])
|
@@ -47,14 +47,12 @@ async def get_partitions_all(request: Request, kbid: str) -> TrainSetPartitions:
|
|
47
47
|
f"/{KB_PREFIX}/{{kbid}}/trainset/{{prefix}}",
|
48
48
|
tags=["Train"],
|
49
49
|
status_code=200,
|
50
|
-
|
50
|
+
summary="Return Train call",
|
51
51
|
response_model=TrainSetPartitions,
|
52
52
|
)
|
53
53
|
@requires_one([NucliaDBRoles.READER])
|
54
54
|
@version(1)
|
55
|
-
async def get_partitions_prefix(
|
56
|
-
request: Request, kbid: str, prefix: str
|
57
|
-
) -> TrainSetPartitions:
|
55
|
+
async def get_partitions_prefix(request: Request, kbid: str, prefix: str) -> TrainSetPartitions:
|
58
56
|
return await get_partitions(kbid, prefix=prefix)
|
59
57
|
|
60
58
|
|
nucliadb/train/app.py
CHANGED
@@ -17,77 +17,44 @@
|
|
17
17
|
# You should have received a copy of the GNU Affero General Public License
|
18
18
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
19
19
|
#
|
20
|
-
import
|
20
|
+
import importlib.metadata
|
21
|
+
|
21
22
|
from fastapi import FastAPI
|
22
|
-
from fastapi.responses import JSONResponse
|
23
23
|
from starlette.middleware import Middleware
|
24
24
|
from starlette.middleware.authentication import AuthenticationMiddleware
|
25
|
-
from starlette.middleware.cors import CORSMiddleware
|
26
25
|
from starlette.requests import ClientDisconnect, Request
|
27
26
|
from starlette.responses import HTMLResponse
|
28
27
|
|
29
|
-
from nucliadb.middleware.transaction import ReadOnlyTransactionMiddleware
|
30
28
|
from nucliadb.train import API_PREFIX
|
31
29
|
from nucliadb.train.api.v1.router import api
|
32
|
-
from nucliadb.train.lifecycle import
|
30
|
+
from nucliadb.train.lifecycle import lifespan
|
33
31
|
from nucliadb_telemetry import errors
|
34
|
-
from
|
32
|
+
from nucliadb_telemetry.fastapi.utils import (
|
33
|
+
client_disconnect_handler,
|
34
|
+
global_exception_handler,
|
35
|
+
)
|
36
|
+
from nucliadb_utils.audit.stream import AuditMiddleware
|
35
37
|
from nucliadb_utils.authentication import NucliaCloudAuthenticationBackend
|
36
38
|
from nucliadb_utils.fastapi.openapi import extend_openapi
|
37
39
|
from nucliadb_utils.fastapi.versioning import VersionedFastAPI
|
38
|
-
from nucliadb_utils.settings import
|
39
|
-
from nucliadb_utils.utilities import
|
40
|
+
from nucliadb_utils.settings import running_settings
|
41
|
+
from nucliadb_utils.utilities import get_audit
|
40
42
|
|
41
43
|
middleware = []
|
42
|
-
|
43
|
-
if has_feature(const.Features.CORS_MIDDLEWARE, default=False):
|
44
|
-
middleware.append(
|
45
|
-
Middleware(
|
46
|
-
CORSMiddleware,
|
47
|
-
allow_origins=http_settings.cors_origins,
|
48
|
-
allow_methods=["*"],
|
49
|
-
# Authorization will be exluded from * in the future, (CORS non-wildcard request-header).
|
50
|
-
# Browsers already showing deprecation notices, so it needs to be specified explicitly
|
51
|
-
allow_headers=["*", "Authorization"],
|
52
|
-
)
|
53
|
-
)
|
54
|
-
|
55
44
|
middleware.extend(
|
56
45
|
[
|
57
|
-
Middleware(
|
58
|
-
|
59
|
-
),
|
60
|
-
Middleware(ReadOnlyTransactionMiddleware),
|
46
|
+
Middleware(AuthenticationMiddleware, backend=NucliaCloudAuthenticationBackend()),
|
47
|
+
Middleware(AuditMiddleware, audit_utility_getter=get_audit),
|
61
48
|
]
|
62
49
|
)
|
63
50
|
|
64
|
-
errors.setup_error_handling(
|
65
|
-
|
66
|
-
|
67
|
-
on_startup = [initialize]
|
68
|
-
on_shutdown = [finalize]
|
69
|
-
|
70
|
-
|
71
|
-
async def global_exception_handler(request: Request, exc: Exception):
|
72
|
-
errors.capture_exception(exc)
|
73
|
-
return JSONResponse(
|
74
|
-
status_code=500,
|
75
|
-
content={"detail": "Something went wrong, please contact your administrator"},
|
76
|
-
)
|
77
|
-
|
78
|
-
|
79
|
-
async def client_disconnect_handler(request: Request, exc: ClientDisconnect):
|
80
|
-
return JSONResponse(
|
81
|
-
status_code=200,
|
82
|
-
content={"detail": "Client disconnected while an operation was in course"},
|
83
|
-
)
|
51
|
+
errors.setup_error_handling(importlib.metadata.distribution("nucliadb").version)
|
84
52
|
|
85
53
|
|
86
54
|
fastapi_settings = dict(
|
87
55
|
debug=running_settings.debug,
|
88
56
|
middleware=middleware,
|
89
|
-
|
90
|
-
on_shutdown=on_shutdown,
|
57
|
+
lifespan=lifespan,
|
91
58
|
exception_handlers={
|
92
59
|
Exception: global_exception_handler,
|
93
60
|
ClientDisconnect: client_disconnect_handler,
|
nucliadb/train/generator.py
CHANGED
@@ -21,11 +21,11 @@
|
|
21
21
|
from typing import AsyncIterator, Optional
|
22
22
|
|
23
23
|
from fastapi import HTTPException
|
24
|
-
from nucliadb_protos.dataset_pb2 import TaskType, TrainSet
|
25
24
|
|
26
25
|
from nucliadb.train.generators.field_classifier import (
|
27
26
|
field_classification_batch_generator,
|
28
27
|
)
|
28
|
+
from nucliadb.train.generators.field_streaming import field_streaming_batch_generator
|
29
29
|
from nucliadb.train.generators.image_classifier import (
|
30
30
|
image_classification_batch_generator,
|
31
31
|
)
|
@@ -46,6 +46,7 @@ from nucliadb.train.generators.token_classifier import (
|
|
46
46
|
)
|
47
47
|
from nucliadb.train.types import TrainBatch
|
48
48
|
from nucliadb.train.utils import get_shard_manager
|
49
|
+
from nucliadb_protos.dataset_pb2 import TaskType, TrainSet
|
49
50
|
|
50
51
|
|
51
52
|
async def generate_train_data(kbid: str, shard: str, trainset: TrainSet):
|
@@ -59,34 +60,24 @@ async def generate_train_data(kbid: str, shard: str, trainset: TrainSet):
|
|
59
60
|
batch_generator: Optional[AsyncIterator[TrainBatch]] = None
|
60
61
|
|
61
62
|
if trainset.type == TaskType.FIELD_CLASSIFICATION:
|
62
|
-
batch_generator = field_classification_batch_generator(
|
63
|
-
kbid, trainset, node, shard_replica_id
|
64
|
-
)
|
63
|
+
batch_generator = field_classification_batch_generator(kbid, trainset, node, shard_replica_id)
|
65
64
|
elif trainset.type == TaskType.IMAGE_CLASSIFICATION:
|
66
|
-
batch_generator = image_classification_batch_generator(
|
67
|
-
kbid, trainset, node, shard_replica_id
|
68
|
-
)
|
65
|
+
batch_generator = image_classification_batch_generator(kbid, trainset, node, shard_replica_id)
|
69
66
|
elif trainset.type == TaskType.PARAGRAPH_CLASSIFICATION:
|
70
67
|
batch_generator = paragraph_classification_batch_generator(
|
71
68
|
kbid, trainset, node, shard_replica_id
|
72
69
|
)
|
73
70
|
elif trainset.type == TaskType.TOKEN_CLASSIFICATION:
|
74
|
-
batch_generator = token_classification_batch_generator(
|
75
|
-
kbid, trainset, node, shard_replica_id
|
76
|
-
)
|
71
|
+
batch_generator = token_classification_batch_generator(kbid, trainset, node, shard_replica_id)
|
77
72
|
elif trainset.type == TaskType.SENTENCE_CLASSIFICATION:
|
78
|
-
batch_generator = sentence_classification_batch_generator(
|
79
|
-
kbid, trainset, node, shard_replica_id
|
80
|
-
)
|
73
|
+
batch_generator = sentence_classification_batch_generator(kbid, trainset, node, shard_replica_id)
|
81
74
|
elif trainset.type == TaskType.PARAGRAPH_STREAMING:
|
82
|
-
batch_generator = paragraph_streaming_batch_generator(
|
83
|
-
kbid, trainset, node, shard_replica_id
|
84
|
-
)
|
75
|
+
batch_generator = paragraph_streaming_batch_generator(kbid, trainset, node, shard_replica_id)
|
85
76
|
|
86
77
|
elif trainset.type == TaskType.QUESTION_ANSWER_STREAMING:
|
87
|
-
batch_generator = question_answer_batch_generator(
|
88
|
-
|
89
|
-
)
|
78
|
+
batch_generator = question_answer_batch_generator(kbid, trainset, node, shard_replica_id)
|
79
|
+
elif trainset.type == TaskType.FIELD_STREAMING:
|
80
|
+
batch_generator = field_streaming_batch_generator(kbid, trainset, node, shard_replica_id)
|
90
81
|
|
91
82
|
if batch_generator is None:
|
92
83
|
raise HTTPException(
|
@@ -20,7 +20,10 @@
|
|
20
20
|
|
21
21
|
from typing import AsyncGenerator
|
22
22
|
|
23
|
-
from
|
23
|
+
from nucliadb.common.cluster.base import AbstractIndexNode
|
24
|
+
from nucliadb.common.ids import FIELD_TYPE_STR_TO_PB
|
25
|
+
from nucliadb.train import logger
|
26
|
+
from nucliadb.train.generators.utils import batchify, get_resource_from_cache_or_db
|
24
27
|
from nucliadb_protos.dataset_pb2 import (
|
25
28
|
FieldClassificationBatch,
|
26
29
|
Label,
|
@@ -29,11 +32,6 @@ from nucliadb_protos.dataset_pb2 import (
|
|
29
32
|
)
|
30
33
|
from nucliadb_protos.nodereader_pb2 import StreamRequest
|
31
34
|
|
32
|
-
from nucliadb.common.cluster.base import AbstractIndexNode
|
33
|
-
from nucliadb.ingest.orm.resource import KB_REVERSE
|
34
|
-
from nucliadb.train import logger
|
35
|
-
from nucliadb.train.generators.utils import batchify, get_resource_from_cache_or_db
|
36
|
-
|
37
35
|
|
38
36
|
def field_classification_batch_generator(
|
39
37
|
kbid: str,
|
@@ -41,15 +39,7 @@ def field_classification_batch_generator(
|
|
41
39
|
node: AbstractIndexNode,
|
42
40
|
shard_replica_id: str,
|
43
41
|
) -> AsyncGenerator[FieldClassificationBatch, None]:
|
44
|
-
|
45
|
-
raise HTTPException(
|
46
|
-
status_code=422,
|
47
|
-
detail="Paragraph Classification should be of 1 labelset",
|
48
|
-
)
|
49
|
-
|
50
|
-
generator = generate_field_classification_payloads(
|
51
|
-
kbid, trainset, node, shard_replica_id
|
52
|
-
)
|
42
|
+
generator = generate_field_classification_payloads(kbid, trainset, node, shard_replica_id)
|
53
43
|
batch_generator = batchify(generator, trainset.batch_size, FieldClassificationBatch)
|
54
44
|
return batch_generator
|
55
45
|
|
@@ -95,13 +85,11 @@ async def get_field_text(kbid: str, rid: str, field: str, field_type: str) -> st
|
|
95
85
|
logger.error(f"{rid} does not exist on DB")
|
96
86
|
return ""
|
97
87
|
|
98
|
-
field_type_int =
|
88
|
+
field_type_int = FIELD_TYPE_STR_TO_PB[field_type]
|
99
89
|
field_obj = await orm_resource.get_field(field, field_type_int, load=False)
|
100
90
|
extracted_text = await field_obj.get_extracted_text()
|
101
91
|
if extracted_text is None:
|
102
|
-
logger.warning(
|
103
|
-
f"{rid} {field} {field_type_int} extracted_text does not exist on DB"
|
104
|
-
)
|
92
|
+
logger.warning(f"{rid} {field} {field_type_int} extracted_text does not exist on DB")
|
105
93
|
return ""
|
106
94
|
|
107
95
|
text = ""
|
@@ -0,0 +1,156 @@
|
|
1
|
+
# Copyright (C) 2021 Bosutech XXI S.L.
|
2
|
+
#
|
3
|
+
# nucliadb is offered under the AGPL v3.0 and as commercial software.
|
4
|
+
# For commercial licensing, contact us at info@nuclia.com.
|
5
|
+
#
|
6
|
+
# AGPL:
|
7
|
+
# This program is free software: you can redistribute it and/or modify
|
8
|
+
# it under the terms of the GNU Affero General Public License as
|
9
|
+
# published by the Free Software Foundation, either version 3 of the
|
10
|
+
# License, or (at your option) any later version.
|
11
|
+
#
|
12
|
+
# This program is distributed in the hope that it will be useful,
|
13
|
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
14
|
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
15
|
+
# GNU Affero General Public License for more details.
|
16
|
+
#
|
17
|
+
# You should have received a copy of the GNU Affero General Public License
|
18
|
+
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
19
|
+
#
|
20
|
+
|
21
|
+
from typing import AsyncGenerator, Optional
|
22
|
+
|
23
|
+
from nucliadb.common.cluster.base import AbstractIndexNode
|
24
|
+
from nucliadb.common.ids import FIELD_TYPE_STR_TO_PB
|
25
|
+
from nucliadb.train import logger
|
26
|
+
from nucliadb.train.generators.utils import batchify, get_resource_from_cache_or_db
|
27
|
+
from nucliadb_protos.dataset_pb2 import (
|
28
|
+
FieldSplitData,
|
29
|
+
FieldStreamingBatch,
|
30
|
+
TrainSet,
|
31
|
+
)
|
32
|
+
from nucliadb_protos.nodereader_pb2 import StreamRequest
|
33
|
+
from nucliadb_protos.resources_pb2 import Basic, FieldComputedMetadata
|
34
|
+
from nucliadb_protos.utils_pb2 import ExtractedText
|
35
|
+
|
36
|
+
|
37
|
+
def field_streaming_batch_generator(
|
38
|
+
kbid: str,
|
39
|
+
trainset: TrainSet,
|
40
|
+
node: AbstractIndexNode,
|
41
|
+
shard_replica_id: str,
|
42
|
+
) -> AsyncGenerator[FieldStreamingBatch, None]:
|
43
|
+
generator = generate_field_streaming_payloads(kbid, trainset, node, shard_replica_id)
|
44
|
+
batch_generator = batchify(generator, trainset.batch_size, FieldStreamingBatch)
|
45
|
+
return batch_generator
|
46
|
+
|
47
|
+
|
48
|
+
async def generate_field_streaming_payloads(
|
49
|
+
kbid: str,
|
50
|
+
trainset: TrainSet,
|
51
|
+
node: AbstractIndexNode,
|
52
|
+
shard_replica_id: str,
|
53
|
+
) -> AsyncGenerator[FieldSplitData, None]:
|
54
|
+
# Query how many resources has each label
|
55
|
+
request = StreamRequest()
|
56
|
+
request.shard_id.id = shard_replica_id
|
57
|
+
|
58
|
+
for label in trainset.filter.labels:
|
59
|
+
request.filter.labels.append(f"/l/{label}")
|
60
|
+
|
61
|
+
for path in trainset.filter.paths:
|
62
|
+
request.filter.labels.append(f"/p/{path}")
|
63
|
+
|
64
|
+
for metadata in trainset.filter.metadata:
|
65
|
+
request.filter.labels.append(f"/m/{metadata}")
|
66
|
+
|
67
|
+
for entity in trainset.filter.entities:
|
68
|
+
request.filter.labels.append(f"/e/{entity}")
|
69
|
+
|
70
|
+
for field in trainset.filter.fields:
|
71
|
+
request.filter.labels.append(f"/f/{field}")
|
72
|
+
|
73
|
+
for status in trainset.filter.status:
|
74
|
+
request.filter.labels.append(f"/n/s/{status}")
|
75
|
+
total = 0
|
76
|
+
|
77
|
+
async for document_item in node.stream_get_fields(request):
|
78
|
+
text_labels = []
|
79
|
+
for label in document_item.labels:
|
80
|
+
text_labels.append(label)
|
81
|
+
|
82
|
+
field_id = f"{document_item.uuid}{document_item.field}"
|
83
|
+
total += 1
|
84
|
+
|
85
|
+
field_parts = document_item.field.split("/")
|
86
|
+
if len(field_parts) == 3:
|
87
|
+
_, field_type, field = field_parts
|
88
|
+
split = "0"
|
89
|
+
elif len(field_parts) == 4:
|
90
|
+
_, field_type, field, split = field_parts
|
91
|
+
else:
|
92
|
+
raise Exception(f"Invalid field definition {document_item.field}")
|
93
|
+
|
94
|
+
tl = FieldSplitData()
|
95
|
+
rid, field_type, field = field_id.split("/")
|
96
|
+
tl.rid = document_item.uuid
|
97
|
+
tl.field = field
|
98
|
+
tl.field_type = field_type
|
99
|
+
tl.split = split
|
100
|
+
extracted = await get_field_text(kbid, rid, field, field_type)
|
101
|
+
if extracted is not None:
|
102
|
+
tl.text.CopyFrom(extracted)
|
103
|
+
|
104
|
+
metadata_obj = await get_field_metadata(kbid, rid, field, field_type)
|
105
|
+
if metadata_obj is not None:
|
106
|
+
tl.metadata.CopyFrom(metadata_obj)
|
107
|
+
|
108
|
+
basic = await get_field_basic(kbid, rid, field, field_type)
|
109
|
+
if basic is not None:
|
110
|
+
tl.basic.CopyFrom(basic)
|
111
|
+
|
112
|
+
tl.labels.extend(text_labels)
|
113
|
+
|
114
|
+
yield tl
|
115
|
+
|
116
|
+
|
117
|
+
async def get_field_text(kbid: str, rid: str, field: str, field_type: str) -> Optional[ExtractedText]:
|
118
|
+
orm_resource = await get_resource_from_cache_or_db(kbid, rid)
|
119
|
+
|
120
|
+
if orm_resource is None:
|
121
|
+
logger.error(f"{rid} does not exist on DB")
|
122
|
+
return None
|
123
|
+
|
124
|
+
field_type_int = FIELD_TYPE_STR_TO_PB[field_type]
|
125
|
+
field_obj = await orm_resource.get_field(field, field_type_int, load=False)
|
126
|
+
extracted_text = await field_obj.get_extracted_text()
|
127
|
+
|
128
|
+
return extracted_text
|
129
|
+
|
130
|
+
|
131
|
+
async def get_field_metadata(
|
132
|
+
kbid: str, rid: str, field: str, field_type: str
|
133
|
+
) -> Optional[FieldComputedMetadata]:
|
134
|
+
orm_resource = await get_resource_from_cache_or_db(kbid, rid)
|
135
|
+
|
136
|
+
if orm_resource is None:
|
137
|
+
logger.error(f"{rid} does not exist on DB")
|
138
|
+
return None
|
139
|
+
|
140
|
+
field_type_int = FIELD_TYPE_STR_TO_PB[field_type]
|
141
|
+
field_obj = await orm_resource.get_field(field, field_type_int, load=False)
|
142
|
+
field_metadata = await field_obj.get_field_metadata()
|
143
|
+
|
144
|
+
return field_metadata
|
145
|
+
|
146
|
+
|
147
|
+
async def get_field_basic(kbid: str, rid: str, field: str, field_type: str) -> Optional[Basic]:
|
148
|
+
orm_resource = await get_resource_from_cache_or_db(kbid, rid)
|
149
|
+
|
150
|
+
if orm_resource is None:
|
151
|
+
logger.error(f"{rid} does not exist on DB")
|
152
|
+
return None
|
153
|
+
|
154
|
+
basic = await orm_resource.get_basic()
|
155
|
+
|
156
|
+
return basic
|
@@ -21,6 +21,12 @@
|
|
21
21
|
import json
|
22
22
|
from typing import Any, AsyncGenerator
|
23
23
|
|
24
|
+
from nucliadb.common.cluster.base import AbstractIndexNode
|
25
|
+
from nucliadb.common.ids import FIELD_TYPE_STR_TO_PB
|
26
|
+
from nucliadb.ingest.fields.base import Field
|
27
|
+
from nucliadb.ingest.orm.resource import Resource
|
28
|
+
from nucliadb.train import logger
|
29
|
+
from nucliadb.train.generators.utils import batchify, get_resource_from_cache_or_db
|
24
30
|
from nucliadb_protos.dataset_pb2 import (
|
25
31
|
ImageClassification,
|
26
32
|
ImageClassificationBatch,
|
@@ -29,12 +35,6 @@ from nucliadb_protos.dataset_pb2 import (
|
|
29
35
|
from nucliadb_protos.nodereader_pb2 import StreamRequest
|
30
36
|
from nucliadb_protos.resources_pb2 import FieldType, PageStructure, VisualSelection
|
31
37
|
|
32
|
-
from nucliadb.common.cluster.base import AbstractIndexNode
|
33
|
-
from nucliadb.ingest.fields.base import Field
|
34
|
-
from nucliadb.ingest.orm.resource import KB_REVERSE, Resource
|
35
|
-
from nucliadb.train import logger
|
36
|
-
from nucliadb.train.generators.utils import batchify, get_resource_from_cache_or_db
|
37
|
-
|
38
38
|
VISUALLY_ANNOTABLE_FIELDS = {FieldType.FILE, FieldType.LINK}
|
39
39
|
|
40
40
|
# PAWLS JSON format
|
@@ -47,9 +47,7 @@ def image_classification_batch_generator(
|
|
47
47
|
node: AbstractIndexNode,
|
48
48
|
shard_replica_id: str,
|
49
49
|
) -> AsyncGenerator[ImageClassificationBatch, None]:
|
50
|
-
generator = generate_image_classification_payloads(
|
51
|
-
kbid, trainset, node, shard_replica_id
|
52
|
-
)
|
50
|
+
generator = generate_image_classification_payloads(kbid, trainset, node, shard_replica_id)
|
53
51
|
batch_generator = batchify(generator, trainset.batch_size, ImageClassificationBatch)
|
54
52
|
return batch_generator
|
55
53
|
|
@@ -71,7 +69,7 @@ async def generate_image_classification_payloads(
|
|
71
69
|
return
|
72
70
|
|
73
71
|
_, field_type_key, field_key = item.field.split("/")
|
74
|
-
field_type =
|
72
|
+
field_type = FIELD_TYPE_STR_TO_PB[field_type_key]
|
75
73
|
|
76
74
|
if field_type not in VISUALLY_ANNOTABLE_FIELDS:
|
77
75
|
continue
|
@@ -131,9 +129,7 @@ async def generate_image_classification_payloads(
|
|
131
129
|
yield ic
|
132
130
|
|
133
131
|
|
134
|
-
async def get_page_selections(
|
135
|
-
resource: Resource, field: Field
|
136
|
-
) -> dict[int, list[VisualSelection]]:
|
132
|
+
async def get_page_selections(resource: Resource, field: Field) -> dict[int, list[VisualSelection]]:
|
137
133
|
page_selections: dict[int, list[VisualSelection]] = {}
|
138
134
|
basic = await resource.get_basic()
|
139
135
|
if basic is None or basic.fieldmetadata is None:
|
@@ -144,7 +140,7 @@ async def get_page_selections(
|
|
144
140
|
for fieldmetadata in basic.fieldmetadata:
|
145
141
|
if (
|
146
142
|
fieldmetadata.field.field == field.id
|
147
|
-
and fieldmetadata.field.field_type ==
|
143
|
+
and fieldmetadata.field.field_type == FIELD_TYPE_STR_TO_PB[field.type]
|
148
144
|
):
|
149
145
|
for selection in fieldmetadata.page_selections:
|
150
146
|
page_selections[selection.page] = selection.visual # type: ignore
|
@@ -155,7 +151,7 @@ async def get_page_selections(
|
|
155
151
|
|
156
152
|
async def get_page_structure(field: Field) -> list[tuple[str, PageStructure]]:
|
157
153
|
page_structures: list[tuple[str, PageStructure]] = []
|
158
|
-
field_type =
|
154
|
+
field_type = FIELD_TYPE_STR_TO_PB[field.type]
|
159
155
|
if field_type == FieldType.FILE:
|
160
156
|
fed = await field.get_file_extracted_data() # type: ignore
|
161
157
|
if fed is None:
|
@@ -163,9 +159,7 @@ async def get_page_structure(field: Field) -> list[tuple[str, PageStructure]]:
|
|
163
159
|
|
164
160
|
fp = fed.file_pages_previews
|
165
161
|
if len(fp.pages) != len(fp.structures):
|
166
|
-
field_path =
|
167
|
-
f"/kb/{field.kbid}/resource/{field.resource.uuid}/file/{field.id}"
|
168
|
-
)
|
162
|
+
field_path = f"/kb/{field.kbid}/resource/{field.resource.uuid}/file/{field.id}"
|
169
163
|
logger.warning(
|
170
164
|
f"File extracted data has a different number of pages and structures! ({field_path})"
|
171
165
|
)
|
@@ -21,6 +21,9 @@
|
|
21
21
|
from typing import AsyncGenerator
|
22
22
|
|
23
23
|
from fastapi import HTTPException
|
24
|
+
|
25
|
+
from nucliadb.common.cluster.base import AbstractIndexNode
|
26
|
+
from nucliadb.train.generators.utils import batchify, get_paragraph
|
24
27
|
from nucliadb_protos.dataset_pb2 import (
|
25
28
|
Label,
|
26
29
|
ParagraphClassificationBatch,
|
@@ -29,9 +32,6 @@ from nucliadb_protos.dataset_pb2 import (
|
|
29
32
|
)
|
30
33
|
from nucliadb_protos.nodereader_pb2 import StreamRequest
|
31
34
|
|
32
|
-
from nucliadb.common.cluster.base import AbstractIndexNode
|
33
|
-
from nucliadb.train.generators.utils import batchify, get_paragraph
|
34
|
-
|
35
35
|
|
36
36
|
def paragraph_classification_batch_generator(
|
37
37
|
kbid: str,
|
@@ -45,12 +45,8 @@ def paragraph_classification_batch_generator(
|
|
45
45
|
detail="Paragraph Classification should be of 1 labelset",
|
46
46
|
)
|
47
47
|
|
48
|
-
generator = generate_paragraph_classification_payloads(
|
49
|
-
|
50
|
-
)
|
51
|
-
batch_generator = batchify(
|
52
|
-
generator, trainset.batch_size, ParagraphClassificationBatch
|
53
|
-
)
|
48
|
+
generator = generate_paragraph_classification_payloads(kbid, trainset, node, shard_replica_id)
|
49
|
+
batch_generator = batchify(generator, trainset.batch_size, ParagraphClassificationBatch)
|
54
50
|
return batch_generator
|
55
51
|
|
56
52
|
|
@@ -20,6 +20,10 @@
|
|
20
20
|
|
21
21
|
from typing import AsyncGenerator
|
22
22
|
|
23
|
+
from nucliadb.common.cluster.base import AbstractIndexNode
|
24
|
+
from nucliadb.common.ids import FIELD_TYPE_STR_TO_PB
|
25
|
+
from nucliadb.train import logger
|
26
|
+
from nucliadb.train.generators.utils import batchify, get_resource_from_cache_or_db
|
23
27
|
from nucliadb_protos.dataset_pb2 import (
|
24
28
|
ParagraphStreamingBatch,
|
25
29
|
ParagraphStreamItem,
|
@@ -27,11 +31,6 @@ from nucliadb_protos.dataset_pb2 import (
|
|
27
31
|
)
|
28
32
|
from nucliadb_protos.nodereader_pb2 import StreamRequest
|
29
33
|
|
30
|
-
from nucliadb.common.cluster.base import AbstractIndexNode
|
31
|
-
from nucliadb.ingest.orm.resource import KB_REVERSE
|
32
|
-
from nucliadb.train import logger
|
33
|
-
from nucliadb.train.generators.utils import batchify, get_resource_from_cache_or_db
|
34
|
-
|
35
34
|
|
36
35
|
def paragraph_streaming_batch_generator(
|
37
36
|
kbid: str,
|
@@ -39,9 +38,7 @@ def paragraph_streaming_batch_generator(
|
|
39
38
|
node: AbstractIndexNode,
|
40
39
|
shard_replica_id: str,
|
41
40
|
) -> AsyncGenerator[ParagraphStreamingBatch, None]:
|
42
|
-
generator = generate_paragraph_streaming_payloads(
|
43
|
-
kbid, trainset, node, shard_replica_id
|
44
|
-
)
|
41
|
+
generator = generate_paragraph_streaming_payloads(kbid, trainset, node, shard_replica_id)
|
45
42
|
batch_generator = batchify(generator, trainset.batch_size, ParagraphStreamingBatch)
|
46
43
|
return batch_generator
|
47
44
|
|
@@ -68,7 +65,7 @@ async def generate_paragraph_streaming_payloads(
|
|
68
65
|
logger.error(f"{rid} does not exist on DB")
|
69
66
|
continue
|
70
67
|
|
71
|
-
field_type_int =
|
68
|
+
field_type_int = FIELD_TYPE_STR_TO_PB[field_type]
|
72
69
|
field_obj = await orm_resource.get_field(field, field_type_int, load=False)
|
73
70
|
|
74
71
|
extracted_text = await field_obj.get_extracted_text()
|