haystack-experimental 0.13.0__py3-none-any.whl → 0.14.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- haystack_experimental/components/agents/__init__.py +16 -0
- haystack_experimental/components/agents/agent.py +633 -0
- haystack_experimental/components/agents/human_in_the_loop/__init__.py +35 -0
- haystack_experimental/components/agents/human_in_the_loop/breakpoint.py +63 -0
- haystack_experimental/components/agents/human_in_the_loop/dataclasses.py +72 -0
- haystack_experimental/components/agents/human_in_the_loop/errors.py +28 -0
- haystack_experimental/components/agents/human_in_the_loop/policies.py +78 -0
- haystack_experimental/components/agents/human_in_the_loop/strategies.py +455 -0
- haystack_experimental/components/agents/human_in_the_loop/types.py +89 -0
- haystack_experimental/components/agents/human_in_the_loop/user_interfaces.py +209 -0
- haystack_experimental/components/generators/chat/openai.py +8 -10
- haystack_experimental/components/preprocessors/embedding_based_document_splitter.py +18 -6
- haystack_experimental/components/preprocessors/md_header_level_inferrer.py +146 -0
- haystack_experimental/components/summarizers/__init__.py +7 -0
- haystack_experimental/components/summarizers/llm_summarizer.py +317 -0
- haystack_experimental/core/__init__.py +3 -0
- haystack_experimental/core/pipeline/__init__.py +3 -0
- haystack_experimental/core/pipeline/breakpoint.py +119 -0
- haystack_experimental/dataclasses/__init__.py +3 -0
- haystack_experimental/dataclasses/breakpoints.py +53 -0
- {haystack_experimental-0.13.0.dist-info → haystack_experimental-0.14.1.dist-info}/METADATA +29 -14
- {haystack_experimental-0.13.0.dist-info → haystack_experimental-0.14.1.dist-info}/RECORD +25 -7
- {haystack_experimental-0.13.0.dist-info → haystack_experimental-0.14.1.dist-info}/WHEEL +0 -0
- {haystack_experimental-0.13.0.dist-info → haystack_experimental-0.14.1.dist-info}/licenses/LICENSE +0 -0
- {haystack_experimental-0.13.0.dist-info → haystack_experimental-0.14.1.dist-info}/licenses/LICENSE-MIT.txt +0 -0
|
@@ -0,0 +1,209 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
|
2
|
+
#
|
|
3
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
from threading import Lock
|
|
7
|
+
from typing import Any, Optional
|
|
8
|
+
|
|
9
|
+
from haystack.core.serialization import default_to_dict
|
|
10
|
+
from rich.console import Console
|
|
11
|
+
from rich.panel import Panel
|
|
12
|
+
from rich.prompt import Prompt
|
|
13
|
+
|
|
14
|
+
from haystack_experimental.components.agents.human_in_the_loop.dataclasses import ConfirmationUIResult
|
|
15
|
+
from haystack_experimental.components.agents.human_in_the_loop.types import ConfirmationUI
|
|
16
|
+
|
|
17
|
+
_ui_lock = Lock()
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class RichConsoleUI(ConfirmationUI):
|
|
21
|
+
"""Rich console interface for user interaction."""
|
|
22
|
+
|
|
23
|
+
def __init__(self, console: Optional[Console] = None):
|
|
24
|
+
self.console = console or Console()
|
|
25
|
+
|
|
26
|
+
def get_user_confirmation(
|
|
27
|
+
self, tool_name: str, tool_description: str, tool_params: dict[str, Any]
|
|
28
|
+
) -> ConfirmationUIResult:
|
|
29
|
+
"""
|
|
30
|
+
Get user confirmation for tool execution via rich console prompts.
|
|
31
|
+
|
|
32
|
+
:param tool_name: The name of the tool to be executed.
|
|
33
|
+
:param tool_description: The description of the tool.
|
|
34
|
+
:param tool_params: The parameters to be passed to the tool.
|
|
35
|
+
:returns: ConfirmationUIResult based on user input.
|
|
36
|
+
"""
|
|
37
|
+
with _ui_lock:
|
|
38
|
+
self._display_tool_info(tool_name, tool_description, tool_params)
|
|
39
|
+
# If wrong input is provided, Prompt.ask will re-prompt
|
|
40
|
+
choice = Prompt.ask("\nYour choice", choices=["y", "n", "m"], default="y", console=self.console)
|
|
41
|
+
return self._process_choice(choice, tool_params)
|
|
42
|
+
|
|
43
|
+
def _display_tool_info(self, tool_name: str, tool_description: str, tool_params: dict[str, Any]) -> None:
|
|
44
|
+
"""
|
|
45
|
+
Display tool information and parameters in a rich panel.
|
|
46
|
+
|
|
47
|
+
:param tool_name: The name of the tool to be executed.
|
|
48
|
+
:param tool_description: The description of the tool.
|
|
49
|
+
:param tool_params: The parameters to be passed to the tool.
|
|
50
|
+
"""
|
|
51
|
+
lines = [
|
|
52
|
+
f"[bold yellow]Tool:[/bold yellow] {tool_name}",
|
|
53
|
+
f"[bold yellow]Description:[/bold yellow] {tool_description}",
|
|
54
|
+
"\n[bold yellow]Arguments:[/bold yellow]",
|
|
55
|
+
]
|
|
56
|
+
|
|
57
|
+
if tool_params:
|
|
58
|
+
for k, v in tool_params.items():
|
|
59
|
+
lines.append(f"[cyan]{k}:[/cyan] {v}")
|
|
60
|
+
else:
|
|
61
|
+
lines.append(" (No arguments)")
|
|
62
|
+
|
|
63
|
+
self.console.print(Panel("\n".join(lines), title="🔧 Tool Execution Request", title_align="left"))
|
|
64
|
+
|
|
65
|
+
def _process_choice(self, choice: str, tool_params: dict[str, Any]) -> ConfirmationUIResult:
|
|
66
|
+
"""
|
|
67
|
+
Process the user's choice and return the corresponding ConfirmationUIResult.
|
|
68
|
+
|
|
69
|
+
:param choice: The user's choice ('y', 'n', or 'm').
|
|
70
|
+
:param tool_params: The original tool parameters.
|
|
71
|
+
:returns:
|
|
72
|
+
ConfirmationUIResult based on user input.
|
|
73
|
+
"""
|
|
74
|
+
if choice == "y":
|
|
75
|
+
return ConfirmationUIResult(action="confirm")
|
|
76
|
+
elif choice == "m":
|
|
77
|
+
return self._modify_params(tool_params)
|
|
78
|
+
else: # reject
|
|
79
|
+
feedback = Prompt.ask("Feedback message (optional)", default="", console=self.console)
|
|
80
|
+
return ConfirmationUIResult(action="reject", feedback=feedback or None)
|
|
81
|
+
|
|
82
|
+
def _modify_params(self, tool_params: dict[str, Any]) -> ConfirmationUIResult:
|
|
83
|
+
"""
|
|
84
|
+
Prompt the user to modify tool parameters.
|
|
85
|
+
|
|
86
|
+
:param tool_params: The original tool parameters.
|
|
87
|
+
:returns:
|
|
88
|
+
ConfirmationUIResult with modified parameters.
|
|
89
|
+
"""
|
|
90
|
+
new_params: dict[str, Any] = {}
|
|
91
|
+
for k, v in tool_params.items():
|
|
92
|
+
# We don't JSON dump strings to avoid users needing to input extra quotes
|
|
93
|
+
default_val = json.dumps(v) if not isinstance(v, str) else v
|
|
94
|
+
while True:
|
|
95
|
+
new_val = Prompt.ask(f"Modify '{k}'", default=default_val, console=self.console)
|
|
96
|
+
try:
|
|
97
|
+
if isinstance(v, str):
|
|
98
|
+
# Always treat input as string
|
|
99
|
+
new_params[k] = new_val
|
|
100
|
+
else:
|
|
101
|
+
# Parse JSON for all non-string types
|
|
102
|
+
new_params[k] = json.loads(new_val)
|
|
103
|
+
break
|
|
104
|
+
except json.JSONDecodeError:
|
|
105
|
+
self.console.print("[red]❌ Invalid JSON, please try again.[/red]")
|
|
106
|
+
|
|
107
|
+
return ConfirmationUIResult(action="modify", new_tool_params=new_params)
|
|
108
|
+
|
|
109
|
+
def to_dict(self) -> dict[str, Any]:
|
|
110
|
+
"""
|
|
111
|
+
Serializes the RichConsoleConfirmationUI to a dictionary.
|
|
112
|
+
|
|
113
|
+
:returns:
|
|
114
|
+
Dictionary with serialized data.
|
|
115
|
+
"""
|
|
116
|
+
# Note: Console object is not serializable; we store None
|
|
117
|
+
return default_to_dict(self, console=None)
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
class SimpleConsoleUI(ConfirmationUI):
|
|
121
|
+
"""Simple console interface using standard input/output."""
|
|
122
|
+
|
|
123
|
+
def get_user_confirmation(
|
|
124
|
+
self, tool_name: str, tool_description: str, tool_params: dict[str, Any]
|
|
125
|
+
) -> ConfirmationUIResult:
|
|
126
|
+
"""
|
|
127
|
+
Get user confirmation for tool execution via simple console prompts.
|
|
128
|
+
|
|
129
|
+
:param tool_name: The name of the tool to be executed.
|
|
130
|
+
:param tool_description: The description of the tool.
|
|
131
|
+
:param tool_params: The parameters to be passed to the tool.
|
|
132
|
+
"""
|
|
133
|
+
with _ui_lock:
|
|
134
|
+
self._display_tool_info(tool_name, tool_description, tool_params)
|
|
135
|
+
valid_choices = {"y", "yes", "n", "no", "m", "modify"}
|
|
136
|
+
while True:
|
|
137
|
+
choice = input("Confirm execution? (y=confirm / n=reject / m=modify): ").strip().lower()
|
|
138
|
+
if choice in valid_choices:
|
|
139
|
+
break
|
|
140
|
+
print("Invalid input. Please enter 'y', 'n', or 'm'.")
|
|
141
|
+
return self._process_choice(choice, tool_params)
|
|
142
|
+
|
|
143
|
+
def _display_tool_info(self, tool_name: str, tool_description: str, tool_params: dict[str, Any]) -> None:
|
|
144
|
+
"""
|
|
145
|
+
Display tool information and parameters in the console.
|
|
146
|
+
|
|
147
|
+
:param tool_name: The name of the tool to be executed.
|
|
148
|
+
:param tool_description: The description of the tool.
|
|
149
|
+
:param tool_params: The parameters to be passed to the tool.
|
|
150
|
+
"""
|
|
151
|
+
print("\n--- Tool Execution Request ---")
|
|
152
|
+
print(f"Tool: {tool_name}")
|
|
153
|
+
print(f"Description: {tool_description}")
|
|
154
|
+
print("Arguments:")
|
|
155
|
+
if tool_params:
|
|
156
|
+
for k, v in tool_params.items():
|
|
157
|
+
print(f" {k}: {v}")
|
|
158
|
+
else:
|
|
159
|
+
print(" (No arguments)")
|
|
160
|
+
print("-" * 30)
|
|
161
|
+
|
|
162
|
+
def _process_choice(self, choice: str, tool_params: dict[str, Any]) -> ConfirmationUIResult:
|
|
163
|
+
"""
|
|
164
|
+
Process the user's choice and return the corresponding ConfirmationUIResult.
|
|
165
|
+
|
|
166
|
+
:param choice: The user's choice ('y', 'n', or 'm').
|
|
167
|
+
:param tool_params: The original tool parameters.
|
|
168
|
+
:returns:
|
|
169
|
+
ConfirmationUIResult based on user input.
|
|
170
|
+
"""
|
|
171
|
+
if choice in ("y", "yes"):
|
|
172
|
+
return ConfirmationUIResult(action="confirm")
|
|
173
|
+
elif choice in ("m", "modify"):
|
|
174
|
+
return self._modify_params(tool_params)
|
|
175
|
+
else: # reject
|
|
176
|
+
feedback = input("Feedback message (optional): ").strip()
|
|
177
|
+
return ConfirmationUIResult(action="reject", feedback=feedback or None)
|
|
178
|
+
|
|
179
|
+
def _modify_params(self, tool_params: dict[str, Any]) -> ConfirmationUIResult:
|
|
180
|
+
"""
|
|
181
|
+
Prompt the user to modify tool parameters.
|
|
182
|
+
|
|
183
|
+
:param tool_params: The original tool parameters.
|
|
184
|
+
:returns:
|
|
185
|
+
ConfirmationUIResult with modified parameters.
|
|
186
|
+
"""
|
|
187
|
+
new_params: dict[str, Any] = {}
|
|
188
|
+
|
|
189
|
+
if not tool_params:
|
|
190
|
+
print("No parameters to modify, skipping modification.")
|
|
191
|
+
return ConfirmationUIResult(action="modify", new_tool_params=new_params)
|
|
192
|
+
|
|
193
|
+
for k, v in tool_params.items():
|
|
194
|
+
# We don't JSON dump strings to avoid users needing to input extra quotes
|
|
195
|
+
default_val = json.dumps(v) if not isinstance(v, str) else v
|
|
196
|
+
while True:
|
|
197
|
+
new_val = input(f"Modify '{k}' (current: {default_val}): ").strip() or default_val
|
|
198
|
+
try:
|
|
199
|
+
if isinstance(v, str):
|
|
200
|
+
# Always treat input as string
|
|
201
|
+
new_params[k] = new_val
|
|
202
|
+
else:
|
|
203
|
+
# Parse JSON for all non-string types
|
|
204
|
+
new_params[k] = json.loads(new_val)
|
|
205
|
+
break
|
|
206
|
+
except json.JSONDecodeError:
|
|
207
|
+
print("❌ Invalid JSON, please try again.")
|
|
208
|
+
|
|
209
|
+
return ConfirmationUIResult(action="modify", new_tool_params=new_params)
|
|
@@ -3,12 +3,12 @@
|
|
|
3
3
|
# SPDX-License-Identifier: Apache-2.0
|
|
4
4
|
|
|
5
5
|
from dataclasses import replace
|
|
6
|
-
from typing import Any, Optional
|
|
6
|
+
from typing import Any, Optional
|
|
7
7
|
|
|
8
8
|
from haystack import component
|
|
9
9
|
from haystack.components.generators.chat.openai import OpenAIChatGenerator as BaseOpenAIChatGenerator
|
|
10
10
|
from haystack.dataclasses import ChatMessage, StreamingCallbackT
|
|
11
|
-
from haystack.tools import
|
|
11
|
+
from haystack.tools import ToolsType
|
|
12
12
|
|
|
13
13
|
from haystack_experimental.utils.hallucination_risk_calculator.dataclasses import HallucinationScoreConfig
|
|
14
14
|
from haystack_experimental.utils.hallucination_risk_calculator.openai_planner import calculate_hallucination_metrics
|
|
@@ -59,7 +59,7 @@ class OpenAIChatGenerator(BaseOpenAIChatGenerator):
|
|
|
59
59
|
streaming_callback: Optional[StreamingCallbackT] = None,
|
|
60
60
|
generation_kwargs: Optional[dict[str, Any]] = None,
|
|
61
61
|
*,
|
|
62
|
-
tools: Optional[
|
|
62
|
+
tools: Optional[ToolsType] = None,
|
|
63
63
|
tools_strict: Optional[bool] = None,
|
|
64
64
|
hallucination_score_config: Optional[HallucinationScoreConfig] = None,
|
|
65
65
|
) -> dict[str, list[ChatMessage]]:
|
|
@@ -75,9 +75,8 @@ class OpenAIChatGenerator(BaseOpenAIChatGenerator):
|
|
|
75
75
|
override the parameters passed during component initialization.
|
|
76
76
|
For details on OpenAI API parameters, see [OpenAI documentation](https://platform.openai.com/docs/api-reference/chat/create).
|
|
77
77
|
:param tools:
|
|
78
|
-
A list of
|
|
79
|
-
`tools` parameter
|
|
80
|
-
`Tool` objects or a `Toolset` instance.
|
|
78
|
+
A list of Tool and/or Toolset objects, or a single Toolset for which the model can prepare calls.
|
|
79
|
+
If set, it will override the `tools` parameter provided during initialization.
|
|
81
80
|
:param tools_strict:
|
|
82
81
|
Whether to enable strict schema adherence for tool calls. If set to `True`, the model will follow exactly
|
|
83
82
|
the schema provided in the `parameters` field of the tool definition, but this may increase latency.
|
|
@@ -127,7 +126,7 @@ class OpenAIChatGenerator(BaseOpenAIChatGenerator):
|
|
|
127
126
|
streaming_callback: Optional[StreamingCallbackT] = None,
|
|
128
127
|
generation_kwargs: Optional[dict[str, Any]] = None,
|
|
129
128
|
*,
|
|
130
|
-
tools: Optional[
|
|
129
|
+
tools: Optional[ToolsType] = None,
|
|
131
130
|
tools_strict: Optional[bool] = None,
|
|
132
131
|
hallucination_score_config: Optional[HallucinationScoreConfig] = None,
|
|
133
132
|
) -> dict[str, list[ChatMessage]]:
|
|
@@ -147,9 +146,8 @@ class OpenAIChatGenerator(BaseOpenAIChatGenerator):
|
|
|
147
146
|
override the parameters passed during component initialization.
|
|
148
147
|
For details on OpenAI API parameters, see [OpenAI documentation](https://platform.openai.com/docs/api-reference/chat/create).
|
|
149
148
|
:param tools:
|
|
150
|
-
A list of
|
|
151
|
-
`tools` parameter
|
|
152
|
-
`Tool` objects or a `Toolset` instance.
|
|
149
|
+
A list of Tool and/or Toolset objects, or a single Toolset for which the model can prepare calls.
|
|
150
|
+
If set, it will override the `tools` parameter provided during initialization.
|
|
153
151
|
:param tools_strict:
|
|
154
152
|
Whether to enable strict schema adherence for tool calls. If set to `True`, the model will follow exactly
|
|
155
153
|
the schema provided in the `parameters` field of the tool definition, but this may increase latency.
|
|
@@ -37,22 +37,31 @@ class EmbeddingBasedDocumentSplitter:
|
|
|
37
37
|
from haystack.components.embedders import SentenceTransformersDocumentEmbedder
|
|
38
38
|
from haystack_experimental.components.preprocessors import EmbeddingBasedDocumentSplitter
|
|
39
39
|
|
|
40
|
+
# Create a document with content that has a clear topic shift
|
|
40
41
|
doc = Document(
|
|
41
42
|
content="This is a first sentence. This is a second sentence. This is a third sentence. "
|
|
42
43
|
"Completely different topic. The same completely different topic."
|
|
43
44
|
)
|
|
44
45
|
|
|
46
|
+
# Initialize the embedder to calculate semantic similarities
|
|
45
47
|
embedder = SentenceTransformersDocumentEmbedder()
|
|
46
48
|
|
|
49
|
+
# Configure the splitter with parameters that control splitting behavior
|
|
47
50
|
splitter = EmbeddingBasedDocumentSplitter(
|
|
48
51
|
document_embedder=embedder,
|
|
49
|
-
sentences_per_group=2,
|
|
50
|
-
percentile=0.95,
|
|
51
|
-
min_length=50,
|
|
52
|
-
max_length=1000
|
|
52
|
+
sentences_per_group=2, # Group 2 sentences before calculating embeddings
|
|
53
|
+
percentile=0.95, # Split when cosine distance exceeds 95th percentile
|
|
54
|
+
min_length=50, # Merge splits shorter than 50 characters
|
|
55
|
+
max_length=1000 # Further split chunks longer than 1000 characters
|
|
53
56
|
)
|
|
54
57
|
splitter.warm_up()
|
|
55
58
|
result = splitter.run(documents=[doc])
|
|
59
|
+
|
|
60
|
+
# The result contains a list of Document objects, each representing a semantic chunk
|
|
61
|
+
# Each split document includes metadata: source_id, split_id, and page_number
|
|
62
|
+
print(f"Original document split into {len(result['documents'])} chunks")
|
|
63
|
+
for i, split_doc in enumerate(result['documents']):
|
|
64
|
+
print(f"Chunk {i}: {split_doc.content[:50]}...")
|
|
56
65
|
```
|
|
57
66
|
"""
|
|
58
67
|
|
|
@@ -78,8 +87,11 @@ class EmbeddingBasedDocumentSplitter:
|
|
|
78
87
|
:param min_length: Minimum length of splits in characters. Splits below this length will be merged.
|
|
79
88
|
:param max_length: Maximum length of splits in characters. Splits above this length will be recursively split.
|
|
80
89
|
:param language: Language for sentence tokenization.
|
|
81
|
-
:param use_split_rules: Whether to use additional split rules for sentence tokenization.
|
|
82
|
-
|
|
90
|
+
:param use_split_rules: Whether to use additional split rules for sentence tokenization. Applies additional
|
|
91
|
+
split rules from SentenceSplitter to the sentence spans.
|
|
92
|
+
:param extend_abbreviations: If True, the abbreviations used by NLTK's PunktTokenizer are extended by a list
|
|
93
|
+
of curated abbreviations. Currently supported languages are: en, de.
|
|
94
|
+
If False, the default abbreviations are used.
|
|
83
95
|
"""
|
|
84
96
|
self.document_embedder = document_embedder
|
|
85
97
|
self.sentences_per_group = sentences_per_group
|
|
@@ -0,0 +1,146 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
|
2
|
+
#
|
|
3
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
|
|
5
|
+
import re
|
|
6
|
+
|
|
7
|
+
from haystack import Document, component, logging
|
|
8
|
+
|
|
9
|
+
logger = logging.getLogger(__name__)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@component
|
|
13
|
+
class MarkdownHeaderLevelInferrer:
|
|
14
|
+
"""
|
|
15
|
+
Infers and rewrites header levels in Markdown text to normalize hierarchy.
|
|
16
|
+
|
|
17
|
+
First header → Always becomes level 1 (#)
|
|
18
|
+
Subsequent headers → Level increases if no content between headers, stays same if content exists
|
|
19
|
+
Maximum level → Capped at 6 (######)
|
|
20
|
+
|
|
21
|
+
### Usage example
|
|
22
|
+
```python
|
|
23
|
+
from haystack import Document
|
|
24
|
+
from haystack_experimental.components.preprocessors import MarkdownHeaderLevelInferrer
|
|
25
|
+
|
|
26
|
+
# Create a document with uniform header levels
|
|
27
|
+
text = "## Title\nSome content\n## Section\nMore content\n## Subsection\nFinal content"
|
|
28
|
+
doc = Document(content=text)
|
|
29
|
+
|
|
30
|
+
# Initialize the inferrer and process the document
|
|
31
|
+
inferrer = MarkdownHeaderLevelInferrer()
|
|
32
|
+
result = inferrer.run([doc])
|
|
33
|
+
|
|
34
|
+
# The headers are now normalized with proper hierarchy
|
|
35
|
+
print(result["documents"][0].content)
|
|
36
|
+
> # Title\nSome content\n## Section\nMore content\n### Subsection\nFinal content
|
|
37
|
+
```
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
def __init__(self):
|
|
41
|
+
"""Initializes the MarkdownHeaderLevelInferrer."""
|
|
42
|
+
# handles headers with optional trailing spaces and empty content
|
|
43
|
+
self._header_pattern = re.compile(r"(?m)^(#{1,6})\s+(.+?)(?:\s*)$")
|
|
44
|
+
|
|
45
|
+
@component.output_types(documents=list[Document])
|
|
46
|
+
def run(self, documents: list[Document]) -> dict:
|
|
47
|
+
"""
|
|
48
|
+
Infers and rewrites the header levels in the content for documents that use uniform header levels.
|
|
49
|
+
|
|
50
|
+
:param documents: list of Document objects to process.
|
|
51
|
+
|
|
52
|
+
:returns:
|
|
53
|
+
dict: a dictionary with the key 'documents' containing the processed Document objects.
|
|
54
|
+
"""
|
|
55
|
+
if not documents:
|
|
56
|
+
logger.warning("No documents provided to process")
|
|
57
|
+
return {"documents": []}
|
|
58
|
+
|
|
59
|
+
logger.debug(f"Inferring and rewriting header levels for {len(documents)} documents")
|
|
60
|
+
processed_docs = [self._process_document(doc) for doc in documents]
|
|
61
|
+
return {"documents": processed_docs}
|
|
62
|
+
|
|
63
|
+
def _process_document(self, doc: Document) -> Document:
|
|
64
|
+
"""
|
|
65
|
+
Processes a single document, inferring and rewriting header levels.
|
|
66
|
+
|
|
67
|
+
:param doc: Document object to process.
|
|
68
|
+
:returns:
|
|
69
|
+
Document object with rewritten header levels.
|
|
70
|
+
"""
|
|
71
|
+
if doc.content is None:
|
|
72
|
+
logger.warning(f"Document {getattr(doc, 'id', '')} content is None; skipping header level inference.")
|
|
73
|
+
return doc
|
|
74
|
+
|
|
75
|
+
matches = list(re.finditer(self._header_pattern, doc.content))
|
|
76
|
+
if not matches:
|
|
77
|
+
logger.info(f"No headers found in document {doc.id}; skipping header level inference.")
|
|
78
|
+
return doc
|
|
79
|
+
|
|
80
|
+
modified_text = MarkdownHeaderLevelInferrer._rewrite_headers(doc.content, matches)
|
|
81
|
+
logger.info(f"Rewrote {len(matches)} headers with inferred levels in document{doc.id}.")
|
|
82
|
+
return MarkdownHeaderLevelInferrer._build_final_document(doc, modified_text)
|
|
83
|
+
|
|
84
|
+
@staticmethod
|
|
85
|
+
def _rewrite_headers(content: str, matches: list[re.Match]) -> str:
|
|
86
|
+
"""
|
|
87
|
+
Rewrites the headers in the content with inferred levels.
|
|
88
|
+
|
|
89
|
+
:param content: Original Markdown content.
|
|
90
|
+
:param matches: List of regex matches for headers.
|
|
91
|
+
"""
|
|
92
|
+
modified_text = content
|
|
93
|
+
offset = 0
|
|
94
|
+
current_level = 1
|
|
95
|
+
|
|
96
|
+
for i, match in enumerate(matches):
|
|
97
|
+
original_header = match.group(0)
|
|
98
|
+
header_text = match.group(2).strip()
|
|
99
|
+
|
|
100
|
+
# Skip empty headers
|
|
101
|
+
if not header_text:
|
|
102
|
+
logger.warning(f"Skipping empty header at position {match.start()}")
|
|
103
|
+
continue
|
|
104
|
+
|
|
105
|
+
has_content = MarkdownHeaderLevelInferrer._has_content_between_headers(content, matches, i)
|
|
106
|
+
inferred_level = MarkdownHeaderLevelInferrer._infer_level(i, current_level, has_content)
|
|
107
|
+
current_level = inferred_level
|
|
108
|
+
|
|
109
|
+
new_header = f"{'#' * inferred_level} {header_text}"
|
|
110
|
+
start_pos = match.start() + offset
|
|
111
|
+
end_pos = match.end() + offset
|
|
112
|
+
modified_text = modified_text[:start_pos] + new_header + modified_text[end_pos:]
|
|
113
|
+
offset += len(new_header) - len(original_header)
|
|
114
|
+
|
|
115
|
+
return modified_text
|
|
116
|
+
|
|
117
|
+
@staticmethod
|
|
118
|
+
def _has_content_between_headers(content: str, matches: list[re.Match], i: int) -> bool:
|
|
119
|
+
"""Checks if there is content between the previous and current header."""
|
|
120
|
+
if i == 0:
|
|
121
|
+
return False
|
|
122
|
+
prev_end = matches[i - 1].end()
|
|
123
|
+
current_start = matches[i].start()
|
|
124
|
+
content_between = content[prev_end:current_start]
|
|
125
|
+
return bool(content_between.strip())
|
|
126
|
+
|
|
127
|
+
@staticmethod
|
|
128
|
+
def _infer_level(i: int, current_level: int, has_content: bool) -> int:
|
|
129
|
+
"""Infers the header level for the current header."""
|
|
130
|
+
if i == 0:
|
|
131
|
+
return 1
|
|
132
|
+
if has_content:
|
|
133
|
+
return current_level
|
|
134
|
+
return min(current_level + 1, 6)
|
|
135
|
+
|
|
136
|
+
@staticmethod
|
|
137
|
+
def _build_final_document(doc: Document, new_content: str) -> Document:
|
|
138
|
+
"""Creates a new Document with updated content, preserving other fields."""
|
|
139
|
+
return Document(
|
|
140
|
+
id=getattr(doc, "id", "") or "",
|
|
141
|
+
content=new_content,
|
|
142
|
+
blob=getattr(doc, "blob", None),
|
|
143
|
+
meta=getattr(doc, "meta", {}) or {},
|
|
144
|
+
score=getattr(doc, "score", None),
|
|
145
|
+
embedding=getattr(doc, "embedding", None),
|
|
146
|
+
)
|