@botpress/zai 2.1.20 → 2.3.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CLAUDE.md +696 -0
- package/README.md +79 -2
- package/dist/index.d.ts +85 -14
- package/dist/index.js +3 -0
- package/dist/operations/group.js +369 -0
- package/dist/operations/rate.js +350 -0
- package/dist/operations/sort.js +450 -0
- package/e2e/data/cache.jsonl +289 -0
- package/package.json +1 -1
- package/src/index.ts +3 -0
- package/src/operations/group.ts +543 -0
- package/src/operations/rate.ts +518 -0
- package/src/operations/sort.ts +618 -0
|
@@ -0,0 +1,518 @@
|
|
|
1
|
+
// eslint-disable consistent-type-definitions
|
|
2
|
+
import { z } from '@bpinternal/zui'
|
|
3
|
+
import pLimit from 'p-limit'
|
|
4
|
+
import { ZaiContext } from '../context'
|
|
5
|
+
import { Response } from '../response'
|
|
6
|
+
import { getTokenizer } from '../tokenizer'
|
|
7
|
+
import { fastHash, stringify } from '../utils'
|
|
8
|
+
import { Zai } from '../zai'
|
|
9
|
+
import { PROMPT_INPUT_BUFFER, PROMPT_OUTPUT_BUFFER } from './constants'
|
|
10
|
+
|
|
11
|
+
// Rating scale constants
|
|
12
|
+
const RATING_VALUES = {
|
|
13
|
+
very_bad: 1,
|
|
14
|
+
bad: 2,
|
|
15
|
+
average: 3,
|
|
16
|
+
good: 4,
|
|
17
|
+
very_good: 5,
|
|
18
|
+
} as const
|
|
19
|
+
|
|
20
|
+
type RatingLabel = keyof typeof RATING_VALUES
|
|
21
|
+
|
|
22
|
+
// Evaluation criteria generated by LLM
|
|
23
|
+
type EvaluationCriteria = Record<
|
|
24
|
+
string,
|
|
25
|
+
{
|
|
26
|
+
very_bad: string
|
|
27
|
+
bad: string
|
|
28
|
+
average: string
|
|
29
|
+
good: string
|
|
30
|
+
very_good: string
|
|
31
|
+
}
|
|
32
|
+
>
|
|
33
|
+
|
|
34
|
+
export type RatingInstructions = string | Record<string, string>
|
|
35
|
+
|
|
36
|
+
export type Options = {
|
|
37
|
+
/** The maximum number of tokens per item */
|
|
38
|
+
tokensPerItem?: number
|
|
39
|
+
/** The maximum number of items to rate per chunk */
|
|
40
|
+
maxItemsPerChunk?: number
|
|
41
|
+
}
|
|
42
|
+
|
|
43
|
+
const _Options = z.object({
|
|
44
|
+
tokensPerItem: z
|
|
45
|
+
.number()
|
|
46
|
+
.min(1)
|
|
47
|
+
.max(100_000)
|
|
48
|
+
.optional()
|
|
49
|
+
.describe('The maximum number of tokens per item')
|
|
50
|
+
.default(250),
|
|
51
|
+
maxItemsPerChunk: z
|
|
52
|
+
.number()
|
|
53
|
+
.min(1)
|
|
54
|
+
.max(100)
|
|
55
|
+
.optional()
|
|
56
|
+
.describe('The maximum number of items to rate per chunk')
|
|
57
|
+
.default(50),
|
|
58
|
+
})
|
|
59
|
+
|
|
60
|
+
// Result types based on instructions type
|
|
61
|
+
export type RatingResult<T extends RatingInstructions> = T extends string
|
|
62
|
+
? {
|
|
63
|
+
[key: string]: number // criteria scores
|
|
64
|
+
total: number // sum of all criteria
|
|
65
|
+
}
|
|
66
|
+
: T extends Record<string, string>
|
|
67
|
+
? {
|
|
68
|
+
[K in keyof T]: number // score for each criterion
|
|
69
|
+
} & {
|
|
70
|
+
total: number // sum of all criteria
|
|
71
|
+
}
|
|
72
|
+
: never
|
|
73
|
+
|
|
74
|
+
export type SimplifiedRatingResult<T extends RatingInstructions> = T extends string ? number : RatingResult<T>
|
|
75
|
+
|
|
76
|
+
declare module '@botpress/zai' {
|
|
77
|
+
interface Zai {
|
|
78
|
+
/**
|
|
79
|
+
* Rates an array of items based on provided instructions.
|
|
80
|
+
* Returns a number (1-5) if instructions is a string, or a Record<string, number> if instructions is a Record.
|
|
81
|
+
*/
|
|
82
|
+
rate<T, I extends RatingInstructions>(
|
|
83
|
+
input: Array<T>,
|
|
84
|
+
instructions: I,
|
|
85
|
+
options?: Options
|
|
86
|
+
): Response<Array<RatingResult<I>>, Array<SimplifiedRatingResult<I>>>
|
|
87
|
+
}
|
|
88
|
+
}
|
|
89
|
+
|
|
90
|
+
const END = '■END■'
|
|
91
|
+
|
|
92
|
+
const rate = async <T, I extends RatingInstructions>(
|
|
93
|
+
input: Array<T>,
|
|
94
|
+
instructions: I,
|
|
95
|
+
_options: Options | undefined,
|
|
96
|
+
ctx: ZaiContext
|
|
97
|
+
): Promise<Array<RatingResult<I>>> => {
|
|
98
|
+
ctx.controller.signal.throwIfAborted()
|
|
99
|
+
const options = _Options.parse(_options ?? {})
|
|
100
|
+
const tokenizer = await getTokenizer()
|
|
101
|
+
const model = await ctx.getModel()
|
|
102
|
+
|
|
103
|
+
// Handle empty array
|
|
104
|
+
if (input.length === 0) {
|
|
105
|
+
return []
|
|
106
|
+
}
|
|
107
|
+
|
|
108
|
+
const taskId = ctx.taskId
|
|
109
|
+
const taskType = 'zai.rate'
|
|
110
|
+
|
|
111
|
+
const TOKENS_TOTAL_MAX = model.input.maxTokens - PROMPT_INPUT_BUFFER - PROMPT_OUTPUT_BUFFER
|
|
112
|
+
|
|
113
|
+
// Phase 1: Generate evaluation criteria
|
|
114
|
+
const isStringInstructions = typeof instructions === 'string'
|
|
115
|
+
const criteriaKeys: string[] = isStringInstructions
|
|
116
|
+
? [] // Will be generated by LLM
|
|
117
|
+
: Object.keys(instructions as Record<string, string>)
|
|
118
|
+
|
|
119
|
+
const generateCriteriaPrompt = isStringInstructions
|
|
120
|
+
? `Generate 3-5 evaluation criteria for: "${instructions}"
|
|
121
|
+
|
|
122
|
+
For each criterion, provide 5 labels (very_bad, bad, average, good, very_good) with brief descriptions.
|
|
123
|
+
|
|
124
|
+
Output format (JSON):
|
|
125
|
+
{
|
|
126
|
+
"criterion1_name": {
|
|
127
|
+
"very_bad": "description",
|
|
128
|
+
"bad": "description",
|
|
129
|
+
"average": "description",
|
|
130
|
+
"good": "description",
|
|
131
|
+
"very_good": "description"
|
|
132
|
+
},
|
|
133
|
+
"criterion2_name": { ... }
|
|
134
|
+
}
|
|
135
|
+
|
|
136
|
+
Keep criterion names short (1-2 words, lowercase, use underscores).
|
|
137
|
+
Keep descriptions brief (5-10 words each).`
|
|
138
|
+
: `For these evaluation criteria, provide 5 labels (very_bad, bad, average, good, very_good) with brief descriptions for each:
|
|
139
|
+
|
|
140
|
+
${criteriaKeys.map((key) => `- ${key}: ${(instructions as Record<string, string>)[key]}`).join('\n')}
|
|
141
|
+
|
|
142
|
+
Output format (JSON):
|
|
143
|
+
{
|
|
144
|
+
"${criteriaKeys[0]}": {
|
|
145
|
+
"very_bad": "description",
|
|
146
|
+
"bad": "description",
|
|
147
|
+
"average": "description",
|
|
148
|
+
"good": "description",
|
|
149
|
+
"very_good": "description"
|
|
150
|
+
}
|
|
151
|
+
${criteriaKeys.length > 1 ? '...' : ''}
|
|
152
|
+
}
|
|
153
|
+
|
|
154
|
+
Keep descriptions brief (5-10 words each).`
|
|
155
|
+
|
|
156
|
+
const { extracted: evaluationCriteria } = await ctx.generateContent({
|
|
157
|
+
systemPrompt: `You are creating evaluation criteria for rating items on a 1-5 scale.
|
|
158
|
+
Each criterion must have exactly 5 labels: very_bad (1), bad (2), average (3), good (4), very_good (5).
|
|
159
|
+
Output valid JSON only.`,
|
|
160
|
+
messages: [
|
|
161
|
+
{
|
|
162
|
+
type: 'text',
|
|
163
|
+
role: 'user',
|
|
164
|
+
content: generateCriteriaPrompt,
|
|
165
|
+
},
|
|
166
|
+
],
|
|
167
|
+
transform: (text) => {
|
|
168
|
+
// Extract JSON from markdown code blocks if present
|
|
169
|
+
const jsonMatch = text.match(/```(?:json)?\s*(\{[\s\S]*?\})\s*```/) || text.match(/(\{[\s\S]*\})/)
|
|
170
|
+
if (!jsonMatch) {
|
|
171
|
+
throw new Error('Failed to parse evaluation criteria JSON')
|
|
172
|
+
}
|
|
173
|
+
return JSON.parse(jsonMatch[1]) as EvaluationCriteria
|
|
174
|
+
},
|
|
175
|
+
})
|
|
176
|
+
|
|
177
|
+
// Extract final criteria keys
|
|
178
|
+
const finalCriteriaKeys = Object.keys(evaluationCriteria)
|
|
179
|
+
if (finalCriteriaKeys.length === 0) {
|
|
180
|
+
throw new Error('No evaluation criteria generated')
|
|
181
|
+
}
|
|
182
|
+
|
|
183
|
+
// Phase 2: Chunk items
|
|
184
|
+
const TOKENS_CRITERIA_MAX = Math.floor(TOKENS_TOTAL_MAX * 0.3)
|
|
185
|
+
const TOKENS_ITEMS_MAX = TOKENS_TOTAL_MAX - TOKENS_CRITERIA_MAX
|
|
186
|
+
|
|
187
|
+
let chunks: Array<typeof input> = []
|
|
188
|
+
let currentChunk: typeof input = []
|
|
189
|
+
let currentChunkTokens = 0
|
|
190
|
+
|
|
191
|
+
for (const element of input) {
|
|
192
|
+
const elementAsString = tokenizer.truncate(stringify(element, false), options.tokensPerItem)
|
|
193
|
+
const elementTokens = tokenizer.count(elementAsString)
|
|
194
|
+
|
|
195
|
+
if (currentChunkTokens + elementTokens > TOKENS_ITEMS_MAX || currentChunk.length >= options.maxItemsPerChunk) {
|
|
196
|
+
if (currentChunk.length > 0) {
|
|
197
|
+
chunks.push(currentChunk)
|
|
198
|
+
}
|
|
199
|
+
currentChunk = []
|
|
200
|
+
currentChunkTokens = 0
|
|
201
|
+
}
|
|
202
|
+
|
|
203
|
+
currentChunk.push(element)
|
|
204
|
+
currentChunkTokens += elementTokens
|
|
205
|
+
}
|
|
206
|
+
|
|
207
|
+
if (currentChunk.length > 0) {
|
|
208
|
+
chunks.push(currentChunk)
|
|
209
|
+
}
|
|
210
|
+
|
|
211
|
+
chunks = chunks.filter((x) => x.length > 0)
|
|
212
|
+
|
|
213
|
+
// Phase 3: Rate each chunk in parallel
|
|
214
|
+
type ChunkResult = {
|
|
215
|
+
ratings: Array<Record<string, number>>
|
|
216
|
+
meta: { cost: { input: number; output: number }; latency: number; tokens: { input: number; output: number } }
|
|
217
|
+
}
|
|
218
|
+
|
|
219
|
+
const rateChunk = async (chunk: typeof input): Promise<ChunkResult> => {
|
|
220
|
+
ctx.controller.signal.throwIfAborted()
|
|
221
|
+
|
|
222
|
+
// Get examples from adapter for active learning
|
|
223
|
+
const chunkInputStr = JSON.stringify(chunk)
|
|
224
|
+
const examples =
|
|
225
|
+
taskId && ctx.adapter
|
|
226
|
+
? await ctx.adapter.getExamples<string, Array<Record<string, number>>>({
|
|
227
|
+
input: chunkInputStr.slice(0, 1000), // Limit search string length
|
|
228
|
+
taskType,
|
|
229
|
+
taskId,
|
|
230
|
+
})
|
|
231
|
+
: []
|
|
232
|
+
|
|
233
|
+
// Check for exact match (cache hit)
|
|
234
|
+
const key = fastHash(
|
|
235
|
+
stringify({
|
|
236
|
+
taskId,
|
|
237
|
+
taskType,
|
|
238
|
+
input: chunkInputStr,
|
|
239
|
+
instructions: stringify(instructions),
|
|
240
|
+
})
|
|
241
|
+
)
|
|
242
|
+
const exactMatch = examples.find((x) => x.key === key)
|
|
243
|
+
if (exactMatch && exactMatch.output) {
|
|
244
|
+
// Return cached result with zero cost
|
|
245
|
+
return {
|
|
246
|
+
ratings: exactMatch.output,
|
|
247
|
+
meta: { cost: { input: 0, output: 0 }, latency: 0, tokens: { input: 0, output: 0 } },
|
|
248
|
+
}
|
|
249
|
+
}
|
|
250
|
+
|
|
251
|
+
const formatCriteria = () => {
|
|
252
|
+
return finalCriteriaKeys
|
|
253
|
+
.map((key) => {
|
|
254
|
+
const labels = evaluationCriteria[key]
|
|
255
|
+
return `**${key}**:
|
|
256
|
+
- very_bad (1): ${labels?.very_bad}
|
|
257
|
+
- bad (2): ${labels?.bad}
|
|
258
|
+
- average (3): ${labels?.average}
|
|
259
|
+
- good (4): ${labels?.good}
|
|
260
|
+
- very_good (5): ${labels?.very_good}`
|
|
261
|
+
})
|
|
262
|
+
.join('\n\n')
|
|
263
|
+
}
|
|
264
|
+
|
|
265
|
+
const formatItems = (items: typeof chunk) => {
|
|
266
|
+
return items
|
|
267
|
+
.map((item, idx) => {
|
|
268
|
+
const itemStr = tokenizer.truncate(stringify(item, false), options.tokensPerItem)
|
|
269
|
+
return `■${idx}: ${itemStr}■`
|
|
270
|
+
})
|
|
271
|
+
.join('\n')
|
|
272
|
+
}
|
|
273
|
+
|
|
274
|
+
// Format examples for few-shot learning
|
|
275
|
+
const exampleMessages: Array<{ type: 'text'; role: 'user' | 'assistant'; content: string }> = []
|
|
276
|
+
|
|
277
|
+
for (const example of examples.slice(0, 5)) {
|
|
278
|
+
// User message with input
|
|
279
|
+
try {
|
|
280
|
+
const exampleInput = JSON.parse(example.input)
|
|
281
|
+
exampleMessages.push({
|
|
282
|
+
type: 'text',
|
|
283
|
+
role: 'user',
|
|
284
|
+
content: `Expert Example - Items to rate:
|
|
285
|
+
${formatItems(Array.isArray(exampleInput) ? exampleInput : [exampleInput])}
|
|
286
|
+
|
|
287
|
+
Rate each item on all criteria.`,
|
|
288
|
+
})
|
|
289
|
+
|
|
290
|
+
// Assistant message with ratings
|
|
291
|
+
const exampleOutput = example.output
|
|
292
|
+
if (Array.isArray(exampleOutput) && exampleOutput.length > 0) {
|
|
293
|
+
const formattedRatings = exampleOutput
|
|
294
|
+
.map((rating, idx) => {
|
|
295
|
+
const pairs = finalCriteriaKeys
|
|
296
|
+
.map((key) => {
|
|
297
|
+
const value = rating[key]
|
|
298
|
+
if (typeof value === 'number') {
|
|
299
|
+
// Convert number back to label
|
|
300
|
+
const labelMap: Record<number, string> = {
|
|
301
|
+
1: 'very_bad',
|
|
302
|
+
2: 'bad',
|
|
303
|
+
3: 'average',
|
|
304
|
+
4: 'good',
|
|
305
|
+
5: 'very_good',
|
|
306
|
+
}
|
|
307
|
+
return `${key}=${labelMap[value] || 'average'}`
|
|
308
|
+
}
|
|
309
|
+
return null
|
|
310
|
+
})
|
|
311
|
+
.filter(Boolean)
|
|
312
|
+
.join(';')
|
|
313
|
+
return `■${idx}:${pairs}■`
|
|
314
|
+
})
|
|
315
|
+
.join('\n')
|
|
316
|
+
|
|
317
|
+
exampleMessages.push({
|
|
318
|
+
type: 'text',
|
|
319
|
+
role: 'assistant',
|
|
320
|
+
content: `${formattedRatings}\n${END}`,
|
|
321
|
+
})
|
|
322
|
+
|
|
323
|
+
if (example.explanation) {
|
|
324
|
+
exampleMessages.push({
|
|
325
|
+
type: 'text',
|
|
326
|
+
role: 'assistant',
|
|
327
|
+
content: `Reasoning: ${example.explanation}`,
|
|
328
|
+
})
|
|
329
|
+
}
|
|
330
|
+
}
|
|
331
|
+
} catch {
|
|
332
|
+
// Skip malformed examples
|
|
333
|
+
}
|
|
334
|
+
}
|
|
335
|
+
|
|
336
|
+
const { extracted, meta } = await ctx.generateContent({
|
|
337
|
+
systemPrompt: `You are rating items based on evaluation criteria.
|
|
338
|
+
|
|
339
|
+
Evaluation Criteria:
|
|
340
|
+
${formatCriteria()}
|
|
341
|
+
|
|
342
|
+
For each item, rate it on EACH criterion using one of these labels:
|
|
343
|
+
very_bad, bad, average, good, very_good
|
|
344
|
+
|
|
345
|
+
Output format:
|
|
346
|
+
■0:criterion1=label;criterion2=label;criterion3=label■
|
|
347
|
+
■1:criterion1=label;criterion2=label;criterion3=label■
|
|
348
|
+
${END}
|
|
349
|
+
|
|
350
|
+
IMPORTANT:
|
|
351
|
+
- Rate every item (■0 to ■${chunk.length - 1})
|
|
352
|
+
- Use exact criterion names: ${finalCriteriaKeys.join(', ')}
|
|
353
|
+
- Use exact label names: very_bad, bad, average, good, very_good
|
|
354
|
+
- Use semicolons (;) between criteria
|
|
355
|
+
- Use equals (=) between criterion and label`,
|
|
356
|
+
stopSequences: [END],
|
|
357
|
+
messages: [
|
|
358
|
+
...exampleMessages,
|
|
359
|
+
{
|
|
360
|
+
type: 'text',
|
|
361
|
+
role: 'user',
|
|
362
|
+
content: `Items to rate (■0 to ■${chunk.length - 1}):
|
|
363
|
+
${formatItems(chunk)}
|
|
364
|
+
|
|
365
|
+
Rate each item on all criteria.
|
|
366
|
+
Output format: ■index:criterion1=label;criterion2=label■
|
|
367
|
+
${END}`,
|
|
368
|
+
},
|
|
369
|
+
],
|
|
370
|
+
transform: (text) => {
|
|
371
|
+
const results: Array<Record<string, number>> = []
|
|
372
|
+
|
|
373
|
+
// Parse ratings: ■0:affordability=good;quality=very_good■
|
|
374
|
+
const regex = /■(\d+):([^■]+)■/g
|
|
375
|
+
let match: RegExpExecArray | null
|
|
376
|
+
|
|
377
|
+
while ((match = regex.exec(text)) !== null) {
|
|
378
|
+
const idx = parseInt(match[1] ?? '', 10)
|
|
379
|
+
const ratingsStr = match[2] ?? ''
|
|
380
|
+
|
|
381
|
+
if (isNaN(idx) || idx < 0 || idx >= chunk.length) {
|
|
382
|
+
continue
|
|
383
|
+
}
|
|
384
|
+
|
|
385
|
+
const itemRatings: Record<string, number> = {}
|
|
386
|
+
let total = 0
|
|
387
|
+
|
|
388
|
+
// Parse criterion=label pairs
|
|
389
|
+
const pairs = ratingsStr.split(';').filter((x) => x.trim().length > 0)
|
|
390
|
+
for (const pair of pairs) {
|
|
391
|
+
const [criterion, label] = pair.split('=').map((x) => x.trim())
|
|
392
|
+
if (!criterion || !label) continue
|
|
393
|
+
|
|
394
|
+
// Convert label to number
|
|
395
|
+
const labelLower = label.toLowerCase().replace(/\s+/g, '_')
|
|
396
|
+
const ratingValue = RATING_VALUES[labelLower as RatingLabel] ?? 3 // default to average
|
|
397
|
+
|
|
398
|
+
itemRatings[criterion] = ratingValue
|
|
399
|
+
total += ratingValue
|
|
400
|
+
}
|
|
401
|
+
|
|
402
|
+
itemRatings.total = total
|
|
403
|
+
results[idx] = itemRatings
|
|
404
|
+
}
|
|
405
|
+
|
|
406
|
+
// Fill in missing results with defaults
|
|
407
|
+
for (let i = 0; i < chunk.length; i++) {
|
|
408
|
+
if (!results[i]) {
|
|
409
|
+
const defaultRatings: Record<string, number> = {}
|
|
410
|
+
let total = 0
|
|
411
|
+
for (const key of finalCriteriaKeys) {
|
|
412
|
+
defaultRatings[key] = 3 // average
|
|
413
|
+
total += 3
|
|
414
|
+
}
|
|
415
|
+
defaultRatings.total = total
|
|
416
|
+
results[i] = defaultRatings
|
|
417
|
+
}
|
|
418
|
+
}
|
|
419
|
+
|
|
420
|
+
return results
|
|
421
|
+
},
|
|
422
|
+
})
|
|
423
|
+
|
|
424
|
+
return { ratings: extracted, meta }
|
|
425
|
+
}
|
|
426
|
+
|
|
427
|
+
// Process chunks in parallel with p-limit
|
|
428
|
+
const limit = pLimit(10)
|
|
429
|
+
const chunkPromises = chunks.map((chunk) => limit(() => rateChunk(chunk)))
|
|
430
|
+
|
|
431
|
+
const ratedChunks = await Promise.all(chunkPromises)
|
|
432
|
+
|
|
433
|
+
// Phase 4: Flatten results and accumulate metadata
|
|
434
|
+
const allRatings = ratedChunks.flatMap((result) => result.ratings) as Array<RatingResult<I>>
|
|
435
|
+
|
|
436
|
+
// Accumulate metadata from all chunks
|
|
437
|
+
const totalMeta = ratedChunks.reduce(
|
|
438
|
+
(acc, result) => ({
|
|
439
|
+
cost: {
|
|
440
|
+
input: acc.cost.input + result.meta.cost.input,
|
|
441
|
+
output: acc.cost.output + result.meta.cost.output,
|
|
442
|
+
},
|
|
443
|
+
latency: Math.max(acc.latency, result.meta.latency), // Use max latency
|
|
444
|
+
tokens: {
|
|
445
|
+
input: acc.tokens.input + result.meta.tokens.input,
|
|
446
|
+
output: acc.tokens.output + result.meta.tokens.output,
|
|
447
|
+
},
|
|
448
|
+
}),
|
|
449
|
+
{
|
|
450
|
+
cost: { input: 0, output: 0 },
|
|
451
|
+
latency: 0,
|
|
452
|
+
tokens: { input: 0, output: 0 },
|
|
453
|
+
}
|
|
454
|
+
)
|
|
455
|
+
|
|
456
|
+
// Save example for active learning
|
|
457
|
+
if (taskId && ctx.adapter && !ctx.controller.signal.aborted) {
|
|
458
|
+
const key = fastHash(
|
|
459
|
+
stringify({
|
|
460
|
+
taskId,
|
|
461
|
+
taskType,
|
|
462
|
+
input: JSON.stringify(input),
|
|
463
|
+
instructions: stringify(instructions),
|
|
464
|
+
})
|
|
465
|
+
)
|
|
466
|
+
|
|
467
|
+
await ctx.adapter.saveExample({
|
|
468
|
+
key,
|
|
469
|
+
taskType,
|
|
470
|
+
taskId,
|
|
471
|
+
input: JSON.stringify(input),
|
|
472
|
+
output: allRatings,
|
|
473
|
+
instructions: typeof instructions === 'string' ? instructions : JSON.stringify(instructions),
|
|
474
|
+
metadata: {
|
|
475
|
+
cost: {
|
|
476
|
+
input: totalMeta.cost.input,
|
|
477
|
+
output: totalMeta.cost.output,
|
|
478
|
+
},
|
|
479
|
+
latency: totalMeta.latency,
|
|
480
|
+
model: ctx.modelId,
|
|
481
|
+
tokens: {
|
|
482
|
+
input: totalMeta.tokens.input,
|
|
483
|
+
output: totalMeta.tokens.output,
|
|
484
|
+
},
|
|
485
|
+
},
|
|
486
|
+
})
|
|
487
|
+
}
|
|
488
|
+
|
|
489
|
+
return allRatings
|
|
490
|
+
}
|
|
491
|
+
|
|
492
|
+
Zai.prototype.rate = function <T, I extends RatingInstructions>(
|
|
493
|
+
this: Zai,
|
|
494
|
+
input: Array<T>,
|
|
495
|
+
instructions: I,
|
|
496
|
+
_options?: Options
|
|
497
|
+
): Response<Array<RatingResult<I>>, Array<SimplifiedRatingResult<I>>> {
|
|
498
|
+
const context = new ZaiContext({
|
|
499
|
+
client: this.client,
|
|
500
|
+
modelId: this.Model,
|
|
501
|
+
taskId: this.taskId,
|
|
502
|
+
taskType: 'zai.rate',
|
|
503
|
+
adapter: this.adapter,
|
|
504
|
+
})
|
|
505
|
+
|
|
506
|
+
return new Response<Array<RatingResult<I>>, Array<SimplifiedRatingResult<I>>>(
|
|
507
|
+
context,
|
|
508
|
+
rate(input, instructions, _options, context),
|
|
509
|
+
(result) => {
|
|
510
|
+
// If instructions is a string, simplify to just the total number
|
|
511
|
+
if (typeof instructions === 'string') {
|
|
512
|
+
return result.map((r) => r.total as SimplifiedRatingResult<I>)
|
|
513
|
+
}
|
|
514
|
+
// Otherwise return the full result (including total)
|
|
515
|
+
return result as Array<SimplifiedRatingResult<I>>
|
|
516
|
+
}
|
|
517
|
+
)
|
|
518
|
+
}
|