Skip to content

Commit 71f8ee8

Browse files
committed
[GR-22670] Rewrite IsFiniteFunctions to use VectorDataLibrary
(cherry picked from commit 15f6d83)
1 parent e66daaf commit 71f8ee8

File tree

1 file changed

+161
-73
lines changed

1 file changed

+161
-73
lines changed

com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/IsFiniteFunctions.java

Lines changed: 161 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,17 @@
2727
import static com.oracle.truffle.r.runtime.builtins.RBuiltinKind.PRIMITIVE;
2828

2929
import java.util.Arrays;
30-
import java.util.function.DoublePredicate;
31-
import java.util.function.IntPredicate;
3230

3331
import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary;
32+
import com.oracle.truffle.api.dsl.Cached;
3433
import com.oracle.truffle.api.dsl.Fallback;
3534
import com.oracle.truffle.api.dsl.ImportStatic;
3635
import com.oracle.truffle.api.dsl.Specialization;
3736
import com.oracle.truffle.api.interop.TruffleObject;
37+
import com.oracle.truffle.api.library.CachedLibrary;
38+
import com.oracle.truffle.api.profiles.ConditionProfile;
39+
import com.oracle.truffle.r.runtime.RInternalError;
40+
import com.oracle.truffle.r.runtime.data.VectorDataLibrary;
3841
import com.oracle.truffle.r.runtime.data.nodes.attributes.SpecialAttributesFunctions.InitDimsNamesDimNamesNode;
3942
import com.oracle.truffle.r.nodes.builtin.RBuiltinNode;
4043
import com.oracle.truffle.r.nodes.unary.TypeofNode;
@@ -44,10 +47,10 @@
4447
import com.oracle.truffle.r.runtime.data.RComplex;
4548
import com.oracle.truffle.r.runtime.data.RDataFactory;
4649
import com.oracle.truffle.r.runtime.data.RDoubleVector;
50+
import com.oracle.truffle.r.runtime.data.RLogicalVector;
4751
import com.oracle.truffle.r.runtime.data.RNull;
4852
import com.oracle.truffle.r.runtime.data.model.RAbstractComplexVector;
4953
import com.oracle.truffle.r.runtime.data.RIntVector;
50-
import com.oracle.truffle.r.runtime.data.RLogicalVector;
5154
import com.oracle.truffle.r.runtime.data.RRawVector;
5255
import com.oracle.truffle.r.runtime.data.model.RAbstractStringVector;
5356
import com.oracle.truffle.r.runtime.data.model.RAbstractVector;
@@ -58,30 +61,41 @@ public class IsFiniteFunctions {
5861
public abstract static class Adapter extends RBuiltinNode.Arg1 {
5962

6063
@Child private InitDimsNamesDimNamesNode initDimsNamesDimNames = InitDimsNamesDimNamesNode.create();
64+
private final Predicates predicates;
6165

62-
@FunctionalInterface
63-
protected interface ComplexPredicate {
64-
boolean test(RComplex x);
66+
protected Adapter(Predicates predicates) {
67+
this.predicates = predicates;
6568
}
6669

67-
@FunctionalInterface
68-
protected interface LogicalPredicate {
69-
boolean test(byte x);
70+
abstract static class Predicates {
71+
public abstract boolean test(double x);
72+
73+
public abstract boolean test(RComplex x);
74+
75+
public boolean test(@SuppressWarnings("unused") int x) {
76+
throw RInternalError.shouldNotReachHere();
77+
}
78+
79+
public boolean test(@SuppressWarnings("unused") byte x) {
80+
throw RInternalError.shouldNotReachHere();
81+
}
7082
}
7183

7284
@Specialization
7385
public RLogicalVector doNull(@SuppressWarnings("unused") RNull x) {
7486
return RDataFactory.createEmptyLogicalVector();
7587
}
7688

77-
@Specialization
78-
public RLogicalVector doString(RAbstractStringVector x) {
79-
return doFunConstant(x, RRuntime.LOGICAL_FALSE);
89+
@Specialization(limit = "getTypedVectorDataLibraryCacheSize()")
90+
public RLogicalVector doString(RAbstractStringVector x,
91+
@CachedLibrary("x.getData()") VectorDataLibrary dataLib) {
92+
return doFunConstant(dataLib, x.getData(), x, RRuntime.LOGICAL_FALSE);
8093
}
8194

82-
@Specialization
83-
public RLogicalVector doRaw(RRawVector x) {
84-
return doFunConstant(x, RRuntime.LOGICAL_FALSE);
95+
@Specialization(limit = "getTypedVectorDataLibraryCacheSize()")
96+
public RLogicalVector doRaw(RRawVector x,
97+
@CachedLibrary("x.getData()") VectorDataLibrary dataLib) {
98+
return doFunConstant(dataLib, x.getData(), x, RRuntime.LOGICAL_FALSE);
8599
}
86100

87101
@Specialization(guards = "isForeignObject(obj)")
@@ -96,48 +110,48 @@ protected Object doIsFiniteOther(Object x) {
96110
throw error(RError.Message.DEFAULT_METHOD_NOT_IMPLEMENTED_FOR_TYPE, TypeofNode.getTypeof(x).getName());
97111
}
98112

99-
protected RLogicalVector doFunConstant(RAbstractVector x, byte value) {
100-
byte[] b = new byte[x.getLength()];
113+
protected RLogicalVector doFunConstant(VectorDataLibrary dataLib, Object xData, RAbstractVector x, byte value) {
114+
byte[] b = new byte[dataLib.getLength(xData)];
101115
Arrays.fill(b, value);
102116
RLogicalVector result = RDataFactory.createLogicalVector(b, RDataFactory.COMPLETE_VECTOR);
103117
initDimsNamesDimNames.initAttributes(result, x);
104118
return result;
105119
}
106120

107-
protected RLogicalVector doFunDouble(RDoubleVector x, DoublePredicate fun) {
108-
byte[] b = new byte[x.getLength()];
121+
protected RLogicalVector doFunDouble(VectorDataLibrary xDataLib, Object xData, RDoubleVector x) {
122+
byte[] b = new byte[xDataLib.getLength(xData)];
109123
for (int i = 0; i < b.length; i++) {
110-
b[i] = RRuntime.asLogical(fun.test(x.getDataAt(i)));
124+
b[i] = RRuntime.asLogical(predicates.test(xDataLib.getDoubleAt(xData, i)));
111125
}
112126
RLogicalVector result = RDataFactory.createLogicalVector(b, RDataFactory.COMPLETE_VECTOR);
113127
initDimsNamesDimNames.initAttributes(result, x);
114128
return result;
115129
}
116130

117-
protected RLogicalVector doFunLogical(RLogicalVector x, LogicalPredicate fun) {
118-
byte[] b = new byte[x.getLength()];
131+
protected RLogicalVector doFunLogical(VectorDataLibrary xDataLib, Object xData, RLogicalVector x) {
132+
byte[] b = new byte[xDataLib.getLength(xData)];
119133
for (int i = 0; i < b.length; i++) {
120-
b[i] = RRuntime.asLogical(fun.test(x.getDataAt(i)));
134+
b[i] = RRuntime.asLogical(predicates.test(xDataLib.getLogicalAt(xData, i)));
121135
}
122136
RLogicalVector result = RDataFactory.createLogicalVector(b, RDataFactory.COMPLETE_VECTOR);
123137
initDimsNamesDimNames.initAttributes(result, x);
124138
return result;
125139
}
126140

127-
protected RLogicalVector doFunInt(RIntVector x, IntPredicate fun) {
128-
byte[] b = new byte[x.getLength()];
141+
protected RLogicalVector doFunInt(VectorDataLibrary xDataLib, Object xData, RIntVector x) {
142+
byte[] b = new byte[xDataLib.getLength(xData)];
129143
for (int i = 0; i < b.length; i++) {
130-
b[i] = RRuntime.asLogical(fun.test(x.getDataAt(i)));
144+
b[i] = RRuntime.asLogical(predicates.test(xDataLib.getIntAt(xData, i)));
131145
}
132146
RLogicalVector result = RDataFactory.createLogicalVector(b, RDataFactory.COMPLETE_VECTOR);
133147
initDimsNamesDimNames.initAttributes(result, x);
134148
return result;
135149
}
136150

137-
protected RLogicalVector doFunComplex(RAbstractComplexVector x, ComplexPredicate fun) {
138-
byte[] b = new byte[x.getLength()];
151+
protected RLogicalVector doFunComplex(VectorDataLibrary xDataLib, Object xData, RAbstractComplexVector x) {
152+
byte[] b = new byte[xDataLib.getLength(xData)];
139153
for (int i = 0; i < b.length; i++) {
140-
b[i] = RRuntime.asLogical(fun.test(x.getDataAt(i)));
154+
b[i] = RRuntime.asLogical(predicates.test(xDataLib.getComplexAt(xData, i)));
141155
}
142156
RLogicalVector result = RDataFactory.createLogicalVector(b, RDataFactory.COMPLETE_VECTOR);
143157
initDimsNamesDimNames.initAttributes(result, x);
@@ -152,34 +166,68 @@ public abstract static class IsFinite extends Adapter {
152166
Casts.noCasts(IsFinite.class);
153167
}
154168

155-
@Specialization
156-
protected RLogicalVector doIsFinite(RDoubleVector vec) {
157-
return doFunDouble(vec, RRuntime::isFinite);
169+
private static final class FinitePredicates extends Predicates {
170+
private static final FinitePredicates INSTANCE = new FinitePredicates();
171+
172+
@Override
173+
public boolean test(double x) {
174+
return RRuntime.isFinite(x);
175+
}
176+
177+
@Override
178+
public boolean test(int x) {
179+
return !RRuntime.isNA(x);
180+
}
181+
182+
@Override
183+
public boolean test(byte x) {
184+
return !RRuntime.isNA(x);
185+
}
186+
187+
@Override
188+
public boolean test(RComplex x) {
189+
return RRuntime.isFinite(x.getRealPart()) && RRuntime.isFinite(x.getImaginaryPart());
190+
}
158191
}
159192

160-
@Specialization(guards = "vec.isComplete()")
161-
protected RLogicalVector doComplete(RIntVector vec) {
162-
return doFunConstant(vec, RRuntime.LOGICAL_TRUE);
193+
public IsFinite() {
194+
super(FinitePredicates.INSTANCE);
163195
}
164196

165-
@Specialization(guards = "vec.isComplete()")
166-
protected RLogicalVector doComplete(RLogicalVector vec) {
167-
return doFunConstant(vec, RRuntime.LOGICAL_TRUE);
197+
@Specialization(limit = "getTypedVectorDataLibraryCacheSize()")
198+
protected RLogicalVector doIsFinite(RDoubleVector vec,
199+
@CachedLibrary("vec.getData()") VectorDataLibrary dataLib) {
200+
return doFunDouble(dataLib, vec.getData(), vec);
168201
}
169202

170-
@Specialization(replaces = "doComplete")
171-
protected RLogicalVector doIsFinite(RIntVector vec) {
172-
return doFunInt(vec, value -> !RRuntime.isNA(value));
203+
@Specialization(limit = "getTypedVectorDataLibraryCacheSize()")
204+
protected RLogicalVector doIsFinite(RIntVector vec,
205+
@Cached("createBinaryProfile()") ConditionProfile isCompleteProfile,
206+
@CachedLibrary("vec.getData()") VectorDataLibrary dataLib) {
207+
final Object vecData = vec.getData();
208+
if (isCompleteProfile.profile(dataLib.isComplete(vecData))) {
209+
return doFunConstant(dataLib, vecData, vec, RRuntime.LOGICAL_TRUE);
210+
} else {
211+
return doFunInt(dataLib, vecData, vec);
212+
}
173213
}
174214

175-
@Specialization(replaces = "doComplete")
176-
protected RLogicalVector doIsFinite(RLogicalVector vec) {
177-
return doFunLogical(vec, value -> !RRuntime.isNA(value));
215+
@Specialization(limit = "getTypedVectorDataLibraryCacheSize()")
216+
protected RLogicalVector doIsFinite(RLogicalVector vec,
217+
@Cached("createBinaryProfile()") ConditionProfile isCompleteProfile,
218+
@CachedLibrary("vec.getData()") VectorDataLibrary dataLib) {
219+
final Object vecData = vec.getData();
220+
if (isCompleteProfile.profile(dataLib.isComplete(vecData))) {
221+
return doFunConstant(dataLib, vecData, vec, RRuntime.LOGICAL_TRUE);
222+
} else {
223+
return doFunLogical(dataLib, vec.getData(), vec);
224+
}
178225
}
179226

180-
@Specialization
181-
protected RLogicalVector doIsFinite(RAbstractComplexVector vec) {
182-
return doFunComplex(vec, value -> RRuntime.isFinite(value.getRealPart()) && RRuntime.isFinite(value.getImaginaryPart()));
227+
@Specialization(limit = "getTypedVectorDataLibraryCacheSize()")
228+
protected RLogicalVector doIsFinite(RAbstractComplexVector vec,
229+
@CachedLibrary("vec.getData()") VectorDataLibrary dataLib) {
230+
return doFunComplex(dataLib, vec.getData(), vec);
183231
}
184232
}
185233

@@ -190,24 +238,46 @@ public abstract static class IsInfinite extends Adapter {
190238
Casts.noCasts(IsInfinite.class);
191239
}
192240

193-
@Specialization
194-
protected RLogicalVector doIsInfinite(RDoubleVector vec) {
195-
return doFunDouble(vec, Double::isInfinite);
241+
private static final class InfinitePredicates extends Predicates {
242+
private static final InfinitePredicates INSTANCE = new InfinitePredicates();
243+
244+
@Override
245+
public boolean test(double x) {
246+
return Double.isInfinite(x);
247+
}
248+
249+
@Override
250+
public boolean test(RComplex x) {
251+
return Double.isInfinite(x.getRealPart()) || Double.isInfinite(x.getImaginaryPart());
252+
}
196253
}
197254

198-
@Specialization
199-
protected RLogicalVector doComplete(RIntVector vec) {
200-
return doFunConstant(vec, RRuntime.LOGICAL_FALSE);
255+
public IsInfinite() {
256+
super(InfinitePredicates.INSTANCE);
201257
}
202258

203-
@Specialization
204-
protected RLogicalVector doComplete(RLogicalVector vec) {
205-
return doFunConstant(vec, RRuntime.LOGICAL_FALSE);
259+
@Specialization(limit = "getTypedVectorDataLibraryCacheSize()")
260+
protected RLogicalVector doIsInfinite(RDoubleVector vec,
261+
@CachedLibrary("vec.getData()") VectorDataLibrary dataLib) {
262+
return doFunDouble(dataLib, vec.getData(), vec);
206263
}
207264

208-
@Specialization
209-
protected RLogicalVector doIsInfinite(RAbstractComplexVector vec) {
210-
return doFunComplex(vec, value -> Double.isInfinite(value.getRealPart()) || Double.isInfinite(value.getImaginaryPart()));
265+
@Specialization(limit = "getTypedVectorDataLibraryCacheSize()")
266+
protected RLogicalVector doComplete(RIntVector vec,
267+
@CachedLibrary("vec.getData()") VectorDataLibrary dataLib) {
268+
return doFunConstant(dataLib, vec.getData(), vec, RRuntime.LOGICAL_FALSE);
269+
}
270+
271+
@Specialization(limit = "getTypedVectorDataLibraryCacheSize()")
272+
protected RLogicalVector doComplete(RLogicalVector vec,
273+
@CachedLibrary("vec.getData()") VectorDataLibrary dataLib) {
274+
return doFunConstant(dataLib, vec.getData(), vec, RRuntime.LOGICAL_FALSE);
275+
}
276+
277+
@Specialization(limit = "getTypedVectorDataLibraryCacheSize()")
278+
protected RLogicalVector doIsInfinite(RAbstractComplexVector vec,
279+
@CachedLibrary("vec.getData()") VectorDataLibrary dataLib) {
280+
return doFunComplex(dataLib, vec.getData(), vec);
211281
}
212282
}
213283

@@ -218,28 +288,46 @@ public abstract static class IsNaN extends Adapter {
218288
Casts.noCasts(IsNaN.class);
219289
}
220290

221-
private static boolean isNaN(double value) {
222-
return Double.isNaN(value) && !RRuntime.isNA(value);
291+
private static final class NaNPredicates extends Predicates {
292+
private static final NaNPredicates INSTANCE = new NaNPredicates();
293+
294+
@Override
295+
public boolean test(double x) {
296+
return Double.isNaN(x) && !RRuntime.isNA(x);
297+
}
298+
299+
@Override
300+
public boolean test(RComplex x) {
301+
return test(x.getRealPart()) || test(x.getImaginaryPart());
302+
}
223303
}
224304

225-
@Specialization
226-
protected RLogicalVector doIsNan(RDoubleVector vec) {
227-
return doFunDouble(vec, IsNaN::isNaN);
305+
public IsNaN() {
306+
super(NaNPredicates.INSTANCE);
228307
}
229308

230-
@Specialization
231-
protected RLogicalVector doIsNan(RIntVector vec) {
232-
return doFunConstant(vec, RRuntime.LOGICAL_FALSE);
309+
@Specialization(limit = "getTypedVectorDataLibraryCacheSize()")
310+
protected RLogicalVector doIsNan(RDoubleVector vec,
311+
@CachedLibrary("vec.getData()") VectorDataLibrary dataLib) {
312+
return doFunDouble(dataLib, vec.getData(), vec);
233313
}
234314

235-
@Specialization
236-
protected RLogicalVector doIsNan(RLogicalVector vec) {
237-
return doFunConstant(vec, RRuntime.LOGICAL_FALSE);
315+
@Specialization(limit = "getTypedVectorDataLibraryCacheSize()")
316+
protected RLogicalVector doIsNan(RIntVector vec,
317+
@CachedLibrary("vec.getData()") VectorDataLibrary dataLib) {
318+
return doFunConstant(dataLib, vec.getData(), vec, RRuntime.LOGICAL_FALSE);
238319
}
239320

240-
@Specialization
241-
protected RLogicalVector doIsNan(RAbstractComplexVector vec) {
242-
return doFunComplex(vec, value -> isNaN(value.getRealPart()) || isNaN(value.getImaginaryPart()));
321+
@Specialization(limit = "getTypedVectorDataLibraryCacheSize()")
322+
protected RLogicalVector doIsNan(RLogicalVector vec,
323+
@CachedLibrary("vec.getData()") VectorDataLibrary dataLib) {
324+
return doFunConstant(dataLib, vec.getData(), vec, RRuntime.LOGICAL_FALSE);
325+
}
326+
327+
@Specialization(limit = "getTypedVectorDataLibraryCacheSize()")
328+
protected RLogicalVector doIsNan(RAbstractComplexVector vec,
329+
@CachedLibrary("vec.getData()") VectorDataLibrary dataLib) {
330+
return doFunComplex(dataLib, vec.getData(), vec);
243331
}
244332
}
245333
}

0 commit comments

Comments
 (0)