TradeSkillMaster/LibTSM/Util/DatabaseClasses/Query.lua

1620 lines
56 KiB
Lua

-- ------------------------------------------------------------------------------ --
-- TradeSkillMaster --
-- https://tradeskillmaster.com --
-- All Rights Reserved - Detailed license information included with addon. --
-- ------------------------------------------------------------------------------ --
--- Database Query Class.
-- This class represents a database query which is used for reading data out of a @{Database} in a structured and
-- efficient manner.
-- @classmod DatabaseQuery
local _, TSM = ...
local Query = TSM.Init("Util.DatabaseClasses.Query")
local Constants = TSM.Include("Util.DatabaseClasses.Constants")
local Util = TSM.Include("Util.DatabaseClasses.Util")
local QueryResultRow = TSM.Include("Util.DatabaseClasses.QueryResultRow")
local QueryClause = TSM.Include("Util.DatabaseClasses.QueryClause")
local ObjectPool = TSM.Include("Util.ObjectPool")
local TempTable = TSM.Include("Util.TempTable")
local Table = TSM.Include("Util.Table")
local Math = TSM.Include("Util.Math")
local LibTSMClass = TSM.Include("LibTSMClass")
local DatabaseQuery = LibTSMClass.DefineClass("DatabaseQuery")
local private = {
objectPool = nil,
}
-- ============================================================================
-- Module Loading
-- ============================================================================
Query:OnModuleLoad(function()
private.objectPool = ObjectPool.New("DATABASE_QUERIES", DatabaseQuery, 1)
end)
-- ============================================================================
-- Module Functions
-- ============================================================================
function Query.Get(db)
local clause = private.objectPool:Get()
clause:_Acquire(db)
return clause
end
-- ============================================================================
-- Class Meta Methods
-- ============================================================================
function DatabaseQuery.__init(self)
self._db = nil
self._rootClause = nil
self._currentClause = nil
self._orderBy = {}
self._orderByAscending = {}
self._distinct = nil
self._updateCallback = nil
self._updateCallbackContext = nil
self._updatesPaused = 0
self._queuedUpdate = false
self._select = {}
self._iteratorState = "IDLE"
self._result = {}
self._resultRowLookup = {}
self._iterDistinctUsed = {}
self._tempResultRow = nil
self._tempVirtualResultRow = nil
self._autoRelease = false
self._resultIsStale = false
self._joinTypes = {}
self._joinDBs = {}
self._joinFields = {}
self._virtualFieldFunc = {}
self._virtualFieldArgField = {}
self._virtualFieldType = {}
self._genericSortWrapper = function(a, b)
return private.DatabaseQuerySortGeneric(self, a, b)
end
self._singleSortWrapper = function(a, b)
return private.DatabaseQuerySortSingle(self, a, b, self._orderByAscending[1])
end
self._secondarySortWrapper = function(a, b)
return private.DatabaseQuerySortSingle(self, a, b, self._orderByAscending[2])
end
self._sortValueCache = {}
self._resultDependencies = {}
end
function DatabaseQuery._Acquire(self, db)
self._db = db
self._db:_RegisterQuery(self)
-- implicit root AND clause
self._rootClause = QueryClause.Get(self)
:And()
self._currentClause = self._rootClause
self._tempResultRow = QueryResultRow.Get()
self._tempResultRow:_Acquire(self._db, self)
end
function DatabaseQuery._Release(self)
assert(self._iteratorState == "IDLE")
-- remove from the database
self._db:_RemoveQuery(self)
self._db = nil
self._rootClause:_Release()
self._rootClause = nil
self._currentClause = nil
self._updateCallback = nil
self._updateCallbackContext = nil
self._updatesPaused = 0
self._queuedUpdate = false
wipe(self._iterDistinctUsed)
self._tempResultRow:Release()
self._tempResultRow = nil
if self._tempVirtualResultRow then
self._tempVirtualResultRow:Release()
self._tempVirtualResultRow = nil
end
self._autoRelease = false
self:_WipeResults()
self:ResetOrderBy()
self:ResetDistinct()
self:ResetSelect()
self:ResetJoins()
self:ResetVirtualFields()
self._resultIsStale = false
wipe(self._resultDependencies)
end
-- ============================================================================
-- Public Class Methods
-- ============================================================================
--- Releases the database query.
-- The database query object will be recycled and must not be accessed after calling this method.
-- @tparam DatabaseQuery self The database query object
-- @tparam[opt=false] boolean abortIterator Abort any in-progress iterator
function DatabaseQuery.Release(self, abortIterator)
if abortIterator then
self._iteratorState = "IDLE"
end
self:_Release()
private.objectPool:Recycle(self)
end
--- Adds a virtual field to the query.
-- @tparam DatabaseQuery self The database query object
-- @tparam string field The name of the new virtual field
-- @tparam string fieldType The type of the virtual field
-- @tparam function func A function which takes a row and returns the value of the virtual field
-- @tparam[opt=nil] string argField The field to pass into the function (otherwise passes the entire row)
-- @treturn DatabaseQuery The database query object
function DatabaseQuery.VirtualField(self, field, fieldType, func, argField)
if self:_GetFieldType(field) or self._virtualFieldFunc[field] then
error("Field already exists: "..tostring(field))
elseif type(func) ~= "function" then
error("Invalid func: "..tostring(func))
elseif fieldType ~= "number" and fieldType ~= "string" and fieldType ~= "boolean" then
error("Field type must be string, number, or boolean")
elseif argField and not self:_GetFieldType(argField) then
error("Arg field doesn't exist: "..tostring(argField))
end
self._virtualFieldFunc[field] = func
self._virtualFieldArgField[field] = argField
self._virtualFieldType[field] = fieldType
self._resultIsStale = true
return self
end
--- Where a field equals a value.
-- @tparam DatabaseQuery self The database query object
-- @tparam string field The name of the field
-- @param value The value to compare to
-- @tparam[opt=nil] string otherField The name of the other field to compare with
-- @treturn DatabaseQuery The database query object
function DatabaseQuery.Equal(self, field, value, otherField)
if value == Constants.OTHER_FIELD_QUERY_PARAM then
local fieldType = self:_GetFieldType(field)
assert(fieldType and fieldType == self:_GetFieldType(otherField))
elseif value ~= Constants.BOUND_QUERY_PARAM then
assert(self:_GetFieldType(field) == type(value))
end
self:_NewClause()
:Equal(field, value, otherField)
return self
end
--- Where a field does not equals a value.
-- @tparam DatabaseQuery self The database query object
-- @tparam string field The name of the field
-- @param value The value to compare to
-- @tparam[opt=nil] string otherField The name of the other field to compare with
-- @treturn DatabaseQuery The database query object
function DatabaseQuery.NotEqual(self, field, value, otherField)
if value == Constants.OTHER_FIELD_QUERY_PARAM then
local fieldType = self:_GetFieldType(field)
assert(fieldType and fieldType == self:_GetFieldType(otherField))
elseif value ~= Constants.BOUND_QUERY_PARAM then
assert(self:_GetFieldType(field) == type(value))
end
self:_NewClause()
:NotEqual(field, value, otherField)
return self
end
--- Where a field is less than a value.
-- @tparam DatabaseQuery self The database query object
-- @tparam string field The name of the field
-- @param value The value to compare to
-- @tparam[opt=nil] string otherField The name of the other field to compare with
-- @treturn DatabaseQuery The database query object
function DatabaseQuery.LessThan(self, field, value, otherField)
if value == Constants.OTHER_FIELD_QUERY_PARAM then
local fieldType = self:_GetFieldType(field)
assert(fieldType and fieldType == self:_GetFieldType(otherField))
elseif value ~= Constants.BOUND_QUERY_PARAM then
assert(self:_GetFieldType(field) == type(value))
end
self:_NewClause()
:LessThan(field, value, otherField)
return self
end
--- Where a field is less than or equal to a value.
-- @tparam DatabaseQuery self The database query object
-- @tparam string field The name of the field
-- @param value The value to compare to
-- @tparam[opt=nil] string otherField The name of the other field to compare with
-- @treturn DatabaseQuery The database query object
function DatabaseQuery.LessThanOrEqual(self, field, value, otherField)
if value == Constants.OTHER_FIELD_QUERY_PARAM then
local fieldType = self:_GetFieldType(field)
assert(fieldType and fieldType == self:_GetFieldType(otherField))
elseif value ~= Constants.BOUND_QUERY_PARAM then
assert(self:_GetFieldType(field) == type(value))
end
self:_NewClause()
:LessThanOrEqual(field, value, otherField)
return self
end
--- Where a field is greater than a value.
-- @tparam DatabaseQuery self The database query object
-- @tparam string field The name of the field
-- @param value The value to compare to
-- @tparam[opt=nil] string otherField The name of the other field to compare with
-- @treturn DatabaseQuery The database query object
function DatabaseQuery.GreaterThan(self, field, value, otherField)
if value == Constants.OTHER_FIELD_QUERY_PARAM then
local fieldType = self:_GetFieldType(field)
assert(fieldType and fieldType == self:_GetFieldType(otherField))
elseif value ~= Constants.BOUND_QUERY_PARAM then
assert(self:_GetFieldType(field) == type(value))
end
self:_NewClause()
:GreaterThan(field, value, otherField)
return self
end
--- Where a field is greater than or equal to a value.
-- @tparam DatabaseQuery self The database query object
-- @tparam string field The name of the field
-- @param value The value to compare to
-- @tparam[opt=nil] string otherField The name of the other field to compare with
-- @treturn DatabaseQuery The database query object
function DatabaseQuery.GreaterThanOrEqual(self, field, value, otherField)
if value == Constants.OTHER_FIELD_QUERY_PARAM then
local fieldType = self:_GetFieldType(field)
assert(fieldType and fieldType == self:_GetFieldType(otherField))
elseif value ~= Constants.BOUND_QUERY_PARAM then
assert(self:_GetFieldType(field) == type(value))
end
self:_NewClause()
:GreaterThanOrEqual(field, value, otherField)
return self
end
--- Where a string field matches a pattern.
-- @tparam DatabaseQuery self The database query object
-- @tparam string field The name of the field
-- @tparam string value The pattern to match
-- @treturn DatabaseQuery The database query object
function DatabaseQuery.Matches(self, field, value)
assert(value ~= Constants.BOUND_QUERY_PARAM, "This method does not support bound values")
assert(self:_GetFieldType(field) == "string" and type(value) == "string")
self:_NewClause()
:Matches(field, strlower(value))
return self
end
--- Where a string field contains a substring.
-- @tparam DatabaseQuery self The database query object
-- @tparam string field The name of the field
-- @tparam string value The substring to match
-- @treturn DatabaseQuery The database query object
function DatabaseQuery.Contains(self, field, value)
assert(value ~= Constants.BOUND_QUERY_PARAM, "This method does not support bound values")
assert(self:_GetFieldType(field) == "string" and type(value) == "string")
self:_NewClause()
:Contains(field, strlower(value))
return self
end
--- Where a string field starts with a substring.
-- @tparam DatabaseQuery self The database query object
-- @tparam string field The name of the field
-- @tparam string value The substring to match
-- @treturn DatabaseQuery The database query object
function DatabaseQuery.StartsWith(self, field, value)
assert(value ~= Constants.BOUND_QUERY_PARAM, "This method does not support bound values")
assert(self:_GetFieldType(field) == "string" and type(value) == "string")
self:_NewClause()
:StartsWith(field, strlower(value))
return self
end
--- Where a foreign field (obtained via a left join) is nil.
-- @tparam DatabaseQuery self The database query object
-- @tparam string field The name of the field
-- @treturn DatabaseQuery The database query object
function DatabaseQuery.IsNil(self, field)
assert(self:_GetJoinType(field) == "LEFT", "Must be a left join")
self:_NewClause()
:IsNil(field)
return self
end
--- Where a foreign field (obtained via a left join) is not nil.
-- @tparam DatabaseQuery self The database query object
-- @tparam string field The name of the field
-- @treturn DatabaseQuery The database query object
function DatabaseQuery.IsNotNil(self, field)
assert(self:_GetJoinType(field) == "LEFT", "Must be a left join")
self:_NewClause()
:IsNotNil(field)
return self
end
--- A custom query clause.
-- @tparam DatabaseQuery self The database query object
-- @tparam function func The function which gets passed the row being evaulated and returns true/false if the query
-- should include it
-- @param[opt] arg An argument to pass to the function
-- @treturn DatabaseQuery The database query object
function DatabaseQuery.Custom(self, func, arg)
assert(type(func) == "function")
self:_NewClause()
:Custom(func, arg)
return self
end
--- Where the hash of a row equals a value.
-- @tparam DatabaseQuery self The database query object
-- @tparam table fields An ordered list of fields to hash
-- @tparam number value The hash value to compare to
-- @treturn DatabaseQuery The database query object
function DatabaseQuery.HashEqual(self, fields, value)
assert(value ~= Constants.BOUND_QUERY_PARAM, "This method does not support bound values")
assert(type(fields) == "table")
for _, field in ipairs(fields) do
local fieldType = self:_GetFieldType(field)
if not fieldType then
error(format("Field %s doesn't exist", tostring(field)))
elseif fieldType ~= "number" and fieldType ~= "string" then
error(format("Cannot hash field of type %s", fieldType))
end
end
self:_NewClause()
:HashEqual(fields, value)
return self
end
--- Where a field exists as a key within a table
-- @tparam DatabaseQuery self The database query object
-- @tparam string field The name of the field
-- @tparam table value The table to check against
-- @treturn DatabaseQuery The database query object
function DatabaseQuery.InTable(self, field, value)
assert(value ~= Constants.BOUND_QUERY_PARAM, "This method does not support bound values")
assert(type(value) == "table")
self:_NewClause()
:InTable(field, value)
return self
end
--- Where a field does not exists as a key within a table
-- @tparam DatabaseQuery self The database query object
-- @tparam string field The name of the field
-- @tparam table value The table to check against
-- @treturn DatabaseQuery The database query object
function DatabaseQuery.NotInTable(self, field, value)
assert(value ~= Constants.BOUND_QUERY_PARAM, "This method does not support bound values")
assert(type(value) == "table")
self:_NewClause()
:NotInTable(field, value)
return self
end
--- Starts a nested AND clause.
-- All of the clauses following this (until the matching @{DatabaseQuery.End}) must be true for the OR clause to be true.
-- @tparam DatabaseQuery self The database query object
-- @treturn DatabaseQuery The database query object
function DatabaseQuery.And(self)
self._currentClause = self:_NewClause()
:And()
return self
end
--- Starts a nested OR clause.
-- At least one of the clauses following this (until the matching @{DatabaseQuery.End}) must be true for the OR clause
-- to be true.
-- @tparam DatabaseQuery self The database query object
-- @treturn DatabaseQuery The database query object
function DatabaseQuery.Or(self)
self._currentClause = self:_NewClause()
:Or()
return self
end
--- Ends a nested AND/OR clause.
-- @tparam DatabaseQuery self The database query object
-- @treturn DatabaseQuery The database query object
function DatabaseQuery.End(self)
assert(self._currentClause ~= self._rootClause, "No current clause to end")
self._currentClause = self._currentClause:_GetParent()
assert(self._currentClause)
return self
end
function DatabaseQuery.LeftJoin(self, db, field)
self:_JoinHelper(db, field, "LEFT")
return self
end
function DatabaseQuery.InnerJoin(self, db, field)
self:_JoinHelper(db, field, "INNER")
return self
end
--- Order the results by a field.
-- This may be called multiple times to provide additional ordering constraints. The priority of the ordering will be
-- descending as this method is called additional times (meaning the first OrderBy will have highest priority).
-- @tparam DatabaseQuery self The database query object
-- @tparam string field The name of the field to order by
-- @tparam boolean ascending Whether to order in ascending order (descending otherwise)
-- @treturn DatabaseQuery The database query object
function DatabaseQuery.OrderBy(self, field, ascending)
assert(ascending == true or ascending == false)
local fieldType = self:_GetFieldType(field)
if not fieldType then
error(format("Field %s doesn't exist", tostring(field)))
elseif fieldType ~= "number" and fieldType ~= "string" and fieldType ~= "boolean" then
error(format("Cannot order by field of type %s", tostring(fieldType)))
end
tinsert(self._orderBy, field)
tinsert(self._orderByAscending, ascending)
self._resultIsStale = true
return self
end
--- Only return distinct results based on a field.
-- This method can be used to ensure that only the first row for each distinct value of the field is returned.
-- @tparam DatabaseQuery self The database query object
-- @tparam string field The field to ensure is distinct in the results
-- @treturn DatabaseQuery The database query object
function DatabaseQuery.Distinct(self, field)
assert(self:_GetFieldType(field), format("Field %s doesn't exist within local DB", tostring(field)))
self._distinct = field
self._resultIsStale = true
return self
end
--- Select specific fields in the result.
-- @tparam DatabaseQuery self The database query object
-- @tparam vararg ... The fields to select
-- @treturn DatabaseQuery The database query object
function DatabaseQuery.Select(self, ...)
assert(#self._select == 0)
local numFields = select("#", ...)
assert(numFields > 0, "Must select at least 1 field")
-- DatabaseRow.GetFields() only supports 10 fields, so we can only support 10 here as well
assert(numFields <= 10, "Select() only supports up to 10 fields")
for i = 1, numFields do
local field = select(i, ...)
tinsert(self._select, field)
end
self._resultIsStale = true
return self
end
--- Binds parameters to a prepared query.
-- The number of arguments should match the number of Constants.BOUND_QUERY_PARAM values in the query's clauses.
-- @tparam DatabaseQuery self The database query object
-- @tparam vararg ... The fields to select
-- @treturn DatabaseQuery The database query object
function DatabaseQuery.BindParams(self, ...)
local numFields = select("#", ...)
assert(self._rootClause:_BindParams(...) == numFields, "Invalid number of bound parameters")
self._resultIsStale = true
return self
end
--- Set an update callback.
-- This callback gets called whenever any rows in the underlying database change.
-- @tparam DatabaseQuery self The database query object
-- @tparam function func The callback function which is called with (self, changedUUID, context)
-- @param[opt=nil] context A context argument which is passed as the third argument to the callback function
-- @treturn DatabaseQuery The database query object
function DatabaseQuery.SetUpdateCallback(self, func, context)
self._updateCallback = func
self._updateCallbackContext = context
return self
end
--- Pauses or unpauses callbacks for query updates.
-- @tparam DatabaseQuery self The database query object
-- @tparam boolean paused Whether or not updates should be paused
-- @treturn DatabaseQuery The database query object
function DatabaseQuery.SetUpdatesPaused(self, paused)
self._updatesPaused = self._updatesPaused + (paused and 1 or -1)
assert(self._updatesPaused >= 0)
if self._updatesPaused == 0 and self._queuedUpdate then
self:_DoUpdateCallback()
end
return self
end
--- Results iterator.
-- Note that the iterator must run to completion (don't use `break` or `return` to escape it early).
-- @tparam DatabaseQuery self The database query object
-- @tparam boolean canAbort Allow the iterator to be aborted if the underlying data is updated which must
-- be handled by the caller by calling `IsIteratorAborted()` at the end of each iteration loop
-- @return An iterator for the results of the query
function DatabaseQuery.Iterator(self, canAbort)
self:_Execute()
assert(self._rootClause and self._currentClause == self._rootClause, "Did not end sub-clause")
assert(self._iteratorState == "IDLE")
assert(not canAbort or not self._updateCallback)
self._iteratorState = canAbort and "IN_PROGRESS_CAN_ABORT" or "IN_PROGRESS"
self._autoRelease = false
return private.QueryResultIterator, self, 0
end
--- Iterates through the results as uuids.
-- @tparam DatabaseQuery self The database query object
-- @return An iterator for the results of the query as UUIDs
function DatabaseQuery.UUIDIterator(self)
self:_Execute()
assert(self._rootClause and self._currentClause == self._rootClause, "Did not end sub-clause")
assert(self._iteratorState == "IDLE")
self._iteratorState = "IN_PROGRESS"
self._autoRelease = false
return private.QueryResultAsUUIDIterator, self, 0
end
--- Results iterator which releases upon completion.
-- Note that the iterator must run to completion (don't use `break` or `return` to escape it early).
-- @tparam DatabaseQuery self The database query object
-- @return An iterator for the results of the query
function DatabaseQuery.IteratorAndRelease(self)
self:_Execute()
assert(self._rootClause and self._currentClause == self._rootClause, "Did not end sub-clause")
assert(self._iteratorState == "IDLE")
self._iteratorState = "IN_PROGRESS"
self._autoRelease = true
return private.QueryResultIterator, self, 0
end
--- Check if the abortable iterator has been aborted.
-- @tparam DatabaseQuery self The database query object
-- @treturn boolean Whether or not the iterator has been aborted
function DatabaseQuery.IsIteratorAborted(self)
if self._iteratorState == "IN_PROGRESS_CAN_ABORT" then
return false
elseif self._iteratorState == "PENDING_ABORT" then
self._iteratorState = "ABORTED"
return true
else
error("Invalid iterator state: "..tostring(self._iteratorState))
end
end
--- Populates a table with the results.
-- The query must have a select clause with at least one or two fields. In the former case, the table will be populated
-- as a list, and in the latter case, the first field must be unique in the results, and will be used as the key for the
-- table with the second field being the value.
-- @tparam DatabaseQuery self The database query object
-- @tparam table tbl The table to store the result in
function DatabaseQuery.AsTable(self, tbl)
self:_Execute()
if #self._select == 1 then
local field = unpack(self._select)
for _, uuid in ipairs(self._result) do
tinsert(tbl, self:_GetResultRowData(uuid, field))
end
elseif #self._select == 2 then
local field1, field2 = unpack(self._select)
for _, uuid in ipairs(self._result) do
local key = self:_GetResultRowData(uuid, field1)
if key == nil or tbl[key] then
error("Key is nil or not distinct")
end
tbl[key] = self:_GetResultRowData(uuid, field2)
end
else
error("Invalid select clause")
end
return self
end
--- Get the number of resulting rows.
-- @tparam DatabaseQuery self The database query object
-- @treturn number The number of rows
function DatabaseQuery.Count(self)
self:_Execute()
return #self._result
end
--- Get the number of resulting rows and release.
-- @tparam DatabaseQuery self The database query object
-- @treturn number The number of rows
function DatabaseQuery.CountAndRelease(self)
self:_Execute()
local count = #self._result
self:Release()
return count
end
--- Get a single result.
-- This method will assert that there is exactly one result from the query and return it.
-- @tparam DatabaseQuery self The database query object
-- @return The result row or the selected fields
function DatabaseQuery.GetSingleResult(self)
self:_Execute()
assert(self:Count() == 1)
return self:GetFirstResult()
end
--- Get a single result and release.
-- This method will assert that there is exactly one result from the query and return it.
-- @tparam DatabaseQuery self The database query object
-- @return The result row or the selected fields
function DatabaseQuery.GetSingleResultAndRelease(self)
assert(#self._select > 0)
local result = self:GetSingleResult()
self:Release()
return result
end
--- Get the first result.
-- Note that this method internally iterates over all the results.
-- @tparam DatabaseQuery self The database query object
-- @return The result row or the selected fields
function DatabaseQuery.GetFirstResult(self)
self:_Execute()
assert(self._iteratorState == "IDLE")
if self:Count() == 0 then
return
end
local uuid = self._result[1]
if not self._resultRowLookup[uuid] then
self:_CreateResultRow(uuid)
end
local row = self._resultRowLookup[uuid]
if #self._select > 0 then
return row:GetFields(unpack(self._select))
else
return row
end
end
--- Get the first result and release.
-- Note that this method internally iterates over all the results.
-- @tparam DatabaseQuery self The database query object
-- @return The result row or the selected fields
function DatabaseQuery.GetFirstResultAndRelease(self)
self:_Execute()
assert(self._iteratorState == "IDLE")
if self:Count() == 0 then
self:Release()
return
end
local uuid = self._result[1]
if not self._resultRowLookup[uuid] then
self:_CreateResultRow(uuid)
end
local row = self._resultRowLookup[uuid]
if #self._select > 0 then
return self:_PassThroughReleaseHelper(row:GetFields(unpack(self._select)))
else
row = row:Clone()
self:Release()
return row
end
end
--- Gets the minimum value of a specific field within the query results.
-- @tparam DatabaseQuery self The database query object
-- @tparam string field The field within the results
-- @treturn ?number The minimum value or nil if there are no results
function DatabaseQuery.Min(self, field)
self:_Execute()
local result = nil
for _, uuid in ipairs(self._result) do
local value = self:_GetResultRowData(uuid, field)
result = min(result or math.huge, value)
end
return result
end
--- Gets the maximum value of a specific field within the query results.
-- @tparam DatabaseQuery self The database query object
-- @tparam string field The field within the results
-- @treturn ?number The maximum value or nil if there are no results
function DatabaseQuery.Max(self, field)
self:_Execute()
local result = nil
for _, uuid in ipairs(self._result) do
local value = self:_GetResultRowData(uuid, field)
result = max(result or -math.huge, value)
end
return result
end
--- Gets the summed value of a specific field within the query results.
-- @tparam DatabaseQuery self The database query object
-- @tparam string field The field within the results
-- @treturn ?number The summed value or nil if there are no results
function DatabaseQuery.Sum(self, field)
self:_Execute()
local result = nil
for _, uuid in ipairs(self._result) do
local value = self:_GetResultRowData(uuid, field)
result = (result or 0) + value
end
return result
end
--- Gets the summed value of a specific field for each group within the query results.
-- @tparam DatabaseQuery self The database query object
-- @tparam string groupField The field to group by
-- @tparam string sumField The field to sum
-- @tparam table result The results table
function DatabaseQuery.GroupedSum(self, groupField, sumField, result)
self:_Execute()
for _, uuid in ipairs(self._result) do
local group = self:_GetResultRowData(uuid, groupField)
local value = self:_GetResultRowData(uuid, sumField)
result[group] = (result[group] or 0) + value
end
end
--- Gets the summed value of a specific field within the query results and releases the query.
-- @tparam DatabaseQuery self The database query object
-- @tparam string field The field within the results
-- @treturn ?number The summed value or nil if there are no results
function DatabaseQuery.SumAndRelease(self, field)
self:_Execute()
local result = nil
for _, uuid in ipairs(self._result) do
local value = self:_GetResultRowData(uuid, field)
result = (result or 0) + value
end
self:Release()
return result
end
--- Gets the average value of a specific field within the query results.
-- @tparam DatabaseQuery self The database query object
-- @tparam string field The field within the results
-- @treturn ?number The average value or nil if there are no results
function DatabaseQuery.Avg(self, field)
local sum = self:Sum(field)
local num = self:Count()
return sum and (sum / num) or nil
end
--- Gets the sum of the products of two fields within the query results.
-- @tparam DatabaseQuery self The database query object
-- @tparam string field1 The first field within the results
-- @tparam string field2 The second field within the results
-- @treturn ?number The summed value or nil if there are no results
function DatabaseQuery.SumOfProduct(self, field1, field2)
self:_Execute()
local result = nil
for _, uuid in ipairs(self._result) do
local value1 = self:_GetResultRowData(uuid, field1)
local value2 = self:_GetResultRowData(uuid, field2)
result = (result or 0) + value1 * value2
end
return result
end
--- Joins the string values of a field with a given separator.
-- @tparam DatabaseQuery self The database query object
-- @tparam string field The field within the results
-- @tparam string sep The separator (can be any number of characters, including an empty string)
-- @treturn string The joined string
function DatabaseQuery.JoinedString(self, field, sep)
self:_Execute()
local parts = TempTable.Acquire()
for _, uuid in ipairs(self._result) do
tinsert(parts, self:_GetResultRowData(uuid, field))
end
local result = table.concat(parts, sep)
TempTable.Release(parts)
return result
end
--- Calculates the hash of the query results.
-- Note that either `fields` must be specified or the query must have a select colum with at most 2 fields.
-- @tparam DatabaseQuery self The database query object
-- @tparam[opt=nil] table fields The fields from each row to hash (ottherwise uses the selected fields)
-- @treturn ?number The hash value or nil if there are no results
function DatabaseQuery.Hash(self, fields)
self:_Execute()
local result = nil
if fields then
for _, uuid in ipairs(self._result) do
for _, field in ipairs(fields) do
result = Math.CalculateHash(self:_GetResultRowData(uuid, field), result)
end
end
else
local keyField, valueField, extra = unpack(self._select)
assert(keyField and not extra)
local hashContext = TempTable.Acquire()
for _, uuid in ipairs(self._result) do
tinsert(hashContext, self:_GetResultRowData(uuid, keyField))
if valueField then
tinsert(hashContext, self:_GetResultRowData(uuid, valueField))
end
end
Table.Sort(hashContext)
for _, value in ipairs(hashContext) do
result = Math.CalculateHash(value, result)
end
TempTable.Release(hashContext)
end
return result
end
--- Calculates the hash of the query results, grouping by a field.
-- @tparam DatabaseQuery self The database query object
-- @tparam table fields The fields from each row to hash
-- @tparam string groupField The field to group by
-- @tparam table result The result table
function DatabaseQuery.GroupedHash(self, fields, groupField, result)
self:_Execute()
for i = 1, #self._result do
local uuid = self._result[i]
local groupValue = self:_GetResultRowData(uuid, groupField)
local rowHash = nil
for j = 1, #fields do
rowHash = Math.CalculateHash(self:_GetResultRowData(uuid, fields[j]), rowHash)
end
result[groupValue] = Math.CalculateHash(rowHash, result[groupValue])
end
end
--- Calculates the hash of the query results and release.
-- Note that either `fields` must be specified or the query must have a select colum with at most 2 fields.
-- @tparam DatabaseQuery self The database query object
-- @tparam[opt=nil] table fields The fields from each row to hash (ottherwise uses the selected fields)
-- @treturn ?number The hash value or nil if there are no results
function DatabaseQuery.HashAndRelease(self, fields)
local result = self:Hash(fields)
self:Release()
return result
end
--- Deletes all the result rows from the database and releases the query.
-- @tparam DatabaseQuery self The database query object
-- @treturn ?number The number of rows deleted (equal to `:Count()`)
function DatabaseQuery.DeleteAndRelease(self)
local count = self:Count()
self._db:BulkDelete(self._result)
self:Release()
return count
end
--- Resets the database query.
-- @tparam DatabaseQuery self The database query object
-- @treturn DatabaseQuery The database query object
function DatabaseQuery.Reset(self)
self:ResetDistinct()
self:ResetSelect()
self:ResetOrderBy()
self:ResetJoins()
self:ResetFilters()
self:ResetVirtualFields()
self:_WipeResults()
self._resultIsStale = true
return self
end
--- Resets any virtual fields added to the database query.
-- @tparam DatabaseQuery self The database query object
-- @treturn DatabaseQuery The database query object
function DatabaseQuery.ResetVirtualFields(self)
wipe(self._virtualFieldFunc)
wipe(self._virtualFieldArgField)
wipe(self._virtualFieldType)
self._resultIsStale = true
return self
end
--- Resets any filtering clauses of the database query.
-- @tparam DatabaseQuery self The database query object
-- @treturn DatabaseQuery The database query object
function DatabaseQuery.ResetFilters(self)
self._rootClause:_Release()
self._rootClause = QueryClause.Get(self)
:And()
self._currentClause = self._rootClause
self._resultIsStale = true
return self
end
--- Resets any ordering clauses of the database query.
-- @tparam DatabaseQuery self The database query object
-- @treturn DatabaseQuery The database query object
function DatabaseQuery.ResetOrderBy(self)
wipe(self._orderBy)
wipe(self._orderByAscending)
self._resultIsStale = true
return self
end
--- Resets any joins of the database query.
-- @tparam DatabaseQuery self The database query object
-- @treturn DatabaseQuery The database query object
function DatabaseQuery.ResetJoins(self)
for _, db in ipairs(self._joinDBs) do
db:_RemoveQuery(self)
end
wipe(self._joinTypes)
wipe(self._joinDBs)
wipe(self._joinFields)
self._resultIsStale = true
return self
end
--- Resets any distinct clauses of the database query.
-- @tparam DatabaseQuery self The database query object
-- @treturn DatabaseQuery The database query object
function DatabaseQuery.ResetDistinct(self)
self._distinct = nil
self._resultIsStale = true
return self
end
--- Resets any select clauses of the database query.
-- @tparam DatabaseQuery self The database query object
-- @treturn DatabaseQuery The database query object
function DatabaseQuery.ResetSelect(self)
wipe(self._select)
self._resultIsStale = true
return self
end
--- Gets info on a specific order by clause.
-- @tparam DatabaseQuery self The database query object
-- @tparam number index The index of the order by clause
-- @treturn ?string The field name
-- @treturn ?boolean Whether or not the sort is ascending
function DatabaseQuery.GetOrderBy(self, index)
assert(self._orderBy[index])
return self._orderBy[index], self._orderByAscending[index]
end
--- Gets info on the last order by clause.
-- @tparam DatabaseQuery self The database query object
-- @treturn ?string The field name
-- @treturn ?boolean Whether or not the sort is ascending
function DatabaseQuery.GetLastOrderBy(self)
return self._orderBy[#self._orderBy], self._orderByAscending[#self._orderByAscending]
end
--- Updates the last order by clause.
-- @tparam DatabaseQuery self The database query object
-- @tparam string field The name of the field to order by
-- @tparam boolean ascending Whether to order in ascending order (descending otherwise)
-- @treturn DatabaseQuery The database query object
function DatabaseQuery.UpdateLastOrderBy(self, field, ascending)
assert(#self._orderBy > 0)
tremove(self._orderBy)
tremove(self._orderByAscending)
self:OrderBy(field, ascending)
return self
end
--- Get a result row by its UUID.
-- @tparam DatabaseQuery self The database query object
-- @tparam number uuid The UUID of the row to get
-- @return QueryResultRow The result row name
function DatabaseQuery.GetResultRowByUUID(self, uuid)
if not self._resultRowLookup[uuid] then
self:_CreateResultRow(uuid)
end
return self._resultRowLookup[uuid]
end
-- ============================================================================
-- Private Class Methods
-- ============================================================================
function DatabaseQuery._GetJoinType(self, field)
for i, db in ipairs(self._joinDBs) do
if db:_GetFieldType(field) then
return self._joinTypes[i]
end
end
end
function DatabaseQuery._GetFieldType(self, field)
local fieldType = self._virtualFieldType[field] or self._db:_GetFieldType(field)
if fieldType then
return fieldType
end
for _, db in ipairs(self._joinDBs) do
fieldType = db:_GetFieldType(field)
if fieldType then
return fieldType
end
end
end
function DatabaseQuery._MarkResultStale(self, changedFields)
assert(self._iteratorState == "IDLE" or self._iteratorState == "IN_PROGRESS_CAN_ABORT" or self._iteratorState == "PENDING_ABORT")
if self._resultIsStale then
-- already marked stale
return
end
if self._resultDependencies._all or not changedFields then
-- either the result depends on all fields or we weren't given a table of changed fields
self._resultIsStale = true
if self._iteratorState == "IN_PROGRESS_CAN_ABORT" then
self._iteratorState = "PENDING_ABORT"
end
return
end
-- check if any of the fields our result is based on changed
for field in pairs(changedFields) do
if self._resultDependencies[field] then
self._resultIsStale = true
if self._iteratorState == "IN_PROGRESS_CAN_ABORT" then
self._iteratorState = "PENDING_ABORT"
end
return
end
end
-- clear the cached values for the changed fields
for _, row in pairs(self._resultRowLookup) do
if row ~= false then
for field in pairs(changedFields) do
rawset(row, field, nil)
end
end
end
if self._iteratorState == "IN_PROGRESS_CAN_ABORT" then
self._iteratorState = "PENDING_ABORT"
end
end
function DatabaseQuery._DoUpdateCallback(self, uuid)
if not self._updateCallback then
assert(self._iteratorState == "IDLE" or self._iteratorState == "PENDING_ABORT")
return
end
-- can't have an update callback on an abortable iterator
assert(self._iteratorState == "IDLE")
if self._updatesPaused > 0 then
self._queuedUpdate = true
else
self._queuedUpdate = false
if self._resultIsStale or not uuid then
self:_updateCallback(nil, self._updateCallbackContext)
elseif self._db:_ContainsUUID(uuid) then
self:_updateCallback(uuid, self._updateCallbackContext)
else
-- the UUID is from a joined DB, so see if we can easily translate it to a local UUID
local localUUID = nil
for i = 1, #self._joinDBs do
local joinDB = self._joinDBs[i]
if joinDB:_ContainsUUID(uuid) then
if localUUID then
-- found more than once, so bail
localUUID = nil
break
end
local joinField = self._joinFields[i]
local joinValue = joinDB:GetRowFieldByUUID(uuid, joinField)
if self._db:_IsUnique(joinField) then
localUUID = self._db:_GetUniqueRow(joinField, joinValue)
elseif self._db:_IsIndex(joinField) then
local lowIndex, highIndex = self._db:_GetIndexListMatchingIndexRange(joinField, Util.ToIndexValue(joinValue))
if not lowIndex or not highIndex or lowIndex ~= highIndex then
-- can't use this index to find a single local UUID
break
end
localUUID = self._db:_GetAllRowsByIndex(joinField)[lowIndex]
end
end
end
self:_updateCallback(localUUID, self._updateCallbackContext)
end
end
end
function DatabaseQuery._NewClause(self)
self._resultIsStale = true
local newClause = QueryClause.Get(self, self._currentClause)
self._currentClause:_InsertSubClause(newClause)
return newClause
end
function DatabaseQuery._WipeResults(self)
for _, row in pairs(self._resultRowLookup) do
if row ~= false then
row:Release()
end
end
wipe(self._result)
wipe(self._resultRowLookup)
end
function DatabaseQuery._Execute(self, force)
if not self._resultIsStale and not force then
return
end
assert(self._rootClause and self._currentClause == self._rootClause, "Did not end sub-clause")
assert(self._iteratorState == "IDLE")
assert(not next(self._iterDistinctUsed))
-- clear the current result
self:_WipeResults()
-- get all the rows which we need to iterate over
local firstOrderBy = self._orderBy[1]
local skipFirstOrderBy = false
local sortNeeded = firstOrderBy and true or false
local indexType, indexField, indexArg1, indexArg2, indexArg3 = self:_GetQueryIndexInfo()
self._result._queryOptimizationResult = indexType
self._result._queryOptimizationField = indexField
if indexType == "EMPTY" then
sortNeeded = false
elseif indexType == "UNIQUE" then
-- we are looking for a unique row
local indexValue = indexArg1
local uuid = self._db:_GetUniqueRow(indexField, indexValue)
if uuid and self:_ResultShouldIncludeRow(uuid, false, #self._joinDBs, self._distinct) then
tinsert(self._result, uuid)
self._resultRowLookup[uuid] = false
end
sortNeeded = false
elseif indexType == "INDEX" then
-- we're querying on an index, so use that index to populate the result
local firstIndex, lastIndex, isStrict = indexArg1, indexArg2, indexArg3
local isAscending = true
if firstOrderBy and indexField == firstOrderBy then
-- we're also ordering by this field so can skip the first OrderBy field
self._result._queryOptimizationResult = "INDEX_AND_ORDER_BY"
skipFirstOrderBy = true
sortNeeded = #self._orderBy > 1
isAscending = self._orderByAscending[1]
end
local indexList = self._db:_GetAllRowsByIndex(indexField)
self:_AddResultRowsFromIndex(indexList, isStrict, firstIndex, lastIndex, isAscending)
elseif indexType == "NONE" then
if firstOrderBy and self._db:_IsIndex(firstOrderBy) then
-- we're ordering on an index, so use that index to iterate through all the rows in order to skip the first OrderBy field
self._result._queryOptimizationResult = "ORDER_BY"
self._result._queryOptimizationField = firstOrderBy
skipFirstOrderBy = true
sortNeeded = #self._orderBy > 1
local isAscending = self._orderByAscending[1]
local indexList = self._db:_GetAllRowsByIndex(firstOrderBy)
self:_AddResultRowsFromIndex(indexList, false, 1, #indexList, isAscending)
else
-- no optimizations
self:_AddResultRowsCheckAll()
end
elseif indexType == "TRIGRAM" then
local indexValue = indexArg1
local uuids = TempTable.Acquire()
self._db:_GetTrigramIndexMatchingRows(indexValue, uuids)
self:_AddResultRowsFromIndex(uuids, false, 1, #uuids, true)
TempTable.Release(uuids)
else
error("Invalid index type: "..tostring(indexType))
end
wipe(self._iterDistinctUsed)
-- sort the results if necessary
if sortNeeded then
if #self._orderBy == 1 then
assert(not skipFirstOrderBy)
assert(not next(self._sortValueCache))
for _, uuid in ipairs(self._result) do
self._sortValueCache[uuid] = Util.ToIndexValue(self:_GetResultRowData(uuid, self._orderBy[1]))
end
Table.Sort(self._result, self._singleSortWrapper)
wipe(self._sortValueCache)
elseif skipFirstOrderBy and #self._orderBy == 2 then
-- the result is already ordered by the first orderBy field, so iterate through it
-- and sort each group of results where the first orderBy field is the same
assert(not next(self._sortValueCache))
local group = TempTable.Acquire()
local subsetLen = 0
local currentSortValue = nil
for i = 1, #self._result do
local uuid = self._result[i]
local sortValue = Util.ToIndexValue(self:_GetResultRowData(uuid, self._orderBy[1]))
self._sortValueCache[uuid] = Util.ToIndexValue(self:_GetResultRowData(uuid, self._orderBy[2]))
if sortValue ~= currentSortValue then
-- the first sort value changed, so we're now in a new group
if subsetLen > 1 then
-- sort the previous group
Table.Sort(group, self._secondarySortWrapper)
-- update the corresponding results
local offset = i - subsetLen - 1
for j = 1, subsetLen do
self._result[offset + j] = group[j]
end
end
subsetLen = 0
wipe(group)
currentSortValue = sortValue
end
subsetLen = subsetLen + 1
group[subsetLen] = uuid
end
if subsetLen > 1 then
-- sort the previous group
Table.Sort(group, self._secondarySortWrapper)
-- update the corresponding results
local offset = #self._result - subsetLen
for i = 1, subsetLen do
self._result[offset + i] = group[i]
end
end
TempTable.Release(group)
wipe(self._sortValueCache)
else
Table.Sort(self._result, self._genericSortWrapper)
end
end
-- update the dependencies
wipe(self._resultDependencies)
if next(self._virtualFieldFunc) then
self._resultDependencies._all = true
else
for i = 1, #self._joinFields do
self._resultDependencies[self._joinFields[i]] = true
end
for i = 1, #self._orderBy do
self._resultDependencies[self._orderBy[i]] = true
end
if self._distinct then
self._resultDependencies[self._distinct] = true
end
for i = 1, #self._select do
self._resultDependencies[self._select[i]] = true
end
for field in self._db:FieldIterator() do
if self._rootClause:_UsesField(field) then
self._resultDependencies[field] = true
end
end
end
self._resultIsStale = false
end
function DatabaseQuery._GetQueryIndexInfo(self)
-- try to find the index with the least result rows
local indexField, indexFirstIndex, indexLastIndex, indexIsStrict = nil, nil, nil, false
local bestIndexDiff = math.huge
for _, field in ipairs(self._db:_GetIndexAndUniqueList()) do
local valueMin, valueMax = self:_IndexValueHelper(strsplit(Constants.DB_INDEX_FIELD_SEP, field))
if valueMin == nil and valueMax == nil then
-- continue
elseif self._db:_IsUnique(field) and valueMin == valueMax then
-- unique indexes result in a single row, at which point the benefit of trying to find something better (EMPTY) is negligible
return "UNIQUE", field, valueMin
elseif self._db:_IsIndex(field) then
-- check how many rows this index results in
local indexList = self._db:_GetAllRowsByIndex(field)
local firstIndex = valueMin and self._db:_IndexListBinarySearch(field, valueMin, true) or min(1, #indexList)
local lastIndex = valueMax and self._db:_IndexListBinarySearch(field, valueMax, false) or #indexList
local indexDiff = lastIndex - firstIndex
if indexDiff < 0 then
-- there are no results within this index, so this is as good as it gets
return "EMPTY", field
else
-- NOTE: string indexes can't be strict since they are case-insensitive
local isStrict = type(valueMin) ~= "string" and type(valueMax) ~= "string" and self._rootClause:_IsStrictIndex(field, valueMin, valueMax)
if isStrict then
-- rough estimate that being able to skip the query makes each row cost 1/4 as much
indexDiff = floor(indexDiff / 4)
end
if indexDiff < bestIndexDiff then
-- this is our new best index
indexField = field
indexFirstIndex = firstIndex
indexLastIndex = lastIndex
indexIsStrict = isStrict
bestIndexDiff = indexDiff
end
end
end
end
if indexField then
return "INDEX", indexField, indexFirstIndex, indexLastIndex, indexIsStrict
end
-- try the trigram index
local trigramIndexField = self._db:_GetTrigramIndexField()
if trigramIndexField then
local trigramIndexValue = self._rootClause:_GetTrigramIndexValue(trigramIndexField)
if trigramIndexValue then
return "TRIGRAM", trigramIndexField, trigramIndexValue
end
end
return "NONE"
end
function DatabaseQuery._AddResultRowsFromIndex(self, indexList, skipQuery, firstIndex, lastIndex, isAscending)
local numJoinDBs = #self._joinDBs
local distinct = self._distinct
local result = self._result
local resultIndex = #self._result + 1
local resultRowLookup = self._resultRowLookup
for i = isAscending and firstIndex or lastIndex, isAscending and lastIndex or firstIndex, isAscending and 1 or -1 do
local uuid = indexList[i]
if skipQuery and numJoinDBs == 0 and not distinct then
-- fast path where there's no further filtering so we add all rows
result[resultIndex] = uuid
resultIndex = resultIndex + 1
resultRowLookup[uuid] = false
elseif self:_ResultShouldIncludeRow(uuid, skipQuery, numJoinDBs, distinct) then
result[resultIndex] = uuid
resultIndex = resultIndex + 1
resultRowLookup[uuid] = false
end
end
end
function DatabaseQuery._AddResultRowsCheckAll(self)
local numJoinDBs = #self._joinDBs
local distinct = self._distinct
local result = self._result
local resultIndex = #self._result + 1
local resultRowLookup = self._resultRowLookup
for _, uuid in self._db:_UUIDIterator() do
if self:_ResultShouldIncludeRow(uuid, false, numJoinDBs, distinct) then
result[resultIndex] = uuid
resultIndex = resultIndex + 1
resultRowLookup[uuid] = false
end
end
end
function DatabaseQuery._ResultShouldIncludeRow(self, uuid, skipQuery, numJoinDBs, distinct)
for i = 1, numJoinDBs do
local joinType = self._joinTypes[i]
local joinDB = self._joinDBs[i]
local joinField = self._joinFields[i]
if joinType == "INNER" and not joinDB:_GetUniqueRow(joinField, self._db:GetRowFieldByUUID(uuid, joinField)) then
return false
end
end
if not skipQuery then
self._tempResultRow:_SetUUID(uuid)
if not self._rootClause:_IsTrue(self._tempResultRow) then
return false
end
end
if distinct then
local distinctValue = self:_GetResultRowData(uuid, distinct)
if self._iterDistinctUsed[distinctValue] then
return false
end
self._iterDistinctUsed[distinctValue] = true
end
return true
end
function DatabaseQuery._CreateResultRow(self, uuid)
assert(self._resultRowLookup[uuid] == false)
local row = QueryResultRow.Get()
row:_Acquire(self._db, self)
row:_SetUUID(uuid)
self._resultRowLookup[uuid] = row
return row
end
function DatabaseQuery._IndexValueHelper(self, ...)
local num = select("#", ...)
local valueMin, valueMax = nil, nil
for i = 1, num do
local fieldPart = select(i, ...)
local partValueMin, partValueMax = self._rootClause:_GetIndexValue(fieldPart)
if partValueMin == nil and partValueMax == nil then
return
end
if num > 1 and (partValueMin == nil or partValueMax == nil) then
-- only use multi-field indexes if there's both a min and max value
return
end
if i > 1 then
valueMin = valueMin .. Constants.DB_INDEX_VALUE_SEP .. partValueMin
valueMax = valueMax .. Constants.DB_INDEX_VALUE_SEP .. partValueMax
else
valueMin = partValueMin
valueMax = partValueMax
end
end
return valueMin, valueMax
end
function DatabaseQuery._PassThroughReleaseHelper(self, ...)
self:Release()
return ...
end
function DatabaseQuery._GetResultRowData(self, uuid, field)
if self._virtualFieldFunc[field] then
local argField = self._virtualFieldArgField[field]
local argValue = nil
if argField then
argValue = self:_GetResultRowData(uuid, argField)
else
if not self._tempVirtualResultRow then
self._tempVirtualResultRow = QueryResultRow.Get()
self._tempVirtualResultRow:_Acquire(self._db, self)
end
self._tempVirtualResultRow:_SetUUID(uuid)
argValue = self._tempVirtualResultRow
end
local value = self._virtualFieldFunc[field](argValue)
if type(value) ~= self._virtualFieldType[field] then
error(format("Virtual field value not the correct type (%s, %s)", tostring(argValue), tostring(value)))
end
return value
elseif #self._joinDBs == 0 or self._db:_GetFieldType(field) then
-- this is a local field
return self._db:GetRowFieldByUUID(uuid, field)
else
-- this is a foreign field
local joinDB = nil
local joinField = nil
for i = 1, #self._joinDBs do
local testDB = self._joinDBs[i]
if testDB:_GetFieldType(field) then
if joinDB then
error("Multiple joined DBs have this field", 2)
end
joinDB = testDB
joinField = self._joinFields[i]
end
end
if not joinDB then
error("Invalid field: "..tostring(field), 2)
end
local foreignUUID = joinDB:_GetUniqueRow(joinField, self:_GetResultRowData(uuid, joinField))
if foreignUUID then
return joinDB:GetRowFieldByUUID(foreignUUID, field)
end
end
end
function DatabaseQuery._JoinHelper(self, db, field, joinType)
assert(type(field) == "string")
local localFieldType = self._virtualFieldType[field] or self._db:_GetFieldType(field)
local foreignFieldType = db:_GetFieldType(field)
assert(localFieldType, "Local field doesn't exist: "..tostring(field))
assert(foreignFieldType, "Foreign field doesn't exist: "..tostring(field))
assert(localFieldType == foreignFieldType, format("Field types don't match (%s, %s)", tostring(localFieldType), tostring(foreignFieldType)))
assert(db:_IsUnique(field), "Field must be unique in foreign DB")
assert(not Table.KeyByValue(self._joinDBs, db), "Already joining with this DB")
for foreignField in db:FieldIterator() do
if foreignField ~= field then
assert(not self._db:_GetFieldType(foreignField), "Foreign field conflicts with local DB: "..tostring(foreignField))
end
end
for virtualField in pairs(self._virtualFieldFunc) do
if virtualField ~= field then
assert(not db:_GetFieldType(virtualField), "Virtual field conflicts with foreign DB: "..tostring(virtualField))
end
end
db:_RegisterQuery(self)
tinsert(self._joinTypes, joinType)
tinsert(self._joinDBs, db)
tinsert(self._joinFields, field)
self._resultIsStale = true
end
-- ============================================================================
-- Private Helper Functions
-- ============================================================================
function private.DatabaseQuerySortSingle(self, aUUID, bUUID, isAscending)
local aValue = self._sortValueCache[aUUID]
local bValue = self._sortValueCache[bUUID]
if aValue == bValue then
-- make the sort stable
return aUUID > bUUID
elseif aValue == nil then
-- sort nil to the end
return false
elseif bValue == nil then
-- sort nil to the end
return true
elseif isAscending then
return aValue < bValue
else
return aValue > bValue
end
end
function private.DatabaseQuerySortGeneric(self, aUUID, bUUID)
for i = 1, #self._orderBy do
local orderByField = self._orderBy[i]
local aValue = Util.ToIndexValue(self:_GetResultRowData(aUUID, orderByField))
local bValue = Util.ToIndexValue(self:_GetResultRowData(bUUID, orderByField))
if aValue == bValue then
-- continue looping
elseif aValue == nil then
-- sort nil to the end
return false
elseif bValue == nil then
-- sort nil to the end
return true
elseif self._orderByAscending[i] then
return aValue < bValue
else
return aValue > bValue
end
end
-- make the sort stable
return aUUID > bUUID
end
function private.QueryResultAsUUIDIterator(self, index)
index = index + 1
local uuid = self._result[index]
if not uuid then
assert(self._iteratorState == "IN_PROGRESS")
self._iteratorState = "IDLE"
if self._autoRelease then
self:Release()
end
return
end
return index, uuid
end
function private.QueryResultIterator(self, index)
index = index + 1
local uuid = self._result[index]
if self._iteratorState == "ABORTED" then
uuid = nil
elseif self._iteratorState ~= "IN_PROGRESS" and self._iteratorState ~= "IN_PROGRESS_CAN_ABORT" then
error("Invalid iteratorState: "..tostring(self._iteratorState))
end
if not uuid then
assert(self._iteratorState == "IN_PROGRESS" or self._iteratorState == "IN_PROGRESS_CAN_ABORT" or self._iteratorState == "ABORTED")
self._iteratorState = "IDLE"
if self._autoRelease then
self:Release()
end
return
end
local numSelectFields = #self._select
if numSelectFields == 0 then
local row = self._resultRowLookup[uuid]
if not row then
row = self:_CreateResultRow(uuid)
end
return index, row
elseif #self._joinDBs == 0 and numSelectFields <= 5 then
-- as an optimization, we don't need to create a result row
if numSelectFields == 1 then
return index, self:_GetResultRowData(uuid, self._select[1])
elseif numSelectFields == 2 then
return index, self:_GetResultRowData(uuid, self._select[1]), self:_GetResultRowData(uuid, self._select[2])
elseif numSelectFields == 3 then
return index, self:_GetResultRowData(uuid, self._select[1]), self:_GetResultRowData(uuid, self._select[2]), self:_GetResultRowData(uuid, self._select[3])
elseif numSelectFields == 4 then
return index, self:_GetResultRowData(uuid, self._select[1]), self:_GetResultRowData(uuid, self._select[2]), self:_GetResultRowData(uuid, self._select[3]), self:_GetResultRowData(uuid, self._select[4])
elseif numSelectFields == 5 then
return index, self:_GetResultRowData(uuid, self._select[1]), self:_GetResultRowData(uuid, self._select[2]), self:_GetResultRowData(uuid, self._select[3]), self:_GetResultRowData(uuid, self._select[4]), self:_GetResultRowData(uuid, self._select[5])
else
error("Invalid numSelectFields: "..tostring(numSelectFields))
end
else
local row = self._resultRowLookup[uuid]
if not row then
row = self:_CreateResultRow(uuid)
end
return index, row:GetFields(unpack(self._select))
end
end