Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
/docs/build/
Manifest.toml
test/Manifest.toml
test/Manifest.toml
.vscode
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ authors = ["lorenzoh <lorenz.ohly@gmail.com>"]
version = "0.1.2"

[deps]
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
LearnBase = "7f8f8fb0-2700-5f03-b4bd-41f8cfc144b6"
MLDataPattern = "9920b226-0b2a-5f5f-9153-9aa70a013f8b"
Expand Down
18 changes: 15 additions & 3 deletions src/DataLoaders.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
module DataLoaders

using Base.Threads
using Distributed
using MLDataPattern
using ThreadPools
using LearnBase
Expand Down Expand Up @@ -43,8 +45,10 @@ for PyTorch users.
- `parallel::Bool = Threads.nthreads() > 1)`: Whether to load data
in parallel, keeping the primary thread is. Default is `true` if
more than one thread is available.
- `usethreads::Bool = true`: Whether to use threads or processes
- `useprimary::Bool = false`: If `false`, keep the main thread free when loading
data in parallel. Is ignored if `parallel` is `false`.
- `maxquesize = nothing`: Maximum size of caching queue.

## Examples

Expand All @@ -57,15 +61,23 @@ function DataLoader(
collate = !isnothing(batchsize),
buffered = collate,
partial = true,
useprimary = Threads.nthreads() == 1,
usethreads = true,
useprimary = usethreads ? nthreads() == 1 : nprocs() == 1,
maxquesize = nothing,
)

Threads.nthreads() > 1 || useprimary || error(
!usethreads || nthreads() > 1 || useprimary || error(
"Julia is running with one thread only, either pass `useprimary = true` or " *
"start Julia with multiple threads by passing " *
"the `-t n` option or setting the `JULIA_NUM_THREADS` " *
"environment variable before starting Julia.")

usethreads || nprocs() > 1 || useprimary || error(
"Julia is running with one procs only, either pass `useprimary = true` or " *
"start Julia with multiple threads by passing " *
"the `-p n` option " *
"environment variable before starting Julia.")

batchwrapper = if isnothing(batchsize)
identity
elseif collate
Expand All @@ -75,7 +87,7 @@ function DataLoader(
data -> batchview(data, size = batchsize)
end

loadwrapper = data -> eachobsparallel(data; useprimary = useprimary, buffered = buffered)
loadwrapper = data -> eachobsparallel(data; usethreads = usethreads, useprimary = useprimary, buffered = buffered, maxquesize = maxquesize)

return loadwrapper(batchwrapper(data))
end
Expand Down
98 changes: 64 additions & 34 deletions src/loaders.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,113 +2,143 @@

struct GetObsParallel{TData}
data::TData
usethreads::Bool
useprimary::Bool
function GetObsParallel(data::TData; useprimary = false) where {TData}
(useprimary || Threads.nthreads() > 1) ||
error("Cannot load data off main thread with only one thread available. Pass `useprimary = true` or start Julia with > 1 threads.")
return new{TData}(data, useprimary)
maxquesize::Int
function GetObsParallel(data::TData; usethreads = true, useprimary = false, maxquesize = nothing) where {TData}
if usethreads
(useprimary || nthreads() > 1) ||
error("Cannot load data off main thread with only one thread available. Pass `useprimary = true` or start Julia with > 1 threads.")
else
(useprimary || nworkers() > 1) ||
error("Cannot load data off main thread with only one process available. Pass `useprimary = true` or start Julia with > 1 processes.")
end
maxquesize = something(maxquesize, usethreads ? nthreads() : nworkers())
return new{TData}(data, usethreads, useprimary, maxquesize)
end
end


Base.length(iterparallel::GetObsParallel) = nobs(iterparallel.data)

function Base.iterate(iterparallel::GetObsParallel)
resultschannel = Channel(Threads.nthreads() - Int(!iterparallel.useprimary))
resultschannel = if iterparallel.usethreads
Channel(iterparallel.maxquesize)
else
RemoteChannel(() -> Channel(iterparallel.maxquesize))
end

workerpool =
WorkerPool(1:nobs(iterparallel.data), useprimary = iterparallel.useprimary) do idx
WorkerPool(1:nobs(iterparallel.data), usethreads = iterparallel.usethreads, useprimary = iterparallel.useprimary) do idx
put!(resultschannel, getobs(iterparallel.data, idx))
end
@async run(workerpool)
task = @async run(workerpool)

return iterate(iterparallel, (resultschannel, workerpool, 0))
return iterate(iterparallel, (task, resultschannel, workerpool, 0))
end


function Base.iterate(iterparallel::GetObsParallel, state)
resultschannel, workerpool, index = state
task, resultschannel, workerpool, index = state

# Worker pool failed
if workerpool.state === Failed
if fetch(workerpool.state) === Failed
error("Worker pool failed.")
# Iteration complete
elseif index >= nobs(iterparallel.data)
close(resultschannel)
wait(task)
return nothing
else
return take!(resultschannel), (resultschannel, workerpool, index + 1)
return take!(resultschannel), (task, resultschannel, workerpool, index + 1)
end
end


# Buffered version

"""
BufferGetObsParallel(data; useprimary = false)
BufferGetObsParallel(data; usethreads = true, useprimary = false, maxquesize = nothing)

Like `MLDataPattern.BufferGetObs` but preloads observations into a
buffer ring with multi-threaded workers.
"""
struct BufferGetObsParallel{TElem,TData}
data::TData
buffers::Vector{TElem}
usethreads::Bool
useprimary::Bool
maxquesize::Int
end

Base.show(io::IO, bufparallel::BufferGetObsParallel) = print(io, "eachobsparallel($(bufparallel.data))")

function BufferGetObsParallel(data; useprimary = false)
nthreads = Threads.nthreads() - Int(!useprimary)
nthreads > 0 ||
error("Cannot load data off main thread with only one thread available. Pass `useprimary = true` or start Julia with > 1 threads.")
function BufferGetObsParallel(data; usethreads = true, useprimary = false, maxquesize = nothing)
if usethreads
(useprimary || nthreads() > 1) ||
error("Cannot load data off main thread with only one thread available. Pass `useprimary = true` or start Julia with > 1 threads.")
else
(useprimary || nworkers() > 1) ||
error("Cannot load data off main thread with only one process available. Pass `useprimary = true` or start Julia with > 1 processes.")
end

buffer = getobs(data, 1)
buffers = [buffer]
for _ ∈ 1:nthreads
maxquesize = something(maxquesize, usethreads ? nthreads() : nworkers())
for _ ∈ 1:maxquesize
push!(buffers, deepcopy(buffer))
end

return BufferGetObsParallel(data, buffers, useprimary)
return BufferGetObsParallel(data, buffers, usethreads, useprimary, maxquesize)
end


Base.length(iterparallel::BufferGetObsParallel) = nobs(iterparallel.data)


function Base.iterate(iterparallel::BufferGetObsParallel)
ringbuffer = RingBuffer(iterparallel.buffers)

workerpool =
WorkerPool(1:nobs(iterparallel.data), useprimary = iterparallel.useprimary) do idx
put!(ringbuffer) do buf
getobs!(buf, iterparallel.data, idx)
if iterparallel.usethreads
resultschannel = RingBuffer(iterparallel.buffers)
workerpool =
WorkerPool(1:nobs(iterparallel.data), useprimary = iterparallel.useprimary) do idx
put!(resultschannel) do buf
getobs!(buf, iterparallel.data, idx)
end
end
end
@async run(workerpool)
else
resultschannel = RemoteChannel(() -> Channel(iterparallel.maxquesize))
workerpool =
WorkerPool(1:nobs(iterparallel.data), usethreads=iterparallel.usethreads, useprimary = iterparallel.useprimary) do idx
put!(resultschannel, getobs(iterparallel.data, idx))
end
end
task = @async run(workerpool)

return iterate(iterparallel, (ringbuffer, workerpool, 0))
return iterate(iterparallel, (task, resultschannel, workerpool, 0))
end


function Base.iterate(iterparallel::BufferGetObsParallel, state)
ringbuffer, workerpool, index = state
task, resultschannel, workerpool, index = state

# Worker pool failed
if workerpool.state === Failed
if fetch(workerpool.state) === Failed
error("Worker pool failed.")
# Iteration complete
elseif index >= nobs(iterparallel.data)
close(resultschannel)
wait(task)
return nothing
else
return take!(ringbuffer), (ringbuffer, workerpool, index + 1)
return take!(resultschannel), (task, resultschannel, workerpool, index + 1)
end
end


# functional interface

"""
eachobsparallel(data; useprimary = false, buffered = true)
eachobsparallel(data; usethreads = true, useprimary = false, buffered = true, maxquesize = nothing)

Parallel data iterator for data container `data`. Loads data on all
available threads (except the first if `useprimary` is `false`).
Expand All @@ -124,6 +154,6 @@ See also `MLDataPattern.eachobs`
are returned in the correct order.

"""
eachobsparallel(data; useprimary = false, buffered = true) =
buffered ? BufferGetObsParallel(data, useprimary = useprimary) :
GetObsParallel(data, useprimary = useprimary)
eachobsparallel(data; usethreads = true, useprimary = false, buffered = true, maxquesize = nothing) =
buffered ? BufferGetObsParallel(data, usethreads = usethreads, useprimary = useprimary, maxquesize = maxquesize) :
GetObsParallel(data, usethreads = usethreads, useprimary = useprimary, maxquesize = maxquesize)
3 changes: 3 additions & 0 deletions src/ringbuffer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ function Base.put!(f!, ringbuffer::RingBuffer)
put!(ringbuffer.results, buf_)
end

function Base.isopen(ringbuffer::RingBuffer)
isopen(ringbuffer.results) && isopen(ringbuffer.buffers)
end

function Base.close(ringbuffer::RingBuffer)
close(ringbuffer.results)
Expand Down
Loading