Skip to content

Commit 3f77fa6

Browse files
committed
Initial support for ON CONFLICT DO NOTHING
1 parent 17fffc6 commit 3f77fa6

File tree

4 files changed

+259
-110
lines changed

4 files changed

+259
-110
lines changed

psqlextra/compiler.py

Lines changed: 45 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from django.db.models.sql.compiler import SQLInsertCompiler, SQLUpdateCompiler
33

44

5-
class PostgresSQLReturningUpdateCompiler(SQLUpdateCompiler):
5+
class PostgresReturningUpdateCompiler(SQLUpdateCompiler):
66
"""Compiler for SQL UPDATE statements that return
77
the primary keys of the affected rows."""
88

@@ -23,42 +23,37 @@ def _form_returning(self):
2323
return 'RETURNING %s' % qn(self.query.model._meta.pk.name)
2424

2525

26-
class PostgresSQLUpsertCompiler(SQLInsertCompiler):
26+
class PostgresInsertCompiler(SQLInsertCompiler):
2727
"""Compiler for SQL INSERT statements."""
2828

29-
def as_sql(self, returning='id'):
29+
def as_sql(self, return_id=False):
3030
"""Builds the SQL INSERT statement."""
3131

3232
queries = [
33-
(self._rewrite_insert(sql, returning), params)
33+
(self._rewrite_insert(sql, return_id), params)
3434
for sql, params in super().as_sql()
3535
]
3636

3737
return queries
3838

3939
def execute_sql(self, return_id=False):
40-
returning = 'id' if return_id else '*'
41-
returning = '*'
42-
4340
# execute all the generate queries
4441
with self.connection.cursor() as cursor:
4542
rows = []
46-
for sql, params in self.as_sql(returning):
43+
for sql, params in self.as_sql(return_id):
4744
cursor.execute(sql, params)
4845
rows.append(cursor.fetchone())
4946

50-
# return the primary key, which is stored in
51-
# the first column that is returned
52-
if return_id:
53-
return dict(id=rows[0][0])
54-
5547
# create a mapping between column names and column value
56-
return {
57-
column.name: rows[0][column_index]
58-
for column_index, column in enumerate(cursor.description)
59-
}
48+
return [
49+
{
50+
column.name: row[column_index]
51+
for column_index, column in enumerate(cursor.description) if row
52+
}
53+
for row in rows
54+
]
6055

61-
def _rewrite_insert(self, sql, returning='id'):
56+
def _rewrite_insert(self, sql, return_id=False):
6257
"""Rewrites a formed SQL INSERT query to include
6358
the ON CONFLICT clause.
6459
@@ -74,31 +69,52 @@ def _rewrite_insert(self, sql, returning='id'):
7469
The specified SQL INSERT query rewritten
7570
to include the ON CONFLICT clause.
7671
"""
77-
qn = self.connection.ops.quote_name
78-
79-
# construct a list of columns to update when there's a conflict
80-
update_columns = ', '.join([
81-
'{0} = EXCLUDED.{0}'.format(qn(field.column))
82-
for field in self.query.update_fields
83-
])
8472

8573
# build the conflict target, the columns to watch
8674
# for conflict basically
8775
conflict_target = self._build_conflict_target()
8876

8977
# form the new sql query that does the insert
9078
new_sql = (
91-
'{insert} ON CONFLICT ({conflict_target}) '
92-
'DO UPDATE SET {update_columns} RETURNING {returning}'
79+
'{insert} ON CONFLICT ({conflict_target}) DO {conflict_action}'
9380
).format(
9481
insert=sql,
9582
conflict_target=conflict_target,
96-
update_columns=update_columns,
97-
returning=returning
83+
conflict_action=self._build_conflict_action(return_id)
9884
)
9985

10086
return new_sql
10187

88+
def _build_conflict_action(self, return_id=False):
89+
"""Builds the `conflict_action` for the DO clause."""
90+
91+
returning = 'id' if return_id else '*'
92+
93+
qn = self.connection.ops.quote_name
94+
95+
# construct a list of columns to update when there's a conflict
96+
if self.query.conflict_action.value == 'UPDATE':
97+
update_columns = ', '.join([
98+
'{0} = EXCLUDED.{0}'.format(qn(field.column))
99+
for field in self.query.update_fields
100+
])
101+
102+
return (
103+
'UPDATE SET {update_columns} RETURNING {returning}'
104+
).format(
105+
update_columns=update_columns,
106+
returning=returning
107+
)
108+
elif self.query.conflict_action.value == 'NOTHING':
109+
return (
110+
'NOTHING RETURNING {returning}'
111+
).format(returning=returning)
112+
113+
raise SuspiciousOperation((
114+
'%s is not a valid conflict action, specify '
115+
'ConflictAction.UPDATE or ConflictAction.NOTHING.'
116+
) % str(self.query.conflict_action))
117+
102118
def _build_conflict_target(self):
103119
"""Builds the `conflict_target` for the ON CONFLICT
104120
clause."""

0 commit comments

Comments
 (0)