ShovelingLife
A Game Programmer
ShovelingLife
전체 방문자
오늘
어제
  • 분류 전체보기 (1118)
    • 개인 프로젝트 (2)
      • Unity (2)
      • Unreal (0)
    • 그래픽스 (57)
      • 공통 (19)
      • 수학 물리 (22)
      • OpenGL & Vulkan (1)
      • DirectX (14)
    • 게임엔진 (191)
      • Unreal (69)
      • Unity (111)
      • Cocos2D-X (3)
      • 개인 플젝 (8)
    • 코딩테스트 (221)
      • 공통 (7)
      • 프로그래머스 (22)
      • 백준 (162)
      • LeetCode (19)
      • HackerRank (2)
      • 코딩테스트 알고리즘 (8)
    • CS (242)
      • 공통 (21)
      • 네트워크 (45)
      • OS & 하드웨어 (60)
      • 자료구조 & 알고리즘 (99)
      • 디자인패턴 (6)
      • UML (4)
      • 데이터베이스 (7)
    • 프로그래밍 언어 (377)
      • C++ (168)
      • C# (97)
      • Java (11)
      • Python (36)
      • SQL (44)
      • JavaScript (9)
      • React (7)
    • 그 외 (15)
      • Math (7)
      • 일상 (5)

블로그 메뉴

  • 홈
  • 태그
  • 방명록

공지사항

  • Source Code 좌측 상단에 복사 버튼 추가 완료
  • 언리얼 엔진 C++ 빌드시간 단축 꿀팁
  • 게임 업계 코딩테스트 관련
  • 1인칭 시점으로 써내려가는 글들

인기 글

태그

  • 알고리즘
  • 백준
  • 함수
  • 파이썬
  • 포인터
  • SQL
  • C++
  • Unity
  • 배열
  • 그래픽스
  • 클래스
  • C
  • 언리얼
  • 유니티
  • string
  • 문자열
  • python
  • 티스토리챌린지
  • 오블완
  • c#

최근 댓글

최근 글

티스토리

hELLO · Designed By 정상우.
ShovelingLife

A Game Programmer

CS/자료구조 & 알고리즘

[C++] Trie (트라이) 소스코드

2023. 12. 1. 15:16

trie.h

#ifndef TRIE_H
#define TRIE_H

#include <memory>
#include <string>
#include <vector>
#include <deque>
#include <map>
#include <iterator>
#include <iostream>
#include <utility>
#include <type_traits>

namespace trie
{

struct SetCounter { };

namespace detail
{

template <typename AtomT, typename PrefixHolderT>
struct TrieNode : public PrefixHolderT
{
private:
    typedef TrieNode<AtomT, PrefixHolderT> self_type;
    typedef self_type * self_pointer;

    self_pointer * data = nullptr;
    uint32_t size = 0; /* Will still be 64 bit due to alignment */

public:
    typedef self_type * const * map_iterator;

    TrieNode(int hint) { if (hint > 0) { resize(hint); } }

    inline static int atom_hash(AtomT x, uint32_t mask) {
        return (x & mask);
    }

    inline static int least_uncolliding_size(AtomT a, AtomT b)
    {
        unsigned int v = ((int) a ^ (int) b);
        return (v & -v) << 1;
    }

    static self_type * value(map_iterator x) { return *x; };

    void resize(uint32_t new_size)
    {
        self_pointer * ndata = (new_size == 0) ?
            nullptr : (new self_pointer[new_size]);

        std::fill(ndata, ndata + new_size, nullptr);

        if (new_size > size) {
            for (uint32_t i = 0; i < size; ++i) {
                if (data[i] != nullptr) {
                    ndata[atom_hash(*data[i]->kbegin(), new_size - 1)] = data[i];
                }
            }
        }

        delete[] data;
        data = ndata;
        size = new_size;
    }

    map_iterator find(AtomT x) const 
    {
        if (size != 0)
        {
            map_iterator result = data + atom_hash(x, size-1);

            if (nullptr != *result and (*result)->starts_with(x)) {
                return result;
            }
        }

        return nullptr;
    }

    ~TrieNode() { resize(0); };

    void put(self_type * edge)
    {
        AtomT x = *edge->kbegin();

        if (size == 0) { resize(2); }

        int hash = atom_hash(x, size-1);

        if (data[hash] == nullptr) {
            data[hash] = edge;
            return;
        }

        resize(least_uncolliding_size(x, *data[hash]->kbegin()));

        data[atom_hash(x, size-1)] = edge;
    }

    map_iterator begin() const { return data; }
    map_iterator end()   const { return data + size; }
    std::nullptr_t nf()  const { return nullptr; }

    void split(self_type * next, int breakIdx)
    {
        this->PrefixHolderT::psplit(next, breakIdx);
        std::swap(this->data, next->data);
        std::swap(this->size, next->size);
        this->swap_value(*next);
        put(next);
    }
};

template <typename AtomT, typename NodeT>
struct TrieIteratorInternal
{
    typedef std::vector< AtomT > key_type;
    typedef typename NodeT::map_iterator traverse_ptr;

    key_type base_prefix;
    const NodeT * m_root;
    std::vector<traverse_ptr> m_ptrs;

    TrieIteratorInternal(const NodeT * a_root) : m_root(a_root) { };

    const NodeT * get(int i = 0) const
    {
        int j = (int) m_ptrs.size() + i;
        return j > 0 ? NodeT::value(m_ptrs[j - 1]) : (j == 0 ? m_root : nullptr);
    }

    typename NodeT::value_type & get_value() const
    {
        NodeT * top = const_cast<NodeT *>(m_ptrs.empty() ? m_root : NodeT::value(m_ptrs.back()));
        return top->get_value();
    }

    std::basic_string<AtomT> get_key_str()
    {
        std::basic_string<AtomT> result(base_prefix.begin(), base_prefix.end());

        const NodeT * i = m_root;

        std::copy(i->kbegin(), i->kend(), std::back_inserter(result));

        for (auto && traverse_ptr : m_ptrs)
        {
            i = NodeT::value(traverse_ptr);
            std::copy(i->kbegin(), i->kend(), std::back_inserter(result));
        }

        return result;
    }

    key_type get_key()
    {
        key_type result(base_prefix);

        const NodeT * i = m_root;
        result.insert(result.back(), i->kbegin(), i->kend());

        for (auto && traverse_ptr : m_ptrs)
        {
            i = traverse_ptr->second;
            result.insert(result.back(), i->kbegin(), i->kend());
        }

        return result;
    }

    bool step_down()
    {
        const NodeT * x = get();
        traverse_ptr it = x->begin();

        while (it != x->end()
                and NodeT::value(it) == nullptr) { 
            ++it; 
        }

        if (it != x->end()) 
        {
            m_ptrs.push_back(it);
            return true;
        }

        return false;
    }

    bool step_fore()
    {
        const NodeT * up = get(-1);

        if (up != nullptr)
        {
            do {
                ++m_ptrs.back();
            } while (m_ptrs.back() != up->end()
                and NodeT::value(m_ptrs.back()) == nullptr);

            return m_ptrs.back() != up->end();
        }

        return false;
    }

    bool step_up()
    {
        m_ptrs.pop_back();

        return step_fore();
    }

    void next()
    {
        if (step_down())   { return; }
        if (step_fore())   { return; }

        while (!m_ptrs.empty()) {
            if (step_up()) { return; }
        }

        m_root = nullptr;
    }

    bool next_value()
    {
        do {
            next();
        } while (m_root != nullptr and !get()->has_value());

        return m_root != nullptr;
    }

    void push(traverse_ptr it)
    {
        m_ptrs.push_back(it);
    }
};

typedef uint32_t trie_offset_t;

template <typename ValueT>
struct ValueHolder
{
private:
    std::auto_ptr<ValueT> value;
public:
    typedef ValueT value_type; /* Effective type */

    value_type & get_value()             { return *value; };
    const value_type & get_value() const { return *value; };

    bool     has_value() const noexcept  { return value.get() != nullptr; };
    void     set_value(const ValueT & x) { value.reset(new ValueT(x)); };
    void     clr_value()                 { set_value(nullptr); };

    void swap_value(ValueHolder & other) { std::swap(this->value, other.value); };
};

template <>
struct ValueHolder<SetCounter>
{
private:
    int count = 0;
public:
    typedef int value_type; /* Effective type */

    value_type & get_value() noexcept { return count; };
    const value_type & get_value() const noexcept { return count; };

    bool     has_value() const noexcept      { return count != 0; };
    void     set_value(const value_type & x) { count = x; };
    void     clr_value()                     { count = 0; };

    void swap_value(ValueHolder & other) { std::swap(count, other.count); };
};


template <typename AtomT, typename ValueT, size_t CMinChunkSize>
struct PrefixHolder : public ValueHolder<ValueT>
{
private:
    typedef PrefixHolder<AtomT, ValueT, CMinChunkSize> self_type;
    typedef std::vector<AtomT> ChunkT;

    ChunkT * chunk;
    trie_offset_t begin, end;
public:
    typedef const AtomT * key_iterator;

    bool starts_with(AtomT x) const { return (*chunk)[begin] == x; };
    key_iterator kbegin() const { return std::addressof((*chunk)[begin]); };
    key_iterator kend()   const { return std::addressof((*chunk)[end]); };

    ChunkT * insertion_hint() { return chunk; }

    void setkey(ChunkT * achunk, trie_offset_t k, trie_offset_t kend)
    {
        chunk = achunk;
        begin = k;
        end   = kend;
    }

    void psplit(self_type * next, int breakIdx)
    {
        next->chunk = chunk;
        next->begin = this->begin + breakIdx;
        next->end   = this->end;
        this->end   = next->begin;
    }
};

template <typename AtomT, typename ValueT>
struct PrefixHolder<AtomT, ValueT, 0> : public ValueHolder<ValueT>
{
private:
    typedef PrefixHolder<AtomT, ValueT, 0> self_type;
    typedef std::vector<AtomT>     ChunkT;

    AtomT  * prefix;
    size_t prefix_len;
    ValueT * value = nullptr;
public:
    typedef const AtomT * key_iterator;

    bool starts_with(AtomT x) const { return *prefix == x; };
    key_iterator kbegin() const { return prefix; };
    key_iterator kend()   const { return prefix + prefix_len; };

    std::nullptr_t insertion_hint() { return nullptr; }

    void setkey(ChunkT * achunk, trie_offset_t k, trie_offset_t kend)
    {
        prefix     = std::addressof(achunk->begin()[k]);
        prefix_len = kend - k;
    }

    void psplit(self_type * next, int breakIdx)
    {
        next->prefix     = this->prefix + breakIdx;
        next->prefix_len = this->prefix_len - breakIdx;
        this->prefix_len -= next->prefix_len;
    }
};

template<typename AtomT, typename ValueT, size_t CMinChunkSize, 
    typename Spec = void>
struct TrieNodeSelector
{
    /* No default implementation.
     * Must be specialized by code, which tries to use it. */
};

template<typename AtomT, typename ValueT, size_t CMinChunkSize>
struct TrieNodeSelector<AtomT, ValueT, CMinChunkSize, 
    typename std::enable_if<std::is_integral<AtomT>::value>::type>
{
    typedef PrefixHolder<char, ValueT, CMinChunkSize>  PrefixHolderType;
    typedef TrieNode<AtomT, PrefixHolderType>    type;
};

};

template <typename AtomT, typename ValueT, size_t CMinChunkSize = 0, 
    typename NodeImpl = typename detail::TrieNodeSelector<AtomT, ValueT, CMinChunkSize>::type >
struct trie_map
{
private:
    typedef NodeImpl NodeT;
    typedef detail::TrieIteratorInternal<AtomT, NodeT> IteratorInternalT;

    /* We use deque in order to make edge storage stable */
    typedef std::deque< NodeT >                    EdgeStorageT;
public:
    typedef typename NodeImpl::value_type          value_type;
    typedef typename IteratorInternalT::key_type   key_type;
    typedef typename NodeImpl::key_iterator        key_iterator;

    typedef value_type mapped_type; /* Defined for the compatibility with map */
private:
    typedef std::vector<AtomT> ChunkT;
    typedef std::deque<ChunkT> InternalStorageT;

    /* The number of elements */
    size_t msize = 0;
    InternalStorageT keys;

    template<typename KeyIterator>
    void insert_infix(KeyIterator it, KeyIterator end, NodeT * parent, NodeT * n)
    {
        size_t ksize = std::distance(it, end);
        ChunkT * target = parent == nullptr ? nullptr : parent->insertion_hint();

        if ((target == nullptr) or (target->size() + ksize) > CMinChunkSize)
        {
            if (CMinChunkSize == 0 or keys.empty() or 
                    (keys.back().size() + ksize) > CMinChunkSize)
            {
                keys.emplace_back();
                keys.back().reserve(CMinChunkSize);
            }

            target = std::addressof(keys.back());
        }

        /* WARNING : Here, we rely on the fact, that vector pointers always
         * remain stable if values inserted fit into reserved space, which
         * should work in practice, but std::vector specification
         * does not guarantee that.
         */
        detail::trie_offset_t kidx    = target->size();
        target->insert(target->end(), it, end);
        detail::trie_offset_t kendidx = target->size();

        n->setkey(target, kidx, kendidx);
    }

    EdgeStorageT edges;

    NodeT * root() { return std::addressof(*edges.begin()); }

    NodeT * new_edge(int hint)
    {
        edges.emplace_back(hint);
        return std::addressof(edges.back());
    }

    template<typename KeyIterator>
    NodeT * insert_edge(NodeT * parent, KeyIterator it, KeyIterator end, const value_type & value)
    {
        NodeT * n = new_edge(0);
        insert_infix(it, end, parent, n);
        if (parent != nullptr) { parent->put(n); }
        n->set_value(value);
        return n;
    }

    template<typename ReplacePolicy>
    void insert_value(NodeT & at, const value_type & value, const ReplacePolicy & replace)
    {
        if (at.has_value()) {
            replace(at.get_value(), value);
        } else {
            at.set_value(value);
        }
    }

    typedef std::shared_ptr<IteratorInternalT> IteratorPtr;

public:

    /** @brief The iterator with very unfair behaviour
     *
     *  There is no const iterator counterpart, because there is no real
     *  benefit in making this iterator const.
     *
     *  @warning: This operator is not copiable in regular sense.
     *    In order to get a fixed copy of the iterator,
     *    use explicit clone() method!
     */
    struct iterator : public std::forward_iterator_tag
    {
        friend class trie_map;

    private:
        IteratorPtr _impl;

        iterator() : _impl() { };

        void normalize() {
            if (!_impl->get()->has_value()) { ++(*this); }
        }

        explicit iterator(IteratorPtr a_impl)
            : _impl(a_impl) { normalize(); }

        explicit iterator(IteratorInternalT * a_impl)
            : _impl(a_impl) { normalize(); }

    public:
        value_type & value() {
            return _impl->get_value();
        }

        std::basic_string<AtomT> key() {
            return _impl->get_key_str();
        }

        value_type & operator *() { return value(); }

        /*
         * There is only one increment operator, since the postfix one
         * is very heavy if implemented correctly.
         */
        iterator & operator ++()
        {
            if (_impl.get() != nullptr)  {
                if (!_impl->next_value()) { _impl.reset(); }
            }

            return *this;
        }

        /**
         * Returns a "real" copy of the iterator, may be a heavy operation.
         */
        iterator clone() {
            return (_impl.get() == nullptr) ?
                iterator() : iterator(new IteratorInternalT(*_impl));
        }

        bool operator == (const iterator & other) const
        {
            return _impl.get() == other._impl.get()
                || (_impl.get() && other._impl.get()
                        && _impl->get() == other._impl->get());
        }

        bool operator != (const iterator & other) const
        {
            return not (*this == other);
        }
    };

private:
    typedef typename NodeT::map_iterator NodeItr;

    /**
     * Generalized lookup algorithm.
     */
    template<typename KeyIterator, typename A, typename B, typename C, typename D, typename E>
    inline void general_search
    (
        NodeT * n,
        KeyIterator it,
        KeyIterator end,
        A exactMatchAction,
        B noNextEdgeAction,
        C endInTheMiddleAction,
        D splitInTheMiddleAction,
        E edgeAction
    )
    {
        key_iterator kbegin = n->kbegin();

        while (n != nullptr)
        {
            key_iterator kend   = n->kend();
            key_iterator k      = kbegin;

            while ((it != end) and (k != kend) and (*k == *it))
                { ++k; ++it; }

            if (it == end)
            {
                if (k == kend) {
                    exactMatchAction(n);
                } else {
                    endInTheMiddleAction(n, k);
                }

                return;
            }
            else if (k != kend)
            {
                splitInTheMiddleAction(n, k, it);
                return;
            }

            NodeItr next_edge = n->find(*it);

            if (next_edge == n->nf())
            {
                noNextEdgeAction(n, it);
                return;
            }

            edgeAction(next_edge, it);

            n = NodeT::value(next_edge);
            kbegin = n->kbegin() + 1;
            ++it; /* Already found the first character */
        }
    }

public:
    template<typename KeyIterator, typename ReplacePolicy>
    void insert(KeyIterator it, KeyIterator end, const value_type & value,
                    const ReplacePolicy & replace)
    {
        if (edges.empty())
        {
            insert_edge(nullptr, it, end, value);
            ++msize;
            return;
        }

        general_search(root(), it, end,
            [this, &value, &replace] (NodeT * n) {
                insert_value(*n, value, replace);
            },

            [this, &value, end] (NodeT * n, KeyIterator kit) {
                insert_edge(n, kit, end, value);
                ++msize;
            },

            [this, &value] (NodeT * n, key_iterator eit) {
                n->split(new_edge(1), eit - n->kbegin());
                n->set_value(value);
                ++msize;
            },

            [this, &value, end] (NodeT * n, key_iterator eit, KeyIterator kit) {
                n->split(new_edge(2), eit - n->kbegin());
                insert_edge(n, kit, end, value);
                ++msize;
            },

            [] (NodeItr x, KeyIterator) { (void)x; }
        );
    }

    size_t size() const noexcept { return msize; }

    template<typename KeyIterator>
    void add(KeyIterator it, KeyIterator end, const value_type & value) {
        return insert(it, end, value,
            [] (value_type & old, const value_type & n) { old += n; } );
    }

    template<typename KeyIterator>
    void insert(KeyIterator it, KeyIterator end, const value_type & value) {
        return insert(it, end, value,
            [] (value_type & old, const value_type & n) { old = n; });
    }

    template<typename ReplacePolicy>
    void insert(const std::basic_string<AtomT> & str, const value_type & value,
                    const ReplacePolicy & replace)
    {
        return insert(str.begin(), str.end(), value, replace);
    }

    void add(const std::basic_string<AtomT> & str, const value_type & value) {
        return add(str.begin(), str.end(), value);
    }

    void insert(const std::basic_string<AtomT> & str, const value_type & value) {
        return insert(str.begin(), str.end(), value);
    }

private:
    template<typename _ValueT, 
        typename = typename std::enable_if<std::is_same<_ValueT, SetCounter>::value>::type>
    struct SetSpecific { };

public:

    template<typename KeyIterator, typename _ValueT = ValueT, typename = SetSpecific<_ValueT> >
    void insert(KeyIterator it, KeyIterator end) {
        return insert(it, end, 1);
    }

    template<typename KeyIterator, typename _ValueT = ValueT, typename = SetSpecific<_ValueT>  >
    void add(KeyIterator it, KeyIterator end) {
        return add(it, end, 1);
    }

    template<typename _ValueT = ValueT, typename = SetSpecific<_ValueT> >
    void insert(const std::basic_string<AtomT> & str) {
        return insert(str.begin(), str.end(), 1);
    }

    template<typename _ValueT = ValueT, typename = SetSpecific<_ValueT>  >
    void add(const std::basic_string<AtomT> & str) {
        return add(str.begin(), str.end(), 1);
    }

    template<typename KeyIterator>
    bool contains(KeyIterator it, KeyIterator end)
    {
        if (edges.empty()) { return false; }

        bool result = false;

        general_search(root(), it, end,
            [&result] (NodeT * n) {
                if (n->has_value()) { result = true; } },

            [] (NodeT * , KeyIterator) { },
            [] (NodeT * , key_iterator ) { },
            [] (NodeT * , key_iterator , KeyIterator ) { },
            [] (NodeItr, KeyIterator) { }
        );

        return result;
    }

    bool contains(const std::basic_string<AtomT> & str)
    {
        return contains(str.begin(), str.end());
    }

private:
    template <typename KeyIterator, typename CallbackType>
    iterator find_prefix_int(NodeT * root_node, KeyIterator it, KeyIterator kend, CallbackType exactMatch)
    {
        iterator output;
        KeyIterator inputEnd;

        general_search(root_node, it, kend,
            /* Exact Match */
            [&exactMatch, &output] (NodeT * n)  {
                if (n->has_value()) { exactMatch(); }
                output._impl.reset(new IteratorInternalT(n));
            },

            [] (NodeT *, KeyIterator) { },

            [&output] (NodeT * n, key_iterator) { 
                output._impl.reset(new IteratorInternalT(n));
            },

            [] (NodeT *, key_iterator, KeyIterator) {  },

            [&inputEnd] (NodeItr, KeyIterator iend) {
                inputEnd = iend; }
        );

        if (output._impl.get() != nullptr)
        {
            output.normalize();
            std::copy(it, inputEnd, std::back_inserter(output._impl->base_prefix));
        }

        return output;
    }

    template <typename KeyIterator>
    iterator find_prefix_int(NodeT * root_node, KeyIterator it, KeyIterator kend, bool & exactMatch)
    {
        exactMatch = false;
        return find_prefix_int(root_node, it, kend, [&exactMatch] () { exactMatch = true; });
    }

    template <typename KeyIterator>
    iterator find_prefix_int(NodeT * root_node, KeyIterator it, KeyIterator kend, std::nullptr_t)
    {
        return find_prefix_int(root_node, it, kend, [] () {});
    }

public:
    template <typename KeyIterator, typename CallbackType>
    iterator find_prefix(KeyIterator it, KeyIterator kend, CallbackType exactMatch)
    {
        if (edges.empty()) { return end(); }
        return find_prefix_int(root(), it, kend, exactMatch);
    }

    /* NOTE : this specialization is needed to catch bool as reference, not as value */
    template <typename KeyIterator>
    iterator find_prefix(KeyIterator it, KeyIterator kend, bool & exactMatch)
    {
        if (edges.empty()) { return end(); }
        return find_prefix_int(root(), it, kend, exactMatch);
    }

    template <typename KeyIterator, typename CallbackType>
    iterator find_prefix(const iterator & base, KeyIterator it, KeyIterator kend, CallbackType exactMatch)
    {
        if (base._impl.get() == nullptr || base._impl->get() == nullptr) {
            return end();
        }

        return find_prefix_int(base._impl->get(), it, kend, exactMatch);
    }

    template <typename CallbackType>
    iterator find_prefix(const std::basic_string<AtomT> & str, CallbackType exactMatch) {
        return find_prefix(str.begin(), str.end(), exactMatch);
    }

    /* NOTE : this "specialization" (overload actually) is needed to catch 
     * bool as reference, not as value */
    iterator find_prefix(const std::basic_string<AtomT> & str, bool & exactMatch) {
        return find_prefix(str.begin(), str.end(), exactMatch);
    }

    iterator find_prefix(const std::basic_string<AtomT> & str) {
        return find_prefix(str.begin(), str.end(), [] () {});
    }

    template <typename KeyIterator>
    iterator find(KeyIterator it, KeyIterator kend)
    {
        if (edges.empty()) { return end(); }

        IteratorInternalT * root_it = new IteratorInternalT(root());
        IteratorPtr output(root_it);

        general_search(root(), it, kend,
            [&output] (NodeT * n) {
                if (!n->has_value()) {
                    output.reset(); } },

            [&output] (NodeT * n, KeyIterator  ) { output.reset(); },
            [&output] (NodeT * n, key_iterator ) { output.reset(); },
            [&output] (NodeT * n, key_iterator, KeyIterator ) 
                { output.reset(); },

            [&output] (NodeItr x, KeyIterator)   { output->push(x); }
        );

        return (output == nullptr) ? iterator() : iterator(output);
    }

    iterator find(const std::basic_string<AtomT> & str)
    {
        return find(str.begin(), str.end());
    }

    iterator begin() { 
        return edges.empty() ? end() :
            iterator(IteratorPtr(new IteratorInternalT(root()))); }

    iterator end()   { return iterator(); }

    template <typename KeyIterator>
    value_type * get(KeyIterator it, KeyIterator end)
    {
        if (edges.empty()) { return nullptr; }

        value_type * result = nullptr;

        general_search(root(), it, end,
            [&result] (NodeT * n) {
                if (n->has_value()) {
                    result = std::addressof(n->get_value()); }
            },

            [] (NodeT * , KeyIterator) { },
            [] (NodeT * , key_iterator ) { },
            [] (NodeT * , key_iterator , KeyIterator ) { },
            [] (NodeItr, KeyIterator) { }
        );

        return result;
    }

    value_type * get(const std::basic_string<AtomT> & str)
    {
        return get(str.begin(), str.end());
    }

    template <typename KeyIterator>
    value_type & at(KeyIterator it, KeyIterator end)
    {
        value_type * result = get(it, end);

        if (result == nullptr) {
            throw std::out_of_range("trie::at"); 
        }

        return *result;
    }

    value_type & at(const std::basic_string<AtomT> & str)
    {
        return at(str.begin(), str.end());
    }

    value_type & operator [](const std::basic_string<AtomT> & str)
    {
        return at(str.begin(), str.end());
    }

    size_t _edges() { return edges.size(); }
    size_t _keys()  { return keys.size(); }

    struct _debug_print
    {
        const trie_map & map;

        _debug_print(const trie_map & amap) : map(amap) {};

        std::ostream & operator ()(std::ostream & stream) const
        {
            if (map.edges.empty())
            {
                return stream << "[ empty ]";
            }

            trie_map::IteratorInternalT it(std::addressof(map.edges[0]));

            while (it.m_root != nullptr) 
            {
                const trie_map::NodeT * n = it.get(0);

                std::copy(n->kbegin(), n->kend(), 
                    std::ostream_iterator<char, char>(stream));

                if (n->has_value())
                {
                    stream << "(=" << n->get_value() << ")";
                }

                if (it.step_down()) { stream << "{"; continue; }
                if (it.step_fore()) { stream << "}{"; continue; }

                while (!it.m_ptrs.empty()) 
                {
                    stream << "}";
                    if (it.step_up()) { break; }
                }

                if (it.m_ptrs.empty()) { it.m_root = nullptr; }
            }

            return stream;
        }

        friend std::ostream & operator << (std::ostream & stream,
            const _debug_print & x)
        {
            return x(stream);
        }
    };
};

/**
 * @warning: operator== ALWAYS returns \true if
 *      the left operand dereferences to \0.
 */
template <typename AtomT>
struct CStrIterator : std::forward_iterator_tag
{
private:
    AtomT * m_str;
    typedef CStrIterator<AtomT> self_type;
public:
    CStrIterator(AtomT * a_str) : m_str(a_str) {}
    explicit CStrIterator(AtomT * a_str, size_t offset) : m_str(a_str + offset) {}

    self_type operator ++() { return m_str++; }
    AtomT & operator *()    { return *m_str;  }

    bool operator ==(const self_type & other) const
    {
        return m_str == other.m_str || *m_str == '\0';
    }
};

};

#endif /* TRIE_H */

 

triefunc.cpp

#define BOOST_TEST_MODULE trie functional test set
#define BOOST_TEST_DYN_LINK

#include <boost/algorithm/string/predicate.hpp>
#include <boost/test/unit_test.hpp>

#include <string>
#include <set>
#include <src/trie.h>

namespace utf  = boost::unit_test;

#define ITEMS_TO_TEST (128*1024)
#define MAX_LENGTH 1024

typedef trie::trie_map<char, trie::SetCounter> TestSet;
typedef trie::trie_map<char, std::string> TestMapI;

/*
static const char * test_components_1[] =
{
    "1121",   "1231",     "1313",   "41412",
    "31314",  "1223092",  "01121",  "01231",
    "01313",  "041412",   "031314", "012292",
    "11217",  "12319",    "13139",  "414127",
    "313147", "12230927", "11219",  "12317",
    "13137",  "414129",   "313149", "12230929",
};
*/

typedef std::minstd_rand DefaultGenerator; /* Determinism and uniformness are not really important */

template<typename Generator>
std::string generate(Generator & g)
{
    std::string result;

    result.resize(g() % MAX_LENGTH);

    for (unsigned i = 0; i < result.size(); ++i) {
        result[i] = (char) (g() & 0xff);
    }

    return result;
}

BOOST_AUTO_TEST_CASE(fill_map)
{
    DefaultGenerator g(1);
    TestMapI t;
    std::set<std::string> t_model;

    for (int i = ITEMS_TO_TEST; i > 0; --i)
    {
        std::string x = generate(g);
        t_model.insert(x);
        t.insert(x, x);
    }

    for (const std::string & x : t_model)
    {
        auto it = t.find(x);
        BOOST_CHECK(it != t.end());
        BOOST_CHECK(it.value() == x);
        BOOST_CHECK(it.value() == it.key());
        BOOST_CHECK(t.contains(x) == true);
        BOOST_CHECK(t.get(x) != nullptr);
        BOOST_CHECK(t.at(x) == x);

        bool captureMatch = false;
        BOOST_CHECK(t.find_prefix(x, captureMatch).value() == it.key());
        BOOST_CHECK(captureMatch == true);
    }

    for (const std::string & x : t)
    {
        BOOST_CHECK(t_model.find(x) != t_model.end());
    }
}

BOOST_AUTO_TEST_CASE(fill_set)
{
    DefaultGenerator g(1);
    TestSet t;
    std::set<std::string> t_model;

    for (int i = ITEMS_TO_TEST; i > 0; --i)
    {
        std::string x = generate(g);
        t_model.insert(x);
        t.insert(x);
    }

    for (const std::string & x : t_model)
    {
        BOOST_CHECK(t.find(x) != t.end());
        BOOST_CHECK(t.contains(x) == true);
        BOOST_CHECK(t.get(x) != nullptr);
    }
}

BOOST_AUTO_TEST_CASE(prefix_lookup)
{
    TestMapI tmap;
    std::map<std::string, std::string> t_model;

    tmap.insert("/home/user1/audio", "a1");
    tmap.insert("/home/user1/video/x", "v1x");
    tmap.insert("/home/user1/video", "v1");
    tmap.insert("/home/user2/audio", "a2");
    tmap.insert("/home/user2/video", "v2");

    for (auto it = tmap.find_prefix("/home/user1"); it != tmap.end(); ++it) {
        t_model[it.key()] = it.value();
    }

    BOOST_CHECK(t_model.size() == 3);
    BOOST_CHECK(t_model["/home/user1/audio"] == std::string("a1"));
    BOOST_CHECK(t_model["/home/user1/video/x"] == std::string("v1x"));
    BOOST_CHECK(t_model["/home/user1/video"] == std::string("v1"));
}

template<typename M>
void simple(M & t)
{
    t.insert("abcabcabc", 1);
    t.insert("abcabc",    1);
    t.insert("abcvabc",   1);
    t.insert("abcxabc",   1);
    t.insert("abcyasbc",  1);
    t.insert("xabcvabc",  1);
    t.insert("xabcxabc",  1);
    t.insert("xabcyasbc", 1);
}

BOOST_AUTO_TEST_CASE(simple_test_1)
{
    TestSet t;
    simple(t);

    int count;
    bool found;

    auto it = simple.cont.find_prefix("abc", found);
    BOOST_CHECK(found == false);

    count = 0;

    for (; it != simple.cont.end(); ++it)
    {
        BOOST_CHECK(boost::starts_with(it.key(), "abc"));
        ++count;
    }

    BOOST_CHECK(count == 5);

    it = simple.cont.find_prefix("abcabc", found);

    count = 0;

    for (; it != simple.cont.end(); ++it)
    {
        BOOST_CHECK(boost::starts_with(it.key(), "abcabc"));
        ++count;
    }

    BOOST_CHECK(count == 2);

    count = 0;
    it = simple.cont.find_prefix("xabc", [&count] () { ++count; });
    BOOST_CHECK(count == 0);

    count = 0;
    it = simple.cont.find_prefix("xabcxabc", [&count] () { ++count; });
    BOOST_CHECK(count == 1);
}

BOOST_AUTO_TEST_CASE(empty_map)
{
    TestMapI t;

    BOOST_CHECK(t.get("something") == nullptr);
    BOOST_CHECK(t.get("") == nullptr);
    BOOST_CHECK(t.contains("") == false);
    BOOST_CHECK(t.find("") == t.end());
}

BOOST_AUTO_TEST_CASE(empty_map_iterators)
{
    TestMapI t;

    BOOST_CHECK(t.find("") == t.end());
    BOOST_CHECK(t.find_prefix("") == t.end());
    BOOST_CHECK(t.find("something") == t.end());
    BOOST_CHECK(t.find_prefix("something") == t.end());
}

BOOST_AUTO_TEST_CASE(empty_set)
{
    TestSet t;

    BOOST_CHECK(t.get("something") == nullptr);
    BOOST_CHECK(t.get("") == nullptr);
    BOOST_CHECK(t.contains("something") == false);
    BOOST_CHECK(t.contains("") == false);
}

BOOST_AUTO_TEST_CASE(empty_set_iterators)
{
    TestSet t;

    BOOST_CHECK(t.find("") == t.end());
    BOOST_CHECK(t.find_prefix("") == t.end());
    BOOST_CHECK(t.find("something") == t.end());
    BOOST_CHECK(t.find_prefix("something") == t.end());
}

 

trietest.cpp

#include "src/trie.h"

#include <map>
#include <chrono>
#include <algorithm>
#include <fstream>
#include <sstream>

typedef trie::trie_map<char, int>    TestCountingSet;
typedef trie::trie_map<char, int, 0> TestCountingSimpleSet;

typedef std::vector<std::string>     WordSet;
typedef std::map<std::string, int>   StringMap;

template <typename Container>
struct StringInserter
{
    typedef typename Container::mapped_type value_type;
    Container & m;
    explicit StringInserter(Container & am) : m(am) {}

    void operator()(const std::string & key, const value_type & x)
    {
        m.insert(key, x);
    }
};

template <typename Container>
struct StringLookup
{
    typedef typename Container::mapped_type value_type;
    Container & m;
    explicit StringLookup(Container & am) : m(am) {}

    value_type * operator()(const std::string & key)
    {
        return m.get(key);
    }
};

template <>
struct StringInserter<StringMap>
{
    typedef typename StringMap::mapped_type value_type;
    StringMap & m;
    explicit StringInserter(StringMap & am) : m(am) {}

    void operator()(const std::string & key, const value_type & x)
    {
        /* Explicity copy the string to be fair */
        m.insert(StringMap::value_type(std::string(key.data(), key.length()), x));
    }
};

template <>
struct StringLookup<StringMap>
{
    typedef typename StringMap::mapped_type value_type;
    StringMap & m;
    explicit StringLookup(StringMap & am) : m(am) {}

    value_type * operator()(const std::string & key)
    {
        auto it = m.find(key);
        return it == m.end() ? nullptr : std::addressof(it->second);
    }
};

struct StatefulRandom
{
    unsigned int seed = 2345;

    int operator()() { return rand_r(&seed); }
    int operator()(int n) { return rand_r(&seed) % n; }
};

struct Generator
{
    int seqsz = 0;
    StatefulRandom rnd;
    WordSet        wordset;

    std::string operator()()
    {
        std::ostringstream str;

        for (int i = seqsz; i > 0; --i)
        {
            str << wordset[rnd() % wordset.size()] << ".";
        }

        str << wordset[rnd() % wordset.size()];

        return str.str();
    };
};

struct perf_clock
{
    typedef std::chrono::high_resolution_clock clock;

    clock::time_point t0;
    uint64_t dt;

    void start() {
        t0 = clock::now();
    }

    void mark() {
        clock::time_point t1 = clock::now();
        dt = std::chrono::duration_cast<std::chrono::nanoseconds>(t1 - t0).count();
        t0 = t1;
    }

    void psec(const std::string & trial, int itemCount)
    {
        std::cout << trial << ".avg\t" << dt / itemCount << std::endl;
    }
};

std::ostream & operator << (std::ostream & stream, const std::vector<int> & x)
{
    stream << "[";

    if (x.size() != 0)
    {
        std::for_each(x.begin(), x.end() - 1, [&stream] (int y) { stream << y << ", "; });
        stream << *x.rbegin();
    }

    stream << "]";

    return stream;
}

template <typename Container>
struct ContainerTest
{
    Container cont;
    volatile int found = 0;
    std::string prefix;
    std::vector<int> numberOfItems;
    std::vector<int> insertTime;
    std::vector<int> lookupTime;

    ContainerTest(const std::string & aprefix) : prefix(aprefix) { };

    void simple()
    {
        StringInserter<Container> inserter(cont);
        StringLookup<Container>   lookup(cont);

        inserter("abcabcabc", 1);
        inserter("abcabc",    1);
        inserter("abcvabc",   1);
        inserter("abcxabc",   1);
        inserter("abcyasbc",  1);
        inserter("xabcvabc",  1);
        inserter("xabcxabc",  1);
        inserter("xabcyasbc", 1);

        lookup("abcabc");
    }

    void words(Generator & generator)
    {
        std::vector<std::string> wset;
        StringInserter<Container> inserter(cont);
        StringLookup<Container>   lookup(cont);
        int total = 20;

        for (int i = 0; i < total; ++i) {
            wset.push_back(generator());
        }

        for (int i = 0; i < total; ++i) {
            inserter(wset[i], 1);
        }

        int lost = total;

        for (int i = 0; i < total; ++i) {
            if (0 != lookup(wset[i])) {
                --lost;
            }
        }

        std::cout << "Lost : " << lost << std::endl;
    }

    void word_set(Generator & generator)
    {
        std::vector<std::string> wset;
        StringInserter<Container> inserter(cont);
        StringLookup<Container>   lookup(cont);
        int total = 100;

        for (int i = 0; i < total; ++i) {
            wset.push_back(generator());
        }

        for (int i = 0; i < total; ++i) {
            cont.insert(wset[i]);
        }

        int lost = total;

        for (int i = 0; i < total; ++i) {
            if (0 != lookup(wset[i])) {
                --lost;
            }
        }

        std::cout << "Lost : " << lost << std::endl;
    }

    void test(Generator & generator)
    {
        std::vector<std::string> words;

        StringInserter<Container> inserter(cont);
        StringLookup<Container>   lookup(cont);
        perf_clock pc;

        const int itemCount = 10000;

        for (int i = 0; i < itemCount * 20; ++i) {
            words.push_back(generator());
        }

        uint64_t len = 0;
        std::for_each(words.begin(), words.end(), [&len] (const std::string & x) { len += x.length(); });
        len /= words.size();
        std::cout << "Average length : ~" << len << std::endl;

        for (int _total = 20; _total > 0; --_total)
        {
            numberOfItems.push_back(cont.size());

            pc.start();

            for (int i = 0; i < itemCount; ++i) {
                inserter(words[(i + _total * itemCount) % words.size()], i);
            }

            pc.mark();

            insertTime.push_back(pc.dt / itemCount);

            pc.start();
            found = 0;
            for (int i = 0; i < itemCount * 10; ++i) 
            {
                int * x = lookup(words[(i + _total * itemCount) % words.size()]);
                if (x != nullptr) { ++found; }
            }

            pc.mark();

            lookupTime.push_back(pc.dt / itemCount / 10);
        }

        std::cout << "Positive found : " << found << std::endl;

        words.clear();
        for (int i = 0; i < itemCount * 3; ++i) {
            words.push_back(generator());
        }

        pc.start();
        found = 0;
        for (int i = 0; i < itemCount * 10; ++i) 
        {
            int * x = lookup(words[i % words.size()]);
            if (x != nullptr) { ++found; }
        }

        pc.mark();
        pc.psec(prefix + ".random-lookup", itemCount * 10);

        std::cout << "Random found : " << found << std::endl;
    }
};

int main()
{
    Generator words;

    {
        std::ifstream fin("/usr/share/dict/words");
        std::string   line;

        while (getline(fin, line)) {
            words.wordset.push_back(line);
        }
    }

    std::cout << words.wordset.size() << std::endl;
    std::random_shuffle(words.wordset.begin(), words.wordset.end(), words.rnd);

    {
        typedef trie::trie_map<char, int> TestMap;
        TestMap tmap;

        tmap.insert("105", 1);
        tmap.insert("104", 2);
        tmap.insert("2093", 3);
        tmap.insert("2097", 4);

        std::cout << tmap["105"] << " ";
        std::cout << tmap["104"] << " ";
        std::cout << tmap["2093"] << " ";
        std::cout << tmap["2097"] << " ";
        std::cout << std::endl;
    }

    {
        typedef trie::trie_map<char, int> TestMap;
        TestMap tmap;

        tmap.insert("10.0.0.1",    1);
        tmap.insert("10.0.17.8",   2);
        tmap.insert("192.168.0.1", 3);
        tmap.insert("192.168.0.2", 4);

        for (auto it = tmap.begin(); it != tmap.end(); ++it) {
            std::cout << *it << " ";
        }

        std::cout << std::endl;

        for (auto it = tmap.begin(); it != tmap.end(); ++it) {
            std::cout << it.key() << " ";
        }

        std::cout << std::endl;
    }

    {
        typedef trie::trie_map<char, trie::SetCounter> TestSet;
        TestSet tset;

        tset.insert("10.0.0.1");
        tset.insert("10.0.17.8");
        tset.insert("192.168.0.1");
        tset.insert("192.168.0.2");

        std::cout << tset.contains("10.0.0.1") << " ";
        std::cout << tset.contains("10.0.17.8") << " ";
        std::cout << tset.contains("10.0.17.2") << " ";
        std::cout << tset.contains("10.0.1.1") << " ";
        std::cout << std::endl;
    }

    {
        typedef trie::trie_map<char, int> TestMap;
        TestMap tmap;

        tmap.insert("/home/user1/audio", 10);
        tmap.insert("/home/user1/video", 11);
        tmap.insert("/home/user2/audio", 20);
        tmap.insert("/home/user2/video", 21);

        for (auto it = tmap.find_prefix("/home/user1"); it != tmap.end(); ++it) {
            std::cout << it.key() << " ";
            std::cout << *it << ";\n";
        }

        std::cout << std::endl;
    }

    {
        typedef trie::trie_map<char, int, 16> TestMap;
        ContainerTest<TestMap> simple("trie");
        simple.simple();
        std::cout << TestMap::_debug_print(simple.cont) << std::endl;
    }

    {
        typedef trie::trie_map<char, int, 1024> TestMap;
        ContainerTest<TestMap> simple("trie");
        simple.words(words);
        std::cout << TestMap::_debug_print(simple.cont) << std::endl;
    }

    {
        typedef trie::trie_map<char, trie::SetCounter> TestSet;
        ContainerTest<TestSet> simple("trie_set");
        simple.simple();

        for (auto it = simple.cont.begin(); it != simple.cont.end(); ++it)
        {
            std::cout << it.key() << std::endl;
        }

        bool found;
        auto it = simple.cont.find_prefix("abc", found);

        std::cout << " *** prefix exact match : " << found << std::endl;

        for (; it != simple.cont.end(); ++it)
        {
            std::cout << it.key() << std::endl;
        }

        it = simple.cont.find_prefix("abcabc", found);

        std::cout << " *** prefix exact match : " << found << std::endl;

        for (; it != simple.cont.end(); ++it)
        {
            std::cout << it.key() << std::endl;
        }

        it = simple.cont.find_prefix("xabc",
            [] () { std::cout << "Error!" << std::endl; });

        it = simple.cont.find_prefix("xabcxabc",
            [] () { std::cout << "OK exact prefix found!" << std::endl; });

        std::cout << " *** countains 'abcvabc' : " << simple.cont.contains("abcvabc") << std::endl;

        it = simple.cont.find("xabcxabc");

        for (; it != simple.cont.end(); ++it)
        {
            std::cout << it.key() << std::endl;
        }
    }

    {
        typedef trie::trie_map<char, trie::SetCounter> TestSet;
        ContainerTest<TestSet> simple("trie_set");
        simple.word_set(words);
        std::cout << TestSet::_debug_print(simple.cont) << std::endl;

        auto it = simple.cont.find("yaray");

        for (; it != simple.cont.end(); ++it)
        {
            std::cout << it.key() << std::endl;
        }
    }

    words.wordset.resize(200000);
    words.seqsz = 0;
    while (words.seqsz < 5)
    {
        std::cout << "***\n" 
            << "seq-len=" << (words.seqsz + 1)
            << " words=" << (words.wordset.size())
            << std::endl;

        std::cout << "*** Map : " << std::endl;

        {
            ContainerTest<StringMap> test1("map");
            words.rnd.seed = 9;
            test1.test(words);

            std::cout << "mapX = "      << test1.numberOfItems << std::endl;
            std::cout << "mapInsert = " << test1.insertTime << std::endl;
            std::cout << "mapLookup = " << test1.lookupTime << std::endl;
        }

        std::cout << "*** Trie 0 : " << std::endl;

        {
            ContainerTest<trie::trie_map<char, int, 0> > test1("trie");
            words.rnd.seed = 9;
            test1.test(words);

            std::cout << "trie0X = "      << test1.numberOfItems << std::endl;
            std::cout << "trie0Insert = " << test1.insertTime << std::endl;
            std::cout << "trie0Lookup = " << test1.lookupTime << std::endl;
        }

        std::cout << "*** Trie 1K : " << std::endl;

        {
            ContainerTest<trie::trie_map<char, int, 1024> > test1("trie");
            words.rnd.seed = 9;
            test1.test(words);

            std::cout << "trie1X = "      << test1.numberOfItems << std::endl;
            std::cout << "trie1Insert = " << test1.insertTime << std::endl;
            std::cout << "trie1Lookup = " << test1.lookupTime << std::endl;
        }

        std::cout << "*** Trie 4K : " << std::endl;

        {
            ContainerTest<trie::trie_map<char, int, 4*1024> > test1("trie");
            words.rnd.seed = 9;
            test1.test(words);

            std::cout << "trie4X = "      << test1.numberOfItems << std::endl;
            std::cout << "trie4Insert = " << test1.insertTime << std::endl;
            std::cout << "trie4Lookup = " << test1.lookupTime << std::endl;
        }

        words.seqsz++;
        words.wordset.resize(words.wordset.size() / 10);
    }

    return 0;
}

 

kephir4eg/trie: C++ implementation of Radix tree (github.com)

저작자표시 (새창열림)

'CS > 자료구조 & 알고리즘' 카테고리의 다른 글

sort() 함수에서 쓰여지는 정렬 알고리즘 (Intro 인트로, Tim 팀)  (0) 2023.12.11
[C++] Radix Sort (기수 정렬)  (0) 2023.12.04
[C++] Trie (트라이) 개념과 구현방법  (0) 2023.11.28
[C++] const map 객체에 [key] 접근시 에러  (0) 2023.11.22
[C++] vector resize 시 주의할 점  (0) 2023.11.22
    'CS/자료구조 & 알고리즘' 카테고리의 다른 글
    • sort() 함수에서 쓰여지는 정렬 알고리즘 (Intro 인트로, Tim 팀)
    • [C++] Radix Sort (기수 정렬)
    • [C++] Trie (트라이) 개념과 구현방법
    • [C++] const map 객체에 [key] 접근시 에러
    ShovelingLife
    ShovelingLife
    Main skill stack => Unity C# / Unreal C++ Studying Front / BackEnd, Java Python

    티스토리툴바