@sogni-ai/sogni-client 0.4.0-aplha.1 → 0.4.0-aplha.10
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/Account/CurrentAccount.d.ts +12 -2
- package/dist/Account/CurrentAccount.js.map +1 -1
- package/dist/Account/index.d.ts +6 -7
- package/dist/Account/index.js +12 -8
- package/dist/Account/index.js.map +1 -1
- package/dist/ApiClient/WebSocketClient/events.d.ts +7 -1
- package/dist/ApiClient/WebSocketClient/index.d.ts +1 -1
- package/dist/ApiClient/WebSocketClient/index.js +14 -5
- package/dist/ApiClient/WebSocketClient/index.js.map +1 -1
- package/dist/ApiClient/WebSocketClient/messages.d.ts +2 -0
- package/dist/Projects/Job.d.ts +20 -1
- package/dist/Projects/Job.js +72 -1
- package/dist/Projects/Job.js.map +1 -1
- package/dist/Projects/Project.d.ts +21 -3
- package/dist/Projects/Project.js +110 -2
- package/dist/Projects/Project.js.map +1 -1
- package/dist/Projects/createJobRequestMessage.d.ts +1 -61
- package/dist/Projects/createJobRequestMessage.js +5 -1
- package/dist/Projects/createJobRequestMessage.js.map +1 -1
- package/dist/Projects/index.d.ts +22 -3
- package/dist/Projects/index.js +70 -12
- package/dist/Projects/index.js.map +1 -1
- package/dist/Projects/types/RawProject.d.ts +87 -0
- package/dist/Projects/types/RawProject.js +3 -0
- package/dist/Projects/types/RawProject.js.map +1 -0
- package/dist/Projects/types/index.d.ts +4 -0
- package/dist/lib/DataEntity.d.ts +1 -0
- package/dist/lib/DataEntity.js +2 -0
- package/dist/lib/DataEntity.js.map +1 -1
- package/dist/lib/base64.js +8 -6
- package/dist/lib/base64.js.map +1 -1
- package/dist/types/ErrorData.d.ts +1 -0
- package/dist/version.d.ts +1 -1
- package/dist/version.js +1 -1
- package/dist/version.js.map +1 -1
- package/package.json +1 -1
- package/src/Account/CurrentAccount.ts +11 -1
- package/src/Account/index.ts +13 -9
- package/src/ApiClient/WebSocketClient/events.ts +5 -1
- package/src/ApiClient/WebSocketClient/index.ts +15 -6
- package/src/ApiClient/WebSocketClient/messages.ts +2 -0
- package/src/Projects/Job.ts +82 -1
- package/src/Projects/Project.ts +134 -5
- package/src/Projects/createJobRequestMessage.ts +5 -1
- package/src/Projects/index.ts +75 -14
- package/src/Projects/types/RawProject.ts +121 -0
- package/src/Projects/types/index.ts +4 -0
- package/src/lib/DataEntity.ts +3 -0
- package/src/lib/base64.ts +8 -4
- package/src/types/ErrorData.ts +1 -0
- package/src/version.ts +1 -1
package/src/Account/index.ts
CHANGED
|
@@ -259,24 +259,28 @@ class AccountApi extends ApiGroup {
|
|
|
259
259
|
|
|
260
260
|
/**
|
|
261
261
|
* Switch between fast and relaxed networks.
|
|
262
|
-
*
|
|
263
|
-
*
|
|
262
|
+
* This will change default network used to process projects. After switching, you will updated
|
|
263
|
+
* list of AI models available for on selected network.
|
|
264
264
|
*
|
|
265
265
|
* @example Switch to the fast network
|
|
266
266
|
* ```typescript
|
|
267
|
-
* client.apiClient.once('connected', ({ network }) => {
|
|
268
|
-
* console.log('Switched to the network:', network);
|
|
269
|
-
* });
|
|
270
267
|
* await client.account.switchNetwork('fast');
|
|
268
|
+
* console.log('Switched to the fast network, now lets wait until we get list of models');
|
|
269
|
+
* await client.projects.waitForModels();
|
|
271
270
|
* ```
|
|
272
|
-
* @param network
|
|
271
|
+
* @param network - Network type to switch to
|
|
273
272
|
*/
|
|
274
|
-
async switchNetwork(network: SupernetType) {
|
|
273
|
+
async switchNetwork(network: SupernetType): Promise<SupernetType> {
|
|
275
274
|
this.currentAccount._update({
|
|
276
|
-
networkStatus: '
|
|
275
|
+
networkStatus: 'switching',
|
|
277
276
|
network: null
|
|
278
277
|
});
|
|
279
|
-
this.client.socket.switchNetwork(network);
|
|
278
|
+
const newNetwork = await this.client.socket.switchNetwork(network);
|
|
279
|
+
this.currentAccount._update({
|
|
280
|
+
networkStatus: 'connected',
|
|
281
|
+
network: newNetwork
|
|
282
|
+
});
|
|
283
|
+
return newNetwork;
|
|
280
284
|
}
|
|
281
285
|
|
|
282
286
|
/**
|
|
@@ -12,7 +12,7 @@ export type JobErrorData = {
|
|
|
12
12
|
imgID?: string;
|
|
13
13
|
isFromWorker: boolean;
|
|
14
14
|
error_message: string;
|
|
15
|
-
error: number;
|
|
15
|
+
error: number | string;
|
|
16
16
|
};
|
|
17
17
|
|
|
18
18
|
export type JobProgressData = {
|
|
@@ -63,6 +63,10 @@ export type SocketEventMap = {
|
|
|
63
63
|
* @event WebSocketClient#balanceUpdate - Received balance update
|
|
64
64
|
*/
|
|
65
65
|
balanceUpdate: BalanceData;
|
|
66
|
+
/**
|
|
67
|
+
* @event WebSocketClient#changeNetwork - Default network changed
|
|
68
|
+
*/
|
|
69
|
+
changeNetwork: { network: SupernetType };
|
|
66
70
|
/**
|
|
67
71
|
* @event WebSocketClient#jobError - Job error occurred
|
|
68
72
|
*/
|
|
@@ -19,9 +19,13 @@ class WebSocketClient extends RestClient<SocketEventMap> {
|
|
|
19
19
|
private _pingInterval: NodeJS.Timeout | null = null;
|
|
20
20
|
|
|
21
21
|
constructor(baseUrl: string, appId: string, supernetType: SupernetType, logger: Logger) {
|
|
22
|
-
|
|
22
|
+
const _baseUrl = new URL(baseUrl);
|
|
23
|
+
if (_baseUrl.protocol === 'wss:') {
|
|
24
|
+
_baseUrl.protocol = 'https:';
|
|
25
|
+
}
|
|
26
|
+
super(_baseUrl.toString(), logger);
|
|
23
27
|
this.appId = appId;
|
|
24
|
-
this.baseUrl =
|
|
28
|
+
this.baseUrl = _baseUrl.toString();
|
|
25
29
|
this._supernetType = supernetType;
|
|
26
30
|
}
|
|
27
31
|
|
|
@@ -56,6 +60,7 @@ class WebSocketClient extends RestClient<SocketEventMap> {
|
|
|
56
60
|
}
|
|
57
61
|
const userAgent = `Sogni/${LIB_VERSION} (sogni-client)`;
|
|
58
62
|
const url = new URL(this.baseUrl);
|
|
63
|
+
url.protocol = 'wss:';
|
|
59
64
|
url.searchParams.set('appId', this.appId);
|
|
60
65
|
url.searchParams.set('clientName', userAgent);
|
|
61
66
|
url.searchParams.set('clientType', 'artist');
|
|
@@ -108,10 +113,14 @@ class WebSocketClient extends RestClient<SocketEventMap> {
|
|
|
108
113
|
}
|
|
109
114
|
}
|
|
110
115
|
|
|
111
|
-
switchNetwork(supernetType: SupernetType) {
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
116
|
+
switchNetwork(supernetType: SupernetType): Promise<SupernetType> {
|
|
117
|
+
return new Promise<SupernetType>(async (resolve, reject) => {
|
|
118
|
+
this.once('changeNetwork', ({ network }) => {
|
|
119
|
+
this._supernetType = network;
|
|
120
|
+
resolve(network);
|
|
121
|
+
});
|
|
122
|
+
await this.send('changeNetwork', supernetType);
|
|
123
|
+
});
|
|
115
124
|
}
|
|
116
125
|
|
|
117
126
|
/**
|
|
@@ -1,7 +1,9 @@
|
|
|
1
1
|
import { JobRequestRaw } from '../../Projects/createJobRequestMessage';
|
|
2
|
+
import { SupernetType } from './types';
|
|
2
3
|
|
|
3
4
|
export interface SocketMessageMap {
|
|
4
5
|
jobRequest: JobRequestRaw;
|
|
6
|
+
changeNetwork: SupernetType;
|
|
5
7
|
}
|
|
6
8
|
|
|
7
9
|
export type MessageType = keyof SocketMessageMap;
|
package/src/Projects/Job.ts
CHANGED
|
@@ -1,5 +1,9 @@
|
|
|
1
1
|
import DataEntity, { EntityEvents } from '../lib/DataEntity';
|
|
2
2
|
import ErrorData from '../types/ErrorData';
|
|
3
|
+
import { RawJob, RawProject } from './types/RawProject';
|
|
4
|
+
import ProjectsApi from './index';
|
|
5
|
+
import { Logger } from '../lib/DefaultLogger';
|
|
6
|
+
import getUUID from '../lib/getUUID';
|
|
3
7
|
|
|
4
8
|
export type JobStatus =
|
|
5
9
|
| 'pending'
|
|
@@ -9,11 +13,23 @@ export type JobStatus =
|
|
|
9
13
|
| 'failed'
|
|
10
14
|
| 'canceled';
|
|
11
15
|
|
|
16
|
+
const JOB_STATUS_MAP: Record<RawJob['status'], JobStatus> = {
|
|
17
|
+
created: 'pending',
|
|
18
|
+
queued: 'pending',
|
|
19
|
+
assigned: 'initiating',
|
|
20
|
+
initiatingModel: 'initiating',
|
|
21
|
+
jobStarted: 'processing',
|
|
22
|
+
jobProgress: 'processing',
|
|
23
|
+
jobCompleted: 'completed',
|
|
24
|
+
jobError: 'failed'
|
|
25
|
+
};
|
|
26
|
+
|
|
12
27
|
/**
|
|
13
28
|
* @inline
|
|
14
29
|
*/
|
|
15
30
|
export interface JobData {
|
|
16
31
|
id: string;
|
|
32
|
+
projectId: string;
|
|
17
33
|
status: JobStatus;
|
|
18
34
|
step: number;
|
|
19
35
|
stepCount: number;
|
|
@@ -32,9 +48,37 @@ export interface JobEventMap extends EntityEvents {
|
|
|
32
48
|
failed: ErrorData;
|
|
33
49
|
}
|
|
34
50
|
|
|
51
|
+
export interface JobOptions {
|
|
52
|
+
api: ProjectsApi;
|
|
53
|
+
logger: Logger;
|
|
54
|
+
}
|
|
55
|
+
|
|
35
56
|
class Job extends DataEntity<JobData, JobEventMap> {
|
|
36
|
-
|
|
57
|
+
static fromRaw(rawProject: RawProject, rawJob: RawJob, options: JobOptions) {
|
|
58
|
+
return new Job(
|
|
59
|
+
{
|
|
60
|
+
id: rawJob.imgID || getUUID(),
|
|
61
|
+
projectId: rawProject.id,
|
|
62
|
+
status: JOB_STATUS_MAP[rawJob.status],
|
|
63
|
+
step: rawJob.performedSteps,
|
|
64
|
+
stepCount: rawProject.stepCount,
|
|
65
|
+
workerName: rawJob.worker.name,
|
|
66
|
+
seed: rawJob.seedUsed,
|
|
67
|
+
isNSFW: rawJob.triggeredNSFWFilter
|
|
68
|
+
},
|
|
69
|
+
options
|
|
70
|
+
);
|
|
71
|
+
}
|
|
72
|
+
|
|
73
|
+
private readonly _api: ProjectsApi;
|
|
74
|
+
private readonly _logger: Logger;
|
|
75
|
+
|
|
76
|
+
constructor(data: JobData, options: JobOptions) {
|
|
37
77
|
super(data);
|
|
78
|
+
|
|
79
|
+
this._api = options.api;
|
|
80
|
+
this._logger = options.logger;
|
|
81
|
+
|
|
38
82
|
this.on('updated', this.handleUpdated.bind(this));
|
|
39
83
|
}
|
|
40
84
|
|
|
@@ -42,6 +86,10 @@ class Job extends DataEntity<JobData, JobEventMap> {
|
|
|
42
86
|
return this.data.id;
|
|
43
87
|
}
|
|
44
88
|
|
|
89
|
+
get projectId() {
|
|
90
|
+
return this.data.projectId;
|
|
91
|
+
}
|
|
92
|
+
|
|
45
93
|
/**
|
|
46
94
|
* Current status of the job.
|
|
47
95
|
*/
|
|
@@ -49,6 +97,10 @@ class Job extends DataEntity<JobData, JobEventMap> {
|
|
|
49
97
|
return this.data.status;
|
|
50
98
|
}
|
|
51
99
|
|
|
100
|
+
get finished() {
|
|
101
|
+
return ['completed', 'failed', 'canceled'].includes(this.status);
|
|
102
|
+
}
|
|
103
|
+
|
|
52
104
|
/**
|
|
53
105
|
* Progress of the job in percentage (0-100).
|
|
54
106
|
*/
|
|
@@ -116,6 +168,35 @@ class Job extends DataEntity<JobData, JobEventMap> {
|
|
|
116
168
|
return this.data.workerName;
|
|
117
169
|
}
|
|
118
170
|
|
|
171
|
+
/**
|
|
172
|
+
* Syncs the job data with the data received from the REST API.
|
|
173
|
+
* @internal
|
|
174
|
+
* @param data
|
|
175
|
+
*/
|
|
176
|
+
async _syncWithRestData(data: RawJob) {
|
|
177
|
+
const delta: Partial<JobData> = {
|
|
178
|
+
step: data.performedSteps,
|
|
179
|
+
workerName: data.worker.name,
|
|
180
|
+
seed: data.seedUsed,
|
|
181
|
+
isNSFW: data.triggeredNSFWFilter
|
|
182
|
+
};
|
|
183
|
+
if (JOB_STATUS_MAP[data.status]) {
|
|
184
|
+
delta.status = JOB_STATUS_MAP[data.status];
|
|
185
|
+
}
|
|
186
|
+
if (!this.data.resultUrl && delta.status === 'completed' && !data.triggeredNSFWFilter) {
|
|
187
|
+
try {
|
|
188
|
+
delta.resultUrl = await this._api.downloadUrl({
|
|
189
|
+
jobId: this.projectId,
|
|
190
|
+
imageId: this.id,
|
|
191
|
+
type: 'complete'
|
|
192
|
+
});
|
|
193
|
+
} catch (error) {
|
|
194
|
+
this._logger.error(error);
|
|
195
|
+
}
|
|
196
|
+
}
|
|
197
|
+
this._update(delta);
|
|
198
|
+
}
|
|
199
|
+
|
|
119
200
|
private handleUpdated(keys: string[]) {
|
|
120
201
|
if (keys.includes('step') || keys.includes('stepCount')) {
|
|
121
202
|
this.emit('progress', this.progress);
|
package/src/Projects/Project.ts
CHANGED
|
@@ -1,11 +1,34 @@
|
|
|
1
|
-
import Job, { JobData } from './Job';
|
|
1
|
+
import Job, { JobData, JobStatus } from './Job';
|
|
2
2
|
import DataEntity, { EntityEvents } from '../lib/DataEntity';
|
|
3
3
|
import { ProjectParams } from './types';
|
|
4
4
|
import cloneDeep from 'lodash/cloneDeep';
|
|
5
5
|
import ErrorData from '../types/ErrorData';
|
|
6
6
|
import getUUID from '../lib/getUUID';
|
|
7
|
+
import { RawJob, RawProject } from './types/RawProject';
|
|
8
|
+
import ProjectsApi from './index';
|
|
9
|
+
import { Logger } from '../lib/DefaultLogger';
|
|
7
10
|
|
|
8
|
-
|
|
11
|
+
// If project is not finished and had no updates for 1 minute, force refresh
|
|
12
|
+
const PROJECT_TIMEOUT = 60 * 1000;
|
|
13
|
+
const MAX_FAILED_SYNC_ATTEMPTS = 3;
|
|
14
|
+
|
|
15
|
+
export type ProjectStatus =
|
|
16
|
+
| 'pending'
|
|
17
|
+
| 'queued'
|
|
18
|
+
| 'processing'
|
|
19
|
+
| 'completed'
|
|
20
|
+
| 'failed'
|
|
21
|
+
| 'canceled';
|
|
22
|
+
|
|
23
|
+
const PROJECT_STATUS_MAP: Record<RawProject['status'], ProjectStatus> = {
|
|
24
|
+
pending: 'pending',
|
|
25
|
+
active: 'queued',
|
|
26
|
+
assigned: 'processing',
|
|
27
|
+
progress: 'processing',
|
|
28
|
+
completed: 'completed',
|
|
29
|
+
errored: 'failed',
|
|
30
|
+
cancelled: 'canceled'
|
|
31
|
+
};
|
|
9
32
|
|
|
10
33
|
/**
|
|
11
34
|
* @inline
|
|
@@ -31,11 +54,20 @@ export interface ProjectEventMap extends EntityEvents {
|
|
|
31
54
|
jobFailed: Job;
|
|
32
55
|
}
|
|
33
56
|
|
|
57
|
+
export interface ProjectOptions {
|
|
58
|
+
api: ProjectsApi;
|
|
59
|
+
logger: Logger;
|
|
60
|
+
}
|
|
61
|
+
|
|
34
62
|
class Project extends DataEntity<ProjectData, ProjectEventMap> {
|
|
35
63
|
private _jobs: Job[] = [];
|
|
36
64
|
private _lastEmitedProgress = -1;
|
|
65
|
+
private readonly _api: ProjectsApi;
|
|
66
|
+
private readonly _logger: Logger;
|
|
67
|
+
private _timeout: NodeJS.Timeout | null = null;
|
|
68
|
+
private _failedSyncAttempts = 0;
|
|
37
69
|
|
|
38
|
-
constructor(data: ProjectParams) {
|
|
70
|
+
constructor(data: ProjectParams, options: ProjectOptions) {
|
|
39
71
|
super({
|
|
40
72
|
id: getUUID(),
|
|
41
73
|
startedAt: new Date(),
|
|
@@ -44,6 +76,11 @@ class Project extends DataEntity<ProjectData, ProjectEventMap> {
|
|
|
44
76
|
status: 'pending'
|
|
45
77
|
});
|
|
46
78
|
|
|
79
|
+
this._api = options.api;
|
|
80
|
+
this._logger = options.logger;
|
|
81
|
+
|
|
82
|
+
this._timeout = setInterval(this._checkForTimeout.bind(this), PROJECT_TIMEOUT);
|
|
83
|
+
|
|
47
84
|
this.on('updated', this.handleUpdated.bind(this));
|
|
48
85
|
}
|
|
49
86
|
|
|
@@ -59,6 +96,10 @@ class Project extends DataEntity<ProjectData, ProjectEventMap> {
|
|
|
59
96
|
return this.data.status;
|
|
60
97
|
}
|
|
61
98
|
|
|
99
|
+
get finished() {
|
|
100
|
+
return ['completed', 'failed', 'canceled'].includes(this.status);
|
|
101
|
+
}
|
|
102
|
+
|
|
62
103
|
get error() {
|
|
63
104
|
return this.data.error;
|
|
64
105
|
}
|
|
@@ -137,6 +178,11 @@ class Project extends DataEntity<ProjectData, ProjectEventMap> {
|
|
|
137
178
|
this.emit('progress', progress);
|
|
138
179
|
this._lastEmitedProgress = progress;
|
|
139
180
|
}
|
|
181
|
+
// If project is finished stop watching for timeout
|
|
182
|
+
if (this._timeout && this.finished) {
|
|
183
|
+
clearInterval(this._timeout!);
|
|
184
|
+
this._timeout = null;
|
|
185
|
+
}
|
|
140
186
|
if (keys.includes('status') || keys.includes('jobs')) {
|
|
141
187
|
const allJobsDone = this.jobs.every((job) =>
|
|
142
188
|
['completed', 'failed', 'canceled'].includes(job.status)
|
|
@@ -155,21 +201,104 @@ class Project extends DataEntity<ProjectData, ProjectEventMap> {
|
|
|
155
201
|
* @internal
|
|
156
202
|
* @param data
|
|
157
203
|
*/
|
|
158
|
-
_addJob(data: JobData) {
|
|
159
|
-
const job =
|
|
204
|
+
_addJob(data: JobData | Job) {
|
|
205
|
+
const job =
|
|
206
|
+
data instanceof Job ? data : new Job(data, { api: this._api, logger: this._logger });
|
|
160
207
|
this._jobs.push(job);
|
|
161
208
|
job.on('updated', () => {
|
|
209
|
+
this.lastUpdated = new Date();
|
|
162
210
|
this.emit('updated', ['jobs']);
|
|
163
211
|
});
|
|
164
212
|
job.on('completed', () => {
|
|
165
213
|
this.emit('jobCompleted', job);
|
|
214
|
+
this._handleJobFinished(job);
|
|
166
215
|
});
|
|
167
216
|
job.on('failed', () => {
|
|
168
217
|
this.emit('jobFailed', job);
|
|
218
|
+
this._handleJobFinished(job);
|
|
169
219
|
});
|
|
170
220
|
return job;
|
|
171
221
|
}
|
|
172
222
|
|
|
223
|
+
private _handleJobFinished(job: Job) {
|
|
224
|
+
const finalStatus: JobStatus[] = ['completed', 'failed', 'canceled'];
|
|
225
|
+
const allJobsDone = this.jobs.every((job) => finalStatus.includes(job.status));
|
|
226
|
+
// If all jobs are done and project is not already failed or completed, update the project status
|
|
227
|
+
if (allJobsDone && this.status !== 'failed' && this.status !== 'completed') {
|
|
228
|
+
const allJobsFailed = this.jobs.every((job) => job.status === 'failed');
|
|
229
|
+
if (allJobsFailed) {
|
|
230
|
+
this._update({ status: 'failed' });
|
|
231
|
+
} else {
|
|
232
|
+
this._update({ status: 'completed' });
|
|
233
|
+
}
|
|
234
|
+
}
|
|
235
|
+
}
|
|
236
|
+
|
|
237
|
+
private _checkForTimeout() {
|
|
238
|
+
if (this.lastUpdated.getTime() + PROJECT_TIMEOUT < Date.now()) {
|
|
239
|
+
this._syncToServer().catch((error) => {
|
|
240
|
+
this._logger.error(error);
|
|
241
|
+
this._failedSyncAttempts++;
|
|
242
|
+
if (this._failedSyncAttempts > MAX_FAILED_SYNC_ATTEMPTS) {
|
|
243
|
+
this._logger.error(
|
|
244
|
+
`Failed to sync project data after ${MAX_FAILED_SYNC_ATTEMPTS} attempts. Stopping further attempts.`
|
|
245
|
+
);
|
|
246
|
+
clearInterval(this._timeout!);
|
|
247
|
+
this._timeout = null;
|
|
248
|
+
}
|
|
249
|
+
});
|
|
250
|
+
}
|
|
251
|
+
}
|
|
252
|
+
|
|
253
|
+
/**
|
|
254
|
+
* Sync project data with the data received from the REST API.
|
|
255
|
+
* @internal
|
|
256
|
+
*/
|
|
257
|
+
async _syncToServer() {
|
|
258
|
+
const data = await this._api.get(this.id);
|
|
259
|
+
const jobData = data.completedWorkerJobs.reduce((acc: Record<string, RawJob>, job) => {
|
|
260
|
+
const jobId = job.imgID || getUUID();
|
|
261
|
+
acc[jobId] = job;
|
|
262
|
+
return acc;
|
|
263
|
+
}, {});
|
|
264
|
+
for (const job of this._jobs) {
|
|
265
|
+
const restJob = jobData[job.id];
|
|
266
|
+
// This should never happen, but just in case we log a warning
|
|
267
|
+
if (!restJob) {
|
|
268
|
+
this._logger.warn(`Job with id ${job.id} not found in the REST project data`);
|
|
269
|
+
return;
|
|
270
|
+
}
|
|
271
|
+
try {
|
|
272
|
+
await job._syncWithRestData(restJob);
|
|
273
|
+
} catch (error) {
|
|
274
|
+
this._logger.error(error);
|
|
275
|
+
this._logger.error(`Failed to sync job ${job.id}`);
|
|
276
|
+
}
|
|
277
|
+
delete jobData[job.id];
|
|
278
|
+
}
|
|
279
|
+
|
|
280
|
+
// If there are any jobs left in jobData, it means they are new jobs that are not in the project yet
|
|
281
|
+
if (Object.keys(jobData).length) {
|
|
282
|
+
for (const job of Object.values(jobData)) {
|
|
283
|
+
const jobInstance = Job.fromRaw(data, job, { api: this._api, logger: this._logger });
|
|
284
|
+
this._addJob(jobInstance);
|
|
285
|
+
}
|
|
286
|
+
}
|
|
287
|
+
|
|
288
|
+
const delta: Partial<ProjectData> = {
|
|
289
|
+
params: {
|
|
290
|
+
...this.data.params,
|
|
291
|
+
numberOfImages: data.imageCount,
|
|
292
|
+
steps: data.stepCount,
|
|
293
|
+
numberOfPreviews: data.previewCount
|
|
294
|
+
}
|
|
295
|
+
};
|
|
296
|
+
if (PROJECT_STATUS_MAP[data.status]) {
|
|
297
|
+
delta.status = PROJECT_STATUS_MAP[data.status];
|
|
298
|
+
}
|
|
299
|
+
this._update(delta);
|
|
300
|
+
}
|
|
301
|
+
|
|
173
302
|
/**
|
|
174
303
|
* Get full project data snapshot. Can be used to serialize the project and store it in a database.
|
|
175
304
|
*/
|
|
@@ -67,7 +67,7 @@ function getTemplate() {
|
|
|
67
67
|
|
|
68
68
|
function createJobRequestMessage(id: string, params: ProjectParams) {
|
|
69
69
|
const template = getTemplate();
|
|
70
|
-
|
|
70
|
+
const jobRequest: Record<string, any> = {
|
|
71
71
|
...template,
|
|
72
72
|
keyFrames: [
|
|
73
73
|
{
|
|
@@ -92,6 +92,10 @@ function createJobRequestMessage(id: string, params: ProjectParams) {
|
|
|
92
92
|
jobID: id,
|
|
93
93
|
disableSafety: !!params.disableNSFWFilter
|
|
94
94
|
};
|
|
95
|
+
if (params.network) {
|
|
96
|
+
jobRequest.network = params.network;
|
|
97
|
+
}
|
|
98
|
+
return jobRequest;
|
|
95
99
|
}
|
|
96
100
|
|
|
97
101
|
export type JobRequestRaw = ReturnType<typeof createJobRequestMessage>;
|
package/src/Projects/index.ts
CHANGED
|
@@ -14,9 +14,28 @@ import { ApiError, ApiReponse } from '../ApiClient';
|
|
|
14
14
|
import { EstimationResponse } from './types/EstimationResponse';
|
|
15
15
|
import { JobEvent, ProjectApiEvents, ProjectEvent } from './types/events';
|
|
16
16
|
import getUUID from '../lib/getUUID';
|
|
17
|
+
import { RawProject } from './types/RawProject';
|
|
18
|
+
import ErrorData from '../types/ErrorData';
|
|
17
19
|
|
|
18
20
|
const GARBAGE_COLLECT_TIMEOUT = 10000;
|
|
19
21
|
|
|
22
|
+
function mapErrorCodes(code: string): number {
|
|
23
|
+
switch (code) {
|
|
24
|
+
case 'serverRestarting':
|
|
25
|
+
return 5001;
|
|
26
|
+
case 'workerDisconnected':
|
|
27
|
+
return 5002;
|
|
28
|
+
case 'jobTimedOut':
|
|
29
|
+
return 5003;
|
|
30
|
+
case 'artistCanceled':
|
|
31
|
+
return 5004;
|
|
32
|
+
case 'ç':
|
|
33
|
+
return 5005;
|
|
34
|
+
default:
|
|
35
|
+
return 5000;
|
|
36
|
+
}
|
|
37
|
+
}
|
|
38
|
+
|
|
20
39
|
class ProjectsApi extends ApiGroup<ProjectApiEvents> {
|
|
21
40
|
private _availableModels: AvailableModel[] = [];
|
|
22
41
|
private projects: Project[] = [];
|
|
@@ -28,6 +47,7 @@ class ProjectsApi extends ApiGroup<ProjectApiEvents> {
|
|
|
28
47
|
constructor(config: ApiConfig) {
|
|
29
48
|
super(config);
|
|
30
49
|
// Listen to server events and emit them as project and job events
|
|
50
|
+
this.client.socket.on('changeNetwork', this.handleChangeNetwork.bind(this));
|
|
31
51
|
this.client.socket.on('swarmModels', this.handleSwarmModels.bind(this));
|
|
32
52
|
this.client.socket.on('jobState', this.handleJobState.bind(this));
|
|
33
53
|
this.client.socket.on('jobProgress', this.handleJobProgress.bind(this));
|
|
@@ -40,6 +60,11 @@ class ProjectsApi extends ApiGroup<ProjectApiEvents> {
|
|
|
40
60
|
this.on('job', this.handleJobEvent.bind(this));
|
|
41
61
|
}
|
|
42
62
|
|
|
63
|
+
private handleChangeNetwork() {
|
|
64
|
+
this._availableModels = [];
|
|
65
|
+
this.emit('availableModels', this._availableModels);
|
|
66
|
+
}
|
|
67
|
+
|
|
43
68
|
private handleSwarmModels(data: SocketEventMap['swarmModels']) {
|
|
44
69
|
const modelIndex = models.reduce((acc: Record<string, any>, model) => {
|
|
45
70
|
acc[model.modelId] = model;
|
|
@@ -137,14 +162,25 @@ class ProjectsApi extends ApiGroup<ProjectApiEvents> {
|
|
|
137
162
|
}
|
|
138
163
|
|
|
139
164
|
private handleJobError(data: JobErrorData) {
|
|
165
|
+
const errorCode = Number(data.error);
|
|
166
|
+
let error: ErrorData;
|
|
167
|
+
if (!isNaN(errorCode)) {
|
|
168
|
+
error = {
|
|
169
|
+
code: errorCode,
|
|
170
|
+
message: data.error_message
|
|
171
|
+
};
|
|
172
|
+
} else {
|
|
173
|
+
error = {
|
|
174
|
+
code: mapErrorCodes(data.error as string),
|
|
175
|
+
originalCode: data.error?.toString(),
|
|
176
|
+
message: data.error_message
|
|
177
|
+
};
|
|
178
|
+
}
|
|
140
179
|
if (!data.imgID) {
|
|
141
180
|
this.emit('project', {
|
|
142
181
|
type: 'error',
|
|
143
182
|
projectId: data.jobID,
|
|
144
|
-
error
|
|
145
|
-
code: Number(data.error),
|
|
146
|
-
message: data.error_message
|
|
147
|
-
}
|
|
183
|
+
error
|
|
148
184
|
});
|
|
149
185
|
return;
|
|
150
186
|
}
|
|
@@ -152,10 +188,7 @@ class ProjectsApi extends ApiGroup<ProjectApiEvents> {
|
|
|
152
188
|
type: 'error',
|
|
153
189
|
projectId: data.jobID,
|
|
154
190
|
jobId: data.imgID,
|
|
155
|
-
error:
|
|
156
|
-
code: Number(data.error),
|
|
157
|
-
message: data.error_message
|
|
158
|
-
}
|
|
191
|
+
error: error
|
|
159
192
|
});
|
|
160
193
|
}
|
|
161
194
|
|
|
@@ -182,7 +215,11 @@ class ProjectsApi extends ApiGroup<ProjectApiEvents> {
|
|
|
182
215
|
error: event.error
|
|
183
216
|
});
|
|
184
217
|
}
|
|
185
|
-
if (project.
|
|
218
|
+
if (project.finished) {
|
|
219
|
+
// Sync project data with the server and remove it from the list after some time
|
|
220
|
+
project._syncToServer().catch((e) => {
|
|
221
|
+
this.client.logger.error(e);
|
|
222
|
+
});
|
|
186
223
|
setTimeout(() => {
|
|
187
224
|
this.projects = this.projects.filter((p) => p.id !== event.projectId);
|
|
188
225
|
}, GARBAGE_COLLECT_TIMEOUT);
|
|
@@ -198,6 +235,7 @@ class ProjectsApi extends ApiGroup<ProjectApiEvents> {
|
|
|
198
235
|
if (!job) {
|
|
199
236
|
job = project._addJob({
|
|
200
237
|
id: event.jobId,
|
|
238
|
+
projectId: event.projectId,
|
|
201
239
|
status: 'pending',
|
|
202
240
|
step: 0,
|
|
203
241
|
stepCount: project.params.steps
|
|
@@ -205,10 +243,10 @@ class ProjectsApi extends ApiGroup<ProjectApiEvents> {
|
|
|
205
243
|
}
|
|
206
244
|
switch (event.type) {
|
|
207
245
|
case 'initiating':
|
|
208
|
-
job._update({ status: 'initiating' });
|
|
246
|
+
job._update({ status: 'initiating', workerName: event.workerName });
|
|
209
247
|
break;
|
|
210
248
|
case 'started':
|
|
211
|
-
job._update({ status: 'processing' });
|
|
249
|
+
job._update({ status: 'processing', workerName: event.workerName });
|
|
212
250
|
break;
|
|
213
251
|
case 'progress':
|
|
214
252
|
job._update({
|
|
@@ -278,7 +316,7 @@ class ProjectsApi extends ApiGroup<ProjectApiEvents> {
|
|
|
278
316
|
* @param data
|
|
279
317
|
*/
|
|
280
318
|
async create(data: ProjectParams): Promise<Project> {
|
|
281
|
-
const project = new Project({ ...data });
|
|
319
|
+
const project = new Project({ ...data }, { api: this, logger: this.client.logger });
|
|
282
320
|
if (data.startingImage) {
|
|
283
321
|
await this.uploadGuideImage(project.id, data.startingImage);
|
|
284
322
|
}
|
|
@@ -288,6 +326,19 @@ class ProjectsApi extends ApiGroup<ProjectApiEvents> {
|
|
|
288
326
|
return project;
|
|
289
327
|
}
|
|
290
328
|
|
|
329
|
+
/**
|
|
330
|
+
* Get project by id, this API returns project data from the server only if the project is
|
|
331
|
+
* completed or failed. If the project is still processing, it will throw 404 error.
|
|
332
|
+
* @internal
|
|
333
|
+
* @param projectId
|
|
334
|
+
*/
|
|
335
|
+
async get(projectId: string) {
|
|
336
|
+
const { data } = await this.client.rest.get<ApiReponse<RawProject>>(
|
|
337
|
+
`/v1/projects/${projectId}`
|
|
338
|
+
);
|
|
339
|
+
return data;
|
|
340
|
+
}
|
|
341
|
+
|
|
291
342
|
private async uploadGuideImage(projectId: string, file: File | Buffer | Blob) {
|
|
292
343
|
const imageId = getUUID();
|
|
293
344
|
const presignedUrl = await this.uploadUrl({
|
|
@@ -330,7 +381,12 @@ class ProjectsApi extends ApiGroup<ProjectApiEvents> {
|
|
|
330
381
|
};
|
|
331
382
|
}
|
|
332
383
|
|
|
333
|
-
|
|
384
|
+
/**
|
|
385
|
+
* Get upload URL for image
|
|
386
|
+
* @internal
|
|
387
|
+
* @param params
|
|
388
|
+
*/
|
|
389
|
+
async uploadUrl(params: ImageUrlParams) {
|
|
334
390
|
const r = await this.client.rest.get<ApiReponse<{ uploadUrl: string }>>(
|
|
335
391
|
`/v1/image/uploadUrl`,
|
|
336
392
|
params
|
|
@@ -338,7 +394,12 @@ class ProjectsApi extends ApiGroup<ProjectApiEvents> {
|
|
|
338
394
|
return r.data.uploadUrl;
|
|
339
395
|
}
|
|
340
396
|
|
|
341
|
-
|
|
397
|
+
/**
|
|
398
|
+
* Get download URL for image
|
|
399
|
+
* @internal
|
|
400
|
+
* @param params
|
|
401
|
+
*/
|
|
402
|
+
async downloadUrl(params: ImageUrlParams) {
|
|
342
403
|
const r = await this.client.rest.get<ApiReponse<{ downloadUrl: string }>>(
|
|
343
404
|
`/v1/image/downloadUrl`,
|
|
344
405
|
params
|