From 5b3974e021035b41af074acb342bcb31745bcc5a Mon Sep 17 00:00:00 2001 From: Weng Xuetian Date: Sun, 19 Apr 2026 13:28:13 -0700 Subject: [PATCH] Add support for store pinyin/table code in user history --- src/libime/core/historybigram.cpp | 402 +++++++++++++++++-------- src/libime/core/historybigram.h | 49 ++- src/libime/core/userlanguagemodel.cpp | 13 +- src/libime/core/userlanguagemodel.h | 3 + src/libime/pinyin/pinyincontext.cpp | 142 ++++++--- src/libime/pinyin/pinyincontext.h | 27 +- src/libime/pinyin/pinyindecoder_p.h | 13 + src/libime/pinyin/pinyinime.cpp | 16 +- src/libime/pinyin/pinyinprediction.cpp | 43 ++- src/libime/pinyin/pinyinprediction.h | 11 + src/libime/table/tablecontext.cpp | 59 +++- test/testhistorybigram.cpp | 61 ++++ test/testpinyinime_unit.cpp | 85 ++++-- test/testtableime.cpp | 2 + test/testtableime_unit.cpp | 71 ++++- 15 files changed, 772 insertions(+), 225 deletions(-) diff --git a/src/libime/core/historybigram.cpp b/src/libime/core/historybigram.cpp index 93deff0..27935c5 100644 --- a/src/libime/core/historybigram.cpp +++ b/src/libime/core/historybigram.cpp @@ -33,8 +33,47 @@ namespace libime { -static constexpr uint32_t historyBinaryFormatMagic = 0x000fc315; -static constexpr uint32_t historyBinaryFormatVersion = 0x3; +namespace { + +using WordWithCode = HistoryBigram::WordWithCode; +using WordWithCodeView = HistoryBigram::WordWithCodeView; + +constexpr uint32_t historyBinaryFormatMagic = 0x000fc315; +constexpr uint32_t historyBinaryFormatVersion = 0x4; +constexpr char bigramSeparator = '\x01'; +constexpr char wordCodeSeparator = '\x02'; + +std::string wordAndCodeToString(WordWithCodeView wordAndCode) { + std::string s{std::get<0>(wordAndCode)}; + if (s.empty()) { + return s; + } + auto code = std::get<1>(wordAndCode); + if (!code.empty()) { + s += wordCodeSeparator; + s += code; + } + return s; +} + +WordWithCode bigramWordWithCode(WordWithCodeView prev, WordWithCodeView cur) { + std::string s; + s.append(std::get<0>(prev)); + s += bigramSeparator; + s.append(std::get<0>(cur)); + + auto code1 = std::get<1>(prev); + auto code2 = std::get<1>(cur); + std::string concatCode; + if (code1.empty() && code2.empty()) { + concatCode = ""; + } else { + concatCode = code1; + concatCode += bigramSeparator; + concatCode += code2; + } + return {s, concatCode}; +} struct WeightedTrie { using TrieType = DATrie; @@ -48,21 +87,68 @@ struct WeightedTrie { int32_t weightedSize() const { return weightedSize_; } - int32_t freq(std::string_view s) const { - auto v = trie_.exactMatchSearch(s.data(), s.size()); - if (TrieType::isNoValue(v)) { - return 0; + int32_t freq(WordWithCodeView wordAndCode) const { + // If query with code, the match will be {word, ""} + {word, code}. + // If query without code, the match will be {word, ""} + {word, + // separator}. + TrieType::position_type pos = 0; + auto result = 0; + auto v = trie_.traverse(wordAndCode.first, pos); + if (TrieType::isValid(v)) { + result += v; + } + const char separator[] = {wordCodeSeparator, '\0'}; + v = trie_.traverse(separator, pos); + if (!TrieType::isNoPath(v)) { + if (!wordAndCode.second.empty() && + wordAndCode.second.front() != bigramSeparator && + wordAndCode.second.back() != bigramSeparator) { + v = trie_.traverse(wordAndCode.second, pos); + if (TrieType::isValid(v)) { + result += v; + } + } else { + trie_.foreach( + [this, &result, &wordAndCode](TrieType::value_type value, + size_t len, + TrieType::position_type pos) { + if (len == 0) { + return true; + } + if (!wordAndCode.second.empty()) { + assert( + wordAndCode.second.front() == bigramSeparator || + wordAndCode.second.back() == bigramSeparator); + std::string codeInTrie; + trie().suffix(codeInTrie, len, pos); + if (wordAndCode.second.front() == bigramSeparator && + !codeInTrie.ends_with(wordAndCode.second)) { + return true; + } + if (wordAndCode.second.back() == bigramSeparator && + !codeInTrie.starts_with(wordAndCode.second)) { + return true; + } + } + result += value; + return true; + }, + pos); + } } - return v; + return result; } - void incFreq(std::string_view s, int32_t delta) { + void incFreq(WordWithCodeView wordAndCode, int32_t delta) { + auto s = wordAndCodeToString(wordAndCode); + trie_.update(s.data(), s.size(), [delta](int32_t v) { return v + delta; }); weightedSize_ += delta; } - void decFreq(std::string_view s, int32_t delta) { + void decFreq(WordWithCodeView wordAndCode, int32_t delta) { + auto s = wordAndCodeToString(wordAndCode); auto v = trie_.exactMatchSearch(s.data(), s.size()); if (TrieType::isNoValue(v)) { return; @@ -77,48 +163,6 @@ struct WeightedTrie { } } - void eraseByKey(std::string_view s) { - auto v = trie_.exactMatchSearch(s.data(), s.size()); - if (TrieType::isNoValue(v)) { - return; - } - trie_.erase(s); - decWeightedSize(v); - } - - void eraseByPrefix(std::string_view s) { - std::vector> values; - trie_.foreach(s, [this, &values](TrieType::value_type value, size_t len, - TrieType::position_type pos) { - std::string buf; - trie().suffix(buf, len, pos); - values.emplace_back(std::move(buf), value); - return true; - }); - for (auto &value : values) { - trie_.erase(value.first); - decWeightedSize(value.second); - } - } - - void eraseBySuffix(std::string_view s) { - std::vector> values; - trie_.foreach(s, - [this, &values, s](TrieType::value_type value, size_t len, - TrieType::position_type pos) { - std::string buf; - trie().suffix(buf, len, pos); - if (buf.ends_with(s)) { - values.emplace_back(std::move(buf), value); - } - return true; - }); - for (auto &value : values) { - trie_.erase(value.first); - decWeightedSize(value.second); - } - } - void fillPredict(std::unordered_set &words, std::string_view word, size_t maxSize) const { trie_.foreach(word, @@ -126,6 +170,10 @@ struct WeightedTrie { TrieType::position_type pos) { std::string buf; trie().suffix(buf, len, pos); + auto separatorPos = buf.find(wordCodeSeparator); + if (separatorPos != std::string::npos) { + buf.erase(separatorPos); + } // Skip special word. if (buf == "" || buf == "") { return true; @@ -157,11 +205,19 @@ class HistoryBigramPool { while (count--) { uint32_t size = 0; throw_if_io_fail(unmarshall(in, size)); - std::vector sentence; + std::vector sentence; while (size--) { std::string buffer; throw_if_io_fail(unmarshallString(in, buffer)); - sentence.emplace_back(std::move(buffer)); + std::string_view bufferView{buffer}; + size_t separatorPos = bufferView.find(wordCodeSeparator); + if (separatorPos != std::string_view::npos) { + sentence.emplace_back( + std::string(bufferView.substr(0, separatorPos)), + std::string(bufferView.substr(separatorPos + 1))); + } else { + sentence.emplace_back(std::move(buffer), ""); + } } add(sentence); } @@ -178,9 +234,46 @@ class HistoryBigramPool { } } for (auto &line : lines | std::views::reverse) { - std::vector sentence = - fcitx::stringutils::split(line, " "); - add(sentence); + std::string_view lineView{line}; + std::vector tokens; + bool withCode = false; + while (!lineView.empty()) { + std::string token; + auto consumed = fcitx::stringutils::consumeMaybeEscapedValue( + lineView, FCITX_WHITESPACE, &token); + if (!consumed.empty()) { + tokens.push_back(std::move(token)); + } + if (tokens.size() == 1 && !lineView.empty() && + lineView.front() == '\t') { + withCode = true; + } + } + + if (withCode) { + if (tokens.size() % 2 != 0) { + continue; + } + add(std::views::iota(static_cast(0), + tokens.size() / 2) | + std::views::transform([&tokens](size_t i) { + return WordWithCode{tokens[i * 2], tokens[(i * 2) + 1]}; + })); + + } else { + add(tokens | + std::views::transform([](const auto &word) -> WordWithCode { + std::vector wordWithMaybeCode = + fcitx::stringutils::split( + word, "\t", + fcitx::stringutils::SplitBehavior::KeepEmpty); + if (wordWithMaybeCode.size() == 2) { + return WordWithCode{wordWithMaybeCode[0], + wordWithMaybeCode[1]}; + } + return WordWithCode{word, ""}; + })); + } } } @@ -193,8 +286,8 @@ class HistoryBigramPool { for (auto &sentence : recent_ | std::views::reverse) { uint32_t size = sentence.size(); throw_if_io_fail(marshall(out, size)); - for (auto &s : sentence) { - throw_if_io_fail(marshallString(out, s)); + for (const auto &s : sentence) { + throw_if_io_fail(marshallString(out, wordAndCodeToString(s))); } } } @@ -202,13 +295,20 @@ class HistoryBigramPool { void dump(std::ostream &out) const { for (const auto &sentence : recent_) { bool first = true; + bool hasCode = std::ranges::any_of(sentence, [](const auto &item) { + return !std::get<1>(item).empty(); + }); for (const auto &s : sentence) { if (first) { first = false; } else { out << " "; } - out << s; + out << fcitx::stringutils::escapeForValue(std::get<0>(s)); + if (hasCode) { + out << "\t" + << fcitx::stringutils::escapeForValue(std::get<1>(s)); + } } out << '\n'; } @@ -222,14 +322,15 @@ class HistoryBigramPool { } template - std::list> add(const R &sentence) { - std::list> popedSentence; + std::list> add(const R &sentence) { + std::list> popedSentence; if (sentence.empty()) { return popedSentence; } // Validate data. if (std::any_of(std::begin(sentence), std::end(sentence), - [](const std::string &word) { + [](const auto &item) { + const auto &[word, code] = item; return word.find('\0') != std::string::npos; })) { return popedSentence; @@ -240,43 +341,33 @@ class HistoryBigramPool { std::prev(recent_.end())); } - std::vector newSentence; + std::vector newSentence; auto delta = 1; for (auto iter = sentence.begin(), end = sentence.end(); iter != end; iter++) { unigram_.incFreq(*iter, delta); - auto next = std::next(iter); + auto next = std::ranges::next(iter); if (next != end) { incBigram(*iter, *next, delta); } - std::string ss; - ss += *iter; - newSentence.push_back(ss); + newSentence.push_back(*iter); } recent_.push_front(std::move(newSentence)); - unigram_.incFreq("", delta); - unigram_.incFreq("", delta); - incBigram("", sentence.front(), delta); - incBigram(sentence.back(), "", delta); + unigram_.incFreq({"", ""}, delta); + unigram_.incFreq({"", ""}, delta); + incBigram({"", ""}, sentence.front(), delta); + incBigram(sentence.back(), {"", ""}, delta); return popedSentence; } - float unigramFreq(std::string_view s) const { - auto v = unigram_.freq(s); - return v; - } + int32_t unigramFreq(WordWithCodeView s) const { return unigram_.freq(s); } - float bigramFreq(std::string_view s1, std::string_view s2) const { - std::string s; - s.append(s1.data(), s1.size()); - s += '|'; - s.append(s2.data(), s2.size()); - auto v = bigram_.freq(s); - return v; + int32_t bigramFreq(WordWithCodeView s1, WordWithCodeView s2) const { + return bigram_.freq(bigramWordWithCode(s1, s2)); } - bool isUnknown(std::string_view word) const { + bool isUnknown(WordWithCodeView word) const { return unigramFreq(word) == 0; } @@ -284,10 +375,14 @@ class HistoryBigramPool { size_t realSize() const { return recent_.size(); } - void forget(std::string_view word) { + void forget(std::string_view word, std::string_view code) { auto iter = recent_.begin(); while (iter != recent_.end()) { - if (std::find(iter->begin(), iter->end(), word) != iter->end()) { + if (std::find_if( + iter->begin(), iter->end(), [word, code](const auto &item) { + const auto &[w, c] = item; + return w == word && (code.empty() || c == code); + }) != iter->end()) { remove(*iter); iter = recent_.erase(iter); } else { @@ -313,37 +408,31 @@ class HistoryBigramPool { decBigram(*iter, *next, delta); } } - decBigram("", sentence.front(), delta); - decBigram(sentence.back(), "", delta); + decBigram({"", ""}, sentence.front(), delta); + decBigram(sentence.back(), {"", ""}, delta); } - void decBigram(std::string_view s1, std::string_view s2, int32_t delta) { - std::string ss; - ss.append(s1.data(), s1.size()); - ss += '|'; - ss.append(s2.data(), s2.size()); - bigram_.decFreq(ss, delta); + void decBigram(WordWithCodeView s1, WordWithCodeView s2, int32_t delta) { + bigram_.decFreq(bigramWordWithCode(s1, s2), delta); } - void incBigram(std::string_view s1, std::string_view s2, int delta) { - std::string ss; - ss.append(s1.data(), s1.size()); - ss += '|'; - ss.append(s2.data(), s2.size()); - bigram_.incFreq(ss, delta); + void incBigram(WordWithCodeView s1, WordWithCodeView s2, int delta) { + bigram_.incFreq(bigramWordWithCode(s1, s2), delta); } const size_t maxSize_; // Used when maxSize_ != 0. size_t size_ = 0; - std::list> recent_; + std::list> recent_; // Used for look up WeightedTrie unigram_; WeightedTrie bigram_; }; +} // namespace + // We define the frequency as following. // (1 - p) the frequency belongs to first pool. // p * (1 - p) Second pool @@ -354,9 +443,9 @@ class HistoryBigramPool { // And then we define alpha as p = 1 / (1 + alpha). class HistoryBigramPrivate { public: - void populateSentence(std::list> popedSentence) { + void populateSentence(std::list> popedSentence) { for (size_t i = 1; !popedSentence.empty() && i < pools_.size(); i++) { - std::list> nextSentences; + std::list> nextSentences; while (!popedSentence.empty()) { auto newPopedSentence = pools_[i].add(popedSentence.front()); popedSentence.pop_front(); @@ -366,7 +455,7 @@ class HistoryBigramPrivate { } } - float unigramFreq(std::string_view word) const { + float unigramFreq(WordWithCodeView word) const { assert(pools_.size() == poolWeight_.size()); float freq = 0; for (size_t i = 0; i < pools_.size(); i++) { @@ -375,7 +464,7 @@ class HistoryBigramPrivate { return freq; } - float bigramFreq(std::string_view prev, std::string_view cur) const { + float bigramFreq(WordWithCodeView prev, WordWithCodeView cur) const { assert(pools_.size() == poolWeight_.size()); float freq = 0; for (size_t i = 0; i < pools_.size(); i++) { @@ -443,33 +532,57 @@ bool HistoryBigram::useOnlyUnigram() const { } void HistoryBigram::add(const libime::SentenceResult &sentence) { + FCITX_D(); + addWithCode(sentence, nullptr); +} + +void HistoryBigram::addWithCode( + const libime::SentenceResult &sentence, + const ValidationCodeExtractor &validationCodeExtractor) { FCITX_D(); d->populateSentence(d->pools_[0].add( sentence.sentence() | - std::views::transform([](const auto &item) -> const std::string & { - return item->word(); - }))); + std::views::transform( + [&validationCodeExtractor](const auto &item) -> WordWithCode { + return {item->word(), validationCodeExtractor + ? validationCodeExtractor(item) + : ""}; + }))); } void HistoryBigram::add(const std::vector &sentence) { FCITX_D(); - d->populateSentence(d->pools_[0].add(sentence)); + d->populateSentence(d->pools_[0].add( + sentence | std::views::transform([](const auto &word) -> WordWithCode { + return WordWithCode{word, ""}; + }))); +} + +void HistoryBigram::addWithCode( + const std::vector &sentenceWithValidationCode) { + FCITX_D(); + d->populateSentence(d->pools_[0].add(sentenceWithValidationCode)); } bool HistoryBigram::isUnknown(std::string_view v) const { FCITX_D(); return std::ranges::all_of(d->pools_, [v](const HistoryBigramPool &pool) { - return pool.isUnknown(v); + return pool.isUnknown({v, ""}); }); } float HistoryBigram::score(std::string_view prev, std::string_view cur) const { + return scoreWithCode({prev, ""}, {cur, ""}); +} + +float HistoryBigram::scoreWithCode(WordWithCodeView prev, + WordWithCodeView cur) const { FCITX_D(); - if (prev.empty()) { - prev = ""; + if (prev.first.empty()) { + prev.first = ""; } - if (cur.empty()) { - cur = ""; + if (cur.first.empty()) { + cur.first = ""; } auto uf0 = d->unigramFreq(prev); @@ -508,7 +621,11 @@ void HistoryBigram::load(std::istream &in) { case 2: std::ranges::for_each(d->pools_, [&in](auto &pool) { pool.load(in); }); break; + case 3: case historyBinaryFormatVersion: + // For version 3 and version 4, the format is the same, but version 4 + // contains additional code data, bump the version to it not backward + // compatible with version 3. readZSTDCompressed(in, [d](std::istream &compressIn) { std::ranges::for_each(d->pools_, [&compressIn](auto &pool) { pool.load(compressIn); @@ -547,9 +664,12 @@ void HistoryBigram::clear() { std::ranges::for_each(d->pools_, std::mem_fn(&HistoryBigramPool::clear)); } -void HistoryBigram::forget(std::string_view word) { +void HistoryBigram::forget(std::string_view word) { forget(word, ""); } + +void HistoryBigram::forget(std::string_view word, std::string_view code) { FCITX_D(); - std::ranges::for_each(d->pools_, [word](auto &pool) { pool.forget(word); }); + std::ranges::for_each( + d->pools_, [word, code](auto &pool) { pool.forget(word, code); }); } void HistoryBigram::fillPredict(std::unordered_set &words, @@ -565,7 +685,7 @@ void HistoryBigram::fillPredict(std::unordered_set &words, } else { lookup = ""; } - lookup += "|"; + lookup += bigramSeparator; std::ranges::for_each( d->pools_, [&words, &lookup, maxSize](const HistoryBigramPool &pool) { pool.fillPredict(words, lookup, maxSize); @@ -575,10 +695,52 @@ void HistoryBigram::fillPredict(std::unordered_set &words, bool HistoryBigram::containsBigram(std::string_view prev, std::string_view cur) const { FCITX_D(); - return std::ranges::any_of(d->pools_, - [&prev, &cur](const HistoryBigramPool &pool) { - return pool.bigramFreq(prev, cur) > 0; - }); + return std::ranges::any_of( + d->pools_, [&prev, &cur](const HistoryBigramPool &pool) { + return pool.bigramFreq({prev, ""}, {cur, ""}) > 0; + }); +} + +float HistoryBigram::unigramFrequency(WordWithCodeView word) const { + FCITX_D(); + return d->unigramFreq(word); +} + +float HistoryBigram::bigramFrequency(WordWithCodeView prev, + WordWithCodeView cur) const { + FCITX_D(); + return d->bigramFreq(prev, cur); +} + +int32_t HistoryBigram::rawUnigramFrequency(WordWithCodeView word) const { + FCITX_D(); + int32_t freq = 0; + for (const auto &pool : d->pools_) { + freq += pool.unigramFreq(word); + } + return freq; +} + +int32_t HistoryBigram::rawBigramFrequency(WordWithCodeView prev, + WordWithCodeView cur) const { + FCITX_D(); + int32_t freq = 0; + for (const auto &pool : d->pools_) { + freq += pool.bigramFreq(prev, cur); + } + return freq; +} + +float HistoryBigram::score(const WordNode *prev, const WordNode *cur) const { + return scoreWithCode(prev, cur, nullptr); +} + +float HistoryBigram::scoreWithCode( + const WordNode *prev, const WordNode *cur, + const ValidationCodeExtractor &extractor) const { + return scoreWithCode( + {prev ? prev->word() : "", extractor && prev ? extractor(prev) : ""}, + {cur ? cur->word() : "", extractor && cur ? extractor(cur) : ""}); } } // namespace libime diff --git a/src/libime/core/historybigram.h b/src/libime/core/historybigram.h index 4c29bdf..82f22ec 100644 --- a/src/libime/core/historybigram.h +++ b/src/libime/core/historybigram.h @@ -7,12 +7,15 @@ #define _FCITX_LIBIME_CORE_HISTORYBIGRAM_H_ #include +#include +#include #include #include #include #include #include #include +#include #include #include #include @@ -22,8 +25,13 @@ namespace libime { class HistoryBigramPrivate; +using ValidationCodeExtractor = std::function; + class LIBIMECORE_EXPORT HistoryBigram { public: + using WordWithCode = std::pair; + using WordWithCodeView = std::pair; + HistoryBigram(); FCITX_DECLARE_VIRTUAL_DTOR_MOVE(HistoryBigram); @@ -43,14 +51,20 @@ class LIBIMECORE_EXPORT HistoryBigram { bool useOnlyUnigram() const; void forget(std::string_view word); + void forget(std::string_view word, std::string_view code); bool isUnknown(std::string_view v) const; - float score(const WordNode *prev, const WordNode *cur) const { - return score(prev ? prev->word() : "", cur ? cur->word() : ""); - } + float score(const WordNode *prev, const WordNode *cur) const; float score(std::string_view prev, std::string_view cur) const; + float scoreWithCode(WordWithCodeView prev, WordWithCodeView cur) const; + float scoreWithCode(const WordNode *prev, const WordNode *cur, + const ValidationCodeExtractor &extractor) const; void add(const SentenceResult &sentence); void add(const std::vector &sentence); + void addWithCode(const SentenceResult &sentence, + const ValidationCodeExtractor &validationCodeExtractor); + void + addWithCode(const std::vector &sentenceWithValidationCode); /// Fill the prediction based on current sentence. void fillPredict(std::unordered_set &words, @@ -59,6 +73,35 @@ class LIBIMECORE_EXPORT HistoryBigram { bool containsBigram(std::string_view prev, std::string_view cur) const; + /** + * Query the weighted frequency of the unigram. + * + * @since 1.1.14 + */ + float unigramFrequency(WordWithCodeView word) const; + + /** + * Query the weighted frequency of the bigram. + * + * @since 1.1.14 + */ + float bigramFrequency(WordWithCodeView prev, WordWithCodeView cur) const; + + /** + * Query the raw frequency of the unigram. + * + * @since 1.1.14 + */ + int32_t rawUnigramFrequency(WordWithCodeView word) const; + + /** + * Query the raw frequency of the bigram. + * + * @since 1.1.14 + */ + int32_t rawBigramFrequency(WordWithCodeView prev, + WordWithCodeView cur) const; + private: std::unique_ptr d_ptr; FCITX_DECLARE_PRIVATE(HistoryBigram); diff --git a/src/libime/core/userlanguagemodel.cpp b/src/libime/core/userlanguagemodel.cpp index a25814f..64fd30c 100644 --- a/src/libime/core/userlanguagemodel.cpp +++ b/src/libime/core/userlanguagemodel.cpp @@ -32,6 +32,7 @@ class UserLanguageModelPrivate { bool useOnlyUnigram_ = false; HistoryBigram history_; + ValidationCodeExtractor extractor_; float weight_ = DEFAULT_USER_LANGUAGE_MODEL_USER_WEIGHT; // log(wa * exp(a) + wb * exp(b)) // log(exp(log(wa) + a) + exp(b + log(wb)) @@ -128,7 +129,12 @@ float UserLanguageModel::score(const State &state, const WordNode &word, score = LanguageModel::score(state, word, out); } const auto *prev = d->wordFromState(state); - float userScore = d->history_.score(prev, &word); + float userScore; + if (d->extractor_) { + userScore = d->history_.scoreWithCode(prev, &word, d->extractor_); + } else { + userScore = d->history_.score(prev, &word); + } d->setWordToState(out, &word); return std::max(score, sum_log_prob(score + d->wa_, userScore + d->wb_)); } @@ -170,4 +176,9 @@ bool UserLanguageModel::containsNonUnigram( return LanguageModel::maxNgramLength(words) > 1; } +void UserLanguageModel::setCodeExtractor(ValidationCodeExtractor extractor) { + FCITX_D(); + d->extractor_ = std::move(extractor); +} + } // namespace libime diff --git a/src/libime/core/userlanguagemodel.h b/src/libime/core/userlanguagemodel.h index 1c767b3..9681377 100644 --- a/src/libime/core/userlanguagemodel.h +++ b/src/libime/core/userlanguagemodel.h @@ -13,6 +13,7 @@ #include #include #include +#include #include #include @@ -48,6 +49,8 @@ class LIBIMECORE_EXPORT UserLanguageModel : public LanguageModel { bool containsNonUnigram(const std::vector &words) const; + void setCodeExtractor(ValidationCodeExtractor extractor); + private: std::unique_ptr d_ptr; FCITX_DECLARE_PRIVATE(UserLanguageModel); diff --git a/src/libime/pinyin/pinyincontext.cpp b/src/libime/pinyin/pinyincontext.cpp index 0fe8859..8e7a7d8 100644 --- a/src/libime/pinyin/pinyincontext.cpp +++ b/src/libime/pinyin/pinyincontext.cpp @@ -16,6 +16,7 @@ #include #include #include +#include #include #include #include @@ -34,6 +35,7 @@ #include "libime/core/userlanguagemodel.h" #include "libime/pinyin/constants.h" #include "pinyindecoder.h" +#include "pinyindecoder_p.h" #include "pinyinencoder.h" #include "pinyinime.h" #include "pinyinmatchstate.h" @@ -41,9 +43,9 @@ namespace libime { enum class LearnWordResult { - Normal, - Custom, - Ignored, + Normal, /// word is consisted all from regular word from dict. + Custom, /// word is consisted with custom word (e.g. symbol replacement). + Ignored, /// not learned as word. }; enum class SelectedPinyinType { @@ -53,13 +55,13 @@ enum class SelectedPinyinType { }; struct SelectedPinyin { - SelectedPinyin(size_t s, WordNode word, std::string encodedPinyin, - SelectedPinyinType type) - : offset_(s), word_(std::move(word)), - encodedPinyin_(std::move(encodedPinyin)), type_(type) {} + SelectedPinyin(size_t s, PinyinWordNode word, SelectedPinyinType type) + : offset_(s), word_(std::move(word)), type_(type) {} + + const std::string &encodedPinyin() const { return word_.encodedPinyin(); } + size_t offset_; - WordNode word_; - std::string encodedPinyin_; + PinyinWordNode word_; SelectedPinyinType type_; }; @@ -82,7 +84,7 @@ class PinyinContextPrivate : public fcitx::QPtrHolder { mutable std::vector candidatesToCursor_; mutable std::unordered_set candidatesToCursorSet_; std::vector conn_; - std::list contextWords_; + std::list contextWords_; size_t alignCursorToNextSegment() const { FCITX_Q(); @@ -179,7 +181,7 @@ class PinyinContextPrivate : public fcitx::QPtrHolder { if (!remain.empty()) { if (std::all_of(remain.begin(), remain.end(), [](char c) { return c == '\''; })) { - selection.emplace_back(q->size(), WordNode("", 0), "", + selection.emplace_back(q->size(), PinyinWordNode({}, 0), SelectedPinyinType::Separator); } } @@ -190,16 +192,18 @@ class PinyinContextPrivate : public fcitx::QPtrHolder { void select(const SentenceResult &sentence) { FCITX_Q(); auto offset = q->selectedLength(); - selectHelper( - [offset, &sentence, this](std::vector &selection) { - for (const auto &p : sentence.sentence()) { - selection.emplace_back( - offset + p->to()->index(), - WordNode{p->word(), ime_->model()->index(p->word())}, - p->as().encodedPinyin(), - SelectedPinyinType::Normal); - } - }); + selectHelper([offset, &sentence, + this](std::vector &selection) { + for (const auto &p : sentence.sentence()) { + selection.emplace_back( + offset + p->to()->index(), + PinyinWordNode{ + {p->word(), p->as().encodedPinyin()}, + ime_->model()->index(p->word())}, + + SelectedPinyinType::Normal); + } + }); } void selectCustom(size_t inputLength, std::string_view segment, @@ -210,20 +214,21 @@ class PinyinContextPrivate : public fcitx::QPtrHolder { &encodedPinyin](std::vector &selection) { auto index = ime_->model()->index(segment); selection.emplace_back( - offset + inputLength, WordNode{segment, index}, - std::string(encodedPinyin), SelectedPinyinType::Custom); + offset + inputLength, + PinyinWordNode{{segment, encodedPinyin}, index}, + SelectedPinyinType::Custom); }); } - LearnWordResult learnWord() { + std::tuple learnWord() { std::string ss; std::string pinyin; if (selected_.empty()) { - return LearnWordResult::Ignored; + return {LearnWordResult::Ignored, ""}; } // don't learn existing word. if (selected_.size() == 1 && selected_[0].size() == 1) { - return LearnWordResult::Ignored; + return {LearnWordResult::Ignored, ""}; } // Validate the learning word. // All single || custom || length <= 4 @@ -235,7 +240,7 @@ class PinyinContextPrivate : public fcitx::QPtrHolder { isAllSingleWord && (s.empty() || (s.size() == 1 && (s[0].type_ == SelectedPinyinType::Separator || - s[0].encodedPinyin_.size() == 2))); + s[0].encodedPinyin().size() == 2))); for (auto &item : s) { if (item.type_ == SelectedPinyinType::Separator) { continue; @@ -244,21 +249,21 @@ class PinyinContextPrivate : public fcitx::QPtrHolder { hasCustom = true; } // We can't learn non pinyin word. - if (item.encodedPinyin_.empty() || - item.encodedPinyin_.size() % 2 != 0) { - return LearnWordResult::Ignored; + if (item.encodedPinyin().empty() || + item.encodedPinyin().size() % 2 != 0) { + return {LearnWordResult::Ignored, ""}; } - totalPinyinLength += item.encodedPinyin_.size() / 2; + totalPinyinLength += item.encodedPinyin().size() / 2; } } FCITX_Q(); if (!hasCustom) { if ((!isAllSingleWord && totalPinyinLength > 4)) { - return LearnWordResult::Ignored; + return {LearnWordResult::Ignored, ""}; } if (ime_->model()->containsNonUnigram(q->selectedWords())) { - return LearnWordResult::Ignored; + return {LearnWordResult::Ignored, ""}; } } @@ -267,25 +272,26 @@ class PinyinContextPrivate : public fcitx::QPtrHolder { if (item.type_ == SelectedPinyinType::Separator) { continue; } - assert(!item.encodedPinyin_.empty()); - assert(item.encodedPinyin_.size() % 2 == 0); + assert(!item.encodedPinyin().empty()); + assert(item.encodedPinyin().size() % 2 == 0); ss += item.word_.word(); if (!pinyin.empty()) { pinyin.push_back('\''); } - pinyin += PinyinEncoder::decodeFullPinyin(item.encodedPinyin_); + pinyin += PinyinEncoder::decodeFullPinyin(item.encodedPinyin()); } } if (auto opt = ime_->dict()->lookupWord(PinyinDictionary::UserDict, pinyin, ss)) { - return LearnWordResult::Normal; + return {LearnWordResult::Ignored, ""}; } ime_->dict()->addWord(PinyinDictionary::UserDict, pinyin, ss, hasCustom ? -1 : 0); - return hasCustom ? LearnWordResult::Custom : LearnWordResult::Normal; + return {hasCustom ? LearnWordResult::Custom : LearnWordResult::Normal, + pinyin}; } }; @@ -921,7 +927,7 @@ PinyinContext::selectedWordsWithPinyin() const { for (const auto &item : s) { if (item.type_ != SelectedPinyinType::Separator) { newSentence.emplace_back(item.word_.word(), - item.encodedPinyin_); + item.encodedPinyin()); } } } @@ -933,11 +939,11 @@ std::string PinyinContext::selectedFullPinyin() const { std::string pinyin; for (const auto &s : d->selected_) { for (const auto &item : s) { - if (!item.encodedPinyin_.empty()) { + if (!item.encodedPinyin().empty()) { if (!pinyin.empty()) { pinyin.push_back('\''); } - pinyin += PinyinEncoder::decodeFullPinyin(item.encodedPinyin_); + pinyin += PinyinEncoder::decodeFullPinyin(item.encodedPinyin()); } } } @@ -970,26 +976,30 @@ void PinyinContext::learn() { return; } - if (auto result = d->learnWord(); result != LearnWordResult::Ignored) { + if (auto [result, encodedWordPinyin] = d->learnWord(); + result != LearnWordResult::Ignored) { // Do not insert custom to history for the first time. if (result == LearnWordResult::Normal) { - std::vector newSentence{sentence()}; - d->ime_->model()->history().add(newSentence); + // Create new sentence with the whole new learned word. + std::vector newSentence{ + {sentence(), encodedWordPinyin}}; + d->ime_->model()->history().addWithCode(newSentence); } } else { - std::vector newSentence; + std::vector newSentence; for (auto &s : d->selected_) { for (auto &item : s) { if (item.type_ != SelectedPinyinType::Separator) { // Non pinyin word. Skip it. - if (item.encodedPinyin_.empty()) { + if (item.encodedPinyin().empty()) { return; } - newSentence.push_back(item.word_.word()); + newSentence.push_back( + {item.word_.word(), item.encodedPinyin()}); } } } - d->ime_->model()->history().add(newSentence); + d->ime_->model()->history().addWithCode(newSentence); } } @@ -1014,7 +1024,7 @@ void PinyinContext::appendContextWords( for (const auto &word : std::span{contextWords}.last(std::min(contextWords.size(), needed))) { d->contextWords_.push_back( - WordNode(word, d->ime_->model()->index(word))); + PinyinWordNode({word, ""}, d->ime_->model()->index(word))); } while (d->contextWords_.size() > needed) { d->contextWords_.pop_front(); @@ -1031,6 +1041,40 @@ std::vector PinyinContext::contextWords() const { return words; } +void PinyinContext::setContextWordsWithPinyin( + const std::vector &contextWordsWithPinyin) { + FCITX_D(); + d->contextWords_.clear(); + appendContextWordsWithPinyin(contextWordsWithPinyin); +} + +void PinyinContext::appendContextWordsWithPinyin( + const std::vector &contextWordsWithPinyin) { + FCITX_D(); + + size_t needed = LanguageModel::maxOrder() - 1; + + for (const auto &word : std::span{contextWordsWithPinyin}.last( + std::min(contextWordsWithPinyin.size(), needed))) { + d->contextWords_.push_back( + PinyinWordNode(word, d->ime_->model()->index(word.first))); + } + while (d->contextWords_.size() > needed) { + d->contextWords_.pop_front(); + } +} + +std::vector +PinyinContext::contextWordsWithPinyin() const { + FCITX_D(); + std::vector words; + words.reserve(d->contextWords_.size()); + for (const auto &word : d->contextWords_) { + words.push_back({word.word(), word.encodedPinyin()}); + } + return words; +} + bool PinyinContext::learnWord() { return false; } PinyinIME *PinyinContext::ime() const { diff --git a/src/libime/pinyin/pinyincontext.h b/src/libime/pinyin/pinyincontext.h index c877ac2..d16c94e 100644 --- a/src/libime/pinyin/pinyincontext.h +++ b/src/libime/pinyin/pinyincontext.h @@ -19,6 +19,7 @@ #include #include #include +#include "libime/core/historybigram.h" namespace libime { class PinyinIME; @@ -109,8 +110,7 @@ class LIBIMEPINYIN_EXPORT PinyinContext : public InputBuffer { std::vector selectedWords() const; /// Selected hanzi with encoded pinyin - std::vector> - selectedWordsWithPinyin() const; + std::vector selectedWordsWithPinyin() const; /// Get the full pinyin string of the selected part. std::string selectedFullPinyin() const; @@ -164,6 +164,29 @@ class LIBIMEPINYIN_EXPORT PinyinContext : public InputBuffer { */ std::vector contextWords() const; + /** + * Set context words with pinyin for better prediction. + * @param contextWordsWithPinyin The context words with encoded pinyin. + * @since 1.1.14 + */ + void setContextWordsWithPinyin( + const std::vector &contextWordsWithPinyin); + + /** + * Append context words with pinyin for better prediction. + * @param contextWordsWithPinyin The context words with pinyin. + * @since 1.1.14 + */ + void appendContextWordsWithPinyin( + const std::vector &contextWordsWithPinyin); + + /** + * Get context words with pinyin for better prediction. + * @return current context words with pinyin + * @since 1.1.14 + */ + std::vector contextWordsWithPinyin() const; + protected: bool typeImpl(const char *s, size_t length) override; diff --git a/src/libime/pinyin/pinyindecoder_p.h b/src/libime/pinyin/pinyindecoder_p.h index 42bb621..5529967 100644 --- a/src/libime/pinyin/pinyindecoder_p.h +++ b/src/libime/pinyin/pinyindecoder_p.h @@ -8,6 +8,8 @@ #include #include +#include +#include #include namespace libime { @@ -20,6 +22,17 @@ class PinyinLatticeNodePrivate : public LatticeNodeData { std::string encodedPinyin_; bool isCorrection_ = false; }; + +class PinyinWordNode : public WordNode { +public: + PinyinWordNode(const HistoryBigram::WordWithCodeView &word, WordIndex idx) + : WordNode(word.first, idx), encodedPinyin_(word.second) {} + const std::string &encodedPinyin() const { return encodedPinyin_; } + +private: + std::string encodedPinyin_; +}; + } // namespace libime #endif // _FCITX_LIBIME_PINYIN_PINYINDECODER_P_H_ diff --git a/src/libime/pinyin/pinyinime.cpp b/src/libime/pinyin/pinyinime.cpp index d169c79..7058258 100644 --- a/src/libime/pinyin/pinyinime.cpp +++ b/src/libime/pinyin/pinyinime.cpp @@ -7,14 +7,17 @@ #include #include #include +#include #include #include #include #include "libime/core/decoder.h" +#include "libime/core/lattice.h" #include "libime/core/userlanguagemodel.h" #include "libime/pinyin/pinyincorrectionprofile.h" +#include "libime/pinyin/pinyindecoder.h" +#include "libime/pinyin/pinyindecoder_p.h" #include "libime/pinyin/pinyinencoder.h" -#include "pinyindecoder.h" namespace libime { @@ -25,6 +28,17 @@ class PinyinIMEPrivate : fcitx::QPtrHolder { : fcitx::QPtrHolder(q), dict_(std::move(dict)), model_(std::move(model)), decoder_(std::make_unique(dict_.get(), model_.get())) { + model_->setCodeExtractor([](const WordNode *node) -> std::string { + if (const auto *pinyinNode = + dynamic_cast(node)) { + return pinyinNode->encodedPinyin(); + } + if (const auto *wordNode = + dynamic_cast(node)) { + return wordNode->encodedPinyin(); + } + return ""; + }); } FCITX_DEFINE_SIGNAL_PRIVATE(PinyinIME, optionChanged); diff --git a/src/libime/pinyin/pinyinprediction.cpp b/src/libime/pinyin/pinyinprediction.cpp index a1b5f96..c1f4f9e 100644 --- a/src/libime/pinyin/pinyinprediction.cpp +++ b/src/libime/pinyin/pinyinprediction.cpp @@ -17,6 +17,7 @@ #include #include #include +#include "libime/core/historybigram.h" #include "libime/core/languagemodel.h" #include "libime/core/prediction.h" #include "libime/pinyin/pinyindictionary.h" @@ -66,13 +67,13 @@ PinyinPrediction::predict(const State &state, auto result = Prediction::predictWithScore(state, sentence, maxSize); std::vector> intermedidateResult; - std::transform( - result.begin(), result.end(), std::back_inserter(intermedidateResult), - [](std::pair &value) { - return std::make_tuple(std::move(value.first), value.second, + std::ranges::transform(result, std::back_inserter(intermedidateResult), + [](std::pair &value) { + return std::make_tuple( + std::move(value.first), value.second, PinyinPredictionSource::Model); - }); - std::make_heap(intermedidateResult.begin(), intermedidateResult.end(), cmp); + }); + std::ranges::make_heap(intermedidateResult, cmp); State prevState = model()->nullState(); State outState; @@ -131,17 +132,33 @@ PinyinPrediction::predict(const State &state, return true; }); - std::sort_heap(intermedidateResult.begin(), intermedidateResult.end(), cmp); - std::transform(intermedidateResult.begin(), intermedidateResult.end(), - std::back_inserter(finalResult), [](auto &value) { - return std::make_pair( - std::move(std::get(value)), - std::get(value)); - }); + std::ranges::sort_heap(intermedidateResult, cmp); + std::ranges::transform( + intermedidateResult, std::back_inserter(finalResult), [](auto &value) { + return std::make_pair(std::move(std::get(value)), + std::get(value)); + }); return finalResult; } +std::vector> +PinyinPrediction::predict( + const State &state, + const std::vector &sentence, + size_t maxSize) { + std::vector words; + words.reserve(sentence.size()); + for (const auto &[word, code] : sentence) { + words.push_back(word); + } + std::string_view lastPinyin; + if (!sentence.empty()) { + lastPinyin = sentence.back().second; + } + return predict(state, words, lastPinyin, maxSize); +} + std::vector PinyinPrediction::predict(const std::vector &sentence, size_t maxSize) { diff --git a/src/libime/pinyin/pinyinprediction.h b/src/libime/pinyin/pinyinprediction.h index 9b702cb..498d991 100644 --- a/src/libime/pinyin/pinyinprediction.h +++ b/src/libime/pinyin/pinyinprediction.h @@ -17,6 +17,7 @@ #include #include #include +#include "libime/core/historybigram.h" namespace libime { @@ -45,6 +46,16 @@ class LIBIMEPINYIN_EXPORT PinyinPrediction : public Prediction { predict(const State &state, const std::vector &sentence, std::string_view lastEncodedPinyin, size_t maxSize = 0); + /** + * Predict from model and pinyin dictionary for the last sentnce being type. + * + * @since 1.1.14 + */ + std::vector> + predict(const State &state, + const std::vector &sentence, + size_t maxSize = 0); + /** * Same as the Prediction::predict with the same signature. */ diff --git a/src/libime/table/tablecontext.cpp b/src/libime/table/tablecontext.cpp index 1f61f74..0a113ab 100644 --- a/src/libime/table/tablecontext.cpp +++ b/src/libime/table/tablecontext.cpp @@ -181,6 +181,13 @@ class TableContextPrivate : public fcitx::QPtrHolder { : QPtrHolder(q), dict_(dict), model_(model), decoder_(&dict, &model) { // Maybe use a better heuristics? candidates_.reserve(2048); + model_.setCodeExtractor([](const WordNode *word) -> std::string { + if (const auto *node = + dynamic_cast(word)) { + return node->code(); + } + return ""; + }); } // sort should already happened at this point. @@ -745,22 +752,34 @@ void TableContext::learn() { return; } } - std::vector newSentence; + std::vector newSentence; for (auto &s : d->selected_) { + if (s.empty()) { + continue; + } + if (std::ranges::any_of( + s, [](const auto &item) { return !item.commit_; })) { + continue; + } std::string word; - for (auto &item : s) { - if (!item.commit_) { - word.clear(); - break; + std::string code; + if (s.size() == 1) { + word = s[0].word_.word(); + code = s[0].code_; + } else { + for (auto &item : s) { + word += item.word_.word(); + } + if (!d->dict_.generate(word, code)) { + return; } - word += item.word_.word(); } if (!word.empty()) { - newSentence.emplace_back(std::move(word)); + newSentence.emplace_back(std::move(word), std::move(code)); } } if (!newSentence.empty()) { - d->model_.history().add(newSentence); + d->model_.history().addWithCode(newSentence); } } @@ -774,20 +793,30 @@ void TableContext::learnLast() { return; } - std::vector newSentence; + std::vector newSentence; + const auto &s = d->selected_.back(); + if (std::ranges::any_of(s, + [](const auto &item) { return !item.commit_; })) { + return; + } std::string word; - for (auto &item : d->selected_.back()) { - if (!item.commit_) { - word.clear(); + std::string code; + if (s.size() == 1) { + word = s[0].word_.word(); + code = s[0].code_; + } else { + for (const auto &item : s) { + word += item.word_.word(); + } + if (!d->dict_.generate(word, code)) { return; } - word += item.word_.word(); } if (!word.empty()) { - newSentence.emplace_back(std::move(word)); + newSentence.emplace_back(std::move(word), std::move(code)); } if (!newSentence.empty()) { - d->model_.history().add(newSentence); + d->model_.history().addWithCode(newSentence); } } diff --git a/test/testhistorybigram.cpp b/test/testhistorybigram.cpp index 199ace6..8a18912 100644 --- a/test/testhistorybigram.cpp +++ b/test/testhistorybigram.cpp @@ -14,6 +14,8 @@ #include #include "libime/core/historybigram.h" +namespace { + void testBasic() { using namespace libime; HistoryBigram history; @@ -209,11 +211,70 @@ void testSaveAndLoadText() { FCITX_ASSERT(dump1.str() == dump2.str()); } +void testWithCode() { + using namespace libime; + HistoryBigram history; + history.addWithCode({{"你", "code1"}, + {"是", "code2"}, + {"一个", "code3"}, + {"好人", "code4"}}); + + auto score = history.scoreWithCode({"你", "code1"}, {"是", "code2"}); + auto scoreWithoutCode = history.score("你", "是"); + auto scoreWithEmptyCode = history.scoreWithCode({"你", ""}, {"是", ""}); + auto scoreWithMismatchCode = + history.scoreWithCode({"你", "code1"}, {"是", "code5"}); + FCITX_ASSERT(score == scoreWithoutCode) << score << " " << scoreWithoutCode; + FCITX_ASSERT(score == scoreWithEmptyCode) + << score << " " << scoreWithEmptyCode; + FCITX_ASSERT(score > scoreWithMismatchCode) + << score << " " << scoreWithMismatchCode; + FCITX_ASSERT(history.rawUnigramFrequency({"你", ""}) == 1); + FCITX_ASSERT(history.rawUnigramFrequency({"你", "code1"}) == 1); + FCITX_ASSERT(history.rawUnigramFrequency({"你", "code2"}) == 0); + FCITX_ASSERT(history.rawBigramFrequency({"你", ""}, {"是", ""}) == 1); + FCITX_ASSERT(history.rawBigramFrequency({"你", "code1"}, {"是", "code2"}) == + 1); + FCITX_ASSERT(history.rawBigramFrequency({"你", "code2"}, {"是", "code2"}) == + 0); + FCITX_ASSERT(history.rawBigramFrequency({"你", ""}, {"是", "code2"}) == 1); + FCITX_ASSERT(history.rawBigramFrequency({"你", "code1"}, {"是", ""}) == 1); +} + +void testWithCodePredict() { + using namespace libime; + HistoryBigram history; + history.addWithCode({{"你", "code1"}, + {"是", "code2"}, + {"一个", "code3"}, + {"好人", "code4"}}); + + { + std::unordered_set result; + history.fillPredict(result, {"你"}, 10); + FCITX_ASSERT(result == std::unordered_set{"是"}); + } + + { + std::unordered_set result; + history.addWithCode({{"你", "code1"}, {"是", "code5"}}); + history.addWithCode({{"你", "code1"}, {"是", "code6"}}); + history.addWithCode({{"你", "code1"}, {"是", "code7"}}); + history.addWithCode({{"你", "code1"}, {"是", "code8"}}); + history.fillPredict(result, {"你"}, 0); + FCITX_ASSERT(result == std::unordered_set{"是"}) << result; + } +} + +} // namespace + int main() { testBasic(); testOverflow(); testPredict(); testSaveAndLoad(); testSaveAndLoadText(); + testWithCode(); + testWithCodePredict(); return 0; } diff --git a/test/testpinyinime_unit.cpp b/test/testpinyinime_unit.cpp index 473d5a4..040e027 100644 --- a/test/testpinyinime_unit.cpp +++ b/test/testpinyinime_unit.cpp @@ -5,10 +5,13 @@ */ #include +#include #include #include +#include #include #include "libime/core/historybigram.h" +#include "libime/core/lattice.h" #include "libime/core/userlanguagemodel.h" #include "libime/pinyin/pinyincontext.h" #include "libime/pinyin/pinyincorrectionprofile.h" @@ -20,20 +23,20 @@ using namespace libime; -int main() { - fcitx::Log::setLogRule("libime=5"); - libime::PinyinIME ime( - std::make_unique(), - std::make_unique(LIBIME_BINARY_DIR "/data/sc.lm")); - ime.setNBest(2); - ime.dict()->load(PinyinDictionary::SystemDict, - LIBIME_BINARY_DIR "/data/sc.dict", - PinyinDictFormat::Binary); - PinyinFuzzyFlags flags = PinyinFuzzyFlag::Inner; - ime.setFuzzyFlags(flags); - ime.setScoreFilter(1.0F); - ime.setShuangpinProfile( - std::make_shared(ShuangpinBuiltinProfile::Xiaohe)); +namespace { + +size_t candidateIndex(PinyinContext &c, const std::string &candidate) { + auto iter = + std::ranges::find(c.candidates(), candidate, &SentenceResult::toString); + FCITX_ASSERT(iter != c.candidates().end()); + return std::distance(c.candidates().begin(), iter); +} + +void selectCandidate(PinyinContext &c, const std::string &candidate) { + c.select(candidateIndex(c, candidate)); +} + +void testPinyin(PinyinIME &ime) { PinyinContext c(&ime); c.type("nihaozhongguo"); @@ -48,9 +51,8 @@ int main() { FCITX_ASSERT(!c.candidatesToCursorSet().count("你好中国")); FCITX_ASSERT(c.candidatesToCursorSet().count("你好")); c.setCursor(0); - auto iter = std::find_if( - c.candidates().begin(), c.candidates().end(), - [](const auto &cand) { return cand.toString() == "你好中国"; }); + auto iter = std::ranges::find(c.candidates(), "你好中国", + &SentenceResult::toString); FCITX_ASSERT(iter != c.candidates().end()); FCITX_ASSERT(!ime.dict()->lookupWord(PinyinDictionary::UserDict, "ni'hao'zhong'guo", "你好中国")); @@ -58,6 +60,10 @@ int main() { c.learn(); FCITX_ASSERT(ime.model()->history().containsBigram("你", "好")); FCITX_ASSERT(ime.model()->history().containsBigram("好", "中国")); +} + +void testShuangpin(PinyinIME &ime) { + PinyinContext c(&ime); c.setUseShuangpin(true); @@ -68,19 +74,58 @@ int main() { c.type("bkqiln"); FCITX_ASSERT(c.candidates().size() == c.candidateSet().size()); - FCITX_ASSERT(!c.candidateSet().count("冰淇淋")); + FCITX_ASSERT(!c.candidateSet().contains("冰淇淋")); c.clear(); ime.setCorrectionProfile(std::make_shared( BuiltinPinyinCorrectionProfile::Qwerty)); ime.setShuangpinProfile(std::make_shared( ShuangpinBuiltinProfile::Xiaohe, ime.correctionProfile().get())); - ime.setFuzzyFlags(flags | PinyinFuzzyFlag::Correction); + ime.setFuzzyFlags({PinyinFuzzyFlag::Inner, PinyinFuzzyFlag::Correction}); c.type("bkqiln"); FCITX_ASSERT(c.candidates().size() == c.candidateSet().size()); - FCITX_ASSERT(c.candidateSet().count("冰淇淋")); + FCITX_ASSERT(c.candidateSet().contains("冰淇淋")); c.clear(); +} + +void testHistory(PinyinIME &ime) { + PinyinContext c(&ime); + c.type("kuai"); + FCITX_ASSERT(c.candidateSet().contains("会")); + auto kuaiIndex = candidateIndex(c, "会"); + c.clear(); + c.type("hui"); + FCITX_ASSERT(c.candidateSet().contains("会")); + selectCandidate(c, "会"); + FCITX_ASSERT(c.selected()); + c.learn(); + c.clear(); + c.type("kuai"); + auto kuaiIndexNew = candidateIndex(c, "会"); + FCITX_ASSERT(kuaiIndexNew == kuaiIndex); +} + +} // namespace + +int main() { + fcitx::Log::setLogRule("libime=5"); + libime::PinyinIME ime( + std::make_unique(), + std::make_unique(LIBIME_BINARY_DIR "/data/sc.lm")); + ime.setNBest(2); + ime.dict()->load(PinyinDictionary::SystemDict, + LIBIME_BINARY_DIR "/data/sc.dict", + PinyinDictFormat::Binary); + PinyinFuzzyFlags flags = PinyinFuzzyFlag::Inner; + ime.setFuzzyFlags(flags); + ime.setScoreFilter(1.0F); + ime.setShuangpinProfile( + std::make_shared(ShuangpinBuiltinProfile::Xiaohe)); + + testPinyin(ime); + testShuangpin(ime); + testHistory(ime); return 0; } diff --git a/test/testtableime.cpp b/test/testtableime.cpp index 113fab7..c22975a 100644 --- a/test/testtableime.cpp +++ b/test/testtableime.cpp @@ -86,6 +86,8 @@ int main() { c.autoSelect(); c.learn(); c.clear(); + } else if (word == "history") { + model.history().dump(std::cout); } size_t count = 1; diff --git a/test/testtableime_unit.cpp b/test/testtableime_unit.cpp index ee2d2c0..1e5ac7a 100644 --- a/test/testtableime_unit.cpp +++ b/test/testtableime_unit.cpp @@ -4,10 +4,14 @@ * SPDX-License-Identifier: LGPL-2.1-or-later */ +#include +#include +#include #include #include #include #include "libime/core/languagemodel.h" +#include "libime/core/lattice.h" #include "libime/core/userlanguagemodel.h" #include "libime/table/tablebaseddictionary.h" #include "libime/table/tablecontext.h" @@ -16,6 +20,8 @@ using namespace libime; +namespace { + class TestLmResolver : public LanguageModelResolver { public: TestLmResolver(std::string_view path) : path_(path) {} @@ -33,7 +39,22 @@ class TestLmResolver : public LanguageModelResolver { std::string path_; }; -int main() { +size_t candidateIndex(TableContext &c, const std::string &candidate) { + auto candidates = c.candidates(); + auto iter = + std::ranges::find(candidates, candidate, &SentenceResult::toString); + std::ranges::for_each(candidates, [](const auto &candidate) { + FCITX_INFO() << candidate.toString() << " " << candidate.score(); + }); + FCITX_ASSERT(iter != candidates.end()); + return std::distance(candidates.begin(), iter); +} + +void selectCandidate(TableContext &c, const std::string &candidate) { + c.select(candidateIndex(c, candidate)); +} + +void testBasic() { fcitx::Log::setLogRule("*=5"); TestLmResolver lmresolver(LIBIME_BINARY_DIR "/data/sc.lm"); auto lm = lmresolver.languageModelFileForLanguage("zh_CN"); @@ -103,6 +124,54 @@ int main() { c.clear(); FCITX_INFO() << "========================"; } +} + +void testHistory() { + fcitx::Log::setLogRule("*=5"); + TestLmResolver lmresolver(LIBIME_BINARY_DIR "/data/sc.lm"); + auto lm = lmresolver.languageModelFileForLanguage("zh_CN"); + TableBasedDictionary dict; + UserLanguageModel model(lm); + dict.load(LIBIME_BINARY_DIR "/data/wbx.main.dict"); + TableOptions options; + options.setLanguageCode("zh_CN"); + options.setLearning(true); + options.setAutoPhraseLength(-1); + options.setAutoSelect(true); + options.setAutoSelectLength(-1); + options.setNoMatchAutoSelectLength(-1); + options.setNoSortInputLength(0); + options.setAutoRuleSet({}); + options.setMatchingKey('z'); + options.setOrderPolicy(OrderPolicy::Freq); + dict.setTableOptions(options); + TableContext c(dict, model); + c.type("a"); + auto index = candidateIndex(c, "其"); + c.clear(); + + c.type("adw"); + selectCandidate(c, "其"); + c.learn(); + c.clear(); + + c.type("a"); + auto index2 = candidateIndex(c, "其"); + FCITX_ASSERT(index == index2); + c.select(index2); + c.learn(); + c.clear(); + + c.type("a"); + auto index3 = candidateIndex(c, "其"); + FCITX_ASSERT(index3 < index2); +} + +} // namespace + +int main() { + testBasic(); + testHistory(); return 0; }