2121from lib .core .usecases .base import BaseUseCase
2222from lib .core .usecases .base import BaseUserBasedUseCase
2323
24-
2524logger = logging .getLogger ("sa" )
2625
2726
@@ -491,8 +490,12 @@ def validate_project_type(self):
491490
492491 def execute (self ):
493492 if self .is_valid ():
494- project_settings = self ._service_provider .projects .list_settings (project = self ._project ).data
495- step_setting = next ((i for i in project_settings if i .attribute == "WorkflowType" ), None )
493+ project_settings = self ._service_provider .projects .list_settings (
494+ project = self ._project
495+ ).data
496+ step_setting = next (
497+ (i for i in project_settings if i .attribute == "WorkflowType" ), None
498+ )
496499 if step_setting .value == constants .StepsType .BASIC :
497500 data = []
498501 steps = self ._service_provider .projects .list_steps (self ._project ).data
@@ -508,9 +511,11 @@ def execute(self):
508511 data .append (step_data )
509512 self ._response .data = data
510513 else :
511- steps = self ._service_provider .projects .list_keypoint_steps (self ._project ).data
512- raise NotImplementedError
513-
514+ self ._response .data = (
515+ self ._service_provider .projects .list_keypoint_steps (
516+ self ._project
517+ ).data ["steps" ]
518+ )
514519 return self ._response
515520
516521
@@ -586,17 +591,40 @@ def __init__(
586591 service_provider : BaseServiceProvider ,
587592 steps : list ,
588593 project : ProjectEntity ,
594+ connections : List [List [int ]] = None ,
589595 ):
590596 super ().__init__ ()
591597 self ._service_provider = service_provider
592598 self ._steps = steps
599+ self ._connections = connections
593600 self ._project = project
594601
595602 def validate_project_type (self ):
596603 if self ._project .type in constants .LIMITED_FUNCTIONS :
597604 raise AppValidationException (
598605 constants .LIMITED_FUNCTIONS [self ._project .type ]
599606 )
607+
608+ def validate_connections (self ):
609+ if not self ._connections :
610+ return
611+
612+ if len (self ._connections ) > len (self ._steps ):
613+ raise AppValidationException (
614+ "Invalid connections: more connections than steps."
615+ )
616+
617+ possible_connections = set (range (1 , len (self ._steps ) + 1 ))
618+ for connection_group in self ._connections :
619+ if len (set (connection_group )) != len (connection_group ):
620+ raise AppValidationException (
621+ "Invalid connections: duplicates in a connection group."
622+ )
623+ if not set (connection_group ).issubset (possible_connections ):
624+ raise AppValidationException (
625+ "Invalid connections: index out of allowed range."
626+ )
627+
600628 def set_basic_steps (self , annotation_classes ):
601629 annotation_classes_map = {}
602630 annotations_classes_attributes_map = {}
@@ -618,9 +646,7 @@ def set_basic_steps(self, annotation_classes):
618646 project = self ._project ,
619647 steps = self ._steps ,
620648 )
621- existing_steps = self ._service_provider .projects .list_steps (
622- self ._project
623- ).data
649+ existing_steps = self ._service_provider .projects .list_steps (self ._project ).data
624650 existing_steps_map = {}
625651 for steps in existing_steps :
626652 existing_steps_map [steps .step ] = steps .id
@@ -630,12 +656,10 @@ def set_basic_steps(self, annotation_classes):
630656 annotation_class_name = step ["className" ]
631657 for attribute in step ["attribute" ]:
632658 attribute_name = attribute ["attribute" ]["name" ]
633- attribute_group_name = attribute ["attribute" ]["attribute_group" ][
634- "name"
635- ]
659+ attribute_group_name = attribute ["attribute" ]["attribute_group" ]["name" ]
636660 if not annotations_classes_attributes_map .get (
637- f"{ annotation_class_name } __{ attribute_group_name } __{ attribute_name } " ,
638- None ,
661+ f"{ annotation_class_name } __{ attribute_group_name } __{ attribute_name } " ,
662+ None ,
639663 ):
640664 raise AppException (
641665 f"Attribute group name or attribute name not found { attribute_group_name } ."
@@ -656,10 +680,44 @@ def set_basic_steps(self, annotation_classes):
656680 attributes = req_data ,
657681 )
658682
659- def set_keypoint_steps (self , annotation_classes ):
683+ @staticmethod
684+ def _validate_keypoint_steps (annotation_classes , steps ):
685+ class_group_attrs_map = {}
686+ for annotation_class in annotation_classes :
687+ class_group_attrs_map [annotation_class .id ] = dict ()
688+ for group in annotation_class .attribute_groups :
689+ class_group_attrs_map [annotation_class .id ][group .id ] = []
690+ for attribute in group .attributes :
691+ class_group_attrs_map [annotation_class .id ][group .id ].append (
692+ attribute .id
693+ )
694+ for step in steps :
695+ class_id = step .get ("class_id" , None )
696+ if not class_id or class_id not in class_group_attrs_map :
697+ raise AppException ("Annotation class not found." )
698+ attributes = step .get ("attribute" , None )
699+ if not attributes :
700+ continue
701+ for attr in attributes :
702+ try :
703+ _id , group_id = attr ["attribute" ].get ("id" , None ), attr [
704+ "attribute"
705+ ].get ("group_id" , None )
706+ assert _id in class_group_attrs_map [class_id ][group_id ]
707+ except (KeyError , AssertionError ):
708+ raise AppException ("Invalid steps provided." )
709+
710+ def set_keypoint_steps (self , annotation_classes , steps , connections ):
711+ self ._validate_keypoint_steps (annotation_classes , steps )
712+ for i in range (1 , len (self ._steps ) + 1 ):
713+ step = self ._steps [i - 1 ]
714+ step ["id" ] = i
715+ if "attribute" not in step :
716+ step ["attribute" ] = []
660717 self ._service_provider .projects .set_keypoint_steps (
661718 project = self ._project ,
662- steps = self ._steps ,
719+ steps = steps ,
720+ connections = connections ,
663721 )
664722
665723 def execute (self ):
@@ -669,13 +727,24 @@ def execute(self):
669727 Condition ("project_id" , self ._project .id , EQ )
670728 ).data
671729
672- project_settings = self ._service_provider .projects .list_settings (project = self ._project ).data
673- step_setting = next ((i for i in project_settings if i .attribute == "WorkflowType" ), None )
730+ project_settings = self ._service_provider .projects .list_settings (
731+ project = self ._project
732+ ).data
733+ step_setting = next (
734+ (i for i in project_settings if i .attribute == "WorkflowType" ), None
735+ )
736+ if (
737+ self ._connections is not None
738+ and step_setting .value == constants .StepsType .BASIC .value
739+ ):
740+ raise AppException ("Can't update steps type. " )
674741
675742 if step_setting .value == constants .StepsType .BASIC :
676743 self .set_basic_steps (annotation_classes )
677744 else :
678- self .set_keypoint_steps (annotation_classes )
745+ self .set_keypoint_steps (
746+ annotation_classes , self ._steps , self ._connections
747+ )
679748
680749 return self ._response
681750
0 commit comments