maque 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (143) hide show
  1. maque/__init__.py +30 -0
  2. maque/__main__.py +926 -0
  3. maque/ai_platform/__init__.py +0 -0
  4. maque/ai_platform/crawl.py +45 -0
  5. maque/ai_platform/metrics.py +258 -0
  6. maque/ai_platform/nlp_preprocess.py +67 -0
  7. maque/ai_platform/webpage_screen_shot.py +195 -0
  8. maque/algorithms/__init__.py +78 -0
  9. maque/algorithms/bezier.py +15 -0
  10. maque/algorithms/bktree.py +117 -0
  11. maque/algorithms/core.py +104 -0
  12. maque/algorithms/hilbert.py +16 -0
  13. maque/algorithms/rate_function.py +92 -0
  14. maque/algorithms/transform.py +27 -0
  15. maque/algorithms/trie.py +272 -0
  16. maque/algorithms/utils.py +63 -0
  17. maque/algorithms/video.py +587 -0
  18. maque/api/__init__.py +1 -0
  19. maque/api/common.py +110 -0
  20. maque/api/fetch.py +26 -0
  21. maque/api/static/icon.png +0 -0
  22. maque/api/static/redoc.standalone.js +1782 -0
  23. maque/api/static/swagger-ui-bundle.js +3 -0
  24. maque/api/static/swagger-ui.css +3 -0
  25. maque/cli/__init__.py +1 -0
  26. maque/cli/clean_invisible_chars.py +324 -0
  27. maque/cli/core.py +34 -0
  28. maque/cli/groups/__init__.py +26 -0
  29. maque/cli/groups/config.py +205 -0
  30. maque/cli/groups/data.py +615 -0
  31. maque/cli/groups/doctor.py +259 -0
  32. maque/cli/groups/embedding.py +222 -0
  33. maque/cli/groups/git.py +29 -0
  34. maque/cli/groups/help.py +410 -0
  35. maque/cli/groups/llm.py +223 -0
  36. maque/cli/groups/mcp.py +241 -0
  37. maque/cli/groups/mllm.py +1795 -0
  38. maque/cli/groups/mllm_simple.py +60 -0
  39. maque/cli/groups/quant.py +210 -0
  40. maque/cli/groups/service.py +490 -0
  41. maque/cli/groups/system.py +570 -0
  42. maque/cli/mllm_run.py +1451 -0
  43. maque/cli/script.py +52 -0
  44. maque/cli/tree.py +49 -0
  45. maque/clustering/__init__.py +52 -0
  46. maque/clustering/analyzer.py +347 -0
  47. maque/clustering/clusterers.py +464 -0
  48. maque/clustering/sampler.py +134 -0
  49. maque/clustering/visualizer.py +205 -0
  50. maque/constant.py +13 -0
  51. maque/core.py +133 -0
  52. maque/cv/__init__.py +1 -0
  53. maque/cv/image.py +219 -0
  54. maque/cv/utils.py +68 -0
  55. maque/cv/video/__init__.py +3 -0
  56. maque/cv/video/keyframe_extractor.py +368 -0
  57. maque/embedding/__init__.py +43 -0
  58. maque/embedding/base.py +56 -0
  59. maque/embedding/multimodal.py +308 -0
  60. maque/embedding/server.py +523 -0
  61. maque/embedding/text.py +311 -0
  62. maque/git/__init__.py +24 -0
  63. maque/git/pure_git.py +912 -0
  64. maque/io/__init__.py +29 -0
  65. maque/io/core.py +38 -0
  66. maque/io/ops.py +194 -0
  67. maque/llm/__init__.py +111 -0
  68. maque/llm/backend.py +416 -0
  69. maque/llm/base.py +411 -0
  70. maque/llm/server.py +366 -0
  71. maque/mcp_server.py +1096 -0
  72. maque/mllm_data_processor_pipeline/__init__.py +17 -0
  73. maque/mllm_data_processor_pipeline/core.py +341 -0
  74. maque/mllm_data_processor_pipeline/example.py +291 -0
  75. maque/mllm_data_processor_pipeline/steps/__init__.py +56 -0
  76. maque/mllm_data_processor_pipeline/steps/data_alignment.py +267 -0
  77. maque/mllm_data_processor_pipeline/steps/data_loader.py +172 -0
  78. maque/mllm_data_processor_pipeline/steps/data_validation.py +304 -0
  79. maque/mllm_data_processor_pipeline/steps/format_conversion.py +411 -0
  80. maque/mllm_data_processor_pipeline/steps/mllm_annotation.py +331 -0
  81. maque/mllm_data_processor_pipeline/steps/mllm_refinement.py +446 -0
  82. maque/mllm_data_processor_pipeline/steps/result_validation.py +501 -0
  83. maque/mllm_data_processor_pipeline/web_app.py +317 -0
  84. maque/nlp/__init__.py +14 -0
  85. maque/nlp/ngram.py +9 -0
  86. maque/nlp/parser.py +63 -0
  87. maque/nlp/risk_matcher.py +543 -0
  88. maque/nlp/sentence_splitter.py +202 -0
  89. maque/nlp/simple_tradition_cvt.py +31 -0
  90. maque/performance/__init__.py +21 -0
  91. maque/performance/_measure_time.py +70 -0
  92. maque/performance/_profiler.py +367 -0
  93. maque/performance/_stat_memory.py +51 -0
  94. maque/pipelines/__init__.py +15 -0
  95. maque/pipelines/clustering.py +252 -0
  96. maque/quantization/__init__.py +42 -0
  97. maque/quantization/auto_round.py +120 -0
  98. maque/quantization/base.py +145 -0
  99. maque/quantization/bitsandbytes.py +127 -0
  100. maque/quantization/llm_compressor.py +102 -0
  101. maque/retriever/__init__.py +35 -0
  102. maque/retriever/chroma.py +654 -0
  103. maque/retriever/document.py +140 -0
  104. maque/retriever/milvus.py +1140 -0
  105. maque/table_ops/__init__.py +1 -0
  106. maque/table_ops/core.py +133 -0
  107. maque/table_viewer/__init__.py +4 -0
  108. maque/table_viewer/download_assets.py +57 -0
  109. maque/table_viewer/server.py +698 -0
  110. maque/table_viewer/static/element-plus-icons.js +5791 -0
  111. maque/table_viewer/static/element-plus.css +1 -0
  112. maque/table_viewer/static/element-plus.js +65236 -0
  113. maque/table_viewer/static/main.css +268 -0
  114. maque/table_viewer/static/main.js +669 -0
  115. maque/table_viewer/static/vue.global.js +18227 -0
  116. maque/table_viewer/templates/index.html +401 -0
  117. maque/utils/__init__.py +56 -0
  118. maque/utils/color.py +68 -0
  119. maque/utils/color_string.py +45 -0
  120. maque/utils/compress.py +66 -0
  121. maque/utils/constant.py +183 -0
  122. maque/utils/core.py +261 -0
  123. maque/utils/cursor.py +143 -0
  124. maque/utils/distance.py +58 -0
  125. maque/utils/docker.py +96 -0
  126. maque/utils/downloads.py +51 -0
  127. maque/utils/excel_helper.py +542 -0
  128. maque/utils/helper_metrics.py +121 -0
  129. maque/utils/helper_parser.py +168 -0
  130. maque/utils/net.py +64 -0
  131. maque/utils/nvidia_stat.py +140 -0
  132. maque/utils/ops.py +53 -0
  133. maque/utils/packages.py +31 -0
  134. maque/utils/path.py +57 -0
  135. maque/utils/tar.py +260 -0
  136. maque/utils/untar.py +129 -0
  137. maque/web/__init__.py +0 -0
  138. maque/web/image_downloader.py +1410 -0
  139. maque-0.2.1.dist-info/METADATA +450 -0
  140. maque-0.2.1.dist-info/RECORD +143 -0
  141. maque-0.2.1.dist-info/WHEEL +4 -0
  142. maque-0.2.1.dist-info/entry_points.txt +3 -0
  143. maque-0.2.1.dist-info/licenses/LICENSE +21 -0
File without changes
@@ -0,0 +1,45 @@
1
+ import asyncio
2
+
3
+ from aiolimiter.compat import wait_for
4
+ from crawl4ai import AsyncWebCrawler, CacheMode
5
+
6
+
7
+ async def capture_and_save_screenshot(url: str, output_path: str):
8
+ async with AsyncWebCrawler(verbose=True, headless=False) as crawler:
9
+ result = await crawler.arun(
10
+ url=url,
11
+ # screenshot=True,
12
+ cache_mode=CacheMode.BYPASS,
13
+
14
+ # magic=True, # 当使用magic模式时,就不用设置下面两行
15
+
16
+ simulate_user=True, # Causes random mouse movements and clicks
17
+ override_navigator=True, # Makes the browser appear more like a real user
18
+
19
+ # include_links_on_markdown=False,
20
+
21
+ remove_overlay_elements=True, # Remove popups/modals
22
+ page_timeout=60000, # Increased timeout for protection checks
23
+ # wait_for="css:.content-loaded",
24
+
25
+
26
+ excluded_tags=['nav', 'footer'],
27
+
28
+ )
29
+ print(result.markdown)
30
+ # print(result.fit_markdown)
31
+
32
+ if result.success and result.screenshot:
33
+ import base64
34
+ screenshot_data = base64.b64decode(result.screenshot)
35
+ with open(output_path, "wb") as f:
36
+ f.write(screenshot_data)
37
+ print(f"Screenshot saved successfully to {output_path}")
38
+ else:
39
+ print("Failed to capture screenshot")
40
+
41
+
42
+ if __name__ == "__main__":
43
+ # asyncio.run(capture_and_save_screenshot("https://www.gradio.app/guides/object-detection-from-video", "screenshot.png"))
44
+ # asyncio.run(capture_and_save_screenshot("https://www.autohome.com.cn", "screenshot.png"))
45
+ asyncio.run(capture_and_save_screenshot("https://www.zhihu.com/question/654186093/answer/3483543427", "screenshot.png"))
@@ -0,0 +1,258 @@
1
+ import os
2
+ from tabulate import tabulate
3
+ from datetime import datetime
4
+ from maque.io import yaml_dump
5
+ from rich.console import Console
6
+ from rich.markdown import Markdown
7
+ from pathlib import Path
8
+ from typing import TYPE_CHECKING
9
+
10
+ if TYPE_CHECKING:
11
+ from pandas import DataFrame
12
+
13
+
14
+ def truncate_labels(labels, max_length=50):
15
+ """截断长标签,确保每个标签的长度不超过max_length"""
16
+ truncated_labels = []
17
+ for label in labels:
18
+ if len(label) > max_length:
19
+ truncated_label = label[:max_length] + "..." # 截断并加上省略号
20
+ else:
21
+ truncated_label = label
22
+ truncated_labels.append(truncated_label)
23
+ return truncated_labels
24
+
25
+
26
+ class MetricsCalculator:
27
+ def __init__(self, df: "DataFrame", pred_col: str = 'predict', label_col: str = 'label',
28
+ include_macro_micro_avg=False,
29
+ remove_matrix_zero_row=False,
30
+ ):
31
+ self.df = df
32
+ self.y_pred = df[pred_col]
33
+ self.y_true = df[label_col]
34
+ self.all_labels = sorted(list(set(self.y_true.unique()).union(set(self.y_pred.unique()))))
35
+ self.needed_labels = None
36
+ self.remove_matrix_zero_row = remove_matrix_zero_row
37
+ self.include_macro_micro_avg = include_macro_micro_avg
38
+ self.metrics = self._calculate_metrics()
39
+
40
+ def plot_confusion_matrix(self, save_path: str = None, figsize=(12, 10), font_scale=1.2, font_path=None,
41
+ x_rotation=45, y_rotation=0):
42
+ import matplotlib.pyplot as plt
43
+ import seaborn as sns
44
+ from matplotlib.font_manager import FontProperties
45
+ import matplotlib
46
+ import warnings
47
+ import warnings
48
+
49
+ # 全局设置默认字体,避免警告
50
+ matplotlib.rcParams['font.sans-serif'] = ['SimHei'] # 设置中文字体
51
+ matplotlib.rcParams['axes.unicode_minus'] = False # 解决负号显示问题
52
+
53
+ # 计算混淆矩阵
54
+ conf_matrix = self.metrics['confusion_matrix']
55
+
56
+ # 截断长标签
57
+ all_labels = truncate_labels(self.all_labels, max_length=6)
58
+ num_classes = len(all_labels) # 获取类的数量
59
+
60
+ # 指定中文字体路径
61
+ if font_path:
62
+ font_prop = FontProperties(fname=font_path)
63
+ else:
64
+ font_prop = None
65
+
66
+ # 设置动态字体大小,字体大小随着类的数量增加而减小
67
+ dynamic_font_size = max(8, 20 - num_classes) # 例如:最小字体为8,随类别数增多字体减小
68
+ tick_font_prop = FontProperties(fname=font_path, size=dynamic_font_size) if font_path else None
69
+
70
+ # 设置绘图的大小和风格
71
+ plt.figure(figsize=figsize)
72
+ sns.set_theme(font_scale=font_scale)
73
+
74
+ # 忽略所有 UserWarning
75
+ with warnings.catch_warnings():
76
+ warnings.filterwarnings("ignore", category=UserWarning)
77
+
78
+ # 绘制热力图
79
+ sns.heatmap(conf_matrix, annot=True, fmt="d", cmap="Blues", xticklabels=all_labels,
80
+ yticklabels=all_labels)
81
+
82
+ # 设置标题和轴标签
83
+ plt.title("Confusion Matrix", fontproperties=font_prop)
84
+ plt.xlabel("Predicted Labels", fontproperties=font_prop)
85
+ plt.ylabel("True Labels", fontproperties=font_prop)
86
+
87
+ # 设置轴标签的字体属性和角度
88
+ plt.xticks(ha="right", fontproperties=tick_font_prop, rotation=x_rotation)
89
+ plt.yticks(fontproperties=tick_font_prop, rotation=y_rotation)
90
+
91
+ plt.tight_layout()
92
+
93
+ # 保存或显示图表
94
+ if save_path:
95
+ plt.savefig(save_path)
96
+ plt.show()
97
+
98
+ def _calculate_metrics(self):
99
+ from sklearn.metrics import precision_score, recall_score, accuracy_score, confusion_matrix, \
100
+ classification_report
101
+ # 计算准确率
102
+ accuracy = accuracy_score(self.y_true, self.y_pred)
103
+
104
+ # 计算每个类别的精确率和召回率
105
+ precision = precision_score(self.y_true, self.y_pred, labels=self.all_labels, average='weighted',
106
+ zero_division=0)
107
+ recall = recall_score(self.y_true, self.y_pred, labels=self.all_labels, average='weighted', zero_division=0)
108
+
109
+ # 计算混淆矩阵
110
+ conf_matrix = confusion_matrix(self.y_true, self.y_pred, labels=self.all_labels)
111
+
112
+ # 计算每个类别的精确率、召回率、F1分数等
113
+ report = classification_report(self.y_true, self.y_pred, labels=self.all_labels, output_dict=True,
114
+ zero_division=0)
115
+ # 移除宏平均和微平均,默认只保留加权平均
116
+ if not self.include_macro_micro_avg:
117
+ report = {label: metrics for label, metrics in report.items() if
118
+ label in self.all_labels or label == 'weighted avg'}
119
+
120
+ # 从report中移除不需要的类别,具体来说,去除support为0的类别
121
+ report = {label: metrics for label, metrics in report.items()
122
+ if metrics['support'] > 0}
123
+
124
+ self.needed_labels = [label for label in report.keys() if label in self.all_labels]
125
+
126
+ # 移除matrix中不需要的行
127
+ needed_idx_list = [self.all_labels.index(label) for label in self.needed_labels]
128
+
129
+ if self.remove_matrix_zero_row:
130
+ conf_matrix = conf_matrix[needed_idx_list]
131
+
132
+ # 返回结果
133
+ return {
134
+ 'accuracy': accuracy,
135
+ 'precision': precision,
136
+ 'recall': recall,
137
+ 'confusion_matrix': conf_matrix,
138
+ 'classification_report': report
139
+ }
140
+
141
+ def get_metrics(self):
142
+ return self.metrics
143
+
144
+ def format_classification_report_as_markdown(self):
145
+ report = self.metrics['classification_report']
146
+ header = "| Label | Precision | Recall | F1-score | Support |\n"
147
+ separator = "|-------|-----------|--------|----------|---------|\n"
148
+ rows = []
149
+ for label, metrics in report.items():
150
+ if isinstance(metrics, dict):
151
+ rows.append(
152
+ f"| {label} | {metrics['precision']:.2f} | {metrics['recall']:.2f} | {metrics['f1-score']:.2f} | {metrics['support']:.0f} |")
153
+ return header + separator + "\n".join(rows)
154
+
155
+ def clean_label_for_markdown(self, label, max_length=20):
156
+ """清理标签文本,使其适合在markdown表格中显示"""
157
+ # 转换为字符串并替换换行符
158
+ label = str(label).replace('\n', ' ')
159
+
160
+ # 移除或替换可能破坏markdown格式的字符
161
+ label = label.replace("|", "\\|")
162
+ label = label.replace("-", "\\-")
163
+ label = label.replace("<", "&lt;")
164
+ label = label.replace(">", "&gt;")
165
+
166
+ # 截断长文本
167
+ if len(label) > max_length:
168
+ label = label[:max_length] + "..."
169
+
170
+ # 确保标签至少有一个可见字符
171
+ label = label.strip()
172
+ if not label:
173
+ label = "(empty)"
174
+
175
+ return label
176
+
177
+ def format_confusion_matrix_as_markdown(self, max_label_length=20):
178
+ matrix = self.metrics['confusion_matrix']
179
+
180
+ # 处理标签
181
+ if self.remove_matrix_zero_row:
182
+ labels = self.needed_labels
183
+ else:
184
+ labels = self.all_labels
185
+
186
+ # 处理所有标签
187
+ processed_labels = [self.clean_label_for_markdown(label, max_label_length) for label in labels]
188
+
189
+ # 构建表头,确保第一列也有标题
190
+ header = "| 真实值/预测值 | " + " | ".join(processed_labels) + " |\n"
191
+
192
+ # 修复分隔符,确保每列都有正确的分隔符
193
+ separator_parts = [":---:"] * (len(processed_labels) + 1) # +1 是为了第一列
194
+ separator = "| " + " | ".join(separator_parts) + " |\n"
195
+
196
+ rows = []
197
+ for i, row in enumerate(matrix):
198
+ # 处理行标签
199
+ row_label = self.clean_label_for_markdown(labels[i], max_label_length)
200
+ # 格式化数字
201
+ formatted_row = [f"{num:,}" for num in row]
202
+ rows.append(f"| {row_label} | " + " | ".join(formatted_row) + " |")
203
+
204
+ return header + separator + "\n".join(rows)
205
+
206
+
207
+ def save_pred_metrics(df: "DataFrame", pred_col: str, label_col: str,
208
+ record_folder='record', config=None, prompt=None, font_path=None,
209
+ plot_confusion_matrix=False,
210
+ ):
211
+ """ 保存预测结果的指标概览和分类报告 """
212
+ metrics_calculator = MetricsCalculator(df, pred_col=pred_col, label_col=label_col)
213
+ metrics = metrics_calculator.get_metrics()
214
+
215
+ table = [["指标概览", "Accuracy", "Precision", "Recall"],
216
+ ["值", metrics['accuracy'], metrics['precision'], metrics['recall']]]
217
+ md = f"\n\n### 指标概览\n\n{tabulate(table, headers='firstrow', tablefmt='github')}"
218
+ metrics_md = metrics_calculator.format_classification_report_as_markdown()
219
+ confusion_matrix_md = metrics_calculator.format_confusion_matrix_as_markdown()
220
+ md += (f"\n\n### Classification Report\n{metrics_md}\n"
221
+ f"\n### Confusion Matrix\n{confusion_matrix_md}")
222
+ now = datetime.now().strftime("%m月%d日%H时%M分%S秒")
223
+ record_folder = Path(record_folder)
224
+ record_folder = record_folder / f'记录时间-{now}'
225
+ record_folder.mkdir(parents=True, exist_ok=True)
226
+ console = Console()
227
+ console.print(Markdown(md))
228
+
229
+ # save files:
230
+ with open(os.path.join(record_folder, 'metrics.md'), 'w', encoding='utf-8') as f:
231
+ f.write(md)
232
+
233
+ if plot_confusion_matrix:
234
+ try:
235
+ metrics_calculator.plot_confusion_matrix(
236
+ save_path=os.path.join(record_folder, 'confusion_matrix.png'),
237
+ font_path=font_path,
238
+ x_rotation=45,
239
+ y_rotation=0,
240
+ )
241
+ except Exception as e:
242
+ print(f"warning: Failed to plot confusion matrix: {e}")
243
+
244
+ if prompt:
245
+ yaml_dump(os.path.join(record_folder, 'prompt.yaml'), prompt)
246
+ if config:
247
+ yaml_dump(os.path.join(record_folder, 'config.yaml'), config)
248
+
249
+ bad_case_df = df[df[pred_col] != df[label_col]]
250
+ try:
251
+ df.to_excel(os.path.join(record_folder, 'result.xlsx'), index=False, engine='openpyxl')
252
+ bad_case_df.to_excel(os.path.join(record_folder, 'bad_case.xlsx'), index=False, engine='openpyxl')
253
+ except Exception:
254
+ print("No module named 'openpyxl'. Please install it with 'pip install openpyxl'.\n"
255
+ "Save result.csv and bad_case.csv instead.")
256
+ df.to_csv(os.path.join(record_folder, 'result.csv'), index=False)
257
+ bad_case_df.to_csv(os.path.join(record_folder, 'bad_case.csv'), index=False)
258
+
@@ -0,0 +1,67 @@
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import io
4
+ import jionlp as jio
5
+ from maque.nlp.deduplicate import EditSimilarity
6
+
7
+ # streamlit run main.py --server.maxUploadSize=1024
8
+ simi = EditSimilarity()
9
+
10
+
11
+ def load_data(file):
12
+ return pd.read_excel(file)
13
+
14
+
15
+ def save_df_to_excel(df):
16
+ output = io.BytesIO()
17
+ with pd.ExcelWriter(output, engine='openpyxl') as writer:
18
+ df.to_excel(writer, index=False)
19
+ output.seek(0)
20
+ return output
21
+
22
+
23
+ st.title('文本预处理工具')
24
+
25
+ uploaded_file = st.file_uploader("选择一个 Excel 文件", type=['xlsx'])
26
+ if 'replay' not in st.session_state:
27
+ st.session_state['replay'] = False
28
+
29
+ if uploaded_file is not None:
30
+ df = load_data(uploaded_file)
31
+
32
+ st.write("数据预览:")
33
+ st.dataframe(df.head(10))
34
+ # 获取 Excel 文件中的列名
35
+ columns = df.columns.tolist()
36
+
37
+ # 让用户选择要清理的多个列
38
+ selected_columns = st.multiselect("选择要清理的列", columns)
39
+ # 选择操作
40
+ operation = st.multiselect("选择操作", [ "去重", "清洗文本", ])
41
+ dedup_threshold = 0.7
42
+ if "去重" in operation:
43
+ dedup_threshold = st.slider("选择去重阈值(越小越严格)", 0.0, 1.0, 0.7)
44
+
45
+ if st.button('执行'):
46
+ for column in selected_columns:
47
+ if "去重" in operation:
48
+ simi.load_from_df(df, target_col=column)
49
+ df = simi.deduplicate(threshold=dedup_threshold)
50
+ if "删除文本中的冗余字符" in operation:
51
+ df[column] = df[column].apply(lambda x: jio.remove_redundant_char(x))
52
+
53
+ if "清洗文本" in operation:
54
+ df[column] = df[column].apply(lambda x: jio.clean_text(x))
55
+
56
+ st.write("清理后的数据预览:")
57
+ st.dataframe(df.head(10))
58
+
59
+ cleaned_data = save_df_to_excel(df)
60
+ st.download_button(
61
+ label="下载处理后的数据",
62
+ data=cleaned_data,
63
+ file_name="cleaned_data.xlsx",
64
+ mime="application/vnd.ms-excel",
65
+ )
66
+ else:
67
+ st.info("请上传一个 Excel 文件。")
@@ -0,0 +1,195 @@
1
+ # pip install playwright
2
+ # playwright install
3
+
4
+ import asyncio
5
+ from playwright.async_api import async_playwright, TimeoutError as PlaywrightTimeoutError
6
+ from PIL import Image
7
+ from io import BytesIO
8
+ import traceback
9
+ from loguru import logger
10
+
11
+
12
+ class ScreenshotTaker:
13
+ def __init__(self, viewport_width=1920, viewport_height=1080, max_concurrent_tasks=16, retry_limit=3):
14
+ self.viewport_width = viewport_width
15
+ self.viewport_height = viewport_height
16
+ self.retry_limit = retry_limit # Maximum retry attempts for loading a page
17
+ self.browser = None
18
+ self.semaphore = asyncio.Semaphore(max_concurrent_tasks) # Limit concurrent tasks
19
+
20
+ async def _initialize_browser(self):
21
+ """Initialize the browser if not already done."""
22
+ if self.browser is None:
23
+ try:
24
+ self.playwright = await async_playwright().start()
25
+ self.browser = await self.playwright.chromium.launch(headless=True,
26
+ args=['--disable-web-security'], # Handle CORS
27
+ )
28
+ except Exception as e:
29
+ logger.error(f"Failed to initialize browser: {e}")
30
+ raise
31
+
32
+ async def _load_page_with_retries(self, page, url):
33
+ """Attempt to load a page with retries if it fails to load."""
34
+ for attempt in range(self.retry_limit):
35
+ try:
36
+ await page.goto(url, wait_until="networkidle", timeout=15000) # Short timeout for quicker retries
37
+ return True
38
+ except (PlaywrightTimeoutError, Exception) as e:
39
+ logger.warning(f"Attempt {attempt + 1} to load {url} failed: {e}")
40
+ await asyncio.sleep(1.1) # Wait briefly before retrying
41
+
42
+ logger.error(f"Failed to load {url} after {self.retry_limit} attempts.")
43
+ return False
44
+
45
+ async def capture_screenshot_in_context(self, context, url, screenshot_path=None):
46
+ """Capture a full-page screenshot of a single URL within a shared context."""
47
+ async with self.semaphore: # Limit concurrency
48
+ try:
49
+ # Open a new page in the shared context
50
+ page = await context.new_page()
51
+
52
+ # Attempt to load the page with retries
53
+ if not await self._load_page_with_retries(page, url):
54
+ await page.close()
55
+ return None # Skip screenshot if page load fails
56
+
57
+ logger.debug(f"Loading {url}")
58
+ # Scroll progressively to ensure lazy-loaded content is displayed
59
+ # 滚动次数
60
+ # scroll_height = await page.evaluate("document.body.scrollHeight")
61
+ # viewport_height = await page.evaluate("window.innerHeight")
62
+ # scroll_times = max(1, scroll_height // viewport_height)
63
+ # logger.info(f"Scrolling {scroll_times} times")
64
+ # for _ in range(scroll_times):
65
+ # await page.evaluate("""() => { window.scrollBy(0, window.innerHeight); }""")
66
+ # await asyncio.sleep(1/scroll_times)
67
+ await page.evaluate("""() => { window.scrollTo(0, document.body.scrollHeight); }""")
68
+ await asyncio.sleep(1.1) # Allow time for content to load
69
+
70
+ # Adjust iframe heights if present
71
+ for iframe_element in await page.query_selector_all("iframe"):
72
+ try:
73
+ frame = await iframe_element.content_frame()
74
+ if frame:
75
+ logger.debug(f"Handling iframe in {url}")
76
+ frame_height = await frame.evaluate("document.body.scrollHeight")
77
+ await iframe_element.evaluate(f"el => el.style.height = '{frame_height}px'")
78
+ await asyncio.sleep(1.1)
79
+ except Exception as iframe_error:
80
+ logger.warning(f"Error handling iframe in {url}: {iframe_error}")
81
+
82
+ # Capture and store screenshot as bytes
83
+ screenshot_bytes = await page.screenshot(full_page=True)
84
+ image = Image.open(BytesIO(screenshot_bytes))
85
+
86
+ # Save screenshot if a path is specified
87
+ if screenshot_path:
88
+ image.save(screenshot_path)
89
+ logger.info(f"Screenshot saved at {screenshot_path}")
90
+
91
+ await page.close() # Close page to release resources
92
+ return image
93
+
94
+ except Exception as e:
95
+ logger.error(f"Unexpected error capturing screenshot for {url}: {e}")
96
+ traceback.print_exc()
97
+ return None # Return None if an error occurs
98
+
99
+ async def _capture_batch_screenshots(self, urls, screenshot_paths=None):
100
+ """
101
+ Capture screenshots for a batch of URLs within a single context.
102
+
103
+ Parameters:
104
+ - urls: list of URLs to capture
105
+ - screenshot_paths: optional list of paths to save each screenshot.
106
+ If None, screenshots will not be saved to files.
107
+
108
+ Returns:
109
+ - List of PIL image objects for each URL.
110
+ """
111
+ await self._initialize_browser()
112
+ screenshot_paths = screenshot_paths or [None] * len(urls)
113
+
114
+ # Create a new context for this batch
115
+ context = await self.browser.new_context(
116
+ viewport={"width": self.viewport_width, "height": self.viewport_height}
117
+ )
118
+
119
+ # Execute all screenshot tasks within this context
120
+ tasks = [self.capture_screenshot_in_context(context, url, path) for url, path in zip(urls, screenshot_paths)]
121
+ images = await asyncio.gather(*tasks, return_exceptions=True)
122
+
123
+ # Close context after batch is completed
124
+ await context.close()
125
+
126
+ # Filter out None or exceptions from results
127
+ images = [img for img in images if isinstance(img, Image.Image)]
128
+ return images
129
+
130
+ async def capture_multi_screenshots(self, urls, batch_size=3, max_batches=3, screenshot_paths=None):
131
+ """
132
+ Capture screenshots for multiple URLs in multiple batches, with specified batch size and concurrency.
133
+
134
+ Parameters:
135
+ - urls: list of URLs to capture
136
+ - batch_size: number of URLs per batch
137
+ - max_batches: maximum number of concurrent batches
138
+ - screenshot_paths: optional list of paths to save each screenshot.
139
+ If None, screenshots will not be saved to files.
140
+
141
+ Returns:
142
+ - List of PIL image objects for each URL.
143
+ """
144
+ await self._initialize_browser()
145
+ screenshot_paths = screenshot_paths or [None] * len(urls)
146
+
147
+ # Split URLs into batches
148
+ url_batches = [urls[i:i + batch_size] for i in range(0, len(urls), batch_size)]
149
+ path_batches = [screenshot_paths[i:i + batch_size] for i in range(0, len(screenshot_paths), batch_size)]
150
+
151
+ # Limit concurrent batch execution
152
+ batch_semaphore = asyncio.Semaphore(max_batches)
153
+
154
+ async def run_batch(urls, paths):
155
+ async with batch_semaphore:
156
+ return await self._capture_batch_screenshots(urls, paths)
157
+
158
+ # Schedule all batch tasks
159
+ tasks = [run_batch(url_batch, path_batch) for url_batch, path_batch in zip(url_batches, path_batches)]
160
+ batch_images = await asyncio.gather(*tasks, return_exceptions=True)
161
+
162
+ # Flatten the list of images from all batches and filter out None or exceptions
163
+ images = [img for batch in batch_images if isinstance(batch, list) for img in batch]
164
+ return images
165
+
166
+ async def close(self):
167
+ """Close the browser and Playwright if they are open."""
168
+ if self.browser:
169
+ await self.browser.close()
170
+ self.browser = None
171
+ if hasattr(self, 'playwright'):
172
+ await self.playwright.stop()
173
+
174
+
175
+ if __name__ == "__main__":
176
+ async def main():
177
+ screenshot_taker = ScreenshotTaker() # Limit concurrent requests
178
+ url1 = "https://baidu.com"
179
+ urls = [url1] * 20 # Example: large batch of URLs
180
+ paths = [f"screenshot_{i}.png" for i in range(len(urls))]
181
+
182
+ # Batch screenshot and get a list of PIL image objects
183
+ images = await screenshot_taker.capture_multi_screenshots(urls,
184
+ batch_size=3,
185
+ max_batches=3,
186
+ screenshot_paths=paths)
187
+
188
+ # Close the browser
189
+ await screenshot_taker.close()
190
+ return images
191
+
192
+
193
+ asyncio.run(main())
194
+
195
+
@@ -0,0 +1,78 @@
1
+ """
2
+ 算法模块 - 数据结构和算法实现
3
+
4
+ 包含各种算法实现:Trie树、变换算法、去重算法、数学函数等
5
+ """
6
+
7
+ # 数据结构
8
+ from .trie import *
9
+ from .bktree import *
10
+
11
+ # 变换算法
12
+ from .hilbert import *
13
+ from .transform import *
14
+
15
+ # 去重算法
16
+ from .video import * # video deduplication
17
+
18
+ # 数学函数
19
+ from .bezier import *
20
+ from .core import * # functions core
21
+ from .rate_function import *
22
+ from .utils import * # functions utils
23
+
24
+ __all__ = [
25
+ # Data structures
26
+ "BKTree",
27
+ "brute_query",
28
+ "levenshtein",
29
+ "Trie",
30
+ "PyTrie",
31
+ "HatTrie",
32
+ "MarisaTrie",
33
+ "DaTrie",
34
+ "AutomatonTrie",
35
+ "Benchmark",
36
+ # Transforms
37
+ "get_hilbert_1d_array",
38
+ "repeat",
39
+ # Video deduplication
40
+ "VideoFrameDeduplicator",
41
+ # Math functions
42
+ "bezier",
43
+ "dict_topk",
44
+ "topk",
45
+ "random_idx",
46
+ "clamp",
47
+ "choose_using_cache",
48
+ "choose",
49
+ "get_num_args",
50
+ "get_parameters",
51
+ # Rate functions
52
+ "linear",
53
+ "smooth",
54
+ "rush_into",
55
+ "rush_from",
56
+ "slow_into",
57
+ "double_smooth",
58
+ "there_and_back",
59
+ "there_and_back_with_pause",
60
+ "running_start",
61
+ "not_quite_there",
62
+ "wiggle",
63
+ "squish_rate_func",
64
+ "lingering",
65
+ "exponential_decay",
66
+ # Utils
67
+ "exists",
68
+ "default",
69
+ "cast_tuple",
70
+ "null_context",
71
+ "pick_and_pop",
72
+ "group_dict_by_key",
73
+ "string_begins_with",
74
+ "group_by_key_prefix",
75
+ "groupby_prefix_and_trim",
76
+ "num_to_groups",
77
+ "find_first",
78
+ ]
@@ -0,0 +1,15 @@
1
+ from .core import choose
2
+
3
+
4
+ def bezier(points):
5
+ n = len(points) - 1
6
+
7
+ def result(t):
8
+ return sum(
9
+ [
10
+ ((1 - t) ** (n - k)) * (t**k) * choose(n, k) * point
11
+ for k, point in enumerate(points)
12
+ ]
13
+ )
14
+
15
+ return result