@umituz/react-native-ai-pruna-provider 1.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +447 -0
- package/package.json +58 -0
- package/src/domain/entities/error.types.ts +52 -0
- package/src/domain/entities/pruna.types.ts +112 -0
- package/src/domain/types/index.ts +21 -0
- package/src/exports/domain.ts +46 -0
- package/src/exports/infrastructure.ts +47 -0
- package/src/exports/presentation.ts +14 -0
- package/src/index.ts +27 -0
- package/src/infrastructure/services/index.ts +12 -0
- package/src/infrastructure/services/pruna-api-client.ts +243 -0
- package/src/infrastructure/services/pruna-input-builder.ts +127 -0
- package/src/infrastructure/services/pruna-provider-subscription.ts +262 -0
- package/src/infrastructure/services/pruna-provider.constants.ts +83 -0
- package/src/infrastructure/services/pruna-provider.ts +211 -0
- package/src/infrastructure/services/pruna-queue-operations.ts +131 -0
- package/src/infrastructure/services/request-store.ts +148 -0
- package/src/infrastructure/utils/helpers/index.ts +25 -0
- package/src/infrastructure/utils/index.ts +30 -0
- package/src/infrastructure/utils/log-collector.ts +96 -0
- package/src/infrastructure/utils/pruna-error-handler.util.ts +119 -0
- package/src/infrastructure/utils/pruna-generation-state-manager.util.ts +98 -0
- package/src/infrastructure/utils/type-guards/index.ts +31 -0
- package/src/init/createAiProviderInitModule.ts +87 -0
- package/src/init/initializePrunaProvider.ts +35 -0
- package/src/presentation/hooks/index.ts +6 -0
- package/src/presentation/hooks/use-pruna-generation.ts +169 -0
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Domain Layer Exports
|
|
3
|
+
*/
|
|
4
|
+
|
|
5
|
+
export type {
|
|
6
|
+
PrunaConfig,
|
|
7
|
+
PrunaModel,
|
|
8
|
+
PrunaModelId,
|
|
9
|
+
PrunaModelType,
|
|
10
|
+
PrunaModelPricing,
|
|
11
|
+
PrunaAspectRatio,
|
|
12
|
+
PrunaResolution,
|
|
13
|
+
PrunaJobInput,
|
|
14
|
+
PrunaJobResult,
|
|
15
|
+
PrunaLogEntry,
|
|
16
|
+
PrunaJobStatusType,
|
|
17
|
+
PrunaQueueStatus,
|
|
18
|
+
PrunaSubscribeOptions,
|
|
19
|
+
PrunaPredictionInput,
|
|
20
|
+
PrunaPredictionResponse,
|
|
21
|
+
PrunaFileUploadResponse,
|
|
22
|
+
} from "../domain/entities/pruna.types";
|
|
23
|
+
|
|
24
|
+
export { PrunaErrorType } from "../domain/entities/error.types";
|
|
25
|
+
export type {
|
|
26
|
+
PrunaErrorCategory,
|
|
27
|
+
PrunaErrorInfo,
|
|
28
|
+
PrunaErrorMessages,
|
|
29
|
+
} from "../domain/entities/error.types";
|
|
30
|
+
|
|
31
|
+
export type {
|
|
32
|
+
ImageFeatureType,
|
|
33
|
+
VideoFeatureType,
|
|
34
|
+
AIProviderConfig,
|
|
35
|
+
AIJobStatusType,
|
|
36
|
+
AILogEntry,
|
|
37
|
+
JobSubmission,
|
|
38
|
+
JobStatus,
|
|
39
|
+
ProviderProgressInfo,
|
|
40
|
+
SubscribeOptions,
|
|
41
|
+
RunOptions,
|
|
42
|
+
ProviderCapabilities,
|
|
43
|
+
ImageFeatureInputData,
|
|
44
|
+
VideoFeatureInputData,
|
|
45
|
+
IAIProvider,
|
|
46
|
+
} from "../domain/types";
|
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Infrastructure Layer Exports
|
|
3
|
+
*/
|
|
4
|
+
|
|
5
|
+
export {
|
|
6
|
+
PrunaProvider,
|
|
7
|
+
prunaProvider,
|
|
8
|
+
cleanupRequestStore,
|
|
9
|
+
stopAutomaticCleanup,
|
|
10
|
+
} from "../infrastructure/services";
|
|
11
|
+
export type { PrunaProviderType, ActiveRequest } from "../infrastructure/services";
|
|
12
|
+
|
|
13
|
+
export {
|
|
14
|
+
mapPrunaError,
|
|
15
|
+
isPrunaErrorRetryable,
|
|
16
|
+
getErrorMessage,
|
|
17
|
+
getErrorMessageOr,
|
|
18
|
+
formatErrorMessage,
|
|
19
|
+
} from "../infrastructure/utils";
|
|
20
|
+
|
|
21
|
+
export {
|
|
22
|
+
isPrunaModelId,
|
|
23
|
+
isPrunaErrorType,
|
|
24
|
+
isValidApiKey,
|
|
25
|
+
isValidModelId,
|
|
26
|
+
isValidPrompt,
|
|
27
|
+
isValidTimeout,
|
|
28
|
+
} from "../infrastructure/utils";
|
|
29
|
+
|
|
30
|
+
export {
|
|
31
|
+
isDefined,
|
|
32
|
+
removeNullish,
|
|
33
|
+
generateUniqueId,
|
|
34
|
+
sleep,
|
|
35
|
+
} from "../infrastructure/utils";
|
|
36
|
+
|
|
37
|
+
export {
|
|
38
|
+
PRUNA_BASE_URL,
|
|
39
|
+
PRUNA_PREDICTIONS_URL,
|
|
40
|
+
PRUNA_FILES_URL,
|
|
41
|
+
DEFAULT_PRUNA_CONFIG,
|
|
42
|
+
UPLOAD_CONFIG,
|
|
43
|
+
PRUNA_CAPABILITIES,
|
|
44
|
+
VALID_PRUNA_MODELS,
|
|
45
|
+
P_VIDEO_DEFAULTS,
|
|
46
|
+
DEFAULT_ASPECT_RATIO,
|
|
47
|
+
} from "../infrastructure/services/pruna-provider.constants";
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Presentation Layer Exports
|
|
3
|
+
*/
|
|
4
|
+
|
|
5
|
+
export { usePrunaGeneration } from "../presentation/hooks";
|
|
6
|
+
export type { UsePrunaGenerationOptions, UsePrunaGenerationResult } from "../presentation/hooks";
|
|
7
|
+
|
|
8
|
+
export {
|
|
9
|
+
PrunaGenerationStateManager,
|
|
10
|
+
} from "../infrastructure/utils/pruna-generation-state-manager.util";
|
|
11
|
+
export type {
|
|
12
|
+
GenerationState,
|
|
13
|
+
GenerationStateOptions,
|
|
14
|
+
} from "../infrastructure/utils/pruna-generation-state-manager.util";
|
package/src/index.ts
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @umituz/react-native-ai-pruna-provider
|
|
3
|
+
* Pruna AI provider for React Native - implements IAIProvider interface
|
|
4
|
+
*
|
|
5
|
+
* Supported models:
|
|
6
|
+
* p-image: text-to-image
|
|
7
|
+
* p-image-edit: image-to-image
|
|
8
|
+
* p-video: image-to-video
|
|
9
|
+
*/
|
|
10
|
+
|
|
11
|
+
// Domain Layer
|
|
12
|
+
export * from "./exports/domain";
|
|
13
|
+
|
|
14
|
+
// Infrastructure Layer
|
|
15
|
+
export * from "./exports/infrastructure";
|
|
16
|
+
|
|
17
|
+
// Presentation Layer
|
|
18
|
+
export * from "./exports/presentation";
|
|
19
|
+
|
|
20
|
+
// Init Module Factory
|
|
21
|
+
export {
|
|
22
|
+
createAiProviderInitModule,
|
|
23
|
+
type AiProviderInitModuleConfig,
|
|
24
|
+
} from './init/createAiProviderInitModule';
|
|
25
|
+
|
|
26
|
+
// Direct Initialization
|
|
27
|
+
export { initializePrunaProvider } from './init/initializePrunaProvider';
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Services Index
|
|
3
|
+
*/
|
|
4
|
+
|
|
5
|
+
export { PrunaProvider, prunaProvider } from "./pruna-provider";
|
|
6
|
+
export type { PrunaProvider as PrunaProviderType } from "./pruna-provider";
|
|
7
|
+
|
|
8
|
+
export {
|
|
9
|
+
cleanupRequestStore,
|
|
10
|
+
stopAutomaticCleanup,
|
|
11
|
+
} from "./request-store";
|
|
12
|
+
export type { ActiveRequest } from "./request-store";
|
|
@@ -0,0 +1,243 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Pruna API Client
|
|
3
|
+
* Low-level HTTP interactions with the Pruna AI API
|
|
4
|
+
*
|
|
5
|
+
* Endpoints:
|
|
6
|
+
* POST /v1/predictions — submit generation (with Try-Sync header for immediate results)
|
|
7
|
+
* POST /v1/files — upload images for p-video (requires file URL, not base64)
|
|
8
|
+
* GET {poll_url} — poll async results
|
|
9
|
+
*
|
|
10
|
+
* Authentication: `apikey` header
|
|
11
|
+
* Model selection: `Model` header
|
|
12
|
+
*/
|
|
13
|
+
|
|
14
|
+
import type { PrunaModelId, PrunaPredictionResponse, PrunaFileUploadResponse } from "../../domain/entities/pruna.types";
|
|
15
|
+
import { PRUNA_BASE_URL, PRUNA_PREDICTIONS_URL, PRUNA_FILES_URL } from "./pruna-provider.constants";
|
|
16
|
+
import { generationLogCollector } from "../utils/log-collector";
|
|
17
|
+
|
|
18
|
+
const TAG = 'pruna-api';
|
|
19
|
+
|
|
20
|
+
/**
|
|
21
|
+
* Upload a base64 image to Pruna's file storage.
|
|
22
|
+
* p-video requires a file URL (not raw base64).
|
|
23
|
+
* Returns the HTTPS file URL to use in predictions.
|
|
24
|
+
*/
|
|
25
|
+
export async function uploadImageToFiles(
|
|
26
|
+
base64Data: string,
|
|
27
|
+
apiKey: string,
|
|
28
|
+
sessionId: string,
|
|
29
|
+
): Promise<string> {
|
|
30
|
+
// Already a URL — return as-is
|
|
31
|
+
if (base64Data.startsWith('http')) {
|
|
32
|
+
generationLogCollector.log(sessionId, TAG, 'Image already a URL, skipping upload');
|
|
33
|
+
return base64Data;
|
|
34
|
+
}
|
|
35
|
+
|
|
36
|
+
generationLogCollector.log(sessionId, TAG, 'Uploading image to Pruna file storage...');
|
|
37
|
+
|
|
38
|
+
// Strip data URI prefix if present
|
|
39
|
+
const raw = base64Data.includes('base64,') ? base64Data.split('base64,')[1] : base64Data;
|
|
40
|
+
|
|
41
|
+
let binaryStr: string;
|
|
42
|
+
try {
|
|
43
|
+
binaryStr = atob(raw);
|
|
44
|
+
} catch {
|
|
45
|
+
throw new Error("Invalid image format. Please provide base64 or a valid URL.");
|
|
46
|
+
}
|
|
47
|
+
|
|
48
|
+
const bytes = new Uint8Array(binaryStr.length);
|
|
49
|
+
for (let i = 0; i < binaryStr.length; i++) {
|
|
50
|
+
bytes[i] = binaryStr.charCodeAt(i);
|
|
51
|
+
}
|
|
52
|
+
|
|
53
|
+
// Detect MIME from first bytes
|
|
54
|
+
let mime = 'image/png';
|
|
55
|
+
if (bytes[0] === 0xFF && bytes[1] === 0xD8) mime = 'image/jpeg';
|
|
56
|
+
else if (bytes[0] === 0x52 && bytes[1] === 0x49) mime = 'image/webp';
|
|
57
|
+
|
|
58
|
+
const blob = new Blob([bytes], { type: mime });
|
|
59
|
+
const ext = mime.split('/')[1];
|
|
60
|
+
const formData = new FormData();
|
|
61
|
+
formData.append('content', blob, `upload.${ext}`);
|
|
62
|
+
|
|
63
|
+
const startTime = Date.now();
|
|
64
|
+
|
|
65
|
+
const response = await fetch(PRUNA_FILES_URL, {
|
|
66
|
+
method: 'POST',
|
|
67
|
+
headers: { 'apikey': apiKey },
|
|
68
|
+
body: formData,
|
|
69
|
+
});
|
|
70
|
+
|
|
71
|
+
if (!response.ok) {
|
|
72
|
+
const err = await response.json().catch(() => ({ message: response.statusText }));
|
|
73
|
+
const errorMessage = (err as { message?: string }).message || `File upload error: ${response.status}`;
|
|
74
|
+
generationLogCollector.error(sessionId, TAG, `File upload failed: ${errorMessage}`);
|
|
75
|
+
throw new Error(errorMessage);
|
|
76
|
+
}
|
|
77
|
+
|
|
78
|
+
const data: PrunaFileUploadResponse = await response.json();
|
|
79
|
+
const fileUrl = data.urls?.get || `${PRUNA_FILES_URL}/${data.id}`;
|
|
80
|
+
|
|
81
|
+
const elapsed = Date.now() - startTime;
|
|
82
|
+
generationLogCollector.log(sessionId, TAG, `File upload completed in ${elapsed}ms → ${fileUrl}`);
|
|
83
|
+
|
|
84
|
+
return fileUrl;
|
|
85
|
+
}
|
|
86
|
+
|
|
87
|
+
/**
|
|
88
|
+
* Strip base64 data URI prefix, returning raw base64 string.
|
|
89
|
+
* If input is already a URL, returns it unchanged.
|
|
90
|
+
*/
|
|
91
|
+
export function stripBase64Prefix(image: string): string {
|
|
92
|
+
if (image.startsWith('http')) return image;
|
|
93
|
+
return image.includes('base64,') ? image.split('base64,')[1] : image;
|
|
94
|
+
}
|
|
95
|
+
|
|
96
|
+
/**
|
|
97
|
+
* Submit a prediction to Pruna AI.
|
|
98
|
+
* Uses Try-Sync header for potential immediate results.
|
|
99
|
+
* Returns raw response (may contain result or polling URL).
|
|
100
|
+
*/
|
|
101
|
+
export async function submitPrediction(
|
|
102
|
+
model: PrunaModelId,
|
|
103
|
+
input: Record<string, unknown>,
|
|
104
|
+
apiKey: string,
|
|
105
|
+
sessionId: string,
|
|
106
|
+
): Promise<PrunaPredictionResponse> {
|
|
107
|
+
generationLogCollector.log(sessionId, TAG, `Submitting prediction for model: ${model}`, {
|
|
108
|
+
inputKeys: Object.keys(input),
|
|
109
|
+
});
|
|
110
|
+
|
|
111
|
+
const startTime = Date.now();
|
|
112
|
+
|
|
113
|
+
const response = await fetch(PRUNA_PREDICTIONS_URL, {
|
|
114
|
+
method: 'POST',
|
|
115
|
+
headers: {
|
|
116
|
+
'apikey': apiKey,
|
|
117
|
+
'Model': model,
|
|
118
|
+
'Try-Sync': 'true',
|
|
119
|
+
'Content-Type': 'application/json',
|
|
120
|
+
},
|
|
121
|
+
body: JSON.stringify({ input }),
|
|
122
|
+
});
|
|
123
|
+
|
|
124
|
+
if (!response.ok) {
|
|
125
|
+
const errorData = await response.json().catch(() => ({ message: response.statusText }));
|
|
126
|
+
const errorMessage = (errorData as { message?: string }).message || `API error: ${response.status}`;
|
|
127
|
+
|
|
128
|
+
generationLogCollector.error(sessionId, TAG, `Prediction failed (${response.status}): ${errorMessage}`);
|
|
129
|
+
|
|
130
|
+
const error = new Error(errorMessage);
|
|
131
|
+
(error as Error & { statusCode?: number }).statusCode = response.status;
|
|
132
|
+
throw error;
|
|
133
|
+
}
|
|
134
|
+
|
|
135
|
+
const elapsed = Date.now() - startTime;
|
|
136
|
+
const result: PrunaPredictionResponse = await response.json();
|
|
137
|
+
|
|
138
|
+
generationLogCollector.log(sessionId, TAG, `Prediction response received in ${elapsed}ms`, {
|
|
139
|
+
hasUri: !!extractUri(result),
|
|
140
|
+
hasGetUrl: !!result.get_url,
|
|
141
|
+
hasStatusUrl: !!result.status_url,
|
|
142
|
+
status: result.status,
|
|
143
|
+
});
|
|
144
|
+
|
|
145
|
+
return result;
|
|
146
|
+
}
|
|
147
|
+
|
|
148
|
+
/**
|
|
149
|
+
* Poll for async prediction results.
|
|
150
|
+
* Polls every 3 seconds up to maxAttempts (~6 min at 120 attempts).
|
|
151
|
+
*/
|
|
152
|
+
export async function pollForResult(
|
|
153
|
+
pollUrl: string,
|
|
154
|
+
apiKey: string,
|
|
155
|
+
sessionId: string,
|
|
156
|
+
maxAttempts: number,
|
|
157
|
+
intervalMs: number,
|
|
158
|
+
signal?: AbortSignal,
|
|
159
|
+
): Promise<string> {
|
|
160
|
+
const fullPollUrl = pollUrl.startsWith('http') ? pollUrl : `${PRUNA_BASE_URL}${pollUrl}`;
|
|
161
|
+
|
|
162
|
+
generationLogCollector.log(sessionId, TAG, `Starting polling at ${fullPollUrl} (max ${maxAttempts} attempts, ${intervalMs}ms interval)`);
|
|
163
|
+
|
|
164
|
+
for (let i = 0; i < maxAttempts; i++) {
|
|
165
|
+
if (signal?.aborted) {
|
|
166
|
+
throw new Error("Request cancelled by user");
|
|
167
|
+
}
|
|
168
|
+
|
|
169
|
+
await new Promise(resolve => setTimeout(resolve, intervalMs));
|
|
170
|
+
|
|
171
|
+
if (signal?.aborted) {
|
|
172
|
+
throw new Error("Request cancelled by user");
|
|
173
|
+
}
|
|
174
|
+
|
|
175
|
+
try {
|
|
176
|
+
const statusRes = await fetch(fullPollUrl, {
|
|
177
|
+
headers: { 'apikey': apiKey },
|
|
178
|
+
});
|
|
179
|
+
|
|
180
|
+
if (!statusRes.ok) {
|
|
181
|
+
generationLogCollector.warn(sessionId, TAG, `Poll attempt ${i + 1}/${maxAttempts}: HTTP ${statusRes.status}, skipping...`);
|
|
182
|
+
continue;
|
|
183
|
+
}
|
|
184
|
+
|
|
185
|
+
const statusData: PrunaPredictionResponse = await statusRes.json();
|
|
186
|
+
|
|
187
|
+
if (statusData.status === 'succeeded' || statusData.status === 'completed') {
|
|
188
|
+
const uri = extractUri(statusData);
|
|
189
|
+
if (uri) {
|
|
190
|
+
generationLogCollector.log(sessionId, TAG, `Polling completed at attempt ${i + 1}/${maxAttempts}`);
|
|
191
|
+
return resolveUri(uri);
|
|
192
|
+
}
|
|
193
|
+
} else if (statusData.status === 'failed') {
|
|
194
|
+
const errorMessage = statusData.error || "Generation failed during processing.";
|
|
195
|
+
generationLogCollector.error(sessionId, TAG, `Polling: generation failed — ${errorMessage}`);
|
|
196
|
+
throw new Error(errorMessage);
|
|
197
|
+
}
|
|
198
|
+
|
|
199
|
+
// Still processing — log progress periodically
|
|
200
|
+
if ((i + 1) % 10 === 0) {
|
|
201
|
+
generationLogCollector.log(sessionId, TAG, `Polling: still processing (attempt ${i + 1}/${maxAttempts})...`);
|
|
202
|
+
}
|
|
203
|
+
} catch (error) {
|
|
204
|
+
if (error instanceof Error && error.message.includes("cancelled by user")) {
|
|
205
|
+
throw error;
|
|
206
|
+
}
|
|
207
|
+
// Non-fatal poll error — continue polling
|
|
208
|
+
if (error instanceof Error && !error.message.includes("failed during processing")) {
|
|
209
|
+
generationLogCollector.warn(sessionId, TAG, `Poll attempt ${i + 1} error: ${error.message}`);
|
|
210
|
+
continue;
|
|
211
|
+
}
|
|
212
|
+
throw error;
|
|
213
|
+
}
|
|
214
|
+
}
|
|
215
|
+
|
|
216
|
+
throw new Error("Generation timed out. Maximum polling attempts reached.");
|
|
217
|
+
}
|
|
218
|
+
|
|
219
|
+
/**
|
|
220
|
+
* Extract result URI from Pruna API response.
|
|
221
|
+
* Checks multiple possible locations (priority order).
|
|
222
|
+
*/
|
|
223
|
+
export function extractUri(data: PrunaPredictionResponse): string | null {
|
|
224
|
+
return (
|
|
225
|
+
data.generation_url ||
|
|
226
|
+
(data.output && typeof data.output === 'object' && !Array.isArray(data.output) ? (data.output as { url: string }).url : null) ||
|
|
227
|
+
(typeof data.output === 'string' ? data.output : null) ||
|
|
228
|
+
data.data ||
|
|
229
|
+
data.video_url ||
|
|
230
|
+
(Array.isArray(data.output) ? data.output[0] : null) ||
|
|
231
|
+
null
|
|
232
|
+
);
|
|
233
|
+
}
|
|
234
|
+
|
|
235
|
+
/**
|
|
236
|
+
* Resolve relative URIs to absolute URLs
|
|
237
|
+
*/
|
|
238
|
+
export function resolveUri(uri: string): string {
|
|
239
|
+
if (uri.startsWith('/')) {
|
|
240
|
+
return `${PRUNA_BASE_URL}${uri}`;
|
|
241
|
+
}
|
|
242
|
+
return uri;
|
|
243
|
+
}
|
|
@@ -0,0 +1,127 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Pruna Input Builder
|
|
3
|
+
* Builds model-specific payloads from generic input.
|
|
4
|
+
*
|
|
5
|
+
* Each Pruna model has strict schema requirements:
|
|
6
|
+
* p-image: { prompt, aspect_ratio? }
|
|
7
|
+
* p-image-edit: { images: string[], prompt, aspect_ratio? }
|
|
8
|
+
* p-video: { image: string (URL), prompt, duration, resolution, fps, draft, aspect_ratio, prompt_upsampling }
|
|
9
|
+
*/
|
|
10
|
+
|
|
11
|
+
import type { PrunaModelId, PrunaAspectRatio, PrunaResolution } from "../../domain/entities/pruna.types";
|
|
12
|
+
import { P_VIDEO_DEFAULTS, DEFAULT_ASPECT_RATIO } from "./pruna-provider.constants";
|
|
13
|
+
import { uploadImageToFiles, stripBase64Prefix } from "./pruna-api-client";
|
|
14
|
+
import { generationLogCollector } from "../utils/log-collector";
|
|
15
|
+
|
|
16
|
+
const TAG = 'pruna-input-builder';
|
|
17
|
+
|
|
18
|
+
/**
|
|
19
|
+
* Build model-specific input payload from generic input.
|
|
20
|
+
* Handles image uploads for p-video (requires file URL).
|
|
21
|
+
*/
|
|
22
|
+
export async function buildModelInput(
|
|
23
|
+
model: PrunaModelId,
|
|
24
|
+
input: Record<string, unknown>,
|
|
25
|
+
apiKey: string,
|
|
26
|
+
sessionId: string,
|
|
27
|
+
): Promise<Record<string, unknown>> {
|
|
28
|
+
const prompt = input.prompt as string | undefined;
|
|
29
|
+
if (!prompt) {
|
|
30
|
+
throw new Error("Prompt is required for all Pruna models.");
|
|
31
|
+
}
|
|
32
|
+
|
|
33
|
+
const aspectRatio = (input.aspect_ratio as PrunaAspectRatio) || DEFAULT_ASPECT_RATIO;
|
|
34
|
+
|
|
35
|
+
if (model === 'p-image') {
|
|
36
|
+
return buildImageInput(prompt, aspectRatio, input);
|
|
37
|
+
}
|
|
38
|
+
|
|
39
|
+
if (model === 'p-image-edit') {
|
|
40
|
+
return buildImageEditInput(prompt, aspectRatio, input, sessionId);
|
|
41
|
+
}
|
|
42
|
+
|
|
43
|
+
if (model === 'p-video') {
|
|
44
|
+
return buildVideoInput(prompt, aspectRatio, input, apiKey, sessionId);
|
|
45
|
+
}
|
|
46
|
+
|
|
47
|
+
throw new Error(`Unknown Pruna model: ${model}`);
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
function buildImageInput(
|
|
51
|
+
prompt: string,
|
|
52
|
+
aspectRatio: PrunaAspectRatio,
|
|
53
|
+
input: Record<string, unknown>,
|
|
54
|
+
): Record<string, unknown> {
|
|
55
|
+
const payload: Record<string, unknown> = { prompt, aspect_ratio: aspectRatio };
|
|
56
|
+
|
|
57
|
+
if (input.seed !== undefined) {
|
|
58
|
+
payload.seed = input.seed;
|
|
59
|
+
}
|
|
60
|
+
|
|
61
|
+
return payload;
|
|
62
|
+
}
|
|
63
|
+
|
|
64
|
+
function buildImageEditInput(
|
|
65
|
+
prompt: string,
|
|
66
|
+
aspectRatio: PrunaAspectRatio,
|
|
67
|
+
input: Record<string, unknown>,
|
|
68
|
+
sessionId: string,
|
|
69
|
+
): Record<string, unknown> {
|
|
70
|
+
// p-image-edit expects images array
|
|
71
|
+
let images: string[];
|
|
72
|
+
|
|
73
|
+
if (Array.isArray(input.images)) {
|
|
74
|
+
images = (input.images as string[]).map(stripBase64Prefix);
|
|
75
|
+
} else if (typeof input.image === 'string') {
|
|
76
|
+
images = [stripBase64Prefix(input.image as string)];
|
|
77
|
+
} else if (typeof input.image_url === 'string') {
|
|
78
|
+
images = [stripBase64Prefix(input.image_url as string)];
|
|
79
|
+
} else if (Array.isArray(input.image_urls)) {
|
|
80
|
+
images = (input.image_urls as string[]).map(stripBase64Prefix);
|
|
81
|
+
} else {
|
|
82
|
+
throw new Error("Image is required for p-image-edit. Provide 'image', 'images', 'image_url', or 'image_urls'.");
|
|
83
|
+
}
|
|
84
|
+
|
|
85
|
+
generationLogCollector.log(sessionId, TAG, `p-image-edit: ${images.length} image(s) prepared`);
|
|
86
|
+
|
|
87
|
+
const payload: Record<string, unknown> = { images, prompt, aspect_ratio: aspectRatio };
|
|
88
|
+
|
|
89
|
+
if (input.seed !== undefined) {
|
|
90
|
+
payload.seed = input.seed;
|
|
91
|
+
}
|
|
92
|
+
|
|
93
|
+
return payload;
|
|
94
|
+
}
|
|
95
|
+
|
|
96
|
+
async function buildVideoInput(
|
|
97
|
+
prompt: string,
|
|
98
|
+
aspectRatio: PrunaAspectRatio,
|
|
99
|
+
input: Record<string, unknown>,
|
|
100
|
+
apiKey: string,
|
|
101
|
+
sessionId: string,
|
|
102
|
+
): Promise<Record<string, unknown>> {
|
|
103
|
+
// p-video requires an image file URL
|
|
104
|
+
const rawImage = (input.image as string) || (input.image_url as string);
|
|
105
|
+
if (!rawImage) {
|
|
106
|
+
throw new Error("Image is required for p-video. Provide 'image' or 'image_url'.");
|
|
107
|
+
}
|
|
108
|
+
|
|
109
|
+
// Upload base64 to file storage if needed (p-video requires HTTPS URL)
|
|
110
|
+
generationLogCollector.log(sessionId, TAG, 'p-video: preparing image for video generation...');
|
|
111
|
+
const fileUrl = await uploadImageToFiles(rawImage, apiKey, sessionId);
|
|
112
|
+
|
|
113
|
+
const duration = (input.duration as number) ?? P_VIDEO_DEFAULTS.duration;
|
|
114
|
+
const resolution = (input.resolution as PrunaResolution) ?? P_VIDEO_DEFAULTS.resolution;
|
|
115
|
+
const draft = (input.draft as boolean) ?? P_VIDEO_DEFAULTS.draft;
|
|
116
|
+
|
|
117
|
+
return {
|
|
118
|
+
image: fileUrl,
|
|
119
|
+
prompt,
|
|
120
|
+
duration,
|
|
121
|
+
resolution,
|
|
122
|
+
fps: P_VIDEO_DEFAULTS.fps,
|
|
123
|
+
draft,
|
|
124
|
+
aspect_ratio: aspectRatio,
|
|
125
|
+
prompt_upsampling: P_VIDEO_DEFAULTS.promptUpsampling,
|
|
126
|
+
};
|
|
127
|
+
}
|