Source code for tbp.monty.frameworks.environment_utils.graph_utils

# Copyright 2025 Thousand Brains Project
# Copyright 2022-2024 Numenta Inc.
#
# Copyright may exist in Contributors' modifications
# and/or contributions to the work.
#
# Use of this source code is governed by the MIT
# license that can be found in the LICENSE file or at
# https://opensource.org/licenses/MIT.
from typing import Optional


[docs]def get_edge_index(graph, previous_node, new_node) -> Optional[int]: """Get the edge index between two nodes in a graph. Args: graph: torch_geometric.data graph previous_node: node ID if the first node in the graph new_node: node ID if the second node in the graph Returns: edge ID between the two nodes """ mask = (graph.edge_index[0] == previous_node) & (graph.edge_index[1] == new_node) if mask.any(): return mask.nonzero().view(-1)[0].item() return None