wandb 0.18.3__py3-none-win32.whl → 0.18.4__py3-none-win32.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wandb/__init__.py +16 -7
- wandb/__init__.pyi +96 -63
- wandb/analytics/sentry.py +91 -88
- wandb/apis/public/api.py +18 -4
- wandb/apis/public/runs.py +53 -2
- wandb/bin/gpu_stats.exe +0 -0
- wandb/bin/wandb-core +0 -0
- wandb/cli/beta.py +178 -0
- wandb/cli/cli.py +5 -171
- wandb/data_types.py +3 -0
- wandb/env.py +74 -73
- wandb/errors/term.py +300 -43
- wandb/proto/v3/wandb_internal_pb2.py +263 -223
- wandb/proto/v3/wandb_server_pb2.py +57 -37
- wandb/proto/v3/wandb_settings_pb2.py +2 -2
- wandb/proto/v4/wandb_internal_pb2.py +226 -218
- wandb/proto/v4/wandb_server_pb2.py +41 -37
- wandb/proto/v4/wandb_settings_pb2.py +2 -2
- wandb/proto/v5/wandb_internal_pb2.py +226 -218
- wandb/proto/v5/wandb_server_pb2.py +41 -37
- wandb/proto/v5/wandb_settings_pb2.py +2 -2
- wandb/sdk/__init__.py +3 -3
- wandb/sdk/artifacts/_validators.py +41 -8
- wandb/sdk/artifacts/artifact.py +32 -1
- wandb/sdk/artifacts/artifact_file_cache.py +1 -2
- wandb/sdk/data_types/_dtypes.py +7 -3
- wandb/sdk/data_types/video.py +15 -6
- wandb/sdk/interface/interface.py +2 -0
- wandb/sdk/internal/internal_api.py +122 -5
- wandb/sdk/internal/sender.py +16 -3
- wandb/sdk/launch/inputs/internal.py +1 -1
- wandb/sdk/lib/module.py +12 -0
- wandb/sdk/lib/printer.py +291 -105
- wandb/sdk/lib/progress.py +274 -0
- wandb/sdk/service/streams.py +21 -11
- wandb/sdk/wandb_init.py +58 -54
- wandb/sdk/wandb_run.py +380 -454
- wandb/sdk/wandb_settings.py +2 -0
- wandb/sdk/wandb_watch.py +17 -11
- wandb/util.py +6 -2
- {wandb-0.18.3.dist-info → wandb-0.18.4.dist-info}/METADATA +4 -3
- {wandb-0.18.3.dist-info → wandb-0.18.4.dist-info}/RECORD +45 -43
- wandb/bin/nvidia_gpu_stats.exe +0 -0
- {wandb-0.18.3.dist-info → wandb-0.18.4.dist-info}/WHEEL +0 -0
- {wandb-0.18.3.dist-info → wandb-0.18.4.dist-info}/entry_points.txt +0 -0
- {wandb-0.18.3.dist-info → wandb-0.18.4.dist-info}/licenses/LICENSE +0 -0
| @@ -0,0 +1,274 @@ | |
| 1 | 
            +
            """Defines an object for printing run progress at the end of a script."""
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            from __future__ import annotations
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            import contextlib
         | 
| 6 | 
            +
            from typing import Iterable, Iterator
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            import wandb
         | 
| 9 | 
            +
            from wandb.proto import wandb_internal_pb2 as pb
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            from . import printer as p
         | 
| 12 | 
            +
             | 
| 13 | 
            +
             | 
| 14 | 
            +
            def print_sync_dedupe_stats(
         | 
| 15 | 
            +
                printer: p.Printer,
         | 
| 16 | 
            +
                final_result: pb.PollExitResponse,
         | 
| 17 | 
            +
            ) -> None:
         | 
| 18 | 
            +
                """Print how much W&B sync reduced the amount of uploaded data.
         | 
| 19 | 
            +
             | 
| 20 | 
            +
                Args:
         | 
| 21 | 
            +
                    final_result: The final PollExit result.
         | 
| 22 | 
            +
                """
         | 
| 23 | 
            +
                deduped_bytes = final_result.pusher_stats.deduped_bytes
         | 
| 24 | 
            +
                total_bytes = final_result.pusher_stats.total_bytes
         | 
| 25 | 
            +
             | 
| 26 | 
            +
                if total_bytes <= 0 or deduped_bytes <= 0:
         | 
| 27 | 
            +
                    return
         | 
| 28 | 
            +
             | 
| 29 | 
            +
                frac = deduped_bytes / total_bytes
         | 
| 30 | 
            +
                printer.display(f"W&B sync reduced upload amount by {frac:.1%}")
         | 
| 31 | 
            +
             | 
| 32 | 
            +
             | 
| 33 | 
            +
            @contextlib.contextmanager
         | 
| 34 | 
            +
            def progress_printer(
         | 
| 35 | 
            +
                printer: p.Printer,
         | 
| 36 | 
            +
                settings: wandb.Settings | None = None,
         | 
| 37 | 
            +
            ) -> Iterator[ProgressPrinter]:
         | 
| 38 | 
            +
                """Context manager providing an object for printing run progress."""
         | 
| 39 | 
            +
                with printer.dynamic_text() as text_area:
         | 
| 40 | 
            +
                    yield ProgressPrinter(printer, text_area, settings)
         | 
| 41 | 
            +
                    printer.progress_close()
         | 
| 42 | 
            +
             | 
| 43 | 
            +
             | 
| 44 | 
            +
            class ProgressPrinter:
         | 
| 45 | 
            +
                """Displays PollExitResponse results to the user."""
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                def __init__(
         | 
| 48 | 
            +
                    self,
         | 
| 49 | 
            +
                    printer: p.Printer,
         | 
| 50 | 
            +
                    progress_text_area: p.DynamicText | None,
         | 
| 51 | 
            +
                    settings: wandb.Settings | None,
         | 
| 52 | 
            +
                ) -> None:
         | 
| 53 | 
            +
                    self._show_operation_stats = settings and settings._show_operation_stats
         | 
| 54 | 
            +
                    self._printer = printer
         | 
| 55 | 
            +
                    self._progress_text_area = progress_text_area
         | 
| 56 | 
            +
                    self._tick = 0
         | 
| 57 | 
            +
                    self._last_printed_line = ""
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                def update(
         | 
| 60 | 
            +
                    self,
         | 
| 61 | 
            +
                    progress: list[pb.PollExitResponse],
         | 
| 62 | 
            +
                ) -> None:
         | 
| 63 | 
            +
                    """Update the displayed information."""
         | 
| 64 | 
            +
                    if not progress:
         | 
| 65 | 
            +
                        return
         | 
| 66 | 
            +
             | 
| 67 | 
            +
                    if self._show_operation_stats:
         | 
| 68 | 
            +
                        self._update_operation_stats(
         | 
| 69 | 
            +
                            list(response.operation_stats for response in progress)
         | 
| 70 | 
            +
                        )
         | 
| 71 | 
            +
                    elif len(progress) == 1:
         | 
| 72 | 
            +
                        self._update_single_run(progress[0])
         | 
| 73 | 
            +
                    else:
         | 
| 74 | 
            +
                        self._update_multiple_runs(progress)
         | 
| 75 | 
            +
             | 
| 76 | 
            +
                    self._tick += 1
         | 
| 77 | 
            +
             | 
| 78 | 
            +
                def _update_operation_stats(self, stats_list: list[pb.OperationStats]) -> None:
         | 
| 79 | 
            +
                    if self._progress_text_area:
         | 
| 80 | 
            +
                        _DynamicOperationStatsPrinter(
         | 
| 81 | 
            +
                            self._printer,
         | 
| 82 | 
            +
                            self._progress_text_area,
         | 
| 83 | 
            +
                            max_lines=6,
         | 
| 84 | 
            +
                            loading_symbol=self._printer.loading_symbol(self._tick),
         | 
| 85 | 
            +
                        ).display(stats_list)
         | 
| 86 | 
            +
             | 
| 87 | 
            +
                    else:
         | 
| 88 | 
            +
                        top_level_operations: list[str] = []
         | 
| 89 | 
            +
                        extra_operations = 0
         | 
| 90 | 
            +
                        for stats in stats_list:
         | 
| 91 | 
            +
                            for op in stats.operations:
         | 
| 92 | 
            +
                                if len(top_level_operations) < 5:
         | 
| 93 | 
            +
                                    top_level_operations.append(op.desc)
         | 
| 94 | 
            +
                                else:
         | 
| 95 | 
            +
                                    extra_operations += 1
         | 
| 96 | 
            +
             | 
| 97 | 
            +
                        line = "; ".join(top_level_operations)
         | 
| 98 | 
            +
                        if extra_operations > 0:
         | 
| 99 | 
            +
                            line += f" (+ {extra_operations} more)"
         | 
| 100 | 
            +
             | 
| 101 | 
            +
                        if line != self._last_printed_line:
         | 
| 102 | 
            +
                            self._printer.display(line)
         | 
| 103 | 
            +
             | 
| 104 | 
            +
                        self._last_printed_line = line
         | 
| 105 | 
            +
             | 
| 106 | 
            +
                def _update_single_run(
         | 
| 107 | 
            +
                    self,
         | 
| 108 | 
            +
                    progress: pb.PollExitResponse,
         | 
| 109 | 
            +
                ) -> None:
         | 
| 110 | 
            +
                    stats = progress.pusher_stats
         | 
| 111 | 
            +
                    line = (
         | 
| 112 | 
            +
                        f"{_megabytes(stats.uploaded_bytes):.3f} MB"
         | 
| 113 | 
            +
                        f" of {_megabytes(stats.total_bytes):.3f} MB uploaded"
         | 
| 114 | 
            +
                    )
         | 
| 115 | 
            +
             | 
| 116 | 
            +
                    if stats.deduped_bytes > 0:
         | 
| 117 | 
            +
                        line += f" ({_megabytes(stats.deduped_bytes):.3f} MB deduped)"
         | 
| 118 | 
            +
             | 
| 119 | 
            +
                    if stats.total_bytes > 0:
         | 
| 120 | 
            +
                        self._update_progress_text(
         | 
| 121 | 
            +
                            line,
         | 
| 122 | 
            +
                            stats.uploaded_bytes / stats.total_bytes,
         | 
| 123 | 
            +
                        )
         | 
| 124 | 
            +
                    else:
         | 
| 125 | 
            +
                        self._update_progress_text(line, 1.0)
         | 
| 126 | 
            +
             | 
| 127 | 
            +
                def _update_multiple_runs(
         | 
| 128 | 
            +
                    self,
         | 
| 129 | 
            +
                    progress_list: list[pb.PollExitResponse],
         | 
| 130 | 
            +
                ) -> None:
         | 
| 131 | 
            +
                    total_files = 0
         | 
| 132 | 
            +
                    uploaded_bytes = 0
         | 
| 133 | 
            +
                    total_bytes = 0
         | 
| 134 | 
            +
             | 
| 135 | 
            +
                    for progress in progress_list:
         | 
| 136 | 
            +
                        total_files += progress.file_counts.wandb_count
         | 
| 137 | 
            +
                        total_files += progress.file_counts.media_count
         | 
| 138 | 
            +
                        total_files += progress.file_counts.artifact_count
         | 
| 139 | 
            +
                        total_files += progress.file_counts.other_count
         | 
| 140 | 
            +
             | 
| 141 | 
            +
                        uploaded_bytes += progress.pusher_stats.uploaded_bytes
         | 
| 142 | 
            +
                        total_bytes += progress.pusher_stats.total_bytes
         | 
| 143 | 
            +
             | 
| 144 | 
            +
                    line = (
         | 
| 145 | 
            +
                        f"Processing {len(progress_list)} runs with {total_files} files"
         | 
| 146 | 
            +
                        f" ({_megabytes(uploaded_bytes):.2f} MB"
         | 
| 147 | 
            +
                        f" / {_megabytes(total_bytes):.2f} MB)"
         | 
| 148 | 
            +
                    )
         | 
| 149 | 
            +
             | 
| 150 | 
            +
                    if total_bytes > 0:
         | 
| 151 | 
            +
                        self._update_progress_text(line, uploaded_bytes / total_bytes)
         | 
| 152 | 
            +
                    else:
         | 
| 153 | 
            +
                        self._update_progress_text(line, 1.0)
         | 
| 154 | 
            +
             | 
| 155 | 
            +
                def _update_progress_text(self, text: str, progress: float) -> None:
         | 
| 156 | 
            +
                    if self._progress_text_area:
         | 
| 157 | 
            +
                        self._progress_text_area.set_text(text)
         | 
| 158 | 
            +
                    else:
         | 
| 159 | 
            +
                        self._printer.progress_update(text + "\r", progress)
         | 
| 160 | 
            +
             | 
| 161 | 
            +
             | 
| 162 | 
            +
            class _DynamicOperationStatsPrinter:
         | 
| 163 | 
            +
                """Single-use object that writes operation stats into a text area."""
         | 
| 164 | 
            +
             | 
| 165 | 
            +
                def __init__(
         | 
| 166 | 
            +
                    self,
         | 
| 167 | 
            +
                    printer: p.Printer,
         | 
| 168 | 
            +
                    text_area: p.DynamicText,
         | 
| 169 | 
            +
                    max_lines: int,
         | 
| 170 | 
            +
                    loading_symbol: str,
         | 
| 171 | 
            +
                ) -> None:
         | 
| 172 | 
            +
                    self._printer = printer
         | 
| 173 | 
            +
                    self._text_area = text_area
         | 
| 174 | 
            +
                    self._max_lines = max_lines
         | 
| 175 | 
            +
                    self._loading_symbol = loading_symbol
         | 
| 176 | 
            +
             | 
| 177 | 
            +
                    self._lines: list[str] = []
         | 
| 178 | 
            +
                    self._ops_shown = 0
         | 
| 179 | 
            +
             | 
| 180 | 
            +
                def display(
         | 
| 181 | 
            +
                    self,
         | 
| 182 | 
            +
                    stats_list: Iterable[pb.OperationStats],
         | 
| 183 | 
            +
                ) -> None:
         | 
| 184 | 
            +
                    """Show the given stats in the text area."""
         | 
| 185 | 
            +
                    total_operations = 0
         | 
| 186 | 
            +
                    for stats in stats_list:
         | 
| 187 | 
            +
                        for op in stats.operations:
         | 
| 188 | 
            +
                            self._add_operation(op, is_subtask=False, indent="")
         | 
| 189 | 
            +
                        total_operations += stats.total_operations
         | 
| 190 | 
            +
             | 
| 191 | 
            +
                    if self._ops_shown < total_operations:
         | 
| 192 | 
            +
                        # NOTE: In Python 3.8, we'd use a chained comparison here.
         | 
| 193 | 
            +
                        if 1 <= self._max_lines and self._max_lines <= len(self._lines):
         | 
| 194 | 
            +
                            self._lines.pop()
         | 
| 195 | 
            +
             | 
| 196 | 
            +
                        remaining = total_operations - self._ops_shown
         | 
| 197 | 
            +
             | 
| 198 | 
            +
                        self._lines.append(f"+ {remaining} more task(s)")
         | 
| 199 | 
            +
             | 
| 200 | 
            +
                    if len(self._lines) == 0:
         | 
| 201 | 
            +
                        if self._loading_symbol:
         | 
| 202 | 
            +
                            self._text_area.set_text(f"{self._loading_symbol} Finishing up...")
         | 
| 203 | 
            +
                        else:
         | 
| 204 | 
            +
                            self._text_area.set_text("Finishing up...")
         | 
| 205 | 
            +
                    else:
         | 
| 206 | 
            +
                        self._text_area.set_text("\n".join(self._lines))
         | 
| 207 | 
            +
             | 
| 208 | 
            +
                def _add_operation(self, op: pb.Operation, is_subtask: bool, indent: str) -> None:
         | 
| 209 | 
            +
                    """Add the operation to `self._lines`."""
         | 
| 210 | 
            +
                    if len(self._lines) >= self._max_lines:
         | 
| 211 | 
            +
                        return
         | 
| 212 | 
            +
             | 
| 213 | 
            +
                    if not is_subtask:
         | 
| 214 | 
            +
                        self._ops_shown += 1
         | 
| 215 | 
            +
             | 
| 216 | 
            +
                    parts = []
         | 
| 217 | 
            +
             | 
| 218 | 
            +
                    # Subtask indicator.
         | 
| 219 | 
            +
                    if is_subtask and self._printer.supports_unicode:
         | 
| 220 | 
            +
                        parts.append("↳")
         | 
| 221 | 
            +
             | 
| 222 | 
            +
                    # Loading symbol.
         | 
| 223 | 
            +
                    if self._loading_symbol:
         | 
| 224 | 
            +
                        parts.append(self._loading_symbol)
         | 
| 225 | 
            +
             | 
| 226 | 
            +
                    # Task name.
         | 
| 227 | 
            +
                    parts.append(op.desc)
         | 
| 228 | 
            +
             | 
| 229 | 
            +
                    # Progress information.
         | 
| 230 | 
            +
                    if op.progress:
         | 
| 231 | 
            +
                        parts.append(f"{op.progress}")
         | 
| 232 | 
            +
             | 
| 233 | 
            +
                    # Task duration.
         | 
| 234 | 
            +
                    parts.append(f"({_time_to_string(seconds=op.runtime_seconds)})")
         | 
| 235 | 
            +
             | 
| 236 | 
            +
                    # Error status.
         | 
| 237 | 
            +
                    self._lines.append(indent + " ".join(parts))
         | 
| 238 | 
            +
                    if op.error_status:
         | 
| 239 | 
            +
                        error_word = self._printer.error("ERROR")
         | 
| 240 | 
            +
                        error_desc = self._printer.secondary_text(op.error_status)
         | 
| 241 | 
            +
                        subtask_indent = "  " if is_subtask else ""
         | 
| 242 | 
            +
                        self._lines.append(
         | 
| 243 | 
            +
                            f"{indent}{subtask_indent}  {error_word} {error_desc}",
         | 
| 244 | 
            +
                        )
         | 
| 245 | 
            +
             | 
| 246 | 
            +
                    # Subtasks.
         | 
| 247 | 
            +
                    if op.subtasks:
         | 
| 248 | 
            +
                        subtask_indent = indent + "  "
         | 
| 249 | 
            +
                        for task in op.subtasks:
         | 
| 250 | 
            +
                            self._add_operation(
         | 
| 251 | 
            +
                                task,
         | 
| 252 | 
            +
                                is_subtask=True,
         | 
| 253 | 
            +
                                indent=subtask_indent,
         | 
| 254 | 
            +
                            )
         | 
| 255 | 
            +
             | 
| 256 | 
            +
             | 
| 257 | 
            +
            def _time_to_string(seconds: float) -> str:
         | 
| 258 | 
            +
                """Returns a short string representing the duration."""
         | 
| 259 | 
            +
                if seconds < 10:
         | 
| 260 | 
            +
                    return f"{seconds:.1f}s"
         | 
| 261 | 
            +
                if seconds < 60:
         | 
| 262 | 
            +
                    return f"{seconds:.0f}s"
         | 
| 263 | 
            +
                if seconds < 60 * 60:
         | 
| 264 | 
            +
                    minutes = seconds / 60
         | 
| 265 | 
            +
                    return f"{minutes:.1f}m"
         | 
| 266 | 
            +
             | 
| 267 | 
            +
                hours = int(seconds / (60 * 60))
         | 
| 268 | 
            +
                minutes = int((seconds / 60) % 60)
         | 
| 269 | 
            +
                return f"{hours}h{minutes}m"
         | 
| 270 | 
            +
             | 
| 271 | 
            +
             | 
| 272 | 
            +
            def _megabytes(bytes: int) -> float:
         | 
| 273 | 
            +
                """Returns the number of megabytes in `bytes`."""
         | 
| 274 | 
            +
                return bytes / (1 << 20)
         | 
    
        wandb/sdk/service/streams.py
    CHANGED
    
    | @@ -20,6 +20,7 @@ import wandb | |
| 20 20 | 
             
            import wandb.util
         | 
| 21 21 | 
             
            from wandb.proto import wandb_internal_pb2 as pb
         | 
| 22 22 | 
             
            from wandb.sdk.internal.settings_static import SettingsStatic
         | 
| 23 | 
            +
            from wandb.sdk.lib import progress
         | 
| 23 24 | 
             
            from wandb.sdk.lib.mailbox import (
         | 
| 24 25 | 
             
                Mailbox,
         | 
| 25 26 | 
             
                MailboxProbe,
         | 
| @@ -257,8 +258,12 @@ class StreamMux: | |
| 257 258 | 
             
                def _on_progress_exit(self, progress_handle: MailboxProgress) -> None:
         | 
| 258 259 | 
             
                    pass
         | 
| 259 260 |  | 
| 260 | 
            -
                def _on_progress_exit_all( | 
| 261 | 
            -
                     | 
| 261 | 
            +
                def _on_progress_exit_all(
         | 
| 262 | 
            +
                    self,
         | 
| 263 | 
            +
                    progress_printer: progress.ProgressPrinter,
         | 
| 264 | 
            +
                    progress_all_handle: MailboxProgressAll,
         | 
| 265 | 
            +
                ) -> None:
         | 
| 266 | 
            +
                    probe_handles: list[MailboxProbe] = []
         | 
| 262 267 | 
             
                    progress_handles = progress_all_handle.get_progress_handles()
         | 
| 263 268 | 
             
                    for progress_handle in progress_handles:
         | 
| 264 269 | 
             
                        probe_handles.extend(progress_handle.get_probe_handles())
         | 
| @@ -268,13 +273,13 @@ class StreamMux: | |
| 268 273 | 
             
                    if self._check_orphaned():
         | 
| 269 274 | 
             
                        self._stopped.set()
         | 
| 270 275 |  | 
| 271 | 
            -
                    poll_exit_responses: List[ | 
| 276 | 
            +
                    poll_exit_responses: List[pb.PollExitResponse] = []
         | 
| 272 277 | 
             
                    for probe_handle in probe_handles:
         | 
| 273 278 | 
             
                        result = probe_handle.get_probe_result()
         | 
| 274 279 | 
             
                        if result:
         | 
| 275 280 | 
             
                            poll_exit_responses.append(result.response.poll_exit_response)
         | 
| 276 281 |  | 
| 277 | 
            -
                     | 
| 282 | 
            +
                    progress_printer.update(poll_exit_responses)
         | 
| 278 283 |  | 
| 279 284 | 
             
                def _finish_all(self, streams: Dict[str, StreamRecord], exit_code: int) -> None:
         | 
| 280 285 | 
             
                    if not streams:
         | 
| @@ -283,7 +288,6 @@ class StreamMux: | |
| 283 288 | 
             
                    printer = get_printer(
         | 
| 284 289 | 
             
                        all(stream._settings._jupyter for stream in streams.values())
         | 
| 285 290 | 
             
                    )
         | 
| 286 | 
            -
                    self._printer = printer
         | 
| 287 291 |  | 
| 288 292 | 
             
                    # fixme: for now we have a single printer for all streams,
         | 
| 289 293 | 
             
                    # and jupyter is disabled if at least single stream's setting set `_jupyter` to false
         | 
| @@ -307,12 +311,18 @@ class StreamMux: | |
| 307 311 | 
             
                        #     exit_code, settings=stream._settings, printer=printer  # type: ignore
         | 
| 308 312 | 
             
                        # )
         | 
| 309 313 |  | 
| 310 | 
            -
                     | 
| 311 | 
            -
             | 
| 312 | 
            -
             | 
| 313 | 
            -
                         | 
| 314 | 
            -
             | 
| 315 | 
            -
             | 
| 314 | 
            +
                    with progress.progress_printer(printer) as progress_printer:
         | 
| 315 | 
            +
                        # todo: should we wait for the max timeout (?) of all exit handles or just wait forever?
         | 
| 316 | 
            +
                        # timeout = max(stream._settings._exit_timeout for stream in streams.values())
         | 
| 317 | 
            +
                        got_result = self._mailbox.wait_all(
         | 
| 318 | 
            +
                            handles=exit_handles,
         | 
| 319 | 
            +
                            timeout=-1,
         | 
| 320 | 
            +
                            on_progress_all=functools.partial(
         | 
| 321 | 
            +
                                self._on_progress_exit_all,
         | 
| 322 | 
            +
                                progress_printer,
         | 
| 323 | 
            +
                            ),
         | 
| 324 | 
            +
                        )
         | 
| 325 | 
            +
                        assert got_result
         | 
| 316 326 |  | 
| 317 327 | 
             
                    # These could be done in parallel in the future
         | 
| 318 328 | 
             
                    for _sid, stream in started_streams.items():
         | 
    
        wandb/sdk/wandb_init.py
    CHANGED
    
    | @@ -8,6 +8,8 @@ For more on using `wandb.init()`, including code snippets, check out our | |
| 8 8 | 
             
            [guide and FAQs](https://docs.wandb.ai/guides/track/launch).
         | 
| 9 9 | 
             
            """
         | 
| 10 10 |  | 
| 11 | 
            +
            from __future__ import annotations
         | 
| 12 | 
            +
             | 
| 11 13 | 
             
            import copy
         | 
| 12 14 | 
             
            import json
         | 
| 13 15 | 
             
            import logging
         | 
| @@ -16,7 +18,7 @@ import platform | |
| 16 18 | 
             
            import sys
         | 
| 17 19 | 
             
            import tempfile
         | 
| 18 20 | 
             
            import time
         | 
| 19 | 
            -
            from typing import TYPE_CHECKING, Any,  | 
| 21 | 
            +
            from typing import TYPE_CHECKING, Any, Sequence
         | 
| 20 22 |  | 
| 21 23 | 
             
            import wandb
         | 
| 22 24 | 
             
            import wandb.env
         | 
| @@ -43,7 +45,7 @@ from .wandb_settings import Settings, Source | |
| 43 45 | 
             
            if TYPE_CHECKING:
         | 
| 44 46 | 
             
                from wandb.proto import wandb_internal_pb2 as pb
         | 
| 45 47 |  | 
| 46 | 
            -
            logger:  | 
| 48 | 
            +
            logger: logging.Logger | None = None  # logger configured during wandb.init()
         | 
| 47 49 |  | 
| 48 50 |  | 
| 49 51 | 
             
            def _set_logger(log_object: logging.Logger) -> None:
         | 
| @@ -52,7 +54,7 @@ def _set_logger(log_object: logging.Logger) -> None: | |
| 52 54 | 
             
                logger = log_object
         | 
| 53 55 |  | 
| 54 56 |  | 
| 55 | 
            -
            def _huggingface_version() ->  | 
| 57 | 
            +
            def _huggingface_version() -> str | None:
         | 
| 56 58 | 
             
                if "transformers" in sys.modules:
         | 
| 57 59 | 
             
                    trans = wandb.util.get_module("transformers")
         | 
| 58 60 | 
             
                    if hasattr(trans, "__version__"):
         | 
| @@ -74,8 +76,8 @@ def _maybe_mp_process(backend: Backend) -> bool: | |
| 74 76 | 
             
                return False
         | 
| 75 77 |  | 
| 76 78 |  | 
| 77 | 
            -
            def _handle_launch_config(settings:  | 
| 78 | 
            -
                launch_run_config:  | 
| 79 | 
            +
            def _handle_launch_config(settings: Settings) -> dict[str, Any]:
         | 
| 80 | 
            +
                launch_run_config: dict[str, Any] = {}
         | 
| 79 81 | 
             
                if not settings.launch:
         | 
| 80 82 | 
             
                    return launch_run_config
         | 
| 81 83 | 
             
                if os.environ.get("WANDB_CONFIG") is not None:
         | 
| @@ -112,22 +114,22 @@ class _WandbInit: | |
| 112 114 |  | 
| 113 115 | 
             
                def __init__(self) -> None:
         | 
| 114 116 | 
             
                    self.kwargs = None
         | 
| 115 | 
            -
                    self. | 
| 116 | 
            -
                    self.sweep_config:  | 
| 117 | 
            -
                    self.launch_config:  | 
| 118 | 
            -
                    self.config:  | 
| 119 | 
            -
                    self.run:  | 
| 120 | 
            -
                    self.backend:  | 
| 121 | 
            -
             | 
| 122 | 
            -
                    self._teardown_hooks:  | 
| 123 | 
            -
                    self._wl:  | 
| 124 | 
            -
                    self._reporter:  | 
| 125 | 
            -
                    self.notebook:  | 
| 126 | 
            -
                    self.printer:  | 
| 117 | 
            +
                    self.setting: Settings | None = None
         | 
| 118 | 
            +
                    self.sweep_config: dict[str, Any] = {}
         | 
| 119 | 
            +
                    self.launch_config: dict[str, Any] = {}
         | 
| 120 | 
            +
                    self.config: dict[str, Any] = {}
         | 
| 121 | 
            +
                    self.run: Run | None = None
         | 
| 122 | 
            +
                    self.backend: Backend | None = None
         | 
| 123 | 
            +
             | 
| 124 | 
            +
                    self._teardown_hooks: list[TeardownHook] = []
         | 
| 125 | 
            +
                    self._wl: wandb_setup._WandbSetup | None = None
         | 
| 126 | 
            +
                    self._reporter: wandb.sdk.lib.reporting.Reporter | None = None
         | 
| 127 | 
            +
                    self.notebook: wandb.jupyter.Notebook | None = None  # type: ignore
         | 
| 128 | 
            +
                    self.printer: Printer | None = None
         | 
| 127 129 |  | 
| 128 130 | 
             
                    self._init_telemetry_obj = telemetry.TelemetryRecord()
         | 
| 129 131 |  | 
| 130 | 
            -
                    self.deprecated_features_used:  | 
| 132 | 
            +
                    self.deprecated_features_used: dict[str, str] = dict()
         | 
| 131 133 |  | 
| 132 134 | 
             
                def _setup_printer(self, settings: Settings) -> None:
         | 
| 133 135 | 
             
                    if self.printer:
         | 
| @@ -203,7 +205,7 @@ class _WandbInit: | |
| 203 205 | 
             
                    self._setup_printer(settings)
         | 
| 204 206 | 
             
                    self._reporter = reporting.setup_reporter(settings=settings)
         | 
| 205 207 |  | 
| 206 | 
            -
                    sagemaker_config:  | 
| 208 | 
            +
                    sagemaker_config: dict = (
         | 
| 207 209 | 
             
                        dict() if settings.sagemaker_disable else sagemaker.parse_sm_config()
         | 
| 208 210 | 
             
                    )
         | 
| 209 211 | 
             
                    if sagemaker_config:
         | 
| @@ -254,7 +256,7 @@ class _WandbInit: | |
| 254 256 | 
             
                    self.sweep_config = dict()
         | 
| 255 257 | 
             
                    sweep_config = self._wl._sweep_config or dict()
         | 
| 256 258 | 
             
                    self.config = dict()
         | 
| 257 | 
            -
                    self.init_artifact_config:  | 
| 259 | 
            +
                    self.init_artifact_config: dict[str, Any] = dict()
         | 
| 258 260 | 
             
                    for config_data in (
         | 
| 259 261 | 
             
                        sagemaker_config,
         | 
| 260 262 | 
             
                        self._wl._config,
         | 
| @@ -368,7 +370,7 @@ class _WandbInit: | |
| 368 370 | 
             
                        else:
         | 
| 369 371 | 
             
                            config_target.setdefault(k, v)
         | 
| 370 372 |  | 
| 371 | 
            -
                def _enable_logging(self, log_fname: str, run_id:  | 
| 373 | 
            +
                def _enable_logging(self, log_fname: str, run_id: str | None = None) -> None:
         | 
| 372 374 | 
             
                    """Enable logging to the global debug log.
         | 
| 373 375 |  | 
| 374 376 | 
             
                    This adds a run_id to the log, in case of multiple processes on the same machine.
         | 
| @@ -602,6 +604,8 @@ class _WandbInit: | |
| 602 604 | 
             
                        define_metric=drun.define_metric,
         | 
| 603 605 | 
             
                        plot_table=drun.plot_table,
         | 
| 604 606 | 
             
                        alert=drun.alert,
         | 
| 607 | 
            +
                        watch=drun.watch,
         | 
| 608 | 
            +
                        unwatch=drun.unwatch,
         | 
| 605 609 | 
             
                    )
         | 
| 606 610 | 
             
                    return drun
         | 
| 607 611 |  | 
| @@ -719,7 +723,7 @@ class _WandbInit: | |
| 719 723 | 
             
                            setattr(tel.imports_init, module_name, True)
         | 
| 720 724 |  | 
| 721 725 | 
             
                        # probe the active start method
         | 
| 722 | 
            -
                        active_start_method:  | 
| 726 | 
            +
                        active_start_method: str | None = None
         | 
| 723 727 | 
             
                        if self.settings.start_method == "thread":
         | 
| 724 728 | 
             
                            active_start_method = self.settings.start_method
         | 
| 725 729 | 
             
                        else:
         | 
| @@ -794,7 +798,7 @@ class _WandbInit: | |
| 794 798 | 
             
                    if not self.settings.disable_git:
         | 
| 795 799 | 
             
                        run._populate_git_info()
         | 
| 796 800 |  | 
| 797 | 
            -
                    run_result:  | 
| 801 | 
            +
                    run_result: pb.RunUpdateResult | None = None
         | 
| 798 802 |  | 
| 799 803 | 
             
                    if self.settings._offline:
         | 
| 800 804 | 
             
                        with telemetry.context(run=run) as tel:
         | 
| @@ -805,7 +809,7 @@ class _WandbInit: | |
| 805 809 | 
             
                                "`resume` will be ignored since W&B syncing is set to `offline`. "
         | 
| 806 810 | 
             
                                f"Starting a new run with run id {run.id}."
         | 
| 807 811 | 
             
                            )
         | 
| 808 | 
            -
                    error:  | 
| 812 | 
            +
                    error: wandb.Error | None = None
         | 
| 809 813 |  | 
| 810 814 | 
             
                    timeout = self.settings.init_timeout
         | 
| 811 815 |  | 
| @@ -909,11 +913,11 @@ class _WandbInit: | |
| 909 913 |  | 
| 910 914 |  | 
| 911 915 | 
             
            def _attach(
         | 
| 912 | 
            -
                attach_id:  | 
| 913 | 
            -
                run_id:  | 
| 916 | 
            +
                attach_id: str | None = None,
         | 
| 917 | 
            +
                run_id: str | None = None,
         | 
| 914 918 | 
             
                *,
         | 
| 915 | 
            -
                run:  | 
| 916 | 
            -
            ) ->  | 
| 919 | 
            +
                run: Run | None = None,
         | 
| 920 | 
            +
            ) -> Run | None:
         | 
| 917 921 | 
             
                """Attach to a run currently executing in another process/thread.
         | 
| 918 922 |  | 
| 919 923 | 
             
                Arguments:
         | 
| @@ -995,32 +999,32 @@ def _attach( | |
| 995 999 |  | 
| 996 1000 |  | 
| 997 1001 | 
             
            def init(
         | 
| 998 | 
            -
                job_type:  | 
| 999 | 
            -
                dir:  | 
| 1000 | 
            -
                config:  | 
| 1001 | 
            -
                project:  | 
| 1002 | 
            -
                entity:  | 
| 1003 | 
            -
                reinit:  | 
| 1004 | 
            -
                tags:  | 
| 1005 | 
            -
                group:  | 
| 1006 | 
            -
                name:  | 
| 1007 | 
            -
                notes:  | 
| 1008 | 
            -
                magic:  | 
| 1009 | 
            -
                config_exclude_keys:  | 
| 1010 | 
            -
                config_include_keys:  | 
| 1011 | 
            -
                anonymous:  | 
| 1012 | 
            -
                mode:  | 
| 1013 | 
            -
                allow_val_change:  | 
| 1014 | 
            -
                resume:  | 
| 1015 | 
            -
                force:  | 
| 1016 | 
            -
                tensorboard:  | 
| 1017 | 
            -
                sync_tensorboard:  | 
| 1018 | 
            -
                monitor_gym:  | 
| 1019 | 
            -
                save_code:  | 
| 1020 | 
            -
                id:  | 
| 1021 | 
            -
                fork_from:  | 
| 1022 | 
            -
                resume_from:  | 
| 1023 | 
            -
                settings:  | 
| 1002 | 
            +
                job_type: str | None = None,
         | 
| 1003 | 
            +
                dir: StrPath | None = None,
         | 
| 1004 | 
            +
                config: dict | str | None = None,
         | 
| 1005 | 
            +
                project: str | None = None,
         | 
| 1006 | 
            +
                entity: str | None = None,
         | 
| 1007 | 
            +
                reinit: bool | None = None,
         | 
| 1008 | 
            +
                tags: Sequence | None = None,
         | 
| 1009 | 
            +
                group: str | None = None,
         | 
| 1010 | 
            +
                name: str | None = None,
         | 
| 1011 | 
            +
                notes: str | None = None,
         | 
| 1012 | 
            +
                magic: dict | str | bool | None = None,
         | 
| 1013 | 
            +
                config_exclude_keys: list[str] | None = None,
         | 
| 1014 | 
            +
                config_include_keys: list[str] | None = None,
         | 
| 1015 | 
            +
                anonymous: str | None = None,
         | 
| 1016 | 
            +
                mode: str | None = None,
         | 
| 1017 | 
            +
                allow_val_change: bool | None = None,
         | 
| 1018 | 
            +
                resume: bool | str | None = None,
         | 
| 1019 | 
            +
                force: bool | None = None,
         | 
| 1020 | 
            +
                tensorboard: bool | None = None,  # alias for sync_tensorboard
         | 
| 1021 | 
            +
                sync_tensorboard: bool | None = None,
         | 
| 1022 | 
            +
                monitor_gym: bool | None = None,
         | 
| 1023 | 
            +
                save_code: bool | None = None,
         | 
| 1024 | 
            +
                id: str | None = None,
         | 
| 1025 | 
            +
                fork_from: str | None = None,
         | 
| 1026 | 
            +
                resume_from: str | None = None,
         | 
| 1027 | 
            +
                settings: Settings | dict[str, Any] | None = None,
         | 
| 1024 1028 | 
             
            ) -> Run:
         | 
| 1025 1029 | 
             
                r"""Start a new run to track and log to W&B.
         | 
| 1026 1030 |  |