wandb 0.15.3__py3-none-any.whl → 0.15.5__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (156) hide show
  1. wandb/__init__.py +1 -1
  2. wandb/analytics/sentry.py +1 -0
  3. wandb/apis/importers/base.py +20 -5
  4. wandb/apis/importers/mlflow.py +7 -1
  5. wandb/apis/internal.py +12 -0
  6. wandb/apis/public.py +247 -1387
  7. wandb/apis/reports/_panels.py +58 -35
  8. wandb/beta/workflows.py +6 -7
  9. wandb/cli/cli.py +130 -60
  10. wandb/data_types.py +3 -1
  11. wandb/filesync/dir_watcher.py +21 -27
  12. wandb/filesync/step_checksum.py +8 -8
  13. wandb/filesync/step_prepare.py +23 -10
  14. wandb/filesync/step_upload.py +13 -13
  15. wandb/filesync/upload_job.py +4 -8
  16. wandb/integration/cohere/__init__.py +3 -0
  17. wandb/integration/cohere/cohere.py +21 -0
  18. wandb/integration/cohere/resolver.py +347 -0
  19. wandb/integration/gym/__init__.py +4 -6
  20. wandb/integration/huggingface/__init__.py +3 -0
  21. wandb/integration/huggingface/huggingface.py +18 -0
  22. wandb/integration/huggingface/resolver.py +213 -0
  23. wandb/integration/langchain/wandb_tracer.py +16 -179
  24. wandb/integration/openai/__init__.py +1 -3
  25. wandb/integration/openai/openai.py +11 -143
  26. wandb/integration/openai/resolver.py +111 -38
  27. wandb/integration/sagemaker/config.py +2 -2
  28. wandb/integration/tensorboard/log.py +4 -4
  29. wandb/old/settings.py +24 -7
  30. wandb/proto/v3/wandb_telemetry_pb2.py +12 -12
  31. wandb/proto/v4/wandb_telemetry_pb2.py +12 -12
  32. wandb/proto/wandb_deprecated.py +3 -1
  33. wandb/sdk/__init__.py +1 -1
  34. wandb/sdk/artifacts/__init__.py +0 -0
  35. wandb/sdk/artifacts/artifact.py +2101 -0
  36. wandb/sdk/artifacts/artifact_download_logger.py +42 -0
  37. wandb/sdk/artifacts/artifact_manifest.py +67 -0
  38. wandb/sdk/artifacts/artifact_manifest_entry.py +159 -0
  39. wandb/sdk/artifacts/artifact_manifests/__init__.py +0 -0
  40. wandb/sdk/artifacts/artifact_manifests/artifact_manifest_v1.py +91 -0
  41. wandb/sdk/{internal → artifacts}/artifact_saver.py +6 -5
  42. wandb/sdk/artifacts/artifact_state.py +10 -0
  43. wandb/sdk/{interface/artifacts/artifact_cache.py → artifacts/artifacts_cache.py} +22 -12
  44. wandb/sdk/artifacts/exceptions.py +55 -0
  45. wandb/sdk/artifacts/storage_handler.py +59 -0
  46. wandb/sdk/artifacts/storage_handlers/__init__.py +0 -0
  47. wandb/sdk/artifacts/storage_handlers/azure_handler.py +192 -0
  48. wandb/sdk/artifacts/storage_handlers/gcs_handler.py +224 -0
  49. wandb/sdk/artifacts/storage_handlers/http_handler.py +112 -0
  50. wandb/sdk/artifacts/storage_handlers/local_file_handler.py +134 -0
  51. wandb/sdk/artifacts/storage_handlers/multi_handler.py +53 -0
  52. wandb/sdk/artifacts/storage_handlers/s3_handler.py +301 -0
  53. wandb/sdk/artifacts/storage_handlers/tracking_handler.py +67 -0
  54. wandb/sdk/artifacts/storage_handlers/wb_artifact_handler.py +132 -0
  55. wandb/sdk/artifacts/storage_handlers/wb_local_artifact_handler.py +72 -0
  56. wandb/sdk/artifacts/storage_layout.py +6 -0
  57. wandb/sdk/artifacts/storage_policies/__init__.py +0 -0
  58. wandb/sdk/artifacts/storage_policies/s3_bucket_policy.py +61 -0
  59. wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +386 -0
  60. wandb/sdk/{interface/artifacts/artifact_storage.py → artifacts/storage_policy.py} +5 -57
  61. wandb/sdk/data_types/_dtypes.py +7 -12
  62. wandb/sdk/data_types/base_types/json_metadata.py +3 -2
  63. wandb/sdk/data_types/base_types/media.py +8 -8
  64. wandb/sdk/data_types/base_types/wb_value.py +12 -13
  65. wandb/sdk/data_types/helper_types/bounding_boxes_2d.py +5 -6
  66. wandb/sdk/data_types/helper_types/classes.py +6 -8
  67. wandb/sdk/data_types/helper_types/image_mask.py +5 -6
  68. wandb/sdk/data_types/histogram.py +4 -3
  69. wandb/sdk/data_types/html.py +3 -4
  70. wandb/sdk/data_types/image.py +11 -9
  71. wandb/sdk/data_types/molecule.py +5 -3
  72. wandb/sdk/data_types/object_3d.py +7 -5
  73. wandb/sdk/data_types/plotly.py +3 -2
  74. wandb/sdk/data_types/saved_model.py +11 -11
  75. wandb/sdk/data_types/trace_tree.py +5 -4
  76. wandb/sdk/data_types/utils.py +3 -5
  77. wandb/sdk/data_types/video.py +5 -4
  78. wandb/sdk/integration_utils/auto_logging.py +215 -0
  79. wandb/sdk/interface/interface.py +15 -15
  80. wandb/sdk/internal/file_pusher.py +8 -16
  81. wandb/sdk/internal/file_stream.py +5 -11
  82. wandb/sdk/internal/handler.py +13 -1
  83. wandb/sdk/internal/internal_api.py +287 -13
  84. wandb/sdk/internal/job_builder.py +119 -30
  85. wandb/sdk/internal/sender.py +6 -26
  86. wandb/sdk/internal/settings_static.py +2 -0
  87. wandb/sdk/internal/system/assets/__init__.py +2 -0
  88. wandb/sdk/internal/system/assets/gpu.py +42 -0
  89. wandb/sdk/internal/system/assets/gpu_amd.py +216 -0
  90. wandb/sdk/internal/system/env_probe_helpers.py +13 -0
  91. wandb/sdk/internal/system/system_info.py +3 -3
  92. wandb/sdk/internal/tb_watcher.py +32 -22
  93. wandb/sdk/internal/thread_local_settings.py +18 -0
  94. wandb/sdk/launch/_project_spec.py +57 -11
  95. wandb/sdk/launch/agent/agent.py +147 -65
  96. wandb/sdk/launch/agent/job_status_tracker.py +34 -0
  97. wandb/sdk/launch/agent/run_queue_item_file_saver.py +45 -0
  98. wandb/sdk/launch/builder/abstract.py +5 -1
  99. wandb/sdk/launch/builder/build.py +21 -18
  100. wandb/sdk/launch/builder/docker_builder.py +10 -4
  101. wandb/sdk/launch/builder/kaniko_builder.py +113 -23
  102. wandb/sdk/launch/builder/noop.py +6 -3
  103. wandb/sdk/launch/builder/templates/_wandb_bootstrap.py +46 -14
  104. wandb/sdk/launch/environment/aws_environment.py +3 -2
  105. wandb/sdk/launch/environment/azure_environment.py +124 -0
  106. wandb/sdk/launch/environment/gcp_environment.py +2 -4
  107. wandb/sdk/launch/environment/local_environment.py +1 -1
  108. wandb/sdk/launch/errors.py +19 -0
  109. wandb/sdk/launch/github_reference.py +32 -19
  110. wandb/sdk/launch/launch.py +3 -8
  111. wandb/sdk/launch/launch_add.py +6 -2
  112. wandb/sdk/launch/loader.py +21 -2
  113. wandb/sdk/launch/registry/azure_container_registry.py +132 -0
  114. wandb/sdk/launch/registry/elastic_container_registry.py +39 -5
  115. wandb/sdk/launch/registry/google_artifact_registry.py +68 -26
  116. wandb/sdk/launch/registry/local_registry.py +2 -1
  117. wandb/sdk/launch/runner/abstract.py +24 -3
  118. wandb/sdk/launch/runner/kubernetes_runner.py +479 -26
  119. wandb/sdk/launch/runner/local_container.py +103 -51
  120. wandb/sdk/launch/runner/local_process.py +1 -1
  121. wandb/sdk/launch/runner/sagemaker_runner.py +60 -10
  122. wandb/sdk/launch/runner/vertex_runner.py +10 -5
  123. wandb/sdk/launch/sweeps/__init__.py +7 -9
  124. wandb/sdk/launch/sweeps/scheduler.py +307 -77
  125. wandb/sdk/launch/sweeps/scheduler_sweep.py +2 -1
  126. wandb/sdk/launch/sweeps/utils.py +82 -35
  127. wandb/sdk/launch/utils.py +89 -75
  128. wandb/sdk/lib/_settings_toposort_generated.py +7 -0
  129. wandb/sdk/lib/capped_dict.py +26 -0
  130. wandb/sdk/lib/{git.py → gitlib.py} +76 -59
  131. wandb/sdk/lib/hashutil.py +12 -4
  132. wandb/sdk/lib/paths.py +96 -8
  133. wandb/sdk/lib/sock_client.py +2 -2
  134. wandb/sdk/lib/timer.py +1 -0
  135. wandb/sdk/service/server.py +22 -9
  136. wandb/sdk/service/server_sock.py +1 -1
  137. wandb/sdk/service/service.py +27 -8
  138. wandb/sdk/verify/verify.py +4 -7
  139. wandb/sdk/wandb_config.py +2 -6
  140. wandb/sdk/wandb_init.py +57 -53
  141. wandb/sdk/wandb_require.py +7 -0
  142. wandb/sdk/wandb_run.py +61 -223
  143. wandb/sdk/wandb_settings.py +28 -4
  144. wandb/testing/relay.py +15 -2
  145. wandb/util.py +74 -36
  146. {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/METADATA +15 -9
  147. {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/RECORD +151 -116
  148. {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/entry_points.txt +1 -0
  149. wandb/integration/langchain/util.py +0 -191
  150. wandb/sdk/interface/artifacts/__init__.py +0 -33
  151. wandb/sdk/interface/artifacts/artifact.py +0 -615
  152. wandb/sdk/interface/artifacts/artifact_manifest.py +0 -131
  153. wandb/sdk/wandb_artifacts.py +0 -2226
  154. {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/LICENSE +0 -0
  155. {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/WHEEL +0 -0
  156. {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,192 @@
1
+ """Azure storage handler."""
2
+ from pathlib import PurePosixPath
3
+ from types import ModuleType
4
+ from typing import TYPE_CHECKING, Dict, Optional, Sequence, Tuple, Union
5
+ from urllib.parse import ParseResult, parse_qsl, urlparse
6
+
7
+ import wandb
8
+ from wandb import util
9
+ from wandb.sdk.artifacts.artifact_manifest_entry import ArtifactManifestEntry
10
+ from wandb.sdk.artifacts.artifacts_cache import get_artifacts_cache
11
+ from wandb.sdk.artifacts.storage_handler import DEFAULT_MAX_OBJECTS, StorageHandler
12
+ from wandb.sdk.lib.hashutil import ETag
13
+ from wandb.sdk.lib.paths import FilePathStr, LogicalPath, StrPath, URIStr
14
+
15
+ if TYPE_CHECKING:
16
+ import azure.identity # type: ignore
17
+ import azure.storage.blob # type: ignore
18
+
19
+ from wandb.sdk.artifacts.artifact import Artifact
20
+
21
+
22
+ class AzureHandler(StorageHandler):
23
+ def can_handle(self, parsed_url: "ParseResult") -> bool:
24
+ return parsed_url.scheme == "https" and parsed_url.netloc.endswith(
25
+ ".blob.core.windows.net"
26
+ )
27
+
28
+ def load_path(
29
+ self,
30
+ manifest_entry: "ArtifactManifestEntry",
31
+ local: bool = False,
32
+ ) -> Union[URIStr, FilePathStr]:
33
+ assert manifest_entry.ref is not None
34
+ if not local:
35
+ return manifest_entry.ref
36
+
37
+ path, hit, cache_open = get_artifacts_cache().check_etag_obj_path(
38
+ URIStr(manifest_entry.ref),
39
+ ETag(manifest_entry.digest),
40
+ manifest_entry.size or 0,
41
+ )
42
+ if hit:
43
+ return path
44
+
45
+ account_url, container_name, blob_name, query = self._parse_uri(
46
+ manifest_entry.ref
47
+ )
48
+ version_id = manifest_entry.extra.get("versionID")
49
+ blob_service_client = self._get_module("azure.storage.blob").BlobServiceClient(
50
+ account_url, credential=self._get_credential(account_url)
51
+ )
52
+ blob_client = blob_service_client.get_blob_client(
53
+ container=container_name, blob=blob_name
54
+ )
55
+ if version_id is None:
56
+ # Try current version, then all versions.
57
+ try:
58
+ downloader = blob_client.download_blob(
59
+ etag=manifest_entry.digest,
60
+ match_condition=self._get_module(
61
+ "azure.core"
62
+ ).MatchConditions.IfNotModified,
63
+ )
64
+ except self._get_module("azure.core.exceptions").ResourceModifiedError:
65
+ container_client = blob_service_client.get_container_client(
66
+ container_name
67
+ )
68
+ for blob_properties in container_client.walk_blobs(
69
+ name_starts_with=blob_name, include=["versions"]
70
+ ):
71
+ if (
72
+ blob_properties.name == blob_name
73
+ and blob_properties.etag == manifest_entry.digest
74
+ and blob_properties.version_id is not None
75
+ ):
76
+ downloader = blob_client.download_blob(
77
+ version_id=blob_properties.version_id
78
+ )
79
+ break
80
+ else: # didn't break
81
+ raise ValueError(
82
+ f"Couldn't find blob version for {manifest_entry.ref} matching "
83
+ f"etag {manifest_entry.digest}."
84
+ )
85
+ else:
86
+ downloader = blob_client.download_blob(version_id=version_id)
87
+ with cache_open(mode="wb") as f:
88
+ downloader.readinto(f)
89
+ return path
90
+
91
+ def store_path(
92
+ self,
93
+ artifact: "Artifact",
94
+ path: Union[URIStr, FilePathStr],
95
+ name: Optional[StrPath] = None,
96
+ checksum: bool = True,
97
+ max_objects: Optional[int] = None,
98
+ ) -> Sequence["ArtifactManifestEntry"]:
99
+ account_url, container_name, blob_name, query = self._parse_uri(path)
100
+ path = URIStr(f"{account_url}/{container_name}/{blob_name}")
101
+
102
+ if not checksum:
103
+ return [
104
+ ArtifactManifestEntry(path=name or blob_name, digest=path, ref=path)
105
+ ]
106
+
107
+ blob_service_client = self._get_module("azure.storage.blob").BlobServiceClient(
108
+ account_url, credential=self._get_credential(account_url)
109
+ )
110
+ blob_client = blob_service_client.get_blob_client(
111
+ container=container_name, blob=blob_name
112
+ )
113
+ if blob_client.exists(version_id=query.get("versionId")):
114
+ blob_properties = blob_client.get_blob_properties(
115
+ version_id=query.get("versionId")
116
+ )
117
+ return [
118
+ self._create_entry(
119
+ blob_properties,
120
+ path=name or PurePosixPath(blob_name).name,
121
+ ref=URIStr(
122
+ f"{account_url}/{container_name}/{blob_properties.name}"
123
+ ),
124
+ )
125
+ ]
126
+
127
+ entries = []
128
+ container_client = blob_service_client.get_container_client(container_name)
129
+ max_objects = max_objects or DEFAULT_MAX_OBJECTS
130
+ for i, blob_properties in enumerate(
131
+ container_client.list_blobs(name_starts_with=f"{blob_name}/")
132
+ ):
133
+ if i >= max_objects:
134
+ raise ValueError(
135
+ f"Exceeded {max_objects} objects tracked, pass max_objects to "
136
+ f"add_reference"
137
+ )
138
+ suffix = PurePosixPath(blob_properties.name).relative_to(blob_name)
139
+ entries.append(
140
+ self._create_entry(
141
+ blob_properties,
142
+ path=LogicalPath(name) / suffix if name else suffix,
143
+ ref=URIStr(
144
+ f"{account_url}/{container_name}/{blob_properties.name}"
145
+ ),
146
+ )
147
+ )
148
+ return entries
149
+
150
+ def _get_module(self, name: str) -> ModuleType:
151
+ module = util.get_module(
152
+ name,
153
+ lazy=False,
154
+ required="Azure references require the azure library, run "
155
+ "pip install wandb[azure]",
156
+ )
157
+ assert isinstance(module, ModuleType)
158
+ return module
159
+
160
+ def _get_credential(
161
+ self, account_url: str
162
+ ) -> Union["azure.identity.DefaultAzureCredential", str]:
163
+ if (
164
+ wandb.run
165
+ and account_url in wandb.run.settings.azure_account_url_to_access_key
166
+ ):
167
+ return wandb.run.settings.azure_account_url_to_access_key[account_url]
168
+ return self._get_module("azure.identity").DefaultAzureCredential()
169
+
170
+ def _parse_uri(self, uri: str) -> Tuple[str, str, str, Dict[str, str]]:
171
+ parsed_url = urlparse(uri)
172
+ query = dict(parse_qsl(parsed_url.query))
173
+ account_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
174
+ _, container_name, blob_name = parsed_url.path.split("/", 2)
175
+ return account_url, container_name, blob_name, query
176
+
177
+ def _create_entry(
178
+ self,
179
+ blob_properties: "azure.storage.blob.BlobProperties",
180
+ path: StrPath,
181
+ ref: URIStr,
182
+ ) -> ArtifactManifestEntry:
183
+ extra = {"etag": blob_properties.etag.strip('"')}
184
+ if blob_properties.version_id:
185
+ extra["versionID"] = blob_properties.version_id
186
+ return ArtifactManifestEntry(
187
+ path=path,
188
+ ref=ref,
189
+ digest=blob_properties.etag.strip('"'),
190
+ size=blob_properties.size,
191
+ extra=extra,
192
+ )
@@ -0,0 +1,224 @@
1
+ """GCS storage handler."""
2
+ import base64
3
+ import time
4
+ from pathlib import PurePosixPath
5
+ from typing import TYPE_CHECKING, Dict, Optional, Sequence, Tuple, Union
6
+ from urllib.parse import ParseResult, urlparse
7
+
8
+ from wandb import util
9
+ from wandb.errors.term import termlog
10
+ from wandb.sdk.artifacts.artifact_manifest_entry import ArtifactManifestEntry
11
+ from wandb.sdk.artifacts.artifacts_cache import get_artifacts_cache
12
+ from wandb.sdk.artifacts.storage_handler import DEFAULT_MAX_OBJECTS, StorageHandler
13
+ from wandb.sdk.lib.hashutil import B64MD5
14
+ from wandb.sdk.lib.paths import FilePathStr, StrPath, URIStr
15
+
16
+ if TYPE_CHECKING:
17
+ import google.cloud.storage as gcs_module # type: ignore
18
+
19
+ from wandb.sdk.artifacts.artifact import Artifact
20
+
21
+
22
+ class GCSHandler(StorageHandler):
23
+ _client: Optional["gcs_module.client.Client"]
24
+ _versioning_enabled: Optional[bool]
25
+
26
+ def __init__(self, scheme: Optional[str] = None) -> None:
27
+ self._scheme = scheme or "gs"
28
+ self._client = None
29
+ self._versioning_enabled = None
30
+ self._cache = get_artifacts_cache()
31
+
32
+ def versioning_enabled(self, bucket_path: str) -> bool:
33
+ if self._versioning_enabled is not None:
34
+ return self._versioning_enabled
35
+ self.init_gcs()
36
+ assert self._client is not None # mypy: unwraps optionality
37
+ bucket = self._client.bucket(bucket_path)
38
+ bucket.reload()
39
+ self._versioning_enabled = bucket.versioning_enabled
40
+ return self._versioning_enabled
41
+
42
+ def can_handle(self, parsed_url: "ParseResult") -> bool:
43
+ return parsed_url.scheme == self._scheme
44
+
45
+ def init_gcs(self) -> "gcs_module.client.Client":
46
+ if self._client is not None:
47
+ return self._client
48
+ storage = util.get_module(
49
+ "google.cloud.storage",
50
+ required="gs:// references requires the google-cloud-storage library, run pip install wandb[gcp]",
51
+ )
52
+ self._client = storage.Client()
53
+ return self._client
54
+
55
+ def _parse_uri(self, uri: str) -> Tuple[str, str, Optional[str]]:
56
+ url = urlparse(uri)
57
+ bucket = url.netloc
58
+ key = url.path[1:]
59
+ version = url.fragment if url.fragment else None
60
+ return bucket, key, version
61
+
62
+ def load_path(
63
+ self,
64
+ manifest_entry: ArtifactManifestEntry,
65
+ local: bool = False,
66
+ ) -> Union[URIStr, FilePathStr]:
67
+ if not local:
68
+ assert manifest_entry.ref is not None
69
+ return manifest_entry.ref
70
+
71
+ path, hit, cache_open = self._cache.check_md5_obj_path(
72
+ B64MD5(manifest_entry.digest), # TODO(spencerpearson): unsafe cast
73
+ manifest_entry.size if manifest_entry.size is not None else 0,
74
+ )
75
+ if hit:
76
+ return path
77
+
78
+ self.init_gcs()
79
+ assert self._client is not None # mypy: unwraps optionality
80
+ assert manifest_entry.ref is not None
81
+ bucket, key, _ = self._parse_uri(manifest_entry.ref)
82
+ version = manifest_entry.extra.get("versionID")
83
+
84
+ obj = None
85
+ # First attempt to get the generation specified, this will return None if versioning is not enabled
86
+ if version is not None:
87
+ obj = self._client.bucket(bucket).get_blob(key, generation=version)
88
+
89
+ if obj is None:
90
+ # Object versioning is disabled on the bucket, so just get
91
+ # the latest version and make sure the MD5 matches.
92
+ obj = self._client.bucket(bucket).get_blob(key)
93
+ if obj is None:
94
+ raise ValueError(
95
+ f"Unable to download object {manifest_entry.ref} with generation {version}"
96
+ )
97
+ md5 = obj.md5_hash
98
+ if md5 != manifest_entry.digest:
99
+ raise ValueError(
100
+ f"Digest mismatch for object {manifest_entry.ref}: expected {manifest_entry.digest} but found {md5}"
101
+ )
102
+
103
+ with cache_open(mode="wb") as f:
104
+ obj.download_to_file(f)
105
+ return path
106
+
107
+ def store_path(
108
+ self,
109
+ artifact: "Artifact",
110
+ path: Union[URIStr, FilePathStr],
111
+ name: Optional[StrPath] = None,
112
+ checksum: bool = True,
113
+ max_objects: Optional[int] = None,
114
+ ) -> Sequence[ArtifactManifestEntry]:
115
+ self.init_gcs()
116
+ assert self._client is not None # mypy: unwraps optionality
117
+
118
+ # After parsing any query params / fragments for additional context,
119
+ # such as version identifiers, pare down the path to just the bucket
120
+ # and key.
121
+ bucket, key, version = self._parse_uri(path)
122
+ path = URIStr(f"{self._scheme}://{bucket}/{key}")
123
+ max_objects = max_objects or DEFAULT_MAX_OBJECTS
124
+ if not self.versioning_enabled(bucket) and version:
125
+ raise ValueError(
126
+ f"Specifying a versionId is not valid for s3://{bucket} as it does not have versioning enabled."
127
+ )
128
+
129
+ if not checksum:
130
+ return [ArtifactManifestEntry(path=name or key, ref=path, digest=path)]
131
+
132
+ start_time = None
133
+ obj = self._client.bucket(bucket).get_blob(key, generation=version)
134
+ multi = obj is None
135
+ if multi:
136
+ start_time = time.time()
137
+ termlog(
138
+ 'Generating checksum for up to %i objects with prefix "%s"... '
139
+ % (max_objects, key),
140
+ newline=False,
141
+ )
142
+ objects = self._client.bucket(bucket).list_blobs(
143
+ prefix=key, max_results=max_objects
144
+ )
145
+ else:
146
+ objects = [obj]
147
+
148
+ entries = [
149
+ self._entry_from_obj(obj, path, name, prefix=key, multi=multi)
150
+ for obj in objects
151
+ ]
152
+ if start_time is not None:
153
+ termlog("Done. %.1fs" % (time.time() - start_time), prefix=False)
154
+ if len(entries) > max_objects:
155
+ raise ValueError(
156
+ "Exceeded %i objects tracked, pass max_objects to add_reference"
157
+ % max_objects
158
+ )
159
+ return entries
160
+
161
+ def _entry_from_obj(
162
+ self,
163
+ obj: "gcs_module.blob.Blob",
164
+ path: str,
165
+ name: Optional[StrPath] = None,
166
+ prefix: str = "",
167
+ multi: bool = False,
168
+ ) -> ArtifactManifestEntry:
169
+ """Create an ArtifactManifestEntry from a GCS object.
170
+
171
+ Arguments:
172
+ obj: The GCS object
173
+ path: The GCS-style path (e.g.: "gs://bucket/file.txt")
174
+ name: The user assigned name, or None if not specified
175
+ prefix: The prefix to add (will be the same as `path` for directories)
176
+ multi: Whether or not this is a multi-object add.
177
+ """
178
+ bucket, key, _ = self._parse_uri(path)
179
+
180
+ # Always use posix paths, since that's what S3 uses.
181
+ posix_key = PurePosixPath(obj.name) # the bucket key
182
+ posix_path = PurePosixPath(bucket) / PurePosixPath(
183
+ key
184
+ ) # the path, with the scheme stripped
185
+ posix_prefix = PurePosixPath(prefix) # the prefix, if adding a prefix
186
+ posix_name = PurePosixPath(name or "")
187
+ posix_ref = posix_path
188
+
189
+ if name is None:
190
+ # We're adding a directory (prefix), so calculate a relative path.
191
+ if str(posix_prefix) in str(posix_key) and posix_prefix != posix_key:
192
+ posix_name = posix_key.relative_to(posix_prefix)
193
+ posix_ref = posix_path / posix_name
194
+ else:
195
+ posix_name = PurePosixPath(posix_key.name)
196
+ posix_ref = posix_path
197
+ elif multi:
198
+ # We're adding a directory with a name override.
199
+ relpath = posix_key.relative_to(posix_prefix)
200
+ posix_name = posix_name / relpath
201
+ posix_ref = posix_path / relpath
202
+ return ArtifactManifestEntry(
203
+ path=posix_name,
204
+ ref=URIStr(f"{self._scheme}://{str(posix_ref)}"),
205
+ digest=obj.md5_hash,
206
+ size=obj.size,
207
+ extra=self._extra_from_obj(obj),
208
+ )
209
+
210
+ @staticmethod
211
+ def _extra_from_obj(obj: "gcs_module.blob.Blob") -> Dict[str, str]:
212
+ return {
213
+ "etag": obj.etag,
214
+ "versionID": obj.generation,
215
+ }
216
+
217
+ @staticmethod
218
+ def _content_addressed_path(md5: str) -> FilePathStr:
219
+ # TODO: is this the structure we want? not at all human
220
+ # readable, but that's probably OK. don't want people
221
+ # poking around in the bucket
222
+ return FilePathStr(
223
+ "wandb/%s" % base64.b64encode(md5.encode("ascii")).decode("ascii")
224
+ )
@@ -0,0 +1,112 @@
1
+ """HTTP storage handler."""
2
+ import os
3
+ from typing import TYPE_CHECKING, Dict, Optional, Sequence, Tuple, Union
4
+ from urllib.parse import ParseResult
5
+
6
+ from wandb.sdk.artifacts.artifact_manifest_entry import ArtifactManifestEntry
7
+ from wandb.sdk.artifacts.artifacts_cache import get_artifacts_cache
8
+ from wandb.sdk.artifacts.storage_handler import StorageHandler
9
+ from wandb.sdk.internal.thread_local_settings import _thread_local_api_settings
10
+ from wandb.sdk.lib.hashutil import ETag
11
+ from wandb.sdk.lib.paths import FilePathStr, StrPath, URIStr
12
+
13
+ if TYPE_CHECKING:
14
+ import requests
15
+
16
+ from wandb.sdk.artifacts.artifact import Artifact
17
+
18
+
19
+ class HTTPHandler(StorageHandler):
20
+ def __init__(
21
+ self, session: "requests.Session", scheme: Optional[str] = None
22
+ ) -> None:
23
+ self._scheme = scheme or "http"
24
+ self._cache = get_artifacts_cache()
25
+ self._session = session
26
+
27
+ def can_handle(self, parsed_url: "ParseResult") -> bool:
28
+ return parsed_url.scheme == self._scheme
29
+
30
+ def load_path(
31
+ self,
32
+ manifest_entry: ArtifactManifestEntry,
33
+ local: bool = False,
34
+ ) -> Union[URIStr, FilePathStr]:
35
+ if not local:
36
+ assert manifest_entry.ref is not None
37
+ return manifest_entry.ref
38
+
39
+ assert manifest_entry.ref is not None
40
+
41
+ path, hit, cache_open = self._cache.check_etag_obj_path(
42
+ URIStr(manifest_entry.ref),
43
+ ETag(manifest_entry.digest), # TODO(spencerpearson): unsafe cast
44
+ manifest_entry.size if manifest_entry.size is not None else 0,
45
+ )
46
+ if hit:
47
+ return path
48
+
49
+ response = self._session.get(
50
+ manifest_entry.ref,
51
+ stream=True,
52
+ cookies=_thread_local_api_settings.cookies,
53
+ headers=_thread_local_api_settings.headers,
54
+ )
55
+ response.raise_for_status()
56
+
57
+ digest: Optional[Union[ETag, FilePathStr, URIStr]]
58
+ digest, size, extra = self._entry_from_headers(response.headers)
59
+ digest = digest or manifest_entry.ref
60
+ if manifest_entry.digest != digest:
61
+ raise ValueError(
62
+ f"Digest mismatch for url {manifest_entry.ref}: expected {manifest_entry.digest} but found {digest}"
63
+ )
64
+
65
+ with cache_open(mode="wb") as file:
66
+ for data in response.iter_content(chunk_size=16 * 1024):
67
+ file.write(data)
68
+ return path
69
+
70
+ def store_path(
71
+ self,
72
+ artifact: "Artifact",
73
+ path: Union[URIStr, FilePathStr],
74
+ name: Optional[StrPath] = None,
75
+ checksum: bool = True,
76
+ max_objects: Optional[int] = None,
77
+ ) -> Sequence[ArtifactManifestEntry]:
78
+ name = name or os.path.basename(path)
79
+ if not checksum:
80
+ return [ArtifactManifestEntry(path=name, ref=path, digest=path)]
81
+
82
+ with self._session.get(
83
+ path,
84
+ stream=True,
85
+ cookies=_thread_local_api_settings.cookies,
86
+ headers=_thread_local_api_settings.headers,
87
+ ) as response:
88
+ response.raise_for_status()
89
+ digest: Optional[Union[ETag, FilePathStr, URIStr]]
90
+ digest, size, extra = self._entry_from_headers(response.headers)
91
+ digest = digest or path
92
+ return [
93
+ ArtifactManifestEntry(
94
+ path=name, ref=path, digest=digest, size=size, extra=extra
95
+ )
96
+ ]
97
+
98
+ def _entry_from_headers(
99
+ self, headers: "requests.structures.CaseInsensitiveDict"
100
+ ) -> Tuple[Optional[ETag], Optional[int], Dict[str, str]]:
101
+ response_headers = {k.lower(): v for k, v in headers.items()}
102
+ size = None
103
+ if response_headers.get("content-length", None):
104
+ size = int(response_headers["content-length"])
105
+
106
+ digest = response_headers.get("etag", None)
107
+ extra = {}
108
+ if digest:
109
+ extra["etag"] = digest
110
+ if digest and digest[:1] == '"' and digest[-1:] == '"':
111
+ digest = digest[1:-1] # trim leading and trailing quotes around etag
112
+ return digest, size, extra
@@ -0,0 +1,134 @@
1
+ """Local file storage handler."""
2
+ import os
3
+ import shutil
4
+ import time
5
+ from typing import TYPE_CHECKING, Optional, Sequence, Union
6
+ from urllib.parse import ParseResult
7
+
8
+ from wandb import util
9
+ from wandb.errors.term import termlog
10
+ from wandb.sdk.artifacts.artifact_manifest_entry import ArtifactManifestEntry
11
+ from wandb.sdk.artifacts.artifacts_cache import get_artifacts_cache
12
+ from wandb.sdk.artifacts.storage_handler import DEFAULT_MAX_OBJECTS, StorageHandler
13
+ from wandb.sdk.lib import filesystem
14
+ from wandb.sdk.lib.hashutil import B64MD5, md5_file_b64, md5_string
15
+ from wandb.sdk.lib.paths import FilePathStr, StrPath, URIStr
16
+
17
+ if TYPE_CHECKING:
18
+ from wandb.sdk.artifacts.artifact import Artifact
19
+
20
+
21
+ class LocalFileHandler(StorageHandler):
22
+ """Handles file:// references."""
23
+
24
+ def __init__(self, scheme: Optional[str] = None) -> None:
25
+ """Track files or directories on a local filesystem.
26
+
27
+ Expand directories to create an entry for each file contained.
28
+ """
29
+ self._scheme = scheme or "file"
30
+ self._cache = get_artifacts_cache()
31
+
32
+ def can_handle(self, parsed_url: "ParseResult") -> bool:
33
+ return parsed_url.scheme == self._scheme
34
+
35
+ def load_path(
36
+ self,
37
+ manifest_entry: ArtifactManifestEntry,
38
+ local: bool = False,
39
+ ) -> Union[URIStr, FilePathStr]:
40
+ if manifest_entry.ref is None:
41
+ raise ValueError(f"Cannot add path with no ref: {manifest_entry.path}")
42
+ local_path = util.local_file_uri_to_path(str(manifest_entry.ref))
43
+ if not os.path.exists(local_path):
44
+ raise ValueError(
45
+ "Local file reference: Failed to find file at path %s" % local_path
46
+ )
47
+
48
+ path, hit, cache_open = self._cache.check_md5_obj_path(
49
+ B64MD5(manifest_entry.digest), # TODO(spencerpearson): unsafe cast
50
+ manifest_entry.size if manifest_entry.size is not None else 0,
51
+ )
52
+ if hit:
53
+ return path
54
+
55
+ md5 = md5_file_b64(local_path)
56
+ if md5 != manifest_entry.digest:
57
+ raise ValueError(
58
+ f"Local file reference: Digest mismatch for path {local_path}: expected {manifest_entry.digest} but found {md5}"
59
+ )
60
+
61
+ filesystem.mkdir_exists_ok(os.path.dirname(path))
62
+
63
+ with cache_open() as f:
64
+ shutil.copy(local_path, f.name)
65
+ return path
66
+
67
+ def store_path(
68
+ self,
69
+ artifact: "Artifact",
70
+ path: Union[URIStr, FilePathStr],
71
+ name: Optional[StrPath] = None,
72
+ checksum: bool = True,
73
+ max_objects: Optional[int] = None,
74
+ ) -> Sequence[ArtifactManifestEntry]:
75
+ local_path = util.local_file_uri_to_path(path)
76
+ max_objects = max_objects or DEFAULT_MAX_OBJECTS
77
+ # We have a single file or directory
78
+ # Note, we follow symlinks for files contained within the directory
79
+ entries = []
80
+
81
+ def md5(path: str) -> B64MD5:
82
+ return (
83
+ md5_file_b64(path)
84
+ if checksum
85
+ else md5_string(str(os.stat(path).st_size))
86
+ )
87
+
88
+ if os.path.isdir(local_path):
89
+ i = 0
90
+ start_time = time.time()
91
+ if checksum:
92
+ termlog(
93
+ 'Generating checksum for up to %i files in "%s"...\n'
94
+ % (max_objects, local_path),
95
+ newline=False,
96
+ )
97
+ for root, _, files in os.walk(local_path):
98
+ for sub_path in files:
99
+ i += 1
100
+ if i > max_objects:
101
+ raise ValueError(
102
+ "Exceeded %i objects tracked, pass max_objects to add_reference"
103
+ % max_objects
104
+ )
105
+ physical_path = os.path.join(root, sub_path)
106
+ # TODO(spencerpearson): this is not a "logical path" in the sense that
107
+ # `LogicalPath` returns a "logical path"; it's a relative path
108
+ # **on the local filesystem**.
109
+ logical_path = os.path.relpath(physical_path, start=local_path)
110
+ if name is not None:
111
+ logical_path = os.path.join(name, logical_path)
112
+
113
+ entry = ArtifactManifestEntry(
114
+ path=logical_path,
115
+ ref=FilePathStr(os.path.join(path, logical_path)),
116
+ size=os.path.getsize(physical_path),
117
+ digest=md5(physical_path),
118
+ )
119
+ entries.append(entry)
120
+ if checksum:
121
+ termlog("Done. %.1fs" % (time.time() - start_time), prefix=False)
122
+ elif os.path.isfile(local_path):
123
+ name = name or os.path.basename(local_path)
124
+ entry = ArtifactManifestEntry(
125
+ path=name,
126
+ ref=path,
127
+ size=os.path.getsize(local_path),
128
+ digest=md5(local_path),
129
+ )
130
+ entries.append(entry)
131
+ else:
132
+ # TODO: update error message if we don't allow directories.
133
+ raise ValueError('Path "%s" must be a valid file or directory path' % path)
134
+ return entries