@ai-sdk/replicate 0.0.0-70e0935a-20260114150030 → 0.0.0-98261322-20260122142521

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,268 @@
1
+ import type { ImageModelV3, SharedV3Warning } from '@ai-sdk/provider';
2
+ import type { Resolvable } from '@ai-sdk/provider-utils';
3
+ import {
4
+ combineHeaders,
5
+ convertImageModelFileToDataUri,
6
+ createBinaryResponseHandler,
7
+ createJsonResponseHandler,
8
+ FetchFunction,
9
+ getFromApi,
10
+ InferSchema,
11
+ lazySchema,
12
+ parseProviderOptions,
13
+ postJsonToApi,
14
+ resolve,
15
+ zodSchema,
16
+ } from '@ai-sdk/provider-utils';
17
+ import { z } from 'zod/v4';
18
+ import { replicateFailedResponseHandler } from './replicate-error';
19
+ import { ReplicateImageModelId } from './replicate-image-settings';
20
+
21
+ interface ReplicateImageModelConfig {
22
+ provider: string;
23
+ baseURL: string;
24
+ headers?: Resolvable<Record<string, string | undefined>>;
25
+ fetch?: FetchFunction;
26
+ _internal?: {
27
+ currentDate?: () => Date;
28
+ };
29
+ }
30
+
31
+ // Flux-2 models support up to 8 input images with input_image, input_image_2, etc.
32
+ const FLUX_2_MODEL_PATTERN = /^black-forest-labs\/flux-2-/;
33
+ const MAX_FLUX_2_INPUT_IMAGES = 8;
34
+
35
+ export class ReplicateImageModel implements ImageModelV3 {
36
+ readonly specificationVersion = 'v3';
37
+
38
+ get maxImagesPerCall(): number {
39
+ // Flux-2 models support up to 8 input images
40
+ return this.isFlux2Model ? MAX_FLUX_2_INPUT_IMAGES : 1;
41
+ }
42
+
43
+ get provider(): string {
44
+ return this.config.provider;
45
+ }
46
+
47
+ private get isFlux2Model(): boolean {
48
+ return FLUX_2_MODEL_PATTERN.test(this.modelId);
49
+ }
50
+
51
+ constructor(
52
+ readonly modelId: ReplicateImageModelId,
53
+ private readonly config: ReplicateImageModelConfig,
54
+ ) {}
55
+
56
+ async doGenerate({
57
+ prompt,
58
+ n,
59
+ aspectRatio,
60
+ size,
61
+ seed,
62
+ providerOptions,
63
+ headers,
64
+ abortSignal,
65
+ files,
66
+ mask,
67
+ }: Parameters<ImageModelV3['doGenerate']>[0]): Promise<
68
+ Awaited<ReturnType<ImageModelV3['doGenerate']>>
69
+ > {
70
+ const warnings: Array<SharedV3Warning> = [];
71
+
72
+ const [modelId, version] = this.modelId.split(':');
73
+
74
+ const currentDate = this.config._internal?.currentDate?.() ?? new Date();
75
+
76
+ // Parse provider options
77
+ const replicateOptions = await parseProviderOptions({
78
+ provider: 'replicate',
79
+ providerOptions,
80
+ schema: replicateImageProviderOptionsSchema,
81
+ });
82
+
83
+ // Handle image input from files
84
+ let imageInputs: Record<string, string> = {};
85
+ if (files != null && files.length > 0) {
86
+ if (this.isFlux2Model) {
87
+ // Flux-2 models use input_image, input_image_2, input_image_3, etc.
88
+ for (
89
+ let i = 0;
90
+ i < Math.min(files.length, MAX_FLUX_2_INPUT_IMAGES);
91
+ i++
92
+ ) {
93
+ const key = i === 0 ? 'input_image' : `input_image_${i + 1}`;
94
+ imageInputs[key] = convertImageModelFileToDataUri(files[i]);
95
+ }
96
+ if (files.length > MAX_FLUX_2_INPUT_IMAGES) {
97
+ warnings.push({
98
+ type: 'other',
99
+ message: `Flux-2 models support up to ${MAX_FLUX_2_INPUT_IMAGES} input images. Additional images are ignored.`,
100
+ });
101
+ }
102
+ } else {
103
+ // Other models use single 'image' parameter
104
+ imageInputs = { image: convertImageModelFileToDataUri(files[0]) };
105
+ if (files.length > 1) {
106
+ warnings.push({
107
+ type: 'other',
108
+ message:
109
+ 'This Replicate model only supports a single input image. Additional images are ignored.',
110
+ });
111
+ }
112
+ }
113
+ }
114
+
115
+ // Handle mask input (not supported by Flux-2 models)
116
+ let maskInput: string | undefined;
117
+ if (mask != null) {
118
+ if (this.isFlux2Model) {
119
+ warnings.push({
120
+ type: 'other',
121
+ message:
122
+ 'Flux-2 models do not support mask input. The mask will be ignored.',
123
+ });
124
+ } else {
125
+ maskInput = convertImageModelFileToDataUri(mask);
126
+ }
127
+ }
128
+
129
+ // Extract maxWaitTimeInSeconds from provider options and prepare the rest for the request body
130
+ const { maxWaitTimeInSeconds, ...inputOptions } = replicateOptions ?? {};
131
+
132
+ // Build the prefer header based on maxWaitTimeInSeconds:
133
+ // - undefined/null: use default sync wait (prefer: wait)
134
+ // - positive number: use custom wait duration (prefer: wait=N)
135
+ const preferHeader: Record<string, string> =
136
+ maxWaitTimeInSeconds != null
137
+ ? { prefer: `wait=${maxWaitTimeInSeconds}` }
138
+ : { prefer: 'wait' };
139
+
140
+ const {
141
+ value: { output },
142
+ responseHeaders,
143
+ } = await postJsonToApi({
144
+ url:
145
+ // different endpoints for versioned vs unversioned models:
146
+ version != null
147
+ ? `${this.config.baseURL}/predictions`
148
+ : `${this.config.baseURL}/models/${modelId}/predictions`,
149
+
150
+ headers: combineHeaders(
151
+ await resolve(this.config.headers),
152
+ headers,
153
+ preferHeader,
154
+ ),
155
+
156
+ body: {
157
+ input: {
158
+ prompt,
159
+ aspect_ratio: aspectRatio,
160
+ size,
161
+ seed,
162
+ num_outputs: n,
163
+ ...imageInputs,
164
+ ...(maskInput != null ? { mask: maskInput } : {}),
165
+ ...inputOptions,
166
+ },
167
+ // for versioned models, include the version in the body:
168
+ ...(version != null ? { version } : {}),
169
+ },
170
+
171
+ successfulResponseHandler: createJsonResponseHandler(
172
+ replicateImageResponseSchema,
173
+ ),
174
+ failedResponseHandler: replicateFailedResponseHandler,
175
+ abortSignal,
176
+ fetch: this.config.fetch,
177
+ });
178
+
179
+ // download the images:
180
+ const outputArray = Array.isArray(output) ? output : [output];
181
+ const images = await Promise.all(
182
+ outputArray.map(async url => {
183
+ const { value: image } = await getFromApi({
184
+ url,
185
+ successfulResponseHandler: createBinaryResponseHandler(),
186
+ failedResponseHandler: replicateFailedResponseHandler,
187
+ abortSignal,
188
+ fetch: this.config.fetch,
189
+ });
190
+ return image;
191
+ }),
192
+ );
193
+
194
+ return {
195
+ images,
196
+ warnings,
197
+ response: {
198
+ timestamp: currentDate,
199
+ modelId: this.modelId,
200
+ headers: responseHeaders,
201
+ },
202
+ };
203
+ }
204
+ }
205
+
206
+ const replicateImageResponseSchema = z.object({
207
+ output: z.union([z.array(z.string()), z.string()]),
208
+ });
209
+
210
+ /**
211
+ * Provider options schema for Replicate image generation.
212
+ *
213
+ * Note: Different Replicate models support different parameters.
214
+ * This schema includes common parameters, but you can pass any
215
+ * model-specific parameters through the passthrough.
216
+ */
217
+ export const replicateImageProviderOptionsSchema = lazySchema(() =>
218
+ zodSchema(
219
+ z
220
+ .object({
221
+ /**
222
+ * Maximum time in seconds to wait for the prediction to complete in sync mode.
223
+ * By default, Replicate uses sync mode with a 60-second timeout.
224
+ *
225
+ * - When not specified: Uses default 60-second sync wait (`prefer: wait`)
226
+ * - When set to a positive number: Uses that duration (`prefer: wait=N`)
227
+ */
228
+ maxWaitTimeInSeconds: z.number().positive().nullish(),
229
+
230
+ /**
231
+ * Guidance scale for classifier-free guidance.
232
+ * Higher values make the output more closely match the prompt.
233
+ */
234
+ guidance_scale: z.number().nullish(),
235
+
236
+ /**
237
+ * Number of denoising steps. More steps = higher quality but slower.
238
+ */
239
+ num_inference_steps: z.number().nullish(),
240
+
241
+ /**
242
+ * Negative prompt to guide what to avoid in the generation.
243
+ */
244
+ negative_prompt: z.string().nullish(),
245
+
246
+ /**
247
+ * Output image format.
248
+ */
249
+ output_format: z.enum(['png', 'jpg', 'webp']).nullish(),
250
+
251
+ /**
252
+ * Output image quality (1-100). Only applies to jpg and webp.
253
+ */
254
+ output_quality: z.number().min(1).max(100).nullish(),
255
+
256
+ /**
257
+ * Strength of the transformation for img2img (0-1).
258
+ * Lower values keep more of the original image.
259
+ */
260
+ strength: z.number().min(0).max(1).nullish(),
261
+ })
262
+ .passthrough(),
263
+ ),
264
+ );
265
+
266
+ export type ReplicateImageProviderOptions = InferSchema<
267
+ typeof replicateImageProviderOptionsSchema
268
+ >;
@@ -0,0 +1,36 @@
1
+ export type ReplicateImageModelId =
2
+ // Text-to-image models
3
+ | 'black-forest-labs/flux-1.1-pro'
4
+ | 'black-forest-labs/flux-1.1-pro-ultra'
5
+ | 'black-forest-labs/flux-dev'
6
+ | 'black-forest-labs/flux-pro'
7
+ | 'black-forest-labs/flux-schnell'
8
+ | 'bytedance/sdxl-lightning-4step'
9
+ | 'fofr/aura-flow'
10
+ | 'fofr/latent-consistency-model'
11
+ | 'fofr/realvisxl-v3-multi-controlnet-lora'
12
+ | 'fofr/sdxl-emoji'
13
+ | 'fofr/sdxl-multi-controlnet-lora'
14
+ | 'ideogram-ai/ideogram-v2'
15
+ | 'ideogram-ai/ideogram-v2-turbo'
16
+ | 'lucataco/dreamshaper-xl-turbo'
17
+ | 'lucataco/open-dalle-v1.1'
18
+ | 'lucataco/realvisxl-v2.0'
19
+ | 'lucataco/realvisxl2-lcm'
20
+ | 'luma/photon'
21
+ | 'luma/photon-flash'
22
+ | 'nvidia/sana'
23
+ | 'playgroundai/playground-v2.5-1024px-aesthetic'
24
+ | 'recraft-ai/recraft-v3'
25
+ | 'recraft-ai/recraft-v3-svg'
26
+ | 'stability-ai/stable-diffusion-3.5-large'
27
+ | 'stability-ai/stable-diffusion-3.5-large-turbo'
28
+ | 'stability-ai/stable-diffusion-3.5-medium'
29
+ | 'tstramer/material-diffusion'
30
+ // Inpainting and image editing models
31
+ | 'black-forest-labs/flux-fill-pro'
32
+ | 'black-forest-labs/flux-fill-dev'
33
+ // Flux-2 models (support up to 8 reference images)
34
+ | 'black-forest-labs/flux-2-pro'
35
+ | 'black-forest-labs/flux-2-dev'
36
+ | (string & {});
@@ -0,0 +1,24 @@
1
+ import { describe, it, expect } from 'vitest';
2
+ import { createReplicate } from './replicate-provider';
3
+ import { ReplicateImageModel } from './replicate-image-model';
4
+
5
+ describe('createReplicate', () => {
6
+ it('creates a provider with required settings', () => {
7
+ const provider = createReplicate({ apiToken: 'test-token' });
8
+ expect(provider.image).toBeDefined();
9
+ });
10
+
11
+ it('creates a provider with custom settings', () => {
12
+ const provider = createReplicate({
13
+ apiToken: 'test-token',
14
+ baseURL: 'https://custom.replicate.com',
15
+ });
16
+ expect(provider.image).toBeDefined();
17
+ });
18
+
19
+ it('creates an image model instance', () => {
20
+ const provider = createReplicate({ apiToken: 'test-token' });
21
+ const model = provider.image('black-forest-labs/flux-schnell');
22
+ expect(model).toBeInstanceOf(ReplicateImageModel);
23
+ });
24
+ });
@@ -0,0 +1,99 @@
1
+ import { NoSuchModelError, ProviderV3 } from '@ai-sdk/provider';
2
+ import type { FetchFunction } from '@ai-sdk/provider-utils';
3
+ import { loadApiKey, withUserAgentSuffix } from '@ai-sdk/provider-utils';
4
+ import { ReplicateImageModel } from './replicate-image-model';
5
+ import { ReplicateImageModelId } from './replicate-image-settings';
6
+ import { VERSION } from './version';
7
+
8
+ export interface ReplicateProviderSettings {
9
+ /**
10
+ API token that is being send using the `Authorization` header.
11
+ It defaults to the `REPLICATE_API_TOKEN` environment variable.
12
+ */
13
+ apiToken?: string;
14
+
15
+ /**
16
+ Use a different URL prefix for API calls, e.g. to use proxy servers.
17
+ The default prefix is `https://api.replicate.com/v1`.
18
+ */
19
+ baseURL?: string;
20
+
21
+ /**
22
+ Custom headers to include in the requests.
23
+ */
24
+ headers?: Record<string, string>;
25
+
26
+ /**
27
+ Custom fetch implementation. You can use it as a middleware to intercept requests,
28
+ or to provide a custom fetch implementation for e.g. testing.
29
+ */
30
+ fetch?: FetchFunction;
31
+ }
32
+
33
+ export interface ReplicateProvider extends ProviderV3 {
34
+ /**
35
+ * Creates a Replicate image generation model.
36
+ */
37
+ image(modelId: ReplicateImageModelId): ReplicateImageModel;
38
+
39
+ /**
40
+ * Creates a Replicate image generation model.
41
+ */
42
+ imageModel(modelId: ReplicateImageModelId): ReplicateImageModel;
43
+
44
+ /**
45
+ * @deprecated Use `embeddingModel` instead.
46
+ */
47
+ textEmbeddingModel(modelId: string): never;
48
+ }
49
+
50
+ /**
51
+ * Create a Replicate provider instance.
52
+ */
53
+ export function createReplicate(
54
+ options: ReplicateProviderSettings = {},
55
+ ): ReplicateProvider {
56
+ const createImageModel = (modelId: ReplicateImageModelId) =>
57
+ new ReplicateImageModel(modelId, {
58
+ provider: 'replicate',
59
+ baseURL: options.baseURL ?? 'https://api.replicate.com/v1',
60
+ headers: withUserAgentSuffix(
61
+ {
62
+ Authorization: `Bearer ${loadApiKey({
63
+ apiKey: options.apiToken,
64
+ environmentVariableName: 'REPLICATE_API_TOKEN',
65
+ description: 'Replicate',
66
+ })}`,
67
+ ...options.headers,
68
+ },
69
+ `ai-sdk/replicate/${VERSION}`,
70
+ ),
71
+ fetch: options.fetch,
72
+ });
73
+
74
+ const embeddingModel = (modelId: string) => {
75
+ throw new NoSuchModelError({
76
+ modelId,
77
+ modelType: 'embeddingModel',
78
+ });
79
+ };
80
+
81
+ return {
82
+ specificationVersion: 'v3' as const,
83
+ image: createImageModel,
84
+ imageModel: createImageModel,
85
+ languageModel: (modelId: string) => {
86
+ throw new NoSuchModelError({
87
+ modelId,
88
+ modelType: 'languageModel',
89
+ });
90
+ },
91
+ embeddingModel,
92
+ textEmbeddingModel: embeddingModel,
93
+ };
94
+ }
95
+
96
+ /**
97
+ * Default Replicate provider instance.
98
+ */
99
+ export const replicate = createReplicate();
package/src/version.ts ADDED
@@ -0,0 +1,6 @@
1
+ // Version string of this package injected at build time.
2
+ declare const __PACKAGE_VERSION__: string | undefined;
3
+ export const VERSION: string =
4
+ typeof __PACKAGE_VERSION__ !== 'undefined'
5
+ ? __PACKAGE_VERSION__
6
+ : '0.0.0-test';