"""
LangGraph Visualization Demo

This script demonstrates how to visualize a LangGraph using the official methods
as documented in: https://langchain-ai.github.io/langgraph/how-tos/visualization/
"""

from typing import Optional, TypedDict, Dict, Any, List
from langgraph.graph import StateGraph, END

# Try to import visualization capabilities
try:
    from langgraph.graph import visualize as graph_visualize
    VISUALIZATION_AVAILABLE = True
    print("LangGraph visualization module imported successfully.")
except ImportError:
    try:
        from langgraph.visualize import visualize as graph_visualize
        VISUALIZATION_AVAILABLE = True
        print("LangGraph visualization module imported successfully.")
    except ImportError:
        VISUALIZATION_AVAILABLE = False
        print("LangGraph visualization module not available.")
        print("Please ensure you have the latest version of LangGraph installed.")
        print("pip install -U langgraph")

# Dummy-LLM zum Testen
class DummyLLM:
    def invoke(self, text):
        if "hoch" in text.lower():
            return "hoch"
        elif "mittel" in text.lower():
            return "mittel"
        else:
            return "niedrig"

llm = DummyLLM()

# 1. Zustandstyp definieren
class State(TypedDict):
    question: str
    response: Optional[str]
    risk_level: Optional[str]

# 2. Graph erstellen
graph = StateGraph(State)

# 3. Knoten definieren
def classify_risk(state: Dict[str, Any]) -> Dict[str, Any]:
    question = state["question"]
    risk_level = llm.invoke(f"Klassifiziere das Risiko: {question}")
    return {"risk_level": risk_level}

def handle_high_risk(state: Dict[str, Any]) -> Dict[str, Any]:
    return {"response": "Hohes Risiko erkannt: Eskalation an Senior Risk Manager"}

def handle_medium_risk(state: Dict[str, Any]) -> Dict[str, Any]:
    return {"response": "Mittleres Risiko: Standard-Prüfprozess einleiten"}

def handle_low_risk(state: Dict[str, Any]) -> Dict[str, Any]:
    return {"response": "Niedriges Risiko: Automatische Freigabe"}

# 4. Knoten hinzufügen
graph.add_node("classify_risk", classify_risk)
graph.add_node("high_risk", handle_high_risk)
graph.add_node("medium_risk", handle_medium_risk)
graph.add_node("low_risk", handle_low_risk)

# 5. Routing-Logik
def route_risk(state: Dict[str, Any]) -> str:
    risk = state["risk_level"]
    if "hoch" in risk.lower():
        return "high_risk"
    elif "mittel" in risk.lower():
        return "medium_risk"
    else:
        return "low_risk"

# 6. Kanten
try:
    # Newer versions use add_conditional_edges
    graph.add_conditional_edges("classify_risk", route_risk)
    print("Using add_conditional_edges method")
except AttributeError:
    # Older versions might use add_edge with a function
    graph.add_edge("classify_risk", route_risk)
    print("Using add_edge method with function")

graph.add_edge("high_risk", END)
graph.add_edge("medium_risk", END)
graph.add_edge("low_risk", END)

# ✅ Startknoten definieren
graph.set_entry_point("classify_risk")

# 7. Kompilieren
app = graph.compile()

def generate_mermaid_manually():
    """Generate a Mermaid diagram representation of the graph manually."""
    mermaid = """
```mermaid
flowchart TD
    START([Start]) --> classify_risk[Classify Risk]
    classify_risk -- risk=hoch --> high_risk[Handle High Risk]
    classify_risk -- risk=mittel --> medium_risk[Handle Medium Risk]
    classify_risk -- risk=niedrig --> low_risk[Handle Low Risk]
    high_risk --> END([End])
    medium_risk --> END
    low_risk --> END

    classDef default fill:#f9f9f9,stroke:#333,stroke-width:1px;
    classDef start fill:#6BCB77,stroke:#333,stroke-width:1px;
    classDef end_fill fill:#FF6B6B,stroke:#333,stroke-width:1px;
    classDef risk fill:#4D96FF,stroke:#333,stroke-width:1px;
    classDef high fill:#FF6B6B,stroke:#333,stroke-width:1px;
    classDef medium fill:#FFD93D,stroke:#333,stroke-width:1px;
    classDef low fill:#6BCB77,stroke:#333,stroke-width:1px;

    class START start;
    class END end_fill;
    class classify_risk risk;
    class high_risk high;
    class medium_risk medium;
    class low_risk low;
```
"""
    return mermaid

def visualize_graph():
    """Visualize the graph using LangGraph's visualization methods."""
    if not VISUALIZATION_AVAILABLE:
        print("Visualization not available. Using manual Mermaid diagram.")
        return generate_mermaid_manually()

    try:
        # Try to use the visualize function with mermaid format
        print("Attempting to visualize graph with Mermaid format...")
        mermaid_diagram = graph_visualize(graph, format="mermaid")
        print("Successfully generated Mermaid diagram!")
        return f"```mermaid\n{mermaid_diagram}\n```"
    except Exception as e1:
        print(f"Error generating Mermaid diagram: {e1}")
        try:
            # Try to use the visualize function with default format
            print("Attempting to visualize graph with default format...")
            diagram = graph_visualize(graph)
            print("Successfully generated diagram!")
            return diagram
        except Exception as e2:
            print(f"Error generating diagram: {e2}")
            print("Falling back to manual Mermaid diagram.")
            return generate_mermaid_manually()

def save_visualization():
    """Save the visualization to files."""
    # Get the visualization
    visualization = visualize_graph()

    # Save to markdown file
    with open("risk_graph_visualization.md", "w") as f:
        f.write("# Risk Classification Workflow\n\n")
        f.write("This diagram shows the flow of the risk classification process.\n\n")
        f.write(visualization)

    print("Visualization saved to risk_graph_visualization.md")

    # Try to save as SVG if possible
    if VISUALIZATION_AVAILABLE:
        try:
            print("Attempting to save as SVG...")
            svg_diagram = graph_visualize(graph, format="svg")
            with open("risk_graph.svg", "w") as f:
                f.write(svg_diagram)
            print("SVG diagram saved to risk_graph.svg")
        except Exception as e:
            print(f"Error saving SVG: {e}")

if __name__ == "__main__":
    # Test the graph
    test_input = {"question": "Was tun bei hoher Marktvolatilität?"}
    result = app.invoke(test_input)
    print("\nTest result:", result)

    # Generate and save visualization
    print("\nGenerating visualization...")
    save_visualization()

    print("\nVisualization complete!")
    print("To view the diagram:")
    print("1. Open risk_graph_visualization.md in a Markdown viewer that supports Mermaid")
    print("2. Or check if risk_graph.svg was created and view it in a browser or image viewer")
