@@ -70,50 +70,69 @@ def _rewrite_insert(self, sql, return_id=False):
7070 to include the ON CONFLICT clause.
7171 """
7272
73+ returning = 'id' if return_id else '*'
74+
75+ if self .query .conflict_action .value == 'UPDATE' :
76+ return self ._rewrite_insert_update (sql , returning )
77+ elif self .query .conflict_action .value == 'NOTHING' :
78+ return self ._rewrite_insert_nothing (sql , returning )
79+
80+ raise SuspiciousOperation ((
81+ '%s is not a valid conflict action, specify '
82+ 'ConflictAction.UPDATE or ConflictAction.NOTHING.'
83+ ) % str (self .query .conflict_action ))
84+
85+ def _rewrite_insert_update (self , sql , returning ):
86+ qn = self .connection .ops .quote_name
87+
88+ update_columns = ', ' .join ([
89+ '{0} = EXCLUDED.{0}' .format (qn (field .column ))
90+ for field in self .query .update_fields
91+ ])
92+
7393 # build the conflict target, the columns to watch
74- # for conflict basically
94+ # for conflicts
7595 conflict_target = self ._build_conflict_target ()
7696
77- # form the new sql query that does the insert
78- new_sql = (
79- '{insert} ON CONFLICT ({conflict_target}) DO {conflict_action }'
97+ return (
98+ '{insert} ON CONFLICT ({conflict_target}) DO UPDATE'
99+ ' SET {update_columns} RETURNING {returning }'
80100 ).format (
81101 insert = sql ,
82102 conflict_target = conflict_target ,
83- conflict_action = self ._build_conflict_action (return_id )
103+ update_columns = update_columns ,
104+ returning = returning
84105 )
85106
86- return new_sql
87-
88- def _build_conflict_action (self , return_id = False ):
89- """Builds the `conflict_action` for the DO clause."""
90-
91- returning = 'id' if return_id else '*'
92-
107+ def _rewrite_insert_nothing (self , sql , returning ):
93108 qn = self .connection .ops .quote_name
94109
95- # construct a list of columns to update when there's a conflict
96- if self .query .conflict_action .value == 'UPDATE' :
97- update_columns = ', ' .join ([
98- '{0} = EXCLUDED.{0}' .format (qn (field .column ))
99- for field in self .query .update_fields
100- ])
101-
102- return (
103- 'UPDATE SET {update_columns} RETURNING {returning}'
104- ).format (
105- update_columns = update_columns ,
106- returning = returning
107- )
108- elif self .query .conflict_action .value == 'NOTHING' :
109- return (
110- 'NOTHING RETURNING {returning}'
111- ).format (returning = returning )
110+ # build the conflict target, the columns to watch
111+ # for conflicts
112+ conflict_target = self ._build_conflict_target ()
112113
113- raise SuspiciousOperation ((
114- '%s is not a valid conflict action, specify '
115- 'ConflictAction.UPDATE or ConflictAction.NOTHING.'
116- ) % str (self .query .conflict_action ))
114+ select_columns = ', ' .join ([
115+ '{0} = \' {1}\' ' .format (qn (column ), getattr (self .query .objs [0 ], column ))
116+ for column in self .query .conflict_target
117+ ])
118+
119+ # this looks complicated, and it is, but it is for a reason... a normal
120+ # ON CONFLICT DO NOTHING doesn't return anything if the row already exists
121+ # so we do DO UPDATE instead that never executes to lock the row, and then
122+ # select from the table in case we're dealing with an existing row..
123+ return (
124+ 'WITH insdata AS ('
125+ '{insert} ON CONFLICT ({conflict_target}) DO UPDATE'
126+ ' SET id = NULL WHERE FALSE RETURNING {returning})'
127+ ' SELECT * FROM insdata UNION ALL'
128+ ' SELECT {returning} FROM {table} WHERE {select_columns} LIMIT 1;'
129+ ).format (
130+ insert = sql ,
131+ conflict_target = conflict_target ,
132+ returning = returning ,
133+ table = self .query .objs [0 ]._meta .db_table ,
134+ select_columns = select_columns
135+ )
117136
118137 def _build_conflict_target (self ):
119138 """Builds the `conflict_target` for the ON CONFLICT
0 commit comments