diff --git a/src/main/scala/org/splink/cpipe/CPipe.scala b/src/main/scala/org/splink/cpipe/CPipe.scala index ac3d654..d2eb413 100644 --- a/src/main/scala/org/splink/cpipe/CPipe.scala +++ b/src/main/scala/org/splink/cpipe/CPipe.scala @@ -1,7 +1,8 @@ package org.splink.cpipe -import org.splink.cpipe.processors.{Exporter, Exporter2, Importer} -import org.splink.cpipe.config.{Config, Arguments} +import org.splink.cpipe.processors.{Exporter, Exporter2, Importer, Importer2} +import org.splink.cpipe.config.{Arguments, Config} + import scala.language.implicitConversions import scala.util.{Failure, Success, Try} @@ -23,11 +24,13 @@ object CPipe { config.mode match { case "import" => new Importer().process(session, config) + case "import2" => + new Importer2().process(session, config) case "export" => - new Exporter().process(session, config) + new Exporter().process(session, exportConfig(config)) case "export2" => if (session.getCluster.getMetadata.getPartitioner == "org.apache.cassandra.dht.Murmur3Partitioner") { - new Exporter2().process(session, config) + new Exporter2().process(session, exportConfig(config)) } else { Output.log("mode 'export2' requires the cluster to use 'Murmur3Partitioner'") } @@ -62,6 +65,16 @@ object CPipe { conf.flags.useCompression) + def exportConfig(config: Config): Config = { + if (config.settings.threads != 1) { + Output.log("Export is limited to 1 thread") + config.copy(settings = config.settings.copy(threads = 1)) + } else { + config + } + } + + object ElapsedSecondFormat { def zero(i: Long) = if (i < 10) s"0$i" else s"$i" diff --git a/src/main/scala/org/splink/cpipe/JsonColumnParser.scala b/src/main/scala/org/splink/cpipe/JsonColumnParser.scala index 2841718..8820f27 100644 --- a/src/main/scala/org/splink/cpipe/JsonColumnParser.scala +++ b/src/main/scala/org/splink/cpipe/JsonColumnParser.scala @@ -1,6 +1,9 @@ package org.splink.cpipe -import com.datastax.driver.core.{DataType, Row} +import java.lang.{Boolean, Double, Short} +import java.util.Date + +import com.datastax.driver.core.{BatchStatement, DataType, PreparedStatement, Row, Session} import play.api.libs.json._ import scala.collection.JavaConverters._ @@ -8,27 +11,47 @@ import scala.util.{Failure, Success, Try} object JsonColumnParser { - case class Column(name: String, value: String, typ: DataType) + case class Column(name: String, value: Object, typ: DataType) + + // SimpleDateFormat is not thread safe + private val tlDateFormat = new ThreadLocal[java.text.SimpleDateFormat] + + private def dateFormat = { + if (tlDateFormat.get() == null) { + tlDateFormat.set(new java.text.SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSS'Z'")) + } + tlDateFormat.get() + } def column2Json(column: Column) = { - val sanitized = stripControlChars(column.value) - Try(Json.parse(sanitized)) match { - case Success(json) => - val r = json match { - case o: JsObject => o - case _ => parseCassandraDataType(sanitized, column.typ) + val value = column.value + + if (value == null) { + Some(JsObject(Map(column.name -> JsNull))) + } else { + val sanitized: String = value match { + case date: Date => dateFormat.format(date) + case _ => stripControlChars(value.toString) } - Some(JsObject(Map(column.name -> r))) + Try(Json.parse(sanitized)) match { + case Success(json) => + val r = json match { + case o: JsObject => o + case _ => parseCassandraDataType(value, sanitized, column.typ) + } - case Failure(_) => - Some(JsObject(Map(column.name -> parseCassandraDataType(sanitized, column.typ)))) - } + Some(JsObject(Map(column.name -> r))) + + case Failure(_) => + Some(JsObject(Map(column.name -> parseCassandraDataType(value, sanitized, column.typ)))) + } + } } def row2Json(row: Row) = row.getColumnDefinitions.iterator.asScala.flatMap { definition => - Try(row.getObject(definition.getName).toString) match { + Try(row.getObject(definition.getName)) match { case Success(value) => column2Json { Column(definition.getName, value, definition.getType) @@ -73,6 +96,50 @@ object JsonColumnParser { } + def json2PreparedStatement(table: String, json: JsObject, session: Session): PreparedStatement = { + val str = s"INSERT INTO $table ( ${json.fields.map(_._1).mkString(", ")} ) VALUES ( ${json.fields.map(_ => "?").mkString(", ")} );" + session.prepare(str) + } + + def getStringToObjectMappingForTable(session: Session, table: String): Map[String, String => Object] = { + val queryResult = session.execute(s"select * from $table limit 1") + queryResult.getColumnDefinitions.asScala.map { + definition => definition.getName -> getStringToObjectConversionMethod(definition.getType) + }.toMap + } + + def getStringToObjectConversionMethod(dataType: DataType): String => Object = (s: String) => { + dataType.getName match { + case DataType.Name.DATE => nullOr(dateFormat.parse)(s) + case DataType.Name.TIMESTAMP => nullOr(dateFormat.parse)(s) + case DataType.Name.DOUBLE => nullOr{x: String => new Double(x.toDouble)}(s) + case DataType.Name.INT => nullOr{x: String => new Integer(x.toInt)}(s) + case DataType.Name.VARCHAR => nullOr(identity)(s) + case DataType.Name.BOOLEAN =>nullOr{x:String => new Boolean(x == "true")}(s) + case DataType.Name.SMALLINT => nullOr{x:String => new Short(x.toShort)} + case _ => throw new IllegalArgumentException(s"Please add a mapping for the '${dataType.getName}' type") + } + } + + def jsValueToScalaObject(name: String, jsValue: JsValue, objectMapping: Map[String, String => Object]) : Object = { + val v = jsValue.toString.stripPrefix("\"").stripSuffix("\"") + objectMapping.get(name).getOrElse(throw new IllegalArgumentException(s"$name was not found in the map $objectMapping"))(v) + } + + def addJsonToBatch(json: JsObject, preparedStatement: PreparedStatement, batch: BatchStatement, objectMapping: Map[String, String => Object]): Unit = { + val values = json.fields.map { v => jsValueToScalaObject(v._1, v._2, objectMapping) } + batch.add(preparedStatement.bind(values : _*)) + } + + def nullOr(parser: String => Object): String => Object = (s: String) => { + if (s.equals("null")) { + null + } else { + parser(s) + } + } + + import java.util.regex.Pattern val pattern = Pattern.compile("[\\u0000-\\u001f]") @@ -80,32 +147,36 @@ object JsonColumnParser { def stripControlChars(s: String) = pattern.matcher(s).replaceAll("") - def parseCassandraDataType(a: String, dt: DataType) = - dt.getName match { - case DataType.Name.ASCII => JsString(a) - case DataType.Name.BLOB => JsString(a) - case DataType.Name.DATE => JsString(a) - case DataType.Name.INET => JsString(a) - case DataType.Name.TEXT => JsString(a) - case DataType.Name.TIME => JsString(a) - case DataType.Name.TIMESTAMP => JsString(a) - case DataType.Name.TIMEUUID => JsString(a) - case DataType.Name.UUID => JsString(a) - case DataType.Name.VARCHAR => JsString(a) - case DataType.Name.BOOLEAN => JsBoolean(a == "true") - case DataType.Name.BIGINT => JsNumber(BigDecimal(a)) - case DataType.Name.DECIMAL => JsNumber(BigDecimal(a)) - case DataType.Name.DOUBLE => JsNumber(BigDecimal(a)) - case DataType.Name.FLOAT => JsNumber(BigDecimal(a)) - case DataType.Name.INT => JsNumber(BigDecimal(a)) - case DataType.Name.SMALLINT => JsNumber(BigDecimal(a)) - case DataType.Name.TINYINT => JsNumber(BigDecimal(a)) - case DataType.Name.VARINT => JsNumber(BigDecimal(a)) - case DataType.Name.LIST => Json.parse(a) - case DataType.Name.MAP => Json.parse(a) - case DataType.Name.SET => Json.parse(a) - case DataType.Name.TUPLE => Json.parse(a) - case DataType.Name.UDT => Json.parse(a) - case _ => Json.parse(a) - } + def parseCassandraDataType(v: Object, a: String, dt: DataType) = { + dt.getName match { + case DataType.Name.ASCII => JsString(a) + case DataType.Name.BLOB => JsString(a) + case DataType.Name.DATE => JsString(a) + case DataType.Name.INET => JsString(a) + case DataType.Name.TEXT => JsString(a) + case DataType.Name.TIME => JsString(a) + case DataType.Name.TIMESTAMP => JsString(a) + case DataType.Name.TIMEUUID => JsString(a) + case DataType.Name.UUID => JsString(a) + case DataType.Name.VARCHAR => JsString(a) + case DataType.Name.BOOLEAN => JsBoolean(a == "true") + case DataType.Name.BIGINT => JsNumber(BigDecimal(a)) + case DataType.Name.DECIMAL => JsNumber(BigDecimal(a)) + case DataType.Name.DOUBLE => v match { + case d: Double if Double.isNaN(d) => JsNull + case _ => JsNumber(BigDecimal(a)) + } + case DataType.Name.FLOAT => JsNumber(BigDecimal(a)) + case DataType.Name.INT => JsNumber(BigDecimal(a)) + case DataType.Name.SMALLINT => JsNumber(BigDecimal(a)) + case DataType.Name.TINYINT => JsNumber(BigDecimal(a)) + case DataType.Name.VARINT => JsNumber(BigDecimal(a)) + case DataType.Name.LIST => Json.parse(a) + case DataType.Name.MAP => Json.parse(a) + case DataType.Name.SET => Json.parse(a) + case DataType.Name.TUPLE => Json.parse(a) + case DataType.Name.UDT => Json.parse(a) + case _ => Json.parse(a) + } + } } diff --git a/src/main/scala/org/splink/cpipe/config/Arguments.scala b/src/main/scala/org/splink/cpipe/config/Arguments.scala index ce94919..7e82f15 100644 --- a/src/main/scala/org/splink/cpipe/config/Arguments.scala +++ b/src/main/scala/org/splink/cpipe/config/Arguments.scala @@ -54,6 +54,9 @@ class Arguments(arguments: Seq[String]) extends ScallopConf(arguments) { val fetchSize = opt[Int](default = Some(5000), descr = "The amount of rows which is retrieved simultaneously. Defaults to 5000.") + val batchSize = opt[Int](default = Some(500), + descr = "The amount of rows which is saved simultaneously when using mode import2. Defaults to 500.") + val threads = opt[Int](default = Some(32), descr = "The amount of parallelism used in export2 mode. Defaults to 32 parallel requests.") @@ -65,8 +68,9 @@ class Arguments(arguments: Seq[String]) extends ScallopConf(arguments) { val compression = choice(Seq("ON", "OFF"), default = Some("ON"), descr = "Use LZ4 compression and trade reduced network traffic for CPU cycles. Defaults to ON") - val mode = choice(choices = Seq("import", "export", "export2"), required = true, + val mode = choice(choices = Seq("import", "import2", "export", "export2"), required = true, descr = "Select the mode. Choose mode 'import' to import data. " + + "Choose mode 'import2' to import data with a prepared statement (faster, but only for tables with fixed columns); " + "Choose mode 'export' to export data (optional with a filter); " + "Choose mode 'export2' to export data using token ranges to increase performance and reduce load on the cluster. " + "'export2' mode cannot be combined with a filter and it requires that the cluster uses Murmur3Partitioner. " + @@ -79,4 +83,4 @@ class Arguments(arguments: Seq[String]) extends ScallopConf(arguments) { } verify() -} \ No newline at end of file +} diff --git a/src/main/scala/org/splink/cpipe/config/Config.scala b/src/main/scala/org/splink/cpipe/config/Config.scala index 9778b66..df99f37 100644 --- a/src/main/scala/org/splink/cpipe/config/Config.scala +++ b/src/main/scala/org/splink/cpipe/config/Config.scala @@ -18,6 +18,7 @@ case object Config { verbose <- args.verbose.toOption threads <- args.threads.toOption fetchSize <- args.fetchSize.toOption + batchSize <- args.batchSize.toOption useCompression <- args.compression.toOption.map { case c if c == "ON" => true case _ => false @@ -42,7 +43,7 @@ case object Config { Selection(keyspace, table, filter), Credentials(username, password), Flags(!beQuiet, useCompression, verbose), - Settings(fetchSize, consistencyLevel, threads)) + Settings(fetchSize, batchSize, consistencyLevel, threads)) } } @@ -54,4 +55,4 @@ final case class Credentials(username: String, password: String) final case class Flags(showProgress: Boolean, useCompression: Boolean, verbose: Boolean) -final case class Settings(fetchSize: Int, consistencyLevel: ConsistencyLevel, threads: Int) +final case class Settings(fetchSize: Int, batchSize: Int, consistencyLevel: ConsistencyLevel, threads: Int) diff --git a/src/main/scala/org/splink/cpipe/processors/Importer2.scala b/src/main/scala/org/splink/cpipe/processors/Importer2.scala new file mode 100644 index 0000000..7695188 --- /dev/null +++ b/src/main/scala/org/splink/cpipe/processors/Importer2.scala @@ -0,0 +1,38 @@ +package org.splink.cpipe.processors + +import com.datastax.driver.core.{BatchStatement, PreparedStatement, Session} +import org.splink.cpipe.config.Config +import org.splink.cpipe.{JsonFrame, Output, Rps} + +import scala.io.Source + +class Importer2 extends Processor { + + import org.splink.cpipe.JsonColumnParser._ + + val rps = new Rps() + + override def process(session: Session, config: Config): Int = { + val frame = new JsonFrame() + var statement: PreparedStatement = null + val dataTypeMapping = getStringToObjectMappingForTable(session, config.selection.table) + + Source.stdin.getLines().flatMap { line => + frame.push(line.toCharArray) + }.grouped(config.settings.batchSize).foreach { group => + val batch = new BatchStatement + group.foreach { str => + string2Json(str).foreach { json => + if (statement == null) { + statement = json2PreparedStatement(config.selection.table, json, session) + } + addJsonToBatch(json, statement, batch, dataTypeMapping) + rps.compute() + } + } + if (config.flags.showProgress) Output.update(s"${rps.count} rows at $rps rows/sec.") + session.execute(batch) + } + rps.count + } +}