together 0.2.5__py3-none-any.whl → 0.2.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
together/__init__.py CHANGED
@@ -16,14 +16,24 @@ api_base_complete = urllib.parse.urljoin(api_base, "/api/inference")
16
16
  api_base_files = urllib.parse.urljoin(api_base, "/v1/files/")
17
17
  api_base_finetune = urllib.parse.urljoin(api_base, "/v1/fine-tunes/")
18
18
  api_base_instances = urllib.parse.urljoin(api_base, "instances/")
19
+ api_base_embeddings = urllib.parse.urljoin(api_base, "api/v1/embeddings")
19
20
 
20
21
  default_text_model = "togethercomputer/RedPajama-INCITE-7B-Chat"
21
22
  default_image_model = "runwayml/stable-diffusion-v1-5"
23
+ default_embedding_model = "togethercomputer/bert-base-uncased"
22
24
  log_level = "WARNING"
23
25
 
26
+ MISSING_API_KEY_MESSAGE = """TOGETHER_API_KEY not found.
27
+ Please set it as an environment variable or set it as together.api_key
28
+ Find your TOGETHER_API_KEY at https://api.together.xyz/settings/api-keys"""
29
+
30
+ MAX_CONNECTION_RETRIES = 2
31
+ BACKOFF_FACTOR = 0.2
32
+
24
33
  min_samples = 100
25
34
 
26
35
  from .complete import Complete
36
+ from .embeddings import Embeddings
27
37
  from .error import *
28
38
  from .files import Files
29
39
  from .finetune import Finetune
@@ -38,12 +48,18 @@ __all__ = [
38
48
  "api_base_files",
39
49
  "api_base_finetune",
40
50
  "api_base_instances",
51
+ "api_base_embeddings",
41
52
  "default_text_model",
42
53
  "default_image_model",
54
+ "default_embedding_model",
43
55
  "Models",
44
56
  "Complete",
45
57
  "Files",
46
58
  "Finetune",
47
59
  "Image",
60
+ "Embeddings",
61
+ "MAX_CONNECTION_RETRIES",
62
+ "MISSING_API_KEY_MESSAGE",
63
+ "BACKOFF_FACTOR",
48
64
  "min_samples",
49
65
  ]
together/cli/cli.py CHANGED
@@ -2,7 +2,7 @@
2
2
  import argparse
3
3
 
4
4
  import together
5
- from together.commands import chat, complete, files, finetune, image, models
5
+ from together.commands import chat, complete, embeddings, files, finetune, image, models
6
6
  from together.utils import get_logger
7
7
 
8
8
 
@@ -49,6 +49,7 @@ def main() -> None:
49
49
  image.add_parser(subparser)
50
50
  files.add_parser(subparser)
51
51
  finetune.add_parser(subparser)
52
+ embeddings.add_parser(subparser)
52
53
 
53
54
  args = parser.parse_args()
54
55
 
together/commands/chat.py CHANGED
@@ -6,6 +6,10 @@ import cmd
6
6
  import together
7
7
  import together.tools.conversation as convo
8
8
  from together import Complete
9
+ from together.utils import get_logger
10
+
11
+
12
+ logger = get_logger(str(__name__))
9
13
 
10
14
 
11
15
  def add_parser(subparsers: argparse._SubParsersAction[argparse.ArgumentParser]) -> None:
@@ -99,18 +103,22 @@ class OpenChatKitShell(cmd.Cmd):
99
103
  def do_say(self, arg: str) -> None:
100
104
  self._convo.push_human_turn(arg)
101
105
  output = ""
102
- for token in self.infer.create_streaming(
103
- prompt=self._convo.get_raw_prompt(),
104
- model=self.args.model,
105
- max_tokens=self.args.max_tokens,
106
- stop=self.args.stop,
107
- temperature=self.args.temperature,
108
- top_p=self.args.top_p,
109
- top_k=self.args.top_k,
110
- repetition_penalty=self.args.repetition_penalty,
111
- ):
112
- print(token, end="", flush=True)
113
- output += token
106
+ try:
107
+ for token in self.infer.create_streaming(
108
+ prompt=self._convo.get_raw_prompt(),
109
+ model=self.args.model,
110
+ max_tokens=self.args.max_tokens,
111
+ stop=self.args.stop,
112
+ temperature=self.args.temperature,
113
+ top_p=self.args.top_p,
114
+ top_k=self.args.top_k,
115
+ repetition_penalty=self.args.repetition_penalty,
116
+ ):
117
+ print(token, end="", flush=True)
118
+ output += token
119
+ except together.AuthenticationError:
120
+ logger.critical(together.MISSING_API_KEY_MESSAGE)
121
+ exit(0)
114
122
  print("\n")
115
123
  self._convo.push_model_response(output)
116
124
 
@@ -131,33 +131,41 @@ def _run_complete(args: argparse.Namespace) -> None:
131
131
  complete = Complete()
132
132
 
133
133
  if args.no_stream:
134
- response = complete.create(
135
- prompt=args.prompt,
136
- model=args.model,
137
- max_tokens=args.max_tokens,
138
- stop=args.stop,
139
- temperature=args.temperature,
140
- top_p=args.top_p,
141
- top_k=args.top_k,
142
- repetition_penalty=args.repetition_penalty,
143
- logprobs=args.logprobs,
144
- )
134
+ try:
135
+ response = complete.create(
136
+ prompt=args.prompt,
137
+ model=args.model,
138
+ max_tokens=args.max_tokens,
139
+ stop=args.stop,
140
+ temperature=args.temperature,
141
+ top_p=args.top_p,
142
+ top_k=args.top_k,
143
+ repetition_penalty=args.repetition_penalty,
144
+ logprobs=args.logprobs,
145
+ )
146
+ except together.AuthenticationError:
147
+ logger.critical(together.MISSING_API_KEY_MESSAGE)
148
+ exit(0)
145
149
  no_streamer(args, response)
146
150
  else:
147
- for text in complete.create_streaming(
148
- prompt=args.prompt,
149
- model=args.model,
150
- max_tokens=args.max_tokens,
151
- stop=args.stop,
152
- temperature=args.temperature,
153
- top_p=args.top_p,
154
- top_k=args.top_k,
155
- repetition_penalty=args.repetition_penalty,
156
- raw=args.raw,
157
- ):
158
- if not args.raw:
159
- print(text, end="", flush=True)
160
- else:
161
- print(text)
151
+ try:
152
+ for text in complete.create_streaming(
153
+ prompt=args.prompt,
154
+ model=args.model,
155
+ max_tokens=args.max_tokens,
156
+ stop=args.stop,
157
+ temperature=args.temperature,
158
+ top_p=args.top_p,
159
+ top_k=args.top_k,
160
+ repetition_penalty=args.repetition_penalty,
161
+ raw=args.raw,
162
+ ):
163
+ if not args.raw:
164
+ print(text, end="", flush=True)
165
+ else:
166
+ print(text)
167
+ except together.AuthenticationError:
168
+ logger.critical(together.MISSING_API_KEY_MESSAGE)
169
+ exit(0)
162
170
  if not args.raw:
163
171
  print("\n")
@@ -0,0 +1,48 @@
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ import json
5
+
6
+ import together
7
+ from together import Embeddings
8
+ from together.utils import get_logger
9
+
10
+
11
+ logger = get_logger(str(__name__))
12
+
13
+
14
+ def add_parser(subparsers: argparse._SubParsersAction[argparse.ArgumentParser]) -> None:
15
+ COMMAND_NAME = "embeddings"
16
+ subparser = subparsers.add_parser(COMMAND_NAME)
17
+
18
+ subparser.add_argument(
19
+ "input",
20
+ metavar="INPUT",
21
+ default=None,
22
+ type=str,
23
+ help="A string providing context for the model to embed",
24
+ )
25
+
26
+ subparser.add_argument(
27
+ "--model",
28
+ "-m",
29
+ default=together.default_embedding_model,
30
+ type=str,
31
+ help=f"The name of the model to query. Default={together.default_text_model}",
32
+ )
33
+ subparser.set_defaults(func=_run_complete)
34
+
35
+
36
+ def _run_complete(args: argparse.Namespace) -> None:
37
+ embeddings = Embeddings()
38
+
39
+ try:
40
+ response = embeddings.create(
41
+ input=args.input,
42
+ model=args.model,
43
+ )
44
+
45
+ print(json.dumps(response, indent=4))
46
+ except together.AuthenticationError:
47
+ logger.critical(together.MISSING_API_KEY_MESSAGE)
48
+ exit(0)
@@ -3,8 +3,14 @@ from __future__ import annotations
3
3
  import argparse
4
4
  import json
5
5
 
6
+ from tabulate import tabulate
7
+
8
+ import together
6
9
  from together import Files
7
- from together.utils import extract_time
10
+ from together.utils import bytes_to_human_readable, extract_time, get_logger
11
+
12
+
13
+ logger = get_logger(str(__name__))
8
14
 
9
15
 
10
16
  def add_parser(subparsers: argparse._SubParsersAction[argparse.ArgumentParser]) -> None:
@@ -23,6 +29,12 @@ def add_parser(subparsers: argparse._SubParsersAction[argparse.ArgumentParser])
23
29
 
24
30
  def _add_list(parser: argparse._SubParsersAction[argparse.ArgumentParser]) -> None:
25
31
  subparser = parser.add_parser("list")
32
+ subparser.add_argument(
33
+ "--raw",
34
+ help="Raw JSON dump of response",
35
+ default=False,
36
+ action="store_true",
37
+ )
26
38
  subparser.set_defaults(func=_run_list)
27
39
 
28
40
 
@@ -82,6 +94,12 @@ def _add_retrieve(parser: argparse._SubParsersAction[argparse.ArgumentParser]) -
82
94
  help="File ID of remote file",
83
95
  type=str,
84
96
  )
97
+ subparser.add_argument(
98
+ "--raw",
99
+ help="Raw JSON dump of response",
100
+ default=False,
101
+ action="store_true",
102
+ )
85
103
  subparser.set_defaults(func=_run_retrieve)
86
104
 
87
105
 
@@ -109,31 +127,79 @@ def _add_retrieve_content(
109
127
 
110
128
 
111
129
  def _run_list(args: argparse.Namespace) -> None:
112
- response = Files.list()
130
+ try:
131
+ response = Files.list()
132
+ except together.AuthenticationError:
133
+ logger.critical(together.MISSING_API_KEY_MESSAGE)
134
+ exit(0)
113
135
  response["data"].sort(key=extract_time)
114
- print(json.dumps(response, indent=4))
136
+ if args.raw:
137
+ print(json.dumps(response, indent=4))
138
+ else:
139
+ display_list = []
140
+ for i in response["data"]:
141
+ display_list.append(
142
+ {
143
+ "File name": i.get("filename"),
144
+ "File ID": i.get("id"),
145
+ "Size": bytes_to_human_readable(
146
+ float(str(i.get("bytes")))
147
+ ), # convert to string for mypy typing
148
+ "Created At": i.get("created_at"),
149
+ "Line Count": i.get("LineCount"),
150
+ }
151
+ )
152
+ table = tabulate(display_list, headers="keys", tablefmt="grid", showindex=True)
153
+ print(table)
115
154
 
116
155
 
117
156
  def _run_check(args: argparse.Namespace) -> None:
118
- response = Files.check(args.file)
157
+ try:
158
+ response = Files.check(args.file)
159
+ except together.AuthenticationError:
160
+ logger.critical(together.MISSING_API_KEY_MESSAGE)
161
+ exit(0)
119
162
  print(json.dumps(response, indent=4))
120
163
 
121
164
 
122
165
  def _run_upload(args: argparse.Namespace) -> None:
123
- response = Files.upload(file=args.file, check=not args.no_check, model=args.model)
166
+ try:
167
+ response = Files.upload(
168
+ file=args.file, check=not args.no_check, model=args.model
169
+ )
170
+ except together.AuthenticationError:
171
+ logger.critical(together.MISSING_API_KEY_MESSAGE)
172
+ exit(0)
124
173
  print(json.dumps(response, indent=4))
125
174
 
126
175
 
127
176
  def _run_delete(args: argparse.Namespace) -> None:
128
- response = Files.delete(args.file_id)
177
+ try:
178
+ response = Files.delete(args.file_id)
179
+ except together.AuthenticationError:
180
+ logger.critical(together.MISSING_API_KEY_MESSAGE)
181
+ exit(0)
129
182
  print(json.dumps(response, indent=4))
130
183
 
131
184
 
132
185
  def _run_retrieve(args: argparse.Namespace) -> None:
133
- response = Files.retrieve(args.file_id)
134
- print(json.dumps(response, indent=4))
186
+ try:
187
+ response = Files.retrieve(args.file_id)
188
+ except together.AuthenticationError:
189
+ logger.critical(together.MISSING_API_KEY_MESSAGE)
190
+ exit(0)
191
+ if args.raw:
192
+ print(json.dumps(response, indent=4))
193
+ else:
194
+ table_data = [{"Key": key, "Value": value} for key, value in response.items()]
195
+ table = tabulate(table_data, tablefmt="grid")
196
+ print(table)
135
197
 
136
198
 
137
199
  def _run_retrieve_content(args: argparse.Namespace) -> None:
138
- output = Files.retrieve_content(args.file_id, args.output)
200
+ try:
201
+ output = Files.retrieve_content(args.file_id, args.output)
202
+ except together.AuthenticationError:
203
+ logger.critical(together.MISSING_API_KEY_MESSAGE)
204
+ exit(0)
139
205
  print(output)