Skip to content

Commit 63fea62

Browse files
committed
Refactor PromiseHelperNode
1 parent 8bc3f68 commit 63fea62

File tree

2 files changed

+155
-75
lines changed

2 files changed

+155
-75
lines changed

com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/function/PromiseHelperNode.java

Lines changed: 153 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,12 @@
2828
import com.oracle.truffle.api.CompilerDirectives.CompilationFinal;
2929
import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary;
3030
import com.oracle.truffle.api.dsl.Cached;
31+
import com.oracle.truffle.api.dsl.ImportStatic;
3132
import com.oracle.truffle.api.dsl.ReportPolymorphism;
3233
import com.oracle.truffle.api.dsl.Specialization;
3334
import com.oracle.truffle.api.frame.Frame;
3435
import com.oracle.truffle.api.frame.MaterializedFrame;
3536
import com.oracle.truffle.api.frame.VirtualFrame;
36-
import com.oracle.truffle.api.nodes.ExplodeLoop;
3737
import com.oracle.truffle.api.nodes.Node;
3838
import com.oracle.truffle.api.profiles.BranchProfile;
3939
import com.oracle.truffle.api.profiles.ConditionProfile;
@@ -49,6 +49,7 @@
4949
import com.oracle.truffle.r.runtime.RArguments;
5050
import com.oracle.truffle.r.runtime.RCaller;
5151
import com.oracle.truffle.r.runtime.RError;
52+
import com.oracle.truffle.r.runtime.RInternalError;
5253
import com.oracle.truffle.r.runtime.VirtualEvalFrame;
5354
import com.oracle.truffle.r.runtime.context.RContext;
5455
import com.oracle.truffle.r.runtime.data.RPromise;
@@ -121,12 +122,21 @@ private boolean deoptimize(RPromise promise) {
121122
}
122123
}
123124

125+
public PromiseHelperNode() {
126+
this((byte) 0);
127+
}
128+
129+
public PromiseHelperNode(byte recursiveCounter) {
130+
this.recursiveCounter = recursiveCounter;
131+
}
132+
133+
private final byte recursiveCounter;
124134
@Child private InlineCacheNode promiseClosureCache;
125135

126136
@CompilationFinal private PrimitiveValueProfile optStateProfile = PrimitiveValueProfile.createEqualityProfile();
127137
@CompilationFinal private ConditionProfile inOriginProfile = ConditionProfile.createBinaryProfile();
128-
@Child private GenerateValueNonDefaultOptimizedNode generateValueNonDefaultOptimizedNode = GenerateValueNonDefaultOptimizedNodeGen.create();
129138
@Child private SetVisibilityNode setVisibility;
139+
@Child private GenerateValueNonDefaultOptimizedNode generateValueNonDefaultOptimizedNode;
130140
private final ValueProfile promiseFrameProfile = ValueProfile.createClassProfile();
131141

132142
/**
@@ -153,8 +163,13 @@ public Object evaluate(VirtualFrame frame, RPromise promise) {
153163
return evaluateSlowPath(frame, promise);
154164
}
155165
if (isDefaultOptProfile.profile(PromiseState.isDefaultOpt(state))) {
166+
// default values of arguments are evaluated in the frame of the function that takes
167+
// them, we do not need to retrieve the frame of the promise, we already have it
156168
return generateValueDefault(frame, promise);
157169
} else {
170+
// non-default arguments we need to evaluate in the frame of the function that supplied
171+
// them and that would mean frame materialization, we first try to see if the promise
172+
// can be optimized
158173
return generateValueNonDefault(frame, state, (EagerPromise) promise);
159174
}
160175
}
@@ -169,6 +184,7 @@ private Object generateValueDefault(VirtualFrame frame, RPromise promise) {
169184
CompilerDirectives.transferToInterpreterAndInvalidate();
170185
promiseClosureCache = insert(InlineCacheNode.create(DSLConfig.getCacheSize(RContext.getInstance().getNonNegativeIntOption(PromiseCacheSize))));
171186
}
187+
// TODO: no wrapping of arguments here?, why we do not have to set visibility here?
172188
promise.setUnderEvaluation();
173189
boolean inOrigin = inOriginProfile.profile(isInOriginFrame(frame, promise));
174190
Frame execFrame = inOrigin ? frame : wrapPromiseFrame(frame, promiseFrameProfile.profile(promise.getFrame()));
@@ -197,6 +213,10 @@ private static boolean getVisibilitySlowPath(Frame frame) {
197213
private Object generateValueNonDefault(VirtualFrame frame, int state, EagerPromise promise) {
198214
assert !PromiseState.isDefaultOpt(state);
199215
if (!isDeoptimized(promise)) {
216+
if (generateValueNonDefaultOptimizedNode == null) {
217+
CompilerDirectives.transferToInterpreterAndInvalidate();
218+
generateValueNonDefaultOptimizedNode = insert(GenerateValueNonDefaultOptimizedNodeGen.create(recursiveCounter));
219+
}
200220
Object result = generateValueNonDefaultOptimizedNode.execute(frame, state, promise);
201221
if (result != null) {
202222
return result;
@@ -341,16 +361,11 @@ public boolean isEvaluated(RPromise promise) {
341361
private final ConditionProfile isDefaultProfile = ConditionProfile.createBinaryProfile();
342362
private final ConditionProfile isFrameForEnvProfile = ConditionProfile.createBinaryProfile();
343363

344-
private final ValueProfile valueProfile = ValueProfile.createClassProfile();
345-
346364
// Eager
347365
private final ConditionProfile isExplicitProfile = ConditionProfile.createBinaryProfile();
348366
private final ConditionProfile isDefaultOptProfile = ConditionProfile.createBinaryProfile();
349367
private final ConditionProfile isDeoptimizedProfile = ConditionProfile.createBinaryProfile();
350368

351-
public PromiseHelperNode() {
352-
}
353-
354369
/**
355370
* @return The state of the {@link RPromise#isUnderEvaluation()} flag.
356371
*/
@@ -359,20 +374,10 @@ public boolean isUnderEvaluation(RPromise promise) {
359374
}
360375

361376
/**
362-
* Used in case the {@link RPromise} is evaluated outside.
363-
*
364-
* @param newValue
365-
*/
366-
public void setValue(Object newValue, RPromise promise) {
367-
promise.setValue(valueProfile.profile(newValue));
368-
}
369-
370-
/**
371-
* @param frame
372377
* @return Whether the given {@link RPromise} is in its origin context and thus can be resolved
373378
* directly inside the AST.
374379
*/
375-
public boolean isInOriginFrame(VirtualFrame frame, RPromise promise) {
380+
private boolean isInOriginFrame(VirtualFrame frame, RPromise promise) {
376381
if (isDefaultArgument(promise) && isNullFrame(promise)) {
377382
return true;
378383
}
@@ -382,87 +387,161 @@ public boolean isInOriginFrame(VirtualFrame frame, RPromise promise) {
382387
return isFrameForEnvProfile.profile(frame == promise.getFrame());
383388
}
384389

390+
/**
391+
* Attempts to generate the value of the given promise in optimized way without having to
392+
* materialize the promise's exec frame. If that's not possible, returns {@code null}.
393+
*
394+
* Note: we have to create a new instance of {@link WrapArgumentNode} for each
395+
* {@link EagerPromise#wrapIndex()} we encounter, but only for {@link EagerPromise#wrapIndex()}
396+
* that are up to {@link ArgumentStatePush#MAX_COUNTED_ARGS}, this is also because the
397+
* {@link WrapArgumentNode} takes the argument index as constructor parameter. Values of R
398+
* arguments have to be channelled through the {@link WrapArgumentNode} so that the reference
399+
* counting can work for the arguments, but the reference counting is only done for up to
400+
* {@link ArgumentStatePush#MAX_COUNTED_ARGS} arguments.
401+
*/
402+
@ImportStatic(DSLConfig.class)
385403
@ReportPolymorphism
386404
protected abstract static class GenerateValueNonDefaultOptimizedNode extends Node {
387405

406+
protected static final int ASSUMPTION_CACHE_SIZE = ArgumentStatePush.MAX_COUNTED_ARGS + 4;
407+
protected static final int CACHE_SIZE = ArgumentStatePush.MAX_COUNTED_ARGS * 2;
408+
protected static final int RECURSIVE_PROMISE_LIMIT = 3;
409+
410+
private final byte recursiveCounter;
411+
@Child private PromiseHelperNode nextNode = null;
412+
@Child private SetVisibilityNode visibility;
413+
414+
protected GenerateValueNonDefaultOptimizedNode(byte recursiveCounter) {
415+
this.recursiveCounter = recursiveCounter;
416+
}
417+
388418
public abstract Object execute(VirtualFrame frame, int state, EagerPromise promise);
389419

390-
@Specialization(guards = {"promise.getIsValidAssumption() == eagerAssumption", "eagerAssumption.isValid()"})
391-
protected Object doCached(VirtualFrame frame, int state, EagerPromise promise,
392-
@SuppressWarnings("unused") @Cached("promise.getIsValidAssumption()") Assumption eagerAssumption) {
420+
// @formatter:off
421+
// data from "rutgen" tests
422+
// column A: # of distinct tuples (assumption, wrapIndex, state) observed per GenerateValueNonDefaultOptimizedNode instance
423+
// column B: # of GenerateValueNonDefaultOptimizedNode instances
424+
// A B
425+
// 1 10555
426+
// 2 1387
427+
// 3 308
428+
// 4 199
429+
// 5 54
430+
// 6 34
431+
// 7 40
432+
// 8 8
433+
// 9 31
434+
// 10 8
435+
// 11 4
436+
// 12 4
437+
// 14 19
438+
// >14 <=4
439+
// @formatter:on
440+
441+
@Specialization(guards = {
442+
"promise.getIsValidAssumption() == eagerAssumption",
443+
"eagerAssumption.isValid()",
444+
"isCompatibleEagerValueProfile(eagerValueProfile, state)",
445+
"isCompatibleWrapNode(wrapArgumentNode, promise, state)"}, //
446+
limit = "getCacheSize(ASSUMPTION_CACHE_SIZE)")
447+
Object doCachedAssumption(VirtualFrame frame, int state, EagerPromise promise,
448+
@SuppressWarnings("unused") @Cached("promise.getIsValidAssumption()") Assumption eagerAssumption,
449+
@Cached("createEagerValueProfile(state)") ValueProfile eagerValueProfile,
450+
@Cached("createWrapArgumentNode(promise, state)") WrapArgumentNode wrapArgumentNode) {
451+
return generateValue(frame, state, promise, wrapArgumentNode, eagerValueProfile);
452+
}
453+
454+
@Specialization(replaces = "doCachedAssumption", guards = {
455+
"isCompatibleWrapNode(wrapArgumentNode, promise, state)",
456+
"isCompatibleEagerValueProfile(eagerValueProfile, state)"}, //
457+
limit = "CACHE_SIZE")
458+
Object doUncachedAssumption(VirtualFrame frame, int state, EagerPromise promise,
459+
@Cached("createBinaryProfile()") ConditionProfile isValidProfile,
460+
@Cached("createEagerValueProfile(state)") ValueProfile eagerValueProfile,
461+
@Cached("createWrapArgumentNode(promise, state)") WrapArgumentNode wrapArgumentNode) {
462+
// Note: the assumption inside the promise is not constant anymore, so we profile the
463+
// result of isValid
464+
if (isValidProfile.profile(promise.isValid())) {
465+
return generateValue(frame, state, promise, wrapArgumentNode, eagerValueProfile);
466+
} else {
467+
return null;
468+
}
469+
}
470+
471+
@Specialization(replaces = "doUncachedAssumption")
472+
Object doFallback(@SuppressWarnings("unused") int state, @SuppressWarnings("unused") EagerPromise promise) {
473+
throw RInternalError.shouldNotReachHere("The cache of doUncachedAssumption should never overflow");
474+
}
475+
476+
// If promise evaluates to another promise, we create another RPromiseHelperNode to evaluate
477+
// that, but only up to certain recursion level
478+
private Object evaluateNextNode(VirtualFrame frame, RPromise nextPromise) {
479+
if (recursiveCounter > DSLConfig.getCacheSize(RECURSIVE_PROMISE_LIMIT)) {
480+
evaluateNextNodeSlowPath(frame.materialize(), nextPromise);
481+
}
482+
if (nextNode == null) {
483+
CompilerDirectives.transferToInterpreterAndInvalidate();
484+
nextNode = insert(new PromiseHelperNode((byte) (recursiveCounter + 1)));
485+
}
486+
return nextNode.evaluate(frame, nextPromise);
487+
}
488+
489+
@TruffleBoundary
490+
private static void evaluateNextNodeSlowPath(MaterializedFrame frame, RPromise nextPromise) {
491+
PromiseHelperNode.evaluateSlowPath(frame, nextPromise);
492+
}
493+
494+
private Object generateValue(VirtualFrame frame, int state, EagerPromise promise, WrapArgumentNode wrapArgumentNode, ValueProfile eagerValueProfile) {
393495
Object value;
394496
if (PromiseState.isEager(state)) {
395-
assert PromiseState.isEager(state);
396-
value = getEagerValue(frame, promise);
497+
assert eagerValueProfile != null;
498+
value = getEagerValue(frame, promise, wrapArgumentNode, eagerValueProfile);
397499
} else {
398500
RPromise nextPromise = (RPromise) promise.getEagerValue();
399-
value = checkNextNode().evaluate(frame, nextPromise);
501+
value = evaluateNextNode(frame, nextPromise);
400502
}
401503
assert promise.getRawValue() == null;
402504
assert value != null;
403505
promise.setValue(value);
404506
return value;
405507
}
406508

407-
@Specialization(replaces = "doCached")
408-
@TruffleBoundary
409-
protected Object switchToSlowPath(@SuppressWarnings("unused") int state, @SuppressWarnings("unused") EagerPromise promise) {
410-
return null;
509+
/**
510+
* for R arguments that need to be wrapped using WrapArgumentNode, creates the
511+
* WrapArgumentNode, otherwise returns null.
512+
*/
513+
static WrapArgumentNode createWrapArgumentNode(EagerPromise promise, int state) {
514+
return needsWrapNode(promise.wrapIndex(), state) ? WrapArgumentNode.create(promise.wrapIndex()) : null;
411515
}
412516

413-
@Child private PromiseHelperNode nextNode = null;
414-
415-
private PromiseHelperNode checkNextNode() {
416-
if (nextNode == null) {
417-
CompilerDirectives.transferToInterpreterAndInvalidate();
418-
nextNode = insert(new PromiseHelperNode());
517+
static boolean isCompatibleWrapNode(WrapArgumentNode wrapNode, EagerPromise promise, int state) {
518+
if (needsWrapNode(promise.wrapIndex(), state)) {
519+
return wrapNode != null && wrapNode.getIndex() == promise.wrapIndex();
520+
} else {
521+
return wrapNode == null;
419522
}
420-
return nextNode;
421523
}
422524

423-
@Children private final WrapArgumentNode[] wrapNodes = new WrapArgumentNode[ArgumentStatePush.MAX_COUNTED_ARGS];
424-
private final ConditionProfile shouldWrap = ConditionProfile.createBinaryProfile();
425-
private final ValueProfile eagerValueProfile = ValueProfile.createClassProfile();
426-
private static final int UNINITIALIZED = -1;
427-
private static final int GENERIC = -2;
428-
@CompilationFinal private int cachedWrapIndex = UNINITIALIZED;
429-
@Child private SetVisibilityNode visibility;
525+
private static boolean needsWrapNode(int wrapIndex, int state) {
526+
return PromiseState.isEager(state) && wrapIndex != ArgumentStatePush.INVALID_INDEX && wrapIndex < ArgumentStatePush.MAX_COUNTED_ARGS;
527+
}
528+
529+
static boolean isCompatibleEagerValueProfile(ValueProfile profile, int state) {
530+
return !PromiseState.isEager(state) || profile != null;
531+
}
532+
533+
static ValueProfile createEagerValueProfile(int state) {
534+
return PromiseState.isEager(state) ? ValueProfile.createClassProfile() : null;
535+
}
430536

431537
/**
432-
* Returns {@link EagerPromise#getEagerValue()} profiled.
538+
* Returns {@link EagerPromise#getEagerValue()} profiled and takes care of wrapping the
539+
* value with {@link WrapArgumentNode}.
433540
*/
434-
@ExplodeLoop
435-
private Object getEagerValue(VirtualFrame frame, EagerPromise promise) {
541+
private Object getEagerValue(VirtualFrame frame, EagerPromise promise, WrapArgumentNode wrapArgumentNode, ValueProfile eagerValueProfile) {
436542
Object o = promise.getEagerValue();
437-
int wrapIndex = promise.wrapIndex();
438-
if (shouldWrap.profile(wrapIndex != ArgumentStatePush.INVALID_INDEX)) {
439-
if (cachedWrapIndex == UNINITIALIZED) {
440-
CompilerDirectives.transferToInterpreterAndInvalidate();
441-
cachedWrapIndex = wrapIndex;
442-
}
443-
if (cachedWrapIndex != GENERIC && wrapIndex != cachedWrapIndex) {
444-
CompilerDirectives.transferToInterpreterAndInvalidate();
445-
cachedWrapIndex = GENERIC;
446-
}
447-
if (cachedWrapIndex != GENERIC) {
448-
if (cachedWrapIndex < ArgumentStatePush.MAX_COUNTED_ARGS) {
449-
if (wrapNodes[cachedWrapIndex] == null) {
450-
CompilerDirectives.transferToInterpreterAndInvalidate();
451-
wrapNodes[cachedWrapIndex] = insert(WrapArgumentNode.create(cachedWrapIndex));
452-
}
453-
wrapNodes[cachedWrapIndex].execute(frame, o);
454-
}
455-
} else {
456-
for (int i = 0; i < ArgumentStatePush.MAX_COUNTED_ARGS; i++) {
457-
if (wrapIndex == i) {
458-
if (wrapNodes[i] == null) {
459-
CompilerDirectives.transferToInterpreterAndInvalidate();
460-
wrapNodes[i] = insert(WrapArgumentNode.create(i));
461-
}
462-
wrapNodes[i].execute(frame, o);
463-
}
464-
}
465-
}
543+
if (wrapArgumentNode != null) {
544+
wrapArgumentNode.execute(frame, o);
466545
}
467546
if (visibility == null) {
468547
CompilerDirectives.transferToInterpreterAndInvalidate();

com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/data/RPromise.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,8 @@ public static final class EagerPromise extends RPromise {
293293
private final EagerFeedback feedback;
294294

295295
/**
296-
* Index of the argument for which the promise was create.
296+
* Index of the argument for which the promise was created for. {@code -1} for promises
297+
* created for default argument values.
297298
*/
298299
private final int wrapIndex;
299300

0 commit comments

Comments
 (0)