qtype 0.1.1__py3-none-any.whl → 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- qtype/application/facade.py +14 -15
- qtype/cli.py +1 -1
- qtype/commands/generate.py +1 -1
- qtype/commands/run.py +7 -3
- qtype/dsl/domain_types.py +24 -3
- qtype/dsl/model.py +56 -3
- qtype/interpreter/base/executor_context.py +18 -1
- qtype/interpreter/base/factory.py +33 -66
- qtype/interpreter/conversions.py +15 -6
- qtype/interpreter/converters.py +14 -12
- qtype/interpreter/executors/bedrock_reranker_executor.py +195 -0
- qtype/interpreter/executors/document_search_executor.py +37 -46
- qtype/interpreter/executors/field_extractor_executor.py +10 -5
- qtype/interpreter/executors/index_upsert_executor.py +114 -110
- qtype/interpreter/flow.py +35 -32
- qtype/semantic/checker.py +79 -19
- qtype/semantic/model.py +43 -3
- {qtype-0.1.1.dist-info → qtype-0.1.2.dist-info}/METADATA +12 -11
- {qtype-0.1.1.dist-info → qtype-0.1.2.dist-info}/RECORD +23 -22
- {qtype-0.1.1.dist-info → qtype-0.1.2.dist-info}/WHEEL +0 -0
- {qtype-0.1.1.dist-info → qtype-0.1.2.dist-info}/entry_points.txt +0 -0
- {qtype-0.1.1.dist-info → qtype-0.1.2.dist-info}/licenses/LICENSE +0 -0
- {qtype-0.1.1.dist-info → qtype-0.1.2.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,195 @@
|
|
|
1
|
+
"""BedrockReranker executor for reordering search results by relevance."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import asyncio
|
|
6
|
+
import logging
|
|
7
|
+
from typing import AsyncIterator
|
|
8
|
+
|
|
9
|
+
from pydantic import BaseModel
|
|
10
|
+
|
|
11
|
+
from qtype.base.types import PrimitiveTypeEnum
|
|
12
|
+
from qtype.dsl.domain_types import RAGChunk, SearchResult
|
|
13
|
+
from qtype.interpreter.auth.aws import aws
|
|
14
|
+
from qtype.interpreter.base.base_step_executor import StepExecutor
|
|
15
|
+
from qtype.interpreter.base.executor_context import ExecutorContext
|
|
16
|
+
from qtype.interpreter.types import FlowMessage
|
|
17
|
+
from qtype.semantic.model import BedrockReranker, ListType
|
|
18
|
+
|
|
19
|
+
logger = logging.getLogger(__name__)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class BedrockRerankerExecutor(StepExecutor):
|
|
23
|
+
"""Executor for BedrockReranker steps that reorder search results by relevance."""
|
|
24
|
+
|
|
25
|
+
def __init__(
|
|
26
|
+
self, step: BedrockReranker, context: ExecutorContext, **dependencies
|
|
27
|
+
):
|
|
28
|
+
super().__init__(step, context, **dependencies)
|
|
29
|
+
if not isinstance(step, BedrockReranker):
|
|
30
|
+
raise ValueError(
|
|
31
|
+
"BedrockRerankerExecutor can only execute BedrockReranker steps."
|
|
32
|
+
)
|
|
33
|
+
self.step: BedrockReranker = step
|
|
34
|
+
|
|
35
|
+
async def process_message(
|
|
36
|
+
self,
|
|
37
|
+
message: FlowMessage,
|
|
38
|
+
) -> AsyncIterator[FlowMessage]:
|
|
39
|
+
"""Process a single FlowMessage for the BedrockReranker step.
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
message: The FlowMessage to process.
|
|
43
|
+
|
|
44
|
+
Yields:
|
|
45
|
+
FlowMessage with reranked results.
|
|
46
|
+
"""
|
|
47
|
+
try:
|
|
48
|
+
# Get the inputs
|
|
49
|
+
query = self._query(message)
|
|
50
|
+
docs = self._docs(message)
|
|
51
|
+
|
|
52
|
+
if len(docs) == 0:
|
|
53
|
+
# No documents to rerank, yield original message
|
|
54
|
+
yield message.copy_with_variables(
|
|
55
|
+
{self.step.outputs[0].id: docs}
|
|
56
|
+
)
|
|
57
|
+
return
|
|
58
|
+
|
|
59
|
+
# Get session for region info
|
|
60
|
+
if self.step.auth is not None:
|
|
61
|
+
with aws(self.step.auth, self.context.secret_manager) as s:
|
|
62
|
+
region_name = s.region_name
|
|
63
|
+
else:
|
|
64
|
+
import boto3
|
|
65
|
+
|
|
66
|
+
region_name = boto3.Session().region_name
|
|
67
|
+
|
|
68
|
+
# Convert the types
|
|
69
|
+
queries = [
|
|
70
|
+
{
|
|
71
|
+
"type": "TEXT",
|
|
72
|
+
"textQuery": {"text": query},
|
|
73
|
+
}
|
|
74
|
+
]
|
|
75
|
+
documents = []
|
|
76
|
+
|
|
77
|
+
for doc in docs:
|
|
78
|
+
if isinstance(doc.content, RAGChunk):
|
|
79
|
+
documents.append(
|
|
80
|
+
{
|
|
81
|
+
"type": "INLINE",
|
|
82
|
+
"inlineDocumentSource": {
|
|
83
|
+
"type": "TEXT",
|
|
84
|
+
"textDocument": {"text": str(doc.content)},
|
|
85
|
+
},
|
|
86
|
+
}
|
|
87
|
+
)
|
|
88
|
+
elif isinstance(doc.content, dict):
|
|
89
|
+
documents.append(
|
|
90
|
+
{
|
|
91
|
+
"type": "INLINE",
|
|
92
|
+
"inlineDocumentSource": {
|
|
93
|
+
"type": "JSON",
|
|
94
|
+
"jsonDocument": doc.content,
|
|
95
|
+
},
|
|
96
|
+
}
|
|
97
|
+
)
|
|
98
|
+
elif isinstance(doc.content, BaseModel):
|
|
99
|
+
documents.append(
|
|
100
|
+
{
|
|
101
|
+
"type": "INLINE",
|
|
102
|
+
"inlineDocumentSource": {
|
|
103
|
+
"type": "JSON",
|
|
104
|
+
"jsonDocument": doc.content.model_dump(),
|
|
105
|
+
},
|
|
106
|
+
}
|
|
107
|
+
)
|
|
108
|
+
else:
|
|
109
|
+
raise ValueError(
|
|
110
|
+
f"Unsupported document content type for BedrockReranker: {type(doc.content)}"
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
reranking_configuration = {
|
|
114
|
+
"type": "BEDROCK_RERANKING_MODEL",
|
|
115
|
+
"bedrockRerankingConfiguration": {
|
|
116
|
+
"numberOfResults": self.step.num_results or len(docs),
|
|
117
|
+
"modelConfiguration": {
|
|
118
|
+
"modelArn": f"arn:aws:bedrock:{region_name}::foundation-model/{self.step.model_id}"
|
|
119
|
+
},
|
|
120
|
+
},
|
|
121
|
+
}
|
|
122
|
+
|
|
123
|
+
def _call_bedrock_rerank():
|
|
124
|
+
"""Create client and call rerank in executor thread."""
|
|
125
|
+
if self.step.auth is not None:
|
|
126
|
+
with aws(self.step.auth, self.context.secret_manager) as s:
|
|
127
|
+
client = s.client("bedrock-agent-runtime")
|
|
128
|
+
return client.rerank(
|
|
129
|
+
queries=queries,
|
|
130
|
+
sources=documents,
|
|
131
|
+
rerankingConfiguration=reranking_configuration,
|
|
132
|
+
)
|
|
133
|
+
else:
|
|
134
|
+
import boto3
|
|
135
|
+
|
|
136
|
+
session = boto3.Session()
|
|
137
|
+
client = session.client("bedrock-agent-runtime")
|
|
138
|
+
return client.rerank(
|
|
139
|
+
queries=queries,
|
|
140
|
+
sources=documents,
|
|
141
|
+
rerankingConfiguration=reranking_configuration,
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
loop = asyncio.get_running_loop()
|
|
145
|
+
response = await loop.run_in_executor(
|
|
146
|
+
self.context.thread_pool, _call_bedrock_rerank
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
results = []
|
|
150
|
+
for d in response["results"]:
|
|
151
|
+
doc = docs[d["index"]]
|
|
152
|
+
new_score = d["relevanceScore"]
|
|
153
|
+
results.append(doc.copy(update={"score": new_score}))
|
|
154
|
+
|
|
155
|
+
# Update the message with reranked results
|
|
156
|
+
yield message.copy_with_variables(
|
|
157
|
+
{self.step.outputs[0].id: results}
|
|
158
|
+
)
|
|
159
|
+
except Exception as e:
|
|
160
|
+
logger.error(f"Reranking failed: {e}", exc_info=True)
|
|
161
|
+
# Emit error event to stream so frontend can display it
|
|
162
|
+
await self.stream_emitter.error(str(e))
|
|
163
|
+
message.set_error(self.step.id, e)
|
|
164
|
+
yield message
|
|
165
|
+
|
|
166
|
+
def _query(self, message: FlowMessage) -> str:
|
|
167
|
+
"""Extract the query string from the FlowMessage.
|
|
168
|
+
|
|
169
|
+
Args:
|
|
170
|
+
message: The FlowMessage containing the query variable.
|
|
171
|
+
Returns:
|
|
172
|
+
The query string.
|
|
173
|
+
"""
|
|
174
|
+
for i in self.step.inputs:
|
|
175
|
+
if i.type == PrimitiveTypeEnum.text:
|
|
176
|
+
return message.variables[i.id]
|
|
177
|
+
raise ValueError(
|
|
178
|
+
f"No text input found for BedrockReranker step {self.step.id}"
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
def _docs(self, message: FlowMessage) -> list[SearchResult]:
|
|
182
|
+
"""Extract the list of SearchResult documents from the FlowMessage.
|
|
183
|
+
|
|
184
|
+
Args:
|
|
185
|
+
message: The FlowMessage containing the document variable.
|
|
186
|
+
Returns:
|
|
187
|
+
The list of SearchResult documents.
|
|
188
|
+
"""
|
|
189
|
+
for i in self.step.inputs:
|
|
190
|
+
if i.type == ListType(element_type="SearchResult"):
|
|
191
|
+
docs = message.variables[i.id]
|
|
192
|
+
return docs
|
|
193
|
+
raise ValueError(
|
|
194
|
+
f"No list of SearchResults input found for BedrockReranker step {self.step.id}"
|
|
195
|
+
)
|
|
@@ -1,6 +1,8 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
from typing import AsyncIterator
|
|
2
4
|
|
|
3
|
-
from qtype.dsl.domain_types import
|
|
5
|
+
from qtype.dsl.domain_types import SearchResult
|
|
4
6
|
from qtype.interpreter.base.base_step_executor import StepExecutor
|
|
5
7
|
from qtype.interpreter.base.executor_context import ExecutorContext
|
|
6
8
|
from qtype.interpreter.conversions import to_opensearch_client
|
|
@@ -29,6 +31,17 @@ class DocumentSearchExecutor(StepExecutor):
|
|
|
29
31
|
)
|
|
30
32
|
self.index_name = self.step.index.name
|
|
31
33
|
|
|
34
|
+
async def finalize(self) -> AsyncIterator[FlowMessage]:
|
|
35
|
+
"""Clean up resources after all messages are processed."""
|
|
36
|
+
if hasattr(self, "client") and self.client:
|
|
37
|
+
try:
|
|
38
|
+
await self.client.close()
|
|
39
|
+
except Exception:
|
|
40
|
+
pass
|
|
41
|
+
# Make this an async generator
|
|
42
|
+
return
|
|
43
|
+
yield # type: ignore[unreachable]
|
|
44
|
+
|
|
32
45
|
async def process_message(
|
|
33
46
|
self,
|
|
34
47
|
message: FlowMessage,
|
|
@@ -39,7 +52,7 @@ class DocumentSearchExecutor(StepExecutor):
|
|
|
39
52
|
message: The FlowMessage to process.
|
|
40
53
|
|
|
41
54
|
Yields:
|
|
42
|
-
|
|
55
|
+
A list of dictionaries with _source, _search_score, and _search_id fields.
|
|
43
56
|
"""
|
|
44
57
|
input_id = self.step.inputs[0].id
|
|
45
58
|
output_id = self.step.outputs[0].id
|
|
@@ -58,62 +71,40 @@ class DocumentSearchExecutor(StepExecutor):
|
|
|
58
71
|
# Build the search query
|
|
59
72
|
search_body = {
|
|
60
73
|
"query": {
|
|
61
|
-
"multi_match": {
|
|
62
|
-
"query": query_text,
|
|
63
|
-
"fields": ["content^2", "title", "*"],
|
|
64
|
-
"type": "best_fields",
|
|
65
|
-
}
|
|
74
|
+
"multi_match": {"query": query_text} | self.step.query_args
|
|
66
75
|
},
|
|
67
|
-
"size":
|
|
76
|
+
"size": self.step.default_top_k,
|
|
68
77
|
}
|
|
69
78
|
|
|
70
79
|
# Apply any filters if specified
|
|
71
80
|
if self.step.filters:
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
"
|
|
75
|
-
|
|
76
|
-
"
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
],
|
|
80
|
-
}
|
|
81
|
+
search_body["query"] = {
|
|
82
|
+
"bool": {
|
|
83
|
+
"must": [search_body["query"]],
|
|
84
|
+
"filter": [
|
|
85
|
+
{"term": {k: v}}
|
|
86
|
+
for k, v in self.step.filters.items()
|
|
87
|
+
],
|
|
81
88
|
}
|
|
89
|
+
}
|
|
82
90
|
|
|
83
|
-
# Execute the search
|
|
84
|
-
response = self.client.search(
|
|
91
|
+
# Execute the search asynchronously using AsyncOpenSearch
|
|
92
|
+
response = await self.client.search(
|
|
85
93
|
index=self.index_name, body=search_body
|
|
86
94
|
)
|
|
87
95
|
|
|
88
|
-
# Process each hit and yield as
|
|
96
|
+
# Process each hit and yield as SearchResult
|
|
97
|
+
# TODO: add support for decomposing a RAGSearchResult for hybrid search
|
|
98
|
+
search_results = []
|
|
89
99
|
for hit in response["hits"]["hits"]:
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
# Build metadata from the source, excluding content field
|
|
98
|
-
metadata = {
|
|
99
|
-
k: v for k, v in source.items() if k not in ["content"]
|
|
100
|
-
}
|
|
101
|
-
|
|
102
|
-
# Create a RAGChunk from the search result
|
|
103
|
-
# Use the document ID as both chunk_id and document_id
|
|
104
|
-
chunk = RAGChunk(
|
|
105
|
-
content=content,
|
|
106
|
-
chunk_id=doc_id,
|
|
107
|
-
document_id=source.get("document_id", doc_id),
|
|
108
|
-
vector=None, # Document search doesn't return embeddings
|
|
109
|
-
metadata=metadata,
|
|
100
|
+
search_results.append(
|
|
101
|
+
SearchResult(
|
|
102
|
+
content=hit["_source"],
|
|
103
|
+
doc_id=hit["_id"],
|
|
104
|
+
score=hit["_score"],
|
|
105
|
+
)
|
|
110
106
|
)
|
|
111
|
-
|
|
112
|
-
# Wrap in RAGSearchResult with the score
|
|
113
|
-
search_result = RAGSearchResult(chunk=chunk, score=score)
|
|
114
|
-
|
|
115
|
-
# Yield result for each document
|
|
116
|
-
yield message.copy_with_variables({output_id: search_result})
|
|
107
|
+
yield message.copy_with_variables({output_id: search_results})
|
|
117
108
|
|
|
118
109
|
except Exception as e:
|
|
119
110
|
# Emit error event to stream so frontend can display it
|
|
@@ -132,12 +132,17 @@ class FieldExtractorExecutor(StepExecutor):
|
|
|
132
132
|
matches = self.jsonpath_expr.find(input_dict)
|
|
133
133
|
|
|
134
134
|
if not matches:
|
|
135
|
-
|
|
136
|
-
(
|
|
137
|
-
|
|
138
|
-
|
|
135
|
+
if self.step.fail_on_missing:
|
|
136
|
+
raise ValueError(
|
|
137
|
+
(
|
|
138
|
+
f"JSONPath expression '{self.step.json_path}' "
|
|
139
|
+
f"did not match any data in input"
|
|
140
|
+
)
|
|
139
141
|
)
|
|
140
|
-
|
|
142
|
+
else:
|
|
143
|
+
# Yield message with None output
|
|
144
|
+
yield message.copy_with_variables({output_id: None})
|
|
145
|
+
return
|
|
141
146
|
|
|
142
147
|
await self.stream_emitter.status(
|
|
143
148
|
f"JSONPath matched {len(matches)} value(s)"
|
|
@@ -3,9 +3,12 @@
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
5
|
import logging
|
|
6
|
+
import uuid
|
|
6
7
|
from typing import AsyncIterator
|
|
7
8
|
|
|
8
9
|
from llama_index.core.schema import TextNode
|
|
10
|
+
from opensearchpy import AsyncOpenSearch
|
|
11
|
+
from pydantic import BaseModel
|
|
9
12
|
|
|
10
13
|
from qtype.dsl.domain_types import RAGChunk, RAGDocument
|
|
11
14
|
from qtype.interpreter.base.batch_step_executor import BatchedStepExecutor
|
|
@@ -39,21 +42,32 @@ class IndexUpsertExecutor(BatchedStepExecutor):
|
|
|
39
42
|
self._vector_store, _ = to_llama_vector_store_and_retriever(
|
|
40
43
|
self.step.index, self.context.secret_manager
|
|
41
44
|
)
|
|
42
|
-
self._opensearch_client = None
|
|
43
45
|
self.index_type = "vector"
|
|
44
46
|
elif isinstance(self.step.index, DocumentIndex):
|
|
45
47
|
# Document index for text-based search
|
|
46
|
-
self._opensearch_client = to_opensearch_client(
|
|
48
|
+
self._opensearch_client: AsyncOpenSearch = to_opensearch_client(
|
|
47
49
|
self.step.index, self.context.secret_manager
|
|
48
50
|
)
|
|
49
51
|
self._vector_store = None
|
|
50
52
|
self.index_type = "document"
|
|
51
53
|
self.index_name = self.step.index.name
|
|
54
|
+
self._document_index: DocumentIndex = self.step.index
|
|
52
55
|
else:
|
|
53
56
|
raise ValueError(
|
|
54
57
|
f"Unsupported index type: {type(self.step.index)}"
|
|
55
58
|
)
|
|
56
59
|
|
|
60
|
+
async def finalize(self) -> AsyncIterator[FlowMessage]:
|
|
61
|
+
"""Clean up resources after all messages are processed."""
|
|
62
|
+
if hasattr(self, "_opensearch_client") and self._opensearch_client:
|
|
63
|
+
try:
|
|
64
|
+
await self._opensearch_client.close()
|
|
65
|
+
except Exception:
|
|
66
|
+
pass
|
|
67
|
+
# Make this an async generator
|
|
68
|
+
return
|
|
69
|
+
yield # type: ignore[unreachable]
|
|
70
|
+
|
|
57
71
|
async def process_batch(
|
|
58
72
|
self, batch: list[FlowMessage]
|
|
59
73
|
) -> AsyncIterator[FlowMessage]:
|
|
@@ -68,58 +82,15 @@ class IndexUpsertExecutor(BatchedStepExecutor):
|
|
|
68
82
|
logger.debug(
|
|
69
83
|
f"Executing IndexUpsert step: {self.step.id} with batch size: {len(batch)}"
|
|
70
84
|
)
|
|
85
|
+
if len(batch) == 0:
|
|
86
|
+
return
|
|
71
87
|
|
|
72
88
|
try:
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
# Collect all RAGChunks or RAGDocuments from the batch
|
|
80
|
-
items_to_upsert = []
|
|
81
|
-
for message in batch:
|
|
82
|
-
input_data = message.variables.get(input_var.id)
|
|
83
|
-
|
|
84
|
-
if input_data is None:
|
|
85
|
-
logger.warning(
|
|
86
|
-
f"No data found for input: {input_var.id} in message"
|
|
87
|
-
)
|
|
88
|
-
continue
|
|
89
|
-
|
|
90
|
-
if not isinstance(input_data, (RAGChunk, RAGDocument)):
|
|
91
|
-
raise ValueError(
|
|
92
|
-
f"IndexUpsert only supports RAGChunk or RAGDocument "
|
|
93
|
-
f"inputs. Got: {type(input_data)}"
|
|
94
|
-
)
|
|
95
|
-
|
|
96
|
-
items_to_upsert.append(input_data)
|
|
97
|
-
|
|
98
|
-
# Upsert to appropriate index type
|
|
99
|
-
if items_to_upsert:
|
|
100
|
-
if self.index_type == "vector":
|
|
101
|
-
await self._upsert_to_vector_store(items_to_upsert)
|
|
102
|
-
else: # document index
|
|
103
|
-
await self._upsert_to_document_index(items_to_upsert)
|
|
104
|
-
|
|
105
|
-
logger.debug(
|
|
106
|
-
f"Successfully upserted {len(items_to_upsert)} items "
|
|
107
|
-
f"to {self.index_type} index in batch"
|
|
108
|
-
)
|
|
109
|
-
|
|
110
|
-
# Emit status update
|
|
111
|
-
index_type_display = (
|
|
112
|
-
"vector index"
|
|
113
|
-
if self.index_type == "vector"
|
|
114
|
-
else "document index"
|
|
115
|
-
)
|
|
116
|
-
await self.stream_emitter.status(
|
|
117
|
-
f"Upserted {len(items_to_upsert)} items to "
|
|
118
|
-
f"{index_type_display}"
|
|
119
|
-
)
|
|
120
|
-
|
|
121
|
-
# Yield all input messages back (IndexUpsert typically doesn't have outputs)
|
|
122
|
-
for message in batch:
|
|
89
|
+
if self.index_type == "vector":
|
|
90
|
+
result_iter = self._upsert_to_vector_store(batch)
|
|
91
|
+
else:
|
|
92
|
+
result_iter = self._upsert_to_document_index(batch)
|
|
93
|
+
async for message in result_iter:
|
|
123
94
|
yield message
|
|
124
95
|
|
|
125
96
|
except Exception as e:
|
|
@@ -133,13 +104,27 @@ class IndexUpsertExecutor(BatchedStepExecutor):
|
|
|
133
104
|
yield message
|
|
134
105
|
|
|
135
106
|
async def _upsert_to_vector_store(
|
|
136
|
-
self,
|
|
137
|
-
) ->
|
|
107
|
+
self, batch: list[FlowMessage]
|
|
108
|
+
) -> AsyncIterator[FlowMessage]:
|
|
138
109
|
"""Upsert items to vector store.
|
|
139
110
|
|
|
140
111
|
Args:
|
|
141
112
|
items: List of RAGChunk or RAGDocument objects
|
|
142
113
|
"""
|
|
114
|
+
# safe since semantic validation checks input length
|
|
115
|
+
input_var = self.step.inputs[0]
|
|
116
|
+
|
|
117
|
+
# Collect all RAGChunks or RAGDocuments from the batch inputs
|
|
118
|
+
items = []
|
|
119
|
+
for message in batch:
|
|
120
|
+
input_data = message.variables.get(input_var.id)
|
|
121
|
+
if not isinstance(input_data, (RAGChunk, RAGDocument)):
|
|
122
|
+
raise ValueError(
|
|
123
|
+
f"IndexUpsert only supports RAGChunk or RAGDocument "
|
|
124
|
+
f"inputs. Got: {type(input_data)}"
|
|
125
|
+
)
|
|
126
|
+
items.append(input_data)
|
|
127
|
+
|
|
143
128
|
# Convert to LlamaIndex TextNode objects
|
|
144
129
|
nodes = []
|
|
145
130
|
for item in items:
|
|
@@ -162,67 +147,86 @@ class IndexUpsertExecutor(BatchedStepExecutor):
|
|
|
162
147
|
|
|
163
148
|
# Batch upsert all nodes to the vector store
|
|
164
149
|
await self._vector_store.async_add(nodes)
|
|
150
|
+
num_inserted = len(items)
|
|
151
|
+
|
|
152
|
+
# Emit status update
|
|
153
|
+
await self.stream_emitter.status(
|
|
154
|
+
f"Upserted {num_inserted} items to index {self.step.index.name}"
|
|
155
|
+
)
|
|
156
|
+
for message in batch:
|
|
157
|
+
yield message
|
|
165
158
|
|
|
166
159
|
async def _upsert_to_document_index(
|
|
167
|
-
self,
|
|
168
|
-
) ->
|
|
160
|
+
self, batch: list[FlowMessage]
|
|
161
|
+
) -> AsyncIterator[FlowMessage]:
|
|
169
162
|
"""Upsert items to document index using bulk API.
|
|
170
163
|
|
|
171
164
|
Args:
|
|
172
|
-
|
|
165
|
+
batch: List of FlowMessages containing documents to upsert
|
|
173
166
|
"""
|
|
174
|
-
|
|
167
|
+
|
|
175
168
|
bulk_body = []
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
if
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
# Check for errors
|
|
220
|
-
if response.get("errors"):
|
|
221
|
-
error_items = [
|
|
222
|
-
item
|
|
223
|
-
for item in response["items"]
|
|
224
|
-
if "error" in item.get("index", {})
|
|
225
|
-
]
|
|
226
|
-
logger.warning(
|
|
227
|
-
f"Bulk upsert had {len(error_items)} errors: {error_items}"
|
|
169
|
+
message_by_id: dict[str, FlowMessage] = {}
|
|
170
|
+
|
|
171
|
+
for message in batch:
|
|
172
|
+
# Collect all input variables into a single document dict
|
|
173
|
+
doc_dict = {}
|
|
174
|
+
for input_var in self.step.inputs:
|
|
175
|
+
value = message.variables.get(input_var.id)
|
|
176
|
+
|
|
177
|
+
# Convert to dict if it's a Pydantic model
|
|
178
|
+
if isinstance(value, BaseModel):
|
|
179
|
+
value = value.model_dump()
|
|
180
|
+
|
|
181
|
+
# Merge into document dict
|
|
182
|
+
if isinstance(value, dict):
|
|
183
|
+
doc_dict.update(value)
|
|
184
|
+
else:
|
|
185
|
+
# Primitive types - use variable name as field name
|
|
186
|
+
doc_dict[input_var.id] = value
|
|
187
|
+
|
|
188
|
+
# Determine the document id field
|
|
189
|
+
id_field = None
|
|
190
|
+
if self._document_index.id_field is not None:
|
|
191
|
+
id_field = self._document_index.id_field
|
|
192
|
+
if id_field not in doc_dict:
|
|
193
|
+
raise ValueError(
|
|
194
|
+
f"Specified id_field '{id_field}' not found in inputs"
|
|
195
|
+
)
|
|
196
|
+
else:
|
|
197
|
+
# Auto-detect with fallback
|
|
198
|
+
for field in ["_id", "id", "doc_id", "document_id"]:
|
|
199
|
+
if field in doc_dict:
|
|
200
|
+
id_field = field
|
|
201
|
+
break
|
|
202
|
+
if id_field is not None:
|
|
203
|
+
doc_id = str(doc_dict[id_field])
|
|
204
|
+
else:
|
|
205
|
+
# Generate a UUID if no id field found
|
|
206
|
+
doc_id = str(uuid.uuid4())
|
|
207
|
+
|
|
208
|
+
# Add bulk action and document
|
|
209
|
+
bulk_body.append(
|
|
210
|
+
{"index": {"_index": self.index_name, "_id": doc_id}}
|
|
228
211
|
)
|
|
212
|
+
bulk_body.append(doc_dict)
|
|
213
|
+
message_by_id[doc_id] = message
|
|
214
|
+
|
|
215
|
+
# Execute bulk request asynchronously
|
|
216
|
+
response = await self._opensearch_client.bulk(body=bulk_body)
|
|
217
|
+
|
|
218
|
+
num_inserted = 0
|
|
219
|
+
for item in response["items"]:
|
|
220
|
+
doc_id = item["index"]["_id"]
|
|
221
|
+
message = message_by_id[doc_id]
|
|
222
|
+
if "error" in item.get("index", {}):
|
|
223
|
+
message.set_error(
|
|
224
|
+
self.step.id,
|
|
225
|
+
Exception(item["index"]["error"]),
|
|
226
|
+
)
|
|
227
|
+
else:
|
|
228
|
+
num_inserted += 1
|
|
229
|
+
yield message
|
|
230
|
+
await self.stream_emitter.status(
|
|
231
|
+
f"Upserted {num_inserted} items to index {self.step.index.name}, {len(batch) - num_inserted} errors occurred."
|
|
232
|
+
)
|