from kfrm_tool.include.db.db_connection import PostgreSQLConnection
from kfrm_tool.include.setting.logger import KFRMLogger


class PgQuery:
    def __init__(self):
        self.pg = PostgreSQLConnection()

    # select
    def getGeometryInfo(self, tableName):
        sql = f"SELECT * FROM geometry_columns WHERE f_table_name='{tableName}'"
        KFRMLogger.debug(f"pg : {sql}")
        sql_result = self.pg.excValue(sql)
        # return : catalog, schema, tableName, geomName, coord, srid, geomType
        return sql_result[0]

    def getProj4OrSrid(self, type, ref):
        # ref = srid, auth_name, proj, srtext
        if type == 1:  # get srid
            content = ref[2].strip().replace("  ", " ")
            target, where = "srid", f"proj4text ilike '%{content}%'"
        else:  # get proj4
            content = ref[0]
            target, where = "proj4text", f"srid = {content}"

        sql = f"SELECT {target} from spatial_ref_sys where {where};"
        ret = self.pg.excValue(sql)

        if not ret:
            try:
                inSql = f"INSERT INTO spatial_ref_sys VALUES({ref[0]}, '{ref[1].upper()}', {ref[0]}, '{ref[3]}', '{ref[2]}');"
                self.pg.excNone(inSql)
            except Exception:
                raise Warning(
                    "It is not the specified coordinate system. Please change to a different coordinate system."
                )

        return self.pg.excValue(sql)[0][0]

    def getTableHeader(self, tableName):
        sql = f"SELECT column_name FROM information_schema.columns WHERE table_name='{tableName}';"
        result = self.pg.excValue(sql)
        header = [h[0] for h in result]
        return header

    def getTableAllContents(self, tableName, columns="*", order=""):
        sql = f'SELECT {columns} FROM "{tableName}" {order};'
        return self.pg.excValue(sql)

    def getAreaSql(self, tableName, sigCd):
        codeList = sigCd.split(", ")
        codes = [f"sig_cd like '{c}%'" for c in codeList]
        sql = f'SELECT * FROM "{tableName}" WHERE {" or ".join(codes)};'
        KFRMLogger.debug(f"pg : {sql}")
        return sql

    def getIntersectAreaName(self, year, sqliteGeom, sqliteRef):
        postGeomInfo = self.getGeometryInfo(f"sgg_boundary_{year}")
        postgisGeom = postGeomInfo[3]
        postgisSrid = postGeomInfo[5]
        sqliteSrid = self.getProj4OrSrid(1, sqliteRef)

        sql = f""" SELECT sig_cd, l_admin as area 
                    FROM vSgg_boundary_{year} WHERE 
                    st_intersects(
                        ST_Buffer(ST_Transform(ST_SetSRID(ST_GeomFromText('{sqliteGeom}'), {sqliteSrid}), {postgisSrid}), 0), 
                        ST_Transform(ST_SetSRID({postgisGeom}, {postgisSrid}), {postgisSrid})
                    ) ORDER BY area asc;"""

        rst = self.pg.excValue(sql)
        area = [a[1] for a in rst if a[1]]
        sig_cd = [a[0] for a in rst if a[0]]
        return [", ".join(sig_cd), ", ".join(area)]

    def getPopuUnitCosts(self, year):
        sql = f'SELECT * FROM "population_unit_cost" WHERE year={year}'
        return self.pg.excValue(sql)

    def getGridContent(self, wkt):
        sql = f"""
            SELECT gid, value, ST_AsText(geometry) as wkt 
            FROM base_grid_500m 
            WHERE ST_Intersects(geometry, ST_GeomFromText('{wkt}', 5179));
        """
        ret = self.pg.excValue(sql)
        return ret
