企业查在线查询官网/seo优化专员招聘
前提条件:
对 C++ 有一定的了解
必须已下载 TensorFlow 源代码,并且能够编译它
我们将支持文件格式的任务分为以下两部分:
文件格式:我们使用读取器 tf.data.Dataset 从文件中读取原始记录(通常由标量字符串张量表示,但也可以有更多结构)
记录格式:我们使用解码器或解析操作将字符串记录转换为可供 TensorFlow 使用的张量
例如,要重新实现 tf.contrib.data.make_csv_dataset 函数,我们可以使用 tf.data.TextLineDataset 提取记录,然后使用 tf.data.Dataset.map 和 tf.decode_csv 从数据集中的每行文本解析 CSV 记录。
为文件格式编写 Dataset
tf.data.Dataset 表示元素序列,其中每个元素可以是文件中的各条记录。下面是几个已经内置到 TensorFlow 中的 “读取器” 数据集示例:
tf.data.TFRecordDataset(源代码位于 kernels/data/reader_dataset_ops.cc)
tf.data.FixedLengthRecordDataset(源代码位于 kernels/data/reader_dataset_ops.cc)
tf.data.TextLineDataset(源代码位于 kernels/data/reader_dataset_ops.cc)
其中每一种实现都包含三个相关类:
tensorflow::DatasetOpKernel 子类(例如 TextLineDatasetOp),可告知 TensorFlow 如何在其 MakeDataset() 方法中根据某操作的输入和属性构造数据集对象
tensorflow::GraphDatasetBase 子类(例如 TextLineDatasetOp::Dataset),它表示数据集本身的不可变定义,并告知 TensorFlow 如何在其 MakeIteratorInternal() 方法中针对该数据集构造迭代器对象
tensorflow::DatasetIterator 子类(例如 TextLineDatasetOp::Dataset::Iterator),它表示迭代器针对特定数据集的可变状态,并告知 TensorFlow 如何在其 GetNextInternal() 方法中从迭代器获取下一个元素
最重要的方法是 GetNextInternal() 方法,因为它定义了如何从文件中实际读取记录并将这些记录表示为一个或多个 Tensor 对象。
要创建一个名为(例如)MyReaderDataset 的新读取器数据集,您需要:
在 C++ 中,定义实现读取逻辑的 tensorflow::DatasetOpKernel、tensorflow::GraphDatasetBase 和 tensorflow::DatasetIterator 的子类
在 C++ 中,注册名为 "MyReaderDataset" 的新读取器操作和内核
在 Python 中,定义 tf.data.Dataset 的子类(名为 MyReaderDataset)
您可以将所有 C++ 代码放在一个文件中,例如 my_reader_dataset_op.cc。如果您熟悉如何添加操作,将会有所帮助。您可以将以下框架当做着手点开始实现:
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
namespace myproject {
namespace {
using ::tensorflow::DT_STRING;
using ::tensorflow::PartialTensorShape;
using ::tensorflow::Status;
class MyReaderDatasetOp : public tensorflow::DatasetOpKernel {
public:
MyReaderDatasetOp(tensorflow::OpKernelConstruction* ctx)
: DatasetOpKernel(ctx) {
// Parse and validate any attrs that define the dataset using
// `ctx->GetAttr()`, and store them in member variables.
}
void MakeDataset(tensorflow::OpKernelContext* ctx,
tensorflow::DatasetBase** output) override {
// Parse and validate any input tensors that define the dataset using
// `ctx->input()` or the utility function
// `ParseScalarArgument(ctx, &arg)`.
// Create the dataset object, passing any (already-validated) arguments from
// attrs or input tensors.
*output = new Dataset(ctx);
}
private:
class Dataset : public tensorflow::GraphDatasetBase {
public:
Dataset(tensorflow::OpKernelContext* ctx) : GraphDatasetBase(ctx) {}
std::unique_ptr<:iteratorbase> MakeIteratorInternal(
const string& prefix) const override {
return std::unique_ptr<:iteratorbase>(new Iterator(
{this, tensorflow::strings::StrCat(prefix, "::MyReader")}));
}
// Record structure: Each record is represented by a scalar string tensor.
//
// Dataset elements can have a fixed number of components of different
// types and shapes; replace the following two methods to customize this
// aspect of the dataset.
const tensorflow::DataTypeVector& output_dtypes() const override {
static auto* const dtypes = new tensorflow::DataTypeVector({DT_STRING});
return *dtypes;
}
const std::vector& output_shapes() const override {
static std::vector* shapes =
new std::vector({ {}});
return *shapes;
}
string DebugString() const override { return "MyReaderDatasetOp::Dataset"; }
protected:
// Optional: Implementation of `GraphDef` serialization for this dataset.
//
// Implement this method if you want to be able to save and restore
// instances of this dataset (and any iterators over it).
Status AsGraphDefInternal(DatasetGraphDefBuilder* b,
tensorflow::Node** output) const override {
// Construct nodes to represent any of the input tensors from this
// object's member variables using `b->AddScalar()` and `b->AddVector()`.
std::vector<:node> input_tensors;
TF_RETURN_IF_ERROR(b->AddDataset(this, input_tensors, output));
return Status::OK();
}
private:
class Iterator : public tensorflow::DatasetIterator {
public:
explicit Iterator(const Params& params)
: DatasetIterator(params), i_(0) {}
// Implementation of the reading logic.
//
// The example implementation in this file yields the string "MyReader!"
// ten times. In general there are three cases:
//
// 1. If an element is successfully read, store it as one or more tensors
// in `*out_tensors`, set `*end_of_sequence = false` and return
// `Status::OK()`.
// 2. If the end of input is reached, set `*end_of_sequence = true` and
// return `Status::OK()`.
// 3. If an error occurs, return an error status using one of the helper
// functions from "tensorflow/core/lib/core/errors.h".
Status GetNextInternal(tensorflow::IteratorContext* ctx,
std::vector<:tensor>* out_tensors,
bool* end_of_sequence) override {
// NOTE: `GetNextInternal()` may be called concurrently, so it is
// recommended that you protect the iterator state with a mutex.
tensorflow::mutex_lock l(mu_);
if (i_ < 10) {
// Create a scalar string tensor and add it to the output.
tensorflow::Tensor record_tensor(ctx->allocator({}), DT_STRING, {});
record_tensor.scalar()() = "MyReader!";
out_tensors->emplace_back(std::move(record_tensor));
++i_;
*end_of_sequence = false;
} else {
*end_of_sequence = true;
}
return Status::OK();
}
protected:
// Optional: Implementation of iterator state serialization for this
// iterator.
//
// Implement these two methods if you want to be able to save and restore
// instances of this iterator.
Status SaveInternal(tensorflow::IteratorStateWriter* writer) override {
tensorflow::mutex_lock l(mu_);
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("i"), i_));
return Status::OK();
}
Status RestoreInternal(tensorflow::IteratorContext* ctx,
tensorflow::IteratorStateReader* reader) override {
tensorflow::mutex_lock l(mu_);
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("i"), &i_));
return Status::OK();
}
private:
tensorflow::mutex mu_;
int64 i_ GUARDED_BY(mu_);
};
};
};
// Register the op definition for MyReaderDataset.
//
// Dataset ops always have a single output, of type `variant`, which represents
// the constructed `Dataset` object.
//
// Add any attrs and input tensors that define the dataset here.
REGISTER_OP("MyReaderDataset")
.Output("handle: variant")
.SetIsStateful()
.SetShapeFn(tensorflow::shape_inference::ScalarShape);
// Register the kernel implementation for MyReaderDataset.
REGISTER_KERNEL_BUILDER(Name("MyReaderDataset").Device(tensorflow::DEVICE_CPU),
MyReaderDatasetOp);
} // namespace
} // namespace myproject
最后一步是编译 C++ 代码并添加 Python 封装容器。要执行此操作,最简单的方法是编译动态库(例如,名为 "my_reader_dataset_op.so"),并添加一个作为 tf.data.Dataset 子类的 Python 类来封装它。下面提供了一个示例 Python 程序:
import tensorflow as tf
# Assumes the file is in the current working directory.
my_reader_dataset_module = tf.load_op_library("./my_reader_dataset_op.so")
class MyReaderDataset(tf.data.Dataset):
def __init__(self):
super(MyReaderDataset, self).__init__()
# Create any input attrs or tensors as members of this class.
def _as_variant_tensor(self):
# Actually construct the graph node for the dataset op.
#
# This method will be invoked when you create an iterator on this dataset
# or a dataset derived from it.
return my_reader_dataset_module.my_reader_dataset()
# The following properties define the structure of each element: a scalar
# tf.string
tensor. Change these properties to match the `output_dtypes()`
# and `output_shapes()` methods of `MyReaderDataset::Dataset` if you modify
# the structure of each element.
@property
def output_types(self):
return tf.string
@property
def output_shapes(self):
return tf.TensorShape([])
@property
def output_classes(self):
return tf.Tensor
if __name__ == "__main__":
# Create a MyReaderDataset and print its elements.
with tf.Session() as sess:
iterator = MyReaderDataset().make_one_shot_iterator()
next_element = iterator.get_next()
try:
while True:
print(sess.run(next_element)) # Prints "MyReader!" ten times.
except tf.errors.OutOfRangeError:
pass
您可以在 tensorflow/python/data/ops/dataset_ops.py 中查看一些示例 Dataset 封装容器类(https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/data/ops/dataset_ops.py)。
为记录格式编写操作
通常,这是一项将标量字符串记录作为输入的普通操作,因此请按照相关说明添加操作。您可以选择将标量字符串键作为输入,并将其包含在报告数据格式不正确的错误消息中。这样,用户就可以更轻松地跟踪错误数据的来源。
用于解码记录的操作示例:
tf.parse_single_example(和 tf.parse_example)
tf.decode_csv
tf.decode_raw
请注意,使用多个操作解码特定记录格式会很有用。例如,您可以将图像作为字符串另存到 tf.train.Example 协议缓冲区中。根据该图像的格式,您可以采用 tf.parse_single_example 操作的相应输出,并调用 tf.image.decode_jpeg、tf.image.decode_png 或 tf.decode_raw。通常采用 tf.decode_raw 的输出并使用 tf.slice 和 tf.reshape 提取各部分。
更多 AI 相关阅读:
TensorFlow Lite 对象检测
标贝科技:TensorFlow 框架提升语音合成效果
使用 TensorFlow Model Analysis 提升模型质量