codebreakparty/markov.hpp

176 lines
4.6 KiB
C++

#pragma once
#include <limits>
#include <map>
#include <random>
#include <stdlib.h>
#include <vector>
/* Markov chain class.
T: Element type
ORDER: Prefix length. A markov chain of order n uses n elements to give a
probability for the next element
*/
template <typename T, size_t ORDER> class Markov {
public:
template <typename, size_t> friend class Markov;
using Child = Markov<T, ORDER - 1>;
using Children = std::map<T, Child>;
/* Total number of entries for this node
If not a leaf, also the sum of the totals of its children.
*/
size_t total = 0;
/* A map of a key to child nodes */
Children children;
Markov() = default;
constexpr size_t order() const { return ORDER; }
/* Add the start of iter to the model
Must have sufficient size (at least order + 1)
*/
template <typename IT> void add(IT iter, size_t size) {
if (size <= ORDER) {
return;
}
total++;
find(*iter)->second.add(iter + 1, size - 1);
}
/* Decrements the occurences of the start of iter by 1.
Must have sufficient size (at least order + 1)
Returns true if the sequence could be found in the model
*/
template <typename IT> bool dec(IT iter, size_t size) {
if (size < ORDER) {
return false;
}
total--;
iter++;
auto child = children.find(*iter);
if (child != children.end()) {
if (child->dec(iter, size - 1)) {
if (child->total == 0) {
children.erase(child);
}
return true;
}
}
return false;
}
/* Probability of finding the last item (but at most the order+1st) in the
model, given the previous sequence */
template <typename IT> double final_probability(IT iter, size_t size) const {
if (size == 0) {
return 1.;
}
auto child = children.find(*iter);
if (child != children.cend()) {
auto &val = child->second;
if (size == 1) {
return double(val.total) / double(total);
} else {
return val.final_probability(iter + 1, size - 1);
}
} else {
return 0.5 / double(total);
}
}
/* Generate one single item that can appear after the prefix given by iter and
size, using the RNG g */
template <typename IT, typename Generator>
T gen(IT iter, size_t size, Generator &g) const {
while (size > order()) {
iter++;
size--;
}
if (size == 0) {
return random_child(g);
}
return gen_impl(iter, size, g, std::bool_constant<(ORDER > 0)>{});
}
/* Probability of finding the first order() items (but at most size) in the
sequence */
template <typename IT> double probability(IT iter, size_t size) const {
if (size == 0) {
return 1.;
}
auto child = children.find(*iter);
if (child != children.cend()) {
auto &val = child->second;
return (double(val.total) / double(total)) *
val.probability(iter + 1, size - 1);
} else {
return 0.5 / double(total);
}
}
/* Returns a random item, weighted by the probability of child nodes */
template <typename Generator> T random_child(Generator &g) const {
std::uniform_int_distribution<size_t> dist(0u, total - 1);
auto val = dist(g);
for (auto &child : children) {
if (val < child.second.total) {
return child.first;
}
val -= child.second.total;
}
return T{};
}
protected:
/* Finds the key or adds it */
typename Children::iterator find(const T &key) {
return children.try_emplace(key).first;
}
/* Generate an item for Markov chains of order 0 */
template <typename IT, typename Generator>
T gen_impl(IT iter, size_t size, Generator &g,
std::bool_constant<false>) const {
return random_child(g);
}
/* Generate an item for longer chains, using the prefix */
template <typename IT, typename Generator>
T gen_impl(IT iter, size_t size, Generator &g,
std::bool_constant<true>) const {
auto child = children.find(*iter);
if (child == children.end()) {
return random_child(g);
}
return child->second.gen(iter + 1, size - 1, g);
}
};
/* Leaf of the recursion with just a count */
template <typename T> class Markov<T, std::numeric_limits<size_t>::max()> {
public:
size_t total = 0;
Markov() = default;
constexpr size_t order() const { return 0; }
template <typename IT> void add(IT, size_t) { total++; }
template <typename IT> bool dec(IT, size_t) {
total--;
return true;
}
template <typename IT> double final_probability(IT, size_t) const {
return 1.;
}
template <typename IT> double probability(IT, size_t) const { return 1.; }
};