TradeSkillMaster/LibTSM/Util/SmartMap.lua

218 lines
6.0 KiB
Lua

-- ------------------------------------------------------------------------------ --
-- TradeSkillMaster --
-- https://tradeskillmaster.com --
-- All Rights Reserved - Detailed license information included with addon. --
-- ------------------------------------------------------------------------------ --
--- Smart Map.
-- @module SmartMap
local _, TSM = ...
local SmartMap = TSM.Init("Util.SmartMap")
local private = {
mapContext = {},
readerContext = {},
}
local VALID_FIELD_TYPES = {
string = true,
number = true,
boolean = true,
}
-- ============================================================================
-- Metatable Methods
-- ============================================================================
local SMART_MAP_MT = {
-- getter
__index = function(self, key)
if key == nil then
error("Attempt to get nil key")
end
if key == "ValueChanged" then
return private.MapValueChanged
elseif key == "SetCallbacksPaused" then
return private.MapSetCallbacksPaused
elseif key == "CreateReader" then
return private.MapCreateReader
elseif key == "GetKeyType" then
return private.MapGetKeyType
elseif key == "GetValueType" then
return private.MapGetValueType
elseif key == "Iterator" then
return private.MapIterator
else
error("Invalid map method: "..tostring(key), 2)
end
end,
-- setter
__newindex = function(self, key, value)
error("Map cannot be written to directly", 2)
end,
__tostring = function(self)
return "SmartMap:"..strmatch(tostring(private.mapContext[self]), "table:[^0-9a-fA-F]*([0-9a-fA-F]+)")
end,
__metatable = false,
}
local READER_MT = {
-- getter
__index = function(self, key)
-- check if the map already has the value for this key cached
local readerContext = private.readerContext[self]
local map = readerContext.map
local mapContext = private.mapContext[map]
if mapContext.data[key] ~= nil then
return mapContext.data[key]
end
-- get the value for this key
local value = mapContext.func(key)
if value == nil then
error(format("No value for key (%s)", tostring(key)))
elseif type(value) ~= mapContext.valueType then
error(format("Invalid type of value (got %s, expected %s): %s", type(value), mapContext.valueType, tostring(value)))
end
-- cache the value both on the map and on this reader
mapContext.data[key] = value
rawset(self, key, value)
return value
end,
-- setter
__newindex = function(self, key, value)
error("Reader is read-only", 2)
end,
__tostring = function(self)
return "SmartMapReader:"..strmatch(tostring(private.readerContext[self]), "table:[^0-9a-fA-F]*([0-9a-fA-F]+)")
end,
__metatable = false,
}
-- ============================================================================
-- Module Functions
-- ============================================================================
function SmartMap.New(keyType, valueType, callable)
assert(VALID_FIELD_TYPES[keyType] and VALID_FIELD_TYPES[valueType])
local map = setmetatable({}, SMART_MAP_MT)
private.mapContext[map] = {
keyType = keyType,
valueType = valueType,
func = callable,
data = {},
readers = {},
callbacksPaused = 0,
hasReaderCallback = false,
}
return map
end
-- ============================================================================
-- Private Helper Functions
-- ============================================================================
function private.MapValueChanged(self, key)
local mapContext = private.mapContext[self]
local oldValue = mapContext.data[key]
if oldValue == nil then
-- nobody cares about this value
return
end
if not mapContext.hasReaderCallback then
-- no reader has registered a callback, so just clear the value
mapContext.data[key] = nil
for _, reader in ipairs(mapContext.readers) do
rawset(reader, key, nil)
end
return
end
-- get the new value
local newValue = mapContext.func(key)
if type(newValue) ~= mapContext.valueType then
error(format("Invalid type (got %s, expected %s)", type(newValue), mapContext.valueType))
end
if oldValue == newValue then
-- the value didn't change
return
end
-- update the data
mapContext.data[key] = newValue
for _, reader in ipairs(mapContext.readers) do
local readerContext = private.readerContext[reader]
local prevValue = rawget(reader, key)
if prevValue ~= nil then
rawset(reader, key, newValue)
if readerContext.callback then
readerContext.pendingChanges[key] = prevValue
if mapContext.callbacksPaused == 0 then
readerContext.callback(reader, readerContext.pendingChanges)
wipe(readerContext.pendingChanges)
end
end
end
end
end
function private.MapSetCallbacksPaused(self, paused)
local mapContext = private.mapContext[self]
if paused then
mapContext.callbacksPaused = mapContext.callbacksPaused + 1
else
mapContext.callbacksPaused = mapContext.callbacksPaused - 1
assert(mapContext.callbacksPaused >= 0)
if mapContext.callbacksPaused == 0 then
for _, reader in ipairs(mapContext.readers) do
local readerContext = private.readerContext[reader]
if readerContext.callback and next(readerContext.pendingChanges) then
readerContext.callback(reader, readerContext.pendingChanges)
wipe(readerContext.pendingChanges)
end
end
end
end
end
function private.MapCreateReader(self, callback)
assert(callback == nil or type(callback) == "function")
local reader = setmetatable({}, READER_MT)
local mapContext = private.mapContext[self]
tinsert(mapContext.readers, reader)
mapContext.hasReaderCallback = mapContext.hasReaderCallback or (callback and true or false)
private.readerContext[reader] = {
map = self,
callback = callback,
pendingChanges = {},
}
return reader
end
function private.MapGetKeyType(self)
return private.mapContext[self].keyType
end
function private.MapGetValueType(self)
return private.mapContext[self].valueType
end
function private.MapIterator(self)
return pairs(private.mapContext[self].data)
end