Skip to content

Commit ba9a036

Browse files
committed
Add support for conflict_target
1 parent 38fcc2b commit ba9a036

File tree

6 files changed

+192
-71
lines changed

6 files changed

+192
-71
lines changed

docs/features.md

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,16 +46,29 @@ In order to combat this, PostgreSQL added native upserts. Also known as [`ON CON
4646

4747

4848
## upsert
49-
The `upsert` method attempts to insert a row with the specified data or updates (and overwrites) the duplicate row, and then returns the primary key of the row that was created/updated:
49+
The `upsert` method attempts to insert a row with the specified data or updates (and overwrites) the duplicate row, and then returns the primary key of the row that was created/updated.
50+
51+
Upserts work by catching conflicts. PostgreSQL requires to know which conflicts to react to. You have to specify the name of the column which's constraint you want to react to. This is specified in the `conflict_target` field. If the constraint you're trying to react to consists of multiple columns, specify multiple columns.
5052

5153
from django.db import models
5254
from psqlextra.models import PostgresModel
5355

5456
class MyModel(PostgresModel):
5557
myfield = models.CharField(max_length=255, unique=True)
5658

57-
id1 = MyModel.objects.upsert(myfield='beer')
58-
id2 = MyModel.objects.upsert(myfield='beer')
59+
id1 = MyModel.objects.upsert(
60+
conflict_target=['myfield'],
61+
fields=dict(
62+
myfield='beer'
63+
)
64+
)
65+
66+
id2 = MyModel.objects.upsert(
67+
conflict_target=['myfield'],
68+
fields=dict(
69+
myfield='beer'
70+
)
71+
)
5972

6073
assert id1 == id2
6174

@@ -73,4 +86,18 @@ Note that a single call to `upsert` results in a single `INSERT INTO ... ON CONF
7386
obj1 = MyModel.objects.upsert_and_get(myfield='beer')
7487
obj2 = MyModel.objects.upsert_and_get(myfield='beer')
7588

89+
obj1 = MyModel.objects.upsert_and_get(
90+
conflict_target=['myfield'],
91+
fields=dict(
92+
myfield='beer'
93+
)
94+
)
95+
96+
obj2 = MyModel.objects.upsert_and_get(
97+
conflict_target=['myfield'],
98+
fields=dict(
99+
myfield='beer'
100+
)
101+
)
102+
76103
assert obj1.id == obj2.id

psqlextra/compiler.py

Lines changed: 38 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
1-
from django.db.models.sql.compiler import SQLInsertCompiler
2-
3-
from .fields import HStoreField
41
from django.core.exceptions import SuspiciousOperation
2+
from django.db.models.sql.compiler import SQLInsertCompiler
53

64

75
class PostgresSQLUpsertCompiler(SQLInsertCompiler):
@@ -31,10 +29,13 @@ def execute_sql(self, return_id=False):
3129
# return the primary key, which is stored in
3230
# the first column that is returned
3331
if return_id:
34-
return rows[0][0]
32+
return dict(id=rows[0][0])
3533

36-
# return the entire row instead
37-
return rows[0]
34+
# create a mapping between column names and column value
35+
return {
36+
column.name: rows[0][column_index]
37+
for column_index, column in enumerate(cursor.description)
38+
}
3839

3940
def _rewrite_insert(self, sql, returning='id'):
4041
"""Rewrites a formed SQL INSERT query to include
@@ -54,63 +55,58 @@ def _rewrite_insert(self, sql, returning='id'):
5455
"""
5556
qn = self.connection.ops.quote_name
5657

57-
# ON CONFLICT requires a list of columns to operate on, form
58-
# a list of columns to pass in
59-
unique_columns = ', '.join(self._get_unique_columns())
60-
if len(unique_columns) == 0:
61-
raise SuspiciousOperation((
62-
'You\'re trying to do a upsert on a table that '
63-
'doesn\'t have any unique columns.'
64-
))
65-
6658
# construct a list of columns to update when there's a conflict
6759
update_columns = ', '.join([
6860
'{0} = EXCLUDED.{0}'.format(qn(field.column))
6961
for field in self.query.update_fields
7062
])
7163

64+
# build the conflict target, the columns to watch
65+
# for conflict basically
66+
conflict_target = self._build_conflict_target()
67+
7268
# form the new sql query that does the insert
7369
new_sql = (
74-
'{insert} ON CONFLICT ({unique_columns}) '
70+
'{insert} ON CONFLICT ({conflict_target}) '
7571
'DO UPDATE SET {update_columns} RETURNING {returning}'
7672
).format(
7773
insert=sql,
78-
unique_columns=unique_columns,
74+
conflict_target=conflict_target,
7975
update_columns=update_columns,
8076
returning=returning
8177
)
8278

8379
return new_sql
8480

85-
def _get_unique_columns(self):
86-
"""Gets a list of columns that are marked as 'UNIQUE'.
87-
88-
This is used in the ON CONFLICT clause. This also
89-
works for :see:HStoreField."""
81+
def _build_conflict_target(self):
82+
"""Builds the `conflict_target` for the ON CONFLICT
83+
clause."""
9084

9185
qn = self.connection.ops.quote_name
92-
unique_columns = []
86+
conflict_target = []
9387

94-
for field in self.query.fields:
95-
if field.unique is True:
96-
unique_columns.append(qn(field.column))
88+
if not isinstance(self.query.conflict_target, list):
89+
raise SuspiciousOperation((
90+
'%s is not a valid conflict target, specify '
91+
'a list of column names, or tuples with column '
92+
'names and hstore key.'
93+
) % str(self.query.conflict_target))
94+
95+
for field in self.query.conflict_target:
96+
if isinstance(field, str):
97+
conflict_target.append(qn(field))
9798
continue
9899

99-
# we must also go into possible tuples since those
100-
# are used to indicate "unique together"
101-
if isinstance(field, HStoreField):
102-
uniqueness = getattr(field, 'uniqueness', None)
103-
if not uniqueness:
104-
continue
105-
for key in uniqueness:
106-
if isinstance(key, tuple):
107-
for sub_key in key:
108-
unique_columns.append(
109-
'(%s->\'%s\')' % (qn(field.column), sub_key))
110-
else:
111-
unique_columns.append(
112-
'(%s->\'%s\')' % (qn(field.column), key))
113-
100+
if isinstance(field, tuple):
101+
field, key = field
102+
conflict_target.append(
103+
'(%s -> \'%s\')' % (qn(field), key))
114104
continue
115105

116-
return unique_columns
106+
raise SuspiciousOperation((
107+
'%s is not a valid conflict target, specify '
108+
'a list of column names, or tuples with column '
109+
'names and hstore key.'
110+
) % str(field))
111+
112+
return ','.join(conflict_target)

psqlextra/manager.py

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
1-
import copy
1+
from typing import List, Dict
22

3-
from django.db import models
3+
import django
44
from django.conf import settings
5-
from django.db.models.sql import InsertQuery
65
from django.core.exceptions import ImproperlyConfigured
7-
import django
6+
from django.db import models
87

9-
from .query import PostgresUpsertQuery
108
from .compiler import PostgresSQLUpsertCompiler
9+
from .query import PostgresUpsertQuery
1110

1211

1312
class PostgresManager(models.Manager):
@@ -28,43 +27,51 @@ def __init__(self, *args, **kwargs):
2827
'the \'psqlextra.backend\'. Set DATABASES.ENGINE.'
2928
) % db_backend)
3029

31-
def upsert(self, **kwargs) -> int:
30+
def upsert(self, conflict_target: List, fields: Dict) -> int:
3231
"""Creates a new record or updates the existing one
3332
with the specified data.
3433
3534
Arguments:
36-
kwargs:
35+
conflict_target:
36+
Fields to pass into the ON CONFLICT clause.
37+
38+
fields:
3739
Fields to insert/update.
3840
3941
Returns:
4042
The primary key of the row that was created/updated.
4143
"""
4244

43-
compiler = self._build_upsert_compiler(kwargs)
44-
return compiler.execute_sql(return_id=True)
45+
compiler = self._build_upsert_compiler(conflict_target, fields)
46+
return compiler.execute_sql(return_id=True)['id']
4547

46-
def upsert_and_get(self, **kwargs):
48+
def upsert_and_get(self, conflict_target: List, fields: Dict):
4749
"""Creates a new record or updates the existing one
4850
with the specified data and then gets the row.
4951
5052
Arguments:
51-
kwargs:
53+
conflict_target:
54+
Fields to pass into the ON CONFLICT clause.
55+
56+
fields:
5257
Fields to insert/update.
5358
5459
Returns:
5560
The model instance representing the row
5661
that was created/updated.
5762
"""
5863

59-
compiler = self._build_upsert_compiler(kwargs)
60-
row = compiler.execute_sql(return_id=False)
61-
field_names = [f.name for f in self.model._meta.concrete_fields]
62-
return self.model.from_db(self.db, field_names, row)
64+
compiler = self._build_upsert_compiler(conflict_target, fields)
65+
column_data = compiler.execute_sql(return_id=False)
66+
return self.model(**column_data)
6367

64-
def _build_upsert_compiler(self, kwargs):
68+
def _build_upsert_compiler(self, conflict_target: List, kwargs):
6569
"""Builds the SQL compiler for a insert/update query.
6670
6771
Arguments:
72+
conflict_target:
73+
Fields to pass into the ON CONFLICT clause.
74+
6875
kwargs:
6976
Field values.
7077
@@ -87,6 +94,7 @@ def _build_upsert_compiler(self, kwargs):
8794

8895
# build a normal insert query
8996
query = PostgresUpsertQuery(self.model)
97+
query.conflict_target = conflict_target
9098
query.values([obj], insert_fields, update_fields)
9199

92100
# use the upsert query compiler to transform the insert

psqlextra/query.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import List
2+
13
from django.db.models.sql import InsertQuery
24

35

@@ -10,8 +12,9 @@ def __init__(self, *args, **kwargs):
1012
super(PostgresUpsertQuery, self).__init__(*args, **kwargs)
1113

1214
self.update_fields = []
15+
self.conflict_target = []
1316

14-
def values(self, objs, insert_fields, update_fields):
17+
def values(self, objs: List, insert_fields: List, update_fields: List):
1518
"""Sets the values to be used in this query.
1619
1720
Insert fields are fields that are definitely

tests/benchmarks/test_upsert.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,10 @@ def _native_upsert(model, random_value):
2323
"""Performs a concurrency safe upsert
2424
using the native PostgreSQL upsert."""
2525

26-
return model.objects.upsert_and_get(field=random_value)
26+
return model.objects.upsert_and_get(
27+
conflict_target=['field'],
28+
fields=dict(field=random_value)
29+
)
2730

2831

2932
@pytest.mark.django_db()

0 commit comments

Comments
 (0)