wandb 0.17.8rc1__py3-none-win_amd64.whl → 0.17.9__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. package_readme.md +47 -53
  2. wandb/__init__.py +12 -6
  3. wandb/__init__.pyi +112 -2
  4. wandb/bin/wandb-core +0 -0
  5. wandb/data_types.py +1 -0
  6. wandb/env.py +13 -0
  7. wandb/integration/keras/__init__.py +2 -5
  8. wandb/integration/keras/callbacks/metrics_logger.py +10 -4
  9. wandb/integration/keras/callbacks/model_checkpoint.py +0 -5
  10. wandb/integration/keras/keras.py +12 -1
  11. wandb/integration/openai/fine_tuning.py +5 -5
  12. wandb/integration/tensorboard/log.py +1 -1
  13. wandb/proto/v3/wandb_internal_pb2.py +31 -21
  14. wandb/proto/v3/wandb_settings_pb2.py +2 -2
  15. wandb/proto/v3/wandb_telemetry_pb2.py +12 -12
  16. wandb/proto/v4/wandb_internal_pb2.py +23 -21
  17. wandb/proto/v4/wandb_settings_pb2.py +2 -2
  18. wandb/proto/v4/wandb_telemetry_pb2.py +12 -12
  19. wandb/proto/v5/wandb_internal_pb2.py +23 -21
  20. wandb/proto/v5/wandb_settings_pb2.py +2 -2
  21. wandb/proto/v5/wandb_telemetry_pb2.py +12 -12
  22. wandb/proto/wandb_deprecated.py +4 -0
  23. wandb/sdk/__init__.py +1 -1
  24. wandb/sdk/artifacts/artifact.py +9 -11
  25. wandb/sdk/artifacts/artifact_manifest_entry.py +10 -2
  26. wandb/sdk/artifacts/storage_handlers/gcs_handler.py +31 -0
  27. wandb/sdk/internal/system/assets/trainium.py +2 -1
  28. wandb/sdk/internal/tb_watcher.py +1 -1
  29. wandb/sdk/lib/_settings_toposort_generated.py +5 -3
  30. wandb/sdk/service/service.py +7 -2
  31. wandb/sdk/wandb_init.py +5 -1
  32. wandb/sdk/wandb_manager.py +0 -3
  33. wandb/sdk/wandb_require.py +22 -1
  34. wandb/sdk/wandb_run.py +14 -4
  35. wandb/sdk/wandb_settings.py +32 -10
  36. wandb/sdk/wandb_setup.py +3 -0
  37. {wandb-0.17.8rc1.dist-info → wandb-0.17.9.dist-info}/METADATA +48 -54
  38. {wandb-0.17.8rc1.dist-info → wandb-0.17.9.dist-info}/RECORD +42 -43
  39. wandb/testing/relay.py +0 -874
  40. /wandb/{viz.py → sdk/lib/viz.py} +0 -0
  41. {wandb-0.17.8rc1.dist-info → wandb-0.17.9.dist-info}/WHEEL +0 -0
  42. {wandb-0.17.8rc1.dist-info → wandb-0.17.9.dist-info}/entry_points.txt +0 -0
  43. {wandb-0.17.8rc1.dist-info → wandb-0.17.9.dist-info}/licenses/LICENSE +0 -0
wandb/testing/relay.py DELETED
@@ -1,874 +0,0 @@
1
- import dataclasses
2
- import json
3
- import logging
4
- import socket
5
- import sys
6
- import threading
7
- import traceback
8
- import urllib.parse
9
- from collections import defaultdict, deque
10
- from copy import deepcopy
11
- from typing import (
12
- TYPE_CHECKING,
13
- Any,
14
- Callable,
15
- Dict,
16
- Iterable,
17
- List,
18
- Mapping,
19
- Optional,
20
- Union,
21
- )
22
-
23
- import flask
24
- import pandas as pd
25
- import requests
26
- import responses
27
-
28
- import wandb
29
- import wandb.util
30
- from wandb.sdk.lib.timer import Timer
31
-
32
- try:
33
- from typing import Literal, TypedDict
34
- except ImportError:
35
- from typing_extensions import Literal, TypedDict
36
-
37
- if sys.version_info >= (3, 8):
38
- from typing import Protocol
39
- else:
40
- from typing_extensions import Protocol
41
-
42
- if TYPE_CHECKING:
43
- from typing import Deque
44
-
45
- class RawRequestResponse(TypedDict):
46
- url: str
47
- request: Optional[Any]
48
- response: Dict[str, Any]
49
- time_elapsed: float # seconds
50
-
51
- ResolverName = Literal[
52
- "upsert_bucket",
53
- "upload_files",
54
- "uploaded_files",
55
- "preempting",
56
- "upsert_sweep",
57
- ]
58
-
59
- class Resolver(TypedDict):
60
- name: ResolverName
61
- resolver: Callable[[Any], Optional[Dict[str, Any]]]
62
-
63
-
64
- class DeliberateHTTPError(Exception):
65
- def __init__(self, message, status_code: int = 500):
66
- Exception.__init__(self)
67
- self.message = message
68
- self.status_code = status_code
69
-
70
- def get_response(self):
71
- return flask.Response(self.message, status=self.status_code)
72
-
73
- def __repr__(self):
74
- return f"DeliberateHTTPError({self.message!r}, {self.status_code!r})"
75
-
76
-
77
- @dataclasses.dataclass
78
- class RunAttrs:
79
- """Simple data class for run attributes."""
80
-
81
- name: str
82
- display_name: str
83
- description: str
84
- sweep_name: str
85
- project: Dict[str, Any]
86
- config: Dict[str, Any]
87
- remote: Optional[str] = None
88
- commit: Optional[str] = None
89
-
90
-
91
- class Context:
92
- """A container used to store the snooped state/data of a test.
93
-
94
- Includes raw requests and responses, parsed and processed data, and a number of
95
- convenience methods and properties for accessing the data.
96
- """
97
-
98
- def __init__(self) -> None:
99
- # parsed/merged data. keys are the individual wandb run id's.
100
- self._entries = defaultdict(dict)
101
- # container for raw requests and responses:
102
- self.raw_data: List[RawRequestResponse] = []
103
- # concatenated file contents for all runs:
104
- self._history: Optional[pd.DataFrame] = None
105
- self._events: Optional[pd.DataFrame] = None
106
- self._summary: Optional[pd.DataFrame] = None
107
- self._config: Optional[Dict[str, Any]] = None
108
- self._output: Optional[Any] = None
109
-
110
- def upsert(self, entry: Dict[str, Any]) -> None:
111
- try:
112
- entry_id: str = entry["name"]
113
- except KeyError:
114
- entry_id = entry["id"]
115
- self._entries[entry_id] = wandb.util.merge_dicts(entry, self._entries[entry_id])
116
-
117
- # mapping interface
118
- def __getitem__(self, key: str) -> Any:
119
- return self._entries[key]
120
-
121
- def keys(self) -> Iterable[str]:
122
- return self._entries.keys()
123
-
124
- def get_file_contents(self, file_name: str) -> pd.DataFrame:
125
- dfs = []
126
-
127
- for entry_id in self._entries:
128
- # - extract the content from `file_name`
129
- # - sort by offset (will be useful when relay server goes async)
130
- # - extract data, merge into a list of dicts and convert to a pandas dataframe
131
- content_list = self._entries[entry_id].get("files", {}).get(file_name, [])
132
- content_list.sort(key=lambda x: x["offset"])
133
- content_list = [item["content"] for item in content_list]
134
- # merge list of lists content_list:
135
- content_list = [item for sublist in content_list for item in sublist]
136
- df = pd.DataFrame.from_records(content_list)
137
- df["__run_id"] = entry_id
138
- dfs.append(df)
139
-
140
- return pd.concat(dfs)
141
-
142
- # attributes to use in assertions
143
- @property
144
- def entries(self) -> Dict[str, Any]:
145
- return deepcopy(self._entries)
146
-
147
- @property
148
- def history(self) -> pd.DataFrame:
149
- # todo: caveat: this assumes that all assertions happen at the end of a test
150
- if self._history is not None:
151
- return deepcopy(self._history)
152
-
153
- self._history = self.get_file_contents("wandb-history.jsonl")
154
- return deepcopy(self._history)
155
-
156
- @property
157
- def events(self) -> pd.DataFrame:
158
- if self._events is not None:
159
- return deepcopy(self._events)
160
-
161
- self._events = self.get_file_contents("wandb-events.jsonl")
162
- return deepcopy(self._events)
163
-
164
- @property
165
- def summary(self) -> pd.DataFrame:
166
- if self._summary is not None:
167
- return deepcopy(self._summary)
168
-
169
- _summary = self.get_file_contents("wandb-summary.json")
170
-
171
- # run summary may be updated multiple times,
172
- # but we are only interested in the last one.
173
- # we can have multiple runs saved to context,
174
- # so we need to group by run id and take the
175
- # last one for each run.
176
- self._summary = (
177
- _summary.groupby("__run_id").last().reset_index(level=["__run_id"])
178
- )
179
-
180
- return deepcopy(self._summary)
181
-
182
- @property
183
- def output(self) -> pd.DataFrame:
184
- if self._output is not None:
185
- return deepcopy(self._output)
186
-
187
- self._output = self.get_file_contents("output.log")
188
- return deepcopy(self._output)
189
-
190
- @property
191
- def config(self) -> Dict[str, Any]:
192
- if self._config is not None:
193
- return deepcopy(self._config)
194
-
195
- self._config = {k: v["config"] for (k, v) in self._entries.items()}
196
- return deepcopy(self._config)
197
-
198
- # @property
199
- # def telemetry(self) -> pd.DataFrame:
200
- # telemetry = pd.DataFrame.from_records(
201
- # [
202
- # {
203
- # "__run_id": run_id,
204
- # "telemetry": config.get("_wandb", {}).get("value", {}).get("t")
205
- # }
206
- # for (run_id, config) in self.config.items()
207
- # ]
208
- # )
209
- # return telemetry
210
-
211
- # convenience data access methods
212
- def get_run_telemetry(self, run_id: str) -> Dict[str, Any]:
213
- return self.config.get(run_id, {}).get("_wandb", {}).get("value", {}).get("t")
214
-
215
- def get_run_metrics(self, run_id: str) -> Dict[str, Any]:
216
- return self.config.get(run_id, {}).get("_wandb", {}).get("value", {}).get("m")
217
-
218
- def get_run_summary(
219
- self, run_id: str, include_private: bool = False
220
- ) -> Dict[str, Any]:
221
- # run summary dataframe must have only one row
222
- # for the given run id, so we convert it to dict
223
- # and extract the first (and only) row.
224
- mask_run = self.summary["__run_id"] == run_id
225
- run_summary = self.summary[mask_run]
226
- ret = (
227
- run_summary.filter(regex="^[^_]", axis=1)
228
- if not include_private
229
- else run_summary
230
- ).to_dict(orient="records")
231
- return ret[0] if len(ret) > 0 else {}
232
-
233
- def get_run_history(
234
- self, run_id: str, include_private: bool = False
235
- ) -> pd.DataFrame:
236
- mask_run = self.history["__run_id"] == run_id
237
- run_history = self.history[mask_run]
238
- return (
239
- run_history.filter(regex="^[^_]", axis=1)
240
- if not include_private
241
- else run_history
242
- )
243
-
244
- def get_run_uploaded_files(self, run_id: str) -> Dict[str, Any]:
245
- return self.entries.get(run_id, {}).get("uploaded", [])
246
-
247
- def get_run_stats(self, run_id: str) -> pd.DataFrame:
248
- mask_run = self.events["__run_id"] == run_id
249
- run_stats = self.events[mask_run]
250
- return run_stats
251
-
252
- def get_run_attrs(self, run_id: str) -> Optional[RunAttrs]:
253
- run_entry = self._entries.get(run_id)
254
- if not run_entry:
255
- return None
256
-
257
- return RunAttrs(
258
- name=run_entry["name"],
259
- display_name=run_entry["displayName"],
260
- description=run_entry["description"],
261
- sweep_name=run_entry["sweepName"],
262
- project=run_entry["project"],
263
- config=run_entry["config"],
264
- remote=run_entry.get("repo"),
265
- commit=run_entry.get("commit"),
266
- )
267
-
268
- def get_run(self, run_id: str) -> Dict[str, Any]:
269
- return self._entries.get(run_id, {})
270
-
271
- # todo: add getter (by run_id) utilities for other properties
272
-
273
-
274
- class QueryResolver:
275
- """Resolve request/response pairs against a set of known patterns.
276
-
277
- This extracts and processes useful data to be later stored in a Context object.
278
- """
279
-
280
- def __init__(self):
281
- self.resolvers: List[Resolver] = [
282
- {
283
- "name": "upsert_bucket",
284
- "resolver": self.resolve_upsert_bucket,
285
- },
286
- {
287
- "name": "upload_files",
288
- "resolver": self.resolve_upload_files,
289
- },
290
- {
291
- "name": "uploaded_files",
292
- "resolver": self.resolve_uploaded_files,
293
- },
294
- {
295
- "name": "uploaded_files_legacy",
296
- "resolver": self.resolve_uploaded_files_legacy,
297
- },
298
- {
299
- "name": "preempting",
300
- "resolver": self.resolve_preempting,
301
- },
302
- {
303
- "name": "upsert_sweep",
304
- "resolver": self.resolve_upsert_sweep,
305
- },
306
- {
307
- "name": "create_artifact",
308
- "resolver": self.resolve_create_artifact,
309
- },
310
- {
311
- "name": "delete_run",
312
- "resolver": self.resolve_delete_run,
313
- },
314
- ]
315
-
316
- @staticmethod
317
- def resolve_upsert_bucket(
318
- request_data: Dict[str, Any], response_data: Dict[str, Any], **kwargs: Any
319
- ) -> Optional[Dict[str, Any]]:
320
- if not isinstance(request_data, dict) or not isinstance(response_data, dict):
321
- return None
322
- query = response_data.get("data", {}).get("upsertBucket") is not None
323
- if query:
324
- data = {
325
- k: v for (k, v) in request_data["variables"].items() if v is not None
326
- }
327
- data.update(response_data["data"]["upsertBucket"].get("bucket"))
328
- if "config" in data:
329
- data["config"] = json.loads(data["config"])
330
- return data
331
- return None
332
-
333
- @staticmethod
334
- def resolve_delete_run(
335
- request_data: Dict[str, Any], response_data: Dict[str, Any], **kwargs: Any
336
- ) -> Optional[Dict[str, Any]]:
337
- if not isinstance(request_data, dict) or not isinstance(response_data, dict):
338
- return None
339
- query = "query" in request_data and "deleteRun" in request_data["query"]
340
- if query:
341
- data = {
342
- k: v for (k, v) in request_data["variables"].items() if v is not None
343
- }
344
- data.update(response_data["data"]["deleteRun"])
345
- return data
346
- return None
347
-
348
- @staticmethod
349
- def resolve_upload_files(
350
- request_data: Dict[str, Any], response_data: Dict[str, Any], **kwargs: Any
351
- ) -> Optional[Dict[str, Any]]:
352
- if not isinstance(request_data, dict):
353
- return None
354
-
355
- query = request_data.get("files") is not None
356
- if query:
357
- # todo: refactor this 🤮🤮🤮🤮🤮 eventually?
358
- name = kwargs.get("path").split("/")[2]
359
- files = defaultdict(list)
360
- for file_name, file_value in request_data["files"].items():
361
- content = []
362
- for k in file_value.get("content", []):
363
- try:
364
- content.append(json.loads(k))
365
- except json.decoder.JSONDecodeError:
366
- content.append([k])
367
-
368
- files[file_name].append(
369
- {"offset": file_value.get("offset"), "content": content}
370
- )
371
-
372
- post_processed_data = {
373
- "name": name,
374
- "dropped": [request_data["dropped"]]
375
- if "dropped" in request_data
376
- else [],
377
- "files": files,
378
- }
379
- return post_processed_data
380
- return None
381
-
382
- @staticmethod
383
- def resolve_uploaded_files(
384
- request_data: Dict[str, Any], response_data: Dict[str, Any], **kwargs: Any
385
- ) -> Optional[Dict[str, Any]]:
386
- if not isinstance(request_data, dict) or not isinstance(response_data, dict):
387
- return None
388
-
389
- query = "CreateRunFiles" in request_data.get("query", "")
390
- if query:
391
- run_name = request_data["variables"]["run"]
392
- files = ((response_data.get("data") or {}).get("createRunFiles") or {}).get(
393
- "files", {}
394
- )
395
- post_processed_data = {
396
- "name": run_name,
397
- "uploaded": [file["name"] for file in files] if files else [""],
398
- }
399
- return post_processed_data
400
- return None
401
-
402
- @staticmethod
403
- def resolve_uploaded_files_legacy(
404
- request_data: Dict[str, Any], response_data: Dict[str, Any], **kwargs: Any
405
- ) -> Optional[Dict[str, Any]]:
406
- # This is a legacy resolver for uploaded files
407
- # No longer used by tests but leaving it here in case we need it in the future
408
- # Please refer to upload_urls() in internal_api.py for more details
409
- if not isinstance(request_data, dict) or not isinstance(response_data, dict):
410
- return None
411
-
412
- query = "RunUploadUrls" in request_data.get("query", "")
413
- if query:
414
- # todo: refactor this 🤮🤮🤮🤮🤮 eventually?
415
- name = request_data["variables"]["run"]
416
- files = (
417
- response_data.get("data", {})
418
- .get("model", {})
419
- .get("bucket", {})
420
- .get("files", {})
421
- .get("edges", [])
422
- )
423
- # note: we count all attempts to upload files
424
- post_processed_data = {
425
- "name": name,
426
- "uploaded": [files[0].get("node", {}).get("name")] if files else [""],
427
- }
428
- return post_processed_data
429
- return None
430
-
431
- @staticmethod
432
- def resolve_preempting(
433
- request_data: Dict[str, Any], response_data: Dict[str, Any], **kwargs: Any
434
- ) -> Optional[Dict[str, Any]]:
435
- if not isinstance(request_data, dict):
436
- return None
437
- query = "preempting" in request_data
438
- if query:
439
- name = kwargs.get("path").split("/")[2]
440
- post_processed_data = {
441
- "name": name,
442
- "preempting": [request_data["preempting"]],
443
- }
444
- return post_processed_data
445
- return None
446
-
447
- @staticmethod
448
- def resolve_upsert_sweep(
449
- request_data: Dict[str, Any], response_data: Dict[str, Any], **kwargs: Any
450
- ) -> Optional[Dict[str, Any]]:
451
- if not isinstance(response_data, dict):
452
- return None
453
- query = response_data.get("data", {}).get("upsertSweep") is not None
454
- if query:
455
- data = response_data["data"]["upsertSweep"].get("sweep")
456
- return data
457
- return None
458
-
459
- def resolve_create_artifact(
460
- self, request_data: Dict[str, Any], response_data: Dict[str, Any], **kwargs: Any
461
- ) -> Optional[Dict[str, Any]]:
462
- if not isinstance(request_data, dict):
463
- return None
464
- query = (
465
- "createArtifact(" in request_data.get("query", "")
466
- and request_data.get("variables") is not None
467
- and response_data is not None
468
- )
469
- if query:
470
- name = request_data["variables"]["runName"]
471
- post_processed_data = {
472
- "name": name,
473
- "create_artifact": [
474
- {
475
- "variables": request_data["variables"],
476
- "response": response_data["data"]["createArtifact"]["artifact"],
477
- }
478
- ],
479
- }
480
- return post_processed_data
481
- return None
482
-
483
- def resolve(
484
- self,
485
- request_data: Dict[str, Any],
486
- response_data: Dict[str, Any],
487
- **kwargs: Any,
488
- ) -> Optional[Dict[str, Any]]:
489
- results = []
490
- for resolver in self.resolvers:
491
- result = resolver.get("resolver")(request_data, response_data, **kwargs)
492
- if result is not None:
493
- results.append(result)
494
- return results
495
-
496
-
497
- class TokenizedCircularPattern:
498
- APPLY_TOKEN = "1"
499
- PASS_TOKEN = "0"
500
- STOP_TOKEN = "2"
501
-
502
- def __init__(self, pattern: str):
503
- known_tokens = {self.APPLY_TOKEN, self.PASS_TOKEN, self.STOP_TOKEN}
504
- if not pattern:
505
- raise ValueError("Pattern cannot be empty")
506
-
507
- if set(pattern) - known_tokens:
508
- raise ValueError(f"Pattern can only contain {known_tokens}")
509
- self.pattern: Deque[str] = deque(pattern)
510
-
511
- def next(self):
512
- if self.pattern[0] == self.STOP_TOKEN:
513
- return
514
- self.pattern.rotate(-1)
515
-
516
- def should_apply(self) -> bool:
517
- return self.pattern[0] == self.APPLY_TOKEN
518
-
519
-
520
- @dataclasses.dataclass
521
- class InjectedResponse:
522
- method: str
523
- url: str
524
- body: Union[str, Exception]
525
- status: int = 200
526
- content_type: str = "text/plain"
527
- # todo: add more fields for other types of responses?
528
- custom_match_fn: Optional[Callable[..., bool]] = None
529
- application_pattern: TokenizedCircularPattern = TokenizedCircularPattern("1")
530
-
531
- # application_pattern defines the pattern of the response injection
532
- # as the requests come in.
533
- # 0 == do not inject the response
534
- # 1 == inject the response
535
- # 2 == stop using the response (END token)
536
- #
537
- # - when no END token is present, the pattern is repeated indefinitely
538
- # - when END token is present, the pattern is applied until the END token is reached
539
- # - to replicate the current behavior:
540
- # - use application_pattern = "1" if wanting to apply the pattern to all requests
541
- # - use application_pattern = "1" * COUNTER + "2" to apply the pattern to the first COUNTER requests
542
- #
543
- # Examples of application_pattern:
544
- # 1. application_pattern = "1012"
545
- # - inject the response for the first request
546
- # - do not inject the response for the second request
547
- # - inject the response for the third request
548
- # - stop using the response starting from the fourth request onwards
549
- # 2. application_pattern = "110"
550
- # repeat the following pattern indefinitely:
551
- # - inject the response for the first request
552
- # - inject the response for the second request
553
- # - stop using the response for the third request
554
-
555
- def __eq__(
556
- self,
557
- other: Union["InjectedResponse", requests.Request, requests.PreparedRequest],
558
- ):
559
- """Check InjectedResponse object equality.
560
-
561
- We use this to check if this response should be injected as a replacement of
562
- `other`.
563
-
564
- :param other:
565
- :return:
566
- """
567
- if not isinstance(
568
- other, (InjectedResponse, requests.Request, requests.PreparedRequest)
569
- ):
570
- return False
571
-
572
- # always check the method and url
573
- ret = self.method == other.method and self.url == other.url
574
- # use custom_match_fn to check, e.g. the request body content
575
- if ret and self.custom_match_fn is not None:
576
- ret = self.custom_match_fn(self, other)
577
- return ret
578
-
579
- def to_dict(self):
580
- excluded_fields = {"application_pattern", "custom_match_fn"}
581
- return {
582
- k: self.__getattribute__(k)
583
- for k in self.__dict__
584
- if (not k.startswith("_") and k not in excluded_fields)
585
- }
586
-
587
-
588
- class RelayControlProtocol(Protocol):
589
- def process(self, request: "flask.Request") -> None: ... # pragma: no cover
590
-
591
- def control(
592
- self, request: "flask.Request"
593
- ) -> Mapping[str, str]: ... # pragma: no cover
594
-
595
-
596
- class RelayServer:
597
- def __init__(
598
- self,
599
- base_url: str,
600
- inject: Optional[List[InjectedResponse]] = None,
601
- control: Optional[RelayControlProtocol] = None,
602
- verbose: bool = False,
603
- ) -> None:
604
- # todo for the future:
605
- # - consider switching from Flask to Quart
606
- # - async app will allow for better failure injection/poor network perf
607
- self.relay_control = control
608
- self.app = flask.Flask(__name__)
609
- self.app.logger.setLevel(logging.INFO)
610
- self.app.register_error_handler(DeliberateHTTPError, self.handle_http_exception)
611
- self.app.add_url_rule(
612
- rule="/graphql",
613
- endpoint="graphql",
614
- view_func=self.graphql,
615
- methods=["POST"],
616
- )
617
- self.app.add_url_rule(
618
- rule="/files/<path:path>",
619
- endpoint="files",
620
- view_func=self.file_stream,
621
- methods=["POST"],
622
- )
623
- self.app.add_url_rule(
624
- rule="/storage",
625
- endpoint="storage",
626
- view_func=self.storage,
627
- methods=["PUT", "GET"],
628
- )
629
- self.app.add_url_rule(
630
- rule="/storage/<path:path>",
631
- endpoint="storage_file",
632
- view_func=self.storage_file,
633
- methods=["PUT", "GET"],
634
- )
635
- if control:
636
- self.app.add_url_rule(
637
- rule="/_control",
638
- endpoint="_control",
639
- view_func=self.control,
640
- methods=["POST"],
641
- )
642
- # @app.route("/artifacts/<entity>/<digest>", methods=["GET", "POST"])
643
- self.port = self._get_free_port()
644
- self.base_url = urllib.parse.urlparse(base_url)
645
- self.session = requests.Session()
646
- self.relay_url = f"http://127.0.0.1:{self.port}"
647
-
648
- # todo: add an option to add custom resolvers
649
- self.resolver = QueryResolver()
650
- # recursively merge-able object to store state
651
- self.context = Context()
652
-
653
- # injected responses
654
- self.inject = inject or []
655
-
656
- # useful when debugging:
657
- # self.after_request_fn = self.app.after_request(self.after_request_fn)
658
- self.verbose = verbose
659
-
660
- @staticmethod
661
- def handle_http_exception(e):
662
- response = e.get_response()
663
- return response
664
-
665
- @staticmethod
666
- def _get_free_port() -> int:
667
- sock = socket.socket()
668
- sock.bind(("", 0))
669
-
670
- _, port = sock.getsockname()
671
- return port
672
-
673
- def start(self) -> None:
674
- # run server in a separate thread
675
- relay_server_thread = threading.Thread(
676
- target=self.app.run,
677
- kwargs={"port": self.port},
678
- daemon=True,
679
- )
680
- relay_server_thread.start()
681
-
682
- def after_request_fn(self, response: "requests.Response") -> "requests.Response":
683
- # todo: this is useful for debugging, but should be removed in the future
684
- # flask.request.url = self.relay_url + flask.request.url
685
- print(flask.request)
686
- print(flask.request.get_json())
687
- print(response)
688
- print(response.json())
689
- return response
690
-
691
- def relay(
692
- self,
693
- request: "flask.Request",
694
- ) -> Union["responses.Response", "requests.Response", None]:
695
- # replace the relay url with the real backend url (self.base_url)
696
- url = (
697
- urllib.parse.urlparse(request.url)
698
- ._replace(netloc=self.base_url.netloc, scheme=self.base_url.scheme)
699
- .geturl()
700
- )
701
- headers = {key: value for (key, value) in request.headers if key != "Host"}
702
- prepared_relayed_request = requests.Request(
703
- method=request.method,
704
- url=url,
705
- headers=headers,
706
- data=request.get_data(),
707
- json=request.get_json(),
708
- ).prepare()
709
-
710
- if self.verbose:
711
- print("*****************")
712
- print("RELAY REQUEST:")
713
- print(prepared_relayed_request.url)
714
- print(prepared_relayed_request.method)
715
- print(prepared_relayed_request.headers)
716
- print(prepared_relayed_request.body)
717
- print("*****************")
718
-
719
- for injected_response in self.inject:
720
- # where are we in the application pattern?
721
- should_apply = injected_response.application_pattern.should_apply()
722
- # check if an injected response matches the request
723
- if injected_response != prepared_relayed_request or not should_apply:
724
- continue
725
-
726
- if self.verbose:
727
- print("*****************")
728
- print("INJECTING RESPONSE:")
729
- print(injected_response.to_dict())
730
- print("*****************")
731
- # rotate the injection pattern
732
- injected_response.application_pattern.next()
733
-
734
- # TODO: allow access to the request object when making the mocked response
735
- with responses.RequestsMock() as mocked_responses:
736
- # do the actual injection
737
- resp = injected_response.to_dict()
738
-
739
- if isinstance(resp["body"], ConnectionResetError):
740
- return None
741
-
742
- mocked_responses.add(**resp)
743
- relayed_response = self.session.send(prepared_relayed_request)
744
-
745
- return relayed_response
746
-
747
- # normal case: no injected response matches the request
748
- relayed_response = self.session.send(prepared_relayed_request)
749
- return relayed_response
750
-
751
- def snoop_context(
752
- self,
753
- request: "flask.Request",
754
- response: "requests.Response",
755
- time_elapsed: float,
756
- **kwargs: Any,
757
- ) -> None:
758
- request_data = request.get_json()
759
- response_data = response.json() or {}
760
-
761
- if self.relay_control:
762
- self.relay_control.process(request)
763
-
764
- # store raw data
765
- raw_data: RawRequestResponse = {
766
- "url": request.url,
767
- "request": request_data,
768
- "response": response_data,
769
- "time_elapsed": time_elapsed,
770
- }
771
- self.context.raw_data.append(raw_data)
772
-
773
- try:
774
- snooped_context = self.resolver.resolve(
775
- request_data,
776
- response_data,
777
- **kwargs,
778
- )
779
- for entry in snooped_context:
780
- self.context.upsert(entry)
781
- except Exception as e:
782
- print("Failed to resolve context: ", e)
783
- traceback.print_exc()
784
- snooped_context = None
785
-
786
- return None
787
-
788
- def graphql(self) -> Mapping[str, str]:
789
- request = flask.request
790
- with Timer() as timer:
791
- relayed_response = self.relay(request)
792
- if self.verbose:
793
- print("*****************")
794
- print("GRAPHQL REQUEST:")
795
- print(request.get_json())
796
- print("GRAPHQL RESPONSE:")
797
- print(relayed_response.status_code, relayed_response.json())
798
- print("*****************")
799
- # snoop work to extract the context
800
- self.snoop_context(request, relayed_response, timer.elapsed)
801
- if self.verbose:
802
- print("*****************")
803
- print("SNOOPED CONTEXT:")
804
- print(self.context.entries)
805
- print(len(self.context.raw_data))
806
- print("*****************")
807
-
808
- return relayed_response.json()
809
-
810
- def file_stream(self, path) -> Mapping[str, str]:
811
- request = flask.request
812
-
813
- with Timer() as timer:
814
- relayed_response = self.relay(request)
815
-
816
- # simulate connection reset by peer
817
- if relayed_response is None:
818
- connection = request.environ["werkzeug.socket"] # Get the socket object
819
- connection.shutdown(socket.SHUT_RDWR)
820
- connection.close()
821
-
822
- if self.verbose:
823
- print("*****************")
824
- print("FILE STREAM REQUEST:")
825
- print("********PATH*********")
826
- print(path)
827
- print("********ENDPATH*********")
828
- print(request.get_json())
829
- print("FILE STREAM RESPONSE:")
830
- print(relayed_response)
831
- print(relayed_response.status_code, relayed_response.json())
832
- print("*****************")
833
- self.snoop_context(request, relayed_response, timer.elapsed, path=path)
834
-
835
- return relayed_response.json()
836
-
837
- def storage(self) -> Mapping[str, str]:
838
- request = flask.request
839
- with Timer() as timer:
840
- relayed_response = self.relay(request)
841
- if self.verbose:
842
- print("*****************")
843
- print("STORAGE REQUEST:")
844
- print(request.get_json())
845
- print("STORAGE RESPONSE:")
846
- print(relayed_response.status_code, relayed_response.json())
847
- print("*****************")
848
-
849
- self.snoop_context(request, relayed_response, timer.elapsed)
850
-
851
- return relayed_response.json()
852
-
853
- def storage_file(self, path) -> Mapping[str, str]:
854
- request = flask.request
855
- with Timer() as timer:
856
- relayed_response = self.relay(request)
857
- if self.verbose:
858
- print("*****************")
859
- print("STORAGE FILE REQUEST:")
860
- print("********PATH*********")
861
- print(path)
862
- print("********ENDPATH*********")
863
- print(request.get_json())
864
- print("STORAGE FILE RESPONSE:")
865
- print(relayed_response.json())
866
- print("*****************")
867
-
868
- self.snoop_context(request, relayed_response, timer.elapsed, path=path)
869
-
870
- return relayed_response.json()
871
-
872
- def control(self) -> Mapping[str, str]:
873
- assert self.relay_control
874
- return self.relay_control.control(flask.request)