#!/usr/bin/env python3
from beancount import loader
from beancount.query import query
from beancount.core.data import Custom
from beancount.core.amount import Amount, add, sub
from beancount.parser import printer
import argparse
from datetime import date
from dateutil.relativedelta import relativedelta
from tabulate import tabulate
from decimal import Decimal
from functools import reduce


class bcolors:
    HEADER = '\033[95m'
    OKBLUE = '\033[94m'
    OKCYAN = '\033[96m'
    OKGREEN = '\033[92m'
    WARNING = '\033[93m'
    FAIL = '\033[91m'
    ENDC = '\033[0m'
    BOLD = '\033[1m'
    UNDERLINE = '\033[4m'


def get_budget_entries(entries, period, start_date):
    budgets = []
    seen_accounts = set()
    for entry in entries:
        if isinstance(entry, Custom) and entry.values[1].value == period \
                and entry.date <= date.fromisoformat(start_date):
            account = entry.values[0].value
            if account not in seen_accounts:
                seen_accounts.add(account)
                budgets.append({
                    "date": entry.date,
                    "account": account,
                    "period": entry.values[1].value,
                    "budget": entry.values[2].value
                })
    return budgets


def get_equity_amounts(entries, options, period, start_date):
    period_delta = relativedelta(
        months=1) if period == "monthly" else relativedelta(years=1)
    end_date = date.fromisoformat(start_date) + period_delta
    equity_query = f"""SELECT account, sum(position) FROM
date >= {start_date} AND
date < {end_date.isoformat()}
WHERE account ~ \"Equity:(LloguerMiquel|FacturesUtilitatsMiquel)\""""
    rtypes, rrows = query.run_query(
        entries, options, equity_query)
    equity = {}
    for row in rrows:
        equity[row.account] = row.sum_position
    return equity


def get_expenses(entries, options, period, start_date):
    period_delta = relativedelta(
        months=1) if period == "monthly" else relativedelta(years=1)
    end_date = date.fromisoformat(start_date) + period_delta
    expenses_query = f"""SELECT account, sum(position) FROM
date >= {start_date} AND
date < {end_date.isoformat()}
WHERE account ~ \"Expenses\" OR account ~ \"Liabilities\""""
    rtypes, rrows = query.run_query(
        entries, options, expenses_query)
    expenses = {}
    for row in rrows:
        expenses[row.account] = row.sum_position
    return expenses


def build_budget(budget_entries, expenses, equity_amounts, total_positive_expenses):
    result = []
    for entry in budget_entries:
        expense = Amount(Decimal(0), entry["budget"].currency)
        expense_perc = 0
        total_perc = 0
        remaining = entry["budget"]
        if entry["account"] in expenses:
            expense = expenses[entry["account"]].get_only_position().units

            # Apply equity deductions for specific accounts
            if entry["account"] == "Expenses:Lloguer" and "Equity:LloguerMiquel" in equity_amounts:
                equity_amount = equity_amounts["Equity:LloguerMiquel"].get_only_position(
                )
                expense = sub(expense, equity_amount.units)
            elif entry["account"] == "Expenses:FacturesUtilitats" and "Equity:FacturesUtilitatsMiquel" in equity_amounts:
                equity_amount = equity_amounts["Equity:FacturesUtilitatsMiquel"].get_only_position(
                )
                expense = sub(expense, equity_amount.units)

            expense_perc = (expense.number /
                            entry["budget"].number) * 100
            # Calculate percentage of total positive expenses only
            if total_positive_expenses.number > 0 and expense.number > 0:
                total_perc = (expense.number /
                              total_positive_expenses.number) * 100
            remaining = sub(remaining, expense)
        result.append({
            "Account": entry["account"],
            "Budget": entry["budget"].to_string(),
            "Expense": expense,
            "Expense (%)": "{}{:,.2f}%{}".format(
                bcolors.FAIL if expense_perc >= 100 else '',
                expense_perc, bcolors.ENDC),
            "Total (%)": "{:,.2f}%".format(total_perc),
            "Remaining": remaining
        })
    return result


def print_report(budget_report, period, start_date, budget_sum, expenses_sum):
    print(f"Budget Report (period={period}, start_date={start_date})")
    print(f"Budget: {budget_sum}")
    print(f"{bcolors.FAIL if expenses_sum >= budget_sum else ''}Expenses: {
          expenses_sum}{bcolors.ENDC}")
    print(tabulate(budget_report, headers="keys", numalign="right", floatfmt=".2f"))


def main():
    parser = argparse.ArgumentParser(description='Generate budget report')
    parser.add_argument('start_date', metavar='start_date', type=str, nargs=1,
                        help='Start date (end date will be one month after if '
                        'monthly report or one year after if yearly report)')
    parser.add_argument('-p', metavar='period', type=str,
                        choices=["monthly", "yearly"], default="monthly",
                        required=False,
                        help='Period (monthly or yearly)')

    args = parser.parse_args()
    start_date = args.start_date[0]
    period = args.p

    filename = "ledger/main.beancount"
    entries, errors, options = loader.load_file(filename)

    if errors:
        printer.print_errors(errors)

    budget_entries = get_budget_entries(entries, period, start_date)
    # TODO: Multiple currencies
    budget_sum = reduce(
        lambda a, b: add(a, b["budget"]),
        budget_entries,
        Amount(Decimal(0), budget_entries[0]["budget"].currency)
    )
    expenses = get_expenses(entries, options, period, start_date)
    equity_amounts = get_equity_amounts(entries, options, period, start_date)
    filtered_expenses = {}
    for entry in budget_entries:
        if entry["account"] in expenses:
            filtered_expenses[entry["account"]] = expenses[entry["account"]]

    # Calculate total of positive expenses only for percentage calculation
    positive_expenses_sum = Amount(
        Decimal(0), budget_entries[0]["budget"].currency)
    for entry in budget_entries:
        if entry["account"] in expenses:
            expense = expenses[entry["account"]].get_only_position().units

            # Apply equity deductions for specific accounts (same logic as in build_budget)
            if entry["account"] == "Expenses:Lloguer" and "Equity:LloguerMiquel" in equity_amounts:
                equity_amount = equity_amounts["Equity:LloguerMiquel"].get_only_position(
                )
                expense = sub(expense, equity_amount.units)
            elif entry["account"] == "Expenses:FacturesUtilitats" and "Equity:FacturesUtilitatsMiquel" in equity_amounts:
                equity_amount = equity_amounts["Equity:FacturesUtilitatsMiquel"].get_only_position(
                )
                expense = sub(expense, equity_amount.units)

            if expense.number > 0:
                positive_expenses_sum = add(positive_expenses_sum, expense)

    budget_report = build_budget(
        budget_entries, expenses, equity_amounts, positive_expenses_sum)
    print_report(budget_report, period, start_date,
                 budget_sum, positive_expenses_sum)


main()
