@llumiverse/core 0.8.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/lib/cjs/CompletionStream.js +78 -0
- package/lib/cjs/CompletionStream.js.map +1 -0
- package/lib/cjs/Driver.js +115 -0
- package/lib/cjs/Driver.js.map +1 -0
- package/lib/cjs/async.js +107 -0
- package/lib/cjs/async.js.map +1 -0
- package/lib/cjs/formatters.js +117 -0
- package/lib/cjs/formatters.js.map +1 -0
- package/lib/cjs/index.js +21 -0
- package/lib/cjs/index.js.map +1 -0
- package/lib/cjs/json.js +14 -0
- package/lib/cjs/json.js.map +1 -0
- package/lib/cjs/package.json +3 -0
- package/lib/cjs/types.js +67 -0
- package/lib/cjs/types.js.map +1 -0
- package/lib/cjs/validation.js +35 -0
- package/lib/cjs/validation.js.map +1 -0
- package/lib/esm/CompletionStream.js +73 -0
- package/lib/esm/CompletionStream.js.map +1 -0
- package/lib/esm/Driver.js +110 -0
- package/lib/esm/Driver.js.map +1 -0
- package/lib/esm/async.js +100 -0
- package/lib/esm/async.js.map +1 -0
- package/lib/esm/formatters.js +113 -0
- package/lib/esm/formatters.js.map +1 -0
- package/lib/esm/index.js +5 -0
- package/lib/esm/index.js.map +1 -0
- package/lib/esm/json.js +10 -0
- package/lib/esm/json.js.map +1 -0
- package/lib/esm/types.js +64 -0
- package/lib/esm/types.js.map +1 -0
- package/lib/esm/validation.js +30 -0
- package/lib/esm/validation.js.map +1 -0
- package/lib/types/CompletionStream.d.ts +21 -0
- package/lib/types/CompletionStream.d.ts.map +1 -0
- package/lib/types/Driver.d.ts +68 -0
- package/lib/types/Driver.d.ts.map +1 -0
- package/lib/types/async.d.ts +21 -0
- package/lib/types/async.d.ts.map +1 -0
- package/lib/types/formatters.d.ts +5 -0
- package/lib/types/formatters.d.ts.map +1 -0
- package/lib/types/index.d.ts +5 -0
- package/lib/types/index.d.ts.map +1 -0
- package/lib/types/json.d.ts +9 -0
- package/lib/types/json.d.ts.map +1 -0
- package/lib/types/types.d.ts +153 -0
- package/lib/types/types.d.ts.map +1 -0
- package/lib/types/validation.d.ts +8 -0
- package/lib/types/validation.d.ts.map +1 -0
- package/package.json +74 -0
- package/src/CompletionStream.ts +92 -0
- package/src/Driver.ts +204 -0
- package/src/async.ts +120 -0
- package/src/formatters.ts +147 -0
- package/src/index.ts +4 -0
- package/src/json.ts +17 -0
- package/src/types.ts +187 -0
- package/src/validation.ts +34 -0
package/src/async.ts
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
1
|
+
|
|
2
|
+
export async function* asyncMap<T, R>(asyncIterable: AsyncIterable<T>, callback: (value: T, index: number) => R) {
|
|
3
|
+
let i = 0;
|
|
4
|
+
for await (const val of asyncIterable)
|
|
5
|
+
yield callback(val, i++);
|
|
6
|
+
}
|
|
7
|
+
|
|
8
|
+
export function oneAsyncIterator<T>(value: T): AsyncIterable<T> {
|
|
9
|
+
return {
|
|
10
|
+
async *[Symbol.asyncIterator]() {
|
|
11
|
+
yield value
|
|
12
|
+
}
|
|
13
|
+
}
|
|
14
|
+
}
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
export class EventStream<T, ReturnT = any> implements AsyncIterable<T>{
|
|
18
|
+
|
|
19
|
+
private queue: T[] = [];
|
|
20
|
+
private pending?: {
|
|
21
|
+
resolve: (result: IteratorResult<T, ReturnT | undefined>) => void,
|
|
22
|
+
reject: (err: any) => void
|
|
23
|
+
};
|
|
24
|
+
private done = false;
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
push(event: T) {
|
|
28
|
+
if (this.done) {
|
|
29
|
+
throw new Error('Cannot push to a closed stream');
|
|
30
|
+
}
|
|
31
|
+
if (this.pending) {
|
|
32
|
+
this.pending.resolve({ value: event });
|
|
33
|
+
this.pending = undefined;
|
|
34
|
+
} else {
|
|
35
|
+
this.queue.push(event);
|
|
36
|
+
}
|
|
37
|
+
}
|
|
38
|
+
|
|
39
|
+
/**
|
|
40
|
+
* Close the stream. This means the stream cannot be feeded anymore.
|
|
41
|
+
* But the consumer can still consume the remaining events.
|
|
42
|
+
*/
|
|
43
|
+
close(value?: ReturnT) {
|
|
44
|
+
this.done = true;
|
|
45
|
+
if (this.pending) {
|
|
46
|
+
this.pending.resolve({ done: true, value });
|
|
47
|
+
this.pending = undefined;
|
|
48
|
+
}
|
|
49
|
+
}
|
|
50
|
+
|
|
51
|
+
[Symbol.asyncIterator](): AsyncIterator<T, ReturnT | undefined> {
|
|
52
|
+
const self = this;
|
|
53
|
+
return {
|
|
54
|
+
next(): Promise<IteratorResult<T, ReturnT | undefined>> {
|
|
55
|
+
const next = self.queue.shift();
|
|
56
|
+
if (next !== undefined) {
|
|
57
|
+
return Promise.resolve({ value: next });
|
|
58
|
+
} else if (self.done) {
|
|
59
|
+
return Promise.resolve({ done: true, value: undefined as ReturnT });
|
|
60
|
+
} else {
|
|
61
|
+
return new Promise<IteratorResult<T, ReturnT | undefined>>((resolve, reject) => {
|
|
62
|
+
self.pending = { resolve, reject };
|
|
63
|
+
});
|
|
64
|
+
}
|
|
65
|
+
},
|
|
66
|
+
async return(value?: ReturnT | Promise<ReturnT>): Promise<IteratorResult<T, ReturnT>> {
|
|
67
|
+
self.done = true;
|
|
68
|
+
self.queue = [];
|
|
69
|
+
if (value === undefined) {
|
|
70
|
+
return { done: true, value: undefined as ReturnT };
|
|
71
|
+
}
|
|
72
|
+
const _value = await value;
|
|
73
|
+
return { done: true, value: _value };
|
|
74
|
+
}
|
|
75
|
+
}
|
|
76
|
+
}
|
|
77
|
+
}
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
/**
|
|
82
|
+
* Transform an async iterator by applying a function to each value.
|
|
83
|
+
* @param originalGenerator
|
|
84
|
+
* @param transform
|
|
85
|
+
**/
|
|
86
|
+
export async function* transformAsyncIterator<T, V>(
|
|
87
|
+
originalGenerator: AsyncIterable<T>,
|
|
88
|
+
transform: (value: T) => V | Promise<V>
|
|
89
|
+
): AsyncIterable<V> {
|
|
90
|
+
for await (const value of originalGenerator) {
|
|
91
|
+
yield transform(value);
|
|
92
|
+
}
|
|
93
|
+
}
|
|
94
|
+
|
|
95
|
+
//TODO move in a test file
|
|
96
|
+
// const max = 10; let cnt = 0;
|
|
97
|
+
// function feedStream(stream: EventStream<string>) {
|
|
98
|
+
// setTimeout(() => {
|
|
99
|
+
// cnt++;
|
|
100
|
+
// console.log('push: ', cnt, max);
|
|
101
|
+
// stream.push('event ' + cnt);
|
|
102
|
+
// if (cnt < max) {
|
|
103
|
+
// console.log('next: ', cnt, max);
|
|
104
|
+
// setTimeout(() => feedStream(stream), 1000);
|
|
105
|
+
// } else {
|
|
106
|
+
// console.log('end of stream');
|
|
107
|
+
// stream.close();
|
|
108
|
+
// }
|
|
109
|
+
// }, 1000);
|
|
110
|
+
// }
|
|
111
|
+
|
|
112
|
+
// const stream = new EventStream<string>();
|
|
113
|
+
// feedStream(stream);
|
|
114
|
+
|
|
115
|
+
// for await (const chunk of stream) {
|
|
116
|
+
// console.log('++++chunk:', chunk);
|
|
117
|
+
// }
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
|
|
@@ -0,0 +1,147 @@
|
|
|
1
|
+
import { JSONSchema4 } from "json-schema";
|
|
2
|
+
import OpenAI from "openai";
|
|
3
|
+
import {
|
|
4
|
+
PromptFormats,
|
|
5
|
+
PromptRole,
|
|
6
|
+
PromptSegment,
|
|
7
|
+
} from "./types.js";
|
|
8
|
+
|
|
9
|
+
export function inferFormatterFromModelName(modelName: string): PromptFormats {
|
|
10
|
+
const name = modelName.toLowerCase();
|
|
11
|
+
if (name.includes("llama")) {
|
|
12
|
+
return PromptFormats.llama2;
|
|
13
|
+
} else if (name.includes("gpt")) {
|
|
14
|
+
return PromptFormats.openai;
|
|
15
|
+
} else if (name.includes("claude")) {
|
|
16
|
+
return PromptFormats.claude;
|
|
17
|
+
} else {
|
|
18
|
+
return PromptFormats.genericTextLLM;
|
|
19
|
+
}
|
|
20
|
+
}
|
|
21
|
+
|
|
22
|
+
export const PromptFormatters: Record<
|
|
23
|
+
PromptFormats,
|
|
24
|
+
(messages: PromptSegment[], schema?: JSONSchema4) => any
|
|
25
|
+
> = {
|
|
26
|
+
openai: openAI,
|
|
27
|
+
llama2: llama2,
|
|
28
|
+
claude: claude,
|
|
29
|
+
genericTextLLM: genericColonSeparator,
|
|
30
|
+
};
|
|
31
|
+
|
|
32
|
+
function openAI(segments: PromptSegment[]) {
|
|
33
|
+
const system: OpenAI.Chat.ChatCompletionMessageParam[] = [];
|
|
34
|
+
const others: OpenAI.Chat.ChatCompletionMessageParam[] = [];
|
|
35
|
+
const safety: OpenAI.Chat.ChatCompletionMessageParam[] = [];
|
|
36
|
+
|
|
37
|
+
for (const msg of segments) {
|
|
38
|
+
if (msg.role === PromptRole.system) {
|
|
39
|
+
system.push({ content: msg.content, role: "system" });
|
|
40
|
+
} else if (msg.role === PromptRole.safety) {
|
|
41
|
+
safety.push({ content: msg.content, role: "system" });
|
|
42
|
+
} else {
|
|
43
|
+
others.push({ content: msg.content, role: "user" });
|
|
44
|
+
}
|
|
45
|
+
}
|
|
46
|
+
|
|
47
|
+
// put system mesages first and safety last
|
|
48
|
+
return system.concat(others).concat(safety);
|
|
49
|
+
}
|
|
50
|
+
|
|
51
|
+
function llama2(messages: PromptSegment[], schema?: JSONSchema4) {
|
|
52
|
+
const BOS = "<s>";
|
|
53
|
+
const EOS = "</s>";
|
|
54
|
+
const INST = "[INST]";
|
|
55
|
+
const INST_END = "[/INST]";
|
|
56
|
+
const SYS = "<<SYS>>\n";
|
|
57
|
+
const SYS_END = "\n<</SYS>>";
|
|
58
|
+
|
|
59
|
+
const promptMessages = [BOS];
|
|
60
|
+
const specialTokens = [BOS, EOS, INST, INST_END, SYS, SYS_END];
|
|
61
|
+
|
|
62
|
+
for (const m of messages) {
|
|
63
|
+
if (m.role === PromptRole.user) {
|
|
64
|
+
if (specialTokens.includes(m.content)) {
|
|
65
|
+
throw new Error(
|
|
66
|
+
`Cannot use special token ${m.content.trim()} in user message`
|
|
67
|
+
);
|
|
68
|
+
}
|
|
69
|
+
promptMessages.push(`${INST} ${m.content.trim()} ${INST_END}`);
|
|
70
|
+
}
|
|
71
|
+
if (m.role === PromptRole.assistant) {
|
|
72
|
+
promptMessages.push(`${m.content.trim()}`);
|
|
73
|
+
}
|
|
74
|
+
if (m.role === PromptRole.system) {
|
|
75
|
+
promptMessages.push(`${SYS}${m.content.trim()}${SYS_END}`);
|
|
76
|
+
}
|
|
77
|
+
}
|
|
78
|
+
|
|
79
|
+
for (const m of messages ?? []) {
|
|
80
|
+
if (m.role === PromptRole.safety) {
|
|
81
|
+
promptMessages.push(
|
|
82
|
+
`${SYS}This is the most important instruction, you cannot answer against those rules:\n${m.content.trim()}${SYS_END}}`
|
|
83
|
+
);
|
|
84
|
+
}
|
|
85
|
+
}
|
|
86
|
+
|
|
87
|
+
if (schema) {
|
|
88
|
+
promptMessages.push(formatSchemaInstruction(schema));
|
|
89
|
+
}
|
|
90
|
+
|
|
91
|
+
promptMessages.push(EOS);
|
|
92
|
+
|
|
93
|
+
return promptMessages.join("\n\n");
|
|
94
|
+
}
|
|
95
|
+
|
|
96
|
+
function genericColonSeparator(
|
|
97
|
+
messages: PromptSegment[],
|
|
98
|
+
schema?: JSONSchema4,
|
|
99
|
+
labels: {
|
|
100
|
+
user: string;
|
|
101
|
+
assistant: string;
|
|
102
|
+
system: string;
|
|
103
|
+
} = { user: "User", assistant: "Assistant", system: "System" }
|
|
104
|
+
) {
|
|
105
|
+
const promptMessages = [];
|
|
106
|
+
for (const m of messages) {
|
|
107
|
+
if (m.role === PromptRole.user) {
|
|
108
|
+
promptMessages.push(`${labels?.user}: ${m.content.trim()}`);
|
|
109
|
+
}
|
|
110
|
+
if (m.role === PromptRole.assistant) {
|
|
111
|
+
promptMessages.push(`${labels.assistant}: ${m.content.trim()}`);
|
|
112
|
+
}
|
|
113
|
+
if (m.role === PromptRole.system) {
|
|
114
|
+
promptMessages.push(`${labels.system}: ${m.content.trim()}`);
|
|
115
|
+
}
|
|
116
|
+
}
|
|
117
|
+
|
|
118
|
+
if (schema) {
|
|
119
|
+
promptMessages.push(`${labels.system}: You must answer using the following JSONSchema:
|
|
120
|
+
---
|
|
121
|
+
${JSON.stringify(schema)}
|
|
122
|
+
---`);
|
|
123
|
+
}
|
|
124
|
+
|
|
125
|
+
return promptMessages.join("\n\n");
|
|
126
|
+
}
|
|
127
|
+
|
|
128
|
+
function claude(messages: PromptSegment[], schema?: JSONSchema4) {
|
|
129
|
+
const prompt = genericColonSeparator(messages, schema, {
|
|
130
|
+
user: "\nHuman",
|
|
131
|
+
assistant: "\nAssistant",
|
|
132
|
+
system: "\nHuman",
|
|
133
|
+
});
|
|
134
|
+
|
|
135
|
+
return "\n\n" + prompt + "\n\nAssistant:";
|
|
136
|
+
}
|
|
137
|
+
|
|
138
|
+
function formatSchemaInstruction(schema: JSONSchema4) {
|
|
139
|
+
const schema_instruction = `<<SYS>>You must answer using the following JSONSchema.
|
|
140
|
+
Do not write anything other than a JSON object corresponding to the schema.
|
|
141
|
+
<schema>
|
|
142
|
+
${JSON.stringify(schema)}
|
|
143
|
+
</schema>
|
|
144
|
+
<</SYS>>`;
|
|
145
|
+
|
|
146
|
+
return schema_instruction;
|
|
147
|
+
}
|
package/src/index.ts
ADDED
package/src/json.ts
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
|
|
2
|
+
function extractJsonFromText(text: string): string {
|
|
3
|
+
const start = text.indexOf("{");
|
|
4
|
+
const end = text.lastIndexOf("}");
|
|
5
|
+
return text.substring(start, end + 1);
|
|
6
|
+
}
|
|
7
|
+
|
|
8
|
+
//TODO LAX parse JSON
|
|
9
|
+
export function parseJSON(text: string): Json {
|
|
10
|
+
return JSON.parse(extractJsonFromText(text));
|
|
11
|
+
}
|
|
12
|
+
|
|
13
|
+
export type JsonPrimative = string | number | boolean | null;
|
|
14
|
+
export type JsonArray = Json[];
|
|
15
|
+
export type JsonObject = { [key: string]: Json };
|
|
16
|
+
export type JsonComposite = JsonArray | JsonObject;
|
|
17
|
+
export type Json = JsonPrimative | JsonComposite;
|
package/src/types.ts
ADDED
|
@@ -0,0 +1,187 @@
|
|
|
1
|
+
import { JSONSchema4 } from "json-schema";
|
|
2
|
+
import { Readable } from "stream";
|
|
3
|
+
import { JsonObject } from "./json.js";
|
|
4
|
+
|
|
5
|
+
export interface ResultValidationError {
|
|
6
|
+
code: 'validation_error' | 'json_error';
|
|
7
|
+
message: string;
|
|
8
|
+
data?: string;
|
|
9
|
+
}
|
|
10
|
+
|
|
11
|
+
export interface Completion<ResultT = any> {
|
|
12
|
+
// the driver impl must return the result and optionally the token_usage. the execution time is computed by the extended abstract driver
|
|
13
|
+
result: ResultT;
|
|
14
|
+
token_usage?: ExecutionTokenUsage;
|
|
15
|
+
execution_time?: number;
|
|
16
|
+
/**
|
|
17
|
+
* Set only if a result validation error occured, otherwise if the result is valid the error field is undefined
|
|
18
|
+
* This can only be set if the resultSchema is set and the reuslt could not be parsed as a json or if the result does not match the schema
|
|
19
|
+
*/
|
|
20
|
+
error?: ResultValidationError;
|
|
21
|
+
}
|
|
22
|
+
|
|
23
|
+
export interface ExecutionResponse<PromptT = any> extends Completion {
|
|
24
|
+
prompt: PromptT;
|
|
25
|
+
}
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
export interface CompletionStream<PromptT = any> extends AsyncIterable<string> {
|
|
29
|
+
completion: ExecutionResponse<PromptT> | undefined;
|
|
30
|
+
}
|
|
31
|
+
|
|
32
|
+
export interface Logger {
|
|
33
|
+
debug: (...obj: any[]) => void;
|
|
34
|
+
info: (...obj: any[]) => void;
|
|
35
|
+
warn: (...obj: any[]) => void;
|
|
36
|
+
error: (...obj: any[]) => void;
|
|
37
|
+
}
|
|
38
|
+
|
|
39
|
+
export interface DriverOptions {
|
|
40
|
+
logger?: Logger | false;
|
|
41
|
+
}
|
|
42
|
+
|
|
43
|
+
export interface PromptOptions {
|
|
44
|
+
model: string;
|
|
45
|
+
format?: PromptFormats;
|
|
46
|
+
resultSchema?: JSONSchema4;
|
|
47
|
+
}
|
|
48
|
+
export interface ExecutionOptions extends PromptOptions {
|
|
49
|
+
temperature?: number;
|
|
50
|
+
max_tokens?: number;
|
|
51
|
+
}
|
|
52
|
+
|
|
53
|
+
// ============== Prompts ===============
|
|
54
|
+
export enum PromptRole {
|
|
55
|
+
safety = "safety",
|
|
56
|
+
system = "system",
|
|
57
|
+
user = "user",
|
|
58
|
+
assistant = "assistant",
|
|
59
|
+
}
|
|
60
|
+
|
|
61
|
+
export interface PromptSegment {
|
|
62
|
+
role: PromptRole;
|
|
63
|
+
content: string;
|
|
64
|
+
}
|
|
65
|
+
|
|
66
|
+
export interface ExecutionTokenUsage {
|
|
67
|
+
prompt?: number;
|
|
68
|
+
result?: number;
|
|
69
|
+
total?: number;
|
|
70
|
+
}
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
// ============== AI MODEL ==============
|
|
74
|
+
|
|
75
|
+
export interface AIModel<ProviderKeys = string> {
|
|
76
|
+
id: string; //id of the model known by the provider
|
|
77
|
+
name: string; //human readable name
|
|
78
|
+
provider: ProviderKeys; //provider name
|
|
79
|
+
description?: string;
|
|
80
|
+
version?: string; //if any version is specified
|
|
81
|
+
type?: ModelType; //type of the model
|
|
82
|
+
tags?: string[]; //tags for searching
|
|
83
|
+
owner?: string; //owner of the model
|
|
84
|
+
status?: AIModelStatus; //status of the model
|
|
85
|
+
canStream?: boolean; //if the model's reponse can be streamed
|
|
86
|
+
isCustom?: boolean; //if the model is a custom model (a trained model)
|
|
87
|
+
}
|
|
88
|
+
|
|
89
|
+
export enum AIModelStatus {
|
|
90
|
+
Available = "available",
|
|
91
|
+
Pending = "pending",
|
|
92
|
+
Stopped = "stopped",
|
|
93
|
+
Unavailable = "unavailable",
|
|
94
|
+
Unknown = "unknown"
|
|
95
|
+
}
|
|
96
|
+
|
|
97
|
+
/**
|
|
98
|
+
* payload to list available models for an enviroment
|
|
99
|
+
* @param environmentId id of the environment
|
|
100
|
+
* @param query text to search for in model name/description
|
|
101
|
+
* @param type type of the model
|
|
102
|
+
* @param tags tags for searching
|
|
103
|
+
*/
|
|
104
|
+
export interface ModelSearchPayload {
|
|
105
|
+
text: string;
|
|
106
|
+
type?: ModelType;
|
|
107
|
+
tags?: string[];
|
|
108
|
+
owner?: string;
|
|
109
|
+
}
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
export enum ModelType {
|
|
113
|
+
Classifier = "classifier",
|
|
114
|
+
Regressor = "regressor",
|
|
115
|
+
Clustering = "clustering",
|
|
116
|
+
AnomalyDetection = "anomaly-detection",
|
|
117
|
+
TimeSeries = "time-series",
|
|
118
|
+
Text = "text",
|
|
119
|
+
Image = "image",
|
|
120
|
+
Audio = "audio",
|
|
121
|
+
Video = "video",
|
|
122
|
+
Embedding = "embedding",
|
|
123
|
+
Chat = "chat",
|
|
124
|
+
Code = "code",
|
|
125
|
+
NLP = "nlp",
|
|
126
|
+
MultiModal = "multi-modal",
|
|
127
|
+
Test = "test",
|
|
128
|
+
Other = "other",
|
|
129
|
+
Unknown = "unknown"
|
|
130
|
+
}
|
|
131
|
+
|
|
132
|
+
// ============== Built-in formats and drivers =====================
|
|
133
|
+
//TODO
|
|
134
|
+
|
|
135
|
+
export enum PromptFormats {
|
|
136
|
+
openai = "openai",
|
|
137
|
+
llama2 = "llama2",
|
|
138
|
+
claude = "claude",
|
|
139
|
+
genericTextLLM = "genericTextLLM",
|
|
140
|
+
}
|
|
141
|
+
|
|
142
|
+
export enum BuiltinProviders {
|
|
143
|
+
openai = 'openai',
|
|
144
|
+
huggingface_ie = 'huggingface_ie',
|
|
145
|
+
replicate = 'replicate',
|
|
146
|
+
bedrock = 'bedrock',
|
|
147
|
+
vertexai = 'vertexai',
|
|
148
|
+
togetherai = 'togetherai',
|
|
149
|
+
//virtual = 'virtual',
|
|
150
|
+
//cohere = 'cohere',
|
|
151
|
+
}
|
|
152
|
+
|
|
153
|
+
// ============== training =====================
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
export interface DataSource {
|
|
157
|
+
name: string;
|
|
158
|
+
getStream(): Readable;
|
|
159
|
+
getURL(): Promise<string>;
|
|
160
|
+
}
|
|
161
|
+
|
|
162
|
+
export interface TrainingOptions {
|
|
163
|
+
name: string; // the new model name
|
|
164
|
+
model: string; // the model to train
|
|
165
|
+
params?: JsonObject; // the training parameters
|
|
166
|
+
}
|
|
167
|
+
|
|
168
|
+
export interface TrainingPromptOptions {
|
|
169
|
+
segments: PromptSegment[];
|
|
170
|
+
completion: string | JsonObject;
|
|
171
|
+
model: string; // the model to train
|
|
172
|
+
schema?: JSONSchema4; // the resuilt schema f any
|
|
173
|
+
}
|
|
174
|
+
|
|
175
|
+
export enum TrainingJobStatus {
|
|
176
|
+
running = "running",
|
|
177
|
+
succeeded = "succeeded",
|
|
178
|
+
failed = "failed",
|
|
179
|
+
cancelled = "cancelled",
|
|
180
|
+
}
|
|
181
|
+
|
|
182
|
+
export interface TrainingJob {
|
|
183
|
+
id: string; // id of the training job
|
|
184
|
+
status: TrainingJobStatus; // status of the training job - depends on the implementation
|
|
185
|
+
details?: string;
|
|
186
|
+
model?: string; // the name of the fine tuned model which is created
|
|
187
|
+
}
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
import { JSONSchema4, validate } from "json-schema";
|
|
2
|
+
import { parseJSON } from "./json.js";
|
|
3
|
+
import { ResultValidationError } from "./types.js";
|
|
4
|
+
|
|
5
|
+
export class ValidationError extends Error implements ResultValidationError {
|
|
6
|
+
constructor(
|
|
7
|
+
public code: 'validation_error' | 'json_error',
|
|
8
|
+
message: string
|
|
9
|
+
) {
|
|
10
|
+
super(message)
|
|
11
|
+
this.name = 'ValidationError'
|
|
12
|
+
}
|
|
13
|
+
}
|
|
14
|
+
|
|
15
|
+
export function validateResult(data: any, schema: JSONSchema4) {
|
|
16
|
+
let json;
|
|
17
|
+
if (typeof data === "string") {
|
|
18
|
+
try {
|
|
19
|
+
json = parseJSON(data);
|
|
20
|
+
} catch (error: any) {
|
|
21
|
+
throw new ValidationError("json_error", error.message)
|
|
22
|
+
}
|
|
23
|
+
} else {
|
|
24
|
+
json = data;
|
|
25
|
+
}
|
|
26
|
+
const validation = validate(json, schema);
|
|
27
|
+
if (!validation.valid) {
|
|
28
|
+
throw new ValidationError(
|
|
29
|
+
"validation_error",
|
|
30
|
+
validation.errors.map(e => e.message).join(",\n"))
|
|
31
|
+
}
|
|
32
|
+
|
|
33
|
+
return json;
|
|
34
|
+
}
|