/cwltool/subgraph.py

https://github.com/common-workflow-language/cwltool · Python · 146 lines · 116 code · 21 blank · 9 comment · 45 complexity · d9849a90360635fa0cf514920d48dfe1 MD5 · raw file

  1. import urllib
  2. from collections import namedtuple
  3. from typing import Dict, MutableMapping, MutableSequence, Optional, Set, Tuple, cast
  4. from ruamel.yaml.comments import CommentedMap
  5. from .utils import CWLObjectType, aslist
  6. from .workflow import Workflow
  7. Node = namedtuple("Node", ("up", "down", "type"))
  8. UP = "up"
  9. DOWN = "down"
  10. INPUT = "input"
  11. OUTPUT = "output"
  12. STEP = "step"
  13. def subgraph_visit(
  14. current: str, nodes: MutableMapping[str, Node], visited: Set[str], direction: str,
  15. ) -> None:
  16. if current in visited:
  17. return
  18. visited.add(current)
  19. if direction == DOWN:
  20. d = nodes[current].down
  21. if direction == UP:
  22. d = nodes[current].up
  23. for c in d:
  24. subgraph_visit(c, nodes, visited, direction)
  25. def declare_node(nodes: Dict[str, Node], nodeid: str, tp: Optional[str]) -> Node:
  26. if nodeid in nodes:
  27. n = nodes[nodeid]
  28. if n.type is None:
  29. nodes[nodeid] = Node(n.up, n.down, tp)
  30. else:
  31. nodes[nodeid] = Node([], [], tp)
  32. return nodes[nodeid]
  33. def get_subgraph(roots: MutableSequence[str], tool: Workflow) -> CommentedMap:
  34. if tool.tool["class"] != "Workflow":
  35. raise Exception("Can only extract subgraph from workflow")
  36. nodes = {} # type: Dict[str, Node]
  37. for inp in tool.tool["inputs"]:
  38. declare_node(nodes, inp["id"], INPUT)
  39. for out in tool.tool["outputs"]:
  40. declare_node(nodes, out["id"], OUTPUT)
  41. for i in aslist(out.get("outputSource", [])):
  42. # source is upstream from output (dependency)
  43. nodes[out["id"]].up.append(i)
  44. # output is downstream from source
  45. declare_node(nodes, i, None)
  46. nodes[i].down.append(out["id"])
  47. for st in tool.tool["steps"]:
  48. step = declare_node(nodes, st["id"], STEP)
  49. for i in st["in"]:
  50. if "source" not in i:
  51. continue
  52. for src in aslist(i["source"]):
  53. # source is upstream from step (dependency)
  54. step.up.append(src)
  55. # step is downstream from source
  56. declare_node(nodes, src, None)
  57. nodes[src].down.append(st["id"])
  58. for out in st["out"]:
  59. # output is downstream from step
  60. step.down.append(out)
  61. # step is upstream from output
  62. declare_node(nodes, out, None)
  63. nodes[out].up.append(st["id"])
  64. # Find all the downstream nodes from the starting points
  65. visited_down = set() # type: Set[str]
  66. for r in roots:
  67. if nodes[r].type == OUTPUT:
  68. subgraph_visit(r, nodes, visited_down, UP)
  69. else:
  70. subgraph_visit(r, nodes, visited_down, DOWN)
  71. def find_step(stepid: str) -> Optional[CWLObjectType]:
  72. for st in tool.steps:
  73. if st.tool["id"] == stepid:
  74. return st.tool
  75. return None
  76. # Now make sure all the nodes are connected to upstream inputs
  77. visited = set() # type: Set[str]
  78. rewire = {} # type: Dict[str, Tuple[str, str]]
  79. for v in visited_down:
  80. visited.add(v)
  81. if nodes[v].type in (STEP, OUTPUT):
  82. for u in nodes[v].up:
  83. if u in visited_down:
  84. continue
  85. if nodes[u].type == INPUT:
  86. visited.add(u)
  87. else:
  88. # rewire
  89. df = urllib.parse.urldefrag(u)
  90. rn = df[0] + "#" + df[1].replace("/", "_")
  91. if nodes[v].type == STEP:
  92. wfstep = find_step(v)
  93. if wfstep is not None:
  94. for inp in cast(
  95. MutableSequence[CWLObjectType], wfstep["inputs"]
  96. ):
  97. if u in inp["source"]:
  98. rewire[u] = (rn, inp["type"])
  99. break
  100. else:
  101. raise Exception("Could not find step %s" % v)
  102. extracted = CommentedMap()
  103. for f in tool.tool:
  104. if f in ("steps", "inputs", "outputs"):
  105. extracted[f] = []
  106. for i in tool.tool[f]:
  107. if i["id"] in visited:
  108. if f == "steps":
  109. for inport in i["in"]:
  110. if "source" not in inport:
  111. continue
  112. if isinstance(inport["source"], MutableSequence):
  113. inport["source"] = [
  114. rewire[s][0]
  115. for s in inport["source"]
  116. if s in rewire
  117. ]
  118. elif inport["source"] in rewire:
  119. inport["source"] = rewire[inport["source"]][0]
  120. extracted[f].append(i)
  121. else:
  122. extracted[f] = tool.tool[f]
  123. for rv in rewire.values():
  124. extracted["inputs"].append({"id": rv[0], "type": rv[1]})
  125. return extracted