import { createStore } from 'zustand';
import {
  Connection,
  Edge,
  EdgeChange,
  Node,
  NodeChange,
  addEdge,
  OnNodesChange,
  OnEdgesChange,
  OnConnect,
  applyNodeChanges,
  applyEdgeChanges,
} from 'reactflow';

export interface OnNodeDataChange {
  id: string;
  data: any;
}

export type NodeDataChanges = Array<{
  change: OnNodeDataChange;
}>;

export type OnNodeDataChanges = (changes: NodeDataChanges) => void;

export type OnNodesSet<Data, U extends string | undefined> = (
  nodes: Node<Data, U>[],
) => void;
export type OnEdgesSet = (edges: Edge[]) => void;

export type RFState<Data, U extends string | undefined> = {
  nodes: Node<Data, U>[];
  edges: Edge[];
  onNodesChange: OnNodesChange;
  onEdgesChange: OnEdgesChange;
  onNodesSet: OnNodesSet<Data, U>;
  onEdgesSet: OnEdgesSet;
  onConnect: OnConnect;
  onNodesDataChange: (changes: NodeDataChanges) => void;
};

const applyNodeDataChanges = <Data, U extends string | undefined>(
  changes: NodeDataChanges,
  nodes: Node<Data, U>[],
) => {
  return nodes.reduce((res: Node<Data, U>[], curr) => {
    const currentChanges = changes.filter((c) => c.change.id === curr.id);
    if (currentChanges.length === 0) {
      res.push(curr);
    }
    for (const currentChange of currentChanges) {
      res.push({
        ...curr,
        data: currentChange.change.data,
      });
    }
    return res;
  }, []);
};

export const createFlowStore = <Data, U extends string | undefined>(
  initialNodes: Node<Data, U>[],
  initialEdges: Edge[],
) =>
  createStore<RFState<Data, U>>((set, get) => ({
    nodes: initialNodes,
    edges: initialEdges,
    onNodesChange: (changes: NodeChange[]) => {
      const nodes = applyNodeChanges<Data>(changes, get().nodes) as Node<
        Data,
        U
      >[];
      set({
        nodes,
      });
    },
    onNodesDataChange: (changes: NodeDataChanges) => {
      set({
        nodes: applyNodeDataChanges(changes, get().nodes),
      });
    },
    onEdgesChange: (changes: EdgeChange[]) => {
      set({
        edges: applyEdgeChanges(changes, get().edges),
      });
    },
    onNodesSet: (nodes: Node<Data, U>[]) => {
      set({
        nodes,
      });
    },
    onEdgesSet: (edges: Edge[]) => {
      set({
        edges,
      });
    },
    onConnect: (connection: Connection) => {
      set({
        edges: addEdge(connection, get().edges),
      });
    },
  }));
