use crate::common::BitSet;
use crate::core::SegmentReader;
use crate::query::ConstScorer;
use crate::query::{BitSetDocSet, Explanation};
use crate::query::{Scorer, Weight};
use crate::schema::{Field, IndexRecordOption};
use crate::termdict::{TermDictionary, TermStreamer};
use crate::DocId;
use crate::TantivyError;
use crate::{Result, SkipResult};
use std::sync::Arc;
use tantivy_fst::Automaton;

/// A weight struct for Fuzzy Term and Regex Queries
pub struct AutomatonWeight<A> {
    field: Field,
    automaton: Arc<A>,
}

impl<A> AutomatonWeight<A>
where
    A: Automaton + Send + Sync + 'static,
{
    /// Create a new AutomationWeight
    pub fn new<IntoArcA: Into<Arc<A>>>(field: Field, automaton: IntoArcA) -> AutomatonWeight<A> {
        AutomatonWeight {
            field,
            automaton: automaton.into(),
        }
    }

    fn automaton_stream<'a>(&'a self, term_dict: &'a TermDictionary) -> TermStreamer<'a, &'a A> {
        let automaton: &A = &*self.automaton;
        let term_stream_builder = term_dict.search(automaton);
        term_stream_builder.into_stream()
    }
}

impl<A> Weight for AutomatonWeight<A>
where
    A: Automaton + Send + Sync + 'static,
{
    fn scorer(&self, reader: &SegmentReader, boost: f32) -> Result<Box<dyn Scorer>> {
        let max_doc = reader.max_doc();
        let mut doc_bitset = BitSet::with_max_value(max_doc);

        let inverted_index = reader.inverted_index(self.field);
        let term_dict = inverted_index.terms();
        let mut term_stream = self.automaton_stream(term_dict);
        while term_stream.advance() {
            let term_info = term_stream.value();
            let mut block_segment_postings = inverted_index
                .read_block_postings_from_terminfo(term_info, IndexRecordOption::Basic);
            while block_segment_postings.advance() {
                for &doc in block_segment_postings.docs() {
                    doc_bitset.insert(doc);
                }
            }
        }
        let doc_bitset = BitSetDocSet::from(doc_bitset);
        let const_scorer = ConstScorer::new(doc_bitset, boost);
        Ok(Box::new(const_scorer))
    }

    fn explain(&self, reader: &SegmentReader, doc: DocId) -> Result<Explanation> {
        let mut scorer = self.scorer(reader, 1.0f32)?;
        if scorer.skip_next(doc) == SkipResult::Reached {
            Ok(Explanation::new("AutomatonScorer", 1.0f32))
        } else {
            Err(TantivyError::InvalidArgument(
                "Document does not exist".to_string(),
            ))
        }
    }
}

#[cfg(test)]
mod tests {
    use super::AutomatonWeight;
    use crate::query::Weight;
    use crate::schema::{Schema, STRING};
    use crate::Index;
    use tantivy_fst::Automaton;

    fn create_index() -> Index {
        let mut schema = Schema::builder();
        let title = schema.add_text_field("title", STRING);
        let index = Index::create_in_ram(schema.build());
        let mut index_writer = index.writer_with_num_threads(1, 3_000_000).unwrap();
        index_writer.add_document(doc!(title=>"abc"));
        index_writer.add_document(doc!(title=>"bcd"));
        index_writer.add_document(doc!(title=>"abcd"));
        assert!(index_writer.commit().is_ok());
        index
    }

    enum State {
        Start,
        NotMatching,
        AfterA,
    }

    struct PrefixedByA;

    impl Automaton for PrefixedByA {
        type State = State;

        fn start(&self) -> Self::State {
            State::Start
        }

        fn is_match(&self, state: &Self::State) -> bool {
            match *state {
                State::AfterA => true,
                _ => false,
            }
        }

        fn accept(&self, state: &Self::State, byte: u8) -> Self::State {
            match *state {
                State::Start => {
                    if byte == b'a' {
                        State::AfterA
                    } else {
                        State::NotMatching
                    }
                }
                State::AfterA => State::AfterA,
                State::NotMatching => State::NotMatching,
            }
        }
    }

    #[test]
    fn test_automaton_weight() {
        let index = create_index();
        let field = index.schema().get_field("title").unwrap();
        let automaton_weight = AutomatonWeight::new(field, PrefixedByA);
        let reader = index.reader().unwrap();
        let searcher = reader.searcher();
        let mut scorer = automaton_weight
            .scorer(searcher.segment_reader(0u32), 1.0f32)
            .unwrap();
        assert!(scorer.advance());
        assert_eq!(scorer.doc(), 0u32);
        assert_eq!(scorer.score(), 1.0f32);
        assert!(scorer.advance());
        assert_eq!(scorer.doc(), 2u32);
        assert_eq!(scorer.score(), 1.0f32);
        assert!(!scorer.advance());
    }

    #[test]
    fn test_automaton_weight_boost() {
        let index = create_index();
        let field = index.schema().get_field("title").unwrap();
        let automaton_weight = AutomatonWeight::new(field, PrefixedByA);
        let reader = index.reader().unwrap();
        let searcher = reader.searcher();
        let mut scorer = automaton_weight
            .scorer(searcher.segment_reader(0u32), 1.32f32)
            .unwrap();
        assert!(scorer.advance());
        assert_eq!(scorer.doc(), 0u32);
        assert_eq!(scorer.score(), 1.32f32);
    }
}
