Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
jeongukjae committed Sep 16, 2023
1 parent cda0751 commit f8bcc8f
Show file tree
Hide file tree
Showing 10 changed files with 75 additions and 35 deletions.
Binary file modified dictionary/legacy-dictionary.nori
Binary file not shown.
1 change: 0 additions & 1 deletion nori/lib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ cc_library(
"//nori/lib/dictionary",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
"@darts_clone",
"@icu//:common",
],
)
Expand Down
4 changes: 2 additions & 2 deletions nori/lib/dictionary/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ cc_library(
"@com_google_absl//absl/log",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
"@darts_clone",
"@darts_ac//darts_ac",
"@icu//:common",
],
)
Expand Down Expand Up @@ -51,7 +51,7 @@ cc_library(
"@com_google_absl//absl/log",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
"@darts_clone",
"@darts_ac//darts_ac",
],
)

Expand Down
22 changes: 12 additions & 10 deletions nori/lib/dictionary/builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <regex>
#include <sstream>

#include "darts_ac/darts_ac.h"
#include "absl/log/log.h"
#include "absl/strings/match.h"
#include "absl/strings/str_split.h"
Expand Down Expand Up @@ -158,17 +159,17 @@ absl::Status DictionaryBuilder::buildTokenInfos(absl::string_view input) {
});

std::vector<const char*> keys;
std::vector<size_t> lengths;
keys.reserve(entries.size());

std::string lastValue = "";
int entryValue = -1;
nori::protos::MorphemeList* lastMorphemeList;

for (int i = 0; i < entries.size(); i++) {
if (entries[i][0] != lastValue) {
keys.push_back(entries[i][0].c_str());
lengths.push_back(entries[i][0].length());
lastValue = entries[i][0];
entryValue++;
lastMorphemeList = noriDictionary.mutable_tokens()->add_morphemes_list();
}

Expand All @@ -179,20 +180,21 @@ absl::Status DictionaryBuilder::buildTokenInfos(absl::string_view input) {

// 4. Build tries
LOG(INFO) << "Build trie. keys[0]: " << keys[0] << ", keys[10]: " << keys[10];
std::unique_ptr<Darts::DoubleArray> trie =
std::unique_ptr<Darts::DoubleArray>(new Darts::DoubleArray);
if (trie->build(keys.size(), const_cast<char**>(&keys[0])) != 0)
std::unique_ptr<darts_ac::DoubleArrayAhoCorasick> trieAC =
std::unique_ptr<darts_ac::DoubleArrayAhoCorasick>(new darts_ac::DoubleArrayAhoCorasick);
if (trieAC->buildAhoCorasick(keys.size(), keys.data(), lengths.data()) != 0)
return absl::InternalError("Cannot build trie.");

noriDictionary.mutable_darts_array()->assign(
static_cast<const char*>(trie->array()), trie->total_size());

trie->set_array(noriDictionary.darts_array().data(),
noriDictionary.darts_array().size());
static_cast<const char*>(trieAC->array()), trieAC->total_size());
noriDictionary.mutable_darts_ac_failure()->assign(
static_cast<const char*>(trieAC->failure()), trieAC->failure_size());
noriDictionary.mutable_darts_ac_depth()->assign(
static_cast<const char*>(trieAC->depth()), trieAC->depth_size());

int searchResult;
for (int i = 0; i < keys.size(); i++) {
trie->exactMatchSearch(keys[i], searchResult);
trieAC->exactMatchSearch(keys[i], searchResult);
if (searchResult != i)
return absl::InternalError("Trie isn't built properly.");
}
Expand Down
2 changes: 0 additions & 2 deletions nori/lib/dictionary/builder.h
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
#ifndef __NORI_DICTIONARY_BUILDER_H__
#define __NORI_DICTIONARY_BUILDER_H__

#include <darts.h>

#include <functional>
#include <memory>
#include <vector>
Expand Down
17 changes: 10 additions & 7 deletions nori/lib/dictionary/dictionary.cc
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
#include "nori/lib/dictionary/dictionary.h"

#include <darts.h>

#include <fstream>
#include <sstream>

Expand All @@ -11,6 +9,7 @@
#include "absl/strings/str_split.h"
#include "absl/strings/string_view.h"
#include "absl/strings/strip.h"
#include "darts_ac/darts_ac.h"
#include "icu4c/source/common/unicode/unistr.h"
#include "nori/lib/utils.h"
#include "re2/re2.h"
Expand Down Expand Up @@ -59,8 +58,10 @@ absl::Status Dictionary::loadPrebuilt(std::string input) {
auto status = internal::deserializeProtobuf(input, dictionary);
if (!status.ok()) return status;

trie.set_array(dictionary.darts_array().data(),
dictionary.darts_array().size());
trieAC.set_array(dictionary.darts_array().data(),
dictionary.darts_array().size());
trieAC.set_failure(dictionary.darts_ac_failure().data());
trieAC.set_depth(dictionary.darts_ac_depth().data());

// backwardSize
backwardSize = dictionary.connection_cost().backward_size();
Expand Down Expand Up @@ -108,7 +109,7 @@ const nori::protos::CharacterClass Dictionary::getCharClass(

absl::Status UserDictionary::load(std::string filename, int leftId, int rightId,
int rightId_T, int rightId_F) {
trie.clear();
trieAC.clear();
morphemes.clear();

std::ifstream ifs(filename);
Expand Down Expand Up @@ -155,8 +156,10 @@ absl::Status UserDictionary::load(std::string filename, int leftId, int rightId,
});

std::vector<const char*> keys;
std::vector<size_t> lengths;
for (const auto& term : terms) {
keys.push_back(term[0].data());
lengths.push_back(term[0].size());

nori::protos::Morpheme morpheme;

Expand Down Expand Up @@ -185,12 +188,12 @@ absl::Status UserDictionary::load(std::string filename, int leftId, int rightId,
morphemes.push_back(morpheme);
}

if (trie.build(keys.size(), const_cast<char**>(&keys[0])) != 0)
if (trieAC.buildAhoCorasick(keys.size(), keys.data(), lengths.data()) != 0)
return absl::InternalError("Cannot build trie.");

// search second item to check Trie is built properly
int searchResult;
trie.exactMatchSearch(keys[1], searchResult);
trieAC.exactMatchSearch(keys[1], searchResult);
if (searchResult != 1)
return absl::InternalError("Trie isn't built properly.");

Expand Down
10 changes: 5 additions & 5 deletions nori/lib/dictionary/dictionary.h
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#ifndef __NORI_DICTIONARY_H__
#define __NORI_DICTIONARY_H__

#include <darts.h>
#include "darts_ac/darts_ac.h"

#include "absl/status/status.h"
#include "absl/strings/string_view.h"
Expand All @@ -27,15 +27,15 @@ class UserDictionary {
int rightId_T, int rightId_F);

// return trie dictionary.
const Darts::DoubleArray* getTrie() const { return &trie; }
const darts_ac::DoubleArrayAhoCorasick* getTrieAC() const { return &trieAC; }

// get all morphemes for the trie.
const std::vector<nori::protos::Morpheme>* getMorphemes() const {
return &morphemes;
}

private:
Darts::DoubleArray trie;
darts_ac::DoubleArrayAhoCorasick trieAC;
std::vector<nori::protos::Morpheme> morphemes;
};

Expand Down Expand Up @@ -76,7 +76,7 @@ class Dictionary {
bool isUserInitialized() const { return userInitialized; }

// return trie dictionary
const Darts::DoubleArray* getTrie() const { return &trie; }
const darts_ac::DoubleArrayAhoCorasick* getTrieAC() const { return &trieAC; }

// return user dictionary
const UserDictionary* getUserDict() const { return &userDictionary; }
Expand Down Expand Up @@ -122,7 +122,7 @@ class Dictionary {
bool initialized = false;
bool userInitialized = false;

Darts::DoubleArray trie;
darts_ac::DoubleArrayAhoCorasick trieAC;
nori::protos::Dictionary dictionary;
Normalizer normalizer;
UserDictionary userDictionary;
Expand Down
2 changes: 2 additions & 0 deletions nori/lib/protos/dictionary.proto
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,8 @@ message ConnectionCost {
// Final message to contain all dictionary information
message Dictionary {
bytes darts_array = 1;
bytes darts_ac_failure = 5;
bytes darts_ac_depth = 6;
Tokens tokens = 2;
UnknownTokens unknown_tokens = 3;
ConnectionCost connection_cost = 4;
Expand Down
46 changes: 39 additions & 7 deletions nori/lib/tokenizer.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#include "nori/lib/tokenizer.h"

#include <darts.h>
#include <google/protobuf/repeated_field.h>

#include <map>
Expand All @@ -9,6 +8,7 @@
#include <vector>

#include "absl/log/log.h"
#include "darts_ac/darts_ac.h"
#include "icu4c/source/common/unicode/uchar.h"
#include "icu4c/source/common/unicode/uscript.h"
#include "icu4c/source/common/unicode/utf.h"
Expand Down Expand Up @@ -171,7 +171,7 @@ TrieNode* selectParent(std::vector<internal::TrieNode>& candidates,
} // namespace internal

// NoriTokenizer class
typedef Darts::DoubleArray::result_pair_type DartsResults;
typedef darts_ac::DoubleArrayAhoCorasick::result_pair_type DartsResults;

absl::Status NoriTokenizer::tokenize(Lattice& lattice,
GraphvizVisualizer* visualizer) const {
Expand Down Expand Up @@ -215,10 +215,9 @@ absl::Status NoriTokenizer::tokenize(Lattice& lattice,

// find user dictionary
if (dictionary->isUserInitialized()) {
const int numNodes =
dictionary->getUserDict()->getTrie()->commonPrefixSearch(
current, trieResults.data(), maxTrieResults,
static_cast<int>(end - current));
const int numNodes = dictionary->getUserDict()->getTrieAC()->find(
current, trieResults.data(), maxTrieResults,
static_cast<int>(end - current));

if (numNodes != 0) {
int index = 0;
Expand Down Expand Up @@ -264,7 +263,7 @@ absl::Status NoriTokenizer::tokenize(Lattice& lattice,
}

// pre-built dictionary
const int numNodes = dictionary->getTrie()->commonPrefixSearch(
const int numNodes = dictionary->getTrieAC()->find(
current, trieResults.data(), maxTrieResults,
static_cast<int>(end - current));
if (numNodes > maxTrieResults)
Expand Down Expand Up @@ -398,4 +397,37 @@ absl::Status NoriTokenizer::tokenize(Lattice& lattice,
return absl::OkStatus();
}

absl::Status NoriTokenizer::findPreBuiltTokens(
Lattice& lattice, std::vector<std::vector<internal::TrieNode>>& nodesByPos,
int& nodeId) const {
std::vector<DartsResults> trieResults(maxTrieResults + 1);
const int numNodes = dictionary->getTrieAC()->find(
lattice.getSentence().data(), trieResults.data(), maxTrieResults,
lattice.getSentence().length());

if (numNodes > maxTrieResults) return absl::InternalError("Cannot search trie");

for (int k = 0; k < numNodes; ++k) {
auto trieResult = trieResults[k];
auto morphemeList = &this->dictionary->getTokens()->morphemes_list(trieResult.value);
auto morphemeSize = morphemeList->morphemes_size();

for (int j = 0; j < morphemeSize; j++) {
const auto* morpheme = &morphemeList->morphemes(j);

int wordCost = morpheme->word_cost();
int spaceCost = 0;
int connectionCost = 0;
internal::TrieNode* parent = internal::selectParent(
nodesByPos[0], morpheme, this->dictionary, connectionCost);

int lastPositionIndex = trieResult.length;
int cost = parent->cost + wordCost + connectionCost + spaceCost;
nodesByPos[lastPositionIndex].emplace_back(
nodeId++, cost, lastPositionIndex, trieResult.length, morpheme,
parent);
}
}
}

} // namespace nori
6 changes: 5 additions & 1 deletion nori/lib/tokenizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ struct Lattice {
class NoriTokenizer {
public:
NoriTokenizer(const nori::dictionary::Dictionary* dictionary,
size_t maxTrieResults = 1024)
size_t maxTrieResults = 4096)
: dictionary(dictionary), maxTrieResults(maxTrieResults) {}

// Tokenize input text and save tokenized information to lattice
Expand All @@ -88,6 +88,10 @@ class NoriTokenizer {
private:
const nori::dictionary::Dictionary* dictionary;
const size_t maxTrieResults;

absl::Status findPreBuiltTokens(Lattice& lattice, std::vector<std::vector<internal::TrieNode>> &nodesByPos, int& nodeId) const;
absl::Status findUserDictionaryTokens(Lattice& lattice, std::vector<std::vector<internal::TrieNode>> &nodesByPos) const;
absl::Status findUnknownTokens(Lattice& lattice, std::vector<std::vector<internal::TrieNode>> &nodesByPos) const;
};

} // namespace nori
Expand Down

0 comments on commit f8bcc8f

Please sign in to comment.