nucliadb 6.7.2.post4874__py3-none-any.whl → 6.10.0.post5705__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- migrations/0023_backfill_pg_catalog.py +8 -4
- migrations/0028_extracted_vectors_reference.py +1 -1
- migrations/0029_backfill_field_status.py +3 -4
- migrations/0032_remove_old_relations.py +2 -3
- migrations/0038_backfill_catalog_field_labels.py +8 -4
- migrations/0039_backfill_converation_splits_metadata.py +106 -0
- migrations/0040_migrate_search_configurations.py +79 -0
- migrations/0041_reindex_conversations.py +137 -0
- migrations/pg/0010_shards_index.py +34 -0
- nucliadb/search/api/v1/resource/utils.py → migrations/pg/0011_catalog_statistics.py +5 -6
- migrations/pg/0012_catalog_statistics_undo.py +26 -0
- nucliadb/backups/create.py +2 -15
- nucliadb/backups/restore.py +4 -15
- nucliadb/backups/tasks.py +4 -1
- nucliadb/common/back_pressure/cache.py +2 -3
- nucliadb/common/back_pressure/materializer.py +7 -13
- nucliadb/common/back_pressure/settings.py +6 -6
- nucliadb/common/back_pressure/utils.py +1 -0
- nucliadb/common/cache.py +9 -9
- nucliadb/common/catalog/__init__.py +79 -0
- nucliadb/common/catalog/dummy.py +36 -0
- nucliadb/common/catalog/interface.py +85 -0
- nucliadb/{search/search/pgcatalog.py → common/catalog/pg.py} +330 -232
- nucliadb/common/catalog/utils.py +56 -0
- nucliadb/common/cluster/manager.py +8 -23
- nucliadb/common/cluster/rebalance.py +484 -112
- nucliadb/common/cluster/rollover.py +36 -9
- nucliadb/common/cluster/settings.py +4 -9
- nucliadb/common/cluster/utils.py +34 -8
- nucliadb/common/context/__init__.py +7 -8
- nucliadb/common/context/fastapi.py +1 -2
- nucliadb/common/datamanagers/__init__.py +2 -4
- nucliadb/common/datamanagers/atomic.py +9 -2
- nucliadb/common/datamanagers/cluster.py +1 -2
- nucliadb/common/datamanagers/fields.py +3 -4
- nucliadb/common/datamanagers/kb.py +6 -6
- nucliadb/common/datamanagers/labels.py +2 -3
- nucliadb/common/datamanagers/resources.py +10 -33
- nucliadb/common/datamanagers/rollover.py +5 -7
- nucliadb/common/datamanagers/search_configurations.py +1 -2
- nucliadb/common/datamanagers/synonyms.py +1 -2
- nucliadb/common/datamanagers/utils.py +4 -4
- nucliadb/common/datamanagers/vectorsets.py +4 -4
- nucliadb/common/external_index_providers/base.py +32 -5
- nucliadb/common/external_index_providers/manager.py +5 -34
- nucliadb/common/external_index_providers/settings.py +1 -27
- nucliadb/common/filter_expression.py +129 -41
- nucliadb/common/http_clients/exceptions.py +8 -0
- nucliadb/common/http_clients/processing.py +16 -23
- nucliadb/common/http_clients/utils.py +3 -0
- nucliadb/common/ids.py +82 -58
- nucliadb/common/locking.py +1 -2
- nucliadb/common/maindb/driver.py +9 -8
- nucliadb/common/maindb/local.py +5 -5
- nucliadb/common/maindb/pg.py +9 -8
- nucliadb/common/nidx.py +22 -5
- nucliadb/common/vector_index_config.py +1 -1
- nucliadb/export_import/datamanager.py +4 -3
- nucliadb/export_import/exporter.py +11 -19
- nucliadb/export_import/importer.py +13 -6
- nucliadb/export_import/tasks.py +2 -0
- nucliadb/export_import/utils.py +6 -18
- nucliadb/health.py +2 -2
- nucliadb/ingest/app.py +8 -8
- nucliadb/ingest/consumer/consumer.py +8 -10
- nucliadb/ingest/consumer/pull.py +10 -8
- nucliadb/ingest/consumer/service.py +5 -30
- nucliadb/ingest/consumer/shard_creator.py +16 -5
- nucliadb/ingest/consumer/utils.py +1 -1
- nucliadb/ingest/fields/base.py +37 -49
- nucliadb/ingest/fields/conversation.py +55 -9
- nucliadb/ingest/fields/exceptions.py +1 -2
- nucliadb/ingest/fields/file.py +22 -8
- nucliadb/ingest/fields/link.py +7 -7
- nucliadb/ingest/fields/text.py +2 -3
- nucliadb/ingest/orm/brain_v2.py +89 -57
- nucliadb/ingest/orm/broker_message.py +2 -4
- nucliadb/ingest/orm/entities.py +10 -209
- nucliadb/ingest/orm/index_message.py +128 -113
- nucliadb/ingest/orm/knowledgebox.py +91 -59
- nucliadb/ingest/orm/processor/auditing.py +1 -3
- nucliadb/ingest/orm/processor/data_augmentation.py +1 -2
- nucliadb/ingest/orm/processor/processor.py +98 -153
- nucliadb/ingest/orm/processor/sequence_manager.py +1 -2
- nucliadb/ingest/orm/resource.py +82 -71
- nucliadb/ingest/orm/utils.py +1 -1
- nucliadb/ingest/partitions.py +12 -1
- nucliadb/ingest/processing.py +17 -17
- nucliadb/ingest/serialize.py +202 -145
- nucliadb/ingest/service/writer.py +15 -114
- nucliadb/ingest/settings.py +36 -15
- nucliadb/ingest/utils.py +1 -2
- nucliadb/learning_proxy.py +23 -26
- nucliadb/metrics_exporter.py +20 -6
- nucliadb/middleware/__init__.py +82 -1
- nucliadb/migrator/datamanager.py +4 -11
- nucliadb/migrator/migrator.py +1 -2
- nucliadb/migrator/models.py +1 -2
- nucliadb/migrator/settings.py +1 -2
- nucliadb/models/internal/augment.py +614 -0
- nucliadb/models/internal/processing.py +19 -19
- nucliadb/openapi.py +2 -2
- nucliadb/purge/__init__.py +3 -8
- nucliadb/purge/orphan_shards.py +1 -2
- nucliadb/reader/__init__.py +5 -0
- nucliadb/reader/api/models.py +6 -13
- nucliadb/reader/api/v1/download.py +59 -38
- nucliadb/reader/api/v1/export_import.py +4 -4
- nucliadb/reader/api/v1/knowledgebox.py +37 -9
- nucliadb/reader/api/v1/learning_config.py +33 -14
- nucliadb/reader/api/v1/resource.py +61 -9
- nucliadb/reader/api/v1/services.py +18 -14
- nucliadb/reader/app.py +3 -1
- nucliadb/reader/reader/notifications.py +1 -2
- nucliadb/search/api/v1/__init__.py +3 -0
- nucliadb/search/api/v1/ask.py +3 -4
- nucliadb/search/api/v1/augment.py +585 -0
- nucliadb/search/api/v1/catalog.py +15 -19
- nucliadb/search/api/v1/find.py +16 -22
- nucliadb/search/api/v1/hydrate.py +328 -0
- nucliadb/search/api/v1/knowledgebox.py +1 -2
- nucliadb/search/api/v1/predict_proxy.py +1 -2
- nucliadb/search/api/v1/resource/ask.py +28 -8
- nucliadb/search/api/v1/resource/ingestion_agents.py +5 -6
- nucliadb/search/api/v1/resource/search.py +9 -11
- nucliadb/search/api/v1/retrieve.py +130 -0
- nucliadb/search/api/v1/search.py +28 -32
- nucliadb/search/api/v1/suggest.py +11 -14
- nucliadb/search/api/v1/summarize.py +1 -2
- nucliadb/search/api/v1/utils.py +2 -2
- nucliadb/search/app.py +3 -2
- nucliadb/search/augmentor/__init__.py +21 -0
- nucliadb/search/augmentor/augmentor.py +232 -0
- nucliadb/search/augmentor/fields.py +704 -0
- nucliadb/search/augmentor/metrics.py +24 -0
- nucliadb/search/augmentor/paragraphs.py +334 -0
- nucliadb/search/augmentor/resources.py +238 -0
- nucliadb/search/augmentor/utils.py +33 -0
- nucliadb/search/lifecycle.py +3 -1
- nucliadb/search/predict.py +33 -19
- nucliadb/search/predict_models.py +8 -9
- nucliadb/search/requesters/utils.py +11 -10
- nucliadb/search/search/cache.py +19 -42
- nucliadb/search/search/chat/ask.py +131 -59
- nucliadb/search/search/chat/exceptions.py +3 -5
- nucliadb/search/search/chat/fetcher.py +201 -0
- nucliadb/search/search/chat/images.py +6 -4
- nucliadb/search/search/chat/old_prompt.py +1375 -0
- nucliadb/search/search/chat/parser.py +510 -0
- nucliadb/search/search/chat/prompt.py +563 -615
- nucliadb/search/search/chat/query.py +453 -32
- nucliadb/search/search/chat/rpc.py +85 -0
- nucliadb/search/search/fetch.py +3 -4
- nucliadb/search/search/filters.py +8 -11
- nucliadb/search/search/find.py +33 -31
- nucliadb/search/search/find_merge.py +124 -331
- nucliadb/search/search/graph_strategy.py +14 -12
- nucliadb/search/search/hydrator/__init__.py +49 -0
- nucliadb/search/search/hydrator/fields.py +217 -0
- nucliadb/search/search/hydrator/images.py +130 -0
- nucliadb/search/search/hydrator/paragraphs.py +323 -0
- nucliadb/search/search/hydrator/resources.py +60 -0
- nucliadb/search/search/ingestion_agents.py +5 -5
- nucliadb/search/search/merge.py +90 -94
- nucliadb/search/search/metrics.py +24 -7
- nucliadb/search/search/paragraphs.py +7 -9
- nucliadb/search/search/predict_proxy.py +44 -18
- nucliadb/search/search/query.py +14 -86
- nucliadb/search/search/query_parser/fetcher.py +51 -82
- nucliadb/search/search/query_parser/models.py +19 -48
- nucliadb/search/search/query_parser/old_filters.py +20 -19
- nucliadb/search/search/query_parser/parsers/ask.py +5 -6
- nucliadb/search/search/query_parser/parsers/catalog.py +7 -11
- nucliadb/search/search/query_parser/parsers/common.py +21 -13
- nucliadb/search/search/query_parser/parsers/find.py +6 -29
- nucliadb/search/search/query_parser/parsers/graph.py +18 -28
- nucliadb/search/search/query_parser/parsers/retrieve.py +207 -0
- nucliadb/search/search/query_parser/parsers/search.py +15 -56
- nucliadb/search/search/query_parser/parsers/unit_retrieval.py +8 -29
- nucliadb/search/search/rank_fusion.py +18 -13
- nucliadb/search/search/rerankers.py +6 -7
- nucliadb/search/search/retrieval.py +300 -0
- nucliadb/search/search/summarize.py +5 -6
- nucliadb/search/search/utils.py +3 -4
- nucliadb/search/settings.py +1 -2
- nucliadb/standalone/api_router.py +1 -1
- nucliadb/standalone/app.py +4 -3
- nucliadb/standalone/auth.py +5 -6
- nucliadb/standalone/lifecycle.py +2 -2
- nucliadb/standalone/run.py +5 -4
- nucliadb/standalone/settings.py +5 -6
- nucliadb/standalone/versions.py +3 -4
- nucliadb/tasks/consumer.py +13 -8
- nucliadb/tasks/models.py +2 -1
- nucliadb/tasks/producer.py +3 -3
- nucliadb/tasks/retries.py +8 -7
- nucliadb/train/api/utils.py +1 -3
- nucliadb/train/api/v1/shards.py +1 -2
- nucliadb/train/api/v1/trainset.py +1 -2
- nucliadb/train/app.py +1 -1
- nucliadb/train/generator.py +4 -4
- nucliadb/train/generators/field_classifier.py +2 -2
- nucliadb/train/generators/field_streaming.py +6 -6
- nucliadb/train/generators/image_classifier.py +2 -2
- nucliadb/train/generators/paragraph_classifier.py +2 -2
- nucliadb/train/generators/paragraph_streaming.py +2 -2
- nucliadb/train/generators/question_answer_streaming.py +2 -2
- nucliadb/train/generators/sentence_classifier.py +4 -10
- nucliadb/train/generators/token_classifier.py +3 -2
- nucliadb/train/generators/utils.py +6 -5
- nucliadb/train/nodes.py +3 -3
- nucliadb/train/resource.py +6 -8
- nucliadb/train/settings.py +3 -4
- nucliadb/train/types.py +11 -11
- nucliadb/train/upload.py +3 -2
- nucliadb/train/uploader.py +1 -2
- nucliadb/train/utils.py +1 -2
- nucliadb/writer/api/v1/export_import.py +4 -1
- nucliadb/writer/api/v1/field.py +15 -14
- nucliadb/writer/api/v1/knowledgebox.py +18 -56
- nucliadb/writer/api/v1/learning_config.py +5 -4
- nucliadb/writer/api/v1/resource.py +9 -20
- nucliadb/writer/api/v1/services.py +10 -132
- nucliadb/writer/api/v1/upload.py +73 -72
- nucliadb/writer/app.py +8 -2
- nucliadb/writer/resource/basic.py +12 -15
- nucliadb/writer/resource/field.py +43 -5
- nucliadb/writer/resource/origin.py +7 -0
- nucliadb/writer/settings.py +2 -3
- nucliadb/writer/tus/__init__.py +2 -3
- nucliadb/writer/tus/azure.py +5 -7
- nucliadb/writer/tus/dm.py +3 -3
- nucliadb/writer/tus/exceptions.py +3 -4
- nucliadb/writer/tus/gcs.py +15 -22
- nucliadb/writer/tus/s3.py +2 -3
- nucliadb/writer/tus/storage.py +3 -3
- {nucliadb-6.7.2.post4874.dist-info → nucliadb-6.10.0.post5705.dist-info}/METADATA +10 -11
- nucliadb-6.10.0.post5705.dist-info/RECORD +410 -0
- nucliadb/common/datamanagers/entities.py +0 -139
- nucliadb/common/external_index_providers/pinecone.py +0 -894
- nucliadb/ingest/orm/processor/pgcatalog.py +0 -129
- nucliadb/search/search/hydrator.py +0 -197
- nucliadb-6.7.2.post4874.dist-info/RECORD +0 -383
- {nucliadb-6.7.2.post4874.dist-info → nucliadb-6.10.0.post5705.dist-info}/WHEEL +0 -0
- {nucliadb-6.7.2.post4874.dist-info → nucliadb-6.10.0.post5705.dist-info}/entry_points.txt +0 -0
- {nucliadb-6.7.2.post4874.dist-info → nucliadb-6.10.0.post5705.dist-info}/top_level.txt +0 -0
|
@@ -18,104 +18,56 @@
|
|
|
18
18
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|
19
19
|
#
|
|
20
20
|
import asyncio
|
|
21
|
-
from
|
|
22
|
-
|
|
23
|
-
from nidx_protos.nodereader_pb2 import
|
|
24
|
-
DocumentScored,
|
|
25
|
-
GraphSearchResponse,
|
|
26
|
-
ParagraphResult,
|
|
27
|
-
ParagraphSearchResponse,
|
|
28
|
-
SearchResponse,
|
|
29
|
-
VectorSearchResponse,
|
|
30
|
-
)
|
|
21
|
+
from collections.abc import Iterable
|
|
22
|
+
|
|
23
|
+
from nidx_protos.nodereader_pb2 import GraphSearchResponse, SearchResponse
|
|
31
24
|
|
|
32
25
|
from nucliadb.common.external_index_providers.base import TextBlockMatch
|
|
33
|
-
from nucliadb.common.ids import ParagraphId
|
|
34
|
-
from nucliadb.
|
|
26
|
+
from nucliadb.common.ids import ParagraphId
|
|
27
|
+
from nucliadb.models.internal.augment import AugmentedParagraph, Paragraph, ParagraphText
|
|
28
|
+
from nucliadb.search.augmentor.paragraphs import augment_paragraphs
|
|
29
|
+
from nucliadb.search.augmentor.resources import augment_resources_deep
|
|
35
30
|
from nucliadb.search.search.cut import cut_page
|
|
36
31
|
from nucliadb.search.search.hydrator import (
|
|
37
32
|
ResourceHydrationOptions,
|
|
38
33
|
TextBlockHydrationOptions,
|
|
39
|
-
hydrate_resource_metadata,
|
|
40
|
-
hydrate_text_block,
|
|
41
|
-
text_block_to_find_paragraph,
|
|
42
34
|
)
|
|
43
35
|
from nucliadb.search.search.merge import merge_relations_results
|
|
36
|
+
from nucliadb.search.search.metrics import merge_observer
|
|
37
|
+
from nucliadb.search.search.paragraphs import highlight_paragraph
|
|
44
38
|
from nucliadb.search.search.query_parser.models import UnitRetrieval
|
|
45
|
-
from nucliadb.search.search.
|
|
46
|
-
from
|
|
47
|
-
|
|
48
|
-
Reranker,
|
|
49
|
-
RerankingOptions,
|
|
50
|
-
)
|
|
51
|
-
from nucliadb_models.common import FieldTypeName
|
|
52
|
-
from nucliadb_models.resource import ExtractedDataTypeName, Resource
|
|
39
|
+
from nucliadb.search.search.rerankers import RerankableItem, Reranker, RerankingOptions
|
|
40
|
+
from nucliadb_models.resource import Resource
|
|
41
|
+
from nucliadb_models.retrieval import RerankerScore
|
|
53
42
|
from nucliadb_models.search import (
|
|
54
|
-
SCORE_TYPE,
|
|
55
43
|
FindField,
|
|
44
|
+
FindParagraph,
|
|
56
45
|
FindResource,
|
|
57
46
|
KnowledgeboxFindResults,
|
|
58
47
|
MinScore,
|
|
59
|
-
ResourceProperties,
|
|
60
|
-
TextPosition,
|
|
61
48
|
)
|
|
62
49
|
from nucliadb_telemetry import metrics
|
|
63
50
|
|
|
64
|
-
from .metrics import merge_observer
|
|
65
|
-
|
|
66
51
|
FIND_FETCH_OPS_DISTRIBUTION = metrics.Histogram(
|
|
67
52
|
"nucliadb_find_fetch_operations",
|
|
68
53
|
buckets=[1, 5, 10, 20, 30, 40, 50, 60, 80, 100, 200],
|
|
69
54
|
)
|
|
70
55
|
|
|
71
|
-
# Constant score given to all graph results until we implement graph scoring
|
|
72
|
-
FAKE_GRAPH_SCORE = 1.0
|
|
73
|
-
|
|
74
56
|
|
|
75
57
|
@merge_observer.wrap({"type": "find_merge"})
|
|
76
58
|
async def build_find_response(
|
|
77
|
-
|
|
59
|
+
search_response: SearchResponse,
|
|
60
|
+
merged_text_blocks: list[TextBlockMatch],
|
|
61
|
+
graph_response: GraphSearchResponse,
|
|
78
62
|
*,
|
|
79
63
|
retrieval: UnitRetrieval,
|
|
80
64
|
kbid: str,
|
|
81
65
|
query: str,
|
|
82
|
-
rephrased_query:
|
|
83
|
-
rank_fusion_algorithm: RankFusionAlgorithm,
|
|
66
|
+
rephrased_query: str | None,
|
|
84
67
|
reranker: Reranker,
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
field_type_filter: list[FieldTypeName] = [],
|
|
88
|
-
highlight: bool = False,
|
|
68
|
+
resource_hydration_options: ResourceHydrationOptions,
|
|
69
|
+
text_block_hydration_options: TextBlockHydrationOptions,
|
|
89
70
|
) -> KnowledgeboxFindResults:
|
|
90
|
-
# XXX: we shouldn't need a min score that we haven't used. Previous
|
|
91
|
-
# implementations got this value from the proto request (i.e., default to 0)
|
|
92
|
-
min_score_bm25 = 0.0
|
|
93
|
-
if retrieval.query.keyword is not None:
|
|
94
|
-
min_score_bm25 = retrieval.query.keyword.min_score
|
|
95
|
-
min_score_semantic = 0.0
|
|
96
|
-
if retrieval.query.semantic is not None:
|
|
97
|
-
min_score_semantic = retrieval.query.semantic.min_score
|
|
98
|
-
|
|
99
|
-
# merge
|
|
100
|
-
search_response = merge_shard_responses(search_responses)
|
|
101
|
-
|
|
102
|
-
keyword_results = keyword_results_to_text_block_matches(search_response.paragraph.results)
|
|
103
|
-
semantic_results = semantic_results_to_text_block_matches(
|
|
104
|
-
filter(
|
|
105
|
-
lambda x: x.score >= min_score_semantic,
|
|
106
|
-
search_response.vector.documents,
|
|
107
|
-
)
|
|
108
|
-
)
|
|
109
|
-
graph_results = graph_results_to_text_block_matches(search_response.graph)
|
|
110
|
-
|
|
111
|
-
merged_text_blocks = rank_fusion_algorithm.fuse(
|
|
112
|
-
{
|
|
113
|
-
IndexSource.KEYWORD: keyword_results,
|
|
114
|
-
IndexSource.SEMANTIC: semantic_results,
|
|
115
|
-
IndexSource.GRAPH: graph_results,
|
|
116
|
-
}
|
|
117
|
-
)
|
|
118
|
-
|
|
119
71
|
# cut
|
|
120
72
|
# we assume pagination + predict reranker is forbidden and has been already
|
|
121
73
|
# enforced/validated by the query parsing.
|
|
@@ -126,14 +78,12 @@ async def build_find_response(
|
|
|
126
78
|
text_blocks_page, next_page = cut_page(merged_text_blocks, retrieval.top_k)
|
|
127
79
|
|
|
128
80
|
# hydrate and rerank
|
|
129
|
-
|
|
130
|
-
|
|
81
|
+
reranking_options = RerankingOptions(
|
|
82
|
+
kbid=kbid,
|
|
83
|
+
# if we have a rephrased query, we assume it'll be better for the
|
|
84
|
+
# reranker model. Otherwise, use the user query
|
|
85
|
+
query=rephrased_query or query,
|
|
131
86
|
)
|
|
132
|
-
text_block_hydration_options = TextBlockHydrationOptions(
|
|
133
|
-
highlight=highlight,
|
|
134
|
-
ematches=search_response.paragraph.ematches, # type: ignore
|
|
135
|
-
)
|
|
136
|
-
reranking_options = RerankingOptions(kbid=kbid, query=query)
|
|
137
87
|
text_blocks, resources, best_matches = await hydrate_and_rerank(
|
|
138
88
|
text_blocks_page,
|
|
139
89
|
kbid,
|
|
@@ -148,12 +98,41 @@ async def build_find_response(
|
|
|
148
98
|
entry_points = []
|
|
149
99
|
if retrieval.query.relation is not None:
|
|
150
100
|
entry_points = retrieval.query.relation.entry_points
|
|
151
|
-
relations = await merge_relations_results([
|
|
101
|
+
relations = await merge_relations_results([graph_response], entry_points)
|
|
152
102
|
|
|
153
103
|
# compose response
|
|
154
104
|
find_resources = compose_find_resources(text_blocks, resources)
|
|
155
105
|
|
|
156
|
-
|
|
106
|
+
# Compute some misc values for the response
|
|
107
|
+
|
|
108
|
+
# XXX: we shouldn't need a min score that we haven't used. Previous
|
|
109
|
+
# implementations got this value from the proto request (i.e., default to 0)
|
|
110
|
+
min_score_bm25 = 0.0
|
|
111
|
+
if retrieval.query.keyword is not None:
|
|
112
|
+
min_score_bm25 = retrieval.query.keyword.min_score
|
|
113
|
+
min_score_semantic = 0.0
|
|
114
|
+
if retrieval.query.semantic is not None:
|
|
115
|
+
min_score_semantic = retrieval.query.semantic.min_score
|
|
116
|
+
|
|
117
|
+
# Bw/c with pagination, next page can be obtained from different places. The
|
|
118
|
+
# meaning is whether a greater top_k would have returned more results.
|
|
119
|
+
# Although it doesn't take into account matches on the same paragraphs, an
|
|
120
|
+
# estimate is good enough
|
|
121
|
+
next_page = (
|
|
122
|
+
# when rank fusion window is greater than top_k or the reranker window
|
|
123
|
+
next_page
|
|
124
|
+
# when the keyword index already has more results
|
|
125
|
+
or search_response.paragraph.next_page
|
|
126
|
+
# when rank fusion window is greater than top_k
|
|
127
|
+
or len(merged_text_blocks) > retrieval.top_k
|
|
128
|
+
# when the sum of all indexes makes more than top_k
|
|
129
|
+
or (
|
|
130
|
+
len(search_response.paragraph.results)
|
|
131
|
+
+ len(search_response.vector.documents)
|
|
132
|
+
+ len([True for path in graph_response.graph if path.metadata.paragraph_id])
|
|
133
|
+
> retrieval.top_k
|
|
134
|
+
)
|
|
135
|
+
)
|
|
157
136
|
total_paragraphs = search_response.paragraph.total
|
|
158
137
|
|
|
159
138
|
find_results = KnowledgeboxFindResults(
|
|
@@ -171,212 +150,6 @@ async def build_find_response(
|
|
|
171
150
|
return find_results
|
|
172
151
|
|
|
173
152
|
|
|
174
|
-
def merge_shard_responses(
|
|
175
|
-
responses: list[SearchResponse],
|
|
176
|
-
) -> SearchResponse:
|
|
177
|
-
"""Merge search responses into a single response as if there were no shards
|
|
178
|
-
involved.
|
|
179
|
-
|
|
180
|
-
ATENTION! This is not a complete merge, we are only merging the fields
|
|
181
|
-
needed to compose a /find response.
|
|
182
|
-
|
|
183
|
-
"""
|
|
184
|
-
paragraphs = []
|
|
185
|
-
vectors = []
|
|
186
|
-
graphs = []
|
|
187
|
-
for response in responses:
|
|
188
|
-
paragraphs.append(response.paragraph)
|
|
189
|
-
vectors.append(response.vector)
|
|
190
|
-
graphs.append(response.graph)
|
|
191
|
-
|
|
192
|
-
merged = SearchResponse(
|
|
193
|
-
paragraph=merge_shards_keyword_responses(paragraphs),
|
|
194
|
-
vector=merge_shards_semantic_responses(vectors),
|
|
195
|
-
graph=merge_shards_graph_responses(graphs),
|
|
196
|
-
)
|
|
197
|
-
return merged
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
def merge_shards_keyword_responses(
|
|
201
|
-
keyword_responses: list[ParagraphSearchResponse],
|
|
202
|
-
) -> ParagraphSearchResponse:
|
|
203
|
-
"""Merge keyword (paragraph) search responses into a single response as if
|
|
204
|
-
there were no shards involved.
|
|
205
|
-
|
|
206
|
-
ATENTION! This is not a complete merge, we are only merging the fields
|
|
207
|
-
needed to compose a /find response.
|
|
208
|
-
|
|
209
|
-
"""
|
|
210
|
-
merged = ParagraphSearchResponse()
|
|
211
|
-
for response in keyword_responses:
|
|
212
|
-
merged.query = response.query
|
|
213
|
-
merged.next_page = merged.next_page or response.next_page
|
|
214
|
-
merged.total += response.total
|
|
215
|
-
merged.results.extend(response.results)
|
|
216
|
-
merged.ematches.extend(response.ematches)
|
|
217
|
-
|
|
218
|
-
return merged
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
def merge_shards_semantic_responses(
|
|
222
|
-
semantic_responses: list[VectorSearchResponse],
|
|
223
|
-
) -> VectorSearchResponse:
|
|
224
|
-
"""Merge semantic (vector) search responses into a single response as if
|
|
225
|
-
there were no shards involved.
|
|
226
|
-
|
|
227
|
-
ATENTION! This is not a complete merge, we are only merging the fields
|
|
228
|
-
needed to compose a /find response.
|
|
229
|
-
|
|
230
|
-
"""
|
|
231
|
-
merged = VectorSearchResponse()
|
|
232
|
-
for response in semantic_responses:
|
|
233
|
-
merged.documents.extend(response.documents)
|
|
234
|
-
|
|
235
|
-
return merged
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
def merge_shards_graph_responses(
|
|
239
|
-
graph_responses: list[GraphSearchResponse],
|
|
240
|
-
):
|
|
241
|
-
merged = GraphSearchResponse()
|
|
242
|
-
|
|
243
|
-
for response in graph_responses:
|
|
244
|
-
nodes_offset = len(merged.nodes)
|
|
245
|
-
relations_offset = len(merged.relations)
|
|
246
|
-
|
|
247
|
-
# paths contain indexes to nodes and relations, we must offset them
|
|
248
|
-
# while merging responses to maintain valid data
|
|
249
|
-
for path in response.graph:
|
|
250
|
-
merged_path = GraphSearchResponse.Path()
|
|
251
|
-
merged_path.CopyFrom(path)
|
|
252
|
-
merged_path.source += nodes_offset
|
|
253
|
-
merged_path.relation += relations_offset
|
|
254
|
-
merged_path.destination += nodes_offset
|
|
255
|
-
merged.graph.append(merged_path)
|
|
256
|
-
|
|
257
|
-
merged.nodes.extend(response.nodes)
|
|
258
|
-
merged.relations.extend(response.relations)
|
|
259
|
-
|
|
260
|
-
return merged
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
def keyword_result_to_text_block_match(item: ParagraphResult) -> TextBlockMatch:
|
|
264
|
-
fuzzy_result = len(item.matches) > 0
|
|
265
|
-
return TextBlockMatch(
|
|
266
|
-
paragraph_id=ParagraphId.from_string(item.paragraph),
|
|
267
|
-
score=item.score.bm25,
|
|
268
|
-
score_type=SCORE_TYPE.BM25,
|
|
269
|
-
order=0, # NOTE: this will be filled later
|
|
270
|
-
text="", # NOTE: this will be filled later too
|
|
271
|
-
position=TextPosition(
|
|
272
|
-
page_number=item.metadata.position.page_number,
|
|
273
|
-
index=item.metadata.position.index,
|
|
274
|
-
start=item.start,
|
|
275
|
-
end=item.end,
|
|
276
|
-
start_seconds=[x for x in item.metadata.position.start_seconds],
|
|
277
|
-
end_seconds=[x for x in item.metadata.position.end_seconds],
|
|
278
|
-
),
|
|
279
|
-
# XXX: we should split labels
|
|
280
|
-
field_labels=[],
|
|
281
|
-
paragraph_labels=list(item.labels),
|
|
282
|
-
fuzzy_search=fuzzy_result,
|
|
283
|
-
is_a_table=item.metadata.representation.is_a_table,
|
|
284
|
-
representation_file=item.metadata.representation.file,
|
|
285
|
-
page_with_visual=item.metadata.page_with_visual,
|
|
286
|
-
)
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
def keyword_results_to_text_block_matches(items: Iterable[ParagraphResult]) -> list[TextBlockMatch]:
|
|
290
|
-
return [keyword_result_to_text_block_match(item) for item in items]
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
class InvalidDocId(Exception):
|
|
294
|
-
"""Raised while parsing an invalid id coming from semantic search"""
|
|
295
|
-
|
|
296
|
-
def __init__(self, invalid_vector_id: str):
|
|
297
|
-
self.invalid_vector_id = invalid_vector_id
|
|
298
|
-
super().__init__(f"Invalid vector ID: {invalid_vector_id}")
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
def semantic_result_to_text_block_match(item: DocumentScored) -> TextBlockMatch:
|
|
302
|
-
try:
|
|
303
|
-
vector_id = VectorId.from_string(item.doc_id.id)
|
|
304
|
-
except (IndexError, ValueError):
|
|
305
|
-
raise InvalidDocId(item.doc_id.id)
|
|
306
|
-
|
|
307
|
-
return TextBlockMatch(
|
|
308
|
-
paragraph_id=ParagraphId.from_vector_id(vector_id),
|
|
309
|
-
score=item.score,
|
|
310
|
-
score_type=SCORE_TYPE.VECTOR,
|
|
311
|
-
order=0, # NOTE: this will be filled later
|
|
312
|
-
text="", # NOTE: this will be filled later too
|
|
313
|
-
position=TextPosition(
|
|
314
|
-
page_number=item.metadata.position.page_number,
|
|
315
|
-
index=item.metadata.position.index,
|
|
316
|
-
start=vector_id.vector_start,
|
|
317
|
-
end=vector_id.vector_end,
|
|
318
|
-
start_seconds=[x for x in item.metadata.position.start_seconds],
|
|
319
|
-
end_seconds=[x for x in item.metadata.position.end_seconds],
|
|
320
|
-
),
|
|
321
|
-
# XXX: we should split labels
|
|
322
|
-
field_labels=[],
|
|
323
|
-
paragraph_labels=list(item.labels),
|
|
324
|
-
fuzzy_search=False, # semantic search doesn't have fuzziness
|
|
325
|
-
is_a_table=item.metadata.representation.is_a_table,
|
|
326
|
-
representation_file=item.metadata.representation.file,
|
|
327
|
-
page_with_visual=item.metadata.page_with_visual,
|
|
328
|
-
)
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
def semantic_results_to_text_block_matches(items: Iterable[DocumentScored]) -> list[TextBlockMatch]:
|
|
332
|
-
text_blocks: list[TextBlockMatch] = []
|
|
333
|
-
for item in items:
|
|
334
|
-
try:
|
|
335
|
-
text_block = semantic_result_to_text_block_match(item)
|
|
336
|
-
except InvalidDocId as exc:
|
|
337
|
-
logger.warning(f"Skipping invalid doc_id: {exc.invalid_vector_id}")
|
|
338
|
-
continue
|
|
339
|
-
text_blocks.append(text_block)
|
|
340
|
-
return text_blocks
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
def graph_results_to_text_block_matches(item: GraphSearchResponse) -> list[TextBlockMatch]:
|
|
344
|
-
matches = []
|
|
345
|
-
for path in item.graph:
|
|
346
|
-
metadata = path.metadata
|
|
347
|
-
|
|
348
|
-
if not metadata.paragraph_id:
|
|
349
|
-
continue
|
|
350
|
-
|
|
351
|
-
paragraph_id = ParagraphId.from_string(metadata.paragraph_id)
|
|
352
|
-
matches.append(
|
|
353
|
-
TextBlockMatch(
|
|
354
|
-
paragraph_id=paragraph_id,
|
|
355
|
-
score=FAKE_GRAPH_SCORE,
|
|
356
|
-
score_type=SCORE_TYPE.RELATION_RELEVANCE,
|
|
357
|
-
order=0, # NOTE: this will be filled later
|
|
358
|
-
text="", # NOTE: this will be filled later too
|
|
359
|
-
position=TextPosition(
|
|
360
|
-
page_number=0,
|
|
361
|
-
index=0,
|
|
362
|
-
start=paragraph_id.paragraph_start,
|
|
363
|
-
end=paragraph_id.paragraph_end,
|
|
364
|
-
start_seconds=[],
|
|
365
|
-
end_seconds=[],
|
|
366
|
-
),
|
|
367
|
-
# XXX: we should split labels
|
|
368
|
-
field_labels=[],
|
|
369
|
-
paragraph_labels=[],
|
|
370
|
-
fuzzy_search=False, # TODO: this depends on the query, should we populate it?
|
|
371
|
-
is_a_table=False,
|
|
372
|
-
representation_file="",
|
|
373
|
-
page_with_visual=False,
|
|
374
|
-
)
|
|
375
|
-
)
|
|
376
|
-
|
|
377
|
-
return matches
|
|
378
|
-
|
|
379
|
-
|
|
380
153
|
@merge_observer.wrap({"type": "hydrate_and_rerank"})
|
|
381
154
|
async def hydrate_and_rerank(
|
|
382
155
|
text_blocks: Iterable[TextBlockMatch],
|
|
@@ -398,11 +171,12 @@ async def hydrate_and_rerank(
|
|
|
398
171
|
"""
|
|
399
172
|
max_operations = asyncio.Semaphore(50)
|
|
400
173
|
|
|
401
|
-
# Iterate text blocks
|
|
402
|
-
#
|
|
174
|
+
# Iterate text blocks to create an "index" for faster access by id and get a
|
|
175
|
+
# list of text block ids and resource ids to hydrate
|
|
403
176
|
text_blocks_by_id: dict[str, TextBlockMatch] = {} # useful for faster access to text blocks later
|
|
404
|
-
|
|
405
|
-
|
|
177
|
+
resources_to_hydrate = set()
|
|
178
|
+
text_block_id_to_hydrate = set()
|
|
179
|
+
|
|
406
180
|
for text_block in text_blocks:
|
|
407
181
|
rid = text_block.paragraph_id.rid
|
|
408
182
|
paragraph_id = text_block.paragraph_id.full()
|
|
@@ -417,41 +191,48 @@ async def hydrate_and_rerank(
|
|
|
417
191
|
# ones we see now, so we'll skip this step and recompute the resources
|
|
418
192
|
# later
|
|
419
193
|
if not reranker.needs_extra_results:
|
|
420
|
-
|
|
421
|
-
resource_hydration_ops[rid] = asyncio.create_task(
|
|
422
|
-
hydrate_resource_metadata(
|
|
423
|
-
kbid,
|
|
424
|
-
rid,
|
|
425
|
-
options=resource_hydration_options,
|
|
426
|
-
concurrency_control=max_operations,
|
|
427
|
-
service_name=SERVICE_NAME,
|
|
428
|
-
)
|
|
429
|
-
)
|
|
194
|
+
resources_to_hydrate.add(rid)
|
|
430
195
|
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
text_block,
|
|
436
|
-
text_block_hydration_options,
|
|
437
|
-
concurrency_control=max_operations,
|
|
438
|
-
)
|
|
439
|
-
)
|
|
440
|
-
)
|
|
196
|
+
if text_block_hydration_options.only_hydrate_empty and text_block.text:
|
|
197
|
+
pass
|
|
198
|
+
else:
|
|
199
|
+
text_block_id_to_hydrate.add(paragraph_id)
|
|
441
200
|
|
|
442
201
|
# hydrate only the strictly needed before rerank
|
|
443
|
-
hydrated_text_blocks: list[TextBlockMatch]
|
|
444
|
-
hydrated_resources: list[Union[Resource, None]]
|
|
445
|
-
|
|
446
202
|
ops = [
|
|
447
|
-
|
|
448
|
-
|
|
203
|
+
augment_paragraphs(
|
|
204
|
+
kbid,
|
|
205
|
+
given=[
|
|
206
|
+
Paragraph.from_text_block_match(text_blocks_by_id[paragraph_id])
|
|
207
|
+
for paragraph_id in text_block_id_to_hydrate
|
|
208
|
+
],
|
|
209
|
+
select=[ParagraphText()],
|
|
210
|
+
concurrency_control=max_operations,
|
|
211
|
+
),
|
|
212
|
+
augment_resources_deep(
|
|
213
|
+
kbid,
|
|
214
|
+
given=list(resources_to_hydrate),
|
|
215
|
+
opts=resource_hydration_options,
|
|
216
|
+
concurrency_control=max_operations,
|
|
217
|
+
),
|
|
449
218
|
]
|
|
450
|
-
FIND_FETCH_OPS_DISTRIBUTION.observe(len(
|
|
219
|
+
FIND_FETCH_OPS_DISTRIBUTION.observe(len(text_block_id_to_hydrate) + len(resources_to_hydrate))
|
|
451
220
|
results = await asyncio.gather(*ops)
|
|
452
221
|
|
|
453
|
-
|
|
454
|
-
|
|
222
|
+
augmented_paragraphs: dict[ParagraphId, AugmentedParagraph | None] = results[0] # type: ignore
|
|
223
|
+
augmented_resources: dict[str, Resource | None] = results[1] # type: ignore
|
|
224
|
+
|
|
225
|
+
# add hydrated text to our text blocks
|
|
226
|
+
for text_block in text_blocks:
|
|
227
|
+
augmented = augmented_paragraphs.get(text_block.paragraph_id, None)
|
|
228
|
+
if augmented is not None and augmented.text is not None:
|
|
229
|
+
if text_block_hydration_options.highlight:
|
|
230
|
+
text = highlight_paragraph(
|
|
231
|
+
augmented.text, words=[], ematches=text_block_hydration_options.ematches
|
|
232
|
+
)
|
|
233
|
+
else:
|
|
234
|
+
text = augmented.text
|
|
235
|
+
text_block.text = text
|
|
455
236
|
|
|
456
237
|
# with the hydrated text, rerank and apply new scores to the text blocks
|
|
457
238
|
to_rerank = [
|
|
@@ -461,7 +242,7 @@ async def hydrate_and_rerank(
|
|
|
461
242
|
score_type=text_block.score_type,
|
|
462
243
|
content=text_block.text or "", # TODO: add a warning, this shouldn't usually happen
|
|
463
244
|
)
|
|
464
|
-
for text_block in
|
|
245
|
+
for text_block in text_blocks
|
|
465
246
|
]
|
|
466
247
|
reranked = await reranker.rerank(to_rerank, reranking_options)
|
|
467
248
|
|
|
@@ -476,7 +257,7 @@ async def hydrate_and_rerank(
|
|
|
476
257
|
score_type = item.score_type
|
|
477
258
|
|
|
478
259
|
text_block = text_blocks_by_id[paragraph_id]
|
|
479
|
-
text_block.score
|
|
260
|
+
text_block.scores.append(RerankerScore(score=score))
|
|
480
261
|
text_block.score_type = score_type
|
|
481
262
|
|
|
482
263
|
matches.append((paragraph_id, score))
|
|
@@ -485,7 +266,7 @@ async def hydrate_and_rerank(
|
|
|
485
266
|
|
|
486
267
|
best_matches = []
|
|
487
268
|
best_text_blocks = []
|
|
488
|
-
|
|
269
|
+
resources_to_hydrate.clear()
|
|
489
270
|
for order, (paragraph_id, _) in enumerate(matches):
|
|
490
271
|
text_block = text_blocks_by_id[paragraph_id]
|
|
491
272
|
text_block.order = order
|
|
@@ -495,24 +276,19 @@ async def hydrate_and_rerank(
|
|
|
495
276
|
# now we have removed the text block surplus, fetch resource metadata
|
|
496
277
|
if reranker.needs_extra_results:
|
|
497
278
|
rid = ParagraphId.from_string(paragraph_id).rid
|
|
498
|
-
|
|
499
|
-
resource_hydration_ops[rid] = asyncio.create_task(
|
|
500
|
-
hydrate_resource_metadata(
|
|
501
|
-
kbid,
|
|
502
|
-
rid,
|
|
503
|
-
options=resource_hydration_options,
|
|
504
|
-
concurrency_control=max_operations,
|
|
505
|
-
service_name=SERVICE_NAME,
|
|
506
|
-
)
|
|
507
|
-
)
|
|
279
|
+
resources_to_hydrate.add(rid)
|
|
508
280
|
|
|
509
281
|
# Finally, fetch resource metadata if we haven't already done it
|
|
510
282
|
if reranker.needs_extra_results:
|
|
511
|
-
|
|
512
|
-
|
|
513
|
-
|
|
283
|
+
FIND_FETCH_OPS_DISTRIBUTION.observe(len(resources_to_hydrate))
|
|
284
|
+
augmented_resources = await augment_resources_deep(
|
|
285
|
+
kbid,
|
|
286
|
+
given=list(resources_to_hydrate),
|
|
287
|
+
opts=resource_hydration_options,
|
|
288
|
+
concurrency_control=max_operations,
|
|
289
|
+
)
|
|
514
290
|
|
|
515
|
-
resources = [resource for resource in
|
|
291
|
+
resources = [resource for resource in augmented_resources.values() if resource is not None]
|
|
516
292
|
|
|
517
293
|
return best_text_blocks, resources, best_matches
|
|
518
294
|
|
|
@@ -547,5 +323,22 @@ def compose_find_resources(
|
|
|
547
323
|
return find_resources
|
|
548
324
|
|
|
549
325
|
|
|
326
|
+
def text_block_to_find_paragraph(text_block: TextBlockMatch) -> FindParagraph:
|
|
327
|
+
return FindParagraph(
|
|
328
|
+
id=text_block.paragraph_id.full(),
|
|
329
|
+
text=text_block.text or "",
|
|
330
|
+
score=text_block.score,
|
|
331
|
+
score_type=text_block.score_type,
|
|
332
|
+
order=text_block.order,
|
|
333
|
+
labels=text_block.paragraph_labels,
|
|
334
|
+
fuzzy_result=text_block.fuzzy_search,
|
|
335
|
+
is_a_table=text_block.is_a_table,
|
|
336
|
+
reference=text_block.representation_file,
|
|
337
|
+
page_with_visual=text_block.page_with_visual,
|
|
338
|
+
position=text_block.position,
|
|
339
|
+
relevant_relations=text_block.relevant_relations,
|
|
340
|
+
)
|
|
341
|
+
|
|
342
|
+
|
|
550
343
|
def _round(x: float) -> float:
|
|
551
344
|
return round(x, ndigits=3)
|
|
@@ -19,8 +19,9 @@
|
|
|
19
19
|
import heapq
|
|
20
20
|
import json
|
|
21
21
|
from collections import defaultdict
|
|
22
|
+
from collections.abc import Collection, Iterable
|
|
22
23
|
from dataclasses import dataclass
|
|
23
|
-
from typing import Any
|
|
24
|
+
from typing import Any
|
|
24
25
|
|
|
25
26
|
from nidx_protos import nodereader_pb2
|
|
26
27
|
from nuclia_models.predict.generative_responses import (
|
|
@@ -55,6 +56,7 @@ from nucliadb_models.internal.predict import (
|
|
|
55
56
|
RerankModel,
|
|
56
57
|
)
|
|
57
58
|
from nucliadb_models.resource import ExtractedDataTypeName
|
|
59
|
+
from nucliadb_models.retrieval import GraphScore
|
|
58
60
|
from nucliadb_models.search import (
|
|
59
61
|
SCORE_TYPE,
|
|
60
62
|
AskRequest,
|
|
@@ -112,11 +114,11 @@ SCHEMA = {
|
|
|
112
114
|
}
|
|
113
115
|
|
|
114
116
|
PROMPT = """\
|
|
115
|
-
You are an advanced language model assisting in scoring relationships (edges) between two entities in a knowledge graph, given a user
|
|
117
|
+
You are an advanced language model assisting in scoring relationships (edges) between two entities in a knowledge graph, given a user's question.
|
|
116
118
|
|
|
117
119
|
For each provided **(head_entity, relationship, tail_entity)**, you must:
|
|
118
120
|
1. Assign a **relevance score** between **0** and **10**.
|
|
119
|
-
2. **0** means “this relationship can
|
|
121
|
+
2. **0** means “this relationship can't be relevant at all to the question.”
|
|
120
122
|
3. **10** means “this relationship is extremely relevant to the question.”
|
|
121
123
|
4. You may use **any integer** between 0 and 10 (e.g., 3, 7, etc.) based on how relevant you deem the relationship to be.
|
|
122
124
|
5. **Language Agnosticism**: The question and the relationships may be in different languages. The relevance scoring should still work and be agnostic of the language.
|
|
@@ -318,8 +320,8 @@ async def get_graph_results(
|
|
|
318
320
|
graph_strategy: GraphStrategy,
|
|
319
321
|
text_block_reranker: Reranker,
|
|
320
322
|
metrics: Metrics,
|
|
321
|
-
generative_model:
|
|
322
|
-
shards:
|
|
323
|
+
generative_model: str | None = None,
|
|
324
|
+
shards: list[str] | None = None,
|
|
323
325
|
) -> tuple[KnowledgeboxFindResults, FindRequest]:
|
|
324
326
|
relations = Relations(entities={})
|
|
325
327
|
explored_entities: set[FrozenRelationNode] = set()
|
|
@@ -465,7 +467,7 @@ async def get_graph_results(
|
|
|
465
467
|
async def fuzzy_search_entities(
|
|
466
468
|
kbid: str,
|
|
467
469
|
query: str,
|
|
468
|
-
) ->
|
|
470
|
+
) -> RelatedEntities | None:
|
|
469
471
|
"""Fuzzy find entities in KB given a query using the same methodology as /suggest, but split by words."""
|
|
470
472
|
|
|
471
473
|
# Build an OR for each word in the query matching with fuzzy any word in any
|
|
@@ -493,7 +495,7 @@ async def fuzzy_search_entities(
|
|
|
493
495
|
# merge shard results while deduplicating repeated entities across shards
|
|
494
496
|
unique_entities: set[RelatedEntity] = set()
|
|
495
497
|
for response in results:
|
|
496
|
-
unique_entities.update(
|
|
498
|
+
unique_entities.update(RelatedEntity(family=e.subtype, value=e.value) for e in response.nodes)
|
|
497
499
|
|
|
498
500
|
return RelatedEntities(entities=list(unique_entities), total=len(unique_entities))
|
|
499
501
|
|
|
@@ -572,7 +574,7 @@ async def rank_relations_generative(
|
|
|
572
574
|
kbid: str,
|
|
573
575
|
user: str,
|
|
574
576
|
top_k: int,
|
|
575
|
-
generative_model:
|
|
577
|
+
generative_model: str | None = None,
|
|
576
578
|
score_threshold: float = 2,
|
|
577
579
|
max_rels_to_eval: int = 100,
|
|
578
580
|
) -> tuple[Relations, dict[str, list[float]]]:
|
|
@@ -650,7 +652,7 @@ async def rank_relations_generative(
|
|
|
650
652
|
if response_json is None or status is None or status.code != "0":
|
|
651
653
|
raise ValueError("No JSON response found")
|
|
652
654
|
|
|
653
|
-
scored_unique_triplets: list[dict[str,
|
|
655
|
+
scored_unique_triplets: list[dict[str, str | Any]] = response_json.object["triplets"]
|
|
654
656
|
|
|
655
657
|
if len(scored_unique_triplets) != len(unique_triplets):
|
|
656
658
|
raise ValueError("Mismatch between input and output triplets")
|
|
@@ -716,7 +718,7 @@ def build_text_blocks_from_relations(
|
|
|
716
718
|
This is a hacky way to generate paragraphs from relations, and it is not the intended use of TextBlockMatch.
|
|
717
719
|
"""
|
|
718
720
|
# Build a set of unique triplets with their scores
|
|
719
|
-
triplets: dict[tuple[str, str, str], tuple[float, Relations,
|
|
721
|
+
triplets: dict[tuple[str, str, str], tuple[float, Relations, ParagraphId | None]] = defaultdict(
|
|
720
722
|
lambda: (0.0, Relations(entities={}), None)
|
|
721
723
|
)
|
|
722
724
|
paragraph_count = 0
|
|
@@ -758,7 +760,7 @@ def build_text_blocks_from_relations(
|
|
|
758
760
|
TextBlockMatch(
|
|
759
761
|
# XXX: Even though we are setting a paragraph_id, the text is not coming from the paragraph
|
|
760
762
|
paragraph_id=p_id,
|
|
761
|
-
score=score,
|
|
763
|
+
scores=[GraphScore(score=score)],
|
|
762
764
|
score_type=SCORE_TYPE.RELATION_RELEVANCE,
|
|
763
765
|
order=0,
|
|
764
766
|
text=f"- {ent} {rel} {tail}", # Manually build the text
|
|
@@ -902,7 +904,7 @@ def relations_match_to_text_block_match(
|
|
|
902
904
|
parsed_paragraph_id = paragraph_match.paragraph_id
|
|
903
905
|
return TextBlockMatch(
|
|
904
906
|
paragraph_id=parsed_paragraph_id,
|
|
905
|
-
score=paragraph_match.score,
|
|
907
|
+
scores=[GraphScore(score=paragraph_match.score)],
|
|
906
908
|
score_type=SCORE_TYPE.RELATION_RELEVANCE,
|
|
907
909
|
order=0, # NOTE: this will be filled later
|
|
908
910
|
text="", # NOTE: this will be filled later too
|