Skip to content

Commit 6a94907

Browse files
committed
Add tests for bulk upserts
1 parent f3b1bb6 commit 6a94907

File tree

5 files changed

+188
-82
lines changed

5 files changed

+188
-82
lines changed

psqlextra/manager.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,23 @@ def on_conflict(self, fields: List[Union[str, Tuple[str]]], action):
129129
self.conflict_action = action
130130
return self
131131

132+
def bulk_insert(self, rows):
133+
"""Creates multiple new records in the database.
134+
135+
This allows specifying custom conflict behavior using .on_conflict().
136+
If no special behavior was specified, this uses the normal Django create(..)
137+
138+
Arguments:
139+
rows:
140+
An array of dictionaries, where each dictionary
141+
describes the fields to insert.
142+
143+
Returns:
144+
"""
145+
146+
for row in rows:
147+
self.insert(**row)
148+
132149
def insert(self, **fields):
133150
"""Creates a new record in the database.
134151

tests/benchmarks/test_insert_nothing.py

Lines changed: 42 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -10,49 +10,45 @@
1010
from ..fake_model import get_fake_model
1111

1212

13-
@pytest.mark.django_db()
14-
class TestInsertNothing(TestCase):
15-
16-
@pytest.mark.benchmark()
17-
@staticmethod
18-
def test_insert_nothing_traditional(benchmark):
19-
model = get_fake_model({
20-
'field': models.CharField(max_length=255, unique=True)
21-
})
22-
23-
random_value = str(uuid.uuid4())[:8]
24-
model.objects.create(field=random_value)
25-
26-
def _traditional_insert(model, random_value):
27-
"""Performs a concurrency safe insert the
28-
traditional way."""
29-
30-
try:
31-
with transaction.atomic():
32-
return model.objects.create(field=random_value)
33-
except IntegrityError:
34-
return model.objects.filter(field=random_value).first()
35-
36-
benchmark(_traditional_insert, model, random_value)
37-
38-
@pytest.mark.benchmark()
39-
@staticmethod
40-
def test_insert_nothing_native(benchmark):
41-
model = get_fake_model({
42-
'field': models.CharField(max_length=255, unique=True)
43-
})
44-
45-
random_value = str(uuid.uuid4())[:8]
46-
model.objects.create(field=random_value)
47-
48-
def _native_insert(model, random_value):
49-
"""Performs a concurrency safeinsert
50-
using the native PostgreSQL conflict resolution."""
51-
52-
return (
53-
model.objects
54-
.on_conflict(['field'], ConflictAction.NOTHING)
55-
.insert_and_get(field=random_value)
56-
)
57-
58-
benchmark(_native_insert, model, random_value)
13+
@pytest.mark.benchmark()
14+
def test_insert_nothing_traditional(benchmark):
15+
model = get_fake_model({
16+
'field': models.CharField(max_length=255, unique=True)
17+
})
18+
19+
random_value = str(uuid.uuid4())[:8]
20+
model.objects.create(field=random_value)
21+
22+
def _traditional_insert(model, random_value):
23+
"""Performs a concurrency safe insert the
24+
traditional way."""
25+
26+
try:
27+
with transaction.atomic():
28+
return model.objects.create(field=random_value)
29+
except IntegrityError:
30+
return model.objects.filter(field=random_value).first()
31+
32+
benchmark(_traditional_insert, model, random_value)
33+
34+
35+
@pytest.mark.benchmark()
36+
def test_insert_nothing_native(benchmark):
37+
model = get_fake_model({
38+
'field': models.CharField(max_length=255, unique=True)
39+
})
40+
41+
random_value = str(uuid.uuid4())[:8]
42+
model.objects.create(field=random_value)
43+
44+
def _native_insert(model, random_value):
45+
"""Performs a concurrency safeinsert
46+
using the native PostgreSQL conflict resolution."""
47+
48+
return (
49+
model.objects
50+
.on_conflict(['field'], ConflictAction.NOTHING)
51+
.insert_and_get(field=random_value)
52+
)
53+
54+
benchmark(_native_insert, model, random_value)

tests/benchmarks/test_upsert.py

Lines changed: 31 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -8,51 +8,46 @@
88
from ..fake_model import get_fake_model
99

1010

11-
@pytest.mark.django_db()
1211
@pytest.mark.benchmark()
13-
class TestUpsert(TestCase):
12+
def test_upsert_traditional(benchmark):
13+
model = get_fake_model({
14+
'field': models.CharField(max_length=255, unique=True)
15+
})
1416

15-
@pytest.mark.benchmark()
16-
@staticmethod
17-
def test_upsert_traditional(benchmark):
18-
model = get_fake_model({
19-
'field': models.CharField(max_length=255, unique=True)
20-
})
17+
random_value = str(uuid.uuid4())[:8]
18+
model.objects.create(field=random_value)
2119

22-
random_value = str(uuid.uuid4())[:8]
23-
model.objects.create(field=random_value)
20+
def _traditional_upsert(model, random_value):
21+
"""Performs a concurrency safe upsert
22+
the traditional way."""
2423

25-
def _traditional_upsert(model, random_value):
26-
"""Performs a concurrency safe upsert
27-
the traditional way."""
24+
try:
2825

29-
try:
26+
with transaction.atomic():
27+
return model.objects.create(field=random_value)
28+
except IntegrityError:
29+
model.objects.update(field=random_value)
30+
return model.objects.get(field=random_value)
3031

31-
with transaction.atomic():
32-
return model.objects.create(field=random_value)
33-
except IntegrityError:
34-
model.objects.update(field=random_value)
35-
return model.objects.get(field=random_value)
32+
benchmark(_traditional_upsert, model, random_value)
3633

37-
benchmark(_traditional_upsert, model, random_value)
3834

39-
@pytest.mark.benchmark()
40-
@staticmethod
41-
def test_upsert_native(benchmark):
42-
model = get_fake_model({
43-
'field': models.CharField(max_length=255, unique=True)
44-
})
35+
@pytest.mark.benchmark()
36+
def test_upsert_native(benchmark):
37+
model = get_fake_model({
38+
'field': models.CharField(max_length=255, unique=True)
39+
})
4540

46-
random_value = str(uuid.uuid4())[:8]
47-
model.objects.create(field=random_value)
41+
random_value = str(uuid.uuid4())[:8]
42+
model.objects.create(field=random_value)
4843

49-
def _native_upsert(model, random_value):
50-
"""Performs a concurrency safe upsert
51-
using the native PostgreSQL upsert."""
44+
def _native_upsert(model, random_value):
45+
"""Performs a concurrency safe upsert
46+
using the native PostgreSQL upsert."""
5247

53-
return model.objects.upsert_and_get(
54-
conflict_target=['field'],
55-
fields=dict(field=random_value)
56-
)
48+
return model.objects.upsert_and_get(
49+
conflict_target=['field'],
50+
fields=dict(field=random_value)
51+
)
5752

58-
benchmark(_native_upsert, model, random_value)
53+
benchmark(_native_upsert, model, random_value)
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
import uuid
2+
3+
import pytest
4+
5+
from django.test import TestCase
6+
from django.db import models, transaction
7+
from django.db.utils import IntegrityError
8+
9+
from psqlextra.query import ConflictAction
10+
11+
from ..fake_model import get_fake_model
12+
13+
ROW_COUNT = 10000
14+
15+
16+
@pytest.mark.benchmark()
17+
def test_upsert_bulk_naive(benchmark):
18+
model = get_fake_model({
19+
'field': models.CharField(max_length=255, unique=True)
20+
})
21+
22+
rows = []
23+
random_values = []
24+
for i in range(0, ROW_COUNT):
25+
random_value = str(uuid.uuid4())
26+
random_values.append(random_value)
27+
rows.append(model(field=random_value))
28+
29+
model.objects.bulk_create(rows)
30+
31+
def _native_upsert(model, random_values):
32+
"""Performs a concurrency safe upsert
33+
using the native PostgreSQL upsert."""
34+
35+
rows = [
36+
dict(field=random_value)
37+
for random_value in random_values
38+
]
39+
40+
for row in rows:
41+
model.objects.on_conflict(['field'], ConflictAction.UPDATE).insert(**row)
42+
43+
benchmark(_native_upsert, model, random_values)
44+
45+
46+
@pytest.mark.benchmark()
47+
def test_upsert_bulk_native(benchmark):
48+
model = get_fake_model({
49+
'field': models.CharField(max_length=255, unique=True)
50+
})
51+
52+
rows = []
53+
random_values = []
54+
for i in range(0, ROW_COUNT):
55+
random_value = str(uuid.uuid4())
56+
random_values.append(random_value)
57+
rows.append(model(field=random_value))
58+
59+
model.objects.bulk_create(rows)
60+
61+
def _native_upsert(model, random_values):
62+
"""Performs a concurrency safe upsert
63+
using the native PostgreSQL upsert."""
64+
65+
rows = [
66+
dict(field=random_value)
67+
for random_value in random_values
68+
]
69+
70+
model.objects.on_conflict(['field'], ConflictAction.UPDATE).bulk_insert(rows)
71+
72+
benchmark(_native_upsert, model, random_values)

tests/test_on_conflict.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -387,3 +387,29 @@ def test_on_conflict_default_value_no_overwrite():
387387

388388
assert obj1.id == obj2.id
389389
assert obj2.title == 'mytitle'
390+
391+
392+
def test_on_conflict_bulk():
393+
"""Tests whether using `on_conflict` with `insert_bulk`
394+
properly works."""
395+
396+
model = get_fake_model({
397+
'title': models.CharField(max_length=255, unique=True)
398+
})
399+
400+
rows = [
401+
dict(title='this is my title'),
402+
dict(title='this is another title'),
403+
dict(title='and another one')
404+
]
405+
406+
(
407+
model.objects
408+
.on_conflict(['title'], ConflictAction.UPDATE)
409+
.bulk_insert(rows)
410+
)
411+
412+
assert model.objects.all().count() == len(rows)
413+
414+
for index, obj in enumerate(list(model.objects.all())):
415+
assert obj.title == rows[index]['title']

0 commit comments

Comments
 (0)