import { createReducer } from '@reduxjs/toolkit'
import {
  addMulticallListeners,
  errorFetchingMulticallResults,
  fetchingMulticallResults,
  removeMulticallListeners,
  toMultiCallKey,
  updateMulticallResults,
} from './actions'
import { ChainId, ChainIdMap, initializeObjectChainIdMap } from 'constants/chainId'

interface CallResultObject {
  data?: string | null
  blockNumber?: number
  fetchingBlockNumber?: number
}

export interface MulticallState {
  // on a per-chain basis
  // stores for each call key the listeners' preferences
  // stores how many listeners there are per each blocks per fetch preference
  callListeners?: ChainIdMap<Record<string, Record<number, number>>>

  callResults: ChainIdMap<Record<string, CallResultObject | undefined>>
}

const initialState: MulticallState = {
  callResults: initializeObjectChainIdMap(),
}

function getBlocksPerFetchForChainId(chainId: number | undefined): number {
  switch (chainId) {
    case ChainId.BASE:
    case ChainId.POLYGON_ZKEVM:
    case ChainId.ARBITRUM_ONE:
      return 15
    default:
      return 1
  }
}

export default createReducer(initialState, builder =>
  builder
    .addCase(
      addMulticallListeners,
      (
        state,
        { payload: { calls, chainId, options: { blocksPerFetch = getBlocksPerFetchForChainId(chainId) } = {} } },
      ) => {
        const listeners: MulticallState['callListeners'] = state.callListeners
          ? state.callListeners
          : (state.callListeners = initializeObjectChainIdMap())
        listeners[chainId] = listeners[chainId] ?? {}
        calls.forEach(call => {
          const callKey = toMultiCallKey(call)
          listeners[chainId][callKey] = listeners[chainId][callKey] ?? {}
          listeners[chainId][callKey][blocksPerFetch] = (listeners[chainId][callKey][blocksPerFetch] ?? 0) + 1
        })
      },
    )
    .addCase(
      removeMulticallListeners,
      (
        state,
        { payload: { chainId, calls, options: { blocksPerFetch = getBlocksPerFetchForChainId(chainId) } = {} } },
      ) => {
        const listeners: MulticallState['callListeners'] = state.callListeners
          ? state.callListeners
          : (state.callListeners = initializeObjectChainIdMap())

        if (!listeners[chainId]) {
          return
        }

        calls.forEach(call => {
          const callKey = toMultiCallKey(call)
          if (!listeners[chainId][callKey]) return
          if (!listeners[chainId][callKey][blocksPerFetch]) return

          if (listeners[chainId][callKey][blocksPerFetch] === 1) {
            delete listeners[chainId][callKey][blocksPerFetch]
          } else {
            listeners[chainId][callKey][blocksPerFetch]--
          }
        })
      },
    )
    .addCase(fetchingMulticallResults, (state, { payload: { chainId, fetchingBlockNumber, calls } }) => {
      state.callResults[chainId] = state.callResults[chainId] ?? {}
      calls.forEach(call => {
        const callKey = toMultiCallKey(call)
        const current = state.callResults[chainId][callKey]
        if (!current) {
          state.callResults[chainId][callKey] = {
            fetchingBlockNumber,
          }
        } else {
          if ((current.fetchingBlockNumber ?? 0) >= fetchingBlockNumber) {
            return
          }
          current.fetchingBlockNumber = fetchingBlockNumber
        }
      })
    })
    .addCase(errorFetchingMulticallResults, (state, { payload: { fetchingBlockNumber, chainId, calls } }) => {
      state.callResults[chainId] = state.callResults[chainId] ?? {}
      calls.forEach(call => {
        const callKey = toMultiCallKey(call)
        const current = state.callResults[chainId][callKey]
        if (!current) {
          return // only should be dispatched if we are already fetching
        }

        if (current.fetchingBlockNumber === fetchingBlockNumber) {
          delete current.fetchingBlockNumber
          current.data = null
          current.blockNumber = fetchingBlockNumber
        }
      })
    })
    .addCase(updateMulticallResults, (state, { payload: { chainId, results, blockNumber } }) => {
      state.callResults[chainId] = state.callResults[chainId] ?? {}
      Object.keys(results).forEach(callKey => {
        const current = state.callResults[chainId][callKey]
        if ((current?.blockNumber ?? 0) > blockNumber) return
        state.callResults[chainId][callKey] = {
          data: results[callKey],
          blockNumber,
        }
      })
    }),
)
