`
wx1568037608
  • 浏览: 35829 次
最近访客 更多访客>>
文章分类
社区版块
存档分类
最新评论

Bert系列(三)——源码解读之Pre-train

 
阅读更多

https://www.jianshu.com/p/22e462f01d8c

pre-train是迁移学习的基础,虽然Google已经发布了各种预训练好的模型,而且因为资源消耗巨大,自己再预训练也不现实(在Google Cloud TPU v2 上训练BERT-Base要花费近500刀,耗时达到两周。在GPU上可想而知只会更贵),但是学习bert的预训练方法可以为我们弄懂整个bert的运行流程提供莫大的帮助。预训练涉及到的模块有点多,所以这也将会是一篇长文,在能简略的地方我尽量简略,还是那句话,我的文章只能是起到一个导读的作用,如果想摸清里面的各种细节还是要自己把源码过一遍的。

pre-train涉及到的模块分为以下三个,我将为大家一一介绍:

1.tokenization.py

2.create_pretraining_data.py

3.run_pretraining.py

其中tokenization是对原始句子内容的解析,分为BasicTokenizer和WordpieceTokenizer两个,不只是在预训练中,在fine-tune和推断过程同样要用到它;create_pretraining_data顾名思义就是将原始语料转换成适合模型预训练的输入数据;run_pretraining就是预训练的执行代码了。

一、tokenization.py

1、BasicTokenizer

class BasicTokenizer(object):
  """Runs basic tokenization (punctuation splitting, lower casing, etc.)."""

  def __init__(self, do_lower_case=True):
    self.do_lower_case = do_lower_case

  def tokenize(self, text):
    """Tokenizes a piece of text."""
    text = convert_to_unicode(text)
    text = self._clean_text(text)

    text = self._tokenize_chinese_chars(text)

    orig_tokens = whitespace_tokenize(text)
    split_tokens = []
    for token in orig_tokens:
      if self.do_lower_case:
        token = token.lower()
        token = self._run_strip_accents(token)
      split_tokens.extend(self._run_split_on_punc(token))

    output_tokens = whitespace_tokenize(" ".join(split_tokens))
    return output_tokens

  def _run_strip_accents(self, text):
    """Strips accents from a piece of text."""
    text = unicodedata.normalize("NFD", text)
    output = []
    for char in text:
      cat = unicodedata.category(char)
      if cat == "Mn":
        continue
      output.append(char)
    return "".join(output)

  def _run_split_on_punc(self, text):
    """Splits punctuation on a piece of text."""
    chars = list(text)
    i = 0
    start_new_word = True
    output = []
    while i < len(chars):
      char = chars[i]
      if _is_punctuation(char):
        output.append([char])
        start_new_word = True
      else:
        if start_new_word:
          output.append([])
        start_new_word = False
        output[-1].append(char)
      i += 1

    return ["".join(x) for x in output]

  def _tokenize_chinese_chars(self, text):
    """Adds whitespace around any CJK character."""
    output = []
    for char in text:
      cp = ord(char)
      if self._is_chinese_char(cp):
        output.append(" ")
        output.append(char)
        output.append(" ")
      else:
        output.append(char)
    return "".join(output)

  def _is_chinese_char(self, cp):
    """Checks whether CP is the codepoint of a CJK character."""
    if ((cp >= 0x4E00 and cp <= 0x9FFF) or  #
        (cp >= 0x3400 and cp <= 0x4DBF) or  #
        (cp >= 0x20000 and cp <= 0x2A6DF) or  #
        (cp >= 0x2A700 and cp <= 0x2B73F) or  #
        (cp >= 0x2B740 and cp <= 0x2B81F) or  #
        (cp >= 0x2B820 and cp <= 0x2CEAF) or
        (cp >= 0xF900 and cp <= 0xFAFF) or  #
        (cp >= 0x2F800 and cp <= 0x2FA1F)):  #
      return True
    return False

  def _clean_text(self, text):
    """Performs invalid character removal and whitespace cleanup on text."""
    output = []
    for char in text:
      cp = ord(char)
      if cp == 0 or cp == 0xfffd or _is_control(char):
        continue
      if _is_whitespace(char):
        output.append(" ")
      else:
        output.append(char)
    return "".join(output)

BasicTokenizer的主要是进行unicode转换、标点符号分割、小写转换、中文字符分割、去除重音符号等操作,最后返回的是关于词的数组(中文是字的数组)

2、WordpieceTokenizer

class WordpieceTokenizer(object):
  """Runs WordPiece tokenziation."""

  def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200):
    self.vocab = vocab
    self.unk_token = unk_token
    self.max_input_chars_per_word = max_input_chars_per_word

  def tokenize(self, text):
    text = convert_to_unicode(text)
    output_tokens = []
    for token in whitespace_tokenize(text):
      chars = list(token)
      if len(chars) > self.max_input_chars_per_word:
        output_tokens.append(self.unk_token)
        continue
      is_bad = False
      start = 0
      sub_tokens = []
      while start < len(chars):
        end = len(chars)
        cur_substr = None
        while start < end:
          substr = "".join(chars[start:end])
          if start > 0:
            substr = "##" + substr
          if substr in self.vocab:
            cur_substr = substr
            break
          end -= 1
        if cur_substr is None:
          is_bad = True
          break
        sub_tokens.append(cur_substr)
        start = end

      if is_bad:
        output_tokens.append(self.unk_token)
      else:
        output_tokens.extend(sub_tokens)
    return output_tokens

WordpieceTokenizer的目的是将合成词分解成类似词根一样的词片。例如将"unwanted"分解成["un", "##want", "##ed"]这么做的目的是防止因为词的过于生僻没有被收录进词典最后只能以[UNK]代替的局面,因为英语当中这样的合成词非常多,词典不可能全部收录。

3、FullTokenizer

class FullTokenizer(object):
  """Runs end-to-end tokenziation."""

  def __init__(self, vocab_file, do_lower_case=True):
    self.vocab = load_vocab(vocab_file)
    self.inv_vocab = {v: k for k, v in self.vocab.items()}
    self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case)
    self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)

  def tokenize(self, text):
    split_tokens = []
    for token in self.basic_tokenizer.tokenize(text):
      for sub_token in self.wordpiece_tokenizer.tokenize(token):
        split_tokens.append(sub_token)

    return split_tokens

  def convert_tokens_to_ids(self, tokens):
    return convert_by_vocab(self.vocab, tokens)

  def convert_ids_to_tokens(self, ids):
    return convert_by_vocab(self.inv_vocab, ids)

FullTokenizer的作用就很显而易见了,对一个文本段进行以上两种解析,最后返回词(字)的数组,同时还提供token到id的索引以及id到token的索引。这里的token可以理解为文本段处理过后的最小单元。

二、create_pretraining_data.py

1、配置

flags.DEFINE_string("input_file", None,
                    "Input raw text file (or comma-separated list of files).")
flags.DEFINE_string(
    "output_file", None,
    "Output TF example file (or comma-separated list of files).")
flags.DEFINE_string("vocab_file", None,
                    "The vocabulary file that the BERT model was trained on.")
flags.DEFINE_bool(
    "do_lower_case", True,
    "Whether to lower case the input text. Should be True for uncased "
    "models and False for cased models.")
flags.DEFINE_integer("max_seq_length", 128, "Maximum sequence length.")
flags.DEFINE_integer("max_predictions_per_seq", 20,
                     "Maximum number of masked LM predictions per sequence.")
flags.DEFINE_integer("random_seed", 12345, "Random seed for data generation.")
flags.DEFINE_integer(
    "dupe_factor", 10,
    "Number of times to duplicate the input data (with different masks).")
flags.DEFINE_float("masked_lm_prob", 0.15, "Masked LM probability.")
flags.DEFINE_float(
    "short_seq_prob", 0.1,
    "Probability of creating sequences which are shorter than the "
    "maximum length.")

配置input_file、output_file分别代表输入的源语料文件和处理过的预料文件地址;

do_lower_case:是否全部转为小写字母,是否转换成小写字母的意义在Bert系列(一)——demo运行里面已经说过了。

dupe_factor:默认重复10次,目的是可以生成不同情况的masks;

short_seq_prob:构造长度小于指定"max_seq_length"的样本比例。因为在fine-tune过程里面输入的target_seq_length是可变的(小于等于max_seq_length),那么为了防止过拟合也需要在pre-train的过程当中构造一些短的样本。

2、main入口

def main(_):
  tf.logging.set_verbosity(tf.logging.INFO)

  tokenizer = tokenization.FullTokenizer(
      vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)

  input_files = []
  for input_pattern in FLAGS.input_file.split(","):
    input_files.extend(tf.gfile.Glob(input_pattern))

  tf.logging.info("*** Reading from input files ***")
  for input_file in input_files:
    tf.logging.info("  %s", input_file)

  rng = random.Random(FLAGS.random_seed)
  instances = create_training_instances(
      input_files, tokenizer, FLAGS.max_seq_length, FLAGS.dupe_factor,
      FLAGS.short_seq_prob, FLAGS.masked_lm_prob, FLAGS.max_predictions_per_seq,
      rng)

  output_files = FLAGS.output_file.split(",")
  tf.logging.info("*** Writing to output files ***")
  for output_file in output_files:
    tf.logging.info("  %s", output_file)

  write_instance_to_example_files(instances, tokenizer, FLAGS.max_seq_length,
                                  FLAGS.max_predictions_per_seq, output_files)

从入口开始看,步骤很简单:1)构造tokenizer ;2)构造instances ;3)保存instances

3、构造instances

def create_training_instances(input_files, tokenizer, max_seq_length,
                              dupe_factor, short_seq_prob, masked_lm_prob,
                              max_predictions_per_seq, rng):
  """Create `TrainingInstance`s from raw text."""
  all_documents = [[]]
  for input_file in input_files:
    with tf.gfile.GFile(input_file, "r") as reader:
      while True:
        line = tokenization.convert_to_unicode(reader.readline())
        if not line:
          break
        line = line.strip()

        # Empty lines are used as document delimiters
        if not line:
          all_documents.append([])
        tokens = tokenizer.tokenize(line)
        if tokens:
          all_documents[-1].append(tokens)

  # Remove empty documents
  all_documents = [x for x in all_documents if x]
  rng.shuffle(all_documents)

  vocab_words = list(tokenizer.vocab.keys())
  instances = []
  for _ in range(dupe_factor):
    for document_index in range(len(all_documents)):
      instances.extend(
          create_instances_from_document(
              all_documents, document_index, max_seq_length, short_seq_prob,
              masked_lm_prob, max_predictions_per_seq, vocab_words, rng))

  rng.shuffle(instances)
  return instances

这一步是阅读数据,数据的输入文本可以是一个文件也可以是用逗号分割的若干文件;
文件里用换行来表示句子的边界,即一句一行,同理段落之间用空一行来表示段落的边界,一个段落表示成一个document;具体的构造方法在create_instances_from_document函数里面。

def create_instances_from_document(
    all_documents, document_index, max_seq_length, short_seq_prob,
    masked_lm_prob, max_predictions_per_seq, vocab_words, rng):
  """Creates `TrainingInstance`s for a single document."""
  document = all_documents[document_index]

  # Account for [CLS], [SEP], [SEP]
  max_num_tokens = max_seq_length - 3
  target_seq_length = max_num_tokens
  if rng.random() < short_seq_prob:
    target_seq_length = rng.randint(2, max_num_tokens)

  instances = []
  current_chunk = []
  current_length = 0
  i = 0
  while i < len(document):
    segment = document[i]
    current_chunk.append(segment)
    current_length += len(segment)
    if i == len(document) - 1 or current_length >= target_seq_length:
      if current_chunk:
        # `a_end` is how many segments from `current_chunk` go into the `A`
        # (first) sentence.
        a_end = 1
        if len(current_chunk) >= 2:
          a_end = rng.randint(1, len(current_chunk) - 1)

        tokens_a = []
        for j in range(a_end):
          tokens_a.extend(current_chunk[j])

        tokens_b = []
        # Random next
        is_random_next = False
        if len(current_chunk) == 1 or rng.random() < 0.5:
          is_random_next = True
          target_b_length = target_seq_length - len(tokens_a)

          for _ in range(10):
            random_document_index = rng.randint(0, len(all_documents) - 1)
            if random_document_index != document_index:
              break

          random_document = all_documents[random_document_index]
          random_start = rng.randint(0, len(random_document) - 1)
          for j in range(random_start, len(random_document)):
            tokens_b.extend(random_document[j])
            if len(tokens_b) >= target_b_length:
              break
  
          num_unused_segments = len(current_chunk) - a_end
          i -= num_unused_segments
        # Actual next
        else:
          is_random_next = False
          for j in range(a_end, len(current_chunk)):
            tokens_b.extend(current_chunk[j])
        truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng)

        assert len(tokens_a) >= 1
        assert len(tokens_b) >= 1

        tokens = []
        segment_ids = []
        tokens.append("[CLS]")
        segment_ids.append(0)
        for token in tokens_a:
          tokens.append(token)
          segment_ids.append(0)

        tokens.append("[SEP]")
        segment_ids.append(0)

        for token in tokens_b:
          tokens.append(token)
          segment_ids.append(1)
        tokens.append("[SEP]")
        segment_ids.append(1)

        (tokens, masked_lm_positions,
         masked_lm_labels) = create_masked_lm_predictions(
             tokens, masked_lm_prob, max_predictions_per_seq, vocab_words, rng)
        instance = TrainingInstance(
            tokens=tokens,
            segment_ids=segment_ids,
            is_random_next=is_random_next,
            masked_lm_positions=masked_lm_positions,
            masked_lm_labels=masked_lm_labels)
        instances.append(instance)
      current_chunk = []
      current_length = 0
    i += 1
  return instances

这一段算是整个模块的核心了。

instance = TrainingInstance(
            tokens=tokens,
            segment_ids=segment_ids,
            is_random_next=is_random_next,
            masked_lm_positions=masked_lm_positions,
            masked_lm_labels=masked_lm_labels)

1)一个instance 包含一个tokens,实际上就是输入的词序列;该序列表现形式为:

[CLS] A [SEP] B [SEP]

A=[token_0, token_1, ...,token_i]
B=[token_i+1, token_i+2, ...,token_n-1]

其中:
2<= n < max_seq_length - 3 (in short_seq_prob)
n=max_seq_length - 3 (in 1-short_seq_prob)

token 最后表现形式如下图所示:


 
tokens示意图

segment_ids 指的形式为[0,0,0...1,1,111] 0的个数为i+1个,1的个数为max_seq_length - (i+1)
对应到模型输入就是token_type

is_random_next:其实就是上图的Label,0.5的概率为True(和当只有一个segment的时候),如果为True则B和A不属于同一document。剩下的情况为False,则B为A同一document的后续句子。

masked_lm_positions:序列里被[MASK]的位置;

masked_lm_labels:序列里被[MASK]的token

2)在create_masked_lm_predictions函数里,一个序列在指定MASK数量之后,有80%被真正MASK,10%还是保留原来token,10%被随机替换成其他token。

4、保存instance

def write_instance_to_example_files(instances, tokenizer, max_seq_length,
                                    max_predictions_per_seq, output_files):
  """Create TF example files from `TrainingInstance`s."""
  writers = []
  for output_file in output_files:
    writers.append(tf.python_io.TFRecordWriter(output_file))

  writer_index = 0

  total_written = 0
  for (inst_index, instance) in enumerate(instances):
    input_ids = tokenizer.convert_tokens_to_ids(instance.tokens)
    input_mask = [1] * len(input_ids)
    segment_ids = list(instance.segment_ids)
    assert len(input_ids) <= max_seq_length

    while len(input_ids) < max_seq_length:
      input_ids.append(0)
      input_mask.append(0)
      segment_ids.append(0)

    assert len(input_ids) == max_seq_length
    assert len(input_mask) == max_seq_length
    assert len(segment_ids) == max_seq_length

    masked_lm_positions = list(instance.masked_lm_positions)
    masked_lm_ids = tokenizer.convert_tokens_to_ids(instance.masked_lm_labels)
    masked_lm_weights = [1.0] * len(masked_lm_ids)

    while len(masked_lm_positions) < max_predictions_per_seq:
      masked_lm_positions.append(0)
      masked_lm_ids.append(0)
      masked_lm_weights.append(0.0)

    next_sentence_label = 1 if instance.is_random_next else 0

    features = collections.OrderedDict()
    features["input_ids"] = create_int_feature(input_ids)
    features["input_mask"] = create_int_feature(input_mask)
    features["segment_ids"] = create_int_feature(segment_ids)
    features["masked_lm_positions"] = create_int_feature(masked_lm_positions)
    features["masked_lm_ids"] = create_int_feature(masked_lm_ids)
    features["masked_lm_weights"] = create_float_feature(masked_lm_weights)
    features["next_sentence_labels"] = create_int_feature([next_sentence_label])

    tf_example = tf.train.Example(features=tf.train.Features(feature=features))

    writers[writer_index].write(tf_example.SerializeToString())
    writer_index = (writer_index + 1) % len(writers)

    total_written += 1

    if inst_index < 20:
      tf.logging.info("*** Example ***")
      tf.logging.info("tokens: %s" % " ".join(
          [tokenization.printable_text(x) for x in instance.tokens]))

      for feature_name in features.keys():
        feature = features[feature_name]
        values = []
        if feature.int64_list.value:
          values = feature.int64_list.value
        elif feature.float_list.value:
          values = feature.float_list.value
        tf.logging.info(
            "%s: %s" % (feature_name, " ".join([str(x) for x in values])))

  for writer in writers:
    writer.close()

  tf.logging.info("Wrote %d total instances", total_written)

instance保存没什么好说的,只有两点:

while len(input_ids) < max_seq_length:
      input_ids.append(0)
      input_mask.append(0)
      segment_ids.append(0)

1)之前不是有short_seq_prob的概率导致样本的长度小于max_predictions_per_seq吗,这里把这些样本补齐,padding为0,同样的还有input_mask和segment_ids;
2) 把instance的is_random_next转化成变量next_sentence_label保存。

为了验证这个数据模块对中文输入输出的支持,我做了个测试:

python3 create_pretraining_data.py   --input_file=/tmp/zh_test.txt   --output_file=/tmp/output.txt   --vocab_file=$BERT_ZH_DIR/vocab.txt

zh_test.txt是我脸滚键盘随意输入的一些汉字,共有两段,每段两句话:

酒店附近开房的艰苦的飞机飞抵发窘惹风波,觉得覅奇偶均衡能否v不。
极度疯狂减肥的人能否打开v高科技就而后就覅哦冏结构i恶如桂萼黑人牙膏覅u我也【发票未开u俄日附件二我就佛i额外阶级感v,我为何军方的我i和服i好热哦iu均为辐9为u

ui和覅文化覅哦佛为进度覅u蛊蛾i巨乳古人规格i兼顾如果我是破看到v个ui就火热i今年的付款了几个vi哦素问。就觉发给金佛i为借口破碎的梦
i觉得覅u而非各位i风格较为哦个粉色哦i多发几个v二哥i文件哦i怪兽决斗盘可加热管覅u个人文集狗哥

vocab.txt是下载的bert中文预训练模型里的词典

最后的部分输出如下所示:

INFO:tensorflow:*** Example ***
INFO:tensorflow:tokens: [CLS] i 觉 得 [UNK] u [MASK] 非 [MASK] 位 i 风 格 较 ##by 哦 个 驅 色 哦 i 多 发 [MASK] 个 v 二 哥 i 文 件 哦 i 怪 [MASK] 决 斗 盘 可 加 热 管 [MASK] u [MASK] [MASK] 文 集 狗 哥 [SEP] [MASK] [UNK] 奇 偶 均 衡 能 否 v 不 。 极 [MASK] 疯 狂 减 肥 的 人 能 否 打 开 v 高 科 技 就 而 [MASK] 就 [UNK] 哦 冏 结 构 i 恶 如 桂 萼 黑 人 牙 膏 [UNK] u 我 也 【 发 票 未 开 [MASK] 俄 日 [MASK] 件 二 我 就 佛 i 额 [MASK] 阶 [MASK] 感 v [MASK] 我 为 [MASK] 军 方 [SEP]
INFO:tensorflow:input_ids: 101 151 6230 2533 100 163 103 7478 103 855 151 7599 3419 6772 8684 1521 702 7705 5682 1521 151 1914 1355 103 702 164 753 1520 151 3152 816 1521 151 2597 103 1104 3159 4669 1377 1217 4178 5052 103 163 103 103 3152 7415 4318 1520 102 103 100 1936 981 1772 6130 5543 1415 164 679 511 3353 103 4556 4312 1121 5503 4638 782 5543 1415 2802 2458 164 7770 4906 2825 2218 5445 103 2218 100 1521 1087 5310 3354 151 2626 1963 3424 5861 7946 782 4280 5601 100 163 2769 738 523 1355 4873 3313 2458 103 915 3189 103 816 753 2769 2218 867 151 7583 103 7348 103 2697 164 103 2769 711 103 1092 3175 102
INFO:tensorflow:input_mask: 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
INFO:tensorflow:segment_ids: 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
INFO:tensorflow:masked_lm_positions: 6 8 14 17 23 34 42 44 45 46 51 63 80 105 108 116 118 121 124 0
INFO:tensorflow:masked_lm_ids: 5445 1392 711 5106 1126 1077 100 702 782 3152 2533 2428 1400 163 7353 1912 5277 8024 862 0
INFO:tensorflow:masked_lm_weights: 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 0.0
INFO:tensorflow:next_sentence_labels: 1

可以看到token序列里的中文确实是以字的形式出现的

三、run_pretraining.py

终于到预训练的执行模块了,里面大部分都是tensorflow训练的常规代码,感觉没什么好分析的。

看过前面的内容和我前两章内容的朋友我想已经初步知道预训练的整个逻辑了,这里作一个简单的介绍:

1、X和Y的确定

    input_ids = features["input_ids"]
    input_mask = features["input_mask"]
    segment_ids = features["segment_ids"]
    masked_lm_positions = features["masked_lm_positions"]
    masked_lm_ids = features["masked_lm_ids"]
    masked_lm_weights = features["masked_lm_weights"]
    next_sentence_labels = features["next_sentence_labels"]
    model = modeling.BertModel(
        config=bert_config,
        is_training=is_training,
        input_ids=input_ids,
        input_mask=input_mask,
        token_type_ids=segment_ids,
        use_one_hot_embeddings=use_one_hot_embeddings)

其中input_ids、input_mask 、segment_ids 作为X,剩下的masked_lm_positions、masked_lm_ids 、masked_lm_weights 、next_sentence_labels 共同作为Y

2、 loss

    (masked_lm_loss,
     masked_lm_example_loss, masked_lm_log_probs) = get_masked_lm_output(
         bert_config, model.get_sequence_output(), model.get_embedding_table(),
         masked_lm_positions, masked_lm_ids, masked_lm_weights)

    (next_sentence_loss, next_sentence_example_loss,
     next_sentence_log_probs) = get_next_sentence_output(
         bert_config, model.get_pooled_output(), next_sentence_labels)

    total_loss = masked_lm_loss + next_sentence_loss

可以看到loss 分别由masked_lm_loss和next_sentence_loss组成,masked_lm_loss针对的是语言模型对MASK起来的标签的预测,即上下文语境预测当前词;而next_sentence_loss是对于句子关系的预测。前者在迁移学习中可以用于标注类任务(分词、NER等),后者可以用于句子关系任务(QA、自然语言推理等)。

需要多说一句的是,masked_lm_loss,用到了模型的sequence_output和embedding_table,这是因为对多个MASK的标签进行预测是一个标注问题,所以需要获取最后一层的整个sequence,而embedding_table用来反embedding,这样就映射到token的学习了。而next_sentence_loss用到的是pooled_output,对应的是第一个token [CLS],它一般用于分类任务的学习。

总结:

本文介绍了以下几个内容:

1、tokenization模块:我把它叫做对原始文本段的解析,只有解析过后才能标准化输入;

2、create_pretraining_data模块:对原始数据进行转换,原始数据本是无标签的数据,通过句子的拼接可以产生句子关系的标签,通过MASK可以产生标注的标签,其本质是语言模型的应用;

3、run_pretraining模块:在执行预训练的时候针对以上两种标签分别利用bert模型的不同输出部件,计算loss,然后进行梯度下降优化。

本文系列
Bert系列(一)——demo运行
Bert系列(二)——模型主体源码解读
Bert系列(四)——源码解读之Fine-tune
Bert系列(五)——中文分词实践 F1 97.8%(附代码)

Reference
1.https://github.com/google-research/bert
2.BERT: Pre-training of Deep Bidirectional Transformers for
Language Understanding



作者:西溪雷神
链接:https://www.jianshu.com/p/22e462f01d8c
来源:简书
简书著作权归作者所有,任何形式的转载都请联系作者获得授权并注明出处。
分享到:
评论

相关推荐

    Pre-trained-BERT-model-using-own-corpus

    https://towardsdatascience.com/how-to-build-a-wordpiece-tokenizer-for-bert-f505d97dddbb to train tokenizer https://towardsdatascience.com/how-to-train-a-bert-model-from-scratch-72cfce554fc6

    BERT: Pre-training of Deep Bidirectional Transformers翻译

    该模型的创新之处在于其预训练机制,它能够从未标记的文本中学习到深度的双向表示。传统的语言模型往往只关注单向的上下文信息,如仅依赖左侧或右侧的文本,而BERT则同时考虑左右两侧的上下文,这使得模型能够更全面...

    人工智能文本分类-采用Keras和Keras-bert实现文本多标签分类任务-对BERT进行微调(源码+文档说明)

    人工智能文本分类-采用Keras和Keras-bert实现文本多标签分类任务-对BERT进行微调(源码+文档说明),含有代码注释,满分大作业资源,新手也可看懂,期末大作业、课程设计、高分必看,下载下来,简单部署,就可以使用...

    基于Pytorch的BERT-IDCNN-BILSTM-CRF中文实体识别实现

    基于Pytorch的BERT-IDCNN-BILSTM-CRF中文实体识别实现 模型训练(可选) 下载pytorch_model.bin到data/bert 下载训练集和测试集到data/ 检查配置constants.py 执行train.py,命令为 python train.py 中文命名实体...

    Davlan/bert-base-multilingual-cased-ner-hrl NER命名实体识别模型

    《Davlan/bert-base-multilingual-cased-ner-hrl:多语言命名实体识别的深度学习模型》 在自然语言处理(NLP)领域,命名实体识别(NER)是一项重要的任务,它涉及到从文本中识别出具有特定意义的实体,如人名、...

    Bert-Chinese-Text-Classification-Pytorch-master.zip.zip

    标题 "Bert-Chinese-Text-Classification-Pytorch-master.zip.zip" 暗示这是一个包含BERT(Bidirectional Encoder Representations from Transformers)模型的中文文本分类项目,基于PyTorch实现。这个压缩包提供了...

    BERT-BiLSTM-CRF-master.zip

    # BERT-BiLSTM-CRF BERT-BiLSTM-CRF的Keras版实现 ## BERT配置 1. 首先需要下载Pre-trained的BERT模型,本文用的是Google开源的中文BERT模型: - ...

    基于BERT-BILSTM-CRF进行中文命名实体识别python+数据+模型(高分项目源码).rar

    基于BERT-BILSTM-CRF进行中文命名实体识别python+数据+模型(高分项目源码) --checkpoint:模型和配置保存位置 --model_hub:预训练模型 ----chinese-bert-wwm-ext: --------vocab.txt --------pytorch_model.bin ...

    Python自然语言处理-BERT实战

    通俗讲解BERT模型中所涉及的核心知识点(Transformer,self-attention等),基于google开源BERT项目从零开始讲解如何搭建自然语言处理通用框架,通过debug源码详细解读其中每一核心代码模块的功能与作用。最后基于...

    bert-base-chinese.zip

    在这个场景中,我们关注的是BERT的一个特定版本——"BERT-base-Chinese",这是一个针对中文语言的预训练模型。 1. **BERT的结构与原理**: BERT模型基于Transformer的编码器部分,由多个自注意力层和前馈神经网络...

    Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks下载

    ### Sentence-BERT: Sentence Embeddings Using Siamese BERT-Networks #### 概述 在自然语言处理(NLP)领域,预训练模型如BERT(Bidirectional Encoder Representations from Transformers)及其变种RoBERTa已...

    huggingface的bert-base-chinese

    【标题】"huggingface的bert-base-chinese" 指的是Hugging Face平台上由Google提供的预训练模型,它是BERT模型的一个中文版本。BERT(Bidirectional Encoder Representations from Transformers)是由Google在2018年...

    BERT-Relation-Extraction

    《BERT在关系抽取中的应用——基于PyTorch实现》 关系抽取是自然语言处理领域中的一个关键任务,它旨在从文本中识别出实体之间的语义关系。近年来,随着深度学习的发展,尤其是Transformer架构的提出,BERT...

    Kaleido-BERT Vision-Language Pre-Training on Fashion Domain.pdf

    "Kaleido-BERT: Vision-Language Pre-training on Fashion Domain" 是一项创新性的研究,旨在改进现有的VLP模型,以更好地理解和处理与时尚相关的图像和文本数据。 Kaleido-BERT的核心创新在于引入了一种名为...

    bert-base-chinese.rar

    本文将围绕标题“bert-base-chinese.rar”所代表的PyTorch实现的中文BERT模型进行详细介绍,旨在帮助读者深入理解其工作原理,并探讨如何应用于中文短文本分类、问答系统等NLP任务。 1. **BERT模型概述** BERT模型...

    bert模型的源代码-基于tensorflow框架

    **BERT模型概述** BERT(Bidirectional Encoder Representations from Transformers)是由Google AI Language团队在2018年提出的预训练语言模型。它通过Transformer架构实现了对输入文本的双向上下文理解,打破了...

    Bert-Chinese-Text-Classification-Pytorch:使用Bert,ERNIE,进行中文文本分类

    数据集包括一系列中文文本样本和对应的标签。可以将数据集整理为CSV或其他常见格式。确保数据集具有足够的样本且标签正确。 模型选择和准备:在中文文本分类任务中,Bert和ERNIE等预训练模型已经在自然语言处理领域...

    bert_model.ckpt.data-00000-of-00001

    bert_model.ckpt.data-00000-of-00001

    基于BERT的文本纠错项目python源码+使用说明+数据.zip

    --model_name_or_path=bert-base-chinese \ --do_train \ --train_data_file=$TRAIN_FILE \ --do_eval \ --eval_data_file=$TEST_FILE \ --mlm --num_train_epochs=3 ``` 或者使用 ``` python -m run_...

Global site tag (gtag.js) - Google Analytics