Skip to content

Commit 04892e1

Browse files
committed
generate training data from .txt file
1 parent 36f13e7 commit 04892e1

File tree

2 files changed

+74
-96
lines changed

2 files changed

+74
-96
lines changed

image-build-process/ocr_data_generator/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,4 @@ shapely
1414
imgaug
1515
pyclipper
1616
lmdb
17+
boto3
Lines changed: 73 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -1,120 +1,97 @@
11
"""Feature engineers the abalone dataset."""
2-
import argparse
3-
import logging
4-
import os
5-
import pathlib
6-
import requests
7-
import tempfile
82

9-
import boto3
10-
import numpy as np
11-
import pandas as pd
123

13-
from sklearn.compose import ColumnTransformer
14-
from sklearn.impute import SimpleImputer
15-
from sklearn.pipeline import Pipeline
16-
from sklearn.preprocessing import StandardScaler, OneHotEncoder
4+
import os
5+
import shutil
6+
7+
from trdg.generators import (
8+
GeneratorFromDict,
9+
GeneratorFromRandom,
10+
GeneratorFromStrings,
11+
GeneratorFromWikipedia,
12+
)
13+
import logging
14+
import argparse
15+
import pathlib
16+
import boto3
1717

1818
logger = logging.getLogger()
1919
logger.setLevel(logging.INFO)
2020
logger.addHandler(logging.StreamHandler())
2121

2222

23-
# Since we get a headerless CSV file we specify the column names here.
24-
feature_columns_names = [
25-
"sex",
26-
"length",
27-
"diameter",
28-
"height",
29-
"whole_weight",
30-
"shucked_weight",
31-
"viscera_weight",
32-
"shell_weight",
33-
]
34-
label_column = "rings"
35-
36-
feature_columns_dtype = {
37-
"sex": str,
38-
"length": np.float64,
39-
"diameter": np.float64,
40-
"height": np.float64,
41-
"whole_weight": np.float64,
42-
"shucked_weight": np.float64,
43-
"viscera_weight": np.float64,
44-
"shell_weight": np.float64,
45-
}
46-
label_column_dtype = {"rings": np.float64}
47-
48-
49-
def merge_two_dicts(x, y):
50-
"""Merges two dicts, returning a new copy."""
51-
z = x.copy()
52-
z.update(y)
53-
return z
23+
def get_strings(file_name):
24+
f = open(file_name, 'r')
25+
results = []
26+
for l in f.readlines():
27+
if l and l.strip():
28+
results.append(l.strip())
29+
return results
30+
31+
32+
def get_fonts(font_dir):
33+
onlyfiles = [os.path.join(font_dir, f) for f in os.listdir(font_dir) if os.path.isfile(os.path.join(font_dir, f))]
34+
return onlyfiles
35+
36+
37+
38+
def get_training_data_img_and_labels(string_file, font_dir, output_folder, img_prefix, limit=1000):
39+
strings = get_strings(string_file)
40+
fonts = get_fonts(font_dir)
41+
print(strings)
42+
print(fonts)
43+
generator = GeneratorFromStrings(
44+
strings,
45+
fonts = [f"{font_dir}/setofont.ttf"],
46+
# blur=2,
47+
# random_blur=True
48+
)
49+
labels = []
50+
i = 0
51+
for img, lbl in generator:
52+
if i<=limit:
53+
file_name = os.path.join(output_folder, str(i)+".jpg")
54+
in_label_file_name = os.path.join(img_prefix, str(i)+".jpg")
55+
img.save(file_name)
56+
labels.append((in_label_file_name, lbl))
57+
i+=1
58+
else:
59+
break
60+
61+
label_file = open(os.path.join(output_folder, "train.txt"), 'w')
62+
for l in labels:
63+
line = '\t'.join(l)
64+
label_file.write(line)
65+
label_file.write('\n')
66+
67+
68+
69+
import sys
70+
5471

5572

5673
if __name__ == "__main__":
5774
logger.debug("Starting preprocessing.")
5875
parser = argparse.ArgumentParser()
5976
parser.add_argument("--input-data", type=str, required=True)
6077
args = parser.parse_args()
61-
6278
base_dir = "/opt/ml/processing"
6379
pathlib.Path(f"{base_dir}/data").mkdir(parents=True, exist_ok=True)
6480
input_data = args.input_data
6581
bucket = input_data.split("/")[2]
6682
key = "/".join(input_data.split("/")[3:])
6783

6884
logger.info("Downloading data from bucket: %s, key: %s", bucket, key)
69-
fn = f"{base_dir}/data/abalone-dataset.csv"
85+
train_fn = f"{base_dir}/data/train.txt"
86+
test_fn = f"{base_dir}/data/test.txt"
7087
s3 = boto3.resource("s3")
71-
s3.Bucket(bucket).download_file(key, fn)
72-
73-
logger.debug("Reading downloaded data.")
74-
df = pd.read_csv(
75-
fn,
76-
header=None,
77-
names=feature_columns_names + [label_column],
78-
dtype=merge_two_dicts(feature_columns_dtype, label_column_dtype),
79-
)
80-
os.unlink(fn)
81-
82-
logger.debug("Defining transformers.")
83-
numeric_features = list(feature_columns_names)
84-
numeric_features.remove("sex")
85-
numeric_transformer = Pipeline(
86-
steps=[("imputer", SimpleImputer(strategy="median")), ("scaler", StandardScaler())]
87-
)
88-
89-
categorical_features = ["sex"]
90-
categorical_transformer = Pipeline(
91-
steps=[
92-
("imputer", SimpleImputer(strategy="constant", fill_value="missing")),
93-
("onehot", OneHotEncoder(handle_unknown="ignore")),
94-
]
95-
)
96-
97-
preprocess = ColumnTransformer(
98-
transformers=[
99-
("num", numeric_transformer, numeric_features),
100-
("cat", categorical_transformer, categorical_features),
101-
]
102-
)
103-
104-
logger.info("Applying transforms.")
105-
y = df.pop("rings")
106-
X_pre = preprocess.fit_transform(df)
107-
y_pre = y.to_numpy().reshape(len(y), 1)
108-
109-
X = np.concatenate((y_pre, X_pre), axis=1)
110-
111-
logger.info("Splitting %d rows of data into train, validation, test datasets.", len(X))
112-
np.random.shuffle(X)
113-
train, validation, test = np.split(X, [int(0.7 * len(X)), int(0.85 * len(X))])
114-
115-
logger.info("Writing out datasets to %s.", base_dir)
116-
pd.DataFrame(train).to_csv(f"{base_dir}/train/train.csv", header=False, index=False)
117-
pd.DataFrame(validation).to_csv(
118-
f"{base_dir}/validation/validation.csv", header=False, index=False
119-
)
120-
pd.DataFrame(test).to_csv(f"{base_dir}/test/test.csv", header=False, index=False)
88+
s3.Bucket(bucket).download_file(key+"/train.txt", train_fn)
89+
s3.Bucket(bucket).download_file(key+"/test.txt", test_fn)
90+
font_dir = "/opt/program/ocr_data_generator/setofont"
91+
train_output_folder = f"{base_dir}/train"
92+
test_output_folder = f"{base_dir}/test"
93+
os.mkdir(train_output_folder)
94+
os.mkdir(test_output_folder)
95+
get_training_data_img_and_labels(train_fn, font_dir, train_output_folder, "train")
96+
get_training_data_img_and_labels(test_fn, font_dir, test_output_folder, "test")
97+

0 commit comments

Comments
 (0)