@sogni-ai/sogni-client 0.3.1 → 0.3.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +1 -1
- package/dist/version.d.ts +1 -1
- package/dist/version.js +1 -1
- package/package.json +5 -3
- package/src/Account/CurrentAccount.ts +101 -0
- package/src/Account/index.ts +243 -0
- package/src/Account/types.ts +90 -0
- package/src/ApiClient/WebSocketClient/ErrorCode.ts +15 -0
- package/src/ApiClient/WebSocketClient/events.ts +94 -0
- package/src/ApiClient/WebSocketClient/index.ts +203 -0
- package/src/ApiClient/WebSocketClient/messages.ts +7 -0
- package/src/ApiClient/WebSocketClient/types.ts +1 -0
- package/src/ApiClient/events.ts +20 -0
- package/src/ApiClient/index.ts +124 -0
- package/src/ApiGroup.ts +25 -0
- package/src/Projects/Job.ts +124 -0
- package/src/Projects/Project.ts +185 -0
- package/src/Projects/createJobRequestMessage.ts +99 -0
- package/src/Projects/index.ts +340 -0
- package/src/Projects/models.json +8906 -0
- package/src/Projects/types/EstimationResponse.ts +45 -0
- package/src/Projects/types/events.ts +78 -0
- package/src/Projects/types/index.ts +146 -0
- package/src/Stats/index.ts +15 -0
- package/src/Stats/types.ts +34 -0
- package/src/events.ts +5 -0
- package/src/index.ts +120 -0
- package/src/lib/DataEntity.ts +38 -0
- package/src/lib/DefaultLogger.ts +47 -0
- package/src/lib/EIP712Helper.ts +57 -0
- package/src/lib/RestClient.ts +76 -0
- package/src/lib/TypedEventEmitter.ts +66 -0
- package/src/lib/base64.ts +9 -0
- package/src/lib/getUUID.ts +8 -0
- package/src/lib/isNodejs.ts +4 -0
- package/src/types/ErrorData.ts +6 -0
- package/src/types/json.ts +5 -0
- package/src/version.ts +1 -0
|
@@ -0,0 +1,203 @@
|
|
|
1
|
+
import { MessageType, SocketMessageMap } from './messages';
|
|
2
|
+
import { SocketEventMap } from './events';
|
|
3
|
+
import RestClient, { AuthData } from '../../lib/RestClient';
|
|
4
|
+
import { SupernetType } from './types';
|
|
5
|
+
import WebSocket, { CloseEvent, ErrorEvent, MessageEvent } from 'isomorphic-ws';
|
|
6
|
+
import { base64Decode, base64Encode } from '../../lib/base64';
|
|
7
|
+
import isNodejs from '../../lib/isNodejs';
|
|
8
|
+
import Cookie from 'js-cookie';
|
|
9
|
+
import { LIB_VERSION } from '../../version';
|
|
10
|
+
import { Logger } from '../../lib/DefaultLogger';
|
|
11
|
+
|
|
12
|
+
const PING_INTERVAL = 15000;
|
|
13
|
+
|
|
14
|
+
class WebSocketClient extends RestClient<SocketEventMap> {
|
|
15
|
+
appId: string;
|
|
16
|
+
baseUrl: string;
|
|
17
|
+
private socket: WebSocket | null = null;
|
|
18
|
+
private _supernetType: SupernetType;
|
|
19
|
+
private _pingInterval: NodeJS.Timeout | null = null;
|
|
20
|
+
|
|
21
|
+
constructor(baseUrl: string, appId: string, supernetType: SupernetType, logger: Logger) {
|
|
22
|
+
super(baseUrl, logger);
|
|
23
|
+
this.appId = appId;
|
|
24
|
+
this.baseUrl = baseUrl;
|
|
25
|
+
this._supernetType = supernetType;
|
|
26
|
+
}
|
|
27
|
+
|
|
28
|
+
set auth(auth: AuthData | null) {
|
|
29
|
+
//In browser, set the cookie
|
|
30
|
+
if (!isNodejs) {
|
|
31
|
+
if (auth) {
|
|
32
|
+
Cookie.set('authorization', auth.token, {
|
|
33
|
+
domain: '.sogni.ai',
|
|
34
|
+
expires: 1
|
|
35
|
+
});
|
|
36
|
+
} else {
|
|
37
|
+
Cookie.remove('authorization', {
|
|
38
|
+
domain: '.sogni.ai'
|
|
39
|
+
});
|
|
40
|
+
}
|
|
41
|
+
}
|
|
42
|
+
this._auth = auth;
|
|
43
|
+
}
|
|
44
|
+
|
|
45
|
+
get supernetType(): SupernetType {
|
|
46
|
+
return this._supernetType;
|
|
47
|
+
}
|
|
48
|
+
|
|
49
|
+
get isConnected(): boolean {
|
|
50
|
+
return !!this.socket;
|
|
51
|
+
}
|
|
52
|
+
|
|
53
|
+
connect() {
|
|
54
|
+
if (this.socket) {
|
|
55
|
+
this.disconnect();
|
|
56
|
+
}
|
|
57
|
+
const userAgent = `Sogni/${LIB_VERSION} (sogni-client)`;
|
|
58
|
+
const url = new URL(this.baseUrl);
|
|
59
|
+
url.searchParams.set('appId', this.appId);
|
|
60
|
+
url.searchParams.set('clientName', userAgent);
|
|
61
|
+
url.searchParams.set('clientType', 'artist');
|
|
62
|
+
//At this point 'relaxed' does not work as expected, so we use 'fast' or empty
|
|
63
|
+
url.searchParams.set('forceWorkerId', this._supernetType === 'fast' ? 'fast' : '');
|
|
64
|
+
let params;
|
|
65
|
+
// In Node.js, ws package is used, so we need to set the auth header
|
|
66
|
+
if (isNodejs) {
|
|
67
|
+
params = {
|
|
68
|
+
headers: {
|
|
69
|
+
Authorization: this._auth?.token,
|
|
70
|
+
'User-Agent': userAgent
|
|
71
|
+
}
|
|
72
|
+
};
|
|
73
|
+
}
|
|
74
|
+
this.socket = new WebSocket(url.toString(), params);
|
|
75
|
+
this.socket.onerror = this.handleError.bind(this);
|
|
76
|
+
this.socket.onmessage = this.handleMessage.bind(this);
|
|
77
|
+
this.socket.onopen = this.handleOpen.bind(this);
|
|
78
|
+
this.socket.onclose = this.handleClose.bind(this);
|
|
79
|
+
this.startPing(this.socket);
|
|
80
|
+
}
|
|
81
|
+
|
|
82
|
+
disconnect() {
|
|
83
|
+
if (!this.socket) {
|
|
84
|
+
return;
|
|
85
|
+
}
|
|
86
|
+
const socket = this.socket;
|
|
87
|
+
this.socket = null;
|
|
88
|
+
socket.onerror = null;
|
|
89
|
+
socket.onmessage = null;
|
|
90
|
+
socket.onopen = null;
|
|
91
|
+
this.stopPing();
|
|
92
|
+
socket.close();
|
|
93
|
+
}
|
|
94
|
+
|
|
95
|
+
private startPing(socket: WebSocket) {
|
|
96
|
+
if (!isNodejs) {
|
|
97
|
+
return;
|
|
98
|
+
}
|
|
99
|
+
this._pingInterval = setInterval(() => {
|
|
100
|
+
socket.ping();
|
|
101
|
+
}, PING_INTERVAL);
|
|
102
|
+
}
|
|
103
|
+
|
|
104
|
+
private stopPing() {
|
|
105
|
+
if (this._pingInterval) {
|
|
106
|
+
clearInterval(this._pingInterval);
|
|
107
|
+
this._pingInterval = null;
|
|
108
|
+
}
|
|
109
|
+
}
|
|
110
|
+
|
|
111
|
+
switchNetwork(supernetType: SupernetType) {
|
|
112
|
+
this._supernetType = supernetType;
|
|
113
|
+
this.disconnect();
|
|
114
|
+
this.connect();
|
|
115
|
+
}
|
|
116
|
+
|
|
117
|
+
/**
|
|
118
|
+
* Ensure the WebSocket connection is open, waiting if necessary and throwing an error if it fails
|
|
119
|
+
* @private
|
|
120
|
+
*/
|
|
121
|
+
private async waitForConnection(): Promise<void> {
|
|
122
|
+
if (!this.socket) {
|
|
123
|
+
throw new Error('WebSocket not connected');
|
|
124
|
+
}
|
|
125
|
+
if (this.socket.readyState === WebSocket.OPEN) {
|
|
126
|
+
return;
|
|
127
|
+
}
|
|
128
|
+
let attempts = 10;
|
|
129
|
+
while (this.socket?.readyState === WebSocket.CONNECTING) {
|
|
130
|
+
this._logger.info('Waiting for WebSocket connection...');
|
|
131
|
+
await new Promise((resolve) => setTimeout(resolve, 1000));
|
|
132
|
+
attempts--;
|
|
133
|
+
if (attempts === 0) {
|
|
134
|
+
this.disconnect();
|
|
135
|
+
throw new Error('WebSocket connection timeout');
|
|
136
|
+
}
|
|
137
|
+
}
|
|
138
|
+
//@ts-expect-error State may change between checks
|
|
139
|
+
if (this.socket?.readyState !== WebSocket.OPEN) {
|
|
140
|
+
this.disconnect();
|
|
141
|
+
throw new Error('WebSocket connection failed');
|
|
142
|
+
}
|
|
143
|
+
}
|
|
144
|
+
|
|
145
|
+
private handleOpen() {
|
|
146
|
+
this.emit('connected', { network: this._supernetType });
|
|
147
|
+
}
|
|
148
|
+
|
|
149
|
+
private handleClose(e: CloseEvent) {
|
|
150
|
+
if (e.target === this.socket) {
|
|
151
|
+
this._logger.info('WebSocket disconnected, cleanup', e);
|
|
152
|
+
this.disconnect();
|
|
153
|
+
this.emit('disconnected', {
|
|
154
|
+
code: e.code,
|
|
155
|
+
reason: e.reason
|
|
156
|
+
});
|
|
157
|
+
}
|
|
158
|
+
}
|
|
159
|
+
|
|
160
|
+
private handleError(e: ErrorEvent) {
|
|
161
|
+
this._logger.error('WebSocket error:', e);
|
|
162
|
+
}
|
|
163
|
+
|
|
164
|
+
private handleMessage(e: MessageEvent) {
|
|
165
|
+
let dataPromise: Promise<string>;
|
|
166
|
+
// In Node.js, e.data is a Buffer, while in browser it's a Blob
|
|
167
|
+
if (isNodejs) {
|
|
168
|
+
dataPromise = Promise.resolve(e.data.toString());
|
|
169
|
+
} else {
|
|
170
|
+
const data = e.data as unknown as Blob;
|
|
171
|
+
dataPromise = data.text();
|
|
172
|
+
}
|
|
173
|
+
dataPromise
|
|
174
|
+
.then((str: string) => {
|
|
175
|
+
const data = JSON.parse(str);
|
|
176
|
+
let payload = null;
|
|
177
|
+
if (data.data) {
|
|
178
|
+
payload = JSON.parse(base64Decode(data.data));
|
|
179
|
+
}
|
|
180
|
+
// Convert jobID and imgID to uppercase for consistency
|
|
181
|
+
['jobID', 'imgID'].forEach((idKey) => {
|
|
182
|
+
if (payload[idKey]) {
|
|
183
|
+
payload[idKey] = payload[idKey].toUpperCase();
|
|
184
|
+
}
|
|
185
|
+
});
|
|
186
|
+
this._logger.debug('WebSocket:', data.type, payload);
|
|
187
|
+
this.emit(data.type, payload);
|
|
188
|
+
})
|
|
189
|
+
.catch((err: any) => {
|
|
190
|
+
this._logger.error('Failed to parse WebSocket message:', err);
|
|
191
|
+
});
|
|
192
|
+
}
|
|
193
|
+
|
|
194
|
+
async send<T extends MessageType>(messageType: T, data: SocketMessageMap[T]) {
|
|
195
|
+
await this.waitForConnection();
|
|
196
|
+
this._logger.debug('WebSocket send:', messageType, data);
|
|
197
|
+
this.socket!.send(
|
|
198
|
+
JSON.stringify({ type: messageType, data: base64Encode(JSON.stringify(data)) })
|
|
199
|
+
);
|
|
200
|
+
}
|
|
201
|
+
}
|
|
202
|
+
|
|
203
|
+
export default WebSocketClient;
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
export type SupernetType = 'relaxed' | 'fast';
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
import ErrorCode from './WebSocketClient/ErrorCode';
|
|
2
|
+
import { SupernetType } from './WebSocketClient/types';
|
|
3
|
+
|
|
4
|
+
export type ApiClientEvents = {
|
|
5
|
+
/**
|
|
6
|
+
* @event ApiClient#connected - The client has been connected to the server.
|
|
7
|
+
*/
|
|
8
|
+
connected: {
|
|
9
|
+
network: SupernetType;
|
|
10
|
+
};
|
|
11
|
+
/**
|
|
12
|
+
* @event ApiClient#disconnected - The client has been disconnected by the server,
|
|
13
|
+
* either authentication is lost or the server is unreachable.This event is not triggered
|
|
14
|
+
* when the client manually disconnects.
|
|
15
|
+
*/
|
|
16
|
+
disconnected: {
|
|
17
|
+
code: ErrorCode;
|
|
18
|
+
reason: string;
|
|
19
|
+
};
|
|
20
|
+
};
|
|
@@ -0,0 +1,124 @@
|
|
|
1
|
+
import RestClient from '../lib/RestClient';
|
|
2
|
+
import WebSocketClient from './WebSocketClient';
|
|
3
|
+
import { jwtDecode } from 'jwt-decode';
|
|
4
|
+
import TypedEventEmitter from '../lib/TypedEventEmitter';
|
|
5
|
+
import { ApiClientEvents } from './events';
|
|
6
|
+
import { ServerConnectData, ServerDisconnectData } from './WebSocketClient/events';
|
|
7
|
+
import { isNotRecoverable } from './WebSocketClient/ErrorCode';
|
|
8
|
+
import { JSONValue } from '../types/json';
|
|
9
|
+
import { SupernetType } from './WebSocketClient/types';
|
|
10
|
+
import { Logger } from '../lib/DefaultLogger';
|
|
11
|
+
|
|
12
|
+
const WS_RECONNECT_ATTEMPTS = 5;
|
|
13
|
+
|
|
14
|
+
export interface ApiReponse<D = JSONValue> {
|
|
15
|
+
status: 'success';
|
|
16
|
+
data: D;
|
|
17
|
+
}
|
|
18
|
+
|
|
19
|
+
/** @inline */
|
|
20
|
+
export interface ApiErrorResponse {
|
|
21
|
+
status: 'error';
|
|
22
|
+
message: string;
|
|
23
|
+
errorCode: number;
|
|
24
|
+
}
|
|
25
|
+
|
|
26
|
+
export class ApiError extends Error {
|
|
27
|
+
status: number;
|
|
28
|
+
payload: ApiErrorResponse;
|
|
29
|
+
constructor(status: number, payload: ApiErrorResponse) {
|
|
30
|
+
super(payload.message);
|
|
31
|
+
this.status = status;
|
|
32
|
+
this.payload = payload;
|
|
33
|
+
}
|
|
34
|
+
}
|
|
35
|
+
|
|
36
|
+
/**
|
|
37
|
+
* @inline
|
|
38
|
+
*/
|
|
39
|
+
interface AuthData {
|
|
40
|
+
token: string;
|
|
41
|
+
walletAddress: string;
|
|
42
|
+
expiresAt: Date;
|
|
43
|
+
}
|
|
44
|
+
|
|
45
|
+
class ApiClient extends TypedEventEmitter<ApiClientEvents> {
|
|
46
|
+
readonly appId: string;
|
|
47
|
+
readonly logger: Logger;
|
|
48
|
+
private _rest: RestClient;
|
|
49
|
+
private _socket: WebSocketClient;
|
|
50
|
+
private _auth: AuthData | null = null;
|
|
51
|
+
private _reconnectAttempts = WS_RECONNECT_ATTEMPTS;
|
|
52
|
+
|
|
53
|
+
constructor(
|
|
54
|
+
baseUrl: string,
|
|
55
|
+
socketUrl: string,
|
|
56
|
+
appId: string,
|
|
57
|
+
networkType: SupernetType,
|
|
58
|
+
logger: Logger
|
|
59
|
+
) {
|
|
60
|
+
super();
|
|
61
|
+
this.appId = appId;
|
|
62
|
+
this.logger = logger;
|
|
63
|
+
this._rest = new RestClient(baseUrl, logger);
|
|
64
|
+
this._socket = new WebSocketClient(socketUrl, appId, networkType, logger);
|
|
65
|
+
this._socket.on('connected', this.handleSocketConnect.bind(this));
|
|
66
|
+
this._socket.on('disconnected', this.handleSocketDisconnect.bind(this));
|
|
67
|
+
}
|
|
68
|
+
|
|
69
|
+
get isAuthenticated(): boolean {
|
|
70
|
+
return !!this._auth && this._auth.expiresAt > new Date();
|
|
71
|
+
}
|
|
72
|
+
|
|
73
|
+
get auth(): AuthData | null {
|
|
74
|
+
return this._auth && this._auth.expiresAt > new Date() ? this._auth : null;
|
|
75
|
+
}
|
|
76
|
+
|
|
77
|
+
get socket(): WebSocketClient {
|
|
78
|
+
return this._socket;
|
|
79
|
+
}
|
|
80
|
+
|
|
81
|
+
get rest(): RestClient {
|
|
82
|
+
return this._rest;
|
|
83
|
+
}
|
|
84
|
+
|
|
85
|
+
authenticate(token: string) {
|
|
86
|
+
const decoded = jwtDecode<{ addr: string; env: string; iat: number; exp: number }>(token);
|
|
87
|
+
this._auth = {
|
|
88
|
+
token,
|
|
89
|
+
walletAddress: decoded.addr,
|
|
90
|
+
expiresAt: new Date(decoded.exp * 1000)
|
|
91
|
+
};
|
|
92
|
+
this.rest.auth = { token };
|
|
93
|
+
this.socket.auth = { token };
|
|
94
|
+
this.socket.connect();
|
|
95
|
+
}
|
|
96
|
+
|
|
97
|
+
removeAuth() {
|
|
98
|
+
this._auth = null;
|
|
99
|
+
this.socket.disconnect();
|
|
100
|
+
}
|
|
101
|
+
|
|
102
|
+
handleSocketConnect({ network }: ServerConnectData) {
|
|
103
|
+
this._reconnectAttempts = WS_RECONNECT_ATTEMPTS;
|
|
104
|
+
this.emit('connected', { network });
|
|
105
|
+
}
|
|
106
|
+
|
|
107
|
+
handleSocketDisconnect(data: ServerDisconnectData) {
|
|
108
|
+
if (!data.code || isNotRecoverable(data.code)) {
|
|
109
|
+
this.removeAuth();
|
|
110
|
+
this.emit('disconnected', data);
|
|
111
|
+
this.logger.error('Not recoverable socket error', data);
|
|
112
|
+
return;
|
|
113
|
+
}
|
|
114
|
+
if (this._reconnectAttempts <= 0) {
|
|
115
|
+
this.emit('disconnected', data);
|
|
116
|
+
this._reconnectAttempts = WS_RECONNECT_ATTEMPTS;
|
|
117
|
+
return;
|
|
118
|
+
}
|
|
119
|
+
this._reconnectAttempts--;
|
|
120
|
+
setTimeout(() => this.socket.connect(), 1000);
|
|
121
|
+
}
|
|
122
|
+
}
|
|
123
|
+
|
|
124
|
+
export default ApiClient;
|
package/src/ApiGroup.ts
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
import { AbstractProvider } from 'ethers';
|
|
2
|
+
import ApiClient from './ApiClient';
|
|
3
|
+
import EIP712Helper from './lib/EIP712Helper';
|
|
4
|
+
import TypedEventEmitter, { EventMap } from './lib/TypedEventEmitter';
|
|
5
|
+
|
|
6
|
+
export interface ApiConfig {
|
|
7
|
+
client: ApiClient;
|
|
8
|
+
provider: AbstractProvider;
|
|
9
|
+
eip712: EIP712Helper;
|
|
10
|
+
}
|
|
11
|
+
|
|
12
|
+
abstract class ApiGroup<E extends EventMap = {}> extends TypedEventEmitter<E> {
|
|
13
|
+
protected client: ApiClient;
|
|
14
|
+
protected provider: AbstractProvider;
|
|
15
|
+
protected eip712: EIP712Helper;
|
|
16
|
+
|
|
17
|
+
constructor(config: ApiConfig) {
|
|
18
|
+
super();
|
|
19
|
+
this.client = config.client;
|
|
20
|
+
this.provider = config.provider;
|
|
21
|
+
this.eip712 = config.eip712;
|
|
22
|
+
}
|
|
23
|
+
}
|
|
24
|
+
|
|
25
|
+
export default ApiGroup;
|
|
@@ -0,0 +1,124 @@
|
|
|
1
|
+
import DataEntity, { EntityEvents } from '../lib/DataEntity';
|
|
2
|
+
import ErrorData from '../types/ErrorData';
|
|
3
|
+
|
|
4
|
+
export type JobStatus =
|
|
5
|
+
| 'pending'
|
|
6
|
+
| 'initiating'
|
|
7
|
+
| 'processing'
|
|
8
|
+
| 'completed'
|
|
9
|
+
| 'failed'
|
|
10
|
+
| 'canceled';
|
|
11
|
+
|
|
12
|
+
/**
|
|
13
|
+
* @inline
|
|
14
|
+
*/
|
|
15
|
+
export interface JobData {
|
|
16
|
+
id: string;
|
|
17
|
+
status: JobStatus;
|
|
18
|
+
step: number;
|
|
19
|
+
stepCount: number;
|
|
20
|
+
seed?: number;
|
|
21
|
+
isNSFW?: boolean;
|
|
22
|
+
userCanceled?: boolean;
|
|
23
|
+
previewUrl?: string;
|
|
24
|
+
resultUrl?: string | null;
|
|
25
|
+
error?: ErrorData;
|
|
26
|
+
}
|
|
27
|
+
|
|
28
|
+
export interface JobEventMap extends EntityEvents {
|
|
29
|
+
progress: number;
|
|
30
|
+
completed: string;
|
|
31
|
+
failed: ErrorData;
|
|
32
|
+
}
|
|
33
|
+
|
|
34
|
+
class Job extends DataEntity<JobData, JobEventMap> {
|
|
35
|
+
constructor(data: JobData) {
|
|
36
|
+
super(data);
|
|
37
|
+
this.on('updated', this.handleUpdated.bind(this));
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
get id() {
|
|
41
|
+
return this.data.id;
|
|
42
|
+
}
|
|
43
|
+
|
|
44
|
+
/**
|
|
45
|
+
* Current status of the job.
|
|
46
|
+
*/
|
|
47
|
+
get status() {
|
|
48
|
+
return this.data.status;
|
|
49
|
+
}
|
|
50
|
+
|
|
51
|
+
/**
|
|
52
|
+
* Progress of the job in percentage (0-100).
|
|
53
|
+
*/
|
|
54
|
+
get progress() {
|
|
55
|
+
return Math.round((this.data.step / this.data.stepCount) * 100);
|
|
56
|
+
}
|
|
57
|
+
|
|
58
|
+
/**
|
|
59
|
+
* Current step of the job.
|
|
60
|
+
*/
|
|
61
|
+
get step() {
|
|
62
|
+
return this.data.step;
|
|
63
|
+
}
|
|
64
|
+
|
|
65
|
+
/**
|
|
66
|
+
* Total number of steps that worker will perform.
|
|
67
|
+
*/
|
|
68
|
+
get stepCount() {
|
|
69
|
+
return this.data.stepCount;
|
|
70
|
+
}
|
|
71
|
+
|
|
72
|
+
/**
|
|
73
|
+
* Seed used to generate the image. This property is only available when the job is completed.
|
|
74
|
+
*/
|
|
75
|
+
get seed() {
|
|
76
|
+
return this.data.seed;
|
|
77
|
+
}
|
|
78
|
+
|
|
79
|
+
/**
|
|
80
|
+
* Last preview image URL generated by the worker.
|
|
81
|
+
*/
|
|
82
|
+
get previewUrl() {
|
|
83
|
+
return this.data.previewUrl;
|
|
84
|
+
}
|
|
85
|
+
|
|
86
|
+
/**
|
|
87
|
+
* URL to the result image, could be null if the job was canceled or triggered NSFW filter while
|
|
88
|
+
* it was not disabled explicitly.
|
|
89
|
+
*/
|
|
90
|
+
get resultUrl() {
|
|
91
|
+
return this.data.resultUrl;
|
|
92
|
+
}
|
|
93
|
+
|
|
94
|
+
get imageUrl() {
|
|
95
|
+
return this.data.resultUrl || this.data.previewUrl;
|
|
96
|
+
}
|
|
97
|
+
|
|
98
|
+
get error() {
|
|
99
|
+
return this.data.error;
|
|
100
|
+
}
|
|
101
|
+
|
|
102
|
+
/**
|
|
103
|
+
* Whether the image is NSFW or not. Only makes sense if job is completed.
|
|
104
|
+
* If NSFW filter is disabled, this property will always be false.
|
|
105
|
+
* If NSFW filter is enabled and the image is NSFW, image will not be available for download.
|
|
106
|
+
*/
|
|
107
|
+
get isNSFW() {
|
|
108
|
+
return !!this.data.isNSFW;
|
|
109
|
+
}
|
|
110
|
+
|
|
111
|
+
private handleUpdated(keys: string[]) {
|
|
112
|
+
if (keys.includes('step') || keys.includes('stepCount')) {
|
|
113
|
+
this.emit('progress', this.progress);
|
|
114
|
+
}
|
|
115
|
+
if (keys.includes('status') && this.status === 'completed') {
|
|
116
|
+
this.emit('completed', this.resultUrl!);
|
|
117
|
+
}
|
|
118
|
+
if (keys.includes('status') && this.status === 'failed') {
|
|
119
|
+
this.emit('failed', this.data.error!);
|
|
120
|
+
}
|
|
121
|
+
}
|
|
122
|
+
}
|
|
123
|
+
|
|
124
|
+
export default Job;
|
|
@@ -0,0 +1,185 @@
|
|
|
1
|
+
import Job, { JobData } from './Job';
|
|
2
|
+
import DataEntity, { EntityEvents } from '../lib/DataEntity';
|
|
3
|
+
import { ProjectParams } from './types';
|
|
4
|
+
import cloneDeep from 'lodash/cloneDeep';
|
|
5
|
+
import ErrorData from '../types/ErrorData';
|
|
6
|
+
import getUUID from '../lib/getUUID';
|
|
7
|
+
|
|
8
|
+
export type ProjectStatus = 'pending' | 'queued' | 'processing' | 'completed' | 'failed';
|
|
9
|
+
|
|
10
|
+
/**
|
|
11
|
+
* @inline
|
|
12
|
+
*/
|
|
13
|
+
export interface ProjectData {
|
|
14
|
+
id: string;
|
|
15
|
+
startedAt: Date;
|
|
16
|
+
params: ProjectParams;
|
|
17
|
+
queuePosition: number;
|
|
18
|
+
status: ProjectStatus;
|
|
19
|
+
error?: ErrorData;
|
|
20
|
+
}
|
|
21
|
+
/** @inline */
|
|
22
|
+
export interface SerializedProject extends ProjectData {
|
|
23
|
+
jobs: JobData[];
|
|
24
|
+
}
|
|
25
|
+
|
|
26
|
+
export interface ProjectEventMap extends EntityEvents {
|
|
27
|
+
progress: number;
|
|
28
|
+
completed: string[];
|
|
29
|
+
failed: ErrorData;
|
|
30
|
+
jobCompleted: Job;
|
|
31
|
+
jobFailed: Job;
|
|
32
|
+
}
|
|
33
|
+
|
|
34
|
+
class Project extends DataEntity<ProjectData, ProjectEventMap> {
|
|
35
|
+
private _jobs: Job[] = [];
|
|
36
|
+
private _lastEmitedProgress = -1;
|
|
37
|
+
|
|
38
|
+
constructor(data: ProjectParams) {
|
|
39
|
+
super({
|
|
40
|
+
id: getUUID(),
|
|
41
|
+
startedAt: new Date(),
|
|
42
|
+
params: data,
|
|
43
|
+
queuePosition: -1,
|
|
44
|
+
status: 'pending'
|
|
45
|
+
});
|
|
46
|
+
|
|
47
|
+
this.on('updated', this.handleUpdated.bind(this));
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
get id() {
|
|
51
|
+
return this.data.id;
|
|
52
|
+
}
|
|
53
|
+
|
|
54
|
+
get params() {
|
|
55
|
+
return this.data.params;
|
|
56
|
+
}
|
|
57
|
+
|
|
58
|
+
get status() {
|
|
59
|
+
return this.data.status;
|
|
60
|
+
}
|
|
61
|
+
|
|
62
|
+
get error() {
|
|
63
|
+
return this.data.error;
|
|
64
|
+
}
|
|
65
|
+
|
|
66
|
+
/**
|
|
67
|
+
* Progress of the project in percentage (0-100).
|
|
68
|
+
*/
|
|
69
|
+
get progress() {
|
|
70
|
+
// Worker can reduce the number of steps in the job, so we need to calculate the progress based on the actual number of steps
|
|
71
|
+
const stepsPerJob = this.jobs.length ? this.jobs[0].stepCount : this.data.params.steps;
|
|
72
|
+
const jobCount = this.data.params.numberOfImages;
|
|
73
|
+
const stepsDone = this._jobs.reduce((acc, job) => acc + job.step, 0);
|
|
74
|
+
return Math.round((stepsDone / (stepsPerJob * jobCount)) * 100);
|
|
75
|
+
}
|
|
76
|
+
|
|
77
|
+
get queuePosition() {
|
|
78
|
+
return this.data.queuePosition;
|
|
79
|
+
}
|
|
80
|
+
|
|
81
|
+
/**
|
|
82
|
+
* List of jobs in the project. Note that jobs will be added to this list as
|
|
83
|
+
* workers start processing them. So initially this list will be empty.
|
|
84
|
+
* Subscribe to project `updated` event to get notified about any update, including new jobs.
|
|
85
|
+
* @example
|
|
86
|
+
* project.on('updated', (keys) => {
|
|
87
|
+
* if (keys.includes('jobs')) {
|
|
88
|
+
* // Project jobs have been updated
|
|
89
|
+
* }
|
|
90
|
+
* });
|
|
91
|
+
*/
|
|
92
|
+
get jobs() {
|
|
93
|
+
return this._jobs.slice(0);
|
|
94
|
+
}
|
|
95
|
+
|
|
96
|
+
/**
|
|
97
|
+
* List of result URLs for all completed jobs in the project.
|
|
98
|
+
*/
|
|
99
|
+
get resultUrls() {
|
|
100
|
+
return this.jobs.map((job) => job.resultUrl).filter((r) => !!r) as string[];
|
|
101
|
+
}
|
|
102
|
+
|
|
103
|
+
/**
|
|
104
|
+
* Wait for the project to complete, then return the result URLs, or throw an error if the project fails.
|
|
105
|
+
* @returns Promise<string[]> - Promise that resolves to the list of result URLs
|
|
106
|
+
* @throws ErrorData
|
|
107
|
+
*/
|
|
108
|
+
waitForCompletion(): Promise<string[]> {
|
|
109
|
+
if (this.status === 'completed') {
|
|
110
|
+
return Promise.resolve(this.resultUrls);
|
|
111
|
+
}
|
|
112
|
+
if (this.status === 'failed') {
|
|
113
|
+
return Promise.reject(this.error);
|
|
114
|
+
}
|
|
115
|
+
|
|
116
|
+
return new Promise((resolve, reject) => {
|
|
117
|
+
this.on('completed', (images) => {
|
|
118
|
+
resolve(images);
|
|
119
|
+
});
|
|
120
|
+
this.on('failed', (error) => {
|
|
121
|
+
reject(error);
|
|
122
|
+
});
|
|
123
|
+
});
|
|
124
|
+
}
|
|
125
|
+
|
|
126
|
+
/**
|
|
127
|
+
* Find a job by id
|
|
128
|
+
* @param id
|
|
129
|
+
*/
|
|
130
|
+
job(id: string) {
|
|
131
|
+
return this._jobs.find((job) => job.id === id);
|
|
132
|
+
}
|
|
133
|
+
|
|
134
|
+
private handleUpdated(keys: string[]) {
|
|
135
|
+
const progress = this.progress;
|
|
136
|
+
if (progress !== this._lastEmitedProgress) {
|
|
137
|
+
this.emit('progress', progress);
|
|
138
|
+
this._lastEmitedProgress = progress;
|
|
139
|
+
}
|
|
140
|
+
if (keys.includes('status') || keys.includes('jobs')) {
|
|
141
|
+
const allJobsDone = this.jobs.every((job) =>
|
|
142
|
+
['completed', 'failed', 'canceled'].includes(job.status)
|
|
143
|
+
);
|
|
144
|
+
if (this.data.status === 'completed' && allJobsDone) {
|
|
145
|
+
return this.emit('completed', this.resultUrls);
|
|
146
|
+
}
|
|
147
|
+
if (this.data.status === 'failed') {
|
|
148
|
+
this.emit('failed', this.data.error!);
|
|
149
|
+
}
|
|
150
|
+
}
|
|
151
|
+
}
|
|
152
|
+
|
|
153
|
+
/**
|
|
154
|
+
* This is internal method to add a job to the project. Do not call this directly.
|
|
155
|
+
* @internal
|
|
156
|
+
* @param data
|
|
157
|
+
*/
|
|
158
|
+
_addJob(data: JobData) {
|
|
159
|
+
const job = new Job(data);
|
|
160
|
+
this._jobs.push(job);
|
|
161
|
+
job.on('updated', () => {
|
|
162
|
+
this.emit('updated', ['jobs']);
|
|
163
|
+
});
|
|
164
|
+
job.on('completed', () => {
|
|
165
|
+
this.emit('jobCompleted', job);
|
|
166
|
+
});
|
|
167
|
+
job.on('failed', () => {
|
|
168
|
+
this.emit('jobFailed', job);
|
|
169
|
+
});
|
|
170
|
+
return job;
|
|
171
|
+
}
|
|
172
|
+
|
|
173
|
+
/**
|
|
174
|
+
* Get full project data snapshot. Can be used to serialize the project and store it in a database.
|
|
175
|
+
*/
|
|
176
|
+
toJSON(): SerializedProject {
|
|
177
|
+
const data = cloneDeep(this.data);
|
|
178
|
+
return {
|
|
179
|
+
...data,
|
|
180
|
+
jobs: this._jobs.map((job) => job.toJSON())
|
|
181
|
+
};
|
|
182
|
+
}
|
|
183
|
+
}
|
|
184
|
+
|
|
185
|
+
export default Project;
|