Skip to content

Commit

Permalink
Add online punctuation and casing prediction model for English langua…
Browse files Browse the repository at this point in the history
…ge (#1224)
  • Loading branch information
frankyoujian committed Aug 6, 2024
1 parent 52830cc commit 1414e4d
Show file tree
Hide file tree
Showing 14 changed files with 874 additions and 0 deletions.
6 changes: 6 additions & 0 deletions sherpa-onnx/csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,10 @@ list(APPEND sources
offline-punctuation-impl.cc
offline-punctuation-model-config.cc
offline-punctuation.cc
online-cnn-bilstm-model.cc
online-punctuation-impl.cc
online-punctuation-model-config.cc
online-punctuation.cc
)

if(SHERPA_ONNX_ENABLE_TTS)
Expand Down Expand Up @@ -243,6 +247,7 @@ if(SHERPA_ONNX_ENABLE_BINARY)
add_executable(sherpa-onnx-offline-language-identification sherpa-onnx-offline-language-identification.cc)
add_executable(sherpa-onnx-offline-parallel sherpa-onnx-offline-parallel.cc)
add_executable(sherpa-onnx-offline-punctuation sherpa-onnx-offline-punctuation.cc)
add_executable(sherpa-onnx-online-punctuation sherpa-onnx-online-punctuation.cc)

if(SHERPA_ONNX_ENABLE_TTS)
add_executable(sherpa-onnx-offline-tts sherpa-onnx-offline-tts.cc)
Expand All @@ -256,6 +261,7 @@ if(SHERPA_ONNX_ENABLE_BINARY)
sherpa-onnx-offline-language-identification
sherpa-onnx-offline-parallel
sherpa-onnx-offline-punctuation
sherpa-onnx-online-punctuation
)
if(SHERPA_ONNX_ENABLE_TTS)
list(APPEND main_exes
Expand Down
25 changes: 25 additions & 0 deletions sherpa-onnx/csrc/online-cnn-bilstm-model-meta-data.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// sherpa-onnx/csrc/online-cnn-bilstm-model-meta-data.h
//
// Copyright (c) 2024 Jian You ([email protected], Cisco Systems)

#ifndef SHERPA_ONNX_CSRC_ONLINE_CNN_BILSTM_MODEL_META_DATA_H_
#define SHERPA_ONNX_CSRC_ONLINE_CNN_BILSTM_MODEL_META_DATA_H_

namespace sherpa_onnx {

struct OnlineCNNBiLSTMModelMetaData {
int32_t comma_id;
int32_t period_id;
int32_t quest_id;

int32_t upper_id;
int32_t cap_id;
int32_t mix_case_id;

int32_t num_cases;
int32_t num_punctuations;
};

} // namespace sherpa_onnx

#endif // SHERPA_ONNX_CSRC_ONLINE_CNN_BILSTM_MODEL_META_DATA_H_
135 changes: 135 additions & 0 deletions sherpa-onnx/csrc/online-cnn-bilstm-model.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
// sherpa-onnx/csrc/online-cnn-bilstm-model.cc
//
// Copyright (c) 2024 Jian You ([email protected], Cisco Systems)

#include "sherpa-onnx/csrc/online-cnn-bilstm-model.h"

#include <string>
#include <vector>

#include "sherpa-onnx/csrc/onnx-utils.h"
#include "sherpa-onnx/csrc/session.h"
#include "sherpa-onnx/csrc/text-utils.h"

namespace sherpa_onnx {

class OnlineCNNBiLSTMModel::Impl {
public:
explicit Impl(const OnlinePunctuationModelConfig &config)
: config_(config),
env_(ORT_LOGGING_LEVEL_ERROR),
sess_opts_(GetSessionOptions(config)),
allocator_{} {
auto buf = ReadFile(config_.cnn_bilstm);
Init(buf.data(), buf.size());
}

#if __ANDROID_API__ >= 9
Impl(AAssetManager *mgr, const OnlinePunctuationModelConfig &config)
: config_(config),
env_(ORT_LOGGING_LEVEL_ERROR),
sess_opts_(GetSessionOptions(config)),
allocator_{} {
auto buf = ReadFile(mgr, config_.cnn_bilstm);
Init(buf.data(), buf.size());
}
#endif

std::pair<Ort::Value, Ort::Value> Forward(Ort::Value token_ids, Ort::Value valid_ids, Ort::Value label_lens) {
std::array<Ort::Value, 3> inputs = {std::move(token_ids), std::move(valid_ids), std::move(label_lens)};

auto ans =
sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(),
output_names_ptr_.data(), output_names_ptr_.size());
return {std::move(ans[0]), std::move(ans[1])};
}

OrtAllocator *Allocator() const { return allocator_; }

const OnlineCNNBiLSTMModelMetaData &GetModelMetadata() const {
return meta_data_;
}

private:
void Init(void *model_data, size_t model_data_length) {
sess_ = std::make_unique<Ort::Session>(env_, model_data, model_data_length,
sess_opts_);

GetInputNames(sess_.get(), &input_names_, &input_names_ptr_);

GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_);

// get meta data
Ort::ModelMetadata meta_data = sess_->GetModelMetadata();

Ort::AllocatorWithDefaultOptions allocator; // used in the macro below

SHERPA_ONNX_READ_META_DATA(meta_data_.comma_id, "COMMA");
SHERPA_ONNX_READ_META_DATA(meta_data_.period_id, "PERIOD");
SHERPA_ONNX_READ_META_DATA(meta_data_.quest_id, "QUESTION");

// assert here, because we will use the constant value
assert(meta_data_.comma_id == 1);
assert(meta_data_.period_id == 2);
assert(meta_data_.quest_id == 3);

SHERPA_ONNX_READ_META_DATA(meta_data_.upper_id, "UPPER");
SHERPA_ONNX_READ_META_DATA(meta_data_.cap_id, "CAP");
SHERPA_ONNX_READ_META_DATA(meta_data_.mix_case_id, "MIX_CASE");

assert(meta_data_.upper_id == 1);
assert(meta_data_.cap_id == 2);
assert(meta_data_.mix_case_id == 3);

// output shape is (T', num_cases)
meta_data_.num_cases =
sess_->GetOutputTypeInfo(0).GetTensorTypeAndShapeInfo().GetShape()[1];
meta_data_.num_punctuations =
sess_->GetOutputTypeInfo(1).GetTensorTypeAndShapeInfo().GetShape()[1];
}

private:
OnlinePunctuationModelConfig config_;
Ort::Env env_;
Ort::SessionOptions sess_opts_;
Ort::AllocatorWithDefaultOptions allocator_;

std::unique_ptr<Ort::Session> sess_;

std::vector<std::string> input_names_;
std::vector<const char *> input_names_ptr_;

std::vector<std::string> output_names_;
std::vector<const char *> output_names_ptr_;

OnlineCNNBiLSTMModelMetaData meta_data_;
};

OnlineCNNBiLSTMModel::OnlineCNNBiLSTMModel(
const OnlinePunctuationModelConfig &config)
: impl_(std::make_unique<Impl>(config)) {}

#if __ANDROID_API__ >= 9
OnlineCNNBiLSTMModel::OnlineCNNBiLSTMModel(
AAssetManager *mgr, const OnlinePunctuationModelConfig &config)
: impl_(std::make_unique<Impl>(mgr, config)) {}
#endif

OnlineCNNBiLSTMModel::~OnlineCNNBiLSTMModel() = default;

std::pair<Ort::Value, Ort::Value> OnlineCNNBiLSTMModel::Forward(Ort::Value token_ids,
Ort::Value valid_ids,
Ort::Value label_lens) const {
return impl_->Forward(std::move(token_ids), std::move(valid_ids), std::move(label_lens));
}

OrtAllocator *OnlineCNNBiLSTMModel::Allocator() const {
return impl_->Allocator();
}

const OnlineCNNBiLSTMModelMetaData &
OnlineCNNBiLSTMModel::GetModelMetadata() const {
return impl_->GetModelMetadata();
}

} // namespace sherpa_onnx
61 changes: 61 additions & 0 deletions sherpa-onnx/csrc/online-cnn-bilstm-model.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
// sherpa-onnx/csrc/online-cnn-bilstm-model.h
//
// Copyright (c) 2024 Jian You ([email protected], Cisco Systems)

#ifndef SHERPA_ONNX_CSRC_ONLINE_CNN_BILSTM_MODEL_H_
#define SHERPA_ONNX_CSRC_ONLINE_CNN_BILSTM_MODEL_H_
#include <memory>
#include <utility>

#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif

#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/online-cnn-bilstm-model-meta-data.h"
#include "sherpa-onnx/csrc/online-punctuation-model-config.h"

namespace sherpa_onnx {

/** This class implements
* https://github.com/frankyoujian/Edge-Punct-Casing/blob/main/onnx_decode_sentence.py
*/
class OnlineCNNBiLSTMModel {
public:
explicit OnlineCNNBiLSTMModel(
const OnlinePunctuationModelConfig &config);

#if __ANDROID_API__ >= 9
OnlineCNNBiLSTMModel(AAssetManager *mgr,
const OnlinePunctuationModelConfig &config);
#endif

~OnlineCNNBiLSTMModel();

/** Run the forward method of the model.
*
* @param token_ids A tensor of shape (N, T) of dtype int32.
* @param valid_ids A tensor of shape (N, T) of dtype int32.
* @param label_lens A tensor of shape (N) of dtype int32.
*
* @return Return a pair of tensors
* - case_logits: A 2-D tensor of shape (T', num_cases).
* - punct_logits: A 2-D tensor of shape (T', num_puncts).
*/
std::pair<Ort::Value, Ort::Value> Forward(Ort::Value token_ids, Ort::Value valid_ids, Ort::Value label_lens) const;

/** Return an allocator for allocating memory
*/
OrtAllocator *Allocator() const;

const OnlineCNNBiLSTMModelMetaData &GetModelMetadata() const;

private:
class Impl;
std::unique_ptr<Impl> impl_;
};

} // namespace sherpa_onnx

#endif // SHERPA_ONNX_CSRC_ONLINE_CNN_BILSTM_MODEL_H_
Loading

0 comments on commit 1414e4d

Please sign in to comment.