@ai-sdk/fal 2.0.9 → 2.0.11

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,183 @@
1
+ import {
2
+ ImageModelV3,
3
+ NoSuchModelError,
4
+ ProviderV3,
5
+ SpeechModelV3,
6
+ TranscriptionModelV3,
7
+ } from '@ai-sdk/provider';
8
+ import type { FetchFunction } from '@ai-sdk/provider-utils';
9
+ import {
10
+ withoutTrailingSlash,
11
+ withUserAgentSuffix,
12
+ } from '@ai-sdk/provider-utils';
13
+ import { FalImageModel } from './fal-image-model';
14
+ import { FalImageModelId } from './fal-image-settings';
15
+ import { FalTranscriptionModelId } from './fal-transcription-options';
16
+ import { FalTranscriptionModel } from './fal-transcription-model';
17
+ import { FalSpeechModelId } from './fal-speech-settings';
18
+ import { FalSpeechModel } from './fal-speech-model';
19
+ import { VERSION } from './version';
20
+
21
+ export interface FalProviderSettings {
22
+ /**
23
+ fal.ai API key. Default value is taken from the `FAL_API_KEY` environment
24
+ variable, falling back to `FAL_KEY`.
25
+ */
26
+ apiKey?: string;
27
+
28
+ /**
29
+ Base URL for the API calls.
30
+ The default prefix is `https://fal.run`.
31
+ */
32
+ baseURL?: string;
33
+
34
+ /**
35
+ Custom headers to include in the requests.
36
+ */
37
+ headers?: Record<string, string>;
38
+
39
+ /**
40
+ Custom fetch implementation. You can use it as a middleware to intercept
41
+ requests, or to provide a custom fetch implementation for e.g. testing.
42
+ */
43
+ fetch?: FetchFunction;
44
+ }
45
+
46
+ export interface FalProvider extends ProviderV3 {
47
+ /**
48
+ Creates a model for image generation.
49
+ */
50
+ image(modelId: FalImageModelId): ImageModelV3;
51
+
52
+ /**
53
+ Creates a model for image generation.
54
+ */
55
+ imageModel(modelId: FalImageModelId): ImageModelV3;
56
+
57
+ /**
58
+ Creates a model for transcription.
59
+ */
60
+ transcription(modelId: FalTranscriptionModelId): TranscriptionModelV3;
61
+
62
+ /**
63
+ Creates a model for speech generation.
64
+ */
65
+ speech(modelId: FalSpeechModelId): SpeechModelV3;
66
+
67
+ /**
68
+ * @deprecated Use `embeddingModel` instead.
69
+ */
70
+ textEmbeddingModel(modelId: string): never;
71
+ }
72
+
73
+ const defaultBaseURL = 'https://fal.run';
74
+
75
+ function loadFalApiKey({
76
+ apiKey,
77
+ description = 'fal.ai',
78
+ }: {
79
+ apiKey: string | undefined;
80
+ description?: string;
81
+ }): string {
82
+ if (typeof apiKey === 'string') {
83
+ return apiKey;
84
+ }
85
+
86
+ if (apiKey != null) {
87
+ throw new Error(`${description} API key must be a string.`);
88
+ }
89
+
90
+ if (typeof process === 'undefined') {
91
+ throw new Error(
92
+ `${description} API key is missing. Pass it using the 'apiKey' parameter. Environment variables are not supported in this environment.`,
93
+ );
94
+ }
95
+
96
+ let envApiKey = process.env.FAL_API_KEY;
97
+ if (envApiKey == null) {
98
+ envApiKey = process.env.FAL_KEY;
99
+ }
100
+
101
+ if (envApiKey == null) {
102
+ throw new Error(
103
+ `${description} API key is missing. Pass it using the 'apiKey' parameter or set either the FAL_API_KEY or FAL_KEY environment variable.`,
104
+ );
105
+ }
106
+
107
+ if (typeof envApiKey !== 'string') {
108
+ throw new Error(
109
+ `${description} API key must be a string. The value of the environment variable is not a string.`,
110
+ );
111
+ }
112
+
113
+ return envApiKey;
114
+ }
115
+
116
+ /**
117
+ Create a fal.ai provider instance.
118
+ */
119
+ export function createFal(options: FalProviderSettings = {}): FalProvider {
120
+ const baseURL = withoutTrailingSlash(options.baseURL ?? defaultBaseURL);
121
+ const getHeaders = () =>
122
+ withUserAgentSuffix(
123
+ {
124
+ Authorization: `Key ${loadFalApiKey({
125
+ apiKey: options.apiKey,
126
+ })}`,
127
+ ...options.headers,
128
+ },
129
+ `ai-sdk/fal/${VERSION}`,
130
+ );
131
+
132
+ const createImageModel = (modelId: FalImageModelId) =>
133
+ new FalImageModel(modelId, {
134
+ provider: 'fal.image',
135
+ baseURL: baseURL ?? defaultBaseURL,
136
+ headers: getHeaders,
137
+ fetch: options.fetch,
138
+ });
139
+
140
+ const createSpeechModel = (modelId: FalSpeechModelId) =>
141
+ new FalSpeechModel(modelId, {
142
+ provider: `fal.speech`,
143
+ url: ({ path }) => path,
144
+ headers: getHeaders,
145
+ fetch: options.fetch,
146
+ });
147
+
148
+ const createTranscriptionModel = (modelId: FalTranscriptionModelId) =>
149
+ new FalTranscriptionModel(modelId, {
150
+ provider: `fal.transcription`,
151
+ url: ({ path }) => path,
152
+ headers: getHeaders,
153
+ fetch: options.fetch,
154
+ });
155
+
156
+ const embeddingModel = (modelId: string) => {
157
+ throw new NoSuchModelError({
158
+ modelId,
159
+ modelType: 'embeddingModel',
160
+ });
161
+ };
162
+
163
+ return {
164
+ specificationVersion: 'v3' as const,
165
+ imageModel: createImageModel,
166
+ image: createImageModel,
167
+ languageModel: (modelId: string) => {
168
+ throw new NoSuchModelError({
169
+ modelId,
170
+ modelType: 'languageModel',
171
+ });
172
+ },
173
+ speech: createSpeechModel,
174
+ embeddingModel,
175
+ textEmbeddingModel: embeddingModel,
176
+ transcription: createTranscriptionModel,
177
+ };
178
+ }
179
+
180
+ /**
181
+ Default fal.ai provider instance.
182
+ */
183
+ export const fal = createFal();
@@ -0,0 +1,128 @@
1
+ import { createTestServer } from '@ai-sdk/test-server/with-vitest';
2
+ import { createFal } from './fal-provider';
3
+ import { FalSpeechModel } from './fal-speech-model';
4
+ import { describe, it, expect } from 'vitest';
5
+
6
+ const provider = createFal({ apiKey: 'test-api-key' });
7
+ const model = provider.speech('fal-ai/minimax/speech-02-hd');
8
+
9
+ const server = createTestServer({
10
+ 'https://fal.run/fal-ai/minimax/speech-02-hd': {},
11
+ 'https://fal.media/files/test.mp3': {},
12
+ });
13
+
14
+ describe('FalSpeechModel.doGenerate', () => {
15
+ function prepareResponses({
16
+ jsonHeaders,
17
+ audioHeaders,
18
+ }: {
19
+ jsonHeaders?: Record<string, string>;
20
+ audioHeaders?: Record<string, string>;
21
+ } = {}) {
22
+ const audioBuffer = new Uint8Array(100);
23
+ server.urls['https://fal.run/fal-ai/minimax/speech-02-hd'].response = {
24
+ type: 'json-value',
25
+ headers: {
26
+ 'content-type': 'application/json',
27
+ ...jsonHeaders,
28
+ },
29
+ body: {
30
+ audio: { url: 'https://fal.media/files/test.mp3' },
31
+ duration_ms: 1234,
32
+ },
33
+ };
34
+ server.urls['https://fal.media/files/test.mp3'].response = {
35
+ type: 'binary',
36
+ headers: {
37
+ 'content-type': 'audio/mp3',
38
+ ...audioHeaders,
39
+ },
40
+ body: Buffer.from(audioBuffer),
41
+ };
42
+ return audioBuffer;
43
+ }
44
+
45
+ it('should pass text and default output_format', async () => {
46
+ prepareResponses();
47
+
48
+ await model.doGenerate({
49
+ text: 'Hello from the AI SDK!',
50
+ });
51
+
52
+ expect(await server.calls[0].requestBodyJson).toMatchObject({
53
+ text: 'Hello from the AI SDK!',
54
+ output_format: 'url',
55
+ });
56
+ });
57
+
58
+ it('should pass headers', async () => {
59
+ prepareResponses();
60
+
61
+ const provider = createFal({
62
+ apiKey: 'test-api-key',
63
+ headers: {
64
+ 'Custom-Provider-Header': 'provider-header-value',
65
+ },
66
+ });
67
+
68
+ await provider.speech('fal-ai/minimax/speech-02-hd').doGenerate({
69
+ text: 'Hello from the AI SDK!',
70
+ headers: {
71
+ 'Custom-Request-Header': 'request-header-value',
72
+ },
73
+ });
74
+
75
+ expect(server.calls[0].requestHeaders).toMatchObject({
76
+ authorization: 'Key test-api-key',
77
+ 'content-type': 'application/json',
78
+ 'custom-provider-header': 'provider-header-value',
79
+ 'custom-request-header': 'request-header-value',
80
+ });
81
+ });
82
+
83
+ it('should return audio data', async () => {
84
+ const audio = prepareResponses();
85
+
86
+ const result = await model.doGenerate({
87
+ text: 'Hello from the AI SDK!',
88
+ });
89
+
90
+ expect(result.audio).toStrictEqual(audio);
91
+ });
92
+
93
+ it('should include response data with timestamp, modelId and headers', async () => {
94
+ prepareResponses({ jsonHeaders: { 'x-request-id': 'test-request-id' } });
95
+
96
+ const testDate = new Date(0);
97
+ const customModel = new FalSpeechModel('fal-ai/minimax/speech-02-hd', {
98
+ provider: 'fal.speech',
99
+ url: ({ path }) => path,
100
+ headers: () => ({}),
101
+ _internal: { currentDate: () => testDate },
102
+ });
103
+
104
+ const result = await customModel.doGenerate({
105
+ text: 'Hello from the AI SDK!',
106
+ });
107
+
108
+ expect(result.response).toMatchObject({
109
+ timestamp: testDate,
110
+ modelId: 'fal-ai/minimax/speech-02-hd',
111
+ headers: expect.objectContaining({ 'x-request-id': 'test-request-id' }),
112
+ });
113
+ });
114
+
115
+ it('should include warnings for unsupported settings', async () => {
116
+ prepareResponses();
117
+
118
+ const result = await model.doGenerate({
119
+ text: 'Hello from the AI SDK!',
120
+ language: 'en',
121
+ // invalid outputFormat triggers a warning and defaults to url
122
+ // (we still return audio via URL)
123
+ outputFormat: 'wav',
124
+ });
125
+
126
+ expect(result.warnings.length).toBeGreaterThan(0);
127
+ });
128
+ });
@@ -0,0 +1,156 @@
1
+ import { SpeechModelV3, SharedV3Warning } from '@ai-sdk/provider';
2
+ import {
3
+ combineHeaders,
4
+ createBinaryResponseHandler,
5
+ createJsonResponseHandler,
6
+ createStatusCodeErrorResponseHandler,
7
+ getFromApi,
8
+ parseProviderOptions,
9
+ postJsonToApi,
10
+ } from '@ai-sdk/provider-utils';
11
+ import { z } from 'zod/v4';
12
+ import { FalConfig } from './fal-config';
13
+ import { falFailedResponseHandler } from './fal-error';
14
+ import { FAL_EMOTIONS, FAL_LANGUAGE_BOOSTS } from './fal-api-types';
15
+ import { FalSpeechModelId } from './fal-speech-settings';
16
+
17
+ const falSpeechProviderOptionsSchema = z.looseObject({
18
+ voice_setting: z
19
+ .object({
20
+ speed: z.number().nullish(),
21
+ vol: z.number().nullish(),
22
+ voice_id: z.string().nullish(),
23
+ pitch: z.number().nullish(),
24
+ english_normalization: z.boolean().nullish(),
25
+ emotion: z.enum(FAL_EMOTIONS).nullish(),
26
+ })
27
+ .partial()
28
+ .nullish(),
29
+ audio_setting: z.record(z.string(), z.unknown()).nullish(),
30
+ language_boost: z.enum(FAL_LANGUAGE_BOOSTS).nullish(),
31
+ pronunciation_dict: z.record(z.string(), z.string()).nullish(),
32
+ });
33
+
34
+ export type FalSpeechCallOptions = z.infer<
35
+ typeof falSpeechProviderOptionsSchema
36
+ >;
37
+
38
+ interface FalSpeechModelConfig extends FalConfig {
39
+ _internal?: {
40
+ currentDate?: () => Date;
41
+ };
42
+ }
43
+
44
+ export class FalSpeechModel implements SpeechModelV3 {
45
+ readonly specificationVersion = 'v3';
46
+
47
+ get provider(): string {
48
+ return this.config.provider;
49
+ }
50
+
51
+ constructor(
52
+ readonly modelId: FalSpeechModelId,
53
+ private readonly config: FalSpeechModelConfig,
54
+ ) {}
55
+
56
+ private async getArgs({
57
+ text,
58
+ voice,
59
+ outputFormat,
60
+ speed,
61
+ language,
62
+ providerOptions,
63
+ }: Parameters<SpeechModelV3['doGenerate']>[0]) {
64
+ const warnings: SharedV3Warning[] = [];
65
+
66
+ const falOptions = await parseProviderOptions({
67
+ provider: 'fal',
68
+ providerOptions,
69
+ schema: falSpeechProviderOptionsSchema,
70
+ });
71
+
72
+ const requestBody = {
73
+ text,
74
+ output_format: outputFormat === 'hex' ? 'hex' : 'url',
75
+ voice,
76
+ speed,
77
+ ...falOptions,
78
+ };
79
+
80
+ // Language is not directly supported; warn and ignore
81
+ if (language) {
82
+ warnings.push({
83
+ type: 'unsupported',
84
+ feature: 'language',
85
+ details:
86
+ "fal speech models don't support 'language' directly; consider providerOptions.fal.language_boost",
87
+ });
88
+ }
89
+
90
+ // warn on invalid values (and on hex until we support hex response handling)
91
+ if (outputFormat && outputFormat !== 'url' && outputFormat !== 'hex') {
92
+ warnings.push({
93
+ type: 'unsupported',
94
+ feature: 'outputFormat',
95
+ details: `Unsupported outputFormat: ${outputFormat}. Using 'url' instead.`,
96
+ });
97
+ }
98
+
99
+ return { requestBody, warnings } as const;
100
+ }
101
+
102
+ async doGenerate(
103
+ options: Parameters<SpeechModelV3['doGenerate']>[0],
104
+ ): Promise<Awaited<ReturnType<SpeechModelV3['doGenerate']>>> {
105
+ const currentDate = this.config._internal?.currentDate?.() ?? new Date();
106
+ const { requestBody, warnings } = await this.getArgs(options);
107
+
108
+ const {
109
+ value: json,
110
+ responseHeaders,
111
+ rawValue,
112
+ } = await postJsonToApi({
113
+ url: this.config.url({
114
+ path: `https://fal.run/${this.modelId}`,
115
+ modelId: this.modelId,
116
+ }),
117
+ headers: combineHeaders(this.config.headers(), options.headers),
118
+ body: requestBody,
119
+ failedResponseHandler: falFailedResponseHandler,
120
+ successfulResponseHandler: createJsonResponseHandler(
121
+ falSpeechResponseSchema,
122
+ ),
123
+ abortSignal: options.abortSignal,
124
+ fetch: this.config.fetch,
125
+ });
126
+
127
+ const audioUrl = json.audio.url;
128
+ const { value: audio } = await getFromApi({
129
+ url: audioUrl,
130
+ failedResponseHandler: createStatusCodeErrorResponseHandler(),
131
+ successfulResponseHandler: createBinaryResponseHandler(),
132
+ abortSignal: options.abortSignal,
133
+ fetch: this.config.fetch,
134
+ });
135
+
136
+ return {
137
+ audio,
138
+ warnings,
139
+ request: {
140
+ body: JSON.stringify(requestBody),
141
+ },
142
+ response: {
143
+ timestamp: currentDate,
144
+ modelId: this.modelId,
145
+ headers: responseHeaders,
146
+ body: rawValue,
147
+ },
148
+ };
149
+ }
150
+ }
151
+
152
+ const falSpeechResponseSchema = z.object({
153
+ audio: z.object({ url: z.string() }),
154
+ duration_ms: z.number().optional(),
155
+ request_id: z.string().optional(),
156
+ });
@@ -0,0 +1,10 @@
1
+ // https://fal.ai/explore/search?categories=text-to-speech&q=newest
2
+ export type FalSpeechModelId =
3
+ | 'fal-ai/minimax/voice-clone'
4
+ | 'fal-ai/minimax/voice-design'
5
+ | 'fal-ai/dia-tts/voice-clone'
6
+ | 'fal-ai/minimax/speech-02-hd'
7
+ | 'fal-ai/minimax/speech-02-turbo'
8
+ | 'fal-ai/dia-tts'
9
+ | 'resemble-ai/chatterboxhd/text-to-speech'
10
+ | (string & {});
@@ -0,0 +1,181 @@
1
+ import { createTestServer } from '@ai-sdk/test-server/with-vitest';
2
+ import { createFal } from './fal-provider';
3
+ import { FalTranscriptionModel } from './fal-transcription-model';
4
+ import { readFile } from 'node:fs/promises';
5
+ import path from 'node:path';
6
+ import { describe, it, expect } from 'vitest';
7
+
8
+ const audioData = await readFile(path.join(__dirname, 'transcript-test.mp3'));
9
+ const provider = createFal({ apiKey: 'test-api-key' });
10
+ const model = provider.transcription('wizper');
11
+
12
+ const server = createTestServer({
13
+ 'https://queue.fal.run/fal-ai/wizper': {},
14
+ 'https://queue.fal.run/fal-ai/wizper/requests/test-id': {},
15
+ });
16
+
17
+ describe('doGenerate', () => {
18
+ function prepareJsonResponse({
19
+ headers,
20
+ }: {
21
+ headers?: Record<string, string>;
22
+ } = {}) {
23
+ server.urls['https://queue.fal.run/fal-ai/wizper'].response = {
24
+ type: 'json-value',
25
+ headers,
26
+ body: {
27
+ status: 'COMPLETED',
28
+ request_id: 'test-id',
29
+ response_url:
30
+ 'https://queue.fal.run/fal-ai/wizper/requests/test-id/result',
31
+ status_url: 'https://queue.fal.run/fal-ai/wizper/requests/test-id',
32
+ cancel_url:
33
+ 'https://queue.fal.run/fal-ai/wizper/requests/test-id/cancel',
34
+ logs: null,
35
+ metrics: {},
36
+ queue_position: 0,
37
+ },
38
+ };
39
+ server.urls[
40
+ 'https://queue.fal.run/fal-ai/wizper/requests/test-id'
41
+ ].response = {
42
+ type: 'json-value',
43
+ headers,
44
+ body: {
45
+ text: 'Hello world!',
46
+ chunks: [
47
+ {
48
+ text: 'Hello',
49
+ timestamp: [0, 1],
50
+ speaker: 'speaker_1',
51
+ },
52
+ {
53
+ text: ' ',
54
+ timestamp: [1, 2],
55
+ speaker: 'speaker_1',
56
+ },
57
+ {
58
+ text: 'world!',
59
+ timestamp: [2, 3],
60
+ speaker: 'speaker_1',
61
+ },
62
+ ],
63
+ diarization_segments: [
64
+ {
65
+ speaker: 'speaker_1',
66
+ timestamp: [0, 3],
67
+ },
68
+ ],
69
+ },
70
+ };
71
+ }
72
+
73
+ it('should pass the model', async () => {
74
+ prepareJsonResponse();
75
+
76
+ await model.doGenerate({
77
+ audio: audioData,
78
+ mediaType: 'audio/wav',
79
+ });
80
+
81
+ expect(await server.calls[0].requestBodyJson).toMatchObject({
82
+ audio_url: expect.stringMatching(/^data:audio\//),
83
+ task: 'transcribe',
84
+ diarize: true,
85
+ chunk_level: 'word',
86
+ });
87
+ });
88
+
89
+ it('should pass headers', async () => {
90
+ prepareJsonResponse();
91
+
92
+ const provider = createFal({
93
+ apiKey: 'test-api-key',
94
+ headers: {
95
+ 'Custom-Provider-Header': 'provider-header-value',
96
+ },
97
+ });
98
+
99
+ await provider.transcription('wizper').doGenerate({
100
+ audio: audioData,
101
+ mediaType: 'audio/wav',
102
+ headers: {
103
+ 'Custom-Request-Header': 'request-header-value',
104
+ },
105
+ });
106
+
107
+ expect(server.calls[0].requestHeaders).toMatchObject({
108
+ authorization: 'Key test-api-key',
109
+ 'content-type': 'application/json',
110
+ 'custom-provider-header': 'provider-header-value',
111
+ 'custom-request-header': 'request-header-value',
112
+ });
113
+ });
114
+
115
+ it('should extract the transcription text', async () => {
116
+ prepareJsonResponse();
117
+
118
+ const result = await model.doGenerate({
119
+ audio: audioData,
120
+ mediaType: 'audio/wav',
121
+ });
122
+
123
+ expect(result.text).toBe('Hello world!');
124
+ });
125
+
126
+ it('should include response data with timestamp, modelId and headers', async () => {
127
+ prepareJsonResponse({
128
+ headers: {
129
+ 'x-request-id': 'test-request-id',
130
+ 'x-ratelimit-remaining': '123',
131
+ },
132
+ });
133
+
134
+ const testDate = new Date(0);
135
+ const customModel = new FalTranscriptionModel('wizper', {
136
+ provider: 'test-provider',
137
+ url: ({ path }) => path,
138
+ headers: () => ({}),
139
+ _internal: {
140
+ currentDate: () => testDate,
141
+ },
142
+ });
143
+
144
+ const result = await customModel.doGenerate({
145
+ audio: audioData,
146
+ mediaType: 'audio/wav',
147
+ });
148
+
149
+ expect(result.response).toMatchObject({
150
+ timestamp: testDate,
151
+ modelId: 'wizper',
152
+ headers: {
153
+ 'content-type': 'application/json',
154
+ 'x-request-id': 'test-request-id',
155
+ 'x-ratelimit-remaining': '123',
156
+ },
157
+ });
158
+ });
159
+
160
+ it('should use real date when no custom date provider is specified', async () => {
161
+ prepareJsonResponse();
162
+
163
+ const testDate = new Date(0);
164
+ const customModel = new FalTranscriptionModel('wizper', {
165
+ provider: 'test-provider',
166
+ url: ({ path }) => path,
167
+ headers: () => ({}),
168
+ _internal: {
169
+ currentDate: () => testDate,
170
+ },
171
+ });
172
+
173
+ const result = await customModel.doGenerate({
174
+ audio: audioData,
175
+ mediaType: 'audio/wav',
176
+ });
177
+
178
+ expect(result.response.timestamp.getTime()).toEqual(testDate.getTime());
179
+ expect(result.response.modelId).toBe('wizper');
180
+ });
181
+ });