wandb 0.17.0rc2__py3-none-any.whl → 0.17.2__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (163) hide show
  1. wandb/__init__.py +4 -2
  2. wandb/apis/importers/internals/internal.py +0 -1
  3. wandb/apis/importers/wandb.py +12 -7
  4. wandb/apis/internal.py +0 -3
  5. wandb/apis/public/api.py +213 -79
  6. wandb/apis/public/artifacts.py +335 -100
  7. wandb/apis/public/files.py +9 -9
  8. wandb/apis/public/jobs.py +16 -4
  9. wandb/apis/public/projects.py +26 -28
  10. wandb/apis/public/query_generator.py +1 -1
  11. wandb/apis/public/runs.py +163 -65
  12. wandb/apis/public/sweeps.py +2 -2
  13. wandb/apis/reports/__init__.py +1 -7
  14. wandb/apis/reports/v1/__init__.py +5 -27
  15. wandb/apis/reports/v2/__init__.py +7 -19
  16. wandb/apis/workspaces/__init__.py +8 -0
  17. wandb/beta/workflows.py +8 -3
  18. wandb/cli/cli.py +151 -59
  19. wandb/docker/__init__.py +1 -1
  20. wandb/errors/term.py +10 -2
  21. wandb/filesync/step_checksum.py +1 -4
  22. wandb/filesync/step_prepare.py +4 -24
  23. wandb/filesync/step_upload.py +5 -107
  24. wandb/filesync/upload_job.py +0 -76
  25. wandb/integration/gym/__init__.py +35 -15
  26. wandb/integration/openai/fine_tuning.py +21 -3
  27. wandb/integration/prodigy/prodigy.py +1 -1
  28. wandb/jupyter.py +16 -17
  29. wandb/old/summary.py +5 -0
  30. wandb/plot/pr_curve.py +2 -1
  31. wandb/plot/roc_curve.py +2 -1
  32. wandb/{plots → plot}/utils.py +13 -25
  33. wandb/proto/v3/wandb_internal_pb2.py +54 -54
  34. wandb/proto/v3/wandb_settings_pb2.py +2 -2
  35. wandb/proto/v3/wandb_telemetry_pb2.py +10 -10
  36. wandb/proto/v4/wandb_internal_pb2.py +54 -54
  37. wandb/proto/v4/wandb_settings_pb2.py +2 -2
  38. wandb/proto/v4/wandb_telemetry_pb2.py +10 -10
  39. wandb/proto/v5/wandb_base_pb2.py +30 -0
  40. wandb/proto/v5/wandb_internal_pb2.py +355 -0
  41. wandb/proto/v5/wandb_server_pb2.py +63 -0
  42. wandb/proto/v5/wandb_settings_pb2.py +45 -0
  43. wandb/proto/v5/wandb_telemetry_pb2.py +41 -0
  44. wandb/proto/wandb_base_pb2.py +2 -0
  45. wandb/proto/wandb_deprecated.py +9 -1
  46. wandb/proto/wandb_generate_deprecated.py +34 -0
  47. wandb/proto/{wandb_internal_codegen.py → wandb_generate_proto.py} +1 -35
  48. wandb/proto/wandb_internal_pb2.py +2 -0
  49. wandb/proto/wandb_server_pb2.py +2 -0
  50. wandb/proto/wandb_settings_pb2.py +2 -0
  51. wandb/proto/wandb_telemetry_pb2.py +2 -0
  52. wandb/sdk/artifacts/artifact.py +76 -23
  53. wandb/sdk/artifacts/artifact_manifest.py +1 -1
  54. wandb/sdk/artifacts/artifact_manifest_entry.py +6 -3
  55. wandb/sdk/artifacts/artifact_manifests/artifact_manifest_v1.py +1 -1
  56. wandb/sdk/artifacts/artifact_saver.py +1 -10
  57. wandb/sdk/artifacts/storage_handlers/local_file_handler.py +6 -2
  58. wandb/sdk/artifacts/storage_handlers/multi_handler.py +1 -1
  59. wandb/sdk/artifacts/storage_handlers/tracking_handler.py +6 -4
  60. wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +2 -42
  61. wandb/sdk/artifacts/storage_policy.py +1 -12
  62. wandb/sdk/data_types/_dtypes.py +5 -2
  63. wandb/sdk/data_types/html.py +1 -1
  64. wandb/sdk/data_types/image.py +1 -1
  65. wandb/sdk/data_types/object_3d.py +1 -1
  66. wandb/sdk/data_types/video.py +4 -2
  67. wandb/sdk/interface/interface.py +13 -0
  68. wandb/sdk/interface/interface_shared.py +1 -1
  69. wandb/sdk/internal/file_pusher.py +2 -5
  70. wandb/sdk/internal/file_stream.py +6 -19
  71. wandb/sdk/internal/internal_api.py +160 -138
  72. wandb/sdk/internal/job_builder.py +207 -135
  73. wandb/sdk/internal/progress.py +0 -28
  74. wandb/sdk/internal/sender.py +105 -42
  75. wandb/sdk/internal/settings_static.py +8 -1
  76. wandb/sdk/internal/system/assets/gpu.py +2 -0
  77. wandb/sdk/internal/system/assets/trainium.py +3 -3
  78. wandb/sdk/internal/system/system_info.py +4 -2
  79. wandb/sdk/internal/update.py +1 -1
  80. wandb/sdk/launch/__init__.py +9 -1
  81. wandb/sdk/launch/_launch.py +4 -24
  82. wandb/sdk/launch/_launch_add.py +1 -3
  83. wandb/sdk/launch/_project_spec.py +184 -224
  84. wandb/sdk/launch/agent/agent.py +58 -18
  85. wandb/sdk/launch/agent/config.py +0 -3
  86. wandb/sdk/launch/builder/abstract.py +67 -0
  87. wandb/sdk/launch/builder/build.py +165 -576
  88. wandb/sdk/launch/builder/context_manager.py +235 -0
  89. wandb/sdk/launch/builder/docker_builder.py +7 -23
  90. wandb/sdk/launch/builder/kaniko_builder.py +10 -23
  91. wandb/sdk/launch/builder/templates/dockerfile.py +92 -0
  92. wandb/sdk/launch/create_job.py +51 -45
  93. wandb/sdk/launch/environment/aws_environment.py +26 -1
  94. wandb/sdk/launch/inputs/files.py +148 -0
  95. wandb/sdk/launch/inputs/internal.py +224 -0
  96. wandb/sdk/launch/inputs/manage.py +95 -0
  97. wandb/sdk/launch/runner/abstract.py +2 -2
  98. wandb/sdk/launch/runner/kubernetes_monitor.py +45 -12
  99. wandb/sdk/launch/runner/kubernetes_runner.py +6 -8
  100. wandb/sdk/launch/runner/local_container.py +2 -3
  101. wandb/sdk/launch/runner/local_process.py +8 -29
  102. wandb/sdk/launch/runner/sagemaker_runner.py +20 -14
  103. wandb/sdk/launch/runner/vertex_runner.py +8 -7
  104. wandb/sdk/launch/sweeps/scheduler.py +2 -0
  105. wandb/sdk/launch/sweeps/utils.py +2 -2
  106. wandb/sdk/launch/utils.py +16 -138
  107. wandb/sdk/lib/_settings_toposort_generated.py +2 -5
  108. wandb/sdk/lib/apikey.py +4 -2
  109. wandb/sdk/lib/config_util.py +3 -3
  110. wandb/sdk/lib/proto_util.py +22 -1
  111. wandb/sdk/lib/redirect.py +1 -1
  112. wandb/sdk/service/service.py +2 -1
  113. wandb/sdk/service/streams.py +5 -5
  114. wandb/sdk/wandb_init.py +25 -59
  115. wandb/sdk/wandb_login.py +28 -25
  116. wandb/sdk/wandb_run.py +135 -70
  117. wandb/sdk/wandb_settings.py +33 -64
  118. wandb/sdk/wandb_watch.py +1 -1
  119. wandb/sklearn/plot/classifier.py +4 -6
  120. wandb/sync/sync.py +2 -2
  121. wandb/testing/relay.py +32 -17
  122. wandb/util.py +39 -37
  123. wandb/wandb_agent.py +3 -3
  124. wandb/wandb_controller.py +3 -2
  125. {wandb-0.17.0rc2.dist-info → wandb-0.17.2.dist-info}/METADATA +7 -9
  126. {wandb-0.17.0rc2.dist-info → wandb-0.17.2.dist-info}/RECORD +129 -151
  127. {wandb-0.17.0rc2.dist-info → wandb-0.17.2.dist-info}/WHEEL +1 -1
  128. wandb/apis/reports/v1/_blocks.py +0 -1406
  129. wandb/apis/reports/v1/_helpers.py +0 -70
  130. wandb/apis/reports/v1/_panels.py +0 -1282
  131. wandb/apis/reports/v1/_templates.py +0 -478
  132. wandb/apis/reports/v1/blocks.py +0 -27
  133. wandb/apis/reports/v1/helpers.py +0 -2
  134. wandb/apis/reports/v1/mutations.py +0 -66
  135. wandb/apis/reports/v1/panels.py +0 -17
  136. wandb/apis/reports/v1/report.py +0 -268
  137. wandb/apis/reports/v1/runset.py +0 -144
  138. wandb/apis/reports/v1/templates.py +0 -7
  139. wandb/apis/reports/v1/util.py +0 -406
  140. wandb/apis/reports/v1/validators.py +0 -131
  141. wandb/apis/reports/v2/blocks.py +0 -25
  142. wandb/apis/reports/v2/expr_parsing.py +0 -257
  143. wandb/apis/reports/v2/gql.py +0 -68
  144. wandb/apis/reports/v2/interface.py +0 -1911
  145. wandb/apis/reports/v2/internal.py +0 -867
  146. wandb/apis/reports/v2/metrics.py +0 -6
  147. wandb/apis/reports/v2/panels.py +0 -15
  148. wandb/catboost/__init__.py +0 -9
  149. wandb/fastai/__init__.py +0 -9
  150. wandb/keras/__init__.py +0 -19
  151. wandb/lightgbm/__init__.py +0 -9
  152. wandb/plots/__init__.py +0 -6
  153. wandb/plots/explain_text.py +0 -36
  154. wandb/plots/heatmap.py +0 -81
  155. wandb/plots/named_entity.py +0 -43
  156. wandb/plots/part_of_speech.py +0 -50
  157. wandb/plots/plot_definitions.py +0 -768
  158. wandb/plots/precision_recall.py +0 -121
  159. wandb/plots/roc.py +0 -103
  160. wandb/sacred/__init__.py +0 -3
  161. wandb/xgboost/__init__.py +0 -9
  162. {wandb-0.17.0rc2.dist-info → wandb-0.17.2.dist-info}/entry_points.txt +0 -0
  163. {wandb-0.17.0rc2.dist-info → wandb-0.17.2.dist-info}/licenses/LICENSE +0 -0
@@ -7,7 +7,7 @@ from typing import Dict, Optional
7
7
  from wandb.sdk.launch.errors import LaunchError
8
8
  from wandb.util import get_module
9
9
 
10
- from ..utils import S3_URI_RE, event_loop_thread_exec
10
+ from ..utils import ARN_PARTITION_RE, S3_URI_RE, event_loop_thread_exec
11
11
  from .abstract import AbstractEnvironment
12
12
 
13
13
  boto3 = get_module(
@@ -49,6 +49,7 @@ class AwsEnvironment(AbstractEnvironment):
49
49
  self._secret_key = secret_key
50
50
  self._session_token = session_token
51
51
  self._account = None
52
+ self._partition = None
52
53
 
53
54
  @classmethod
54
55
  def from_default(cls, region: Optional[str] = None) -> "AwsEnvironment":
@@ -122,6 +123,30 @@ class AwsEnvironment(AbstractEnvironment):
122
123
  def region(self, region: str) -> None:
123
124
  self._region = region
124
125
 
126
+ async def get_partition(self) -> str:
127
+ """Set the partition for the AWS environment."""
128
+ try:
129
+ session = await self.get_session()
130
+ client = await event_loop_thread_exec(session.client)("sts")
131
+ get_caller_identity = event_loop_thread_exec(client.get_caller_identity)
132
+ identity = await get_caller_identity()
133
+ arn = identity.get("Arn")
134
+ if not arn:
135
+ raise LaunchError(
136
+ "Could not set partition for AWS environment. ARN not found."
137
+ )
138
+ matched_partition = ARN_PARTITION_RE.match(arn)
139
+ if not matched_partition:
140
+ raise LaunchError(
141
+ f"Could not set partition for AWS environment. ARN {arn} is not valid."
142
+ )
143
+ partition = matched_partition.group(1)
144
+ return partition
145
+ except botocore.exceptions.ClientError as e:
146
+ raise LaunchError(
147
+ f"Could not set partition for AWS environment. {e}"
148
+ ) from e
149
+
125
150
  async def verify(self) -> None:
126
151
  """Verify that the AWS environment is configured correctly.
127
152
 
@@ -0,0 +1,148 @@
1
+ import json
2
+ import os
3
+ from typing import Any, Dict
4
+
5
+ import yaml
6
+
7
+ from ..errors import LaunchError
8
+
9
+ FILE_OVERRIDE_ENV_VAR = "WANDB_LAUNCH_FILE_OVERRIDES"
10
+
11
+
12
+ class FileOverrides:
13
+ """Singleton that read file overrides json from environment variables."""
14
+
15
+ _instance = None
16
+
17
+ def __new__(cls):
18
+ if cls._instance is None:
19
+ cls._instance = object.__new__(cls)
20
+ cls._instance.overrides = {}
21
+ cls._instance.load()
22
+ return cls._instance
23
+
24
+ def load(self) -> None:
25
+ """Load overrides from an environment variable."""
26
+ overrides = os.environ.get(FILE_OVERRIDE_ENV_VAR)
27
+ if overrides is None:
28
+ if f"{FILE_OVERRIDE_ENV_VAR}_0" in os.environ:
29
+ overrides = ""
30
+ idx = 0
31
+ while f"{FILE_OVERRIDE_ENV_VAR}_{idx}" in os.environ:
32
+ overrides += os.environ[f"{FILE_OVERRIDE_ENV_VAR}_{idx}"]
33
+ idx += 1
34
+ if overrides:
35
+ try:
36
+ contents = json.loads(overrides)
37
+ if not isinstance(contents, dict):
38
+ raise LaunchError(f"Invalid JSON in {FILE_OVERRIDE_ENV_VAR}")
39
+ self.overrides = contents
40
+ except json.JSONDecodeError:
41
+ raise LaunchError(f"Invalid JSON in {FILE_OVERRIDE_ENV_VAR}")
42
+
43
+
44
+ def config_path_is_valid(path: str) -> None:
45
+ """Validate a config file path.
46
+
47
+ This function checks if a given config file path is valid. A valid path
48
+ should meet the following criteria:
49
+
50
+ - The path must be expressed as a relative path without any upwards path
51
+ traversal, e.g. `../config.json`.
52
+ - The file specified by the path must exist.
53
+ - The file must have a supported extension (`.json`, `.yaml`, or `.yml`).
54
+
55
+ Args:
56
+ path (str): The path to validate.
57
+
58
+ Raises:
59
+ LaunchError: If the path is not valid.
60
+ """
61
+ if os.path.isabs(path):
62
+ raise LaunchError(
63
+ f"Invalid config path: {path}. Please provide a relative path."
64
+ )
65
+ if ".." in path:
66
+ raise LaunchError(
67
+ f"Invalid config path: {path}. Please provide a relative path "
68
+ "without any upward path traversal, e.g. `../config.json`."
69
+ )
70
+ path = os.path.normpath(path)
71
+ if not os.path.exists(path):
72
+ raise LaunchError(f"Invalid config path: {path}. File does not exist.")
73
+ if not any(path.endswith(ext) for ext in [".json", ".yaml", ".yml"]):
74
+ raise LaunchError(
75
+ f"Invalid config path: {path}. Only JSON and YAML files are supported."
76
+ )
77
+
78
+
79
+ def override_file(path: str) -> None:
80
+ """Check for file overrides in the environment and apply them if found."""
81
+ file_overrides = FileOverrides()
82
+ if path in file_overrides.overrides:
83
+ overrides = file_overrides.overrides.get(path)
84
+ if overrides is not None:
85
+ config = _read_config_file(path)
86
+ _update_dict(config, overrides)
87
+ _write_config_file(path, config)
88
+
89
+
90
+ def _write_config_file(path: str, config: Any) -> None:
91
+ """Write a config file to disk.
92
+
93
+ Args:
94
+ path (str): The path to the config file.
95
+ config (Any): The contents of the config file as a Python object.
96
+
97
+ Raises:
98
+ LaunchError: If the file extension is not supported.
99
+ """
100
+ _, ext = os.path.splitext(path)
101
+ if ext == ".json":
102
+ with open(path, "w") as f:
103
+ json.dump(config, f, indent=2)
104
+ elif ext in [".yaml", ".yml"]:
105
+ with open(path, "w") as f:
106
+ yaml.safe_dump(config, f)
107
+ else:
108
+ raise LaunchError(f"Unsupported file extension: {ext}")
109
+
110
+
111
+ def _read_config_file(path: str) -> Any:
112
+ """Read a config file from disk.
113
+
114
+ Args:
115
+ path (str): The path to the config file.
116
+
117
+ Returns:
118
+ Any: The contents of the config file as a Python object.
119
+ """
120
+ _, ext = os.path.splitext(path)
121
+ if ext == ".json":
122
+ with open(
123
+ path,
124
+ ) as f:
125
+ return json.load(f)
126
+ elif ext in [".yaml", ".yml"]:
127
+ with open(
128
+ path,
129
+ ) as f:
130
+ return yaml.safe_load(f)
131
+ else:
132
+ raise LaunchError(f"Unsupported file extension: {ext}")
133
+
134
+
135
+ def _update_dict(target: Dict, source: Dict) -> None:
136
+ """Update a dictionary with the contents of another dictionary.
137
+
138
+ Args:
139
+ target (Dict): The dictionary to update.
140
+ source (Dict): The dictionary to update from.
141
+ """
142
+ for key, value in source.items():
143
+ if isinstance(value, dict):
144
+ if key not in target:
145
+ target[key] = {}
146
+ _update_dict(target[key], value)
147
+ else:
148
+ target[key] = value
@@ -0,0 +1,224 @@
1
+ """The layer between launch sdk user code and the wandb internal process.
2
+
3
+ If there is an active run this communication is done through the wandb run's
4
+ backend interface.
5
+
6
+ If there is no active run, the messages are staged on the StagedLaunchInputs
7
+ singleton and sent when a run is created.
8
+ """
9
+
10
+ import os
11
+ import pathlib
12
+ import shutil
13
+ import tempfile
14
+ from typing import List, Optional
15
+
16
+ import wandb
17
+ import wandb.data_types
18
+ from wandb.sdk.launch.errors import LaunchError
19
+ from wandb.sdk.wandb_run import Run
20
+
21
+ from .files import config_path_is_valid, override_file
22
+
23
+ PERIOD = "."
24
+ BACKSLASH = "\\"
25
+ LAUNCH_MANAGED_CONFIGS_DIR = "_wandb_configs"
26
+
27
+
28
+ class ConfigTmpDir:
29
+ """Singleton for managing temporary directories for configuration files.
30
+
31
+ Any configuration files designated as inputs to a launch job are copied to
32
+ a temporary directory. This singleton manages the temporary directory and
33
+ provides paths to the configuration files.
34
+ """
35
+
36
+ _instance = None
37
+
38
+ def __new__(cls):
39
+ if cls._instance is None:
40
+ cls._instance = object.__new__(cls)
41
+ return cls._instance
42
+
43
+ def __init__(self):
44
+ if not hasattr(self, "_tmp_dir"):
45
+ self._tmp_dir = tempfile.mkdtemp()
46
+ self._configs_dir = os.path.join(self._tmp_dir, LAUNCH_MANAGED_CONFIGS_DIR)
47
+ os.mkdir(self._configs_dir)
48
+
49
+ @property
50
+ def tmp_dir(self):
51
+ return pathlib.Path(self._tmp_dir)
52
+
53
+ @property
54
+ def configs_dir(self):
55
+ return pathlib.Path(self._configs_dir)
56
+
57
+
58
+ class JobInputArguments:
59
+ """Arguments for the publish_job_input of Interface."""
60
+
61
+ def __init__(
62
+ self,
63
+ include: Optional[List[str]] = None,
64
+ exclude: Optional[List[str]] = None,
65
+ file_path: Optional[str] = None,
66
+ run_config: Optional[bool] = None,
67
+ ):
68
+ self.include = include
69
+ self.exclude = exclude
70
+ self.file_path = file_path
71
+ self.run_config = run_config
72
+
73
+
74
+ class StagedLaunchInputs:
75
+ _instance = None
76
+
77
+ def __new__(cls):
78
+ if cls._instance is None:
79
+ cls._instance = object.__new__(cls)
80
+ return cls._instance
81
+
82
+ def __init__(self) -> None:
83
+ if not hasattr(self, "_staged_inputs"):
84
+ self._staged_inputs: List[JobInputArguments] = []
85
+
86
+ def add_staged_input(
87
+ self,
88
+ input_arguments: JobInputArguments,
89
+ ):
90
+ self._staged_inputs.append(input_arguments)
91
+
92
+ def apply(self, run: Run):
93
+ """Apply the staged inputs to the given run."""
94
+ for input in self._staged_inputs:
95
+ _publish_job_input(input, run)
96
+
97
+
98
+ def _publish_job_input(
99
+ input: JobInputArguments,
100
+ run: Run,
101
+ ) -> None:
102
+ """Publish a job input to the backend interface of the given run.
103
+
104
+ Arguments:
105
+ input (JobInputArguments): The arguments for the job input.
106
+ run (Run): The run to publish the job input to.
107
+ """
108
+ assert run._backend is not None
109
+ assert run._backend.interface is not None
110
+ assert input.run_config is not None
111
+
112
+ interface = run._backend.interface
113
+ if input.file_path:
114
+ config_dir = ConfigTmpDir()
115
+ dest = os.path.join(config_dir.configs_dir, input.file_path)
116
+ run.save(dest, base_path=config_dir.tmp_dir)
117
+ interface.publish_job_input(
118
+ include_paths=[_split_on_unesc_dot(path) for path in input.include]
119
+ if input.include
120
+ else [],
121
+ exclude_paths=[_split_on_unesc_dot(path) for path in input.exclude]
122
+ if input.exclude
123
+ else [],
124
+ run_config=input.run_config,
125
+ file_path=input.file_path or "",
126
+ )
127
+
128
+
129
+ def handle_config_file_input(
130
+ path: str,
131
+ include: Optional[List[str]] = None,
132
+ exclude: Optional[List[str]] = None,
133
+ ):
134
+ """Declare an overridable configuration file for a launch job.
135
+
136
+ The configuration file is copied to a temporary directory and the path to
137
+ the copy is sent to the backend interface of the active run and used to
138
+ configure the job builder.
139
+
140
+ If there is no active run, the configuration file is staged and sent when a
141
+ run is created.
142
+ """
143
+ config_path_is_valid(path)
144
+ override_file(path)
145
+ tmp_dir = ConfigTmpDir()
146
+ dest = os.path.join(tmp_dir.configs_dir, path)
147
+ dest_dir = os.path.dirname(dest)
148
+ if not os.path.exists(dest_dir):
149
+ os.makedirs(dest_dir)
150
+ shutil.copy(
151
+ path,
152
+ dest,
153
+ )
154
+ arguments = JobInputArguments(
155
+ include=include,
156
+ exclude=exclude,
157
+ file_path=path,
158
+ run_config=False,
159
+ )
160
+ if wandb.run is not None:
161
+ _publish_job_input(arguments, wandb.run)
162
+ else:
163
+ staged_inputs = StagedLaunchInputs()
164
+ staged_inputs.add_staged_input(arguments)
165
+
166
+
167
+ def handle_run_config_input(
168
+ include: Optional[List[str]] = None, exclude: Optional[List[str]] = None
169
+ ):
170
+ """Declare wandb.config as an overridable configuration for a launch job.
171
+
172
+ The include and exclude paths are sent to the backend interface of the
173
+ active run and used to configure the job builder.
174
+
175
+ If there is no active run, the include and exclude paths are staged and sent
176
+ when a run is created.
177
+ """
178
+ arguments = JobInputArguments(
179
+ include=include,
180
+ exclude=exclude,
181
+ run_config=True,
182
+ file_path=None,
183
+ )
184
+ if wandb.run is not None:
185
+ _publish_job_input(arguments, wandb.run)
186
+ else:
187
+ stage_inputs = StagedLaunchInputs()
188
+ stage_inputs.add_staged_input(arguments)
189
+
190
+
191
+ def _split_on_unesc_dot(path: str) -> List[str]:
192
+ r"""Split a string on unescaped dots.
193
+
194
+ Arguments:
195
+ path (str): The string to split.
196
+
197
+ Raises:
198
+ ValueError: If the path has a trailing escape character.
199
+
200
+ Returns:
201
+ List[str]: The split string.
202
+ """
203
+ parts = []
204
+ part = ""
205
+ i = 0
206
+ while i < len(path):
207
+ if path[i] == BACKSLASH:
208
+ if i == len(path) - 1:
209
+ raise LaunchError(
210
+ f"Invalid config path {path}: trailing {BACKSLASH}.",
211
+ )
212
+ if path[i + 1] == PERIOD:
213
+ part += PERIOD
214
+ i += 2
215
+ elif path[i] == PERIOD:
216
+ parts.append(part)
217
+ part = ""
218
+ i += 1
219
+ else:
220
+ part += path[i]
221
+ i += 1
222
+ if part:
223
+ parts.append(part)
224
+ return parts
@@ -0,0 +1,95 @@
1
+ """Functions for declaring overridable configuration for launch jobs."""
2
+
3
+ from typing import List, Optional
4
+
5
+
6
+ def manage_config_file(
7
+ path: str,
8
+ include: Optional[List[str]] = None,
9
+ exclude: Optional[List[str]] = None,
10
+ ):
11
+ r"""Declare an overridable configuration file for a launch job.
12
+
13
+ If a new job version is created from the active run, the configuration file
14
+ will be added to the job's inputs. If the job is launched and overrides
15
+ have been provided for the configuration file, this function will detect
16
+ the overrides from the environment and update the configuration file on disk.
17
+ Note that these overrides will only be applied in ephemeral containers.
18
+ `include` and `exclude` are lists of dot separated paths with the config.
19
+ The paths are used to filter subtrees of the configuration file out of the
20
+ job's inputs.
21
+
22
+ For example, given the following configuration file:
23
+ ```yaml
24
+ model:
25
+ name: resnet
26
+ layers: 18
27
+ training:
28
+ epochs: 10
29
+ batch_size: 32
30
+ ```
31
+
32
+ Passing `include=['model']` will only include the `model` subtree in the
33
+ job's inputs. Passing `exclude=['model.layers']` will exclude the `layers`
34
+ key from the `model` subtree. Note that `exclude` takes precedence over
35
+ `include`.
36
+
37
+ `.` is used as a separator for nested keys. If a key contains a `.`, it
38
+ should be escaped with a backslash, e.g. `include=[r'model\.layers']`. Note
39
+ the use of `r` to denote a raw string when using escape chars.
40
+
41
+ Args:
42
+ path (str): The path to the configuration file. This path must be
43
+ relative and must not contain backwards traversal, i.e. `..`.
44
+ include (List[str]): A list of keys to include in the configuration file.
45
+ exclude (List[str]): A list of keys to exclude from the configuration file.
46
+
47
+ Raises:
48
+ LaunchError: If the path is not valid, or if there is no active run.
49
+ """
50
+ from .internal import handle_config_file_input
51
+
52
+ return handle_config_file_input(path, include, exclude)
53
+
54
+
55
+ def manage_wandb_config(
56
+ include: Optional[List[str]] = None,
57
+ exclude: Optional[List[str]] = None,
58
+ ):
59
+ r"""Declare wandb.config as an overridable configuration for a launch job.
60
+
61
+ If a new job version is created from the active run, the run config
62
+ (wandb.config) will become an overridable input of the job. If the job is
63
+ launched and overrides have been provided for the run config, the overrides
64
+ will be applied to the run config when `wandb.init` is called.
65
+ `include` and `exclude` are lists of dot separated paths with the config.
66
+ The paths are used to filter subtrees of the configuration file out of the
67
+ job's inputs.
68
+
69
+ For example, given the following run config contents:
70
+ ```yaml
71
+ model:
72
+ name: resnet
73
+ layers: 18
74
+ training:
75
+ epochs: 10
76
+ batch_size: 32
77
+ ```
78
+ Passing `include=['model']` will only include the `model` subtree in the
79
+ job's inputs. Passing `exclude=['model.layers']` will exclude the `layers`
80
+ key from the `model` subtree. Note that `exclude` takes precedence over
81
+ `include`.
82
+ `.` is used as a separator for nested keys. If a key contains a `.`, it
83
+ should be escaped with a backslash, e.g. `include=[r'model\.layers']`. Note
84
+ the use of `r` to denote a raw string when using escape chars.
85
+
86
+ Args:
87
+ include (List[str]): A list of subtrees to include in the configuration.
88
+ exclude (List[str]): A list of subtrees to exclude from the configuration.
89
+
90
+ Raises:
91
+ LaunchError: If there is no active run.
92
+ """
93
+ from .internal import handle_run_config_input
94
+
95
+ handle_run_config_input(include, exclude)
@@ -40,9 +40,9 @@ State = Literal[
40
40
 
41
41
 
42
42
  class Status:
43
- def __init__(self, state: "State" = "unknown", data=None): # type: ignore
43
+ def __init__(self, state: "State" = "unknown", messages: List[str] = None): # type: ignore
44
44
  self.state = state
45
- self.data = data or {}
45
+ self.messages = messages or []
46
46
 
47
47
  def __repr__(self) -> "State":
48
48
  return self.state
@@ -14,6 +14,7 @@ from kubernetes_asyncio.client import ( # type: ignore # noqa: F401
14
14
  BatchV1Api,
15
15
  CoreV1Api,
16
16
  CustomObjectsApi,
17
+ V1Pod,
17
18
  V1PodStatus,
18
19
  )
19
20
 
@@ -118,6 +119,27 @@ def _is_container_creating(status: "V1PodStatus") -> bool:
118
119
  return False
119
120
 
120
121
 
122
+ def _is_pod_unschedulable(status: "V1PodStatus") -> Tuple[bool, str]:
123
+ """Return whether the pod is unschedulable along with the reason message."""
124
+ if not status.conditions:
125
+ return False, ""
126
+ for condition in status.conditions:
127
+ if (
128
+ condition.type == "PodScheduled"
129
+ and condition.status == "False"
130
+ and condition.reason == "Unschedulable"
131
+ ):
132
+ return True, condition.message
133
+ return False, ""
134
+
135
+
136
+ def _get_crd_job_name(object: "V1Pod") -> Optional[str]:
137
+ refs = object.metadata.owner_references
138
+ if refs:
139
+ return refs[0].name
140
+ return None
141
+
142
+
121
143
  def _state_from_conditions(conditions: List[Dict[str, Any]]) -> Optional[State]:
122
144
  """Get the status from the pod conditions."""
123
145
  true_conditions = [
@@ -298,10 +320,18 @@ class LaunchKubernetesMonitor:
298
320
  counts[state] += 1
299
321
  return counts
300
322
 
301
- def _set_status(self, job_name: str, status: Status) -> None:
323
+ def _set_status_state(self, job_name: str, state: State) -> None:
302
324
  """Set the status of the run."""
303
- if self._job_states.get(job_name) != status:
304
- self._job_states[job_name] = status
325
+ if job_name not in self._job_states:
326
+ self._job_states[job_name] = Status(state)
327
+ elif self._job_states[job_name].state != state:
328
+ self._job_states[job_name].state = state
329
+
330
+ def _add_status_message(self, job_name: str, message: str) -> None:
331
+ if job_name not in self._job_states:
332
+ self._job_states[job_name] = Status("unknown")
333
+ wandb.termwarn(f"Warning from Kubernetes for job {job_name}: {message}")
334
+ self._job_states[job_name].messages.append(message)
305
335
 
306
336
  async def _monitor_pods(self, namespace: str) -> None:
307
337
  """Monitor a namespace for changes."""
@@ -312,15 +342,19 @@ class LaunchKubernetesMonitor:
312
342
  label_selector=self._label_selector,
313
343
  ):
314
344
  obj = event.get("object")
315
- job_name = obj.metadata.labels.get("job-name")
345
+ job_name = obj.metadata.labels.get("job-name") or _get_crd_job_name(obj)
316
346
  if job_name is None or not hasattr(obj, "status"):
317
347
  continue
318
348
  if self.__get_status(job_name) in ["finished", "failed"]:
319
349
  continue
350
+
351
+ is_unschedulable, reason = _is_pod_unschedulable(obj.status)
352
+ if is_unschedulable:
353
+ self._add_status_message(job_name, reason)
320
354
  if obj.status.phase == "Running" or _is_container_creating(obj.status):
321
- self._set_status(job_name, Status("running"))
355
+ self._set_status_state(job_name, "running")
322
356
  elif _is_preempted(obj.status):
323
- self._set_status(job_name, Status("preempted"))
357
+ self._set_status_state(job_name, "preempted")
324
358
 
325
359
  async def _monitor_jobs(self, namespace: str) -> None:
326
360
  """Monitor a namespace for changes."""
@@ -334,15 +368,15 @@ class LaunchKubernetesMonitor:
334
368
  job_name = obj.metadata.name
335
369
 
336
370
  if obj.status.succeeded == 1:
337
- self._set_status(job_name, Status("finished"))
371
+ self._set_status_state(job_name, "finished")
338
372
  elif obj.status.failed is not None and obj.status.failed >= 1:
339
- self._set_status(job_name, Status("failed"))
373
+ self._set_status_state(job_name, "failed")
340
374
 
341
375
  # If the job is deleted and we haven't seen a terminal state
342
376
  # then we will consider the job failed.
343
377
  if event.get("type") == "DELETED":
344
378
  if self._job_states.get(job_name) != Status("finished"):
345
- self._set_status(job_name, Status("failed"))
379
+ self._set_status_state(job_name, "failed")
346
380
 
347
381
  async def _monitor_crd(
348
382
  self, namespace: str, custom_resource: CustomResource
@@ -355,7 +389,7 @@ class LaunchKubernetesMonitor:
355
389
  plural=custom_resource.plural,
356
390
  group=custom_resource.group,
357
391
  version=custom_resource.version,
358
- label_selector=self._label_selector, # TODO: Label selector doesn't work for CRDs.
392
+ label_selector=self._label_selector,
359
393
  ):
360
394
  object = event.get("object")
361
395
  name = object.get("metadata", dict()).get("name")
@@ -383,8 +417,7 @@ class LaunchKubernetesMonitor:
383
417
  )
384
418
  if state is None:
385
419
  continue
386
- status = Status(state)
387
- self._set_status(name, status)
420
+ self._set_status_state(name, state)
388
421
 
389
422
 
390
423
  class SafeWatch:
@@ -29,7 +29,6 @@ from wandb.sdk.lib.retry import ExponentialBackoff, retry_async
29
29
  from wandb.util import get_module
30
30
 
31
31
  from .._project_spec import EntryPoint, LaunchProject
32
- from ..builder.build import get_env_vars_dict
33
32
  from ..errors import LaunchError
34
33
  from ..utils import (
35
34
  LOG_PREFIX,
@@ -374,8 +373,7 @@ class KubernetesRunner(AbstractRunner):
374
373
  }
375
374
 
376
375
  entry_point = (
377
- launch_project.override_entrypoint
378
- or launch_project.get_single_entry_point()
376
+ launch_project.override_entrypoint or launch_project.get_job_entry_point()
379
377
  )
380
378
  if launch_project.docker_image:
381
379
  # dont specify run id if user provided image, could have multiple runs
@@ -401,8 +399,8 @@ class KubernetesRunner(AbstractRunner):
401
399
  launch_project.override_entrypoint is not None,
402
400
  )
403
401
 
404
- env_vars = get_env_vars_dict(
405
- launch_project, self._api, MAX_ENV_LENGTHS[self.__class__.__name__]
402
+ env_vars = launch_project.get_env_vars_dict(
403
+ self._api, MAX_ENV_LENGTHS[self.__class__.__name__]
406
404
  )
407
405
  api_key_secret = None
408
406
  for cont in containers:
@@ -511,8 +509,8 @@ class KubernetesRunner(AbstractRunner):
511
509
  api_version = resource_args.get("apiVersion", "batch/v1")
512
510
 
513
511
  if api_version not in ["batch/v1", "batch/v1beta1"]:
514
- env_vars = get_env_vars_dict(
515
- launch_project, self._api, MAX_ENV_LENGTHS[self.__class__.__name__]
512
+ env_vars = launch_project.get_env_vars_dict(
513
+ self._api, MAX_ENV_LENGTHS[self.__class__.__name__]
516
514
  )
517
515
  # Crawl the resource args and add our env vars to the containers.
518
516
  add_wandb_env(resource_args, env_vars)
@@ -537,7 +535,7 @@ class KubernetesRunner(AbstractRunner):
537
535
  if LaunchAgent.initialized():
538
536
  add_label_to_pods(
539
537
  resource_args,
540
- WANDB_K8S_LABEL_MONITOR,
538
+ WANDB_K8S_LABEL_AGENT,
541
539
  LaunchAgent.name(),
542
540
  )
543
541
  resource_args["metadata"]["labels"][WANDB_K8S_LABEL_AGENT] = (