#!/usr/bin/env python

# mmfind

__version__ = "v1.9"
__version_date__ = "20140825"
__author__ = "Thomas Rosleff Soerensen"
__usage__ = """
  mmfind [options] (<multiple_fasta_file>|<directory_with_fasta_files>)

  Evaluates an alignment in multiple FASTA format for mismatches.
  Quality scores are considered if a file with scores is supplied
  (the file should have the same name but a "qual"-extension;
  example: test.fa and test.qual). Scores should be in FASTA format
  and will be mapped on the aligned sequences. mmfind is a command-
  line tool written in Python. It is tested with Python 2.6, 2.7 
  and 3.1.

  OPTIONS:

     -h     (help:) display this message.

     FILTERING:

     -a <integer>
            (aligned:) minimal number of aligned sequences (default:2).

     -L <integer>
            (length:) minimal alignment length (default:200).

     -p <integer>
            (polymorphism cutoff:) ignore alignments with percent mismatches per 
            length exceeding given value (default:3).

     -l <integer>
            (length:) maximal length of mismatch to be reported (default:3).

     -b <integer>
            (border distance:) minimal distance of a mismatch to the alignment 
            ends to be reported (default:80).

     -s <integer>
            (score:) minimal average score of the bases of a mismatch (default:20).

     -n <integer>
            (neighborhood scores:) minimal average score of the 10 neighboring bases
            of a mismatch (5 upstream, 5 downstream) (default:15).

     -A
            (all mismatches:) prevent filtering and display all mismatches (default:
            use default filtering, see above).

     OUTPUT OPTIONS:

     -o <basename_of_outfiles>
            (outfile:) files to which the reports should be appended (default:
            <basename_of_infile>.report.csv and <basename_of_infile>.mismatches.csv).

     -d
            (description:) write a descriptive headline to the report files (default:
            no headline).

     -c
            (clustal files:) write the alignments in clustal-like format
            (default: no).

     OTHER OPTIONS:

     -v     (verbose:) display processing information.


Script version: %s (%s)
Author: %s
""" % (__version__, __version_date__, __author__)

__todo__ = """
- minimal border distance according to trim coordinates
- type-filter
- status-filter
- verbose: messages
- -A-option: print out all mismatches, no filtering
"""

#################
# CONFIGURATION #
#################

# default values for commandline options
options = {
	'verbose': 0,
	'ref_seq': None,
	'min_alnseq': 2,
	'min_alnlen': 200,
	'pm_cutoff': 3,
	'basename': '',
	'basename_out': '',
	'report_headline': False,
	'max_pms_length': 3,
	'min_border_distance': 80,
	'min_pms_score': 20,
	'min_nbh_score': 15,
	'max_N': 100,
	'fasta_ext': ['fa', 'fasta', 'fas', 'mfa'],
	'clustal_out': False,
	}


###########
# IMPORTS #
###########

import sys, os, re, string
from types import LambdaType


###############################
# CONSTANTS, GLOBAL VARIABLES #
###############################

# error messages
__error_msg__ = "ERROR %02d: %s\n"
rkey_error = "key '%s' not supported for '%s'"
set_rkey_error = "rkeys doesn't match with keys"

# defined errors
__errorcodes__ = {
	0: "unrecognized error",
	1: "'%s' is not a file",
	2: "wrong or missing option for '%s': %s",
	13: "need filename to process",
	100: "not yet supported"
	}

# sequence processing

base2label = {
	'.': 0,
	'a': 1, 'c': 2, 'g': 4, 't': 8,
	'm': 3, 'r': 5, 'w': 9, 's': 6,
	'y': 10, 'k': 12, 'v': 7, 'h': 11,
	'd': 13, 'b': 14, 'n': 15, '-': 16,
	'?': -1, 'x': 15,
	}

label2ambigbase = {
	0: '.',	
	1: 'a', 2: 'c', 4: 'g', 8: 't',
	3: 'm', 5: 'r', 9: 'w', 6: 's',
	10: 'y', 12: 'k', 7: 'v', 11: 'h', 
	13: 'd', 14: 'b', 15: 'n', 16: '-',
	-1: '?'
	}

label2base = {
	0: '.',	
	1: 'a', 2: 'c', 4: 'g', 8: 't',
	3: 'n', 5: 'n', 9: 'n', 6: 'n',
	10: 'n', 12: 'n', 7: 'n', 11: 'n', 
	13: 'n', 14: 'n', 15: 'n', 16: '-',
	-1: '?'
	}


##################
### EXCEPTIONS ###
##################

class Error(Exception):
    def __init__(self, msg=''):
        self._msg = msg
        Exception.__init__(self, msg)
    def __repr__(self):
        return self._msg
    __str__ = __repr__

class MismatchError(Error):
    pass

class NoMismatchError(MismatchError):
    pass

class MixedMismatchError(MismatchError):
    pass

class MismatchLengthError(MismatchError):
    pass

class MismatchExpandError(MismatchError):
    pass


###################
#### FUNCTIONS ####
###################

def exit_on_error(number, *args):

	sys.stderr.write("mmfind_version=%s\n" % __version__)
	if args:
		error_txt = __errorcodes__[number] % tuple(args)
	else:
		error_txt = __errorcodes__[number]
	sys.stderr.write(__error_msg__ % (number, error_txt))

	sys.exit()

def avg(l):
	return reduce(lambda a, b: a+b, l, 0) / len(l)

## alignment functions

def firstbase(seq):
    m = ambigcharsp.search(seq)
    if m:
        return m.start()
    else:
        raise StandardError("no bases in sequence")

def has_gap_ends(l):
    if l and type(l) == str:
        l = [l]
    for item in l:
        if len(item) > 0 and (item[0] in ('.', '-') or item[-1] in ('.', '-')):
            return 1
    return 0

# all items of a list are of equal length
is_eqlen = lambda l: len(l) < 2 or (len(l[0]) == len(l[1])
                                     and is_eqlen(l[1:]))

# comparison of ambiguity characters
is_bitwise_homogen = lambda l: len(l) < 2 or (l[0] & l[1]
                                     and is_bitwise_homogen(l[1:]))

# comparison when ambiguity characters indicate heterocygosity
is_equal = lambda l: len(l) < 2 or (l[0] == l[1]
                                     and is_equal(l[1:]))

# list contains gap character
has_gapchar_list = lambda l: len(filter(lambda i: '-' in i, l)) > 0

# return a gap free version of list l
ungap_list = lambda l: map(lambda i: allgapcharsp.sub('', i), l)

# count base chars in a list
base_count_list = lambda l: len(filter(lambda i: i, ungap_list(l)))

# list contains no bases
no_base_list = lambda l: len(filter(lambda i: i, ungap_list(l))) == 0

# return a version of list l where all items of list rl are removed
remove_all = lambda l, rl: filter(lambda i: i not in rl, l)

def ltgaps2dots(seq):

	m = re.match('^(\-*)(.*?)(\-*)$', seq)
	return len(m.group(1)) * '.' + m.group(2) + len(m.group(3)) * '.'

def fastafile2dict(infile):

	d, ids = {}, []
	t = open(infile).read()
	seqlist = t.split('>')[1:]
	seqlist = map(lambda s: re.split('[\n\r]', s, 1), seqlist)
	for seq in seqlist:
		k = seq[0].split(' ', 1)[0]
		d[k] = re.sub('\s', '', seq[1])
		ids.append(k)
	return d, ids

def qualfile2dict(infile):

	d = {}
	t = open(infile).read()
	sclist = t.split('>')[1:]
	sclist = map(lambda s: re.split('[\n\r]', s, 1), sclist)
	for sc in sclist:
		k = sc[0].split(' ', 1)[0]
		d[k] = map(int, re.sub('[\n\r]', ' ', sc[1].strip()).split())
	return d

def consensus(*seq, **kwargs):

	# possible kwargs are:
	# use_ambigchars = (True|False); default: True
	# remove_gaponly_sites = (True|False); default: False
	# clip_single_sequence_ends = (True|False); default: False
	# indelchar = [one_character]; default: 'I'

	# check sequence arguments

	if len(seq) < 2:
		raise ValueError('sequences should be more than one')

	seqlen = len(seq[0])
	if len(filter(lambda s: len(s) == seqlen, seq[1:])) != len(seq[1:]):
		raise('sequences should be of equal length')

	# setting processing options

	use_ambigchars = kwargs.get('use_ambigchars', True)
	if use_ambigchars:
		f2b = label2ambigbase
	else:
		f2b = label2base
	remove_gaponly_sites = kwargs.get('remove_gaponly_sites', False)
	clip_single_sequence_ends = kwargs.get('clip_single_sequence_ends', False)
	indelchar = kwargs.get('indelchar', 'I')
	is_end = 1 # status flag for evaluation of clipping ends
	cache = []
	cons = []

	# walk through sequences
	for i in range(seqlen):

		# counting valid bases per position
		c = 0
		# consensus base at position
		base = 0

		for s in seq:

			label = base2label[s[i].lower()]
			if not use_ambigchars and label not in (0, 1, 2, 4, 8, 16):
				# original contains ambiguity character
				base = -1
				break
			if s[i] == '.':
				# no data available for original sequence
				continue
			base |= label
			c += 1

		if base in (0, 16) and remove_gaponly_sites:
			# only gap or "no data available"
			continue
		elif base > 16:
			# base and gap detected -> indel
			cons.append(indelchar)
		else:
			if c == 1:
				# only data from one sequence available
				if clip_single_sequence_ends:
					if not is_end:
						is_end = 1
					cache.append(f2b[base].lower())
				else:
					cons.append(f2b[base].lower())
			else:
				is_end *= 0
				if cache:
					if cons or (not clip_single_sequence_ends):
						cons += cache
					cache = []
				cons.append(f2b[base].upper())

	return "".join(cons)

def process_file(infile, options):

	data = {'mr': []}

	dd, ids = fastafile2dict(infile)

	use_qual = False
	fn = "%s.qual" % os.path.join(options['dirname'], options['basename'])
	if options['verbose']:
		print "Searching for '%s'" % fn ####
	if os.path.isfile(fn):
		use_qual = True
	
	if use_qual:
		if options['verbose']:
			print "Reading qual-file '%s'." % fn
		dq = qualfile2dict(fn)
	
	d = {
			'id': options['basename'],
			'aln_length': 0,
			'mismatches': 0,
			'snp_count': 0,
			'error': 0, 
			}
	ar = MismatchEvaluationRecordCSV(d)
	
	srcseq_id = ''
	if not options['ref_seq']:
		ignore_ref = True
	else:
		if options['verbose']:
			print "Using reference sequence '%s'" % options['ref_seq']
		ignore_ref = False

	aln = Alignment(
			dd, 
			options['ref_seq'],
			complete=1, 
			leftpos=1, 
			flanking=0,
			with_labels=True,
			first_seq_gaps_as_introns=False,
			heterozygous=False,
			)

	coor = aln.trimcoordinates(options['min_alnseq'])

	if not is_eqlen(dd.values()):
		ar['error'] = -42
	elif len(aln.keys()) < int(options['min_alnseq']):
		ar['error'] = -602
	elif not coor or (coor[1] - coor[0] + 1) < options['min_alnlen']:
		ar['error'] = -605
	if ar['error'] < 0:
		data['ar'] = ar
		return data

	if use_qual:
		aln.map_scores(dq)

	#aln.pprint(with_scores=1) ####
	
	ar['aln_length'] = len(aln)
	len(aln.keys()), aln.leftpos, aln.rightpos
	
	number_of_aligned_bases = aln.aligned_bases_count()
	aln.evaluate_matching_bases()
	
	aln.evaluate_mismatches()
	ar['mismatches'] = len(aln) - aln.associated_sequences['mismatches'].count('.')
	mm = aln.get_mismatches(100)
	
	if options['verbose']:
		print "Found %s aligned bases." % number_of_aligned_bases
		print "Detected %s mismatches." % ar['mismatches']
	#mm = aln.get_mismatches(100)
	
	if options['clustal_out']:
		fo = open("%s.aln" % os.path.join(options['dirname'], options['basename']), 'w')
		fo.write("%s\n" % aln.pprint(50, 1, 1, 1, 1))
		fo.close()

	# calculate false observations
	#false_observations = aln.associated_sequences['mismatches'].count('!')
	#ar['false_observations'] = false_observations

	mismatches = 0
	type_count = {}
	type_count_b = {}
	
	for m in mm:
		start = m.i
		end = m.i + len(m)
		if '!' in aln.associated_sequences['mismatches'][start:end]:
			continue
		if '.' in ''.join(m.values()):
			continue
		# Sequence.ltgaps2dots( *seq )
		consargs = {
				    'use_ambigchars': False,
				    'indelchar': 'N',
				    'remove_gaponly_sites': True
				    }
		consl = consensus(*m.left_bases.values(), **consargs)
		consr = consensus(*m.right_bases.values(), **consargs)
		mbd = min([len(consl), len(consr)])
		# scoring polymorphisms
		pm_sc_avg, surr_sc_avg  = '', ''
		if use_qual:
			pm_scs, surr_scs  = [], []
			for name in m.keys():
				pm_scs.append(avg(m.scores[name]))
				surr_scs.append(avg(m.left_bases.scores[name][-5:]+m.right_bases.scores[name][:5]))
			pm_sc_avg, surr_sc_avg = min(pm_scs), min(surr_scs)
		# polymorphism status
		m_status = 0
		if len(m) <= options['max_pms_length']:
			m_status |= 1
		if mbd > options['min_border_distance']:
			m_status |= 2
		if surr_sc_avg >= options['min_nbh_score'] and pm_sc_avg >= options['min_pms_score']:
			m_status |= 4
		mr = MismatchRecordCSV(
			{
				'id': options['basename'],
				'type': {'s': 'SNV', 'S': 'MNV', 'i': 'InDel', 'I': 'InDel', 'M': 'Mixed'}[m.type],
				'label': aln.associated_sequences['mismatches'][start:end],
				'startpos_aln': start + 1,
				'length': str(len(m)),
				'consensus': consl + '[' + '/'.join(m.values()) + ']' + consr,
				'min_border_distance': mbd,
				'polymorphism_avg_score': str(pm_sc_avg),
				'neighborhood_avg_score': str(surr_sc_avg),
				'n_count': consl.lower().count('n') + consr.lower().count('n'),
				'status': m_status,
				}
			)
		# filtering
		if (
			len(m) <= options['max_pms_length']
			and mbd >= options['min_border_distance']
			and pm_sc_avg >= options['min_pms_score']
			and surr_sc_avg >= options['min_nbh_score']
			):
			if m.type in type_count:
				type_count[m.type] += 1
				type_count_b[m.type] += int(mr['length'])
			else:
				type_count[m.type] = 1
				type_count_b[m.type] = int(mr['length'])
			mismatches += len(m)
			data['mr'].append(mr)
	ar['snp_count'] = type_count.get('s', '0')
	ar['mb_snp_count'] = type_count.get('S', '0')
	ar['indel_count'] = type_count.get('i', '0')
	ar['mb_indel_count'] = type_count.get('I', '0')
	ar['snp_b_count'] = type_count_b.get('s', '0')
	ar['mb_snp_b_count'] = type_count_b.get('S', '0')
	ar['indel_b_count'] = type_count_b.get('i', '0')
	ar['mb_indel_b_count'] = type_count_b.get('I', '0')
	poly_percent = 100 * int(ar['mismatches']) / float(number_of_aligned_bases)
	ar['poly_percent'] = "% 5.2f" % poly_percent

	#if ar['poly_fraction'] > (int(options['pm_cutoff']) / 100.):
	if poly_percent > options['pm_cutoff']:
		ar['error'] = -656
		data['mr'] = []

	else:
		ar['status'] = 1

	data['ar'] = ar
	return data


#################
#### CLASSES ####
#################

class Record(dict):

	# supported keys
	rkeys = []
	skeys = []

	def __init__(self, d={}):
		dict.__init__(self, d)
		self.rkeys = map(str, self.__class__.rkeys)

	def __setitem__(self, k, v):
		k, v = str(k), str(v)
		if k in self.rkeys:
			dict.__setitem__(self, k, v)
		else:
			raise KeyError(rkey_error % (k, self.__class__))
	
	def __getitem__(self, k):
		return dict.get(self, str(k), '')

	def __repr__(self):
		return self.record()

	def set_rkeys(self, rkeys):
		rkeys = map(str, rkeys)
		if reduce(lambda s, k: s | (k in self.rkeys), rkeys, 1):
			self.rkeys = rkeys
		else:
			raise KeyError(set_rkey_error)

	def keys(self):
		return self.rkeys

	def values(self):
		return map(lambda k: self[k], self.rkeys)

	def items(self):
		return map(lambda k: (k, self[k]), self.rkeys)

	def record(self, record_delim="----\n"):
		t = []
		for k in self.rkeys:
			t.append("%s=%s" % (k, self[k]))
		t.append(record_delim)
		return "\n".join(t)

	def header(self, delim='\t', enclose='"'):
		return "%s\n" % "\t".join(self.keys())

	def header2(self, delim='\t', enclose='"'):
		return "%s\n" % "\t".join(self.__class__.skeys)

	def csv(self, delim='\t', enclose='"'):
		return "%s\n" % "\t".join(map(str, self.values()))

	def xml(self):
		pass

class MismatchRecordCSV(Record):

	rkeys = [
		'id',
		'type',
		'startpos_aln',
		'label',
		'length',
		'consensus',
		'min_border_distance',
		'polymorphism_avg_score',
		'neighborhood_avg_score',
		'n_count',
		'status',
		]
	skeys = [
		'ID',
		'TYPE',
		'ALN_S',
		'PS_CONS',
		'PS_LEN',
		'CONS',
		'MINBORD',
		'PSAVGSC',
		'NGAVGSC',
		'N_COUNT',
		'STATUS',
		]

class MismatchEvaluationRecordCSV(Record):

	rkeys = [
		'id',
		'aln_length',
		'mismatches',
		'snp_count',
		'snp_b_count',
		'mb_snp_count',
		'mb_snp_b_count',
		'indel_count',
		'indel_b_count',
		'mb_indel_count',
		'mb_indel_b_count',
		'poly_percent',
		'status',
		'error',
		]
	skeys = [
		'ID',
		'ALN_LEN',
		'MISM',
		'SNVS',
		'SNV_B',
		'MNVS',
		'MNV_B',
		'S_IND',
		'S_IND_B',
		'M_IND',
		'M_IND_B',
		'P_PERC',
		'STATUS',
		'ERROR',
		]

class TWDict(dict):

	def __init__(self, d={}, supp={}):
		"""supp are supplementary keys referencing a value.
		These will not be backtranslated."""
		dict.__init__(self, d)
		self.r = {}
		for k, v in d.items():
			self._set_reverse(k, v)
		self.update(supp)

	def _set_reverse(self, k, v):
		if v not in self.r:
			self.r[v] = k
		else:
			raise ValueError("key-value-pairs not one-to-one")

	def key2value(self, k):
		return dict.__getitem__(self, k)

	def value2key(self, v):
		return self.r[v]

	def __setitem__(self, k, v):
		dict.__setitem__(self, k, v)
		self._set_reverse(k, v)

trd = TWDict({'a': 1, 'c': 2, 'g': 4, 't': 8,
              'm': 3, 'r': 5, 'w': 9, 's': 6,
              'y': 10, 'k': 12, 'v': 7, 'h': 11,
              'd': 13, 'b': 14, 'n': 15, '-': 16},
             {'x': 15})
ambigchars = trd.keys()
ambigcharsp = re.compile("[%s]" % ''.join(ambigchars), re.I)
allgapchars = ['\.', '\-']
allgapcharsp = re.compile("[\.\-]")
gapendp = re.compile("^[\.\-]|[\.\-]$")

class Dict(dict):

	pass

class Alignment(Dict):
	
	nodots = 0

	mismatch_types = {
		's': 'SNV', # single-nucleotide variant
		'S': 'MNV', # multi-nucleotide variant
		'i': 'InDel',
		'I': 'InDel',
		'M': 'Mixed',
		}

	def __init__(self, d={}, ref='', **kwargs):

		dict.__init__(self, {})

		if d:
			d = dict(d)
			if self.__class__ == Mismatch:
				self.left_bases = kwargs.get('left_bases', {})
				self.right_bases = kwargs.get('right_bases', {})
				self.i = kwargs.get('i', 0)
			self.comment = ''
			self._check_length(d, kwargs)
			self.associated_sequences = kwargs.get('associated_sequences', {})
			self.scores = kwargs.get('scores', {})
			self.left_scores = kwargs.get('left_scores', {})
			self.right_scores = kwargs.get('right_scores', {})
			if kwargs.get('nodots', self.__class__.nodots):
				self._store_sequences(d)
			else:
				self._ends2dots(d)
			if ref:
				self.seq_ref = ref
				self._calculate_coordinates(kwargs)
			else:
				self.seq_ref = None
				self.flanking = 0
				self.leftpos = kwargs.get('leftpos', 1)
				self.rightpos = len(d.values()[0])
			# other attributes
			self.with_labels = kwargs.get('with_labels', False)
			self.first_seq_gaps_as_introns = kwargs.get('first_seq_gaps_as_introns', False)
			self.heterozygous = kwargs.get('heterozygous', False)
			if self.__class__ == Mismatch:
				self._check_mismatch()

	# FORMATTING (INTERNAL)

	def _ends2dots(self, d):
		for k in d.keys():
			self[k] = ltgaps2dots(d[k])

	def _store_sequences(self, d):
		for k in d.keys():
			self[k] = d[k]

	def _check_length(self, d, kwargs):
		if not is_eqlen(d.values()):
			if kwargs.get('complete', 0):
				maxlen = max(map(len, d.values()))
				for k in d.keys():
					d[k] = d[k] + (maxlen - len(d[k])) * '.'
			else:
				raise StandardError("lengths differ")

	def __len__(self):
		try:
			return len(self.values()[0])
		except IndexError:
			return None
	
	def aligned_bases_count(self, min_aligned=2):
		n = 0
		for i in range(len(self)):
			if base_count_list(self.item(i).values()) >= min_aligned:
				n += 1
		return n

	# BIOLOGICAL INDEXING

	def _calculate_coordinates(self, kwargs):
		try:
			if has_gap_ends(self[self.seq_ref]):
				self.flanking = 1
				self.leftpos = kwargs.get('leftpos', 1) \
								   + kwargs.get('flanking', 0) - 1
				self.rightpos = self._calculate_pos(len(self)-1)[2]
			else:
				self.flanking = 0
				self.leftpos = kwargs.get('leftpos', 1) \
								   + kwargs.get('flanking', 0)
				self.rightpos = self._calculate_pos(len(self)-1)[0]
		except IndexError:
			self.flanking = 1
			self.leftpos = kwargs.get('leftpos', 1) \
							   + kwargs.get('flanking', 0) - 1
			self.rightpos = self.leftpos

	def _calculate_pos(self, i):
		if self.seq_ref:
			bases = len(ungap(self[self.seq_ref][:i+1]))
			if ambigcharsp.search(self[self.seq_ref][i]):
				pos = self.leftpos + bases - 1 + self.flanking
			else:
				pos = None
			try:
				flpos = pos - 1
				frpos = pos + 1
			except:
				flpos = self.leftpos + bases - 1 + self.flanking
				frpos = flpos + 1
		else:
			pos = self.leftpos + i - 1
			flpos, frpos = None, None
		return (pos, flpos, frpos)

	# STRING REPRESENTATIONS

	def __str__(self):
		tmpk = self.keys()
		tmpk.sort()
		if self.seq_ref:
			tmpk.remove(self.seq_ref)
			tmpk = [self.seq_ref] + tmpk
		lines = map(lambda k: "%12s %s" % (str(k)[:12], self[k]), tmpk)
		return '\n'.join(lines)

	def pprint(self, width=60, with_mm=0, with_num=0, with_match=0, with_scores=0):
		t = []
		tmpk = self.keys()
		tmpk.sort()
		if self.seq_ref:
			tmpk.remove(self.seq_ref)
			tmpk = [self.seq_ref] + tmpk
		if with_num:
			s, e = 0, 0
			coor = dict(zip(tmpk, len(tmpk)*[0]))
			coor_new = coor.copy()
			for i in range(0, len(self), width):
				sl = self.slice(i, i+width)
				for k in tmpk:
					coor_new[k] +=  base_count_list(sl[k])
				e += len(sl[k])
				t.append("%12s % 6s %s % 6s" % ('', s+1, ' '*len(sl[k]), e))
				lines = map(lambda k: "%12s % 6s %s % 6s" % (str(k)[:12], [coor[k]+1, ''][not base_count_list(sl[k])], sl[k], [coor_new[k], ''][not base_count_list(sl[k])]), tmpk)
				t += lines
				if with_mm and self.associated_sequences.has_key('mismatches'):
					t.append("%12s        %s" % ('mismatches', self.associated_sequences['mismatches'][i:i+width]))
				if with_mm and self.associated_sequences.has_key('coding_state'):
					t.append("%12s        %s" % ('coding', self.associated_sequences['coding_state'][i:i+width]))
				if with_match and self.associated_sequences.has_key('matches'):
					t.append("%12s        %s" % ('matches', self.associated_sequences['matches'][i:i+width]))
				if with_scores and self.scores:
					lines = map(lambda k: "%12s        %s" % ('', "".join(map(lambda v: ("%02d" % v)[0], sl.scores[k]))), tmpk)
					t += lines
				t.append('')
				s = e
				coor = coor_new.copy()
		else:
			for i in range(0, len(self), width):
				sl = self.slice(i, i+width)
				t.append(str(sl))
				if with_mm and self.associated_sequences.has_key('mismatches'):
					t.append("%12s %s" % ('mismatches', self.associated_sequences['mismatches'][i:i+width]))
				if with_mm and self.associated_sequences.has_key('coding_state'):
					t.append("%12s %s" % ('coding', self.associated_sequences['coding_state'][i:i+width]))
				if with_match and self.associated_sequences.has_key('matches'):
					t.append("%12s        %s" % ('matches', self.associated_sequences['matches'][i:i+width]))
				if with_scores and self.scores:
					lines = map(lambda k: "%12s        %s" % ('', "".join(map(lambda v: ("%02d" % v)[0], sl.scores[k]))), tmpk)
					t += lines
				t.append('')
		return '\n'.join(t)

	def as_fasta(self):
		pass

	def as_csv(self):
		pass

	def as_records(self):
		pass

	def map_scores(self, scores):
		for k in self.keys():
			if k in scores:
				i, tmp = 0, []
				for c in self[k]:
					if c == '.':
						tmp.append(0)
						continue
					elif c == '-':
						tmp.append(20)
						continue
					tmp.append(scores[k][i])
					i += 1
				self.scores[k] = tmp

	def ids(self, pattern=''):
		if not pattern:
			return self.keys()
		if type(pattern) == LambdaType:
			return filter(lambda id: pattern(id), self.keys())
		else:
			return filter(lambda id: re.search(pattern, id), self.keys())

	def sequences(self, pattern=''):
		if not pattern:
			return self.values()
		if type(pattern) == LambdaType:
			return filter(lambda id: pattern(id), self.values())
		else:
			return filter(lambda id: re.search(pattern, id), self.values())

	def bases(self, i, pattern=''):
		"""Returns a list of the bases of a position."""
		return map(lambda id: self[id][i], self.ids(pattern))

	def item(self, i, **kwargs):
		"""Returns an Alignment object!"""
		j = i + 1
		return self.slice(i, j, **kwargs)

	def slice(self, i, j, **kwargs):
		"""Returns an Alignment object!"""
		d = {}
		for k in self.keys():
			d[k] = self[k][i:j]
		if i == j:
			lpos, flpos, frpos = None, i, i
		else:
			lpos, flpos, frpos = self._calculate_pos(i)
		if lpos:
			dd = {'leftpos': lpos, 'flanking': 0, 'nodots': 1}
		else:
			dd = {'leftpos': flpos, 'flanking': 1, 'nodots': 1}
		dd['comment'] = self.comment
		dd['with_labels'] = True
		dd['first_seq_gaps_as_introns'] = self.first_seq_gaps_as_introns
		dd['heterozygous'] = self.heterozygous
		tmp = {}
		for k in self.associated_sequences.keys():
			tmp[k] = self.associated_sequences[k][i:j]
		dd['associated_sequences'] = tmp
		for k in self.scores.keys():
			tmp[k] = self.scores[k][i:j]
		dd['scores'] = tmp
		dd.update(kwargs)
		return self.__class__(d, self.seq_ref, **dd)

	# TRIMMING, REMOVAL OF COLUMNS

	def trim(self, cmp_cutoff=2):
		for i in range(len(self)):
			if base_count_list(self.item(i).values()) >= cmp_cutoff:
				break
		for j in range(len(self)-1, -1, -1):
			if base_count_list(self.item(j).values()) >= cmp_cutoff:
				break
		if i <= j:
			return self.slice(i, j+1)
		else:
			return None

	def trimcoordinates(self, cmp_cutoff=2):
		for i in range(len(self)):
			if base_count_list(self.item(i).values()) >= cmp_cutoff:
				break
		for j in range(len(self)-1, -1, -1):
			if base_count_list(self.item(j).values()) >= cmp_cutoff:
				break
		if i <= j:
			return i, j+1
		else:
			return None

	def clip_leading_nonbase_items(self):
		for i in range(len(self)):
			if not no_base_list(self.item(i).values()):
				break
		return self.slice(i, len(self))

	def without_nonbase_items(self):
		tmp = {}
		for k in self.keys():
			tmp[k] = ''
		for i in range(len(self)):
			if not no_base_list(self.item(i).values()):
				 for k in self.keys():
					 tmp[k] += self[k][i]
		return self.__class__(tmp, self.seq_ref, **self.__dict__)

	def get_surrounding_bases(self, start, end, n):
		return (self.slice(max(0, start-n), start),
				self.slice(end, min(len(self), end+n)))

	def set_surrounding_bases(self, dl, dr):
		self.left_bases = dl
		self.right_bases = dr

	# EVALUATION OF COLUMNS

	def evaluate_coding_state(self):
		ref = self[self.seq_ref]
		ic, il = 0, 0
		self.associated_sequences['coding_state'] = ['e'] * len(self)
		if type(self.first_seq_gaps_as_introns) == int:
			min_intron_len = self.first_seq_gaps_as_introns
			offset = 0
			while 1:
				m = re.search("(\-{%s,})" % min_intron_len, ref[offset:])
				if m:
					startval = m.start(1) + offset
					endval = m.end(1) + offset
					for i in range(startval, endval, 1):
						self.associated_sequences['coding_state'][i] = 'i'
					offset = endval
					ic += 1
					il += (m.end(1) - m.start(1))
				else:
					break
		self.associated_sequences['coding_state'] = ''.join(self.associated_sequences['coding_state'])
		self.intron_count = ic
		self.intron_length = il

	def evaluate_consensus(self):
		clist = []
		for i in range(len(self)):
			v = self.item(i).values()
			v = remove_all(v, ['.', '-'])
			v = map(lambda ov: trd.key2value(ov.lower()), v)
			if len(v) > 1:
				c = trd.value2key(reduce(int.__and__, v)).upper()
			else:
				c = trd.value2key(v[0]).lower()
			clist.append(c)
		self.consensus = ''.join(clist)

	def evaluate_matching_bases(self, cmp_cutoff=2):
		clist = []
		for i in range(len(self)):
			v = self.item(i).values()
			v = remove_all(v, ['.'])
			v = map(lambda ov: trd.key2value(ov.lower()), v)
			if len(v) >= cmp_cutoff:
				if is_bitwise_homogen(v):
					c = '*'
				else:
					c = ' '
			else:
				c = ' '
			clist.append(c)
		self.associated_sequences['matches'] = ''.join(clist)

	def evaluate_mismatches(self):

		mmt = []

		for i in range(len(self)):

			bases = filter(lambda b: b != '.', self.bases(i))

			brepr = map(lambda b: trd.key2value(b.lower()), bases)
			brepr_nfree = filter(lambda v: v != 15, brepr)
			label = reduce(lambda a, b: a | b, brepr_nfree, 0)

			if (is_equal(brepr) 
				or (not self.heterozygous and is_bitwise_homogen(brepr))
				or (self.heterozygous and is_equal(brepr_nfree))
				):
				# non polymorphic sites, gap only sites,
				# sites with ambiguous chars and not stringently polymorphic
				mmt.append('.')
			elif label < 16:
				# definitely a SNP
				if self.with_labels:
					mmt.append(label2ambigbase[label])
				else:
					mmt.append('S')
			else:

				# first sequence is denoted as EST and gaps should be taken
				# as introns
				if self.first_seq_gaps_as_introns and len(brepr) > 1:
					alabel = reduce(lambda a, b: a | b, brepr[1:], 0)
					if (not self.heterozygous and is_bitwise_homogen(brepr[1:])) or is_equal(brepr[1:]):
						mmt.append('.')
					elif alabel < 16:
						if self.with_labels:
							mmt.append(label2ambigbase[alabel])
						else:
							mmt.append('S')
					elif alabel in [17, 18, 20, 24]:
						mmt.append('I')
					else:
						mmt.append('?')

				else:
					if len(brepr) <= 1:
						print "ERROR"
						sys.exit()
					if label in [17, 18, 20, 24]:
						# definitely an indel
						mmt.append('I')
					else:
						mmt.append('?')

		self.associated_sequences['mismatches'] = ''.join(mmt)

	def get_mismatches(self, number_surr_bases=100, no_S=False):
		
		# no_S: treat multibase SNP events as multiple onebase SNPs

		mm = []

		if not self.associated_sequences.has_key('mismatches'):
			self.evaluate_mismatches()
		mms = self.associated_sequences['mismatches']

		if no_S:
			p = '(([^\.I\?\!])|\.([I\?]+)(?=\.))'
		else:
			p = '(\.([^\.I\?\!]+)(?=\.)|\.([I\?]+)(?=\.))'

		mo_tmp, start, end = None, None, None

		if self.seq_ref:
			offset = firstbase(self[self.seq_ref])
		else:
			offset = 0

		while 1:

			# search from firstbase or end of last match
			m = re.search(p, mms[offset:])

			if not m:
				# no match -> stop search
				break

			# snp or indel?
			if m.group(2):
				start, end = m.start(2), m.end(2)
			elif m.group(3):
				start, end = m.start(3), m.end(3)

			lpos, flpos, dummy = self._calculate_pos(start + offset)
			rpos, dummy, frpos = self._calculate_pos(end - 1 + offset)
			if lpos and rpos:
				dd = {
					'leftpos': lpos, 
					'rightpos': rpos,
					'flanking': 0, 
					}
			else:
				dd = {
					'leftpos': flpos, 
					'rightpos': frpos, 
					'flanking': 1, 
					}
			dd['nodots'] = 1
			dd['i'] = start + offset
			dd['with_labels'] = True
			dd['first_seq_gaps_as_introns'] = False
			dd['heterozygous'] = self.heterozygous
			dd['associated_sequences'] = {}
			for k in self.associated_sequences:
				dd['associated_sequences'][k] = self.associated_sequences[k][start + offset:end + offset]
			dd['scores'] = {}
			for k in self.scores:
				dd['scores'][k] = self.scores[k][start + offset:end + offset]
			mo = Mismatch(dict(self.slice(start + offset, end + offset)), self.seq_ref, **dd)
			dl, dr = self.get_surrounding_bases(start + offset, end + offset, number_surr_bases)
			mo.set_surrounding_bases(dl, dr)
			mm.append(mo)
			offset += end - 1
	
		return mm

	def check_mismatches(self, check=None):

		if type(check) != LambdaType:
			return

		mmt = list(self.associated_sequences['mismatches'])

		for i in range(len(self)):

			bases = filter(lambda b: b != '.', self.bases(i, check))
			brepr = map(lambda b: trd.key2value(b.lower()), bases)

			if is_bitwise_homogen(brepr) or self.heterozygous:
				# non polymorphic sites, gap only sites,
				# sites with ambiguous chars and not stringently polymorphic
				continue
			else:
				# a "polymorphism" that should not be there
				mmt[i] = '!'

		self.associated_sequences['mismatches'] = ''.join(mmt)

	def detect_mismatches(self, number_surr_bases=100):
		mm = []
		mo_tmp, start, end = None, None, None
		offset = firstbase(self[self.seq_ref])
		for i in range(offset, len(self)):
			lpos, flpos, frpos = self._calculate_pos(i)
			if lpos:
				dd = {'leftpos': lpos, 'flanking': 0, 'nodots': 1, 'i': i}
			else:
				dd = {'leftpos': flpos, 'flanking': 1, 'nodots': 1, 'i': i}
			try:
				mo = Mismatch(self.item(i), self.seq_ref, **dd)
				if mo_tmp:
					try:
						mo_tmp.expand(mo)
						end = i+1
					except MismatchExpandError:
						mo_tmp.set_surrounding_bases(*self.get_surrounding_bases(start, end, number_surr_bases))
						mm.append(mo_tmp)
						mo_tmp = mo
						start, end = i, i+1
				else:
					mo_tmp = mo
					start, end = i, i+1
			except NoMismatchError:
				if mo_tmp:
					mo_tmp.set_surrounding_bases(*self.get_surrounding_bases(start, end, number_surr_bases))
					mm.append(mo_tmp)
					mo_tmp = None
					start, end = None, None
			except (MixedMismatchError, MismatchLengthError):
				self.comment += "Index %s: MixedMismatchError or MismatchLengthError; " % i
				if mo_tmp:
					mo_tmp.set_surrounding_bases(*self.get_surrounding_bases(start, end, number_surr_bases))
					mm.append(mo_tmp)
					mo_tmp = None
					start, end = None, None
		if mo_tmp:
			mo_tmp.set_surrounding_bases(*self.get_surrounding_bases(start, end, number_surr_bases))
			mm.append(mo_tmp)
		return mm

	def map_alignment(self, other):
		# check if reference sequences fit
		assert self.seq_ref == other.seq_ref
		ref = self.seq_ref
		assert ungap(self[ref]) == ungap(other[ref])
		assert firstbase(self[ref]) == firstbase(other[ref]) == 0
		# prepare dict for temporary data storage
		d = {}
		for k in self.keys() + other.keys():
			d[k] = []
		mk_self = self.keys()
		mk_self.remove(ref)
		mk_other = other.keys()
		mk_other.remove(ref)
		# start mapping
		i, j = 0, 0
		while 1:
			try:
				c_self = self[ref][i]
			except IndexError:
				d[ref] += list(other[ref][j:])
				remlen = len(other[ref][j:])
				for k in mk_self:
					d[k] += ['-'] * remlen
				for k in mk_other:
					d[k] += list(other[k][j:])
				break
			try:
				c_oth = other[ref][j]
			except IndexError:
				d[ref] += list(self[ref][i:])
				remlen = len(self[ref][i:])
				for k in mk_self:
					d[k] += list(self[k][i:])
				for k in mk_other:
					d[k] += ['-'] * remlen
				break
			if c_self == c_oth:
				d[ref].append(c_self)
				for k in mk_self:
					d[k].append(self[k][i])
				for k in mk_other:
					d[k].append(other[k][j])
				i += 1
				j += 1
			elif c_self == '-':
				d[ref].append('-')
				for k in mk_self:
					d[k].append(self[k][i])
				for k in mk_other:
					d[k].append('-')
				i += 1
			elif c_oth == '-':
				d[ref].append('-')
				for k in mk_self:
					d[k].append('-')
				for k in mk_other:
					d[k].append(other[k][j])
				j += 1
		for k in d:
			d[k] = ''.join(d[k]).replace('.', '-')
		return Alignment(d, ref)

	def tag(self, id, seq):

		pass

class Bases(Dict):
	"""Dealing with strings and lists of bases."""

	def __init__(self, l):

		self.l = l

		bcount = {}
		for b in self.l:
			b = b.lower()
			try:
				bcount[b] += 1
			except KeyError:
				bcount[b] = 1
		dict.__init__(self, bcount)

	def unambiguous(self):

		return Bases(filter(lambda b: b in 'acgt', self.l))
				
	def ambiguous(self):

		return Bases(filter(lambda b: b in 'mrwsykvhdbnx', self.l))

	def gaps(self):

		return Bases(filter(lambda b: b == '-', self.l))

	def dots(self):
				
		return Bases(filter(lambda b: b == '.', self.l))


class Mismatch(Alignment):

	"""d should be a dictionary with keys = id, values = bases
	   ref is the reference, which should be an id in mm.keys()
	   pos is the position of the mismatch"""
	nodots = 1

	def _check_mismatch(self):
		if not self.associated_sequences.has_key('mismatches'):
			self.evaluate_mismatches()
		mms = self.associated_sequences['mismatches']
		if '!' in mms:
			self.type = '!'
		elif re.search('^[^\.I\?\!]+$', mms):
			# SNP
			if len(self) == 1:
				self.type = 's'
			else:
				self.type = 'S'
		elif re.search('^[\.I]+$', mms):
			# indel
			if len(self) == 1:
				self.type = 'i'
			else:
				self.type = 'I'
		else:
			# mixed type
			self.type = 'M'
		d = {}
		if self.associated_sequences.has_key('coding_state'):
			for b in self.associated_sequences['coding_state']:
				if d.has_key(b):
					d[b] += 1
				else:
					d[b] = 1
			if len(d) == 1:
				if 'e' in d:
					self.coding = True
				else:
					self.coding = False
			else:
				raise StandardError('Error in Mismatch object: coding state inhomogen')

	def evaluate_allele_frequency(self, type):
		# self.seq_ref!
		d = {}
		for k in self:
			if self.ignore_ref and k == self.seq_ref:
				continue
			if d.has_key(self[k]):
				d[self[k]] += 1
			else:
				d[self[k]] = 1
		if type == 'ABSOLUTE':
			return d
		elif type == 'PERCENT':
			if self.ignore_ref:
				total = len(self.values()) - 1
			else:
				total = len(self.values())
			for k in d:
				d[k] = float(d[k]) / total
			return d

	def _check_length(self, d, kwargs):
		if not is_eqlen(d.values()):
			raise MismatchLengthError("lengths differ")

	def _is_snp(self):
		if re.search('^[^\.I\?]+$', self.associated_sequences['mismatches']):
			return True
		else:
			return False

	def _is_indel(self):
		if re.search('^[I\?]+$', self.associated_sequences['mismatches']):
			return True
		else:
			return False

	def expand(self, mo, strict_snps=1, **kwargs):
		"""mo should be an Mismatch object"""
		if not self.__class__ == mo.__class__:
			raise ValueError("argument is not a Mismatch object")
		if mo.type == self.type and mo.flanking == self.flanking:
			c = {0: 1, 1: -1}[self.flanking]
			if self.type == 'indel' and mo.leftpos == self.rightpos + c:
				self._expand(mo)
			elif self.type == 'snp' \
			and mo.leftpos == (self.rightpos + c):
				if strict_snps:
					raise MismatchExpandError("no expansion of SNPs in strict mode")
				else:
					self._expand(mo)
			else:
				raise MismatchExpandError("mismatches are not consecutive")
		else:
			raise MismatchExpandError("mismatches of different type")
		self.update(kwargs)

	def _expand(self, mo):
		for k, v in self.items():
			self[k] += mo.get(k, '.'*len(mo))
		self.rightpos = mo.rightpos

	def _not_supported(self):
		raise AttributeError("method is not supported")

	clip_leading_nonbase_items = _not_supported

	detect_mismatches = _not_supported


##############
#### MAIN ####
##############

def main():

	#try:

		# help needed?

		if len(sys.argv) == 1 or '-h' in sys.argv:
			sys.stdout.write(__usage__)
			sys.exit()

		# get / check infiles

		if len(sys.argv) < 2 or sys.argv[-1].startswith("-"):
			exit_on_error(13)

		if os.path.isfile(sys.argv[-1]):
			infiles = [sys.argv[-1]]
			options['dirname'] = os.path.dirname(os.path.abspath(infile))
		elif os.path.isdir(sys.argv[-1]):
			options['dirname'] = os.path.abspath(sys.argv[-1])
			ff = filter(lambda f: f.rsplit('.', 1)[1] in options['fasta_ext'], os.listdir(options['dirname']))
			infiles = map(lambda f: os.path.join(options['dirname'], f), ff)
		else:
			exit_on_error(1, sys.argv[-1])

		# parameters for processing

		if '-a' in sys.argv:
			try:
				arg = sys.argv[sys.argv.index('-a')+1]
				options['min_alnseq'] = int(arg)
				if options['min_alnseq'] < 2:
					exit_on_error(2, '-a', arg)
			except (KeyError, ValueError):
				i = sys.argv.index('-a')
				exit_on_error(2, '-a', ' '.join(sys.argv[i:i+2]))

		if '-L' in sys.argv:
			try:
				arg = sys.argv[sys.argv.index('-L')+1]
				options['min_alnlen'] = int(arg)
				if options['min_alnlen'] < 1:
					exit_on_error(2, '-L', arg)
			except (KeyError, ValueError):
				i = sys.argv.index('-L')
				exit_on_error(2, '-L', ' '.join(sys.argv[i:i+2]))

		if '-p' in sys.argv:
			try:
				arg = sys.argv[sys.argv.index('-p')+1]
				options['pm_cutoff'] = int(arg)
				if not (0 <= options['pm_cutoff'] <= 100):
					exit_on_error(2, '-p', arg)
			except (KeyError, ValueError):
				i = sys.argv.index('-p')
				exit_on_error(2, '-p', ' '.join(sys.argv[i:i+2]))

		if '-l' in sys.argv:
			try:
				arg = sys.argv[sys.argv.index('-l')+1]
				options['max_pms_length'] = int(arg)
				if options['max_pms_length'] < 1:
					exit_on_error(2, '-l', arg)
			except (KeyError, ValueError):
				i = sys.argv.index('-l')
				exit_on_error(2, '-l', ' '.join(sys.argv[i:i+2]))

		if '-b' in sys.argv:
			try:
				arg = sys.argv[sys.argv.index('-b')+1]
				options['min_border_distance'] = int(arg)
				if options['min_border_distance'] < 0:
					exit_on_error(2, '-b', arg)
			except (KeyError, ValueError):
				i = sys.argv.index('-b')
				exit_on_error(2, '-b', ' '.join(sys.argv[i:i+2]))

		if '-s' in sys.argv:
			try:
				arg = sys.argv[sys.argv.index('-s')+1]
				options['min_pms_score'] = int(arg)
				if options['min_pms_score'] < 0:
					exit_on_error(2, '-s', arg)
			except (KeyError, ValueError):
				i = sys.argv.index('-s')
				exit_on_error(2, '-s', ' '.join(sys.argv[i:i+2]))

		if '-n' in sys.argv:
			try:
				arg = sys.argv[sys.argv.index('-n')+1]
				options['min_nbh_score'] = int(arg)
				if options['min_nbh_score'] < 0:
					exit_on_error(2, '-n', arg)
			except (KeyError, ValueError):
				i = sys.argv.index('-n')
				exit_on_error(2, '-n', ' '.join(sys.argv[i:i+2]))

		if '-d' in sys.argv:
			options['report_headline'] = True

		if '-c' in sys.argv:
			options['clustal_out'] = True

		if '-o' in sys.argv:
			try:
				options['basename_out'] = sys.argv[sys.argv.index('-o')+1]
				if not os.path.isdir(os.path.dirname(os.path.abspath(options['basename_out']))):
					exit_on_error(1, '-o', options['basename_out'])
			except (KeyError, ValueError):
				i = sys.argv.index('-r')
				exit_on_error(2, '-o', ' '.join(sys.argv[i:i+2]))

		if '-A' in sys.argv:
			options['min_alnseq'] = 2
			options['min_alnlen'] = 0
			options['pm_cutoff'] = 100
			options['max_pms_length'] = 100 #??
			options['min_border_distance'] = 0
			options['min_pms_score'] = 0
			options['min_nbh_score'] = 0
			options['max_N'] = 100

		if '-v' in sys.argv:
			options['verbose'] = True

		for infile in infiles:

			options['basename'] = os.path.basename(infile.rsplit('.', 1)[0])

			if options['verbose']:
				print "Processing '%s'" % infile
			data = process_file(infile, options)

			if data['mr']:
				if options['basename_out']:
					fn = "%s.mismatches.csv" % options['basename_out']
				else:
					fn = "%s.mismatches.csv" % options['basename']
				header = ''
				try:
					if not os.path.isfile(fn):
						header = data['mr'][0].header2()
				except IndexError:
					pass
				fo = open(fn, 'a')
				if options['report_headline']:
					fo.write(header)
				fo.write(''.join(map(lambda d: d.csv(), data['mr'])))
				fo.close()
	
			if options['basename_out']:
				fn = "%s.alignments.csv" % options['basename_out']
			else:
				fn = "%s.alignments.csv" % options['basename']
			header = ''
			if not os.path.isfile(fn):
				header = data['ar'].header2()
			fo_report = open(fn, 'a')
			if options['report_headline']:
				fo_report.write(header)
			fo_report.write(data['ar'].csv())
			fo_report.close()

	#except StandardError, py_msg:

		# catch unrecognized errors and write pythons error message to
		# STDERR
	#	sys.stderr.write("%s\n" % py_msg)
	#	exit_on_error(0)


### main ###

if __name__ == "__main__":
	main()
