kreuzberg 3.11.4__py3-none-any.whl → 3.13.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. kreuzberg/__init__.py +14 -13
  2. kreuzberg/__main__.py +0 -2
  3. kreuzberg/_api/main.py +119 -9
  4. kreuzberg/_chunker.py +0 -15
  5. kreuzberg/_config.py +212 -292
  6. kreuzberg/_document_classification.py +20 -47
  7. kreuzberg/_entity_extraction.py +1 -122
  8. kreuzberg/_extractors/_base.py +4 -71
  9. kreuzberg/_extractors/_email.py +1 -15
  10. kreuzberg/_extractors/_html.py +9 -12
  11. kreuzberg/_extractors/_image.py +1 -25
  12. kreuzberg/_extractors/_pandoc.py +10 -147
  13. kreuzberg/_extractors/_pdf.py +38 -94
  14. kreuzberg/_extractors/_presentation.py +0 -99
  15. kreuzberg/_extractors/_spread_sheet.py +13 -55
  16. kreuzberg/_extractors/_structured.py +1 -4
  17. kreuzberg/_gmft.py +14 -199
  18. kreuzberg/_language_detection.py +1 -36
  19. kreuzberg/_mcp/__init__.py +0 -2
  20. kreuzberg/_mcp/server.py +3 -10
  21. kreuzberg/_mime_types.py +1 -19
  22. kreuzberg/_ocr/_base.py +4 -76
  23. kreuzberg/_ocr/_easyocr.py +124 -186
  24. kreuzberg/_ocr/_paddleocr.py +154 -224
  25. kreuzberg/_ocr/_table_extractor.py +184 -0
  26. kreuzberg/_ocr/_tesseract.py +797 -361
  27. kreuzberg/_playa.py +5 -31
  28. kreuzberg/_registry.py +0 -36
  29. kreuzberg/_types.py +588 -93
  30. kreuzberg/_utils/_cache.py +84 -138
  31. kreuzberg/_utils/_device.py +0 -74
  32. kreuzberg/_utils/_document_cache.py +0 -75
  33. kreuzberg/_utils/_errors.py +0 -50
  34. kreuzberg/_utils/_ocr_cache.py +136 -0
  35. kreuzberg/_utils/_pdf_lock.py +0 -16
  36. kreuzberg/_utils/_process_pool.py +17 -64
  37. kreuzberg/_utils/_quality.py +0 -60
  38. kreuzberg/_utils/_ref.py +32 -0
  39. kreuzberg/_utils/_serialization.py +0 -30
  40. kreuzberg/_utils/_string.py +9 -59
  41. kreuzberg/_utils/_sync.py +0 -77
  42. kreuzberg/_utils/_table.py +49 -101
  43. kreuzberg/_utils/_tmp.py +0 -9
  44. kreuzberg/cli.py +54 -74
  45. kreuzberg/extraction.py +39 -32
  46. {kreuzberg-3.11.4.dist-info → kreuzberg-3.13.1.dist-info}/METADATA +19 -15
  47. kreuzberg-3.13.1.dist-info/RECORD +57 -0
  48. kreuzberg-3.11.4.dist-info/RECORD +0 -54
  49. {kreuzberg-3.11.4.dist-info → kreuzberg-3.13.1.dist-info}/WHEEL +0 -0
  50. {kreuzberg-3.11.4.dist-info → kreuzberg-3.13.1.dist-info}/entry_points.txt +0 -0
  51. {kreuzberg-3.11.4.dist-info → kreuzberg-3.13.1.dist-info}/licenses/LICENSE +0 -0
@@ -1,14 +1,10 @@
1
- """Quality post-processing utilities for extracted text."""
2
-
3
1
  from __future__ import annotations
4
2
 
5
3
  import re
6
4
  from functools import reduce
7
5
  from typing import Any
8
6
 
9
- # Pre-compiled patterns for performance
10
7
  _OCR_ARTIFACTS = {
11
- # Common OCR misreads
12
8
  "scattered_chars": re.compile(r"\b[a-zA-Z]\s{2,}[a-zA-Z]\s{2,}[a-zA-Z]\b"),
13
9
  "repeated_punctuation": re.compile(r"[.]{3,}|[-]{3,}|[_]{3,}"),
14
10
  "isolated_punctuation": re.compile(r"\s[.,;:!?]\s"),
@@ -17,7 +13,6 @@ _OCR_ARTIFACTS = {
17
13
  "broken_sentences": re.compile(r"[a-z]\s{3,}[A-Z][a-z]"),
18
14
  }
19
15
 
20
- # Combined pattern for faster OCR penalty calculation
21
16
  _COMBINED_OCR_PATTERN = re.compile(
22
17
  r"(?P<scattered>\b[a-zA-Z]\s{2,}[a-zA-Z]\s{2,}[a-zA-Z]\b)|"
23
18
  r"(?P<repeated>[.]{3,}|[-]{3,}|[_]{3,})|"
@@ -27,14 +22,12 @@ _COMBINED_OCR_PATTERN = re.compile(
27
22
  r"(?P<broken>[a-z]\s{3,}[A-Z][a-z])"
28
23
  )
29
24
 
30
- # Pre-compiled patterns for text normalization
31
25
  _WHITESPACE_NORMALIZE = re.compile(r"[ \t\f\v\r\xa0\u2000-\u200b\u2028\u2029\u3000]+")
32
26
  _NEWLINE_NORMALIZE = re.compile(r"\n\s*\n\s*\n+")
33
27
  _SENTENCE_DETECT = re.compile(r"[.!?]\s+[A-Z]")
34
28
  _PUNCTUATION_DETECT = re.compile(r"[.!?]")
35
29
 
36
30
  _SCRIPT_PATTERNS = {
37
- # JavaScript and CSS content
38
31
  "js_functions": re.compile(r"function\s+\w+\s*\([^)]*\)\s*\{[^}]*\}", re.IGNORECASE),
39
32
  "css_rules": re.compile(r"\.[a-zA-Z][\w-]*\s*\{[^}]*\}", re.IGNORECASE),
40
33
  "script_tags": re.compile(r"<script[^>]*>.*?</script>", re.DOTALL | re.IGNORECASE),
@@ -51,39 +44,24 @@ _NAVIGATION_PATTERNS = {
51
44
 
52
45
 
53
46
  def calculate_quality_score(text: str, metadata: dict[str, Any] | None = None) -> float:
54
- """Calculate overall quality score for extracted text.
55
-
56
- Args:
57
- text: The extracted text content
58
- metadata: Optional metadata for additional scoring
59
-
60
- Returns:
61
- Quality score between 0.0 and 1.0
62
- """
63
47
  if not text or not text.strip():
64
48
  return 0.0
65
49
 
66
- # Initialize score
67
50
  score = 1.0
68
51
  total_chars = len(text)
69
52
 
70
- # Penalize OCR artifacts
71
53
  ocr_penalty = _calculate_ocr_penalty(text, total_chars)
72
54
  score -= ocr_penalty * 0.3
73
55
 
74
- # Penalize script/style content
75
56
  script_penalty = _calculate_script_penalty(text, total_chars)
76
57
  score -= script_penalty * 0.2
77
58
 
78
- # Penalize navigation content
79
59
  nav_penalty = _calculate_navigation_penalty(text, total_chars)
80
60
  score -= nav_penalty * 0.1
81
61
 
82
- # Bonus for structure (sentences, paragraphs)
83
62
  structure_bonus = _calculate_structure_bonus(text)
84
63
  score += structure_bonus * 0.2
85
64
 
86
- # Bonus for metadata richness
87
65
  if metadata:
88
66
  metadata_bonus = _calculate_metadata_bonus(metadata)
89
67
  score += metadata_bonus * 0.1
@@ -92,27 +70,15 @@ def calculate_quality_score(text: str, metadata: dict[str, Any] | None = None) -
92
70
 
93
71
 
94
72
  def clean_extracted_text(text: str) -> str:
95
- """Clean extracted text by removing artifacts and improving quality.
96
-
97
- Args:
98
- text: The raw extracted text
99
-
100
- Returns:
101
- Cleaned text with artifacts removed
102
- """
103
73
  if not text:
104
74
  return text
105
75
 
106
- # Remove script and style content using functools.reduce for single pass
107
76
  text = reduce(lambda t, pattern: pattern.sub(" ", t), _SCRIPT_PATTERNS.values(), text)
108
77
 
109
- # Clean OCR artifacts
110
78
  text = _clean_ocr_artifacts(text)
111
79
 
112
- # Clean navigation elements
113
80
  text = _clean_navigation_elements(text)
114
81
 
115
- # Normalize whitespace using pre-compiled patterns
116
82
  text = _WHITESPACE_NORMALIZE.sub(" ", text)
117
83
  text = _NEWLINE_NORMALIZE.sub("\n\n", text)
118
84
 
@@ -120,72 +86,57 @@ def clean_extracted_text(text: str) -> str:
120
86
 
121
87
 
122
88
  def _calculate_ocr_penalty(text: str, total_chars: int) -> float:
123
- """Calculate penalty for OCR artifacts."""
124
89
  if total_chars == 0:
125
90
  return 0.0
126
91
 
127
- # Use combined pattern for single-pass processing
128
92
  artifact_chars = sum(len(match.group()) for match in _COMBINED_OCR_PATTERN.finditer(text))
129
93
  return min(1.0, artifact_chars / total_chars)
130
94
 
131
95
 
132
96
  def _calculate_script_penalty(text: str, total_chars: int) -> float:
133
- """Calculate penalty for script/style content."""
134
97
  if total_chars == 0:
135
98
  return 0.0
136
99
 
137
- # Use sum with generator expression for single-pass calculation
138
100
  script_chars = sum(len(match) for pattern in _SCRIPT_PATTERNS.values() for match in pattern.findall(text))
139
101
 
140
102
  return min(1.0, script_chars / total_chars)
141
103
 
142
104
 
143
105
  def _calculate_navigation_penalty(text: str, total_chars: int) -> float:
144
- """Calculate penalty for navigation content."""
145
106
  if total_chars == 0:
146
107
  return 0.0
147
108
 
148
- # Use sum with generator expression for single-pass calculation
149
109
  nav_chars = sum(len(match) for pattern in _NAVIGATION_PATTERNS.values() for match in pattern.findall(text))
150
110
 
151
111
  return min(1.0, nav_chars / total_chars)
152
112
 
153
113
 
154
114
  def _calculate_structure_bonus(text: str) -> float:
155
- """Calculate bonus for proper text structure."""
156
115
  if not text:
157
116
  return 0.0
158
117
 
159
- # Count sentences (rough heuristic)
160
118
  sentence_count = len(_SENTENCE_DETECT.findall(text))
161
119
 
162
- # Count paragraphs
163
120
  paragraph_count = len(text.split("\n\n"))
164
121
 
165
- # Calculate structure score
166
122
  words = len(text.split())
167
123
  if words == 0:
168
124
  return 0.0
169
125
 
170
- # Good structure: reasonable sentence and paragraph distribution
171
126
  avg_words_per_sentence = words / max(1, sentence_count)
172
127
  avg_words_per_paragraph = words / max(1, paragraph_count)
173
128
 
174
129
  structure_score = 0.0
175
130
 
176
- # Bonus for reasonable sentence length (10-30 words)
177
131
  if 10 <= avg_words_per_sentence <= 30:
178
132
  structure_score += 0.3
179
133
 
180
- # Bonus for reasonable paragraph length (50-300 words)
181
134
  if 50 <= avg_words_per_paragraph <= 300:
182
135
  structure_score += 0.3
183
136
 
184
- # Bonus for having multiple paragraphs
185
137
  if paragraph_count > 1:
186
138
  structure_score += 0.2
187
139
 
188
- # Bonus for having punctuation
189
140
  if _PUNCTUATION_DETECT.search(text):
190
141
  structure_score += 0.2
191
142
 
@@ -193,7 +144,6 @@ def _calculate_structure_bonus(text: str) -> float:
193
144
 
194
145
 
195
146
  def _calculate_metadata_bonus(metadata: dict[str, Any]) -> float:
196
- """Calculate bonus for rich metadata."""
197
147
  if not metadata:
198
148
  return 0.0
199
149
 
@@ -204,30 +154,20 @@ def _calculate_metadata_bonus(metadata: dict[str, Any]) -> float:
204
154
 
205
155
 
206
156
  def _clean_ocr_artifacts(text: str) -> str:
207
- """Remove common OCR artifacts from text."""
208
- # Fix scattered characters (likely OCR errors)
209
157
  text = _OCR_ARTIFACTS["scattered_chars"].sub(lambda m: m.group().replace(" ", ""), text)
210
158
 
211
- # Clean repeated punctuation
212
159
  text = _OCR_ARTIFACTS["repeated_punctuation"].sub("...", text)
213
160
 
214
- # Fix isolated punctuation
215
161
  text = _OCR_ARTIFACTS["isolated_punctuation"].sub(" ", text)
216
162
 
217
- # Remove malformed words with numbers mixed in
218
163
  text = _OCR_ARTIFACTS["malformed_words"].sub(" ", text)
219
164
 
220
- # Normalize excessive whitespace
221
165
  return _OCR_ARTIFACTS["excessive_whitespace"].sub(" ", text)
222
166
 
223
167
 
224
168
  def _clean_navigation_elements(text: str) -> str:
225
- """Remove navigation elements from text."""
226
- # Remove navigation words
227
169
  text = _NAVIGATION_PATTERNS["nav_words"].sub(" ", text)
228
170
 
229
- # Remove breadcrumbs
230
171
  text = _NAVIGATION_PATTERNS["breadcrumbs"].sub(" ", text)
231
172
 
232
- # Remove pagination
233
173
  return _NAVIGATION_PATTERNS["pagination"].sub(" ", text)
@@ -0,0 +1,32 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING, Any, ClassVar, Generic, TypeVar, cast
4
+
5
+ if TYPE_CHECKING:
6
+ from collections.abc import Callable
7
+
8
+ T = TypeVar("T")
9
+
10
+
11
+ class Ref(Generic[T]):
12
+ _instances: ClassVar[dict[str, Any]] = {}
13
+
14
+ def __init__(self, name: str, factory: Callable[[], T]) -> None:
15
+ self.name = name
16
+ self.factory = factory
17
+
18
+ def get(self) -> T:
19
+ if self.name not in self._instances:
20
+ self._instances[self.name] = self.factory()
21
+ return cast("T", self._instances[self.name])
22
+
23
+ def clear(self) -> None:
24
+ if self.name in self._instances:
25
+ del self._instances[self.name]
26
+
27
+ def is_initialized(self) -> bool:
28
+ return self.name in self._instances
29
+
30
+ @classmethod
31
+ def clear_all(cls) -> None:
32
+ cls._instances.clear()
@@ -1,5 +1,3 @@
1
- """Fast serialization utilities using msgspec."""
2
-
3
1
  from __future__ import annotations
4
2
 
5
3
  from dataclasses import is_dataclass
@@ -12,7 +10,6 @@ from msgspec.msgpack import decode, encode
12
10
  T = TypeVar("T")
13
11
 
14
12
 
15
- # Define dict method names in priority order
16
13
  _DICT_METHOD_NAMES = (
17
14
  "to_dict",
18
15
  "as_dict",
@@ -25,21 +22,18 @@ _DICT_METHOD_NAMES = (
25
22
 
26
23
 
27
24
  def encode_hook(obj: Any) -> Any:
28
- """Custom encoder for complex objects."""
29
25
  if callable(obj):
30
26
  return None
31
27
 
32
28
  if isinstance(obj, Exception):
33
29
  return {"message": str(obj), "type": type(obj).__name__}
34
30
 
35
- # Check for dict-like methods more efficiently using any() with generator
36
31
  for attr_name in _DICT_METHOD_NAMES:
37
32
  method = getattr(obj, attr_name, None)
38
33
  if method is not None and callable(method):
39
34
  return method()
40
35
 
41
36
  if is_dataclass(obj) and not isinstance(obj, type):
42
- # Use msgspec.to_builtins for more efficient conversion
43
37
  return msgspec.to_builtins(obj)
44
38
 
45
39
  if hasattr(obj, "save") and hasattr(obj, "format"):
@@ -49,18 +43,6 @@ def encode_hook(obj: Any) -> Any:
49
43
 
50
44
 
51
45
  def deserialize(value: str | bytes, target_type: type[T]) -> T:
52
- """Deserialize bytes/string to target type.
53
-
54
- Args:
55
- value: Serialized data
56
- target_type: Type to deserialize to
57
-
58
- Returns:
59
- Deserialized object
60
-
61
- Raises:
62
- ValueError: If deserialization fails
63
- """
64
46
  try:
65
47
  return decode(cast("bytes", value), type=target_type, strict=False)
66
48
  except MsgspecError as e:
@@ -68,18 +50,6 @@ def deserialize(value: str | bytes, target_type: type[T]) -> T:
68
50
 
69
51
 
70
52
  def serialize(value: Any, **kwargs: Any) -> bytes:
71
- """Serialize value to bytes.
72
-
73
- Args:
74
- value: Object to serialize
75
- **kwargs: Additional data to merge with value if it's a dict
76
-
77
- Returns:
78
- Serialized bytes
79
-
80
- Raises:
81
- ValueError: If serialization fails
82
- """
83
53
  if isinstance(value, dict) and kwargs:
84
54
  value = value | kwargs
85
55
 
@@ -7,52 +7,33 @@ from functools import lru_cache
7
7
 
8
8
  import chardetng_py
9
9
 
10
- # Compile regex patterns once at module level for performance
11
10
  _WHITESPACE_PATTERN = re.compile(r"[ \t\f\v\r\xa0\u2000-\u200b\u2028\u2029\u3000]+")
12
11
  _NEWLINES_PATTERN = re.compile(r"\n+")
13
12
  _MOJIBAKE_PATTERNS = {
14
- # Hebrew as Cyrillic patterns
15
13
  "hebrew_as_cyrillic": re.compile(r"[\u0400-\u04FF]{3,}"),
16
- # Control characters that shouldn't appear in text
17
14
  "control_chars": re.compile(r"[\x00-\x08\x0B-\x0C\x0E-\x1F\x7F-\x9F]"),
18
- # Unicode replacement characters
19
15
  "replacement_chars": re.compile(r"\uFFFD+"),
20
- # Isolated combining marks (likely encoding issues)
21
16
  "isolated_combining": re.compile(r"[\u0300-\u036F](?![^\u0300-\u036F])"),
22
17
  }
23
18
 
24
- # Simple cache for encoding detection (in-memory, session-scoped)
25
19
  _encoding_cache: dict[str, str] = {}
26
20
 
27
21
 
28
22
  @lru_cache(maxsize=128)
29
23
  def _get_encoding_cache_key(data_hash: str, size: int) -> str:
30
- """Generate cache key for encoding detection."""
31
- # Use string interpolation which is faster than format strings for simple cases
32
24
  return f"{data_hash}:{size}"
33
25
 
34
26
 
35
27
  def safe_decode(byte_data: bytes, encoding: str | None = None) -> str:
36
- """Decode a byte string safely with mojibake detection and correction.
37
-
38
- Args:
39
- byte_data: The byte string to decode.
40
- encoding: The encoding to use when decoding the byte string.
41
-
42
- Returns:
43
- The decoded string with mojibake detection and correction.
44
- """
45
28
  if not byte_data:
46
29
  return ""
47
30
 
48
- # Try provided encoding first (fastest path)
49
31
  if encoding:
50
32
  with suppress(UnicodeDecodeError, LookupError):
51
33
  decoded = byte_data.decode(encoding)
52
34
  return _fix_mojibake(decoded)
53
35
 
54
- # Check cache for similar content (performance optimization)
55
- data_hash = hashlib.sha256(byte_data[:1024]).hexdigest()[:16] # Hash first 1KB
36
+ data_hash = hashlib.sha256(byte_data[:1024]).hexdigest()[:16]
56
37
  cache_key = _get_encoding_cache_key(data_hash, len(byte_data))
57
38
 
58
39
  if cache_key in _encoding_cache:
@@ -61,25 +42,22 @@ def safe_decode(byte_data: bytes, encoding: str | None = None) -> str:
61
42
  decoded = byte_data.decode(cached_encoding)
62
43
  return _fix_mojibake(decoded)
63
44
 
64
- # Use chardetng for better performance than charset-normalizer
65
45
  detected_encoding = chardetng_py.detect(byte_data)
66
46
  if detected_encoding:
67
47
  with suppress(UnicodeDecodeError, LookupError):
68
48
  decoded = byte_data.decode(detected_encoding)
69
- # Cache successful encoding detection
70
- if len(_encoding_cache) < 1000: # Prevent unlimited growth
49
+ if len(_encoding_cache) < 1000: # Prevent unlimited growth ~keep
71
50
  _encoding_cache[cache_key] = detected_encoding
72
51
  return _fix_mojibake(decoded)
73
52
 
74
- # Try multiple encodings with confidence scoring
75
53
  encodings_to_try = [
76
54
  "utf-8",
77
- "windows-1255", # Hebrew
78
- "iso-8859-8", # Hebrew
79
- "windows-1256", # Arabic
80
- "iso-8859-6", # Arabic
81
- "windows-1252", # Western European
82
- "cp1251", # Cyrillic
55
+ "windows-1255", # Hebrew ~keep
56
+ "iso-8859-8", # Hebrew ~keep
57
+ "windows-1256", # Arabic ~keep
58
+ "iso-8859-6", # Arabic ~keep
59
+ "windows-1252", # Western European ~keep
60
+ "cp1251", # Cyrillic ~keep
83
61
  ]
84
62
 
85
63
  best_result = None
@@ -96,12 +74,10 @@ def safe_decode(byte_data: bytes, encoding: str | None = None) -> str:
96
74
  if best_result and best_confidence > 0.5:
97
75
  return _fix_mojibake(best_result)
98
76
 
99
- # Final fallback
100
77
  return byte_data.decode("latin-1", errors="replace")
101
78
 
102
79
 
103
80
  def _calculate_text_confidence(text: str) -> float:
104
- """Calculate confidence score for decoded text quality."""
105
81
  if not text:
106
82
  return 0.0
107
83
 
@@ -109,77 +85,51 @@ def _calculate_text_confidence(text: str) -> float:
109
85
  if total_chars == 0:
110
86
  return 0.0
111
87
 
112
- # Check for common encoding problems - compile patterns once
113
88
  replacement_count = len(_MOJIBAKE_PATTERNS["replacement_chars"].findall(text))
114
89
  control_count = len(_MOJIBAKE_PATTERNS["control_chars"].findall(text))
115
90
 
116
- # Penalize replacement and control characters
117
91
  penalty = (replacement_count + control_count * 2) / total_chars
118
92
 
119
- # Bonus for readable character ranges - more efficient counting
120
- # Use generator expression with early termination
121
93
  readable_chars = sum(1 for c in text if c.isprintable() or c.isspace())
122
94
  readability_score = readable_chars / total_chars
123
95
 
124
- # Check for suspicious Cyrillic that might be misencoded Hebrew
125
96
  cyrillic_matches = _MOJIBAKE_PATTERNS["hebrew_as_cyrillic"].findall(text)
126
97
  if cyrillic_matches:
127
- # Calculate total length more efficiently
128
98
  cyrillic_length = sum(len(match) for match in cyrillic_matches)
129
99
  if cyrillic_length > total_chars * 0.1:
130
- penalty += 0.3 # Heavy penalty for likely mojibake
100
+ penalty += 0.3
131
101
 
132
102
  return max(0.0, min(1.0, readability_score - penalty))
133
103
 
134
104
 
135
105
  def _fix_mojibake(text: str) -> str:
136
- """Attempt to fix common mojibake patterns."""
137
106
  if not text:
138
107
  return text
139
108
 
140
- # Remove control characters
141
109
  text = _MOJIBAKE_PATTERNS["control_chars"].sub("", text)
142
110
 
143
- # Remove replacement characters
144
111
  text = _MOJIBAKE_PATTERNS["replacement_chars"].sub("", text)
145
112
 
146
- # Remove isolated combining marks
147
113
  text = _MOJIBAKE_PATTERNS["isolated_combining"].sub("", text)
148
114
 
149
- # Try to fix Hebrew encoded as Cyrillic (common Windows-1255 -> CP1251 confusion)
150
115
  if _MOJIBAKE_PATTERNS["hebrew_as_cyrillic"].search(text):
151
- # This is a heuristic fix - in practice, you'd need actual character mapping
152
- # For now, we flag it for manual review by keeping the text but adding a marker
153
116
  pass
154
117
 
155
118
  return text
156
119
 
157
120
 
158
121
  def normalize_spaces(text: str) -> str:
159
- """Normalize spaces while preserving line breaks and paragraph structure.
160
-
161
- Args:
162
- text: The text to normalize.
163
-
164
- Returns:
165
- The normalized text with proper spacing.
166
- """
167
122
  if not text or not text.strip():
168
123
  return ""
169
124
 
170
- # Split by double newlines to preserve paragraph breaks
171
125
  paragraphs = text.split("\n\n")
172
126
 
173
127
  result_paragraphs = []
174
128
 
175
129
  for paragraph in paragraphs:
176
- # Use pre-compiled patterns for better performance
177
- # Replace multiple whitespace (except newlines) with single space
178
130
  cleaned = _WHITESPACE_PATTERN.sub(" ", paragraph)
179
- # Clean up multiple newlines within paragraph (keep single newlines)
180
131
  cleaned = _NEWLINES_PATTERN.sub("\n", cleaned)
181
132
 
182
- # Process lines efficiently - manual loop avoids double strip() calls
183
133
  lines = []
184
134
  for line in cleaned.split("\n"):
185
135
  stripped_line = line.strip()
kreuzberg/_utils/_sync.py CHANGED
@@ -18,17 +18,6 @@ P = ParamSpec("P")
18
18
 
19
19
 
20
20
  async def run_sync(sync_fn: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
21
- """Run a synchronous function in an asynchronous context.
22
-
23
- Args:
24
- sync_fn: The synchronous function to run.
25
- *args: The positional arguments to pass to the function.
26
- **kwargs: The keyword arguments to pass to the function.
27
-
28
- Returns:
29
- The result of the synchronous function.
30
- """
31
- # Optimize: only create partial if we have kwargs
32
21
  if kwargs:
33
22
  handler = partial(sync_fn, **kwargs)
34
23
  return cast("T", await any_io_run_sync(handler, *args, abandon_on_cancel=True)) # pyright: ignore [reportCallIssue]
@@ -36,14 +25,6 @@ async def run_sync(sync_fn: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -
36
25
 
37
26
 
38
27
  async def run_taskgroup(*async_tasks: Awaitable[Any]) -> list[Any]:
39
- """Run a list of coroutines concurrently.
40
-
41
- Args:
42
- *async_tasks: The list of coroutines to run.
43
-
44
- Returns:
45
- The results of the coroutines.
46
- """
47
28
  results: list[Any] = [None] * len(async_tasks)
48
29
 
49
30
  async def run_task(index: int, task: Awaitable[T]) -> None:
@@ -57,15 +38,6 @@ async def run_taskgroup(*async_tasks: Awaitable[Any]) -> list[Any]:
57
38
 
58
39
 
59
40
  async def run_taskgroup_batched(*async_tasks: Awaitable[Any], batch_size: int) -> list[Any]:
60
- """Run a list of coroutines concurrently in batches.
61
-
62
- Args:
63
- *async_tasks: The list of coroutines to run.
64
- batch_size: The size of each batch.
65
-
66
- Returns:
67
- The results of the coroutines.
68
- """
69
41
  results: list[Any] = []
70
42
 
71
43
  for i in range(0, len(async_tasks), batch_size):
@@ -76,25 +48,6 @@ async def run_taskgroup_batched(*async_tasks: Awaitable[Any], batch_size: int) -
76
48
 
77
49
 
78
50
  async def run_maybe_sync(fn: Callable[P, T | Awaitable[T]], *args: P.args, **kwargs: P.kwargs) -> T:
79
- """Executes a callable function and handles both synchronous and asynchronous
80
- results.
81
-
82
- This function invokes the provided callable `sync_fn` with the given
83
- arguments and keyword arguments. If the result of `sync_fn` is awaitable,
84
- it awaits the result before returning it. Otherwise, the result is returned
85
- directly.
86
-
87
- Args:
88
- fn: The callable to be executed. It can produce either a
89
- synchronous or asynchronous result.
90
- *args: Positional arguments to pass to `sync_fn`.
91
- **kwargs: Keyword arguments to pass to `sync_fn`.
92
-
93
- Returns:
94
- The result of `sync_fn` invocation. If the result is awaitable, the
95
- awaited value is returned. Otherwise, the synchronous result is
96
- returned.
97
- """
98
51
  result = fn(*args, **kwargs)
99
52
  if isawaitable(result):
100
53
  return cast("T", await result)
@@ -102,40 +55,10 @@ async def run_maybe_sync(fn: Callable[P, T | Awaitable[T]], *args: P.args, **kwa
102
55
 
103
56
 
104
57
  def run_maybe_async(fn: Callable[P, T | Awaitable[T]], *args: P.args, **kwargs: P.kwargs) -> T:
105
- """Runs a synchronous or asynchronous function, resolving the output.
106
-
107
- Determines if the provided function is synchronous or asynchronous. If synchronous,
108
- executes it directly. If asynchronous, it runs the function within the event loop
109
- using anyio. The return value is resolved regardless of the function type.
110
-
111
- Args:
112
- fn: The function to be executed, which can
113
- either be synchronous or asynchronous.
114
- *args: Positional arguments to be passed to the function.
115
- **kwargs: Keyword arguments to be passed to the function.
116
-
117
- Returns:
118
- T: The return value of the executed function, resolved if asynchronous.
119
- """
120
58
  return cast("T", fn(*args, **kwargs) if not iscoroutinefunction(fn) else anyio.run(partial(fn, **kwargs), *args))
121
59
 
122
60
 
123
61
  def run_sync_only(fn: Callable[P, T | Awaitable[T]], *args: P.args, **kwargs: P.kwargs) -> T:
124
- """Runs a function, but only if it's synchronous. Raises error if async.
125
-
126
- This is used for pure sync code paths where we cannot handle async functions.
127
-
128
- Args:
129
- fn: The function to be executed, must be synchronous.
130
- *args: Positional arguments to be passed to the function.
131
- **kwargs: Keyword arguments to be passed to the function.
132
-
133
- Returns:
134
- T: The return value of the executed function.
135
-
136
- Raises:
137
- RuntimeError: If the function is asynchronous.
138
- """
139
62
  if iscoroutinefunction(fn):
140
63
  raise RuntimeError(f"Cannot run async function {fn.__name__} in sync-only context")
141
64
  return cast("T", fn(*args, **kwargs))