-
Notifications
You must be signed in to change notification settings - Fork 0
WaitFreeQueueFast
Pslydhh edited this page Oct 5, 2017
·
1 revision
package org.psly.concurrent;
import java.lang.reflect.Field;
import java.util.Random;
import java.util.concurrent.atomic.AtomicInteger;
import org.psly.concurrent.WaitFreeQueueFastPath.OpDescPoll.PNode;
import sun.misc.Unsafe;
public class WaitFreeQueueFastPath {
public WaitFreeQueueFastPath() {
tail = new Node(null);
head = new NodeHead(tail, null);
@SuppressWarnings("unchecked")
OpDesc[] temState = (OpDesc[]) new OpDesc>[Short.MAX_VALUE];
state = temState;
@SuppressWarnings("unchecked")
HelpRecord[] temHelps = new WaitFreeQueueFastPath.HelpRecord[Short.MAX_VALUE];
helpRecords = temHelps;
}
void helpIfNeeded() {
int tid = tLocal.get();
HelpRecord rec;
if ((rec = helpRecords[tid]) == null) {
helpRecords[tid] = rec = new HelpRecord();
}
if (rec.nextCheck-- == 0) {
OpDesc desc = getArrayAt(rec.curTid);
if (desc instanceof OpDescAdd) {
OpDescAdd opDesc = (OpDescAdd) desc;
if (opDesc.pending == 1 && opDesc.phase == rec.willHelpPhase) {
helpEnqTimes.getAndIncrement();
helpEnq(opDesc);
}
} else if (desc instanceof OpDescPoll) {
OpDescPoll opDesc = (OpDescPoll) desc;
if (opDesc.pNode.pending == 1 && opDesc.phase == rec.willHelpPhase) {
helpDeqTimes.getAndIncrement();
helpDeq(opDesc);
}
}
rec.reset();
}
}
public boolean add(E item) {
helpIfNeeded();
Node node = new Node(item);
int trials = 0;
while (trials++ < MAXFAILURES) {
Node last = tail;
Node next = last.next;
if (last == tail) {
if (next == null) {
if (last.casNext(null, node)) {
enqValue.getAndIncrement();
enqFastValue.getAndIncrement();
casTail(last, node);
return true;
}
} else {
fixTail(last, next);
}
}
}
return wfEnq(node);
}
void fixTail(Node last, Node next) {
if (next.opDesc == null)
casTail(last, next);
else
helpFinishEnq();
}
boolean wfEnq(Node node) {
int tid = tLocal.get();
OpDescAdd opDesc = (OpDescAdd) getArrayAt(tid);
long phase;
if (opDesc == null)
phase = -1;
else
phase = opDesc.phase;
phase = ++phase > 0 ? phase : 0;
setArrayAt(tid, opDesc = new OpDescAdd((short) phase, 1, true, node));
helpEnq(opDesc);
return true;
}
public void helpEnq(OpDescAdd opDescAdd) {
while (opDescAdd.pending == 1) {
Node last = tail;
Node next = last.next;
if (last == tail) {
if (next == null) {
if (opDescAdd.pending == 1) {
if (last.casNext(null, opDescAdd.node)) {
enqValue.getAndIncrement();
enqSlowValue.getAndIncrement();
helpFinishEnq();
return;
}
}
} else {
helpFinishEnq();
}
}
}
}
private void helpFinishEnq() {
Node last = tail;
Node next = last.next;
if (next != null) {
if (next.opDesc instanceof OpDescAdd) {
OpDescAdd curDesc = (OpDescAdd) next.opDesc;
if (last == tail && curDesc.node == next) {
curDesc.casPending(1, 0);
casTail(last, next);
}
} else {
casTail(last, next);
}
}
}
public E poll() {
helpIfNeeded();
int trials = 0;
while (trials++ < MAXFAILURES) {
NodeHead localHead = head;
Node first = localHead.node;
Node last = tail;
Node next = first.next;
if (localHead == head) {
if (first == last) {
if (next == null) {
deqNull.getAndIncrement();
deqFastNull.getAndIncrement();
return null;
}
fixTail(first, next);
} else if (localHead.opDesc == null) {
E value = next.value;
if (casHead(first, next, null, null)) {
deqValue.getAndIncrement();
deqFastValue.getAndIncrement();
headTransfer.getAndIncrement();
headTranDirect.getAndIncrement();
return value;
}
} else {
helpFinishDeq();
}
}
}
return wfDeq();
}
E wfDeq() {
int tid = tLocal.get();
OpDescPoll opDesc = (OpDescPoll) getArrayAt(tid);
long phase;
if (opDesc == null)
phase = -1;
else
phase = opDesc.phase;
phase = ++phase > 0 ? phase : 0;
setArrayAt(tid, opDesc = new OpDescPoll((short) phase, 1, false, null));
helpDeq(opDesc);
Node node = opDesc.pNode.node;
if (node == null)
return null;
return node.next.value;
}
public int size() {
Node head = this.head.node;
Node next;
int count = 0;
for (;;) {
next = head.next;
if (next == null)
break;
++count;
head = next;
}
return count;
}
private void helpDeq(OpDescPoll opDescPoll) {
PNode pNode;
while ((pNode = opDescPoll.pNode).pending == 1) {
NodeHead localHead = head;
Node first = localHead.node;
Node last = tail;
Node next = first.next;
if (localHead == head) {
if (first == last) {
if (next == null) {
if (last == tail && pNode.pending == 1) {
if (opDescPoll.casPNode(pNode, new PNode(null, 0))) {
deqNull.getAndIncrement();
deqSlowNull.getAndIncrement();
}
}
} else {
helpFinishEnq();
}
} else {
if (pNode.pending == 0)
break;
if (localHead == head) {
if (localHead.opDesc == null) {
if (first != pNode.node) {
opDescPoll.casPNode(pNode, new OpDescPoll.PNode<>(first, 1));
first = (pNode = opDescPoll.pNode).node;
if (pNode.pending == 0)
return;
}
if (localHead == head) {
if (casHead(first, first, null, opDescPoll)) {
deqSlowValue.getAndIncrement();
headTransformOne.getAndIncrement();
headTransfer.getAndIncrement();
}
}
}
helpFinishDeq();
}
}
}
}
}
private void helpFinishDeq() {
NodeHead localHead = head;
Node first = localHead.node;
Node next = first.next;
OpDescPoll op = localHead.opDesc;
if (op != null) {
PNode pNode = op.pNode;
if (pNode.pending == 1){
if (op.casPNode(pNode, new PNode(pNode.node, 0))){
deqValue.getAndIncrement();
}
}
if (localHead == head && next != null) {
if (casHead(first, next, op, null)) {
headTransformTwo.getAndIncrement();
}
}
}
}
boolean casTail(Node cmp, Node val) {
return UNSAFE.compareAndSwapObject(this, tailOffset, cmp, val);
}
boolean casHead(Node cmp, Node val, OpDescPoll cmpOp, OpDescPoll valOp) {
NodeHead localHead = head;
return cmp == localHead.node && cmpOp == localHead.opDesc
&& UNSAFE.compareAndSwapObject(this, headOffset, localHead, new NodeHead(val, valOp));
}
final void setArrayAt(int i, OpDesc v) {
UNSAFE.putObjectVolatile(state, ((long) i << ASHIFT) + ABASE, v);
}
@SuppressWarnings("unchecked")
final OpDesc getArrayAt(int i) {
return (OpDesc) UNSAFE.getObjectVolatile(state, ((long) i << ASHIFT) + ABASE);
}
public volatile NodeHead head;
public volatile Node tail;
private static final int HELPINGDELAY = 8;
private static final int MAXFAILURES = 8;
final ThreadLocal tLocal = new ThreadLocal() {
protected Integer initialValue() {
return tidGenerator.getAndIncrement();
}
};
final ThreadLocal threadLocalRandom = new ThreadLocal() {
protected Integer initialValue() {
return new Random().nextInt();
}
};
private final AtomicInteger tidGenerator = new AtomicInteger();
public AtomicInteger enqValue = new AtomicInteger();
public AtomicInteger enqFastValue = new AtomicInteger();
public AtomicInteger enqSlowValue = new AtomicInteger();
public AtomicInteger deqValue = new AtomicInteger();
public AtomicInteger deqNull = new AtomicInteger();
public AtomicInteger deqFastValue = new AtomicInteger();
public AtomicInteger deqSlowValue = new AtomicInteger();
public AtomicInteger deqFastNull = new AtomicInteger();
public AtomicInteger deqSlowNull = new AtomicInteger();
public AtomicInteger headTransfer = new AtomicInteger();
public AtomicInteger headTransformOne = new AtomicInteger();
public AtomicInteger headTransformTwo = new AtomicInteger();
public AtomicInteger headTranDirect = new AtomicInteger();
public AtomicInteger helpEnqTimes = new AtomicInteger();
public AtomicInteger helpDeqTimes = new AtomicInteger();
private static class Node {
final E value;
volatile Node next;
volatile OpDesc opDesc;
Node(E val) {
value = val;
next = null;
opDesc = null;
}
boolean casNext(Node cmp, Node val) {
return UNSAFE.compareAndSwapObject(this, nextOffset, cmp, val);
}
private static final sun.misc.Unsafe UNSAFE;
private static final long nextOffset;
static {
try {
UNSAFE = UtilUnsafe.getUnsafe();
nextOffset = UNSAFE.objectFieldOffset(Node.class.getDeclaredField("next"));
} catch (Exception e) {
throw new Error(e);
}
}
}
private static class NodeHead {
final Node node;
final OpDescPoll opDesc;
NodeHead(Node node, OpDescPoll opDesc) {
this.node = node;
this.opDesc = opDesc;
}
}
final OpDesc[] state;
static class OpDesc {
volatile short phase;
final boolean enqueue;
OpDesc(short ph, boolean enq) {
phase = ph;
enqueue = enq;
}
}
static final class OpDescAdd extends OpDesc {
final Node node;
volatile int pending;
OpDescAdd(short ph, int pend, boolean enq, Node n) {
super(ph, enq);
pending = pend;
node = n;
if (n != null)
n.opDesc = this;
}
boolean casPending(int cmp, int val) {
return UNSAFE.compareAndSwapInt(this, pendingOffsetAdd, cmp, val);
}
private static final sun.misc.Unsafe UNSAFE;
private static final long pendingOffsetAdd;
static {
try {
UNSAFE = UtilUnsafe.getUnsafe();
pendingOffsetAdd = UNSAFE.objectFieldOffset(OpDescAdd.class.getDeclaredField("pending"));
} catch (Exception e) {
throw new Error(e);
}
}
}
static final class OpDescPoll extends OpDesc {
volatile PNode pNode;
OpDescPoll(short ph, int pend, boolean enq, Node n) {
super(ph, enq);
pNode = new PNode(n, pend);
}
static final class PNode {
public PNode(Node node, int pending) {
this.node = node;
this.pending = pending;
}
final Node node;
final int pending;
}
boolean casPNode(PNode cmp, PNode val) {
return UNSAFE.compareAndSwapObject(this, pNodeOffset, cmp, val);
}
private static final sun.misc.Unsafe UNSAFE;
private static final long pNodeOffset;
static {
try {
UNSAFE = UtilUnsafe.getUnsafe();
pNodeOffset = UNSAFE.objectFieldOffset(OpDescPoll.class.getDeclaredField("pNode"));
} catch (Exception e) {
throw new Error(e);
}
}
}
final HelpRecord[] helpRecords;
final class HelpRecord {
int curTid;
long willHelpPhase;
long nextCheck;
HelpRecord() {
curTid = -1;
reset();
}
void reset() {
curTid = (curTid + 1) % tidGenerator.get();
OpDesc op = WaitFreeQueueFastPath.this.getArrayAt(curTid);
willHelpPhase = op != null ? op.phase : -1;
nextCheck = HELPINGDELAY;
}
}
private static class UtilUnsafe {
private UtilUnsafe() {
}
/** Fetch the Unsafe. Use With Caution. */
public static Unsafe getUnsafe() {
if (UtilUnsafe.class.getClassLoader() == null)
return Unsafe.getUnsafe();
try {
final Field fld = Unsafe.class.getDeclaredField("theUnsafe");
fld.setAccessible(true);
return (Unsafe) fld.get(UtilUnsafe.class);
} catch (Exception e) {
throw new RuntimeException("Could not obtain access to sun.misc.Unsafe", e);
}
}
}
private static final sun.misc.Unsafe UNSAFE;
private static final long headOffset;
private static final long tailOffset;
private static final int _Obase;
private static final int _Oscale;
private static final long ABASE;
private static final int ASHIFT;
static {
try {
UNSAFE = UtilUnsafe.getUnsafe();
headOffset = UNSAFE.objectFieldOffset(WaitFreeQueueFastPath.class.getDeclaredField("head"));
tailOffset = UNSAFE.objectFieldOffset(WaitFreeQueueFastPath.class.getDeclaredField("tail"));
_Obase = UNSAFE.arrayBaseOffset(OpDesc[].class);
_Oscale = UNSAFE.arrayIndexScale(OpDesc[].class);
ABASE = _Obase;
if ((_Oscale & (_Oscale - 1)) != 0)
throw new Error("data type scale not a power of two");
ASHIFT = 31 - Integer.numberOfLeadingZeros(_Oscale);
} catch (Exception e) {
throw new Error(e);
}
}
}