diff --git a/.gitignore b/.gitignore index f19646c..62d4c14 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ /docs/build/ Manifest.toml -test/Manifest.toml \ No newline at end of file +test/Manifest.toml +.vscode \ No newline at end of file diff --git a/Project.toml b/Project.toml index 4a30120..d365335 100644 --- a/Project.toml +++ b/Project.toml @@ -4,6 +4,7 @@ authors = ["lorenzoh "] version = "0.1.2" [deps] +Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" LearnBase = "7f8f8fb0-2700-5f03-b4bd-41f8cfc144b6" MLDataPattern = "9920b226-0b2a-5f5f-9153-9aa70a013f8b" diff --git a/src/DataLoaders.jl b/src/DataLoaders.jl index e60806a..588bc89 100644 --- a/src/DataLoaders.jl +++ b/src/DataLoaders.jl @@ -1,5 +1,7 @@ module DataLoaders +using Base.Threads +using Distributed using MLDataPattern using ThreadPools using LearnBase @@ -43,8 +45,10 @@ for PyTorch users. - `parallel::Bool = Threads.nthreads() > 1)`: Whether to load data in parallel, keeping the primary thread is. Default is `true` if more than one thread is available. +- `usethreads::Bool = true`: Whether to use threads or processes - `useprimary::Bool = false`: If `false`, keep the main thread free when loading data in parallel. Is ignored if `parallel` is `false`. +- `maxquesize = nothing`: Maximum size of caching queue. ## Examples @@ -57,15 +61,23 @@ function DataLoader( collate = !isnothing(batchsize), buffered = collate, partial = true, - useprimary = Threads.nthreads() == 1, + usethreads = true, + useprimary = usethreads ? nthreads() == 1 : nprocs() == 1, + maxquesize = nothing, ) - Threads.nthreads() > 1 || useprimary || error( + !usethreads || nthreads() > 1 || useprimary || error( "Julia is running with one thread only, either pass `useprimary = true` or " * "start Julia with multiple threads by passing " * "the `-t n` option or setting the `JULIA_NUM_THREADS` " * "environment variable before starting Julia.") + usethreads || nprocs() > 1 || useprimary || error( + "Julia is running with one procs only, either pass `useprimary = true` or " * + "start Julia with multiple threads by passing " * + "the `-p n` option " * + "environment variable before starting Julia.") + batchwrapper = if isnothing(batchsize) identity elseif collate @@ -75,7 +87,7 @@ function DataLoader( data -> batchview(data, size = batchsize) end - loadwrapper = data -> eachobsparallel(data; useprimary = useprimary, buffered = buffered) + loadwrapper = data -> eachobsparallel(data; usethreads = usethreads, useprimary = useprimary, buffered = buffered, maxquesize = maxquesize) return loadwrapper(batchwrapper(data)) end diff --git a/src/loaders.jl b/src/loaders.jl index 6684f33..9e1d44a 100644 --- a/src/loaders.jl +++ b/src/loaders.jl @@ -2,11 +2,19 @@ struct GetObsParallel{TData} data::TData + usethreads::Bool useprimary::Bool - function GetObsParallel(data::TData; useprimary = false) where {TData} - (useprimary || Threads.nthreads() > 1) || - error("Cannot load data off main thread with only one thread available. Pass `useprimary = true` or start Julia with > 1 threads.") - return new{TData}(data, useprimary) + maxquesize::Int + function GetObsParallel(data::TData; usethreads = true, useprimary = false, maxquesize = nothing) where {TData} + if usethreads + (useprimary || nthreads() > 1) || + error("Cannot load data off main thread with only one thread available. Pass `useprimary = true` or start Julia with > 1 threads.") + else + (useprimary || nworkers() > 1) || + error("Cannot load data off main thread with only one process available. Pass `useprimary = true` or start Julia with > 1 processes.") + end + maxquesize = something(maxquesize, usethreads ? nthreads() : nworkers()) + return new{TData}(data, usethreads, useprimary, maxquesize) end end @@ -14,29 +22,35 @@ end Base.length(iterparallel::GetObsParallel) = nobs(iterparallel.data) function Base.iterate(iterparallel::GetObsParallel) - resultschannel = Channel(Threads.nthreads() - Int(!iterparallel.useprimary)) + resultschannel = if iterparallel.usethreads + Channel(iterparallel.maxquesize) + else + RemoteChannel(() -> Channel(iterparallel.maxquesize)) + end workerpool = - WorkerPool(1:nobs(iterparallel.data), useprimary = iterparallel.useprimary) do idx + WorkerPool(1:nobs(iterparallel.data), usethreads = iterparallel.usethreads, useprimary = iterparallel.useprimary) do idx put!(resultschannel, getobs(iterparallel.data, idx)) end - @async run(workerpool) + task = @async run(workerpool) - return iterate(iterparallel, (resultschannel, workerpool, 0)) + return iterate(iterparallel, (task, resultschannel, workerpool, 0)) end function Base.iterate(iterparallel::GetObsParallel, state) - resultschannel, workerpool, index = state + task, resultschannel, workerpool, index = state # Worker pool failed - if workerpool.state === Failed + if fetch(workerpool.state) === Failed error("Worker pool failed.") # Iteration complete elseif index >= nobs(iterparallel.data) + close(resultschannel) + wait(task) return nothing else - return take!(resultschannel), (resultschannel, workerpool, index + 1) + return take!(resultschannel), (task, resultschannel, workerpool, index + 1) end end @@ -44,7 +58,7 @@ end # Buffered version """ - BufferGetObsParallel(data; useprimary = false) + BufferGetObsParallel(data; usethreads = true, useprimary = false, maxquesize = nothing) Like `MLDataPattern.BufferGetObs` but preloads observations into a buffer ring with multi-threaded workers. @@ -52,23 +66,30 @@ buffer ring with multi-threaded workers. struct BufferGetObsParallel{TElem,TData} data::TData buffers::Vector{TElem} + usethreads::Bool useprimary::Bool + maxquesize::Int end Base.show(io::IO, bufparallel::BufferGetObsParallel) = print(io, "eachobsparallel($(bufparallel.data))") -function BufferGetObsParallel(data; useprimary = false) - nthreads = Threads.nthreads() - Int(!useprimary) - nthreads > 0 || - error("Cannot load data off main thread with only one thread available. Pass `useprimary = true` or start Julia with > 1 threads.") +function BufferGetObsParallel(data; usethreads = true, useprimary = false, maxquesize = nothing) + if usethreads + (useprimary || nthreads() > 1) || + error("Cannot load data off main thread with only one thread available. Pass `useprimary = true` or start Julia with > 1 threads.") + else + (useprimary || nworkers() > 1) || + error("Cannot load data off main thread with only one process available. Pass `useprimary = true` or start Julia with > 1 processes.") + end buffer = getobs(data, 1) buffers = [buffer] - for _ ∈ 1:nthreads + maxquesize = something(maxquesize, usethreads ? nthreads() : nworkers()) + for _ ∈ 1:maxquesize push!(buffers, deepcopy(buffer)) end - return BufferGetObsParallel(data, buffers, useprimary) + return BufferGetObsParallel(data, buffers, usethreads, useprimary, maxquesize) end @@ -76,31 +97,40 @@ Base.length(iterparallel::BufferGetObsParallel) = nobs(iterparallel.data) function Base.iterate(iterparallel::BufferGetObsParallel) - ringbuffer = RingBuffer(iterparallel.buffers) - - workerpool = - WorkerPool(1:nobs(iterparallel.data), useprimary = iterparallel.useprimary) do idx - put!(ringbuffer) do buf - getobs!(buf, iterparallel.data, idx) + if iterparallel.usethreads + resultschannel = RingBuffer(iterparallel.buffers) + workerpool = + WorkerPool(1:nobs(iterparallel.data), useprimary = iterparallel.useprimary) do idx + put!(resultschannel) do buf + getobs!(buf, iterparallel.data, idx) + end end - end - @async run(workerpool) + else + resultschannel = RemoteChannel(() -> Channel(iterparallel.maxquesize)) + workerpool = + WorkerPool(1:nobs(iterparallel.data), usethreads=iterparallel.usethreads, useprimary = iterparallel.useprimary) do idx + put!(resultschannel, getobs(iterparallel.data, idx)) + end + end + task = @async run(workerpool) - return iterate(iterparallel, (ringbuffer, workerpool, 0)) + return iterate(iterparallel, (task, resultschannel, workerpool, 0)) end function Base.iterate(iterparallel::BufferGetObsParallel, state) - ringbuffer, workerpool, index = state + task, resultschannel, workerpool, index = state # Worker pool failed - if workerpool.state === Failed + if fetch(workerpool.state) === Failed error("Worker pool failed.") # Iteration complete elseif index >= nobs(iterparallel.data) + close(resultschannel) + wait(task) return nothing else - return take!(ringbuffer), (ringbuffer, workerpool, index + 1) + return take!(resultschannel), (task, resultschannel, workerpool, index + 1) end end @@ -108,7 +138,7 @@ end # functional interface """ - eachobsparallel(data; useprimary = false, buffered = true) + eachobsparallel(data; usethreads = true, useprimary = false, buffered = true, maxquesize = nothing) Parallel data iterator for data container `data`. Loads data on all available threads (except the first if `useprimary` is `false`). @@ -124,6 +154,6 @@ See also `MLDataPattern.eachobs` are returned in the correct order. """ -eachobsparallel(data; useprimary = false, buffered = true) = - buffered ? BufferGetObsParallel(data, useprimary = useprimary) : - GetObsParallel(data, useprimary = useprimary) +eachobsparallel(data; usethreads = true, useprimary = false, buffered = true, maxquesize = nothing) = + buffered ? BufferGetObsParallel(data, usethreads = usethreads, useprimary = useprimary, maxquesize = maxquesize) : + GetObsParallel(data, usethreads = usethreads, useprimary = useprimary, maxquesize = maxquesize) diff --git a/src/ringbuffer.jl b/src/ringbuffer.jl index a5ef4a2..5014f98 100644 --- a/src/ringbuffer.jl +++ b/src/ringbuffer.jl @@ -82,6 +82,9 @@ function Base.put!(f!, ringbuffer::RingBuffer) put!(ringbuffer.results, buf_) end +function Base.isopen(ringbuffer::RingBuffer) + isopen(ringbuffer.results) && isopen(ringbuffer.buffers) +end function Base.close(ringbuffer::RingBuffer) close(ringbuffer.results) diff --git a/src/workerpool.jl b/src/workerpool.jl index be21af4..c8ea8e5 100644 --- a/src/workerpool.jl +++ b/src/workerpool.jl @@ -1,3 +1,25 @@ +import Base: put!, wait, isready, take!, fetch + +mutable struct ValueChannel{T} <: AbstractChannel{T} + v::T + cond_take::Condition # waiting for data to become available + ValueChannel(v) = new{typeof(v)}(v, Condition()) +end + +function put!(c::ValueChannel, v) + c.v = v + notify(c.cond_take) + return c +end + +take!(c::ValueChannel) = fetch(c) + +isready(c::ValueChannel) = true + +fetch(c::ValueChannel) = c.v + +wait(c::ValueChannel) = wait(c.cond_take) + @enum PoolState begin Done Running @@ -11,65 +33,82 @@ end mutable struct WorkerPool{TArgs} workerfn::Any args::Vector{TArgs} + usethreads::Bool useprimary::Bool - state::PoolState - ntasks::Threads.Atomic{Int} + state::ValueChannel{PoolState} end -function WorkerPool(workerfn, args::AbstractVector{TArgs}; useprimary = false) where {TArgs} - (useprimary || Threads.nthreads() > 1) || - error("Cannot load data off main thread with only one thread available. Pass `useprimary = true` or start Julia with >1 threads.") - return WorkerPool{TArgs}(workerfn, collect(args), useprimary, Done, Threads.Atomic{Int}(0)) +function WorkerPool(workerfn, args::AbstractVector{TArgs}; usethreads = true, useprimary = false) where {TArgs} + if usethreads + (useprimary || nthreads() > 1) || + error("Cannot load data off main thread with only one thread available. Pass `useprimary = true` or start Julia with > 1 threads.") + else + (useprimary || nworkers() > 1) || + error("Cannot load data off main thread with only one process available. Pass `useprimary = true` or start Julia with > 1 processes.") + end + return WorkerPool{TArgs}(workerfn, collect(args), usethreads, useprimary, ValueChannel(Done)) end -function run(pool::WorkerPool) - pool.state = Running - # set remaining tasks counter. - pool.ntasks[] = length(pool.args) +function run(pool::WorkerPool{TArgs}) where TArgs + @unpack workerfn, usethreads, useprimary, state = pool + put!(state, Running) # watchdog that sends exception to main thread if a worker fails maintask = current_task() - @async begin - while pool.state !== Done - if pool.state === Failed + watchdog = @async begin + while fetch(state) !== Done + if fetch(state) === Failed Base.throwto( maintask, - PoolFailedException("Failed to process all tasks. $(pool.ntasks[]) unfinished tasks remaining"), + PoolFailedException("Failed to process all tasks. $(length(tasks)) unfinished tasks remaining"), ) end sleep(0.1) end end - - function inloop(args) - #for args in pool.args # uncomment for debugging - # task error handling - pool.state !== Failed || error("Shutting down worker $(Threads.threadid())") - try - # execute task - pool.workerfn(args...) - Threads.atomic_add!(pool.ntasks, -1) - catch e - println(stacktrace()) - @error "Exception while executing task on worker $(Threads.threadid()). Shutting down WorkerPool." e = - e stacktrace = stacktrace() args = args - pool.state = Failed - rethrow() - end - end - - if pool.useprimary - @qthreads for args in pool.args - inloop(args) + + if usethreads + (useprimary ? qforeach : qbforeach)(pool.args) do args + inloop(state, workerfn, threadid(), args) end else - @qbthreads for args in pool.args - inloop(args) + tasks = Channel{TArgs}(Inf) + foreach(a -> put!(tasks, a), pool.args) + close(tasks) + remote_state = RemoteChannel(() -> state) + remote_tasks = RemoteChannel(() -> tasks) + @sync for id in (useprimary ? procs() : workers()) + @spawnat id on_worker(remote_tasks, remote_state, workerfn, usethreads, useprimary) end end # Tasks completed successfully - pool.state = Done + put!(state, Done) + wait(watchdog) end + +function inloop(state, workerfn, id, args) + try + # execute task + workerfn(args...) + catch e + display(stacktrace()) + @error "Exception while executing task on worker $id. Shutting down WorkerPool." e = + e stacktrace = stacktrace() args = args + put!(state, Failed) + rethrow() + end +end + +function on_worker(tasks, state, workerfn, usethreads, useprimary) + # task error handling + id = usethreads ? threadid() : myid() + !useprimary && id == 1 && return + while isready(tasks) + args = try take!(tasks) catch e break end + fetch(state) !== Failed || error("Shutting down worker $id") + inloop(state, workerfn, id, args) + end +end \ No newline at end of file diff --git a/test/Project.toml b/test/Project.toml index 1d2c169..4ff1307 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,4 +1,5 @@ [deps] +Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" LearnBase = "7f8f8fb0-2700-5f03-b4bd-41f8cfc144b6" MLDataPattern = "9920b226-0b2a-5f5f-9153-9aa70a013f8b" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" diff --git a/test/runtests.jl b/test/runtests.jl index 13777a5..fb7598a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,3 +1,9 @@ +using Distributed + +addprocs(2) + +@everywhere begin + using Test using TestSetExtensions using DataLoaders @@ -21,6 +27,7 @@ LearnBase.getobs!(buf, ds::MockDataset, idx::Int) = ds.inplace ? fill!(buf, 0.3) : getobs(ds, idx) LearnBase.nobs(ds::MockDataset) = ds.n +end @testset ExtendedTestSet "collate" begin @test collate([([1, 2], 3), ([4, 5], 6)]) == ([1 4; 2 5], [3, 6]) @@ -135,7 +142,7 @@ end @testset ExtendedTestSet "iterate" begin dl = make() - x, (results, workerpool, idx) = iterate(dl) + x, (task, results, workerpool, idx) = iterate(dl) @test idx == 1 @test_nowarn for obs in dl end @@ -148,7 +155,7 @@ end @testset ExtendedTestSet "iterate" begin dl = make() - x, (ringbuffer, workerpool, idx) = iterate(dl) + x, (task, ringbuffer, workerpool, idx) = iterate(dl) @test idx == 1 @test_nowarn for obs in dl end @@ -174,33 +181,46 @@ end end +for usethreads in (true, false), maxquesize in (nothing, 4, 8) -@testset ExtendedTestSet "DataLoader" begin +@testset ExtendedTestSet "DataLoader: usethreads=$usethreads, maxquesize=$maxquesize" begin data = MockDataset(256, (10, 5), true) bs = 8 - + @testset ExtendedTestSet "buffer, collate, parallel" begin - dl = DataLoader(data, bs) + dl = DataLoader(data, bs, usethreads = usethreads, maxquesize = maxquesize) @test_nowarn for batch in dl end end @testset ExtendedTestSet "buffer, collate, parallel, samples" begin - dl = DataLoader(data, nothing) + dl = DataLoader(data, nothing, usethreads = usethreads, maxquesize = maxquesize) + @test_nowarn for batch in dl end + end + + @testset ExtendedTestSet "buffer, collate, parallel, samples, distributed" begin + dl = DataLoader(data, nothing, usethreads = usethreads, maxquesize = maxquesize) @test_nowarn for batch in dl end end @testset ExtendedTestSet "collate, parallel" begin - dl = DataLoader(data, bs, buffered = false) + dl = DataLoader(data, bs, buffered = false, usethreads = usethreads, maxquesize = maxquesize) @test_nowarn for batch in dl end end @testset ExtendedTestSet "collate" begin - dl = DataLoader(data, bs, buffered = false) + dl = DataLoader(data, bs, buffered = false, usethreads = usethreads) + @test_nowarn for batch in dl end + end + + @testset ExtendedTestSet "collate, distributed" begin + dl = DataLoader(data, bs, buffered = false, usethreads = usethreads) @test_nowarn for batch in dl end end @testset ExtendedTestSet "buffer, collate" begin - dl = DataLoader(data, bs, buffered = true) + dl = DataLoader(data, bs, buffered = true, usethreads = usethreads) @test_nowarn for batch in dl end end end + +end \ No newline at end of file