Skip to content

Commit 7d883ad

Browse files
committed
Add keypoint handling
1 parent 9a17031 commit 7d883ad

File tree

10 files changed

+495
-36
lines changed

10 files changed

+495
-36
lines changed

CHANGELOG.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ History
77
All release highlights of this project will be documented in this file.
88

99
4.4.34 - April 11, 2025
10-
______________________
10+
_______________________
1111

1212
**Added**
1313

docs/source/api_reference/api_project.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,6 @@ Projects
2424
.. automethod:: superannotate.SAClient.add_contributors_to_project
2525
.. automethod:: superannotate.SAClient.get_project_settings
2626
.. automethod:: superannotate.SAClient.set_project_default_image_quality_in_editor
27-
.. automethod:: superannotate.SAClient.set_project_steps
2827
.. automethod:: superannotate.SAClient.get_project_steps
28+
.. automethod:: superannotate.SAClient.set_project_steps
2929
.. automethod:: superannotate.SAClient.get_component_config

src/superannotate/lib/app/interface/sdk_interface.py

Lines changed: 74 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,6 @@
7474
from lib.core.entities.work_managament import WMUserTypeEnum
7575
from lib.core.jsx_conditions import EmptyQuery
7676

77-
7877
logger = logging.getLogger("sa")
7978

8079
NotEmptyStr = constr(strict=True, min_length=1)
@@ -1485,10 +1484,11 @@ def get_project_steps(self, project: Union[str, dict]):
14851484
:param project: project name or metadata
14861485
:type project: str or dict
14871486
1488-
:return: project steps
1489-
:rtype: list of dicts
1487+
:return: A list of step dictionaries,
1488+
or a dictionary containing both steps and their connections (for Keypoint workflows).
1489+
:rtype: list of dicts or dict
14901490
1491-
Response Example:
1491+
Response Example for General Annotation Project:
14921492
::
14931493
14941494
[
@@ -1507,6 +1507,34 @@ def get_project_steps(self, project: Union[str, dict]):
15071507
}
15081508
]
15091509
1510+
Response Example for Keypoint Annotation Project:
1511+
::
1512+
1513+
{
1514+
"steps": [
1515+
{
1516+
"step": 1,
1517+
"className": "Left Shoulder",
1518+
"class_id": "1",
1519+
"attribute": [
1520+
{
1521+
"attribute": {
1522+
"id": 123,
1523+
"group_id": 12
1524+
}
1525+
}
1526+
]
1527+
},
1528+
{
1529+
"step": 2,
1530+
"class_id": "2",
1531+
"className": "Right Shoulder",
1532+
}
1533+
],
1534+
"connections": [
1535+
[1, 2]
1536+
]
1537+
}
15101538
"""
15111539
project_name, _ = extract_project_folder(project)
15121540
project = self.controller.get_project(project_name)
@@ -2503,7 +2531,12 @@ def download_export(
25032531
if response.errors:
25042532
raise AppException(response.errors)
25052533

2506-
def set_project_steps(self, project: Union[NotEmptyStr, dict], steps: List[dict]):
2534+
def set_project_steps(
2535+
self,
2536+
project: Union[NotEmptyStr, dict],
2537+
steps: List[dict],
2538+
connections: List[List[int]] = None,
2539+
):
25072540
"""Sets project's steps.
25082541
25092542
:param project: project name or metadata
@@ -2512,7 +2545,11 @@ def set_project_steps(self, project: Union[NotEmptyStr, dict], steps: List[dict]
25122545
:param steps: new workflow list of dicts
25132546
:type steps: list of dicts
25142547
2515-
Request Example:
2548+
:param connections: Defines connections between keypoint annotation steps.
2549+
Each inner list specifies a pair of step IDs indicating a connection.
2550+
:type connections: list of dicts
2551+
2552+
Request Example for General Annotation Project:
25162553
::
25172554
25182555
sa.set_project_steps(
@@ -2533,10 +2570,40 @@ def set_project_steps(self, project: Union[NotEmptyStr, dict], steps: List[dict]
25332570
}
25342571
]
25352572
)
2573+
2574+
Request Example for Keypoint Annotation Project:
2575+
::
2576+
2577+
sa.set_project_steps(
2578+
project="Pose Estimation Project",
2579+
steps=[
2580+
{
2581+
"step": 1,
2582+
"class_id": 12,
2583+
"attribute": [
2584+
{
2585+
"attribute": {
2586+
"id": 123,
2587+
"group_id": 12
2588+
}
2589+
}
2590+
]
2591+
},
2592+
{
2593+
"step": 2,
2594+
"class_id": 13
2595+
}
2596+
],
2597+
connections=[
2598+
[1, 2]
2599+
]
2600+
)
25362601
"""
25372602
project_name, _ = extract_project_folder(project)
25382603
project = self.controller.get_project(project_name)
2539-
response = self.controller.projects.set_steps(project, steps=steps)
2604+
response = self.controller.projects.set_steps(
2605+
project, steps=steps, connections=connections
2606+
)
25402607
if response.errors:
25412608
raise AppException(response.errors)
25422609

src/superannotate/lib/core/enums.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ class ProjectType(BaseTitledEnum):
116116
def images(self):
117117
return self.VECTOR.value, self.PIXEL.value, self.TILED.value
118118

119+
119120
class StepsType(Enum):
120121
INITIAL = 1
121122
BASIC = 2

src/superannotate/lib/core/serviceproviders.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,7 @@ def set_step(self, project: entities.ProjectEntity, step: entities.StepEntity):
273273
raise NotImplementedError
274274

275275
@abstractmethod
276-
def set_keypoint_steps(self, project: entities.ProjectEntity, steps):
276+
def set_keypoint_steps(self, project: entities.ProjectEntity, steps, connections):
277277
raise NotImplementedError
278278

279279
@abstractmethod

src/superannotate/lib/core/usecases/projects.py

Lines changed: 88 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
from lib.core.usecases.base import BaseUseCase
2222
from lib.core.usecases.base import BaseUserBasedUseCase
2323

24-
2524
logger = logging.getLogger("sa")
2625

2726

@@ -491,8 +490,12 @@ def validate_project_type(self):
491490

492491
def execute(self):
493492
if self.is_valid():
494-
project_settings = self._service_provider.projects.list_settings(project=self._project).data
495-
step_setting = next((i for i in project_settings if i.attribute == "WorkflowType"), None)
493+
project_settings = self._service_provider.projects.list_settings(
494+
project=self._project
495+
).data
496+
step_setting = next(
497+
(i for i in project_settings if i.attribute == "WorkflowType"), None
498+
)
496499
if step_setting.value == constants.StepsType.BASIC:
497500
data = []
498501
steps = self._service_provider.projects.list_steps(self._project).data
@@ -508,9 +511,11 @@ def execute(self):
508511
data.append(step_data)
509512
self._response.data = data
510513
else:
511-
steps = self._service_provider.projects.list_keypoint_steps(self._project).data
512-
raise NotImplementedError
513-
514+
self._response.data = (
515+
self._service_provider.projects.list_keypoint_steps(
516+
self._project
517+
).data["steps"]
518+
)
514519
return self._response
515520

516521

@@ -586,17 +591,40 @@ def __init__(
586591
service_provider: BaseServiceProvider,
587592
steps: list,
588593
project: ProjectEntity,
594+
connections: List[List[int]] = None,
589595
):
590596
super().__init__()
591597
self._service_provider = service_provider
592598
self._steps = steps
599+
self._connections = connections
593600
self._project = project
594601

595602
def validate_project_type(self):
596603
if self._project.type in constants.LIMITED_FUNCTIONS:
597604
raise AppValidationException(
598605
constants.LIMITED_FUNCTIONS[self._project.type]
599606
)
607+
608+
def validate_connections(self):
609+
if not self._connections:
610+
return
611+
612+
if len(self._connections) > len(self._steps):
613+
raise AppValidationException(
614+
"Invalid connections: more connections than steps."
615+
)
616+
617+
possible_connections = set(range(1, len(self._steps) + 1))
618+
for connection_group in self._connections:
619+
if len(set(connection_group)) != len(connection_group):
620+
raise AppValidationException(
621+
"Invalid connections: duplicates in a connection group."
622+
)
623+
if not set(connection_group).issubset(possible_connections):
624+
raise AppValidationException(
625+
"Invalid connections: index out of allowed range."
626+
)
627+
600628
def set_basic_steps(self, annotation_classes):
601629
annotation_classes_map = {}
602630
annotations_classes_attributes_map = {}
@@ -618,9 +646,7 @@ def set_basic_steps(self, annotation_classes):
618646
project=self._project,
619647
steps=self._steps,
620648
)
621-
existing_steps = self._service_provider.projects.list_steps(
622-
self._project
623-
).data
649+
existing_steps = self._service_provider.projects.list_steps(self._project).data
624650
existing_steps_map = {}
625651
for steps in existing_steps:
626652
existing_steps_map[steps.step] = steps.id
@@ -630,12 +656,10 @@ def set_basic_steps(self, annotation_classes):
630656
annotation_class_name = step["className"]
631657
for attribute in step["attribute"]:
632658
attribute_name = attribute["attribute"]["name"]
633-
attribute_group_name = attribute["attribute"]["attribute_group"][
634-
"name"
635-
]
659+
attribute_group_name = attribute["attribute"]["attribute_group"]["name"]
636660
if not annotations_classes_attributes_map.get(
637-
f"{annotation_class_name}__{attribute_group_name}__{attribute_name}",
638-
None,
661+
f"{annotation_class_name}__{attribute_group_name}__{attribute_name}",
662+
None,
639663
):
640664
raise AppException(
641665
f"Attribute group name or attribute name not found {attribute_group_name}."
@@ -656,10 +680,44 @@ def set_basic_steps(self, annotation_classes):
656680
attributes=req_data,
657681
)
658682

659-
def set_keypoint_steps(self, annotation_classes):
683+
@staticmethod
684+
def _validate_keypoint_steps(annotation_classes, steps):
685+
class_group_attrs_map = {}
686+
for annotation_class in annotation_classes:
687+
class_group_attrs_map[annotation_class.id] = dict()
688+
for group in annotation_class.attribute_groups:
689+
class_group_attrs_map[annotation_class.id][group.id] = []
690+
for attribute in group.attributes:
691+
class_group_attrs_map[annotation_class.id][group.id].append(
692+
attribute.id
693+
)
694+
for step in steps:
695+
class_id = step.get("class_id", None)
696+
if not class_id or class_id not in class_group_attrs_map:
697+
raise AppException("Annotation class not found.")
698+
attributes = step.get("attribute", None)
699+
if not attributes:
700+
continue
701+
for attr in attributes:
702+
try:
703+
_id, group_id = attr["attribute"].get("id", None), attr[
704+
"attribute"
705+
].get("group_id", None)
706+
assert _id in class_group_attrs_map[class_id][group_id]
707+
except (KeyError, AssertionError):
708+
raise AppException("Invalid steps provided.")
709+
710+
def set_keypoint_steps(self, annotation_classes, steps, connections):
711+
self._validate_keypoint_steps(annotation_classes, steps)
712+
for i in range(1, len(self._steps) + 1):
713+
step = self._steps[i - 1]
714+
step["id"] = i
715+
if "attribute" not in step:
716+
step["attribute"] = []
660717
self._service_provider.projects.set_keypoint_steps(
661718
project=self._project,
662-
steps=self._steps,
719+
steps=steps,
720+
connections=connections,
663721
)
664722

665723
def execute(self):
@@ -669,13 +727,24 @@ def execute(self):
669727
Condition("project_id", self._project.id, EQ)
670728
).data
671729

672-
project_settings = self._service_provider.projects.list_settings(project=self._project).data
673-
step_setting = next((i for i in project_settings if i.attribute == "WorkflowType"), None)
730+
project_settings = self._service_provider.projects.list_settings(
731+
project=self._project
732+
).data
733+
step_setting = next(
734+
(i for i in project_settings if i.attribute == "WorkflowType"), None
735+
)
736+
if (
737+
self._connections is not None
738+
and step_setting.value == constants.StepsType.BASIC.value
739+
):
740+
raise AppException("Can't update steps type. ")
674741

675742
if step_setting.value == constants.StepsType.BASIC:
676743
self.set_basic_steps(annotation_classes)
677744
else:
678-
self.set_keypoint_steps(annotation_classes)
745+
self.set_keypoint_steps(
746+
annotation_classes, self._steps, self._connections
747+
)
679748

680749
return self._response
681750

src/superannotate/lib/infrastructure/controller.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -489,10 +489,13 @@ def list_steps(self, project: ProjectEntity):
489489
)
490490
return use_case.execute()
491491

492-
def set_steps(self, project: ProjectEntity, steps: List):
492+
def set_steps(
493+
self, project: ProjectEntity, steps: List, connections: List[List[int]] = None
494+
):
493495
use_case = usecases.SetStepsUseCase(
494496
service_provider=self.service_provider,
495497
steps=steps,
498+
connections=connections,
496499
project=project,
497500
)
498501
return use_case.execute()

src/superannotate/lib/infrastructure/services/project.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -107,10 +107,7 @@ def list_steps(self, project: entities.ProjectEntity):
107107
)
108108

109109
def list_keypoint_steps(self, project: entities.ProjectEntity):
110-
return self.client.request(
111-
self.URL_KEYPOINT_STEPS.format(project.id),
112-
"get"
113-
)
110+
return self.client.request(self.URL_KEYPOINT_STEPS.format(project.id), "get")
114111

115112
def set_step(self, project: entities.ProjectEntity, step: entities.StepEntity):
116113
return self.client.request(
@@ -119,11 +116,16 @@ def set_step(self, project: entities.ProjectEntity, step: entities.StepEntity):
119116
data={"steps": [step]},
120117
)
121118

122-
def set_keypoint_steps(self, project: entities.ProjectEntity, steps):
119+
def set_keypoint_steps(self, project: entities.ProjectEntity, steps, connections):
123120
return self.client.request(
124121
self.URL_SET_KEYPOINT_STEPS.format(project.id),
125122
"post",
126-
data={"steps": steps},
123+
data={
124+
"steps": {
125+
"steps": steps,
126+
"connections": connections if connections else [],
127+
}
128+
},
127129
)
128130

129131
# TODO check

0 commit comments

Comments
 (0)