wandb 0.17.8__py3-none-any.whl → 0.17.9__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
wandb/testing/relay.py DELETED
@@ -1,880 +0,0 @@
1
- import dataclasses
2
- import json
3
- import logging
4
- import socket
5
- import sys
6
- import threading
7
- import traceback
8
- import urllib.parse
9
- from collections import defaultdict, deque
10
- from copy import deepcopy
11
- from typing import (
12
- TYPE_CHECKING,
13
- Any,
14
- Callable,
15
- Dict,
16
- Iterable,
17
- List,
18
- Mapping,
19
- Optional,
20
- Union,
21
- )
22
-
23
- import flask
24
- import pandas as pd
25
- import requests
26
- import responses
27
-
28
- import wandb
29
- import wandb.util
30
- from wandb.sdk.lib.timer import Timer
31
-
32
- try:
33
- from typing import Literal, TypedDict
34
- except ImportError:
35
- from typing_extensions import Literal, TypedDict
36
-
37
- if sys.version_info >= (3, 8):
38
- from typing import Protocol
39
- else:
40
- from typing_extensions import Protocol
41
-
42
- if TYPE_CHECKING:
43
- from typing import Deque
44
-
45
- class RawRequestResponse(TypedDict):
46
- url: str
47
- request: Optional[Any]
48
- response: Dict[str, Any]
49
- time_elapsed: float # seconds
50
-
51
- ResolverName = Literal[
52
- "upsert_bucket",
53
- "upload_files",
54
- "uploaded_files",
55
- "preempting",
56
- "upsert_sweep",
57
- ]
58
-
59
- class Resolver(TypedDict):
60
- name: ResolverName
61
- resolver: Callable[[Any], Optional[Dict[str, Any]]]
62
-
63
-
64
- class DeliberateHTTPError(Exception):
65
- def __init__(self, message, status_code: int = 500):
66
- Exception.__init__(self)
67
- self.message = message
68
- self.status_code = status_code
69
-
70
- def get_response(self):
71
- return flask.Response(self.message, status=self.status_code)
72
-
73
- def __repr__(self):
74
- return f"DeliberateHTTPError({self.message!r}, {self.status_code!r})"
75
-
76
-
77
- @dataclasses.dataclass
78
- class RunAttrs:
79
- """Simple data class for run attributes."""
80
-
81
- name: str
82
- display_name: str
83
- description: str
84
- sweep_name: str
85
- project: Dict[str, Any]
86
- config: Dict[str, Any]
87
- remote: Optional[str] = None
88
- commit: Optional[str] = None
89
-
90
-
91
- class Context:
92
- """A container used to store the snooped state/data of a test.
93
-
94
- Includes raw requests and responses, parsed and processed data, and a number of
95
- convenience methods and properties for accessing the data.
96
- """
97
-
98
- def __init__(self) -> None:
99
- # parsed/merged data. keys are the individual wandb run id's.
100
- self._entries = defaultdict(dict)
101
- # container for raw requests and responses:
102
- self.raw_data: List[RawRequestResponse] = []
103
- # concatenated file contents for all runs:
104
- self._history: Optional[pd.DataFrame] = None
105
- self._events: Optional[pd.DataFrame] = None
106
- self._summary: Optional[pd.DataFrame] = None
107
- self._config: Optional[Dict[str, Any]] = None
108
- self._output: Optional[Any] = None
109
-
110
- def upsert(self, entry: Dict[str, Any]) -> None:
111
- try:
112
- entry_id: str = entry["name"]
113
- except KeyError:
114
- entry_id = entry["id"]
115
- self._entries[entry_id] = wandb.util.merge_dicts(entry, self._entries[entry_id])
116
-
117
- # mapping interface
118
- def __getitem__(self, key: str) -> Any:
119
- return self._entries[key]
120
-
121
- def keys(self) -> Iterable[str]:
122
- return self._entries.keys()
123
-
124
- def get_file_contents(self, file_name: str) -> pd.DataFrame:
125
- dfs = []
126
-
127
- for entry_id in self._entries:
128
- # - extract the content from `file_name`
129
- # - sort by offset (will be useful when relay server goes async)
130
- # - extract data, merge into a list of dicts and convert to a pandas dataframe
131
- content_list = self._entries[entry_id].get("files", {}).get(file_name, [])
132
- content_list.sort(key=lambda x: x["offset"])
133
- content_list = [item["content"] for item in content_list]
134
- # merge list of lists content_list:
135
- content_list = [item for sublist in content_list for item in sublist]
136
- df = pd.DataFrame.from_records(content_list)
137
- df["__run_id"] = entry_id
138
- dfs.append(df)
139
-
140
- return pd.concat(dfs)
141
-
142
- # attributes to use in assertions
143
- @property
144
- def entries(self) -> Dict[str, Any]:
145
- return deepcopy(self._entries)
146
-
147
- @property
148
- def history(self) -> pd.DataFrame:
149
- # todo: caveat: this assumes that all assertions happen at the end of a test
150
- if self._history is not None:
151
- return deepcopy(self._history)
152
-
153
- self._history = self.get_file_contents("wandb-history.jsonl")
154
- return deepcopy(self._history)
155
-
156
- @property
157
- def events(self) -> pd.DataFrame:
158
- if self._events is not None:
159
- return deepcopy(self._events)
160
-
161
- self._events = self.get_file_contents("wandb-events.jsonl")
162
- return deepcopy(self._events)
163
-
164
- @property
165
- def summary(self) -> pd.DataFrame:
166
- if self._summary is not None:
167
- return deepcopy(self._summary)
168
-
169
- _summary = self.get_file_contents("wandb-summary.json")
170
-
171
- # run summary may be updated multiple times,
172
- # but we are only interested in the last one.
173
- # we can have multiple runs saved to context,
174
- # so we need to group by run id and take the
175
- # last one for each run.
176
- self._summary = (
177
- _summary.groupby("__run_id").last().reset_index(level=["__run_id"])
178
- )
179
-
180
- return deepcopy(self._summary)
181
-
182
- @property
183
- def output(self) -> pd.DataFrame:
184
- if self._output is not None:
185
- return deepcopy(self._output)
186
-
187
- self._output = self.get_file_contents("output.log")
188
- return deepcopy(self._output)
189
-
190
- @property
191
- def config(self) -> Dict[str, Any]:
192
- if self._config is not None:
193
- return deepcopy(self._config)
194
-
195
- self._config = {k: v["config"] for (k, v) in self._entries.items() if k}
196
- return deepcopy(self._config)
197
-
198
- # @property
199
- # def telemetry(self) -> pd.DataFrame:
200
- # telemetry = pd.DataFrame.from_records(
201
- # [
202
- # {
203
- # "__run_id": run_id,
204
- # "telemetry": config.get("_wandb", {}).get("value", {}).get("t")
205
- # }
206
- # for (run_id, config) in self.config.items()
207
- # ]
208
- # )
209
- # return telemetry
210
-
211
- # convenience data access methods
212
- def get_run_config(self, run_id: str) -> Dict[str, Any]:
213
- return self.config.get(run_id, {})
214
-
215
- def get_run_telemetry(self, run_id: str) -> Dict[str, Any]:
216
- return self.config.get(run_id, {}).get("_wandb", {}).get("value", {}).get("t")
217
-
218
- def get_run_metrics(self, run_id: str) -> Dict[str, Any]:
219
- return self.config.get(run_id, {}).get("_wandb", {}).get("value", {}).get("m")
220
-
221
- def get_run_summary(
222
- self, run_id: str, include_private: bool = False
223
- ) -> Dict[str, Any]:
224
- # run summary dataframe must have only one row
225
- # for the given run id, so we convert it to dict
226
- # and extract the first (and only) row.
227
- mask_run = self.summary["__run_id"] == run_id
228
- run_summary = self.summary[mask_run]
229
- ret = (
230
- run_summary.filter(regex="^[^_]", axis=1)
231
- if not include_private
232
- else run_summary
233
- ).to_dict(orient="records")
234
- return ret[0] if len(ret) > 0 else {}
235
-
236
- def get_run_history(
237
- self, run_id: str, include_private: bool = False
238
- ) -> pd.DataFrame:
239
- mask_run = self.history["__run_id"] == run_id
240
- run_history = self.history[mask_run]
241
- return (
242
- run_history.filter(regex="^[^_]", axis=1)
243
- if not include_private
244
- else run_history
245
- )
246
-
247
- def get_run_uploaded_files(self, run_id: str) -> Dict[str, Any]:
248
- return self.entries.get(run_id, {}).get("uploaded", [])
249
-
250
- def get_run_stats(self, run_id: str) -> pd.DataFrame:
251
- mask_run = self.events["__run_id"] == run_id
252
- run_stats = self.events[mask_run]
253
- return run_stats
254
-
255
- def get_run_attrs(self, run_id: str) -> Optional[RunAttrs]:
256
- run_entry = self._entries.get(run_id)
257
- if not run_entry:
258
- return None
259
-
260
- return RunAttrs(
261
- name=run_entry["name"],
262
- display_name=run_entry["displayName"],
263
- description=run_entry["description"],
264
- sweep_name=run_entry["sweepName"],
265
- project=run_entry["project"],
266
- config=run_entry["config"],
267
- remote=run_entry.get("repo"),
268
- commit=run_entry.get("commit"),
269
- )
270
-
271
- def get_run(self, run_id: str) -> Dict[str, Any]:
272
- return self._entries.get(run_id, {})
273
-
274
- def get_run_ids(self) -> List[str]:
275
- return [k for k in self._entries.keys() if k]
276
-
277
- # todo: add getter (by run_id) utilities for other properties
278
-
279
-
280
- class QueryResolver:
281
- """Resolve request/response pairs against a set of known patterns.
282
-
283
- This extracts and processes useful data to be later stored in a Context object.
284
- """
285
-
286
- def __init__(self):
287
- self.resolvers: List[Resolver] = [
288
- {
289
- "name": "upsert_bucket",
290
- "resolver": self.resolve_upsert_bucket,
291
- },
292
- {
293
- "name": "upload_files",
294
- "resolver": self.resolve_upload_files,
295
- },
296
- {
297
- "name": "uploaded_files",
298
- "resolver": self.resolve_uploaded_files,
299
- },
300
- {
301
- "name": "uploaded_files_legacy",
302
- "resolver": self.resolve_uploaded_files_legacy,
303
- },
304
- {
305
- "name": "preempting",
306
- "resolver": self.resolve_preempting,
307
- },
308
- {
309
- "name": "upsert_sweep",
310
- "resolver": self.resolve_upsert_sweep,
311
- },
312
- {
313
- "name": "create_artifact",
314
- "resolver": self.resolve_create_artifact,
315
- },
316
- {
317
- "name": "delete_run",
318
- "resolver": self.resolve_delete_run,
319
- },
320
- ]
321
-
322
- @staticmethod
323
- def resolve_upsert_bucket(
324
- request_data: Dict[str, Any], response_data: Dict[str, Any], **kwargs: Any
325
- ) -> Optional[Dict[str, Any]]:
326
- if not isinstance(request_data, dict) or not isinstance(response_data, dict):
327
- return None
328
- query = response_data.get("data", {}).get("upsertBucket") is not None
329
- if query:
330
- data = {
331
- k: v for (k, v) in request_data["variables"].items() if v is not None
332
- }
333
- data.update(response_data["data"]["upsertBucket"].get("bucket"))
334
- if "config" in data:
335
- data["config"] = json.loads(data["config"])
336
- return data
337
- return None
338
-
339
- @staticmethod
340
- def resolve_delete_run(
341
- request_data: Dict[str, Any], response_data: Dict[str, Any], **kwargs: Any
342
- ) -> Optional[Dict[str, Any]]:
343
- if not isinstance(request_data, dict) or not isinstance(response_data, dict):
344
- return None
345
- query = "query" in request_data and "deleteRun" in request_data["query"]
346
- if query:
347
- data = {
348
- k: v for (k, v) in request_data["variables"].items() if v is not None
349
- }
350
- data.update(response_data["data"]["deleteRun"])
351
- return data
352
- return None
353
-
354
- @staticmethod
355
- def resolve_upload_files(
356
- request_data: Dict[str, Any], response_data: Dict[str, Any], **kwargs: Any
357
- ) -> Optional[Dict[str, Any]]:
358
- if not isinstance(request_data, dict):
359
- return None
360
-
361
- query = request_data.get("files") is not None
362
- if query:
363
- # todo: refactor this 🤮🤮🤮🤮🤮 eventually?
364
- name = kwargs.get("path").split("/")[2]
365
- files = defaultdict(list)
366
- for file_name, file_value in request_data["files"].items():
367
- content = []
368
- for k in file_value.get("content", []):
369
- try:
370
- content.append(json.loads(k))
371
- except json.decoder.JSONDecodeError:
372
- content.append([k])
373
-
374
- files[file_name].append(
375
- {"offset": file_value.get("offset"), "content": content}
376
- )
377
-
378
- post_processed_data = {
379
- "name": name,
380
- "dropped": [request_data["dropped"]]
381
- if "dropped" in request_data
382
- else [],
383
- "files": files,
384
- }
385
- return post_processed_data
386
- return None
387
-
388
- @staticmethod
389
- def resolve_uploaded_files(
390
- request_data: Dict[str, Any], response_data: Dict[str, Any], **kwargs: Any
391
- ) -> Optional[Dict[str, Any]]:
392
- if not isinstance(request_data, dict) or not isinstance(response_data, dict):
393
- return None
394
-
395
- query = "CreateRunFiles" in request_data.get("query", "")
396
- if query:
397
- run_name = request_data["variables"]["run"]
398
- files = ((response_data.get("data") or {}).get("createRunFiles") or {}).get(
399
- "files", {}
400
- )
401
- post_processed_data = {
402
- "name": run_name,
403
- "uploaded": [file["name"] for file in files] if files else [""],
404
- }
405
- return post_processed_data
406
- return None
407
-
408
- @staticmethod
409
- def resolve_uploaded_files_legacy(
410
- request_data: Dict[str, Any], response_data: Dict[str, Any], **kwargs: Any
411
- ) -> Optional[Dict[str, Any]]:
412
- # This is a legacy resolver for uploaded files
413
- # No longer used by tests but leaving it here in case we need it in the future
414
- # Please refer to upload_urls() in internal_api.py for more details
415
- if not isinstance(request_data, dict) or not isinstance(response_data, dict):
416
- return None
417
-
418
- query = "RunUploadUrls" in request_data.get("query", "")
419
- if query:
420
- # todo: refactor this 🤮🤮🤮🤮🤮 eventually?
421
- name = request_data["variables"]["run"]
422
- files = (
423
- response_data.get("data", {})
424
- .get("model", {})
425
- .get("bucket", {})
426
- .get("files", {})
427
- .get("edges", [])
428
- )
429
- # note: we count all attempts to upload files
430
- post_processed_data = {
431
- "name": name,
432
- "uploaded": [files[0].get("node", {}).get("name")] if files else [""],
433
- }
434
- return post_processed_data
435
- return None
436
-
437
- @staticmethod
438
- def resolve_preempting(
439
- request_data: Dict[str, Any], response_data: Dict[str, Any], **kwargs: Any
440
- ) -> Optional[Dict[str, Any]]:
441
- if not isinstance(request_data, dict):
442
- return None
443
- query = "preempting" in request_data
444
- if query:
445
- name = kwargs.get("path").split("/")[2]
446
- post_processed_data = {
447
- "name": name,
448
- "preempting": [request_data["preempting"]],
449
- }
450
- return post_processed_data
451
- return None
452
-
453
- @staticmethod
454
- def resolve_upsert_sweep(
455
- request_data: Dict[str, Any], response_data: Dict[str, Any], **kwargs: Any
456
- ) -> Optional[Dict[str, Any]]:
457
- if not isinstance(response_data, dict):
458
- return None
459
- query = response_data.get("data", {}).get("upsertSweep") is not None
460
- if query:
461
- data = response_data["data"]["upsertSweep"].get("sweep")
462
- return data
463
- return None
464
-
465
- def resolve_create_artifact(
466
- self, request_data: Dict[str, Any], response_data: Dict[str, Any], **kwargs: Any
467
- ) -> Optional[Dict[str, Any]]:
468
- if not isinstance(request_data, dict):
469
- return None
470
- query = (
471
- "createArtifact(" in request_data.get("query", "")
472
- and request_data.get("variables") is not None
473
- and response_data is not None
474
- )
475
- if query:
476
- name = request_data["variables"]["runName"]
477
- post_processed_data = {
478
- "name": name,
479
- "create_artifact": [
480
- {
481
- "variables": request_data["variables"],
482
- "response": response_data["data"]["createArtifact"]["artifact"],
483
- }
484
- ],
485
- }
486
- return post_processed_data
487
- return None
488
-
489
- def resolve(
490
- self,
491
- request_data: Dict[str, Any],
492
- response_data: Dict[str, Any],
493
- **kwargs: Any,
494
- ) -> Optional[Dict[str, Any]]:
495
- results = []
496
- for resolver in self.resolvers:
497
- result = resolver.get("resolver")(request_data, response_data, **kwargs)
498
- if result is not None:
499
- results.append(result)
500
- return results
501
-
502
-
503
- class TokenizedCircularPattern:
504
- APPLY_TOKEN = "1"
505
- PASS_TOKEN = "0"
506
- STOP_TOKEN = "2"
507
-
508
- def __init__(self, pattern: str):
509
- known_tokens = {self.APPLY_TOKEN, self.PASS_TOKEN, self.STOP_TOKEN}
510
- if not pattern:
511
- raise ValueError("Pattern cannot be empty")
512
-
513
- if set(pattern) - known_tokens:
514
- raise ValueError(f"Pattern can only contain {known_tokens}")
515
- self.pattern: Deque[str] = deque(pattern)
516
-
517
- def next(self):
518
- if self.pattern[0] == self.STOP_TOKEN:
519
- return
520
- self.pattern.rotate(-1)
521
-
522
- def should_apply(self) -> bool:
523
- return self.pattern[0] == self.APPLY_TOKEN
524
-
525
-
526
- @dataclasses.dataclass
527
- class InjectedResponse:
528
- method: str
529
- url: str
530
- body: Union[str, Exception]
531
- status: int = 200
532
- content_type: str = "text/plain"
533
- # todo: add more fields for other types of responses?
534
- custom_match_fn: Optional[Callable[..., bool]] = None
535
- application_pattern: TokenizedCircularPattern = TokenizedCircularPattern("1")
536
-
537
- # application_pattern defines the pattern of the response injection
538
- # as the requests come in.
539
- # 0 == do not inject the response
540
- # 1 == inject the response
541
- # 2 == stop using the response (END token)
542
- #
543
- # - when no END token is present, the pattern is repeated indefinitely
544
- # - when END token is present, the pattern is applied until the END token is reached
545
- # - to replicate the current behavior:
546
- # - use application_pattern = "1" if wanting to apply the pattern to all requests
547
- # - use application_pattern = "1" * COUNTER + "2" to apply the pattern to the first COUNTER requests
548
- #
549
- # Examples of application_pattern:
550
- # 1. application_pattern = "1012"
551
- # - inject the response for the first request
552
- # - do not inject the response for the second request
553
- # - inject the response for the third request
554
- # - stop using the response starting from the fourth request onwards
555
- # 2. application_pattern = "110"
556
- # repeat the following pattern indefinitely:
557
- # - inject the response for the first request
558
- # - inject the response for the second request
559
- # - stop using the response for the third request
560
-
561
- def __eq__(
562
- self,
563
- other: Union["InjectedResponse", requests.Request, requests.PreparedRequest],
564
- ):
565
- """Check InjectedResponse object equality.
566
-
567
- We use this to check if this response should be injected as a replacement of
568
- `other`.
569
-
570
- :param other:
571
- :return:
572
- """
573
- if not isinstance(
574
- other, (InjectedResponse, requests.Request, requests.PreparedRequest)
575
- ):
576
- return False
577
-
578
- # always check the method and url
579
- ret = self.method == other.method and self.url == other.url
580
- # use custom_match_fn to check, e.g. the request body content
581
- if ret and self.custom_match_fn is not None:
582
- ret = self.custom_match_fn(self, other)
583
- return ret
584
-
585
- def to_dict(self):
586
- excluded_fields = {"application_pattern", "custom_match_fn"}
587
- return {
588
- k: self.__getattribute__(k)
589
- for k in self.__dict__
590
- if (not k.startswith("_") and k not in excluded_fields)
591
- }
592
-
593
-
594
- class RelayControlProtocol(Protocol):
595
- def process(self, request: "flask.Request") -> None: ... # pragma: no cover
596
-
597
- def control(
598
- self, request: "flask.Request"
599
- ) -> Mapping[str, str]: ... # pragma: no cover
600
-
601
-
602
- class RelayServer:
603
- def __init__(
604
- self,
605
- base_url: str,
606
- inject: Optional[List[InjectedResponse]] = None,
607
- control: Optional[RelayControlProtocol] = None,
608
- verbose: bool = False,
609
- ) -> None:
610
- # todo for the future:
611
- # - consider switching from Flask to Quart
612
- # - async app will allow for better failure injection/poor network perf
613
- self.relay_control = control
614
- self.app = flask.Flask(__name__)
615
- self.app.logger.setLevel(logging.INFO)
616
- self.app.register_error_handler(DeliberateHTTPError, self.handle_http_exception)
617
- self.app.add_url_rule(
618
- rule="/graphql",
619
- endpoint="graphql",
620
- view_func=self.graphql,
621
- methods=["POST"],
622
- )
623
- self.app.add_url_rule(
624
- rule="/files/<path:path>",
625
- endpoint="files",
626
- view_func=self.file_stream,
627
- methods=["POST"],
628
- )
629
- self.app.add_url_rule(
630
- rule="/storage",
631
- endpoint="storage",
632
- view_func=self.storage,
633
- methods=["PUT", "GET"],
634
- )
635
- self.app.add_url_rule(
636
- rule="/storage/<path:path>",
637
- endpoint="storage_file",
638
- view_func=self.storage_file,
639
- methods=["PUT", "GET"],
640
- )
641
- if control:
642
- self.app.add_url_rule(
643
- rule="/_control",
644
- endpoint="_control",
645
- view_func=self.control,
646
- methods=["POST"],
647
- )
648
- # @app.route("/artifacts/<entity>/<digest>", methods=["GET", "POST"])
649
- self.port = self._get_free_port()
650
- self.base_url = urllib.parse.urlparse(base_url)
651
- self.session = requests.Session()
652
- self.relay_url = f"http://127.0.0.1:{self.port}"
653
-
654
- # todo: add an option to add custom resolvers
655
- self.resolver = QueryResolver()
656
- # recursively merge-able object to store state
657
- self.context = Context()
658
-
659
- # injected responses
660
- self.inject = inject or []
661
-
662
- # useful when debugging:
663
- # self.after_request_fn = self.app.after_request(self.after_request_fn)
664
- self.verbose = verbose
665
-
666
- @staticmethod
667
- def handle_http_exception(e):
668
- response = e.get_response()
669
- return response
670
-
671
- @staticmethod
672
- def _get_free_port() -> int:
673
- sock = socket.socket()
674
- sock.bind(("", 0))
675
-
676
- _, port = sock.getsockname()
677
- return port
678
-
679
- def start(self) -> None:
680
- # run server in a separate thread
681
- relay_server_thread = threading.Thread(
682
- target=self.app.run,
683
- kwargs={"port": self.port},
684
- daemon=True,
685
- )
686
- relay_server_thread.start()
687
-
688
- def after_request_fn(self, response: "requests.Response") -> "requests.Response":
689
- # todo: this is useful for debugging, but should be removed in the future
690
- # flask.request.url = self.relay_url + flask.request.url
691
- print(flask.request)
692
- print(flask.request.get_json())
693
- print(response)
694
- print(response.json())
695
- return response
696
-
697
- def relay(
698
- self,
699
- request: "flask.Request",
700
- ) -> Union["responses.Response", "requests.Response", None]:
701
- # replace the relay url with the real backend url (self.base_url)
702
- url = (
703
- urllib.parse.urlparse(request.url)
704
- ._replace(netloc=self.base_url.netloc, scheme=self.base_url.scheme)
705
- .geturl()
706
- )
707
- headers = {key: value for (key, value) in request.headers if key != "Host"}
708
- prepared_relayed_request = requests.Request(
709
- method=request.method,
710
- url=url,
711
- headers=headers,
712
- data=request.get_data(),
713
- json=request.get_json(),
714
- ).prepare()
715
-
716
- if self.verbose:
717
- print("*****************")
718
- print("RELAY REQUEST:")
719
- print(prepared_relayed_request.url)
720
- print(prepared_relayed_request.method)
721
- print(prepared_relayed_request.headers)
722
- print(prepared_relayed_request.body)
723
- print("*****************")
724
-
725
- for injected_response in self.inject:
726
- # where are we in the application pattern?
727
- should_apply = injected_response.application_pattern.should_apply()
728
- # check if an injected response matches the request
729
- if injected_response != prepared_relayed_request or not should_apply:
730
- continue
731
-
732
- if self.verbose:
733
- print("*****************")
734
- print("INJECTING RESPONSE:")
735
- print(injected_response.to_dict())
736
- print("*****************")
737
- # rotate the injection pattern
738
- injected_response.application_pattern.next()
739
-
740
- # TODO: allow access to the request object when making the mocked response
741
- with responses.RequestsMock() as mocked_responses:
742
- # do the actual injection
743
- resp = injected_response.to_dict()
744
-
745
- if isinstance(resp["body"], ConnectionResetError):
746
- return None
747
-
748
- mocked_responses.add(**resp)
749
- relayed_response = self.session.send(prepared_relayed_request)
750
-
751
- return relayed_response
752
-
753
- # normal case: no injected response matches the request
754
- relayed_response = self.session.send(prepared_relayed_request)
755
- return relayed_response
756
-
757
- def snoop_context(
758
- self,
759
- request: "flask.Request",
760
- response: "requests.Response",
761
- time_elapsed: float,
762
- **kwargs: Any,
763
- ) -> None:
764
- request_data = request.get_json()
765
- response_data = response.json() or {}
766
-
767
- if self.relay_control:
768
- self.relay_control.process(request)
769
-
770
- # store raw data
771
- raw_data: RawRequestResponse = {
772
- "url": request.url,
773
- "request": request_data,
774
- "response": response_data,
775
- "time_elapsed": time_elapsed,
776
- }
777
- self.context.raw_data.append(raw_data)
778
-
779
- try:
780
- snooped_context = self.resolver.resolve(
781
- request_data,
782
- response_data,
783
- **kwargs,
784
- )
785
- for entry in snooped_context:
786
- self.context.upsert(entry)
787
- except Exception as e:
788
- print("Failed to resolve context: ", e)
789
- traceback.print_exc()
790
- snooped_context = None
791
-
792
- return None
793
-
794
- def graphql(self) -> Mapping[str, str]:
795
- request = flask.request
796
- with Timer() as timer:
797
- relayed_response = self.relay(request)
798
- if self.verbose:
799
- print("*****************")
800
- print("GRAPHQL REQUEST:")
801
- print(request.get_json())
802
- print("GRAPHQL RESPONSE:")
803
- print(relayed_response.status_code, relayed_response.json())
804
- print("*****************")
805
- # snoop work to extract the context
806
- self.snoop_context(request, relayed_response, timer.elapsed)
807
- if self.verbose:
808
- print("*****************")
809
- print("SNOOPED CONTEXT:")
810
- print(self.context.entries)
811
- print(len(self.context.raw_data))
812
- print("*****************")
813
-
814
- return relayed_response.json()
815
-
816
- def file_stream(self, path) -> Mapping[str, str]:
817
- request = flask.request
818
-
819
- with Timer() as timer:
820
- relayed_response = self.relay(request)
821
-
822
- # simulate connection reset by peer
823
- if relayed_response is None:
824
- connection = request.environ["werkzeug.socket"] # Get the socket object
825
- connection.shutdown(socket.SHUT_RDWR)
826
- connection.close()
827
-
828
- if self.verbose:
829
- print("*****************")
830
- print("FILE STREAM REQUEST:")
831
- print("********PATH*********")
832
- print(path)
833
- print("********ENDPATH*********")
834
- print(request.get_json())
835
- print("FILE STREAM RESPONSE:")
836
- print(relayed_response)
837
- print(relayed_response.status_code, relayed_response.json())
838
- print("*****************")
839
- self.snoop_context(request, relayed_response, timer.elapsed, path=path)
840
-
841
- return relayed_response.json()
842
-
843
- def storage(self) -> Mapping[str, str]:
844
- request = flask.request
845
- with Timer() as timer:
846
- relayed_response = self.relay(request)
847
- if self.verbose:
848
- print("*****************")
849
- print("STORAGE REQUEST:")
850
- print(request.get_json())
851
- print("STORAGE RESPONSE:")
852
- print(relayed_response.status_code, relayed_response.json())
853
- print("*****************")
854
-
855
- self.snoop_context(request, relayed_response, timer.elapsed)
856
-
857
- return relayed_response.json()
858
-
859
- def storage_file(self, path) -> Mapping[str, str]:
860
- request = flask.request
861
- with Timer() as timer:
862
- relayed_response = self.relay(request)
863
- if self.verbose:
864
- print("*****************")
865
- print("STORAGE FILE REQUEST:")
866
- print("********PATH*********")
867
- print(path)
868
- print("********ENDPATH*********")
869
- print(request.get_json())
870
- print("STORAGE FILE RESPONSE:")
871
- print(relayed_response.json())
872
- print("*****************")
873
-
874
- self.snoop_context(request, relayed_response, timer.elapsed, path=path)
875
-
876
- return relayed_response.json()
877
-
878
- def control(self) -> Mapping[str, str]:
879
- assert self.relay_control
880
- return self.relay_control.control(flask.request)