gptmed 0.3.2__py3-none-any.whl → 0.3.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
gptmed/__init__.py CHANGED
@@ -32,7 +32,7 @@ Advanced Usage:
32
32
  >>> model = GPTTransformer(config)
33
33
  """
34
34
 
35
- __version__ = "0.3.2"
35
+ __version__ = "0.3.3"
36
36
  __author__ = "Sanjog Sigdel"
37
37
  __email__ = "sigdelsanjog@gmail.com"
38
38
 
gptmed/api.py CHANGED
@@ -187,9 +187,9 @@ def train_from_config(config_path: str, verbose: bool = True) -> Dict[str, Any]:
187
187
  betas=args['betas'],
188
188
  eps=args['eps'],
189
189
  max_steps=args['max_steps'],
190
- save_every=args['save_every'],
191
- eval_every=args['eval_every'],
192
- log_every=args['log_every'],
190
+ save_interval=args['save_interval'],
191
+ eval_interval=args['eval_interval'],
192
+ log_interval=args['log_interval'],
193
193
  keep_last_n=args['keep_last_n'],
194
194
  train_data_path=args['train_data'],
195
195
  val_data_path=args['val_data'],
@@ -333,20 +333,32 @@ def generate(
333
333
  model = GPTTransformer(model_config)
334
334
  model.load_state_dict(checkpoint_data['model_state_dict'])
335
335
 
336
+ # Load tokenizer
337
+ import sentencepiece as spm
338
+ from gptmed.inference.generation_config import GenerationConfig
339
+
340
+ tokenizer_sp = spm.SentencePieceProcessor()
341
+ tokenizer_sp.Load(tokenizer)
342
+
336
343
  # Create generator
337
344
  generator = TextGenerator(
338
345
  model=model,
339
- tokenizer_path=tokenizer,
346
+ tokenizer=tokenizer_sp,
340
347
  device=device
341
348
  )
342
349
 
343
- # Generate
344
- output = generator.generate(
345
- prompt=prompt,
350
+ # Create generation config
351
+ gen_config = GenerationConfig(
346
352
  max_length=max_length,
347
353
  temperature=temperature,
348
354
  top_k=top_k,
349
355
  top_p=top_p
350
356
  )
351
357
 
358
+ # Generate
359
+ output = generator.generate(
360
+ prompt=prompt,
361
+ gen_config=gen_config
362
+ )
363
+
352
364
  return output
@@ -110,13 +110,13 @@ def config_to_args(config: Dict[str, Any]) -> Dict[str, Any]:
110
110
 
111
111
  # Checkpointing
112
112
  'checkpoint_dir': config['checkpointing']['checkpoint_dir'],
113
- 'save_interval': config['checkpointing']['save_every'],
113
+ 'save_interval': config['checkpointing'].get('save_interval', config['checkpointing'].get('save_every', 1)),
114
114
  'keep_last_n': config['checkpointing']['keep_last_n'],
115
115
 
116
116
  # Logging
117
117
  'log_dir': config['logging']['log_dir'],
118
- 'eval_interval': config['logging']['eval_every'],
119
- 'log_interval': config['logging']['log_every'],
118
+ 'eval_interval': config['logging'].get('eval_interval', config['logging'].get('eval_every', 100)),
119
+ 'log_interval': config['logging'].get('log_interval', config['logging'].get('log_every', 10)),
120
120
 
121
121
  # Device
122
122
  'device': config['device']['device'],
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: gptmed
3
- Version: 0.3.2
3
+ Version: 0.3.3
4
4
  Summary: A lightweight GPT-based language model framework for training custom question-answering models on any domain
5
5
  Author-email: Sanjog Sigdel <sigdelsanjog@gmail.com>
6
6
  Maintainer-email: Sanjog Sigdel <sigdelsanjog@gmail.com>
@@ -1,7 +1,7 @@
1
- gptmed/__init__.py,sha256=lFcfEI8k6ct6XadmW8r3oNjKa-JdkqfhONn55Pmoop8,1676
2
- gptmed/api.py,sha256=IU5r9ujg3S-Lem5-FOGDDLdh1UJ_FqCbaQayzyJez5c,10774
1
+ gptmed/__init__.py,sha256=mwzeW2Qc6j1z5f6HOvZ_BNOnFSncWEK2KEkdqq91yYY,1676
2
+ gptmed/api.py,sha256=gUWooWsXDaGb1r22YnzS3w-sU-n-b4gB4-gh0fMsT4A,11109
3
3
  gptmed/configs/__init__.py,sha256=yRa-zgPQ-OCzu8fvCrfWMG-CjF3dru3PZzknzm0oUaQ,23
4
- gptmed/configs/config_loader.py,sha256=aQkyOzu2Jp0jjsBM9Gbza60rfUKX_KwF_3ED_Dcv34o,5851
4
+ gptmed/configs/config_loader.py,sha256=ZWdH63XOOu0T8seWBiJFZtzlyFmzHzKmMxon6ZgZHlg,6000
5
5
  gptmed/configs/train_config.py,sha256=KqfNBh9hdTTd_6gEAlrClU8sVFSlVDmZJOrf3cPwFe8,4657
6
6
  gptmed/configs/training_config.yaml,sha256=EEZZa3kcsZr3g-_fKDPYZt4_NTpmS-3NvJrTYSWNc8g,2874
7
7
  gptmed/data/__init__.py,sha256=iAHeakB5pBAd7MkmarPPY0UKS9bTaO_winLZ23Y2O90,54
@@ -33,9 +33,9 @@ gptmed/training/utils.py,sha256=pJxCwneNr2STITIYwIDCxRzIICDFOxOMzK8DT7ck2oQ,5651
33
33
  gptmed/utils/__init__.py,sha256=XuMhIqOXF7mjnog_6Iky-hSbwvFb0iK42B4iDUpgi0U,44
34
34
  gptmed/utils/checkpoints.py,sha256=L4q1-_4GbHCoD7QuEKYeQ-xXDTF-6sqZOxKQ_LT8YmQ,7112
35
35
  gptmed/utils/logging.py,sha256=7dJc1tayMxCBjFSDXe4r9ACUTpoPTTGsJ0UZMTqZIDY,5303
36
- gptmed-0.3.2.dist-info/licenses/LICENSE,sha256=v2spsd7N1pKFFh2G8wGP_45iwe5S0DYiJzG4im8Rupc,1066
37
- gptmed-0.3.2.dist-info/METADATA,sha256=D6l-6CxTFN7UQyfjVZokH1qA50P9szHJezb8qzQdWjg,13605
38
- gptmed-0.3.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
39
- gptmed-0.3.2.dist-info/entry_points.txt,sha256=ATqOzTtPVdUiFX5ZSeo3n9JkUCqocUxEXTgy1CfNRZE,110
40
- gptmed-0.3.2.dist-info/top_level.txt,sha256=mhyEq3rG33t21ziJz5w3TPgx0RjPf4zXMNUx2JTiNmE,7
41
- gptmed-0.3.2.dist-info/RECORD,,
36
+ gptmed-0.3.3.dist-info/licenses/LICENSE,sha256=v2spsd7N1pKFFh2G8wGP_45iwe5S0DYiJzG4im8Rupc,1066
37
+ gptmed-0.3.3.dist-info/METADATA,sha256=0ohKwsi3802GMhVUIx2n76i4QHhY0dkzdG4a_g1p_Hw,13605
38
+ gptmed-0.3.3.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
39
+ gptmed-0.3.3.dist-info/entry_points.txt,sha256=ATqOzTtPVdUiFX5ZSeo3n9JkUCqocUxEXTgy1CfNRZE,110
40
+ gptmed-0.3.3.dist-info/top_level.txt,sha256=mhyEq3rG33t21ziJz5w3TPgx0RjPf4zXMNUx2JTiNmE,7
41
+ gptmed-0.3.3.dist-info/RECORD,,
File without changes