langchain 0.0.89 → 0.0.91
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/callbacks/handlers/console.cjs +2 -0
- package/dist/callbacks/handlers/console.js +2 -0
- package/dist/callbacks/handlers/tracer.cjs +0 -4
- package/dist/callbacks/handlers/tracer.d.ts +1 -1
- package/dist/callbacks/handlers/tracer.js +0 -4
- package/dist/callbacks/handlers/tracer_langchain.cjs +55 -3
- package/dist/callbacks/handlers/tracer_langchain.d.ts +21 -1
- package/dist/callbacks/handlers/tracer_langchain.js +55 -3
- package/dist/callbacks/handlers/tracer_langchain_v1.d.ts +1 -1
- package/dist/llms/ai21.cjs +189 -0
- package/dist/llms/ai21.d.ts +81 -0
- package/dist/llms/ai21.js +185 -0
- package/dist/prompts/base.d.ts +1 -0
- package/dist/prompts/chat.d.ts +2 -1
- package/dist/prompts/index.cjs +3 -1
- package/dist/prompts/index.d.ts +1 -0
- package/dist/prompts/index.js +1 -0
- package/dist/prompts/pipeline.cjs +76 -0
- package/dist/prompts/pipeline.d.ts +24 -0
- package/dist/prompts/pipeline.js +72 -0
- package/dist/vectorstores/chroma.cjs +2 -2
- package/dist/vectorstores/chroma.js +2 -2
- package/dist/vectorstores/pinecone.cjs +15 -3
- package/dist/vectorstores/pinecone.js +15 -3
- package/dist/vectorstores/prisma.cjs +65 -25
- package/dist/vectorstores/prisma.d.ts +27 -10
- package/dist/vectorstores/prisma.js +65 -25
- package/dist/vectorstores/typeorm.cjs +187 -0
- package/dist/vectorstores/typeorm.d.ts +32 -0
- package/dist/vectorstores/typeorm.js +182 -0
- package/llms/ai21.cjs +1 -0
- package/llms/ai21.d.ts +1 -0
- package/llms/ai21.js +1 -0
- package/package.json +25 -3
- package/vectorstores/typeorm.cjs +1 -0
- package/vectorstores/typeorm.d.ts +1 -0
- package/vectorstores/typeorm.js +1 -0
|
@@ -52,7 +52,6 @@ class BaseTracer extends base_js_1.BaseCallbackHandler {
|
|
|
52
52
|
name: llm.name,
|
|
53
53
|
parent_run_id: parentRunId,
|
|
54
54
|
start_time: Date.now(),
|
|
55
|
-
end_time: 0,
|
|
56
55
|
serialized: llm,
|
|
57
56
|
inputs: { prompts },
|
|
58
57
|
execution_order,
|
|
@@ -71,7 +70,6 @@ class BaseTracer extends base_js_1.BaseCallbackHandler {
|
|
|
71
70
|
name: llm.name,
|
|
72
71
|
parent_run_id: parentRunId,
|
|
73
72
|
start_time: Date.now(),
|
|
74
|
-
end_time: 0,
|
|
75
73
|
serialized: llm,
|
|
76
74
|
inputs: { messages },
|
|
77
75
|
execution_order,
|
|
@@ -110,7 +108,6 @@ class BaseTracer extends base_js_1.BaseCallbackHandler {
|
|
|
110
108
|
name: chain.name,
|
|
111
109
|
parent_run_id: parentRunId,
|
|
112
110
|
start_time: Date.now(),
|
|
113
|
-
end_time: 0,
|
|
114
111
|
serialized: chain,
|
|
115
112
|
inputs,
|
|
116
113
|
execution_order,
|
|
@@ -148,7 +145,6 @@ class BaseTracer extends base_js_1.BaseCallbackHandler {
|
|
|
148
145
|
name: tool.name,
|
|
149
146
|
parent_run_id: parentRunId,
|
|
150
147
|
start_time: Date.now(),
|
|
151
|
-
end_time: 0,
|
|
152
148
|
serialized: tool,
|
|
153
149
|
inputs: { input },
|
|
154
150
|
execution_order,
|
|
@@ -49,7 +49,6 @@ export class BaseTracer extends BaseCallbackHandler {
|
|
|
49
49
|
name: llm.name,
|
|
50
50
|
parent_run_id: parentRunId,
|
|
51
51
|
start_time: Date.now(),
|
|
52
|
-
end_time: 0,
|
|
53
52
|
serialized: llm,
|
|
54
53
|
inputs: { prompts },
|
|
55
54
|
execution_order,
|
|
@@ -68,7 +67,6 @@ export class BaseTracer extends BaseCallbackHandler {
|
|
|
68
67
|
name: llm.name,
|
|
69
68
|
parent_run_id: parentRunId,
|
|
70
69
|
start_time: Date.now(),
|
|
71
|
-
end_time: 0,
|
|
72
70
|
serialized: llm,
|
|
73
71
|
inputs: { messages },
|
|
74
72
|
execution_order,
|
|
@@ -107,7 +105,6 @@ export class BaseTracer extends BaseCallbackHandler {
|
|
|
107
105
|
name: chain.name,
|
|
108
106
|
parent_run_id: parentRunId,
|
|
109
107
|
start_time: Date.now(),
|
|
110
|
-
end_time: 0,
|
|
111
108
|
serialized: chain,
|
|
112
109
|
inputs,
|
|
113
110
|
execution_order,
|
|
@@ -145,7 +142,6 @@ export class BaseTracer extends BaseCallbackHandler {
|
|
|
145
142
|
name: tool.name,
|
|
146
143
|
parent_run_id: parentRunId,
|
|
147
144
|
start_time: Date.now(),
|
|
148
|
-
end_time: 0,
|
|
149
145
|
serialized: tool,
|
|
150
146
|
inputs: { input },
|
|
151
147
|
execution_order,
|
|
@@ -70,19 +70,22 @@ class LangChainTracer extends tracer_js_1.BaseTracer {
|
|
|
70
70
|
start_time: run.start_time,
|
|
71
71
|
end_time: run.end_time,
|
|
72
72
|
run_type: run.run_type,
|
|
73
|
-
|
|
73
|
+
// example_id is only set for the root run
|
|
74
|
+
reference_example_id: run.parent_run_id ? undefined : example_id,
|
|
74
75
|
extra: runExtra,
|
|
76
|
+
parent_run_id: run.parent_run_id,
|
|
75
77
|
execution_order: run.execution_order,
|
|
76
78
|
serialized: run.serialized,
|
|
77
79
|
error: run.error,
|
|
78
80
|
inputs: run.inputs,
|
|
79
81
|
outputs: run.outputs ?? {},
|
|
80
82
|
session_name: this.sessionName,
|
|
81
|
-
child_runs:
|
|
83
|
+
child_runs: [],
|
|
82
84
|
};
|
|
83
85
|
return persistedRun;
|
|
84
86
|
}
|
|
85
|
-
async persistRun(
|
|
87
|
+
async persistRun(_run) { }
|
|
88
|
+
async _persistRunSingle(run) {
|
|
86
89
|
const persistedRun = await this._convertToCreate(run, this.exampleId);
|
|
87
90
|
const endpoint = `${this.endpoint}/runs`;
|
|
88
91
|
const response = await this.caller.call(fetch, endpoint, {
|
|
@@ -98,5 +101,54 @@ class LangChainTracer extends tracer_js_1.BaseTracer {
|
|
|
98
101
|
throw new Error(`Failed to persist run: ${response.status} ${response.statusText} ${body}`);
|
|
99
102
|
}
|
|
100
103
|
}
|
|
104
|
+
async _updateRunSingle(run) {
|
|
105
|
+
const runUpdate = {
|
|
106
|
+
end_time: run.end_time,
|
|
107
|
+
error: run.error,
|
|
108
|
+
outputs: run.outputs,
|
|
109
|
+
parent_run_id: run.parent_run_id,
|
|
110
|
+
reference_example_id: run.reference_example_id,
|
|
111
|
+
};
|
|
112
|
+
const endpoint = `${this.endpoint}/runs/${run.id}`;
|
|
113
|
+
const response = await this.caller.call(fetch, endpoint, {
|
|
114
|
+
method: "PATCH",
|
|
115
|
+
headers: this.headers,
|
|
116
|
+
body: JSON.stringify(runUpdate),
|
|
117
|
+
signal: AbortSignal.timeout(this.timeout),
|
|
118
|
+
});
|
|
119
|
+
// consume the response body to release the connection
|
|
120
|
+
// https://undici.nodejs.org/#/?id=garbage-collection
|
|
121
|
+
const body = await response.text();
|
|
122
|
+
if (!response.ok) {
|
|
123
|
+
throw new Error(`Failed to update run: ${response.status} ${response.statusText} ${body}`);
|
|
124
|
+
}
|
|
125
|
+
}
|
|
126
|
+
async onLLMStart(run) {
|
|
127
|
+
await this._persistRunSingle(run);
|
|
128
|
+
}
|
|
129
|
+
async onLLMEnd(run) {
|
|
130
|
+
await this._updateRunSingle(run);
|
|
131
|
+
}
|
|
132
|
+
async onLLMError(run) {
|
|
133
|
+
await this._updateRunSingle(run);
|
|
134
|
+
}
|
|
135
|
+
async onChainStart(run) {
|
|
136
|
+
await this._persistRunSingle(run);
|
|
137
|
+
}
|
|
138
|
+
async onChainEnd(run) {
|
|
139
|
+
await this._updateRunSingle(run);
|
|
140
|
+
}
|
|
141
|
+
async onChainError(run) {
|
|
142
|
+
await this._updateRunSingle(run);
|
|
143
|
+
}
|
|
144
|
+
async onToolStart(run) {
|
|
145
|
+
await this._persistRunSingle(run);
|
|
146
|
+
}
|
|
147
|
+
async onToolEnd(run) {
|
|
148
|
+
await this._updateRunSingle(run);
|
|
149
|
+
}
|
|
150
|
+
async onToolError(run) {
|
|
151
|
+
await this._updateRunSingle(run);
|
|
152
|
+
}
|
|
101
153
|
}
|
|
102
154
|
exports.LangChainTracer = LangChainTracer;
|
|
@@ -1,9 +1,18 @@
|
|
|
1
1
|
import { AsyncCaller, AsyncCallerParams } from "../../util/async_caller.js";
|
|
2
2
|
import { BaseTracer, Run, BaseRun } from "./tracer.js";
|
|
3
|
+
import { RunOutputs } from "../../schema/index.js";
|
|
3
4
|
export interface RunCreate extends BaseRun {
|
|
5
|
+
parent_run_id?: string;
|
|
4
6
|
child_runs: this[];
|
|
5
7
|
session_name?: string;
|
|
6
8
|
}
|
|
9
|
+
export interface RunUpdate {
|
|
10
|
+
end_time?: number;
|
|
11
|
+
error?: string;
|
|
12
|
+
outputs?: RunOutputs;
|
|
13
|
+
parent_run_id?: string;
|
|
14
|
+
reference_example_id?: string;
|
|
15
|
+
}
|
|
7
16
|
export interface LangChainTracerFields {
|
|
8
17
|
exampleId?: string;
|
|
9
18
|
sessionName?: string;
|
|
@@ -20,5 +29,16 @@ export declare class LangChainTracer extends BaseTracer implements LangChainTrac
|
|
|
20
29
|
timeout: number;
|
|
21
30
|
constructor({ exampleId, sessionName, callerParams, timeout, }?: LangChainTracerFields);
|
|
22
31
|
private _convertToCreate;
|
|
23
|
-
protected persistRun(
|
|
32
|
+
protected persistRun(_run: Run): Promise<void>;
|
|
33
|
+
protected _persistRunSingle(run: Run): Promise<void>;
|
|
34
|
+
protected _updateRunSingle(run: Run): Promise<void>;
|
|
35
|
+
onLLMStart(run: Run): Promise<void>;
|
|
36
|
+
onLLMEnd(run: Run): Promise<void>;
|
|
37
|
+
onLLMError(run: Run): Promise<void>;
|
|
38
|
+
onChainStart(run: Run): Promise<void>;
|
|
39
|
+
onChainEnd(run: Run): Promise<void>;
|
|
40
|
+
onChainError(run: Run): Promise<void>;
|
|
41
|
+
onToolStart(run: Run): Promise<void>;
|
|
42
|
+
onToolEnd(run: Run): Promise<void>;
|
|
43
|
+
onToolError(run: Run): Promise<void>;
|
|
24
44
|
}
|
|
@@ -67,19 +67,22 @@ export class LangChainTracer extends BaseTracer {
|
|
|
67
67
|
start_time: run.start_time,
|
|
68
68
|
end_time: run.end_time,
|
|
69
69
|
run_type: run.run_type,
|
|
70
|
-
|
|
70
|
+
// example_id is only set for the root run
|
|
71
|
+
reference_example_id: run.parent_run_id ? undefined : example_id,
|
|
71
72
|
extra: runExtra,
|
|
73
|
+
parent_run_id: run.parent_run_id,
|
|
72
74
|
execution_order: run.execution_order,
|
|
73
75
|
serialized: run.serialized,
|
|
74
76
|
error: run.error,
|
|
75
77
|
inputs: run.inputs,
|
|
76
78
|
outputs: run.outputs ?? {},
|
|
77
79
|
session_name: this.sessionName,
|
|
78
|
-
child_runs:
|
|
80
|
+
child_runs: [],
|
|
79
81
|
};
|
|
80
82
|
return persistedRun;
|
|
81
83
|
}
|
|
82
|
-
async persistRun(
|
|
84
|
+
async persistRun(_run) { }
|
|
85
|
+
async _persistRunSingle(run) {
|
|
83
86
|
const persistedRun = await this._convertToCreate(run, this.exampleId);
|
|
84
87
|
const endpoint = `${this.endpoint}/runs`;
|
|
85
88
|
const response = await this.caller.call(fetch, endpoint, {
|
|
@@ -95,4 +98,53 @@ export class LangChainTracer extends BaseTracer {
|
|
|
95
98
|
throw new Error(`Failed to persist run: ${response.status} ${response.statusText} ${body}`);
|
|
96
99
|
}
|
|
97
100
|
}
|
|
101
|
+
async _updateRunSingle(run) {
|
|
102
|
+
const runUpdate = {
|
|
103
|
+
end_time: run.end_time,
|
|
104
|
+
error: run.error,
|
|
105
|
+
outputs: run.outputs,
|
|
106
|
+
parent_run_id: run.parent_run_id,
|
|
107
|
+
reference_example_id: run.reference_example_id,
|
|
108
|
+
};
|
|
109
|
+
const endpoint = `${this.endpoint}/runs/${run.id}`;
|
|
110
|
+
const response = await this.caller.call(fetch, endpoint, {
|
|
111
|
+
method: "PATCH",
|
|
112
|
+
headers: this.headers,
|
|
113
|
+
body: JSON.stringify(runUpdate),
|
|
114
|
+
signal: AbortSignal.timeout(this.timeout),
|
|
115
|
+
});
|
|
116
|
+
// consume the response body to release the connection
|
|
117
|
+
// https://undici.nodejs.org/#/?id=garbage-collection
|
|
118
|
+
const body = await response.text();
|
|
119
|
+
if (!response.ok) {
|
|
120
|
+
throw new Error(`Failed to update run: ${response.status} ${response.statusText} ${body}`);
|
|
121
|
+
}
|
|
122
|
+
}
|
|
123
|
+
async onLLMStart(run) {
|
|
124
|
+
await this._persistRunSingle(run);
|
|
125
|
+
}
|
|
126
|
+
async onLLMEnd(run) {
|
|
127
|
+
await this._updateRunSingle(run);
|
|
128
|
+
}
|
|
129
|
+
async onLLMError(run) {
|
|
130
|
+
await this._updateRunSingle(run);
|
|
131
|
+
}
|
|
132
|
+
async onChainStart(run) {
|
|
133
|
+
await this._persistRunSingle(run);
|
|
134
|
+
}
|
|
135
|
+
async onChainEnd(run) {
|
|
136
|
+
await this._updateRunSingle(run);
|
|
137
|
+
}
|
|
138
|
+
async onChainError(run) {
|
|
139
|
+
await this._updateRunSingle(run);
|
|
140
|
+
}
|
|
141
|
+
async onToolStart(run) {
|
|
142
|
+
await this._persistRunSingle(run);
|
|
143
|
+
}
|
|
144
|
+
async onToolEnd(run) {
|
|
145
|
+
await this._updateRunSingle(run);
|
|
146
|
+
}
|
|
147
|
+
async onToolError(run) {
|
|
148
|
+
await this._updateRunSingle(run);
|
|
149
|
+
}
|
|
98
150
|
}
|
|
@@ -0,0 +1,189 @@
|
|
|
1
|
+
"use strict";
|
|
2
|
+
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
|
+
exports.AI21 = void 0;
|
|
4
|
+
const base_js_1 = require("./base.cjs");
|
|
5
|
+
const env_js_1 = require("../util/env.cjs");
|
|
6
|
+
class AI21 extends base_js_1.LLM {
|
|
7
|
+
constructor(fields) {
|
|
8
|
+
super(fields ?? {});
|
|
9
|
+
Object.defineProperty(this, "model", {
|
|
10
|
+
enumerable: true,
|
|
11
|
+
configurable: true,
|
|
12
|
+
writable: true,
|
|
13
|
+
value: "j2-jumbo-instruct"
|
|
14
|
+
});
|
|
15
|
+
Object.defineProperty(this, "temperature", {
|
|
16
|
+
enumerable: true,
|
|
17
|
+
configurable: true,
|
|
18
|
+
writable: true,
|
|
19
|
+
value: 0.7
|
|
20
|
+
});
|
|
21
|
+
Object.defineProperty(this, "maxTokens", {
|
|
22
|
+
enumerable: true,
|
|
23
|
+
configurable: true,
|
|
24
|
+
writable: true,
|
|
25
|
+
value: 1024
|
|
26
|
+
});
|
|
27
|
+
Object.defineProperty(this, "minTokens", {
|
|
28
|
+
enumerable: true,
|
|
29
|
+
configurable: true,
|
|
30
|
+
writable: true,
|
|
31
|
+
value: 0
|
|
32
|
+
});
|
|
33
|
+
Object.defineProperty(this, "topP", {
|
|
34
|
+
enumerable: true,
|
|
35
|
+
configurable: true,
|
|
36
|
+
writable: true,
|
|
37
|
+
value: 1
|
|
38
|
+
});
|
|
39
|
+
Object.defineProperty(this, "presencePenalty", {
|
|
40
|
+
enumerable: true,
|
|
41
|
+
configurable: true,
|
|
42
|
+
writable: true,
|
|
43
|
+
value: AI21.getDefaultAI21PenaltyData()
|
|
44
|
+
});
|
|
45
|
+
Object.defineProperty(this, "countPenalty", {
|
|
46
|
+
enumerable: true,
|
|
47
|
+
configurable: true,
|
|
48
|
+
writable: true,
|
|
49
|
+
value: AI21.getDefaultAI21PenaltyData()
|
|
50
|
+
});
|
|
51
|
+
Object.defineProperty(this, "frequencyPenalty", {
|
|
52
|
+
enumerable: true,
|
|
53
|
+
configurable: true,
|
|
54
|
+
writable: true,
|
|
55
|
+
value: AI21.getDefaultAI21PenaltyData()
|
|
56
|
+
});
|
|
57
|
+
Object.defineProperty(this, "numResults", {
|
|
58
|
+
enumerable: true,
|
|
59
|
+
configurable: true,
|
|
60
|
+
writable: true,
|
|
61
|
+
value: 1
|
|
62
|
+
});
|
|
63
|
+
Object.defineProperty(this, "logitBias", {
|
|
64
|
+
enumerable: true,
|
|
65
|
+
configurable: true,
|
|
66
|
+
writable: true,
|
|
67
|
+
value: void 0
|
|
68
|
+
});
|
|
69
|
+
Object.defineProperty(this, "ai21ApiKey", {
|
|
70
|
+
enumerable: true,
|
|
71
|
+
configurable: true,
|
|
72
|
+
writable: true,
|
|
73
|
+
value: void 0
|
|
74
|
+
});
|
|
75
|
+
Object.defineProperty(this, "stop", {
|
|
76
|
+
enumerable: true,
|
|
77
|
+
configurable: true,
|
|
78
|
+
writable: true,
|
|
79
|
+
value: void 0
|
|
80
|
+
});
|
|
81
|
+
Object.defineProperty(this, "baseUrl", {
|
|
82
|
+
enumerable: true,
|
|
83
|
+
configurable: true,
|
|
84
|
+
writable: true,
|
|
85
|
+
value: void 0
|
|
86
|
+
});
|
|
87
|
+
this.model = fields?.model ?? this.model;
|
|
88
|
+
this.temperature = fields?.temperature ?? this.temperature;
|
|
89
|
+
this.maxTokens = fields?.maxTokens ?? this.maxTokens;
|
|
90
|
+
this.minTokens = fields?.minTokens ?? this.minTokens;
|
|
91
|
+
this.topP = fields?.topP ?? this.topP;
|
|
92
|
+
this.presencePenalty = fields?.presencePenalty ?? this.presencePenalty;
|
|
93
|
+
this.countPenalty = fields?.countPenalty ?? this.countPenalty;
|
|
94
|
+
this.frequencyPenalty = fields?.frequencyPenalty ?? this.frequencyPenalty;
|
|
95
|
+
this.numResults = fields?.numResults ?? this.numResults;
|
|
96
|
+
this.logitBias = fields?.logitBias;
|
|
97
|
+
this.ai21ApiKey =
|
|
98
|
+
fields?.ai21ApiKey ?? (0, env_js_1.getEnvironmentVariable)("AI21_API_KEY");
|
|
99
|
+
this.stop = fields?.stop;
|
|
100
|
+
this.baseUrl = fields?.baseUrl;
|
|
101
|
+
}
|
|
102
|
+
validateEnvironment() {
|
|
103
|
+
if (!this.ai21ApiKey) {
|
|
104
|
+
throw new Error(`No AI21 API key found. Please set it as "AI21_API_KEY" in your environment variables.`);
|
|
105
|
+
}
|
|
106
|
+
}
|
|
107
|
+
static getDefaultAI21PenaltyData() {
|
|
108
|
+
return {
|
|
109
|
+
scale: 0,
|
|
110
|
+
applyToWhitespaces: true,
|
|
111
|
+
applyToPunctuations: true,
|
|
112
|
+
applyToNumbers: true,
|
|
113
|
+
applyToStopwords: true,
|
|
114
|
+
applyToEmojis: true,
|
|
115
|
+
};
|
|
116
|
+
}
|
|
117
|
+
/** Get the type of LLM. */
|
|
118
|
+
_llmType() {
|
|
119
|
+
return "ai21";
|
|
120
|
+
}
|
|
121
|
+
/** Get the default parameters for calling AI21 API. */
|
|
122
|
+
get defaultParams() {
|
|
123
|
+
return {
|
|
124
|
+
temperature: this.temperature,
|
|
125
|
+
maxTokens: this.maxTokens,
|
|
126
|
+
minTokens: this.minTokens,
|
|
127
|
+
topP: this.topP,
|
|
128
|
+
presencePenalty: this.presencePenalty,
|
|
129
|
+
countPenalty: this.countPenalty,
|
|
130
|
+
frequencyPenalty: this.frequencyPenalty,
|
|
131
|
+
numResults: this.numResults,
|
|
132
|
+
logitBias: this.logitBias,
|
|
133
|
+
};
|
|
134
|
+
}
|
|
135
|
+
/** Get the identifying parameters for this LLM. */
|
|
136
|
+
get identifyingParams() {
|
|
137
|
+
return { ...this.defaultParams, model: this.model };
|
|
138
|
+
}
|
|
139
|
+
/** Call out to AI21's complete endpoint.
|
|
140
|
+
Args:
|
|
141
|
+
prompt: The prompt to pass into the model.
|
|
142
|
+
stop: Optional list of stop words to use when generating.
|
|
143
|
+
|
|
144
|
+
Returns:
|
|
145
|
+
The string generated by the model.
|
|
146
|
+
|
|
147
|
+
Example:
|
|
148
|
+
let response = ai21._call("Tell me a joke.");
|
|
149
|
+
*/
|
|
150
|
+
async _call(prompt, options) {
|
|
151
|
+
let stop = options?.stop;
|
|
152
|
+
this.validateEnvironment();
|
|
153
|
+
if (this.stop && stop && this.stop.length > 0 && stop.length > 0) {
|
|
154
|
+
throw new Error("`stop` found in both the input and default params.");
|
|
155
|
+
}
|
|
156
|
+
stop = this.stop ?? stop ?? [];
|
|
157
|
+
const baseUrl = this.baseUrl ?? this.model === "j1-grande-instruct"
|
|
158
|
+
? "https://api.ai21.com/studio/v1/experimental"
|
|
159
|
+
: "https://api.ai21.com/studio/v1";
|
|
160
|
+
const url = `${baseUrl}/${this.model}/complete`;
|
|
161
|
+
const headers = {
|
|
162
|
+
Authorization: `Bearer ${this.ai21ApiKey}`,
|
|
163
|
+
"Content-Type": "application/json",
|
|
164
|
+
};
|
|
165
|
+
const data = { prompt, stopSequences: stop, ...this.defaultParams };
|
|
166
|
+
const responseData = await this.caller.callWithOptions({}, async () => {
|
|
167
|
+
const response = await fetch(url, {
|
|
168
|
+
method: "POST",
|
|
169
|
+
headers,
|
|
170
|
+
body: JSON.stringify(data),
|
|
171
|
+
signal: options.signal,
|
|
172
|
+
});
|
|
173
|
+
if (!response.ok) {
|
|
174
|
+
const error = new Error(`AI21 call failed with status code ${response.status}`);
|
|
175
|
+
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
|
176
|
+
error.response = response;
|
|
177
|
+
throw error;
|
|
178
|
+
}
|
|
179
|
+
return response.json();
|
|
180
|
+
});
|
|
181
|
+
if (!responseData.completions ||
|
|
182
|
+
responseData.completions.length === 0 ||
|
|
183
|
+
!responseData.completions[0].data) {
|
|
184
|
+
throw new Error("No completions found in response");
|
|
185
|
+
}
|
|
186
|
+
return responseData.completions[0].data.text ?? "";
|
|
187
|
+
}
|
|
188
|
+
}
|
|
189
|
+
exports.AI21 = AI21;
|
|
@@ -0,0 +1,81 @@
|
|
|
1
|
+
import { LLM, BaseLLMParams } from "./base.js";
|
|
2
|
+
export type AI21PenaltyData = {
|
|
3
|
+
scale: number;
|
|
4
|
+
applyToWhitespaces: boolean;
|
|
5
|
+
applyToPunctuations: boolean;
|
|
6
|
+
applyToNumbers: boolean;
|
|
7
|
+
applyToStopwords: boolean;
|
|
8
|
+
applyToEmojis: boolean;
|
|
9
|
+
};
|
|
10
|
+
export interface AI21Input extends BaseLLMParams {
|
|
11
|
+
ai21ApiKey?: string;
|
|
12
|
+
model?: string;
|
|
13
|
+
temperature?: number;
|
|
14
|
+
minTokens?: number;
|
|
15
|
+
maxTokens?: number;
|
|
16
|
+
topP?: number;
|
|
17
|
+
presencePenalty?: AI21PenaltyData;
|
|
18
|
+
countPenalty?: AI21PenaltyData;
|
|
19
|
+
frequencyPenalty?: AI21PenaltyData;
|
|
20
|
+
numResults?: number;
|
|
21
|
+
logitBias?: Record<string, number>;
|
|
22
|
+
stop?: string[];
|
|
23
|
+
baseUrl?: string;
|
|
24
|
+
}
|
|
25
|
+
export declare class AI21 extends LLM implements AI21Input {
|
|
26
|
+
model: string;
|
|
27
|
+
temperature: number;
|
|
28
|
+
maxTokens: number;
|
|
29
|
+
minTokens: number;
|
|
30
|
+
topP: number;
|
|
31
|
+
presencePenalty: AI21PenaltyData;
|
|
32
|
+
countPenalty: AI21PenaltyData;
|
|
33
|
+
frequencyPenalty: AI21PenaltyData;
|
|
34
|
+
numResults: number;
|
|
35
|
+
logitBias?: Record<string, number>;
|
|
36
|
+
ai21ApiKey?: string;
|
|
37
|
+
stop?: string[];
|
|
38
|
+
baseUrl?: string;
|
|
39
|
+
constructor(fields?: AI21Input);
|
|
40
|
+
validateEnvironment(): void;
|
|
41
|
+
static getDefaultAI21PenaltyData(): AI21PenaltyData;
|
|
42
|
+
/** Get the type of LLM. */
|
|
43
|
+
_llmType(): string;
|
|
44
|
+
/** Get the default parameters for calling AI21 API. */
|
|
45
|
+
get defaultParams(): {
|
|
46
|
+
temperature: number;
|
|
47
|
+
maxTokens: number;
|
|
48
|
+
minTokens: number;
|
|
49
|
+
topP: number;
|
|
50
|
+
presencePenalty: AI21PenaltyData;
|
|
51
|
+
countPenalty: AI21PenaltyData;
|
|
52
|
+
frequencyPenalty: AI21PenaltyData;
|
|
53
|
+
numResults: number;
|
|
54
|
+
logitBias: Record<string, number> | undefined;
|
|
55
|
+
};
|
|
56
|
+
/** Get the identifying parameters for this LLM. */
|
|
57
|
+
get identifyingParams(): {
|
|
58
|
+
model: string;
|
|
59
|
+
temperature: number;
|
|
60
|
+
maxTokens: number;
|
|
61
|
+
minTokens: number;
|
|
62
|
+
topP: number;
|
|
63
|
+
presencePenalty: AI21PenaltyData;
|
|
64
|
+
countPenalty: AI21PenaltyData;
|
|
65
|
+
frequencyPenalty: AI21PenaltyData;
|
|
66
|
+
numResults: number;
|
|
67
|
+
logitBias: Record<string, number> | undefined;
|
|
68
|
+
};
|
|
69
|
+
/** Call out to AI21's complete endpoint.
|
|
70
|
+
Args:
|
|
71
|
+
prompt: The prompt to pass into the model.
|
|
72
|
+
stop: Optional list of stop words to use when generating.
|
|
73
|
+
|
|
74
|
+
Returns:
|
|
75
|
+
The string generated by the model.
|
|
76
|
+
|
|
77
|
+
Example:
|
|
78
|
+
let response = ai21._call("Tell me a joke.");
|
|
79
|
+
*/
|
|
80
|
+
_call(prompt: string, options: this["ParsedCallOptions"]): Promise<string>;
|
|
81
|
+
}
|