@ai-sdk/replicate 2.0.16 → 2.0.18

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,342 @@
1
+ import {
2
+ AISDKError,
3
+ type Experimental_VideoModelV3,
4
+ type SharedV3Warning,
5
+ } from '@ai-sdk/provider';
6
+ import {
7
+ combineHeaders,
8
+ convertImageModelFileToDataUri,
9
+ createJsonResponseHandler,
10
+ delay,
11
+ type FetchFunction,
12
+ getFromApi,
13
+ lazySchema,
14
+ parseProviderOptions,
15
+ postJsonToApi,
16
+ type Resolvable,
17
+ resolve,
18
+ zodSchema,
19
+ } from '@ai-sdk/provider-utils';
20
+ import { z } from 'zod/v4';
21
+ import { replicateFailedResponseHandler } from './replicate-error';
22
+ import type { ReplicateVideoModelId } from './replicate-video-settings';
23
+
24
+ export type ReplicateVideoProviderOptions = {
25
+ // Polling configuration
26
+ pollIntervalMs?: number | null;
27
+ pollTimeoutMs?: number | null;
28
+ maxWaitTimeInSeconds?: number | null;
29
+
30
+ // Common video generation options
31
+ guidance_scale?: number | null;
32
+ num_inference_steps?: number | null;
33
+
34
+ // Stable Video Diffusion specific
35
+ motion_bucket_id?: number | null;
36
+ cond_aug?: number | null;
37
+ decoding_t?: number | null;
38
+ video_length?: string | null;
39
+ sizing_strategy?: string | null;
40
+ frames_per_second?: number | null;
41
+
42
+ // MiniMax specific
43
+ prompt_optimizer?: boolean | null;
44
+
45
+ [key: string]: unknown; // For passthrough
46
+ };
47
+
48
+ interface ReplicateVideoModelConfig {
49
+ provider: string;
50
+ baseURL: string;
51
+ headers?: Resolvable<Record<string, string | undefined>>;
52
+ fetch?: FetchFunction;
53
+ _internal?: {
54
+ currentDate?: () => Date;
55
+ };
56
+ }
57
+
58
+ export class ReplicateVideoModel implements Experimental_VideoModelV3 {
59
+ readonly specificationVersion = 'v3';
60
+ readonly maxVideosPerCall = 1; // Replicate video models support 1 video at a time
61
+
62
+ get provider(): string {
63
+ return this.config.provider;
64
+ }
65
+
66
+ constructor(
67
+ readonly modelId: ReplicateVideoModelId,
68
+ private readonly config: ReplicateVideoModelConfig,
69
+ ) {}
70
+
71
+ async doGenerate(
72
+ options: Parameters<Experimental_VideoModelV3['doGenerate']>[0],
73
+ ): Promise<Awaited<ReturnType<Experimental_VideoModelV3['doGenerate']>>> {
74
+ const currentDate = this.config._internal?.currentDate?.() ?? new Date();
75
+ const warnings: SharedV3Warning[] = [];
76
+
77
+ const replicateOptions = (await parseProviderOptions({
78
+ provider: 'replicate',
79
+ providerOptions: options.providerOptions,
80
+ schema: replicateVideoProviderOptionsSchema,
81
+ })) as ReplicateVideoProviderOptions | undefined;
82
+
83
+ const [modelId, version] = this.modelId.split(':');
84
+ const input: Record<string, unknown> = {};
85
+
86
+ if (options.prompt != null) {
87
+ input.prompt = options.prompt;
88
+ }
89
+
90
+ if (options.image != null) {
91
+ if (options.image.type === 'url') {
92
+ input.image = options.image.url;
93
+ } else {
94
+ input.image = convertImageModelFileToDataUri(options.image);
95
+ }
96
+ }
97
+
98
+ if (options.aspectRatio) {
99
+ input.aspect_ratio = options.aspectRatio;
100
+ }
101
+
102
+ if (options.resolution) {
103
+ input.size = options.resolution;
104
+ }
105
+
106
+ if (options.duration) {
107
+ input.duration = options.duration;
108
+ }
109
+
110
+ if (options.fps) {
111
+ input.fps = options.fps;
112
+ }
113
+
114
+ if (options.seed) {
115
+ input.seed = options.seed;
116
+ }
117
+
118
+ if (replicateOptions != null) {
119
+ const opts = replicateOptions;
120
+ if (opts.guidance_scale !== undefined && opts.guidance_scale !== null) {
121
+ input.guidance_scale = opts.guidance_scale;
122
+ }
123
+ if (
124
+ opts.num_inference_steps !== undefined &&
125
+ opts.num_inference_steps !== null
126
+ ) {
127
+ input.num_inference_steps = opts.num_inference_steps;
128
+ }
129
+ if (
130
+ opts.motion_bucket_id !== undefined &&
131
+ opts.motion_bucket_id !== null
132
+ ) {
133
+ input.motion_bucket_id = opts.motion_bucket_id;
134
+ }
135
+ if (opts.cond_aug !== undefined && opts.cond_aug !== null) {
136
+ input.cond_aug = opts.cond_aug;
137
+ }
138
+ if (opts.decoding_t !== undefined && opts.decoding_t !== null) {
139
+ input.decoding_t = opts.decoding_t;
140
+ }
141
+ if (opts.video_length !== undefined && opts.video_length !== null) {
142
+ input.video_length = opts.video_length;
143
+ }
144
+ if (opts.sizing_strategy !== undefined && opts.sizing_strategy !== null) {
145
+ input.sizing_strategy = opts.sizing_strategy;
146
+ }
147
+ if (
148
+ opts.frames_per_second !== undefined &&
149
+ opts.frames_per_second !== null
150
+ ) {
151
+ input.frames_per_second = opts.frames_per_second;
152
+ }
153
+ if (
154
+ opts.prompt_optimizer !== undefined &&
155
+ opts.prompt_optimizer !== null
156
+ ) {
157
+ input.prompt_optimizer = opts.prompt_optimizer;
158
+ }
159
+
160
+ for (const [key, value] of Object.entries(opts)) {
161
+ if (
162
+ ![
163
+ 'pollIntervalMs',
164
+ 'pollTimeoutMs',
165
+ 'maxWaitTimeInSeconds',
166
+ 'guidance_scale',
167
+ 'num_inference_steps',
168
+ 'motion_bucket_id',
169
+ 'cond_aug',
170
+ 'decoding_t',
171
+ 'video_length',
172
+ 'sizing_strategy',
173
+ 'frames_per_second',
174
+ 'prompt_optimizer',
175
+ ].includes(key)
176
+ ) {
177
+ input[key] = value;
178
+ }
179
+ }
180
+ }
181
+
182
+ const maxWaitTimeInSeconds = replicateOptions?.maxWaitTimeInSeconds;
183
+ const preferHeader: Record<string, string> =
184
+ maxWaitTimeInSeconds != null
185
+ ? { prefer: `wait=${maxWaitTimeInSeconds}` }
186
+ : { prefer: 'wait' };
187
+
188
+ const predictionUrl =
189
+ version != null
190
+ ? `${this.config.baseURL}/predictions`
191
+ : `${this.config.baseURL}/models/${modelId}/predictions`;
192
+
193
+ const { value: prediction, responseHeaders } = await postJsonToApi({
194
+ url: predictionUrl,
195
+ headers: combineHeaders(
196
+ await resolve(this.config.headers),
197
+ options.headers,
198
+ preferHeader,
199
+ ),
200
+ body: {
201
+ input,
202
+ ...(version != null ? { version } : {}),
203
+ },
204
+ successfulResponseHandler: createJsonResponseHandler(
205
+ replicatePredictionSchema,
206
+ ),
207
+ failedResponseHandler: replicateFailedResponseHandler,
208
+ abortSignal: options.abortSignal,
209
+ fetch: this.config.fetch,
210
+ });
211
+
212
+ let finalPrediction = prediction;
213
+ if (
214
+ prediction.status === 'starting' ||
215
+ prediction.status === 'processing'
216
+ ) {
217
+ const pollIntervalMs = replicateOptions?.pollIntervalMs ?? 2000; // 2 seconds
218
+ const pollTimeoutMs = replicateOptions?.pollTimeoutMs ?? 300000; // 5 minutes
219
+
220
+ const startTime = Date.now();
221
+
222
+ while (
223
+ finalPrediction.status === 'starting' ||
224
+ finalPrediction.status === 'processing'
225
+ ) {
226
+ if (Date.now() - startTime > pollTimeoutMs) {
227
+ throw new AISDKError({
228
+ name: 'REPLICATE_VIDEO_GENERATION_TIMEOUT',
229
+ message: `Video generation timed out after ${pollTimeoutMs}ms`,
230
+ });
231
+ }
232
+
233
+ await delay(pollIntervalMs);
234
+
235
+ if (options.abortSignal?.aborted) {
236
+ throw new AISDKError({
237
+ name: 'REPLICATE_VIDEO_GENERATION_ABORTED',
238
+ message: 'Video generation request was aborted',
239
+ });
240
+ }
241
+
242
+ const { value: statusPrediction } = await getFromApi({
243
+ url: finalPrediction.urls.get,
244
+ headers: await resolve(this.config.headers),
245
+ successfulResponseHandler: createJsonResponseHandler(
246
+ replicatePredictionSchema,
247
+ ),
248
+ failedResponseHandler: replicateFailedResponseHandler,
249
+ abortSignal: options.abortSignal,
250
+ fetch: this.config.fetch,
251
+ });
252
+
253
+ finalPrediction = statusPrediction;
254
+ }
255
+ }
256
+
257
+ if (finalPrediction.status === 'failed') {
258
+ throw new AISDKError({
259
+ name: 'REPLICATE_VIDEO_GENERATION_FAILED',
260
+ message: `Video generation failed: ${finalPrediction.error ?? 'Unknown error'}`,
261
+ });
262
+ }
263
+
264
+ if (finalPrediction.status === 'canceled') {
265
+ throw new AISDKError({
266
+ name: 'REPLICATE_VIDEO_GENERATION_CANCELED',
267
+ message: 'Video generation was canceled',
268
+ });
269
+ }
270
+
271
+ const videoUrl = finalPrediction.output;
272
+ if (!videoUrl) {
273
+ throw new AISDKError({
274
+ name: 'REPLICATE_VIDEO_GENERATION_ERROR',
275
+ message: 'No video URL in response',
276
+ });
277
+ }
278
+
279
+ return {
280
+ videos: [
281
+ {
282
+ type: 'url',
283
+ url: videoUrl,
284
+ mediaType: 'video/mp4',
285
+ },
286
+ ],
287
+ warnings,
288
+ response: {
289
+ timestamp: currentDate,
290
+ modelId: this.modelId,
291
+ headers: responseHeaders,
292
+ },
293
+ providerMetadata: {
294
+ replicate: {
295
+ videos: [
296
+ {
297
+ url: videoUrl,
298
+ },
299
+ ],
300
+ predictionId: finalPrediction.id,
301
+ metrics: finalPrediction.metrics,
302
+ },
303
+ },
304
+ };
305
+ }
306
+ }
307
+
308
+ const replicatePredictionSchema = z.object({
309
+ id: z.string(),
310
+ status: z.enum(['starting', 'processing', 'succeeded', 'failed', 'canceled']),
311
+ output: z.string().nullish(),
312
+ error: z.string().nullish(),
313
+ urls: z.object({
314
+ get: z.string(),
315
+ }),
316
+ metrics: z
317
+ .object({
318
+ predict_time: z.number().nullish(),
319
+ })
320
+ .nullish(),
321
+ });
322
+
323
+ const replicateVideoProviderOptionsSchema = lazySchema(() =>
324
+ zodSchema(
325
+ z
326
+ .object({
327
+ pollIntervalMs: z.number().positive().nullish(),
328
+ pollTimeoutMs: z.number().positive().nullish(),
329
+ maxWaitTimeInSeconds: z.number().positive().nullish(),
330
+ guidance_scale: z.number().nullish(),
331
+ num_inference_steps: z.number().nullish(),
332
+ motion_bucket_id: z.number().nullish(),
333
+ cond_aug: z.number().nullish(),
334
+ decoding_t: z.number().nullish(),
335
+ video_length: z.string().nullish(),
336
+ sizing_strategy: z.string().nullish(),
337
+ frames_per_second: z.number().nullish(),
338
+ prompt_optimizer: z.boolean().nullish(),
339
+ })
340
+ .loose(),
341
+ ),
342
+ );
@@ -0,0 +1,5 @@
1
+ export type ReplicateVideoModelId =
2
+ | 'minimax/video-01'
3
+ | 'minimax/video-01:6c1e4171-288a-4ca2-a738-894f0e87699d'
4
+ | 'stability-ai/stable-video-diffusion:3f0457e4619daac51203dedb472816fd4af51f3149fa7a9e0b5ffcf1b8172438'
5
+ | (string & {});