diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index e9177b9a9..ad94a7b6d 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -91,7 +91,7 @@ steps: codecov: true - label: Julia 1.11 (CUDA) - timeout_in_minutes: 20 + timeout_in_minutes: 30 <<: *gputest plugins: - JuliaCI/julia#v1: diff --git a/ext/CUDAExt.jl b/ext/CUDAExt.jl index 6b8c61f9a..9f9b8df4d 100644 --- a/ext/CUDAExt.jl +++ b/ext/CUDAExt.jl @@ -48,6 +48,13 @@ function Dagger.memory_space(x::CuArray) device_uuid = CUDA.uuid(dev) return CUDAVRAMMemorySpace(myid(), device_id, device_uuid) end +function Dagger.aliasing(x::CuArray{T}) where T + space = Dagger.memory_space(x) + S = typeof(space) + cuptr = pointer(x) + rptr = Dagger.RemotePtr{Cvoid}(UInt64(cuptr), space) + return Dagger.ContiguousAliasing(Dagger.MemorySpan{S}(rptr, sizeof(T)*length(x))) +end Dagger.memory_spaces(proc::CuArrayDeviceProc) = Set([CUDAVRAMMemorySpace(proc.owner, proc.device, proc.device_uuid)]) Dagger.processors(space::CUDAVRAMMemorySpace) = Set([CuArrayDeviceProc(space.owner, space.device, space.device_uuid)]) @@ -75,6 +82,8 @@ function with_context!(space::CUDAVRAMMemorySpace) @assert Dagger.root_worker_id(space) == myid() with_context!(space.device) end +Dagger.with_context!(proc::CuArrayDeviceProc) = with_context!(proc) +Dagger.with_context!(space::CUDAVRAMMemorySpace) = with_context!(space) function with_context(f, x) old_ctx = context() old_stream = stream() diff --git a/ext/IntelExt.jl b/ext/IntelExt.jl index 74253007d..08d54ee81 100644 --- a/ext/IntelExt.jl +++ b/ext/IntelExt.jl @@ -46,6 +46,13 @@ function Dagger.memory_space(x::oneArray) return IntelVRAMMemorySpace(myid(), device_id) end _device_id(dev::ZeDevice) = findfirst(other_dev->other_dev === dev, collect(oneAPI.devices())) +function Dagger.aliasing(x::oneArray{T}) where T + space = Dagger.memory_space(x) + S = typeof(space) + gpu_ptr = pointer(x) + rptr = Dagger.RemotePtr{Cvoid}(UInt64(gpu_ptr), space) + return Dagger.ContiguousAliasing(Dagger.MemorySpan{S}(rptr, sizeof(T)*length(x))) +end Dagger.memory_spaces(proc::oneArrayDeviceProc) = Set([IntelVRAMMemorySpace(proc.owner, proc.device_id)]) Dagger.processors(space::IntelVRAMMemorySpace) = Set([oneArrayDeviceProc(space.owner, space.device_id)]) @@ -68,6 +75,8 @@ function with_context!(space::IntelVRAMMemorySpace) @assert Dagger.root_worker_id(space) == myid() with_context!(space.device_id) end +Dagger.with_context!(proc::oneArrayDeviceProc) = with_context!(proc) +Dagger.with_context!(space::IntelVRAMMemorySpace) = with_context!(space) function with_context(f, x) old_drv = driver() old_dev = device() diff --git a/ext/MetalExt.jl b/ext/MetalExt.jl index 50cfc8905..21cea360a 100644 --- a/ext/MetalExt.jl +++ b/ext/MetalExt.jl @@ -43,6 +43,13 @@ function Dagger.memory_space(x::MtlArray) return MetalVRAMMemorySpace(myid(), device_id) end _device_id(dev::MtlDevice) = findfirst(other_dev->other_dev === dev, Metal.devices()) +function Dagger.aliasing(x::MtlArray{T}) where T + space = Dagger.memory_space(x) + S = typeof(space) + gpu_ptr = pointer(x) + rptr = Dagger.RemotePtr{Cvoid}(UInt64(gpu_ptr), space) + return Dagger.ContiguousAliasing(Dagger.MemorySpan{S}(rptr, sizeof(T)*length(x))) +end Dagger.memory_spaces(proc::MtlArrayDeviceProc) = Set([MetalVRAMMemorySpace(proc.owner, proc.device_id)]) Dagger.processors(space::MetalVRAMMemorySpace) = Set([MtlArrayDeviceProc(space.owner, space.device_id)]) @@ -66,6 +73,8 @@ end function with_context!(space::MetalVRAMMemorySpace) @assert Dagger.root_worker_id(space) == myid() end +Dagger.with_context!(proc::MtlArrayDeviceProc) = with_context!(proc) +Dagger.with_context!(space::MetalVRAMMemorySpace) = with_context!(space) function with_context(f, x) with_context!(x) return f() diff --git a/ext/OpenCLExt.jl b/ext/OpenCLExt.jl index fbf73de72..f8eac930c 100644 --- a/ext/OpenCLExt.jl +++ b/ext/OpenCLExt.jl @@ -44,6 +44,13 @@ function Dagger.memory_space(x::CLArray) idx = findfirst(==(queue), QUEUES) return CLMemorySpace(myid(), idx) end +function Dagger.aliasing(x::CLArray{T}) where T + space = Dagger.memory_space(x) + S = typeof(space) + gpu_ptr = pointer(x) + rptr = Dagger.RemotePtr{Cvoid}(UInt64(gpu_ptr), space) + return Dagger.ContiguousAliasing(Dagger.MemorySpan{S}(rptr, sizeof(T)*length(x))) +end Dagger.memory_spaces(proc::CLArrayDeviceProc) = Set([CLMemorySpace(proc.owner, proc.device)]) Dagger.processors(space::CLMemorySpace) = Set([CLArrayDeviceProc(space.owner, space.device)]) @@ -71,6 +78,8 @@ function with_context!(space::CLMemorySpace) @assert Dagger.root_worker_id(space) == myid() with_context!(space.device) end +Dagger.with_context!(proc::CLArrayDeviceProc) = with_context!(proc) +Dagger.with_context!(space::CLMemorySpace) = with_context!(space) function with_context(f, x) old_ctx = cl.context() old_queue = cl.queue() diff --git a/ext/ROCExt.jl b/ext/ROCExt.jl index 288c4744f..773c2bb95 100644 --- a/ext/ROCExt.jl +++ b/ext/ROCExt.jl @@ -39,6 +39,13 @@ end Dagger.root_worker_id(space::ROCVRAMMemorySpace) = space.owner Dagger.memory_space(x::ROCArray) = ROCVRAMMemorySpace(myid(), AMDGPU.device(x).device_id) +function Dagger.aliasing(x::ROCArray{T}) where T + space = Dagger.memory_space(x) + S = typeof(space) + gpu_ptr = pointer(x) + rptr = Dagger.RemotePtr{Cvoid}(UInt64(gpu_ptr), space) + return Dagger.ContiguousAliasing(Dagger.MemorySpan{S}(rptr, sizeof(T)*length(x))) +end Dagger.memory_spaces(proc::ROCArrayDeviceProc) = Set([ROCVRAMMemorySpace(proc.owner, proc.device_id)]) Dagger.processors(space::ROCVRAMMemorySpace) = Set([ROCArrayDeviceProc(space.owner, space.device_id)]) @@ -67,6 +74,8 @@ function with_context!(space::ROCVRAMMemorySpace) @assert Dagger.root_worker_id(space) == myid() with_context!(space.device_id) end +Dagger.with_context!(proc::ROCArrayDeviceProc) = with_context!(proc) +Dagger.with_context!(space::ROCVRAMMemorySpace) = with_context!(space) function with_context(f, x) old_ctx = context() old_device = AMDGPU.device() diff --git a/src/Dagger.jl b/src/Dagger.jl index fa30c7c1a..102a76149 100644 --- a/src/Dagger.jl +++ b/src/Dagger.jl @@ -73,6 +73,9 @@ include("utils/fetch.jl") include("utils/chunks.jl") include("utils/logging.jl") include("submission.jl") +abstract type MemorySpace end +include("utils/memory-span.jl") +include("utils/interval_tree.jl") include("memory-spaces.jl") # Task scheduling @@ -83,7 +86,12 @@ include("utils/caching.jl") include("sch/Sch.jl"); using .Sch # Data dependency task queue -include("datadeps.jl") +include("datadeps/aliasing.jl") +include("datadeps/chunkview.jl") +include("datadeps/remainders.jl") +include("datadeps/queue.jl") + +# Stencils include("utils/haloarray.jl") include("stencil.jl") diff --git a/src/argument.jl b/src/argument.jl index 94246a75e..849486e03 100644 --- a/src/argument.jl +++ b/src/argument.jl @@ -20,6 +20,7 @@ function pos_kw(pos::ArgPosition) @assert pos.kw != :NULL return pos.kw end + mutable struct Argument pos::ArgPosition value @@ -41,6 +42,35 @@ function Base.iterate(arg::Argument, state::Bool) return nothing end end - Base.copy(arg::Argument) = Argument(ArgPosition(arg.pos), arg.value) chunktype(arg::Argument) = chunktype(value(arg)) + +mutable struct TypedArgument{T} + pos::ArgPosition + value::T +end +TypedArgument(pos::Integer, value::T) where T = TypedArgument{T}(ArgPosition(true, pos, :NULL), value) +TypedArgument(kw::Symbol, value::T) where T = TypedArgument{T}(ArgPosition(false, 0, kw), value) +Base.setproperty!(arg::TypedArgument, name::Symbol, value::T) where T = + throw(ArgumentError("Cannot set properties of TypedArgument")) +ispositional(arg::TypedArgument) = ispositional(arg.pos) +iskw(arg::TypedArgument) = iskw(arg.pos) +pos_idx(arg::TypedArgument) = pos_idx(arg.pos) +pos_kw(arg::TypedArgument) = pos_kw(arg.pos) +raw_position(arg::TypedArgument) = raw_position(arg.pos) +value(arg::TypedArgument) = arg.value +valuetype(arg::TypedArgument{T}) where T = T +Base.iterate(arg::TypedArgument) = (arg.pos, true) +function Base.iterate(arg::TypedArgument, state::Bool) + if state + return (arg.value, false) + else + return nothing + end +end +Base.copy(arg::TypedArgument{T}) where T = TypedArgument{T}(ArgPosition(arg.pos), arg.value) +chunktype(arg::TypedArgument) = chunktype(value(arg)) + +Argument(arg::TypedArgument) = Argument(arg.pos, arg.value) + +const AnyArgument = Union{Argument, TypedArgument} \ No newline at end of file diff --git a/src/datadeps.jl b/src/datadeps.jl deleted file mode 100644 index d20bda647..000000000 --- a/src/datadeps.jl +++ /dev/null @@ -1,1082 +0,0 @@ -import Graphs: SimpleDiGraph, add_edge!, add_vertex!, inneighbors, outneighbors, nv - -export In, Out, InOut, Deps, spawn_datadeps - -"Specifies a read-only dependency." -struct In{T} - x::T -end -"Specifies a write-only dependency." -struct Out{T} - x::T -end -"Specifies a read-write dependency." -struct InOut{T} - x::T -end -"Specifies one or more dependencies." -struct Deps{T,DT<:Tuple} - x::T - deps::DT -end -Deps(x, deps...) = Deps(x, deps) - -struct DataDepsTaskQueue <: AbstractTaskQueue - # The queue above us - upper_queue::AbstractTaskQueue - # The set of tasks that have already been seen - seen_tasks::Union{Vector{Pair{DTaskSpec,DTask}},Nothing} - # The data-dependency graph of all tasks - g::Union{SimpleDiGraph{Int},Nothing} - # The mapping from task to graph ID - task_to_id::Union{Dict{DTask,Int},Nothing} - # How to traverse the dependency graph when launching tasks - traversal::Symbol - # Which scheduler to use to assign tasks to processors - scheduler::Symbol - - # Whether aliasing across arguments is possible - # The fields following only apply when aliasing==true - aliasing::Bool - - function DataDepsTaskQueue(upper_queue; - traversal::Symbol=:inorder, - scheduler::Symbol=:naive, - aliasing::Bool=true) - seen_tasks = Pair{DTaskSpec,DTask}[] - g = SimpleDiGraph() - task_to_id = Dict{DTask,Int}() - return new(upper_queue, seen_tasks, g, task_to_id, traversal, scheduler, - aliasing) - end -end - -function unwrap_inout(arg) - readdep = false - writedep = false - if arg isa In - readdep = true - arg = arg.x - elseif arg isa Out - writedep = true - arg = arg.x - elseif arg isa InOut - readdep = true - writedep = true - arg = arg.x - elseif arg isa Deps - alldeps = Tuple[] - for dep in arg.deps - dep_mod, inner_deps = unwrap_inout(dep) - for (_, readdep, writedep) in inner_deps - push!(alldeps, (dep_mod, readdep, writedep)) - end - end - arg = arg.x - return arg, alldeps - else - readdep = true - end - return arg, Tuple[(identity, readdep, writedep)] -end - -function enqueue!(queue::DataDepsTaskQueue, spec::Pair{DTaskSpec,DTask}) - push!(queue.seen_tasks, spec) -end -function enqueue!(queue::DataDepsTaskQueue, specs::Vector{Pair{DTaskSpec,DTask}}) - append!(queue.seen_tasks, specs) -end - -_identity_hash(arg, h::UInt=UInt(0)) = ismutable(arg) ? objectid(arg) : hash(arg, h) -_identity_hash(arg::SubArray, h::UInt=UInt(0)) = hash(arg.indices, hash(arg.offset1, hash(arg.stride1, _identity_hash(arg.parent, h)))) -_identity_hash(arg::CartesianIndices, h::UInt=UInt(0)) = hash(arg.indices, hash(typeof(arg), h)) - -struct ArgumentWrapper - arg - dep_mod - hash::UInt - - function ArgumentWrapper(arg, dep_mod) - h = hash(dep_mod) - h = _identity_hash(arg, h) - return new(arg, dep_mod, h) - end -end -Base.hash(aw::ArgumentWrapper) = hash(ArgumentWrapper, aw.hash) -Base.:(==)(aw1::ArgumentWrapper, aw2::ArgumentWrapper) = - aw1.hash == aw2.hash - -struct DataDepsAliasingState - # Track original and current data locations - # We track data => space - data_origin::Dict{AliasingWrapper,MemorySpace} - data_locality::Dict{AliasingWrapper,MemorySpace} - - # Track writers ("owners") and readers - ainfos_owner::Dict{AliasingWrapper,Union{Pair{DTask,Int},Nothing}} - ainfos_readers::Dict{AliasingWrapper,Vector{Pair{DTask,Int}}} - ainfos_overlaps::Dict{AliasingWrapper,Set{AliasingWrapper}} - - # Cache ainfo lookups - ainfo_cache::Dict{ArgumentWrapper,AliasingWrapper} - - function DataDepsAliasingState() - data_origin = Dict{AliasingWrapper,MemorySpace}() - data_locality = Dict{AliasingWrapper,MemorySpace}() - - ainfos_owner = Dict{AliasingWrapper,Union{Pair{DTask,Int},Nothing}}() - ainfos_readers = Dict{AliasingWrapper,Vector{Pair{DTask,Int}}}() - ainfos_overlaps = Dict{AliasingWrapper,Set{AliasingWrapper}}() - - ainfo_cache = Dict{ArgumentWrapper,AliasingWrapper}() - - return new(data_origin, data_locality, - ainfos_owner, ainfos_readers, ainfos_overlaps, - ainfo_cache) - end -end -struct DataDepsNonAliasingState - # Track original and current data locations - # We track data => space - data_origin::IdDict{Any,MemorySpace} - data_locality::IdDict{Any,MemorySpace} - - # Track writers ("owners") and readers - args_owner::IdDict{Any,Union{Pair{DTask,Int},Nothing}} - args_readers::IdDict{Any,Vector{Pair{DTask,Int}}} - - function DataDepsNonAliasingState() - data_origin = IdDict{Any,MemorySpace}() - data_locality = IdDict{Any,MemorySpace}() - - args_owner = IdDict{Any,Union{Pair{DTask,Int},Nothing}}() - args_readers = IdDict{Any,Vector{Pair{DTask,Int}}}() - - return new(data_origin, data_locality, - args_owner, args_readers) - end -end -struct DataDepsState{State<:Union{DataDepsAliasingState,DataDepsNonAliasingState}} - # Whether aliasing is being analyzed - aliasing::Bool - - # The ordered list of tasks and their read/write dependencies - dependencies::Vector{Pair{DTask,Vector{Tuple{Bool,Bool,AliasingWrapper,<:Any,<:Any}}}} - - # The mapping of memory space to remote argument copies - remote_args::Dict{MemorySpace,IdDict{Any,Any}} - - # Cache of whether arguments supports in-place move - supports_inplace_cache::IdDict{Any,Bool} - - # The aliasing analysis state - alias_state::State - - function DataDepsState(aliasing::Bool) - dependencies = Pair{DTask,Vector{Tuple{Bool,Bool,AliasingWrapper,<:Any,<:Any}}}[] - remote_args = Dict{MemorySpace,IdDict{Any,Any}}() - supports_inplace_cache = IdDict{Any,Bool}() - if aliasing - state = DataDepsAliasingState() - else - state = DataDepsNonAliasingState() - end - return new{typeof(state)}(aliasing, dependencies, remote_args, supports_inplace_cache, state) - end -end - -function aliasing(astate::DataDepsAliasingState, arg, dep_mod) - aw = ArgumentWrapper(arg, dep_mod) - get!(astate.ainfo_cache, aw) do - return AliasingWrapper(aliasing(arg, dep_mod)) - end -end - -function supports_inplace_move(state::DataDepsState, arg) - return get!(state.supports_inplace_cache, arg) do - return supports_inplace_move(arg) - end -end - -# Determine which arguments could be written to, and thus need tracking - -"Whether `arg` has any writedep in this datadeps region." -function has_writedep(state::DataDepsState{DataDepsNonAliasingState}, arg, deps) - # Check if we are writing to this memory - writedep = any(dep->dep[3], deps) - if writedep - arg_has_writedep[arg] = true - return true - end - - # Check if another task is writing to this memory - for (_, taskdeps) in state.dependencies - for (_, other_arg_writedep, _, _, other_arg) in taskdeps - other_arg_writedep || continue - if arg === other_arg - return true - end - end - end - - return false -end -""" -Whether `arg` has any writedep at or before executing `task` in this -datadeps region. -""" -function has_writedep(state::DataDepsState, arg, deps, task::DTask) - is_writedep(arg, deps, task) && return true - if state.aliasing - for (other_task, other_taskdeps) in state.dependencies - for (readdep, writedep, other_ainfo, _, _) in other_taskdeps - writedep || continue - for (dep_mod, _, _) in deps - ainfo = aliasing(state.alias_state, arg, dep_mod) - if will_alias(ainfo, other_ainfo) - return true - end - end - end - if task === other_task - return false - end - end - else - for (other_task, other_taskdeps) in state.dependencies - for (readdep, writedep, _, _, other_arg) in other_taskdeps - writedep || continue - if arg === other_arg - return true - end - end - if task === other_task - return false - end - end - end - error("Task isn't in argdeps set") -end -"Whether `arg` is written to by `task`." -function is_writedep(arg, deps, task::DTask) - return any(dep->dep[3], deps) -end - -# Aliasing state setup -function populate_task_info!(state::DataDepsState, spec::DTaskSpec, task::DTask) - # Populate task dependencies - dependencies_to_add = Vector{Tuple{Bool,Bool,AliasingWrapper,<:Any,<:Any}}() - - # Track the task's arguments and access patterns - for (idx, _arg) in enumerate(spec.fargs) - # Unwrap In/InOut/Out wrappers and record dependencies - arg, deps = unwrap_inout(value(_arg)) - - # Unwrap the Chunk underlying any DTask arguments - arg = arg isa DTask ? fetch(arg; raw=true) : arg - - # Skip non-aliasing arguments - type_may_alias(typeof(arg)) || continue - - # Add all aliasing dependencies - for (dep_mod, readdep, writedep) in deps - if state.aliasing - ainfo = aliasing(state.alias_state, arg, dep_mod) - else - ainfo = AliasingWrapper(UnknownAliasing()) - end - push!(dependencies_to_add, (readdep, writedep, ainfo, dep_mod, arg)) - end - - # Populate argument write info - populate_argument_info!(state, arg, deps) - end - - # Track the task result too - # N.B. We state no readdep/writedep because, while we can't model the aliasing info for the task result yet, we don't want to synchronize because of this - push!(dependencies_to_add, (false, false, AliasingWrapper(UnknownAliasing()), identity, task)) - - # Record argument/result dependencies - push!(state.dependencies, task => dependencies_to_add) -end -function populate_argument_info!(state::DataDepsState{DataDepsAliasingState}, arg, deps) - astate = state.alias_state - for (dep_mod, readdep, writedep) in deps - ainfo = aliasing(astate, arg, dep_mod) - - # Initialize owner and readers - if !haskey(astate.ainfos_owner, ainfo) - overlaps = Set{AliasingWrapper}() - push!(overlaps, ainfo) - for other_ainfo in keys(astate.ainfos_owner) - ainfo == other_ainfo && continue - if will_alias(ainfo, other_ainfo) - push!(overlaps, other_ainfo) - push!(astate.ainfos_overlaps[other_ainfo], ainfo) - end - end - astate.ainfos_overlaps[ainfo] = overlaps - astate.ainfos_owner[ainfo] = nothing - astate.ainfos_readers[ainfo] = Pair{DTask,Int}[] - end - - # Assign data owner and locality - if !haskey(astate.data_locality, ainfo) - astate.data_locality[ainfo] = memory_space(arg) - astate.data_origin[ainfo] = memory_space(arg) - end - end -end -function populate_argument_info!(state::DataDepsState{DataDepsNonAliasingState}, arg, deps) - astate = state.alias_state - # Initialize owner and readers - if !haskey(astate.args_owner, arg) - astate.args_owner[arg] = nothing - astate.args_readers[arg] = DTask[] - end - - # Assign data owner and locality - if !haskey(astate.data_locality, arg) - astate.data_locality[arg] = memory_space(arg) - astate.data_origin[arg] = memory_space(arg) - end -end -function populate_return_info!(state::DataDepsState{DataDepsAliasingState}, task, space) - astate = state.alias_state - @assert !haskey(astate.data_locality, task) - # FIXME: We don't yet know about ainfos for this task -end -function populate_return_info!(state::DataDepsState{DataDepsNonAliasingState}, task, space) - astate = state.alias_state - @assert !haskey(astate.data_locality, task) - astate.data_locality[task] = space - astate.data_origin[task] = space -end - -""" - supports_inplace_move(x) -> Bool - -Returns `false` if `x` doesn't support being copied into from another object -like `x`, via `move!`. This is used in `spawn_datadeps` to prevent attempting -to copy between values which don't support mutation or otherwise don't have an -implemented `move!` and want to skip in-place copies. When this returns -`false`, datadeps will instead perform out-of-place copies for each non-local -use of `x`, and the data in `x` will not be updated when the `spawn_datadeps` -region returns. -""" -supports_inplace_move(x) = true -supports_inplace_move(t::DTask) = supports_inplace_move(fetch(t; raw=true)) -function supports_inplace_move(c::Chunk) - # FIXME: Use MemPool.access_ref - pid = root_worker_id(c.processor) - if pid == myid() - return supports_inplace_move(poolget(c.handle)) - else - return remotecall_fetch(supports_inplace_move, pid, c) - end -end -supports_inplace_move(::Function) = false - -# Read/write dependency management -function get_write_deps!(state::DataDepsState, ainfo_or_arg, task, write_num, syncdeps) - _get_write_deps!(state, ainfo_or_arg, task, write_num, syncdeps) - _get_read_deps!(state, ainfo_or_arg, task, write_num, syncdeps) -end -function get_read_deps!(state::DataDepsState, ainfo_or_arg, task, write_num, syncdeps) - _get_write_deps!(state, ainfo_or_arg, task, write_num, syncdeps) -end - -function _get_write_deps!(state::DataDepsState{DataDepsAliasingState}, ainfo::AbstractAliasing, task, write_num, syncdeps) - astate = state.alias_state - ainfo.inner isa NoAliasing && return - for other_ainfo in astate.ainfos_overlaps[ainfo] - other_task_write_num = astate.ainfos_owner[other_ainfo] - @dagdebug nothing :spawn_datadeps "Considering sync with writer via $ainfo -> $other_ainfo" - other_task_write_num === nothing && continue - other_task, other_write_num = other_task_write_num - write_num == other_write_num && continue - @dagdebug nothing :spawn_datadeps "Sync with writer via $ainfo -> $other_ainfo" - push!(syncdeps, ThunkSyncdep(other_task)) - end -end -function _get_read_deps!(state::DataDepsState{DataDepsAliasingState}, ainfo::AbstractAliasing, task, write_num, syncdeps) - astate = state.alias_state - ainfo.inner isa NoAliasing && return - for other_ainfo in astate.ainfos_overlaps[ainfo] - @dagdebug nothing :spawn_datadeps "Considering sync with reader via $ainfo -> $other_ainfo" - other_tasks = astate.ainfos_readers[other_ainfo] - for (other_task, other_write_num) in other_tasks - write_num == other_write_num && continue - @dagdebug nothing :spawn_datadeps "Sync with reader via $ainfo -> $other_ainfo" - push!(syncdeps, ThunkSyncdep(other_task)) - end - end -end -function add_writer!(state::DataDepsState{DataDepsAliasingState}, ainfo::AbstractAliasing, task, write_num) - state.alias_state.ainfos_owner[ainfo] = task=>write_num - empty!(state.alias_state.ainfos_readers[ainfo]) - # Not necessary to assert a read, but conceptually it's true - add_reader!(state, ainfo, task, write_num) -end -function add_reader!(state::DataDepsState{DataDepsAliasingState}, ainfo::AbstractAliasing, task, write_num) - push!(state.alias_state.ainfos_readers[ainfo], task=>write_num) -end - -function _get_write_deps!(state::DataDepsState{DataDepsNonAliasingState}, arg, task, write_num, syncdeps) - other_task_write_num = state.alias_state.args_owner[arg] - if other_task_write_num !== nothing - other_task, other_write_num = other_task_write_num - if write_num != other_write_num - push!(syncdeps, ThunkSyncdep(other_task)) - end - end -end -function _get_read_deps!(state::DataDepsState{DataDepsNonAliasingState}, arg, task, write_num, syncdeps) - for (other_task, other_write_num) in state.alias_state.args_readers[arg] - if write_num != other_write_num - push!(syncdeps, ThunkSyncdep(other_task)) - end - end -end -function add_writer!(state::DataDepsState{DataDepsNonAliasingState}, arg, task, write_num) - state.alias_state.args_owner[arg] = task=>write_num - empty!(state.alias_state.args_readers[arg]) - # Not necessary to assert a read, but conceptually it's true - add_reader!(state, arg, task, write_num) -end -function add_reader!(state::DataDepsState{DataDepsNonAliasingState}, arg, task, write_num) - push!(state.alias_state.args_readers[arg], task=>write_num) -end - -# Make a copy of each piece of data on each worker -# memory_space => {arg => copy_of_arg} -isremotehandle(x) = false -isremotehandle(x::DTask) = true -isremotehandle(x::Chunk) = true -function generate_slot!(state::DataDepsState, dest_space, data) - if data isa DTask - data = fetch(data; raw=true) - end - orig_space = memory_space(data) - to_proc = first(processors(dest_space)) - from_proc = first(processors(orig_space)) - dest_space_args = get!(IdDict{Any,Any}, state.remote_args, dest_space) - if orig_space == dest_space && (data isa Chunk || !isremotehandle(data)) - # Fast path for local data or data already in a Chunk - data_chunk = tochunk(data, from_proc) - dest_space_args[data] = data_chunk - @assert processor(data_chunk) in processors(dest_space) || data isa Chunk && processor(data) isa Dagger.OSProc - @assert memory_space(data_chunk) == orig_space - else - to_w = root_worker_id(dest_space) - ctx = Sch.eager_context() - id = rand(Int) - dest_space_args[data] = remotecall_fetch(to_w, from_proc, to_proc, data) do from_proc, to_proc, data - timespan_start(ctx, :move, (;thunk_id=0, id, position=ArgPosition(), processor=to_proc), (;f=nothing, data)) - data_converted = move(from_proc, to_proc, data) - data_chunk = tochunk(data_converted, to_proc) - @assert processor(data_chunk) in processors(dest_space) - @assert memory_space(data_converted) == memory_space(data_chunk) "space mismatch! $(memory_space(data_converted)) != $(memory_space(data_chunk)) ($(typeof(data_converted)) vs. $(typeof(data_chunk))), spaces ($orig_space -> $dest_space)" - if orig_space != dest_space - @assert orig_space != memory_space(data_chunk) "space preserved! $orig_space != $(memory_space(data_chunk)) ($(typeof(data)) vs. $(typeof(data_chunk))), spaces ($orig_space -> $dest_space)" - end - return data_chunk - end - timespan_finish(ctx, :move, (;thunk_id=0, id, position=ArgPosition(), processor=to_proc), (;f=nothing, data=dest_space_args[data])) - end - return dest_space_args[data] -end - -struct DataDepsSchedulerState - task_to_spec::Dict{DTask,DTaskSpec} - assignments::Dict{DTask,MemorySpace} - dependencies::Dict{DTask,Set{DTask}} - task_completions::Dict{DTask,UInt64} - space_completions::Dict{MemorySpace,UInt64} - capacities::Dict{MemorySpace,Int} - - function DataDepsSchedulerState() - return new(Dict{DTask,DTaskSpec}(), - Dict{DTask,MemorySpace}(), - Dict{DTask,Set{DTask}}(), - Dict{DTask,UInt64}(), - Dict{MemorySpace,UInt64}(), - Dict{MemorySpace,Int}()) - end -end - -function distribute_tasks!(queue::DataDepsTaskQueue) - #= TODO: Improvements to be made: - # - Support for copying non-AbstractArray arguments - # - Parallelize read copies - # - Unreference unused slots - # - Reuse memory when possible - # - Account for differently-sized data - =# - - # Get the set of all processors to be scheduled on - all_procs = Processor[] - scope = get_compute_scope() - for w in procs() - append!(all_procs, get_processors(OSProc(w))) - end - filter!(proc->!isa(constrain(ExactScope(proc), scope), - InvalidScope), - all_procs) - if isempty(all_procs) - throw(Sch.SchedulingException("No processors available, try widening scope")) - end - exec_spaces = unique(vcat(map(proc->collect(memory_spaces(proc)), all_procs)...)) - if !all(space->space isa CPURAMMemorySpace, exec_spaces) && !all(space->root_worker_id(space) == myid(), exec_spaces) - @warn "Datadeps support for multi-GPU, multi-worker is currently broken\nPlease be prepared for incorrect results or errors" maxlog=1 - end - - # Round-robin assign tasks to processors - upper_queue = get_options(:task_queue) - - traversal = queue.traversal - if traversal == :inorder - # As-is - task_order = Colon() - elseif traversal == :bfs - # BFS - task_order = Int[1] - to_walk = Int[1] - seen = Set{Int}([1]) - while !isempty(to_walk) - # N.B. next_root has already been seen - next_root = popfirst!(to_walk) - for v in outneighbors(queue.g, next_root) - if !(v in seen) - push!(task_order, v) - push!(seen, v) - push!(to_walk, v) - end - end - end - elseif traversal == :dfs - # DFS (modified with backtracking) - task_order = Int[] - to_walk = Int[1] - seen = Set{Int}() - while length(task_order) < length(queue.seen_tasks) && !isempty(to_walk) - next_root = popfirst!(to_walk) - if !(next_root in seen) - iv = inneighbors(queue.g, next_root) - if all(v->v in seen, iv) - push!(task_order, next_root) - push!(seen, next_root) - ov = outneighbors(queue.g, next_root) - prepend!(to_walk, ov) - else - push!(to_walk, next_root) - end - end - end - else - throw(ArgumentError("Invalid traversal mode: $traversal")) - end - - state = DataDepsState(queue.aliasing) - astate = state.alias_state - sstate = DataDepsSchedulerState() - for proc in all_procs - space = only(memory_spaces(proc)) - get!(()->0, sstate.capacities, space) - sstate.capacities[space] += 1 - end - - # Start launching tasks and necessary copies - write_num = 1 - proc_idx = 1 - pressures = Dict{Processor,Int}() - proc_to_scope_lfu = BasicLFUCache{Processor,AbstractScope}(1024) - for (spec, task) in queue.seen_tasks[task_order] - # Populate all task dependencies - populate_task_info!(state, spec, task) - - task_scope = @something(spec.options.compute_scope, spec.options.scope, DefaultScope()) - scheduler = queue.scheduler - if scheduler == :naive - raw_args = map(arg->tochunk(value(arg)), spec.fargs) - our_proc = remotecall_fetch(1, all_procs, raw_args) do all_procs, raw_args - Sch.init_eager() - sch_state = Sch.EAGER_STATE[] - - @lock sch_state.lock begin - # Calculate costs per processor and select the most optimal - # FIXME: This should consider any already-allocated slots, - # whether they are up-to-date, and if not, the cost of moving - # data to them - procs, costs = Sch.estimate_task_costs(sch_state, all_procs, nothing, raw_args) - return first(procs) - end - end - elseif scheduler == :smart - raw_args = map(filter(arg->haskey(astate.data_locality, value(arg)), spec.fargs)) do arg - arg_chunk = tochunk(last(arg)) - # Only the owned slot is valid - # FIXME: Track up-to-date copies and pass all of those - return arg_chunk => data_locality[arg] - end - f_chunk = tochunk(value(f)) - our_proc, task_pressure = remotecall_fetch(1, all_procs, pressures, f_chunk, raw_args) do all_procs, pressures, f, chunks_locality - Sch.init_eager() - sch_state = Sch.EAGER_STATE[] - - @lock sch_state.lock begin - tx_rate = sch_state.transfer_rate[] - - costs = Dict{Processor,Float64}() - for proc in all_procs - # Filter out chunks that are already local - chunks_filt = Iterators.filter(((chunk, space)=chunk_locality)->!(proc in processors(space)), chunks_locality) - - # Estimate network transfer costs based on data size - # N.B. `affinity(x)` really means "data size of `x`" - # N.B. We treat same-worker transfers as having zero transfer cost - tx_cost = Sch.impute_sum(affinity(chunk)[2] for chunk in chunks_filt) - - # Estimate total cost to move data and get task running after currently-scheduled tasks - est_time_util = get(pressures, proc, UInt64(0)) - costs[proc] = est_time_util + (tx_cost/tx_rate) - end - - # Look up estimated task cost - sig = Sch.signature(sch_state, f, map(first, chunks_locality)) - task_pressure = get(sch_state.signature_time_cost, sig, 1000^3) - - # Shuffle procs around, so equally-costly procs are equally considered - P = randperm(length(all_procs)) - procs = getindex.(Ref(all_procs), P) - - # Sort by lowest cost first - sort!(procs, by=p->costs[p]) - - best_proc = first(procs) - return best_proc, task_pressure - end - end - # FIXME: Pressure should be decreased by pressure of syncdeps on same processor - pressures[our_proc] = get(pressures, our_proc, UInt64(0)) + task_pressure - elseif scheduler == :ultra - args = Base.mapany(spec.fargs) do arg - pos, data = arg - data, _ = unwrap_inout(data) - if data isa DTask - data = fetch(data; raw=true) - end - return pos => tochunk(data) - end - f_chunk = tochunk(value(f)) - task_time = remotecall_fetch(1, f_chunk, args) do f, args - Sch.init_eager() - sch_state = Sch.EAGER_STATE[] - return @lock sch_state.lock begin - sig = Sch.signature(sch_state, f, args) - return get(sch_state.signature_time_cost, sig, 1000^3) - end - end - - # FIXME: Copy deps are computed eagerly - deps = get(Set{Any}, spec.options, :syncdeps) - - # Find latest time-to-completion of all syncdeps - deps_completed = UInt64(0) - for dep in deps - haskey(sstate.task_completions, dep) || continue # copy deps aren't recorded - deps_completed = max(deps_completed, sstate.task_completions[dep]) - end - - # Find latest time-to-completion of each memory space - # FIXME: Figure out space completions based on optimal packing - spaces_completed = Dict{MemorySpace,UInt64}() - for space in exec_spaces - completed = UInt64(0) - for (task, other_space) in sstate.assignments - space == other_space || continue - completed = max(completed, sstate.task_completions[task]) - end - spaces_completed[space] = completed - end - - # Choose the earliest-available memory space and processor - # FIXME: Consider move time - move_time = UInt64(0) - local our_space_completed - while true - our_space_completed, our_space = findmin(spaces_completed) - our_space_procs = filter(proc->proc in all_procs, processors(our_space)) - if isempty(our_space_procs) - delete!(spaces_completed, our_space) - continue - end - our_proc = rand(our_space_procs) - break - end - - sstate.task_to_spec[task] = spec - sstate.assignments[task] = our_space - sstate.task_completions[task] = our_space_completed + move_time + task_time - elseif scheduler == :roundrobin - our_proc = all_procs[proc_idx] - if task_scope == scope - # all_procs is already limited to scope - else - if isa(constrain(task_scope, scope), InvalidScope) - throw(Sch.SchedulingException("Scopes are not compatible: $(scope), $(task_scope)")) - end - while !proc_in_scope(our_proc, task_scope) - proc_idx = mod1(proc_idx + 1, length(all_procs)) - our_proc = all_procs[proc_idx] - end - end - else - error("Invalid scheduler: $sched") - end - @assert our_proc in all_procs - our_space = only(memory_spaces(our_proc)) - - # Find the scope for this task (and its copies) - if task_scope == scope - # Optimize for the common case, cache the proc=>scope mapping - our_scope = get!(proc_to_scope_lfu, our_proc) do - our_procs = filter(proc->proc in all_procs, collect(processors(our_space))) - return constrain(UnionScope(map(ExactScope, our_procs)...), scope) - end - else - # Use the provided scope and constrain it to the available processors - our_procs = filter(proc->proc in all_procs, collect(processors(our_space))) - our_scope = constrain(UnionScope(map(ExactScope, our_procs)...), task_scope) - end - if our_scope isa InvalidScope - throw(Sch.SchedulingException("Scopes are not compatible: $(our_scope.x), $(our_scope.y)")) - end - - f = spec.fargs[1] - f.value = move(ThreadProc(myid(), 1), our_proc, value(f)) - @dagdebug nothing :spawn_datadeps "($(repr(value(f)))) Scheduling: $our_proc ($our_space)" - - # Copy raw task arguments for analysis - task_args = map(copy, spec.fargs) - - # Copy args from local to remote - for (idx, _arg) in enumerate(task_args) - # Is the data writeable? - arg, deps = unwrap_inout(value(_arg)) - arg = arg isa DTask ? fetch(arg; raw=true) : arg - if !type_may_alias(typeof(arg)) - @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$idx] Skipped copy-to (unwritten)" - spec.fargs[idx].value = arg - continue - end - if !supports_inplace_move(state, arg) - @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$idx] Skipped copy-to (non-writeable)" - spec.fargs[idx].value = arg - continue - end - - # Is the source of truth elsewhere? - arg_remote = get!(get!(IdDict{Any,Any}, state.remote_args, our_space), arg) do - generate_slot!(state, our_space, arg) - end - if queue.aliasing - for (dep_mod, _, _) in deps - ainfo = aliasing(astate, arg, dep_mod) - data_space = astate.data_locality[ainfo] - nonlocal = our_space != data_space - if nonlocal - # Add copy-to operation (depends on latest owner of arg) - @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$idx][$dep_mod] Enqueueing copy-to: $data_space => $our_space" - arg_local = get!(get!(IdDict{Any,Any}, state.remote_args, data_space), arg) do - generate_slot!(state, data_space, arg) - end - copy_to_scope = our_scope - copy_to_syncdeps = Set{ThunkSyncdep}() - get_write_deps!(state, ainfo, task, write_num, copy_to_syncdeps) - @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$idx][$dep_mod] $(length(copy_to_syncdeps)) syncdeps" - copy_to = Dagger.@spawn scope=copy_to_scope exec_scope=copy_to_scope syncdeps=copy_to_syncdeps meta=true Dagger.move!(dep_mod, our_space, data_space, arg_remote, arg_local) - add_writer!(state, ainfo, copy_to, write_num) - - astate.data_locality[ainfo] = our_space - else - @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$idx][$dep_mod] Skipped copy-to (local): $data_space" - end - end - else - data_space = astate.data_locality[arg] - nonlocal = our_space != data_space - if nonlocal - # Add copy-to operation (depends on latest owner of arg) - @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$idx] Enqueueing copy-to: $data_space => $our_space" - arg_local = get!(get!(IdDict{Any,Any}, state.remote_args, data_space), arg) do - generate_slot!(state, data_space, arg) - end - copy_to_scope = our_scope - copy_to_syncdeps = Set{ThunkSyncdep}() - get_write_deps!(state, arg, task, write_num, copy_to_syncdeps) - @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$idx] $(length(copy_to_syncdeps)) syncdeps" - copy_to = Dagger.@spawn scope=copy_to_scope exec_scope=copy_to_scope syncdeps=copy_to_syncdeps Dagger.move!(identity, our_space, data_space, arg_remote, arg_local) - add_writer!(state, arg, copy_to, write_num) - - astate.data_locality[arg] = our_space - else - @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$idx] Skipped copy-to (local): $data_space" - end - end - spec.fargs[idx].value = arg_remote - end - write_num += 1 - - # Validate that we're not accidentally performing a copy - for (idx, _arg) in enumerate(spec.fargs) - _, deps = unwrap_inout(value(task_args[idx])) - # N.B. We only do this check when the argument supports in-place - # moves, because for the moment, we are not guaranteeing updates or - # write-back of results - arg = value(_arg) - if is_writedep(arg, deps, task) && supports_inplace_move(state, arg) - arg_space = memory_space(arg) - @assert arg_space == our_space "($(repr(value(f))))[$idx] Tried to pass $(typeof(arg)) from $arg_space to $our_space" - end - end - - # Calculate this task's syncdeps - if spec.options.syncdeps === nothing - spec.options.syncdeps = Set{ThunkSyncdep}() - end - syncdeps = spec.options.syncdeps - for (idx, (_, arg)) in enumerate(task_args) - arg, deps = unwrap_inout(arg) - arg = arg isa DTask ? fetch(arg; raw=true) : arg - type_may_alias(typeof(arg)) || continue - supports_inplace_move(state, arg) || continue - if queue.aliasing - for (dep_mod, _, writedep) in deps - ainfo = aliasing(astate, arg, dep_mod) - if writedep - @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$idx][$dep_mod] Syncing as writer" - get_write_deps!(state, ainfo, task, write_num, syncdeps) - else - @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$idx][$dep_mod] Syncing as reader" - get_read_deps!(state, ainfo, task, write_num, syncdeps) - end - end - else - if is_writedep(arg, deps, task) - @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$idx] Syncing as writer" - get_write_deps!(state, arg, task, write_num, syncdeps) - else - @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$idx] Syncing as reader" - get_read_deps!(state, arg, task, write_num, syncdeps) - end - end - end - @dagdebug nothing :spawn_datadeps "($(repr(value(f)))) $(length(syncdeps)) syncdeps" - - # Launch user's task - spec.options.scope = our_scope - spec.options.exec_scope = our_scope - enqueue!(upper_queue, spec=>task) - - # Update read/write tracking for arguments - for (idx, (_, arg)) in enumerate(task_args) - arg, deps = unwrap_inout(arg) - arg = arg isa DTask ? fetch(arg; raw=true) : arg - type_may_alias(typeof(arg)) || continue - if queue.aliasing - for (dep_mod, _, writedep) in deps - ainfo = aliasing(astate, arg, dep_mod) - if writedep - @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$idx][$dep_mod] Set as owner" - add_writer!(state, ainfo, task, write_num) - else - add_reader!(state, ainfo, task, write_num) - end - end - else - if is_writedep(arg, deps, task) - @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$idx] Set as owner" - add_writer!(state, arg, task, write_num) - else - add_reader!(state, arg, task, write_num) - end - end - end - - # Update tracking for return value - populate_return_info!(state, task, our_space) - - write_num += 1 - proc_idx = mod1(proc_idx + 1, length(all_procs)) - end - - # Copy args from remote to local - if queue.aliasing - # We need to replay the writes from all tasks in-order (skipping any - # outdated write owners), to ensure that overlapping writes are applied - # in the correct order - - # First, find the latest owners of each live ainfo - arg_writes = IdDict{Any,Vector{Tuple{AliasingWrapper,<:Any,MemorySpace}}}() - for (task, taskdeps) in state.dependencies - for (_, writedep, ainfo, dep_mod, arg) in taskdeps - writedep || continue - haskey(astate.data_locality, ainfo) || continue - @assert haskey(astate.ainfos_owner, ainfo) "Missing ainfo: $ainfo ($dep_mod($(typeof(arg))))" - - # Skip virtual writes from task result aliasing - # FIXME: Make this less bad - if arg isa DTask && dep_mod === identity && ainfo.inner isa UnknownAliasing - continue - end - - # Skip non-writeable arguments - if !supports_inplace_move(state, arg) - @dagdebug nothing :spawn_datadeps "Skipped copy-from (non-writeable)" - continue - end - - # Get the set of writers - ainfo_writes = get!(Vector{Tuple{AliasingWrapper,<:Any,MemorySpace}}, arg_writes, arg) - - #= FIXME: If we fully overlap any writer, evict them - idxs = findall(ainfo_write->overlaps_all(ainfo, ainfo_write[1]), ainfo_writes) - deleteat!(ainfo_writes, idxs) - =# - - # Make ourselves the latest writer - push!(ainfo_writes, (ainfo, dep_mod, astate.data_locality[ainfo])) - end - end - - # Then, replay the writes from each owner in-order - # FIXME: write_num should advance across overlapping ainfo's, as - # writes must be ordered sequentially - for (arg, ainfo_writes) in arg_writes - if length(ainfo_writes) > 1 - # FIXME: Remove me - deleteat!(ainfo_writes, 1:length(ainfo_writes)-1) - end - for (ainfo, dep_mod, data_remote_space) in ainfo_writes - # Is the source of truth elsewhere? - data_local_space = astate.data_origin[ainfo] - if data_local_space != data_remote_space - # Add copy-from operation - @dagdebug nothing :spawn_datadeps "[$dep_mod] Enqueueing copy-from: $data_remote_space => $data_local_space" - arg_local = get!(get!(IdDict{Any,Any}, state.remote_args, data_local_space), arg) do - generate_slot!(state, data_local_space, arg) - end - arg_remote = state.remote_args[data_remote_space][arg] - @assert arg_remote !== arg_local - data_local_proc = first(processors(data_local_space)) - copy_from_scope = UnionScope(map(ExactScope, collect(processors(data_local_space)))...) - copy_from_syncdeps = Set{ThunkSyncdep}() - get_write_deps!(state, ainfo, nothing, write_num, copy_from_syncdeps) - @dagdebug nothing :spawn_datadeps "$(length(copy_from_syncdeps)) syncdeps" - copy_from = Dagger.@spawn scope=copy_from_scope exec_scope=copy_from_scope syncdeps=copy_from_syncdeps meta=true Dagger.move!(dep_mod, data_local_space, data_remote_space, arg_local, arg_remote) - else - @dagdebug nothing :spawn_datadeps "[$dep_mod] Skipped copy-from (local): $data_remote_space" - end - end - end - else - for arg in keys(astate.data_origin) - # Is the data previously written? - arg, deps = unwrap_inout(arg) - if !type_may_alias(typeof(arg)) - @dagdebug nothing :spawn_datadeps "Skipped copy-from (immutable)" - end - - # Can the data be written back to? - if !supports_inplace_move(state, arg) - @dagdebug nothing :spawn_datadeps "Skipped copy-from (non-writeable)" - end - - # Is the source of truth elsewhere? - data_remote_space = astate.data_locality[arg] - data_local_space = astate.data_origin[arg] - if data_local_space != data_remote_space - # Add copy-from operation - @dagdebug nothing :spawn_datadeps "Enqueueing copy-from: $data_remote_space => $data_local_space" - arg_local = state.remote_args[data_local_space][arg] - arg_remote = state.remote_args[data_remote_space][arg] - @assert arg_remote !== arg_local - data_local_proc = first(processors(data_local_space)) - copy_from_scope = ExactScope(data_local_proc) - copy_from_syncdeps = Set{ThunkSyncdep}() - get_write_deps!(state, arg, nothing, write_num, copy_from_syncdeps) - @dagdebug nothing :spawn_datadeps "$(length(copy_from_syncdeps)) syncdeps" - copy_from = Dagger.@spawn scope=copy_from_scope exec_scope=copy_from_scope syncdeps=copy_from_syncdeps meta=true Dagger.move!(identity, data_local_space, data_remote_space, arg_local, arg_remote) - else - @dagdebug nothing :spawn_datadeps "Skipped copy-from (local): $data_remote_space" - end - end - end -end - -""" - spawn_datadeps(f::Base.Callable; traversal::Symbol=:inorder) - -Constructs a "datadeps" (data dependencies) region and calls `f` within it. -Dagger tasks launched within `f` may wrap their arguments with `In`, `Out`, or -`InOut` to indicate whether the task will read, write, or read+write that -argument, respectively. These argument dependencies will be used to specify -which tasks depend on each other based on the following rules: - -- Dependencies across unrelated arguments are independent; only dependencies on arguments which overlap in memory synchronize with each other -- `InOut` is the same as `In` and `Out` applied simultaneously, and synchronizes with the union of the `In` and `Out` effects -- Any two or more `In` dependencies do not synchronize with each other, and may execute in parallel -- An `Out` dependency synchronizes with any previous `In` and `Out` dependencies -- An `In` dependency synchronizes with any previous `Out` dependencies -- If unspecified, an `In` dependency is assumed - -In general, the result of executing tasks following the above rules will be -equivalent to simply executing tasks sequentially and in order of submission. -Of course, if dependencies are incorrectly specified, undefined behavior (and -unexpected results) may occur. - -Unlike other Dagger tasks, tasks executed within a datadeps region are allowed -to write to their arguments when annotated with `Out` or `InOut` -appropriately. - -At the end of executing `f`, `spawn_datadeps` will wait for all launched tasks -to complete, rethrowing the first error, if any. The result of `f` will be -returned from `spawn_datadeps`. - -The keyword argument `traversal` controls the order that tasks are launched by -the scheduler, and may be set to `:bfs` or `:dfs` for Breadth-First Scheduling -or Depth-First Scheduling, respectively. All traversal orders respect the -dependencies and ordering of the launched tasks, but may provide better or -worse performance for a given set of datadeps tasks. This argument is -experimental and subject to change. -""" -function spawn_datadeps(f::Base.Callable; static::Bool=true, - traversal::Symbol=:inorder, - scheduler::Union{Symbol,Nothing}=nothing, - aliasing::Bool=true, - launch_wait::Union{Bool,Nothing}=nothing) - if !static - throw(ArgumentError("Dynamic scheduling is no longer available")) - end - wait_all(; check_errors=true) do - scheduler = something(scheduler, DATADEPS_SCHEDULER[], :roundrobin)::Symbol - launch_wait = something(launch_wait, DATADEPS_LAUNCH_WAIT[], false)::Bool - if launch_wait - result = spawn_bulk() do - queue = DataDepsTaskQueue(get_options(:task_queue); - traversal, scheduler, aliasing) - with_options(f; task_queue=queue) - distribute_tasks!(queue) - end - else - queue = DataDepsTaskQueue(get_options(:task_queue); - traversal, scheduler, aliasing) - result = with_options(f; task_queue=queue) - distribute_tasks!(queue) - end - return result - end -end -const DATADEPS_SCHEDULER = ScopedValue{Union{Symbol,Nothing}}(nothing) -const DATADEPS_LAUNCH_WAIT = ScopedValue{Union{Bool,Nothing}}(nothing) diff --git a/src/datadeps/aliasing.jl b/src/datadeps/aliasing.jl new file mode 100644 index 000000000..ef83006b9 --- /dev/null +++ b/src/datadeps/aliasing.jl @@ -0,0 +1,817 @@ +import Graphs: SimpleDiGraph, add_edge!, add_vertex!, inneighbors, outneighbors, nv + +export In, Out, InOut, Deps, spawn_datadeps + +#= +============================================================================== + DATADEPS ALIASING AND DATA MOVEMENT SYSTEM +============================================================================== + +This file implements the data dependencies system for Dagger tasks, which allows +tasks to write to their arguments in a controlled manner. The system maintains +data coherency across distributed workers by tracking aliasing relationships +and orchestrating data movement operations. + +OVERVIEW: +--------- +The datadeps system enables parallel execution of tasks that modify shared data +by analyzing memory aliasing relationships and scheduling appropriate data +transfers. The core challenge is maintaining coherency when aliased data (e.g., +an array and its views) needs to be accessed by tasks running on different workers. + +KEY CONCEPTS: +------------- + +1. ALIASING ANALYSIS: + - Every mutable argument is analyzed for its memory access pattern + - Memory spans are computed to determine which bytes in memory are accessed + - Arguments that access overlapping memory spans are considered "aliasing" + - Examples: An array A and view(A, 2:3, 2:3) alias each other + +2. DATA LOCALITY TRACKING: + - The system tracks where the "source of truth" for each piece of data lives + - As tasks execute and modify data, the source of truth may move between workers + - Each argument can have its own independent source of truth location + +3. ALIASED OBJECT MANAGEMENT: + - When copying arguments between workers, the system tracks "aliased objects" + - This ensures that if both an array and its view need to be copied to a worker, + only one copy of the underlying array is made, with the view pointing to it + - The aliased_object!() and move_rewrap() functions manage this sharing + +THE DISTRIBUTED ALIASING PROBLEM: +--------------------------------- + +In a multithreaded environment, aliasing "just works" because all tasks operate +on the same memory. However, in a distributed environment, arguments must be +copied between workers, which breaks aliasing relationships. + +Consider this scenario: +```julia +A = rand(4, 4) +vA = view(A, 2:3, 2:3) + +Dagger.spawn_datadeps() do + Dagger.@spawn inc!(InOut(A), 1) # Task 1: increment all of A + Dagger.@spawn inc!(InOut(vA), 2) # Task 2: increment view of A +end +``` + +MULTITHREADED BEHAVIOR (WORKS): +- Both tasks run on the same worker +- They operate on the same memory, with proper dependency tracking +- Task dependencies ensure correct ordering (e.g., Task 1 then Task 2) + +DISTRIBUTED BEHAVIOR (THE PROBLEM): +- Each argument must be copied to the destination worker +- Without special handling, we would copy A and vA independently to another worker +- This creates two separate arrays, breaking the aliasing relationship between A and vA + +THE SOLUTION - PARTIAL DATA MOVEMENT: +------------------------------------- + +The datadeps system solves this by: + +1. UNIFIED ALLOCATION: + - When copying aliased objects, ensure only one underlying array exists per worker + - Use aliased_object!() to detect and reuse existing allocations + - Views on the destination worker point to the shared underlying array + +2. PARTIAL DATA TRANSFER: + - Instead of copying entire objects, only transfer the "dirty" regions + - This minimizes network traffic and maximizes parallelism + - Uses the move!(dep_mod, ...) function with dependency modifiers + +3. REMAINDER TRACKING: + - When a partial region is updated, track what parts still need updating + - Before a task needs the full object, copy the remaining "clean" regions + - This preserves all updates while avoiding overwrites + +EXAMPLE EXECUTION FLOW: +----------------------- + +Given: A = 4x4 array, vA = view(A, 2:3, 2:3) +Tasks: T1 modifies InOut(A), T2 modifies InOut(vA) + +1. INITIAL STATE: + - A and vA both exist on worker0 (main worker) + - A's data_locality = worker0, vA's data_locality = worker0 + +2. T1 SCHEDULED ON WORKER1: + - Copy A from worker0 to worker1 + - T1 executes, modifying all of A on worker1 + - Update: A's data_locality = worker1, A is now "dirty" on worker1 + +3. T2 SCHEDULED ON WORKER2: + - T2 needs vA, but vA aliases with A (which was modified by T1) + - Copy vA-region of A from worker1 to worker2 + - This is a PARTIAL copy - only the 2:3, 2:3 region + - Create vA on worker2 pointing to the appropriate region + - T2 executes, modifying vA region on worker2 + - Update: vA's data_locality = worker2 + +4. FINAL SYNCHRONIZATION: + - Some future task needs the complete A + - A needs to be assembled from: worker1 (non-vA regions) + worker2 (vA region) + - REMAINDER COPY: Copy non-vA regions from worker1 to worker2 + - OR INVERSE: Copy vA-region from worker2 to worker1, then copy full A + +MEMORY SPAN COMPUTATION: +------------------------ + +The system uses memory spans to determine aliasing and compute remainders: + +- ContiguousAliasing: Single contiguous memory region (e.g., full array) +- StridedAliasing: Multiple non-contiguous regions (e.g., SubArray) +- DiagonalAliasing: Diagonal elements only (e.g., Diagonal(A)) +- TriangularAliasing: Triangular regions (e.g., UpperTriangular(A)) + +Remainder computation involves: +1. Computing memory spans for all overlapping aliasing objects +2. Finding the set difference: full_object_spans - updated_spans +3. Creating a "remainder aliasing" object representing the not-yet-updated regions +4. Performing move! with this remainder object to copy only needed data + +DATA MOVEMENT FUNCTIONS: +------------------------ + +move!(dep_mod, to_space, from_space, to, from): +- The core in-place data movement function +- dep_mod specifies which part of the data to copy (identity, UpperTriangular, etc.) +- Supports partial copies via dependency modifiers + +move_rewrap(): +- Handles copying of wrapped objects (SubArrays, ChunkViews) +- Ensures aliased objects are reused on destination worker + +enqueue_copy_to!(): +- Schedules data movement tasks before user tasks +- Ensures data is up-to-date on the worker where a task will run + +CURRENT LIMITATIONS AND TODOS: +------------------------------- + +1. REMAINDER COMPUTATION: + - The system currently handles simple overlaps but needs sophisticated + remainder calculation for complex aliasing patterns + - Need functions to compute span set differences + +2. ORDERING DEPENDENCIES: + - Need to ensure remainder copies happen in correct order + - Must not overwrite more recent updates with stale data + +3. COMPLEX ALIASING PATTERNS: + - Multiple overlapping views of the same array + - Nested aliasing structures (views of views) + - Mixed aliasing types (diagonal + triangular regions) + +4. PERFORMANCE OPTIMIZATION: + - Minimize number of copy operations + - Batch compatible transfers + - Optimize for common access patterns +=# + +"Specifies a read-only dependency." +struct In{T} + x::T +end +"Specifies a write-only dependency." +struct Out{T} + x::T +end +"Specifies a read-write dependency." +struct InOut{T} + x::T +end +"Specifies one or more dependencies." +struct Deps{T,DT<:Tuple} + x::T + deps::DT +end +Deps(x, deps...) = Deps(x, deps) + +chunktype(::In{T}) where T = T +chunktype(::Out{T}) where T = T +chunktype(::InOut{T}) where T = T +chunktype(::Deps{T,DT}) where {T,DT} = T + +function unwrap_inout(arg) + readdep = false + writedep = false + if arg isa In + readdep = true + arg = arg.x + elseif arg isa Out + writedep = true + arg = arg.x + elseif arg isa InOut + readdep = true + writedep = true + arg = arg.x + elseif arg isa Deps + alldeps = Tuple[] + for dep in arg.deps + dep_mod, inner_deps = unwrap_inout(dep) + for (_, readdep, writedep) in inner_deps + push!(alldeps, (dep_mod, readdep, writedep)) + end + end + arg = arg.x + return arg, alldeps + else + readdep = true + end + return arg, Tuple[(identity, readdep, writedep)] +end + +_identity_hash(arg, h::UInt=UInt(0)) = ismutable(arg) ? objectid(arg) : hash(arg, h) +_identity_hash(arg::Chunk, h::UInt=UInt(0)) = hash(arg.handle, hash(Chunk, h)) +_identity_hash(arg::SubArray, h::UInt=UInt(0)) = hash(arg.indices, hash(arg.offset1, hash(arg.stride1, _identity_hash(arg.parent, h)))) +_identity_hash(arg::CartesianIndices, h::UInt=UInt(0)) = hash(arg.indices, hash(typeof(arg), h)) + +struct ArgumentWrapper + arg + dep_mod + hash::UInt + + function ArgumentWrapper(arg, dep_mod) + h = hash(dep_mod) + h = _identity_hash(arg, h) + return new(arg, dep_mod, h) + end +end +Base.hash(aw::ArgumentWrapper) = hash(ArgumentWrapper, aw.hash) +Base.:(==)(aw1::ArgumentWrapper, aw2::ArgumentWrapper) = + aw1.hash == aw2.hash +Base.isequal(aw1::ArgumentWrapper, aw2::ArgumentWrapper) = + aw1.hash == aw2.hash + +struct HistoryEntry + ainfo::AliasingWrapper + space::MemorySpace + write_num::Int +end + +struct DataDepsState + # The mapping of original raw argument to its Chunk + raw_arg_to_chunk::IdDict{Any,Chunk} + + # The origin memory space of each argument + # Used to track the original location of an argument, for final copy-from + arg_origin::IdDict{Any,MemorySpace} + + # The mapping of memory space to argument to remote argument copies + # Used to replace an argument with its remote copy + remote_args::Dict{MemorySpace,IdDict{Any,Chunk}} + + # The mapping of remote argument to original argument + remote_arg_to_original::IdDict{Any,Any} + + # The mapping of original argument wrapper to remote argument wrapper + remote_arg_w::Dict{ArgumentWrapper,Dict{MemorySpace,ArgumentWrapper}} + + # The mapping of ainfo to argument and dep_mod + # Used to lookup which argument and dep_mod a given ainfo is generated from + # N.B. This is a mapping for remote argument copies + ainfo_arg::Dict{AliasingWrapper,Set{ArgumentWrapper}} + + # The history of writes (direct or indirect) to each argument and dep_mod, in terms of ainfos directly written to, and the memory space they were written to + # Updated when a new write happens on an overlapping ainfo + # Used by remainder copies to track which portions of an argument and dep_mod were written to elsewhere, through another argument + arg_history::Dict{ArgumentWrapper,Vector{HistoryEntry}} + + # The mapping of memory space and argument to the memory space of the last direct write + # Used by remainder copies to lookup the "backstop" if any portion of the target ainfo is not updated by the remainder + arg_owner::Dict{ArgumentWrapper,MemorySpace} + + # The overlap of each argument with every other argument, based on the ainfo overlaps + # Incrementally updated as new ainfos are created + # Used for fast history updates + arg_overlaps::Dict{ArgumentWrapper,Set{ArgumentWrapper}} + + # The mapping of, for a given memory space, the backing Chunks that an ainfo references + # Used by slot generation to replace the backing Chunks during move + ainfo_backing_chunk::Dict{MemorySpace,Dict{AbstractAliasing,Chunk}} + + # Cache of argument's supports_inplace_move query result + supports_inplace_cache::IdDict{Any,Bool} + + # Cache of argument and dep_mod to ainfo + # N.B. This is a mapping for remote argument copies + ainfo_cache::Dict{ArgumentWrapper,AliasingWrapper} + + # The oracle for aliasing lookups + # Used to populate ainfos_overlaps efficiently + ainfos_lookup::AliasingLookup + + # The overlapping ainfos for each ainfo + # Incrementally updated as new ainfos are created + # Used for fast will_alias lookups + ainfos_overlaps::Dict{AliasingWrapper,Set{AliasingWrapper}} + + # Track writers ("owners") and readers + # Updated as new writer and reader tasks are launched + # Used by task dependency tracking to calculate syncdeps and ensure correct launch ordering + ainfos_owner::Dict{AliasingWrapper,Union{Pair{DTask,Int},Nothing}} + ainfos_readers::Dict{AliasingWrapper,Vector{Pair{DTask,Int}}} + + function DataDepsState(aliasing::Bool) + if !aliasing + @warn "aliasing=false is no longer supported, aliasing is now always enabled" maxlog=1 + end + + arg_to_chunk = IdDict{Any,Chunk}() + arg_origin = IdDict{Any,MemorySpace}() + remote_args = Dict{MemorySpace,IdDict{Any,Any}}() + remote_arg_to_original = IdDict{Any,Any}() + remote_arg_w = Dict{ArgumentWrapper,Dict{MemorySpace,ArgumentWrapper}}() + ainfo_arg = Dict{AliasingWrapper,Set{ArgumentWrapper}}() + arg_owner = Dict{ArgumentWrapper,MemorySpace}() + arg_overlaps = Dict{ArgumentWrapper,Set{ArgumentWrapper}}() + ainfo_backing_chunk = Dict{MemorySpace,Dict{AbstractAliasing,Chunk}}() + arg_history = Dict{ArgumentWrapper,Vector{HistoryEntry}}() + + supports_inplace_cache = IdDict{Any,Bool}() + ainfo_cache = Dict{ArgumentWrapper,AliasingWrapper}() + + ainfos_lookup = AliasingLookup() + ainfos_overlaps = Dict{AliasingWrapper,Set{AliasingWrapper}}() + + ainfos_owner = Dict{AliasingWrapper,Union{Pair{DTask,Int},Nothing}}() + ainfos_readers = Dict{AliasingWrapper,Vector{Pair{DTask,Int}}}() + + return new(arg_to_chunk, arg_origin, remote_args, remote_arg_to_original, remote_arg_w, ainfo_arg, arg_owner, arg_overlaps, ainfo_backing_chunk, arg_history, + supports_inplace_cache, ainfo_cache, ainfos_lookup, ainfos_overlaps, ainfos_owner, ainfos_readers) + end +end + +function supports_inplace_move(state::DataDepsState, arg) + return get!(state.supports_inplace_cache, arg) do + return supports_inplace_move(arg) + end +end + +# Determine which arguments could be written to, and thus need tracking +"Whether `arg` is written to by `task`." +function is_writedep(arg, deps, task::DTask) + return any(dep->dep[3], deps) +end + +# Aliasing state setup +function populate_task_info!(state::DataDepsState, task_args, spec::DTaskSpec, task::DTask) + # Track the task's arguments and access patterns + return map_or_ntuple(task_args) do idx + _arg = task_args[idx] + + # Unwrap the argument + _arg_with_deps = value(_arg) + pos = _arg.pos + + # Unwrap In/InOut/Out wrappers and record dependencies + arg_pre_unwrap, deps = unwrap_inout(_arg_with_deps) + + # Unwrap the Chunk underlying any DTask arguments + arg = arg_pre_unwrap isa DTask ? fetch(arg_pre_unwrap; raw=true) : arg_pre_unwrap + + # Skip non-aliasing arguments or arguments that don't support in-place move + may_alias = type_may_alias(typeof(arg)) + inplace_move = may_alias && supports_inplace_move(state, arg) + if !may_alias || !inplace_move + arg_w = ArgumentWrapper(arg, identity) + if is_typed(spec) + return TypedDataDepsTaskArgument(arg, pos, may_alias, inplace_move, (DataDepsTaskDependency(arg_w, false, false),)) + else + return DataDepsTaskArgument(arg, pos, may_alias, inplace_move, [DataDepsTaskDependency(arg_w, false, false)]) + end + end + + # Generate a Chunk for the argument if necessary + if haskey(state.raw_arg_to_chunk, arg) + arg_chunk = state.raw_arg_to_chunk[arg] + else + if !(arg isa Chunk) + arg_chunk = tochunk(arg) + state.raw_arg_to_chunk[arg] = arg_chunk + else + state.raw_arg_to_chunk[arg] = arg + arg_chunk = arg + end + end + + # Track the origin space of the argument + origin_space = memory_space(arg_chunk) + state.arg_origin[arg_chunk] = origin_space + state.remote_arg_to_original[arg_chunk] = arg_chunk + + # Populate argument info for all aliasing dependencies + # And return the argument, dependencies, and ArgumentWrappers + if is_typed(spec) + deps = Tuple(DataDepsTaskDependency(arg_chunk, dep) for dep in deps) + map_or_ntuple(deps) do dep_idx + dep = deps[dep_idx] + # Populate argument info + populate_argument_info!(state, dep.arg_w, origin_space) + end + return TypedDataDepsTaskArgument(arg_chunk, pos, may_alias, inplace_move, deps) + else + deps = [DataDepsTaskDependency(arg_chunk, dep) for dep in deps] + map_or_ntuple(deps) do dep_idx + dep = deps[dep_idx] + # Populate argument info + populate_argument_info!(state, dep.arg_w, origin_space) + end + return DataDepsTaskArgument(arg_chunk, pos, may_alias, inplace_move, deps) + end + end +end +function populate_argument_info!(state::DataDepsState, arg_w::ArgumentWrapper, origin_space::MemorySpace) + # Initialize ownership and history + if !haskey(state.arg_owner, arg_w) + # N.B. This is valid (even if the backing data is up-to-date elsewhere), + # because we only use this to track the "backstop" if any portion of the + # target ainfo is not updated by the remainder (at which point, this + # is thus the correct owner). + state.arg_owner[arg_w] = origin_space + + # Initialize the overlap set + state.arg_overlaps[arg_w] = Set{ArgumentWrapper}() + end + if !haskey(state.arg_history, arg_w) + state.arg_history[arg_w] = Vector{HistoryEntry}() + end + + # Calculate the ainfo (which will populate ainfo structures and merge history) + aliasing!(state, origin_space, arg_w) +end +# N.B. arg_w must be the original argument wrapper, not a remote copy +function aliasing!(state::DataDepsState, target_space::MemorySpace, arg_w::ArgumentWrapper) + if haskey(state.remote_arg_w, arg_w) && haskey(state.remote_arg_w[arg_w], target_space) + remote_arg_w = @inbounds state.remote_arg_w[arg_w][target_space] + remote_arg = remote_arg_w.arg + else + # Grab the remote copy of the argument, and calculate the ainfo + remote_arg = get_or_generate_slot!(state, target_space, arg_w.arg) + remote_arg_w = ArgumentWrapper(remote_arg, arg_w.dep_mod) + get!(Dict{MemorySpace,ArgumentWrapper}, state.remote_arg_w, arg_w)[target_space] = remote_arg_w + end + + # Check if we already have the result cached + if haskey(state.ainfo_cache, remote_arg_w) + return state.ainfo_cache[remote_arg_w] + end + + # Calculate the ainfo + ainfo = AliasingWrapper(aliasing(remote_arg, arg_w.dep_mod)) + + # Cache the result + state.ainfo_cache[remote_arg_w] = ainfo + + # Update the mapping of ainfo to argument and dep_mod + if !haskey(state.ainfo_arg, ainfo) + state.ainfo_arg[ainfo] = Set{ArgumentWrapper}([remote_arg_w]) + end + push!(state.ainfo_arg[ainfo], remote_arg_w) + + # Populate info for the new ainfo + populate_ainfo!(state, arg_w, ainfo, target_space) + + return ainfo +end +function populate_ainfo!(state::DataDepsState, original_arg_w::ArgumentWrapper, target_ainfo::AliasingWrapper, target_space::MemorySpace) + if !haskey(state.ainfos_owner, target_ainfo) + # Add ourselves to the lookup oracle + ainfo_idx = push!(state.ainfos_lookup, target_ainfo) + + # Find overlapping ainfos + overlaps = Set{AliasingWrapper}() + push!(overlaps, target_ainfo) + for other_ainfo in intersect(state.ainfos_lookup, target_ainfo; ainfo_idx) + target_ainfo == other_ainfo && continue + # Mark us and them as overlapping + push!(overlaps, other_ainfo) + push!(state.ainfos_overlaps[other_ainfo], target_ainfo) + + # Add overlapping history to our own + for other_remote_arg_w in state.ainfo_arg[other_ainfo] + other_arg = state.remote_arg_to_original[other_remote_arg_w.arg] + other_arg_w = ArgumentWrapper(other_arg, other_remote_arg_w.dep_mod) + push!(state.arg_overlaps[original_arg_w], other_arg_w) + push!(state.arg_overlaps[other_arg_w], original_arg_w) + merge_history!(state, original_arg_w, other_arg_w) + end + end + state.ainfos_overlaps[target_ainfo] = overlaps + + # Initialize owner and readers + state.ainfos_owner[target_ainfo] = nothing + state.ainfos_readers[target_ainfo] = Pair{DTask,Int}[] + end +end +function merge_history!(state::DataDepsState, arg_w::ArgumentWrapper, other_arg_w::ArgumentWrapper) + history = state.arg_history[arg_w] + @opcounter :merge_history + @opcounter :merge_history_complexity length(history) + origin_space = state.arg_origin[other_arg_w.arg] + for other_entry in state.arg_history[other_arg_w] + write_num_tuple = HistoryEntry(AliasingWrapper(NoAliasing()), origin_space, other_entry.write_num) + range = searchsorted(history, write_num_tuple; by=x->x.write_num) + if !isempty(range) + # Find and skip duplicates + match = false + for source_idx in range + source_entry = history[source_idx] + if source_entry.ainfo == other_entry.ainfo && + source_entry.space == other_entry.space && + source_entry.write_num == other_entry.write_num + match = true + break + end + end + match && continue + + # Insert at the first position + idx = first(range) + else + # Insert at the last position + idx = length(history) + 1 + end + insert!(history, idx, other_entry) + end +end +function truncate_history!(state::DataDepsState, arg_w::ArgumentWrapper) + # FIXME: Do this continuously if possible + if haskey(state.arg_history, arg_w) && length(state.arg_history[arg_w]) > 100000 + origin_space = state.arg_origin[arg_w.arg] + @opcounter :truncate_history + _, last_idx = compute_remainder_for_arg!(state, origin_space, arg_w, 0; compute_syncdeps=false) + if last_idx > 0 + @opcounter :truncate_history_removed last_idx + deleteat!(state.arg_history[arg_w], 1:last_idx) + end + end +end + +""" + supports_inplace_move(x) -> Bool + +Returns `false` if `x` doesn't support being copied into from another object +like `x`, via `move!`. This is used in `spawn_datadeps` to prevent attempting +to copy between values which don't support mutation or otherwise don't have an +implemented `move!` and want to skip in-place copies. When this returns +`false`, datadeps will instead perform out-of-place copies for each non-local +use of `x`, and the data in `x` will not be updated when the `spawn_datadeps` +region returns. +""" +supports_inplace_move(x) = true +supports_inplace_move(t::DTask) = supports_inplace_move(fetch(t; raw=true)) +function supports_inplace_move(c::Chunk) + # FIXME: Use MemPool.access_ref + pid = root_worker_id(c.processor) + if pid == myid() + return supports_inplace_move(poolget(c.handle)) + else + return remotecall_fetch(supports_inplace_move, pid, c) + end +end +supports_inplace_move(::Function) = false + +# Read/write dependency management +function get_write_deps!(state::DataDepsState, dest_space::MemorySpace, ainfo::AbstractAliasing, write_num, syncdeps) + # We need to sync with both writers and readers + _get_write_deps!(state, dest_space, ainfo, write_num, syncdeps) + _get_read_deps!(state, dest_space, ainfo, write_num, syncdeps) +end +function get_read_deps!(state::DataDepsState, dest_space::MemorySpace, ainfo::AbstractAliasing, write_num, syncdeps) + # We only need to sync with writers, not readers + _get_write_deps!(state, dest_space, ainfo, write_num, syncdeps) +end + +function _get_write_deps!(state::DataDepsState, dest_space::MemorySpace, ainfo::AbstractAliasing, write_num, syncdeps) + ainfo.inner isa NoAliasing && return + for other_ainfo in state.ainfos_overlaps[ainfo] + other_task_write_num = state.ainfos_owner[other_ainfo] + @dagdebug nothing :spawn_datadeps_sync "Considering sync with writer via $ainfo -> $other_ainfo" + other_task_write_num === nothing && continue + other_task, other_write_num = other_task_write_num + write_num == other_write_num && continue + @dagdebug nothing :spawn_datadeps_sync "Sync with writer via $ainfo -> $other_ainfo" + push!(syncdeps, ThunkSyncdep(other_task)) + end +end +function _get_read_deps!(state::DataDepsState, dest_space::MemorySpace, ainfo::AbstractAliasing, write_num, syncdeps) + ainfo.inner isa NoAliasing && return + for other_ainfo in state.ainfos_overlaps[ainfo] + @dagdebug nothing :spawn_datadeps_sync "Considering sync with reader via $ainfo -> $other_ainfo" + other_tasks = state.ainfos_readers[other_ainfo] + for (other_task, other_write_num) in other_tasks + write_num == other_write_num && continue + @dagdebug nothing :spawn_datadeps_sync "Sync with reader via $ainfo -> $other_ainfo" + push!(syncdeps, ThunkSyncdep(other_task)) + end + end +end +function add_writer!(state::DataDepsState, arg_w::ArgumentWrapper, dest_space::MemorySpace, ainfo::AbstractAliasing, task, write_num) + state.ainfos_owner[ainfo] = task=>write_num + empty!(state.ainfos_readers[ainfo]) + + # Clear the history for this target, since this is a new write event + empty!(state.arg_history[arg_w]) + + # Add our own history + push!(state.arg_history[arg_w], HistoryEntry(ainfo, dest_space, write_num)) + + # Find overlapping arguments and update their history + for other_arg_w in state.arg_overlaps[arg_w] + other_arg_w == arg_w && continue + push!(state.arg_history[other_arg_w], HistoryEntry(ainfo, dest_space, write_num)) + end + + # Record the last place we were fully written to + state.arg_owner[arg_w] = dest_space + + # Not necessary to assert a read, but conceptually it's true + add_reader!(state, arg_w, dest_space, ainfo, task, write_num) +end +function add_reader!(state::DataDepsState, arg_w::ArgumentWrapper, dest_space::MemorySpace, ainfo::AbstractAliasing, task, write_num) + push!(state.ainfos_readers[ainfo], task=>write_num) +end + +# Make a copy of each piece of data on each worker +# memory_space => {arg => copy_of_arg} +isremotehandle(x) = false +isremotehandle(x::DTask) = true +isremotehandle(x::Chunk) = true +function generate_slot!(state::DataDepsState, dest_space, data) + # N.B. We do not perform any sync/copy with the current owner of the data, + # because all we want here is to make a copy of some version of the data, + # even if the data is not up to date. + orig_space = memory_space(data) + to_proc = first(processors(dest_space)) + from_proc = first(processors(orig_space)) + dest_space_args = get!(IdDict{Any,Any}, state.remote_args, dest_space) + if !haskey(state.ainfo_backing_chunk, dest_space) + state.ainfo_backing_chunk[dest_space] = Dict{AbstractAliasing,Chunk}() + end + # FIXME: tochunk the cache just once per space + aliased_object_cache = AliasedObjectCache(tochunk(state.ainfo_backing_chunk[dest_space])) + ctx = Sch.eager_context() + id = rand(Int) + @maybelog ctx timespan_start(ctx, :move, (;thunk_id=0, id, position=ArgPosition(), processor=to_proc), (;f=nothing, data)) + data_chunk = move_rewrap(aliased_object_cache, from_proc, to_proc, orig_space, dest_space, data) + @maybelog ctx timespan_finish(ctx, :move, (;thunk_id=0, id, position=ArgPosition(), processor=to_proc), (;f=nothing, data=data_chunk)) + @assert memory_space(data_chunk) == dest_space "space mismatch! $dest_space (dest) != $(memory_space(data_chunk)) (actual) ($(typeof(data)) (data) vs. $(typeof(data_chunk)) (chunk)), spaces ($orig_space -> $dest_space)" + dest_space_args[data] = data_chunk + state.remote_arg_to_original[data_chunk] = data + + return dest_space_args[data] +end +function get_or_generate_slot!(state, dest_space, data) + @assert !(data isa ArgumentWrapper) + if !haskey(state.remote_args, dest_space) + state.remote_args[dest_space] = IdDict{Any,Any}() + end + if !haskey(state.remote_args[dest_space], data) + return generate_slot!(state, dest_space, data) + end + return state.remote_args[dest_space][data] +end +struct AliasedObjectCache + chunk::Chunk +end +@warn "Document these public methods" maxlog=1 +function Base.haskey(cache::AliasedObjectCache, ainfo::AbstractAliasing) + wid = root_worker_id(cache.chunk) + if wid != myid() + return remotecall_fetch(haskey, wid, cache, ainfo) + end + cache_raw = unwrap(cache.chunk)::Dict{AbstractAliasing,Chunk} + return haskey(cache_raw, ainfo) +end +function Base.getindex(cache::AliasedObjectCache, ainfo::AbstractAliasing) + wid = root_worker_id(cache.chunk) + if wid != myid() + return remotecall_fetch(getindex, wid, cache, ainfo) + end + cache_raw = unwrap(cache.chunk)::Dict{AbstractAliasing,Chunk} + return getindex(cache_raw, ainfo) +end +function Base.setindex!(cache::AliasedObjectCache, value::Chunk, ainfo::AbstractAliasing) + wid = root_worker_id(cache.chunk) + if wid != myid() + return remotecall_fetch(setindex!, wid, cache, value, ainfo) + end + cache_raw = unwrap(cache.chunk)::Dict{AbstractAliasing,Chunk} + cache_raw[ainfo] = value + return +end +function aliased_object!(f, cache::AliasedObjectCache, x; ainfo=aliasing(x, identity)) + if haskey(cache, ainfo) + return cache[ainfo] + else + y = f(x) + @assert y isa Chunk "Didn't get a Chunk from functor" + cache[ainfo] = y + return y + end +end +function remotecall_endpoint(f, from_proc, to_proc, from_space, to_space, data) + to_w = root_worker_id(to_proc) + if to_w == myid() + data_converted = f(move(from_proc, to_proc, data)) + return tochunk(data_converted, to_proc) + end + return remotecall_fetch(to_w, from_proc, to_proc, to_space, data) do from_proc, to_proc, to_space, data + data_converted = f(move(from_proc, to_proc, data)) + return tochunk(data_converted, to_proc) + end +end +function rewrap_aliased_object!(cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, x) + return aliased_object!(cache, x) do x + return remotecall_endpoint(identity, from_proc, to_proc, from_space, to_space, x) + end +end +function move_rewrap(cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, data::Chunk) + # Unwrap so that we hit the right dispatch + wid = root_worker_id(data) + if wid != myid() + return remotecall_fetch(move_rewrap, wid, cache, from_proc, to_proc, from_space, to_space, data) + end + data_raw = unwrap(data) + return move_rewrap(cache, from_proc, to_proc, from_space, to_space, data_raw) +end +function move_rewrap(cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, data) + # For generic data + return aliased_object!(cache, data) do data + return remotecall_endpoint(identity, from_proc, to_proc, from_space, to_space, data) + end +end +function move_rewrap(cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, v::SubArray) + to_w = root_worker_id(to_proc) + p_chunk = rewrap_aliased_object!(cache, from_proc, to_proc, from_space, to_space, parent(v)) + inds = parentindices(v) + return remotecall_fetch(to_w, from_proc, to_proc, from_space, to_space, p_chunk, inds) do from_proc, to_proc, from_space, to_space, p_chunk, inds + p_new = move(from_proc, to_proc, p_chunk) + v_new = view(p_new, inds...) + return tochunk(v_new, to_proc) + end +end +# FIXME: Do this programmatically via recursive dispatch +for wrapper in (UpperTriangular, LowerTriangular, UnitUpperTriangular, UnitLowerTriangular) + @eval function move_rewrap(cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, v::$(wrapper)) + to_w = root_worker_id(to_proc) + p_chunk = rewrap_aliased_object!(cache, from_proc, to_proc, from_space, to_space, parent(v)) + return remotecall_fetch(to_w, from_proc, to_proc, from_space, to_space, p_chunk) do from_proc, to_proc, from_space, to_space, p_chunk + p_new = move(from_proc, to_proc, p_chunk) + v_new = $(wrapper)(p_new) + return tochunk(v_new, to_proc) + end + end +end +function move_rewrap(cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, v::Base.RefValue) + to_w = root_worker_id(to_proc) + p_chunk = rewrap_aliased_object!(cache, from_proc, to_proc, from_space, to_space, v[]) + return remotecall_fetch(to_w, from_proc, to_proc, from_space, to_space, p_chunk) do from_proc, to_proc, from_space, to_space, p_chunk + p_new = move(from_proc, to_proc, p_chunk) + v_new = Ref(p_new) + return tochunk(v_new, to_proc) + end +end +#= +function move_rewrap_recursive(cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, x::T) where T + if isstructtype(T) + # Check all object fields (recursive) + for field in fieldnames(T) + value = getfield(x, field) + new_value = aliased_object!(cache, value) do value + return move_rewrap_recursive(cache, from_proc, to_proc, from_space, to_space, value) + end + setfield!(x, field, new_value) + end + return x + else + @warn "Cannot move-rewrap object of type $T" + return x + end +end +move_rewrap(cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, x::String) = x # FIXME: Not necessarily true +move_rewrap(cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, x::Symbol) = x +move_rewrap(cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, x::Type) = x +=# + +struct DataDepsSchedulerState + task_to_spec::Dict{DTask,DTaskSpec} + assignments::Dict{DTask,MemorySpace} + dependencies::Dict{DTask,Set{DTask}} + task_completions::Dict{DTask,UInt64} + space_completions::Dict{MemorySpace,UInt64} + capacities::Dict{MemorySpace,Int} + + function DataDepsSchedulerState() + return new(Dict{DTask,DTaskSpec}(), + Dict{DTask,MemorySpace}(), + Dict{DTask,Set{DTask}}(), + Dict{DTask,UInt64}(), + Dict{MemorySpace,UInt64}(), + Dict{MemorySpace,Int}()) + end +end \ No newline at end of file diff --git a/src/datadeps/chunkview.jl b/src/datadeps/chunkview.jl new file mode 100644 index 000000000..42f32cca9 --- /dev/null +++ b/src/datadeps/chunkview.jl @@ -0,0 +1,58 @@ +struct ChunkView{N} + chunk::Chunk + slices::NTuple{N, Union{Int, AbstractRange{Int}, Colon}} +end + +function Base.view(c::Chunk, slices...) + if c.domain isa ArrayDomain + nd, sz = ndims(c.domain), size(c.domain) + nd == length(slices) || throw(DimensionMismatch("Expected $nd slices, got $(length(slices))")) + + for (i, s) in enumerate(slices) + if s isa Int + 1 ≤ s ≤ sz[i] || throw(ArgumentError("Index $s out of bounds for dimension $i (size $(sz[i]))")) + elseif s isa AbstractRange + isempty(s) && continue + 1 ≤ first(s) ≤ last(s) ≤ sz[i] || throw(ArgumentError("Range $s out of bounds for dimension $i (size $(sz[i]))")) + elseif s === Colon() + continue + else + throw(ArgumentError("Invalid slice type $(typeof(s)) at dimension $i, Expected Type of Int, AbstractRange, or Colon")) + end + end + end + + return ChunkView(c, slices) +end + +Base.view(c::DTask, slices...) = view(fetch(c; raw=true), slices...) + +function aliasing(x::ChunkView{N}) where N + return remotecall_fetch(root_worker_id(x.chunk.processor), x.chunk, x.slices) do x, slices + x = unwrap(x) + v = view(x, slices...) + return aliasing(v) + end +end +memory_space(x::ChunkView) = memory_space(x.chunk) +isremotehandle(x::ChunkView) = true + +function move_rewrap(cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, slice::ChunkView) + to_w = root_worker_id(to_proc) + p_chunk = rewrap_aliased_object!(cache, from_proc, to_proc, from_space, to_space, slice.chunk) + return remotecall_fetch(to_w, from_proc, to_proc, from_space, to_space, p_chunk, slice.slices) do from_proc, to_proc, from_space, to_space, p_chunk, inds + p_new = move(from_proc, to_proc, p_chunk) + v_new = view(p_new, inds...) + return tochunk(v_new, to_proc) + end +end +function move(from_proc::Processor, to_proc::Processor, slice::ChunkView) + to_w = root_worker_id(to_proc) + return remotecall_fetch(to_w, from_proc, to_proc, slice.chunk, slice.slices) do from_proc, to_proc, chunk, slices + chunk_new = move(from_proc, to_proc, chunk) + v_new = view(chunk_new, slices...) + return tochunk(v_new, to_proc) + end +end + +Base.fetch(slice::ChunkView) = view(fetch(slice.chunk), slice.slices...) \ No newline at end of file diff --git a/src/datadeps/queue.jl b/src/datadeps/queue.jl new file mode 100644 index 000000000..c7b5e2bc1 --- /dev/null +++ b/src/datadeps/queue.jl @@ -0,0 +1,557 @@ +struct DataDepsTaskQueue <: AbstractTaskQueue + # The queue above us + upper_queue::AbstractTaskQueue + # The set of tasks that have already been seen + seen_tasks::Union{Vector{DTaskPair},Nothing} + # The data-dependency graph of all tasks + g::Union{SimpleDiGraph{Int},Nothing} + # The mapping from task to graph ID + task_to_id::Union{Dict{DTask,Int},Nothing} + # How to traverse the dependency graph when launching tasks + traversal::Symbol + # Which scheduler to use to assign tasks to processors + scheduler::Symbol + + # Whether aliasing across arguments is possible + # The fields following only apply when aliasing==true + aliasing::Bool + + function DataDepsTaskQueue(upper_queue; + traversal::Symbol=:inorder, + scheduler::Symbol=:naive, + aliasing::Bool=true) + seen_tasks = DTaskPair[] + g = SimpleDiGraph() + task_to_id = Dict{DTask,Int}() + return new(upper_queue, seen_tasks, g, task_to_id, traversal, scheduler, + aliasing) + end +end + +function enqueue!(queue::DataDepsTaskQueue, pair::DTaskPair) + push!(queue.seen_tasks, pair) +end +function enqueue!(queue::DataDepsTaskQueue, pairs::Vector{DTaskPair}) + append!(queue.seen_tasks, pairs) +end + +""" + spawn_datadeps(f::Base.Callable; traversal::Symbol=:inorder) + +Constructs a "datadeps" (data dependencies) region and calls `f` within it. +Dagger tasks launched within `f` may wrap their arguments with `In`, `Out`, or +`InOut` to indicate whether the task will read, write, or read+write that +argument, respectively. These argument dependencies will be used to specify +which tasks depend on each other based on the following rules: + +- Dependencies across unrelated arguments are independent; only dependencies on arguments which overlap in memory synchronize with each other +- `InOut` is the same as `In` and `Out` applied simultaneously, and synchronizes with the union of the `In` and `Out` effects +- Any two or more `In` dependencies do not synchronize with each other, and may execute in parallel +- An `Out` dependency synchronizes with any previous `In` and `Out` dependencies +- An `In` dependency synchronizes with any previous `Out` dependencies +- If unspecified, an `In` dependency is assumed + +In general, the result of executing tasks following the above rules will be +equivalent to simply executing tasks sequentially and in order of submission. +Of course, if dependencies are incorrectly specified, undefined behavior (and +unexpected results) may occur. + +Unlike other Dagger tasks, tasks executed within a datadeps region are allowed +to write to their arguments when annotated with `Out` or `InOut` +appropriately. + +At the end of executing `f`, `spawn_datadeps` will wait for all launched tasks +to complete, rethrowing the first error, if any. The result of `f` will be +returned from `spawn_datadeps`. + +The keyword argument `traversal` controls the order that tasks are launched by +the scheduler, and may be set to `:bfs` or `:dfs` for Breadth-First Scheduling +or Depth-First Scheduling, respectively. All traversal orders respect the +dependencies and ordering of the launched tasks, but may provide better or +worse performance for a given set of datadeps tasks. This argument is +experimental and subject to change. +""" +function spawn_datadeps(f::Base.Callable; static::Bool=true, + traversal::Symbol=:inorder, + scheduler::Union{Symbol,Nothing}=nothing, + aliasing::Bool=true, + launch_wait::Union{Bool,Nothing}=nothing) + if !static + throw(ArgumentError("Dynamic scheduling is no longer available")) + end + wait_all(; check_errors=true) do + scheduler = something(scheduler, DATADEPS_SCHEDULER[], :roundrobin)::Symbol + launch_wait = something(launch_wait, DATADEPS_LAUNCH_WAIT[], false)::Bool + if launch_wait + result = spawn_bulk() do + queue = DataDepsTaskQueue(get_options(:task_queue); + traversal, scheduler, aliasing) + with_options(f; task_queue=queue) + distribute_tasks!(queue) + end + else + queue = DataDepsTaskQueue(get_options(:task_queue); + traversal, scheduler, aliasing) + result = with_options(f; task_queue=queue) + distribute_tasks!(queue) + end + return result + end +end +const DATADEPS_SCHEDULER = ScopedValue{Union{Symbol,Nothing}}(nothing) +const DATADEPS_LAUNCH_WAIT = ScopedValue{Union{Bool,Nothing}}(nothing) + +function distribute_tasks!(queue::DataDepsTaskQueue) + #= TODO: Improvements to be made: + # - Support for copying non-AbstractArray arguments + # - Parallelize read copies + # - Unreference unused slots + # - Reuse memory when possible + # - Account for differently-sized data + =# + + # Get the set of all processors to be scheduled on + all_procs = Processor[] + scope = get_compute_scope() + for w in procs() + append!(all_procs, get_processors(OSProc(w))) + end + filter!(proc->proc_in_scope(proc, scope), all_procs) + if isempty(all_procs) + throw(Sch.SchedulingException("No processors available, try widening scope")) + end + scope = UnionScope(map(ExactScope, all_procs)) + exec_spaces = unique(vcat(map(proc->collect(memory_spaces(proc)), all_procs)...)) + if !all(space->space isa CPURAMMemorySpace, exec_spaces) && !all(space->root_worker_id(space) == myid(), exec_spaces) + @warn "Datadeps support for multi-GPU, multi-worker is currently broken\nPlease be prepared for incorrect results or errors" maxlog=1 + end + + # Round-robin assign tasks to processors + upper_queue = get_options(:task_queue) + + traversal = queue.traversal + if traversal == :inorder + # As-is + task_order = Colon() + elseif traversal == :bfs + # BFS + task_order = Int[1] + to_walk = Int[1] + seen = Set{Int}([1]) + while !isempty(to_walk) + # N.B. next_root has already been seen + next_root = popfirst!(to_walk) + for v in outneighbors(queue.g, next_root) + if !(v in seen) + push!(task_order, v) + push!(seen, v) + push!(to_walk, v) + end + end + end + elseif traversal == :dfs + # DFS (modified with backtracking) + task_order = Int[] + to_walk = Int[1] + seen = Set{Int}() + while length(task_order) < length(queue.seen_tasks) && !isempty(to_walk) + next_root = popfirst!(to_walk) + if !(next_root in seen) + iv = inneighbors(queue.g, next_root) + if all(v->v in seen, iv) + push!(task_order, next_root) + push!(seen, next_root) + ov = outneighbors(queue.g, next_root) + prepend!(to_walk, ov) + else + push!(to_walk, next_root) + end + end + end + else + throw(ArgumentError("Invalid traversal mode: $traversal")) + end + + state = DataDepsState(queue.aliasing) + sstate = DataDepsSchedulerState() + for proc in all_procs + space = only(memory_spaces(proc)) + get!(()->0, sstate.capacities, space) + sstate.capacities[space] += 1 + end + + # Start launching tasks and necessary copies + write_num = 1 + proc_idx = 1 + #pressures = Dict{Processor,Int}() + proc_to_scope_lfu = BasicLFUCache{Processor,AbstractScope}(1024) + for pair in queue.seen_tasks[task_order] + spec = pair.spec + task = pair.task + write_num, proc_idx = distribute_task!(queue, state, all_procs, scope, spec, task, spec.fargs, proc_to_scope_lfu, write_num, proc_idx) + end + + # Copy args from remote to local + # N.B. We sort the keys to ensure a deterministic order for uniformity + for arg_w in sort(collect(keys(state.arg_owner)); by=arg_w->arg_w.hash) + arg = arg_w.arg + origin_space = state.arg_origin[arg] + remainder, _ = compute_remainder_for_arg!(state, origin_space, arg_w, write_num) + if remainder isa MultiRemainderAliasing + origin_scope = UnionScope(map(ExactScope, collect(processors(origin_space)))...) + enqueue_remainder_copy_from!(state, origin_space, arg_w, remainder, origin_scope, write_num) + elseif remainder isa FullCopy + origin_scope = UnionScope(map(ExactScope, collect(processors(origin_space)))...) + enqueue_copy_from!(state, origin_space, arg_w, origin_scope, write_num) + else + @assert remainder isa NoAliasing "Expected NoAliasing, got $(typeof(remainder))" + @dagdebug nothing :spawn_datadeps "Skipped copy-from (up-to-date): $origin_space" + ctx = Sch.eager_context() + id = rand(UInt) + @maybelog ctx timespan_start(ctx, :datadeps_copy_skip, (;id), (;)) + @maybelog ctx timespan_finish(ctx, :datadeps_copy_skip, (;id), (;thunk_id=0, from_space=origin_space, to_space=origin_space, arg_w, from_arg=arg, to_arg=arg)) + end + end +end +struct DataDepsTaskDependency + arg_w::ArgumentWrapper + readdep::Bool + writedep::Bool +end +DataDepsTaskDependency(arg, dep) = + DataDepsTaskDependency(ArgumentWrapper(arg, dep[1]), dep[2], dep[3]) +struct DataDepsTaskArgument + arg + pos::ArgPosition + may_alias::Bool + inplace_move::Bool + deps::Vector{DataDepsTaskDependency} +end +struct TypedDataDepsTaskArgument{T,N} + arg::T + pos::ArgPosition + may_alias::Bool + inplace_move::Bool + deps::NTuple{N,DataDepsTaskDependency} +end +map_or_ntuple(f, xs::Vector) = map(f, 1:length(xs)) +@inline map_or_ntuple(@specialize(f), xs::NTuple{N,T}) where {N,T} = ntuple(f, Val(N)) +function distribute_task!(queue::DataDepsTaskQueue, state::DataDepsState, all_procs, scope, spec::DTaskSpec{typed}, task::DTask, fargs, proc_to_scope_lfu, write_num::Int, proc_idx::Int) where typed + @specialize spec fargs + + if typed + fargs::Tuple + else + fargs::Vector{Argument} + end + + task_scope = @something(spec.options.compute_scope, spec.options.scope, DefaultScope()) + scheduler = queue.scheduler + if scheduler == :naive + raw_args = map(arg->tochunk(value(arg)), spec.fargs) + our_proc = remotecall_fetch(1, all_procs, raw_args) do all_procs, raw_args + Sch.init_eager() + sch_state = Sch.EAGER_STATE[] + + @lock sch_state.lock begin + # Calculate costs per processor and select the most optimal + # FIXME: This should consider any already-allocated slots, + # whether they are up-to-date, and if not, the cost of moving + # data to them + procs, costs = Sch.estimate_task_costs(sch_state, all_procs, nothing, raw_args) + return first(procs) + end + end + elseif scheduler == :smart + raw_args = map(filter(arg->haskey(state.data_locality, value(arg)), spec.fargs)) do arg + arg_chunk = tochunk(value(arg)) + # Only the owned slot is valid + # FIXME: Track up-to-date copies and pass all of those + return arg_chunk => data_locality[arg] + end + f_chunk = tochunk(value(spec.fargs[1])) + our_proc, task_pressure = remotecall_fetch(1, all_procs, pressures, f_chunk, raw_args) do all_procs, pressures, f, chunks_locality + Sch.init_eager() + sch_state = Sch.EAGER_STATE[] + + @lock sch_state.lock begin + tx_rate = sch_state.transfer_rate[] + + costs = Dict{Processor,Float64}() + for proc in all_procs + # Filter out chunks that are already local + chunks_filt = Iterators.filter(((chunk, space)=chunk_locality)->!(proc in processors(space)), chunks_locality) + + # Estimate network transfer costs based on data size + # N.B. `affinity(x)` really means "data size of `x`" + # N.B. We treat same-worker transfers as having zero transfer cost + tx_cost = Sch.impute_sum(affinity(chunk)[2] for chunk in chunks_filt) + + # Estimate total cost to move data and get task running after currently-scheduled tasks + est_time_util = get(pressures, proc, UInt64(0)) + costs[proc] = est_time_util + (tx_cost/tx_rate) + end + + # Look up estimated task cost + sig = Sch.signature(sch_state, f, map(first, chunks_locality)) + task_pressure = get(sch_state.signature_time_cost, sig, 1000^3) + + # Shuffle procs around, so equally-costly procs are equally considered + P = randperm(length(all_procs)) + procs = getindex.(Ref(all_procs), P) + + # Sort by lowest cost first + sort!(procs, by=p->costs[p]) + + best_proc = first(procs) + return best_proc, task_pressure + end + end + # FIXME: Pressure should be decreased by pressure of syncdeps on same processor + pressures[our_proc] = get(pressures, our_proc, UInt64(0)) + task_pressure + elseif scheduler == :ultra + args = Base.mapany(spec.fargs) do arg + pos, data = arg + data, _ = unwrap_inout(data) + if data isa DTask + data = fetch(data; move_value=false, unwrap=false) + end + return pos => tochunk(data) + end + f_chunk = tochunk(value(spec.fargs[1])) + task_time = remotecall_fetch(1, f_chunk, args) do f, args + Sch.init_eager() + sch_state = Sch.EAGER_STATE[] + return @lock sch_state.lock begin + sig = Sch.signature(sch_state, f, args) + return get(sch_state.signature_time_cost, sig, 1000^3) + end + end + + # FIXME: Copy deps are computed eagerly + deps = @something(spec.options.syncdeps, Set{Any}()) + + # Find latest time-to-completion of all syncdeps + deps_completed = UInt64(0) + for dep in deps + haskey(sstate.task_completions, dep) || continue # copy deps aren't recorded + deps_completed = max(deps_completed, sstate.task_completions[dep]) + end + + # Find latest time-to-completion of each memory space + # FIXME: Figure out space completions based on optimal packing + spaces_completed = Dict{MemorySpace,UInt64}() + for space in exec_spaces + completed = UInt64(0) + for (task, other_space) in sstate.assignments + space == other_space || continue + completed = max(completed, sstate.task_completions[task]) + end + spaces_completed[space] = completed + end + + # Choose the earliest-available memory space and processor + # FIXME: Consider move time + move_time = UInt64(0) + local our_space_completed + while true + our_space_completed, our_space = findmin(spaces_completed) + our_space_procs = filter(proc->proc in all_procs, processors(our_space)) + if isempty(our_space_procs) + delete!(spaces_completed, our_space) + continue + end + our_proc = rand(our_space_procs) + break + end + + sstate.task_to_spec[task] = spec + sstate.assignments[task] = our_space + sstate.task_completions[task] = our_space_completed + move_time + task_time + elseif scheduler == :roundrobin + our_proc = all_procs[proc_idx] + if task_scope == scope + # all_procs is already limited to scope + else + if isa(constrain(task_scope, scope), InvalidScope) + throw(Sch.SchedulingException("Scopes are not compatible: $(scope), $(task_scope)")) + end + while !proc_in_scope(our_proc, task_scope) + proc_idx = mod1(proc_idx + 1, length(all_procs)) + our_proc = all_procs[proc_idx] + end + end + else + error("Invalid scheduler: $sched") + end + @assert our_proc in all_procs + our_space = only(memory_spaces(our_proc)) + + # Find the scope for this task (and its copies) + task_scope = @something(spec.options.compute_scope, spec.options.scope, DefaultScope()) + if task_scope == scope + # Optimize for the common case, cache the proc=>scope mapping + our_scope = get!(proc_to_scope_lfu, our_proc) do + our_procs = filter(proc->proc in all_procs, collect(processors(our_space))) + return constrain(UnionScope(map(ExactScope, our_procs)...), scope) + end + else + # Use the provided scope and constrain it to the available processors + our_procs = filter(proc->proc in all_procs, collect(processors(our_space))) + our_scope = constrain(UnionScope(map(ExactScope, our_procs)...), task_scope) + end + if our_scope isa InvalidScope + throw(Sch.SchedulingException("Scopes are not compatible: $(our_scope.x), $(our_scope.y)")) + end + + f = spec.fargs[1] + tid = task.uid + # FIXME: May not be correct to move this under uniformity + #f.value = move(default_processor(), our_proc, value(f)) + @dagdebug tid :spawn_datadeps "($(repr(value(f)))) Scheduling: $our_proc ($our_space)" + + # Copy raw task arguments for analysis + # N.B. Used later for checking dependencies + task_args = map_or_ntuple(idx->copy(spec.fargs[idx]), spec.fargs) + + # Populate all task dependencies + task_arg_ws = populate_task_info!(state, task_args, spec, task) + + # Truncate the history for each argument + map_or_ntuple(task_arg_ws) do idx + arg_ws = task_arg_ws[idx] + map_or_ntuple(arg_ws.deps) do dep_idx + dep = arg_ws.deps[dep_idx] + truncate_history!(state, dep.arg_w) + end + return + end + + # Copy args from local to remote + remote_args = map_or_ntuple(task_arg_ws) do idx + arg_ws = task_arg_ws[idx] + arg = arg_ws.arg + pos = raw_position(arg_ws.pos) + + # Is the data written previously or now? + if !arg_ws.may_alias + @dagdebug tid :spawn_datadeps "($(repr(value(f))))[$(idx-1)] Skipped copy-to (immutable)" + return arg + end + + # Is the data writeable? + if !arg_ws.inplace_move + @dagdebug tid :spawn_datadeps "($(repr(value(f))))[$(idx-1)] Skipped copy-to (non-writeable)" + return arg + end + + # Is the source of truth elsewhere? + arg_remote = get_or_generate_slot!(state, our_space, arg) + map_or_ntuple(arg_ws.deps) do dep_idx + dep = arg_ws.deps[dep_idx] + arg_w = dep.arg_w + dep_mod = arg_w.dep_mod + remainder, _ = compute_remainder_for_arg!(state, our_space, arg_w, write_num) + if remainder isa MultiRemainderAliasing + enqueue_remainder_copy_to!(state, our_space, arg_w, remainder, value(f), idx, our_scope, task, write_num) + elseif remainder isa FullCopy + enqueue_copy_to!(state, our_space, arg_w, value(f), idx, our_scope, task, write_num) + else + @assert remainder isa NoAliasing "Expected NoAliasing, got $(typeof(remainder))" + @dagdebug tid :spawn_datadeps "($(repr(value(f))))[$(idx-1)][$dep_mod] Skipped copy-to (up-to-date): $our_space" + end + end + return arg_remote + end + write_num += 1 + + # Validate that we're not accidentally performing a copy + map_or_ntuple(task_arg_ws) do idx + arg_ws = task_arg_ws[idx] + arg = remote_args[idx] + + # Get the dependencies again as (dep_mod, readdep, writedep) + deps = map_or_ntuple(arg_ws.deps) do dep_idx + dep = arg_ws.deps[dep_idx] + (dep.arg_w.dep_mod, dep.readdep, dep.writedep) + end + + # Check that any mutable and written arguments are already in the correct space + # N.B. We only do this check when the argument supports in-place + # moves, because for the moment, we are not guaranteeing updates or + # write-back of results + if is_writedep(arg, deps, task) && arg_ws.may_alias && arg_ws.inplace_move + arg_space = memory_space(arg) + @assert arg_space == our_space "($(repr(value(f))))[$(idx-1)] Tried to pass $(typeof(arg)) from $arg_space to $our_space" + end + end + + # Calculate this task's syncdeps + if spec.options.syncdeps === nothing + spec.options.syncdeps = Set{ThunkSyncdep}() + end + syncdeps = spec.options.syncdeps + map_or_ntuple(task_arg_ws) do idx + arg_ws = task_arg_ws[idx] + arg = arg_ws.arg + arg_ws.may_alias || return + arg_ws.inplace_move || return + map_or_ntuple(arg_ws.deps) do dep_idx + dep = arg_ws.deps[dep_idx] + arg_w = dep.arg_w + ainfo = aliasing!(state, our_space, arg_w) + dep_mod = arg_w.dep_mod + if dep.writedep + @dagdebug tid :spawn_datadeps "($(repr(value(f))))[$(idx-1)][$dep_mod] Syncing as writer" + get_write_deps!(state, our_space, ainfo, write_num, syncdeps) + else + @dagdebug tid :spawn_datadeps "($(repr(value(f))))[$(idx-1)][$dep_mod] Syncing as reader" + get_read_deps!(state, our_space, ainfo, write_num, syncdeps) + end + end + return + end + @dagdebug tid :spawn_datadeps "($(repr(value(f)))) Task has $(length(syncdeps)) syncdeps" + + # Launch user's task + new_fargs = map_or_ntuple(task_arg_ws) do idx + if is_typed(spec) + return TypedArgument(task_arg_ws[idx].pos, remote_args[idx]) + else + return Argument(task_arg_ws[idx].pos, remote_args[idx]) + end + end + new_spec = DTaskSpec(new_fargs, spec.options) + new_spec.options.scope = our_scope + new_spec.options.exec_scope = our_scope + new_spec.options.occupancy = Dict(Any=>0) + ctx = Sch.eager_context() + @maybelog ctx timespan_start(ctx, :datadeps_execute, (;thunk_id=task.uid), (;)) + enqueue!(queue.upper_queue, DTaskPair(new_spec, task)) + @maybelog ctx timespan_finish(ctx, :datadeps_execute, (;thunk_id=task.uid), (;space=our_space, deps=task_arg_ws, args=remote_args)) + + # Update read/write tracking for arguments + map_or_ntuple(task_arg_ws) do idx + arg_ws = task_arg_ws[idx] + arg = arg_ws.arg + arg_ws.may_alias || return + arg_ws.inplace_move || return + for dep in arg_ws.deps + arg_w = dep.arg_w + ainfo = aliasing!(state, our_space, arg_w) + dep_mod = arg_w.dep_mod + if dep.writedep + @dagdebug tid :spawn_datadeps "($(repr(value(f))))[$(idx-1)][$dep_mod] Task set as writer" + add_writer!(state, arg_w, our_space, ainfo, task, write_num) + else + add_reader!(state, arg_w, our_space, ainfo, task, write_num) + end + end + return + end + + write_num += 1 + proc_idx = mod1(proc_idx + 1, length(all_procs)) + + return write_num, proc_idx +end diff --git a/src/datadeps/remainders.jl b/src/datadeps/remainders.jl new file mode 100644 index 000000000..67fdd2588 --- /dev/null +++ b/src/datadeps/remainders.jl @@ -0,0 +1,503 @@ +# Remainder tracking and computation functions + +""" + RemainderAliasing{S<:MemorySpace} <: AbstractAliasing + +Represents the memory spans that remain after subtracting some regions from a base aliasing object. +This is used to perform partial data copies that only update the "remainder" regions. +""" +struct RemainderAliasing{S<:MemorySpace} <: AbstractAliasing + space::S + spans::Vector{Tuple{LocalMemorySpan,LocalMemorySpan}} + syncdeps::Set{ThunkSyncdep} +end +RemainderAliasing(space::S, spans::Vector{Tuple{LocalMemorySpan,LocalMemorySpan}}, syncdeps::Set{ThunkSyncdep}) where S = + RemainderAliasing{S}(space, spans, syncdeps) + +memory_spans(ra::RemainderAliasing) = ra.spans + +Base.hash(ra::RemainderAliasing, h::UInt) = hash(ra.spans, hash(RemainderAliasing, h)) +Base.:(==)(ra1::RemainderAliasing, ra2::RemainderAliasing) = ra1.spans == ra2.spans + +# Add will_alias support for RemainderAliasing +function will_alias(x::RemainderAliasing, y::AbstractAliasing) + return will_alias(memory_spans(x), memory_spans(y)) +end + +function will_alias(x::AbstractAliasing, y::RemainderAliasing) + return will_alias(memory_spans(x), memory_spans(y)) +end + +function will_alias(x::RemainderAliasing, y::RemainderAliasing) + return will_alias(memory_spans(x), memory_spans(y)) +end + +struct MultiRemainderAliasing <: AbstractAliasing + remainders::Vector{<:RemainderAliasing} +end +MultiRemainderAliasing() = MultiRemainderAliasing(RemainderAliasing[]) + +memory_spans(mra::MultiRemainderAliasing) = vcat(memory_spans.(mra.remainders)...) + +Base.hash(mra::MultiRemainderAliasing, h::UInt) = hash(mra.remainders, hash(MultiRemainderAliasing, h)) +Base.:(==)(mra1::MultiRemainderAliasing, mra2::MultiRemainderAliasing) = mra1.remainders == mra2.remainders + +#= FIXME: Integrate with main documentation +Problem statement: + +Remainder copy calculation needs to ensure that, for a given argument and +dependency modifier, and for a given target memory space, any data not yet +updated (whether through this arg or through another that aliases) is added to +the remainder, while any data that has been updated is not in the remainder. +Remainder copies may be multi-part, as data may be spread across multiple other +memory spaces. + +Ainfo is not alone sufficient to identify the combination of argument and +dependency modifier, as ainfo is specific to an allocation in a given memory +space. Thus, this combination needs to be tracked together, and separately from +memory space. However, information may span multiple memory spaces (and thus +multiple ainfos), so we should try to make queries of cross-memory space +information fast, as they will need to be performed for every task, for every +combination. + +Game Plan: + +- Use ArgumentWrapper to track this combination throughout the codebase, ideally generated just once +- Maintain the keying of remote_args only on argument, as the dependency modifier doesn’t affect the argument being passed into the task, so it should not factor into generating and tracking remote argument copies +- Add a structure to track the mapping from ArgumentWrapper to memory space to ainfo, as a quick way to lookup all ainfos needing to be considered +- When considering a remainder copy, only look at a single memory space’s ainfos at a time, as the ainfos should overlap exactly the same way on any memory space, and this allows us to use ainfo_overlaps to track overlaps +- Remainder copies will need to separately consider the source memory space, and the destination memory space when acquiring spans to copy to/from +- Memory spans for ainfos generated from the same ArgumentWrapper should be assumed to be paired in the same order, regardless of memory space, to ensure we can perform the translation from source to destination span address + - Alternatively, we might provide an API to take source and destination ainfos, and desired remainder memory spans, which then performs the copy for us +- When a task or copy writes to arguments, we should record this happening for all overlapping ainfos, in a manner that will be efficient to query from another memory space. We can probably walk backwards and attach this to a structure keyed on ArgumentWrapper, as that will be very efficient for later queries (because the history will now be linearized in one vector). +- Remainder copies will need to know, for all overlapping ainfos of the ArgumentWrapper ainfo at the target memory space, how recently that ainfo was updated relative to other ainfos, and relative to how recently the target ainfo was written. + - The last time the target ainfo was written is the furthest back we need to consider, as the target data must have been fully up-to-date when that write completed. + - Consideration of updates should start at most recent first, walking backwards in time, as the most recent updates contain the up-to-date data. + - For each span under consideration, we should subtract from it the current remainder set, to ensure we only copy up-to-date data. + - We must add that span portion to the remainder set no matter what, but if it was updated on the target memory space, we don’t need to schedule a copy for it, since it’s already where it needs to be. + - Even before the last target write is seen, we are allowed to stop searching if we find that our target ainfo is fully covered (because this implies that the target ainfo is fully out-of-date). +=# + +struct FullCopy end + +""" + compute_remainder_for_arg!(state::DataDepsState, + target_space::MemorySpace, + arg_w::ArgumentWrapper) + +Computes what remainder regions need to be copied to `target_space` before a task can access `arg_w`. +Returns a `MultiRemainderAliasing` object representing the remainder, or `NoAliasing()` if no remainder needed. + +The algorithm starts by collecting the memory spans of `arg_w` in `target_space` - this is the "remainder". +When this remainder is empty, the algorithm will be finished. +Additionally, a dictionary is created to store the source and destination +memory spans (for each source memory space) that will be used to create the +`MultiRemainderAliasing` object - this is the "tracker". + +The algorithm walks backwards through the `arg_history` vector for `arg_w` +(which is an ordered list of all overlapping ainfos that were directy written to (potentially in a different memory space than `target_space`) +since the last time this `arg_w` was written to). If this ainfo is in `target_space`, +then it is not under consideration; it is simply subtraced from the remainder with `subtract_remainder!`, +and the algorithm goes to the next ainfo. Otherwise, the algorithm will consider this ainfo for tracking. + +For each overlapping ainfo (which lives in a different memory space than `target_space`) to be tracked, there exists a corresponding "mirror" ainfo in +`target_space`, which is the equivalent of the overlapping ainfo, but in +`target_space`. This mirror ainfo is assumed to have an identical number of memory spans as the overlapping ainfo, +and each memory span is assumed to be identical in size, but not necessarily identical in address. + +These three sets of memory spans (from the remainder, the overlapping ainfo, and the mirror ainfo) are then passed to `schedule_aliasing!`. +This call will subtract the spans of the mirror ainfo from the remainder (as the two live in the same memory space and thus can be directly compared), +and will update the remainder accordingly. +Additionaly, it will also use this subtraction to update the tracker, by adding the equivalent spans (mapped from mirror ainfo to overlapping ainfo) to the tracker as the source, +and the spans of the remainder as the destination. + +If the history is exhausted without the remainder becoming empty, then the +remaining data in `target_space` is assumed to be up-to-date (as the latest write +to `arg_w` is the furthest back we need to consider). + +Finally, the tracker is converted into a `MultiRemainderAliasing` object, +and returned. +""" +function compute_remainder_for_arg!(state::DataDepsState, + target_space::MemorySpace, + arg_w::ArgumentWrapper, + write_num::Int; compute_syncdeps::Bool=true) + @label restart + + # Determine all memory spaces of the history + spaces_set = Set{MemorySpace}() + push!(spaces_set, target_space) + owner_space = state.arg_owner[arg_w] + push!(spaces_set, owner_space) + for entry in state.arg_history[arg_w] + push!(spaces_set, entry.space) + end + spaces = collect(spaces_set) + N = length(spaces) + + # Lookup all memory spans for arg_w in these spaces + target_ainfos = Vector{Vector{LocalMemorySpan}}() + for space in spaces + target_space_ainfo = aliasing!(state, space, arg_w) + spans = memory_spans(target_space_ainfo) + push!(target_ainfos, LocalMemorySpan.(spans)) + end + nspans = length(first(target_ainfos)) + @assert all(==(nspans), length.(target_ainfos)) "Aliasing info for $(typeof(arg_w.arg))[$(arg_w.dep_mod)] has different number of spans in different memory spaces" + + # FIXME: This is a hack to ensure that we don't miss any history generated by aliasing(...) + for entry in state.arg_history[arg_w] + if !in(entry.space, spaces) + @opcounter :compute_remainder_for_arg_restart + @goto restart + end + end + + # We may only need to schedule a full copy from the origin space to the + # target space if this is the first time we've written to `arg_w` + if isempty(state.arg_history[arg_w]) + if owner_space != target_space + return FullCopy(), 0 + else + return NoAliasing(), 0 + end + end + + # Create our remainder as an interval tree over all target ainfos + VERIFY_SPAN_CURRENT_OBJECT[] = arg_w.arg + remainder = IntervalTree{ManyMemorySpan{N}}(ManyMemorySpan{N}(ntuple(i -> target_ainfos[i][j], N)) for j in 1:nspans) + + # Create our tracker + tracker = Dict{MemorySpace,Tuple{Vector{Tuple{LocalMemorySpan,LocalMemorySpan}},Set{ThunkSyncdep}}}() + + # Walk backwards through the history of writes to this target + # other_ainfo is the overlapping ainfo that was written to + # other_space is the memory space of the overlapping ainfo + last_idx = length(state.arg_history[arg_w]) + for idx in length(state.arg_history[arg_w]):-1:0 + if isempty(remainder) + # All done! + last_idx = idx + break + end + + if idx > 0 + other_entry = state.arg_history[arg_w][idx] + other_ainfo = other_entry.ainfo + other_space = other_entry.space + else + # If we've reached the end of the history, evaluate ourselves + other_ainfo = aliasing!(state, owner_space, arg_w) + other_space = owner_space + end + + # Lookup all memory spans for arg_w in these spaces + other_remote_arg_w = first(collect(state.ainfo_arg[other_ainfo])) + other_arg_w = ArgumentWrapper(state.remote_arg_to_original[other_remote_arg_w.arg], other_remote_arg_w.dep_mod) + other_ainfos = Vector{Vector{LocalMemorySpan}}() + for space in spaces + other_space_ainfo = aliasing!(state, space, other_arg_w) + spans = memory_spans(other_space_ainfo) + push!(other_ainfos, LocalMemorySpan.(spans)) + end + nspans = length(first(other_ainfos)) + other_many_spans = [ManyMemorySpan{N}(ntuple(i -> other_ainfos[i][j], N)) for j in 1:nspans] + foreach(other_many_spans) do span + verify_span(span) + end + + if other_space == target_space + # Only subtract, this data is already up-to-date in target_space + # N.B. We don't add to syncdeps here, because we'll see this ainfo + # in get_write_deps! + @opcounter :compute_remainder_for_arg_subtract + subtract_spans!(remainder, other_many_spans) + continue + end + + # Subtract from remainder and schedule copy in tracker + other_space_idx = something(findfirst(==(other_space), spaces)) + target_space_idx = something(findfirst(==(target_space), spaces)) + tracker_other_space = get!(tracker, other_space) do + (Vector{Tuple{LocalMemorySpan,LocalMemorySpan}}(), Set{ThunkSyncdep}()) + end + @opcounter :compute_remainder_for_arg_schedule + has_overlap = schedule_remainder!(tracker_other_space[1], other_space_idx, target_space_idx, remainder, other_many_spans) + if compute_syncdeps && has_overlap + @assert haskey(state.ainfos_owner, other_ainfo) "[idx $idx] ainfo $(typeof(other_ainfo)) has no owner" + get_read_deps!(state, other_space, other_ainfo, write_num, tracker_other_space[2]) + end + end + VERIFY_SPAN_CURRENT_OBJECT[] = nothing + + if isempty(tracker) || all(tracked->isempty(tracked[1]), values(tracker)) + return NoAliasing(), 0 + end + + # Return scheduled copies and the index of the last ainfo we considered + mra = MultiRemainderAliasing() + for space in spaces + if haskey(tracker, space) + spans, syncdeps = tracker[space] + if !isempty(spans) + push!(mra.remainders, RemainderAliasing(space, spans, syncdeps)) + end + end + end + @assert !isempty(mra.remainders) "Expected at least one remainder (spaces: $spaces, tracker spaces: $(collect(keys(tracker))))" + return mra, last_idx +end + +### Memory Span Set Operations for Remainder Computation + +""" + schedule_remainder!(tracker::Vector, source_space_idx::Int, dest_space_idx::Int, remainder::IntervalTree, other_many_spans::Vector{ManyMemorySpan{N}}) + +Calculates the difference between `remainder` and `other_many_spans`, subtracts +it from `remainder`, and then adds that difference to `tracker` as a scheduled +copy from `other_many_spans` to the subtraced portion of `remainder`. +""" +function schedule_remainder!(tracker::Vector, source_space_idx::Int, dest_space_idx::Int, remainder::IntervalTree, other_many_spans::Vector{ManyMemorySpan{N}}) where N + diff = Vector{ManyMemorySpan{N}}() + subtract_spans!(remainder, other_many_spans, diff) + for span in diff + source_span = span.spans[source_space_idx] + dest_span = span.spans[dest_space_idx] + @assert span_len(source_span) == span_len(dest_span) "Source and dest spans are not the same size: $(span_len(source_span)) != $(span_len(dest_span))" + push!(tracker, (source_span, dest_span)) + end + return !isempty(diff) +end + +### Remainder copy functions + +""" + enqueue_remainder_copy_to!(state::DataDepsState, f, target_ainfo::AliasingWrapper, remainder_aliasing, dep_mod, arg, idx, + our_space::MemorySpace, our_scope, task::DTask, write_num::Int) + +Enqueues a copy operation to update the remainder regions of an object before a task runs. +""" +function enqueue_remainder_copy_to!(state::DataDepsState, dest_space::MemorySpace, arg_w::ArgumentWrapper, remainder_aliasing::MultiRemainderAliasing, + f, idx, dest_scope, task, write_num::Int) + for remainder in remainder_aliasing.remainders + @assert !isempty(remainder.spans) + enqueue_remainder_copy_to!(state, dest_space, arg_w, remainder, f, idx, dest_scope, task, write_num) + end +end +function enqueue_remainder_copy_to!(state::DataDepsState, dest_space::MemorySpace, arg_w::ArgumentWrapper, remainder_aliasing::RemainderAliasing, + f, idx, dest_scope, task, write_num::Int) + dep_mod = arg_w.dep_mod + + # Find the source space for the remainder data + # We need to find where the best version of the target data lives that hasn't been + # overwritten by more recent partial updates + source_space = remainder_aliasing.space + + @dagdebug task.uid :spawn_datadeps "($(repr(f)))[$(idx-1)][$dep_mod] Enqueueing remainder copy-to for $(typeof(arg_w.arg))[$(arg_w.dep_mod)]: $source_space => $dest_space" + + # Get the source and destination arguments + arg_dest = state.remote_args[dest_space][arg_w.arg] + arg_source = get_or_generate_slot!(state, source_space, arg_w.arg) + + # Create a copy task for the remainder + remainder_syncdeps = Set{Any}() + target_ainfo = aliasing!(state, dest_space, arg_w) + for syncdep in remainder_aliasing.syncdeps + push!(remainder_syncdeps, syncdep) + end + empty!(remainder_aliasing.syncdeps) # We can't bring these to move! + get_write_deps!(state, dest_space, target_ainfo, write_num, remainder_syncdeps) + + @dagdebug task.uid :spawn_datadeps "($(repr(f)))[$(idx-1)][$dep_mod] Remainder copy-to has $(length(remainder_syncdeps)) syncdeps" + + # Launch the remainder copy task + ctx = Sch.eager_context() + id = rand(UInt) + @maybelog ctx timespan_start(ctx, :datadeps_copy, (;id), (;)) + copy_task = Dagger.@spawn scope=dest_scope exec_scope=dest_scope syncdeps=remainder_syncdeps meta=true Dagger.move!(remainder_aliasing, dest_space, source_space, arg_dest, arg_source) + @maybelog ctx timespan_finish(ctx, :datadeps_copy, (;id), (;thunk_id=copy_task.uid, from_space=source_space, to_space=dest_space, arg_w, from_arg=arg_source, to_arg=arg_dest)) + + # This copy task becomes a new writer for the target region + add_writer!(state, arg_w, dest_space, target_ainfo, copy_task, write_num) +end +""" + enqueue_remainder_copy_from!(state::DataDepsState, target_ainfo::AliasingWrapper, arg, remainder_aliasing, + origin_space::MemorySpace, origin_scope, write_num::Int) + +Enqueues a copy operation to update the remainder regions of an object back to the original space. +""" +function enqueue_remainder_copy_from!(state::DataDepsState, dest_space::MemorySpace, arg_w::ArgumentWrapper, remainder_aliasing::MultiRemainderAliasing, + dest_scope, write_num::Int) + for remainder in remainder_aliasing.remainders + @assert !isempty(remainder.spans) + enqueue_remainder_copy_from!(state, dest_space, arg_w, remainder, dest_scope, write_num) + end +end +function enqueue_remainder_copy_from!(state::DataDepsState, dest_space::MemorySpace, arg_w::ArgumentWrapper, remainder_aliasing::RemainderAliasing, + dest_scope, write_num::Int) + dep_mod = arg_w.dep_mod + + # Find the source space for the remainder data + # We need to find where the best version of the target data lives that hasn't been + # overwritten by more recent partial updates + source_space = remainder_aliasing.space + + @dagdebug nothing :spawn_datadeps "($(typeof(arg_w.arg)))[$dep_mod] Enqueueing remainder copy-from for: $source_space => $dest_space" + + # Get the source and destination arguments + arg_dest = state.remote_args[dest_space][arg_w.arg] + arg_source = get_or_generate_slot!(state, source_space, arg_w.arg) + + # Create a copy task for the remainder + remainder_syncdeps = Set{Any}() + target_ainfo = aliasing!(state, dest_space, arg_w) + for syncdep in remainder_aliasing.syncdeps + push!(remainder_syncdeps, syncdep) + end + empty!(remainder_aliasing.syncdeps) # We can't bring these to move! + get_write_deps!(state, dest_space, target_ainfo, write_num, remainder_syncdeps) + + @dagdebug nothing :spawn_datadeps "($(typeof(arg_w.arg)))[$dep_mod] Remainder copy-from has $(length(remainder_syncdeps)) syncdeps" + + # Launch the remainder copy task + ctx = Sch.eager_context() + id = rand(UInt) + @maybelog ctx timespan_start(ctx, :datadeps_copy, (;id), (;)) + copy_task = Dagger.@spawn scope=dest_scope exec_scope=dest_scope syncdeps=remainder_syncdeps meta=true Dagger.move!(remainder_aliasing, dest_space, source_space, arg_dest, arg_source) + @maybelog ctx timespan_finish(ctx, :datadeps_copy, (;id), (;thunk_id=copy_task.uid, from_space=source_space, to_space=dest_space, arg_w, from_arg=arg_source, to_arg=arg_dest)) + + # This copy task becomes a new writer for the target region + add_writer!(state, arg_w, dest_space, target_ainfo, copy_task, write_num) +end + +# FIXME: Document me +function enqueue_copy_to!(state::DataDepsState, dest_space::MemorySpace, arg_w::ArgumentWrapper, + f, idx, dest_scope, task, write_num::Int) + dep_mod = arg_w.dep_mod + source_space = state.arg_owner[arg_w] + target_ainfo = aliasing!(state, dest_space, arg_w) + + @dagdebug task.uid :spawn_datadeps "($(repr(f)))[$(idx-1)][$dep_mod] Enqueueing full copy-to for $(typeof(arg_w.arg))[$(arg_w.dep_mod)]: $source_space => $dest_space" + + # Get the source and destination arguments + arg_dest = state.remote_args[dest_space][arg_w.arg] + arg_source = get_or_generate_slot!(state, source_space, arg_w.arg) + + # Create a copy task for the remainder + copy_syncdeps = Set{Any}() + source_ainfo = aliasing!(state, source_space, arg_w) + target_ainfo = aliasing!(state, dest_space, arg_w) + get_read_deps!(state, source_space, source_ainfo, write_num, copy_syncdeps) + get_write_deps!(state, dest_space, target_ainfo, write_num, copy_syncdeps) + + @dagdebug task.uid :spawn_datadeps "($(repr(f)))[$(idx-1)][$dep_mod] Full copy-to has $(length(copy_syncdeps)) syncdeps" + + # Launch the remainder copy task + ctx = Sch.eager_context() + id = rand(UInt) + @maybelog ctx timespan_start(ctx, :datadeps_copy, (;id), (;)) + copy_task = Dagger.@spawn scope=dest_scope exec_scope=dest_scope syncdeps=copy_syncdeps meta=true Dagger.move!(dep_mod, dest_space, source_space, arg_dest, arg_source) + @maybelog ctx timespan_finish(ctx, :datadeps_copy, (;id), (;thunk_id=copy_task.uid, from_space=source_space, to_space=dest_space, arg_w, from_arg=arg_source, to_arg=arg_dest)) + + # This copy task becomes a new writer for the target region + add_writer!(state, arg_w, dest_space, target_ainfo, copy_task, write_num) +end +function enqueue_copy_from!(state::DataDepsState, dest_space::MemorySpace, arg_w::ArgumentWrapper, + dest_scope, write_num::Int) + dep_mod = arg_w.dep_mod + source_space = state.arg_owner[arg_w] + target_ainfo = aliasing!(state, dest_space, arg_w) + + @dagdebug nothing :spawn_datadeps "($(typeof(arg_w.arg)))[$dep_mod] Enqueueing full copy-from: $source_space => $dest_space" + + # Get the source and destination arguments + arg_dest = state.remote_args[dest_space][arg_w.arg] + arg_source = get_or_generate_slot!(state, source_space, arg_w.arg) + + # Create a copy task for the remainder + copy_syncdeps = Set{Any}() + source_ainfo = aliasing!(state, source_space, arg_w) + target_ainfo = aliasing!(state, dest_space, arg_w) + get_read_deps!(state, source_space, source_ainfo, write_num, copy_syncdeps) + get_write_deps!(state, dest_space, target_ainfo, write_num, copy_syncdeps) + + @dagdebug nothing :spawn_datadeps "($(typeof(arg_w.arg)))[$dep_mod] Full copy-from has $(length(copy_syncdeps)) syncdeps" + + # Launch the remainder copy task + ctx = Sch.eager_context() + id = rand(UInt) + @maybelog ctx timespan_start(ctx, :datadeps_copy, (;id), (;)) + copy_task = Dagger.@spawn scope=dest_scope exec_scope=dest_scope syncdeps=copy_syncdeps meta=true Dagger.move!(dep_mod, dest_space, source_space, arg_dest, arg_source) + @maybelog ctx timespan_finish(ctx, :datadeps_copy, (;id), (;thunk_id=copy_task.uid, from_space=source_space, to_space=dest_space, arg_w, from_arg=arg_source, to_arg=arg_dest)) + + # This copy task becomes a new writer for the target region + add_writer!(state, arg_w, dest_space, target_ainfo, copy_task, write_num) +end + +# Main copy function for RemainderAliasing +function move!(dep_mod::RemainderAliasing{S}, to_space::MemorySpace, from_space::MemorySpace, to::Chunk, from::Chunk) where S + # TODO: Support direct copy between GPU memory spaces + + @assert sizeof(eltype(chunktype(from))) == sizeof(eltype(chunktype(to))) "Source and destination chunks have different element sizes: $(sizeof(eltype(chunktype(from)))) != $(sizeof(eltype(chunktype(to))))" + + # Copy the data from the source object + copies = remotecall_fetch(root_worker_id(from_space), from_space, dep_mod, from) do from_space, dep_mod, from + len = sum(span_tuple->span_len(span_tuple[1]), dep_mod.spans) + copies = Vector{UInt8}(undef, len) + from_raw = unwrap(from) + offset = UInt64(1) + with_context!(from_space) + GC.@preserve copies begin + for (from_span, _) in dep_mod.spans + elsize = sizeof(eltype(from_raw)) + offset_n = UInt64((offset-1) / elsize) + UInt64(1) + n = UInt64(from_span.len / elsize) + read_remainder!(copies, offset_n, from_raw, from_span.ptr, n) + offset += from_span.len + end + end + @assert offset == len+UInt64(1) + return copies + end + + # Copy the data into the destination object + offset = UInt64(1) + to_raw = unwrap(to) + GC.@preserve copies begin + for (_, to_span) in dep_mod.spans + elsize = sizeof(eltype(to_raw)) + offset_n = UInt64((offset-1) / elsize) + UInt64(1) + n = UInt64(to_span.len / elsize) + write_remainder!(copies, offset_n, to_raw, to_span.ptr, n) + offset += to_span.len + end + @assert offset == length(copies)+UInt64(1) + end + + # Ensure that the data is visible + Core.Intrinsics.atomic_fence(:release) + + return +end + +function read_remainder!(copies::Vector{UInt8}, copies_offset::UInt64, from::DenseArray, from_ptr::UInt64, n::UInt64) + elsize = sizeof(eltype(from)) + from_offset = UInt64((from_ptr - UInt64(pointer(from))) / elsize) + UInt64(1) + from_vec = reshape(from, prod(size(from)))::DenseVector{eltype(from)} + copies_typed = unsafe_wrap(Vector{eltype(from)}, Ptr{eltype(from)}(pointer(copies, copies_offset)), n) + copyto!(copies_typed, 1, from_vec, Int(from_offset), Int(n)) +end +function read_remainder!(copies::Vector{UInt8}, copies_offset::UInt64, from::SubArray, from_ptr::UInt64, n::UInt64) + read_remainder!(copies, copies_offset, parent(from), from_ptr, n) +end + +function write_remainder!(copies::Vector{UInt8}, copies_offset::UInt64, to::DenseArray, to_ptr::UInt64, n::UInt64) + elsize = sizeof(eltype(to)) + to_offset = UInt64((to_ptr - UInt64(pointer(to))) / elsize) + UInt64(1) + to_vec = reshape(to, prod(size(to)))::DenseVector{eltype(to)} + copies_typed = unsafe_wrap(Vector{eltype(to)}, Ptr{eltype(to)}(pointer(copies, copies_offset)), n) + copyto!(to_vec, Int(to_offset), copies_typed, 1, Int(n)) +end +function write_remainder!(copies::Vector{UInt8}, copies_offset::UInt64, to::SubArray, to_ptr::UInt64, n::UInt64) + write_remainder!(copies, copies_offset, parent(to), to_ptr, n) +end diff --git a/src/gpu.jl b/src/gpu.jl index 06d749543..fa93f8076 100644 --- a/src/gpu.jl +++ b/src/gpu.jl @@ -100,4 +100,7 @@ function gpu_synchronize(kind::Symbol) gpu_synchronize(Val(kind)) end end -gpu_synchronize(::Val{:CPU}) = nothing \ No newline at end of file +gpu_synchronize(::Val{:CPU}) = nothing + +with_context!(proc::Processor) = nothing +with_context!(space::MemorySpace) = nothing \ No newline at end of file diff --git a/src/memory-spaces.jl b/src/memory-spaces.jl index 9f65a1a21..d39e665cc 100644 --- a/src/memory-spaces.jl +++ b/src/memory-spaces.jl @@ -1,8 +1,7 @@ -abstract type MemorySpace end - struct CPURAMMemorySpace <: MemorySpace owner::Int end +CPURAMMemorySpace() = CPURAMMemorySpace(myid()) root_worker_id(space::CPURAMMemorySpace) = space.owner memory_space(x) = CPURAMMemorySpace(myid()) @@ -30,7 +29,7 @@ processors(space::CPURAMMemorySpace) = ### In-place Data Movement function unwrap(x::Chunk) - @assert root_worker_id(x.processor) == myid() + @assert x.handle.owner == myid() MemPool.poolget(x.handle) end move!(dep_mod, to_space::MemorySpace, from_space::MemorySpace, to::T, from::F) where {T,F} = @@ -89,44 +88,20 @@ function type_may_alias(::Type{T}) where T return false end -may_alias(::MemorySpace, ::MemorySpace) = true +may_alias(::MemorySpace, ::MemorySpace) = false +may_alias(space1::M, space2::M) where M<:MemorySpace = space1 == space2 may_alias(space1::CPURAMMemorySpace, space2::CPURAMMemorySpace) = space1.owner == space2.owner -struct RemotePtr{T,S<:MemorySpace} <: Ref{T} - addr::UInt - space::S -end -RemotePtr{T}(addr::UInt, space::S) where {T,S} = RemotePtr{T,S}(addr, space) -RemotePtr{T}(ptr::Ptr{V}, space::S) where {T,V,S} = RemotePtr{T,S}(UInt(ptr), space) -RemotePtr{T}(ptr::Ptr{V}) where {T,V} = RemotePtr{T}(UInt(ptr), CPURAMMemorySpace(myid())) -Base.convert(::Type{RemotePtr}, x::Ptr{T}) where T = - RemotePtr(UInt(x), CPURAMMemorySpace(myid())) -Base.convert(::Type{<:RemotePtr{V}}, x::Ptr{T}) where {V,T} = - RemotePtr{V}(UInt(x), CPURAMMemorySpace(myid())) -Base.:+(ptr::RemotePtr{T}, offset::Integer) where T = RemotePtr{T}(ptr.addr + offset, ptr.space) -Base.:-(ptr::RemotePtr{T}, offset::Integer) where T = RemotePtr{T}(ptr.addr - offset, ptr.space) -function Base.isless(ptr1::RemotePtr, ptr2::RemotePtr) - @assert ptr1.space == ptr2.space - return ptr1.addr < ptr2.addr -end - -struct MemorySpan{S} - ptr::RemotePtr{Cvoid,S} - len::UInt -end -MemorySpan(ptr::RemotePtr{Cvoid,S}, len::Integer) where S = - MemorySpan{S}(ptr, UInt(len)) - abstract type AbstractAliasing end memory_spans(::T) where T<:AbstractAliasing = throw(ArgumentError("Must define `memory_spans` for `$T`")) memory_spans(x) = memory_spans(aliasing(x)) memory_spans(x, T) = memory_spans(aliasing(x, T)) +### Type-generic aliasing info wrapper -struct AliasingWrapper <: AbstractAliasing +mutable struct AliasingWrapper <: AbstractAliasing inner::AbstractAliasing hash::UInt64 - AliasingWrapper(inner::AbstractAliasing) = new(inner, hash(inner)) end memory_spans(x::AliasingWrapper) = memory_spans(x.inner) @@ -135,8 +110,204 @@ equivalent_structure(x::AliasingWrapper, y::AliasingWrapper) = Base.hash(x::AliasingWrapper, h::UInt64) = hash(x.hash, h) Base.isequal(x::AliasingWrapper, y::AliasingWrapper) = x.hash == y.hash Base.:(==)(x::AliasingWrapper, y::AliasingWrapper) = x.hash == y.hash -will_alias(x::AliasingWrapper, y::AliasingWrapper) = - will_alias(x.inner, y.inner) +will_alias(x::AliasingWrapper, y::AliasingWrapper) = will_alias(x.inner, y.inner) + +### Small dictionary type + +struct SmallDict{K,V} <: AbstractDict{K,V} + keys::Vector{K} + vals::Vector{V} +end +SmallDict{K,V}() where {K,V} = SmallDict{K,V}(Vector{K}(), Vector{V}()) +function Base.getindex(d::SmallDict{K,V}, key) where {K,V} + key_idx = findfirst(==(convert(K, key)), d.keys) + if key_idx === nothing + throw(KeyError(key)) + end + return @inbounds d.vals[key_idx] +end +function Base.setindex!(d::SmallDict{K,V}, val, key) where {K,V} + key_conv = convert(K, key) + key_idx = findfirst(==(key_conv), d.keys) + if key_idx === nothing + push!(d.keys, key_conv) + push!(d.vals, convert(V, val)) + else + d.vals[key_idx] = convert(V, val) + end + return val +end +Base.haskey(d::SmallDict{K,V}, key) where {K,V} = in(convert(K, key), d.keys) +Base.keys(d::SmallDict) = d.keys +Base.length(d::SmallDict) = length(d.keys) +Base.iterate(d::SmallDict) = iterate(d, 1) +Base.iterate(d::SmallDict, state) = state > length(d.keys) ? nothing : (d.keys[state] => d.vals[state], state+1) + +### Type-stable lookup structure for AliasingWrappers + +struct AliasingLookup + # The set of memory spaces that are being tracked + spaces::Vector{MemorySpace} + # The set of AliasingWrappers that are being tracked + # One entry for each AliasingWrapper + ainfos::Vector{AliasingWrapper} + # The memory spaces for each AliasingWrapper + # One entry for each AliasingWrapper + ainfos_spaces::Vector{Vector{Int}} + # The spans for each AliasingWrapper in each memory space + # One entry for each AliasingWrapper + spans::Vector{SmallDict{Int,Vector{LocalMemorySpan}}} + # The set of AliasingWrappers that only exist in a single memory space + # One entry for each AliasingWrapper + ainfos_only_space::Vector{Int} + # The bounding span for each AliasingWrapper in each memory space + # One entry for each AliasingWrapper + bounding_spans::Vector{SmallDict{Int,LocalMemorySpan}} + # The interval tree of the bounding spans for each AliasingWrapper + # One entry for each MemorySpace + bounding_spans_tree::Vector{IntervalTree{LocatorMemorySpan{Int},UInt64}} + + AliasingLookup() = new(MemorySpace[], + AliasingWrapper[], + Vector{Int}[], + SmallDict{Int,Vector{LocalMemorySpan}}[], + Int[], + SmallDict{Int,LocalMemorySpan}[], + IntervalTree{LocatorMemorySpan{Int},UInt64}[]) +end +function Base.push!(lookup::AliasingLookup, ainfo::AliasingWrapper) + # Update the set of memory spaces and spans, + # and find the bounding spans for this AliasingWrapper + spaces_set = Set{MemorySpace}(lookup.spaces) + self_spaces_set = Set{Int}() + spans = SmallDict{Int,Vector{LocalMemorySpan}}() + for span in memory_spans(ainfo) + space = span.ptr.space + if !in(space, spaces_set) + push!(spaces_set, space) + push!(lookup.spaces, space) + push!(lookup.bounding_spans_tree, IntervalTree{LocatorMemorySpan{Int}}()) + end + space_idx = findfirst(==(space), lookup.spaces) + push!(self_spaces_set, space_idx) + spans_in_space = get!(Vector{LocalMemorySpan}, spans, space_idx) + push!(spans_in_space, LocalMemorySpan(span)) + end + push!(lookup.ainfos_spaces, collect(self_spaces_set)) + push!(lookup.spans, spans) + + # Update the set of AliasingWrappers + push!(lookup.ainfos, ainfo) + ainfo_idx = length(lookup.ainfos) + + # Check if the AliasingWrapper only exists in a single memory space + if length(self_spaces_set) == 1 + space_idx = only(self_spaces_set) + push!(lookup.ainfos_only_space, space_idx) + else + push!(lookup.ainfos_only_space, 0) + end + + # Add the bounding spans for this AliasingWrapper + bounding_spans = SmallDict{Int,LocalMemorySpan}() + for space_idx in keys(spans) + space_spans = spans[space_idx] + bound_start = minimum(span_start, space_spans) + bound_end = maximum(span_end, space_spans) + bounding_span = LocalMemorySpan(bound_start, bound_end - bound_start) + bounding_spans[space_idx] = bounding_span + insert!(lookup.bounding_spans_tree[space_idx], LocatorMemorySpan(bounding_span, ainfo_idx)) + end + push!(lookup.bounding_spans, bounding_spans) + + return ainfo_idx +end +struct AliasingLookupFinder + lookup::AliasingLookup + ainfo::AliasingWrapper + ainfo_idx::Int + spaces_idx::Vector{Int} + to_consider::Vector{Int} +end +Base.eltype(::AliasingLookupFinder) = AliasingWrapper +Base.IteratorSize(::AliasingLookupFinder) = Base.SizeUnknown() +# FIXME: We should use a Dict{UInt,Int} to find the ainfo_idx instead of linear search +function Base.intersect(lookup::AliasingLookup, ainfo::AliasingWrapper; ainfo_idx=nothing) + if ainfo_idx === nothing + ainfo_idx = something(findfirst(==(ainfo), lookup.ainfos)) + end + spaces_idx = lookup.ainfos_spaces[ainfo_idx] + to_consider_spans = LocatorMemorySpan{Int}[] + for space_idx in spaces_idx + bounding_spans_tree = lookup.bounding_spans_tree[space_idx] + self_bounding_span = LocatorMemorySpan(lookup.bounding_spans[ainfo_idx][space_idx], 0) + find_overlapping!(bounding_spans_tree, self_bounding_span, to_consider_spans; exact=false) + end + to_consider = Int[locator.owner for locator in to_consider_spans] + @assert all(to_consider .> 0) + return AliasingLookupFinder(lookup, ainfo, ainfo_idx, spaces_idx, to_consider) +end +Base.iterate(finder::AliasingLookupFinder) = iterate(finder, 1) +function Base.iterate(finder::AliasingLookupFinder, cursor_ainfo_idx) + ainfo_spaces = nothing + cursor_space_idx = 1 + + # New ainfos enter here + @label ainfo_restart + + # Check if we've exhausted all ainfos + if cursor_ainfo_idx > length(finder.to_consider) + return nothing + end + ainfo_idx = finder.to_consider[cursor_ainfo_idx] + + # Find the appropriate memory spaces for this ainfo + if ainfo_spaces === nothing + ainfo_spaces = finder.lookup.ainfos_spaces[ainfo_idx] + end + + # New memory spaces (for the same ainfo) enter here + @label space_restart + + # Check if we've exhausted all memory spaces for this ainfo, and need to move to the next ainfo + if cursor_space_idx > length(ainfo_spaces) + cursor_ainfo_idx += 1 + ainfo_spaces = nothing + cursor_space_idx = 1 + @goto ainfo_restart + end + + # Find the currently considered memory space for this ainfo + space_idx = ainfo_spaces[cursor_space_idx] + + # Check if this memory space is part of our target ainfo's spaces + if !(space_idx in finder.spaces_idx) + cursor_space_idx += 1 + @goto space_restart + end + + # Check if this ainfo's bounding span is part of our target ainfo's bounding span in this memory space + other_ainfo_bounding_span = finder.lookup.bounding_spans[ainfo_idx][space_idx] + self_bounding_span = finder.lookup.bounding_spans[finder.ainfo_idx][space_idx] + if !spans_overlap(other_ainfo_bounding_span, self_bounding_span) + cursor_space_idx += 1 + @goto space_restart + end + + # We have a overlapping bounds in the same memory space, so check if the ainfos are aliasing + # This is the slow path! + other_ainfo = finder.lookup.ainfos[ainfo_idx] + aliasing = will_alias(finder.ainfo, other_ainfo) + if !aliasing + cursor_ainfo_idx += 1 + ainfo_spaces = nothing + cursor_space_idx = 1 + @goto ainfo_restart + end + + # We overlap, so return the ainfo and the next ainfo index + return other_ainfo, cursor_ainfo_idx+1 +end struct NoAliasing <: AbstractAliasing end memory_spans(::NoAliasing) = MemorySpan{CPURAMMemorySpace}[] @@ -151,8 +322,11 @@ struct CombinedAliasing <: AbstractAliasing end function memory_spans(ca::CombinedAliasing) # FIXME: Don't hardcode CPURAMMemorySpace - all_spans = MemorySpan{CPURAMMemorySpace}[] - for sub_a in ca.sub_ainfos + if length(ca.sub_ainfos) == 0 + return MemorySpan{CPURAMMemorySpace}[] + end + all_spans = memory_spans(ca.sub_ainfos[1]) + for sub_a in ca.sub_ainfos[2:end] append!(all_spans, memory_spans(sub_a)) end return all_spans @@ -213,8 +387,14 @@ end aliasing(::String) = NoAliasing() # FIXME: Not necessarily true aliasing(::Symbol) = NoAliasing() aliasing(::Type) = NoAliasing() -aliasing(x::Chunk, T) = remotecall_fetch(root_worker_id(x.processor), x, T) do x, T - aliasing(unwrap(x), T) +function aliasing(x::Chunk, T) + @assert x.handle isa DRef + if root_worker_id(x.processor) == myid() + return aliasing(unwrap(x), T) + end + return remotecall_fetch(root_worker_id(x.processor), x, T) do x, T + aliasing(unwrap(x), T) + end end aliasing(x::Chunk) = remotecall_fetch(root_worker_id(x.processor), x) do x aliasing(unwrap(x)) @@ -273,13 +453,22 @@ function _memory_spans(a::StridedAliasing{T,N,S}, spans, ptr, dim) where {T,N,S} return spans end -function aliasing(x::SubArray{T,N,A}) where {T,N,A<:Array} +function aliasing(x::SubArray{T,N}) where {T,N} if isbitstype(T) - S = CPURAMMemorySpace - return StridedAliasing{T,ndims(x),S}(RemotePtr{Cvoid}(pointer(parent(x))), - RemotePtr{Cvoid}(pointer(x)), - parentindices(x), - size(x), strides(parent(x))) + p = parent(x) + space = memory_space(p) + S = typeof(space) + parent_ptr = RemotePtr{Cvoid}(UInt64(pointer(p)), space) + ptr = RemotePtr{Cvoid}(UInt64(pointer(x)), space) + NA = ndims(p) + raw_inds = parentindices(x) + inds = ntuple(i->raw_inds[i] isa Integer ? (raw_inds[i]:raw_inds[i]) : UnitRange(raw_inds[i]), NA) + sz = ntuple(i->length(inds[i]), NA) + return StridedAliasing{T,NA,S}(parent_ptr, + ptr, + inds, + sz, + strides(p)) else # FIXME: Also ContiguousAliasing of container #return IteratedAliasing(x) @@ -396,76 +585,8 @@ end function will_alias(x_span::MemorySpan, y_span::MemorySpan) may_alias(x_span.ptr.space, y_span.ptr.space) || return false # FIXME: Allow pointer conversion instead of just failing - @assert x_span.ptr.space == y_span.ptr.space + @assert x_span.ptr.space == y_span.ptr.space "Memory spans are in different spaces: $(x_span.ptr.space) vs. $(y_span.ptr.space)" x_end = x_span.ptr + x_span.len - 1 y_end = y_span.ptr + y_span.len - 1 return x_span.ptr <= y_end && y_span.ptr <= x_end end - -struct ChunkView{N} - chunk::Chunk - slices::NTuple{N, Union{Int, AbstractRange{Int}, Colon}} -end - -function Base.view(c::Chunk, slices...) - if c.domain isa ArrayDomain - nd, sz = ndims(c.domain), size(c.domain) - nd == length(slices) || throw(DimensionMismatch("Expected $nd slices, got $(length(slices))")) - - for (i, s) in enumerate(slices) - if s isa Int - 1 ≤ s ≤ sz[i] || throw(ArgumentError("Index $s out of bounds for dimension $i (size $(sz[i]))")) - elseif s isa AbstractRange - isempty(s) && continue - 1 ≤ first(s) ≤ last(s) ≤ sz[i] || throw(ArgumentError("Range $s out of bounds for dimension $i (size $(sz[i]))")) - elseif s === Colon() - continue - else - throw(ArgumentError("Invalid slice type $(typeof(s)) at dimension $i, Expected Type of Int, AbstractRange, or Colon")) - end - end - end - - return ChunkView(c, slices) -end - -Base.view(c::DTask, slices...) = view(fetch(c; raw=true), slices...) - -function aliasing(x::ChunkView{N}) where N - remotecall_fetch(root_worker_id(x.chunk.processor), x.chunk, x.slices) do x, slices - x = unwrap(x) - v = view(x, slices...) - return aliasing(v) - end -end -memory_space(x::ChunkView) = memory_space(x.chunk) -isremotehandle(x::ChunkView) = true - -#= -function move!(dep_mod, to_space::MemorySpace, from_space::MemorySpace, to::ChunkView, from::ChunkView) - to_w = root_worker_id(to_space) - @assert to_w == myid() - to_raw = unwrap(to.chunk) - from_w = root_worker_id(from_space) - from_raw = to_w == from_w ? unwrap(from.chunk) : remotecall_fetch(f->copy(unwrap(f)), from_w, from.chunk) - from_view = view(from_raw, from.slices...) - to_view = view(to_raw, to.slices...) - move!(dep_mod, to_space, from_space, to_view, from_view) - return -end -=# - -function move(from_proc::Processor, to_proc::Processor, slice::ChunkView) - if from_proc == to_proc - return view(unwrap(slice.chunk), slice.slices...) - else - # Need to copy the underlying data, so collapse the view - from_w = root_worker_id(from_proc) - data = remotecall_fetch(from_w, slice.chunk, slice.slices) do chunk, slices - copy(view(unwrap(chunk), slices...)) - end - return move(from_proc, to_proc, data) - end -end - -Base.fetch(slice::ChunkView) = view(fetch(slice.chunk), slice.slices...) \ No newline at end of file diff --git a/src/queue.jl b/src/queue.jl index c8c6007ec..37947a0ac 100644 --- a/src/queue.jl +++ b/src/queue.jl @@ -1,32 +1,63 @@ -mutable struct DTaskSpec - fargs::Vector{Argument} +mutable struct DTaskSpec{typed,FA<:Tuple} + _fargs::Vector{Argument} + _typed_fargs::FA options::Options end +DTaskSpec(fargs::Vector{Argument}, options::Options) = + DTaskSpec{false, Tuple{}}(fargs, (), options) +DTaskSpec(fargs::FA, options::Options) where FA = + DTaskSpec{true, FA}(Argument[], fargs, options) +is_typed(spec::DTaskSpec{typed}) where typed = typed +function Base.getproperty(spec::DTaskSpec{typed}, field::Symbol) where typed + if field === :fargs + if typed + return getfield(spec, :_typed_fargs) + else + return getfield(spec, :_fargs) + end + else + return getfield(spec, field) + end +end + +struct DTaskPair + spec::DTaskSpec + task::DTask +end +is_typed(pair::DTaskPair) = is_typed(pair.spec) +Base.iterate(pair::DTaskPair) = (pair.spec, true) +function Base.iterate(pair::DTaskPair, state::Bool) + if state + return (pair.task, false) + else + return nothing + end +end abstract type AbstractTaskQueue end function enqueue! end struct DefaultTaskQueue <: AbstractTaskQueue end -enqueue!(::DefaultTaskQueue, spec::Pair{DTaskSpec,DTask}) = - eager_launch!(spec) -enqueue!(::DefaultTaskQueue, specs::Vector{Pair{DTaskSpec,DTask}}) = - eager_launch!(specs) +enqueue!(::DefaultTaskQueue, pair::DTaskPair) = + eager_launch!(pair) +enqueue!(::DefaultTaskQueue, pairs::Vector{DTaskPair}) = + eager_launch!(pairs) -enqueue!(spec::Pair{DTaskSpec,DTask}) = - enqueue!(get_options(:task_queue, DefaultTaskQueue()), spec) -enqueue!(specs::Vector{Pair{DTaskSpec,DTask}}) = - enqueue!(get_options(:task_queue, DefaultTaskQueue()), specs) +enqueue!(pair::DTaskPair) = + enqueue!(get_options(:task_queue, DefaultTaskQueue()), pair) +enqueue!(pairs::Vector{DTaskPair}) = + enqueue!(get_options(:task_queue, DefaultTaskQueue()), pairs) struct LazyTaskQueue <: AbstractTaskQueue - tasks::Vector{Pair{DTaskSpec,DTask}} - LazyTaskQueue() = new(Pair{DTaskSpec,DTask}[]) + tasks::Vector{DTaskPair} + LazyTaskQueue() = new(DTaskPair[]) end -function enqueue!(queue::LazyTaskQueue, spec::Pair{DTaskSpec,DTask}) - push!(queue.tasks, spec) +function enqueue!(queue::LazyTaskQueue, pair::DTaskPair) + push!(queue.tasks, pair) end -function enqueue!(queue::LazyTaskQueue, specs::Vector{Pair{DTaskSpec,DTask}}) - append!(queue.tasks, specs) +function enqueue!(queue::LazyTaskQueue, pairs::Vector{DTaskPair}) + append!(queue.tasks, pairs) end function spawn_bulk(f::Base.Callable) queue = LazyTaskQueue() @@ -50,25 +81,25 @@ function _add_prev_deps!(queue::InOrderTaskQueue, spec::DTaskSpec) push!(syncdeps, ThunkSyncdep(task)) end end -function enqueue!(queue::InOrderTaskQueue, spec::Pair{DTaskSpec,DTask}) +function enqueue!(queue::InOrderTaskQueue, pair::DTaskPair) if length(queue.prev_tasks) > 0 - _add_prev_deps!(queue, first(spec)) + _add_prev_deps!(queue, pair.spec) empty!(queue.prev_tasks) end - push!(queue.prev_tasks, last(spec)) - enqueue!(queue.upper_queue, spec) + push!(queue.prev_tasks, pair.task) + enqueue!(queue.upper_queue, pair) end -function enqueue!(queue::InOrderTaskQueue, specs::Vector{Pair{DTaskSpec,DTask}}) +function enqueue!(queue::InOrderTaskQueue, pairs::Vector{DTaskPair}) if length(queue.prev_tasks) > 0 - for (spec, task) in specs - _add_prev_deps!(queue, spec) + for pair in pairs + _add_prev_deps!(queue, pair.spec) end empty!(queue.prev_tasks) end - for (spec, task) in specs - push!(queue.prev_tasks, task) + for pair in pairs + push!(queue.prev_tasks, pair.task) end - enqueue!(queue.upper_queue, specs) + enqueue!(queue.upper_queue, pairs) end function spawn_sequential(f::Base.Callable) queue = InOrderTaskQueue(get_options(:task_queue, DefaultTaskQueue())) @@ -79,15 +110,15 @@ struct WaitAllQueue <: AbstractTaskQueue upper_queue::AbstractTaskQueue tasks::Vector{DTask} end -function enqueue!(queue::WaitAllQueue, spec::Pair{DTaskSpec,DTask}) - push!(queue.tasks, spec[2]) - enqueue!(queue.upper_queue, spec) +function enqueue!(queue::WaitAllQueue, pair::DTaskPair) + push!(queue.tasks, pair.task) + enqueue!(queue.upper_queue, pair) end -function enqueue!(queue::WaitAllQueue, specs::Vector{Pair{DTaskSpec,DTask}}) - for (_, task) in specs - push!(queue.tasks, task) +function enqueue!(queue::WaitAllQueue, pairs::Vector{DTaskPair}) + for pair in pairs + push!(queue.tasks, pair.task) end - enqueue!(queue.upper_queue, specs) + enqueue!(queue.upper_queue, pairs) end function wait_all(f; check_errors::Bool=false) queue = WaitAllQueue(get_options(:task_queue, DefaultTaskQueue()), DTask[]) diff --git a/src/sch/util.jl b/src/sch/util.jl index 3f9d7b2f6..d3b7a4804 100644 --- a/src/sch/util.jl +++ b/src/sch/util.jl @@ -238,7 +238,7 @@ function set_failed!(state, origin, thunk=origin) @dagdebug thunk :finish "Setting as failed" filter!(x -> x !== thunk, state.ready) # N.B. If origin === thunk, we assume that the caller has already set the error - if origin !== thunk + if origin !== thunk && !has_result(state, thunk) origin_ex = load_result(state, origin) if origin_ex isa RemoteException origin_ex = origin_ex.captured diff --git a/src/scopes.jl b/src/scopes.jl index ba291bc2b..79190c292 100644 --- a/src/scopes.jl +++ b/src/scopes.jl @@ -40,6 +40,9 @@ struct UnionScope <: AbstractScope push!(scope_set, scope) end end + if isempty(scope_set) + throw(ArgumentError("Cannot construct UnionScope with no inner scopes")) + end return new((collect(scope_set)...,)) end end diff --git a/src/submission.jl b/src/submission.jl index 2e7b1c836..4ff4f2294 100644 --- a/src/submission.jl +++ b/src/submission.jl @@ -268,24 +268,29 @@ function eager_process_elem_submission_to_local!(id_map, arg::Argument) arg.value = Sch.ThunkID(id_map[value(arg).uid], value(arg).thunk_ref) end end -function eager_process_args_submission_to_local!(id_map, spec_pair::Pair{DTaskSpec,DTask}) - spec, task = spec_pair +function eager_process_elem_submission_to_local(id_map, arg::TypedArgument{T}) where T + @assert !(T <: Thunk) "Cannot use `Thunk`s in `@spawn`/`spawn`" + if T <: DTask && haskey(id_map, (value(arg)::DTask).uid) + #=FIXME:UNIQUE=# + return Sch.ThunkID(id_map[value(arg).uid], value(arg).thunk_ref) + end + return arg +end +function eager_process_args_submission_to_local!(id_map, spec::DTaskSpec{false}) for arg in spec.fargs eager_process_elem_submission_to_local!(id_map, arg) end end -function eager_process_args_submission_to_local!(id_map, spec_pairs::Vector{Pair{DTaskSpec,DTask}}) - for spec_pair in spec_pairs - eager_process_args_submission_to_local!(id_map, spec_pair) - end +function eager_process_args_submission_to_local(id_map, spec::DTaskSpec{true}) + return ntuple(i->eager_process_elem_submission_to_local(id_map, spec.fargs[i]), length(spec.fargs)) end -function DTaskMetadata(spec::DTaskSpec) - f = value(spec.fargs[1]) +DTaskMetadata(spec::DTaskSpec) = DTaskMetadata(eager_metadata(spec.fargs)) +function eager_metadata(fargs) + f = value(fargs[1]) f = f isa StreamingFunction ? f.f : f - arg_types = ntuple(i->chunktype(value(spec.fargs[i+1])), length(spec.fargs)-1) - return_type = Base.promote_op(f, arg_types...) - return DTaskMetadata(return_type) + arg_types = ntuple(i->chunktype(value(fargs[i+1])), length(fargs)-1) + return Base.promote_op(f, arg_types...) end function eager_spawn(spec::DTaskSpec) @@ -298,48 +303,64 @@ end chunktype(t::DTask) = t.metadata.return_type -function eager_launch!((spec, task)::Pair{DTaskSpec,DTask}) +function eager_launch!(pair::DTaskPair) + spec = pair.spec + task = pair.task + # Assign a name, if specified eager_assign_name!(spec, task) # Lookup DTask -> ThunkID - lock(Sch.EAGER_ID_MAP) do id_map - eager_process_args_submission_to_local!(id_map, spec=>task) + fargs = lock(Sch.EAGER_ID_MAP) do id_map + if is_typed(spec) + return Argument[map(Argument, eager_process_args_submission_to_local(id_map, spec))...] + else + eager_process_args_submission_to_local!(id_map, spec) + return spec.fargs + end end # Submit the task #=FIXME:REALLOC=# thunk_id = eager_submit!(PayloadOne(task.uid, task.future, - spec.fargs, spec.options, true)) + fargs, spec.options, true)) task.thunk_ref = thunk_id.ref end -function eager_launch!(specs::Vector{Pair{DTaskSpec,DTask}}) - ntasks = length(specs) +# FIXME: Don't convert Tuple to Vector{Argument} +function eager_launch!(pairs::Vector{DTaskPair}) + ntasks = length(pairs) # Assign a name, if specified - for (spec, task) in specs - eager_assign_name!(spec, task) + for pair in pairs + eager_assign_name!(pair.spec, pair.task) end #=FIXME:REALLOC_N=# - uids = [task.uid for (_, task) in specs] - futures = [task.future for (_, task) in specs] + uids = [pair.task.uid for pair in pairs] + futures = [pair.task.future for pair in pairs] # Get all functions, args/kwargs, and options #=FIXME:REALLOC_N=# all_fargs = lock(Sch.EAGER_ID_MAP) do id_map # Lookup DTask -> ThunkID - eager_process_args_submission_to_local!(id_map, specs) - [spec.fargs for (spec, _) in specs] + return map(pairs) do pair + spec = pair.spec + if is_typed(spec) + return Argument[map(Argument, eager_process_args_submission_to_local(id_map, spec))...] + else + eager_process_args_submission_to_local!(id_map, spec) + return spec.fargs + end + end end - all_options = Options[spec.options for (spec, _) in specs] + all_options = Options[pair.spec.options for pair in pairs] # Submit the tasks #=FIXME:REALLOC=# thunk_ids = eager_submit!(PayloadMulti(ntasks, uids, futures, all_fargs, all_options, true)) for i in 1:ntasks - task = specs[i][2] + task = pairs[i].task task.thunk_ref = thunk_ids[i].ref end end diff --git a/src/thunk.jl b/src/thunk.jl index 482d66209..e13e299f0 100644 --- a/src/thunk.jl +++ b/src/thunk.jl @@ -17,8 +17,6 @@ function unset!(spec::ThunkSpec, _) spec.id = 0 spec.cache_ref = nothing spec.affinity = nothing - compute_scope = DefaultScope() - result_scope = AnyScope() spec.options = nothing end @@ -186,21 +184,19 @@ function args_kwargs_to_arguments(f, args, kwargs) end return args_kwargs end -function args_kwargs_to_arguments(f, args) - @nospecialize f args - args_kwargs = Argument[] - push!(args_kwargs, Argument(ArgPosition(true, 0, :NULL), f)) - pos_ctr = 1 - for idx in 1:length(args) - pos, arg = args[idx]::Pair - if pos === nothing - push!(args_kwargs, Argument(pos_ctr, arg)) - pos_ctr += 1 +function args_kwargs_to_typedarguments(f, args, kwargs) + nargs = 1 + length(args) + length(kwargs) + return ntuple(nargs) do idx + if idx == 1 + return TypedArgument(ArgPosition(true, 0, :NULL), f) + elseif idx in 2:(1+length(args)) + arg = args[idx-1] + return TypedArgument(idx, arg) else - push!(args_kwargs, Argument(pos, arg)) + kw, value = kwargs[idx-length(args)-1] + return TypedArgument(kw, value) end end - return args_kwargs end """ @@ -491,7 +487,11 @@ function _par(mod, ex::Expr; lazy=true, recur=true, opts=()) @gensym result return quote let - $result = $spawn($f, $Options(;$(opts...)), $(args...); $(kwargs...)) + $result = if $get_task_typed() + $typed_spawn($f, $Options(;$(opts...)), $(args...); $(kwargs...)) + else + $spawn($f, $Options(;$(opts...)), $(args...); $(kwargs...)) + end if $(Expr(:islocal, sync_var)) put!($sync_var, schedule(Task(()->fetch($result; raw=true)))) end @@ -516,6 +516,9 @@ function _setindex!_return_value(A, value, idxs...) return value end +const TASK_TYPED = ScopedValue{Bool}(false) +get_task_typed() = TASK_TYPED[] + """ Dagger.spawn(f, args...; kwargs...) -> DTask @@ -526,6 +529,36 @@ Spawns a `DTask` that will call `f(args...; kwargs...)`. Also supports passing a function spawn(f, args...; kwargs...) @nospecialize f args kwargs + # Merge all passed options + if length(args) >= 1 && first(args) isa Options + # N.B. Make a defensive copy in case user aliases Options struct + task_options = copy(first(args)::Options) + args = args[2:end] + else + task_options = Options() + end + + # Process the args and kwargs into Argument form + args_kwargs = args_kwargs_to_arguments(f, args, kwargs) + + return _spawn(args_kwargs, task_options) +end +function typed_spawn(f, args...; kwargs...) + # Merge all passed options + if length(args) >= 1 && first(args) isa Options + # N.B. Make a defensive copy in case user aliases Options struct + task_options = copy(first(args)::Options) + args = args[2:end] + else + task_options = Options() + end + + # Process the args and kwargs into Tuple of TypedArgument form + args_kwargs = args_kwargs_to_typedarguments(f, args, kwargs) + + return _spawn(args_kwargs, task_options) +end +function _spawn(args_kwargs, task_options) # Get all scoped options and determine which propagate beyond this task scoped_options = get_options()::NamedTuple if haskey(scoped_options, :propagates) @@ -539,20 +572,9 @@ function spawn(f, args...; kwargs...) end append!(propagates, keys(scoped_options)::NTuple{N,Symbol} where N) - # Merge all passed options - if length(args) >= 1 && first(args) isa Options - # N.B. Make a defensive copy in case user aliases Options struct - task_options = copy(first(args)::Options) - args = args[2:end] - else - task_options = Options() - end # N.B. Merges into task_options options_merge!(task_options, scoped_options; override=false) - # Process the args and kwargs into Pair form - args_kwargs = args_kwargs_to_arguments(f, args, kwargs) - # Get task queue, and don't let it propagate task_queue = get(scoped_options, :task_queue, DefaultTaskQueue())::AbstractTaskQueue filter!(prop -> prop != :task_queue, propagates) @@ -568,7 +590,7 @@ function spawn(f, args...; kwargs...) task = eager_spawn(spec) # Enqueue the task into the task queue - enqueue!(task_queue, spec=>task) + enqueue!(task_queue, DTaskPair(spec, task)) return task end diff --git a/src/utils/chunks.jl b/src/utils/chunks.jl index 400b49332..9f0c3b487 100644 --- a/src/utils/chunks.jl +++ b/src/utils/chunks.jl @@ -174,6 +174,9 @@ function tochunk(x::Chunk, proc=nothing, scope=nothing; rewrap=false, kwargs...) end tochunk(x::Thunk, proc=nothing, scope=nothing; kwargs...) = x +root_worker_id(chunk::Chunk) = root_worker_id(chunk.handle) +root_worker_id(dref::DRef) = dref.owner # FIXME: Migration + function savechunk(data, dir, f) sz = open(joinpath(dir, f), "w") do io serialize(io, MemPool.MMWrap(data)) diff --git a/src/utils/dagdebug.jl b/src/utils/dagdebug.jl index 615030400..873e47e79 100644 --- a/src/utils/dagdebug.jl +++ b/src/utils/dagdebug.jl @@ -35,3 +35,28 @@ macro dagdebug(thunk, category, msg, args...) end end) end + +# FIXME: Calculate fast-growth based on clock time, not iteration +const OPCOUNTER_CATEGORIES = Symbol[] +const OPCOUNTER_FAST_GROWTH_THRESHOLD = Ref(10_000_000) +struct OpCounter + value::Threads.Atomic{Int} +end +OpCounter() = OpCounter(Threads.Atomic{Int}(0)) +macro opcounter(category, count=1) + cat_sym = category.value + @gensym old + opcounter_sym = Symbol(:OPCOUNTER_, cat_sym) + if !isdefined(__module__, opcounter_sym) + __module__.eval(:(#=const=# $opcounter_sym = OpCounter())) + end + esc(quote + if $(QuoteNode(cat_sym)) in $OPCOUNTER_CATEGORIES + $old = Threads.atomic_add!($opcounter_sym.value, Int($count)) + if $old > 1 && (mod1($old, $OPCOUNTER_FAST_GROWTH_THRESHOLD[]) == 1 || $count > $OPCOUNTER_FAST_GROWTH_THRESHOLD[]) + println("Fast-growing counter: $($(QuoteNode(cat_sym))) = $($old)") + end + end + end) +end +opcounter(mod::Module, category::Symbol) = getfield(mod, Symbol(:OPCOUNTER_, category)).value[] \ No newline at end of file diff --git a/src/utils/haloarray.jl b/src/utils/haloarray.jl index 1fadbeeb6..c5990099e 100644 --- a/src/utils/haloarray.jl +++ b/src/utils/haloarray.jl @@ -99,4 +99,22 @@ Adapt.adapt_structure(to, H::Dagger.HaloArray) = HaloArray(Adapt.adapt(to, H.center), Adapt.adapt.(Ref(to), H.edges), Adapt.adapt.(Ref(to), H.corners), - H.halo_width) \ No newline at end of file + H.halo_width) + +function aliasing(A::HaloArray) + return CombinedAliasing([aliasing(A.center), aliasing(A.edges), aliasing(A.corners)]) +end +memory_space(A::HaloArray) = memory_space(A.center) +function move_rewrap(cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, A::HaloArray) + center_chunk = rewrap_aliased_object!(cache, from_proc, to_proc, from_space, to_space, A.center) + edge_chunks = ntuple(i->rewrap_aliased_object!(cache, from_proc, to_proc, from_space, to_space, A.edges[i]), length(A.edges)) + corner_chunks = ntuple(i->rewrap_aliased_object!(cache, from_proc, to_proc, from_space, to_space, A.corners[i]), length(A.corners)) + halo_width = A.halo_width + to_w = root_worker_id(to_proc) + return remotecall_fetch(to_w, from_proc, to_proc, from_space, to_space, center_chunk, edge_chunks, corner_chunks, halo_width) do from_proc, to_proc, from_space, to_space, center_chunk, edge_chunks, corner_chunks, halo_width + center_new = move(from_proc, to_proc, center_chunk) + edges_new = ntuple(i->move(from_proc, to_proc, edge_chunks[i]), length(edge_chunks)) + corners_new = ntuple(i->move(from_proc, to_proc, corner_chunks[i]), length(corner_chunks)) + return tochunk(HaloArray(center_new, edges_new, corners_new, halo_width), to_proc) + end +end \ No newline at end of file diff --git a/src/utils/interval_tree.jl b/src/utils/interval_tree.jl new file mode 100644 index 000000000..dda355756 --- /dev/null +++ b/src/utils/interval_tree.jl @@ -0,0 +1,453 @@ +mutable struct IntervalNode{M,E} + span::M + max_end::E # Maximum end value in this subtree + left::Union{IntervalNode{M,E}, Nothing} + right::Union{IntervalNode{M,E}, Nothing} + + IntervalNode(span::M) where M <: MemorySpan = new{M,UInt64}(span, span_end(span), nothing, nothing) + IntervalNode(span::LocalMemorySpan) = new{LocalMemorySpan,UInt64}(span, span_end(span), nothing, nothing) + IntervalNode(span::ManyMemorySpan{N}) where N = new{ManyMemorySpan{N},ManyPair{N}}(span, span_end(span), nothing, nothing) + IntervalNode(span::LocatorMemorySpan{T}) where T = new{LocatorMemorySpan{T},UInt64}(span, span_end(span), nothing, nothing) +end + +mutable struct IntervalTree{M,E} + root::Union{IntervalNode{M,E}, Nothing} + + IntervalTree{M}() where M<:MemorySpan = new{M,UInt64}(nothing) + IntervalTree{LocalMemorySpan}() = new{LocalMemorySpan,UInt64}(nothing) + IntervalTree{ManyMemorySpan{N}}() where N = new{ManyMemorySpan{N},ManyPair{N}}(nothing) + IntervalTree{LocatorMemorySpan{T}}() where T = new{LocatorMemorySpan{T},UInt64}(nothing) +end + +# Construct interval tree from unsorted set of spans +function IntervalTree{M}(spans) where M + tree = IntervalTree{M}() + for span in spans + insert!(tree, span) + end + verify_spans(tree) + return tree +end +IntervalTree(spans::Vector{M}) where M = IntervalTree{M}(spans) + +function Base.show(io::IO, tree::IntervalTree) + println(io, "$(typeof(tree)) (with $(length(tree)) spans):") + for (i, span) in enumerate(tree) + println(io, " $i: [$(span_start(span)), $(span_end(span))) (len=$(span_len(span)))") + end +end + +function Base.collect(tree::IntervalTree{M}) where M + result = M[] + for span in tree + push!(result, span) + end + return result +end + +# Useful for debugging when spans get misaligned +function verify_spans(tree::IntervalTree{ManyMemorySpan{N}}) where N + for span in tree + verify_span(span) + end +end + +function Base.iterate(tree::IntervalTree{M}) where M + state = Vector{M}() + if tree.root === nothing + return nothing + end + return iterate(tree.root) +end +function Base.iterate(tree::IntervalTree, state) + return iterate(tree.root, state) +end +function Base.iterate(root::IntervalNode{M,E}) where {M,E} + state = Vector{IntervalNode{M,E}}() + push!(state, root) + return iterate(root, state) +end +function Base.iterate(root::IntervalNode, state) + if isempty(state) + return nothing + end + current = popfirst!(state) + if current.right !== nothing + pushfirst!(state, current.right) + end + if current.left !== nothing + pushfirst!(state, current.left) + end + return current.span, state +end + +function Base.length(tree::IntervalTree) + result = 0 + for _ in tree + result += 1 + end + return result +end + +# Update max_end value for a node based on its children +function update_max_end!(node::IntervalNode) + max_end = span_end(node.span) + if node.left !== nothing + max_end = max(max_end, node.left.max_end) + end + if node.right !== nothing + max_end = max(max_end, node.right.max_end) + end + node.max_end = max_end +end + +# Insert a span into the interval tree +function Base.insert!(tree::IntervalTree{M,E}, span::M) where {M,E} + if !isempty(span) + if tree.root === nothing + tree.root = IntervalNode(span) + update_max_end!(tree.root) + return span + end + #tree.root = insert_node!(tree.root, span) + to_update = Vector{IntervalNode{M,E}}() + prev_node = tree.root + cur_node = tree.root + while cur_node !== nothing + if span_start(span) <= span_start(cur_node.span) + cur_node = cur_node.left + else + cur_node = cur_node.right + end + if cur_node !== nothing + prev_node = cur_node + push!(to_update, cur_node) + end + end + if prev_node.left === nothing + prev_node.left = IntervalNode(span) + else + prev_node.right = IntervalNode(span) + end + for node_idx in eachindex(to_update) + node = to_update[node_idx] + update_max_end!(node) + end + end + return span +end + +function insert_node!(::Nothing, span::M) where M + return IntervalNode(span) +end +function insert_node!(root::IntervalNode{M,E}, span::M) where {M,E} + # Use a queue to track the path for updating max_end after insertion + path = Vector{IntervalNode{M,E}}() + current = root + + # Traverse to find the insertion point + while current !== nothing + push!(path, current) + if span_start(span) <= span_start(current.span) + if current.left === nothing + current.left = IntervalNode(span) + break + end + current = current.left + else + if current.right === nothing + current.right = IntervalNode(span) + break + end + current = current.right + end + end + + # Update max_end for all ancestors (process in reverse order) + while !isempty(path) + node = pop!(path) + update_max_end!(node) + end + + return root +end + +# Remove a specific span from the tree (split as needed) +function Base.delete!(tree::IntervalTree{M}, span::M) where M + if !isempty(span) + tree.root = delete_node!(tree.root, span) + end + return span +end + +function delete_node!(::Nothing, span::M) where M + return nothing +end +function delete_node!(root::IntervalNode{M,E}, span::M) where {M,E} + # Track the path to the target node: (node, direction_to_child) + path = Vector{Tuple{IntervalNode{M,E}, Symbol}}() + current = root + target = nothing + target_type = :none # :exact or :overlap + + # Phase 1: Search for target node + while current !== nothing + is_exact = span_start(current.span) == span_start(span) && span_len(current.span) == span_len(span) + is_overlap = !is_exact && spans_overlap(current.span, span) + + if is_exact + target = current + target_type = :exact + break + elseif is_overlap + target = current + target_type = :overlap + break + elseif span_start(span) <= span_start(current.span) + push!(path, (current, :left)) + current = current.left + else + push!(path, (current, :right)) + current = current.right + end + end + + if target === nothing + return root + end + + # Phase 2: Compute replacement for target node + original_span = target.span + succ_path = Vector{IntervalNode{M,E}}() # Path to successor (for max_end updates) + local replacement::Union{IntervalNode{M,E}, Nothing} + + if target.left === nothing && target.right === nothing + # Leaf node + replacement = nothing + elseif target.left === nothing + # Only right child + replacement = target.right + elseif target.right === nothing + # Only left child + replacement = target.left + else + # Two children - find and remove inorder successor + successor = find_min(target.right) + + if target.right === successor + # Successor is direct right child + target.right = successor.right + else + # Track path to successor for max_end updates + succ_parent = target.right + push!(succ_path, succ_parent) + while succ_parent.left !== successor + succ_parent = succ_parent.left + push!(succ_path, succ_parent) + end + # Remove successor by replacing with its right child + succ_parent.left = successor.right + end + + target.span = successor.span + replacement = target + end + + # Phase 3: Handle overlap case - add remaining portions + if target_type == :overlap + original_start = span_start(original_span) + original_end = span_end(original_span) + del_start = span_start(span) + del_end = span_end(span) + verify_span(span) + + # Left portion: exists if original starts before deleted span + if original_start < del_start + left_end = min(original_end, del_start) + if left_end > original_start + left_span = M(original_start, left_end - original_start) + if !isempty(left_span) + replacement = insert_node!(replacement, left_span) + end + end + end + + # Right portion: exists if original extends beyond deleted span + if original_end > del_end + right_start = max(original_start, del_end) + if original_end > right_start + right_span = M(right_start, original_end - right_start) + if !isempty(right_span) + replacement = insert_node!(replacement, right_span) + end + end + end + end + + # Phase 4: Update parent's child pointer + if isempty(path) + root = replacement + else + parent, dir = path[end] + if dir == :left + parent.left = replacement + else + parent.right = replacement + end + end + + # Phase 5: Update max_end in correct order (bottom-up) + # First: successor path (if any) + for i in length(succ_path):-1:1 + update_max_end!(succ_path[i]) + end + # Second: target node (if it wasn't removed) + if replacement === target + update_max_end!(target) + end + # Third: main path (ancestors of target) + for i in length(path):-1:1 + update_max_end!(path[i][1]) + end + + return root +end + +function find_min(node::IntervalNode) + while node.left !== nothing + node = node.left + end + return node +end + +# Find all spans that overlap with the given query span +function find_overlapping(tree::IntervalTree{M}, query::M; exact::Bool=true) where M + result = M[] + find_overlapping!(tree.root, query, result; exact) + return result +end +function find_overlapping!(tree::IntervalTree{M}, query::M, result::Vector{M}; exact::Bool=true) where M + find_overlapping!(tree.root, query, result; exact) + return result +end + +function find_overlapping!(::Nothing, query::M, result::Vector{M}; exact::Bool=true) where M + return +end +function find_overlapping!(node::IntervalNode{M,E}, query::M, result::Vector{M}; exact::Bool=true) where {M,E} + # Use a queue for breadth-first traversal + queue = Vector{IntervalNode{M,E}}() + push!(queue, node) + + while !isempty(queue) + current = popfirst!(queue) + + # Check if current node overlaps with query + if spans_overlap(current.span, query) + if exact + # Get the overlapping portion of the span + overlap = span_diff(current.span, query) + verify_span(overlap) + if !isempty(overlap) + push!(result, overlap) + end + else + push!(result, current.span) + end + end + + # Enqueue left subtree if it might contain overlapping intervals + if current.left !== nothing && current.left.max_end > span_start(query) + push!(queue, current.left) + end + + # Enqueue right subtree if query extends beyond current node's start + if current.right !== nothing && span_end(query) > span_start(current.span) + push!(queue, current.right) + end + end +end + +# ============================================================================ +# MAIN SUBTRACTION ALGORITHM +# ============================================================================ + +""" + subtract_spans!(minuend_tree::IntervalTree{M}, subtrahend_spans::Vector{M}, diff=nothing) where M + +Subtract all spans in subtrahend_spans from the minuend_tree in-place. +The minuend_tree is modified to contain only the portions that remain after subtraction. + +Time Complexity: O(M log N + M*K) where M = |subtrahend_spans|, N = |minuend nodes|, + K = average overlaps per subtrahend span +Space Complexity: O(1) additional space (modifies tree in-place) + +If `diff` is provided, add the overlapping spans to `diff`. +""" +function subtract_spans!(minuend_tree::IntervalTree{M}, subtrahend_spans::Vector{M}, diff=nothing) where M + for sub_span in subtrahend_spans + subtract_single_span!(minuend_tree, sub_span, diff) + end +end + +""" + subtract_single_span!(tree::IntervalTree, sub_span::MemorySpan, diff=nothing) + +Subtract a single span from the interval tree. This function: +1. Finds all overlapping spans in the tree +2. Removes each overlapping span +3. Adds back the non-overlapping portions (left and/or right remnants) +4. If diff is provided, add the overlapping span to diff +""" +function subtract_single_span!(tree::IntervalTree{M}, sub_span::M, diff=nothing) where M + # Find all spans that overlap with the subtrahend + overlapping_spans = find_overlapping(tree, sub_span) + + # Process each overlapping span + for overlap_span in overlapping_spans + # Remove the overlapping span from the tree + delete!(tree, overlap_span) + + # Calculate and add back the portions that should remain + add_remaining_portions!(tree, overlap_span, sub_span) + + if diff !== nothing && !isempty(overlap_span) + push!(diff, overlap_span) + end + end +end + +""" + add_remaining_portions!(tree::IntervalTree, original::MemorySpan, subtracted::MemorySpan) + +After removing an overlapping span, add back the portions that don't overlap with the subtracted span. +There can be up to two remaining portions: left and right of the subtracted region. +""" +function add_remaining_portions!(tree::IntervalTree{M}, original::M, subtracted::M) where M + original_start = span_start(original) + original_end = span_end(original) + sub_start = span_start(subtracted) + sub_end = span_end(subtracted) + + # Left portion: exists if original starts before subtracted + if original_start < sub_start + left_end = min(original_end, sub_start) + if left_end > original_start + left_span = M(original_start, left_end - original_start) + if !isempty(left_span) + insert!(tree, left_span) + end + end + end + + # Right portion: exists if original extends beyond subtracted + if original_end > sub_end + right_start = max(original_start, sub_end) + if original_end > right_start + right_span = M(right_start, original_end - right_start) + if !isempty(right_span) + insert!(tree, right_span) + end + end + end +end \ No newline at end of file diff --git a/src/utils/memory-span.jl b/src/utils/memory-span.jl new file mode 100644 index 000000000..c00d16c36 --- /dev/null +++ b/src/utils/memory-span.jl @@ -0,0 +1,140 @@ +### Remote pointer type + +struct RemotePtr{T,S<:MemorySpace} <: Ref{T} + addr::UInt + space::S +end +RemotePtr{T}(addr::UInt, space::S) where {T,S} = RemotePtr{T,S}(addr, space) +RemotePtr{T}(ptr::Ptr{V}, space::S) where {T,V,S} = RemotePtr{T,S}(UInt(ptr), space) +RemotePtr{T}(ptr::Ptr{V}) where {T,V} = RemotePtr{T}(UInt(ptr), CPURAMMemorySpace(myid())) +# FIXME: Don't hardcode CPURAMMemorySpace +RemotePtr(addr::UInt) = RemotePtr{Cvoid}(addr, CPURAMMemorySpace(myid())) +Base.convert(::Type{RemotePtr}, x::Ptr{T}) where T = + RemotePtr(UInt(x), CPURAMMemorySpace(myid())) +Base.convert(::Type{<:RemotePtr{V}}, x::Ptr{T}) where {V,T} = + RemotePtr{V}(UInt(x), CPURAMMemorySpace(myid())) +Base.convert(::Type{UInt}, ptr::RemotePtr) = ptr.addr +Base.:+(ptr::RemotePtr{T}, offset::Integer) where T = RemotePtr{T}(ptr.addr + offset, ptr.space) +Base.:-(ptr::RemotePtr{T}, offset::Integer) where T = RemotePtr{T}(ptr.addr - offset, ptr.space) +function Base.isless(ptr1::RemotePtr, ptr2::RemotePtr) + @assert ptr1.space == ptr2.space + return ptr1.addr < ptr2.addr +end + +### Generic memory spans + +struct MemorySpan{S} + ptr::RemotePtr{Cvoid,S} + len::UInt +end +MemorySpan(ptr::RemotePtr{Cvoid,S}, len::Integer) where S = + MemorySpan{S}(ptr, UInt(len)) +MemorySpan{S}(addr::UInt, len::Integer) where S = + MemorySpan{S}(RemotePtr{Cvoid,S}(addr), UInt(len)) +Base.isless(a::MemorySpan, b::MemorySpan) = a.ptr < b.ptr +Base.isempty(x::MemorySpan) = x.len == 0 +span_start(span::MemorySpan) = span.ptr.addr +span_len(span::MemorySpan) = span.len +span_end(span::MemorySpan) = span.ptr.addr + span.len +spans_overlap(span1::MemorySpan, span2::MemorySpan) = + span_start(span1) < span_end(span2) && span_start(span2) < span_end(span1) +function span_diff(span1::MemorySpan, span2::MemorySpan) + @assert span1.ptr.space == span2.ptr.space + start = max(span_start(span1), span_start(span2)) + stop = min(span_end(span1), span_end(span2)) + start_ptr = RemotePtr(start, span1.ptr.space) + if start < stop + len = stop - start + return MemorySpan(start_ptr, len) + else + return MemorySpan(start_ptr, 0) + end +end + +### More space-efficient memory spans + +struct LocalMemorySpan + ptr::UInt + len::UInt +end +LocalMemorySpan(span::MemorySpan) = LocalMemorySpan(span.ptr.addr, span.len) +Base.isempty(x::LocalMemorySpan) = x.len == 0 +span_start(span::LocalMemorySpan) = span.ptr +span_len(span::LocalMemorySpan) = span.len +span_end(span::LocalMemorySpan) = span.ptr + span.len +spans_overlap(span1::LocalMemorySpan, span2::LocalMemorySpan) = + span_start(span1) < span_end(span2) && span_start(span2) < span_end(span1) +function span_diff(span1::LocalMemorySpan, span2::LocalMemorySpan) + start = max(span_start(span1), span_start(span2)) + stop = min(span_end(span1), span_end(span2)) + if start < stop + len = stop - start + return LocalMemorySpan(start, len) + else + return LocalMemorySpan(start, 0) + end +end + +# FIXME: Store the length separately, since it's shared by all spans +struct ManyMemorySpan{N} + spans::NTuple{N,LocalMemorySpan} +end +Base.isempty(x::ManyMemorySpan) = all(isempty, x.spans) +span_start(span::ManyMemorySpan{N}) where N = ManyPair(ntuple(i -> span_start(span.spans[i]), N)) +span_len(span::ManyMemorySpan{N}) where N = ManyPair(ntuple(i -> span_len(span.spans[i]), N)) +span_end(span::ManyMemorySpan{N}) where N = ManyPair(ntuple(i -> span_end(span.spans[i]), N)) +spans_overlap(span1::ManyMemorySpan{N}, span2::ManyMemorySpan{N}) where N = + # N.B. The spans are assumed to be the same length and relative offset + spans_overlap(span1.spans[1], span2.spans[1]) +function span_diff(span1::ManyMemorySpan{N}, span2::ManyMemorySpan{N}) where N + verify_span(span1) + verify_span(span2) + span = ManyMemorySpan(ntuple(i -> span_diff(span1.spans[i], span2.spans[i]), N)) + verify_span(span) + return span +end +const VERIFY_SPAN_CURRENT_OBJECT = TaskLocalValue{Any}(()->nothing) +function verify_span(span::ManyMemorySpan{N}) where N + @assert allequal(span_len, span.spans) "All spans must be the same: $(map(span_len, span.spans))\nWhile processing $(typeof(VERIFY_SPAN_CURRENT_OBJECT[]))" +end + +struct ManyPair{N} <: Unsigned + pairs::NTuple{N,UInt} +end +Base.promote_rule(::Type{ManyPair}, ::Type{T}) where {T<:Integer} = ManyPair +Base.convert(::Type{ManyPair{N}}, x::T) where {T<:Integer,N} = ManyPair(ntuple(i -> x, N)) +Base.convert(::Type{ManyPair}, x::ManyPair) = x +Base.:+(x::ManyPair{N}, y::ManyPair{N}) where N = ManyPair(ntuple(i -> x.pairs[i] + y.pairs[i], N)) +Base.:-(x::ManyPair{N}, y::ManyPair{N}) where N = ManyPair(ntuple(i -> x.pairs[i] - y.pairs[i], N)) +Base.:-(x::ManyPair) = error("Can't negate a ManyPair") +Base.:(==)(x::ManyPair, y::ManyPair) = x.pairs == y.pairs +Base.isless(x::ManyPair, y::ManyPair) = x.pairs[1] < y.pairs[1] +Base.:(<)(x::ManyPair, y::ManyPair) = x.pairs[1] < y.pairs[1] +Base.string(x::ManyPair) = "ManyPair($(x.pairs))" +Base.show(io::IO, x::ManyPair) = print(io, string(x)) + +ManyMemorySpan{N}(start::ManyPair{N}, len::ManyPair{N}) where N = + ManyMemorySpan{N}(ntuple(i -> LocalMemorySpan(start.pairs[i], len.pairs[i]), N)) + +### Memory spans with ownership info + +struct LocatorMemorySpan{T} + span::LocalMemorySpan + owner::T +end +LocatorMemorySpan{T}(start::UInt64, len::UInt64) where T = # For interval tree + LocatorMemorySpan{T}(LocalMemorySpan(start, len), 0) +Base.isempty(x::LocatorMemorySpan) = span_len(x.span) == 0 +span_start(x::LocatorMemorySpan) = span_start(x.span) +span_end(x::LocatorMemorySpan) = span_end(x.span) +span_len(x::LocatorMemorySpan) = span_len(x.span) +spans_overlap(span1::LocatorMemorySpan{T}, span2::LocatorMemorySpan{T}) where T = + spans_overlap(span1.span, span2.span) +function span_diff(span1::LocatorMemorySpan{T}, span2::LocatorMemorySpan{T}) where T + span = LocatorMemorySpan(span_diff(span1.span, span2.span), 0) + verify_span(span) + return span +end +function verify_span(span::LocatorMemorySpan{T}) where T + verify_span(span.span) +end \ No newline at end of file diff --git a/test/datadeps.jl b/test/datadeps.jl index cd83be95f..05a3091af 100644 --- a/test/datadeps.jl +++ b/test/datadeps.jl @@ -1,4 +1,5 @@ -import Dagger: ChunkView, Chunk +import Dagger: ChunkView, Chunk, AbstractAliasing, MemorySpace, ArgumentWrapper +import Dagger: aliasing, memory_space using LinearAlgebra, Graphs @testset "Memory Aliasing" begin @@ -82,7 +83,7 @@ end end function with_logs(f) - Dagger.enable_logging!(;taskdeps=true, taskargs=true) + Dagger.enable_logging!(;taskdeps=true, taskargs=true, timeline=true) try f() return Dagger.fetch_logs!() @@ -108,68 +109,296 @@ function taskdeps_for_task(logs::Dict{Int,<:Dict}, tid::Int) end error("Task $tid not found in logs") end -function test_task_dominators(logs::Dict, tid::Int, doms::Vector; all_tids::Vector=[], nondom_check::Bool=false) - g = SimpleDiGraph() - tid_to_v = Dict{Int,Int}() +function all_tasks_in_logs(logs::Dict) + all_tids = Int[] + for w in keys(logs) + _logs = logs[w] + for idx in 1:length(_logs[:core]) + core_log = _logs[:core][idx] + id_log = _logs[:id][idx] + if core_log.category == :add_thunk && core_log.kind == :finish + tid = id_log.thunk_id::Int + push!(all_tids, tid) + end + end + end + return all_tids +end +mutable struct FlowEntry + kind::Symbol + tid::Int + ainfo::AbstractAliasing + to_ainfo::AbstractAliasing + from_space::MemorySpace + to_space::MemorySpace + read::Bool + write::Bool +end +struct FlowCheck + read::Bool + write::Bool + arg_w::ArgumentWrapper + orig_ainfo::AbstractAliasing + orig_space::MemorySpace + function FlowCheck(kind, arg, dep_mod=identity) + if kind == :read + read = true + write = false + elseif kind == :write + read = false + write = true + elseif kind == :readwrite + read = true + write = true + else + error("Invalid kind: $kind") + end + arg_w = maybe_rewrap_arg_w(ArgumentWrapper(arg, dep_mod)) + return new(read, write, arg_w, aliasing(arg, dep_mod), memory_space(arg)) + end +end +struct FlowGraph + g::SimpleDiGraph + tid_to_v::Dict{Int,Int} + FlowGraph() = new(SimpleDiGraph(), Dict{Int,Int}()) +end +struct FlowState + flows::Dict{ArgumentWrapper,Vector{FlowEntry}} + graph::FlowGraph + FlowState() = new(Dict{ArgumentWrapper,Vector{FlowEntry}}(), FlowGraph()) +end +function maybe_rewrap_arg_w(arg_w::ArgumentWrapper) + arg = arg_w.arg + if arg isa DTask + arg = fetch(arg; raw=true) + end + if arg isa Chunk && Dagger.root_worker_id(arg) == myid() + arg = Dagger.unwrap(arg) + end + return ArgumentWrapper(arg, arg_w.dep_mod) +end +function build_dataflow(logs::Dict; verbose::Bool=false) + state = FlowState() + orig_ainfos = Dict{AbstractAliasing,AbstractAliasing}() + ainfo_arg_w = Dict{AbstractAliasing,ArgumentWrapper}() + + function add_execute!(arg_w, orig_ainfo, ainfo, tid, space, read, write) + ainfo_flows = get!(Vector{FlowEntry}, state.flows, arg_w) + # Skip duplicates (same arg 2+ times to same task) + dup_idx = findfirst(flow->flow.tid == tid, ainfo_flows) + if dup_idx === nothing + if !haskey(orig_ainfos, ainfo) + orig_ainfos[ainfo] = orig_ainfo + end + if !haskey(ainfo_arg_w, ainfo) + ainfo_arg_w[ainfo] = arg_w + end + verbose && println("Adding execute flow (tid $tid, space $space, read $read, write $write):\n $orig_ainfo ->\n $ainfo") + verbose && println(" $(arg_w.dep_mod), $(arg_w.arg)") + push!(ainfo_flows, FlowEntry(:execute, tid, ainfo, ainfo, space, space, read, write)) + else + # Union read and write fields + ainfo_flows[dup_idx].read |= read + ainfo_flows[dup_idx].write |= write + end + end + function add_copy!(arg_w, from_arg, to_arg, tid, from_space, to_space) + dep_mod = arg_w.dep_mod + from_ainfo = aliasing(from_arg, dep_mod) + to_ainfo = aliasing(to_arg, dep_mod) + if !haskey(orig_ainfos, from_ainfo) + orig_ainfos[from_ainfo] = from_ainfo + end + if !haskey(ainfo_arg_w, from_ainfo) + ainfo_arg_w[from_ainfo] = arg_w + end + if !haskey(ainfo_arg_w, to_ainfo) + ainfo_arg_w[to_ainfo] = arg_w + end + orig_ainfo = orig_ainfos[from_ainfo] + orig_ainfos[to_ainfo] = orig_ainfo + arg_flows = get!(Vector{FlowEntry}, state.flows, arg_w) + verbose && println("Adding copy flow (tid $tid, from_space $from_space, to_space $to_space):\n $orig_ainfo ->\n $to_ainfo") + verbose && println(" $(arg_w.dep_mod), $(arg_w.arg)") + push!(arg_flows, FlowEntry(:copy, tid, from_ainfo, to_ainfo, from_space, to_space, true, true)) + end + + # Populate graph from syncdeps seen = Set{Int}() - to_visit = copy(all_tids) + to_visit = all_tasks_in_logs(logs) while !isempty(to_visit) this_tid = popfirst!(to_visit) this_tid in seen && continue push!(seen, this_tid) - if !(this_tid in keys(tid_to_v)) - add_vertex!(g); tid_to_v[this_tid] = nv(g) + if !(this_tid in keys(state.graph.tid_to_v)) + add_vertex!(state.graph.g); state.graph.tid_to_v[this_tid] = nv(state.graph.g) end # Add syncdeps deps = taskdeps_for_task(logs, this_tid) for dep in deps - if !(dep in keys(tid_to_v)) - add_vertex!(g); tid_to_v[dep] = nv(g) + if !(dep in keys(state.graph.tid_to_v)) + add_vertex!(state.graph.g); state.graph.tid_to_v[dep] = nv(state.graph.g) end - add_edge!(g, tid_to_v[this_tid], tid_to_v[dep]) + add_edge!(state.graph.g, state.graph.tid_to_v[this_tid], state.graph.tid_to_v[dep]) push!(to_visit, dep) end end - state = dijkstra_shortest_paths(g, tid_to_v[tid]) - any_failed = false - @test !has_edge(g, tid_to_v[tid], tid_to_v[tid]) - any_failed |= has_edge(g, tid_to_v[tid], tid_to_v[tid]) - for dom in doms - @test state.pathcounts[tid_to_v[dom]] > 0 - if state.pathcounts[tid_to_v[dom]] == 0 - println("Expected dominance for $dom of $tid") - any_failed = true - end - end - if nondom_check - for nondom in all_tids - nondom == tid && continue - nondom in doms && continue - @test state.pathcounts[tid_to_v[nondom]] == 0 - if state.pathcounts[tid_to_v[nondom]] > 0 - println("Expected non-dominance for $nondom of $tid") - any_failed = true + + # Populate flows and graphs from datadeps logs + for w in keys(logs) + _logs = logs[w] + for idx in 1:length(_logs[:core]) + core_log = _logs[:core][idx] + id_log = _logs[:id][idx] + tl_log = _logs[:timeline][idx] + if core_log.category == :datadeps_execute && core_log.kind == :finish + tid = id_log.thunk_id + for (remote_arg, depset) in zip(tl_log.args, tl_log.deps) + for dep in depset.deps + arg_w = maybe_rewrap_arg_w(dep.arg_w) + orig_ainfo = aliasing(arg_w.arg, arg_w.dep_mod) + remote_ainfo = aliasing(remote_arg, arg_w.dep_mod) + space = memory_space(remote_arg) + add_execute!(arg_w, orig_ainfo, remote_ainfo, tid, space, dep.readdep, dep.writedep) + end + end + elseif (core_log.category == :datadeps_copy || core_log.category == :datadeps_copy_skip) && core_log.kind == :finish + tid = tl_log.thunk_id + from_space = tl_log.from_space + to_space = tl_log.to_space + from_arg = tl_log.from_arg + to_arg = tl_log.to_arg + arg_w = maybe_rewrap_arg_w(tl_log.arg_w) + add_copy!(arg_w, from_arg, to_arg, tid, from_space, to_space) + end + end + end + + return state +end +function test_dataflow(state::FlowState, checks...; verbose::Bool=true) + # Check that each ainfo starts and ends in the same space + for arg_w in keys(state.flows) + ainfo = aliasing(arg_w.arg, arg_w.dep_mod) + arg_flows = state.flows[arg_w] + orig_space = memory_space(arg_w.arg) #arg_flows[1].from_space + #=if ainfo != arg_flows[1].ainfo + verbose && println("Ainfo key $(ainfo) is not the same as the first flow's ainfo $(ainfo_flows[1].ainfo)") + return false + end=# + final_space = arg_flows[end].to_space + # FIXME: will_alias doesn't check across spaces + any_writes = any(flows->Dagger.will_alias(flows[1], ainfo) && any(flow->flow.write, flows[2]), state.flows) + if orig_space != final_space + if verbose + println("Arg ($(arg_w.dep_mod), $(arg_w.arg)) starts in $(orig_space) but ends in $(final_space)") + for flow in arg_flows + println(" $(flow.kind) $(flow.tid) $(flow.from_space) -> $(flow.to_space)") + end + end + return false + end + end + + # Check each flow against the previous flow, ensuring that the previous flow is a dominator of the current flow + # FIXME: Validate non-dominance when unnecessary? + for arg_w in keys(state.flows) + arg_flows = state.flows[arg_w] + for (idx, flow) in enumerate(arg_flows) + if idx > 1 + prev_flow = arg_flows[idx-1] + if !prev_flow.write && !flow.write + # R->R don't depend on each other + continue + end + if !prev_flow.write && flow.write && prev_flow.kind == :execute && flow.kind == :copy && prev_flow.ainfo != flow.to_ainfo + # Copy only writes to a different ainfo, so don't depend on each other + continue + end + if flow.tid == 0 + # Ignore copy skip flows + continue + end + v = state.graph.tid_to_v[flow.tid] + prev_v = state.graph.tid_to_v[prev_flow.tid] + path_state = dijkstra_shortest_paths(state.graph.g, v; allpaths=true) + if path_state.pathcounts[prev_v] == 0 + if verbose + println("Flow $(idx-1) (tid $(prev_flow.tid), $(prev_flow.kind), R:$(prev_flow.read), W:$(prev_flow.write)) is not a dominator of flow $(idx) (tid $(flow.tid), $(flow.kind), R:$(flow.read), W:$(flow.write))") + @show length(state.flows[arg_w]) + for flow in state.flows[arg_w] + println(" $(flow.kind) $(flow.tid) $(flow.from_space) -> $(flow.to_space) (R:$(flow.read), W:$(flow.write))") + end + for flow in state.flows[arg_w] + println(" May write to: $(flow.to_ainfo)") + end + e_vs = collect(edges(state.graph.g)) + e_tids = map(e->Edge(only(filter(tv->tv[2]==src(e), state.graph.tid_to_v))[1], + only(filter(tv->tv[2]==dst(e), state.graph.tid_to_v))[1]), + e_vs) + sort!(e_tids) + for e in e_tids + s_tid, d_tid = src(e), dst(e) + println("Edge: $s_tid -(up)> $d_tid") + end + end + return false + end end end end - # For debugging purposes - if any_failed - println("Failure detected!") - println("Root: $tid") - println("Exp. doms: $doms") - println("All: $all_tids") - e_vs = collect(edges(g)) - e_tids = map(e->Edge(only(filter(tv->tv[2]==src(e), tid_to_v))[1], - only(filter(tv->tv[2]==dst(e), tid_to_v))[1]), - e_vs) - sort!(e_tids) - for e in e_tids - s_tid, d_tid = src(e), dst(e) - println("Edge: $s_tid -(up)> $d_tid") + # Walk through each check, ensuring that the current state of the flow matches the check + arg_locations = Dict{ArgumentWrapper,MemorySpace}() + flow_idxs = Dict{ArgumentWrapper,Int}(arg_w=>1 for arg_w in keys(state.flows)) + for (idx, check) in enumerate(checks) + # Record the original location of the ainfo + if !haskey(arg_locations, check.arg_w) + arg_locations[check.arg_w] = check.orig_space + end + + # Try to advance a flow + if !haskey(flow_idxs, check.arg_w) + if verbose + @warn "Didn't encounter argument ($(check.arg_w.dep_mod), $(check.arg_w.arg))" + println("Seen arguments:") + for arg_w in keys(state.flows) + println(" ($(arg_w.dep_mod), $(arg_w.arg))") + end + return false + end end + flow_idx = flow_idxs[check.arg_w] + while true + if flow_idx > length(state.flows[check.arg_w]) + verbose && println("Exhausted all tasks while trying to find $(check.arg_w)") + return false + end + flow = state.flows[check.arg_w][flow_idx] + if flow.kind == :execute + # The current flow state must match the check + if flow.read == check.read && flow.write == check.write + # Match, move on to next check + flow_idx += 1 + break + else + verbose && println("Expected ($(check.read), $(check.write)), got ($(flow.read), $(flow.write))") + return false + end + elseif flow.kind == :copy + # We need to advance our ainfo location + # FIXME: Assert proper data progression (requires more complex tracking of other arguments) + #@assert flow.from_space == arg_locations[check.arg_w] + arg_locations[check.arg_w] = flow.to_space + flow_idx += 1 + end + end + + flow_idxs[check.arg_w] = flow_idx end + + return true end @everywhere do_nothing(Xs...) = nothing @@ -177,16 +406,15 @@ end @everywhere mut_V!(V) = (V .= 1;) function test_datadeps(;args_chunks::Bool, args_thunks::Bool, - args_loc::Int, - aliasing::Bool) + args_loc::Int) # Returns last value - @test Dagger.spawn_datadeps(;aliasing) do + @test Dagger.spawn_datadeps() do 42 end == 42 # Tasks are started and finished as spawn_datadeps returns ts = [] - Dagger.spawn_datadeps(;aliasing) do + Dagger.spawn_datadeps() do for i in 1:5 t = Dagger.@spawn sleep(0.1) @test !istaskstarted(t) @@ -195,7 +423,7 @@ function test_datadeps(;args_chunks::Bool, @test all(istaskdone, ts) # Rethrows any task exceptions - @test_throws Exception Dagger.spawn_datadeps(;aliasing) do + @test_throws Exception Dagger.spawn_datadeps() do Dagger.@spawn error("Test") end @@ -206,10 +434,13 @@ function test_datadeps(;args_chunks::Bool, A = Dagger.@spawn scope=Dagger.scope(worker=args_loc) copy(A) end + @warn "Negative-test the test_dataflow helper" + # Task return values can be tracked ts = [] + local t1 logs = with_logs() do - Dagger.spawn_datadeps(;aliasing) do + Dagger.spawn_datadeps() do t1 = Dagger.@spawn fill(42, 1) push!(ts, t1) push!(ts, Dagger.@spawn copyto!(Out(A), In(t1))) @@ -217,273 +448,280 @@ function test_datadeps(;args_chunks::Bool, end tid_1, tid_2 = task_id.(ts) @test fetch(A)[1] == 42.0 - test_task_dominators(logs, tid_1, []; all_tids=[tid_1, tid_2]) + state = build_dataflow(logs) + # FIXME: We don't record the task as a syncdep, but instead internally `fetch` the chunk - test_task_dominators(logs, tid_2, [#=tid_1=#]; all_tids=[tid_1, tid_2]) + # We don't see the :readwrite because we don't see the use of t1 + #@test test_dataflow(state, FlowCheck(:readwrite, t1)) + @test test_dataflow(state, FlowCheck(:read, t1), FlowCheck(:write, A)) # R->R Non-Aliasing ts = [] logs = with_logs() do - Dagger.spawn_datadeps(;aliasing) do + Dagger.spawn_datadeps() do push!(ts, Dagger.@spawn do_nothing(In(A))) push!(ts, Dagger.@spawn do_nothing(In(A))) end end tid_1, tid_2 = task_id.(ts) - test_task_dominators(logs, tid_1, []; all_tids=[tid_1, tid_2]) - test_task_dominators(logs, tid_2, []; all_tids=[tid_1, tid_2], nondom_check=false) + state = build_dataflow(logs) + test_dataflow(state, FlowCheck(:read, A), FlowCheck(:read, A)) # R->W Aliasing ts = [] logs = with_logs() do - Dagger.spawn_datadeps(;aliasing) do + Dagger.spawn_datadeps() do push!(ts, Dagger.@spawn do_nothing(In(A))) push!(ts, Dagger.@spawn do_nothing(Out(A))) end end tid_1, tid_2 = task_id.(ts) - test_task_dominators(logs, tid_1, []; all_tids=[tid_1, tid_2]) - test_task_dominators(logs, tid_2, [tid_1]; all_tids=[tid_1, tid_2]) + state = build_dataflow(logs) + @test test_dataflow(state, FlowCheck(:read, A), FlowCheck(:write, A)) # W->W Aliasing ts = [] logs = with_logs() do - Dagger.spawn_datadeps(;aliasing) do + Dagger.spawn_datadeps() do push!(ts, Dagger.@spawn do_nothing(Out(A))) push!(ts, Dagger.@spawn do_nothing(Out(A))) end end tid_1, tid_2 = task_id.(ts) - test_task_dominators(logs, tid_1, []; all_tids=[tid_1, tid_2]) - test_task_dominators(logs, tid_2, [tid_1]; all_tids=[tid_1, tid_2]) + state = build_dataflow(logs) + @test test_dataflow(state, FlowCheck(:write, A), FlowCheck(:write, A)) # R->R Non-Self-Aliasing ts = [] logs = with_logs() do - Dagger.spawn_datadeps(;aliasing) do + Dagger.spawn_datadeps() do push!(ts, Dagger.@spawn do_nothing(In(A), In(A))) push!(ts, Dagger.@spawn do_nothing(In(A), In(A))) end end tid_1, tid_2 = task_id.(ts) - test_task_dominators(logs, tid_1, []; all_tids=[tid_1, tid_2]) - test_task_dominators(logs, tid_2, []; all_tids=[tid_1, tid_2]) + state = build_dataflow(logs) + @test test_dataflow(state, FlowCheck(:read, A), FlowCheck(:read, A)) # R->W Self-Aliasing ts = [] logs = with_logs() do - Dagger.spawn_datadeps(;aliasing) do + Dagger.spawn_datadeps() do push!(ts, Dagger.@spawn do_nothing(In(A), In(A))) push!(ts, Dagger.@spawn do_nothing(Out(A), Out(A))) end end tid_1, tid_2 = task_id.(ts) - test_task_dominators(logs, tid_1, []; all_tids=[tid_1, tid_2]) - test_task_dominators(logs, tid_2, [tid_1]; all_tids=[tid_1, tid_2]) + state = build_dataflow(logs) + @test test_dataflow(state, FlowCheck(:read, A), FlowCheck(:write, A)) + if !test_dataflow(state, FlowCheck(:read, A), FlowCheck(:write, A)) + exit(1) + end # W->W Self-Aliasing ts = [] logs = with_logs() do - Dagger.spawn_datadeps(;aliasing) do + Dagger.spawn_datadeps() do push!(ts, Dagger.@spawn do_nothing(Out(A), Out(A))) push!(ts, Dagger.@spawn do_nothing(Out(A), Out(A))) end end tid_1, tid_2 = task_id.(ts) - test_task_dominators(logs, tid_1, []; all_tids=[tid_1, tid_2]) - test_task_dominators(logs, tid_2, [tid_1]; all_tids=[tid_1, tid_2]) - - if aliasing - function wrap_chunk_thunk(f, args...) - if args_thunks || args_chunks - result = Dagger.@spawn scope=Dagger.scope(worker=args_loc) f(args...) - if args_thunks - return result - elseif args_chunks - return fetch(result; raw=true) - end - else - # N.B. We don't allocate remotely for raw data - return f(args...) + state = build_dataflow(logs) + @test test_dataflow(state, FlowCheck(:write, A), FlowCheck(:write, A)) + + function wrap_chunk_thunk(f, args...) + if args_thunks || args_chunks + result = Dagger.@spawn scope=Dagger.scope(worker=args_loc) f(args...) + if args_thunks + return result + elseif args_chunks + return fetch(result; raw=true) end + else + # N.B. We don't allocate remotely for raw data + return f(args...) end - B = wrap_chunk_thunk(rand, 4, 4) - - # Views - B_ul = wrap_chunk_thunk(view, B, 1:2, 1:2) - B_ur = wrap_chunk_thunk(view, B, 1:2, 3:4) - B_ll = wrap_chunk_thunk(view, B, 3:4, 1:2) - B_lr = wrap_chunk_thunk(view, B, 3:4, 3:4) - B_mid = wrap_chunk_thunk(view, B, 2:3, 2:3) - for (B_name, B_view) in ( - (:B_ul, B_ul), - (:B_ur, B_ur), - (:B_ll, B_ll), - (:B_lr, B_lr), - (:B_mid, B_mid)) - @test Dagger.will_alias(Dagger.aliasing(B), Dagger.aliasing(B_view)) - B_view === B_mid && continue - @test Dagger.will_alias(Dagger.aliasing(B_mid), Dagger.aliasing(B_view)) - end - local t_A, t_B, t_ul, t_ur, t_ll, t_lr, t_mid - local t_ul2, t_ur2, t_ll2, t_lr2 - logs = with_logs() do - Dagger.spawn_datadeps(;aliasing) do - t_A = Dagger.@spawn do_nothing(InOut(A)) - t_B = Dagger.@spawn do_nothing(InOut(B)) - t_ul = Dagger.@spawn do_nothing(InOut(B_ul)) - t_ur = Dagger.@spawn do_nothing(InOut(B_ur)) - t_ll = Dagger.@spawn do_nothing(InOut(B_ll)) - t_lr = Dagger.@spawn do_nothing(InOut(B_lr)) - t_mid = Dagger.@spawn do_nothing(InOut(B_mid)) - t_ul2 = Dagger.@spawn do_nothing(InOut(B_ul)) - t_ur2 = Dagger.@spawn do_nothing(InOut(B_ur)) - t_ll2 = Dagger.@spawn do_nothing(InOut(B_ll)) - t_lr2 = Dagger.@spawn do_nothing(InOut(B_lr)) - end + end + B = wrap_chunk_thunk(rand, 4, 4) + + # Views + B_ul = wrap_chunk_thunk(view, B, 1:2, 1:2) + B_ur = wrap_chunk_thunk(view, B, 1:2, 3:4) + B_ll = wrap_chunk_thunk(view, B, 3:4, 1:2) + B_lr = wrap_chunk_thunk(view, B, 3:4, 3:4) + B_mid = wrap_chunk_thunk(view, B, 2:3, 2:3) + for (B_name, B_view) in ( + (:B_ul, B_ul), + (:B_ur, B_ur), + (:B_ll, B_ll), + (:B_lr, B_lr), + (:B_mid, B_mid)) + @test Dagger.will_alias(Dagger.aliasing(B), Dagger.aliasing(B_view)) + B_view === B_mid && continue + @test Dagger.will_alias(Dagger.aliasing(B_mid), Dagger.aliasing(B_view)) + end + local t_A, t_B, t_ul, t_ur, t_ll, t_lr, t_mid + local t_ul2, t_ur2, t_ll2, t_lr2 + logs = with_logs() do + Dagger.spawn_datadeps() do + t_A = Dagger.@spawn do_nothing(InOut(A)) + t_B = Dagger.@spawn do_nothing(InOut(B)) + t_ul = Dagger.@spawn do_nothing(InOut(B_ul)) + t_ur = Dagger.@spawn do_nothing(InOut(B_ur)) + t_ll = Dagger.@spawn do_nothing(InOut(B_ll)) + t_lr = Dagger.@spawn do_nothing(InOut(B_lr)) + t_mid = Dagger.@spawn do_nothing(InOut(B_mid)) + t_ul2 = Dagger.@spawn do_nothing(InOut(B_ul)) + t_ur2 = Dagger.@spawn do_nothing(InOut(B_ur)) + t_ll2 = Dagger.@spawn do_nothing(InOut(B_ll)) + t_lr2 = Dagger.@spawn do_nothing(InOut(B_lr)) end - tid_A, tid_B, tid_ul, tid_ur, tid_ll, tid_lr, tid_mid = - task_id.([t_A, t_B, t_ul, t_ur, t_ll, t_lr, t_mid]) - tid_ul2, tid_ur2, tid_ll2, tid_lr2 = - task_id.([t_ul2, t_ur2, t_ll2, t_lr2]) - tids_all = [tid_A, tid_B, tid_ul, tid_ur, tid_ll, tid_lr, tid_mid, - tid_ul2, tid_ur2, tid_ll2, tid_lr2] - test_task_dominators(logs, tid_A, []; all_tids=tids_all) - test_task_dominators(logs, tid_B, []; all_tids=tids_all) - test_task_dominators(logs, tid_ul, [tid_B]; all_tids=tids_all) - test_task_dominators(logs, tid_ur, [tid_B]; all_tids=tids_all) - test_task_dominators(logs, tid_ll, [tid_B]; all_tids=tids_all) - test_task_dominators(logs, tid_lr, [tid_B]; all_tids=tids_all) - test_task_dominators(logs, tid_mid, [tid_B, tid_ul, tid_ur, tid_ll, tid_lr]; all_tids=tids_all) - test_task_dominators(logs, tid_ul2, [tid_B, tid_mid, tid_ul]; all_tids=tids_all, nondom_check=false) - test_task_dominators(logs, tid_ur2, [tid_B, tid_mid, tid_ur]; all_tids=tids_all, nondom_check=false) - test_task_dominators(logs, tid_ll2, [tid_B, tid_mid, tid_ll]; all_tids=tids_all, nondom_check=false) - test_task_dominators(logs, tid_lr2, [tid_B, tid_mid, tid_lr]; all_tids=tids_all, nondom_check=false) - - # (Unit)Upper/LowerTriangular and Diagonal - B_upper = wrap_chunk_thunk(UpperTriangular, B) - B_unitupper = wrap_chunk_thunk(UnitUpperTriangular, B) - B_lower = wrap_chunk_thunk(LowerTriangular, B) - B_unitlower = wrap_chunk_thunk(UnitLowerTriangular, B) - for (B_name, B_view) in ( - (:B_upper, B_upper), - (:B_unitupper, B_unitupper), - (:B_lower, B_lower), - (:B_unitlower, B_unitlower)) - @test Dagger.will_alias(Dagger.aliasing(B), Dagger.aliasing(B_view)) - end - @test Dagger.will_alias(Dagger.aliasing(B_upper), Dagger.aliasing(B_lower)) - @test !Dagger.will_alias(Dagger.aliasing(B_unitupper), Dagger.aliasing(B_unitlower)) - @test Dagger.will_alias(Dagger.aliasing(B_upper), Dagger.aliasing(B_unitupper)) - @test Dagger.will_alias(Dagger.aliasing(B_lower), Dagger.aliasing(B_unitlower)) - - @test Dagger.will_alias(Dagger.aliasing(B_upper), Dagger.aliasing(B, Diagonal)) - @test Dagger.will_alias(Dagger.aliasing(B_lower), Dagger.aliasing(B, Diagonal)) - @test !Dagger.will_alias(Dagger.aliasing(B_unitupper), Dagger.aliasing(B, Diagonal)) - @test !Dagger.will_alias(Dagger.aliasing(B_unitlower), Dagger.aliasing(B, Diagonal)) - - local t_A, t_B, t_upper, t_unitupper, t_lower, t_unitlower, t_diag - local t_upper2, t_unitupper2, t_lower2, t_unitlower2 - logs = with_logs() do - Dagger.spawn_datadeps(;aliasing) do - t_A = Dagger.@spawn do_nothing(InOut(A)) - t_B = Dagger.@spawn do_nothing(InOut(B)) - t_upper = Dagger.@spawn do_nothing(InOut(B_upper)) - t_unitupper = Dagger.@spawn do_nothing(InOut(B_unitupper)) - t_lower = Dagger.@spawn do_nothing(InOut(B_lower)) - t_unitlower = Dagger.@spawn do_nothing(InOut(B_unitlower)) - t_diag = Dagger.@spawn do_nothing(Deps(B, InOut(Diagonal))) - t_unitlower2 = Dagger.@spawn do_nothing(InOut(B_unitlower)) - t_lower2 = Dagger.@spawn do_nothing(InOut(B_lower)) - t_unitupper2 = Dagger.@spawn do_nothing(InOut(B_unitupper)) - t_upper2 = Dagger.@spawn do_nothing(InOut(B_upper)) - end + end + tid_A, tid_B, tid_ul, tid_ur, tid_ll, tid_lr, tid_mid = + task_id.([t_A, t_B, t_ul, t_ur, t_ll, t_lr, t_mid]) + tid_ul2, tid_ur2, tid_ll2, tid_lr2 = + task_id.([t_ul2, t_ur2, t_ll2, t_lr2]) + tids_all = [tid_A, tid_B, tid_ul, tid_ur, tid_ll, tid_lr, tid_mid, + tid_ul2, tid_ur2, tid_ll2, tid_lr2] + state = build_dataflow(logs) + @test test_dataflow(state, FlowCheck(:readwrite, A)) + @test test_dataflow(state, FlowCheck(:readwrite, B)) + @test test_dataflow(state, FlowCheck(:readwrite, B), FlowCheck(:readwrite, B_ul)) + @test test_dataflow(state, FlowCheck(:readwrite, B), FlowCheck(:readwrite, B_ur)) + @test test_dataflow(state, FlowCheck(:readwrite, B), FlowCheck(:readwrite, B_ll)) + @test test_dataflow(state, FlowCheck(:readwrite, B), FlowCheck(:readwrite, B_lr)) + for arg in [B_ul, B_ur, B_ll, B_lr] + @test test_dataflow(state, FlowCheck(:readwrite, B), FlowCheck(:readwrite, arg), FlowCheck(:readwrite, B_mid), FlowCheck(:readwrite, arg)) + end + + # (Unit)Upper/LowerTriangular and Diagonal + B_upper = wrap_chunk_thunk(UpperTriangular, B) + B_unitupper = wrap_chunk_thunk(UnitUpperTriangular, B) + B_lower = wrap_chunk_thunk(LowerTriangular, B) + B_unitlower = wrap_chunk_thunk(UnitLowerTriangular, B) + for (B_name, B_view) in ( + (:B_upper, B_upper), + (:B_unitupper, B_unitupper), + (:B_lower, B_lower), + (:B_unitlower, B_unitlower)) + @test Dagger.will_alias(Dagger.aliasing(B), Dagger.aliasing(B_view)) + end + @test Dagger.will_alias(Dagger.aliasing(B_upper), Dagger.aliasing(B_lower)) + @test !Dagger.will_alias(Dagger.aliasing(B_unitupper), Dagger.aliasing(B_unitlower)) + @test Dagger.will_alias(Dagger.aliasing(B_upper), Dagger.aliasing(B_unitupper)) + @test Dagger.will_alias(Dagger.aliasing(B_lower), Dagger.aliasing(B_unitlower)) + + @test Dagger.will_alias(Dagger.aliasing(B_upper), Dagger.aliasing(B, Diagonal)) + @test Dagger.will_alias(Dagger.aliasing(B_lower), Dagger.aliasing(B, Diagonal)) + @test !Dagger.will_alias(Dagger.aliasing(B_unitupper), Dagger.aliasing(B, Diagonal)) + @test !Dagger.will_alias(Dagger.aliasing(B_unitlower), Dagger.aliasing(B, Diagonal)) + + local t_A, t_B, t_upper, t_unitupper, t_lower, t_unitlower, t_diag + local t_upper2, t_unitupper2, t_lower2, t_unitlower2 + logs = with_logs() do + Dagger.spawn_datadeps() do + t_A = Dagger.@spawn do_nothing(InOut(A)) + t_B = Dagger.@spawn do_nothing(InOut(B)) + t_upper = Dagger.@spawn do_nothing(InOut(B_upper)) + t_unitupper = Dagger.@spawn do_nothing(InOut(B_unitupper)) + t_lower = Dagger.@spawn do_nothing(InOut(B_lower)) + t_unitlower = Dagger.@spawn do_nothing(InOut(B_unitlower)) + t_diag = Dagger.@spawn do_nothing(Deps(B, InOut(Diagonal))) + t_unitlower2 = Dagger.@spawn do_nothing(InOut(B_unitlower)) + t_lower2 = Dagger.@spawn do_nothing(InOut(B_lower)) + t_unitupper2 = Dagger.@spawn do_nothing(InOut(B_unitupper)) + t_upper2 = Dagger.@spawn do_nothing(InOut(B_upper)) end - tid_A, tid_B, tid_upper, tid_unitupper, tid_lower, tid_unitlower, tid_diag = - task_id.([t_A, t_B, t_upper, t_unitupper, t_lower, t_unitlower, t_diag]) - tid_upper2, tid_unitupper2, tid_lower2, tid_unitlower2 = - task_id.([t_upper2, t_unitupper2, t_lower2, t_unitlower2]) - tids_all = [tid_A, tid_B, tid_upper, tid_unitupper, tid_lower, tid_unitlower, tid_diag, - tid_upper2, tid_unitupper2, tid_lower2, tid_unitlower2] - test_task_dominators(logs, tid_A, []; all_tids=tids_all) - test_task_dominators(logs, tid_B, []; all_tids=tids_all) - # FIXME: Proper non-dominance checks - test_task_dominators(logs, tid_upper, [tid_B]; all_tids=tids_all, nondom_check=false) - test_task_dominators(logs, tid_unitupper, [tid_B, tid_upper]; all_tids=tids_all, nondom_check=false) - test_task_dominators(logs, tid_lower, [tid_B, tid_upper]; all_tids=tids_all, nondom_check=false) - test_task_dominators(logs, tid_unitlower, [tid_B, tid_lower]; all_tids=tids_all, nondom_check=false) - test_task_dominators(logs, tid_diag, [tid_B, tid_upper, tid_lower]; all_tids=tids_all, nondom_check=false) - test_task_dominators(logs, tid_unitlower2, [tid_B, tid_lower, tid_unitlower]; all_tids=tids_all, nondom_check=false) - test_task_dominators(logs, tid_lower2, [tid_B, tid_lower, tid_unitlower, tid_diag, tid_unitlower2]; all_tids=tids_all, nondom_check=false) - test_task_dominators(logs, tid_unitupper2, [tid_B, tid_upper, tid_unitupper]; all_tids=tids_all, nondom_check=false) - test_task_dominators(logs, tid_upper2, [tid_B, tid_upper, tid_unitupper, tid_diag, tid_unitupper2]; all_tids=tids_all, nondom_check=false) - - # Additional aliasing tests - views_overlap(x, y) = Dagger.will_alias(Dagger.aliasing(x), Dagger.aliasing(y)) - - A = wrap_chunk_thunk(identity, B) - - A_r1 = wrap_chunk_thunk(view, A, 1:1, 1:4) - A_r2 = wrap_chunk_thunk(view, A, 2:2, 1:4) - B_r1 = wrap_chunk_thunk(view, B, 1:1, 1:4) - B_r2 = wrap_chunk_thunk(view, B, 2:2, 1:4) - - A_c1 = wrap_chunk_thunk(view, A, 1:4, 1:1) - A_c2 = wrap_chunk_thunk(view, A, 1:4, 2:2) - B_c1 = wrap_chunk_thunk(view, B, 1:4, 1:1) - B_c2 = wrap_chunk_thunk(view, B, 1:4, 2:2) - - A_mid = wrap_chunk_thunk(view, A, 2:3, 2:3) - B_mid = wrap_chunk_thunk(view, B, 2:3, 2:3) - - @test views_overlap(A_r1, A_r1) - @test views_overlap(B_r1, B_r1) - @test views_overlap(A_c1, A_c1) - @test views_overlap(B_c1, B_c1) - - @test views_overlap(A_r1, B_r1) - @test views_overlap(A_r2, B_r2) - @test views_overlap(A_c1, B_c1) - @test views_overlap(A_c2, B_c2) - - @test !views_overlap(A_r1, A_r2) - @test !views_overlap(B_r1, B_r2) - @test !views_overlap(A_c1, A_c2) - @test !views_overlap(B_c1, B_c2) - - @test views_overlap(A_r1, A_c1) - @test views_overlap(A_r1, B_c1) - @test views_overlap(A_r2, A_c2) - @test views_overlap(A_r2, B_c2) - - for (name, mid) in ((:A_mid, A_mid), (:B_mid, B_mid)) - @test !views_overlap(A_r1, mid) - @test !views_overlap(B_r1, mid) - @test !views_overlap(A_c1, mid) - @test !views_overlap(B_c1, mid) - - @test views_overlap(A_r2, mid) - @test views_overlap(B_r2, mid) - @test views_overlap(A_c2, mid) - @test views_overlap(B_c2, mid) - end - - @test views_overlap(A_mid, A_mid) - @test views_overlap(A_mid, B_mid) - - # SubArray hashing - V = zeros(3) - Dagger.spawn_datadeps(;aliasing) do - Dagger.@spawn mut_V!(InOut(view(V, 1:2))) - Dagger.@spawn mut_V!(InOut(view(V, 2:3))) - end - @test fetch(V) == [1, 1, 1] end + tid_A, tid_B, tid_upper, tid_unitupper, tid_lower, tid_unitlower, tid_diag = + task_id.([t_A, t_B, t_upper, t_unitupper, t_lower, t_unitlower, t_diag]) + tid_upper2, tid_unitupper2, tid_lower2, tid_unitlower2 = + task_id.([t_upper2, t_unitupper2, t_lower2, t_unitlower2]) + tids_all = [tid_A, tid_B, tid_upper, tid_unitupper, tid_lower, tid_unitlower, tid_diag, + tid_upper2, tid_unitupper2, tid_lower2, tid_unitlower2] + state = build_dataflow(logs) + @test test_dataflow(state, FlowCheck(:readwrite, A)) + @test test_dataflow(state, FlowCheck(:readwrite, B)) + @test test_dataflow(state, FlowCheck(:readwrite, B), FlowCheck(:readwrite, B_upper), FlowCheck(:readwrite, B_unitupper)) + @test test_dataflow(state, FlowCheck(:readwrite, B), FlowCheck(:readwrite, B_upper), FlowCheck(:readwrite, B_lower)) + @test test_dataflow(state, FlowCheck(:readwrite, B), FlowCheck(:readwrite, B_lower), FlowCheck(:readwrite, B_unitlower)) + @test test_dataflow(state, FlowCheck(:readwrite, B), FlowCheck(:readwrite, B_upper), FlowCheck(:readwrite, B_lower), + FlowCheck(:readwrite, B, Diagonal)) + @test test_dataflow(state, FlowCheck(:readwrite, B), FlowCheck(:readwrite, B_lower), FlowCheck(:readwrite, B_unitlower), + FlowCheck(:readwrite, B, Diagonal), FlowCheck(:readwrite, B_unitlower)) + @test test_dataflow(state, FlowCheck(:readwrite, B), FlowCheck(:readwrite, B_lower), FlowCheck(:readwrite, B_unitlower), + FlowCheck(:readwrite, B, Diagonal), FlowCheck(:readwrite, B_unitlower), FlowCheck(:readwrite, B_lower)) + @test test_dataflow(state, FlowCheck(:readwrite, B), FlowCheck(:readwrite, B_upper), FlowCheck(:readwrite, B_unitupper), + FlowCheck(:readwrite, B_unitupper)) + @test test_dataflow(state, FlowCheck(:readwrite, B), FlowCheck(:readwrite, B_upper), FlowCheck(:readwrite, B_unitupper), + FlowCheck(:readwrite, B, Diagonal), FlowCheck(:readwrite, B_unitupper), FlowCheck(:readwrite, B_upper)) + + # Additional aliasing tests + views_overlap(x, y) = Dagger.will_alias(Dagger.aliasing(x), Dagger.aliasing(y)) + + A = wrap_chunk_thunk(identity, B) + + A_r1 = wrap_chunk_thunk(view, A, 1:1, 1:4) + A_r2 = wrap_chunk_thunk(view, A, 2:2, 1:4) + B_r1 = wrap_chunk_thunk(view, B, 1:1, 1:4) + B_r2 = wrap_chunk_thunk(view, B, 2:2, 1:4) + + A_c1 = wrap_chunk_thunk(view, A, 1:4, 1:1) + A_c2 = wrap_chunk_thunk(view, A, 1:4, 2:2) + B_c1 = wrap_chunk_thunk(view, B, 1:4, 1:1) + B_c2 = wrap_chunk_thunk(view, B, 1:4, 2:2) + + A_mid = wrap_chunk_thunk(view, A, 2:3, 2:3) + B_mid = wrap_chunk_thunk(view, B, 2:3, 2:3) + + @test views_overlap(A_r1, A_r1) + @test views_overlap(B_r1, B_r1) + @test views_overlap(A_c1, A_c1) + @test views_overlap(B_c1, B_c1) + + @test views_overlap(A_r1, B_r1) + @test views_overlap(A_r2, B_r2) + @test views_overlap(A_c1, B_c1) + @test views_overlap(A_c2, B_c2) + + @test !views_overlap(A_r1, A_r2) + @test !views_overlap(B_r1, B_r2) + @test !views_overlap(A_c1, A_c2) + @test !views_overlap(B_c1, B_c2) + + @test views_overlap(A_r1, A_c1) + @test views_overlap(A_r1, B_c1) + @test views_overlap(A_r2, A_c2) + @test views_overlap(A_r2, B_c2) + + for (name, mid) in ((:A_mid, A_mid), (:B_mid, B_mid)) + @test !views_overlap(A_r1, mid) + @test !views_overlap(B_r1, mid) + @test !views_overlap(A_c1, mid) + @test !views_overlap(B_c1, mid) + + @test views_overlap(A_r2, mid) + @test views_overlap(B_r2, mid) + @test views_overlap(A_c2, mid) + @test views_overlap(B_c2, mid) + end + + @test views_overlap(A_mid, A_mid) + @test views_overlap(A_mid, B_mid) + + # SubArray hashing + V = zeros(3) + Dagger.spawn_datadeps() do + Dagger.@spawn mut_V!(InOut(view(V, 1:2))) + Dagger.@spawn mut_V!(InOut(view(V, 2:3))) + end + @test fetch(V) == [1, 1, 1] # FIXME: Deps # Outer Scope - exec_procs = fetch.(Dagger.spawn_datadeps(;aliasing) do + exec_procs = fetch.(Dagger.spawn_datadeps() do [Dagger.@spawn Dagger.task_processor() for i in 1:10] end) unique!(exec_procs) @@ -499,7 +737,7 @@ function test_datadeps(;args_chunks::Bool, end # Inner Scope - @test_throws Dagger.Sch.SchedulingException Dagger.spawn_datadeps(;aliasing) do + @test_throws Dagger.Sch.SchedulingException Dagger.spawn_datadeps() do Dagger.@spawn scope=Dagger.ExactScope(Dagger.ThreadProc(1, 5000)) 1+1 end @@ -528,7 +766,7 @@ function test_datadeps(;args_chunks::Bool, C = Dagger.@spawn scope=Dagger.scope(worker=args_loc) copy(C) D = Dagger.@spawn scope=Dagger.scope(worker=args_loc) copy(D) end - Dagger.spawn_datadeps(;aliasing) do + Dagger.spawn_datadeps() do Dagger.@spawn add!(InOut(B), In(A)) Dagger.@spawn add!(InOut(C), In(A)) Dagger.@spawn add!(InOut(C), In(B)) @@ -545,7 +783,7 @@ function test_datadeps(;args_chunks::Bool, elseif args_thunks As = map(A->(Dagger.@spawn scope=Dagger.scope(worker=args_loc) copy(A)), As) end - Dagger.spawn_datadeps(;aliasing) do + Dagger.spawn_datadeps() do to_reduce = Vector[] push!(to_reduce, As) while !isempty(to_reduce) @@ -576,7 +814,7 @@ function test_datadeps(;args_chunks::Bool, elseif args_thunks M = map(m->(Dagger.@spawn scope=Dagger.scope(worker=args_loc) copy(m)), M) end - Dagger.spawn_datadeps(;aliasing) do + Dagger.spawn_datadeps() do for k in range(1, mt) Dagger.@spawn LAPACK.potrf!('L', InOut(M[k, k])) for _m in range(k+1, mt) @@ -596,18 +834,16 @@ function test_datadeps(;args_chunks::Bool, @test isapprox(M_dense, expected) end -@testset "$(aliasing ? "With" : "Without") Aliasing Support" for aliasing in (true, false) - @testset "$args_mode Data" for args_mode in (:Raw, :Chunk, :Thunk) - args_chunks = args_mode == :Chunk - args_thunks = args_mode == :Thunk - for nw in (1, 2) - args_loc = nw == 2 ? 2 : 1 - for nt in (1, 2) - if nprocs() >= nw && Threads.nthreads() >= nt - @testset "$nw Workers, $nt Threads" begin - Dagger.with_options(;scope=Dagger.scope(workers=1:nw, threads=1:nt)) do - test_datadeps(;args_chunks, args_thunks, args_loc, aliasing) - end +@testset "$args_mode Data" for args_mode in (:Raw, :Chunk, :Thunk) + args_chunks = args_mode == :Chunk + args_thunks = args_mode == :Thunk + for nw in (1, 2) + args_loc = nw == 2 ? 2 : 1 + for nt in (1, 2) + if nprocs() >= nw && Threads.nthreads() >= nt + @testset "$nw Workers, $nt Threads" begin + Dagger.with_options(;scope=Dagger.scope(workers=1:nw, threads=1:nt)) do + test_datadeps(;args_chunks, args_thunks, args_loc) end end end diff --git a/test/scopes.jl b/test/scopes.jl index fa5bf1135..55d15b349 100644 --- a/test/scopes.jl +++ b/test/scopes.jl @@ -123,8 +123,8 @@ us_es1_multi_ch = Dagger.tochunk(nothing, OSProc(), UnionScope(es1, es1)) @test fetch(Dagger.@spawn exact_scope_test(us_es1_multi_ch)) == es1.processor - # No inner scopes - @test UnionScope() isa UnionScope + # No inner scopes (disallowed) + @test_throws ArgumentError UnionScope() # Same inner scope @test fetch(Dagger.@spawn exact_scope_test(us_es1_ch, us_es1_ch)) == es1.processor @@ -165,7 +165,7 @@ @test Dagger.scope(:any) isa AnyScope @test Dagger.scope(:default) == DefaultScope() @test_throws ArgumentError Dagger.scope(:blah) - @test Dagger.scope(()) == UnionScope() + @test_throws ArgumentError Dagger.scope(()) @test Dagger.scope(worker=wid1) == Dagger.scope(workers=[wid1]) diff --git a/test/task-affinity.jl b/test/task-affinity.jl index f1e26295a..ce898b476 100644 --- a/test/task-affinity.jl +++ b/test/task-affinity.jl @@ -135,7 +135,7 @@ @testset "Chunk function, scope, compute_scope and result_scope" begin @everywhere g(x, y) = x * 2 + y * 3 - n = cld(numscopes, 3) + n = fld(numscopes, 3) shuffle!(availscopes) scope_a = availscopes[1:n]