Skip to content

Commit e67526d

Browse files
author
EC2 Default User
committed
add preprocess and conditional steps
1 parent 04892e1 commit e67526d

File tree

5 files changed

+204
-270
lines changed

5 files changed

+204
-270
lines changed
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
高雄醫學大學附設中和紀念醫院
2+
佛教慈濟醫療財團法人花蓮慈濟醫院
3+
基督復臨安息日會醫療財團法人臺
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
國立臺灣大學醫學院附設醫院
2+
三軍總醫院
3+
臺北榮民總醫院
4+
國泰綜合醫院
5+
臺北及林口長庚醫院
6+
馬偕紀念醫院
7+
新光吳火獅紀念醫院
8+
臺北市立萬芳醫院
9+
亞東紀念醫院
10+
臺中榮民總醫院
11+
中山醫學大學附設醫院
12+
中國醫藥大學附設醫院
13+
彰化基督教醫療財團法人彰化基督教醫院
14+
國立成功大學醫學院附設醫院
15+
奇美醫院
16+
高雄榮民總醫院

sagemaker_pipelines/paddleocr/pipeline.py

Lines changed: 68 additions & 204 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
ScriptProcessor,
2828
)
2929
from sagemaker.sklearn.processing import SKLearnProcessor
30-
from sagemaker.workflow.conditions import ConditionLessThanOrEqualTo
30+
from sagemaker.workflow.conditions import ConditionGreaterThanOrEqualTo
3131
from sagemaker.workflow.condition_step import (
3232
ConditionStep,
3333
)
@@ -47,7 +47,7 @@
4747
from sagemaker.workflow.step_collections import RegisterModel
4848

4949
from botocore.exceptions import ClientError
50-
50+
import boto3
5151

5252
BASE_DIR = os.path.dirname(os.path.realpath(__file__))
5353

@@ -76,72 +76,7 @@ def get_session(region, default_bucket):
7676
default_bucket=default_bucket,
7777
)
7878

79-
def resolve_ecr_uri_from_image_versions(sagemaker_session, image_versions, image_name):
80-
""" Gets ECR URI from image versions
81-
Args:
82-
sagemaker_session: boto3 session for sagemaker client
83-
image_versions: list of the image versions
84-
image_name: Name of the image
85-
86-
Returns:
87-
ECR URI of the image version
88-
"""
8979

90-
#Fetch image details to get the Base Image URI
91-
for image_version in image_versions:
92-
if image_version['ImageVersionStatus'] == 'CREATED':
93-
image_arn = image_version["ImageVersionArn"]
94-
version = image_version["Version"]
95-
logger.info(f"Identified the latest image version: {image_arn}")
96-
response = sagemaker_session.sagemaker_client.describe_image_version(
97-
ImageName=image_name,
98-
Version=version
99-
)
100-
return response['ContainerImage']
101-
return None
102-
103-
def resolve_ecr_uri(sagemaker_session, image_arn):
104-
"""Gets the ECR URI from the image name
105-
106-
Args:
107-
sagemaker_session: boto3 session for sagemaker client
108-
image_name: name of the image
109-
110-
Returns:
111-
ECR URI of the latest image version
112-
"""
113-
114-
# Fetching image name from image_arn (^arn:aws(-[\w]+)*:sagemaker:.+:[0-9]{12}:image/[a-z0-9]([-.]?[a-z0-9])*$)
115-
image_name = image_arn.partition("image/")[2]
116-
try:
117-
# Fetch the image versions
118-
next_token=''
119-
while True:
120-
response = sagemaker_session.sagemaker_client.list_image_versions(
121-
ImageName=image_name,
122-
MaxResults=100,
123-
SortBy='VERSION',
124-
SortOrder='DESCENDING',
125-
NextToken=next_token
126-
)
127-
ecr_uri = resolve_ecr_uri_from_image_versions(sagemaker_session, response['ImageVersions'], image_name)
128-
if "NextToken" in response:
129-
next_token = response["NextToken"]
130-
131-
if ecr_uri is not None:
132-
return ecr_uri
133-
134-
# Return error if no versions of the image found
135-
error_message = (
136-
f"No image version found for image name: {image_name}"
137-
)
138-
logger.error(error_message)
139-
raise Exception(error_message)
140-
141-
except (ClientError, sagemaker_session.sagemaker_client.exceptions.ResourceNotFound) as e:
142-
error_message = e.response["Error"]["Message"]
143-
logger.error(error_message)
144-
raise Exception(error_message)
14580

14681
def get_pipeline(
14782
region,
@@ -167,12 +102,12 @@ def get_pipeline(
167102
default_bucket = sagemaker_session.default_bucket()
168103
if role is None:
169104
role = sagemaker.session.get_execution_role(sagemaker_session)
170-
171-
# parameters for pipeline execution
172-
# processing_instance_count = ParameterInteger(name="ProcessingInstanceCount", default_value=1)
173-
# processing_instance_type = ParameterString(
174-
# name="ProcessingInstanceType", default_value="ml.m5.xlarge"
175-
# )
105+
# parametersagemaker_sessions for pipeline execution
106+
sess = boto3.Session()
107+
processing_instance_count = ParameterInteger(name="ProcessingInstanceCount", default_value=1)
108+
processing_instance_type = ParameterString(
109+
name="ProcessingInstanceType", default_value="ml.m5.xlarge"
110+
)
176111
training_instance_type = ParameterString(
177112
name="TrainingInstanceType", default_value="ml.p2.xlarge"
178113
)
@@ -184,60 +119,40 @@ def get_pipeline(
184119
)
185120
input_data = ParameterString(
186121
name="InputDataUrl",
187-
default_value="s3://{}/DEMO-paddle-byo/".format(default_bucket)
122+
default_value="s3://{}/PaddleOCR/input/data".format(default_bucket)
188123
)
124+
account = sess.client("sts").get_caller_identity()["Account"]
125+
region = sess.region_name
126+
data_generate_image_name = "generate-ocr-train-data"
127+
train_image_name = "paddle"
128+
data_generate_image = "{}.dkr.ecr.{}.amazonaws.com/{}".format(account, region, data_generate_image_name)
189129

190-
training_image_name = "paddle"
191-
inference_image_name = "paddle"
192-
193-
# processing step for feature engineering
194-
# try:
195-
# processing_image_uri = sagemaker_session.sagemaker_client.describe_image_version(ImageName=processing_image_name)['ContainerImage']
196-
# except (sagemaker_session.sagemaker_client.exceptions.ResourceNotFound):
197-
# processing_image_uri = sagemaker.image_uris.retrieve(
198-
# framework="xgboost",
199-
# region=region,
200-
# version="1.0-1",
201-
# py_version="py3",
202-
# instance_type=processing_instance_type,
203-
# )
204-
# script_processor = ScriptProcessor(
205-
# image_uri=processing_image_uri,
206-
# instance_type=processing_instance_type,
207-
# instance_count=processing_instance_count,
208-
# base_job_name=f"{base_job_prefix}/sklearn-abalone-preprocess",
209-
# command=["python3"],
210-
# sagemaker_session=sagemaker_session,
211-
# role=role,
212-
# )
213-
# step_process = ProcessingStep(
214-
# name="PreprocessAbaloneData",
215-
# processor=script_processor,
216-
# outputs=[
217-
# ProcessingOutput(output_name="train", source="/opt/ml/processing/train"),
218-
# ProcessingOutput(output_name="validation", source="/opt/ml/processing/validation"),
219-
# ProcessingOutput(output_name="test", source="/opt/ml/processing/test"),
220-
# ],
221-
# code=os.path.join(BASE_DIR, "preprocess.py"),
222-
# job_arguments=["--input-data", input_data],
223-
# )
130+
script_processor = ScriptProcessor(
131+
image_uri=data_generate_image,
132+
instance_type=processing_instance_type,
133+
instance_count=processing_instance_count,
134+
base_job_name=f"{base_job_prefix}/paddle-ocr-data-generation",
135+
command=["python3"],
136+
sagemaker_session=sagemaker_session,
137+
role=role,
138+
)
139+
step_process = ProcessingStep(
140+
name="GenerateOCRTrainingData",
141+
processor=script_processor,
142+
outputs=[
143+
ProcessingOutput(output_name="data", source="/opt/ml/processing/input/data"),
144+
],
145+
code=os.path.join(BASE_DIR, "preprocess.py"),
146+
job_arguments=["--input-data", input_data],
147+
)
224148

225149
# training step for generating model artifacts
226150
model_path = f"s3://{sagemaker_session.default_bucket()}/{base_job_prefix}/PaddleOCRTrain"
227151

228-
# try:
229-
# print(training_image_name)
230-
# training_image_uri = sagemaker_session.sagemaker_client.describe_image_version(ImageName=training_image_name)['ContainerImage']
231-
# except (sagemaker_session.sagemaker_client.exceptions.ResourceNotFound):
232-
# training_image_uri = sagemaker.image_uris.retrieve(
233-
# framework="xgboost",
234-
# region=region,
235-
# version="1.0-1",
236-
# py_version="py3",
237-
# instance_type=training_instance_type,
238-
# )
239-
240-
training_image_uri = "230755935769.dkr.ecr.us-east-1.amazonaws.com/paddle:latest"
152+
153+
image = "{}.dkr.ecr.{}.amazonaws.com/{}".format(account, region, train_image_name)
154+
155+
training_image_uri = image
241156
hyperparameters = {"epoch_num": 10,
242157
"print_batch_step":5,
243158
"save_epoch_step":30,
@@ -252,6 +167,12 @@ def get_pipeline(
252167
sagemaker_session=sagemaker_session,
253168
base_job_name=f"{base_job_prefix}/paddleocr-train",
254169
hyperparameters=hyperparameters,
170+
# acc: 0.2007992007992008, norm_edit_dis: 0.7116550116550118, fps: 97.10778964378831, best_epoch: 9
171+
metric_definitions=[
172+
{'Name': 'validation:acc', 'Regex': '.*best metric,.*acc:(.*?),'},
173+
{'Name': 'validation:norm_edit_dis', 'Regex': '.*best metric,.*norm_edit_dis:(.*?),'}
174+
]
175+
255176
)
256177

257178

@@ -260,71 +181,17 @@ def get_pipeline(
260181
estimator=paddle_train,
261182
inputs={
262183
"training": TrainingInput(
263-
s3_data=input_data,
264-
content_type="text/csv",
265-
)
266-
},
184+
s3_data=step_process.properties.ProcessingOutputConfig.Outputs[
185+
"data"
186+
].S3Output.S3Uri,
187+
content_type="image/jpeg")
188+
}
267189
)
268190

269-
# processing step for evaluation
270-
# script_eval = ScriptProcessor(
271-
# image_uri=training_image_uri,
272-
# command=["python3"],
273-
# instance_type=processing_instance_type,
274-
# instance_count=1,
275-
# base_job_name=f"{base_job_prefix}/script-abalone-eval",
276-
# sagemaker_session=sagemaker_session,
277-
# role=role,
278-
# )
279-
# evaluation_report = PropertyFile(
280-
# name="AbaloneEvaluationReport",
281-
# output_name="evaluation",
282-
# path="evaluation.json",
283-
# )
284-
# step_eval = ProcessingStep(
285-
# name="EvaluateAbaloneModel",
286-
# processor=script_eval,
287-
# inputs=[
288-
# ProcessingInput(
289-
# source=step_train.properties.ModelArtifacts.S3ModelArtifacts,
290-
# destination="/opt/ml/processing/model",
291-
# ),
292-
# ProcessingInput(
293-
# source=step_process.properties.ProcessingOutputConfig.Outputs[
294-
# "test"
295-
# ].S3Output.S3Uri,
296-
# destination="/opt/ml/processing/test",
297-
# ),
298-
# ],
299-
# outputs=[
300-
# ProcessingOutput(output_name="evaluation", source="/opt/ml/processing/evaluation"),
301-
# ],
302-
# code=os.path.join(BASE_DIR, "evaluate.py"),
303-
# property_files=[evaluation_report],
304-
# )
305-
306-
# # register model step that will be conditionally executed
307-
# model_metrics = ModelMetrics(
308-
# model_statistics=MetricsSource(
309-
# s3_uri="{}/evaluation.json".format(
310-
# step_eval.arguments["ProcessingOutputConfig"]["Outputs"][0]["S3Output"]["S3Uri"]
311-
# ),
312-
# content_type="application/json"
313-
# )
314-
# )
315-
316-
# try:
317-
# inference_image_uri = sagemaker_session.sagemaker_client.describe_image_version(ImageName=inference_image_name)['ContainerImage']
318-
# except (sagemaker_session.sagemaker_client.exceptions.ResourceNotFound):
319-
# inference_image_uri = sagemaker.image_uris.retrieve(
320-
# framework="xgboost",
321-
# region=region,
322-
# version="1.0-1",
323-
# py_version="py3",
324-
# instance_type=inference_instance_type,
325-
# )
326-
327-
inference_image_uri = "230755935769.dkr.ecr.us-east-1.amazonaws.com/paddle:latest"
191+
192+
193+
194+
inference_image_uri = image
328195
step_register = RegisterModel(
329196
name="RegisterPaddleOCRModel",
330197
estimator=paddle_train,
@@ -335,37 +202,34 @@ def get_pipeline(
335202
inference_instances=["ml.p2.xlarge"],
336203
transform_instances=["ml.p2.xlarge"],
337204
model_package_group_name=model_package_group_name,
338-
approval_status=model_approval_status,
339-
# model_metrics=model_metrics,
205+
approval_status=model_approval_status
206+
)
207+
208+
cond_lte = ConditionGreaterThanOrEqualTo( # You can change the condition here
209+
left=step_train.properties.FinalMetricDataList[0].Value,
210+
right=0.8, # You can change the threshold here
211+
)
212+
213+
step_cond = ConditionStep(
214+
name="PaddleOCRAccuracyCond",
215+
conditions=[cond_lte],
216+
if_steps=[step_register],
217+
else_steps=[],
340218
)
341219

342-
# condition step for evaluating model quality and branching execution
343-
# cond_lte = ConditionLessThanOrEqualTo(
344-
# left=JsonGet(
345-
# step_name=step_eval.name,
346-
# property_file=evaluation_report,
347-
# json_path="regression_metrics.mse.value"
348-
# ),
349-
# right=6.0,
350-
# )
351-
# step_cond = ConditionStep(
352-
# name="CheckMSEAbaloneEvaluation",
353-
# conditions=[cond_lte],
354-
# if_steps=[step_register],
355-
# else_steps=[],
356-
# )
357220

358221
# pipeline instance
359222
pipeline = Pipeline(
360223
name=pipeline_name,
361224
parameters=[
362-
# processing_instance_type,
363-
# processing_instance_count,
225+
processing_instance_type,
226+
processing_instance_count,
364227
training_instance_type,
365228
model_approval_status,
366229
input_data,
367230
],
368-
steps=[step_train, step_register],
231+
steps = [step_process, step_train, step_cond],
232+
# steps=[step_train, step_register],
369233
sagemaker_session=sagemaker_session,
370234
)
371235
return pipeline

0 commit comments

Comments
 (0)