Skip to content

Commit 2abc36d

Browse files
committed
[GR-13413] Refactor PromiseHelperNode.
PullRequest: fastr/1890
2 parents c9da64b + e0fe5cc commit 2abc36d

File tree

3 files changed

+176
-90
lines changed

3 files changed

+176
-90
lines changed

com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/function/PromiseHelperNode.java

Lines changed: 161 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -22,18 +22,20 @@
2222
*/
2323
package com.oracle.truffle.r.nodes.function;
2424

25+
import static com.oracle.truffle.r.runtime.context.FastROptions.PromiseCacheSize;
26+
2527
import com.oracle.truffle.api.Assumption;
2628
import com.oracle.truffle.api.CompilerAsserts;
2729
import com.oracle.truffle.api.CompilerDirectives;
2830
import com.oracle.truffle.api.CompilerDirectives.CompilationFinal;
2931
import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary;
3032
import com.oracle.truffle.api.dsl.Cached;
33+
import com.oracle.truffle.api.dsl.ImportStatic;
3134
import com.oracle.truffle.api.dsl.ReportPolymorphism;
3235
import com.oracle.truffle.api.dsl.Specialization;
3336
import com.oracle.truffle.api.frame.Frame;
3437
import com.oracle.truffle.api.frame.MaterializedFrame;
3538
import com.oracle.truffle.api.frame.VirtualFrame;
36-
import com.oracle.truffle.api.nodes.ExplodeLoop;
3739
import com.oracle.truffle.api.nodes.Node;
3840
import com.oracle.truffle.api.profiles.BranchProfile;
3941
import com.oracle.truffle.api.profiles.ConditionProfile;
@@ -45,10 +47,10 @@
4547
import com.oracle.truffle.r.nodes.function.visibility.GetVisibilityNode;
4648
import com.oracle.truffle.r.nodes.function.visibility.SetVisibilityNode;
4749
import com.oracle.truffle.r.runtime.DSLConfig;
48-
import static com.oracle.truffle.r.runtime.context.FastROptions.PromiseCacheSize;
4950
import com.oracle.truffle.r.runtime.RArguments;
5051
import com.oracle.truffle.r.runtime.RCaller;
5152
import com.oracle.truffle.r.runtime.RError;
53+
import com.oracle.truffle.r.runtime.RInternalError;
5254
import com.oracle.truffle.r.runtime.VirtualEvalFrame;
5355
import com.oracle.truffle.r.runtime.context.RContext;
5456
import com.oracle.truffle.r.runtime.data.RPromise;
@@ -121,12 +123,21 @@ private boolean deoptimize(RPromise promise) {
121123
}
122124
}
123125

126+
public PromiseHelperNode() {
127+
this((byte) 0);
128+
}
129+
130+
public PromiseHelperNode(byte recursiveCounter) {
131+
this.recursiveCounter = recursiveCounter;
132+
}
133+
134+
private final byte recursiveCounter;
124135
@Child private InlineCacheNode promiseClosureCache;
125136

126137
@CompilationFinal private PrimitiveValueProfile optStateProfile = PrimitiveValueProfile.createEqualityProfile();
127138
@CompilationFinal private ConditionProfile inOriginProfile = ConditionProfile.createBinaryProfile();
128-
@Child private GenerateValueNonDefaultOptimizedNode generateValueNonDefaultOptimizedNode = GenerateValueNonDefaultOptimizedNodeGen.create();
129139
@Child private SetVisibilityNode setVisibility;
140+
@Child private GenerateValueNonDefaultOptimizedNode generateValueNonDefaultOptimizedNode;
130141
private final ValueProfile promiseFrameProfile = ValueProfile.createClassProfile();
131142

132143
/**
@@ -153,8 +164,13 @@ public Object evaluate(VirtualFrame frame, RPromise promise) {
153164
return evaluateSlowPath(frame, promise);
154165
}
155166
if (isDefaultOptProfile.profile(PromiseState.isDefaultOpt(state))) {
167+
// default values of arguments are evaluated in the frame of the function that takes
168+
// them, we do not need to retrieve the frame of the promise, we already have it
156169
return generateValueDefault(frame, promise);
157170
} else {
171+
// non-default arguments we need to evaluate in the frame of the function that supplied
172+
// them and that would mean frame materialization, we first try to see if the promise
173+
// can be optimized
158174
return generateValueNonDefault(frame, state, (EagerPromise) promise);
159175
}
160176
}
@@ -169,6 +185,7 @@ private Object generateValueDefault(VirtualFrame frame, RPromise promise) {
169185
CompilerDirectives.transferToInterpreterAndInvalidate();
170186
promiseClosureCache = insert(InlineCacheNode.create(DSLConfig.getCacheSize(RContext.getInstance().getNonNegativeIntOption(PromiseCacheSize))));
171187
}
188+
// TODO: no wrapping of arguments here?, why we do not have to set visibility here?
172189
promise.setUnderEvaluation();
173190
boolean inOrigin = inOriginProfile.profile(isInOriginFrame(frame, promise));
174191
Frame execFrame = inOrigin ? frame : wrapPromiseFrame(frame, promiseFrameProfile.profile(promise.getFrame()));
@@ -197,6 +214,10 @@ private static boolean getVisibilitySlowPath(Frame frame) {
197214
private Object generateValueNonDefault(VirtualFrame frame, int state, EagerPromise promise) {
198215
assert !PromiseState.isDefaultOpt(state);
199216
if (!isDeoptimized(promise)) {
217+
if (generateValueNonDefaultOptimizedNode == null) {
218+
CompilerDirectives.transferToInterpreterAndInvalidate();
219+
generateValueNonDefaultOptimizedNode = insert(GenerateValueNonDefaultOptimizedNodeGen.create(recursiveCounter));
220+
}
200221
Object result = generateValueNonDefaultOptimizedNode.execute(frame, state, promise);
201222
if (result != null) {
202223
return result;
@@ -341,16 +362,11 @@ public boolean isEvaluated(RPromise promise) {
341362
private final ConditionProfile isDefaultProfile = ConditionProfile.createBinaryProfile();
342363
private final ConditionProfile isFrameForEnvProfile = ConditionProfile.createBinaryProfile();
343364

344-
private final ValueProfile valueProfile = ValueProfile.createClassProfile();
345-
346365
// Eager
347366
private final ConditionProfile isExplicitProfile = ConditionProfile.createBinaryProfile();
348367
private final ConditionProfile isDefaultOptProfile = ConditionProfile.createBinaryProfile();
349368
private final ConditionProfile isDeoptimizedProfile = ConditionProfile.createBinaryProfile();
350369

351-
public PromiseHelperNode() {
352-
}
353-
354370
/**
355371
* @return The state of the {@link RPromise#isUnderEvaluation()} flag.
356372
*/
@@ -359,20 +375,10 @@ public boolean isUnderEvaluation(RPromise promise) {
359375
}
360376

361377
/**
362-
* Used in case the {@link RPromise} is evaluated outside.
363-
*
364-
* @param newValue
365-
*/
366-
public void setValue(Object newValue, RPromise promise) {
367-
promise.setValue(valueProfile.profile(newValue));
368-
}
369-
370-
/**
371-
* @param frame
372378
* @return Whether the given {@link RPromise} is in its origin context and thus can be resolved
373379
* directly inside the AST.
374380
*/
375-
public boolean isInOriginFrame(VirtualFrame frame, RPromise promise) {
381+
private boolean isInOriginFrame(VirtualFrame frame, RPromise promise) {
376382
if (isDefaultArgument(promise) && isNullFrame(promise)) {
377383
return true;
378384
}
@@ -382,87 +388,167 @@ public boolean isInOriginFrame(VirtualFrame frame, RPromise promise) {
382388
return isFrameForEnvProfile.profile(frame == promise.getFrame());
383389
}
384390

391+
/**
392+
* Attempts to generate the value of the given promise in optimized way without having to
393+
* materialize the promise's exec frame. If that's not possible, returns {@code null}.
394+
*
395+
* Note: we have to create a new instance of {@link WrapArgumentNode} for each
396+
* {@link EagerPromise#wrapIndex()} we encounter, but only for {@link EagerPromise#wrapIndex()}
397+
* that are up to {@link ArgumentStatePush#MAX_COUNTED_ARGS}, this is also because the
398+
* {@link WrapArgumentNode} takes the argument index as constructor parameter. Values of R
399+
* arguments have to be channelled through the {@link WrapArgumentNode} so that the reference
400+
* counting can work for the arguments, but the reference counting is only done for up to
401+
* {@link ArgumentStatePush#MAX_COUNTED_ARGS} arguments.
402+
*/
403+
@ImportStatic(DSLConfig.class)
385404
@ReportPolymorphism
386405
protected abstract static class GenerateValueNonDefaultOptimizedNode extends Node {
387406

407+
protected static final int ASSUMPTION_CACHE_SIZE = ArgumentStatePush.MAX_COUNTED_ARGS + 4;
408+
protected static final int CACHE_SIZE = ArgumentStatePush.MAX_COUNTED_ARGS * 2;
409+
protected static final int RECURSIVE_PROMISE_LIMIT = 3;
410+
411+
// If set to -1, then no further recursion should take place
412+
// This avoids having to invoke DSLConfig.getCacheSize in PE'd code
413+
private final byte recursiveCounter;
414+
@Child private PromiseHelperNode nextNode = null;
415+
@Child private SetVisibilityNode visibility;
416+
417+
protected GenerateValueNonDefaultOptimizedNode(byte recursiveCounter) {
418+
if (recursiveCounter > DSLConfig.getCacheSize(RECURSIVE_PROMISE_LIMIT)) {
419+
this.recursiveCounter = -1;
420+
} else {
421+
this.recursiveCounter = recursiveCounter;
422+
}
423+
}
424+
388425
public abstract Object execute(VirtualFrame frame, int state, EagerPromise promise);
389426

390-
@Specialization(guards = {"promise.getIsValidAssumption() == eagerAssumption", "eagerAssumption.isValid()"})
391-
protected Object doCached(VirtualFrame frame, int state, EagerPromise promise,
392-
@SuppressWarnings("unused") @Cached("promise.getIsValidAssumption()") Assumption eagerAssumption) {
427+
// @formatter:off
428+
// data from "rutgen" tests
429+
// column A: # of distinct tuples (assumption, wrapIndex, state) observed per GenerateValueNonDefaultOptimizedNode instance
430+
// column B: # of GenerateValueNonDefaultOptimizedNode instances
431+
// A B
432+
// 1 10555
433+
// 2 1387
434+
// 3 308
435+
// 4 199
436+
// 5 54
437+
// 6 34
438+
// 7 40
439+
// 8 8
440+
// 9 31
441+
// 10 8
442+
// 11 4
443+
// 12 4
444+
// 14 19
445+
// >14 <=4
446+
// @formatter:on
447+
448+
@Specialization(guards = {
449+
"promise.getIsValidAssumption() == eagerAssumption",
450+
"eagerAssumption.isValid()",
451+
"isCompatibleEagerValueProfile(eagerValueProfile, state)",
452+
"isCompatibleWrapNode(wrapArgumentNode, promise, state)"}, //
453+
limit = "getCacheSize(ASSUMPTION_CACHE_SIZE)")
454+
Object doCachedAssumption(VirtualFrame frame, int state, EagerPromise promise,
455+
@SuppressWarnings("unused") @Cached("promise.getIsValidAssumption()") Assumption eagerAssumption,
456+
@Cached("createEagerValueProfile(state)") ValueProfile eagerValueProfile,
457+
@Cached("createWrapArgumentNode(promise, state)") WrapArgumentNode wrapArgumentNode) {
458+
return generateValue(frame, state, promise, wrapArgumentNode, eagerValueProfile);
459+
}
460+
461+
@Specialization(replaces = "doCachedAssumption", guards = {
462+
"isCompatibleWrapNode(wrapArgumentNode, promise, state)",
463+
"isCompatibleEagerValueProfile(eagerValueProfile, state)"}, //
464+
limit = "CACHE_SIZE")
465+
Object doUncachedAssumption(VirtualFrame frame, int state, EagerPromise promise,
466+
@Cached("createBinaryProfile()") ConditionProfile isValidProfile,
467+
@Cached("createEagerValueProfile(state)") ValueProfile eagerValueProfile,
468+
@Cached("createWrapArgumentNode(promise, state)") WrapArgumentNode wrapArgumentNode) {
469+
// Note: the assumption inside the promise is not constant anymore, so we profile the
470+
// result of isValid
471+
if (isValidProfile.profile(promise.isValid())) {
472+
return generateValue(frame, state, promise, wrapArgumentNode, eagerValueProfile);
473+
} else {
474+
return null;
475+
}
476+
}
477+
478+
@Specialization(replaces = "doUncachedAssumption")
479+
Object doFallback(@SuppressWarnings("unused") int state, @SuppressWarnings("unused") EagerPromise promise) {
480+
throw RInternalError.shouldNotReachHere("The cache of doUncachedAssumption should never overflow");
481+
}
482+
483+
// If promise evaluates to another promise, we create another RPromiseHelperNode to evaluate
484+
// that, but only up to certain recursion level
485+
private Object evaluateNextNode(VirtualFrame frame, RPromise nextPromise) {
486+
if (recursiveCounter == -1) {
487+
evaluateNextNodeSlowPath(frame.materialize(), nextPromise);
488+
}
489+
if (nextNode == null) {
490+
CompilerDirectives.transferToInterpreterAndInvalidate();
491+
nextNode = insert(new PromiseHelperNode((byte) (recursiveCounter + 1)));
492+
}
493+
return nextNode.evaluate(frame, nextPromise);
494+
}
495+
496+
@TruffleBoundary
497+
private static void evaluateNextNodeSlowPath(MaterializedFrame frame, RPromise nextPromise) {
498+
PromiseHelperNode.evaluateSlowPath(frame, nextPromise);
499+
}
500+
501+
private Object generateValue(VirtualFrame frame, int state, EagerPromise promise, WrapArgumentNode wrapArgumentNode, ValueProfile eagerValueProfile) {
393502
Object value;
394503
if (PromiseState.isEager(state)) {
395-
assert PromiseState.isEager(state);
396-
value = getEagerValue(frame, promise);
504+
assert eagerValueProfile != null;
505+
value = getEagerValue(frame, promise, wrapArgumentNode, eagerValueProfile);
397506
} else {
398507
RPromise nextPromise = (RPromise) promise.getEagerValue();
399-
value = checkNextNode().evaluate(frame, nextPromise);
508+
value = evaluateNextNode(frame, nextPromise);
400509
}
401510
assert promise.getRawValue() == null;
402511
assert value != null;
403512
promise.setValue(value);
404513
return value;
405514
}
406515

407-
@Specialization(replaces = "doCached")
408-
@TruffleBoundary
409-
protected Object switchToSlowPath(@SuppressWarnings("unused") int state, @SuppressWarnings("unused") EagerPromise promise) {
410-
return null;
516+
/**
517+
* for R arguments that need to be wrapped using WrapArgumentNode, creates the
518+
* WrapArgumentNode, otherwise returns null.
519+
*/
520+
static WrapArgumentNode createWrapArgumentNode(EagerPromise promise, int state) {
521+
return needsWrapNode(promise.wrapIndex(), state) ? WrapArgumentNode.create(promise.wrapIndex()) : null;
411522
}
412523

413-
@Child private PromiseHelperNode nextNode = null;
414-
415-
private PromiseHelperNode checkNextNode() {
416-
if (nextNode == null) {
417-
CompilerDirectives.transferToInterpreterAndInvalidate();
418-
nextNode = insert(new PromiseHelperNode());
524+
static boolean isCompatibleWrapNode(WrapArgumentNode wrapNode, EagerPromise promise, int state) {
525+
if (needsWrapNode(promise.wrapIndex(), state)) {
526+
return wrapNode != null && wrapNode.getIndex() == promise.wrapIndex();
527+
} else {
528+
return wrapNode == null;
419529
}
420-
return nextNode;
421530
}
422531

423-
@Children private final WrapArgumentNode[] wrapNodes = new WrapArgumentNode[ArgumentStatePush.MAX_COUNTED_ARGS];
424-
private final ConditionProfile shouldWrap = ConditionProfile.createBinaryProfile();
425-
private final ValueProfile eagerValueProfile = ValueProfile.createClassProfile();
426-
private static final int UNINITIALIZED = -1;
427-
private static final int GENERIC = -2;
428-
@CompilationFinal private int cachedWrapIndex = UNINITIALIZED;
429-
@Child private SetVisibilityNode visibility;
532+
private static boolean needsWrapNode(int wrapIndex, int state) {
533+
return PromiseState.isEager(state) && wrapIndex != ArgumentStatePush.INVALID_INDEX && wrapIndex < ArgumentStatePush.MAX_COUNTED_ARGS;
534+
}
535+
536+
static boolean isCompatibleEagerValueProfile(ValueProfile profile, int state) {
537+
return !PromiseState.isEager(state) || profile != null;
538+
}
539+
540+
static ValueProfile createEagerValueProfile(int state) {
541+
return PromiseState.isEager(state) ? ValueProfile.createClassProfile() : null;
542+
}
430543

431544
/**
432-
* Returns {@link EagerPromise#getEagerValue()} profiled.
545+
* Returns {@link EagerPromise#getEagerValue()} profiled and takes care of wrapping the
546+
* value with {@link WrapArgumentNode}.
433547
*/
434-
@ExplodeLoop
435-
private Object getEagerValue(VirtualFrame frame, EagerPromise promise) {
548+
private Object getEagerValue(VirtualFrame frame, EagerPromise promise, WrapArgumentNode wrapArgumentNode, ValueProfile eagerValueProfile) {
436549
Object o = promise.getEagerValue();
437-
int wrapIndex = promise.wrapIndex();
438-
if (shouldWrap.profile(wrapIndex != ArgumentStatePush.INVALID_INDEX)) {
439-
if (cachedWrapIndex == UNINITIALIZED) {
440-
CompilerDirectives.transferToInterpreterAndInvalidate();
441-
cachedWrapIndex = wrapIndex;
442-
}
443-
if (cachedWrapIndex != GENERIC && wrapIndex != cachedWrapIndex) {
444-
CompilerDirectives.transferToInterpreterAndInvalidate();
445-
cachedWrapIndex = GENERIC;
446-
}
447-
if (cachedWrapIndex != GENERIC) {
448-
if (cachedWrapIndex < ArgumentStatePush.MAX_COUNTED_ARGS) {
449-
if (wrapNodes[cachedWrapIndex] == null) {
450-
CompilerDirectives.transferToInterpreterAndInvalidate();
451-
wrapNodes[cachedWrapIndex] = insert(WrapArgumentNode.create(cachedWrapIndex));
452-
}
453-
wrapNodes[cachedWrapIndex].execute(frame, o);
454-
}
455-
} else {
456-
for (int i = 0; i < ArgumentStatePush.MAX_COUNTED_ARGS; i++) {
457-
if (wrapIndex == i) {
458-
if (wrapNodes[i] == null) {
459-
CompilerDirectives.transferToInterpreterAndInvalidate();
460-
wrapNodes[i] = insert(WrapArgumentNode.create(i));
461-
}
462-
wrapNodes[i].execute(frame, o);
463-
}
464-
}
465-
}
550+
if (wrapArgumentNode != null) {
551+
wrapArgumentNode.execute(frame, o);
466552
}
467553
if (visibility == null) {
468554
CompilerDirectives.transferToInterpreterAndInvalidate();

0 commit comments

Comments
 (0)