wandb 0.17.8__py3-none-any.whl → 0.17.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
wandb/testing/relay.py DELETED
@@ -1,880 +0,0 @@
1
- import dataclasses
2
- import json
3
- import logging
4
- import socket
5
- import sys
6
- import threading
7
- import traceback
8
- import urllib.parse
9
- from collections import defaultdict, deque
10
- from copy import deepcopy
11
- from typing import (
12
- TYPE_CHECKING,
13
- Any,
14
- Callable,
15
- Dict,
16
- Iterable,
17
- List,
18
- Mapping,
19
- Optional,
20
- Union,
21
- )
22
-
23
- import flask
24
- import pandas as pd
25
- import requests
26
- import responses
27
-
28
- import wandb
29
- import wandb.util
30
- from wandb.sdk.lib.timer import Timer
31
-
32
- try:
33
- from typing import Literal, TypedDict
34
- except ImportError:
35
- from typing_extensions import Literal, TypedDict
36
-
37
- if sys.version_info >= (3, 8):
38
- from typing import Protocol
39
- else:
40
- from typing_extensions import Protocol
41
-
42
- if TYPE_CHECKING:
43
- from typing import Deque
44
-
45
- class RawRequestResponse(TypedDict):
46
- url: str
47
- request: Optional[Any]
48
- response: Dict[str, Any]
49
- time_elapsed: float # seconds
50
-
51
- ResolverName = Literal[
52
- "upsert_bucket",
53
- "upload_files",
54
- "uploaded_files",
55
- "preempting",
56
- "upsert_sweep",
57
- ]
58
-
59
- class Resolver(TypedDict):
60
- name: ResolverName
61
- resolver: Callable[[Any], Optional[Dict[str, Any]]]
62
-
63
-
64
- class DeliberateHTTPError(Exception):
65
- def __init__(self, message, status_code: int = 500):
66
- Exception.__init__(self)
67
- self.message = message
68
- self.status_code = status_code
69
-
70
- def get_response(self):
71
- return flask.Response(self.message, status=self.status_code)
72
-
73
- def __repr__(self):
74
- return f"DeliberateHTTPError({self.message!r}, {self.status_code!r})"
75
-
76
-
77
- @dataclasses.dataclass
78
- class RunAttrs:
79
- """Simple data class for run attributes."""
80
-
81
- name: str
82
- display_name: str
83
- description: str
84
- sweep_name: str
85
- project: Dict[str, Any]
86
- config: Dict[str, Any]
87
- remote: Optional[str] = None
88
- commit: Optional[str] = None
89
-
90
-
91
- class Context:
92
- """A container used to store the snooped state/data of a test.
93
-
94
- Includes raw requests and responses, parsed and processed data, and a number of
95
- convenience methods and properties for accessing the data.
96
- """
97
-
98
- def __init__(self) -> None:
99
- # parsed/merged data. keys are the individual wandb run id's.
100
- self._entries = defaultdict(dict)
101
- # container for raw requests and responses:
102
- self.raw_data: List[RawRequestResponse] = []
103
- # concatenated file contents for all runs:
104
- self._history: Optional[pd.DataFrame] = None
105
- self._events: Optional[pd.DataFrame] = None
106
- self._summary: Optional[pd.DataFrame] = None
107
- self._config: Optional[Dict[str, Any]] = None
108
- self._output: Optional[Any] = None
109
-
110
- def upsert(self, entry: Dict[str, Any]) -> None:
111
- try:
112
- entry_id: str = entry["name"]
113
- except KeyError:
114
- entry_id = entry["id"]
115
- self._entries[entry_id] = wandb.util.merge_dicts(entry, self._entries[entry_id])
116
-
117
- # mapping interface
118
- def __getitem__(self, key: str) -> Any:
119
- return self._entries[key]
120
-
121
- def keys(self) -> Iterable[str]:
122
- return self._entries.keys()
123
-
124
- def get_file_contents(self, file_name: str) -> pd.DataFrame:
125
- dfs = []
126
-
127
- for entry_id in self._entries:
128
- # - extract the content from `file_name`
129
- # - sort by offset (will be useful when relay server goes async)
130
- # - extract data, merge into a list of dicts and convert to a pandas dataframe
131
- content_list = self._entries[entry_id].get("files", {}).get(file_name, [])
132
- content_list.sort(key=lambda x: x["offset"])
133
- content_list = [item["content"] for item in content_list]
134
- # merge list of lists content_list:
135
- content_list = [item for sublist in content_list for item in sublist]
136
- df = pd.DataFrame.from_records(content_list)
137
- df["__run_id"] = entry_id
138
- dfs.append(df)
139
-
140
- return pd.concat(dfs)
141
-
142
- # attributes to use in assertions
143
- @property
144
- def entries(self) -> Dict[str, Any]:
145
- return deepcopy(self._entries)
146
-
147
- @property
148
- def history(self) -> pd.DataFrame:
149
- # todo: caveat: this assumes that all assertions happen at the end of a test
150
- if self._history is not None:
151
- return deepcopy(self._history)
152
-
153
- self._history = self.get_file_contents("wandb-history.jsonl")
154
- return deepcopy(self._history)
155
-
156
- @property
157
- def events(self) -> pd.DataFrame:
158
- if self._events is not None:
159
- return deepcopy(self._events)
160
-
161
- self._events = self.get_file_contents("wandb-events.jsonl")
162
- return deepcopy(self._events)
163
-
164
- @property
165
- def summary(self) -> pd.DataFrame:
166
- if self._summary is not None:
167
- return deepcopy(self._summary)
168
-
169
- _summary = self.get_file_contents("wandb-summary.json")
170
-
171
- # run summary may be updated multiple times,
172
- # but we are only interested in the last one.
173
- # we can have multiple runs saved to context,
174
- # so we need to group by run id and take the
175
- # last one for each run.
176
- self._summary = (
177
- _summary.groupby("__run_id").last().reset_index(level=["__run_id"])
178
- )
179
-
180
- return deepcopy(self._summary)
181
-
182
- @property
183
- def output(self) -> pd.DataFrame:
184
- if self._output is not None:
185
- return deepcopy(self._output)
186
-
187
- self._output = self.get_file_contents("output.log")
188
- return deepcopy(self._output)
189
-
190
- @property
191
- def config(self) -> Dict[str, Any]:
192
- if self._config is not None:
193
- return deepcopy(self._config)
194
-
195
- self._config = {k: v["config"] for (k, v) in self._entries.items() if k}
196
- return deepcopy(self._config)
197
-
198
- # @property
199
- # def telemetry(self) -> pd.DataFrame:
200
- # telemetry = pd.DataFrame.from_records(
201
- # [
202
- # {
203
- # "__run_id": run_id,
204
- # "telemetry": config.get("_wandb", {}).get("value", {}).get("t")
205
- # }
206
- # for (run_id, config) in self.config.items()
207
- # ]
208
- # )
209
- # return telemetry
210
-
211
- # convenience data access methods
212
- def get_run_config(self, run_id: str) -> Dict[str, Any]:
213
- return self.config.get(run_id, {})
214
-
215
- def get_run_telemetry(self, run_id: str) -> Dict[str, Any]:
216
- return self.config.get(run_id, {}).get("_wandb", {}).get("value", {}).get("t")
217
-
218
- def get_run_metrics(self, run_id: str) -> Dict[str, Any]:
219
- return self.config.get(run_id, {}).get("_wandb", {}).get("value", {}).get("m")
220
-
221
- def get_run_summary(
222
- self, run_id: str, include_private: bool = False
223
- ) -> Dict[str, Any]:
224
- # run summary dataframe must have only one row
225
- # for the given run id, so we convert it to dict
226
- # and extract the first (and only) row.
227
- mask_run = self.summary["__run_id"] == run_id
228
- run_summary = self.summary[mask_run]
229
- ret = (
230
- run_summary.filter(regex="^[^_]", axis=1)
231
- if not include_private
232
- else run_summary
233
- ).to_dict(orient="records")
234
- return ret[0] if len(ret) > 0 else {}
235
-
236
- def get_run_history(
237
- self, run_id: str, include_private: bool = False
238
- ) -> pd.DataFrame:
239
- mask_run = self.history["__run_id"] == run_id
240
- run_history = self.history[mask_run]
241
- return (
242
- run_history.filter(regex="^[^_]", axis=1)
243
- if not include_private
244
- else run_history
245
- )
246
-
247
- def get_run_uploaded_files(self, run_id: str) -> Dict[str, Any]:
248
- return self.entries.get(run_id, {}).get("uploaded", [])
249
-
250
- def get_run_stats(self, run_id: str) -> pd.DataFrame:
251
- mask_run = self.events["__run_id"] == run_id
252
- run_stats = self.events[mask_run]
253
- return run_stats
254
-
255
- def get_run_attrs(self, run_id: str) -> Optional[RunAttrs]:
256
- run_entry = self._entries.get(run_id)
257
- if not run_entry:
258
- return None
259
-
260
- return RunAttrs(
261
- name=run_entry["name"],
262
- display_name=run_entry["displayName"],
263
- description=run_entry["description"],
264
- sweep_name=run_entry["sweepName"],
265
- project=run_entry["project"],
266
- config=run_entry["config"],
267
- remote=run_entry.get("repo"),
268
- commit=run_entry.get("commit"),
269
- )
270
-
271
- def get_run(self, run_id: str) -> Dict[str, Any]:
272
- return self._entries.get(run_id, {})
273
-
274
- def get_run_ids(self) -> List[str]:
275
- return [k for k in self._entries.keys() if k]
276
-
277
- # todo: add getter (by run_id) utilities for other properties
278
-
279
-
280
- class QueryResolver:
281
- """Resolve request/response pairs against a set of known patterns.
282
-
283
- This extracts and processes useful data to be later stored in a Context object.
284
- """
285
-
286
- def __init__(self):
287
- self.resolvers: List[Resolver] = [
288
- {
289
- "name": "upsert_bucket",
290
- "resolver": self.resolve_upsert_bucket,
291
- },
292
- {
293
- "name": "upload_files",
294
- "resolver": self.resolve_upload_files,
295
- },
296
- {
297
- "name": "uploaded_files",
298
- "resolver": self.resolve_uploaded_files,
299
- },
300
- {
301
- "name": "uploaded_files_legacy",
302
- "resolver": self.resolve_uploaded_files_legacy,
303
- },
304
- {
305
- "name": "preempting",
306
- "resolver": self.resolve_preempting,
307
- },
308
- {
309
- "name": "upsert_sweep",
310
- "resolver": self.resolve_upsert_sweep,
311
- },
312
- {
313
- "name": "create_artifact",
314
- "resolver": self.resolve_create_artifact,
315
- },
316
- {
317
- "name": "delete_run",
318
- "resolver": self.resolve_delete_run,
319
- },
320
- ]
321
-
322
- @staticmethod
323
- def resolve_upsert_bucket(
324
- request_data: Dict[str, Any], response_data: Dict[str, Any], **kwargs: Any
325
- ) -> Optional[Dict[str, Any]]:
326
- if not isinstance(request_data, dict) or not isinstance(response_data, dict):
327
- return None
328
- query = response_data.get("data", {}).get("upsertBucket") is not None
329
- if query:
330
- data = {
331
- k: v for (k, v) in request_data["variables"].items() if v is not None
332
- }
333
- data.update(response_data["data"]["upsertBucket"].get("bucket"))
334
- if "config" in data:
335
- data["config"] = json.loads(data["config"])
336
- return data
337
- return None
338
-
339
- @staticmethod
340
- def resolve_delete_run(
341
- request_data: Dict[str, Any], response_data: Dict[str, Any], **kwargs: Any
342
- ) -> Optional[Dict[str, Any]]:
343
- if not isinstance(request_data, dict) or not isinstance(response_data, dict):
344
- return None
345
- query = "query" in request_data and "deleteRun" in request_data["query"]
346
- if query:
347
- data = {
348
- k: v for (k, v) in request_data["variables"].items() if v is not None
349
- }
350
- data.update(response_data["data"]["deleteRun"])
351
- return data
352
- return None
353
-
354
- @staticmethod
355
- def resolve_upload_files(
356
- request_data: Dict[str, Any], response_data: Dict[str, Any], **kwargs: Any
357
- ) -> Optional[Dict[str, Any]]:
358
- if not isinstance(request_data, dict):
359
- return None
360
-
361
- query = request_data.get("files") is not None
362
- if query:
363
- # todo: refactor this 🤮🤮🤮🤮🤮 eventually?
364
- name = kwargs.get("path").split("/")[2]
365
- files = defaultdict(list)
366
- for file_name, file_value in request_data["files"].items():
367
- content = []
368
- for k in file_value.get("content", []):
369
- try:
370
- content.append(json.loads(k))
371
- except json.decoder.JSONDecodeError:
372
- content.append([k])
373
-
374
- files[file_name].append(
375
- {"offset": file_value.get("offset"), "content": content}
376
- )
377
-
378
- post_processed_data = {
379
- "name": name,
380
- "dropped": [request_data["dropped"]]
381
- if "dropped" in request_data
382
- else [],
383
- "files": files,
384
- }
385
- return post_processed_data
386
- return None
387
-
388
- @staticmethod
389
- def resolve_uploaded_files(
390
- request_data: Dict[str, Any], response_data: Dict[str, Any], **kwargs: Any
391
- ) -> Optional[Dict[str, Any]]:
392
- if not isinstance(request_data, dict) or not isinstance(response_data, dict):
393
- return None
394
-
395
- query = "CreateRunFiles" in request_data.get("query", "")
396
- if query:
397
- run_name = request_data["variables"]["run"]
398
- files = ((response_data.get("data") or {}).get("createRunFiles") or {}).get(
399
- "files", {}
400
- )
401
- post_processed_data = {
402
- "name": run_name,
403
- "uploaded": [file["name"] for file in files] if files else [""],
404
- }
405
- return post_processed_data
406
- return None
407
-
408
- @staticmethod
409
- def resolve_uploaded_files_legacy(
410
- request_data: Dict[str, Any], response_data: Dict[str, Any], **kwargs: Any
411
- ) -> Optional[Dict[str, Any]]:
412
- # This is a legacy resolver for uploaded files
413
- # No longer used by tests but leaving it here in case we need it in the future
414
- # Please refer to upload_urls() in internal_api.py for more details
415
- if not isinstance(request_data, dict) or not isinstance(response_data, dict):
416
- return None
417
-
418
- query = "RunUploadUrls" in request_data.get("query", "")
419
- if query:
420
- # todo: refactor this 🤮🤮🤮🤮🤮 eventually?
421
- name = request_data["variables"]["run"]
422
- files = (
423
- response_data.get("data", {})
424
- .get("model", {})
425
- .get("bucket", {})
426
- .get("files", {})
427
- .get("edges", [])
428
- )
429
- # note: we count all attempts to upload files
430
- post_processed_data = {
431
- "name": name,
432
- "uploaded": [files[0].get("node", {}).get("name")] if files else [""],
433
- }
434
- return post_processed_data
435
- return None
436
-
437
- @staticmethod
438
- def resolve_preempting(
439
- request_data: Dict[str, Any], response_data: Dict[str, Any], **kwargs: Any
440
- ) -> Optional[Dict[str, Any]]:
441
- if not isinstance(request_data, dict):
442
- return None
443
- query = "preempting" in request_data
444
- if query:
445
- name = kwargs.get("path").split("/")[2]
446
- post_processed_data = {
447
- "name": name,
448
- "preempting": [request_data["preempting"]],
449
- }
450
- return post_processed_data
451
- return None
452
-
453
- @staticmethod
454
- def resolve_upsert_sweep(
455
- request_data: Dict[str, Any], response_data: Dict[str, Any], **kwargs: Any
456
- ) -> Optional[Dict[str, Any]]:
457
- if not isinstance(response_data, dict):
458
- return None
459
- query = response_data.get("data", {}).get("upsertSweep") is not None
460
- if query:
461
- data = response_data["data"]["upsertSweep"].get("sweep")
462
- return data
463
- return None
464
-
465
- def resolve_create_artifact(
466
- self, request_data: Dict[str, Any], response_data: Dict[str, Any], **kwargs: Any
467
- ) -> Optional[Dict[str, Any]]:
468
- if not isinstance(request_data, dict):
469
- return None
470
- query = (
471
- "createArtifact(" in request_data.get("query", "")
472
- and request_data.get("variables") is not None
473
- and response_data is not None
474
- )
475
- if query:
476
- name = request_data["variables"]["runName"]
477
- post_processed_data = {
478
- "name": name,
479
- "create_artifact": [
480
- {
481
- "variables": request_data["variables"],
482
- "response": response_data["data"]["createArtifact"]["artifact"],
483
- }
484
- ],
485
- }
486
- return post_processed_data
487
- return None
488
-
489
- def resolve(
490
- self,
491
- request_data: Dict[str, Any],
492
- response_data: Dict[str, Any],
493
- **kwargs: Any,
494
- ) -> Optional[Dict[str, Any]]:
495
- results = []
496
- for resolver in self.resolvers:
497
- result = resolver.get("resolver")(request_data, response_data, **kwargs)
498
- if result is not None:
499
- results.append(result)
500
- return results
501
-
502
-
503
- class TokenizedCircularPattern:
504
- APPLY_TOKEN = "1"
505
- PASS_TOKEN = "0"
506
- STOP_TOKEN = "2"
507
-
508
- def __init__(self, pattern: str):
509
- known_tokens = {self.APPLY_TOKEN, self.PASS_TOKEN, self.STOP_TOKEN}
510
- if not pattern:
511
- raise ValueError("Pattern cannot be empty")
512
-
513
- if set(pattern) - known_tokens:
514
- raise ValueError(f"Pattern can only contain {known_tokens}")
515
- self.pattern: Deque[str] = deque(pattern)
516
-
517
- def next(self):
518
- if self.pattern[0] == self.STOP_TOKEN:
519
- return
520
- self.pattern.rotate(-1)
521
-
522
- def should_apply(self) -> bool:
523
- return self.pattern[0] == self.APPLY_TOKEN
524
-
525
-
526
- @dataclasses.dataclass
527
- class InjectedResponse:
528
- method: str
529
- url: str
530
- body: Union[str, Exception]
531
- status: int = 200
532
- content_type: str = "text/plain"
533
- # todo: add more fields for other types of responses?
534
- custom_match_fn: Optional[Callable[..., bool]] = None
535
- application_pattern: TokenizedCircularPattern = TokenizedCircularPattern("1")
536
-
537
- # application_pattern defines the pattern of the response injection
538
- # as the requests come in.
539
- # 0 == do not inject the response
540
- # 1 == inject the response
541
- # 2 == stop using the response (END token)
542
- #
543
- # - when no END token is present, the pattern is repeated indefinitely
544
- # - when END token is present, the pattern is applied until the END token is reached
545
- # - to replicate the current behavior:
546
- # - use application_pattern = "1" if wanting to apply the pattern to all requests
547
- # - use application_pattern = "1" * COUNTER + "2" to apply the pattern to the first COUNTER requests
548
- #
549
- # Examples of application_pattern:
550
- # 1. application_pattern = "1012"
551
- # - inject the response for the first request
552
- # - do not inject the response for the second request
553
- # - inject the response for the third request
554
- # - stop using the response starting from the fourth request onwards
555
- # 2. application_pattern = "110"
556
- # repeat the following pattern indefinitely:
557
- # - inject the response for the first request
558
- # - inject the response for the second request
559
- # - stop using the response for the third request
560
-
561
- def __eq__(
562
- self,
563
- other: Union["InjectedResponse", requests.Request, requests.PreparedRequest],
564
- ):
565
- """Check InjectedResponse object equality.
566
-
567
- We use this to check if this response should be injected as a replacement of
568
- `other`.
569
-
570
- :param other:
571
- :return:
572
- """
573
- if not isinstance(
574
- other, (InjectedResponse, requests.Request, requests.PreparedRequest)
575
- ):
576
- return False
577
-
578
- # always check the method and url
579
- ret = self.method == other.method and self.url == other.url
580
- # use custom_match_fn to check, e.g. the request body content
581
- if ret and self.custom_match_fn is not None:
582
- ret = self.custom_match_fn(self, other)
583
- return ret
584
-
585
- def to_dict(self):
586
- excluded_fields = {"application_pattern", "custom_match_fn"}
587
- return {
588
- k: self.__getattribute__(k)
589
- for k in self.__dict__
590
- if (not k.startswith("_") and k not in excluded_fields)
591
- }
592
-
593
-
594
- class RelayControlProtocol(Protocol):
595
- def process(self, request: "flask.Request") -> None: ... # pragma: no cover
596
-
597
- def control(
598
- self, request: "flask.Request"
599
- ) -> Mapping[str, str]: ... # pragma: no cover
600
-
601
-
602
- class RelayServer:
603
- def __init__(
604
- self,
605
- base_url: str,
606
- inject: Optional[List[InjectedResponse]] = None,
607
- control: Optional[RelayControlProtocol] = None,
608
- verbose: bool = False,
609
- ) -> None:
610
- # todo for the future:
611
- # - consider switching from Flask to Quart
612
- # - async app will allow for better failure injection/poor network perf
613
- self.relay_control = control
614
- self.app = flask.Flask(__name__)
615
- self.app.logger.setLevel(logging.INFO)
616
- self.app.register_error_handler(DeliberateHTTPError, self.handle_http_exception)
617
- self.app.add_url_rule(
618
- rule="/graphql",
619
- endpoint="graphql",
620
- view_func=self.graphql,
621
- methods=["POST"],
622
- )
623
- self.app.add_url_rule(
624
- rule="/files/<path:path>",
625
- endpoint="files",
626
- view_func=self.file_stream,
627
- methods=["POST"],
628
- )
629
- self.app.add_url_rule(
630
- rule="/storage",
631
- endpoint="storage",
632
- view_func=self.storage,
633
- methods=["PUT", "GET"],
634
- )
635
- self.app.add_url_rule(
636
- rule="/storage/<path:path>",
637
- endpoint="storage_file",
638
- view_func=self.storage_file,
639
- methods=["PUT", "GET"],
640
- )
641
- if control:
642
- self.app.add_url_rule(
643
- rule="/_control",
644
- endpoint="_control",
645
- view_func=self.control,
646
- methods=["POST"],
647
- )
648
- # @app.route("/artifacts/<entity>/<digest>", methods=["GET", "POST"])
649
- self.port = self._get_free_port()
650
- self.base_url = urllib.parse.urlparse(base_url)
651
- self.session = requests.Session()
652
- self.relay_url = f"http://127.0.0.1:{self.port}"
653
-
654
- # todo: add an option to add custom resolvers
655
- self.resolver = QueryResolver()
656
- # recursively merge-able object to store state
657
- self.context = Context()
658
-
659
- # injected responses
660
- self.inject = inject or []
661
-
662
- # useful when debugging:
663
- # self.after_request_fn = self.app.after_request(self.after_request_fn)
664
- self.verbose = verbose
665
-
666
- @staticmethod
667
- def handle_http_exception(e):
668
- response = e.get_response()
669
- return response
670
-
671
- @staticmethod
672
- def _get_free_port() -> int:
673
- sock = socket.socket()
674
- sock.bind(("", 0))
675
-
676
- _, port = sock.getsockname()
677
- return port
678
-
679
- def start(self) -> None:
680
- # run server in a separate thread
681
- relay_server_thread = threading.Thread(
682
- target=self.app.run,
683
- kwargs={"port": self.port},
684
- daemon=True,
685
- )
686
- relay_server_thread.start()
687
-
688
- def after_request_fn(self, response: "requests.Response") -> "requests.Response":
689
- # todo: this is useful for debugging, but should be removed in the future
690
- # flask.request.url = self.relay_url + flask.request.url
691
- print(flask.request)
692
- print(flask.request.get_json())
693
- print(response)
694
- print(response.json())
695
- return response
696
-
697
- def relay(
698
- self,
699
- request: "flask.Request",
700
- ) -> Union["responses.Response", "requests.Response", None]:
701
- # replace the relay url with the real backend url (self.base_url)
702
- url = (
703
- urllib.parse.urlparse(request.url)
704
- ._replace(netloc=self.base_url.netloc, scheme=self.base_url.scheme)
705
- .geturl()
706
- )
707
- headers = {key: value for (key, value) in request.headers if key != "Host"}
708
- prepared_relayed_request = requests.Request(
709
- method=request.method,
710
- url=url,
711
- headers=headers,
712
- data=request.get_data(),
713
- json=request.get_json(),
714
- ).prepare()
715
-
716
- if self.verbose:
717
- print("*****************")
718
- print("RELAY REQUEST:")
719
- print(prepared_relayed_request.url)
720
- print(prepared_relayed_request.method)
721
- print(prepared_relayed_request.headers)
722
- print(prepared_relayed_request.body)
723
- print("*****************")
724
-
725
- for injected_response in self.inject:
726
- # where are we in the application pattern?
727
- should_apply = injected_response.application_pattern.should_apply()
728
- # check if an injected response matches the request
729
- if injected_response != prepared_relayed_request or not should_apply:
730
- continue
731
-
732
- if self.verbose:
733
- print("*****************")
734
- print("INJECTING RESPONSE:")
735
- print(injected_response.to_dict())
736
- print("*****************")
737
- # rotate the injection pattern
738
- injected_response.application_pattern.next()
739
-
740
- # TODO: allow access to the request object when making the mocked response
741
- with responses.RequestsMock() as mocked_responses:
742
- # do the actual injection
743
- resp = injected_response.to_dict()
744
-
745
- if isinstance(resp["body"], ConnectionResetError):
746
- return None
747
-
748
- mocked_responses.add(**resp)
749
- relayed_response = self.session.send(prepared_relayed_request)
750
-
751
- return relayed_response
752
-
753
- # normal case: no injected response matches the request
754
- relayed_response = self.session.send(prepared_relayed_request)
755
- return relayed_response
756
-
757
- def snoop_context(
758
- self,
759
- request: "flask.Request",
760
- response: "requests.Response",
761
- time_elapsed: float,
762
- **kwargs: Any,
763
- ) -> None:
764
- request_data = request.get_json()
765
- response_data = response.json() or {}
766
-
767
- if self.relay_control:
768
- self.relay_control.process(request)
769
-
770
- # store raw data
771
- raw_data: RawRequestResponse = {
772
- "url": request.url,
773
- "request": request_data,
774
- "response": response_data,
775
- "time_elapsed": time_elapsed,
776
- }
777
- self.context.raw_data.append(raw_data)
778
-
779
- try:
780
- snooped_context = self.resolver.resolve(
781
- request_data,
782
- response_data,
783
- **kwargs,
784
- )
785
- for entry in snooped_context:
786
- self.context.upsert(entry)
787
- except Exception as e:
788
- print("Failed to resolve context: ", e)
789
- traceback.print_exc()
790
- snooped_context = None
791
-
792
- return None
793
-
794
- def graphql(self) -> Mapping[str, str]:
795
- request = flask.request
796
- with Timer() as timer:
797
- relayed_response = self.relay(request)
798
- if self.verbose:
799
- print("*****************")
800
- print("GRAPHQL REQUEST:")
801
- print(request.get_json())
802
- print("GRAPHQL RESPONSE:")
803
- print(relayed_response.status_code, relayed_response.json())
804
- print("*****************")
805
- # snoop work to extract the context
806
- self.snoop_context(request, relayed_response, timer.elapsed)
807
- if self.verbose:
808
- print("*****************")
809
- print("SNOOPED CONTEXT:")
810
- print(self.context.entries)
811
- print(len(self.context.raw_data))
812
- print("*****************")
813
-
814
- return relayed_response.json()
815
-
816
- def file_stream(self, path) -> Mapping[str, str]:
817
- request = flask.request
818
-
819
- with Timer() as timer:
820
- relayed_response = self.relay(request)
821
-
822
- # simulate connection reset by peer
823
- if relayed_response is None:
824
- connection = request.environ["werkzeug.socket"] # Get the socket object
825
- connection.shutdown(socket.SHUT_RDWR)
826
- connection.close()
827
-
828
- if self.verbose:
829
- print("*****************")
830
- print("FILE STREAM REQUEST:")
831
- print("********PATH*********")
832
- print(path)
833
- print("********ENDPATH*********")
834
- print(request.get_json())
835
- print("FILE STREAM RESPONSE:")
836
- print(relayed_response)
837
- print(relayed_response.status_code, relayed_response.json())
838
- print("*****************")
839
- self.snoop_context(request, relayed_response, timer.elapsed, path=path)
840
-
841
- return relayed_response.json()
842
-
843
- def storage(self) -> Mapping[str, str]:
844
- request = flask.request
845
- with Timer() as timer:
846
- relayed_response = self.relay(request)
847
- if self.verbose:
848
- print("*****************")
849
- print("STORAGE REQUEST:")
850
- print(request.get_json())
851
- print("STORAGE RESPONSE:")
852
- print(relayed_response.status_code, relayed_response.json())
853
- print("*****************")
854
-
855
- self.snoop_context(request, relayed_response, timer.elapsed)
856
-
857
- return relayed_response.json()
858
-
859
- def storage_file(self, path) -> Mapping[str, str]:
860
- request = flask.request
861
- with Timer() as timer:
862
- relayed_response = self.relay(request)
863
- if self.verbose:
864
- print("*****************")
865
- print("STORAGE FILE REQUEST:")
866
- print("********PATH*********")
867
- print(path)
868
- print("********ENDPATH*********")
869
- print(request.get_json())
870
- print("STORAGE FILE RESPONSE:")
871
- print(relayed_response.json())
872
- print("*****************")
873
-
874
- self.snoop_context(request, relayed_response, timer.elapsed, path=path)
875
-
876
- return relayed_response.json()
877
-
878
- def control(self) -> Mapping[str, str]:
879
- assert self.relay_control
880
- return self.relay_control.control(flask.request)