Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions rosetta/utils/te_pax_t5x_ckpt_converter/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ python converter/main.py \
--input-path=/your_path_to_src_ckpt \
--output-path=/your_path_to_output_ckpt \
--fw=pax \
--direction=fw2tw \
--direction=fw2te \
--pax-repeat \
--num-of-layer=8 \
--num-of-head=6 \
Expand Down Expand Up @@ -154,7 +154,7 @@ restoring it to keep training.

#### The folder structure of CKPT by Pax and T5X
If you would like to run the converted CKPTs with frameworks, you may expect the converted CKPTs have the same folder
structure with CKPTs stored by frameworks. In this case, you could set `--output-path` to be the same stucture as the
structure with CKPTs stored by frameworks. In this case, you could set `--output-path` to be the same stucture as the
CKPTs from frameworks, and no need to pre-generate folders, since it would be generated when needed.
For Pax, you could set `--output-path` be like ` /${your_path_to_output}/checkpoints/checkpoint_${step}`.
For T5X, you could set `--output-path` be like `/${your_path_to_output}/checkpoint_${step}`.
19 changes: 15 additions & 4 deletions rosetta/utils/te_pax_t5x_ckpt_converter/converter/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.


import argparse

from paxml_converters import Pax2TEConvertHelper, Pax2TERepeatConvertHelper
Expand Down Expand Up @@ -109,6 +108,17 @@ def parse_args():
default=False,
help="indicate if skip the conversion for LayerNorm.")

parser.add_argument('--gen-fp8-meta',
action="store_true",
default=False,
help="indicate if generate corresponding FP8 meta."
" Only works when --direction=fw2te")
parser.add_argument(
'--amax-history-len',
type=int,
default=1,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we set this to the recommended default of 1024?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

help="the length of amax history, which is only used when --gen-fp8-meta is specified.")

parser.add_argument('--pax-repeat',
action="store_true",
default=False,
Expand All @@ -129,7 +139,8 @@ def parse_args():
def get_convert_helper(args):

model_config = ModelConfig(args.num_of_layer, args.embed_dim, args.num_of_head, args.head_dim,
args.mlp_intermediate_dim, args.kernel_chunk_size)
args.mlp_intermediate_dim, args.kernel_chunk_size,
args.amax_history_len)

convert_helper_cls = None

Expand All @@ -140,8 +151,8 @@ def get_convert_helper(args):
convert_helper_cls = T5X_CONVERT_HELPER_DICT[(args.direction, args.t5x_fuse_qkv)]

assert convert_helper_cls is not None, "Not Supported."
return convert_helper_cls(args.input_path, args.output_path, model_config,
args.weight_only, args.skip_ln)
return convert_helper_cls(args.input_path, args.output_path, model_config, args.weight_only,
args.skip_ln, args.gen_fp8_meta)


if __name__ == "__main__":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.


import jax.numpy as jnp

from utils import ConvertHelper
Expand All @@ -26,6 +25,10 @@ def catagories(self):
return ['mdl_vars.params']
return ['mdl_vars.params', "opt_states_0_2.m.params", "opt_states_0_2.v.params"]

@property
def fp8_meta_catagories(self):
return {'mdl_vars.params': 'mdl_vars.fp8_metas'}


class Pax2TEConvertHelper(PaxConvertHelperBase):

Expand All @@ -46,8 +49,11 @@ def _generate_ckpt_map(self):
f"lm.transformer.x_layers_{i}.ff_layer.ffn_layer1.linear.w":
self._get_convert_pkg(
f"lm.transformer.x_layers_{i}.transformerlayer.cld.mlp.wi_kernel",
(hidden_dim, mlp_intermediate_dim), 0,
lambda x: jnp.reshape(x, (*x.shape[:-1], 1, x.shape[-1]))),
(hidden_dim, mlp_intermediate_dim),
0,
lambda x: jnp.reshape(x, (*x.shape[:-1], 1, x.shape[-1])),
gen_fp8_meta=True,
fp8_meta_postfix='0'),
f"lm.transformer.x_layers_{i}.ff_layer.ffn_layer2.bias.b":
self._get_convert_pkg(
f"lm.transformer.x_layers_{i}.transformerlayer.cld.mlp.wo_bias",
Expand All @@ -57,7 +63,10 @@ def _generate_ckpt_map(self):
f"lm.transformer.x_layers_{i}.ff_layer.ffn_layer2.linear.w":
self._get_convert_pkg(
f"lm.transformer.x_layers_{i}.transformerlayer.cld.mlp.wo_kernel",
(mlp_intermediate_dim, hidden_dim), 1),
(mlp_intermediate_dim, hidden_dim),
1,
gen_fp8_meta=True,
fp8_meta_postfix='1'),
f"lm.transformer.x_layers_{i}.ff_layer.layer_norm.bias":
self._get_convert_pkg(
f"lm.transformer.x_layers_{i}.transformerlayer.cld.mlp.ln_bias",
Expand Down Expand Up @@ -90,9 +99,12 @@ def _generate_ckpt_map(self):
f"lm.transformer.x_layers_{i}.self_attention.combined_qkv.w":
self._get_convert_pkg(
f"lm.transformer.x_layers_{i}.transformerlayer.cld.attention.qkv.kernel",
(3, hidden_dim, num_of_head, head_dim), 0,
(3, hidden_dim, num_of_head, head_dim),
0,
lambda x: jnp.reshape(x, (*x.shape[:-2], x.shape[-2] * x.shape[-1])),
lambda x: jnp.transpose(x, (1, 0, 2))),
lambda x: jnp.transpose(x, (1, 0, 2)),
gen_fp8_meta=True,
fp8_meta_postfix='0'),
f"lm.transformer.x_layers_{i}.self_attention.post.b":
self._get_convert_pkg(
f"lm.transformer.x_layers_{i}.transformerlayer.cld.attention.out.bias",
Expand All @@ -102,9 +114,12 @@ def _generate_ckpt_map(self):
f"lm.transformer.x_layers_{i}.self_attention.post.w":
self._get_convert_pkg(
f"lm.transformer.x_layers_{i}.transformerlayer.cld.attention.out.kernel",
(hidden_dim, num_of_head, head_dim), 1,
(hidden_dim, num_of_head, head_dim),
1,
lambda x: jnp.reshape(x, (*x.shape[:-2], x.shape[-2] * x.shape[-1])),
lambda x: jnp.transpose(x, (1, 0)))
lambda x: jnp.transpose(x, (1, 0)),
gen_fp8_meta=True,
fp8_meta_postfix='0')
})

return ckpt_map
Expand Down Expand Up @@ -199,6 +214,10 @@ def catagories(self):
f"opt_states_0.p#{num_of_layer}#i-1_2.v.params"
]

@property
def fp8_meta_catagories(self):
return {'mdl_vars.params': 'mdl_vars.fp8_metas'}


class Pax2TERepeatConvertHelper(PaxRepeatConvertHelperBase):

Expand All @@ -220,8 +239,12 @@ def _generate_ckpt_map(self):
'lm.transformer.repeat.sub.x_layers_0.ff_layer.ffn_layer1.linear.w':
self._get_convert_pkg(
'lm.transformer.repeat.sub.x_layers_0.transformerlayer.cld.mlp.wi_kernel',
(num_of_layer, hidden_dim, mlp_intermediate_dim), 1,
lambda x: jnp.reshape(x, (*x.shape[:-1], 1, x.shape[-1]))),
(num_of_layer, hidden_dim, mlp_intermediate_dim),
1,
lambda x: jnp.reshape(x, (*x.shape[:-1], 1, x.shape[-1])),
gen_fp8_meta=True,
fp8_meta_postfix='0',
fp8_meta_shape_prefix=(num_of_layer,)),
'lm.transformer.repeat.sub.x_layers_0.ff_layer.ffn_layer2.bias.b':
self._get_convert_pkg(
'lm.transformer.repeat.sub.x_layers_0.transformerlayer.cld.mlp.wo_bias',
Expand All @@ -231,7 +254,11 @@ def _generate_ckpt_map(self):
'lm.transformer.repeat.sub.x_layers_0.ff_layer.ffn_layer2.linear.w':
self._get_convert_pkg(
'lm.transformer.repeat.sub.x_layers_0.transformerlayer.cld.mlp.wo_kernel',
(num_of_layer, mlp_intermediate_dim, hidden_dim), 2),
(num_of_layer, mlp_intermediate_dim, hidden_dim),
2,
gen_fp8_meta=True,
fp8_meta_postfix='1',
fp8_meta_shape_prefix=(num_of_layer,)),
'lm.transformer.repeat.sub.x_layers_0.ff_layer.layer_norm.bias':
self._get_convert_pkg(
'lm.transformer.repeat.sub.x_layers_0.transformerlayer.cld.mlp.ln_bias',
Expand Down Expand Up @@ -264,9 +291,13 @@ def _generate_ckpt_map(self):
'lm.transformer.repeat.sub.x_layers_0.self_attention.combined_qkv.w':
self._get_convert_pkg(
'lm.transformer.repeat.sub.x_layers_0.transformerlayer.cld.attention.qkv.kernel',
(num_of_layer, 3, hidden_dim, num_of_head, head_dim), 1,
(num_of_layer, 3, hidden_dim, num_of_head, head_dim),
1,
lambda x: jnp.reshape(x, (*x.shape[:-2], x.shape[-2] * x.shape[-1])),
lambda x: jnp.transpose(x, (0, 2, 1, 3))),
lambda x: jnp.transpose(x, (0, 2, 1, 3)),
gen_fp8_meta=True,
fp8_meta_postfix='0',
fp8_meta_shape_prefix=(num_of_layer,)),
'lm.transformer.repeat.sub.x_layers_0.self_attention.post.b':
self._get_convert_pkg(
'lm.transformer.repeat.sub.x_layers_0.transformerlayer.cld.attention.out.bias',
Expand All @@ -276,9 +307,13 @@ def _generate_ckpt_map(self):
'lm.transformer.repeat.sub.x_layers_0.self_attention.post.w':
self._get_convert_pkg(
'lm.transformer.repeat.sub.x_layers_0.transformerlayer.cld.attention.out.kernel',
(num_of_layer, hidden_dim, num_of_head, head_dim), 2,
(num_of_layer, hidden_dim, num_of_head, head_dim),
2,
lambda x: jnp.reshape(x, (*x.shape[:-2], x.shape[-2] * x.shape[-1])),
lambda x: jnp.transpose(x, (0, 2, 1)))
lambda x: jnp.transpose(x, (0, 2, 1)),
gen_fp8_meta=True,
fp8_meta_postfix='0',
fp8_meta_shape_prefix=(num_of_layer,))
})

return ckpt_map
Expand Down
Loading