`
wx1568037608
  • 浏览: 33483 次
最近访客 更多访客>>
文章分类
社区版块
存档分类
最新评论

PyTorch学习笔记之CBOW模型实践

 
阅读更多


复制代码

 1 import torch
 2 from torch import nn, optim
 3 from torch.autograd import Variable
 4 import torch.nn.functional as F
 5 
 6 CONTEXT_SIZE = 2  # 2 words to the left, 2 to the right
 7 raw_text = "We are about to study the idea of a computational process. Computational processes are abstract beings that inhabit computers. As they evolve, processes manipulate other abstract things called data. The evolution of a process is directed by a pattern of rules called a program. People create programs to direct processes. In effect, we conjure the spirits of the computer with our spells.".split(' ')
 8 
 9 vocab = set(raw_text)
10 word_to_idx = {word: i for i, word in enumerate(vocab)}
11 
12 data = []
13 for i in range(CONTEXT_SIZE, len(raw_text)-CONTEXT_SIZE):
14     context = [raw_text[i-2], raw_text[i-1], raw_text[i+1], raw_text[i+2]]
15     target = raw_text[i]
16     data.append((context, target))
17 
18 
19 class CBOW(nn.Module):
20     def __init__(self, n_word, n_dim, context_size):
21         super(CBOW, self).__init__()
22         self.embedding = nn.Embedding(n_word, n_dim)
23         self.linear1 = nn.Linear(2*context_size*n_dim, 128)
24         self.linear2 = nn.Linear(128, n_word)
25 
26     def forward(self, x):
27         x = self.embedding(x)
28         x = x.view(1, -1)
29         x = self.linear1(x)
30         x = F.relu(x, inplace=True)
31         x = self.linear2(x)
32         x = F.log_softmax(x)
33         return x
34 
35 
36 model = CBOW(len(word_to_idx), 100, CONTEXT_SIZE)
37 if torch.cuda.is_available():
38     model = model.cuda()
39 
40 criterion = nn.CrossEntropyLoss()
41 optimizer = optim.SGD(model.parameters(), lr=1e-3)
42 
43 for epoch in range(100):
44     print('epoch {}'.format(epoch))
45     print('*'*10)
46     running_loss = 0
47     for word in data:
48         context, target = word
49         context = Variable(torch.LongTensor([word_to_idx[i] for i in context]))
50         target = Variable(torch.LongTensor([word_to_idx[target]]))
51         if torch.cuda.is_available():
52             context = context.cuda()
53             target = target.cuda()
54         # forward
55         out = model(context)
56         loss = criterion(out, target)
57         running_loss += loss.data[0]
58         # backward
59         optimizer.zero_grad()
60         loss.backward()
61         optimizer.step()
62     print('loss: {:.6f}'.format(running_loss / len(data)))
分享到:
评论

相关推荐

    中英文语料训练CBOW模型获得词向量(pytorch实现)

    自然语言处理第二次作业: data文件夹中存储语料(中文语料以及英文语料由老师提供,另一份为...script文件夹中为CBOW的脚本,同时处理中文语料与英文语料 运行步骤:在脚本中确定训练中文或者是英语后,直接运行即可

    B站刘二大人Pytorch课程学习笔记及课后作业

    本文基于B站刘二大人PyTorch实践课程的学习笔记,深入探讨了PyTorch的核心概念、模型构建以及卷积神经网络的应用。 1. **概述** PyTorch的基础包括线性代数、概率论和Python基础知识。监督学习是主要的机器学习...

    笔记pytorch学习笔记

    【PyTorch学习笔记概述】 PyTorch是Facebook开源的一款深度学习框架,它以其灵活性、易用性和强大的计算能力在学术界和工业界都受到了广泛的欢迎。本笔记将基于B站牛二大人的讲解,深入探讨PyTorch的核心概念、基本...

    自然语言处理-pytorch-CBOW实验数据集

    ### 自然语言处理-pytorch-...通过对上述知识点的学习,我们可以了解到CBOW模型的基本概念、应用场景以及如何使用PyTorch进行实现。通过实践操作,读者能够更好地掌握CBOW模型的原理及其在自然语言处理领域的应用价值。

    PyTorch深度学习实践_pytorch_深度学习_

    pytorch深度学习实践,深度学习实践入门,内附pdf,代码。

    pytorch学习笔记

    pytorch学习笔记

    pytorch 学习笔记

    pytorch 学习笔记

    基于pytorch的模型剪枝+模型量化+BN合并+TRT部署(cifar数据)

    在深度学习领域,模型优化是提高模型性能和部署效率的关键环节。本项目聚焦于四个关键技术:模型剪枝、模型量化、批归一化(BN)层的合并以及使用TensorRT进行部署,这些技术都是针对PyTorch框架进行的。下面将详细...

    深度学习框架pytorch入门与实践pdf书与代码.zip

    《深度学习框架PyTorch入门与实践》是一本旨在引导初学者和有一定基础的开发者深入理解并掌握PyTorch这一强大深度学习库的书籍。PyTorch是Facebook AI Research(FAIR)团队开发的一个开源机器学习框架,以其易用性...

    2024年度最新PyTorch深度学习实践

    2024年度最新PyTorch深度学习实践2024年度最新PyTorch深度学习实践2024年度最新PyTorch深度学习实践2024年度最新PyTorch深度学习实践2024年度最新PyTorch深度学习实践2024年度最新PyTorch深度学习实践 2024年度最新...

    人脸识别项目实战arcface-pytorch源码+预训练模型+测试集.zip

    人脸识别项目实战arcface-pytorch源码+预训练模型+测试集.zip人脸识别项目实战arcface-pytorch源码+预训练模型+测试集.zip人脸识别项目实战arcface-pytorch源码+预训练模型+测试集.zip人脸识别项目实战arcface-...

    人工智能-项目实践-模型压缩-针对pytorch模型的自动化模型结构分析和修改工具集,包含自动分析模型结构的模型压缩算法库

    人工智能-项目实践-模型压缩-针对pytorch模型的自动化模型结构分析和修改工具集,包含自动分析模型结构的模型压缩算法库 requirement onnx>=1.6 onnxruntime>=1.5 pytorch>=1.7 tensorboardX>=1.8 scikit-learn ...

    pytorch课堂笔记.pdf

    而PyTorch则为研究人员和开发人员提供了一个强大的工具,使得构建和训练深度学习模型变得更为容易和高效。通过深入理解这些概念,我们可以更好地把握机器学习和深度学习的未来发展方向,并将其应用到解决现实世界的...

    深度学习框架pytorch入门与实践源代码.rar

    《PyTorch深度学习入门与实践》源代码...通过掌握PyTorch的基本概念、核心模块和实践技巧,可以高效地进行深度学习模型的设计和训练。提供的源代码将帮助读者从实践中深化对PyTorch的理解,进一步提升深度学习能力。

    【Pytorch 技术文档】Pytorch基础教程之torchserve模型部署解析

    【Pytorch 技术文档】Pytorch基础教程之torchserve模型部署解析

    pytorch性别识别.pt模型

    这是我写性别识别demo训练出来的模型,大家可以下载使用。本人亲自迁移到Android设备,完全没有问题。

    Python-一些关于pytorch深入学习的笔记

    以上这些是PyTorch深度学习笔记中可能涉及的主要内容,通过学习和实践,你将能够熟练地运用PyTorch构建和训练深度学习模型,解决实际问题。在"Deep-Learning-master"这个项目中,你可能会发现更具体的示例代码和实践...

    PyTorch深度学习实践

    《PyTorch深度学习实践》课程是一门针对Python编程者和数据科学家的进阶课程,主要聚焦于使用PyTorch框架进行深度学习模型的设计与实现。PyTorch是Facebook开源的一个强大工具,它以其灵活性、易用性和强大的计算...

    PyTorch 模型训练实⽤教程_余霆嵩(去水印)

    本教程内容主要为在 PyTorch 中训练一个模型所可能涉及到的方法及函 数,并且对 PyTorch 提供的数据增强方法(22 个)、权值初始化方法(10 个)、损失函数(17 个)、优化器(6 个)及 tensorboardX 的方法(13 个...

Global site tag (gtag.js) - Google Analytics