當(dāng)前位置:首頁 > 芯聞號(hào) > 技術(shù)解析
[導(dǎo)讀]為增進(jìn)大家對(duì)pytorch的了解,本文將對(duì)pytorch的簡(jiǎn)單知識(shí)加以講解。如果你對(duì)本文內(nèi)容具有興趣,不妨繼續(xù)往下閱讀哦。

Pytorch作為深度學(xué)習(xí)庫,常被使用。原因在于,pytorch代碼更為簡(jiǎn)單。不管是深度學(xué)習(xí)新手還是老手,pytorch都是一大利器。為增進(jìn)大家對(duì)pytorch的了解,本文將對(duì)pytorch的簡(jiǎn)單知識(shí)加以講解。如果你對(duì)本文內(nèi)容具有興趣,不妨繼續(xù)往下閱讀哦。

1. overview

不同于 theano,tensorflow 等低層程序庫,或者 keras、sonnet 等高層 wrapper,pytorch 是一種自成體系的深度學(xué)習(xí)庫(圖1)。

圖1. 幾種深度學(xué)習(xí)程序庫對(duì)比

如圖2所示,pytorch 由低層到上層主要有三大塊功能模塊。

圖2. pytorch 主要功能模塊

1.1 張量計(jì)算引擎(tensor computaTIon)

Tensor 計(jì)算引擎,類似 numpy 和 matlab,基本對(duì)象是tensor(類比 numpy 中的 ndarray 或 matlab 中的 array)。除提供基于 CPU 的常用操作的實(shí)現(xiàn)外,pytorch 還提供了高效的 GPU 實(shí)現(xiàn),這對(duì)于深度學(xué)習(xí)至關(guān)重要。

1.2 自動(dòng)求導(dǎo)機(jī)制(autograd)

由于深度學(xué)習(xí)模型日趨復(fù)雜,因此,對(duì)自動(dòng)求導(dǎo)的支持對(duì)于學(xué)習(xí)框架變得必不可少。pytorch 采用了動(dòng)態(tài)求導(dǎo)機(jī)制,使用類似方法的框架包括: chainer,dynet。作為對(duì)比,theano,tensorflow 采用靜態(tài)自動(dòng)求導(dǎo)機(jī)制。

1.3 神經(jīng)網(wǎng)絡(luò)的高層庫(NN)

pytorch 還提供了高層的。對(duì)于常用的網(wǎng)絡(luò)結(jié)構(gòu),如全連接、卷積、RNN 等。同時(shí),pytorch 還提供了常用的、opTImizer 及參數(shù)。

這里,我們重點(diǎn)關(guān)注如何自定義神經(jīng)網(wǎng)絡(luò)結(jié)構(gòu)。

2. 自定義 Module

圖3. pytorch Module

module 是 pytorch 組織神經(jīng)網(wǎng)絡(luò)的基本方式。Module 包含了模型的參數(shù)以及計(jì)算邏輯。FuncTIon 承載了實(shí)際的功能,定義了前向和后向的計(jì)算邏輯。

下面以最簡(jiǎn)單的 MLP 網(wǎng)絡(luò)結(jié)構(gòu)為例,介紹下如何實(shí)現(xiàn)自定義網(wǎng)絡(luò)結(jié)構(gòu)。完整代碼可以參見repo。

2.1 FuncTIon

Function 是 pytorch 自動(dòng)求導(dǎo)機(jī)制的核心類。Function 是無參數(shù)或者說無狀態(tài)的,它只負(fù)責(zé)接收輸入,返回相應(yīng)的輸出;對(duì)于反向,它接收輸出相應(yīng)的梯度,返回輸入相應(yīng)的梯度。

這里我們只關(guān)注如何自定義 Function。Function 的定義見。下面是簡(jiǎn)化的代碼段:

class Function(object):

def forward(self, *input):

raise NotImplementedError

def backward(self, *grad_output):

raise NotImplementedError

forward 和 backward 的輸入和輸出都是 Tensor 對(duì)象。

Function 對(duì)象是 callable 的,即可以通過()的方式進(jìn)行調(diào)用。其中調(diào)用的輸入和輸出都為 Variable 對(duì)象。下面的示例了如何實(shí)現(xiàn)一個(gè) ReLU 激活函數(shù)并進(jìn)行調(diào)用:

import torch

from torch.autograd import Function

class ReLUF(Function):

def forward(self, input):

self.save_for_backward(input)

output = input.clamp(min=0)

return output

def backward(self, output_grad):

input = self.to_save[0]

input_grad = output_grad.clone()

input_grad[input < 0] = 0

return input_grad

## Test

if __name__ == "__main__":

from torch.autograd import Variable

torch.manual_seed(1111)

a = torch.randn(2, 3)

va = Variable(a, requires_grad=True)

vb = ReLUF()(va)

print va.data, vb.data

vb.backward(torch.ones(va.size()))

print vb.grad.data, va.grad.data

如果 backward 中需要用到 forward 的輸入,需要在 forward 中顯式的保存需要的輸入。在上面的代碼中,forward 利用self.save_for_backward函數(shù),將輸入暫時(shí)保存,并在 backward 中利用saved_tensors (python tuple 對(duì)象) 取出。

顯然,forward 的輸入應(yīng)該和 backward 的輸入相對(duì)應(yīng);同時(shí),forward 的輸出應(yīng)該和 backward 的輸入相匹配。

由于 Function 可能需要暫存 input tensor,因此,建議不復(fù)用 Function 對(duì)象,以避免遇到內(nèi)存提前釋放的問題。如所示,forward的每次調(diào)用都重新生成一個(gè) ReLUF 對(duì)象,而不能在初始化時(shí)生成在 forward 中反復(fù)調(diào)用。

2.2 Module

類似于 Function,Module 對(duì)象也是 callable 是,輸入和輸出也是 Variable。不同的是,Module 是[可以]有參數(shù)的。Module 包含兩個(gè)主要部分:參數(shù)及計(jì)算邏輯(Function 調(diào)用)。由于ReLU激活函數(shù)沒有參數(shù),這里我們以最基本的全連接層為例來說明如何自定義Module。

全連接層的運(yùn)算邏輯定義如下 Function:

import torch

from torch.autograd import Function

class LinearF(Function):

def forward(self, input, weight, bias=None):

self.save_for_backward(input, weight, bias)

output = torch.mm(input, weight.t())

if bias is not None:

output += bias.unsqueeze(0).expand_as(output)

return output

def backward(self, grad_output):

input, weight, bias = self.saved_tensors

grad_input = grad_weight = grad_bias = None

if self.needs_input_grad[0]:

grad_input = torch.mm(grad_output, weight)

if self.needs_input_grad[1]:

grad_weight = torch.mm(grad_output.t(), input)

if bias is not None and self.needs_input_grad[2]:

grad_bias = grad_output.sum(0).squeeze(0)

if bias is not None:

return grad_input, grad_weight, grad_bias

else:

return grad_input, grad_weight

為一個(gè)元素為 bool 型的 tuple,長(zhǎng)度與 forward 的參數(shù)數(shù)量相同,用來標(biāo)識(shí)各個(gè)輸入是否輸入計(jì)算梯度;對(duì)于無需梯度的輸入,可以減少不必要的計(jì)算。

Function(此處為 LinearF) 定義了基本的計(jì)算邏輯,Module 只需要在初始化時(shí)為參數(shù)分配內(nèi)存空間,并在計(jì)算時(shí),將參數(shù)傳遞給相應(yīng)的 Function 對(duì)象。代碼如下:

import torch

import torch.nn as nn

class Linear(nn.Module):

def __init__(self, in_features, out_features, bias=True):

super(Linear, self).__init__()

self.in_features = in_features

self.out_features = out_features

self.weight = nn.Parameter(torch.Tensor(out_features, in_features))

if bias:

self.bias = nn.Parameter(torch.Tensor(out_features))

else:

self.register_parameter('bias', None)

def forward(self, input):

return LinearF()(input, self.weight, self.bias)

需要注意的是,參數(shù)是內(nèi)存空間由 tensor 對(duì)象維護(hù),但 tensor 需要包裝為一個(gè)Parameter 對(duì)象。Parameter 是 Variable 的特殊子類,僅有是不同是 Parameter 默認(rèn)requires_grad為 True。Varaible 是自動(dòng)求導(dǎo)機(jī)制的核心類,此處暫不介紹,參見。

3. 自定義循環(huán)神經(jīng)網(wǎng)絡(luò)(RNN)

我們嘗試自己定義一個(gè)更復(fù)雜的 Module ——RNN。這里,我們只定義最基礎(chǔ)的 vanilla RNN(圖4),基本的計(jì)算公式如下:

ht=relu(W?x+U?ht?1)

圖4. RNN

更復(fù)雜的 LSTM、GRU 或者其他變種的實(shí)現(xiàn)也非常類似。

3.1 定義 Cell

import torch

from torch.nn import Module, Parameter

class RNNCell(Module):

def __init__(self, input_size, hidden_size):

super(RNNCell, self).__init__()

self.input_size = input_size

self.hidden_size = hidden_size

self.weight_ih = Parameter(torch.Tensor(hidden_size, input_size))

self.weight_hh = Parameter(torch.Tensor(hidden_size, hidden_size))

self.bias_ih = Parameter(torch.Tensor(hidden_size))

self.bias_hh = Parameter(torch.Tensor(hidden_size))

self.reset_parameters()

def reset_parameters(self):

stdv = 1.0 / math.sqrt(self.hidden_size)

for weight in self.parameters():

weight.data.uniform_(-stdv, stdv)

def forward(self, input, h):

output = LinearF()(input, self.weight_ih, self.bias_ih) + LinearF()(h, self.weight_hh, self.bias_hh)

output = ReLUF()(output)

return output

3.2 定義完整的 RNN

import torch

from torch.nn import Module

class RNN(Moudule):

def __init__(self, input_size, hidden_size):

super(RNN, self).__init__()

self.input_size = input_size

self.hidden_size = hidden_size

sef.cell = RNNCell(input_size, hidden_size)

def forward(self, inputs, initial_state):

time_steps = inputs.size(1)

state = initial_state

outputs = []

for t in range(time_steps):

state = self.cell(inputs[:, t, :], state)

outputs.append(state)

return outputs

討論

pytorch 的 Module 結(jié)構(gòu)是傳承自 torch,這一點(diǎn)也同樣被 keras (functional API)所借鑒。 在 caffe 等一些[早期的]深度學(xué)習(xí)框架中,network 是由于若干 layer ,經(jīng)由不同的拓?fù)浣Y(jié)構(gòu)組成的。而在 (pyt)torch 中沒有 layer 和 network 是區(qū)分,一切都是 callable 的 Module。Module 的調(diào)用的輸入和輸出都是 tensor (由 Variable 封裝),用戶可以非常自然的構(gòu)造任意有向無環(huán)的網(wǎng)絡(luò)結(jié)構(gòu)(DAG)。

同時(shí), pytorch 的 autograd 機(jī)制封裝的比較淺,可以比較容易的定制反傳或修改梯度。這對(duì)有些算法是非常重要。

總之,僅就自定義算法而言,pytorch 是一個(gè)非常優(yōu)雅的深度學(xué)習(xí)框架。

以上便是此次小編帶來的“pytorch”相關(guān)內(nèi)容,通過本文,希望大家對(duì)上述知識(shí)具備一定的了解。如果你喜歡本文,不妨持續(xù)關(guān)注我們網(wǎng)站哦,小編將于后期帶來更多精彩內(nèi)容。最后,十分感謝大家的閱讀,have a nice day!

本站聲明: 本文章由作者或相關(guān)機(jī)構(gòu)授權(quán)發(fā)布,目的在于傳遞更多信息,并不代表本站贊同其觀點(diǎn),本站亦不保證或承諾內(nèi)容真實(shí)性等。需要轉(zhuǎn)載請(qǐng)聯(lián)系該專欄作者,如若文章內(nèi)容侵犯您的權(quán)益,請(qǐng)及時(shí)聯(lián)系本站刪除。
換一批
延伸閱讀

9月2日消息,不造車的華為或?qū)⒋呱龈蟮莫?dú)角獸公司,隨著阿維塔和賽力斯的入局,華為引望愈發(fā)顯得引人矚目。

關(guān)鍵字: 阿維塔 塞力斯 華為

倫敦2024年8月29日 /美通社/ -- 英國汽車技術(shù)公司SODA.Auto推出其旗艦產(chǎn)品SODA V,這是全球首款涵蓋汽車工程師從創(chuàng)意到認(rèn)證的所有需求的工具,可用于創(chuàng)建軟件定義汽車。 SODA V工具的開發(fā)耗時(shí)1.5...

關(guān)鍵字: 汽車 人工智能 智能驅(qū)動(dòng) BSP

北京2024年8月28日 /美通社/ -- 越來越多用戶希望企業(yè)業(yè)務(wù)能7×24不間斷運(yùn)行,同時(shí)企業(yè)卻面臨越來越多業(yè)務(wù)中斷的風(fēng)險(xiǎn),如企業(yè)系統(tǒng)復(fù)雜性的增加,頻繁的功能更新和發(fā)布等。如何確保業(yè)務(wù)連續(xù)性,提升韌性,成...

關(guān)鍵字: 亞馬遜 解密 控制平面 BSP

8月30日消息,據(jù)媒體報(bào)道,騰訊和網(wǎng)易近期正在縮減他們對(duì)日本游戲市場(chǎng)的投資。

關(guān)鍵字: 騰訊 編碼器 CPU

8月28日消息,今天上午,2024中國國際大數(shù)據(jù)產(chǎn)業(yè)博覽會(huì)開幕式在貴陽舉行,華為董事、質(zhì)量流程IT總裁陶景文發(fā)表了演講。

關(guān)鍵字: 華為 12nm EDA 半導(dǎo)體

8月28日消息,在2024中國國際大數(shù)據(jù)產(chǎn)業(yè)博覽會(huì)上,華為常務(wù)董事、華為云CEO張平安發(fā)表演講稱,數(shù)字世界的話語權(quán)最終是由生態(tài)的繁榮決定的。

關(guān)鍵字: 華為 12nm 手機(jī) 衛(wèi)星通信

要點(diǎn): 有效應(yīng)對(duì)環(huán)境變化,經(jīng)營(yíng)業(yè)績(jī)穩(wěn)中有升 落實(shí)提質(zhì)增效舉措,毛利潤(rùn)率延續(xù)升勢(shì) 戰(zhàn)略布局成效顯著,戰(zhàn)新業(yè)務(wù)引領(lǐng)增長(zhǎng) 以科技創(chuàng)新為引領(lǐng),提升企業(yè)核心競(jìng)爭(zhēng)力 堅(jiān)持高質(zhì)量發(fā)展策略,塑強(qiáng)核心競(jìng)爭(zhēng)優(yōu)勢(shì)...

關(guān)鍵字: 通信 BSP 電信運(yùn)營(yíng)商 數(shù)字經(jīng)濟(jì)

北京2024年8月27日 /美通社/ -- 8月21日,由中央廣播電視總臺(tái)與中國電影電視技術(shù)學(xué)會(huì)聯(lián)合牽頭組建的NVI技術(shù)創(chuàng)新聯(lián)盟在BIRTV2024超高清全產(chǎn)業(yè)鏈發(fā)展研討會(huì)上宣布正式成立。 活動(dòng)現(xiàn)場(chǎng) NVI技術(shù)創(chuàng)新聯(lián)...

關(guān)鍵字: VI 傳輸協(xié)議 音頻 BSP

北京2024年8月27日 /美通社/ -- 在8月23日舉辦的2024年長(zhǎng)三角生態(tài)綠色一體化發(fā)展示范區(qū)聯(lián)合招商會(huì)上,軟通動(dòng)力信息技術(shù)(集團(tuán))股份有限公司(以下簡(jiǎn)稱"軟通動(dòng)力")與長(zhǎng)三角投資(上海)有限...

關(guān)鍵字: BSP 信息技術(shù)
關(guān)閉
關(guān)閉