# Copyright (C) 2023 - 2025 ANSYS, Inc. and/or its affiliates.
# SPDX-License-Identifier: MIT
#
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import typing
import numpy as np
import pandas as pd
from ansys.dyna.core import Deck
[docs]
def get_nid_to_index_mapping(nodes) -> typing.Dict:
    """Given a node id, output the node index as a dict"""
    mapping = {}
    for idx, node in nodes.iterrows():
        mapping[node["nid"]] = idx
    return mapping 
[docs]
def merge_keywords(
    deck: Deck,
) -> typing.Tuple[pd.DataFrame, typing.Dict]:
    """
    Merge mesh keywords.
    Given a deck, merges specific keywords (NODE, ELEMENT_SHELL, ELEMENT_BEAM, ELEMENT_SOLID)
    and returns tham as data frames.
    """
    nodes_temp = [kwd.nodes for kwd in deck.get_kwds_by_type("NODE")]
    nodes = pd.concat(nodes_temp) if len(nodes_temp) else pd.DataFrame()
    df_list = {}
    for item in ["SHELL", "BEAM", "SOLID"]:
        matching_elements = [kwd.elements for kwd in deck.get_kwds_by_type("ELEMENT") if kwd.subkeyword == item]
        df_list[item] = pd.concat(matching_elements) if len(matching_elements) else pd.DataFrame()
    return (
        nodes,
        df_list,
    )  # solids 
[docs]
def process_nodes(nodes_df):
    nodes_xyz = nodes_df[["x", "y", "z"]]
    return nodes_xyz.to_numpy() 
[docs]
def shell_facet_array(facets: pd.DataFrame) -> np.array:
    """
    Get the shell facet array from the DataFrame.
    Facets are a pandas frame that is a sequence of integers
    or NAs with max length of 8.
    valid rows contain 3,4,6, or 8 items consecutive from the
    left.  we don't plot quadratic edges so 6/8 collapse to 3/4
    invalid rows are ignored, meaning they return an empty array
    return an array of length 4 or 5 using the pyvista spec
    for facets which includes a length prefix
    [1,2,3]=>[3,1,2,3]
    [1,2,3,0]=>[3,1,2,3]
    [1,2,3,NA]=>[3,1,2,3]
    """
    facet_array = np.empty(5, dtype=np.int32)
    for idx, item in enumerate(facets):
        # find the first empty column
        if pd.isna(item) or item == 0:
            if idx == 3 or idx == 6:
                facet_array[0] = 3
                return facet_array[:-1]
            elif idx == 4:
                facet_array[0] = 4
                return facet_array
            else:
                # invalid
                return np.empty(0)
        # fill the output to the right of the prefix
        if idx < 4:
            facet_array[idx + 1] = item
    facet_array[0] = 4
    return facet_array 
[docs]
def solid_array(solids: pd.DataFrame):
    """
    Get the solid array from the DataFrame.
    Solids are a pandas frame that is a sequence of integers
    or NAs with max length of 28.
    valid rows contain 3, 4, 6, or 8 items consecutive from the
    left.  We don't plot quadratic edges so 6/8 collapse to 3/4
    invalid rows are ignored, meaning they return an empty array
    return an array of length 4 or 5 using the pyvista spec
    for facets which includes a length prefix
    [1,2,3]=>[3,1,2,3]
    [1,2,3,0]=>[3,1,2,3]
    [1,2,3,NA]=>[3,1,2,3]
    """
    # FACES CREATED BY THE SOLIDS BASED ON MANUAL
    # A DUMMY ZERO IS PUT AS A PLACEHOLDER FOR THE LEN PREFIX
    four_node_faces = [[0, 1, 2, 3], [0, 1, 2, 4], [0, 1, 3, 4], [0, 2, 3, 4]]
    six_node_faces = [
        [0, 1, 2, 5],
        [0, 3, 4, 6],
        [0, 2, 3, 5, 6],
        [0, 1, 5, 6, 4],
        [0, 1, 2, 3, 4],
    ]
    eight_node_faces = [
        [0, 1, 2, 3, 4],
        [0, 1, 2, 5, 6],
        [0, 5, 6, 7, 8],
        [0, 3, 4, 7, 8],
        [0, 2, 3, 6, 7],
        [0, 1, 4, 5, 8],
    ]
    facet_array = []
    for idx, item in enumerate(solids):
        # find the first empty column
        if pd.isna(item) or item == 0:
            if idx == 4:
                facet_array = [len(facet) - 1 if i == 0 else solids[i - 1] for facet in four_node_faces for i in facet]
                return facet_array
            elif idx == 6:
                facet_array = [len(facet) - 1 if i == 0 else solids[i - 1] for facet in six_node_faces for i in facet]
                return facet_array
            elif idx == 8:
                facet_array = [len(facet) - 1 if i == 0 else solids[i - 1] for facet in eight_node_faces for i in facet]
                return facet_array
            else:
                # invalid
                return []
        # fill the output to the right of the prefix
    return np.array(facet_array) 
[docs]
def line_array(lines: pd.DataFrame) -> np.array:
    """
    Convert DataFrame to lines array.
    `lines` is a pandas frame that is a sequence of integers
    or NAs with max length of 2.
    valid rows contain 2 items consecutive from the
    left.
    invalid rows are ignored, meaning they return an empty array
    return an array of length 3 using the pyvista spec
    for facets which includes a length prefix
    [1,2,]=>[2,1,2]
    [1,2,3,0]=>[]
    [1,2,3,NA]=>[]
    """
    line_array = np.empty(3, dtype=np.int32)
    for idx, item in enumerate(lines):
        # find the first empty column
        if pd.isna(item) or item == 0:
            if idx == 0 or idx == 1:
                return np.empty(0)
        # fill the output to the right of the prefix
        if idx < 2:
            line_array[idx + 1] = item
    line_array[0] = 2
    return line_array 
[docs]
def map_facet_nid_to_index(flat_facets: np.array, mapping: typing.Dict) -> np.array:
    """Convert mapping to numpy array.
    Given a flat list of facets or lines, use the mapping from nid to python index
    to output the numbering system for pyvista from the numbering from dyna
    """
    # Map the indexes but skip the prefix
    flat_facets_indexed = np.empty(len(flat_facets), dtype=np.int32)
    skip_flag = 0
    for idx, item in np.ndenumerate(flat_facets):
        if skip_flag == 0:
            flat_facets_indexed[idx] = item
            skip_flag -= int(item)
        else:
            flat_facets_indexed[idx] = mapping[item]
            skip_flag += 1
    return flat_facets_indexed 
[docs]
def get_pyvista():
    try:
        import pyvista as pv
    except ImportError:
        raise Exception("plot is only supported if pyvista is installed")
    return pv 
[docs]
def get_polydata(deck: Deck, cwd=None):
    """Create the PolyData Object for plotting from a given deck with nodes and elements."""
    # import this lazily (otherwise this adds over a second to the import time of pyDyna)
    pv = get_pyvista()
    # check kwargs for cwd. future more arguments to plot
    # flatten deck
    if cwd is not None:
        flat_deck = deck.expand(cwd=cwd, recurse=True)
    else:
        flat_deck = deck.expand(recurse=True)
    # get dataframes for each element types
    nodes_df, element_dict = merge_keywords(flat_deck)
    shells_df = element_dict["SHELL"]
    beams_df = element_dict["BEAM"]
    solids_df = element_dict["SOLID"]
    nodes_list = process_nodes(nodes_df)
    if len(nodes_df) == 0 or len(shells_df) + len(beams_df) + len(solids_df) == 0:
        raise Exception("missing node or element keyword to plot")
    mapping = get_nid_to_index_mapping(nodes_df)
    # get the node information, element_ids and part_ids
    facets, shell_eids, shell_pids = extract_shell_facets(shells_df, mapping)
    lines, line_eids, line_pids = extract_lines(beams_df, mapping)
    solids_info = extract_solids(solids_df, mapping)
    # celltype_dict for beam and shell
    celltype_dict = {
        pv.CellType.LINE: lines.reshape([-1, 3])[:, 1:],
        pv.CellType.QUAD: facets.reshape([-1, 5])[:, 1:],
    }
    # dict of cell types for node counts
    solid_celltype = {
        4: pv.CellType.TETRA,
        5: pv.CellType.PYRAMID,
        6: pv.CellType.WEDGE,
        8: pv.CellType.HEXAHEDRON,
    }
    # Update celltype_dict with solid elements
    solids_pids = np.empty((0), dtype=int)
    solids_eids = np.empty((0), dtype=int)
    for n_points, elements in solids_info.items():
        if len(elements[0]) == 0:
            continue
        temp_solids, temp_solids_eids, temp_solids_pids = elements
        celltype_dict[solid_celltype[n_points]] = temp_solids.reshape([-1, n_points + 1])[:, 1:]
        # Update part_ids and element_ids info for solid elements
        if len(solids_pids) != 0:
            solids_pids = np.concatenate((solids_pids, temp_solids_pids), axis=0)
        else:
            solids_pids = temp_solids_pids
        if len(solids_pids) != 0:
            solids_eids = np.concatenate((solids_eids, temp_solids_eids), axis=0)
        else:
            solids_eids = temp_solids_eids
    # Create UnstructuredGrid
    plot_data = pv.UnstructuredGrid(celltype_dict, nodes_list)
    # Mapping part_ids and element_ids
    plot_data.cell_data["part_ids"] = np.concatenate((line_pids, shell_pids, solids_pids), axis=0)
    plot_data.cell_data["element_ids"] = np.concatenate((line_eids, shell_eids, solids_eids), axis=0)
    return plot_data 
[docs]
def plot_deck(deck, **args):
    """Plot the deck."""
    # import this lazily (otherwise this adds over a second to the import time of pyDyna)
    pv = get_pyvista()
    plot_data = get_polydata(deck, args.pop("cwd", ""))
    # set default color if both color and scalars are not specified
    color = args.pop("color", None)
    scalars = args.pop("scalars", None)
    if scalars is not None:
        return plot_data.plot(scalars=scalars, **args)  # User specified scalars
    elif color is not None:
        return plot_data.plot(color=color, **args)  # User specified color
    else:
        return plot_data.plot(color=pv.global_theme.color, **args)  # Default color