wandb 0.17.8rc1__py3-none-win32.whl → 0.17.9__py3-none-win32.whl
Sign up to get free protection for your applications and to get access to all the features.
- package_readme.md +47 -53
- wandb/__init__.py +12 -6
- wandb/__init__.pyi +112 -2
- wandb/bin/wandb-core +0 -0
- wandb/data_types.py +1 -0
- wandb/env.py +13 -0
- wandb/integration/keras/__init__.py +2 -5
- wandb/integration/keras/callbacks/metrics_logger.py +10 -4
- wandb/integration/keras/callbacks/model_checkpoint.py +0 -5
- wandb/integration/keras/keras.py +12 -1
- wandb/integration/openai/fine_tuning.py +5 -5
- wandb/integration/tensorboard/log.py +1 -1
- wandb/proto/v3/wandb_internal_pb2.py +31 -21
- wandb/proto/v3/wandb_settings_pb2.py +2 -2
- wandb/proto/v3/wandb_telemetry_pb2.py +12 -12
- wandb/proto/v4/wandb_internal_pb2.py +23 -21
- wandb/proto/v4/wandb_settings_pb2.py +2 -2
- wandb/proto/v4/wandb_telemetry_pb2.py +12 -12
- wandb/proto/v5/wandb_internal_pb2.py +23 -21
- wandb/proto/v5/wandb_settings_pb2.py +2 -2
- wandb/proto/v5/wandb_telemetry_pb2.py +12 -12
- wandb/proto/wandb_deprecated.py +4 -0
- wandb/sdk/__init__.py +1 -1
- wandb/sdk/artifacts/artifact.py +9 -11
- wandb/sdk/artifacts/artifact_manifest_entry.py +10 -2
- wandb/sdk/artifacts/storage_handlers/gcs_handler.py +31 -0
- wandb/sdk/internal/system/assets/trainium.py +2 -1
- wandb/sdk/internal/tb_watcher.py +1 -1
- wandb/sdk/lib/_settings_toposort_generated.py +5 -3
- wandb/sdk/service/service.py +7 -2
- wandb/sdk/wandb_init.py +5 -1
- wandb/sdk/wandb_manager.py +0 -3
- wandb/sdk/wandb_require.py +22 -1
- wandb/sdk/wandb_run.py +14 -4
- wandb/sdk/wandb_settings.py +32 -10
- wandb/sdk/wandb_setup.py +3 -0
- {wandb-0.17.8rc1.dist-info → wandb-0.17.9.dist-info}/METADATA +48 -54
- {wandb-0.17.8rc1.dist-info → wandb-0.17.9.dist-info}/RECORD +42 -43
- wandb/testing/relay.py +0 -874
- /wandb/{viz.py → sdk/lib/viz.py} +0 -0
- {wandb-0.17.8rc1.dist-info → wandb-0.17.9.dist-info}/WHEEL +0 -0
- {wandb-0.17.8rc1.dist-info → wandb-0.17.9.dist-info}/entry_points.txt +0 -0
- {wandb-0.17.8rc1.dist-info → wandb-0.17.9.dist-info}/licenses/LICENSE +0 -0
wandb/testing/relay.py
DELETED
@@ -1,874 +0,0 @@
|
|
1
|
-
import dataclasses
|
2
|
-
import json
|
3
|
-
import logging
|
4
|
-
import socket
|
5
|
-
import sys
|
6
|
-
import threading
|
7
|
-
import traceback
|
8
|
-
import urllib.parse
|
9
|
-
from collections import defaultdict, deque
|
10
|
-
from copy import deepcopy
|
11
|
-
from typing import (
|
12
|
-
TYPE_CHECKING,
|
13
|
-
Any,
|
14
|
-
Callable,
|
15
|
-
Dict,
|
16
|
-
Iterable,
|
17
|
-
List,
|
18
|
-
Mapping,
|
19
|
-
Optional,
|
20
|
-
Union,
|
21
|
-
)
|
22
|
-
|
23
|
-
import flask
|
24
|
-
import pandas as pd
|
25
|
-
import requests
|
26
|
-
import responses
|
27
|
-
|
28
|
-
import wandb
|
29
|
-
import wandb.util
|
30
|
-
from wandb.sdk.lib.timer import Timer
|
31
|
-
|
32
|
-
try:
|
33
|
-
from typing import Literal, TypedDict
|
34
|
-
except ImportError:
|
35
|
-
from typing_extensions import Literal, TypedDict
|
36
|
-
|
37
|
-
if sys.version_info >= (3, 8):
|
38
|
-
from typing import Protocol
|
39
|
-
else:
|
40
|
-
from typing_extensions import Protocol
|
41
|
-
|
42
|
-
if TYPE_CHECKING:
|
43
|
-
from typing import Deque
|
44
|
-
|
45
|
-
class RawRequestResponse(TypedDict):
|
46
|
-
url: str
|
47
|
-
request: Optional[Any]
|
48
|
-
response: Dict[str, Any]
|
49
|
-
time_elapsed: float # seconds
|
50
|
-
|
51
|
-
ResolverName = Literal[
|
52
|
-
"upsert_bucket",
|
53
|
-
"upload_files",
|
54
|
-
"uploaded_files",
|
55
|
-
"preempting",
|
56
|
-
"upsert_sweep",
|
57
|
-
]
|
58
|
-
|
59
|
-
class Resolver(TypedDict):
|
60
|
-
name: ResolverName
|
61
|
-
resolver: Callable[[Any], Optional[Dict[str, Any]]]
|
62
|
-
|
63
|
-
|
64
|
-
class DeliberateHTTPError(Exception):
|
65
|
-
def __init__(self, message, status_code: int = 500):
|
66
|
-
Exception.__init__(self)
|
67
|
-
self.message = message
|
68
|
-
self.status_code = status_code
|
69
|
-
|
70
|
-
def get_response(self):
|
71
|
-
return flask.Response(self.message, status=self.status_code)
|
72
|
-
|
73
|
-
def __repr__(self):
|
74
|
-
return f"DeliberateHTTPError({self.message!r}, {self.status_code!r})"
|
75
|
-
|
76
|
-
|
77
|
-
@dataclasses.dataclass
|
78
|
-
class RunAttrs:
|
79
|
-
"""Simple data class for run attributes."""
|
80
|
-
|
81
|
-
name: str
|
82
|
-
display_name: str
|
83
|
-
description: str
|
84
|
-
sweep_name: str
|
85
|
-
project: Dict[str, Any]
|
86
|
-
config: Dict[str, Any]
|
87
|
-
remote: Optional[str] = None
|
88
|
-
commit: Optional[str] = None
|
89
|
-
|
90
|
-
|
91
|
-
class Context:
|
92
|
-
"""A container used to store the snooped state/data of a test.
|
93
|
-
|
94
|
-
Includes raw requests and responses, parsed and processed data, and a number of
|
95
|
-
convenience methods and properties for accessing the data.
|
96
|
-
"""
|
97
|
-
|
98
|
-
def __init__(self) -> None:
|
99
|
-
# parsed/merged data. keys are the individual wandb run id's.
|
100
|
-
self._entries = defaultdict(dict)
|
101
|
-
# container for raw requests and responses:
|
102
|
-
self.raw_data: List[RawRequestResponse] = []
|
103
|
-
# concatenated file contents for all runs:
|
104
|
-
self._history: Optional[pd.DataFrame] = None
|
105
|
-
self._events: Optional[pd.DataFrame] = None
|
106
|
-
self._summary: Optional[pd.DataFrame] = None
|
107
|
-
self._config: Optional[Dict[str, Any]] = None
|
108
|
-
self._output: Optional[Any] = None
|
109
|
-
|
110
|
-
def upsert(self, entry: Dict[str, Any]) -> None:
|
111
|
-
try:
|
112
|
-
entry_id: str = entry["name"]
|
113
|
-
except KeyError:
|
114
|
-
entry_id = entry["id"]
|
115
|
-
self._entries[entry_id] = wandb.util.merge_dicts(entry, self._entries[entry_id])
|
116
|
-
|
117
|
-
# mapping interface
|
118
|
-
def __getitem__(self, key: str) -> Any:
|
119
|
-
return self._entries[key]
|
120
|
-
|
121
|
-
def keys(self) -> Iterable[str]:
|
122
|
-
return self._entries.keys()
|
123
|
-
|
124
|
-
def get_file_contents(self, file_name: str) -> pd.DataFrame:
|
125
|
-
dfs = []
|
126
|
-
|
127
|
-
for entry_id in self._entries:
|
128
|
-
# - extract the content from `file_name`
|
129
|
-
# - sort by offset (will be useful when relay server goes async)
|
130
|
-
# - extract data, merge into a list of dicts and convert to a pandas dataframe
|
131
|
-
content_list = self._entries[entry_id].get("files", {}).get(file_name, [])
|
132
|
-
content_list.sort(key=lambda x: x["offset"])
|
133
|
-
content_list = [item["content"] for item in content_list]
|
134
|
-
# merge list of lists content_list:
|
135
|
-
content_list = [item for sublist in content_list for item in sublist]
|
136
|
-
df = pd.DataFrame.from_records(content_list)
|
137
|
-
df["__run_id"] = entry_id
|
138
|
-
dfs.append(df)
|
139
|
-
|
140
|
-
return pd.concat(dfs)
|
141
|
-
|
142
|
-
# attributes to use in assertions
|
143
|
-
@property
|
144
|
-
def entries(self) -> Dict[str, Any]:
|
145
|
-
return deepcopy(self._entries)
|
146
|
-
|
147
|
-
@property
|
148
|
-
def history(self) -> pd.DataFrame:
|
149
|
-
# todo: caveat: this assumes that all assertions happen at the end of a test
|
150
|
-
if self._history is not None:
|
151
|
-
return deepcopy(self._history)
|
152
|
-
|
153
|
-
self._history = self.get_file_contents("wandb-history.jsonl")
|
154
|
-
return deepcopy(self._history)
|
155
|
-
|
156
|
-
@property
|
157
|
-
def events(self) -> pd.DataFrame:
|
158
|
-
if self._events is not None:
|
159
|
-
return deepcopy(self._events)
|
160
|
-
|
161
|
-
self._events = self.get_file_contents("wandb-events.jsonl")
|
162
|
-
return deepcopy(self._events)
|
163
|
-
|
164
|
-
@property
|
165
|
-
def summary(self) -> pd.DataFrame:
|
166
|
-
if self._summary is not None:
|
167
|
-
return deepcopy(self._summary)
|
168
|
-
|
169
|
-
_summary = self.get_file_contents("wandb-summary.json")
|
170
|
-
|
171
|
-
# run summary may be updated multiple times,
|
172
|
-
# but we are only interested in the last one.
|
173
|
-
# we can have multiple runs saved to context,
|
174
|
-
# so we need to group by run id and take the
|
175
|
-
# last one for each run.
|
176
|
-
self._summary = (
|
177
|
-
_summary.groupby("__run_id").last().reset_index(level=["__run_id"])
|
178
|
-
)
|
179
|
-
|
180
|
-
return deepcopy(self._summary)
|
181
|
-
|
182
|
-
@property
|
183
|
-
def output(self) -> pd.DataFrame:
|
184
|
-
if self._output is not None:
|
185
|
-
return deepcopy(self._output)
|
186
|
-
|
187
|
-
self._output = self.get_file_contents("output.log")
|
188
|
-
return deepcopy(self._output)
|
189
|
-
|
190
|
-
@property
|
191
|
-
def config(self) -> Dict[str, Any]:
|
192
|
-
if self._config is not None:
|
193
|
-
return deepcopy(self._config)
|
194
|
-
|
195
|
-
self._config = {k: v["config"] for (k, v) in self._entries.items()}
|
196
|
-
return deepcopy(self._config)
|
197
|
-
|
198
|
-
# @property
|
199
|
-
# def telemetry(self) -> pd.DataFrame:
|
200
|
-
# telemetry = pd.DataFrame.from_records(
|
201
|
-
# [
|
202
|
-
# {
|
203
|
-
# "__run_id": run_id,
|
204
|
-
# "telemetry": config.get("_wandb", {}).get("value", {}).get("t")
|
205
|
-
# }
|
206
|
-
# for (run_id, config) in self.config.items()
|
207
|
-
# ]
|
208
|
-
# )
|
209
|
-
# return telemetry
|
210
|
-
|
211
|
-
# convenience data access methods
|
212
|
-
def get_run_telemetry(self, run_id: str) -> Dict[str, Any]:
|
213
|
-
return self.config.get(run_id, {}).get("_wandb", {}).get("value", {}).get("t")
|
214
|
-
|
215
|
-
def get_run_metrics(self, run_id: str) -> Dict[str, Any]:
|
216
|
-
return self.config.get(run_id, {}).get("_wandb", {}).get("value", {}).get("m")
|
217
|
-
|
218
|
-
def get_run_summary(
|
219
|
-
self, run_id: str, include_private: bool = False
|
220
|
-
) -> Dict[str, Any]:
|
221
|
-
# run summary dataframe must have only one row
|
222
|
-
# for the given run id, so we convert it to dict
|
223
|
-
# and extract the first (and only) row.
|
224
|
-
mask_run = self.summary["__run_id"] == run_id
|
225
|
-
run_summary = self.summary[mask_run]
|
226
|
-
ret = (
|
227
|
-
run_summary.filter(regex="^[^_]", axis=1)
|
228
|
-
if not include_private
|
229
|
-
else run_summary
|
230
|
-
).to_dict(orient="records")
|
231
|
-
return ret[0] if len(ret) > 0 else {}
|
232
|
-
|
233
|
-
def get_run_history(
|
234
|
-
self, run_id: str, include_private: bool = False
|
235
|
-
) -> pd.DataFrame:
|
236
|
-
mask_run = self.history["__run_id"] == run_id
|
237
|
-
run_history = self.history[mask_run]
|
238
|
-
return (
|
239
|
-
run_history.filter(regex="^[^_]", axis=1)
|
240
|
-
if not include_private
|
241
|
-
else run_history
|
242
|
-
)
|
243
|
-
|
244
|
-
def get_run_uploaded_files(self, run_id: str) -> Dict[str, Any]:
|
245
|
-
return self.entries.get(run_id, {}).get("uploaded", [])
|
246
|
-
|
247
|
-
def get_run_stats(self, run_id: str) -> pd.DataFrame:
|
248
|
-
mask_run = self.events["__run_id"] == run_id
|
249
|
-
run_stats = self.events[mask_run]
|
250
|
-
return run_stats
|
251
|
-
|
252
|
-
def get_run_attrs(self, run_id: str) -> Optional[RunAttrs]:
|
253
|
-
run_entry = self._entries.get(run_id)
|
254
|
-
if not run_entry:
|
255
|
-
return None
|
256
|
-
|
257
|
-
return RunAttrs(
|
258
|
-
name=run_entry["name"],
|
259
|
-
display_name=run_entry["displayName"],
|
260
|
-
description=run_entry["description"],
|
261
|
-
sweep_name=run_entry["sweepName"],
|
262
|
-
project=run_entry["project"],
|
263
|
-
config=run_entry["config"],
|
264
|
-
remote=run_entry.get("repo"),
|
265
|
-
commit=run_entry.get("commit"),
|
266
|
-
)
|
267
|
-
|
268
|
-
def get_run(self, run_id: str) -> Dict[str, Any]:
|
269
|
-
return self._entries.get(run_id, {})
|
270
|
-
|
271
|
-
# todo: add getter (by run_id) utilities for other properties
|
272
|
-
|
273
|
-
|
274
|
-
class QueryResolver:
|
275
|
-
"""Resolve request/response pairs against a set of known patterns.
|
276
|
-
|
277
|
-
This extracts and processes useful data to be later stored in a Context object.
|
278
|
-
"""
|
279
|
-
|
280
|
-
def __init__(self):
|
281
|
-
self.resolvers: List[Resolver] = [
|
282
|
-
{
|
283
|
-
"name": "upsert_bucket",
|
284
|
-
"resolver": self.resolve_upsert_bucket,
|
285
|
-
},
|
286
|
-
{
|
287
|
-
"name": "upload_files",
|
288
|
-
"resolver": self.resolve_upload_files,
|
289
|
-
},
|
290
|
-
{
|
291
|
-
"name": "uploaded_files",
|
292
|
-
"resolver": self.resolve_uploaded_files,
|
293
|
-
},
|
294
|
-
{
|
295
|
-
"name": "uploaded_files_legacy",
|
296
|
-
"resolver": self.resolve_uploaded_files_legacy,
|
297
|
-
},
|
298
|
-
{
|
299
|
-
"name": "preempting",
|
300
|
-
"resolver": self.resolve_preempting,
|
301
|
-
},
|
302
|
-
{
|
303
|
-
"name": "upsert_sweep",
|
304
|
-
"resolver": self.resolve_upsert_sweep,
|
305
|
-
},
|
306
|
-
{
|
307
|
-
"name": "create_artifact",
|
308
|
-
"resolver": self.resolve_create_artifact,
|
309
|
-
},
|
310
|
-
{
|
311
|
-
"name": "delete_run",
|
312
|
-
"resolver": self.resolve_delete_run,
|
313
|
-
},
|
314
|
-
]
|
315
|
-
|
316
|
-
@staticmethod
|
317
|
-
def resolve_upsert_bucket(
|
318
|
-
request_data: Dict[str, Any], response_data: Dict[str, Any], **kwargs: Any
|
319
|
-
) -> Optional[Dict[str, Any]]:
|
320
|
-
if not isinstance(request_data, dict) or not isinstance(response_data, dict):
|
321
|
-
return None
|
322
|
-
query = response_data.get("data", {}).get("upsertBucket") is not None
|
323
|
-
if query:
|
324
|
-
data = {
|
325
|
-
k: v for (k, v) in request_data["variables"].items() if v is not None
|
326
|
-
}
|
327
|
-
data.update(response_data["data"]["upsertBucket"].get("bucket"))
|
328
|
-
if "config" in data:
|
329
|
-
data["config"] = json.loads(data["config"])
|
330
|
-
return data
|
331
|
-
return None
|
332
|
-
|
333
|
-
@staticmethod
|
334
|
-
def resolve_delete_run(
|
335
|
-
request_data: Dict[str, Any], response_data: Dict[str, Any], **kwargs: Any
|
336
|
-
) -> Optional[Dict[str, Any]]:
|
337
|
-
if not isinstance(request_data, dict) or not isinstance(response_data, dict):
|
338
|
-
return None
|
339
|
-
query = "query" in request_data and "deleteRun" in request_data["query"]
|
340
|
-
if query:
|
341
|
-
data = {
|
342
|
-
k: v for (k, v) in request_data["variables"].items() if v is not None
|
343
|
-
}
|
344
|
-
data.update(response_data["data"]["deleteRun"])
|
345
|
-
return data
|
346
|
-
return None
|
347
|
-
|
348
|
-
@staticmethod
|
349
|
-
def resolve_upload_files(
|
350
|
-
request_data: Dict[str, Any], response_data: Dict[str, Any], **kwargs: Any
|
351
|
-
) -> Optional[Dict[str, Any]]:
|
352
|
-
if not isinstance(request_data, dict):
|
353
|
-
return None
|
354
|
-
|
355
|
-
query = request_data.get("files") is not None
|
356
|
-
if query:
|
357
|
-
# todo: refactor this 🤮🤮🤮🤮🤮 eventually?
|
358
|
-
name = kwargs.get("path").split("/")[2]
|
359
|
-
files = defaultdict(list)
|
360
|
-
for file_name, file_value in request_data["files"].items():
|
361
|
-
content = []
|
362
|
-
for k in file_value.get("content", []):
|
363
|
-
try:
|
364
|
-
content.append(json.loads(k))
|
365
|
-
except json.decoder.JSONDecodeError:
|
366
|
-
content.append([k])
|
367
|
-
|
368
|
-
files[file_name].append(
|
369
|
-
{"offset": file_value.get("offset"), "content": content}
|
370
|
-
)
|
371
|
-
|
372
|
-
post_processed_data = {
|
373
|
-
"name": name,
|
374
|
-
"dropped": [request_data["dropped"]]
|
375
|
-
if "dropped" in request_data
|
376
|
-
else [],
|
377
|
-
"files": files,
|
378
|
-
}
|
379
|
-
return post_processed_data
|
380
|
-
return None
|
381
|
-
|
382
|
-
@staticmethod
|
383
|
-
def resolve_uploaded_files(
|
384
|
-
request_data: Dict[str, Any], response_data: Dict[str, Any], **kwargs: Any
|
385
|
-
) -> Optional[Dict[str, Any]]:
|
386
|
-
if not isinstance(request_data, dict) or not isinstance(response_data, dict):
|
387
|
-
return None
|
388
|
-
|
389
|
-
query = "CreateRunFiles" in request_data.get("query", "")
|
390
|
-
if query:
|
391
|
-
run_name = request_data["variables"]["run"]
|
392
|
-
files = ((response_data.get("data") or {}).get("createRunFiles") or {}).get(
|
393
|
-
"files", {}
|
394
|
-
)
|
395
|
-
post_processed_data = {
|
396
|
-
"name": run_name,
|
397
|
-
"uploaded": [file["name"] for file in files] if files else [""],
|
398
|
-
}
|
399
|
-
return post_processed_data
|
400
|
-
return None
|
401
|
-
|
402
|
-
@staticmethod
|
403
|
-
def resolve_uploaded_files_legacy(
|
404
|
-
request_data: Dict[str, Any], response_data: Dict[str, Any], **kwargs: Any
|
405
|
-
) -> Optional[Dict[str, Any]]:
|
406
|
-
# This is a legacy resolver for uploaded files
|
407
|
-
# No longer used by tests but leaving it here in case we need it in the future
|
408
|
-
# Please refer to upload_urls() in internal_api.py for more details
|
409
|
-
if not isinstance(request_data, dict) or not isinstance(response_data, dict):
|
410
|
-
return None
|
411
|
-
|
412
|
-
query = "RunUploadUrls" in request_data.get("query", "")
|
413
|
-
if query:
|
414
|
-
# todo: refactor this 🤮🤮🤮🤮🤮 eventually?
|
415
|
-
name = request_data["variables"]["run"]
|
416
|
-
files = (
|
417
|
-
response_data.get("data", {})
|
418
|
-
.get("model", {})
|
419
|
-
.get("bucket", {})
|
420
|
-
.get("files", {})
|
421
|
-
.get("edges", [])
|
422
|
-
)
|
423
|
-
# note: we count all attempts to upload files
|
424
|
-
post_processed_data = {
|
425
|
-
"name": name,
|
426
|
-
"uploaded": [files[0].get("node", {}).get("name")] if files else [""],
|
427
|
-
}
|
428
|
-
return post_processed_data
|
429
|
-
return None
|
430
|
-
|
431
|
-
@staticmethod
|
432
|
-
def resolve_preempting(
|
433
|
-
request_data: Dict[str, Any], response_data: Dict[str, Any], **kwargs: Any
|
434
|
-
) -> Optional[Dict[str, Any]]:
|
435
|
-
if not isinstance(request_data, dict):
|
436
|
-
return None
|
437
|
-
query = "preempting" in request_data
|
438
|
-
if query:
|
439
|
-
name = kwargs.get("path").split("/")[2]
|
440
|
-
post_processed_data = {
|
441
|
-
"name": name,
|
442
|
-
"preempting": [request_data["preempting"]],
|
443
|
-
}
|
444
|
-
return post_processed_data
|
445
|
-
return None
|
446
|
-
|
447
|
-
@staticmethod
|
448
|
-
def resolve_upsert_sweep(
|
449
|
-
request_data: Dict[str, Any], response_data: Dict[str, Any], **kwargs: Any
|
450
|
-
) -> Optional[Dict[str, Any]]:
|
451
|
-
if not isinstance(response_data, dict):
|
452
|
-
return None
|
453
|
-
query = response_data.get("data", {}).get("upsertSweep") is not None
|
454
|
-
if query:
|
455
|
-
data = response_data["data"]["upsertSweep"].get("sweep")
|
456
|
-
return data
|
457
|
-
return None
|
458
|
-
|
459
|
-
def resolve_create_artifact(
|
460
|
-
self, request_data: Dict[str, Any], response_data: Dict[str, Any], **kwargs: Any
|
461
|
-
) -> Optional[Dict[str, Any]]:
|
462
|
-
if not isinstance(request_data, dict):
|
463
|
-
return None
|
464
|
-
query = (
|
465
|
-
"createArtifact(" in request_data.get("query", "")
|
466
|
-
and request_data.get("variables") is not None
|
467
|
-
and response_data is not None
|
468
|
-
)
|
469
|
-
if query:
|
470
|
-
name = request_data["variables"]["runName"]
|
471
|
-
post_processed_data = {
|
472
|
-
"name": name,
|
473
|
-
"create_artifact": [
|
474
|
-
{
|
475
|
-
"variables": request_data["variables"],
|
476
|
-
"response": response_data["data"]["createArtifact"]["artifact"],
|
477
|
-
}
|
478
|
-
],
|
479
|
-
}
|
480
|
-
return post_processed_data
|
481
|
-
return None
|
482
|
-
|
483
|
-
def resolve(
|
484
|
-
self,
|
485
|
-
request_data: Dict[str, Any],
|
486
|
-
response_data: Dict[str, Any],
|
487
|
-
**kwargs: Any,
|
488
|
-
) -> Optional[Dict[str, Any]]:
|
489
|
-
results = []
|
490
|
-
for resolver in self.resolvers:
|
491
|
-
result = resolver.get("resolver")(request_data, response_data, **kwargs)
|
492
|
-
if result is not None:
|
493
|
-
results.append(result)
|
494
|
-
return results
|
495
|
-
|
496
|
-
|
497
|
-
class TokenizedCircularPattern:
|
498
|
-
APPLY_TOKEN = "1"
|
499
|
-
PASS_TOKEN = "0"
|
500
|
-
STOP_TOKEN = "2"
|
501
|
-
|
502
|
-
def __init__(self, pattern: str):
|
503
|
-
known_tokens = {self.APPLY_TOKEN, self.PASS_TOKEN, self.STOP_TOKEN}
|
504
|
-
if not pattern:
|
505
|
-
raise ValueError("Pattern cannot be empty")
|
506
|
-
|
507
|
-
if set(pattern) - known_tokens:
|
508
|
-
raise ValueError(f"Pattern can only contain {known_tokens}")
|
509
|
-
self.pattern: Deque[str] = deque(pattern)
|
510
|
-
|
511
|
-
def next(self):
|
512
|
-
if self.pattern[0] == self.STOP_TOKEN:
|
513
|
-
return
|
514
|
-
self.pattern.rotate(-1)
|
515
|
-
|
516
|
-
def should_apply(self) -> bool:
|
517
|
-
return self.pattern[0] == self.APPLY_TOKEN
|
518
|
-
|
519
|
-
|
520
|
-
@dataclasses.dataclass
|
521
|
-
class InjectedResponse:
|
522
|
-
method: str
|
523
|
-
url: str
|
524
|
-
body: Union[str, Exception]
|
525
|
-
status: int = 200
|
526
|
-
content_type: str = "text/plain"
|
527
|
-
# todo: add more fields for other types of responses?
|
528
|
-
custom_match_fn: Optional[Callable[..., bool]] = None
|
529
|
-
application_pattern: TokenizedCircularPattern = TokenizedCircularPattern("1")
|
530
|
-
|
531
|
-
# application_pattern defines the pattern of the response injection
|
532
|
-
# as the requests come in.
|
533
|
-
# 0 == do not inject the response
|
534
|
-
# 1 == inject the response
|
535
|
-
# 2 == stop using the response (END token)
|
536
|
-
#
|
537
|
-
# - when no END token is present, the pattern is repeated indefinitely
|
538
|
-
# - when END token is present, the pattern is applied until the END token is reached
|
539
|
-
# - to replicate the current behavior:
|
540
|
-
# - use application_pattern = "1" if wanting to apply the pattern to all requests
|
541
|
-
# - use application_pattern = "1" * COUNTER + "2" to apply the pattern to the first COUNTER requests
|
542
|
-
#
|
543
|
-
# Examples of application_pattern:
|
544
|
-
# 1. application_pattern = "1012"
|
545
|
-
# - inject the response for the first request
|
546
|
-
# - do not inject the response for the second request
|
547
|
-
# - inject the response for the third request
|
548
|
-
# - stop using the response starting from the fourth request onwards
|
549
|
-
# 2. application_pattern = "110"
|
550
|
-
# repeat the following pattern indefinitely:
|
551
|
-
# - inject the response for the first request
|
552
|
-
# - inject the response for the second request
|
553
|
-
# - stop using the response for the third request
|
554
|
-
|
555
|
-
def __eq__(
|
556
|
-
self,
|
557
|
-
other: Union["InjectedResponse", requests.Request, requests.PreparedRequest],
|
558
|
-
):
|
559
|
-
"""Check InjectedResponse object equality.
|
560
|
-
|
561
|
-
We use this to check if this response should be injected as a replacement of
|
562
|
-
`other`.
|
563
|
-
|
564
|
-
:param other:
|
565
|
-
:return:
|
566
|
-
"""
|
567
|
-
if not isinstance(
|
568
|
-
other, (InjectedResponse, requests.Request, requests.PreparedRequest)
|
569
|
-
):
|
570
|
-
return False
|
571
|
-
|
572
|
-
# always check the method and url
|
573
|
-
ret = self.method == other.method and self.url == other.url
|
574
|
-
# use custom_match_fn to check, e.g. the request body content
|
575
|
-
if ret and self.custom_match_fn is not None:
|
576
|
-
ret = self.custom_match_fn(self, other)
|
577
|
-
return ret
|
578
|
-
|
579
|
-
def to_dict(self):
|
580
|
-
excluded_fields = {"application_pattern", "custom_match_fn"}
|
581
|
-
return {
|
582
|
-
k: self.__getattribute__(k)
|
583
|
-
for k in self.__dict__
|
584
|
-
if (not k.startswith("_") and k not in excluded_fields)
|
585
|
-
}
|
586
|
-
|
587
|
-
|
588
|
-
class RelayControlProtocol(Protocol):
|
589
|
-
def process(self, request: "flask.Request") -> None: ... # pragma: no cover
|
590
|
-
|
591
|
-
def control(
|
592
|
-
self, request: "flask.Request"
|
593
|
-
) -> Mapping[str, str]: ... # pragma: no cover
|
594
|
-
|
595
|
-
|
596
|
-
class RelayServer:
|
597
|
-
def __init__(
|
598
|
-
self,
|
599
|
-
base_url: str,
|
600
|
-
inject: Optional[List[InjectedResponse]] = None,
|
601
|
-
control: Optional[RelayControlProtocol] = None,
|
602
|
-
verbose: bool = False,
|
603
|
-
) -> None:
|
604
|
-
# todo for the future:
|
605
|
-
# - consider switching from Flask to Quart
|
606
|
-
# - async app will allow for better failure injection/poor network perf
|
607
|
-
self.relay_control = control
|
608
|
-
self.app = flask.Flask(__name__)
|
609
|
-
self.app.logger.setLevel(logging.INFO)
|
610
|
-
self.app.register_error_handler(DeliberateHTTPError, self.handle_http_exception)
|
611
|
-
self.app.add_url_rule(
|
612
|
-
rule="/graphql",
|
613
|
-
endpoint="graphql",
|
614
|
-
view_func=self.graphql,
|
615
|
-
methods=["POST"],
|
616
|
-
)
|
617
|
-
self.app.add_url_rule(
|
618
|
-
rule="/files/<path:path>",
|
619
|
-
endpoint="files",
|
620
|
-
view_func=self.file_stream,
|
621
|
-
methods=["POST"],
|
622
|
-
)
|
623
|
-
self.app.add_url_rule(
|
624
|
-
rule="/storage",
|
625
|
-
endpoint="storage",
|
626
|
-
view_func=self.storage,
|
627
|
-
methods=["PUT", "GET"],
|
628
|
-
)
|
629
|
-
self.app.add_url_rule(
|
630
|
-
rule="/storage/<path:path>",
|
631
|
-
endpoint="storage_file",
|
632
|
-
view_func=self.storage_file,
|
633
|
-
methods=["PUT", "GET"],
|
634
|
-
)
|
635
|
-
if control:
|
636
|
-
self.app.add_url_rule(
|
637
|
-
rule="/_control",
|
638
|
-
endpoint="_control",
|
639
|
-
view_func=self.control,
|
640
|
-
methods=["POST"],
|
641
|
-
)
|
642
|
-
# @app.route("/artifacts/<entity>/<digest>", methods=["GET", "POST"])
|
643
|
-
self.port = self._get_free_port()
|
644
|
-
self.base_url = urllib.parse.urlparse(base_url)
|
645
|
-
self.session = requests.Session()
|
646
|
-
self.relay_url = f"http://127.0.0.1:{self.port}"
|
647
|
-
|
648
|
-
# todo: add an option to add custom resolvers
|
649
|
-
self.resolver = QueryResolver()
|
650
|
-
# recursively merge-able object to store state
|
651
|
-
self.context = Context()
|
652
|
-
|
653
|
-
# injected responses
|
654
|
-
self.inject = inject or []
|
655
|
-
|
656
|
-
# useful when debugging:
|
657
|
-
# self.after_request_fn = self.app.after_request(self.after_request_fn)
|
658
|
-
self.verbose = verbose
|
659
|
-
|
660
|
-
@staticmethod
|
661
|
-
def handle_http_exception(e):
|
662
|
-
response = e.get_response()
|
663
|
-
return response
|
664
|
-
|
665
|
-
@staticmethod
|
666
|
-
def _get_free_port() -> int:
|
667
|
-
sock = socket.socket()
|
668
|
-
sock.bind(("", 0))
|
669
|
-
|
670
|
-
_, port = sock.getsockname()
|
671
|
-
return port
|
672
|
-
|
673
|
-
def start(self) -> None:
|
674
|
-
# run server in a separate thread
|
675
|
-
relay_server_thread = threading.Thread(
|
676
|
-
target=self.app.run,
|
677
|
-
kwargs={"port": self.port},
|
678
|
-
daemon=True,
|
679
|
-
)
|
680
|
-
relay_server_thread.start()
|
681
|
-
|
682
|
-
def after_request_fn(self, response: "requests.Response") -> "requests.Response":
|
683
|
-
# todo: this is useful for debugging, but should be removed in the future
|
684
|
-
# flask.request.url = self.relay_url + flask.request.url
|
685
|
-
print(flask.request)
|
686
|
-
print(flask.request.get_json())
|
687
|
-
print(response)
|
688
|
-
print(response.json())
|
689
|
-
return response
|
690
|
-
|
691
|
-
def relay(
|
692
|
-
self,
|
693
|
-
request: "flask.Request",
|
694
|
-
) -> Union["responses.Response", "requests.Response", None]:
|
695
|
-
# replace the relay url with the real backend url (self.base_url)
|
696
|
-
url = (
|
697
|
-
urllib.parse.urlparse(request.url)
|
698
|
-
._replace(netloc=self.base_url.netloc, scheme=self.base_url.scheme)
|
699
|
-
.geturl()
|
700
|
-
)
|
701
|
-
headers = {key: value for (key, value) in request.headers if key != "Host"}
|
702
|
-
prepared_relayed_request = requests.Request(
|
703
|
-
method=request.method,
|
704
|
-
url=url,
|
705
|
-
headers=headers,
|
706
|
-
data=request.get_data(),
|
707
|
-
json=request.get_json(),
|
708
|
-
).prepare()
|
709
|
-
|
710
|
-
if self.verbose:
|
711
|
-
print("*****************")
|
712
|
-
print("RELAY REQUEST:")
|
713
|
-
print(prepared_relayed_request.url)
|
714
|
-
print(prepared_relayed_request.method)
|
715
|
-
print(prepared_relayed_request.headers)
|
716
|
-
print(prepared_relayed_request.body)
|
717
|
-
print("*****************")
|
718
|
-
|
719
|
-
for injected_response in self.inject:
|
720
|
-
# where are we in the application pattern?
|
721
|
-
should_apply = injected_response.application_pattern.should_apply()
|
722
|
-
# check if an injected response matches the request
|
723
|
-
if injected_response != prepared_relayed_request or not should_apply:
|
724
|
-
continue
|
725
|
-
|
726
|
-
if self.verbose:
|
727
|
-
print("*****************")
|
728
|
-
print("INJECTING RESPONSE:")
|
729
|
-
print(injected_response.to_dict())
|
730
|
-
print("*****************")
|
731
|
-
# rotate the injection pattern
|
732
|
-
injected_response.application_pattern.next()
|
733
|
-
|
734
|
-
# TODO: allow access to the request object when making the mocked response
|
735
|
-
with responses.RequestsMock() as mocked_responses:
|
736
|
-
# do the actual injection
|
737
|
-
resp = injected_response.to_dict()
|
738
|
-
|
739
|
-
if isinstance(resp["body"], ConnectionResetError):
|
740
|
-
return None
|
741
|
-
|
742
|
-
mocked_responses.add(**resp)
|
743
|
-
relayed_response = self.session.send(prepared_relayed_request)
|
744
|
-
|
745
|
-
return relayed_response
|
746
|
-
|
747
|
-
# normal case: no injected response matches the request
|
748
|
-
relayed_response = self.session.send(prepared_relayed_request)
|
749
|
-
return relayed_response
|
750
|
-
|
751
|
-
def snoop_context(
|
752
|
-
self,
|
753
|
-
request: "flask.Request",
|
754
|
-
response: "requests.Response",
|
755
|
-
time_elapsed: float,
|
756
|
-
**kwargs: Any,
|
757
|
-
) -> None:
|
758
|
-
request_data = request.get_json()
|
759
|
-
response_data = response.json() or {}
|
760
|
-
|
761
|
-
if self.relay_control:
|
762
|
-
self.relay_control.process(request)
|
763
|
-
|
764
|
-
# store raw data
|
765
|
-
raw_data: RawRequestResponse = {
|
766
|
-
"url": request.url,
|
767
|
-
"request": request_data,
|
768
|
-
"response": response_data,
|
769
|
-
"time_elapsed": time_elapsed,
|
770
|
-
}
|
771
|
-
self.context.raw_data.append(raw_data)
|
772
|
-
|
773
|
-
try:
|
774
|
-
snooped_context = self.resolver.resolve(
|
775
|
-
request_data,
|
776
|
-
response_data,
|
777
|
-
**kwargs,
|
778
|
-
)
|
779
|
-
for entry in snooped_context:
|
780
|
-
self.context.upsert(entry)
|
781
|
-
except Exception as e:
|
782
|
-
print("Failed to resolve context: ", e)
|
783
|
-
traceback.print_exc()
|
784
|
-
snooped_context = None
|
785
|
-
|
786
|
-
return None
|
787
|
-
|
788
|
-
def graphql(self) -> Mapping[str, str]:
|
789
|
-
request = flask.request
|
790
|
-
with Timer() as timer:
|
791
|
-
relayed_response = self.relay(request)
|
792
|
-
if self.verbose:
|
793
|
-
print("*****************")
|
794
|
-
print("GRAPHQL REQUEST:")
|
795
|
-
print(request.get_json())
|
796
|
-
print("GRAPHQL RESPONSE:")
|
797
|
-
print(relayed_response.status_code, relayed_response.json())
|
798
|
-
print("*****************")
|
799
|
-
# snoop work to extract the context
|
800
|
-
self.snoop_context(request, relayed_response, timer.elapsed)
|
801
|
-
if self.verbose:
|
802
|
-
print("*****************")
|
803
|
-
print("SNOOPED CONTEXT:")
|
804
|
-
print(self.context.entries)
|
805
|
-
print(len(self.context.raw_data))
|
806
|
-
print("*****************")
|
807
|
-
|
808
|
-
return relayed_response.json()
|
809
|
-
|
810
|
-
def file_stream(self, path) -> Mapping[str, str]:
|
811
|
-
request = flask.request
|
812
|
-
|
813
|
-
with Timer() as timer:
|
814
|
-
relayed_response = self.relay(request)
|
815
|
-
|
816
|
-
# simulate connection reset by peer
|
817
|
-
if relayed_response is None:
|
818
|
-
connection = request.environ["werkzeug.socket"] # Get the socket object
|
819
|
-
connection.shutdown(socket.SHUT_RDWR)
|
820
|
-
connection.close()
|
821
|
-
|
822
|
-
if self.verbose:
|
823
|
-
print("*****************")
|
824
|
-
print("FILE STREAM REQUEST:")
|
825
|
-
print("********PATH*********")
|
826
|
-
print(path)
|
827
|
-
print("********ENDPATH*********")
|
828
|
-
print(request.get_json())
|
829
|
-
print("FILE STREAM RESPONSE:")
|
830
|
-
print(relayed_response)
|
831
|
-
print(relayed_response.status_code, relayed_response.json())
|
832
|
-
print("*****************")
|
833
|
-
self.snoop_context(request, relayed_response, timer.elapsed, path=path)
|
834
|
-
|
835
|
-
return relayed_response.json()
|
836
|
-
|
837
|
-
def storage(self) -> Mapping[str, str]:
|
838
|
-
request = flask.request
|
839
|
-
with Timer() as timer:
|
840
|
-
relayed_response = self.relay(request)
|
841
|
-
if self.verbose:
|
842
|
-
print("*****************")
|
843
|
-
print("STORAGE REQUEST:")
|
844
|
-
print(request.get_json())
|
845
|
-
print("STORAGE RESPONSE:")
|
846
|
-
print(relayed_response.status_code, relayed_response.json())
|
847
|
-
print("*****************")
|
848
|
-
|
849
|
-
self.snoop_context(request, relayed_response, timer.elapsed)
|
850
|
-
|
851
|
-
return relayed_response.json()
|
852
|
-
|
853
|
-
def storage_file(self, path) -> Mapping[str, str]:
|
854
|
-
request = flask.request
|
855
|
-
with Timer() as timer:
|
856
|
-
relayed_response = self.relay(request)
|
857
|
-
if self.verbose:
|
858
|
-
print("*****************")
|
859
|
-
print("STORAGE FILE REQUEST:")
|
860
|
-
print("********PATH*********")
|
861
|
-
print(path)
|
862
|
-
print("********ENDPATH*********")
|
863
|
-
print(request.get_json())
|
864
|
-
print("STORAGE FILE RESPONSE:")
|
865
|
-
print(relayed_response.json())
|
866
|
-
print("*****************")
|
867
|
-
|
868
|
-
self.snoop_context(request, relayed_response, timer.elapsed, path=path)
|
869
|
-
|
870
|
-
return relayed_response.json()
|
871
|
-
|
872
|
-
def control(self) -> Mapping[str, str]:
|
873
|
-
assert self.relay_control
|
874
|
-
return self.relay_control.control(flask.request)
|