Source code for tbp.monty.frameworks.environment_utils.graph_utils
# Copyright 2025 Thousand Brains Project
# Copyright 2022-2024 Numenta Inc.
#
# Copyright may exist in Contributors' modifications
# and/or contributions to the work.
#
# Use of this source code is governed by the MIT
# license that can be found in the LICENSE file or at
# https://opensource.org/licenses/MIT.
import numpy as np
[docs]def get_edge_index(graph, previous_node, new_node):
"""Get the edge index between two nodes in a graph.
TODO: There must be an easier way to do this!
Args:
graph: torch_geometric.data graph
previous_node: node ID if the first node in the graph
new_node: node ID if the second node in the graph
Returns:
edge ID between the two nodes
"""
edges_of_node = np.where(graph.edge_index[0] == previous_node)[0]
for i in range(len(edges_of_node)):
possible_next_node = graph.edge_index[1][edges_of_node[i]]
if possible_next_node == new_node:
return edges_of_node[i]