@aj-archipelago/cortex 1.3.12 → 1.3.15
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +11 -1
- package/config.js +5 -4
- package/helper-apps/cortex-file-handler/blobHandler.js +109 -84
- package/helper-apps/cortex-file-handler/constants.js +1 -0
- package/helper-apps/cortex-file-handler/fileChunker.js +1 -2
- package/helper-apps/cortex-file-handler/helper.js +45 -1
- package/helper-apps/cortex-file-handler/index.js +43 -55
- package/helper-apps/cortex-file-handler/package.json +3 -2
- package/helper-apps/cortex-file-handler/scripts/test-azure.sh +1 -1
- package/helper-apps/cortex-file-handler/start.js +14 -1
- package/helper-apps/cortex-file-handler/tests/blobHandler.test.js +292 -0
- package/helper-apps/cortex-file-handler/tests/fileChunker.test.js +3 -14
- package/helper-apps/cortex-file-handler/tests/start.test.js +2 -0
- package/package.json +1 -1
- package/server/plugins/azureVideoTranslatePlugin.js +279 -144
- package/server/plugins/replicateApiPlugin.js +54 -2
|
@@ -1,26 +1,223 @@
|
|
|
1
1
|
// AzureVideoTranslatePlugin.js
|
|
2
2
|
import ModelPlugin from "./modelPlugin.js";
|
|
3
3
|
import logger from "../../lib/logger.js";
|
|
4
|
-
import axios from "axios";
|
|
5
4
|
import { publishRequestProgress } from "../../lib/redisSubscription.js";
|
|
6
|
-
import
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
} catch (e) {
|
|
13
|
-
return false;
|
|
14
|
-
}
|
|
15
|
-
}
|
|
5
|
+
import crypto from 'crypto';
|
|
6
|
+
import axios from 'axios';
|
|
7
|
+
import {config} from "../../config.js";
|
|
8
|
+
|
|
9
|
+
// turn off any caching because we're polling the operation status
|
|
10
|
+
axios.defaults.cache = false;
|
|
16
11
|
|
|
17
12
|
class AzureVideoTranslatePlugin extends ModelPlugin {
|
|
13
|
+
static lastProcessingRate = null; // bytes per second
|
|
14
|
+
|
|
18
15
|
constructor(pathway, model) {
|
|
19
16
|
super(pathway, model);
|
|
20
|
-
this.
|
|
21
|
-
this.
|
|
22
|
-
this.
|
|
23
|
-
this.
|
|
17
|
+
this.subscriptionKey = config.get("azureVideoTranslationApiKey");
|
|
18
|
+
this.apiVersion = "2024-05-20-preview";
|
|
19
|
+
this.baseUrl = "";
|
|
20
|
+
this.startTime = null;
|
|
21
|
+
this.videoContentLength = null;
|
|
22
|
+
}
|
|
23
|
+
|
|
24
|
+
async verifyVideoAccess(videoUrl) {
|
|
25
|
+
try {
|
|
26
|
+
const response = await axios.head(videoUrl);
|
|
27
|
+
|
|
28
|
+
const contentType = response.headers['content-type'];
|
|
29
|
+
const contentLength = parseInt(response.headers['content-length'], 10);
|
|
30
|
+
|
|
31
|
+
if (contentType && !contentType.includes('video/mp4')) {
|
|
32
|
+
logger.warn(`Warning: Video might not be in MP4 format. Content-Type: ${contentType}`);
|
|
33
|
+
}
|
|
34
|
+
|
|
35
|
+
const TYPICAL_BITRATE = 2.5 * 1024 * 1024; // 2.5 Mbps
|
|
36
|
+
const durationSeconds = Math.round((contentLength * 8) / TYPICAL_BITRATE);
|
|
37
|
+
|
|
38
|
+
return {
|
|
39
|
+
isAccessible: true,
|
|
40
|
+
contentLength,
|
|
41
|
+
durationSeconds: durationSeconds || 60
|
|
42
|
+
};
|
|
43
|
+
} catch (error) {
|
|
44
|
+
throw new Error(`Failed to access video: ${error.message}`);
|
|
45
|
+
}
|
|
46
|
+
}
|
|
47
|
+
|
|
48
|
+
async createTranslation(params) {
|
|
49
|
+
const { videoUrl, sourceLanguage, targetLanguage, voiceKind, translationId } = params;
|
|
50
|
+
|
|
51
|
+
const translation = {
|
|
52
|
+
id: translationId,
|
|
53
|
+
displayName: `${translationId}.mp4`,
|
|
54
|
+
description: `Translate video from ${sourceLanguage} to ${targetLanguage}`,
|
|
55
|
+
input: {
|
|
56
|
+
sourceLocale: sourceLanguage,
|
|
57
|
+
targetLocale: targetLanguage,
|
|
58
|
+
voiceKind: voiceKind,
|
|
59
|
+
videoFileUrl: videoUrl
|
|
60
|
+
}
|
|
61
|
+
};
|
|
62
|
+
|
|
63
|
+
const url = `${this.baseUrl}/translations/${translationId}?api-version=${this.apiVersion}`;
|
|
64
|
+
logger.debug(`Creating translation: ${url}`);
|
|
65
|
+
|
|
66
|
+
try {
|
|
67
|
+
const response = await axios.put(url, translation, {
|
|
68
|
+
headers: {
|
|
69
|
+
'Content-Type': 'application/json',
|
|
70
|
+
'Ocp-Apim-Subscription-Key': this.subscriptionKey,
|
|
71
|
+
}
|
|
72
|
+
});
|
|
73
|
+
|
|
74
|
+
const operationUrl = response.headers['operation-location'];
|
|
75
|
+
return { translation: response.data, operationUrl };
|
|
76
|
+
} catch (error) {
|
|
77
|
+
const errorText = error.response?.data || error.message;
|
|
78
|
+
throw new Error(`Failed to create translation: ${error.message}\nDetails: ${errorText}`);
|
|
79
|
+
}
|
|
80
|
+
}
|
|
81
|
+
|
|
82
|
+
async getTranslationStatus(translationId) {
|
|
83
|
+
const url = `${this.baseUrl}/translations/${translationId}?api-version=${this.apiVersion}`;
|
|
84
|
+
try {
|
|
85
|
+
const response = await axios.get(url, {
|
|
86
|
+
headers: {
|
|
87
|
+
'Ocp-Apim-Subscription-Key': this.subscriptionKey,
|
|
88
|
+
}
|
|
89
|
+
});
|
|
90
|
+
return response.data;
|
|
91
|
+
} catch (error) {
|
|
92
|
+
throw new Error(`Failed to get translation status: ${error.message}`);
|
|
93
|
+
}
|
|
94
|
+
}
|
|
95
|
+
|
|
96
|
+
async getIterationStatus(translationId, iterationId) {
|
|
97
|
+
const url = `${this.baseUrl}/translations/${translationId}/iterations/${iterationId}?api-version=${this.apiVersion}`;
|
|
98
|
+
|
|
99
|
+
try {
|
|
100
|
+
const response = await axios.get(url, {
|
|
101
|
+
headers: {
|
|
102
|
+
'Ocp-Apim-Subscription-Key': this.subscriptionKey,
|
|
103
|
+
}
|
|
104
|
+
});
|
|
105
|
+
return response.data;
|
|
106
|
+
} catch (error) {
|
|
107
|
+
const errorText = error.response?.data || error.message;
|
|
108
|
+
throw new Error(`Failed to get iteration status: ${error.message}\nDetails: ${errorText}`);
|
|
109
|
+
}
|
|
110
|
+
}
|
|
111
|
+
|
|
112
|
+
async pollOperation(operationUrl) {
|
|
113
|
+
try {
|
|
114
|
+
const response = await axios.get(operationUrl, {
|
|
115
|
+
headers: {
|
|
116
|
+
'Ocp-Apim-Subscription-Key': this.subscriptionKey,
|
|
117
|
+
}
|
|
118
|
+
});
|
|
119
|
+
return response.data;
|
|
120
|
+
} catch (error) {
|
|
121
|
+
const errorText = error.response?.data || error.message;
|
|
122
|
+
throw new Error(`Failed to poll operation: ${error.message}\nDetails: ${errorText}`);
|
|
123
|
+
}
|
|
124
|
+
}
|
|
125
|
+
|
|
126
|
+
async monitorOperation(operationUrlOrConfig, entityType = 'operation') {
|
|
127
|
+
|
|
128
|
+
let estimatedTotalTime = 0;
|
|
129
|
+
if (AzureVideoTranslatePlugin.lastProcessingRate && this.videoContentLength) {
|
|
130
|
+
estimatedTotalTime = this.videoContentLength / AzureVideoTranslatePlugin.lastProcessingRate;
|
|
131
|
+
} else {
|
|
132
|
+
// First run: estimate based on 1x calculated video duration
|
|
133
|
+
estimatedTotalTime = (this.videoContentLength * 8) / (2.5 * 1024 * 1024);
|
|
134
|
+
}
|
|
135
|
+
|
|
136
|
+
// eslint-disable-next-line no-constant-condition
|
|
137
|
+
while (true) {
|
|
138
|
+
let status;
|
|
139
|
+
if (typeof operationUrlOrConfig === 'string') {
|
|
140
|
+
const operation = await this.pollOperation(operationUrlOrConfig);
|
|
141
|
+
status = operation;
|
|
142
|
+
} else {
|
|
143
|
+
const { translationId, iterationId } = operationUrlOrConfig;
|
|
144
|
+
const iteration = await this.getIterationStatus(translationId, iterationId);
|
|
145
|
+
status = iteration;
|
|
146
|
+
}
|
|
147
|
+
|
|
148
|
+
logger.debug(`${entityType} status: ${JSON.stringify(status, null, 2)}`);
|
|
149
|
+
|
|
150
|
+
let progress = 0;
|
|
151
|
+
let estimatedProgress = 0;
|
|
152
|
+
let progressMessage = '';
|
|
153
|
+
switch (entityType) {
|
|
154
|
+
case 'translation':
|
|
155
|
+
progressMessage = 'Getting ready to translate video...';
|
|
156
|
+
break;
|
|
157
|
+
case 'iteration':
|
|
158
|
+
if (status.status === 'NotStarted') {
|
|
159
|
+
progressMessage = 'Waiting for translation to start...';
|
|
160
|
+
} else if (status.status === 'Running') {
|
|
161
|
+
progressMessage = 'Translating video...';
|
|
162
|
+
if (this.startTime) {
|
|
163
|
+
// Calculate progress based on elapsed time
|
|
164
|
+
const elapsedSeconds = (Date.now() - this.startTime) / 1000;
|
|
165
|
+
estimatedProgress = Math.min(0.95, elapsedSeconds / estimatedTotalTime);
|
|
166
|
+
const remainingSeconds = Math.max(0, estimatedTotalTime - elapsedSeconds);
|
|
167
|
+
if (remainingSeconds > 0) {
|
|
168
|
+
if (remainingSeconds < 60) {
|
|
169
|
+
const roundedSeconds = Math.ceil(remainingSeconds);
|
|
170
|
+
progressMessage = `Translating video... ${roundedSeconds} second${roundedSeconds !== 1 ? 's' : ''} remaining`;
|
|
171
|
+
} else {
|
|
172
|
+
const remainingMinutes = Math.ceil(remainingSeconds / 60);
|
|
173
|
+
progressMessage = `Translating video... ${remainingMinutes} minute${remainingMinutes !== 1 ? 's' : ''} remaining`;
|
|
174
|
+
}
|
|
175
|
+
}
|
|
176
|
+
progress = status.percentComplete ? status.percentComplete / 100 : estimatedProgress;
|
|
177
|
+
} else {
|
|
178
|
+
this.startTime = Date.now();
|
|
179
|
+
estimatedProgress = 0;
|
|
180
|
+
}
|
|
181
|
+
} else if (status.status === 'Succeeded') {
|
|
182
|
+
progressMessage = 'Video translation complete.';
|
|
183
|
+
} else if (status.status === 'Failed') {
|
|
184
|
+
progressMessage = 'Video translation failed.';
|
|
185
|
+
}
|
|
186
|
+
break;
|
|
187
|
+
}
|
|
188
|
+
|
|
189
|
+
// Publish progress updates
|
|
190
|
+
publishRequestProgress({
|
|
191
|
+
requestId: this.requestId,
|
|
192
|
+
progress,
|
|
193
|
+
info: progressMessage
|
|
194
|
+
});
|
|
195
|
+
|
|
196
|
+
if (status.status === 'Succeeded') {
|
|
197
|
+
return status;
|
|
198
|
+
} else if (status.status === 'Failed') {
|
|
199
|
+
throw new Error(`${entityType} failed: ${status.error?.message || 'Unknown error'}`);
|
|
200
|
+
}
|
|
201
|
+
await new Promise(resolve => setTimeout(resolve, 5000));
|
|
202
|
+
}
|
|
203
|
+
}
|
|
204
|
+
|
|
205
|
+
async getTranslationOutput(translationId, iterationId) {
|
|
206
|
+
const iteration = await this.getIterationStatus(translationId, iterationId);
|
|
207
|
+
const translation = await this.getTranslationStatus(translationId);
|
|
208
|
+
if (iteration.result) {
|
|
209
|
+
const targetLocale = translation.input.targetLocale;
|
|
210
|
+
return {
|
|
211
|
+
outputVideoSubtitleWebVttFileUrl: iteration.result.sourceLocaleSubtitleWebvttFileUrl,
|
|
212
|
+
targetLocales: {
|
|
213
|
+
[targetLocale]: {
|
|
214
|
+
outputVideoFileUrl: iteration.result.translatedVideoFileUrl,
|
|
215
|
+
outputVideoSubtitleWebVttFileUrl: iteration.result.targetLocaleSubtitleWebvttFileUrl
|
|
216
|
+
}
|
|
217
|
+
}
|
|
218
|
+
};
|
|
219
|
+
}
|
|
220
|
+
return null;
|
|
24
221
|
}
|
|
25
222
|
|
|
26
223
|
getRequestParameters(_, parameters, __) {
|
|
@@ -37,150 +234,88 @@ class AzureVideoTranslatePlugin extends ModelPlugin {
|
|
|
37
234
|
);
|
|
38
235
|
}
|
|
39
236
|
|
|
40
|
-
handleStream(stream, onData, onEnd, onError) {
|
|
41
|
-
const timeout = setTimeout(() => {
|
|
42
|
-
onError(new Error('Stream timeout'));
|
|
43
|
-
}, 300000); // timeout
|
|
44
|
-
|
|
45
|
-
stream.on('data', (chunk) => {
|
|
46
|
-
clearTimeout(timeout);
|
|
47
|
-
const lines = chunk.toString().split('\n\n');
|
|
48
|
-
lines.forEach(line => {
|
|
49
|
-
if (line.startsWith('data: ')) {
|
|
50
|
-
const eventData = line.slice(6);
|
|
51
|
-
try {
|
|
52
|
-
this.handleEvent({ data: eventData }, onData);
|
|
53
|
-
} catch (error) {
|
|
54
|
-
onError(error);
|
|
55
|
-
}
|
|
56
|
-
}
|
|
57
|
-
});
|
|
58
|
-
});
|
|
59
|
-
stream.on('end', () => {
|
|
60
|
-
clearTimeout(timeout);
|
|
61
|
-
this.cleanup();
|
|
62
|
-
onEnd();
|
|
63
|
-
});
|
|
64
|
-
stream.on('error', (error) => {
|
|
65
|
-
clearTimeout(timeout);
|
|
66
|
-
console.error('Stream error:', error);
|
|
67
|
-
this.cleanup();
|
|
68
|
-
onError(error);
|
|
69
|
-
});
|
|
70
|
-
}
|
|
71
|
-
|
|
72
|
-
handleEvent(event, onData) {
|
|
73
|
-
const data = event.data;
|
|
74
|
-
this.jsonBuffer += data;
|
|
75
|
-
this.jsonDepth += (data.match(/{/g) || []).length - (data.match(/}/g) || []).length;
|
|
76
|
-
|
|
77
|
-
if (this.jsonDepth === 0 && this.jsonBuffer.trim()) {
|
|
78
|
-
logger.debug(this.jsonBuffer);
|
|
79
|
-
if (this.jsonBuffer.includes('Failed to run with exception')) {
|
|
80
|
-
this.cleanup();
|
|
81
|
-
throw new Error(this.jsonBuffer);
|
|
82
|
-
}
|
|
83
|
-
|
|
84
|
-
onData(this.jsonBuffer);
|
|
85
|
-
this.jsonBuffer = '';
|
|
86
|
-
this.jsonDepth = 0;
|
|
87
|
-
}
|
|
88
|
-
}
|
|
89
|
-
|
|
90
237
|
async execute(text, parameters, prompt, cortexRequest) {
|
|
91
|
-
if (!this.
|
|
92
|
-
throw new Error("
|
|
238
|
+
if (!this.subscriptionKey) {
|
|
239
|
+
throw new Error("Azure Video Translation subscription key is not set");
|
|
93
240
|
}
|
|
241
|
+
|
|
94
242
|
this.requestId = cortexRequest.requestId;
|
|
243
|
+
this.baseUrl = cortexRequest.url;
|
|
244
|
+
|
|
95
245
|
const requestParameters = this.getRequestParameters(text, parameters, prompt);
|
|
246
|
+
|
|
96
247
|
try {
|
|
97
|
-
const
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
248
|
+
const translationId = `cortex-translation-${this.requestId}`;
|
|
249
|
+
const videoUrl = requestParameters.sourcevideooraudiofilepath;
|
|
250
|
+
const sourceLanguage = requestParameters.sourcelocale;
|
|
251
|
+
const targetLanguage = requestParameters.targetlocale;
|
|
252
|
+
const voiceKind = requestParameters.voicekind || 'PlatformVoice';
|
|
253
|
+
const embedSubtitles = requestParameters.withoutsubtitleintranslatedvideofile === "false" ? true : false;
|
|
254
|
+
const speakerCount = parseInt(requestParameters.speakercount) || 0;
|
|
255
|
+
|
|
256
|
+
// Verify video access and get duration
|
|
257
|
+
const videoInfo = await this.verifyVideoAccess(videoUrl);
|
|
258
|
+
this.videoContentLength = videoInfo.contentLength;
|
|
259
|
+
logger.debug(`Video info: ${JSON.stringify(videoInfo, null, 2)}`);
|
|
260
|
+
|
|
261
|
+
// Create translation
|
|
262
|
+
const { operationUrl } = await this.createTranslation({
|
|
263
|
+
videoUrl, sourceLanguage, targetLanguage, voiceKind, translationId
|
|
104
264
|
});
|
|
105
265
|
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
if (parsedData.progress !== undefined) {
|
|
114
|
-
let timeInfo = '';
|
|
115
|
-
if (parsedData.estimated_time_remaining && parsedData.elapsed_time) {
|
|
116
|
-
const minutes = Math.ceil(parsedData.estimated_time_remaining / 60);
|
|
117
|
-
timeInfo = minutes <= 2
|
|
118
|
-
? `Should be done soon (${parsedData.elapsed_time} elapsed)`
|
|
119
|
-
: `Estimated ${minutes} minutes remaining`;
|
|
120
|
-
}
|
|
266
|
+
logger.debug(`Starting translation monitoring with operation URL: ${operationUrl}`);
|
|
267
|
+
// Monitor translation creation
|
|
268
|
+
const operationStatus = await this.monitorOperation(operationUrl, 'translation');
|
|
269
|
+
logger.debug(`Translation operation completed with status: ${JSON.stringify(operationStatus, null, 2)}`);
|
|
270
|
+
|
|
271
|
+
const updatedTranslation = await this.getTranslationStatus(translationId);
|
|
272
|
+
logger.debug(`Translation status after operation: ${JSON.stringify(updatedTranslation, null, 2)}`);
|
|
121
273
|
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
requestId: this.requestId,
|
|
133
|
-
info: data
|
|
134
|
-
});
|
|
135
|
-
}
|
|
136
|
-
logger.debug('Data:', data);
|
|
137
|
-
|
|
138
|
-
// Extract JSON content if message contains targetLocales
|
|
139
|
-
const jsonMatch = data.match(/{[\s\S]*"targetLocales"[\s\S]*}/);
|
|
140
|
-
if (jsonMatch) {
|
|
141
|
-
const extractedJson = jsonMatch[0];
|
|
142
|
-
if (isValidJSON(extractedJson)) {
|
|
143
|
-
finalJson = extractedJson;
|
|
144
|
-
}
|
|
145
|
-
}
|
|
146
|
-
},
|
|
147
|
-
() => {
|
|
148
|
-
resolve(finalJson)
|
|
149
|
-
},
|
|
150
|
-
(error) => reject(error)
|
|
151
|
-
);
|
|
152
|
-
}).finally(() => this.cleanup());
|
|
274
|
+
// Create iteration
|
|
275
|
+
const iteration = {
|
|
276
|
+
id: crypto.randomUUID(),
|
|
277
|
+
displayName: translationId,
|
|
278
|
+
input: {
|
|
279
|
+
subtitleMaxCharCountPerSegment: 42,
|
|
280
|
+
exportSubtitleInVideo: embedSubtitles,
|
|
281
|
+
...(speakerCount > 0 && { speakerCount })
|
|
282
|
+
}
|
|
283
|
+
};
|
|
153
284
|
|
|
154
|
-
|
|
155
|
-
this.
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
285
|
+
logger.debug(`Creating iteration: ${JSON.stringify(iteration, null, 2)}`);
|
|
286
|
+
const iterationUrl = `${this.baseUrl}/translations/${translationId}/iterations/${iteration.id}?api-version=${this.apiVersion}`;
|
|
287
|
+
try {
|
|
288
|
+
const iterationResponse = await axios.put(iterationUrl, iteration, {
|
|
289
|
+
headers: {
|
|
290
|
+
'Content-Type': 'application/json',
|
|
291
|
+
'Ocp-Apim-Subscription-Key': this.subscriptionKey,
|
|
292
|
+
'Cache-Control': 'no-cache',
|
|
293
|
+
'Pragma': 'no-cache'
|
|
294
|
+
}
|
|
295
|
+
});
|
|
159
296
|
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
return response;
|
|
168
|
-
}
|
|
297
|
+
const iterationOperationUrl = iterationResponse.headers['operation-location'];
|
|
298
|
+
await this.monitorOperation(iterationOperationUrl, 'iteration');
|
|
299
|
+
|
|
300
|
+
// Update processing rate for future estimates
|
|
301
|
+
const totalSeconds = (Date.now() - this.startTime) / 1000;
|
|
302
|
+
AzureVideoTranslatePlugin.lastProcessingRate = this.videoContentLength / totalSeconds;
|
|
303
|
+
logger.debug(`Updated processing rate: ${AzureVideoTranslatePlugin.lastProcessingRate} bytes/second`);
|
|
169
304
|
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
305
|
+
const output = await this.getTranslationOutput(translationId, iteration.id);
|
|
306
|
+
return JSON.stringify(output);
|
|
307
|
+
} catch (error) {
|
|
308
|
+
const errorText = error.response?.data || error.message;
|
|
309
|
+
throw new Error(`Failed to create iteration: ${error.message}\nDetails: ${errorText}`);
|
|
310
|
+
}
|
|
311
|
+
} catch (error) {
|
|
312
|
+
logger.error(`Error in video translation: ${error.message}`);
|
|
313
|
+
throw error;
|
|
176
314
|
}
|
|
177
315
|
}
|
|
178
316
|
|
|
179
317
|
cleanup() {
|
|
180
|
-
|
|
181
|
-
this.eventSource.close();
|
|
182
|
-
this.eventSource = null;
|
|
183
|
-
}
|
|
318
|
+
// No cleanup needed for direct API implementation
|
|
184
319
|
}
|
|
185
320
|
}
|
|
186
321
|
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
// replicateApiPlugin.js
|
|
2
2
|
import ModelPlugin from "./modelPlugin.js";
|
|
3
3
|
import logger from "../../lib/logger.js";
|
|
4
|
+
import axios from "axios";
|
|
4
5
|
|
|
5
6
|
class ReplicateApiPlugin extends ModelPlugin {
|
|
6
7
|
constructor(pathway, model) {
|
|
@@ -106,10 +107,61 @@ class ReplicateApiPlugin extends ModelPlugin {
|
|
|
106
107
|
cortexRequest.data = requestParameters;
|
|
107
108
|
cortexRequest.params = requestParameters.params;
|
|
108
109
|
|
|
109
|
-
|
|
110
|
+
// Make initial request to start prediction
|
|
111
|
+
const stringifiedResponse = await this.executeRequest(cortexRequest);
|
|
112
|
+
const parsedResponse = JSON.parse(stringifiedResponse);
|
|
113
|
+
|
|
114
|
+
// If we got a completed response, return it
|
|
115
|
+
if (parsedResponse?.status === "succeeded") {
|
|
116
|
+
return stringifiedResponse;
|
|
117
|
+
}
|
|
118
|
+
|
|
119
|
+
logger.info("Replicate API returned a non-completed response.");
|
|
120
|
+
|
|
121
|
+
if (!parsedResponse?.id) {
|
|
122
|
+
throw new Error("No prediction ID returned from Replicate API");
|
|
123
|
+
}
|
|
124
|
+
|
|
125
|
+
// Get the prediction ID and polling URL
|
|
126
|
+
const predictionId = parsedResponse.id;
|
|
127
|
+
const pollUrl = parsedResponse.urls?.get;
|
|
128
|
+
|
|
129
|
+
if (!pollUrl) {
|
|
130
|
+
throw new Error("No polling URL returned from Replicate API");
|
|
131
|
+
}
|
|
132
|
+
|
|
133
|
+
// Poll for results
|
|
134
|
+
const maxAttempts = 60; // 5 minutes with 5 second intervals
|
|
135
|
+
const pollInterval = 5000;
|
|
136
|
+
|
|
137
|
+
for (let attempt = 0; attempt < maxAttempts; attempt++) {
|
|
138
|
+
try {
|
|
139
|
+
const pollResponse = await axios.get(pollUrl, {
|
|
140
|
+
headers: cortexRequest.headers
|
|
141
|
+
});
|
|
142
|
+
|
|
143
|
+
logger.info("Polling Replicate API - attempt " + attempt);
|
|
144
|
+
const status = pollResponse.data?.status;
|
|
145
|
+
|
|
146
|
+
if (status === "succeeded") {
|
|
147
|
+
logger.info("Replicate API returned a completed response after polling");
|
|
148
|
+
return JSON.stringify(pollResponse.data);
|
|
149
|
+
} else if (status === "failed" || status === "canceled") {
|
|
150
|
+
throw new Error(`Prediction ${status}: ${pollResponse.data?.error || "Unknown error"}`);
|
|
151
|
+
}
|
|
152
|
+
|
|
153
|
+
// Wait before next poll
|
|
154
|
+
await new Promise(resolve => setTimeout(resolve, pollInterval));
|
|
155
|
+
} catch (error) {
|
|
156
|
+
logger.error(`Error polling prediction ${predictionId}: ${error.message}`);
|
|
157
|
+
throw error;
|
|
158
|
+
}
|
|
159
|
+
}
|
|
160
|
+
|
|
161
|
+
throw new Error(`Prediction ${predictionId} timed out after ${maxAttempts * pollInterval / 1000} seconds`);
|
|
110
162
|
}
|
|
111
163
|
|
|
112
|
-
//
|
|
164
|
+
// Stringify the response from the Replicate API
|
|
113
165
|
parseResponse(data) {
|
|
114
166
|
if (data.data) {
|
|
115
167
|
return JSON.stringify(data.data);
|