Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add online punctuation and casing prediction model for English language #1224

Merged
merged 6 commits into from
Aug 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions sherpa-onnx/csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,10 @@ list(APPEND sources
offline-punctuation-impl.cc
offline-punctuation-model-config.cc
offline-punctuation.cc
online-cnn-bilstm-model.cc
online-punctuation-impl.cc
online-punctuation-model-config.cc
online-punctuation.cc
)

if(SHERPA_ONNX_ENABLE_TTS)
Expand Down Expand Up @@ -244,6 +248,7 @@ if(SHERPA_ONNX_ENABLE_BINARY)
add_executable(sherpa-onnx-offline-language-identification sherpa-onnx-offline-language-identification.cc)
add_executable(sherpa-onnx-offline-parallel sherpa-onnx-offline-parallel.cc)
add_executable(sherpa-onnx-offline-punctuation sherpa-onnx-offline-punctuation.cc)
add_executable(sherpa-onnx-online-punctuation sherpa-onnx-online-punctuation.cc)

if(SHERPA_ONNX_ENABLE_TTS)
add_executable(sherpa-onnx-offline-tts sherpa-onnx-offline-tts.cc)
Expand All @@ -257,6 +262,7 @@ if(SHERPA_ONNX_ENABLE_BINARY)
sherpa-onnx-offline-language-identification
sherpa-onnx-offline-parallel
sherpa-onnx-offline-punctuation
sherpa-onnx-online-punctuation
)
if(SHERPA_ONNX_ENABLE_TTS)
list(APPEND main_exes
Expand Down
25 changes: 25 additions & 0 deletions sherpa-onnx/csrc/online-cnn-bilstm-model-meta-data.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// sherpa-onnx/csrc/online-cnn-bilstm-model-meta-data.h
//
// Copyright (c) 2024 Jian You ([email protected], Cisco Systems)

#ifndef SHERPA_ONNX_CSRC_ONLINE_CNN_BILSTM_MODEL_META_DATA_H_
#define SHERPA_ONNX_CSRC_ONLINE_CNN_BILSTM_MODEL_META_DATA_H_

namespace sherpa_onnx {

struct OnlineCNNBiLSTMModelMetaData {
int32_t comma_id;
int32_t period_id;
int32_t quest_id;

int32_t upper_id;
int32_t cap_id;
int32_t mix_case_id;

int32_t num_cases;
int32_t num_punctuations;
};

} // namespace sherpa_onnx

#endif // SHERPA_ONNX_CSRC_ONLINE_CNN_BILSTM_MODEL_META_DATA_H_
135 changes: 135 additions & 0 deletions sherpa-onnx/csrc/online-cnn-bilstm-model.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
// sherpa-onnx/csrc/online-cnn-bilstm-model.cc
//
// Copyright (c) 2024 Jian You ([email protected], Cisco Systems)

#include "sherpa-onnx/csrc/online-cnn-bilstm-model.h"

#include <string>
#include <vector>

#include "sherpa-onnx/csrc/onnx-utils.h"
#include "sherpa-onnx/csrc/session.h"
#include "sherpa-onnx/csrc/text-utils.h"

namespace sherpa_onnx {

class OnlineCNNBiLSTMModel::Impl {
public:
explicit Impl(const OnlinePunctuationModelConfig &config)
: config_(config),
env_(ORT_LOGGING_LEVEL_ERROR),
sess_opts_(GetSessionOptions(config)),
allocator_{} {
auto buf = ReadFile(config_.cnn_bilstm);
Init(buf.data(), buf.size());
}

#if __ANDROID_API__ >= 9
Impl(AAssetManager *mgr, const OnlinePunctuationModelConfig &config)
: config_(config),
env_(ORT_LOGGING_LEVEL_ERROR),
sess_opts_(GetSessionOptions(config)),
allocator_{} {
auto buf = ReadFile(mgr, config_.cnn_bilstm);
Init(buf.data(), buf.size());
}
#endif

std::pair<Ort::Value, Ort::Value> Forward(Ort::Value token_ids, Ort::Value valid_ids, Ort::Value label_lens) {
std::array<Ort::Value, 3> inputs = {std::move(token_ids), std::move(valid_ids), std::move(label_lens)};

auto ans =
sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(),
output_names_ptr_.data(), output_names_ptr_.size());
return {std::move(ans[0]), std::move(ans[1])};
}

OrtAllocator *Allocator() const { return allocator_; }

const OnlineCNNBiLSTMModelMetaData &GetModelMetadata() const {
return meta_data_;
}

private:
void Init(void *model_data, size_t model_data_length) {
sess_ = std::make_unique<Ort::Session>(env_, model_data, model_data_length,
sess_opts_);

GetInputNames(sess_.get(), &input_names_, &input_names_ptr_);

GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_);

// get meta data
Ort::ModelMetadata meta_data = sess_->GetModelMetadata();

Ort::AllocatorWithDefaultOptions allocator; // used in the macro below

SHERPA_ONNX_READ_META_DATA(meta_data_.comma_id, "COMMA");
SHERPA_ONNX_READ_META_DATA(meta_data_.period_id, "PERIOD");
SHERPA_ONNX_READ_META_DATA(meta_data_.quest_id, "QUESTION");

// assert here, because we will use the constant value
assert(meta_data_.comma_id == 1);
assert(meta_data_.period_id == 2);
assert(meta_data_.quest_id == 3);

SHERPA_ONNX_READ_META_DATA(meta_data_.upper_id, "UPPER");
SHERPA_ONNX_READ_META_DATA(meta_data_.cap_id, "CAP");
SHERPA_ONNX_READ_META_DATA(meta_data_.mix_case_id, "MIX_CASE");

assert(meta_data_.upper_id == 1);
assert(meta_data_.cap_id == 2);
assert(meta_data_.mix_case_id == 3);

// output shape is (T', num_cases)
meta_data_.num_cases =
sess_->GetOutputTypeInfo(0).GetTensorTypeAndShapeInfo().GetShape()[1];
meta_data_.num_punctuations =
sess_->GetOutputTypeInfo(1).GetTensorTypeAndShapeInfo().GetShape()[1];
}

private:
OnlinePunctuationModelConfig config_;
Ort::Env env_;
Ort::SessionOptions sess_opts_;
Ort::AllocatorWithDefaultOptions allocator_;

std::unique_ptr<Ort::Session> sess_;

std::vector<std::string> input_names_;
std::vector<const char *> input_names_ptr_;

std::vector<std::string> output_names_;
std::vector<const char *> output_names_ptr_;

OnlineCNNBiLSTMModelMetaData meta_data_;
};

OnlineCNNBiLSTMModel::OnlineCNNBiLSTMModel(
const OnlinePunctuationModelConfig &config)
: impl_(std::make_unique<Impl>(config)) {}

#if __ANDROID_API__ >= 9
OnlineCNNBiLSTMModel::OnlineCNNBiLSTMModel(
AAssetManager *mgr, const OnlinePunctuationModelConfig &config)
: impl_(std::make_unique<Impl>(mgr, config)) {}
#endif

OnlineCNNBiLSTMModel::~OnlineCNNBiLSTMModel() = default;

std::pair<Ort::Value, Ort::Value> OnlineCNNBiLSTMModel::Forward(Ort::Value token_ids,
Ort::Value valid_ids,
Ort::Value label_lens) const {
return impl_->Forward(std::move(token_ids), std::move(valid_ids), std::move(label_lens));
}

OrtAllocator *OnlineCNNBiLSTMModel::Allocator() const {
return impl_->Allocator();
}

const OnlineCNNBiLSTMModelMetaData &
OnlineCNNBiLSTMModel::GetModelMetadata() const {
return impl_->GetModelMetadata();
}

} // namespace sherpa_onnx
61 changes: 61 additions & 0 deletions sherpa-onnx/csrc/online-cnn-bilstm-model.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
// sherpa-onnx/csrc/online-cnn-bilstm-model.h
//
// Copyright (c) 2024 Jian You ([email protected], Cisco Systems)

#ifndef SHERPA_ONNX_CSRC_ONLINE_CNN_BILSTM_MODEL_H_
#define SHERPA_ONNX_CSRC_ONLINE_CNN_BILSTM_MODEL_H_
#include <memory>
#include <utility>

#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif

#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/online-cnn-bilstm-model-meta-data.h"
#include "sherpa-onnx/csrc/online-punctuation-model-config.h"

namespace sherpa_onnx {

/** This class implements
* https://github.com/frankyoujian/Edge-Punct-Casing/blob/main/onnx_decode_sentence.py
*/
class OnlineCNNBiLSTMModel {
public:
explicit OnlineCNNBiLSTMModel(
const OnlinePunctuationModelConfig &config);

#if __ANDROID_API__ >= 9
OnlineCNNBiLSTMModel(AAssetManager *mgr,
const OnlinePunctuationModelConfig &config);
#endif

~OnlineCNNBiLSTMModel();

/** Run the forward method of the model.
*
* @param token_ids A tensor of shape (N, T) of dtype int32.
* @param valid_ids A tensor of shape (N, T) of dtype int32.
* @param label_lens A tensor of shape (N) of dtype int32.
*
* @return Return a pair of tensors
* - case_logits: A 2-D tensor of shape (T', num_cases).
* - punct_logits: A 2-D tensor of shape (T', num_puncts).
*/
std::pair<Ort::Value, Ort::Value> Forward(Ort::Value token_ids, Ort::Value valid_ids, Ort::Value label_lens) const;

/** Return an allocator for allocating memory
*/
OrtAllocator *Allocator() const;

const OnlineCNNBiLSTMModelMetaData &GetModelMetadata() const;

private:
class Impl;
std::unique_ptr<Impl> impl_;
};

} // namespace sherpa_onnx

#endif // SHERPA_ONNX_CSRC_ONLINE_CNN_BILSTM_MODEL_H_
Loading
Loading