diff --git a/configure.py b/configure.py index 8572fa7fdb249d1ae6f39202b871d17cd44755f8..6279c4261044e3f33519ece5f2ac19af2acb505d 100644 --- a/configure.py +++ b/configure.py @@ -25,10 +25,12 @@ import re import subprocess import sys +# pylint: disable=g-import-not-at-top try: from shutil import which except ImportError: from distutils.spawn import find_executable as which +# pylint: enable=g-import-not-at-top _TF_BAZELRC = os.path.join(os.path.dirname(os.path.abspath(__file__)), '.tf_configure.bazelrc') @@ -485,7 +487,10 @@ def set_cc_opt_flags(environ_cp): cc_opt_flags = get_from_env_or_user_or_default(environ_cp, 'CC_OPT_FLAGS', question, default_cc_opt_flags) for opt in cc_opt_flags.split(): - write_to_bazelrc('build:opt --cxxopt=%s --copt=%s' % (opt, opt)) + host_opt = '-march=native' # It should be safe on the same build host. + write_to_bazelrc( + 'build:opt --cxxopt=%s --copt=%s' % (opt, opt) + + ' --host_cxxopt=%s --host_copt=%s' % (host_opt, host_opt)) def set_tf_cuda_clang(environ_cp): diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc index 893175373f670b38b03b3492ddffb16102708160..6ef4860f35835e59be3452b57204d42c82d0816b 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc @@ -130,7 +130,9 @@ Status CopySubgraph(const Graph& graph, const Frame* frame, stack.push_back(src); } Node* src_copy = (*node_map)[e->src()->id()]; - int src_output = squash_src_outputs[e->src()->id()] ? 0 : e->src_output(); + int src_output = squash_src_outputs[e->src()->id()] && !e->IsControlEdge() + ? 0 + : e->src_output(); Node* dst_copy = (*node_map)[e->dst()->id()]; output->AddEdge(src_copy, src_output, dst_copy, e->dst_input()); } diff --git a/tensorflow/compiler/tf2xla/kernels/gather_op.cc b/tensorflow/compiler/tf2xla/kernels/gather_op.cc index 2c5d910d58d8e486deb1f644cbda11261b93a84e..e420f21ca33fe7de9b33f404ce04eae62d9c041e 100644 --- a/tensorflow/compiler/tf2xla/kernels/gather_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/gather_op.cc @@ -77,18 +77,6 @@ xla::ComputationDataHandle XlaComputeGatherDynamicSlice( out_shape.dim_sizes()); } - // Degenerate case: single slice. - if (num_indices == 1) { - auto index = builder->Reshape(indices, {1}); - auto start_index = builder->Pad( - index, XlaHelpers::Zero(builder, index_type), - xla::MakeEdgePaddingConfig( - {{input_shape_pre_axis.dims(), input_shape_post_axis.dims()}})); - auto slice = - builder->DynamicSlice(input, start_index, slice_shape.dim_sizes()); - return builder->Reshape(slice, out_shape.dim_sizes()); - } - // Specify the shape of the loop-carried Tensor tuple. xla::PrimitiveType ptype; TF_CHECK_OK(DataTypeToPrimitiveType(dtype, &ptype)); diff --git a/tensorflow/compiler/xla/client/computation_builder.cc b/tensorflow/compiler/xla/client/computation_builder.cc index 24774c4c2a385d9aabd22a550bd8be3acf409d85..763d94e94c2167f47b3f0777a31815f02791aa9e 100644 --- a/tensorflow/compiler/xla/client/computation_builder.cc +++ b/tensorflow/compiler/xla/client/computation_builder.cc @@ -1309,7 +1309,7 @@ Status ComputationBuilder::SetReturnValue( } StatusOr ComputationBuilder::IsConstant( - const ComputationDataHandle& operand) { + const ComputationDataHandle& operand, int64 num_parameters) { if (!first_error_.ok()) { return first_error_; } @@ -1317,6 +1317,7 @@ StatusOr ComputationBuilder::IsConstant( IsConstantRequest request; *request.mutable_computation() = computation_.handle(); *request.mutable_operand() = operand; + request.set_num_parameters(num_parameters); IsConstantResponse response; VLOG(2) << "making IsConstant request"; @@ -1330,7 +1331,8 @@ StatusOr ComputationBuilder::IsConstant( } StatusOr> ComputationBuilder::ComputeConstant( - const ComputationDataHandle& operand, const Layout* output_layout) { + const ComputationDataHandle& operand, const Layout* output_layout, + tensorflow::gtl::ArraySlice parameters) { if (!first_error_.ok()) { return first_error_; } @@ -1341,6 +1343,9 @@ StatusOr> ComputationBuilder::ComputeConstant( if (output_layout != nullptr) { *request.mutable_output_layout() = *output_layout; } + for (const auto& param : parameters) { + *request.add_parameters() = param.ToProto(); + } ComputeConstantResponse response; diff --git a/tensorflow/compiler/xla/client/computation_builder.h b/tensorflow/compiler/xla/client/computation_builder.h index bc7ad06a3fec4aecbe11312982542c9d615b4911..8e1b4be1f3ebf8e3f530b053447f86f7a2f56fa7 100644 --- a/tensorflow/compiler/xla/client/computation_builder.h +++ b/tensorflow/compiler/xla/client/computation_builder.h @@ -746,11 +746,12 @@ class ComputationBuilder { ComputationDataHandle Recv(const Shape& shape, const ChannelHandle& handle); // Returns true if 'operand' is a compile-time constant. A compile-time - // constant does not depend on parameters, or on stateful operators such - // as `RngNormal` or `Infeed`. Unlike `ComputeConstant`, `IsConstant` tests - // whether a computation is a compile-time constant without evaluating the - // computation. - StatusOr IsConstant(const ComputationDataHandle& operand); + // constant does not depend on parameters with higher index then + // `num_parameters`, or on stateful operators such as `RngNormal` or `Infeed`. + // Unlike `ComputeConstant`, `IsConstant` tests whether a computation is a + // compile-time constant without evaluating the computation. + StatusOr IsConstant(const ComputationDataHandle& operand, + int64 num_parameters = 0); // Normalizes operand across spatial and batch dimensions for each feature. // @@ -795,7 +796,7 @@ class ComputationBuilder { float epsilon, int64 feature_index); // Computes the value of a constant indicated by a - // ComputationDataHandle. + // ComputationDataHandle using a non-optimized interpreter on the host. // // The operand must be from the computation currently being built - // i.e., returned from this builder with no intervening call to @@ -803,8 +804,11 @@ class ComputationBuilder { // that may stop working at any time. // // The operand must represent a constant value, which in this case - // means that it must not statically depend on a parameter to the - // computation that is being built. + // means that it must not statically depend on any parameter of the + // computation that is being built other then the ones specified on the + // paramtere list. The parameters in the list will be indexed by their + // parameter id property so the number of parameters specified should be at + // least as many as the largest used parameter index. // // `IsConstant` can be used to test whether a computation is a compile-time // constant without evaluation it. `ComputeConstant` only succeeds for @@ -822,7 +826,8 @@ class ComputationBuilder { // will be stored using that layout. StatusOr> ComputeConstant( const ComputationDataHandle& operand, - const Layout* output_layout = nullptr); + const Layout* output_layout = nullptr, + tensorflow::gtl::ArraySlice parameters = {}); // Returns a new ComputationBuilder whose resultant Computation is used only // by this ComputationBuilder. The sub-ComputationBuilder has the same diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc index 8536429846f87fd5c4b073cc4b13b3f1c5eb2e5c..b422b22df9cfbefb6611fcb229ed42e67fe3a0d8 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment.cc @@ -101,6 +101,11 @@ BufferAllocationProto BufferAllocation::ToProto() const { proto_assigned->set_offset(buffer_offset_size.second.offset); proto_assigned->set_size(buffer_offset_size.second.size); } + std::sort(proto.mutable_assigned()->begin(), proto.mutable_assigned()->end(), + [](const BufferAllocationProto::Assigned& assign1, + const BufferAllocationProto::Assigned& assign2) { + return assign1.logical_buffer_id() < assign2.logical_buffer_id(); + }); return proto; } diff --git a/tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.cc b/tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.cc index b49047283119fb2f10b9f68eaa37a7bdc27f63a6..81c29e4726c7be53b433be896f558f502e43c885 100644 --- a/tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.cc +++ b/tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.cc @@ -52,7 +52,7 @@ llvm::Function* EmitVectorF32TanhIfNeeded(llvm::Module* module, llvm::IRBuilder<> ir_builder(vector_tanh_body); llvm::FastMathFlags fast_math_flags; - fast_math_flags.setUnsafeAlgebra(); + fast_math_flags.setFast(); ir_builder.setFastMathFlags(fast_math_flags); llvm::Value* input = &*vector_tanh_function->arg_begin(); diff --git a/tensorflow/compiler/xla/service/executable.h b/tensorflow/compiler/xla/service/executable.h index 2d32e59d36c4e3026e0e151561db3076146fabe4..7e0d182b365c35788195e70dc35c3923ed8991bb 100644 --- a/tensorflow/compiler/xla/service/executable.h +++ b/tensorflow/compiler/xla/service/executable.h @@ -88,6 +88,16 @@ class Executable { tensorflow::gtl::ArraySlice> arguments); + // Populates `hlo_execution_profile` from `executor`. This is implicit in any + // Execute* API call that takes a hlo_execution_profile argument, but must be + // called explicitly for other (async, for example) variants after the stream + // has completed. + virtual Status PopulateExecutionProfile( + HloExecutionProfile* hlo_execution_profile, + perftools::gputools::StreamExecutor* executor) { + return Status::OK(); + } + // Convenience wrapper for calling Executable::ExecuteOnStream. Sets up a // timer for the execution, sets up HLO profiling if enabled, and fills in the // given ExecutionProfile if non-null. The ExecuteOnStream overloads have diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index e09899e48d389716226661c5ae61defb066362aa..5107ac782d7c93dfa17969338bf97c9fd9bb1516 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -1901,12 +1901,13 @@ std::vector HloInstruction::ExtraAttributesToString() const { if (has_sharding()) { extra.push_back(StrCat("sharding=", sharding().ToString())); } - if (!control_successors_.empty()) { - extra.push_back(StrCat( - "control-successors=", - Join(control_successors_, ", ", [](string* out, HloInstruction* succ) { - StrAppend(out, succ->name()); - }))); + if (!control_predecessors_.empty()) { + extra.push_back(StrCat("control-predecessors={", + Join(control_predecessors_, ", ", + [](string* out, HloInstruction* pre) { + StrAppend(out, pre->name()); + }), + "}")); } return extra; } diff --git a/tensorflow/compiler/xla/service/hlo_runner.cc b/tensorflow/compiler/xla/service/hlo_runner.cc index aaa4e3a2e3a7523a6779f86ba0c20cb48c23f600..f463e57d995c0f0549872a1a0bf20a3ead626dc8 100644 --- a/tensorflow/compiler/xla/service/hlo_runner.cc +++ b/tensorflow/compiler/xla/service/hlo_runner.cc @@ -41,11 +41,21 @@ namespace se = ::perftools::gputools; namespace xla { /*static*/ StatusOr> -HloRunner::ReadModuleFromHloProtoFile(const char* filename, +HloRunner::ReadModuleFromHloProtoFile(const std::string& filename, const DebugOptions& debug_options) { HloProto proto; - TF_RETURN_IF_ERROR(tensorflow::ReadBinaryProto(tensorflow::Env::Default(), - filename, &proto)); + + const Status s = + tensorflow::ReadBinaryProto(tensorflow::Env::Default(), filename, &proto); + + if (!s.ok()) { + const Status s2 = + tensorflow::ReadTextProto(tensorflow::Env::Default(), filename, &proto); + if (!s2.ok()) { + return Status(s2.code(), s.error_message() + "\n" + s2.error_message()); + } + } + TF_ASSIGN_OR_RETURN( HloModuleConfig config, HloModule::CreateModuleConfigFromProto(proto.hlo_module())); @@ -56,7 +66,7 @@ HloRunner::ReadModuleFromHloProtoFile(const char* filename, } /*static*/ StatusOr> -HloRunner::ReadModuleFromHloTextDumpFile(const char* filename, +HloRunner::ReadModuleFromHloTextDumpFile(const std::string& filename, const DebugOptions& debug_options) { string hlo_string; TF_RETURN_IF_ERROR(tensorflow::ReadFileToString(tensorflow::Env::Default(), @@ -66,6 +76,19 @@ HloRunner::ReadModuleFromHloTextDumpFile(const char* filename, return tools::Parse(hlo_string, config); } +/*static*/ StatusOr> HloRunner::ReadModule( + const std::string& filename, const DebugOptions& debug_options) { + auto module = HloRunner::ReadModuleFromHloProtoFile(filename, debug_options); + if (module.ok()) { + return module; + } + const std::string e = module.status().error_message(); + module = HloRunner::ReadModuleFromHloTextDumpFile(filename, debug_options); + return module.ok() ? std::move(module) + : Status(module.status().code(), + e + "\n" + module.status().error_message()); +} + // Define this in .cc file to avoid having to include eigen or forward declare // these types in the header. struct HloRunner::EigenThreadPoolWrapper { diff --git a/tensorflow/compiler/xla/service/hlo_runner.h b/tensorflow/compiler/xla/service/hlo_runner.h index b0e2b980e22d3348f63a528cd49f0b814540d549..a5732848c6b4191faf8d7b07c749132ca8b14413 100644 --- a/tensorflow/compiler/xla/service/hlo_runner.h +++ b/tensorflow/compiler/xla/service/hlo_runner.h @@ -44,15 +44,23 @@ class HloRunner { ~HloRunner(); - // Reads the binary proto file in xla.HloProto format, creates and returns the - // HloModule. + // Reads the proto file in xla.HloProto format, creates and returns the + // HloModule. Will try to parse the filename as binary proto, then try as + // text proto if that fails. static StatusOr> ReadModuleFromHloProtoFile( - const char* filename, const DebugOptions& debug_options); + const std::string& filename, const DebugOptions& debug_options); // Reads the hlo text dump file in HloModule::ToString format, creates and // returns the HloModule. static StatusOr> ReadModuleFromHloTextDumpFile( - const char* filename, const DebugOptions& debug_options); + const std::string& filename, const DebugOptions& debug_options); + + // Tries to parse the filename specified first as binary proto format, then + // as a textual proto format, then textual IR, then gives up if both fail. + // ReadModuleFromHloProtoFile or ReadModuleFromHloTextDumpFile should be used + // explicitly when you know the format, this if you don't. + static StatusOr> ReadModule( + const std::string& filename, const DebugOptions& debug_options); // Executes the given module with given literals as input and returns the // result as a Literal. The LiteralPtr type accepts Literal* or diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc index 5dff4b5778970dd473c5f158b3828a850847d1ff..956c0d5f05288e32c626f247ce8356c60d17808d 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc @@ -555,8 +555,9 @@ int64 ByteSizeOf(const Shape& shape, const llvm::DataLayout& data_layout) { llvm::FastMathFlags GetFastMathFlags(bool fast_math_enabled) { llvm::FastMathFlags flags; if (fast_math_enabled) { - // UnsafeAlgebra implies NoInfs, NoNaNs, NoSignedZeros, and AllowReciprocal. - flags.setUnsafeAlgebra(); + // Fast implies AllowReassoc, NoInfs, NoNaNs, NoSignedZeros, + // AllowReciprocal, AllowContract, and ApproxFunc. + flags.setFast(); } return flags; } diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc index bac33d8102e07766531a4ce6eac77aff4971bfef..71afbee456b0f5eb67cb092d84f8e95ea1038c54 100644 --- a/tensorflow/compiler/xla/service/service.cc +++ b/tensorflow/compiler/xla/service/service.cc @@ -490,14 +490,20 @@ Service::ExecuteParallelAndRegisterResult( std::vector> arguments, Backend* backend, tensorflow::gtl::ArraySlice device_handles, - tensorflow::gtl::ArraySlice result_tags) { + tensorflow::gtl::ArraySlice result_tags, + ExecutionProfile* profile) { // Streams where the computation are launched, so we can wait on the streams // to complete. std::vector::SmartPtr> streams; + std::vector> timers; // Global data handles for the computation results, one for each computation. std::vector result_handles; + // Device ID to stream executor, populated only with devices that are being + // profiled. + std::map index_to_profiled_streams; + TF_ASSIGN_OR_RETURN(DeviceAssignment device_assignment, backend->computation_placer()->AssignDevices( options_.number_of_replicas(), executables.size())); @@ -510,6 +516,21 @@ Service::ExecuteParallelAndRegisterResult( backend->BorrowStream(replicas[replica])); streams.push_back(std::move(stream)); + if (replica == 0 && profile != nullptr) { + timers.emplace_back( + new perftools::gputools::Timer(streams.back()->parent())); + streams.back() + ->InitTimer(timers.back().get()) + .ThenStartTimer(timers.back().get()); + CHECK(timers.front() != nullptr); + } + + if (replica == 0 && + executables[i]->module_config().debug_options().xla_hlo_profile() && + executables[i]->hlo_profiling_enabled()) { + index_to_profiled_streams[i] = streams.back().get(); + } + // Set up run options. ExecutableRunOptions options; options.set_stream(streams.back().get()); @@ -526,6 +547,10 @@ Service::ExecuteParallelAndRegisterResult( perftools::gputools::DeviceMemoryBase result, executables[i]->ExecuteAsyncOnStream(&run_options, arguments[i])); + if (replica == 0 && profile != nullptr) { + streams.back()->ThenStopTimer(timers.back().get()); + } + // All replicas share the same device address for the result allocation, // so only one of the replicas need to register the result handle. if (replica == 0) { @@ -543,6 +568,69 @@ Service::ExecuteParallelAndRegisterResult( } } + // For every stream that had profiling enabled, obtain and debug-dump the HLO + // profile. + for (auto& index_to_profiled_stream : index_to_profiled_streams) { + int64 device = index_to_profiled_stream.first; + se::Stream* stream = index_to_profiled_stream.second; + HloExecutionProfile hlo_profile; + TF_RETURN_IF_ERROR(executables[device]->PopulateExecutionProfile( + &hlo_profile, stream->parent())); + + std::unordered_set profiled_computations = + hlo_profile.profiled_computations(); + // To ensure we have print the profiles in a stable order, iterate over the + // computations in post order. + auto& module = executables[device]->module(); + std::list all_computations = + module.MakeComputationPostOrder(); + for (xla::HloComputation* computation : all_computations) { + if (profiled_computations.count(computation) > 0) { + string profile_string = hlo_profile.ToString( + *computation, streams[0]->parent()->GetDeviceDescription(), + executables[device]->CreateCostAnalysis().get()); + if (!profile_string.empty()) { + LOG(INFO) << "HLO profile for execution on device " << device + << ":\n"; + XLA_LOG_LINES(tensorflow::INFO, profile_string); + } + } + } + hlo_graph_dumper::MaybeDumpHloModule(module, "Service::Execute", + &hlo_profile); + } + + if (profile != nullptr) { + CHECK(!timers.empty()); + std::vector timer_nanoseconds; + timer_nanoseconds.reserve(timers.size()); + for (auto& timer : timers) { + timer_nanoseconds.push_back(timer->Nanoseconds()); + } + uint64 nanoseconds = + *std::max_element(timer_nanoseconds.begin(), timer_nanoseconds.end()); + + // Merge in run-time profile information from execution_profile on the + // zeroth device. + profile->MergeFrom(executables[0]->execution_profile()); + + // Overall execution time (in nanoseconds) from the executor timer. + profile->set_compute_and_transfer_time_ns(nanoseconds); + + // TODO(b/28123297): On GPU we end up including transfer time in + // the compute time this way. Instead, we should get the correct + // value by measuring it. Setting the field here at least lets + // benchmarks provide *some* value for GPU computations. + // + // TODO(b/28447609): The value in compute_and_transfer_time_ns is actually + // the compute time without the transfer time, so this way we get the + // correct compute time. We should instead have the correct value for + // compute_and_transfer_time and set compute_time to the compute time. + if (profile->compute_time_ns() == 0) { + profile->set_compute_time_ns(profile->compute_and_transfer_time_ns()); + } + } + return result_handles; } @@ -715,14 +803,16 @@ tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg, // Execute the generated executables in parallel and return the device // handles for each computation's output. + ExecutionProfile profile; TF_ASSIGN_OR_RETURN( std::vector outputs, ExecuteParallelAndRegisterResult(executable_ptrs, all_arguments, execute_backend_.get(), device_handles, - computation_names)); + computation_names, &profile)); for (const GlobalDataHandle& output : outputs) { ExecuteResponse response; *response.mutable_output() = output; + *response.mutable_profile() = profile; *result->add_responses() = response; } @@ -1082,8 +1172,9 @@ tensorflow::Status Service::IsConstant(const IsConstantRequest* arg, return InvalidArgument("computations may not be empty"); } - TF_ASSIGN_OR_RETURN(bool is_constant, - user_computation->IsConstant(arg->operand())); + TF_ASSIGN_OR_RETURN( + bool is_constant, + user_computation->IsConstant(arg->operand(), arg->num_parameters())); result->set_is_constant(is_constant); return tensorflow::Status::OK(); @@ -1101,8 +1192,9 @@ tensorflow::Status Service::ComputeConstant(const ComputeConstantRequest* arg, return InvalidArgument("computations may not be empty"); } - TF_ASSIGN_OR_RETURN(bool is_constant, - user_computation->IsConstant(arg->operand())); + TF_ASSIGN_OR_RETURN( + bool is_constant, + user_computation->IsConstant(arg->operand(), arg->parameters_size())); if (!is_constant) { return InvalidArgument("Operand to ComputeConstant depends on parameter."); } @@ -1141,8 +1233,18 @@ tensorflow::Status Service::ComputeConstant(const ComputeConstantRequest* arg, /*include_unreachable_instructions=*/ false)); + std::vector parameters(arg->parameters_size()); + for (int64 i = 0; i < arg->parameters_size(); ++i) { + parameters[i] = Literal(arg->parameters(i)); + } + std::vector parameter_ptrs; + std::transform(parameters.begin(), parameters.end(), + std::back_inserter(parameter_ptrs), + [](const Literal& literal) { return &literal; }); + HloEvaluator evaluator; - TF_ASSIGN_OR_RETURN(auto result_literal, evaluator.Evaluate(*module, {})); + TF_ASSIGN_OR_RETURN(auto result_literal, + evaluator.Evaluate(*module, parameter_ptrs)); // Since the shape_with_output_layout option in ExecutionOption is // non-effective to the Evaluator results, explicit relayout here. if (arg->has_output_layout()) { diff --git a/tensorflow/compiler/xla/service/service.h b/tensorflow/compiler/xla/service/service.h index 2452259f736054b5bf1f03fc5103d65eded7f398..6646be2e9aa43763b93bcea7a1df9d10580f162c 100644 --- a/tensorflow/compiler/xla/service/service.h +++ b/tensorflow/compiler/xla/service/service.h @@ -327,7 +327,8 @@ class Service : public ServiceInterface { arguments, Backend* backend, tensorflow::gtl::ArraySlice device_handles, - tensorflow::gtl::ArraySlice result_tags); + tensorflow::gtl::ArraySlice result_tags, + ExecutionProfile* profile); // Convenience function for adding a function to a user computation. template diff --git a/tensorflow/compiler/xla/service/user_computation.cc b/tensorflow/compiler/xla/service/user_computation.cc index 006c814996df9b209e6cd4d75bc04689c4e297c5..e9d182509b5356d32b667b7921e2843d30faeb9b 100644 --- a/tensorflow/compiler/xla/service/user_computation.cc +++ b/tensorflow/compiler/xla/service/user_computation.cc @@ -1482,14 +1482,15 @@ UserComputation::ComputeProgramShape( namespace { -// A visitor which checks whether an operation is a compile-time constant. That -// is, the operation does not depend on any parameter instructions. The visitor -// walks the computation starting at a given operation and sets is_constant to -// false iff a parameter or RNG operation is encountered. -void ConstantVisitor(const SessionComputation& session_computation, - const ComputationDataHandle& handle, - std::set* visited, bool* is_constant) { - if (visited->count(handle.handle()) != 0 || !*is_constant) { +// A visitor which checks whether an operation is pure functional meaning that +// it doesn't depend on any parameter with an index higher then num_parameters. +// The visitor walks the computation starting at a given operation and sets +// is_functional to false iff a parameter or RNG operation is encountered. +void PureFunctionalVisitor(const SessionComputation& session_computation, + const ComputationDataHandle& handle, + int64 num_parameters, std::set* visited, + bool* is_functional) { + if (visited->count(handle.handle()) != 0 || !*is_functional) { return; } @@ -1497,7 +1498,7 @@ void ConstantVisitor(const SessionComputation& session_computation, session_computation.requests().at(handle.handle()); switch (request.request().op_case()) { case OpRequest::kRngRequest: - *is_constant = false; + *is_functional = false; break; case OpRequest::kConstantRequest: @@ -1506,41 +1507,43 @@ void ConstantVisitor(const SessionComputation& session_computation, case OpRequest::kGetTupleElementRequest: { const GetTupleElementRequest& get_tuple_element_request = request.request().get_tuple_element_request(); - ConstantVisitor(session_computation, get_tuple_element_request.operand(), - visited, is_constant); + PureFunctionalVisitor(session_computation, + get_tuple_element_request.operand(), num_parameters, + visited, is_functional); break; } case OpRequest::kSliceRequest: { const SliceRequest& slice_request = request.request().slice_request(); - ConstantVisitor(session_computation, slice_request.operand(), visited, - is_constant); + PureFunctionalVisitor(session_computation, slice_request.operand(), + num_parameters, visited, is_functional); break; } case OpRequest::kDynamicSliceRequest: { const DynamicSliceRequest& dynamic_slice_request = request.request().dynamic_slice_request(); - ConstantVisitor(session_computation, dynamic_slice_request.operand(), - visited, is_constant); - ConstantVisitor(session_computation, - dynamic_slice_request.start_indices(), visited, - is_constant); + PureFunctionalVisitor(session_computation, + dynamic_slice_request.operand(), num_parameters, + visited, is_functional); + PureFunctionalVisitor(session_computation, + dynamic_slice_request.start_indices(), + num_parameters, visited, is_functional); break; } case OpRequest::kDynamicUpdateSliceRequest: { const DynamicUpdateSliceRequest& dynamic_update_slice_request = request.request().dynamic_update_slice_request(); - ConstantVisitor(session_computation, - dynamic_update_slice_request.operand(), visited, - is_constant); - ConstantVisitor(session_computation, - dynamic_update_slice_request.update(), visited, - is_constant); - ConstantVisitor(session_computation, - dynamic_update_slice_request.start_indices(), visited, - is_constant); + PureFunctionalVisitor(session_computation, + dynamic_update_slice_request.operand(), + num_parameters, visited, is_functional); + PureFunctionalVisitor(session_computation, + dynamic_update_slice_request.update(), + num_parameters, visited, is_functional); + PureFunctionalVisitor(session_computation, + dynamic_update_slice_request.start_indices(), + num_parameters, visited, is_functional); break; } @@ -1549,7 +1552,8 @@ void ConstantVisitor(const SessionComputation& session_computation, request.request().concatenate_request(); for (const ComputationDataHandle& handle : concatenate_request.operands()) { - ConstantVisitor(session_computation, handle, visited, is_constant); + PureFunctionalVisitor(session_computation, handle, num_parameters, + visited, is_functional); } break; } @@ -1557,61 +1561,63 @@ void ConstantVisitor(const SessionComputation& session_computation, case OpRequest::kConvolveRequest: { const ConvolveRequest& convolve_request = request.request().convolve_request(); - ConstantVisitor(session_computation, convolve_request.lhs(), visited, - is_constant); - ConstantVisitor(session_computation, convolve_request.rhs(), visited, - is_constant); + PureFunctionalVisitor(session_computation, convolve_request.lhs(), + num_parameters, visited, is_functional); + PureFunctionalVisitor(session_computation, convolve_request.rhs(), + num_parameters, visited, is_functional); break; } case OpRequest::kCrossReplicaSumRequest: { // TODO(b/33009255): Implmement constant folding for cross replica sum. - *is_constant = false; + *is_functional = false; break; } case OpRequest::kInfeedRequest: { - *is_constant = false; + *is_functional = false; break; } case OpRequest::kOutfeedRequest: { - *is_constant = false; + *is_functional = false; break; } case OpRequest::kCallRequest: { const CallRequest& call_request = request.request().call_request(); for (const ComputationDataHandle& handle : call_request.operands()) { - ConstantVisitor(session_computation, handle, visited, is_constant); + PureFunctionalVisitor(session_computation, handle, num_parameters, + visited, is_functional); } // TODO(b/32495713): We aren't checking the to_apply computation itself, // so we conservatively say that computations containing the Call op - // cannot be constant. We cannot set is_constant=false in other similar + // cannot be constant. We cannot set is_functional=false in other similar // cases since we're already relying on IsConstant to return true. - *is_constant = false; + *is_functional = false; break; } case OpRequest::kCustomCallRequest: { - *is_constant = false; + *is_functional = false; break; } case OpRequest::kSendRequest: { - *is_constant = false; + *is_functional = false; break; } case OpRequest::kRecvRequest: { - *is_constant = false; + *is_functional = false; break; } case OpRequest::kMapRequest: { const MapRequest& map_request = request.request().map_request(); for (const ComputationDataHandle& handle : map_request.operands()) { - ConstantVisitor(session_computation, handle, visited, is_constant); + PureFunctionalVisitor(session_computation, handle, num_parameters, + visited, is_functional); } // TODO(b/32495713): We aren't checking the to_apply computation itself. break; @@ -1619,10 +1625,10 @@ void ConstantVisitor(const SessionComputation& session_computation, case OpRequest::kReduceRequest: { const ReduceRequest& reduce_request = request.request().reduce_request(); - ConstantVisitor(session_computation, reduce_request.operand(), visited, - is_constant); - ConstantVisitor(session_computation, reduce_request.init_value(), visited, - is_constant); + PureFunctionalVisitor(session_computation, reduce_request.operand(), + num_parameters, visited, is_functional); + PureFunctionalVisitor(session_computation, reduce_request.init_value(), + num_parameters, visited, is_functional); // TODO(b/32495713): We aren't checking the to_apply computation itself. break; } @@ -1630,10 +1636,12 @@ void ConstantVisitor(const SessionComputation& session_computation, case OpRequest::kReduceWindowRequest: { const ReduceWindowRequest& reduce_window_request = request.request().reduce_window_request(); - ConstantVisitor(session_computation, reduce_window_request.operand(), - visited, is_constant); - ConstantVisitor(session_computation, reduce_window_request.init_value(), - visited, is_constant); + PureFunctionalVisitor(session_computation, + reduce_window_request.operand(), num_parameters, + visited, is_functional); + PureFunctionalVisitor(session_computation, + reduce_window_request.init_value(), num_parameters, + visited, is_functional); // TODO(b/32495713): We aren't checking the to_apply computation itself. break; } @@ -1641,13 +1649,15 @@ void ConstantVisitor(const SessionComputation& session_computation, case OpRequest::kSelectAndScatterRequest: { const SelectAndScatterRequest& select_and_scatter_request = request.request().select_and_scatter_request(); - ConstantVisitor(session_computation, select_and_scatter_request.operand(), - visited, is_constant); - ConstantVisitor(session_computation, select_and_scatter_request.source(), - visited, is_constant); - ConstantVisitor(session_computation, - select_and_scatter_request.init_value(), visited, - is_constant); + PureFunctionalVisitor(session_computation, + select_and_scatter_request.operand(), + num_parameters, visited, is_functional); + PureFunctionalVisitor(session_computation, + select_and_scatter_request.source(), num_parameters, + visited, is_functional); + PureFunctionalVisitor(session_computation, + select_and_scatter_request.init_value(), + num_parameters, visited, is_functional); // TODO(b/32495713): We aren't checking the select and scatter // computations themselves. break; @@ -1656,76 +1666,80 @@ void ConstantVisitor(const SessionComputation& session_computation, case OpRequest::kBroadcastRequest: { const BroadcastRequest& broadcast_request = request.request().broadcast_request(); - ConstantVisitor(session_computation, broadcast_request.operand(), visited, - is_constant); + PureFunctionalVisitor(session_computation, broadcast_request.operand(), + num_parameters, visited, is_functional); break; } case OpRequest::kReshapeRequest: { const ReshapeRequest& reshape_request = request.request().reshape_request(); - ConstantVisitor(session_computation, reshape_request.operand(), visited, - is_constant); + PureFunctionalVisitor(session_computation, reshape_request.operand(), + num_parameters, visited, is_functional); break; } case OpRequest::kReverseRequest: { const ReverseRequest& reverse_request = request.request().reverse_request(); - ConstantVisitor(session_computation, reverse_request.operand(), visited, - is_constant); + PureFunctionalVisitor(session_computation, reverse_request.operand(), + num_parameters, visited, is_functional); break; } case OpRequest::kPadRequest: { const PadRequest& pad_request = request.request().pad_request(); - ConstantVisitor(session_computation, pad_request.operand(), visited, - is_constant); - ConstantVisitor(session_computation, pad_request.padding_value(), visited, - is_constant); + PureFunctionalVisitor(session_computation, pad_request.operand(), + num_parameters, visited, is_functional); + PureFunctionalVisitor(session_computation, pad_request.padding_value(), + num_parameters, visited, is_functional); break; } case OpRequest::kParameterRequest: { - *is_constant = false; + const ParameterRequest& parameter_request = + request.request().parameter_request(); + if (parameter_request.parameter() >= num_parameters) { + *is_functional = false; + } break; } case OpRequest::kConvertRequest: { const ConvertRequest& convert_request = request.request().convert_request(); - ConstantVisitor(session_computation, convert_request.operand(), visited, - is_constant); + PureFunctionalVisitor(session_computation, convert_request.operand(), + num_parameters, visited, is_functional); break; } case OpRequest::kWhileRequest: { const WhileRequest& while_request = request.request().while_request(); - ConstantVisitor(session_computation, while_request.init(), visited, - is_constant); + PureFunctionalVisitor(session_computation, while_request.init(), + num_parameters, visited, is_functional); // TODO(b/32495713): We aren't checking the condition and body // computations themselves. - *is_constant = false; + *is_functional = false; break; } case OpRequest::kTernaryOpRequest: { const TernaryOpRequest& ternary_op_request = request.request().ternary_op_request(); - ConstantVisitor(session_computation, ternary_op_request.lhs(), visited, - is_constant); - ConstantVisitor(session_computation, ternary_op_request.rhs(), visited, - is_constant); - ConstantVisitor(session_computation, ternary_op_request.ehs(), visited, - is_constant); + PureFunctionalVisitor(session_computation, ternary_op_request.lhs(), + num_parameters, visited, is_functional); + PureFunctionalVisitor(session_computation, ternary_op_request.rhs(), + num_parameters, visited, is_functional); + PureFunctionalVisitor(session_computation, ternary_op_request.ehs(), + num_parameters, visited, is_functional); break; } case OpRequest::kTransposeRequest: { const TransposeRequest& transpose_request = request.request().transpose_request(); - ConstantVisitor(session_computation, transpose_request.operand(), visited, - is_constant); + PureFunctionalVisitor(session_computation, transpose_request.operand(), + num_parameters, visited, is_functional); break; } @@ -1734,7 +1748,8 @@ void ConstantVisitor(const SessionComputation& session_computation, request.request().variadic_op_request(); for (const ComputationDataHandle& handle : variadic_op_request.operands()) { - ConstantVisitor(session_computation, handle, visited, is_constant); + PureFunctionalVisitor(session_computation, handle, num_parameters, + visited, is_functional); } break; } @@ -1742,67 +1757,74 @@ void ConstantVisitor(const SessionComputation& session_computation, case OpRequest::kUnaryOpRequest: { const UnaryOpRequest& unary_op_request = request.request().unary_op_request(); - ConstantVisitor(session_computation, unary_op_request.operand(), visited, - is_constant); + PureFunctionalVisitor(session_computation, unary_op_request.operand(), + num_parameters, visited, is_functional); break; } case OpRequest::kBatchNormTrainingRequest: { const BatchNormTrainingRequest& batch_norm_training_request = request.request().batch_norm_training_request(); - ConstantVisitor(session_computation, - batch_norm_training_request.operand(), visited, - is_constant); - ConstantVisitor(session_computation, batch_norm_training_request.scale(), - visited, is_constant); - ConstantVisitor(session_computation, batch_norm_training_request.offset(), - visited, is_constant); + PureFunctionalVisitor(session_computation, + batch_norm_training_request.operand(), + num_parameters, visited, is_functional); + PureFunctionalVisitor(session_computation, + batch_norm_training_request.scale(), num_parameters, + visited, is_functional); + PureFunctionalVisitor(session_computation, + batch_norm_training_request.offset(), + num_parameters, visited, is_functional); break; } case OpRequest::kBatchNormInferenceRequest: { const BatchNormInferenceRequest& batch_norm_inference_request = request.request().batch_norm_inference_request(); - ConstantVisitor(session_computation, - batch_norm_inference_request.operand(), visited, - is_constant); - ConstantVisitor(session_computation, batch_norm_inference_request.scale(), - visited, is_constant); - ConstantVisitor(session_computation, - batch_norm_inference_request.offset(), visited, - is_constant); - ConstantVisitor(session_computation, batch_norm_inference_request.mean(), - visited, is_constant); - ConstantVisitor(session_computation, - batch_norm_inference_request.variance(), visited, - is_constant); + PureFunctionalVisitor(session_computation, + batch_norm_inference_request.operand(), + num_parameters, visited, is_functional); + PureFunctionalVisitor(session_computation, + batch_norm_inference_request.scale(), + num_parameters, visited, is_functional); + PureFunctionalVisitor(session_computation, + batch_norm_inference_request.offset(), + num_parameters, visited, is_functional); + PureFunctionalVisitor(session_computation, + batch_norm_inference_request.mean(), num_parameters, + visited, is_functional); + PureFunctionalVisitor(session_computation, + batch_norm_inference_request.variance(), + num_parameters, visited, is_functional); break; } case OpRequest::kBatchNormGradRequest: { const BatchNormGradRequest& batch_norm_grad_request = request.request().batch_norm_grad_request(); - ConstantVisitor(session_computation, batch_norm_grad_request.operand(), - visited, is_constant); - ConstantVisitor(session_computation, batch_norm_grad_request.scale(), - visited, is_constant); - ConstantVisitor(session_computation, batch_norm_grad_request.mean(), - visited, is_constant); - ConstantVisitor(session_computation, batch_norm_grad_request.variance(), - visited, is_constant); - ConstantVisitor(session_computation, - batch_norm_grad_request.grad_output(), visited, - is_constant); + PureFunctionalVisitor(session_computation, + batch_norm_grad_request.operand(), num_parameters, + visited, is_functional); + PureFunctionalVisitor(session_computation, + batch_norm_grad_request.scale(), num_parameters, + visited, is_functional); + PureFunctionalVisitor(session_computation, batch_norm_grad_request.mean(), + num_parameters, visited, is_functional); + PureFunctionalVisitor(session_computation, + batch_norm_grad_request.variance(), num_parameters, + visited, is_functional); + PureFunctionalVisitor(session_computation, + batch_norm_grad_request.grad_output(), + num_parameters, visited, is_functional); break; } case OpRequest::kBinaryOpRequest: { const BinaryOpRequest& binary_op_request = request.request().binary_op_request(); - ConstantVisitor(session_computation, binary_op_request.lhs(), visited, - is_constant); - ConstantVisitor(session_computation, binary_op_request.rhs(), visited, - is_constant); + PureFunctionalVisitor(session_computation, binary_op_request.lhs(), + num_parameters, visited, is_functional); + PureFunctionalVisitor(session_computation, binary_op_request.rhs(), + num_parameters, visited, is_functional); break; } @@ -1817,8 +1839,8 @@ void ConstantVisitor(const SessionComputation& session_computation, } // namespace -StatusOr UserComputation::IsConstant( - const ComputationDataHandle& handle) { +StatusOr UserComputation::IsConstant(const ComputationDataHandle& handle, + int64 num_parameters) { tensorflow::mutex_lock lock(mutex_); // Verify that the handle is valid. @@ -1829,7 +1851,8 @@ StatusOr UserComputation::IsConstant( bool is_constant = true; std::set visited; - ConstantVisitor(session_computation_, handle, &visited, &is_constant); + PureFunctionalVisitor(session_computation_, handle, num_parameters, &visited, + &is_constant); return is_constant; } diff --git a/tensorflow/compiler/xla/service/user_computation.h b/tensorflow/compiler/xla/service/user_computation.h index dabf68e298ed2600d5248b7b8c7b1e014efedb14..ac879ce55a75f6241a39f935b79017be46c1816b 100644 --- a/tensorflow/compiler/xla/service/user_computation.h +++ b/tensorflow/compiler/xla/service/user_computation.h @@ -250,9 +250,11 @@ class UserComputation { StatusOr> ComputeProgramShape( VersionedComputationHandle::Version version) const; - // Returns true if the given data handle does not depend on any - // parameters. That is, the value can be computed at compile time. - StatusOr IsConstant(const ComputationDataHandle& handle); + // Returns true if the given data handle does not depend on any parameter with + // index higher then num_parameters. That is, the value can be computed at + // compile time if we know the first num_parameters arguments. + StatusOr IsConstant(const ComputationDataHandle& handle, + int64 num_parameters); // Returns the output shape of the operation indicated by the given handle. StatusOr GetShape(const ComputationDataHandle& handle); diff --git a/tensorflow/compiler/xla/tests/compute_constant_test.cc b/tensorflow/compiler/xla/tests/compute_constant_test.cc index b2e9743af79d0e4658451e7a9522c338036851ba..d423c78476dde18d209b5efac9e8f77da41bfeb4 100644 --- a/tensorflow/compiler/xla/tests/compute_constant_test.cc +++ b/tensorflow/compiler/xla/tests/compute_constant_test.cc @@ -71,24 +71,27 @@ class ComputeConstantTest : public ::testing::Test { StatusOr> ComputeConstantLiteral( Client* client, const ComputationDataHandle& operand, - ComputationBuilder* builder, Layout* output_layout = nullptr) { - TF_ASSIGN_OR_RETURN(auto computed, - builder->ComputeConstant(operand, output_layout)); + ComputationBuilder* builder, Layout* output_layout = nullptr, + tensorflow::gtl::ArraySlice parameters = {}) { + TF_ASSIGN_OR_RETURN(auto computed, builder->ComputeConstant( + operand, output_layout, parameters)); return std::move(computed); } template - StatusOr ComputeConstantScalar(Client* client, - const ComputationDataHandle& operand, - ComputationBuilder* builder) { - TF_ASSIGN_OR_RETURN(auto literal, - ComputeConstantLiteral(client, operand, builder)); + StatusOr ComputeConstantScalar( + Client* client, const ComputationDataHandle& operand, + ComputationBuilder* builder, + tensorflow::gtl::ArraySlice parameters = {}) { + TF_ASSIGN_OR_RETURN( + auto literal, + ComputeConstantLiteral(client, operand, builder, nullptr, parameters)); return literal->Get({}); } bool IsConstant(const ComputationDataHandle& operand, - ComputationBuilder* builder) { - StatusOr result = builder->IsConstant(operand); + ComputationBuilder* builder, int64 num_parameters = 0) { + StatusOr result = builder->IsConstant(operand, num_parameters); EXPECT_TRUE(result.ok()) << result.status(); return result.ok() ? result.ValueOrDie() : false; } @@ -138,7 +141,25 @@ TEST_F(ComputeConstantTest, ScalarRng) { } } -TEST_F(ComputeConstantTest, DirectParam) { +TEST_F(ComputeConstantTest, Param) { + for (ClientType client_type : client_types) { + Client* client = ClientOrDie(platform_, client_type); + ComputationBuilder b(client, TestName()); + auto param = b.Parameter(0, ShapeUtil::MakeShape(F32, {}), "lhs"); + auto computation = b.Add(param, b.ConstantR0(1.5f)); + + std::vector arguments; + arguments.emplace_back(*Literal::CreateR0(42.5f)); + EXPECT_TRUE(IsConstant(computation, &b, arguments.size())); + + auto value = + ComputeConstantScalar(client, computation, &b, arguments); + ASSERT_TRUE(value.ok()) << value.status(); + EXPECT_EQ(value.ValueOrDie(), 44.0f); + } +} + +TEST_F(ComputeConstantTest, DirectParamMissing) { for (ClientType client_type : client_types) { Client* client = ClientOrDie(platform_, client_type); ComputationBuilder b(client, TestName()); @@ -152,7 +173,7 @@ TEST_F(ComputeConstantTest, DirectParam) { } } -TEST_F(ComputeConstantTest, IndirectParam) { +TEST_F(ComputeConstantTest, IndirectParamMissing) { for (ClientType client_type : client_types) { Client* client = ClientOrDie(platform_, client_type); ComputationBuilder b(client, TestName()); diff --git a/tensorflow/compiler/xla/tests/while_test.cc b/tensorflow/compiler/xla/tests/while_test.cc index 71a1b0abee51ba2819daed23208b0da8d5107207..3b29a2eb9e04cc8f5bd55be00bfc6e6ad0b985c2 100644 --- a/tensorflow/compiler/xla/tests/while_test.cc +++ b/tensorflow/compiler/xla/tests/while_test.cc @@ -357,6 +357,111 @@ TEST_F(WhileTest, WhileWithVectorResultIntoTuple) { ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001)); } +// TODO(b/63003356): 11-06-2017: fails on all back-ends with incorrect result. +TEST_F(WhileTest, DISABLED_WhileWithPermutationAndTupleResult) { + std::vector shape_elements = { + ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(F32, {3}), + ShapeUtil::MakeShape(F32, {3}), ShapeUtil::MakeShape(F32, {3})}; + Shape result_shape = ShapeUtil::MakeTupleShape(shape_elements); + + // Create a computation for the condition. + // Repeat for N iterations. + const int N = 2; + Computation condition; + { + ComputationBuilder builder(client_, "condition"); + auto prev = builder.Parameter(0, result_shape, "prev"); + auto iteration = builder.GetTupleElement(prev, 0); + builder.Gt(builder.ConstantR0(N), iteration); + condition = builder.Build().ConsumeValueOrDie(); + } + + // Create a computation for the body. + // Add 1 to the iteration variable and permute the weights. + Computation body; + { + ComputationBuilder builder(client_, "body"); + auto prev = builder.Parameter(0, result_shape, "prev"); + auto iteration = builder.GetTupleElement(prev, 0); + auto w1 = builder.GetTupleElement(prev, 1); + auto w2 = builder.GetTupleElement(prev, 2); + auto w3 = builder.GetTupleElement(prev, 3); + auto result = builder.Tuple( + {builder.Add(iteration, builder.ConstantR0(1)), w3, w1, w2}); + body = builder.Build().ConsumeValueOrDie(); + } + + // Create a While node with computations for the condition and the body. + ComputationBuilder builder(client_, "while"); + auto init = builder.Tuple( + {builder.ConstantR0(0), builder.ConstantR1(3, 1.f), + builder.ConstantR1(3, 2.f), builder.ConstantR1(3, 3.f)}); + auto result = builder.While(condition, body, init); + VLOG(2) << "result = " + << ShapeUtil::HumanString( + *builder.GetShape(result).ConsumeValueOrDie()); + + auto expected_counter = Literal::CreateR0(N); + auto expected_w1 = Literal::CreateR1({1.0f, 1.0f, 1.0f}); + auto expected_w2 = Literal::CreateR1({2.0f, 2.0f, 2.0f}); + auto expected_w3 = Literal::CreateR1({3.0f, 3.0f, 3.0f}); + auto expected = Literal::MakeTuple({expected_counter.get(), expected_w2.get(), + expected_w3.get(), expected_w1.get()}); + VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape()); + ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001)); +} + +// TODO(b/63003356): 11-06-2017: fails on all back-ends with incorrect result. +TEST_F(WhileTest, DISABLED_WhileWithPermutationAndVectorResult) { + std::vector shape_elements = { + ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(F32, {3}), + ShapeUtil::MakeShape(F32, {3}), ShapeUtil::MakeShape(F32, {3})}; + Shape result_shape = ShapeUtil::MakeTupleShape(shape_elements); + + // Create a computation for the condition. + // Repeat for N iterations. + const int N = 2; + Computation condition; + { + ComputationBuilder builder(client_, "condition"); + auto prev = builder.Parameter(0, result_shape, "prev"); + auto iteration = builder.GetTupleElement(prev, 0); + builder.Gt(builder.ConstantR0(N), iteration); + condition = builder.Build().ConsumeValueOrDie(); + } + + // Create a computation for the body. + // Add 1 to the iteration variable permute the weights. + Computation body; + { + ComputationBuilder builder(client_, "body"); + auto prev = builder.Parameter(0, result_shape, "prev"); + auto iteration = builder.GetTupleElement(prev, 0); + auto w1 = builder.GetTupleElement(prev, 1); + auto w2 = builder.GetTupleElement(prev, 2); + auto w3 = builder.GetTupleElement(prev, 3); + auto result = builder.Tuple( + {builder.Add(iteration, builder.ConstantR0(1)), w3, w1, w2}); + body = builder.Build().ConsumeValueOrDie(); + } + + // Create a While node with computations for the condition and the body. + ComputationBuilder builder(client_, "while"); + auto init = builder.Tuple( + {builder.ConstantR0(0), builder.ConstantR1(3, 1.f), + builder.ConstantR1(3, 2.f), builder.ConstantR1(3, 3.f)}); + auto xla_while = builder.While(condition, body, init); + + auto add12 = builder.Add(builder.GetTupleElement(xla_while, 1), + builder.GetTupleElement(xla_while, 2)); + auto result = builder.Add(add12, builder.GetTupleElement(xla_while, 3)); + VLOG(2) << "result = " + << ShapeUtil::HumanString( + *builder.GetShape(result).ConsumeValueOrDie()); + std::vector expected = {6.f, 6.f, 6.f}; + ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); +} + // Tests a while node when the result type T is a Tuple. // // tuple> result(0, vector(10, 0.0f)); diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc index 5de73ee866388e6b5b1a330b282792d1da4100c2..6c2e37e3b5cdd73157279fb171d3332aa9854184 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc +++ b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc @@ -58,6 +58,7 @@ class HloParser { string* root_name); bool ParseInstruction(HloComputation::Builder* builder, string* root_name); bool ParseSharding(HloInstruction* instruction); + bool ParseControlPredecessors(HloInstruction* instruction); bool ParseLiteral(std::unique_ptr* literal, const Shape& shape); bool ParseTupleLiteral(std::unique_ptr* literal, const Shape& shape); bool ParseNonTupleLiteral(std::unique_ptr* literal, @@ -436,10 +437,35 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, return TokenError(StrCat("parsing not yet implemented for op: ", HloOpcodeString(opcode))); } - // Parse "sharding=". - if (lexer_.GetKind() == TokKind::kComma) { - if (!ParseSharding(instruction)) { - return false; + + bool has_sharding = false; + bool has_control = false; + while (EatIfPresent(TokKind::kComma)) { + string attribute_name; + if (!ParseAttributeName(&attribute_name)) { + return TokenError("expects ', sharding=' or ', control-predecessors='"); + } + + if (attribute_name == "sharding") { + // Parse "sharding=". + if (has_sharding) { + return TokenError("expects at most 1 'sharding='"); + } + has_sharding = true; + if (!ParseSharding(instruction)) { + return false; + } + } else if (attribute_name == "control-predecessors") { + // Parse "control-predecessors" + if (has_control) { + return TokenError("expects at most 1 'control-predecessors='"); + } + has_control = true; + if (!ParseControlPredecessors(instruction)) { + return false; + } + } else { + return TokenError(StrCat("unexpected attribute: ", attribute_name)); } } @@ -449,15 +475,6 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, // ::= '{' 'replicated'? 'maximal'? ('device=' int)? shape? ('devices=' ('[' // dims ']')* device_list)? '}' dims ::= int_list device_list ::= int_list bool HloParser::ParseSharding(HloInstruction* instruction) { - if (!ParseToken(TokKind::kComma, - "expects ',' in front of an extra attribute")) { - return false; - } - string attribute_name; - if (!ParseAttributeName(&attribute_name) || attribute_name != "sharding") { - return TokenError("expects attribute name: sharding"); - } - if (!ParseToken(TokKind::kLbrace, "expected '{' to start sharding attribute")) { return false; @@ -577,6 +594,34 @@ bool HloParser::ParseSharding(HloInstruction* instruction) { return true; } +// '{' name+ '}' +bool HloParser::ParseControlPredecessors(HloInstruction* instruction) { + if (!ParseToken(TokKind::kLbrace, + "expects '{' at the beginning of control predecessors")) { + return false; + } + do { + string name; + if (!ParseName(&name)) { + return TokenError("expects a control predecessor"); + } + HloInstruction* pre = + tensorflow::gtl::FindPtrOrNull(instruction_pool_, name); + if (!pre) { + return TokenError( + StrCat("control predecessor ", name, " is not defined: ")); + } + Status status = pre->AddControlDependencyTo(instruction); + if (!status.ok()) { + return TokenError(StrCat("error adding control dependency for: ", name, + " status: ", status.ToString())); + } + } while (EatIfPresent(TokKind::kComma)); + + return ParseToken(TokKind::kRbrace, + "expects '}' at the end of control predecessors"); +} + bool HloParser::SetValueInLiteral(int64 value, int64 linear_index, Literal* literal) { const Shape& shape = literal->shape(); diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc index e065af7da61fb6096c820825cbd81d4a788c93fe..359256f0646367f8af13439b30067624defcd44c 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc @@ -214,7 +214,7 @@ R"(HloModule TwoSendRecvBothWayRecvFist_module: ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] { %recv = f32[] recv(), channel_id=15, sharding={maximal device=1} ROOT %constant = f32[] constant(2.1), sharding={maximal device=0} - %send = () send(f32[] %constant), channel_id=16, sharding={maximal device=0} + %send = () send(f32[] %constant), channel_id=16, sharding={maximal device=0}, control-predecessors={%recv} } )" diff --git a/tensorflow/compiler/xla/xla.proto b/tensorflow/compiler/xla/xla.proto index ce3c3eee68ad7f7ebb42836e3cae14803f8650d7..710bb6ff25bf649693165c5e9fb6bc50e81db4ca 100644 --- a/tensorflow/compiler/xla/xla.proto +++ b/tensorflow/compiler/xla/xla.proto @@ -361,6 +361,7 @@ message WaitForExecutionResponse { message IsConstantRequest { ComputationHandle computation = 1; ComputationDataHandle operand = 2; + int64 num_parameters = 3; } message IsConstantResponse { @@ -371,6 +372,7 @@ message ComputeConstantRequest { ComputationHandle computation = 1; ComputationDataHandle operand = 2; Layout output_layout = 3; + repeated LiteralProto parameters = 4; } message ComputeConstantResponse { diff --git a/tensorflow/contrib/batching/adaptive_shared_batch_scheduler.h b/tensorflow/contrib/batching/adaptive_shared_batch_scheduler.h index a0606427a526ffc67e10d12a084eabc64564e4ab..6ed177e001758ad8c566c7965e1ec10ae5235fc8 100644 --- a/tensorflow/contrib/batching/adaptive_shared_batch_scheduler.h +++ b/tensorflow/contrib/batching/adaptive_shared_batch_scheduler.h @@ -399,7 +399,7 @@ ASBSQueue::~ASBSQueue() { template Status ASBSQueue::Schedule(std::unique_ptr* task) { - bool added_new_batch = false; + ASBSBatch* new_batch = nullptr; size_t size = (*task)->size(); if (size > options_.max_batch_size) { return errors::InvalidArgument("Task size ", size, @@ -418,15 +418,14 @@ Status ASBSQueue::Schedule(std::unique_ptr* task) { current_batch_ = nullptr; } if (!current_batch_) { - added_new_batch = true; num_enqueued_batches_++; - current_batch_ = + current_batch_ = new_batch = new ASBSBatch(this, scheduler_->GetEnv()->NowMicros()); } current_batch_->AddTask(std::move(*task)); num_enqueued_tasks_++; } - if (added_new_batch) scheduler_->AddBatch(current_batch_); + if (new_batch != nullptr) scheduler_->AddBatch(new_batch); return Status::OK(); } diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py index 4d9fd753233f520b961559e2dc3adbafc5bb1d2f..cebe3474ca9251971c23bde9e82564189c1ee624 100644 --- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py +++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py @@ -208,7 +208,7 @@ def extract_features(features, feature_columns): if tensor.dtype == dtypes.float32: if len(tensor.shape) > 1 and tensor.shape[1] > 1: unstacked = array_ops.unstack(tensor, axis=1) - for i in xrange(len(unstacked)): + for i in range(len(unstacked)): dense_float_names.append(_FEATURE_NAME_TEMPLATE % (key, i)) dense_floats.append(array_ops.reshape(unstacked[i], [-1, 1])) else: diff --git a/tensorflow/contrib/cmake/tf_python.cmake b/tensorflow/contrib/cmake/tf_python.cmake index 68234911a3fda9df6c65f32b088d0968a6f37c00..4b60460cb22b3937065b9cb7f71061019d9f0a4e 100755 --- a/tensorflow/contrib/cmake/tf_python.cmake +++ b/tensorflow/contrib/cmake/tf_python.cmake @@ -224,6 +224,7 @@ add_python_module("tensorflow/python/grappler") add_python_module("tensorflow/python/keras") add_python_module("tensorflow/python/keras/activations") add_python_module("tensorflow/python/keras/applications") +add_python_module("tensorflow/python/keras/applications/inception_resnet_v2") add_python_module("tensorflow/python/keras/applications/inception_v3") add_python_module("tensorflow/python/keras/applications/mobilenet") add_python_module("tensorflow/python/keras/applications/resnet50") diff --git a/tensorflow/contrib/data/__init__.py b/tensorflow/contrib/data/__init__.py index 6c46acf20442c2cc435829afa57e8383b493d6af..824ac4298f88a0372743324793f6de453dae71c8 100644 --- a/tensorflow/contrib/data/__init__.py +++ b/tensorflow/contrib/data/__init__.py @@ -30,6 +30,7 @@ See the @{$datasets$Importing Data} Programmer's Guide for an overview. @@make_saveable_from_iterator @@read_batch_features @@unbatch +@@parallel_interleave @@rejection_resample @@sloppy_interleave @@ -50,6 +51,7 @@ from tensorflow.contrib.data.python.ops.dataset_ops import get_single_element from tensorflow.contrib.data.python.ops.enumerate_ops import enumerate_dataset from tensorflow.contrib.data.python.ops.error_ops import ignore_errors from tensorflow.contrib.data.python.ops.grouping import group_by_window +from tensorflow.contrib.data.python.ops.interleave_ops import parallel_interleave from tensorflow.contrib.data.python.ops.interleave_ops import sloppy_interleave from tensorflow.contrib.data.python.ops.iterator_ops import make_saveable_from_iterator from tensorflow.contrib.data.python.ops.readers import FixedLengthRecordDataset diff --git a/tensorflow/contrib/eager/python/examples/mnist/mnist.py b/tensorflow/contrib/eager/python/examples/mnist/mnist.py index 3dd920415df8b6d367d690132b647318ab8ba25b..bfb7d5a9002787f6544d383de58150661ac2bde3 100644 --- a/tensorflow/contrib/eager/python/examples/mnist/mnist.py +++ b/tensorflow/contrib/eager/python/examples/mnist/mnist.py @@ -191,9 +191,9 @@ def main(_): train_dir = None test_dir = None summary_writer = tf.contrib.summary.create_summary_file_writer( - train_dir, flush_secs=10) + train_dir, flush_millis=10000) test_summary_writer = tf.contrib.summary.create_summary_file_writer( - test_dir, flush_secs=10, name='test') + test_dir, flush_millis=10000, name='test') checkpoint_prefix = os.path.join(FLAGS.checkpoint_dir, 'ckpt') with tf.device(device): diff --git a/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot.py b/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot.py index 318962c634e0d050b35da5efc405400380c1b759..609cbd28772c3ae8da70648ca5b1b264a8a255e2 100644 --- a/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot.py +++ b/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot.py @@ -248,9 +248,9 @@ def main(_): log_dir = os.path.join(FLAGS.dir, "summaries") tf.gfile.MakeDirs(log_dir) train_summary_writer = tf.contrib.summary.create_summary_file_writer( - os.path.join(log_dir, "train"), flush_secs=10) + os.path.join(log_dir, "train"), flush_millis=10000) test_summary_writer = tf.contrib.summary.create_summary_file_writer( - os.path.join(log_dir, "eval"), flush_secs=10, name="eval") + os.path.join(log_dir, "eval"), flush_millis=10000, name="eval") with tf.device(device): for epoch in range(FLAGS.num_epochs): diff --git a/tensorflow/contrib/estimator/BUILD b/tensorflow/contrib/estimator/BUILD index a0f83ac10555913b5be177f0f2b00b2b0e30494a..6eb2cfdaca7840c4a5dd8cffc9620aaf3f96a1de 100644 --- a/tensorflow/contrib/estimator/BUILD +++ b/tensorflow/contrib/estimator/BUILD @@ -7,6 +7,7 @@ package( licenses(["notice"]) # Apache 2.0 load("//tensorflow:tensorflow.bzl", "py_test") +load("//tensorflow:tensorflow.bzl", "cuda_py_test") filegroup( name = "all_files", @@ -30,6 +31,7 @@ py_library( ":head", ":logit_fns", ":multi_head", + ":replicate_model_fn", "//tensorflow/python:util", ], ) @@ -227,9 +229,69 @@ py_test( "//tensorflow/python:string_ops", "//tensorflow/python/estimator:metric_keys", "//tensorflow/python/estimator:model_fn", - "//tensorflow/python/estimator:prediction_keys", + "//tensorflow/python/ops/losses", "//tensorflow/python/saved_model:signature_constants", "//third_party/py/numpy", "@six_archive//:six", ], ) + +py_library( + name = "replicate_model_fn", + srcs = [ + "python/estimator/replicate_model_fn.py", + ], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/core:protos_all_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:device", + "//tensorflow/python:device_lib", + "//tensorflow/python:framework_ops", + "//tensorflow/python:gradients", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform", + "//tensorflow/python:state_ops", + "//tensorflow/python:training", + "//tensorflow/python:variable_scope", + "//tensorflow/python:variables", + "//tensorflow/python/estimator:export_output", + "//tensorflow/python/estimator:model_fn", + "//tensorflow/python/estimator:util", + "@six_archive//:six", + ], +) + +cuda_py_test( + name = "replicate_model_fn_test", + size = "small", + srcs = ["python/estimator/replicate_model_fn_test.py"], + additional_deps = [ + "//tensorflow/python/estimator", + "//tensorflow/python/estimator:dnn", + "//tensorflow/python/estimator:export_export", + "//tensorflow/python/estimator:export_output", + "//tensorflow/python/estimator:model_fn", + "//tensorflow/python/estimator:numpy_io", + "//tensorflow/python/estimator:optimizers", + "//tensorflow/python/estimator:prediction_keys", + "//tensorflow/python/feature_column", + "//tensorflow/python/ops/losses", + "//tensorflow/python/saved_model:signature_constants", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:math_ops", + "//tensorflow/python:metrics", + "//tensorflow/python:platform", + "//tensorflow/python:summary", + "//tensorflow/python:training", + "//tensorflow/python:variable_scope", + "//tensorflow/python:variables", + ":replicate_model_fn", + ], + tags = ["requires-gpu-sm35"], +) diff --git a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py new file mode 100644 index 0000000000000000000000000000000000000000..7005a647db599dfa386f34406911febe1d9d5651 --- /dev/null +++ b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py @@ -0,0 +1,470 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utilities to replicate model_fn's over local GPUs. + +This file contains util that allow to replicate `Estimator.model_fn` over +GPUs. Replicated version of a `model_fn` is returned that can subsequently +be used with `Estimator`. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import copy + +import six + +from tensorflow.core.framework import node_def_pb2 +from tensorflow.python.client import device_lib +from tensorflow.python.estimator import model_fn as model_fn_lib +from tensorflow.python.estimator import util +from tensorflow.python.estimator.export import export_output as export_output_lib +from tensorflow.python.framework import device as framework_device +from tensorflow.python.framework import ops as ops_lib +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import gradients as gradients_lib +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import state_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables as variables_lib +from tensorflow.python.platform import tf_logging +from tensorflow.python.training import training_util + + +def replicate_model_fn(model_fn, optimizer_fn, devices=None): + """Replicate `Estimator.model_fn` over GPUs within a single host. + + The given `model_fn` specifies a single forward pass of a model. To replicate + such a model over GPUs, each GPU gets its own instance of the forward pass + (a.k.a. a tower). The input features and labels get sharded into the chunks + that correspond to the number of GPUs. Each tower computes its own loss based + on its input. For each such loss, gradients are computed. After that, the + available losses are summed to form aggregated loss. The available + gradients are summed too. Then, they update weights using the specified + optimizer. + + If `devices` are `None`, then all available GPUs are going to be used for + replication. If no GPUs are available, then the model is going to be + placed on the CPU. + + Two modes of local replication over available GPUs are supported: + 1) If exactly 1 GPU is detected, then variables and operations are placed + onto GPU. + 2) If more than 1 GPU is detected, then variables are going to be placed on + the CPU. Replicas of operations are placed on each individual GPU. + + Here is an example of how one might use their `model_fn` to run over GPUs: + ```python + def optimizer_fn(): + return tf.train.GradientDescentOptimizer(learning_rate=0.001) + ... + def model_fn(...): # See `model_fn` in `Estimator`. + loss = ... + if mode == tf.estimator.ModeKeys.TRAIN: + # See the section below on `EstimatorSpec.train_op`. + return EstimatorSpec(mode=mode, loss=loss, train_op=tf.noop()) + + # No change for `ModeKeys.EVAL` or `ModeKeys.PREDICT`. + return EstimatorSpec(...) + ... + classifier = tf.estimator.Estimator( + model_fn=replicate_model_fn.replicate_model_fn(model_fn, optimizer_fn)) + ``` + + On `EstimatorSpec.train_op`: + `model_fn` returns `EstimatorSpec.train_op` for + `tf.estimator.GraphKeys.TRAIN`. It is typically derived using an optimizer. + `replicate_model_fn` ignores the returned `EstimatorSpec.train_op`, so there + is no need to use an optimizer inside the user's `model_fn`. The + `EstimatorSpec.loss` subgraph is going to be executed, while + `EstimatorSpec.train_op` isn't going to be executed. One could pass + `train_op=tf.noop()` to `EstimatorSpec`. + + On sharding input features and labels: + Input features and labels are split for consumption by each tower. They are + split across the dimension 0. Features and labels need to be batch major. + + On reduction algorithms: + Certain algorithms were chosen for aggregating results of computations on + multiple towers: + - Losses from all towers are reduced using sum. + - Gradients are reduced using sum for each trainable variable. + - `eval_metrics_ops` are reduced per metric using `reduce_mean`. + - `EstimatorSpec.predictions` and `EstimatorSpec.export_outputs` are + reduced using concatenation. + - For all other fields of `EstimatorSpec` the values of the first tower + are taken. + + On replication of variables: + Variables are not duplicated between towers. Instead, they are placed on a + single device as defined above and shared across towers. + + Other current limitations: + - `predictions` are not supported for `ModeKeys.EVAL`. That is required for + `tf.contrib.estimator.add_metrics`. + + Args: + model_fn: `model_fn` as defined in `Estimator`. See the section above about + the train_op argument of `EstimatorSpec`. + optimizer_fn: a function that returns an optimizer instance. The function + may accept one `params` argument. This is the `params` argument as + defined by `Estimator`. See the `Estimator` documentation for details. + devices: Optional list of devices to replicate the model across. This + argument can be used to replice only on the subset of available GPUs. + If `None`, then all available GPUs are going to be used for replication. + If no GPUs are available, then the model is going to be placed on the CPU. + + Returns: + A replicated version of the supplied `model_fn`. Returned function that + conforms to the requirements of `Estimator`'s `model_fn` and can be used + instead of the supplied `model_fn`. + """ + if not devices: + devices = _get_local_devices('GPU') or _get_local_devices('CPU') + + is_a_single_gpu_case = len(devices) == 1 and 'GPU' in devices[0] + local_ps_device = '/{}:0'.format('GPU' if is_a_single_gpu_case else 'CPU') + + tf_logging.info('Replicating the `model_fn` across {}. Local parameter ' + 'server device is going to be {}.'.format( + devices, local_ps_device)) + + def replicated_model_fn(mode, features, labels, params=None, config=None): + """Replicated version of `model_fn` to be used instead.""" + feature_shards, label_shards = _split_batch( + features, labels, len(devices), device=local_ps_device) + tower_specs = _get_loss_towers( + model_fn=model_fn, + mode=mode, + features=feature_shards, + labels=label_shards, + params=params, + config=config, + devices=devices, + local_ps_device=local_ps_device) + + if mode == model_fn_lib.ModeKeys.TRAIN: + train_op = _minimize_towers(tower_specs, + _call_optimizer_fn(optimizer_fn, params)) + return _train_spec( + tower_specs, train_op, aggregation_device=local_ps_device) + elif mode == model_fn_lib.ModeKeys.EVAL: + return _eval_spec(tower_specs, aggregation_device=local_ps_device) + elif mode == model_fn_lib.ModeKeys.PREDICT: + return _predict_spec(tower_specs, aggregation_device=local_ps_device) + + return replicated_model_fn + + +def _get_local_devices(device_type): + local_device_protos = device_lib.list_local_devices() + return [ + device.name + for device in local_device_protos + if device.device_type == device_type + ] + + +def _split_batch(features, labels, number_of_shards, device): + """Split input features and labes into batches.""" + + def split_dictionary(dictionary): + shards = [{} for _ in range(number_of_shards)] + for name, tensor in six.iteritems(dictionary): + for i, shard in enumerate(array_ops.split(tensor, number_of_shards)): + shards[i][name] = shard + return shards + + with ops_lib.name_scope('split_inputs'): + with ops_lib.device(device): + if isinstance(features, dict): + feature_shards = split_dictionary(features) + else: + feature_shards = array_ops.split(features, number_of_shards) + + if labels is None: + label_shards = None + elif isinstance(labels, dict): + label_shards = split_dictionary(labels) + else: + label_shards = array_ops.split(labels, number_of_shards) + return feature_shards, label_shards + + +_DEFAULT_NAME_SCOPE_PATTERN = 'tower_{}' + + +def _get_loss_towers(model_fn, + mode, + features, + labels, + params, + config, + devices, + local_ps_device, + name_scope_pattern=_DEFAULT_NAME_SCOPE_PATTERN): + """Replicate the loss computation across devices.""" + tower_specs = [] + + model_fn_args = util.fn_args(model_fn) + optional_params = {} + if 'params' in model_fn_args: + optional_params['params'] = copy.deepcopy(params) + if 'config' in model_fn_args: + optional_params['config'] = copy.deepcopy(config) + + for i, device in enumerate(devices): + is_the_first_tower = (i == 0) + + device_setter = _local_device_setter( + worker_device=device, ps_device=local_ps_device) + + # We would like to preserve the names of the variables and ops that a user + # might be relying on. Names with prefix are going to resolve to variables + # and ops of the first tower. + name_scope = name_scope_pattern + if is_the_first_tower: + name_scope = '' + + with variable_scope.variable_scope('', reuse=not is_the_first_tower): + with ops_lib.name_scope(name_scope.format(i)): + with ops_lib.device(device_setter): + labels_shard = None + if labels: + labels_shard = labels[i] + + tower_specs.append( + model_fn( + mode=mode, + features=features[i], + labels=labels_shard, + **optional_params)) + return tower_specs + + +def _local_device_setter(ps_device, worker_device): + """A device setter that puts distributes Var/Ops to PS/workers.""" + ps_ops = ['Variable', 'VariableV2', 'VarHandleOp'] + + def local_device_chooser(op): + current_device = framework_device.DeviceSpec.from_string(op.device or '') + + node_def = op if isinstance(op, node_def_pb2.NodeDef) else op.node_def + if node_def.op in ps_ops: + ps_device_spec = framework_device.DeviceSpec.from_string( + '{}'.format(ps_device)) + + ps_device_spec.merge_from(current_device) + return ps_device_spec.to_string() + else: + worker_device_spec = framework_device.DeviceSpec.from_string( + worker_device or '') + worker_device_spec.merge_from(current_device) + return worker_device_spec.to_string() + + return local_device_chooser + + +def _minimize_towers(tower_specs, optimizer): + """Aggregate and apply gradients for computed losses.""" + grad_lists = {} + for tower_spec in tower_specs: + with ops_lib.device(tower_spec.loss.device): + variables = variables_lib.trainable_variables() + gradients = gradients_lib.gradients(tower_spec.loss, variables) + + for var, grad in zip(variables, gradients): + if grad is not None: + grad_lists.setdefault(var, []).append(grad) + + aggregated_grads = [] + with ops_lib.name_scope('gradient_aggregating'): + for var, grads in six.iteritems(grad_lists): + grad = _compute_sum_on_device(grads, var.device) + aggregated_grads.append((grad, var)) + + train_op = optimizer.apply_gradients( + aggregated_grads, global_step=training_util.get_global_step()) + + return train_op + + +def _call_optimizer_fn(optimizer_fn, params): + arguments = {} + optimizer_fn_arguments = util.fn_args(optimizer_fn) + if 'params' in optimizer_fn_arguments: + arguments['params'] = params + return optimizer_fn(**arguments) + + +def _compute_sum_on_device(values, device, name=None): + with ops_lib.device(device): + return math_ops.add_n(values, name=name) + + +def _train_spec(tower_specs, + train_op, + aggregation_device, + aggregated_loss_name='loss'): + """Populate replicated EstimatorSpec for `GraphKeys.TRAIN`.""" + estimator_spec = tower_specs[0]._asdict() + estimator_spec['mode'] = model_fn_lib.ModeKeys.TRAIN + estimator_spec['train_op'] = train_op + estimator_spec['loss'] = _compute_sum_on_device( + [spec.loss for spec in tower_specs], aggregation_device, + aggregated_loss_name) + return model_fn_lib.EstimatorSpec(**estimator_spec) + + +def _eval_spec(tower_specs, aggregation_device, aggregated_loss_name='loss'): + """Populate replicated EstimatorSpec for `GraphKeys.EVAL`.""" + estimator_spec = tower_specs[0]._asdict() + estimator_spec['mode'] = model_fn_lib.ModeKeys.EVAL + estimator_spec['loss'] = _compute_sum_on_device( + [spec.loss for spec in tower_specs], aggregation_device, + aggregated_loss_name) + + eval_metric_ops_lists = {} + for tower_spec in tower_specs: + metrics = tower_spec.eval_metric_ops or {} + for name, (_, update_op) in six.iteritems(metrics): + update_ops = eval_metric_ops_lists.setdefault(name, ([])) + update_ops.append(update_op) + + eval_metric_ops = {} + for name, (metric_tensor, _) in six.iteritems(tower_specs[0].eval_metric_ops): + with ops_lib.control_dependencies(eval_metric_ops_lists[name]): + # This operation reduces local variables across all metrics, yet is + # called for every metric. This is redundant and it's done because + # it is hard to know what local variables correspond to what metric. + # Estimator is going to execute all `reduced_update_op`s as part of + # a group inside a single `Session.run()` call, which will avoid duplicate + # computation. + reduced_update_op = _reduce_metric_variables(len(tower_specs)) + eval_metric_ops[name] = (metric_tensor, reduced_update_op) + + estimator_spec['eval_metric_ops'] = eval_metric_ops + return model_fn_lib.EstimatorSpec(**estimator_spec) + + +def _reduce_metric_variables(number_of_towers): + """Aggregate local variables used in metrics into the first tower.""" + if number_of_towers == 1: + return control_flow_ops.no_op() + + metric_variables = ops_lib.get_collection(ops_lib.GraphKeys.METRIC_VARIABLES) + variables_per_tower = len(metric_variables) // number_of_towers + + if len(metric_variables) % number_of_towers != 0: + raise ValueError( + 'Different `EstimatorSpec.eval_metric_ops` across `model_fn()` calls.' + ' Expected {} local variables, but got {} instead.'.format( + variables_per_tower * number_of_towers, len(metric_variables))) + + # `metric_variables` has the size of `variables_per_tower` x + # number_of_towers. Each tower is produced by calling the same model_fn. + # First `variables_per_tower` correspond to the first tower. Each such + # variable has an replica at the `(variables_per_tower * i)` position, where + # `i` is `[1.. number_of_towers]`. We are going to add values from replicas + # to each variable of the first tower. We then zero out replica values, so + # that `_reduce_metric_variables` operation is idempotent. If a metric + # is then computed based on local variables from the first tower, then the + # resulting metric is an estimate for all `number_of_towers` towers. + ops = [] + for i in range(0, variables_per_tower): + next_replica_id = i + variables_per_tower + replicas = [ + metric_variables[replica_id] + for replica_id in range(next_replica_id, len(metric_variables), + variables_per_tower) + ] # `replicas` doesn't contain the first-tower variable. + + reduce_op = state_ops.assign_add(metric_variables[i], + math_ops.add_n(replicas)) + + with ops_lib.control_dependencies([reduce_op]): + for replica in replicas: + zeros_for_replica = array_ops.zeros( + array_ops.shape(replica), dtype=replica.dtype) + zero_out_replica_op = state_ops.assign(replica, zeros_for_replica) + ops.append(zero_out_replica_op) + + return control_flow_ops.group(*ops) + + +def _predict_spec(tower_specs, aggregation_device): + """Populate replicated EstimatorSpec for `GraphKeys.PREDICT`.""" + estimator_spec = tower_specs[0]._asdict() + estimator_spec['mode'] = model_fn_lib.ModeKeys.PREDICT + + with ops_lib.device(aggregation_device): + estimator_spec['predictions'] = _concat_tensor_dicts( + *[tower_spec.predictions for tower_spec in tower_specs]) + + export_outputs_dict = _dict_concat( + *[tower_spec.export_outputs for tower_spec in tower_specs]) + + export_outputs = {} + for name, export_output_list in six.iteritems(export_outputs_dict): + if isinstance(export_output_list[0], export_output_lib.PredictOutput): + export_outputs[name] = export_output_lib.PredictOutput( + outputs=_concat_tensor_dicts(*[ + export_output.outputs for export_output in export_output_list + ])) + elif isinstance(export_output_list[0], + export_output_lib.RegressionOutput): + export_outputs[name] = export_output_lib.RegressionOutput( + value=array_ops.concat( + [export_output.value for export_output in export_output_list], + axis=0)) + elif isinstance(export_output_list[0], + export_output_lib.ClassificationOutput): + scores = None + if export_output_list[0].scores is not None: + scores = array_ops.concat( + [export_output.scores for export_output in export_output_list], + axis=0) + + classes = None + if export_output_list[0].classes is not None: + classes = array_ops.stack( + [export_output.classes for export_output in export_output_list], + axis=0) + + export_outputs[name] = export_output_lib.ClassificationOutput( + scores=scores, classes=classes) + + estimator_spec['export_outputs'] = export_outputs + return model_fn_lib.EstimatorSpec(**estimator_spec) + + +def _concat_tensor_dicts(*tensor_dicts): + return { + name: array_ops.concat(tensors, axis=0, name=name) + for name, tensors in six.iteritems(_dict_concat(*tensor_dicts)) + } + + +def _dict_concat(*dicts): + list_dict = {} + for d in dicts: + if d is None: + continue + + for k, v in six.iteritems(d): + list_dict.setdefault(k, []).append(v) + return list_dict diff --git a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py new file mode 100644 index 0000000000000000000000000000000000000000..10b47fba5af0f2a036df637a4f4f996d388270c6 --- /dev/null +++ b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py @@ -0,0 +1,901 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for utilities that replicate `Estimator.model_fn` over GPUs.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import re +import shutil +import tempfile +import numpy as np +import six + +from tensorflow.contrib.estimator.python.estimator import replicate_model_fn +from tensorflow.python.estimator import estimator as estimator_lib +from tensorflow.python.estimator import model_fn as model_fn_lib +from tensorflow.python.estimator.canned import dnn +from tensorflow.python.estimator.canned import optimizers +from tensorflow.python.estimator.canned import prediction_keys +from tensorflow.python.estimator.export import export +from tensorflow.python.estimator.export import export_output +from tensorflow.python.estimator.inputs import numpy_io +from tensorflow.python.feature_column import feature_column +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops as ops_lib +from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import metrics as metrics_lib +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables +from tensorflow.python.ops.losses import losses +from tensorflow.python.platform import gfile +from tensorflow.python.platform import test +from tensorflow.python.saved_model import signature_constants +from tensorflow.python.summary.writer import writer_cache +from tensorflow.python.training import gradient_descent + + +class DNNClassifierIntegrationTest(test_util.TensorFlowTestCase): + + def setUp(self): + self._model_dir = tempfile.mkdtemp() + + def test_complete_flow(self): + n_classes = 3 + input_dimension = 2 + batch_size = 12 + + data = np.linspace( + 0., n_classes - 1., batch_size * input_dimension, dtype=np.float32) + x_data = data.reshape(batch_size, input_dimension) + y_data = np.reshape(self._as_label(data[:batch_size]), (batch_size, 1)) + train_input_fn = numpy_io.numpy_input_fn( + x={'x': x_data}, + y=y_data, + batch_size=batch_size, + num_epochs=None, + shuffle=True) + eval_input_fn = numpy_io.numpy_input_fn( + x={'x': x_data}, y=y_data, batch_size=batch_size, shuffle=False) + predict_input_fn = numpy_io.numpy_input_fn( + x={'x': x_data}, batch_size=batch_size, shuffle=False) + + feature_columns = [ + feature_column.numeric_column('x', shape=(input_dimension,)) + ] + + estimator = dnn.DNNClassifier( + hidden_units=(2, 2), + feature_columns=feature_columns, + n_classes=n_classes, + model_dir=self._model_dir) + + def optimizer_fn(): + return optimizers.get_optimizer_instance('Adagrad', learning_rate=0.05) + + # TODO(isaprykin): Switch Estimator to use allow_soft_placement=True + # during export_savedmodel and then switch this test to replicate over + # GPUs instead of CPUs. + estimator = estimator_lib.Estimator( + model_fn=replicate_model_fn.replicate_model_fn( + estimator.model_fn, + optimizer_fn, + devices=['/cpu:0', '/cpu:0', '/cpu:0']), + model_dir=estimator.model_dir, + config=estimator.config, + params=estimator.params) + + num_steps = 10 + estimator.train(train_input_fn, steps=num_steps) + + scores = estimator.evaluate(eval_input_fn) + self.assertEqual(num_steps, scores[ops_lib.GraphKeys.GLOBAL_STEP]) + self.assertIn('loss', six.iterkeys(scores)) + + predicted_proba = np.array([ + x[prediction_keys.PredictionKeys.PROBABILITIES] + for x in estimator.predict(predict_input_fn) + ]) + self.assertAllEqual((batch_size, n_classes), predicted_proba.shape) + + feature_spec = feature_column.make_parse_example_spec(feature_columns) + serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn( + feature_spec) + export_dir = estimator.export_savedmodel(tempfile.mkdtemp(), + serving_input_receiver_fn) + self.assertTrue(gfile.Exists(export_dir)) + + def _as_label(self, data_in_float): + return np.rint(data_in_float).astype(np.int64) + + def tearDown(self): + if self._model_dir: + writer_cache.FileWriterCache.clear() + shutil.rmtree(self._model_dir) + + +class ReplicateModelTest(test_util.TensorFlowTestCase): + + def model_fn(self, mode, features, labels, params): + c = variable_scope.get_variable( + 'c', + initializer=constant_op.constant(10, dtype=dtypes.float64), + dtype=dtypes.float64) + + predictions = math_ops.multiply(features, c) + + loss = None + if mode is not model_fn_lib.ModeKeys.PREDICT: + loss = losses.absolute_difference( + labels=labels, + predictions=predictions, + reduction=losses.Reduction.SUM) + loss = math_ops.reduce_sum(loss) + + metrics = { + 'accuracy': metrics_lib.accuracy(labels, predictions), + 'auc': metrics_lib.auc(labels, predictions) + } + + return model_fn_lib.EstimatorSpec( + mode=mode, + loss=loss, + eval_metric_ops=metrics, + predictions={'probabilities': predictions}, + train_op=control_flow_ops.no_op()) # This train_op isn't actually used. + + def optimizer_fn(self, params): + return gradient_descent.GradientDescentOptimizer(params['learning_rate']) + + @property + def params(self): + params = {} + params['learning_rate'] = 1.0 + return params + + def test_train(self): + features = np.array([[1.0], [2.0]]) + labels = np.array([[1.0], [2.0]]) + + with self.test_session() as session: + replicated_model_fn = replicate_model_fn.replicate_model_fn( + self.model_fn, self.optimizer_fn, devices=['/gpu:0', '/gpu:1']) + estimator_spec = replicated_model_fn(model_fn_lib.ModeKeys.TRAIN, + features, labels, self.params) + session.run(variables.global_variables_initializer()) + + # loss = feature * c - label + total_loss = (1.0 * 10 - 1.0) + (2.0 * 10 - 2.0) + self.assertEqual(total_loss, session.run(estimator_spec.loss)) + + # loss' of c is 3. + # new value of c = 10 - learning rate * 3 = 7.0. + session.run(estimator_spec.train_op) + with variable_scope.variable_scope('', reuse=True): + c = variable_scope.get_variable('c', dtype=dtypes.float64) + self.assertEqual(7.0, session.run(c)) + + def test_train_spec_with_optimizer_without_params(self): + + def optimizer_fn_without_params(): + return gradient_descent.GradientDescentOptimizer(learning_rate=1.0) + + features = np.array([[1.0], [2.0]]) + labels = np.array([[1.0], [2.0]]) + + with self.test_session() as session: # pylint: disable=unused-variable + replicated_model_fn = replicate_model_fn.replicate_model_fn( + self.model_fn, + optimizer_fn_without_params, + devices=['/gpu:0', '/gpu:1']) + # This call is going to fail if `replicated_model_fn` is still passing + # `params` inside `optimizer_fn`, even though the latter doesn't take any: + estimator_spec = replicated_model_fn(model_fn_lib.ModeKeys.TRAIN, + features, labels, self.params) + del estimator_spec + + def test_eval(self): + features = np.array([[0.01], [0.002]]) + labels = np.array([[0.01], [0.02]]) + + with self.test_session() as session: + replicated_model_fn = replicate_model_fn.replicate_model_fn( + self.model_fn, self.optimizer_fn, devices=['/gpu:0', '/gpu:1']) + estimator_spec = replicated_model_fn(model_fn_lib.ModeKeys.EVAL, features, + labels, self.params) + session.run(variables.local_variables_initializer()) + session.run(variables.global_variables_initializer()) + + accuracy, a = estimator_spec.eval_metric_ops['accuracy'] + auc, b = estimator_spec.eval_metric_ops['auc'] + + session.run([a, b]) + accuracy = session.run(accuracy) + auc = session.run(auc) + + # Accuracy is 0.0 (no match) in the first tower. + # Accuracy is 1.0 (match) in the second tower, since the feature + # times weight "c" happened to be equal to the label. + total_loss = ((0.01 * 10 - 0.01) + (0.002 * 10 - 0.02)) + + self.assertNear((0.0 + 1.0) / 2.0, accuracy, 0.01) + self.assertEqual(0, auc) + self.assertNear(total_loss, session.run(estimator_spec.loss), 0.01) + + def test_predict(self): + features = np.array([[0.01], [0.002]]) + labels = np.array([[0.01], [0.02]]) + + with self.test_session() as session: + replicated_model_fn = replicate_model_fn.replicate_model_fn( + self.model_fn, self.optimizer_fn, devices=['/gpu:0', '/gpu:1']) + estimator_spec = replicated_model_fn(model_fn_lib.ModeKeys.PREDICT, + features, labels, self.params) + session.run(variables.global_variables_initializer()) + + self.assertAllClose({ + 'probabilities': np.array([[0.1], [0.02]]) + }, session.run(estimator_spec.predictions)) + + def test_train_single_tower(self): + features = np.array([[1.0], [2.0]]) + labels = np.array([[1.0], [2.0]]) + + with self.test_session() as session: + replicated_model_fn = replicate_model_fn.replicate_model_fn( + self.model_fn, self.optimizer_fn) + estimator_spec = replicated_model_fn(model_fn_lib.ModeKeys.TRAIN, + features, labels, self.params) + session.run(variables.global_variables_initializer()) + + # loss = feature * c - label + total_loss = (1.0 * 10 - 1.0) + (2.0 * 10 - 2.0) + self.assertEqual(total_loss, session.run(estimator_spec.loss)) + + # loss' of c is 3. + # new value of c = 10 - learning rate * 3 = 7.0. + session.run(estimator_spec.train_op) + with variable_scope.variable_scope('', reuse=True): + c = variable_scope.get_variable('c', dtype=dtypes.float64) + self.assertEqual(7.0, session.run(c)) + + def test_eval_single_tower(self): + features = np.array([[0.01], [0.002]]) + labels = np.array([[0.01], [0.02]]) + + with self.test_session() as session: + replicated_model_fn = replicate_model_fn.replicate_model_fn( + self.model_fn, self.optimizer_fn, devices=['/gpu:0']) + estimator_spec = replicated_model_fn(model_fn_lib.ModeKeys.EVAL, features, + labels, self.params) + session.run(variables.local_variables_initializer()) + session.run(variables.global_variables_initializer()) + + accuracy, a = estimator_spec.eval_metric_ops['accuracy'] + auc, b = estimator_spec.eval_metric_ops['auc'] + + session.run([a, b]) + accuracy = session.run(accuracy) + auc = session.run(auc) + + # Accuracy is 0.0 (no match) in the first tower. + # Accuracy is 1.0 (match) in the second tower, since the feature + # times weight "c" happened to be equal to the label. + total_loss = ((0.01 * 10 - 0.01) + (0.002 * 10 - 0.02)) + + self.assertNear((0.0 + 1.0) / 2.0, accuracy, 0.01) + self.assertEqual(0, auc) + self.assertNear(total_loss, session.run(estimator_spec.loss), 0.01) + + def test_predict_single_tower(self): + features = np.array([[0.01], [0.002]]) + labels = np.array([[0.01], [0.02]]) + + with self.test_session() as session: + replicated_model_fn = replicate_model_fn.replicate_model_fn( + self.model_fn, self.optimizer_fn, devices=['/gpu:0']) + estimator_spec = replicated_model_fn(model_fn_lib.ModeKeys.PREDICT, + features, labels, self.params) + session.run(variables.global_variables_initializer()) + + self.assertAllClose({ + 'probabilities': np.array([[0.1], [0.02]]) + }, session.run(estimator_spec.predictions)) + + +class GetLossTowersTest(test_util.TensorFlowTestCase): + + def model_fn(self, mode, features, labels, params): + c = variable_scope.get_variable( + 'c', + initializer=constant_op.constant(0.25, dtype=dtypes.float64), + dtype=dtypes.float64) + + predictions = math_ops.add(np.array([0.1, 0.2, 0.3, features[0]]), c) + labels = np.array([0.1, 0.2, 0.3, labels[0]]) + + loss = losses.absolute_difference( + labels=labels, predictions=predictions, reduction=losses.Reduction.SUM) + + return model_fn_lib.EstimatorSpec(mode=mode, loss=math_ops.reduce_sum(loss)) + + def test_gradients_are_computed(self): + with self.test_session() as session: + tower_specs = replicate_model_fn._get_loss_towers( + self.model_fn, + mode=None, + features=[[0.6], [1.6]], + labels=[[0.6], [0.6]], + params=None, + config=None, + devices=['/gpu:0', '/gpu:1'], + local_ps_device='/gpu:0', + name_scope_pattern='test_tower_{}') + session.run(variables.global_variables_initializer()) + + self.assertEqual(len(tower_specs), 2) + + self.assertEqual('/device:GPU:0', tower_specs[0].loss.device) + self.assertEqual('Sum:0', tower_specs[0].loss.name) + self.assertEqual(1.0, session.run(tower_specs[0].loss)) + + self.assertEqual('/device:GPU:1', tower_specs[1].loss.device) + self.assertEqual('test_tower_1/Sum:0', tower_specs[1].loss.name) + # The input batch for the second tower had a loss that is 1.0 + # bigger: 0.6 vs 1.6. + self.assertEqual(2.0, session.run(tower_specs[1].loss)) + + self.assertEqual(1, len(variables.global_variables())) + self.assertEqual(1, len(variables.trainable_variables())) + + with variable_scope.variable_scope('', reuse=True): + c = variable_scope.get_variable('c', dtype=dtypes.float64) + self.assertEqual(0.25, session.run(c)) + + +class SplitBatchTest(test_util.TensorFlowTestCase): + + def evaluate_shards(self, first_list, second_list): + evaluate_items = lambda x: x.eval() + return list(map(evaluate_items, first_list)), list( + map(evaluate_items, second_list)) + + def test_simple_half_split(self): + with self.test_session() as session: # pylint: disable=unused-variable + features = [0.0, 1.0, 2.0, 3.0] + labels = [10.0, 11.0, 12.0, 13.0] + feature_shards, label_shards = replicate_model_fn._split_batch( + features, labels, 2, device='/gpu:0') + + feature_shards, label_shards = self.evaluate_shards( + feature_shards, label_shards) + + self.assertAllEqual([[0.0, 1.0], [2.0, 3.0]], feature_shards) + self.assertAllEqual([[10.0, 11.0], [12.0, 13.0]], label_shards) + + def test_to_each_their_own(self): + with self.test_session() as session: # pylint: disable=unused-variable + features = [0.0, 1.0, 2.0, 3.0] + labels = [10.0, 11.0, 12.0, 13.0] + feature_shards, label_shards = replicate_model_fn._split_batch( + features, labels, 4, device='/gpu:0') + + feature_shards, label_shards = self.evaluate_shards( + feature_shards, label_shards) + + self.assertAllEqual([[0.0], [1.0], [2.0], [3.0]], feature_shards) + self.assertAllEqual([[10.0], [11.0], [12.0], [13.0]], label_shards) + + def test_one_batch(self): + with self.test_session() as session: # pylint: disable=unused-variable + features = [0.0, 1.0, 2.0, 3.0] + labels = [10.0, 11.0, 12.0, 13.0] + feature_shards, label_shards = replicate_model_fn._split_batch( + features, labels, 1, device='/gpu:0') + + feature_shards, label_shards = self.evaluate_shards( + feature_shards, label_shards) + + self.assertAllEqual([[0.0, 1.0, 2.0, 3.0]], feature_shards) + self.assertAllEqual([[10.0, 11.0, 12.0, 13.0]], label_shards) + + def test_half_split_in_dictionary(self): + with self.test_session() as session: # pylint: disable=unused-variable + features = {'first': [0.0, 1.0, 2.0, 3.0], 'second': [4.0, 5.0, 6.0, 7.0]} + labels = [10.0, 11.0, 12.0, 13.0] + + feature_shards, label_shards = replicate_model_fn._split_batch( + features, labels, 2, device='/gpu:0') + + self.assertAllEqual([0.0, 1.0], feature_shards[0]['first'].eval()) + self.assertAllEqual([4.0, 5.0], feature_shards[0]['second'].eval()) + self.assertAllEqual([2.0, 3.0], feature_shards[1]['first'].eval()) + self.assertAllEqual([6.0, 7.0], feature_shards[1]['second'].eval()) + self.assertAllEqual([10.0, 11.0], label_shards[0].eval()) + self.assertAllEqual([12.0, 13.0], label_shards[1].eval()) + + def test_one_batch_in_dictionary(self): + with self.test_session() as session: # pylint: disable=unused-variable + features = {'first': [0.0, 1.0, 2.0, 3.0], 'second': [4.0, 5.0, 6.0, 7.0]} + labels = [10.0, 11.0, 12.0, 13.0] + + feature_shards, label_shards = replicate_model_fn._split_batch( + features, labels, 1, device='/gpu:0') + + self.assertAllEqual([0.0, 1.0, 2.0, 3.0], + feature_shards[0]['first'].eval()) + self.assertAllEqual([4.0, 5.0, 6.0, 7.0], + feature_shards[0]['second'].eval()) + self.assertAllEqual([10.0, 11.0, 12.0, 13.0], label_shards[0].eval()) + + def test_feature_and_label_dictionaries(self): + with self.test_session() as session: # pylint: disable=unused-variable + features = {'first': [0.0, 1.0, 2.0, 3.0], 'second': [4.0, 5.0, 6.0, 7.0]} + labels = {'first': [10.0, 11.0], 'second': [12.0, 13.0]} + + feature_shards, label_shards = replicate_model_fn._split_batch( + features, labels, 2, device='/gpu:0') + + self.assertAllEqual([0.0, 1.0], feature_shards[0]['first'].eval()) + self.assertAllEqual([4.0, 5.0], feature_shards[0]['second'].eval()) + self.assertAllEqual([2.0, 3.0], feature_shards[1]['first'].eval()) + self.assertAllEqual([6.0, 7.0], feature_shards[1]['second'].eval()) + self.assertAllEqual([10.0], label_shards[0]['first'].eval()) + self.assertAllEqual([12.0], label_shards[0]['second'].eval()) + self.assertAllEqual([11], label_shards[1]['first'].eval()) + self.assertAllEqual([13.0], label_shards[1]['second'].eval()) + + +class TrainSpecTest(test_util.TensorFlowTestCase): + + expected_predictions = {} + + def create_estimator_spec(self, loss): + return model_fn_lib.EstimatorSpec( + mode=model_fn_lib.ModeKeys.TRAIN, + loss=loss, + train_op=loss, # Not used; currently required. + predictions=self.expected_predictions) + + def create_constant_loss(self, loss_value): + return constant_op.constant(loss_value, dtype=dtypes.float64) + + def test_example(self): + with self.test_session() as session: + tower_losses = list(map(self.create_constant_loss, [2, 4, 6])) + tower_specs = list(map(self.create_estimator_spec, tower_losses)) + + expected_train_op = tower_losses[1] + + estimator_spec = replicate_model_fn._train_spec( + tower_specs, expected_train_op, aggregation_device='/gpu:0') + + self.assertEqual(expected_train_op, estimator_spec.train_op) + self.assertEqual(2 + 4 + 6, session.run(estimator_spec.loss)) + self.assertEqual(self.expected_predictions, estimator_spec.predictions) + + +class EvalSpecTest(test_util.TensorFlowTestCase): + + def create_estimator_spec(self, loss, metrics): + return model_fn_lib.EstimatorSpec( + mode=model_fn_lib.ModeKeys.EVAL, loss=loss, eval_metric_ops=metrics) + + def create_constant_loss(self, loss_value): + return constant_op.constant(loss_value, dtype=dtypes.float64) + + def create_eval_metrics(self, noise): + predictions = np.array([0.1, 0.2, 0.3, 0.6 + noise]) + labels = np.array([0.1, 0.2, 0.3, 0.6]) + + metrics = { + 'accuracy': metrics_lib.accuracy(labels, predictions), + 'auc': metrics_lib.auc(labels, predictions) + } + return metrics + + def test_example(self): + with self.test_session() as session: + tower_losses = map(self.create_constant_loss, [2, 4, 6]) + tower_metrics = map(self.create_eval_metrics, [0, 0.2, 0.3]) + tower_specs = [ + self.create_estimator_spec(l, m) + for l, m in zip(tower_losses, tower_metrics) + ] + session.run(variables.local_variables_initializer()) + + estimator_spec = replicate_model_fn._eval_spec( + tower_specs, aggregation_device='/device:GPU:0') + + accuracy, a = estimator_spec.eval_metric_ops['accuracy'] + auc, b = estimator_spec.eval_metric_ops['auc'] + + self.assertEqual('/device:CPU:0', accuracy.device) + self.assertEqual('/device:CPU:0', auc.device) + + session.run([a, b]) + accuracy = session.run(accuracy) + auc = session.run(auc) + + self.assertNear((12 - 2) / 12, accuracy, 0.01) + self.assertEqual(0, auc) + self.assertEqual(2 + 4 + 6, session.run(estimator_spec.loss)) + + def test_handles_single_tower(self): + with self.test_session() as session: + tower_losses = map(self.create_constant_loss, [5]) + tower_metrics = map(self.create_eval_metrics, [0.2]) + tower_specs = [ + self.create_estimator_spec(l, m) + for l, m in zip(tower_losses, tower_metrics) + ] + session.run(variables.local_variables_initializer()) + + estimator_spec = replicate_model_fn._eval_spec( + tower_specs, aggregation_device='/device:GPU:0') + + accuracy, a = estimator_spec.eval_metric_ops['accuracy'] + auc, b = estimator_spec.eval_metric_ops['auc'] + + self.assertEqual('/device:CPU:0', accuracy.device) + self.assertEqual('/device:CPU:0', auc.device) + + session.run([a, b]) + accuracy = session.run(accuracy) + auc = session.run(auc) + + self.assertNear((4 - 1) / 4, accuracy, 0.01) + self.assertEqual(0, auc) + self.assertEqual(5, session.run(estimator_spec.loss)) + + +class PredictSpecTest(test_util.TensorFlowTestCase): + + def model_fn(self, mode, features, labels, params): + c = variable_scope.get_variable( + 'c', + initializer=constant_op.constant(0.25, dtype=dtypes.float64), + dtype=dtypes.float64) + + predictions = math_ops.add(np.array([features[0], features[0]]), c) + + return model_fn_lib.EstimatorSpec( + mode=model_fn_lib.ModeKeys.PREDICT, + predictions={ + 'probabilities': predictions + }) + + def test_example(self): + with self.test_session() as session: + tower_specs = replicate_model_fn._get_loss_towers( + self.model_fn, + mode=None, + features=[[0.1], [0.2]], + labels=[[], []], + params=None, + config=None, + devices=['/gpu:0', '/gpu:1'], + local_ps_device='/gpu:0', + ) + session.run(variables.global_variables_initializer()) + + estimator_spec = replicate_model_fn._predict_spec( + tower_specs, aggregation_device='/gpu:0') + + self.assertEqual('/device:GPU:0', + estimator_spec.predictions['probabilities'].device) + self.assertAllClose({ + 'probabilities': np.array([0.35, 0.35, 0.45, 0.45]) + }, session.run(estimator_spec.predictions)) + + +class ReduceMetricVariablesTest(test_util.TensorFlowTestCase): + + def create_metric_variable(self, initial_value, name): + return variable_scope.variable( + initial_value, + trainable=False, + collections=[ops_lib.GraphKeys.METRIC_VARIABLES], + validate_shape=True, + name=name) + + def create_tower_metrics(self, tower_id): + with variable_scope.variable_scope('', reuse=(tower_id != 0)): + self.create_metric_variable(1.3 * (tower_id + 1), 'total') + self.create_metric_variable(2.3 * (tower_id + 1), 'count') + self.create_metric_variable( + np.array([3.3, 3.5, 3.7]) * (tower_id + 1), 'total') + + def test_example(self): + with self.test_session() as session: + for tower_id in range(3): + self.create_tower_metrics(tower_id) + + session.run( + variables.variables_initializer( + ops_lib.get_collection(ops_lib.GraphKeys.METRIC_VARIABLES))) + + session.run( + replicate_model_fn._reduce_metric_variables(number_of_towers=3)) + + # 1st tower = 1.3, 2.3, [3.3, 3.5, 3.7] + # 2nd tower = 2.6, 4.6, [6.6, 7.0, 7.4] + # 3rd tower = 3.9, 6.9, [9.9, 10.5, 11.1] + # Reduced = 7.8, 13.8, [19.8, 21.0, 22.2] + # Towers are accumulated in the first tower. + local_metrics = session.run( + ops_lib.get_collection(ops_lib.GraphKeys.METRIC_VARIABLES)) + + self.assertNear(7.8, local_metrics[0], 0.01) + self.assertNear(13.8, local_metrics[1], 0.01) + self.assertAllClose([19.8, 21., 22.1], local_metrics[2], 0.01) + self.assertNear(0.0, local_metrics[3], 0.01) + self.assertNear(0.0, local_metrics[4], 0.01) + self.assertAllClose([0.0, 0.0, 0.0], local_metrics[5], 0.01) + self.assertNear(0.0, local_metrics[6], 0.01) + self.assertNear(0.0, local_metrics[7], 0.01) + self.assertAllClose([0.0, 0.0, 0.0], local_metrics[8], 0.01) + + def test_reduce_is_idempotent(self): + with self.test_session() as session: + for tower_id in range(3): + self.create_tower_metrics(tower_id) + + session.run( + variables.variables_initializer( + ops_lib.get_collection(ops_lib.GraphKeys.METRIC_VARIABLES))) + + for _ in range(20): + session.run( + replicate_model_fn._reduce_metric_variables(number_of_towers=3)) + + local_metrics = session.run( + ops_lib.get_collection(ops_lib.GraphKeys.METRIC_VARIABLES)) + + self.assertNear(7.8, local_metrics[0], 0.01) + self.assertNear(13.8, local_metrics[1], 0.01) + self.assertAllClose([19.8, 21., 22.1], local_metrics[2], 0.01) + self.assertNear(0.0, local_metrics[3], 0.01) + self.assertNear(0.0, local_metrics[4], 0.01) + self.assertAllClose([0.0, 0.0, 0.0], local_metrics[5], 0.01) + self.assertNear(0.0, local_metrics[6], 0.01) + self.assertNear(0.0, local_metrics[7], 0.01) + self.assertAllClose([0.0, 0.0, 0.0], local_metrics[8], 0.01) + + def test_handles_single_tower(self): + with self.test_session() as session: + self.create_tower_metrics(0) + session.run( + variables.variables_initializer( + ops_lib.get_collection(ops_lib.GraphKeys.METRIC_VARIABLES))) + + session.run( + replicate_model_fn._reduce_metric_variables(number_of_towers=1)) + + local_metrics = session.run( + ops_lib.get_collection(ops_lib.GraphKeys.METRIC_VARIABLES)) + + self.assertNear(1.3, local_metrics[0], 0.01) + self.assertNear(2.3, local_metrics[1], 0.01) + self.assertAllClose([3.3, 3.5, 3.7], local_metrics[2], 0.01) + + def test_doesnt_accept_uneven_number_of_variables(self): + with self.test_session() as session: + for tower_id in range(3): + self.create_tower_metrics(tower_id) + self.create_metric_variable(-1.0, 'oddball') + + session.run( + variables.variables_initializer( + ops_lib.get_collection(ops_lib.GraphKeys.METRIC_VARIABLES))) + + with self.assertRaisesRegexp(ValueError, ''): + session.run( + replicate_model_fn._reduce_metric_variables(number_of_towers=3)) + + +class MergeExportOutputsTest(test_util.TensorFlowTestCase): + + def optimizer_fn(self): + return gradient_descent.GradientDescentOptimizer(1.0) + + def model_fn(self, mode, features, labels, params): + c = variable_scope.get_variable( + 'c', + initializer=constant_op.constant(10, dtype=dtypes.float64), + dtype=dtypes.float64) + + predictions = {'probabilities': math_ops.multiply(features, c)} + loss = losses.absolute_difference( + labels=labels, + predictions=predictions['probabilities'], + reduction=losses.Reduction.SUM) + + metrics = { + 'accuracy': metrics_lib.accuracy(labels, predictions['probabilities']), + 'auc': metrics_lib.auc(labels, predictions['probabilities']) + } + tensor_string_repr = str(features) + classes = constant_op.constant( + re.search('(split_inputs/split:[0-9])', tensor_string_repr).group(1), + dtype=dtypes.string) + + export_outputs = { + signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: + export_output.PredictOutput(predictions), + 'classification_output': + export_output.ClassificationOutput(predictions['probabilities'], + classes), + 'classification_scores': + export_output.ClassificationOutput( + scores=predictions['probabilities']), + 'classification_classes': + export_output.ClassificationOutput(classes=classes), + 'regression_output': + export_output.RegressionOutput(predictions['probabilities']), + } + + return model_fn_lib.EstimatorSpec( + mode=mode, + loss=math_ops.reduce_sum(loss), + eval_metric_ops=metrics, + predictions=predictions, + train_op=loss, # This train_op isn't actually used. + export_outputs=export_outputs) + + def replicate_estimator_spec(self, session): + features = np.array([0.01, 0.002]) + labels = np.array([0.01, 0.02]) + + replicated_model_fn = replicate_model_fn.replicate_model_fn( + self.model_fn, self.optimizer_fn, devices=['/gpu:0', '/gpu:1']) + estimator_spec = replicated_model_fn(model_fn_lib.ModeKeys.PREDICT, + features, labels, {}) + session.run(variables.global_variables_initializer()) + return estimator_spec + + def test_merde_predict_output(self): + with self.test_session() as session: + estimator_spec = self.replicate_estimator_spec(session) + self.assertAllClose( + { + 'probabilities': np.array([0.1, 0.02]) + }, + session.run(estimator_spec.export_outputs[ + signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY].outputs)) + + def test_merge_classification_output_scores_classes(self): + with self.test_session() as session: + estimator_spec = self.replicate_estimator_spec(session) + self.assertAllClose( + [0.1, 0.02], + session.run( + estimator_spec.export_outputs['classification_output'].scores)) + self.assertAllEqual( + [b'split_inputs/split:0', b'split_inputs/split:1'], + session.run( + estimator_spec.export_outputs['classification_output'].classes)) + + def test_merge_classification_output_scores(self): + with self.test_session() as session: + estimator_spec = self.replicate_estimator_spec(session) + self.assertAllClose( + [0.1, 0.02], + session.run( + estimator_spec.export_outputs['classification_scores'].scores)) + self.assertEqual( + None, estimator_spec.export_outputs['classification_scores'].classes) + + def test_merge_classification_output_classes(self): + with self.test_session() as session: + estimator_spec = self.replicate_estimator_spec(session) + self.assertAllEqual( + [b'split_inputs/split:0', b'split_inputs/split:1'], + session.run( + estimator_spec.export_outputs['classification_classes'].classes)) + self.assertEqual( + None, estimator_spec.export_outputs['classification_classes'].scores) + + def test_merge_regression_output(self): + with self.test_session() as session: + estimator_spec = self.replicate_estimator_spec(session) + self.assertAllClose( + [0.1, 0.02], + session.run(estimator_spec.export_outputs['regression_output'].value)) + + +class GetLocalDevicesTest(test_util.TensorFlowTestCase): + + def test_there_is_at_least_a_cpu(self): + self.assertTrue(replicate_model_fn._get_local_devices('CPU')) + + def test_there_is_no_xpu(self): + self.assertFalse( + replicate_model_fn._get_local_devices('XPU')) # XPU doesn't exist. + + def test_whether_there_is_a_gpu(self): + self.assertEqual( + len(replicate_model_fn._get_local_devices('GPU')), + test.is_gpu_available()) + + +class LocalDeviceSetterTest(test_util.TensorFlowTestCase): + + def test_vars_are_on_ps_but_ops_are_on_workers(self): + local_device_setter = replicate_model_fn._local_device_setter( + ps_device='/device:GPU:3', worker_device='/device:GPU:2') + + with ops_lib.device(local_device_setter): + c = variables.Variable(0.01) + self.assertEqual('/device:GPU:3', c.device) + + cc = variables.Variable(0.02) + self.assertEqual('/device:GPU:3', cc.device) + + ccc = variables.Variable(0.03) + self.assertEqual('/device:GPU:3', ccc.device) + + c_op = array_ops.concat(c, axis=0) + self.assertEqual('/device:GPU:2', c_op.device) + + cc_op = array_ops.concat(cc, axis=0) + self.assertEqual('/device:GPU:2', cc_op.device) + + +class ComputeSumWithDevicePlacementTest(test_util.TensorFlowTestCase): + + def test_example(self): + with self.test_session() as session: + total = replicate_model_fn._compute_sum_on_device( + [1.0, 2.0, 3.0, 4.0], device='/device:GPU:0', name='test_sum') + + self.assertEqual('/device:GPU:0', total.device) + self.assertEqual('test_sum', total.op.name) + self.assertEqual(10.0, session.run(total)) + + +class ConcatTensorDictsTest(test_util.TensorFlowTestCase): + + def test_example(self): + tensor_dicts = [ + { + 'a': np.array([1.0, 2.0]), + 'b': np.array([11.0]), + 'c': np.array([21.0]), + }, + { + 'a': np.array([3.0]), + 'b': np.array([12.0, 13.0]), + }, + { + 'b': np.array([14.0]), + }, + ] + + with self.test_session() as session: + self.assertAllClose({ + 'a': np.array([1.0, 2.0, 3.0]), + 'b': np.array([11.0, 12.0, 13.0, 14.0]), + 'c': np.array([21.0]), + }, session.run(replicate_model_fn._concat_tensor_dicts(*tensor_dicts))) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/framework/BUILD b/tensorflow/contrib/framework/BUILD index 891425fd8cae6fbbf60d30cbd9137c049073456c..e8dad886a1409babdf4ea47b9cd05def1f1ce25e 100644 --- a/tensorflow/contrib/framework/BUILD +++ b/tensorflow/contrib/framework/BUILD @@ -24,6 +24,7 @@ tf_custom_op_py_library( "python/framework/__init__.py", "python/framework/checkpoint_utils.py", "python/framework/experimental.py", + "python/framework/graph_util.py", "python/framework/tensor_util.py", "python/ops/__init__.py", "python/ops/accumulate_n_v2.py", @@ -32,6 +33,7 @@ tf_custom_op_py_library( "python/ops/checkpoint_ops.py", "python/ops/ops.py", "python/ops/prettyprint_ops.py", + "python/ops/sort_ops.py", "python/ops/variables.py", ], dso = [ @@ -231,6 +233,17 @@ py_test( ], ) +py_test( + name = "graph_util_test", + srcs = ["python/framework/graph_util_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":framework_py", + "//tensorflow/python:client_testlib", + "//tensorflow/python:platform", + ], +) + py_test( name = "tensor_util_test", srcs = ["python/framework/tensor_util_test.py"], @@ -307,6 +320,20 @@ py_test( ], ) +py_test( + name = "sort_ops_test", + size = "medium", + srcs = ["python/ops/sort_ops_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":framework_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:random_ops", + "//third_party/py/numpy", + ], +) + filegroup( name = "all_files", srcs = glob( diff --git a/tensorflow/contrib/framework/__init__.py b/tensorflow/contrib/framework/__init__.py index 8421ba7c0423c6ed274f92ba74930822d0171e05..3f592611830e40a30392239c85486a2fad15a2a2 100644 --- a/tensorflow/contrib/framework/__init__.py +++ b/tensorflow/contrib/framework/__init__.py @@ -79,6 +79,8 @@ See the @{$python/contrib.framework} guide. @@load_embedding_initializer @@load_linear_multiclass_bias_initializer @@load_variable_slot_initializer + +@@sort """ from __future__ import absolute_import diff --git a/tensorflow/contrib/framework/python/framework/__init__.py b/tensorflow/contrib/framework/python/framework/__init__.py index c8e6a4685498a4d89cef44f6a9a3acbe7557cb67..2d49771ab756359712a3ee0b23649c231678f952 100644 --- a/tensorflow/contrib/framework/python/framework/__init__.py +++ b/tensorflow/contrib/framework/python/framework/__init__.py @@ -21,6 +21,7 @@ from __future__ import print_function # pylint: disable=wildcard-import from tensorflow.contrib.framework.python.framework.checkpoint_utils import * from tensorflow.contrib.framework.python.framework.experimental import experimental +from tensorflow.contrib.framework.python.framework.graph_util import * from tensorflow.contrib.framework.python.framework.tensor_util import * # pylint: enable=wildcard-import from tensorflow.python.util import decorator_utils diff --git a/tensorflow/contrib/framework/python/framework/graph_util.py b/tensorflow/contrib/framework/python/framework/graph_util.py new file mode 100644 index 0000000000000000000000000000000000000000..8ab8711db4650921e0d366a91adfe2f68b5a42f9 --- /dev/null +++ b/tensorflow/contrib/framework/python/framework/graph_util.py @@ -0,0 +1,128 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Helpers to manipulate a tensor graph in python. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import copy +import six + +# pylint: disable=unused-import +from tensorflow.core.framework import graph_pb2 +from tensorflow.core.framework import node_def_pb2 +from tensorflow.python.framework.graph_util_impl import _assert_nodes_are_present +from tensorflow.python.framework.graph_util_impl import _bfs_for_reachable_nodes +from tensorflow.python.framework.graph_util_impl import _extract_graph_summary +from tensorflow.python.framework.graph_util_impl import _node_name + +__all__ = ["fuse_op"] + + +def fuse_op(graph_def, input_nodes, output_nodes, output_dtypes, + output_quantized, op_name, op_type): + """Fuse subgraph between input_nodes and output_nodes into a single custom op. + + Args: + graph_def: A graph_pb2.GraphDef proto. + input_nodes: input nodes to the subgraph to be fused. + output_nodes: output nodes to the subgraph to be fused. + output_dtypes: A list of output datatypes for the custom op + output_quantized: A boolean flag that indicates if output is quantized + op_name: fused op name. + op_type: fused op type. + Returns: + The GraphDef of the new graph. + + Raises: + TypeError: If 'graph_def' is not a graph_pb2.GraphDef proto. + """ + + if not isinstance(graph_def, graph_pb2.GraphDef): + raise TypeError("graph_def must be a graph_pb2.GraphDef proto.") + + if isinstance(input_nodes, six.string_types): + raise TypeError("input_nodes must be a list.") + + if isinstance(output_nodes, six.string_types): + raise TypeError("output_nodes must be a list.") + + name_to_input_name, name_to_node, name_to_seq_num = _extract_graph_summary( + graph_def) + _assert_nodes_are_present(name_to_node, input_nodes + output_nodes) + + # Nodes upto and including input_nodes + reachable_by_input = _bfs_for_reachable_nodes(input_nodes, name_to_input_name) + # Nodes upto and including output_nodes + reachable_by_output = _bfs_for_reachable_nodes(output_nodes, + name_to_input_name) + + # Set of nodes in the list input_nodes + input_nodes_set = set(input_nodes) + + # Set of nodes in the list output_nodes + output_nodes_set = set(output_nodes) + + nodes_post_output = [] + for node in graph_def.node: + n = _node_name(node.name) + if n in reachable_by_output: + if n not in reachable_by_input and n not in output_nodes_set: + # n is between input and output, i.e., part of the fused op + next_to_visit = [n] + while next_to_visit: + cur_node = next_to_visit[0] + del next_to_visit[0] + if cur_node in reachable_by_input and cur_node not in input_nodes_set: + raise TypeError("Node %s uses input %s not in input_nodes." % + (n, cur_node)) + if cur_node not in input_nodes_set: + next_to_visit += name_to_input_name[cur_node] + else: + nodes_post_output.append(n) + + # Add all nodes upto the input nodes + out = graph_pb2.GraphDef() + reachable_by_input_sorted = sorted( + list(reachable_by_input), key=lambda n: name_to_seq_num[n]) + for node in reachable_by_input_sorted: + out.node.extend([copy.deepcopy(name_to_node[node])]) + + # Add the custom op + new_node = node_def_pb2.NodeDef() + for node in input_nodes: + new_node.input.append(node) + new_node.attr["_output_types"].list.type[:] = output_dtypes + new_node.attr["_output_quantized"].b = output_quantized + new_node.op = op_type + new_node.name = op_name + out.node.extend([new_node]) + + # Add the nodes in the output of the custom op + for index, n in enumerate(output_nodes): + assert len(name_to_node[n].input) == 1 + new_node = copy.deepcopy(name_to_node[n]) + del new_node.input[:] + new_node.input.append(op_name + (":" + str(index) if index != 0 else "")) + out.node.extend([new_node]) + + # Add the nodes post output_nodes + for n in nodes_post_output: + out.node.extend([copy.deepcopy(name_to_node[n])]) + + out.library.CopyFrom(graph_def.library) + out.versions.CopyFrom(graph_def.versions) + return out diff --git a/tensorflow/contrib/framework/python/framework/graph_util_test.py b/tensorflow/contrib/framework/python/framework/graph_util_test.py new file mode 100644 index 0000000000000000000000000000000000000000..87b992e22e1ad3aa20389d0834eeb3a5972c676e --- /dev/null +++ b/tensorflow/contrib/framework/python/framework/graph_util_test.py @@ -0,0 +1,61 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""@graph_util tests.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.framework.python.framework import graph_util +from tensorflow.core.framework import graph_pb2 +from tensorflow.core.framework import node_def_pb2 +from tensorflow.core.framework import types_pb2 +from tensorflow.python.platform import test + + +def GetNewNode(name, op, input_nodes): + new_node = node_def_pb2.NodeDef() + new_node.op = op + new_node.name = name + for node in input_nodes: + new_node.input.append(node) + return new_node + + +class GraphUtilTest(test.TestCase): + + def testGraphUtil(self): + graph_def = graph_pb2.GraphDef() + node_a = GetNewNode('A', 'Placeholder', []) + node_b = GetNewNode('B', 'Op1', ['A']) + node_c = GetNewNode('C', 'Op1', ['B']) + node_d = GetNewNode('D', 'Op1', ['C']) + node_e = GetNewNode('E', 'Op1', ['D']) + graph_def.node.extend([node_a, node_b, node_c, node_d, node_e]) + fused_graph_def = graph_util.fuse_op( + graph_def, ['A'], ['D'], [types_pb2.DT_FLOAT], True, 'FusedOp', 'Op2') + self.assertEqual(len(fused_graph_def.node), 4) + self.assertEqual(fused_graph_def.node[0].name, 'A') + self.assertEqual(fused_graph_def.node[1].name, 'FusedOp') + self.assertEqual(fused_graph_def.node[1].input[0], 'A') + self.assertEqual(fused_graph_def.node[1].op, 'Op2') + self.assertEqual(fused_graph_def.node[1].attr['_output_quantized'].b, True) + self.assertEqual(fused_graph_def.node[1].attr['_output_types'].list.type, + [types_pb2.DT_FLOAT]) + self.assertEqual(fused_graph_def.node[2].name, 'D') + self.assertEqual(fused_graph_def.node[3].name, 'E') + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/framework/python/ops/__init__.py b/tensorflow/contrib/framework/python/ops/__init__.py index edef37cf0c0719bf10a4c75c34adb30b9716cdcd..685bb94779762ce46ee342e7e0a182c54be64743 100644 --- a/tensorflow/contrib/framework/python/ops/__init__.py +++ b/tensorflow/contrib/framework/python/ops/__init__.py @@ -24,5 +24,6 @@ from tensorflow.contrib.framework.python.ops.arg_scope import * from tensorflow.contrib.framework.python.ops.checkpoint_ops import * from tensorflow.contrib.framework.python.ops.ops import * from tensorflow.contrib.framework.python.ops.prettyprint_ops import * +from tensorflow.contrib.framework.python.ops.sort_ops import * from tensorflow.contrib.framework.python.ops.variables import * # pylint: enable=wildcard-import diff --git a/tensorflow/contrib/framework/python/ops/sort_ops.py b/tensorflow/contrib/framework/python/ops/sort_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..8f62f0ea7b9b561f235b9496ffda97a9f378d530 --- /dev/null +++ b/tensorflow/contrib/framework/python/ops/sort_ops.py @@ -0,0 +1,113 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Support for sorting tensors. + +@@sort +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import ops as framework_ops +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops + + +def sort(values, axis=-1, direction='ASCENDING', name=None): + """Sorts a tensor. + + Args: + values: 1-D or higher numeric `Tensor`. + axis: The axis along which to sort. The default is -1, which sorts the last + axis. + direction: The direction in which to sort the values (`'ASCENDING'` or + `'DESCENDING'`). + name: Optional name for the operation. + + Returns: + A `Tensor` with the same dtype and shape as `values`, with the elements + sorted along the given `axis`. + + Raises: + ValueError: If axis is not a constant scalar, or the direction is invalid. + """ + with framework_ops.name_scope(name, 'sort'): + if direction not in _SORT_IMPL: + raise ValueError('%s should be one of %s' % + (direction, ', '.join(sorted(_SORT_IMPL.keys())))) + # Axis must be an integer, not a Tensor. + axis = framework_ops.convert_to_tensor(axis, name='axis') + axis_static = tensor_util.constant_value(axis) + if axis.shape.ndims != 0 or axis_static is None: + raise ValueError('axis must be a constant scalar') + axis_static = int(axis_static) # Avoids NumPy casting error + + values = framework_ops.convert_to_tensor(values, name='values') + + return _SORT_IMPL[direction](values, axis_static) + + +def _descending_sort(values, axis): + """Sorts values in reverse using `top_k`. + + Args: + values: Tensor of numeric values. + axis: Index of the axis which values should be sorted along. + + Returns: + The sorted values. + """ + k = array_ops.shape(values)[axis] + rank = array_ops.rank(values) + # Fast path: sorting the last axis. + if axis == -1 or axis + 1 == values.get_shape().ndims: + return nn_ops.top_k(values, k)[0] + + # Otherwise, transpose the array. Swap axes `axis` and `rank - 1`. + if axis < 0: + # Make axis a Tensor with the real axis index if needed. + axis += rank + transposition = array_ops.concat( + [ + # Axes up to axis are unchanged. + math_ops.range(axis), + # Swap axis and rank - 1. + [rank - 1], + # Axes in [axis + 1, rank - 1) are unchanged. + math_ops.range(axis + 1, rank - 1), + # Swap axis and rank - 1. + [axis] + ], + axis=0) + top_k_input = array_ops.transpose(values, transposition) + values, unused_indices = nn_ops.top_k(top_k_input, k) + # transposition contains a single cycle of length 2 (swapping 2 elements), + # so it is an involution (it is its own inverse). + return array_ops.transpose(values, transposition) + + +def _ascending_sort(values, axis): + # Negate the values to get the ascending order from descending sort. + values_or_indices = _descending_sort(-values, axis) + return -values_or_indices + + +_SORT_IMPL = { + 'ASCENDING': _ascending_sort, + 'DESCENDING': _descending_sort, +} diff --git a/tensorflow/contrib/framework/python/ops/sort_ops_test.py b/tensorflow/contrib/framework/python/ops/sort_ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..d08ae502f10d98ee14d8bea2f76b18bedb935cea --- /dev/null +++ b/tensorflow/contrib/framework/python/ops/sort_ops_test.py @@ -0,0 +1,95 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the sort wrapper.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.framework.python.ops import sort_ops +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.platform import test + + +class SortTest(test.TestCase): + + def testRandom_lowDimensionality(self): + self._testRandom_lowDimensionality(negative_axis=False) + + def testRandom_lowDimensionality_negative(self): + self._testRandom_lowDimensionality(negative_axis=True) + + def _testRandom_lowDimensionality(self, negative_axis): + np.random.seed(42) + for _ in range(20): + rank = np.random.randint(1, 3) + shape = [np.random.randint(0, 20) for _ in range(rank)] + arr = np.random.random(shape) + sort_axis = np.random.choice(rank) + if negative_axis: + sort_axis = -1 - sort_axis + with self.test_session(): + self.assertAllEqual( + np.sort(arr, axis=sort_axis), + sort_ops.sort(constant_op.constant(arr), axis=sort_axis).eval()) + + def testRandom_highDimensionality(self): + np.random.seed(100) + for _ in range(20): + rank = np.random.randint(5, 15) + shape = [np.random.randint(1, 4) for _ in range(rank)] + arr = np.random.random(shape) + sort_axis = np.random.choice(rank) + with self.test_session(): + self.assertAllEqual( + np.sort(arr, axis=sort_axis), + sort_ops.sort(constant_op.constant(arr), axis=sort_axis).eval()) + + def testScalar(self): + # Create an empty scalar where the static shape is unknown. + zeros_length_1 = array_ops.zeros( + random_ops.random_uniform([1], minval=0, maxval=1, dtype=dtypes.int32), + dtype=dtypes.int32) + scalar = array_ops.zeros(zeros_length_1) + + sort = sort_ops.sort(scalar) + with self.test_session(): + with self.assertRaises(errors.InvalidArgumentError): + sort.eval() + + def testNegativeOutOfBounds_staticShape(self): + arr = constant_op.constant([3, 4, 5]) + with self.assertRaises(ValueError): + sort_ops.sort(arr, axis=-4) + + def testDescending(self): + arr = np.random.random((10, 5, 5)) + with self.test_session(): + self.assertAllEqual( + np.sort(arr, axis=0)[::-1], + sort_ops.sort( + constant_op.constant(arr), + axis=0, + direction='DESCENDING').eval()) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/summary/summary_ops.py b/tensorflow/contrib/summary/summary_ops.py index ecfa6baeff35ce6b25185a686d528ec73f17c1ee..56e31985936c22d9b5d6c85fff067118152e220d 100644 --- a/tensorflow/contrib/summary/summary_ops.py +++ b/tensorflow/contrib/summary/summary_ops.py @@ -246,8 +246,8 @@ def image(name, tensor, bad_color=None, max_images=3, family=None): """Writes an image summary if possible.""" def function(tag, scope): - if bad_color is None: - bad_color_ = constant_op.constant([255, 0, 0, 255], dtype=dtypes.uint8) + bad_color_ = (constant_op.constant([255, 0, 0, 255], dtype=dtypes.uint8) + if bad_color is None else bad_color) # Note the identity to move the tensor to the CPU. return gen_summary_ops.write_image_summary( context.context().summary_writer_resource, diff --git a/tensorflow/contrib/tpu/profiler/BUILD b/tensorflow/contrib/tpu/profiler/BUILD index f6309e2e72f75a4ba5b323b4d7348c49555d522e..0e1fca3d3c8b6f3a19b3e989dbee1863475796c5 100644 --- a/tensorflow/contrib/tpu/profiler/BUILD +++ b/tensorflow/contrib/tpu/profiler/BUILD @@ -95,3 +95,10 @@ tf_proto_library_cc( cc_api_version = 2, visibility = ["//visibility:public"], ) + +tf_proto_library_cc( + name = "tf_op_stats_proto", + srcs = ["tf_op_stats.proto"], + cc_api_version = 2, + visibility = ["//visibility:public"], +) diff --git a/tensorflow/contrib/tpu/profiler/tf_op_stats.proto b/tensorflow/contrib/tpu/profiler/tf_op_stats.proto new file mode 100644 index 0000000000000000000000000000000000000000..5b2dbb31243d401fbab31bab5bc86133896693fe --- /dev/null +++ b/tensorflow/contrib/tpu/profiler/tf_op_stats.proto @@ -0,0 +1,127 @@ +// This proto describes the format of tensorflow operation level stats for +// profiling (in tensorboard) purpose. + +syntax = "proto2"; + +package tensorflow.tpu; + +// Result proto for OpMetrics. +message OpMetricsResult { + // True if this OP is executed on the device; False if it is executed on the + // host. + optional bool on_device = 1; + reserved 2; // was uint32 id. + // Name of this OP. + optional string name = 3; + // Rank of this OP. + optional uint64 rank = 4; + // The starting time in cycles of the last instance of this OP executed. + optional double last_starttime_in_cycles = 5; + // The ending time in cycles of the last instance of this OP executed. + optional double last_endtime_in_cycles = 6; + // If this OP (say A), is an immediate child of another OP (say B), this field + // stores the sum of duration in microseconds of A inside B. If A appears more + // than once in B, the duration of all A's appearances will be added together. + // This sum will be reset after the self-time of B is calculated so that it + // can be reused for a new parent OP. + optional double sum_of_duration_in_us_as_children = 7; + // Number of instances that this OP occurred. + optional uint64 occurrences = 8; + // Total time in microseconds spent in this OP (accumulated + // over all of its occurrences). + optional double total_time_in_us = 9; + // Total self time in microseconds spent in this OP + // (accumulated over all of its occurrences). + optional double total_self_time_in_us = 10; + // The total self time as a fraction of sum of all OP's + // total self time on the host. + optional double host_total_self_time_as_fraction_of_all_op_time = 11; + // Cumulative total self time in fraction on the host. + optional double host_cumulative_total_self_time_as_fraction_of_all_op_time = + 12; + // The total self time as a fraction of sum of all OP's + // total self time on the device. + optional double device_total_self_time_as_fraction_of_all_op_time = 13; + // Cumulative total self time in fraction on the device. + optional double device_cumulative_total_self_time_as_fraction_of_all_op_time = + 14; + // Total number of FLOPs incurred by this OP. + optional double total_flops = 15; + // Total time in microseconds that the MXU is occupied by this OP. + optional double total_bytes_accessed = 16; + // Total time in microseconds that the MXU is occupied by this OP. + optional double mxu_occupancy_in_us = 17; + // Total time in microseconds that the XU is occupied by this OP. + optional double xu_occupancy_in_us = 18; + // Total DMA access stall time in microseconds. + optional double total_dma_stall_in_us = 19; +} + +// Result proto for OpMetricsDb. +message OpMetricsDbResult { + // A bunch of OpMetricsResults. + repeated OpMetricsResult metrics_db = 1; +} + +// Result proto for StepInfo. +message StepInfoResult { + // The (micro) step number. + optional uint32 step_num = 1; + // The step duration in picoseconds. + optional uint64 duration_ps = 2; + // The infeed duration in picoseconds. + // Can turn into a map if we want a variable number of ops. + optional uint64 infeed_duration_ps = 3; +} + +// Result proto for a sequence of steps. +message StepSequenceResult { + // A sequence of StepInfoResults. + repeated StepInfoResult step_sequence = 1; +} + +// Result proto for a StepDatabase. +message StepDatabaseResult { + // A map from core_id to StepSequenceResult. + map step_sequence_per_core = 1; +} + +// Result proto for Dashboard data. +message DashboardResult { + // The total iteration time in nanoseconds. + optional double iteration_time_ns = 1; + // The total number of iterations. + optional int32 num_iterations = 2; + // The total computation time in nanoseconds. + optional double computation_time_ns = 3; + // The total number of computations. + optional int32 num_computations = 4; +} + +// Result proto for HloExtraInfo. +message HloExtraInfoResult { + // Category of the HLO op given by the compiler. + optional string category = 1; + // The long name of the HLO that includes the dimensions. + optional string long_name = 2; +} + +// Result proto for HloExtraInfoMap. +message HloExtraInfoMapResult { + // A map from HLO name to HloExtraInfo. + map hlo_extrainfo_map = 1; +} + +// Result proto for TfStatsHelper. +message TfOpStats { + // The result for the TF-metric database. + optional OpMetricsDbResult tf_metrics_db = 1; + // The result for the HLO-metric database. + optional OpMetricsDbResult hlo_metrics_db = 2; + // The result for the step database. + optional StepDatabaseResult step_db = 3; + // The result for the TPU dashboard. + optional DashboardResult dashboard = 4; + // The result for the HloExtraInfoMap. + optional HloExtraInfoMapResult hlo_extrainfo_map = 5; +} diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py index 5a3b8314291951b5dfce091dccb0dc9e5f7af3b5..060b3f912926fbaa56bc1150e50434a7ad22c847 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py @@ -66,7 +66,7 @@ _CROSS_REPLICA_SUM_OP = 'CrossReplicaSum' _RESERVED_PARAMS_KEYS = [_BATCH_SIZE_KEY] # TODO(b/65703635): Flip the value and remove all dead code. -_WRAP_INPUT_FN_INTO_WHILE_LOOP = False +_WRAP_INPUT_FN_INTO_WHILE_LOOP = True def _create_global_step(graph): diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 7b535da0b2f751ef258580b678bec5022d671b82..9530af637ef953c293472d926281de77cf626752 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -1414,16 +1414,19 @@ LIB_INTERNAL_PUBLIC_HEADERS = tf_additional_lib_hdrs() + [ "platform/tracing.h", ] +# Replicated for lib_internal and lib_internal_impl. +LIB_INTERNAL_DEFINES = (tf_additional_lib_defines() + [ + "TF_USE_SNAPPY", + ] + tf_additional_verbs_lib_defines() + + tf_additional_mpi_lib_defines() + + tf_additional_gdr_lib_defines()) + cc_library( name = "lib_internal", srcs = LIB_INTERNAL_PRIVATE_HEADERS, hdrs = LIB_INTERNAL_PUBLIC_HEADERS, copts = tf_copts(), - defines = tf_additional_lib_defines() + [ - "TF_USE_SNAPPY", - ] + tf_additional_verbs_lib_defines() + - tf_additional_mpi_lib_defines() + - tf_additional_gdr_lib_defines(), + defines = LIB_INTERNAL_DEFINES, linkopts = select({ "//tensorflow:freebsd": [], "//tensorflow:windows": [], @@ -1477,6 +1480,7 @@ cc_library( ), hdrs = LIB_INTERNAL_PUBLIC_HEADERS, copts = tf_copts(), + defines = LIB_INTERNAL_DEFINES, deps = tf_additional_lib_deps() + [ ":lib_hash_crc32c_accelerate_internal", ":lib_proto_parsing", diff --git a/tensorflow/core/api_def/api_test.cc b/tensorflow/core/api_def/api_test.cc index ceeb172fa0a9abf2ab7adcfc801b4bcb5fa04381..d95d958d5afaad58bdec82183be3d3a09cf4605d 100644 --- a/tensorflow/core/api_def/api_test.cc +++ b/tensorflow/core/api_def/api_test.cc @@ -46,92 +46,218 @@ constexpr char kDefaultApiDefDir[] = "tensorflow/core/api_def/base_api"; constexpr char kOverridesFilePath[] = "tensorflow/cc/ops/op_gen_overrides.pbtxt"; -constexpr char kApiDefFileFormat[] = "api_def_%c.pbtxt"; -constexpr char kAlphabet[] = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"; +constexpr char kApiDefFileFormat[] = "api_def_%s.pbtxt"; +constexpr char kApiDefFilePattern[] = "api_def_*.pbtxt"; -// Get map from first character to ApiDefs for ops -// that start with that character. -std::unordered_map GenerateApiDef( - const OpList& ops, const OpGenOverrides& overrides) { +void FillBaseApiDef(ApiDef* api_def, const OpDef& op) { + api_def->set_graph_op_name(op.name()); + // Add arg docs + for (auto& input_arg : op.input_arg()) { + if (!input_arg.description().empty()) { + auto* api_def_in_arg = api_def->add_in_arg(); + api_def_in_arg->set_name(input_arg.name()); + api_def_in_arg->set_description(input_arg.description()); + } + } + for (auto& output_arg : op.output_arg()) { + if (!output_arg.description().empty()) { + auto* api_def_out_arg = api_def->add_out_arg(); + api_def_out_arg->set_name(output_arg.name()); + api_def_out_arg->set_description(output_arg.description()); + } + } + // Add attr docs + for (auto& attr : op.attr()) { + if (!attr.description().empty()) { + auto* api_def_attr = api_def->add_attr(); + api_def_attr->set_name(attr.name()); + api_def_attr->set_description(attr.description()); + } + } + // Add docs + api_def->set_summary(op.summary()); + api_def->set_description(op.description()); +} + +// Checks if arg1 should be before arg2 according to ordering in args. +bool CheckArgBefore(const ApiDef::Arg* arg1, const ApiDef::Arg* arg2, + const protobuf::RepeatedPtrField& args) { + for (auto& arg : args) { + if (arg.name() == arg2->name()) { + return false; + } else if (arg.name() == arg1->name()) { + return true; + } + } + return false; +} + +// Checks if attr1 should be before attr2 according to ordering in op_def. +bool CheckAttrBefore(const ApiDef::Attr* attr1, const ApiDef::Attr* attr2, + const OpDef& op_def) { + for (auto& attr : op_def.attr()) { + if (attr.name() == attr2->name()) { + return false; + } else if (attr.name() == attr1->name()) { + return true; + } + } + return false; +} + +// Applies renames to args. +void ApplyArgOverrides( + protobuf::RepeatedPtrField* args, + const protobuf::RepeatedPtrField& renames, + const protobuf::RepeatedPtrField& op_args, + const string& op_name) { + for (auto& rename : renames) { + // First check if rename is valid. + bool valid = false; + for (const auto& op_arg : op_args) { + if (op_arg.name() == rename.from()) { + valid = true; + } + } + QCHECK(valid) << rename.from() << " is not a valid argument for " + << op_name; + bool found_arg = false; + // If Arg is already in ApiDef, just update it. + for (int i = 0; i < args->size(); ++i) { + auto* arg = args->Mutable(i); + if (arg->name() == rename.from()) { + arg->set_rename_to(rename.to()); + found_arg = true; + break; + } + } + if (!found_arg) { // not in ApiDef, add a new arg. + auto* new_arg = args->Add(); + new_arg->set_name(rename.from()); + new_arg->set_rename_to(rename.to()); + } + } + // We don't really need a specific order here right now. + // However, it is clearer if order follows OpDef. + std::sort(args->pointer_begin(), args->pointer_end(), + [&](ApiDef::Arg* arg1, ApiDef::Arg* arg2) { + return CheckArgBefore(arg1, arg2, op_args); + }); +} + +// Returns existing attribute with the given name if such +// attribute exists. Otherwise, adds a new attribute and returns it. +ApiDef::Attr* FindOrAddAttr(ApiDef* api_def, const string attr_name) { + // If Attr is already in ApiDef, just update it. + for (int i = 0; i < api_def->attr_size(); ++i) { + auto* attr = api_def->mutable_attr(i); + if (attr->name() == attr_name) { + return attr; + } + } + // Add a new Attr. + auto* new_attr = api_def->add_attr(); + new_attr->set_name(attr_name); + return new_attr; +} + +// Applies renames and default values to attributes. +void ApplyAttrOverrides(ApiDef* api_def, const OpGenOverride& op_override, + const OpDef& op_def) { + for (auto& attr_rename : op_override.attr_rename()) { + auto* attr = FindOrAddAttr(api_def, attr_rename.from()); + attr->set_rename_to(attr_rename.to()); + } + + for (auto& attr_default : op_override.attr_default()) { + auto* attr = FindOrAddAttr(api_def, attr_default.name()); + *(attr->mutable_default_value()) = attr_default.value(); + } + // We don't really need a specific order here right now. + // However, it is clearer if order follows OpDef. + std::sort(api_def->mutable_attr()->pointer_begin(), + api_def->mutable_attr()->pointer_end(), + [&](ApiDef::Attr* attr1, ApiDef::Attr* attr2) { + return CheckAttrBefore(attr1, attr2, op_def); + }); +} + +void ApplyOverridesToApiDef(ApiDef* api_def, const OpDef& op, + const OpGenOverride& op_override) { + // Fill ApiDef with data based on op and op_override. + // Set visibility + if (op_override.skip()) { + api_def->set_visibility(ApiDef_Visibility_SKIP); + } else if (op_override.hide()) { + api_def->set_visibility(ApiDef_Visibility_HIDDEN); + } + // Add endpoints + if (!op_override.rename_to().empty()) { + api_def->add_endpoint()->set_name(op_override.rename_to()); + } else if (!op_override.alias().empty()) { + api_def->add_endpoint()->set_name(op.name()); + } + + for (auto& alias : op_override.alias()) { + auto* endpoint = api_def->add_endpoint(); + endpoint->set_name(alias); + } + + ApplyArgOverrides(api_def->mutable_in_arg(), op_override.input_rename(), + op.input_arg(), api_def->graph_op_name()); + ApplyArgOverrides(api_def->mutable_out_arg(), op_override.output_rename(), + op.output_arg(), api_def->graph_op_name()); + ApplyAttrOverrides(api_def, op_override, op); +} + +// Get map from ApiDef file path to corresponding ApiDefs proto. +std::unordered_map GenerateApiDef( + const string& api_def_dir, const OpList& ops, + const OpGenOverrides& overrides) { std::unordered_map name_to_override; for (const auto& op_override : overrides.op()) { name_to_override[op_override.name()] = op_override; } - std::unordered_map api_defs_map; + std::unordered_map api_defs_map; for (const auto& op : ops.op()) { CHECK(!op.name().empty()) << "Encountered empty op name: %s" << op.DebugString(); - const char file_id = toupper(op.name()[0]); - CHECK(isalpha(file_id)) << "Unexpected op name: " << op.name(); - ApiDef* api_def = api_defs_map[file_id].add_op(); - api_def->set_graph_op_name(op.name()); + string file_path = io::JoinPath(api_def_dir, kApiDefFileFormat); + file_path = strings::Printf(file_path.c_str(), op.name().c_str()); + ApiDef* api_def = api_defs_map[file_path].add_op(); + FillBaseApiDef(api_def, op); if (name_to_override.find(op.name()) != name_to_override.end()) { - const auto& op_override = name_to_override[op.name()]; - // Set visibility - if (op_override.skip()) { - api_def->set_visibility(ApiDef_Visibility_SKIP); - } else if (op_override.hide()) { - api_def->set_visibility(ApiDef_Visibility_HIDDEN); - } - // Add endpoints - if (!op_override.rename_to().empty()) { - auto* endpoint = api_def->add_endpoint(); - endpoint->set_name(op_override.rename_to()); - } else { - auto* endpoint = api_def->add_endpoint(); - endpoint->set_name(op.name()); - } - for (auto& alias : op_override.alias()) { - auto* endpoint = api_def->add_endpoint(); - endpoint->set_name(alias); - } - // Add attributes - for (auto& attr : op.attr()) { - auto* api_def_attr = api_def->add_attr(); - api_def_attr->set_name(attr.name()); - for (auto& attr_override : op_override.attr_default()) { - if (attr.name() == attr_override.name()) { - *(api_def_attr->mutable_default_value()) = attr_override.value(); - } - } - for (auto& attr_rename : op_override.attr_rename()) { - if (attr.name() == attr_rename.from()) { - api_def_attr->set_rename_to(attr_rename.to()); - } - } - } - } else { - auto* endpoint = api_def->add_endpoint(); - endpoint->set_name(op.name()); + ApplyOverridesToApiDef(api_def, op, name_to_override[op.name()]); } - // Add docs - api_def->set_summary(op.summary()); - api_def->set_description(op.description()); } return api_defs_map; } -// Reads golden api defs file with the given suffix. -string GetGoldenApiDefsStr(Env* env, const string& api_files_dir, char suffix) { - string file_path = strings::Printf( - io::JoinPath(api_files_dir, kApiDefFileFormat).c_str(), suffix); - if (env->FileExists(file_path).ok()) { +// Reads golden ApiDef files and returns a map from file name to ApiDef file +// contents. +std::unordered_map GetGoldenApiDefs( + Env* env, const string& api_files_dir) { + std::vector matching_paths; + TF_CHECK_OK(env->GetMatchingPaths( + io::JoinPath(api_files_dir, kApiDefFilePattern), &matching_paths)); + + std::unordered_map file_path_to_api_def; + for (auto& file_path : matching_paths) { string file_contents; - TF_EXPECT_OK(ReadFileToString(env, file_path, &file_contents)); - return file_contents; + TF_CHECK_OK(ReadFileToString(env, file_path, &file_contents)); + file_path_to_api_def[file_path] = file_contents; } - return ""; + return file_path_to_api_def; } void RunApiTest(bool update_api_def, const string& api_files_dir) { // Read C++ overrides file - string overrides_file_contents; + OpGenOverrides overrides; Env* env = Env::Default(); - TF_EXPECT_OK( - ReadFileToString(env, kOverridesFilePath, &overrides_file_contents)); + TF_EXPECT_OK(ReadTextProto(env, kOverridesFilePath, &overrides)); // Read all ops OpList ops; @@ -139,29 +265,22 @@ void RunApiTest(bool update_api_def, const string& api_files_dir) { const std::vector multi_line_fields = {"description"}; // Get expected ApiDefs - OpGenOverrides overrides; - auto new_api_defs_map = GenerateApiDef(ops, overrides); + const auto new_api_defs_map = GenerateApiDef(api_files_dir, ops, overrides); bool updated_at_least_one_file = false; + const auto golden_api_defs_map = GetGoldenApiDefs(env, api_files_dir); - for (char c : kAlphabet) { - string golden_api_defs_str = GetGoldenApiDefsStr(env, api_files_dir, c); - string new_api_defs_str = new_api_defs_map[c].DebugString(); + for (auto new_api_entry : new_api_defs_map) { + const auto& file_path = new_api_entry.first; + const auto& golden_api_defs_str = golden_api_defs_map.at(file_path); + string new_api_defs_str = new_api_entry.second.DebugString(); new_api_defs_str = PBTxtToMultiline(new_api_defs_str, multi_line_fields); if (golden_api_defs_str == new_api_defs_str) { continue; } if (update_api_def) { - string output_file_path = - io::JoinPath(api_files_dir, strings::Printf(kApiDefFileFormat, c)); - if (new_api_defs_str.empty()) { - std::cout << "Deleting " << output_file_path << "..." << std::endl; - TF_EXPECT_OK(env->DeleteFile(output_file_path)); - } else { - std::cout << "Updating " << output_file_path << "..." << std::endl; - TF_EXPECT_OK( - WriteStringToFile(env, output_file_path, new_api_defs_str)); - } + std::cout << "Updating " << file_path << "..." << std::endl; + TF_EXPECT_OK(WriteStringToFile(env, file_path, new_api_defs_str)); updated_at_least_one_file = true; } else { EXPECT_EQ(golden_api_defs_str, new_api_defs_str) @@ -170,6 +289,21 @@ void RunApiTest(bool update_api_def, const string& api_files_dir) { } } + for (const auto& golden_api_entry : golden_api_defs_map) { + const auto& file_path = golden_api_entry.first; + if (new_api_defs_map.find(file_path) == new_api_defs_map.end()) { + if (update_api_def) { + std::cout << "Deleting " << file_path << "..." << std::endl; + TF_EXPECT_OK(env->DeleteFile(file_path)); + updated_at_least_one_file = true; + } else { + EXPECT_EQ("", golden_api_entry.second) + << "To update golden API files, run " + << "tensorflow/core/api_def/update_api_def.sh."; + } + } + } + if (update_api_def && !updated_at_least_one_file) { std::cout << "Api def files are already up to date." << std::endl; } diff --git a/tensorflow/core/api_def/base_api/api_def_A.pbtxt b/tensorflow/core/api_def/base_api/api_def_A.pbtxt deleted file mode 100644 index 8193d1bc624535c7894430284686e8664fb71a2d..0000000000000000000000000000000000000000 --- a/tensorflow/core/api_def/base_api/api_def_A.pbtxt +++ /dev/null @@ -1,670 +0,0 @@ -op { - graph_op_name: "Abort" - endpoint { - name: "Abort" - } - summary: "Raise a exception to abort the process when called." - description: <= 2." -} -op { - graph_op_name: "AdjustContrastv2" - endpoint { - name: "AdjustContrastv2" - } - summary: "Adjust the contrast of one or more images." - description: < [2.0132, 1.056] -``` - -@compatibility(numpy) -Equivalent to np.angle. -@end_compatibility -END -} -op { - graph_op_name: "Any" - endpoint { - name: "Any" - } - summary: "Computes the \"logical or\" of elements across dimensions of a tensor." - description: < l1 else 0.0 -accum = accum_new -END -} -op { - graph_op_name: "ApplyFtrlV2" - endpoint { - name: "ApplyFtrlV2" - } - summary: "Update \'*var\' according to the Ftrl-proximal scheme." - description: < l1 else 0.0 -accum = accum_new -END -} -op { - graph_op_name: "ApplyGradientDescent" - endpoint { - name: "ApplyGradientDescent" - } - summary: "Update \'*var\' by subtracting \'alpha\' * \'delta\' from it." -} -op { - graph_op_name: "ApplyMomentum" - endpoint { - name: "ApplyMomentum" - } - summary: "Update \'*var\' according to the momentum scheme. Set use_nesterov = True if you" - description: <= 2." +} diff --git a/tensorflow/core/api_def/base_api/api_def_AdjustContrastv2.pbtxt b/tensorflow/core/api_def/base_api/api_def_AdjustContrastv2.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..429a5e4434e011d1ba43847b9abf8877b4d41e7a --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_AdjustContrastv2.pbtxt @@ -0,0 +1,36 @@ +op { + graph_op_name: "AdjustContrastv2" + endpoint { + name: "AdjustContrast" + } + in_arg { + name: "images" + description: < [2.0132, 1.056] +``` + +@compatibility(numpy) +Equivalent to np.angle. +@end_compatibility +END +} diff --git a/tensorflow/core/api_def/base_api/api_def_Any.pbtxt b/tensorflow/core/api_def/base_api/api_def_Any.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..09fd4e0b6036447dfe355ff56da29e276de62f2b --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_Any.pbtxt @@ -0,0 +1,42 @@ +op { + graph_op_name: "Any" + endpoint { + name: "Any" + } + endpoint { + name: "ReduceAny" + } + in_arg { + name: "input" + description: < l1 else 0.0 +accum = accum_new +END +} diff --git a/tensorflow/core/api_def/base_api/api_def_ApplyFtrlV2.pbtxt b/tensorflow/core/api_def/base_api/api_def_ApplyFtrlV2.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..974f3adc196129f9fe83d098c22dc3cd237263d6 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_ApplyFtrlV2.pbtxt @@ -0,0 +1,75 @@ +op { + graph_op_name: "ApplyFtrlV2" + in_arg { + name: "var" + description: < l1 else 0.0 +accum = accum_new +END +} diff --git a/tensorflow/core/api_def/base_api/api_def_ApplyGradientDescent.pbtxt b/tensorflow/core/api_def/base_api/api_def_ApplyGradientDescent.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..2f38ebd1b8c89a1a65368d3da38cead73225ada5 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_ApplyGradientDescent.pbtxt @@ -0,0 +1,35 @@ +op { + graph_op_name: "ApplyGradientDescent" + in_arg { + name: "var" + description: < -1. +END + } + attr { + name: "scientific" + description: < -1. +END + } + attr { + name: "fill" + description: < -1. If empty, pads with spaces. +Another typical value is '0'. String cannot be longer than 1 character. +END + } + summary: "Converts each entry in the given tensor to strings. Supports many numeric" + description: <