datetime:2024-01-19 14:37
author:nzb
图
图是一种常见的数据结构,用于表示一组对象之间的关系。在图的存储方式中,两种常见的方法是邻接表和邻接矩阵。
邻接表
- 在邻接表中,图的每个顶点都有一个关联的链表,链表中存储了与该顶点相邻的其他顶点。
- 对于有向图或无向图,每个顶点对应一个链表,链表中的元素表示与该顶点直接相邻的顶点。
- 对于有权图,链表中通常包含边的权重。
- 邻接表对稀疏图更为节省空间,因为它只存储实际存在的边。
邻接矩阵
- 在邻接矩阵中,图的关系被表示为一个二维数组(矩阵)。
- 对于无向图,矩阵是对称的;对于有向图,矩阵不一定对称。
- 矩阵中的元素表示两个顶点之间是否存在边,以及可能包含边的权重。
- 对于稠密图,邻接矩阵通常更为高效,因为它直接表示所有可能的边。
- 无向图是一种特殊的有向图
选择邻接表还是邻接矩阵取决于图的特性以及使用场景。邻接表适用于稀疏图,而邻接矩阵适用于稠密图。
邻接表可直接查出一个点后续有多少个邻居点
- 邻接矩阵可直接查出每条边
A B
| \ /
| 3\/2
7| /\
| / \
C--5-- D
# 邻接表
A: C(7), D(3)
B: C(2)
C: A(7), B(2), D(5)
D: A(3), C(5)
# 邻接矩阵
0 1 2 3
A B C D
0 A 0 +∞ 7 3
1 B +∞ 0 2 +∞
2 C 7 2 0 5
3 D 3 +∞ 5 0
数据结构
import typing
class Node:
def __init__(self, value: typing.Any):
self.value = value
# 无向图的入度和出度相等
self.node_in: int = 0 # 点的入度
self.node_out: int = 0 # 点的出度
self.next_nodes: typing.List[Node] = list() # 指向的邻居,只关心发散出去的
self.edges: typing.List[Edge] = list() # 有哪些边属于该点
class Edge:
def __init__(self, weight: int, from_node: Node, to_node: Node):
self.weight: int = weight
self.from_node: Node = from_node
self.to_node: Node = to_node
# 实现 __hash__ 方法
def __hash__(self):
# 可以根据对象的属性生成一个唯一的哈希值
return hash((self.from_node, self.to_node, self.weight))
def __lt__(self, other):
return self.weight < other.weight
def __eq__(self, other):
return self.weight == other.weight
def __gt__(self, other):
return self.weight > other.weight
class Graph:
def __init__(self):
self.nodes: typing.Dict[int, Node] = dict() # key: 点编号,node点实例
self.edges: typing.Set[Edge] = set()
def create_graph(graph_data: list):
graph = Graph()
for it in graph_data:
weight = it[0]
from_node = it[1]
to_node = it[2]
if from_node not in graph.nodes:
graph.nodes[from_node] = Node(from_node)
if to_node not in graph.nodes:
graph.nodes[to_node] = Node(to_node)
from_node_ins = graph.nodes.get(from_node)
to_node_ins = graph.nodes.get(to_node)
edge = Edge(weight, from_node_ins, to_node_ins)
from_node_ins.next_nodes.append(to_node_ins)
from_node_ins.node_out += 1
to_node_ins.node_in += 1
from_node_ins.edges.append(edge)
graph.edges.add(edge)
return graph
data = [
[5, 0, 1], # 权重5,0指向1
[3, 1, 2],
[7, 0, 2]
]
res = create_graph(data)
print(res)
图的宽度优先遍历
跟二叉树的宽度优先遍历的区别就是,二叉树没有环,而图有,要解决的就是不进入环,出不来的问题
- 1,利用队列实现
- 2,从源节点开始依次按照宽度进队列,然后弹出
- 3,每弹出一个点,把该节点所有没有进过队列的邻接点放入队列
- 4,直到队列变空
- 5,注意处理有环图
from collections import deque
# A------E
# | \ /
# | \/
# | B
# | /\
# | / \
# C------D
data = [
[5, "A", "E"], # A -> E
[5, "A", "B"],
[5, "A", "C"],
[5, "B", "C"],
[5, "B", "A"],
[5, "B", "E"],
[5, "B", "D"],
[5, "C", "A"],
[5, "C", "B"],
[5, "C", "D"],
[5, "D", "C"],
[5, "D", "B"],
[5, "D", "E"],
[5, "E", "A"],
[5, "E", "B"],
[5, "E", "D"],
]
res = create_graph(data)
def dfs_width(node: Node):
if not node:
return
dq = deque([node])
# 用于处理有环的图,如果无环可以不用加这个
# 实际上,为了更快,可以把哈希表换成数组,因为现实中,城市编号不会特别大,数据的速度比哈希表的常数时间快
node_set = set()
node_set.add(node)
while dq:
cur = dq.popleft()
# 打印或处理(定制)
print(cur.value, end=" ")
for next_node in cur.next_nodes:
if next_node not in node_set:
node_set.add(next_node)
dq.append(next_node)
dfs_width(res.nodes["A"])
广度优先遍历
- 1,利用栈实现
- 2,从源节点开始把节点按照深度放入栈,然后弹出
- 3,每弹出一个点,把该节点下一个没有进过栈的邻接点放入栈,弹出的节点先重新入栈
- 4,直到栈变空
- 5,就是一条路走到底
- 6,注意处理有环图
# A
# / | \
# / | \
# / | \
# B----C----E
# | /
# | /
# | /
# D
data = [
[5, "A", "B"],
[5, "A", "E"], # A -> E
[5, "A", "C"],
[5, "B", "C"],
[5, "B", "A"],
[5, "C", "A"],
[5, "C", "B"],
[5, "C", "D"],
[5, "C", "E"],
[5, "D", "C"],
[5, "D", "E"],
[5, "E", "A"],
[5, "E", "C"],
[5, "E", "D"],
]
res = create_graph(data)
"""
A 先加入栈,处理,栈不为空
弹出A,A的下一个节点有B,C,E,不妨处理B,集合不含B,重新把A入栈,然后B入栈,集合加入B,处理B,break跳出,不再看C,E
弹出B, ->A,C,A已经在集合里面跳过,来到C,B重新入栈,C入栈,处理C,break
弹出C, ->A,B,D,E,A,B已经在集合跳过,来到D, C重新入栈,D入栈,处理D,break
弹出D, ->C,E,C跳过,来到E, D,E入栈,处理E,break
弹出E, ->A,C,D都在集合里跳过
弹出D
弹出C
弹出B
弹出A
"""
def dfs_scope(node: Node):
if not node:
return
stack = [node]
# 用于处理有环的图,如果无环可以不用加这个
# 实际上,为了更快,可以把哈希表换成数组,因为现实中,城市编号不会特别大,数据的速度比哈希表的常数时间快
node_set = set()
node_set.add(node)
# 打印或替换成其他处理函数
print(node.value, end=" ")
while stack:
cur = stack.pop()
for next_node in cur.next_nodes:
# 如果有一条没走完,它会在继续重新入栈,继续往下走
# 一路上走到黑
if next_node not in node_set:
stack.append(cur)
stack.append(next_node)
node_set.add(next_node)
# 打印或替换成其他处理函数
print(next_node.value, end=" ")
break
dfs_scope(res.nodes["A"])
拓扑排序算法
适用范围:要求有向图,且有入度为0的节点,且没有环
比如程序编译依赖包关系
-----------
| |
| V
A -> B -> C -> D
| ^
| |
-----------
# -----------
# | |
# | V
# A -> B -> C -> D
# | ^
# | |
# -----------
data = [
[5, "A", "B"],
[5, "B", "C"],
[5, "C", "D"],
[5, "A", "C"],
[5, "B", "D"],
]
graph = create_graph(data)
from collections import deque
def topology_sort(graph: Graph):
node_map = dict() # node: 剩余入度
dq0 = deque() # 入度为0的点
for node in graph.nodes.values():
node_map[node] = node.node_in
if node.node_in == 0:
dq0.append(node)
res = [] # 拓扑排序的结果
while dq0:
cur = dq0.popleft()
res.append(cur)
# 擦除当前节点的影响
for next_node in cur.next_nodes:
node_map[next_node] -= 1
if node_map[next_node] == 0:
dq0.append(next_node)
return res
for i in topology_sort(graph):
print(i.value)
最小生成树
使用并查集,查询和合并的速度是常数级别
# 7 10万
# A------B--------E
# | \ |
# | \ |
# 2|100\ |1000
# | \ |
# C------D
# 4
# 集合查询、合并
# 刚开始:{A}, {B}, {C}, {D}, {E}
# 加上2这条边:这条边的from(A)和to(B)不在一个集合里,所以加上,然后A,C合并:{A, C}, {B}, {D}, {E}
# 加上4这条边:这条边的from(C)和to(D)不在一个集合里,所以加上,然后A,C合并:{A, C, D}, {B}, {E}
# 加上7这条边:这条边的from(A)和to(B)不在一个集合里,所以加上,然后A,C合并:{A, C, D, B}, {E}
# 加上100这条边:这条边的from(A)和to(D)在一个集合里,所以这条边不要,{A, C, D, B}, {E}
# 加上1000这条边:这条边的from(B)和to(D)在一个集合里,所以这条边不要,{A, C, D, B}, {E}
# 加上10万这条边:这条边的from(B)和to(E)不在一个集合里,所以加上,然后A,C合并:{A, C, D, B, E}
# 差不多这样的结构
import typing
class UnionFind:
def __init__(self, nodes: typing.List[Node]):
self.node_map = {}
for node in nodes:
self.node_map[node] = {node}
def find(self, from_node: Node, to_node: Node):
return self.node_map.get(from_node) == self.node_map.get(to_node)
def union(self, from_node: Node, to_node: Node):
from_set = self.node_map.get(from_node)
to_set = self.node_map.get(to_node)
if from_set is not None and to_set is not None and from_set != to_set:
for node in to_set:
from_set.add(node)
self.node_map[node] = from_set # 被合并的集合里面的节点需要指向,最新的集合地址
# 3
# A------B
# | \ / |
# 100| 7\/5 |
# | /\ |2
# | / \ |
# C------D
# 1000
# 保证连通性的同时,权值最小
# 3
# A------B
# /|
# /5|
# / |2
# / |
# C D
kruskal算法
适用范围:要求无向图
kruskal
算法以边的角度出发- 把边排序,依次选择最小的边,看这条边加上,看有没有形成环
- 没有,要这条边
- 有,不要这条边
- 因此需要一种检测有没有形成环的功能(并查集)
# 7 10万
# A------B--------E
# | \ |
# | \ |
# 2|100\ |1000
# | \ |
# C------D
# 4
data = [
[7, "A", "B"],
[100000, "B", "E"],
[2, "A", "C"],
[100, "A", "D"],
[1000, "B", "D"],
[4, "C", "D"],
]
graph = create_graph(data)
import typing
import heapq
def kruskalMST(graph: Graph):
union_find = UnionFind(list(graph.nodes.values()))
# 优先级队列(小根堆),该方法需要Edge实现比较魔术函数
hq = []
for value in graph.edges: # M条边
heapq.heappush(hq, value) # O(logM)
result: typing.List[Edge] = []
while hq: # M条边
cur_edge = heapq.heappop(hq) # O(logM)
if not union_find.find(cur_edge.from_node, cur_edge.to_node): # 查
result.append(cur_edge)
union_find.union(cur_edge.from_node, cur_edge.to_node) # 合并
return result
res = kruskalMST(graph)
for it in res:
print(it.weight)
prim算法
适用范围:要求无向图
# 7 10万
# A------B--------E
# | \ |
# | \ |
# 2|100\ |1000
# | \ |
# C------D
# 4
data = [
[7, "A", "B"],
[100000, "B", "E"],
[2, "A", "C"],
[100, "A", "D"],
[1000, "B", "D"],
[4, "C", "D"],
]
graph = create_graph(data)
import typing
import heapq
def primMST(graph: Graph):
hq = []
node_set: typing.Set[Node] = set() # 是否是新点
result: typing.List[Edge] = [] # 依次挑选的边
# 用于处理森林的情况,比如多个不连通的图,各自生成最小生成树
for node in graph.nodes.values():
# 新的一个点,联通的图的话,这里开始就行
if node not in node_set:
node_set.add(node)
for edge in node.edges: # 该点所有的边加入到优先级队列中
heapq.heappush(hq, edge)
while hq:
edge = heapq.heappop(hq) # 弹出最小的边
to_node = edge.to_node # 可能是新的点
if to_node not in node_set: # 如果边的to节点不在集合里面,就是新点
node_set.add(to_node)
result.append(edge) # 加入结果中
for edge in to_node.edges:
heapq.heappush(hq, edge) # 一条边可能会重复加入,但是不会影响结果,应该已经在集合里面了,最多增加常数时间
return result
res = primMST(graph)
for it in res:
print(it.weight)
Dijkstra算法
解决一个点到所有点的最短路径
适用范围:没有权值为负数的边
# D
# / | \
# / | \
# / | \
# 9/ |7 \16
# / | \
# A--15--C--14--E
# \ | /
# 3\ |2 /200
# \ | /
# \ | /
# \ | /
# B
"""
# 先生成一张表:A到各个点的距离,初始位正无穷
A B C D E
A 0 +∞ +∞ +∞ +∞
每一次在表里选距离最短的点A
然后看从这个点出发的边,能不能把这张表的记录变得更小
A -> B(3)
A B C D E
A 0 3 +∞ +∞ +∞
A -> B(15)
A B C D E
A 0 3 15 +∞ +∞
A -> D(9)
A B C D E
A 0 3 15 9 +∞
(不动了)
在剩下的记录中找距离最短的点B
B(3) -> A(3), 到A, 3+3 > 0, 跳过
B(3) -> C(2), 到C, 3+2 < 15, 改写
A B C D E
A 0 3 5 9 +∞
B(3) -> E(200), 到C, 3+200 < +∞, 改写
A B C D E
A 0 3 5 9 203
(不动了) (不动了)
在剩下的记录中找距离最短的点C
C(5) -> A(15), 到A, 5+15 > 0, 跳过
C(5) -> B(2), 到C, 5+2 > 3, 跳过
C(5) -> E(14), 到C, 5+14 < 203, 改写
A B C D E
A 0 3 5 9 203
C(5) -> D(7), 到C, 5+7 > 9, 跳过
A B C D E
A 0 3 5 9 19
(不动了)(不动了)(不动了)
以此类推
"""
# D
# / | \
# / | \
# / | \
# 9/ |7 \16
# / | \
# A--15--C--14--E
# \ | /
# 3\ |2 /200
# \ | /
# \ | /
# \ | /
# B
data = [
[3, "A", "B"],
[9, "A", "D"],
[15, "A", "C"],
[3, "B", "A"],
[2, "B", "C"],
[200, "B", "E"],
[15, "C", "A"],
[2, "C", "B"],
[14, "C", "E"],
[7, "C", "D"],
[9, "D", "A"],
[7, "D", "C"],
[16, "D", "E"],
[16, "E", "D"],
[14, "E", "C"],
[200, "E", "B"]
]
graph = create_graph(data)
def get_min_distance_and_unselected_node(node_map: dict, select_set: set):
min_node = None
min_dist = float("inf")
for node, dist in node_map.items():
if node not in select_set and dist < min_dist:
min_node = node
min_dist = dist
return min_node
def dijkstra(node: Node):
# 从head出发到所有点的最小距离
# key: 从head出发到达的点
# value: 从head出发到key最小的距离
# 如果在表中,没有T记录,含义是从head出发到T这个点的距离为正无穷
node_map = {node: 0}
# 已经求过距离的点,不在处理
selected_nodes = set() # 锁住不动的点
min_node = get_min_distance_and_unselected_node(node_map, selected_nodes)
while min_node:
dist = node_map.get(min_node)
for edge in min_node.edges:
to_node = edge.to_node
# 新增
if to_node not in node_map: # 没有,正无穷,新增
node_map[to_node] = dist + edge.weight
else:
# 更新
node_map[to_node] = min(node_map.get(to_node), dist + edge.weight) # 有没有变小
selected_nodes.add(min_node) # 锁住,该点不用操作了
min_node = get_min_distance_and_unselected_node(node_map, selected_nodes)
return node_map
data = dijkstra(graph.nodes["A"])
for k, v in data.items():
print(k.value, v)
- 堆结构改写
Dijkstra
算法
指标不变,只是加快了常数时间
- 从源点出发,源点放入小根堆
- 源点弹出,从该点出发的到的点和距离放入小根堆
- 再弹出堆顶,查看从该点出发到的点有没有让距离变小,变小放入小根堆
- 依次类推下去,利用小根堆依次弹出数据,源点出发到每个点的最短距离就出来了
import typing
class Node:
def __init__(self, value: typing.Any):
self.value = value
# 无向图的入度和出度相等
self.node_in: int = 0 # 点的入度
self.node_out: int = 0 # 点的出度
self.next_nodes: typing.List[Node] = list() # 指向的邻居,只关心发散出去的
self.edges: typing.List[Edge] = list() # 有哪些边属于该点
class Edge:
def __init__(self, weight: int, from_node: Node, to_node: Node):
self.weight: int = weight
self.from_node: Node = from_node
self.to_node: Node = to_node
# 实现 __hash__ 方法
def __hash__(self):
# 可以根据对象的属性生成一个唯一的哈希值
return hash((self.from_node, self.to_node, self.weight))
def __lt__(self, other):
return self.weight < other.weight
def __eq__(self, other):
return self.weight == other.weight
def __gt__(self, other):
return self.weight > other.weight
class Graph:
def __init__(self):
self.nodes: typing.Dict[int, Node] = dict() # key: 点编号,node点实例
self.edges: typing.Set[Edge] = set()
def create_graph(graph_data: list):
graph = Graph()
for it in graph_data:
weight = it[0]
from_node = it[1]
to_node = it[2]
if from_node not in graph.nodes:
graph.nodes[from_node] = Node(from_node)
if to_node not in graph.nodes:
graph.nodes[to_node] = Node(to_node)
from_node_ins = graph.nodes.get(from_node)
to_node_ins = graph.nodes.get(to_node)
edge = Edge(weight, from_node_ins, to_node_ins)
from_node_ins.next_nodes.append(to_node_ins)
from_node_ins.node_out += 1
to_node_ins.node_in += 1
from_node_ins.edges.append(edge)
graph.edges.add(edge)
return graph
class NodeRecord:
def __init__(self, node: Node, distance: int):
self.node = node
self.distance = distance
class NodeHeap:
def __init__(self, node_size: int):
self.nodes: typing.List[Node] = [None] * node_size # 堆
self.heap_index_map: typing.Dict[Node, int] = dict() # 在堆上的位置索引
self.distance_map: typing.Dict[Node, int] = dict()
self.size = 0
def add_or_update_or_ignore(self, node: Node, distance: int):
# 在堆上,更新
if self._in_heap(node):
self.distance_map.update({node: min(self.distance_map[node], distance)}) # 更新最小距离
# 变小了,可能经历一个往上的过程
self._heap_insert(self.heap_index_map[node])
# 没进来过
if not self._is_entered(node):
self.nodes[self.size] = node # 新增
self.heap_index_map[node] = self.size
self.distance_map[node] = distance
# 可能经历一个往上的过程
self._heap_insert(self.size)
self.size += 1
# 已经弹出过的,-1的忽略
def is_empty(self):
return self.size == 0
def pop(self) -> NodeRecord:
node_record = NodeRecord(self.nodes[0], self.distance_map[self.nodes[0]])
self._swap(0, self.size - 1) # 最后一个元素放到堆顶
self.heap_index_map[self.nodes[self.size - 1]] = -1 # 因为上面交换到最后一个元素,然后弹出了,所以置为-1
self.distance_map.pop(self.nodes[self.size - 1]) # 删掉对应节点距离信息
self.nodes[self.size - 1] = None # 最后一个位置释放掉
self.size -= 1
self._heapify(0, self.size)
return node_record
def _is_entered(self, node: Node) -> bool:
"""
node是否进过堆
:param node:
:return: 进过返回对应索引或-1,否则未进入
-1 代办进来过,但处理完了
"""
# return self.heap_index_map.get(node) # 索引为0 会被当成False
return node in self.heap_index_map
def _in_heap(self, node: Node) -> bool:
"""
在不在堆上
进来过,并且不等于-1
-1 表示(弹出了)不在堆上
:param node:
:return:
"""
return self._is_entered(node) and self.heap_index_map.get(node, -1) != -1
def _swap(self, index1: int, index2: int):
"""
调整堆
:param index1:
:param index2:
:return:
"""
# 索引对换更新
self.heap_index_map.update({self.nodes[index1]: index2, self.nodes[index2]: index1})
# self.heap_index_map[self.nodes[index1]], self.heap_index_map[self.nodes[index2]] = index2, index1
# 堆元素对换
self.nodes[index1], self.nodes[index2] = self.nodes[index2], self.nodes[index1]
# 小根堆调整
def _heap_insert(self, index: int):
"""
根据距离往上调整
:param index:
:return:
"""
parent_idx = max(0, (index - 1) >> 1)
while self.distance_map[self.nodes[index]] < self.distance_map[self.nodes[parent_idx]]:
self._swap(index, parent_idx)
index, parent_idx = parent_idx, max(0, (index - 1) >> 1)
def _heapify(self, index: int, size: int):
left = 2 * index + 1
while left < size:
smallest = left + 1 if left + 1 < size and self.distance_map[self.nodes[left + 1]] < self.distance_map[
self.nodes[left]] else left
# 最小的都比父大,完成调整
if self.distance_map[self.nodes[smallest]] >= self.distance_map[self.nodes[index]]:
break
self._swap(index, smallest)
index, left = smallest, 2 * index + 1
def dijkstra_heap(head: Node, size: int):
"""
利用小根堆改进dijkstra算法
从head出发,所有head能到到达的节点,生成到达每个节点的最小路径记录并返回
:param head:
:param size:
:return:
"""
node_heap = NodeHeap(size)
node_heap.add_or_update_or_ignore(head, 0)
result: typing.Dict[Node, int] = dict()
while not node_heap.is_empty():
cur = node_heap.pop()
for edge in cur.node.edges:
node_heap.add_or_update_or_ignore(edge.to_node, cur.distance + edge.weight)
result[cur.node] = cur.distance
return result
# D
# / | \
# / | \
# / | \
# 9/ |7 \16
# / | \
# A--15--C--14--E
# \ | /
# 3\ |2 /200
# \ | /
# \ | /
# \ | /
# B
data = [
[3, "A", "B"],
[9, "A", "D"],
[15, "A", "C"],
[3, "B", "A"],
[2, "B", "C"],
[200, "B", "E"],
[15, "C", "A"],
[2, "C", "B"],
[14, "C", "E"],
[7, "C", "D"],
[9, "D", "A"],
[7, "D", "C"],
[16, "D", "E"],
[16, "E", "D"],
[14, "E", "C"],
[200, "E", "B"]
]
graph = create_graph(data)
data = dijkstra_heap(graph.nodes["A"], size=len(graph.nodes))
for k, v in data.items():
print(k.value, v)
A*算法
A*
解决的是指定源点,指定目标点,求源点到达目标点的最短距离,增加了当前点到终点的启发函数(预估函数)Dijkstra
的小根堆是根据源点到当前点的代价排序的, 而A*
根据从源点到当前点的距离 + 当前点到目标点的估计距离来进行排序,剩下的所有细节和Dijkstra
算法一致Dijkstra
算法其实是一种特殊的A*
算法,只是当前点到目标点的估计距离都是0
启发函数要求
当前点到终点的预估距离 <= 当前点到终点的真实最短距离(如果有障碍物,真实最短距离因为需要绕障肯定就变长)
预估函数是一种吸引力
- 合适的吸引力可以提升算法的速度,吸引力过强会出现错误
- 保证预估距离 <= 真实最短距离的情况下,尽量接近真实最短距离,可以做到功能正确且最快
预估终点距离经常选择
- 曼哈顿距离(适用二维网格图,只能上下左右,不能斜线,如果能走斜线就不适用)
- 欧式距离
- 对角线距离:
max(|行 - 行|, |列 - 列|)
(适用可以走斜线的二维网格)
Q R S T U
J K L M V
E F G N W
B C H O X
A D I P Y
目标:从 G 到 U
限制:只能上下左右
BFS:宽度优先遍历
第一层:L F H N
第二层:K S M E C ...
...
Dijkstra:引入了堆可能比BFS差
上下左右暂开,距离都是1放入小根堆
弹出,继续往复
A*
原点G到当前点距离 预估到目标距离(曼哈顿距离) 堆
G 0 4 G(0+4) 第1次弹出
L 1 3 L(1+3) 第2次弹出
F 1 5 F(1+5)
S 2 2 S(2+2) 第3次弹出
...
真实距离
G L F S
0 1 1 2
import heapq
import random
import math
import time
class AStarAlgorithm:
def __init__(self):
# x坐标变化值,0: 上,1: 右,2: 下,3: 左
# y坐标变化值,索引加一
self.move_direction = (-1, 0, 1, 0, -1)
def dijkstra(self, grid, start_x, start_y, target_x, target_y):
"""
特殊的A*算法
grid[i][j] == 0 代表障碍物
grid[i][j] == 1 代表道路
只能走上下左右,不能走斜线
返回从(start_x, start_y)到(target_x, target_y)的最短距离
:return:
"""
if grid[start_x][start_y] == 0 or grid[target_x][target_y] == 0:
return -1
m, n = len(grid), len(grid[0])
distance = [[float("inf") for _ in range(n)] for _ in range(m)]
visited = [[False for _ in range(n)] for _ in range(m)]
distance[start_x][start_y] = 1
hq = [] # 小根堆
heapq.heappush(hq, (1, start_x, start_y)) # (从源点出发到当前点距离, 行, 列)
while hq:
_, x, y = heapq.heappop(hq)
if visited[x][y]:
continue
visited[x][y] = True
if x == target_x and y == target_y: # 找到目标点
return distance[x][y]
for i in range(4):
nx = x + self.move_direction[i]
ny = y + self.move_direction[i + 1]
# 未越界,可走,未访问过,距离小于存在的距离
if 0 <= nx < m and 0 <= ny < n and grid[nx][ny] == 1 and not visited[nx][ny]
and distance[x][y] + 1 < distance[nx][ny]:
distance[nx][ny] = distance[x][y] + 1
heapq.heappush(hq, (distance[x][y] + 1, nx, ny))
return -1
def a_star(self, grid, start_x, start_y, target_x, target_y):
if grid[start_x][start_y] == 0 or grid[target_x][target_y] == 0:
return -1
m, n = len(grid), len(grid[0])
distance = [[float("inf") for _ in range(n)] for _ in range(m)]
visited = [[False for _ in range(n)] for _ in range(m)]
distance[start_x][start_y] = 1
hq = []
heapq.heappush(hq, (1 + self.manhattan_distance(start_x, start_y, target_x, target_y), start_x, start_y))
while hq:
_, x, y = heapq.heappop(hq)
if visited[x][y]:
continue
visited[x][y] = True
if x == target_x and y == target_y:
return distance[x][y]
for i in range(4):
nx = x + self.move_direction[i]
ny = y + self.move_direction[i + 1]
if 0 <= nx < m and 0 <= ny < n and grid[nx][ny] == 1 and not visited[nx][ny]
and distance[x][y] + 1 < distance[nx][ny]:
distance[nx][ny] = distance[x][y] + 1
heapq.heappush(hq,
(distance[x][y] + 1 + self.manhattan_distance(nx, ny, target_x, target_y), nx, ny))
return -1
@staticmethod
def manhattan_distance(sx, sy, tx, ty):
"""曼哈顿距离"""
return abs(tx - sx) + abs(ty - sy)
@staticmethod
def diagonal_distance(sx, sy, tx, ty):
"""对角线距离"""
return max(abs(tx - sx), abs(ty - sy))
@staticmethod
def euclidean_distance(sx, sy, tx, ty):
"""欧式距离"""
return math.hypot(tx - sx, ty - sy)
@staticmethod
def random_grid(n):
# 30%概率为0
return [[0 if random.random() < 0.3 else 1 for _ in range(n)] for _ in range(n)]
def main():
ins = AStarAlgorithm()
# length = 100
# test_times = 10000
# print("功能测试开始")
# for _ in range(test_times):
# n = int(random.random() * length) + 2
# grid = ins.random_grid(n)
# start_x, start_y = int(random.random() * n), int(random.random() * n)
# target_x, target_y = int(random.random() * n), int(random.random() * n)
# ans1 = ins.dijkstra(grid, start_x, start_y, target_x, target_y)
# ans2 = ins.a_star(grid, start_x, start_y, target_x, target_y)
# if ans1 != ans2:
# print("Oops, It's wrong")
# print("功能测试结束")
print("性能测试开始")
grid = ins.random_grid(1000)
sx, sy, tx, ty = 0, 0, 990, 990
start_time = time.time()
ans1 = ins.dijkstra(grid, sx, sy, tx, ty)
print(f"Dijkstra算法结果:{ans1}, 运行时间: {int((time.time() - start_time) * 1e3)} ms")
start_time = time.time()
ans2 = ins.a_star(grid, sx, sy, tx, ty)
print(f"A*算法结果:{ans2}, 运行时间: {int((time.time() - start_time) * 1e3)} ms")
print("性能测试结束")
# 性能测试开始
# Dijkstra算法结果:1985, 运行时间: 6829 ms
# A*算法结果:1985, 运行时间: 1204 ms
# 性能测试结束
if __name__ == '__main__':
main()
Dijkstra算法和A*算法区别
A*
算法与Dijkstra
算法最大的区别就是提供了启发函数,通过结合已经走过的路径长度和启发函数的估计值,选择下一个要走的路;- 而
Dijkstra
算法是一种贪心算法,它只考虑已知的路径长度,没有考虑目标位置 A*
算法因为提供了启发函数,通常情况下比Dijkstra
算法快,特别是目标位置距离起始位置较远,有多条路径可以选择的情况Dijkstra
算法在某些情况下可能比A*
算法快,比如路径代价均匀分布,导致启发函数估计值不够精确,就会导致大量的路径遍历