@llumiverse/drivers 0.11.0 → 0.12.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (154) hide show
  1. package/README.md +1 -0
  2. package/lib/cjs/bedrock/index.js +74 -22
  3. package/lib/cjs/bedrock/index.js.map +1 -1
  4. package/lib/cjs/groq/index.js +112 -0
  5. package/lib/cjs/groq/index.js.map +1 -0
  6. package/lib/cjs/huggingface_ie.js +6 -0
  7. package/lib/cjs/huggingface_ie.js.map +1 -1
  8. package/lib/cjs/index.js +1 -0
  9. package/lib/cjs/index.js.map +1 -1
  10. package/lib/cjs/mistral/index.js +8 -4
  11. package/lib/cjs/mistral/index.js.map +1 -1
  12. package/lib/cjs/openai.js +40 -18
  13. package/lib/cjs/openai.js.map +1 -1
  14. package/lib/cjs/replicate.js +1 -0
  15. package/lib/cjs/replicate.js.map +1 -1
  16. package/lib/cjs/togetherai/index.js +4 -1
  17. package/lib/cjs/togetherai/index.js.map +1 -1
  18. package/lib/cjs/vertexai/models/gemini.js +15 -3
  19. package/lib/cjs/vertexai/models/gemini.js.map +1 -1
  20. package/lib/cjs/vertexai/models/palm-model-base.js +3 -1
  21. package/lib/cjs/vertexai/models/palm-model-base.js.map +1 -1
  22. package/lib/esm/bedrock/index.js +74 -22
  23. package/lib/esm/bedrock/index.js.map +1 -1
  24. package/lib/esm/groq/index.js +105 -0
  25. package/lib/esm/groq/index.js.map +1 -0
  26. package/lib/esm/huggingface_ie.js +6 -0
  27. package/lib/esm/huggingface_ie.js.map +1 -1
  28. package/lib/esm/index.js +1 -0
  29. package/lib/esm/index.js.map +1 -1
  30. package/lib/esm/mistral/index.js +9 -5
  31. package/lib/esm/mistral/index.js.map +1 -1
  32. package/lib/esm/openai.js +40 -18
  33. package/lib/esm/openai.js.map +1 -1
  34. package/lib/esm/replicate.js +1 -0
  35. package/lib/esm/replicate.js.map +1 -1
  36. package/lib/esm/togetherai/index.js +4 -1
  37. package/lib/esm/togetherai/index.js.map +1 -1
  38. package/lib/esm/vertexai/models/gemini.js +16 -4
  39. package/lib/esm/vertexai/models/gemini.js.map +1 -1
  40. package/lib/esm/vertexai/models/palm-model-base.js +3 -1
  41. package/lib/esm/vertexai/models/palm-model-base.js.map +1 -1
  42. package/lib/types/bedrock/index.d.ts.map +1 -1
  43. package/lib/types/{src/mistral → groq}/index.d.ts +8 -8
  44. package/lib/types/groq/index.d.ts.map +1 -0
  45. package/lib/types/huggingface_ie.d.ts +2 -0
  46. package/lib/types/huggingface_ie.d.ts.map +1 -1
  47. package/lib/types/index.d.ts +1 -0
  48. package/lib/types/index.d.ts.map +1 -1
  49. package/lib/types/mistral/index.d.ts.map +1 -1
  50. package/lib/types/openai.d.ts.map +1 -1
  51. package/lib/types/replicate.d.ts +1 -0
  52. package/lib/types/replicate.d.ts.map +1 -1
  53. package/lib/types/togetherai/index.d.ts.map +1 -1
  54. package/lib/types/togetherai/interfaces.d.ts +15 -0
  55. package/lib/types/togetherai/interfaces.d.ts.map +1 -1
  56. package/lib/types/vertexai/models/gemini.d.ts.map +1 -1
  57. package/lib/types/vertexai/models/palm-model-base.d.ts.map +1 -1
  58. package/package.json +3 -2
  59. package/src/bedrock/index.ts +69 -21
  60. package/src/groq/index.ts +134 -0
  61. package/src/huggingface_ie.ts +6 -0
  62. package/src/index.ts +1 -1
  63. package/src/mistral/index.ts +11 -7
  64. package/src/mistral/types.ts +2 -2
  65. package/src/openai.ts +43 -20
  66. package/src/replicate.ts +1 -0
  67. package/src/togetherai/index.ts +6 -4
  68. package/src/togetherai/interfaces.ts +16 -0
  69. package/src/vertexai/models/gemini.ts +13 -5
  70. package/src/vertexai/models/palm-model-base.ts +3 -1
  71. package/lib/cjs/vertexai/utils/prompts.js +0 -52
  72. package/lib/cjs/vertexai/utils/prompts.js.map +0 -1
  73. package/lib/esm/src/bedrock/index.js +0 -375
  74. package/lib/esm/src/bedrock/index.js.map +0 -1
  75. package/lib/esm/src/bedrock/s3.js +0 -53
  76. package/lib/esm/src/bedrock/s3.js.map +0 -1
  77. package/lib/esm/src/huggingface_ie.js +0 -173
  78. package/lib/esm/src/huggingface_ie.js.map +0 -1
  79. package/lib/esm/src/index.js +0 -9
  80. package/lib/esm/src/index.js.map +0 -1
  81. package/lib/esm/src/mistral/index.js +0 -145
  82. package/lib/esm/src/mistral/index.js.map +0 -1
  83. package/lib/esm/src/mistral/types.js +0 -80
  84. package/lib/esm/src/mistral/types.js.map +0 -1
  85. package/lib/esm/src/openai.js +0 -195
  86. package/lib/esm/src/openai.js.map +0 -1
  87. package/lib/esm/src/replicate.js +0 -281
  88. package/lib/esm/src/replicate.js.map +0 -1
  89. package/lib/esm/src/test/TestErrorCompletionStream.js +0 -16
  90. package/lib/esm/src/test/TestErrorCompletionStream.js.map +0 -1
  91. package/lib/esm/src/test/TestValidationErrorCompletionStream.js +0 -20
  92. package/lib/esm/src/test/TestValidationErrorCompletionStream.js.map +0 -1
  93. package/lib/esm/src/test/index.js +0 -91
  94. package/lib/esm/src/test/index.js.map +0 -1
  95. package/lib/esm/src/test/utils.js +0 -25
  96. package/lib/esm/src/test/utils.js.map +0 -1
  97. package/lib/esm/src/togetherai/index.js +0 -89
  98. package/lib/esm/src/togetherai/index.js.map +0 -1
  99. package/lib/esm/src/togetherai/interfaces.js +0 -2
  100. package/lib/esm/src/togetherai/interfaces.js.map +0 -1
  101. package/lib/esm/src/vertexai/debug.js +0 -6
  102. package/lib/esm/src/vertexai/debug.js.map +0 -1
  103. package/lib/esm/src/vertexai/embeddings/embeddings-text.js +0 -19
  104. package/lib/esm/src/vertexai/embeddings/embeddings-text.js.map +0 -1
  105. package/lib/esm/src/vertexai/index.js +0 -73
  106. package/lib/esm/src/vertexai/index.js.map +0 -1
  107. package/lib/esm/src/vertexai/models/codey-chat.js +0 -61
  108. package/lib/esm/src/vertexai/models/codey-chat.js.map +0 -1
  109. package/lib/esm/src/vertexai/models/codey-text.js +0 -31
  110. package/lib/esm/src/vertexai/models/codey-text.js.map +0 -1
  111. package/lib/esm/src/vertexai/models/gemini.js +0 -136
  112. package/lib/esm/src/vertexai/models/gemini.js.map +0 -1
  113. package/lib/esm/src/vertexai/models/palm-model-base.js +0 -53
  114. package/lib/esm/src/vertexai/models/palm-model-base.js.map +0 -1
  115. package/lib/esm/src/vertexai/models/palm2-chat.js +0 -61
  116. package/lib/esm/src/vertexai/models/palm2-chat.js.map +0 -1
  117. package/lib/esm/src/vertexai/models/palm2-text.js +0 -31
  118. package/lib/esm/src/vertexai/models/palm2-text.js.map +0 -1
  119. package/lib/esm/src/vertexai/models.js +0 -87
  120. package/lib/esm/src/vertexai/models.js.map +0 -1
  121. package/lib/esm/src/vertexai/utils/prompts.js +0 -47
  122. package/lib/esm/src/vertexai/utils/prompts.js.map +0 -1
  123. package/lib/esm/src/vertexai/utils/tensor.js +0 -82
  124. package/lib/esm/src/vertexai/utils/tensor.js.map +0 -1
  125. package/lib/esm/tsconfig.tsbuildinfo +0 -1
  126. package/lib/esm/vertexai/utils/prompts.js +0 -47
  127. package/lib/esm/vertexai/utils/prompts.js.map +0 -1
  128. package/lib/types/src/bedrock/index.d.ts +0 -94
  129. package/lib/types/src/bedrock/s3.d.ts +0 -16
  130. package/lib/types/src/huggingface_ie.d.ts +0 -30
  131. package/lib/types/src/index.d.ts +0 -8
  132. package/lib/types/src/mistral/types.d.ts +0 -130
  133. package/lib/types/src/openai.d.ts +0 -30
  134. package/lib/types/src/replicate.d.ts +0 -47
  135. package/lib/types/src/test/TestErrorCompletionStream.d.ts +0 -8
  136. package/lib/types/src/test/TestValidationErrorCompletionStream.d.ts +0 -8
  137. package/lib/types/src/test/index.d.ts +0 -23
  138. package/lib/types/src/test/utils.d.ts +0 -4
  139. package/lib/types/src/togetherai/index.d.ts +0 -21
  140. package/lib/types/src/togetherai/interfaces.d.ts +0 -80
  141. package/lib/types/src/vertexai/debug.d.ts +0 -1
  142. package/lib/types/src/vertexai/embeddings/embeddings-text.d.ts +0 -9
  143. package/lib/types/src/vertexai/index.d.ts +0 -21
  144. package/lib/types/src/vertexai/models/codey-chat.d.ts +0 -50
  145. package/lib/types/src/vertexai/models/codey-text.d.ts +0 -38
  146. package/lib/types/src/vertexai/models/gemini.d.ts +0 -10
  147. package/lib/types/src/vertexai/models/palm-model-base.d.ts +0 -60
  148. package/lib/types/src/vertexai/models/palm2-chat.d.ts +0 -60
  149. package/lib/types/src/vertexai/models/palm2-text.d.ts +0 -38
  150. package/lib/types/src/vertexai/models.d.ts +0 -13
  151. package/lib/types/src/vertexai/utils/prompts.d.ts +0 -19
  152. package/lib/types/src/vertexai/utils/tensor.d.ts +0 -5
  153. package/lib/types/vertexai/utils/prompts.d.ts +0 -20
  154. package/lib/types/vertexai/utils/prompts.d.ts.map +0 -1
@@ -91,29 +91,33 @@ export class BedrockDriver extends AbstractDriver<BedrockDriverOptions, BedrockP
91
91
  const body = decoder.decode(response.body);
92
92
  const result = JSON.parse(body);
93
93
 
94
- const getText = () => {
95
- if (result.completion) {
96
- return result.completion;
97
- } else if (result.generation) {
98
- return result.generation;
94
+ const getTextAnsStopReason = (): string[] => {
95
+ if (result.generation) {
96
+ // LLAMA2
97
+ return [result.generation, result.stop_reason]; // comes in coirrect format (stop, length)
99
98
  } else if (result.generations) {
100
- return result.generations[0].text;
99
+ // COHERE
100
+ return [result.generations[0].text, cohereFinishReason(result.generations[0].finish_reason)];
101
101
  } else if (result.completions) {
102
102
  //A21
103
- return result.completions[0].data?.text;
104
- } else if (result.content) { // calude
105
- return result.content[0]?.text || '';
106
- //result.stop_reason --> the stop reason
103
+ return [result.completions[0].data?.text, a21FinishReason(result.completions[0].finishReason?.reason)];
104
+ } else if (result.content) {
105
+ // anthropic claude
106
+ return [result.content[0]?.text || '', claudeFinishReason(result.stop_reason)];
107
107
  } else if (result.outputs) {
108
108
  // mistral
109
- return result.outputs[0]?.text;
110
- //result.outputs[0]?.stop_reason --> the stop reason
109
+ return [result.outputs[0]?.text, result.outputs[0]?.stop_reason]; // the stop reason is in the expected format ("stop" and "length")
110
+ } else if (result.results) {
111
+ // Amazon Titan
112
+ return [result.results[0]?.outputText ?? '', titanFinishReason(result.results[0]?.completionReason)];
113
+ } else if (result.completion) { // TODO: who uses this?
114
+ return [result.completion];
111
115
  } else {
112
- return result.toString();
116
+ return [result.toString()];
113
117
  }
114
118
  };
115
119
 
116
- const text = getText();
120
+ const [text, finish_reason] = getTextAnsStopReason();
117
121
 
118
122
  const promptLength = typeof prompt === 'string' ? prompt.length :
119
123
  (prompt.system || '').length + prompt.messages.reduce((acc, m) => acc + m.content.length, 0);
@@ -123,7 +127,8 @@ export class BedrockDriver extends AbstractDriver<BedrockDriverOptions, BedrockP
123
127
  result: text?.length,
124
128
  prompt: promptLength,
125
129
  total: text?.length + promptLength,
126
- }
130
+ },
131
+ finish_reason
127
132
  }
128
133
  }
129
134
 
@@ -136,7 +141,11 @@ export class BedrockDriver extends AbstractDriver<BedrockDriverOptions, BedrockP
136
141
  contentType: "application/json",
137
142
  body: JSON.stringify(payload),
138
143
  });
139
- return this.extractDataFromResponse(prompt, res);
144
+ const completion = this.extractDataFromResponse(prompt, res);
145
+ if (options.include_original_response) {
146
+ completion.original_response = res;
147
+ }
148
+ return completion;
140
149
  }
141
150
 
142
151
  protected async canStream(options: ExecutionOptions): Promise<boolean> {
@@ -167,9 +176,9 @@ export class BedrockDriver extends AbstractDriver<BedrockDriverOptions, BedrockP
167
176
 
168
177
  return transformAsyncIterator(res.body, (stream: ResponseStream) => {
169
178
  const segment = JSON.parse(decoder.decode(stream.chunk?.bytes));
170
- if (segment.delta) {
179
+ if (segment.delta) { // who is this?
171
180
  return segment.delta.text || '';
172
- } else if (segment.completion) {
181
+ } else if (segment.completion) { // who is this?
173
182
  return segment.completion;
174
183
  } else if (segment.completions) {
175
184
  return segment.completions[0].data?.text;
@@ -181,6 +190,11 @@ export class BedrockDriver extends AbstractDriver<BedrockDriverOptions, BedrockP
181
190
  // mistral.mixtral-8x7b-instruct-v0:1
182
191
  return segment.outputs[0].text;
183
192
  //segment.outputs[0].stop_reason;
193
+ } else if (segment.outputText) {
194
+ // Amazon Titan
195
+ return segment.outputText;
196
+ //completionReason
197
+ // token count too
184
198
  } else {
185
199
  segment.toString();
186
200
  }
@@ -230,12 +244,12 @@ export class BedrockDriver extends AbstractDriver<BedrockDriverOptions, BedrockP
230
244
  } as CohereRequestPayload;
231
245
  } else if (contains(options.model, "amazon")) {
232
246
  return {
233
- inputText: prompt,
247
+ inputText: "User: " + (prompt as string) + "\nBot:", // see https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-titan-text.html#model-parameters-titan-request-response
234
248
  textGenerationConfig: {
235
249
  temperature: options.temperature,
236
- topP: 0.9,
250
+ topP: options.top_p,
237
251
  maxTokenCount: options.max_tokens,
238
- stopSequences: ["\n"],
252
+ //stopSequences: ["\n"],
239
253
  },
240
254
  } as AmazonRequestPayload;
241
255
  } else if (contains(options.model, "mistral")) {
@@ -502,4 +516,38 @@ function jobInfo(job: GetModelCustomizationJobCommandOutput, jobId: string): Tra
502
516
  }
503
517
 
504
518
 
519
+ function claudeFinishReason(reason: string | undefined) {
520
+ if (!reason) return undefined;
521
+ switch (reason) {
522
+ case 'end_turn': return "stop";
523
+ case 'max_tokens': return "length";
524
+ default: return reason; //stop_sequence
525
+ }
526
+ }
505
527
 
528
+ function cohereFinishReason(reason: string | undefined) {
529
+ if (!reason) return undefined;
530
+ switch (reason) {
531
+ case 'COMPLETE': return "stop";
532
+ case 'MAX_TOKENS': return "length";
533
+ default: return reason;
534
+ }
535
+ }
536
+
537
+ function a21FinishReason(reason: string | undefined) {
538
+ if (!reason) return undefined;
539
+ switch (reason) {
540
+ case 'endoftext': return "stop";
541
+ case 'length': return "length";
542
+ default: return reason;
543
+ }
544
+ }
545
+
546
+ function titanFinishReason(reason: string | undefined) {
547
+ if (!reason) return undefined;
548
+ switch (reason) {
549
+ case 'FINISH': return "stop";
550
+ case 'LENGTH': return "length";
551
+ default: return reason;
552
+ }
553
+ }
@@ -0,0 +1,134 @@
1
+ import { AIModel, AbstractDriver, BuiltinProviders, Completion, DriverOptions, EmbeddingsOptions, EmbeddingsResult, ExecutionOptions, PromptSegment } from "@llumiverse/core";
2
+ import { transformAsyncIterator } from "@llumiverse/core/async";
3
+ import { OpenAITextMessage, formatOpenAILikePrompt, getJSONSafetyNotice } from "@llumiverse/core/formatters";
4
+ import Groq from "groq-sdk";
5
+
6
+
7
+ interface GroqDriverOptions extends DriverOptions {
8
+ apiKey: string;
9
+ endpoint_url?: string;
10
+ }
11
+
12
+
13
+ export class GroqDriver extends AbstractDriver<GroqDriverOptions, OpenAITextMessage[]> {
14
+ provider: string;
15
+ apiKey: string;
16
+ client: Groq;
17
+ endpointUrl?: string;
18
+
19
+ constructor(options: GroqDriverOptions) {
20
+ super(options);
21
+ this.provider = BuiltinProviders.groq;
22
+ this.apiKey = options.apiKey;
23
+ this.client = new Groq({
24
+ apiKey: options.apiKey,
25
+ baseURL: options.endpoint_url
26
+ });
27
+ }
28
+
29
+ // protected canStream(options: ExecutionOptions): Promise<boolean> {
30
+ // if (options.resultSchema) {
31
+ // // not yet streamign json responses
32
+ // return Promise.resolve(false);
33
+ // } else {
34
+ // return Promise.resolve(true);
35
+ // }
36
+ // }
37
+
38
+ getResponseFormat(_options: ExecutionOptions): Groq.Chat.Completions.CompletionCreateParams.ResponseFormat | undefined {
39
+ //TODO: when forcing json_object type the streaming is not supported.
40
+ // either implement canStream as above or comment the code below:
41
+ // const responseFormatJson: Groq.Chat.Completions.CompletionCreateParams.ResponseFormat = {
42
+ // type: "json_object",
43
+ // }
44
+
45
+ // return _options.resultSchema ? responseFormatJson : undefined;
46
+ return undefined;
47
+ }
48
+
49
+ protected formatPrompt(segments: PromptSegment[], opts: ExecutionOptions): OpenAITextMessage[] {
50
+ const messages = formatOpenAILikePrompt(segments);
51
+ //Add JSON instruction is schema is provided
52
+ if (opts.resultSchema) {
53
+ messages.push({
54
+ role: "user",
55
+ content: "IMPORTANT: " + getJSONSafetyNotice(opts.resultSchema)
56
+ });
57
+ }
58
+ return messages;
59
+ }
60
+
61
+ async requestCompletion(messages: OpenAITextMessage[], options: ExecutionOptions): Promise<Completion<any>> {
62
+
63
+
64
+ const res = await this.client.chat.completions.create({
65
+ model: options.model,
66
+ messages: messages,
67
+ max_tokens: options.max_tokens,
68
+ temperature: options.temperature,
69
+ response_format: this.getResponseFormat(options),
70
+ });
71
+
72
+
73
+ const choice = res.choices[0];
74
+ const result = choice.message.content;
75
+
76
+ return {
77
+ result: result,
78
+ token_usage: {
79
+ prompt: res.usage?.prompt_tokens,
80
+ result: res.usage?.completion_tokens,
81
+ total: res.usage?.total_tokens,
82
+ },
83
+ finish_reason: choice.finish_reason,
84
+ original_response: options.include_original_response ? res : undefined,
85
+ };
86
+ }
87
+
88
+ async requestCompletionStream(messages: OpenAITextMessage[], options: ExecutionOptions): Promise<AsyncIterable<string>> {
89
+
90
+ const res = await this.client.chat.completions.create({
91
+ model: options.model,
92
+ messages: messages,
93
+ max_tokens: options.max_tokens,
94
+ temperature: options.temperature,
95
+ response_format: this.getResponseFormat(options),
96
+ stream: true
97
+ });
98
+
99
+ return transformAsyncIterator(res, (res) => res.choices[0].delta.content || '');
100
+
101
+ }
102
+
103
+ async listModels(): Promise<AIModel<string>[]> {
104
+ const models = await this.client.models.list();
105
+
106
+ if (!models.data) {
107
+ throw new Error("No models found");
108
+ }
109
+
110
+ const aimodels = models.data?.map(m => {
111
+ if (!m.id) {
112
+ throw new Error("Model id is missing");
113
+ }
114
+ return {
115
+ id: m.id,
116
+ name: m.id,
117
+ description: undefined,
118
+ provider: this.provider,
119
+ owner: m.owned_by || '',
120
+ }
121
+ });
122
+
123
+ return aimodels;
124
+ }
125
+
126
+ validateConnection(): Promise<boolean> {
127
+ throw new Error("Method not implemented.");
128
+ }
129
+
130
+ async generateEmbeddings({ }: EmbeddingsOptions): Promise<EmbeddingsResult> {
131
+ throw new Error("Method not implemented.");
132
+ }
133
+
134
+ }
@@ -92,6 +92,10 @@ export class HuggingFaceIEDriver extends AbstractDriver<HuggingFaceIEDriverOptio
92
92
  },
93
93
  });
94
94
 
95
+ let finish_reason = res.details?.finish_reason;
96
+ if (finish_reason === "eos_token") {
97
+ finish_reason = "stop";
98
+ }
95
99
  return {
96
100
  result: res.generated_text,
97
101
  token_usage: {
@@ -99,6 +103,8 @@ export class HuggingFaceIEDriver extends AbstractDriver<HuggingFaceIEDriverOptio
99
103
  prompt: prompt.length,
100
104
  total: res.generated_text.length + prompt.length,
101
105
  },
106
+ finish_reason,
107
+ original_response: options.include_original_response ? res : undefined,
102
108
  };
103
109
 
104
110
  }
package/src/index.ts CHANGED
@@ -6,4 +6,4 @@ export * from "./replicate.js";
6
6
  export * from "./test/index.js";
7
7
  export * from "./togetherai/index.js";
8
8
  export * from "./vertexai/index.js";
9
-
9
+ export * from "./groq/index.js";
@@ -1,8 +1,8 @@
1
- import { AIModel, AbstractDriver, Completion, DriverOptions, EmbeddingsOptions, EmbeddingsResult, ExecutionOptions, PromptSegment } from "@llumiverse/core";
1
+ import { AIModel, AbstractDriver, BuiltinProviders, Completion, DriverOptions, EmbeddingsOptions, EmbeddingsResult, ExecutionOptions, PromptSegment } from "@llumiverse/core";
2
2
  import { transformSSEStream } from "@llumiverse/core/async";
3
3
  import { OpenAITextMessage, formatOpenAILikePrompt, getJSONSafetyNotice } from "@llumiverse/core/formatters";
4
4
  import { FetchClient } from "api-fetch-client";
5
- import { CompletionRequestParams, ListModelsResponse, ResponseFormat } from "./types.js";
5
+ import { ChatCompletionResponse, CompletionRequestParams, ListModelsResponse, ResponseFormat } from "./types.js";
6
6
 
7
7
  //TODO retry on 429
8
8
  //const RETRY_STATUS_CODES = [429, 500, 502, 503, 504];
@@ -23,7 +23,7 @@ export class MistralAIDriver extends AbstractDriver<MistralAIDriverOptions, Open
23
23
 
24
24
  constructor(options: MistralAIDriverOptions) {
25
25
  super(options);
26
- this.provider = "MistralAI";
26
+ this.provider = BuiltinProviders.mistralai;
27
27
  this.apiKey = options.apiKey;
28
28
  //this.client = new MistralClient(options.apiKey, options.endpointUrl);
29
29
  this.client = new FetchClient(options.endpoint_url || ENDPOINT).withHeaders({
@@ -70,9 +70,10 @@ export class MistralAIDriver extends AbstractDriver<MistralAIDriverOptions, Open
70
70
  temperature: options.temperature,
71
71
  responseFormat: this.getResponseFormat(options),
72
72
  })
73
- })
73
+ }) as ChatCompletionResponse;
74
74
 
75
- const result = res.choices[0]?.message.content;
75
+ const choice = res.choices[0];
76
+ const result = choice.message.content;
76
77
 
77
78
  return {
78
79
  result: result,
@@ -80,7 +81,9 @@ export class MistralAIDriver extends AbstractDriver<MistralAIDriverOptions, Open
80
81
  prompt: res.usage.prompt_tokens,
81
82
  result: res.usage.completion_tokens,
82
83
  total: res.usage.total_tokens,
83
- }
84
+ },
85
+ finish_reason: choice.finish_reason,
86
+ original_response: options.include_original_response ? res : undefined,
84
87
  };
85
88
  }
86
89
 
@@ -112,7 +115,8 @@ export class MistralAIDriver extends AbstractDriver<MistralAIDriverOptions, Open
112
115
  id: m.id,
113
116
  name: m.id,
114
117
  description: undefined,
115
- provider: m.owned_by,
118
+ provider: this.provider,
119
+ owner: m.owned_by,
116
120
  }
117
121
  });
118
122
 
@@ -79,7 +79,7 @@ export interface ChatCompletionResponseChoice {
79
79
  role: string;
80
80
  content: string;
81
81
  };
82
- finish_reason: string;
82
+ finish_reason: string; // "stop" "length" "model_length" "error" "tool_calls"
83
83
  }
84
84
 
85
85
  export interface ChatCompletionResponseChunkChoice {
@@ -89,7 +89,7 @@ export interface ChatCompletionResponseChunkChoice {
89
89
  content?: string;
90
90
  tool_calls?: ToolCalls[];
91
91
  };
92
- finish_reason: string;
92
+ finish_reason: string; // "stop" "length" "model_length" "error" "tool_calls"
93
93
  }
94
94
 
95
95
  export interface ChatCompletionResponse {
package/src/openai.ts CHANGED
@@ -15,8 +15,8 @@ import {
15
15
  TrainingOptions,
16
16
  TrainingPromptOptions,
17
17
  } from "@llumiverse/core";
18
- import { formatOpenAILikePrompt } from "@llumiverse/core/formatters";
19
18
  import { asyncMap } from "@llumiverse/core/async";
19
+ import { formatOpenAILikePrompt } from "@llumiverse/core/formatters";
20
20
  import OpenAI from "openai";
21
21
  import { Stream } from "openai/streaming";
22
22
 
@@ -59,16 +59,20 @@ export class OpenAIDriver extends AbstractDriver<
59
59
  total: result.usage?.total_tokens,
60
60
  };
61
61
 
62
+ const choice = result.choices[0];
63
+ const finish_reason = choice.finish_reason;
64
+
62
65
  //if no schema, return content
63
66
  if (!options.resultSchema) {
64
67
  return {
65
- result: result.choices[0]?.message.content as string,
68
+ result: choice.message.content as string,
66
69
  token_usage: tokenInfo,
70
+ finish_reason
67
71
  }
68
72
  }
69
73
 
70
74
  //we have a schema: get the content and return after validation
71
- const data = result.choices[0]?.message.function_call?.arguments as any;
75
+ const data = choice?.message.tool_calls?.[0].function.arguments;
72
76
  if (!data) {
73
77
  this.logger?.error("[OpenAI] Response is not valid", result);
74
78
  throw new Error("Response is not valid: no data");
@@ -76,7 +80,8 @@ export class OpenAIDriver extends AbstractDriver<
76
80
 
77
81
  return {
78
82
  result: data,
79
- token_usage: tokenInfo
83
+ token_usage: tokenInfo,
84
+ finish_reason
80
85
  };
81
86
  }
82
87
 
@@ -84,7 +89,7 @@ export class OpenAIDriver extends AbstractDriver<
84
89
  const mapFn = options.resultSchema
85
90
  ? (chunk: OpenAI.Chat.Completions.ChatCompletionChunk) => {
86
91
  return (
87
- chunk.choices[0]?.delta?.function_call?.arguments ?? ""
92
+ chunk.choices[0]?.delta?.tool_calls?.[0].function?.arguments ?? ""
88
93
  );
89
94
  }
90
95
  : (chunk: OpenAI.Chat.Completions.ChatCompletionChunk) => {
@@ -98,17 +103,22 @@ export class OpenAIDriver extends AbstractDriver<
98
103
  temperature: options.temperature,
99
104
  n: 1,
100
105
  max_tokens: options.max_tokens,
101
- functions: options.resultSchema
106
+ tools: options.resultSchema
102
107
  ? [
103
108
  {
104
- name: "format_output",
105
- parameters: options.resultSchema as any,
106
- },
109
+ function: {
110
+ name: "format_output",
111
+ parameters: options.resultSchema as any,
112
+ },
113
+ type: "function"
114
+ } as OpenAI.Chat.ChatCompletionTool,
107
115
  ]
108
116
  : undefined,
109
- function_call: options.resultSchema
110
- ? { name: "format_output" }
111
- : undefined,
117
+ tool_choice: options.resultSchema
118
+ ? {
119
+ type: 'function',
120
+ function: { name: "format_output" }
121
+ } : undefined,
112
122
  })) as Stream<OpenAI.Chat.Completions.ChatCompletionChunk>;
113
123
 
114
124
  return asyncMap(stream, mapFn);
@@ -118,9 +128,12 @@ export class OpenAIDriver extends AbstractDriver<
118
128
  const functions = options.resultSchema
119
129
  ? [
120
130
  {
121
- name: "format_output",
122
- parameters: options.resultSchema as any,
123
- },
131
+ function: {
132
+ name: "format_output",
133
+ parameters: options.resultSchema as any,
134
+ },
135
+ type: 'function'
136
+ } as OpenAI.Chat.ChatCompletionTool,
124
137
  ]
125
138
  : undefined;
126
139
 
@@ -131,13 +144,23 @@ export class OpenAIDriver extends AbstractDriver<
131
144
  temperature: options.temperature,
132
145
  n: 1,
133
146
  max_tokens: options.max_tokens,
134
- functions: functions,
135
- function_call: options.resultSchema
136
- ? { name: "format_output" }
137
- : undefined,
147
+ tools: functions,
148
+ tool_choice: options.resultSchema
149
+ ? {
150
+ type: 'function',
151
+ function: { name: "format_output" }
152
+ } : undefined,
153
+ // functions: functions,
154
+ // function_call: options.resultSchema
155
+ // ? { name: "format_output" }
156
+ // : undefined,
138
157
  });
139
158
 
140
- return this.extractDataFromResponse(options, res);
159
+ const completion = this.extractDataFromResponse(options, res);
160
+ if (options.include_original_response) {
161
+ completion.original_response = res;
162
+ }
163
+ return completion;
141
164
  }
142
165
 
143
166
  createTrainingPrompt(options: TrainingPromptOptions): string {
package/src/replicate.ts CHANGED
@@ -138,6 +138,7 @@ export class ReplicateDriver extends AbstractDriver<DriverOptions, string> {
138
138
  prompt: prompt.length,
139
139
  total: res.output.length + prompt.length,
140
140
  },
141
+ original_response: options.include_original_response ? res : undefined,
141
142
  };
142
143
  }
143
144
 
@@ -1,7 +1,7 @@
1
1
  import { AIModel, AbstractDriver, Completion, DriverOptions, EmbeddingsResult, ExecutionOptions } from "@llumiverse/core";
2
2
  import { transformSSEStream } from "@llumiverse/core/async";
3
3
  import { FetchClient } from "api-fetch-client";
4
- import { TogetherModelInfo } from "./interfaces.js";
4
+ import { TextCompletion, TogetherModelInfo } from "./interfaces.js";
5
5
 
6
6
  interface TogetherAIDriverOptions extends DriverOptions {
7
7
  apiKey: string;
@@ -42,9 +42,9 @@ export class TogetherAIDriver extends AbstractDriver<TogetherAIDriverOptions, st
42
42
  "[/INST]"
43
43
  ],
44
44
  }
45
- })
46
-
47
- const text = res.choices[0]?.text ?? '';
45
+ }) as TextCompletion;
46
+ const choice = res.choices[0];
47
+ const text = choice.text ?? '';
48
48
  const usage = res.usage || {};
49
49
  return {
50
50
  result: text,
@@ -53,6 +53,8 @@ export class TogetherAIDriver extends AbstractDriver<TogetherAIDriverOptions, st
53
53
  result: usage.completion_tokens,
54
54
  total: usage.total_tokens,
55
55
  },
56
+ finish_reason: choice.finish_reason,
57
+ original_response: options.include_original_response ? res : undefined,
56
58
  }
57
59
  }
58
60
 
@@ -85,4 +85,20 @@ export interface TogetherModelInfo {
85
85
  link: string;
86
86
  descriptionLink: string;
87
87
  depth: Depth;
88
+ }
89
+
90
+ export interface TextCompletion {
91
+ id: string;
92
+ choices: {
93
+ text: string,
94
+ finish_reason: string, // stop | length ?
95
+ }[];
96
+ usage: {
97
+ prompt_tokens: number;
98
+ completion_tokens: number;
99
+ total_tokens: number;
100
+ }
101
+ created: number;
102
+ model: string;
103
+ object: string;
88
104
  }
@@ -1,4 +1,4 @@
1
- import { Content, GenerateContentRequest, HarmBlockThreshold, HarmCategory, TextPart } from "@google-cloud/vertexai";
1
+ import { Content, FinishReason, GenerateContentRequest, HarmBlockThreshold, HarmCategory, TextPart } from "@google-cloud/vertexai";
2
2
  import { AIModel, Completion, ExecutionOptions, ExecutionTokenUsage, ModelType, PromptOptions, PromptRole, PromptSegment } from "@llumiverse/core";
3
3
  import { asyncMap } from "@llumiverse/core/async";
4
4
  import { VertexAIDriver } from "../index.js";
@@ -7,12 +7,13 @@ import { ModelDefinition } from "../models.js";
7
7
  function getGenerativeModel(driver: VertexAIDriver, options: ExecutionOptions) {
8
8
  return driver.vertexai.preview.getGenerativeModel({
9
9
  model: options.model,
10
- //TODO pass in the options
10
+ //TODO pass in the options
11
11
  safety_settings: [{
12
12
  category: HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
13
13
  threshold: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE
14
14
  }],
15
15
  generation_config: {
16
+ candidate_count: 1,
16
17
  temperature: options.temperature,
17
18
  max_output_tokens: options.max_tokens
18
19
  },
@@ -111,9 +112,14 @@ export class GeminiModelDefinition implements ModelDefinition<GenerateContentReq
111
112
  total: usage?.totalTokenCount,
112
113
  }
113
114
 
114
- let result: any;
115
+ let finish_reason: string | undefined, result: any;
115
116
  const candidate = response.candidates[0];
116
117
  if (candidate) {
118
+ switch (candidate.finishReason) {
119
+ case FinishReason.MAX_TOKENS: finish_reason = "length"; break;
120
+ case FinishReason.STOP: finish_reason = "stop"; break;
121
+ default: finish_reason = candidate.finishReason;
122
+ }
117
123
  const content = candidate.content;
118
124
  if (content) {
119
125
  result = collectTextParts(content);
@@ -126,8 +132,10 @@ export class GeminiModelDefinition implements ModelDefinition<GenerateContentReq
126
132
 
127
133
  return {
128
134
  result: result ?? '',
129
- token_usage
130
- };
135
+ token_usage,
136
+ finish_reason,
137
+ original_response: options.include_original_response ? response : undefined,
138
+ } as Completion;
131
139
  }
132
140
 
133
141
  async requestCompletionStream(driver: VertexAIDriver, prompt: GenerateContentRequest, options: ExecutionOptions): Promise<AsyncIterable<string>> {
@@ -89,7 +89,9 @@ export abstract class AbstractPalmModelDefinition<NonStreamingPromptT extends No
89
89
  prompt: inputTokens,
90
90
  result: outputTokens,
91
91
  total: inputTokens && outputTokens ? inputTokens + outputTokens : undefined,
92
- }
92
+ },
93
+ //finish_reason not available
94
+ original_response: options.include_original_response ? response : undefined,
93
95
  } as Completion;
94
96
 
95
97
  }