wandb 0.17.8rc1__py3-none-macosx_11_0_arm64.whl → 0.17.9__py3-none-macosx_11_0_arm64.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (44) hide show
  1. package_readme.md +47 -53
  2. wandb/__init__.py +12 -6
  3. wandb/__init__.pyi +112 -2
  4. wandb/bin/apple_gpu_stats +0 -0
  5. wandb/bin/wandb-core +0 -0
  6. wandb/data_types.py +1 -0
  7. wandb/env.py +13 -0
  8. wandb/integration/keras/__init__.py +2 -5
  9. wandb/integration/keras/callbacks/metrics_logger.py +10 -4
  10. wandb/integration/keras/callbacks/model_checkpoint.py +0 -5
  11. wandb/integration/keras/keras.py +12 -1
  12. wandb/integration/openai/fine_tuning.py +5 -5
  13. wandb/integration/tensorboard/log.py +1 -1
  14. wandb/proto/v3/wandb_internal_pb2.py +31 -21
  15. wandb/proto/v3/wandb_settings_pb2.py +2 -2
  16. wandb/proto/v3/wandb_telemetry_pb2.py +12 -12
  17. wandb/proto/v4/wandb_internal_pb2.py +23 -21
  18. wandb/proto/v4/wandb_settings_pb2.py +2 -2
  19. wandb/proto/v4/wandb_telemetry_pb2.py +12 -12
  20. wandb/proto/v5/wandb_internal_pb2.py +23 -21
  21. wandb/proto/v5/wandb_settings_pb2.py +2 -2
  22. wandb/proto/v5/wandb_telemetry_pb2.py +12 -12
  23. wandb/proto/wandb_deprecated.py +4 -0
  24. wandb/sdk/__init__.py +1 -1
  25. wandb/sdk/artifacts/artifact.py +9 -11
  26. wandb/sdk/artifacts/artifact_manifest_entry.py +10 -2
  27. wandb/sdk/artifacts/storage_handlers/gcs_handler.py +31 -0
  28. wandb/sdk/internal/system/assets/trainium.py +2 -1
  29. wandb/sdk/internal/tb_watcher.py +1 -1
  30. wandb/sdk/lib/_settings_toposort_generated.py +5 -3
  31. wandb/sdk/service/service.py +7 -2
  32. wandb/sdk/wandb_init.py +5 -1
  33. wandb/sdk/wandb_manager.py +0 -3
  34. wandb/sdk/wandb_require.py +22 -1
  35. wandb/sdk/wandb_run.py +14 -4
  36. wandb/sdk/wandb_settings.py +32 -10
  37. wandb/sdk/wandb_setup.py +3 -0
  38. {wandb-0.17.8rc1.dist-info → wandb-0.17.9.dist-info}/METADATA +48 -54
  39. {wandb-0.17.8rc1.dist-info → wandb-0.17.9.dist-info}/RECORD +43 -44
  40. wandb/testing/relay.py +0 -874
  41. /wandb/{viz.py → sdk/lib/viz.py} +0 -0
  42. {wandb-0.17.8rc1.dist-info → wandb-0.17.9.dist-info}/WHEEL +0 -0
  43. {wandb-0.17.8rc1.dist-info → wandb-0.17.9.dist-info}/entry_points.txt +0 -0
  44. {wandb-0.17.8rc1.dist-info → wandb-0.17.9.dist-info}/licenses/LICENSE +0 -0
wandb/testing/relay.py DELETED
@@ -1,874 +0,0 @@
1
- import dataclasses
2
- import json
3
- import logging
4
- import socket
5
- import sys
6
- import threading
7
- import traceback
8
- import urllib.parse
9
- from collections import defaultdict, deque
10
- from copy import deepcopy
11
- from typing import (
12
- TYPE_CHECKING,
13
- Any,
14
- Callable,
15
- Dict,
16
- Iterable,
17
- List,
18
- Mapping,
19
- Optional,
20
- Union,
21
- )
22
-
23
- import flask
24
- import pandas as pd
25
- import requests
26
- import responses
27
-
28
- import wandb
29
- import wandb.util
30
- from wandb.sdk.lib.timer import Timer
31
-
32
- try:
33
- from typing import Literal, TypedDict
34
- except ImportError:
35
- from typing_extensions import Literal, TypedDict
36
-
37
- if sys.version_info >= (3, 8):
38
- from typing import Protocol
39
- else:
40
- from typing_extensions import Protocol
41
-
42
- if TYPE_CHECKING:
43
- from typing import Deque
44
-
45
- class RawRequestResponse(TypedDict):
46
- url: str
47
- request: Optional[Any]
48
- response: Dict[str, Any]
49
- time_elapsed: float # seconds
50
-
51
- ResolverName = Literal[
52
- "upsert_bucket",
53
- "upload_files",
54
- "uploaded_files",
55
- "preempting",
56
- "upsert_sweep",
57
- ]
58
-
59
- class Resolver(TypedDict):
60
- name: ResolverName
61
- resolver: Callable[[Any], Optional[Dict[str, Any]]]
62
-
63
-
64
- class DeliberateHTTPError(Exception):
65
- def __init__(self, message, status_code: int = 500):
66
- Exception.__init__(self)
67
- self.message = message
68
- self.status_code = status_code
69
-
70
- def get_response(self):
71
- return flask.Response(self.message, status=self.status_code)
72
-
73
- def __repr__(self):
74
- return f"DeliberateHTTPError({self.message!r}, {self.status_code!r})"
75
-
76
-
77
- @dataclasses.dataclass
78
- class RunAttrs:
79
- """Simple data class for run attributes."""
80
-
81
- name: str
82
- display_name: str
83
- description: str
84
- sweep_name: str
85
- project: Dict[str, Any]
86
- config: Dict[str, Any]
87
- remote: Optional[str] = None
88
- commit: Optional[str] = None
89
-
90
-
91
- class Context:
92
- """A container used to store the snooped state/data of a test.
93
-
94
- Includes raw requests and responses, parsed and processed data, and a number of
95
- convenience methods and properties for accessing the data.
96
- """
97
-
98
- def __init__(self) -> None:
99
- # parsed/merged data. keys are the individual wandb run id's.
100
- self._entries = defaultdict(dict)
101
- # container for raw requests and responses:
102
- self.raw_data: List[RawRequestResponse] = []
103
- # concatenated file contents for all runs:
104
- self._history: Optional[pd.DataFrame] = None
105
- self._events: Optional[pd.DataFrame] = None
106
- self._summary: Optional[pd.DataFrame] = None
107
- self._config: Optional[Dict[str, Any]] = None
108
- self._output: Optional[Any] = None
109
-
110
- def upsert(self, entry: Dict[str, Any]) -> None:
111
- try:
112
- entry_id: str = entry["name"]
113
- except KeyError:
114
- entry_id = entry["id"]
115
- self._entries[entry_id] = wandb.util.merge_dicts(entry, self._entries[entry_id])
116
-
117
- # mapping interface
118
- def __getitem__(self, key: str) -> Any:
119
- return self._entries[key]
120
-
121
- def keys(self) -> Iterable[str]:
122
- return self._entries.keys()
123
-
124
- def get_file_contents(self, file_name: str) -> pd.DataFrame:
125
- dfs = []
126
-
127
- for entry_id in self._entries:
128
- # - extract the content from `file_name`
129
- # - sort by offset (will be useful when relay server goes async)
130
- # - extract data, merge into a list of dicts and convert to a pandas dataframe
131
- content_list = self._entries[entry_id].get("files", {}).get(file_name, [])
132
- content_list.sort(key=lambda x: x["offset"])
133
- content_list = [item["content"] for item in content_list]
134
- # merge list of lists content_list:
135
- content_list = [item for sublist in content_list for item in sublist]
136
- df = pd.DataFrame.from_records(content_list)
137
- df["__run_id"] = entry_id
138
- dfs.append(df)
139
-
140
- return pd.concat(dfs)
141
-
142
- # attributes to use in assertions
143
- @property
144
- def entries(self) -> Dict[str, Any]:
145
- return deepcopy(self._entries)
146
-
147
- @property
148
- def history(self) -> pd.DataFrame:
149
- # todo: caveat: this assumes that all assertions happen at the end of a test
150
- if self._history is not None:
151
- return deepcopy(self._history)
152
-
153
- self._history = self.get_file_contents("wandb-history.jsonl")
154
- return deepcopy(self._history)
155
-
156
- @property
157
- def events(self) -> pd.DataFrame:
158
- if self._events is not None:
159
- return deepcopy(self._events)
160
-
161
- self._events = self.get_file_contents("wandb-events.jsonl")
162
- return deepcopy(self._events)
163
-
164
- @property
165
- def summary(self) -> pd.DataFrame:
166
- if self._summary is not None:
167
- return deepcopy(self._summary)
168
-
169
- _summary = self.get_file_contents("wandb-summary.json")
170
-
171
- # run summary may be updated multiple times,
172
- # but we are only interested in the last one.
173
- # we can have multiple runs saved to context,
174
- # so we need to group by run id and take the
175
- # last one for each run.
176
- self._summary = (
177
- _summary.groupby("__run_id").last().reset_index(level=["__run_id"])
178
- )
179
-
180
- return deepcopy(self._summary)
181
-
182
- @property
183
- def output(self) -> pd.DataFrame:
184
- if self._output is not None:
185
- return deepcopy(self._output)
186
-
187
- self._output = self.get_file_contents("output.log")
188
- return deepcopy(self._output)
189
-
190
- @property
191
- def config(self) -> Dict[str, Any]:
192
- if self._config is not None:
193
- return deepcopy(self._config)
194
-
195
- self._config = {k: v["config"] for (k, v) in self._entries.items()}
196
- return deepcopy(self._config)
197
-
198
- # @property
199
- # def telemetry(self) -> pd.DataFrame:
200
- # telemetry = pd.DataFrame.from_records(
201
- # [
202
- # {
203
- # "__run_id": run_id,
204
- # "telemetry": config.get("_wandb", {}).get("value", {}).get("t")
205
- # }
206
- # for (run_id, config) in self.config.items()
207
- # ]
208
- # )
209
- # return telemetry
210
-
211
- # convenience data access methods
212
- def get_run_telemetry(self, run_id: str) -> Dict[str, Any]:
213
- return self.config.get(run_id, {}).get("_wandb", {}).get("value", {}).get("t")
214
-
215
- def get_run_metrics(self, run_id: str) -> Dict[str, Any]:
216
- return self.config.get(run_id, {}).get("_wandb", {}).get("value", {}).get("m")
217
-
218
- def get_run_summary(
219
- self, run_id: str, include_private: bool = False
220
- ) -> Dict[str, Any]:
221
- # run summary dataframe must have only one row
222
- # for the given run id, so we convert it to dict
223
- # and extract the first (and only) row.
224
- mask_run = self.summary["__run_id"] == run_id
225
- run_summary = self.summary[mask_run]
226
- ret = (
227
- run_summary.filter(regex="^[^_]", axis=1)
228
- if not include_private
229
- else run_summary
230
- ).to_dict(orient="records")
231
- return ret[0] if len(ret) > 0 else {}
232
-
233
- def get_run_history(
234
- self, run_id: str, include_private: bool = False
235
- ) -> pd.DataFrame:
236
- mask_run = self.history["__run_id"] == run_id
237
- run_history = self.history[mask_run]
238
- return (
239
- run_history.filter(regex="^[^_]", axis=1)
240
- if not include_private
241
- else run_history
242
- )
243
-
244
- def get_run_uploaded_files(self, run_id: str) -> Dict[str, Any]:
245
- return self.entries.get(run_id, {}).get("uploaded", [])
246
-
247
- def get_run_stats(self, run_id: str) -> pd.DataFrame:
248
- mask_run = self.events["__run_id"] == run_id
249
- run_stats = self.events[mask_run]
250
- return run_stats
251
-
252
- def get_run_attrs(self, run_id: str) -> Optional[RunAttrs]:
253
- run_entry = self._entries.get(run_id)
254
- if not run_entry:
255
- return None
256
-
257
- return RunAttrs(
258
- name=run_entry["name"],
259
- display_name=run_entry["displayName"],
260
- description=run_entry["description"],
261
- sweep_name=run_entry["sweepName"],
262
- project=run_entry["project"],
263
- config=run_entry["config"],
264
- remote=run_entry.get("repo"),
265
- commit=run_entry.get("commit"),
266
- )
267
-
268
- def get_run(self, run_id: str) -> Dict[str, Any]:
269
- return self._entries.get(run_id, {})
270
-
271
- # todo: add getter (by run_id) utilities for other properties
272
-
273
-
274
- class QueryResolver:
275
- """Resolve request/response pairs against a set of known patterns.
276
-
277
- This extracts and processes useful data to be later stored in a Context object.
278
- """
279
-
280
- def __init__(self):
281
- self.resolvers: List[Resolver] = [
282
- {
283
- "name": "upsert_bucket",
284
- "resolver": self.resolve_upsert_bucket,
285
- },
286
- {
287
- "name": "upload_files",
288
- "resolver": self.resolve_upload_files,
289
- },
290
- {
291
- "name": "uploaded_files",
292
- "resolver": self.resolve_uploaded_files,
293
- },
294
- {
295
- "name": "uploaded_files_legacy",
296
- "resolver": self.resolve_uploaded_files_legacy,
297
- },
298
- {
299
- "name": "preempting",
300
- "resolver": self.resolve_preempting,
301
- },
302
- {
303
- "name": "upsert_sweep",
304
- "resolver": self.resolve_upsert_sweep,
305
- },
306
- {
307
- "name": "create_artifact",
308
- "resolver": self.resolve_create_artifact,
309
- },
310
- {
311
- "name": "delete_run",
312
- "resolver": self.resolve_delete_run,
313
- },
314
- ]
315
-
316
- @staticmethod
317
- def resolve_upsert_bucket(
318
- request_data: Dict[str, Any], response_data: Dict[str, Any], **kwargs: Any
319
- ) -> Optional[Dict[str, Any]]:
320
- if not isinstance(request_data, dict) or not isinstance(response_data, dict):
321
- return None
322
- query = response_data.get("data", {}).get("upsertBucket") is not None
323
- if query:
324
- data = {
325
- k: v for (k, v) in request_data["variables"].items() if v is not None
326
- }
327
- data.update(response_data["data"]["upsertBucket"].get("bucket"))
328
- if "config" in data:
329
- data["config"] = json.loads(data["config"])
330
- return data
331
- return None
332
-
333
- @staticmethod
334
- def resolve_delete_run(
335
- request_data: Dict[str, Any], response_data: Dict[str, Any], **kwargs: Any
336
- ) -> Optional[Dict[str, Any]]:
337
- if not isinstance(request_data, dict) or not isinstance(response_data, dict):
338
- return None
339
- query = "query" in request_data and "deleteRun" in request_data["query"]
340
- if query:
341
- data = {
342
- k: v for (k, v) in request_data["variables"].items() if v is not None
343
- }
344
- data.update(response_data["data"]["deleteRun"])
345
- return data
346
- return None
347
-
348
- @staticmethod
349
- def resolve_upload_files(
350
- request_data: Dict[str, Any], response_data: Dict[str, Any], **kwargs: Any
351
- ) -> Optional[Dict[str, Any]]:
352
- if not isinstance(request_data, dict):
353
- return None
354
-
355
- query = request_data.get("files") is not None
356
- if query:
357
- # todo: refactor this 🤮🤮🤮🤮🤮 eventually?
358
- name = kwargs.get("path").split("/")[2]
359
- files = defaultdict(list)
360
- for file_name, file_value in request_data["files"].items():
361
- content = []
362
- for k in file_value.get("content", []):
363
- try:
364
- content.append(json.loads(k))
365
- except json.decoder.JSONDecodeError:
366
- content.append([k])
367
-
368
- files[file_name].append(
369
- {"offset": file_value.get("offset"), "content": content}
370
- )
371
-
372
- post_processed_data = {
373
- "name": name,
374
- "dropped": [request_data["dropped"]]
375
- if "dropped" in request_data
376
- else [],
377
- "files": files,
378
- }
379
- return post_processed_data
380
- return None
381
-
382
- @staticmethod
383
- def resolve_uploaded_files(
384
- request_data: Dict[str, Any], response_data: Dict[str, Any], **kwargs: Any
385
- ) -> Optional[Dict[str, Any]]:
386
- if not isinstance(request_data, dict) or not isinstance(response_data, dict):
387
- return None
388
-
389
- query = "CreateRunFiles" in request_data.get("query", "")
390
- if query:
391
- run_name = request_data["variables"]["run"]
392
- files = ((response_data.get("data") or {}).get("createRunFiles") or {}).get(
393
- "files", {}
394
- )
395
- post_processed_data = {
396
- "name": run_name,
397
- "uploaded": [file["name"] for file in files] if files else [""],
398
- }
399
- return post_processed_data
400
- return None
401
-
402
- @staticmethod
403
- def resolve_uploaded_files_legacy(
404
- request_data: Dict[str, Any], response_data: Dict[str, Any], **kwargs: Any
405
- ) -> Optional[Dict[str, Any]]:
406
- # This is a legacy resolver for uploaded files
407
- # No longer used by tests but leaving it here in case we need it in the future
408
- # Please refer to upload_urls() in internal_api.py for more details
409
- if not isinstance(request_data, dict) or not isinstance(response_data, dict):
410
- return None
411
-
412
- query = "RunUploadUrls" in request_data.get("query", "")
413
- if query:
414
- # todo: refactor this 🤮🤮🤮🤮🤮 eventually?
415
- name = request_data["variables"]["run"]
416
- files = (
417
- response_data.get("data", {})
418
- .get("model", {})
419
- .get("bucket", {})
420
- .get("files", {})
421
- .get("edges", [])
422
- )
423
- # note: we count all attempts to upload files
424
- post_processed_data = {
425
- "name": name,
426
- "uploaded": [files[0].get("node", {}).get("name")] if files else [""],
427
- }
428
- return post_processed_data
429
- return None
430
-
431
- @staticmethod
432
- def resolve_preempting(
433
- request_data: Dict[str, Any], response_data: Dict[str, Any], **kwargs: Any
434
- ) -> Optional[Dict[str, Any]]:
435
- if not isinstance(request_data, dict):
436
- return None
437
- query = "preempting" in request_data
438
- if query:
439
- name = kwargs.get("path").split("/")[2]
440
- post_processed_data = {
441
- "name": name,
442
- "preempting": [request_data["preempting"]],
443
- }
444
- return post_processed_data
445
- return None
446
-
447
- @staticmethod
448
- def resolve_upsert_sweep(
449
- request_data: Dict[str, Any], response_data: Dict[str, Any], **kwargs: Any
450
- ) -> Optional[Dict[str, Any]]:
451
- if not isinstance(response_data, dict):
452
- return None
453
- query = response_data.get("data", {}).get("upsertSweep") is not None
454
- if query:
455
- data = response_data["data"]["upsertSweep"].get("sweep")
456
- return data
457
- return None
458
-
459
- def resolve_create_artifact(
460
- self, request_data: Dict[str, Any], response_data: Dict[str, Any], **kwargs: Any
461
- ) -> Optional[Dict[str, Any]]:
462
- if not isinstance(request_data, dict):
463
- return None
464
- query = (
465
- "createArtifact(" in request_data.get("query", "")
466
- and request_data.get("variables") is not None
467
- and response_data is not None
468
- )
469
- if query:
470
- name = request_data["variables"]["runName"]
471
- post_processed_data = {
472
- "name": name,
473
- "create_artifact": [
474
- {
475
- "variables": request_data["variables"],
476
- "response": response_data["data"]["createArtifact"]["artifact"],
477
- }
478
- ],
479
- }
480
- return post_processed_data
481
- return None
482
-
483
- def resolve(
484
- self,
485
- request_data: Dict[str, Any],
486
- response_data: Dict[str, Any],
487
- **kwargs: Any,
488
- ) -> Optional[Dict[str, Any]]:
489
- results = []
490
- for resolver in self.resolvers:
491
- result = resolver.get("resolver")(request_data, response_data, **kwargs)
492
- if result is not None:
493
- results.append(result)
494
- return results
495
-
496
-
497
- class TokenizedCircularPattern:
498
- APPLY_TOKEN = "1"
499
- PASS_TOKEN = "0"
500
- STOP_TOKEN = "2"
501
-
502
- def __init__(self, pattern: str):
503
- known_tokens = {self.APPLY_TOKEN, self.PASS_TOKEN, self.STOP_TOKEN}
504
- if not pattern:
505
- raise ValueError("Pattern cannot be empty")
506
-
507
- if set(pattern) - known_tokens:
508
- raise ValueError(f"Pattern can only contain {known_tokens}")
509
- self.pattern: Deque[str] = deque(pattern)
510
-
511
- def next(self):
512
- if self.pattern[0] == self.STOP_TOKEN:
513
- return
514
- self.pattern.rotate(-1)
515
-
516
- def should_apply(self) -> bool:
517
- return self.pattern[0] == self.APPLY_TOKEN
518
-
519
-
520
- @dataclasses.dataclass
521
- class InjectedResponse:
522
- method: str
523
- url: str
524
- body: Union[str, Exception]
525
- status: int = 200
526
- content_type: str = "text/plain"
527
- # todo: add more fields for other types of responses?
528
- custom_match_fn: Optional[Callable[..., bool]] = None
529
- application_pattern: TokenizedCircularPattern = TokenizedCircularPattern("1")
530
-
531
- # application_pattern defines the pattern of the response injection
532
- # as the requests come in.
533
- # 0 == do not inject the response
534
- # 1 == inject the response
535
- # 2 == stop using the response (END token)
536
- #
537
- # - when no END token is present, the pattern is repeated indefinitely
538
- # - when END token is present, the pattern is applied until the END token is reached
539
- # - to replicate the current behavior:
540
- # - use application_pattern = "1" if wanting to apply the pattern to all requests
541
- # - use application_pattern = "1" * COUNTER + "2" to apply the pattern to the first COUNTER requests
542
- #
543
- # Examples of application_pattern:
544
- # 1. application_pattern = "1012"
545
- # - inject the response for the first request
546
- # - do not inject the response for the second request
547
- # - inject the response for the third request
548
- # - stop using the response starting from the fourth request onwards
549
- # 2. application_pattern = "110"
550
- # repeat the following pattern indefinitely:
551
- # - inject the response for the first request
552
- # - inject the response for the second request
553
- # - stop using the response for the third request
554
-
555
- def __eq__(
556
- self,
557
- other: Union["InjectedResponse", requests.Request, requests.PreparedRequest],
558
- ):
559
- """Check InjectedResponse object equality.
560
-
561
- We use this to check if this response should be injected as a replacement of
562
- `other`.
563
-
564
- :param other:
565
- :return:
566
- """
567
- if not isinstance(
568
- other, (InjectedResponse, requests.Request, requests.PreparedRequest)
569
- ):
570
- return False
571
-
572
- # always check the method and url
573
- ret = self.method == other.method and self.url == other.url
574
- # use custom_match_fn to check, e.g. the request body content
575
- if ret and self.custom_match_fn is not None:
576
- ret = self.custom_match_fn(self, other)
577
- return ret
578
-
579
- def to_dict(self):
580
- excluded_fields = {"application_pattern", "custom_match_fn"}
581
- return {
582
- k: self.__getattribute__(k)
583
- for k in self.__dict__
584
- if (not k.startswith("_") and k not in excluded_fields)
585
- }
586
-
587
-
588
- class RelayControlProtocol(Protocol):
589
- def process(self, request: "flask.Request") -> None: ... # pragma: no cover
590
-
591
- def control(
592
- self, request: "flask.Request"
593
- ) -> Mapping[str, str]: ... # pragma: no cover
594
-
595
-
596
- class RelayServer:
597
- def __init__(
598
- self,
599
- base_url: str,
600
- inject: Optional[List[InjectedResponse]] = None,
601
- control: Optional[RelayControlProtocol] = None,
602
- verbose: bool = False,
603
- ) -> None:
604
- # todo for the future:
605
- # - consider switching from Flask to Quart
606
- # - async app will allow for better failure injection/poor network perf
607
- self.relay_control = control
608
- self.app = flask.Flask(__name__)
609
- self.app.logger.setLevel(logging.INFO)
610
- self.app.register_error_handler(DeliberateHTTPError, self.handle_http_exception)
611
- self.app.add_url_rule(
612
- rule="/graphql",
613
- endpoint="graphql",
614
- view_func=self.graphql,
615
- methods=["POST"],
616
- )
617
- self.app.add_url_rule(
618
- rule="/files/<path:path>",
619
- endpoint="files",
620
- view_func=self.file_stream,
621
- methods=["POST"],
622
- )
623
- self.app.add_url_rule(
624
- rule="/storage",
625
- endpoint="storage",
626
- view_func=self.storage,
627
- methods=["PUT", "GET"],
628
- )
629
- self.app.add_url_rule(
630
- rule="/storage/<path:path>",
631
- endpoint="storage_file",
632
- view_func=self.storage_file,
633
- methods=["PUT", "GET"],
634
- )
635
- if control:
636
- self.app.add_url_rule(
637
- rule="/_control",
638
- endpoint="_control",
639
- view_func=self.control,
640
- methods=["POST"],
641
- )
642
- # @app.route("/artifacts/<entity>/<digest>", methods=["GET", "POST"])
643
- self.port = self._get_free_port()
644
- self.base_url = urllib.parse.urlparse(base_url)
645
- self.session = requests.Session()
646
- self.relay_url = f"http://127.0.0.1:{self.port}"
647
-
648
- # todo: add an option to add custom resolvers
649
- self.resolver = QueryResolver()
650
- # recursively merge-able object to store state
651
- self.context = Context()
652
-
653
- # injected responses
654
- self.inject = inject or []
655
-
656
- # useful when debugging:
657
- # self.after_request_fn = self.app.after_request(self.after_request_fn)
658
- self.verbose = verbose
659
-
660
- @staticmethod
661
- def handle_http_exception(e):
662
- response = e.get_response()
663
- return response
664
-
665
- @staticmethod
666
- def _get_free_port() -> int:
667
- sock = socket.socket()
668
- sock.bind(("", 0))
669
-
670
- _, port = sock.getsockname()
671
- return port
672
-
673
- def start(self) -> None:
674
- # run server in a separate thread
675
- relay_server_thread = threading.Thread(
676
- target=self.app.run,
677
- kwargs={"port": self.port},
678
- daemon=True,
679
- )
680
- relay_server_thread.start()
681
-
682
- def after_request_fn(self, response: "requests.Response") -> "requests.Response":
683
- # todo: this is useful for debugging, but should be removed in the future
684
- # flask.request.url = self.relay_url + flask.request.url
685
- print(flask.request)
686
- print(flask.request.get_json())
687
- print(response)
688
- print(response.json())
689
- return response
690
-
691
- def relay(
692
- self,
693
- request: "flask.Request",
694
- ) -> Union["responses.Response", "requests.Response", None]:
695
- # replace the relay url with the real backend url (self.base_url)
696
- url = (
697
- urllib.parse.urlparse(request.url)
698
- ._replace(netloc=self.base_url.netloc, scheme=self.base_url.scheme)
699
- .geturl()
700
- )
701
- headers = {key: value for (key, value) in request.headers if key != "Host"}
702
- prepared_relayed_request = requests.Request(
703
- method=request.method,
704
- url=url,
705
- headers=headers,
706
- data=request.get_data(),
707
- json=request.get_json(),
708
- ).prepare()
709
-
710
- if self.verbose:
711
- print("*****************")
712
- print("RELAY REQUEST:")
713
- print(prepared_relayed_request.url)
714
- print(prepared_relayed_request.method)
715
- print(prepared_relayed_request.headers)
716
- print(prepared_relayed_request.body)
717
- print("*****************")
718
-
719
- for injected_response in self.inject:
720
- # where are we in the application pattern?
721
- should_apply = injected_response.application_pattern.should_apply()
722
- # check if an injected response matches the request
723
- if injected_response != prepared_relayed_request or not should_apply:
724
- continue
725
-
726
- if self.verbose:
727
- print("*****************")
728
- print("INJECTING RESPONSE:")
729
- print(injected_response.to_dict())
730
- print("*****************")
731
- # rotate the injection pattern
732
- injected_response.application_pattern.next()
733
-
734
- # TODO: allow access to the request object when making the mocked response
735
- with responses.RequestsMock() as mocked_responses:
736
- # do the actual injection
737
- resp = injected_response.to_dict()
738
-
739
- if isinstance(resp["body"], ConnectionResetError):
740
- return None
741
-
742
- mocked_responses.add(**resp)
743
- relayed_response = self.session.send(prepared_relayed_request)
744
-
745
- return relayed_response
746
-
747
- # normal case: no injected response matches the request
748
- relayed_response = self.session.send(prepared_relayed_request)
749
- return relayed_response
750
-
751
- def snoop_context(
752
- self,
753
- request: "flask.Request",
754
- response: "requests.Response",
755
- time_elapsed: float,
756
- **kwargs: Any,
757
- ) -> None:
758
- request_data = request.get_json()
759
- response_data = response.json() or {}
760
-
761
- if self.relay_control:
762
- self.relay_control.process(request)
763
-
764
- # store raw data
765
- raw_data: RawRequestResponse = {
766
- "url": request.url,
767
- "request": request_data,
768
- "response": response_data,
769
- "time_elapsed": time_elapsed,
770
- }
771
- self.context.raw_data.append(raw_data)
772
-
773
- try:
774
- snooped_context = self.resolver.resolve(
775
- request_data,
776
- response_data,
777
- **kwargs,
778
- )
779
- for entry in snooped_context:
780
- self.context.upsert(entry)
781
- except Exception as e:
782
- print("Failed to resolve context: ", e)
783
- traceback.print_exc()
784
- snooped_context = None
785
-
786
- return None
787
-
788
- def graphql(self) -> Mapping[str, str]:
789
- request = flask.request
790
- with Timer() as timer:
791
- relayed_response = self.relay(request)
792
- if self.verbose:
793
- print("*****************")
794
- print("GRAPHQL REQUEST:")
795
- print(request.get_json())
796
- print("GRAPHQL RESPONSE:")
797
- print(relayed_response.status_code, relayed_response.json())
798
- print("*****************")
799
- # snoop work to extract the context
800
- self.snoop_context(request, relayed_response, timer.elapsed)
801
- if self.verbose:
802
- print("*****************")
803
- print("SNOOPED CONTEXT:")
804
- print(self.context.entries)
805
- print(len(self.context.raw_data))
806
- print("*****************")
807
-
808
- return relayed_response.json()
809
-
810
- def file_stream(self, path) -> Mapping[str, str]:
811
- request = flask.request
812
-
813
- with Timer() as timer:
814
- relayed_response = self.relay(request)
815
-
816
- # simulate connection reset by peer
817
- if relayed_response is None:
818
- connection = request.environ["werkzeug.socket"] # Get the socket object
819
- connection.shutdown(socket.SHUT_RDWR)
820
- connection.close()
821
-
822
- if self.verbose:
823
- print("*****************")
824
- print("FILE STREAM REQUEST:")
825
- print("********PATH*********")
826
- print(path)
827
- print("********ENDPATH*********")
828
- print(request.get_json())
829
- print("FILE STREAM RESPONSE:")
830
- print(relayed_response)
831
- print(relayed_response.status_code, relayed_response.json())
832
- print("*****************")
833
- self.snoop_context(request, relayed_response, timer.elapsed, path=path)
834
-
835
- return relayed_response.json()
836
-
837
- def storage(self) -> Mapping[str, str]:
838
- request = flask.request
839
- with Timer() as timer:
840
- relayed_response = self.relay(request)
841
- if self.verbose:
842
- print("*****************")
843
- print("STORAGE REQUEST:")
844
- print(request.get_json())
845
- print("STORAGE RESPONSE:")
846
- print(relayed_response.status_code, relayed_response.json())
847
- print("*****************")
848
-
849
- self.snoop_context(request, relayed_response, timer.elapsed)
850
-
851
- return relayed_response.json()
852
-
853
- def storage_file(self, path) -> Mapping[str, str]:
854
- request = flask.request
855
- with Timer() as timer:
856
- relayed_response = self.relay(request)
857
- if self.verbose:
858
- print("*****************")
859
- print("STORAGE FILE REQUEST:")
860
- print("********PATH*********")
861
- print(path)
862
- print("********ENDPATH*********")
863
- print(request.get_json())
864
- print("STORAGE FILE RESPONSE:")
865
- print(relayed_response.json())
866
- print("*****************")
867
-
868
- self.snoop_context(request, relayed_response, timer.elapsed, path=path)
869
-
870
- return relayed_response.json()
871
-
872
- def control(self) -> Mapping[str, str]:
873
- assert self.relay_control
874
- return self.relay_control.control(flask.request)