# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------

import os
from typing import Dict, List, Optional, Union

from azure.core.credentials import TokenCredential
from azure.ai.evaluation._model_configurations import AzureAIProject
from azure.ai.evaluation.simulator._model_tools import ManagedIdentityAPITokenManager
from azure.ai.evaluation._common.raiclient import MachineLearningServicesClient
from azure.ai.evaluation._constants import TokenScope
from azure.ai.evaluation._common.utils import is_onedp_project
from azure.ai.evaluation._common.onedp import AIProjectClient
from azure.ai.evaluation._common import EvaluationServiceOneDPClient
import jwt
import time
import ast

class GeneratedRAIClient:
    """Client for the Responsible AI Service using the auto-generated MachineLearningServicesClient.
    
    :param azure_ai_project: The scope of the Azure AI project. It contains subscription id, resource group, and project name.
    :type azure_ai_project: ~azure.ai.evaluation.AzureAIProject
    :param token_manager: The token manager
    :type token_manager: ~azure.ai.evaluation.simulator._model_tools._identity_manager.APITokenManager
    """
    
    def __init__(self, azure_ai_project: Union[AzureAIProject, str], token_manager: ManagedIdentityAPITokenManager):
        self.azure_ai_project = azure_ai_project
        self.token_manager = token_manager
        
        if not is_onedp_project(azure_ai_project):
            # Service URL construction
            if "RAI_SVC_URL" in os.environ:
                endpoint = os.environ["RAI_SVC_URL"].rstrip("/")
            else:
                endpoint = self._get_service_discovery_url()
            
            # Create the autogenerated client
            self._client = MachineLearningServicesClient(
                endpoint=endpoint,
                subscription_id=self.azure_ai_project["subscription_id"],
                resource_group_name=self.azure_ai_project["resource_group_name"],
                workspace_name=self.azure_ai_project["project_name"],
                credential=self.token_manager,
            ).rai_svc
        else:
            self._client = AIProjectClient(endpoint=azure_ai_project, credential=token_manager).red_teams
            self._operations_client = AIProjectClient(endpoint=azure_ai_project, credential=token_manager).evaluations
            self._evaluation_onedp_client = EvaluationServiceOneDPClient(endpoint=azure_ai_project, credential=token_manager)
        
    def _get_service_discovery_url(self):
        """Get the service discovery URL.
        
        :return: The service discovery URL
        :rtype: str
        """
        import requests
        bearer_token = self._fetch_or_reuse_token(self.token_manager)
        headers = {"Authorization": f"Bearer {bearer_token}", "Content-Type": "application/json"}
        
        response = requests.get(
            f"https://management.azure.com/subscriptions/{self.azure_ai_project['subscription_id']}/"
            f"resourceGroups/{self.azure_ai_project['resource_group_name']}/"
            f"providers/Microsoft.MachineLearningServices/workspaces/{self.azure_ai_project['project_name']}?"
            f"api-version=2023-08-01-preview",
            headers=headers,
            timeout=5,
        )
        
        if response.status_code != 200:
            msg = (
                f"Failed to connect to your Azure AI project. Please check if the project scope is configured "
                f"correctly, and make sure you have the necessary access permissions. "
                f"Status code: {response.status_code}."
            )
            raise Exception(msg)

        # Parse the discovery URL
        from urllib.parse import urlparse
        base_url = urlparse(response.json()["properties"]["discoveryUrl"])
        return f"{base_url.scheme}://{base_url.netloc}"
    
    async def get_attack_objectives(self, risk_category: Optional[str] = None, application_scenario: str = None, strategy: Optional[str] = None) -> Dict:
        """Get attack objectives using the auto-generated operations.
        
        :param risk_category: Optional risk category to filter the attack objectives
        :type risk_category: Optional[str]
        :param application_scenario: Optional description of the application scenario for context
        :type application_scenario: str
        :param strategy: Optional strategy to filter the attack objectives
        :type strategy: Optional[str]
        :return: The attack objectives
        :rtype: Dict
        """ 
        try:
            # Send the request using the autogenerated client
            response = self._client.get_attack_objectives(
                risk_types=[risk_category],
                lang="en",
                strategy=strategy,
            )
            return response
            
        except Exception as e:
            # Log the exception for debugging purposes
            import logging
            logging.error(f"Error in get_attack_objectives: {str(e)}")
            raise
        
    async def get_jailbreak_prefixes(self) -> List[str]:
        """Get jailbreak prefixes using the auto-generated operations.
        
        :return: The jailbreak prefixes
        :rtype: List[str]
        """
        try:
            # Send the request using the autogenerated client
            response = self._client.get_jail_break_dataset_with_type(type="upia")
            if isinstance(response, list):
                return response
            else:
                self.logger.error("Unexpected response format from get_jail_break_dataset_with_type")
                raise ValueError("Unexpected response format from get_jail_break_dataset_with_type")
            
        except Exception as e:
            return [""]

    def _fetch_or_reuse_token(self, credential: TokenCredential, token: Optional[str] = None) -> str:
        """Get token. Fetch a new token if the current token is near expiry

        :param credential: The Azure authentication credential.
        :type credential:
        ~azure.core.credentials.TokenCredential
        :param token: The Azure authentication token. Defaults to None. If none, a new token will be fetched.
        :type token: str
        :return: The Azure authentication token.
        """
        if token:
            # Decode the token to get its expiration time
            try:
                decoded_token = jwt.decode(token, options={"verify_signature": False})
            except jwt.PyJWTError:
                pass
            else:
                exp_time = decoded_token["exp"]
                current_time = time.time()

                # Return current token if not near expiry
                if (exp_time - current_time) >= 300:
                    return token

        return credential.get_token(TokenScope.DEFAULT_AZURE_MANAGEMENT).token
