guidellm 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of guidellm might be problematic. Click here for more details.
- guidellm/__init__.py +19 -0
- guidellm/backend/__init__.py +10 -0
- guidellm/backend/base.py +320 -0
- guidellm/backend/openai.py +168 -0
- guidellm/config.py +234 -0
- guidellm/core/__init__.py +24 -0
- guidellm/core/distribution.py +190 -0
- guidellm/core/report.py +321 -0
- guidellm/core/request.py +44 -0
- guidellm/core/result.py +545 -0
- guidellm/core/serializable.py +169 -0
- guidellm/executor/__init__.py +10 -0
- guidellm/executor/base.py +213 -0
- guidellm/executor/profile_generator.py +343 -0
- guidellm/logger.py +83 -0
- guidellm/main.py +336 -0
- guidellm/request/__init__.py +13 -0
- guidellm/request/base.py +194 -0
- guidellm/request/emulated.py +391 -0
- guidellm/request/file.py +76 -0
- guidellm/request/transformers.py +100 -0
- guidellm/scheduler/__init__.py +4 -0
- guidellm/scheduler/base.py +374 -0
- guidellm/scheduler/load_generator.py +196 -0
- guidellm/utils/__init__.py +40 -0
- guidellm/utils/injector.py +70 -0
- guidellm/utils/progress.py +196 -0
- guidellm/utils/text.py +455 -0
- guidellm/utils/transformers.py +151 -0
- guidellm-0.1.0.dist-info/LICENSE +201 -0
- guidellm-0.1.0.dist-info/METADATA +434 -0
- guidellm-0.1.0.dist-info/RECORD +35 -0
- guidellm-0.1.0.dist-info/WHEEL +5 -0
- guidellm-0.1.0.dist-info/entry_points.txt +3 -0
- guidellm-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,196 @@
|
|
|
1
|
+
from datetime import datetime
|
|
2
|
+
from typing import List
|
|
3
|
+
|
|
4
|
+
from loguru import logger
|
|
5
|
+
from rich.console import Group
|
|
6
|
+
from rich.live import Live
|
|
7
|
+
from rich.panel import Panel
|
|
8
|
+
from rich.progress import (
|
|
9
|
+
BarColumn,
|
|
10
|
+
Progress,
|
|
11
|
+
SpinnerColumn,
|
|
12
|
+
TaskID,
|
|
13
|
+
TaskProgressColumn,
|
|
14
|
+
TextColumn,
|
|
15
|
+
TimeElapsedColumn,
|
|
16
|
+
TimeRemainingColumn,
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
__all__ = ["BenchmarkReportProgress"]
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class BenchmarkReportProgress:
|
|
23
|
+
"""
|
|
24
|
+
Manages the progress display for benchmarks and report generation using Rich.
|
|
25
|
+
|
|
26
|
+
This class provides a visual representation of the benchmarking process
|
|
27
|
+
and report generation using Rich's progress bars and panels.
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
def __init__(self):
|
|
31
|
+
"""
|
|
32
|
+
Initialize the BenchmarkReportProgress with default settings.
|
|
33
|
+
|
|
34
|
+
This method sets up the progress displays for both individual benchmarks
|
|
35
|
+
and the overall report, as well as initializing internal task management
|
|
36
|
+
structures.
|
|
37
|
+
"""
|
|
38
|
+
logger.info("Initializing BenchmarkReportProgress instance")
|
|
39
|
+
|
|
40
|
+
self.benchmarks_progress = Progress(
|
|
41
|
+
TextColumn("[{task.fields[start_time_str]}]"),
|
|
42
|
+
SpinnerColumn(),
|
|
43
|
+
TaskProgressColumn(),
|
|
44
|
+
TextColumn("{task.description}"),
|
|
45
|
+
TextColumn(" "),
|
|
46
|
+
TextColumn(
|
|
47
|
+
"[bold cyan]({task.fields[req_per_sec]} req/sec avg)[/bold cyan]"
|
|
48
|
+
),
|
|
49
|
+
)
|
|
50
|
+
self.benchmarks_panel = Panel(
|
|
51
|
+
self.benchmarks_progress,
|
|
52
|
+
title="Benchmarks",
|
|
53
|
+
title_align="left",
|
|
54
|
+
expand=True,
|
|
55
|
+
)
|
|
56
|
+
self.report_progress = Progress(
|
|
57
|
+
SpinnerColumn(),
|
|
58
|
+
TextColumn("Generating report..."),
|
|
59
|
+
BarColumn(bar_width=None),
|
|
60
|
+
TextColumn(
|
|
61
|
+
"({task.fields[completed_benchmarks]}/{task.fields[total_benchmarks]})"
|
|
62
|
+
),
|
|
63
|
+
TextColumn("["),
|
|
64
|
+
TimeElapsedColumn(),
|
|
65
|
+
TextColumn("<"),
|
|
66
|
+
TimeRemainingColumn(),
|
|
67
|
+
TextColumn("]"),
|
|
68
|
+
)
|
|
69
|
+
self.render_group = Group(self.benchmarks_panel, self.report_progress)
|
|
70
|
+
self.live = Live(self.render_group, redirect_stdout=True, redirect_stderr=True)
|
|
71
|
+
|
|
72
|
+
self.report_task: TaskID = None # type: ignore # noqa: PGH003
|
|
73
|
+
self.benchmark_tasks: List[TaskID] = []
|
|
74
|
+
self.benchmark_tasks_started: List[bool] = []
|
|
75
|
+
self.benchmark_tasks_completed: List[bool] = []
|
|
76
|
+
self.benchmark_tasks_progress: List[float] = []
|
|
77
|
+
|
|
78
|
+
def start(self, task_descriptions: List[str]) -> None:
|
|
79
|
+
"""
|
|
80
|
+
Starts the live progress display and initializes benchmark tasks.
|
|
81
|
+
|
|
82
|
+
:param task_descriptions: List of descriptions for each benchmark task.
|
|
83
|
+
:type task_descriptions: List[str]
|
|
84
|
+
"""
|
|
85
|
+
logger.info(
|
|
86
|
+
"Starting BenchmarkReportProgress with task descriptions: {}",
|
|
87
|
+
task_descriptions,
|
|
88
|
+
)
|
|
89
|
+
self.live.start()
|
|
90
|
+
|
|
91
|
+
for task_description in task_descriptions:
|
|
92
|
+
logger.debug("Adding task with description: {}", task_description)
|
|
93
|
+
task_id = self.benchmarks_progress.add_task(
|
|
94
|
+
task_description,
|
|
95
|
+
start=False,
|
|
96
|
+
total=None,
|
|
97
|
+
start_time_str="--:--:--",
|
|
98
|
+
req_per_sec="#.##",
|
|
99
|
+
)
|
|
100
|
+
self.benchmark_tasks.append(task_id)
|
|
101
|
+
self.benchmark_tasks_started.append(False)
|
|
102
|
+
self.benchmark_tasks_completed.append(False)
|
|
103
|
+
self.benchmark_tasks_progress.append(0)
|
|
104
|
+
|
|
105
|
+
self.report_task = self.report_progress.add_task(
|
|
106
|
+
"",
|
|
107
|
+
total=len(self.benchmark_tasks) * 100, # 100 points per report
|
|
108
|
+
completed_benchmarks=0,
|
|
109
|
+
total_benchmarks=len(task_descriptions),
|
|
110
|
+
)
|
|
111
|
+
logger.info("Initialized {} benchmark tasks", len(task_descriptions))
|
|
112
|
+
|
|
113
|
+
def update_benchmark(
|
|
114
|
+
self,
|
|
115
|
+
index: int,
|
|
116
|
+
description: str,
|
|
117
|
+
completed: bool,
|
|
118
|
+
completed_count: int,
|
|
119
|
+
completed_total: int,
|
|
120
|
+
start_time: float,
|
|
121
|
+
req_per_sec: float,
|
|
122
|
+
) -> None:
|
|
123
|
+
"""
|
|
124
|
+
Updates the progress of a specific benchmark task.
|
|
125
|
+
|
|
126
|
+
:param index: Index of the benchmark task to update.
|
|
127
|
+
:type index: int
|
|
128
|
+
:param description: Description of the current benchmark task.
|
|
129
|
+
:type description: str
|
|
130
|
+
:param completed: Flag indicating if the benchmark is completed.
|
|
131
|
+
:type completed: bool
|
|
132
|
+
:param completed_count: Number of completed operations for the task.
|
|
133
|
+
:type completed_count: int
|
|
134
|
+
:param completed_total: Total number of operations for the task.
|
|
135
|
+
:type completed_total: int
|
|
136
|
+
:param start_time: Start time of the benchmark in timestamp format.
|
|
137
|
+
:type start_time: float
|
|
138
|
+
:param req_per_sec: Average requests per second.
|
|
139
|
+
:type req_per_sec: float
|
|
140
|
+
:raises ValueError: If trying to update a completed benchmark.
|
|
141
|
+
"""
|
|
142
|
+
if self.benchmark_tasks_completed[index]:
|
|
143
|
+
err = ValueError(f"Benchmark {index} already completed")
|
|
144
|
+
logger.error("Error updating benchmark: {}", err)
|
|
145
|
+
raise err
|
|
146
|
+
|
|
147
|
+
if not self.benchmark_tasks_started[index]:
|
|
148
|
+
self.benchmark_tasks_started[index] = True
|
|
149
|
+
self.benchmarks_progress.start_task(self.benchmark_tasks[index])
|
|
150
|
+
logger.info("Starting benchmark task at index {}", index)
|
|
151
|
+
|
|
152
|
+
if completed:
|
|
153
|
+
self.benchmark_tasks_completed[index] = True
|
|
154
|
+
self.benchmark_tasks_progress[index] = 100
|
|
155
|
+
self.benchmarks_progress.stop_task(self.benchmark_tasks[index])
|
|
156
|
+
logger.info("Completed benchmark task at index {}", index)
|
|
157
|
+
|
|
158
|
+
self.benchmark_tasks_progress[index] = completed_count / completed_total * 100
|
|
159
|
+
self.benchmarks_progress.update(
|
|
160
|
+
self.benchmark_tasks[index],
|
|
161
|
+
description=description,
|
|
162
|
+
total=completed_total,
|
|
163
|
+
completed=completed_count if not completed else completed_total,
|
|
164
|
+
req_per_sec=(f"{req_per_sec:.2f}" if req_per_sec else "#.##"),
|
|
165
|
+
start_time_str=datetime.fromtimestamp(start_time).strftime("%H:%M:%S")
|
|
166
|
+
if start_time
|
|
167
|
+
else "--:--:--",
|
|
168
|
+
)
|
|
169
|
+
logger.debug(
|
|
170
|
+
"Updated benchmark task at index {}: {}% complete",
|
|
171
|
+
index,
|
|
172
|
+
self.benchmark_tasks_progress[index],
|
|
173
|
+
)
|
|
174
|
+
self.report_progress.update(
|
|
175
|
+
self.report_task,
|
|
176
|
+
total=len(self.benchmark_tasks) * 100,
|
|
177
|
+
completed=sum(self.benchmark_tasks_progress),
|
|
178
|
+
completed_benchmarks=sum(self.benchmark_tasks_completed),
|
|
179
|
+
total_benchmarks=len(self.benchmark_tasks),
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
def finish(self) -> None:
|
|
183
|
+
"""
|
|
184
|
+
Marks the overall report task as finished and stops the live display.
|
|
185
|
+
"""
|
|
186
|
+
logger.info("Finishing BenchmarkReportProgress")
|
|
187
|
+
self.report_progress.update(
|
|
188
|
+
self.report_task,
|
|
189
|
+
total=len(self.benchmark_tasks) * 100,
|
|
190
|
+
completed=len(self.benchmark_tasks) * 100,
|
|
191
|
+
completed_benchmarks=len(self.benchmark_tasks),
|
|
192
|
+
total_benchmarks=len(self.benchmark_tasks),
|
|
193
|
+
)
|
|
194
|
+
self.report_progress.stop_task(self.report_task)
|
|
195
|
+
self.live.stop()
|
|
196
|
+
logger.info("BenchmarkReportProgress finished and live display stopped")
|
guidellm/utils/text.py
ADDED
|
@@ -0,0 +1,455 @@
|
|
|
1
|
+
import csv
|
|
2
|
+
import json
|
|
3
|
+
import re
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
6
|
+
from urllib.parse import urlparse
|
|
7
|
+
|
|
8
|
+
import ftfy
|
|
9
|
+
import requests
|
|
10
|
+
import yaml
|
|
11
|
+
from loguru import logger
|
|
12
|
+
|
|
13
|
+
from guidellm.config import settings
|
|
14
|
+
|
|
15
|
+
__all__ = [
|
|
16
|
+
"clean_text",
|
|
17
|
+
"filter_text",
|
|
18
|
+
"is_path",
|
|
19
|
+
"is_path_like",
|
|
20
|
+
"is_url",
|
|
21
|
+
"load_text",
|
|
22
|
+
"load_text_lines",
|
|
23
|
+
"parse_text_objects",
|
|
24
|
+
"split_lines_by_punctuation",
|
|
25
|
+
"split_text",
|
|
26
|
+
]
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
NAME_TITLES = [
|
|
30
|
+
"Mr.",
|
|
31
|
+
"Mrs.",
|
|
32
|
+
"Ms.",
|
|
33
|
+
"Dr.",
|
|
34
|
+
"Prof.",
|
|
35
|
+
"Jr.",
|
|
36
|
+
"Sr.",
|
|
37
|
+
"St.",
|
|
38
|
+
"Lt.",
|
|
39
|
+
"Col.",
|
|
40
|
+
"Gen.",
|
|
41
|
+
"Rep.",
|
|
42
|
+
"Sen.",
|
|
43
|
+
"Gov.",
|
|
44
|
+
"Pres.",
|
|
45
|
+
]
|
|
46
|
+
SENTENCE_REGEX = r'[^.!?]*[.!?]["\']?\s*(?=[A-Z])'
|
|
47
|
+
MAX_EXTENSION_LENGTH = 8
|
|
48
|
+
MAX_PATH_LENGTH = 4096
|
|
49
|
+
EXTENSION_TYPES = {
|
|
50
|
+
"csv": "csv",
|
|
51
|
+
"jsonl": "jsonl",
|
|
52
|
+
"json": "json",
|
|
53
|
+
"yaml": "yaml",
|
|
54
|
+
"yml": "yaml",
|
|
55
|
+
"txt": "txt",
|
|
56
|
+
"text": "txt",
|
|
57
|
+
}
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def filter_text(
|
|
61
|
+
text: str,
|
|
62
|
+
filter_start: Optional[Union[str, int]] = None,
|
|
63
|
+
filter_end: Optional[Union[str, int]] = None,
|
|
64
|
+
) -> str:
|
|
65
|
+
"""
|
|
66
|
+
Filter text by start and end strings or indices
|
|
67
|
+
|
|
68
|
+
:param text: the text to filter
|
|
69
|
+
:param filter_start: the start string or index to filter from
|
|
70
|
+
:param filter_end: the end string or index to filter to
|
|
71
|
+
:return: the filtered text
|
|
72
|
+
"""
|
|
73
|
+
filter_start_index = -1
|
|
74
|
+
filter_end_index = -1
|
|
75
|
+
|
|
76
|
+
if filter_start and isinstance(filter_start, str):
|
|
77
|
+
filter_start_index = text.index(filter_start)
|
|
78
|
+
elif filter_start:
|
|
79
|
+
if not isinstance(filter_start, int):
|
|
80
|
+
raise ValueError(f"Invalid filter start index: {filter_start}")
|
|
81
|
+
filter_start_index = filter_start
|
|
82
|
+
|
|
83
|
+
if filter_end and isinstance(filter_end, str):
|
|
84
|
+
filter_end_index = text.index(filter_end)
|
|
85
|
+
elif filter_end:
|
|
86
|
+
if not isinstance(filter_end, int):
|
|
87
|
+
raise ValueError(f"Invalid filter end index: {filter_end}")
|
|
88
|
+
filter_end_index = filter_end
|
|
89
|
+
|
|
90
|
+
if filter_start_index > -1:
|
|
91
|
+
text = text[filter_start_index:]
|
|
92
|
+
if filter_end_index > -1:
|
|
93
|
+
text = text[:filter_end_index]
|
|
94
|
+
|
|
95
|
+
return text
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def clean_text(
|
|
99
|
+
text: str,
|
|
100
|
+
fix_encoding: bool = True,
|
|
101
|
+
clean_whitespace: bool = False,
|
|
102
|
+
remove_empty_lines: bool = False,
|
|
103
|
+
force_new_line_punctuation: bool = False,
|
|
104
|
+
) -> str:
|
|
105
|
+
"""
|
|
106
|
+
Clean text by fixing encoding, cleaning whitespace, removing empty lines,
|
|
107
|
+
and forcing new line punctuation
|
|
108
|
+
|
|
109
|
+
:param text: the text to clean
|
|
110
|
+
:param fix_encoding: True to fix the encoding of the text, False to leave as is
|
|
111
|
+
:param clean_whitespace: True to clean the whitespace in the text
|
|
112
|
+
(remove extra spaces, tabs, etc), False to leave as is
|
|
113
|
+
:param remove_empty_lines: True to remove empty lines from the text
|
|
114
|
+
(lines with only whitespace), False to leave as is
|
|
115
|
+
:param force_new_line_punctuation: True to force new lines at punctuation
|
|
116
|
+
(line ends in a period, exclamation point, or question mark),
|
|
117
|
+
False to leave as is
|
|
118
|
+
:return: The cleaned text
|
|
119
|
+
"""
|
|
120
|
+
|
|
121
|
+
if fix_encoding:
|
|
122
|
+
text = ftfy.fix_text(text)
|
|
123
|
+
|
|
124
|
+
if clean_whitespace:
|
|
125
|
+
text = "\n".join(
|
|
126
|
+
[re.sub(r"\s+", " ", line).strip() for line in text.splitlines()]
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
if remove_empty_lines:
|
|
130
|
+
text = "\n".join([line for line in text.splitlines() if line.strip()])
|
|
131
|
+
|
|
132
|
+
if force_new_line_punctuation:
|
|
133
|
+
# first remove any existing new lines
|
|
134
|
+
text = " ".join(line for line in text.splitlines() if line.strip())
|
|
135
|
+
lines = split_lines_by_punctuation(text)
|
|
136
|
+
text = "\n".join(lines)
|
|
137
|
+
|
|
138
|
+
return text
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
def split_lines_by_punctuation(text: str) -> List[str]:
|
|
142
|
+
"""
|
|
143
|
+
Split text into lines based on punctuation
|
|
144
|
+
|
|
145
|
+
:param text: the text to split
|
|
146
|
+
:return: the list of lines
|
|
147
|
+
"""
|
|
148
|
+
|
|
149
|
+
lines = []
|
|
150
|
+
current_line = ""
|
|
151
|
+
skip_next = False
|
|
152
|
+
|
|
153
|
+
for index, char in enumerate(text):
|
|
154
|
+
if skip_next:
|
|
155
|
+
skip_next = False
|
|
156
|
+
continue
|
|
157
|
+
|
|
158
|
+
current_line += char
|
|
159
|
+
|
|
160
|
+
if char not in [".", "!", "?"]:
|
|
161
|
+
# must match end of sentence punctuation
|
|
162
|
+
continue
|
|
163
|
+
|
|
164
|
+
# if this is the character for a title, don't split
|
|
165
|
+
if any(current_line.endswith(title) for title in NAME_TITLES):
|
|
166
|
+
continue
|
|
167
|
+
|
|
168
|
+
char_next_1 = text[index + 1] if index + 1 < len(text) else None
|
|
169
|
+
char_next_2 = text[index + 2] if index + 2 < len(text) else None
|
|
170
|
+
char_next_3 = text[index + 3] if index + 3 < len(text) else None
|
|
171
|
+
|
|
172
|
+
next_is_space = char_next_1 and char_next_1.isspace()
|
|
173
|
+
next_is_quote_and_space = char_next_1 in ["'", '"'] and char_next_2 == " "
|
|
174
|
+
|
|
175
|
+
# next character must be a space or a quote, otherwise skip
|
|
176
|
+
if not next_is_space and not next_is_quote_and_space:
|
|
177
|
+
continue
|
|
178
|
+
|
|
179
|
+
# after this, next character must be an upper case letter
|
|
180
|
+
upper_char = char_next_3 if next_is_quote_and_space else char_next_2
|
|
181
|
+
next_is_upper = upper_char and (
|
|
182
|
+
upper_char.isupper() or upper_char in ["'", '"']
|
|
183
|
+
)
|
|
184
|
+
|
|
185
|
+
if not next_is_upper:
|
|
186
|
+
continue
|
|
187
|
+
|
|
188
|
+
# if next char is a quote, add it and skip next
|
|
189
|
+
if next_is_quote_and_space:
|
|
190
|
+
current_line += text[index + 1]
|
|
191
|
+
skip_next = True
|
|
192
|
+
|
|
193
|
+
lines.append(current_line.strip())
|
|
194
|
+
current_line = ""
|
|
195
|
+
|
|
196
|
+
if current_line:
|
|
197
|
+
lines.append(current_line.strip())
|
|
198
|
+
|
|
199
|
+
return lines
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
def is_url(url: str) -> bool:
|
|
203
|
+
"""
|
|
204
|
+
Check if a string is a URL
|
|
205
|
+
|
|
206
|
+
:param url: the string to check
|
|
207
|
+
:return: True if the string is a URL, False if not
|
|
208
|
+
"""
|
|
209
|
+
try:
|
|
210
|
+
result = urlparse(url)
|
|
211
|
+
return all([result.scheme, result.netloc])
|
|
212
|
+
except Exception: # noqa: BLE001
|
|
213
|
+
return False
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
def is_path(path: Any) -> bool:
|
|
217
|
+
"""
|
|
218
|
+
Check if a string is a path
|
|
219
|
+
|
|
220
|
+
:param path: the string to check
|
|
221
|
+
:return: True if the string is a path, False if not
|
|
222
|
+
"""
|
|
223
|
+
if not isinstance(path, (str, Path)):
|
|
224
|
+
return False
|
|
225
|
+
|
|
226
|
+
if isinstance(path, str):
|
|
227
|
+
path = Path(path)
|
|
228
|
+
|
|
229
|
+
return path.exists()
|
|
230
|
+
|
|
231
|
+
|
|
232
|
+
def is_path_like(path: Any, enforce_file: bool = False) -> bool:
|
|
233
|
+
"""
|
|
234
|
+
Check if a string has a path like structure where it doesn't need to exist
|
|
235
|
+
|
|
236
|
+
:param path: the string to check
|
|
237
|
+
:param enforce_file: True if the path should be a file, False if not
|
|
238
|
+
:return: True if the string is path like, False if not
|
|
239
|
+
"""
|
|
240
|
+
# if path isn't a str or Path, it's not a path
|
|
241
|
+
if not isinstance(path, (str, Path)):
|
|
242
|
+
return False
|
|
243
|
+
|
|
244
|
+
if isinstance(path, Path):
|
|
245
|
+
path = str(path)
|
|
246
|
+
|
|
247
|
+
# if text is too long, it's not a path (4096 for most linux setups)
|
|
248
|
+
if len(path) > MAX_PATH_LENGTH:
|
|
249
|
+
return False
|
|
250
|
+
|
|
251
|
+
# if it starts with a URL scheme, it's not a path
|
|
252
|
+
if path.startswith(("http", "ftp")):
|
|
253
|
+
return False
|
|
254
|
+
|
|
255
|
+
test_path = Path(path)
|
|
256
|
+
|
|
257
|
+
# if it's supposed to be a file and there's no extension or
|
|
258
|
+
# the extension is too long, it's not a path
|
|
259
|
+
return not enforce_file or (
|
|
260
|
+
bool(test_path.suffix) and len(test_path.suffix) <= MAX_EXTENSION_LENGTH
|
|
261
|
+
)
|
|
262
|
+
|
|
263
|
+
|
|
264
|
+
def split_text(text: str) -> Tuple[List[str], List[str], List[int]]:
|
|
265
|
+
"""
|
|
266
|
+
Split text into words / tokens, the white space separators between words,
|
|
267
|
+
and the indices for each new line
|
|
268
|
+
|
|
269
|
+
:param text: the text to split
|
|
270
|
+
:return: the words, the white space separators, and the new line indices
|
|
271
|
+
"""
|
|
272
|
+
if not text or not text.strip():
|
|
273
|
+
return [], [], []
|
|
274
|
+
|
|
275
|
+
text = text.strip()
|
|
276
|
+
tokens = [] # type: List[str]
|
|
277
|
+
separators = [] # type: List[str]
|
|
278
|
+
new_lines = [0]
|
|
279
|
+
buffer = text[0]
|
|
280
|
+
is_token = not text[0].isspace()
|
|
281
|
+
|
|
282
|
+
for char in text[1:]:
|
|
283
|
+
char_whitespace = char.isspace()
|
|
284
|
+
|
|
285
|
+
if char == "\n":
|
|
286
|
+
new_lines.append(len(tokens) + 1)
|
|
287
|
+
|
|
288
|
+
if char_whitespace and is_token:
|
|
289
|
+
tokens.append(buffer)
|
|
290
|
+
buffer = char
|
|
291
|
+
is_token = False
|
|
292
|
+
elif char_whitespace:
|
|
293
|
+
buffer += char
|
|
294
|
+
elif not char_whitespace and not is_token:
|
|
295
|
+
separators.append(buffer)
|
|
296
|
+
buffer = char
|
|
297
|
+
is_token = True
|
|
298
|
+
else:
|
|
299
|
+
buffer += char
|
|
300
|
+
|
|
301
|
+
if buffer and is_token:
|
|
302
|
+
tokens.append(buffer)
|
|
303
|
+
separators.append(" ")
|
|
304
|
+
elif buffer:
|
|
305
|
+
separators.append(buffer)
|
|
306
|
+
|
|
307
|
+
return tokens, separators, new_lines
|
|
308
|
+
|
|
309
|
+
|
|
310
|
+
def load_text(data: Union[str, Path], encoding: Optional[str] = None) -> str:
|
|
311
|
+
"""
|
|
312
|
+
Load an HTML file from a path or URL
|
|
313
|
+
|
|
314
|
+
:param data: the path or URL to load the HTML file from
|
|
315
|
+
:type data: Union[str, Path]
|
|
316
|
+
:param encoding: the encoding to use when reading the file
|
|
317
|
+
:type encoding: str
|
|
318
|
+
:return: the HTML content
|
|
319
|
+
:rtype: str
|
|
320
|
+
"""
|
|
321
|
+
logger.debug("Loading text: {}", data)
|
|
322
|
+
|
|
323
|
+
if not data:
|
|
324
|
+
return ""
|
|
325
|
+
|
|
326
|
+
# check URLs
|
|
327
|
+
if isinstance(data, str) and data.startswith("http"):
|
|
328
|
+
response = requests.get(data, timeout=settings.request_timeout)
|
|
329
|
+
response.raise_for_status()
|
|
330
|
+
return response.text
|
|
331
|
+
|
|
332
|
+
# check raw text
|
|
333
|
+
if isinstance(data, str) and not is_path_like(data, enforce_file=True):
|
|
334
|
+
return data
|
|
335
|
+
|
|
336
|
+
# assume local file
|
|
337
|
+
if not isinstance(data, Path):
|
|
338
|
+
data = Path(data)
|
|
339
|
+
|
|
340
|
+
if not data.exists():
|
|
341
|
+
raise FileNotFoundError(f"File not found: {data}")
|
|
342
|
+
|
|
343
|
+
if not data.is_file():
|
|
344
|
+
raise IsADirectoryError(f"Path is a directory: {data}")
|
|
345
|
+
|
|
346
|
+
return data.read_text(encoding=encoding)
|
|
347
|
+
|
|
348
|
+
|
|
349
|
+
def parse_text_objects(data: str, format_: str = "txt") -> List[Dict]:
|
|
350
|
+
"""
|
|
351
|
+
Parse text data into a list of dictionaries based on the format given
|
|
352
|
+
(csv, jsonl, json, yaml, txt).
|
|
353
|
+
|
|
354
|
+
:param data: the text data to parse
|
|
355
|
+
:param format_: the format of the data to parse:
|
|
356
|
+
'csv', 'jsonl', 'json', 'yaml', 'txt'
|
|
357
|
+
:return: the list of dictionaries parsed from the data, if text
|
|
358
|
+
then each line is a dictionary with a single key 'text'
|
|
359
|
+
"""
|
|
360
|
+
if not isinstance(data, str):
|
|
361
|
+
raise ValueError(f"Unsupported data given of type: {type(data)}")
|
|
362
|
+
|
|
363
|
+
if format_ == "csv":
|
|
364
|
+
reader = csv.DictReader(data.splitlines())
|
|
365
|
+
columns = reader.fieldnames
|
|
366
|
+
return [{col: row[col] for col in columns} for row in reader] # type: ignore # noqa: PGH003
|
|
367
|
+
|
|
368
|
+
if format_ == "jsonl":
|
|
369
|
+
return [json.loads(line) for line in data.splitlines() if line]
|
|
370
|
+
|
|
371
|
+
if format_ in ("json", "yaml"):
|
|
372
|
+
data = json.loads(data) if format_ == "json" else yaml.safe_load(data)
|
|
373
|
+
|
|
374
|
+
if not data:
|
|
375
|
+
return []
|
|
376
|
+
|
|
377
|
+
if isinstance(data, dict) and len(data) == 1:
|
|
378
|
+
logger.debug("Getting first value from JSON/YAML object: {}", data)
|
|
379
|
+
data = list(data.values())[0]
|
|
380
|
+
elif isinstance(data, dict):
|
|
381
|
+
logger.debug("Converting JSON/YAML object to list: {}", data)
|
|
382
|
+
data = list(data.values())
|
|
383
|
+
|
|
384
|
+
if not isinstance(data, list) or not isinstance(data[0], dict):
|
|
385
|
+
raise ValueError(f"Unsupported data structure given: {data}")
|
|
386
|
+
|
|
387
|
+
return data
|
|
388
|
+
|
|
389
|
+
if format_ == "txt":
|
|
390
|
+
return [{"text": line} for line in data.splitlines() if line]
|
|
391
|
+
|
|
392
|
+
raise ValueError(f"Unsupported format given: {format_}")
|
|
393
|
+
|
|
394
|
+
|
|
395
|
+
def load_text_lines(
|
|
396
|
+
data: Union[str, Path, List[Dict]],
|
|
397
|
+
format_: Optional[str] = None,
|
|
398
|
+
filters: Optional[List[str]] = None,
|
|
399
|
+
encoding: Optional[str] = None,
|
|
400
|
+
) -> List[str]:
|
|
401
|
+
"""
|
|
402
|
+
Load text lines from a file or data object with optional filtering and formatting.
|
|
403
|
+
|
|
404
|
+
|
|
405
|
+
:param data: the data to load the text lines from
|
|
406
|
+
:param format_: the format of the data to load, if not provided will be inferred.
|
|
407
|
+
Supported formats: 'csv', 'jsonl', 'json', 'yaml', 'txt'
|
|
408
|
+
:param filters: the keys to filter the data by when loading in order of preference.
|
|
409
|
+
If not provided, will use the first key in the data object.
|
|
410
|
+
:param encoding: the encoding to use when reading the file
|
|
411
|
+
:return: the list of text lines
|
|
412
|
+
"""
|
|
413
|
+
logger.debug(
|
|
414
|
+
"Loading text lines with format {}, filters {}, encoding {} for data: {}",
|
|
415
|
+
format_,
|
|
416
|
+
filters,
|
|
417
|
+
encoding,
|
|
418
|
+
data,
|
|
419
|
+
)
|
|
420
|
+
|
|
421
|
+
if not data:
|
|
422
|
+
return []
|
|
423
|
+
|
|
424
|
+
if not format_ and isinstance(data, (str, Path)) and "." in str(data):
|
|
425
|
+
extension = str(data).split(".")[-1]
|
|
426
|
+
format_ = EXTENSION_TYPES.get(extension, "txt")
|
|
427
|
+
elif not format_:
|
|
428
|
+
format_ = "txt"
|
|
429
|
+
|
|
430
|
+
# load the data if it's a path or URL
|
|
431
|
+
if isinstance(data, Path) or (isinstance(data, str) and data.startswith("http")):
|
|
432
|
+
data = load_text(data, encoding=encoding)
|
|
433
|
+
data = clean_text(data)
|
|
434
|
+
|
|
435
|
+
# parse the data into a list of dictionaries based on the format
|
|
436
|
+
if isinstance(data, str):
|
|
437
|
+
data = parse_text_objects(data, format_)
|
|
438
|
+
|
|
439
|
+
if not isinstance(data, list):
|
|
440
|
+
raise ValueError(f"Unsupported data given of type: {type(data)}")
|
|
441
|
+
|
|
442
|
+
if not isinstance(data[0], dict):
|
|
443
|
+
raise ValueError(f"Unsupported data item type given: {type(data[0])}")
|
|
444
|
+
|
|
445
|
+
# grab the first available filter key to use if preference order as provided
|
|
446
|
+
filter_ = list(data[0].keys())[0]
|
|
447
|
+
for filt in filters or []:
|
|
448
|
+
if filt not in data[0]:
|
|
449
|
+
continue
|
|
450
|
+
|
|
451
|
+
filter_ = filt
|
|
452
|
+
break
|
|
453
|
+
|
|
454
|
+
# extract the lines from the data
|
|
455
|
+
return [row[filter_] for row in data] if filter_ else [str(row) for row in data]
|