From fcc883b98bb3bc4693421f3d81c89acd7148f1b4 Mon Sep 17 00:00:00 2001 From: Yao Lu Date: Tue, 22 Dec 2020 23:00:37 +0800 Subject: [PATCH 1/9] distributed --- .gitignore | 3 +- Project.toml | 1 + src/DataLoaders.jl | 19 ++++++-- src/loaders.jl | 32 ++++++++----- src/workerpool.jl | 111 ++++++++++++++++++++++++++++++--------------- 5 files changed, 114 insertions(+), 52 deletions(-) diff --git a/.gitignore b/.gitignore index f19646c..62d4c14 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ /docs/build/ Manifest.toml -test/Manifest.toml \ No newline at end of file +test/Manifest.toml +.vscode \ No newline at end of file diff --git a/Project.toml b/Project.toml index 4a30120..d365335 100644 --- a/Project.toml +++ b/Project.toml @@ -4,6 +4,7 @@ authors = ["lorenzoh "] version = "0.1.2" [deps] +Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" LearnBase = "7f8f8fb0-2700-5f03-b4bd-41f8cfc144b6" MLDataPattern = "9920b226-0b2a-5f5f-9153-9aa70a013f8b" diff --git a/src/DataLoaders.jl b/src/DataLoaders.jl index e60806a..f286492 100644 --- a/src/DataLoaders.jl +++ b/src/DataLoaders.jl @@ -1,5 +1,7 @@ module DataLoaders +using Base.Threads +using Distributed using MLDataPattern using ThreadPools using LearnBase @@ -57,15 +59,26 @@ function DataLoader( collate = !isnothing(batchsize), buffered = collate, partial = true, - useprimary = Threads.nthreads() == 1, + usethreads = true, + useprimary = usethreads ? nthreads() == 1 : nprocs() == 1, ) - Threads.nthreads() > 1 || useprimary || error( + !usethreads || nthreads() > 1 || useprimary || error( "Julia is running with one thread only, either pass `useprimary = true` or " * "start Julia with multiple threads by passing " * "the `-t n` option or setting the `JULIA_NUM_THREADS` " * "environment variable before starting Julia.") + usethreads || Distributed.nprocs() > 1 || useprimary || error( + "Julia is running with one procs only, either pass `useprimary = true` or " * + "start Julia with multiple threads by passing " * + "the `-p n` option " * + "environment variable before starting Julia.") + + usethreads || !buffered || error( + "Buffered loading is not compatible with `usethreads = false`" + ) + batchwrapper = if isnothing(batchsize) identity elseif collate @@ -75,7 +88,7 @@ function DataLoader( data -> batchview(data, size = batchsize) end - loadwrapper = data -> eachobsparallel(data; useprimary = useprimary, buffered = buffered) + loadwrapper = data -> eachobsparallel(data; usethreads = usethreads, useprimary = useprimary, buffered = buffered) return loadwrapper(batchwrapper(data)) end diff --git a/src/loaders.jl b/src/loaders.jl index 6684f33..6a4da37 100644 --- a/src/loaders.jl +++ b/src/loaders.jl @@ -2,11 +2,17 @@ struct GetObsParallel{TData} data::TData + usethreads::Bool useprimary::Bool - function GetObsParallel(data::TData; useprimary = false) where {TData} - (useprimary || Threads.nthreads() > 1) || - error("Cannot load data off main thread with only one thread available. Pass `useprimary = true` or start Julia with > 1 threads.") - return new{TData}(data, useprimary) + function GetObsParallel(data::TData; usethreads = true, useprimary = false) where {TData} + if usethreads + (useprimary || nthreads() > 1) || + error("Cannot load data off main thread with only one thread available. Pass `useprimary = true` or start Julia with > 1 threads.") + else + (useprimary || nworkers() > 1) || + error("Cannot load data off main thread with only one process available. Pass `useprimary = true` or start Julia with > 1 processes.") + end + return new{TData}(data, usethreads, useprimary) end end @@ -14,10 +20,14 @@ end Base.length(iterparallel::GetObsParallel) = nobs(iterparallel.data) function Base.iterate(iterparallel::GetObsParallel) - resultschannel = Channel(Threads.nthreads() - Int(!iterparallel.useprimary)) + resultschannel = if iterparallel.usethreads + Channel(nthreads() - Int(!iterparallel.useprimary)) + else + RemoteChannel(Distributed.nprocs() - Int(!iterparallel.useprimary)) + end workerpool = - WorkerPool(1:nobs(iterparallel.data), useprimary = iterparallel.useprimary) do idx + WorkerPool(1:nobs(iterparallel.data), usethreads=iterparallel.usethreads, useprimary = iterparallel.useprimary) do idx put!(resultschannel, getobs(iterparallel.data, idx)) end @async run(workerpool) @@ -30,7 +40,7 @@ function Base.iterate(iterparallel::GetObsParallel, state) resultschannel, workerpool, index = state # Worker pool failed - if workerpool.state === Failed + if fetch(workerpool.state) === Failed error("Worker pool failed.") # Iteration complete elseif index >= nobs(iterparallel.data) @@ -58,7 +68,7 @@ end Base.show(io::IO, bufparallel::BufferGetObsParallel) = print(io, "eachobsparallel($(bufparallel.data))") function BufferGetObsParallel(data; useprimary = false) - nthreads = Threads.nthreads() - Int(!useprimary) + nthreads = nthreads() - Int(!useprimary) nthreads > 0 || error("Cannot load data off main thread with only one thread available. Pass `useprimary = true` or start Julia with > 1 threads.") @@ -94,7 +104,7 @@ function Base.iterate(iterparallel::BufferGetObsParallel, state) ringbuffer, workerpool, index = state # Worker pool failed - if workerpool.state === Failed + if fetch(workerpool.state) === Failed error("Worker pool failed.") # Iteration complete elseif index >= nobs(iterparallel.data) @@ -124,6 +134,6 @@ See also `MLDataPattern.eachobs` are returned in the correct order. """ -eachobsparallel(data; useprimary = false, buffered = true) = +eachobsparallel(data; usethreads = true, useprimary = false, buffered = true) = buffered ? BufferGetObsParallel(data, useprimary = useprimary) : - GetObsParallel(data, useprimary = useprimary) + GetObsParallel(data, usethreads = usethreads, useprimary = useprimary) diff --git a/src/workerpool.jl b/src/workerpool.jl index be21af4..7ed7010 100644 --- a/src/workerpool.jl +++ b/src/workerpool.jl @@ -1,3 +1,25 @@ +import Base: put!, wait, isready, take!, fetch + +mutable struct ValueChannel{T} <: AbstractChannel{T} + v::T + cond_take::Condition # waiting for data to become available + ValueChannel(v) = new{typeof(v)}(v, Condition()) +end + +function put!(c::ValueChannel, v) + c.v = v + notify(c.cond_take) + return c +end + +take!(c::ValueChannel) = fetch(c) + +isready(c::ValueChannel) = true + +fetch(c::ValueChannel) = c.v + +wait(c::ValueChannel) = wait(c.cond_take) + @enum PoolState begin Done Running @@ -11,65 +33,80 @@ end mutable struct WorkerPool{TArgs} workerfn::Any args::Vector{TArgs} + usethreads::Bool useprimary::Bool - state::PoolState - ntasks::Threads.Atomic{Int} + state::ValueChannel{PoolState} end -function WorkerPool(workerfn, args::AbstractVector{TArgs}; useprimary = false) where {TArgs} - (useprimary || Threads.nthreads() > 1) || - error("Cannot load data off main thread with only one thread available. Pass `useprimary = true` or start Julia with >1 threads.") - return WorkerPool{TArgs}(workerfn, collect(args), useprimary, Done, Threads.Atomic{Int}(0)) +function WorkerPool(workerfn, args::AbstractVector{TArgs}; usethreads = true, useprimary = false) where {TArgs} + if usethreads + (useprimary || nthreads() > 1) || + error("Cannot load data off main thread with only one thread available. Pass `useprimary = true` or start Julia with > 1 threads.") + else + (useprimary || nworkers() > 1) || + error("Cannot load data off main thread with only one process available. Pass `useprimary = true` or start Julia with > 1 processes.") + end + return WorkerPool{TArgs}(workerfn, collect(args), usethreads, useprimary, ValueChannel(Done)) end -function run(pool::WorkerPool) - pool.state = Running - # set remaining tasks counter. - pool.ntasks[] = length(pool.args) +function run(pool::WorkerPool{TArgs}) where TArgs + workerfn = pool.workerfn + tasks = Channel{TArgs}(length(pool.args)) + foreach(a -> put!(tasks, a), pool.args) + usethreads = pool.usethreads + useprimary = pool.useprimary + state = pool.state + put!(state, Running) # watchdog that sends exception to main thread if a worker fails maintask = current_task() @async begin - while pool.state !== Done - if pool.state === Failed + while fetch(state) !== Done + if fetch(state) === Failed Base.throwto( maintask, - PoolFailedException("Failed to process all tasks. $(pool.ntasks[]) unfinished tasks remaining"), + PoolFailedException("Failed to process all tasks. $(length(tasks)) unfinished tasks remaining"), ) end sleep(0.1) end end - - function inloop(args) - #for args in pool.args # uncomment for debugging - # task error handling - pool.state !== Failed || error("Shutting down worker $(Threads.threadid())") - try - # execute task - pool.workerfn(args...) - Threads.atomic_add!(pool.ntasks, -1) - catch e - println(stacktrace()) - @error "Exception while executing task on worker $(Threads.threadid()). Shutting down WorkerPool." e = - e stacktrace = stacktrace() args = args - pool.state = Failed - rethrow() - end - end - - if pool.useprimary - @qthreads for args in pool.args - inloop(args) + + if usethreads + @sync for id in (useprimary ? 1 : 2):nthreads() + Threads.@spawn on_worker(tasks, state, workerfn, usethreads, useprimary) end else - @qbthreads for args in pool.args - inloop(args) + remote_state = RemoteChannel(() -> state) + remote_tasks = RemoteChannel(() -> tasks) + @sync for id in (useprimary ? procs() : workers()) + @spawnat id on_worker(remote_tasks, remote_state, workerfn, usethreads, useprimary) end end # Tasks completed successfully - pool.state = Done + put!(pool.state, Done) end + + +function on_worker(tasks, state, workerfn, usethreads, useprimary) + # task error handling + id = usethreads ? threadid() : myid() + !useprimary && id == 1 && return + while isready(tasks) + fetch(state) !== Failed || error("Shutting down worker $id") + args = take!(tasks) + try + # execute task + workerfn(args...) + catch e + display(stacktrace()) + @error "Exception while executing task on worker $id. Shutting down WorkerPool." e = + e stacktrace = stacktrace() args = args + put!(state, Failed) + # rethrow() + end + end +end \ No newline at end of file From 7548a5e4ea284a434993798377bcb00da26352af Mon Sep 17 00:00:00 2001 From: Yao Lu Date: Tue, 22 Dec 2020 23:05:00 +0800 Subject: [PATCH 2/9] usethreads docs --- src/DataLoaders.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/DataLoaders.jl b/src/DataLoaders.jl index f286492..65f42c3 100644 --- a/src/DataLoaders.jl +++ b/src/DataLoaders.jl @@ -45,6 +45,7 @@ for PyTorch users. - `parallel::Bool = Threads.nthreads() > 1)`: Whether to load data in parallel, keeping the primary thread is. Default is `true` if more than one thread is available. +- `usethreads::Bool = true`: Whether to use threads or processes - `useprimary::Bool = false`: If `false`, keep the main thread free when loading data in parallel. Is ignored if `parallel` is `false`. From 10ae8128d09b477b862053e6e869cf42b12a0d14 Mon Sep 17 00:00:00 2001 From: Yao Lu Date: Wed, 23 Dec 2020 01:18:24 +0800 Subject: [PATCH 3/9] reuse ThreadPools --- src/loaders.jl | 2 +- src/workerpool.jl | 38 +++++++++++++++++++------------------- 2 files changed, 20 insertions(+), 20 deletions(-) diff --git a/src/loaders.jl b/src/loaders.jl index 6a4da37..db3b8c1 100644 --- a/src/loaders.jl +++ b/src/loaders.jl @@ -68,7 +68,7 @@ end Base.show(io::IO, bufparallel::BufferGetObsParallel) = print(io, "eachobsparallel($(bufparallel.data))") function BufferGetObsParallel(data; useprimary = false) - nthreads = nthreads() - Int(!useprimary) + nthreads = Threads.nthreads() - Int(!useprimary) nthreads > 0 || error("Cannot load data off main thread with only one thread available. Pass `useprimary = true` or start Julia with > 1 threads.") diff --git a/src/workerpool.jl b/src/workerpool.jl index 7ed7010..cf6f7f8 100644 --- a/src/workerpool.jl +++ b/src/workerpool.jl @@ -52,12 +52,7 @@ end function run(pool::WorkerPool{TArgs}) where TArgs - workerfn = pool.workerfn - tasks = Channel{TArgs}(length(pool.args)) - foreach(a -> put!(tasks, a), pool.args) - usethreads = pool.usethreads - useprimary = pool.useprimary - state = pool.state + @unpack workerfn, usethreads, useprimary, state = pool put!(state, Running) # watchdog that sends exception to main thread if a worker fails @@ -75,10 +70,12 @@ function run(pool::WorkerPool{TArgs}) where TArgs end if usethreads - @sync for id in (useprimary ? 1 : 2):nthreads() - Threads.@spawn on_worker(tasks, state, workerfn, usethreads, useprimary) + (useprimary ? qforeach : qbforeach)(pool.args) do args + inloop(state, workerfn, threadid(), args) end else + tasks = Channel{TArgs}(length(pool.args)) + foreach(a -> put!(tasks, a), pool.args) remote_state = RemoteChannel(() -> state) remote_tasks = RemoteChannel(() -> tasks) @sync for id in (useprimary ? procs() : workers()) @@ -87,9 +84,21 @@ function run(pool::WorkerPool{TArgs}) where TArgs end # Tasks completed successfully - put!(pool.state, Done) + put!(state, Done) end +function inloop(state, workerfn, id, args) + try + # execute task + workerfn(args...) + catch e + display(stacktrace()) + @error "Exception while executing task on worker $id. Shutting down WorkerPool." e = + e stacktrace = stacktrace() args = args + put!(state, Failed) + rethrow() + end +end function on_worker(tasks, state, workerfn, usethreads, useprimary) # task error handling @@ -98,15 +107,6 @@ function on_worker(tasks, state, workerfn, usethreads, useprimary) while isready(tasks) fetch(state) !== Failed || error("Shutting down worker $id") args = take!(tasks) - try - # execute task - workerfn(args...) - catch e - display(stacktrace()) - @error "Exception while executing task on worker $id. Shutting down WorkerPool." e = - e stacktrace = stacktrace() args = args - put!(state, Failed) - # rethrow() - end + inloop(state, workerfn, id, args) end end \ No newline at end of file From 4e1f853eb454f95b2a595791ed3f8bac1815c1b3 Mon Sep 17 00:00:00 2001 From: Yao Lu Date: Wed, 23 Dec 2020 01:18:35 +0800 Subject: [PATCH 4/9] distributed test --- test/runtests.jl | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index 13777a5..8805de9 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,3 +1,9 @@ +using Distributed + +addprocs(2) + +@everywhere begin + using Test using TestSetExtensions using DataLoaders @@ -21,6 +27,7 @@ LearnBase.getobs!(buf, ds::MockDataset, idx::Int) = ds.inplace ? fill!(buf, 0.3) : getobs(ds, idx) LearnBase.nobs(ds::MockDataset) = ds.n +end @testset ExtendedTestSet "collate" begin @test collate([([1, 2], 3), ([4, 5], 6)]) == ([1 4; 2 5], [3, 6]) @@ -178,7 +185,7 @@ end @testset ExtendedTestSet "DataLoader" begin data = MockDataset(256, (10, 5), true) bs = 8 - + @testset ExtendedTestSet "buffer, collate, parallel" begin dl = DataLoader(data, bs) @test_nowarn for batch in dl end @@ -189,6 +196,11 @@ end @test_nowarn for batch in dl end end + @testset ExtendedTestSet "buffer, collate, parallel, samples, distributed" begin + dl = DataLoader(data, nothing, usethreads = false) + @test_nowarn for batch in dl end + end + @testset ExtendedTestSet "collate, parallel" begin dl = DataLoader(data, bs, buffered = false) @test_nowarn for batch in dl end @@ -199,6 +211,11 @@ end @test_nowarn for batch in dl end end + @testset ExtendedTestSet "collate, distributed" begin + dl = DataLoader(data, bs, buffered = false, usethreads = false) + @test_nowarn for batch in dl end + end + @testset ExtendedTestSet "buffer, collate" begin dl = DataLoader(data, bs, buffered = true) @test_nowarn for batch in dl end From 57c270449db3e05745e6ba426fb538cb9a498340 Mon Sep 17 00:00:00 2001 From: Yao Lu Date: Wed, 23 Dec 2020 02:14:25 +0800 Subject: [PATCH 5/9] add maxquesize --- src/DataLoaders.jl | 10 +++----- src/loaders.jl | 61 +++++++++++++++++++++++++++++----------------- test/runtests.jl | 19 +++++++++------ 3 files changed, 54 insertions(+), 36 deletions(-) diff --git a/src/DataLoaders.jl b/src/DataLoaders.jl index 65f42c3..588bc89 100644 --- a/src/DataLoaders.jl +++ b/src/DataLoaders.jl @@ -48,6 +48,7 @@ for PyTorch users. - `usethreads::Bool = true`: Whether to use threads or processes - `useprimary::Bool = false`: If `false`, keep the main thread free when loading data in parallel. Is ignored if `parallel` is `false`. +- `maxquesize = nothing`: Maximum size of caching queue. ## Examples @@ -62,6 +63,7 @@ function DataLoader( partial = true, usethreads = true, useprimary = usethreads ? nthreads() == 1 : nprocs() == 1, + maxquesize = nothing, ) !usethreads || nthreads() > 1 || useprimary || error( @@ -70,16 +72,12 @@ function DataLoader( "the `-t n` option or setting the `JULIA_NUM_THREADS` " * "environment variable before starting Julia.") - usethreads || Distributed.nprocs() > 1 || useprimary || error( + usethreads || nprocs() > 1 || useprimary || error( "Julia is running with one procs only, either pass `useprimary = true` or " * "start Julia with multiple threads by passing " * "the `-p n` option " * "environment variable before starting Julia.") - usethreads || !buffered || error( - "Buffered loading is not compatible with `usethreads = false`" - ) - batchwrapper = if isnothing(batchsize) identity elseif collate @@ -89,7 +87,7 @@ function DataLoader( data -> batchview(data, size = batchsize) end - loadwrapper = data -> eachobsparallel(data; usethreads = usethreads, useprimary = useprimary, buffered = buffered) + loadwrapper = data -> eachobsparallel(data; usethreads = usethreads, useprimary = useprimary, buffered = buffered, maxquesize = maxquesize) return loadwrapper(batchwrapper(data)) end diff --git a/src/loaders.jl b/src/loaders.jl index db3b8c1..7da257a 100644 --- a/src/loaders.jl +++ b/src/loaders.jl @@ -4,7 +4,8 @@ struct GetObsParallel{TData} data::TData usethreads::Bool useprimary::Bool - function GetObsParallel(data::TData; usethreads = true, useprimary = false) where {TData} + maxquesize::Int + function GetObsParallel(data::TData; usethreads = true, useprimary = false, maxquesize = nothing) where {TData} if usethreads (useprimary || nthreads() > 1) || error("Cannot load data off main thread with only one thread available. Pass `useprimary = true` or start Julia with > 1 threads.") @@ -12,7 +13,8 @@ struct GetObsParallel{TData} (useprimary || nworkers() > 1) || error("Cannot load data off main thread with only one process available. Pass `useprimary = true` or start Julia with > 1 processes.") end - return new{TData}(data, usethreads, useprimary) + maxquesize = something(maxquesize, usethreads ? nthreads() : nworkers()) + return new{TData}(data, usethreads, useprimary, maxquesize) end end @@ -21,9 +23,9 @@ Base.length(iterparallel::GetObsParallel) = nobs(iterparallel.data) function Base.iterate(iterparallel::GetObsParallel) resultschannel = if iterparallel.usethreads - Channel(nthreads() - Int(!iterparallel.useprimary)) + Channel(iterparallel.maxquesize) else - RemoteChannel(Distributed.nprocs() - Int(!iterparallel.useprimary)) + RemoteChannel(iterparallel.maxquesize) end workerpool = @@ -54,7 +56,7 @@ end # Buffered version """ - BufferGetObsParallel(data; useprimary = false) + BufferGetObsParallel(data; usethreads = true, useprimary = false, maxquesize = nothing) Like `MLDataPattern.BufferGetObs` but preloads observations into a buffer ring with multi-threaded workers. @@ -62,23 +64,30 @@ buffer ring with multi-threaded workers. struct BufferGetObsParallel{TElem,TData} data::TData buffers::Vector{TElem} + usethreads::Bool useprimary::Bool + maxquesize::Int end Base.show(io::IO, bufparallel::BufferGetObsParallel) = print(io, "eachobsparallel($(bufparallel.data))") -function BufferGetObsParallel(data; useprimary = false) - nthreads = Threads.nthreads() - Int(!useprimary) - nthreads > 0 || - error("Cannot load data off main thread with only one thread available. Pass `useprimary = true` or start Julia with > 1 threads.") +function BufferGetObsParallel(data; usethreads = true, useprimary = false, maxquesize = nothing) + if usethreads + (useprimary || nthreads() > 1) || + error("Cannot load data off main thread with only one thread available. Pass `useprimary = true` or start Julia with > 1 threads.") + else + (useprimary || nworkers() > 1) || + error("Cannot load data off main thread with only one process available. Pass `useprimary = true` or start Julia with > 1 processes.") + end buffer = getobs(data, 1) buffers = [buffer] - for _ ∈ 1:nthreads + maxquesize = something(maxquesize, usethreads ? nthreads() : nworkers()) + for _ ∈ 1:maxquesize push!(buffers, deepcopy(buffer)) end - return BufferGetObsParallel(data, buffers, useprimary) + return BufferGetObsParallel(data, buffers, usethreads, useprimary, maxquesize) end @@ -86,14 +95,22 @@ Base.length(iterparallel::BufferGetObsParallel) = nobs(iterparallel.data) function Base.iterate(iterparallel::BufferGetObsParallel) - ringbuffer = RingBuffer(iterparallel.buffers) - - workerpool = - WorkerPool(1:nobs(iterparallel.data), useprimary = iterparallel.useprimary) do idx - put!(ringbuffer) do buf - getobs!(buf, iterparallel.data, idx) + if iterparallel.usethreads + ringbuffer = RingBuffer(iterparallel.buffers) + workerpool = + WorkerPool(1:nobs(iterparallel.data), useprimary = iterparallel.useprimary) do idx + put!(ringbuffer) do buf + getobs!(buf, iterparallel.data, idx) + end end - end + else + resultschannel = RemoteChannel(iterparallel.maxquesize) + workerpool = + WorkerPool(1:nobs(iterparallel.data), usethreads=iterparallel.usethreads, useprimary = iterparallel.useprimary) do idx + put!(resultschannel, getobs(iterparallel.data, idx)) + end + end + @async run(workerpool) return iterate(iterparallel, (ringbuffer, workerpool, 0)) @@ -118,7 +135,7 @@ end # functional interface """ - eachobsparallel(data; useprimary = false, buffered = true) + eachobsparallel(data; usethreads = true, useprimary = false, buffered = true, maxquesize = nothing) Parallel data iterator for data container `data`. Loads data on all available threads (except the first if `useprimary` is `false`). @@ -134,6 +151,6 @@ See also `MLDataPattern.eachobs` are returned in the correct order. """ -eachobsparallel(data; usethreads = true, useprimary = false, buffered = true) = - buffered ? BufferGetObsParallel(data, useprimary = useprimary) : - GetObsParallel(data, usethreads = usethreads, useprimary = useprimary) +eachobsparallel(data; usethreads = true, useprimary = false, buffered = true, maxquesize = nothing) = + buffered ? BufferGetObsParallel(data, useprimary = useprimary, maxquesize = maxquesize) : + GetObsParallel(data, usethreads = usethreads, useprimary = useprimary, maxquesize = maxquesize) diff --git a/test/runtests.jl b/test/runtests.jl index 8805de9..c9610c4 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -181,43 +181,46 @@ end end +for usethreads in (true, false), maxquesize in (nothing, 4, 8) -@testset ExtendedTestSet "DataLoader" begin +@testset ExtendedTestSet "DataLoader: usethreads=$usethreads, maxquesize=$maxquesize" begin data = MockDataset(256, (10, 5), true) bs = 8 @testset ExtendedTestSet "buffer, collate, parallel" begin - dl = DataLoader(data, bs) + dl = DataLoader(data, bs, usethreads = usethreads, maxquesize = maxquesize) @test_nowarn for batch in dl end end @testset ExtendedTestSet "buffer, collate, parallel, samples" begin - dl = DataLoader(data, nothing) + dl = DataLoader(data, nothing, usethreads = usethreads, maxquesize = maxquesize) @test_nowarn for batch in dl end end @testset ExtendedTestSet "buffer, collate, parallel, samples, distributed" begin - dl = DataLoader(data, nothing, usethreads = false) + dl = DataLoader(data, nothing, usethreads = usethreads, maxquesize = maxquesize) @test_nowarn for batch in dl end end @testset ExtendedTestSet "collate, parallel" begin - dl = DataLoader(data, bs, buffered = false) + dl = DataLoader(data, bs, buffered = false, usethreads = usethreads, maxquesize = maxquesize) @test_nowarn for batch in dl end end @testset ExtendedTestSet "collate" begin - dl = DataLoader(data, bs, buffered = false) + dl = DataLoader(data, bs, buffered = false, usethreads = usethreads) @test_nowarn for batch in dl end end @testset ExtendedTestSet "collate, distributed" begin - dl = DataLoader(data, bs, buffered = false, usethreads = false) + dl = DataLoader(data, bs, buffered = false, usethreads = usethreads) @test_nowarn for batch in dl end end @testset ExtendedTestSet "buffer, collate" begin - dl = DataLoader(data, bs, buffered = true) + dl = DataLoader(data, bs, buffered = true, usethreads = usethreads) @test_nowarn for batch in dl end end end + +end \ No newline at end of file From fc8bed83e4ec14c74dbe448cfb2bb3cfc00ef9aa Mon Sep 17 00:00:00 2001 From: Yao Lu Date: Wed, 23 Dec 2020 02:36:45 +0800 Subject: [PATCH 6/9] fix RemoteChannel --- src/loaders.jl | 18 +++++++++--------- test/Project.toml | 1 + 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/src/loaders.jl b/src/loaders.jl index 7da257a..ce43cd1 100644 --- a/src/loaders.jl +++ b/src/loaders.jl @@ -25,11 +25,11 @@ function Base.iterate(iterparallel::GetObsParallel) resultschannel = if iterparallel.usethreads Channel(iterparallel.maxquesize) else - RemoteChannel(iterparallel.maxquesize) + RemoteChannel(() -> Channel(iterparallel.maxquesize)) end workerpool = - WorkerPool(1:nobs(iterparallel.data), usethreads=iterparallel.usethreads, useprimary = iterparallel.useprimary) do idx + WorkerPool(1:nobs(iterparallel.data), usethreads = iterparallel.usethreads, useprimary = iterparallel.useprimary) do idx put!(resultschannel, getobs(iterparallel.data, idx)) end @async run(workerpool) @@ -96,15 +96,15 @@ Base.length(iterparallel::BufferGetObsParallel) = nobs(iterparallel.data) function Base.iterate(iterparallel::BufferGetObsParallel) if iterparallel.usethreads - ringbuffer = RingBuffer(iterparallel.buffers) + resultschannel = RingBuffer(iterparallel.buffers) workerpool = WorkerPool(1:nobs(iterparallel.data), useprimary = iterparallel.useprimary) do idx - put!(ringbuffer) do buf + put!(resultschannel) do buf getobs!(buf, iterparallel.data, idx) end end else - resultschannel = RemoteChannel(iterparallel.maxquesize) + resultschannel = RemoteChannel(() -> Channel(iterparallel.maxquesize)) workerpool = WorkerPool(1:nobs(iterparallel.data), usethreads=iterparallel.usethreads, useprimary = iterparallel.useprimary) do idx put!(resultschannel, getobs(iterparallel.data, idx)) @@ -113,12 +113,12 @@ function Base.iterate(iterparallel::BufferGetObsParallel) @async run(workerpool) - return iterate(iterparallel, (ringbuffer, workerpool, 0)) + return iterate(iterparallel, (resultschannel, workerpool, 0)) end function Base.iterate(iterparallel::BufferGetObsParallel, state) - ringbuffer, workerpool, index = state + resultschannel, workerpool, index = state # Worker pool failed if fetch(workerpool.state) === Failed @@ -127,7 +127,7 @@ function Base.iterate(iterparallel::BufferGetObsParallel, state) elseif index >= nobs(iterparallel.data) return nothing else - return take!(ringbuffer), (ringbuffer, workerpool, index + 1) + return take!(resultschannel), (resultschannel, workerpool, index + 1) end end @@ -152,5 +152,5 @@ See also `MLDataPattern.eachobs` """ eachobsparallel(data; usethreads = true, useprimary = false, buffered = true, maxquesize = nothing) = - buffered ? BufferGetObsParallel(data, useprimary = useprimary, maxquesize = maxquesize) : + buffered ? BufferGetObsParallel(data, usethreads = usethreads, useprimary = useprimary, maxquesize = maxquesize) : GetObsParallel(data, usethreads = usethreads, useprimary = useprimary, maxquesize = maxquesize) diff --git a/test/Project.toml b/test/Project.toml index 1d2c169..4ff1307 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,4 +1,5 @@ [deps] +Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" LearnBase = "7f8f8fb0-2700-5f03-b4bd-41f8cfc144b6" MLDataPattern = "9920b226-0b2a-5f5f-9153-9aa70a013f8b" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" From 84f6cfc2c5b8d0ebd3a00f13787011435267befd Mon Sep 17 00:00:00 2001 From: Yao Lu Date: Sun, 27 Dec 2020 02:28:15 +0800 Subject: [PATCH 7/9] close channel and wait task --- src/loaders.jl | 25 ++++++++++++++----------- src/ringbuffer.jl | 4 ++++ src/workerpool.jl | 2 +- test/runtests.jl | 6 +++--- 4 files changed, 22 insertions(+), 15 deletions(-) diff --git a/src/loaders.jl b/src/loaders.jl index ce43cd1..d122f8b 100644 --- a/src/loaders.jl +++ b/src/loaders.jl @@ -32,23 +32,25 @@ function Base.iterate(iterparallel::GetObsParallel) WorkerPool(1:nobs(iterparallel.data), usethreads = iterparallel.usethreads, useprimary = iterparallel.useprimary) do idx put!(resultschannel, getobs(iterparallel.data, idx)) end - @async run(workerpool) + task = @async run(workerpool) - return iterate(iterparallel, (resultschannel, workerpool, 0)) + return iterate(iterparallel, (task, resultschannel, workerpool, 0)) end function Base.iterate(iterparallel::GetObsParallel, state) - resultschannel, workerpool, index = state + task, resultschannel, workerpool, index = state # Worker pool failed if fetch(workerpool.state) === Failed error("Worker pool failed.") # Iteration complete elseif index >= nobs(iterparallel.data) + close(resultschannel) + wait(task) return nothing else - return take!(resultschannel), (resultschannel, workerpool, index + 1) + return take!(resultschannel), (task, resultschannel, workerpool, index + 1) end end @@ -99,7 +101,7 @@ function Base.iterate(iterparallel::BufferGetObsParallel) resultschannel = RingBuffer(iterparallel.buffers) workerpool = WorkerPool(1:nobs(iterparallel.data), useprimary = iterparallel.useprimary) do idx - put!(resultschannel) do buf + isopen(resultschannel) && put!(resultschannel) do buf getobs!(buf, iterparallel.data, idx) end end @@ -107,27 +109,28 @@ function Base.iterate(iterparallel::BufferGetObsParallel) resultschannel = RemoteChannel(() -> Channel(iterparallel.maxquesize)) workerpool = WorkerPool(1:nobs(iterparallel.data), usethreads=iterparallel.usethreads, useprimary = iterparallel.useprimary) do idx - put!(resultschannel, getobs(iterparallel.data, idx)) + isopen(resultschannel) && put!(resultschannel, getobs(iterparallel.data, idx)) end end + task = @async run(workerpool) - @async run(workerpool) - - return iterate(iterparallel, (resultschannel, workerpool, 0)) + return iterate(iterparallel, (task, resultschannel, workerpool, 0)) end function Base.iterate(iterparallel::BufferGetObsParallel, state) - resultschannel, workerpool, index = state + task, resultschannel, workerpool, index = state # Worker pool failed if fetch(workerpool.state) === Failed error("Worker pool failed.") # Iteration complete elseif index >= nobs(iterparallel.data) + close(resultschannel) + wait(task) return nothing else - return take!(resultschannel), (resultschannel, workerpool, index + 1) + return take!(resultschannel), (task, resultschannel, workerpool, index + 1) end end diff --git a/src/ringbuffer.jl b/src/ringbuffer.jl index a5ef4a2..db48a64 100644 --- a/src/ringbuffer.jl +++ b/src/ringbuffer.jl @@ -82,6 +82,10 @@ function Base.put!(f!, ringbuffer::RingBuffer) put!(ringbuffer.results, buf_) end +function Base.isopen(ringbuffer::RingBuffer) + isopen(ringbuffer.results) + isopen(ringbuffer.buffers) +end function Base.close(ringbuffer::RingBuffer) close(ringbuffer.results) diff --git a/src/workerpool.jl b/src/workerpool.jl index cf6f7f8..f0bbc5e 100644 --- a/src/workerpool.jl +++ b/src/workerpool.jl @@ -104,7 +104,7 @@ function on_worker(tasks, state, workerfn, usethreads, useprimary) # task error handling id = usethreads ? threadid() : myid() !useprimary && id == 1 && return - while isready(tasks) + while isopen(tasks) && isready(tasks) fetch(state) !== Failed || error("Shutting down worker $id") args = take!(tasks) inloop(state, workerfn, id, args) diff --git a/test/runtests.jl b/test/runtests.jl index c9610c4..9a53ccf 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,6 +1,6 @@ using Distributed -addprocs(2) +# addprocs(8) @everywhere begin @@ -142,7 +142,7 @@ end @testset ExtendedTestSet "iterate" begin dl = make() - x, (results, workerpool, idx) = iterate(dl) + x, (task, results, workerpool, idx) = iterate(dl) @test idx == 1 @test_nowarn for obs in dl end @@ -155,7 +155,7 @@ end @testset ExtendedTestSet "iterate" begin dl = make() - x, (ringbuffer, workerpool, idx) = iterate(dl) + x, (task, ringbuffer, workerpool, idx) = iterate(dl) @test idx == 1 @test_nowarn for obs in dl end From 61d596fac06a5d74c0966dfbdbbecca83a4d1520 Mon Sep 17 00:00:00 2001 From: Yao Lu Date: Sun, 27 Dec 2020 11:16:03 +0800 Subject: [PATCH 8/9] wait watchdog --- src/workerpool.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/workerpool.jl b/src/workerpool.jl index f0bbc5e..5c613bf 100644 --- a/src/workerpool.jl +++ b/src/workerpool.jl @@ -57,7 +57,7 @@ function run(pool::WorkerPool{TArgs}) where TArgs # watchdog that sends exception to main thread if a worker fails maintask = current_task() - @async begin + watchdog = @async begin while fetch(state) !== Done if fetch(state) === Failed Base.throwto( @@ -85,6 +85,7 @@ function run(pool::WorkerPool{TArgs}) where TArgs # Tasks completed successfully put!(state, Done) + wait(watchdog) end function inloop(state, workerfn, id, args) From bdcd5746723924b5ceb6c5f816582b50f3f618ce Mon Sep 17 00:00:00 2001 From: Yao Lu Date: Sun, 27 Dec 2020 15:22:10 +0800 Subject: [PATCH 9/9] fix task hung --- src/loaders.jl | 4 ++-- src/ringbuffer.jl | 3 +-- src/workerpool.jl | 7 ++++--- test/runtests.jl | 2 +- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/loaders.jl b/src/loaders.jl index d122f8b..9e1d44a 100644 --- a/src/loaders.jl +++ b/src/loaders.jl @@ -101,7 +101,7 @@ function Base.iterate(iterparallel::BufferGetObsParallel) resultschannel = RingBuffer(iterparallel.buffers) workerpool = WorkerPool(1:nobs(iterparallel.data), useprimary = iterparallel.useprimary) do idx - isopen(resultschannel) && put!(resultschannel) do buf + put!(resultschannel) do buf getobs!(buf, iterparallel.data, idx) end end @@ -109,7 +109,7 @@ function Base.iterate(iterparallel::BufferGetObsParallel) resultschannel = RemoteChannel(() -> Channel(iterparallel.maxquesize)) workerpool = WorkerPool(1:nobs(iterparallel.data), usethreads=iterparallel.usethreads, useprimary = iterparallel.useprimary) do idx - isopen(resultschannel) && put!(resultschannel, getobs(iterparallel.data, idx)) + put!(resultschannel, getobs(iterparallel.data, idx)) end end task = @async run(workerpool) diff --git a/src/ringbuffer.jl b/src/ringbuffer.jl index db48a64..5014f98 100644 --- a/src/ringbuffer.jl +++ b/src/ringbuffer.jl @@ -83,8 +83,7 @@ function Base.put!(f!, ringbuffer::RingBuffer) end function Base.isopen(ringbuffer::RingBuffer) - isopen(ringbuffer.results) - isopen(ringbuffer.buffers) + isopen(ringbuffer.results) && isopen(ringbuffer.buffers) end function Base.close(ringbuffer::RingBuffer) diff --git a/src/workerpool.jl b/src/workerpool.jl index 5c613bf..c8ea8e5 100644 --- a/src/workerpool.jl +++ b/src/workerpool.jl @@ -74,8 +74,9 @@ function run(pool::WorkerPool{TArgs}) where TArgs inloop(state, workerfn, threadid(), args) end else - tasks = Channel{TArgs}(length(pool.args)) + tasks = Channel{TArgs}(Inf) foreach(a -> put!(tasks, a), pool.args) + close(tasks) remote_state = RemoteChannel(() -> state) remote_tasks = RemoteChannel(() -> tasks) @sync for id in (useprimary ? procs() : workers()) @@ -105,9 +106,9 @@ function on_worker(tasks, state, workerfn, usethreads, useprimary) # task error handling id = usethreads ? threadid() : myid() !useprimary && id == 1 && return - while isopen(tasks) && isready(tasks) + while isready(tasks) + args = try take!(tasks) catch e break end fetch(state) !== Failed || error("Shutting down worker $id") - args = take!(tasks) inloop(state, workerfn, id, args) end end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 9a53ccf..fb7598a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,6 +1,6 @@ using Distributed -# addprocs(8) +addprocs(2) @everywhere begin