22from django .db .models .sql .compiler import SQLInsertCompiler , SQLUpdateCompiler
33
44
5- class PostgresSQLReturningUpdateCompiler (SQLUpdateCompiler ):
5+ class PostgresReturningUpdateCompiler (SQLUpdateCompiler ):
66 """Compiler for SQL UPDATE statements that return
77 the primary keys of the affected rows."""
88
@@ -23,42 +23,37 @@ def _form_returning(self):
2323 return 'RETURNING %s' % qn (self .query .model ._meta .pk .name )
2424
2525
26- class PostgresSQLUpsertCompiler (SQLInsertCompiler ):
26+ class PostgresInsertCompiler (SQLInsertCompiler ):
2727 """Compiler for SQL INSERT statements."""
2828
29- def as_sql (self , returning = 'id' ):
29+ def as_sql (self , return_id = False ):
3030 """Builds the SQL INSERT statement."""
3131
3232 queries = [
33- (self ._rewrite_insert (sql , returning ), params )
33+ (self ._rewrite_insert (sql , return_id ), params )
3434 for sql , params in super ().as_sql ()
3535 ]
3636
3737 return queries
3838
3939 def execute_sql (self , return_id = False ):
40- returning = 'id' if return_id else '*'
41- returning = '*'
42-
4340 # execute all the generate queries
4441 with self .connection .cursor () as cursor :
4542 rows = []
46- for sql , params in self .as_sql (returning ):
43+ for sql , params in self .as_sql (return_id ):
4744 cursor .execute (sql , params )
4845 rows .append (cursor .fetchone ())
4946
50- # return the primary key, which is stored in
51- # the first column that is returned
52- if return_id :
53- return dict (id = rows [0 ][0 ])
54-
5547 # create a mapping between column names and column value
56- return {
57- column .name : rows [0 ][column_index ]
58- for column_index , column in enumerate (cursor .description )
59- }
48+ return [
49+ {
50+ column .name : row [column_index ]
51+ for column_index , column in enumerate (cursor .description ) if row
52+ }
53+ for row in rows
54+ ]
6055
61- def _rewrite_insert (self , sql , returning = 'id' ):
56+ def _rewrite_insert (self , sql , return_id = False ):
6257 """Rewrites a formed SQL INSERT query to include
6358 the ON CONFLICT clause.
6459
@@ -74,31 +69,52 @@ def _rewrite_insert(self, sql, returning='id'):
7469 The specified SQL INSERT query rewritten
7570 to include the ON CONFLICT clause.
7671 """
77- qn = self .connection .ops .quote_name
78-
79- # construct a list of columns to update when there's a conflict
80- update_columns = ', ' .join ([
81- '{0} = EXCLUDED.{0}' .format (qn (field .column ))
82- for field in self .query .update_fields
83- ])
8472
8573 # build the conflict target, the columns to watch
8674 # for conflict basically
8775 conflict_target = self ._build_conflict_target ()
8876
8977 # form the new sql query that does the insert
9078 new_sql = (
91- '{insert} ON CONFLICT ({conflict_target}) '
92- 'DO UPDATE SET {update_columns} RETURNING {returning}'
79+ '{insert} ON CONFLICT ({conflict_target}) DO {conflict_action}'
9380 ).format (
9481 insert = sql ,
9582 conflict_target = conflict_target ,
96- update_columns = update_columns ,
97- returning = returning
83+ conflict_action = self ._build_conflict_action (return_id )
9884 )
9985
10086 return new_sql
10187
88+ def _build_conflict_action (self , return_id = False ):
89+ """Builds the `conflict_action` for the DO clause."""
90+
91+ returning = 'id' if return_id else '*'
92+
93+ qn = self .connection .ops .quote_name
94+
95+ # construct a list of columns to update when there's a conflict
96+ if self .query .conflict_action .value == 'UPDATE' :
97+ update_columns = ', ' .join ([
98+ '{0} = EXCLUDED.{0}' .format (qn (field .column ))
99+ for field in self .query .update_fields
100+ ])
101+
102+ return (
103+ 'UPDATE SET {update_columns} RETURNING {returning}'
104+ ).format (
105+ update_columns = update_columns ,
106+ returning = returning
107+ )
108+ elif self .query .conflict_action .value == 'NOTHING' :
109+ return (
110+ 'NOTHING RETURNING {returning}'
111+ ).format (returning = returning )
112+
113+ raise SuspiciousOperation ((
114+ '%s is not a valid conflict action, specify '
115+ 'ConflictAction.UPDATE or ConflictAction.NOTHING.'
116+ ) % str (self .query .conflict_action ))
117+
102118 def _build_conflict_target (self ):
103119 """Builds the `conflict_target` for the ON CONFLICT
104120 clause."""
0 commit comments