wandb 0.16.4__py3-none-any.whl → 0.16.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wandb/__init__.py +2 -2
- wandb/agents/pyagent.py +1 -1
- wandb/apis/public/api.py +6 -6
- wandb/apis/reports/v2/interface.py +4 -8
- wandb/apis/reports/v2/internal.py +12 -45
- wandb/cli/cli.py +29 -5
- wandb/integration/openai/fine_tuning.py +74 -37
- wandb/integration/ultralytics/callback.py +0 -1
- wandb/proto/v3/wandb_internal_pb2.py +332 -312
- wandb/proto/v3/wandb_settings_pb2.py +13 -3
- wandb/proto/v3/wandb_telemetry_pb2.py +10 -10
- wandb/proto/v4/wandb_internal_pb2.py +316 -312
- wandb/proto/v4/wandb_settings_pb2.py +5 -3
- wandb/proto/v4/wandb_telemetry_pb2.py +10 -10
- wandb/sdk/artifacts/artifact.py +92 -26
- wandb/sdk/artifacts/artifact_manifest_entry.py +6 -1
- wandb/sdk/artifacts/artifact_manifests/artifact_manifest_v1.py +1 -0
- wandb/sdk/artifacts/artifact_saver.py +16 -36
- wandb/sdk/artifacts/storage_handler.py +2 -1
- wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +13 -5
- wandb/sdk/interface/interface.py +60 -15
- wandb/sdk/interface/interface_shared.py +13 -7
- wandb/sdk/internal/file_stream.py +19 -0
- wandb/sdk/internal/handler.py +1 -4
- wandb/sdk/internal/internal_api.py +2 -0
- wandb/sdk/internal/job_builder.py +45 -17
- wandb/sdk/internal/sender.py +53 -28
- wandb/sdk/internal/settings_static.py +9 -0
- wandb/sdk/internal/system/system_info.py +4 -1
- wandb/sdk/launch/_launch.py +5 -0
- wandb/sdk/launch/_project_spec.py +5 -20
- wandb/sdk/launch/agent/agent.py +80 -37
- wandb/sdk/launch/agent/config.py +8 -0
- wandb/sdk/launch/builder/kaniko_builder.py +149 -134
- wandb/sdk/launch/create_job.py +44 -48
- wandb/sdk/launch/runner/kubernetes_monitor.py +3 -1
- wandb/sdk/launch/runner/kubernetes_runner.py +20 -2
- wandb/sdk/launch/sweeps/scheduler.py +3 -1
- wandb/sdk/launch/utils.py +23 -5
- wandb/sdk/lib/__init__.py +2 -5
- wandb/sdk/lib/_settings_toposort_generated.py +2 -0
- wandb/sdk/lib/filesystem.py +11 -1
- wandb/sdk/lib/run_moment.py +78 -0
- wandb/sdk/service/streams.py +1 -6
- wandb/sdk/wandb_init.py +12 -7
- wandb/sdk/wandb_login.py +43 -26
- wandb/sdk/wandb_run.py +179 -94
- wandb/sdk/wandb_settings.py +55 -16
- wandb/testing/relay.py +5 -6
- {wandb-0.16.4.dist-info → wandb-0.16.6.dist-info}/METADATA +1 -1
- {wandb-0.16.4.dist-info → wandb-0.16.6.dist-info}/RECORD +55 -54
- {wandb-0.16.4.dist-info → wandb-0.16.6.dist-info}/WHEEL +1 -1
- {wandb-0.16.4.dist-info → wandb-0.16.6.dist-info}/LICENSE +0 -0
- {wandb-0.16.4.dist-info → wandb-0.16.6.dist-info}/entry_points.txt +0 -0
- {wandb-0.16.4.dist-info → wandb-0.16.6.dist-info}/top_level.txt +0 -0
wandb/__init__.py
CHANGED
@@ -11,8 +11,8 @@ For scripts and interactive notebooks, see https://github.com/wandb/examples.
|
|
11
11
|
|
12
12
|
For reference documentation, see https://docs.wandb.com/ref/python.
|
13
13
|
"""
|
14
|
-
__version__ = "0.16.
|
15
|
-
_minimum_core_version = "0.17.
|
14
|
+
__version__ = "0.16.6"
|
15
|
+
_minimum_core_version = "0.17.0b10"
|
16
16
|
|
17
17
|
# Used with pypi checks and other messages related to pip
|
18
18
|
_wandb_module = "wandb"
|
wandb/agents/pyagent.py
CHANGED
@@ -347,7 +347,7 @@ def pyagent(sweep_id, function, entity=None, project=None, count=None):
|
|
347
347
|
count (int, optional): the number of trials to run.
|
348
348
|
"""
|
349
349
|
if not callable(function):
|
350
|
-
raise Exception("function
|
350
|
+
raise Exception("function parameter must be callable!")
|
351
351
|
agent = Agent(
|
352
352
|
sweep_id,
|
353
353
|
function=function,
|
wandb/apis/public/api.py
CHANGED
@@ -313,15 +313,15 @@ class Api:
|
|
313
313
|
entity: (str) Optional name of the entity to create the queue. If None, will use the configured or default entity.
|
314
314
|
prioritization_mode: (str) Optional version of prioritization to use. Either "V0" or None
|
315
315
|
config: (dict) Optional default resource configuration to be used for the queue. Use handlebars (eg. "{{var}}") to specify template variables.
|
316
|
-
template_variables (dict)
|
316
|
+
template_variables: (dict) A dictionary of template variable schemas to be used with the config. Expected format of:
|
317
317
|
{
|
318
318
|
"var-name": {
|
319
319
|
"schema": {
|
320
|
-
"type": "
|
321
|
-
"default":
|
322
|
-
"minimum":
|
323
|
-
"maximum":
|
324
|
-
"enum": [..."
|
320
|
+
"type": ("string", "number", or "integer"),
|
321
|
+
"default": (optional value),
|
322
|
+
"minimum": (optional minimum),
|
323
|
+
"maximum": (optional maximum),
|
324
|
+
"enum": [..."(options)"]
|
325
325
|
}
|
326
326
|
}
|
327
327
|
}
|
@@ -4,6 +4,8 @@ from datetime import datetime
|
|
4
4
|
from typing import Dict, Iterable, Optional, Tuple, Union
|
5
5
|
from typing import List as LList
|
6
6
|
|
7
|
+
from annotated_types import Annotated, Ge, Le
|
8
|
+
|
7
9
|
try:
|
8
10
|
from typing import Literal
|
9
11
|
except ImportError:
|
@@ -758,14 +760,8 @@ block_mapping = {
|
|
758
760
|
|
759
761
|
@dataclass(config=dataclass_config)
|
760
762
|
class GradientPoint(Base):
|
761
|
-
color: str
|
762
|
-
offset: float
|
763
|
-
|
764
|
-
@validator("color")
|
765
|
-
def validate_color(cls, v): # noqa: N805
|
766
|
-
if not internal.is_valid_color(v):
|
767
|
-
raise ValueError("invalid color, value should be hex, rgb, or rgba")
|
768
|
-
return v
|
763
|
+
color: Annotated[str, internal.ColorStrConstraints]
|
764
|
+
offset: Annotated[float, Ge(0), Le(100)] = 0
|
769
765
|
|
770
766
|
def to_model(self):
|
771
767
|
return internal.GradientPoint(color=self.color, offset=self.offset)
|
@@ -1,18 +1,19 @@
|
|
1
1
|
"""JSONSchema for internal types. Hopefully this is auto-generated one day!"""
|
2
2
|
import json
|
3
3
|
import random
|
4
|
-
import re
|
5
4
|
from copy import deepcopy
|
6
5
|
from datetime import datetime
|
7
6
|
from typing import Any, Dict, Optional, Tuple, Union
|
8
7
|
from typing import List as LList
|
9
8
|
|
9
|
+
from annotated_types import Annotated, Ge, Le
|
10
|
+
|
10
11
|
try:
|
11
12
|
from typing import Literal
|
12
13
|
except ImportError:
|
13
14
|
from typing_extensions import Literal
|
14
15
|
|
15
|
-
from pydantic import BaseModel, ConfigDict, Field, validator
|
16
|
+
from pydantic import BaseModel, ConfigDict, Field, StringConstraints, validator
|
16
17
|
from pydantic.alias_generators import to_camel
|
17
18
|
|
18
19
|
|
@@ -48,6 +49,13 @@ def _generate_name(length: int = 12) -> str:
|
|
48
49
|
return rand36.lower()[:length]
|
49
50
|
|
50
51
|
|
52
|
+
hex_pattern = r"^#(?:[0-9a-fA-F]{3}){1,2}$"
|
53
|
+
rgb_pattern = r"^rgb\(\s*(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\s*,\s*(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\s*,\s*(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\s*\)$"
|
54
|
+
rgba_pattern = r"^rgba\(\s*(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\s*,\s*(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\s*,\s*(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\s*,\s*(1|0|0?\.\d+)\s*\)$"
|
55
|
+
ColorStrConstraints = StringConstraints(
|
56
|
+
pattern=f"{hex_pattern}|{rgb_pattern}|{rgba_pattern}"
|
57
|
+
)
|
58
|
+
|
51
59
|
LinePlotStyle = Literal["line", "stacked-area", "pct-area"]
|
52
60
|
BarPlotStyle = Literal["bar", "boxplot", "violin"]
|
53
61
|
FontSize = Literal["small", "medium", "large", "auto"]
|
@@ -609,14 +617,8 @@ class LinePlot(Panel):
|
|
609
617
|
|
610
618
|
|
611
619
|
class GradientPoint(ReportAPIBaseModel):
|
612
|
-
color: str
|
613
|
-
offset: float
|
614
|
-
|
615
|
-
@validator("color")
|
616
|
-
def validate_color(cls, v): # noqa: N805
|
617
|
-
if not is_valid_color(v):
|
618
|
-
raise ValueError("invalid color, value should be hex, rgb, or rgba")
|
619
|
-
return v
|
620
|
+
color: Annotated[str, ColorStrConstraints]
|
621
|
+
offset: Annotated[float, Ge(0), Le(100)] = 0
|
620
622
|
|
621
623
|
|
622
624
|
class ScatterPlotConfig(ReportAPIBaseModel):
|
@@ -866,38 +868,3 @@ block_type_mapping = {
|
|
866
868
|
"table-of-contents": TableOfContents,
|
867
869
|
"block-quote": BlockQuote,
|
868
870
|
}
|
869
|
-
|
870
|
-
|
871
|
-
def is_valid_color(color_str: str) -> bool:
|
872
|
-
# Regular expression for hex color validation
|
873
|
-
hex_color_pattern = r"^#(?:[0-9a-fA-F]{3}){1,2}$"
|
874
|
-
|
875
|
-
# Check if it's a valid hex color
|
876
|
-
if re.match(hex_color_pattern, color_str):
|
877
|
-
return True
|
878
|
-
|
879
|
-
# Try parsing it as an RGB or RGBA tuple
|
880
|
-
try:
|
881
|
-
# Strip 'rgb(' or 'rgba(' and the closing ')'
|
882
|
-
if color_str.startswith("rgb(") and color_str.endswith(")"):
|
883
|
-
parts = color_str[4:-1].split(",")
|
884
|
-
elif color_str.startswith("rgba(") and color_str.endswith(")"):
|
885
|
-
parts = color_str[5:-1].split(",")
|
886
|
-
else:
|
887
|
-
return False
|
888
|
-
|
889
|
-
# Convert parts to integers and validate ranges
|
890
|
-
parts = [int(p.strip()) for p in parts]
|
891
|
-
if len(parts) == 3 and all(0 <= p <= 255 for p in parts):
|
892
|
-
return True # Valid RGB
|
893
|
-
if (
|
894
|
-
len(parts) == 4
|
895
|
-
and all(0 <= p <= 255 for p in parts[:-1])
|
896
|
-
and 0 <= parts[-1] <= 1
|
897
|
-
):
|
898
|
-
return True # Valid RGBA
|
899
|
-
|
900
|
-
except ValueError:
|
901
|
-
pass
|
902
|
-
|
903
|
-
return False
|
wandb/cli/cli.py
CHANGED
@@ -1680,6 +1680,7 @@ def launch(
|
|
1680
1680
|
hidden=True,
|
1681
1681
|
help="a wandb client registration URL, this is generated in the UI",
|
1682
1682
|
)
|
1683
|
+
@click.option("--verbose", "-v", count=True, help="Display verbose output")
|
1683
1684
|
@display_error
|
1684
1685
|
def launch_agent(
|
1685
1686
|
ctx,
|
@@ -1690,6 +1691,7 @@ def launch_agent(
|
|
1690
1691
|
config=None,
|
1691
1692
|
url=None,
|
1692
1693
|
log_file=None,
|
1694
|
+
verbose=0,
|
1693
1695
|
):
|
1694
1696
|
logger.info(
|
1695
1697
|
f"=== Launch-agent called with kwargs {locals()} CLI Version: {wandb.__version__} ==="
|
@@ -1707,7 +1709,7 @@ def launch_agent(
|
|
1707
1709
|
api = _get_cling_api()
|
1708
1710
|
wandb._sentry.configure_scope(process_context="launch_agent")
|
1709
1711
|
agent_config, api = _launch.resolve_agent_config(
|
1710
|
-
entity, project, max_jobs, queues, config
|
1712
|
+
entity, project, max_jobs, queues, config, verbose
|
1711
1713
|
)
|
1712
1714
|
|
1713
1715
|
if len(agent_config.get("queues")) == 0:
|
@@ -1905,7 +1907,7 @@ def describe(job):
|
|
1905
1907
|
"--entry-point",
|
1906
1908
|
"-E",
|
1907
1909
|
"entrypoint",
|
1908
|
-
help="
|
1910
|
+
help="Entrypoint to the script, including an executable and an entrypoint file. Required for code or repo jobs",
|
1909
1911
|
)
|
1910
1912
|
@click.option(
|
1911
1913
|
"--git-hash",
|
@@ -2345,8 +2347,30 @@ def artifact():
|
|
2345
2347
|
default=None,
|
2346
2348
|
help="Resume the last run from your current directory.",
|
2347
2349
|
)
|
2350
|
+
@click.option(
|
2351
|
+
"--skip_cache",
|
2352
|
+
is_flag=True,
|
2353
|
+
default=False,
|
2354
|
+
help="Skip caching while uploading artifact files.",
|
2355
|
+
)
|
2356
|
+
@click.option(
|
2357
|
+
"--policy",
|
2358
|
+
default="mutable",
|
2359
|
+
type=click.Choice(["mutable", "immutable"]),
|
2360
|
+
help="Set the storage policy while uploading artifact files.",
|
2361
|
+
)
|
2348
2362
|
@display_error
|
2349
|
-
def put(
|
2363
|
+
def put(
|
2364
|
+
path,
|
2365
|
+
name,
|
2366
|
+
description,
|
2367
|
+
type,
|
2368
|
+
alias,
|
2369
|
+
run_id,
|
2370
|
+
resume,
|
2371
|
+
skip_cache,
|
2372
|
+
policy,
|
2373
|
+
):
|
2350
2374
|
if name is None:
|
2351
2375
|
name = os.path.basename(path)
|
2352
2376
|
public_api = PublicApi()
|
@@ -2361,10 +2385,10 @@ def put(path, name, description, type, alias, run_id, resume):
|
|
2361
2385
|
artifact_path = f"{entity}/{project}/{artifact_name}:{alias[0]}"
|
2362
2386
|
if os.path.isdir(path):
|
2363
2387
|
wandb.termlog(f'Uploading directory {path} to: "{artifact_path}" ({type})')
|
2364
|
-
artifact.add_dir(path)
|
2388
|
+
artifact.add_dir(path, skip_cache=skip_cache, policy=policy)
|
2365
2389
|
elif os.path.isfile(path):
|
2366
2390
|
wandb.termlog(f'Uploading file {path} to: "{artifact_path}" ({type})')
|
2367
|
-
artifact.add_file(path)
|
2391
|
+
artifact.add_file(path, skip_cache=skip_cache, policy=policy)
|
2368
2392
|
elif "://" in path:
|
2369
2393
|
wandb.termlog(
|
2370
2394
|
f'Logging reference artifact from {path} to: "{artifact_path}" ({type})'
|
@@ -1,9 +1,11 @@
|
|
1
1
|
import datetime
|
2
2
|
import io
|
3
3
|
import json
|
4
|
+
import os
|
4
5
|
import re
|
6
|
+
import tempfile
|
5
7
|
import time
|
6
|
-
from typing import Any, Dict, Optional, Tuple
|
8
|
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
7
9
|
|
8
10
|
import wandb
|
9
11
|
from wandb import util
|
@@ -26,7 +28,10 @@ if parse_version(openai.__version__) < parse_version("1.0.1"):
|
|
26
28
|
|
27
29
|
from openai import OpenAI # noqa: E402
|
28
30
|
from openai.types.fine_tuning import FineTuningJob # noqa: E402
|
29
|
-
from openai.types.fine_tuning.fine_tuning_job import
|
31
|
+
from openai.types.fine_tuning.fine_tuning_job import ( # noqa: E402
|
32
|
+
Error,
|
33
|
+
Hyperparameters,
|
34
|
+
)
|
30
35
|
|
31
36
|
np = util.get_module(
|
32
37
|
name="numpy",
|
@@ -59,6 +64,7 @@ class WandbLogger:
|
|
59
64
|
entity: Optional[str] = None,
|
60
65
|
overwrite: bool = False,
|
61
66
|
wait_for_job_success: bool = True,
|
67
|
+
log_datasets: bool = True,
|
62
68
|
**kwargs_wandb_init: Dict[str, Any],
|
63
69
|
) -> str:
|
64
70
|
"""Sync fine-tunes to Weights & Biases.
|
@@ -150,6 +156,7 @@ class WandbLogger:
|
|
150
156
|
entity,
|
151
157
|
overwrite,
|
152
158
|
show_individual_warnings,
|
159
|
+
log_datasets,
|
153
160
|
**kwargs_wandb_init,
|
154
161
|
)
|
155
162
|
|
@@ -160,11 +167,14 @@ class WandbLogger:
|
|
160
167
|
|
161
168
|
@classmethod
|
162
169
|
def _wait_for_job_success(cls, fine_tune: FineTuningJob) -> FineTuningJob:
|
163
|
-
wandb.termlog("Waiting for the OpenAI fine-tuning job to
|
170
|
+
wandb.termlog("Waiting for the OpenAI fine-tuning job to finish training...")
|
171
|
+
wandb.termlog(
|
172
|
+
"To avoid blocking, you can call `WandbLogger.sync` with `wait_for_job_success=False` after OpenAI training completes."
|
173
|
+
)
|
164
174
|
while True:
|
165
175
|
if fine_tune.status == "succeeded":
|
166
176
|
wandb.termlog(
|
167
|
-
"Fine-tuning finished, logging metrics, model metadata, and
|
177
|
+
"Fine-tuning finished, logging metrics, model metadata, and run metadata to Weights & Biases"
|
168
178
|
)
|
169
179
|
return fine_tune
|
170
180
|
if fine_tune.status == "failed":
|
@@ -190,6 +200,7 @@ class WandbLogger:
|
|
190
200
|
entity: Optional[str],
|
191
201
|
overwrite: bool,
|
192
202
|
show_individual_warnings: bool,
|
203
|
+
log_datasets: bool,
|
193
204
|
**kwargs_wandb_init: Dict[str, Any],
|
194
205
|
):
|
195
206
|
fine_tune_id = fine_tune.id
|
@@ -209,7 +220,7 @@ class WandbLogger:
|
|
209
220
|
# check results are present
|
210
221
|
try:
|
211
222
|
results_id = fine_tune.result_files[0]
|
212
|
-
results = cls.openai_client.files.
|
223
|
+
results = cls.openai_client.files.content(file_id=results_id).text
|
213
224
|
except openai.NotFoundError:
|
214
225
|
if show_individual_warnings:
|
215
226
|
wandb.termwarn(
|
@@ -233,7 +244,7 @@ class WandbLogger:
|
|
233
244
|
cls._run.summary["fine_tuned_model"] = fine_tuned_model
|
234
245
|
|
235
246
|
# training/validation files and fine-tune details
|
236
|
-
cls._log_artifacts(fine_tune, project, entity)
|
247
|
+
cls._log_artifacts(fine_tune, project, entity, log_datasets, overwrite)
|
237
248
|
|
238
249
|
# mark run as complete
|
239
250
|
cls._run.summary["status"] = "succeeded"
|
@@ -249,7 +260,7 @@ class WandbLogger:
|
|
249
260
|
else:
|
250
261
|
raise Exception(
|
251
262
|
"It appears you are not currently logged in to Weights & Biases. "
|
252
|
-
"Please run `wandb login` in your terminal. "
|
263
|
+
"Please run `wandb login` in your terminal or `wandb.login()` in a notebook."
|
253
264
|
"When prompted, you can obtain your API key by visiting wandb.ai/authorize."
|
254
265
|
)
|
255
266
|
|
@@ -286,15 +297,9 @@ class WandbLogger:
|
|
286
297
|
config["finished_at"]
|
287
298
|
).strftime("%Y-%m-%d %H:%M:%S")
|
288
299
|
if config.get("hyperparameters"):
|
289
|
-
hyperparameters =
|
290
|
-
|
291
|
-
|
292
|
-
# If unpacking fails, log the object which will render as string
|
293
|
-
config["hyperparameters"] = hyperparameters
|
294
|
-
else:
|
295
|
-
# nested rendering on hyperparameters
|
296
|
-
config["hyperparameters"] = hyperparams
|
297
|
-
|
300
|
+
config["hyperparameters"] = cls.sanitize(config["hyperparameters"])
|
301
|
+
if config.get("error"):
|
302
|
+
config["error"] = cls.sanitize(config["error"])
|
298
303
|
return config
|
299
304
|
|
300
305
|
@classmethod
|
@@ -314,21 +319,44 @@ class WandbLogger:
|
|
314
319
|
|
315
320
|
return hyperparams
|
316
321
|
|
322
|
+
@staticmethod
|
323
|
+
def sanitize(input: Any) -> Union[Dict, List, str]:
|
324
|
+
valid_types = [bool, int, float, str]
|
325
|
+
if isinstance(input, (Hyperparameters, Error)):
|
326
|
+
return dict(input)
|
327
|
+
if isinstance(input, dict):
|
328
|
+
return {
|
329
|
+
k: v if type(v) in valid_types else str(v) for k, v in input.items()
|
330
|
+
}
|
331
|
+
elif isinstance(input, list):
|
332
|
+
return [v if type(v) in valid_types else str(v) for v in input]
|
333
|
+
else:
|
334
|
+
return str(input)
|
335
|
+
|
317
336
|
@classmethod
|
318
337
|
def _log_artifacts(
|
319
|
-
cls,
|
338
|
+
cls,
|
339
|
+
fine_tune: FineTuningJob,
|
340
|
+
project: str,
|
341
|
+
entity: Optional[str],
|
342
|
+
log_datasets: bool,
|
343
|
+
overwrite: bool,
|
320
344
|
) -> None:
|
321
|
-
|
322
|
-
|
323
|
-
|
324
|
-
fine_tune.
|
325
|
-
|
326
|
-
|
327
|
-
|
328
|
-
|
329
|
-
|
330
|
-
|
331
|
-
|
345
|
+
if log_datasets:
|
346
|
+
wandb.termlog("Logging training/validation files...")
|
347
|
+
# training/validation files
|
348
|
+
training_file = fine_tune.training_file if fine_tune.training_file else None
|
349
|
+
validation_file = (
|
350
|
+
fine_tune.validation_file if fine_tune.validation_file else None
|
351
|
+
)
|
352
|
+
for file, prefix, artifact_type in (
|
353
|
+
(training_file, "train", "training_files"),
|
354
|
+
(validation_file, "valid", "validation_files"),
|
355
|
+
):
|
356
|
+
if file is not None:
|
357
|
+
cls._log_artifact_inputs(
|
358
|
+
file, prefix, artifact_type, project, entity, overwrite
|
359
|
+
)
|
332
360
|
|
333
361
|
# fine-tune details
|
334
362
|
fine_tune_id = fine_tune.id
|
@@ -337,9 +365,14 @@ class WandbLogger:
|
|
337
365
|
type="model",
|
338
366
|
metadata=dict(fine_tune),
|
339
367
|
)
|
368
|
+
|
340
369
|
with artifact.new_file("model_metadata.json", mode="w", encoding="utf-8") as f:
|
341
370
|
dict_fine_tune = dict(fine_tune)
|
342
|
-
dict_fine_tune["hyperparameters"] =
|
371
|
+
dict_fine_tune["hyperparameters"] = cls.sanitize(
|
372
|
+
dict_fine_tune["hyperparameters"]
|
373
|
+
)
|
374
|
+
dict_fine_tune["error"] = cls.sanitize(dict_fine_tune["error"])
|
375
|
+
dict_fine_tune = cls.sanitize(dict_fine_tune)
|
343
376
|
json.dump(dict_fine_tune, f, indent=2)
|
344
377
|
cls._run.log_artifact(
|
345
378
|
artifact,
|
@@ -354,6 +387,7 @@ class WandbLogger:
|
|
354
387
|
artifact_type: str,
|
355
388
|
project: str,
|
356
389
|
entity: Optional[str],
|
390
|
+
overwrite: bool,
|
357
391
|
) -> None:
|
358
392
|
# get input artifact
|
359
393
|
artifact_name = f"{prefix}-{file_id}"
|
@@ -366,23 +400,26 @@ class WandbLogger:
|
|
366
400
|
artifact = cls._get_wandb_artifact(artifact_path)
|
367
401
|
|
368
402
|
# create artifact if file not already logged previously
|
369
|
-
if artifact is None:
|
403
|
+
if artifact is None or overwrite:
|
370
404
|
# get file content
|
371
405
|
try:
|
372
|
-
file_content = cls.openai_client.files.
|
406
|
+
file_content = cls.openai_client.files.content(file_id=file_id)
|
373
407
|
except openai.NotFoundError:
|
374
408
|
wandb.termerror(
|
375
|
-
f"File {file_id} could not be retrieved. Make sure you
|
409
|
+
f"File {file_id} could not be retrieved. Make sure you have OpenAI permissions to download training/validation files"
|
376
410
|
)
|
377
411
|
return
|
378
412
|
|
379
413
|
artifact = wandb.Artifact(artifact_name, type=artifact_type)
|
380
|
-
with
|
381
|
-
|
414
|
+
with tempfile.NamedTemporaryFile(delete=False) as tmp_file:
|
415
|
+
tmp_file.write(file_content.content)
|
416
|
+
tmp_file_path = tmp_file.name
|
417
|
+
artifact.add_file(tmp_file_path, file_id)
|
418
|
+
os.unlink(tmp_file_path)
|
382
419
|
|
383
420
|
# create a Table
|
384
421
|
try:
|
385
|
-
table, n_items = cls._make_table(file_content)
|
422
|
+
table, n_items = cls._make_table(file_content.text)
|
386
423
|
# Add table to the artifact.
|
387
424
|
artifact.add(table, file_id)
|
388
425
|
# Add the same table to the workspace.
|
@@ -390,9 +427,9 @@ class WandbLogger:
|
|
390
427
|
# Update the run config and artifact metadata
|
391
428
|
cls._run.config.update({f"n_{prefix}": n_items})
|
392
429
|
artifact.metadata["items"] = n_items
|
393
|
-
except Exception:
|
430
|
+
except Exception as e:
|
394
431
|
wandb.termerror(
|
395
|
-
f"
|
432
|
+
f"Issue saving {file_id} as a Table to Artifacts, exception:\n '{e}'"
|
396
433
|
)
|
397
434
|
else:
|
398
435
|
# log number of items
|