import logging
from dataclasses import dataclass
from typing import Optional, List, Callable

import pandas as pd

from pydantic import ValidationError, BaseModel
from sqlalchemy import create_engine, Engine
from sqlalchemy.orm import sessionmaker, Session

import nsys_cpu_stats.trace_utils as tu
import server.app.routes.report as local_report_routes
import server.app.routes.hotspot_analysis as local_hotspot_analysis_routes
import server.app.routes.threading_analysis as local_threading_analysis_routes

# pylint: disable-next=unused-import
import server.app.models.models  # noqa: F401
from server.app.models.base import Base
from server.app.models.hotspot_analysis.frame_times import FrameTimesCreate
from server.app.models.hotspot_analysis.hotspot_analysis import HotspotAnalysisCreate, HotspotAnalysisCustomFieldCreate, HotspotAnalysisCustomChartRowCreate
from server.app.models.hotspot_analysis.region import RegionCreate
from server.app.models.hotspot_analysis.region_thread import RegionThreadCreate
from server.app.models.report import ReportCreate
from server.app.models.threading_analysis.application_stats import ApplicationStatsCreate
from server.app.models.threading_analysis.cpu_info import CpuInfoCreate
from server.app.models.threading_analysis.health_check import HealthCheckMetricType, HealthCheckCreate
from server.app.models.threading_analysis.process_utilization import ProcessUtilisationCreate
from server.app.models.threading_analysis.thread_concurrency import ThreadConcurrencyCreate
from server.app.models.threading_analysis.thread_time import ThreadTimeCreate
from server.app.models.threading_analysis.threading_analysis import ThreadingAnalysisCreate

logger = logging.getLogger(__name__)


@dataclass
class DatabaseIds:
    remote: Optional[int] = None
    local: Optional[int] = None


class ReportExporter:
    local_database_url: Optional[str]

    is_local_db_available = False

    database_report_ids: Optional[DatabaseIds] = None

    SessionMaker = None

    def __init__(self, local_database_file: Optional[str] = None):
        if local_database_file:
            self.local_database_url = f'sqlite:///{local_database_file}'
            self.is_local_db_available = True

        # TODO: setup connection to local database
        self.__setup_local_database_connection()


    def __setup_local_database_connection(self):
        if not self.is_local_db_available:
            return
        engine: Engine = create_engine(self.local_database_url, connect_args={"check_same_thread": False}, pool_size=30, max_overflow=-1)
        Base.metadata.create_all(bind=engine)

        self.SessionMaker = sessionmaker(autocommit=False, autoflush=False, bind=engine)

    def __export_model(self,
                       model: BaseModel,
                       local_call: Callable[[BaseModel, Session], any],
                       parent_id_field_name: Optional[str] = None,
                       parent_ids: Optional[DatabaseIds] = None,
                       entity_ids: Optional[DatabaseIds] = None
                       ) -> DatabaseIds:
        result: DatabaseIds = DatabaseIds()
        model_remote = model
        model_local = model
        if parent_ids and parent_id_field_name:
            model_remote = model.model_copy(update={parent_id_field_name: parent_ids.remote})
            model_local = model.model_copy(update={parent_id_field_name: parent_ids.local})

        args_remote = []
        args_local = []
        if entity_ids is not None:
            args_remote.append(entity_ids.remote)
            args_local.append(entity_ids.local)

        args_remote.append(model_remote)
        args_local.append(model_local)

        session: Session = self.SessionMaker()
        args_local.append(session)

        if self.is_local_db_available:
            try:
                local_object_in_db: BaseModel = local_call(*args_local)
                result.local = getattr(local_object_in_db, 'id', None)
            except Exception as e:  # pylint: disable=broad-except
                logger.error(f"Error during creation local object: {local_call.__name__} ; {e}")

        return result

    def create_remote_report(self, meta_info: tu.SourceMetaInfo) -> Optional[DatabaseIds]:
        if self.database_report_ids is not None:
            raise ValueError("Report already exists")
        try:
            report_create = ReportCreate(
                name=meta_info.report_name,
                gtl_dq_job_id=tu.safe_type(meta_info.gtl_dq_job_id, int, -1),
                gtl_application_id=tu.safe_type(meta_info.gtl_application_id, int, -1),
                gtl_application_name=meta_info.gtl_application_name,
                fc_job_id=meta_info.fc_job_id,
                fc_task_id=meta_info.fc_task_id,
                nsys_source_resolved_gtlfs_uuid=meta_info.report_resolved_nsys_id,
                nsys_source_resolved_gtlfs_path=meta_info.report_resolved_nsys_path,
                nsys_source_gtlfs_uuid=meta_info.report_source_nsys_id,
                nsys_source_gtlfs_path=meta_info.report_source_nsys_path,
                nsys_source_sqlite_gtlfs_uuid=meta_info.report_source_sqlite_id,
                nsys_source_sqlite_gtlfs_path=meta_info.report_source_sqlite_path
            )
        except ValidationError as e:
            logger.error(f"Error during creation remote report validation: {e}")
            return None

        self.database_report_ids = self.__export_model(report_create, local_report_routes.create_report)

        return self.database_report_ids

    def export_threading_analysis(self, df_dict: dict[str, pd.DataFrame]) -> Optional[DatabaseIds]:
        if self.database_report_ids is None:
            return None
        # TODO: get source data classes instead of pandas dataframes
        # TODO: handle validation errors
        threading_stats = df_dict['threading_stats'].set_index('Statistics')['Values'].to_dict() if 'threading_stats' in df_dict else {}
        application_stats = ApplicationStatsCreate(
            process_name=threading_stats.get('process_name', ''),
            total_thread_count=tu.safe_type(threading_stats.get('total_thread_count', 0), int),
            active_thread_count=tu.safe_type(threading_stats.get('active_thread_count', 0), int),
            job_thread_count=tu.safe_type(threading_stats.get('job_thread_count', 0), int),
            serial_thread_count=tu.safe_type(threading_stats.get('serial_thread_count', 0), int),
            concurrency_of_active_threads=tu.safe_type(threading_stats.get('concurrency_of_active_threads', 0), int),
            start_time_s=tu.safe_float(threading_stats.get('start_time_(s)', 0)),
            end_time_s=tu.safe_float(threading_stats.get('end_time_(s)', 0)),
            average_cpu_frametime_ms=tu.safe_float(threading_stats.get('average_cpu_frametime_(ms)', 0)),
            average_gpu_frametime_ms=tu.safe_float(threading_stats.get('average_gpu_frametime_(ms)', 0)),
            cpu_idle_pct=tu.safe_float(threading_stats.get('cpu_idle_(%)', 0)),
            gpu_idle_pct=tu.safe_float(threading_stats.get('gpu_idle_(%)', 0)),
            total_thread_utilisation_pct=tu.safe_float(threading_stats.get('total_thread_utilisation_(%)', 0)),
            busiest_thread_util_pct=tu.safe_float(threading_stats.get('busiest_thread_util_(%)', 0)),
            median_job_thread_util_pct=tu.safe_float(threading_stats.get('median_job_thread_util_(%)', 0)),
            serial_work_util_pct=tu.safe_float(threading_stats.get('serial_work_util_(%)', 0)),
            one_thread_active_pct=tu.safe_float(threading_stats.get('one_thread_active_(%)', 0)),
            two_threads_active_pct=tu.safe_float(threading_stats.get('two_threads_active_(%)', 0)),
        )
        cpu_info = CpuInfoCreate.model_validate(df_dict['cpu_info'].set_index('Statistics')['Values'].to_dict())

        def get_metric_type(description: str) -> HealthCheckMetricType:
            if description.endswith('(bool)'):
                return HealthCheckMetricType.BOOL
            if description.endswith('(%)'):
                return HealthCheckMetricType.PERCENT
            if description.endswith('Count'):
                return HealthCheckMetricType.INTEGER
            return HealthCheckMetricType.FLOAT

        health_check: List[HealthCheckCreate] = []
        for index, row in df_dict['health_check'].iterrows():
            health_check.append(HealthCheckCreate(
                description=row.get('Description', ''),
                metric_value=tu.safe_float(row.get('Health Metric', ''), 0, check_bool=True),
                metric_type=get_metric_type(row.get('Description', '')),
                warning=row.get('Warning', ''),
                flag=row.get('Flag', '')
            ))

        process_utilisation: List[ProcessUtilisationCreate] = []
        if 'process_utilisation' in df_dict:
            for index, row in df_dict['process_utilisation'].iterrows():
                process_utilisation.append(ProcessUtilisationCreate(
                    process=row['Processes'],
                    utilisation_pct=tu.safe_float(row['Utilisation (%)'], 0)
                ))

        thread_concurrency: List[ThreadConcurrencyCreate] = []
        thread_concurrency_pct_dict = df_dict['thread_concurrency'].set_index('Number of Threads')['Concurrency (%)'].to_dict() if 'thread_concurrency' in df_dict else {}
        thread_concurrency_time_dict = df_dict['thread_concurrency_time'].set_index('Number of Threads')['Concurrency Time (ms)'].to_dict() if 'thread_concurrency_time' in df_dict else {}
        for key in thread_concurrency_pct_dict.keys():
            thread_concurrency.append(ThreadConcurrencyCreate(
                thread_index=tu.safe_type(key, int, 0),
                concurrency_time_ms=tu.safe_float(thread_concurrency_time_dict.get(key), 0),
                concurrency_pct=tu.safe_float(thread_concurrency_pct_dict.get(key), 0)
            ))

        thread_time: List[ThreadTimeCreate] = []
        thread_time_dict = df_dict['thread_time'].set_index('Threads')['Time (ms)'].to_dict() if 'thread_time' in df_dict else {}
        thread_utilisation_dict = df_dict['thread_utilisation'].set_index('Threads')['Utilisation (%)'].to_dict() if 'thread_utilisation' in df_dict else {}
        cpu_thread_utilisation_dict = {}
        if 'cpu_thread_utilisation' in df_dict:
            cpu_thread_utilisation_dict = df_dict['cpu_thread_utilisation'].pivot_table(index='Threads', columns='P/E Core', values='Utilisation (%)', aggfunc='first').reset_index()
            cpu_thread_utilisation_dict.columns.name = None
            cpu_thread_utilisation_dict = cpu_thread_utilisation_dict.rename_axis(None, axis=1).set_index('Threads').to_dict(orient='index')
        thread_name_list = list(set(thread_time_dict.keys()) | set(thread_utilisation_dict.keys()) | set(cpu_thread_utilisation_dict.keys()))
        for key in thread_name_list:
            thread_time.append(ThreadTimeCreate(
                name=key,
                time_ms=tu.safe_float(thread_time_dict.get(key, 0), 0),
                utilisation_pct=tu.safe_float(thread_utilisation_dict.get(key, 0), 0),
                utilisation_p_core_pct=tu.safe_float(cpu_thread_utilisation_dict.get(key, {}).get('P Core', None), None),
                utilisation_e_core_pct=tu.safe_float(cpu_thread_utilisation_dict.get(key, {}).get('E Core', None), None)
            ))

        thread_time.sort(key=lambda x: x.name)

        threading_analysis = ThreadingAnalysisCreate(
            report_id=0,
            application_stats=application_stats,
            cpu_info=cpu_info,
            health_check=health_check,
            process_utilisation=process_utilisation,
            thread_concurrency=thread_concurrency,
            thread_time=thread_time
        )
        return self.__export_model(threading_analysis, local_threading_analysis_routes.create_threading_analysis, 'report_id', self.database_report_ids)

    def create_remote_hotspot_region(self, region_create: RegionCreate, hotspot_analysis_ids: Optional[DatabaseIds]) -> Optional[DatabaseIds]:
        if self.database_report_ids is None:
            return None
        return self.__export_model(region_create, local_hotspot_analysis_routes.create_region, 'hotspot_analysis_id', hotspot_analysis_ids)

    def patch_remote_hotspot_region(self, region_ids: Optional[DatabaseIds], region_patch: RegionCreate.model_as_partial()):
        if self.database_report_ids:
            self.__export_model(region_patch, local_hotspot_analysis_routes.patch_region, None, None, region_ids)

    def create_remote_hotspot_thread(self, thread_create: RegionThreadCreate, region_ids: Optional[DatabaseIds]) -> Optional[DatabaseIds]:
        if self.database_report_ids is None:
            return None
        return self.__export_model(thread_create, local_hotspot_analysis_routes.create_region_thread, 'region_id', region_ids)

    @tu.timeit
    def create_remote_hotspot_analysis(self,
                                       region_overview: dict,
                                       capture_info: dict,
                                       cpu_frametimes: Optional[pd.DataFrame],
                                       gpu_frametimes: Optional[pd.DataFrame],
                                       sorted_gpumetric_regions_values: Optional[pd.DataFrame],
                                       sorted_gpumetric_regions_durations: Optional[pd.DataFrame],
                                       app_health_gpu_bound: Optional[bool],
                                       report_range: str,
                                       report_threading: str
                                       ) -> Optional[DatabaseIds]:
        if self.database_report_ids is None:
            return None
        # TODO: do not use pandas dataframes. get data from source

        default_region_overview_field_list: list[str] = ['Region Description', 'Number of regions', 'Fastest (ms)', 'Slowest (ms)', 'Median (ms)', 'Mean (ms)', 'Average Micro-stutters (ms)', 'Region Count']

        hotspot_analysis_create = HotspotAnalysisCreate(
            report_id=0,

            process_name=capture_info.get('Process Name', ''),
            process_id=tu.safe_type(capture_info.get('PID', 0), int, 0),
            start_time_s=tu.safe_float(capture_info.get('Start Time (s)', '')),
            end_time_s=tu.safe_float(capture_info.get('End Time (s)', '')),
            duration_s=tu.safe_float(capture_info.get('Duration (s)', '')),

            region_description=region_overview.get('Region Description', ''),
            number_of_regions=tu.safe_type(region_overview.get('Number of regions', None), int, None),
            fastest_ms=tu.safe_float(region_overview.get('Fastest (ms)', None)),
            slowest_ms=tu.safe_float(region_overview.get('Slowest (ms)', None)),
            median_ms=tu.safe_float(region_overview.get('Median (ms)', None)),
            mean_ms=tu.safe_float(region_overview.get('Mean (ms)', None)),
            average_micro_stutters_ms=tu.safe_float(region_overview.get('Average Micro-stutters (ms)', None)),
            region_count=tu.safe_type(region_overview.get('Region Count', 0), int, None),

            app_health_gpu_bound=app_health_gpu_bound,
            hotspot_range=report_range,
            hotspot_report_threading=report_threading,

            frame_times=[],
            custom_fields=[],
            custom_chart_rows=[],
            custom_chart_title=None,
            custom_chart_x_axis_name=None,
            custom_chart_y_axis_name=None,
        )

        # Fill custom region_overview fields
        for field in region_overview.keys():
            if field in default_region_overview_field_list:
                continue
            hotspot_analysis_create.custom_fields.append(HotspotAnalysisCustomFieldCreate(
                name=field,
                value=str(region_overview[field])
            ))

        if cpu_frametimes is not None:
            if gpu_frametimes is None:
                gpu_frametimes = pd.DataFrame({'duration_gpu_ms': [None] * len(cpu_frametimes)}, index=cpu_frametimes.index)
            else:
                gpu_frametimes = gpu_frametimes[[0]].rename(columns={0: 'duration_gpu_ms'})

            frametimes = cpu_frametimes.join(gpu_frametimes)

            for index, row in frametimes.iterrows():
                hotspot_analysis_create.frame_times.append(FrameTimesCreate(
                    frame_index=index,
                    start_ns=tu.safe_type(row['start_ns'], int, 0),
                    duration_cpu_ms=tu.safe_float(row['duration_ms'], 0),
                    duration_gpu_ms=tu.safe_float(row['duration_gpu_ms'], None)
                ))

        sorted_gpumetric_regions: Optional[pd.DataFrame] = None
        if sorted_gpumetric_regions_values is not None:
            hotspot_analysis_create.custom_chart_title = f'GPU Metrics Ranges based on Largest Sum of {report_range}'
            hotspot_analysis_create.custom_chart_x_axis_name = 'Range Count'
            hotspot_analysis_create.custom_chart_y_axis_name = f'{report_range} Metric Summed Value'
            sorted_gpumetric_regions = sorted_gpumetric_regions_values
        elif sorted_gpumetric_regions_durations is not None:
            hotspot_analysis_create.custom_chart_title = f'GPU Metrics Ranges based on Duration of {report_range}'
            hotspot_analysis_create.custom_chart_x_axis_name = 'Range Count'
            hotspot_analysis_create.custom_chart_y_axis_name = f'{report_range} Range Duration (ms)'
            sorted_gpumetric_regions = sorted_gpumetric_regions_durations

        if sorted_gpumetric_regions is not None:
            for index, row in sorted_gpumetric_regions.iterrows():
                hotspot_analysis_create.custom_chart_rows.append(HotspotAnalysisCustomChartRowCreate(
                    row_index=index,
                    value=row[0]
                ))

        return self.__export_model(hotspot_analysis_create, local_hotspot_analysis_routes.create_hotspot_analysis, 'report_id', self.database_report_ids)
