@nahisaho/katashiro-llm 2.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +21 -0
- package/dist/LLMClient.d.ts +64 -0
- package/dist/LLMClient.d.ts.map +1 -0
- package/dist/LLMClient.js +139 -0
- package/dist/LLMClient.js.map +1 -0
- package/dist/PromptManager.d.ts +66 -0
- package/dist/PromptManager.d.ts.map +1 -0
- package/dist/PromptManager.js +121 -0
- package/dist/PromptManager.js.map +1 -0
- package/dist/TokenCounter.d.ts +43 -0
- package/dist/TokenCounter.d.ts.map +1 -0
- package/dist/TokenCounter.js +100 -0
- package/dist/TokenCounter.js.map +1 -0
- package/dist/index.d.ts +12 -0
- package/dist/index.d.ts.map +1 -0
- package/dist/index.js +17 -0
- package/dist/index.js.map +1 -0
- package/dist/providers/AzureOpenAILLMProvider.d.ts +82 -0
- package/dist/providers/AzureOpenAILLMProvider.d.ts.map +1 -0
- package/dist/providers/AzureOpenAILLMProvider.js +339 -0
- package/dist/providers/AzureOpenAILLMProvider.js.map +1 -0
- package/dist/providers/BaseLLMProvider.d.ts +51 -0
- package/dist/providers/BaseLLMProvider.d.ts.map +1 -0
- package/dist/providers/BaseLLMProvider.js +72 -0
- package/dist/providers/BaseLLMProvider.js.map +1 -0
- package/dist/providers/LLMFactory.d.ts +75 -0
- package/dist/providers/LLMFactory.d.ts.map +1 -0
- package/dist/providers/LLMFactory.js +149 -0
- package/dist/providers/LLMFactory.js.map +1 -0
- package/dist/providers/MockLLMProvider.d.ts +57 -0
- package/dist/providers/MockLLMProvider.d.ts.map +1 -0
- package/dist/providers/MockLLMProvider.js +120 -0
- package/dist/providers/MockLLMProvider.js.map +1 -0
- package/dist/providers/OllamaLLMProvider.d.ts +73 -0
- package/dist/providers/OllamaLLMProvider.d.ts.map +1 -0
- package/dist/providers/OllamaLLMProvider.js +242 -0
- package/dist/providers/OllamaLLMProvider.js.map +1 -0
- package/dist/providers/OpenAILLMProvider.d.ts +87 -0
- package/dist/providers/OpenAILLMProvider.d.ts.map +1 -0
- package/dist/providers/OpenAILLMProvider.js +349 -0
- package/dist/providers/OpenAILLMProvider.js.map +1 -0
- package/dist/providers/index.d.ts +17 -0
- package/dist/providers/index.d.ts.map +1 -0
- package/dist/providers/index.js +19 -0
- package/dist/providers/index.js.map +1 -0
- package/dist/types.d.ts +251 -0
- package/dist/types.d.ts.map +1 -0
- package/dist/types.js +8 -0
- package/dist/types.js.map +1 -0
- package/package.json +51 -0
- package/src/LLMClient.ts +171 -0
- package/src/PromptManager.ts +156 -0
- package/src/TokenCounter.ts +114 -0
- package/src/index.ts +35 -0
- package/src/providers/AzureOpenAILLMProvider.ts +494 -0
- package/src/providers/BaseLLMProvider.ts +110 -0
- package/src/providers/LLMFactory.ts +216 -0
- package/src/providers/MockLLMProvider.ts +173 -0
- package/src/providers/OllamaLLMProvider.ts +322 -0
- package/src/providers/OpenAILLMProvider.ts +500 -0
- package/src/providers/index.ts +35 -0
- package/src/types.ts +268 -0
|
@@ -0,0 +1,114 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Token Counter - トークン数カウント
|
|
3
|
+
*
|
|
4
|
+
* @requirement REQ-LLM-001
|
|
5
|
+
* @design DES-KATASHIRO-003-LLM §3.1
|
|
6
|
+
*/
|
|
7
|
+
|
|
8
|
+
/**
|
|
9
|
+
* モデル別トークン推定係数
|
|
10
|
+
*/
|
|
11
|
+
const MODEL_TOKEN_FACTORS: Record<string, number> = {
|
|
12
|
+
'gpt-4o': 4.0,
|
|
13
|
+
'gpt-4o-mini': 4.0,
|
|
14
|
+
'gpt-4-turbo': 4.0,
|
|
15
|
+
'gpt-4': 4.0,
|
|
16
|
+
'gpt-3.5-turbo': 4.0,
|
|
17
|
+
'claude-3-5-sonnet': 4.5,
|
|
18
|
+
'claude-3-opus': 4.5,
|
|
19
|
+
'claude-3-sonnet': 4.5,
|
|
20
|
+
'claude-3-haiku': 4.5,
|
|
21
|
+
'gemini-pro': 4.0,
|
|
22
|
+
'gemini-1.5-pro': 4.0,
|
|
23
|
+
default: 4.0,
|
|
24
|
+
};
|
|
25
|
+
|
|
26
|
+
/**
|
|
27
|
+
* TokenCounter - トークン数カウントユーティリティ
|
|
28
|
+
*/
|
|
29
|
+
export class TokenCounter {
|
|
30
|
+
/**
|
|
31
|
+
* テキストのトークン数を推定
|
|
32
|
+
*/
|
|
33
|
+
estimate(text: string, model?: string): number {
|
|
34
|
+
const factor = this.getTokenFactor(model);
|
|
35
|
+
return Math.ceil(text.length / factor);
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
/**
|
|
39
|
+
* メッセージ配列のトークン数を推定
|
|
40
|
+
*/
|
|
41
|
+
estimateMessages(
|
|
42
|
+
messages: Array<{ role: string; content: string }>,
|
|
43
|
+
model?: string
|
|
44
|
+
): number {
|
|
45
|
+
let total = 0;
|
|
46
|
+
for (const message of messages) {
|
|
47
|
+
// ロールのオーバーヘッド(約4トークン)
|
|
48
|
+
total += 4;
|
|
49
|
+
total += this.estimate(message.content, model);
|
|
50
|
+
}
|
|
51
|
+
// メッセージ区切りのオーバーヘッド
|
|
52
|
+
total += 3;
|
|
53
|
+
return total;
|
|
54
|
+
}
|
|
55
|
+
|
|
56
|
+
/**
|
|
57
|
+
* トークン数が制限内か確認
|
|
58
|
+
*/
|
|
59
|
+
isWithinLimit(text: string, limit: number, model?: string): boolean {
|
|
60
|
+
return this.estimate(text, model) <= limit;
|
|
61
|
+
}
|
|
62
|
+
|
|
63
|
+
/**
|
|
64
|
+
* 制限内に収まるようにテキストを切り詰め
|
|
65
|
+
*/
|
|
66
|
+
truncateToLimit(text: string, limit: number, model?: string): string {
|
|
67
|
+
const currentTokens = this.estimate(text, model);
|
|
68
|
+
if (currentTokens <= limit) {
|
|
69
|
+
return text;
|
|
70
|
+
}
|
|
71
|
+
|
|
72
|
+
const factor = this.getTokenFactor(model);
|
|
73
|
+
const targetLength = Math.floor(limit * factor * 0.9); // 10%マージン
|
|
74
|
+
return text.slice(0, targetLength) + '...';
|
|
75
|
+
}
|
|
76
|
+
|
|
77
|
+
/**
|
|
78
|
+
* モデル別トークン係数取得
|
|
79
|
+
*/
|
|
80
|
+
private getTokenFactor(model?: string): number {
|
|
81
|
+
if (!model) {
|
|
82
|
+
return MODEL_TOKEN_FACTORS.default ?? 4.0;
|
|
83
|
+
}
|
|
84
|
+
|
|
85
|
+
// モデル名の部分一致
|
|
86
|
+
for (const [key, factor] of Object.entries(MODEL_TOKEN_FACTORS)) {
|
|
87
|
+
if (model.includes(key)) {
|
|
88
|
+
return factor;
|
|
89
|
+
}
|
|
90
|
+
}
|
|
91
|
+
|
|
92
|
+
return MODEL_TOKEN_FACTORS.default ?? 4.0;
|
|
93
|
+
}
|
|
94
|
+
}
|
|
95
|
+
|
|
96
|
+
// シングルトン
|
|
97
|
+
let tokenCounterInstance: TokenCounter | null = null;
|
|
98
|
+
|
|
99
|
+
/**
|
|
100
|
+
* TokenCounter シングルトン取得
|
|
101
|
+
*/
|
|
102
|
+
export function getTokenCounter(): TokenCounter {
|
|
103
|
+
if (!tokenCounterInstance) {
|
|
104
|
+
tokenCounterInstance = new TokenCounter();
|
|
105
|
+
}
|
|
106
|
+
return tokenCounterInstance;
|
|
107
|
+
}
|
|
108
|
+
|
|
109
|
+
/**
|
|
110
|
+
* TokenCounter リセット(テスト用)
|
|
111
|
+
*/
|
|
112
|
+
export function resetTokenCounter(): void {
|
|
113
|
+
tokenCounterInstance = null;
|
|
114
|
+
}
|
package/src/index.ts
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* LLM Package - Main Entry Point
|
|
3
|
+
*
|
|
4
|
+
* @requirement REQ-LLM-001〜006
|
|
5
|
+
* @design DES-KATASHIRO-003-LLM
|
|
6
|
+
*/
|
|
7
|
+
|
|
8
|
+
// Types
|
|
9
|
+
export * from './types.js';
|
|
10
|
+
|
|
11
|
+
// Providers
|
|
12
|
+
export * from './providers/index.js';
|
|
13
|
+
|
|
14
|
+
// Client
|
|
15
|
+
export {
|
|
16
|
+
LLMClient,
|
|
17
|
+
getLLMClient,
|
|
18
|
+
initLLMClient,
|
|
19
|
+
resetLLMClient,
|
|
20
|
+
} from './LLMClient.js';
|
|
21
|
+
|
|
22
|
+
// Prompt Manager
|
|
23
|
+
export {
|
|
24
|
+
PromptManager,
|
|
25
|
+
getPromptManager,
|
|
26
|
+
resetPromptManager,
|
|
27
|
+
type TemplateVariables,
|
|
28
|
+
} from './PromptManager.js';
|
|
29
|
+
|
|
30
|
+
// Token Counter
|
|
31
|
+
export {
|
|
32
|
+
TokenCounter,
|
|
33
|
+
getTokenCounter,
|
|
34
|
+
resetTokenCounter,
|
|
35
|
+
} from './TokenCounter.js';
|
|
@@ -0,0 +1,494 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Azure OpenAI LLM Provider
|
|
3
|
+
*
|
|
4
|
+
* Azure OpenAI Service provider for chat completions
|
|
5
|
+
*
|
|
6
|
+
* @requirement REQ-LLM-001
|
|
7
|
+
* @design DES-KATASHIRO-003-LLM
|
|
8
|
+
*/
|
|
9
|
+
|
|
10
|
+
import type { z, ZodType } from 'zod';
|
|
11
|
+
import type {
|
|
12
|
+
ProviderConfig,
|
|
13
|
+
GenerateRequest,
|
|
14
|
+
GenerateResponse,
|
|
15
|
+
StreamChunk,
|
|
16
|
+
Message,
|
|
17
|
+
TokenUsage,
|
|
18
|
+
ToolCall,
|
|
19
|
+
ToolDefinition,
|
|
20
|
+
FinishReason,
|
|
21
|
+
} from '../types.js';
|
|
22
|
+
import { BaseLLMProvider } from './BaseLLMProvider.js';
|
|
23
|
+
|
|
24
|
+
/**
|
|
25
|
+
* Azure OpenAI設定
|
|
26
|
+
*/
|
|
27
|
+
export interface AzureOpenAIProviderConfig extends ProviderConfig {
|
|
28
|
+
/** Azure OpenAI エンドポイント */
|
|
29
|
+
endpoint?: string;
|
|
30
|
+
/** APIキー */
|
|
31
|
+
apiKey?: string;
|
|
32
|
+
/** デプロイメント名 */
|
|
33
|
+
deploymentName?: string;
|
|
34
|
+
/** APIバージョン */
|
|
35
|
+
apiVersion?: string;
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
/**
|
|
39
|
+
* Azure OpenAIメッセージ形式
|
|
40
|
+
*/
|
|
41
|
+
interface AzureMessage {
|
|
42
|
+
role: 'system' | 'user' | 'assistant' | 'tool';
|
|
43
|
+
content: string | null;
|
|
44
|
+
name?: string;
|
|
45
|
+
tool_call_id?: string;
|
|
46
|
+
tool_calls?: Array<{
|
|
47
|
+
id: string;
|
|
48
|
+
type: 'function';
|
|
49
|
+
function: {
|
|
50
|
+
name: string;
|
|
51
|
+
arguments: string;
|
|
52
|
+
};
|
|
53
|
+
}>;
|
|
54
|
+
}
|
|
55
|
+
|
|
56
|
+
/**
|
|
57
|
+
* Azure Chat Completion レスポンス
|
|
58
|
+
*/
|
|
59
|
+
interface AzureChatResponse {
|
|
60
|
+
id: string;
|
|
61
|
+
object: 'chat.completion';
|
|
62
|
+
created: number;
|
|
63
|
+
model: string;
|
|
64
|
+
choices: Array<{
|
|
65
|
+
index: number;
|
|
66
|
+
message: AzureMessage;
|
|
67
|
+
finish_reason: 'stop' | 'length' | 'tool_calls' | 'content_filter';
|
|
68
|
+
}>;
|
|
69
|
+
usage: {
|
|
70
|
+
prompt_tokens: number;
|
|
71
|
+
completion_tokens: number;
|
|
72
|
+
total_tokens: number;
|
|
73
|
+
};
|
|
74
|
+
}
|
|
75
|
+
|
|
76
|
+
/**
|
|
77
|
+
* Azure Stream Chunk
|
|
78
|
+
*/
|
|
79
|
+
interface AzureStreamChunk {
|
|
80
|
+
id: string;
|
|
81
|
+
object: 'chat.completion.chunk';
|
|
82
|
+
created: number;
|
|
83
|
+
model: string;
|
|
84
|
+
choices: Array<{
|
|
85
|
+
index: number;
|
|
86
|
+
delta: {
|
|
87
|
+
role?: string;
|
|
88
|
+
content?: string;
|
|
89
|
+
tool_calls?: Array<{
|
|
90
|
+
index: number;
|
|
91
|
+
id?: string;
|
|
92
|
+
type?: 'function';
|
|
93
|
+
function?: {
|
|
94
|
+
name?: string;
|
|
95
|
+
arguments?: string;
|
|
96
|
+
};
|
|
97
|
+
}>;
|
|
98
|
+
};
|
|
99
|
+
finish_reason: string | null;
|
|
100
|
+
}>;
|
|
101
|
+
usage?: {
|
|
102
|
+
prompt_tokens: number;
|
|
103
|
+
completion_tokens: number;
|
|
104
|
+
total_tokens: number;
|
|
105
|
+
};
|
|
106
|
+
}
|
|
107
|
+
|
|
108
|
+
/**
|
|
109
|
+
* Azure OpenAI LLMプロバイダー
|
|
110
|
+
*
|
|
111
|
+
* Azure OpenAI Serviceを使用したテキスト生成
|
|
112
|
+
*
|
|
113
|
+
* @example
|
|
114
|
+
* ```typescript
|
|
115
|
+
* const provider = new AzureOpenAILLMProvider({
|
|
116
|
+
* endpoint: 'https://your-resource.openai.azure.com',
|
|
117
|
+
* apiKey: process.env.AZURE_OPENAI_API_KEY,
|
|
118
|
+
* deploymentName: 'gpt-4o',
|
|
119
|
+
* apiVersion: '2024-02-15-preview',
|
|
120
|
+
* });
|
|
121
|
+
*
|
|
122
|
+
* const response = await provider.generate({
|
|
123
|
+
* messages: [{ role: 'user', content: 'Hello!' }],
|
|
124
|
+
* });
|
|
125
|
+
* ```
|
|
126
|
+
*/
|
|
127
|
+
export class AzureOpenAILLMProvider extends BaseLLMProvider {
|
|
128
|
+
readonly name = 'azure-openai';
|
|
129
|
+
readonly supportedModels = [
|
|
130
|
+
'gpt-4o',
|
|
131
|
+
'gpt-4o-mini',
|
|
132
|
+
'gpt-4-turbo',
|
|
133
|
+
'gpt-4',
|
|
134
|
+
'gpt-35-turbo',
|
|
135
|
+
];
|
|
136
|
+
|
|
137
|
+
private readonly endpoint: string;
|
|
138
|
+
private readonly apiKey: string;
|
|
139
|
+
private readonly deploymentName: string;
|
|
140
|
+
private readonly apiVersion: string;
|
|
141
|
+
|
|
142
|
+
constructor(config: AzureOpenAIProviderConfig = {}) {
|
|
143
|
+
super(config);
|
|
144
|
+
|
|
145
|
+
this.endpoint = config.endpoint ?? process.env.AZURE_OPENAI_ENDPOINT ?? '';
|
|
146
|
+
this.apiKey = config.apiKey ?? process.env.AZURE_OPENAI_API_KEY ?? '';
|
|
147
|
+
this.deploymentName =
|
|
148
|
+
config.deploymentName ??
|
|
149
|
+
process.env.AZURE_OPENAI_DEPLOYMENT ??
|
|
150
|
+
config.defaultModel ??
|
|
151
|
+
'';
|
|
152
|
+
this.apiVersion = config.apiVersion ?? '2024-02-15-preview';
|
|
153
|
+
|
|
154
|
+
// 設定検証
|
|
155
|
+
if (!this.endpoint) {
|
|
156
|
+
throw new Error(
|
|
157
|
+
'Azure OpenAI endpoint is required. Set AZURE_OPENAI_ENDPOINT or provide endpoint in config.'
|
|
158
|
+
);
|
|
159
|
+
}
|
|
160
|
+
if (!this.apiKey) {
|
|
161
|
+
throw new Error(
|
|
162
|
+
'Azure OpenAI API key is required. Set AZURE_OPENAI_API_KEY or provide apiKey in config.'
|
|
163
|
+
);
|
|
164
|
+
}
|
|
165
|
+
if (!this.deploymentName) {
|
|
166
|
+
throw new Error(
|
|
167
|
+
'Azure OpenAI deployment name is required. Set AZURE_OPENAI_DEPLOYMENT or provide deploymentName in config.'
|
|
168
|
+
);
|
|
169
|
+
}
|
|
170
|
+
}
|
|
171
|
+
|
|
172
|
+
protected getDefaultModel(): string {
|
|
173
|
+
return this.deploymentName;
|
|
174
|
+
}
|
|
175
|
+
|
|
176
|
+
/**
|
|
177
|
+
* メッセージ形式変換
|
|
178
|
+
*/
|
|
179
|
+
private convertMessages(messages: Message[]): AzureMessage[] {
|
|
180
|
+
return messages.map((msg) => {
|
|
181
|
+
const converted: AzureMessage = {
|
|
182
|
+
role: msg.role,
|
|
183
|
+
content:
|
|
184
|
+
typeof msg.content === 'string'
|
|
185
|
+
? msg.content
|
|
186
|
+
: JSON.stringify(msg.content),
|
|
187
|
+
};
|
|
188
|
+
|
|
189
|
+
if (msg.name) converted.name = msg.name;
|
|
190
|
+
if (msg.toolCallId) converted.tool_call_id = msg.toolCallId;
|
|
191
|
+
if (msg.toolCalls) {
|
|
192
|
+
converted.tool_calls = msg.toolCalls.map((tc) => ({
|
|
193
|
+
id: tc.id,
|
|
194
|
+
type: tc.type,
|
|
195
|
+
function: tc.function,
|
|
196
|
+
}));
|
|
197
|
+
}
|
|
198
|
+
|
|
199
|
+
return converted;
|
|
200
|
+
});
|
|
201
|
+
}
|
|
202
|
+
|
|
203
|
+
/**
|
|
204
|
+
* ツール定義変換
|
|
205
|
+
*/
|
|
206
|
+
private convertTools(tools?: ToolDefinition[]): unknown[] | undefined {
|
|
207
|
+
if (!tools) return undefined;
|
|
208
|
+
|
|
209
|
+
return tools.map((tool) => ({
|
|
210
|
+
type: tool.type,
|
|
211
|
+
function: {
|
|
212
|
+
name: tool.function.name,
|
|
213
|
+
description: tool.function.description,
|
|
214
|
+
parameters: tool.function.parameters,
|
|
215
|
+
},
|
|
216
|
+
}));
|
|
217
|
+
}
|
|
218
|
+
|
|
219
|
+
/**
|
|
220
|
+
* Finish Reason変換
|
|
221
|
+
*/
|
|
222
|
+
private convertFinishReason(reason: string | null): FinishReason {
|
|
223
|
+
switch (reason) {
|
|
224
|
+
case 'stop':
|
|
225
|
+
return 'stop';
|
|
226
|
+
case 'length':
|
|
227
|
+
return 'length';
|
|
228
|
+
case 'tool_calls':
|
|
229
|
+
return 'tool_calls';
|
|
230
|
+
case 'content_filter':
|
|
231
|
+
return 'content_filter';
|
|
232
|
+
default:
|
|
233
|
+
return 'stop';
|
|
234
|
+
}
|
|
235
|
+
}
|
|
236
|
+
|
|
237
|
+
/**
|
|
238
|
+
* テキスト生成
|
|
239
|
+
*/
|
|
240
|
+
async generate(request: GenerateRequest): Promise<GenerateResponse> {
|
|
241
|
+
const baseUrl = this.endpoint.endsWith('/')
|
|
242
|
+
? this.endpoint.slice(0, -1)
|
|
243
|
+
: this.endpoint;
|
|
244
|
+
|
|
245
|
+
const url = `${baseUrl}/openai/deployments/${this.deploymentName}/chat/completions?api-version=${this.apiVersion}`;
|
|
246
|
+
|
|
247
|
+
const headers: Record<string, string> = {
|
|
248
|
+
'Content-Type': 'application/json',
|
|
249
|
+
'api-key': this.apiKey,
|
|
250
|
+
};
|
|
251
|
+
|
|
252
|
+
const body: Record<string, unknown> = {
|
|
253
|
+
messages: this.convertMessages(request.messages),
|
|
254
|
+
temperature: request.temperature,
|
|
255
|
+
max_tokens: request.maxTokens,
|
|
256
|
+
top_p: request.topP,
|
|
257
|
+
stop: request.stopSequences,
|
|
258
|
+
user: request.user,
|
|
259
|
+
};
|
|
260
|
+
|
|
261
|
+
if (request.tools) {
|
|
262
|
+
body.tools = this.convertTools(request.tools);
|
|
263
|
+
body.tool_choice = request.toolChoice;
|
|
264
|
+
}
|
|
265
|
+
|
|
266
|
+
if (request.responseFormat) {
|
|
267
|
+
body.response_format = {
|
|
268
|
+
type: request.responseFormat.type,
|
|
269
|
+
};
|
|
270
|
+
}
|
|
271
|
+
|
|
272
|
+
const controller = new AbortController();
|
|
273
|
+
const timeoutId = setTimeout(
|
|
274
|
+
() => controller.abort(),
|
|
275
|
+
this.config.timeout ?? 30000
|
|
276
|
+
);
|
|
277
|
+
|
|
278
|
+
try {
|
|
279
|
+
const response = await fetch(url, {
|
|
280
|
+
method: 'POST',
|
|
281
|
+
headers,
|
|
282
|
+
body: JSON.stringify(body),
|
|
283
|
+
signal: controller.signal,
|
|
284
|
+
});
|
|
285
|
+
|
|
286
|
+
if (!response.ok) {
|
|
287
|
+
const errorText = await response.text();
|
|
288
|
+
throw new Error(
|
|
289
|
+
`Azure OpenAI API error: ${response.status} - ${errorText}`
|
|
290
|
+
);
|
|
291
|
+
}
|
|
292
|
+
|
|
293
|
+
const data = (await response.json()) as AzureChatResponse;
|
|
294
|
+
const choice = data.choices[0];
|
|
295
|
+
|
|
296
|
+
if (!choice) {
|
|
297
|
+
throw new Error('No choices returned from Azure OpenAI API');
|
|
298
|
+
}
|
|
299
|
+
|
|
300
|
+
const toolCalls: ToolCall[] | undefined = choice.message.tool_calls?.map(
|
|
301
|
+
(tc) => ({
|
|
302
|
+
id: tc.id,
|
|
303
|
+
type: tc.type,
|
|
304
|
+
function: tc.function,
|
|
305
|
+
})
|
|
306
|
+
);
|
|
307
|
+
|
|
308
|
+
const usage: TokenUsage = {
|
|
309
|
+
promptTokens: data.usage.prompt_tokens,
|
|
310
|
+
completionTokens: data.usage.completion_tokens,
|
|
311
|
+
totalTokens: data.usage.total_tokens,
|
|
312
|
+
};
|
|
313
|
+
|
|
314
|
+
return {
|
|
315
|
+
id: data.id,
|
|
316
|
+
model: data.model,
|
|
317
|
+
content: choice.message.content ?? '',
|
|
318
|
+
toolCalls,
|
|
319
|
+
usage,
|
|
320
|
+
finishReason: this.convertFinishReason(choice.finish_reason),
|
|
321
|
+
};
|
|
322
|
+
} finally {
|
|
323
|
+
clearTimeout(timeoutId);
|
|
324
|
+
}
|
|
325
|
+
}
|
|
326
|
+
|
|
327
|
+
/**
|
|
328
|
+
* ストリーミング生成
|
|
329
|
+
*/
|
|
330
|
+
async *generateStream(request: GenerateRequest): AsyncGenerator<StreamChunk> {
|
|
331
|
+
const baseUrl = this.endpoint.endsWith('/')
|
|
332
|
+
? this.endpoint.slice(0, -1)
|
|
333
|
+
: this.endpoint;
|
|
334
|
+
|
|
335
|
+
const url = `${baseUrl}/openai/deployments/${this.deploymentName}/chat/completions?api-version=${this.apiVersion}`;
|
|
336
|
+
|
|
337
|
+
const headers: Record<string, string> = {
|
|
338
|
+
'Content-Type': 'application/json',
|
|
339
|
+
'api-key': this.apiKey,
|
|
340
|
+
};
|
|
341
|
+
|
|
342
|
+
const body: Record<string, unknown> = {
|
|
343
|
+
messages: this.convertMessages(request.messages),
|
|
344
|
+
temperature: request.temperature,
|
|
345
|
+
max_tokens: request.maxTokens,
|
|
346
|
+
top_p: request.topP,
|
|
347
|
+
stop: request.stopSequences,
|
|
348
|
+
user: request.user,
|
|
349
|
+
stream: true,
|
|
350
|
+
};
|
|
351
|
+
|
|
352
|
+
if (request.tools) {
|
|
353
|
+
body.tools = this.convertTools(request.tools);
|
|
354
|
+
body.tool_choice = request.toolChoice;
|
|
355
|
+
}
|
|
356
|
+
|
|
357
|
+
const response = await fetch(url, {
|
|
358
|
+
method: 'POST',
|
|
359
|
+
headers,
|
|
360
|
+
body: JSON.stringify(body),
|
|
361
|
+
});
|
|
362
|
+
|
|
363
|
+
if (!response.ok || !response.body) {
|
|
364
|
+
const errorText = await response.text();
|
|
365
|
+
throw new Error(
|
|
366
|
+
`Azure OpenAI API error: ${response.status} - ${errorText}`
|
|
367
|
+
);
|
|
368
|
+
}
|
|
369
|
+
|
|
370
|
+
const reader = response.body.getReader();
|
|
371
|
+
const decoder = new TextDecoder();
|
|
372
|
+
let buffer = '';
|
|
373
|
+
|
|
374
|
+
const toolCallsMap = new Map<
|
|
375
|
+
number,
|
|
376
|
+
{ id: string; name: string; arguments: string }
|
|
377
|
+
>();
|
|
378
|
+
|
|
379
|
+
try {
|
|
380
|
+
while (true) {
|
|
381
|
+
const { done, value } = await reader.read();
|
|
382
|
+
if (done) break;
|
|
383
|
+
|
|
384
|
+
buffer += decoder.decode(value, { stream: true });
|
|
385
|
+
const lines = buffer.split('\n');
|
|
386
|
+
buffer = lines.pop() ?? '';
|
|
387
|
+
|
|
388
|
+
for (const line of lines) {
|
|
389
|
+
if (!line.startsWith('data: ')) continue;
|
|
390
|
+
const data = line.slice(6).trim();
|
|
391
|
+
if (data === '[DONE]') {
|
|
392
|
+
yield { type: 'done' };
|
|
393
|
+
continue;
|
|
394
|
+
}
|
|
395
|
+
|
|
396
|
+
try {
|
|
397
|
+
const chunk = JSON.parse(data) as AzureStreamChunk;
|
|
398
|
+
const delta = chunk.choices[0]?.delta;
|
|
399
|
+
|
|
400
|
+
if (delta?.content) {
|
|
401
|
+
yield {
|
|
402
|
+
type: 'content',
|
|
403
|
+
content: delta.content,
|
|
404
|
+
};
|
|
405
|
+
}
|
|
406
|
+
|
|
407
|
+
if (delta?.tool_calls) {
|
|
408
|
+
for (const tc of delta.tool_calls) {
|
|
409
|
+
const existing = toolCallsMap.get(tc.index) ?? {
|
|
410
|
+
id: '',
|
|
411
|
+
name: '',
|
|
412
|
+
arguments: '',
|
|
413
|
+
};
|
|
414
|
+
if (tc.id) existing.id = tc.id;
|
|
415
|
+
if (tc.function?.name) existing.name = tc.function.name;
|
|
416
|
+
if (tc.function?.arguments) {
|
|
417
|
+
existing.arguments += tc.function.arguments;
|
|
418
|
+
}
|
|
419
|
+
toolCallsMap.set(tc.index, existing);
|
|
420
|
+
}
|
|
421
|
+
}
|
|
422
|
+
|
|
423
|
+
if (chunk.usage) {
|
|
424
|
+
yield {
|
|
425
|
+
type: 'usage',
|
|
426
|
+
usage: {
|
|
427
|
+
promptTokens: chunk.usage.prompt_tokens,
|
|
428
|
+
completionTokens: chunk.usage.completion_tokens,
|
|
429
|
+
totalTokens: chunk.usage.total_tokens,
|
|
430
|
+
},
|
|
431
|
+
};
|
|
432
|
+
}
|
|
433
|
+
|
|
434
|
+
const finishReason = chunk.choices[0]?.finish_reason;
|
|
435
|
+
if (finishReason === 'tool_calls') {
|
|
436
|
+
for (const [, tc] of toolCallsMap) {
|
|
437
|
+
yield {
|
|
438
|
+
type: 'tool_call',
|
|
439
|
+
toolCall: {
|
|
440
|
+
id: tc.id,
|
|
441
|
+
type: 'function',
|
|
442
|
+
function: {
|
|
443
|
+
name: tc.name,
|
|
444
|
+
arguments: tc.arguments,
|
|
445
|
+
},
|
|
446
|
+
},
|
|
447
|
+
};
|
|
448
|
+
}
|
|
449
|
+
}
|
|
450
|
+
} catch {
|
|
451
|
+
// JSON parse error - skip
|
|
452
|
+
}
|
|
453
|
+
}
|
|
454
|
+
}
|
|
455
|
+
} finally {
|
|
456
|
+
reader.releaseLock();
|
|
457
|
+
}
|
|
458
|
+
}
|
|
459
|
+
|
|
460
|
+
/**
|
|
461
|
+
* 構造化出力生成
|
|
462
|
+
*/
|
|
463
|
+
override async generateStructured<T extends ZodType>(
|
|
464
|
+
request: GenerateRequest,
|
|
465
|
+
schema: T
|
|
466
|
+
): Promise<z.infer<T>> {
|
|
467
|
+
const jsonSchema = this.zodToJsonSchema(schema);
|
|
468
|
+
|
|
469
|
+
const enhancedRequest: GenerateRequest = {
|
|
470
|
+
...request,
|
|
471
|
+
responseFormat: { type: 'json_object' },
|
|
472
|
+
messages: [
|
|
473
|
+
...request.messages,
|
|
474
|
+
{
|
|
475
|
+
role: 'user',
|
|
476
|
+
content: `Respond with valid JSON matching this schema:\n${JSON.stringify(jsonSchema, null, 2)}`,
|
|
477
|
+
},
|
|
478
|
+
],
|
|
479
|
+
};
|
|
480
|
+
|
|
481
|
+
const response = await this.generate(enhancedRequest);
|
|
482
|
+
const parsed = JSON.parse(response.content);
|
|
483
|
+
return schema.parse(parsed);
|
|
484
|
+
}
|
|
485
|
+
|
|
486
|
+
/**
|
|
487
|
+
* トークン数カウント(近似)
|
|
488
|
+
*/
|
|
489
|
+
override async countTokens(text: string, _model?: string): Promise<number> {
|
|
490
|
+
const englishChars = text.replace(/[^\x00-\x7F]/g, '').length;
|
|
491
|
+
const nonEnglishChars = text.length - englishChars;
|
|
492
|
+
return Math.ceil(englishChars / 4 + nonEnglishChars / 2);
|
|
493
|
+
}
|
|
494
|
+
}
|