/** @file Table.h Templated hash table class. @maintainer Morgan McGuire, http://graphics.cs.williams.edu @created 2001-04-22 @edited 2013-01-22 Copyright 2000-2013, Morgan McGuire. All rights reserved. */ #ifndef G3D_Table_h #define G3D_Table_h #include <cstddef> #include <string> #include "G3D/platform.h" #include "G3D/Array.h" #include "G3D/debug.h" #include "G3D/System.h" #include "G3D/g3dmath.h" #include "G3D/EqualsTrait.h" #include "G3D/HashTrait.h" #include "G3D/MemoryManager.h" #ifdef _MSC_VER # pragma warning (push) // Debug name too long warning # pragma warning (disable : 4786) #endif namespace G3D { /** An unordered data structure mapping keys to values. There are two ways of definining custom hash functions (G3D provides built-in ones for most classes): <pre> class Foo { public: std::string name; int index; static size_t hashCode(const Foo& key) { return HashTrait<std::string>::hashCode(key.name) + key.index; } }; template<> struct HashTrait<class Foo> { static size_t hashCode(const Foo& key) { return HashTrait<std::string>::hashCode(key.name) + key.index; } }; // Use Foo::hashCode Table<Foo, std::string, Foo> fooTable1; // Use HashTrait<Foo> Table<Foo, std::string> fooTable2; </pre> Key must be a pointer, an int, a std::string or provide overloads for: <PRE> template<> struct HashTrait<class Key> { static size_t hashCode(const Key& key) { return reinterpret_cast<size_t>( ... ); } }; </PRE> and one of <PRE> template<> struct EqualsTrait<class Key>{ static bool equals(const Key& a, const Key& b) { return ... ; } }; bool operator==(const Key&, const Key&); </PRE> G3D pre-defines HashTrait specializations for common types (like <CODE>int</CODE> and <CODE>std::string</CODE>). If you use a Table with a different type you must write those functions yourself. For example, an enum would use: <PRE> template<> struct HashTrait<MyEnum> { static size_t hashCode(const MyEnum& key) const { return reinterpret_cast<size_t>( key ); } }; </PRE> And rely on the default enum operator==. Periodically check that debugGetLoad() is low (> 0.1). When it gets near 1.0 your hash function is badly designed and maps too many inputs to the same output. */ template<class Key, class Value, class HashFunc = HashTrait<Key>, class EqualsFunc = EqualsTrait<Key> > class Table { public: /** The pairs returned by iterator. */ class Entry { public: Key key; Value value; Entry() {} Entry(const Key& k) : key(k) {} Entry(const Key& k, const Value& v) : key(k), value(v) {} bool operator==(const Entry &peer) const { return (key == peer.key && value == peer.value); } bool operator!=(const Entry &peer) const { return !operator==(peer); } }; private: typedef Table<Key, Value, HashFunc, EqualsFunc> ThisType; /** Linked list nodes used internally by HashTable. */ class Node { public: Entry entry; size_t hashCode; Node* next; private: // Private to require use of the allocator Node(const Key& k, const Value& v, size_t h, Node* n) : entry(k, v), hashCode(h), next(n) { debugAssert((next == NULL) || isValidHeapPointer(next)); } Node(const Key& k, size_t h, Node* n) : entry(k), hashCode(h), next(n) { debugAssert((next == NULL) || isValidHeapPointer(next)); } public: static Node* create(const Key& k, const Value& v, size_t h, Node* n, MemoryManager::Ref& mm) { Node* node = (Node*)mm->alloc(sizeof(Node)); return new (node) Node(k, v, h, n); } static Node* create(const Key& k, size_t hashCode, Node* n, MemoryManager::Ref& mm) { Node* node = (Node*)mm->alloc(sizeof(Node)); return new (node) Node(k, hashCode, n); } static void destroy(Node* n, MemoryManager::Ref& mm) { n->~Node(); mm->free(n); } /** Clones a whole chain; */ Node* clone(MemoryManager::Ref& mm) { return create(this->entry.key, this->entry.value, hashCode, (next == NULL) ? NULL : next->clone(mm), mm); } }; void checkIntegrity() const { # ifdef G3D_DEBUG debugAssert(m_bucket == NULL || isValidHeapPointer(m_bucket)); for (size_t b = 0; b < m_numBuckets; ++b) { Node* node = m_bucket[b]; debugAssert(node == NULL || isValidHeapPointer(node)); while (node != NULL) { debugAssert(node == NULL || isValidHeapPointer(node)); node = node->next; } } # endif } /** Number of elements in the table.*/ size_t m_size; /** Array of Node*. We don't use Array<Node*> because Table is lower-level than Array. Some elements may be NULL. */ Node** m_bucket; /** Length of the m_bucket array. */ size_t m_numBuckets; MemoryManager::Ref m_memoryManager; void* alloc(size_t s) const { return m_memoryManager->alloc(s); } void free(void* p) const { return m_memoryManager->free(p); } /** Re-hashes for a larger m_bucket size. */ void resize(size_t newSize) { // Hang onto the old m_bucket array Node** oldBucket = m_bucket; // Allocate a new m_bucket array with the new size m_bucket = (Node**)alloc(sizeof(Node*) * newSize); alwaysAssertM(m_bucket != NULL, "MemoryManager::alloc returned NULL. Out of memory."); // Set all pointers to NULL System::memset(m_bucket, 0, newSize * sizeof(Node*)); // Move each node to its new hash location for (size_t b = 0; b < m_numBuckets; ++b) { Node* node = oldBucket[b]; // There is a linked list of nodes at this m_bucket while (node != NULL) { // Hang onto the old next pointer Node* nextNode = node->next; // Insert at the head of the list for m_bucket[i] size_t i = node->hashCode % newSize; node->next = m_bucket[i]; m_bucket[i] = node; // Move on to the next node node = nextNode; } // Drop the old pointer for cleanliness when debugging oldBucket[b] = NULL; } // Delete the old storage free(oldBucket); this->m_numBuckets = newSize; checkIntegrity(); } void copyFrom(const ThisType& h) { if (&h == this) { return; } debugAssert(m_bucket == NULL); m_size = h.m_size; m_numBuckets = h.m_numBuckets; m_bucket = (Node**)alloc(sizeof(Node*) * m_numBuckets); // No need to NULL elements since we're about to overwrite them for (size_t b = 0; b < m_numBuckets; ++b) { if (h.m_bucket[b] != NULL) { m_bucket[b] = h.m_bucket[b]->clone(m_memoryManager); } else { m_bucket[b] = NULL; } } checkIntegrity(); } /** Frees the heap structures for the nodes. */ void freeMemory() { checkIntegrity(); for (size_t b = 0; b < m_numBuckets; ++b) { Node* node = m_bucket[b]; while (node != NULL) { Node* next = node->next; Node::destroy(node, m_memoryManager); node = next; } m_bucket[b] = NULL; } free(m_bucket); m_bucket = NULL; m_numBuckets = 0; m_size = 0; } public: /** Creates an empty hash table using the default MemoryManager. */ Table() : m_bucket(NULL) { m_memoryManager = MemoryManager::create(); m_numBuckets = 0; m_size = 0; m_bucket = NULL; checkIntegrity(); } /** Changes the internal memory manager to m */ void clearAndSetMemoryManager(const MemoryManager::Ref& m) { clear(); debugAssert(m_bucket == NULL); m_memoryManager = m; } /** Recommends that the table resize to anticipate at least this number of elements. */ void setSizeHint(size_t n) { size_t s = n * 3; if (s > m_numBuckets) { resize(s); } } /** Destroys all of the memory allocated by the table, but does <B>not</B> call delete on keys or values if they are pointers. If you want to deallocate things that the table points at, use getKeys() and Array::deleteAll() to delete them. */ virtual ~Table() { freeMemory(); } /** Uses the default memory manager */ Table(const ThisType& h) { m_memoryManager = MemoryManager::create(); m_numBuckets = 0; m_size = 0; m_bucket = NULL; this->copyFrom(h); checkIntegrity(); } Table& operator=(const ThisType& h) { // No need to copy if the argument is this if (this != &h) { // Free the existing nodes freeMemory(); this->copyFrom(h); checkIntegrity(); } return *this; } /** Returns the length of the deepest m_bucket. */ size_t debugGetDeepestBucketSize() const { size_t deepest = 0; for (size_t b = 0; b < m_numBuckets; ++b) { size_t count = 0; Node* node = m_bucket[b]; while (node != NULL) { node = node->next; ++count; } if (count > deepest) { deepest = count; } } return deepest; } /** Returns the average size of non-empty buckets. */ float debugGetAverageBucketSize() const { uint64 num = 0; for (size_t b = 0; b < m_numBuckets; ++b) { Node* node = m_bucket[b]; if (node != NULL) { ++num; } } return (float)((double)size() / num); } /** A small load (close to zero) means the hash table is acting very efficiently most of the time. A large load (close to 1) means the hash table is acting poorly-- all operations will be very slow. A large load will result from a bad hash function that maps too many keys to the same code. */ double debugGetLoad() const { return (double)size() / m_numBuckets; } /** Returns the number of buckets. */ size_t debugGetNumBuckets() const { return m_numBuckets; } /** C++ STL style iterator variable. See begin(). */ class Iterator { private: friend class Table<Key, Value, HashFunc, EqualsFunc>; /** Bucket index. */ size_t index; /** Linked list node. */ Node* node; size_t m_numBuckets; Node** m_bucket; bool isDone; /** Creates the end iterator. */ Iterator() : index(0), node(NULL), m_bucket(NULL) { isDone = true; } Iterator(size_t numBuckets, Node** m_bucket) : index(0), node(NULL), m_numBuckets(numBuckets), m_bucket(m_bucket) { if (m_numBuckets == 0) { // Empty table isDone = true; return; } # ifdef G3D_DEBUG for (unsigned int i = 0; i < m_numBuckets; ++i) { debugAssert((m_bucket[i] == NULL) || isValidHeapPointer(m_bucket[i])); } # endif index = 0; node = m_bucket[index]; debugAssert((node == NULL) || isValidHeapPointer(node)); isDone = false; findNext(); debugAssert((node == NULL) || isValidHeapPointer(node)); } /** If node is NULL, then finds the next element by searching through the bucket array. Sets isDone if no more nodes are available. */ void findNext() { while (node == NULL) { ++index; if (index >= m_numBuckets) { m_bucket = NULL; index = 0; isDone = true; return; } else { node = m_bucket[index]; debugAssert((node == NULL) || isValidHeapPointer(node)); } } debugAssert(isValidHeapPointer(node)); } public: inline bool operator!=(const Iterator& other) const { return !(*this == other); } bool operator==(const Iterator& other) const { if (other.isDone || isDone) { // Common case; check against isDone. return (isDone == other.isDone); } else { return (node == other.node) && (index == other.index); } } /** Pre increment. */ Iterator& operator++() { debugAssert(! isDone); debugAssert(node != NULL); debugAssert(isValidHeapPointer(node)); debugAssert((node->next == NULL) || isValidHeapPointer(node->next)); node = node->next; findNext(); debugAssert(isDone || isValidHeapPointer(node)); return *this; } /** Post increment (slower than preincrement). */ Iterator operator++(int) { Iterator old = *this; ++(*this); return old; } const Entry& operator*() const { return node->entry; } const Value& value() const { return node->entry.value; } const Key& key() const { return node->entry.key; } Entry* operator->() const { debugAssert(isValidHeapPointer(node)); return &(node->entry); } operator Entry*() const { debugAssert(isValidHeapPointer(node)); return &(node->entry); } bool isValid() const { return ! isDone; } /** @deprecated Use isValid */ bool hasMore() const { return ! isDone; } }; /** C++ STL style iterator method. Returns the first Entry, which contains a key and value. Use preincrement (++entry) to get to the next element. Do not modify the table while iterating. */ Iterator begin() const { return Iterator(m_numBuckets, m_bucket); } /** C++ STL style iterator method. Returns one after the last iterator element. */ const Iterator end() const { return Iterator(); } /** Removes all elements. Guaranteed to free all memory associated with the table. */ void clear() { freeMemory(); m_numBuckets = 0; m_size = 0; m_bucket = NULL; } /** Returns the number of keys. */ size_t size() const { return m_size; } /** If you insert a pointer into the key or value of a table, you are responsible for deallocating the object eventually. Inserting key into a table is O(1), but may cause a potentially slow rehashing. */ void set(const Key& key, const Value& value) { getCreateEntry(key).value = value; } private: /** Helper for remove() and getRemove() */ bool remove(const Key& key, Key& removedKey, Value& removedValue, bool updateRemoved) { if (m_numBuckets == 0) { return false; } const size_t code = HashFunc::hashCode(key); const size_t b = code % m_numBuckets; // Go to the m_bucket Node* n = m_bucket[b]; if (n == NULL) { return false; } Node* previous = NULL; // Try to find the node do { if ((code == n->hashCode) && EqualsFunc::equals(n->entry.key, key)) { // This is the node; remove it // Replace the previous's next pointer if (previous == NULL) { m_bucket[b] = n->next; } else { previous->next = n->next; } if (updateRemoved) { removedKey = n->entry.key; removedValue = n->entry.value; } // Delete the node Node::destroy(n, m_memoryManager); --m_size; //checkIntegrity(); return true; } previous = n; n = n->next; } while (n != NULL); //checkIntegrity(); return false; } public: /** If @a member is present, sets @a removed to the element being removed and returns true. Otherwise returns false and does not write to @a removed. */ bool getRemove(const Key& key, Key& removedKey, Value& removedValue) { return remove(key, removedKey, removedValue, true); } /** Removes an element from the table if it is present. @return true if the element was found and removed, otherwise false */ bool remove(const Key& key) { Key x; Value v; return remove(key, x, v, false); } private: Entry* getEntryPointer(const Key& key) const { if (m_numBuckets == 0) { return NULL; } size_t code = HashFunc::hashCode(key); size_t b = code % m_numBuckets; Node* node = m_bucket[b]; while (node != NULL) { if ((node->hashCode == code) && EqualsFunc::equals(node->entry.key, key)) { return &(node->entry); } node = node->next; } return NULL; } public: /** If a value that is EqualsFunc to @a member is present, returns a pointer to the version stored in the data structure, otherwise returns NULL. */ const Key* getKeyPointer(const Key& key) const { const Entry* e = getEntryPointer(key); if (e == NULL) { return NULL; } else { return &(e->key); } } /** Returns the value associated with key. @deprecated Use get(key, val) or getPointer(key) */ Value& get(const Key& key) const { Entry* e = getEntryPointer(key); debugAssertM(e != NULL, "Key not found"); return e->value; } /** Returns a pointer to the element if it exists, or NULL if it does not. Note that if your value type <i>is</i> a pointer, the return value is a pointer to a pointer. Do not remove the element while holding this pointer. It is easy to accidentally mis-use this method. Consider making a Table<Value*> and using get(key, val) instead, which makes you manage the memory for the values yourself and is less likely to result in pointer errors. */ Value* getPointer(const Key& key) const { if (m_numBuckets == 0) { return NULL; } size_t code = HashFunc::hashCode(key); size_t b = code % m_numBuckets; Node* node = m_bucket[b]; while (node != NULL) { if ((node->hashCode == code) && EqualsFunc::equals(node->entry.key, key)) { // found key return &(node->entry.value); } node = node->next; } // Failed to find key return NULL; } /** If the key is present in the table, val is set to the associated value and returns true. If the key is not present, returns false. */ bool get(const Key& key, Value& val) const { Value* v = getPointer(key); if (v != NULL) { val = *v; return true; } else { return false; } } /** Called by getCreate() and set() \param created Set to true if the entry was created by this method. */ Entry& getCreateEntry(const Key& key, bool& created) { created = false; if (m_numBuckets == 0) { resize(10); } size_t code = HashFunc::hashCode(key); size_t b = code % m_numBuckets; // Go to the m_bucket Node* n = m_bucket[b]; // No m_bucket, so this must be the first if (n == NULL) { m_bucket[b] = Node::create(key, code, NULL, m_memoryManager); ++m_size; created = true; //checkIntegrity(); return m_bucket[b]->entry; } size_t bucketLength = 1; // Sometimes a bad hash code will cause all elements // to collide. Detect this case and don't rehash when // it occurs; nothing good will come from the rehashing. bool allSameCode = true; // Try to find the node do { allSameCode = allSameCode && (code == n->hashCode); if ((code == n->hashCode) && EqualsFunc::equals(n->entry.key, key)) { // This is the a pre-existing node //checkIntegrity(); return n->entry; } n = n->next; ++bucketLength; } while (n != NULL); // Allow the load factor to rise as the table gets huge const int bucketsPerElement = (m_size > 50000) ? 3 : ((m_size > 10000) ? 5 : ((m_size > 5000) ? 10 : 15)); const size_t maxBucketLength = 3; // (Don't bother changing the size of the table if all entries // have the same hashcode--they'll still collide) if ((bucketLength > maxBucketLength) && ! allSameCode && (m_numBuckets < m_size * bucketsPerElement)) { // This m_bucket was really large; rehash if all elements // don't have the same hashcode the number of buckets is // reasonable. // Back off the scale factor as the number of buckets gets // large float f = 3.0f; if (m_numBuckets > 1000000) { f = 1.5f; } else if (m_numBuckets > 100000) { f = 2.0f; } int newSize = iMax((int)(m_numBuckets * f) + 1, (int)(m_size * f)); resize(newSize); } // Not found; insert at the head. b = code % m_numBuckets; m_bucket[b] = Node::create(key, code, m_bucket[b], m_memoryManager); ++m_size; created = true; //checkIntegrity(); return m_bucket[b]->entry; } Entry& getCreateEntry(const Key& key) { bool ignore; return getCreateEntry(key, ignore); } /** Returns the current value that key maps to, creating it if necessary.*/ Value& getCreate(const Key& key) { return getCreateEntry(key).value; } /** \param created True if the element was created. */ Value& getCreate(const Key& key, bool& created) { return getCreateEntry(key, created).value; } /** Returns true if key is in the table. */ bool containsKey(const Key& key) const { if (m_numBuckets == 0) { return false; } size_t code = HashFunc::hashCode(key); size_t b = code % m_numBuckets; Node* node = m_bucket[b]; while (node != NULL) { if ((node->hashCode == code) && EqualsFunc::equals(node->entry.key, key)) { return true; } node = node->next; } return false; } /** Short syntax for get. */ inline Value& operator[](const Key &key) const { return get(key); } /** Returns an array of all of the keys in the table. You can iterate over the keys to get the values. @deprecated */ Array<Key> getKeys() const { Array<Key> keyArray; getKeys(keyArray); return keyArray; } void getKeys(Array<Key>& keyArray) const { keyArray.resize(0, DONT_SHRINK_UNDERLYING_ARRAY); for (size_t i = 0; i < m_numBuckets; ++i) { Node* node = m_bucket[i]; while (node != NULL) { keyArray.append(node->entry.key); node = node->next; } } } /** Will contain duplicate values if they exist in the table. This array is parallel to the one returned by getKeys() if the table has not been modified. */ void getValues(Array<Value>& valueArray) const { valueArray.resize(0, DONT_SHRINK_UNDERLYING_ARRAY); for (size_t i = 0; i < m_numBuckets; ++i) { Node* node = m_bucket[i]; while (node != NULL) { valueArray.append(node->entry.value); node = node->next; } } } /** Calls delete on all of the keys and then clears the table. */ void deleteKeys() { for (size_t i = 0; i < m_numBuckets; ++i) { Node* node = m_bucket[i]; while (node != NULL) { delete node->entry.key; node->entry.key = NULL; node = node->next; } } clear(); } /** Calls delete on all of the values. This is unsafe-- do not call unless you know that each value appears at most once. Does not clear the table, so you are left with a table of NULL pointers. */ void deleteValues() { for (size_t i = 0; i < m_numBuckets; ++i) { Node* node = m_bucket[i]; while (node != NULL) { delete node->entry.value; node->entry.value = NULL; node = node->next; } } } template<class H, class E> bool operator==(const Table<Key, Value, H, E>& other) const { if (size() != other.size()) { return false; } for (Iterator it = begin(); it.hasMore(); ++it) { const Value* v = other.getPointer(it->key); if ((v == NULL) || (*v != it->value)) { // Either the key did not exist or the value was not the same return false; } } // this and other have the same number of keys, so we don't // have to check for extra keys in other. return true; } template<class H, class E> bool operator!=(const Table<Key, Value, H, E>& other) const { return ! (*this == other); } void debugPrintStatus() { debugPrintf("Deepest bucket size = %d\n", (int)debugGetDeepestBucketSize()); debugPrintf("Average bucket size = %g\n", debugGetAverageBucketSize()); debugPrintf("Load factor = %g\n", debugGetLoad()); } }; } // namespace #ifdef _MSC_VER # pragma warning (pop) #endif #endif