ai-compare-candidates 0.0.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/.editorconfig +51 -0
- package/.vscode/settings.json +3 -0
- package/.yarnrc.yml +16 -0
- package/LICENSE +6 -0
- package/README.md +77 -0
- package/dist/index.cjs +24 -0
- package/dist/index.cjs.map +1 -0
- package/dist/index.d.cts +134 -0
- package/dist/index.d.cts.map +1 -0
- package/dist/index.d.mts +134 -0
- package/dist/index.d.mts.map +1 -0
- package/dist/index.mjs +24 -0
- package/dist/index.mjs.map +1 -0
- package/example/.editorconfig +51 -0
- package/example/.vscode/extensions.json +13 -0
- package/example/.vscode/settings.json +5 -0
- package/example/README.md +21 -0
- package/example/index.html +21 -0
- package/example/package.json +37 -0
- package/example/postcss.config.js +29 -0
- package/example/public/favicon.ico +0 -0
- package/example/public/icons/favicon-128x128.png +0 -0
- package/example/public/icons/favicon-16x16.png +0 -0
- package/example/public/icons/favicon-32x32.png +0 -0
- package/example/public/icons/favicon-96x96.png +0 -0
- package/example/quasar.config.ts +222 -0
- package/example/src/App.vue +5 -0
- package/example/src/boot/electronHuggingFaceFix.ts +8 -0
- package/example/src/boot/icons.ts +20 -0
- package/example/src/css/app.scss +1 -0
- package/example/src/css/quasar.variables.scss +25 -0
- package/example/src/env.d.ts +7 -0
- package/example/src/layouts/app.vue +147 -0
- package/example/src/router/index.ts +37 -0
- package/example/src/router/routes.ts +8 -0
- package/example/src/stores/index.ts +32 -0
- package/example/src/stores/store.ts +19 -0
- package/example/tsconfig.json +3 -0
- package/package.json +55 -0
- package/src/index.ts +478 -0
- package/tsconfig.json +12 -0
- package/tsconfig.node.json +12 -0
- package/tsconfig.tsbuildinfo +1 -0
- package/tsdown.config.ts +12 -0
package/src/index.ts
ADDED
|
@@ -0,0 +1,478 @@
|
|
|
1
|
+
import{
|
|
2
|
+
env,
|
|
3
|
+
pipeline,
|
|
4
|
+
AutoTokenizer,
|
|
5
|
+
TextGenerationPipeline,
|
|
6
|
+
ProgressInfo,
|
|
7
|
+
ProgressCallback,
|
|
8
|
+
SummarizationPipeline,
|
|
9
|
+
FeatureExtractionPipeline,
|
|
10
|
+
PreTrainedTokenizer,
|
|
11
|
+
TextGenerationConfig
|
|
12
|
+
}from '@huggingface/transformers';
|
|
13
|
+
import {MemoryVectorStore} from '@langchain/classic/vectorstores/memory';
|
|
14
|
+
import {Embeddings} from '@langchain/core/embeddings';
|
|
15
|
+
import lodash from 'lodash';
|
|
16
|
+
import jsan from 'jsan';
|
|
17
|
+
|
|
18
|
+
export class AICompareCandidates extends Embeddings{
|
|
19
|
+
readonly env=env;
|
|
20
|
+
DEBUG=true;
|
|
21
|
+
|
|
22
|
+
generator:TextGenerationPipeline|null=null;
|
|
23
|
+
generatorModelName='Xenova/LaMini-GPT-774M';
|
|
24
|
+
generatorPromise:Promise<TextGenerationPipeline>|null=null;
|
|
25
|
+
generatorProgressInfo:ProgressInfo=<ProgressInfo>{};
|
|
26
|
+
generatorProgressCallback:ProgressCallback|null=null;
|
|
27
|
+
|
|
28
|
+
summariser:SummarizationPipeline|null=null;
|
|
29
|
+
summariserModelName='Xenova/distilbart-cnn-12-6';
|
|
30
|
+
summariserPromise:Promise<SummarizationPipeline>|null=null;
|
|
31
|
+
summariserProgressInfo:ProgressInfo=<ProgressInfo>{};
|
|
32
|
+
summariserProgressCallback:ProgressCallback|null=null;
|
|
33
|
+
|
|
34
|
+
embedder:FeatureExtractionPipeline|null=null;
|
|
35
|
+
embedderModelName='Xenova/all-MiniLM-L12-v2';
|
|
36
|
+
embedderPromise:Promise<FeatureExtractionPipeline>|null=null;
|
|
37
|
+
embedderProgressInfo:ProgressInfo=<ProgressInfo>{};
|
|
38
|
+
embedderProgressCallback:ProgressCallback|null=null;
|
|
39
|
+
|
|
40
|
+
tokeniser:PreTrainedTokenizer|null=null;
|
|
41
|
+
tokeniserModelName=this.generatorModelName;
|
|
42
|
+
tokeniserPromise:Promise<PreTrainedTokenizer>|null=null;
|
|
43
|
+
tokeniserProgressInfo:ProgressInfo=<ProgressInfo>{};
|
|
44
|
+
tokeniserProgressCallback:ProgressCallback|null=null;
|
|
45
|
+
|
|
46
|
+
generateSearchAreasMaxNewTokens=64;
|
|
47
|
+
generateSearchAreasTemperature=0.35;
|
|
48
|
+
generateSearchAreasRepetitionPenalty=1.5;
|
|
49
|
+
|
|
50
|
+
rankingMaxNewTokens=64;
|
|
51
|
+
rankingTemperature=0.35;
|
|
52
|
+
rankingRepetitionPenalty=1.5;
|
|
53
|
+
|
|
54
|
+
targetSummarisedStringTokenCount=420;
|
|
55
|
+
|
|
56
|
+
static{
|
|
57
|
+
env.localModelPath='';
|
|
58
|
+
env.allowRemoteModels=true;
|
|
59
|
+
env.allowLocalModels=false;
|
|
60
|
+
}
|
|
61
|
+
|
|
62
|
+
constructor(){
|
|
63
|
+
super({});
|
|
64
|
+
}
|
|
65
|
+
|
|
66
|
+
async loadGenerator({
|
|
67
|
+
progressCallback,
|
|
68
|
+
modelName=''
|
|
69
|
+
}:AICompareCandidates.LoadArguments=<AICompareCandidates.LoadArguments>{}){
|
|
70
|
+
if(typeof modelName==='string'&&modelName)this.generatorModelName=modelName;
|
|
71
|
+
if(!this.generatorModelName)throw new Error('Invalid generator model name');
|
|
72
|
+
if(progressCallback)this.generatorProgressCallback=progressCallback;
|
|
73
|
+
//ts-ignore is needed due to frequent error TS2590: Expression produces a union type that is too complex to represent.
|
|
74
|
+
//@ts-ignore
|
|
75
|
+
this.generatorPromise=pipeline('text-generation',this.generatorModelName,{
|
|
76
|
+
device:'webgpu',
|
|
77
|
+
progress_callback:progressInfo=>{
|
|
78
|
+
if(this.DEBUG)console.log(jsan.stringify(progressInfo));
|
|
79
|
+
Object.assign(this.generatorProgressInfo,progressInfo);
|
|
80
|
+
return this.generatorProgressCallback?.(progressInfo);
|
|
81
|
+
}
|
|
82
|
+
});
|
|
83
|
+
this.generator=await this.generatorPromise;
|
|
84
|
+
return this.generator;
|
|
85
|
+
}
|
|
86
|
+
|
|
87
|
+
async checkGeneratorLoaded(){
|
|
88
|
+
if(!this.generatorPromise)this.loadGenerator();
|
|
89
|
+
if(!this.generator)await this.generatorPromise;
|
|
90
|
+
if(!this.generator)throw new Error('Unable to load generator');
|
|
91
|
+
}
|
|
92
|
+
|
|
93
|
+
async loadSummariser({
|
|
94
|
+
progressCallback,
|
|
95
|
+
modelName=''
|
|
96
|
+
}:AICompareCandidates.LoadArguments=<AICompareCandidates.LoadArguments>{}){
|
|
97
|
+
if(typeof modelName==='string'&&modelName)this.summariserModelName=modelName;
|
|
98
|
+
if(!this.summariserModelName)throw new Error('Invalid summariser model name');
|
|
99
|
+
if(progressCallback)this.summariserProgressCallback=progressCallback;
|
|
100
|
+
//ts-ignore is needed due to frequent error TS2590: Expression produces a union type that is too complex to represent.
|
|
101
|
+
//@ts-ignore
|
|
102
|
+
this.summariserPromise=pipeline('summarization',this.summariserModelName,{
|
|
103
|
+
device:'webgpu',
|
|
104
|
+
progress_callback:progressInfo=>{
|
|
105
|
+
if(this.DEBUG)console.log(jsan.stringify(progressInfo));
|
|
106
|
+
Object.assign(this.summariserProgressInfo,progressInfo);
|
|
107
|
+
return this.summariserProgressCallback?.(progressInfo);
|
|
108
|
+
}
|
|
109
|
+
});
|
|
110
|
+
this.summariser=await this.summariserPromise;
|
|
111
|
+
return this.summariser;
|
|
112
|
+
}
|
|
113
|
+
|
|
114
|
+
async checkSummariserLoaded(){
|
|
115
|
+
if(!this.summariserPromise)this.loadSummariser();
|
|
116
|
+
if(!this.summariser)await this.summariserPromise;
|
|
117
|
+
if(!this.summariser)throw new Error('Unable to load summariser');
|
|
118
|
+
}
|
|
119
|
+
|
|
120
|
+
async loadEmbedder({
|
|
121
|
+
progressCallback,
|
|
122
|
+
modelName=''
|
|
123
|
+
}:AICompareCandidates.LoadArguments=<AICompareCandidates.LoadArguments>{}){
|
|
124
|
+
if(typeof modelName==='string'&&modelName)this.embedderModelName=modelName;
|
|
125
|
+
if(!this.embedderModelName)throw new Error('Invalid embedder model name');
|
|
126
|
+
if(progressCallback)this.embedderProgressCallback=progressCallback;
|
|
127
|
+
//ts-ignore is needed due to frequent error TS2590: Expression produces a union type that is too complex to represent.
|
|
128
|
+
//@ts-ignore
|
|
129
|
+
this.embedderPromise=pipeline('feature-extraction',this.embedderModelName,{
|
|
130
|
+
device:'webgpu',
|
|
131
|
+
progress_callback:progressInfo=>{
|
|
132
|
+
if(this.DEBUG)console.log(jsan.stringify(progressInfo));
|
|
133
|
+
Object.assign(this.embedderProgressInfo,progressInfo);
|
|
134
|
+
return this.embedderProgressCallback?.(progressInfo);
|
|
135
|
+
}
|
|
136
|
+
});
|
|
137
|
+
this.embedder=await this.embedderPromise;
|
|
138
|
+
return this.embedder;
|
|
139
|
+
}
|
|
140
|
+
|
|
141
|
+
async checkEmbedderLoaded(){
|
|
142
|
+
if(!this.embedderPromise)this.loadEmbedder();
|
|
143
|
+
if(!this.embedder)await this.embedderPromise;
|
|
144
|
+
if(!this.embedder)throw new Error('Unable to load embedder');
|
|
145
|
+
}
|
|
146
|
+
|
|
147
|
+
async loadTokeniser({
|
|
148
|
+
progressCallback,
|
|
149
|
+
modelName=''
|
|
150
|
+
}:AICompareCandidates.LoadArguments=<AICompareCandidates.LoadArguments>{}){
|
|
151
|
+
if(typeof modelName==='string'&&modelName)this.tokeniserModelName=modelName;
|
|
152
|
+
if(!this.tokeniserModelName)throw new Error('Invalid tokeniser model name');
|
|
153
|
+
if(progressCallback)this.tokeniserProgressCallback=progressCallback;
|
|
154
|
+
//ts-ignore is needed due to frequent error TS2590: Expression produces a union type that is too complex to represent.
|
|
155
|
+
//@ts-ignore
|
|
156
|
+
this.tokeniserPromise=AutoTokenizer.from_pretrained(this.tokeniserModelName,{
|
|
157
|
+
progress_callback:progressInfo=>{
|
|
158
|
+
if(this.DEBUG)console.log(jsan.stringify(progressInfo));
|
|
159
|
+
Object.assign(this.tokeniserProgressInfo,progressInfo);
|
|
160
|
+
return this.tokeniserProgressCallback?.(progressInfo);
|
|
161
|
+
}
|
|
162
|
+
})
|
|
163
|
+
this.tokeniser=await this.tokeniserPromise;
|
|
164
|
+
return this.tokeniser;
|
|
165
|
+
}
|
|
166
|
+
|
|
167
|
+
async checkTokeniserLoaded(){
|
|
168
|
+
if(!this.tokeniserPromise)this.loadTokeniser();
|
|
169
|
+
if(!this.tokeniser)await this.tokeniserPromise;
|
|
170
|
+
if(!this.tokeniser)throw new Error('Unable to load tokeniser');
|
|
171
|
+
}
|
|
172
|
+
|
|
173
|
+
async embedQuery(text:string):Promise<number[]>{
|
|
174
|
+
await this.checkEmbedderLoaded();
|
|
175
|
+
return Array.from((await this.embedder?.(text,{
|
|
176
|
+
pooling:'mean',
|
|
177
|
+
normalize:true
|
|
178
|
+
}))?.data);
|
|
179
|
+
}
|
|
180
|
+
|
|
181
|
+
async embedDocuments(texts:string[]):Promise<number[][]>{
|
|
182
|
+
return Promise.all(texts.map(text=>this.embedQuery(text)));
|
|
183
|
+
}
|
|
184
|
+
|
|
185
|
+
generatePromptTemplate(prompt:string){
|
|
186
|
+
return 'Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n'+
|
|
187
|
+
'### Instruction:\n'+
|
|
188
|
+
prompt+
|
|
189
|
+
'\n\n### Response:';
|
|
190
|
+
}
|
|
191
|
+
|
|
192
|
+
defaultGenerateSearchAreasInstruction(problemDescription:string){
|
|
193
|
+
return 'List the relevant subject areas for the following issues. Limit your response to 100 words.\nIssues: "'+problemDescription+'"';
|
|
194
|
+
}
|
|
195
|
+
|
|
196
|
+
defaultConvertCandidateToDocument<Candidate>({
|
|
197
|
+
candidate,
|
|
198
|
+
index
|
|
199
|
+
}:AICompareCandidates.ConvertCandidateToDocumentArguments<Candidate>=<AICompareCandidates.ConvertCandidateToDocumentArguments<Candidate>>{}){
|
|
200
|
+
let document='Start of Candidate #'+index;
|
|
201
|
+
for(let i in candidate)document+='\n'+lodash.startCase(i)+': '+(typeof candidate[i]==='object'?jsan.stringify(candidate[i]):String(candidate[i]));
|
|
202
|
+
document+='\nEnd of Candidate #'+index;
|
|
203
|
+
return document;
|
|
204
|
+
}
|
|
205
|
+
|
|
206
|
+
defaultGenerateRankingInstruction({
|
|
207
|
+
problemDescription,
|
|
208
|
+
summaries,
|
|
209
|
+
candidatesForFinalSelection,
|
|
210
|
+
candidateIdentifierField
|
|
211
|
+
}:AICompareCandidates.GenerateRankingInstructionArguments=<AICompareCandidates.GenerateRankingInstructionArguments>{}){
|
|
212
|
+
return 'Strictly follow these rules:\n'+
|
|
213
|
+
'1. Rank ONLY the top '+candidatesForFinalSelection+' candidates with one 15-word sentence explaining why\n'+
|
|
214
|
+
'2. Rank the candidates based on "'+problemDescription.replace(/(\r|\n)/g,' ')+'"\n'+
|
|
215
|
+
'3. If unclear, say "Insufficient information to determine"\n\n'+
|
|
216
|
+
'Candidates:\n\n'+summaries.join('\n\n')+'\n\n'+
|
|
217
|
+
'Format exactly:\n'+
|
|
218
|
+
'#1. "Full '+lodash.startCase(candidateIdentifierField)+'": 15-word explanation\n'+
|
|
219
|
+
'#2. ...';
|
|
220
|
+
}
|
|
221
|
+
|
|
222
|
+
regexIndexOf(text:string,regex:RegExp,startIndex:number){
|
|
223
|
+
let indexInSuffix=text.slice(startIndex).search(regex);
|
|
224
|
+
return indexInSuffix<0?indexInSuffix:indexInSuffix+startIndex;
|
|
225
|
+
}
|
|
226
|
+
|
|
227
|
+
defaultExtractIdentifierFromCandidateDocument({
|
|
228
|
+
candidateDocument,
|
|
229
|
+
candidateIdentifierField
|
|
230
|
+
}:AICompareCandidates.ExtractIdentifierFromCandidateDocumentArguments=<AICompareCandidates.ExtractIdentifierFromCandidateDocumentArguments>{}){
|
|
231
|
+
if(this.DEBUG)console.log(candidateDocument,candidateIdentifierField);
|
|
232
|
+
let startCase=lodash.startCase(candidateIdentifierField);
|
|
233
|
+
let startIndex=candidateDocument.indexOf(startCase);
|
|
234
|
+
if(startIndex<0)startIndex=candidateDocument.toLowerCase().indexOf(startCase.toLowerCase());
|
|
235
|
+
if(this.DEBUG)console.log(startIndex);
|
|
236
|
+
if(startIndex>=0)startIndex+=startCase.length;
|
|
237
|
+
if(startIndex<0){
|
|
238
|
+
startIndex=candidateDocument.toLowerCase().indexOf(candidateIdentifierField.toLowerCase());
|
|
239
|
+
if(startIndex>=0)startIndex+=candidateIdentifierField.length;
|
|
240
|
+
}
|
|
241
|
+
if(this.DEBUG)console.log(startIndex);
|
|
242
|
+
else return '';
|
|
243
|
+
startIndex=candidateDocument.indexOf(':',startIndex);
|
|
244
|
+
if(this.DEBUG)console.log(startIndex);
|
|
245
|
+
if(startIndex<0)startIndex=this.regexIndexOf(candidateDocument,/\s+/,startIndex);
|
|
246
|
+
if(this.DEBUG)console.log(startIndex);
|
|
247
|
+
if(startIndex<0)return '';
|
|
248
|
+
let endIndex=candidateDocument.indexOf('\n',startIndex);
|
|
249
|
+
if(endIndex<0)endIndex=candidateDocument.length;
|
|
250
|
+
if(this.DEBUG)console.log(endIndex);
|
|
251
|
+
return candidateDocument.substring(startIndex,endIndex).trim();
|
|
252
|
+
}
|
|
253
|
+
|
|
254
|
+
defaultExtractIdentifiersFromRationale(rationale:string){
|
|
255
|
+
let regex=/^\s*#\s*\d+\s*\.?\s*"([^"]+)"/gm;
|
|
256
|
+
let matches:string[]=[];
|
|
257
|
+
for(let match:RegExpExecArray|null;Array.isArray(match=regex.exec(rationale));)if(match[1])matches.push(match[1]);
|
|
258
|
+
return matches;
|
|
259
|
+
}
|
|
260
|
+
|
|
261
|
+
async compareCandidates<Candidate>({
|
|
262
|
+
candidates,
|
|
263
|
+
problemDescription='',
|
|
264
|
+
generateSearchAreasInstruction=this.defaultGenerateSearchAreasInstruction.bind(this),
|
|
265
|
+
convertCandidateToDocument=this.defaultConvertCandidateToDocument.bind(this),
|
|
266
|
+
candidatesForInitialSelection=2,
|
|
267
|
+
candidatesForFinalSelection=1,
|
|
268
|
+
generateRankingInstruction=this.defaultGenerateRankingInstruction.bind(this),
|
|
269
|
+
extractIdentifiersFromRationale=this.defaultExtractIdentifiersFromRationale.bind(this),
|
|
270
|
+
extractIdentifierFromCandidateDocument=this.defaultExtractIdentifierFromCandidateDocument.bind(this),
|
|
271
|
+
candidateIdentifierField=undefined,
|
|
272
|
+
getSummarisableSubstringIndices
|
|
273
|
+
}:AICompareCandidates.CompareArguments<Candidate>=<AICompareCandidates.CompareArguments<Candidate>>{}):Promise<AICompareCandidates.CompareCandidatesReturn<Candidate>|void>{
|
|
274
|
+
if(!Array.isArray(candidates)||candidates.length<=0)throw new Error('No candidates provided');
|
|
275
|
+
candidatesForInitialSelection=lodash.toSafeInteger(candidatesForInitialSelection);
|
|
276
|
+
if(candidatesForInitialSelection<=0)throw new Error('Candidates for initial selection must be a positive integer bigger than 0');
|
|
277
|
+
candidatesForFinalSelection=lodash.toSafeInteger(candidatesForFinalSelection);
|
|
278
|
+
if(candidatesForFinalSelection<=0)throw new Error('Candidates for initial selection must be a positive integer bigger than 0');
|
|
279
|
+
if(candidatesForInitialSelection<candidatesForFinalSelection)throw new Error('Candidates for initial selection must be equal or more than candidates for final selection');
|
|
280
|
+
if(candidatesForInitialSelection>candidates.length)throw new Error('There are '+candidatesForInitialSelection+'candidates for initial selection which is more than the total number of candidates of '+candidates.length);
|
|
281
|
+
if(candidatesForFinalSelection>candidates.length)throw new Error('There are '+candidatesForFinalSelection+'candidates for initial selection which is more than the total number of candidates of '+candidates.length);
|
|
282
|
+
if(!candidateIdentifierField){
|
|
283
|
+
candidateIdentifierField=Object.keys(candidates[0] as object)[0] as keyof Candidate;
|
|
284
|
+
if(!candidateIdentifierField)throw new Error('No candidate identifier field');
|
|
285
|
+
}
|
|
286
|
+
|
|
287
|
+
let rationale='';
|
|
288
|
+
let selectedCandidates:Candidate[]=[];
|
|
289
|
+
|
|
290
|
+
await this.checkEmbedderLoaded();
|
|
291
|
+
if(!this.embedder)return;
|
|
292
|
+
let candidateDocuments=candidates.map((candidate,index)=>convertCandidateToDocument({
|
|
293
|
+
candidate,
|
|
294
|
+
index
|
|
295
|
+
}));
|
|
296
|
+
let vectorStore=await MemoryVectorStore.fromTexts(
|
|
297
|
+
lodash.cloneDeep(candidateDocuments),
|
|
298
|
+
candidateDocuments.map((document,index)=>index),
|
|
299
|
+
this
|
|
300
|
+
);
|
|
301
|
+
|
|
302
|
+
let searchAreasPromptTemplate=this.generatePromptTemplate(generateSearchAreasInstruction(problemDescription));
|
|
303
|
+
if(this.DEBUG)console.log('Formatted search areas prompt: '+searchAreasPromptTemplate);
|
|
304
|
+
await this.checkTokeniserLoaded();
|
|
305
|
+
if(!this.tokeniser)return;
|
|
306
|
+
let searchAreasPromptTokens=this.tokeniser.encode(searchAreasPromptTemplate);
|
|
307
|
+
if(searchAreasPromptTokens.length>this.tokeniser.model_max_length)throw new Error('Search areas instruction prompt is too long for the tokeniser model');
|
|
308
|
+
|
|
309
|
+
await this.checkGeneratorLoaded();
|
|
310
|
+
if(!this.generator)return;
|
|
311
|
+
let pad_token_id=this.tokeniser.pad_token_id??this.tokeniser.sep_token_id??0;
|
|
312
|
+
let eos_token_id=this.tokeniser.sep_token_id??2;
|
|
313
|
+
let searchAreasReplyArray=await this.generator(searchAreasPromptTemplate,{
|
|
314
|
+
max_new_tokens:this.generateSearchAreasMaxNewTokens,
|
|
315
|
+
temperature:this.generateSearchAreasTemperature,
|
|
316
|
+
repetition_penalty:this.generateSearchAreasRepetitionPenalty,
|
|
317
|
+
pad_token_id,
|
|
318
|
+
eos_token_id
|
|
319
|
+
});
|
|
320
|
+
let searchAreasReply=Array.isArray(searchAreasReplyArray?.[0])?searchAreasReplyArray?.[0]?.[0]:searchAreasReplyArray?.[0];
|
|
321
|
+
if(!searchAreasReply.generated_text)throw new Error('No generated text for search areas');
|
|
322
|
+
if(this.DEBUG)console.log('Generated search areas response: '+searchAreasReply.generated_text);
|
|
323
|
+
let searchAreasResponseIndex=searchAreasReply.generated_text.toString().indexOf('### Response:');
|
|
324
|
+
if(searchAreasResponseIndex>=0)searchAreasResponseIndex+='### Response:'.length;
|
|
325
|
+
else searchAreasResponseIndex=0;
|
|
326
|
+
|
|
327
|
+
let vectorSearchQuery=searchAreasReply.generated_text.toString().substring(searchAreasResponseIndex).trim();
|
|
328
|
+
//generally the first sentence has the greatest relevance to the actual prompt
|
|
329
|
+
//if(vectorSearchQuery.includes('.'))vectorSearchQuery=vectorSearchQuery.split('.')[0].trim();
|
|
330
|
+
if(this.DEBUG)console.log('Vector search query: '+vectorSearchQuery);
|
|
331
|
+
let queryResult=await vectorStore.similaritySearch(vectorSearchQuery,candidatesForInitialSelection);
|
|
332
|
+
if(this.DEBUG)console.log('Vector search results: ',queryResult);
|
|
333
|
+
|
|
334
|
+
let summaries:string[]=[];
|
|
335
|
+
//only bother doing summarisation if there are candidates which exceed the token count
|
|
336
|
+
if(queryResult.some(result=>result.pageContent.trim().split(/\s+/).length>this.targetSummarisedStringTokenCount)){
|
|
337
|
+
await this.checkSummariserLoaded();
|
|
338
|
+
if(!this.summariser)return;
|
|
339
|
+
summaries=(await Promise.allSettled(queryResult.map(async result=>{
|
|
340
|
+
if(!result.pageContent||typeof result.pageContent!=='string')return '';
|
|
341
|
+
if(result.pageContent.trim().split(/\s+/).length<=this.targetSummarisedStringTokenCount)return result.pageContent;
|
|
342
|
+
let summarisableSubstringIndices:AICompareCandidates.SummarisableSubstringIndices={
|
|
343
|
+
start:0,
|
|
344
|
+
end:result.pageContent.length
|
|
345
|
+
};
|
|
346
|
+
if(getSummarisableSubstringIndices)Object.assign(summarisableSubstringIndices,getSummarisableSubstringIndices(result.pageContent));
|
|
347
|
+
summarisableSubstringIndices.start=lodash.clamp(lodash.toSafeInteger(summarisableSubstringIndices.start),0,result.pageContent.length);
|
|
348
|
+
summarisableSubstringIndices.end=lodash.clamp(lodash.toSafeInteger(summarisableSubstringIndices.end),0,result.pageContent.length);
|
|
349
|
+
let summarisableSubstring=result.pageContent.substring(summarisableSubstringIndices.start,summarisableSubstringIndices.end);
|
|
350
|
+
let contentBefore=result.pageContent.substring(0,summarisableSubstringIndices.start);
|
|
351
|
+
let contentAfter=result.pageContent.substring(summarisableSubstringIndices.end);
|
|
352
|
+
let wordsWithoutSummarisable=contentBefore.split(/s+/).length+contentAfter.split(/s+/).length;
|
|
353
|
+
let targetSummarisedSubstringTokenCount=Math.max(1,420-wordsWithoutSummarisable);
|
|
354
|
+
let summarisedSubstringArray=await this.summariser?.(summarisableSubstring,<TextGenerationConfig>{
|
|
355
|
+
max_length:targetSummarisedSubstringTokenCount
|
|
356
|
+
});
|
|
357
|
+
let summarisedSubstring=Array.isArray(summarisedSubstringArray?.[0])?summarisedSubstringArray?.[0]?.[0]:summarisedSubstringArray?.[0];
|
|
358
|
+
let summarisedString=contentBefore+(summarisedSubstring?.summary_text??'').split(/s+/).slice(targetSummarisedSubstringTokenCount).join(' ')+contentAfter;
|
|
359
|
+
if(this.DEBUG)console.log('Summarised candidate: '+summarisedString);
|
|
360
|
+
return summarisedString;
|
|
361
|
+
}))).filter(result=>result.status==='fulfilled'&&result.value).map(result=>(result as PromiseFulfilledResult<string>).value);
|
|
362
|
+
}else{
|
|
363
|
+
summaries=queryResult.map(result=>result.pageContent);
|
|
364
|
+
}
|
|
365
|
+
|
|
366
|
+
let rankingPromptTemplate=this.generatePromptTemplate(generateRankingInstruction({
|
|
367
|
+
problemDescription,
|
|
368
|
+
summaries,
|
|
369
|
+
candidatesForFinalSelection,
|
|
370
|
+
candidateIdentifierField:String(candidateIdentifierField)
|
|
371
|
+
}));
|
|
372
|
+
if(this.DEBUG)console.log('Formatted ranking prompt: '+rankingPromptTemplate);
|
|
373
|
+
let rankingPromptTokens=this.tokeniser.encode(rankingPromptTemplate);
|
|
374
|
+
if(rankingPromptTokens.length>this.tokeniser.model_max_length)throw new Error('Ranking instruction prompt is too long for the tokeniser model');
|
|
375
|
+
let rankingArray=await this.generator(rankingPromptTemplate,{
|
|
376
|
+
max_new_tokens:this.rankingMaxNewTokens,
|
|
377
|
+
temperature:this.rankingTemperature,
|
|
378
|
+
repetition_penalty:this.rankingRepetitionPenalty,
|
|
379
|
+
pad_token_id,
|
|
380
|
+
eos_token_id
|
|
381
|
+
});
|
|
382
|
+
let ranking=Array.isArray(rankingArray?.[0])?rankingArray?.[0]?.[0]:rankingArray[0];
|
|
383
|
+
rationale=ranking.generated_text.toString().trim().replace(/(\*\*)|(<\/?s>)|(\[.*?\])\s*/g, '');
|
|
384
|
+
if(this.DEBUG)console.log('Generated rationale: '+rationale);
|
|
385
|
+
let rationaleResponseIndex=rationale.indexOf('### Response:');
|
|
386
|
+
if(rationaleResponseIndex>=0)rationaleResponseIndex+='### Response:'.length;
|
|
387
|
+
else rationaleResponseIndex=0;
|
|
388
|
+
rationale=rationale.substring(rationaleResponseIndex);
|
|
389
|
+
//if(!rationale)throw new Error('No rationale generated');
|
|
390
|
+
|
|
391
|
+
if(rationale){
|
|
392
|
+
let identifiers=extractIdentifiersFromRationale(rationale);
|
|
393
|
+
if(identifiers.length>candidatesForFinalSelection)identifiers=identifiers.slice(0,candidatesForFinalSelection);
|
|
394
|
+
selectedCandidates=lodash.compact(identifiers.map(identifier=>{
|
|
395
|
+
let selectedCandidate=candidates.find(candidate=>String(candidate[candidateIdentifierField]).toLowerCase()===identifier.toLowerCase());
|
|
396
|
+
if(selectedCandidate)return selectedCandidate;
|
|
397
|
+
selectedCandidate=candidates.find(candidate=>String(candidate[candidateIdentifierField]).toLowerCase().includes(identifier.toLowerCase()));
|
|
398
|
+
if(selectedCandidate)return selectedCandidate;
|
|
399
|
+
selectedCandidate=candidates.find(candidate=>identifier.toLowerCase().includes(String(candidate[candidateIdentifierField]).toLowerCase()));
|
|
400
|
+
if(selectedCandidate)return selectedCandidate;
|
|
401
|
+
return null;
|
|
402
|
+
}));
|
|
403
|
+
}
|
|
404
|
+
|
|
405
|
+
if(!Array.isArray(selectedCandidates)||selectedCandidates.length<=0){
|
|
406
|
+
selectedCandidates=lodash.uniq(lodash.compact(queryResult.map(result=>{
|
|
407
|
+
let identifier=extractIdentifierFromCandidateDocument({
|
|
408
|
+
candidateDocument:result.pageContent,
|
|
409
|
+
candidateIdentifierField:String(candidateIdentifierField)
|
|
410
|
+
});
|
|
411
|
+
if(this.DEBUG)console.log('Extracted identifier from candidate document: '+identifier);
|
|
412
|
+
let selectedCandidate=candidates.find(candidate=>String(candidate[candidateIdentifierField]).toLowerCase()===identifier.toLowerCase());
|
|
413
|
+
if(selectedCandidate)return selectedCandidate;
|
|
414
|
+
selectedCandidate=candidates.find(candidate=>String(candidate[candidateIdentifierField]).toLowerCase().includes(identifier.toLowerCase()));
|
|
415
|
+
if(selectedCandidate)return selectedCandidate;
|
|
416
|
+
selectedCandidate=candidates.find(candidate=>identifier.toLowerCase().includes(String(candidate[candidateIdentifierField]).toLowerCase()));
|
|
417
|
+
if(selectedCandidate)return selectedCandidate;
|
|
418
|
+
return null;
|
|
419
|
+
}))).slice(candidatesForFinalSelection);
|
|
420
|
+
}
|
|
421
|
+
if(this.DEBUG)console.log('Selected candidates',selectedCandidates);
|
|
422
|
+
|
|
423
|
+
return{
|
|
424
|
+
rationale,
|
|
425
|
+
selectedCandidates
|
|
426
|
+
};
|
|
427
|
+
}
|
|
428
|
+
};
|
|
429
|
+
|
|
430
|
+
export namespace AICompareCandidates{
|
|
431
|
+
export interface LoadArguments{
|
|
432
|
+
progressCallback?:ProgressCallback;
|
|
433
|
+
modelName:string;
|
|
434
|
+
};
|
|
435
|
+
|
|
436
|
+
export interface SummarisableSubstringIndices{
|
|
437
|
+
start:number;
|
|
438
|
+
end:number;
|
|
439
|
+
};
|
|
440
|
+
|
|
441
|
+
export interface CompareArguments<Candidate>{
|
|
442
|
+
candidates:Candidate[];
|
|
443
|
+
problemDescription:string;
|
|
444
|
+
generateSearchAreasInstruction?:(problemDescription:string)=>string;
|
|
445
|
+
convertCandidateToDocument?:(convertCandidateToDocumentArguments:ConvertCandidateToDocumentArguments<Candidate>)=>string;
|
|
446
|
+
candidatesForInitialSelection?:number;
|
|
447
|
+
candidatesForFinalSelection?:number;
|
|
448
|
+
generateRankingInstruction?:(generateRankingInstructionArguments:GenerateRankingInstructionArguments)=>string;
|
|
449
|
+
extractIdentifiersFromRationale?:(rationale:string)=>string[];
|
|
450
|
+
extractIdentifierFromCandidateDocument?:(extractIdentifierFromCandidateDocumentArguments:ExtractIdentifierFromCandidateDocumentArguments)=>string;
|
|
451
|
+
candidateIdentifierField?:keyof Candidate;
|
|
452
|
+
getSummarisableSubstringIndices?:(candidateDocument:string)=>SummarisableSubstringIndices;
|
|
453
|
+
};
|
|
454
|
+
|
|
455
|
+
export interface ConvertCandidateToDocumentArguments<Candidate>{
|
|
456
|
+
candidate:Candidate;
|
|
457
|
+
index:number;
|
|
458
|
+
};
|
|
459
|
+
|
|
460
|
+
export interface ExtractIdentifierFromCandidateDocumentArguments{
|
|
461
|
+
candidateDocument:string;
|
|
462
|
+
candidateIdentifierField:string;
|
|
463
|
+
};
|
|
464
|
+
|
|
465
|
+
export interface GenerateRankingInstructionArguments{
|
|
466
|
+
problemDescription:string;
|
|
467
|
+
summaries:string[];
|
|
468
|
+
candidatesForFinalSelection:number;
|
|
469
|
+
candidateIdentifierField:string;
|
|
470
|
+
};
|
|
471
|
+
|
|
472
|
+
export interface CompareCandidatesReturn<Candidate>{
|
|
473
|
+
selectedCandidates:Candidate[],
|
|
474
|
+
rationale:string
|
|
475
|
+
};
|
|
476
|
+
};
|
|
477
|
+
|
|
478
|
+
export default AICompareCandidates;
|
package/tsconfig.json
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
{
|
|
2
|
+
"compilerOptions": {
|
|
3
|
+
"allowSyntheticDefaultImports": true,
|
|
4
|
+
"declarationMap": true,
|
|
5
|
+
"jsx": "preserve",
|
|
6
|
+
"composite": true,
|
|
7
|
+
"module": "esnext",
|
|
8
|
+
"moduleResolution": "node",
|
|
9
|
+
"skipLibCheck": true
|
|
10
|
+
},
|
|
11
|
+
"extends": "./example/.quasar/tsconfig.json"
|
|
12
|
+
}
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
{
|
|
2
|
+
"compilerOptions": {
|
|
3
|
+
"allowSyntheticDefaultImports": true,
|
|
4
|
+
"declarationMap": true,
|
|
5
|
+
"jsx": "preserve",
|
|
6
|
+
"composite": true,
|
|
7
|
+
"module": "esnext",
|
|
8
|
+
"moduleResolution": "node",
|
|
9
|
+
"skipLibCheck": true
|
|
10
|
+
},
|
|
11
|
+
"extends": "./example/.quasar/tsconfig.json"
|
|
12
|
+
}
|