Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Preserve previous result as context for next segment #1335

Merged
merged 1 commit into from
Sep 11, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 16 additions & 7 deletions sherpa-onnx/csrc/online-recognizer-transducer-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -360,31 +360,40 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
}

void Reset(OnlineStream *s) const override {
int32_t context_size = model_->ContextSize();

{
// segment is incremented only when the last
// result is not empty
// result is not empty, contains non-blanks and longer than context_size)
const auto &r = s->GetResult();
if (!r.tokens.empty() && r.tokens.back() != 0) {
if (!r.tokens.empty() && r.tokens.back() != 0 && r.tokens.size() > context_size) {
s->GetCurrentSegment() += 1;
}
}

// reset encoder states
// s->SetStates(model_->GetEncoderInitStates());

// we keep the decoder_out
decoder_->UpdateDecoderOut(&s->GetResult());
Ort::Value decoder_out = std::move(s->GetResult().decoder_out);

auto r = decoder_->GetEmptyResult();
if (config_.decoding_method == "modified_beam_search" &&
nullptr != s->GetContextGraph()) {
for (auto it = r.hyps.begin(); it != r.hyps.end(); ++it) {
it->second.context_state = s->GetContextGraph()->Root();
}
}

auto last_result = s->GetResult();
// if last result is not empty, then
// preserve last tokens as the context for next result
if (static_cast<int32_t>(last_result.tokens.size()) > context_size) {
std::vector<int64_t> context(last_result.tokens.end() - context_size, last_result.tokens.end());

Hypotheses context_hyp({{context, 0}});
r.hyps = std::move(context_hyp);
r.tokens = std::move(context);
}

s->SetResult(r);
s->GetResult().decoder_out = std::move(decoder_out);

// Note: We only update counters. The underlying audio samples
// are not discarded.
Expand Down
Loading