import { App } from 'antd'
import { applyPatch } from 'fast-json-patch'
import { useRef, useState } from 'react'

export interface SragParams {
  debugLevel: 0 | 1 | 2
  model: string
  queryRewritePrompt: string
  intentionPrompt: string
  chatPrompt: string
  useNer: boolean
  useQueryRewrite: boolean
  intentionEnable: boolean
}

export interface Output {
  action: string
  result: {
    log_dict: {
      latency_log_str: string
      rewrite_query: {
        llm_query: string
        search_query: string
      }
      query_rewrite_t: number
      crawl_contents_url: string[]
      recognize_names: string[]
      use_snippt_url: string[][]
    }
    generic_llm: { [k: number]: string }
  }
}

export interface SragMessage {
  query: string
  done: boolean
  output?: Output
}

export async function querySrag(postId: string, message: string, params: SragParams, gateway?: string) {
  const res = await fetch(gateway || `${import.meta.env.VITE_FAVIE_GATEWAY_HOST}/dashboard/stream_debug`, {
    method: 'POST',
    headers: {
      'Content-Type': 'application/json'
    },
    body: JSON.stringify({
      input: {
        source: 'DEBUG',
        post_id: postId,
        human_input: message,
        params: {
          srag_params: {
            model: params.model,
            use_query_rewrite: params.useQueryRewrite,
            use_ner: params.useNer,
            debug_level: params.debugLevel,
            intention_enable: params.intentionEnable,
            query_rewrite_prompt: params.queryRewritePrompt,
            chat_prompt: params.chatPrompt,
            intention_prompt: params.intentionPrompt
          }
        }
      }
    })
  })

  if (res.status >= 400) {
    throw new Error(res.statusText)
  }

  if (!res.body) {
    throw new Error('no response body')
  }

  return res.body.getReader()
}

export default function useSrag() {
  const { message } = App.useApp()
  const [query, setQuery] = useState('')
  const [params, setParams] = useState<SragParams>({
    debugLevel: 1,
    model: 'openai/gpt-4o-mini',
    chatPrompt: '',
    queryRewritePrompt: '',
    intentionPrompt: '',
    useNer: true,
    useQueryRewrite: true,
    intentionEnable: true
  })
  const [gateway, setGateway] = useState('')
  const [withContext, setWithContext] = useState(true)
  const [postId, setPostId] = useState('')
  const [messages, setMessages] = useState<SragMessage[]>([])
  const [loading, setLoading] = useState(false)
  const [content, setContent] = useState('')
  const output = useRef<{ output?: { output: Output[] } }>({})

  const changeParams = (params: Partial<SragParams>) => {
    setParams(prev => ({ ...prev, ...params }))
  }

  const handleStream = async (stream: ReadableStreamDefaultReader<Uint8Array>, query: string) => {
    const decoder = new TextDecoder('utf-8')
    let tmpStr = ''

    while (true) {
      const { done, value } = await stream.read()
      if (done) {
        break
      }

      const chunk = decoder.decode(value, { stream: true })
      if (chunk.startsWith('event') && !chunk.endsWith('\r\n')) {
        tmpStr = chunk
        continue
      }
      if (!chunk.startsWith('event')) {
        tmpStr += chunk
        if (!tmpStr.endsWith('\r\n')) {
          continue
        }
      }

      const arr = (tmpStr || chunk).split('\r\n').filter(Boolean)
      if (tmpStr) {
        tmpStr = ''
      }

      for (let i = 0; i < arr.length; i++) {
        const type = arr[i]
        const payload = arr[i + 1]

        if (type === 'event: metadata') {
          const data = JSON.parse(payload.replace('data: ', '')) as { post_id: string }
  
          setPostId(data.post_id)
  
          continue
        }

        if (type === 'event: data') {
          output.current = applyPatch(output.current, JSON.parse(payload.replace('data: ', ''))).newDocument
          if (output.current.output?.output[0]?.result?.generic_llm) {
            setContent(output.current.output.output[0].result.generic_llm[0])
          }
  
          continue
        }

        if (type === 'event: end') {
          const final = output.current.output?.output[0]

          setMessages(prev => prev.map(v => v.query === query && !v.done
            ? { ...v, done: true, output: final }
            : v
          ))
          setContent('')
          output.current = {}

          break
        }
      }
    }
  }

  const ask = async () => {
    if (!query) {
      return
    }

    if (withContext) {
      setMessages(prev => [...(prev ?? []), { query, done: false }])
    } else {
      setMessages([{ query, done: false }])
    }

    setLoading(true)
    try {
      const res = await querySrag(withContext ? postId : '', query, params, gateway)
      await handleStream(res, query)
    } catch (error) {
      if (error instanceof Error) {
        message.error(error.message)
      }
    } finally {
      setLoading(false)
    }
  }

  return {
    query,
    params,
    withContext,
    gateway,
    messages,
    content,
    loading,
    ask,
    changeParams,
    changeQuery: setQuery,
    changeGateway: setGateway,
    changeWithContext: setWithContext,
    changeMessages: setMessages
  }
}

export function useSragCompare() {
  const left = useSrag()
  const right = useSrag()
  const [compare, setCompare] = useState(false)

  const compareQuery = (val: boolean) => {
    if (val) {
      right.changeQuery(left.query)
      right.changeParams(left.params)
      right.changeGateway(left.gateway)
      right.changeWithContext(left.withContext)
      right.changeMessages(left.messages)
    }

    setCompare(val)
  }

  const submitQuery = () => {
    left.ask()
    if (compare) {
      right.ask()
    }
  }

  return {
    left: {
      query: left.query,
      params: left.params,
      withContext: left.withContext,
      gateway: left.gateway,
      messages: left.messages,
      content: left.content,
      loading: left.loading,
      onParamsChange: left.changeParams,
      onQueryChange: left.changeQuery,
      onGatewayChange: left.changeGateway,
      onWithContextChange: left.changeWithContext
    },
    right: {
      query: right.query,
      params: right.params,
      withContext: right.withContext,
      gateway: right.gateway,
      messages: right.messages,
      content: right.content,
      loading: right.loading,
      onParamsChange: right.changeParams,
      onQueryChange: right.changeQuery,
      onGatewayChange: right.changeGateway,
      onWithContextChange: right.changeWithContext
    },
    compare,
    compareQuery,
    submitQuery
  }
}
