88import net .snowflake .client .jdbc .internal .org .bouncycastle .pkcs .PKCSException ;
99import org .embulk .config .ConfigDiff ;
1010import org .embulk .config .ConfigException ;
11+ import org .embulk .config .ConfigSource ;
1112import org .embulk .config .TaskSource ;
1213import org .embulk .output .jdbc .*;
1314import org .embulk .output .snowflake .PrivateKeyReader ;
2425import org .embulk .util .config .ConfigDefault ;
2526
2627public class SnowflakeOutputPlugin extends AbstractJdbcOutputPlugin {
27- private StageIdentifier stageIdentifier ;
28-
2928 public interface SnowflakePluginTask extends PluginTask {
3029 @ Config ("driver_path" )
3130 @ ConfigDefault ("null" )
@@ -75,6 +74,10 @@ public interface SnowflakePluginTask extends PluginTask {
7574 @ Config ("empty_field_as_null" )
7675 @ ConfigDefault ("true" )
7776 public boolean getEmtpyFieldAsNull ();
77+
78+ @ Config ("delete_stage_on_error" )
79+ @ ConfigDefault ("false" )
80+ public boolean getDeleteStageOnError ();
7881 }
7982
8083 @ Override
@@ -144,25 +147,39 @@ protected JdbcOutputConnector getConnector(PluginTask task, boolean retryableMet
144147 }
145148
146149 @ Override
147- public ConfigDiff resume (
148- TaskSource taskSource , Schema schema , int taskCount , OutputPlugin .Control control ) {
149- throw new UnsupportedOperationException ("snowflake output plugin does not support resuming" );
150- }
151-
152- @ Override
153- protected void doCommit (JdbcOutputConnection con , PluginTask task , int taskCount )
154- throws SQLException {
155- super .doCommit (con , task , taskCount );
156- SnowflakeOutputConnection snowflakeCon = (SnowflakeOutputConnection ) con ;
157-
150+ public ConfigDiff transaction (
151+ ConfigSource config , Schema schema , int taskCount , OutputPlugin .Control control ) {
152+ PluginTask task = CONFIG_MAPPER .map (config , this .getTaskClass ());
158153 SnowflakePluginTask t = (SnowflakePluginTask ) task ;
159- if (this .stageIdentifier == null ) {
160- this .stageIdentifier = StageIdentifierHolder .getStageIdentifier (t );
154+ StageIdentifier stageIdentifier = StageIdentifierHolder .getStageIdentifier (t );
155+ ConfigDiff configDiff ;
156+ SnowflakeOutputConnection snowflakeCon = null ;
157+
158+ try {
159+ snowflakeCon = (SnowflakeOutputConnection ) getConnector (task , true ).connect (true );
160+ snowflakeCon .runCreateStage (stageIdentifier );
161+ configDiff = super .transaction (config , schema , taskCount , control );
162+ if (t .getDeleteStage ()) {
163+ snowflakeCon .runDropStage (stageIdentifier );
164+ }
165+ } catch (Exception e ) {
166+ if (t .getDeleteStage () && t .getDeleteStageOnError ()) {
167+ try {
168+ snowflakeCon .runDropStage (stageIdentifier );
169+ } catch (SQLException ex ) {
170+ throw new RuntimeException (ex );
171+ }
172+ }
173+ throw new RuntimeException (e );
161174 }
162175
163- if (t .getDeleteStage ()) {
164- snowflakeCon .runDropStage (this .stageIdentifier );
165- }
176+ return configDiff ;
177+ }
178+
179+ @ Override
180+ public ConfigDiff resume (
181+ TaskSource taskSource , Schema schema , int taskCount , OutputPlugin .Control control ) {
182+ throw new UnsupportedOperationException ("snowflake output plugin does not support resuming" );
166183 }
167184
168185 @ Override
@@ -179,20 +196,11 @@ protected BatchInsert newBatchInsert(PluginTask task, Optional<MergeConfig> merg
179196 throw new UnsupportedOperationException (
180197 "Snowflake output plugin doesn't support 'merge_direct' mode." );
181198 }
182-
183- SnowflakePluginTask t = (SnowflakePluginTask ) task ;
184- // TODO: put some where executes once
185- if (this .stageIdentifier == null ) {
186- SnowflakeOutputConnection snowflakeCon =
187- (SnowflakeOutputConnection ) getConnector (task , true ).connect (true );
188- this .stageIdentifier = StageIdentifierHolder .getStageIdentifier (t );
189- snowflakeCon .runCreateStage (this .stageIdentifier );
190- }
191199 SnowflakePluginTask pluginTask = (SnowflakePluginTask ) task ;
192200
193201 return new SnowflakeCopyBatchInsert (
194202 getConnector (task , true ),
195- this . stageIdentifier ,
203+ StageIdentifierHolder . getStageIdentifier ( pluginTask ) ,
196204 false ,
197205 pluginTask .getMaxUploadRetries (),
198206 pluginTask .getEmtpyFieldAsNull ());
0 commit comments