datetime:2024-01-19 14:37
author:nzb

图是一种常见的数据结构,用于表示一组对象之间的关系。在图的存储方式中,两种常见的方法是邻接表和邻接矩阵。

  • 邻接表

    • 在邻接表中,图的每个顶点都有一个关联的链表,链表中存储了与该顶点相邻的其他顶点。
    • 对于有向图或无向图,每个顶点对应一个链表,链表中的元素表示与该顶点直接相邻的顶点。
    • 对于有权图,链表中通常包含边的权重。
    • 邻接表对稀疏图更为节省空间,因为它只存储实际存在的边。
  • 邻接矩阵

    • 在邻接矩阵中,图的关系被表示为一个二维数组(矩阵)。
    • 对于无向图,矩阵是对称的;对于有向图,矩阵不一定对称。
    • 矩阵中的元素表示两个顶点之间是否存在边,以及可能包含边的权重。
    • 对于稠密图,邻接矩阵通常更为高效,因为它直接表示所有可能的边。
  • 无向图是一种特殊的有向图
  • 选择邻接表还是邻接矩阵取决于图的特性以及使用场景。邻接表适用于稀疏图,而邻接矩阵适用于稠密图。

  • 邻接表可直接查出一个点后续有多少个邻居点

  • 邻接矩阵可直接查出每条边
  A      B
 |  \  /
 |  3\/2
7|   /\
 |  /  \
 C--5-- D

# 邻接表
A: C(7), D(3)
B: C(2)
C: A(7), B(2), D(5)
D: A(3), C(5)

# 邻接矩阵
     0   1   2   3
     A   B   C   D
0 A  0   +∞  7   3
1 B  +∞  0   2   +∞
2 C  7   2   0   5
3 D  3   +∞  5   0

数据结构


import typing


class Node:

    def __init__(self, value: typing.Any):
        self.value = value
        # 无向图的入度和出度相等
        self.node_in: int = 0  # 点的入度
        self.node_out: int = 0  # 点的出度
        self.next_nodes: typing.List[Node] = list()  # 指向的邻居,只关心发散出去的
        self.edges: typing.List[Edge] = list()  # 有哪些边属于该点


class Edge:

    def __init__(self, weight: int, from_node: Node, to_node: Node):
        self.weight: int = weight
        self.from_node: Node = from_node
        self.to_node: Node = to_node

        # 实现 __hash__ 方法

    def __hash__(self):
        # 可以根据对象的属性生成一个唯一的哈希值
        return hash((self.from_node, self.to_node, self.weight))

    def __lt__(self, other):
        return self.weight < other.weight

    def __eq__(self, other):
        return self.weight == other.weight

    def __gt__(self, other):
        return self.weight > other.weight


class Graph:

    def __init__(self):
        self.nodes: typing.Dict[int, Node] = dict()  # key: 点编号,node点实例
        self.edges: typing.Set[Edge] = set()


def create_graph(graph_data: list):
    graph = Graph()
    for it in graph_data:
        weight = it[0]
        from_node = it[1]
        to_node = it[2]
        if from_node not in graph.nodes:
            graph.nodes[from_node] = Node(from_node)
        if to_node not in graph.nodes:
            graph.nodes[to_node] = Node(to_node)

        from_node_ins = graph.nodes.get(from_node)
        to_node_ins = graph.nodes.get(to_node)
        edge = Edge(weight, from_node_ins, to_node_ins)
        from_node_ins.next_nodes.append(to_node_ins)
        from_node_ins.node_out += 1
        to_node_ins.node_in += 1
        from_node_ins.edges.append(edge)
        graph.edges.add(edge)
    return graph


data = [
    [5, 0, 1],  # 权重5,0指向1
    [3, 1, 2],
    [7, 0, 2]
]

res = create_graph(data)
print(res)

图的宽度优先遍历

跟二叉树的宽度优先遍历的区别就是,二叉树没有环,而图有,要解决的就是不进入环,出不来的问题

  • 1,利用队列实现
  • 2,从源节点开始依次按照宽度进队列,然后弹出
  • 3,每弹出一个点,把该节点所有没有进过队列的邻接点放入队列
  • 4,直到队列变空
  • 5,注意处理有环图

from collections import deque

#  A------E
#  | \  /
#  |  \/
#  |   B
#  |  /\
#  | /  \
#  C------D
data = [
    [5, "A", "E"],  # A -> E
    [5, "A", "B"],
    [5, "A", "C"],
    [5, "B", "C"],
    [5, "B", "A"],
    [5, "B", "E"],
    [5, "B", "D"],
    [5, "C", "A"],
    [5, "C", "B"],
    [5, "C", "D"],
    [5, "D", "C"],
    [5, "D", "B"],
    [5, "D", "E"],
    [5, "E", "A"],
    [5, "E", "B"],
    [5, "E", "D"],
]

res = create_graph(data)


def dfs_width(node: Node):
    if not node:
        return
    dq = deque([node])
    # 用于处理有环的图,如果无环可以不用加这个
    # 实际上,为了更快,可以把哈希表换成数组,因为现实中,城市编号不会特别大,数据的速度比哈希表的常数时间快
    node_set = set()
    node_set.add(node)
    while dq:
        cur = dq.popleft()
        # 打印或处理(定制)
        print(cur.value, end=" ")
        for next_node in cur.next_nodes:
            if next_node not in node_set:
                node_set.add(next_node)
                dq.append(next_node)


dfs_width(res.nodes["A"])

广度优先遍历

  • 1,利用栈实现
  • 2,从源节点开始把节点按照深度放入栈,然后弹出
  • 3,每弹出一个点,把该节点下一个没有进过栈的邻接点放入栈,弹出的节点先重新入栈
  • 4,直到栈变空
  • 5,就是一条路走到底
  • 6,注意处理有环图

#       A
#     / | \
#    /  |  \
#   /   |   \
#  B----C----E
#       |   /
#       |  /
#       | /
#       D

data = [
    [5, "A", "B"],
    [5, "A", "E"],  # A -> E
    [5, "A", "C"],
    [5, "B", "C"],
    [5, "B", "A"],
    [5, "C", "A"],
    [5, "C", "B"],
    [5, "C", "D"],
    [5, "C", "E"],
    [5, "D", "C"],
    [5, "D", "E"],
    [5, "E", "A"],
    [5, "E", "C"],
    [5, "E", "D"],
]

res = create_graph(data)

"""
A 先加入栈,处理,栈不为空
弹出A,A的下一个节点有B,C,E,不妨处理B,集合不含B,重新把A入栈,然后B入栈,集合加入B,处理B,break跳出,不再看C,E
弹出B, ->A,C,A已经在集合里面跳过,来到C,B重新入栈,C入栈,处理C,break
弹出C, ->A,B,D,E,A,B已经在集合跳过,来到D, C重新入栈,D入栈,处理D,break
弹出D, ->C,E,C跳过,来到E, D,E入栈,处理E,break
弹出E, ->A,C,D都在集合里跳过
弹出D
弹出C
弹出B
弹出A
"""


def dfs_scope(node: Node):
    if not node:
        return
    stack = [node]
    # 用于处理有环的图,如果无环可以不用加这个
    # 实际上,为了更快,可以把哈希表换成数组,因为现实中,城市编号不会特别大,数据的速度比哈希表的常数时间快
    node_set = set()
    node_set.add(node)
    # 打印或替换成其他处理函数
    print(node.value, end=" ")
    while stack:
        cur = stack.pop()
        for next_node in cur.next_nodes:
            # 如果有一条没走完,它会在继续重新入栈,继续往下走
            # 一路上走到黑
            if next_node not in node_set:
                stack.append(cur)
                stack.append(next_node)
                node_set.add(next_node)
                # 打印或替换成其他处理函数
                print(next_node.value, end=" ")
                break


dfs_scope(res.nodes["A"])

拓扑排序算法

适用范围:要求有向图,且有入度为0的节点,且没有环

比如程序编译依赖包关系

-----------
|         |
|         V
A -> B -> C -> D
     |         ^
     |         |
     -----------

# -----------
# |         |
# |         V
# A -> B -> C -> D
#      |         ^
#      |         |
#      -----------

data = [
    [5, "A", "B"],
    [5, "B", "C"],
    [5, "C", "D"],
    [5, "A", "C"],
    [5, "B", "D"],

]

graph = create_graph(data)

from collections import deque


def topology_sort(graph: Graph):
    node_map = dict()  # node: 剩余入度
    dq0 = deque()  # 入度为0的点
    for node in graph.nodes.values():
        node_map[node] = node.node_in
        if node.node_in == 0:
            dq0.append(node)

    res = []  # 拓扑排序的结果
    while dq0:
        cur = dq0.popleft()
        res.append(cur)
        # 擦除当前节点的影响
        for next_node in cur.next_nodes:
            node_map[next_node] -= 1
            if node_map[next_node] == 0:
                dq0.append(next_node)
    return res


for i in topology_sort(graph):
    print(i.value)

最小生成树

使用并查集,查询和合并的速度是常数级别

#        7      10万
#    A------B--------E
#    | \    |
#    |  \   |
#   2|100\  |1000
#    |    \ |
#    C------D
#        4

# 集合查询、合并
# 刚开始:{A}, {B}, {C}, {D}, {E}
# 加上2这条边:这条边的from(A)和to(B)不在一个集合里,所以加上,然后A,C合并:{A, C}, {B}, {D}, {E}
# 加上4这条边:这条边的from(C)和to(D)不在一个集合里,所以加上,然后A,C合并:{A, C, D}, {B}, {E}
# 加上7这条边:这条边的from(A)和to(B)不在一个集合里,所以加上,然后A,C合并:{A, C, D, B}, {E}
# 加上100这条边:这条边的from(A)和to(D)在一个集合里,所以这条边不要,{A, C, D, B}, {E}
# 加上1000这条边:这条边的from(B)和to(D)在一个集合里,所以这条边不要,{A, C, D, B}, {E}
# 加上10万这条边:这条边的from(B)和to(E)不在一个集合里,所以加上,然后A,C合并:{A, C, D, B, E}

# 差不多这样的结构
import typing


class UnionFind:
    def __init__(self, nodes: typing.List[Node]):
        self.node_map = {}
        for node in nodes:
            self.node_map[node] = {node}

    def find(self, from_node: Node, to_node: Node):
        return self.node_map.get(from_node) == self.node_map.get(to_node)

    def union(self, from_node: Node, to_node: Node):
        from_set = self.node_map.get(from_node)
        to_set = self.node_map.get(to_node)
        if from_set is not None and to_set is not None and from_set != to_set:
            for node in to_set:
                from_set.add(node)
                self.node_map[node] = from_set  # 被合并的集合里面的节点需要指向,最新的集合地址
#        3
#    A------B
#    | \  / |
# 100| 7\/5 |
#    |  /\  |2
#    | /  \ |
#    C------D
#      1000


# 保证连通性的同时,权值最小

#        3
#    A------B
#          /|
#         /5|
#        /  |2
#       /   |
#      C    D

kruskal算法

适用范围:要求无向图

  • kruskal算法以边的角度出发
  • 把边排序,依次选择最小的边,看这条边加上,看有没有形成环
    • 没有,要这条边
    • 有,不要这条边
  • 因此需要一种检测有没有形成环的功能(并查集)
#        7      10万
#    A------B--------E
#    | \    |
#    |  \   |
#   2|100\  |1000
#    |    \ |
#    C------D
#        4

data = [
    [7, "A", "B"],
    [100000, "B", "E"],
    [2, "A", "C"],
    [100, "A", "D"],
    [1000, "B", "D"],
    [4, "C", "D"],

]

graph = create_graph(data)

import typing
import heapq


def kruskalMST(graph: Graph):
    union_find = UnionFind(list(graph.nodes.values()))
    # 优先级队列(小根堆),该方法需要Edge实现比较魔术函数
    hq = []
    for value in graph.edges:  # M条边
        heapq.heappush(hq, value)  # O(logM)

    result: typing.List[Edge] = []
    while hq:  # M条边
        cur_edge = heapq.heappop(hq)  # O(logM)
        if not union_find.find(cur_edge.from_node, cur_edge.to_node):  # 查
            result.append(cur_edge)
            union_find.union(cur_edge.from_node, cur_edge.to_node)  # 合并
    return result


res = kruskalMST(graph)
for it in res:
    print(it.weight)

prim算法

适用范围:要求无向图


#        7      10万
#    A------B--------E
#    | \    |
#    |  \   |
#   2|100\  |1000
#    |    \ |
#    C------D
#        4

data = [
    [7, "A", "B"],
    [100000, "B", "E"],
    [2, "A", "C"],
    [100, "A", "D"],
    [1000, "B", "D"],
    [4, "C", "D"],

]

graph = create_graph(data)

import typing
import heapq


def primMST(graph: Graph):
    hq = []
    node_set: typing.Set[Node] = set()  # 是否是新点
    result: typing.List[Edge] = []  # 依次挑选的边
    # 用于处理森林的情况,比如多个不连通的图,各自生成最小生成树
    for node in graph.nodes.values():
        # 新的一个点,联通的图的话,这里开始就行
        if node not in node_set:
            node_set.add(node)
            for edge in node.edges:  # 该点所有的边加入到优先级队列中
                heapq.heappush(hq, edge)
            while hq:
                edge = heapq.heappop(hq)  # 弹出最小的边
                to_node = edge.to_node  # 可能是新的点
                if to_node not in node_set:  # 如果边的to节点不在集合里面,就是新点
                    node_set.add(to_node)
                    result.append(edge)  # 加入结果中
                    for edge in to_node.edges:
                        heapq.heappush(hq, edge)  # 一条边可能会重复加入,但是不会影响结果,应该已经在集合里面了,最多增加常数时间
    return result


res = primMST(graph)
for it in res:
    print(it.weight)

Dijkstra算法

解决一个点到所有点的最短路径

适用范围:没有权值为负数的边

#         D
#       / | \
#      /  |  \
#     /   |   \
#   9/    |7   \16
#   /     |     \
#  A--15--C--14--E
#   \     |     /
#   3\    |2   /200
#     \   |   /
#      \  |  /
#       \ | /
#         B

"""
# 先生成一张表:A到各个点的距离,初始位正无穷
    A   B    C    D    E
A   0  +∞   +∞   +∞   +∞


每一次在表里选距离最短的点A
然后看从这个点出发的边,能不能把这张表的记录变得更小
A -> B(3)
    A   B    C    D    E
A   0  3   +∞   +∞   +∞

A -> B(15)
    A   B    C    D    E
A   0   3    15   +∞   +∞

A -> D(9)
    A   B    C    D    E
A   0   3    15   9   +∞
 (不动了)


在剩下的记录中找距离最短的点B
B(3) -> A(3), 到A, 3+3 > 0, 跳过

B(3) -> C(2), 到C, 3+2 < 15, 改写
    A     B     C     D     E
A   0     3     5     9    +∞

B(3) -> E(200), 到C, 3+200 < +∞, 改写
    A     B     C     D     E
A   0     3     5     9     203
 (不动了) (不动了)



在剩下的记录中找距离最短的点C
C(5) -> A(15), 到A, 5+15 > 0, 跳过

C(5) -> B(2), 到C, 5+2 > 3, 跳过

C(5) -> E(14), 到C, 5+14 < 203, 改写
    A     B     C     D     E
A   0     3     5     9     203

C(5) -> D(7), 到C, 5+7 > 9, 跳过
    A     B     C     D     E
A   0     3     5     9     19
(不动了)(不动了)(不动了)

以此类推
"""

#         D
#       / | \
#      /  |  \
#     /   |   \
#   9/    |7   \16
#   /     |     \
#  A--15--C--14--E
#   \     |     /
#   3\    |2   /200
#     \   |   /
#      \  |  /
#       \ | /
#         B

data = [
    [3, "A", "B"],
    [9, "A", "D"],
    [15, "A", "C"],
    [3, "B", "A"],
    [2, "B", "C"],
    [200, "B", "E"],
    [15, "C", "A"],
    [2, "C", "B"],
    [14, "C", "E"],
    [7, "C", "D"],
    [9, "D", "A"],
    [7, "D", "C"],
    [16, "D", "E"],
    [16, "E", "D"],
    [14, "E", "C"],
    [200, "E", "B"]
]

graph = create_graph(data)


def get_min_distance_and_unselected_node(node_map: dict, select_set: set):
    min_node = None
    min_dist = float("inf")
    for node, dist in node_map.items():
        if node not in select_set and dist < min_dist:
            min_node = node
            min_dist = dist
    return min_node


def dijkstra(node: Node):
    # 从head出发到所有点的最小距离
    # key: 从head出发到达的点
    # value: 从head出发到key最小的距离
    # 如果在表中,没有T记录,含义是从head出发到T这个点的距离为正无穷
    node_map = {node: 0}
    # 已经求过距离的点,不在处理
    selected_nodes = set()  # 锁住不动的点
    min_node = get_min_distance_and_unselected_node(node_map, selected_nodes)
    while min_node:
        dist = node_map.get(min_node)
        for edge in min_node.edges:
            to_node = edge.to_node
            # 新增
            if to_node not in node_map:  # 没有,正无穷,新增
                node_map[to_node] = dist + edge.weight
            else:
                # 更新
                node_map[to_node] = min(node_map.get(to_node), dist + edge.weight)  # 有没有变小
        selected_nodes.add(min_node)  # 锁住,该点不用操作了
        min_node = get_min_distance_and_unselected_node(node_map, selected_nodes)
    return node_map


data = dijkstra(graph.nodes["A"])
for k, v in data.items():
    print(k.value, v)
  • 堆结构改写Dijkstra算法

指标不变,只是加快了常数时间

  • 从源点出发,源点放入小根堆
  • 源点弹出,从该点出发的到的点和距离放入小根堆
  • 再弹出堆顶,查看从该点出发到的点有没有让距离变小,变小放入小根堆
  • 依次类推下去,利用小根堆依次弹出数据,源点出发到每个点的最短距离就出来了

import typing


class Node:

    def __init__(self, value: typing.Any):
        self.value = value
        # 无向图的入度和出度相等
        self.node_in: int = 0  # 点的入度
        self.node_out: int = 0  # 点的出度
        self.next_nodes: typing.List[Node] = list()  # 指向的邻居,只关心发散出去的
        self.edges: typing.List[Edge] = list()  # 有哪些边属于该点


class Edge:

    def __init__(self, weight: int, from_node: Node, to_node: Node):
        self.weight: int = weight
        self.from_node: Node = from_node
        self.to_node: Node = to_node

        # 实现 __hash__ 方法

    def __hash__(self):
        # 可以根据对象的属性生成一个唯一的哈希值
        return hash((self.from_node, self.to_node, self.weight))

    def __lt__(self, other):
        return self.weight < other.weight

    def __eq__(self, other):
        return self.weight == other.weight

    def __gt__(self, other):
        return self.weight > other.weight


class Graph:

    def __init__(self):
        self.nodes: typing.Dict[int, Node] = dict()  # key: 点编号,node点实例
        self.edges: typing.Set[Edge] = set()


def create_graph(graph_data: list):
    graph = Graph()
    for it in graph_data:
        weight = it[0]
        from_node = it[1]
        to_node = it[2]
        if from_node not in graph.nodes:
            graph.nodes[from_node] = Node(from_node)
        if to_node not in graph.nodes:
            graph.nodes[to_node] = Node(to_node)

        from_node_ins = graph.nodes.get(from_node)
        to_node_ins = graph.nodes.get(to_node)
        edge = Edge(weight, from_node_ins, to_node_ins)
        from_node_ins.next_nodes.append(to_node_ins)
        from_node_ins.node_out += 1
        to_node_ins.node_in += 1
        from_node_ins.edges.append(edge)
        graph.edges.add(edge)
    return graph


class NodeRecord:
    def __init__(self, node: Node, distance: int):
        self.node = node
        self.distance = distance


class NodeHeap:
    def __init__(self, node_size: int):
        self.nodes: typing.List[Node] = [None] * node_size  # 堆
        self.heap_index_map: typing.Dict[Node, int] = dict()  # 在堆上的位置索引
        self.distance_map: typing.Dict[Node, int] = dict()
        self.size = 0

    def add_or_update_or_ignore(self, node: Node, distance: int):
        # 在堆上,更新
        if self._in_heap(node):
            self.distance_map.update({node: min(self.distance_map[node], distance)})  # 更新最小距离
            # 变小了,可能经历一个往上的过程
            self._heap_insert(self.heap_index_map[node])
        # 没进来过
        if not self._is_entered(node):
            self.nodes[self.size] = node  # 新增
            self.heap_index_map[node] = self.size
            self.distance_map[node] = distance
            # 可能经历一个往上的过程
            self._heap_insert(self.size)
            self.size += 1
        # 已经弹出过的,-1的忽略

    def is_empty(self):
        return self.size == 0

    def pop(self) -> NodeRecord:
        node_record = NodeRecord(self.nodes[0], self.distance_map[self.nodes[0]])
        self._swap(0, self.size - 1)  # 最后一个元素放到堆顶
        self.heap_index_map[self.nodes[self.size - 1]] = -1  # 因为上面交换到最后一个元素,然后弹出了,所以置为-1
        self.distance_map.pop(self.nodes[self.size - 1])  # 删掉对应节点距离信息
        self.nodes[self.size - 1] = None  # 最后一个位置释放掉
        self.size -= 1
        self._heapify(0, self.size)
        return node_record

    def _is_entered(self, node: Node) -> bool:
        """
        node是否进过堆
        :param node:
        :return: 进过返回对应索引或-1,否则未进入
        -1 代办进来过,但处理完了
        """
        # return self.heap_index_map.get(node)  # 索引为0 会被当成False
        return node in self.heap_index_map

    def _in_heap(self, node: Node) -> bool:
        """
        在不在堆上
        进来过,并且不等于-1
        -1 表示(弹出了)不在堆上
        :param node:
        :return:
        """
        return self._is_entered(node) and self.heap_index_map.get(node, -1) != -1

    def _swap(self, index1: int, index2: int):
        """
        调整堆
        :param index1:
        :param index2:
        :return:
        """
        # 索引对换更新
        self.heap_index_map.update({self.nodes[index1]: index2, self.nodes[index2]: index1})
        # self.heap_index_map[self.nodes[index1]], self.heap_index_map[self.nodes[index2]] = index2, index1
        # 堆元素对换
        self.nodes[index1], self.nodes[index2] = self.nodes[index2], self.nodes[index1]

    # 小根堆调整
    def _heap_insert(self, index: int):
        """
        根据距离往上调整
        :param index:
        :return:
        """
        parent_idx = max(0, (index - 1) >> 1)
        while self.distance_map[self.nodes[index]] < self.distance_map[self.nodes[parent_idx]]:
            self._swap(index, parent_idx)
            index, parent_idx = parent_idx, max(0, (index - 1) >> 1)

    def _heapify(self, index: int, size: int):
        left = 2 * index + 1
        while left < size:
            smallest = left + 1 if left + 1 < size and self.distance_map[self.nodes[left + 1]] < self.distance_map[
                self.nodes[left]] else left
            # 最小的都比父大,完成调整
            if self.distance_map[self.nodes[smallest]] >= self.distance_map[self.nodes[index]]:
                break
            self._swap(index, smallest)
            index, left = smallest, 2 * index + 1


def dijkstra_heap(head: Node, size: int):
    """
    利用小根堆改进dijkstra算法
    从head出发,所有head能到到达的节点,生成到达每个节点的最小路径记录并返回
    :param head:
    :param size:
    :return:
    """
    node_heap = NodeHeap(size)
    node_heap.add_or_update_or_ignore(head, 0)
    result: typing.Dict[Node, int] = dict()
    while not node_heap.is_empty():
        cur = node_heap.pop()
        for edge in cur.node.edges:
            node_heap.add_or_update_or_ignore(edge.to_node, cur.distance + edge.weight)
        result[cur.node] = cur.distance
    return result


#         D
#       / | \
#      /  |  \
#     /   |   \
#   9/    |7   \16
#   /     |     \
#  A--15--C--14--E
#   \     |     /
#   3\    |2   /200
#     \   |   /
#      \  |  /
#       \ | /
#         B

data = [
    [3, "A", "B"],
    [9, "A", "D"],
    [15, "A", "C"],
    [3, "B", "A"],
    [2, "B", "C"],
    [200, "B", "E"],
    [15, "C", "A"],
    [2, "C", "B"],
    [14, "C", "E"],
    [7, "C", "D"],
    [9, "D", "A"],
    [7, "D", "C"],
    [16, "D", "E"],
    [16, "E", "D"],
    [14, "E", "C"],
    [200, "E", "B"]
]

graph = create_graph(data)

data = dijkstra_heap(graph.nodes["A"], size=len(graph.nodes))
for k, v in data.items():
    print(k.value, v)

A*算法

  • A*解决的是指定源点,指定目标点,求源点到达目标点的最短距离,增加了当前点到终点的启发函数(预估函数)
  • Dijkstra的小根堆是根据源点到当前点的代价排序的, 而A*根据从源点到当前点的距离 + 当前点到目标点的估计距离来进行排序,剩下的所有细节和Dijkstra算法一致
  • Dijkstra算法其实是一种特殊的A*算法,只是当前点到目标点的估计距离都是0

启发函数要求

当前点到终点的预估距离 <= 当前点到终点的真实最短距离(如果有障碍物,真实最短距离因为需要绕障肯定就变长)

预估函数是一种吸引力

  • 合适的吸引力可以提升算法的速度,吸引力过强会出现错误
  • 保证预估距离 <= 真实最短距离的情况下,尽量接近真实最短距离,可以做到功能正确且最快

预估终点距离经常选择

  • 曼哈顿距离(适用二维网格图,只能上下左右,不能斜线,如果能走斜线就不适用)
  • 欧式距离
  • 对角线距离:max(|行 - 行|, |列 - 列|)(适用可以走斜线的二维网格)

Q R S T U
J K L M V
E F G N W
B C H O X
A D I P Y

目标:从 G 到 U
限制:只能上下左右

BFS:宽度优先遍历
第一层:L F H N
第二层:K S M E C ...
...

Dijkstra:引入了堆可能比BFS差
上下左右暂开,距离都是1放入小根堆
弹出,继续往复


A*

      原点G到当前点距离   预估到目标距离(曼哈顿距离)      堆           
G           0               4                   G(0+4)    第1次弹出     
L           1               3                   L(1+3)    第2次弹出
F           1               5                   F(1+5)
S           2               2                   S(2+2)    第3次弹出
...


真实距离

G  L  F  S
0  1  1  2

import heapq
import random
import math
import time


class AStarAlgorithm:

    def __init__(self):
        # x坐标变化值,0: 上,1: 右,2: 下,3: 左
        # y坐标变化值,索引加一
        self.move_direction = (-1, 0, 1, 0, -1)

    def dijkstra(self, grid, start_x, start_y, target_x, target_y):
        """
        特殊的A*算法
        grid[i][j] == 0 代表障碍物
        grid[i][j] == 1 代表道路
        只能走上下左右,不能走斜线
        返回从(start_x, start_y)到(target_x, target_y)的最短距离
        :return:
        """
        if grid[start_x][start_y] == 0 or grid[target_x][target_y] == 0:
            return -1
        m, n = len(grid), len(grid[0])
        distance = [[float("inf") for _ in range(n)] for _ in range(m)]
        visited = [[False for _ in range(n)] for _ in range(m)]

        distance[start_x][start_y] = 1
        hq = []  # 小根堆
        heapq.heappush(hq, (1, start_x, start_y))  # (从源点出发到当前点距离, 行, 列)

        while hq:
            _, x, y = heapq.heappop(hq)
            if visited[x][y]:
                continue
            visited[x][y] = True

            if x == target_x and y == target_y:  # 找到目标点
                return distance[x][y]

            for i in range(4):
                nx = x + self.move_direction[i]
                ny = y + self.move_direction[i + 1]
                # 未越界,可走,未访问过,距离小于存在的距离
                if 0 <= nx < m and 0 <= ny < n and grid[nx][ny] == 1 and not visited[nx][ny]
                    and distance[x][y] + 1 < distance[nx][ny]:
                distance[nx][ny] = distance[x][y] + 1
                heapq.heappush(hq, (distance[x][y] + 1, nx, ny))

    return -1


def a_star(self, grid, start_x, start_y, target_x, target_y):
    if grid[start_x][start_y] == 0 or grid[target_x][target_y] == 0:
        return -1
    m, n = len(grid), len(grid[0])
    distance = [[float("inf") for _ in range(n)] for _ in range(m)]
    visited = [[False for _ in range(n)] for _ in range(m)]
    distance[start_x][start_y] = 1

    hq = []
    heapq.heappush(hq, (1 + self.manhattan_distance(start_x, start_y, target_x, target_y), start_x, start_y))

    while hq:
        _, x, y = heapq.heappop(hq)
        if visited[x][y]:
            continue
        visited[x][y] = True

        if x == target_x and y == target_y:
            return distance[x][y]

        for i in range(4):
            nx = x + self.move_direction[i]
            ny = y + self.move_direction[i + 1]
            if 0 <= nx < m and 0 <= ny < n and grid[nx][ny] == 1 and not visited[nx][ny]
                and distance[x][y] + 1 < distance[nx][ny]:
            distance[nx][ny] = distance[x][y] + 1
            heapq.heappush(hq,
                           (distance[x][y] + 1 + self.manhattan_distance(nx, ny, target_x, target_y), nx, ny))


return -1


@staticmethod
def manhattan_distance(sx, sy, tx, ty):
    """曼哈顿距离"""
    return abs(tx - sx) + abs(ty - sy)


@staticmethod
def diagonal_distance(sx, sy, tx, ty):
    """对角线距离"""
    return max(abs(tx - sx), abs(ty - sy))


@staticmethod
def euclidean_distance(sx, sy, tx, ty):
    """欧式距离"""
    return math.hypot(tx - sx, ty - sy)


@staticmethod
def random_grid(n):
    # 30%概率为0
    return [[0 if random.random() < 0.3 else 1 for _ in range(n)] for _ in range(n)]


def main():
    ins = AStarAlgorithm()

    # length = 100
    # test_times = 10000
    # print("功能测试开始")
    # for _ in range(test_times):
    #     n = int(random.random() * length) + 2
    #     grid = ins.random_grid(n)
    #     start_x, start_y = int(random.random() * n), int(random.random() * n)
    #     target_x, target_y = int(random.random() * n), int(random.random() * n)
    #     ans1 = ins.dijkstra(grid, start_x, start_y, target_x, target_y)
    #     ans2 = ins.a_star(grid, start_x, start_y, target_x, target_y)
    #     if ans1 != ans2:
    #         print("Oops, It's wrong")
    # print("功能测试结束")

    print("性能测试开始")
    grid = ins.random_grid(1000)
    sx, sy, tx, ty = 0, 0, 990, 990
    start_time = time.time()
    ans1 = ins.dijkstra(grid, sx, sy, tx, ty)
    print(f"Dijkstra算法结果:{ans1}, 运行时间: {int((time.time() - start_time) * 1e3)} ms")

    start_time = time.time()
    ans2 = ins.a_star(grid, sx, sy, tx, ty)
    print(f"A*算法结果:{ans2}, 运行时间: {int((time.time() - start_time) * 1e3)} ms")
    print("性能测试结束")
    # 性能测试开始
    # Dijkstra算法结果:1985, 运行时间: 6829 ms
    # A*算法结果:1985, 运行时间: 1204 ms
    # 性能测试结束


if __name__ == '__main__':
    main()

Dijkstra算法和A*算法区别

  • A*算法与Dijkstra算法最大的区别就是提供了启发函数,通过结合已经走过的路径长度和启发函数的估计值,选择下一个要走的路;
  • Dijkstra算法是一种贪心算法,它只考虑已知的路径长度,没有考虑目标位置
  • A*算法因为提供了启发函数,通常情况下比Dijkstra算法快,特别是目标位置距离起始位置较远,有多条路径可以选择的情况
  • Dijkstra算法在某些情况下可能比A*算法快,比如路径代价均匀分布,导致启发函数估计值不够精确,就会导致大量的路径遍历

results matching ""

    No results matching ""