@dxos/plugin-transformer 0.7.5-main.b19bfc8
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +8 -0
- package/README.md +15 -0
- package/dist/lib/browser/index.mjs +52 -0
- package/dist/lib/browser/index.mjs.map +7 -0
- package/dist/lib/browser/meta.json +1 -0
- package/dist/lib/browser/types/index.mjs +1 -0
- package/dist/lib/browser/types/index.mjs.map +7 -0
- package/dist/lib/node/index.cjs +71 -0
- package/dist/lib/node/index.cjs.map +7 -0
- package/dist/lib/node/meta.json +1 -0
- package/dist/lib/node/types/index.cjs +2 -0
- package/dist/lib/node/types/index.cjs.map +7 -0
- package/dist/lib/node-esm/index.mjs +54 -0
- package/dist/lib/node-esm/index.mjs.map +7 -0
- package/dist/lib/node-esm/meta.json +1 -0
- package/dist/lib/node-esm/types/index.mjs +2 -0
- package/dist/lib/node-esm/types/index.mjs.map +7 -0
- package/dist/types/src/TransformerPlugin.d.ts +2 -0
- package/dist/types/src/TransformerPlugin.d.ts.map +1 -0
- package/dist/types/src/capabilities/index.d.ts +1 -0
- package/dist/types/src/capabilities/index.d.ts.map +1 -0
- package/dist/types/src/components/DebugInfo.d.ts +14 -0
- package/dist/types/src/components/DebugInfo.d.ts.map +1 -0
- package/dist/types/src/components/Voice.d.ts +7 -0
- package/dist/types/src/components/Voice.d.ts.map +1 -0
- package/dist/types/src/components/Voice.stories.d.ts +8 -0
- package/dist/types/src/components/Voice.stories.d.ts.map +1 -0
- package/dist/types/src/hooks/index.d.ts +3 -0
- package/dist/types/src/hooks/index.d.ts.map +1 -0
- package/dist/types/src/hooks/useAudioStream.d.ts +12 -0
- package/dist/types/src/hooks/useAudioStream.d.ts.map +1 -0
- package/dist/types/src/hooks/usePipeline.d.ts +41 -0
- package/dist/types/src/hooks/usePipeline.d.ts.map +1 -0
- package/dist/types/src/index.d.ts +3 -0
- package/dist/types/src/index.d.ts.map +1 -0
- package/dist/types/src/meta.d.ts +10 -0
- package/dist/types/src/meta.d.ts.map +1 -0
- package/dist/types/src/testing/model.test.d.ts +1 -0
- package/dist/types/src/testing/model.test.d.ts.map +1 -0
- package/dist/types/src/testing/node-pipeline.d.ts +12 -0
- package/dist/types/src/testing/node-pipeline.d.ts.map +1 -0
- package/dist/types/src/testing/pipeline.d.ts +28 -0
- package/dist/types/src/testing/pipeline.d.ts.map +1 -0
- package/dist/types/src/testing/pipeline.test.d.ts +2 -0
- package/dist/types/src/testing/pipeline.test.d.ts.map +1 -0
- package/dist/types/src/testing/web-pipeline.d.ts +12 -0
- package/dist/types/src/testing/web-pipeline.d.ts.map +1 -0
- package/dist/types/src/translations.d.ts +9 -0
- package/dist/types/src/translations.d.ts.map +1 -0
- package/dist/types/src/types/index.d.ts +1 -0
- package/dist/types/src/types/index.d.ts.map +1 -0
- package/dist/types/tsconfig.tsbuildinfo +1 -0
- package/package.json +80 -0
- package/src/TransformerPlugin.tsx +34 -0
- package/src/capabilities/index.ts +3 -0
- package/src/components/DebugInfo.tsx +79 -0
- package/src/components/Voice.stories.tsx +32 -0
- package/src/components/Voice.tsx +110 -0
- package/src/hooks/index.ts +6 -0
- package/src/hooks/useAudioStream.ts +252 -0
- package/src/hooks/usePipeline.ts +153 -0
- package/src/index.ts +7 -0
- package/src/meta.ts +16 -0
- package/src/testing/model.test.ts +3 -0
- package/src/testing/node-pipeline.ts +35 -0
- package/src/testing/pipeline.test.ts +90 -0
- package/src/testing/pipeline.ts +74 -0
- package/src/testing/web-pipeline.ts +46 -0
- package/src/translations.ts +15 -0
- package/src/types/index.ts +3 -0
|
@@ -0,0 +1,153 @@
|
|
|
1
|
+
//
|
|
2
|
+
// Copyright 2025 DXOS.org
|
|
3
|
+
//
|
|
4
|
+
|
|
5
|
+
import { type Pipeline, pipeline, env } from '@xenova/transformers';
|
|
6
|
+
import { useState, useRef, useEffect } from 'react';
|
|
7
|
+
|
|
8
|
+
import { invariant } from '@dxos/invariant';
|
|
9
|
+
import { log } from '@dxos/log';
|
|
10
|
+
|
|
11
|
+
// Add WebGPU types.
|
|
12
|
+
declare global {
|
|
13
|
+
interface Navigator {
|
|
14
|
+
gpu?: {
|
|
15
|
+
requestAdapter(): Promise<GPUAdapter | null>;
|
|
16
|
+
};
|
|
17
|
+
}
|
|
18
|
+
|
|
19
|
+
interface GPUAdapter {
|
|
20
|
+
requestAdapterInfo(): Promise<GPUAdapterInfo>;
|
|
21
|
+
}
|
|
22
|
+
|
|
23
|
+
interface GPUAdapterInfo {
|
|
24
|
+
vendor: string;
|
|
25
|
+
architecture: string;
|
|
26
|
+
description: string;
|
|
27
|
+
}
|
|
28
|
+
}
|
|
29
|
+
|
|
30
|
+
// Configure cache and runtime settings.
|
|
31
|
+
env.cacheDir = './.cache';
|
|
32
|
+
env.allowLocalModels = true;
|
|
33
|
+
|
|
34
|
+
// Configure ONNX runtime for WebGPU.
|
|
35
|
+
(env.backends.onnx as any).wasm.numThreads = 1;
|
|
36
|
+
(env.backends.onnx as any).provider = 'webgpu';
|
|
37
|
+
(env.backends.onnx as any).webgpu = {
|
|
38
|
+
profilingMode: true,
|
|
39
|
+
};
|
|
40
|
+
|
|
41
|
+
export type PipelineConfig = {
|
|
42
|
+
active?: boolean;
|
|
43
|
+
debug?: boolean;
|
|
44
|
+
model: string;
|
|
45
|
+
};
|
|
46
|
+
|
|
47
|
+
export type PipelineState = {
|
|
48
|
+
gpuInfo: string;
|
|
49
|
+
isLoaded: boolean;
|
|
50
|
+
isLoading: boolean;
|
|
51
|
+
error: string | null;
|
|
52
|
+
};
|
|
53
|
+
|
|
54
|
+
// TODO(burdon): Document external API.
|
|
55
|
+
export type TranscriptionOptions = {
|
|
56
|
+
sampling_rate: number;
|
|
57
|
+
chunk_length_s: number;
|
|
58
|
+
stride_length_s: number;
|
|
59
|
+
return_timestamps: boolean;
|
|
60
|
+
language: string;
|
|
61
|
+
};
|
|
62
|
+
|
|
63
|
+
export const usePipeline = ({ active, model, debug }: PipelineConfig) => {
|
|
64
|
+
const [state, setState] = useState<PipelineState>({
|
|
65
|
+
gpuInfo: '',
|
|
66
|
+
isLoaded: false,
|
|
67
|
+
isLoading: false,
|
|
68
|
+
error: null,
|
|
69
|
+
});
|
|
70
|
+
|
|
71
|
+
const pipelineRef = useRef<Pipeline | null>(null);
|
|
72
|
+
|
|
73
|
+
// TODO(burdon): Factor out loading model. Separate tests.
|
|
74
|
+
useEffect(() => {
|
|
75
|
+
const loadModel = async () => {
|
|
76
|
+
try {
|
|
77
|
+
setState((prev) => ({ ...prev, isLoading: true, error: null }));
|
|
78
|
+
|
|
79
|
+
// Check WebGPU support.
|
|
80
|
+
if (!navigator.gpu) {
|
|
81
|
+
log.warn('WebGPU is not supported, falling back to CPU');
|
|
82
|
+
setState((prev) => ({ ...prev, gpuInfo: 'WebGPU not supported (using CPU)' }));
|
|
83
|
+
} else {
|
|
84
|
+
try {
|
|
85
|
+
const adapter = await navigator.gpu.requestAdapter();
|
|
86
|
+
if (!adapter) {
|
|
87
|
+
throw new Error('No GPU adapter found');
|
|
88
|
+
}
|
|
89
|
+
|
|
90
|
+
// Try to get adapter info if available.
|
|
91
|
+
try {
|
|
92
|
+
const adapterInfo = await adapter.requestAdapterInfo();
|
|
93
|
+
if (adapterInfo) {
|
|
94
|
+
setState((prev) => ({
|
|
95
|
+
...prev,
|
|
96
|
+
gpuInfo: `${adapterInfo.description || 'GPU'} (${adapterInfo.vendor || 'Unknown'})`,
|
|
97
|
+
}));
|
|
98
|
+
} else {
|
|
99
|
+
setState((prev) => ({ ...prev, gpuInfo: 'GPU Available (details unknown)' }));
|
|
100
|
+
}
|
|
101
|
+
} catch (err) {
|
|
102
|
+
log.warn('could not get GPU info', { err });
|
|
103
|
+
setState((prev) => ({ ...prev, gpuInfo: 'GPU Available (details unavailable)' }));
|
|
104
|
+
}
|
|
105
|
+
} catch (err) {
|
|
106
|
+
log.warn('WebGPU initialization failed', { err });
|
|
107
|
+
setState((prev) => ({ ...prev, gpuInfo: 'GPU initialization failed (using CPU)' }));
|
|
108
|
+
}
|
|
109
|
+
}
|
|
110
|
+
|
|
111
|
+
const pipe = await pipeline('automatic-speech-recognition', model, {
|
|
112
|
+
quantized: true,
|
|
113
|
+
progress_callback: (progress: any) => {
|
|
114
|
+
if (debug) {
|
|
115
|
+
log(`loading model: ${Math.round(progress.progress * 100)}%`);
|
|
116
|
+
}
|
|
117
|
+
},
|
|
118
|
+
});
|
|
119
|
+
|
|
120
|
+
pipelineRef.current = pipe;
|
|
121
|
+
setState((prev) => ({ ...prev, isLoaded: true, isLoading: false }));
|
|
122
|
+
log.info('model loaded successfully');
|
|
123
|
+
} catch (err) {
|
|
124
|
+
log.error('error loading model', { err });
|
|
125
|
+
setState((prev) => ({
|
|
126
|
+
...prev,
|
|
127
|
+
isLoading: false,
|
|
128
|
+
error: 'error loading model: ' + (err as Error).message,
|
|
129
|
+
}));
|
|
130
|
+
}
|
|
131
|
+
};
|
|
132
|
+
|
|
133
|
+
if (active) {
|
|
134
|
+
void loadModel();
|
|
135
|
+
}
|
|
136
|
+
|
|
137
|
+
return () => {
|
|
138
|
+
pipelineRef.current = null;
|
|
139
|
+
setState((prev) => ({ ...prev, isLoaded: false }));
|
|
140
|
+
};
|
|
141
|
+
}, [active, debug, model]);
|
|
142
|
+
|
|
143
|
+
const transcribe = async (audioData: Float32Array, options: TranscriptionOptions) => {
|
|
144
|
+
invariant(pipelineRef.current, 'pipeline not initialized');
|
|
145
|
+
const result = await pipelineRef.current(audioData, options);
|
|
146
|
+
return result;
|
|
147
|
+
};
|
|
148
|
+
|
|
149
|
+
return {
|
|
150
|
+
...state,
|
|
151
|
+
transcribe,
|
|
152
|
+
};
|
|
153
|
+
};
|
package/src/index.ts
ADDED
package/src/meta.ts
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
//
|
|
2
|
+
// Copyright 2023 DXOS.org
|
|
3
|
+
//
|
|
4
|
+
|
|
5
|
+
import { type PluginMeta } from '@dxos/app-framework';
|
|
6
|
+
|
|
7
|
+
export const TRANSFORMER_PLUGIN = 'dxos.org/plugin/transformer';
|
|
8
|
+
|
|
9
|
+
export const meta = {
|
|
10
|
+
id: TRANSFORMER_PLUGIN,
|
|
11
|
+
name: 'Transformer',
|
|
12
|
+
description: 'Run local transformers.',
|
|
13
|
+
icon: 'ph--cpu--regular',
|
|
14
|
+
source: 'https://github.com/dxos/dxos/tree/main/packages/plugins/experimental/plugin-transformer',
|
|
15
|
+
tags: ['experimental'],
|
|
16
|
+
} satisfies PluginMeta;
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
//
|
|
2
|
+
// Copyright 2025 DXOS.org
|
|
3
|
+
//
|
|
4
|
+
|
|
5
|
+
import { pipeline as xenovaPipeline } from '@xenova/transformers';
|
|
6
|
+
|
|
7
|
+
import { type EmbeddingOutput, RagPipeline } from './pipeline';
|
|
8
|
+
|
|
9
|
+
/**
|
|
10
|
+
* Node implementation using Xenova transformers.
|
|
11
|
+
*/
|
|
12
|
+
export class NodeRagPipeline extends RagPipeline {
|
|
13
|
+
private embeddingModel: string;
|
|
14
|
+
private textModel: string;
|
|
15
|
+
|
|
16
|
+
constructor(embeddingModel = 'Xenova/all-MiniLM-L6-v2', textModel = 'Xenova/t5-small') {
|
|
17
|
+
super();
|
|
18
|
+
this.embeddingModel = embeddingModel;
|
|
19
|
+
this.textModel = textModel;
|
|
20
|
+
}
|
|
21
|
+
|
|
22
|
+
protected async generateEmbeddings(text: string | string[]): Promise<EmbeddingOutput> {
|
|
23
|
+
const embedding = await xenovaPipeline('feature-extraction', this.embeddingModel);
|
|
24
|
+
const result = await embedding(text, { pooling: 'mean', normalize: true });
|
|
25
|
+
return { data: Array.from(result.data) };
|
|
26
|
+
}
|
|
27
|
+
|
|
28
|
+
protected async generateText(prompt: string) {
|
|
29
|
+
const generator = await xenovaPipeline('text2text-generation', this.textModel);
|
|
30
|
+
return generator(prompt, {
|
|
31
|
+
num_return_sequences: 1,
|
|
32
|
+
max_new_tokens: 50,
|
|
33
|
+
});
|
|
34
|
+
}
|
|
35
|
+
}
|
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
//
|
|
2
|
+
// Copyright 2025 DXOS.org
|
|
3
|
+
//
|
|
4
|
+
|
|
5
|
+
import { AutoTokenizer, pipeline } from '@huggingface/transformers';
|
|
6
|
+
import { describe, test } from 'vitest';
|
|
7
|
+
|
|
8
|
+
import { log } from '@dxos/log';
|
|
9
|
+
|
|
10
|
+
import { NodeRagPipeline } from './node-pipeline';
|
|
11
|
+
|
|
12
|
+
describe('transformers', () => {
|
|
13
|
+
test.skip('tokenizer', async () => {
|
|
14
|
+
const tokenizer = await AutoTokenizer.from_pretrained('Xenova/bert-base-uncased');
|
|
15
|
+
const tokens = await tokenizer('I love transformers!');
|
|
16
|
+
log.info('tokens', { tokens });
|
|
17
|
+
});
|
|
18
|
+
|
|
19
|
+
test.skip('sentiment', async ({ expect }) => {
|
|
20
|
+
const sentiment = await pipeline('sentiment-analysis');
|
|
21
|
+
const result = await sentiment('I love transformers!');
|
|
22
|
+
expect(result[0]).to.include({ label: 'POSITIVE' });
|
|
23
|
+
log.info('result', { result });
|
|
24
|
+
}, 30_000);
|
|
25
|
+
|
|
26
|
+
test.skip('run embeddings', async ({ expect }) => {
|
|
27
|
+
const content = [
|
|
28
|
+
'create map of london',
|
|
29
|
+
'create function to generate the fibonacci sequence',
|
|
30
|
+
'create a table with',
|
|
31
|
+
];
|
|
32
|
+
|
|
33
|
+
const embedding = await pipeline('feature-extraction', 'Xenova/all-MiniLM-L6-v2');
|
|
34
|
+
const output = await embedding(content.join('\n'), { pooling: 'mean', normalize: true });
|
|
35
|
+
log.info('output', { output });
|
|
36
|
+
});
|
|
37
|
+
|
|
38
|
+
test.skip('generation', async ({ expect }) => {
|
|
39
|
+
const generator = await pipeline('text-generation', 'Xenova/gpt2');
|
|
40
|
+
const output = await generator('create');
|
|
41
|
+
log.info('output', { output });
|
|
42
|
+
});
|
|
43
|
+
});
|
|
44
|
+
|
|
45
|
+
/**
|
|
46
|
+
* Note: The RAG pipeline is designed for browser environments with WebGPU support.
|
|
47
|
+
* The Xenova transformers library primarily targets browser environments and may not
|
|
48
|
+
* work correctly in Node.js/Vitest testing environments.
|
|
49
|
+
*
|
|
50
|
+
* To test this functionality:
|
|
51
|
+
* 1. Use in a browser environment
|
|
52
|
+
* 2. Ensure WebGPU is available
|
|
53
|
+
* 3. Consider using a test framework that supports browser environments (like Playwright or Cypress)
|
|
54
|
+
*/
|
|
55
|
+
describe('rag pipeline', () => {
|
|
56
|
+
const knowledgeBase = [
|
|
57
|
+
//
|
|
58
|
+
'create a new table.',
|
|
59
|
+
'create a new kanban.',
|
|
60
|
+
'create a map of london.',
|
|
61
|
+
'create a function to generate the fibonacci sequence.',
|
|
62
|
+
];
|
|
63
|
+
|
|
64
|
+
test.skip('prediction pipeline', async ({ expect }) => {
|
|
65
|
+
const pipeline = new NodeRagPipeline();
|
|
66
|
+
const result = await pipeline.generateCompletions('create a function', knowledgeBase);
|
|
67
|
+
|
|
68
|
+
expect(result).toBeDefined();
|
|
69
|
+
expect(Array.isArray(result)).toBe(true);
|
|
70
|
+
expect(result.length).toBeGreaterThan(0);
|
|
71
|
+
|
|
72
|
+
log.info('result', { result });
|
|
73
|
+
}, 120_000);
|
|
74
|
+
|
|
75
|
+
test('semantic similarity with local pipeline', async ({ expect }) => {
|
|
76
|
+
const semanticKnowledgeBase = [
|
|
77
|
+
'The quick brown fox jumps over the lazy dog.',
|
|
78
|
+
'JavaScript is a programming language.',
|
|
79
|
+
'Python is used for data science.',
|
|
80
|
+
];
|
|
81
|
+
|
|
82
|
+
const pipeline = new NodeRagPipeline();
|
|
83
|
+
const result = await pipeline.generateCompletions('What can I use for coding?', semanticKnowledgeBase);
|
|
84
|
+
|
|
85
|
+
expect(result).toBeDefined();
|
|
86
|
+
expect(Array.isArray(result)).toBe(true);
|
|
87
|
+
|
|
88
|
+
log.info('result', { result });
|
|
89
|
+
}, 120_000);
|
|
90
|
+
});
|
|
@@ -0,0 +1,74 @@
|
|
|
1
|
+
//
|
|
2
|
+
// Copyright 2025 DXOS.org
|
|
3
|
+
//
|
|
4
|
+
|
|
5
|
+
import { log } from '@dxos/log';
|
|
6
|
+
|
|
7
|
+
/**
|
|
8
|
+
* Types for embeddings and model outputs.
|
|
9
|
+
*/
|
|
10
|
+
export type EmbeddingOutput = {
|
|
11
|
+
data: number[]; // Simplified to number[] since we'll convert all arrays.
|
|
12
|
+
};
|
|
13
|
+
|
|
14
|
+
/**
|
|
15
|
+
* Base class for RAG pipeline implementations.
|
|
16
|
+
*/
|
|
17
|
+
export abstract class RagPipeline {
|
|
18
|
+
/**
|
|
19
|
+
* Generate embeddings for a text or array of texts.
|
|
20
|
+
*/
|
|
21
|
+
protected abstract generateEmbeddings(text: string | string[]): Promise<EmbeddingOutput>;
|
|
22
|
+
|
|
23
|
+
/**
|
|
24
|
+
* Generate text completions given a prompt.
|
|
25
|
+
*/
|
|
26
|
+
protected abstract generateText(prompt: string): Promise<any>;
|
|
27
|
+
|
|
28
|
+
/**
|
|
29
|
+
* Calculate cosine similarity between two vectors.
|
|
30
|
+
*/
|
|
31
|
+
protected cosineSimilarity(a: Float32Array | number[], b: Float32Array | number[]): number {
|
|
32
|
+
const aArray = Array.from(a);
|
|
33
|
+
const bArray = Array.from(b);
|
|
34
|
+
const dotProduct = aArray.reduce((sum, val, i) => sum + val * bArray[i], 0);
|
|
35
|
+
const magnitudeA = Math.sqrt(aArray.reduce((sum, val) => sum + val * val, 0));
|
|
36
|
+
const magnitudeB = Math.sqrt(bArray.reduce((sum, val) => sum + val * val, 0));
|
|
37
|
+
return dotProduct / (magnitudeA * magnitudeB);
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
/**
|
|
41
|
+
* Generate completions using the RAG pipeline.
|
|
42
|
+
*/
|
|
43
|
+
async generateCompletions(input: string, knowledgeBase: string[]) {
|
|
44
|
+
// Generate embeddings for input and knowledge base.
|
|
45
|
+
const inputEmbedding = await this.generateEmbeddings(input);
|
|
46
|
+
const knowledgeEmbeddings = await Promise.all(knowledgeBase.map((text) => this.generateEmbeddings(text)));
|
|
47
|
+
|
|
48
|
+
// Calculate similarity scores.
|
|
49
|
+
const similarities = knowledgeEmbeddings.map((emb) => this.cosineSimilarity(inputEmbedding.data, emb.data));
|
|
50
|
+
|
|
51
|
+
// Get top 3 most relevant contexts.
|
|
52
|
+
const topContexts = similarities
|
|
53
|
+
.map((score, idx) => ({ score, text: knowledgeBase[idx] }))
|
|
54
|
+
.sort((a, b) => b.score - a.score)
|
|
55
|
+
.slice(0, 3)
|
|
56
|
+
.map((item) => item.text);
|
|
57
|
+
|
|
58
|
+
log.info('topContexts', { topContexts });
|
|
59
|
+
|
|
60
|
+
// Generate completions using retrieved context.
|
|
61
|
+
// const prompt = `
|
|
62
|
+
// Given the history and list of commands below, complete the following user input.
|
|
63
|
+
// History: ''
|
|
64
|
+
// Context: ${topContexts.join('\n')}
|
|
65
|
+
// User input: "${input}"
|
|
66
|
+
// `.trim();
|
|
67
|
+
|
|
68
|
+
const prompt = 'hello';
|
|
69
|
+
|
|
70
|
+
// `Context: ${topContexts.join('\n')}\nQuery: ${input}\nResponse:`;
|
|
71
|
+
log.info('prompt', { prompt });
|
|
72
|
+
return this.generateText(prompt);
|
|
73
|
+
}
|
|
74
|
+
}
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
//
|
|
2
|
+
// Copyright 2025 DXOS.org
|
|
3
|
+
//
|
|
4
|
+
|
|
5
|
+
import { pipeline } from '@huggingface/transformers';
|
|
6
|
+
|
|
7
|
+
import { type EmbeddingOutput, RagPipeline } from './pipeline';
|
|
8
|
+
|
|
9
|
+
// TODO(burdon): Workers.
|
|
10
|
+
// https://huggingface.co/docs/transformers.js/tutorials/react#step-3-design-the-user-interface
|
|
11
|
+
// TODO(burdon): WebGPT.
|
|
12
|
+
// https://huggingface.co/docs/transformers.js/guides/webgpu
|
|
13
|
+
|
|
14
|
+
/**
|
|
15
|
+
* Web-based implementation using Hugging Face transformers.
|
|
16
|
+
*/
|
|
17
|
+
export class WebRagPipeline extends RagPipeline {
|
|
18
|
+
private embeddingModel: string;
|
|
19
|
+
private textModel: string;
|
|
20
|
+
|
|
21
|
+
constructor(embeddingModel = 'Xenova/all-MiniLM-L6-v2', textModel = 'Xenova/distilgpt2') {
|
|
22
|
+
super();
|
|
23
|
+
this.embeddingModel = embeddingModel;
|
|
24
|
+
this.textModel = textModel;
|
|
25
|
+
}
|
|
26
|
+
|
|
27
|
+
protected async generateEmbeddings(text: string | string[]): Promise<EmbeddingOutput> {
|
|
28
|
+
const embedding = await pipeline('feature-extraction', this.embeddingModel, {
|
|
29
|
+
revision: 'main',
|
|
30
|
+
});
|
|
31
|
+
const result = await embedding(text, { pooling: 'mean', normalize: true });
|
|
32
|
+
// Convert DataArray to number[]
|
|
33
|
+
return { data: Array.from(result.data) };
|
|
34
|
+
}
|
|
35
|
+
|
|
36
|
+
protected async generateText(prompt: string) {
|
|
37
|
+
const generator = await pipeline('text-generation', this.textModel, {
|
|
38
|
+
revision: 'main',
|
|
39
|
+
});
|
|
40
|
+
return generator(prompt, {
|
|
41
|
+
max_new_tokens: 50,
|
|
42
|
+
num_return_sequences: 3,
|
|
43
|
+
temperature: 0.7,
|
|
44
|
+
});
|
|
45
|
+
}
|
|
46
|
+
}
|