Skip to content

Commit

Permalink
Add C API for speaker embedding extractor. (#711)
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj committed Mar 28, 2024
1 parent 638f48f commit 2e0bcca
Show file tree
Hide file tree
Showing 23 changed files with 739 additions and 80 deletions.
15 changes: 14 additions & 1 deletion .github/scripts/test-c-api.sh
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/usr/bin/env bash

set -e
set -ex

log() {
# This function is from espnet
Expand All @@ -9,6 +9,7 @@ log() {
}

echo "SLID_EXE is $SLID_EXE"
echo "SID_EXE is $SID_EXE"
echo "PATH: $PATH"


Expand All @@ -24,3 +25,15 @@ rm sherpa-onnx-whisper-tiny.tar.bz2
$SLID_EXE

rm -rf sherpa-onnx-whisper-tiny*

log "------------------------------------------------------------"
log "Download file for speaker identification and verification "
log "------------------------------------------------------------"

curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/3dspeaker_speech_campplus_sv_zh-cn_16k-common.onnx
git clone https://github.com/csukuangfj/sr-data

$SID_EXE

rm -fv *.onnx
rm -rf sr-data
3 changes: 2 additions & 1 deletion .github/workflows/linux.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -124,11 +124,12 @@ jobs:
name: release-${{ matrix.build_type }}-with-shared-lib-${{ matrix.shared_lib }}-with-tts-${{ matrix.with_tts }}
path: build/bin/*

- name: Test spoken language identification (C API)
- name: Test C API
shell: bash
run: |
export PATH=$PWD/build/bin:$PATH
export SLID_EXE=spoken-language-identification-c-api
export SID_EXE=speaker-identification-c-api
.github/scripts/test-c-api.sh
Expand Down
3 changes: 2 additions & 1 deletion .github/workflows/macos.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -103,11 +103,12 @@ jobs:
otool -L build/bin/sherpa-onnx
otool -l build/bin/sherpa-onnx
- name: Test spoken language identification (C API)
- name: Test C API
shell: bash
run: |
export PATH=$PWD/build/bin:$PATH
export SLID_EXE=spoken-language-identification-c-api
export SID_EXE=speaker-identification-c-api
.github/scripts/test-c-api.sh
Expand Down
3 changes: 2 additions & 1 deletion .github/workflows/windows-x64.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,12 @@ jobs:
ls -lh ./bin/Release/sherpa-onnx.exe
- name: Test spoken language identification (C API)
- name: Test C API
shell: bash
run: |
export PATH=$PWD/build/bin/Release:$PATH
export SLID_EXE=spoken-language-identification-c-api.exe
export SID_EXE=speaker-identification-c-api.exe
.github/scripts/test-c-api.sh
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -87,3 +87,4 @@ vits-coqui-*
vits-mms-*
*.tar.bz2
sherpa-onnx-paraformer-trilingual-zh-cantonese-en
sr-data
3 changes: 3 additions & 0 deletions c-api-examples/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ endif()
add_executable(spoken-language-identification-c-api spoken-language-identification-c-api.c)
target_link_libraries(spoken-language-identification-c-api sherpa-onnx-c-api)

add_executable(speaker-identification-c-api speaker-identification-c-api.c)
target_link_libraries(speaker-identification-c-api sherpa-onnx-c-api)

if(SHERPA_ONNX_HAS_ALSA)
add_subdirectory(./asr-microphone-example)
elseif((UNIX AND NOT APPLE) OR LINUX)
Expand Down
7 changes: 4 additions & 3 deletions c-api-examples/asr-microphone-example/c-api-alsa.cc
Original file line number Diff line number Diff line change
Expand Up @@ -188,10 +188,11 @@ int32_t main(int32_t argc, char *argv[]) {
}
}

SherpaOnnxOnlineRecognizer *recognizer = CreateOnlineRecognizer(&config);
SherpaOnnxOnlineStream *stream = CreateOnlineStream(recognizer);
const SherpaOnnxOnlineRecognizer *recognizer =
CreateOnlineRecognizer(&config);
const SherpaOnnxOnlineStream *stream = CreateOnlineStream(recognizer);

SherpaOnnxDisplay *display = CreateDisplay(50);
const SherpaOnnxDisplay *display = CreateDisplay(50);
int32_t segment_id = 0;

const char *device_name = argv[context.index];
Expand Down
7 changes: 4 additions & 3 deletions c-api-examples/decode-file-c-api.c
Original file line number Diff line number Diff line change
Expand Up @@ -162,10 +162,11 @@ int32_t main(int32_t argc, char *argv[]) {
}
}

SherpaOnnxOnlineRecognizer *recognizer = CreateOnlineRecognizer(&config);
SherpaOnnxOnlineStream *stream = CreateOnlineStream(recognizer);
const SherpaOnnxOnlineRecognizer *recognizer =
CreateOnlineRecognizer(&config);
const SherpaOnnxOnlineStream *stream = CreateOnlineStream(recognizer);

SherpaOnnxDisplay *display = CreateDisplay(50);
const SherpaOnnxDisplay *display = CreateDisplay(50);
int32_t segment_id = 0;

const char *wav_filename = argv[context.index];
Expand Down
256 changes: 256 additions & 0 deletions c-api-examples/speaker-identification-c-api.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,256 @@
// c-api-examples/speaker-identification-c-api.c
//
// Copyright (c) 2024 Xiaomi Corporation

// We assume you have pre-downloaded the speaker embedding extractor model
// from
// https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-recongition-models
//
// An example command to download
// "3dspeaker_speech_campplus_sv_zh-cn_16k-common.onnx"
// is given below:
//
// clang-format off
//
// wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/3dspeaker_speech_campplus_sv_zh-cn_16k-common.onnx
//
// clang-format on
//
// Also, please download the test wave files from
//
// https://github.com/csukuangfj/sr-data

#include <stdio.h>
#include <stdlib.h>
#include <string.h>

#include "sherpa-onnx/c-api/c-api.h"

static const float *ComputeEmbedding(
const SherpaOnnxSpeakerEmbeddingExtractor *ex, const char *wav_filename) {
const SherpaOnnxWave *wave = SherpaOnnxReadWave(wav_filename);
if (wave == NULL) {
fprintf(stderr, "Failed to read %s\n", wav_filename);
exit(-1);
}

const SherpaOnnxOnlineStream *stream =
SherpaOnnxSpeakerEmbeddingExtractorCreateStream(ex);

AcceptWaveform(stream, wave->sample_rate, wave->samples, wave->num_samples);
InputFinished(stream);

if (!SherpaOnnxSpeakerEmbeddingExtractorIsReady(ex, stream)) {
fprintf(stderr, "The input wave file %s is too short!\n", wav_filename);
exit(-1);
}

// we will free `v` outside of this function
const float *v =
SherpaOnnxSpeakerEmbeddingExtractorComputeEmbedding(ex, stream);

DestroyOnlineStream(stream);
SherpaOnnxFreeWave(wave);

// Remeber to free v to avoid memory leak
return v;
}

int32_t main() {
SherpaOnnxSpeakerEmbeddingExtractorConfig config;

memset(&config, 0, sizeof(config));

// please download the model from
// https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-recongition-models
config.model = "./3dspeaker_speech_campplus_sv_zh-cn_16k-common.onnx";

config.num_threads = 1;
config.debug = 0;
config.provider = "cpu";

const SherpaOnnxSpeakerEmbeddingExtractor *ex =
SherpaOnnxCreateSpeakerEmbeddingExtractor(&config);
if (!ex) {
fprintf(stderr, "Failed to create speaker embedding extractor");
return -1;
}

int32_t dim = SherpaOnnxSpeakerEmbeddingExtractorDim(ex);

const SherpaOnnxSpeakerEmbeddingManager *manager =
SherpaOnnxCreateSpeakerEmbeddingManager(dim);

// Please download the test data from
// https://github.com/csukuangfj/sr-data
const char *spk1_1 = "./sr-data/enroll/fangjun-sr-1.wav";
const char *spk1_2 = "./sr-data/enroll/fangjun-sr-2.wav";
const char *spk1_3 = "./sr-data/enroll/fangjun-sr-3.wav";

const char *spk2_1 = "./sr-data/enroll/leijun-sr-1.wav";
const char *spk2_2 = "./sr-data/enroll/leijun-sr-2.wav";

const float *spk1_vec[4] = {NULL};
spk1_vec[0] = ComputeEmbedding(ex, spk1_1);
spk1_vec[1] = ComputeEmbedding(ex, spk1_2);
spk1_vec[2] = ComputeEmbedding(ex, spk1_3);

const float *spk2_vec[3] = {NULL};
spk2_vec[0] = ComputeEmbedding(ex, spk2_1);
spk2_vec[1] = ComputeEmbedding(ex, spk2_2);

if (!SherpaOnnxSpeakerEmbeddingManagerAddList(manager, "fangjun", spk1_vec)) {
fprintf(stderr, "Failed to register fangjun\n");
exit(-1);
}

if (!SherpaOnnxSpeakerEmbeddingManagerContains(manager, "fangjun")) {
fprintf(stderr, "Failed to find fangjun\n");
exit(-1);
}

if (!SherpaOnnxSpeakerEmbeddingManagerAddList(manager, "leijun", spk2_vec)) {
fprintf(stderr, "Failed to register leijun\n");
exit(-1);
}

if (!SherpaOnnxSpeakerEmbeddingManagerContains(manager, "leijun")) {
fprintf(stderr, "Failed to find leijun\n");
exit(-1);
}

if (SherpaOnnxSpeakerEmbeddingManagerNumSpeakers(manager) != 2) {
fprintf(stderr, "There should be two speakers: fangjun and leijun\n");
exit(-1);
}

const char *const *all_speakers =
SherpaOnnxSpeakerEmbeddingManagerGetAllSpeakers(manager);
const char *const *p = all_speakers;
fprintf(stderr, "list of registered speakers\n-----\n");
while (p[0]) {
fprintf(stderr, "speaker: %s\n", p[0]);
++p;
}
fprintf(stderr, "----\n");

SherpaOnnxSpeakerEmbeddingManagerFreeAllSpeakers(all_speakers);

const char *test1 = "./sr-data/test/fangjun-test-sr-1.wav";
const char *test2 = "./sr-data/test/leijun-test-sr-1.wav";
const char *test3 = "./sr-data/test/liudehua-test-sr-1.wav";

const float *v1 = ComputeEmbedding(ex, test1);
const float *v2 = ComputeEmbedding(ex, test2);
const float *v3 = ComputeEmbedding(ex, test3);

float threshold = 0.6;

const char *name1 =
SherpaOnnxSpeakerEmbeddingManagerSearch(manager, v1, threshold);
if (name1) {
fprintf(stderr, "%s: Found %s\n", test1, name1);
SherpaOnnxSpeakerEmbeddingManagerFreeSearch(name1);
} else {
fprintf(stderr, "%s: Not found\n", test1);
}

const char *name2 =
SherpaOnnxSpeakerEmbeddingManagerSearch(manager, v2, threshold);
if (name2) {
fprintf(stderr, "%s: Found %s\n", test2, name2);
SherpaOnnxSpeakerEmbeddingManagerFreeSearch(name2);
} else {
fprintf(stderr, "%s: Not found\n", test2);
}

const char *name3 =
SherpaOnnxSpeakerEmbeddingManagerSearch(manager, v3, threshold);
if (name3) {
fprintf(stderr, "%s: Found %s\n", test3, name3);
SherpaOnnxSpeakerEmbeddingManagerFreeSearch(name3);
} else {
fprintf(stderr, "%s: Not found\n", test3);
}

int32_t ok = SherpaOnnxSpeakerEmbeddingManagerVerify(manager, "fangjun", v1,
threshold);
if (ok) {
fprintf(stderr, "%s matches fangjun\n", test1);
} else {
fprintf(stderr, "%s does NOT match fangjun\n", test1);
}

ok = SherpaOnnxSpeakerEmbeddingManagerVerify(manager, "fangjun", v2,
threshold);
if (ok) {
fprintf(stderr, "%s matches fangjun\n", test2);
} else {
fprintf(stderr, "%s does NOT match fangjun\n", test2);
}

fprintf(stderr, "Removing fangjun\n");
if (!SherpaOnnxSpeakerEmbeddingManagerRemove(manager, "fangjun")) {
fprintf(stderr, "Failed to remove fangjun\n");
exit(-1);
}

if (SherpaOnnxSpeakerEmbeddingManagerNumSpeakers(manager) != 1) {
fprintf(stderr, "There should be only 1 speaker left\n");
exit(-1);
}

name1 = SherpaOnnxSpeakerEmbeddingManagerSearch(manager, v1, threshold);
if (name1) {
fprintf(stderr, "%s: Found %s\n", test1, name1);
SherpaOnnxSpeakerEmbeddingManagerFreeSearch(name1);
} else {
fprintf(stderr, "%s: Not found\n", test1);
}

fprintf(stderr, "Removing leijun\n");
if (!SherpaOnnxSpeakerEmbeddingManagerRemove(manager, "leijun")) {
fprintf(stderr, "Failed to remove leijun\n");
exit(-1);
}

if (SherpaOnnxSpeakerEmbeddingManagerNumSpeakers(manager) != 0) {
fprintf(stderr, "There should be only 1 speaker left\n");
exit(-1);
}

name2 = SherpaOnnxSpeakerEmbeddingManagerSearch(manager, v2, threshold);
if (name2) {
fprintf(stderr, "%s: Found %s\n", test2, name2);
SherpaOnnxSpeakerEmbeddingManagerFreeSearch(name2);
} else {
fprintf(stderr, "%s: Not found\n", test2);
}

all_speakers = SherpaOnnxSpeakerEmbeddingManagerGetAllSpeakers(manager);

p = all_speakers;
fprintf(stderr, "list of registered speakers\n-----\n");
while (p[0]) {
fprintf(stderr, "speaker: %s\n", p[0]);
++p;
}
fprintf(stderr, "----\n");

SherpaOnnxSpeakerEmbeddingManagerFreeAllSpeakers(all_speakers);
SherpaOnnxSpeakerEmbeddingExtractorDestroyEmbedding(v1);
SherpaOnnxSpeakerEmbeddingExtractorDestroyEmbedding(v2);
SherpaOnnxSpeakerEmbeddingExtractorDestroyEmbedding(v3);

SherpaOnnxSpeakerEmbeddingExtractorDestroyEmbedding(spk1_vec[0]);
SherpaOnnxSpeakerEmbeddingExtractorDestroyEmbedding(spk1_vec[1]);
SherpaOnnxSpeakerEmbeddingExtractorDestroyEmbedding(spk1_vec[2]);

SherpaOnnxSpeakerEmbeddingExtractorDestroyEmbedding(spk2_vec[0]);
SherpaOnnxSpeakerEmbeddingExtractorDestroyEmbedding(spk2_vec[1]);

SherpaOnnxDestroySpeakerEmbeddingManager(manager);
SherpaOnnxDestroySpeakerEmbeddingExtractor(ex);

return 0;
}
3 changes: 3 additions & 0 deletions c-api-examples/spoken-language-identification-c-api.c
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
// c-api-examples/spoken-language-identification-c-api.c
//
// Copyright (c) 2024 Xiaomi Corporation

// We assume you have pre-downloaded the whisper multi-lingual models
// from https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models
Expand Down
2 changes: 1 addition & 1 deletion ios-swift/SherpaOnnx/SherpaOnnx/ViewController.swift
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ class ViewController: UIViewController {
// Please select one model that is best suitable for you.
//
// You can also modify Model.swift to add new pre-trained models from
// https://k2-fsa.github.io/sherpa/ncnn/pretrained_models/index.html
// https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html

// let modelConfig = getBilingualStreamZhEnZipformer20230220()
// let modelConfig = getZhZipformer20230615()
Expand Down
Loading

0 comments on commit 2e0bcca

Please sign in to comment.