Skip to content

Commit 0c6e0ab

Browse files
committed
Add extra tests for specific conflict actions
1 parent 1018221 commit 0c6e0ab

File tree

4 files changed

+94
-6
lines changed

4 files changed

+94
-6
lines changed

psqlextra/compiler.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,9 @@ def _rewrite_insert(self, sql, params, return_id=False):
9191
) % str(self.query.conflict_action))
9292

9393
def _rewrite_insert_update(self, sql, params, returning):
94+
"""Rewrites a formed SQL INSERT query to include
95+
the ON CONFLICT DO UPDATE clause."""
96+
9497
update_columns = ', '.join([
9598
'{0} = EXCLUDED.{0}'.format(self.qn(field.column))
9699
for field in self.query.update_fields
@@ -114,6 +117,9 @@ def _rewrite_insert_update(self, sql, params, returning):
114117
)
115118

116119
def _rewrite_insert_nothing(self, sql, params, returning):
120+
"""Rewrites a formed SQL INSERT query to include
121+
the ON CONFLICT DO NOTHING clause."""
122+
117123
# build the conflict target, the columns to watch
118124
# for conflicts
119125
conflict_target = self._build_conflict_target()

tests/test_on_conflict.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313

1414
@pytest.mark.parametrize("conflict_action", CONFLICT_ACTIONS)
15-
def test_simple(conflict_action):
15+
def test_on_conflict(conflict_action):
1616
"""Tests whether simple inserts work correctly."""
1717

1818
model = get_fake_model({
@@ -40,7 +40,7 @@ def test_simple(conflict_action):
4040

4141

4242
@pytest.mark.parametrize("conflict_action", CONFLICT_ACTIONS)
43-
def test_auto_fields(conflict_action):
43+
def test_on_conflict_auto_fields(conflict_action):
4444
"""Asserts that fields that automatically add something
4545
to the model automatically still work properly when upserting."""
4646

@@ -81,7 +81,7 @@ def test_auto_fields(conflict_action):
8181

8282

8383
@pytest.mark.parametrize("conflict_action", CONFLICT_ACTIONS)
84-
def test_foreign_key(conflict_action):
84+
def test_on_conflict_foreign_key(conflict_action):
8585
"""Asserts that models with foreign key relationships
8686
can safely be inserted."""
8787

@@ -125,7 +125,7 @@ def test_foreign_key(conflict_action):
125125

126126

127127
@pytest.mark.parametrize("conflict_action", CONFLICT_ACTIONS)
128-
def test_get_partial(conflict_action):
128+
def test_on_conflict_partial_get(conflict_action):
129129
"""Asserts that when doing a insert_and_get with
130130
only part of the columns on the model, all fields
131131
are returned properly."""
@@ -162,7 +162,7 @@ def test_get_partial(conflict_action):
162162

163163

164164
@pytest.mark.parametrize("conflict_action", CONFLICT_ACTIONS)
165-
def test_invalid_conflict_target(conflict_action):
165+
def test_on_conflict_invalid_target(conflict_action):
166166
"""Tests whether specifying a invalid value
167167
for `conflict_target` raises an error."""
168168

@@ -186,7 +186,7 @@ def test_invalid_conflict_target(conflict_action):
186186

187187

188188
@pytest.mark.parametrize("conflict_action", CONFLICT_ACTIONS)
189-
def test_outdated_model(conflict_action):
189+
def test_on_conflict_outdated_model(conflict_action):
190190
"""Tests whether insert properly handles
191191
fields that are in the database but not on the model.
192192

tests/test_on_conflict_nothing.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
from django.db import models
2+
3+
from psqlextra import HStoreField
4+
from psqlextra.query import ConflictAction
5+
6+
from .fake_model import get_fake_model
7+
8+
9+
def test_on_conflict_nothing():
10+
"""Tests whether simple insert NOTHING works correctly."""
11+
12+
model = get_fake_model({
13+
'title': HStoreField(uniqueness=['key1']),
14+
'cookies': models.CharField(max_length=255, null=True)
15+
})
16+
17+
obj1 = (
18+
model.objects
19+
.on_conflict([('title', 'key1')], ConflictAction.NOTHING)
20+
.insert_and_get(title={'key1': 'beer'}, cookies='cheers')
21+
)
22+
23+
obj1.refresh_from_db()
24+
assert obj1.title['key1'] == 'beer'
25+
assert obj1.cookies == 'cheers'
26+
27+
obj2 = (
28+
model.objects
29+
.on_conflict([('title', 'key1')], ConflictAction.NOTHING)
30+
.insert_and_get(title={'key1': 'beer'}, cookies='choco')
31+
)
32+
33+
obj1.refresh_from_db()
34+
obj2.refresh_from_db()
35+
36+
# assert that the 'cookies' field didn't change
37+
assert obj1.id == obj2.id
38+
assert obj1.title['key1'] == 'beer'
39+
assert obj1.cookies == 'cheers'
40+
assert obj2.title['key1'] == 'beer'
41+
assert obj2.cookies == 'cheers'

tests/test_on_conflict_update.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
from django.db import models
2+
3+
from psqlextra import HStoreField
4+
from psqlextra.query import ConflictAction
5+
6+
from .fake_model import get_fake_model
7+
8+
9+
def test_on_conflict_update():
10+
"""Tests whether simple upserts works correctly."""
11+
12+
model = get_fake_model({
13+
'title': HStoreField(uniqueness=['key1']),
14+
'cookies': models.CharField(max_length=255, null=True)
15+
})
16+
17+
obj1 = (
18+
model.objects
19+
.on_conflict([('title', 'key1')], ConflictAction.UPDATE)
20+
.insert_and_get(title={'key1': 'beer'}, cookies='cheers')
21+
)
22+
23+
obj1.refresh_from_db()
24+
assert obj1.title['key1'] == 'beer'
25+
assert obj1.cookies == 'cheers'
26+
27+
obj2 = (
28+
model.objects
29+
.on_conflict([('title', 'key1')], ConflictAction.UPDATE)
30+
.insert_and_get(title={'key1': 'beer'}, cookies='choco')
31+
)
32+
33+
obj1.refresh_from_db()
34+
obj2.refresh_from_db()
35+
36+
# assert both objects are the same
37+
assert obj1.id == obj2.id
38+
assert obj1.title['key1'] == 'beer'
39+
assert obj1.cookies == 'choco'
40+
assert obj2.title['key1'] == 'beer'
41+
assert obj2.cookies == 'choco'

0 commit comments

Comments
 (0)