@gradio/client 0.19.4 → 0.20.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CHANGELOG.md +14 -0
- package/README.md +8 -1
- package/dist/client.d.ts +4 -0
- package/dist/client.d.ts.map +1 -1
- package/dist/constants.d.ts +4 -0
- package/dist/constants.d.ts.map +1 -1
- package/dist/helpers/api_info.d.ts +1 -0
- package/dist/helpers/api_info.d.ts.map +1 -1
- package/dist/helpers/data.d.ts.map +1 -1
- package/dist/helpers/init_helpers.d.ts +4 -1
- package/dist/helpers/init_helpers.d.ts.map +1 -1
- package/dist/index.js +229 -61
- package/dist/test/handlers.d.ts.map +1 -1
- package/dist/test/test_data.d.ts.map +1 -1
- package/dist/types.d.ts +6 -0
- package/dist/types.d.ts.map +1 -1
- package/dist/utils/duplicate.d.ts.map +1 -1
- package/dist/utils/post_data.d.ts.map +1 -1
- package/dist/utils/predict.d.ts.map +1 -1
- package/dist/utils/submit.d.ts.map +1 -1
- package/dist/utils/upload_files.d.ts.map +1 -1
- package/dist/utils/view_api.d.ts.map +1 -1
- package/package.json +1 -1
- package/src/client.ts +70 -28
- package/src/constants.ts +5 -0
- package/src/helpers/api_info.ts +44 -17
- package/src/helpers/data.ts +3 -2
- package/src/helpers/init_helpers.ts +98 -9
- package/src/test/api_info.test.ts +69 -4
- package/src/test/data.test.ts +4 -4
- package/src/test/handlers.ts +249 -2
- package/src/test/init.test.ts +2 -2
- package/src/test/init_helpers.test.ts +53 -1
- package/src/test/test_data.ts +3 -0
- package/src/types.ts +6 -0
- package/src/utils/duplicate.ts +27 -2
- package/src/utils/post_data.ts +2 -1
- package/src/utils/predict.ts +4 -2
- package/src/utils/submit.ts +37 -8
- package/src/utils/upload_files.ts +2 -1
- package/src/utils/view_api.ts +7 -4
@@ -1 +1 @@
|
|
1
|
-
{"version":3,"file":"view_api.d.ts","sourceRoot":"","sources":["../../src/utils/view_api.ts"],"names":[],"mappings":"AAGA,OAAO,EAAE,MAAM,EAAE,MAAM,WAAW,CAAC;AAInC,wBAAsB,QAAQ,CAAC,IAAI,EAAE,MAAM,GAAG,OAAO,CAAC,GAAG,CAAC,
|
1
|
+
{"version":3,"file":"view_api.d.ts","sourceRoot":"","sources":["../../src/utils/view_api.ts"],"names":[],"mappings":"AAGA,OAAO,EAAE,MAAM,EAAE,MAAM,WAAW,CAAC;AAInC,wBAAsB,QAAQ,CAAC,IAAI,EAAE,MAAM,GAAG,OAAO,CAAC,GAAG,CAAC,CA8DzD"}
|
package/package.json
CHANGED
package/src/client.ts
CHANGED
@@ -23,8 +23,10 @@ import { submit } from "./utils/submit";
|
|
23
23
|
import { RE_SPACE_NAME, process_endpoint } from "./helpers/api_info";
|
24
24
|
import {
|
25
25
|
map_names_to_ids,
|
26
|
+
resolve_cookies,
|
26
27
|
resolve_config,
|
27
|
-
get_jwt
|
28
|
+
get_jwt,
|
29
|
+
parse_and_set_cookies
|
28
30
|
} from "./helpers/init_helpers";
|
29
31
|
import { check_space_status } from "./helpers/spaces";
|
30
32
|
import { open_stream } from "./utils/stream";
|
@@ -47,6 +49,8 @@ export class Client {
|
|
47
49
|
jwt: string | false = false;
|
48
50
|
last_status: Record<string, Status["stage"]> = {};
|
49
51
|
|
52
|
+
private cookies: string | null = null;
|
53
|
+
|
50
54
|
// streaming
|
51
55
|
stream_status = { open: false };
|
52
56
|
pending_stream_messages: Record<string, any[][]> = {};
|
@@ -56,7 +60,12 @@ export class Client {
|
|
56
60
|
heartbeat_event: EventSource | null = null;
|
57
61
|
|
58
62
|
fetch(input: RequestInfo | URL, init?: RequestInit): Promise<Response> {
|
59
|
-
|
63
|
+
const headers = new Headers(init?.headers || {});
|
64
|
+
if (this && this.cookies) {
|
65
|
+
headers.append("Cookie", this.cookies);
|
66
|
+
}
|
67
|
+
|
68
|
+
return fetch(input, { ...init, headers });
|
60
69
|
}
|
61
70
|
|
62
71
|
async stream(url: URL): Promise<EventSource> {
|
@@ -108,6 +117,7 @@ export class Client {
|
|
108
117
|
) => Promise<SubmitReturn>;
|
109
118
|
open_stream: () => Promise<void>;
|
110
119
|
private resolve_config: (endpoint: string) => Promise<Config | undefined>;
|
120
|
+
private resolve_cookies: () => Promise<void>;
|
111
121
|
constructor(app_reference: string, options: ClientOptions = {}) {
|
112
122
|
this.app_reference = app_reference;
|
113
123
|
this.options = options;
|
@@ -120,6 +130,7 @@ export class Client {
|
|
120
130
|
this.predict = predict.bind(this);
|
121
131
|
this.open_stream = open_stream.bind(this);
|
122
132
|
this.resolve_config = resolve_config.bind(this);
|
133
|
+
this.resolve_cookies = resolve_cookies.bind(this);
|
123
134
|
this.upload = upload.bind(this);
|
124
135
|
}
|
125
136
|
|
@@ -135,33 +146,56 @@ export class Client {
|
|
135
146
|
}
|
136
147
|
|
137
148
|
try {
|
138
|
-
|
139
|
-
this.
|
140
|
-
|
141
|
-
if (config.space_id && this.options.hf_token) {
|
142
|
-
this.jwt = await get_jwt(config.space_id, this.options.hf_token);
|
143
|
-
}
|
149
|
+
if (this.options.auth) {
|
150
|
+
await this.resolve_cookies();
|
151
|
+
}
|
144
152
|
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
153
|
+
await this._resolve_config().then(({ config }) =>
|
154
|
+
this._resolve_hearbeat(config)
|
155
|
+
);
|
156
|
+
} catch (e: any) {
|
157
|
+
throw Error(e);
|
158
|
+
}
|
150
159
|
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
}
|
160
|
+
this.api_info = await this.view_api();
|
161
|
+
this.api_map = map_names_to_ids(this.config?.dependencies || []);
|
162
|
+
}
|
155
163
|
|
156
|
-
|
164
|
+
async _resolve_hearbeat(_config: Config): Promise<void> {
|
165
|
+
if (_config) {
|
166
|
+
this.config = _config;
|
167
|
+
if (this.config && this.config.connect_heartbeat) {
|
168
|
+
if (this.config.space_id && this.options.hf_token) {
|
169
|
+
this.jwt = await get_jwt(
|
170
|
+
this.config.space_id,
|
171
|
+
this.options.hf_token,
|
172
|
+
this.cookies
|
173
|
+
);
|
157
174
|
}
|
158
|
-
}
|
159
|
-
} catch (e) {
|
160
|
-
throw Error(CONFIG_ERROR_MSG + (e as Error).message);
|
175
|
+
}
|
161
176
|
}
|
162
177
|
|
163
|
-
|
164
|
-
|
178
|
+
if (_config.space_id && this.options.hf_token) {
|
179
|
+
this.jwt = await get_jwt(_config.space_id, this.options.hf_token);
|
180
|
+
}
|
181
|
+
|
182
|
+
if (this.config && this.config.connect_heartbeat) {
|
183
|
+
// connect to the heartbeat endpoint via GET request
|
184
|
+
const heartbeat_url = new URL(
|
185
|
+
`${this.config.root}/heartbeat/${this.session_hash}`
|
186
|
+
);
|
187
|
+
|
188
|
+
// if the jwt is available, add it to the query params
|
189
|
+
if (this.jwt) {
|
190
|
+
heartbeat_url.searchParams.set("__sign", this.jwt);
|
191
|
+
}
|
192
|
+
|
193
|
+
// Just connect to the endpoint without parsing the response. Ref: https://github.com/gradio-app/gradio/pull/7974#discussion_r1557717540
|
194
|
+
if (!this.heartbeat_event)
|
195
|
+
this.heartbeat_event = await this.stream(heartbeat_url);
|
196
|
+
} else {
|
197
|
+
this.heartbeat_event?.close();
|
198
|
+
}
|
165
199
|
}
|
166
200
|
|
167
201
|
static async connect(
|
@@ -201,9 +235,8 @@ export class Client {
|
|
201
235
|
}
|
202
236
|
|
203
237
|
return this.config_success(config);
|
204
|
-
} catch (e) {
|
205
|
-
|
206
|
-
if (space_id) {
|
238
|
+
} catch (e: any) {
|
239
|
+
if (space_id && status_callback) {
|
207
240
|
check_space_status(
|
208
241
|
space_id,
|
209
242
|
RE_SPACE_NAME.test(space_id) ? "space_name" : "subdomain",
|
@@ -217,6 +250,7 @@ export class Client {
|
|
217
250
|
load_status: "error",
|
218
251
|
detail: "NOT_FOUND"
|
219
252
|
});
|
253
|
+
throw Error(e);
|
220
254
|
}
|
221
255
|
}
|
222
256
|
}
|
@@ -246,6 +280,9 @@ export class Client {
|
|
246
280
|
}
|
247
281
|
|
248
282
|
async handle_space_success(status: SpaceStatus): Promise<Config | void> {
|
283
|
+
if (!this) {
|
284
|
+
throw new Error(CONFIG_ERROR_MSG);
|
285
|
+
}
|
249
286
|
const { status_callback } = this.options;
|
250
287
|
if (status_callback) status_callback(status);
|
251
288
|
if (status.status === "running") {
|
@@ -259,7 +296,6 @@ export class Client {
|
|
259
296
|
|
260
297
|
return _config as Config;
|
261
298
|
} catch (e) {
|
262
|
-
console.error(e);
|
263
299
|
if (status_callback) {
|
264
300
|
status_callback({
|
265
301
|
status: "error",
|
@@ -268,6 +304,7 @@ export class Client {
|
|
268
304
|
detail: "NOT_FOUND"
|
269
305
|
});
|
270
306
|
}
|
307
|
+
throw e;
|
271
308
|
}
|
272
309
|
}
|
273
310
|
}
|
@@ -333,7 +370,8 @@ export class Client {
|
|
333
370
|
const response = await this.fetch(`${root_url}/component_server/`, {
|
334
371
|
method: "POST",
|
335
372
|
body: body,
|
336
|
-
headers
|
373
|
+
headers,
|
374
|
+
credentials: "include"
|
337
375
|
});
|
338
376
|
|
339
377
|
if (!response.ok) {
|
@@ -349,6 +387,10 @@ export class Client {
|
|
349
387
|
}
|
350
388
|
}
|
351
389
|
|
390
|
+
public set_cookies(raw_cookies: string): void {
|
391
|
+
this.cookies = parse_and_set_cookies(raw_cookies).join("; ");
|
392
|
+
}
|
393
|
+
|
352
394
|
private prepare_return_obj(): client_return {
|
353
395
|
return {
|
354
396
|
config: this.config,
|
package/src/constants.ts
CHANGED
@@ -25,3 +25,8 @@ export const CONFIG_ERROR_MSG = "Could not resolve app config. ";
|
|
25
25
|
export const SPACE_STATUS_ERROR_MSG = "Could not get space status. ";
|
26
26
|
export const API_INFO_ERROR_MSG = "Could not get API info. ";
|
27
27
|
export const SPACE_METADATA_ERROR_MSG = "Space metadata could not be loaded. ";
|
28
|
+
export const INVALID_URL_MSG = "Invalid URL. A full URL path is required.";
|
29
|
+
export const UNAUTHORIZED_MSG = "Not authorized to access this space. ";
|
30
|
+
export const INVALID_CREDENTIALS_MSG = "Invalid credentials. Could not login. ";
|
31
|
+
export const MISSING_CREDENTIALS_MSG =
|
32
|
+
"Login credentials are required to access this space.";
|
package/src/helpers/api_info.ts
CHANGED
@@ -1,9 +1,14 @@
|
|
1
1
|
import type { Status } from "../types";
|
2
|
-
import {
|
2
|
+
import {
|
3
|
+
HOST_URL,
|
4
|
+
INVALID_URL_MSG,
|
5
|
+
QUEUE_FULL_MSG,
|
6
|
+
SPACE_METADATA_ERROR_MSG
|
7
|
+
} from "../constants";
|
3
8
|
import type { ApiData, ApiInfo, Config, JsApiData } from "../types";
|
4
9
|
import { determine_protocol } from "./init_helpers";
|
5
10
|
|
6
|
-
export const RE_SPACE_NAME = /^[
|
11
|
+
export const RE_SPACE_NAME = /^[a-zA-Z0-9_\-\.]+\/[a-zA-Z0-9_\-\.]+$/;
|
7
12
|
export const RE_SPACE_DOMAIN = /.*hf\.space\/{0,1}$/;
|
8
13
|
|
9
14
|
export async function process_endpoint(
|
@@ -20,12 +25,13 @@ export async function process_endpoint(
|
|
20
25
|
headers.Authorization = `Bearer ${hf_token}`;
|
21
26
|
}
|
22
27
|
|
23
|
-
const _app_reference = app_reference.trim();
|
28
|
+
const _app_reference = app_reference.trim().replace(/\/$/, "");
|
24
29
|
|
25
30
|
if (RE_SPACE_NAME.test(_app_reference)) {
|
31
|
+
// app_reference is a HF space name
|
26
32
|
try {
|
27
33
|
const res = await fetch(
|
28
|
-
`https://huggingface.co/api/spaces/${_app_reference}
|
34
|
+
`https://huggingface.co/api/spaces/${_app_reference}/${HOST_URL}`,
|
29
35
|
{ headers }
|
30
36
|
);
|
31
37
|
|
@@ -36,13 +42,12 @@ export async function process_endpoint(
|
|
36
42
|
...determine_protocol(_host)
|
37
43
|
};
|
38
44
|
} catch (e) {
|
39
|
-
throw new Error(
|
40
|
-
"Space metadata could not be loaded. " + (e as Error).message
|
41
|
-
);
|
45
|
+
throw new Error(SPACE_METADATA_ERROR_MSG);
|
42
46
|
}
|
43
47
|
}
|
44
48
|
|
45
49
|
if (RE_SPACE_DOMAIN.test(_app_reference)) {
|
50
|
+
// app_reference is a direct HF space domain
|
46
51
|
const { ws_protocol, http_protocol, host } =
|
47
52
|
determine_protocol(_app_reference);
|
48
53
|
|
@@ -60,6 +65,18 @@ export async function process_endpoint(
|
|
60
65
|
};
|
61
66
|
}
|
62
67
|
|
68
|
+
export const join_urls = (...urls: string[]): string => {
|
69
|
+
try {
|
70
|
+
return urls.reduce((base_url: string, part: string) => {
|
71
|
+
base_url = base_url.replace(/\/+$/, "");
|
72
|
+
part = part.replace(/^\/+/, "");
|
73
|
+
return new URL(part, base_url + "/").toString();
|
74
|
+
});
|
75
|
+
} catch (e) {
|
76
|
+
throw new Error(INVALID_URL_MSG);
|
77
|
+
}
|
78
|
+
};
|
79
|
+
|
63
80
|
export function transform_api_info(
|
64
81
|
api_info: ApiInfo<ApiData>,
|
65
82
|
config: Config,
|
@@ -77,27 +94,30 @@ export function transform_api_info(
|
|
77
94
|
Object.entries(api_info[category]).forEach(
|
78
95
|
([endpoint, { parameters, returns }]) => {
|
79
96
|
const dependencyIndex =
|
80
|
-
config.dependencies.
|
97
|
+
config.dependencies.find(
|
81
98
|
(dep) =>
|
82
99
|
dep.api_name === endpoint ||
|
83
100
|
dep.api_name === endpoint.replace("/", "")
|
84
|
-
) ||
|
101
|
+
)?.id ||
|
85
102
|
api_map[endpoint.replace("/", "")] ||
|
86
103
|
-1;
|
87
104
|
|
88
105
|
const dependencyTypes =
|
89
106
|
dependencyIndex !== -1
|
90
|
-
? config.dependencies
|
107
|
+
? config.dependencies.find((dep) => dep.id == dependencyIndex)
|
108
|
+
?.types
|
91
109
|
: { continuous: false, generator: false };
|
92
110
|
|
93
111
|
if (
|
94
112
|
dependencyIndex !== -1 &&
|
95
|
-
config.dependencies
|
96
|
-
parameters.length
|
113
|
+
config.dependencies.find((dep) => dep.id == dependencyIndex)?.inputs
|
114
|
+
?.length !== parameters.length
|
97
115
|
) {
|
98
|
-
const components = config.dependencies
|
99
|
-
|
100
|
-
|
116
|
+
const components = config.dependencies
|
117
|
+
.find((dep) => dep.id == dependencyIndex)!
|
118
|
+
.inputs.map(
|
119
|
+
(input) => config.components.find((c) => c.id === input)?.type
|
120
|
+
);
|
101
121
|
|
102
122
|
try {
|
103
123
|
components.forEach((comp, idx) => {
|
@@ -115,7 +135,9 @@ export function transform_api_info(
|
|
115
135
|
parameters.splice(idx, 0, new_param);
|
116
136
|
}
|
117
137
|
});
|
118
|
-
} catch (e) {
|
138
|
+
} catch (e) {
|
139
|
+
console.error(e);
|
140
|
+
}
|
119
141
|
}
|
120
142
|
|
121
143
|
const transform_type = (
|
@@ -201,6 +223,7 @@ export function get_description(
|
|
201
223
|
return type?.description;
|
202
224
|
}
|
203
225
|
|
226
|
+
/* eslint-disable complexity */
|
204
227
|
export function handle_message(
|
205
228
|
data: any,
|
206
229
|
last_status: Status["stage"]
|
@@ -308,7 +331,10 @@ export function handle_message(
|
|
308
331
|
message: !data.success ? data.output.error : undefined,
|
309
332
|
stage: data.success ? "complete" : "error",
|
310
333
|
code: data.code,
|
311
|
-
progress_data: data.progress_data
|
334
|
+
progress_data: data.progress_data,
|
335
|
+
changed_state_ids: data.success
|
336
|
+
? data.output.changed_state_ids
|
337
|
+
: undefined
|
312
338
|
},
|
313
339
|
data: data.success ? data.output : null
|
314
340
|
};
|
@@ -330,6 +356,7 @@ export function handle_message(
|
|
330
356
|
|
331
357
|
return { type: "none", status: { stage: "error", queue } };
|
332
358
|
}
|
359
|
+
/* eslint-enable complexity */
|
333
360
|
|
334
361
|
/**
|
335
362
|
* Maps the provided `data` to the parameters defined by the `/info` endpoint response.
|
package/src/helpers/data.ts
CHANGED
@@ -96,8 +96,9 @@ export async function walk_and_store_blobs(
|
|
96
96
|
}
|
97
97
|
|
98
98
|
export function skip_queue(id: number, config: Config): boolean {
|
99
|
-
|
100
|
-
|
99
|
+
let fn_queue = config?.dependencies?.find((dep) => dep.id == id)?.queue;
|
100
|
+
if (fn_queue != null) {
|
101
|
+
return !fn_queue;
|
101
102
|
}
|
102
103
|
return !config.enable_queue;
|
103
104
|
}
|
@@ -1,6 +1,15 @@
|
|
1
1
|
import type { Config } from "../types";
|
2
|
-
import {
|
2
|
+
import {
|
3
|
+
CONFIG_ERROR_MSG,
|
4
|
+
CONFIG_URL,
|
5
|
+
INVALID_CREDENTIALS_MSG,
|
6
|
+
LOGIN_URL,
|
7
|
+
MISSING_CREDENTIALS_MSG,
|
8
|
+
SPACE_METADATA_ERROR_MSG,
|
9
|
+
UNAUTHORIZED_MSG
|
10
|
+
} from "../constants";
|
3
11
|
import { Client } from "..";
|
12
|
+
import { join_urls, process_endpoint } from "./api_info";
|
4
13
|
|
5
14
|
/**
|
6
15
|
* This function is used to resolve the URL for making requests when the app has a root path.
|
@@ -25,12 +34,14 @@ export function resolve_root(
|
|
25
34
|
|
26
35
|
export async function get_jwt(
|
27
36
|
space: string,
|
28
|
-
token: `hf_${string}
|
37
|
+
token: `hf_${string}`,
|
38
|
+
cookies?: string | null
|
29
39
|
): Promise<string | false> {
|
30
40
|
try {
|
31
41
|
const r = await fetch(`https://huggingface.co/api/spaces/${space}/jwt`, {
|
32
42
|
headers: {
|
33
|
-
Authorization: `Bearer ${token}
|
43
|
+
Authorization: `Bearer ${token}`,
|
44
|
+
...(cookies ? { Cookie: cookies } : {})
|
34
45
|
}
|
35
46
|
});
|
36
47
|
|
@@ -47,8 +58,8 @@ export function map_names_to_ids(
|
|
47
58
|
): Record<string, number> {
|
48
59
|
let apis: Record<string, number> = {};
|
49
60
|
|
50
|
-
fns.forEach(({ api_name
|
51
|
-
if (api_name) apis[api_name] =
|
61
|
+
fns.forEach(({ api_name, id }) => {
|
62
|
+
if (api_name) apis[api_name] = id;
|
52
63
|
});
|
53
64
|
return apis;
|
54
65
|
}
|
@@ -75,15 +86,24 @@ export async function resolve_config(
|
|
75
86
|
config.root = config_root;
|
76
87
|
return { ...config, path } as Config;
|
77
88
|
} else if (endpoint) {
|
78
|
-
const
|
79
|
-
|
89
|
+
const config_url = join_urls(endpoint, CONFIG_URL);
|
90
|
+
const response = await this.fetch(config_url, {
|
91
|
+
headers,
|
92
|
+
credentials: "include"
|
80
93
|
});
|
81
94
|
|
95
|
+
if (response?.status === 401 && !this.options.auth) {
|
96
|
+
throw new Error(MISSING_CREDENTIALS_MSG);
|
97
|
+
} else if (response?.status === 401 && this.options.auth) {
|
98
|
+
throw new Error(INVALID_CREDENTIALS_MSG);
|
99
|
+
}
|
82
100
|
if (response?.status === 200) {
|
83
101
|
let config = await response.json();
|
84
102
|
config.path = config.path ?? "";
|
85
103
|
config.root = endpoint;
|
86
104
|
return config;
|
105
|
+
} else if (response?.status === 401) {
|
106
|
+
throw new Error(UNAUTHORIZED_MSG);
|
87
107
|
}
|
88
108
|
throw new Error(CONFIG_ERROR_MSG);
|
89
109
|
}
|
@@ -91,13 +111,70 @@ export async function resolve_config(
|
|
91
111
|
throw new Error(CONFIG_ERROR_MSG);
|
92
112
|
}
|
93
113
|
|
114
|
+
export async function resolve_cookies(this: Client): Promise<void> {
|
115
|
+
const { http_protocol, host } = await process_endpoint(
|
116
|
+
this.app_reference,
|
117
|
+
this.options.hf_token
|
118
|
+
);
|
119
|
+
|
120
|
+
try {
|
121
|
+
if (this.options.auth) {
|
122
|
+
const cookie_header = await get_cookie_header(
|
123
|
+
http_protocol,
|
124
|
+
host,
|
125
|
+
this.options.auth,
|
126
|
+
this.fetch,
|
127
|
+
this.options.hf_token
|
128
|
+
);
|
129
|
+
|
130
|
+
if (cookie_header) this.set_cookies(cookie_header);
|
131
|
+
}
|
132
|
+
} catch (e: unknown) {
|
133
|
+
throw Error((e as Error).message);
|
134
|
+
}
|
135
|
+
}
|
136
|
+
|
137
|
+
// separating this from client-bound resolve_cookies so that it can be used in duplicate
|
138
|
+
export async function get_cookie_header(
|
139
|
+
http_protocol: string,
|
140
|
+
host: string,
|
141
|
+
auth: [string, string],
|
142
|
+
_fetch: typeof fetch,
|
143
|
+
hf_token?: `hf_${string}`
|
144
|
+
): Promise<string | null> {
|
145
|
+
const formData = new FormData();
|
146
|
+
formData.append("username", auth?.[0]);
|
147
|
+
formData.append("password", auth?.[1]);
|
148
|
+
|
149
|
+
let headers: { Authorization?: string } = {};
|
150
|
+
|
151
|
+
if (hf_token) {
|
152
|
+
headers.Authorization = `Bearer ${hf_token}`;
|
153
|
+
}
|
154
|
+
|
155
|
+
const res = await _fetch(`${http_protocol}//${host}/${LOGIN_URL}`, {
|
156
|
+
headers,
|
157
|
+
method: "POST",
|
158
|
+
body: formData,
|
159
|
+
credentials: "include"
|
160
|
+
});
|
161
|
+
|
162
|
+
if (res.status === 200) {
|
163
|
+
return res.headers.get("set-cookie");
|
164
|
+
} else if (res.status === 401) {
|
165
|
+
throw new Error(INVALID_CREDENTIALS_MSG);
|
166
|
+
} else {
|
167
|
+
throw new Error(SPACE_METADATA_ERROR_MSG);
|
168
|
+
}
|
169
|
+
}
|
170
|
+
|
94
171
|
export function determine_protocol(endpoint: string): {
|
95
172
|
ws_protocol: "ws" | "wss";
|
96
173
|
http_protocol: "http:" | "https:";
|
97
174
|
host: string;
|
98
175
|
} {
|
99
176
|
if (endpoint.startsWith("http")) {
|
100
|
-
const { protocol, host } = new URL(endpoint);
|
177
|
+
const { protocol, host, pathname } = new URL(endpoint);
|
101
178
|
|
102
179
|
if (host.endsWith("hf.space")) {
|
103
180
|
return {
|
@@ -109,7 +186,7 @@ export function determine_protocol(endpoint: string): {
|
|
109
186
|
return {
|
110
187
|
ws_protocol: protocol === "https:" ? "wss" : "ws",
|
111
188
|
http_protocol: protocol as "http:" | "https:",
|
112
|
-
host
|
189
|
+
host: host + (pathname !== "/" ? pathname : "")
|
113
190
|
};
|
114
191
|
} else if (endpoint.startsWith("file:")) {
|
115
192
|
// This case is only expected to be used for the Wasm mode (Gradio-lite),
|
@@ -128,3 +205,15 @@ export function determine_protocol(endpoint: string): {
|
|
128
205
|
host: endpoint
|
129
206
|
};
|
130
207
|
}
|
208
|
+
|
209
|
+
export const parse_and_set_cookies = (cookie_header: string): string[] => {
|
210
|
+
let cookies: string[] = [];
|
211
|
+
const parts = cookie_header.split(/,(?=\s*[^\s=;]+=[^\s=;]+)/);
|
212
|
+
parts.forEach((cookie) => {
|
213
|
+
const [cookie_name, cookie_value] = cookie.split(";")[0].split("=");
|
214
|
+
if (cookie_name && cookie_value) {
|
215
|
+
cookies.push(`${cookie_name.trim()}=${cookie_value.trim()}`);
|
216
|
+
}
|
217
|
+
});
|
218
|
+
return cookies;
|
219
|
+
};
|
@@ -1,16 +1,22 @@
|
|
1
|
-
import {
|
1
|
+
import {
|
2
|
+
INVALID_URL_MSG,
|
3
|
+
QUEUE_FULL_MSG,
|
4
|
+
SPACE_METADATA_ERROR_MSG
|
5
|
+
} from "../constants";
|
2
6
|
import { beforeAll, afterEach, afterAll, it, expect, describe } from "vitest";
|
3
7
|
import {
|
4
8
|
handle_message,
|
5
9
|
get_description,
|
6
10
|
get_type,
|
7
11
|
process_endpoint,
|
12
|
+
join_urls,
|
8
13
|
map_data_to_params
|
9
14
|
} from "../helpers/api_info";
|
10
15
|
import { initialise_server } from "./server";
|
11
16
|
import { transformed_api_info } from "./test_data";
|
12
17
|
|
13
18
|
const server = initialise_server();
|
19
|
+
const IS_NODE = process.env.TEST_MODE === "node";
|
14
20
|
|
15
21
|
beforeAll(() => server.listen());
|
16
22
|
afterEach(() => server.resetHandlers());
|
@@ -435,9 +441,7 @@ describe("process_endpoint", () => {
|
|
435
441
|
try {
|
436
442
|
await process_endpoint(app_reference, hf_token);
|
437
443
|
} catch (error) {
|
438
|
-
expect(error.message).toEqual(
|
439
|
-
SPACE_METADATA_ERROR_MSG + "Unexpected end of JSON input"
|
440
|
-
);
|
444
|
+
expect(error.message).toEqual(SPACE_METADATA_ERROR_MSG);
|
441
445
|
}
|
442
446
|
});
|
443
447
|
|
@@ -455,6 +459,67 @@ describe("process_endpoint", () => {
|
|
455
459
|
const result = await process_endpoint("hmb/hello_world");
|
456
460
|
expect(result).toEqual(expected);
|
457
461
|
});
|
462
|
+
|
463
|
+
it("processes local server URLs correctly", async () => {
|
464
|
+
const local_url = "http://localhost:7860/gradio";
|
465
|
+
const response_local_url = await process_endpoint(local_url);
|
466
|
+
expect(response_local_url.space_id).toBe(false);
|
467
|
+
expect(response_local_url.host).toBe("localhost:7860/gradio");
|
468
|
+
|
469
|
+
const local_url_2 = "http://localhost:7860/gradio/";
|
470
|
+
const response_local_url_2 = await process_endpoint(local_url_2);
|
471
|
+
expect(response_local_url_2.space_id).toBe(false);
|
472
|
+
expect(response_local_url_2.host).toBe("localhost:7860/gradio");
|
473
|
+
});
|
474
|
+
|
475
|
+
it("handles hugging face space references", async () => {
|
476
|
+
const space_id = "hmb/hello_world";
|
477
|
+
|
478
|
+
const response = await process_endpoint(space_id);
|
479
|
+
expect(response.space_id).toBe(space_id);
|
480
|
+
expect(response.host).toContain("hf.space");
|
481
|
+
});
|
482
|
+
|
483
|
+
it("handles hugging face domain URLs", async () => {
|
484
|
+
const app_reference = "https://hmb-hello-world.hf.space/";
|
485
|
+
const response = await process_endpoint(app_reference);
|
486
|
+
expect(response.space_id).toBe("hmb-hello-world");
|
487
|
+
expect(response.host).toBe("hmb-hello-world.hf.space");
|
488
|
+
});
|
489
|
+
});
|
490
|
+
|
491
|
+
describe("join_urls", () => {
|
492
|
+
it("joins URLs correctly", () => {
|
493
|
+
expect(join_urls("http://localhost:7860", "/gradio")).toBe(
|
494
|
+
"http://localhost:7860/gradio"
|
495
|
+
);
|
496
|
+
expect(join_urls("http://localhost:7860/", "/gradio")).toBe(
|
497
|
+
"http://localhost:7860/gradio"
|
498
|
+
);
|
499
|
+
expect(join_urls("http://localhost:7860", "app/", "/gradio")).toBe(
|
500
|
+
"http://localhost:7860/app/gradio"
|
501
|
+
);
|
502
|
+
expect(join_urls("http://localhost:7860/", "/app/", "/gradio/")).toBe(
|
503
|
+
"http://localhost:7860/app/gradio/"
|
504
|
+
);
|
505
|
+
|
506
|
+
expect(join_urls("http://127.0.0.1:8000/app", "/config")).toBe(
|
507
|
+
"http://127.0.0.1:8000/app/config"
|
508
|
+
);
|
509
|
+
|
510
|
+
expect(join_urls("http://127.0.0.1:8000/app/gradio", "/config")).toBe(
|
511
|
+
"http://127.0.0.1:8000/app/gradio/config"
|
512
|
+
);
|
513
|
+
});
|
514
|
+
it("throws an error when the URLs are not valid", () => {
|
515
|
+
expect(() => join_urls("localhost:7860", "/gradio")).toThrowError(
|
516
|
+
INVALID_URL_MSG
|
517
|
+
);
|
518
|
+
|
519
|
+
expect(() => join_urls("localhost:7860", "/gradio", "app")).toThrowError(
|
520
|
+
INVALID_URL_MSG
|
521
|
+
);
|
522
|
+
});
|
458
523
|
});
|
459
524
|
|
460
525
|
describe("map_data_params", () => {
|
package/src/test/data.test.ts
CHANGED
@@ -195,7 +195,7 @@ describe("skip_queue", () => {
|
|
195
195
|
|
196
196
|
it("should not skip queue when global and dependency queue is enabled", () => {
|
197
197
|
config.enable_queue = true;
|
198
|
-
config.dependencies
|
198
|
+
config.dependencies.find((dep) => dep.id === id)!.queue = true;
|
199
199
|
|
200
200
|
const result = skip_queue(id, config_response);
|
201
201
|
|
@@ -204,7 +204,7 @@ describe("skip_queue", () => {
|
|
204
204
|
|
205
205
|
it("should not skip queue when global queue is disabled and dependency queue is enabled", () => {
|
206
206
|
config.enable_queue = false;
|
207
|
-
config.dependencies
|
207
|
+
config.dependencies.find((dep) => dep.id === id)!.queue = true;
|
208
208
|
|
209
209
|
const result = skip_queue(id, config_response);
|
210
210
|
|
@@ -213,7 +213,7 @@ describe("skip_queue", () => {
|
|
213
213
|
|
214
214
|
it("should should skip queue when global queue and dependency queue is disabled", () => {
|
215
215
|
config.enable_queue = false;
|
216
|
-
config.dependencies
|
216
|
+
config.dependencies.find((dep) => dep.id === id)!.queue = false;
|
217
217
|
|
218
218
|
const result = skip_queue(id, config_response);
|
219
219
|
|
@@ -222,7 +222,7 @@ describe("skip_queue", () => {
|
|
222
222
|
|
223
223
|
it("should should skip queue when global queue is enabled and dependency queue is disabled", () => {
|
224
224
|
config.enable_queue = true;
|
225
|
-
config.dependencies
|
225
|
+
config.dependencies.find((dep) => dep.id === id)!.queue = false;
|
226
226
|
|
227
227
|
const result = skip_queue(id, config_response);
|
228
228
|
|