22
33import django
44from django .conf import settings
5- from django .core .exceptions import ImproperlyConfigured
65from django .db import models , transaction
76from django .db .models .sql import UpdateQuery
87from django .db .models .sql .constants import CURSOR
98from django .db .models .fields import NOT_PROVIDED
9+ from django .core .exceptions import ImproperlyConfigured , SuspiciousOperation
1010
1111from . import signals
1212from .compiler import (PostgresReturningUpdateCompiler ,
@@ -143,8 +143,13 @@ def bulk_insert(self, rows):
143143 Returns:
144144 """
145145
146- for row in rows :
147- self .insert (** row )
146+ if self .conflict_target or self .conflict_action :
147+ compiler = self ._build_insert_compiler (rows )
148+ compiler .execute_sql (return_id = True )
149+ return
150+
151+ # no special action required, use the standard Django bulk_create(..)
152+ super ().bulk_create ([self .model (** fields ) for fields in rows ])
148153
149154 def insert (self , ** fields ):
150155 """Creates a new record in the database.
@@ -161,7 +166,7 @@ def insert(self, **fields):
161166 """
162167
163168 if self .conflict_target or self .conflict_action :
164- compiler = self ._build_insert_compiler (** fields )
169+ compiler = self ._build_insert_compiler ([ fields ] )
165170 rows = compiler .execute_sql (return_id = True )
166171 if 'id' in rows [0 ]:
167172 return rows [0 ]['id' ]
@@ -189,7 +194,7 @@ def insert_and_get(self, **fields):
189194 # no special action required, use the standard Django create(..)
190195 return super ().create (** fields )
191196
192- compiler = self ._build_insert_compiler (** fields )
197+ compiler = self ._build_insert_compiler ([ fields ] )
193198 rows = compiler .execute_sql (return_id = False )
194199
195200 columns = rows [0 ]
@@ -248,27 +253,46 @@ def upsert_and_get(self, conflict_target: List, fields: Dict):
248253 self .on_conflict (conflict_target , ConflictAction .UPDATE )
249254 return self .insert_and_get (** fields )
250255
251- def _build_insert_compiler (self , ** fields ):
256+ def _build_insert_compiler (self , rows : List [ Dict ] ):
252257 """Builds the SQL compiler for a insert query.
253258
259+ Arguments:
260+ rows:
261+ A list of dictionaries, where each entry
262+ describes a record to insert.
263+
254264 Returns:
255265 The SQL compiler for the insert.
256266 """
257267
258- # create an empty object to store the result in
259- obj = self .model (** fields )
268+ # create model objects, we also have to detect cases
269+ # such as:
270+ # [dict(first_name='swen'), dict(fist_name='swen', last_name='kooij')]
271+ # we need to be certain that each row specifies the exact same
272+ # amount of fields/columns
273+ objs = []
274+ field_count = len (rows [0 ])
275+ for index , row in enumerate (rows ):
276+ if field_count != len (row ):
277+ raise SuspiciousOperation ((
278+ 'In bulk upserts, you cannot have rows with different field '
279+ 'configurations. Row {0} has a different field config than '
280+ 'the first row.'
281+ ).format (index ))
282+
283+ objs .append (self .model (** row ))
260284
261285 # indicate this query is going to perform write
262286 self ._for_write = True
263287
264288 # get the fields to be used during update/insert
265- insert_fields , update_fields = self ._get_upsert_fields (fields )
289+ insert_fields , update_fields = self ._get_upsert_fields (rows [ 0 ] )
266290
267291 # build a normal insert query
268292 query = PostgresInsertQuery (self .model )
269293 query .conflict_action = self .conflict_action
270294 query .conflict_target = self .conflict_target
271- query .values ([ obj ] , insert_fields , update_fields )
295+ query .values (objs , insert_fields , update_fields )
272296
273297 # use the postgresql insert query compiler to transform the insert
274298 # into an special postgresql insert
@@ -401,11 +425,11 @@ def __del__(self):
401425 if self ._signals_connected is False :
402426 return
403427
404- django .db .models .signals .post_save .disconnect (
405- self ._on_model_save , sender = self .model )
428+ # django.db.models.signals.post_save.disconnect(
429+ # self._on_model_save, sender=self.model)
406430
407- django .db .models .signals .pre_delete .disconnect (
408- self ._on_model_delete , sender = self .model )
431+ # django.db.models.signals.pre_delete.disconnect(
432+ # self._on_model_delete, sender=self.model)
409433
410434 def get_queryset (self ):
411435 """Gets the query set to be used on this manager."""
0 commit comments