google-tunix 0.1.0.dev0__tar.gz → 0.1.0.dev1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (94) hide show
  1. {google_tunix-0.1.0.dev0 → google_tunix-0.1.0.dev1}/PKG-INFO +1 -1
  2. {google_tunix-0.1.0.dev0 → google_tunix-0.1.0.dev1}/google_tunix.egg-info/PKG-INFO +1 -1
  3. {google_tunix-0.1.0.dev0 → google_tunix-0.1.0.dev1}/pyproject.toml +1 -1
  4. {google_tunix-0.1.0.dev0 → google_tunix-0.1.0.dev1}/tunix/cli/utils/model.py +3 -3
  5. {google_tunix-0.1.0.dev0 → google_tunix-0.1.0.dev1}/tunix/models/gemma/model.py +6 -6
  6. {google_tunix-0.1.0.dev0 → google_tunix-0.1.0.dev1}/tunix/models/gemma/params_safetensors.py +6 -5
  7. {google_tunix-0.1.0.dev0 → google_tunix-0.1.0.dev1}/tunix/models/gemma3/model.py +6 -6
  8. {google_tunix-0.1.0.dev0 → google_tunix-0.1.0.dev1}/tunix/models/gemma3/params.py +1 -1
  9. {google_tunix-0.1.0.dev0 → google_tunix-0.1.0.dev1}/tunix/models/gemma3/params_safetensors.py +3 -3
  10. {google_tunix-0.1.0.dev0 → google_tunix-0.1.0.dev1}/LICENSE +0 -0
  11. {google_tunix-0.1.0.dev0 → google_tunix-0.1.0.dev1}/README.md +0 -0
  12. {google_tunix-0.1.0.dev0 → google_tunix-0.1.0.dev1}/google_tunix.egg-info/SOURCES.txt +0 -0
  13. {google_tunix-0.1.0.dev0 → google_tunix-0.1.0.dev1}/google_tunix.egg-info/dependency_links.txt +0 -0
  14. {google_tunix-0.1.0.dev0 → google_tunix-0.1.0.dev1}/google_tunix.egg-info/requires.txt +0 -0
  15. {google_tunix-0.1.0.dev0 → google_tunix-0.1.0.dev1}/google_tunix.egg-info/top_level.txt +0 -0
  16. {google_tunix-0.1.0.dev0 → google_tunix-0.1.0.dev1}/setup.cfg +0 -0
  17. {google_tunix-0.1.0.dev0 → google_tunix-0.1.0.dev1}/tunix/__init__.py +0 -0
  18. {google_tunix-0.1.0.dev0 → google_tunix-0.1.0.dev1}/tunix/cli/config.py +0 -0
  19. {google_tunix-0.1.0.dev0 → google_tunix-0.1.0.dev1}/tunix/cli/grpo_main.py +0 -0
  20. {google_tunix-0.1.0.dev0 → google_tunix-0.1.0.dev1}/tunix/cli/peft_main.py +0 -0
  21. {google_tunix-0.1.0.dev0 → google_tunix-0.1.0.dev1}/tunix/cli/utils/reward.py +0 -0
  22. {google_tunix-0.1.0.dev0 → google_tunix-0.1.0.dev1}/tunix/distillation/__init__.py +0 -0
  23. {google_tunix-0.1.0.dev0 → google_tunix-0.1.0.dev1}/tunix/distillation/distillation_trainer.py +0 -0
  24. {google_tunix-0.1.0.dev0 → google_tunix-0.1.0.dev1}/tunix/distillation/feature_extraction/__init__.py +0 -0
  25. {google_tunix-0.1.0.dev0 → google_tunix-0.1.0.dev1}/tunix/distillation/feature_extraction/pooling.py +0 -0
  26. {google_tunix-0.1.0.dev0 → google_tunix-0.1.0.dev1}/tunix/distillation/feature_extraction/projection.py +0 -0
  27. {google_tunix-0.1.0.dev0 → google_tunix-0.1.0.dev1}/tunix/distillation/feature_extraction/sowed_module.py +0 -0
  28. {google_tunix-0.1.0.dev0 → google_tunix-0.1.0.dev1}/tunix/distillation/strategies/__init__.py +0 -0
  29. {google_tunix-0.1.0.dev0 → google_tunix-0.1.0.dev1}/tunix/distillation/strategies/attention.py +0 -0
  30. {google_tunix-0.1.0.dev0 → google_tunix-0.1.0.dev1}/tunix/distillation/strategies/base_strategy.py +0 -0
  31. {google_tunix-0.1.0.dev0 → google_tunix-0.1.0.dev1}/tunix/distillation/strategies/feature_pooling.py +0 -0
  32. {google_tunix-0.1.0.dev0 → google_tunix-0.1.0.dev1}/tunix/distillation/strategies/feature_projection.py +0 -0
  33. {google_tunix-0.1.0.dev0 → google_tunix-0.1.0.dev1}/tunix/distillation/strategies/logit.py +0 -0
  34. {google_tunix-0.1.0.dev0 → google_tunix-0.1.0.dev1}/tunix/examples/data/math_dataset.py +0 -0
  35. {google_tunix-0.1.0.dev0 → google_tunix-0.1.0.dev1}/tunix/examples/data/translation_dataset.py +0 -0
  36. {google_tunix-0.1.0.dev0 → google_tunix-0.1.0.dev1}/tunix/generate/base_sampler.py +0 -0
  37. {google_tunix-0.1.0.dev0 → google_tunix-0.1.0.dev1}/tunix/generate/beam_search.py +0 -0
  38. {google_tunix-0.1.0.dev0 → google_tunix-0.1.0.dev1}/tunix/generate/sampler.py +0 -0
  39. {google_tunix-0.1.0.dev0 → google_tunix-0.1.0.dev1}/tunix/generate/tokenizer_adapter.py +0 -0
  40. {google_tunix-0.1.0.dev0 → google_tunix-0.1.0.dev1}/tunix/generate/utils.py +0 -0
  41. {google_tunix-0.1.0.dev0 → google_tunix-0.1.0.dev1}/tunix/generate/vllm_sampler.py +0 -0
  42. {google_tunix-0.1.0.dev0 → google_tunix-0.1.0.dev1}/tunix/models/gemma/params.py +0 -0
  43. {google_tunix-0.1.0.dev0 → google_tunix-0.1.0.dev1}/tunix/models/gemma/sampler.py +0 -0
  44. {google_tunix-0.1.0.dev0 → google_tunix-0.1.0.dev1}/tunix/models/llama3/model.py +0 -0
  45. {google_tunix-0.1.0.dev0 → google_tunix-0.1.0.dev1}/tunix/models/llama3/params.py +0 -0
  46. {google_tunix-0.1.0.dev0 → google_tunix-0.1.0.dev1}/tunix/models/qwen2/model.py +0 -0
  47. {google_tunix-0.1.0.dev0 → google_tunix-0.1.0.dev1}/tunix/models/qwen2/params.py +0 -0
  48. {google_tunix-0.1.0.dev0 → google_tunix-0.1.0.dev1}/tunix/models/qwen3/model.py +0 -0
  49. {google_tunix-0.1.0.dev0 → google_tunix-0.1.0.dev1}/tunix/models/qwen3/params.py +0 -0
  50. {google_tunix-0.1.0.dev0 → google_tunix-0.1.0.dev1}/tunix/models/safetensors_loader.py +0 -0
  51. {google_tunix-0.1.0.dev0 → google_tunix-0.1.0.dev1}/tunix/oss/utils.py +0 -0
  52. {google_tunix-0.1.0.dev0 → google_tunix-0.1.0.dev1}/tunix/rl/common.py +0 -0
  53. {google_tunix-0.1.0.dev0 → google_tunix-0.1.0.dev1}/tunix/rl/experimental/agentic/agents/agent_types.py +0 -0
  54. {google_tunix-0.1.0.dev0 → google_tunix-0.1.0.dev1}/tunix/rl/experimental/agentic/agents/base_agent.py +0 -0
  55. {google_tunix-0.1.0.dev0 → google_tunix-0.1.0.dev1}/tunix/rl/experimental/agentic/agents/tool_agent.py +0 -0
  56. {google_tunix-0.1.0.dev0 → google_tunix-0.1.0.dev1}/tunix/rl/experimental/agentic/environments/base_environment.py +0 -0
  57. {google_tunix-0.1.0.dev0 → google_tunix-0.1.0.dev1}/tunix/rl/experimental/agentic/environments/tool_environment.py +0 -0
  58. {google_tunix-0.1.0.dev0 → google_tunix-0.1.0.dev1}/tunix/rl/experimental/agentic/parser/tool_parser/gemini_parser.py +0 -0
  59. {google_tunix-0.1.0.dev0 → google_tunix-0.1.0.dev1}/tunix/rl/experimental/agentic/parser/tool_parser/qwen_parser.py +0 -0
  60. {google_tunix-0.1.0.dev0 → google_tunix-0.1.0.dev1}/tunix/rl/experimental/agentic/parser/tool_parser/tool_parser_base.py +0 -0
  61. {google_tunix-0.1.0.dev0 → google_tunix-0.1.0.dev1}/tunix/rl/experimental/agentic/parser/tool_parser/tool_parser_registry.py +0 -0
  62. {google_tunix-0.1.0.dev0 → google_tunix-0.1.0.dev1}/tunix/rl/experimental/agentic/prompts/prompt_template.py +0 -0
  63. {google_tunix-0.1.0.dev0 → google_tunix-0.1.0.dev1}/tunix/rl/experimental/agentic/rewards/reward.py +0 -0
  64. {google_tunix-0.1.0.dev0 → google_tunix-0.1.0.dev1}/tunix/rl/experimental/agentic/rewards/reward_types.py +0 -0
  65. {google_tunix-0.1.0.dev0 → google_tunix-0.1.0.dev1}/tunix/rl/experimental/agentic/tools/base_tool.py +0 -0
  66. {google_tunix-0.1.0.dev0 → google_tunix-0.1.0.dev1}/tunix/rl/experimental/agentic/tools/calculator_tool.py +0 -0
  67. {google_tunix-0.1.0.dev0 → google_tunix-0.1.0.dev1}/tunix/rl/experimental/agentic/tools/tool_manager.py +0 -0
  68. {google_tunix-0.1.0.dev0 → google_tunix-0.1.0.dev1}/tunix/rl/experimental/agentic/trajectory/trajectory_collect_engine.py +0 -0
  69. {google_tunix-0.1.0.dev0 → google_tunix-0.1.0.dev1}/tunix/rl/grpo/grpo_helpers.py +0 -0
  70. {google_tunix-0.1.0.dev0 → google_tunix-0.1.0.dev1}/tunix/rl/grpo/grpo_learner.py +0 -0
  71. {google_tunix-0.1.0.dev0 → google_tunix-0.1.0.dev1}/tunix/rl/inference/inference_worker.py +0 -0
  72. {google_tunix-0.1.0.dev0 → google_tunix-0.1.0.dev1}/tunix/rl/ppo/ppo_helpers.py +0 -0
  73. {google_tunix-0.1.0.dev0 → google_tunix-0.1.0.dev1}/tunix/rl/ppo/ppo_learner.py +0 -0
  74. {google_tunix-0.1.0.dev0 → google_tunix-0.1.0.dev1}/tunix/rl/queue/data_queue.py +0 -0
  75. {google_tunix-0.1.0.dev0 → google_tunix-0.1.0.dev1}/tunix/rl/reshard.py +0 -0
  76. {google_tunix-0.1.0.dev0 → google_tunix-0.1.0.dev1}/tunix/rl/rl_cluster.py +0 -0
  77. {google_tunix-0.1.0.dev0 → google_tunix-0.1.0.dev1}/tunix/rl/rl_learner.py +0 -0
  78. {google_tunix-0.1.0.dev0 → google_tunix-0.1.0.dev1}/tunix/rl/rollout/base_rollout.py +0 -0
  79. {google_tunix-0.1.0.dev0 → google_tunix-0.1.0.dev1}/tunix/rl/rollout/vanilla_rollout.py +0 -0
  80. {google_tunix-0.1.0.dev0 → google_tunix-0.1.0.dev1}/tunix/rl/rollout/vllm_rollout.py +0 -0
  81. {google_tunix-0.1.0.dev0 → google_tunix-0.1.0.dev1}/tunix/rl/trainer.py +0 -0
  82. {google_tunix-0.1.0.dev0 → google_tunix-0.1.0.dev1}/tunix/rl/utils.py +0 -0
  83. {google_tunix-0.1.0.dev0 → google_tunix-0.1.0.dev1}/tunix/sft/checkpoint_manager.py +0 -0
  84. {google_tunix-0.1.0.dev0 → google_tunix-0.1.0.dev1}/tunix/sft/dpo/dpo_trainer.py +0 -0
  85. {google_tunix-0.1.0.dev0 → google_tunix-0.1.0.dev1}/tunix/sft/hooks.py +0 -0
  86. {google_tunix-0.1.0.dev0 → google_tunix-0.1.0.dev1}/tunix/sft/inflight_throttler.py +0 -0
  87. {google_tunix-0.1.0.dev0 → google_tunix-0.1.0.dev1}/tunix/sft/metrics_logger.py +0 -0
  88. {google_tunix-0.1.0.dev0 → google_tunix-0.1.0.dev1}/tunix/sft/peft_trainer.py +0 -0
  89. {google_tunix-0.1.0.dev0 → google_tunix-0.1.0.dev1}/tunix/sft/profiler.py +0 -0
  90. {google_tunix-0.1.0.dev0 → google_tunix-0.1.0.dev1}/tunix/sft/progress_bar.py +0 -0
  91. {google_tunix-0.1.0.dev0 → google_tunix-0.1.0.dev1}/tunix/sft/sharding_utils.py +0 -0
  92. {google_tunix-0.1.0.dev0 → google_tunix-0.1.0.dev1}/tunix/sft/system_metrics_calculator.py +0 -0
  93. {google_tunix-0.1.0.dev0 → google_tunix-0.1.0.dev1}/tunix/sft/utils.py +0 -0
  94. {google_tunix-0.1.0.dev0 → google_tunix-0.1.0.dev1}/tunix/tests/test_common.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: google-tunix
3
- Version: 0.1.0.dev0
3
+ Version: 0.1.0.dev1
4
4
  Summary: A lightweight JAX-native LLM post-training framework.
5
5
  Author-email: Tunix Developers <tunix-dev@google.com>
6
6
  License-Expression: Apache-2.0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: google-tunix
3
- Version: 0.1.0.dev0
3
+ Version: 0.1.0.dev1
4
4
  Summary: A lightweight JAX-native LLM post-training framework.
5
5
  Author-email: Tunix Developers <tunix-dev@google.com>
6
6
  License-Expression: Apache-2.0
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "google-tunix"
3
- version = "0.1.0.dev0"
3
+ version = "0.1.0.dev1"
4
4
  authors = [
5
5
  { name = "Tunix Developers", email = "tunix-dev@google.com" },
6
6
  ]
@@ -37,9 +37,9 @@ from tunix.models.qwen3 import model as qwen3_lib
37
37
 
38
38
  # Map prefixes to the target object containing the methods.
39
39
  CONFIG_MAP = {
40
- 'gemma': gemma_lib.TransformerConfig,
41
- 'gemma2': gemma_lib.TransformerConfig,
42
- 'gemma3': gemma3_lib.Gemma3Config,
40
+ 'gemma': gemma_lib.ModelConfig,
41
+ 'gemma2': gemma_lib.ModelConfig,
42
+ 'gemma3': gemma3_lib.ModelConfig,
43
43
  'llama3.1': llama3_lib.ModelConfig,
44
44
  'llama3.2': llama3_lib.ModelConfig,
45
45
  'qwen2.5': qwen2_lib.ModelConfig,
@@ -555,7 +555,7 @@ class RMSNorm(nnx.Module):
555
555
 
556
556
 
557
557
  @dataclasses.dataclass(frozen=True)
558
- class TransformerConfig:
558
+ class ModelConfig:
559
559
  """Configuration for the gemma transformer."""
560
560
 
561
561
  num_layers: int
@@ -781,13 +781,13 @@ class Transformer(nnx.Module, pytree=False):
781
781
  ) -> 'Transformer':
782
782
 
783
783
  if version in ['2b', '2b-it', '1.1-2b-it']:
784
- config = TransformerConfig.gemma_2b()
784
+ config = ModelConfig.gemma_2b()
785
785
  elif version in ['7b', '7b-it', '1.1-7b-it']:
786
- config = TransformerConfig.gemma_7b()
786
+ config = ModelConfig.gemma_7b()
787
787
  elif version in ['2-2b', '2-2b-it']:
788
- config = TransformerConfig.gemma2_2b()
788
+ config = ModelConfig.gemma2_2b()
789
789
  elif version in ['2-9b', '2-9b-it']:
790
- config = TransformerConfig.gemma2_9b()
790
+ config = ModelConfig.gemma2_9b()
791
791
  else:
792
792
  raise ValueError(f'Unsupported version: {version}')
793
793
 
@@ -800,7 +800,7 @@ class Transformer(nnx.Module, pytree=False):
800
800
 
801
801
  def __init__(
802
802
  self,
803
- config: TransformerConfig,
803
+ config: ModelConfig,
804
804
  *,
805
805
  rngs: nnx.Rngs,
806
806
  shd_config: ShardingConfig = ShardingConfig.get_default_sharding(),
@@ -14,7 +14,7 @@ from tunix.models import safetensors_loader
14
14
  from tunix.models.gemma import model as model_lib
15
15
 
16
16
 
17
- def _get_key_and_transform_mapping(cfg: model_lib.TransformerConfig):
17
+ def _get_key_and_transform_mapping(cfg: model_lib.ModelConfig):
18
18
  """Mapping of torch_keys to (nnx_keys, (permute_rule, reshape_rule))."""
19
19
  mapping = {
20
20
  r"model\.embed_tokens\.weight": ("embedder.input_embedding", None),
@@ -74,7 +74,7 @@ def _get_key_and_transform_mapping(cfg: model_lib.TransformerConfig):
74
74
  return mapping
75
75
 
76
76
 
77
- def _make_preprocess_fn(cfg: model_lib.TransformerConfig):
77
+ def _make_preprocess_fn(cfg: model_lib.ModelConfig):
78
78
  """Creates a preprocess function to reshape and stack Q, K, and V tensors for Gemma safetensors."""
79
79
  q_pat = re.compile(r"tmp\.layers\.([0-9]+)\.attn\.q$")
80
80
  k_pat = re.compile(r"tmp\.layers\.([0-9]+)\.attn\.k$")
@@ -160,14 +160,15 @@ def _peek_vocab_size_from_safetensors(file_dir: str) -> int:
160
160
  if fn.endswith(".safetensors"):
161
161
  path = os.path.join(file_dir, fn)
162
162
  with safe_open(path, framework="jax") as f:
163
- shape = f.get_tensor("model.embed_tokens.weight").shape
164
- return shape[0]
163
+ if "model.embed_tokens.weight" in f.keys():
164
+ shape = f.get_tensor("model.embed_tokens.weight").shape
165
+ return shape[0]
165
166
  raise FileNotFoundError("No .safetensors found to peek vocab size")
166
167
 
167
168
 
168
169
  def create_model_from_safe_tensors(
169
170
  file_dir: str,
170
- config: model_lib.TransformerConfig,
171
+ config: model_lib.ModelConfig,
171
172
  mesh: jax.sharding.Mesh | None = None,
172
173
  dtype: jnp.dtype | None = None,
173
174
  ) -> model_lib.Transformer:
@@ -81,7 +81,7 @@ class QueryPreAttentionNormalisation(enum.Enum):
81
81
 
82
82
 
83
83
  @dataclasses.dataclass(frozen=True, kw_only=True)
84
- class Gemma3Config:
84
+ class ModelConfig:
85
85
  """Transformer config."""
86
86
 
87
87
  num_layers: int
@@ -105,7 +105,7 @@ class Gemma3Config:
105
105
  def gemma3_1b(
106
106
  cls,
107
107
  sharding_config: ShardingConfig = ShardingConfig.get_default_sharding(),
108
- ) -> 'Gemma3Config':
108
+ ) -> 'ModelConfig':
109
109
  return cls(
110
110
  num_layers=26,
111
111
  num_embed=262144,
@@ -124,7 +124,7 @@ class Gemma3Config:
124
124
  def gemma3_4b(
125
125
  cls,
126
126
  sharding_config: ShardingConfig = ShardingConfig.get_default_sharding(),
127
- ) -> 'Gemma3Config':
127
+ ) -> 'ModelConfig':
128
128
  """Gemma3-4B text-only config."""
129
129
  return cls(
130
130
  num_layers=34,
@@ -145,7 +145,7 @@ class Gemma3Config:
145
145
  def gemma3_12b(
146
146
  cls,
147
147
  sharding_config: ShardingConfig = ShardingConfig.get_default_sharding(),
148
- ) -> 'Gemma3Config':
148
+ ) -> 'ModelConfig':
149
149
  """Gemma3-12B text-only config."""
150
150
  return cls(
151
151
  num_layers=48,
@@ -167,7 +167,7 @@ class Gemma3Config:
167
167
  def gemma3_27b(
168
168
  cls,
169
169
  sharding_config: ShardingConfig = ShardingConfig.get_default_sharding(),
170
- ) -> 'Gemma3Config':
170
+ ) -> 'ModelConfig':
171
171
  """Gemma3-27B text-only config."""
172
172
  return cls(
173
173
  num_layers=62,
@@ -705,7 +705,7 @@ class RMSNorm(nnx.Module):
705
705
  class Gemma3(nnx.Module, pytree=False):
706
706
  """Gemma transformer."""
707
707
 
708
- def __init__(self, config: Gemma3Config, *, rngs: nnx.Rngs):
708
+ def __init__(self, config: ModelConfig, *, rngs: nnx.Rngs):
709
709
  self.config = config
710
710
  self.embedder = Embedder(
711
711
  vocab_size=config.num_embed,
@@ -47,7 +47,7 @@ GEMMA3_TOKENIZER = 'gs://gemma-data/tokenizers/tokenizer_gemma3.model'
47
47
 
48
48
  def create_model_from_checkpoint(
49
49
  checkpoint_path: str,
50
- model_config: model_lib.Gemma3Config,
50
+ model_config: model_lib.ModelConfig,
51
51
  mesh: jax.sharding.Mesh | None = None,
52
52
  ) -> model_lib.Gemma3:
53
53
  """Load a Gemma3 model from a checkpoint."""
@@ -11,7 +11,7 @@ from tunix.models import safetensors_loader
11
11
  from tunix.models.gemma3 import model as model_lib
12
12
 
13
13
 
14
- def _get_key_and_transform_mapping(cfg: model_lib.Gemma3Config):
14
+ def _get_key_and_transform_mapping(cfg: model_lib.ModelConfig):
15
15
  """Mapping of torch_keys to (nnx_keys, (permute_rule, reshape_rule))."""
16
16
  return {
17
17
  r"model\.embed_tokens\.weight": ("embedder.input_embedding", None),
@@ -98,7 +98,7 @@ def _get_key_and_transform_mapping(cfg: model_lib.Gemma3Config):
98
98
  }
99
99
 
100
100
 
101
- def _make_preprocess_fn(cfg: model_lib.Gemma3Config):
101
+ def _make_preprocess_fn(cfg: model_lib.ModelConfig):
102
102
  """Creates a tensor preprocessing function for Gemma3 safetensors, fusing q, k, and v projections."""
103
103
  q_pat = re.compile(r"tmp\.layers\.([0-9]+)\.attn\.q$")
104
104
  k_pat = re.compile(r"tmp\.layers\.([0-9]+)\.attn\.k$")
@@ -193,7 +193,7 @@ def _make_preprocess_fn(cfg: model_lib.Gemma3Config):
193
193
 
194
194
  def create_model_from_safe_tensors(
195
195
  file_dir: str,
196
- config: model_lib.Gemma3Config,
196
+ config: model_lib.ModelConfig,
197
197
  mesh: jax.sharding.Mesh | None = None,
198
198
  dtype: jnp.dtype | None = None,
199
199
  ):