"""Workflow for RNA-Seq Analysis
  :Author: Arvind Rasi Subramaniam
  :Date: 7 Dec 2020
"""

import pandas as pd

sample_annotations = pd.read_table("sample_annotations.tsv")
hisat2_index_location = '/fh/fast/subramaniam_a/db/rasi/hisat2/grch38_tran/'
annotations_file = '/fh/fast/subramaniam_a/db/rasi/genomes/human/hg38/gencode/annotations/gencode.v32.annotation.gff3.gz'
print(sample_annotations)


rule all:
  """List of all files we want at the end"""
  input: 
    [f'data/gene_counts/{sample_name}.tsv.gz' for sample_name in sample_annotations.loc[:, "sample_name"]],


rule get_fastq:
  """Download fastq from SRA"""
  input:
  params:
    srr = lambda wildcards, output: sample_annotations.loc[sample_annotations['sample_name'] == wildcards.sample_name, 'srr'].item(),
  output:
    'data/fastq/{sample_name}.fastq.gz'
  log:
    'data/fastq/{sample_name}.log'
  shell:
    """
    module load SRA-Toolkit/2.10.4-ubuntu64
    fastq-dump {params.srr} --gzip -O data/fastq/ &> {log}
    mv data/fastq/{params.srr}.fastq.gz {output}
    """


rule align_genome:
  """Align non-contaminant reads against the genome and transcriptome"""
  input:
    'data/fastq/{sample_name}.fastq.gz'
  params:
    hisat2_index = hisat2_index_location
  output:
    temp('data/alignments/{sample_name}.sam')
  output:
    'data/alignments/{sample_name}.log'
  threads:
    8
  shell:
    """
    module load HISAT2/2.1.0-foss-2018b;
    hisat2 \
    -k 2 \
    --un /dev/null \
    --threads {threads} \
    --no-spliced-alignment \
    --rna-strandness F \
    --no-unal \
    -x {params.hisat2_index} \
    -U {input} \
    -S {output} \
    2> {log}
    """


rule sort_and_index_bam:
  """Convert SAM alignments to sorted and indexed BAM alignments"""
  input:
    'data/alignments/{sample_name}.sam'
  output:
    unsorted_bam = temp('data/alignments/{sample_name}.unsorted.bam'),
    bam = 'data/alignments/{sample_name}.bam',
    bam_index = 'data/alignments/{sample_name}.bam.bai'
  log:
    'data/alignments/{sample_name}.samtools.log'
  shell:
    """
    module load SAMtools/1.10-GCCcore-8.3.0;
    samtools view -@ {threads} -b {input} > {output.unsorted_bam} 2> {log}
    samtools sort -@ {threads} {output.unsorted_bam} > {output.bam} 2>> {log}
    samtools index -@ {threads} {output.bam} 2>> {log}
    """


rule get_alignment_stats:
  """Get basic alignment status using samtools stat"""
  input:
    'data/alignments/{sample_name}.bam'
  output:
    'data/alignment_stats/{sample_name}.stats.txt'
  log:
    'data/alignments/{sample_name}.samtools.log'
  shell:
    """
    module load SAMtools/1.10-GCCcore-8.3.0;
    samtools stats {input} > {output} 2>> {log}
    """


rule get_gene_counts:
  """Find the number of reads within each gene"""
  input:
    'data/alignments/{sample_name}.bam'
  params:
    annotations = annotations_file
  output:
    'data/gene_counts/{sample_name}.tsv.gz'
  log:
    'data/gene_counts/{sample_name}.log'
  shell:
    """
    Rscript calculate_gene_counts.R {input} {output} {params.annotations} 2> {log}
    """
