wandb 0.18.0__py3-none-win32.whl → 0.18.1__py3-none-win32.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wandb/__init__.py +2 -2
- wandb/__init__.pyi +1 -1
- wandb/apis/public/runs.py +2 -0
- wandb/bin/wandb-core +0 -0
- wandb/cli/cli.py +0 -2
- wandb/data_types.py +9 -2019
- wandb/env.py +0 -5
- wandb/{sklearn → integration/sklearn}/calculate/calibration_curves.py +7 -7
- wandb/{sklearn → integration/sklearn}/calculate/class_proportions.py +1 -1
- wandb/{sklearn → integration/sklearn}/calculate/confusion_matrix.py +3 -2
- wandb/{sklearn → integration/sklearn}/calculate/elbow_curve.py +6 -6
- wandb/{sklearn → integration/sklearn}/calculate/learning_curve.py +2 -2
- wandb/{sklearn → integration/sklearn}/calculate/outlier_candidates.py +2 -2
- wandb/{sklearn → integration/sklearn}/calculate/residuals.py +8 -8
- wandb/{sklearn → integration/sklearn}/calculate/silhouette.py +2 -2
- wandb/{sklearn → integration/sklearn}/calculate/summary_metrics.py +2 -2
- wandb/{sklearn → integration/sklearn}/plot/classifier.py +5 -5
- wandb/{sklearn → integration/sklearn}/plot/clusterer.py +10 -6
- wandb/{sklearn → integration/sklearn}/plot/regressor.py +5 -5
- wandb/{sklearn → integration/sklearn}/plot/shared.py +3 -3
- wandb/{sklearn → integration/sklearn}/utils.py +8 -8
- wandb/{wandb_torch.py → integration/torch/wandb_torch.py} +36 -32
- wandb/proto/v3/wandb_base_pb2.py +2 -1
- wandb/proto/v3/wandb_internal_pb2.py +2 -1
- wandb/proto/v3/wandb_server_pb2.py +2 -1
- wandb/proto/v3/wandb_settings_pb2.py +2 -1
- wandb/proto/v3/wandb_telemetry_pb2.py +2 -1
- wandb/proto/v4/wandb_base_pb2.py +2 -1
- wandb/proto/v4/wandb_internal_pb2.py +2 -1
- wandb/proto/v4/wandb_server_pb2.py +2 -1
- wandb/proto/v4/wandb_settings_pb2.py +2 -1
- wandb/proto/v4/wandb_telemetry_pb2.py +2 -1
- wandb/proto/v5/wandb_base_pb2.py +3 -2
- wandb/proto/v5/wandb_internal_pb2.py +3 -2
- wandb/proto/v5/wandb_server_pb2.py +3 -2
- wandb/proto/v5/wandb_settings_pb2.py +3 -2
- wandb/proto/v5/wandb_telemetry_pb2.py +3 -2
- wandb/sdk/data_types/audio.py +165 -0
- wandb/sdk/data_types/bokeh.py +70 -0
- wandb/sdk/data_types/graph.py +405 -0
- wandb/sdk/data_types/image.py +156 -0
- wandb/sdk/data_types/table.py +1204 -0
- wandb/sdk/data_types/trace_tree.py +2 -2
- wandb/sdk/data_types/utils.py +49 -0
- wandb/sdk/service/service.py +2 -9
- wandb/sdk/service/streams.py +0 -7
- wandb/sdk/wandb_init.py +10 -3
- wandb/sdk/wandb_run.py +6 -152
- wandb/sdk/wandb_setup.py +1 -1
- wandb/sklearn.py +35 -0
- wandb/util.py +6 -2
- {wandb-0.18.0.dist-info → wandb-0.18.1.dist-info}/METADATA +1 -1
- {wandb-0.18.0.dist-info → wandb-0.18.1.dist-info}/RECORD +61 -57
- wandb/sdk/lib/console.py +0 -39
- /wandb/{sklearn → integration/sklearn}/__init__.py +0 -0
- /wandb/{sklearn → integration/sklearn}/calculate/__init__.py +0 -0
- /wandb/{sklearn → integration/sklearn}/calculate/decision_boundaries.py +0 -0
- /wandb/{sklearn → integration/sklearn}/calculate/feature_importances.py +0 -0
- /wandb/{sklearn → integration/sklearn}/plot/__init__.py +0 -0
- {wandb-0.18.0.dist-info → wandb-0.18.1.dist-info}/WHEEL +0 -0
- {wandb-0.18.0.dist-info → wandb-0.18.1.dist-info}/entry_points.txt +0 -0
- {wandb-0.18.0.dist-info → wandb-0.18.1.dist-info}/licenses/LICENSE +0 -0
wandb/data_types.py
CHANGED
@@ -13,30 +13,10 @@ serialize to JSON, since that is what wandb uses to save the objects locally
|
|
13
13
|
and upload them to the W&B server.
|
14
14
|
"""
|
15
15
|
|
16
|
-
import
|
17
|
-
import binascii
|
18
|
-
import codecs
|
19
|
-
import datetime
|
20
|
-
import hashlib
|
21
|
-
import json
|
22
|
-
import logging
|
23
|
-
import os
|
24
|
-
import pprint
|
25
|
-
from decimal import Decimal
|
26
|
-
from typing import Optional
|
27
|
-
|
28
|
-
import wandb
|
29
|
-
from wandb import util
|
30
|
-
from wandb.sdk.lib import filesystem
|
31
|
-
|
32
|
-
from .sdk.data_types import _dtypes
|
33
|
-
from .sdk.data_types._private import MEDIA_TMP
|
34
|
-
from .sdk.data_types.base_types.media import (
|
35
|
-
BatchableMedia,
|
36
|
-
Media,
|
37
|
-
_numpy_arrays_to_lists,
|
38
|
-
)
|
16
|
+
from .sdk.data_types.audio import Audio
|
39
17
|
from .sdk.data_types.base_types.wb_value import WBValue
|
18
|
+
from .sdk.data_types.bokeh import Bokeh
|
19
|
+
from .sdk.data_types.graph import Graph, Node
|
40
20
|
from .sdk.data_types.helper_types.bounding_boxes_2d import BoundingBoxes2D
|
41
21
|
from .sdk.data_types.helper_types.classes import Classes
|
42
22
|
from .sdk.data_types.helper_types.image_mask import ImageMask
|
@@ -47,9 +27,9 @@ from .sdk.data_types.molecule import Molecule
|
|
47
27
|
from .sdk.data_types.object_3d import Object3D, box3d
|
48
28
|
from .sdk.data_types.plotly import Plotly
|
49
29
|
from .sdk.data_types.saved_model import _SavedModel
|
30
|
+
from .sdk.data_types.table import JoinedTable, PartitionedTable, Table
|
50
31
|
from .sdk.data_types.trace_tree import WBTraceTree
|
51
32
|
from .sdk.data_types.video import Video
|
52
|
-
from .sdk.lib import runid
|
53
33
|
|
54
34
|
# Note: we are importing everything from the sdk/data_types to maintain a namespace for now.
|
55
35
|
# Once we fully type this file and move it all into sdk, then we will need to clean up the
|
@@ -59,7 +39,11 @@ __all__ = [
|
|
59
39
|
# Untyped Exports
|
60
40
|
"Audio",
|
61
41
|
"Table",
|
42
|
+
"JoinedTable",
|
43
|
+
"PartitionedTable",
|
62
44
|
"Bokeh",
|
45
|
+
"Node",
|
46
|
+
"Graph",
|
63
47
|
# Typed Exports
|
64
48
|
"Histogram",
|
65
49
|
"Html",
|
@@ -71,2003 +55,9 @@ __all__ = [
|
|
71
55
|
"Video",
|
72
56
|
"WBTraceTree",
|
73
57
|
"_SavedModel",
|
58
|
+
"WBValue",
|
74
59
|
# Typed Legacy Exports (I'd like to remove these)
|
75
60
|
"ImageMask",
|
76
61
|
"BoundingBoxes2D",
|
77
62
|
"Classes",
|
78
63
|
]
|
79
|
-
|
80
|
-
|
81
|
-
class _TableLinkMixin:
|
82
|
-
def set_table(self, table):
|
83
|
-
self._table = table
|
84
|
-
|
85
|
-
|
86
|
-
class _TableKey(str, _TableLinkMixin):
|
87
|
-
def set_table(self, table, col_name):
|
88
|
-
assert col_name in table.columns
|
89
|
-
self._table = table
|
90
|
-
self._col_name = col_name
|
91
|
-
|
92
|
-
|
93
|
-
class _TableIndex(int, _TableLinkMixin):
|
94
|
-
def get_row(self):
|
95
|
-
row = {}
|
96
|
-
if self._table:
|
97
|
-
row = {
|
98
|
-
c: self._table.data[self][i] for i, c in enumerate(self._table.columns)
|
99
|
-
}
|
100
|
-
|
101
|
-
return row
|
102
|
-
|
103
|
-
|
104
|
-
def _json_helper(val, artifact):
|
105
|
-
if isinstance(val, WBValue):
|
106
|
-
return val.to_json(artifact)
|
107
|
-
elif val.__class__ is dict:
|
108
|
-
res = {}
|
109
|
-
for key in val:
|
110
|
-
res[key] = _json_helper(val[key], artifact)
|
111
|
-
return res
|
112
|
-
|
113
|
-
if hasattr(val, "tolist"):
|
114
|
-
py_val = val.tolist()
|
115
|
-
if val.__class__.__name__ == "datetime64" and isinstance(py_val, int):
|
116
|
-
# when numpy datetime64 .tolist() returns an int, it is nanoseconds.
|
117
|
-
# need to convert to milliseconds
|
118
|
-
return _json_helper(py_val / int(1e6), artifact)
|
119
|
-
return _json_helper(py_val, artifact)
|
120
|
-
elif hasattr(val, "item"):
|
121
|
-
return _json_helper(val.item(), artifact)
|
122
|
-
|
123
|
-
if isinstance(val, datetime.datetime):
|
124
|
-
if val.tzinfo is None:
|
125
|
-
val = datetime.datetime(
|
126
|
-
val.year,
|
127
|
-
val.month,
|
128
|
-
val.day,
|
129
|
-
val.hour,
|
130
|
-
val.minute,
|
131
|
-
val.second,
|
132
|
-
val.microsecond,
|
133
|
-
tzinfo=datetime.timezone.utc,
|
134
|
-
)
|
135
|
-
return int(val.timestamp() * 1000)
|
136
|
-
elif isinstance(val, datetime.date):
|
137
|
-
return int(
|
138
|
-
datetime.datetime(
|
139
|
-
val.year, val.month, val.day, tzinfo=datetime.timezone.utc
|
140
|
-
).timestamp()
|
141
|
-
* 1000
|
142
|
-
)
|
143
|
-
elif isinstance(val, (list, tuple)):
|
144
|
-
return [_json_helper(i, artifact) for i in val]
|
145
|
-
elif isinstance(val, Decimal):
|
146
|
-
return float(val)
|
147
|
-
else:
|
148
|
-
return util.json_friendly(val)[0]
|
149
|
-
|
150
|
-
|
151
|
-
class Table(Media):
|
152
|
-
"""The Table class used to display and analyze tabular data.
|
153
|
-
|
154
|
-
Unlike traditional spreadsheets, Tables support numerous types of data:
|
155
|
-
scalar values, strings, numpy arrays, and most subclasses of `wandb.data_types.Media`.
|
156
|
-
This means you can embed `Images`, `Video`, `Audio`, and other sorts of rich, annotated media
|
157
|
-
directly in Tables, alongside other traditional scalar values.
|
158
|
-
|
159
|
-
This class is the primary class used to generate the Table Visualizer
|
160
|
-
in the UI: https://docs.wandb.ai/guides/data-vis/tables.
|
161
|
-
|
162
|
-
Arguments:
|
163
|
-
columns: (List[str]) Names of the columns in the table.
|
164
|
-
Defaults to ["Input", "Output", "Expected"].
|
165
|
-
data: (List[List[any]]) 2D row-oriented array of values.
|
166
|
-
dataframe: (pandas.DataFrame) DataFrame object used to create the table.
|
167
|
-
When set, `data` and `columns` arguments are ignored.
|
168
|
-
optional: (Union[bool,List[bool]]) Determines if `None` values are allowed. Default to True
|
169
|
-
- If a singular bool value, then the optionality is enforced for all
|
170
|
-
columns specified at construction time
|
171
|
-
- If a list of bool values, then the optionality is applied to each
|
172
|
-
column - should be the same length as `columns`
|
173
|
-
applies to all columns. A list of bool values applies to each respective column.
|
174
|
-
allow_mixed_types: (bool) Determines if columns are allowed to have mixed types
|
175
|
-
(disables type validation). Defaults to False
|
176
|
-
"""
|
177
|
-
|
178
|
-
MAX_ROWS = 10000
|
179
|
-
MAX_ARTIFACT_ROWS = 200000
|
180
|
-
_MAX_EMBEDDING_DIMENSIONS = 150
|
181
|
-
_log_type = "table"
|
182
|
-
|
183
|
-
def __init__(
|
184
|
-
self,
|
185
|
-
columns=None,
|
186
|
-
data=None,
|
187
|
-
rows=None,
|
188
|
-
dataframe=None,
|
189
|
-
dtype=None,
|
190
|
-
optional=True,
|
191
|
-
allow_mixed_types=False,
|
192
|
-
):
|
193
|
-
"""Initializes a Table object.
|
194
|
-
|
195
|
-
The rows is available for legacy reasons and should not be used.
|
196
|
-
The Table class uses data to mimic the Pandas API.
|
197
|
-
"""
|
198
|
-
super().__init__()
|
199
|
-
self._pk_col = None
|
200
|
-
self._fk_cols = set()
|
201
|
-
if allow_mixed_types:
|
202
|
-
dtype = _dtypes.AnyType
|
203
|
-
|
204
|
-
# This is kept for legacy reasons (tss: personally, I think we should remove this)
|
205
|
-
if columns is None:
|
206
|
-
columns = ["Input", "Output", "Expected"]
|
207
|
-
|
208
|
-
# Explicit dataframe option
|
209
|
-
if dataframe is not None:
|
210
|
-
self._init_from_dataframe(dataframe, columns, optional, dtype)
|
211
|
-
else:
|
212
|
-
# Expected pattern
|
213
|
-
if data is not None:
|
214
|
-
if util.is_numpy_array(data):
|
215
|
-
self._init_from_ndarray(data, columns, optional, dtype)
|
216
|
-
elif util.is_pandas_data_frame(data):
|
217
|
-
self._init_from_dataframe(data, columns, optional, dtype)
|
218
|
-
else:
|
219
|
-
self._init_from_list(data, columns, optional, dtype)
|
220
|
-
|
221
|
-
# legacy
|
222
|
-
elif rows is not None:
|
223
|
-
self._init_from_list(rows, columns, optional, dtype)
|
224
|
-
|
225
|
-
# Default empty case
|
226
|
-
else:
|
227
|
-
self._init_from_list([], columns, optional, dtype)
|
228
|
-
|
229
|
-
@staticmethod
|
230
|
-
def _assert_valid_columns(columns):
|
231
|
-
valid_col_types = [str, int]
|
232
|
-
assert isinstance(columns, list), "columns argument expects a `list` object"
|
233
|
-
assert len(columns) == 0 or all(
|
234
|
-
[type(col) in valid_col_types for col in columns]
|
235
|
-
), "columns argument expects list of strings or ints"
|
236
|
-
|
237
|
-
def _init_from_list(self, data, columns, optional=True, dtype=None):
|
238
|
-
assert isinstance(data, list), "data argument expects a `list` object"
|
239
|
-
self.data = []
|
240
|
-
self._assert_valid_columns(columns)
|
241
|
-
self.columns = columns
|
242
|
-
self._make_column_types(dtype, optional)
|
243
|
-
for row in data:
|
244
|
-
self.add_data(*row)
|
245
|
-
|
246
|
-
def _init_from_ndarray(self, ndarray, columns, optional=True, dtype=None):
|
247
|
-
assert util.is_numpy_array(
|
248
|
-
ndarray
|
249
|
-
), "ndarray argument expects a `numpy.ndarray` object"
|
250
|
-
self.data = []
|
251
|
-
self._assert_valid_columns(columns)
|
252
|
-
self.columns = columns
|
253
|
-
self._make_column_types(dtype, optional)
|
254
|
-
for row in ndarray:
|
255
|
-
self.add_data(*row)
|
256
|
-
|
257
|
-
def _init_from_dataframe(self, dataframe, columns, optional=True, dtype=None):
|
258
|
-
assert util.is_pandas_data_frame(
|
259
|
-
dataframe
|
260
|
-
), "dataframe argument expects a `pandas.core.frame.DataFrame` object"
|
261
|
-
self.data = []
|
262
|
-
columns = list(dataframe.columns)
|
263
|
-
self._assert_valid_columns(columns)
|
264
|
-
self.columns = columns
|
265
|
-
self._make_column_types(dtype, optional)
|
266
|
-
for row in range(len(dataframe)):
|
267
|
-
self.add_data(*tuple(dataframe[col].values[row] for col in self.columns))
|
268
|
-
|
269
|
-
def _make_column_types(self, dtype=None, optional=True):
|
270
|
-
if dtype is None:
|
271
|
-
dtype = _dtypes.UnknownType()
|
272
|
-
|
273
|
-
if optional.__class__ is not list:
|
274
|
-
optional = [optional for _ in range(len(self.columns))]
|
275
|
-
|
276
|
-
if dtype.__class__ is not list:
|
277
|
-
dtype = [dtype for _ in range(len(self.columns))]
|
278
|
-
|
279
|
-
self._column_types = _dtypes.TypedDictType({})
|
280
|
-
for col_name, opt, dt in zip(self.columns, optional, dtype):
|
281
|
-
self.cast(col_name, dt, opt)
|
282
|
-
|
283
|
-
def cast(self, col_name, dtype, optional=False):
|
284
|
-
"""Casts a column to a specific data type.
|
285
|
-
|
286
|
-
This can be one of the normal python classes, an internal W&B type, or an
|
287
|
-
example object, like an instance of wandb.Image or wandb.Classes.
|
288
|
-
|
289
|
-
Arguments:
|
290
|
-
col_name: (str) - The name of the column to cast.
|
291
|
-
dtype: (class, wandb.wandb_sdk.interface._dtypes.Type, any) - The target dtype.
|
292
|
-
optional: (bool) - If the column should allow Nones.
|
293
|
-
"""
|
294
|
-
assert col_name in self.columns
|
295
|
-
|
296
|
-
wbtype = _dtypes.TypeRegistry.type_from_dtype(dtype)
|
297
|
-
|
298
|
-
if optional:
|
299
|
-
wbtype = _dtypes.OptionalType(wbtype)
|
300
|
-
|
301
|
-
# Cast each value in the row, raising an error if there are invalid entries.
|
302
|
-
col_ndx = self.columns.index(col_name)
|
303
|
-
for row in self.data:
|
304
|
-
result_type = wbtype.assign(row[col_ndx])
|
305
|
-
if isinstance(result_type, _dtypes.InvalidType):
|
306
|
-
raise TypeError(
|
307
|
-
"Existing data {}, of type {} cannot be cast to {}".format(
|
308
|
-
row[col_ndx],
|
309
|
-
_dtypes.TypeRegistry.type_of(row[col_ndx]),
|
310
|
-
wbtype,
|
311
|
-
)
|
312
|
-
)
|
313
|
-
wbtype = result_type
|
314
|
-
|
315
|
-
# Assert valid options
|
316
|
-
is_pk = isinstance(wbtype, _PrimaryKeyType)
|
317
|
-
is_fk = isinstance(wbtype, _ForeignKeyType)
|
318
|
-
is_fi = isinstance(wbtype, _ForeignIndexType)
|
319
|
-
if is_pk or is_fk or is_fi:
|
320
|
-
assert (
|
321
|
-
not optional
|
322
|
-
), "Primary keys, foreign keys, and foreign indexes cannot be optional."
|
323
|
-
|
324
|
-
if (is_fk or is_fk) and id(wbtype.params["table"]) == id(self):
|
325
|
-
raise AssertionError("Cannot set a foreign table reference to same table.")
|
326
|
-
|
327
|
-
if is_pk:
|
328
|
-
assert (
|
329
|
-
self._pk_col is None
|
330
|
-
), "Cannot have multiple primary keys - {} is already set as the primary key.".format(
|
331
|
-
self._pk_col
|
332
|
-
)
|
333
|
-
|
334
|
-
# Update the column type
|
335
|
-
self._column_types.params["type_map"][col_name] = wbtype
|
336
|
-
|
337
|
-
# Wrap the data if needed
|
338
|
-
self._update_keys()
|
339
|
-
return wbtype
|
340
|
-
|
341
|
-
def __ne__(self, other):
|
342
|
-
return not self.__eq__(other)
|
343
|
-
|
344
|
-
def _eq_debug(self, other, should_assert=False):
|
345
|
-
eq = isinstance(other, Table)
|
346
|
-
assert not should_assert or eq, "Found type {}, expected {}".format(
|
347
|
-
other.__class__, Table
|
348
|
-
)
|
349
|
-
eq = eq and len(self.data) == len(other.data)
|
350
|
-
assert not should_assert or eq, "Found {} rows, expected {}".format(
|
351
|
-
len(other.data), len(self.data)
|
352
|
-
)
|
353
|
-
eq = eq and self.columns == other.columns
|
354
|
-
assert not should_assert or eq, "Found columns {}, expected {}".format(
|
355
|
-
other.columns, self.columns
|
356
|
-
)
|
357
|
-
eq = eq and self._column_types == other._column_types
|
358
|
-
assert (
|
359
|
-
not should_assert or eq
|
360
|
-
), "Found column type {}, expected column type {}".format(
|
361
|
-
other._column_types, self._column_types
|
362
|
-
)
|
363
|
-
if eq:
|
364
|
-
for row_ndx in range(len(self.data)):
|
365
|
-
for col_ndx in range(len(self.data[row_ndx])):
|
366
|
-
_eq = self.data[row_ndx][col_ndx] == other.data[row_ndx][col_ndx]
|
367
|
-
# equal if all are equal
|
368
|
-
if util.is_numpy_array(_eq):
|
369
|
-
_eq = ((_eq * -1) + 1).sum() == 0
|
370
|
-
eq = eq and _eq
|
371
|
-
assert (
|
372
|
-
not should_assert or eq
|
373
|
-
), "Unequal data at row_ndx {} col_ndx {}: found {}, expected {}".format(
|
374
|
-
row_ndx,
|
375
|
-
col_ndx,
|
376
|
-
other.data[row_ndx][col_ndx],
|
377
|
-
self.data[row_ndx][col_ndx],
|
378
|
-
)
|
379
|
-
if not eq:
|
380
|
-
return eq
|
381
|
-
return eq
|
382
|
-
|
383
|
-
def __eq__(self, other):
|
384
|
-
return self._eq_debug(other)
|
385
|
-
|
386
|
-
def add_row(self, *row):
|
387
|
-
"""Deprecated; use add_data instead."""
|
388
|
-
logging.warning("add_row is deprecated, use add_data")
|
389
|
-
self.add_data(*row)
|
390
|
-
|
391
|
-
def add_data(self, *data):
|
392
|
-
"""Adds a new row of data to the table. The maximum amount of rows in a table is determined by `wandb.Table.MAX_ARTIFACT_ROWS`.
|
393
|
-
|
394
|
-
The length of the data should match the length of the table column.
|
395
|
-
"""
|
396
|
-
if len(data) != len(self.columns):
|
397
|
-
raise ValueError(
|
398
|
-
"This table expects {} columns: {}, found {}".format(
|
399
|
-
len(self.columns), self.columns, len(data)
|
400
|
-
)
|
401
|
-
)
|
402
|
-
|
403
|
-
# Special case to pre-emptively cast a column as a key.
|
404
|
-
# Needed as String.assign(Key) is invalid
|
405
|
-
for ndx, item in enumerate(data):
|
406
|
-
if isinstance(item, _TableLinkMixin):
|
407
|
-
self.cast(
|
408
|
-
self.columns[ndx],
|
409
|
-
_dtypes.TypeRegistry.type_of(item),
|
410
|
-
optional=False,
|
411
|
-
)
|
412
|
-
|
413
|
-
# Update the table's column types
|
414
|
-
result_type = self._get_updated_result_type(data)
|
415
|
-
self._column_types = result_type
|
416
|
-
|
417
|
-
# rows need to be mutable
|
418
|
-
if isinstance(data, tuple):
|
419
|
-
data = list(data)
|
420
|
-
# Add the new data
|
421
|
-
self.data.append(data)
|
422
|
-
|
423
|
-
# Update the wrapper values if needed
|
424
|
-
self._update_keys(force_last=True)
|
425
|
-
|
426
|
-
def _get_updated_result_type(self, row):
|
427
|
-
"""Returns the updated result type based on the inputted row.
|
428
|
-
|
429
|
-
Raises:
|
430
|
-
TypeError: if the assignment is invalid.
|
431
|
-
"""
|
432
|
-
incoming_row_dict = {
|
433
|
-
col_key: row[ndx] for ndx, col_key in enumerate(self.columns)
|
434
|
-
}
|
435
|
-
current_type = self._column_types
|
436
|
-
result_type = current_type.assign(incoming_row_dict)
|
437
|
-
if isinstance(result_type, _dtypes.InvalidType):
|
438
|
-
raise TypeError(
|
439
|
-
"Data row contained incompatible types:\n{}".format(
|
440
|
-
current_type.explain(incoming_row_dict)
|
441
|
-
)
|
442
|
-
)
|
443
|
-
return result_type
|
444
|
-
|
445
|
-
def _to_table_json(self, max_rows=None, warn=True):
|
446
|
-
# separate this method for easier testing
|
447
|
-
if max_rows is None:
|
448
|
-
max_rows = Table.MAX_ROWS
|
449
|
-
n_rows = len(self.data)
|
450
|
-
if n_rows > max_rows and warn:
|
451
|
-
if wandb.run and (
|
452
|
-
wandb.run.settings.table_raise_on_max_row_limit_exceeded
|
453
|
-
or wandb.run.settings.strict
|
454
|
-
):
|
455
|
-
raise ValueError(
|
456
|
-
f"Table row limit exceeded: table has {n_rows} rows, limit is {max_rows}. "
|
457
|
-
f"To increase the maximum number of allowed rows in a wandb.Table, override "
|
458
|
-
f"the limit with `wandb.Table.MAX_ARTIFACT_ROWS = X` and try again. Note: "
|
459
|
-
f"this may cause slower queries in the W&B UI."
|
460
|
-
)
|
461
|
-
logging.warning("Truncating wandb.Table object to %i rows." % max_rows)
|
462
|
-
return {"columns": self.columns, "data": self.data[:max_rows]}
|
463
|
-
|
464
|
-
def bind_to_run(self, *args, **kwargs):
|
465
|
-
# We set `warn=False` since Tables will now always be logged to both
|
466
|
-
# files and artifacts. The file limit will never practically matter and
|
467
|
-
# this code path will be ultimately removed. The 10k limit warning confuses
|
468
|
-
# users given that we publicly say 200k is the limit.
|
469
|
-
data = self._to_table_json(warn=False)
|
470
|
-
tmp_path = os.path.join(MEDIA_TMP.name, runid.generate_id() + ".table.json")
|
471
|
-
data = _numpy_arrays_to_lists(data)
|
472
|
-
with codecs.open(tmp_path, "w", encoding="utf-8") as fp:
|
473
|
-
util.json_dump_safer(data, fp)
|
474
|
-
self._set_file(tmp_path, is_tmp=True, extension=".table.json")
|
475
|
-
super().bind_to_run(*args, **kwargs)
|
476
|
-
|
477
|
-
@classmethod
|
478
|
-
def get_media_subdir(cls):
|
479
|
-
return os.path.join("media", "table")
|
480
|
-
|
481
|
-
@classmethod
|
482
|
-
def from_json(cls, json_obj, source_artifact):
|
483
|
-
data = []
|
484
|
-
column_types = None
|
485
|
-
np_deserialized_columns = {}
|
486
|
-
timestamp_column_indices = set()
|
487
|
-
if json_obj.get("column_types") is not None:
|
488
|
-
column_types = _dtypes.TypeRegistry.type_from_dict(
|
489
|
-
json_obj["column_types"], source_artifact
|
490
|
-
)
|
491
|
-
for col_name in column_types.params["type_map"]:
|
492
|
-
col_type = column_types.params["type_map"][col_name]
|
493
|
-
ndarray_type = None
|
494
|
-
if isinstance(col_type, _dtypes.NDArrayType):
|
495
|
-
ndarray_type = col_type
|
496
|
-
elif isinstance(col_type, _dtypes.UnionType):
|
497
|
-
for t in col_type.params["allowed_types"]:
|
498
|
-
if isinstance(t, _dtypes.NDArrayType):
|
499
|
-
ndarray_type = t
|
500
|
-
elif isinstance(t, _dtypes.TimestampType):
|
501
|
-
timestamp_column_indices.add(
|
502
|
-
json_obj["columns"].index(col_name)
|
503
|
-
)
|
504
|
-
|
505
|
-
elif isinstance(col_type, _dtypes.TimestampType):
|
506
|
-
timestamp_column_indices.add(json_obj["columns"].index(col_name))
|
507
|
-
|
508
|
-
if (
|
509
|
-
ndarray_type is not None
|
510
|
-
and ndarray_type._get_serialization_path() is not None
|
511
|
-
):
|
512
|
-
serialization_path = ndarray_type._get_serialization_path()
|
513
|
-
np = util.get_module(
|
514
|
-
"numpy",
|
515
|
-
required="Deserializing NumPy columns requires NumPy to be installed.",
|
516
|
-
)
|
517
|
-
deserialized = np.load(
|
518
|
-
source_artifact.get_entry(serialization_path["path"]).download()
|
519
|
-
)
|
520
|
-
np_deserialized_columns[json_obj["columns"].index(col_name)] = (
|
521
|
-
deserialized[serialization_path["key"]]
|
522
|
-
)
|
523
|
-
ndarray_type._clear_serialization_path()
|
524
|
-
|
525
|
-
for r_ndx, row in enumerate(json_obj["data"]):
|
526
|
-
row_data = []
|
527
|
-
for c_ndx, item in enumerate(row):
|
528
|
-
cell = item
|
529
|
-
if c_ndx in timestamp_column_indices and isinstance(item, (int, float)):
|
530
|
-
cell = datetime.datetime.fromtimestamp(
|
531
|
-
item / 1000, tz=datetime.timezone.utc
|
532
|
-
)
|
533
|
-
elif c_ndx in np_deserialized_columns:
|
534
|
-
cell = np_deserialized_columns[c_ndx][r_ndx]
|
535
|
-
elif isinstance(item, dict) and "_type" in item:
|
536
|
-
obj = WBValue.init_from_json(item, source_artifact)
|
537
|
-
if obj is not None:
|
538
|
-
cell = obj
|
539
|
-
row_data.append(cell)
|
540
|
-
data.append(row_data)
|
541
|
-
|
542
|
-
# construct Table with dtypes for each column if type information exists
|
543
|
-
dtypes = None
|
544
|
-
if column_types is not None:
|
545
|
-
dtypes = [
|
546
|
-
column_types.params["type_map"][str(col)] for col in json_obj["columns"]
|
547
|
-
]
|
548
|
-
|
549
|
-
new_obj = cls(columns=json_obj["columns"], data=data, dtype=dtypes)
|
550
|
-
|
551
|
-
if column_types is not None:
|
552
|
-
new_obj._column_types = column_types
|
553
|
-
|
554
|
-
new_obj._update_keys()
|
555
|
-
return new_obj
|
556
|
-
|
557
|
-
def to_json(self, run_or_artifact):
|
558
|
-
json_dict = super().to_json(run_or_artifact)
|
559
|
-
|
560
|
-
if isinstance(run_or_artifact, wandb.wandb_sdk.wandb_run.Run):
|
561
|
-
json_dict.update(
|
562
|
-
{
|
563
|
-
"_type": "table-file",
|
564
|
-
"ncols": len(self.columns),
|
565
|
-
"nrows": len(self.data),
|
566
|
-
}
|
567
|
-
)
|
568
|
-
|
569
|
-
elif isinstance(run_or_artifact, wandb.Artifact):
|
570
|
-
artifact = run_or_artifact
|
571
|
-
mapped_data = []
|
572
|
-
data = self._to_table_json(Table.MAX_ARTIFACT_ROWS)["data"]
|
573
|
-
|
574
|
-
ndarray_col_ndxs = set()
|
575
|
-
for col_ndx, col_name in enumerate(self.columns):
|
576
|
-
col_type = self._column_types.params["type_map"][col_name]
|
577
|
-
ndarray_type = None
|
578
|
-
if isinstance(col_type, _dtypes.NDArrayType):
|
579
|
-
ndarray_type = col_type
|
580
|
-
elif isinstance(col_type, _dtypes.UnionType):
|
581
|
-
for t in col_type.params["allowed_types"]:
|
582
|
-
if isinstance(t, _dtypes.NDArrayType):
|
583
|
-
ndarray_type = t
|
584
|
-
|
585
|
-
# Do not serialize 1d arrays - these are likely embeddings and
|
586
|
-
# will not have the same cost as higher dimensional arrays
|
587
|
-
is_1d_array = (
|
588
|
-
ndarray_type is not None
|
589
|
-
and "shape" in ndarray_type._params
|
590
|
-
and isinstance(ndarray_type._params["shape"], list)
|
591
|
-
and len(ndarray_type._params["shape"]) == 1
|
592
|
-
and ndarray_type._params["shape"][0]
|
593
|
-
<= self._MAX_EMBEDDING_DIMENSIONS
|
594
|
-
)
|
595
|
-
if is_1d_array:
|
596
|
-
self._column_types.params["type_map"][col_name] = _dtypes.ListType(
|
597
|
-
_dtypes.NumberType, ndarray_type._params["shape"][0]
|
598
|
-
)
|
599
|
-
elif ndarray_type is not None:
|
600
|
-
np = util.get_module(
|
601
|
-
"numpy",
|
602
|
-
required="Serializing NumPy requires NumPy to be installed.",
|
603
|
-
)
|
604
|
-
file_name = f"{str(col_name)}_{runid.generate_id()}.npz"
|
605
|
-
npz_file_name = os.path.join(MEDIA_TMP.name, file_name)
|
606
|
-
np.savez_compressed(
|
607
|
-
npz_file_name,
|
608
|
-
**{
|
609
|
-
str(col_name): self.get_column(col_name, convert_to="numpy")
|
610
|
-
},
|
611
|
-
)
|
612
|
-
entry = artifact.add_file(
|
613
|
-
npz_file_name, "media/serialized_data/" + file_name, is_tmp=True
|
614
|
-
)
|
615
|
-
ndarray_type._set_serialization_path(entry.path, str(col_name))
|
616
|
-
ndarray_col_ndxs.add(col_ndx)
|
617
|
-
|
618
|
-
for row in data:
|
619
|
-
mapped_row = []
|
620
|
-
for ndx, v in enumerate(row):
|
621
|
-
if ndx in ndarray_col_ndxs:
|
622
|
-
mapped_row.append(None)
|
623
|
-
else:
|
624
|
-
mapped_row.append(_json_helper(v, artifact))
|
625
|
-
mapped_data.append(mapped_row)
|
626
|
-
|
627
|
-
json_dict.update(
|
628
|
-
{
|
629
|
-
"_type": Table._log_type,
|
630
|
-
"columns": self.columns,
|
631
|
-
"data": mapped_data,
|
632
|
-
"ncols": len(self.columns),
|
633
|
-
"nrows": len(mapped_data),
|
634
|
-
"column_types": self._column_types.to_json(artifact),
|
635
|
-
}
|
636
|
-
)
|
637
|
-
else:
|
638
|
-
raise ValueError("to_json accepts wandb_run.Run or wandb_artifact.Artifact")
|
639
|
-
|
640
|
-
return json_dict
|
641
|
-
|
642
|
-
def iterrows(self):
|
643
|
-
"""Returns the table data by row, showing the index of the row and the relevant data.
|
644
|
-
|
645
|
-
Yields:
|
646
|
-
------
|
647
|
-
index : int
|
648
|
-
The index of the row. Using this value in other W&B tables
|
649
|
-
will automatically build a relationship between the tables
|
650
|
-
row : List[any]
|
651
|
-
The data of the row.
|
652
|
-
"""
|
653
|
-
for ndx in range(len(self.data)):
|
654
|
-
index = _TableIndex(ndx)
|
655
|
-
index.set_table(self)
|
656
|
-
yield index, self.data[ndx]
|
657
|
-
|
658
|
-
def set_pk(self, col_name):
|
659
|
-
# TODO: Docs
|
660
|
-
assert col_name in self.columns
|
661
|
-
self.cast(col_name, _PrimaryKeyType())
|
662
|
-
|
663
|
-
def set_fk(self, col_name, table, table_col):
|
664
|
-
# TODO: Docs
|
665
|
-
assert col_name in self.columns
|
666
|
-
assert col_name != self._pk_col
|
667
|
-
self.cast(col_name, _ForeignKeyType(table, table_col))
|
668
|
-
|
669
|
-
def _update_keys(self, force_last=False):
|
670
|
-
"""Updates the known key-like columns based on current column types.
|
671
|
-
|
672
|
-
If the state has been updated since the last update, wraps the data
|
673
|
-
appropriately in the Key classes.
|
674
|
-
|
675
|
-
Arguments:
|
676
|
-
force_last: (bool) Wraps the last column of data even if there
|
677
|
-
are no key updates.
|
678
|
-
"""
|
679
|
-
_pk_col = None
|
680
|
-
_fk_cols = set()
|
681
|
-
|
682
|
-
# Buildup the known keys from column types
|
683
|
-
c_types = self._column_types.params["type_map"]
|
684
|
-
for t in c_types:
|
685
|
-
if isinstance(c_types[t], _PrimaryKeyType):
|
686
|
-
_pk_col = t
|
687
|
-
elif isinstance(c_types[t], _ForeignKeyType) or isinstance(
|
688
|
-
c_types[t], _ForeignIndexType
|
689
|
-
):
|
690
|
-
_fk_cols.add(t)
|
691
|
-
|
692
|
-
# If there are updates to perform, safely update them
|
693
|
-
has_update = _pk_col != self._pk_col or _fk_cols != self._fk_cols
|
694
|
-
if has_update:
|
695
|
-
# If we removed the PK
|
696
|
-
if _pk_col is None and self._pk_col is not None:
|
697
|
-
raise AssertionError(
|
698
|
-
f"Cannot unset primary key (column {self._pk_col})"
|
699
|
-
)
|
700
|
-
# If there is a removed FK
|
701
|
-
if len(self._fk_cols - _fk_cols) > 0:
|
702
|
-
raise AssertionError(
|
703
|
-
"Cannot unset foreign key. Attempted to unset ({})".format(
|
704
|
-
self._fk_cols - _fk_cols
|
705
|
-
)
|
706
|
-
)
|
707
|
-
|
708
|
-
self._pk_col = _pk_col
|
709
|
-
self._fk_cols = _fk_cols
|
710
|
-
|
711
|
-
# Apply updates to data only if there are update or the caller
|
712
|
-
# requested the final row to be updated
|
713
|
-
if has_update or force_last:
|
714
|
-
self._apply_key_updates(not has_update)
|
715
|
-
|
716
|
-
def _apply_key_updates(self, only_last=False):
|
717
|
-
"""Appropriately wraps the underlying data in special Key classes.
|
718
|
-
|
719
|
-
Arguments:
|
720
|
-
only_last: only apply the updates to the last row (used for performance when
|
721
|
-
the caller knows that the only new data is the last row and no updates were
|
722
|
-
applied to the column types)
|
723
|
-
"""
|
724
|
-
c_types = self._column_types.params["type_map"]
|
725
|
-
|
726
|
-
# Define a helper function which will wrap the data of a single row
|
727
|
-
# in the appropriate class wrapper.
|
728
|
-
def update_row(row_ndx):
|
729
|
-
for fk_col in self._fk_cols:
|
730
|
-
col_ndx = self.columns.index(fk_col)
|
731
|
-
|
732
|
-
# Wrap the Foreign Keys
|
733
|
-
if isinstance(c_types[fk_col], _ForeignKeyType) and not isinstance(
|
734
|
-
self.data[row_ndx][col_ndx], _TableKey
|
735
|
-
):
|
736
|
-
self.data[row_ndx][col_ndx] = _TableKey(self.data[row_ndx][col_ndx])
|
737
|
-
self.data[row_ndx][col_ndx].set_table(
|
738
|
-
c_types[fk_col].params["table"],
|
739
|
-
c_types[fk_col].params["col_name"],
|
740
|
-
)
|
741
|
-
|
742
|
-
# Wrap the Foreign Indexes
|
743
|
-
elif isinstance(c_types[fk_col], _ForeignIndexType) and not isinstance(
|
744
|
-
self.data[row_ndx][col_ndx], _TableIndex
|
745
|
-
):
|
746
|
-
self.data[row_ndx][col_ndx] = _TableIndex(
|
747
|
-
self.data[row_ndx][col_ndx]
|
748
|
-
)
|
749
|
-
self.data[row_ndx][col_ndx].set_table(
|
750
|
-
c_types[fk_col].params["table"]
|
751
|
-
)
|
752
|
-
|
753
|
-
# Wrap the Primary Key
|
754
|
-
if self._pk_col is not None:
|
755
|
-
col_ndx = self.columns.index(self._pk_col)
|
756
|
-
self.data[row_ndx][col_ndx] = _TableKey(self.data[row_ndx][col_ndx])
|
757
|
-
self.data[row_ndx][col_ndx].set_table(self, self._pk_col)
|
758
|
-
|
759
|
-
if only_last:
|
760
|
-
update_row(len(self.data) - 1)
|
761
|
-
else:
|
762
|
-
for row_ndx in range(len(self.data)):
|
763
|
-
update_row(row_ndx)
|
764
|
-
|
765
|
-
def add_column(self, name, data, optional=False):
|
766
|
-
"""Adds a column of data to the table.
|
767
|
-
|
768
|
-
Arguments:
|
769
|
-
name: (str) - the unique name of the column
|
770
|
-
data: (list | np.array) - a column of homogeneous data
|
771
|
-
optional: (bool) - if null-like values are permitted
|
772
|
-
"""
|
773
|
-
assert isinstance(name, str) and name not in self.columns
|
774
|
-
is_np = util.is_numpy_array(data)
|
775
|
-
assert isinstance(data, list) or is_np
|
776
|
-
assert isinstance(optional, bool)
|
777
|
-
is_first_col = len(self.columns) == 0
|
778
|
-
assert is_first_col or len(data) == len(
|
779
|
-
self.data
|
780
|
-
), f"Expected length {len(self.data)}, found {len(data)}"
|
781
|
-
|
782
|
-
# Add the new data
|
783
|
-
for ndx in range(max(len(data), len(self.data))):
|
784
|
-
if is_first_col:
|
785
|
-
self.data.append([])
|
786
|
-
if is_np:
|
787
|
-
self.data[ndx].append(data[ndx])
|
788
|
-
else:
|
789
|
-
self.data[ndx].append(data[ndx])
|
790
|
-
# add the column
|
791
|
-
self.columns.append(name)
|
792
|
-
|
793
|
-
try:
|
794
|
-
self.cast(name, _dtypes.UnknownType(), optional=optional)
|
795
|
-
except TypeError as err:
|
796
|
-
# Undo the changes
|
797
|
-
if is_first_col:
|
798
|
-
self.data = []
|
799
|
-
self.columns = []
|
800
|
-
else:
|
801
|
-
for ndx in range(len(self.data)):
|
802
|
-
self.data[ndx] = self.data[ndx][:-1]
|
803
|
-
self.columns = self.columns[:-1]
|
804
|
-
raise err
|
805
|
-
|
806
|
-
def get_column(self, name, convert_to=None):
|
807
|
-
"""Retrieves a column from the table and optionally converts it to a NumPy object.
|
808
|
-
|
809
|
-
Arguments:
|
810
|
-
name: (str) - the name of the column
|
811
|
-
convert_to: (str, optional)
|
812
|
-
- "numpy": will convert the underlying data to numpy object
|
813
|
-
"""
|
814
|
-
assert name in self.columns
|
815
|
-
assert convert_to is None or convert_to == "numpy"
|
816
|
-
if convert_to == "numpy":
|
817
|
-
np = util.get_module(
|
818
|
-
"numpy", required="Converting to NumPy requires installing NumPy"
|
819
|
-
)
|
820
|
-
col = []
|
821
|
-
col_ndx = self.columns.index(name)
|
822
|
-
for row in self.data:
|
823
|
-
item = row[col_ndx]
|
824
|
-
if convert_to is not None and isinstance(item, WBValue):
|
825
|
-
item = item.to_data_array()
|
826
|
-
col.append(item)
|
827
|
-
if convert_to == "numpy":
|
828
|
-
col = np.array(col)
|
829
|
-
return col
|
830
|
-
|
831
|
-
def get_index(self):
|
832
|
-
"""Returns an array of row indexes for use in other tables to create links."""
|
833
|
-
ndxs = []
|
834
|
-
for ndx in range(len(self.data)):
|
835
|
-
index = _TableIndex(ndx)
|
836
|
-
index.set_table(self)
|
837
|
-
ndxs.append(index)
|
838
|
-
return ndxs
|
839
|
-
|
840
|
-
def get_dataframe(self):
|
841
|
-
"""Returns a `pandas.DataFrame` of the table."""
|
842
|
-
pd = util.get_module(
|
843
|
-
"pandas",
|
844
|
-
required="Converting to pandas.DataFrame requires installing pandas",
|
845
|
-
)
|
846
|
-
return pd.DataFrame.from_records(self.data, columns=self.columns)
|
847
|
-
|
848
|
-
def index_ref(self, index):
|
849
|
-
"""Gets a reference of the index of a row in the table."""
|
850
|
-
assert index < len(self.data)
|
851
|
-
_index = _TableIndex(index)
|
852
|
-
_index.set_table(self)
|
853
|
-
return _index
|
854
|
-
|
855
|
-
def add_computed_columns(self, fn):
|
856
|
-
"""Adds one or more computed columns based on existing data.
|
857
|
-
|
858
|
-
Args:
|
859
|
-
fn: A function which accepts one or two parameters, ndx (int) and row (dict),
|
860
|
-
which is expected to return a dict representing new columns for that row, keyed
|
861
|
-
by the new column names.
|
862
|
-
|
863
|
-
`ndx` is an integer representing the index of the row. Only included if `include_ndx`
|
864
|
-
is set to `True`.
|
865
|
-
|
866
|
-
`row` is a dictionary keyed by existing columns
|
867
|
-
"""
|
868
|
-
new_columns = {}
|
869
|
-
for ndx, row in self.iterrows():
|
870
|
-
row_dict = {self.columns[i]: row[i] for i in range(len(self.columns))}
|
871
|
-
new_row_dict = fn(ndx, row_dict)
|
872
|
-
assert isinstance(new_row_dict, dict)
|
873
|
-
for key in new_row_dict:
|
874
|
-
new_columns[key] = new_columns.get(key, [])
|
875
|
-
new_columns[key].append(new_row_dict[key])
|
876
|
-
for new_col_name in new_columns:
|
877
|
-
self.add_column(new_col_name, new_columns[new_col_name])
|
878
|
-
|
879
|
-
|
880
|
-
class _PartitionTablePartEntry:
|
881
|
-
"""Helper class for PartitionTable to track its parts."""
|
882
|
-
|
883
|
-
def __init__(self, entry, source_artifact):
|
884
|
-
self.entry = entry
|
885
|
-
self.source_artifact = source_artifact
|
886
|
-
self._part = None
|
887
|
-
|
888
|
-
def get_part(self):
|
889
|
-
if self._part is None:
|
890
|
-
self._part = self.source_artifact.get(self.entry.path)
|
891
|
-
return self._part
|
892
|
-
|
893
|
-
def free(self):
|
894
|
-
self._part = None
|
895
|
-
|
896
|
-
|
897
|
-
class PartitionedTable(Media):
|
898
|
-
"""A table which is composed of multiple sub-tables.
|
899
|
-
|
900
|
-
Currently, PartitionedTable is designed to point to a directory within an artifact.
|
901
|
-
"""
|
902
|
-
|
903
|
-
_log_type = "partitioned-table"
|
904
|
-
|
905
|
-
def __init__(self, parts_path):
|
906
|
-
"""Initialize a PartitionedTable.
|
907
|
-
|
908
|
-
Args:
|
909
|
-
parts_path (str): path to a directory of tables in the artifact.
|
910
|
-
"""
|
911
|
-
super().__init__()
|
912
|
-
self.parts_path = parts_path
|
913
|
-
self._loaded_part_entries = {}
|
914
|
-
|
915
|
-
def to_json(self, artifact_or_run):
|
916
|
-
json_obj = {
|
917
|
-
"_type": PartitionedTable._log_type,
|
918
|
-
}
|
919
|
-
if isinstance(artifact_or_run, wandb.wandb_sdk.wandb_run.Run):
|
920
|
-
artifact_entry_url = self._get_artifact_entry_ref_url()
|
921
|
-
if artifact_entry_url is None:
|
922
|
-
raise ValueError(
|
923
|
-
"PartitionedTables must first be added to an Artifact before logging to a Run"
|
924
|
-
)
|
925
|
-
json_obj["artifact_path"] = artifact_entry_url
|
926
|
-
else:
|
927
|
-
json_obj["parts_path"] = self.parts_path
|
928
|
-
return json_obj
|
929
|
-
|
930
|
-
@classmethod
|
931
|
-
def from_json(cls, json_obj, source_artifact):
|
932
|
-
instance = cls(json_obj["parts_path"])
|
933
|
-
entries = source_artifact.manifest.get_entries_in_directory(
|
934
|
-
json_obj["parts_path"]
|
935
|
-
)
|
936
|
-
for entry in entries:
|
937
|
-
instance._add_part_entry(entry, source_artifact)
|
938
|
-
return instance
|
939
|
-
|
940
|
-
def iterrows(self):
|
941
|
-
"""Iterate over rows as (ndx, row).
|
942
|
-
|
943
|
-
Yields:
|
944
|
-
------
|
945
|
-
index : int
|
946
|
-
The index of the row.
|
947
|
-
row : List[any]
|
948
|
-
The data of the row.
|
949
|
-
"""
|
950
|
-
columns = None
|
951
|
-
ndx = 0
|
952
|
-
for entry_path in self._loaded_part_entries:
|
953
|
-
part = self._loaded_part_entries[entry_path].get_part()
|
954
|
-
if columns is None:
|
955
|
-
columns = part.columns
|
956
|
-
elif columns != part.columns:
|
957
|
-
raise ValueError(
|
958
|
-
"Table parts have non-matching columns. {} != {}".format(
|
959
|
-
columns, part.columns
|
960
|
-
)
|
961
|
-
)
|
962
|
-
for _, row in part.iterrows():
|
963
|
-
yield ndx, row
|
964
|
-
ndx += 1
|
965
|
-
|
966
|
-
self._loaded_part_entries[entry_path].free()
|
967
|
-
|
968
|
-
def _add_part_entry(self, entry, source_artifact):
|
969
|
-
self._loaded_part_entries[entry.path] = _PartitionTablePartEntry(
|
970
|
-
entry, source_artifact
|
971
|
-
)
|
972
|
-
|
973
|
-
def __ne__(self, other):
|
974
|
-
return not self.__eq__(other)
|
975
|
-
|
976
|
-
def __eq__(self, other):
|
977
|
-
return isinstance(other, self.__class__) and self.parts_path == other.parts_path
|
978
|
-
|
979
|
-
def bind_to_run(self, *args, **kwargs):
|
980
|
-
raise ValueError("PartitionedTables cannot be bound to runs")
|
981
|
-
|
982
|
-
|
983
|
-
class Audio(BatchableMedia):
|
984
|
-
"""Wandb class for audio clips.
|
985
|
-
|
986
|
-
Arguments:
|
987
|
-
data_or_path: (string or numpy array) A path to an audio file
|
988
|
-
or a numpy array of audio data.
|
989
|
-
sample_rate: (int) Sample rate, required when passing in raw
|
990
|
-
numpy array of audio data.
|
991
|
-
caption: (string) Caption to display with audio.
|
992
|
-
"""
|
993
|
-
|
994
|
-
_log_type = "audio-file"
|
995
|
-
|
996
|
-
def __init__(self, data_or_path, sample_rate=None, caption=None):
|
997
|
-
"""Accept a path to an audio file or a numpy array of audio data."""
|
998
|
-
super().__init__()
|
999
|
-
self._duration = None
|
1000
|
-
self._sample_rate = sample_rate
|
1001
|
-
self._caption = caption
|
1002
|
-
|
1003
|
-
if isinstance(data_or_path, str):
|
1004
|
-
if self.path_is_reference(data_or_path):
|
1005
|
-
self._path = data_or_path
|
1006
|
-
self._sha256 = hashlib.sha256(data_or_path.encode("utf-8")).hexdigest()
|
1007
|
-
self._is_tmp = False
|
1008
|
-
else:
|
1009
|
-
self._set_file(data_or_path, is_tmp=False)
|
1010
|
-
else:
|
1011
|
-
if sample_rate is None:
|
1012
|
-
raise ValueError(
|
1013
|
-
'Argument "sample_rate" is required when instantiating wandb.Audio with raw data.'
|
1014
|
-
)
|
1015
|
-
|
1016
|
-
soundfile = util.get_module(
|
1017
|
-
"soundfile",
|
1018
|
-
required='Raw audio requires the soundfile package. To get it, run "pip install soundfile"',
|
1019
|
-
)
|
1020
|
-
|
1021
|
-
tmp_path = os.path.join(MEDIA_TMP.name, runid.generate_id() + ".wav")
|
1022
|
-
soundfile.write(tmp_path, data_or_path, sample_rate)
|
1023
|
-
self._duration = len(data_or_path) / float(sample_rate)
|
1024
|
-
|
1025
|
-
self._set_file(tmp_path, is_tmp=True)
|
1026
|
-
|
1027
|
-
@classmethod
|
1028
|
-
def get_media_subdir(cls):
|
1029
|
-
return os.path.join("media", "audio")
|
1030
|
-
|
1031
|
-
@classmethod
|
1032
|
-
def from_json(cls, json_obj, source_artifact):
|
1033
|
-
return cls(
|
1034
|
-
source_artifact.get_entry(json_obj["path"]).download(),
|
1035
|
-
caption=json_obj["caption"],
|
1036
|
-
)
|
1037
|
-
|
1038
|
-
def bind_to_run(
|
1039
|
-
self, run, key, step, id_=None, ignore_copy_err: Optional[bool] = None
|
1040
|
-
):
|
1041
|
-
if self.path_is_reference(self._path):
|
1042
|
-
raise ValueError(
|
1043
|
-
"Audio media created by a reference to external storage cannot currently be added to a run"
|
1044
|
-
)
|
1045
|
-
|
1046
|
-
return super().bind_to_run(run, key, step, id_, ignore_copy_err)
|
1047
|
-
|
1048
|
-
def to_json(self, run):
|
1049
|
-
json_dict = super().to_json(run)
|
1050
|
-
json_dict.update(
|
1051
|
-
{
|
1052
|
-
"_type": self._log_type,
|
1053
|
-
"caption": self._caption,
|
1054
|
-
}
|
1055
|
-
)
|
1056
|
-
return json_dict
|
1057
|
-
|
1058
|
-
@classmethod
|
1059
|
-
def seq_to_json(cls, seq, run, key, step):
|
1060
|
-
audio_list = list(seq)
|
1061
|
-
|
1062
|
-
util.get_module(
|
1063
|
-
"soundfile",
|
1064
|
-
required="wandb.Audio requires the soundfile package. To get it, run: pip install soundfile",
|
1065
|
-
)
|
1066
|
-
base_path = os.path.join(run.dir, "media", "audio")
|
1067
|
-
filesystem.mkdir_exists_ok(base_path)
|
1068
|
-
meta = {
|
1069
|
-
"_type": "audio",
|
1070
|
-
"count": len(audio_list),
|
1071
|
-
"audio": [a.to_json(run) for a in audio_list],
|
1072
|
-
}
|
1073
|
-
sample_rates = cls.sample_rates(audio_list)
|
1074
|
-
if sample_rates:
|
1075
|
-
meta["sampleRates"] = sample_rates
|
1076
|
-
durations = cls.durations(audio_list)
|
1077
|
-
if durations:
|
1078
|
-
meta["durations"] = durations
|
1079
|
-
captions = cls.captions(audio_list)
|
1080
|
-
if captions:
|
1081
|
-
meta["captions"] = captions
|
1082
|
-
|
1083
|
-
return meta
|
1084
|
-
|
1085
|
-
@classmethod
|
1086
|
-
def durations(cls, audio_list):
|
1087
|
-
return [a._duration for a in audio_list]
|
1088
|
-
|
1089
|
-
@classmethod
|
1090
|
-
def sample_rates(cls, audio_list):
|
1091
|
-
return [a._sample_rate for a in audio_list]
|
1092
|
-
|
1093
|
-
@classmethod
|
1094
|
-
def captions(cls, audio_list):
|
1095
|
-
captions = [a._caption for a in audio_list]
|
1096
|
-
if all(c is None for c in captions):
|
1097
|
-
return False
|
1098
|
-
else:
|
1099
|
-
return ["" if c is None else c for c in captions]
|
1100
|
-
|
1101
|
-
def resolve_ref(self):
|
1102
|
-
if self.path_is_reference(self._path):
|
1103
|
-
# this object was already created using a ref:
|
1104
|
-
return self._path
|
1105
|
-
source_artifact = self._artifact_source.artifact
|
1106
|
-
|
1107
|
-
resolved_name = source_artifact._local_path_to_name(self._path)
|
1108
|
-
if resolved_name is not None:
|
1109
|
-
target_entry = source_artifact.manifest.get_entry_by_path(resolved_name)
|
1110
|
-
if target_entry is not None:
|
1111
|
-
return target_entry.ref
|
1112
|
-
|
1113
|
-
return None
|
1114
|
-
|
1115
|
-
def __eq__(self, other):
|
1116
|
-
if self.path_is_reference(self._path) or self.path_is_reference(other._path):
|
1117
|
-
# one or more of these objects is an unresolved reference -- we'll compare
|
1118
|
-
# their reference paths instead of their SHAs:
|
1119
|
-
return (
|
1120
|
-
self.resolve_ref() == other.resolve_ref()
|
1121
|
-
and self._caption == other._caption
|
1122
|
-
)
|
1123
|
-
|
1124
|
-
return super().__eq__(other) and self._caption == other._caption
|
1125
|
-
|
1126
|
-
def __ne__(self, other):
|
1127
|
-
return not self.__eq__(other)
|
1128
|
-
|
1129
|
-
|
1130
|
-
class JoinedTable(Media):
|
1131
|
-
"""Join two tables for visualization in the Artifact UI.
|
1132
|
-
|
1133
|
-
Arguments:
|
1134
|
-
table1 (str, wandb.Table, ArtifactManifestEntry):
|
1135
|
-
the path to a wandb.Table in an artifact, the table object, or ArtifactManifestEntry
|
1136
|
-
table2 (str, wandb.Table):
|
1137
|
-
the path to a wandb.Table in an artifact, the table object, or ArtifactManifestEntry
|
1138
|
-
join_key (str, [str, str]):
|
1139
|
-
key or keys to perform the join
|
1140
|
-
"""
|
1141
|
-
|
1142
|
-
_log_type = "joined-table"
|
1143
|
-
|
1144
|
-
def __init__(self, table1, table2, join_key):
|
1145
|
-
super().__init__()
|
1146
|
-
|
1147
|
-
if not isinstance(join_key, str) and (
|
1148
|
-
not isinstance(join_key, list) or len(join_key) != 2
|
1149
|
-
):
|
1150
|
-
raise ValueError(
|
1151
|
-
"JoinedTable join_key should be a string or a list of two strings"
|
1152
|
-
)
|
1153
|
-
|
1154
|
-
if not self._validate_table_input(table1):
|
1155
|
-
raise ValueError(
|
1156
|
-
"JoinedTable table1 should be an artifact path to a table or wandb.Table object"
|
1157
|
-
)
|
1158
|
-
|
1159
|
-
if not self._validate_table_input(table2):
|
1160
|
-
raise ValueError(
|
1161
|
-
"JoinedTable table2 should be an artifact path to a table or wandb.Table object"
|
1162
|
-
)
|
1163
|
-
|
1164
|
-
self._table1 = table1
|
1165
|
-
self._table2 = table2
|
1166
|
-
self._join_key = join_key
|
1167
|
-
|
1168
|
-
@classmethod
|
1169
|
-
def from_json(cls, json_obj, source_artifact):
|
1170
|
-
t1 = source_artifact.get(json_obj["table1"])
|
1171
|
-
if t1 is None:
|
1172
|
-
t1 = json_obj["table1"]
|
1173
|
-
|
1174
|
-
t2 = source_artifact.get(json_obj["table2"])
|
1175
|
-
if t2 is None:
|
1176
|
-
t2 = json_obj["table2"]
|
1177
|
-
|
1178
|
-
return cls(
|
1179
|
-
t1,
|
1180
|
-
t2,
|
1181
|
-
json_obj["join_key"],
|
1182
|
-
)
|
1183
|
-
|
1184
|
-
@staticmethod
|
1185
|
-
def _validate_table_input(table):
|
1186
|
-
"""Helper method to validate that the table input is one of the 3 supported types."""
|
1187
|
-
return (
|
1188
|
-
(isinstance(table, str) and table.endswith(".table.json"))
|
1189
|
-
or isinstance(table, Table)
|
1190
|
-
or isinstance(table, PartitionedTable)
|
1191
|
-
or (hasattr(table, "ref_url") and table.ref_url().endswith(".table.json"))
|
1192
|
-
)
|
1193
|
-
|
1194
|
-
def _ensure_table_in_artifact(self, table, artifact, table_ndx):
|
1195
|
-
"""Helper method to add the table to the incoming artifact. Returns the path."""
|
1196
|
-
if isinstance(table, Table) or isinstance(table, PartitionedTable):
|
1197
|
-
table_name = f"t{table_ndx}_{str(id(self))}"
|
1198
|
-
if (
|
1199
|
-
table._artifact_source is not None
|
1200
|
-
and table._artifact_source.name is not None
|
1201
|
-
):
|
1202
|
-
table_name = os.path.basename(table._artifact_source.name)
|
1203
|
-
entry = artifact.add(table, table_name)
|
1204
|
-
table = entry.path
|
1205
|
-
# Check if this is an ArtifactManifestEntry
|
1206
|
-
elif hasattr(table, "ref_url"):
|
1207
|
-
# Give the new object a unique, yet deterministic name
|
1208
|
-
name = binascii.hexlify(base64.standard_b64decode(table.digest)).decode(
|
1209
|
-
"ascii"
|
1210
|
-
)[:20]
|
1211
|
-
entry = artifact.add_reference(
|
1212
|
-
table.ref_url(), "{}.{}.json".format(name, table.name.split(".")[-2])
|
1213
|
-
)[0]
|
1214
|
-
table = entry.path
|
1215
|
-
|
1216
|
-
err_str = "JoinedTable table:{} not found in artifact. Add a table to the artifact using Artifact#add(<table>, {}) before adding this JoinedTable"
|
1217
|
-
if table not in artifact._manifest.entries:
|
1218
|
-
raise ValueError(err_str.format(table, table))
|
1219
|
-
|
1220
|
-
return table
|
1221
|
-
|
1222
|
-
def to_json(self, artifact_or_run):
|
1223
|
-
json_obj = {
|
1224
|
-
"_type": JoinedTable._log_type,
|
1225
|
-
}
|
1226
|
-
if isinstance(artifact_or_run, wandb.wandb_sdk.wandb_run.Run):
|
1227
|
-
artifact_entry_url = self._get_artifact_entry_ref_url()
|
1228
|
-
if artifact_entry_url is None:
|
1229
|
-
raise ValueError(
|
1230
|
-
"JoinedTables must first be added to an Artifact before logging to a Run"
|
1231
|
-
)
|
1232
|
-
json_obj["artifact_path"] = artifact_entry_url
|
1233
|
-
else:
|
1234
|
-
table1 = self._ensure_table_in_artifact(self._table1, artifact_or_run, 1)
|
1235
|
-
table2 = self._ensure_table_in_artifact(self._table2, artifact_or_run, 2)
|
1236
|
-
json_obj.update(
|
1237
|
-
{
|
1238
|
-
"table1": table1,
|
1239
|
-
"table2": table2,
|
1240
|
-
"join_key": self._join_key,
|
1241
|
-
}
|
1242
|
-
)
|
1243
|
-
return json_obj
|
1244
|
-
|
1245
|
-
def __ne__(self, other):
|
1246
|
-
return not self.__eq__(other)
|
1247
|
-
|
1248
|
-
def _eq_debug(self, other, should_assert=False):
|
1249
|
-
eq = isinstance(other, JoinedTable)
|
1250
|
-
assert not should_assert or eq, "Found type {}, expected {}".format(
|
1251
|
-
other.__class__, JoinedTable
|
1252
|
-
)
|
1253
|
-
eq = eq and self._join_key == other._join_key
|
1254
|
-
assert not should_assert or eq, "Found {} join key, expected {}".format(
|
1255
|
-
other._join_key, self._join_key
|
1256
|
-
)
|
1257
|
-
eq = eq and self._table1._eq_debug(other._table1, should_assert)
|
1258
|
-
eq = eq and self._table2._eq_debug(other._table2, should_assert)
|
1259
|
-
return eq
|
1260
|
-
|
1261
|
-
def __eq__(self, other):
|
1262
|
-
return self._eq_debug(other, False)
|
1263
|
-
|
1264
|
-
def bind_to_run(self, *args, **kwargs):
|
1265
|
-
raise ValueError("JoinedTables cannot be bound to runs")
|
1266
|
-
|
1267
|
-
|
1268
|
-
class Bokeh(Media):
|
1269
|
-
"""Wandb class for Bokeh plots.
|
1270
|
-
|
1271
|
-
Arguments:
|
1272
|
-
val: Bokeh plot
|
1273
|
-
"""
|
1274
|
-
|
1275
|
-
_log_type = "bokeh-file"
|
1276
|
-
|
1277
|
-
def __init__(self, data_or_path):
|
1278
|
-
super().__init__()
|
1279
|
-
bokeh = util.get_module("bokeh", required=True)
|
1280
|
-
if isinstance(data_or_path, str) and os.path.exists(data_or_path):
|
1281
|
-
with open(data_or_path) as file:
|
1282
|
-
b_json = json.load(file)
|
1283
|
-
self.b_obj = bokeh.document.Document.from_json(b_json)
|
1284
|
-
self._set_file(data_or_path, is_tmp=False, extension=".bokeh.json")
|
1285
|
-
elif isinstance(data_or_path, bokeh.model.Model):
|
1286
|
-
_data = bokeh.document.Document()
|
1287
|
-
_data.add_root(data_or_path)
|
1288
|
-
# serialize/deserialize pairing followed by sorting attributes ensures
|
1289
|
-
# that the file's sha's are equivalent in subsequent calls
|
1290
|
-
self.b_obj = bokeh.document.Document.from_json(_data.to_json())
|
1291
|
-
b_json = self.b_obj.to_json()
|
1292
|
-
if "references" in b_json["roots"]:
|
1293
|
-
b_json["roots"]["references"].sort(key=lambda x: x["id"])
|
1294
|
-
|
1295
|
-
tmp_path = os.path.join(MEDIA_TMP.name, runid.generate_id() + ".bokeh.json")
|
1296
|
-
with codecs.open(tmp_path, "w", encoding="utf-8") as fp:
|
1297
|
-
util.json_dump_safer(b_json, fp)
|
1298
|
-
self._set_file(tmp_path, is_tmp=True, extension=".bokeh.json")
|
1299
|
-
elif not isinstance(data_or_path, bokeh.document.Document):
|
1300
|
-
raise TypeError(
|
1301
|
-
"Bokeh constructor accepts Bokeh document/model or path to Bokeh json file"
|
1302
|
-
)
|
1303
|
-
|
1304
|
-
def get_media_subdir(self):
|
1305
|
-
return os.path.join("media", "bokeh")
|
1306
|
-
|
1307
|
-
def to_json(self, run):
|
1308
|
-
# TODO: (tss) this is getting redundant for all the media objects. We can probably
|
1309
|
-
# pull this into Media#to_json and remove this type override for all the media types.
|
1310
|
-
# There are only a few cases where the type is different between artifacts and runs.
|
1311
|
-
json_dict = super().to_json(run)
|
1312
|
-
json_dict["_type"] = self._log_type
|
1313
|
-
return json_dict
|
1314
|
-
|
1315
|
-
@classmethod
|
1316
|
-
def from_json(cls, json_obj, source_artifact):
|
1317
|
-
return cls(source_artifact.get_entry(json_obj["path"]).download())
|
1318
|
-
|
1319
|
-
|
1320
|
-
def _nest(thing):
|
1321
|
-
# Use tensorflows nest function if available, otherwise just wrap object in an array"""
|
1322
|
-
|
1323
|
-
tfutil = util.get_module("tensorflow.python.util")
|
1324
|
-
if tfutil:
|
1325
|
-
return tfutil.nest.flatten(thing)
|
1326
|
-
else:
|
1327
|
-
return [thing]
|
1328
|
-
|
1329
|
-
|
1330
|
-
class Graph(Media):
|
1331
|
-
"""Wandb class for graphs.
|
1332
|
-
|
1333
|
-
This class is typically used for saving and displaying neural net models. It
|
1334
|
-
represents the graph as an array of nodes and edges. The nodes can have
|
1335
|
-
labels that can be visualized by wandb.
|
1336
|
-
|
1337
|
-
Examples:
|
1338
|
-
Import a keras model:
|
1339
|
-
```
|
1340
|
-
Graph.from_keras(keras_model)
|
1341
|
-
```
|
1342
|
-
|
1343
|
-
Attributes:
|
1344
|
-
format (string): Format to help wandb display the graph nicely.
|
1345
|
-
nodes ([wandb.Node]): List of wandb.Nodes
|
1346
|
-
nodes_by_id (dict): dict of ids -> nodes
|
1347
|
-
edges ([(wandb.Node, wandb.Node)]): List of pairs of nodes interpreted as edges
|
1348
|
-
loaded (boolean): Flag to tell whether the graph is completely loaded
|
1349
|
-
root (wandb.Node): root node of the graph
|
1350
|
-
"""
|
1351
|
-
|
1352
|
-
_log_type = "graph-file"
|
1353
|
-
|
1354
|
-
def __init__(self, format="keras"):
|
1355
|
-
super().__init__()
|
1356
|
-
# LB: TODO: I think we should factor criterion and criterion_passed out
|
1357
|
-
self.format = format
|
1358
|
-
self.nodes = []
|
1359
|
-
self.nodes_by_id = {}
|
1360
|
-
self.edges = []
|
1361
|
-
self.loaded = False
|
1362
|
-
self.criterion = None
|
1363
|
-
self.criterion_passed = False
|
1364
|
-
self.root = None # optional root Node if applicable
|
1365
|
-
|
1366
|
-
def _to_graph_json(self, run=None):
|
1367
|
-
# Needs to be its own function for tests
|
1368
|
-
return {
|
1369
|
-
"format": self.format,
|
1370
|
-
"nodes": [node.to_json() for node in self.nodes],
|
1371
|
-
"edges": [edge.to_json() for edge in self.edges],
|
1372
|
-
}
|
1373
|
-
|
1374
|
-
def bind_to_run(self, *args, **kwargs):
|
1375
|
-
data = self._to_graph_json()
|
1376
|
-
tmp_path = os.path.join(MEDIA_TMP.name, runid.generate_id() + ".graph.json")
|
1377
|
-
data = _numpy_arrays_to_lists(data)
|
1378
|
-
with codecs.open(tmp_path, "w", encoding="utf-8") as fp:
|
1379
|
-
util.json_dump_safer(data, fp)
|
1380
|
-
self._set_file(tmp_path, is_tmp=True, extension=".graph.json")
|
1381
|
-
if self.is_bound():
|
1382
|
-
return
|
1383
|
-
super().bind_to_run(*args, **kwargs)
|
1384
|
-
|
1385
|
-
@classmethod
|
1386
|
-
def get_media_subdir(cls):
|
1387
|
-
return os.path.join("media", "graph")
|
1388
|
-
|
1389
|
-
def to_json(self, run):
|
1390
|
-
json_dict = super().to_json(run)
|
1391
|
-
json_dict["_type"] = self._log_type
|
1392
|
-
return json_dict
|
1393
|
-
|
1394
|
-
def __getitem__(self, nid):
|
1395
|
-
return self.nodes_by_id[nid]
|
1396
|
-
|
1397
|
-
def pprint(self):
|
1398
|
-
for edge in self.edges:
|
1399
|
-
pprint.pprint(edge.attributes)
|
1400
|
-
for node in self.nodes:
|
1401
|
-
pprint.pprint(node.attributes)
|
1402
|
-
|
1403
|
-
def add_node(self, node=None, **node_kwargs):
|
1404
|
-
if node is None:
|
1405
|
-
node = Node(**node_kwargs)
|
1406
|
-
elif node_kwargs:
|
1407
|
-
raise ValueError(
|
1408
|
-
f"Only pass one of either node ({node}) or other keyword arguments ({node_kwargs})"
|
1409
|
-
)
|
1410
|
-
self.nodes.append(node)
|
1411
|
-
self.nodes_by_id[node.id] = node
|
1412
|
-
|
1413
|
-
return node
|
1414
|
-
|
1415
|
-
def add_edge(self, from_node, to_node):
|
1416
|
-
edge = Edge(from_node, to_node)
|
1417
|
-
self.edges.append(edge)
|
1418
|
-
|
1419
|
-
return edge
|
1420
|
-
|
1421
|
-
@classmethod
|
1422
|
-
def from_keras(cls, model):
|
1423
|
-
# TODO: his method requires a refactor to work with the keras 3.
|
1424
|
-
graph = cls()
|
1425
|
-
# Shamelessly copied (then modified) from keras/keras/utils/layer_utils.py
|
1426
|
-
sequential_like = cls._is_sequential(model)
|
1427
|
-
|
1428
|
-
relevant_nodes = None
|
1429
|
-
if not sequential_like:
|
1430
|
-
relevant_nodes = []
|
1431
|
-
for v in model._nodes_by_depth.values():
|
1432
|
-
relevant_nodes += v
|
1433
|
-
|
1434
|
-
layers = model.layers
|
1435
|
-
for i in range(len(layers)):
|
1436
|
-
node = Node.from_keras(layers[i])
|
1437
|
-
if hasattr(layers[i], "_inbound_nodes"):
|
1438
|
-
for in_node in layers[i]._inbound_nodes:
|
1439
|
-
if relevant_nodes and in_node not in relevant_nodes:
|
1440
|
-
# node is not part of the current network
|
1441
|
-
continue
|
1442
|
-
for in_layer in _nest(in_node.inbound_layers):
|
1443
|
-
inbound_keras_node = Node.from_keras(in_layer)
|
1444
|
-
|
1445
|
-
if inbound_keras_node.id not in graph.nodes_by_id:
|
1446
|
-
graph.add_node(inbound_keras_node)
|
1447
|
-
inbound_node = graph.nodes_by_id[inbound_keras_node.id]
|
1448
|
-
|
1449
|
-
graph.add_edge(inbound_node, node)
|
1450
|
-
graph.add_node(node)
|
1451
|
-
return graph
|
1452
|
-
|
1453
|
-
@classmethod
|
1454
|
-
def _is_sequential(cls, model):
|
1455
|
-
sequential_like = True
|
1456
|
-
|
1457
|
-
if (
|
1458
|
-
model.__class__.__name__ != "Sequential"
|
1459
|
-
and hasattr(model, "_is_graph_network")
|
1460
|
-
and model._is_graph_network
|
1461
|
-
):
|
1462
|
-
nodes_by_depth = model._nodes_by_depth.values()
|
1463
|
-
nodes = []
|
1464
|
-
for v in nodes_by_depth:
|
1465
|
-
# TensorFlow2 doesn't insure inbound is always a list
|
1466
|
-
inbound = v[0].inbound_layers
|
1467
|
-
if not hasattr(inbound, "__len__"):
|
1468
|
-
inbound = [inbound]
|
1469
|
-
if (len(v) > 1) or (len(v) == 1 and len(inbound) > 1):
|
1470
|
-
# if the model has multiple nodes
|
1471
|
-
# or if the nodes have multiple inbound_layers
|
1472
|
-
# the model is no longer sequential
|
1473
|
-
sequential_like = False
|
1474
|
-
break
|
1475
|
-
nodes += v
|
1476
|
-
if sequential_like:
|
1477
|
-
# search for shared layers
|
1478
|
-
for layer in model.layers:
|
1479
|
-
flag = False
|
1480
|
-
if hasattr(layer, "_inbound_nodes"):
|
1481
|
-
for node in layer._inbound_nodes:
|
1482
|
-
if node in nodes:
|
1483
|
-
if flag:
|
1484
|
-
sequential_like = False
|
1485
|
-
break
|
1486
|
-
else:
|
1487
|
-
flag = True
|
1488
|
-
if not sequential_like:
|
1489
|
-
break
|
1490
|
-
return sequential_like
|
1491
|
-
|
1492
|
-
|
1493
|
-
class Node(WBValue):
|
1494
|
-
"""Node used in `Graph`."""
|
1495
|
-
|
1496
|
-
def __init__(
|
1497
|
-
self,
|
1498
|
-
id=None,
|
1499
|
-
name=None,
|
1500
|
-
class_name=None,
|
1501
|
-
size=None,
|
1502
|
-
parameters=None,
|
1503
|
-
output_shape=None,
|
1504
|
-
is_output=None,
|
1505
|
-
num_parameters=None,
|
1506
|
-
node=None,
|
1507
|
-
):
|
1508
|
-
self._attributes = {"name": None}
|
1509
|
-
self.in_edges = {} # indexed by source node id
|
1510
|
-
self.out_edges = {} # indexed by dest node id
|
1511
|
-
# optional object (e.g. PyTorch Parameter or Module) that this Node represents
|
1512
|
-
self.obj = None
|
1513
|
-
|
1514
|
-
if node is not None:
|
1515
|
-
self._attributes.update(node._attributes)
|
1516
|
-
del self._attributes["id"]
|
1517
|
-
self.obj = node.obj
|
1518
|
-
|
1519
|
-
if id is not None:
|
1520
|
-
self.id = id
|
1521
|
-
if name is not None:
|
1522
|
-
self.name = name
|
1523
|
-
if class_name is not None:
|
1524
|
-
self.class_name = class_name
|
1525
|
-
if size is not None:
|
1526
|
-
self.size = size
|
1527
|
-
if parameters is not None:
|
1528
|
-
self.parameters = parameters
|
1529
|
-
if output_shape is not None:
|
1530
|
-
self.output_shape = output_shape
|
1531
|
-
if is_output is not None:
|
1532
|
-
self.is_output = is_output
|
1533
|
-
if num_parameters is not None:
|
1534
|
-
self.num_parameters = num_parameters
|
1535
|
-
|
1536
|
-
def to_json(self, run=None):
|
1537
|
-
return self._attributes
|
1538
|
-
|
1539
|
-
def __repr__(self):
|
1540
|
-
return repr(self._attributes)
|
1541
|
-
|
1542
|
-
@property
|
1543
|
-
def id(self):
|
1544
|
-
"""Must be unique in the graph."""
|
1545
|
-
return self._attributes.get("id")
|
1546
|
-
|
1547
|
-
@id.setter
|
1548
|
-
def id(self, val):
|
1549
|
-
self._attributes["id"] = val
|
1550
|
-
return val
|
1551
|
-
|
1552
|
-
@property
|
1553
|
-
def name(self):
|
1554
|
-
"""Usually the type of layer or sublayer."""
|
1555
|
-
return self._attributes.get("name")
|
1556
|
-
|
1557
|
-
@name.setter
|
1558
|
-
def name(self, val):
|
1559
|
-
self._attributes["name"] = val
|
1560
|
-
return val
|
1561
|
-
|
1562
|
-
@property
|
1563
|
-
def class_name(self):
|
1564
|
-
"""Usually the type of layer or sublayer."""
|
1565
|
-
return self._attributes.get("class_name")
|
1566
|
-
|
1567
|
-
@class_name.setter
|
1568
|
-
def class_name(self, val):
|
1569
|
-
self._attributes["class_name"] = val
|
1570
|
-
return val
|
1571
|
-
|
1572
|
-
@property
|
1573
|
-
def functions(self):
|
1574
|
-
return self._attributes.get("functions", [])
|
1575
|
-
|
1576
|
-
@functions.setter
|
1577
|
-
def functions(self, val):
|
1578
|
-
self._attributes["functions"] = val
|
1579
|
-
return val
|
1580
|
-
|
1581
|
-
@property
|
1582
|
-
def parameters(self):
|
1583
|
-
return self._attributes.get("parameters", [])
|
1584
|
-
|
1585
|
-
@parameters.setter
|
1586
|
-
def parameters(self, val):
|
1587
|
-
self._attributes["parameters"] = val
|
1588
|
-
return val
|
1589
|
-
|
1590
|
-
@property
|
1591
|
-
def size(self):
|
1592
|
-
return self._attributes.get("size")
|
1593
|
-
|
1594
|
-
@size.setter
|
1595
|
-
def size(self, val):
|
1596
|
-
"""Tensor size."""
|
1597
|
-
self._attributes["size"] = tuple(val)
|
1598
|
-
return val
|
1599
|
-
|
1600
|
-
@property
|
1601
|
-
def output_shape(self):
|
1602
|
-
return self._attributes.get("output_shape")
|
1603
|
-
|
1604
|
-
@output_shape.setter
|
1605
|
-
def output_shape(self, val):
|
1606
|
-
"""Tensor output_shape."""
|
1607
|
-
self._attributes["output_shape"] = val
|
1608
|
-
return val
|
1609
|
-
|
1610
|
-
@property
|
1611
|
-
def is_output(self):
|
1612
|
-
return self._attributes.get("is_output")
|
1613
|
-
|
1614
|
-
@is_output.setter
|
1615
|
-
def is_output(self, val):
|
1616
|
-
"""Tensor is_output."""
|
1617
|
-
self._attributes["is_output"] = val
|
1618
|
-
return val
|
1619
|
-
|
1620
|
-
@property
|
1621
|
-
def num_parameters(self):
|
1622
|
-
return self._attributes.get("num_parameters")
|
1623
|
-
|
1624
|
-
@num_parameters.setter
|
1625
|
-
def num_parameters(self, val):
|
1626
|
-
"""Tensor num_parameters."""
|
1627
|
-
self._attributes["num_parameters"] = val
|
1628
|
-
return val
|
1629
|
-
|
1630
|
-
@property
|
1631
|
-
def child_parameters(self):
|
1632
|
-
return self._attributes.get("child_parameters")
|
1633
|
-
|
1634
|
-
@child_parameters.setter
|
1635
|
-
def child_parameters(self, val):
|
1636
|
-
"""Tensor child_parameters."""
|
1637
|
-
self._attributes["child_parameters"] = val
|
1638
|
-
return val
|
1639
|
-
|
1640
|
-
@property
|
1641
|
-
def is_constant(self):
|
1642
|
-
return self._attributes.get("is_constant")
|
1643
|
-
|
1644
|
-
@is_constant.setter
|
1645
|
-
def is_constant(self, val):
|
1646
|
-
"""Tensor is_constant."""
|
1647
|
-
self._attributes["is_constant"] = val
|
1648
|
-
return val
|
1649
|
-
|
1650
|
-
@classmethod
|
1651
|
-
def from_keras(cls, layer):
|
1652
|
-
node = cls()
|
1653
|
-
|
1654
|
-
try:
|
1655
|
-
output_shape = layer.output_shape
|
1656
|
-
except AttributeError:
|
1657
|
-
output_shape = ["multiple"]
|
1658
|
-
|
1659
|
-
node.id = layer.name
|
1660
|
-
node.name = layer.name
|
1661
|
-
node.class_name = layer.__class__.__name__
|
1662
|
-
node.output_shape = output_shape
|
1663
|
-
node.num_parameters = layer.count_params()
|
1664
|
-
|
1665
|
-
return node
|
1666
|
-
|
1667
|
-
|
1668
|
-
class Edge(WBValue):
|
1669
|
-
"""Edge used in `Graph`."""
|
1670
|
-
|
1671
|
-
def __init__(self, from_node, to_node):
|
1672
|
-
self._attributes = {}
|
1673
|
-
self.from_node = from_node
|
1674
|
-
self.to_node = to_node
|
1675
|
-
|
1676
|
-
def __repr__(self):
|
1677
|
-
temp_attr = dict(self._attributes)
|
1678
|
-
del temp_attr["from_node"]
|
1679
|
-
del temp_attr["to_node"]
|
1680
|
-
temp_attr["from_id"] = self.from_node.id
|
1681
|
-
temp_attr["to_id"] = self.to_node.id
|
1682
|
-
return str(temp_attr)
|
1683
|
-
|
1684
|
-
def to_json(self, run=None):
|
1685
|
-
return [self.from_node.id, self.to_node.id]
|
1686
|
-
|
1687
|
-
@property
|
1688
|
-
def name(self):
|
1689
|
-
"""Optional, not necessarily unique."""
|
1690
|
-
return self._attributes.get("name")
|
1691
|
-
|
1692
|
-
@name.setter
|
1693
|
-
def name(self, val):
|
1694
|
-
self._attributes["name"] = val
|
1695
|
-
return val
|
1696
|
-
|
1697
|
-
@property
|
1698
|
-
def from_node(self):
|
1699
|
-
return self._attributes.get("from_node")
|
1700
|
-
|
1701
|
-
@from_node.setter
|
1702
|
-
def from_node(self, val):
|
1703
|
-
self._attributes["from_node"] = val
|
1704
|
-
return val
|
1705
|
-
|
1706
|
-
@property
|
1707
|
-
def to_node(self):
|
1708
|
-
return self._attributes.get("to_node")
|
1709
|
-
|
1710
|
-
@to_node.setter
|
1711
|
-
def to_node(self, val):
|
1712
|
-
self._attributes["to_node"] = val
|
1713
|
-
return val
|
1714
|
-
|
1715
|
-
|
1716
|
-
# Custom dtypes for typing system
|
1717
|
-
class _ImageFileType(_dtypes.Type):
|
1718
|
-
name = "image-file"
|
1719
|
-
legacy_names = ["wandb.Image"]
|
1720
|
-
types = [Image]
|
1721
|
-
|
1722
|
-
def __init__(
|
1723
|
-
self,
|
1724
|
-
box_layers=None,
|
1725
|
-
box_score_keys=None,
|
1726
|
-
mask_layers=None,
|
1727
|
-
class_map=None,
|
1728
|
-
**kwargs,
|
1729
|
-
):
|
1730
|
-
box_layers = box_layers or {}
|
1731
|
-
box_score_keys = box_score_keys or []
|
1732
|
-
mask_layers = mask_layers or {}
|
1733
|
-
class_map = class_map or {}
|
1734
|
-
|
1735
|
-
if isinstance(box_layers, _dtypes.ConstType):
|
1736
|
-
box_layers = box_layers._params["val"]
|
1737
|
-
if not isinstance(box_layers, dict):
|
1738
|
-
raise TypeError("box_layers must be a dict")
|
1739
|
-
else:
|
1740
|
-
box_layers = _dtypes.ConstType(
|
1741
|
-
{layer_key: set(box_layers[layer_key]) for layer_key in box_layers}
|
1742
|
-
)
|
1743
|
-
|
1744
|
-
if isinstance(mask_layers, _dtypes.ConstType):
|
1745
|
-
mask_layers = mask_layers._params["val"]
|
1746
|
-
if not isinstance(mask_layers, dict):
|
1747
|
-
raise TypeError("mask_layers must be a dict")
|
1748
|
-
else:
|
1749
|
-
mask_layers = _dtypes.ConstType(
|
1750
|
-
{layer_key: set(mask_layers[layer_key]) for layer_key in mask_layers}
|
1751
|
-
)
|
1752
|
-
|
1753
|
-
if isinstance(box_score_keys, _dtypes.ConstType):
|
1754
|
-
box_score_keys = box_score_keys._params["val"]
|
1755
|
-
if not isinstance(box_score_keys, list) and not isinstance(box_score_keys, set):
|
1756
|
-
raise TypeError("box_score_keys must be a list or a set")
|
1757
|
-
else:
|
1758
|
-
box_score_keys = _dtypes.ConstType(set(box_score_keys))
|
1759
|
-
|
1760
|
-
if isinstance(class_map, _dtypes.ConstType):
|
1761
|
-
class_map = class_map._params["val"]
|
1762
|
-
if not isinstance(class_map, dict):
|
1763
|
-
raise TypeError("class_map must be a dict")
|
1764
|
-
else:
|
1765
|
-
class_map = _dtypes.ConstType(class_map)
|
1766
|
-
|
1767
|
-
self.params.update(
|
1768
|
-
{
|
1769
|
-
"box_layers": box_layers,
|
1770
|
-
"box_score_keys": box_score_keys,
|
1771
|
-
"mask_layers": mask_layers,
|
1772
|
-
"class_map": class_map,
|
1773
|
-
}
|
1774
|
-
)
|
1775
|
-
|
1776
|
-
def assign_type(self, wb_type=None):
|
1777
|
-
if isinstance(wb_type, _ImageFileType):
|
1778
|
-
box_layers_self = self.params["box_layers"].params["val"] or {}
|
1779
|
-
box_score_keys_self = self.params["box_score_keys"].params["val"] or []
|
1780
|
-
mask_layers_self = self.params["mask_layers"].params["val"] or {}
|
1781
|
-
class_map_self = self.params["class_map"].params["val"] or {}
|
1782
|
-
|
1783
|
-
box_layers_other = wb_type.params["box_layers"].params["val"] or {}
|
1784
|
-
box_score_keys_other = wb_type.params["box_score_keys"].params["val"] or []
|
1785
|
-
mask_layers_other = wb_type.params["mask_layers"].params["val"] or {}
|
1786
|
-
class_map_other = wb_type.params["class_map"].params["val"] or {}
|
1787
|
-
|
1788
|
-
# Merge the class_ids from each set of box_layers
|
1789
|
-
box_layers = {
|
1790
|
-
str(key): set(
|
1791
|
-
list(box_layers_self.get(key, []))
|
1792
|
-
+ list(box_layers_other.get(key, []))
|
1793
|
-
)
|
1794
|
-
for key in set(
|
1795
|
-
list(box_layers_self.keys()) + list(box_layers_other.keys())
|
1796
|
-
)
|
1797
|
-
}
|
1798
|
-
|
1799
|
-
# Merge the class_ids from each set of mask_layers
|
1800
|
-
mask_layers = {
|
1801
|
-
str(key): set(
|
1802
|
-
list(mask_layers_self.get(key, []))
|
1803
|
-
+ list(mask_layers_other.get(key, []))
|
1804
|
-
)
|
1805
|
-
for key in set(
|
1806
|
-
list(mask_layers_self.keys()) + list(mask_layers_other.keys())
|
1807
|
-
)
|
1808
|
-
}
|
1809
|
-
|
1810
|
-
# Merge the box score keys
|
1811
|
-
box_score_keys = set(list(box_score_keys_self) + list(box_score_keys_other))
|
1812
|
-
|
1813
|
-
# Merge the class_map
|
1814
|
-
class_map = {
|
1815
|
-
str(key): class_map_self.get(key, class_map_other.get(key, None))
|
1816
|
-
for key in set(
|
1817
|
-
list(class_map_self.keys()) + list(class_map_other.keys())
|
1818
|
-
)
|
1819
|
-
}
|
1820
|
-
|
1821
|
-
return _ImageFileType(box_layers, box_score_keys, mask_layers, class_map)
|
1822
|
-
|
1823
|
-
return _dtypes.InvalidType()
|
1824
|
-
|
1825
|
-
@classmethod
|
1826
|
-
def from_obj(cls, py_obj):
|
1827
|
-
if not isinstance(py_obj, Image):
|
1828
|
-
raise TypeError("py_obj must be a wandb.Image")
|
1829
|
-
else:
|
1830
|
-
if hasattr(py_obj, "_boxes") and py_obj._boxes:
|
1831
|
-
box_layers = {
|
1832
|
-
str(key): set(py_obj._boxes[key]._class_labels.keys())
|
1833
|
-
for key in py_obj._boxes.keys()
|
1834
|
-
}
|
1835
|
-
box_score_keys = {
|
1836
|
-
key
|
1837
|
-
for val in py_obj._boxes.values()
|
1838
|
-
for box in val._val
|
1839
|
-
for key in box.get("scores", {}).keys()
|
1840
|
-
}
|
1841
|
-
|
1842
|
-
else:
|
1843
|
-
box_layers = {}
|
1844
|
-
box_score_keys = set()
|
1845
|
-
|
1846
|
-
if hasattr(py_obj, "_masks") and py_obj._masks:
|
1847
|
-
mask_layers = {
|
1848
|
-
str(key): set(
|
1849
|
-
py_obj._masks[key]._val["class_labels"].keys()
|
1850
|
-
if hasattr(py_obj._masks[key], "_val")
|
1851
|
-
else []
|
1852
|
-
)
|
1853
|
-
for key in py_obj._masks.keys()
|
1854
|
-
}
|
1855
|
-
else:
|
1856
|
-
mask_layers = {}
|
1857
|
-
|
1858
|
-
if hasattr(py_obj, "_classes") and py_obj._classes:
|
1859
|
-
class_set = {
|
1860
|
-
str(item["id"]): item["name"] for item in py_obj._classes._class_set
|
1861
|
-
}
|
1862
|
-
else:
|
1863
|
-
class_set = {}
|
1864
|
-
|
1865
|
-
return cls(box_layers, box_score_keys, mask_layers, class_set)
|
1866
|
-
|
1867
|
-
|
1868
|
-
class _TableType(_dtypes.Type):
|
1869
|
-
name = "table"
|
1870
|
-
legacy_names = ["wandb.Table"]
|
1871
|
-
types = [Table]
|
1872
|
-
|
1873
|
-
def __init__(self, column_types=None):
|
1874
|
-
if column_types is None:
|
1875
|
-
column_types = _dtypes.UnknownType()
|
1876
|
-
if isinstance(column_types, dict):
|
1877
|
-
column_types = _dtypes.TypedDictType(column_types)
|
1878
|
-
elif not (
|
1879
|
-
isinstance(column_types, _dtypes.TypedDictType)
|
1880
|
-
or isinstance(column_types, _dtypes.UnknownType)
|
1881
|
-
):
|
1882
|
-
raise TypeError("column_types must be a dict or TypedDictType")
|
1883
|
-
|
1884
|
-
self.params.update({"column_types": column_types})
|
1885
|
-
|
1886
|
-
def assign_type(self, wb_type=None):
|
1887
|
-
if isinstance(wb_type, _TableType):
|
1888
|
-
column_types = self.params["column_types"].assign_type(
|
1889
|
-
wb_type.params["column_types"]
|
1890
|
-
)
|
1891
|
-
if not isinstance(column_types, _dtypes.InvalidType):
|
1892
|
-
return _TableType(column_types)
|
1893
|
-
|
1894
|
-
return _dtypes.InvalidType()
|
1895
|
-
|
1896
|
-
@classmethod
|
1897
|
-
def from_obj(cls, py_obj):
|
1898
|
-
if not isinstance(py_obj, Table):
|
1899
|
-
raise TypeError("py_obj must be a wandb.Table")
|
1900
|
-
else:
|
1901
|
-
return cls(py_obj._column_types)
|
1902
|
-
|
1903
|
-
|
1904
|
-
class _ForeignKeyType(_dtypes.Type):
|
1905
|
-
name = "foreignKey"
|
1906
|
-
legacy_names = ["wandb.TableForeignKey"]
|
1907
|
-
types = [_TableKey]
|
1908
|
-
|
1909
|
-
def __init__(self, table, col_name):
|
1910
|
-
assert isinstance(table, Table)
|
1911
|
-
assert isinstance(col_name, str)
|
1912
|
-
assert col_name in table.columns
|
1913
|
-
self.params.update({"table": table, "col_name": col_name})
|
1914
|
-
|
1915
|
-
def assign_type(self, wb_type=None):
|
1916
|
-
if isinstance(wb_type, _dtypes.StringType):
|
1917
|
-
return self
|
1918
|
-
elif (
|
1919
|
-
isinstance(wb_type, _ForeignKeyType)
|
1920
|
-
and id(self.params["table"]) == id(wb_type.params["table"])
|
1921
|
-
and self.params["col_name"] == wb_type.params["col_name"]
|
1922
|
-
):
|
1923
|
-
return self
|
1924
|
-
|
1925
|
-
return _dtypes.InvalidType()
|
1926
|
-
|
1927
|
-
@classmethod
|
1928
|
-
def from_obj(cls, py_obj):
|
1929
|
-
if not isinstance(py_obj, _TableKey):
|
1930
|
-
raise TypeError("py_obj must be a _TableKey")
|
1931
|
-
else:
|
1932
|
-
return cls(py_obj._table, py_obj._col_name)
|
1933
|
-
|
1934
|
-
def to_json(self, artifact=None):
|
1935
|
-
res = super().to_json(artifact)
|
1936
|
-
if artifact is not None:
|
1937
|
-
table_name = f"media/tables/t_{runid.generate_id()}"
|
1938
|
-
entry = artifact.add(self.params["table"], table_name)
|
1939
|
-
res["params"]["table"] = entry.path
|
1940
|
-
else:
|
1941
|
-
raise AssertionError(
|
1942
|
-
"_ForeignKeyType does not support serialization without an artifact"
|
1943
|
-
)
|
1944
|
-
return res
|
1945
|
-
|
1946
|
-
@classmethod
|
1947
|
-
def from_json(
|
1948
|
-
cls,
|
1949
|
-
json_dict,
|
1950
|
-
artifact,
|
1951
|
-
):
|
1952
|
-
table = None
|
1953
|
-
col_name = None
|
1954
|
-
if artifact is None:
|
1955
|
-
raise AssertionError(
|
1956
|
-
"_ForeignKeyType does not support deserialization without an artifact"
|
1957
|
-
)
|
1958
|
-
else:
|
1959
|
-
table = artifact.get(json_dict["params"]["table"])
|
1960
|
-
col_name = json_dict["params"]["col_name"]
|
1961
|
-
|
1962
|
-
if table is None:
|
1963
|
-
raise AssertionError("Unable to deserialize referenced table")
|
1964
|
-
|
1965
|
-
return cls(table, col_name)
|
1966
|
-
|
1967
|
-
|
1968
|
-
class _ForeignIndexType(_dtypes.Type):
|
1969
|
-
name = "foreignIndex"
|
1970
|
-
legacy_names = ["wandb.TableForeignIndex"]
|
1971
|
-
types = [_TableIndex]
|
1972
|
-
|
1973
|
-
def __init__(self, table):
|
1974
|
-
assert isinstance(table, Table)
|
1975
|
-
self.params.update({"table": table})
|
1976
|
-
|
1977
|
-
def assign_type(self, wb_type=None):
|
1978
|
-
if isinstance(wb_type, _dtypes.NumberType):
|
1979
|
-
return self
|
1980
|
-
elif isinstance(wb_type, _ForeignIndexType) and id(self.params["table"]) == id(
|
1981
|
-
wb_type.params["table"]
|
1982
|
-
):
|
1983
|
-
return self
|
1984
|
-
|
1985
|
-
return _dtypes.InvalidType()
|
1986
|
-
|
1987
|
-
@classmethod
|
1988
|
-
def from_obj(cls, py_obj):
|
1989
|
-
if not isinstance(py_obj, _TableIndex):
|
1990
|
-
raise TypeError("py_obj must be a _TableIndex")
|
1991
|
-
else:
|
1992
|
-
return cls(py_obj._table)
|
1993
|
-
|
1994
|
-
def to_json(self, artifact=None):
|
1995
|
-
res = super().to_json(artifact)
|
1996
|
-
if artifact is not None:
|
1997
|
-
table_name = f"media/tables/t_{runid.generate_id()}"
|
1998
|
-
entry = artifact.add(self.params["table"], table_name)
|
1999
|
-
res["params"]["table"] = entry.path
|
2000
|
-
else:
|
2001
|
-
raise AssertionError(
|
2002
|
-
"_ForeignIndexType does not support serialization without an artifact"
|
2003
|
-
)
|
2004
|
-
return res
|
2005
|
-
|
2006
|
-
@classmethod
|
2007
|
-
def from_json(
|
2008
|
-
cls,
|
2009
|
-
json_dict,
|
2010
|
-
artifact,
|
2011
|
-
):
|
2012
|
-
table = None
|
2013
|
-
if artifact is None:
|
2014
|
-
raise AssertionError(
|
2015
|
-
"_ForeignIndexType does not support deserialization without an artifact"
|
2016
|
-
)
|
2017
|
-
else:
|
2018
|
-
table = artifact.get(json_dict["params"]["table"])
|
2019
|
-
|
2020
|
-
if table is None:
|
2021
|
-
raise AssertionError("Unable to deserialize referenced table")
|
2022
|
-
|
2023
|
-
return cls(table)
|
2024
|
-
|
2025
|
-
|
2026
|
-
class _PrimaryKeyType(_dtypes.Type):
|
2027
|
-
name = "primaryKey"
|
2028
|
-
legacy_names = ["wandb.TablePrimaryKey"]
|
2029
|
-
|
2030
|
-
def assign_type(self, wb_type=None):
|
2031
|
-
if isinstance(wb_type, _dtypes.StringType) or isinstance(
|
2032
|
-
wb_type, _PrimaryKeyType
|
2033
|
-
):
|
2034
|
-
return self
|
2035
|
-
return _dtypes.InvalidType()
|
2036
|
-
|
2037
|
-
@classmethod
|
2038
|
-
def from_obj(cls, py_obj):
|
2039
|
-
if not isinstance(py_obj, _TableKey):
|
2040
|
-
raise TypeError("py_obj must be a wandb.Table")
|
2041
|
-
else:
|
2042
|
-
return cls()
|
2043
|
-
|
2044
|
-
|
2045
|
-
class _AudioFileType(_dtypes.Type):
|
2046
|
-
name = "audio-file"
|
2047
|
-
types = [Audio]
|
2048
|
-
|
2049
|
-
|
2050
|
-
class _BokehFileType(_dtypes.Type):
|
2051
|
-
name = "bokeh-file"
|
2052
|
-
types = [Bokeh]
|
2053
|
-
|
2054
|
-
|
2055
|
-
class _JoinedTableType(_dtypes.Type):
|
2056
|
-
name = "joined-table"
|
2057
|
-
types = [JoinedTable]
|
2058
|
-
|
2059
|
-
|
2060
|
-
class _PartitionedTableType(_dtypes.Type):
|
2061
|
-
name = "partitioned-table"
|
2062
|
-
types = [PartitionedTable]
|
2063
|
-
|
2064
|
-
|
2065
|
-
_dtypes.TypeRegistry.add(_AudioFileType)
|
2066
|
-
_dtypes.TypeRegistry.add(_BokehFileType)
|
2067
|
-
_dtypes.TypeRegistry.add(_ImageFileType)
|
2068
|
-
_dtypes.TypeRegistry.add(_TableType)
|
2069
|
-
_dtypes.TypeRegistry.add(_JoinedTableType)
|
2070
|
-
_dtypes.TypeRegistry.add(_PartitionedTableType)
|
2071
|
-
_dtypes.TypeRegistry.add(_ForeignKeyType)
|
2072
|
-
_dtypes.TypeRegistry.add(_PrimaryKeyType)
|
2073
|
-
_dtypes.TypeRegistry.add(_ForeignIndexType)
|