@mastra/rag 0.1.19-alpha.2 → 0.1.19-alpha.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/.turbo/turbo-build.log +7 -7
- package/CHANGELOG.md +15 -0
- package/dist/_tsup-dts-rollup.d.cts +237 -5
- package/dist/_tsup-dts-rollup.d.ts +237 -5
- package/dist/index.cjs +3980 -14
- package/dist/index.js +3977 -11
- package/package.json +2 -4
- package/src/document/document.test.ts +123 -2
- package/src/document/document.ts +15 -21
- package/src/document/extractors/index.ts +5 -0
- package/src/document/extractors/keywords.test.ts +119 -0
- package/src/document/extractors/keywords.ts +123 -0
- package/src/document/extractors/questions.test.ts +120 -0
- package/src/document/extractors/questions.ts +126 -0
- package/src/document/extractors/summary.test.ts +107 -0
- package/src/document/extractors/summary.ts +130 -0
- package/src/document/extractors/title.test.ts +121 -0
- package/src/document/extractors/title.ts +210 -0
- package/src/document/extractors/types.ts +40 -0
- package/src/document/types.ts +5 -33
|
@@ -0,0 +1,126 @@
|
|
|
1
|
+
import type { MastraLanguageModel } from '@mastra/core/agent';
|
|
2
|
+
import { PromptTemplate, defaultQuestionExtractPrompt, TextNode, BaseExtractor } from 'llamaindex';
|
|
3
|
+
import type { QuestionExtractPrompt, BaseNode } from 'llamaindex';
|
|
4
|
+
import { baseLLM, STRIP_REGEX } from './types';
|
|
5
|
+
import type { QuestionAnswerExtractArgs } from './types';
|
|
6
|
+
|
|
7
|
+
type ExtractQuestion = {
|
|
8
|
+
/**
|
|
9
|
+
* Questions extracted from the node as a string (may be empty if extraction fails).
|
|
10
|
+
*/
|
|
11
|
+
questionsThisExcerptCanAnswer: string;
|
|
12
|
+
};
|
|
13
|
+
|
|
14
|
+
/**
|
|
15
|
+
* Extract questions from a list of nodes.
|
|
16
|
+
*/
|
|
17
|
+
export class QuestionsAnsweredExtractor extends BaseExtractor {
|
|
18
|
+
/**
|
|
19
|
+
* MastraLanguageModel instance.
|
|
20
|
+
* @type {MastraLanguageModel}
|
|
21
|
+
*/
|
|
22
|
+
llm: MastraLanguageModel;
|
|
23
|
+
|
|
24
|
+
/**
|
|
25
|
+
* Number of questions to generate.
|
|
26
|
+
* @type {number}
|
|
27
|
+
* @default 5
|
|
28
|
+
*/
|
|
29
|
+
questions: number = 5;
|
|
30
|
+
|
|
31
|
+
/**
|
|
32
|
+
* The prompt template to use for the question extractor.
|
|
33
|
+
* @type {string}
|
|
34
|
+
*/
|
|
35
|
+
promptTemplate: QuestionExtractPrompt;
|
|
36
|
+
|
|
37
|
+
/**
|
|
38
|
+
* Wheter to use metadata for embeddings only
|
|
39
|
+
* @type {boolean}
|
|
40
|
+
* @default false
|
|
41
|
+
*/
|
|
42
|
+
embeddingOnly: boolean = false;
|
|
43
|
+
|
|
44
|
+
/**
|
|
45
|
+
* Constructor for the QuestionsAnsweredExtractor class.
|
|
46
|
+
* @param {MastraLanguageModel} llm MastraLanguageModel instance.
|
|
47
|
+
* @param {number} questions Number of questions to generate.
|
|
48
|
+
* @param {QuestionExtractPrompt['template']} promptTemplate Optional custom prompt template (should include {context}).
|
|
49
|
+
* @param {boolean} embeddingOnly Whether to use metadata for embeddings only.
|
|
50
|
+
*/
|
|
51
|
+
constructor(options?: QuestionAnswerExtractArgs) {
|
|
52
|
+
if (options?.questions && options.questions < 1) throw new Error('Questions must be greater than 0');
|
|
53
|
+
|
|
54
|
+
super();
|
|
55
|
+
|
|
56
|
+
this.llm = options?.llm ?? baseLLM;
|
|
57
|
+
this.questions = options?.questions ?? 5;
|
|
58
|
+
this.promptTemplate = options?.promptTemplate
|
|
59
|
+
? new PromptTemplate({
|
|
60
|
+
templateVars: ['numQuestions', 'context'],
|
|
61
|
+
template: options.promptTemplate,
|
|
62
|
+
}).partialFormat({
|
|
63
|
+
numQuestions: '5',
|
|
64
|
+
})
|
|
65
|
+
: defaultQuestionExtractPrompt;
|
|
66
|
+
this.embeddingOnly = options?.embeddingOnly ?? false;
|
|
67
|
+
}
|
|
68
|
+
|
|
69
|
+
/**
|
|
70
|
+
* Extract answered questions from a node.
|
|
71
|
+
* @param {BaseNode} node Node to extract questions from.
|
|
72
|
+
* @returns {Promise<Array<ExtractQuestion> | Array<{}>>} Questions extracted from the node.
|
|
73
|
+
*/
|
|
74
|
+
async extractQuestionsFromNode(node: BaseNode): Promise<ExtractQuestion> {
|
|
75
|
+
const text = node.getContent(this.metadataMode);
|
|
76
|
+
if (!text || text.trim() === '') {
|
|
77
|
+
return { questionsThisExcerptCanAnswer: '' };
|
|
78
|
+
}
|
|
79
|
+
if (this.isTextNodeOnly && !(node instanceof TextNode)) {
|
|
80
|
+
return { questionsThisExcerptCanAnswer: '' };
|
|
81
|
+
}
|
|
82
|
+
|
|
83
|
+
const contextStr = node.getContent(this.metadataMode);
|
|
84
|
+
|
|
85
|
+
const prompt = this.promptTemplate.format({
|
|
86
|
+
context: contextStr,
|
|
87
|
+
numQuestions: this.questions.toString(),
|
|
88
|
+
});
|
|
89
|
+
|
|
90
|
+
const questions = await this.llm.doGenerate({
|
|
91
|
+
inputFormat: 'messages',
|
|
92
|
+
mode: { type: 'regular' },
|
|
93
|
+
prompt: [
|
|
94
|
+
{
|
|
95
|
+
role: 'user',
|
|
96
|
+
content: [{ type: 'text', text: prompt }],
|
|
97
|
+
},
|
|
98
|
+
],
|
|
99
|
+
});
|
|
100
|
+
|
|
101
|
+
let result = '';
|
|
102
|
+
try {
|
|
103
|
+
if (typeof questions.text === 'string') {
|
|
104
|
+
result = questions.text.replace(STRIP_REGEX, '').trim();
|
|
105
|
+
} else {
|
|
106
|
+
console.warn('Question extraction LLM output was not a string:', questions.text);
|
|
107
|
+
}
|
|
108
|
+
} catch (err) {
|
|
109
|
+
console.warn('Question extraction failed:', err);
|
|
110
|
+
}
|
|
111
|
+
return {
|
|
112
|
+
questionsThisExcerptCanAnswer: result,
|
|
113
|
+
};
|
|
114
|
+
}
|
|
115
|
+
|
|
116
|
+
/**
|
|
117
|
+
* Extract answered questions from a list of nodes.
|
|
118
|
+
* @param {BaseNode[]} nodes Nodes to extract questions from.
|
|
119
|
+
* @returns {Promise<Array<ExtractQuestion> | Array<{}>>} Questions extracted from the nodes.
|
|
120
|
+
*/
|
|
121
|
+
async extract(nodes: BaseNode[]): Promise<Array<ExtractQuestion> | Array<object>> {
|
|
122
|
+
const results = await Promise.all(nodes.map(node => this.extractQuestionsFromNode(node)));
|
|
123
|
+
|
|
124
|
+
return results;
|
|
125
|
+
}
|
|
126
|
+
}
|
|
@@ -0,0 +1,107 @@
|
|
|
1
|
+
import { createOpenAI } from '@ai-sdk/openai';
|
|
2
|
+
import { TextNode } from 'llamaindex';
|
|
3
|
+
import { describe, it, expect, vi } from 'vitest';
|
|
4
|
+
import { SummaryExtractor } from './summary';
|
|
5
|
+
|
|
6
|
+
const openai = createOpenAI({
|
|
7
|
+
apiKey: process.env.OPENAI_API_KEY,
|
|
8
|
+
});
|
|
9
|
+
|
|
10
|
+
const model = openai('gpt-4o');
|
|
11
|
+
|
|
12
|
+
vi.setConfig({ testTimeout: 10_000, hookTimeout: 10_000 });
|
|
13
|
+
|
|
14
|
+
describe('SummaryExtractor', () => {
|
|
15
|
+
it('can use a custom model from the test suite', async () => {
|
|
16
|
+
const extractor = new SummaryExtractor({ llm: model });
|
|
17
|
+
const node = new TextNode({ text: 'A summary test using a custom model.' });
|
|
18
|
+
const summary = await extractor.generateNodeSummary(node);
|
|
19
|
+
expect(typeof summary).toBe('string');
|
|
20
|
+
expect(summary.length).toBeGreaterThan(0);
|
|
21
|
+
});
|
|
22
|
+
it('extracts summary from normal text', async () => {
|
|
23
|
+
const extractor = new SummaryExtractor();
|
|
24
|
+
const node = new TextNode({ text: 'This is a test document.' });
|
|
25
|
+
const summary = await extractor.generateNodeSummary(node);
|
|
26
|
+
expect(typeof summary).toBe('string');
|
|
27
|
+
expect(summary.length).toBeGreaterThan(0);
|
|
28
|
+
});
|
|
29
|
+
|
|
30
|
+
it('handles empty input gracefully', async () => {
|
|
31
|
+
const extractor = new SummaryExtractor();
|
|
32
|
+
const node = new TextNode({ text: '' });
|
|
33
|
+
const summary = await extractor.generateNodeSummary(node);
|
|
34
|
+
expect(summary).toBe('');
|
|
35
|
+
});
|
|
36
|
+
|
|
37
|
+
it('supports prompt customization', async () => {
|
|
38
|
+
const extractor = new SummaryExtractor({ promptTemplate: 'Summarize: {context}' });
|
|
39
|
+
const node = new TextNode({ text: 'Test document for prompt customization.' });
|
|
40
|
+
const summary = await extractor.generateNodeSummary(node);
|
|
41
|
+
expect(typeof summary).toBe('string');
|
|
42
|
+
expect(summary.length).toBeGreaterThan(0);
|
|
43
|
+
});
|
|
44
|
+
|
|
45
|
+
it('handles very long input', async () => {
|
|
46
|
+
const extractor = new SummaryExtractor();
|
|
47
|
+
const longText = 'A'.repeat(1000);
|
|
48
|
+
const node = new TextNode({ text: longText });
|
|
49
|
+
const summary = await extractor.generateNodeSummary(node);
|
|
50
|
+
expect(typeof summary).toBe('string');
|
|
51
|
+
});
|
|
52
|
+
|
|
53
|
+
it('handles whitespace only input', async () => {
|
|
54
|
+
const extractor = new SummaryExtractor();
|
|
55
|
+
const node = new TextNode({ text: ' ' });
|
|
56
|
+
const summary = await extractor.generateNodeSummary(node);
|
|
57
|
+
expect(summary).toBe('');
|
|
58
|
+
});
|
|
59
|
+
|
|
60
|
+
it('handles special characters and emojis', async () => {
|
|
61
|
+
const extractor = new SummaryExtractor();
|
|
62
|
+
const node = new TextNode({ text: '🚀✨🔥' });
|
|
63
|
+
const summary = await extractor.generateNodeSummary(node);
|
|
64
|
+
expect(typeof summary).toBe('string');
|
|
65
|
+
expect(summary.length).toBeGreaterThan(0);
|
|
66
|
+
});
|
|
67
|
+
|
|
68
|
+
it('handles numbers only', async () => {
|
|
69
|
+
const extractor = new SummaryExtractor();
|
|
70
|
+
const node = new TextNode({ text: '1234567890' });
|
|
71
|
+
const summary = await extractor.generateNodeSummary(node);
|
|
72
|
+
expect(typeof summary).toBe('string');
|
|
73
|
+
expect(summary.length).toBeGreaterThan(0);
|
|
74
|
+
});
|
|
75
|
+
|
|
76
|
+
it('handles HTML tags', async () => {
|
|
77
|
+
const extractor = new SummaryExtractor();
|
|
78
|
+
const node = new TextNode({ text: '<h1>Test</h1>' });
|
|
79
|
+
const summary = await extractor.generateNodeSummary(node);
|
|
80
|
+
expect(typeof summary).toBe('string');
|
|
81
|
+
expect(summary.length).toBeGreaterThan(0);
|
|
82
|
+
});
|
|
83
|
+
|
|
84
|
+
it('handles non-English text', async () => {
|
|
85
|
+
const extractor = new SummaryExtractor();
|
|
86
|
+
const node = new TextNode({ text: '这是一个测试文档。' });
|
|
87
|
+
const summary = await extractor.generateNodeSummary(node);
|
|
88
|
+
expect(typeof summary).toBe('string');
|
|
89
|
+
expect(summary.length).toBeGreaterThan(0);
|
|
90
|
+
});
|
|
91
|
+
|
|
92
|
+
it('handles duplicate/repeated text', async () => {
|
|
93
|
+
const extractor = new SummaryExtractor();
|
|
94
|
+
const node = new TextNode({ text: 'repeat repeat repeat' });
|
|
95
|
+
const summary = await extractor.generateNodeSummary(node);
|
|
96
|
+
expect(typeof summary).toBe('string');
|
|
97
|
+
expect(summary.length).toBeGreaterThan(0);
|
|
98
|
+
});
|
|
99
|
+
|
|
100
|
+
it('handles only punctuation', async () => {
|
|
101
|
+
const extractor = new SummaryExtractor();
|
|
102
|
+
const node = new TextNode({ text: '!!!???...' });
|
|
103
|
+
const summary = await extractor.generateNodeSummary(node);
|
|
104
|
+
expect(typeof summary).toBe('string');
|
|
105
|
+
expect(summary.length).toBeGreaterThan(0);
|
|
106
|
+
});
|
|
107
|
+
});
|
|
@@ -0,0 +1,130 @@
|
|
|
1
|
+
import type { MastraLanguageModel } from '@mastra/core/agent';
|
|
2
|
+
import { PromptTemplate, defaultSummaryPrompt, TextNode, BaseExtractor } from 'llamaindex';
|
|
3
|
+
import type { SummaryPrompt, BaseNode } from 'llamaindex';
|
|
4
|
+
import { baseLLM, STRIP_REGEX } from './types';
|
|
5
|
+
import type { SummaryExtractArgs } from './types';
|
|
6
|
+
|
|
7
|
+
type ExtractSummary = {
|
|
8
|
+
sectionSummary?: string;
|
|
9
|
+
prevSectionSummary?: string;
|
|
10
|
+
nextSectionSummary?: string;
|
|
11
|
+
};
|
|
12
|
+
|
|
13
|
+
/**
|
|
14
|
+
* Summarize an array of nodes using a custom LLM.
|
|
15
|
+
*
|
|
16
|
+
* @param nodes Array of node-like objects
|
|
17
|
+
* @param options Summary extraction options
|
|
18
|
+
* @returns Array of summary results
|
|
19
|
+
*/
|
|
20
|
+
export class SummaryExtractor extends BaseExtractor {
|
|
21
|
+
/**
|
|
22
|
+
* MastraLanguageModel instance.
|
|
23
|
+
* @type {MastraLanguageModel}
|
|
24
|
+
*/
|
|
25
|
+
private llm: MastraLanguageModel;
|
|
26
|
+
/**
|
|
27
|
+
* List of summaries to extract: 'self', 'prev', 'next'
|
|
28
|
+
* @type {string[]}
|
|
29
|
+
*/
|
|
30
|
+
summaries: string[];
|
|
31
|
+
|
|
32
|
+
/**
|
|
33
|
+
* The prompt template to use for the summary extractor.
|
|
34
|
+
* @type {string}
|
|
35
|
+
*/
|
|
36
|
+
promptTemplate: SummaryPrompt;
|
|
37
|
+
|
|
38
|
+
private selfSummary: boolean;
|
|
39
|
+
private prevSummary: boolean;
|
|
40
|
+
private nextSummary: boolean;
|
|
41
|
+
|
|
42
|
+
constructor(options?: SummaryExtractArgs) {
|
|
43
|
+
const summaries = options?.summaries ?? ['self'];
|
|
44
|
+
|
|
45
|
+
if (summaries && !summaries.some(s => ['self', 'prev', 'next'].includes(s)))
|
|
46
|
+
throw new Error("Summaries must be one of 'self', 'prev', 'next'");
|
|
47
|
+
|
|
48
|
+
super();
|
|
49
|
+
|
|
50
|
+
this.llm = options?.llm ?? baseLLM;
|
|
51
|
+
this.summaries = summaries;
|
|
52
|
+
this.promptTemplate = options?.promptTemplate
|
|
53
|
+
? new PromptTemplate({
|
|
54
|
+
templateVars: ['context'],
|
|
55
|
+
template: options.promptTemplate,
|
|
56
|
+
})
|
|
57
|
+
: defaultSummaryPrompt;
|
|
58
|
+
|
|
59
|
+
this.selfSummary = summaries?.includes('self') ?? false;
|
|
60
|
+
this.prevSummary = summaries?.includes('prev') ?? false;
|
|
61
|
+
this.nextSummary = summaries?.includes('next') ?? false;
|
|
62
|
+
}
|
|
63
|
+
|
|
64
|
+
/**
|
|
65
|
+
* Extract summary from a node.
|
|
66
|
+
* @param {BaseNode} node Node to extract summary from.
|
|
67
|
+
* @returns {Promise<string>} Summary extracted from the node.
|
|
68
|
+
*/
|
|
69
|
+
async generateNodeSummary(node: BaseNode): Promise<string> {
|
|
70
|
+
const text = node.getContent(this.metadataMode);
|
|
71
|
+
if (!text || text.trim() === '') {
|
|
72
|
+
return '';
|
|
73
|
+
}
|
|
74
|
+
if (this.isTextNodeOnly && !(node instanceof TextNode)) {
|
|
75
|
+
return '';
|
|
76
|
+
}
|
|
77
|
+
const context = node.getContent(this.metadataMode);
|
|
78
|
+
|
|
79
|
+
const prompt = this.promptTemplate.format({
|
|
80
|
+
context,
|
|
81
|
+
});
|
|
82
|
+
|
|
83
|
+
const result = await this.llm.doGenerate({
|
|
84
|
+
inputFormat: 'messages',
|
|
85
|
+
mode: { type: 'regular' },
|
|
86
|
+
prompt: [
|
|
87
|
+
{
|
|
88
|
+
role: 'user',
|
|
89
|
+
content: [{ type: 'text', text: prompt }],
|
|
90
|
+
},
|
|
91
|
+
],
|
|
92
|
+
});
|
|
93
|
+
|
|
94
|
+
let summary = '';
|
|
95
|
+
if (typeof result.text === 'string') {
|
|
96
|
+
summary = result.text.trim();
|
|
97
|
+
} else {
|
|
98
|
+
console.warn('Summary extraction LLM output was not a string:', result.text);
|
|
99
|
+
}
|
|
100
|
+
|
|
101
|
+
return summary.replace(STRIP_REGEX, '');
|
|
102
|
+
}
|
|
103
|
+
|
|
104
|
+
/**
|
|
105
|
+
* Extract summaries from a list of nodes.
|
|
106
|
+
* @param {BaseNode[]} nodes Nodes to extract summaries from.
|
|
107
|
+
* @returns {Promise<ExtractSummary[]>} Summaries extracted from the nodes.
|
|
108
|
+
*/
|
|
109
|
+
async extract(nodes: BaseNode[]): Promise<ExtractSummary[]> {
|
|
110
|
+
if (!nodes.every(n => n instanceof TextNode)) throw new Error('Only `TextNode` is allowed for `Summary` extractor');
|
|
111
|
+
|
|
112
|
+
const nodeSummaries = await Promise.all(nodes.map(node => this.generateNodeSummary(node)));
|
|
113
|
+
|
|
114
|
+
const metadataList: ExtractSummary[] = nodes.map(() => ({}));
|
|
115
|
+
|
|
116
|
+
for (let i = 0; i < nodes.length; i++) {
|
|
117
|
+
if (i > 0 && this.prevSummary && nodeSummaries[i - 1]) {
|
|
118
|
+
metadataList[i]!['prevSectionSummary'] = nodeSummaries[i - 1];
|
|
119
|
+
}
|
|
120
|
+
if (i < nodes.length - 1 && this.nextSummary && nodeSummaries[i + 1]) {
|
|
121
|
+
metadataList[i]!['nextSectionSummary'] = nodeSummaries[i + 1];
|
|
122
|
+
}
|
|
123
|
+
if (this.selfSummary && nodeSummaries[i]) {
|
|
124
|
+
metadataList[i]!['sectionSummary'] = nodeSummaries[i];
|
|
125
|
+
}
|
|
126
|
+
}
|
|
127
|
+
|
|
128
|
+
return metadataList;
|
|
129
|
+
}
|
|
130
|
+
}
|
|
@@ -0,0 +1,121 @@
|
|
|
1
|
+
import { createOpenAI } from '@ai-sdk/openai';
|
|
2
|
+
import { TextNode } from 'llamaindex';
|
|
3
|
+
import { describe, it, expect, vi } from 'vitest';
|
|
4
|
+
import { TitleExtractor } from './title';
|
|
5
|
+
|
|
6
|
+
const openai = createOpenAI({
|
|
7
|
+
apiKey: process.env.OPENAI_API_KEY,
|
|
8
|
+
});
|
|
9
|
+
|
|
10
|
+
const model = openai('gpt-4o');
|
|
11
|
+
|
|
12
|
+
vi.setConfig({ testTimeout: 10_000, hookTimeout: 10_000 });
|
|
13
|
+
|
|
14
|
+
describe('TitleExtractor', () => {
|
|
15
|
+
it('can use a custom model from the test suite', async () => {
|
|
16
|
+
const extractor = new TitleExtractor({ llm: model });
|
|
17
|
+
const node = new TextNode({ text: 'A title test using a custom model.' });
|
|
18
|
+
const titles = await extractor.extract([node]);
|
|
19
|
+
expect(Array.isArray(titles)).toBe(true);
|
|
20
|
+
expect(titles[0]).toHaveProperty('documentTitle');
|
|
21
|
+
expect(typeof titles[0].documentTitle).toBe('string');
|
|
22
|
+
expect(titles[0].documentTitle.length).toBeGreaterThan(0);
|
|
23
|
+
});
|
|
24
|
+
|
|
25
|
+
it('extracts title', async () => {
|
|
26
|
+
const extractor = new TitleExtractor({ llm: model });
|
|
27
|
+
const node = new TextNode({ text: 'This is a test document.' });
|
|
28
|
+
const titles = await extractor.extract([node]);
|
|
29
|
+
expect(Array.isArray(titles)).toBe(true);
|
|
30
|
+
expect(titles[0]).toHaveProperty('documentTitle');
|
|
31
|
+
expect(typeof titles[0].documentTitle).toBe('string');
|
|
32
|
+
expect(titles[0].documentTitle.length).toBeGreaterThan(0);
|
|
33
|
+
});
|
|
34
|
+
|
|
35
|
+
it('handles empty input gracefully', async () => {
|
|
36
|
+
const extractor = new TitleExtractor({ llm: model });
|
|
37
|
+
const node = new TextNode({ text: '' });
|
|
38
|
+
const titles = await extractor.extract([node]);
|
|
39
|
+
expect(titles[0].documentTitle).toBe('');
|
|
40
|
+
});
|
|
41
|
+
|
|
42
|
+
it('supports prompt customization', async () => {
|
|
43
|
+
const extractor = new TitleExtractor({ llm: model, nodeTemplate: 'Title for: {context}' });
|
|
44
|
+
const node = new TextNode({ text: 'Test document for prompt customization.' });
|
|
45
|
+
const titles = await extractor.extract([node]);
|
|
46
|
+
expect(titles[0]).toHaveProperty('documentTitle');
|
|
47
|
+
expect(typeof titles[0].documentTitle).toBe('string');
|
|
48
|
+
expect(titles[0].documentTitle.length).toBeGreaterThan(0);
|
|
49
|
+
});
|
|
50
|
+
|
|
51
|
+
it('handles very long input', async () => {
|
|
52
|
+
const extractor = new TitleExtractor({ llm: model });
|
|
53
|
+
const longText = 'A'.repeat(1000);
|
|
54
|
+
const node = new TextNode({ text: longText });
|
|
55
|
+
const titles = await extractor.extract([node]);
|
|
56
|
+
expect(titles[0]).toHaveProperty('documentTitle');
|
|
57
|
+
expect(typeof titles[0].documentTitle).toBe('string');
|
|
58
|
+
expect(titles[0].documentTitle.length).toBeGreaterThan(0);
|
|
59
|
+
});
|
|
60
|
+
|
|
61
|
+
it('handles whitespace only input', async () => {
|
|
62
|
+
const extractor = new TitleExtractor({ llm: model });
|
|
63
|
+
const node = new TextNode({ text: ' ' });
|
|
64
|
+
const titles = await extractor.extract([node]);
|
|
65
|
+
expect(titles[0].documentTitle).toBe('');
|
|
66
|
+
});
|
|
67
|
+
|
|
68
|
+
it('handles special characters and emojis', async () => {
|
|
69
|
+
const extractor = new TitleExtractor({ llm: model });
|
|
70
|
+
const node = new TextNode({ text: '🚀✨🔥' });
|
|
71
|
+
const titles = await extractor.extract([node]);
|
|
72
|
+
expect(titles[0]).toHaveProperty('documentTitle');
|
|
73
|
+
expect(typeof titles[0].documentTitle).toBe('string');
|
|
74
|
+
expect(titles[0].documentTitle.length).toBeGreaterThan(0);
|
|
75
|
+
});
|
|
76
|
+
|
|
77
|
+
it('handles numbers only', async () => {
|
|
78
|
+
const extractor = new TitleExtractor({ llm: model });
|
|
79
|
+
const node = new TextNode({ text: '1234567890' });
|
|
80
|
+
const titles = await extractor.extract([node]);
|
|
81
|
+
expect(titles[0]).toHaveProperty('documentTitle');
|
|
82
|
+
expect(typeof titles[0].documentTitle).toBe('string');
|
|
83
|
+
expect(titles[0].documentTitle.length).toBeGreaterThan(0);
|
|
84
|
+
});
|
|
85
|
+
|
|
86
|
+
it('handles HTML tags', async () => {
|
|
87
|
+
const extractor = new TitleExtractor({ llm: model });
|
|
88
|
+
const node = new TextNode({ text: '<h1>Test</h1>' });
|
|
89
|
+
const titles = await extractor.extract([node]);
|
|
90
|
+
expect(titles[0]).toHaveProperty('documentTitle');
|
|
91
|
+
expect(typeof titles[0].documentTitle).toBe('string');
|
|
92
|
+
expect(titles[0].documentTitle.length).toBeGreaterThan(0);
|
|
93
|
+
});
|
|
94
|
+
|
|
95
|
+
it('handles non-English text', async () => {
|
|
96
|
+
const extractor = new TitleExtractor({ llm: model });
|
|
97
|
+
const node = new TextNode({ text: '这是一个测试文档。' });
|
|
98
|
+
const titles = await extractor.extract([node]);
|
|
99
|
+
expect(titles[0]).toHaveProperty('documentTitle');
|
|
100
|
+
expect(typeof titles[0].documentTitle).toBe('string');
|
|
101
|
+
expect(titles[0].documentTitle.length).toBeGreaterThan(0);
|
|
102
|
+
});
|
|
103
|
+
|
|
104
|
+
it('handles duplicate/repeated text', async () => {
|
|
105
|
+
const extractor = new TitleExtractor({ llm: model });
|
|
106
|
+
const node = new TextNode({ text: 'repeat repeat repeat' });
|
|
107
|
+
const titles = await extractor.extract([node]);
|
|
108
|
+
expect(titles[0]).toHaveProperty('documentTitle');
|
|
109
|
+
expect(typeof titles[0].documentTitle).toBe('string');
|
|
110
|
+
expect(titles[0].documentTitle.length).toBeGreaterThan(0);
|
|
111
|
+
});
|
|
112
|
+
|
|
113
|
+
it('handles only punctuation', async () => {
|
|
114
|
+
const extractor = new TitleExtractor({ llm: model });
|
|
115
|
+
const node = new TextNode({ text: '!!!???...' });
|
|
116
|
+
const titles = await extractor.extract([node]);
|
|
117
|
+
expect(titles[0]).toHaveProperty('documentTitle');
|
|
118
|
+
expect(typeof titles[0].documentTitle).toBe('string');
|
|
119
|
+
expect(titles[0].documentTitle.length).toBeGreaterThan(0);
|
|
120
|
+
});
|
|
121
|
+
});
|
|
@@ -0,0 +1,210 @@
|
|
|
1
|
+
import type { MastraLanguageModel } from '@mastra/core/agent';
|
|
2
|
+
import {
|
|
3
|
+
PromptTemplate,
|
|
4
|
+
defaultTitleCombinePromptTemplate,
|
|
5
|
+
defaultTitleExtractorPromptTemplate,
|
|
6
|
+
MetadataMode,
|
|
7
|
+
TextNode,
|
|
8
|
+
BaseExtractor,
|
|
9
|
+
} from 'llamaindex';
|
|
10
|
+
import type { TitleCombinePrompt, TitleExtractorPrompt, BaseNode } from 'llamaindex';
|
|
11
|
+
import { baseLLM } from './types';
|
|
12
|
+
import type { TitleExtractorsArgs } from './types';
|
|
13
|
+
|
|
14
|
+
type ExtractTitle = {
|
|
15
|
+
documentTitle: string;
|
|
16
|
+
};
|
|
17
|
+
|
|
18
|
+
/**
|
|
19
|
+
* Extract title from a list of nodes.
|
|
20
|
+
*/
|
|
21
|
+
export class TitleExtractor extends BaseExtractor {
|
|
22
|
+
/**
|
|
23
|
+
* MastraLanguageModel instance.
|
|
24
|
+
* @type {MastraLanguageModel}
|
|
25
|
+
*/
|
|
26
|
+
llm: MastraLanguageModel;
|
|
27
|
+
|
|
28
|
+
/**
|
|
29
|
+
* Can work for mixture of text and non-text nodes
|
|
30
|
+
* @type {boolean}
|
|
31
|
+
* @default false
|
|
32
|
+
*/
|
|
33
|
+
isTextNodeOnly: boolean = false;
|
|
34
|
+
|
|
35
|
+
/**
|
|
36
|
+
* Number of nodes to extrct titles from.
|
|
37
|
+
* @type {number}
|
|
38
|
+
* @default 5
|
|
39
|
+
*/
|
|
40
|
+
nodes: number = 5;
|
|
41
|
+
|
|
42
|
+
/**
|
|
43
|
+
* The prompt template to use for the title extractor.
|
|
44
|
+
* @type {string}
|
|
45
|
+
*/
|
|
46
|
+
nodeTemplate: TitleExtractorPrompt;
|
|
47
|
+
|
|
48
|
+
/**
|
|
49
|
+
* The prompt template to merge title with..
|
|
50
|
+
* @type {string}
|
|
51
|
+
*/
|
|
52
|
+
combineTemplate: TitleCombinePrompt;
|
|
53
|
+
|
|
54
|
+
/**
|
|
55
|
+
* Constructor for the TitleExtractor class.
|
|
56
|
+
* @param {MastraLanguageModel} llm MastraLanguageModel instance.
|
|
57
|
+
* @param {number} nodes Number of nodes to extract titles from.
|
|
58
|
+
* @param {TitleExtractorPrompt} nodeTemplate The prompt template to use for the title extractor.
|
|
59
|
+
* @param {string} combineTemplate The prompt template to merge title with..
|
|
60
|
+
*/
|
|
61
|
+
constructor(options?: TitleExtractorsArgs) {
|
|
62
|
+
super();
|
|
63
|
+
|
|
64
|
+
this.llm = options?.llm ?? baseLLM;
|
|
65
|
+
this.nodes = options?.nodes ?? 5;
|
|
66
|
+
|
|
67
|
+
this.nodeTemplate = options?.nodeTemplate
|
|
68
|
+
? new PromptTemplate({
|
|
69
|
+
templateVars: ['context'],
|
|
70
|
+
template: options.nodeTemplate,
|
|
71
|
+
})
|
|
72
|
+
: defaultTitleExtractorPromptTemplate;
|
|
73
|
+
|
|
74
|
+
this.combineTemplate = options?.combineTemplate
|
|
75
|
+
? new PromptTemplate({
|
|
76
|
+
templateVars: ['context'],
|
|
77
|
+
template: options.combineTemplate,
|
|
78
|
+
})
|
|
79
|
+
: defaultTitleCombinePromptTemplate;
|
|
80
|
+
}
|
|
81
|
+
|
|
82
|
+
/**
|
|
83
|
+
* Extract titles from a list of nodes.
|
|
84
|
+
* @param {BaseNode[]} nodes Nodes to extract titles from.
|
|
85
|
+
* @returns {Promise<BaseNode<ExtractTitle>[]>} Titles extracted from the nodes.
|
|
86
|
+
*/
|
|
87
|
+
async extract(nodes: BaseNode[]): Promise<Array<ExtractTitle>> {
|
|
88
|
+
// Prepare output array in original node order
|
|
89
|
+
const results: ExtractTitle[] = new Array(nodes.length);
|
|
90
|
+
// Keep track of nodes with content to extract
|
|
91
|
+
const nodesToExtractTitle: BaseNode[] = [];
|
|
92
|
+
const nodeIndexes: number[] = [];
|
|
93
|
+
|
|
94
|
+
nodes.forEach((node, idx) => {
|
|
95
|
+
const text = node.getContent(this.metadataMode);
|
|
96
|
+
if (!text || text.trim() === '') {
|
|
97
|
+
results[idx] = { documentTitle: '' };
|
|
98
|
+
} else {
|
|
99
|
+
nodesToExtractTitle.push(node);
|
|
100
|
+
nodeIndexes.push(idx);
|
|
101
|
+
}
|
|
102
|
+
});
|
|
103
|
+
|
|
104
|
+
if (nodesToExtractTitle.length) {
|
|
105
|
+
const filteredNodes = this.filterNodes(nodesToExtractTitle);
|
|
106
|
+
if (filteredNodes.length) {
|
|
107
|
+
const nodesByDocument = this.separateNodesByDocument(filteredNodes);
|
|
108
|
+
const titlesByDocument = await this.extractTitles(nodesByDocument);
|
|
109
|
+
filteredNodes.forEach((node, i) => {
|
|
110
|
+
const nodeIndex = nodeIndexes[i];
|
|
111
|
+
const groupKey = node.sourceNode?.nodeId ?? node.id_;
|
|
112
|
+
if (typeof nodeIndex === 'number') {
|
|
113
|
+
results[nodeIndex] = {
|
|
114
|
+
documentTitle: titlesByDocument[groupKey] ?? '',
|
|
115
|
+
};
|
|
116
|
+
}
|
|
117
|
+
});
|
|
118
|
+
}
|
|
119
|
+
}
|
|
120
|
+
return results;
|
|
121
|
+
}
|
|
122
|
+
|
|
123
|
+
private filterNodes(nodes: BaseNode[]): BaseNode[] {
|
|
124
|
+
return nodes.filter(node => {
|
|
125
|
+
if (this.isTextNodeOnly && !(node instanceof TextNode)) {
|
|
126
|
+
return false;
|
|
127
|
+
}
|
|
128
|
+
return true;
|
|
129
|
+
});
|
|
130
|
+
}
|
|
131
|
+
|
|
132
|
+
private separateNodesByDocument(nodes: BaseNode[]): Record<string, BaseNode[]> {
|
|
133
|
+
const nodesByDocument: Record<string, BaseNode[]> = {};
|
|
134
|
+
|
|
135
|
+
for (const node of nodes) {
|
|
136
|
+
const groupKey = node.sourceNode?.nodeId ?? node.id_;
|
|
137
|
+
nodesByDocument[groupKey] = nodesByDocument[groupKey] || [];
|
|
138
|
+
nodesByDocument[groupKey].push(node);
|
|
139
|
+
}
|
|
140
|
+
|
|
141
|
+
return nodesByDocument;
|
|
142
|
+
}
|
|
143
|
+
|
|
144
|
+
private async extractTitles(nodesByDocument: Record<string, BaseNode[]>): Promise<Record<string, string>> {
|
|
145
|
+
const titlesByDocument: Record<string, string> = {};
|
|
146
|
+
|
|
147
|
+
for (const [key, nodes] of Object.entries(nodesByDocument)) {
|
|
148
|
+
const titleCandidates = await this.getTitlesCandidates(nodes);
|
|
149
|
+
const combinedTitles = titleCandidates.join(', ');
|
|
150
|
+
const completion = await this.llm.doGenerate({
|
|
151
|
+
inputFormat: 'messages',
|
|
152
|
+
mode: { type: 'regular' },
|
|
153
|
+
prompt: [
|
|
154
|
+
{
|
|
155
|
+
role: 'user',
|
|
156
|
+
content: [
|
|
157
|
+
{
|
|
158
|
+
type: 'text',
|
|
159
|
+
text: this.combineTemplate.format({
|
|
160
|
+
context: combinedTitles,
|
|
161
|
+
}),
|
|
162
|
+
},
|
|
163
|
+
],
|
|
164
|
+
},
|
|
165
|
+
],
|
|
166
|
+
});
|
|
167
|
+
|
|
168
|
+
let title = '';
|
|
169
|
+
if (typeof completion.text === 'string') {
|
|
170
|
+
title = completion.text.trim();
|
|
171
|
+
} else {
|
|
172
|
+
console.warn('Title extraction LLM output was not a string:', completion.text);
|
|
173
|
+
}
|
|
174
|
+
titlesByDocument[key] = title;
|
|
175
|
+
}
|
|
176
|
+
|
|
177
|
+
return titlesByDocument;
|
|
178
|
+
}
|
|
179
|
+
|
|
180
|
+
private async getTitlesCandidates(nodes: BaseNode[]): Promise<string[]> {
|
|
181
|
+
const titleJobs = nodes.map(async node => {
|
|
182
|
+
const completion = await this.llm.doGenerate({
|
|
183
|
+
inputFormat: 'messages',
|
|
184
|
+
mode: { type: 'regular' },
|
|
185
|
+
prompt: [
|
|
186
|
+
{
|
|
187
|
+
role: 'user',
|
|
188
|
+
content: [
|
|
189
|
+
{
|
|
190
|
+
type: 'text',
|
|
191
|
+
text: this.nodeTemplate.format({
|
|
192
|
+
context: node.getContent(MetadataMode.ALL),
|
|
193
|
+
}),
|
|
194
|
+
},
|
|
195
|
+
],
|
|
196
|
+
},
|
|
197
|
+
],
|
|
198
|
+
});
|
|
199
|
+
|
|
200
|
+
if (typeof completion.text === 'string') {
|
|
201
|
+
return completion.text.trim();
|
|
202
|
+
} else {
|
|
203
|
+
console.warn('Title candidate extraction LLM output was not a string:', completion.text);
|
|
204
|
+
return '';
|
|
205
|
+
}
|
|
206
|
+
});
|
|
207
|
+
|
|
208
|
+
return await Promise.all(titleJobs);
|
|
209
|
+
}
|
|
210
|
+
}
|