import { produce } from "immer";
import { BotResponseOption, Message, Scenario } from "src/types/models";
import { shuffleArray } from "src/utils/array";
import { createWithEqualityFn } from "zustand/traditional";

import {
  InteractiveConversation,
  InteractiveConversationState,
  TaskData,
  TaskState,
} from "../types";
import {
  MaybeScenarioOrSelector,
  canCompleteConversation,
  getRandomNumber,
  getScenarioAtTurn,
  getTaskState,
} from "../utils";

type InteractiveSideBySideStateSlice = {
  startError: string | null;
  chatError: string | null;
  taskState: TaskState;
  refreshHistory: string[][];
  taskData: TaskData | null;
  options: BotResponseOption[];
  isLoadingOptions: boolean;
  preflightId: number | null;
  setStartError: (error: string | null) => void;
  setChatError: (error: string | null) => void;
  setTaskState: (taskState: TaskState) => void;
  setHistory: (history: Message[]) => void;
  selectScenario: (scenario: Scenario, turn: number) => void;
  getCurrentConversation: () => InteractiveConversation | null;
  endConversation: (conversationId: number) => void;
  isLastConversation: (conversationId: number) => boolean;
  setOptions: (options: BotResponseOption[]) => void;
  setIsLoadingOptions: (isLoadingOptions: boolean) => void;
  hasEnoughOptions: () => boolean;
  setPreflightId: (preflightId: number) => void;
  getMaxRefreshAttempts: () => number;
  setTaskData: (taskData: TaskData) => void;
  getShouldWriteAlternative: () => boolean;
  getShouldShowBothOptionsAreBadButton: () => boolean;
  getShouldShowBothOptionsAreGoodButton: () => boolean;
  getReplies: () => string[] | null;
  canCompleteConversation: () => boolean;
  logCurrentRefreshAttempt: () => void;
  setRefreshHistory: (refreshHistory: string[][]) => void;
  getScenarioAtTurn: (turn: number) => MaybeScenarioOrSelector;
  getNumRequiredAiMessages: () => number;
  isSelectingScenario: () => boolean;
  getIsScenarioRequired: () => boolean;
};

const useInteractiveSideBySideStore =
  createWithEqualityFn<InteractiveSideBySideStateSlice>(
    (set, get) => ({
      startError: null,
      chatError: null,
      taskState: TaskState.NEED_CONFIDENTIALITY_AGREEMENT,
      refreshHistory: [],
      taskData: null,
      isLoadingOptions: false,
      options: [],
      preflightId: null,
      numAiTurnsPerConversations: {},
      setStartError: (error: string | null) =>
        set(() => ({ startError: error })),
      setChatError: (error: string | null) => set(() => ({ chatError: error })),
      setTaskState: (taskState: TaskState) => set(() => ({ taskState })),
      setHistory: (history: Message[]) =>
        set(
          produce((state: InteractiveSideBySideStateSlice) => {
            const currentConversationIndex =
              state.taskData!.conversations.findIndex(
                (c) => c.state !== InteractiveConversationState.COMPLETED
              );
            state.taskData!.conversations[currentConversationIndex].history =
              history;
          })
        ),
      selectScenario: (scenario: Scenario, turn: number) =>
        set(
          produce((state: InteractiveSideBySideStateSlice) => {
            const currentConversationIndex =
              state.taskData!.conversations.findIndex(
                (c) => c.state !== InteractiveConversationState.COMPLETED
              );
            if (state.taskData!.multiple_scenarios_config) {
              state.taskData!.conversations[
                currentConversationIndex
              ].scenario_checkpoints.push({
                scenario,
                turn,
              });
            } else {
              state.taskData!.conversations[currentConversationIndex].scenario =
                scenario;
            }
          })
        ),
      setTaskData: (taskData: TaskData) => {
        set(() => {
          const conversations: InteractiveConversation[] = [];
          // generate the number of ai turns needed for each conversation up front
          for (const conversation of taskData.conversations) {
            if (taskData.multiple_scenarios_config) {
              conversation.multiple_scenarios_config = {
                ...taskData.multiple_scenarios_config,
                single_scenario_configs: [],
              };
              const configs = taskData.multiple_scenarios_config.should_shuffle
                ? shuffleArray(
                    taskData.multiple_scenarios_config.single_scenario_configs
                  )
                : taskData.multiple_scenarios_config.single_scenario_configs;
              for (const config of configs) {
                conversation.multiple_scenarios_config.single_scenario_configs.push(
                  {
                    ...config,
                    num_ai_turns: getRandomNumber(
                      config.min_num_ai_turns,
                      config.max_num_ai_turns + 1
                    ),
                  }
                );
              }
            }
            conversations.push({
              ...conversation,
              num_ai_turns: taskData.num_turns,
            });
          }
          const newTaskData = {
            ...taskData,
            conversations,
          };
          return {
            taskState: getTaskState(newTaskData),
            taskData: newTaskData,
          };
        });
      },
      endConversation: (conversationId: number) => {
        const { taskData } = get();
        if (!taskData) {
          return;
        }
        const newConversations = taskData.conversations.map((conversation) => {
          if (conversation.id === conversationId) {
            return {
              ...conversation,
              state: InteractiveConversationState.COMPLETED,
            };
          }
          return conversation;
        });
        const newTaskData = {
          ...taskData,
          conversations: newConversations,
        };
        set(() => ({
          taskState: getTaskState(newTaskData),
          taskData: newTaskData,
        }));
      },
      isLastConversation: (conversationId: number) => {
        const { taskData } = get();
        const lastConvo =
          taskData?.conversations[taskData?.conversations.length - 1];
        return lastConvo?.id === conversationId;
      },
      getCurrentConversation: () => {
        const { taskData } = get();
        if (!taskData) {
          return null;
        }
        const currentConversationIndex = taskData.conversations.findIndex(
          (conversation) =>
            conversation.state !== InteractiveConversationState.COMPLETED
        );
        if (currentConversationIndex < 0) {
          return null;
        }

        return {
          ...taskData.conversations[currentConversationIndex],
          index: currentConversationIndex,
        };
      },
      getShouldWriteAlternative: () => {
        const { taskData, refreshHistory } = get();
        if (!taskData?.require_alternative_response) {
          return false;
        }

        // If this is a retrieval task, asks for alternative immediately
        // without gating on minimum number of rerolls
        if (taskData?.retrieval_annotation_form) {
          return true;
        }

        const numRefreshAttempts = refreshHistory.length;
        return (
          !taskData?.max_refresh_attempts ||
          numRefreshAttempts >= taskData?.max_refresh_attempts
        );
      },
      getShouldShowBothOptionsAreBadButton: () => {
        const { taskData, refreshHistory } = get();
        const numRefreshAttempts = refreshHistory.length;

        if (!taskData) {
          return false;
        }

        if (
          !taskData.require_alternative_response ||
          taskData.max_refresh_attempts === 0
        ) {
          return true;
        }

        return (
          taskData.max_refresh_attempts > 0 &&
          numRefreshAttempts < taskData.max_refresh_attempts
        );
      },
      getShouldShowBothOptionsAreGoodButton: () => {
        const { taskData } = get();
        return Boolean(taskData?.show_both_options_are_good);
      },
      getMaxRefreshAttempts: () => {
        const { taskData } = get();
        return taskData?.max_refresh_attempts || 0;
      },
      hasEnoughOptions: () => {
        const { taskData, options } = get();
        return Boolean(
          options.length >= 2 &&
            taskData &&
            taskData.model_options.length === options.length
        );
      },
      setOptions: (options: BotResponseOption[]) => set(() => ({ options })),
      setIsLoadingOptions: (isLoadingOptions: boolean) =>
        set(() => ({ isLoadingOptions })),
      setPreflightId: (preflightId: number) => set(() => ({ preflightId })),
      getReplies: () => {
        const { options, hasEnoughOptions } = get();
        if (!hasEnoughOptions()) {
          return null;
        }
        return options.map((option) => option.text);
      },
      canCompleteConversation: () => {
        const { getCurrentConversation } = get();
        const currentConversation = getCurrentConversation();
        if (!currentConversation) {
          return false;
        }
        return canCompleteConversation(currentConversation);
      },
      logCurrentRefreshAttempt: () => {
        const replies = get().getReplies()!;
        set((state) => ({
          refreshHistory: [...state.refreshHistory, replies],
        }));
      },
      setRefreshHistory: (refreshHistory: string[][]) =>
        set(() => ({ refreshHistory })),
      getScenarioAtTurn: (turn: number) => {
        const { getCurrentConversation } = get();
        const currentConversation = getCurrentConversation();
        if (!currentConversation) {
          return undefined;
        }
        return getScenarioAtTurn(turn, currentConversation);
      },
      getNumRequiredAiMessages: () => {
        const { getCurrentConversation, taskData } = get();
        const currentConversation = getCurrentConversation();
        if (!taskData || !currentConversation) {
          return 0;
        }
        if (currentConversation.multiple_scenarios_config) {
          return currentConversation.multiple_scenarios_config.single_scenario_configs.reduce(
            (acc, config) => acc + config.num_ai_turns,
            0
          );
        } else {
          return taskData.num_turns;
        }
      },
      isSelectingScenario: () => {
        const {
          getIsScenarioRequired,
          getCurrentConversation,
          getScenarioAtTurn,
        } = get();
        const currentConversation = getCurrentConversation();
        const isScenarioRequired = getIsScenarioRequired();
        if (!isScenarioRequired || !currentConversation) {
          return false;
        }
        return (
          getScenarioAtTurn(currentConversation.history.length)?.type ===
          "selector"
        );
      },
      getIsScenarioRequired: () => {
        const { taskData } = get();
        return Boolean(
          taskData?.require_scenario || taskData?.multiple_scenarios_config
        );
      },
    }),
    Object.is
  );

export default useInteractiveSideBySideStore;
