sentimentizer 0.6.5__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,20 @@
1
+ Copyright (c) 2018-Present Edward Yang
2
+
3
+ Permission is hereby granted, free of charge, to any person obtaining
4
+ a copy of this software and associated documentation files (the
5
+ "Software"), to deal in the Software without restriction, including
6
+ without limitation the rights to use, copy, modify, merge, publish,
7
+ distribute, sublicense, and/or sell copies of the Software, and to
8
+ permit persons to whom the Software is furnished to do so, subject to
9
+ the following conditions:
10
+
11
+ The above copyright notice and this permission notice shall be
12
+ included in all copies or substantial portions of the Software.
13
+
14
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
15
+ EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
16
+ MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
17
+ NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
18
+ LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
19
+ OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
20
+ WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
@@ -0,0 +1,3 @@
1
+ include README.md
2
+ include sentimentizer/data/yelp.dictionary
3
+ include sentimentizer/data/weights.pth
@@ -0,0 +1,80 @@
1
+ Metadata-Version: 2.1
2
+ Name: sentimentizer
3
+ Version: 0.6.5
4
+ Summary: straight forward rnn model
5
+ Author-email: Edward Yang <edwardpyang@gmail.com>
6
+ License: Copyright (c) 2018-Present Edward Yang
7
+
8
+ Permission is hereby granted, free of charge, to any person obtaining
9
+ a copy of this software and associated documentation files (the
10
+ "Software"), to deal in the Software without restriction, including
11
+ without limitation the rights to use, copy, modify, merge, publish,
12
+ distribute, sublicense, and/or sell copies of the Software, and to
13
+ permit persons to whom the Software is furnished to do so, subject to
14
+ the following conditions:
15
+
16
+ The above copyright notice and this permission notice shall be
17
+ included in all copies or substantial portions of the Software.
18
+
19
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
20
+ EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
21
+ MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
22
+ NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
23
+ LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
24
+ OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
25
+ WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
26
+
27
+ Keywords: rnn,pytorch,nlp,sentiment
28
+ Classifier: License :: OSI Approved :: MIT License
29
+ Classifier: Programming Language :: Python
30
+ Requires-Python: >=3.10
31
+ Description-Content-Type: text/markdown
32
+ Provides-Extra: dev
33
+ License-File: LICENSE
34
+
35
+ # Introduction
36
+
37
+ [![PyPI Latest Release](https://img.shields.io/pypi/v/sentimentizer.svg)](https://pypi.org/project/sentimentizer/)
38
+ ![GitHub CI](https://github.com/eddiepyang/sentimentizer/actions/workflows/ci.yaml/badge.svg)
39
+
40
+ Beta release, api subject to change. Install with:
41
+
42
+ ```
43
+ pip install sentimentizer
44
+ ```
45
+
46
+ This repo contains Neural Nets written with the pytorch framework for sentiment analysis.
47
+ A LSTM based torch model can be found in the rnn folder. In spite of large language models (GPT3.5 as of 2023)
48
+ dominating the conversation, small models can be pretty effective and are nice to learn from. This model focuses on sentiment analysis and was trained on
49
+ a single gpu in minutes and requires less than 1GB of memory.
50
+
51
+
52
+ ## Usage
53
+ ```
54
+ # where 0 is very negative and 1 is very positive
55
+ from sentimentizer.tokenizer import get_trained_tokenizer
56
+ from sentimentizer.rnn.model import get_trained_model
57
+
58
+ model = get_trained_model(64, 'cpu')
59
+ tokenizer = get_trained_tokenizer()
60
+ review_text = "greatest pie ever, best in town!"
61
+ positive_ids = tokenizer.tokenize_text(review_text)
62
+ model.predict(positive_ids)
63
+
64
+ >> tensor(0.9701)
65
+ ```
66
+
67
+ ## Install for development with miniconda:
68
+ ```
69
+ conda create -n {env}
70
+ conda install pip
71
+ pip install -e .
72
+ ```
73
+
74
+ ## Retrain model
75
+ To rerun the model:
76
+ * get the yelp [dataset](https://www.yelp.com/dataset),
77
+ * get the glove 6B 100D [dataset](https://nlp.stanford.edu/projects/glove/)
78
+ * place both files in the package data directory
79
+ * run the training script in workflows
80
+
@@ -0,0 +1,46 @@
1
+ # Introduction
2
+
3
+ [![PyPI Latest Release](https://img.shields.io/pypi/v/sentimentizer.svg)](https://pypi.org/project/sentimentizer/)
4
+ ![GitHub CI](https://github.com/eddiepyang/sentimentizer/actions/workflows/ci.yaml/badge.svg)
5
+
6
+ Beta release, api subject to change. Install with:
7
+
8
+ ```
9
+ pip install sentimentizer
10
+ ```
11
+
12
+ This repo contains Neural Nets written with the pytorch framework for sentiment analysis.
13
+ A LSTM based torch model can be found in the rnn folder. In spite of large language models (GPT3.5 as of 2023)
14
+ dominating the conversation, small models can be pretty effective and are nice to learn from. This model focuses on sentiment analysis and was trained on
15
+ a single gpu in minutes and requires less than 1GB of memory.
16
+
17
+
18
+ ## Usage
19
+ ```
20
+ # where 0 is very negative and 1 is very positive
21
+ from sentimentizer.tokenizer import get_trained_tokenizer
22
+ from sentimentizer.rnn.model import get_trained_model
23
+
24
+ model = get_trained_model(64, 'cpu')
25
+ tokenizer = get_trained_tokenizer()
26
+ review_text = "greatest pie ever, best in town!"
27
+ positive_ids = tokenizer.tokenize_text(review_text)
28
+ model.predict(positive_ids)
29
+
30
+ >> tensor(0.9701)
31
+ ```
32
+
33
+ ## Install for development with miniconda:
34
+ ```
35
+ conda create -n {env}
36
+ conda install pip
37
+ pip install -e .
38
+ ```
39
+
40
+ ## Retrain model
41
+ To rerun the model:
42
+ * get the yelp [dataset](https://www.yelp.com/dataset),
43
+ * get the glove 6B 100D [dataset](https://nlp.stanford.edu/projects/glove/)
44
+ * place both files in the package data directory
45
+ * run the training script in workflows
46
+
@@ -0,0 +1,34 @@
1
+ [build-system]
2
+ requires = ["setuptools>=61.0.1", "wheel"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "sentimentizer"
7
+ version = "0.6.5"
8
+ description = "straight forward rnn model"
9
+ readme = "README.md"
10
+ authors = [{ name = "Edward Yang", email = "edwardpyang@gmail.com" }]
11
+ license = { file = "LICENSE" }
12
+ classifiers = [
13
+ "License :: OSI Approved :: MIT License",
14
+ "Programming Language :: Python"
15
+ ]
16
+
17
+ keywords = ["rnn", "pytorch", "nlp", "sentiment"]
18
+ dependencies = [
19
+ "numpy >= 1.24.0",
20
+ "polars >= 0.15.16",
21
+ "pandas >= 1.5.2",
22
+ "torch >= 1.13.1",
23
+ "orjson >= 3.8.5",
24
+ "pyarrow >= 10.0.1",
25
+ "attrs >= 22.2.0",
26
+ "scikit-learn",
27
+ "gensim",
28
+ "structlog",
29
+ "psutil",
30
+ ]
31
+ requires-python = ">=3.10"
32
+
33
+ [project.optional-dependencies]
34
+ dev = ["black", "bumpver", "isort", "pip-tools", "pytest"]
@@ -0,0 +1,57 @@
1
+ from functools import wraps
2
+ import logging
3
+ from pathlib import Path
4
+ import sys
5
+ import time
6
+ from typing import TextIO
7
+
8
+ import psutil
9
+ import structlog
10
+
11
+
12
+ file_path = Path(__file__)
13
+ root = file_path.parent.parent.absolute()
14
+
15
+
16
+ def new_logger(level: int = 20, output: TextIO = sys.stderr) -> structlog.PrintLogger:
17
+ """
18
+ creates instance of struct logger
19
+ """
20
+ structlog.configure(
21
+ cache_logger_on_first_use=True,
22
+ wrapper_class=structlog.make_filtering_bound_logger(level),
23
+ processors=[
24
+ structlog.contextvars.merge_contextvars,
25
+ structlog.processors.add_log_level,
26
+ structlog.processors.format_exc_info,
27
+ structlog.processors.TimeStamper(fmt="iso", utc=True),
28
+ structlog.processors.JSONRenderer(),
29
+ ],
30
+ logger_factory=structlog.PrintLoggerFactory(file=output),
31
+ )
32
+ return structlog.getLogger(__name__)
33
+
34
+
35
+ logger = new_logger(logging.INFO)
36
+
37
+
38
+ def time_decorator(func):
39
+ """logs time stats of function"""
40
+
41
+ @wraps(func)
42
+ def wrapper(*args, **kwargs):
43
+ ts = time.perf_counter()
44
+ result = func(*args, **kwargs)
45
+ te = time.perf_counter()
46
+ event = "function completed successfully"
47
+ logger.info(
48
+ event,
49
+ function=func.__name__,
50
+ run_time=f"{te-ts: 2.4f} seconds",
51
+ available_memory=f"{psutil.virtual_memory().available/1024**3: .2f} GBs",
52
+ free_memory=f"{psutil.virtual_memory().free/1024**3: .2f} GBs",
53
+ used_memory=f"{psutil.virtual_memory().used/1024**3: .2f} GBs",
54
+ )
55
+ return result
56
+
57
+ return wrapper
@@ -0,0 +1,91 @@
1
+ import os
2
+ import enum
3
+
4
+ from dataclasses import dataclass
5
+ from logging import NOTSET, DEBUG, INFO, WARN, ERROR, CRITICAL
6
+
7
+ from typing import Tuple, Callable
8
+ from sentimentizer import root
9
+
10
+ data_path = os.path.join(root, "sentimentizer")
11
+
12
+ DEFAULT_LOG_LEVEL = INFO
13
+
14
+ BATCH_SIZE = 100000
15
+ WRITE_BYTES = "wb"
16
+ READ_BYTES = "rb"
17
+ TEXT_COLUMN = "text"
18
+
19
+ Devices = set(("cpu", "cuda", "mps"))
20
+
21
+ class Device(enum.Enum):
22
+ CPU = 0
23
+ CUDA = 1
24
+ MPS = 2
25
+
26
+ class FitModes(enum.Enum):
27
+ fitting = 0
28
+ training = 1
29
+ evaluation = 2
30
+
31
+
32
+ @dataclass
33
+ class OptimizationParams:
34
+ lr: float = 0.005
35
+ betas: Tuple[float, float] = (0.7, 0.99)
36
+ weight_decay: float = 1e-4
37
+
38
+
39
+ @dataclass
40
+ class SchedulerParams:
41
+ T_max: int = 100
42
+ eta_min: int = 0
43
+ last_epoch: int = -1
44
+
45
+
46
+ @dataclass(frozen=True)
47
+ class TokenizerConfig:
48
+ text_col: str = "text"
49
+ label_col: str = "stars"
50
+ inputs: str = "data"
51
+ labels: str = "target"
52
+ stop: int = 10000
53
+ max_len: int = 200
54
+ dict_min: int = 3
55
+ dict_keep: int = 20000
56
+ no_above: float = 0.99999
57
+ save_dictionary: bool = True
58
+
59
+
60
+ @dataclass(frozen=True)
61
+ class FileConfig:
62
+ archive_file_path: str = f"{data_path}/data/archive.zip"
63
+ raw_file_path: str = "yelp_academic_dataset_review.json"
64
+ dictionary_file_path: str = f"{data_path}/data/yelp.dictionary"
65
+ raw_reviews_file_path: str = f"{data_path}/data/review_data.arrow"
66
+ processed_reviews_file_path: str = f"{data_path}/data/review_data.parquet"
67
+ weights_file_path: str = f"{data_path}/data/weights.pth"
68
+
69
+
70
+ @dataclass
71
+ class TrainerConfig:
72
+ batch_size: int = 64
73
+ epochs: int = 4
74
+ workers: int = 10
75
+ device: str = "cuda"
76
+ memory: bool = True
77
+
78
+
79
+ @dataclass
80
+ class EmbeddingsConfig:
81
+ file_path: str = f"{data_path}/data/glove.6B.zip"
82
+ sub_file_path: str = "glove.6B.100d.txt"
83
+ emb_length: int = 100
84
+
85
+
86
+ @dataclass
87
+ class DriverConfig:
88
+ files: Callable = FileConfig
89
+ embeddings: Callable = EmbeddingsConfig
90
+ tokenizer: Callable = TokenizerConfig
91
+ trainer: Callable = TrainerConfig
File without changes
@@ -0,0 +1,122 @@
1
+ import zipfile
2
+ from itertools import islice
3
+ from typing import IO, Generator
4
+
5
+ import numpy as np
6
+ import orjson as json
7
+ import pyarrow as pa
8
+ from gensim import corpora
9
+
10
+ from sentimentizer import new_logger, time_decorator
11
+ from sentimentizer.config import DEFAULT_LOG_LEVEL, EmbeddingsConfig
12
+ from sentimentizer.tokenizer import tokenize
13
+
14
+ logger = new_logger(DEFAULT_LOG_LEVEL)
15
+
16
+ BATCH_SIZE = 100000
17
+ WRITE_BYTES = "wb"
18
+
19
+
20
+ def generate_batch(
21
+ generator_input: Generator[dict, str, None], iter_size: int
22
+ ) -> Generator[pa.RecordBatch, list, None]:
23
+
24
+ for start in range(0, iter_size, BATCH_SIZE):
25
+ end = min(start + BATCH_SIZE, iter_size)
26
+ review_dicts = []
27
+ review_dicts.extend(islice(generator_input, BATCH_SIZE))
28
+ yield review_dicts, start, end
29
+
30
+
31
+ def process_json(json_file: IO[bytes], stop: int = 0) -> Generator:
32
+ for i, line in enumerate(json_file):
33
+ if i % 100000 == 0:
34
+ logger.debug(f"processing line {i}")
35
+ dc = json.loads(line)
36
+ dc["text"] = tokenize(dc.get("text"))
37
+ if i >= stop and stop != 0:
38
+ break
39
+ yield dc
40
+
41
+
42
+ @time_decorator
43
+ def extract_data(file_path: str, compressed_file_name: str, stop: int = 0) -> Generator:
44
+ "reads from zipped yelp data file"
45
+
46
+ with zipfile.ZipFile(file_path) as zfile:
47
+ inf = zfile.open(compressed_file_name)
48
+ return process_json(inf, stop)
49
+
50
+
51
+ def write_arrow(
52
+ generator_input: Generator,
53
+ iter_size: int,
54
+ write_path: str,
55
+ schema: pa.Schema = None,
56
+ ) -> None:
57
+ gen = generate_batch(generator_input, iter_size)
58
+
59
+ in_schema = schema
60
+ if schema is None:
61
+ records, _, _ = next(gen)
62
+ batch = pa.RecordBatch.from_pylist(records)
63
+ in_schema = batch.schema
64
+
65
+ with pa.OSFile(write_path, WRITE_BYTES) as sink, pa.ipc.RecordBatchFileWriter(
66
+ sink, in_schema
67
+ ) as writer:
68
+ if schema is None:
69
+ writer.write(batch)
70
+
71
+ for records, _, end in gen:
72
+ try:
73
+ batch = pa.RecordBatch.from_pylist(records)
74
+ writer.write(batch)
75
+ except pa.ArrowInvalid:
76
+ logger.info(f"file completed, last item count was {end}")
77
+
78
+
79
+ @time_decorator
80
+ def extract_embeddings(
81
+ dictionary: corpora.Dictionary, cfg: EmbeddingsConfig
82
+ ) -> dict[str, np.ndarray]:
83
+ """load glove vectors"""
84
+
85
+ embeddings_dict: dict = {}
86
+
87
+ with zipfile.ZipFile(cfg.file_path, "r") as f, f.open(cfg.sub_file_path, "r") as z:
88
+ for line in z:
89
+ values = line.split()
90
+ key = values[0].decode()
91
+
92
+ if key in dictionary.token2id:
93
+ embeddings_dict.setdefault(
94
+ dictionary.token2id[key] + 1,
95
+ np.asarray(values[1:], dtype=np.float32), # noqa: E501
96
+ )
97
+
98
+ return embeddings_dict
99
+
100
+
101
+ @time_decorator
102
+ def new_embedding_weights(
103
+ dictionary: corpora.Dictionary, cfg: EmbeddingsConfig
104
+ ) -> np.ndarray:
105
+
106
+ """converts local dictionary to embeddings from glove"""
107
+
108
+ embeddings_dict: dict = extract_embeddings(dictionary, cfg)
109
+
110
+ for word in dictionary.values():
111
+ if word not in embeddings_dict:
112
+ embeddings_dict.setdefault(
113
+ dictionary.token2id[word] + 1, np.random.normal(0, 0.32, cfg.emb_length)
114
+ )
115
+
116
+ return np.vstack(
117
+ (
118
+ np.zeros(cfg.emb_length),
119
+ list(embeddings_dict.values()),
120
+ np.random.randn(cfg.emb_length),
121
+ )
122
+ )
@@ -0,0 +1,42 @@
1
+ from attr import define
2
+ import pandas as pd
3
+
4
+ import torch
5
+ from typing import Tuple
6
+ from torch.utils.data import Dataset
7
+ from sklearn.model_selection import train_test_split
8
+ from sentimentizer import new_logger
9
+ from sentimentizer.config import DEFAULT_LOG_LEVEL
10
+
11
+ logger = new_logger(DEFAULT_LOG_LEVEL)
12
+
13
+
14
+ @define
15
+ class CorpusDataset(Dataset):
16
+ """Dataset class required for pytorch to output items by index"""
17
+
18
+ data: pd.DataFrame
19
+ x_labels: str = "data"
20
+ y_labels: str = "target"
21
+
22
+ def __attr_pre__init__(self):
23
+
24
+ super().__init__()
25
+
26
+ def __len__(self):
27
+ return self.data.__len__()
28
+
29
+ def __getitem__(self, i):
30
+ return torch.tensor(self.data[self.x_labels].iat[i]), torch.tensor(
31
+ self.data[self.y_labels].iat[i]
32
+ )
33
+
34
+
35
+ def load_train_val_corpus_datasets(
36
+ data_path: str, test_size=0.2
37
+ ) -> Tuple[CorpusDataset, CorpusDataset]:
38
+
39
+ df = pd.read_parquet(data_path)
40
+ train_df, val_df = train_test_split(df, test_size=test_size)
41
+ del df
42
+ return CorpusDataset(data=train_df), CorpusDataset(val_df)
File without changes
@@ -0,0 +1,116 @@
1
+ from importlib.resources import files
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from gensim import corpora
7
+ from torch import nn
8
+
9
+ from sentimentizer import new_logger
10
+ from sentimentizer.config import DEFAULT_LOG_LEVEL, EmbeddingsConfig, Devices
11
+ from sentimentizer.extractor import new_embedding_weights
12
+
13
+ logger = new_logger(DEFAULT_LOG_LEVEL)
14
+
15
+
16
+ class Decoder(nn.Module):
17
+ """model class"""
18
+
19
+ def __init__(
20
+ self,
21
+ batch_size: int,
22
+ input_len: int,
23
+ d_model: int,
24
+ n_heads: int,
25
+ emb_weights: torch.Tensor, # weights are vocabsize x embedding length
26
+ verbose: bool = True,
27
+ dropout: float = 0.2,
28
+ ):
29
+ super().__init__()
30
+ # vocab size in, hidden size out
31
+ self.batch_size = batch_size
32
+ self.emb_weights = emb_weights
33
+
34
+ self.embed_layer = nn.Embedding(emb_weights.shape[0], emb_weights.shape[1])
35
+ self.fc0 = nn.Linear(emb_weights.shape[1], emb_weights.shape[1])
36
+
37
+ self.dropout = dropout
38
+ self.dropout_layer = nn.Dropout1d(p=self.dropout, inplace=True)
39
+ # input of shape (seq_len, batch, input_size)
40
+ # https://pytorch.org/docs/stable/nn.html
41
+
42
+ decoder_layer = nn.TransformerDecoderLayer(d_model, n_heads)
43
+ layer_norm = nn.LayerNorm(d_model)
44
+ self.encoder = nn.TransformerDecoder(
45
+ encoder_layer=decoder_layer, num_layers=1, norm=layer_norm
46
+ )
47
+
48
+ self.fc1 = nn.Linear(input_len, 1)
49
+ self.fc2 = nn.Linear(emb_weights.shape[1], 1)
50
+ self.verbose = verbose
51
+
52
+ def load_weights(self):
53
+ self.embed_layer.load_state_dict({"weight": self.emb_weights}) # type: ignore
54
+ return self
55
+
56
+ def forward(self, inputs: torch.Tensor):
57
+ embeds = self.embed_layer(inputs)
58
+ self.dropout_layer(embeds)
59
+
60
+ logger.debug("embedding shape %s" % (embeds.shape,))
61
+ embeds = F.relu(self.fc0(embeds))
62
+ encoded_out = self.encoder(embeds.permute(0, 2, 1))
63
+
64
+ logger.debug("lstm out shape %s" % (encoded_out.shape,))
65
+ out = self.fc1(encoded_out)
66
+ logger.debug("fc1 out shape %s" % (out.shape,))
67
+ fout = self.fc2(out.permute(0, 2, 1))
68
+ logger.debug("final %s" % (fout.shape,))
69
+
70
+ return torch.squeeze(fout)
71
+
72
+ def predict(self, converted_text: np.ndarray) -> torch.Tensor:
73
+ with torch.no_grad():
74
+ self.eval()
75
+ output = torch.from_numpy(converted_text)
76
+ return torch.sigmoid(self.forward(output))
77
+
78
+
79
+ def new_model(
80
+ dict_path: str, embeddings_config: EmbeddingsConfig, batch_size: int, input_len: int
81
+ ):
82
+ dict_yelp = corpora.Dictionary.load(dict_path)
83
+ embedding_matrix = new_embedding_weights(dict_yelp, embeddings_config)
84
+ emb_t = torch.from_numpy(embedding_matrix)
85
+ model = Encoder(
86
+ batch_size=batch_size,
87
+ d_model=200,
88
+ n_heads=4,
89
+ input_len=input_len,
90
+ emb_weights=emb_t,
91
+ )
92
+ model.load_weights()
93
+ return model
94
+
95
+
96
+ def get_trained_model(batch_size: int, device: str) -> Encoder:
97
+ """loads pre-trained model"""
98
+ if device not in Devices:
99
+ raise ValueError("device must be cpu, cuda, or mps")
100
+
101
+ weights = torch.load(
102
+ str(files("sentimentizer.data").joinpath("embed_weights.pth")),
103
+ map_location=torch.device(device=device),
104
+ )
105
+ empty_embeddings = torch.zeros(weights["embed_layer.weight"].shape)
106
+ model = Encoder(
107
+ batch_size=batch_size,
108
+ d_model=200,
109
+ n_heads=4,
110
+ input_len=200,
111
+ emb_weights=empty_embeddings,
112
+ )
113
+
114
+ model.load_state_dict(weights)
115
+
116
+ return model