wandb 0.17.8rc1__py3-none-win_amd64.whl → 0.17.9__py3-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package_readme.md +47 -53
- wandb/__init__.py +12 -6
- wandb/__init__.pyi +112 -2
- wandb/bin/wandb-core +0 -0
- wandb/data_types.py +1 -0
- wandb/env.py +13 -0
- wandb/integration/keras/__init__.py +2 -5
- wandb/integration/keras/callbacks/metrics_logger.py +10 -4
- wandb/integration/keras/callbacks/model_checkpoint.py +0 -5
- wandb/integration/keras/keras.py +12 -1
- wandb/integration/openai/fine_tuning.py +5 -5
- wandb/integration/tensorboard/log.py +1 -1
- wandb/proto/v3/wandb_internal_pb2.py +31 -21
- wandb/proto/v3/wandb_settings_pb2.py +2 -2
- wandb/proto/v3/wandb_telemetry_pb2.py +12 -12
- wandb/proto/v4/wandb_internal_pb2.py +23 -21
- wandb/proto/v4/wandb_settings_pb2.py +2 -2
- wandb/proto/v4/wandb_telemetry_pb2.py +12 -12
- wandb/proto/v5/wandb_internal_pb2.py +23 -21
- wandb/proto/v5/wandb_settings_pb2.py +2 -2
- wandb/proto/v5/wandb_telemetry_pb2.py +12 -12
- wandb/proto/wandb_deprecated.py +4 -0
- wandb/sdk/__init__.py +1 -1
- wandb/sdk/artifacts/artifact.py +9 -11
- wandb/sdk/artifacts/artifact_manifest_entry.py +10 -2
- wandb/sdk/artifacts/storage_handlers/gcs_handler.py +31 -0
- wandb/sdk/internal/system/assets/trainium.py +2 -1
- wandb/sdk/internal/tb_watcher.py +1 -1
- wandb/sdk/lib/_settings_toposort_generated.py +5 -3
- wandb/sdk/service/service.py +7 -2
- wandb/sdk/wandb_init.py +5 -1
- wandb/sdk/wandb_manager.py +0 -3
- wandb/sdk/wandb_require.py +22 -1
- wandb/sdk/wandb_run.py +14 -4
- wandb/sdk/wandb_settings.py +32 -10
- wandb/sdk/wandb_setup.py +3 -0
- {wandb-0.17.8rc1.dist-info → wandb-0.17.9.dist-info}/METADATA +48 -54
- {wandb-0.17.8rc1.dist-info → wandb-0.17.9.dist-info}/RECORD +42 -43
- wandb/testing/relay.py +0 -874
- /wandb/{viz.py → sdk/lib/viz.py} +0 -0
- {wandb-0.17.8rc1.dist-info → wandb-0.17.9.dist-info}/WHEEL +0 -0
- {wandb-0.17.8rc1.dist-info → wandb-0.17.9.dist-info}/entry_points.txt +0 -0
- {wandb-0.17.8rc1.dist-info → wandb-0.17.9.dist-info}/licenses/LICENSE +0 -0
wandb/testing/relay.py
DELETED
@@ -1,874 +0,0 @@
|
|
1
|
-
import dataclasses
|
2
|
-
import json
|
3
|
-
import logging
|
4
|
-
import socket
|
5
|
-
import sys
|
6
|
-
import threading
|
7
|
-
import traceback
|
8
|
-
import urllib.parse
|
9
|
-
from collections import defaultdict, deque
|
10
|
-
from copy import deepcopy
|
11
|
-
from typing import (
|
12
|
-
TYPE_CHECKING,
|
13
|
-
Any,
|
14
|
-
Callable,
|
15
|
-
Dict,
|
16
|
-
Iterable,
|
17
|
-
List,
|
18
|
-
Mapping,
|
19
|
-
Optional,
|
20
|
-
Union,
|
21
|
-
)
|
22
|
-
|
23
|
-
import flask
|
24
|
-
import pandas as pd
|
25
|
-
import requests
|
26
|
-
import responses
|
27
|
-
|
28
|
-
import wandb
|
29
|
-
import wandb.util
|
30
|
-
from wandb.sdk.lib.timer import Timer
|
31
|
-
|
32
|
-
try:
|
33
|
-
from typing import Literal, TypedDict
|
34
|
-
except ImportError:
|
35
|
-
from typing_extensions import Literal, TypedDict
|
36
|
-
|
37
|
-
if sys.version_info >= (3, 8):
|
38
|
-
from typing import Protocol
|
39
|
-
else:
|
40
|
-
from typing_extensions import Protocol
|
41
|
-
|
42
|
-
if TYPE_CHECKING:
|
43
|
-
from typing import Deque
|
44
|
-
|
45
|
-
class RawRequestResponse(TypedDict):
|
46
|
-
url: str
|
47
|
-
request: Optional[Any]
|
48
|
-
response: Dict[str, Any]
|
49
|
-
time_elapsed: float # seconds
|
50
|
-
|
51
|
-
ResolverName = Literal[
|
52
|
-
"upsert_bucket",
|
53
|
-
"upload_files",
|
54
|
-
"uploaded_files",
|
55
|
-
"preempting",
|
56
|
-
"upsert_sweep",
|
57
|
-
]
|
58
|
-
|
59
|
-
class Resolver(TypedDict):
|
60
|
-
name: ResolverName
|
61
|
-
resolver: Callable[[Any], Optional[Dict[str, Any]]]
|
62
|
-
|
63
|
-
|
64
|
-
class DeliberateHTTPError(Exception):
|
65
|
-
def __init__(self, message, status_code: int = 500):
|
66
|
-
Exception.__init__(self)
|
67
|
-
self.message = message
|
68
|
-
self.status_code = status_code
|
69
|
-
|
70
|
-
def get_response(self):
|
71
|
-
return flask.Response(self.message, status=self.status_code)
|
72
|
-
|
73
|
-
def __repr__(self):
|
74
|
-
return f"DeliberateHTTPError({self.message!r}, {self.status_code!r})"
|
75
|
-
|
76
|
-
|
77
|
-
@dataclasses.dataclass
|
78
|
-
class RunAttrs:
|
79
|
-
"""Simple data class for run attributes."""
|
80
|
-
|
81
|
-
name: str
|
82
|
-
display_name: str
|
83
|
-
description: str
|
84
|
-
sweep_name: str
|
85
|
-
project: Dict[str, Any]
|
86
|
-
config: Dict[str, Any]
|
87
|
-
remote: Optional[str] = None
|
88
|
-
commit: Optional[str] = None
|
89
|
-
|
90
|
-
|
91
|
-
class Context:
|
92
|
-
"""A container used to store the snooped state/data of a test.
|
93
|
-
|
94
|
-
Includes raw requests and responses, parsed and processed data, and a number of
|
95
|
-
convenience methods and properties for accessing the data.
|
96
|
-
"""
|
97
|
-
|
98
|
-
def __init__(self) -> None:
|
99
|
-
# parsed/merged data. keys are the individual wandb run id's.
|
100
|
-
self._entries = defaultdict(dict)
|
101
|
-
# container for raw requests and responses:
|
102
|
-
self.raw_data: List[RawRequestResponse] = []
|
103
|
-
# concatenated file contents for all runs:
|
104
|
-
self._history: Optional[pd.DataFrame] = None
|
105
|
-
self._events: Optional[pd.DataFrame] = None
|
106
|
-
self._summary: Optional[pd.DataFrame] = None
|
107
|
-
self._config: Optional[Dict[str, Any]] = None
|
108
|
-
self._output: Optional[Any] = None
|
109
|
-
|
110
|
-
def upsert(self, entry: Dict[str, Any]) -> None:
|
111
|
-
try:
|
112
|
-
entry_id: str = entry["name"]
|
113
|
-
except KeyError:
|
114
|
-
entry_id = entry["id"]
|
115
|
-
self._entries[entry_id] = wandb.util.merge_dicts(entry, self._entries[entry_id])
|
116
|
-
|
117
|
-
# mapping interface
|
118
|
-
def __getitem__(self, key: str) -> Any:
|
119
|
-
return self._entries[key]
|
120
|
-
|
121
|
-
def keys(self) -> Iterable[str]:
|
122
|
-
return self._entries.keys()
|
123
|
-
|
124
|
-
def get_file_contents(self, file_name: str) -> pd.DataFrame:
|
125
|
-
dfs = []
|
126
|
-
|
127
|
-
for entry_id in self._entries:
|
128
|
-
# - extract the content from `file_name`
|
129
|
-
# - sort by offset (will be useful when relay server goes async)
|
130
|
-
# - extract data, merge into a list of dicts and convert to a pandas dataframe
|
131
|
-
content_list = self._entries[entry_id].get("files", {}).get(file_name, [])
|
132
|
-
content_list.sort(key=lambda x: x["offset"])
|
133
|
-
content_list = [item["content"] for item in content_list]
|
134
|
-
# merge list of lists content_list:
|
135
|
-
content_list = [item for sublist in content_list for item in sublist]
|
136
|
-
df = pd.DataFrame.from_records(content_list)
|
137
|
-
df["__run_id"] = entry_id
|
138
|
-
dfs.append(df)
|
139
|
-
|
140
|
-
return pd.concat(dfs)
|
141
|
-
|
142
|
-
# attributes to use in assertions
|
143
|
-
@property
|
144
|
-
def entries(self) -> Dict[str, Any]:
|
145
|
-
return deepcopy(self._entries)
|
146
|
-
|
147
|
-
@property
|
148
|
-
def history(self) -> pd.DataFrame:
|
149
|
-
# todo: caveat: this assumes that all assertions happen at the end of a test
|
150
|
-
if self._history is not None:
|
151
|
-
return deepcopy(self._history)
|
152
|
-
|
153
|
-
self._history = self.get_file_contents("wandb-history.jsonl")
|
154
|
-
return deepcopy(self._history)
|
155
|
-
|
156
|
-
@property
|
157
|
-
def events(self) -> pd.DataFrame:
|
158
|
-
if self._events is not None:
|
159
|
-
return deepcopy(self._events)
|
160
|
-
|
161
|
-
self._events = self.get_file_contents("wandb-events.jsonl")
|
162
|
-
return deepcopy(self._events)
|
163
|
-
|
164
|
-
@property
|
165
|
-
def summary(self) -> pd.DataFrame:
|
166
|
-
if self._summary is not None:
|
167
|
-
return deepcopy(self._summary)
|
168
|
-
|
169
|
-
_summary = self.get_file_contents("wandb-summary.json")
|
170
|
-
|
171
|
-
# run summary may be updated multiple times,
|
172
|
-
# but we are only interested in the last one.
|
173
|
-
# we can have multiple runs saved to context,
|
174
|
-
# so we need to group by run id and take the
|
175
|
-
# last one for each run.
|
176
|
-
self._summary = (
|
177
|
-
_summary.groupby("__run_id").last().reset_index(level=["__run_id"])
|
178
|
-
)
|
179
|
-
|
180
|
-
return deepcopy(self._summary)
|
181
|
-
|
182
|
-
@property
|
183
|
-
def output(self) -> pd.DataFrame:
|
184
|
-
if self._output is not None:
|
185
|
-
return deepcopy(self._output)
|
186
|
-
|
187
|
-
self._output = self.get_file_contents("output.log")
|
188
|
-
return deepcopy(self._output)
|
189
|
-
|
190
|
-
@property
|
191
|
-
def config(self) -> Dict[str, Any]:
|
192
|
-
if self._config is not None:
|
193
|
-
return deepcopy(self._config)
|
194
|
-
|
195
|
-
self._config = {k: v["config"] for (k, v) in self._entries.items()}
|
196
|
-
return deepcopy(self._config)
|
197
|
-
|
198
|
-
# @property
|
199
|
-
# def telemetry(self) -> pd.DataFrame:
|
200
|
-
# telemetry = pd.DataFrame.from_records(
|
201
|
-
# [
|
202
|
-
# {
|
203
|
-
# "__run_id": run_id,
|
204
|
-
# "telemetry": config.get("_wandb", {}).get("value", {}).get("t")
|
205
|
-
# }
|
206
|
-
# for (run_id, config) in self.config.items()
|
207
|
-
# ]
|
208
|
-
# )
|
209
|
-
# return telemetry
|
210
|
-
|
211
|
-
# convenience data access methods
|
212
|
-
def get_run_telemetry(self, run_id: str) -> Dict[str, Any]:
|
213
|
-
return self.config.get(run_id, {}).get("_wandb", {}).get("value", {}).get("t")
|
214
|
-
|
215
|
-
def get_run_metrics(self, run_id: str) -> Dict[str, Any]:
|
216
|
-
return self.config.get(run_id, {}).get("_wandb", {}).get("value", {}).get("m")
|
217
|
-
|
218
|
-
def get_run_summary(
|
219
|
-
self, run_id: str, include_private: bool = False
|
220
|
-
) -> Dict[str, Any]:
|
221
|
-
# run summary dataframe must have only one row
|
222
|
-
# for the given run id, so we convert it to dict
|
223
|
-
# and extract the first (and only) row.
|
224
|
-
mask_run = self.summary["__run_id"] == run_id
|
225
|
-
run_summary = self.summary[mask_run]
|
226
|
-
ret = (
|
227
|
-
run_summary.filter(regex="^[^_]", axis=1)
|
228
|
-
if not include_private
|
229
|
-
else run_summary
|
230
|
-
).to_dict(orient="records")
|
231
|
-
return ret[0] if len(ret) > 0 else {}
|
232
|
-
|
233
|
-
def get_run_history(
|
234
|
-
self, run_id: str, include_private: bool = False
|
235
|
-
) -> pd.DataFrame:
|
236
|
-
mask_run = self.history["__run_id"] == run_id
|
237
|
-
run_history = self.history[mask_run]
|
238
|
-
return (
|
239
|
-
run_history.filter(regex="^[^_]", axis=1)
|
240
|
-
if not include_private
|
241
|
-
else run_history
|
242
|
-
)
|
243
|
-
|
244
|
-
def get_run_uploaded_files(self, run_id: str) -> Dict[str, Any]:
|
245
|
-
return self.entries.get(run_id, {}).get("uploaded", [])
|
246
|
-
|
247
|
-
def get_run_stats(self, run_id: str) -> pd.DataFrame:
|
248
|
-
mask_run = self.events["__run_id"] == run_id
|
249
|
-
run_stats = self.events[mask_run]
|
250
|
-
return run_stats
|
251
|
-
|
252
|
-
def get_run_attrs(self, run_id: str) -> Optional[RunAttrs]:
|
253
|
-
run_entry = self._entries.get(run_id)
|
254
|
-
if not run_entry:
|
255
|
-
return None
|
256
|
-
|
257
|
-
return RunAttrs(
|
258
|
-
name=run_entry["name"],
|
259
|
-
display_name=run_entry["displayName"],
|
260
|
-
description=run_entry["description"],
|
261
|
-
sweep_name=run_entry["sweepName"],
|
262
|
-
project=run_entry["project"],
|
263
|
-
config=run_entry["config"],
|
264
|
-
remote=run_entry.get("repo"),
|
265
|
-
commit=run_entry.get("commit"),
|
266
|
-
)
|
267
|
-
|
268
|
-
def get_run(self, run_id: str) -> Dict[str, Any]:
|
269
|
-
return self._entries.get(run_id, {})
|
270
|
-
|
271
|
-
# todo: add getter (by run_id) utilities for other properties
|
272
|
-
|
273
|
-
|
274
|
-
class QueryResolver:
|
275
|
-
"""Resolve request/response pairs against a set of known patterns.
|
276
|
-
|
277
|
-
This extracts and processes useful data to be later stored in a Context object.
|
278
|
-
"""
|
279
|
-
|
280
|
-
def __init__(self):
|
281
|
-
self.resolvers: List[Resolver] = [
|
282
|
-
{
|
283
|
-
"name": "upsert_bucket",
|
284
|
-
"resolver": self.resolve_upsert_bucket,
|
285
|
-
},
|
286
|
-
{
|
287
|
-
"name": "upload_files",
|
288
|
-
"resolver": self.resolve_upload_files,
|
289
|
-
},
|
290
|
-
{
|
291
|
-
"name": "uploaded_files",
|
292
|
-
"resolver": self.resolve_uploaded_files,
|
293
|
-
},
|
294
|
-
{
|
295
|
-
"name": "uploaded_files_legacy",
|
296
|
-
"resolver": self.resolve_uploaded_files_legacy,
|
297
|
-
},
|
298
|
-
{
|
299
|
-
"name": "preempting",
|
300
|
-
"resolver": self.resolve_preempting,
|
301
|
-
},
|
302
|
-
{
|
303
|
-
"name": "upsert_sweep",
|
304
|
-
"resolver": self.resolve_upsert_sweep,
|
305
|
-
},
|
306
|
-
{
|
307
|
-
"name": "create_artifact",
|
308
|
-
"resolver": self.resolve_create_artifact,
|
309
|
-
},
|
310
|
-
{
|
311
|
-
"name": "delete_run",
|
312
|
-
"resolver": self.resolve_delete_run,
|
313
|
-
},
|
314
|
-
]
|
315
|
-
|
316
|
-
@staticmethod
|
317
|
-
def resolve_upsert_bucket(
|
318
|
-
request_data: Dict[str, Any], response_data: Dict[str, Any], **kwargs: Any
|
319
|
-
) -> Optional[Dict[str, Any]]:
|
320
|
-
if not isinstance(request_data, dict) or not isinstance(response_data, dict):
|
321
|
-
return None
|
322
|
-
query = response_data.get("data", {}).get("upsertBucket") is not None
|
323
|
-
if query:
|
324
|
-
data = {
|
325
|
-
k: v for (k, v) in request_data["variables"].items() if v is not None
|
326
|
-
}
|
327
|
-
data.update(response_data["data"]["upsertBucket"].get("bucket"))
|
328
|
-
if "config" in data:
|
329
|
-
data["config"] = json.loads(data["config"])
|
330
|
-
return data
|
331
|
-
return None
|
332
|
-
|
333
|
-
@staticmethod
|
334
|
-
def resolve_delete_run(
|
335
|
-
request_data: Dict[str, Any], response_data: Dict[str, Any], **kwargs: Any
|
336
|
-
) -> Optional[Dict[str, Any]]:
|
337
|
-
if not isinstance(request_data, dict) or not isinstance(response_data, dict):
|
338
|
-
return None
|
339
|
-
query = "query" in request_data and "deleteRun" in request_data["query"]
|
340
|
-
if query:
|
341
|
-
data = {
|
342
|
-
k: v for (k, v) in request_data["variables"].items() if v is not None
|
343
|
-
}
|
344
|
-
data.update(response_data["data"]["deleteRun"])
|
345
|
-
return data
|
346
|
-
return None
|
347
|
-
|
348
|
-
@staticmethod
|
349
|
-
def resolve_upload_files(
|
350
|
-
request_data: Dict[str, Any], response_data: Dict[str, Any], **kwargs: Any
|
351
|
-
) -> Optional[Dict[str, Any]]:
|
352
|
-
if not isinstance(request_data, dict):
|
353
|
-
return None
|
354
|
-
|
355
|
-
query = request_data.get("files") is not None
|
356
|
-
if query:
|
357
|
-
# todo: refactor this 🤮🤮🤮🤮🤮 eventually?
|
358
|
-
name = kwargs.get("path").split("/")[2]
|
359
|
-
files = defaultdict(list)
|
360
|
-
for file_name, file_value in request_data["files"].items():
|
361
|
-
content = []
|
362
|
-
for k in file_value.get("content", []):
|
363
|
-
try:
|
364
|
-
content.append(json.loads(k))
|
365
|
-
except json.decoder.JSONDecodeError:
|
366
|
-
content.append([k])
|
367
|
-
|
368
|
-
files[file_name].append(
|
369
|
-
{"offset": file_value.get("offset"), "content": content}
|
370
|
-
)
|
371
|
-
|
372
|
-
post_processed_data = {
|
373
|
-
"name": name,
|
374
|
-
"dropped": [request_data["dropped"]]
|
375
|
-
if "dropped" in request_data
|
376
|
-
else [],
|
377
|
-
"files": files,
|
378
|
-
}
|
379
|
-
return post_processed_data
|
380
|
-
return None
|
381
|
-
|
382
|
-
@staticmethod
|
383
|
-
def resolve_uploaded_files(
|
384
|
-
request_data: Dict[str, Any], response_data: Dict[str, Any], **kwargs: Any
|
385
|
-
) -> Optional[Dict[str, Any]]:
|
386
|
-
if not isinstance(request_data, dict) or not isinstance(response_data, dict):
|
387
|
-
return None
|
388
|
-
|
389
|
-
query = "CreateRunFiles" in request_data.get("query", "")
|
390
|
-
if query:
|
391
|
-
run_name = request_data["variables"]["run"]
|
392
|
-
files = ((response_data.get("data") or {}).get("createRunFiles") or {}).get(
|
393
|
-
"files", {}
|
394
|
-
)
|
395
|
-
post_processed_data = {
|
396
|
-
"name": run_name,
|
397
|
-
"uploaded": [file["name"] for file in files] if files else [""],
|
398
|
-
}
|
399
|
-
return post_processed_data
|
400
|
-
return None
|
401
|
-
|
402
|
-
@staticmethod
|
403
|
-
def resolve_uploaded_files_legacy(
|
404
|
-
request_data: Dict[str, Any], response_data: Dict[str, Any], **kwargs: Any
|
405
|
-
) -> Optional[Dict[str, Any]]:
|
406
|
-
# This is a legacy resolver for uploaded files
|
407
|
-
# No longer used by tests but leaving it here in case we need it in the future
|
408
|
-
# Please refer to upload_urls() in internal_api.py for more details
|
409
|
-
if not isinstance(request_data, dict) or not isinstance(response_data, dict):
|
410
|
-
return None
|
411
|
-
|
412
|
-
query = "RunUploadUrls" in request_data.get("query", "")
|
413
|
-
if query:
|
414
|
-
# todo: refactor this 🤮🤮🤮🤮🤮 eventually?
|
415
|
-
name = request_data["variables"]["run"]
|
416
|
-
files = (
|
417
|
-
response_data.get("data", {})
|
418
|
-
.get("model", {})
|
419
|
-
.get("bucket", {})
|
420
|
-
.get("files", {})
|
421
|
-
.get("edges", [])
|
422
|
-
)
|
423
|
-
# note: we count all attempts to upload files
|
424
|
-
post_processed_data = {
|
425
|
-
"name": name,
|
426
|
-
"uploaded": [files[0].get("node", {}).get("name")] if files else [""],
|
427
|
-
}
|
428
|
-
return post_processed_data
|
429
|
-
return None
|
430
|
-
|
431
|
-
@staticmethod
|
432
|
-
def resolve_preempting(
|
433
|
-
request_data: Dict[str, Any], response_data: Dict[str, Any], **kwargs: Any
|
434
|
-
) -> Optional[Dict[str, Any]]:
|
435
|
-
if not isinstance(request_data, dict):
|
436
|
-
return None
|
437
|
-
query = "preempting" in request_data
|
438
|
-
if query:
|
439
|
-
name = kwargs.get("path").split("/")[2]
|
440
|
-
post_processed_data = {
|
441
|
-
"name": name,
|
442
|
-
"preempting": [request_data["preempting"]],
|
443
|
-
}
|
444
|
-
return post_processed_data
|
445
|
-
return None
|
446
|
-
|
447
|
-
@staticmethod
|
448
|
-
def resolve_upsert_sweep(
|
449
|
-
request_data: Dict[str, Any], response_data: Dict[str, Any], **kwargs: Any
|
450
|
-
) -> Optional[Dict[str, Any]]:
|
451
|
-
if not isinstance(response_data, dict):
|
452
|
-
return None
|
453
|
-
query = response_data.get("data", {}).get("upsertSweep") is not None
|
454
|
-
if query:
|
455
|
-
data = response_data["data"]["upsertSweep"].get("sweep")
|
456
|
-
return data
|
457
|
-
return None
|
458
|
-
|
459
|
-
def resolve_create_artifact(
|
460
|
-
self, request_data: Dict[str, Any], response_data: Dict[str, Any], **kwargs: Any
|
461
|
-
) -> Optional[Dict[str, Any]]:
|
462
|
-
if not isinstance(request_data, dict):
|
463
|
-
return None
|
464
|
-
query = (
|
465
|
-
"createArtifact(" in request_data.get("query", "")
|
466
|
-
and request_data.get("variables") is not None
|
467
|
-
and response_data is not None
|
468
|
-
)
|
469
|
-
if query:
|
470
|
-
name = request_data["variables"]["runName"]
|
471
|
-
post_processed_data = {
|
472
|
-
"name": name,
|
473
|
-
"create_artifact": [
|
474
|
-
{
|
475
|
-
"variables": request_data["variables"],
|
476
|
-
"response": response_data["data"]["createArtifact"]["artifact"],
|
477
|
-
}
|
478
|
-
],
|
479
|
-
}
|
480
|
-
return post_processed_data
|
481
|
-
return None
|
482
|
-
|
483
|
-
def resolve(
|
484
|
-
self,
|
485
|
-
request_data: Dict[str, Any],
|
486
|
-
response_data: Dict[str, Any],
|
487
|
-
**kwargs: Any,
|
488
|
-
) -> Optional[Dict[str, Any]]:
|
489
|
-
results = []
|
490
|
-
for resolver in self.resolvers:
|
491
|
-
result = resolver.get("resolver")(request_data, response_data, **kwargs)
|
492
|
-
if result is not None:
|
493
|
-
results.append(result)
|
494
|
-
return results
|
495
|
-
|
496
|
-
|
497
|
-
class TokenizedCircularPattern:
|
498
|
-
APPLY_TOKEN = "1"
|
499
|
-
PASS_TOKEN = "0"
|
500
|
-
STOP_TOKEN = "2"
|
501
|
-
|
502
|
-
def __init__(self, pattern: str):
|
503
|
-
known_tokens = {self.APPLY_TOKEN, self.PASS_TOKEN, self.STOP_TOKEN}
|
504
|
-
if not pattern:
|
505
|
-
raise ValueError("Pattern cannot be empty")
|
506
|
-
|
507
|
-
if set(pattern) - known_tokens:
|
508
|
-
raise ValueError(f"Pattern can only contain {known_tokens}")
|
509
|
-
self.pattern: Deque[str] = deque(pattern)
|
510
|
-
|
511
|
-
def next(self):
|
512
|
-
if self.pattern[0] == self.STOP_TOKEN:
|
513
|
-
return
|
514
|
-
self.pattern.rotate(-1)
|
515
|
-
|
516
|
-
def should_apply(self) -> bool:
|
517
|
-
return self.pattern[0] == self.APPLY_TOKEN
|
518
|
-
|
519
|
-
|
520
|
-
@dataclasses.dataclass
|
521
|
-
class InjectedResponse:
|
522
|
-
method: str
|
523
|
-
url: str
|
524
|
-
body: Union[str, Exception]
|
525
|
-
status: int = 200
|
526
|
-
content_type: str = "text/plain"
|
527
|
-
# todo: add more fields for other types of responses?
|
528
|
-
custom_match_fn: Optional[Callable[..., bool]] = None
|
529
|
-
application_pattern: TokenizedCircularPattern = TokenizedCircularPattern("1")
|
530
|
-
|
531
|
-
# application_pattern defines the pattern of the response injection
|
532
|
-
# as the requests come in.
|
533
|
-
# 0 == do not inject the response
|
534
|
-
# 1 == inject the response
|
535
|
-
# 2 == stop using the response (END token)
|
536
|
-
#
|
537
|
-
# - when no END token is present, the pattern is repeated indefinitely
|
538
|
-
# - when END token is present, the pattern is applied until the END token is reached
|
539
|
-
# - to replicate the current behavior:
|
540
|
-
# - use application_pattern = "1" if wanting to apply the pattern to all requests
|
541
|
-
# - use application_pattern = "1" * COUNTER + "2" to apply the pattern to the first COUNTER requests
|
542
|
-
#
|
543
|
-
# Examples of application_pattern:
|
544
|
-
# 1. application_pattern = "1012"
|
545
|
-
# - inject the response for the first request
|
546
|
-
# - do not inject the response for the second request
|
547
|
-
# - inject the response for the third request
|
548
|
-
# - stop using the response starting from the fourth request onwards
|
549
|
-
# 2. application_pattern = "110"
|
550
|
-
# repeat the following pattern indefinitely:
|
551
|
-
# - inject the response for the first request
|
552
|
-
# - inject the response for the second request
|
553
|
-
# - stop using the response for the third request
|
554
|
-
|
555
|
-
def __eq__(
|
556
|
-
self,
|
557
|
-
other: Union["InjectedResponse", requests.Request, requests.PreparedRequest],
|
558
|
-
):
|
559
|
-
"""Check InjectedResponse object equality.
|
560
|
-
|
561
|
-
We use this to check if this response should be injected as a replacement of
|
562
|
-
`other`.
|
563
|
-
|
564
|
-
:param other:
|
565
|
-
:return:
|
566
|
-
"""
|
567
|
-
if not isinstance(
|
568
|
-
other, (InjectedResponse, requests.Request, requests.PreparedRequest)
|
569
|
-
):
|
570
|
-
return False
|
571
|
-
|
572
|
-
# always check the method and url
|
573
|
-
ret = self.method == other.method and self.url == other.url
|
574
|
-
# use custom_match_fn to check, e.g. the request body content
|
575
|
-
if ret and self.custom_match_fn is not None:
|
576
|
-
ret = self.custom_match_fn(self, other)
|
577
|
-
return ret
|
578
|
-
|
579
|
-
def to_dict(self):
|
580
|
-
excluded_fields = {"application_pattern", "custom_match_fn"}
|
581
|
-
return {
|
582
|
-
k: self.__getattribute__(k)
|
583
|
-
for k in self.__dict__
|
584
|
-
if (not k.startswith("_") and k not in excluded_fields)
|
585
|
-
}
|
586
|
-
|
587
|
-
|
588
|
-
class RelayControlProtocol(Protocol):
|
589
|
-
def process(self, request: "flask.Request") -> None: ... # pragma: no cover
|
590
|
-
|
591
|
-
def control(
|
592
|
-
self, request: "flask.Request"
|
593
|
-
) -> Mapping[str, str]: ... # pragma: no cover
|
594
|
-
|
595
|
-
|
596
|
-
class RelayServer:
|
597
|
-
def __init__(
|
598
|
-
self,
|
599
|
-
base_url: str,
|
600
|
-
inject: Optional[List[InjectedResponse]] = None,
|
601
|
-
control: Optional[RelayControlProtocol] = None,
|
602
|
-
verbose: bool = False,
|
603
|
-
) -> None:
|
604
|
-
# todo for the future:
|
605
|
-
# - consider switching from Flask to Quart
|
606
|
-
# - async app will allow for better failure injection/poor network perf
|
607
|
-
self.relay_control = control
|
608
|
-
self.app = flask.Flask(__name__)
|
609
|
-
self.app.logger.setLevel(logging.INFO)
|
610
|
-
self.app.register_error_handler(DeliberateHTTPError, self.handle_http_exception)
|
611
|
-
self.app.add_url_rule(
|
612
|
-
rule="/graphql",
|
613
|
-
endpoint="graphql",
|
614
|
-
view_func=self.graphql,
|
615
|
-
methods=["POST"],
|
616
|
-
)
|
617
|
-
self.app.add_url_rule(
|
618
|
-
rule="/files/<path:path>",
|
619
|
-
endpoint="files",
|
620
|
-
view_func=self.file_stream,
|
621
|
-
methods=["POST"],
|
622
|
-
)
|
623
|
-
self.app.add_url_rule(
|
624
|
-
rule="/storage",
|
625
|
-
endpoint="storage",
|
626
|
-
view_func=self.storage,
|
627
|
-
methods=["PUT", "GET"],
|
628
|
-
)
|
629
|
-
self.app.add_url_rule(
|
630
|
-
rule="/storage/<path:path>",
|
631
|
-
endpoint="storage_file",
|
632
|
-
view_func=self.storage_file,
|
633
|
-
methods=["PUT", "GET"],
|
634
|
-
)
|
635
|
-
if control:
|
636
|
-
self.app.add_url_rule(
|
637
|
-
rule="/_control",
|
638
|
-
endpoint="_control",
|
639
|
-
view_func=self.control,
|
640
|
-
methods=["POST"],
|
641
|
-
)
|
642
|
-
# @app.route("/artifacts/<entity>/<digest>", methods=["GET", "POST"])
|
643
|
-
self.port = self._get_free_port()
|
644
|
-
self.base_url = urllib.parse.urlparse(base_url)
|
645
|
-
self.session = requests.Session()
|
646
|
-
self.relay_url = f"http://127.0.0.1:{self.port}"
|
647
|
-
|
648
|
-
# todo: add an option to add custom resolvers
|
649
|
-
self.resolver = QueryResolver()
|
650
|
-
# recursively merge-able object to store state
|
651
|
-
self.context = Context()
|
652
|
-
|
653
|
-
# injected responses
|
654
|
-
self.inject = inject or []
|
655
|
-
|
656
|
-
# useful when debugging:
|
657
|
-
# self.after_request_fn = self.app.after_request(self.after_request_fn)
|
658
|
-
self.verbose = verbose
|
659
|
-
|
660
|
-
@staticmethod
|
661
|
-
def handle_http_exception(e):
|
662
|
-
response = e.get_response()
|
663
|
-
return response
|
664
|
-
|
665
|
-
@staticmethod
|
666
|
-
def _get_free_port() -> int:
|
667
|
-
sock = socket.socket()
|
668
|
-
sock.bind(("", 0))
|
669
|
-
|
670
|
-
_, port = sock.getsockname()
|
671
|
-
return port
|
672
|
-
|
673
|
-
def start(self) -> None:
|
674
|
-
# run server in a separate thread
|
675
|
-
relay_server_thread = threading.Thread(
|
676
|
-
target=self.app.run,
|
677
|
-
kwargs={"port": self.port},
|
678
|
-
daemon=True,
|
679
|
-
)
|
680
|
-
relay_server_thread.start()
|
681
|
-
|
682
|
-
def after_request_fn(self, response: "requests.Response") -> "requests.Response":
|
683
|
-
# todo: this is useful for debugging, but should be removed in the future
|
684
|
-
# flask.request.url = self.relay_url + flask.request.url
|
685
|
-
print(flask.request)
|
686
|
-
print(flask.request.get_json())
|
687
|
-
print(response)
|
688
|
-
print(response.json())
|
689
|
-
return response
|
690
|
-
|
691
|
-
def relay(
|
692
|
-
self,
|
693
|
-
request: "flask.Request",
|
694
|
-
) -> Union["responses.Response", "requests.Response", None]:
|
695
|
-
# replace the relay url with the real backend url (self.base_url)
|
696
|
-
url = (
|
697
|
-
urllib.parse.urlparse(request.url)
|
698
|
-
._replace(netloc=self.base_url.netloc, scheme=self.base_url.scheme)
|
699
|
-
.geturl()
|
700
|
-
)
|
701
|
-
headers = {key: value for (key, value) in request.headers if key != "Host"}
|
702
|
-
prepared_relayed_request = requests.Request(
|
703
|
-
method=request.method,
|
704
|
-
url=url,
|
705
|
-
headers=headers,
|
706
|
-
data=request.get_data(),
|
707
|
-
json=request.get_json(),
|
708
|
-
).prepare()
|
709
|
-
|
710
|
-
if self.verbose:
|
711
|
-
print("*****************")
|
712
|
-
print("RELAY REQUEST:")
|
713
|
-
print(prepared_relayed_request.url)
|
714
|
-
print(prepared_relayed_request.method)
|
715
|
-
print(prepared_relayed_request.headers)
|
716
|
-
print(prepared_relayed_request.body)
|
717
|
-
print("*****************")
|
718
|
-
|
719
|
-
for injected_response in self.inject:
|
720
|
-
# where are we in the application pattern?
|
721
|
-
should_apply = injected_response.application_pattern.should_apply()
|
722
|
-
# check if an injected response matches the request
|
723
|
-
if injected_response != prepared_relayed_request or not should_apply:
|
724
|
-
continue
|
725
|
-
|
726
|
-
if self.verbose:
|
727
|
-
print("*****************")
|
728
|
-
print("INJECTING RESPONSE:")
|
729
|
-
print(injected_response.to_dict())
|
730
|
-
print("*****************")
|
731
|
-
# rotate the injection pattern
|
732
|
-
injected_response.application_pattern.next()
|
733
|
-
|
734
|
-
# TODO: allow access to the request object when making the mocked response
|
735
|
-
with responses.RequestsMock() as mocked_responses:
|
736
|
-
# do the actual injection
|
737
|
-
resp = injected_response.to_dict()
|
738
|
-
|
739
|
-
if isinstance(resp["body"], ConnectionResetError):
|
740
|
-
return None
|
741
|
-
|
742
|
-
mocked_responses.add(**resp)
|
743
|
-
relayed_response = self.session.send(prepared_relayed_request)
|
744
|
-
|
745
|
-
return relayed_response
|
746
|
-
|
747
|
-
# normal case: no injected response matches the request
|
748
|
-
relayed_response = self.session.send(prepared_relayed_request)
|
749
|
-
return relayed_response
|
750
|
-
|
751
|
-
def snoop_context(
|
752
|
-
self,
|
753
|
-
request: "flask.Request",
|
754
|
-
response: "requests.Response",
|
755
|
-
time_elapsed: float,
|
756
|
-
**kwargs: Any,
|
757
|
-
) -> None:
|
758
|
-
request_data = request.get_json()
|
759
|
-
response_data = response.json() or {}
|
760
|
-
|
761
|
-
if self.relay_control:
|
762
|
-
self.relay_control.process(request)
|
763
|
-
|
764
|
-
# store raw data
|
765
|
-
raw_data: RawRequestResponse = {
|
766
|
-
"url": request.url,
|
767
|
-
"request": request_data,
|
768
|
-
"response": response_data,
|
769
|
-
"time_elapsed": time_elapsed,
|
770
|
-
}
|
771
|
-
self.context.raw_data.append(raw_data)
|
772
|
-
|
773
|
-
try:
|
774
|
-
snooped_context = self.resolver.resolve(
|
775
|
-
request_data,
|
776
|
-
response_data,
|
777
|
-
**kwargs,
|
778
|
-
)
|
779
|
-
for entry in snooped_context:
|
780
|
-
self.context.upsert(entry)
|
781
|
-
except Exception as e:
|
782
|
-
print("Failed to resolve context: ", e)
|
783
|
-
traceback.print_exc()
|
784
|
-
snooped_context = None
|
785
|
-
|
786
|
-
return None
|
787
|
-
|
788
|
-
def graphql(self) -> Mapping[str, str]:
|
789
|
-
request = flask.request
|
790
|
-
with Timer() as timer:
|
791
|
-
relayed_response = self.relay(request)
|
792
|
-
if self.verbose:
|
793
|
-
print("*****************")
|
794
|
-
print("GRAPHQL REQUEST:")
|
795
|
-
print(request.get_json())
|
796
|
-
print("GRAPHQL RESPONSE:")
|
797
|
-
print(relayed_response.status_code, relayed_response.json())
|
798
|
-
print("*****************")
|
799
|
-
# snoop work to extract the context
|
800
|
-
self.snoop_context(request, relayed_response, timer.elapsed)
|
801
|
-
if self.verbose:
|
802
|
-
print("*****************")
|
803
|
-
print("SNOOPED CONTEXT:")
|
804
|
-
print(self.context.entries)
|
805
|
-
print(len(self.context.raw_data))
|
806
|
-
print("*****************")
|
807
|
-
|
808
|
-
return relayed_response.json()
|
809
|
-
|
810
|
-
def file_stream(self, path) -> Mapping[str, str]:
|
811
|
-
request = flask.request
|
812
|
-
|
813
|
-
with Timer() as timer:
|
814
|
-
relayed_response = self.relay(request)
|
815
|
-
|
816
|
-
# simulate connection reset by peer
|
817
|
-
if relayed_response is None:
|
818
|
-
connection = request.environ["werkzeug.socket"] # Get the socket object
|
819
|
-
connection.shutdown(socket.SHUT_RDWR)
|
820
|
-
connection.close()
|
821
|
-
|
822
|
-
if self.verbose:
|
823
|
-
print("*****************")
|
824
|
-
print("FILE STREAM REQUEST:")
|
825
|
-
print("********PATH*********")
|
826
|
-
print(path)
|
827
|
-
print("********ENDPATH*********")
|
828
|
-
print(request.get_json())
|
829
|
-
print("FILE STREAM RESPONSE:")
|
830
|
-
print(relayed_response)
|
831
|
-
print(relayed_response.status_code, relayed_response.json())
|
832
|
-
print("*****************")
|
833
|
-
self.snoop_context(request, relayed_response, timer.elapsed, path=path)
|
834
|
-
|
835
|
-
return relayed_response.json()
|
836
|
-
|
837
|
-
def storage(self) -> Mapping[str, str]:
|
838
|
-
request = flask.request
|
839
|
-
with Timer() as timer:
|
840
|
-
relayed_response = self.relay(request)
|
841
|
-
if self.verbose:
|
842
|
-
print("*****************")
|
843
|
-
print("STORAGE REQUEST:")
|
844
|
-
print(request.get_json())
|
845
|
-
print("STORAGE RESPONSE:")
|
846
|
-
print(relayed_response.status_code, relayed_response.json())
|
847
|
-
print("*****************")
|
848
|
-
|
849
|
-
self.snoop_context(request, relayed_response, timer.elapsed)
|
850
|
-
|
851
|
-
return relayed_response.json()
|
852
|
-
|
853
|
-
def storage_file(self, path) -> Mapping[str, str]:
|
854
|
-
request = flask.request
|
855
|
-
with Timer() as timer:
|
856
|
-
relayed_response = self.relay(request)
|
857
|
-
if self.verbose:
|
858
|
-
print("*****************")
|
859
|
-
print("STORAGE FILE REQUEST:")
|
860
|
-
print("********PATH*********")
|
861
|
-
print(path)
|
862
|
-
print("********ENDPATH*********")
|
863
|
-
print(request.get_json())
|
864
|
-
print("STORAGE FILE RESPONSE:")
|
865
|
-
print(relayed_response.json())
|
866
|
-
print("*****************")
|
867
|
-
|
868
|
-
self.snoop_context(request, relayed_response, timer.elapsed, path=path)
|
869
|
-
|
870
|
-
return relayed_response.json()
|
871
|
-
|
872
|
-
def control(self) -> Mapping[str, str]:
|
873
|
-
assert self.relay_control
|
874
|
-
return self.relay_control.control(flask.request)
|