Source code for tamr_toolbox.workflow.concurrent.Graph

import logging
from typing import List, Set, Tuple

from dataclasses import dataclass
import networkx as nx

from tamr_unify_client import Client
from tamr_unify_client.project.resource import Project
from tamr_toolbox.models.project_type import ProjectType
from typing import Dict, Any

LOGGER = logging.getLogger(__name__)


[docs]@dataclass() class Graph: """ A dataclass for holding a set of Tamr project dependencies (edges), and the generated graph from them. """ edges: set directed_graph: nx.DiGraph
def _get_upstream_projects(project: Project, *, all_projects: List[Project]) -> List[Project]: """ get projects immediately upstream of a given project Args: project: the project to check all_projects: a list of all Projects on the Tamr Core instance Returns: A list of project names upstream of the project """ client = project.client # find upstream datasets - if GR project just get input datasets if ProjectType[project.type] == ProjectType.GOLDEN_RECORDS: upstream_datasets = [x for x in project.input_datasets().stream()] # else find the upstream datasets of the UD (not input datasets to capture datasets used in Tx) else: unified_dataset_id = project.unified_dataset().relative_id unified_dataset = client.datasets.by_relative_id(unified_dataset_id) upstream_datasets = unified_dataset.upstream_datasets() upstream_project_names = [] # walk through upstream datasets for upstream_result in upstream_datasets: # get the upstream object as a dataset upstream_dataset = client.datasets.by_resource_id(upstream_result.resource_id) # see if it is the output of a project and if so add to the list upstream_dataset_projects = set( x.project_name for x in upstream_dataset.usage().usage.output_from_project_steps ) upstream_project_names.extend([x for x in upstream_dataset_projects]) # we have all projects in all_projects variables so only return those whose # name is in upstream_project_names return [x for x in all_projects if x.name in upstream_project_names] def _build_edges( project: Project, client: Client, *, edges: set = None, all_projects: List[Project] ) -> Set[Tuple[str, str]]: """ builds a set of tuples of all edges of format (source, target) Args: project: the project to get edges for client: tamr client edges: set of tuples (source, target) all_projects: a list of all Projects on the Tamr Core instance Returns: """ upstream_projects = _get_upstream_projects(project, all_projects=all_projects) if edges is None: edges = set() for upstream_project in upstream_projects: # add the edge for this upstream dataset edges.add((upstream_project.name, project.name)) # if we've already walked backward for this upstream dataset don't keep walking # you know we've walked it if it shows up as the target in an edge (hence the [1] index) if any(upstream_project.name == x[1] for x in edges): LOGGER.debug( f"skipping dataset {upstream_project.name} since it is already" f" in edges as target: {[ x for x in edges if x[1] ==upstream_project.name]}" ) continue else: # and then go to it and get it's upstream datasets further_upstream_edges = _build_edges( upstream_project, client, edges=set(edges), all_projects=all_projects ) # print(f"adding further upstream edges {further_upstream_edges}") edges = edges.union(further_upstream_edges) return edges
[docs]def from_edges(edges: Set[tuple]) -> Graph: """ Directly build a graph from a list of edges - tuples of format (source, target) dependencies Args: edges: List of edges in tuple format Returns: Graph object """ graph = nx.DiGraph() graph.add_edges_from(edges) return Graph(edges=edges, directed_graph=graph)
[docs]def from_project_list(projects: List[Project], client: Client) -> Graph: """ Creates a graph from a list of datasets Args: projects: list of Tamr dataset objects client: tamr client Returns: A Graph object built from the dependencies of the datasets passed """ # start with empty set graph_edges = set() # save all projects to list so we don't have to hit the API every time all_projects = [x for x in client.projects.stream()] # for each dataset get the edges and take union for project in projects: graph_edges = set( graph_edges.union(_build_edges(project, client, all_projects=all_projects)) ) graph = nx.DiGraph() graph.add_edges_from(graph_edges) return Graph(edges=graph_edges, directed_graph=graph)
[docs]def get_source_nodes(graph: Graph) -> List[str]: """ Gives all source nodes in a graph Args: graph: Graph for which to find source nodes Returns: List of node names """ di_graph = graph.directed_graph source_nodes = [] for n in di_graph.nodes(): if not [x for x in di_graph.predecessors(n)]: source_nodes.append(n) return source_nodes
[docs]def get_end_nodes(graph: Graph): """ Returns all end nodes in a directed graph Args: graph: Graph for which to find end nodes Returns: List of names of all end nodes """ di_graph = graph.directed_graph end_nodes = [] for n in di_graph.nodes(): if not [x for x in di_graph.successors(n)]: end_nodes.append(n) return end_nodes
[docs]def get_projects_by_tier(graph: Graph) -> Dict[int, Any]: """ Find the different projects at each tier Args: graph: the Graph for which to generate the tiers Returns: A json dict who's structure is {'tier': [paths_at_that_tier], ...} e.g. {1: ['SM_project_1', 'Classification_project_1'], 2: ['Mastering_project'], 3: ['Golden_records_project']} """ source_nodes = get_source_nodes(graph) path_tier_dict = {0: []} for n in graph.directed_graph.nodes: # just add source nodes to tier 0 if n in source_nodes: path_tier_dict[0].append(n) continue # since all dependent projects must be run first # the tier is the maximum length of all simple paths to this node # so get the max length from all source nodes tier = 0 for s in source_nodes: current_paths = [x for x in nx.all_simple_paths(graph.directed_graph, s, n)] # skip source nodes that don't link to this project if not current_paths: continue max_path_length = len(max(current_paths, key=lambda x: len(x))) if max_path_length > tier: tier = max_path_length # decrement tier to count from 0 tier = tier - 1 # check if tier exists in the dict, if so add to it, else create entry if tier in path_tier_dict: path_tier_dict[tier].append(n) else: path_tier_dict[tier] = [n] return {k: set(v) for k, v in path_tier_dict.items()}
[docs]def get_all_downstream_nodes(graph: Graph, node: str) -> Set[str]: """ Get all nodes downstream of this one (i.e. they have a path from this node to them) Args: graph: the graph to use node: the node to check Returns: A list of downstream node names """ downstream_paths = [] diGraph = graph.directed_graph for n in diGraph.nodes(): all_paths = [x for x in nx.all_simple_paths(diGraph, node, n)] if all_paths: for path in all_paths: downstream_paths.extend(path) # the above generates a list of things like # [projectA, projectB, projectA, projectC] if projectA goes to both project_B and project_C # so make a set, then remove the node itself downstream_nodes = {x for x in set(downstream_paths) if x != node} return downstream_nodes
[docs]def get_successors(graph: Graph, node: str) -> Set[str]: """ Get all successor nodes to the current node Args: graph: the graph to use node: the node to check Returns: A set of nodes that are successors to the current node """ di_graph = graph.directed_graph return set(x for x in di_graph.successors(node))
[docs]def get_predecessors(graph: Graph, node: str) -> Set[str]: """ Get all predecessor nodes to the current node Args: graph: the graph to use node: the node to check Returns: A set of nodes that are predecessors to the current node """ di_graph = graph.directed_graph return set(x for x in di_graph.predecessors(node))
[docs]def add_edges(graph: Graph, edges: Set[tuple]) -> Graph: """ Takes an existing graph and creates a new one with the new edge Args: graph: the graph to start with edges: the edges to add Returns: A copy of initial graph with new edge """ old_edges = set(graph.directed_graph.edges) new_edges = old_edges.union(edges) return from_edges(new_edges)