Skip to content

Commit

Permalink
Export NeMo FastConformer Hybrid Transducer Large Streaming to ONNX (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj committed May 8, 2024
1 parent a9f936e commit 68b25ab
Show file tree
Hide file tree
Showing 9 changed files with 611 additions and 1 deletion.
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: export-nemo-speaker-verification-to-onnx
name: export-nemo-fast-conformer-ctc-to-onnx

on:
workflow_dispatch:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
name: export-nemo-fast-conformer-transducer-to-onnx

on:
workflow_dispatch:

concurrency:
group: export-nemo-fast-conformer-hybrid-transducer-to-onnx-${{ github.ref }}
cancel-in-progress: true

jobs:
export-nemo-fast-conformer-hybrid-transducer-to-onnx:
if: github.repository_owner == 'k2-fsa' || github.repository_owner == 'csukuangfj'
name: NeMo transducer
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: [macos-latest]
python-version: ["3.10"]

steps:
- uses: actions/checkout@v4

- name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}

- name: Install NeMo
shell: bash
run: |
BRANCH='main'
pip install git+https://github.com/NVIDIA/NeMo.git@$BRANCH#egg=nemo_toolkit[asr]
pip install onnxruntime
pip install kaldi-native-fbank
pip install soundfile librosa
- name: Run
shell: bash
run: |
cd scripts/nemo/fast-conformer-hybrid-transducer-ctc
./run-transducer.sh
mv -v sherpa-onnx-nemo* ../../..
- name: Download test waves
shell: bash
run: |
mkdir test_wavs
pushd test_wavs
curl -SL -O https://hf-mirror.com/csukuangfj/sherpa-onnx-nemo-ctc-en-conformer-small/resolve/main/test_wavs/0.wav
curl -SL -O https://hf-mirror.com/csukuangfj/sherpa-onnx-nemo-ctc-en-conformer-small/resolve/main/test_wavs/1.wav
curl -SL -O https://hf-mirror.com/csukuangfj/sherpa-onnx-nemo-ctc-en-conformer-small/resolve/main/test_wavs/8k.wav
curl -SL -O https://hf-mirror.com/csukuangfj/sherpa-onnx-nemo-ctc-en-conformer-small/resolve/main/test_wavs/trans.txt
popd
cp -av test_wavs ./sherpa-onnx-nemo-streaming-fast-conformer-transducer-80ms
cp -av test_wavs ./sherpa-onnx-nemo-streaming-fast-conformer-transducer-480ms
cp -av test_wavs ./sherpa-onnx-nemo-streaming-fast-conformer-transducer-1040ms
tar cjvf sherpa-onnx-nemo-streaming-fast-conformer-transducer-80ms.tar.bz2 sherpa-onnx-nemo-streaming-fast-conformer-transducer-80ms
tar cjvf sherpa-onnx-nemo-streaming-fast-conformer-transducer-480ms.tar.bz2 sherpa-onnx-nemo-streaming-fast-conformer-transducer-480ms
tar cjvf sherpa-onnx-nemo-streaming-fast-conformer-transducer-1040ms.tar.bz2 sherpa-onnx-nemo-streaming-fast-conformer-transducer-1040ms
- name: Release
uses: svenstaro/upload-release-action@v2
with:
file_glob: true
file: ./*.tar.bz2
overwrite: true
repo_name: k2-fsa/sherpa-onnx
repo_token: ${{ secrets.UPLOAD_GH_SHERPA_ONNX_TOKEN }}
tag: asr-models
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#!/usr/bin/env python3
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
import argparse
from typing import Dict

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
#!/usr/bin/env python3
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
import argparse
from typing import Dict

import nemo.collections.asr as nemo_asr
import onnx
import torch


def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model",
type=str,
required=True,
choices=["80", "480", "1040"],
)
return parser.parse_args()


def add_meta_data(filename: str, meta_data: Dict[str, str]):
"""Add meta data to an ONNX model. It is changed in-place.
Args:
filename:
Filename of the ONNX model to be changed.
meta_data:
Key-value pairs.
"""
model = onnx.load(filename)
while len(model.metadata_props):
model.metadata_props.pop()

for key, value in meta_data.items():
meta = model.metadata_props.add()
meta.key = key
meta.value = str(value)

onnx.save(model, filename)


@torch.no_grad()
def main():
args = get_args()
model_name = f"stt_en_fastconformer_hybrid_large_streaming_{args.model}ms"

asr_model = nemo_asr.models.ASRModel.from_pretrained(model_name=model_name)

with open("./tokens.txt", "w", encoding="utf-8") as f:
for i, s in enumerate(asr_model.joint.vocabulary):
f.write(f"{s} {i}\n")
f.write(f"<blk> {i+1}\n")
print("Saved to tokens.txt")

decoder_type = "rnnt"
asr_model.change_decoding_strategy(decoder_type=decoder_type)
asr_model.eval()

assert asr_model.encoder.streaming_cfg is not None
if isinstance(asr_model.encoder.streaming_cfg.chunk_size, list):
chunk_size = asr_model.encoder.streaming_cfg.chunk_size[1]
else:
chunk_size = asr_model.encoder.streaming_cfg.chunk_size

if isinstance(asr_model.encoder.streaming_cfg.pre_encode_cache_size, list):
pre_encode_cache_size = asr_model.encoder.streaming_cfg.pre_encode_cache_size[1]
else:
pre_encode_cache_size = asr_model.encoder.streaming_cfg.pre_encode_cache_size
window_size = chunk_size + pre_encode_cache_size

print("chunk_size", chunk_size)
print("pre_encode_cache_size", pre_encode_cache_size)
print("window_size", window_size)

chunk_shift = chunk_size

# cache_last_channel: (batch_size, dim1, dim2, dim3)
cache_last_channel_dim1 = len(asr_model.encoder.layers)
cache_last_channel_dim2 = asr_model.encoder.streaming_cfg.last_channel_cache_size
cache_last_channel_dim3 = asr_model.encoder.d_model

# cache_last_time: (batch_size, dim1, dim2, dim3)
cache_last_time_dim1 = len(asr_model.encoder.layers)
cache_last_time_dim2 = asr_model.encoder.d_model
cache_last_time_dim3 = asr_model.encoder.conv_context_size[0]

asr_model.set_export_config({"decoder_type": "rnnt", "cache_support": True})

# asr_model.export("model.onnx")
asr_model.encoder.export("encoder.onnx")
asr_model.decoder.export("decoder.onnx")
asr_model.joint.export("joiner.onnx")
# model.onnx is a suffix.
# It will generate two files:
# encoder-model.onnx
# decoder_joint-model.onnx

meta_data = {
"vocab_size": asr_model.decoder.vocab_size,
"window_size": window_size,
"chunk_shift": chunk_shift,
"normalize_type": "None",
"cache_last_channel_dim1": cache_last_channel_dim1,
"cache_last_channel_dim2": cache_last_channel_dim2,
"cache_last_channel_dim3": cache_last_channel_dim3,
"cache_last_time_dim1": cache_last_time_dim1,
"cache_last_time_dim2": cache_last_time_dim2,
"cache_last_time_dim3": cache_last_time_dim3,
"pred_rnn_layers": asr_model.decoder.pred_rnn_layers,
"pred_hidden": asr_model.decoder.pred_hidden,
"subsampling_factor": 8,
"model_type": "EncDecHybridRNNTCTCBPEModel",
"version": "1",
"model_author": "NeMo",
"url": f"https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/{model_name}",
"comment": "Only the transducer branch is exported",
}
add_meta_data("encoder.onnx", meta_data)

print(meta_data)


if __name__ == "__main__":
main()
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#!/usr/bin/env bash
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)

set -ex

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
#!/usr/bin/env bash
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)

set -ex

if [ ! -e ./0.wav ]; then
# curl -SL -O https://hf-mirror.com/csukuangfj/icefall-asr-librispeech-streaming-zipformer-small-2024-03-18/resolve/main/test_wavs/0.wav
curl -SL -O https://huggingface.co/csukuangfj/icefall-asr-librispeech-streaming-zipformer-small-2024-03-18/resolve/main/test_wavs/0.wav
fi

ms=(
80
480
1040
)

for m in ${ms[@]}; do
./export-onnx-transducer.py --model $m
d=sherpa-onnx-nemo-streaming-fast-conformer-transducer-${m}ms
if [ ! -f $d/encoder.onnx ]; then
mkdir -p $d
mv -v encoder.onnx $d/
mv -v decoder.onnx $d/
mv -v joiner.onnx $d/
mv -v tokens.txt $d/
ls -lh $d
fi
done

# Now test the exported models

for m in ${ms[@]}; do
d=sherpa-onnx-nemo-streaming-fast-conformer-transducer-${m}ms
python3 ./test-onnx-transducer.py \
--encoder $d/encoder.onnx \
--decoder $d/decoder.onnx \
--joiner $d/joiner.onnx \
--tokens $d/tokens.txt \
--wav ./0.wav
done
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
#!/usr/bin/env python3
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)

import onnxruntime


def show(filename):
session_opts = onnxruntime.SessionOptions()
session_opts.log_severity_level = 3
sess = onnxruntime.InferenceSession(filename, session_opts)
for i in sess.get_inputs():
print(i)

print("-----")

for i in sess.get_outputs():
print(i)


def main():
print("=========encoder==========")
show("./encoder.onnx")

print("=========decoder==========")
show("./decoder.onnx")

print("=========joiner==========")
show("./joiner.onnx")


if __name__ == "__main__":
main()

"""
=========encoder==========
NodeArg(name='audio_signal', type='tensor(float)', shape=['audio_signal_dynamic_axes_1', 80, 'audio_signal_dynamic_axes_2'])
NodeArg(name='length', type='tensor(int64)', shape=['length_dynamic_axes_1'])
NodeArg(name='cache_last_channel', type='tensor(float)', shape=['cache_last_channel_dynamic_axes_1', 17, 'cache_last_channel_dynamic_axes_2', 512])
NodeArg(name='cache_last_time', type='tensor(float)', shape=['cache_last_time_dynamic_axes_1', 17, 512, 'cache_last_time_dynamic_axes_2'])
NodeArg(name='cache_last_channel_len', type='tensor(int64)', shape=['cache_last_channel_len_dynamic_axes_1'])
-----
NodeArg(name='outputs', type='tensor(float)', shape=['outputs_dynamic_axes_1', 512, 'outputs_dynamic_axes_2'])
NodeArg(name='encoded_lengths', type='tensor(int64)', shape=['encoded_lengths_dynamic_axes_1'])
NodeArg(name='cache_last_channel_next', type='tensor(float)', shape=['cache_last_channel_next_dynamic_axes_1', 17, 'cache_last_channel_next_dynamic_axes_2', 512])
NodeArg(name='cache_last_time_next', type='tensor(float)', shape=['cache_last_time_next_dynamic_axes_1', 17, 512, 'cache_last_time_next_dynamic_axes_2'])
NodeArg(name='cache_last_channel_next_len', type='tensor(int64)', shape=['cache_last_channel_next_len_dynamic_axes_1'])
=========decoder==========
NodeArg(name='targets', type='tensor(int32)', shape=['targets_dynamic_axes_1', 'targets_dynamic_axes_2'])
NodeArg(name='target_length', type='tensor(int32)', shape=['target_length_dynamic_axes_1'])
NodeArg(name='states.1', type='tensor(float)', shape=[1, 'states.1_dim_1', 640])
NodeArg(name='onnx::LSTM_3', type='tensor(float)', shape=[1, 1, 640])
-----
NodeArg(name='outputs', type='tensor(float)', shape=['outputs_dynamic_axes_1', 640, 'outputs_dynamic_axes_2'])
NodeArg(name='prednet_lengths', type='tensor(int32)', shape=['prednet_lengths_dynamic_axes_1'])
NodeArg(name='states', type='tensor(float)', shape=[1, 'states_dynamic_axes_1', 640])
NodeArg(name='74', type='tensor(float)', shape=[1, 'LSTM74_dim_1', 640])
=========joiner==========
NodeArg(name='encoder_outputs', type='tensor(float)', shape=['encoder_outputs_dynamic_axes_1', 512, 'encoder_outputs_dynamic_axes_2'])
NodeArg(name='decoder_outputs', type='tensor(float)', shape=['decoder_outputs_dynamic_axes_1', 640, 'decoder_outputs_dynamic_axes_2'])
-----
NodeArg(name='outputs', type='tensor(float)', shape=['outputs_dynamic_axes_1', 'outputs_dynamic_axes_2', 'outputs_dynamic_axes_3', 1025])
"""
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#!/usr/bin/env python3
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)

import argparse
from pathlib import Path
Expand Down
Loading

0 comments on commit 68b25ab

Please sign in to comment.