"""
Risk Graph Visualization using Mermaid

This script creates a visualization of the risk classification workflow
using the Mermaid format as recommended in the LangGraph documentation:
https://langchain-ai.github.io/langgraph/how-tos/visualization/#mermaid
"""

from typing import Optional, TypedDict
from langgraph.graph import StateGraph, END

# Dummy-LLM zum Testen
class DummyLLM:
    def invoke(self, text):
        if "hoch" in text.lower():
            return "hoch"
        elif "mittel" in text.lower():
            return "mittel"
        else:
            return "niedrig"

llm = DummyLLM()

# 1. Zustandstyp definieren
class State(TypedDict):
    question: str
    response: Optional[str]
    risk_level: Optional[str]

# 2. Graph erstellen
graph = StateGraph(State)

# 3. Knoten definieren
def classify_risk(state):
    question = state["question"]
    risk_level = llm.invoke(f"Klassifiziere das Risiko: {question}")
    return {"risk_level": risk_level}

def handle_high_risk(state):
    return {"response": "Hohes Risiko erkannt: Eskalation an Senior Risk Manager"}

def handle_medium_risk(state):
    return {"response": "Mittleres Risiko: Standard-Prüfprozess einleiten"}

def handle_low_risk(state):
    return {"response": "Niedriges Risiko: Automatische Freigabe"}

# 4. Knoten hinzufügen
graph.add_node("classify_risk", classify_risk)
graph.add_node("high_risk", handle_high_risk)
graph.add_node("medium_risk", handle_medium_risk)
graph.add_node("low_risk", handle_low_risk)

# 5. Routing-Logik
def route_risk(state):
    risk = state["risk_level"]
    if "hoch" in risk.lower():
        return "high_risk"
    elif "mittel" in risk.lower():
        return "medium_risk"
    else:
        return "low_risk"

# 6. Kanten
try:
    # Newer versions use add_conditional_edges
    graph.add_conditional_edges("classify_risk", route_risk)
except AttributeError:
    # Older versions might use add_edge with a function
    graph.add_edge("classify_risk", route_risk)

graph.add_edge("high_risk", END)
graph.add_edge("medium_risk", END)
graph.add_edge("low_risk", END)

# ✅ Startknoten definieren
graph.set_entry_point("classify_risk")

# 7. Kompilieren
app = graph.compile()

def generate_mermaid():
    """Generate a Mermaid diagram representation of the graph."""
    mermaid = """
```mermaid
flowchart TD
    START([Start]) --> classify_risk[Classify Risk]
    classify_risk -- risk=hoch --> high_risk[Handle High Risk]
    classify_risk -- risk=mittel --> medium_risk[Handle Medium Risk]
    classify_risk -- risk=niedrig --> low_risk[Handle Low Risk]
    high_risk --> END([End])
    medium_risk --> END
    low_risk --> END

    classDef default fill:#f9f9f9,stroke:#333,stroke-width:1px;
    classDef start fill:#6BCB77,stroke:#333,stroke-width:1px;
    classDef end_fill fill:#FF6B6B,stroke:#333,stroke-width:1px;
    classDef risk fill:#4D96FF,stroke:#333,stroke-width:1px;
    classDef high fill:#FF6B6B,stroke:#333,stroke-width:1px;
    classDef medium fill:#FFD93D,stroke:#333,stroke-width:1px;
    classDef low fill:#6BCB77,stroke:#333,stroke-width:1px;

    class START start;
    class END end_fill;
    class classify_risk risk;
    class high_risk high;
    class medium_risk medium;
    class low_risk low;
```
"""
    return mermaid

def save_mermaid_to_file():
    """Save the Mermaid diagram to a file."""
    mermaid_content = generate_mermaid()

    # Save to markdown file
    with open("risk_graph_mermaid.md", "w") as f:
        f.write("# Risk Classification Workflow\n\n")
        f.write("This diagram shows the flow of the risk classification process.\n\n")
        f.write(mermaid_content)

    print("Mermaid diagram saved to risk_graph_mermaid.md")
    print("To view the diagram:")
    print("1. Open the markdown file in a Markdown viewer that supports Mermaid")
    print("2. Or copy the Mermaid code to an online Mermaid editor like https://mermaid.live")

if __name__ == "__main__":
    # Test the graph
    test_input = {"question": "Was tun bei hoher Marktvolatilität?"}
    result = app.invoke(test_input)
    print("Test result:", result)

    # Generate and save Mermaid diagram
    save_mermaid_to_file()
