1+ """Example workflow pipeline script for abalone pipeline.
2+
3+ . -RegisterModel
4+ .
5+ Process-> Train -> Evaluate -> Condition .
6+ .
7+ . -(stop)
8+
9+ Implements a get_pipeline(**kwargs) method.
10+ """
11+ import os
12+
13+ import boto3
14+ import logging
15+ import sagemaker
16+ import sagemaker .session
17+
18+ from sagemaker .estimator import Estimator
19+ from sagemaker .inputs import TrainingInput
20+ from sagemaker .model_metrics import (
21+ MetricsSource ,
22+ ModelMetrics ,
23+ )
24+ from sagemaker .processing import (
25+ ProcessingInput ,
26+ ProcessingOutput ,
27+ ScriptProcessor ,
28+ )
29+ from sagemaker .sklearn .processing import SKLearnProcessor
30+ from sagemaker .workflow .conditions import ConditionGreaterThanOrEqualTo
31+ from sagemaker .workflow .condition_step import (
32+ ConditionStep ,
33+ )
34+ from sagemaker .workflow .functions import (
35+ JsonGet ,
36+ )
37+ from sagemaker .workflow .parameters import (
38+ ParameterInteger ,
39+ ParameterString ,
40+ )
41+ from sagemaker .workflow .pipeline import Pipeline
42+ from sagemaker .workflow .properties import PropertyFile
43+ from sagemaker .workflow .steps import (
44+ ProcessingStep ,
45+ TrainingStep ,
46+ )
47+ from sagemaker .workflow .step_collections import RegisterModel
48+
49+ from botocore .exceptions import ClientError
50+ import boto3
51+
52+ BASE_DIR = os .path .dirname (os .path .realpath (__file__ ))
53+
54+ logger = logging .getLogger (__name__ )
55+
56+
57+ def get_session (region , default_bucket ):
58+ """Gets the sagemaker session based on the region.
59+
60+ Args:
61+ region: the aws region to start the session
62+ default_bucket: the bucket to use for storing the artifacts
63+
64+ Returns:
65+ `sagemaker.session.Session instance
66+ """
67+
68+ boto_session = boto3 .Session (region_name = region )
69+
70+ sagemaker_client = boto_session .client ("sagemaker" )
71+ runtime_client = boto_session .client ("sagemaker-runtime" )
72+ return sagemaker .session .Session (
73+ boto_session = boto_session ,
74+ sagemaker_client = sagemaker_client ,
75+ sagemaker_runtime_client = runtime_client ,
76+ default_bucket = default_bucket ,
77+ )
78+
79+
80+
81+ def get_pipeline (
82+ region ,
83+ role = None ,
84+ default_bucket = None ,
85+ model_package_group_name = "PaddleOCRPackageGroup" ,
86+ pipeline_name = "PaddleOCRPipelne" ,
87+ base_job_prefix = "PaddleOCR" ,
88+ project_id = "SageMakerProjectId"
89+ ):
90+ """Gets a SageMaker ML Pipeline instance working with on abalone data.
91+
92+ Args:
93+ region: AWS region to create and run the pipeline.
94+ role: IAM role to create and run steps and pipeline.
95+ default_bucket: the bucket to use for storing the artifacts
96+
97+ Returns:
98+ an instance of a pipeline
99+ """
100+ sagemaker_session = get_session (region , default_bucket )
101+ if not default_bucket :
102+ default_bucket = sagemaker_session .default_bucket ()
103+ if role is None :
104+ role = sagemaker .session .get_execution_role (sagemaker_session )
105+ # parametersagemaker_sessions for pipeline execution
106+ sess = boto3 .Session ()
107+ processing_instance_count = ParameterInteger (name = "ProcessingInstanceCount" , default_value = 1 )
108+ processing_instance_type = ParameterString (
109+ name = "ProcessingInstanceType" , default_value = "ml.m5.xlarge"
110+ )
111+ training_instance_type = ParameterString (
112+ name = "TrainingInstanceType" , default_value = "ml.p2.xlarge"
113+ )
114+ inference_instance_type = ParameterString (
115+ name = "InferenceInstanceType" , default_value = "ml.p2.xlarge"
116+ )
117+ model_approval_status = ParameterString (
118+ name = "ModelApprovalStatus" , default_value = "PendingManualApproval"
119+ )
120+ input_data = ParameterString (
121+ name = "InputDataUrl" ,
122+ default_value = "s3://{}/PaddleOCR/input/data" .format (default_bucket )
123+ )
124+ account = sess .client ("sts" ).get_caller_identity ()["Account" ]
125+ region = sess .region_name
126+ data_generate_image_name = "generate-ocr-train-data"
127+ train_image_name = "paddle"
128+ data_generate_image = "{}.dkr.ecr.{}.amazonaws.com/{}" .format (account , region , data_generate_image_name )
129+
130+ script_processor = ScriptProcessor (
131+ image_uri = data_generate_image ,
132+ instance_type = processing_instance_type ,
133+ instance_count = processing_instance_count ,
134+ base_job_name = f"{ base_job_prefix } /paddle-ocr-data-generation" ,
135+ command = ["python3" ],
136+ sagemaker_session = sagemaker_session ,
137+ role = role ,
138+ )
139+ step_process = ProcessingStep (
140+ name = "GenerateOCRTrainingData" ,
141+ processor = script_processor ,
142+ outputs = [
143+ ProcessingOutput (output_name = "data" , source = "/opt/ml/processing/input/data" ),
144+ ],
145+ code = os .path .join (BASE_DIR , "preprocess.py" ),
146+ job_arguments = ["--input-data" , input_data ],
147+ )
148+
149+ # training step for generating model artifacts
150+ model_path = f"s3://{ sagemaker_session .default_bucket ()} /{ base_job_prefix } /PaddleOCRTrain"
151+
152+
153+ image = "{}.dkr.ecr.{}.amazonaws.com/{}" .format (account , region , train_image_name )
154+
155+ training_image_uri = image
156+ hyperparameters = {"epoch_num" : 10 ,
157+ "print_batch_step" :5 ,
158+ "save_epoch_step" :30 ,
159+ 'pretrained_model' :'/opt/program/pretrain/ch_ppocr_mobile_v2.0_rec_train/best_accuracy' }
160+
161+ paddle_train = Estimator (
162+ image_uri = training_image_uri ,
163+ instance_type = training_instance_type ,
164+ role = role ,
165+ instance_count = 1 ,
166+ output_path = model_path ,
167+ sagemaker_session = sagemaker_session ,
168+ base_job_name = f"{ base_job_prefix } /paddleocr-train" ,
169+ hyperparameters = hyperparameters ,
170+ # acc: 0.2007992007992008, norm_edit_dis: 0.7116550116550118, fps: 97.10778964378831, best_epoch: 9
171+ metric_definitions = [
172+ {'Name' : 'validation:acc' , 'Regex' : '.*best metric,.*acc:(.*?),' },
173+ {'Name' : 'validation:norm_edit_dis' , 'Regex' : '.*best metric,.*norm_edit_dis:(.*?),' }
174+ ]
175+
176+ )
177+
178+
179+ step_train = TrainingStep (
180+ name = "TrainPaddleOCRModel" ,
181+ estimator = paddle_train ,
182+ inputs = {
183+ "training" : TrainingInput (
184+ s3_data = step_process .properties .ProcessingOutputConfig .Outputs [
185+ "data"
186+ ].S3Output .S3Uri ,
187+ content_type = "image/jpeg" )
188+ }
189+ )
190+
191+
192+
193+
194+ inference_image_uri = image
195+ step_register = RegisterModel (
196+ name = "RegisterPaddleOCRModel" ,
197+ estimator = paddle_train ,
198+ image_uri = inference_image_uri ,
199+ model_data = step_train .properties .ModelArtifacts .S3ModelArtifacts ,
200+ content_types = ["text/csv" ],
201+ response_types = ["text/csv" ],
202+ inference_instances = ["ml.p2.xlarge" ],
203+ transform_instances = ["ml.p2.xlarge" ],
204+ model_package_group_name = model_package_group_name ,
205+ approval_status = model_approval_status
206+ )
207+
208+ cond_lte = ConditionGreaterThanOrEqualTo ( # You can change the condition here
209+ left = step_train .properties .FinalMetricDataList [0 ].Value ,
210+ right = 0.8 , # You can change the threshold here
211+ )
212+
213+ step_cond = ConditionStep (
214+ name = "PaddleOCRAccuracyCond" ,
215+ conditions = [cond_lte ],
216+ if_steps = [step_register ],
217+ else_steps = [],
218+ )
219+
220+
221+ # pipeline instance
222+ pipeline = Pipeline (
223+ name = pipeline_name ,
224+ parameters = [
225+ processing_instance_type ,
226+ processing_instance_count ,
227+ training_instance_type ,
228+ model_approval_status ,
229+ input_data ,
230+ ],
231+ steps = [step_process , step_train , step_cond ],
232+ # steps=[step_train, step_register],
233+ sagemaker_session = sagemaker_session ,
234+ )
235+ return pipeline
0 commit comments