import React, { useEffect, useRef } from "react";
import { select } from "d3-selection";
import { axisTop, axisRight, axisBottom, axisLeft } from "d3-axis";
import { makeStyles } from "@material-ui/core";
import { COLORS } from "../../../../constants";
import { format } from "d3-format";

const AXIS_OFFSET = 20;

const useStyles = makeStyles((theme) => ({
  axis: {
    "& text": {
      textRendering: "optimizeLegibility",
      stroke: "none",
      fill: COLORS.text,
      fontSize: "12px",
      fontWeight: "normal",
      lineHeight: "1.46",
      textAnchor: "middle",
    },
    "& .tick": {
      textAnchor: "start",
    },
    "& path, & line": {
      stroke: COLORS.veryLightGrey,
      fill: "none",
      strokeWidth: "1px",
    },
  },
}));

const Axes = ({
  xScale,
  yScale,
  height,
  width,
  showXAxisLabel = false,
  showYAxisLabel = false,
  xTickFormat = format(".0f"),
  yTickFormat = format(".0f"),
  xAxisLabel = "x-axis-label",
  yAxisLabel = "y-axis-label",
  showYAxis = true,
  showSimpleXAxis = false,
  isLastChart,
}) => {
  const classes = useStyles();
  const xAxisRef = useRef(null);
  const yAxisRef = useRef(null);

  useEffect(() => {
    showSimpleXAxis ? renderSimpleXAxis() : renderXAxis();
    renderXAxis();
    showYAxis && renderYAxis();
  });

  const renderXAxis = () => {
    const xAxis = select(xAxisRef.current);
    const xAxisTop = axisTop(xScale)
      .tickFormat(xTickFormat)
      .ticks(5)
      .tickSize(-height);
    xAxis.call(xAxisTop);
    xAxis.select(".domain").remove();
    xAxis.selectAll(".tick text").attr("dy", height + AXIS_OFFSET);
  };

  const renderSimpleXAxis = () => {
    const xAxis = select(xAxisRef.current);
    const xAxisBottom = axisBottom(xScale).tickFormat(xTickFormat).ticks(5);
    xAxis.call(xAxisBottom);
    xAxis.select(".domain").remove();
    xAxis.selectAll(".tick line").remove();
  };

  const renderYAxis = () => {
    //console.log("renderYAxis")
    const yAxis = select(yAxisRef.current);
    const yAxisRight = axisRight(yScale)
      .tickFormat(yTickFormat)
      .ticks(5)
      .tickSize(width);
    yAxis.call(yAxisRight);
    yAxis.select(".domain").remove();
    yAxis.selectAll(".tick text").attr("dx", -width - AXIS_OFFSET);
    //console.log("yAxis", yAxis)
  };

  return (
    <g className={classes.axis}>
      <g ref={xAxisRef}>
        {showXAxisLabel && (
          <text
            dy="1.5em"
            transform={`translate(${width / 2},${height + AXIS_OFFSET})`}
          >
            {xAxisLabel}
          </text>
        )}
      </g>
      <g ref={yAxisRef}>
        {showYAxisLabel && (
          <text
            dy="-1.5em"
            transform={`translate(${-AXIS_OFFSET},${height / 2}) rotate(-90)`}
          >
            {yAxisLabel}
          </text>
        )}
      </g>
    </g>
  );
};

export default Axes;
