import { Message } from "src/types/models";
import { createWithEqualityFn } from "zustand/traditional";

import {
  ConversationState,
  TaskData,
  TaskState,
  StaticConversation,
} from "../types";
import { getTaskState } from "../utils";

interface StaticSxsStateSlice {
  taskState: TaskState;
  taskData: TaskData | null;
  setTaskState: (taskState: TaskState) => void;
  setTaskData: (taskData: TaskData) => void;
  endConversation: (conversationId: number) => void;
  addConversation: (conversation: StaticConversation) => void;
  skipConversation: (conversation: StaticConversation) => void;
  getCurrentConversationIndex: () => number | null;
  isFinished: () => boolean;
  addMessage: (message: Message) => void;
}

const useStaticSxsStore = createWithEqualityFn<StaticSxsStateSlice>(
  (set, get) => ({
    taskState: TaskState.NEED_CONFIDENTIALITY_AGREEMENT,
    taskData: null,
    currentConversationIndex: 0,

    setTaskState: (taskState: TaskState) => set((state) => ({ taskState })),
    setTaskData: (taskData: TaskData) =>
      set((state) => {
        const newTaskState = getTaskState(taskData);
        return { taskData, taskState: newTaskState };
      }),
    endConversation: (conversationId: number) => {
      const { taskData, setTaskData } = get();
      if (!taskData) {
        return;
      }
      const convos = taskData.convos.map((conversation) => {
        if (conversation.id === conversationId) {
          conversation.state = ConversationState.COMPLETED;
        }
        return conversation;
      });
      const newTaskData = {
        ...taskData,
        convos,
      };
      setTaskData(newTaskData);
    },
    getCurrentConversationIndex: () => {
      const { taskData } = get();
      if (!taskData) {
        return null;
      }
      const index = taskData.convos.findIndex(
        (conversation) => conversation.state !== ConversationState.COMPLETED
      );
      return index > -1 ? index : null;
    },
    isFinished: () => {
      const { taskState } = get();
      return taskState === TaskState.FINISHED;
    },
    addMessage: (message: Message) => {
      const { taskData, getCurrentConversationIndex, setTaskData } = get();
      const currentConversationIndex = getCurrentConversationIndex();
      if (!taskData || currentConversationIndex === null) {
        return;
      }
      const convos = taskData!.convos.map((conversation, index) => {
        if (index === currentConversationIndex) {
          conversation.context.push(message);
        }
        return conversation;
      });
      const newTaskData = {
        ...taskData,
        convos,
      };
      setTaskData(newTaskData);
    },
    addConversation: (conversation: StaticConversation) => {
      const { taskData, setTaskData } = get();
      if (!taskData) {
        return;
      }
      const newTaskData = {
        ...taskData,
        convos: [...taskData.convos, conversation],
      };
      setTaskData(newTaskData);
    },
    skipConversation: (replacement: StaticConversation) => {
      const { taskData, getCurrentConversationIndex, setTaskData } = get();
      const currentConversationIndex = getCurrentConversationIndex();
      if (!taskData || currentConversationIndex === null) {
        return;
      }
      const convos = taskData.convos.map((conversation, index) => {
        if (index === currentConversationIndex) {
          return replacement;
        }
        return conversation;
      });
      const newTaskData = {
        ...taskData,
        convos,
      };
      setTaskData(newTaskData);
    },
  }),
  Object.is
);

export default useStaticSxsStore;
