Skip to content

Commit

Permalink
Support specifying max speech duration for VAD.
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj committed Sep 14, 2024
1 parent 5d76171 commit 71536f6
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 7 deletions.
9 changes: 8 additions & 1 deletion python-api-examples/generate-subtitles.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,14 @@ def main():

config = sherpa_onnx.VadModelConfig()
config.silero_vad.model = args.silero_vad_model
config.silero_vad.min_silence_duration = 0.25
config.silero_vad.threshold = 0.5
config.silero_vad.min_silence_duration = 0.25 # seconds
config.silero_vad.min_speech_duration = 0.25 # seconds

# If the current segment is larger than this value, then it increases
# the threshold to 0.9 internally. After detecting this segment,
# it resets the threshold to its original value.
config.silero_vad.max_speech_duration = 5 # seconds
config.sample_rate = args.sample_rate

window_size = config.silero_vad.window_size
Expand Down
34 changes: 34 additions & 0 deletions sherpa-onnx/csrc/silero-vad-model-config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,12 @@ void SileroVadModelConfig::Register(ParseOptions *po) {
"In seconds. In the end of each silence chunk wait for "
"--silero-vad-min-speech-duration seconds before separating it");

po->Register(
"silero-vad-max-speech-duration", &max_speech_duration,
"In seconds. If a speech segment is longer than this value, then we "
"increase the threshold to 0.9. After finishing detecting the segment, "
"the threshold value is reset to its original value.");

po->Register(
"silero-vad-window-size", &window_size,
"In samples. Audio chunks of --silero-vad-window-size samples are fed "
Expand Down Expand Up @@ -63,6 +69,33 @@ bool SileroVadModelConfig::Validate() const {
return false;
}

if (min_silence_duration <= 0) {
SHERPA_ONNX_LOGE(
"Please use a larger value for --silero-vad-min-silence-duration. "
"Given: "
"%f",
min_silence_duration);
return false;
}

if (min_speech_duration <= 0) {
SHERPA_ONNX_LOGE(
"Please use a larger value for --silero-vad-min-speech-duration. "
"Given: "
"%f",
min_speech_duration);
return false;
}

if (max_speech_duration <= 0) {
SHERPA_ONNX_LOGE(
"Please use a larger value for --silero-vad-max-speech-duration. "
"Given: "
"%f",
max_speech_duration);
return false;
}

return true;
}

Expand All @@ -74,6 +107,7 @@ std::string SileroVadModelConfig::ToString() const {
os << "threshold=" << threshold << ", ";
os << "min_silence_duration=" << min_silence_duration << ", ";
os << "min_speech_duration=" << min_speech_duration << ", ";
os << "max_speech_duration=" << max_speech_duration << ", ";
os << "window_size=" << window_size << ")";

return os.str();
Expand Down
5 changes: 5 additions & 0 deletions sherpa-onnx/csrc/silero-vad-model-config.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@ struct SileroVadModelConfig {
// 256, 512, 768 samples for 800 Hz
int32_t window_size = 512; // in samples

// If a speech segment is longer than this value, then we increase
// the threshold to 0.9. After finishing detecting the segment,
// the threshold value is reset to its original value.
float max_speech_duration = 20; // in seconds

SileroVadModelConfig() = default;

void Register(ParseOptions *po);
Expand Down
21 changes: 17 additions & 4 deletions sherpa-onnx/csrc/voice-activity-detector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,18 @@ class VoiceActivityDetector::Impl {
explicit Impl(const VadModelConfig &config, float buffer_size_in_seconds = 60)
: model_(VadModel::Create(config)),
config_(config),
buffer_(buffer_size_in_seconds * config.sample_rate) {}
buffer_(buffer_size_in_seconds * config.sample_rate) {
Init();
}

#if __ANDROID_API__ >= 9
Impl(AAssetManager *mgr, const VadModelConfig &config,
float buffer_size_in_seconds = 60)
: model_(VadModel::Create(mgr, config)),
config_(config),
buffer_(buffer_size_in_seconds * config.sample_rate) {}
buffer_(buffer_size_in_seconds * config.sample_rate) {
Init();
}
#endif

void AcceptWaveform(const float *samples, int32_t n) {
Expand Down Expand Up @@ -145,6 +149,15 @@ class VoiceActivityDetector::Impl {

const VadModelConfig &GetConfig() const { return config_; }

private:
void Init() {
// TODO(fangjun): Currently, we support only one vad model.
// If a new vad model is added, we need to change the place
// where max_speech_duration is placed.
max_utterance_length_ =
config_.sample_rate * config_.silero_vad.max_speech_duration;
}

private:
std::queue<SpeechSegment> segments_;

Expand All @@ -153,9 +166,9 @@ class VoiceActivityDetector::Impl {
CircularBuffer buffer_;
std::vector<float> last_;

int max_utterance_length_ = 16000 * 20; // in samples
int max_utterance_length_ = -1; // in samples
float new_min_silence_duration_s_ = 0.1;
float new_threshold_ = 1.10;
float new_threshold_ = 0.90;

int32_t start_ = -1;
};
Expand Down
8 changes: 6 additions & 2 deletions sherpa-onnx/python/csrc/silero-vad-model-config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,25 +17,29 @@ void PybindSileroVadModelConfig(py::module *m) {
.def(py::init<>())
.def(py::init([](const std::string &model, float threshold,
float min_silence_duration, float min_speech_duration,
int32_t window_size) -> std::unique_ptr<PyClass> {
int32_t window_size,
float max_speech_duration) -> std::unique_ptr<PyClass> {
auto ans = std::make_unique<PyClass>();

ans->model = model;
ans->threshold = threshold;
ans->min_silence_duration = min_silence_duration;
ans->min_speech_duration = min_speech_duration;
ans->window_size = window_size;
ans->max_speech_duration = max_speech_duration;

return ans;
}),
py::arg("model"), py::arg("threshold") = 0.5,
py::arg("min_silence_duration") = 0.5,
py::arg("min_speech_duration") = 0.25, py::arg("window_size") = 512)
py::arg("min_speech_duration") = 0.25, py::arg("window_size") = 512,
py::arg("max_speech_duration") = 20)
.def_readwrite("model", &PyClass::model)
.def_readwrite("threshold", &PyClass::threshold)
.def_readwrite("min_silence_duration", &PyClass::min_silence_duration)
.def_readwrite("min_speech_duration", &PyClass::min_speech_duration)
.def_readwrite("window_size", &PyClass::window_size)
.def_readwrite("max_speech_duration", &PyClass::max_speech_duration)
.def("__str__", &PyClass::ToString)
.def("validate", &PyClass::Validate);
}
Expand Down

0 comments on commit 71536f6

Please sign in to comment.