Skip to content

Commit 9a17031

Browse files
committed
Keypoint support draft
1 parent f114078 commit 9a17031

File tree

5 files changed

+149
-95
lines changed

5 files changed

+149
-95
lines changed

src/superannotate/lib/core/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from lib.core.enums import ImageQuality
1111
from lib.core.enums import ProjectStatus
1212
from lib.core.enums import ProjectType
13+
from lib.core.enums import StepsType
1314
from lib.core.enums import TrainingStatus
1415
from lib.core.enums import UploadState
1516
from lib.core.enums import UserRole
@@ -186,6 +187,7 @@ def setup_logging(level=DEFAULT_LOGGING_LEVEL, file_path=LOG_FILE_LOCATION):
186187
FolderStatus,
187188
ProjectStatus,
188189
ProjectType,
190+
StepsType,
189191
UserRole,
190192
UploadState,
191193
TrainingStatus,

src/superannotate/lib/core/enums.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,11 @@ class ProjectType(BaseTitledEnum):
116116
def images(self):
117117
return self.VECTOR.value, self.PIXEL.value, self.TILED.value
118118

119+
class StepsType(Enum):
120+
INITIAL = 1
121+
BASIC = 2
122+
KEYPOINT = 3
123+
119124

120125
class UserRole(BaseTitledEnum):
121126
CONTRIBUTOR = "Contributor", 4

src/superannotate/lib/core/serviceproviders.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,10 +264,18 @@ def set_settings(
264264
def list_steps(self, project: entities.ProjectEntity):
265265
raise NotImplementedError
266266

267+
@abstractmethod
268+
def list_keypoint_steps(self, project: entities.ProjectEntity):
269+
raise NotImplementedError
270+
267271
@abstractmethod
268272
def set_step(self, project: entities.ProjectEntity, step: entities.StepEntity):
269273
raise NotImplementedError
270274

275+
@abstractmethod
276+
def set_keypoint_steps(self, project: entities.ProjectEntity, steps):
277+
raise NotImplementedError
278+
271279
@abstractmethod
272280
def set_steps(self, project: entities.ProjectEntity, steps: list):
273281
raise NotImplementedError

src/superannotate/lib/core/usecases/projects.py

Lines changed: 119 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from collections import defaultdict
44
from typing import List
55

6-
import lib.core as constances
6+
import lib.core as constants
77
from lib.core.conditions import Condition
88
from lib.core.conditions import CONDITION_EQ as EQ
99
from lib.core.entities import ContributorEntity
@@ -228,12 +228,12 @@ def __init__(
228228

229229
def validate_settings(self):
230230
for setting in self._project.settings[:]:
231-
if setting.attribute not in constances.PROJECT_SETTINGS_VALID_ATTRIBUTES:
231+
if setting.attribute not in constants.PROJECT_SETTINGS_VALID_ATTRIBUTES:
232232
self._project.settings.remove(setting)
233233
if setting.attribute == "ImageQuality" and isinstance(setting.value, str):
234-
setting.value = constances.ImageQuality(setting.value).value
234+
setting.value = constants.ImageQuality(setting.value).value
235235
elif setting.attribute == "FrameRate":
236-
if not self._project.type == constances.ProjectType.VIDEO.value:
236+
if not self._project.type == constants.ProjectType.VIDEO.value:
237237
raise AppValidationException(
238238
"FrameRate is available only for Video projects"
239239
)
@@ -263,14 +263,14 @@ def validate_project_name(self):
263263
if (
264264
len(
265265
set(self._project.name).intersection(
266-
constances.SPECIAL_CHARACTERS_IN_PROJECT_FOLDER_NAMES
266+
constants.SPECIAL_CHARACTERS_IN_PROJECT_FOLDER_NAMES
267267
)
268268
)
269269
> 0
270270
):
271271
self._project.name = "".join(
272272
"_"
273-
if char in constances.SPECIAL_CHARACTERS_IN_PROJECT_FOLDER_NAMES
273+
if char in constants.SPECIAL_CHARACTERS_IN_PROJECT_FOLDER_NAMES
274274
else char
275275
for char in self._project.name
276276
)
@@ -291,7 +291,7 @@ def validate_project_name(self):
291291
def execute(self):
292292
if self.is_valid():
293293
# new projects can only have the status of NotStarted
294-
self._project.status = constances.ProjectStatus.NotStarted.value
294+
self._project.status = constants.ProjectStatus.NotStarted.value
295295
response = self._service_provider.projects.create(self._project)
296296
if not response.ok:
297297
self._response.errors = response.error
@@ -326,7 +326,7 @@ def execute(self):
326326
data["classes"] = self._project.classes
327327
logger.info(
328328
f"Created project {entity.name} (ID {entity.id}) "
329-
f"with type {constances.ProjectType(self._response.data.type).name}."
329+
f"with type {constants.ProjectType(self._response.data.type).name}."
330330
)
331331
return self._response
332332

@@ -368,12 +368,12 @@ def __init__(
368368

369369
def validate_settings(self):
370370
for setting in self._project.settings[:]:
371-
if setting.attribute not in constances.PROJECT_SETTINGS_VALID_ATTRIBUTES:
371+
if setting.attribute not in constants.PROJECT_SETTINGS_VALID_ATTRIBUTES:
372372
self._project.settings.remove(setting)
373373
if setting.attribute == "ImageQuality" and isinstance(setting.value, str):
374-
setting.value = constances.ImageQuality(setting.value).value
374+
setting.value = constants.ImageQuality(setting.value).value
375375
elif setting.attribute == "FrameRate":
376-
if not self._project.type == constances.ProjectType.VIDEO.value:
376+
if not self._project.type == constants.ProjectType.VIDEO.value:
377377
raise AppValidationException(
378378
"FrameRate is available only for Video projects"
379379
)
@@ -404,14 +404,14 @@ def validate_project_name(self):
404404
if (
405405
len(
406406
set(self._project.name).intersection(
407-
constances.SPECIAL_CHARACTERS_IN_PROJECT_FOLDER_NAMES
407+
constants.SPECIAL_CHARACTERS_IN_PROJECT_FOLDER_NAMES
408408
)
409409
)
410410
> 0
411411
):
412412
self._project.name = "".join(
413413
"_"
414-
if char in constances.SPECIAL_CHARACTERS_IN_PROJECT_FOLDER_NAMES
414+
if char in constants.SPECIAL_CHARACTERS_IN_PROJECT_FOLDER_NAMES
415415
else char
416416
for char in self._project.name
417417
)
@@ -484,26 +484,33 @@ def __init__(self, project: ProjectEntity, service_provider: BaseServiceProvider
484484
self._service_provider = service_provider
485485

486486
def validate_project_type(self):
487-
if self._project.type in constances.LIMITED_FUNCTIONS:
487+
if self._project.type in constants.LIMITED_FUNCTIONS:
488488
raise AppValidationException(
489-
constances.LIMITED_FUNCTIONS[self._project.type]
489+
constants.LIMITED_FUNCTIONS[self._project.type]
490490
)
491491

492492
def execute(self):
493493
if self.is_valid():
494-
data = []
495-
steps = self._service_provider.projects.list_steps(self._project).data
496-
for step in steps:
497-
step_data = step.dict()
498-
annotation_classes = self._service_provider.annotation_classes.list(
499-
Condition("project_id", self._project.id, EQ)
500-
).data
501-
for annotation_class in annotation_classes:
502-
if annotation_class.id == step.class_id:
503-
step_data["className"] = annotation_class.name
504-
break
505-
data.append(step_data)
506-
self._response.data = data
494+
project_settings = self._service_provider.projects.list_settings(project=self._project).data
495+
step_setting = next((i for i in project_settings if i.attribute == "WorkflowType"), None)
496+
if step_setting.value == constants.StepsType.BASIC:
497+
data = []
498+
steps = self._service_provider.projects.list_steps(self._project).data
499+
for step in steps:
500+
step_data = step.dict()
501+
annotation_classes = self._service_provider.annotation_classes.list(
502+
Condition("project_id", self._project.id, EQ)
503+
).data
504+
for annotation_class in annotation_classes:
505+
if annotation_class.id == step.class_id:
506+
step_data["className"] = annotation_class.name
507+
break
508+
data.append(step_data)
509+
self._response.data = data
510+
else:
511+
steps = self._service_provider.projects.list_keypoint_steps(self._project).data
512+
raise NotImplementedError
513+
507514
return self._response
508515

509516

@@ -524,19 +531,19 @@ def validate_image_quality(self):
524531
if setting["attribute"].lower() == "imagequality" and isinstance(
525532
setting["value"], str
526533
):
527-
setting["value"] = constances.ImageQuality(setting["value"]).value
534+
setting["value"] = constants.ImageQuality(setting["value"]).value
528535
return
529536

530537
def validate_project_type(self):
531538
for attribute in self._to_update:
532539
if attribute.get(
533540
"attribute", ""
534541
) == "ImageQuality" and self._project.type in [
535-
constances.ProjectType.VIDEO.value,
536-
constances.ProjectType.DOCUMENT.value,
542+
constants.ProjectType.VIDEO.value,
543+
constants.ProjectType.DOCUMENT.value,
537544
]:
538545
raise AppValidationException(
539-
constances.DEPRICATED_DOCUMENT_VIDEO_MESSAGE
546+
constants.DEPRICATED_DOCUMENT_VIDEO_MESSAGE
540547
)
541548

542549
def execute(self):
@@ -552,7 +559,7 @@ def execute(self):
552559
for new_setting in self._to_update:
553560
if (
554561
new_setting["attribute"]
555-
in constances.PROJECT_SETTINGS_VALID_ATTRIBUTES
562+
in constants.PROJECT_SETTINGS_VALID_ATTRIBUTES
556563
):
557564
new_settings_to_update.append(
558565
SettingEntity(
@@ -586,73 +593,90 @@ def __init__(
586593
self._project = project
587594

588595
def validate_project_type(self):
589-
if self._project.type in constances.LIMITED_FUNCTIONS:
596+
if self._project.type in constants.LIMITED_FUNCTIONS:
590597
raise AppValidationException(
591-
constances.LIMITED_FUNCTIONS[self._project.type]
598+
constants.LIMITED_FUNCTIONS[self._project.type]
592599
)
600+
def set_basic_steps(self, annotation_classes):
601+
annotation_classes_map = {}
602+
annotations_classes_attributes_map = {}
603+
for annotation_class in annotation_classes:
604+
annotation_classes_map[annotation_class.name] = annotation_class.id
605+
for attribute_group in annotation_class.attribute_groups:
606+
for attribute in attribute_group.attributes:
607+
annotations_classes_attributes_map[
608+
f"{annotation_class.name}__{attribute_group.name}__{attribute.name}"
609+
] = attribute.id
610+
611+
for step in [step for step in self._steps if "className" in step]:
612+
if step.get("id"):
613+
del step["id"]
614+
step["class_id"] = annotation_classes_map.get(step["className"], None)
615+
if not step["class_id"]:
616+
raise AppException("Annotation class not found.")
617+
self._service_provider.projects.set_steps(
618+
project=self._project,
619+
steps=self._steps,
620+
)
621+
existing_steps = self._service_provider.projects.list_steps(
622+
self._project
623+
).data
624+
existing_steps_map = {}
625+
for steps in existing_steps:
626+
existing_steps_map[steps.step] = steps.id
627+
628+
req_data = []
629+
for step in self._steps:
630+
annotation_class_name = step["className"]
631+
for attribute in step["attribute"]:
632+
attribute_name = attribute["attribute"]["name"]
633+
attribute_group_name = attribute["attribute"]["attribute_group"][
634+
"name"
635+
]
636+
if not annotations_classes_attributes_map.get(
637+
f"{annotation_class_name}__{attribute_group_name}__{attribute_name}",
638+
None,
639+
):
640+
raise AppException(
641+
f"Attribute group name or attribute name not found {attribute_group_name}."
642+
)
643+
644+
if not existing_steps_map.get(step["step"], None):
645+
raise AppException("Couldn't find step in steps")
646+
req_data.append(
647+
{
648+
"workflow_id": existing_steps_map[step["step"]],
649+
"attribute_id": annotations_classes_attributes_map[
650+
f"{annotation_class_name}__{attribute_group_name}__{attribute_name}"
651+
],
652+
}
653+
)
654+
self._service_provider.projects.set_project_step_attributes(
655+
project=self._project,
656+
attributes=req_data,
657+
)
658+
659+
def set_keypoint_steps(self, annotation_classes):
660+
self._service_provider.projects.set_keypoint_steps(
661+
project=self._project,
662+
steps=self._steps,
663+
)
593664

594665
def execute(self):
595666
if self.is_valid():
667+
596668
annotation_classes = self._service_provider.annotation_classes.list(
597669
Condition("project_id", self._project.id, EQ)
598670
).data
599-
annotation_classes_map = {}
600-
annotations_classes_attributes_map = {}
601-
for annotation_class in annotation_classes:
602-
annotation_classes_map[annotation_class.name] = annotation_class.id
603-
for attribute_group in annotation_class.attribute_groups:
604-
for attribute in attribute_group.attributes:
605-
annotations_classes_attributes_map[
606-
f"{annotation_class.name}__{attribute_group.name}__{attribute.name}"
607-
] = attribute.id
608-
609-
for step in [step for step in self._steps if "className" in step]:
610-
if step.get("id"):
611-
del step["id"]
612-
step["class_id"] = annotation_classes_map.get(step["className"], None)
613-
if not step["class_id"]:
614-
raise AppException("Annotation class not found.")
615-
self._service_provider.projects.set_steps(
616-
project=self._project,
617-
steps=self._steps,
618-
)
619-
existing_steps = self._service_provider.projects.list_steps(
620-
self._project
621-
).data
622-
existing_steps_map = {}
623-
for steps in existing_steps:
624-
existing_steps_map[steps.step] = steps.id
625-
626-
req_data = []
627-
for step in self._steps:
628-
annotation_class_name = step["className"]
629-
for attribute in step["attribute"]:
630-
attribute_name = attribute["attribute"]["name"]
631-
attribute_group_name = attribute["attribute"]["attribute_group"][
632-
"name"
633-
]
634-
if not annotations_classes_attributes_map.get(
635-
f"{annotation_class_name}__{attribute_group_name}__{attribute_name}",
636-
None,
637-
):
638-
raise AppException(
639-
f"Attribute group name or attribute name not found {attribute_group_name}."
640-
)
641671

642-
if not existing_steps_map.get(step["step"], None):
643-
raise AppException("Couldn't find step in steps")
644-
req_data.append(
645-
{
646-
"workflow_id": existing_steps_map[step["step"]],
647-
"attribute_id": annotations_classes_attributes_map[
648-
f"{annotation_class_name}__{attribute_group_name}__{attribute_name}"
649-
],
650-
}
651-
)
652-
self._service_provider.projects.set_project_step_attributes(
653-
project=self._project,
654-
attributes=req_data,
655-
)
672+
project_settings = self._service_provider.projects.list_settings(project=self._project).data
673+
step_setting = next((i for i in project_settings if i.attribute == "WorkflowType"), None)
674+
675+
if step_setting.value == constants.StepsType.BASIC:
676+
self.set_basic_steps(annotation_classes)
677+
else:
678+
self.set_keypoint_steps(annotation_classes)
679+
656680
return self._response
657681

658682

@@ -744,11 +768,11 @@ def execute(self):
744768
team_users = set()
745769
project_users = {user.user_id for user in self._project.users}
746770
for user in self._team.users:
747-
if user.user_role == constances.UserRole.CONTRIBUTOR.value:
771+
if user.user_role == constants.UserRole.CONTRIBUTOR.value:
748772
team_users.add(user.email)
749773
# collecting pending team users which is not admin
750774
for user in self._team.pending_invitations:
751-
if user["user_role"] == constances.UserRole.CONTRIBUTOR.value:
775+
if user["user_role"] == constants.UserRole.CONTRIBUTOR.value:
752776
team_users.add(user["email"])
753777
# collecting pending project users which is not admin
754778
for user in self._project.unverified_users:
@@ -831,9 +855,9 @@ def execute(self):
831855
response = self._service_provider.invite_contributors(
832856
team_id=self._team.id,
833857
# REMINDER UserRole.VIEWER is the contributor for the teams
834-
team_role=constances.UserRole.ADMIN.value
858+
team_role=constants.UserRole.ADMIN.value
835859
if self._set_admin
836-
else constances.UserRole.CONTRIBUTOR.value,
860+
else constants.UserRole.CONTRIBUTOR.value,
837861
emails=to_add,
838862
)
839863
invited, failed = (

0 commit comments

Comments
 (0)