Skip to content

Commit 552fb27

Browse files
committed
Custom CompositePrimaryKey and foreign key support on partitioned models
1 parent a5bd4a6 commit 552fb27

File tree

2 files changed

+325
-7
lines changed

2 files changed

+325
-7
lines changed

psqlextra/models/partitioned.py

Lines changed: 131 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
1-
from typing import Iterable
1+
from typing import Iterable, List, Optional, Tuple
22

3+
from django.core.exceptions import ImproperlyConfigured
4+
from django.db import models
35
from django.db.models.base import ModelBase
6+
from django.db.models.fields.composite import CompositePrimaryKey
7+
from django.db.models.options import Options
48

59
from psqlextra.types import PostgresPartitioningMethod
610

@@ -20,19 +24,140 @@ class PostgresPartitionedModelMeta(ModelBase):
2024
default_key: Iterable[str] = []
2125

2226
def __new__(cls, name, bases, attrs, **kwargs):
23-
new_class = super().__new__(cls, name, bases, attrs, **kwargs)
24-
meta_class = attrs.pop("PartitioningMeta", None)
27+
partitioning_meta_class = attrs.pop("PartitioningMeta", None)
28+
29+
partitioning_method = getattr(partitioning_meta_class, "method", None)
30+
partitioning_key = getattr(partitioning_meta_class, "key", None)
31+
special = getattr(partitioning_meta_class, "special", None)
2532

26-
method = getattr(meta_class, "method", None)
27-
key = getattr(meta_class, "key", None)
33+
if special:
34+
cls._create_primary_key(attrs, partitioning_key)
2835

2936
patitioning_meta = PostgresPartitionedModelOptions(
30-
method=method or cls.default_method, key=key or cls.default_key
37+
method=partitioning_method or cls.default_method,
38+
key=partitioning_key or cls.default_key,
3139
)
3240

41+
new_class = super().__new__(cls, name, bases, attrs, **kwargs)
3342
new_class.add_to_class("_partitioning_meta", patitioning_meta)
3443
return new_class
3544

45+
@classmethod
46+
def _create_primary_key(cls, attrs, partitioning_key: Optional[List[str]]):
47+
pk = cls._find_primary_key(attrs)
48+
if pk and isinstance(pk[1], CompositePrimaryKey):
49+
return
50+
51+
if not pk:
52+
attrs["id"] = attrs.get("id") or cls._create_auto_field(attrs)
53+
pk_fields = ["id"]
54+
else:
55+
pk_fields = [pk[0]]
56+
57+
unique_pk_fields = set(pk_fields + (partitioning_key or []))
58+
if len(unique_pk_fields) <= 1:
59+
return
60+
61+
auto_generated_pk = CompositePrimaryKey(*sorted(unique_pk_fields))
62+
attrs["pk"] = auto_generated_pk
63+
64+
@classmethod
65+
def _create_auto_field(cls, attrs):
66+
app_label = attrs.get("app_label")
67+
meta_class = attrs.get("Meta", None)
68+
69+
pk_class = Options(meta_class, app_label)._get_default_pk_class()
70+
return pk_class(verbose_name="ID", primary_key=True, auto_created=True)
71+
72+
@classmethod
73+
def _find_primary_key(cls, attrs) -> Optional[Tuple[str, models.Field]]:
74+
"""Gets the field that has been marked by the user as the primary key
75+
field for this model.
76+
77+
This is quite complex because Django allows a variety of options:
78+
79+
1. No PK at all. In this case, Django generates one named `id`
80+
as an auto-increment integer (AutoField)
81+
82+
2. One field that has `primary_key=True`. Any field can have
83+
this attribute, but Django would error if there were more.
84+
85+
3. One field named `pk`.
86+
87+
4. One field that has `primary_key=True` and a field that
88+
is of type `CompositePrimaryKey` that includes that
89+
field.
90+
91+
Since a table can only have one primary key, our goal here
92+
is to find the field (if any) that is going to become
93+
the primary key of the table.
94+
95+
Our logic is straight forward:
96+
97+
1. If there is a `CompositePrimaryKey`, that field becomes the primary key.
98+
99+
2. If there is a field with `primary_key=True`, that field becomes the primary key.
100+
101+
3. There is no primary key.
102+
"""
103+
104+
fields = {
105+
name: value
106+
for name, value in attrs.items()
107+
if isinstance(value, models.Field)
108+
}
109+
110+
fields_marked_as_pk = {
111+
name: value for name, value in fields.items() if value.primary_key
112+
}
113+
114+
# We cannot let the user define a field named `pk` that is not a CompositePrimaryKey
115+
# already because when we generate a primary key, we want to name it `pk`.
116+
field_named_pk = attrs.get("pk")
117+
if field_named_pk and not field_named_pk.primary_key:
118+
raise ImproperlyConfigured(
119+
"You cannot define a field named `pk` that is not a primary key."
120+
)
121+
122+
if field_named_pk:
123+
if not isinstance(field_named_pk, CompositePrimaryKey):
124+
raise ImproperlyConfigured(
125+
"You cannot define a field named `pk` that is not a composite primary key on a partitioned model. Either make `pk` a CompositePrimaryKey or rename it."
126+
)
127+
128+
return ("pk", field_named_pk)
129+
130+
if not fields_marked_as_pk:
131+
return None
132+
133+
# Make sure the user didn't define N primary keys. Django would also warn
134+
# about this.
135+
#
136+
# One exception is a set up such as:
137+
#
138+
# >>> id = models.AutoField(primary_key=True)
139+
# >>> timestamp = models.DateTimeField()
140+
# >>> pk = models.CompositePrimaryKey("id", "timestamp")
141+
#
142+
# In this case, both `id` and `pk` are marked as primary key. Django
143+
# allows this and just ignores the `primary_key=True` attribute
144+
# on all the other fields except the composite one.
145+
#
146+
# We also handle this as expected and treat the CompositePrimaryKey
147+
# as the primary key.
148+
sorted_fields_marked_as_pk = sorted(
149+
list(fields_marked_as_pk.items()),
150+
key=lambda pair: 0
151+
if isinstance(pair[1], CompositePrimaryKey)
152+
else 1,
153+
)
154+
if len(sorted_fields_marked_as_pk[1:]) > 1:
155+
raise ImproperlyConfigured(
156+
"You cannot mark more than one fields as a primary key."
157+
)
158+
159+
return sorted_fields_marked_as_pk[0]
160+
36161

37162
class PostgresPartitionedModel(
38163
PostgresModel, metaclass=PostgresPartitionedModelMeta

tests/test_partitioned_model.py

Lines changed: 194 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,13 @@
1+
import django
2+
import pytest
3+
4+
from django.core.exceptions import ImproperlyConfigured
5+
from django.db import models
6+
17
from psqlextra.models import PostgresPartitionedModel
28
from psqlextra.types import PostgresPartitioningMethod
39

4-
from .fake_model import define_fake_partitioned_model
10+
from .fake_model import define_fake_model, define_fake_partitioned_model
511

612

713
def test_partitioned_model_abstract():
@@ -70,3 +76,190 @@ def test_partitioned_model_key_option_none():
7076
model = define_fake_partitioned_model(partitioning_options=dict(key=None))
7177

7278
assert model._partitioning_meta.key == []
79+
80+
81+
@pytest.mark.skipif(
82+
django.VERSION < (5, 2),
83+
reason="Django < 5.2 doesn't implement composite primary keys",
84+
)
85+
def test_partitioned_model_custom_composite_primary_key_with_auto_field():
86+
model = define_fake_partitioned_model(
87+
fields={
88+
"auto_id": models.AutoField(),
89+
"my_custom_pk": models.CompositePrimaryKey("auto_id", "timestamp"),
90+
"timestamp": models.DateTimeField(),
91+
},
92+
partitioning_options=dict(key=["timestamp"], special=True),
93+
)
94+
95+
assert isinstance(model._meta.pk, models.CompositePrimaryKey)
96+
assert model._meta.pk.name == "my_custom_pk"
97+
assert model._meta.pk.columns == ("auto_id", "timestamp")
98+
99+
100+
@pytest.mark.skipif(
101+
django.VERSION < (5, 2),
102+
reason="Django < 5.2 doesn't implement composite primary keys",
103+
)
104+
def test_partitioned_model_custom_composite_primary_key_with_id_field():
105+
model = define_fake_partitioned_model(
106+
fields={
107+
"id": models.IntegerField(),
108+
"my_custom_pk": models.CompositePrimaryKey("id", "timestamp"),
109+
"timestamp": models.DateTimeField(),
110+
},
111+
partitioning_options=dict(key=["timestamp"], special=True),
112+
)
113+
114+
assert isinstance(model._meta.pk, models.CompositePrimaryKey)
115+
assert model._meta.pk.name == "my_custom_pk"
116+
assert model._meta.pk.columns == ("id", "timestamp")
117+
118+
119+
@pytest.mark.skipif(
120+
django.VERSION < (5, 2),
121+
reason="Django < 5.2 doesn't implement composite primary keys",
122+
)
123+
def test_partitioned_model_custom_composite_primary_key_named_id():
124+
model = define_fake_partitioned_model(
125+
fields={
126+
"other_field": models.TextField(),
127+
"id": models.CompositePrimaryKey("other_field", "timestamp"),
128+
"timestamp": models.DateTimeField(),
129+
},
130+
partitioning_options=dict(key=["timestamp"], special=True),
131+
)
132+
133+
assert isinstance(model._meta.pk, models.CompositePrimaryKey)
134+
assert model._meta.pk.name == "id"
135+
assert model._meta.pk.columns == ("other_field", "timestamp")
136+
137+
138+
@pytest.mark.skipif(
139+
django.VERSION < (5, 2),
140+
reason="Django < 5.2 doesn't implement composite primary keys",
141+
)
142+
def test_partitioned_model_field_named_pk_not_composite_not_primary():
143+
with pytest.raises(ImproperlyConfigured):
144+
define_fake_partitioned_model(
145+
fields={
146+
"pk": models.TextField(),
147+
"id": models.CompositePrimaryKey("other_field", "timestamp"),
148+
"timestamp": models.DateTimeField(),
149+
},
150+
partitioning_options=dict(key=["timestamp"], special=True),
151+
)
152+
153+
154+
@pytest.mark.skipif(
155+
django.VERSION < (5, 2),
156+
reason="Django < 5.2 doesn't implement composite primary keys",
157+
)
158+
def test_partitioned_model_field_named_pk_not_composite():
159+
with pytest.raises(ImproperlyConfigured):
160+
define_fake_partitioned_model(
161+
fields={
162+
"pk": models.AutoField(primary_key=True),
163+
"timestamp": models.DateTimeField(),
164+
},
165+
partitioning_options=dict(key=["timestamp"], special=True),
166+
)
167+
168+
169+
@pytest.mark.skipif(
170+
django.VERSION < (5, 2),
171+
reason="Django < 5.2 doesn't implement composite primary keys",
172+
)
173+
def test_partitioned_model_field_multiple_pks():
174+
with pytest.raises(ImproperlyConfigured):
175+
define_fake_partitioned_model(
176+
fields={
177+
"id": models.AutoField(primary_key=True),
178+
"another_pk": models.TextField(primary_key=True),
179+
"timestamp": models.DateTimeField(),
180+
"real_pk": models.CompositePrimaryKey("id", "timestamp"),
181+
},
182+
partitioning_options=dict(key=["timestamp"], special=True),
183+
)
184+
185+
186+
@pytest.mark.skipif(
187+
django.VERSION < (5, 2),
188+
reason="Django < 5.2 doesn't implement composite primary keys",
189+
)
190+
def test_partitioned_model_no_pk_defined():
191+
model = define_fake_partitioned_model(
192+
fields={
193+
"timestamp": models.DateTimeField(),
194+
},
195+
partitioning_options=dict(key=["timestamp"], special=True),
196+
)
197+
198+
assert isinstance(model._meta.pk, models.CompositePrimaryKey)
199+
assert model._meta.pk.name == "pk"
200+
assert model._meta.pk.columns == ("id", "timestamp")
201+
202+
id_field = model._meta.get_field("id")
203+
assert id_field.name == "id"
204+
assert id_field.column == "id"
205+
assert isinstance(id_field, models.AutoField)
206+
assert id_field.primary_key is True
207+
208+
209+
@pytest.mark.skipif(
210+
django.VERSION < (5, 2),
211+
reason="Django < 5.2 doesn't implement composite primary keys",
212+
)
213+
def test_partitioned_model_composite_primary_key():
214+
model = define_fake_partitioned_model(
215+
fields={
216+
"id": models.AutoField(primary_key=True),
217+
"pk": models.CompositePrimaryKey("id", "timestamp"),
218+
"timestamp": models.DateTimeField(),
219+
},
220+
partitioning_options=dict(key=["timestamp"], special=True),
221+
)
222+
223+
assert isinstance(model._meta.pk, models.CompositePrimaryKey)
224+
assert model._meta.pk.name == "pk"
225+
assert model._meta.pk.columns == ("id", "timestamp")
226+
227+
228+
@pytest.mark.skipif(
229+
django.VERSION < (5, 2),
230+
reason="Django < 5.2 doesn't implement composite primary keys",
231+
)
232+
def test_partitioned_model_composite_primary_key_foreign_key():
233+
model = define_fake_partitioned_model(
234+
fields={
235+
"timestamp": models.DateTimeField(),
236+
},
237+
partitioning_options=dict(key=["timestamp"], special=True),
238+
)
239+
240+
define_fake_model(
241+
fields={
242+
"model": models.ForeignKey(model, on_delete=models.CASCADE),
243+
},
244+
)
245+
246+
247+
@pytest.mark.skipif(
248+
django.VERSION < (5, 2),
249+
reason="Django < 5.2 doesn't implement composite primary keys",
250+
)
251+
def test_partitioned_model_custom_composite_primary_key_foreign_key():
252+
model = define_fake_partitioned_model(
253+
fields={
254+
"id": models.TextField(primary_key=True),
255+
"timestamp": models.DateTimeField(),
256+
"custom": models.CompositePrimaryKey("id", "timestamp"),
257+
},
258+
partitioning_options=dict(key=["timestamp"], special=True),
259+
)
260+
261+
define_fake_model(
262+
fields={
263+
"model": models.ForeignKey(model, on_delete=models.CASCADE),
264+
},
265+
)

0 commit comments

Comments
 (0)