#!/usr/bin/env python
# -*- coding: utf-8 -*-

from .base import OrderedVocabulary
from collections import defaultdict
from six import iteritems
import re
import logging

logger = logging.getLogger(__name__)

class VocabExpander(OrderedVocabulary):
  def __init__(self, vocabulary, formatters, strategy):
    super(VocabExpander, self).__init__(vocabulary.words)
    self.strategy = strategy
    self._vocab = vocabulary
    self.aux_word_id = defaultdict(lambda: [])
    self.formatters = formatters
    self.expand(formatters)
    self.aux_id_word = {id_:w for w, id_ in iteritems(self.aux_word_id)}

  def __getitem__(self, key):
    try:
      return self._vocab[key]
    except KeyError as e:
      try:
        return self.aux_word_id[key]
      except KeyError as e:
        return self.approximate_ids(key)

  def __contains__(self, key):
    return ((key in self._vocab) or
            (key in self.aux_word_id) or
            self.approximate(key))

  def __len__(self):
    return len(self._vocab) + len(self.aux_word_id)
  
  def __delitem__(self):
    raise NotImplementedError("It is quite complex, let us do it in the future")

  def format(self, w):
    return [f(w) for f in self.formatters]
  
  def approximate(self, w):
    f = lambda key: (key in self._vocab) or (key in self.aux_word_id)
    return {w_:self[w_] for w_ in self.format(w) if f(w_)}
  
  def approximate_ids(self, key):
    ids = [id_ for w, id_ in self.approximate(key).items()]
    if not ids:
      raise KeyError(u"{} not found".format(key))
    else:
      if self.strategy == 'most_frequent':
        return min(ids)
      else:
        return tuple(sorted(ids))  

  def _expand(self, formatter):
    for w in self.word_id:
      w_ = formatter(w)
      if w_ not in self._vocab:
        id_ = self.word_id[w]
        self.aux_word_id[w_].append(id_)

  def expand(self, formatters):
    for formatter in formatters:
      self._expand(formatter)
    if self.strategy == 'average':
      self.aux_word_id = {w: tuple(sorted(ids)) for w, ids in iteritems(self.aux_word_id)}
    elif self.strategy == 'most_frequent':
      self.aux_word_id = {w: min(ids) for w, ids in iteritems(self.aux_word_id)}
    else:
      raise ValueError("A strategy is needed")

    words_added = self.aux_word_id.keys()
    old_no = len(self._vocab)
    new_no = len(self.aux_word_id)
    logger.info("We have {} original words.".format(old_no))
    logger.info("Added {} new words.".format(new_no))
    logger.info("The new total number of words is {}".format(len(self)))
    logger.debug(u"Words added\n{}\n".format(u" ".join(words_added)))


class CaseExpander(VocabExpander):
  def __init__(self, vocabulary, strategy='most_frequent'):
    formatters = [lambda x: x.lower(),
                  lambda x: x.title(),
                  lambda x: x.upper()]
    super(CaseExpander, self).__init__(vocabulary=vocabulary, formatters=formatters, strategy=strategy)

    
class DigitExpander(VocabExpander):
  def __init__(self, vocabulary, strategy='most_frequent'):
    pattern = re.compile("[0-9]", flags=re.UNICODE)
    formatters = [lambda x: pattern.sub("#", x)]
    super(DigitExpander, self).__init__(vocabulary=vocabulary, formatters=formatters, strategy=strategy)