Skip to content

Commit 8ea9b21

Browse files
committed
Support binary data transfer in COPY FROM
My benchmark of transferring the integers from 0 to 1,000,000 both as an integer and as a string was about the same speed as the old text-based transfer. I believe that the binary transfer will start to show significant benefits when transferring binary data, other fields that don't need to be represented as fields and also means that the user doesn't need to worry about escapping their data.
1 parent 4db2fde commit 8ea9b21

File tree

3 files changed

+256
-0
lines changed

3 files changed

+256
-0
lines changed

Sources/PostgresNIO/Connection/PostgresConnection+CopyFrom.swift

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,111 @@ public struct PostgresCopyFromWriter: Sendable {
9898
}
9999
}
100100

101+
102+
/// Handle to send binary data for a `COPY ... FROM STDIN` query to the backend.
103+
///
104+
/// It takes care of serializing `PostgresEncodable` column types into the binary format that Postgres expects.
105+
public struct PostgresBinaryCopyFromWriter: ~Copyable {
106+
/// Handle to serialize columns into a row that is being written by `PostgresBinaryCopyFromWriter`.
107+
public struct ColumnWriter: ~Copyable {
108+
/// The `PostgresBinaryCopyFromWriter` that is gathering the serialized data.
109+
///
110+
/// We need to model this as `UnsafeMutablePointer` because we can't express in the Swift type system that
111+
/// `ColumnWriter` never exceeds the lifetime of `PostgresBinaryCopyFromWriter`.
112+
@usableFromInline
113+
let underlying: UnsafeMutablePointer<PostgresBinaryCopyFromWriter>
114+
115+
/// The number of columns that have been written by this `ColumnWriter`.
116+
@usableFromInline
117+
var columns: UInt16 = 0
118+
119+
@usableFromInline
120+
init(underlying: UnsafeMutablePointer<PostgresBinaryCopyFromWriter>) {
121+
self.underlying = underlying
122+
}
123+
124+
/// Serialize a single column to a row.
125+
///
126+
/// - Important: It is critical that that data type encoded here exactly matches the data type in the
127+
/// database. For example, if the database stores an a 4-bit integer the corresponding `writeColumn` must
128+
/// be called with an `Int32`. Serializing an integer of a different width will cause a deserialization
129+
/// failure in the backend.
130+
@inlinable
131+
public mutating func writeColumn(_ column: (some PostgresEncodable)?) throws {
132+
columns += 1
133+
try underlying.pointee.writeColumn(column)
134+
}
135+
}
136+
137+
/// The underlying `PostgresCopyFromWriter` that sends the serialized data to the backend.
138+
@usableFromInline let underlying: PostgresCopyFromWriter
139+
140+
/// The buffer in which we accumulate binary data. Once this buffer exceeds `bufferSize`, we flush it to
141+
/// the backend.
142+
@usableFromInline var buffer = ByteBuffer()
143+
144+
/// Once `buffer` exceeds this size, it gets flushed to the backend.
145+
@usableFromInline let bufferSize: Int
146+
147+
init(underlying: PostgresCopyFromWriter, bufferSize: Int) {
148+
self.underlying = underlying
149+
// Allocate 10% more than the buffer size because we only flush the buffer once it has exceeded `bufferSize`
150+
buffer.reserveCapacity(bufferSize + bufferSize / 10)
151+
self.bufferSize = bufferSize
152+
}
153+
154+
/// Serialize a single row to the backend. Call `writeColumn` on `columnWriter` for every column that should be
155+
/// included in the row.
156+
@inlinable
157+
public mutating func writeRow(_ body: (_ columnWriter: inout ColumnWriter) throws -> Void) async throws {
158+
// Write a placeholder for the number of columns
159+
let columnIndex = buffer.writerIndex
160+
buffer.writeInteger(UInt16(0))
161+
162+
let columns = try withUnsafeMutablePointer(to: &self) { pointerToSelf in
163+
// Important: We need to ensure that `pointerToSelf` (and thus `ColumnWriter`) does not exceed the lifetime
164+
// of `self` because it is holding an unsafe reference to it.
165+
//
166+
// We achieve this because `ColumnWriter` is non-Copyable and thus the client can't store a copy to it.
167+
// Furthermore, `columnWriter` is destroyed before the end of `withUnsafeMutablePointer`, which holds `self`
168+
// alive.
169+
var columnWriter = ColumnWriter(underlying: pointerToSelf)
170+
171+
try body(&columnWriter)
172+
173+
return columnWriter.columns
174+
}
175+
176+
// Fill in the number of columns
177+
buffer.setInteger(columns, at: columnIndex)
178+
179+
if buffer.readableBytes > bufferSize {
180+
try await flush()
181+
}
182+
}
183+
184+
/// Serialize a single column to the buffer. Should only be called by `ColumnWriter`.
185+
@inlinable
186+
mutating func writeColumn(_ column: (some PostgresEncodable)?) throws {
187+
guard let column else {
188+
buffer.writeInteger(Int32(-1))
189+
return
190+
}
191+
try buffer.writeLengthPrefixed(as: Int32.self) { buffer in
192+
let startIndex = buffer.writerIndex
193+
try column.encode(into: &buffer, context: .default)
194+
return buffer.writerIndex - startIndex
195+
}
196+
}
197+
198+
/// Flush any pending data in the buffer to the backend.
199+
@usableFromInline
200+
mutating func flush() async throws {
201+
try await underlying.write(buffer)
202+
buffer.clear()
203+
}
204+
}
205+
101206
/// Specifies the format in which data is transferred to the backend in a COPY operation.
102207
///
103208
/// See the Postgres documentation at https://www.postgresql.org/docs/current/sql-copy.html for the option's meanings
@@ -113,15 +218,25 @@ public struct PostgresCopyFromFormat: Sendable {
113218
public init() {}
114219
}
115220

221+
/// Options that can be used to modify the `binary` format of a COPY operation.
222+
public struct BinaryOptions: Sendable {
223+
public init() {}
224+
}
225+
116226
enum Format {
117227
case text(TextOptions)
228+
case binary(BinaryOptions)
118229
}
119230

120231
var format: Format
121232

122233
public static func text(_ options: TextOptions) -> PostgresCopyFromFormat {
123234
return PostgresCopyFromFormat(format: .text(options))
124235
}
236+
237+
public static func binary(_ options: BinaryOptions) -> PostgresCopyFromFormat {
238+
return PostgresCopyFromFormat(format: .binary(options))
239+
}
125240
}
126241

127242
/// Create a `COPY ... FROM STDIN` query based on the given parameters.
@@ -153,6 +268,8 @@ private func buildCopyFromQuery(
153268
// Set the delimiter as a Unicode code point. This avoids the possibility of SQL injection.
154269
queryOptions.append("DELIMITER U&'\\\(String(format: "%04x", delimiter.value))'")
155270
}
271+
case .binary:
272+
queryOptions.append("FORMAT binary")
156273
}
157274
precondition(!queryOptions.isEmpty)
158275
query += " WITH ("
@@ -162,6 +279,49 @@ private func buildCopyFromQuery(
162279
}
163280

164281
extension PostgresConnection {
282+
/// Copy data into a table using a `COPY <table name> FROM STDIN` query, transferring data in a binary format.
283+
///
284+
/// - Parameters:
285+
/// - table: The name of the table into which to copy the data.
286+
/// - columns: The name of the columns to copy. If an empty array is passed, all columns are assumed to be copied.
287+
/// - bufferSize: How many bytes to accumulate a local buffer before flushing it to the database. Can affect
288+
/// performance characteristics of the copy operation.
289+
/// - writeData: Closure that produces the data for the table, to be streamed to the backend. Call `write` on the
290+
/// writer provided by the closure to send data to the backend and return from the closure once all data is sent.
291+
/// Throw an error from the closure to fail the data transfer. The error thrown by the closure will be rethrown
292+
/// by the `copyFromBinary` function.
293+
///
294+
/// - Important: The table and column names are inserted into the `COPY FROM` query as passed and might thus be
295+
/// susceptible to SQL injection. Ensure no untrusted data is contained in these strings.
296+
public func copyFromBinary(
297+
table: String,
298+
columns: [String] = [],
299+
options: PostgresCopyFromFormat.BinaryOptions = .init(),
300+
bufferSize: Int = 100_000,
301+
logger: Logger,
302+
file: String = #fileID,
303+
line: Int = #line,
304+
writeData: (inout PostgresBinaryCopyFromWriter) async throws -> Void
305+
) async throws {
306+
try await copyFrom(table: table, columns: columns, format: .binary(PostgresCopyFromFormat.BinaryOptions()), logger: logger) { writer in
307+
var header = ByteBuffer()
308+
header.writeString("PGCOPY\n")
309+
header.writeInteger(UInt8(0xff))
310+
header.writeString("\r\n\0")
311+
312+
// Flag fields
313+
header.writeInteger(UInt32(0))
314+
315+
// Header extension area length
316+
header.writeInteger(UInt32(0))
317+
try await writer.write(header)
318+
319+
var binaryWriter = PostgresBinaryCopyFromWriter(underlying: writer, bufferSize: bufferSize)
320+
try await writeData(&binaryWriter)
321+
try await binaryWriter.flush()
322+
}
323+
}
324+
165325
/// Copy data into a table using a `COPY <table name> FROM STDIN` query.
166326
///
167327
/// - Parameters:

Tests/IntegrationTests/PSQLIntegrationTests.swift

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -487,4 +487,38 @@ final class IntegrationTests: XCTestCase {
487487
XCTAssertEqual((error as? PSQLError)?.serverInfo?[.sqlState], "42601") // scanner_yyerror
488488
}
489489
}
490+
491+
func testCopyFromBinary() async throws {
492+
let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 2)
493+
defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) }
494+
let eventLoop = eventLoopGroup.next()
495+
496+
let conn = try await PostgresConnection.test(on: eventLoop).get()
497+
defer { XCTAssertNoThrow(try conn.close().wait()) }
498+
499+
_ = try? await conn.query("DROP TABLE copy_table", logger: .psqlTest).get()
500+
_ = try await conn.query("CREATE TABLE copy_table (id INT, name VARCHAR(100))", logger: .psqlTest).get()
501+
502+
try await conn.copyFromBinary(table: "copy_table", columns: ["id", "name"], logger: .psqlTest) { writer in
503+
let records: [(id: Int, name: String)] = [
504+
(1, "Alice"),
505+
(42, "Bob")
506+
]
507+
for record in records {
508+
try await writer.writeRow { columnWriter in
509+
try columnWriter.writeColumn(Int32(record.id))
510+
try columnWriter.writeColumn(record.name)
511+
}
512+
}
513+
}
514+
let rows = try await conn.query("SELECT id, name FROM copy_table").get().rows.map { try $0.decode((Int, String).self) }
515+
guard rows.count == 2 else {
516+
XCTFail("Expected 2 columns, received \(rows.count)")
517+
return
518+
}
519+
XCTAssertEqual(rows[0].0, 1)
520+
XCTAssertEqual(rows[0].1, "Alice")
521+
XCTAssertEqual(rows[1].0, 42)
522+
XCTAssertEqual(rows[1].1, "Bob")
523+
}
490524
}

Tests/PostgresNIOTests/New/PostgresConnectionTests.swift

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -936,6 +936,68 @@ import Synchronization
936936
}
937937
}
938938

939+
@Test func testCopyFromBinary() async throws {
940+
try await self.withAsyncTestingChannel { connection, channel in
941+
try await withThrowingTaskGroup(of: Void.self) { taskGroup async throws -> Void in
942+
taskGroup.addTask {
943+
try await connection.copyFromBinary(table: "copy_table", logger: .psqlTest) {
944+
writer in
945+
try await writer.writeRow { columnWriter in
946+
try columnWriter.writeColumn(Int32(1))
947+
try columnWriter.writeColumn("Alice")
948+
}
949+
try await writer.writeRow { columnWriter in
950+
try columnWriter.writeColumn(Int32(2))
951+
try columnWriter.writeColumn("Bob")
952+
}
953+
}
954+
}
955+
956+
let copyRequest = try await channel.waitForUnpreparedRequest()
957+
#expect(copyRequest.parse.query == #"COPY "copy_table" FROM STDIN WITH (FORMAT binary)"#)
958+
959+
try await channel.sendUnpreparedRequestWithNoParametersBindResponse()
960+
try await channel.writeInbound(
961+
PostgresBackendMessage.copyInResponse(
962+
.init(format: .binary, columnFormats: [.binary, .binary])))
963+
964+
let copyData = try await channel.waitForCopyData()
965+
#expect(copyData.result == .done)
966+
var data = copyData.data
967+
// Signature
968+
#expect(data.readString(length: 7) == "PGCOPY\n")
969+
#expect(data.readInteger(as: UInt8.self) == 0xff)
970+
#expect(data.readString(length: 3) == "\r\n\0")
971+
// Flags
972+
#expect(data.readInteger(as: UInt32.self) == 0)
973+
// Header extension area length
974+
#expect(data.readInteger(as: UInt32.self) == 0)
975+
976+
struct Row: Equatable {
977+
let id: Int32
978+
let name: String
979+
}
980+
var rows: [Row] = []
981+
while data.readableBytes > 0 {
982+
// Number of columns
983+
#expect(data.readInteger(as: UInt16.self) == 2)
984+
// 'id' column
985+
#expect(data.readInteger(as: UInt32.self) == 4)
986+
let id = data.readInteger(as: Int32.self)
987+
// 'name' column length
988+
let nameLength = data.readInteger(as: UInt32.self)
989+
let name = data.readString(length: Int(try #require(nameLength)))
990+
rows.append(Row(id: try #require(id), name: try #require(name)))
991+
}
992+
#expect(rows == [Row(id: 1, name: "Alice"), Row(id: 2, name: "Bob")])
993+
try await channel.writeInbound(PostgresBackendMessage.commandComplete("COPY 1"))
994+
995+
try await channel.waitForPostgresFrontendMessage(\.sync)
996+
try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle))
997+
}
998+
}
999+
}
1000+
9391001
func withAsyncTestingChannel(_ body: (PostgresConnection, NIOAsyncTestingChannel) async throws -> ()) async throws {
9401002
let eventLoop = NIOAsyncTestingEventLoop()
9411003
let channel = try await NIOAsyncTestingChannel(loop: eventLoop) { channel in

0 commit comments

Comments
 (0)