From 794dcdab8b30127928630122669c3937517de93e Mon Sep 17 00:00:00 2001 From: Hao Li <1127478+lihaosky@users.noreply.github.com> Date: Tue, 2 Dec 2025 09:23:59 -0800 Subject: [PATCH] [FLINK-37902] batch support for ml_predict --- .../strategies/MLPredictTypeStrategy.java | 2 +- .../BatchExecMLPredictTableFunction.java | 101 ++ .../CommonExecMLPredictTableFunction.java | 246 +++++ .../StreamExecMLPredictTableFunction.java | 211 +--- .../BatchPhysicalMLPredictTableFunction.java | 70 ++ .../CommonPhysicalMLPredictTableFunction.java | 209 ++++ .../StreamPhysicalMLPredictTableFunction.java | 168 +--- ...treamPhysicalProcessTableFunctionRule.java | 3 +- .../PhysicalMLPredictTableFunctionRule.java} | 58 +- .../plan/utils/ExecNodeMetadataUtil.java | 2 + .../plan/rules/FlinkBatchRuleSets.scala | 3 +- .../plan/rules/FlinkStreamRuleSets.scala | 6 +- .../batch/sql/MLPredictTableFunctionTest.java | 32 + .../MLPredictTableFunctionTestBase.java | 549 ++++++++++ .../exec/batch/MLPredictBatchRestoreTest.java | 72 ++ .../exec/testutils/BatchRestoreTestBase.java | 18 +- .../sql/MLPredictTableFunctionTest.java | 451 +-------- .../batch/table/AsyncMLPredictITCase.java | 53 + .../runtime/batch/table/MLPredictITCase.java | 42 + .../stream/table/AsyncMLPredictITCase.java | 275 +---- .../runtime/stream/table/MLPredictITCase.java | 179 +--- .../runtime/utils/MLPredictITCaseBase.java | 258 +++++ .../batch/sql/MLPredictTableFunctionTest.xml | 939 ++++++++++++++++++ .../plan/async-unordered-ml-predict.json | 134 +++ .../sync-ml-predict-with-runtime-options.json | 132 +++ .../sync-ml-predict/plan/sync-ml-predict.json | 130 +++ 26 files changed, 3096 insertions(+), 1247 deletions(-) create mode 100644 flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/batch/BatchExecMLPredictTableFunction.java create mode 100644 flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/common/CommonExecMLPredictTableFunction.java create mode 100644 flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchPhysicalMLPredictTableFunction.java create mode 100644 flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/physical/common/CommonPhysicalMLPredictTableFunction.java rename flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/{nodes/physical/stream/StreamPhysicalMLPredictTableFunctionRule.java => rules/physical/common/PhysicalMLPredictTableFunctionRule.java} (77%) create mode 100644 flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/batch/sql/MLPredictTableFunctionTest.java create mode 100644 flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/common/MLPredictTableFunctionTestBase.java create mode 100644 flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/nodes/exec/batch/MLPredictBatchRestoreTest.java create mode 100644 flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/runtime/batch/table/AsyncMLPredictITCase.java create mode 100644 flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/runtime/batch/table/MLPredictITCase.java create mode 100644 flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/runtime/utils/MLPredictITCaseBase.java create mode 100644 flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/batch/sql/MLPredictTableFunctionTest.xml create mode 100644 flink-table/flink-table-planner/src/test/resources/restore-tests/batch-exec-ml-predict-table-function_1/async-unordered-ml-predict/plan/async-unordered-ml-predict.json create mode 100644 flink-table/flink-table-planner/src/test/resources/restore-tests/batch-exec-ml-predict-table-function_1/sync-ml-predict-with-runtime-options/plan/sync-ml-predict-with-runtime-options.json create mode 100644 flink-table/flink-table-planner/src/test/resources/restore-tests/batch-exec-ml-predict-table-function_1/sync-ml-predict/plan/sync-ml-predict.json diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/MLPredictTypeStrategy.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/MLPredictTypeStrategy.java index 4ea8a7b9a8a4a..5867d55378059 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/MLPredictTypeStrategy.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/MLPredictTypeStrategy.java @@ -157,7 +157,7 @@ private static Optional> inferMLPredictInputTypes( return Optional.empty(); } - // Config map validation is done in StreamPhysicalMLPredictTableFunctionRule since + // Config map validation is done in PhysicalMLPredictTableFunctionRule since // we are not able to get map literal here. return Optional.of(callContext.getArgumentDataTypes()); } diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/batch/BatchExecMLPredictTableFunction.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/batch/BatchExecMLPredictTableFunction.java new file mode 100644 index 0000000000000..2cac964a18203 --- /dev/null +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/batch/BatchExecMLPredictTableFunction.java @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.planner.plan.nodes.exec.batch; + +import org.apache.flink.FlinkVersion; +import org.apache.flink.configuration.ReadableConfig; +import org.apache.flink.table.data.RowData; +import org.apache.flink.table.planner.plan.nodes.exec.ExecNodeContext; +import org.apache.flink.table.planner.plan.nodes.exec.ExecNodeMetadata; +import org.apache.flink.table.planner.plan.nodes.exec.InputProperty; +import org.apache.flink.table.planner.plan.nodes.exec.SingleTransformationTranslator; +import org.apache.flink.table.planner.plan.nodes.exec.common.CommonExecMLPredictTableFunction; +import org.apache.flink.table.planner.plan.nodes.exec.spec.MLPredictSpec; +import org.apache.flink.table.planner.plan.nodes.exec.spec.ModelSpec; +import org.apache.flink.table.planner.plan.utils.FunctionCallUtil; +import org.apache.flink.table.types.logical.RowType; + +import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonCreator; +import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonProperty; + +import javax.annotation.Nullable; + +import java.util.Collections; +import java.util.List; + +/** Batch {@link ExecNode} for {@code ML_PREDICT}. */ +@ExecNodeMetadata( + name = "batch-exec-ml-predict-table-function", + version = 1, + consumedOptions = { + "table.exec.async-ml-predict.max-concurrent-operations", + "table.exec.async-ml-predict.timeout", + "table.exec.async-ml-predict.output-mode" + }, + producedTransformations = CommonExecMLPredictTableFunction.ML_PREDICT_TRANSFORMATION, + minPlanVersion = FlinkVersion.v2_3, + minStateVersion = FlinkVersion.v2_3) +public class BatchExecMLPredictTableFunction extends CommonExecMLPredictTableFunction + implements SingleTransformationTranslator, BatchExecNode { + + public BatchExecMLPredictTableFunction( + ReadableConfig persistedConfig, + MLPredictSpec mlPredictSpec, + ModelSpec modelSpec, + @Nullable FunctionCallUtil.AsyncOptions asyncOptions, + InputProperty inputProperty, + RowType outputType, + String description) { + this( + ExecNodeContext.newNodeId(), + ExecNodeContext.newContext(BatchExecMLPredictTableFunction.class), + ExecNodeContext.newPersistedConfig( + BatchExecMLPredictTableFunction.class, persistedConfig), + mlPredictSpec, + modelSpec, + asyncOptions, + Collections.singletonList(inputProperty), + outputType, + description); + } + + @JsonCreator + public BatchExecMLPredictTableFunction( + @JsonProperty(FIELD_NAME_ID) int id, + @JsonProperty(FIELD_NAME_TYPE) ExecNodeContext context, + @JsonProperty(FIELD_NAME_CONFIGURATION) ReadableConfig persistedConfig, + @JsonProperty(FIELD_NAME_ML_PREDICT_SPEC) MLPredictSpec mlPredictSpec, + @JsonProperty(FIELD_NAME_MODEL_SPEC) ModelSpec modelSpec, + @JsonProperty(FIELD_NAME_ASYNC_OPTIONS) @Nullable + FunctionCallUtil.AsyncOptions asyncOptions, + @JsonProperty(FIELD_NAME_INPUT_PROPERTIES) List inputProperties, + @JsonProperty(FIELD_NAME_OUTPUT_TYPE) RowType outputType, + @JsonProperty(FIELD_NAME_DESCRIPTION) String description) { + super( + id, + context, + persistedConfig, + mlPredictSpec, + modelSpec, + asyncOptions, + inputProperties, + outputType, + description); + } +} diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/common/CommonExecMLPredictTableFunction.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/common/CommonExecMLPredictTableFunction.java new file mode 100644 index 0000000000000..8a53c169f19c7 --- /dev/null +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/common/CommonExecMLPredictTableFunction.java @@ -0,0 +1,246 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.planner.plan.nodes.exec.common; + +import org.apache.flink.api.common.functions.FlatMapFunction; +import org.apache.flink.api.dag.Transformation; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.configuration.PipelineOptions; +import org.apache.flink.configuration.ReadableConfig; +import org.apache.flink.streaming.api.functions.async.AsyncFunction; +import org.apache.flink.streaming.api.operators.ProcessOperator; +import org.apache.flink.streaming.api.operators.SimpleOperatorFactory; +import org.apache.flink.streaming.api.operators.async.AsyncWaitOperatorFactory; +import org.apache.flink.table.api.TableException; +import org.apache.flink.table.catalog.DataTypeFactory; +import org.apache.flink.table.data.RowData; +import org.apache.flink.table.functions.AsyncPredictFunction; +import org.apache.flink.table.functions.PredictFunction; +import org.apache.flink.table.functions.UserDefinedFunction; +import org.apache.flink.table.ml.AsyncPredictRuntimeProvider; +import org.apache.flink.table.ml.ModelProvider; +import org.apache.flink.table.ml.PredictRuntimeProvider; +import org.apache.flink.table.planner.calcite.FlinkContext; +import org.apache.flink.table.planner.codegen.CodeGeneratorContext; +import org.apache.flink.table.planner.codegen.FunctionCallCodeGenerator; +import org.apache.flink.table.planner.codegen.MLPredictCodeGenerator; +import org.apache.flink.table.planner.delegation.PlannerBase; +import org.apache.flink.table.planner.plan.nodes.exec.ExecNodeBase; +import org.apache.flink.table.planner.plan.nodes.exec.ExecNodeConfig; +import org.apache.flink.table.planner.plan.nodes.exec.ExecNodeContext; +import org.apache.flink.table.planner.plan.nodes.exec.InputProperty; +import org.apache.flink.table.planner.plan.nodes.exec.spec.MLPredictSpec; +import org.apache.flink.table.planner.plan.nodes.exec.spec.ModelSpec; +import org.apache.flink.table.planner.plan.nodes.exec.utils.ExecNodeUtil; +import org.apache.flink.table.planner.plan.utils.FunctionCallUtil; +import org.apache.flink.table.runtime.collector.ListenableCollector; +import org.apache.flink.table.runtime.functions.ml.ModelPredictRuntimeProviderContext; +import org.apache.flink.table.runtime.generated.GeneratedCollector; +import org.apache.flink.table.runtime.generated.GeneratedFunction; +import org.apache.flink.table.runtime.operators.ml.AsyncMLPredictRunner; +import org.apache.flink.table.runtime.operators.ml.MLPredictRunner; +import org.apache.flink.table.runtime.typeutils.InternalTypeInfo; +import org.apache.flink.table.types.logical.RowType; +import org.apache.flink.util.Preconditions; + +import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonProperty; + +import javax.annotation.Nullable; + +import java.util.List; + +/** Common ExecNode for {@code ML_PREDICT}. */ +public abstract class CommonExecMLPredictTableFunction extends ExecNodeBase { + + public static final String ML_PREDICT_TRANSFORMATION = "ml-predict-table-function"; + + protected static final String FIELD_NAME_ML_PREDICT_SPEC = "mlPredictSpec"; + protected static final String FIELD_NAME_MODEL_SPEC = "modelSpec"; + protected static final String FIELD_NAME_ASYNC_OPTIONS = "asyncOptions"; + + @JsonProperty(FIELD_NAME_ML_PREDICT_SPEC) + protected final MLPredictSpec mlPredictSpec; + + @JsonProperty(FIELD_NAME_MODEL_SPEC) + protected final ModelSpec modelSpec; + + @JsonProperty(FIELD_NAME_ASYNC_OPTIONS) + protected final @Nullable FunctionCallUtil.AsyncOptions asyncOptions; + + protected CommonExecMLPredictTableFunction( + int id, + ExecNodeContext context, + ReadableConfig persistedConfig, + MLPredictSpec mlPredictSpec, + ModelSpec modelSpec, + @Nullable FunctionCallUtil.AsyncOptions asyncOptions, + List inputProperties, + RowType outputType, + String description) { + super(id, context, persistedConfig, inputProperties, outputType, description); + this.mlPredictSpec = mlPredictSpec; + this.modelSpec = modelSpec; + this.asyncOptions = asyncOptions; + } + + @Override + protected Transformation translateToPlanInternal( + PlannerBase planner, ExecNodeConfig config) { + Transformation inputTransformation = + (Transformation) getInputEdges().get(0).translateToPlan(planner); + + ModelProvider provider = modelSpec.getModelProvider(planner.getFlinkContext()); + boolean async = asyncOptions != null; + UserDefinedFunction predictFunction = findModelFunction(provider, async); + FlinkContext context = planner.getFlinkContext(); + DataTypeFactory dataTypeFactory = context.getCatalogManager().getDataTypeFactory(); + + RowType inputType = (RowType) getInputEdges().get(0).getOutputType(); + RowType modelOutputType = + (RowType) + modelSpec + .getContextResolvedModel() + .getResolvedModel() + .getResolvedOutputSchema() + .toPhysicalRowDataType() + .getLogicalType(); + return async + ? createAsyncModelPredict( + inputTransformation, + config, + planner.getFlinkContext().getClassLoader(), + dataTypeFactory, + inputType, + modelOutputType, + (RowType) getOutputType(), + (AsyncPredictFunction) predictFunction) + : createModelPredict( + inputTransformation, + config, + planner.getFlinkContext().getClassLoader(), + dataTypeFactory, + inputType, + modelOutputType, + (RowType) getOutputType(), + (PredictFunction) predictFunction); + } + + private Transformation createModelPredict( + Transformation inputTransformation, + ExecNodeConfig config, + ClassLoader classLoader, + DataTypeFactory dataTypeFactory, + RowType inputRowType, + RowType modelOutputType, + RowType resultRowType, + PredictFunction predictFunction) { + GeneratedFunction> generatedFetcher = + MLPredictCodeGenerator.generateSyncPredictFunction( + config, + classLoader, + dataTypeFactory, + inputRowType, + modelOutputType, + resultRowType, + mlPredictSpec.getFeatures(), + predictFunction, + modelSpec.getContextResolvedModel().getIdentifier().asSummaryString(), + config.get(PipelineOptions.OBJECT_REUSE)); + GeneratedCollector> generatedCollector = + MLPredictCodeGenerator.generateCollector( + new CodeGeneratorContext(config, classLoader), + inputRowType, + modelOutputType, + (RowType) getOutputType()); + MLPredictRunner mlPredictRunner = new MLPredictRunner(generatedFetcher, generatedCollector); + SimpleOperatorFactory operatorFactory = + SimpleOperatorFactory.of(new ProcessOperator<>(mlPredictRunner)); + return ExecNodeUtil.createOneInputTransformation( + inputTransformation, + createTransformationMeta(ML_PREDICT_TRANSFORMATION, config), + operatorFactory, + InternalTypeInfo.of(getOutputType()), + inputTransformation.getParallelism(), + false); + } + + @SuppressWarnings("unchecked") + private Transformation createAsyncModelPredict( + Transformation inputTransformation, + ExecNodeConfig config, + ClassLoader classLoader, + DataTypeFactory dataTypeFactory, + RowType inputRowType, + RowType modelOutputType, + RowType resultRowType, + AsyncPredictFunction asyncPredictFunction) { + FunctionCallCodeGenerator.GeneratedTableFunctionWithDataType> + generatedFuncWithType = + MLPredictCodeGenerator.generateAsyncPredictFunction( + config, + classLoader, + dataTypeFactory, + inputRowType, + modelOutputType, + resultRowType, + mlPredictSpec.getFeatures(), + asyncPredictFunction, + modelSpec + .getContextResolvedModel() + .getIdentifier() + .asSummaryString()); + AsyncFunction asyncFunc = + new AsyncMLPredictRunner( + (GeneratedFunction) generatedFuncWithType.tableFunc(), + Preconditions.checkNotNull(asyncOptions).asyncBufferCapacity); + return ExecNodeUtil.createOneInputTransformation( + inputTransformation, + createTransformationMeta(ML_PREDICT_TRANSFORMATION, config), + new AsyncWaitOperatorFactory<>( + asyncFunc, + asyncOptions.asyncTimeout, + asyncOptions.asyncBufferCapacity, + asyncOptions.asyncOutputMode), + InternalTypeInfo.of(getOutputType()), + inputTransformation.getParallelism(), + false); + } + + private UserDefinedFunction findModelFunction(ModelProvider provider, boolean async) { + ModelPredictRuntimeProviderContext context = + new ModelPredictRuntimeProviderContext( + modelSpec.getContextResolvedModel().getResolvedModel(), + Configuration.fromMap(mlPredictSpec.getRuntimeConfig())); + if (async) { + if (provider instanceof AsyncPredictRuntimeProvider) { + return ((AsyncPredictRuntimeProvider) provider).createAsyncPredictFunction(context); + } + } else { + if (provider instanceof PredictRuntimeProvider) { + return ((PredictRuntimeProvider) provider).createPredictFunction(context); + } + } + + throw new TableException( + "Required " + + (async ? "async" : "sync") + + " model function by planner, but ModelProvider " + + "does not offer a valid model function."); + } +} diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecMLPredictTableFunction.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecMLPredictTableFunction.java index 6d1fe8a8cee39..29efe72312e71 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecMLPredictTableFunction.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecMLPredictTableFunction.java @@ -19,49 +19,18 @@ package org.apache.flink.table.planner.plan.nodes.exec.stream; import org.apache.flink.FlinkVersion; -import org.apache.flink.api.common.functions.FlatMapFunction; -import org.apache.flink.api.dag.Transformation; -import org.apache.flink.configuration.Configuration; -import org.apache.flink.configuration.PipelineOptions; import org.apache.flink.configuration.ReadableConfig; -import org.apache.flink.streaming.api.functions.async.AsyncFunction; -import org.apache.flink.streaming.api.operators.ProcessOperator; -import org.apache.flink.streaming.api.operators.SimpleOperatorFactory; -import org.apache.flink.streaming.api.operators.async.AsyncWaitOperatorFactory; -import org.apache.flink.table.api.TableException; -import org.apache.flink.table.catalog.DataTypeFactory; import org.apache.flink.table.data.RowData; -import org.apache.flink.table.functions.AsyncPredictFunction; -import org.apache.flink.table.functions.PredictFunction; -import org.apache.flink.table.functions.UserDefinedFunction; -import org.apache.flink.table.ml.AsyncPredictRuntimeProvider; -import org.apache.flink.table.ml.ModelProvider; -import org.apache.flink.table.ml.PredictRuntimeProvider; -import org.apache.flink.table.planner.calcite.FlinkContext; -import org.apache.flink.table.planner.codegen.CodeGeneratorContext; -import org.apache.flink.table.planner.codegen.FunctionCallCodeGenerator; -import org.apache.flink.table.planner.codegen.MLPredictCodeGenerator; -import org.apache.flink.table.planner.delegation.PlannerBase; import org.apache.flink.table.planner.plan.nodes.exec.ExecNode; -import org.apache.flink.table.planner.plan.nodes.exec.ExecNodeBase; -import org.apache.flink.table.planner.plan.nodes.exec.ExecNodeConfig; import org.apache.flink.table.planner.plan.nodes.exec.ExecNodeContext; import org.apache.flink.table.planner.plan.nodes.exec.ExecNodeMetadata; import org.apache.flink.table.planner.plan.nodes.exec.InputProperty; import org.apache.flink.table.planner.plan.nodes.exec.MultipleTransformationTranslator; +import org.apache.flink.table.planner.plan.nodes.exec.common.CommonExecMLPredictTableFunction; import org.apache.flink.table.planner.plan.nodes.exec.spec.MLPredictSpec; import org.apache.flink.table.planner.plan.nodes.exec.spec.ModelSpec; -import org.apache.flink.table.planner.plan.nodes.exec.utils.ExecNodeUtil; import org.apache.flink.table.planner.plan.utils.FunctionCallUtil; -import org.apache.flink.table.runtime.collector.ListenableCollector; -import org.apache.flink.table.runtime.functions.ml.ModelPredictRuntimeProviderContext; -import org.apache.flink.table.runtime.generated.GeneratedCollector; -import org.apache.flink.table.runtime.generated.GeneratedFunction; -import org.apache.flink.table.runtime.operators.ml.AsyncMLPredictRunner; -import org.apache.flink.table.runtime.operators.ml.MLPredictRunner; -import org.apache.flink.table.runtime.typeutils.InternalTypeInfo; import org.apache.flink.table.types.logical.RowType; -import org.apache.flink.util.Preconditions; import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonCreator; import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonProperty; @@ -80,27 +49,12 @@ "table.exec.async-ml-predict.timeout", "table.exec.async-ml-predict.output-mode" }, - producedTransformations = StreamExecMLPredictTableFunction.ML_PREDICT_TRANSFORMATION, + producedTransformations = CommonExecMLPredictTableFunction.ML_PREDICT_TRANSFORMATION, minPlanVersion = FlinkVersion.v2_1, minStateVersion = FlinkVersion.v2_1) -public class StreamExecMLPredictTableFunction extends ExecNodeBase +public class StreamExecMLPredictTableFunction extends CommonExecMLPredictTableFunction implements MultipleTransformationTranslator, StreamExecNode { - public static final String ML_PREDICT_TRANSFORMATION = "ml-predict-table-function"; - - public static final String FIELD_NAME_ML_PREDICT_SPEC = "mlPredictSpec"; - public static final String FIELD_NAME_MODEL_SPEC = "modelSpec"; - public static final String FIELD_NAME_ASYNC_OPTIONS = "asyncOptions"; - - @JsonProperty(FIELD_NAME_ML_PREDICT_SPEC) - private final MLPredictSpec mlPredictSpec; - - @JsonProperty(FIELD_NAME_MODEL_SPEC) - private final ModelSpec modelSpec; - - @JsonProperty(FIELD_NAME_ASYNC_OPTIONS) - private final @Nullable FunctionCallUtil.AsyncOptions asyncOptions; - public StreamExecMLPredictTableFunction( ReadableConfig persistedConfig, MLPredictSpec mlPredictSpec, @@ -133,154 +87,15 @@ public StreamExecMLPredictTableFunction( @JsonProperty(FIELD_NAME_INPUT_PROPERTIES) List inputProperties, @JsonProperty(FIELD_NAME_OUTPUT_TYPE) RowType outputType, @JsonProperty(FIELD_NAME_DESCRIPTION) String description) { - super(id, context, persistedConfig, inputProperties, outputType, description); - this.mlPredictSpec = mlPredictSpec; - this.modelSpec = modelSpec; - this.asyncOptions = asyncOptions; - } - - @Override - protected Transformation translateToPlanInternal( - PlannerBase planner, ExecNodeConfig config) { - Transformation inputTransformation = - (Transformation) getInputEdges().get(0).translateToPlan(planner); - - ModelProvider provider = modelSpec.getModelProvider(planner.getFlinkContext()); - boolean async = asyncOptions != null; - UserDefinedFunction predictFunction = findModelFunction(provider, async); - FlinkContext context = planner.getFlinkContext(); - DataTypeFactory dataTypeFactory = context.getCatalogManager().getDataTypeFactory(); - - RowType inputType = (RowType) getInputEdges().get(0).getOutputType(); - RowType modelOutputType = - (RowType) - modelSpec - .getContextResolvedModel() - .getResolvedModel() - .getResolvedOutputSchema() - .toPhysicalRowDataType() - .getLogicalType(); - return async - ? createAsyncModelPredict( - inputTransformation, - config, - planner.getFlinkContext().getClassLoader(), - dataTypeFactory, - inputType, - modelOutputType, - (RowType) getOutputType(), - (AsyncPredictFunction) predictFunction) - : createModelPredict( - inputTransformation, - config, - planner.getFlinkContext().getClassLoader(), - dataTypeFactory, - inputType, - modelOutputType, - (RowType) getOutputType(), - (PredictFunction) predictFunction); - } - - private Transformation createModelPredict( - Transformation inputTransformation, - ExecNodeConfig config, - ClassLoader classLoader, - DataTypeFactory dataTypeFactory, - RowType inputRowType, - RowType modelOutputType, - RowType resultRowType, - PredictFunction predictFunction) { - GeneratedFunction> generatedFetcher = - MLPredictCodeGenerator.generateSyncPredictFunction( - config, - classLoader, - dataTypeFactory, - inputRowType, - modelOutputType, - resultRowType, - mlPredictSpec.getFeatures(), - predictFunction, - modelSpec.getContextResolvedModel().getIdentifier().asSummaryString(), - config.get(PipelineOptions.OBJECT_REUSE)); - GeneratedCollector> generatedCollector = - MLPredictCodeGenerator.generateCollector( - new CodeGeneratorContext(config, classLoader), - inputRowType, - modelOutputType, - (RowType) getOutputType()); - MLPredictRunner mlPredictRunner = new MLPredictRunner(generatedFetcher, generatedCollector); - SimpleOperatorFactory operatorFactory = - SimpleOperatorFactory.of(new ProcessOperator<>(mlPredictRunner)); - return ExecNodeUtil.createOneInputTransformation( - inputTransformation, - createTransformationMeta(ML_PREDICT_TRANSFORMATION, config), - operatorFactory, - InternalTypeInfo.of(getOutputType()), - inputTransformation.getParallelism(), - false); - } - - @SuppressWarnings("unchecked") - private Transformation createAsyncModelPredict( - Transformation inputTransformation, - ExecNodeConfig config, - ClassLoader classLoader, - DataTypeFactory dataTypeFactory, - RowType inputRowType, - RowType modelOutputType, - RowType resultRowType, - AsyncPredictFunction asyncPredictFunction) { - FunctionCallCodeGenerator.GeneratedTableFunctionWithDataType> - generatedFuncWithType = - MLPredictCodeGenerator.generateAsyncPredictFunction( - config, - classLoader, - dataTypeFactory, - inputRowType, - modelOutputType, - resultRowType, - mlPredictSpec.getFeatures(), - asyncPredictFunction, - modelSpec - .getContextResolvedModel() - .getIdentifier() - .asSummaryString()); - AsyncFunction asyncFunc = - new AsyncMLPredictRunner( - (GeneratedFunction) generatedFuncWithType.tableFunc(), - Preconditions.checkNotNull(asyncOptions).asyncBufferCapacity); - return ExecNodeUtil.createOneInputTransformation( - inputTransformation, - createTransformationMeta(ML_PREDICT_TRANSFORMATION, config), - new AsyncWaitOperatorFactory<>( - asyncFunc, - asyncOptions.asyncTimeout, - asyncOptions.asyncBufferCapacity, - asyncOptions.asyncOutputMode), - InternalTypeInfo.of(getOutputType()), - inputTransformation.getParallelism(), - false); - } - - private UserDefinedFunction findModelFunction(ModelProvider provider, boolean async) { - ModelPredictRuntimeProviderContext context = - new ModelPredictRuntimeProviderContext( - modelSpec.getContextResolvedModel().getResolvedModel(), - Configuration.fromMap(mlPredictSpec.getRuntimeConfig())); - if (async) { - if (provider instanceof AsyncPredictRuntimeProvider) { - return ((AsyncPredictRuntimeProvider) provider).createAsyncPredictFunction(context); - } - } else { - if (provider instanceof PredictRuntimeProvider) { - return ((PredictRuntimeProvider) provider).createPredictFunction(context); - } - } - - throw new TableException( - "Required " - + (async ? "async" : "sync") - + " model function by planner, but ModelProvider " - + "does not offer a valid model function."); + super( + id, + context, + persistedConfig, + mlPredictSpec, + modelSpec, + asyncOptions, + inputProperties, + outputType, + description); } } diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchPhysicalMLPredictTableFunction.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchPhysicalMLPredictTableFunction.java new file mode 100644 index 0000000000000..a0822bc071a92 --- /dev/null +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchPhysicalMLPredictTableFunction.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.planner.plan.nodes.physical.batch; + +import org.apache.flink.table.planner.calcite.FlinkTypeFactory; +import org.apache.flink.table.planner.calcite.RexModelCall; +import org.apache.flink.table.planner.plan.nodes.exec.ExecNode; +import org.apache.flink.table.planner.plan.nodes.exec.InputProperty; +import org.apache.flink.table.planner.plan.nodes.exec.batch.BatchExecMLPredictTableFunction; +import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalTableFunctionScan; +import org.apache.flink.table.planner.plan.nodes.physical.common.CommonPhysicalMLPredictTableFunction; +import org.apache.flink.table.planner.utils.ShortcutUtils; + +import org.apache.calcite.plan.RelOptCluster; +import org.apache.calcite.plan.RelTraitSet; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.type.RelDataType; + +import java.util.List; +import java.util.Map; + +/** Batch physical RelNode for ml predict table function. */ +public class BatchPhysicalMLPredictTableFunction extends CommonPhysicalMLPredictTableFunction + implements BatchPhysicalRel { + + public BatchPhysicalMLPredictTableFunction( + RelOptCluster cluster, + RelTraitSet traits, + RelNode inputRel, + FlinkLogicalTableFunctionScan scan, + RelDataType outputRowType, + Map runtimeConfig) { + super(cluster, traits, inputRel, scan, outputRowType, runtimeConfig); + } + + @Override + public RelNode copy(RelTraitSet traitSet, List inputs) { + return new BatchPhysicalMLPredictTableFunction( + getCluster(), traitSet, inputs.get(0), scan, getRowType(), runtimeConfig); + } + + @Override + public ExecNode translateToExecNode() { + RexModelCall modelCall = extractOperand(operand -> operand instanceof RexModelCall); + return new BatchExecMLPredictTableFunction( + ShortcutUtils.unwrapTableConfig(this), + buildMLPredictSpec(runtimeConfig), + buildModelSpec(modelCall), + buildAsyncOptions(modelCall, runtimeConfig), + InputProperty.DEFAULT, + FlinkTypeFactory.toLogicalRowType(getRowType()), + getRelDetailedDescription()); + } +} diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/physical/common/CommonPhysicalMLPredictTableFunction.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/physical/common/CommonPhysicalMLPredictTableFunction.java new file mode 100644 index 0000000000000..81e127ba15088 --- /dev/null +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/physical/common/CommonPhysicalMLPredictTableFunction.java @@ -0,0 +1,209 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.planner.plan.nodes.physical.common; + +import org.apache.flink.configuration.Configuration; +import org.apache.flink.table.api.TableException; +import org.apache.flink.table.api.config.MLPredictRuntimeConfigOptions; +import org.apache.flink.table.ml.AsyncPredictRuntimeProvider; +import org.apache.flink.table.ml.ModelProvider; +import org.apache.flink.table.ml.PredictRuntimeProvider; +import org.apache.flink.table.planner.calcite.RexModelCall; +import org.apache.flink.table.planner.calcite.RexTableArgCall; +import org.apache.flink.table.planner.plan.nodes.exec.spec.MLPredictSpec; +import org.apache.flink.table.planner.plan.nodes.exec.spec.ModelSpec; +import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalTableFunctionScan; +import org.apache.flink.table.planner.plan.nodes.physical.FlinkPhysicalRel; +import org.apache.flink.table.planner.plan.utils.FunctionCallUtil; +import org.apache.flink.table.planner.plan.utils.FunctionCallUtil.FunctionParam; +import org.apache.flink.table.planner.plan.utils.MLPredictUtil; +import org.apache.flink.table.planner.utils.ShortcutUtils; + +import org.apache.calcite.plan.RelOptCluster; +import org.apache.calcite.plan.RelTraitSet; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.RelWriter; +import org.apache.calcite.rel.SingleRel; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rex.RexCall; +import org.apache.calcite.rex.RexLiteral; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.sql.SqlDescriptorOperator; + +import javax.annotation.Nullable; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.function.Predicate; +import java.util.stream.Collectors; + +/** Common physical node for {@code ML_PREDICT}. */ +public abstract class CommonPhysicalMLPredictTableFunction extends SingleRel + implements FlinkPhysicalRel { + + protected final RelDataType outputRowType; + protected final FlinkLogicalTableFunctionScan scan; + protected final Map runtimeConfig; + + protected CommonPhysicalMLPredictTableFunction( + RelOptCluster cluster, + RelTraitSet traits, + RelNode inputRel, + FlinkLogicalTableFunctionScan scan, + RelDataType outputRowType, + Map runtimeConfig) { + super(cluster, traits, inputRel); + this.scan = scan; + this.outputRowType = outputRowType; + this.runtimeConfig = runtimeConfig; + } + + @Override + protected RelDataType deriveRowType() { + return outputRowType; + } + + @Override + public RelWriter explainTerms(RelWriter pw) { + return super.explainTerms(pw) + .item("invocation", scan.getCall()) + .item("rowType", getRowType()); + } + + public RexNode getMLPredictCall() { + return scan.getCall(); + } + + protected MLPredictSpec buildMLPredictSpec(Map runtimeConfig) { + RexTableArgCall tableCall = extractOperand(operand -> operand instanceof RexTableArgCall); + RexCall descriptorCall = + extractOperand( + operand -> + operand instanceof RexCall + && ((RexCall) operand).getOperator() + instanceof SqlDescriptorOperator); + Map column2Index = new HashMap<>(); + List fieldNames = tableCall.getType().getFieldNames(); + for (int i = 0; i < fieldNames.size(); i++) { + column2Index.put(fieldNames.get(i), i); + } + List features = + descriptorCall.getOperands().stream() + .map( + operand -> { + if (operand instanceof RexLiteral) { + RexLiteral literal = (RexLiteral) operand; + String fieldName = RexLiteral.stringValue(literal); + Integer index = column2Index.get(fieldName); + if (index == null) { + throw new TableException( + String.format( + "Field %s is not found in input schema: %s.", + fieldName, tableCall.getType())); + } + return new FunctionCallUtil.FieldRef(index); + } else { + throw new TableException( + String.format( + "Unknown operand for descriptor operator: %s.", + operand)); + } + }) + .collect(Collectors.toList()); + return new MLPredictSpec(features, runtimeConfig); + } + + protected ModelSpec buildModelSpec(RexModelCall modelCall) { + ModelSpec modelSpec = new ModelSpec(modelCall.getContextResolvedModel()); + modelSpec.setModelProvider(modelCall.getModelProvider()); + return modelSpec; + } + + protected @Nullable FunctionCallUtil.AsyncOptions buildAsyncOptions( + RexModelCall modelCall, Map runtimeConfig) { + boolean isAsyncEnabled = isAsyncMLPredict(modelCall.getModelProvider(), runtimeConfig); + if (isAsyncEnabled) { + return MLPredictUtil.getMergedMLPredictAsyncOptions( + runtimeConfig, ShortcutUtils.unwrapTableConfig(getCluster())); + } else { + return null; + } + } + + @SuppressWarnings("unchecked") + protected Optional extractOptionalOperand(Predicate predicate) { + return (Optional) + ((RexCall) scan.getCall()).getOperands().stream().filter(predicate).findFirst(); + } + + @SuppressWarnings("unchecked") + protected T extractOperand(Predicate predicate) { + return (T) + extractOptionalOperand(predicate) + .orElseThrow( + () -> + new TableException( + String.format( + "MLPredict doesn't contain specified operand: %s", + scan.getCall().toString()))); + } + + protected boolean isAsyncMLPredict(ModelProvider provider, Map runtimeConfig) { + boolean syncFound = false; + boolean asyncFound = false; + Optional requiredMode = + Configuration.fromMap(runtimeConfig) + .getOptional(MLPredictRuntimeConfigOptions.ASYNC); + + if (provider instanceof PredictRuntimeProvider) { + syncFound = true; + } + if (provider instanceof AsyncPredictRuntimeProvider) { + asyncFound = true; + } + + if (!syncFound && !asyncFound) { + throw new TableException( + String.format( + "Unknown model provider found: %s.", provider.getClass().getName())); + } + + if (requiredMode.isEmpty()) { + return asyncFound; + } else if (requiredMode.get()) { + if (!asyncFound) { + throw new TableException( + String.format( + "Require async mode, but model provider %s doesn't support async mode.", + provider.getClass().getName())); + } + return true; + } else { + if (!syncFound) { + throw new TableException( + String.format( + "Require sync mode, but model provider %s doesn't support sync mode.", + provider.getClass().getName())); + } + return false; + } + } +} diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalMLPredictTableFunction.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalMLPredictTableFunction.java index 0016ec9893d4a..99594ea578e49 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalMLPredictTableFunction.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalMLPredictTableFunction.java @@ -18,52 +18,26 @@ package org.apache.flink.table.planner.plan.nodes.physical.stream; -import org.apache.flink.configuration.Configuration; -import org.apache.flink.table.api.TableException; -import org.apache.flink.table.api.config.MLPredictRuntimeConfigOptions; -import org.apache.flink.table.ml.AsyncPredictRuntimeProvider; -import org.apache.flink.table.ml.ModelProvider; -import org.apache.flink.table.ml.PredictRuntimeProvider; import org.apache.flink.table.planner.calcite.FlinkTypeFactory; import org.apache.flink.table.planner.calcite.RexModelCall; -import org.apache.flink.table.planner.calcite.RexTableArgCall; import org.apache.flink.table.planner.plan.nodes.exec.ExecNode; import org.apache.flink.table.planner.plan.nodes.exec.InputProperty; -import org.apache.flink.table.planner.plan.nodes.exec.spec.MLPredictSpec; -import org.apache.flink.table.planner.plan.nodes.exec.spec.ModelSpec; import org.apache.flink.table.planner.plan.nodes.exec.stream.StreamExecMLPredictTableFunction; import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalTableFunctionScan; -import org.apache.flink.table.planner.plan.utils.FunctionCallUtil; -import org.apache.flink.table.planner.plan.utils.FunctionCallUtil.FunctionParam; -import org.apache.flink.table.planner.plan.utils.MLPredictUtil; +import org.apache.flink.table.planner.plan.nodes.physical.common.CommonPhysicalMLPredictTableFunction; import org.apache.flink.table.planner.utils.ShortcutUtils; import org.apache.calcite.plan.RelOptCluster; import org.apache.calcite.plan.RelTraitSet; import org.apache.calcite.rel.RelNode; -import org.apache.calcite.rel.RelWriter; -import org.apache.calcite.rel.SingleRel; import org.apache.calcite.rel.type.RelDataType; -import org.apache.calcite.rex.RexCall; -import org.apache.calcite.rex.RexLiteral; -import org.apache.calcite.rex.RexNode; -import org.apache.calcite.sql.SqlDescriptorOperator; -import javax.annotation.Nullable; - -import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.Optional; -import java.util.function.Predicate; -import java.util.stream.Collectors; /** Stream physical RelNode for ml predict table function. */ -public class StreamPhysicalMLPredictTableFunction extends SingleRel implements StreamPhysicalRel { - - private final RelDataType outputRowType; - private final FlinkLogicalTableFunctionScan scan; - private final Map runtimeConfig; +public class StreamPhysicalMLPredictTableFunction extends CommonPhysicalMLPredictTableFunction + implements StreamPhysicalRel { public StreamPhysicalMLPredictTableFunction( RelOptCluster cluster, @@ -72,10 +46,7 @@ public StreamPhysicalMLPredictTableFunction( FlinkLogicalTableFunctionScan scan, RelDataType outputRowType, Map runtimeConfig) { - super(cluster, traits, inputRel); - this.scan = scan; - this.outputRowType = outputRowType; - this.runtimeConfig = runtimeConfig; + super(cluster, traits, inputRel, scan, outputRowType, runtimeConfig); } @Override @@ -101,135 +72,4 @@ public ExecNode translateToExecNode() { FlinkTypeFactory.toLogicalRowType(getRowType()), getRelDetailedDescription()); } - - @Override - protected RelDataType deriveRowType() { - return outputRowType; - } - - @Override - public RelWriter explainTerms(RelWriter pw) { - return super.explainTerms(pw) - .item("invocation", scan.getCall()) - .item("rowType", getRowType()); - } - - public RexNode getMLPredictCall() { - return scan.getCall(); - } - - private MLPredictSpec buildMLPredictSpec(Map runtimeConfig) { - RexTableArgCall tableCall = extractOperand(operand -> operand instanceof RexTableArgCall); - RexCall descriptorCall = - extractOperand( - operand -> - operand instanceof RexCall - && ((RexCall) operand).getOperator() - instanceof SqlDescriptorOperator); - Map column2Index = new HashMap<>(); - List fieldNames = tableCall.getType().getFieldNames(); - for (int i = 0; i < fieldNames.size(); i++) { - column2Index.put(fieldNames.get(i), i); - } - List features = - descriptorCall.getOperands().stream() - .map( - operand -> { - if (operand instanceof RexLiteral) { - RexLiteral literal = (RexLiteral) operand; - String fieldName = RexLiteral.stringValue(literal); - Integer index = column2Index.get(fieldName); - if (index == null) { - throw new TableException( - String.format( - "Field %s is not found in input schema: %s.", - fieldName, tableCall.getType())); - } - return new FunctionCallUtil.FieldRef(index); - } else { - throw new TableException( - String.format( - "Unknown operand for descriptor operator: %s.", - operand)); - } - }) - .collect(Collectors.toList()); - return new MLPredictSpec(features, runtimeConfig); - } - - private ModelSpec buildModelSpec(RexModelCall modelCall) { - ModelSpec modelSpec = new ModelSpec(modelCall.getContextResolvedModel()); - modelSpec.setModelProvider(modelCall.getModelProvider()); - return modelSpec; - } - - private @Nullable FunctionCallUtil.AsyncOptions buildAsyncOptions( - RexModelCall modelCall, Map runtimeConfig) { - boolean isAsyncEnabled = isAsyncMLPredict(modelCall.getModelProvider(), runtimeConfig); - if (isAsyncEnabled) { - return MLPredictUtil.getMergedMLPredictAsyncOptions( - runtimeConfig, ShortcutUtils.unwrapTableConfig(getCluster())); - } else { - return null; - } - } - - @SuppressWarnings("unchecked") - private Optional extractOptionalOperand(Predicate predicate) { - return (Optional) - ((RexCall) scan.getCall()).getOperands().stream().filter(predicate).findFirst(); - } - - @SuppressWarnings("unchecked") - private T extractOperand(Predicate predicate) { - return (T) - extractOptionalOperand(predicate) - .orElseThrow( - () -> - new TableException( - String.format( - "MLPredict doesn't contain specified operand: %s", - scan.getCall().toString()))); - } - - private boolean isAsyncMLPredict(ModelProvider provider, Map runtimeConfig) { - boolean syncFound = false; - boolean asyncFound = false; - Optional requiredMode = - Configuration.fromMap(runtimeConfig) - .getOptional(MLPredictRuntimeConfigOptions.ASYNC); - - if (provider instanceof PredictRuntimeProvider) { - syncFound = true; - } - if (provider instanceof AsyncPredictRuntimeProvider) { - asyncFound = true; - } - - if (!syncFound && !asyncFound) { - throw new TableException( - String.format( - "Unknown model provider found: %s.", provider.getClass().getName())); - } - - if (requiredMode.isEmpty()) { - return asyncFound; - } else if (requiredMode.get()) { - if (!asyncFound) { - throw new TableException( - String.format( - "Require async mode, but model provider %s doesn't support async mode.", - provider.getClass().getName())); - } - return true; - } else { - if (!syncFound) { - throw new TableException( - String.format( - "Require sync mode, but model provider %s doesn't support sync mode.", - provider.getClass().getName())); - } - return false; - } - } } diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalProcessTableFunctionRule.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalProcessTableFunctionRule.java index 7b709db6e3bbb..5c3f90103616f 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalProcessTableFunctionRule.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalProcessTableFunctionRule.java @@ -24,6 +24,7 @@ import org.apache.flink.table.planner.functions.bridging.BridgingSqlFunction; import org.apache.flink.table.planner.plan.nodes.FlinkConventions; import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalTableFunctionScan; +import org.apache.flink.table.planner.plan.rules.physical.common.PhysicalMLPredictTableFunctionRule; import org.apache.flink.table.planner.plan.trait.FlinkRelDistribution; import org.apache.flink.table.planner.utils.ShortcutUtils; @@ -70,7 +71,7 @@ public boolean matches(RelOptRuleCall call) { final RexCall rexCall = (RexCall) scan.getCall(); final FunctionDefinition definition = ShortcutUtils.unwrapFunctionDefinition(rexCall); return definition != null - && !StreamPhysicalMLPredictTableFunctionRule.isMLPredictFunction(definition) + && !PhysicalMLPredictTableFunctionRule.isMLPredictFunction(definition) && definition.getKind() == FunctionKind.PROCESS_TABLE; } diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalMLPredictTableFunctionRule.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/common/PhysicalMLPredictTableFunctionRule.java similarity index 77% rename from flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalMLPredictTableFunctionRule.java rename to flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/common/PhysicalMLPredictTableFunctionRule.java index 34b84399d1ed9..7d5a61ee41977 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalMLPredictTableFunctionRule.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/common/PhysicalMLPredictTableFunctionRule.java @@ -16,7 +16,7 @@ * limitations under the License. */ -package org.apache.flink.table.planner.plan.nodes.physical.stream; +package org.apache.flink.table.planner.plan.rules.physical.common; import org.apache.flink.configuration.Configuration; import org.apache.flink.table.api.ValidationException; @@ -27,8 +27,11 @@ import org.apache.flink.table.ml.AsyncPredictRuntimeProvider; import org.apache.flink.table.ml.PredictRuntimeProvider; import org.apache.flink.table.planner.calcite.RexModelCall; +import org.apache.flink.table.planner.plan.nodes.FlinkConvention; import org.apache.flink.table.planner.plan.nodes.FlinkConventions; import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalTableFunctionScan; +import org.apache.flink.table.planner.plan.nodes.physical.batch.BatchPhysicalMLPredictTableFunction; +import org.apache.flink.table.planner.plan.nodes.physical.stream.StreamPhysicalMLPredictTableFunction; import org.apache.flink.table.planner.plan.utils.FunctionCallUtil; import org.apache.flink.table.planner.utils.ShortcutUtils; import org.apache.flink.table.types.logical.LogicalType; @@ -57,22 +60,30 @@ /** * Rule to convert a {@link FlinkLogicalTableFunctionScan} with ml_predict call into a {@link - * StreamPhysicalMLPredictTableFunction}. + * BatchPhysicalMLPredictTableFunction} or {@link StreamPhysicalMLPredictTableFunction}. */ -public class StreamPhysicalMLPredictTableFunctionRule extends ConverterRule { +public class PhysicalMLPredictTableFunctionRule extends ConverterRule { private static final String CONFIG_ERROR_MESSAGE = "Config parameter of ML_PREDICT function should be a MAP data type consisting String literals."; - public static final StreamPhysicalMLPredictTableFunctionRule INSTANCE = - new StreamPhysicalMLPredictTableFunctionRule( + public static final PhysicalMLPredictTableFunctionRule BATCH_INSTANCE = + new PhysicalMLPredictTableFunctionRule( + Config.INSTANCE.withConversion( + FlinkLogicalTableFunctionScan.class, + FlinkConventions.LOGICAL(), + FlinkConventions.BATCH_PHYSICAL(), + "PhysicalMLPredictTableFunctionRule:Batch")); + + public static final PhysicalMLPredictTableFunctionRule STREAM_INSTANCE = + new PhysicalMLPredictTableFunctionRule( Config.INSTANCE.withConversion( FlinkLogicalTableFunctionScan.class, FlinkConventions.LOGICAL(), FlinkConventions.STREAM_PHYSICAL(), - "StreamPhysicalModelTableFunctionRule")); + "PhysicalMLPredictTableFunctionRule:Stream")); - private StreamPhysicalMLPredictTableFunctionRule(Config config) { + private PhysicalMLPredictTableFunctionRule(Config config) { super(config); } @@ -93,23 +104,34 @@ public boolean matches(RelOptRuleCall call) { @Override public @Nullable RelNode convert(RelNode rel) { final FlinkLogicalTableFunctionScan scan = (FlinkLogicalTableFunctionScan) rel; - final RelNode newInput = - RelOptRule.convert(scan.getInput(0), FlinkConventions.STREAM_PHYSICAL()); + final FlinkConvention convention = (FlinkConvention) getOutConvention(); + final RelNode newInput = RelOptRule.convert(scan.getInput(0), convention); - final RelTraitSet providedTraitSet = - rel.getTraitSet().replace(FlinkConventions.STREAM_PHYSICAL()); + final RelTraitSet providedTraitSet = rel.getTraitSet().replace(convention); // Extract and validate configuration from the 4th operand if present final RexCall rexCall = (RexCall) scan.getCall(); final Map runtimeConfig = buildRuntimeConfig(rexCall); - return new StreamPhysicalMLPredictTableFunction( - scan.getCluster(), - providedTraitSet, - newInput, - scan, - scan.getRowType(), - runtimeConfig); + if (convention == FlinkConventions.BATCH_PHYSICAL()) { + return new BatchPhysicalMLPredictTableFunction( + scan.getCluster(), + providedTraitSet, + newInput, + scan, + scan.getRowType(), + runtimeConfig); + } else if (convention == FlinkConventions.STREAM_PHYSICAL()) { + return new StreamPhysicalMLPredictTableFunction( + scan.getCluster(), + providedTraitSet, + newInput, + scan, + scan.getRowType(), + runtimeConfig); + } else { + throw new UnsupportedOperationException("Unsupported convention: " + convention); + } } public static boolean isMLPredictFunction(FunctionDefinition definition) { diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/utils/ExecNodeMetadataUtil.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/utils/ExecNodeMetadataUtil.java index cd24cc9c82b55..6b5719cfe1ba6 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/utils/ExecNodeMetadataUtil.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/utils/ExecNodeMetadataUtil.java @@ -36,6 +36,7 @@ import org.apache.flink.table.planner.plan.nodes.exec.batch.BatchExecHashJoin; import org.apache.flink.table.planner.plan.nodes.exec.batch.BatchExecLimit; import org.apache.flink.table.planner.plan.nodes.exec.batch.BatchExecLookupJoin; +import org.apache.flink.table.planner.plan.nodes.exec.batch.BatchExecMLPredictTableFunction; import org.apache.flink.table.planner.plan.nodes.exec.batch.BatchExecMatch; import org.apache.flink.table.planner.plan.nodes.exec.batch.BatchExecNestedLoopJoin; import org.apache.flink.table.planner.plan.nodes.exec.batch.BatchExecOverAggregate; @@ -201,6 +202,7 @@ private ExecNodeMetadataUtil() { add(BatchExecMatch.class); add(BatchExecOverAggregate.class); add(BatchExecRank.class); + add(BatchExecMLPredictTableFunction.class); } }; diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/FlinkBatchRuleSets.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/FlinkBatchRuleSets.scala index 6af00f475c4ff..40acd26b9bdd0 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/FlinkBatchRuleSets.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/FlinkBatchRuleSets.scala @@ -21,7 +21,7 @@ import org.apache.flink.table.planner.plan.nodes.logical._ import org.apache.flink.table.planner.plan.rules.logical._ import org.apache.flink.table.planner.plan.rules.physical.FlinkExpandConversionRule import org.apache.flink.table.planner.plan.rules.physical.batch._ -import org.apache.flink.table.planner.plan.rules.physical.common.PhysicalVectorSearchTableFunctionRule +import org.apache.flink.table.planner.plan.rules.physical.common.{PhysicalMLPredictTableFunctionRule, PhysicalVectorSearchTableFunctionRule} import org.apache.calcite.rel.core.RelFactories import org.apache.calcite.rel.logical.{LogicalIntersect, LogicalMinus, LogicalUnion} @@ -430,6 +430,7 @@ object FlinkBatchRuleSets { BatchPhysicalMatchRule.INSTANCE, // ml PhysicalVectorSearchTableFunctionRule.BATCH_INSTANCE, + PhysicalMLPredictTableFunctionRule.BATCH_INSTANCE, // correlate BatchPhysicalConstantTableFunctionScanRule.INSTANCE, BatchPhysicalCorrelateRule.INSTANCE, diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/FlinkStreamRuleSets.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/FlinkStreamRuleSets.scala index c733b19ed20ba..a8b1e37b04441 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/FlinkStreamRuleSets.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/FlinkStreamRuleSets.scala @@ -18,10 +18,10 @@ package org.apache.flink.table.planner.plan.rules import org.apache.flink.table.planner.plan.nodes.logical._ -import org.apache.flink.table.planner.plan.nodes.physical.stream.{StreamPhysicalMLPredictTableFunctionRule, StreamPhysicalProcessTableFunctionRule} +import org.apache.flink.table.planner.plan.nodes.physical.stream.StreamPhysicalProcessTableFunctionRule import org.apache.flink.table.planner.plan.rules.logical.{JoinToMultiJoinRule, _} import org.apache.flink.table.planner.plan.rules.physical.FlinkExpandConversionRule -import org.apache.flink.table.planner.plan.rules.physical.common.PhysicalVectorSearchTableFunctionRule +import org.apache.flink.table.planner.plan.rules.physical.common.{PhysicalMLPredictTableFunctionRule, PhysicalVectorSearchTableFunctionRule} import org.apache.flink.table.planner.plan.rules.physical.stream._ import org.apache.calcite.rel.core.RelFactories @@ -490,7 +490,7 @@ object FlinkStreamRuleSets { // process table function StreamPhysicalProcessTableFunctionRule.INSTANCE, // model TVFs - StreamPhysicalMLPredictTableFunctionRule.INSTANCE, + PhysicalMLPredictTableFunctionRule.STREAM_INSTANCE, PhysicalVectorSearchTableFunctionRule.STREAM_INSTANCE, // join StreamPhysicalJoinRule.INSTANCE, diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/batch/sql/MLPredictTableFunctionTest.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/batch/sql/MLPredictTableFunctionTest.java new file mode 100644 index 0000000000000..e19a045b7f429 --- /dev/null +++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/batch/sql/MLPredictTableFunctionTest.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.planner.plan.batch.sql; + +import org.apache.flink.table.api.TableConfig; +import org.apache.flink.table.planner.plan.common.MLPredictTableFunctionTestBase; +import org.apache.flink.table.planner.utils.TableTestUtil; + +/** Tests for ML_PREDICT table function in batch mode. */ +public class MLPredictTableFunctionTest extends MLPredictTableFunctionTestBase { + + @Override + protected TableTestUtil getUtil() { + return batchTestUtil(TableConfig.getDefault()); + } +} diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/common/MLPredictTableFunctionTestBase.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/common/MLPredictTableFunctionTestBase.java new file mode 100644 index 0000000000000..b922e66d8a6ea --- /dev/null +++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/common/MLPredictTableFunctionTestBase.java @@ -0,0 +1,549 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.planner.plan.common; + +import org.apache.flink.table.api.ExplainDetail; +import org.apache.flink.table.api.TableException; +import org.apache.flink.table.api.ValidationException; +import org.apache.flink.table.planner.utils.JavaScalaConversionUtil; +import org.apache.flink.table.planner.utils.TableTestBase; +import org.apache.flink.table.planner.utils.TableTestUtil; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; + +import java.util.Collections; +import java.util.stream.Stream; + +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** Base test class for ML_PREDICT table function. */ +public abstract class MLPredictTableFunctionTestBase extends TableTestBase { + + protected TableTestUtil util; + + protected abstract TableTestUtil getUtil(); + + @BeforeEach + public void setup() { + util = getUtil(); + + if (util.isBounded()) { + // Create test table + util.tableEnv() + .executeSql( + "CREATE TABLE MyTable (\n" + + " a INT,\n" + + " b BIGINT,\n" + + " c STRING,\n" + + " d DECIMAL(10, 3)\n" + + ") with (\n" + + " 'connector' = 'values',\n" + + " 'bounded' = 'true'\n" + + ")"); + } else { + // Create test table + util.tableEnv() + .executeSql( + "CREATE TABLE MyTable (\n" + + " a INT,\n" + + " b BIGINT,\n" + + " c STRING,\n" + + " d DECIMAL(10, 3),\n" + + " rowtime TIMESTAMP(3),\n" + + " proctime as PROCTIME(),\n" + + " WATERMARK FOR rowtime AS rowtime - INTERVAL '1' SECOND\n" + + ") with (\n" + + " 'connector' = 'values'\n" + + ")"); + } + + // Create test model + util.tableEnv() + .executeSql( + "CREATE MODEL MyModel\n" + + "INPUT (a INT, b BIGINT)\n" + + "OUTPUT(e STRING, f ARRAY)\n" + + "with (\n" + + " 'provider' = 'test-model',\n" + + " 'endpoint' = 'someendpoint',\n" + + " 'task' = 'text_generation'\n" + + ")"); + } + + @Test + public void testSimpleArguments() { + String sql = + "SELECT *\n" + + "FROM TABLE(ML_PREDICT(TABLE MyTable, " + + "MODEL MyModel, " + + "DESCRIPTOR(a, b)))"; + util.verifyRelPlan(sql); + } + + @Test + public void testNamedArguments() { + String sql = + "SELECT *\n" + + "FROM TABLE(ML_PREDICT(INPUT => TABLE MyTable, " + + "MODEL => MODEL MyModel, " + + "ARGS => DESCRIPTOR(a, b)))"; + util.verifyRelPlan(sql); + } + + @Test + public void testOptionalNamedArguments() { + String sql = + "SELECT *\n" + + "FROM TABLE(ML_PREDICT(INPUT => TABLE MyTable, " + + "MODEL => MODEL MyModel, " + + "ARGS => DESCRIPTOR(a, b)," + + "CONFIG => MAP['key', 'value']))"; + util.verifyRelPlan(sql); + } + + @Test + public void testConfigWithCast() { + String sql = + "SELECT *\n" + + "FROM TABLE(ML_PREDICT(TABLE MyTable, MODEL MyModel, DESCRIPTOR(a, b), MAP['async', 'true', 'timeout', '100s']))"; + util.verifyRelPlan(sql); + } + + @Test + public void testTooFewArguments() { + String sql = "SELECT *\n" + "FROM TABLE(ML_PREDICT(TABLE MyTable, MODEL MyModel))"; + assertThatThrownBy(() -> util.verifyRelPlan(sql)) + .hasMessageContaining("No match found for function signature ML_PREDICT"); + } + + @Test + public void testNonExistModel() { + String sql = + "SELECT *\n" + + "FROM TABLE(ML_PREDICT(TABLE MyTable, MODEL NonExistModel, DESCRIPTOR(a, b)))"; + assertThatThrownBy(() -> util.verifyRelPlan(sql)) + .isInstanceOf(ValidationException.class) + .hasMessageContaining("Object 'NonExistModel' not found"); + } + + @Test + public void testConflictOutputColumnName() { + util.tableEnv() + .executeSql( + "CREATE MODEL ConflictModel\n" + + "INPUT (a INT, b BIGINT)\n" + + "OUTPUT(c STRING, d ARRAY)\n" + + "with (\n" + + " 'task' = 'text_generation',\n" + + " 'endpoint' = 'someendpoint',\n" + + " 'provider' = 'test-model'" + + ")"); + + String sql = + "SELECT *\n" + + "FROM TABLE(ML_PREDICT(TABLE MyTable, MODEL ConflictModel, DESCRIPTOR(a, b)))"; + util.verifyRelPlan(sql); + } + + @Test + public void testMissingModelParam() { + String sql = + "SELECT *\n" + + "FROM TABLE(ML_PREDICT(TABLE MyTable, DESCRIPTOR(a, b), DESCRIPTOR(a, b)))"; + assertThatThrownBy(() -> util.verifyRelPlan(sql)) + .isInstanceOf(ValidationException.class) + .hasRootCauseMessage( + "Invalid argument value. Argument 'MODEL' expects a model to be passed."); + } + + @Test + public void testMismatchInputSize() { + String sql = + "SELECT *\n" + + "FROM TABLE(ML_PREDICT(TABLE MyTable, MODEL MyModel, DESCRIPTOR(a, b, c)))"; + assertThatThrownBy(() -> util.verifyRelPlan(sql)) + .isInstanceOf(ValidationException.class) + .hasRootCauseMessage( + "Number of descriptor columns (3) does not match model input size (2)."); + } + + @Test + public void testNonExistColumn() { + String sql = + "SELECT *\n" + + "FROM TABLE(ML_PREDICT(TABLE MyTable, MODEL MyModel, DESCRIPTOR(no_col)))"; + + if (util.isBounded()) { + assertThatThrownBy(() -> util.verifyRelPlan(sql)) + .isInstanceOf(ValidationException.class) + .hasRootCauseMessage( + "Descriptor column 'no_col' not found in table columns. Available columns: a, b, c, d."); + } else { + assertThatThrownBy(() -> util.verifyRelPlan(sql)) + .isInstanceOf(ValidationException.class) + .hasRootCauseMessage( + "Descriptor column 'no_col' not found in table columns. Available columns: a, b, c, d, rowtime, proctime."); + } + } + + @Test + public void testNonSimpleColumn() { + String sql = + "SELECT *\n" + + "FROM TABLE(ML_PREDICT(TABLE MyTable, MODEL MyModel, DESCRIPTOR(MyTable.a)))"; + assertThatThrownBy(() -> util.verifyRelPlan(sql)) + .isInstanceOf(ValidationException.class) + .hasRootCauseMessage( + "Third argument must be a descriptor with simple column names for ML_PREDICT function."); + } + + @ParameterizedTest + @MethodSource("compatibleTypeProvider") + public void testCompatibleInputTypes(String tableType, String modelType) { + // Create test table with dynamic type + String bounded = util.isBounded() ? ",\n 'bounded' = 'true'\n" : "\n"; + util.tableEnv() + .executeSql( + String.format( + "CREATE TABLE TypeTable (\n" + + " col %s\n" + + ") with (\n" + + " 'connector' = 'values'\n" + + bounded + + ")", + tableType)); + + // Create test model with dynamic type + util.tableEnv() + .executeSql( + String.format( + "CREATE MODEL TypeModel\n" + + "INPUT (x %s)\n" + + "OUTPUT (res STRING)\n" + + "with (\n" + + " 'task' = 'text_generation',\n" + + " 'endpoint' = 'someendpoint',\n" + + " 'provider' = 'test-model'" + + ")", + modelType)); + + String sql = + "SELECT *\n" + + "FROM TABLE(ML_PREDICT(TABLE TypeTable, MODEL TypeModel, DESCRIPTOR(col)))"; + util.verifyRelPlan(sql); + } + + @ParameterizedTest + @MethodSource("incompatibleTypeProvider") + public void testIncompatibleInputTypes(String tableType, String modelType) { + // Create test table with dynamic type + String bounded = util.isBounded() ? ",\n 'bounded' = 'true'\n" : "\n"; + util.tableEnv() + .executeSql( + String.format( + "CREATE TABLE TypeTable (\n" + + " col %s\n" + + ") with (\n" + + " 'connector' = 'values'" + + bounded + + ")", + tableType)); + + // Create test model with dynamic type + util.tableEnv() + .executeSql( + String.format( + "CREATE MODEL TypeModel\n" + + "INPUT (x %s)\n" + + "OUTPUT (res STRING)\n" + + "with (\n" + + " 'provider' = 'openai'\n" + + ")", + modelType)); + + String sql = + "SELECT *\n" + + "FROM TABLE(ML_PREDICT(TABLE TypeTable, MODEL TypeModel, DESCRIPTOR(col)))"; + assertThatThrownBy(() -> util.verifyRelPlan(sql)) + .isInstanceOf(ValidationException.class) + .hasStackTraceContaining("cannot be assigned to model input type"); + } + + @Test + public void testIllegalConfig() { + assertThatThrownBy( + () -> + util.verifyRelPlan( + "SELECT *\n" + + "FROM TABLE(ML_PREDICT(TABLE MyTable, MODEL MyModel, DESCRIPTOR(a, b), MAP['async', true]))")) + .isInstanceOf(ValidationException.class) + .hasRootCauseMessage( + "Invalid argument type at position 3. Data type MAP expected but MAP NOT NULL passed."); + + assertThatThrownBy( + () -> + util.verifyRelPlan( + "SELECT *\n" + + "FROM TABLE(ML_PREDICT(TABLE MyTable, MODEL MyModel, DESCRIPTOR(a, b), MAP['async', 'yes']))")) + .hasCauseInstanceOf(ValidationException.class) + .hasStackTraceContaining("Failed to parse the config."); + + assertThatThrownBy( + () -> + util.verifyRelPlan( + "SELECT *\n" + + "FROM TABLE(ML_PREDICT(TABLE MyTable, MODEL MyModel, DESCRIPTOR(a, b), MAP['async', 'true', 'max-concurrent-operations', '-1']))")) + .hasCauseInstanceOf(ValidationException.class) + .hasStackTraceContaining( + "Invalid runtime config option 'max-concurrent-operations'. Its value should be positive integer but was -1."); + + assertThatThrownBy( + () -> + util.verifyRelPlan( + "SELECT *\n" + + "FROM TABLE(ML_PREDICT(TABLE MyTable, MODEL MyModel, DESCRIPTOR(a, b), MAP['async', 'true', 'capacity', CAST(-1 AS STRING)]))")) + .hasCauseInstanceOf(ValidationException.class) + .hasStackTraceContaining( + "Config parameter should be a MAP data type consisting of String literals."); + + assertThatThrownBy( + () -> + util.verifyExecPlan( + "SELECT *\n" + + "FROM TABLE(ML_PREDICT(TABLE MyTable, MODEL MyModel, DESCRIPTOR(a, b), MAP['async', 'true']))")) + .isInstanceOf(TableException.class) + .hasMessageContaining( + "Require async mode, but model provider org.apache.flink.table.factories.TestModelProviderFactory$TestModelProviderMock doesn't support async mode."); + } + + @Test + public void testNonExistProvider() { + util.tableEnv() + .executeSql( + "CREATE MODEL ConflictModel\n" + + "INPUT (a INT, b BIGINT)\n" + + "OUTPUT(c STRING, d ARRAY)\n" + + "with (\n" + + " 'task' = 'text_generation',\n" + + " 'endpoint' = 'someendpoint',\n" + + " 'provider' = 'non-exist-model'" + + ")"); + + String sql = + "SELECT *\n" + + "FROM TABLE(ML_PREDICT(TABLE MyTable, MODEL ConflictModel, DESCRIPTOR(a, b)))"; + assertThatThrownBy(() -> util.verifyRelPlan(sql)) + .isInstanceOf(ValidationException.class) + .hasMessageContaining( + "Unable to create a model provider for model 'default_catalog.default_database.ConflictModel'."); + } + + @Test + public void testNonPredictProvider() { + util.tableEnv() + .executeSql( + "CREATE MODEL ConflictModel\n" + + "INPUT (a INT, b BIGINT)\n" + + "OUTPUT(c STRING, d ARRAY)\n" + + "with (\n" + + " 'task' = 'text_generation',\n" + + " 'endpoint' = 'someendpoint',\n" + + " 'provider' = 'non-exist-model'" + + ")"); + + String sql = + "SELECT *\n" + + "FROM TABLE(ML_PREDICT(TABLE MyTable, MODEL ConflictModel, DESCRIPTOR(a, b)))"; + assertThatThrownBy(() -> util.verifyRelPlan(sql)) + .isInstanceOf(ValidationException.class) + .hasMessageContaining( + "Unable to create a model provider for model 'default_catalog.default_database.ConflictModel'."); + } + + @Test + public void testNotMLPredictRuntimeProvider() { + util.tableEnv() + .executeSql( + "CREATE MODEL ConflictModel\n" + + "INPUT (a INT, b BIGINT)\n" + + "OUTPUT(c STRING, d ARRAY)\n" + + "with (\n" + + " 'task' = 'text_generation',\n" + + " 'endpoint' = 'someendpoint',\n" + + " 'provider' = 'non-predict-model'" + + ")"); + + String sql = + "SELECT *\n" + + "FROM TABLE(ML_PREDICT(TABLE MyTable, MODEL ConflictModel, DESCRIPTOR(a, b)))"; + assertThatThrownBy(() -> util.verifyRelPlan(sql)) + .isInstanceOf(TableException.class) + .hasMessageContaining( + "This exception indicates that the query uses an unsupported SQL feature."); + } + + @Test + public void testInputTableIsInsertOnlyStream() { + String sql = + "SELECT *\n" + + "FROM TABLE(ML_PREDICT(TABLE MyTable, MODEL MyModel, DESCRIPTOR(a, b)))"; + util.verifyRelPlan( + sql, + JavaScalaConversionUtil.toScala( + Collections.singletonList(ExplainDetail.CHANGELOG_MODE))); + } + + @Test + public void testInputTableIsCdcStream() { + if (util.isBounded()) { + // CDC table is unbounded stream, skip the bounded test + return; + } + util.tableEnv() + .executeSql( + "CREATE TABLE CdcTable(\n" + + " a INT,\n" + + " b BIGINT,\n" + + " PRIMARY KEY (a) NOT ENFORCED\n" + + ") WITH (\n" + + " 'connector' = 'values',\n" + + " 'changelog-mode' = 'I,UA,UB,D'" + + ")"); + String sql = + "SELECT *\n" + "FROM ML_PREDICT(TABLE CdcTable, MODEL MyModel, DESCRIPTOR(a, b))"; + assertThatThrownBy(() -> util.verifyRelPlan(sql)) + .isInstanceOf(TableException.class) + .hasMessageContaining( + "StreamPhysicalMLPredictTableFunction doesn't support consuming update and delete changes which is produced by node TableSourceScan(table=[[default_catalog, default_database, CdcTable]], fields=[a, b])"); + } + + protected static Stream compatibleTypeProvider() { + return Stream.of( + // NOT NULL to NULLABLE type + Arguments.of("STRING NOT NULL", "STRING"), + + // Exact matches - primitive types + Arguments.of("BOOLEAN", "BOOLEAN"), + Arguments.of("TINYINT", "TINYINT"), + Arguments.of("SMALLINT", "SMALLINT"), + Arguments.of("INT", "INT"), + Arguments.of("BIGINT", "BIGINT"), + Arguments.of("FLOAT", "FLOAT"), + Arguments.of("DOUBLE", "DOUBLE"), + Arguments.of("DECIMAL(10,2)", "DECIMAL(10,2)"), + Arguments.of("STRING", "STRING"), + Arguments.of("BINARY(10)", "BINARY(10)"), + Arguments.of("VARBINARY(10)", "VARBINARY(10)"), + Arguments.of("DATE", "DATE"), + Arguments.of("TIME(3)", "TIME(3)"), + Arguments.of("TIMESTAMP(3)", "TIMESTAMP(3)"), + Arguments.of("TIMESTAMP_LTZ(3)", "TIMESTAMP_LTZ(3)"), + + // Numeric type promotions + Arguments.of("TINYINT", "SMALLINT"), + Arguments.of("SMALLINT", "INT"), + Arguments.of("INT", "BIGINT"), + Arguments.of("FLOAT", "DOUBLE"), + Arguments.of("DECIMAL(5,2)", "DECIMAL(10,2)"), + Arguments.of( + "DECIMAL(10,2)", "DECIMAL(5,2)"), // This is also allowed, is this a bug? + + // String type compatibility + Arguments.of("CHAR(10)", "STRING"), + Arguments.of("VARCHAR(20)", "STRING"), + + // Temporal types + Arguments.of("TIMESTAMP(3)", "TIMESTAMP(3)"), + Arguments.of("DATE", "DATE"), + Arguments.of("TIME(3)", "TIME(3)"), + + // Array types + Arguments.of("ARRAY", "ARRAY"), + Arguments.of("ARRAY", "ARRAY"), + Arguments.of("ARRAY", "ARRAY"), + Arguments.of("ARRAY", "ARRAY"), + + // Map types + Arguments.of("MAP", "MAP"), + Arguments.of("MAP", "MAP"), + Arguments.of("MAP>", "MAP>"), + + // Row types + Arguments.of("ROW", "ROW"), + Arguments.of( + "ROW", "ROW"), // Different field name + Arguments.of( + "ROW>", "ROW>"), + Arguments.of( + "ROW>", + "ROW>"), + + // Nested complex types + Arguments.of( + "ROW, b MAP>>", + "ROW, b MAP>>"), + Arguments.of( + "MAP>>", + "MAP>>")); + } + + protected static Stream incompatibleTypeProvider() { + return Stream.of( + // NULLABLE to NOT NULL type + Arguments.of("STRING", "STRING NOT NULL"), + + // Incompatible primitive types + Arguments.of("BOOLEAN", "INT"), + Arguments.of("STRING", "INT"), + Arguments.of("INT", "STRING"), + Arguments.of("TIMESTAMP(3)", "INT"), + Arguments.of("DATE", "TIMESTAMP(3)"), + Arguments.of("BINARY(10)", "STRING"), + + // Incompatible numeric types (wrong direction) + Arguments.of("BIGINT", "INT"), // Cannot downcast + Arguments.of("DOUBLE", "FLOAT"), // Cannot downcast + + // Incompatible array types + Arguments.of("ARRAY", "ARRAY"), + Arguments.of("ARRAY", "ARRAY"), + Arguments.of("INT", "ARRAY"), + + // Incompatible map types + Arguments.of("MAP", "MAP"), // Key type mismatch + Arguments.of("MAP", "MAP"), // Value type mismatch + Arguments.of("MAP", "MAP"), // Cannot downcast value + Arguments.of("MAP", "MAP"), // Cannot downcast key + + // Incompatible row types + Arguments.of("ROW", "ROW"), // Field type mismatch + Arguments.of("ROW", "ROW"), // Field count mismatch + + // Incompatible nested types + Arguments.of( + "ROW, b MAP>", + "ROW, b MAP>"), + Arguments.of("MAP>", "MAP>"), + Arguments.of("ARRAY>", "ARRAY>")); + } +} diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/nodes/exec/batch/MLPredictBatchRestoreTest.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/nodes/exec/batch/MLPredictBatchRestoreTest.java new file mode 100644 index 0000000000000..41fe1defbf103 --- /dev/null +++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/nodes/exec/batch/MLPredictBatchRestoreTest.java @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.planner.plan.nodes.exec.batch; + +import org.apache.flink.configuration.ConfigOption; +import org.apache.flink.table.api.config.ExecutionConfigOptions; +import org.apache.flink.table.planner.plan.nodes.exec.testutils.BatchRestoreTestBase; +import org.apache.flink.table.planner.plan.utils.ExecNodeMetadataUtil; +import org.apache.flink.table.test.program.TableTestProgram; + +import org.junit.jupiter.api.Test; + +import java.util.Arrays; +import java.util.HashSet; +import java.util.List; +import java.util.Objects; +import java.util.stream.Collectors; + +import static org.apache.flink.table.planner.plan.nodes.exec.stream.MLPredictTestPrograms.ASYNC_UNORDERED_ML_PREDICT; +import static org.apache.flink.table.planner.plan.nodes.exec.stream.MLPredictTestPrograms.SYNC_ML_PREDICT; +import static org.apache.flink.table.planner.plan.nodes.exec.stream.MLPredictTestPrograms.SYNC_ML_PREDICT_WITH_RUNTIME_CONFIG; +import static org.assertj.core.api.Assertions.assertThat; + +/** Restore tests for {@link BatchExecMLPredictTableFunction}. */ +public class MLPredictBatchRestoreTest extends BatchRestoreTestBase { + + public MLPredictBatchRestoreTest() { + super(BatchExecMLPredictTableFunction.class); + } + + @Test + public void testExecNodeMetadataContainsRequiredOptions() { + assertThat( + new HashSet<>( + Arrays.asList( + Objects.requireNonNull( + ExecNodeMetadataUtil.consumedOptions( + BatchExecMLPredictTableFunction.class))))) + .isEqualTo( + Arrays.asList( + ExecutionConfigOptions + .TABLE_EXEC_ASYNC_ML_PREDICT_MAX_CONCURRENT_OPERATIONS, + ExecutionConfigOptions.TABLE_EXEC_ASYNC_ML_PREDICT_TIMEOUT, + ExecutionConfigOptions + .TABLE_EXEC_ASYNC_ML_PREDICT_OUTPUT_MODE) + .stream() + .map(ConfigOption::key) + .collect(Collectors.toSet())); + } + + @Override + public List programs() { + return List.of( + SYNC_ML_PREDICT, ASYNC_UNORDERED_ML_PREDICT, SYNC_ML_PREDICT_WITH_RUNTIME_CONFIG); + } +} diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/nodes/exec/testutils/BatchRestoreTestBase.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/nodes/exec/testutils/BatchRestoreTestBase.java index 3dbcad99db57c..c17bee760e14f 100644 --- a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/nodes/exec/testutils/BatchRestoreTestBase.java +++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/nodes/exec/testutils/BatchRestoreTestBase.java @@ -23,11 +23,13 @@ import org.apache.flink.table.api.PlanReference; import org.apache.flink.table.api.TableEnvironment; import org.apache.flink.table.api.config.TableConfigOptions; +import org.apache.flink.table.planner.factories.TestValuesModelFactory; import org.apache.flink.table.planner.factories.TestValuesTableFactory; import org.apache.flink.table.planner.plan.nodes.exec.ExecNode; import org.apache.flink.table.planner.plan.nodes.exec.ExecNodeMetadata; import org.apache.flink.table.planner.plan.nodes.exec.batch.BatchExecNode; import org.apache.flink.table.planner.plan.utils.ExecNodeMetadataUtil; +import org.apache.flink.table.test.program.ModelTestStep; import org.apache.flink.table.test.program.SinkTestStep; import org.apache.flink.table.test.program.SourceTestStep; import org.apache.flink.table.test.program.SqlTestStep; @@ -103,6 +105,7 @@ public EnumSet supportedSetupSteps() { return EnumSet.of( TestKind.CONFIG, TestKind.FUNCTION, + TestKind.MODEL, TestKind.SOURCE_WITH_RESTORE_DATA, TestKind.SOURCE_WITH_DATA, TestKind.SINK_WITH_RESTORE_DATA, @@ -152,6 +155,13 @@ public void generateCompiledPlans(TableTestProgram program) { TableConfigOptions.PLAN_COMPILE_CATALOG_OBJECTS, TableConfigOptions.CatalogPlanCompilation.SCHEMA); + for (ModelTestStep modelTestStep : program.getSetupModelTestSteps()) { + final Map options = new HashMap<>(); + options.put("provider", "values"); + options.put("data-id", TestValuesModelFactory.registerData(modelTestStep.data)); + modelTestStep.apply(tEnv, options); + } + for (SourceTestStep sourceTestStep : program.getSetupSourceTestSteps()) { final String id = TestValuesTableFactory.registerData(sourceTestStep.dataBeforeRestore); final Map options = new HashMap<>(); @@ -198,8 +208,14 @@ void loadAndRunCompiledPlan(TableTestProgram program, ExecNodeMetadata metadata) program.getSetupConfigOptionTestSteps().forEach(s -> s.apply(tEnv)); - for (SourceTestStep sourceTestStep : program.getSetupSourceTestSteps()) { + for (ModelTestStep modelTestStep : program.getSetupModelTestSteps()) { + final Map options = new HashMap<>(); + options.put("provider", "values"); + options.put("data-id", TestValuesModelFactory.registerData(modelTestStep.data)); + modelTestStep.apply(tEnv, options); + } + for (SourceTestStep sourceTestStep : program.getSetupSourceTestSteps()) { List data = new ArrayList<>(); data.addAll(sourceTestStep.dataBeforeRestore); data.addAll(sourceTestStep.dataAfterRestore); diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/stream/sql/MLPredictTableFunctionTest.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/stream/sql/MLPredictTableFunctionTest.java index 1c20fdb7ed958..b409e02cf9db9 100644 --- a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/stream/sql/MLPredictTableFunctionTest.java +++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/stream/sql/MLPredictTableFunctionTest.java @@ -21,32 +21,31 @@ import org.apache.flink.table.api.ExplainDetail; import org.apache.flink.table.api.TableConfig; import org.apache.flink.table.api.TableException; -import org.apache.flink.table.api.ValidationException; +import org.apache.flink.table.planner.plan.common.MLPredictTableFunctionTestBase; import org.apache.flink.table.planner.utils.JavaScalaConversionUtil; -import org.apache.flink.table.planner.utils.TableTestBase; import org.apache.flink.table.planner.utils.TableTestUtil; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; -import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.Arguments; -import org.junit.jupiter.params.provider.MethodSource; import java.util.Collections; -import java.util.stream.Stream; import static org.assertj.core.api.Assertions.assertThatThrownBy; -/** Tests for model table value function. */ -public class MLPredictTableFunctionTest extends TableTestBase { +/** Tests for model table value function in stream mode. */ +public class MLPredictTableFunctionTest extends MLPredictTableFunctionTestBase { - private TableTestUtil util; + @Override + protected TableTestUtil getUtil() { + return streamTestUtil(TableConfig.getDefault()); + } + @Override @BeforeEach public void setup() { - util = streamTestUtil(TableConfig.getDefault()); + util = getUtil(); - // Create test table + // Create test table with stream-specific columns util.tableEnv() .executeSql( "CREATE TABLE MyTable (\n" @@ -68,330 +67,12 @@ public void setup() { + "INPUT (a INT, b BIGINT)\n" + "OUTPUT(e STRING, f ARRAY)\n" + "with (\n" - + " 'provider' = 'test-model',\n" // test model provider defined in - // TestModelProviderFactory in - // flink-table-common + + " 'provider' = 'test-model',\n" + " 'endpoint' = 'someendpoint',\n" + " 'task' = 'text_generation'\n" + ")"); } - @Test - public void testSimpleArguments() { - String sql = - "SELECT *\n" - + "FROM TABLE(ML_PREDICT(TABLE MyTable, " - + "MODEL MyModel, " - + "DESCRIPTOR(a, b)))"; - util.verifyRelPlan(sql); - } - - @Test - public void testNamedArguments() { - String sql = - "SELECT *\n" - + "FROM TABLE(ML_PREDICT(INPUT => TABLE MyTable, " - + "MODEL => MODEL MyModel, " - + "ARGS => DESCRIPTOR(a, b)))"; - util.verifyRelPlan(sql); - } - - @Test - public void testOptionalNamedArguments() { - String sql = - "SELECT *\n" - + "FROM TABLE(ML_PREDICT(INPUT => TABLE MyTable, " - + "MODEL => MODEL MyModel, " - + "ARGS => DESCRIPTOR(a, b)," - + "CONFIG => MAP['key', 'value']))"; - util.verifyRelPlan(sql); - } - - @Test - public void testConfigWithCast() { - // 'async' and 'timeout' in the map are both cast to VARCHAR(7) - String sql = - "SELECT *\n" - + "FROM TABLE(ML_PREDICT(TABLE MyTable, MODEL MyModel, DESCRIPTOR(a, b), MAP['async', 'true', 'timeout', '100s']))"; - util.verifyRelPlan(sql); - } - - @Test - public void testTooFewArguments() { - String sql = "SELECT *\n" + "FROM TABLE(ML_PREDICT(TABLE MyTable, MODEL MyModel))"; - assertThatThrownBy(() -> util.verifyRelPlan(sql)) - .hasMessageContaining( - "No match found for function signature ML_PREDICT(, ).\n" - + "Supported signatures are:\n" - + "ML_PREDICT(INPUT => {TABLE, ROW SEMANTIC TABLE}, MODEL => {MODEL}, ARGS => DESCRIPTOR, CONFIG => MAP)"); - } - - @Test - public void testTooManyArguments() { - String sql = - "SELECT *\n" - + "FROM TABLE(ML_PREDICT(TABLE MyTable, MODEL MyModel, DESCRIPTOR(a, b), MAP['key', 'value'], 'arg0'))"; - assertThatThrownBy(() -> util.verifyRelPlan(sql)) - .hasMessageContaining( - "No match found for function signature ML_PREDICT(, , , <(CHAR(3), CHAR(5)) MAP>, ).\n" - + "Supported signatures are:\n" - + "ML_PREDICT(INPUT => {TABLE, ROW SEMANTIC TABLE}, MODEL => {MODEL}, ARGS => DESCRIPTOR, CONFIG => MAP)"); - } - - @Test - public void testNonExistModel() { - String sql = - "SELECT *\n" - + "FROM TABLE(ML_PREDICT(TABLE MyTable, MODEL NonExistModel, DESCRIPTOR(a, b), MAP['key', 'value'], 'arg0'))"; - assertThatThrownBy(() -> util.verifyRelPlan(sql)) - .isInstanceOf(ValidationException.class) - .hasMessageContaining("Object 'NonExistModel' not found"); - } - - @Test - public void testConflictOutputColumnName() { - util.tableEnv() - .executeSql( - "CREATE MODEL ConflictModel\n" - + "INPUT (a INT, b BIGINT)\n" - + "OUTPUT(c STRING, d ARRAY)\n" - + "with (\n" - + " 'task' = 'text_generation',\n" - + " 'endpoint' = 'someendpoint',\n" - + " 'provider' = 'test-model'" - + ")"); - - String sql = - "SELECT *\n" - + "FROM TABLE(ML_PREDICT(TABLE MyTable, MODEL ConflictModel, DESCRIPTOR(a, b)))"; - util.verifyRelPlan(sql); - } - - @Test - public void testMissingModelParam() { - String sql = - "SELECT *\n" - + "FROM TABLE(ML_PREDICT(TABLE MyTable, DESCRIPTOR(a, b), DESCRIPTOR(a, b)))"; - assertThatThrownBy(() -> util.verifyRelPlan(sql)) - .isInstanceOf(ValidationException.class) - .hasRootCauseMessage( - "Invalid argument value. Argument 'MODEL' expects a model to be passed."); - } - - @Test - public void testMismatchInputSize() { - String sql = - "SELECT *\n" - + "FROM TABLE(ML_PREDICT(TABLE MyTable, MODEL MyModel, DESCRIPTOR(a, b, c)))"; - assertThatThrownBy(() -> util.verifyRelPlan(sql)) - .isInstanceOf(ValidationException.class) - .hasRootCauseMessage( - "Number of descriptor columns (3) does not match model input size (2)."); - } - - @Test - public void testNonExistColumn() { - String sql = - "SELECT *\n" - + "FROM TABLE(ML_PREDICT(TABLE MyTable, MODEL MyModel, DESCRIPTOR(no_col)))"; - assertThatThrownBy(() -> util.verifyRelPlan(sql)) - .isInstanceOf(ValidationException.class) - .hasRootCauseMessage( - "Descriptor column 'no_col' not found in table columns. Available columns: a, b, c, d, rowtime, proctime."); - } - - @Test - public void testNonSimpleColumn() { - String sql = - "SELECT *\n" - + "FROM TABLE(ML_PREDICT(TABLE MyTable, MODEL MyModel, DESCRIPTOR(MyTable.a)))"; - assertThatThrownBy(() -> util.verifyRelPlan(sql)) - .isInstanceOf(ValidationException.class) - .hasRootCauseMessage( - "Third argument must be a descriptor with simple column names for ML_PREDICT function."); - } - - @ParameterizedTest - @MethodSource("compatibleTypeProvider") - public void testCompatibleInputTypes(String tableType, String modelType) { - // Create test table with dynamic type - util.tableEnv() - .executeSql( - String.format( - "CREATE TABLE TypeTable (\n" - + " col %s\n" - + ") with (\n" - + " 'connector' = 'values'\n" - + ")", - tableType)); - - // Create test model with dynamic type - util.tableEnv() - .executeSql( - String.format( - "CREATE MODEL TypeModel\n" - + "INPUT (x %s)\n" - + "OUTPUT (res STRING)\n" - + "with (\n" - + " 'task' = 'text_generation',\n" - + " 'endpoint' = 'someendpoint',\n" - + " 'provider' = 'test-model'" - + ")", - modelType)); - - String sql = - "SELECT *\n" - + "FROM TABLE(ML_PREDICT(TABLE TypeTable, MODEL TypeModel, DESCRIPTOR(col)))"; - util.verifyRelPlan(sql); - } - - @ParameterizedTest - @MethodSource("incompatibleTypeProvider") - public void testIncompatibleInputTypes(String tableType, String modelType) { - // Create test table with dynamic type - util.tableEnv() - .executeSql( - String.format( - "CREATE TABLE TypeTable (\n" - + " col %s\n" - + ") with (\n" - + " 'connector' = 'values'\n" - + ")", - tableType)); - - // Create test model with dynamic type - util.tableEnv() - .executeSql( - String.format( - "CREATE MODEL TypeModel\n" - + "INPUT (x %s)\n" - + "OUTPUT (res STRING)\n" - + "with (\n" - + " 'provider' = 'openai'\n" - + ")", - modelType)); - - String sql = - "SELECT *\n" - + "FROM TABLE(ML_PREDICT(TABLE TypeTable, MODEL TypeModel, DESCRIPTOR(col)))"; - assertThatThrownBy(() -> util.verifyRelPlan(sql)) - .isInstanceOf(ValidationException.class) - .hasStackTraceContaining("cannot be assigned to model input type"); - } - - @Test - public void testIllegalConfig() { - assertThatThrownBy( - () -> - util.verifyRelPlan( - "SELECT *\n" - + "FROM TABLE(ML_PREDICT(TABLE MyTable, MODEL MyModel, DESCRIPTOR(a, b), MAP['async', true]))")) - .isInstanceOf(ValidationException.class) - .hasRootCauseMessage( - "Invalid argument type at position 3. Data type MAP expected but MAP NOT NULL passed."); - - assertThatThrownBy( - () -> - util.verifyRelPlan( - "SELECT *\n" - + "FROM TABLE(ML_PREDICT(TABLE MyTable, MODEL MyModel, DESCRIPTOR(a, b), MAP['async', 'yes']))")) - .hasCauseInstanceOf(ValidationException.class) - .hasStackTraceContaining("Failed to parse the config."); - - assertThatThrownBy( - () -> - util.verifyRelPlan( - "SELECT *\n" - + "FROM TABLE(ML_PREDICT(TABLE MyTable, MODEL MyModel, DESCRIPTOR(a, b), MAP['async', 'true', 'max-concurrent-operations', '-1']))")) - .hasCauseInstanceOf(ValidationException.class) - .hasStackTraceContaining( - "Invalid runtime config option 'max-concurrent-operations'. Its value should be positive integer but was -1."); - - assertThatThrownBy( - () -> - util.verifyRelPlan( - "SELECT *\n" - + "FROM TABLE(ML_PREDICT(TABLE MyTable, MODEL MyModel, DESCRIPTOR(a, b), MAP['async', 'true', 'capacity', CAST(-1 AS STRING)]))")) - .hasCauseInstanceOf(ValidationException.class) - .hasStackTraceContaining( - "Config parameter should be a MAP data type consisting of String literals."); - - assertThatThrownBy( - () -> - util.verifyExecPlan( - "SELECT *\n" - + "FROM TABLE(ML_PREDICT(TABLE MyTable, MODEL MyModel, DESCRIPTOR(a, b), MAP['async', 'true']))")) - .isInstanceOf(TableException.class) - .hasMessageContaining( - "Require async mode, but model provider org.apache.flink.table.factories.TestModelProviderFactory$TestModelProviderMock doesn't support async mode."); - } - - @Test - public void testNonExistProvider() { - util.tableEnv() - .executeSql( - "CREATE MODEL ConflictModel\n" - + "INPUT (a INT, b BIGINT)\n" - + "OUTPUT(c STRING, d ARRAY)\n" - + "with (\n" - + " 'task' = 'text_generation',\n" - + " 'endpoint' = 'someendpoint',\n" - + " 'provider' = 'non-exist-model'" - + ")"); - - String sql = - "SELECT *\n" - + "FROM TABLE(ML_PREDICT(TABLE MyTable, MODEL ConflictModel, DESCRIPTOR(a, b)))"; - assertThatThrownBy(() -> util.verifyRelPlan(sql)) - .isInstanceOf(ValidationException.class) - .hasMessageContaining( - "Unable to create a model provider for model 'default_catalog.default_database.ConflictModel'."); - } - - @Test - public void testNonPredictProvider() { - util.tableEnv() - .executeSql( - "CREATE MODEL ConflictModel\n" - + "INPUT (a INT, b BIGINT)\n" - + "OUTPUT(c STRING, d ARRAY)\n" - + "with (\n" - + " 'task' = 'text_generation',\n" - + " 'endpoint' = 'someendpoint',\n" - + " 'provider' = 'non-exist-model'" - + ")"); - - String sql = - "SELECT *\n" - + "FROM TABLE(ML_PREDICT(TABLE MyTable, MODEL ConflictModel, DESCRIPTOR(a, b)))"; - assertThatThrownBy(() -> util.verifyRelPlan(sql)) - .isInstanceOf(ValidationException.class) - .hasMessageContaining( - "Unable to create a model provider for model 'default_catalog.default_database.ConflictModel'."); - } - - @Test - public void testNotMLPredictRuntimeProvider() { - util.tableEnv() - .executeSql( - "CREATE MODEL ConflictModel\n" - + "INPUT (a INT, b BIGINT)\n" - + "OUTPUT(c STRING, d ARRAY)\n" - + "with (\n" - + " 'task' = 'text_generation',\n" - + " 'endpoint' = 'someendpoint',\n" - + " 'provider' = 'non-predict-model'" - + ")"); - - String sql = - "SELECT *\n" - + "FROM TABLE(ML_PREDICT(TABLE MyTable, MODEL ConflictModel, DESCRIPTOR(a, b)))"; - assertThatThrownBy(() -> util.verifyRelPlan(sql)) - .isInstanceOf(TableException.class) - .hasMessageContaining( - "This exception indicates that the query uses an unsupported SQL feature."); - } - @Test public void testInputTableIsInsertOnlyStream() { String sql = @@ -422,114 +103,4 @@ public void testInputTableIsCdcStream() { .hasMessageContaining( "StreamPhysicalMLPredictTableFunction doesn't support consuming update and delete changes which is produced by node TableSourceScan(table=[[default_catalog, default_database, CdcTable]], fields=[a, b])"); } - - private static Stream compatibleTypeProvider() { - return Stream.of( - // NOT NULL to NULLABLE type - Arguments.of("STRING NOT NULL", "STRING"), - - // Exact matches - primitive types - Arguments.of("BOOLEAN", "BOOLEAN"), - Arguments.of("TINYINT", "TINYINT"), - Arguments.of("SMALLINT", "SMALLINT"), - Arguments.of("INT", "INT"), - Arguments.of("BIGINT", "BIGINT"), - Arguments.of("FLOAT", "FLOAT"), - Arguments.of("DOUBLE", "DOUBLE"), - Arguments.of("DECIMAL(10,2)", "DECIMAL(10,2)"), - Arguments.of("STRING", "STRING"), - Arguments.of("BINARY(10)", "BINARY(10)"), - Arguments.of("VARBINARY(10)", "VARBINARY(10)"), - Arguments.of("DATE", "DATE"), - Arguments.of("TIME(3)", "TIME(3)"), - Arguments.of("TIMESTAMP(3)", "TIMESTAMP(3)"), - Arguments.of("TIMESTAMP_LTZ(3)", "TIMESTAMP_LTZ(3)"), - - // Numeric type promotions - Arguments.of("TINYINT", "SMALLINT"), - Arguments.of("SMALLINT", "INT"), - Arguments.of("INT", "BIGINT"), - Arguments.of("FLOAT", "DOUBLE"), - Arguments.of("DECIMAL(5,2)", "DECIMAL(10,2)"), - Arguments.of( - "DECIMAL(10,2)", "DECIMAL(5,2)"), // This is also allowed, is this a bug? - - // String type compatibility - Arguments.of("CHAR(10)", "STRING"), - Arguments.of("VARCHAR(20)", "STRING"), - - // Temporal types - Arguments.of("TIMESTAMP(3)", "TIMESTAMP(3)"), - Arguments.of("DATE", "DATE"), - Arguments.of("TIME(3)", "TIME(3)"), - - // Array types - Arguments.of("ARRAY", "ARRAY"), - Arguments.of("ARRAY", "ARRAY"), - Arguments.of("ARRAY", "ARRAY"), - Arguments.of("ARRAY", "ARRAY"), - - // Map types - Arguments.of("MAP", "MAP"), - Arguments.of("MAP", "MAP"), - Arguments.of("MAP>", "MAP>"), - - // Row types - Arguments.of("ROW", "ROW"), - Arguments.of( - "ROW", "ROW"), // Different field name - Arguments.of( - "ROW>", "ROW>"), - Arguments.of( - "ROW>", - "ROW>"), - - // Nested complex types - Arguments.of( - "ROW, b MAP>>", - "ROW, b MAP>>"), - Arguments.of( - "MAP>>", - "MAP>>")); - } - - private static Stream incompatibleTypeProvider() { - return Stream.of( - // NULLABLE to NOT NULL type - Arguments.of("STRING", "STRING NOT NULL"), - - // Incompatible primitive types - Arguments.of("BOOLEAN", "INT"), - Arguments.of("STRING", "INT"), - Arguments.of("INT", "STRING"), - Arguments.of("TIMESTAMP(3)", "INT"), - Arguments.of("DATE", "TIMESTAMP(3)"), - Arguments.of("BINARY(10)", "STRING"), - - // Incompatible numeric types (wrong direction) - Arguments.of("BIGINT", "INT"), // Cannot downcast - Arguments.of("DOUBLE", "FLOAT"), // Cannot downcast - - // Incompatible array types - Arguments.of("ARRAY", "ARRAY"), - Arguments.of("ARRAY", "ARRAY"), - Arguments.of("INT", "ARRAY"), - - // Incompatible map types - Arguments.of("MAP", "MAP"), // Key type mismatch - Arguments.of("MAP", "MAP"), // Value type mismatch - Arguments.of("MAP", "MAP"), // Cannot downcast value - Arguments.of("MAP", "MAP"), // Cannot downcast key - - // Incompatible row types - Arguments.of("ROW", "ROW"), // Field type mismatch - Arguments.of("ROW", "ROW"), // Field count mismatch - - // Incompatible nested types - Arguments.of( - "ROW, b MAP>", - "ROW, b MAP>"), - Arguments.of("MAP>", "MAP>"), - Arguments.of("ARRAY>", "ARRAY>")); - } } diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/runtime/batch/table/AsyncMLPredictITCase.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/runtime/batch/table/AsyncMLPredictITCase.java new file mode 100644 index 0000000000000..16ae7ee80e3b1 --- /dev/null +++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/runtime/batch/table/AsyncMLPredictITCase.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.planner.runtime.batch.table; + +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.EnvironmentSettings; +import org.apache.flink.table.api.TableEnvironment; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.planner.plan.nodes.exec.batch.BatchExecMLPredictTableFunction; +import org.apache.flink.table.planner.runtime.utils.MLPredictITCaseBase; + +import org.junit.jupiter.api.BeforeEach; + +/** ITCase for async ML_PREDICT in batch mode. Tests {@link BatchExecMLPredictTableFunction}. */ +public class AsyncMLPredictITCase extends MLPredictITCaseBase { + + private StreamExecutionEnvironment env; + + @BeforeEach + @Override + public void before() throws Exception { + env = StreamExecutionEnvironment.getExecutionEnvironment(); + env.getConfig().enableObjectReuse(); + super.before(); + } + + @Override + protected TableEnvironment getTableEnvironment() { + EnvironmentSettings settings = EnvironmentSettings.newInstance().inBatchMode().build(); + return StreamTableEnvironment.create(env, settings); + } + + @Override + protected boolean isAsync() { + return true; + } +} diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/runtime/batch/table/MLPredictITCase.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/runtime/batch/table/MLPredictITCase.java new file mode 100644 index 0000000000000..bed4fc2d410eb --- /dev/null +++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/runtime/batch/table/MLPredictITCase.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.planner.runtime.batch.table; + +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.EnvironmentSettings; +import org.apache.flink.table.api.TableEnvironment; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.planner.plan.nodes.exec.batch.BatchExecMLPredictTableFunction; +import org.apache.flink.table.planner.runtime.utils.MLPredictITCaseBase; + +/** ITCase for {@link BatchExecMLPredictTableFunction}. */ +public class MLPredictITCase extends MLPredictITCaseBase { + + @Override + protected TableEnvironment getTableEnvironment() { + EnvironmentSettings settings = EnvironmentSettings.newInstance().inBatchMode().build(); + return StreamTableEnvironment.create( + StreamExecutionEnvironment.getExecutionEnvironment(), settings); + } + + @Override + protected boolean isAsync() { + return false; + } +} diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/runtime/stream/table/AsyncMLPredictITCase.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/runtime/stream/table/AsyncMLPredictITCase.java index 3e2dea310a588..0c95639310d97 100644 --- a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/runtime/stream/table/AsyncMLPredictITCase.java +++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/runtime/stream/table/AsyncMLPredictITCase.java @@ -18,271 +18,36 @@ package org.apache.flink.table.planner.runtime.stream.table; -import org.apache.flink.core.testutils.FlinkAssertions; -import org.apache.flink.table.api.config.ExecutionConfigOptions; -import org.apache.flink.table.planner.factories.TestValuesModelFactory; -import org.apache.flink.table.planner.factories.TestValuesTableFactory; -import org.apache.flink.table.planner.runtime.utils.StreamingWithStateTestBase; -import org.apache.flink.testutils.junit.extensions.parameterized.ParameterizedTestExtension; -import org.apache.flink.testutils.junit.extensions.parameterized.Parameters; -import org.apache.flink.types.Row; -import org.apache.flink.util.CollectionUtil; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.EnvironmentSettings; +import org.apache.flink.table.api.TableEnvironment; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.planner.runtime.utils.MLPredictITCaseBase; import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.TestTemplate; -import org.junit.jupiter.api.extension.ExtendWith; -import java.util.Arrays; -import java.util.Collection; -import java.util.Collections; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.concurrent.TimeoutException; +/** ITCase for async ML_PREDICT in stream mode. */ +public class AsyncMLPredictITCase extends MLPredictITCaseBase { -import static org.assertj.core.api.Assertions.assertThatList; -import static org.assertj.core.api.Assertions.assertThatThrownBy; - -/** ITCase for async ML_PREDICT. */ -@ExtendWith(ParameterizedTestExtension.class) -public class AsyncMLPredictITCase extends StreamingWithStateTestBase { - - private final Boolean objectReuse; - private final ExecutionConfigOptions.AsyncOutputMode asyncOutputMode; - - public AsyncMLPredictITCase( - StateBackendMode backend, - Boolean objectReuse, - ExecutionConfigOptions.AsyncOutputMode asyncOutputMode) { - super(backend); - - this.objectReuse = objectReuse; - this.asyncOutputMode = asyncOutputMode; - } - - private final List data = - Arrays.asList( - Row.of(1L, 12, "Julian"), - Row.of(2L, 15, "Hello"), - Row.of(3L, 15, "Fabian"), - Row.of(8L, 11, "Hello world"), - Row.of(9L, 12, "Hello world!")); - - private final List dataWithNull = - Arrays.asList( - Row.of(15L, null, "Hello"), - Row.of(3L, 15, "Fabian"), - Row.of(11L, null, "Hello world"), - Row.of(9L, 12, "Hello world!")); - - private final Map> id2features = new HashMap<>(); - - { - id2features.put(Row.of(1L), Collections.singletonList(Row.of("x1", 1, "z1"))); - id2features.put(Row.of(2L), Collections.singletonList(Row.of("x2", 2, "z2"))); - id2features.put(Row.of(3L), Collections.singletonList(Row.of("x3", 3, "z3"))); - id2features.put(Row.of(8L), Collections.singletonList(Row.of("x8", 8, "z8"))); - id2features.put(Row.of(9L), Collections.singletonList(Row.of("x9", 9, "z9"))); - } - - private final Map> idLen2features = new HashMap<>(); - - { - idLen2features.put( - Row.of(15L, null), Collections.singletonList(Row.of("x1", 1, "zNull15"))); - idLen2features.put(Row.of(15L, 15), Collections.singletonList(Row.of("x1", 1, "z1515"))); - idLen2features.put(Row.of(3L, 15), Collections.singletonList(Row.of("x2", 2, "z315"))); - idLen2features.put( - Row.of(11L, null), Collections.singletonList(Row.of("x3", 3, "zNull11"))); - idLen2features.put(Row.of(11L, 11), Collections.singletonList(Row.of("x3", 3, "z1111"))); - idLen2features.put(Row.of(9L, 12), Collections.singletonList(Row.of("x8", 8, "z912"))); - idLen2features.put(Row.of(12L, 12), Collections.singletonList(Row.of("x8", 8, "z1212"))); - } - - private final Map> content2vector = new HashMap<>(); - - { - content2vector.put( - Row.of("Julian"), - Collections.singletonList(Row.of((Object) new Float[] {1.0f, 2.0f, 3.0f}))); - content2vector.put( - Row.of("Hello"), - Collections.singletonList(Row.of((Object) new Float[] {2.0f, 3.0f, 4.0f}))); - content2vector.put( - Row.of("Fabian"), - Collections.singletonList(Row.of((Object) new Float[] {3.0f, 4.0f, 5.0f}))); - content2vector.put( - Row.of("Hello world"), - Collections.singletonList(Row.of((Object) new Float[] {4.0f, 5.0f, 6.0f}))); - content2vector.put( - Row.of("Hello world!"), - Collections.singletonList(Row.of((Object) new Float[] {5.0f, 6.0f, 7.0f}))); - } + private StreamExecutionEnvironment env; @BeforeEach - public void before() { + @Override + public void before() throws Exception { + env = StreamExecutionEnvironment.getExecutionEnvironment(); + env.setParallelism(4); + env.getConfig().enableObjectReuse(); super.before(); - if (objectReuse) { - env().getConfig().enableObjectReuse(); - } else { - env().getConfig().disableObjectReuse(); - } - tEnv().getConfig() - .set( - ExecutionConfigOptions.TABLE_EXEC_ASYNC_ML_PREDICT_OUTPUT_MODE, - asyncOutputMode); - - createScanTable("src", data); - createScanTable("nullable_src", dataWithNull); - - tEnv().executeSql( - String.format( - "CREATE MODEL m1\n" - + "INPUT (a BIGINT)\n" - + "OUTPUT (x STRING, y INT, z STRING)\n" - + "WITH (\n" - + " 'provider' = 'values'," - + " 'async' = 'true'," - + " 'data-id' = '%s'" - + ")", - TestValuesModelFactory.registerData(id2features))); - tEnv().executeSql( - String.format( - "CREATE MODEL m2\n" - + "INPUT (a BIGINT, b INT)\n" - + "OUTPUT (x STRING, y INT, z STRING)\n" - + "WITH (\n" - + " 'provider' = 'values'," - + " 'async' = 'true'," - + " 'data-id' = '%s'" - + ")", - TestValuesModelFactory.registerData(idLen2features))); - tEnv().executeSql( - String.format( - "CREATE MODEL m3\n" - + "INPUT (content STRING)\n" - + "OUTPUT (vector ARRAY)\n" - + "WITH (\n" - + " 'provider' = 'values'," - + " 'data-id' = '%s'," - + " 'latency' = '1000'," - + " 'async' = 'true'" - + ")", - TestValuesModelFactory.registerData(content2vector))); - } - - @TestTemplate - public void testAsyncMLPredict() { - assertThatList( - CollectionUtil.iteratorToList( - tEnv().executeSql( - "SELECT id, z FROM ML_PREDICT(TABLE src, MODEL m1, DESCRIPTOR(`id`))") - .collect())) - .containsExactlyInAnyOrder( - Row.of(1L, "z1"), - Row.of(2L, "z2"), - Row.of(3L, "z3"), - Row.of(8L, "z8"), - Row.of(9L, "z9")); - } - - @TestTemplate - public void testAsyncMLPredictWithMultipleFields() { - assertThatList( - CollectionUtil.iteratorToList( - tEnv().executeSql( - "SELECT id, len, z FROM ML_PREDICT(TABLE nullable_src, MODEL m2, DESCRIPTOR(`id`, `len`))") - .collect())) - .containsExactlyInAnyOrder( - Row.of(3L, 15, "z315"), - Row.of(9L, 12, "z912"), - Row.of(11L, null, "zNull11"), - Row.of(15L, null, "zNull15")); - } - - @TestTemplate - public void testAsyncMLPredictWithConstantValues() { - assertThatList( - CollectionUtil.iteratorToList( - tEnv().executeSql( - "WITH v(id) AS (SELECT * FROM (VALUES (CAST(1 AS BIGINT)), (CAST(2 AS BIGINT)))) " - + "SELECT * FROM ML_PREDICT(INPUT => TABLE v, MODEL => MODEL `m1`, ARGS => DESCRIPTOR(`id`))") - .collect())) - .containsExactlyInAnyOrder(Row.of(1L, "x1", 1, "z1"), Row.of(2L, "x2", 2, "z2")); - } - - @TestTemplate - public void testAsyncPredictWithRuntimeConfig() { - assertThatThrownBy( - () -> - tEnv().executeSql( - "SELECT id, vector FROM ML_PREDICT(TABLE src, MODEL m3, DESCRIPTOR(`content`), MAP['timeout', '1ms'])") - .await()) - .satisfies( - FlinkAssertions.anyCauseMatches( - TimeoutException.class, "Async function call has timed out.")); } - private void createScanTable(String tableName, List data) { - String dataId = TestValuesTableFactory.registerData(data); - tEnv().executeSql( - String.format( - "CREATE TABLE `%s`(\n" - + " id BIGINT," - + " len INT," - + " content STRING," - + " PRIMARY KEY (`id`) NOT ENFORCED" - + ") WITH (" - + " 'connector' = 'values'," - + " 'data-id' = '%s'" - + ")", - tableName, dataId)); + @Override + protected TableEnvironment getTableEnvironment() { + EnvironmentSettings settings = EnvironmentSettings.newInstance().inStreamingMode().build(); + return StreamTableEnvironment.create(env, settings); } - @Parameters(name = "backend = {0}, objectReuse = {1}, asyncOutputMode = {2}") - public static Collection parameters() { - return Arrays.asList( - new Object[][] { - { - StreamingWithStateTestBase.HEAP_BACKEND(), - true, - ExecutionConfigOptions.AsyncOutputMode.ALLOW_UNORDERED - }, - { - StreamingWithStateTestBase.HEAP_BACKEND(), - true, - ExecutionConfigOptions.AsyncOutputMode.ORDERED - }, - { - StreamingWithStateTestBase.HEAP_BACKEND(), - false, - ExecutionConfigOptions.AsyncOutputMode.ALLOW_UNORDERED - }, - { - StreamingWithStateTestBase.HEAP_BACKEND(), - false, - ExecutionConfigOptions.AsyncOutputMode.ORDERED - }, - { - StreamingWithStateTestBase.ROCKSDB_BACKEND(), - true, - ExecutionConfigOptions.AsyncOutputMode.ALLOW_UNORDERED - }, - { - StreamingWithStateTestBase.ROCKSDB_BACKEND(), - true, - ExecutionConfigOptions.AsyncOutputMode.ORDERED - }, - { - StreamingWithStateTestBase.ROCKSDB_BACKEND(), - false, - ExecutionConfigOptions.AsyncOutputMode.ALLOW_UNORDERED - }, - { - StreamingWithStateTestBase.ROCKSDB_BACKEND(), - false, - ExecutionConfigOptions.AsyncOutputMode.ORDERED - } - }); + @Override + protected boolean isAsync() { + return true; } } diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/runtime/stream/table/MLPredictITCase.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/runtime/stream/table/MLPredictITCase.java index 4bf0f44f46541..c4d34fb0b1fc9 100644 --- a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/runtime/stream/table/MLPredictITCase.java +++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/runtime/stream/table/MLPredictITCase.java @@ -18,176 +18,25 @@ package org.apache.flink.table.planner.runtime.stream.table; -import org.apache.flink.table.api.Model; -import org.apache.flink.table.api.Table; -import org.apache.flink.table.planner.factories.TestValuesModelFactory; -import org.apache.flink.table.planner.factories.TestValuesTableFactory; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.EnvironmentSettings; +import org.apache.flink.table.api.TableEnvironment; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; import org.apache.flink.table.planner.plan.nodes.exec.stream.StreamExecMLPredictTableFunction; -import org.apache.flink.table.planner.runtime.utils.StreamingTestBase; -import org.apache.flink.types.ColumnList; -import org.apache.flink.types.Row; -import org.apache.flink.util.CollectionUtil; - -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; - -import java.util.Arrays; -import java.util.Collections; -import java.util.HashMap; -import java.util.List; -import java.util.Map; - -import static org.assertj.core.api.AssertionsForInterfaceTypes.assertThatList; +import org.apache.flink.table.planner.runtime.utils.MLPredictITCaseBase; /** ITCase for {@link StreamExecMLPredictTableFunction}. */ -public class MLPredictITCase extends StreamingTestBase { - - private final List data = - Arrays.asList( - Row.of(1L, 12, "Julian"), - Row.of(2L, 15, "Hello"), - Row.of(3L, 15, "Fabian"), - Row.of(8L, 11, "Hello world"), - Row.of(9L, 12, "Hello world!")); - - private final List dataWithNull = - Arrays.asList( - Row.of(null, 15, "Hello"), - Row.of(3L, 15, "Fabian"), - Row.of(null, 11, "Hello world"), - Row.of(9L, 12, "Hello world!")); - - private final Map> id2features = new HashMap<>(); - - { - id2features.put(Row.of(1L), Collections.singletonList(Row.of("x1", 1, "z1"))); - id2features.put(Row.of(2L), Collections.singletonList(Row.of("x2", 2, "z2"))); - id2features.put(Row.of(3L), Collections.singletonList(Row.of("x3", 3, "z3"))); - id2features.put(Row.of(8L), Collections.singletonList(Row.of("x8", 8, "z8"))); - id2features.put(Row.of(9L), Collections.singletonList(Row.of("x9", 9, "z9"))); - } - - private final Map> idLen2features = new HashMap<>(); - - { - idLen2features.put(Row.of(null, 15), Collections.singletonList(Row.of("x1", 1, "zNull15"))); - idLen2features.put(Row.of(15L, 15), Collections.singletonList(Row.of("x1", 1, "z1515"))); - idLen2features.put(Row.of(3L, 15), Collections.singletonList(Row.of("x2", 2, "z315"))); - idLen2features.put(Row.of(null, 11), Collections.singletonList(Row.of("x3", 3, "zNull11"))); - idLen2features.put(Row.of(11L, 11), Collections.singletonList(Row.of("x3", 3, "z1111"))); - idLen2features.put(Row.of(9L, 12), Collections.singletonList(Row.of("x8", 8, "z912"))); - idLen2features.put(Row.of(12L, 12), Collections.singletonList(Row.of("x8", 8, "z1212"))); - } - - @BeforeEach - public void before() throws Exception { - super.before(); - createScanTable("src", data); - createScanTable("nullable_src", dataWithNull); - - tEnv().executeSql( - String.format( - "CREATE MODEL m1\n" - + "INPUT (a BIGINT)\n" - + "OUTPUT (x STRING, y INT, z STRING)\n" - + "WITH (\n" - + " 'provider' = 'values'," - + " 'data-id' = '%s'" - + ")", - TestValuesModelFactory.registerData(id2features))); - tEnv().executeSql( - String.format( - "CREATE MODEL m2\n" - + "INPUT (a BIGINT, b INT)\n" - + "OUTPUT (x STRING, y INT, z STRING)\n" - + "WITH (\n" - + " 'provider' = 'values'," - + " 'data-id' = '%s'" - + ")", - TestValuesModelFactory.registerData(idLen2features))); - } - - @Test - public void testMLPredict() { - List result = - CollectionUtil.iteratorToList( - tEnv().executeSql( - "SELECT id, z " - + "FROM ML_PREDICT(TABLE src, MODEL m1, DESCRIPTOR(`id`)) ") - .collect()); - - assertThatList(result) - .containsExactlyInAnyOrder( - Row.of(1L, "z1"), - Row.of(2L, "z2"), - Row.of(3L, "z3"), - Row.of(8L, "z8"), - Row.of(9L, "z9")); - } - - @Test - public void testMLPredictWithMultipleFields() { - List result = - CollectionUtil.iteratorToList( - tEnv().executeSql( - "SELECT id, len, z " - + "FROM ML_PREDICT(TABLE nullable_src, MODEL m2, DESCRIPTOR(`id`, `len`)) ") - .collect()); - - assertThatList(result) - .containsExactlyInAnyOrder( - Row.of(3L, 15, "z315"), - Row.of(9L, 12, "z912"), - Row.of(null, 11, "zNull11"), - Row.of(null, 15, "zNull15")); - } - - @Test - public void testPredictWithConstantValues() { - List result = - CollectionUtil.iteratorToList( - tEnv().executeSql( - "WITH v(id) AS (SELECT * FROM (VALUES (CAST(1 AS BIGINT)), (CAST(2 AS BIGINT)))) " - + "SELECT * FROM ML_PREDICT( " - + " INPUT => TABLE v, " - + " MODEL => MODEL `m1`, " - + " ARGS => DESCRIPTOR(`id`) " - + ")") - .collect()); - - assertThatList(result) - .containsExactlyInAnyOrder(Row.of(1L, "x1", 1, "z1"), Row.of(2L, "x2", 2, "z2")); - } +public class MLPredictITCase extends MLPredictITCaseBase { - @Test - public void testPredictTableApiWithView() { - Model model = tEnv().fromModel("m1"); - Table table = tEnv().from("src"); - tEnv().createView("view_src", model.predict(table, ColumnList.of("id"))); - List results = - CollectionUtil.iteratorToList( - tEnv().executeSql("select * from view_src").collect()); - assertThatList(results) - .containsExactlyInAnyOrder( - Row.of(1L, 12, "Julian", "x1", 1, "z1"), - Row.of(2L, 15, "Hello", "x2", 2, "z2"), - Row.of(3L, 15, "Fabian", "x3", 3, "z3"), - Row.of(8L, 11, "Hello world", "x8", 8, "z8"), - Row.of(9L, 12, "Hello world!", "x9", 9, "z9")); + @Override + protected TableEnvironment getTableEnvironment() { + EnvironmentSettings settings = EnvironmentSettings.newInstance().inStreamingMode().build(); + return StreamTableEnvironment.create( + StreamExecutionEnvironment.getExecutionEnvironment(), settings); } - private void createScanTable(String tableName, List data) { - String dataId = TestValuesTableFactory.registerData(data); - tEnv().executeSql( - String.format( - "CREATE TABLE `%s`(\n" - + " id BIGINT,\n" - + " len INT,\n" - + " content STRING\n" - + ") WITH (\n" - + " 'connector' = 'values',\n" - + " 'data-id' = '%s'\n" - + ")", - tableName, dataId)); + @Override + protected boolean isAsync() { + return false; } } diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/runtime/utils/MLPredictITCaseBase.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/runtime/utils/MLPredictITCaseBase.java new file mode 100644 index 0000000000000..4e1eb8daf686c --- /dev/null +++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/runtime/utils/MLPredictITCaseBase.java @@ -0,0 +1,258 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.planner.runtime.utils; + +import org.apache.flink.core.testutils.FlinkAssertions; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.Model; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.TableEnvironment; +import org.apache.flink.table.planner.factories.TestValuesModelFactory; +import org.apache.flink.table.planner.factories.TestValuesTableFactory; +import org.apache.flink.test.junit5.MiniClusterExtension; +import org.apache.flink.types.ColumnList; +import org.apache.flink.types.Row; +import org.apache.flink.util.CollectionUtil; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.TimeoutException; + +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.AssertionsForInterfaceTypes.assertThatList; + +/** Base ITCase to verify {@code ML_PREDICT} function. */ +public abstract class MLPredictITCaseBase { + + @RegisterExtension + private static final MiniClusterExtension MINI_CLUSTER_EXTENSION = new MiniClusterExtension(); + + protected TableEnvironment tEnv; + + protected abstract TableEnvironment getTableEnvironment(); + + protected abstract boolean isAsync(); + + private final List data = + Arrays.asList( + Row.of(1L, 12, "Julian"), + Row.of(2L, 15, "Hello"), + Row.of(3L, 15, "Fabian"), + Row.of(8L, 11, "Hello world"), + Row.of(9L, 12, "Hello world!")); + + private final List dataWithNull = + Arrays.asList( + Row.of(null, 15, "Hello"), + Row.of(3L, 15, "Fabian"), + Row.of(null, 11, "Hello world"), + Row.of(9L, 12, "Hello world!")); + + private final Map> id2features = new HashMap<>(); + + { + id2features.put(Row.of(1L), Collections.singletonList(Row.of("x1", 1, "z1"))); + id2features.put(Row.of(2L), Collections.singletonList(Row.of("x2", 2, "z2"))); + id2features.put(Row.of(3L), Collections.singletonList(Row.of("x3", 3, "z3"))); + id2features.put(Row.of(8L), Collections.singletonList(Row.of("x8", 8, "z8"))); + id2features.put(Row.of(9L), Collections.singletonList(Row.of("x9", 9, "z9"))); + } + + private final Map> idLen2features = new HashMap<>(); + + { + idLen2features.put(Row.of(null, 15), Collections.singletonList(Row.of("x1", 1, "zNull15"))); + idLen2features.put(Row.of(15L, 15), Collections.singletonList(Row.of("x1", 1, "z1515"))); + idLen2features.put(Row.of(3L, 15), Collections.singletonList(Row.of("x2", 2, "z315"))); + idLen2features.put(Row.of(null, 11), Collections.singletonList(Row.of("x3", 3, "zNull11"))); + idLen2features.put(Row.of(11L, 11), Collections.singletonList(Row.of("x3", 3, "z1111"))); + idLen2features.put(Row.of(9L, 12), Collections.singletonList(Row.of("x8", 8, "z912"))); + idLen2features.put(Row.of(12L, 12), Collections.singletonList(Row.of("x8", 8, "z1212"))); + } + + private final Map> content2vector = new HashMap<>(); + + { + content2vector.put( + Row.of("Julian"), + Collections.singletonList(Row.of((Object) new Float[] {1.0f, 2.0f, 3.0f}))); + content2vector.put( + Row.of("Hello"), + Collections.singletonList(Row.of((Object) new Float[] {2.0f, 3.0f, 4.0f}))); + content2vector.put( + Row.of("Fabian"), + Collections.singletonList(Row.of((Object) new Float[] {3.0f, 4.0f, 5.0f}))); + content2vector.put( + Row.of("Hello world"), + Collections.singletonList(Row.of((Object) new Float[] {4.0f, 5.0f, 6.0f}))); + content2vector.put( + Row.of("Hello world!"), + Collections.singletonList(Row.of((Object) new Float[] {5.0f, 6.0f, 7.0f}))); + } + + @BeforeEach + public void before() throws Exception { + tEnv = getTableEnvironment(); + createScanTable("src", data); + createScanTable("nullable_src", dataWithNull); + + tEnv.executeSql( + String.format( + "CREATE MODEL m1\n" + + "INPUT (a BIGINT)\n" + + "OUTPUT (x STRING, y INT, z STRING)\n" + + "WITH (\n" + + " 'provider' = 'values'," + + " 'async' = '%s'," + + " 'data-id' = '%s'" + + ")", + isAsync(), TestValuesModelFactory.registerData(id2features))); + tEnv.executeSql( + String.format( + "CREATE MODEL m2\n" + + "INPUT (a BIGINT, b INT)\n" + + "OUTPUT (x STRING, y INT, z STRING)\n" + + "WITH (\n" + + " 'provider' = 'values'," + + " 'async' = '%s'," + + " 'data-id' = '%s'" + + ")", + isAsync(), TestValuesModelFactory.registerData(idLen2features))); + tEnv.executeSql( + String.format( + "CREATE MODEL m3\n" + + "INPUT (content STRING)\n" + + "OUTPUT (vector ARRAY)\n" + + "WITH (\n" + + " 'provider' = 'values'," + + " 'data-id' = '%s'," + + " 'latency' = '1000'," + + " 'async' = '%s'" + + ")", + TestValuesModelFactory.registerData(content2vector), isAsync())); + } + + @Test + public void testMLPredict() { + List result = + CollectionUtil.iteratorToList( + tEnv.executeSql( + "SELECT id, z " + + "FROM ML_PREDICT(TABLE src, MODEL m1, DESCRIPTOR(`id`)) ") + .collect()); + + assertThatList(result) + .containsExactlyInAnyOrder( + Row.of(1L, "z1"), + Row.of(2L, "z2"), + Row.of(3L, "z3"), + Row.of(8L, "z8"), + Row.of(9L, "z9")); + } + + @Test + public void testMLPredictWithMultipleFields() { + List result = + CollectionUtil.iteratorToList( + tEnv.executeSql( + "SELECT id, len, z " + + "FROM ML_PREDICT(TABLE nullable_src, MODEL m2, DESCRIPTOR(`id`, `len`)) ") + .collect()); + + assertThatList(result) + .containsExactlyInAnyOrder( + Row.of(3L, 15, "z315"), + Row.of(9L, 12, "z912"), + Row.of(null, 11, "zNull11"), + Row.of(null, 15, "zNull15")); + } + + @Test + public void testPredictWithConstantValues() { + List result = + CollectionUtil.iteratorToList( + tEnv.executeSql( + "WITH v(id) AS (SELECT * FROM (VALUES (CAST(1 AS BIGINT)), (CAST(2 AS BIGINT)))) " + + "SELECT * FROM ML_PREDICT( " + + " INPUT => TABLE v, " + + " MODEL => MODEL `m1`, " + + " ARGS => DESCRIPTOR(`id`) " + + ")") + .collect()); + + assertThatList(result) + .containsExactlyInAnyOrder(Row.of(1L, "x1", 1, "z1"), Row.of(2L, "x2", 2, "z2")); + } + + @Test + public void testPredictTableApiWithView() { + Model model = tEnv.fromModel("m1"); + Table table = tEnv.from("src"); + tEnv.createView("view_src", model.predict(table, ColumnList.of("id"))); + List results = + CollectionUtil.iteratorToList(tEnv.executeSql("select * from view_src").collect()); + assertThatList(results) + .containsExactlyInAnyOrder( + Row.of(1L, 12, "Julian", "x1", 1, "z1"), + Row.of(2L, 15, "Hello", "x2", 2, "z2"), + Row.of(3L, 15, "Fabian", "x3", 3, "z3"), + Row.of(8L, 11, "Hello world", "x8", 8, "z8"), + Row.of(9L, 12, "Hello world!", "x9", 9, "z9")); + } + + @Test + public void testPredictWithRuntimeConfig() { + if (!isAsync()) { + // Only test async timeout for async mode + return; + } + assertThatThrownBy( + () -> + tEnv.executeSql( + "SELECT id, vector FROM ML_PREDICT(TABLE src, MODEL m3, DESCRIPTOR(`content`), MAP['timeout', '1ms'])") + .await()) + .satisfies( + FlinkAssertions.anyCauseMatches( + TimeoutException.class, "Async function call has timed out.")); + } + + private void createScanTable(String tableName, List data) { + String dataId = TestValuesTableFactory.registerData(data); + String bounded = tEnv instanceof StreamExecutionEnvironment ? "false" : "true"; + tEnv.executeSql( + String.format( + "CREATE TABLE `%s`(\n" + + " id BIGINT,\n" + + " len INT,\n" + + " content STRING\n" + + ") WITH (\n" + + " 'connector' = 'values',\n" + + " 'data-id' = '%s',\n" + + " 'bounded' = '%s'\n" + + ")", + tableName, dataId, bounded)); + } +} diff --git a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/batch/sql/MLPredictTableFunctionTest.xml b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/batch/sql/MLPredictTableFunctionTest.xml new file mode 100644 index 0000000000000..8e5c5c6696bca --- /dev/null +++ b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/batch/sql/MLPredictTableFunctionTest.xml @@ -0,0 +1,939 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + TABLE MyTable, MODEL => MODEL MyModel, ARGS => DESCRIPTOR(a, b)))]]> + + + + + + + + + + + TABLE MyTable, MODEL => MODEL MyModel, ARGS => DESCRIPTOR(a, b),CONFIG => MAP['key', 'value']))]]> + + + + + + + + + + + + + + + + + + + + diff --git a/flink-table/flink-table-planner/src/test/resources/restore-tests/batch-exec-ml-predict-table-function_1/async-unordered-ml-predict/plan/async-unordered-ml-predict.json b/flink-table/flink-table-planner/src/test/resources/restore-tests/batch-exec-ml-predict-table-function_1/async-unordered-ml-predict/plan/async-unordered-ml-predict.json new file mode 100644 index 0000000000000..0c17f2850e152 --- /dev/null +++ b/flink-table/flink-table-planner/src/test/resources/restore-tests/batch-exec-ml-predict-table-function_1/async-unordered-ml-predict/plan/async-unordered-ml-predict.json @@ -0,0 +1,134 @@ +{ + "flinkVersion" : "2.3", + "nodes" : [ { + "id" : 4, + "type" : "batch-exec-table-source-scan_1", + "scanTableSource" : { + "table" : { + "identifier" : "`default_catalog`.`default_database`.`features`", + "resolvedTable" : { + "schema" : { + "columns" : [ { + "name" : "id", + "dataType" : "INT NOT NULL" + }, { + "name" : "feature", + "dataType" : "VARCHAR(2147483647)" + } ], + "primaryKey" : { + "name" : "PK_id", + "type" : "PRIMARY_KEY", + "columns" : [ "id" ] + } + } + } + } + }, + "outputType" : "ROW<`id` INT NOT NULL, `feature` VARCHAR(2147483647)>", + "description" : "TableSourceScan(table=[[default_catalog, default_database, features]], fields=[id, feature])", + "dynamicFilteringDataListenerID" : "cab90d1a-d797-4bd3-ac72-890a294544a2" + }, { + "id" : 5, + "type" : "batch-exec-ml-predict-table-function_1", + "configuration" : { + "table.exec.async-ml-predict.max-concurrent-operations" : "10", + "table.exec.async-ml-predict.output-mode" : "ALLOW_UNORDERED", + "table.exec.async-ml-predict.timeout" : "3 min" + }, + "mlPredictSpec" : { + "features" : [ { + "type" : "FieldRef", + "index" : 1 + } ], + "runtimeConfig" : { } + }, + "modelSpec" : { + "model" : { + "identifier" : "`default_catalog`.`default_database`.`chatgpt`", + "resolvedModel" : { + "inputSchema" : { + "columns" : [ { + "name" : "feature", + "dataType" : "VARCHAR(2147483647)" + } ] + }, + "outputSchema" : { + "columns" : [ { + "name" : "category", + "dataType" : "VARCHAR(2147483647)" + } ] + } + } + } + }, + "asyncOptions" : { + "capacity " : 10, + "timeout" : 180000, + "output-mode" : "UNORDERED" + }, + "inputProperties" : [ { + "requiredDistribution" : { + "type" : "UNKNOWN" + }, + "damBehavior" : "PIPELINED", + "priority" : 0 + } ], + "outputType" : "ROW<`id` INT NOT NULL, `feature` VARCHAR(2147483647), `category` VARCHAR(2147483647)>", + "description" : "MLPredictTableFunction(invocation=[ML_PREDICT(TABLE(#0), Model(MODEL default_catalog.default_database.chatgpt), DESCRIPTOR(_UTF-16LE'feature'), DEFAULT())], rowType=[RecordType(INTEGER id, VARCHAR(2147483647) feature, VARCHAR(2147483647) category)])" + }, { + "id" : 6, + "type" : "batch-exec-sink_1", + "configuration" : { + "table.exec.sink.not-null-enforcer" : "ERROR", + "table.exec.sink.type-length-enforcer" : "IGNORE" + }, + "dynamicTableSink" : { + "table" : { + "identifier" : "`default_catalog`.`default_database`.`sink_t`", + "resolvedTable" : { + "schema" : { + "columns" : [ { + "name" : "id", + "dataType" : "INT NOT NULL" + }, { + "name" : "feature", + "dataType" : "VARCHAR(2147483647)" + }, { + "name" : "category", + "dataType" : "VARCHAR(2147483647)" + } ], + "primaryKey" : { + "name" : "PK_id", + "type" : "PRIMARY_KEY", + "columns" : [ "id" ] + } + } + } + } + }, + "inputProperties" : [ { + "requiredDistribution" : { + "type" : "UNKNOWN" + }, + "damBehavior" : "BLOCKING", + "priority" : 0 + } ], + "outputType" : "ROW<`id` INT NOT NULL, `feature` VARCHAR(2147483647), `category` VARCHAR(2147483647)>", + "description" : "Sink(table=[default_catalog.default_database.sink_t], fields=[id, feature, category])" + } ], + "edges" : [ { + "source" : 4, + "target" : 5, + "shuffle" : { + "type" : "FORWARD" + }, + "shuffleMode" : "PIPELINED" + }, { + "source" : 5, + "target" : 6, + "shuffle" : { + "type" : "FORWARD" + }, + "shuffleMode" : "PIPELINED" + } ] +} \ No newline at end of file diff --git a/flink-table/flink-table-planner/src/test/resources/restore-tests/batch-exec-ml-predict-table-function_1/sync-ml-predict-with-runtime-options/plan/sync-ml-predict-with-runtime-options.json b/flink-table/flink-table-planner/src/test/resources/restore-tests/batch-exec-ml-predict-table-function_1/sync-ml-predict-with-runtime-options/plan/sync-ml-predict-with-runtime-options.json new file mode 100644 index 0000000000000..7d97064c47ced --- /dev/null +++ b/flink-table/flink-table-planner/src/test/resources/restore-tests/batch-exec-ml-predict-table-function_1/sync-ml-predict-with-runtime-options/plan/sync-ml-predict-with-runtime-options.json @@ -0,0 +1,132 @@ +{ + "flinkVersion" : "2.3", + "nodes" : [ { + "id" : 7, + "type" : "batch-exec-table-source-scan_1", + "scanTableSource" : { + "table" : { + "identifier" : "`default_catalog`.`default_database`.`features`", + "resolvedTable" : { + "schema" : { + "columns" : [ { + "name" : "id", + "dataType" : "INT NOT NULL" + }, { + "name" : "feature", + "dataType" : "VARCHAR(2147483647)" + } ], + "primaryKey" : { + "name" : "PK_id", + "type" : "PRIMARY_KEY", + "columns" : [ "id" ] + } + } + } + } + }, + "outputType" : "ROW<`id` INT NOT NULL, `feature` VARCHAR(2147483647)>", + "description" : "TableSourceScan(table=[[default_catalog, default_database, features]], fields=[id, feature])", + "dynamicFilteringDataListenerID" : "6163b6a7-cef6-4b79-acb1-fe5baefe0566" + }, { + "id" : 8, + "type" : "batch-exec-ml-predict-table-function_1", + "configuration" : { + "table.exec.async-ml-predict.max-concurrent-operations" : "10", + "table.exec.async-ml-predict.output-mode" : "ORDERED", + "table.exec.async-ml-predict.timeout" : "3 min" + }, + "mlPredictSpec" : { + "features" : [ { + "type" : "FieldRef", + "index" : 1 + } ], + "runtimeConfig" : { + "async" : "false" + } + }, + "modelSpec" : { + "model" : { + "identifier" : "`default_catalog`.`default_database`.`chatgpt`", + "resolvedModel" : { + "inputSchema" : { + "columns" : [ { + "name" : "feature", + "dataType" : "VARCHAR(2147483647)" + } ] + }, + "outputSchema" : { + "columns" : [ { + "name" : "category", + "dataType" : "VARCHAR(2147483647)" + } ] + } + } + } + }, + "asyncOptions" : null, + "inputProperties" : [ { + "requiredDistribution" : { + "type" : "UNKNOWN" + }, + "damBehavior" : "PIPELINED", + "priority" : 0 + } ], + "outputType" : "ROW<`id` INT NOT NULL, `feature` VARCHAR(2147483647), `category` VARCHAR(2147483647)>", + "description" : "MLPredictTableFunction(invocation=[ML_PREDICT(TABLE(#0), Model(MODEL default_catalog.default_database.chatgpt), DESCRIPTOR(_UTF-16LE'feature'), MAP(_UTF-16LE'async', _UTF-16LE'false'))], rowType=[RecordType(INTEGER id, VARCHAR(2147483647) feature, VARCHAR(2147483647) category)])" + }, { + "id" : 9, + "type" : "batch-exec-sink_1", + "configuration" : { + "table.exec.sink.not-null-enforcer" : "ERROR", + "table.exec.sink.type-length-enforcer" : "IGNORE" + }, + "dynamicTableSink" : { + "table" : { + "identifier" : "`default_catalog`.`default_database`.`sink_t`", + "resolvedTable" : { + "schema" : { + "columns" : [ { + "name" : "id", + "dataType" : "INT NOT NULL" + }, { + "name" : "feature", + "dataType" : "VARCHAR(2147483647)" + }, { + "name" : "category", + "dataType" : "VARCHAR(2147483647)" + } ], + "primaryKey" : { + "name" : "PK_id", + "type" : "PRIMARY_KEY", + "columns" : [ "id" ] + } + } + } + } + }, + "inputProperties" : [ { + "requiredDistribution" : { + "type" : "UNKNOWN" + }, + "damBehavior" : "BLOCKING", + "priority" : 0 + } ], + "outputType" : "ROW<`id` INT NOT NULL, `feature` VARCHAR(2147483647), `category` VARCHAR(2147483647)>", + "description" : "Sink(table=[default_catalog.default_database.sink_t], fields=[id, feature, category])" + } ], + "edges" : [ { + "source" : 7, + "target" : 8, + "shuffle" : { + "type" : "FORWARD" + }, + "shuffleMode" : "PIPELINED" + }, { + "source" : 8, + "target" : 9, + "shuffle" : { + "type" : "FORWARD" + }, + "shuffleMode" : "PIPELINED" + } ] +} \ No newline at end of file diff --git a/flink-table/flink-table-planner/src/test/resources/restore-tests/batch-exec-ml-predict-table-function_1/sync-ml-predict/plan/sync-ml-predict.json b/flink-table/flink-table-planner/src/test/resources/restore-tests/batch-exec-ml-predict-table-function_1/sync-ml-predict/plan/sync-ml-predict.json new file mode 100644 index 0000000000000..8ec7835fc96f2 --- /dev/null +++ b/flink-table/flink-table-planner/src/test/resources/restore-tests/batch-exec-ml-predict-table-function_1/sync-ml-predict/plan/sync-ml-predict.json @@ -0,0 +1,130 @@ +{ + "flinkVersion" : "2.3", + "nodes" : [ { + "id" : 1, + "type" : "batch-exec-table-source-scan_1", + "scanTableSource" : { + "table" : { + "identifier" : "`default_catalog`.`default_database`.`features`", + "resolvedTable" : { + "schema" : { + "columns" : [ { + "name" : "id", + "dataType" : "INT NOT NULL" + }, { + "name" : "feature", + "dataType" : "VARCHAR(2147483647)" + } ], + "primaryKey" : { + "name" : "PK_id", + "type" : "PRIMARY_KEY", + "columns" : [ "id" ] + } + } + } + } + }, + "outputType" : "ROW<`id` INT NOT NULL, `feature` VARCHAR(2147483647)>", + "description" : "TableSourceScan(table=[[default_catalog, default_database, features]], fields=[id, feature])", + "dynamicFilteringDataListenerID" : "7b204f5a-fc72-4c5d-9cae-4dfd20ec4ad8" + }, { + "id" : 2, + "type" : "batch-exec-ml-predict-table-function_1", + "configuration" : { + "table.exec.async-ml-predict.max-concurrent-operations" : "10", + "table.exec.async-ml-predict.output-mode" : "ORDERED", + "table.exec.async-ml-predict.timeout" : "3 min" + }, + "mlPredictSpec" : { + "features" : [ { + "type" : "FieldRef", + "index" : 1 + } ], + "runtimeConfig" : { } + }, + "modelSpec" : { + "model" : { + "identifier" : "`default_catalog`.`default_database`.`chatgpt`", + "resolvedModel" : { + "inputSchema" : { + "columns" : [ { + "name" : "feature", + "dataType" : "VARCHAR(2147483647)" + } ] + }, + "outputSchema" : { + "columns" : [ { + "name" : "category", + "dataType" : "VARCHAR(2147483647)" + } ] + } + } + } + }, + "asyncOptions" : null, + "inputProperties" : [ { + "requiredDistribution" : { + "type" : "UNKNOWN" + }, + "damBehavior" : "PIPELINED", + "priority" : 0 + } ], + "outputType" : "ROW<`id` INT NOT NULL, `feature` VARCHAR(2147483647), `category` VARCHAR(2147483647)>", + "description" : "MLPredictTableFunction(invocation=[ML_PREDICT(TABLE(#0), Model(MODEL default_catalog.default_database.chatgpt), DESCRIPTOR(_UTF-16LE'feature'), DEFAULT())], rowType=[RecordType(INTEGER id, VARCHAR(2147483647) feature, VARCHAR(2147483647) category)])" + }, { + "id" : 3, + "type" : "batch-exec-sink_1", + "configuration" : { + "table.exec.sink.not-null-enforcer" : "ERROR", + "table.exec.sink.type-length-enforcer" : "IGNORE" + }, + "dynamicTableSink" : { + "table" : { + "identifier" : "`default_catalog`.`default_database`.`sink_t`", + "resolvedTable" : { + "schema" : { + "columns" : [ { + "name" : "id", + "dataType" : "INT NOT NULL" + }, { + "name" : "feature", + "dataType" : "VARCHAR(2147483647)" + }, { + "name" : "category", + "dataType" : "VARCHAR(2147483647)" + } ], + "primaryKey" : { + "name" : "PK_id", + "type" : "PRIMARY_KEY", + "columns" : [ "id" ] + } + } + } + } + }, + "inputProperties" : [ { + "requiredDistribution" : { + "type" : "UNKNOWN" + }, + "damBehavior" : "BLOCKING", + "priority" : 0 + } ], + "outputType" : "ROW<`id` INT NOT NULL, `feature` VARCHAR(2147483647), `category` VARCHAR(2147483647)>", + "description" : "Sink(table=[default_catalog.default_database.sink_t], fields=[id, feature, category])" + } ], + "edges" : [ { + "source" : 1, + "target" : 2, + "shuffle" : { + "type" : "FORWARD" + }, + "shuffleMode" : "PIPELINED" + }, { + "source" : 2, + "target" : 3, + "shuffle" : { + "type" : "FORWARD" + }, + "shuffleMode" : "PIPELINED" + } ] +} \ No newline at end of file