From e9c2d4bcb5eaeecca9364b03bf4c6b7c70f812b0 Mon Sep 17 00:00:00 2001 From: Ohad Mosafi Date: Sun, 27 Sep 2020 16:55:00 -0700 Subject: [PATCH 1/3] [samples] add inference sample for simple snp trainer --- .../rnn_consensus_trainer.py | 2 +- samples/simple_snp_trainer/cnn_snp_infer.py | 78 +++++++++++++++++++ samples/simple_snp_trainer/cnn_snp_trainer.py | 2 +- 3 files changed, 80 insertions(+), 2 deletions(-) create mode 100755 samples/simple_snp_trainer/cnn_snp_infer.py diff --git a/samples/simple_consensus_caller/rnn_consensus_trainer.py b/samples/simple_consensus_caller/rnn_consensus_trainer.py index d9f779c..5155d0a 100644 --- a/samples/simple_consensus_caller/rnn_consensus_trainer.py +++ b/samples/simple_consensus_caller/rnn_consensus_trainer.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # -"""A sample program highlighting usage of VariantWorks SDK to write a simple SNP variant caller using a CNN.""" +"""A sample program highlighting usage of VariantWorks SDK to write a simple consensus training tool.""" import argparse diff --git a/samples/simple_snp_trainer/cnn_snp_infer.py b/samples/simple_snp_trainer/cnn_snp_infer.py new file mode 100755 index 0000000..0953ad8 --- /dev/null +++ b/samples/simple_snp_trainer/cnn_snp_infer.py @@ -0,0 +1,78 @@ +#!/usr/bin/env python +# +# Copyright 2020 NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""A sample program highlighting usage of VariantWorks SDK to write a simple SNP variant caller using a CNN.""" + +import argparse + +import nemo + +from variantworks.dataloader import HDFDataLoader +from variantworks.networks import AlexNet +from variantworks.neural_types import ReadPileupNeuralType, VariantZygosityNeuralType + + +def create_model(): + """Return neural network to test.""" + # Neural Network + alexnet = AlexNet(num_input_channels=2, num_output_logits=3) + + return alexnet + + +def infer(parsed_args): + """Infer a sample model.""" + # Create neural factory as per NeMo requirements. + nf = nemo.core.NeuralModuleFactory( + placement=nemo.core.neural_factory.DeviceType.GPU, checkpoint_dir=parsed_args.model_dir) + + model = create_model() + + # Create test DAG + test_dataset = HDFDataLoader(args.test_hdf, batch_size=32, + shuffle=True, num_workers=args.threads, + tensor_keys=["encodings", "labels"], + tensor_dims=[('B', 'C', 'H', 'W'), tuple('B')], + tensor_neural_types=[ReadPileupNeuralType(), VariantZygosityNeuralType()]) + encoding, vz_labels = test_dataset() + + vz = model(encoding=encoding) + + nf.infer([vz], checkpoint_dir=parsed_args.model_dir, verbose=True) + + +def build_parser(): + """Build parser object with options for sample.""" + import multiprocessing + + parser = argparse.ArgumentParser( + description="Simple model inference SNP caller based on VariantWorks.") + parser.add_argument("--test-hdf", + help="HDF with examples for testing.", + required=True) + parser.add_argument("-t", "--threads", type=int, + help="Threads to use for parallel loading.", + required=False, default=multiprocessing.cpu_count()) + parser.add_argument("--model-dir", type=str, + help="Directory for loading saved trained model checkpoints.", + required=False, default="./models") + return parser + + +if __name__ == "__main__": + parser = build_parser() + args = parser.parse_args() + infer(args) diff --git a/samples/simple_snp_trainer/cnn_snp_trainer.py b/samples/simple_snp_trainer/cnn_snp_trainer.py index 1ec1cd2..d8aa031 100755 --- a/samples/simple_snp_trainer/cnn_snp_trainer.py +++ b/samples/simple_snp_trainer/cnn_snp_trainer.py @@ -113,7 +113,7 @@ def train(args): def build_parser(): """Build parser object with options for sample.""" parser = argparse.ArgumentParser( - description="Simple SNP caller based on VariantWorks.") + description="Simple model training for SNP caller based on VariantWorks.") parser.add_argument("--train-hdf", help="HDF with examples for training.", From c786f44f92512b381fa3e00bd7ea4a986a17710f Mon Sep 17 00:00:00 2001 From: Ohad Mosafi Date: Thu, 8 Oct 2020 07:02:00 -0700 Subject: [PATCH 2/3] [samples] simple snp inference - uses ReadPileupDataLoader instead of HDFDataLoader --- samples/simple_snp_trainer/cnn_snp_infer.py | 81 +++++++++++++++------ variantworks/io/vcfio.py | 5 ++ 2 files changed, 62 insertions(+), 24 deletions(-) diff --git a/samples/simple_snp_trainer/cnn_snp_infer.py b/samples/simple_snp_trainer/cnn_snp_infer.py index 0953ad8..2e5f161 100755 --- a/samples/simple_snp_trainer/cnn_snp_infer.py +++ b/samples/simple_snp_trainer/cnn_snp_infer.py @@ -18,11 +18,14 @@ import argparse +import os import nemo +import torch -from variantworks.dataloader import HDFDataLoader +from variantworks.dataloader import ReadPileupDataLoader +from variantworks.encoders import PileupEncoder, ZygosityLabelDecoder +from variantworks.io.vcfio import VCFReader, VCFWriter from variantworks.networks import AlexNet -from variantworks.neural_types import ReadPileupNeuralType, VariantZygosityNeuralType def create_model(): @@ -37,39 +40,69 @@ def infer(parsed_args): """Infer a sample model.""" # Create neural factory as per NeMo requirements. nf = nemo.core.NeuralModuleFactory( - placement=nemo.core.neural_factory.DeviceType.GPU, checkpoint_dir=parsed_args.model_dir) + placement=nemo.core.neural_factory.DeviceType.GPU, + checkpoint_dir=parsed_args.model_dir) + + vcf_readers = [] + for tp_file in parsed_args.tp_vcf_files: + vcf_readers.append(VCFReader(vcf=tp_file, bams=[parsed_args.bam], is_fp=False)) + for fp_file in parsed_args.fp_vcf_files: + vcf_readers.append(VCFReader(vcf=fp_file, bams=[parsed_args.bam], is_fp=True)) + + # Setup encoder for samples and labels. + sample_encoder = PileupEncoder(window_size=100, max_reads=100, + layers=[PileupEncoder.Layer.READ, PileupEncoder.Layer.BASE_QUALITY]) + test_dataset = ReadPileupDataLoader(ReadPileupDataLoader.Type.TEST, vcf_readers, + batch_size=32, shuffle=False, sample_encoder=sample_encoder) model = create_model() - # Create test DAG - test_dataset = HDFDataLoader(args.test_hdf, batch_size=32, - shuffle=True, num_workers=args.threads, - tensor_keys=["encodings", "labels"], - tensor_dims=[('B', 'C', 'H', 'W'), tuple('B')], - tensor_neural_types=[ReadPileupNeuralType(), VariantZygosityNeuralType()]) - encoding, vz_labels = test_dataset() + encoding = test_dataset() + # Execute inference vz = model(encoding=encoding) - nf.infer([vz], checkpoint_dir=parsed_args.model_dir, verbose=True) + inferred_results = nf.infer([vz], checkpoint_dir=parsed_args.model_dir, verbose=True) + + # Decode inference results to labels + inferred_zygosity = list() + zyg_decoder = ZygosityLabelDecoder() + for tensor_batches in inferred_results: + for batch in tensor_batches: + predicted_classes = torch.argmax(batch, dim=1) + inferred_zygosity.extend([zyg_decoder(pred) + for pred in predicted_classes]) + + # Create output file for each vcf reader + start_reader_idx = 0 + for vcf_reader in vcf_readers: + input_vcf_df = vcf_reader.dataframe + gt_col = "{}_GT".format(vcf_reader.samples[0]) + assert (gt_col in input_vcf_df) + # Update GT column data + reader_len = len(input_vcf_df[gt_col]) + input_vcf_df[gt_col] = inferred_zygosity[start_reader_idx:start_reader_idx+reader_len] + start_reader_idx += reader_len + output_path = '{}_{}.{}'.format( + "inferred", "".join(os.path.basename(vcf_reader.file_path).split('.')[0:-1]), 'vcf') + vcf_writer = VCFWriter(input_vcf_df, output_path=output_path, sample_names=vcf_reader.samples) + vcf_writer.write_output(input_vcf_df) def build_parser(): """Build parser object with options for sample.""" - import multiprocessing - - parser = argparse.ArgumentParser( + args_parser = argparse.ArgumentParser( description="Simple model inference SNP caller based on VariantWorks.") - parser.add_argument("--test-hdf", - help="HDF with examples for testing.", - required=True) - parser.add_argument("-t", "--threads", type=int, - help="Threads to use for parallel loading.", - required=False, default=multiprocessing.cpu_count()) - parser.add_argument("--model-dir", type=str, - help="Directory for loading saved trained model checkpoints.", - required=False, default="./models") - return parser + args_parser.add_argument("--tp-vcf-files", nargs="+", + help="List of TP VCF files to infer.", default=[], required=True) + args_parser.add_argument("--fp-vcf-files", nargs="+", + help="List of FP VCF files to infer.", default=[]) + args_parser.add_argument("--bam", type=str, + help="BAM file with reads.", required=True) + args_parser.add_argument("--model-dir", type=str, + help="Directory for loading saved trained model checkpoints.", + required=False, default="./models") + return args_parser if __name__ == "__main__": diff --git a/variantworks/io/vcfio.py b/variantworks/io/vcfio.py index f02d0ff..ffa40a1 100644 --- a/variantworks/io/vcfio.py +++ b/variantworks/io/vcfio.py @@ -173,6 +173,11 @@ def __init__(self, # Parse the VCF self._parallel_parse_vcf() + @property + def file_path(self): + """Get vcf file path for this VCF reader.""" + return self._vcf + @property def samples(self): """Get list of samples names in VCF file.""" From f99d77eae4ff47d6c76e30d98a47dfc5a8b0b0bf Mon Sep 17 00:00:00 2001 From: Ohad Mosafi Date: Thu, 8 Oct 2020 07:03:27 -0700 Subject: [PATCH 3/3] Update readme with S3 links to data for the snp samples --- docs/source/README.rst | 36 ++++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/docs/source/README.rst b/docs/source/README.rst index f7b34f8..a756d94 100644 --- a/docs/source/README.rst +++ b/docs/source/README.rst @@ -42,3 +42,39 @@ Getting Started # Install pre-push hooks to run tests ln -nfs $(readlink -f hooks/pre-push) .git/hooks/pre-push +Sample Data +--------------- +We provide sample data to explore the sample scripts for the *simple_snp_trainer*: +The data was derived from https://github.com/clara-parabricks/DL4VC/blob/master/docs/Data.md after generating +variant candidates and then intersect them with known truth set for variants in that region. + +* Bam files + #. Chr 1 + * https://variantworks.s3.us-east-2.amazonaws.com/HG001-NA12878-50x.sort.chr1.bam + * https://variantworks.s3.us-east-2.amazonaws.com/HG001-NA12878-50x.sort.chr1.bam.bai + #. Chr 10 + * https://variantworks.s3.us-east-2.amazonaws.com/HG001-NA12878-50x.sort.chr10.bam + * https://variantworks.s3.us-east-2.amazonaws.com/HG001-NA12878-50x.sort.chr10.bam.bai + #. Chr 17 + * https://variantworks.s3.us-east-2.amazonaws.com/HG001-NA12878-50x.sort.chr17.bam + * https://variantworks.s3.us-east-2.amazonaws.com/HG001-NA12878-50x.sort.chr17.bam.bai + +* VCF files: + #. Chr 1 + * https://variantworks.s3.us-east-2.amazonaws.com/HG001-NA12878_TP_chr1.vcf (True Positive) + * https://variantworks.s3.us-east-2.amazonaws.com/HG001-NA12878_FP_chr1.vcf (False Positive) + #. Chr1 (subset, first 7000 variants) + * https://variantworks.s3.us-east-2.amazonaws.com/HG001-NA12878_TP_chr1_7000samples.vcf (True Positive) + * https://variantworks.s3.us-east-2.amazonaws.com/HG001-NA12878_FP_chr1_7000samples.vcf (False Positive) + #. Chr 10 + * https://variantworks.s3.us-east-2.amazonaws.com/HG001-NA12878_TP_chr10.vcf (True Positive) + * https://variantworks.s3.us-east-2.amazonaws.com/HG001-NA12878_FP_chr10.vcf (False Positive) + #. Chr10 (subset, first 7000 variants) + * https://variantworks.s3.us-east-2.amazonaws.com/HG001-NA12878_TP_chr10_7000samples.vcf (True Positive) + * https://variantworks.s3.us-east-2.amazonaws.com/HG001-NA12878_FP_chr10_7000samples.vcf (False Positive) + #. Chr 17 + * https://variantworks.s3.us-east-2.amazonaws.com/HG001-NA12878_TP_chr17.vcf (True Positive) + * https://variantworks.s3.us-east-2.amazonaws.com/HG001-NA12878_FP_chr17.vcf (False Positive) + #. Chr17 (subset, first 7000 variants) + * https://variantworks.s3.us-east-2.amazonaws.com/HG001-NA12878_TP_chr17_7000samples.vcf (True Positive) + * https://variantworks.s3.us-east-2.amazonaws.com/HG001-NA12878_FP_chr17_7000samples.vcf (False Positive)