#include <fstream>
#include <iostream>
#include <vector>
using namespace std;

bool checkRow(vector<string> lines, int x, int y, int x_max, int y_max) {
  if (x > x_max - 4) return false;

  if (lines[y][x] == 'X' && lines[y][x + 1] == 'M' && lines[y][x + 2] == 'A' && lines[y][x + 3] == 'S') return true;
  if (lines[y][x] == 'S' && lines[y][x + 1] == 'A' && lines[y][x + 2] == 'M' && lines[y][x + 3] == 'X') return true;
  return false;
}

bool checkCol(vector<string> lines, int x, int y, int x_max, int y_max) {
  if (y > y_max - 4) return false;

  if (lines[y][x] == 'X' && lines[y + 1][x] == 'M' && lines[y + 2][x] == 'A' && lines[y + 3][x] == 'S') return true;
  if (lines[y][x] == 'S' && lines[y + 1][x] == 'A' && lines[y + 2][x] == 'M' && lines[y + 3][x] == 'X') return true;
  return false;
}

bool checkDiagL(vector<string> lines, int x, int y, int x_max, int y_max) {
  if (x < 3 || y > y_max - 4) return false;

  if (lines[y][x] == 'X' && lines[y + 1][x - 1] == 'M' && lines[y + 2][x - 2] == 'A' && lines[y + 3][x - 3] == 'S') return true;
  if (lines[y][x] == 'S' && lines[y + 1][x - 1] == 'A' && lines[y + 2][x - 2] == 'M' && lines[y + 3][x - 3] == 'X') return true;
  return false;
}

bool checkDiagR(vector<string> lines, int x, int y, int x_max, int y_max) {
  if (x > x_max - 4 || y > y_max - 4) return false;

  if (lines[y][x] == 'X' && lines[y + 1][x + 1] == 'M' && lines[y + 2][x + 2] == 'A' && lines[y + 3][x + 3] == 'S') return true;
  if (lines[y][x] == 'S' && lines[y + 1][x + 1] == 'A' && lines[y + 2][x + 2] == 'M' && lines[y + 3][x + 3] == 'X') return true;
  return false;
}

bool checkCross(vector<string> lines, int x, int y, int x_max, int y_max) {
  if (lines[y][x] != 'A' || y < 1 || x < 1 || y > y_max - 2 || x > x_max - 2) return false;

  bool ret = true;

  // left diagonal
  ret = ret && ((lines[y - 1][x - 1] == 'M' && lines[y + 1][x + 1] == 'S') || (lines[y - 1][x - 1] == 'S' && lines[y + 1][x + 1] == 'M'));

  // right diagonal
  ret = ret && ((lines[y - 1][x + 1] == 'M' && lines[y + 1][x - 1] == 'S') || (lines[y - 1][x + 1] == 'S' && lines[y + 1][x - 1] == 'M'));

  return ret;
}

int part1(const char *path) {
  ifstream input(path);
  
  // parse input as 2d array
  string line;
  vector<string> lines;
  while (getline(input, line)) {
    lines.push_back(line);
  }

  int x_max = lines[0].size();
  int y_max = lines.size();

  int sum = 0;

  for (int y = 0; y < y_max; y++) {
    for (int x = 0; x < x_max; x++) {
      sum += checkRow(lines, x, y, x_max, y_max) ? 1 : 0;
      sum += checkCol(lines, x, y, x_max, y_max) ? 1 : 0;
      sum += checkDiagL(lines, x, y, x_max, y_max) ? 1 : 0;
      sum += checkDiagR(lines, x, y, x_max, y_max) ? 1 : 0;
    }
  }
  return sum;
}

int part2(const char *path) {
  ifstream input(path);
  
  // parse input as 2d array
  string line;
  vector<string> lines;
  while (getline(input, line)) {
    lines.push_back(line);
  }

  int x_max = lines[0].size();
  int y_max = lines.size();

  int sum = 0;

  for (int y = 0; y < y_max; y++) {
    for (int x = 0; x < x_max; x++) {
      sum += checkCross(lines, x, y, x_max, y_max) ? 1 : 0;
    }
  }
  return sum;
}

int main(int argc, char *argv[]) {
  cout << "Example1: " << part1("example") << endl;
  cout << "Input1: " << part1("input") << endl;
  cout << "Example2: " << part2("example") << endl;
  cout << "Input2: " << part2("input") << endl;
  return 0;
}