summaryrefslogtreecommitdiffhomepage
path: root/util
diff options
context:
space:
mode:
authorAaron Klotz <aaron@tailscale.com>2023-09-19 14:16:15 -0600
committerAaron Klotz <aaron@tailscale.com>2023-09-19 14:18:20 -0600
commita82dfe7f99a2b9a25b490023fc63dd2ba493bf30 (patch)
treebd4d114e5c4cd9f049a664eb94d0740d3b5deef9 /util
parent19a9d9037f9770adb2cc4b812aeb1f1ff02da5af (diff)
downloadtailscale-aaron/win_process_mitigations.tar.xz
tailscale-aaron/win_process_mitigations.zip
Signed-off-by: Aaron Klotz <aaron@tailscale.com>
Diffstat (limited to 'util')
-rw-r--r--util/winutil/mksyscall.go4
-rw-r--r--util/winutil/process_windows.go577
-rw-r--r--util/winutil/process_windows_test.go17
-rw-r--r--util/winutil/subprocess_windows_test.go388
-rw-r--r--util/winutil/testdata/testprocessattributes/main_windows.go40
-rw-r--r--util/winutil/testdata/testprocessattributes/tests_windows.go57
-rw-r--r--util/winutil/winutil_windows.go79
-rw-r--r--util/winutil/zsyscall_windows.go20
8 files changed, 1101 insertions, 81 deletions
diff --git a/util/winutil/mksyscall.go b/util/winutil/mksyscall.go
index 3c5515ee0..3bae51c4b 100644
--- a/util/winutil/mksyscall.go
+++ b/util/winutil/mksyscall.go
@@ -6,5 +6,7 @@ package winutil
//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go mksyscall.go
//go:generate go run golang.org/x/tools/cmd/goimports -w zsyscall_windows.go
-//sys queryServiceConfig2(hService windows.Handle, infoLevel uint32, buf *byte, bufLen uint32, bytesNeeded *uint32) (err error) [failretval==0] = advapi32.QueryServiceConfig2W
+//sys getProcessMitigationPolicy(hProcess windows.Handle, mitigationPolicy _PROCESS_MITIGATION_POLICY, buf unsafe.Pointer, bufLen uintptr) (err error) [int32(failretval)==0] = kernel32.GetProcessMitigationPolicy
+//sys queryServiceConfig2(hService windows.Handle, infoLevel uint32, buf *byte, bufLen uint32, bytesNeeded *uint32) (err error) [int32(failretval)==0] = advapi32.QueryServiceConfig2W
//sys registerApplicationRestart(cmdLineExclExeName *uint16, flags uint32) (ret wingoes.HRESULT) = kernel32.RegisterApplicationRestart
+//sys setProcessMitigationPolicy(mitigationPolicy _PROCESS_MITIGATION_POLICY, buf unsafe.Pointer, bufLen uintptr) (err error) [int32(failretval)==0] = kernel32.SetProcessMitigationPolicy
diff --git a/util/winutil/process_windows.go b/util/winutil/process_windows.go
new file mode 100644
index 000000000..34e0d2742
--- /dev/null
+++ b/util/winutil/process_windows.go
@@ -0,0 +1,577 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package winutil
+
+import (
+ "bytes"
+ "fmt"
+ "os"
+ "runtime"
+ "strings"
+ "unicode/utf16"
+ "unsafe"
+
+ "golang.org/x/sys/windows"
+)
+
+// StartProcessAsChild starts exePath process as a child of parentPID.
+// StartProcessAsChild copies parentPID's environment variables into
+// the new process, along with any optional environment variables in extraEnv.
+func StartProcessAsChild(parentPID uint32, exePath string, extraEnv []string) error {
+ return StartProcessWithAttributes(exePath, ProcessAttributeEnvExtra{Slice: extraEnv}, ProcessAttributeParentProcessID(parentPID))
+}
+
+func StartProcessWithAttributes(exePath string, attrs ...any) (err error) {
+ var desktop string
+ var parentPID uint32
+ var mitigationBits uint64
+ var inheritableHandleList ProcessAttributeExplicitInheritableHandleList
+ var useStdHandles bool
+ var useToken windows.Token
+ var wd string
+ var procSA *windows.SecurityAttributes
+ var threadSA *windows.SecurityAttributes
+ var envExtra ProcessAttributeEnvExtra
+ var args []string
+ creationFlags := uint32(windows.CREATE_UNICODE_ENVIRONMENT | windows.EXTENDED_STARTUPINFO_PRESENT)
+
+ for _, attr := range attrs {
+ switch v := attr.(type) {
+ case ProcessAttributeExplicitInheritableHandleList:
+ inheritableHandleList, useStdHandles, err = v.filtered()
+ if err != nil {
+ return err
+ }
+ case *ProcessAttributeExplicitInheritableHandleList:
+ inheritableHandleList, useStdHandles, err = v.filtered()
+ if err != nil {
+ return err
+ }
+ case ProcessAttributeParentProcessID:
+ parentPID = uint32(v)
+ case *ProcessAttributeParentProcessID:
+ parentPID = uint32(*v)
+ case ProcessMitigationPolicies:
+ mitigationBits = v.asMitigationBits()
+ case *ProcessMitigationPolicies:
+ mitigationBits = v.asMitigationBits()
+ case windows.Token:
+ useToken = v
+ case ProcessAttributeGUIBindInfo:
+ desktop = v.String()
+ case *ProcessAttributeGUIBindInfo:
+ desktop = v.String()
+ case ProcessAttributeWorkingDirectory:
+ wd = v.String()
+ case *ProcessAttributeWorkingDirectory:
+ wd = v.String()
+ case ProcessAttributeSecurity:
+ procSA = v.Process
+ threadSA = v.Thread
+ case *ProcessAttributeSecurity:
+ procSA = v.Process
+ threadSA = v.Thread
+ case ProcessAttributeEnvExtra:
+ envExtra = v
+ case *ProcessAttributeEnvExtra:
+ envExtra = *v
+ case ProcessAttributeArgs:
+ args = []string(v)
+ case *ProcessAttributeArgs:
+ args = []string(*v)
+ case ProcessAttributeFlags:
+ creationFlags |= v.creationFlags()
+ case *ProcessAttributeFlags:
+ creationFlags |= v.creationFlags()
+ default:
+ return os.ErrInvalid
+ }
+ }
+
+ var attrCount uint32
+ if len(inheritableHandleList.Handles) > 0 {
+ attrCount++
+ }
+ if parentPID != 0 {
+ attrCount++
+ }
+ if mitigationBits != 0 {
+ attrCount++
+ }
+
+ var ph windows.Handle
+ var env []string
+ if parentPID == 0 {
+ env = os.Environ()
+ } else {
+ // According to https://docs.microsoft.com/en-us/windows/win32/procthread/process-security-and-access-rights
+ //
+ // ... To open a handle to another process and obtain full access rights,
+ // you must enable the SeDebugPrivilege privilege. ...
+ //
+ // But we only need PROCESS_CREATE_PROCESS. So perhaps SeDebugPrivilege is too much.
+ //
+ // https://devblogs.microsoft.com/oldnewthing/20080314-00/?p=23113
+ //
+ // TODO: try look for something less than SeDebugPrivilege
+
+ runtime.LockOSThread()
+ defer runtime.UnlockOSThread()
+
+ err := windows.ImpersonateSelf(windows.SecurityImpersonation)
+ if err != nil {
+ return err
+ }
+ defer windows.RevertToSelf()
+
+ err = EnableCurrentThreadPrivilege("SeDebugPrivilege")
+ if err != nil {
+ return err
+ }
+
+ ph, err = windows.OpenProcess(
+ windows.PROCESS_CREATE_PROCESS|windows.PROCESS_QUERY_INFORMATION|windows.PROCESS_DUP_HANDLE,
+ false, parentPID)
+ if err != nil {
+ return err
+ }
+ defer windows.CloseHandle(ph)
+
+ var pt windows.Token
+ if err := windows.OpenProcessToken(ph, windows.TOKEN_QUERY, &pt); err != nil {
+ return err
+ }
+ defer pt.Close()
+
+ env, err = pt.Environ(false)
+ if err != nil {
+ return err
+ }
+ }
+
+ env16 := envExtra.envBlock(env)
+
+ var inheritHandles bool
+ var attrList *windows.ProcThreadAttributeList
+ if attrCount > 0 {
+ attrListContainer, err := windows.NewProcThreadAttributeList(attrCount)
+ if err != nil {
+ return err
+ }
+ defer attrListContainer.Delete()
+
+ if ph != 0 {
+ attrListContainer.Update(windows.PROC_THREAD_ATTRIBUTE_PARENT_PROCESS, unsafe.Pointer(&ph), unsafe.Sizeof(ph))
+ }
+
+ if hll := uintptr(len(inheritableHandleList.Handles)); hll > 0 {
+ attrListContainer.Update(windows.PROC_THREAD_ATTRIBUTE_HANDLE_LIST, unsafe.Pointer(&inheritableHandleList.Handles[0]), hll*unsafe.Sizeof(windows.Handle(0)))
+ inheritHandles = true
+ }
+
+ if mitigationBits != 0 {
+ attrListContainer.Update(windows.PROC_THREAD_ATTRIBUTE_MITIGATION_POLICY, unsafe.Pointer(&mitigationBits), unsafe.Sizeof(mitigationBits))
+ }
+
+ attrList = attrListContainer.List()
+ }
+
+ var desktop16 *uint16
+ if desktop != "" {
+ desktop16, err = windows.UTF16PtrFromString(desktop)
+ if err != nil {
+ return err
+ }
+ }
+
+ var startupInfoFlags uint32
+ if useStdHandles {
+ startupInfoFlags |= windows.STARTF_USESTDHANDLES
+ }
+
+ siex := windows.StartupInfoEx{
+ StartupInfo: windows.StartupInfo{
+ Cb: uint32(unsafe.Sizeof(windows.StartupInfoEx{})),
+ Desktop: desktop16,
+ Flags: startupInfoFlags,
+ StdInput: inheritableHandleList.Stdin,
+ StdOutput: inheritableHandleList.Stdout,
+ StdErr: inheritableHandleList.Stderr,
+ },
+ ProcThreadAttributeList: attrList,
+ }
+
+ var wd16 *uint16
+ if wd != "" {
+ wd16, err = windows.UTF16PtrFromString(wd)
+ if err != nil {
+ return err
+ }
+ }
+
+ exePath16, err := windows.UTF16PtrFromString(exePath)
+ if err != nil {
+ return err
+ }
+
+ cmdLine, err := makeCmdLine(exePath, args)
+ if err != nil {
+ return err
+ }
+
+ var pi windows.ProcessInformation
+ if useToken == 0 {
+ err = windows.CreateProcess(exePath16, cmdLine, procSA, threadSA, inheritHandles, creationFlags, env16, wd16, &siex.StartupInfo, &pi)
+ } else {
+ err = windows.CreateProcessAsUser(useToken, exePath16, cmdLine, procSA, threadSA, inheritHandles, creationFlags, env16, wd16, &siex.StartupInfo, &pi)
+ }
+
+ runtime.KeepAlive(siex)
+
+ if err != nil {
+ return err
+ }
+ defer windows.CloseHandle(pi.Thread)
+ defer windows.CloseHandle(pi.Process)
+
+ return err
+}
+
+func makeCmdLine(exePath string, args []string) (*uint16, error) {
+ var buf strings.Builder
+
+ buf.WriteString(windows.EscapeArg(exePath))
+
+ for _, arg := range args {
+ if buf.Len() > 0 {
+ buf.WriteByte(' ')
+ }
+ buf.WriteString(windows.EscapeArg(arg))
+ }
+
+ return windows.UTF16PtrFromString(buf.String())
+}
+
+// StartProcessAsCurrentGUIUser is like StartProcessAsChild, but if finds
+// current logged in user desktop process (normally explorer.exe),
+// and passes found PID to StartProcessAsChild.
+func StartProcessAsCurrentGUIUser(exePath string, extraEnv []string) error {
+ // as described in https://devblogs.microsoft.com/oldnewthing/20190425-00/?p=102443
+ desktop, err := GetDesktopPID()
+ if err != nil {
+ return fmt.Errorf("failed to find desktop: %v", err)
+ }
+ err = StartProcessAsChild(desktop, exePath, extraEnv)
+ if err != nil {
+ return fmt.Errorf("failed to start executable: %v", err)
+ }
+ return nil
+}
+
+type ProcessAttributeEnvExtra struct {
+ Map map[string]string
+ Slice []string
+}
+
+func (ee *ProcessAttributeEnvExtra) envBlock(env []string) *uint16 {
+ var buf bytes.Buffer
+
+ for _, s := range [][]string{env, ee.Slice} {
+ for _, v := range s {
+ buf.WriteString(v)
+ buf.WriteByte(0)
+ }
+ }
+
+ for k, v := range ee.Map {
+ buf.WriteString(k)
+ buf.WriteByte('=')
+ buf.WriteString(v)
+ buf.WriteByte(0)
+ }
+
+ if buf.Len() == 0 {
+ // So that we end with a double-null in the empty env case (unlikely)
+ buf.WriteByte(0)
+ }
+
+ buf.WriteByte(0)
+
+ return &utf16.Encode([]rune(string(buf.Bytes())))[0]
+}
+
+type ProcessAttributeFlags struct {
+ BreakawayFromJob bool
+ CreateNewConsole bool
+ CreateNewProcessGroup bool
+ Detached bool
+ InheritParentAffinity bool
+ NoConsoleWindow bool
+}
+
+func (paf *ProcessAttributeFlags) creationFlags() (result uint32) {
+ if paf.BreakawayFromJob {
+ result |= windows.CREATE_BREAKAWAY_FROM_JOB
+ }
+ if paf.CreateNewConsole {
+ result |= windows.CREATE_NEW_CONSOLE
+ }
+ if paf.CreateNewProcessGroup {
+ result |= windows.CREATE_NEW_PROCESS_GROUP
+ }
+ if paf.Detached {
+ result |= windows.DETACHED_PROCESS
+ }
+ if paf.InheritParentAffinity {
+ result |= windows.INHERIT_PARENT_AFFINITY
+ }
+ if paf.NoConsoleWindow {
+ result |= windows.CREATE_NO_WINDOW
+ }
+
+ return result
+}
+
+type ProcessAttributeArgs []string
+
+type ProcessAttributeSecurity struct {
+ Process *windows.SecurityAttributes
+ Thread *windows.SecurityAttributes
+}
+
+type ProcessAttributeWorkingDirectory string
+
+func (wd *ProcessAttributeWorkingDirectory) String() string {
+ return string(*wd)
+}
+
+type ProcessAttributeGUIBindInfo struct {
+ WindowStation string
+ Desktop string
+}
+
+func (gbi *ProcessAttributeGUIBindInfo) String() string {
+ winsta := gbi.WindowStation
+ if winsta == "" {
+ winsta = "Winsta0"
+ }
+
+ desktop := gbi.Desktop
+ if desktop == "" {
+ desktop = "default"
+ }
+
+ var buf strings.Builder
+ buf.WriteString(winsta)
+ buf.WriteByte('\\')
+ buf.WriteString(desktop)
+ return buf.String()
+}
+
+type ProcessAttributeParentProcessID uint32
+
+type ProcessAttributeExplicitInheritableHandleList struct {
+ Stdin windows.Handle
+ Stdout windows.Handle
+ Stderr windows.Handle
+ Handles []windows.Handle
+}
+
+func (eihl *ProcessAttributeExplicitInheritableHandleList) filtered() (result ProcessAttributeExplicitInheritableHandleList, containsStd bool, err error) {
+ result = ProcessAttributeExplicitInheritableHandleList{
+ Stdin: eihl.Stdin,
+ Stdout: eihl.Stdout,
+ Stderr: eihl.Stderr,
+ Handles: make([]windows.Handle, 0, len(eihl.Handles)+3),
+ }
+
+ handles := make([]windows.Handle, 0, len(eihl.Handles)+3)
+
+ if result.Stdin == 0 {
+ result.Stdin = windows.Stdin
+ }
+ handles = append(handles, result.Stdin)
+
+ if result.Stdout == 0 {
+ result.Stdout = windows.Stdout
+ }
+ handles = append(handles, result.Stdout)
+
+ if result.Stderr == 0 {
+ result.Stderr = windows.Stderr
+ }
+ handles = append(handles, result.Stderr)
+
+ handles = append(handles, eihl.Handles...)
+
+ for i, h := range handles {
+ fileType, err := windows.GetFileType(h)
+ if err != nil {
+ return result, false, err
+ }
+ if fileType != windows.FILE_TYPE_DISK && fileType != windows.FILE_TYPE_PIPE {
+ continue
+ }
+
+ if err := windows.SetHandleInformation(h, windows.HANDLE_FLAG_INHERIT, windows.HANDLE_FLAG_INHERIT); err != nil {
+ return result, false, err
+ }
+
+ result.Handles = append(result.Handles, h)
+ if i < 3 {
+ // Standard handle
+ containsStd = true
+ }
+ }
+
+ return result, containsStd, nil
+}
+
+type _PROCESS_MITIGATION_POLICY int32
+
+const (
+ processDEPPolicy _PROCESS_MITIGATION_POLICY = 0
+ processASLRPolicy _PROCESS_MITIGATION_POLICY = 1
+ processDynamicCodePolicy _PROCESS_MITIGATION_POLICY = 2
+ processStrictHandleCheckPolicy _PROCESS_MITIGATION_POLICY = 3
+ processSystemCallDisablePolicy _PROCESS_MITIGATION_POLICY = 4
+ processMitigationOptionsMask _PROCESS_MITIGATION_POLICY = 5
+ processExtensionPointDisablePolicy _PROCESS_MITIGATION_POLICY = 6
+ processControlFlowGuardPolicy _PROCESS_MITIGATION_POLICY = 7
+ processSignaturePolicy _PROCESS_MITIGATION_POLICY = 8
+ processFontDisablePolicy _PROCESS_MITIGATION_POLICY = 9
+ processImageLoadPolicy _PROCESS_MITIGATION_POLICY = 10
+ processSystemCallFilterPolicy _PROCESS_MITIGATION_POLICY = 11
+ processPayloadRestrictionPolicy _PROCESS_MITIGATION_POLICY = 12
+ processChildProcessPolicy _PROCESS_MITIGATION_POLICY = 13
+ processSideChannelIsolationPolicy _PROCESS_MITIGATION_POLICY = 14
+ processUserShadowStackPolicy _PROCESS_MITIGATION_POLICY = 15
+ processRedirectionTrustPolicy _PROCESS_MITIGATION_POLICY = 16
+ processUserPointerAuthPolicy _PROCESS_MITIGATION_POLICY = 17
+ processSEHOPPolicy _PROCESS_MITIGATION_POLICY = 18
+)
+
+type processMitigationPolicyFlags struct {
+ Flags uint32
+}
+
+const (
+ _NoRemoteImages = 1
+ _NoLowMandatoryLabelImages = (1 << 1)
+ _PreferSystem32Images = (1 << 2)
+ _MicrosoftSignedOnly = 1
+ _DisableExtensionPoints = 1
+ _ProhibitDynamicCode = 1
+)
+
+type ProcessMitigationPolicies struct {
+ DisableExtensionPoints bool
+ PreferSystem32Images bool
+ ProhibitDynamicCode bool
+ ProhibitLowMandatoryLabelImages bool
+ ProhibitNonMicrosoftSignedDLLs bool
+ ProhibitRemoteImages bool
+}
+
+func CurrentProcessMitigationPolicies() (result ProcessMitigationPolicies, _ error) {
+ var flags processMitigationPolicyFlags
+ cp := windows.CurrentProcess()
+
+ if err := getProcessMitigationPolicy(cp, processExtensionPointDisablePolicy, unsafe.Pointer(&flags), unsafe.Sizeof(flags)); err != nil {
+ return result, err
+ }
+ result.DisableExtensionPoints = flags.Flags&_DisableExtensionPoints != 0
+
+ if err := getProcessMitigationPolicy(cp, processSystemCallDisablePolicy, unsafe.Pointer(&flags), unsafe.Sizeof(flags)); err != nil {
+ return result, err
+ }
+ result.ProhibitNonMicrosoftSignedDLLs = flags.Flags&_MicrosoftSignedOnly != 0
+
+ if err := getProcessMitigationPolicy(cp, processDynamicCodePolicy, unsafe.Pointer(&flags), unsafe.Sizeof(flags)); err != nil {
+ return result, err
+ }
+ result.ProhibitDynamicCode = flags.Flags&_ProhibitDynamicCode != 0
+
+ if err := getProcessMitigationPolicy(cp, processImageLoadPolicy, unsafe.Pointer(&flags), unsafe.Sizeof(flags)); err != nil {
+ return result, err
+ }
+ result.ProhibitRemoteImages = flags.Flags&_NoRemoteImages != 0
+ result.ProhibitLowMandatoryLabelImages = flags.Flags&_NoLowMandatoryLabelImages != 0
+ result.PreferSystem32Images = flags.Flags&_PreferSystem32Images != 0
+
+ return result, nil
+}
+
+func (pmp *ProcessMitigationPolicies) SetOnCurrentProcess() error {
+ if pmp.DisableExtensionPoints {
+ v := processMitigationPolicyFlags{
+ Flags: _DisableExtensionPoints,
+ }
+ if err := setProcessMitigationPolicy(processExtensionPointDisablePolicy, unsafe.Pointer(&v), unsafe.Sizeof(v)); err != nil {
+ return err
+ }
+ }
+
+ if pmp.ProhibitNonMicrosoftSignedDLLs {
+ v := processMitigationPolicyFlags{
+ Flags: _MicrosoftSignedOnly,
+ }
+ if err := setProcessMitigationPolicy(processSystemCallDisablePolicy, unsafe.Pointer(&v), unsafe.Sizeof(v)); err != nil {
+ return err
+ }
+ }
+
+ if pmp.ProhibitDynamicCode {
+ v := processMitigationPolicyFlags{
+ Flags: _ProhibitDynamicCode,
+ }
+ if err := setProcessMitigationPolicy(processDynamicCodePolicy, unsafe.Pointer(&v), unsafe.Sizeof(v)); err != nil {
+ return err
+ }
+ }
+
+ var imageLoadFlags uint32
+ if pmp.PreferSystem32Images {
+ imageLoadFlags |= _PreferSystem32Images
+ }
+ if pmp.ProhibitLowMandatoryLabelImages {
+ imageLoadFlags |= _NoLowMandatoryLabelImages
+ }
+ if pmp.ProhibitRemoteImages {
+ imageLoadFlags |= _NoRemoteImages
+ }
+
+ if imageLoadFlags != 0 {
+ v := processMitigationPolicyFlags{
+ Flags: imageLoadFlags,
+ }
+ if err := setProcessMitigationPolicy(processImageLoadPolicy, unsafe.Pointer(&v), unsafe.Sizeof(v)); err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
+func (pmp *ProcessMitigationPolicies) asMitigationBits() (result uint64) {
+ if pmp.DisableExtensionPoints {
+ result |= (1 << 32)
+ }
+ if pmp.PreferSystem32Images {
+ result |= (1 << 60)
+ }
+ if pmp.ProhibitDynamicCode {
+ result |= (1 << 36)
+ }
+ if pmp.ProhibitLowMandatoryLabelImages {
+ result |= (1 << 56)
+ }
+ if pmp.ProhibitNonMicrosoftSignedDLLs {
+ result |= (1 << 44)
+ }
+ if pmp.ProhibitRemoteImages {
+ result |= (1 << 52)
+ }
+ return result
+}
diff --git a/util/winutil/process_windows_test.go b/util/winutil/process_windows_test.go
new file mode 100644
index 000000000..dc314db2c
--- /dev/null
+++ b/util/winutil/process_windows_test.go
@@ -0,0 +1,17 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package winutil
+
+import (
+ "strings"
+ "testing"
+)
+
+func TestMitigateSelf(t *testing.T) {
+ output := strings.TrimSpace(runTestProg(t, "testprocessattributes", "MitigateSelf"))
+ want := "OK"
+ if output != want {
+ t.Errorf("%s\n", strings.TrimPrefix(output, "error: "))
+ }
+}
diff --git a/util/winutil/subprocess_windows_test.go b/util/winutil/subprocess_windows_test.go
new file mode 100644
index 000000000..886ca644d
--- /dev/null
+++ b/util/winutil/subprocess_windows_test.go
@@ -0,0 +1,388 @@
+// Copyright 2015 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package winutil
+
+import (
+ "bytes"
+ "errors"
+ "fmt"
+ "os"
+ "os/exec"
+ "path/filepath"
+ "runtime"
+ "strconv"
+ "strings"
+ "sync"
+ "testing"
+ "time"
+)
+
+// The code in this file is adapted from internal/testenv in the Go source tree
+// and is used for writing tests that require spawning subprocesses.
+
+var toRemove []string
+
+func TestMain(m *testing.M) {
+ status := m.Run()
+ for _, file := range toRemove {
+ os.RemoveAll(file)
+ }
+ os.Exit(status)
+}
+
+var testprog struct {
+ sync.Mutex
+ dir string
+ target map[string]*buildexe
+}
+
+type buildexe struct {
+ once sync.Once
+ exe string
+ err error
+}
+
+func runTestProg(t *testing.T, binary, name string, env ...string) string {
+ exe, err := buildTestProg(t, binary)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ return runBuiltTestProg(t, exe, name, env...)
+}
+
+func runBuiltTestProg(t *testing.T, exe, name string, env ...string) string {
+ cmd := exec.Command(exe, name)
+ cmd.Env = append(cmd.Env, env...)
+ if testing.Short() {
+ cmd.Env = append(cmd.Env, "RUNTIME_TEST_SHORT=1")
+ }
+ out, _ := runWithTimeout(t, cmd)
+ return string(out)
+}
+
+var serializeBuild = make(chan bool, 2)
+
+func buildTestProg(t *testing.T, binary string, flags ...string) (string, error) {
+ testprog.Lock()
+ if testprog.dir == "" {
+ dir, err := os.MkdirTemp("", "go-build")
+ if err != nil {
+ t.Fatalf("failed to create temp directory: %v", err)
+ }
+ testprog.dir = dir
+ toRemove = append(toRemove, dir)
+ }
+
+ if testprog.target == nil {
+ testprog.target = make(map[string]*buildexe)
+ }
+ name := binary
+ if len(flags) > 0 {
+ name += "_" + strings.Join(flags, "_")
+ }
+ target, ok := testprog.target[name]
+ if !ok {
+ target = &buildexe{}
+ testprog.target[name] = target
+ }
+
+ dir := testprog.dir
+
+ // Unlock testprog while actually building, so that other
+ // tests can look up executables that were already built.
+ testprog.Unlock()
+
+ target.once.Do(func() {
+ // Only do two "go build"'s at a time,
+ // to keep load from getting too high.
+ serializeBuild <- true
+ defer func() { <-serializeBuild }()
+
+ // Don't get confused if goToolPath calls t.Skip.
+ target.err = errors.New("building test called t.Skip")
+
+ exe := filepath.Join(dir, name+".exe")
+
+ t.Logf("running go build -o %s %s", exe, strings.Join(flags, " "))
+ cmd := exec.Command(goToolPath(t), append([]string{"build", "-o", exe}, flags...)...)
+ cmd.Dir = "testdata/" + binary
+ out, err := cmd.CombinedOutput()
+ if err != nil {
+ target.err = fmt.Errorf("building %s %v: %v\n%s", binary, flags, err, out)
+ } else {
+ target.exe = exe
+ target.err = nil
+ }
+ })
+
+ return target.exe, target.err
+}
+
+// goTool reports the path to the Go tool.
+func goTool() (string, error) {
+ if !hasGoBuild() {
+ return "", errors.New("platform cannot run go tool")
+ }
+ exeSuffix := ".exe"
+ goroot, err := findGOROOT()
+ if err != nil {
+ return "", fmt.Errorf("cannot find go tool: %w", err)
+ }
+ path := filepath.Join(goroot, "bin", "go"+exeSuffix)
+ if _, err := os.Stat(path); err == nil {
+ return path, nil
+ }
+ goBin, err := exec.LookPath("go" + exeSuffix)
+ if err != nil {
+ return "", errors.New("cannot find go tool: " + err.Error())
+ }
+ return goBin, nil
+}
+
+// knownEnv is a list of environment variables that affect the operation
+// of the Go command.
+const knownEnv = `
+ AR
+ CC
+ CGO_CFLAGS
+ CGO_CFLAGS_ALLOW
+ CGO_CFLAGS_DISALLOW
+ CGO_CPPFLAGS
+ CGO_CPPFLAGS_ALLOW
+ CGO_CPPFLAGS_DISALLOW
+ CGO_CXXFLAGS
+ CGO_CXXFLAGS_ALLOW
+ CGO_CXXFLAGS_DISALLOW
+ CGO_ENABLED
+ CGO_FFLAGS
+ CGO_FFLAGS_ALLOW
+ CGO_FFLAGS_DISALLOW
+ CGO_LDFLAGS
+ CGO_LDFLAGS_ALLOW
+ CGO_LDFLAGS_DISALLOW
+ CXX
+ FC
+ GCCGO
+ GO111MODULE
+ GO386
+ GOAMD64
+ GOARCH
+ GOARM
+ GOBIN
+ GOCACHE
+ GOENV
+ GOEXE
+ GOEXPERIMENT
+ GOFLAGS
+ GOGCCFLAGS
+ GOHOSTARCH
+ GOHOSTOS
+ GOINSECURE
+ GOMIPS
+ GOMIPS64
+ GOMODCACHE
+ GONOPROXY
+ GONOSUMDB
+ GOOS
+ GOPATH
+ GOPPC64
+ GOPRIVATE
+ GOPROXY
+ GOROOT
+ GOSUMDB
+ GOTMPDIR
+ GOTOOLDIR
+ GOVCS
+ GOWASM
+ GOWORK
+ GO_EXTLINK_ENABLED
+ PKG_CONFIG
+`
+
+// goToolPath reports the path to the Go tool.
+// It is a convenience wrapper around goTool.
+// If the tool is unavailable goToolPath calls t.Skip.
+// If the tool should be available and isn't, goToolPath calls t.Fatal.
+func goToolPath(t testing.TB) string {
+ mustHaveGoBuild(t)
+ path, err := goTool()
+ if err != nil {
+ t.Fatal(err)
+ }
+ // Add all environment variables that affect the Go command to test metadata.
+ // Cached test results will be invalidate when these variables change.
+ // See golang.org/issue/32285.
+ for _, envVar := range strings.Fields(knownEnv) {
+ os.Getenv(envVar)
+ }
+ return path
+}
+
+// hasGoBuild reports whether the current system can build programs with “go build”
+// and then run them with os.StartProcess or exec.Command.
+func hasGoBuild() bool {
+ if os.Getenv("GO_GCFLAGS") != "" {
+ // It's too much work to require every caller of the go command
+ // to pass along "-gcflags="+os.Getenv("GO_GCFLAGS").
+ // For now, if $GO_GCFLAGS is set, report that we simply can't
+ // run go build.
+ return false
+ }
+ return true
+}
+
+// mustHaveGoBuild checks that the current system can build programs with “go build”
+// and then run them with os.StartProcess or exec.Command.
+// If not, mustHaveGoBuild calls t.Skip with an explanation.
+func mustHaveGoBuild(t testing.TB) {
+ if os.Getenv("GO_GCFLAGS") != "" {
+ t.Skipf("skipping test: 'go build' not compatible with setting $GO_GCFLAGS")
+ }
+ if !hasGoBuild() {
+ t.Skipf("skipping test: 'go build' not available on %s/%s", runtime.GOOS, runtime.GOARCH)
+ }
+}
+
+// hasGoRun reports whether the current system can run programs with “go run.”
+func hasGoRun() bool {
+ // For now, having go run and having go build are the same.
+ return hasGoBuild()
+}
+
+// mustHaveGoRun checks that the current system can run programs with “go run.”
+// If not, mustHaveGoRun calls t.Skip with an explanation.
+func mustHaveGoRun(t testing.TB) {
+ if !hasGoRun() {
+ t.Skipf("skipping test: 'go run' not available on %s/%s", runtime.GOOS, runtime.GOARCH)
+ }
+}
+
+var (
+ gorootOnce sync.Once
+ gorootPath string
+ gorootErr error
+)
+
+func findGOROOT() (string, error) {
+ gorootOnce.Do(func() {
+ gorootPath = runtime.GOROOT()
+ if gorootPath != "" {
+ // If runtime.GOROOT() is non-empty, assume that it is valid.
+ //
+ // (It might not be: for example, the user may have explicitly set GOROOT
+ // to the wrong directory, or explicitly set GOROOT_FINAL but not GOROOT
+ // and hasn't moved the tree to GOROOT_FINAL yet. But those cases are
+ // rare, and if that happens the user can fix what they broke.)
+ return
+ }
+
+ // runtime.GOROOT doesn't know where GOROOT is (perhaps because the test
+ // binary was built with -trimpath, or perhaps because GOROOT_FINAL was set
+ // without GOROOT and the tree hasn't been moved there yet).
+ //
+ // Since this is internal/testenv, we can cheat and assume that the caller
+ // is a test of some package in a subdirectory of GOROOT/src. ('go test'
+ // runs the test in the directory containing the packaged under test.) That
+ // means that if we start walking up the tree, we should eventually find
+ // GOROOT/src/go.mod, and we can report the parent directory of that.
+
+ cwd, err := os.Getwd()
+ if err != nil {
+ gorootErr = fmt.Errorf("finding GOROOT: %w", err)
+ return
+ }
+
+ dir := cwd
+ for {
+ parent := filepath.Dir(dir)
+ if parent == dir {
+ // dir is either "." or only a volume name.
+ gorootErr = fmt.Errorf("failed to locate GOROOT/src in any parent directory")
+ return
+ }
+
+ if base := filepath.Base(dir); base != "src" {
+ dir = parent
+ continue // dir cannot be GOROOT/src if it doesn't end in "src".
+ }
+
+ b, err := os.ReadFile(filepath.Join(dir, "go.mod"))
+ if err != nil {
+ if os.IsNotExist(err) {
+ dir = parent
+ continue
+ }
+ gorootErr = fmt.Errorf("finding GOROOT: %w", err)
+ return
+ }
+ goMod := string(b)
+
+ for goMod != "" {
+ var line string
+ line, goMod, _ = strings.Cut(goMod, "\n")
+ fields := strings.Fields(line)
+ if len(fields) >= 2 && fields[0] == "module" && fields[1] == "std" {
+ // Found "module std", which is the module declaration in GOROOT/src!
+ gorootPath = parent
+ return
+ }
+ }
+ }
+ })
+
+ return gorootPath, gorootErr
+}
+
+// runWithTimeout runs cmd and returns its combined output. If the
+// subprocess exits with a non-zero status, it will log that status
+// and return a non-nil error, but this is not considered fatal.
+func runWithTimeout(t testing.TB, cmd *exec.Cmd) ([]byte, error) {
+ args := cmd.Args
+ if args == nil {
+ args = []string{cmd.Path}
+ }
+
+ var b bytes.Buffer
+ cmd.Stdout = &b
+ cmd.Stderr = &b
+ if err := cmd.Start(); err != nil {
+ t.Fatalf("starting %s: %v", args, err)
+ }
+
+ // If the process doesn't complete within 1 minute,
+ // assume it is hanging and kill it to get a stack trace.
+ p := cmd.Process
+ done := make(chan bool)
+ go func() {
+ scale := 2
+ if s := os.Getenv("GO_TEST_TIMEOUT_SCALE"); s != "" {
+ if sc, err := strconv.Atoi(s); err == nil {
+ scale = sc
+ }
+ }
+
+ select {
+ case <-done:
+ case <-time.After(time.Duration(scale) * time.Minute):
+ p.Signal(os.Kill)
+ // If SIGQUIT doesn't do it after a little
+ // while, kill the process.
+ select {
+ case <-done:
+ case <-time.After(time.Duration(scale) * 30 * time.Second):
+ p.Signal(os.Kill)
+ }
+ }
+ }()
+
+ err := cmd.Wait()
+ if err != nil {
+ t.Logf("%s exit status: %v", args, err)
+ }
+ close(done)
+
+ return b.Bytes(), err
+}
diff --git a/util/winutil/testdata/testprocessattributes/main_windows.go b/util/winutil/testdata/testprocessattributes/main_windows.go
new file mode 100644
index 000000000..0a83af630
--- /dev/null
+++ b/util/winutil/testdata/testprocessattributes/main_windows.go
@@ -0,0 +1,40 @@
+// Copyright 2015 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build windows
+
+package main
+
+import "os"
+
+var (
+ cmds = map[string]func(){}
+ err error
+)
+
+func register(name string, f func()) {
+ if cmds[name] != nil {
+ panic("duplicate registration: " + name)
+ }
+ cmds[name] = f
+}
+
+func registerInit(name string, f func()) {
+ if len(os.Args) >= 2 && os.Args[1] == name {
+ f()
+ }
+}
+
+func main() {
+ if len(os.Args) < 2 {
+ println("usage: " + os.Args[0] + " name-of-test")
+ return
+ }
+ f := cmds[os.Args[1]]
+ if f == nil {
+ println("unknown function: " + os.Args[1])
+ return
+ }
+ f()
+}
diff --git a/util/winutil/testdata/testprocessattributes/tests_windows.go b/util/winutil/testdata/testprocessattributes/tests_windows.go
new file mode 100644
index 000000000..93f543988
--- /dev/null
+++ b/util/winutil/testdata/testprocessattributes/tests_windows.go
@@ -0,0 +1,57 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package main
+
+import (
+ "fmt"
+
+ "tailscale.com/util/winutil"
+)
+
+func init() {
+ // registerInit("Foo", FooInit)
+ // register("Foo", Foo)
+ register("MitigateSelf", MitigateSelf)
+}
+
+func MitigateSelf() {
+ var zero winutil.ProcessMitigationPolicies
+ initialPolicies, err := winutil.CurrentProcessMitigationPolicies()
+ if err != nil {
+ fmt.Printf("error: CurrentProcessMitigationPolicies: %v\n", err)
+ return
+ }
+
+ if initialPolicies != zero {
+ fmt.Println("error: initialPolicies not zero value")
+ return
+ }
+
+ setTo := winutil.ProcessMitigationPolicies{
+ DisableExtensionPoints: true,
+ PreferSystem32Images: true,
+ ProhibitDynamicCode: true,
+ ProhibitLowMandatoryLabelImages: true,
+ ProhibitNonMicrosoftSignedDLLs: true,
+ ProhibitRemoteImages: true,
+ }
+
+ if err := setTo.SetOnCurrentProcess(); err != nil {
+ fmt.Printf("error: SetOnCurrentProcess: %v\n", err)
+ return
+ }
+
+ checkPolicies, err := winutil.CurrentProcessMitigationPolicies()
+ if err != nil {
+ fmt.Printf("error: CurrentProcessMitigationPolicies: %v\n", err)
+ return
+ }
+
+ if checkPolicies != setTo {
+ fmt.Printf("error: checkPolicies got %#v, want %#v\n", checkPolicies, setTo)
+ return
+ }
+
+ fmt.Println("OK")
+}
diff --git a/util/winutil/winutil_windows.go b/util/winutil/winutil_windows.go
index ed516ce6b..0e75d4174 100644
--- a/util/winutil/winutil_windows.go
+++ b/util/winutil/winutil_windows.go
@@ -8,7 +8,6 @@ import (
"fmt"
"log"
"os"
- "os/exec"
"os/user"
"runtime"
"strings"
@@ -248,84 +247,6 @@ func EnableCurrentThreadPrivilege(name string) error {
return windows.AdjustTokenPrivileges(t, false, &tp, 0, nil, nil)
}
-// StartProcessAsChild starts exePath process as a child of parentPID.
-// StartProcessAsChild copies parentPID's environment variables into
-// the new process, along with any optional environment variables in extraEnv.
-func StartProcessAsChild(parentPID uint32, exePath string, extraEnv []string) error {
- // The rest of this function requires SeDebugPrivilege to be held.
-
- runtime.LockOSThread()
- defer runtime.UnlockOSThread()
-
- err := windows.ImpersonateSelf(windows.SecurityImpersonation)
- if err != nil {
- return err
- }
- defer windows.RevertToSelf()
-
- // According to https://docs.microsoft.com/en-us/windows/win32/procthread/process-security-and-access-rights
- //
- // ... To open a handle to another process and obtain full access rights,
- // you must enable the SeDebugPrivilege privilege. ...
- //
- // But we only need PROCESS_CREATE_PROCESS. So perhaps SeDebugPrivilege is too much.
- //
- // https://devblogs.microsoft.com/oldnewthing/20080314-00/?p=23113
- //
- // TODO: try look for something less than SeDebugPrivilege
-
- err = EnableCurrentThreadPrivilege("SeDebugPrivilege")
- if err != nil {
- return err
- }
-
- ph, err := windows.OpenProcess(
- windows.PROCESS_CREATE_PROCESS|windows.PROCESS_QUERY_INFORMATION|windows.PROCESS_DUP_HANDLE,
- false, parentPID)
- if err != nil {
- return err
- }
- defer windows.CloseHandle(ph)
-
- var pt windows.Token
- err = windows.OpenProcessToken(ph, windows.TOKEN_QUERY, &pt)
- if err != nil {
- return err
- }
- defer pt.Close()
-
- env, err := pt.Environ(false)
- if err != nil {
- return err
-
- }
- env = append(env, extraEnv...)
-
- sys := &syscall.SysProcAttr{ParentProcess: syscall.Handle(ph)}
-
- cmd := exec.Command(exePath)
- cmd.Env = env
- cmd.SysProcAttr = sys
-
- return cmd.Start()
-}
-
-// StartProcessAsCurrentGUIUser is like StartProcessAsChild, but if finds
-// current logged in user desktop process (normally explorer.exe),
-// and passes found PID to StartProcessAsChild.
-func StartProcessAsCurrentGUIUser(exePath string, extraEnv []string) error {
- // as described in https://devblogs.microsoft.com/oldnewthing/20190425-00/?p=102443
- desktop, err := GetDesktopPID()
- if err != nil {
- return fmt.Errorf("failed to find desktop: %v", err)
- }
- err = StartProcessAsChild(desktop, exePath, extraEnv)
- if err != nil {
- return fmt.Errorf("failed to start executable: %v", err)
- }
- return nil
-}
-
// CreateAppMutex creates a named Windows mutex, returning nil if the mutex
// is created successfully or an error if the mutex already exists or could not
// be created for some other reason.
diff --git a/util/winutil/zsyscall_windows.go b/util/winutil/zsyscall_windows.go
index 77e9f36c8..ad3912092 100644
--- a/util/winutil/zsyscall_windows.go
+++ b/util/winutil/zsyscall_windows.go
@@ -43,12 +43,22 @@ var (
modkernel32 = windows.NewLazySystemDLL("kernel32.dll")
procQueryServiceConfig2W = modadvapi32.NewProc("QueryServiceConfig2W")
+ procGetProcessMitigationPolicy = modkernel32.NewProc("GetProcessMitigationPolicy")
procRegisterApplicationRestart = modkernel32.NewProc("RegisterApplicationRestart")
+ procSetProcessMitigationPolicy = modkernel32.NewProc("SetProcessMitigationPolicy")
)
func queryServiceConfig2(hService windows.Handle, infoLevel uint32, buf *byte, bufLen uint32, bytesNeeded *uint32) (err error) {
r1, _, e1 := syscall.Syscall6(procQueryServiceConfig2W.Addr(), 5, uintptr(hService), uintptr(infoLevel), uintptr(unsafe.Pointer(buf)), uintptr(bufLen), uintptr(unsafe.Pointer(bytesNeeded)), 0)
- if r1 == 0 {
+ if int32(r1) == 0 {
+ err = errnoErr(e1)
+ }
+ return
+}
+
+func getProcessMitigationPolicy(hProcess windows.Handle, mitigationPolicy _PROCESS_MITIGATION_POLICY, buf unsafe.Pointer, bufLen uintptr) (err error) {
+ r1, _, e1 := syscall.Syscall6(procGetProcessMitigationPolicy.Addr(), 4, uintptr(hProcess), uintptr(mitigationPolicy), uintptr(buf), uintptr(bufLen), 0, 0)
+ if int32(r1) == 0 {
err = errnoErr(e1)
}
return
@@ -59,3 +69,11 @@ func registerApplicationRestart(cmdLineExclExeName *uint16, flags uint32) (ret w
ret = wingoes.HRESULT(r0)
return
}
+
+func setProcessMitigationPolicy(mitigationPolicy _PROCESS_MITIGATION_POLICY, buf unsafe.Pointer, bufLen uintptr) (err error) {
+ r1, _, e1 := syscall.Syscall(procSetProcessMitigationPolicy.Addr(), 3, uintptr(mitigationPolicy), uintptr(buf), uintptr(bufLen))
+ if int32(r1) == 0 {
+ err = errnoErr(e1)
+ }
+ return
+}