wandb 0.18.0__py3-none-macosx_11_0_arm64.whl → 0.18.1__py3-none-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wandb/__init__.py +2 -2
- wandb/__init__.pyi +1 -1
- wandb/apis/public/runs.py +2 -0
- wandb/bin/apple_gpu_stats +0 -0
- wandb/bin/wandb-core +0 -0
- wandb/cli/cli.py +0 -2
- wandb/data_types.py +9 -2019
- wandb/env.py +0 -5
- wandb/{sklearn → integration/sklearn}/calculate/calibration_curves.py +7 -7
- wandb/{sklearn → integration/sklearn}/calculate/class_proportions.py +1 -1
- wandb/{sklearn → integration/sklearn}/calculate/confusion_matrix.py +3 -2
- wandb/{sklearn → integration/sklearn}/calculate/elbow_curve.py +6 -6
- wandb/{sklearn → integration/sklearn}/calculate/learning_curve.py +2 -2
- wandb/{sklearn → integration/sklearn}/calculate/outlier_candidates.py +2 -2
- wandb/{sklearn → integration/sklearn}/calculate/residuals.py +8 -8
- wandb/{sklearn → integration/sklearn}/calculate/silhouette.py +2 -2
- wandb/{sklearn → integration/sklearn}/calculate/summary_metrics.py +2 -2
- wandb/{sklearn → integration/sklearn}/plot/classifier.py +5 -5
- wandb/{sklearn → integration/sklearn}/plot/clusterer.py +10 -6
- wandb/{sklearn → integration/sklearn}/plot/regressor.py +5 -5
- wandb/{sklearn → integration/sklearn}/plot/shared.py +3 -3
- wandb/{sklearn → integration/sklearn}/utils.py +8 -8
- wandb/{wandb_torch.py → integration/torch/wandb_torch.py} +36 -32
- wandb/proto/v3/wandb_base_pb2.py +2 -1
- wandb/proto/v3/wandb_internal_pb2.py +2 -1
- wandb/proto/v3/wandb_server_pb2.py +2 -1
- wandb/proto/v3/wandb_settings_pb2.py +2 -1
- wandb/proto/v3/wandb_telemetry_pb2.py +2 -1
- wandb/proto/v4/wandb_base_pb2.py +2 -1
- wandb/proto/v4/wandb_internal_pb2.py +2 -1
- wandb/proto/v4/wandb_server_pb2.py +2 -1
- wandb/proto/v4/wandb_settings_pb2.py +2 -1
- wandb/proto/v4/wandb_telemetry_pb2.py +2 -1
- wandb/proto/v5/wandb_base_pb2.py +3 -2
- wandb/proto/v5/wandb_internal_pb2.py +3 -2
- wandb/proto/v5/wandb_server_pb2.py +3 -2
- wandb/proto/v5/wandb_settings_pb2.py +3 -2
- wandb/proto/v5/wandb_telemetry_pb2.py +3 -2
- wandb/sdk/data_types/audio.py +165 -0
- wandb/sdk/data_types/bokeh.py +70 -0
- wandb/sdk/data_types/graph.py +405 -0
- wandb/sdk/data_types/image.py +156 -0
- wandb/sdk/data_types/table.py +1204 -0
- wandb/sdk/data_types/trace_tree.py +2 -2
- wandb/sdk/data_types/utils.py +49 -0
- wandb/sdk/service/service.py +2 -9
- wandb/sdk/service/streams.py +0 -7
- wandb/sdk/wandb_init.py +10 -3
- wandb/sdk/wandb_run.py +6 -152
- wandb/sdk/wandb_setup.py +1 -1
- wandb/sklearn.py +35 -0
- wandb/util.py +6 -2
- {wandb-0.18.0.dist-info → wandb-0.18.1.dist-info}/METADATA +1 -1
- {wandb-0.18.0.dist-info → wandb-0.18.1.dist-info}/RECORD +62 -58
- wandb/sdk/lib/console.py +0 -39
- /wandb/{sklearn → integration/sklearn}/__init__.py +0 -0
- /wandb/{sklearn → integration/sklearn}/calculate/__init__.py +0 -0
- /wandb/{sklearn → integration/sklearn}/calculate/decision_boundaries.py +0 -0
- /wandb/{sklearn → integration/sklearn}/calculate/feature_importances.py +0 -0
- /wandb/{sklearn → integration/sklearn}/plot/__init__.py +0 -0
- {wandb-0.18.0.dist-info → wandb-0.18.1.dist-info}/WHEEL +0 -0
- {wandb-0.18.0.dist-info → wandb-0.18.1.dist-info}/entry_points.txt +0 -0
- {wandb-0.18.0.dist-info → wandb-0.18.1.dist-info}/licenses/LICENSE +0 -0
| @@ -0,0 +1,165 @@ | |
| 1 | 
            +
            import hashlib
         | 
| 2 | 
            +
            import os
         | 
| 3 | 
            +
            from typing import Optional
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            from wandb import util
         | 
| 6 | 
            +
            from wandb.sdk.lib import filesystem, runid
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            from . import _dtypes
         | 
| 9 | 
            +
            from ._private import MEDIA_TMP
         | 
| 10 | 
            +
            from .base_types.media import BatchableMedia
         | 
| 11 | 
            +
             | 
| 12 | 
            +
             | 
| 13 | 
            +
            class Audio(BatchableMedia):
         | 
| 14 | 
            +
                """Wandb class for audio clips.
         | 
| 15 | 
            +
             | 
| 16 | 
            +
                Arguments:
         | 
| 17 | 
            +
                    data_or_path: (string or numpy array) A path to an audio file
         | 
| 18 | 
            +
                        or a numpy array of audio data.
         | 
| 19 | 
            +
                    sample_rate: (int) Sample rate, required when passing in raw
         | 
| 20 | 
            +
                        numpy array of audio data.
         | 
| 21 | 
            +
                    caption: (string) Caption to display with audio.
         | 
| 22 | 
            +
                """
         | 
| 23 | 
            +
             | 
| 24 | 
            +
                _log_type = "audio-file"
         | 
| 25 | 
            +
             | 
| 26 | 
            +
                def __init__(self, data_or_path, sample_rate=None, caption=None):
         | 
| 27 | 
            +
                    """Accept a path to an audio file or a numpy array of audio data."""
         | 
| 28 | 
            +
                    super().__init__()
         | 
| 29 | 
            +
                    self._duration = None
         | 
| 30 | 
            +
                    self._sample_rate = sample_rate
         | 
| 31 | 
            +
                    self._caption = caption
         | 
| 32 | 
            +
             | 
| 33 | 
            +
                    if isinstance(data_or_path, str):
         | 
| 34 | 
            +
                        if self.path_is_reference(data_or_path):
         | 
| 35 | 
            +
                            self._path = data_or_path
         | 
| 36 | 
            +
                            self._sha256 = hashlib.sha256(data_or_path.encode("utf-8")).hexdigest()
         | 
| 37 | 
            +
                            self._is_tmp = False
         | 
| 38 | 
            +
                        else:
         | 
| 39 | 
            +
                            self._set_file(data_or_path, is_tmp=False)
         | 
| 40 | 
            +
                    else:
         | 
| 41 | 
            +
                        if sample_rate is None:
         | 
| 42 | 
            +
                            raise ValueError(
         | 
| 43 | 
            +
                                'Argument "sample_rate" is required when instantiating wandb.Audio with raw data.'
         | 
| 44 | 
            +
                            )
         | 
| 45 | 
            +
             | 
| 46 | 
            +
                        soundfile = util.get_module(
         | 
| 47 | 
            +
                            "soundfile",
         | 
| 48 | 
            +
                            required='Raw audio requires the soundfile package. To get it, run "pip install soundfile"',
         | 
| 49 | 
            +
                        )
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                        tmp_path = os.path.join(MEDIA_TMP.name, runid.generate_id() + ".wav")
         | 
| 52 | 
            +
                        soundfile.write(tmp_path, data_or_path, sample_rate)
         | 
| 53 | 
            +
                        self._duration = len(data_or_path) / float(sample_rate)
         | 
| 54 | 
            +
             | 
| 55 | 
            +
                        self._set_file(tmp_path, is_tmp=True)
         | 
| 56 | 
            +
             | 
| 57 | 
            +
                @classmethod
         | 
| 58 | 
            +
                def get_media_subdir(cls):
         | 
| 59 | 
            +
                    return os.path.join("media", "audio")
         | 
| 60 | 
            +
             | 
| 61 | 
            +
                @classmethod
         | 
| 62 | 
            +
                def from_json(cls, json_obj, source_artifact):
         | 
| 63 | 
            +
                    return cls(
         | 
| 64 | 
            +
                        source_artifact.get_entry(json_obj["path"]).download(),
         | 
| 65 | 
            +
                        caption=json_obj["caption"],
         | 
| 66 | 
            +
                    )
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                def bind_to_run(
         | 
| 69 | 
            +
                    self, run, key, step, id_=None, ignore_copy_err: Optional[bool] = None
         | 
| 70 | 
            +
                ):
         | 
| 71 | 
            +
                    if self.path_is_reference(self._path):
         | 
| 72 | 
            +
                        raise ValueError(
         | 
| 73 | 
            +
                            "Audio media created by a reference to external storage cannot currently be added to a run"
         | 
| 74 | 
            +
                        )
         | 
| 75 | 
            +
             | 
| 76 | 
            +
                    return super().bind_to_run(run, key, step, id_, ignore_copy_err)
         | 
| 77 | 
            +
             | 
| 78 | 
            +
                def to_json(self, run):
         | 
| 79 | 
            +
                    json_dict = super().to_json(run)
         | 
| 80 | 
            +
                    json_dict.update(
         | 
| 81 | 
            +
                        {
         | 
| 82 | 
            +
                            "_type": self._log_type,
         | 
| 83 | 
            +
                            "caption": self._caption,
         | 
| 84 | 
            +
                        }
         | 
| 85 | 
            +
                    )
         | 
| 86 | 
            +
                    return json_dict
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                @classmethod
         | 
| 89 | 
            +
                def seq_to_json(cls, seq, run, key, step):
         | 
| 90 | 
            +
                    audio_list = list(seq)
         | 
| 91 | 
            +
             | 
| 92 | 
            +
                    util.get_module(
         | 
| 93 | 
            +
                        "soundfile",
         | 
| 94 | 
            +
                        required="wandb.Audio requires the soundfile package. To get it, run: pip install soundfile",
         | 
| 95 | 
            +
                    )
         | 
| 96 | 
            +
                    base_path = os.path.join(run.dir, "media", "audio")
         | 
| 97 | 
            +
                    filesystem.mkdir_exists_ok(base_path)
         | 
| 98 | 
            +
                    meta = {
         | 
| 99 | 
            +
                        "_type": "audio",
         | 
| 100 | 
            +
                        "count": len(audio_list),
         | 
| 101 | 
            +
                        "audio": [a.to_json(run) for a in audio_list],
         | 
| 102 | 
            +
                    }
         | 
| 103 | 
            +
                    sample_rates = cls.sample_rates(audio_list)
         | 
| 104 | 
            +
                    if sample_rates:
         | 
| 105 | 
            +
                        meta["sampleRates"] = sample_rates
         | 
| 106 | 
            +
                    durations = cls.durations(audio_list)
         | 
| 107 | 
            +
                    if durations:
         | 
| 108 | 
            +
                        meta["durations"] = durations
         | 
| 109 | 
            +
                    captions = cls.captions(audio_list)
         | 
| 110 | 
            +
                    if captions:
         | 
| 111 | 
            +
                        meta["captions"] = captions
         | 
| 112 | 
            +
             | 
| 113 | 
            +
                    return meta
         | 
| 114 | 
            +
             | 
| 115 | 
            +
                @classmethod
         | 
| 116 | 
            +
                def durations(cls, audio_list):
         | 
| 117 | 
            +
                    return [a._duration for a in audio_list]
         | 
| 118 | 
            +
             | 
| 119 | 
            +
                @classmethod
         | 
| 120 | 
            +
                def sample_rates(cls, audio_list):
         | 
| 121 | 
            +
                    return [a._sample_rate for a in audio_list]
         | 
| 122 | 
            +
             | 
| 123 | 
            +
                @classmethod
         | 
| 124 | 
            +
                def captions(cls, audio_list):
         | 
| 125 | 
            +
                    captions = [a._caption for a in audio_list]
         | 
| 126 | 
            +
                    if all(c is None for c in captions):
         | 
| 127 | 
            +
                        return False
         | 
| 128 | 
            +
                    else:
         | 
| 129 | 
            +
                        return ["" if c is None else c for c in captions]
         | 
| 130 | 
            +
             | 
| 131 | 
            +
                def resolve_ref(self):
         | 
| 132 | 
            +
                    if self.path_is_reference(self._path):
         | 
| 133 | 
            +
                        # this object was already created using a ref:
         | 
| 134 | 
            +
                        return self._path
         | 
| 135 | 
            +
                    source_artifact = self._artifact_source.artifact
         | 
| 136 | 
            +
             | 
| 137 | 
            +
                    resolved_name = source_artifact._local_path_to_name(self._path)
         | 
| 138 | 
            +
                    if resolved_name is not None:
         | 
| 139 | 
            +
                        target_entry = source_artifact.manifest.get_entry_by_path(resolved_name)
         | 
| 140 | 
            +
                        if target_entry is not None:
         | 
| 141 | 
            +
                            return target_entry.ref
         | 
| 142 | 
            +
             | 
| 143 | 
            +
                    return None
         | 
| 144 | 
            +
             | 
| 145 | 
            +
                def __eq__(self, other):
         | 
| 146 | 
            +
                    if self.path_is_reference(self._path) or self.path_is_reference(other._path):
         | 
| 147 | 
            +
                        # one or more of these objects is an unresolved reference -- we'll compare
         | 
| 148 | 
            +
                        # their reference paths instead of their SHAs:
         | 
| 149 | 
            +
                        return (
         | 
| 150 | 
            +
                            self.resolve_ref() == other.resolve_ref()
         | 
| 151 | 
            +
                            and self._caption == other._caption
         | 
| 152 | 
            +
                        )
         | 
| 153 | 
            +
             | 
| 154 | 
            +
                    return super().__eq__(other) and self._caption == other._caption
         | 
| 155 | 
            +
             | 
| 156 | 
            +
                def __ne__(self, other):
         | 
| 157 | 
            +
                    return not self.__eq__(other)
         | 
| 158 | 
            +
             | 
| 159 | 
            +
             | 
| 160 | 
            +
            class _AudioFileType(_dtypes.Type):
         | 
| 161 | 
            +
                name = "audio-file"
         | 
| 162 | 
            +
                types = [Audio]
         | 
| 163 | 
            +
             | 
| 164 | 
            +
             | 
| 165 | 
            +
            _dtypes.TypeRegistry.add(_AudioFileType)
         | 
| @@ -0,0 +1,70 @@ | |
| 1 | 
            +
            import codecs
         | 
| 2 | 
            +
            import json
         | 
| 3 | 
            +
            import os
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            from wandb import util
         | 
| 6 | 
            +
            from wandb.sdk.lib import runid
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            from . import _dtypes
         | 
| 9 | 
            +
            from ._private import MEDIA_TMP
         | 
| 10 | 
            +
            from .base_types.media import Media
         | 
| 11 | 
            +
             | 
| 12 | 
            +
             | 
| 13 | 
            +
            class Bokeh(Media):
         | 
| 14 | 
            +
                """Wandb class for Bokeh plots.
         | 
| 15 | 
            +
             | 
| 16 | 
            +
                Arguments:
         | 
| 17 | 
            +
                    val: Bokeh plot
         | 
| 18 | 
            +
                """
         | 
| 19 | 
            +
             | 
| 20 | 
            +
                _log_type = "bokeh-file"
         | 
| 21 | 
            +
             | 
| 22 | 
            +
                def __init__(self, data_or_path):
         | 
| 23 | 
            +
                    super().__init__()
         | 
| 24 | 
            +
                    bokeh = util.get_module("bokeh", required=True)
         | 
| 25 | 
            +
                    if isinstance(data_or_path, str) and os.path.exists(data_or_path):
         | 
| 26 | 
            +
                        with open(data_or_path) as file:
         | 
| 27 | 
            +
                            b_json = json.load(file)
         | 
| 28 | 
            +
                        self.b_obj = bokeh.document.Document.from_json(b_json)
         | 
| 29 | 
            +
                        self._set_file(data_or_path, is_tmp=False, extension=".bokeh.json")
         | 
| 30 | 
            +
                    elif isinstance(data_or_path, bokeh.model.Model):
         | 
| 31 | 
            +
                        _data = bokeh.document.Document()
         | 
| 32 | 
            +
                        _data.add_root(data_or_path)
         | 
| 33 | 
            +
                        # serialize/deserialize pairing followed by sorting attributes ensures
         | 
| 34 | 
            +
                        # that the file's sha's are equivalent in subsequent calls
         | 
| 35 | 
            +
                        self.b_obj = bokeh.document.Document.from_json(_data.to_json())
         | 
| 36 | 
            +
                        b_json = self.b_obj.to_json()
         | 
| 37 | 
            +
                        if "references" in b_json["roots"]:
         | 
| 38 | 
            +
                            b_json["roots"]["references"].sort(key=lambda x: x["id"])
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                        tmp_path = os.path.join(MEDIA_TMP.name, runid.generate_id() + ".bokeh.json")
         | 
| 41 | 
            +
                        with codecs.open(tmp_path, "w", encoding="utf-8") as fp:
         | 
| 42 | 
            +
                            util.json_dump_safer(b_json, fp)
         | 
| 43 | 
            +
                        self._set_file(tmp_path, is_tmp=True, extension=".bokeh.json")
         | 
| 44 | 
            +
                    elif not isinstance(data_or_path, bokeh.document.Document):
         | 
| 45 | 
            +
                        raise TypeError(
         | 
| 46 | 
            +
                            "Bokeh constructor accepts Bokeh document/model or path to Bokeh json file"
         | 
| 47 | 
            +
                        )
         | 
| 48 | 
            +
             | 
| 49 | 
            +
                def get_media_subdir(self):
         | 
| 50 | 
            +
                    return os.path.join("media", "bokeh")
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                def to_json(self, run):
         | 
| 53 | 
            +
                    # TODO: (tss) this is getting redundant for all the media objects. We can probably
         | 
| 54 | 
            +
                    # pull this into Media#to_json and remove this type override for all the media types.
         | 
| 55 | 
            +
                    # There are only a few cases where the type is different between artifacts and runs.
         | 
| 56 | 
            +
                    json_dict = super().to_json(run)
         | 
| 57 | 
            +
                    json_dict["_type"] = self._log_type
         | 
| 58 | 
            +
                    return json_dict
         | 
| 59 | 
            +
             | 
| 60 | 
            +
                @classmethod
         | 
| 61 | 
            +
                def from_json(cls, json_obj, source_artifact):
         | 
| 62 | 
            +
                    return cls(source_artifact.get_entry(json_obj["path"]).download())
         | 
| 63 | 
            +
             | 
| 64 | 
            +
             | 
| 65 | 
            +
            class _BokehFileType(_dtypes.Type):
         | 
| 66 | 
            +
                name = "bokeh-file"
         | 
| 67 | 
            +
                types = [Bokeh]
         | 
| 68 | 
            +
             | 
| 69 | 
            +
             | 
| 70 | 
            +
            _dtypes.TypeRegistry.add(_BokehFileType)
         | 
| @@ -0,0 +1,405 @@ | |
| 1 | 
            +
            import codecs
         | 
| 2 | 
            +
            import os
         | 
| 3 | 
            +
            import pprint
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            from wandb import util
         | 
| 6 | 
            +
            from wandb.sdk.data_types._private import MEDIA_TMP
         | 
| 7 | 
            +
            from wandb.sdk.data_types.base_types.media import Media, _numpy_arrays_to_lists
         | 
| 8 | 
            +
            from wandb.sdk.data_types.base_types.wb_value import WBValue
         | 
| 9 | 
            +
            from wandb.sdk.lib import runid
         | 
| 10 | 
            +
             | 
| 11 | 
            +
             | 
| 12 | 
            +
            def _nest(thing):
         | 
| 13 | 
            +
                # Use tensorflows nest function if available, otherwise just wrap object in an array"""
         | 
| 14 | 
            +
             | 
| 15 | 
            +
                tfutil = util.get_module("tensorflow.python.util")
         | 
| 16 | 
            +
                if tfutil:
         | 
| 17 | 
            +
                    return tfutil.nest.flatten(thing)
         | 
| 18 | 
            +
                else:
         | 
| 19 | 
            +
                    return [thing]
         | 
| 20 | 
            +
             | 
| 21 | 
            +
             | 
| 22 | 
            +
            class Edge(WBValue):
         | 
| 23 | 
            +
                """Edge used in `Graph`."""
         | 
| 24 | 
            +
             | 
| 25 | 
            +
                def __init__(self, from_node, to_node):
         | 
| 26 | 
            +
                    self._attributes = {}
         | 
| 27 | 
            +
                    self.from_node = from_node
         | 
| 28 | 
            +
                    self.to_node = to_node
         | 
| 29 | 
            +
             | 
| 30 | 
            +
                def __repr__(self):
         | 
| 31 | 
            +
                    temp_attr = dict(self._attributes)
         | 
| 32 | 
            +
                    del temp_attr["from_node"]
         | 
| 33 | 
            +
                    del temp_attr["to_node"]
         | 
| 34 | 
            +
                    temp_attr["from_id"] = self.from_node.id
         | 
| 35 | 
            +
                    temp_attr["to_id"] = self.to_node.id
         | 
| 36 | 
            +
                    return str(temp_attr)
         | 
| 37 | 
            +
             | 
| 38 | 
            +
                def to_json(self, run=None):
         | 
| 39 | 
            +
                    return [self.from_node.id, self.to_node.id]
         | 
| 40 | 
            +
             | 
| 41 | 
            +
                @property
         | 
| 42 | 
            +
                def name(self):
         | 
| 43 | 
            +
                    """Optional, not necessarily unique."""
         | 
| 44 | 
            +
                    return self._attributes.get("name")
         | 
| 45 | 
            +
             | 
| 46 | 
            +
                @name.setter
         | 
| 47 | 
            +
                def name(self, val):
         | 
| 48 | 
            +
                    self._attributes["name"] = val
         | 
| 49 | 
            +
                    return val
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                @property
         | 
| 52 | 
            +
                def from_node(self):
         | 
| 53 | 
            +
                    return self._attributes.get("from_node")
         | 
| 54 | 
            +
             | 
| 55 | 
            +
                @from_node.setter
         | 
| 56 | 
            +
                def from_node(self, val):
         | 
| 57 | 
            +
                    self._attributes["from_node"] = val
         | 
| 58 | 
            +
                    return val
         | 
| 59 | 
            +
             | 
| 60 | 
            +
                @property
         | 
| 61 | 
            +
                def to_node(self):
         | 
| 62 | 
            +
                    return self._attributes.get("to_node")
         | 
| 63 | 
            +
             | 
| 64 | 
            +
                @to_node.setter
         | 
| 65 | 
            +
                def to_node(self, val):
         | 
| 66 | 
            +
                    self._attributes["to_node"] = val
         | 
| 67 | 
            +
                    return val
         | 
| 68 | 
            +
             | 
| 69 | 
            +
             | 
| 70 | 
            +
            class Node(WBValue):
         | 
| 71 | 
            +
                """Node used in `Graph`."""
         | 
| 72 | 
            +
             | 
| 73 | 
            +
                def __init__(
         | 
| 74 | 
            +
                    self,
         | 
| 75 | 
            +
                    id=None,
         | 
| 76 | 
            +
                    name=None,
         | 
| 77 | 
            +
                    class_name=None,
         | 
| 78 | 
            +
                    size=None,
         | 
| 79 | 
            +
                    parameters=None,
         | 
| 80 | 
            +
                    output_shape=None,
         | 
| 81 | 
            +
                    is_output=None,
         | 
| 82 | 
            +
                    num_parameters=None,
         | 
| 83 | 
            +
                    node=None,
         | 
| 84 | 
            +
                ):
         | 
| 85 | 
            +
                    self._attributes = {"name": None}
         | 
| 86 | 
            +
                    self.in_edges = {}  # indexed by source node id
         | 
| 87 | 
            +
                    self.out_edges = {}  # indexed by dest node id
         | 
| 88 | 
            +
                    # optional object (e.g. PyTorch Parameter or Module) that this Node represents
         | 
| 89 | 
            +
                    self.obj = None
         | 
| 90 | 
            +
             | 
| 91 | 
            +
                    if node is not None:
         | 
| 92 | 
            +
                        self._attributes.update(node._attributes)
         | 
| 93 | 
            +
                        del self._attributes["id"]
         | 
| 94 | 
            +
                        self.obj = node.obj
         | 
| 95 | 
            +
             | 
| 96 | 
            +
                    if id is not None:
         | 
| 97 | 
            +
                        self.id = id
         | 
| 98 | 
            +
                    if name is not None:
         | 
| 99 | 
            +
                        self.name = name
         | 
| 100 | 
            +
                    if class_name is not None:
         | 
| 101 | 
            +
                        self.class_name = class_name
         | 
| 102 | 
            +
                    if size is not None:
         | 
| 103 | 
            +
                        self.size = size
         | 
| 104 | 
            +
                    if parameters is not None:
         | 
| 105 | 
            +
                        self.parameters = parameters
         | 
| 106 | 
            +
                    if output_shape is not None:
         | 
| 107 | 
            +
                        self.output_shape = output_shape
         | 
| 108 | 
            +
                    if is_output is not None:
         | 
| 109 | 
            +
                        self.is_output = is_output
         | 
| 110 | 
            +
                    if num_parameters is not None:
         | 
| 111 | 
            +
                        self.num_parameters = num_parameters
         | 
| 112 | 
            +
             | 
| 113 | 
            +
                def to_json(self, run=None):
         | 
| 114 | 
            +
                    return self._attributes
         | 
| 115 | 
            +
             | 
| 116 | 
            +
                def __repr__(self):
         | 
| 117 | 
            +
                    return repr(self._attributes)
         | 
| 118 | 
            +
             | 
| 119 | 
            +
                @property
         | 
| 120 | 
            +
                def id(self):
         | 
| 121 | 
            +
                    """Must be unique in the graph."""
         | 
| 122 | 
            +
                    return self._attributes.get("id")
         | 
| 123 | 
            +
             | 
| 124 | 
            +
                @id.setter
         | 
| 125 | 
            +
                def id(self, val):
         | 
| 126 | 
            +
                    self._attributes["id"] = val
         | 
| 127 | 
            +
                    return val
         | 
| 128 | 
            +
             | 
| 129 | 
            +
                @property
         | 
| 130 | 
            +
                def name(self):
         | 
| 131 | 
            +
                    """Usually the type of layer or sublayer."""
         | 
| 132 | 
            +
                    return self._attributes.get("name")
         | 
| 133 | 
            +
             | 
| 134 | 
            +
                @name.setter
         | 
| 135 | 
            +
                def name(self, val):
         | 
| 136 | 
            +
                    self._attributes["name"] = val
         | 
| 137 | 
            +
                    return val
         | 
| 138 | 
            +
             | 
| 139 | 
            +
                @property
         | 
| 140 | 
            +
                def class_name(self):
         | 
| 141 | 
            +
                    """Usually the type of layer or sublayer."""
         | 
| 142 | 
            +
                    return self._attributes.get("class_name")
         | 
| 143 | 
            +
             | 
| 144 | 
            +
                @class_name.setter
         | 
| 145 | 
            +
                def class_name(self, val):
         | 
| 146 | 
            +
                    self._attributes["class_name"] = val
         | 
| 147 | 
            +
                    return val
         | 
| 148 | 
            +
             | 
| 149 | 
            +
                @property
         | 
| 150 | 
            +
                def functions(self):
         | 
| 151 | 
            +
                    return self._attributes.get("functions", [])
         | 
| 152 | 
            +
             | 
| 153 | 
            +
                @functions.setter
         | 
| 154 | 
            +
                def functions(self, val):
         | 
| 155 | 
            +
                    self._attributes["functions"] = val
         | 
| 156 | 
            +
                    return val
         | 
| 157 | 
            +
             | 
| 158 | 
            +
                @property
         | 
| 159 | 
            +
                def parameters(self):
         | 
| 160 | 
            +
                    return self._attributes.get("parameters", [])
         | 
| 161 | 
            +
             | 
| 162 | 
            +
                @parameters.setter
         | 
| 163 | 
            +
                def parameters(self, val):
         | 
| 164 | 
            +
                    self._attributes["parameters"] = val
         | 
| 165 | 
            +
                    return val
         | 
| 166 | 
            +
             | 
| 167 | 
            +
                @property
         | 
| 168 | 
            +
                def size(self):
         | 
| 169 | 
            +
                    return self._attributes.get("size")
         | 
| 170 | 
            +
             | 
| 171 | 
            +
                @size.setter
         | 
| 172 | 
            +
                def size(self, val):
         | 
| 173 | 
            +
                    """Tensor size."""
         | 
| 174 | 
            +
                    self._attributes["size"] = tuple(val)
         | 
| 175 | 
            +
                    return val
         | 
| 176 | 
            +
             | 
| 177 | 
            +
                @property
         | 
| 178 | 
            +
                def output_shape(self):
         | 
| 179 | 
            +
                    return self._attributes.get("output_shape")
         | 
| 180 | 
            +
             | 
| 181 | 
            +
                @output_shape.setter
         | 
| 182 | 
            +
                def output_shape(self, val):
         | 
| 183 | 
            +
                    """Tensor output_shape."""
         | 
| 184 | 
            +
                    self._attributes["output_shape"] = val
         | 
| 185 | 
            +
                    return val
         | 
| 186 | 
            +
             | 
| 187 | 
            +
                @property
         | 
| 188 | 
            +
                def is_output(self):
         | 
| 189 | 
            +
                    return self._attributes.get("is_output")
         | 
| 190 | 
            +
             | 
| 191 | 
            +
                @is_output.setter
         | 
| 192 | 
            +
                def is_output(self, val):
         | 
| 193 | 
            +
                    """Tensor is_output."""
         | 
| 194 | 
            +
                    self._attributes["is_output"] = val
         | 
| 195 | 
            +
                    return val
         | 
| 196 | 
            +
             | 
| 197 | 
            +
                @property
         | 
| 198 | 
            +
                def num_parameters(self):
         | 
| 199 | 
            +
                    return self._attributes.get("num_parameters")
         | 
| 200 | 
            +
             | 
| 201 | 
            +
                @num_parameters.setter
         | 
| 202 | 
            +
                def num_parameters(self, val):
         | 
| 203 | 
            +
                    """Tensor num_parameters."""
         | 
| 204 | 
            +
                    self._attributes["num_parameters"] = val
         | 
| 205 | 
            +
                    return val
         | 
| 206 | 
            +
             | 
| 207 | 
            +
                @property
         | 
| 208 | 
            +
                def child_parameters(self):
         | 
| 209 | 
            +
                    return self._attributes.get("child_parameters")
         | 
| 210 | 
            +
             | 
| 211 | 
            +
                @child_parameters.setter
         | 
| 212 | 
            +
                def child_parameters(self, val):
         | 
| 213 | 
            +
                    """Tensor child_parameters."""
         | 
| 214 | 
            +
                    self._attributes["child_parameters"] = val
         | 
| 215 | 
            +
                    return val
         | 
| 216 | 
            +
             | 
| 217 | 
            +
                @property
         | 
| 218 | 
            +
                def is_constant(self):
         | 
| 219 | 
            +
                    return self._attributes.get("is_constant")
         | 
| 220 | 
            +
             | 
| 221 | 
            +
                @is_constant.setter
         | 
| 222 | 
            +
                def is_constant(self, val):
         | 
| 223 | 
            +
                    """Tensor is_constant."""
         | 
| 224 | 
            +
                    self._attributes["is_constant"] = val
         | 
| 225 | 
            +
                    return val
         | 
| 226 | 
            +
             | 
| 227 | 
            +
                @classmethod
         | 
| 228 | 
            +
                def from_keras(cls, layer):
         | 
| 229 | 
            +
                    node = cls()
         | 
| 230 | 
            +
             | 
| 231 | 
            +
                    try:
         | 
| 232 | 
            +
                        output_shape = layer.output_shape
         | 
| 233 | 
            +
                    except AttributeError:
         | 
| 234 | 
            +
                        output_shape = ["multiple"]
         | 
| 235 | 
            +
             | 
| 236 | 
            +
                    node.id = layer.name
         | 
| 237 | 
            +
                    node.name = layer.name
         | 
| 238 | 
            +
                    node.class_name = layer.__class__.__name__
         | 
| 239 | 
            +
                    node.output_shape = output_shape
         | 
| 240 | 
            +
                    node.num_parameters = layer.count_params()
         | 
| 241 | 
            +
             | 
| 242 | 
            +
                    return node
         | 
| 243 | 
            +
             | 
| 244 | 
            +
             | 
| 245 | 
            +
            class Graph(Media):
         | 
| 246 | 
            +
                """Wandb class for graphs.
         | 
| 247 | 
            +
             | 
| 248 | 
            +
                This class is typically used for saving and displaying neural net models.  It
         | 
| 249 | 
            +
                represents the graph as an array of nodes and edges.  The nodes can have
         | 
| 250 | 
            +
                labels that can be visualized by wandb.
         | 
| 251 | 
            +
             | 
| 252 | 
            +
                Examples:
         | 
| 253 | 
            +
                    Import a keras model:
         | 
| 254 | 
            +
                    ```
         | 
| 255 | 
            +
                    Graph.from_keras(keras_model)
         | 
| 256 | 
            +
                    ```
         | 
| 257 | 
            +
             | 
| 258 | 
            +
                Attributes:
         | 
| 259 | 
            +
                    format (string): Format to help wandb display the graph nicely.
         | 
| 260 | 
            +
                    nodes ([wandb.Node]): List of wandb.Nodes
         | 
| 261 | 
            +
                    nodes_by_id (dict): dict of ids -> nodes
         | 
| 262 | 
            +
                    edges ([(wandb.Node, wandb.Node)]): List of pairs of nodes interpreted as edges
         | 
| 263 | 
            +
                    loaded (boolean): Flag to tell whether the graph is completely loaded
         | 
| 264 | 
            +
                    root (wandb.Node): root node of the graph
         | 
| 265 | 
            +
                """
         | 
| 266 | 
            +
             | 
| 267 | 
            +
                _log_type = "graph-file"
         | 
| 268 | 
            +
             | 
| 269 | 
            +
                def __init__(self, format="keras"):
         | 
| 270 | 
            +
                    super().__init__()
         | 
| 271 | 
            +
                    # LB: TODO: I think we should factor criterion and criterion_passed out
         | 
| 272 | 
            +
                    self.format = format
         | 
| 273 | 
            +
                    self.nodes = []
         | 
| 274 | 
            +
                    self.nodes_by_id = {}
         | 
| 275 | 
            +
                    self.edges = []
         | 
| 276 | 
            +
                    self.loaded = False
         | 
| 277 | 
            +
                    self.criterion = None
         | 
| 278 | 
            +
                    self.criterion_passed = False
         | 
| 279 | 
            +
                    self.root = None  # optional root Node if applicable
         | 
| 280 | 
            +
             | 
| 281 | 
            +
                def _to_graph_json(self, run=None):
         | 
| 282 | 
            +
                    # Needs to be its own function for tests
         | 
| 283 | 
            +
                    return {
         | 
| 284 | 
            +
                        "format": self.format,
         | 
| 285 | 
            +
                        "nodes": [node.to_json() for node in self.nodes],
         | 
| 286 | 
            +
                        "edges": [edge.to_json() for edge in self.edges],
         | 
| 287 | 
            +
                    }
         | 
| 288 | 
            +
             | 
| 289 | 
            +
                def bind_to_run(self, *args, **kwargs):
         | 
| 290 | 
            +
                    data = self._to_graph_json()
         | 
| 291 | 
            +
                    tmp_path = os.path.join(MEDIA_TMP.name, runid.generate_id() + ".graph.json")
         | 
| 292 | 
            +
                    data = _numpy_arrays_to_lists(data)
         | 
| 293 | 
            +
                    with codecs.open(tmp_path, "w", encoding="utf-8") as fp:
         | 
| 294 | 
            +
                        util.json_dump_safer(data, fp)
         | 
| 295 | 
            +
                    self._set_file(tmp_path, is_tmp=True, extension=".graph.json")
         | 
| 296 | 
            +
                    if self.is_bound():
         | 
| 297 | 
            +
                        return
         | 
| 298 | 
            +
                    super().bind_to_run(*args, **kwargs)
         | 
| 299 | 
            +
             | 
| 300 | 
            +
                @classmethod
         | 
| 301 | 
            +
                def get_media_subdir(cls):
         | 
| 302 | 
            +
                    return os.path.join("media", "graph")
         | 
| 303 | 
            +
             | 
| 304 | 
            +
                def to_json(self, run):
         | 
| 305 | 
            +
                    json_dict = super().to_json(run)
         | 
| 306 | 
            +
                    json_dict["_type"] = self._log_type
         | 
| 307 | 
            +
                    return json_dict
         | 
| 308 | 
            +
             | 
| 309 | 
            +
                def __getitem__(self, nid):
         | 
| 310 | 
            +
                    return self.nodes_by_id[nid]
         | 
| 311 | 
            +
             | 
| 312 | 
            +
                def pprint(self):
         | 
| 313 | 
            +
                    for edge in self.edges:
         | 
| 314 | 
            +
                        pprint.pprint(edge.attributes)
         | 
| 315 | 
            +
                    for node in self.nodes:
         | 
| 316 | 
            +
                        pprint.pprint(node.attributes)
         | 
| 317 | 
            +
             | 
| 318 | 
            +
                def add_node(self, node=None, **node_kwargs):
         | 
| 319 | 
            +
                    if node is None:
         | 
| 320 | 
            +
                        node = Node(**node_kwargs)
         | 
| 321 | 
            +
                    elif node_kwargs:
         | 
| 322 | 
            +
                        raise ValueError(
         | 
| 323 | 
            +
                            f"Only pass one of either node ({node}) or other keyword arguments ({node_kwargs})"
         | 
| 324 | 
            +
                        )
         | 
| 325 | 
            +
                    self.nodes.append(node)
         | 
| 326 | 
            +
                    self.nodes_by_id[node.id] = node
         | 
| 327 | 
            +
             | 
| 328 | 
            +
                    return node
         | 
| 329 | 
            +
             | 
| 330 | 
            +
                def add_edge(self, from_node, to_node):
         | 
| 331 | 
            +
                    edge = Edge(from_node, to_node)
         | 
| 332 | 
            +
                    self.edges.append(edge)
         | 
| 333 | 
            +
             | 
| 334 | 
            +
                    return edge
         | 
| 335 | 
            +
             | 
| 336 | 
            +
                @classmethod
         | 
| 337 | 
            +
                def from_keras(cls, model):
         | 
| 338 | 
            +
                    # TODO: his method requires a refactor to work with the keras 3.
         | 
| 339 | 
            +
                    graph = cls()
         | 
| 340 | 
            +
                    # Shamelessly copied (then modified) from keras/keras/utils/layer_utils.py
         | 
| 341 | 
            +
                    sequential_like = cls._is_sequential(model)
         | 
| 342 | 
            +
             | 
| 343 | 
            +
                    relevant_nodes = None
         | 
| 344 | 
            +
                    if not sequential_like:
         | 
| 345 | 
            +
                        relevant_nodes = []
         | 
| 346 | 
            +
                        for v in model._nodes_by_depth.values():
         | 
| 347 | 
            +
                            relevant_nodes += v
         | 
| 348 | 
            +
             | 
| 349 | 
            +
                    layers = model.layers
         | 
| 350 | 
            +
                    for i in range(len(layers)):
         | 
| 351 | 
            +
                        node = Node.from_keras(layers[i])
         | 
| 352 | 
            +
                        if hasattr(layers[i], "_inbound_nodes"):
         | 
| 353 | 
            +
                            for in_node in layers[i]._inbound_nodes:
         | 
| 354 | 
            +
                                if relevant_nodes and in_node not in relevant_nodes:
         | 
| 355 | 
            +
                                    # node is not part of the current network
         | 
| 356 | 
            +
                                    continue
         | 
| 357 | 
            +
                                for in_layer in _nest(in_node.inbound_layers):
         | 
| 358 | 
            +
                                    inbound_keras_node = Node.from_keras(in_layer)
         | 
| 359 | 
            +
             | 
| 360 | 
            +
                                    if inbound_keras_node.id not in graph.nodes_by_id:
         | 
| 361 | 
            +
                                        graph.add_node(inbound_keras_node)
         | 
| 362 | 
            +
                                    inbound_node = graph.nodes_by_id[inbound_keras_node.id]
         | 
| 363 | 
            +
             | 
| 364 | 
            +
                                    graph.add_edge(inbound_node, node)
         | 
| 365 | 
            +
                        graph.add_node(node)
         | 
| 366 | 
            +
                    return graph
         | 
| 367 | 
            +
             | 
| 368 | 
            +
                @classmethod
         | 
| 369 | 
            +
                def _is_sequential(cls, model):
         | 
| 370 | 
            +
                    sequential_like = True
         | 
| 371 | 
            +
             | 
| 372 | 
            +
                    if (
         | 
| 373 | 
            +
                        model.__class__.__name__ != "Sequential"
         | 
| 374 | 
            +
                        and hasattr(model, "_is_graph_network")
         | 
| 375 | 
            +
                        and model._is_graph_network
         | 
| 376 | 
            +
                    ):
         | 
| 377 | 
            +
                        nodes_by_depth = model._nodes_by_depth.values()
         | 
| 378 | 
            +
                        nodes = []
         | 
| 379 | 
            +
                        for v in nodes_by_depth:
         | 
| 380 | 
            +
                            # TensorFlow2 doesn't insure inbound is always a list
         | 
| 381 | 
            +
                            inbound = v[0].inbound_layers
         | 
| 382 | 
            +
                            if not hasattr(inbound, "__len__"):
         | 
| 383 | 
            +
                                inbound = [inbound]
         | 
| 384 | 
            +
                            if (len(v) > 1) or (len(v) == 1 and len(inbound) > 1):
         | 
| 385 | 
            +
                                # if the model has multiple nodes
         | 
| 386 | 
            +
                                # or if the nodes have multiple inbound_layers
         | 
| 387 | 
            +
                                # the model is no longer sequential
         | 
| 388 | 
            +
                                sequential_like = False
         | 
| 389 | 
            +
                                break
         | 
| 390 | 
            +
                            nodes += v
         | 
| 391 | 
            +
                        if sequential_like:
         | 
| 392 | 
            +
                            # search for shared layers
         | 
| 393 | 
            +
                            for layer in model.layers:
         | 
| 394 | 
            +
                                flag = False
         | 
| 395 | 
            +
                                if hasattr(layer, "_inbound_nodes"):
         | 
| 396 | 
            +
                                    for node in layer._inbound_nodes:
         | 
| 397 | 
            +
                                        if node in nodes:
         | 
| 398 | 
            +
                                            if flag:
         | 
| 399 | 
            +
                                                sequential_like = False
         | 
| 400 | 
            +
                                                break
         | 
| 401 | 
            +
                                            else:
         | 
| 402 | 
            +
                                                flag = True
         | 
| 403 | 
            +
                                if not sequential_like:
         | 
| 404 | 
            +
                                    break
         | 
| 405 | 
            +
                    return sequential_like
         |