mle-interview
  • 面试指南
  • 数据结构与算法
    • 列表
      • 912. Sort an Array
      • 215. Kth Largest Element
      • 977. Squares of a Sorted Array
      • 605. Can Place Flowers
      • 59. Spiral Matrix II
      • 179. Largest Number
      • 31. Next Permutation
    • 二分查找
      • 704. Binary Search
      • 69. Sqrt(x)
      • 278. First Bad Version
      • 34. Find First and Last Position of Element in Sorted Array
      • 33. Search in Rotated Sorted Array
      • 81. Search in Rotated Sorted Array II
      • 162. Find Peak Element
      • 4. Median of Two Sorted Arrays
      • 1095. Find in Mountain Array
      • 240. Search a 2D Matrix II
      • 540. Single Element in a Sorted Array
      • 528. Random Pick with Weight
      • 1300. Sum of Mutated Array Closest to Target
      • 410. Split Array Largest Sum
      • 1044. Longest Duplicate Substring
      • *644. Maximum Average Subarray II
      • *1060. Missing Element in Sorted Array
      • *1062. Longest Repeating Substring
      • *1891. Cutting Ribbons
    • 双指针
      • 26. Remove Duplicate Numbers in Array
      • 283. Move Zeroes
      • 75. Sort Colors
      • 88. Merge Sorted Arrays
      • 167. Two Sum II - Input array is sorted
      • 11. Container With Most Water
      • 42. Trapping Rain Water
      • 15. 3Sum
      • 16. 3Sum Closest
      • 18. 4Sum
      • 454. 4Sum II
      • 409. Longest Palindrome
      • 125. Valid Palindrome
      • 647. Palindromic Substrings
      • 209. Minimum Size Subarray Sum
      • 5. Longest Palindromic Substring
      • 395. Longest Substring with At Least K Repeating Characters
      • 424. Longest Repeating Character Replacement
      • 76. Minimum Window Substring
      • 3. Longest Substring Without Repeating Characters
      • 1004. Max Consecutive Ones III
      • 1658. Minimum Operations to Reduce X to Zero
      • *277. Find the Celebrity
      • *340. Longest Substring with At Most K Distinct Characters
    • 链表
      • 203. Remove Linked List Elements
      • 19. Remove Nth Node From End of List
      • 876. Middle of the Linked List
      • 206. Reverse Linked List
      • 92. Reverse Linked List II
      • 24. Swap Nodes in Pairs
      • 707. Design Linked List
      • 148. Sort List
      • 160. Intersection of Two Linked Lists
      • 141. Linked List Cycle
      • 142. Linked List Cycle II
      • 328. Odd Even Linked List
    • 哈希表
      • 706. Design HashMap
      • 1. Two Sum
      • 146. LRU Cache
      • 128. Longest Consecutive Sequence
      • 73. Set Matrix Zeroes
      • 380. Insert Delete GetRandom O(1)
      • 49. Group Anagrams
      • 350. Intersection of Two Arrays II
      • 299. Bulls and Cows
      • *348. Design Tic-Tac-Toe
    • 字符串
      • 242. Valid Anagram
      • 151. Reverse Words in a String
      • 205. Isomorphic Strings
      • 647. Palindromic Substrings
      • 696. Count Binary Substrings
      • 28. Find the Index of the First Occurrence in a String
      • *186. Reverse Words in a String II
    • 栈与队列
      • 225. Implement Stack using Queues
      • 54. Spiral Matrix
      • 155. Min Stack
      • 232. Implement Queue using Stacks
      • 150. Evaluate Reverse Polish Notation
      • 224. Basic Calculator
      • 20. Valid Parentheses
      • 1472. Design Browser History
      • 1209. Remove All Adjacent Duplicates in String II
      • 1249. Minimum Remove to Make Valid Parentheses
      • *281. Zigzag Iterator
      • *1429. First Unique Number
      • *346. Moving Average from Data Stream
    • 优先队列/堆
      • 692. Top K Frequent Words
      • 347. Top K Frequent Elements
      • 973. K Closest Points
      • 23. Merge K Sorted Lists
      • 264. Ugly Number II
      • 378. Kth Smallest Element in a Sorted Matrix
      • 295. Find Median from Data Stream
      • 767. Reorganize String
      • 1438. Longest Continuous Subarray With Absolute Diff Less Than or Equal to Limit
      • 895. Maximum Frequency Stack
      • 1705. Maximum Number of Eaten Apples
      • *1086. High Five
    • 深度优先DFS
      • 二叉树
      • 543. Diameter of Binary Tree
      • 101. Symmetric Tree
      • 124. Binary Tree Maximum Path Sum
      • 226. Invert Binary Tree
      • 104. Maximum Depth of Binary Tree
      • 951. Flip Equivalent Binary Trees
      • 236. Lowest Common Ancestor of a Binary Tree
      • 987. Vertical Order Traversal of a Binary Tree
      • 572. Subtree of Another Tree
      • 863. All Nodes Distance K in Binary Tree
      • 1110. Delete Nodes And Return Forest
      • 230. Kth Smallest element in a BST
      • 98. Validate Binary Search Tree
      • 235. Lowest Common Ancestor of a Binary Search Tree
      • 669. Trim a Binary Search Tree
      • 700. Search in a Binary Search Tree
      • 108. Convert Sorted Array to Binary Search Tree
      • 450. Delete Node in a BST
      • 938. Range Sum of BST
      • *270. Closest Binary Search Tree Value
      • *333. Largest BST Subtree
      • *285. Inorder Successor in BST
      • *1485. Clone Binary Tree With Random Pointer
      • 回溯
      • 39. Combination Sum
      • 78. Subsets
      • 46. Permutation
      • 77. Combinations
      • 17. Letter Combinations of a Phone Number
      • 51. N-Queens
      • 93. Restore IP Addresses
      • 22. Generate Parentheses
      • 856. Score of Parentheses
      • 301. Remove Invalid Parentheses
      • 37. Sodoku Solver
      • 图DFS
      • 126. Word Ladder II
      • 212. Word Search II
      • 79. Word Search
      • 399. Evaluate Division
      • 1376. Time Needed to Inform All Employees
      • 131. Palindrome Partitioning
      • 491. Non-decreasing Subsequences
      • 698. Partition to K Equal Sum Subsets
      • 526. Beautiful Arrangement
      • 139. Word Break
      • 377. Combination Sum IV
      • 472. Concatenated Words
      • 403. Frog Jump
      • 329. Longest Increasing Path in a Matrix
      • 797. All Paths From Source to Target
      • 695. Max Area of Island
      • 341. Flatten Nested List Iterator
      • 394. Decode String
      • *291. Word Pattern II
      • *694. Number of Distinct Islands
      • *1274. Number of Ships in a Rectangle
      • *1087. Brace Expansion
    • 广度优先BFS
      • 102. Binary Tree Level Order Traversal
      • 103. Binary Tree Zigzag Level Order Traversal
      • 297. Serialize and Deserialize Binary Tree
      • 310. Minimum Height Trees
      • 127. Word Ladder
      • 934. Shortest Bridge
      • 200. Number of Islands
      • 133. Clone Graph
      • 130. Surrounded Regions
      • 752. Open the Lock
      • 815. Bus Routes
      • 1091. Shortest Path in Binary Matrix
      • 542. 01 Matrix
      • 1293. Shortest Path in a Grid with Obstacles Elimination
      • 417. Pacific Atlantic Water Flow
      • 207. Course Schedule
      • 210. Course Schedule II
      • 787. Cheapest Flights Within K Stops
      • 444. Sequence Reconstruction
      • 994. Rotting Oranges
      • 785. Is Graph Bipartite?
      • *366. Find Leaves of Binary Tree
      • *314. Binary Tree Vertical Order Traversal
      • *269. Alien Dictionary
      • *323. Connected Component in Undirected Graph
      • *490. The Maze
    • 动态规划
      • 70. Climbing Stairs
      • 72. Edit Distance
      • 377. Combination Sum IV
      • 1335. Minimum Difficulty of a Job Schedule
      • 97. Interleaving String
      • 472. Concatenated Words
      • 403. Frog Jump
      • 674. Longest Continuous Increasing Subsequence
      • 62. Unique Paths
      • 64. Minimum Path Sum
      • 368. Largest Divisible Subset
      • 300. Longest Increasing Subsequence
      • 354. Russian Doll Envelopes
      • 121. Best Time to Buy and Sell Stock
      • 132. Palindrome Partitioning II
      • 312. Burst Balloons
      • 1143. Longest Common Subsequence
      • 718. Maximum Length of Repeated Subarray
      • 174. Dungeon Game
      • 115. Distinct Subsequences
      • 91. Decode Ways
      • 639. Decode Ways II
      • 712. Minimum ASCII Delete Sum for Two Strings
      • 221. Maximal Square
      • 1277. Count Square Submatrices with All Ones
      • 198. House Robber
      • 213. House Robber II
      • 1235. Maximum Profit in Job Scheduling
      • 740. Delete and Earn
      • 87. Scramble String
      • 1140. Stone Game II
      • 322. Coin Change
      • 518. Coin Change II
      • 1048. Longest String Chain
      • 44. Wildcard Matching
      • 10. Regular Expression Matching
      • 32. Longest Valid Parentheses
      • 1043. Partition Array for Maximum Sum
      • *256. Paint House
      • 926. Flip String to Monotone Increasing
      • *1062. Longest Repeating Substring
      • *1216. Valid Palindrome III
    • 贪心
      • 56. Merge Intervals
      • 621. Task Scheduler
      • 135. Candy
      • 376. Wiggle Subsequence
      • 55. Jump Game
      • 134. Gas Station
      • 1005. Maximize Sum Of Array After K Negations
      • 406. Queue Reconstruction by Height
      • 452. Minimum Number of Arrows to Burst Balloons
      • 738. Monotone Increasing Digits
    • 单调栈
      • 739. Daily Temperatures
      • 503. Next Greater Element II
      • 901. Online Stock Span
      • 85. Maximum Rectangle
      • 84. Largest Rectangle in Histogram
      • 907. Sum of Subarray Minimums
      • 239. Sliding Window Maximum
    • 前缀和
      • 53. Maximum Subarray
      • 523. Continuous Subarray Sum
      • 304. Range Sum Query 2D - Immutable
      • 1423. Maximum Points You Can Obtain from Cards
      • 1031. Maximum Sum of Two Non-Overlapping Subarrays
    • 并查集
      • 684. Redundant Connection
      • 721. Accounts Merge
      • 547. Number of Provinces
      • 737. Sentence Similarity II
      • *305. Number of Islands II
    • 字典树trie
      • 208. Implement Trie
      • 211. Design Add and Search Words Data Structure
      • 1268. Search Suggestions System
      • *1166. Design File System
      • *642. Design Search Autocomplete System
    • 扫描线sweep line
      • 253. Meeting Room II
      • 1094. Car Pooling
      • 218. The Skyline Problem
      • *759. Employee Free Time
    • tree map
      • 729. My Calendar I
      • 981. Time Based Key-Value Store
      • 846. Hand of Straights
      • 480. Sliding Window Median
      • 318. Count of Smaller Numbers After Self
    • 数学类
      • 50. Pow(x, n)
      • *311. Sparse Matrix Multiplication
      • 382. Linked List Random Node
      • 398. Random Pick Index
      • 29. Divide Two Integers
    • 设计类
      • 1603. Design Parking System
      • 355. Design Twitter
      • 1396. Design Underground System
      • *359. Logger Rate Limiter
      • *353. Design Snake Game
      • *379. Design Phone Directory
      • *588. Design In-Memory File System
      • *1244. Design A Leaderboard
    • SQL
  • 机器学习
    • 数学基础
    • 评价指标
    • 线性回归
    • 逻辑回归
    • 树模型
    • 深度学习
    • 支持向量机
    • KNN
    • 无监督学习
    • k-means
    • 强化学习 RL
    • 自然语言处理 NLP
    • 大语言模型 LLM
    • 机器视觉 CV
    • 多模态 MM
    • 分布式机器学习
    • 推荐系统
    • 异常检测与风控
    • 模型解释性
    • 多任务学习
    • MLops
    • 特征工程
    • 在线学习
    • 硬件 cuda/triton
    • 产品case分析
    • 项目deep dive
    • 机器学习代码汇总
  • 系统设计
    • 面向对象设计
      • 电梯设计
      • 停车场设计
      • Unix文件系统设计
    • 系统设计
      • 设计社交网站Twitter
      • 设计视频网站Youtube
      • 短网址系统
      • 爬虫系统
      • 任务调度系统
      • 日志系统
      • 分布式缓存
      • 广告点击聚合系统
      • webhook
    • 机器学习系统设计
      • 推荐系统
      • 搜索引擎
      • Youtube视频推荐
      • Twitter推荐
      • 广告点击预测
      • 新闻推送推荐
      • POI推荐
      • Youtube视频搜索
      • 有害内容检测
      • 大模型RAG
      • 大模型Agent
      • 信贷风控
      • 朋友推荐
      • 去重复性/版权检测
      • 情感分析
      • 目标检测
      • 问答系统
      • 知识图谱问答
  • 行为面试
    • 领导力法则
    • 问答举例
  • 案例分享
    • 准备工作
    • 面试小抄
    • 面试之后
Powered by GitBook
On this page
  • 1. 优化 Optimizer
  • 1.1 前向后向传播
  • 1.2 优化 Optimizer
  • 1.3 学习率scheduler
  • 1.4 初始化
  • 2. 损失函数
  • 3. 网络模型结构
  • 3.1 多层感知机 MLP
  • 3.2 卷积神经网络 CNN
  • 3.3 循环神经网络 RNN
  • 3.4 Transformer
  • 3.5 正则化
  • 3.6 激活函数
  • 3.7 标准化 Norm
  • 3.8 dropout
  • 3.9 pool
  • 4. 训练
  • reference
  1. 机器学习

深度学习

Previous树模型Next支持向量机

Last updated 21 days ago

1. 优化 Optimizer

1.1 前向后向传播

  • pytorch和jax的backprop

  • 训练神经网络的一次迭代分三步:(1)前向传递计算损失函数;(2)后向传递计算梯度;(3)优化器更新模型参数

    • 前向传播,根据预测值和标签计算损失函数,以及损失函数对应的梯度。损失函数类的设计有正向值计算方法和梯度计算方法, 损失函数对y_hat的偏微分

    • 从loss梯度后向传播,计算每一个训练参数的grad。每一层后向传播的输入都是后面层的梯度。每一层有前向方法f(x)和后向方法f(grad)

    • 根据参数值和参数梯度进行优化更新参数: optimizer(w, w_grad)

1.2 优化 Optimizer

梯度:

  • slope of a curve at a given point

  • 从单变量看,抖的时候就走的步子大一点,缓的时候就走的小一点. 多个变量的不同变化决定了整体优化方向

动量

  • 除了此刻的输入外,还考虑上一时刻的输出. 有的优化根据历史梯度计算一阶动量和二阶动量

SGD原理

  • 考虑加入惯性,引入一阶动量,SGD with Momentum

  • Very flexible—can use other loss functions

  • Can be parallelized

  • Slower—does not converge as quickly

  • Harder to handle the unobserved entries (need to use negative sampling or gravity)

Adam和adgrad区别和应用场景

  • Adam: 每个参数梯度增加了一阶动量(momentum)和二阶动量(variance),Adaptive + Momentum. 通过其来自适应控制步长,当梯度较小时,整体的学习率就会增加,反之会缩小

RAdam

  • 用指数滑动平均去估计梯度每个分量的一阶矩(动量)和二阶矩(自适应学习率),并用二阶矩去 normalize 一阶矩,得到每一步的更新量

AdamW

  • 模型的优化方向是"历史动量"和"当前数据梯度"共同决定的

对抗训练

  • 在训练过程中产生一些攻击样本,相当于是加了一层正则化,给神经网络的随机梯度优化限制了一个李普希茨的约束

牛顿法

  • 梯度下降是用平面来逼近局部,牛顿法是用曲面逼近局部

Batch Size

  • 用尽可能能塞进内存的batch size去train模型,提升训练速度. 但也存在trade-off

    • batch size过小,波动会比较大,不太容易收敛。但这种波峰,也有助于跳出局部最优,模型更容易有更好的泛化能力

    • batch size变大,步数整体变少,训练的步数更少,本来就波动就小,步数也少,同样本的情况下,你收敛的会更慢

1.3 学习率scheduler

  • LR与batch_size

    • 常用的heuristic 是 LR 应该与 batch size 的增长倍数的开方成正比,从而保证 variance 与梯度成比例的增长

# Cyclic LR, 每隔一段时间重启学习率,这样在单位时间内能收敛到多个局部最小值,可以得到很多个模型做集成
scheduler = lambda x: ((LR_INIT-LR_MIN)/2)*(np.cos(PI*(np.mod(x-1,CYCLE)/(CYCLE)))+1)+LR_MIN

# warp up, 有助于减缓模型在初始阶段对mini-batch的提前过拟合现象,保持分布的平稳,同时有助于保持模型深层的稳定性
warmup_steps = int(batches_per_epoch * 5)

1.4 初始化

  • 权重为什么不能被初始化为0?

    • 会导致激活后具有相同的值,网络相当于只有一个隐含层节点一样, hidden size失去意义

2. 损失函数

  • MSE

    • prediction made by model trained with MSE loss is always normally distributed

  • cross entropy/ 对数损失

    • nn.CrossEntropyLoss(pred, label) = nn.NLLLoss(torch.log(nn.Softmax(pred)), label)

  • Focal loss

    • 对CE loss增加了一个调制系数来降低容易样本的权重值,使得训练过程更加关注困难样本。增加的这个系数就是评价难易,也就是概率的gamma次方

3. 网络模型结构

3.1 多层感知机 MLP

向量内积

  • 表征两个向量的夹角,表征一个向量在另一个向量上的投影

  • 表征加权平均

import numpy as np

class Dense:
    def __init__(self, input_size, output_size):
        self.weights = np.random.rand(input_size, output_size) - 0.5
        self.bias = np.random.rand(1, output_size) - 0.5

    def forward_propagation(self, input_data):
        self.input = input_data
        self.output = np.dot(self.input, self.weights) + self.bias
        return self.output

    def backward_propagation(self, output_error, learning_rate):
        input_error = np.dot(output_error, self.weights.T)
        weight_error = np.dot(self.input.T, output_error)
        self.weights -= learning_rate * weight_error
        self.bias -= learning_rate * output_error
        return input_error

3.2 卷积神经网络 CNN

  • CNN的归纳偏置Inductive Bias:locality(局部性)和translation equivariance(平移等变性)

  • Convolution is a mathematical operation trying to learn the values of filter(s) using backprop, where we have an input I, and an argument, kernel K to produce an output that expresses how the shape of one is modified by another.

  • Convolutional layer is core building block of CNN, it helps with feature detection.

  • Kernel K is a set of learnable filters and is small spatially compared to the image but extends through the full depth of the input image.

  • Dimension of the feature map as a function of the input image size(W), kernel size(F), Stride(S) and Padding(P) is (W−F+2P)/S+1

  • No. of parameters = (Kernel size _ Kernel size _ Dimension ) + 1 (bias)

  • 感受野(Receptive Field)

  • 1*1卷积的作用 (1) 改变通道数 (2) 用于语义分割等密集预测

  • 卷积计算时可以等价于一个大矩阵一次性运算(Orthogonal Convolutional Neural Networks)

  • 在线卷积(Online Convolution)是在数据流式输入的情况下,实时计算卷积操作

import numpy as np

def conv2d(inputs, kernels, bias, stride, padding):
    """ 正向卷积操作
    inputs: 输入数据,形状为 (C, H, W)
    kernels: 卷积核,形状为 (F, C, HH, WW),C是图片输入层数,F是输出层数 filters
    bias: 偏置,形状为 (F,)
    stride: 步长
    padding: 填充
    """
    # 获取输入数据和卷积核的形状
    C, H, W = inputs.shape
    F, _, HH, WW = kernels.shape

    # 对输入数据进行填充。在第一个轴(通常是通道轴)上不进行填充,在第二个轴和第三个轴(通常是高度和宽度轴)上在开始和结束位置都填充padding个值
    inputs_pad = np.pad(inputs, ((0, 0), (padding, padding), (padding, padding)))

    # 初始化输出数据,卷积后的图像size大小
    H_out = 1 + (H + 2 * padding - HH) // stride
    W_out = 1 + (W + 2 * padding - WW) // stride
    outputs = np.zeros((F, H_out, W_out))

    for i in range(H_out):
        for j in range(W_out):  # 找到out图像对于的原始图像区域,然后对图像进行sum和bias
            inputs_slice = inputs_pad[:, i*stride:i*stride+HH, j*stride:j*stride+WW]
            # axis=(1, 2, 3)表示在通道、高度和宽度这三个轴上进行求和
            outputs[:, i, j] = np.sum(inputs_slice * kernels, axis=(1, 2, 3)) + bias
    return outputs

3.3 循环神经网络 RNN

  • RNN的归纳偏置inductive bias:sequentiality和time invariance,即序列顺序上的time-steps有联系,和时间变换的不变性(rnn权重共享)

  • 梯度爆炸与梯度消失

    • 梯度消失:在反向传播过程中累计梯度一直相乘,当很多小于1的梯度出现时导致前面的梯度很小,难以学习long-term dependencies

      • 一般改进: 改进模型

    • 梯度爆炸:the exploding gradient problem当梯度较大,链式法则导致连乘过大,数值不稳定

      • 一般改进: 梯度截断, 权重衰减

    • 通过多个gate

  • 长距离依赖问题

  • 计算复杂度:

    • LSTM: 序列长度 x(hidden**2)

3.4 Transformer

  • 模型结构

    • encoder: embed + layer(self-attention, skip-connect, ln, ffn, skip-connect, ln) * 6

    • decoder: embed + layer(self-attention, cross-attention, ffn, skip-connect, ln) * 6

  • attention

    • 通过使用key向量,模型可以学习到不同模块之间的相似性和差异性,即对于不同的query向量,它可以通过计算query向量与key向量之间的相似度,来确定哪些key向量与该query向量最相似。

    • kq的计算结果,形成一个(n,n)邻接矩阵,再与v相乘形成加权平均的消息传递

    • MLP-mixer提出的抽象,attention是token-mixing,ffn是channel-mixing

    • 多头注意力?增强网络的容量和表达能力,类比CNN中的不同channel

    • 时间和空间复杂度

      • sequence length n, vector representations d. QK矩阵相乘复杂度为O(n^2 d), softmax与V相乘复杂度O(n^2 d)

      • FFN复杂度:O(n d^2)

    • 优化:kv-cache,MQA,GQA,MLA

      • kv cache: 空间换时间,自回归中每次生成一个token,前面的token计算存在重复性

      • Multi Query Attention: MQA 让所有的头之间共享同一份 Key 和 Value 矩阵,每个头只单独保留了一份 Query 参数,从而大大减少 Key 和 Value 矩阵的参数量

      • Group Query Attention: 将查询头分成N组,每个组共享一个Key 和 Value 矩阵

      • Flash attention: 利用GPU硬件非均匀的存储器层次结构; compute attention by blocks to reduce global memory access

  • attention为什么除以根号d

    • 称为attention的temperature。如果输入向量的维度d比较大,那么内积的结果也可能非常大,这会导致注意力分数也变得非常大,可能会使得softmax函数的计算变得不稳定(接近one-hot, 梯度消失),并且会影响模型的训练和推理效果。通过除以根号d,可以将注意力分数缩小到一个合适的范围内,从而使softmax函数计算更加稳定,并且更容易收敛。

    • Google T5采用Xavier初始化缓解梯度消失,从而不需要除根号d

  • 使用正弦作为位置编码

    • self-attention无法表达位置信息,由位置编码提供位置信息

    • 绝对位置编码的案例(sinusoidal、learned)、相对位置编码的案例(T5、XLNet、DeBERTa、ALiBi等)、旋转位置编码(RoPE、xPos)

  • Positional Encoding/Embedding 区别

    • 学习式(learned):直接将位置编码当作可训练参数,比如最大长度为 512,编码维度为 768,那么就初始化一个 512×768 的矩阵作为位置向量,让它随着训练过程更新。BERT、GPT 等模型所用的就是这种位置编码

    • 固定式(fixed):位置编码通过三角函数公式计算得出

  • masking

    • Q*K结果上,加一个很大的负数

  • LSTM相比Transformer有什么优势

  • attention瓶颈

    • low rank,talking-head

  • Transformer是如何处理可变长度数据的?

    • 可变长度的意思: 模型训练好了,一个新的序列长度样本也可以作为输入. 但一个batch内仍需要padding到同一长度

    • 只需要保持参数矩阵维度与输入序列的长度无关,例如全连接层针对feature, 都不影响sequence维度; attention等也都是

  • warmup预热学习率

    • 在训练开始时,模型的参数初始值是随机的,模型还没有学到有效的特征表示。如果此时直接使用较大的学习率进行训练,可能会导致模型的参数值更新过快,从而影响模型的稳定性和收敛速度。此时使用warmup预热学习率的策略可以逐渐增加学习率,使得模型参数逐渐收敛到一定的范围内,提高模型的稳定性和收敛速度。

  • KV Cache

    • 加速推断, 解码过程是一个token一个token生成,如果每一次解码都从输入开始拼接好解码的token,那么会有非常多的重复计算

    • 矩阵乘法性质: 矩阵可以分块,将矩阵A拆分为[:s], [s]两部分,分别和矩阵B相乘,那么最终结果可以直接拼接

def scaled_dot_product(q, k, v, softmax, attention_mask, attention_dropout):
    outputs = tf.matmul(q, k, transpose_b=True)
    dk = tf.math.sqrt(tf.cast(q.shape[-1], dtype=tf.float32))
    outputs = outputs / dk
    # if attention_mask is not None:
    #     outputs = outputs + (1 - attention_mask) * -1e9

    outputs = softmax(outputs, mask=attention_mask)
    outputs = Dropout(rate=attention_dropout)(outputs)
    outputs = tf.matmul(outputs, v)  # shape: (m,Tx,depth), same shape as q,k,v
    return outputs

# multi-head有多种写法: 变成4维的 (batch_size, -1, num_heads, d_k), 变成3维的(batch * num_heads, -1, d_k), 以及下面的循环
class FullAttention(tf.keras.layers.Layer):
    def __init__(self,d_model, num_of_heads, dropout, d_out=None):
        super().__init__()
        self.d_model = d_model
        self.num_of_heads = num_of_heads
        self.dropout = dropout
        self.depth = d_model // num_of_heads
        self.wq = [Dense(self.depth//2, use_bias=False) for i in range(num_of_heads)]
        self.wk = [Dense(self.depth//2, use_bias=False) for i in range(num_of_heads)]
        self.wv = [Dense(self.depth//2, use_bias=False) for i in range(num_of_heads)]
        self.wo = Dense(d_model if d_out is None else d_out, use_bias=False)
        self.softmax = tf.keras.layers.Softmax()

    def call(self, q, k, v, attention_mask=None, training=False):
        multi_attn = []
        for i in range(self.num_of_heads):
            Q = self.wq[i](q)
            K = self.wk[i](k)
            V = self.wv[i](v)
            multi_attn.append(scaled_dot_product(Q, K, V, self.softmax, attention_mask, self.dropout))

        multi_attn = tf.concat(multi_attn, axis=-1)
        multi_head_attention = self.wo(multi_attn)

        return multi_head_attention

3.5 正则化

  • 对模型施加显式的正则化约束

    • L1/L2 weight decay

    • dropout,

    • batch normalization,residual learning, label smoothing

  • 利用数据增广的方法,通过数据层面对模型施加隐式正则化约束

# 标签平滑,hard label转变成soft label,使网络优化更加平滑。有效正则化工具,通过在均匀分布和hard标签之间应用加权平均值来生成soft标签。用于减少训练的过拟合问题并进一步提高分类性能
targets = (1 - label_smooth) * targets + label_smooth / num_classes

3.6 激活函数

  • 激活函数引入非线性

  • sigmoid, relu, mish

3.7 标准化 Norm

Batch Norm

  • BN用来减少 “Internal Covariate Shift” 来加速网络的训练,BN 和 ResNet 的作用类似,都使得 loss landscape 变得更加光滑了 (How Does Batch Normalization Help Optimization)

  • BN在训练和测试过程中,其均值和方差的计算方式是不同的。测试过程中采用的是基于训练时估计的统计值,训练过程中则是采用指数加权平均计算

  • 注意有可训练的参数scale和bias

  • BN,当 batch 较小时不具备统计意义,而加大的 batch 又受硬件的影响;BN 适用于 DNN、CNN 之类固定深度的神经网络,而对于 RNN 这类 sequence 长度不一致的神经网络来说,会出现 sequence 长度不同的情况

  • 分布式训练时,BN的跨卡通信

  • 注意在CV应用时,BM不仅仅在Batch_size维度上进行norm, 还在图像的(H, W) 或者时序的 sequence_length 维度上进行norm

Layer Norm

  • layer normalization 有助于得到一个球体空间中符合0均值1方差高斯分布的 embedding, batch normalization不具备这个功能

  • LayerNorm可以对输入进行归一化,使得每个神经元的输入具有相似的分布特征,从而有助于网络的训练和泛化性能。此外,由于归一化的系数是可学习的,网络可以根据输入数据的特点自适应地学习到合适的归一化系数。

  • 加速模型的训练。由于输入已经被归一化,不同特征之间的尺度差异较小,因此优化过程更容易收敛,加快了模型的训练速度。

  • 为什么不用batch norm? BN广泛用于CV,针对同一特征、跨样本开展归一。样本之间仍然具有可比较性,但特征与特征之间不再具有可比较性。NLP中关键的不在于样本中同一特征的可比较

  • 由于BN需要统计不同样本统计值,因此分布式训练需要sync BatchNorm, Layer Norm则不需要

  • PreNorm/PostNorm

# layer norm: https://www.kaggle.com/code/cpmpml/graph-transfomer?scriptVersionId=24171638&cellId=18
mean = K.mean(inputs, axis=-1, keepdims=True)
variance = K.mean(K.square(inputs - mean), axis=-1, keepdims=True)
std = K.sqrt(variance + self.epsilon)
outputs = (inputs - mean) / std
if self.scale:
    outputs *= self.gamma
if self.center:
    outputs += self.beta

GroupNorm

def GroupNorm(x, gamma, beta, G, eps=1e-5):
    # x: input features with shape [N,C,H,W]
    # gamma, beta: scale and offset, with shape [1,C,1,1]
    # G: number of groups for GN
    N, C, H, W = x.shape
    x = tf.reshape(x, [N, G, C // G, H, W])
    mean, var = tf.nn.moments(x, [2, 3, 4], keep dims=True)
    x = (x - mean) / tf.sqrt(var + eps)
    x = tf.reshape(x, [N, C, H, W])
    return x * gamma + beta
  • RMSNorm舍弃了中心化操作(re-centering),归一化过程只实现缩放(re-scaling),缩放系数是均方根(RMS)

3.8 dropout

  • 训练时,根据binomial分布随机将一些节点置为0,概率为p,剩神经元通过乘一个系数(1/(1-p))保持该层的均值和方差不变;预测时不丢弃神经元,所有神经元输出会被乘以(1-p)

  • 参考AlphaDropout,普通dropout+selu激活函数会导致在回归问题中出现偏差

3.9 pool

import numpy as np

def get_pools(img: np.array, pool_size: int, stride: int) -> np.array:
    pools = []

    # Iterate over all row blocks (single block has `stride` rows)
    for i in np.arange(img.shape[0], step=stride):
        # Iterate over all column blocks (single block has `stride` columns)
        for j in np.arange(img.shape[0], step=stride):
            mat = img[i:i+pool_size, j:j+pool_size]
            # Make sure it's rectangular - has the shape identical to the pool size
            if mat.shape == (pool_size, pool_size):
                # Append to the list of pools
                pools.append(mat)
    return np.array(pools)

def max_pooling(pools: np.array) -> np.array:
    num_pools = pools.shape[0]  # Total number of pools
    # Shape of the matrix after pooling - Square root of the number of pools
    tgt_shape = (int(np.sqrt(num_pools)), int(np.sqrt(num_pools)))

    pooled = []
    for pool in pools:
        pooled.append(np.max(pool))
    return np.array(pooled).reshape(tgt_shape)

4. 训练

  • 混合精度训练

    • 一般用单精度(FP32),半精度(FP16)可以降低内存消耗

    • 训练时,权重、激活值和梯度都使用FP16进行计算,但是会保存FP32的权重值,在梯度更新时对FP32的权重进行更新。在下一步训练时将FP32的权重值转换为FP16再进行FWD和BWD的计算。因为使用FP16进行梯度更新的话,有一些梯度过小就会变成0,从而没有更新。还有权重值比梯度值大太多的时候,相减也会导致梯度消失。

  • 分布式训练

    • 模型并行

    • 数据并行: 同步(synchronous), 异步(asynchronous); Parameter Server, Ring-AllReduce

  • 梯度累积

    • 多步前向之后才后向传播

  • Gradient-Checkpointing

    • 计算梯度的时候,到某一层重新计算激活值,而不保存

reference

  • https://github.com/christianversloot/machine-learning-articles/blob/main/how-to-use-pytorch-loss-functions.md

  • https://www.kaggle.com/code/isbhargav/guide-to-pytorch-learning-rate-scheduling

ce=−ylog(p)−(1−y)log(1−p)ce = - ylog(p) - (1-y)log(1-p)ce=−ylog(p)−(1−y)log(1−p)
Loss=−1N∑n=1N∑i=1Cyn,ilog⁡(pn,i)Loss = -\frac{1}{N} \sum_{n=1}^{N} \sum_{i=1}^{C} y_{n,i} \log(p_{n,i})Loss=−N1​n=1∑N​i=1∑C​yn,i​log(pn,i​)

Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)VAttention(Q,K,V)=softmax(dk​​QKT​)V

PE∗(pos,2i)=sin(pos/100002i/d∗model) \begin{aligned} PE*{(pos,2i)} & = sin(pos/10000^{2i/d*{model}})\ \end{aligned}PE∗(pos,2i)​=sin(pos/100002i/d∗model) ​

PE∗(pos,2i+1)=cos(pos/100002i/d∗model) \begin{aligned} PE*{(pos,2i+1)} & = cos(pos/10000^{2i/d*{model}}) \ \end{aligned}PE∗(pos,2i+1)​=cos(pos/100002i/d∗model) ​

https://github.com/EurekaLabsAI/micrograd
binary cross entropy
Implementing Synchronized Multi-GPU Batch Normalization
RMSNorm - Root Mean Square Layer Normalization
小白都能看懂的超详细Attention机制详解 - 雅正冲蛋的文章 - 知乎
https://github.com/tmheo/deep_learning_study
https://zybuluo.com/hanbingtao/note/581764
深度网络loss除以10和学习率除以10是不是等价的? - 走遍山水路的回答 - 知乎
深度学习中,是否应该打破正负样本1:1的迷信思想? - 密排六方橘子的回答 - 知乎
为什么Layer Norm反向传播的梯度会接近零? - JoJoJoJoya的回答 - 知乎
对比pytorch中的BatchNorm和LayerNorm层 - 严昕的文章 - 知乎
LSTM如何来避免梯度弥散和梯度爆炸? - Quokka的回答 - 知乎
NLP中的Transformer架构在训练和测试时是如何做到decoder的并行化的? - 市井小民的回答 - 知乎
碎碎念:Transformer的细枝末节 - 小莲子的文章 - 知乎
优化时该用SGD,还是用Adam?
对数损失函数
详解深度学习中的梯度消失、爆炸原因及其解决方法 - DoubleV的文章 - 知乎
浅谈后向传递的计算量大约是前向传递的两倍 - 回旋托马斯x的文章 - 知乎
从 0 手撸一个 pytorch - 易迟的文章 - 知乎
如何理解Adam算法(Adaptive Moment Estimation)? - Summer Clover的回答 - 知乎
五、参数量、计算量FLOPS推导 - 小明的HZ的文章 - 知乎
PyTorch 源码解读系列 - OpenMMLab的文章 - 知乎
万字综述,核心开发者全面解读PyTorch内部机制
一文搞懂混合精度训练原理 (常用O1) - APlayBoy的文章 - 知乎