-
Notifications
You must be signed in to change notification settings - Fork 370
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Export NeMo FastConformer Hybrid Transducer Large Streaming to ONNX (#…
…844)
- Loading branch information
1 parent
a9f936e
commit 68b25ab
Showing
9 changed files
with
611 additions
and
1 deletion.
There are no files selected for viewing
2 changes: 1 addition & 1 deletion
2
.github/workflows/export-nemo-fast-conformer-hybrid-transducer-ctc.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
73 changes: 73 additions & 0 deletions
73
.github/workflows/export-nemo-fast-conformer-hybrid-transducer-transducer.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
name: export-nemo-fast-conformer-transducer-to-onnx | ||
|
||
on: | ||
workflow_dispatch: | ||
|
||
concurrency: | ||
group: export-nemo-fast-conformer-hybrid-transducer-to-onnx-${{ github.ref }} | ||
cancel-in-progress: true | ||
|
||
jobs: | ||
export-nemo-fast-conformer-hybrid-transducer-to-onnx: | ||
if: github.repository_owner == 'k2-fsa' || github.repository_owner == 'csukuangfj' | ||
name: NeMo transducer | ||
runs-on: ${{ matrix.os }} | ||
strategy: | ||
fail-fast: false | ||
matrix: | ||
os: [macos-latest] | ||
python-version: ["3.10"] | ||
|
||
steps: | ||
- uses: actions/checkout@v4 | ||
|
||
- name: Setup Python ${{ matrix.python-version }} | ||
uses: actions/setup-python@v5 | ||
with: | ||
python-version: ${{ matrix.python-version }} | ||
|
||
- name: Install NeMo | ||
shell: bash | ||
run: | | ||
BRANCH='main' | ||
pip install git+https://github.com/NVIDIA/NeMo.git@$BRANCH#egg=nemo_toolkit[asr] | ||
pip install onnxruntime | ||
pip install kaldi-native-fbank | ||
pip install soundfile librosa | ||
- name: Run | ||
shell: bash | ||
run: | | ||
cd scripts/nemo/fast-conformer-hybrid-transducer-ctc | ||
./run-transducer.sh | ||
mv -v sherpa-onnx-nemo* ../../.. | ||
- name: Download test waves | ||
shell: bash | ||
run: | | ||
mkdir test_wavs | ||
pushd test_wavs | ||
curl -SL -O https://hf-mirror.com/csukuangfj/sherpa-onnx-nemo-ctc-en-conformer-small/resolve/main/test_wavs/0.wav | ||
curl -SL -O https://hf-mirror.com/csukuangfj/sherpa-onnx-nemo-ctc-en-conformer-small/resolve/main/test_wavs/1.wav | ||
curl -SL -O https://hf-mirror.com/csukuangfj/sherpa-onnx-nemo-ctc-en-conformer-small/resolve/main/test_wavs/8k.wav | ||
curl -SL -O https://hf-mirror.com/csukuangfj/sherpa-onnx-nemo-ctc-en-conformer-small/resolve/main/test_wavs/trans.txt | ||
popd | ||
cp -av test_wavs ./sherpa-onnx-nemo-streaming-fast-conformer-transducer-80ms | ||
cp -av test_wavs ./sherpa-onnx-nemo-streaming-fast-conformer-transducer-480ms | ||
cp -av test_wavs ./sherpa-onnx-nemo-streaming-fast-conformer-transducer-1040ms | ||
tar cjvf sherpa-onnx-nemo-streaming-fast-conformer-transducer-80ms.tar.bz2 sherpa-onnx-nemo-streaming-fast-conformer-transducer-80ms | ||
tar cjvf sherpa-onnx-nemo-streaming-fast-conformer-transducer-480ms.tar.bz2 sherpa-onnx-nemo-streaming-fast-conformer-transducer-480ms | ||
tar cjvf sherpa-onnx-nemo-streaming-fast-conformer-transducer-1040ms.tar.bz2 sherpa-onnx-nemo-streaming-fast-conformer-transducer-1040ms | ||
- name: Release | ||
uses: svenstaro/upload-release-action@v2 | ||
with: | ||
file_glob: true | ||
file: ./*.tar.bz2 | ||
overwrite: true | ||
repo_name: k2-fsa/sherpa-onnx | ||
repo_token: ${{ secrets.UPLOAD_GH_SHERPA_ONNX_TOKEN }} | ||
tag: asr-models |
1 change: 1 addition & 0 deletions
1
scripts/nemo/fast-conformer-hybrid-transducer-ctc/export-onnx-ctc.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
125 changes: 125 additions & 0 deletions
125
scripts/nemo/fast-conformer-hybrid-transducer-ctc/export-onnx-transducer.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
#!/usr/bin/env python3 | ||
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang) | ||
import argparse | ||
from typing import Dict | ||
|
||
import nemo.collections.asr as nemo_asr | ||
import onnx | ||
import torch | ||
|
||
|
||
def get_args(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument( | ||
"--model", | ||
type=str, | ||
required=True, | ||
choices=["80", "480", "1040"], | ||
) | ||
return parser.parse_args() | ||
|
||
|
||
def add_meta_data(filename: str, meta_data: Dict[str, str]): | ||
"""Add meta data to an ONNX model. It is changed in-place. | ||
Args: | ||
filename: | ||
Filename of the ONNX model to be changed. | ||
meta_data: | ||
Key-value pairs. | ||
""" | ||
model = onnx.load(filename) | ||
while len(model.metadata_props): | ||
model.metadata_props.pop() | ||
|
||
for key, value in meta_data.items(): | ||
meta = model.metadata_props.add() | ||
meta.key = key | ||
meta.value = str(value) | ||
|
||
onnx.save(model, filename) | ||
|
||
|
||
@torch.no_grad() | ||
def main(): | ||
args = get_args() | ||
model_name = f"stt_en_fastconformer_hybrid_large_streaming_{args.model}ms" | ||
|
||
asr_model = nemo_asr.models.ASRModel.from_pretrained(model_name=model_name) | ||
|
||
with open("./tokens.txt", "w", encoding="utf-8") as f: | ||
for i, s in enumerate(asr_model.joint.vocabulary): | ||
f.write(f"{s} {i}\n") | ||
f.write(f"<blk> {i+1}\n") | ||
print("Saved to tokens.txt") | ||
|
||
decoder_type = "rnnt" | ||
asr_model.change_decoding_strategy(decoder_type=decoder_type) | ||
asr_model.eval() | ||
|
||
assert asr_model.encoder.streaming_cfg is not None | ||
if isinstance(asr_model.encoder.streaming_cfg.chunk_size, list): | ||
chunk_size = asr_model.encoder.streaming_cfg.chunk_size[1] | ||
else: | ||
chunk_size = asr_model.encoder.streaming_cfg.chunk_size | ||
|
||
if isinstance(asr_model.encoder.streaming_cfg.pre_encode_cache_size, list): | ||
pre_encode_cache_size = asr_model.encoder.streaming_cfg.pre_encode_cache_size[1] | ||
else: | ||
pre_encode_cache_size = asr_model.encoder.streaming_cfg.pre_encode_cache_size | ||
window_size = chunk_size + pre_encode_cache_size | ||
|
||
print("chunk_size", chunk_size) | ||
print("pre_encode_cache_size", pre_encode_cache_size) | ||
print("window_size", window_size) | ||
|
||
chunk_shift = chunk_size | ||
|
||
# cache_last_channel: (batch_size, dim1, dim2, dim3) | ||
cache_last_channel_dim1 = len(asr_model.encoder.layers) | ||
cache_last_channel_dim2 = asr_model.encoder.streaming_cfg.last_channel_cache_size | ||
cache_last_channel_dim3 = asr_model.encoder.d_model | ||
|
||
# cache_last_time: (batch_size, dim1, dim2, dim3) | ||
cache_last_time_dim1 = len(asr_model.encoder.layers) | ||
cache_last_time_dim2 = asr_model.encoder.d_model | ||
cache_last_time_dim3 = asr_model.encoder.conv_context_size[0] | ||
|
||
asr_model.set_export_config({"decoder_type": "rnnt", "cache_support": True}) | ||
|
||
# asr_model.export("model.onnx") | ||
asr_model.encoder.export("encoder.onnx") | ||
asr_model.decoder.export("decoder.onnx") | ||
asr_model.joint.export("joiner.onnx") | ||
# model.onnx is a suffix. | ||
# It will generate two files: | ||
# encoder-model.onnx | ||
# decoder_joint-model.onnx | ||
|
||
meta_data = { | ||
"vocab_size": asr_model.decoder.vocab_size, | ||
"window_size": window_size, | ||
"chunk_shift": chunk_shift, | ||
"normalize_type": "None", | ||
"cache_last_channel_dim1": cache_last_channel_dim1, | ||
"cache_last_channel_dim2": cache_last_channel_dim2, | ||
"cache_last_channel_dim3": cache_last_channel_dim3, | ||
"cache_last_time_dim1": cache_last_time_dim1, | ||
"cache_last_time_dim2": cache_last_time_dim2, | ||
"cache_last_time_dim3": cache_last_time_dim3, | ||
"pred_rnn_layers": asr_model.decoder.pred_rnn_layers, | ||
"pred_hidden": asr_model.decoder.pred_hidden, | ||
"subsampling_factor": 8, | ||
"model_type": "EncDecHybridRNNTCTCBPEModel", | ||
"version": "1", | ||
"model_author": "NeMo", | ||
"url": f"https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/{model_name}", | ||
"comment": "Only the transducer branch is exported", | ||
} | ||
add_meta_data("encoder.onnx", meta_data) | ||
|
||
print(meta_data) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
#!/usr/bin/env bash | ||
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang) | ||
|
||
set -ex | ||
|
||
|
40 changes: 40 additions & 0 deletions
40
scripts/nemo/fast-conformer-hybrid-transducer-ctc/run-transducer.sh
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
#!/usr/bin/env bash | ||
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang) | ||
|
||
set -ex | ||
|
||
if [ ! -e ./0.wav ]; then | ||
# curl -SL -O https://hf-mirror.com/csukuangfj/icefall-asr-librispeech-streaming-zipformer-small-2024-03-18/resolve/main/test_wavs/0.wav | ||
curl -SL -O https://huggingface.co/csukuangfj/icefall-asr-librispeech-streaming-zipformer-small-2024-03-18/resolve/main/test_wavs/0.wav | ||
fi | ||
|
||
ms=( | ||
80 | ||
480 | ||
1040 | ||
) | ||
|
||
for m in ${ms[@]}; do | ||
./export-onnx-transducer.py --model $m | ||
d=sherpa-onnx-nemo-streaming-fast-conformer-transducer-${m}ms | ||
if [ ! -f $d/encoder.onnx ]; then | ||
mkdir -p $d | ||
mv -v encoder.onnx $d/ | ||
mv -v decoder.onnx $d/ | ||
mv -v joiner.onnx $d/ | ||
mv -v tokens.txt $d/ | ||
ls -lh $d | ||
fi | ||
done | ||
|
||
# Now test the exported models | ||
|
||
for m in ${ms[@]}; do | ||
d=sherpa-onnx-nemo-streaming-fast-conformer-transducer-${m}ms | ||
python3 ./test-onnx-transducer.py \ | ||
--encoder $d/encoder.onnx \ | ||
--decoder $d/decoder.onnx \ | ||
--joiner $d/joiner.onnx \ | ||
--tokens $d/tokens.txt \ | ||
--wav ./0.wav | ||
done |
63 changes: 63 additions & 0 deletions
63
scripts/nemo/fast-conformer-hybrid-transducer-ctc/show-onnx-transudcer.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
#!/usr/bin/env python3 | ||
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang) | ||
|
||
import onnxruntime | ||
|
||
|
||
def show(filename): | ||
session_opts = onnxruntime.SessionOptions() | ||
session_opts.log_severity_level = 3 | ||
sess = onnxruntime.InferenceSession(filename, session_opts) | ||
for i in sess.get_inputs(): | ||
print(i) | ||
|
||
print("-----") | ||
|
||
for i in sess.get_outputs(): | ||
print(i) | ||
|
||
|
||
def main(): | ||
print("=========encoder==========") | ||
show("./encoder.onnx") | ||
|
||
print("=========decoder==========") | ||
show("./decoder.onnx") | ||
|
||
print("=========joiner==========") | ||
show("./joiner.onnx") | ||
|
||
|
||
if __name__ == "__main__": | ||
main() | ||
|
||
""" | ||
=========encoder========== | ||
NodeArg(name='audio_signal', type='tensor(float)', shape=['audio_signal_dynamic_axes_1', 80, 'audio_signal_dynamic_axes_2']) | ||
NodeArg(name='length', type='tensor(int64)', shape=['length_dynamic_axes_1']) | ||
NodeArg(name='cache_last_channel', type='tensor(float)', shape=['cache_last_channel_dynamic_axes_1', 17, 'cache_last_channel_dynamic_axes_2', 512]) | ||
NodeArg(name='cache_last_time', type='tensor(float)', shape=['cache_last_time_dynamic_axes_1', 17, 512, 'cache_last_time_dynamic_axes_2']) | ||
NodeArg(name='cache_last_channel_len', type='tensor(int64)', shape=['cache_last_channel_len_dynamic_axes_1']) | ||
----- | ||
NodeArg(name='outputs', type='tensor(float)', shape=['outputs_dynamic_axes_1', 512, 'outputs_dynamic_axes_2']) | ||
NodeArg(name='encoded_lengths', type='tensor(int64)', shape=['encoded_lengths_dynamic_axes_1']) | ||
NodeArg(name='cache_last_channel_next', type='tensor(float)', shape=['cache_last_channel_next_dynamic_axes_1', 17, 'cache_last_channel_next_dynamic_axes_2', 512]) | ||
NodeArg(name='cache_last_time_next', type='tensor(float)', shape=['cache_last_time_next_dynamic_axes_1', 17, 512, 'cache_last_time_next_dynamic_axes_2']) | ||
NodeArg(name='cache_last_channel_next_len', type='tensor(int64)', shape=['cache_last_channel_next_len_dynamic_axes_1']) | ||
=========decoder========== | ||
NodeArg(name='targets', type='tensor(int32)', shape=['targets_dynamic_axes_1', 'targets_dynamic_axes_2']) | ||
NodeArg(name='target_length', type='tensor(int32)', shape=['target_length_dynamic_axes_1']) | ||
NodeArg(name='states.1', type='tensor(float)', shape=[1, 'states.1_dim_1', 640]) | ||
NodeArg(name='onnx::LSTM_3', type='tensor(float)', shape=[1, 1, 640]) | ||
----- | ||
NodeArg(name='outputs', type='tensor(float)', shape=['outputs_dynamic_axes_1', 640, 'outputs_dynamic_axes_2']) | ||
NodeArg(name='prednet_lengths', type='tensor(int32)', shape=['prednet_lengths_dynamic_axes_1']) | ||
NodeArg(name='states', type='tensor(float)', shape=[1, 'states_dynamic_axes_1', 640]) | ||
NodeArg(name='74', type='tensor(float)', shape=[1, 'LSTM74_dim_1', 640]) | ||
=========joiner========== | ||
NodeArg(name='encoder_outputs', type='tensor(float)', shape=['encoder_outputs_dynamic_axes_1', 512, 'encoder_outputs_dynamic_axes_2']) | ||
NodeArg(name='decoder_outputs', type='tensor(float)', shape=['decoder_outputs_dynamic_axes_1', 640, 'decoder_outputs_dynamic_axes_2']) | ||
----- | ||
NodeArg(name='outputs', type='tensor(float)', shape=['outputs_dynamic_axes_1', 'outputs_dynamic_axes_2', 'outputs_dynamic_axes_3', 1025]) | ||
""" |
1 change: 1 addition & 0 deletions
1
scripts/nemo/fast-conformer-hybrid-transducer-ctc/test-onnx-ctc.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.