@huyooo/ai-chat-shared 0.2.12 → 0.2.14

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/src/markdown.ts CHANGED
@@ -1,12 +1,236 @@
1
1
  /**
2
2
  * Markdown 渲染工具
3
3
  * 提供统一的 Markdown 渲染函数,供 React 和 Vue 版本使用
4
+ * 支持 LaTeX 数学公式渲染和 Mermaid 图表
4
5
  */
5
6
 
6
7
  import { marked, Renderer } from 'marked'
7
8
  import DOMPurify from 'dompurify'
9
+ import katex from 'katex'
10
+ import mermaid from 'mermaid'
8
11
  import { highlightCode } from './highlighter'
9
12
 
13
+ // Mermaid 初始化标记
14
+ let mermaidInitialized = false
15
+
16
+ /**
17
+ * base64url 编码(UTF-8)
18
+ * 不能直接把源码放进 attribute(换行会被规范化),所以使用 base64 传递。
19
+ * 同时避免在某些链路里 '+' 被当成空格的问题,这里使用 base64url(- _,无 padding)。
20
+ */
21
+ function encodeBase64Utf8(text: string): string {
22
+ const bytes = new TextEncoder().encode(text)
23
+ let binary = ''
24
+ for (const b of bytes) binary += String.fromCharCode(b)
25
+ const b64 = btoa(binary)
26
+ // base64url: +/ -> -_,去掉末尾 padding =
27
+ return b64.replace(/\+/g, '-').replace(/\//g, '_').replace(/=+$/g, '')
28
+ }
29
+
30
+ /**
31
+ * Mermaid 源码编码成可安全放入 attribute 的 base64url(UTF-8)
32
+ * 给前端组件(非 markdown 渲染路径)复用,避免重复实现。
33
+ */
34
+ export function encodeMermaidCodeToBase64Url(code: string): string {
35
+ return encodeBase64Utf8(code)
36
+ }
37
+
38
+ /**
39
+ * base64url 解码(UTF-8)
40
+ */
41
+ function decodeBase64Utf8(b64: string): string {
42
+ // 兼容:某些情况下 '+' 可能被变成空格
43
+ const normalized = b64.replace(/\s/g, '+')
44
+ // base64url -> base64
45
+ let base64 = normalized.replace(/-/g, '+').replace(/_/g, '/')
46
+ // 补齐 padding
47
+ const pad = base64.length % 4
48
+ if (pad === 2) base64 += '=='
49
+ else if (pad === 3) base64 += '='
50
+ else if (pad !== 0) {
51
+ // 非法长度,交给 atob 抛错
52
+ }
53
+
54
+ const binary = atob(base64)
55
+ const bytes = new Uint8Array(binary.length)
56
+ for (let i = 0; i < binary.length; i++) bytes[i] = binary.charCodeAt(i)
57
+ return new TextDecoder().decode(bytes)
58
+ }
59
+
60
+ /**
61
+ * 初始化 Mermaid(仅需调用一次)
62
+ */
63
+ export function initMermaid(): void {
64
+ if (mermaidInitialized) return
65
+
66
+ mermaid.initialize({
67
+ startOnLoad: false,
68
+ theme: 'dark',
69
+ securityLevel: 'loose',
70
+ fontFamily: 'inherit',
71
+ flowchart: {
72
+ useMaxWidth: true,
73
+ htmlLabels: true,
74
+ curve: 'basis',
75
+ },
76
+ sequence: {
77
+ useMaxWidth: true,
78
+ },
79
+ gantt: {
80
+ useMaxWidth: true,
81
+ },
82
+ })
83
+
84
+ mermaidInitialized = true
85
+ }
86
+
87
+ /**
88
+ * 渲染页面中的 Mermaid 图表
89
+ * 用户点击"显示图表"按钮后调用
90
+ * @param container 容器元素(可选,默认为 document)
91
+ */
92
+ export async function renderMermaidDiagrams(container?: Element): Promise<void> {
93
+ if (!mermaidInitialized) {
94
+ initMermaid()
95
+ }
96
+
97
+ const root = container || document
98
+ // 只选择未渲染的占位符
99
+ const elements = root.querySelectorAll('.mermaid-placeholder:not(.mermaid-rendered):not(.mermaid-error-container)')
100
+
101
+ for (const el of elements) {
102
+ const b64 = el.getAttribute('data-mermaid-code-b64')
103
+ if (!b64) continue
104
+
105
+ let code = ''
106
+ try {
107
+ code = decodeBase64Utf8(b64)
108
+ } catch (e) {
109
+ console.warn('Mermaid 解码失败:', e)
110
+ el.innerHTML = `<div class="mermaid-error">图表解码失败</div>`
111
+ el.classList.remove('mermaid-placeholder')
112
+ el.classList.add('mermaid-error-container')
113
+ continue
114
+ }
115
+
116
+ // 显示加载中
117
+ el.innerHTML = `<div class="mermaid-loading">渲染中...</div>`
118
+
119
+ try {
120
+ const id = `mermaid-${Math.random().toString(36).slice(2, 11)}`
121
+ const { svg } = await mermaid.render(id, code)
122
+ el.innerHTML = svg
123
+ el.classList.remove('mermaid-placeholder')
124
+ el.classList.add('mermaid-rendered')
125
+ el.removeAttribute('data-mermaid-code-b64')
126
+ } catch (error) {
127
+ console.warn('Mermaid 渲染失败:', error)
128
+ // 直接显示错误和源码
129
+ el.innerHTML = `<div class="mermaid-error">图表渲染失败</div><pre class="mermaid-source">${code.replace(/</g, '&lt;').replace(/>/g, '&gt;')}</pre>`
130
+ el.classList.remove('mermaid-placeholder')
131
+ el.classList.add('mermaid-error-container')
132
+ }
133
+ }
134
+ }
135
+
136
+ /**
137
+ * 渲染 LaTeX 公式
138
+ * @param latex LaTeX 字符串
139
+ * @param displayMode 是否为块级公式
140
+ * @returns 渲染后的 HTML
141
+ */
142
+ function renderLatex(latex: string, displayMode: boolean): string {
143
+ try {
144
+ return katex.renderToString(latex, {
145
+ displayMode,
146
+ throwOnError: false,
147
+ strict: false,
148
+ trust: true,
149
+ output: 'html',
150
+ })
151
+ } catch (error) {
152
+ // 渲染失败时返回原文
153
+ console.warn('LaTeX 渲染失败:', error)
154
+ return displayMode
155
+ ? `<div class="latex-error">$$${latex}$$</div>`
156
+ : `<span class="latex-error">$${latex}$</span>`
157
+ }
158
+ }
159
+
160
+ /**
161
+ * 渲染 fenced 的 LaTeX 代码块(```latex / ```katex),用于“先代码、后渲染”的视图切换。
162
+ * 注意:这里只返回 HTML 字符串,容器样式由 `.latex-block` 控制。
163
+ */
164
+ export function renderLatexBlockToHtml(code: string): string {
165
+ // 滚动条样式与系统对齐:直接复用全局 `.chat-scrollbar`
166
+ return `<div class="latex-block chat-scrollbar">${renderLatex(code.trim(), true)}</div>`
167
+ }
168
+
169
+ /**
170
+ * 预处理文本中的 LaTeX 公式
171
+ * 将 LaTeX 公式替换为占位符,避免被 markdown 解析器处理
172
+ * @param text 原始文本
173
+ * @returns { processed: 处理后文本, placeholders: 占位符映射 }
174
+ */
175
+ function preprocessLatex(text: string): { processed: string; placeholders: Map<string, string> } {
176
+ const placeholders = new Map<string, string>()
177
+ let counter = 0
178
+
179
+ // 生成唯一占位符(使用 HTML 注释格式,避免被 markdown 解析)
180
+ const createPlaceholder = () => {
181
+ const placeholder = `<!--LATEX:${counter++}-->`
182
+ return placeholder
183
+ }
184
+
185
+ let processed = text
186
+
187
+ // 处理块级公式 $$...$$ (优先处理,避免和行内公式冲突)
188
+ processed = processed.replace(/\$\$([\s\S]*?)\$\$/g, (_, latex) => {
189
+ const placeholder = createPlaceholder()
190
+ const rendered = renderLatex(latex.trim(), true)
191
+ placeholders.set(placeholder, `<div class="latex-block chat-scrollbar">${rendered}</div>`)
192
+ return placeholder
193
+ })
194
+
195
+ // 处理块级公式 \[...\]
196
+ processed = processed.replace(/\\\[([\s\S]*?)\\\]/g, (_, latex) => {
197
+ const placeholder = createPlaceholder()
198
+ const rendered = renderLatex(latex.trim(), true)
199
+ placeholders.set(placeholder, `<div class="latex-block chat-scrollbar">${rendered}</div>`)
200
+ return placeholder
201
+ })
202
+
203
+ // 处理行内公式 $...$ (单个 $,不能跨行)
204
+ // 注意:避免匹配 $$ 和转义的 \$
205
+ processed = processed.replace(/(?<!\$)\$(?!\$)((?:\\.|[^$\n])+?)\$(?!\$)/g, (_, latex) => {
206
+ const placeholder = createPlaceholder()
207
+ const rendered = renderLatex(latex.trim(), false)
208
+ placeholders.set(placeholder, `<span class="latex-inline">${rendered}</span>`)
209
+ return placeholder
210
+ })
211
+
212
+ // 处理行内公式 \(...\)
213
+ processed = processed.replace(/\\\(([\s\S]*?)\\\)/g, (_, latex) => {
214
+ const placeholder = createPlaceholder()
215
+ const rendered = renderLatex(latex.trim(), false)
216
+ placeholders.set(placeholder, `<span class="latex-inline">${rendered}</span>`)
217
+ return placeholder
218
+ })
219
+
220
+ return { processed, placeholders }
221
+ }
222
+
223
+ /**
224
+ * 后处理:将占位符替换回渲染后的 LaTeX
225
+ */
226
+ function postprocessLatex(html: string, placeholders: Map<string, string>): string {
227
+ let result = html
228
+ placeholders.forEach((rendered, placeholder) => {
229
+ result = result.replace(placeholder, rendered)
230
+ })
231
+ return result
232
+ }
233
+
10
234
  /**
11
235
  * 渲染 Markdown 为 HTML
12
236
  * @param markdown Markdown 文本
@@ -15,19 +239,33 @@ import { highlightCode } from './highlighter'
15
239
  export function renderMarkdown(markdown: string): string {
16
240
  if (!markdown) return ''
17
241
 
242
+ // 预处理 LaTeX 公式
243
+ const { processed, placeholders } = preprocessLatex(markdown)
244
+
18
245
  // 创建自定义渲染器(继承默认渲染器)
19
246
  const renderer = new Renderer()
20
247
 
21
- // 自定义代码块渲染
248
+ // 自定义代码块渲染(支持 Mermaid)
22
249
  renderer.code = (code: string, language?: string) => {
250
+ const lang = (language || 'plaintext').toLowerCase()
251
+
252
+ // Mermaid 图表:生成占位符,延迟渲染
253
+ if (lang === 'mermaid') {
254
+ // 注意:HTML attribute 会把换行规范化成空格,直接写入会破坏 mermaid 语法
255
+ // 所以这里用 base64(UTF-8) 存储源码,渲染时再解码
256
+ const b64 = encodeBase64Utf8(code)
257
+ return `<div class="mermaid-placeholder" data-mermaid-code-b64="${b64}"><div class="mermaid-loading">加载图表中...</div></div>`
258
+ }
259
+
23
260
  const highlighted = highlightCode(code, language)
24
- const lang = language || 'plaintext'
25
- return `<pre class="markdown-code-block"><code class="language-${lang}">${highlighted}</code></pre>`
261
+ // 滚动条样式与系统对齐:直接复用全局 `.chat-scrollbar`
262
+ return `<pre class="markdown-code-block chat-scrollbar"><code class="language-${lang}">${highlighted}</code></pre>`
26
263
  }
27
264
 
28
265
  // 自定义表格渲染(添加样式类和滚动容器)
29
266
  renderer.table = (header: string, body: string) => {
30
- return `<div class="markdown-table-wrapper"><table class="markdown-table"><thead>${header}</thead><tbody>${body}</tbody></table></div>`
267
+ // 滚动条样式与系统对齐:直接复用全局 `.chat-scrollbar`
268
+ return `<div class="markdown-table-wrapper chat-scrollbar"><table class="markdown-table"><thead>${header}</thead><tbody>${body}</tbody></table></div>`
31
269
  }
32
270
 
33
271
  // 自定义链接渲染(添加 target="_blank" 和 rel="noopener noreferrer")
@@ -56,9 +294,13 @@ export function renderMarkdown(markdown: string): string {
56
294
  })
57
295
 
58
296
  // 渲染 Markdown(marked 12.x 返回字符串,不是 Promise)
59
- let html = marked(markdown, { renderer }) as string
297
+ let html = marked(processed, { renderer }) as string
298
+
299
+ // 后处理:将占位符替换回渲染后的 LaTeX
300
+ html = postprocessLatex(html, placeholders)
60
301
 
61
302
  // 使用 DOMPurify 清理 HTML,防止 XSS
303
+ // 允许 KaTeX 生成的 SVG 和相关元素
62
304
  html = DOMPurify.sanitize(html, {
63
305
  ALLOWED_TAGS: [
64
306
  'p', 'br', 'strong', 'em', 'u', 's', 'code', 'pre',
@@ -67,13 +309,27 @@ export function renderMarkdown(markdown: string): string {
67
309
  'table', 'thead', 'tbody', 'tr', 'th', 'td',
68
310
  'a', 'img', 'hr',
69
311
  'div', 'span',
312
+ // KaTeX 相关标签
313
+ 'math', 'semantics', 'mrow', 'mi', 'mn', 'mo', 'msup', 'msub',
314
+ 'mfrac', 'mroot', 'msqrt', 'mtext', 'mspace', 'mtable', 'mtr', 'mtd',
315
+ 'annotation', 'mover', 'munder', 'munderover', 'menclose', 'mpadded',
316
+ 'svg', 'line', 'path', 'rect', 'circle', 'g', 'use', 'defs', 'symbol',
70
317
  ],
71
318
  ALLOWED_ATTR: [
72
319
  'href', 'title', 'target', 'rel', 'class',
73
320
  'src', 'alt', 'width', 'height',
321
+ // Mermaid 相关属性
322
+ 'data-mermaid-code-b64',
323
+ // KaTeX 相关属性
324
+ 'style', 'xmlns', 'viewBox', 'preserveAspectRatio',
325
+ 'd', 'x', 'y', 'x1', 'y1', 'x2', 'y2', 'r', 'cx', 'cy',
326
+ 'fill', 'stroke', 'stroke-width', 'transform',
327
+ 'xlink:href', 'aria-hidden', 'focusable', 'role',
328
+ 'mathvariant', 'encoding', 'stretchy', 'fence', 'separator',
329
+ 'lspace', 'rspace', 'minsize', 'maxsize', 'accent', 'accentunder',
74
330
  ],
331
+ ADD_ATTR: ['xmlns:xlink'],
75
332
  })
76
333
 
77
334
  return html
78
335
  }
79
-
package/src/parser.ts CHANGED
@@ -33,7 +33,7 @@ export function parseContent(raw: string): ContentBlock[] {
33
33
  if (match.index > lastIndex) {
34
34
  const textContent = raw.slice(lastIndex, match.index)
35
35
  if (textContent.trim()) {
36
- blocks.push(createTextBlock(textContent))
36
+ blocks.push(createTextBlock(textContent, undefined, true))
37
37
  }
38
38
  }
39
39
 
@@ -49,7 +49,7 @@ export function parseContent(raw: string): ContentBlock[] {
49
49
  if (lastIndex < raw.length) {
50
50
  const textContent = raw.slice(lastIndex)
51
51
  if (textContent.trim()) {
52
- blocks.push(createTextBlock(textContent))
52
+ blocks.push(createTextBlock(textContent, undefined, true))
53
53
  }
54
54
  }
55
55
 
@@ -57,18 +57,18 @@ export function parseContent(raw: string): ContentBlock[] {
57
57
  }
58
58
 
59
59
  /** 创建文本块 */
60
- function createTextBlock(content: string): TextBlock {
60
+ function createTextBlock(content: string, id?: string, trimContent = true): TextBlock {
61
61
  return {
62
- id: generateId(),
62
+ id: id || generateId(),
63
63
  type: 'text',
64
- content: content.trim(),
64
+ content: trimContent ? content.trim() : content,
65
65
  }
66
66
  }
67
67
 
68
68
  /** 创建代码块 */
69
- function createCodeBlock(content: string, language?: string, filename?: string): CodeBlock {
69
+ function createCodeBlock(content: string, language?: string, filename?: string, id?: string): CodeBlock {
70
70
  return {
71
- id: generateId(),
71
+ id: id || generateId(),
72
72
  type: 'code',
73
73
  content,
74
74
  language,
@@ -90,6 +90,12 @@ export interface StreamParseState {
90
90
  codeLanguage?: string
91
91
  /** 当前代码块内容 */
92
92
  codeContent: string
93
+ /** 当前代码块 ID(用于稳定 key,避免流式时每次 render 都生成新 id) */
94
+ codeBlockId: string | null
95
+ /** 当前文本块内容(流式追加,避免把一句话拆成很多 text block) */
96
+ textContent: string
97
+ /** 当前文本块 ID(稳定 key) */
98
+ textBlockId: string | null
93
99
  }
94
100
 
95
101
  /** 创建初始流式解析状态 */
@@ -100,9 +106,25 @@ export function createStreamParseState(): StreamParseState {
100
106
  inCodeBlock: false,
101
107
  codeLanguage: undefined,
102
108
  codeContent: '',
109
+ codeBlockId: null,
110
+ textContent: '',
111
+ textBlockId: null,
103
112
  }
104
113
  }
105
114
 
115
+ function appendText(state: StreamParseState, text: string): void {
116
+ if (!text) return
117
+ if (state.textBlockId === null) state.textBlockId = generateId()
118
+ state.textContent += text
119
+ }
120
+
121
+ function flushTextBlock(state: StreamParseState): void {
122
+ if (!state.textContent.trim()) return
123
+ state.blocks.push(createTextBlock(state.textContent, state.textBlockId || undefined, false))
124
+ state.textContent = ''
125
+ state.textBlockId = null
126
+ }
127
+
106
128
  /**
107
129
  * 流式解析(增量更新)
108
130
  * @param chunk 新增的文本块
@@ -110,7 +132,7 @@ export function createStreamParseState(): StreamParseState {
110
132
  * @returns 更新后的状态
111
133
  */
112
134
  export function parseContentStream(chunk: string, state: StreamParseState): StreamParseState {
113
- const newState = { ...state }
135
+ const newState: StreamParseState = { ...state, blocks: [...state.blocks] }
114
136
  newState.buffer += chunk
115
137
 
116
138
  // 处理缓冲区
@@ -121,14 +143,19 @@ export function parseContentStream(chunk: string, state: StreamParseState): Stre
121
143
  if (endIndex !== -1) {
122
144
  // 找到结束标记
123
145
  newState.codeContent += newState.buffer.slice(0, endIndex)
124
- newState.blocks.push(createCodeBlock(
125
- newState.codeContent.trim(),
126
- newState.codeLanguage
127
- ))
146
+ newState.blocks.push(
147
+ createCodeBlock(
148
+ newState.codeContent.trim(),
149
+ newState.codeLanguage,
150
+ undefined,
151
+ newState.codeBlockId || undefined
152
+ )
153
+ )
128
154
  newState.buffer = newState.buffer.slice(endIndex + 3)
129
155
  newState.inCodeBlock = false
130
156
  newState.codeLanguage = undefined
131
157
  newState.codeContent = ''
158
+ newState.codeBlockId = null
132
159
  } else {
133
160
  // 没找到结束标记,继续累积
134
161
  newState.codeContent += newState.buffer
@@ -141,9 +168,8 @@ export function parseContentStream(chunk: string, state: StreamParseState): Stre
141
168
  if (startIndex !== -1) {
142
169
  // 找到开始标记
143
170
  const beforeCode = newState.buffer.slice(0, startIndex)
144
- if (beforeCode.trim()) {
145
- newState.blocks.push(createTextBlock(beforeCode))
146
- }
171
+ appendText(newState, beforeCode)
172
+ flushTextBlock(newState)
147
173
 
148
174
  // 解析语言标识
149
175
  const afterStart = newState.buffer.slice(startIndex + 3)
@@ -153,25 +179,22 @@ export function parseContentStream(chunk: string, state: StreamParseState): Stre
153
179
  newState.buffer = afterStart.slice(newlineIndex + 1)
154
180
  newState.inCodeBlock = true
155
181
  newState.codeContent = ''
182
+ newState.codeBlockId = generateId()
156
183
  } else {
157
184
  // 语言行还没完整,等待更多数据
158
185
  break
159
186
  }
160
187
  } else {
161
- // 没有代码块标记,检查是否有不完整的 ``` 开头
188
+ // 没有代码块标记:追加到“当前文本块”里(不 flush),避免碎片化
189
+ // 同时检查末尾是否可能是不完整的 ```,如果是就保留在 buffer 等待后续补全
162
190
  const lastBackticks = newState.buffer.lastIndexOf('`')
163
191
  if (lastBackticks !== -1 && newState.buffer.length - lastBackticks < 3) {
164
192
  // 可能是不完整的 ```,保留
165
193
  const safeText = newState.buffer.slice(0, lastBackticks)
166
- if (safeText.trim()) {
167
- newState.blocks.push(createTextBlock(safeText))
168
- }
194
+ appendText(newState, safeText)
169
195
  newState.buffer = newState.buffer.slice(lastBackticks)
170
196
  } else {
171
- // 安全的纯文本
172
- if (newState.buffer.trim()) {
173
- newState.blocks.push(createTextBlock(newState.buffer))
174
- }
197
+ appendText(newState, newState.buffer)
175
198
  newState.buffer = ''
176
199
  }
177
200
  break
@@ -192,13 +215,20 @@ export function finishStreamParse(state: StreamParseState): ContentBlock[] {
192
215
 
193
216
  if (state.inCodeBlock) {
194
217
  // 未闭合的代码块,作为代码块处理
195
- blocks.push(createCodeBlock(
196
- (state.codeContent + state.buffer).trim(),
197
- state.codeLanguage
198
- ))
199
- } else if (state.buffer.trim()) {
200
- // 剩余文本
201
- blocks.push(createTextBlock(state.buffer))
218
+ blocks.push(
219
+ createCodeBlock(
220
+ (state.codeContent + state.buffer).trim(),
221
+ state.codeLanguage,
222
+ undefined,
223
+ state.codeBlockId || undefined
224
+ )
225
+ )
226
+ } else {
227
+ // 剩余文本 + 累积文本:合并成一个 text block(稳定 id),避免碎片化
228
+ const merged = state.textContent + state.buffer
229
+ if (merged.trim()) {
230
+ blocks.push(createTextBlock(merged, state.textBlockId || undefined, false))
231
+ }
202
232
  }
203
233
 
204
234
  return blocks