Skip to content

Commit ad336d1

Browse files
committed
Basic, actualy working bulk upserts
1 parent 6a94907 commit ad336d1

File tree

7 files changed

+45
-32
lines changed

7 files changed

+45
-32
lines changed

psqlextra/manager.py

Lines changed: 38 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@
22

33
import django
44
from django.conf import settings
5-
from django.core.exceptions import ImproperlyConfigured
65
from django.db import models, transaction
76
from django.db.models.sql import UpdateQuery
87
from django.db.models.sql.constants import CURSOR
98
from django.db.models.fields import NOT_PROVIDED
9+
from django.core.exceptions import ImproperlyConfigured, SuspiciousOperation
1010

1111
from . import signals
1212
from .compiler import (PostgresReturningUpdateCompiler,
@@ -143,8 +143,13 @@ def bulk_insert(self, rows):
143143
Returns:
144144
"""
145145

146-
for row in rows:
147-
self.insert(**row)
146+
if self.conflict_target or self.conflict_action:
147+
compiler = self._build_insert_compiler(rows)
148+
compiler.execute_sql(return_id=True)
149+
return
150+
151+
# no special action required, use the standard Django bulk_create(..)
152+
super().bulk_create([self.model(**fields) for fields in rows])
148153

149154
def insert(self, **fields):
150155
"""Creates a new record in the database.
@@ -161,7 +166,7 @@ def insert(self, **fields):
161166
"""
162167

163168
if self.conflict_target or self.conflict_action:
164-
compiler = self._build_insert_compiler(**fields)
169+
compiler = self._build_insert_compiler([fields])
165170
rows = compiler.execute_sql(return_id=True)
166171
if 'id' in rows[0]:
167172
return rows[0]['id']
@@ -189,7 +194,7 @@ def insert_and_get(self, **fields):
189194
# no special action required, use the standard Django create(..)
190195
return super().create(**fields)
191196

192-
compiler = self._build_insert_compiler(**fields)
197+
compiler = self._build_insert_compiler([fields])
193198
rows = compiler.execute_sql(return_id=False)
194199

195200
columns = rows[0]
@@ -248,27 +253,46 @@ def upsert_and_get(self, conflict_target: List, fields: Dict):
248253
self.on_conflict(conflict_target, ConflictAction.UPDATE)
249254
return self.insert_and_get(**fields)
250255

251-
def _build_insert_compiler(self, **fields):
256+
def _build_insert_compiler(self, rows: List[Dict]):
252257
"""Builds the SQL compiler for a insert query.
253258
259+
Arguments:
260+
rows:
261+
A list of dictionaries, where each entry
262+
describes a record to insert.
263+
254264
Returns:
255265
The SQL compiler for the insert.
256266
"""
257267

258-
# create an empty object to store the result in
259-
obj = self.model(**fields)
268+
# create model objects, we also have to detect cases
269+
# such as:
270+
# [dict(first_name='swen'), dict(fist_name='swen', last_name='kooij')]
271+
# we need to be certain that each row specifies the exact same
272+
# amount of fields/columns
273+
objs = []
274+
field_count = len(rows[0])
275+
for index, row in enumerate(rows):
276+
if field_count != len(row):
277+
raise SuspiciousOperation((
278+
'In bulk upserts, you cannot have rows with different field '
279+
'configurations. Row {0} has a different field config than '
280+
'the first row.'
281+
).format(index))
282+
283+
objs.append(self.model(**row))
260284

261285
# indicate this query is going to perform write
262286
self._for_write = True
263287

264288
# get the fields to be used during update/insert
265-
insert_fields, update_fields = self._get_upsert_fields(fields)
289+
insert_fields, update_fields = self._get_upsert_fields(rows[0])
266290

267291
# build a normal insert query
268292
query = PostgresInsertQuery(self.model)
269293
query.conflict_action = self.conflict_action
270294
query.conflict_target = self.conflict_target
271-
query.values([obj], insert_fields, update_fields)
295+
query.values(objs, insert_fields, update_fields)
272296

273297
# use the postgresql insert query compiler to transform the insert
274298
# into an special postgresql insert
@@ -401,11 +425,11 @@ def __del__(self):
401425
if self._signals_connected is False:
402426
return
403427

404-
django.db.models.signals.post_save.disconnect(
405-
self._on_model_save, sender=self.model)
428+
# django.db.models.signals.post_save.disconnect(
429+
# self._on_model_save, sender=self.model)
406430

407-
django.db.models.signals.pre_delete.disconnect(
408-
self._on_model_delete, sender=self.model)
431+
# django.db.models.signals.pre_delete.disconnect(
432+
# self._on_model_delete, sender=self.model)
409433

410434
def get_queryset(self):
411435
"""Gets the query set to be used on this manager."""

tests/benchmarks/test_insert_nothing.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import uuid
22

3-
from django.test import TestCase
3+
import pytest
4+
45
from django.db import models, transaction
56
from django.db.utils import IntegrityError
6-
import pytest
77

88
from psqlextra.query import ConflictAction
99

tests/benchmarks/test_upsert.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import uuid
22

3-
from django.test import TestCase
3+
import pytest
4+
45
from django.db import models, transaction
56
from django.db.utils import IntegrityError
6-
import pytest
77

88
from ..fake_model import get_fake_model
99

tests/benchmarks/test_upsert_bulk.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,7 @@
22

33
import pytest
44

5-
from django.test import TestCase
6-
from django.db import models, transaction
7-
from django.db.utils import IntegrityError
5+
from django.db import models
86

97
from psqlextra.query import ConflictAction
108

tests/test_hstore_field.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,4 @@
1-
from django.db import models
2-
31
from psqlextra import HStoreField
4-
from psqlextra.expressions import HStoreRef
5-
6-
from .fake_model import get_fake_model
72

83

94
def test_deconstruct():

tests/test_query_values.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,8 @@
11
import pytest
22

3-
from collections import namedtuple
4-
5-
from django.db.models import CharField, ForeignKey
3+
from django.db.models import ForeignKey
64

75
from psqlextra import HStoreField
8-
from psqlextra.expressions import HStoreRef
96

107
from .fake_model import get_fake_model
118

@@ -68,7 +65,7 @@ def test_values_hstore_key_through_fk():
6865
})
6966

7067
fobj = fmodel.objects.create(name={'en': 'swen', 'ar': 'arabic swen'})
71-
obj = model.objects.create(fk=fobj)
68+
model.objects.create(fk=fobj)
7269

7370
result = list(model.objects.values('fk__name__ar'))[0]
7471
assert result['fk__name__ar'] == fobj.name['ar']

tests/test_signals.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ def mock_signal_handler(signal):
2020
'flag': models.BooleanField(default=False)
2121
})
2222

23-
2423
signal_handler = Mock()
2524
signal.connect(signal_handler, sender=model, weak=False)
2625

0 commit comments

Comments
 (0)